warp decodeによるMoEモデル推論の改善
並列化の軸を反転することで、MoEモデル推論を1.8倍高速化し、精度も向上させます。
ほとんどのMoE推論システムでは、トークン生成の経路をエキスパート中心に構成しています。これはルーティングの仕組みに沿ったものであり、大規模環境では標準的なアプローチでした。ところが、Blackwell GPUでの小バッチデコードでは、エキスパートではなく出力を中心にカーネルを構成するほうが適していることがわかりました。私たちはこのアプローチを「warp decode」と呼んでいます。
warp decodeにたどり着いたきっかけは、BlackwellにおけるMoEデコードで実際に達成可能な最大メモリ帯域幅はどの程度かを突き詰めて考えたことでした。そこから、並列化の軸を完全に反転させる発想に至りました。warpをエキスパートに割り当てるのではなく、各warpを単一の出力値 (ニューロン) に割り当てます。
パフォーマンスと精度の両方を改善できるカーネルはまれですが、warp decodeはその一例です。Blackwellでは、スループットを1.84倍向上させると同時に、出力が完全なFP32参照値に対して1.4倍近くなり、精度も改善します。これにより、Composer向けの研究および学習パイプラインが高速化され、モデルをより速く改善し、新しいバージョンをより頻繁にリリースできるようになります。
従来型のMoEパス
最新のMoEモデルでは、各トークンを specialized なエキスパートネットワークの一部に振り分け、たとえばあるレイヤーでは128個のうち8個を選択します。標準的な実装では、各エキスパートに必要なトークンを集めて演算を実行し、その後に結果を組み立て直す形で、すべての計算がエキスパート中心に構成されています。
これは、各エキスパートあたりで共有される処理によってデータ配置のオーバーヘッドをならせるプリフィルや大規模バッチではうまく機能します。しかし、1回に1トークンずつしか生成しない自己回帰デコードの段階では、それを正当化できるほど共有処理がありません。従来のパスの8段階のうち5段階は、エキスパート中心の処理に合わせてデータレイアウトを管理するためだけに存在しており、実際の計算は一切行っていません。
変更点
Warp decode は、並列化の単位をエキスパートではなく出力に合わせて再編成することで、5 つの「ブックキーピング」ステップをなくします。
最新の GPU は、warp と呼ばれる 32 レーンの並列実行単位で命令を実行します。新しいアプローチでは、各 warp がちょうど 1 つの出力値の計算を担当します。warp は必要な重みデータをメモリから直接ストリーミングし、ルーティングされた 8 つのエキスパートすべてにわたる合計を単一の累積値に集約して、1 つの結果を書き込みます。
このワープの独立性により、warp decode はステージング、ハンドオフ、warp 間の同期ポイント、中間バッファを一切必要とせずに実行できます。MoE の計算レイヤー全体は、moe_gate_up_3d_batched と moe_down_3d_batched の 2 つのカーネルに集約されています。
2つのカーネルの仕組み
gate/upカーネルでは、各cooperative thread array (CTA) は8つのwarpで構成され、各warpはトークンとルーティング先エキスパートの組み合わせごとに1つの中間ニューロンを受け持ちます。warpはルーティング先エキスパートのIDを読み込み、そのニューロンに対応するgateとupの重み行を読み取り、入力アクティベーションベクトル全体をストリーミングしながら処理します。MXFP8の重みはその場でFP32に変換され、2つのドット積はどちらもプライベートレジスタに累積されます。
2つのカーネルは1回のパスに融合されているため、アクティベーションベクトルは一度だけ読み込まれ、共有メモリにステージングすることなく、そのまま両方の射影にすぐ使い回されます。warpレベルのリダクション後、warpはSiLU(gate) × upを適用し、1つの中間値を書き出します。
downカーネルでは、各warpが1つのトークンに対する1つの出力次元を担当します。すべてのtop-kルーティングエキスパートを順に処理し、該当するdown projectionの重み行を読み込み、中間アクティベーションをストリーミングしながら、各エキスパートのルーティング重みを1つのFP32アキュムレータに織り込みます。
すべてのエキスパートの処理が終わると、__shfl_xor_syncを使ったwarpレベルのバタフライリダクションで、32個のlaneローカルな部分和をまとめます。これはPTXのshfl.sync.bfly命令に直接コンパイルされます。この命令は、warp内のlane間でレジスタをやり取りする単一のハードウェアプリミティブで、共有メモリを完全にバイパスします。
ここでの利点は、L1との往復アクセスやバンク競合、明示的なバリアが不要なことです。同期はlaneマスクを介して命令自体に組み込まれているためです。別個のエピローグを設ける代わりに、最終的な重み付きtop-k結合は射影そのものの一部になります。
warp decode内の各warpは独立しており、その生存期間を通して一貫した単一の役割、つまり1つの出力スカラーを生成する役割だけを担います。このワープの独立性によって、従来のパスで必要だった共有メモリへのステージング、warp間の同期、中間バッファが不要になります。
パイプラインの簡素化と高速化
Warp decode は、従来のパスで必要だったステージやバッファを削減することと、warp 間の独立性を生み出して、より適切なスケジューリングとレイテンシ隠蔽を可能にすることという、2 つの主要な仕組みによってパフォーマンスを向上させます。
ステージの排除
ステージの排除によって、スループット向上の大部分が得られます。padding、scattering、そして combine ステップを排除します。これらのステージをなくすには、従来のパイプラインのステージを単に統合するだけでなく、並列性を根本から組み替える必要があります。
パディングの排除
従来のパス: grouped kernel の要件に合わせるため、各エキスパート の token リストを 2 のべき乗、または 128 バイト境界にそろうようパディングします。デコード時に token が 1 つしかない場合、これは償却できないオーバーヘッドになります。
Warp decode パス: エキスパート ごとのバッチを一切作成しないため、このオーバーヘッドを完全に回避します。
scatter と combine の排除
従来のパス: 各エキスパート の処理が終わるたびに、8つの中間結果を GPU メモリに書き出し、その後、それらを結合するために別個のリダクションステップを実行します。
Warp decode パス: 各エキスパート のルーティング重みは、warp 内の実行中のアキュムレータに取り込まれます。8つの中間結果はメモリ上に生成されないため、後続のリダクションパスで発生する書き込みコストと読み取りコストの両方を削減できます。
バッファの排除
この再編成では、従来のパスでエキスパート中心のレイアウトに起因して必要だった2つの中間メモリバッファも不要になります。
1つ目は activation gather buffer で、入力アクティベーションベクトルをコピーしてエキスパート優先レイアウトに並べ替えたものです。バッチサイズが1の場合、これはすでに存在しているデータを丸ごとコピーすることになります。2つ目はエキスパートごとの出力バッファです。エキスパート数が8、隠れ次元が2048の場合、これは BF16 でトークンあたり 8 × 2048 × 2 bytes = 32 KB となり、確保され、書き込まれ、すぐに一度だけ読み出されたあと破棄されます。
Warp decode では、32本の warp レーンにまたがるレジスタアキュムレータに8つのエキスパートの寄与を畳み込むことで、この両方を排除します。このため、最後の単一スカラー書き込みまでは、グローバルメモリに何も到達しません。トークンあたり 32 KB 超の中間バッファトラフィックを削減することで、実際にパフォーマンスを左右する重み行のために L2 キャッシュ容量を空けられます。
ワープの独立性
この再編成により、残された計算も高速になります。これは、カーネルが設計上「embarrassingly parallel」だからです。つまり、各ワープはほかのすべてのワープから完全に独立しています。各ワープは出力スカラーをちょうど 1 つだけ担当し、必要な重み行だけを読み取るため、ワープ間に共有された可変状態はありません。
単一のワープのレベルでは、この独立性は完全です。入力アクティベーションは読み取り専用で、アキュムレータはプライベートレジスタ内に保持され、出力は一意のアドレスに書き込まれます。ハードウェアスケジューラの観点では、出力次元全体は独立した作業項目が並ぶフラットなプールです。
GPU のワープスケジューラは、正しさの制約を気にすることなく、任意のワープを任意のタイミングで任意の順序で発行できます。あるワープがメモリロード待ちで停止すると、スケジューラは即座に別のワープへ切り替えます。B200 の 148 基の streaming multiprocessors 全体で数千のワープが同時に実行されるため、メモリレイテンシはほぼ完全に、ほかのワープによる役に立つ計算の陰に隠れます。
このカーネルは線形にスケールするため、出力次元を 2 倍にすると、追加の同期なしで独立したワープ数も 2 倍になります。同じことはトークンのバッチ次元にも当てはまるため、スケジューラには、ノード間にエッジのない 1 つのフラットな作業空間として見えます。これは、エキスパート向けの GEMM カーネルでブロック内の協調が必要になる従来のパスとは対照的です。
結果
大規模環境におけるエンドツーエンドのデコードスループット
NVIDIA B200 GPU 上で Qwen-3 スタイルのモデルを実行する社内推論システムでテストしたところ、一貫したスループット向上が確認されました。スループット向上はすべてのコンテキスト長バケットでほぼ一定であり、これはプロンプト長に依存しない、純粋に生成時の改善であることを示しています。
精度向上
中間のアクティベーション量子化ステップをなくすことで、品質に定量的に確認できる影響が現れます。BF16アクティベーションをMXFP8に変換してから再び戻すと、丸め誤差の下限が生じ、それがモデルの各層を通じて蓄積されます。Warp decodeでは、アクティベーションは全工程でBF16、アキュムレータはFP32のまま維持されるため、リダクション処理が精度の落ちた入力に対して行われることはありません。その結果、warp decodeの出力は、従来の経路よりも完全な32ビット精度の基準値に対して1.4倍近くなります。
ハードウェア効率
warp decode の開発ではまず、ハードウェアの最大スループットにどこまで迫れるかを検証しました。B200 における連続メモリ読み出しの実測ピークは 6.8 TB/s です (copy kernel を使って測定) 。warp decode は B=32 で 3.95 TB/s を維持し、これはそのピークの 58% に相当します。残る差は、エキスパート ルーティング が生むランダムアクセスパターンによるメモリレイテンシのコストを反映している可能性が高いと考えられます。これは、各 token が 5、8、14、19 など、隣接していない エキスパート にルーティングされる場合があるためです。
一方、ピークスループットは連続した (0,1,2,3) メモリ読み出しで測定されます。reference implementation との一致度は、すべてのバッチサイズで非常に高い水準でした。最小コサイン類似度は > 0.999996、絶対誤差の最大値は 0.001953 でした。
warp decodeとComposerの学習
warp decodeは、エキスパート中心の実行を全面的に置き換えるものではありません。prefill や大規模バッチ推論のような高負荷のワークロードでは、多くのトークンが同じエキスパートを共有するため、依然としてエキスパート中心のパッキングが有効です。トークンを整理するコストも、十分な実計算量に対してなら償却できるため、見合うものになります。
warp decodeが威力を発揮するのは、そうしたオーバーヘッドを正当化できるほど、エキスパートごとの共有作業が十分にない場合で、これはMoEデコードでよくある状況です。そのため、これはComposerを継続的に改善していくうえで重要な要素となっています。事前学習データやRLへの投資がモデルの出力品質を左右する一方で、warp decodeのような推論への投資は、その出力がどれだけ速く、正確に開発者へ届くかを左右します。