跳至主內容

採用自訂 MXFP8 核心,MoE 訓練可快 1.5 倍

作者: Stuart Sul 屬於 研究

透過全面重構,在 Blackwell GPU 上將 MoE 層效能提升至 3.5 倍。

我們致力於打造全球最卓越的 AI(人工智慧)程式碼模型,但訓練大型語言模型的成本可能十分高昂。以我們規模最大的內部模型為例,訓練往往需要動用數以萬計的 GPU、歷時數週。這不僅計算資源花費驚人,也會拖慢改進成果送達使用者的速度。

我們最近開始把 Hopper GPU(H100)升級為 Blackwell GPU(B200),並把這視為深入最佳化訓練工作負載的契機。效能剖析顯示,主要瓶頸在 Mixture‑of‑Experts(MoE)層,該層以MegaBlocks實作,佔前向傳遞時間近 53%,以及反向傳遞時間的 27%。

這也是為什麼在過去幾週裡,我們直接在 GPU 核層從頭重寫了整個 MoE 層,完全不依賴任何 CUDA 程式庫。相反地,我們使用純粹、傳統的 CUDA 與 PTX,再點綴少量的ThunderKittens。結果是,我們在前向與反向傳播的 MoE 層效能上達成了3.5 倍的提升,在 Blackwell 上帶來 1.5 倍的端到端訓練加速,並較我們原本的 Hopper 設定快上 2 倍。我們相信,我們的技術堆疊已經比當前任意組合的開源替代方案更快。

Figure 1. Relative speedups of MXFP8 MoE compared to BF16 (normalized to 1.0).
Figure 1. Relative speedups of MXFP8 MoE compared to BF16 (normalized to 1.0).

這項改進主要來自從 BF16 切換到MXFP8,而我們幾乎未犧牲訓練品質。不過,這也讓我們體會到,降至低精度說來容易做來難。若未謹慎處理,受限於各類核心(kernel)開銷,MXFP8 訓練相較 BF16 可能僅帶來極小幅的效能提升。此外,MXFP8 的訓練做法並未廣為流傳,意謂著你得自行摸索出正確途徑。

以下是我們的方法,以及我們在 Cursor 進行的機器學習工作一覽。

Microscaling(MX)資料格式快速介紹

降低大型深度學習模型計算成本的常見方式之一,是採用較低精度的活化值(activations)與權重。不過,若將它們轉換為更窄的位元寬度格式(例如 8 位元或更少),除非對數值進行適當縮放,否則會引入不可接受的四捨五入誤差。舉例來說,大型模型中的某些權重可能是 0.0001、0.0005,或 –0.0007,但若直接轉成 FP8,這些數值都會被捨入成同一個數字——零,因為在 FP8E4M3 中可表示的最小正值約為 0.0019。

為了解決此問題,通常會對每個張量套用縮放因子,將張量重新縮放,同時讓其數值維持在目標資料格式可表示的範圍內。這可確保充分利用可用的動態範圍。舉例來說,若某個張量的數值皆介於 –0.0001 到 0.0001 之間,其縮放因子可以設為 4,480,000。如此一來,張量縮放後的數值將落在 [–448, 448] 範圍內,對應到 FP8E4M3 的可表示界限。

Microscaling 更進一步,將縮放套用到張量中更細的子區塊,而不是對整個張量使用單一縮放因子。Microscaling(MX)格式透過定義一組低精度、微縮放的資料格式來標準化此作法。完整細節請見其規格,其中定義了以下符合 MX 的具體格式:

Table 1. Concrete MX-compliant formats.
Table 1. Concrete MX-compliant formats.

例如,若使用 MXFP8 格式並採用 E4M3 元素類型,則資料由 FP8E4M3 元素構成,且對每個連續的 32 個元素區塊套用 FP8E8M0 縮放因子,如下圖所示。

Figure 2. MXFP8 quantization example: Each 1x32 block shares a scale.
Figure 2. MXFP8 quantization example: Each 1x32 block shares a scale.

雖然微尺度化(microscaling)能帶來顯著效能提升,但在實務上導入會面臨多項挑戰,且高度依賴底層硬體。接下來先說明為何在 NVIDIA Blackwell GPU 上導入微尺度化特別不容易。

