Training Kernels
Understanding backward passes, mixed-precision training, and memory optimization techniques that make training large models possible on limited GPU memory.
- Implement backward passes for common operations
- Apply mixed-precision training with loss scaling
- Use activation checkpointing to trade compute for memory
- Implement gradient accumulation for larger effective batch sizes
- Debug training-specific numerical issues
The Computational Graph
Training neural networks requires computing gradients via automatic differentiation. For every forward operation, there's a corresponding backward operation that propagates gradients.
Memory Requirements
Training consumes far more memory than inference because you must store:
For a 7B parameter model in FP32, the ZeRO paper breaks down memory consumption:
Activations scale with batch size and sequence length. For transformers, attention activations grow as O(seq_len²) per layer unless using FlashAttention which doesn't materialize the full attention matrix.
Forward vs Backward Compute
A common rule of thumb: backward pass costs ~2x the forward pass. This comes from computing both the gradient with respect to inputs and weights. For a linear layer Y = XW:
# Forward: 1 matmul Y = X @ W # (B, in) @ (in, out) = (B, out)
# Backward: 2 matmuls dX = dY @ W.T # gradient w.r.t. input dW = X.T @ dY # gradient w.r.t. weights
Writing Backward Passes
Understanding backward kernels helps you write custom operations and debug gradient issues. Each operation must define how gradients flow backwards.
Linear Layer Backward
The linear layer backward pass requires saving the input tensor from the forward pass:
@triton.jit def linear_backward_dx( dY_ptr, W_ptr, dX_ptr, M, N, K, # M=batch, N=in_features, K=out_features stride_dy_m, stride_dy_k, stride_w_n, stride_w_k, stride_dx_m, stride_dx_n, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr ): """dX = dY @ W.T""" pid_m = tl.program_id(0) pid_n = tl.program_id(1) # Accumulate dX[m, n] = sum_k(dY[m, k] * W[n, k]) offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_k = tl.arange(0, BLOCK_K) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(0, K, BLOCK_K): dy = tl.load(dY_ptr + offs_m[:, None] * stride_dy_m + (k + offs_k[None, :]) * stride_dy_k) w = tl.load(W_ptr + offs_n[:, None] * stride_w_n + (k + offs_k[None, :]) * stride_w_k) acc += tl.dot(dy, tl.trans(w)) tl.store(dX_ptr + offs_m[:, None] * stride_dx_m + offs_n[None, :] * stride_dx_n, acc)
LayerNorm Backward
LayerNorm backward is more complex because normalization creates dependencies between elements. The gradients must account for the mean and variance computations:
# Forward: y = (x - mean) / std * gamma + beta # Where std = sqrt(var + eps) # Backward requires 3 passes: # 1. Compute gradient contributions dx_hat = dY * gamma # (B, D) dvar = sum(dx_hat * (x - mean) * -0.5 * std^-3) dmean = sum(dx_hat * -1/std) + dvar * sum(-2 * (x - mean)) / D # 2. Compute dx dx = dx_hat / std + dvar * 2 * (x - mean) / D + dmean / D # 3. Compute dgamma, dbeta dgamma = sum(dY * (x - mean) / std, dim=0) # reduce over batch dbeta = sum(dY, dim=0)
Division by std can cause issues when variance is very small. Always use the same epsilon in backward as forward, and consider using numerically stable formulations.
Attention Backward Memory
Standard attention backward is memory-intensive because it requires the full attention matrix. For sequence length L and batch B with H heads:
| What to Store | Size | L=2048, B=8, H=32 |
|---|---|---|
| Q, K, V tensors | 3 × B × H × L × d | ~3 GB |
| Attention matrix | B × H × L × L | ~8 GB |
| Softmax output | B × H × L × L | ~8 GB |
This is why FlashAttention recomputes attention during backward instead of storing it—trading compute for memory.
FP16/BF16 Training with Loss Scaling
Mixed-precision training uses lower precision (FP16 or BF16) for forward/backward passes while maintaining FP32 master weights. This reduces memory and leverages Tensor Cores for speed.
The Precision Hierarchy
| Format | Bits | Dynamic Range | Use Case |
|---|---|---|---|
FP32 |
32 | ~1038 | Master weights, optimizer states |
FP16 |
16 | ~65504 | Forward/backward (requires loss scaling) |
BF16 |
16 | ~1038 | Forward/backward (no loss scaling needed) |
TF32 |
19 (internal) | ~1038 | Tensor Core matmuls on Ampere+ |
Loss Scaling
FP16 can't represent gradients smaller than ~6×10-8. Small gradients underflow to zero, causing training to stall. Loss scaling multiplies the loss by a large factor, scaling up all gradients:
from torch.cuda.amp import autocast, GradScaler scaler = GradScaler() for inputs, targets in dataloader: optimizer.zero_grad() # Forward in FP16 with autocast(): outputs = model(inputs) loss = criterion(outputs, targets) # Backward with scaled loss scaler.scale(loss).backward() # loss × scale_factor # Unscale gradients, check for inf/nan, step scaler.step(optimizer) # only steps if grads valid scaler.update() # adjust scale factor
Dynamic Loss Scaling
The scale factor is adjusted dynamically:
- No overflow for N steps → increase scale (typically 2x)
- Overflow detected → decrease scale (typically 0.5x), skip update
BF16 has the same exponent range as FP32, so gradients rarely underflow. Most modern training uses BF16 without loss scaling when hardware supports it (Ampere+ GPUs, TPUs).
What Stays in FP32
Even in mixed-precision training, some operations must remain in FP32:
- Softmax — accumulation over many elements
- LayerNorm variance — requires high precision for stability
- Loss computation — small loss values can underflow
- Optimizer states — momentum/variance accumulate over many steps
Trading Compute for Memory
Activation checkpointing (gradient checkpointing) saves memory by not storing all intermediate activations. Instead, it recomputes them during the backward pass.
How It Works
Without checkpointing, all activations are saved for backward:
With checkpointing every 2 layers:
Implementation
from torch.utils.checkpoint import checkpoint class CheckpointedTransformer(nn.Module): def __init__(self, n_layers): super().__init__() self.layers = nn.ModuleList([ TransformerBlock() for _ in range(n_layers) ]) def forward(self, x): for layer in self.layers: # checkpoint() doesn't save activations for this segment # Instead, recomputes them during backward x = checkpoint(layer, x, use_reentrant=False) return x
Memory vs Compute Tradeoff
| Strategy | Memory | Compute Overhead | When to Use |
|---|---|---|---|
| No checkpointing | O(n) | 0% | Memory not a constraint |
| Checkpoint every k layers | O(n/k) | ~(k-1)/k × forward | Moderate memory pressure |
| Checkpoint every layer | O(1) per layer | ~100% (2x forward) | Severe memory constraints |
| Selective checkpointing | Variable | Variable | Checkpoint attention, keep FFN |
Not all operations benefit equally from checkpointing. Attention has O(L²) activations but O(L²) recompute cost. FFN has O(L×d) activations but cheap recompute. Selective checkpointing only checkpoints the most memory-intensive operations.
Simulating Larger Batch Sizes
Gradient accumulation lets you train with effective batch sizes larger than what fits in memory by accumulating gradients over multiple forward-backward passes before updating weights.
The Pattern
With 4 accumulation steps and micro-batch size 8, effective batch size = 32.
Implementation
accumulation_steps = 4 optimizer.zero_grad() for i, (inputs, targets) in enumerate(dataloader): # Forward + backward (gradients accumulate) with autocast(): outputs = model(inputs) loss = criterion(outputs, targets) loss = loss / accumulation_steps # Scale loss scaler.scale(loss).backward() # Only update every N steps if (i + 1) % accumulation_steps == 0: scaler.step(optimizer) scaler.update() optimizer.zero_grad()
Since gradients are summed (not averaged) across accumulation steps, you must divide the loss by the number of accumulation steps to get correct gradient magnitudes. Alternatively, ensure your loss is already a mean over the batch.
BatchNorm Interaction
BatchNorm computes statistics over micro-batches, not the effective batch. This can cause training instability with small micro-batches. Solutions:
- SyncBatchNorm — synchronize statistics across GPUs
- GroupNorm/LayerNorm — statistics independent of batch size
- Ghost BatchNorm — compute stats over virtual sub-batches
Stability and Debugging
Training instability manifests as loss spikes, NaN losses, or gradients that explode/vanish. Understanding the numerical causes helps you fix them.
Gradient Clipping
Gradient clipping prevents exploding gradients by capping their magnitude:
# After backward, before optimizer step torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Or clip by value torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)
Common max_norm values: 1.0 for transformers, 0.1-0.5 for RNNs. Monitor gradient norms—if clipping activates frequently, investigate the cause.
Loss Spikes
Sudden loss increases often indicate numerical issues:
| Symptom | Likely Cause | Fix |
|---|---|---|
| Spike then recovery | Outlier batch, FP16 overflow | Skip batch, adjust loss scale |
| Spike then NaN | Gradient explosion | Gradient clipping, lower LR |
| Gradual divergence | LR too high, bad hyperparams | LR warmup, hyperparameter search |
| NaN from start | Weight init, input data issues | Check data pipeline, init scale |
Debugging NaN Gradients
def check_nan_hook(module, grad_input, grad_output): for i, grad in enumerate(grad_output): if grad is not None and torch.isnan(grad).any(): print(f"NaN in {module.__class__.__name__} grad_output[{i}]") print(f" grad stats: min={grad.min()}, max={grad.max()}") # Register on all modules for module in model.modules(): module.register_backward_hook(check_nan_hook)
PyTorch's anomaly detection
tracks which operation produced NaN: torch.autograd.set_detect_anomaly(True).
Slows training significantly—use only for debugging.
Weight Initialization
Proper initialization prevents vanishing/exploding gradients at training start. Common schemes:
- He/Kaiming — std = sqrt(2/fan_in), good for ReLU
- Xavier/Glorot — std = sqrt(2/(fan_in + fan_out)), good for tanh/sigmoid
- Small init for residuals — scale final layer of residual blocks by 1/sqrt(n_layers)
GPT-style models often scale residual path weights by 1/sqrt(2 * n_layers)
as noted in the GPT-2 paper.
Key Takeaways
- Training needs ~4x model weights for Adam (weights + gradients + 2x optimizer states)
- Activations scale with batch size, sequence length, and model depth
- Backward pass costs ~2x forward due to computing gradients for both inputs and weights
- Mixed precision — FP16/BF16 halves memory for activations and gradients
- Activation checkpointing — trade compute for memory by recomputing activations
- Gradient accumulation — simulate larger batches without more memory
- Use loss scaling with FP16 (not needed for BF16)
- Gradient clipping prevents explosion—common max_norm = 1.0
- Proper initialization scales with network depth