Bessere MoE-Modellinferenz mit warp decode
Durch das Umdrehen der Parallelitätsachse erreichen wir eine 1,8x schnellere und genauere MoE-Modellinferenz.
Die meisten MoE-Inferenzsysteme organisieren die Token-Generierung entlang der Experten. Das entspricht der Funktionsweise des Routings und war im großen Maßstab bislang der Standardansatz. Für Decoding mit kleinen Batches auf Blackwell-GPUs haben wir jedoch festgestellt, dass es besser funktioniert, den Kernel entlang der Ausgaben statt der Experten zu organisieren. Wir nennen diesen Ansatz „warp decode“.
Zu warp decode kamen wir, als wir darüber nachdachten, wie hoch die maximal erreichbare Speicherbandbreite für MoE-Decode auf Blackwell tatsächlich ist. Das brachte uns dazu, die Parallelitätsachse komplett umzudrehen. Statt Warps Experten zuzuweisen, ordnen wir jeden Warp einem einzelnen Ausgabewert (Neuron) zu.
Kernel, die sowohl die Performance als auch die Genauigkeit verbessern, sind selten — und warp decode ist einer davon. Auf Blackwell liefert es einen 1,84x höheren Durchsatz und verbessert zugleich die Genauigkeit: Die Ausgaben liegen 1,4x näher an einer vollständigen FP32-Referenz. Das beschleunigt die Research- und Training-Pipeline für Composer, sodass wir das Modell schneller verbessern und neue Versionen häufiger ausliefern können.
Der herkömmliche MoE-Pfad
Moderne MoE-Modelle leiten jedes Token durch eine Teilmenge spezialisierter Expertennetzwerke und wählen dabei in einer bestimmten Schicht zum Beispiel 8 von 128 aus. In der Standardimplementierung ist die gesamte Berechnung auf diese Experten ausgerichtet: Die Token, die jeder Experte benötigt, werden gesammelt, die Berechnungen werden ausgeführt und die Ergebnisse anschließend wieder zusammengesetzt.
Das funktioniert gut für Prefill und große Batches, bei denen die gemeinsame Arbeit pro Experte den Overhead für die Datenorganisation amortisiert. Im autoregressiven Decode-Schritt, bei dem jeweils nur ein Token erzeugt wird, gibt es dafür jedoch nicht genug gemeinsame Arbeit. Fünf der acht Stufen im traditionellen Pfad dienen ausschließlich dazu, das Datenlayout für die expertenzentrierte Sicht zu verwalten, und führen keine eigentliche Berechnung aus.
Was wir geändert haben
Warp Decode eliminiert diese fünf „Verwaltungsschritte“, indem es die Parallelisierung nicht mehr um Experten, sondern um Ausgaben herum organisiert.
Moderne GPUs führen Instruktionen in Gruppen von 32 parallelen Verarbeitungslanes aus, die als Warp bezeichnet werden. In unserem neuen Ansatz ist jedem Warp genau ein Ausgabewert zur Berechnung zugewiesen. Der Warp streamt die benötigten Gewichtungsdaten direkt aus dem Speicher, summiert die Werte aller acht weitergeleiteten Experten in einer einzigen laufend aktualisierten Gesamtsumme auf und schreibt ein einziges Ergebnis.
Durch diese Unabhängigkeit der Warps kann Warp Decode ohne Staging, Übergaben, Warp-übergreifende Synchronisationspunkte oder Zwischenpuffer ausgeführt werden. Die gesamte MoE-Rechenschicht wird auf zwei Kernel reduziert: moe_gate_up_3d_batched und moe_down_3d_batched.
Wie die beiden Kernel funktionieren
Im Gate/Up-Kernel besteht jedes Cooperative Thread Array (CTA) aus acht Warps, und jeder Warp übernimmt ein Zwischennneuron für jede Kombination aus einem Token und einem gerouteten Experten. Der Warp lädt die ID des gerouteten Experten, liest die Gate- und Up-Gewichtszeilen für dieses Neuron und streamt über den Eingabeaktivierungsvektor. MXFP8-Gewichte werden on the fly in FP32 umgewandelt, und beide Skalarprodukte werden in privaten Registern akkumuliert.
Da die beiden Kernel zu einem einzigen Durchlauf zusammengeführt sind, wird der Aktivierungsvektor einmal gelesen und sofort für beide Projektionen wiederverwendet, ganz ohne Staging im Shared Memory. Nach einer Reduktion auf Warp-Ebene wendet der Warp SiLU(gate) × up an und schreibt einen Zwischenwert.
Im Down-Kernel übernimmt jeder Warp eine Ausgabedimension für ein Token. Er durchläuft alle Top-k-gerouteten Experten, lädt die relevante Gewichtszeile der Down-Projection und streamt über die Zwischenaktivierungen, während das Routing-Gewicht jedes Experten in einen einzigen laufenden FP32-Akkumulator einfließt.
Nachdem alle Experten verarbeitet wurden, reduzieren wir die 32 lane-lokalen Teilsummen mit einer Butterfly-Reduktion auf Warp-Ebene mithilfe von __shfl_xor_sync. Dies wird direkt zur PTX-Instruktion shfl.sync.bfly kompiliert, einer einzelnen Hardware-Primitive, die Register zwischen Lanes innerhalb des Warps austauscht und Shared Memory vollständig umgeht.
Der Vorteil dabei ist, dass wir keine L1-Roundtrips, Bankkonflikte oder expliziten Barrieren brauchen, weil die Synchronisierung über die Lane-Maske direkt in die Instruktion eingebaut ist. Statt eines separaten Epilogs wird die endgültige gewichtete Top-k-Kombination Teil der Projektion selbst.
Jeder Warp in Warp Decode ist unabhängig und erhält für seine gesamte Laufzeit eine einzige, stabile Aufgabe: einen Ausgabeskalar zu erzeugen. Diese Unabhängigkeit der Warps beseitigt das Shared-Memory-Staging, die Warp-übergreifende Synchronisierung und die Zwischenpuffer, die der herkömmliche Pfad erfordert.
Vereinfachung und Beschleunigung der Pipeline
Warp Decode erzielt Leistungssteigerungen durch zwei Hauptmechanismen: Zum einen entfallen Stufen und Puffer, die der herkömmliche Pfad benötigt, zum anderen entsteht Warp-Unabhängigkeit, die besseres Scheduling und eine effektivere Verdeckung von Latenzen ermöglicht.
Eliminierung von Stufen
Die Eliminierung von Stufen macht den größten Teil des Durchsatzgewinns aus. Wir eliminieren Padding, Scattering und den Schritt combine. Das Entfernen dieser Stufen erfordert eine grundlegende Neuorganisation der Parallelisierung, anstatt lediglich Stufen der traditionellen Pipeline zusammenzulegen.
Eliminierung von Padding
Traditioneller Pfad: Füllt die Token-Liste jedes Experten bis zur nächsten Zweierpotenz bzw. auf 128-Byte-Grenzen auf, um die Anforderungen gruppierter Kernel zu erfüllen. Beim Decoding mit nur einem Token lässt sich dieser Overhead nicht amortisieren.
Warp-Decode-Pfad: Vermeidet diesen Overhead vollständig, da dabei nie expertenweise Batches gebildet werden.
Wegfall von Scatter und Combine
Traditioneller Pfad: Nachdem jeder Experte fertig ist, schreibt er acht Zwischenergebnisse in den GPU-Speicher und führt dann einen separaten Reduktionsschritt aus, um sie zusammenzuführen.
Warp-Decode-Pfad: Das Routing-Gewicht für jeden Experten wird innerhalb des Warp direkt in den laufenden Akkumulator eingerechnet. Die acht Zwischenergebnisse werden nie im Speicher materialisiert, wodurch sowohl die Schreib- als auch die Lesekosten eines nachfolgenden Reduktionsdurchlaufs entfallen.
Eliminierung von Puffern
Die Umstrukturierung beseitigt außerdem zwei Zwischenpuffer im Speicher, die der herkömmliche Pfad aufgrund seines expertenzentrierten Layouts benötigt.
Der erste ist ein Activation-Gather-Puffer, also der Eingabe-Aktivierungsvektor, der in ein expertenorientiertes Layout kopiert und umgeordnet wird. Bei einer Batch-Größe von 1 ist das eine vollständige Kopie von Daten, die bereits vorhanden sind. Der zweite ist ein Output-Puffer pro Experte. Bei acht Experten und einer Hidden-Dimension von 2048 entspricht das 8 × 2048 × 2 Byte = 32 KB pro Token in BF16, die allokiert, geschrieben, sofort einmal gelesen und dann verworfen werden.
Warp Decode eliminiert beide, indem die acht Expertenbeiträge über 32 Warp-Lanes hinweg in einem Register-Akkumulator zusammengeführt werden; nichts gelangt in den globalen Speicher, bis am Ende ein einzelner Skalar geschrieben wird. Das Entfernen von mehr als 32 KB Zwischenpuffer-Traffic pro Token schafft L2-Cache-Kapazität für die Gewichtszeilen, die die Performance tatsächlich bestimmen.
Warp-Unabhängigkeit
Die Umstrukturierung beschleunigt auch die verbleibende Berechnung, weil der Kernel von Haus aus „embarrassingly parallel“ ist: Jeder Warp ist vollständig unabhängig von allen anderen. Da jeder Warp genau einen Ausgabe-Skalar besitzt und nur die Gewichtungszeilen liest, die er benötigt, gibt es zwischen den Warps keinen gemeinsam genutzten veränderlichen Zustand.
Auf der Ebene eines einzelnen Warps ist diese Unabhängigkeit vollständig. Die Eingabeaktivierungen sind schreibgeschützt, der Akkumulator liegt in privaten Registern, und der Schreibzugriff auf die Ausgabe erfolgt an eine eindeutige Adresse. Aus Sicht des Hardware-Schedulers ist die gesamte Ausgabedimension ein flacher Pool unabhängiger Arbeitseinheiten.
Der Warp-Scheduler der GPU kann jeden Warp jederzeit und in beliebiger Reihenfolge ausführen, ohne dass dabei Korrektheitsbedingungen zu beachten wären. Wenn ein Warp ins Stocken gerät, während er auf einen Speicherladevorgang wartet, wechselt der Scheduler sofort zu einem anderen. Bei Tausenden von Warps, die auf den 148 Streaming-Multiprozessoren einer B200 gleichzeitig aktiv sind, wird die Speicherlatenz fast vollständig durch nützliche Berechnungen anderer Warps verdeckt.
Der Kernel skaliert außerdem linear: Eine Verdopplung der Ausgabedimension verdoppelt auch die Anzahl unabhängiger Warps, ganz ohne zusätzliche Synchronisierung. Dasselbe gilt für die Token-Batch-Dimension, sodass der Scheduler einen einzigen flachen Namespace von Arbeit ohne Kanten zwischen den Knoten sieht. Das steht im Gegensatz zum traditionellen Pfad, bei dem GEMM-Kernel auf Expertenniveau eine Koordination innerhalb des Blocks erfordern.
Ergebnisse
End-to-End-Dekodierungsdurchsatz in großem Maßstab
Testing auf unserem internen Inferenzsystem mit einem Modell im Qwen-3-Stil auf NVIDIA-B200-GPUs zeigte einen konsistenten Durchsatzgewinn. Der Durchsatzgewinn bleibt über alle Buckets der Kontextlänge hinweg konstant und bestätigt damit, dass es sich um eine reine Verbesserung während der Generierung handelt, die nicht von der Prompt-Länge abhängt.
Verbesserte Genauigkeit
Das Entfernen des Zwischenschritts der Aktivierungsquantisierung hat einen messbaren Einfluss auf die Qualität. Die Konvertierung von BF16-Aktivierungen nach MXFP8 und zurück führt zu einer Untergrenze bei Rundungsfehlern, die sich über die Schichten des Modells hinweg aufsummiert. Warp Decode belässt Aktivierungen durchgehend in BF16 und Akkumulatoren in FP32, sodass die Reduktion nie auf beeinträchtigten Eingaben arbeitet. Das Ergebnis: Warp Decode erzeugt Ausgaben, die der vollständigen 32-Bit-Ground-Truth 1,4-mal näher kommen als beim klassischen Pfad.
Hardware-Effizienz
Wir haben warp decode mit der Frage entwickelt, wie nah wir an den maximalen Durchsatz der Hardware herankommen können. Der gemessene Spitzenwert des B200 für sequenzielle Speicherlesezugriffe liegt bei 6,8 TB/s (gemessen mit einem Copy-Kernel). Bei B=32 erreicht warp decode konstant 3,95 TB/s, also 58 % dieses Spitzenwerts. Die verbleibende Lücke dürfte auf die Speicherlatenz zurückzuführen sein, die durch die zufälligen Zugriffsmuster des Experten-Routings entsteht, da jedes Token an nicht benachbarte Experten wie 5, 8, 14, 19 usw. weitergeleitet werden kann.
Zum Vergleich: Der Spitzendurchsatz wird mit sequenziellen Speicherlesezugriffen (0,1,2,3) gemessen. Die Übereinstimmung mit der Referenzimplementierung war über alle Batch-Größen hinweg sehr hoch: minimale Kosinus-Ähnlichkeit > 0,999996, maximale absolute Differenz 0,001953.
Warp Decode und Composer-Training
Warp Decode ist kein allgemeiner Ersatz für die expertenzentrierte Ausführung. Umfangreichere Workloads wie Prefill und Inferenz mit großen Batches profitieren weiterhin von expertenzentriertem Packing, weil viele Token denselben Experten nutzen und sich der Aufwand für ihre Organisation über ausreichend tatsächliche Rechenarbeit amortisiert.
Warp Decode ist dann im Vorteil, wenn es pro Experte nicht genug gemeinsame Arbeit gibt, um diesen Overhead zu rechtfertigen – was bei MoE-Decode häufig der Fall ist. Dadurch ist es ein wichtiger Bestandteil davon, wie wir Composer kontinuierlich verbessern. Während Investitionen in Pretraining-Daten und RL die Qualität der Modellausgaben bestimmen, entscheiden Inferenzinvestitionen wie Warp Decode darüber, wie schnell und präzise diese Ausgaben Entwickler erreichen.