Ir para o conteúdo

Atenção Eficiente em Escala

1. Por que eficiência importa em 2026

O custo de um modelo de linguagem grande implantado não é pago majoritariamente durante o treinamento. Um modelo de fundação moderno é treinado uma única vez em um cluster de computação grande e então é servido aos usuários por meses ou anos; ao longo da vida útil de implantação do modelo, o total de computação de inferência excede a computação de treinamento por uma margem ampla, e a despesa operacional dominante é o custo de gerar tokens para os usuários [src_007]. A economia da attention, portanto, precisa ser analisada em tempo de decodificação, não em tempo de treinamento, e o gargalo relevante em tempo de decodificação não é aritmética, mas largura de banda de memória.

⚠️ Armadilha

A afirmação "decodificação é limitada pela largura de banda de memória, não pela computação" é a premissa determinante deste capítulo e é justificada aritmeticamente apenas em §7. Um leitor que ainda não viu a assimetria \(\sim 312\) TFLOP/s vs \(\sim 1{,}5\)\(2\) TB/s deve tratá-la como uma referência adiante para §7, em vez de como uma alegação injustificada.

A pressão se intensificou com o surgimento de aplicações de contexto longo. O Transformer de 2017 foi avaliado em comprimentos de sequência de algumas centenas de tokens, e a primeira geração de modelos de linguagem decoder-only foi implantada em contexto de 2k ou 4k [src_002]. Em 2024, a fronteira de pesos abertos era de 32k, e em 2026 implantações em produção rotineiramente servem janelas de 128k ou mais longas para geração aumentada por recuperação, análise de documentos e fluxos de trabalho agênticos [src_007]. Nesses comprimentos, duas coisas acontecem simultaneamente. A computação de attention quadrática \(\mathcal{O}(T^2 \, d_h)\) se torna uma fração mensurável do custo por passo, e o armazenamento linear-em-\(T\) do KV-cache se torna uma restrição vinculante para quantas requisições concorrentes um acelerador pode servir [src_002, src_007]. Ambas as pressões recaem sobre a layer de attention; nenhuma é sentida pelo FFN.

🔗 Conexão

Onde este capítulo lida com armazenar um KV-cache de \(128\)k, o Capítulo 2 lida com a questão de como o modelo usa \(128\)k tokens de informação posicional — RoPE mais os truques de extrapolação YaRN/NTK que permitem a um modelo treinado em contexto mais curto generalizar para fora.

Este capítulo aborda as três ideias de engenharia que os artigos de 2022–2024 usaram para absorver ambas as pressões sem alterar o que a attention computa. A primeira ideia, grouped-query attention (GQA), ataca o KV-cache compartilhando heads de chave/valor entre grupos de heads de consulta [src_020]. A segunda, FlashAttention, ataca o tempo de execução de attention organizando a computação em tiles para que a matriz de attention \(T \times T\) nunca precise ser materializada na HBM [src_021]. A terceira, os trabalhos subsequentes FlashAttention-2 e FlashAttention-3, atacam a lacuna entre os FLOPs alcançáveis e os de pico em aceleradores modernos reparticionando o trabalho e explorando caminhos assíncronos específicos do hardware [src_022, src_023]. Juntas, essas três mudanças tornam a inferência de contexto 100k em um modelo de pesos abertos da classe 70B viável em um único servidor de oito GPUs [src_007, src_047].

🔗 Conexão

Este capítulo desenvolve os componentes de KV-cache e FlashAttention que o Capítulo 8 montará na pilha moderna de inferência decoder-only — o mesmo capítulo que retoma o speculative decoding em detalhe.

2. O KV-cache: anatomia e fórmula de memória

A geração autorregressiva produz tokens um de cada vez. Em cada passo, o modelo recebe o token mais recente, faz seu embedding e o passa por todas as layers da pilha para produzir uma distribuição de logits para a próxima posição [src_002]. Como o estado oculto de cada token anterior em cada layer já foi computado, é desperdício recomputar chaves e valores para essas posições: seus tensores \(K_{\ell, t}\) e \(V_{\ell, t}\) não mudam de passo para passo, então podem ser armazenados em cache e reutilizados [src_002, src_020]. Esse cache, acumulado entre posições e layers e persistido entre passos de geração, é o KV-cache [src_002].

