コンテンツへスキップ

カスタム MXFP8 カーネルで MoE 学習を1.5倍高速化

作成者: Stuart Sul研究

Blackwell GPU向けの全面再構築により、MoEレイヤーを3.5倍高速化。

私たちは世界最高のAIコーディングモデルを構築したいと考えていますが、大規模言語モデルの学習には高いコストがかかります。たとえば、社内で最大規模のモデルでは、数万台のGPUを使用して学習に数週間を要することがあります。これは計算資源の面で高価であるだけでなく、改善をユーザーの皆さまに届けるスピードも遅くしてしまいます。

私たちは最近、Hopper GPU(H100)から Blackwell GPU(B200)へのアップグレードを開始し、それを学習ワークロードを徹底的に最適化する好機と捉えました。プロファイリングの結果、主なボトルネックは MegaBlocks で実装された Mixture-of-Experts(MoE)レイヤーであり、順伝播時間の約53%と逆伝播時間の27%を占めていることが判明しました。

そのため過去数週間にわたり、CUDA ライブラリへの依存を一切持たず、GPU カーネルレベルで MoE レイヤー全体をゼロから書き直しました。代わりに、昔ながらの純粋な CUDA と PTX を用い、そこに少量の ThunderKittens を織り交ぜています。その結果、順伝播・逆伝播の双方で MoE レイヤー性能を 3.5 倍向上し、Blackwell での学習をエンドツーエンドで 1.5 倍、従来の Hopper 構成比で 2 倍高速化しました。現時点で入手可能なオープンソース代替のいかなる組み合わせよりも、私たちのスタックは高速だと考えています。

Figure 1. Relative speedups of MXFP8 MoE compared to BF16 (normalized to 1.0).
Figure 1. Relative speedups of MXFP8 MoE compared to BF16 (normalized to 1.0).

この改善の大部分は、BF16 から MXFP8 への移行によるもので、学習品質の低下はほぼゼロで達成できました。しかし同時に、低精度化は言うは易く行うは難しであることも学びました。注意深く設計しないと、各種カーネルのオーバーヘッドにより、MXFP8 での学習は BF16 に比べて性能向上がごくわずかにとどまる場合があります。さらに MXFP8 の学習レシピは広く共有されておらず、最適な手法は自ら見つけ出す必要があります。

以下では、私たちの取り組むレシピと、Cursor における ML の研究・開発についてご紹介します。

Microscaling(MX)データ形式のクイック紹介

大規模なディープラーニングモデルの計算コストを下げる一般的な方法は、より低精度のアクティベーションと重みを用いることです。しかし、それらを狭いビット幅(例:8ビット以下)の形式に変換すると、値を適切にスケーリングしない限り、許容できない丸め誤差が生じます。たとえば、大規模モデルの一部の重みが 0.0001、0.0005、または –0.0007 であったとして、何も考えずに FP8 に変換すると、FP8E4M3 で表現可能な最小の正の値が約 0.0019 であるため、これらはすべて同じ数値、すなわちゼロへと丸められてしまいます。

この課題に対処するため、各テンソルごとにスケーリング係数を適用し、対象のデータ形式で表現可能な範囲内に収まるようにテンソルを再スケーリングするのが一般的です。これにより、利用可能なダイナミックレンジを最大限活用できます。例えば、テンソルの値がすべて –0.0001 から 0.0001 の間にある場合、スケーリング係数を 4,480,000 に設定できます。これにより、スケーリング後のテンソルの値は [–448, 448] の範囲となり、これは FP8E4M3 の表現可能な境界に対応します。

Microscaling は、テンソル全体に単一のスケール係数を用いるのではなく、テンソルのより細かなサブブロックごとにスケーリングを適用することで、さらに一歩進んだ手法です。Microscaling (MX) 形式は、このアプローチを標準化し、低精度かつマイクロスケールのデータ形式のセットを定義します。詳細はその仕様で提供されており、以下の具体的な MX 準拠形式が定義されています。

Table 1. Concrete MX-compliant formats.
Table 1. Concrete MX-compliant formats.

例えば、MXFP8 形式で E4M3 要素タイプを使用する場合、データは FP8E4M3 要素で構成され、下図に示すとおり、連続する32要素ごとのブロックに対して FP8E8M0 のスケール係数が適用されます。

Figure 2. MXFP8 quantization example: Each 1x32 block shares a scale.
Figure 2. MXFP8 quantization example: Each 1x32 block shares a scale.

