跳至内容

使用自定义 MXFP8 内核,将 MoE 训练提速至 1.5 倍

作者: Stuart Sul归入研究

通过为 Blackwell GPU 进行完整重构,实现 MoE 层 3.5 倍的加速。

我们致力于打造全球最出色的 AI 编码模型,但训练大型语言模型成本高昂。以我们最大的内部模型为例,往往需要在数以万计的 GPU 上训练数周。这不仅计算资源消耗巨大,也会放慢改进交付给用户的速度.

我们最近开始从 Hopper GPU(H100)升级到 Blackwell GPU(B200),并借此机会对训练工作负载进行深度优化。剖析发现主要瓶颈在 Mixture-of-Experts(MoE)层,该层使用MegaBlocks实现,占据了前向传播时间的近 53% 和反向传播时间的 27%。

这就是为什么在过去几周里,我们从零开始在 GPU 内核层面重写了整个 MoE 层,完全不依赖任何 CUDA 库。相反,我们使用了纯正的传统 CUDA 和 PTX,并点缀少量ThunderKittens。因此,我们在前向和反向传播的 MoE 层性能上实现了3.5 倍的提升,在 Blackwell 上带来端到端 1.5 倍的训练加速,并相较于我们最初的 Hopper 方案实现了 2 倍加速。我们相信,我们的技术栈比当今任何开源替代方案的组合都更快。

Figure 1. Relative speedups of MXFP8 MoE compared to BF16 (normalized to 1.0).
Figure 1. Relative speedups of MXFP8 MoE compared to BF16 (normalized to 1.0).

我们的大部分改进来自从 BF16 过渡到MXFP8,而且几乎没有损失训练质量。但这也让我们认识到,降到低精度说起来容易做起来难。如果处理不当,由于各种内核开销,采用 MXFP8 训练相较于 BF16 可能只能带来极小的性能提升。并且 MXFP8 的训练配方并未广泛分享,这意味着你需要自己摸索出正确的方法。

以下内容介绍我们的方法,以及我们在 Cursor 所进行的机器学习工作。

Microscaling(MX)数据格式快速介绍

降低大型深度学习模型计算成本的常见方法是使用较低精度的激活和权重。然而,若将它们转换为更窄的位宽格式(例如 8 位或更少),除非对数值进行适当缩放,否则会引入不可接受的舍入误差。比如,大模型中的某些权重可能为 0.0001、0.0005 或 –0.0007,但如果直接转换为 FP8,这些数都会被舍入为同一个数:0,因为在 FP8E4M3 中可表示的最小正值约为 0.0019。

为了解决这一问题,通常会对每个张量应用一个缩放因子(per-tensor scaling factor),在不超出目标数据格式可表示范围的前提下对张量进行重缩放,从而充分利用可用的动态范围。举例来说,如果某个张量的取值都在 –0.0001 到 0.0001 之间,其缩放因子可以设为 4,480,000。这样缩放后的取值范围将变为 [–448, 448],与 FP8E4M3 的可表示边界相对应。

Microscaling 通过将缩放应用于张量的细粒度子块(而非对整个张量使用单一缩放因子)更进一步。Microscaling(MX)格式通过定义一组低精度、微缩放的数据格式来规范这一方法。其规范提供了完整细节,并定义了以下具体的 MX 兼容格式:

Table 1. Concrete MX-compliant formats.
Table 1. Concrete MX-compliant formats.

例如,如果我们使用 MXFP8 格式并选择 E4M3 元素类型,数据将由 FP8E4M3 元素组成,并对每个连续的 32 个元素块应用 FP8E8M0 缩放因子,如下图所示。

Figure 2. MXFP8 quantization example: Each 1x32 block shares a scale.
Figure 2. MXFP8 quantization example: Each 1x32 block shares a scale.

尽管微缩比例(microscaling)能够带来显著的性能提升,但在实际应用中会引入诸多挑战,而且高度依赖底层硬件。我们先来看一看,为什么在 NVIDIA Blackwell GPU 上应用微缩比例会特别具有挑战性。

1. Tensor 内存与 CUDA 核心扼杀了反量化的节奏

