採用自訂 MXFP8 核心,MoE 訓練可快 1.5 倍
透過全面重構,在 Blackwell GPU 上將 MoE 層效能提升至 3.5 倍。
我們致力於打造全球最卓越的 AI(人工智慧)程式碼模型,但訓練大型語言模型的成本可能十分高昂。以我們規模最大的內部模型為例,訓練往往需要動用數以萬計的 GPU、歷時數週。這不僅計算資源花費驚人,也會拖慢改進成果送達使用者的速度。
我們最近開始把 Hopper GPU(H100)升級為 Blackwell GPU(B200),並把這視為深入最佳化訓練工作負載的契機。效能剖析顯示,主要瓶頸在 Mixture‑of‑Experts(MoE)層,該層以MegaBlocks實作,佔前向傳遞時間近 53%,以及反向傳遞時間的 27%。
這也是為什麼在過去幾週裡,我們直接在 GPU 核層從頭重寫了整個 MoE 層,完全不依賴任何 CUDA 程式庫。相反地,我們使用純粹、傳統的 CUDA 與 PTX,再點綴少量的ThunderKittens。結果是,我們在前向與反向傳播的 MoE 層效能上達成了3.5 倍的提升,在 Blackwell 上帶來 1.5 倍的端到端訓練加速,並較我們原本的 Hopper 設定快上 2 倍。我們相信,我們的技術堆疊已經比當前任意組合的開源替代方案更快。

這項改進主要來自從 BF16 切換到MXFP8,而我們幾乎未犧牲訓練品質。不過,這也讓我們體會到,降至低精度說來容易做來難。若未謹慎處理,受限於各類核心(kernel)開銷,MXFP8 訓練相較 BF16 可能僅帶來極小幅的效能提升。此外,MXFP8 的訓練做法並未廣為流傳,意謂著你得自行摸索出正確途徑。
以下是我們的方法,以及我們在 Cursor 進行的機器學習工作一覽。
Microscaling(MX)資料格式快速介紹
降低大型深度學習模型計算成本的常見方式之一,是採用較低精度的活化值(activations)與權重。不過,若將它們轉換為更窄的位元寬度格式(例如 8 位元或更少),除非對數值進行適當縮放,否則會引入不可接受的四捨五入誤差。舉例來說,大型模型中的某些權重可能是 0.0001、0.0005,或 –0.0007,但若直接轉成 FP8,這些數值都會被捨入成同一個數字——零,因為在 FP8E4M3 中可表示的最小正值約為 0.0019。
為了解決此問題,通常會對每個張量套用縮放因子,將張量重新縮放,同時讓其數值維持在目標資料格式可表示的範圍內。這可確保充分利用可用的動態範圍。舉例來說,若某個張量的數值皆介於 –0.0001 到 0.0001 之間,其縮放因子可以設為 4,480,000。如此一來,張量縮放後的數值將落在 [–448, 448] 範圍內,對應到 FP8E4M3 的可表示界限。
Microscaling 更進一步,將縮放套用到張量中更細的子區塊,而不是對整個張量使用單一縮放因子。Microscaling(MX)格式透過定義一組低精度、微縮放的資料格式來標準化此作法。完整細節請見其規格,其中定義了以下符合 MX 的具體格式:

例如,若使用 MXFP8 格式並採用 E4M3 元素類型,則資料由 FP8E4M3 元素構成,且對每個連續的 32 個元素區塊套用 FP8E8M0 縮放因子,如下圖所示。