1. Tensor 記憶體與 CUDA 核心讓反量化的氣氛全被破壞

在微縮尺度的 FP8 矩陣乘法中,計算會沿著歸約維度拆分為較小的區塊化步驟。每次完成一個區塊的矩陣乘法後,會使用縮放因子將部分結果反量化並累加,然後再繼續處理下一個區塊。

例如,在知名的DeepSeek V3(DSV3)技術報告中,AA矩陣以大小為1×1281 \times 128的區塊進行縮放,而BB矩陣則以大小為128×128128 \times 128的區塊進行縮放。這表示步驟如下:

1. 執行 128 個區塊的矩陣相乘:

C_block = A[:, k*128:(k+1)*128] @ B[k*128:(k+1)*128, :]

2. 使用縮放因子進行反量化:

C_block = A_scale[:, k] * C_block * B_scale[k, :]

其中,A_scale 沿 K 維度對 128 個元素進行廣播,而 B_scale 則在兩個維度上各對 128 個元素進行廣播

3. 累加反量化後的區塊:

C += C_block

4. 沿著歸約維度繼續進行到下一個區塊:

k += 1

此方法在 Hopper 架構(DSV3 於其上訓練)上自然運作良好,原因是:(1) 張量核心的矩陣乘法(透過 wgmma 指令)其結果會累積在暫存器中;(2) 你可以將矩陣乘法管線化,在使用 CUDA 核心進行反量化的同時,非同步啟動其他張量核心的矩陣乘法。由於一切都累積在暫存器中,因此矩陣乘法之間無需額外的資料搬移。

在 Blackwell GPU 上,情況已不再如此,原因是tensor memory(TMEM)。TMEM 是在 Blackwell GPU 上新增的一組晶片內、SM 本地的記憶體。與 Hopper GPU 不同,後者會將張量核心的矩陣乘法結果直接累積在暫存器中,Blackwell GPU 則會將結果累積於 TMEM(透過 tcgen05.mma 指令)。若要對累加器進行自訂算術運算,必須先將結果從 TMEM 傳送到暫存器,以 CUDA 核心處理後再寫回 TMEM;接著還要等待所有這些指令完成,因為資料搬移是非同步的。儘管 TMEM 的存取速度比共享記憶體(SMEM)更快,這仍會大幅降低張量核心的佔用率。

Figure 3. Gantt chart taken from our custom Blackwell attention kernel. First row shows the tensor core activity (QK<sup>T</sup>). Second row shows the CUDA core activity (TMEM → registers, then softmax). TMEM → register latency causes the tensor core pipeline bubble.
Figure 3. Gantt chart taken from our custom Blackwell attention kernel. First row shows the tensor core activity (QKT). Second row shows the CUDA core activity (TMEM → registers, then softmax). TMEM → register latency causes the tensor core pipeline bubble.

乍看之下,這提示了一種管線化的作法:將 TMEM 劃分為多個區塊,當某一個區塊執行矩陣乘法時,另一個區塊同時進行資料搬移與反量化。

但這裡還有另一個問題:雖然 Blackwell 的張量核心相較於 Hopper 的 TFLOP/s 提升了兩倍,但 FP32 CUDA 核心只提升了約 33%(60 → 80 TFLOP/s)。在 32 區塊縮放(32-block scaling)下,反量化的計算量是矩陣乘法的 1/32(兩者都涉及乘法與加法;使用 CUDA 核心進行反量化時還必須手動累加),但反量化的速度卻只有矩陣乘法的 1/56(理論值為 4,500 TFLOP/s)。因此,在 CUDA 核心上進行反量化所花的時間,幾乎可能是矩陣乘法的 1.76 倍。這太糟糕了。

Hopper(H100 SXM5)Blackwell(B200)

FP8 Tensor Core 吞吐量

1,979 TFLOP/s

4,500 TFLOP/s

FP32 CUDA 核心吞吐量

60 TFLOP/s

80 TFLOP/s

32 區塊去量化時間

(相對於矩陣乘法)

1.03x

1.76x

表 2:Hopper 與 Blackwell 的相對反量化成本。

就實證而言,我們無法透過上述方法的任何變體,超過 Hopper 實際可達的 FP8 吞吐量 1,500 TFLOP/s。而且這還甚至未將量化的額外開銷計入。

