用 warp decode 实现更好的 MoE 模型推理
通过翻转并行性轴,我们将 MoE 模型推理提速 1.8 倍,同时提升精度。
大多数 MoE 推理系统都会围绕专家来组织 token 生成路径。这与路由机制的工作方式一致,也一直是大规模场景下的标准做法。然而,在 Blackwell GPU 上进行小批量解码时,我们发现,围绕输出而非专家来组织 kernel 效果更好。我们将这种方法称为“warp decode”。
warp decode 的提出,源于我们对 Blackwell 上 MoE 解码理论最大可达内存带宽的思考。由此,我们彻底翻转了并行性轴。不再将 warp 分配给专家,而是让每个 warp 对应一个输出值 (神经元) 。
能够同时提升性能和精度的 kernel 并不多见,warp decode 就是其中之一。在 Blackwell 上,它将吞吐量提升了 1.84 倍,同时也提高了精度,使输出结果与完整 FP32 参考值相比接近了 1.4 倍。这加快了 Composer 的研究与训练流程,让我们能够更快改进模型,并更频繁地发布新版本。
传统的 MoE 路径
现代 MoE 模型会将每个 token 路由到一部分专用的专家网络中,例如在某一层从 128 个专家中选择 8 个。标准实现会围绕这些专家来组织全部计算:先收集每个专家所需的 token,完成计算,再将结果重新组装起来。
这在 prefill 和大批量场景下效果很好,因为每个专家上的共享工作足以摊薄整理数据的额外开销。但在自回归解码阶段,由于一次只生成一个 token,就没有足够的共享工作来支撑这种做法。传统路径中的八个阶段里,有五个阶段纯粹是为了管理以专家为中心视角下的数据布局,本身并不进行任何实际计算。
我们做了哪些改变
Warp decode 不再围绕专家,而是围绕输出重新组织并行计算,从而省去了那五个“簿记”步骤。
现代 GPU 会以由 32 条并行处理通道组成的组来执行指令,这样的一组称为一个 warp。在我们的新方案中,每个 warp 都只负责计算一个输出值。warp 会直接从内存中流式读取所需的权重数据,将所有 8 个路由专家的结果汇总到一个持续累加的总值中,最后写出一个结果。
这种 warp 级别的独立性让 warp decode 无需任何暂存、切换、跨 warp 同步点或中间缓冲区即可运行。整个 MoE 计算层被压缩为两个 kernel:moe_gate_up_3d_batched 和 moe_down_3d_batched。
两个 kernel 如何工作
在 gate/up kernel 中,每个 cooperative thread array (CTA) 由 8 个 warp 组成,每个 warp 负责一个 token 与一个路由专家配对对应的中间神经元。该 warp 会加载路由专家 ID,读取该神经元对应的 gate 和 up 权重行,并流式扫描输入激活向量。MXFP8 权重会在计算过程中即时转换为 FP32,两个点积都累加到私有寄存器中。
由于这两个 kernel 融合在同一次计算中,激活向量只需读取一次,就能立即复用于两个投影,完全不需要共享内存暂存。在完成 warp 级归约后,warp 会应用 SiLU(gate) × up,并写出一个中间值。
在 down kernel 中,每个 warp 负责一个 token 的一个输出维度。它会遍历所有 top-k 路由专家,加载对应的 down-projection 权重行,并流式处理中间激活,同时将每个 专家 的路由权重累积到同一个持续更新的 FP32 累加器中。
在处理完所有 专家 后,我们会使用 __shfl_xor_sync 做 warp 级蝶形归约,将 32 个 lane 的局部部分和归约起来。这会直接编译为 PTX shfl.sync.bfly 指令。它是一条单一的硬件原语,用于在 warp 内各个 lane 之间交换寄存器,完全绕过共享内存。
这样做的好处在于,我们不再需要 L1 往返、bank 冲突或显式屏障,因为同步已经通过 lane mask 内置在这条指令里。最终的加权 top-k 组合也不再是单独的 epilogue,而是直接成为投影本身的一部分。
warp decode 中的每个 warp 都彼此独立,并且在整个生命周期内始终承担一个固定且稳定的任务:生成一个输出标量。正是这种 warp 独立性,消除了传统路径所需的共享内存暂存、跨 warp 同步以及中间缓冲区。
流水线简化与加速
Warp decode 主要通过两种机制提升性能:一是去掉传统路径所需的阶段和缓冲区,二是实现 warp 的独立性,从而带来更优的调度效果和更好的延迟隐藏能力。
阶段消除
阶段消除带来了绝大部分吞吐量提升。我们去掉了填充、分发以及 combine 步骤。要移除这些阶段,需要从底层重新组织并行方式,而不只是简单地合并传统流水线中的各个阶段。
消除填充
传统路径:将每个 专家 的 token 列表填充到 2 的幂或 128 字节边界,以满足分组 kernel 的要求。在解码阶段,如果只有单个 token,这种开销就无法摊销。
warp decode 路径:由于完全不会形成按 专家 划分的批次,因此彻底避免了这类开销。
消除 scatter 和 combine 操作
传统路径: 每个专家完成后,都会将八个中间结果写入 GPU 内存,然后再运行一个单独的归约步骤将它们combine。
warp decode 路径: 每个专家的路由权重会在 warp 内直接并入正在累加的结果中。这八个中间结果不会真正落到内存里,因此也省去了后续一次归约遍历的写入和读取开销。
消除缓冲区
这种重组还消除了传统路径由于其以 专家 为中心的布局而需要的两个中间内存缓冲区。
第一个是激活收集缓冲区,即输入激活向量会被复制并重新排列为以 专家 为主的布局。在批大小为 1 时,这实际上是对已存在数据的一次完整复制。第二个是每个 专家 的输出缓冲区。对于 8 个 专家 和 2048 的隐藏维度,在 BF16 下,这意味着每个 token 需要 8 × 2048 × 2 bytes = 32 KB,先分配、写入,随后立即读取一次,然后丢弃。
Warp decode 通过在 32 条 warp lane 上将这 8 个 专家 的贡献汇总到一个寄存器累加器中,消除了这两个缓冲区;在最终写出单个标量之前,数据都不会进入全局内存。每个 token 减少 32+ KB 的中间缓冲区传输后,就能为真正决定性能的权重行释放更多 L2 缓存容量。
Warp 独立性
这种重组也让保留下来的计算更快,因为该 kernel 在设计上就是“天然易于并行”的:每个 warp 都与其他 warp 完全独立。由于每个 warp 恰好只负责一个输出标量,并且只读取自己所需的权重行,因此 warp 之间不存在共享的可变状态。
在单个 warp 的层面上,这种独立性是完全的。输入激活值是只读的,累加器保存在私有寄存器中,输出则写入唯一的地址。从硬件调度器的视角来看,整个输出维度就是一组扁平排列、彼此独立的工作项。
GPU 的 warp 调度器可以在任何时间、按任意顺序调度任意 warp,而不受正确性约束。当某个 warp 因等待内存加载而停顿时,调度器会立即切换到另一个 warp。在 B200 的 148 个流式多处理器上,同时运行着数千个 warp,因此内存延迟几乎完全被其他 warp 的有效计算掩盖。
该 kernel 还具备线性扩展性:输出维度翻倍时,独立 warp 的数量也随之翻倍,且无需额外同步。token 批次维度也是如此,因此调度器看到的是一个扁平的工作命名空间,节点之间没有任何依赖边。这与传统路径形成鲜明对比,后者中的专家级 GEMM kernel 需要块内协同。
结果
大规模下的端到端解码吞吐量
在我们的内部推理系统上,对运行于 NVIDIA B200 GPU 的 Qwen-3 风格模型进行测试后,结果显示吞吐量获得了稳定的吞吐量提升。吞吐量提升在所有上下文长度分组中都基本一致,说明这是一项纯粹发生在生成阶段的改进,不依赖于提示长度。
精度提升
移除中间激活量化步骤,会对输出质量带来可量化的影响。将 BF16 激活转换为 MXFP8 再转回 BF16,会引入舍入误差下限,并在模型各层中不断累积。Warp decode 在整个过程中都将激活保持为 BF16,并将累加器保持为 FP32,因此归约操作不会基于已退化的输入进行。最终,与传统路径相比,warp decode 生成的输出与完整 32 位真值的接近程度提升了 1.4 倍。
硬件效率
我们在开发 warp decode 时,首先考虑的是:它究竟能在多大程度上逼近硬件的最大吞吐量。B200 在连续内存读取上的实测峰值为 6.8 TB/s (通过 copy kernel 测得) 。在 B=32 时,warp decode 可稳定达到 3.95 TB/s,相当于该峰值的 58%。剩余的差距很可能反映了 专家 路由 所带来的随机访问模式的内存延迟开销,因为每个 token 都可能被路由到并不相邻的 专家,例如 5、8、14、19 等。
相比之下,峰值吞吐量是通过连续的 (0,1,2,3) 内存读取测得的。与 reference implementation 对比时,在所有 批大小 下的正确性都非常高:最小余弦相似度 > 0.999996,最大绝对差异为 0.001953。
Warp decode 与 Composer 训练
Warp decode 并不能普遍取代以专家为中心的执行方式。对于 prefill 和大批量推理这类高吞吐量工作负载,以专家为中心的打包方式仍然更有优势,因为许多 token 会共享同一个专家,而整理这些 token 的成本能够被足够多的实际计算摊薄,因此是划算的。
当每个专家可共享的工作量不足以支撑这部分额外开销时,Warp decode 就会更具优势,而 MoE decode 往往正是这种情况。这也使它成为我们持续改进 Composer 的重要一环。虽然对预训练数据和 RL 的投入决定了模型输出的质量,但像 warp decode 这样的推理层投入,则决定了这些输出能够以多快、多准确的速度触达开发者。