マイクロスケーリングは大きな性能向上を引き出せる一方で、実運用に適用するにはいくつかの課題があり、基盤となるハードウェアへの依存も大きくなります。まず、NVIDIA Blackwell GPUでマイクロスケーリングを適用することが特に難しい理由を見ていきましょう。

1. テンソルメモリとCUDAコアがデクオン化の勢いを殺す

マイクロスケール化された FP8 行列乗算では、計算は縮約次元に沿って小さなブロック単位のステップに分割されます。各ブロック行列乗算の後、部分結果はスケール係数を用いてデ量子化され、次のブロックに進む前に加算(蓄積)されます。

たとえば、よく知られている DeepSeek V3 (DSV3) の技術レポートでは、A 行列は 1 \times 128 のサイズのブロックでスケーリングされ、B 行列は 128 \times 128 のサイズのブロックでスケーリングされています。つまり、手順は次のとおりです。

1. 128ブロックの行列乗算を実行します:

C_block = A[:, k*128:(k+1)*128] @ B[k*128:(k+1)*128, :]

2. スケール係数を用いてデ量子化を適用します:

C_block = A_scale[:, k] * C_block * B_scale[k, :]

ここで、A_scale は K 次元に沿って 128 要素にブロードキャストされ、B_scale は両次元に沿って 128 要素にブロードキャストされます

3. 復量子化したブロックを累積する:

C += C_block

4. 縮約の次元に沿って、次のブロックに進みます。

k += 1

この手法は(DSV3 が学習に用いられた)Hopper アーキテクチャで自然に高い効果を発揮します。理由は、(1) テンソルコアの行列積(wgmma 命令による)の結果がレジスタに累積されること、そして (2) CUDA コアでの量子化解除を行っている間に、他のテンソルコア行列積を非同期に起動してパイプライン化できることです。すべてがレジスタ内で累積されるため、行列積の間で追加のデータ移動は不要です。

これは、tensor memory(TMEM)により、Blackwell GPU ではもはや当てはまりません。TMEM は Blackwell GPU に追加されたオンチップの SM ローカルメモリの新しいセットです。Hopper GPU ではテンソルコアによる行列乗算の結果がレジスタに直接累積されますが、Blackwell GPU ではそれらが TMEM(tcgen05.mma 命令経由)に累積されます。アキュムレータに対してカスタム演算を行うには、TMEM からレジスタへ結果を転送し、CUDA コアで処理し、TMEM に書き戻し、さらにデータ移動が非同期であるためこれらすべての命令の完了を待機する必要があります。TMEM は共有メモリ(SMEM)よりアクセスが高速ではあるものの、これでもテンソルコアの占有率を低下させてしまいます。

Figure 3. Gantt chart taken from our custom Blackwell attention kernel. First row shows the tensor core activity (QK<sup>T</sup>). Second row shows the CUDA core activity (TMEM → registers, then softmax). TMEM → register latency causes the tensor core pipeline bubble.
Figure 3. Gantt chart taken from our custom Blackwell attention kernel. First row shows the tensor core activity (QKT). Second row shows the CUDA core activity (TMEM → registers, then softmax). TMEM → register latency causes the tensor core pipeline bubble.

一見すると、これはパイプライン型の手法を示唆しています。つまり、TMEM を複数のチャンクに分割し、一方のチャンクで行列積を実行している間に、もう一方のチャンクではデータ移動と脱量子化を行うというものです。

しかし、ここには別の問題があります。Blackwell のテンソルコアは Hopper と比べて TFLOP/s が倍増した一方で、FP32 の CUDA コアは約 33%(60 → 80 TFLOP/s)しか向上していません。32 ブロックスケーリングでは、非量子化に費やされる計算量は行列積の 1/32(どちらも乗算と加算を伴い、CUDA コアによる非量子化では手動で加算を蓄積する必要がある)ですが、非量子化の速度は行列積の 1/56(理論上 4,500 TFLOP/s)にすぎません。その結果、CUDA コアでの非量子化は行列積に要する時間のほぼ 1.76 倍かかり得ます。これは最悪です。

Hopper (H100 SXM5)Blackwell (B200)

FP8 Tensor Core スループット

1,979 TFLOP/s

4,500 TFLOP/s

FP32 CUDA Core スループット

60 TFLOP/s

80 TFLOP/s

32ブロックのデクアンタイズ時間

(行列積に対する相対値)

1.03x

1.76x

表 2. Hopper と Blackwell における相対的なデクオンタイズ化コスト。

