- GQAグループ単位の動的ブロック選択により、1Mトークン時の注意計算量を28.4倍削減しH800 GPUでプリフィル14.2倍・デコード7.6倍を実測
- KL損失・勾配デタッチ・Indexer Warmupの3段階安定化で疎Attentionへの移行時の品質劣化を防ぎ、25以上のベンチマークで標準GQAと同等精度を維持
- 109Bマルチモーダルモデル「MiniMax-M3」として実サービスに展開済み、専用GPUカーネルをGitHubで公開し再現性が高い
研究の背景
LLMが長い文脈を処理できるほど、文書要約・長編コード解析・マルチターン会話など実用的なタスクの品質が向上します。しかし標準的なSoftmax Attention(注意機構)は、入力トークン数の2乗に比例して計算量が増える根本的な問題を抱えています。
100万トークン(1Mトークン)の文脈を扱う場合、1,000トークンの場合と比べて単純計算で100万倍の計算コストが生じます。これが長文脈LLMの推論を著しく遅くし、実用展開の最大のボトルネックとなっています。
疎Attention(Sparse Attention)はこの問題に対処する有力な手法ですが、「どのトークンに注意するか」を精度よく動的に選ぶ仕組みの設計が難しく、品質の低下や実装の複雑さが課題でした。MiniMax社が開発したMiniMax Sparse Attention(MSA)は、シンプルさとスケーラビリティを設計の中心に据えることでこの課題を解決した手法です。
MSAの2ブランチ構成

MSAはGQA(Grouped Query Attention、複数のQueryヘッドがKey/Valueを共有する方式)を前提とした、Index BranchとMain Branchの2ブランチ設計です。
Index Branchは「どのブロックを見るか」を決める軽量なスカウト役です。GQAのグループごとに1つのQueryヘッドを用意し、全コンテキストをブロック単位でスキャンします。各ブロック内のトークンスコアをmax-poolingで集約することでブロック全体の重要度を計算し、上位kブロックを選択します。直近のトークンが入るローカルブロックは、スコアに関わらず常に選択に含まれます。
Main Branchは選択されたk×Bkトークン(デフォルトでは2,048トークン)だけを対象に通常のSoftmax Attentionを実行します。GQAグループ内の全Queryヘッドはブロック選択を共有しつつ、独立したQueryプロジェクションを維持するため、細粒度の表現能力は損なわれません。
計算量の観点では、GQAがトークン数Nの2乗に比例してFLOPが増えるのに対し、MSAのMain Branchは選択ブロック数kに対して線形にしか増えません。これが28.4倍というFLOP削減の根拠です。
学習を安定させる3つの仕組み
疎Attentionを素朴に学習させると、Index BranchとMain Branchが乖離したり、補助損失がモデル全体の表現能力を損なったりする問題が起きやすくなります。MSAは3段階の仕組みでこれを防いでいます。
まずKLアライメント損失です。Index BranchのAttention分布を、Main Branchが選択ブロック上で計算した分布に近づけるようKL距離を最小化します。LM損失だけでIndexerを学習させた場合よりも、この直接的な監督信号の方が精度の高いブロック選択を学べることが実験で確認されています。
次に勾配デタッチです。KL損失の勾配をIndex Branchへの入力でStop-gradientにより遮断し、バックボーン(主要モデル)まで伝搬させません。これにより補助損失はIndex Branchの学習だけに効き、バックボーンの汎用能力が劣化しません。デタッチなしで学習すると勾配スパイクが発生し、一般タスクのベンチマーク性能が下がることも確認されています。
最後にIndexer Warmupです。学習の初期段階は両ブランチとも完全Attentionで動かし、Index BranchがMain Branchの分布を十分に学んでから疎Attentionに切り替えます。学習初期にAttentionエントロピーが急落する現象が観察されており、Warmup期間を設けることで長文脈タスクと一般タスクの両方で性能が安定します。
GPUカーネルの3つの最適化
理論的なFLOP削減を実際の速度向上に結びつけるために、MSAは専用GPUカーネルを開発しています。
Exp-free TopKはSoftmaxの順序保存性を活用しています。Softmax変換の前後でスコアの大小順は変わらないため、exp計算なしに生スコアのままTop-kブロックを選択できます。Index Branchのオーバーヘッドを最小化する鍵となる最適化です。
KV外積イテレーションはKVブロックを外側のループとし、Queryをまとめて処理する方式です。従来のQ外積方式と比べて算術強度がおよそBk/3倍に向上し、メモリ帯域幅の効率的な活用を実現しています。
2フェーズフォワードパスは負荷不均衡への対処策です。Attention計算フェーズとCombineフェーズに分離し、部分的な出力と対数和スコアを中間バッファに保持します。クエリごとに選択ブロック数が異なる場合でも、GPU演算の停滞を最小限に抑えます。
実験結果

109Bパラメータ規模のMixture-of-ExpertsモデルでMSAを評価した結果、1Mトークンのコンテキストで注意計算量を28.4倍削減し、H800 GPU上でプリフィル処理14.2倍・デコード7.6倍の速度向上を実測しています。
精度面では、MMLU(言語理解)・数学推論・コード生成・マルチモーダルタスクを含む25以上のベンチマークで標準GQAと同等の結果を維持しています。長文脈特化ベンチマークのHELMETとRULERでも、2,048トークンという厳しいAttentionバジェット制約のもとで完全Attentionベースラインに近い性能を示しました。
また、各GQAグループが異なるストライプ状の選択パターンを学ぶことも可視化実験で確認されています。浅い層では長距離の多様な選択が行われ、深い層になるほど少数の重要領域に絞り込まれる傾向が見られました。ハイブリッドLLMのAttention研究でも議論されている「最初のトークンへの注意集中(Attention Sink)」現象もMSAで観察されており、明示的なSinkパラメータの追加も検討されましたが、デフォルト設計を超える改善は見られなかったと報告されています。
まとめ
MiniMax Sparse AttentionはGQAグループ単位のブロック動的選択という直感的な設計で、1Mトークン規模の長文脈Attentionを実用的な速度で処理する道を切り開きました。3つの学習安定化機構と専用GPUカーネルの組み合わせにより、理論値に近い実測速度向上を実現している点が際立ちます。
既に109Bモデル「MiniMax-M3」として実サービスに展開済みであり、専用GPUカーネルをオープンソースで公開していることから、他チームへの技術波及も期待できます。長文脈LLMの運用コストが障壁となっている現場にとって、MSAは具体的な選択肢の一つとなるでしょう。