O tamanho do cache decorre diretamente de sua definição. Fixe uma layer de decoder \(\ell\) e uma posição gerada \(t\). A layer mantém dois tensores de ativação naquela posição, as chaves \(K_{\ell, t}\) e os valores \(V_{\ell, t}\), cada um com formato \((B, h_{kv}, d_h)\), onde \(B\) é o tamanho do batch, \(h_{kv}\) é o número de heads de KV (veremos em §4–§6 que \(h_{kv}\) não precisa ser igual ao número de heads de consulta \(h\)), e \(d_h\) é a dimensão por head [src_020, src_002]. Empilhando ao longo de \(T\) posições geradas e \(L\) layers, e contabilizando tanto o tensor de chaves quanto o de valores, o número total de escalares em cache é

\[ N_{\text{KV}} \;=\; 2 \cdot L \cdot B \cdot T \cdot h_{kv} \cdot d_h. \]

O fator inicial \(2\) é para \(K\) e \(V\) conjuntamente; é uma contagem de tensores, não uma contagem de bytes. Para converter escalares em bytes, multiplicamos pela precisão de armazenamento: em fp16 ou bf16, dois bytes por elemento [src_002, src_020]:

\[ \text{tamanho do KV-cache em bytes} \;=\; 2 \cdot L \cdot B \cdot T \cdot h_{kv} \cdot d_h \cdot \text{bytes\_per\_elem}. \]

Ambos os fatores de \(2\) importam e são independentes. O primeiro \(2\) é a contagem de \(K\)-mais-\(V\) e ainda estaria lá se o modelo usasse fp32. O segundo \(2\) é o número de bytes por elemento do armazenamento em meia precisão e seria \(4\) se o cache estivesse em fp32 ou \(1\) se fosse quantizado para int8 [src_020]. Confundir esses dois fatores é um erro comum de contabilidade em escrita informal; vamos mantê-los separados ao longo deste capítulo.

A fórmula tem três consequências qualitativas imediatas. Primeira, o cache cresce linearmente no comprimento da sequência \(T\), então dobrar a janela de contexto dobra o cache. Segunda, ele cresce linearmente no número de heads de KV \(h_{kv}\), então qualquer mudança arquitetônica que reduza \(h_{kv}\) encolhe o cache exatamente pelo mesmo fator. Terceira, ele cresce linearmente na contagem de layers \(L\), então a pegada de KV-cache de um modelo profundo e estreito pode ser maior que a de um modelo raso e largo mesmo quando suas contagens de parâmetros são equivalentes [src_020].

3. Um cálculo concreto de ordem de grandeza: Llama-3 70B em contexto 128k

Para converter essas leis de escala em um número que um engenheiro pode planejar, percorremos um único exemplo trabalhado. Tome o Llama-3 70B, que tem \(L = 80\) layers de decoder, \(h = 64\) heads de consulta, \(h_{kv} = 8\) heads de KV (o modelo usa GQA com oito grupos; retornaremos à escolha em §6), e dimensão por head \(d_h = 128\) [src_010, src_047]. Suponha que servimos uma única requisição, então \(B = 1\), em um comprimento de contexto \(T = 128\text{k} = 128 \cdot 1024 = 131{,}072\) tokens, com o cache mantido em fp16 então \(\text{bytes\_per\_elem} = 2\).

Substituindo na fórmula de §2:

\[ \begin{aligned} \text{tamanho do KV-cache} &= 2 \cdot L \cdot B \cdot T \cdot h_{kv} \cdot d_h \cdot \text{bytes\_per\_elem} \\ &= 2 \cdot 80 \cdot 1 \cdot 131{,}072 \cdot 8 \cdot 128 \cdot 2 \quad \text{bytes}. \end{aligned} \]

Multiplicando as constantes passo a passo:

  • \(2 \cdot 80 = 160\).
  • \(160 \cdot 131{,}072 = 20{,}971{,}520\).
  • \(20{,}971{,}520 \cdot 8 = 167{,}772{,}160\).
  • \(167{,}772{,}160 \cdot 128 = 21{,}474{,}836{,}480\).
  • \(21{,}474{,}836{,}480 \cdot 2 = 42{,}949{,}672{,}960\) bytes.

Um atalho por-unidade útil se esconde dentro desse total de 11 dígitos. O custo por (token, head de KV, elemento da dimensão por head) é \(2 \cdot 2 = 4\) bytes (a contagem de \(K\)+\(V\) vezes os bytes por elemento); multiplicado por \(d_h = 128\) isso dá \(512\) bytes por token por head de KV, então o resto da cadeia é apenas \(L \cdot B \cdot T \cdot h_{kv}\). Esse atalho mental — cerca de \(512\) bytes por (token, head de KV) com \(d_h = 128\), fp16 — é o que torna o argumento de economia de GQA em §6 aritmético, e não retórico.

🤔 Pause e pense