実測では、上記手法のいかなるバリエーションを用いても、Hopper の現実的な FP8 スループットである 1,500 TFLOP/s を上回ることはできませんでした。しかも、これは量子化のオーバーヘッドすら考慮していません。

2. 夥しい量子化による死

あらゆる行列は、FP8 行列積カーネルに渡す前に FP8 量子化を行う必要があります。量子化にかかる時間は、可能な限り削減し目立たなくすることが不可欠です。注意深く扱わないと、量子化カーネルが GPU の実行時間の大半を占めてしまうことがあります。

量子化がどれほど支配的になり得るかを理解するために、単純な行列積 C = A x B を考えてみましょう。ここで、

  • A は M×K
  • B は K×N
  • C は M×N

MoE の学習におけるグループ化された行列乗算では、K や N と比べて M が非常に大きいのが一般的です。そのため、次のようにしましょう:

  • M = 131,072
  • K = 7,168
  • N = 2,048

これにより、行列乗算そのものに必要な浮動小数点演算の総数は次のとおりです。

131,072×7168×2048×2FLOP=3.85TFLOP131{,}072 \times 7168 \times 2048 \times 2 \,\text{FLOP} = 3.85 \,\text{TFLOP}

当社のベンチマークでは、Blackwell GPU(B200)は FP8 行列乗算のスループットが約 3,300 TFLOP/s であることが示されています。これは、行列乗算の実行に必要な実時間(ウォールクロック時間)の見込みが次のとおりであることを意味します:

3.85 TFLOP / 3,300 TFLOP/s=1.16 ms3.85 \text{ TFLOP } / \text{ } 3{,}300 \text{ TFLOP/s} = 1.16 \text{ ms}

ただし、A 行列と B 行列を量子化するのに要する時間も考慮する必要があります。量子化はメモリ帯域幅に律速されるため、重要なのは総データ移動量です。元の行列が BF16(2 バイト)で、スケールのブロックサイズが 32 であると仮定しましょう。これは次のことを意味します:

  • 負荷 A: 131,072 \times 7,168 \times 2\text{B} = 1.88 \text{ GB}

  • ロード B: 7,168 \times 2,048 \times 2\text{B} = 0.029 \text{ GB}

  • 量子化済み A を保存: 131,072 \times 7,168 \times 1\text{B} = 0.94 \text{ GB}

  • 量子化した B を保存: 7,168 \times 2,048 \times 1\text{B} = 0.015 \text{ GB}

  • A のスケール係数を保存: 131,072 \times 7,168 \times \frac{1}{32} \times 1\text{B} = 0.029 \text{ GB}

  • B のスケール係数を保存: 7,168 \times 2,048 \times \frac{1}{32} \times 1\text{B} = 0.0005 \text{ GB}

合計で、これはおよそ 2.9 GB の High Bandwidth Memory (HBM) の読み書きに相当します。B200 上で持続的に 6.5 TB/s の HBM スループットが得られると仮定すると、最適化された FP8 量子化カーネルに必要な時間は次のとおりです:

2.9 GB / 6.5 TB/s0.44 ms2.9 \text{ GB } / \text{ } 6.5 \text{ TB/s} \approx 0.44 \text{ ms}

これは行列積の時間のほぼ40%に相当します。バックワード伝播のためにA行列を転置して量子化する一般的なケースでは、これが2倍になって0.88 msとなり、行列積の処理時間の約76%に達します!

理論上は、FP8 の行列積は BF16 の行列積よりも 2 倍高速ですが、量子化に要する時間が実際には性能向上を完全に打ち消してしまう可能性があります。上記の分析は、ブロックスケールされた FP8 の行列積が 3,300 TFLOP/s で動作するという仮定に基づくため、楽観的でもあります。これはエンドツーエンドの MoE 学習で一般的に達成される水準を大きく上回ります。最悪の場合、量子化にかかる時間が FP8 の行列積よりも長くなり、結果として FP8 の計算全体が BF16 よりも遅くなることもありえます。

高速なオープンソースのFP8量子化カーネルはいくつか存在します。具体的には、NVIDIAの TransformerEngine や Pytorch の TorchAO です。しかし、当社のマイクロベンチマークでは、これらは帯域幅を活用しきれておらず、約 4.5 TB/s で動作していることが示されました。SM 内命令(例:cp.async.bulk)のみに依拠した場合、Blackwell は容易に 6~7 TB/s の HBM 帯域幅で動作できることが分かっています。

