Mejor inferencia de modelos MoE con warp decode
Al invertir el eje del paralelismo, logramos una inferencia de modelos MoE 1.8x más rápida y más precisa.
La mayoría de los sistemas de inferencia de MoE organizan la generación de tokens en torno a los expertos. Esto refleja cómo funciona el enrutamiento y ha sido el enfoque estándar a escala. Sin embargo, para la decodificación con lotes pequeños en GPUs Blackwell, descubrimos que organizar el kernel en torno a los resultados, en lugar de a los expertos, funciona mejor. A este enfoque lo llamamos “warp decode”.
Llegamos a warp decode al preguntarnos cuál es realmente el ancho de banda de memoria máximo alcanzable para la decodificación de MoE en Blackwell. Eso nos llevó a invertir por completo el eje del paralelismo. En vez de asignar warps a expertos, asignamos cada warp a un único valor de resultado (neurona).
Los kernels que mejoran tanto el rendimiento como la precisión son poco comunes, y warp decode es uno de ellos. En Blackwell, ofrece una mejora de 1.84x en el rendimiento y, además, mejora la precisión, con resultados 1.4x más cercanos a una referencia completa en FP32. Esto acelera el pipeline de investigación y entrenamiento de Composer, lo que nos permite mejorar el modelo más rápido y publicar nuevas versiones con mayor frecuencia.
La ruta convencional de MoE
Los modelos MoE modernos dirigen cada token a través de un subconjunto de redes expertas especializadas, seleccionando, por ejemplo, 8 de 128 en una capa determinada. La implementación estándar organiza todo el cómputo en torno a esos expertos: recopila los tokens que necesita cada experto, ejecuta los cálculos y vuelve a ensamblar los resultados.
Esto funciona bien para el prefill y los lotes grandes, donde el trabajo compartido por experto amortiza la sobrecarga de organizar los datos. Pero durante el paso de decodificación autorregresiva, en el que solo producimos un token cada vez, no hay suficiente trabajo compartido como para justificarlo. Cinco de las ocho etapas de la ruta tradicional existen únicamente para gestionar la disposición de los datos desde una perspectiva centrada en los expertos y no realizan ningún cómputo real.
Qué cambiamos
Warp decode elimina esos cinco pasos de “bookkeeping” al reorganizar el paralelismo en torno a los resultados en lugar de a los expertos.
Las GPU modernas ejecutan instrucciones en grupos de 32 carriles de procesamiento en paralelo, llamados warps. En nuestro nuevo enfoque, a cada warp se le asigna el cálculo de exactamente un valor de resultado. El warp lee directamente de la memoria los datos de pesos que necesita, agrega los totales de los ocho expertos enrutados en un único acumulado y escribe un único resultado.
Esta independencia de los warps permite que warp decode se ejecute sin etapas intermedias, traspasos, puntos de sincronización entre warps ni búferes intermedios. Toda la capa de cómputo de MoE queda reducida a dos kernels, moe_gate_up_3d_batched y moe_down_3d_batched.
Cómo funcionan los dos kernels
En el kernel de gate/up, cada cooperative thread array (CTA) tiene ocho warps, y cada warp se encarga de una neurona intermedia para cada combinación de un token y un experto enrutado. El warp carga el ID del experto enrutado, lee las filas de pesos de gate y up para esa neurona, y recorre en streaming el vector de activación de entrada. Los pesos MXFP8 se convierten a FP32 sobre la marcha, y ambos productos escalares se acumulan en registros privados.
Como los dos kernels están fusionados en una sola pasada, el vector de activación se lee una vez y se reutiliza de inmediato para ambas proyecciones, sin ninguna etapa intermedia en memoria compartida. Después de una reducción a nivel de warp, el warp aplica SiLU(gate) × up y escribe un valor intermedio.
En el down kernel, cada warp se encarga de una dimensión de salida para un token. Recorre todos los expertos enrutados del top-k, cargando la fila de pesos relevante de la proyección down y recorriendo en streaming las activaciones intermedias, mientras integra el peso de enrutamiento de cada experto en un único acumulador FP32.
Después de procesar todos los expertos, reducimos las 32 sumas parciales locales de cada carril con una reducción butterfly a nivel de warp usando __shfl_xor_sync. Esto se compila directamente a la instrucción PTX shfl.sync.bfly, una única primitiva de hardware que intercambia registros entre carriles dentro del warp, sin pasar por la memoria compartida.
La ventaja aquí es que no necesitamos accesos de ida y vuelta a L1, conflictos de bancos ni barreras explícitas, porque la sincronización viene integrada en la instrucción mediante la máscara de carriles. En lugar de un epílogo independiente, la combinación ponderada final del top-k pasa a formar parte de la propia proyección.
Cada warp en warp decode es independiente y recibe una única asignación estable durante toda su vida útil: producir un escalar de salida. Esta independencia de los warps es lo que elimina la etapa intermedia en memoria compartida, la sincronización entre warps y los búferes intermedios que requiere la ruta tradicional.
Simplificación y aceleración del pipeline
La warp decode logra mejoras de rendimiento mediante dos mecanismos principales: eliminar etapas y búferes que requería la ruta tradicional, y crear independencia entre warps, lo que permite una mejor planificación y ocultación de la latencia.
Eliminación de etapas
La eliminación de etapas aporta la mayor parte de la mejora de rendimiento. Eliminamos el relleno, la dispersión y la etapa de combinación. Suprimir estas etapas requiere reorganizar el paralelismo desde la base, en lugar de limitarse a fusionar etapas de la canalización tradicional.
Eliminación del relleno
Ruta tradicional: Rellena la lista de tokens de cada experto hasta alinearla a límites de potencia de 2 o de 128 bytes, para cumplir los requisitos de los kernels agrupados. Durante la decodificación con un solo token, esto supone una sobrecarga que no puede amortizarse.
Ruta de warp decode: Evita por completo esta sobrecarga al no formar nunca lotes por experto.
Eliminación de scatter y combinación
Ruta tradicional: Después de que cada experto termina, escribe ocho resultados intermedios en la memoria de la GPU y luego ejecuta un paso de reducción independiente para combinarlos.
Ruta de warp decode: El peso de enrutamiento de cada experto se integra en el acumulador dentro del warp. Los ocho resultados intermedios nunca llegan a materializarse en memoria, lo que evita tanto el coste de escritura como el de lectura de una pasada de reducción posterior.
Eliminación de búferes
La reorganización también elimina dos búferes de memoria intermedios que la ruta tradicional requiere como consecuencia de su diseño centrado en los expertos.
El primero es un búfer de recolección de activaciones: el vector de activación de entrada se copia y se reorganiza en un formato con los expertos como dimensión principal. Con un tamaño de lote de 1, esto supone una copia completa de datos que ya existen. El segundo es un búfer de resultado por experto. Con ocho expertos y una dimensión oculta de 2048, esto equivale a 8 × 2048 × 2 bytes = 32 KB por token en BF16, que se asignan, se escriben, se leen inmediatamente una vez y luego se descartan.
La warp decode elimina ambos al combinar las ocho contribuciones de los expertos en un acumulador en registros a lo largo de los 32 carriles del warp, donde nada llega a la memoria global hasta la escritura final de un único escalar. Eliminar más de 32 KB de tráfico de búferes intermedios por token libera capacidad de la caché L2 para las filas de pesos que realmente determinan el rendimiento.
Independencia de los warps
La reorganización también acelera el cómputo resultante porque, por diseño, el kernel es “trivialmente paralelizable”: cada warp es completamente independiente de los demás. Como cada warp tiene exactamente un escalar de salida y solo lee las filas de pesos que necesita, no existe estado mutable compartido entre warps.
A nivel de un solo warp, esta independencia es total. Las activaciones de entrada son de solo lectura, el acumulador vive en registros privados y la escritura de salida va a una dirección única. Desde la perspectiva del planificador de hardware, toda la dimensión de salida es un conjunto plano de elementos de trabajo independientes.
El planificador de warps de la GPU puede emitir cualquier warp en cualquier momento y en cualquier orden, sin restricciones de corrección. Cuando un warp se bloquea esperando una carga de memoria, el planificador cambia de inmediato a otro. Con miles de warps en vuelo distribuidos entre los 148 multiprocesadores de streaming de una B200, la latencia de memoria queda casi por completo oculta tras el cómputo útil de otros warps.
El kernel también escala linealmente: duplicar la dimensión de salida duplica el número de warps independientes sin añadir sincronización. Lo mismo ocurre en la dimensión del lote de tokens, por lo que el planificador ve un único espacio de trabajo plano, sin dependencias entre nodos. Esto contrasta con el enfoque tradicional, donde los kernels GEMM a nivel de experto requieren coordinación dentro del bloque.
Resultados
Rendimiento de decodificación de extremo a extremo a escala
El Testing en nuestro sistema interno de inferencia, ejecutando un modelo de estilo Qwen-3 en GPU NVIDIA B200, mostró una mejora constante del rendimiento. Esta mejora del rendimiento se mantiene uniforme en todos los rangos de longitud de contexto, lo que confirma que se trata de una mejora puramente en el tiempo de generación y que no depende de la longitud del prompt.
Mayor precisión
Eliminar el paso intermedio de cuantización de activaciones tiene un impacto medible en la calidad. Convertir activaciones BF16 a MXFP8 y de vuelta introduce un nivel mínimo de error de redondeo que se acumula a lo largo de las capas del modelo. Warp decode mantiene las activaciones en BF16 en todo momento y los acumuladores en FP32, por lo que la reducción nunca opera sobre entradas degradadas. El resultado es que warp decode produce resultados 1.4x más cercanos a la referencia completa de 32 bits que la ruta clásica.
Eficiencia de hardware
Empezamos a desarrollar warp decode preguntándonos hasta qué punto podíamos acercarnos al rendimiento máximo del hardware. El pico medido del B200 para lecturas contiguas de memoria es de 6.8 TB/s (medido con un kernel de copia). Warp decode alcanza de forma sostenida 3.95 TB/s con B=32, es decir, el 58% de ese pico. La brecha restante probablemente refleje la penalización por latencia de memoria de los patrones de acceso aleatorio que genera el enrutamiento de expertos, ya que cada token puede dirigirse a expertos no adyacentes como 5, 8, 14, 19, etc.
En cambio, el pico de rendimiento se mide con lecturas contiguas de memoria (0,1,2,3). La concordancia con la implementación de referencia fue muy alta en todos los tamaños de lote: similitud coseno mínima > 0.999996, diferencia absoluta máxima 0.001953.
Warp decode y entrenamiento de Composer
Warp decode no es un sustituto general de la ejecución centrada en expertos. Las cargas de trabajo de mayor volumen, como prefill y la inferencia en lotes grandes, siguen beneficiándose del empaquetado centrado en expertos porque muchos tokens comparten el mismo experto, y el costo de organizarlos se amortiza con suficiente cómputo efectivo como para que valga la pena.
Warp decode sobresale cuando no hay suficiente trabajo compartido por experto como para justificar esa sobrecarga, como suele ocurrir con la decodificación de MoE. Esto lo convierte en una parte importante de cómo seguimos mejorando Composer. Mientras que las inversiones en datos de preentrenamiento y RL determinan la calidad de los resultados del modelo, las inversiones en inferencia como warp decode determinan con qué rapidez y precisión esos resultados llegan a los desarrolladores.