在微尺度的 FP8 矩阵乘法中,计算会沿归约维度被拆分为更小的分块步骤。每次完成一个分块的矩阵乘法后,部分结果会使用缩放因子进行反量化,然后再进行累加,之后继续处理下一个分块。

例如,在广为人知的DeepSeek V3 (DSV3)技术报告中,AA矩阵按1×1281 \times 128的块进行缩放,BB矩阵按128×128128 \times 128的块进行缩放。这意味着步骤如下:

1. 执行一次 128 块矩阵乘法:

C_block = A[:, k*128:(k+1)*128] @ B[k*128:(k+1)*128, :]

2. 使用缩放因子进行反量化:

C_block = A_scale[:, k] * C_block * B_scale[k, :]

其中,A_scale 在 K 维度上广播到 128 个元素,B_scale 在两个维度上各广播到 128 个元素

3. 累积反量化后的块:

C += C_block

4. 沿着归约维度继续到下一个块:

k += 1

这种方法在 Hopper 架构上(DSV3 在其上进行训练)自然表现出色,原因是:(1) 张量核心矩阵乘法(通过 wgmma 指令)的结果累积在寄存器中;(2) 你可以对矩阵乘法进行流水化,在用 CUDA 核心执行反量化的同时,异步启动其他张量核心的矩阵乘法。由于所有结果都累积在寄存器中,因此在矩阵乘法之间无需额外的数据搬移。

由于tensor memory(TMEM),在 Blackwell GPU 上情况已不再如此。TMEM 是 Blackwell GPU 新增的一组片上、SM 本地内存。与 Hopper GPU 不同,后者会将张量核心的矩阵乘累加结果直接保存在寄存器中,Blackwell GPU 则将其累加到 TMEM 中(通过 tcgen05.mma 指令)。若要对累加器执行自定义算术运算,必须先将结果从 TMEM 传输到寄存器,用 CUDA 核心处理后再写回 TMEM,且由于这些数据移动是异步的,还需要等待所有这些指令完成。尽管 TMEM 的访问速度快于共享内存(SMEM),这仍会降低张量核心的占用率。

Figure 3. Gantt chart taken from our custom Blackwell attention kernel. First row shows the tensor core activity (QK<sup>T</sup>). Second row shows the CUDA core activity (TMEM → registers, then softmax). TMEM → register latency causes the tensor core pipeline bubble.
Figure 3. Gantt chart taken from our custom Blackwell attention kernel. First row shows the tensor core activity (QKT). Second row shows the CUDA core activity (TMEM → registers, then softmax). TMEM → register latency causes the tensor core pipeline bubble.

乍一看,这提示了一种流水线式的方法:将 TMEM 拆分为若干块,在对其中一块执行矩阵乘法的同时,另一块进行数据搬运和反量化。

但还有另一个问题:虽然 Blackwell 张量核心的 TFLOP/s 相比 Hopper 翻了一倍,但 FP32 CUDA 核心仅提升了约 33%(60 → 80 TFLOP/s)。在 32 分块缩放的情况下,反量化所需的计算量是矩阵乘法的 1/32(两者都涉及乘法与加法;使用 CUDA 核心进行反量化时还必须手动累加),但反量化的速度却只有矩阵乘法的 1/56(理论值为 4,500 TFLOP/s)。结果是,在 CUDA 核心上执行反量化所耗时间几乎是矩阵乘法的 1.76 倍。这太糟糕了。

Hopper(H100 SXM5)Blackwell(B200)

FP8 Tensor Core 吞吐量

1,979 TFLOP/s

4,500 TFLOP/s

FP32 CUDA Core 吞吐量

60 TFLOP/s

80 TFLOP/s

32 组块去量化时间

(相对于矩阵乘法)

1.03×

1.76×

表 2。Hopper 与 Blackwell 的相对反量化成本。

从实证来看,即便采用上述方法的各种变体,我们也无法超过 Hopper 实际可达的 FP8 吞吐量 1,500 TFLOP/s。而且这还没有把量化开销考虑在内。

2. 千刀万剐式的量化