さらに、NVIDIA の MXFP8 のテンソルコア行列乗算用 PTX 命令である tcgen05.mma 命令は、やや直感に反する スケールファクタのレイアウトを必要とします。32 ブロックスケーリングでは、TransformerEngine や TorchAO の量子化カーネルは、スケールファクタを素朴な M x N / 32 のレイアウトで返します。これはその後、PyTorch 側でリシェイプするか、あるいはリシェイプのロジックを他のカーネルに融合する必要があり、いずれも性能に悪影響があります。実際のところ、行列乗算カーネル内でスケールファクタを処理するのは避けたいところです。最速のロード経路は、HBM → SMEM(cp.async.bulk)を経由し、その後 SMEM → TMEM(tcgen05.cp)へと進む方法です;スケールが一度でもレジスタタイルに回り道すると、その時点でテンソルのリズムは崩れます。

次に、前述の課題への対処方法を、まず量子化へのアプローチから順に説明します。

最適な低精度レシピの選び方

BF16 の学習品質に合わせるため、低精度での一連の実験を行い、各レシピが BF16 からどれだけ乖離するかを測定しました。これらの結果から、当社のワークロードにおいて BF16 とほぼ同一の学習損失収束を示す手法を特定しました。

具体的には、要素のデータ型として FP8E4M3(指数ビット 4、仮数ビット 3)を用いる MXFP8 形式、スケールのデータ型として FPE8M0(指数ビット 8)、スケーリングのブロックサイズを 32 としています。また、論文 “Recipes for Pre-training LLMs with MXFP8” の MXFP8 量子化レシピも採用しています。次を満たすように定義します:

  • BF16(または FP32)ベクトル V = \{ V_i \mid i = 0, \ldots, 31 \}
  • 対応する FP8E4M3 ベクトル Q = { Q_i \mid i = 0, \ldots, 31 }
  • FP8E8M0 スケール S

次のようにQとSを計算します。

S=cast_to_fp8e8m0(absolute_max(V)/448)S = \text{\texttt{cast\_to\_fp8e8m0}}\big(\text{\texttt{absolute\_max}}(V) / 448\big)
Qi=cast_to_fp8e4m3(Vi/S)Q_i = \text{\texttt{cast\_to\_fp8e4m3}}\big(V_i / S\big)

ここで、cast_to_fp8e8m0 は最も近い 2 の冪に切り上げ、最小値を 2^{-127} にクランプします。また、cast_to_fp8e4m3 は範囲外の値をサチュレートし、最も近い値(最近接、同値は偶数に丸め)に丸めます。

これにより、FP8 の学習損失は BF16 の学習損失と一致しました。

Figure 4. BF16 vs MXFP8 Training Loss over 10k steps: nearly indistinguishable.
Figure 4. BF16 vs MXFP8 Training Loss over 10k steps: nearly indistinguishable.
Figure 5. BF16 vs MXFP8 Training Loss (9k–10k steps) with a specific data point at step 9832.
Figure 5. BF16 vs MXFP8 Training Loss (9k–10k steps) with a specific data point at step 9832.

tcgen05 MXFP8 ブロックスケール行列乗算の採用

NVIDIA Blackwell アーキテクチャでは、MX フォーマットが Tensor Core に組み込まれています。ブロックスケーリングは tcgen05.mma...block_scale 命令で呼び出され、ハードウェアで処理されます。これにより、テンソルコアの行列積演算中にすべてが完結するため、TMEM からデータを取り出して非量子化する必要がなくなります。したがって、MXFP8 行列積カーネルの設計は tcgen05.mma を中心に据え、最大性能を得るためにその 仕様 の制約内で動作させる必要があります。

いくつか重要な留意点があります。まず、tcgen05.mma 命令は非同期に起動するのに単一スレッドだけを必要とします。これは Hopper とは対照的で、Hopper では wgmma 命令が非同期起動にワープグループ(128 スレッド)全体を必要とします。結果として、Hopper のカーネルで一般的なプロデューサ/コンシューマ パターン、すなわち 256 以上のスレッドを行列乗算の起動に割り当てる方式からは外れる必要があります。

次に、tcgen05.mma は 2-CTA の行列乗算をサポートしており、2 つの SM が B 行列を共有して協調実行します。これによりメモリトラフィックと共有メモリ使用量が削減され、より深い行列乗算パイプラインが可能になります。ベンチマークでは、この 2-CTA モードは非クラスター版と比べて MXFP8 の行列乗算で約 15~20% の高速化を示し、ピーク性能の達成に不可欠であることがわかりました。