2. 千刀萬剮式的量化

每個矩陣在送入 FP8 矩陣乘法核心之前都必須先進行 FP8 量化。盡可能縮短並隱藏量化時間至關重要;若處理不當,量化核心可能會壓過 GPU 的執行時間。

為了理解量化的主導程度,我們來看一個簡單的矩陣乘法C=A×BC = A \times B,其中:

AAM×KM \times K

BBK×NK \times N

CCM×NM \times N

在用於 MoE 訓練的分組矩陣乘法中,MM相較於KKNN通常要大得多。因此我們取:

M=131,072M = 131,072

K=7,168K = 7,168

N=2,048N = 2,048

由此,矩陣乘法本身所需的浮點運算總量為:

131,072×7168×2048×2FLOP=3.85TFLOP131{,}072 \times 7168 \times 2048 \times 2 \,\text{FLOP} = 3.85 \,\text{TFLOP}

我們的基準測試顯示,Blackwell GPU(B200)在 FP8 矩陣乘法的吞吐量約為 3,300 TFLOP/s。這表示執行該矩陣乘法的預期壁鐘時間為:

3.85 TFLOP / 3,300 TFLOP/s=1.16 毫秒3.85 \text{ TFLOP } / \text{ } 3{,}300 \text{ TFLOP/s} = 1.16 \text{ 毫秒}

不過,我們也需要考量將矩陣 A 與 B 量化所需的時間。由於量化受記憶體頻寬限制,關鍵在於總資料傳輸量。假設原始矩陣為 BF16(2 位元組),且縮放區塊大小為 32。這表示我們必須:

  • 載入 A:131,072×7,168×2B=1.88 GB131{,}072 \times 7{,}168 \times 2\text{B} = 1.88 \text{ GB}
  • 載入 B:7,168×2,048×2B=0.029 GB7{,}168 \times 2{,}048 \times 2\text{B} = 0.029 \text{ GB}
  • 儲存量化後的 A:131,072×7,168×1B=0.94 GB131{,}072 \times 7{,}168 \times 1\text{B} = 0.94 \text{ GB}
  • 儲存量化後的 B:7,168×2,048×1B=0.015 GB7{,}168 \times 2,048 \times 1\text{B} = 0.015 \text{ GB}
  • 為 A 儲存比例因子:131,072×7,168×132×1B=0.029 GB131{,}072 \times 7{,}168 \times \frac{1}{32} \times 1\text{B} = 0.029 \text{ GB}
  • 為 B 儲存縮放係數:7,168×2,048×132×1B=0.0005 GB7{,}168 \times 2{,}048 \times \frac{1}{32} \times 1\text{B} = 0.0005 \text{ GB}

合計約 2.9 GB 的高頻寬記憶體(HBM)讀寫。假設 B200 的可持續 HBM 吞吐量為 6.5 TB/s,則經過最佳化的 FP8 量化核心所需時間為:

2.9 GB / 6.5 TB/s0.44 ms2.9 \text{ GB } / \text{ } 6.5 \text{ TB/s} \approx 0.44 \text{ ms}

這幾乎佔了矩陣乘法時間的 40%。在常見情況下,若我們也為反向傳播將 A 矩陣先轉置並量化,時間會翻倍到 0.88 ms,約佔矩陣乘法時間的 76%

儘管理論上 FP8 的矩陣乘法比 BF16 快 2 倍,但量化所需的時間可能會大幅抵銷這項效能提升。上述分析也偏樂觀,因為它假設區塊縮放(block‑scaled)的 FP8 矩陣乘法能以 3,300 TFLOP/s 運行,這遠高於端到端 MoE 訓練中通常可達到的水準。在最糟情況下,量化花費的時間甚至可能比 FP8 矩陣乘法更長,從而使整體 FP8 計算反而比 BF16 更慢。

確實有一些快速的開源 FP8 量化核心,分別來自 NVIDIA 的TransformerEngine與 PyTorch 的TorchAO。然而,我們的微型基準測試顯示,它們仍有頻寬未被充分發揮,實測約為 4.5 TB/s。僅依賴 SM 內部指令(例如 cp.async.bulk),我們知道 Blackwell 可以輕鬆達到 6~7 TB/s 的 HBM 頻寬。

