Efficient Attention at Scale¶
1. Why efficiency matters in 2026¶
The cost of a deployed large language model is not paid mostly during training. A modern foundation model trains once on a large compute cluster and is then served to users for months or years; over the deployed lifetime of the model, the total inference compute exceeds the training compute by a wide margin, and the dominant operational expense is the cost of generating tokens for users [src_007]. The economics of attention therefore have to be analysed at decode time, not at training time, and the relevant bottleneck at decode time is not arithmetic but memory bandwidth.
⚠️ Pitfall
The "decode is memory-bandwidth-bound, not compute-bound" claim is the load-bearing premise of this chapter and is justified arithmetically only in §7. A reader who has not yet seen the \(\sim 312\) TFLOP/s vs \(\sim 1.5\)–\(2\) TB/s asymmetry should treat it as a forward reference to §7 rather than as an unjustified assertion.
The pressure has sharpened with the rise of long-context applications. The 2017 Transformer was evaluated at sequence lengths of a few hundred tokens, and the early generation of decoder-only language models was deployed at 2k or 4k context [src_002]. By 2024 the open-weights frontier was 32k, and by 2026 production deployments routinely serve 128k or longer windows for retrieval-augmented generation, document analysis, and agentic workflows [src_007]. At those lengths, two things happen at once. The quadratic attention computation \(\mathcal{O}(T^2 \, d_h)\) becomes a measurable fraction of per-step cost, and the linear-in-\(T\) KV-cache storage becomes a binding constraint on how many concurrent requests one accelerator can serve [src_002, src_007]. Both pressures fall on the attention layer; neither is felt by the FFN.
🔗 Connection
Where this chapter handles storing a \(128\)k KV-cache, Chapter 2 handles the question of how the model uses \(128\)k tokens of position information — RoPE plus the YaRN/NTK extrapolation tricks that let a model trained at shorter context generalise out.
This chapter takes up the three engineering ideas that the 2022–2024 papers used to absorb both pressures without changing what attention computes. The first idea, grouped-query attention (GQA), attacks the KV-cache by sharing key/value heads across groups of query heads [src_020]. The second, FlashAttention, attacks attention runtime by tiling the computation so that the \(T \times T\) attention matrix never has to be materialised in HBM [src_021]. The third, the FlashAttention-2 and FlashAttention-3 follow-ups, attack the gap between achievable and peak FLOPs on modern accelerators by re-partitioning the work and exploiting hardware-specific async paths [src_022, src_023]. Together these three changes make 100k-context inference on a 70B-class open-weights model viable on a single eight-GPU server [src_007, src_047].
🔗 Connection
This chapter develops the KV-cache and FlashAttention components that Chapter 8 will assemble into the modern decoder-only inference stack — the same chapter that picks up speculative decoding in detail.
2. The KV-cache: anatomy and memory formula¶
Autoregressive generation produces tokens one at a time. At each step, the model takes the latest token, embeds it, and runs it through every layer of the stack to produce a logit distribution for the next position [src_002]. Because every previous token's hidden state at every layer has already been computed, it is wasteful to recompute keys and values for those positions: their \(K_{\ell, t}\) and \(V_{\ell, t}\) tensors do not change from step to step, so they can be cached and reused [src_002, src_020]. This cache, accumulated across positions and layers and persisted between generation steps, is the KV-cache [src_002].
The size of the cache follows directly from its definition. Fix one decoder layer \(\ell\) and one generated position \(t\). The layer holds two activation tensors at that position, the keys \(K_{\ell, t}\) and the values \(V_{\ell, t}\), each of shape \((B, h_{kv}, d_h)\), where \(B\) is the batch size, \(h_{kv}\) is the number of KV heads (we will see in §4–§6 that \(h_{kv}\) need not equal the number of query heads \(h\)), and \(d_h\) is the per-head dimension [src_020, src_002]. Stacking across \(T\) generated positions and \(L\) layers, and accounting for both the key tensor and the value tensor, the total number of cached scalars is
The leading factor \(2\) is for \(K\) and \(V\) jointly; it is a count of tensors, not a count of bytes. To convert scalars to bytes we multiply by the storage precision: in fp16 or bf16, two bytes per element [src_002, src_020]:
Both factors of \(2\) matter and are independent. The first \(2\) is the \(K\)-plus-\(V\) count and would still be there if the model used fp32. The second \(2\) is the bytes-per-element of half-precision storage and would be \(4\) if the cache were held in fp32 or \(1\) if it were quantised to int8 [src_020]. Conflating the two factors is a common bookkeeping error in informal writing; we will keep them separate throughout this chapter.
The formula has three immediate qualitative consequences. First, the cache grows linearly in the sequence length \(T\), so doubling the context window doubles the cache. Second, it grows linearly in the number of KV heads \(h_{kv}\), so any architectural change that reduces \(h_{kv}\) shrinks the cache by exactly the same factor. Third, it grows linearly in the layer count \(L\), so the KV-cache footprint of a deep narrow model can be larger than that of a shallow wide one even when their parameter counts are matched [src_020].
3. A concrete back-of-the-envelope: Llama-3 70B at 128k context¶
To convert these scaling laws into a number an engineer can plan around, we walk through a single worked example. Take Llama-3 70B, which has \(L = 80\) decoder layers, \(h = 64\) query heads, \(h_{kv} = 8\) KV heads (the model uses GQA with eight groups; we return to the choice in §6), and per-head dimension \(d_h = 128\) [src_010, src_047]. Suppose we serve a single request, so \(B = 1\), at a context length of \(T = 128\text{k} = 128 \cdot 1024 = 131{,}072\) tokens, with the cache held in fp16 so \(\text{bytes\_per\_elem} = 2\).
Plugging into the formula of §2:
Multiplying the constants step by step:
- \(2 \cdot 80 = 160\).
- \(160 \cdot 131{,}072 = 20{,}971{,}520\).
- \(20{,}971{,}520 \cdot 8 = 167{,}772{,}160\).
- \(167{,}772{,}160 \cdot 128 = 21{,}474{,}836{,}480\).
- \(21{,}474{,}836{,}480 \cdot 2 = 42{,}949{,}672{,}960\) bytes.
A useful per-unit shortcut hides inside that 11-digit total. The per-(token, KV-head, head-dim element) byte cost is \(2 \cdot 2 = 4\) bytes (the \(K\)+\(V\) count times bytes-per-element); multiplied by \(d_h = 128\) this gives \(512\) bytes per token per KV-head, so the rest of the chain is just \(L \cdot B \cdot T \cdot h_{kv}\). That mental shortcut — about \(512\) bytes per (token, KV-head) at \(d_h = 128\), fp16 — is what makes §6's GQA savings argument arithmetic rather than rhetorical.
🤔 Pause and reflect
Before reading on, predict — at \(B = 8\) concurrent requests instead of \(B = 1\), what KV-cache per server does the \(128\)k Llama-3 70B configuration require, and does it fit in the \(640\) GB of an eight-H100 server once the model weights are also accounted for? (Do not look ahead — write the answer down or say it out loud.)
That is \(42{,}949{,}672{,}960\) bytes per request. Dividing by \(2^{30} = 1{,}073{,}741{,}824\) to convert to gibibytes,
or, equivalently, about \(43\) GB in decimal units (dividing by \(10^9\)). The synthesis note for this part records the same calculation at the rounded \(\approx 41\) GB level [src_007]; depending on whether one uses gibibytes or gigabytes the figure rounds to \(40\) or \(43\), and we will write the rounded result as roughly \(40\)–\(43\) GB per sequence to be honest about the unit ambiguity.
🎯 Intuition
GiB (\(2^{30}\) bytes) and GB (\(10^9\) bytes) differ by about \(7.4\%\). HBM datasheets typically quote GB; OS tools and many ML frameworks report in GiB. The same KV-cache is honestly \(40\) GiB or \(43\) GB; the engineering decision does not turn on which unit you write — only on remembering which one you are reading.
Two observations make this number actionable. First, \(40\) GB is comparable to the full fp16 weight memory of a \(20\)B-class model and exceeds the \(80\) GB HBM of a single H100 by half once the model weights themselves (which for a \(70\)B model in fp16 are \(\approx 140\) GB and must already be sharded across multiple GPUs) are accounted for [src_007, src_047]. Serving a single \(128\)k request therefore consumes a substantial fraction of cluster memory, and serving many concurrent requests would exhaust HBM long before the compute saturates. Second, the size scales with \(h_{kv}\) in a way that makes the architectural choice of grouped-query attention financially load-bearing: had the same model retained the original multi-head attention with \(h_{kv} = h = 64\), the cache would have been \(64/8 = 8\times\) larger, or roughly \(320\)–\(340\) GB per sequence, which is not deployable on any current single-server accelerator [src_020, src_010].
💡 Key result
A single \(128\)k-context Llama-3 70B request consumes roughly half an H100's HBM in KV-cache alone; the same model in MHA would need eight times that and would not fit on any single-server accelerator.
4. Multi-Head Attention recap¶
Standard multi-head attention (MHA), as introduced in the 2017 Transformer paper and recapped in Chapter 1, gives every one of \(h\) query heads its own independently learned \(K\) and \(V\) projections [src_020]. Concretely, with input \(X \in \mathbb{R}^{T \times D}\) and per-head dimension \(d_h = D/h\), MHA computes for each head \(i \in \{1, \ldots, h\}\)
with \(W_Q^{(i)}, W_K^{(i)}, W_V^{(i)} \in \mathbb{R}^{D \times d_h}\) all distinct across heads.
🎯 Intuition
Picture each query head as carrying its own private "lookup dictionary" — a unique \(K\) matrix mapping tokens to retrieval keys, and a unique \(V\) matrix mapping tokens to retrieved content. With \(h\) heads, the model maintains \(h\) independent dictionaries side by side. The output projection \(W_O\) (implicit here but appearing in the §12 closing summary) glues the per-head outputs back into a single residual-stream update. The §5–§6 question is whether the dictionaries themselves can be shared across heads, or only the reads on them.
The number of cached KV scalars per layer per token is then \(h \cdot d_h \cdot 2\) (for \(K\) and \(V\)), so \(h_{kv} = h\) in MHA and the KV-cache scales as \(h\) [src_020, src_002].
🔗 Connection
Multi-head attention with \(h\) independent KV projections, including the scaled dot-product softmax structure and the output-projection \(W_O\), is defined in Chapter 1; this chapter takes the §1 description as a starting point and varies \(h_{kv}\).
For the Llama-3 70B configuration of §3, MHA would mean \(h_{kv} = 64\) rather than \(h_{kv} = 8\). Because the per-head dimension and every other factor in the cache formula are unchanged, the cache is exactly \(h / h_{kv} = 64/8 = 8\times\) larger under MHA than under the GQA-8 design the model actually ships with [src_020]. This factor-of-eight swing on a single architectural knob is what made grouped-query attention an obvious win at scale.
5. Multi-Query Attention (MQA): the extreme¶
The first published attempt to break the linear scaling of the KV-cache with \(h\) was multi-query attention (MQA), proposed by Shazeer in 2019 [src_020]. MQA collapses all \(h\) KV heads into one shared pair: every query head still has its own learned projection \(W_Q^{(i)}\), but the keys and values are produced by a single \(W_K \in \mathbb{R}^{D \times d_h}\) and a single \(W_V \in \mathbb{R}^{D \times d_h}\) that are shared across all query heads [src_020]. In the language of §2, MQA sets \(h_{kv} = 1\), so the KV-cache shrinks by a factor of \(h\) relative to MHA [src_020].
The trade-off is quality and stability. Ainslie et al. (2023) report that MQA-trained models reach inference speed close to a single-head bottleneck — the T5-XXL benchmarks in their Table 4 show MQA inference time at \(0.24\) s per sample versus \(1.51\) s for MHA — but the corresponding average task score drops from \(47.2\) on MHA to \(46.6\) on MQA, and the authors describe MQA as prone to quality degradation and training instability, especially with long inputs [src_020]. PaLM committed to MQA from pre-training and absorbed the cost, but T5 and the early Llama family retained MHA. By 2023 the question was whether something between \(h_{kv} = 1\) and \(h_{kv} = h\) could keep most of the speed and most of the quality.
⚠️ Pitfall
MQA's quality drop is small in headline averages but uneven by task — long-input evaluations and few-shot reasoning show a wider gap than summarisation. "GQA-8 is essentially MHA-quality" is true at the same noise scale as random seeds; "MQA is essentially MHA-quality" is not.
6. Grouped-Query Attention (GQA): the interpolation that won¶
Grouped-query attention answers that question by introducing a group factor. With a group factor \(g\), the \(h\) query heads are partitioned into \(h_{kv} = h/g\) groups, and each group shares one \(K\) and one \(V\) head [src_020]. The two endpoints \(g = 1\) and \(g = h\) recover the familiar special cases: \(g = 1\) gives \(h_{kv} = h\), which is MHA, and \(g = h\) gives \(h_{kv} = 1\), which is MQA [src_020]. The interesting regime is the interior, where \(g\) is small (so the cache is dramatically smaller than MHA) but greater than one (so query heads do not all share the same key/value subspace).
Ainslie et al. study this directly. Starting from a fully trained MHA T5-XXL checkpoint, they collapse the \(K\) and \(V\) heads of each group into a single shared head by mean-pooling the per-head projections, and then continue pre-training for an additional 5% of the original training compute to repair the perturbation [src_020]. This uptraining procedure recovers most of the original quality at a small fraction of the original training cost: the GQA-8 variant of T5-XXL averages \(47.1\) across summarisation and question-answering benchmarks, compared to \(47.2\) for the original MHA model and \(46.6\) for the MQA variant, with inference time per sample of \(0.28\) s versus \(1.51\) s for MHA — so GQA-8 retains essentially MHA-level quality at roughly five times the throughput [src_020].
The economic implications of those numbers are what made GQA the default for the Llama-2 70B, Llama-3 70B, Qwen2, and DeepSeek-V2 families [src_010, src_047, src_002]. At 70B scale the query head count is large — Llama-3 70B uses \(h = 64\), as we noted in §3 — and the KV-cache savings of going from MHA to GQA-8 are exactly the factor-of-eight calculation we walked through. At the same time, the quality gap to MHA is within the noise of pre-training random seeds, and the gap to MQA is wide enough to matter in long-context evaluations [src_020]. The combination of "almost MHA quality" and "almost MQA cache size" is what GQA buys, and it explains why every open-weights frontier release after 2023 ships some variant of grouped-query attention [src_010, src_002].
🔄 Recap
- Complete the formula: in MQA the KV-cache size is \(2 \cdot L \cdot B \cdot T \cdot 1 \cdot d_h \cdot \text{bytes\_per\_elem}\) — write down the analogous expressions for MHA at \(h_{kv} = h\) and for GQA at \(h_{kv} = h/g\).
- Explain in your own words why the MHA→GQA-8 quality gap on T5-XXL is within seed noise but the MHA→MQA gap is not — what is GQA-8 keeping that MQA throws away?
- Predict: for a hypothetical Llama-class model with \(L = 60\), \(h = 48\), \(h_{kv} = 6\), \(d_h = 128\), fp16, what is the per-request KV-cache at \(T = 64\)k, and how does the factor-of-eight savings argument from §3 generalise?
7. Compute versus memory bandwidth: why standard attention is HBM-bound¶
Section 6 controls the size of the KV-cache; FlashAttention controls the cost of using it. Before describing FlashAttention itself, we have to be clear about which bottleneck it relieves. The 2017 description of attention as "\(\mathcal{O}(T^2 \, d_h)\) in compute and memory" treats the FLOP count and the memory footprint as the relevant resources, but on a modern GPU the binding constraint at long sequence length is neither — it is the bandwidth of the off-chip DRAM that connects the streaming multiprocessors (SMs — the GPU's parallel compute units, analogous to CPU cores) to the high-bandwidth memory (HBM) on the same package [src_021]. On a modern GPU, on-chip SRAM (per-SM scratchpad, ~MB-scale, ~10 TB/s effective bandwidth) sits one level above HBM (off-chip, ~GB-scale, ~1.5–2 TB/s) — an order-of-magnitude bandwidth gap, like the gap between L2 cache and main RAM on a CPU.
The reason is structural. On an A100 the peak FP16 matmul throughput is roughly \(312\) TFLOP/s, but the HBM bandwidth is roughly \(1.5\)–\(2\) TB/s; a single matrix multiply moves bytes in and out of HBM at the bandwidth-limited rate while consuming only a small fraction of the compute units, and this asymmetry is even sharper on H100 [src_021, src_023].
The standard implementation of scaled dot-product attention \(\text{softmax}(Q K^{\top}/\sqrt{d_h}) V\) allocates an explicit \(T \times T\) attention matrix in HBM, materialises it once during the softmax, and reads it back during the multiplication with \(V\) — three full passes over a \(T^2\)-sized tensor that the compute does not need but the implementation insists on [src_021]. At \(T = 8\text{k}\) that intermediate matrix is \(64\text{M}\) entries per head per batch element; at \(T = 128\text{k}\) it is \(16\text{B}\) entries, and the time spent shuttling it across the HBM-SRAM boundary dominates the runtime [src_021].
This is why an attention kernel that is memory-bandwidth-bound, not compute-bound, is the right object to optimise: cutting the number of HBM round-trips translates directly into wall-clock speedups regardless of whether the FLOP count is reduced [src_021].
🎯 Intuition
Tiling is the natural response to a bandwidth limit. If a value is read once from slow memory, used many times, and discarded, the cost is bandwidth-bound; if instead the value is read once into fast memory and reused there before being discarded, the bandwidth cost amortises across uses. FlashAttention applies exactly this pattern to the \(T \times T\) attention matrix — keep \(Q\), \(K\), \(V\) tiles resident in SRAM long enough that the per-byte work goes up before they are evicted.
FlashAttention is precisely such a kernel.
8. FlashAttention v1 (Dao et al., 2022): IO-aware tiling¶
FlashAttention reorganises the same exact computation around the GPU's memory hierarchy [src_021]. The key observation is that softmax is the only operation in the chain that requires all \(T\) key-query inner products simultaneously, and softmax has a numerically stable streaming form: rather than collecting the full row of logits, normalising once, and exponentiating, one can maintain a running maximum \(m\) and a running denominator \(\ell\) as logits arrive in blocks, updating both whenever a new block is seen [src_021].
🤔 Pause and reflect
Pause here. If you subtract the running maximum \(m\) from each new logit before exponentiating, the exponential is bounded — but does the answer depend on what the next block of logits looks like? Predict whether (and why) softmax is shift-invariant under subtracting any constant from every logit, then check the next sentence. (Do not look ahead — write the answer down or say it out loud.)
This is the online softmax trick, originally due to Milakov and Gimelshein in 2018; FlashAttention adapts it to attention so that the \(T \times T\) matrix never has to be assembled in one place [src_021].
Concretely, the kernel tiles \(Q\) into row blocks and \(K, V\) into column blocks. For each \(Q\) block, it loops over the \(K, V\) blocks, computes the partial inner products \(Q_{\text{block}} K_{\text{block}}^{\top}\) inside on-chip SRAM, runs the streaming softmax update on that partial slice, and accumulates the corresponding \(\text{softmax-weight} \cdot V_{\text{block}}\) contribution into an output buffer that lives in SRAM until the row block is finished [src_021].
The \(T \times T\) matrix is never materialised in HBM, the only HBM traffic is reading \(Q\), \(K\), \(V\) once and writing the output once, and the attention output that comes out of this loop is bit-for-bit equal to what the textbook implementation would have produced — FlashAttention is exact, not an approximation [src_021]. For the backward pass, the kernel re-runs the forward tiling rather than persisting the giant attention matrix, trading a small amount of recomputation for a large reduction in memory pressure [src_021].
The empirical payoff is the headline number of the 2022 paper: roughly a \(3\times\) end-to-end speedup on GPT-2 training and \(2.4\times\) on the long-range arena benchmark, with a \(10\)–\(20\times\) reduction in memory used by the attention layer — large enough to make Path-X (a \(16\)k-token classification task from the Long-Range Arena benchmark) and even Path-256 (\(64\)k tokens, also from Long-Range Arena) trainable for the first time [src_021]. We do not derive the streaming-softmax recurrence line by line in this chapter; the original paper and CS336 Lecture 5 carry that derivation in detail, and readers who want the proof of correctness should consult both [src_021, src_004].
9. FlashAttention-2 (Dao, 2023): work partitioning at the warp level¶
FlashAttention v1 was bandwidth-aware but not yet hardware-saturating. On the A100 it achieved only \(25\)–\(40\)% of the theoretical FP16 peak FLOPs, because the kernel's work partitioning left some hardware resources idle [src_022]. FlashAttention-2, published in 2023, is an engineering reorganisation that closes most of that gap [src_022].
Three changes carry the speedup. First, the kernel reduces the number of non-matmul floating-point operations: the original implementation rescaled the running output buffer at every block update, but most of those rescalings can be deferred until the row block is complete, replacing many small \(\exp\) and division calls with one final correction [src_022]. This matters because non-matmul FLOPs are not handled by the Tensor Cores (specialised matmul units alongside the general-purpose CUDA cores; on H100/A100 nearly all peak FP16/FP8 throughput lives in Tensor Cores, not in CUDA cores) and so consume a disproportionate share of cycles on a chip whose peak throughput lives in matmul. Second, it parallelises the forward pass across the sequence dimension as well as the batch and head dimensions, so that long-sequence single-batch workloads (which are typical of decode) keep all SMs busy rather than leaving them idle waiting for the next batch element [src_022]. Third, it re-partitions the work inside a thread block so that the four warps of a block share keys and values via shared memory rather than each fetching them from HBM, eliminating redundant loads and reducing intra-block synchronisation [src_022].
The combined effect on A100 is a \(2\times\) speedup over FlashAttention v1, reaching roughly \(225\) TFLOP/s in FP16 and a model-FLOPs-utilisation of about \(72\)% — the highest sustained attention throughput reported for the A100 generation [src_022]. The semantics of the kernel are unchanged: FlashAttention-2 still computes exact softmax attention, only faster.
10. FlashAttention-3 (Shah et al., 2024): Hopper-specific async and FP8¶
The H100 generation introduced new hardware paths that FlashAttention-2 did not exploit. The Tensor Memory Accelerator (TMA) supports asynchronous bulk loads from HBM into shared memory; the WGMMA instruction issues warpgroup-level (four cooperating warps acting as one issue unit on Hopper) matmul on the Tensor Cores; and the H100 supports FP8 with hardware acceleration. On the H100, FlashAttention-2 reached only about \(35\)% of peak — a regression in utilisation despite higher absolute throughput, because the kernel was not async-aware and used FP16 throughout [src_023].
FlashAttention-3 redesigns the kernel around these Hopper features. The forward pass uses producer/consumer warp specialisation: dedicated producer warps issue TMA loads of \(K\) and \(V\) tiles while consumer warps run the WGMMA matmul on tiles already in SRAM, so data movement and computation overlap rather than serialising [src_023].
The softmax is interleaved with the asynchronous matmul issue so that the non-matmul softmax work hides behind the GEMM (general matrix multiply — the dense matmul kernel issued on Tensor Cores) latency rather than blocking it [src_023]. Finally, the kernel adds a low-precision path: \(Q\), \(K\), \(V\) are quantised to FP8 with block quantisation (one scale factor per tile) and incoherent processing (a Hadamard-style randomised projection applied before quantisation to flatten outliers), which together reduce the FP8 quantisation error of attention by a factor of \(2.6\times\) relative to the per-tensor-quantised baseline [src_023].
🎯 Intuition
An orthogonal Hadamard projection rotates the activation vector so that any single "outlier" coordinate is spread across many coordinates of the rotated vector. After rotation, no single entry dominates; per-tile scaling factors then bracket a tighter range, and FP8 quantisation loses less. Crucially, the rotation is invertible — the kernel undoes it after the matmul, so the math is exact up to FP8 round-off.
The combined result is a \(1.5\)–\(2\times\) speedup over FlashAttention-2 on H100, reaching roughly \(740\) TFLOP/s in FP16 (\(\approx 75\)% MFU) and roughly \(1.2\) PFLOP/s in FP8 [src_023]. As with the v1-to-v2 transition, the kernel still computes exact softmax attention; the FP8 path additionally certifies a small, controlled numerical error that keeps end-to-end training and inference accuracy within noise of the FP16 reference [src_023].
💡 Key result
FlashAttention-3 reaches about \(75\%\) MFU on H100 in FP16 and roughly \(1.2\) PFLOP/s in FP8 — closing most of the utilisation gap that v2 left, and certifying a controlled FP8 numerical error inside the same exact-softmax envelope.
🔄 Recap
- Explain why standard scaled dot-product attention is HBM-bound at long \(T\) — what does the kernel read from HBM, and how many round-trips does the textbook implementation cost?
- Compare FlashAttention v1 and FlashAttention-2 on the A100 at the same problem size: what does v2 change about the kernel's work partitioning, and which of those three changes matters most for long-sequence single-batch decode?
- Predict: of the three Hopper-specific paths v3 exploits (TMA-based async loads, WGMMA warpgroup matmul, FP8), which would matter most for a workload that is dominated by single-batch long-decode rather than by training-time long-sequence forward passes?
11. A teaser for speculative decoding¶
KV-cache shrinking and FlashAttention together are the load-bearing wins of efficient attention; they are what makes long-context, large-model inference economically viable in 2026. They are not the only inference-time tricks worth knowing. Speculative decoding — drafting a short continuation with a cheaper model and verifying it in parallel with the target model — wins another \(2\)–\(3\times\) on top of GQA + FlashAttention by amortising the per-step decode cost across multiple accepted tokens [src_007]. We defer the algorithm to Chapter 8, where it sits alongside the rest of the modern decoder-only inference stack, and to Appendix B, which walks through the gpt-fast reference implementation [src_010]. The point for the present chapter is that speculative decoding multiplies the gains of GQA and FlashAttention; it does not replace them.
🔗 Connection
The speculative-decoding draft/target loop is developed in Chapter 8 alongside the rest of the modern decoder-only inference stack, and Appendix B walks through the gpt-fast reference implementation.
12. Closing summary¶
The story of efficient attention in the 2022–2026 window is that two architectural bottlenecks — the KV-cache, which is bound by HBM capacity, and the attention matrix materialisation, which is bound by HBM bandwidth — were attacked separately and resolved in parallel.
A compact summary of the KV-cache implications, holding \(L\), \(h\), \(d_h\), \(T\), \(B\), and the storage precision fixed at the Llama-3 70B / 128k / fp16 / single-sequence configuration of §3:
| Variant | \(h_{kv}\) | KV-cache ratio vs MHA | Approximate cache for Llama-3 70B at 128k |
|---|---|---|---|
| MHA | \(h = 64\) | \(1\) | \(\approx 320\)–\(340\) GB |
| GQA-8 (Llama-3 default) | \(8\) | \(1/8\) | \(\approx 40\)–\(43\) GB |
| MQA | \(1\) | \(1/64\) | \(\approx 5\) GB |
(Numbers are derived from the formula of §2 plus the worked example of §3; quality trade-offs from [src_020].)
A parallel summary of the runtime implications on a single A100 / H100, holding the attention layer and the sequence length fixed and varying only the kernel:
| Kernel | Hardware target | Approximate attention throughput | Speedup vs naive |
|---|---|---|---|
| Naive softmax-attention | A100 | bandwidth-bound, \(T^2\) HBM traffic | \(1\times\) (reference) |
| FlashAttention v1 | A100 | exact, tiled, no \(T \times T\) in HBM | \(\sim 3\times\) on GPT-2 [src_021] |
| FlashAttention-2 | A100 | \(\sim 225\) TFLOP/s, \(\sim 72\)% MFU | \(\sim 2\times\) over v1 [src_022] |
| FlashAttention-3 (FP16) | H100 | \(\sim 740\) TFLOP/s, \(\sim 75\)% MFU | \(\sim 1.5\)–\(2\times\) over v2 [src_023] |
| FlashAttention-3 (FP8) | H100 | \(\sim 1.2\) PFLOP/s | further \(\sim 2\times\) at controlled error [src_023] |
Both tables collapse into the same engineering claim. By 2026, the cost of attention at decode time on a frontier open-weights model is set neither by the textbook \(\mathcal{O}(T^2 \, d_h)\) FLOP count nor by the textbook \(\mathcal{O}(h \cdot d_h)\)-per-token cache — both of those quantities have been reduced by a small-integer constant factor through architectural and kernel-level engineering, while the underlying mathematical specification of softmax attention remains untouched. Chapter 8 picks up the thread where this chapter leaves it, building the modern decoder-only language model out of the components developed across Chapters 1–4: pre-RMSNorm wrappers, RoPE-augmented self-attention, GQA-shaped KV projections, FlashAttention-shaped kernels, and SwiGLU FFNs.
References¶
- src_002 — Tong Xiao and Jingbo Zhu. Foundations of Large Language Models. arXiv:2501.09223v2, 2025. https://arxiv.org/pdf/2501.09223
- src_004 — Tatsunori Hashimoto and Percy Liang. Stanford CS336: Language Modeling from Scratch (Spring 2025). Course landing page, 2025. https://stanford-cs336.github.io/spring2025/
- src_007 — Hugging Face. Ultra-Scale Playbook (Feb 2025). HuggingFace Spaces, 2025. https://huggingface.co/spaces/nanotron/ultrascale-playbook
- src_010 — Sebastian Raschka. Build a Large Language Model (From Scratch). Manning, 2024. https://github.com/rasbt/LLMs-from-scratch
- src_020 — Joshua Ainslie, James Lee-Thorp, Michiel de Jong, Yury Zemlyanskiy, Federico Lebrón, and Sumit Sanghai. GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. arXiv:2305.13245, 2023. https://arxiv.org/pdf/2305.13245
- src_021 — Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, and Christopher Ré. FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. arXiv:2205.14135, 2022. https://arxiv.org/pdf/2205.14135
- src_022 — Tri Dao. FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. arXiv:2307.08691, 2023. https://arxiv.org/pdf/2307.08691
- src_023 — Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, and Tri Dao. FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-Precision. arXiv:2407.08608, 2024. https://arxiv.org/pdf/2407.08608
- src_047 — Dilyan Grigorov. Building Large Language Models from Scratch. Apress, 2026. https://doi.org/10.1007/979-8-8688-2297-1