第三に、先述のとおり、tcgen05.mma はレジスタではなく TMEM に結果を蓄積します。これはレジスタ使用圧を下げますが、同時に tcgen05.ldtcgen05.st 命令を介した TMEM とレジスタ間のデータ移動を追加で発生させます。これらのデータ移動は最小限に抑える必要があります。

最後に、tcgen05.mma 命令ではスケール係数が TMEM に常駐している必要があります。しかし、HBM から TMEM にスケールを直接ロードする方法はありません。最速の手順は、まず cp.async.bulk.tensor 命令(Tensor Memory Accelerator、略して TMA を活用)を用いて HBM からオンチップの SMEM にデータをロードし、その後 tcgen05.cp 命令で SMEM から TMEM へ転送することです。これを機能させるためには、スケール係数は本記事の後半で説明するように、tcgen05.mma が想定するレイアウトで保存されている必要があります。

これらの検討に関わるすべての命令(tcgen05.mmacp.async.bulk.tensortcgen05.cptcgen05.ld、およびtcgen05.st)は、単一スレッドによって非同期に起動されます。これにより、ワープ特化(warp specialization)を適用し、TMEMとSMEMをリングバッファとして用いたパイプライン化データフローを設計できます。

これを行うために、まず TMEM と SMEM を分割します。Blackwell では、128×512 の TMEM(セルあたり 32 ビット)と、スレッドブロックあたり 227 KB の連続した SMEM が与えられます。TMEM は A と B のスケール格納用に 5 スロットに分割し、行列乗算の累積(MMA)のための領域を残します。同様に、SMEM では MMA の結果を HBM に書き戻すための領域を確保し、残りを入力タイルとスケール係数をロードするための 5 スロットに分割します。レイアウトは以下に示します。

Figure 6. Simplified TMEM allocation: accumulator region plus 5 slots each for A and B scales
Figure 6. Simplified TMEM allocation: accumulator region plus 5 slots each for A and B scales
Figure 7. Simplified SMEM allocation: 5 slots reserved for input tiles and scales
Figure 7. Simplified SMEM allocation: 5 slots reserved for input tiles and scales

このセットアップでは、特定のワープがHBMからSMEMへ入力タイルとスケールを継続的にロードし、別のワープがスケールをSMEMからTMEMへ移動し、さらに別のワープがMMAを起動し、また一部のワープはときどきTMEMのアキュムレータをレジスタへロードしてSMEMに格納し、TMA経由でHBMに書き戻すというパイプラインを設計します。

Figure 8. Simplified MXFP8 Matrix Multiplication Pipeline
Figure 8. Simplified MXFP8 Matrix Multiplication Pipeline

具体的には、各スレッドブロックに 3 つの warpgroup(384 スレッド)を割り当て、warpgroup を 2 種類に特化させます。2 つの warpgroup は TMEM → レジスタ → SMEM → HBM のデータフローのみを実行し、これはレジスタプレッシャーが最も高い処理です。もう一方の warpgroup は warp 専門化を行います。Warp 0 は入力タイルを HBM から SMEM にロードし、warp 1 はスケールを HBM から SMEM にロードし、warp 2 はスケールを SMEM から TMEM にロードし、warp 3 はテンソルコアの行列乗算を起動します。また、永続グリッドパターンも実装しており、各 SM(Blackwell GPU では 148)に 1 つのスレッドブロックを割り当て、結果を HBM に書き戻している間に新しい入力タイルをロードできるようにしています。

このパイプラインの擬似コードは次のとおりです。

