Pesquisa

Melhor inferência em modelos MoE com warp decode

Ao inverter o eixo do paralelismo, alcançamos uma inferência em modelos MoE 1,8x mais rápida e mais precisa.

Less Wright, Federico Cassano & Zhiyuan Zhang11 min de leitura

A maioria dos sistemas de inferência de MoE organiza o caminho de geração de tokens em torno dos especialistas. Isso reflete como o roteamento funciona e tem sido a abordagem padrão em escala. Para decodificação em lotes pequenos nas GPUs Blackwell, porém, constatamos que organizar o kernel em torno das saídas, em vez dos especialistas, funciona melhor. Chamamos essa abordagem de “warp decode”.

Chegamos ao warp decode ao pensar sobre qual é, na prática, a largura de banda de memória máxima atingível para decodificação de MoE no Blackwell. Isso nos levou a inverter completamente o eixo do paralelismo. Em vez de atribuir warps aos especialistas, atribuímos cada warp a um único valor de saída (neurônio).

Kernels que melhoram tanto a performance quanto a precisão são raros, e o warp decode é um deles. No Blackwell, ele oferece um ganho de 1,84x em throughput e também melhora a precisão, com saídas 1,4x mais próximas de uma referência FP32 completa. Isso acelera o pipeline de pesquisa e treinamento do Composer, permitindo melhorar o modelo mais rápido e entregar novas versões com mais frequência."

O caminho convencional do MoE

Modelos modernos de MoE roteiam cada token por um subconjunto de redes de especialistas, selecionando, por exemplo, 8 de 128 em uma determinada camada. A implementação padrão organiza toda a computação em torno desses especialistas: reúne os tokens de que cada especialista precisa, executa os cálculos e recompõe os resultados.

Isso funciona bem para prefill e lotes grandes, em que o trabalho compartilhado por especialista compensa a sobrecarga de organizar os dados. Mas, durante a etapa de decodificação autorregressiva, em que produzimos apenas um token por vez, não há trabalho compartilhado suficiente para justificar isso. Cinco das oito etapas do caminho tradicional existem apenas para gerenciar o layout dos dados na visão centrada no especialista e não realizam nenhum cálculo de fato.

O que mudamos

O warp decode elimina essas cinco etapas de “controle” ao reorganizar o paralelismo em torno das saídas, e não dos especialistas.

GPUs modernas executam instruções em grupos de 32 lanes de processamento paralelo, chamados de warp. Em nossa nova abordagem, cada warp fica responsável por calcular exatamente um valor de saída. O warp busca diretamente da memória os dados de peso de que precisa, agrega os totais de todos os oito especialistas roteados em um único acumulado e grava um resultado.

Essa independência de warp permite que o warp decode seja executado sem qualquer staging, handoffs, pontos de sincronização entre warps ou buffers intermediários. Toda a camada de computação do MoE é condensada em dois kernels, moe_gate_up_3d_batched e moe_down_3d_batched.

Como os dois kernels funcionam

No kernel de gate/up, cada cooperative thread array (CTA) tem oito warps, e cada warp fica responsável por um neurônio intermediário para cada combinação de um token com um especialista roteado. O warp carrega o ID do especialista roteado, lê as linhas de pesos de gate e up desse neurônio e percorre o vetor de ativação de entrada em streaming. Os pesos MXFP8 são convertidos para FP32 em tempo real, e ambos os produtos escalares são acumulados em registradores privados.

Como os dois kernels são fundidos em uma única passagem, o vetor de ativação é lido uma vez e reutilizado imediatamente nas duas projeções, sem nenhuma etapa em memória compartilhada. Após uma redução em nível de warp, o warp aplica SiLU(gate) × up e grava um valor intermediário.

No down kernel, cada warp fica responsável por uma dimensão de saída de um token. Ele percorre todos os especialistas roteados do top-k, carregando a linha de pesos relevante da projeção down e percorrendo as ativações intermediárias em streaming, enquanto incorpora o peso de roteamento de cada especialista em um único acumulador FP32.

Depois que todos os especialistas são processados, reduzimos as 32 somas parciais locais das lanes com uma redução butterfly em nível de warp usando __shfl_xor_sync. Isso é compilado diretamente para a instrução PTX shfl.sync.bfly, um primitivo único de hardware que troca registradores entre lanes dentro do warp, contornando totalmente a memória compartilhada.

A vantagem aqui é que não precisamos de acessos de ida e volta à L1, conflitos de banco nem barreiras explícitas, porque a sincronização já vem embutida na instrução por meio da máscara de lane. Em vez de um epílogo separado, a combinação final ponderada do top-k passa a fazer parte da própria projeção.

Cada warp no warp decode é independente e recebe uma única atribuição estável durante toda a sua execução: produzir um escalar de saída. É essa independência de warp que elimina a etapa em memória compartilhada, a sincronização entre warps e os buffers intermediários exigidos pelo caminho tradicional.

Simplificação e aceleração do pipeline

O warp decode alcança melhorias de performance por dois mecanismos principais: eliminando estágios e buffers exigidos pelo caminho tradicional e criando independência entre warps, o que permite melhor escalonamento e ocultação de latência.

Eliminação de estágios

As eliminações de estágios respondem pela maior parte do ganho de throughput. Eliminamos o padding, o scattering e a etapa de combinação. Remover esses estágios exige reorganizar o paralelismo do zero, em vez de apenas fundir estágios do pipeline tradicional.

Eliminação de padding

Caminho tradicional: Preenche a lista de tokens de cada especialista até limites de potência de 2 ou de 128 bytes, para atender aos requisitos de kernels agrupados. Na decodificação, com um único token, isso se torna uma sobrecarga que não pode ser amortizada.