Antes de seguir, preveja — com \(B = 8\) requisições concorrentes em vez de \(B = 1\), qual KV-cache por servidor a configuração Llama-3 70B em \(128\)k requer, e cabe nos \(640\) GB de um servidor de oito H100s uma vez que os pesos do modelo também são contabilizados? (Não olhe adiante — escreva a resposta ou diga em voz alta.)

Isso são \(42{,}949{,}672{,}960\) bytes por requisição. Dividindo por \(2^{30} = 1{,}073{,}741{,}824\) para converter para gibibytes,

\[ \text{tamanho do KV-cache} \;=\; \frac{42{,}949{,}672{,}960}{2^{30}} \;=\; 40 \text{ GiB}, \]

ou, equivalentemente, cerca de \(43\) GB em unidades decimais (dividindo por \(10^9\)). A nota de síntese para esta parte registra o mesmo cálculo no nível arredondado de \(\approx 41\) GB [src_007]; dependendo de se usamos gibibytes ou gigabytes, o número arredonda para \(40\) ou \(43\), e escreveremos o resultado arredondado como aproximadamente \(40\)\(43\) GB por sequência para sermos honestos sobre a ambiguidade de unidade.

🎯 Intuição

GiB (\(2^{30}\) bytes) e GB (\(10^9\) bytes) diferem por cerca de \(7{,}4\%\). Datasheets de HBM tipicamente citam GB; ferramentas de SO e muitos frameworks de ML reportam em GiB. O mesmo KV-cache é honestamente \(40\) GiB ou \(43\) GB; a decisão de engenharia não depende de qual unidade você escreve — apenas de lembrar qual delas você está lendo.

Duas observações tornam esse número acionável. Primeira, \(40\) GB é comparável à memória total de pesos em fp16 de um modelo da classe \(20\)B e excede em metade os \(80\) GB de HBM de um único H100 quando os próprios pesos do modelo (que para um modelo \(70\)B em fp16 são \(\approx 140\) GB e já precisam ser fragmentados entre múltiplas GPUs) são contabilizados [src_007, src_047]. Servir uma única requisição \(128\)k, portanto, consome uma fração substancial da memória do cluster, e servir muitas requisições concorrentes esgotaria a HBM muito antes de a computação saturar. Segunda, o tamanho escala com \(h_{kv}\) de uma forma que torna a escolha arquitetônica de grouped-query attention financeiramente determinante: se o mesmo modelo tivesse retido a multi-head attention original com \(h_{kv} = h = 64\), o cache teria sido \(64/8 = 8\times\) maior, ou aproximadamente \(320\)\(340\) GB por sequência, o que não é implantável em nenhum acelerador atual de servidor único [src_020, src_010].

💡 Resultado-chave

Uma única requisição de \(128\)k de contexto no Llama-3 70B consome aproximadamente metade da HBM de uma H100 apenas em KV-cache; o mesmo modelo em MHA precisaria de oito vezes isso e não caberia em nenhum acelerador de servidor único.

4. Recapitulação de Multi-Head Attention

A multi-head attention padrão (MHA), como introduzida no artigo do Transformer de 2017 e recapitulada no Capítulo 1, dá a cada uma das \(h\) heads de consulta suas próprias projeções \(K\) e \(V\) aprendidas independentemente [src_020]. Concretamente, com entrada \(X \in \mathbb{R}^{T \times D}\) e dimensão por head \(d_h = D/h\), MHA computa para cada head \(i \in \{1, \ldots, h\}\)

\[ Q^{(i)} = X W_Q^{(i)}, \qquad K^{(i)} = X W_K^{(i)}, \qquad V^{(i)} = X W_V^{(i)}, \]

com \(W_Q^{(i)}, W_K^{(i)}, W_V^{(i)} \in \mathbb{R}^{D \times d_h}\) todos distintos entre heads.

🎯 Intuição

Imagine cada head de consulta carregando seu próprio "dicionário de consulta" privado — uma matriz \(K\) única que mapeia tokens em chaves de recuperação, e uma matriz \(V\) única que mapeia tokens em conteúdo recuperado. Com \(h\) heads, o modelo mantém \(h\) dicionários independentes lado a lado. A projeção de saída \(W_O\) (implícita aqui, mas aparecendo no resumo de fechamento de §12) cola as saídas por head novamente em uma única atualização do residual stream. A pergunta de §5–§6 é se os dicionários em si podem ser compartilhados entre heads, ou apenas as leituras sobre eles.

O número de escalares de KV em cache por layer por token é então \(h \cdot d_h \cdot 2\) (para \(K\) e \(V\)), então \(h_{kv} = h\) em MHA e o KV-cache escala como \(h\) [src_020, src_002].

🔗 Conexão