if (warpgroup_id < 2) {
    for (int i = 0; i < num_tiles; i++) {
        mbarrier_wait_for_final_matmul_completion(); // mbarrier.try_wait
        async_load_from_TMEM(reg, TMEM); // tcgen05.ld
        wait_for_load_completion(); // tcgen05.wait
        // This is iterated in the actual implementation to save SMEM 
        store_to_SMEM(SMEM, reg);
        TMA_async_store(HBM, SMEM); // cp.async.bulk.tensor
    }
} else {
    if (warp_id == 0) {
        for (int i = 0; i < num_tiles; i++) {
            for (int j = 0; j < num_iters; j++) {
                mbarrier_wait_for_matmul_completion();
                // load input tiles
                TMA_async_load(SMEM, HBM);
            }
        }
    } else if (warp_id == 1) {
        for (int i = 0; i < num_tiles; i++) {
            for (int j = 0; j < num_iters; j++) {
                mbarrier_wait_for_tcgen05_cp_completion();
                // load scales (HBM -> SMEM)
                TMA_load(SMEM, HBM);
            }
        }
    } else if (warp_id == 2) {
        for (int i = 0; i < num_tiles; i++) {
            for (int j = 0; j < num_iters; j++) {
                mbarrier_wait_for_matmul_completion();
                mbarrier_wait_for_scale_SMEM_load();
                // load scales (SMEM -> TMEM)
                load_to_TMEM(TMEM, SMEM); // tcgen05.cp
            }
        }
    } else if (cta_rank == 0) { // 2-CTA MMA is launched by a single CTA
        for (int i = 0; i < num_tiles; i++) {
            mbarrier_wait_for_TMEM_clear();
            for (int j = 0; j < num_iters; j++) {
                mbarrier_wait_for_input_SMEM_load();
                mbarrier_wait_for_scale_TMEM_load();
                // tcgen05.mma.cta_group::2.mxf8f6f4.block_scale
                launch_matmul(SMEM, SMEM, TMEM);
            }
        }
    }
}

Blackwell GPU におけるブロックスケール行列乗算で避けられない制約の1つは、TMEM のサイズです。マイクロベンチマークでは、Blackwell のテンソルコアは、128x512 の TMEM をアキュムレータとしてフルに使用した場合に最高のスループットを達成することが示されています。2-CTA の FP8 行列乗算では、これは 256x32x256 の tcgen05.mma 命令を常時 2 本実行することに相当します。各 tcgen05.mma は CTA ごとに 128x256 の TMEM 領域を消費するため、2 本を組み合わせると、両方の CTA にわたって 128x512 の TMEM アレイを完全に占有します。

しかし、スケール係数もTMEM内に常駐させる必要がある場合、TMEMの128x256領域のみを用いて、一度に実行できるのは単一の256x32x256のtcgen05.mma命令に限られます。その結果として、性能低下は避けられません。たとえば、この制約下では16,384x16,384x16,384のFP8行列積のスループットは3,200 TFLOP/sから3,040 TFLOP/sへと低下します。

これらのスループット値は、純粋な FP8 行列乗算にのみ適用されます。MXFP8 のブロックスケーリングを用いると、TMEM のパイプライン処理によるオーバーヘッドのため、スループットは必然的にさらに低下します。実運用では、ブロックスケーリングされた MXFP8 行列乗算カーネルに対し、L2 キャッシュのクリアを行ったうえで約 2,750 TFLOP/s を達成しています。それでも、これは標準的な BF16 行列乗算(最適形状で通常 1,500~1,550 TFLOP/s)と比べておよそ 1.83 倍高速です。悪くない滑り出しです!

MXFP8 のグループ化行列乗算への拡張

スタンドアロンの MXFP8 行列積カーネルは有用な第一歩ですが、MXFP8 の MoE 学習中における用途は限定的です(例:共有エキスパートのシナリオ)。MXFP8 で MoE を完全にサポートするには、グループ化された行列積カーネルが必要です。具体的には次のとおりです:

  1. グループ化された順伝播(Fprop)/データ勾配(Dgrad)
  2. グループ化された重み勾配(Wgrad)

これらのバリアントは、私たちがこれらのカーネルをゼロから構築している理由の一部であることにご留意ください。現時点までに、32ブロックスケーリングに対応した MXFP8 MoE の学習を完全にサポートするオープンソースの代替手段は見つかっていません。

カーネルレベルでは、Grouped Fprop と Dgrad は同じ構造を共有します。唯一の違いは、入力テンソルが up と gate の両方の射影を通過するため、Dgrad では加算が必要になる点です。しかしこれは、cp.async.bulk.tensor 命令を cp.reduce.async.bulk.tensor に置き換えることで容易に実装できます。後者は HBM へのアトミックな加算ストアを非同期で実行できます。

次の行列が与えられているとします:

  • Anum_tokens x in_dim
  • W: E x in_dim x out_dim、ここで E は現在のランク上のエキスパート数
  • 出力: num_tokens x out_dim

また、トークンがエキスパートのインデックス順に並んでいると仮定すると、グループ化された Fprop/Dgrad カーネルは次を実行します:

for i in range(num_routed_experts):
    start = 0 if i == 0 else end
    end = start + assigned_tokens_per_expert[i]
    O[start:end, :] = A[start:end, :] @ W[i, :, :]