在将矩阵送入 FP8 矩阵乘法内核之前,必须先对每个矩阵进行 FP8 量化。尽可能减少并隐藏量化所需时间至关重要。若处理不当,量化内核会占用并主导 GPU 运行时间。

为了理解量化可以有多主导性,让我们考虑一个简单的矩阵乘法C=A×BC = A \times B,其中:

AAM×KM \times K

BBK×NK \times N

CCM×NM \times N

在用于 MoE 训练的分组矩阵乘法中,MM通常远大于KKNN。因此我们令:

M=131,072M = 131,072

K=7,168K = 7,168

N=2,048N = 2,048

由此,完成矩阵乘法本身所需的浮点运算总数为:

131,072×7168×2048×2FLOP=3.85TFLOP131{,}072 \times 7168 \times 2048 \times 2 \,\text{FLOP} = 3.85 \,\text{TFLOP}

我们的基准测试显示,Blackwell GPU(B200)在 FP8 矩阵乘法上的吞吐量约为 3,300 TFLOP/s。这意味着预计执行该矩阵乘法的实际运行时间为:

3.85 TFLOP / 3,300 TFLOP/s=1.16 毫秒3.85 \text{ TFLOP } / \text{ } 3{,}300 \text{ TFLOP/s} = 1.16 \text{ 毫秒}

不过,我们还需要考虑将矩阵 A 和 B 量化所需的时间。由于量化受到内存带宽的限制,关键因素是数据移动的总量。假设原始矩阵为 BF16(2 字节),且缩放块大小为 32。这意味着我们必须:

  • 负载 A:131,072×7,168×2B=1.88 GB131{,}072 \times 7{,}168 \times 2\text{B} = 1.88 \text{ GB}
  • 负载 B:7,168×2,048×2B=0.029 GB7{,}168 \times 2{,}048 \times 2\text{B} = 0.029 \text{ GB}
  • 存储量化后的 A:131,072×7,168×1B=0.94 GB131{,}072 \times 7{,}168 \times 1\text{B} = 0.94 \text{ GB}
  • 存储量化后的 B:7,168×2,048×1B=0.015 GB7{,}168 \times 2,048 \times 1\text{B} = 0.015 \text{ GB}
  • 为 A 存储缩放因子:131,072×7,168×132×1B=0.029 GB131{,}072 \times 7{,}168 \times \frac{1}{32} \times 1\text{B} = 0.029 \text{ GB}
  • 为 B 存储缩放因子:7,168×2,048×132×1B=0.0005 GB7{,}168 \times 2{,}048 \times \frac{1}{32} \times 1\text{B} = 0.0005 \text{ GB}

总体而言,这大约涉及 2.9 GB 的高带宽内存(HBM)读写。假设在 B200 上可持续的 HBM 吞吐量为 6.5 TB/s,那么一个优化的 FP8 量化内核将需要:

2.9 GB / 6.5 TB/s0.44 ms2.9 \text{ GB } / \text{ } 6.5 \text{ TB/s} \approx 0.44 \text{ ms}

这几乎占据了矩阵乘法时间的 40%。在一个常见场景中,我们还会对 A 矩阵进行转置并量化以用于反向传播,这会使时间翻倍到 0.88 ms,约占矩阵乘法时间的 76%

尽管理论上 FP8 矩阵乘法相比 BF16 矩阵乘法快两倍,但量化所耗时间可能彻底抵消这部分性能提升。以上分析甚至还是偏乐观的,因为它假设块缩放的 FP8 矩阵乘法能以 3,300 TFLOP/s 的速度运行,这远高于端到端 MoE 训练中通常能达到的水平。在最糟情况下,量化所需时间可能比 FP8 矩阵乘法更长,从而使整体 FP8 计算反而比 BF16 更慢。

确实存在一些快速的开源 FP8 量化内核,例如来自 NVIDIA 的TransformerEngine或 Pytorch 的TorchAO。然而,我们的微基准测试显示,它们未充分利用带宽,运行速度约为 4.5 TB/s。若纯粹依赖 SM 内部指令(例如 cp.async.bulk),我们知道 Blackwell 可以轻松达到 6~7 TB/s 的 HBM 带宽。

