研究

以 warp decode 提升 MoE 模型推論

透過翻轉平行化軸線,我們實現了快 1.8 倍且更準確的 MoE 模型推論。

Less Wright, Federico Cassano & Zhiyuan Zhang4 分鐘閱讀時間

大多數 MoE 推論系統都會圍繞 專家 來組織 token 生成路徑。這反映了 routing 的運作方式,也一直是大規模場景下的標準作法。然而,對於 Blackwell GPU 上的小批次 decode,我們發現以輸出而非 專家 為中心來組織 kernel,效果更好。我們將這種方法稱為「warp decode」。

我們之所以提出 warp decode,是因為我們思考了 Blackwell 上 MoE decode 實際能達到的最大記憶體頻寬究竟是多少。這讓我們徹底翻轉了平行化軸線。我們不再將 warps 指派給 專家,而是讓每個 warp 負責單一輸出值 (神經元) 。

能同時提升效能與準確度的 kernels 相當少見,而 warp decode 就是其中之一。在 Blackwell 上,它可帶來 1.84 倍的吞吐量提升,同時也提升了準確度,讓輸出 與完整 FP32 reference 的差距縮小 1.4 倍。這也加快了 Composer 的研究與訓練管線,讓我們能更快改進模型,並更頻繁地推出新版本。

傳統的 MoE 路徑

現代 MoE 模型會將每個 token 路由到一部分專門的專家網路,例如在某一層中從 128 個專家裡選出 8 個。標準實作會以這些專家為核心來組織所有運算:先收集各個專家所需的 token,執行運算,再將結果重新組合。

這種方式在 prefill 和大型批次中效果很好,因為每位專家分攤到的共享工作足以抵消整理資料的額外負擔。但在自回歸解碼步驟中,由於一次只會產生一個 token,共享工作不足以合理化這種做法。傳統路徑中的八個階段裡,有五個階段純粹是在為以專家為中心的視角管理資料配置,實際上並未進行任何運算。

我們做了哪些改變

Warp decode 透過將平行化的組織方式從 專家 轉為以輸出為中心,消除了那五個「簿記」步驟。

現代 GPU 會以稱為 warp 的 32 條平行處理通道群組來執行指令。在我們的新方法中,每個 warp 都只負責計算一個輸出值。warp 會直接從記憶體串流所需的權重資料,將路由到的八個 專家 的總和累積到同一個累計值中,並寫回一個結果。

這種 warp 的獨立性,讓 warp decode 無需任何暫存、交接、跨 warp 同步點或中介緩衝區就能執行。整個 MoE 計算層被濃縮成兩個 kernel:moe_gate_up_3d_batchedmoe_down_3d_batched

兩個 kernel 的運作方式

在 gate/up kernel 中,每個 cooperative thread array (CTA) 包含八個 warp,而每個 warp 會為每一組 token 與 routed expert 的配對負責一個中間神經元。warp 會載入 routed expert ID,讀取該神經元對應的 gate 與 up 權重列,並串流處理輸入 activation 向量。MXFP8 權重會即時轉換為 FP32,而兩個點積都會累加到私有暫存器中。

由於這兩個 kernel 融合為單次傳遞,activation 向量只需讀取一次,隨即同時重用於兩個 projection,無需經過 shared memory 暫存。在完成 warp-level reduction 後,warp 會套用 SiLU(gate) × up,並寫出一個中間值。

down kernel 中,每個 warp 會為單一 token 負責一個 output 維度。它會遍歷所有 top-k routed 專家,載入對應的 down-projection 權重列,並串流處理中間 activation,同時將每個 expert 的 routing weight 納入同一個持續累加的 FP32 accumulator。

在所有 專家 都處理完成後,我們會使用 __shfl_xor_sync 做 warp-level butterfly reduction,將 32 個 lane-local partial sums 歸約起來。這會直接編譯成 PTX 的 shfl.sync.bfly 指令;這是一個單一硬體原語,可在同一個 warp 的 lanes 之間交換暫存器,完全繞過 shared memory。

這樣的好處是,我們不需要 L1 往返、bank conflicts 或顯式 barriers,因為同步已透過 lane mask 內建在這條指令中。最終加權的 top-k 組合不再是獨立的 epilogue,而是直接成為 projection 本身的一部分。

在 warp decode 中,每個 warp 都彼此獨立,且在整個生命週期中都只會收到一個固定且穩定的任務:產生一個 output scalar。正是這種 warp 獨立性,消除了傳統路徑所需的 shared memory 暫存、跨 warp 同步,以及中間緩衝區。

管線簡化與加速

warp decode 透過兩項主要機制提升效能:一是移除傳統路徑所需的階段與緩衝區,二是建立 warp 的獨立性,進而改善排程並隱藏延遲。

消除階段

消除階段帶來了大部分的吞吐量增益。我們去除了 padding、scattering 和 combine 步驟。要移除這些階段,必須從根本重新組織平行化方式,而不只是單純融合傳統管線中的各個階段。

消除填充

傳統路徑:將每個專家 的 token 清單填充到 2 的冪次方或 128 位元組的邊界,以符合 grouped kernel 的要求。在解碼時若只有單一 token,這類額外開銷無法攤銷。