一方で Grouped Wgrad は、エキスパートの分割が M 軸ではなく K(縮約)軸に沿って行われる点が異なります。これは次を計算します:

for i in range(num_routed_experts):
    start = 0 if i == 0 else end
    end = start + assigned_tokens_per_expert[i]
    W_grad[i, :, :] = A.T[:, start:end] @ O[start:end, :]

カーネル抽象化

カーネルレベルでは、一般的な作業単位は、指定された行・列・縮約範囲にわたる行列の乗算・加算(multiply-accumulate)です。この単位を抽象化したことは、グループ化された行列乗算カーネルの実装に非常に有用で、元の MXFP8 行列乗算カーネルを最小限の変更で再利用できるようになりました。そこで、私たちは元のカーネルを切り出しました:

// All units of 128
int expert_idx = ... // 0 for 2D case
int row_block_idx = ...
int col_block_idx = ...
int reduction_block_start_idx = ...
int reduction_block_end_idx = ...

// Based on MXFP8 matrix multiplication implemented above
// Runs at 256x128 granularity when possible (if not, 128x128)
run_mxfp8_matmul(
    expert_idx,
    row_block_idx, 
    col_block_idx, 
    reduction_block_start_idx, 
    reduction_block_end_idx
);

この抽象化により、グループ化された MXFP8 の行列積バリアントの実装は、上記の抽象化を起動する前に、適切なループ構造を設計し、インデックスを正しく割り当てる作業へと還元されます。ただし、単純なループ処理は、L2 キャッシュの利用効率が低いためにパフォーマンスを大きく低下させる可能性があります。

専門家単位のスーパーグルーピングによるL2キャッシュ最適化

高い L2 キャッシュの活用率を維持することは極めて重要です。ベンチマークでは、非効率な HBM アクセスパターンにより、グループ化された行列積カーネルで性能が約 50% 低下する可能性があることが分かりました。これに対処するため、L2 の再利用を最大化する supergroupingThunderKittens カーネル由来のヒューリスティクス)を適用しました。これは、任意の時点で 148 個の SM 全体が計算する出力行列の領域を可能な限り正方形に近づけることで、L2 の再利用を確保します。詳細は上記リンク先のカーネルコードをご覧ください。

グループ化行列乗算カーネルの主要な改良点は、出力行列全体ではなく現在のエキスパートに属するサブ行列のみを考慮して、エキスパートごとにスーパーグループ化を適用したことでした。これは特にグループ化Wgradにおいて効果的で、エキスパート分割により縮約軸が狭くなることが多いためです。縮約軸が狭いとテンソルコアの利用率が低下し、メモリ帯域幅が主要なボトルネックとなります。

適切な行列乗算カーネルの抽象化と、エキスパート単位の L2 キャッシュ最適化により、グループ化した MXFP8 行列乗算カーネルで約 2,650 TFLOP/s を達成しました—非グループ化版と比べても低下はわずか 4% です。素晴らしい成果です!

グループ化行列乗算のベンチマーク

Blackwell上でFprop/Dgrad/Wgradをグルーピングして扱える最も近いオープンソース代替は、DeepSeekのDeepGEMMです。DeepGEMMは、AおよびB行列に対して1x128および128x128のスケールブロックを用いる点でやや異なり、その代償として精度が低下します。とはいえ、利用可能な唯一の選択肢であるため、当社はDeepGEMMを内部のモデル学習に統合し、その性能をプロファイリングしました。DeepGEMMはマイクロベンチマークでは特定の入力形状において優れた結果を示した一方で、エンドツーエンドのベンチマークでは、当社のワークロードにおいて当社のカーネルがそれを上回ることが確認されました。

DeepSeek DeepGEMM当社

グループ化 Fprop / Dgrad

0.67 ms

0.43 ms

グループ化 Wgrad

0.71 ms

0.65 ms

表 3. 社内モデル学習中の、グループ化された Fprop/Dgrad および Wgrad カーネルの平均レイテンシ。

また、上記のベンチマークには量子化に要する時間が含まれていない点にも注意が必要です。DeepGEMM は最適化された量子化カーネルを提供しておらず、それがない場合、エンドツーエンドの性能は急速に低下し得ます。最悪の場合、BF16 の学習よりも遅くなることさえあります。

史上最速の MXFP8 量子化カーネルを構築する