此外,NVIDIA 的 MXFP8 tcgen05.mma 指令(用於張量核心矩陣乘法的 PTX 指令)需要一種略顯不直觀的縮放因子配置。在 32 區塊縮放時,TransformerEngine 或 TorchAO 的量化核心會以較為單純的 M x N / 32 佈局回傳縮放因子。接著必須重整形(可在 PyTorch 中進行,或將重整形邏輯融合進其他核心),兩種做法都會拖慢效能。實務上,你確實不希望在矩陣乘法核心內處理縮放因子。載入它們最快的方式是走 HBM → SMEM(cp.async.bulk)路徑,再走 SMEM → TMEM(tcgen05.cp)路徑;一旦縮放因子繞道經過暫存器分塊,整個張量的節奏就沒了。

接下來,我們將說明如何解決上述挑戰,先從我們對量化(quantization)的做法談起。

選擇合適的低精度方案

為了達到與 BF16 訓練品質相當的水準,我們進行了一系列低精度實驗,測量各種方法相對於 BF16 的偏差。基於這些結果,我們找出一種在我們的工作負載下,其訓練損失收斂幾乎與 BF16 相同的做法。

更具體來說,我們使用 MXFP8 格式:元素資料型別為 FP8E4M3(4 個指數位、3 個尾數位)、尺度(scale)資料型別為 FPE8M0(8 個指數位),且縮放區塊大小為 32。同時,我們採用論文「Recipes for Pre-training LLMs with MXFP8」中的 MXFP8 量化方案。令:

  • BF16(或 FP32)向量V={Vii=0,,31}V = \{ V_i \mid i = 0, \ldots, 31 \}
  • 對應的 FP8E4M3 向量Q={Qii=0,,31}Q = \{ Q_i \mid i = 0, \ldots, 31 \}
  • FP8E8M0 尺度SS

我們按以下方式計算QQSS

S=cast_to_fp8e8m0(absolute_max(V)/448)S = \text{\texttt{cast\_to\_fp8e8m0}}\big(\text{\texttt{absolute\_max}}(V) / 448\big)
Qi=cast_to_fp8e4m3(Vi/S)Q_i = \text{\texttt{cast\_to\_fp8e4m3}}\big(V_i / S\big)

其中,cast_to_fp8e8m0 會向進位到最接近的 2 的冪次,並將最小值夾限為21272^{-127};而 cast_to_fp8e4m3 會對超出範圍的數值進行飽和處理,並四捨五入至最接近值;遇到中點時採用「就近取偶」規則。

透過這樣做,我們的 FP8 訓練損失與 BF16 訓練損失一致:

Figure 4. BF16 vs MXFP8 Training Loss over 10k steps: nearly indistinguishable.
Figure 4. BF16 vs MXFP8 Training Loss over 10k steps: nearly indistinguishable.
Figure 5. BF16 vs MXFP8 Training Loss (9k–10k steps) with a specific data point at step 9832.
Figure 5. BF16 vs MXFP8 Training Loss (9k–10k steps) with a specific data point at step 9832.

採用 tcgen05 MXFP8 區塊縮放矩陣乘法

在 NVIDIA Blackwell 架構上,MX 格式已內建於張量核心。區塊縮放透過 tcgen05.mma...block_scale 指令啟用,並由硬體處理。由於一切都在張量核心的矩陣乘法過程中完成,無需為反量化將資料自 TMEM 搬出。因此,我們的 MXFP8 矩陣乘法核心設計必須圍繞 tcgen05.mma,並在範疇限制內運作,以達到最佳效能。

有幾點重要注意事項。首先,tcgen05.mma 指令只需單一執行緒即可非同步啟動。這與 Hopper 不同;在 Hopper 中,wgmma 指令需要整個 warpgroup(128 個執行緒)才能非同步啟動。因此,我們必須偏離 Hopper 核心中常見的生產者—消費者模型,在該模型中有 256 個以上的執行緒專門負責啟動矩陣乘法。

其次,tcgen05.mma 支援 2-CTA 矩陣乘法,兩個 SM 透過共享 B 矩陣協同執行矩陣相乘。這可同時降低記憶體流量與共享記憶體使用量,讓矩陣乘法管線更深。基準測試顯示,與未叢集的版本相比,2-CTA 模式在 MXFP8 矩陣乘法上可帶來約 15~20% 的效能提升,對達到峰值效能至關重要。