雖然微尺度化(microscaling)能帶來顯著效能提升,但在實務上導入會面臨多項挑戰,且高度依賴底層硬體。接下來先說明為何在 NVIDIA Blackwell GPU 上導入微尺度化特別不容易。
1. Tensor 記憶體與 CUDA 核心讓反量化的氣氛全被破壞
在微縮尺度的 FP8 矩陣乘法中,計算會沿著歸約維度拆分為較小的區塊化步驟。每次完成一個區塊的矩陣乘法後,會使用縮放因子將部分結果反量化並累加,然後再繼續處理下一個區塊。
例如,在知名的DeepSeek V3(DSV3)技術報告中,矩陣以大小為的區塊進行縮放,而矩陣則以大小為的區塊進行縮放。這表示步驟如下:
1. 執行 128 個區塊的矩陣相乘:
C_block = A[:, k*128:(k+1)*128] @ B[k*128:(k+1)*128, :]2. 使用縮放因子進行反量化:
C_block = A_scale[:, k] * C_block * B_scale[k, :]其中,A_scale 沿 K 維度對 128 個元素進行廣播,而 B_scale 則在兩個維度上各對 128 個元素進行廣播
3. 累加反量化後的區塊:
C += C_block4. 沿著歸約維度繼續進行到下一個區塊:
k += 1此方法在 Hopper 架構(DSV3 於其上訓練)上自然運作良好,原因是:(1) 張量核心的矩陣乘法(透過 wgmma 指令)其結果會累積在暫存器中;(2) 你可以將矩陣乘法管線化,在使用 CUDA 核心進行反量化的同時,非同步啟動其他張量核心的矩陣乘法。由於一切都累積在暫存器中,因此矩陣乘法之間無需額外的資料搬移。
在 Blackwell GPU 上,情況已不再如此,原因是tensor memory(TMEM)。TMEM 是在 Blackwell GPU 上新增的一組晶片內、SM 本地的記憶體。與 Hopper GPU 不同,後者會將張量核心的矩陣乘法結果直接累積在暫存器中,Blackwell GPU 則會將結果累積於 TMEM(透過 tcgen05.mma 指令)。若要對累加器進行自訂算術運算,必須先將結果從 TMEM 傳送到暫存器,以 CUDA 核心處理後再寫回 TMEM;接著還要等待所有這些指令完成,因為資料搬移是非同步的。儘管 TMEM 的存取速度比共享記憶體(SMEM)更快,這仍會大幅降低張量核心的佔用率。