A multi-head attention com \(h\) projeções de KV independentes, incluindo a estrutura de softmax escalado e a projeção de saída \(W_O\), é definida no Capítulo 1; este capítulo toma a descrição de §1 como ponto de partida e varia \(h_{kv}\).

Para a configuração do Llama-3 70B de §3, MHA significaria \(h_{kv} = 64\) em vez de \(h_{kv} = 8\). Como a dimensão por head e todos os outros fatores na fórmula do cache são inalterados, o cache é exatamente \(h / h_{kv} = 64/8 = 8\times\) maior sob MHA do que sob o design GQA-8 que o modelo realmente entrega [src_020]. Esse salto de fator oito em um único parâmetro arquitetônico é o que tornou grouped-query attention uma vitória óbvia em escala.

5. Multi-Query Attention (MQA): o extremo

A primeira tentativa publicada de quebrar a escala linear do KV-cache com \(h\) foi multi-query attention (MQA), proposta por Shazeer em 2019 [src_020]. MQA colapsa todas as \(h\) heads de KV em um único par compartilhado: cada head de consulta ainda tem sua própria projeção aprendida \(W_Q^{(i)}\), mas as chaves e valores são produzidos por um único \(W_K \in \mathbb{R}^{D \times d_h}\) e um único \(W_V \in \mathbb{R}^{D \times d_h}\) que são compartilhados entre todas as heads de consulta [src_020]. Na linguagem de §2, MQA define \(h_{kv} = 1\), então o KV-cache encolhe por um fator de \(h\) relativo a MHA [src_020].

O trade-off é qualidade e estabilidade. Ainslie et al. (2023) relatam que modelos treinados com MQA atingem velocidade de inferência próxima a um gargalo de head única — os benchmarks T5-XXL na Tabela 4 mostram tempo de inferência MQA em \(0{,}24\) s por amostra contra \(1{,}51\) s para MHA — mas a pontuação média correspondente nas tarefas cai de \(47{,}2\) em MHA para \(46{,}6\) em MQA, e os autores descrevem MQA como propensa a degradação de qualidade e instabilidade de treinamento, especialmente com entradas longas [src_020]. PaLM se comprometeu com MQA desde o pré-treinamento e absorveu o custo, mas T5 e a primeira família Llama mantiveram MHA. Em 2023, a questão era se algo entre \(h_{kv} = 1\) e \(h_{kv} = h\) poderia manter a maior parte da velocidade e a maior parte da qualidade.

⚠️ Armadilha

A queda de qualidade da MQA é pequena nas médias gerais, mas desigual por tarefa — avaliações de entrada longa e raciocínio few-shot mostram uma lacuna mais ampla do que sumarização. "GQA-8 é essencialmente qualidade MHA" é verdadeiro na mesma escala de ruído de sementes aleatórias; "MQA é essencialmente qualidade MHA" não é.

6. Grouped-Query Attention (GQA): a interpolação que venceu

Grouped-query attention responde a essa pergunta introduzindo um fator de grupo. Com um fator de grupo \(g\), as \(h\) heads de consulta são particionadas em \(h_{kv} = h/g\) grupos, e cada grupo compartilha uma head \(K\) e uma head \(V\) [src_020]. Os dois extremos \(g = 1\) e \(g = h\) recuperam os casos especiais familiares: \(g = 1\)\(h_{kv} = h\), que é MHA, e \(g = h\)\(h_{kv} = 1\), que é MQA [src_020]. O regime interessante é o interior, onde \(g\) é pequeno (então o cache é dramaticamente menor que MHA) mas maior que um (então as heads de consulta não compartilham todas o mesmo subespaço de chave/valor).

Ainslie et al. estudam isso diretamente. Partindo de um checkpoint T5-XXL totalmente treinado em MHA, eles colapsam as heads \(K\) e \(V\) de cada grupo em uma única head compartilhada por meio de média das projeções por head, e então continuam o pré-treinamento por mais 5% da computação de treinamento original para reparar a perturbação [src_020]. Esse procedimento de uptraining recupera a maior parte da qualidade original a uma pequena fração do custo de treinamento original: a variante GQA-8 do T5-XXL alcança média de \(47{,}1\) em benchmarks de sumarização e perguntas e respostas, comparado a \(47{,}2\) para o modelo MHA original e \(46{,}6\) para a variante MQA, com tempo de inferência por amostra de \(0{,}28\) s contra \(1{,}51\) s para MHA — então GQA-8 retém qualidade essencialmente em nível MHA com aproximadamente cinco vezes a vazão [src_020].