第三,如前所述,tcgen05.mma 會將結果累積在 TMEM,而非暫存器。這雖可降低暫存器壓力,但也會因 tcgen05.ldtcgen05.st 指令在 TMEM 與暫存器之間引入額外的資料搬移。我們必須將這些搬移降到最低。

最後,tcgen05.mma 指令要求縮放係數(scale factor)位於 TMEM。然而,沒有直接的方法可將縮放係數從 HBM 載入到 TMEM。最快的作法是先使用 cp.async.bulk.tensor 指令(利用 Tensor Memory Accelerator,簡稱 TMA)把資料從 HBM 載入到晶片上的 SMEM,接著再使用 tcgen05.cp 指令將其從 SMEM 傳送到 TMEM。為了讓此流程可行,縮放係數必須以 tcgen05.mma 所要求(預期)的記憶體配置(layout)儲存,本文稍後會加以說明。

在這些考量中涉及的所有指令 — tcgen05.mmacp.async.bulk.tensortcgen05.cptcgen05.ldtcgen05.st — 都由單一執行緒以非同步方式啟動。這使我們能套用 warp 專門化,並設計具管線化的資料流,使用 TMEM 與 SMEM 作為循環緩衝區。

為了做到這點,我們先將 TMEM 與 SMEM 進行分割。在 Blackwell 上,提供的是 128×512 的 TMEM(每格 32 位元),以及每個 threadblock 擁有 227 KB 的連續 SMEM。我們把 TMEM 劃分為 5 個槽位,用於儲存 A 與 B 的縮放值,並預留空間給矩陣乘法累加(MMA)。同樣地,在 SMEM 中我們保留空間以便將 MMA 結果回寫到 HBM,並將其餘空間分割為 5 個槽位,用於載入輸入分塊(tiles)與縮放因子。下方說明其配置。

Figure 6. Simplified TMEM allocation: accumulator region plus 5 slots each for A and B scales
Figure 6. Simplified TMEM allocation: accumulator region plus 5 slots each for A and B scales
Figure 7. Simplified SMEM allocation: 5 slots reserved for input tiles and scales
Figure 7. Simplified SMEM allocation: 5 slots reserved for input tiles and scales

在這樣的設定下,我們設計了一條管線:部分 warp 會持續從 HBM 載入輸入資料磚並縮放後放入 SMEM;另有一些 warp 會將縮放參數從 SMEM 移至 TMEM;還有一些負責啟動 MMA;此外,部分 warp 會偶爾把 TMEM 的累加器載入至暫存器、存回 SMEM,並透過 TMA 寫回 HBM。

Figure 8. Simplified MXFP8 Matrix Multiplication Pipeline
Figure 8. Simplified MXFP8 Matrix Multiplication Pipeline

更具體地說,我們為每個 threadblock 指派 3 個 warpgroup(共 384 個執行緒),並將這些 warpgroup 分工為兩類:其中 2 個 warpgroup 僅負責 TMEM → register → SMEM → HBM 的資料流,這部分有最高的暫存器壓力。另一個 warpgroup 採用 warp 專門化:warp 0 將輸入分塊從 HBM 載入至 SMEM,warp 1 將 scale 從 HBM 載入至 SMEM,warp 2 將 scale 從 SMEM 載入至 TMEM,而 warp 3 啟動 tensor core 的矩陣乘法。我們也實作了常駐網格(persistent grid)樣式,為每個 SM(Blackwell GPU 上為 148 個)指派單一 threadblock,使在將結果寫回 HBM 的同時可並行載入新的輸入分塊。

此管線的偽代碼如下:

if (warpgroup_id < 2) {
    for (int i = 0; i < num_tiles; i++) {
        mbarrier_wait_for_final_matmul_completion(); // mbarrier.try_wait
        async_load_from_TMEM(reg, TMEM); // tcgen05.ld
        wait_for_load_completion(); // tcgen05.wait
        // This is iterated in the actual implementation to save SMEM 
        store_to_SMEM(SMEM, reg);
        TMA_async_store(HBM, SMEM); // cp.async.bulk.tensor
    }
} else {
    if (warp_id == 0) {
        for (int i = 0; i < num_tiles; i++) {
            for (int j = 0; j < num_iters; j++) {
                mbarrier_wait_for_matmul_completion();
                // load input tiles
                TMA_async_load(SMEM, HBM);
            }
        }
    } else if (warp_id == 1) {
        for (int i = 0; i < num_tiles; i++) {
            for (int j = 0; j < num_iters; j++) {
                mbarrier_wait_for_tcgen05_cp_completion();
                // load scales (HBM -> SMEM)
                TMA_load(SMEM, HBM);
            }
        }
    } else if (warp_id == 2) {
        for (int i = 0; i < num_tiles; i++) {
            for (int j = 0; j < num_iters; j++) {
                mbarrier_wait_for_matmul_completion();
                mbarrier_wait_for_scale_SMEM_load();
                // load scales (SMEM -> TMEM)
                load_to_TMEM(TMEM, SMEM); // tcgen05.cp
            }
        }
    } else if (cta_rank == 0) { // 2-CTA MMA is launched by a single CTA
        for (int i = 0; i < num_tiles; i++) {
            mbarrier_wait_for_TMEM_clear();
            for (int j = 0; j < num_iters; j++) {
                mbarrier_wait_for_input_SMEM_load();
                mbarrier_wait_for_scale_TMEM_load();
                // tcgen05.mma.cta_group::2.mxf8f6f4.block_scale
                launch_matmul(SMEM, SMEM, TMEM);
            }
        }
    }
}

在 Blackwell GPU 上進行區塊縮放矩陣乘法時,無可避免的一項限制是 TMEM 的大小。我們的微基準測試顯示,當完整的 128×512 TMEM 作為累加器使用時,Blackwell 張量核心可達到最高吞吐量。對於 2-CTA 的 FP8 矩陣乘法,這對應於持續執行兩條 256×32×256 的 tcgen05.mma 指令。每條 tcgen05.mma 會在每個 CTA 消耗一個 128×256 的 TMEM 區域,因此兩條指令合計會在兩個 CTA 上完整占用 128×512 的 TMEM 陣列。

然而,當縮放因子也必須駐留於 TMEM 時,我們一次只能在 TMEM 的 128×256 區域內執行一條 256×32×256 的 tcgen05.mma 指令。因此,效能下滑在所難免。舉例而言,在此限制下,16,384×16,384×16,384 的 FP8 矩陣乘法吞吐量會從 3,200 TFLOP/s 降至 3,040 TFLOP/s。

這些吞吐量數值僅適用於純 FP8 矩陣乘法。使用 MXFP8 區塊式縮放時,因 TMEM 管線化的額外負擔,吞吐量不可避免會進一步下降。實務上,對採用區塊式縮放的 MXFP8 矩陣乘法核心在清空 L2 快取後,我們可達約 2,750 TFLOP/s。即便如此,這仍比標準 BF16 矩陣乘法快約 1.83 倍;在最佳形狀下,後者通常約為 1,500~1,550 TFLOP/s。開局還不錯!

擴充至 MXFP8 分組矩陣乘法

獨立的 MXFP8 矩陣乘法核心是實用的第一步,但在使用 MXFP8 進行 MoE 訓練時,其適用情境有限(例如共享專家場景)。要在 MXFP8 中完整支援 MoE,我們需要分組矩陣乘法核心,具體而言:

  1. 分組前向傳播(Fprop)/資料梯度(Dgrad)
  2. 群組化權重梯度(Wgrad)

請注意,正因為存在這些變體,我們才從頭打造這些核心。迄今為止,我們尚未找到任何能完全支援以 32 區塊縮放進行 MXFP8 MoE 訓練的開源替代方案。

在核心層級,分組的 Fprop 與 Dgrad 具有相同的結構。唯一的差異是 Dgrad 需要進行累加,因為輸入張量會同時通過 up 與 gate 投影。不過,這可以很容易地透過將 cp.async.bulk.tensor 指令替換為 cp.reduce.async.bulk.tensor 來實作;後者可非同步地對 HBM 執行具原子性的加法寫入。

已知下列矩陣:

  • Anum_tokens × in_dim
  • WE × in_dim × out_dim,其中 E 為當前 rank 上的專家數
  • Onum_tokens x out_dim