此外,NVIDIA 的 MXFP8 tcgen05.mma 指令(张量核心矩阵乘法的 PTX 指令)需要一种略显反直觉的缩放因子布局。在 32-block 缩放下,TransformerEngine 或 TorchAO 量化内核会以朴素的 M x N / 32 布局返回缩放因子。随后必须进行重排,或者在 PyTorch 中完成,或者将重排逻辑融合进其他内核,这两种方式都会负面影响性能。实际中,你并不希望在矩阵乘法内核内处理缩放因子。加载它们最快的方式是走 HBM → SMEM(cp.async.bulk)路径,然后再走 SMEM → TMEM(tcgen05.cp)路径;一旦缩放因子绕行通过寄存器 tile,tensor 的节奏就没了。

接下来,我们将从量化方案入手,说明我们如何解决上述挑战。

选择合适的低精度方案

为了达到与 BF16 训练质量相匹配的效果,我们进行了一系列低精度实验,评估各方案与 BF16 的偏离程度。基于结果,我们找出了在我们的工作负载下,训练损失收敛几乎与 BF16 相同的方案。

具体而言,我们使用 MXFP8 格式,其中元素数据类型为 FP8E4M3(4 位指数、3 位尾数),尺度(scale)数据类型为 FPE8M0(8 位指数),缩放块大小为 32。我们还采用论文“Recipes for Pre-training LLMs with MXFP8”中的 MXFP8 量化方案。令:

  • BF16(或 FP32)向量V={Vii=0,,31}V = \{ V_i \mid i = 0, \ldots, 31 \}
  • 对应的 FP8E4M3 向量Q={Qii=0,,31}Q = \{ Q_i \mid i = 0, \ldots, 31 \}
  • FP8E8M0 规模SS

我们按如下方式计算QQSS

S=cast_to_fp8e8m0(absolute_max(V)/448)S = \text{\texttt{cast\_to\_fp8e8m0}}\big(\text{\texttt{absolute\_max}}(V) / 448\big)
Qi=cast_to_fp8e4m3(Vi/S)Q_i = \text{\texttt{cast\_to\_fp8e4m3}}\big(V_i / S\big)

其中,cast_to_fp8e8m0向最近的 2 的幂进行取整,并将最小值钳制为21272^{-127};而 cast_to_fp8e4m3对超出范围的数进行饱和处理,并四舍五入到最接近值,遇到中点时取偶。

由此,我们的 FP8 训练损失与 BF16 训练损失相匹配:

Figure 4. BF16 vs MXFP8 Training Loss over 10k steps: nearly indistinguishable.
Figure 4. BF16 vs MXFP8 Training Loss over 10k steps: nearly indistinguishable.
Figure 5. BF16 vs MXFP8 Training Loss (9k–10k steps) with a specific data point at step 9832.
Figure 5. BF16 vs MXFP8 Training Loss (9k–10k steps) with a specific data point at step 9832.

采用 tcgen05 MXFP8 分块缩放的矩阵乘法

在 NVIDIA Blackwell 架构上,MX 格式被固化在张量核心中。通过 tcgen05.mma...block_scale 指令启用块缩放,并由硬件处理。由于一切都在张量核心的矩阵乘法过程中完成,这就无需将数据从 TMEM 中移出进行反量化。因此,我们的 MXFP8 矩阵乘法内核设计必须围绕 tcgen05.mma 展开,并在其内部以及约束条件内进行,以获得最大性能。

这里有几个重要考量。首先,tcgen05.mma 指令只需要单个线程即可异步发起;这与 Hopper 形成对比,后者的 wgmma 指令需要整个 warpgroup(128 个线程)才能异步发起。因此,我们不得不偏离 Hopper 内核中常见的生产者-消费者模式,在那种模式下有 256+ 个线程专用于发起矩阵乘法。

其次,tcgen05.mma 支持 2-CTA 矩阵乘法:两个 SM 通过共享 B 矩阵协同执行一次矩阵乘法。这既减少了内存流量,也降低了共享内存占用,从而允许更深的矩阵乘法流水线。我们的基准测试表明,相较于非集群版本,这种 2-CTA 模式在 MXFP8 矩阵乘法中可带来约 15~20% 的加速,对追求峰值性能至关重要。

