Chapter 5

Attention

The math behind attention kernels. From dot products to softmax to FlashAttention—understand what every operation actually does.

Building on Chapter 4
Chapter 4's tiling wasn't just about matrix multiply—it's a universal optimization pattern. Attention is where this pattern becomes essential. Without tiling, attention on long sequences would require terabytes of memory. With tiling (FlashAttention), it fits in shared memory.
What You'll Learn
  1. Implement standard scaled dot-product attention
  2. Explain why naive attention is O(n²) in memory
  3. Describe how FlashAttention achieves O(n) memory
  4. Implement online softmax for numerical stability
  5. Apply tiling strategies to attention computation
📚
Prerequisites

This chapter uses linear algebra and exponentials. Linear Algebra | Exponentials

01 - SIMILARITY

The Dot Product: Measuring Relevance

Attention computes how relevant each cached token is to your query. The dot product is the measuring stick—it tells you how much two vectors "agree."

High dot product = vectors point in similar directions = high relevance.
Zero dot product = vectors are perpendicular = no relationship.

Interactive: Compute Q*K
Query (Q)
Key (K)
Dot Product (Q * K)
1.00
(1x0.5) + (0x0) + (0.5x1) + (-0.5x0) = 1.00

In attention, you compute Q*K for every cached token. With a 4096-token context, that's 4096 dot products just to process one query. This is why attention is the bottleneck.

Q and K have shape [seq_len, d]. Shape of QK^T?
[d, d]
[seq_len, seq_len]
[seq_len, d]

02 - NORMALIZATION

Softmax: From Scores to Probabilities

Raw dot products can be any value—positive, negative, huge, tiny. Softmax converts them to a probability distribution: all positive, sums to 1.

softmax(xi) = exp(xi) / Sum exp(xj)

The exponential amplifies differences. A score of 10 vs 5 becomes e10/e5 ~ 150x more weight, not 2x. This makes attention "sharp"—it focuses on the most relevant tokens.

Interactive: Softmax Visualization
Score 1
Score 2
Score 3
Score 4
Score 5
0.50
0.18
0.11
0.02
0.07
Try setting Score 1 to 100. What happens to the other probabilities?
They decrease proportionally
They collapse to nearly zero
They stay the same

The problem: exp(100) = 2.7 x 1043. That overflows FP16 (max ~ 65504). Your kernel crashes. This is why we need the numerical stability trick.

03 - THE TRICK

Online Softmax: Streaming Without Overflow

Two problems with naive softmax:

1. Overflow: Large values explode exp(). Solution: Subtract max first.

2. Memory: You need to see ALL values to compute max. But in attention, you're processing in blocks to stay in fast SRAM. Solution: Online algorithm that updates incrementally.

softmax(xi) = exp(xi - max(x)) / Sum exp(xj - max(x))
Simulation: Online Softmax Streaming

Watch how the algorithm maintains running statistics as new values stream in. This is how FlashAttention processes attention in blocks.

Values will appear here...
Running Max (m)
-inf
Sum of Exp (l)
0.00
Values Seen
0
Key Insight

Click "Add Random Block" to start the simulation.

# The online softmax update rule
def update(m_old, l_old, new_block):
    m_block = new_block.max()
    m_new = max(m_old, m_block)
    
    # Rescale old accumulator to new max
    l_new = l_old * exp(m_old - m_new)
    
    # Add new block contribution
    l_new += sum(exp(new_block - m_new))
    
    return m_new, l_new
Online softmax maintains:
Running max only
Running sum only
Both running max and sum of exp

04 - PRECISION

Floating Point: What FP8 and NVFP4 Actually Are

Your KV cache is quantized. Understanding the bit layout tells you what precision you're trading for memory bandwidth.

Floating Point Formats
S
Exp(5)
Mantissa(10)
S
E(4)
M(3)

Sign Exponent Mantissa

Why NVFP4 works: Neural network values cluster tightly. A per-block scaling factor (stored in FP8) shifts the representable range to where your values actually are. You get 4x memory reduction vs FP16 with ~1% accuracy loss.

With only 1 mantissa bit (E2M1), how many distinct positive values can NVFP4 represent (excluding zero and special values)?
4
6
8
16
05 - SYNTHESIS

Putting It Together: Full Attention

Now you have all the pieces. Here's the complete attention equation and what each part does:

Visualization: Decode Attention Step-by-Step
Current Step

Click a step to see what happens at each stage of attention.

What your kernel must do efficiently:

1. Load Q (single vector for decode)
2. Stream through KV cache in blocks (fits in SRAM)
3. Compute dot products, track online softmax statistics
4. Accumulate weighted V vectors
5. Output final attention result

The bottleneck is memory bandwidth—loading all those K and V vectors from HBM. FP8/NVFP4 quantization halves or quarters that traffic.

PRACTICE

Hands-On Labs

REFERENCES

Citations & Further Reading

Video Resources

Attention in Transformers (3Blue1Brown)

Outstanding visual explanation of attention, QKV, and the transformer architecture.

Watch on YouTube
Let's Build GPT: From Scratch (Andrej Karpathy)

Build a transformer from scratch with detailed attention implementation.

Watch on YouTube

Foundational Papers

  1. Attention Is All You Need - Vaswani et al., NeurIPS 2017
    arXiv:1706.03762
  2. FlashAttention - Dao et al., NeurIPS 2022
    arXiv:2205.14135
  3. FlashAttention-2 - Dao, 2023
    arXiv:2307.08691