在假設 tokens 依專家索引排序的情況下,分組的 Fprop/Dgrad 核心會執行:

for i in range(num_routed_experts):
    start = 0 if i == 0 else end
    end = start + assigned_tokens_per_expert[i]
    O[start:end, :] = A[start:end, :] @ W[i, :, :]

相較之下,Grouped Wgrad 的差異在於專家切分是沿著 K(歸約)軸,而不是沿著 M 軸。其計算如下:

for i in range(num_routed_experts):
    start = 0 if i == 0 else end
    end = start + assigned_tokens_per_expert[i]
    W_grad[i, :, :] = A.T[:, start:end] @ O[start:end, :]

核心抽象層

在核心層級,常見的工作單位是在指定的列、行與歸約範圍上執行矩陣乘加運算。將此單位抽象化,對於實作分組矩陣乘法核心極為有用,也讓我們能以最少變更重用原有的 MXFP8 矩陣乘法核心。因此,我們將原始核心抽取如下:

// All units of 128
int expert_idx = ... // 0 for 2D case
int row_block_idx = ...
int col_block_idx = ...
int reduction_block_start_idx = ...
int reduction_block_end_idx = ...

// Based on MXFP8 matrix multiplication implemented above
// Runs at 256x128 granularity when possible (if not, 128x128)
run_mxfp8_matmul(
    expert_idx,
    row_block_idx, 
    col_block_idx, 
    reduction_block_start_idx, 
    reduction_block_end_idx
);

透過此抽象化,實作分組的 MXFP8 矩陣乘法變體歸結為在啟動上述抽象之前設計合適的迴圈結構,並正確指派索引。不過,天真/樸素的迴圈方式可能因 L2 快取利用率不佳而顯著降低效能。

透過按專家劃分的超分組進行 L2 快取優化

維持高 L2 快取使用率至關重要。我們的基準測試顯示,低效的 HBM 存取模式可能使分組矩陣乘法的核心效能下降近 50%。為了解決這個問題,我們採用了 supergrouping——ThunderKittens 核心中的啟發式方法——透過讓任一時間點由全部 148 個 SM 所計算的輸出矩陣區域盡可能接近正方形,來最大化 L2 的重複利用。你可以閱讀上方連結的核心程式碼以進一步了解。

我們對分組矩陣乘法核心的一項關鍵改進,是對每位專家套用超分組(supergrouping),只針對當前專家所屬的子矩陣進行處理,而非整個輸出矩陣。這在分組的 Wgrad 上特別有效,因為受專家分割影響,歸約軸往往較窄。歸約軸越窄,張量核心的利用率就越低,記憶體頻寬因而成為主要瓶頸。

透過合適的矩陣乘法核心抽象與按專家維度進行的 L2 快取最佳化,我們在分組的 MXFP8 矩陣乘法核心上達成約 2,650 TFLOP/s——相較於未分組版本僅下降 4%。太好了!

群組式矩陣乘法效能基準

在 Blackwell 上,針對分組的 Fprop/Dgrad/Wgrad,最接近的開源替代方案是 DeepSeek 的 DeepGEMM。DeepGEMM 與我們略有不同,因其對 A 與 B 矩陣採用 1×128 與 128×128 的尺度分塊,代價是精度有所下降。即便如此,它仍是唯一可用的替代方案,因此我們將 DeepGEMM 整合進內部模型訓練,並對其效能進行分析。雖然 DeepGEMM 在我們的微基準測試中對某些輸入形狀表現出色,但端到端基準測試顯示,在我們的工作負載上,我們的核心效能優於它。

DeepSeek DeepGEMM我們的方法

分組 Fprop/Dgrad

0.67 ms

0.43 ms

分組 Wgrad

0.71 ms

0.65 ms

表 3:內部模型訓練期間,分組 Fprop/Dgrad 與 Wgrad 核心的平均延遲

另需注意,上述基準測試不包含量化時間。DeepGEMM 並未提供最佳化的量化核心,缺少這些核心時,端到端效能可能快速下滑;在最糟的情況下,甚至可能比 BF16 訓練更慢。

打造史上最快的 MXFP8 量化內核

