Une meilleure inférence de modèles MoE avec warp decode
En inversant l’axe du parallélisme, nous obtenons une inférence de modèles MoE 1,8x plus rapide et plus précise.
La plupart des systèmes d’inférence MoE organisent la génération de tokens autour des experts. Cela reflète le fonctionnement du routage et constitue l’approche standard à grande échelle. Pour le décodage en petits lots sur les GPU Blackwell, nous avons toutefois constaté qu’il est plus efficace d’organiser le kernel autour des sorties plutôt que des experts. Nous appelons cette approche le « warp decode ».
Nous sommes arrivés au warp decode en nous demandant quelle était réellement la bande passante mémoire maximale atteignable pour le décodage MoE sur Blackwell. Cela nous a amenés à inverser complètement l’axe du parallélisme. Au lieu d’assigner des warps aux experts, nous assignons chaque warp à une seule valeur de sortie (neurone).
Les kernels qui améliorent à la fois les performances et la précision sont rares, et le warp decode en fait partie. Sur Blackwell, il offre un débit 1,84x supérieur tout en améliorant la précision, avec des sorties 1,4x plus proches d’une référence FP32 complète. Cela accélère le pipeline de recherche et d’entraînement de Composer, ce qui nous permet d’améliorer le modèle plus vite et de livrer de nouvelles versions plus souvent.
Le chemin traditionnel du MoE
Les modèles MoE modernes acheminent chaque token vers un sous-ensemble de réseaux d’experts spécialisés, en en sélectionnant par exemple 8 sur 128 dans une couche donnée. L’implémentation standard organise tous les calculs autour de ces experts : elle regroupe les token dont chaque expert a besoin, exécute les calculs, puis réassemble les résultats.
Cette approche fonctionne bien pour le prefill et les grands lots, où le travail mutualisé par expert amortit le surcoût lié à l’organisation des données. Mais pendant l’étape de décodage autorégressif, où l’on ne produit qu’un token à la fois, ce travail mutualisé n’est pas suffisant pour le justifier. Cinq des huit étapes du chemin traditionnel servent uniquement à gérer la disposition des données dans cette approche centrée sur les experts et n’effectuent aucun calcul réel.
Ce que nous avons changé
Warp decode élimine ces cinq étapes de « gestion annexe » en réorganisant le parallélisme autour des valeurs de sortie plutôt que des experts.
Les GPU modernes exécutent les instructions par groupes de 32 voies de traitement parallèles, appelés des warps. Dans notre nouvelle approche, chaque warp se voit attribuer exactement une valeur de sortie à calculer. Le warp récupère en flux les données de poids dont il a besoin directement depuis la mémoire, cumule les contributions des huit experts routés dans un seul accumulateur, puis écrit un unique résultat.
Cette indépendance des warps permet à warp decode de s’exécuter sans mise en tampon préalable, relais, points de synchronisation entre warps ni tampons intermédiaires. Toute la couche de calcul MoE est condensée en deux kernels, moe_gate_up_3d_batched et moe_down_3d_batched.
Comment les deux kernels fonctionnent
Dans le kernel gate/up, chaque cooperative thread array (CTA) comprend huit warps, et chaque warp se voit attribuer un neurone intermédiaire pour chaque paire formée d’un token et d’un expert routé. Le warp charge l’ID de l’expert routé, lit les lignes de poids gate et up pour ce neurone, puis parcourt le vecteur d’activation d’entrée. Les poids MXFP8 sont convertis en FP32 à la volée, et les deux produits scalaires s’accumulent dans des registres privés.
Comme les deux kernels sont fusionnés en un seul passage, le vecteur d’activation n’est lu qu’une seule fois, puis immédiatement réutilisé pour les deux projections, sans aucune mise en tampon en mémoire partagée. Après une réduction au niveau du warp, le warp applique SiLU(gate) × up et écrit une valeur intermédiaire.
Dans le kernel down, chaque warp se voit attribuer une dimension de sortie pour un token. Il parcourt tous les experts routés du top-k, en chargeant la ligne de poids de projection down correspondante et en parcourant les activations intermédiaires, tout en intégrant le poids de routage de chaque expert dans un unique accumulateur FP32.
Une fois tous les experts traités, nous réduisons les 32 sommes partielles locales aux voies de traitement parallèles au moyen d’une réduction butterfly au niveau du warp avec __shfl_xor_sync. Cela se compile directement en l’instruction PTX shfl.sync.bfly, une primitive matérielle unique qui échange des registres entre les voies de traitement parallèles au sein du warp, sans passer du tout par la mémoire partagée.
Le principal avantage, ici, est que nous évitons les allers-retours vers le cache L1, les conflits de banques et les barrières explicites, car la synchronisation est directement intégrée à l’instruction via le masque de voies de traitement parallèles. Au lieu d’un épilogue distinct, la combinaison pondérée finale du top-k fait directement partie de la projection elle-même.
Chaque warp de warp decode est indépendant et reçoit une affectation unique et stable pendant toute sa durée de vie : produire un seul scalaire de sortie. C’est cette indépendance des warps qui élimine la mise en tampon en mémoire partagée, la synchronisation entre warps et les tampons intermédiaires requis par le chemin traditionnel.
Simplification et accélération du pipeline
Warp decode améliore les performances grâce à deux mécanismes principaux : la suppression d’étapes et de tampons qu’exigeait le chemin traditionnel, et la création d’une indépendance entre les warps, qui permet un ordonnancement plus efficace et un meilleur masquage de la latence.
Suppression d’étapes
La suppression d’étapes apporte l’essentiel du gain de débit. Nous éliminons le padding, la dispersion et l’étape de combinaison. Supprimer ces étapes exige de repenser le parallélisme de fond en comble, plutôt que de simplement fusionner des étapes du pipeline traditionnel.
Suppression du padding
chemin traditionnel : aligne la liste de tokens de chaque expert sur des limites en puissance de 2 ou de 128 octets afin de respecter les contraintes des kernels groupés. Lors du décodage d’un seul token, ce surcoût ne peut pas être amorti.
Chemin de décodage warp decode : évite complètement ce surcoût en ne formant jamais de lots par expert.
Élimination du scatter et de la combinaison
Chemin traditionnel : Une fois qu’un expert a terminé, il écrit huit résultats intermédiaires dans la mémoire GPU, puis effectue une étape de réduction distincte pour les combiner.
chemin de décodage warp decode : Le poids de routage de chaque expert est intégré à l’accumulateur courant au sein du warp. Les huit résultats intermédiaires ne sont jamais matérialisés en mémoire, ce qui évite à la fois les coûts d’écriture et de lecture d’un passage de réduction ultérieur.
Élimination des tampons
La réorganisation supprime également deux tampons mémoire intermédiaires dont le chemin traditionnel a besoin en raison de sa disposition centrée sur les experts.
Le premier est un tampon de collecte des activations : le vecteur d’activation d’entrée y est copié puis réorganisé dans une disposition par expert. Avec une taille de lot de 1, il s’agit d’une copie complète de données déjà présentes. Le second est un tampon de sortie par expert. Avec huit experts et une dimension cachée de 2048, cela représente 8 × 2048 × 2 octets = 32 KB par token en BF16, alloués, écrits, relus immédiatement une fois, puis abandonnés.
warp decode élimine les deux en regroupant les huit contributions des experts dans un accumulateur en registres réparti sur les 32 voies du warp, sans rien écrire en mémoire globale avant l’écriture finale d’une seule valeur scalaire. La suppression de plus de 32 KB de trafic de tampons intermédiaires par token libère de la capacité dans le cache L2 pour les lignes de poids qui déterminent réellement les performances.
Indépendance des warps
La réorganisation accélère également le calcul restant, car le kernel est « parfaitement parallélisable » par conception : chaque warp est entièrement indépendant de tous les autres. Comme chaque warp produit exactement un scalaire de sortie et ne lit que les lignes de poids dont il a besoin, il n’existe aucun état mutable partagé entre les warps.
À l’échelle d’un seul warp, cette indépendance est totale. Les activations d’entrée sont en lecture seule, l’accumulateur réside dans des registres privés et l’écriture de la sortie se fait à une adresse unique. Du point de vue de l’ordonnanceur matériel, toute la dimension de sortie forme un ensemble plat d’unités de travail indépendantes.
L’ordonnanceur de warps du GPU peut exécuter n’importe quel warp à tout moment, dans n’importe quel ordre, sans aucune contrainte de correction. Lorsqu’un warp est bloqué en attendant un chargement mémoire, l’ordonnanceur bascule immédiatement sur un autre. Avec des milliers de warps en cours d’exécution sur les 148 multiprocesseurs de streaming d’un B200, la latence mémoire est presque entièrement masquée par le calcul utile effectué par les autres warps.
Le kernel passe aussi linéairement à l’échelle : doubler la dimension de sortie double le nombre de warps indépendants, sans synchronisation supplémentaire. Il en va de même pour la dimension du lot de tokens, de sorte que l’ordonnanceur voit un espace de travail plat, sans arêtes entre les nœuds. Cela contraste avec le chemin traditionnel, où les kernels GEMM au niveau expert nécessitent une coordination intra-bloc.
Résultats
Débit de décodage de bout en bout à grande échelle
Des tests réalisés sur notre système d’inférence interne, avec un modèle de type Qwen-3 exécuté sur des GPU NVIDIA B200, ont montré un gain de débit constant. Ce gain de débit reste uniforme dans toutes les tranches de longueur de contexte, ce qui confirme qu’il s’agit d’une amélioration qui intervient uniquement au moment de la génération et ne dépend pas de la longueur du prompt.
Précision accrue
La suppression de l’étape intermédiaire de quantification des activations améliore la qualité de manière mesurable. La conversion des activations de BF16 vers MXFP8, puis de nouveau vers BF16, introduit un seuil incompressible d’erreur d’arrondi qui s’accumule au fil des couches du modèle. Warp decode conserve les activations en BF16 tout au long du processus et les accumulateurs en FP32, de sorte que la réduction ne s’effectue jamais sur des entrées dégradées. Résultat : warp decode produit des sorties 1,4 fois plus proches de la référence en pleine précision (32 bits) que le chemin traditionnel.
Efficacité matérielle
Nous avons commencé à développer warp decode en nous demandant à quel point nous pouvions nous rapprocher du débit maximal du matériel. Le pic mesuré du B200 pour les lectures mémoire contiguës est de 6,8 TB/s (mesuré à l’aide d’un kernel de copie). Warp decode atteint 3,95 TB/s à B=32, soit 58 % de ce pic. L’écart restant reflète probablement le coût en latence mémoire des patterns d’accès aléatoire induits par le routage vers les experts, puisque chaque token peut être routé vers des experts non adjacents comme 5, 8, 14, 19, etc.
À l’inverse, le débit maximal est mesuré à l’aide de lectures mémoire contiguës (0,1,2,3). La conformité à l’implémentation de référence était très bonne pour toutes les tailles de lot : similarité cosinus minimale > 0,999996, différence absolue maximale de 0,001953.
Warp decode et entraînement de Composer
Le Warp decode n’a pas vocation à remplacer de façon générale l’exécution centrée sur les experts. Les charges de travail à fort volume, comme le prefill et l’inférence en grands lots, bénéficient toujours d’un regroupement centré sur les experts, car de nombreux tokens partagent le même expert, et le coût de ce regroupement est amorti sur un volume de calcul réel suffisant pour en valoir la peine.
Le Warp decode l’emporte lorsqu’il n’y a pas assez de travail partagé par expert pour justifier ce surcoût, comme c’est souvent le cas lors du décodage MoE. Cela en fait une composante importante de l’amélioration continue de Composer. Alors que les investissements dans les données de préentraînement et le RL déterminent la qualité des sorties du modèle, les investissements dans l’inférence, comme le Warp decode, déterminent la vitesse et la précision avec lesquelles ces sorties parviennent aux développeurs.