As implicações econômicas desses números são o que tornou GQA o padrão para as famílias Llama-2 70B, Llama-3 70B, Qwen2 e DeepSeek-V2 [src_010, src_047, src_002]. Em escala 70B a contagem de heads de consulta é grande — Llama-3 70B usa \(h = 64\), como notamos em §3 — e as economias de KV-cache ao ir de MHA para GQA-8 são exatamente o cálculo de fator-oito que percorremos. Ao mesmo tempo, a lacuna de qualidade para MHA está dentro do ruído das sementes aleatórias de pré-treinamento, e a lacuna para MQA é ampla o suficiente para importar em avaliações de contexto longo [src_020]. A combinação de "qualidade quase MHA" e "tamanho de cache quase MQA" é o que GQA compra, e isso explica por que toda versão de fronteira de pesos abertos após 2023 entrega alguma variante de grouped-query attention [src_010, src_002].

🔄 Recapitulação

  • Complete a fórmula: em MQA o tamanho do KV-cache é \(2 \cdot L \cdot B \cdot T \cdot 1 \cdot d_h \cdot \text{bytes\_per\_elem}\) — escreva as expressões análogas para MHA com \(h_{kv} = h\) e para GQA com \(h_{kv} = h/g\).
  • Explique com suas próprias palavras por que a lacuna de qualidade MHA→GQA-8 no T5-XXL está dentro do ruído de sementes mas a lacuna MHA→MQA não está — o que GQA-8 mantém que MQA descarta?
  • Preveja: para um modelo hipotético da classe Llama com \(L = 60\), \(h = 48\), \(h_{kv} = 6\), \(d_h = 128\), fp16, qual é o KV-cache por requisição em \(T = 64\)k, e como o argumento de economia de fator oito de §3 se generaliza?

7. Computação versus largura de banda de memória: por que attention padrão é limitada por HBM

A Seção 6 controla o tamanho do KV-cache; FlashAttention controla o custo de usá-lo. Antes de descrever o próprio FlashAttention, precisamos ser claros sobre qual gargalo ele alivia. A descrição de 2017 da attention como "\(\mathcal{O}(T^2 \, d_h)\) em computação e memória" trata a contagem de FLOPs e a pegada de memória como os recursos relevantes, mas em uma GPU moderna a restrição vinculante em comprimentos de sequência longos não é nenhum dos dois — é a largura de banda da DRAM externa que conecta os streaming multiprocessors (SMs — as unidades de computação paralela da GPU, análogas aos núcleos de uma CPU) à high-bandwidth memory (HBM) no mesmo pacote [src_021]. Em uma GPU moderna, a SRAM em chip (scratchpad por SM, escala de MB, largura de banda efetiva de ~10 TB/s) fica um nível acima da HBM (fora do chip, escala de GB, ~1,5–2 TB/s) — uma diferença de uma ordem de grandeza em largura de banda, como a diferença entre cache L2 e RAM principal em uma CPU.

A razão é estrutural. Em uma A100, a vazão de pico de matmul FP16 é aproximadamente \(312\) TFLOP/s, mas a largura de banda HBM é aproximadamente \(1{,}5\)\(2\) TB/s; uma única multiplicação de matrizes move bytes para dentro e fora da HBM na taxa limitada pela largura de banda enquanto consome apenas uma pequena fração das unidades de computação, e essa assimetria é ainda mais acentuada na H100 [src_021, src_023].

A implementação padrão de scaled dot-product attention \(\text{softmax}(Q K^{\top}/\sqrt{d_h}) V\) aloca uma matriz explícita de attention \(T \times T\) na HBM, materializa-a uma vez durante o softmax e a lê novamente durante a multiplicação com \(V\) — três passagens completas sobre um tensor de tamanho \(T^2\) que a computação não precisa, mas que a implementação insiste [src_021]. Em \(T = 8\text{k}\), essa matriz intermediária é \(64\text{M}\) entradas por head por elemento de batch; em \(T = 128\text{k}\) são \(16\text{B}\) entradas, e o tempo gasto movendo-a através da fronteira HBM-SRAM domina o tempo de execução [src_021].

É por isso que um kernel de attention que é limitado pela largura de banda de memória, não limitado pela computação, é o objeto certo a otimizar: cortar o número de viagens de ida e volta na HBM se traduz diretamente em acelerações de tempo real, independentemente de a contagem de FLOPs ser reduzida [src_021].

🎯 Intuição

Tiling é a resposta natural a um limite de largura de banda. Se um valor é lido uma vez da memória lenta, usado muitas vezes e descartado, o custo é limitado pela largura de banda; se em vez disso o valor é lido uma vez para a memória rápida e reutilizado lá antes de ser descartado, o custo de largura de banda se amortiza entre os usos. FlashAttention aplica exatamente esse padrão à matriz de attention \(T \times T\) — manter os tiles de \(Q\), \(K\), \(V\) residentes na SRAM tempo suficiente para que o trabalho por byte aumente antes de serem despejados.