乍看之下,這提示了一種管線化的作法:將 TMEM 劃分為多個區塊,當某一個區塊執行矩陣乘法時,另一個區塊同時進行資料搬移與反量化。
但這裡還有另一個問題:雖然 Blackwell 的張量核心相較於 Hopper 的 TFLOP/s 提升了兩倍,但 FP32 CUDA 核心只提升了約 33%(60 → 80 TFLOP/s)。在 32 區塊縮放(32-block scaling)下,反量化的計算量是矩陣乘法的 1/32(兩者都涉及乘法與加法;使用 CUDA 核心進行反量化時還必須手動累加),但反量化的速度卻只有矩陣乘法的 1/56(理論值為 4,500 TFLOP/s)。因此,在 CUDA 核心上進行反量化所花的時間,幾乎可能是矩陣乘法的 1.76 倍。這太糟糕了。
| Hopper(H100 SXM5) | Blackwell(B200) | |
|---|---|---|
FP8 Tensor Core 吞吐量 | 1,979 TFLOP/s | 4,500 TFLOP/s |
FP32 CUDA 核心吞吐量 | 60 TFLOP/s | 80 TFLOP/s |
32 區塊去量化時間 (相對於矩陣乘法) | 1.03x | 1.76x |
表 2:Hopper 與 Blackwell 的相對反量化成本。
就實證而言,我們無法透過上述方法的任何變體,超過 Hopper 實際可達的 FP8 吞吐量 1,500 TFLOP/s。而且這還甚至未將量化的額外開銷計入。
2. 千刀萬剮式的量化
每個矩陣在送入 FP8 矩陣乘法核心之前都必須先進行 FP8 量化。盡可能縮短並隱藏量化時間至關重要;若處理不當,量化核心可能會壓過 GPU 的執行時間。
為了理解量化的主導程度,我們來看一個簡單的矩陣乘法,其中:
為
為
為
在用於 MoE 訓練的分組矩陣乘法中,相較於和通常要大得多。因此我們取:
由此,矩陣乘法本身所需的浮點運算總量為:
我們的基準測試顯示,Blackwell GPU(B200)在 FP8 矩陣乘法的吞吐量約為 3,300 TFLOP/s。這表示執行該矩陣乘法的預期壁鐘時間為:
不過,我們也需要考量將矩陣 A 與 B 量化所需的時間。由於量化受記憶體頻寬限制,關鍵在於總資料傳輸量。假設原始矩陣為 BF16(2 位元組),且縮放區塊大小為 32。這表示我們必須:
- 載入 A:
- 載入 B:
- 儲存量化後的 A:
- 儲存量化後的 B:
- 為 A 儲存比例因子:
- 為 B 儲存縮放係數:
合計約 2.9 GB 的高頻寬記憶體(HBM)讀寫。假設 B200 的可持續 HBM 吞吐量為 6.5 TB/s,則經過最佳化的 FP8 量化核心所需時間為:
這幾乎佔了矩陣乘法時間的 40%。在常見情況下,若我們也為反向傳播將 A 矩陣先轉置並量化,時間會翻倍到 0.88 ms,約佔矩陣乘法時間的 76%!
儘管理論上 FP8 的矩陣乘法比 BF16 快 2 倍,但量化所需的時間可能會大幅抵銷這項效能提升。上述分析也偏樂觀,因為它假設區塊縮放(block‑scaled)的 FP8 矩陣乘法能以 3,300 TFLOP/s 運行,這遠高於端到端 MoE 訓練中通常可達到的水準。在最糟情況下,量化花費的時間甚至可能比 FP8 矩陣乘法更長,從而使整體 FP8 計算反而比 BF16 更慢。
確實有一些快速的開源 FP8 量化核心,分別來自 NVIDIA 的TransformerEngine與 PyTorch 的TorchAO。然而,我們的微型基準測試顯示,它們仍有頻寬未被充分發揮,實測約為 4.5 TB/s。僅依賴 SM 內部指令(例如 cp.async.bulk),我們知道 Blackwell 可以輕鬆達到 6~7 TB/s 的 HBM 頻寬。
此外,NVIDIA 的 MXFP8 tcgen05.mma 指令(用於張量核心矩陣乘法的 PTX 指令)需要一種略顯不直觀的縮放因子配置。在 32 區塊縮放時,TransformerEngine 或 TorchAO 的量化核心會以較為單純的 M x N / 32 佈局回傳縮放因子。接著必須重整形(可在 PyTorch 中進行,或將重整形邏輯融合進其他核心),兩種做法都會拖慢效能。實務上,你確實不希望在矩陣乘法核心內處理縮放因子。載入它們最快的方式是走 HBM → SMEM(cp.async.bulk)路徑,再走 SMEM → TMEM(tcgen05.cp)路徑;一旦縮放因子繞道經過暫存器分塊,整個張量的節奏就沒了。
接下來,我們將說明如何解決上述挑戰,先從我們對量化(quantization)的做法談起。
選擇合適的低精度方案
為了達到與 BF16 訓練品質相當的水準,我們進行了一系列低精度實驗,測量各種方法相對於 BF16 的偏差。基於這些結果,我們找出一種在我們的工作負載下,其訓練損失收斂幾乎與 BF16 相同的做法。
更具體來說,我們使用 MXFP8 格式:元素資料型別為 FP8E4M3(4 個指數位、3 個尾數位)、尺度(scale)資料型別為 FPE8M0(8 個指數位),且縮放區塊大小為 32。同時,我們採用論文「Recipes for Pre-training LLMs with MXFP8」中的 MXFP8 量化方案。令:
- BF16(或 FP32)向量
- 對應的 FP8E4M3 向量
- FP8E8M0 尺度
我們按以下方式計算與:
其中,cast_to_fp8e8m0 會向上進位到最接近的 2 的冪次,並將最小值夾限為;而 cast_to_fp8e4m3 會對超出範圍的數值進行飽和處理,並四捨五入至最接近值;遇到中點時採用「就近取偶」規則。
透過這樣做,我們的 FP8 訓練損失與 BF16 訓練損失一致:


採用 tcgen05 MXFP8 區塊縮放矩陣乘法
在 NVIDIA Blackwell 架構上,MX 格式已內建於張量核心。區塊縮放透過 tcgen05.mma...block_scale 指令啟用,並由硬體處理。由於一切都在張量核心的矩陣乘法過程中完成,無需為反量化將資料自 TMEM 搬出。因此,我們的 MXFP8 矩陣乘法核心設計必須圍繞 tcgen05.mma,並在其範疇限制內運作,以達到最佳效能。
有幾點重要注意事項。首先,tcgen05.mma 指令只需單一執行緒即可非同步啟動。這與 Hopper 不同;在 Hopper 中,wgmma 指令需要整個 warpgroup(128 個執行緒)才能非同步啟動。因此,我們必須偏離 Hopper 核心中常見的生產者—消費者模型,在該模型中有 256 個以上的執行緒專門負責啟動矩陣乘法。
其次,tcgen05.mma 支援 2-CTA 矩陣乘法,兩個 SM 透過共享 B 矩陣協同執行矩陣相乘。這可同時降低記憶體流量與共享記憶體使用量,讓矩陣乘法管線更深。基準測試顯示,與未叢集的版本相比,2-CTA 模式在 MXFP8 矩陣乘法上可帶來約 15~20% 的效能提升,對達到峰值效能至關重要。
第三,如前所述,tcgen05.mma 會將結果累積在 TMEM,而非暫存器。這雖可降低暫存器壓力,但也會因 tcgen05.ld 與 tcgen05.st 指令在 TMEM 與暫存器之間引入額外的資料搬移。我們必須將這些搬移降到最低。
最後,tcgen05.mma 指令要求縮放係數(scale factor)位於 TMEM。然而,沒有直接的方法可將縮放係數從 HBM 載入到 TMEM。最快的作法是先使用 cp.async.bulk.tensor 指令(利用 Tensor Memory Accelerator,簡稱 TMA)把資料從 HBM 載入到晶片上的 SMEM,接著再使用 tcgen05.cp 指令將其從 SMEM 傳送到 TMEM。為了讓此流程可行,縮放係數必須以 tcgen05.mma 所要求(預期)的記憶體配置(layout)儲存,本文稍後會加以說明。
在這些考量中涉及的所有指令 — tcgen05.mma、cp.async.bulk.tensor、tcgen05.cp、tcgen05.ld 和 tcgen05.st — 都由單一執行緒以非同步方式啟動。這使我們能套用 warp 專門化,並設計具管線化的資料流,使用 TMEM 與 SMEM 作為循環緩衝區。
為了做到這點,我們先將 TMEM 與 SMEM 進行分割。在 Blackwell 上,提供的是 128×512 的 TMEM(每格 32 位元),以及每個 threadblock 擁有 227 KB 的連續 SMEM。我們把 TMEM 劃分為 5 個槽位,用於儲存 A 與 B 的縮放值,並預留空間給矩陣乘法累加(MMA)。同樣地,在 SMEM 中我們保留空間以便將 MMA 結果回寫到 HBM,並將其餘空間分割為 5 個槽位,用於載入輸入分塊(tiles)與縮放因子。下方說明其配置。


在這樣的設定下,我們設計了一條管線:部分 warp 會持續從 HBM 載入輸入資料磚並縮放後放入 SMEM;另有一些 warp 會將縮放參數從 SMEM 移至 TMEM;還有一些負責啟動 MMA;此外,部分 warp 會偶爾把 TMEM 的累加器載入至暫存器、存回 SMEM,並透過 TMA 寫回 HBM。