前述のとおり、既存の MXFP8 量子化カーネルは最適とは言えないだけでなく、tcgen05.mma のスケールファクタのレイアウトに合わせるためのリシェイプが必要で、実行時の追加オーバーヘッドを招きます。したがって私たちの目標は、次の要件を満たす量子化カーネルを設計することでした。

  • メモリ帯域幅を完全に使い切り、できれば大きな M 次元のシェイプに対して 6 TB/s を超えることが望ましい。
  • スケール行列を正確なレイアウトtcgen05.mma が期待する形で tcgen05.mma生成し、MXFP8 カーネルが HBM → SMEM → TMEM のデータフローを中間変換なしで実行できるようにします。

瞬時に最高性能をもたらす単一の「魔法の」設計は存在せず、実際には適切なプリミティブの組み合わせと、試行錯誤から見出した数多くのマイクロ最適化の積み重ねでした。最大の成果は、ワープ内のオーバーヘッドを減らすためにTMAのスウィズルをやめ、手動のスウィズルパターンに切り替えたこと、手動によるスレッドブロック内オーバーラップではなくワープスケジューラとスレッドブロック間の非同期性に委ねたこと、そしてSMEMとレジスタ使用量を抑えてSMの占有率を高めたことから得られました。加えて、TMAの活用、SMEMバンク競合の排除、高速ベクトル組み込み関数の活用など、標準的な最適化も一通り行いました。ただし、一般化できることは実のところあまり多くありません。

これらの最適化により、tcgen05.mmaと直接互換のレイアウトでスケール行列を生成しつつ、6.2+ TB/sを維持するMXFP8量子化カーネルを実装しました。私たちの知る限り、これはMoEトレーニング向けに利用可能な最速のMXFP8量子化カーネルです。

NVIDIA TransformerEnginePyTorch TorchAO当社

ナイーブ

5236.35 GB/s

5245.15 GB/s

該当なし

reshape あり

4430.27 GB/s

4524.45 GB/s

6212.21 GB/s

表4. メモリ帯域幅の活用率による MXFP8 量子化カーネル比較(E4M3、32 ブロックスケーリング)。「reshape あり」には、tcgen05.mma 用にスケール係数を reshape する時間を含む。

MXFP8 量子化に対するその他の最適化としては、非転置と転置の両方の結果を出力する融合量子化カーネル(後者はバックワードパスに必要)を構築したことに加え、HBM へのアクセスを可能な限り最小化するために、MXFP8 の量子化ロジックを他のカーネルへ直接融合したことが挙げられます。例えば、融合 MXFP8 SwiGLU カーネルのプロローグとエピローグに、MXFP8 の非量子化/量子化ロジックを付加しました。

さらなる高速化!

上記の最適化により、順伝播と逆伝播の両方で MoE レイヤーの実行を 3.5 倍高速化しつつ、同等の学習品質を維持できました。社内モデルの一つで検証したところ、Blackwell ではエンドツーエンドで 1.5 倍、従来の Hopper 構成と比べては 2 倍の学習スループット向上が確認されました(GPU あたり毎秒トークン数(TPS/GPU)で測定)。当社のスタックは、オープンソース代替のいかなる組み合わせよりも高速に MXFP8 の MoE 学習を実行できると考えています。

Figure 9. End-to-end training TPS per GPU (internal model).
Figure 9. End-to-end training TPS per GPU (internal model).
Hopper BF16Blackwell BF16Blackwell MXFP8

MoE フォワード(ms)

32.36 ms

25.96 ms

9.45 ms

MoE 逆伝播(ms)

63.24 ms

59.17 ms

17.04 ms

エンドツーエンド(TPS/GPU)

12k TPS/GPU

16k TPS/GPU

24k TPS/GPU

表5. パフォーマンス比較(内部モデル)。注:Blackwell BF16 MoE は Hopper のコードをそのまま移植したものです

締めくくり

まだやるべきことは多く、最適化すべきカーネルも数多く残っています。現在は、より効率的なマルチGPU間通信の構築、カスタムAttentionカーネルの改良、そして今後のMoEトレーニング実行に向けたFP4への移行準備に注力しています。

高性能カーネルの作成と最適化、大規模なコーディングモデルの学習、または当社の取り組み全般に関心がある方は、ぜひご連絡ください。hiring@anysphere.co までご連絡をお待ちしています。

本記事をお読みいただき、貴重なフィードバックをお寄せくださった Benjamin Spector、Sasha Rush、Less Wright、Federico Cassano、Rohan Shah、Jacob Jackson、Sujay Jayakar、Shengtong Zhang に深く感謝いたします。

カテゴリー: 研究

著者: Stuart Sul