FlashAttention é precisamente esse tipo de kernel.

8. FlashAttention v1 (Dao et al., 2022): tiling consciente de IO

FlashAttention reorganiza exatamente a mesma computação em torno da hierarquia de memória da GPU [src_021]. A observação chave é que o softmax é a única operação na cadeia que requer todos os \(T\) produtos internos chave-consulta simultaneamente, e o softmax tem uma forma de streaming numericamente estável: em vez de coletar a linha completa de logits, normalizar uma vez e exponenciar, pode-se manter um máximo corrente \(m\) e um denominador corrente \(\ell\) à medida que os logits chegam em blocos, atualizando ambos sempre que um novo bloco é visto [src_021].

🤔 Pause e pense

Pare aqui. Se você subtrair o máximo corrente \(m\) de cada novo logit antes de exponenciar, o exponencial fica limitado — mas a resposta depende de como o próximo bloco de logits se parece? Preveja se (e por quê) o softmax é invariante sob a subtração de qualquer constante de todos os logits, depois confira a próxima frase. (Não olhe adiante — escreva a resposta ou diga em voz alta.)

Esse é o truque do online softmax, originalmente devido a Milakov e Gimelshein em 2018; FlashAttention o adapta à attention para que a matriz \(T \times T\) nunca precise ser montada em um único lugar [src_021].

Concretamente, o kernel divide \(Q\) em blocos de linhas e \(K, V\) em blocos de colunas. Para cada bloco \(Q\), ele itera sobre os blocos \(K, V\), computa os produtos internos parciais \(Q_{\text{block}} K_{\text{block}}^{\top}\) dentro da SRAM em chip, executa a atualização de softmax em streaming nessa fatia parcial e acumula a contribuição correspondente \(\text{softmax-weight} \cdot V_{\text{block}}\) em um buffer de saída que vive na SRAM até o bloco de linhas estar terminado [src_021].

A matriz \(T \times T\) nunca é materializada na HBM, o único tráfego HBM é ler \(Q\), \(K\), \(V\) uma vez e escrever a saída uma vez, e a saída de attention que sai desse loop é bit a bit igual ao que a implementação do livro-texto teria produzido — FlashAttention é exato, não uma aproximação [src_021]. Para a passagem de retropropagação, o kernel re-executa o tiling da passagem direta em vez de persistir a matriz gigante de attention, trocando uma pequena quantidade de recomputação por uma grande redução na pressão de memória [src_021].

O ganho empírico é o número principal do artigo de 2022: aproximadamente um ganho de velocidade ponta a ponta de \(3\times\) no treinamento do GPT-2 e \(2{,}4\times\) no benchmark long-range arena, com uma redução de \(10\)\(20\times\) na memória usada pela layer de attention — grande o suficiente para tornar o Path-X (uma tarefa de classificação de \(16\)k tokens do benchmark Long-Range Arena) e até o Path-256 (\(64\)k tokens, também do Long-Range Arena) treináveis pela primeira vez [src_021]. Não derivamos a recorrência de softmax em streaming linha por linha neste capítulo; o artigo original e o CS336 Lecture 5 carregam essa derivação em detalhe, e os leitores que quiserem a prova de correção devem consultar ambos [src_021, src_004].

9. FlashAttention-2 (Dao, 2023): particionamento de trabalho no nível de warp

O FlashAttention v1 era consciente da largura de banda, mas ainda não saturava o hardware. Na A100, alcançava apenas \(25\)\(40\)% dos FLOPs FP16 de pico teórico, porque o particionamento de trabalho do kernel deixava alguns recursos de hardware ociosos [src_022]. FlashAttention-2, publicado em 2023, é uma reorganização de engenharia que fecha a maior parte dessa lacuna [src_022].

Três mudanças carregam a aceleração. Primeira, o kernel reduz o número de operações de ponto flutuante não-matmul: a implementação original re-escalava o buffer de saída corrente a cada atualização de bloco, mas a maioria dessas re-escalagens pode ser adiada até o bloco de linhas estar completo, substituindo muitas pequenas chamadas de \(\exp\) e divisão por uma correção final [src_022]. Isso importa porque FLOPs não-matmul não são manipulados pelos Tensor Cores (unidades especializadas de matmul que coexistem com os CUDA cores de propósito geral; em H100/A100 quase toda a vazão de pico em FP16/FP8 vive nos Tensor Cores, não nos CUDA cores) e, portanto, consomem uma parcela desproporcional de ciclos em um chip cuja vazão de pico vive em matmul. Segunda, ela paraleliza a passagem direta ao longo da dimensão de sequência, bem como das dimensões de batch e head, para que cargas de trabalho de sequência longa com batch único (que são típicas da decodificação) mantenham todos os SMs ocupados em vez de deixá-los ociosos esperando o próximo elemento de batch [src_022]. Terceira, ela reparticiona o trabalho dentro de um bloco de threads para que os quatro warps de um bloco compartilhem chaves e valores via memória compartilhada em vez de cada um buscá-los da HBM, eliminando cargas redundantes e reduzindo a sincronização intra-bloco [src_022].