warp decode 路徑:完全避免了這類額外開銷,因為它從不為各專家 形成批次。

消除 scatter 與 combine

傳統路徑:每個專家 完成後,都會將八個中間結果寫入 GPU 記憶體,然後再執行一個獨立的 reduction 步驟將它們合併。

warp decode 路徑:每個專家 的 routing 權重會直接併入 warp 內持續更新的累加器。這八個中間結果完全不會實際寫入記憶體,因此可省下後續 reduction pass 的寫入與讀取成本。

消除緩衝區

這項重組也消除了傳統路徑因其以專家為中心的版面配置而必須使用的兩個中介記憶體緩衝區。

第一個是 activation gather buffer,也就是將輸入 activation 向量複製並重新排列為以專家為主的版面配置。在 batch size 為 1 時,這其實是把已經存在的資料完整複製一遍。第二個是每個專家的輸出緩衝區。若有八個專家且 hidden dimension 為 2048,在 BF16 下,這相當於每個 token 需要 8 × 2048 × 2 bytes = 32 KB,先配置、寫入、隨即讀取一次,然後丟棄。

Warp decode 會將八個專家的貢獻整合到分布於 32 個 warp lane 上的 register accumulator 中,因此在最後一次單一純量寫入之前,資料都不會進入全域記憶體,從而消除這兩種緩衝區。每個 token 移除超過 32 KB 的中介緩衝區傳輸量後,就能釋放 L2 cache 容量,留給真正決定效能的權重列。

Warp 獨立性

這項重組也讓保留下來的運算更快,因為這個 kernel 在設計上就是「極易平行化」:每個 warp 都完全獨立於其他 warp。由於每個 warp 恰好只負責一個輸出純量,且只會讀取自己需要的權重列,因此 warp 之間不存在共享的可變狀態。

在單一 warp 的層級上,這種獨立性是徹底的。輸入活化值是唯讀的,累加器位於私有暫存器中,而輸出則會寫入唯一的位址。從硬體排程器的角度來看,整個輸出維度就是一個由彼此獨立的工作項目組成的平坦池。

GPU 的 warp 排程器可以在任何時間、以任何順序排程任何 warp,而不受正確性約束。當某個 warp 因等待記憶體載入而停滯時,排程器會立刻切換到另一個 warp。在 B200 的 148 個串流多處理器上同時有數千個 warp 在執行時,記憶體延遲幾乎完全被其他 warp 的有效運算所隱藏。

這個 kernel 也能線性擴展;也就是說,將輸出維度加倍時,獨立 warp 的數量也會加倍,且不需要額外同步。在 token batch 維度上也是如此,因此排程器看到的是一個平坦的工作命名空間,節點之間沒有任何邊。這與傳統路徑形成對比,因為專家層級的 GEMM kernel 需要區塊內協調。

結果

大規模下的端到端解碼吞吐量

在我們的內部推論系統上,於 NVIDIA B200 GPU 執行 Qwen-3 風格模型的測試顯示,吞吐量有穩定提升。這項吞吐量提升在所有上下文長度分組中都維持一致,證實這純粹是生成階段的改進,且不受提示詞長度影響。

準確度改進

移除中間的活化量化步驟,會對品質帶來可量化的影響。將 BF16 活化轉換成 MXFP8 再轉回來,會引入一個會在模型各層中累積的捨入誤差下限。Warp decode 會在整個過程中將活化維持在 BF16,並將累加器維持在 FP32,因此歸約運算不會作用在已劣化的輸入上。結果是,warp decode 產生的輸出與完整 32 位元真值的接近程度,比傳統路徑高出 1.4 倍。

硬體效率

我們在開始開發 warp decode 時,先思考它能逼近硬體最大吞吐量到什麼程度。B200 在連續記憶體讀取上的實測峰值為 6.8 TB/s (使用 copy kernel 測得) 。在 B=32 時,warp decode 可穩定維持 3.95 TB/s,相當於該峰值的 58%。剩餘的差距,很可能反映了 expert routing 造成的隨機存取模式所帶來的記憶體延遲成本,因為每個 token 都可能被路由到不相鄰的 專家,例如 5、8、14、19 等。

相較之下,峰值吞吐量是透過連續的 (0,1,2,3) 記憶體讀取測得。與參考實作相比,在所有批次大小下的正確性都非常高:最小 cosine similarity > 0.999996,最大絕對差為 0.001953。

Warp decode 與 Composer 訓練

Warp decode 並不是 expert-centric execution 的通用替代方案。像 prefill 和 large-batch inference 這類高吞吐量工作負載,仍然會受益於 expert-centric packing,因為許多 token 會共用相同的 expert,而重新組織它們的成本可以分攤到足夠多的實際運算上,因此是值得的。

當每個 expert 可共享的工作量不足以支撐這項額外開銷時,Warp decode 就會更有優勢,而 MoE decode 往往正是這種情況。這使它成為我們持續改進 Composer 的重要一環。雖然對預訓練資料和 RL 的投入會決定模型輸出的品質,但像 warp decode 這類在推論上的投入,則決定了這些輸出能以多快、多準確的方式交付給開發者。

分類: 研究

作者s: Less Wright, Federico Cassano & Zhiyuan Zhang