第三,如前所述,tcgen05.mma 会将结果累积在 TMEM 中,而不是寄存器中。这样可以降低寄存器压力,但也会通过 tcgen05.ldtcgen05.st 指令在 TMEM 与寄存器之间引入额外的数据移动。我们必须确保将这些数据移动降到最低。

最后,tcgen05.mma 指令要求缩放因子驻留在 TMEM 中。然而,没有直接的方法可以将缩放值从 HBM 加载到 TMEM。最快的方法是先使用 cp.async.bulk.tensor 指令(利用 Tensor Memory Accelerator,简称 TMA)将数据从 HBM 加载到片上 SMEM,然后再使用 tcgen05.cp 指令将其从 SMEM 传输到 TMEM。要使这一流程可行,缩放因子必须按照 tcgen05.mma 所需的布局进行存储,本文稍后会进行说明。

上述这些考虑中涉及的所有指令 — tcgen05.mmacp.async.bulk.tensortcgen05.cptcgen05.ldtcgen05.st — 都由单个线程异步发起。这使我们能够应用 warp 专用化,并设计流水线式数据流,使用 TMEM 和 SMEM 作为环形缓冲区。

为此,我们首先对 TMEM 和 SMEM 进行分区。在 Blackwell 上,我们可用 128x512 的 TMEM(每个单元 32 位)以及每个 threadblock 连续的 227 KB SMEM。我们将 TMEM 划分为用于 A 和 B 缩放因子存储的 5 个槽位,同时为矩阵乘法累加(MMA)保留空间。类似地,在 SMEM 中我们预留空间用于将 MMA 结果回写到 HBM,并将其余部分划分为 5 个槽位,用于加载输入 tile 和缩放因子。如下图所示为布局示意。

Figure 6. Simplified TMEM allocation: accumulator region plus 5 slots each for A and B scales
Figure 6. Simplified TMEM allocation: accumulator region plus 5 slots each for A and B scales
Figure 7. Simplified SMEM allocation: 5 slots reserved for input tiles and scales
Figure 7. Simplified SMEM allocation: 5 slots reserved for input tiles and scales

在这种设置下,我们设计了一条流水线:部分 warp 持续从 HBM 将输入 tile 和 scale 加载并搬运到 SMEM,另一些将 scale 从 SMEM 移动到 TMEM,另外一些发起 MMA 运算,还有一些会不时将 TMEM 累加器加载到寄存器中、存入 SMEM,并通过 TMA 将其写回 HBM。

Figure 8. Simplified MXFP8 Matrix Multiplication Pipeline
Figure 8. Simplified MXFP8 Matrix Multiplication Pipeline

具体来说,我们为每个 threadblock 分配 3 个 warpgroup(384 个线程),并将这些 warpgroup 划分为两类:其中 2 个 warpgroup 仅执行 TMEM → 寄存器 → SMEM → HBM 的数据流动,这会带来最高的寄存器压力。另一个 warpgroup 执行 warp-specialization:warp 0 将输入 tile 从 HBM 加载到 SMEM,warp 1 将 scale 从 HBM 加载到 SMEM,warp 2 将 scale 从 SMEM 加载到 TMEM,warp 3 触发张量核心的矩阵乘操作。我们还实现了持久网格模式(persistent grid pattern),为每个 SM(Blackwell GPU 上为 148 个)分配单个 threadblock,使得在将结果写回 HBM 的同时可以加载新的输入 tile。

该流水线的伪代码如下所示:

if (warpgroup_id < 2) {
    for (int i = 0; i < num_tiles; i++) {
        mbarrier_wait_for_final_matmul_completion(); // mbarrier.try_wait
        async_load_from_TMEM(reg, TMEM); // tcgen05.ld
        wait_for_load_completion(); // tcgen05.wait
        // This is iterated in the actual implementation to save SMEM 
        store_to_SMEM(SMEM, reg);
        TMA_async_store(HBM, SMEM); // cp.async.bulk.tensor
    }
} else {
    if (warp_id == 0) {
        for (int i = 0; i < num_tiles; i++) {
            for (int j = 0; j < num_iters; j++) {
                mbarrier_wait_for_matmul_completion();
                // load input tiles
                TMA_async_load(SMEM, HBM);
            }
        }
    } else if (warp_id == 1) {
        for (int i = 0; i < num_tiles; i++) {
            for (int j = 0; j < num_iters; j++) {
                mbarrier_wait_for_tcgen05_cp_completion();
                // load scales (HBM -> SMEM)
                TMA_load(SMEM, HBM);
            }
        }
    } else if (warp_id == 2) {
        for (int i = 0; i < num_tiles; i++) {
            for (int j = 0; j < num_iters; j++) {
                mbarrier_wait_for_matmul_completion();
                mbarrier_wait_for_scale_SMEM_load();
                // load scales (SMEM -> TMEM)
                load_to_TMEM(TMEM, SMEM); // tcgen05.cp
            }
        }
    } else if (cta_rank == 0) { // 2-CTA MMA is launched by a single CTA
        for (int i = 0; i < num_tiles; i++) {
            mbarrier_wait_for_TMEM_clear();
            for (int j = 0; j < num_iters; j++) {
                mbarrier_wait_for_input_SMEM_load();
                mbarrier_wait_for_scale_TMEM_load();
                // tcgen05.mma.cta_group::2.mxf8f6f4.block_scale
                launch_matmul(SMEM, SMEM, TMEM);
            }
        }
    }
}

在 Blackwell GPU 上进行块尺度矩阵乘法时,一个不可避免的限制是 TMEM 的大小。我们的微基准测试显示,当将完整的 128x512 TMEM 用作累加器时,Blackwell 张量核心可实现最高吞吐量。对于 2-CTA 的 FP8 矩阵乘法,这对应于持续运行两个 256x32x256 的 tcgen05.mma 指令。每条 tcgen05.mma 会为每个 CTA 消耗一个 128x256 的 TMEM 区域,因此两条指令合在一起会在两个 CTA 之间完整占用 128x512 的 TMEM 阵列。

然而,当缩放因子也必须驻留在 TMEM 中时,我们在仅使用 128x256 的 TMEM 区域时,每次只能执行一条 256x32x256 的 tcgen05.mma 指令。其结果是性能下降不可避免。举例来说,在该约束下,16,384x16,384x16,384 的 FP8 矩阵乘法吞吐量将从 3,200 TFLOP/s 降至 3,040 TFLOP/s。

这些吞吐量数据仅适用于纯 FP8 矩阵乘法。使用 MXFP8 分块缩放时,由于 TMEM 流水线开销,吞吐量不可避免地进一步下降。实际中,在对 L2 缓存进行清空的条件下,我们在分块缩放的 MXFP8 矩阵乘法内核上可达到约 2,750 TFLOP/s。即便如此,这仍然比标准 BF16 矩阵乘法快约 1.83 倍;在最优形状下,BF16 通常能达到 1,500~1,550 TFLOP/s。开局还不错!

扩展至 MXFP8 分组矩阵乘法

一个独立的 MXFP8 矩阵乘法内核是有用的第一步,但在 MXFP8 的 MoE 训练中,它的应用受限(例如共享专家的场景)。要在 MXFP8 中全面支持 MoE,我们需要分组矩阵乘法内核,具体而言:

  1. 分组前向传播(Fprop)/ 数据梯度(Dgrad)
  2. 分组权重梯度(Wgrad)

请注意,这些变体也是我们从零开始构建这些内核的部分原因。迄今为止,我们尚未发现任何开源替代方案能够完全支持采用 32-block 缩放的 MXFP8 MoE 训练。

在内核层面,分组的 Fprop 和 Dgrad 具有相同的结构。唯一的区别是 Dgrad 需要进行累加,因为输入张量会同时经过 up 投影和 gate 投影。但这可以很容易地通过将 cp.async.bulk.tensor 指令替换为 cp.reduce.async.bulk.tensor 来实现,后者可以异步对 HBM 执行原子加并存储(add-store)。