O efeito combinado na A100 é uma aceleração de \(2\times\) sobre FlashAttention v1, alcançando aproximadamente \(225\) TFLOP/s em FP16 e uma utilização de FLOPs do modelo de cerca de \(72\)% — a maior vazão sustentada de attention reportada para a geração A100 [src_022]. A semântica do kernel é inalterada: FlashAttention-2 ainda computa attention softmax exata, apenas mais rápido.

10. FlashAttention-3 (Shah et al., 2024): assíncrono específico do Hopper e FP8

A geração H100 introduziu novos caminhos de hardware que o FlashAttention-2 não explorava. O Tensor Memory Accelerator (TMA) suporta cargas em massa assíncronas da HBM para a memória compartilhada; a instrução WGMMA emite matmul em nível de warpgroup (quatro warps cooperando como uma única unidade de emissão no Hopper) nos Tensor Cores; e o H100 suporta FP8 com aceleração de hardware. No H100, FlashAttention-2 atingiu apenas cerca de \(35\)% do pico — uma regressão na utilização apesar da maior vazão absoluta, porque o kernel não era consciente de assincronia e usava FP16 em todo lugar [src_023].

FlashAttention-3 redesenha o kernel em torno desses recursos do Hopper. A passagem direta usa especialização de warps produtor/consumidor: warps produtores dedicados emitem cargas TMA de tiles \(K\) e \(V\) enquanto warps consumidores executam o matmul WGMMA em tiles já na SRAM, então o movimento de dados e a computação se sobrepõem em vez de serializar [src_023].

O softmax é intercalado com a emissão assíncrona de matmul para que o trabalho de softmax não-matmul se esconda atrás da latência do GEMM (general matrix multiply — o kernel denso de matmul emitido nos Tensor Cores) em vez de bloqueá-la [src_023]. Finalmente, o kernel adiciona um caminho de baixa precisão: \(Q\), \(K\), \(V\) são quantizados para FP8 com quantização por bloco (um fator de escala por tile) e processamento incoerente (uma projeção aleatorizada estilo Hadamard aplicada antes da quantização para achatar outliers), que juntos reduzem o erro de quantização FP8 da attention por um fator de \(2{,}6\times\) relativo à linha de base quantizada por tensor [src_023].

🎯 Intuição

Uma projeção Hadamard ortogonal rotaciona o vetor de ativação de modo que qualquer coordenada "outlier" isolada seja espalhada por muitas coordenadas do vetor rotacionado. Após a rotação, nenhuma entrada isolada domina; os fatores de escala por tile então abrangem uma faixa mais estreita, e a quantização FP8 perde menos. Crucialmente, a rotação é invertível — o kernel a desfaz após o matmul, então a matemática é exata até o arredondamento de FP8.

O resultado combinado é uma aceleração de \(1{,}5\)\(2\times\) sobre FlashAttention-2 no H100, alcançando aproximadamente \(740\) TFLOP/s em FP16 (\(\approx 75\)% MFU) e aproximadamente \(1{,}2\) PFLOP/s em FP8 [src_023]. Como na transição de v1 para v2, o kernel ainda computa attention softmax exata; o caminho FP8 adicionalmente certifica um pequeno erro numérico controlado que mantém a precisão de treinamento e inferência ponta a ponta dentro do ruído da referência FP16 [src_023].

💡 Resultado-chave

FlashAttention-3 atinge cerca de \(75\%\) de MFU na H100 em FP16 e aproximadamente \(1{,}2\) PFLOP/s em FP8 — fechando a maior parte da lacuna de utilização que a v2 deixou, e certificando um erro numérico FP8 controlado dentro do mesmo envelope de softmax exato.

