Common Kernels
Beyond attention: the essential kernels that make up modern neural networks. LayerNorm, RMSNorm, fused activations, embeddings, and optimizers.
- Implement fused element-wise operations (activation + bias)
- Write efficient LayerNorm and RMSNorm kernels
- Optimize embedding lookups for arbitrary indices
- Understand fused optimizer patterns (Adam)
- Know when to write custom kernels vs using cuDNN
Why Fusion Matters
Most neural network operations are memory-bound—limited by how fast you can move data, not compute. Fusion reduces memory traffic by combining multiple operations into one kernel.
Unfused vs Fused: Memory Traffic
2 kernel launches, 6 memory operations
1 kernel launch, 2 memory operations → 3x less memory traffic
Fused Bias + Activation in Triton
@triton.jit def fused_bias_gelu( x_ptr, bias_ptr, out_ptr, N, # number of elements BLOCK: tl.constexpr ): pid = tl.program_id(0) offsets = pid * BLOCK + tl.arange(0, BLOCK) mask = offsets < N # Load x and bias (bias is broadcast) x = tl.load(x_ptr + offsets, mask=mask) bias = tl.load(bias_ptr + offsets % BIAS_SIZE, mask=mask) # Fused: add bias, then GELU x = x + bias # GELU approximation: x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) out = x * 0.5 * (1.0 + tl.libdevice.tanh( 0.7978845608 * (x + 0.044715 * x * x * x) )) tl.store(out_ptr + offsets, out, mask=mask)
The GELU activation (Hendrycks & Gimpel, 2016)
uses a tanh approximation for efficiency. PyTorch's nn.GELU(approximate='tanh') uses this form.
The exact form uses erf() which is slower.
LayerNorm and RMSNorm
Layer Normalization (Ba et al., 2016) is ubiquitous in Transformers. RMSNorm (Zhang & Sennrich, 2019) is a simplified variant used in LLaMA and other modern architectures.
LayerNorm vs RMSNorm
| Operation | LayerNorm | RMSNorm |
|---|---|---|
| Formula | γ * (x - μ) / √(σ² + ε) + β |
γ * x / √(mean(x²) + ε) |
| Compute mean? | Yes (for centering) | No |
| Learnable bias? | Yes (β) | No |
| Passes over data | 2 (mean, then variance) | 1 (just RMS) |
| Used in | BERT, GPT-2, T5 | LLaMA, Mistral, Gemma |
LayerNorm Kernel Anatomy
Load row into registers/shared memory
Each thread block handles one or more rows. Load the entire row for reduction.
Compute mean via parallel reduction
Sum all elements, divide by count. Use warp shuffles for efficiency.
Compute variance via parallel reduction
Sum (x - mean)², divide by count. Can use Welford's algorithm for single-pass.
Normalize and apply affine transform
Compute (x - mean) * rsqrt(var + eps) * gamma + beta. Write output.
RMSNorm: Simpler and Faster
@triton.jit def rmsnorm_kernel( x_ptr, weight_ptr, out_ptr, stride, N, eps: tl.constexpr, BLOCK: tl.constexpr ): row = tl.program_id(0) x_ptr += row * stride out_ptr += row * stride # Load row cols = tl.arange(0, BLOCK) mask = cols < N x = tl.load(x_ptr + cols, mask=mask, other=0.0).to(tl.float32) # Compute RMS (no mean subtraction!) x_sq = x * x mean_sq = tl.sum(x_sq) / N rms = tl.rsqrt(mean_sq + eps) # Normalize and scale weight = tl.load(weight_ptr + cols, mask=mask) out = x * rms * weight tl.store(out_ptr + cols, out, mask=mask)
Always compute reductions in FP32, even for FP16 inputs. Accumulating many small values in FP16 causes precision loss. Cast back to the output dtype only at the final store.
Embedding Lookups
Embedding tables convert token IDs to vectors. The challenge: indices are arbitrary, causing non-coalesced memory access.
The Coalescing Problem
Indices: [0, 1, 2, 3, 4, 5, 6, 7]
Coalesced: 1 memory transaction
Indices: [42, 7, 1024, 3, 999, 15, 42, 100]
Scattered: up to 8 memory transactions
Optimization Strategies
| Strategy | When to Use | Trade-off |
|---|---|---|
| Vectorized loads | Embedding dim divisible by 4 | Load float4 instead of float → 4x fewer transactions per row |
| L2 cache persistence | Repeated access to same embeddings | Use cudaAccessPolicyWindow to pin hot embeddings |
| Sorted indices | Batch allows reordering | Sort indices to improve locality, then unsort output |
| Embedding bag | Sum/mean pooling over variable-length sequences | Fuse gather + reduction in one kernel |
@triton.jit def embedding_kernel( indices_ptr, weight_ptr, out_ptr, seq_len, embed_dim, weight_stride, BLOCK_SIZE: tl.constexpr ): # Each program handles one token token_idx = tl.program_id(0) # Load the vocabulary index for this token vocab_idx = tl.load(indices_ptr + token_idx) # Calculate pointer to embedding row embed_ptr = weight_ptr + vocab_idx * weight_stride # Load embedding in chunks (vectorized) offsets = tl.arange(0, BLOCK_SIZE) mask = offsets < embed_dim embedding = tl.load(embed_ptr + offsets, mask=mask) # Store to output out_offset = token_idx * embed_dim tl.store(out_ptr + out_offset + offsets, embedding, mask=mask)
Fused Adam
The Adam optimizer (Kingma & Ba, 2014) updates parameters using first and second moment estimates. A naive implementation requires multiple kernel launches and memory passes.
Adam Memory Traffic
| Operation | Unfused (separate kernels) | Fused (one kernel) |
|---|---|---|
| Load gradient | 1 load |
1 kernel Load: g, m, v, p Store: m, v, p = 7 memory ops |
| Update m = β₁m + (1-β₁)g | Load m, store m | |
| Update v = β₂v + (1-β₂)g² | Load g, load v, store v | |
| Compute m̂, v̂ (bias correction) | Load m, load v | |
| Update p = p - lr * m̂ / (√v̂ + ε) | Load p, store p | |
| Total memory ops | 11 ops (5 kernels) | 7 ops (1 kernel) |
@triton.jit def fused_adam( param_ptr, grad_ptr, m_ptr, v_ptr, lr, beta1, beta2, eps, step, N, BLOCK: tl.constexpr ): pid = tl.program_id(0) offsets = pid * BLOCK + tl.arange(0, BLOCK) mask = offsets < N # Load all tensors once p = tl.load(param_ptr + offsets, mask=mask) g = tl.load(grad_ptr + offsets, mask=mask) m = tl.load(m_ptr + offsets, mask=mask) v = tl.load(v_ptr + offsets, mask=mask) # Update moments m = beta1 * m + (1 - beta1) * g v = beta2 * v + (1 - beta2) * g * g # Bias correction m_hat = m / (1 - tl.math.pow(beta1, step)) v_hat = v / (1 - tl.math.pow(beta2, step)) # Update parameters p = p - lr * m_hat / (tl.sqrt(v_hat) + eps) # Store updated values tl.store(param_ptr + offsets, p, mask=mask) tl.store(m_ptr + offsets, m, mask=mask) tl.store(v_ptr + offsets, v, mask=mask)
Real-world optimizers use multi-tensor apply: process all model parameters in one kernel launch. This amortizes kernel launch overhead (which can be ~5-10μs per launch) across all parameters.
When to Write Custom Kernels
cuDNN and cuBLAS are highly optimized. Don't reinvent the wheel—but know when custom kernels win.
Use Libraries When...
| Operation | Library | Why |
|---|---|---|
| Matrix multiply (GEMM) | cuBLAS | Tensor Core optimized, auto-tuned per GPU |
| Convolution | cuDNN | Multiple algorithms (Winograd, FFT, implicit GEMM), auto-tuned |
| Batch normalization | cuDNN | Fused forward+backward, running stats handled |
| Attention (standard) | FlashAttention | IO-aware, extensively optimized |
Write Custom When...
| Scenario | Example | Why Custom Wins |
|---|---|---|
| Fusion opportunities | Bias + GELU + Dropout | Libraries can't fuse across op boundaries |
| Non-standard shapes | Very small matrices, odd dimensions | Libraries optimized for common sizes |
| Custom attention patterns | Sliding window, sparse patterns | Standard attention doesn't support masking patterns |
| Research ops | Novel activations, custom losses | No library implementation exists |
Always benchmark your custom kernel against the library version. A "clever" custom kernel that's 20% slower than cuDNN is a waste of engineering time. Profile with realistic batch sizes and shapes.
Citations & Further Reading
Papers
-
Layer Normalization
Ba, Kiros, Hinton, 2016
arXiv:1607.06450 -
Root Mean Square Layer Normalization
Zhang & Sennrich, 2019
arXiv:1910.07467 -
GELU Activation
Hendrycks & Gimpel, 2016
arXiv:1606.08415 -
Adam Optimizer
Kingma & Ba, 2014
arXiv:1412.6980 -
LLaMA (uses RMSNorm)
Touvron et al., 2023
arXiv:2302.13971
Documentation
-
Triton Tutorials
Fused softmax, matrix multiply, and more
triton-lang.org/tutorials -
CUDA C++ Best Practices Guide
Memory coalescing, occupancy, optimization
docs.nvidia.com/cuda -
cuDNN Documentation
Convolution, normalization, attention APIs
docs.nvidia.com/cudnn -
APEX Fused Optimizers
Multi-tensor apply, fused Adam/LAMB
nvidia.github.io/apex