给定如下矩阵:

  • Anum_tokens x in_dim
  • WE × in_dim × out_dim,其中 E 为当前 rank 上的专家数量
  • 输出num_tokens x out_dim

并假设 tokens 按照专家索引排序,分组的 Fprop/Dgrad 内核执行:

for i in range(num_routed_experts):
    start = 0 if i == 0 else end
    end = start + assigned_tokens_per_expert[i]
    O[start:end, :] = A[start:end, :] @ W[i, :, :]

另一方面,Grouped Wgrad 的不同之处在于,其专家划分发生在 K(归约)轴上而非 M 轴上。它计算:

for i in range(num_routed_experts):
    start = 0 if i == 0 else end
    end = start + assigned_tokens_per_expert[i]
    W_grad[i, :, :] = A.T[:, start:end] @ O[start:end, :]

内核抽象

在内核层面,通用的工作单元是针对指定的行、列与归约范围执行矩阵乘-累加操作。将这一单元抽象化对于实现分组矩阵乘法内核极为有用,并使我们能以最小改动复用原有的 MXFP8 矩阵乘法内核。为此,我们将原始内核抽离如下:

// All units of 128
int expert_idx = ... // 0 for 2D case
int row_block_idx = ...
int col_block_idx = ...
int reduction_block_start_idx = ...
int reduction_block_end_idx = ...

// Based on MXFP8 matrix multiplication implemented above
// Runs at 256x128 granularity when possible (if not, 128x128)
run_mxfp8_matmul(
    expert_idx,
    row_block_idx, 
    col_block_idx, 
    reduction_block_start_idx, 
    reduction_block_end_idx
);

借助这一抽象,实现分组的 MXFP8 矩阵乘法变体只需在启动上述抽象之前设计合适的循环结构并正确分配索引。然而,天真的循环方式会因 L2 缓存利用率低而显著降低性能。

通过专家分组超聚合进行 L2 缓存优化

保持高 L2 缓存利用率至关重要。在我们的基准测试中,低效的 HBM 访问模式会使分组矩阵乘法内核的性能下降近 50%。为解决这一问题,我们采用了supergrouping来自 ThunderKittens 内核的启发式方法),通过确保在任意时刻由全部 148 个 SM 计算的输出矩阵区域尽可能接近正方形,从而最大化 L2 复用。你可以通过阅读上方链接的内核代码了解更多信息。

我们对分组矩阵乘法内核的一项关键改进是在每个专家的粒度上应用超级分组,只考虑当前专家对应的子矩阵,而不是整个输出矩阵。该方法对分组 Wgrad 尤其有效,因为由于专家划分,归约轴通常较窄。归约轴较窄会降低张量核心的利用率,使内存带宽成为主要瓶颈。

借助合适的矩阵乘法内核抽象和按专家优化的 L2 缓存,我们在分组 MXFP8 矩阵乘法内核上实现了约 2,650 TFLOP/s——相比非分组版本仅下降 4%。太棒了!

分组矩阵乘法基准测试

在 Blackwell 上,用于分组 Fprop/Dgrad/Wgrad 的最接近的开源替代方案是 DeepSeek 的 DeepGEMM。DeepGEMM 的不同之处在于其对 A 与 B 矩阵分别采用 1x128 和 128x128 的分块尺度,这会以精度降低为代价。尽管如此,它仍是唯一可用的替代方案,因此我们将 DeepGEMM 集成到内部模型训练中并对其性能进行了分析。虽然 DeepGEMM 在我们的微基准测试中对某些输入形状表现突出,但端到端基准结果显示,在我们的工作负载上,我们的内核性能优于它。

DeepSeek DeepGEMM我们的

分组 Fprop / Dgrad

0.67 ms

0.43 ms

分组 Wgrad

0.71 ms

0.65 ms

表 3。内部模型训练期间分组 Fprop/Dgrad 与 Wgrad 内核的平均时延。

同样需要注意的是,以上基准测试不包含量化时间。DeepGEMM 未提供优化的量化内核;缺少这一步时,端到端性能会迅速下降。最糟情况下,甚至可能比 BF16 训练更慢。