🔄 Recapitulação

  • Explique por que a scaled dot-product attention padrão é limitada por HBM em \(T\) longo — o que o kernel lê da HBM, e quantas viagens de ida e volta a implementação do livro-texto custa?
  • Compare FlashAttention v1 e FlashAttention-2 na A100 no mesmo tamanho de problema: o que a v2 muda sobre o particionamento de trabalho do kernel, e qual dessas três mudanças mais importa para decodificação de sequência longa com batch único?
  • Preveja: dos três caminhos específicos do Hopper que a v3 explora (cargas assíncronas baseadas em TMA, matmul de warpgroup WGMMA, FP8), qual mais importaria para uma carga de trabalho dominada por decodificação de sequência longa com batch único, em vez de passagens diretas de sequência longa em tempo de treinamento?

11. Um teaser para speculative decoding

Encolher o KV-cache e FlashAttention juntos são as vitórias determinantes da attention eficiente; eles são o que torna a inferência de contexto longo e modelo grande economicamente viável em 2026. Eles não são os únicos truques de tempo de inferência que vale conhecer. O speculative decoding — esboçar uma curta continuação com um modelo mais barato e verificá-la em paralelo com o modelo alvo — ganha mais \(2\)\(3\times\) sobre GQA + FlashAttention amortizando o custo de decodificação por passo entre múltiplos tokens aceitos [src_007]. Adiamos o algoritmo para o Capítulo 8, onde ele se senta junto com o resto da pilha moderna de inferência decoder-only, e para o Apêndice B, que percorre a implementação de referência do gpt-fast [src_010]. O ponto para o presente capítulo é que speculative decoding multiplica os ganhos de GQA e FlashAttention; ele não os substitui.

🔗 Conexão

O loop draft/target do speculative decoding é desenvolvido no Capítulo 8 junto com o restante da pilha moderna de inferência decoder-only, e o Apêndice B percorre a implementação de referência do gpt-fast.

12. Resumo de fechamento

A história da attention eficiente na janela 2022–2026 é que dois gargalos arquitetônicos — o KV-cache, que é limitado pela capacidade da HBM, e a materialização da matriz de attention, que é limitada pela largura de banda da HBM — foram atacados separadamente e resolvidos em paralelo.

Um resumo compacto das implicações do KV-cache, mantendo \(L\), \(h\), \(d_h\), \(T\), \(B\) e a precisão de armazenamento fixados na configuração Llama-3 70B / 128k / fp16 / sequência única de §3:

Variante \(h_{kv}\) Razão de KV-cache vs MHA Cache aproximado para Llama-3 70B em 128k
MHA \(h = 64\) \(1\) \(\approx 320\)\(340\) GB
GQA-8 (padrão Llama-3) \(8\) \(1/8\) \(\approx 40\)\(43\) GB
MQA \(1\) \(1/64\) \(\approx 5\) GB

(Os números são derivados da fórmula de §2 mais o exemplo trabalhado de §3; trade-offs de qualidade de [src_020].)

Um resumo paralelo das implicações de tempo de execução em uma única A100 / H100, mantendo a layer de attention e o comprimento de sequência fixos e variando apenas o kernel:

Kernel Hardware-alvo Vazão aproximada de attention Aceleração vs ingênua
Softmax-attention ingênua A100 limitada pela largura de banda, tráfego HBM \(T^2\) \(1\times\) (referência)
FlashAttention v1 A100 exata, em tiles, sem \(T \times T\) na HBM \(\sim 3\times\) no GPT-2 [src_021]
FlashAttention-2 A100 \(\sim 225\) TFLOP/s, \(\sim 72\)% MFU \(\sim 2\times\) sobre v1 [src_022]
FlashAttention-3 (FP16) H100 \(\sim 740\) TFLOP/s, \(\sim 75\)% MFU \(\sim 1{,}5\)\(2\times\) sobre v2 [src_023]
FlashAttention-3 (FP8) H100 \(\sim 1{,}2\) PFLOP/s mais \(\sim 2\times\) a erro controlado [src_023]

Ambas as tabelas colapsam na mesma alegação de engenharia. Em 2026, o custo de attention em tempo de decodificação em um modelo de pesos abertos de fronteira não é definido nem pela contagem de FLOPs \(\mathcal{O}(T^2 \, d_h)\) do livro-texto nem pelo cache \(\mathcal{O}(h \cdot d_h)\)-por-token do livro-texto — ambas essas quantidades foram reduzidas por um pequeno fator constante inteiro através de engenharia arquitetônica e em nível de kernel, enquanto a especificação matemática subjacente da attention softmax permanece intocada. O Capítulo 8 retoma o fio onde este capítulo o deixa, construindo o moderno modelo de linguagem decoder-only a partir dos componentes desenvolvidos ao longo dos Capítulos 1–4: invólucros pre-RMSNorm, self-attention aumentada por RoPE, projeções de KV no formato GQA, kernels no formato FlashAttention e FFNs SwiGLU.

Referências