- NVIDIA Blackwellの非対称ハードウェアスケーリングに対応し、cuDNN比1.3倍・Triton比2.7倍の高速化を実現
- 完全非同期MMA操作とsoftmax再スケーリングの最適化でBF16精度で1613 TFLOPs/sを達成
- CuTe-DSLによるPython実装でコンパイル時間を従来比20〜30倍短縮し、開発効率も向上
研究の背景:非対称スケーリングという壁
大規模言語モデル(LLM)の推論・学習において、Attention(注意機構)はボトルネックの一つです。入力トークン数の2乗に比例して計算量が増えるため、長文処理では特に顕著なコストになります。
FlashAttention-2、FlashAttention-3と進化してきたこのシリーズは、メモリ効率と演算効率を両立する手法として広く普及しています。しかし、NVIDIAが2024年末から展開しているBlackwellアーキテクチャ(B200、GB200)は、従来世代とは異なる特性を持ちます。
Blackwellではテンソルコアの演算性能が前世代比2倍に強化されています。ところが共有メモリ(SRAM)の帯域幅や指数関数演算ユニットの強化は相対的に緩やかであり、この非対称な性能バランスが新たな最適化の課題を生み出しています。FlashAttention-4は、この「非対称ハードウェアスケーリング」に正面から向き合うことをコンセプトに設計されました。
提案手法:3つの協調最適化
本研究はアルゴリズムとカーネルパイプラインを一体で設計する「協調設計(co-design)」という方針を採り、3つの主要な最適化を組み合わせています。

1つ目は完全非同期MMA操作とタイルサイズの拡大です。行列乗算演算(MMA)を完全に非同期化し、処理パイプラインを再設計することで、テンソルコアの演算が他の処理に待たされる時間を最小化します。またBlackwellの大きなレジスタファイルを活かしてタイルサイズを拡大し、1回の演算で処理できるデータ量を増やしています。
2つ目は指数関数演算とsoftmaxの再スケーリング最適化です。Attentionの核心にあるsoftmax計算には多くの指数関数演算が含まれますが、テンソルコア以外で処理されるこれらの「非matmul演算」がボトルネックになりやすいです。FlashAttention-4では条件付きsoftmax再スケーリングのアルゴリズムを改良し、指数関数の呼び出し回数を削減することでこの非対称性に対処しています。
3つ目は逆伝播(バックワードパス)の効率化です。テンソルメモリと2-CTA(Cooperative Thread Array)MMAモードを活用して共有メモリのトラフィックを削減し、アトミック演算のオーバーヘッドも抑えています。学習時に必要な勾配計算の効率が向上することで、ファインチューニングや事前学習での恩恵も大きくなります。
実験結果:BF16で1613 TFLOPs/sを達成
B200 GPUでの評価において、FlashAttention-4はBF16精度で最大1613 TFLOPs/sを記録しました。これはFlashAttention-4が到達できる理論性能の71%に相当し、現実的なワークロードで非常に高い演算効率を示しています。

既存の最適化ライブラリとの比較では、NVIDIAが提供するcuDNN 9.13に対して最大1.3倍の高速化、PythonベースのGPUカーネル記述言語であるTritonに対しては最大2.7倍の高速化を実現しています。
同様の課題に取り組む研究として、強化学習でGPUカーネルを自動最適化するCUDA Agentも注目されていますが、FlashAttention-4はAttentionという特定演算に特化した手動設計により、より高い性能を引き出しています。
CuTe-DSLによる開発効率の革新
FlashAttention-4のもう一つの貢献が、実装方法の刷新です。従来のFlashAttentionシリーズはC++テンプレートを駆使した実装でしたが、FlashAttention-4ではCuTe-DSL(Domain-Specific Language)をPythonに組み込む形で実装しています。
CuTe-DSLはNVIDIAが開発したテンソル操作の記述言語で、低レベルのGPUカーネル最適化をより抽象的な記述で行えます。Python上での実装に切り替えることで、コンパイル時間が従来比で20〜30倍短縮されました。C++テンプレートのビルドには数十分かかるケースもありましたが、Python実装ではその障壁が大幅に低くなります。
完全な表現力(expressiveness)は維持しつつ開発サイクルが高速化されるため、研究者や実装者がアルゴリズムの改良を素早く試せる環境が整います。
まとめと今後の展望
FlashAttention-4はFlashAttention-2、3の系譜を継ぎながら、Blackwellという新世代ハードウェアに対してアルゴリズムとカーネルを一体的に再設計した論文です。性能面での数値改善だけでなく、CuTe-DSLを用いたPython実装によって開発効率も大きく向上させた点は、次世代のGPUカーネル研究の方向性を示しています。
NVIDIAのBlackwell GPUはデータセンターへの実導入が進みつつある段階であり、LLMの学習・推論インフラにFlashAttention-4が組み込まれれば、スループットの改善が直接コスト削減や長文処理能力の向上につながります。一方で、本手法はBlackwellの特性に特化した設計であるため、将来の新アーキテクチャへの対応は再設計が必要になる可能性があります。AlgorithmとHardwareの協調設計というアプローチが、今後の効率化研究の主流になっていくかどうかが注目点です。