打造史上最快的 MXFP8 量化内核

如前所述,现有的 MXFP8 量化内核不仅效果不佳,还需要重排以匹配 tcgen05.mma 的尺度因子布局,从而带来额外的运行时开销。因此,我们的目标是设计一个能够:

  • 充分饱和内存带宽,优选在大型 M 维度形状下超过 6 TB/s。
  • 生成尺度矩阵以完全相同的布局为……所需tcgen05.mma,使我们的 MXFP8 内核能够在无需中间转换的情况下执行 HBM → SMEM → TMEM 数据流。

事实证明,并不存在一个能瞬间带来峰值性能的“魔法”设计;真正奏效的是选对基础原语并通过反复试验摸索出的诸多微优化。收益最大的是:去掉 TMA swizzling,改用手动 swizzling 模式以降低 warp 内部开销;依赖 warp 调度器与线程块之间的异步性,而非在单个线程块内手工重叠;并尽量减少 SMEM 与寄存器占用以提升 SM 占用率。除此之外,还有一系列标准优化,比如使用 TMA、消除 SMEM 银行冲突、利用快速向量内建函数等。但这些其实很难被普遍化。

通过这些优化,我们实现了一个 MXFP8 量化内核,在生成与 tcgen05.mma 直接兼容布局的缩放矩阵的同时,持续达到 6.2+ TB/s 的带宽表现。据我们所知,这是当前用于 MoE 训练的最快 MXFP8 量化内核。

NVIDIA TransformerEnginePyTorch TorchAO我们的

朴素

5236.35 GB/s

5245.15 GB/s

不适用

使用 reshape

4430.27 GB/s

4524.45 GB/s

6212.21 GB/s

表 4. 按内存带宽利用率比较 MXFP8 量化内核(E4M3,32 块缩放)。“包含重塑”包括为 tcgen05.mma 重塑缩放因子的时间。

对 MXFP8 量化的其他优化包括:构建一个融合的量化内核,可同时输出未转置和转置结果(后者用于反向传播);并将 MXFP8 量化逻辑直接融合进其他内核,尽可能减少对 HBM 的访问。例如,我们在融合的 MXFP8 SwiGLU 内核的序言和尾声阶段挂接了 MXFP8 的反量化/量化逻辑。

速度提升!

凭借上述所有优化,我们在保持相同训练质量的同时,达成了MoE 层前向与反向传递均提速 3.5 倍。在我们的一款内部模型上测试,这在 Blackwell 上带来了端到端训练 1.5 倍的加速,并且相较于我们原先的 Hopper 配置实现了 2 倍的加速,以上以每块 GPU 的每秒处理 token 数(TPS/GPU)计量。我们相信,与任何开源替代组合相比,我们的技术栈能够以更快速度完成 MXFP8 的 MoE 训练。

Figure 9. End-to-end training TPS per GPU (internal model).
Figure 9. End-to-end training TPS per GPU (internal model).
Hopper BF16Blackwell BF16Blackwell MXFP8

MoE 前向(毫秒)

32.36 毫秒

25.96 毫秒

9.45 毫秒

MoE 反向传播(毫秒)

63.24 ms

59.17 ms

17.04 ms

端到端(TPS/GPU)

12k TPS/GPU

16k TPS/GPU

24k TPS/GPU

表 5. 性能对比(内部模型)。注意:Blackwell BF16 MoE 是从 Hopper 代码直接移植而来

结语

我们仍有大量工作要完成,许多内核也有待优化。我们目前专注于构建更高效的多 GPU 通信、改进自定义注意力内核,并为未来的 MoE 训练切换到 FP4 做准备。

如果你对编写和优化高性能内核、训练大型代码模型,或更广泛地对我们的工作感兴趣,欢迎与我们联系。请发送邮件至hiring@anysphere.co

特别感谢 Benjamin Spector、Sasha Rush、Less Wright、Federico Cassano、Rohan Shah、Jacob Jackson、Sujay Jayakar 和 Shengtong Zhang 阅读此文并提供宝贵反馈。

归类于: 研究

作者: Stuart Sul