更具體地說,我們為每個 threadblock 指派 3 個 warpgroup(共 384 個執行緒),並將這些 warpgroup 分工為兩類:其中 2 個 warpgroup 僅負責 TMEM → register → SMEM → HBM 的資料流,這部分有最高的暫存器壓力。另一個 warpgroup 採用 warp 專門化:warp 0 將輸入分塊從 HBM 載入至 SMEM,warp 1 將 scale 從 HBM 載入至 SMEM,warp 2 將 scale 從 SMEM 載入至 TMEM,而 warp 3 啟動 tensor core 的矩陣乘法。我們也實作了常駐網格(persistent grid)樣式,為每個 SM(Blackwell GPU 上為 148 個)指派單一 threadblock,使在將結果寫回 HBM 的同時可並行載入新的輸入分塊。
此管線的偽代碼如下:
if (warpgroup_id < 2) {
for (int i = 0; i < num_tiles; i++) {
mbarrier_wait_for_final_matmul_completion(); // mbarrier.try_wait
async_load_from_TMEM(reg, TMEM); // tcgen05.ld
wait_for_load_completion(); // tcgen05.wait
// This is iterated in the actual implementation to save SMEM
store_to_SMEM(SMEM, reg);
TMA_async_store(HBM, SMEM); // cp.async.bulk.tensor
}
} else {
if (warp_id == 0) {
for (int i = 0; i < num_tiles; i++) {
for (int j = 0; j < num_iters; j++) {
mbarrier_wait_for_matmul_completion();
// load input tiles
TMA_async_load(SMEM, HBM);
}
}
} else if (warp_id == 1) {
for (int i = 0; i < num_tiles; i++) {
for (int j = 0; j < num_iters; j++) {
mbarrier_wait_for_tcgen05_cp_completion();
// load scales (HBM -> SMEM)
TMA_load(SMEM, HBM);
}
}
} else if (warp_id == 2) {
for (int i = 0; i < num_tiles; i++) {
for (int j = 0; j < num_iters; j++) {
mbarrier_wait_for_matmul_completion();
mbarrier_wait_for_scale_SMEM_load();
// load scales (SMEM -> TMEM)
load_to_TMEM(TMEM, SMEM); // tcgen05.cp
}
}
} else if (cta_rank == 0) { // 2-CTA MMA is launched by a single CTA
for (int i = 0; i < num_tiles; i++) {
mbarrier_wait_for_TMEM_clear();
for (int j = 0; j < num_iters; j++) {
mbarrier_wait_for_input_SMEM_load();
mbarrier_wait_for_scale_TMEM_load();
// tcgen05.mma.cta_group::2.mxf8f6f4.block_scale
launch_matmul(SMEM, SMEM, TMEM);
}
}
}
}在 Blackwell GPU 上進行區塊縮放矩陣乘法時,無可避免的一項限制是 TMEM 的大小。我們的微基準測試顯示,當完整的 128×512 TMEM 作為累加器使用時,Blackwell 張量核心可達到最高吞吐量。對於 2-CTA 的 FP8 矩陣乘法,這對應於持續執行兩條 256×32×256 的 tcgen05.mma 指令。每條 tcgen05.mma 會在每個 CTA 消耗一個 128×256 的 TMEM 區域,因此兩條指令合計會在兩個 CTA 上完整占用 128×512 的 TMEM 陣列。
然而,當縮放因子也必須駐留於 TMEM 時,我們一次只能在 TMEM 的 128×256 區域內執行一條 256×32×256 的 tcgen05.mma 指令。因此,效能下滑在所難免。舉例而言,在此限制下,16,384×16,384×16,384 的 FP8 矩陣乘法吞吐量會從 3,200 TFLOP/s 降至 3,040 TFLOP/s。
這些吞吐量數值僅適用於純 FP8 矩陣乘法。使用 MXFP8 區塊式縮放時,因 TMEM 管線化的額外負擔,吞吐量不可避免會進一步下降。實務上,對採用區塊式縮放的 MXFP8 矩陣乘法核心在清空 L2 快取後,我們可達約 2,750 TFLOP/s。即便如此,這仍比標準 BF16 矩陣乘法快約 1.83 倍;在最佳形狀下,後者通常約為 1,500~1,550 TFLOP/s。開局還不錯!
擴充至 MXFP8 分組矩陣乘法
獨立的 MXFP8 矩陣乘法核心是實用的第一步,但在使用 MXFP8 進行 MoE 訓練時,其適用情境有限(例如共享專家場景)。要在 MXFP8 中完整支援 MoE,我們需要分組矩陣乘法核心,具體而言:
- 分組前向傳播(Fprop)/資料梯度(Dgrad)
- 群組化權重梯度(Wgrad)
請注意,正因為存在這些變體,我們才從頭打造這些核心。迄今為止,我們尚未找到任何能完全支援以 32 區塊縮放進行 MXFP8 MoE 訓練的開源替代方案。
在核心層級,分組的 Fprop 與 Dgrad 具有相同的結構。唯一的差異是 Dgrad 需要進行累加,因為輸入張量會同時通過 up 與 gate 投影。不過,這可以很容易地透過將 cp.async.bulk.tensor 指令替換為 cp.reduce.async.bulk.tensor 來實作;後者可非同步地對 HBM 執行具原子性的加法寫入。
已知下列矩陣:
- A:
num_tokens×in_dim - W:
E×in_dim×out_dim,其中E為當前 rank 上的專家數 - O:
num_tokensxout_dim
在假設 tokens 依專家索引排序的情況下,分組的 Fprop/Dgrad 核心會執行:
for i in range(num_routed_experts):
start = 0 if i == 0 else end
end = start + assigned_tokens_per_expert[i]
O[start:end, :] = A[start:end, :] @ W[i, :, :]相較之下,Grouped Wgrad 的差異在於專家切分是沿著 K(歸約)軸,而不是沿著 M 軸。其計算如下:
for i in range(num_routed_experts):
start = 0 if i == 0 else end
end = start + assigned_tokens_per_expert[i]
W_grad[i, :, :] = A.T[:, start:end] @ O[start:end, :]核心抽象層
在核心層級,常見的工作單位是在指定的列、行與歸約範圍上執行矩陣乘加運算。將此單位抽象化,對於實作分組矩陣乘法核心極為有用,也讓我們能以最少變更重用原有的 MXFP8 矩陣乘法核心。因此,我們將原始核心抽取如下:
// All units of 128
int expert_idx = ... // 0 for 2D case
int row_block_idx = ...
int col_block_idx = ...
int reduction_block_start_idx = ...
int reduction_block_end_idx = ...
// Based on MXFP8 matrix multiplication implemented above
// Runs at 256x128 granularity when possible (if not, 128x128)
run_mxfp8_matmul(
expert_idx,
row_block_idx,
col_block_idx,
reduction_block_start_idx,
reduction_block_end_idx
);透過此抽象化,實作分組的 MXFP8 矩陣乘法變體歸結為在啟動上述抽象之前設計合適的迴圈結構,並正確指派索引。不過,天真/樸素的迴圈方式可能因 L2 快取利用率不佳而顯著降低效能。
透過按專家劃分的超分組進行 L2 快取優化
維持高 L2 快取使用率至關重要。我們的基準測試顯示,低效的 HBM 存取模式可能使分組矩陣乘法的核心效能下降近 50%。為了解決這個問題,我們採用了 supergrouping——ThunderKittens 核心中的啟發式方法——透過讓任一時間點由全部 148 個 SM 所計算的輸出矩陣區域盡可能接近正方形,來最大化 L2 的重複利用。你可以閱讀上方連結的核心程式碼以進一步了解。
我們對分組矩陣乘法核心的一項關鍵改進,是對每位專家套用超分組(supergrouping),只針對當前專家所屬的子矩陣進行處理,而非整個輸出矩陣。這在分組的 Wgrad 上特別有效,因為受專家分割影響,歸約軸往往較窄。歸約軸越窄,張量核心的利用率就越低,記憶體頻寬因而成為主要瓶頸。
透過合適的矩陣乘法核心抽象與按專家維度進行的 L2 快取最佳化,我們在分組的 MXFP8 矩陣乘法核心上達成約 2,650 TFLOP/s——相較於未分組版本僅下降 4%。太好了!
群組式矩陣乘法效能基準
在 Blackwell 上,針對分組的 Fprop/Dgrad/Wgrad,最接近的開源替代方案是 DeepSeek 的 DeepGEMM。DeepGEMM 與我們略有不同,因其對 A 與 B 矩陣採用 1×128 與 128×128 的尺度分塊,代價是精度有所下降。即便如此,它仍是唯一可用的替代方案,因此我們將 DeepGEMM 整合進內部模型訓練,並對其效能進行分析。雖然 DeepGEMM 在我們的微基準測試中對某些輸入形狀表現出色,但端到端基準測試顯示,在我們的工作負載上,我們的核心效能優於它。
| DeepSeek DeepGEMM | 我們的方法 | |
|---|---|---|
分組 Fprop/Dgrad | 0.67 ms | 0.43 ms |
分組 Wgrad | 0.71 ms | 0.65 ms |
表 3:內部模型訓練期間,分組 Fprop/Dgrad 與 Wgrad 核心的平均延遲
另需注意,上述基準測試不包含量化時間。DeepGEMM 並未提供最佳化的量化核心,缺少這些核心時,端到端效能可能快速下滑;在最糟的情況下,甚至可能比 BF16 訓練更慢。
打造史上最快的 MXFP8 量化內核
如前所述,現有的 MXFP8 量化核心不僅效果不佳,還需要重新調整形狀以符合 tcgen05.mma 的縮放因子佈局,導致額外的執行時開銷。因此,我們的目標是設計一個量化核心,能夠:
- 充分吃滿記憶體頻寬;對於 M 維度很大的形狀,最好能達到甚至超過 6 TB/s。
- 產生符合精確佈局之
tcgen05.mma預期的縮放矩陣,使我們的 MXFP8 核心能在無需中間轉換的情況下,執行 HBM → SMEM → TMEM 的資料流。
結果證明,並沒有哪一種能立刻帶來峰值效能的「魔法」設計;真正有效的是選對基礎原語,外加靠反覆試驗摸索出的眾多微調最佳化。最大的收穫來自移除 TMA swizzling,改用手動 swizzling 模式以降低同一個 warp 內的額外負擔;依賴 warp 排程器與 threadblock 之間的非同步,而不是手動在同一個 threadblock 內做重疊;以及盡量降低 SMEM 與暫存器的用量,以提高 SM 的佔用率。除此之外,還有一整套常見的最佳化手法,包括使用 TMA、消除 SMEM bank 衝突、善用快速向量內建函式等。不過,這些其實很難推而廣之。
透過這些最佳化,我們實作了一個 MXFP8 量化核心,能維持 6.2+ TB/s 的效能,並以可直接相容 tcgen05.mma 的版面配置產生縮放矩陣。就我們所知,這是目前用於 MoE 訓練最快的 MXFP8 量化核心。
| NVIDIA TransformerEngine | PyTorch TorchAO | 我們的方案 | |
|---|---|---|---|
Naive(基本實作) | 5236.35 GB/s | 5245.15 GB/s | 不適用 |
使用 reshape | 4430.27 GB/s | 4524.45 GB/s | 6212.21 GB/s |
表 4。依記憶體頻寬使用率比較 MXFP8 量化核心(E4M3,32 區塊縮放)。「包含 reshape」指包含為 tcgen05.mma 進行縮放因子重塑所花費的時間。
對 MXFP8 量化的其他優化包括:建立一個融合式量化核心,可同時輸出未轉置與已轉置的結果(後者用於反向傳播),並將 MXFP8 量化邏輯直接融入其他核心,以盡可能減少對 HBM 的存取。例如,我們將 MXFP8 的反量化/量化邏輯分別整合到融合式 MXFP8 SwiGLU 核心的前言(prologue)與收尾(epilogue)階段。
速度提升!
透過上述所有最佳化,我們在 MoE 層的前向與反向傳遞上達成3.5 倍的加速,同時維持相同的訓練品質。在我們的一個內部模型上測試時,這在 Blackwell 上帶來了端到端 1.5 倍的訓練加速,並相較於我們原本的 Hopper 設定達到 2 倍加速;以上以每張 GPU 每秒 tokens 數(TPS/GPU)衡量。我們相信,我們的技術堆疊在 MXFP8 的 MoE 訓練速度上,快於任何開源替代方案的組合。

| Hopper BF16 | Blackwell BF16 | Blackwell MXFP8 | |
|---|---|---|---|
MoE 前向傳播(毫秒) | 32.36 毫秒 | 25.96 毫秒 | 9.45 毫秒 |
MoE 反向傳播(ms) | 63.24 ms | 59.17 ms | 17.04 ms |
端到端(TPS/GPU) | 12k TPS/GPU | 16k TPS/GPU | 24k TPS/GPU |
表 5。效能比較(內部模型)。注意:Blackwell BF16 MoE 為 Hopper 程式碼的直接移植
最後想法
我們前方仍有大量工作,且還有許多運算核心有待優化。目前我們專注於打造更高效的多 GPU 通訊、改進自訂的 attention 核心,並為未來的 MoE 訓練作業準備轉換至 FP4。
如果你有興趣撰寫與優化高效能核心、訓練大型程式碼模型,或對我們的整體工作方向感到興趣,歡迎與我們聯繫。請寄信至hiring@anysphere.co。
特別感謝 Benjamin Spector、Sasha Rush、Less Wright、Federico Cassano、Rohan Shah、Jacob Jackson、Sujay Jayakar 與 Shengtong Zhang 閱讀本文並提供寶貴回饋。