Caminho de decodificação por warp: Evita totalmente essa sobrecarga ao nunca formar lotes por especialista.

Eliminação de scatter e combine

Caminho tradicional: Depois que cada especialista termina, ele grava oito resultados intermediários na memória da GPU e, em seguida, executa uma etapa de redução separada para combiná-los.

Caminho de decodificação por warp: O peso de roteamento de cada especialista é incorporado ao acumulador corrente dentro do warp. Os oito resultados intermediários nunca chegam a ser materializados na memória, economizando tanto os custos de gravação quanto os de leitura de uma etapa posterior de redução.

Eliminação de buffers

A reorganização também elimina dois buffers de memória intermediários que o caminho tradicional exige como consequência de seu layout orientado por especialistas.

O primeiro é um buffer de agrupamento de ativações, ou seja, o vetor de ativação de entrada copiado e reorganizado em um layout com especialistas como dimensão principal. Com batch size 1, isso é uma cópia completa de dados que já existem. O segundo é um buffer de saída por especialista. Com oito especialistas e dimensão oculta 2048, isso representa 8 × 2048 × 2 bytes = 32 KB por token em BF16, que são alocados, gravados, lidos imediatamente uma vez e descartados.

O Warp decode elimina ambos ao consolidar as oito contribuições dos especialistas em um acumulador em registrador ao longo das 32 lanes do warp, onde nada chega à memória global até a gravação final de um único escalar. Remover mais de 32 KB de tráfego de buffers intermediários por token libera capacidade do cache L2 para as linhas de pesos que realmente determinam a performance.

Independência de Warp

A reorganização também torna a computação restante mais rápida porque o kernel é “embarrassingly parallel” por definição: cada warp é completamente independente de qualquer outro. Como cada warp produz exatamente um escalar de saída e lê apenas as linhas de pesos de que precisa, não há estado mutável compartilhado entre warps.

No nível de um único warp, essa independência é total. As ativações de entrada são apenas para leitura, o acumulador fica em registradores privados, e a escrita da saída vai para um endereço exclusivo. Na perspectiva do escalonador de hardware, toda a dimensão de saída é um único conjunto de itens de trabalho independentes.

O escalonador de warps da GPU pode despachar qualquer warp a qualquer momento, em qualquer ordem, sem nenhuma restrição de corretude. Quando um warp fica parado esperando um carregamento de memória, o escalonador imediatamente alterna para outro. Com milhares de warps em execução nos 148 multiprocessadores de streaming de uma B200, a latência de memória fica quase totalmente encoberta pela computação útil de outros warps.

O kernel também escala linearmente, de modo que dobrar a dimensão de saída dobra o número de warps independentes sem adicionar sincronização. O mesmo vale para a dimensão do lote de tokens, então o escalonador vê um único namespace plano de trabalho, sem arestas entre nós. Isso contrasta com a abordagem tradicional, em que kernels GEMM no nível do expert exigem coordenação dentro do bloco.

Resultados

Throughput de decodificação ponta a ponta em escala

Testing em nosso sistema interno de inferência, executando um modelo no estilo Qwen-3 em GPUs NVIDIA B200, resultou em um ganho consistente de throughput. O ganho de throughput se mantém estável em todas as faixas de comprimento de contexto, confirmando que esta é uma melhoria puramente no tempo de geração e que não depende do comprimento do prompt.

Maior precisão

Remover a etapa intermediária de quantização das ativações tem um impacto mensurável na qualidade. Converter ativações de BF16 para MXFP8 e depois de volta introduz um patamar de erro de arredondamento que se acumula ao longo das camadas do modelo. O warp decode mantém as ativações em BF16 o tempo todo e os acumuladores em FP32, de modo que a redução nunca opera sobre entradas degradadas. Como resultado, o warp decode produz saídas 1,4x mais próximas da referência completa em 32 bits do que o caminho clássico.

Eficiência do hardware

Começamos a desenvolver o warp decode avaliando quão perto conseguiríamos chegar do throughput máximo do hardware. O pico medido do B200 para leituras contíguas de memória é de 6,8 TB/s (medido com um kernel de cópia). O warp decode mantém 3,95 TB/s em B=32, ou 58% desse pico. A diferença restante provavelmente reflete o custo de latência de memória dos padrões de acesso aleatório criados pelo roteamento de especialistas, já que cada token pode ser roteado para especialistas não adjacentes, como 5, 8, 14, 19 etc.

Já o throughput de pico é medido com leituras contíguas de memória (0,1,2,3). A aderência à implementação de referência foi excelente em todos os tamanhos de lote: similaridade cosseno mínima > 0,999996, diferença absoluta máxima de 0,001953.

Warp decode e treinamento do Composer

Warp decode não substitui de forma geral a execução centrada em especialistas. Cargas de trabalho de maior volume, como prefill e inferência em grandes lotes, ainda se beneficiam do empacotamento centrado em especialistas, porque muitos tokens compartilham o mesmo especialista, e o custo de organizá-los é amortizado por uma quantidade de computação real suficiente para compensar.

Warp decode se destaca quando não há trabalho compartilhado suficiente por especialista para justificar essa sobrecarga, como costuma acontecer na decodificação MoE. Isso o torna uma parte importante de como continuamos a aprimorar o Composer. Embora os investimentos em dados de pré-treinamento e RL determinem a qualidade das saídas do modelo, investimentos em inferência como warp decode determinam com que rapidez e precisão essas saídas chegam aos desenvolvedores.

Arquivado em: Pesquisa

Autors: Less Wright, Federico Cassano & Zhiyuan Zhang