如前所述,現有的 MXFP8 量化核心不僅效果不佳,還需要重新調整形狀以符合 tcgen05.mma 的縮放因子佈局,導致額外的執行時開銷。因此,我們的目標是設計一個量化核心,能夠:

  • 充分吃滿記憶體頻寬;對於 M 維度很大的形狀,最好能達到甚至超過 6 TB/s。
  • 產生符合精確佈局tcgen05.mma 預期的縮放矩陣,使我們的 MXFP8 核心能在無需中間轉換的情況下,執行 HBM → SMEM → TMEM 的資料流。

結果證明,並沒有哪一種能立刻帶來峰值效能的「魔法」設計;真正有效的是選對基礎原語,外加靠反覆試驗摸索出的眾多微調最佳化。最大的收穫來自移除 TMA swizzling,改用手動 swizzling 模式以降低同一個 warp 內的額外負擔;依賴 warp 排程器與 threadblock 之間的非同步,而不是手動在同一個 threadblock 內做重疊;以及盡量降低 SMEM 與暫存器的用量,以提高 SM 的佔用率。除此之外,還有一整套常見的最佳化手法,包括使用 TMA、消除 SMEM bank 衝突、善用快速向量內建函式等。不過,這些其實很難推而廣之。

透過這些最佳化,我們實作了一個 MXFP8 量化核心,能維持 6.2+ TB/s 的效能,並以可直接相容 tcgen05.mma 的版面配置產生縮放矩陣。就我們所知,這是目前用於 MoE 訓練最快的 MXFP8 量化核心。

NVIDIA TransformerEnginePyTorch TorchAO我們的方案

Naive(基本實作)

5236.35 GB/s

5245.15 GB/s

不適用

使用 reshape

4430.27 GB/s

4524.45 GB/s

6212.21 GB/s

表 4。依記憶體頻寬使用率比較 MXFP8 量化核心(E4M3,32 區塊縮放)。「包含 reshape」指包含為 tcgen05.mma 進行縮放因子重塑所花費的時間。

對 MXFP8 量化的其他優化包括:建立一個融合式量化核心,可同時輸出未轉置與已轉置的結果(後者用於反向傳播),並將 MXFP8 量化邏輯直接融入其他核心,以盡可能減少對 HBM 的存取。例如,我們將 MXFP8 的反量化/量化邏輯分別整合到融合式 MXFP8 SwiGLU 核心的前言(prologue)與收尾(epilogue)階段。

速度提升!

透過上述所有最佳化,我們在 MoE 層的前向與反向傳遞上達成3.5 倍的加速,同時維持相同的訓練品質。在我們的一個內部模型上測試時,這在 Blackwell 上帶來了端到端 1.5 倍的訓練加速,並相較於我們原本的 Hopper 設定達到 2 倍加速;以上以每張 GPU 每秒 tokens 數(TPS/GPU)衡量。我們相信,我們的技術堆疊在 MXFP8 的 MoE 訓練速度上,快於任何開源替代方案的組合。

Figure 9. End-to-end training TPS per GPU (internal model).
Figure 9. End-to-end training TPS per GPU (internal model).
Hopper BF16Blackwell BF16Blackwell MXFP8

MoE 前向傳播(毫秒)

32.36 毫秒

25.96 毫秒

9.45 毫秒

MoE 反向傳播(ms)

63.24 ms

59.17 ms

17.04 ms

端到端(TPS/GPU)

12k TPS/GPU

16k TPS/GPU

24k TPS/GPU

表 5。效能比較(內部模型)。注意:Blackwell BF16 MoE 為 Hopper 程式碼的直接移植

最後想法

我們前方仍有大量工作,且還有許多運算核心有待優化。目前我們專注於打造更高效的多 GPU 通訊、改進自訂的 attention 核心,並為未來的 MoE 訓練作業準備轉換至 FP4

如果你有興趣撰寫與優化高效能核心、訓練大型程式碼模型,或對我們的整體工作方向感到興趣,歡迎與我們聯繫。請寄信至hiring@anysphere.co

特別感謝 Benjamin Spector、Sasha Rush、Less Wright、Federico Cassano、Rohan Shah、Jacob Jackson、Sujay Jayakar 與 Shengtong Zhang 閱讀本文並提供寶貴回饋。

歸類於: 研究

作者: Stuart Sul