조사

warp decode로 더 나은 MoE 모델 추론

병렬성 축을 뒤집어 MoE 모델 추론 속도를 1.8배 높이고 정확도도 개선합니다.

Less Wright, Federico Cassano & Zhiyuan Zhang읽는 데 8분 소요

대부분의 MoE 추론 시스템은 토큰 생성 경로를 expert 중심으로 구성합니다. 이는 라우팅이 동작하는 방식과 맞닿아 있으며, 대규모 환경에서는 표준적인 접근법이었습니다. 하지만 Blackwell GPU에서의 소규모 배치 디코드에서는 커널을 expert가 아니라 출력 중심으로 구성하는 편이 더 낫다는 점을 확인했습니다. 우리는 이 접근법을 “warp decode”라고 부릅니다.

warp decode는 Blackwell에서 MoE 디코드가 실제로 달성할 수 있는 최대 메모리 대역폭이 무엇인지 고민하는 과정에서 도달한 방식입니다. 그 결과 병렬성 축 자체를 완전히 뒤집게 되었습니다. 워프를 expert에 할당하는 대신, 각 워프를 하나의 출력 값(뉴런)에 할당합니다.

성능과 정확도를 모두 개선하는 커널은 드물며, warp decode는 그중 하나입니다. Blackwell에서 이 방식은 처리량을 1.84배 높이는 동시에, 출력이 전체 FP32 reference에 1.4배 더 가까워지도록 정확도도 개선합니다. 덕분에 Composer의 조사 및 학습 파이프라인이 빨라져, 모델을 더 빠르게 개선하고 새 버전을 더 자주 배포할 수 있습니다.

기존 MoE 경로

최신 MoE 모델은 각 토큰을 특화된 expert 네트워크의 일부로 라우팅하며, 예를 들어 특정 레이어에서는 128개 중 8개를 선택합니다. 표준 구현은 각 expert에게 필요한 토큰을 모은 뒤 연산을 수행하고 결과를 다시 조합하는 방식으로, 모든 계산을 expert 중심으로 구성합니다.

이 방식은 expert별로 공유되는 작업이 데이터 구성 오버헤드를 상쇄할 수 있는 프리필과 대규모 배치에서는 잘 동작합니다. 하지만 한 번에 토큰 하나만 생성하는 자기회귀 디코드 단계에서는 이를 정당화할 만큼 공유 작업이 충분하지 않습니다. 기존 경로의 8단계 중 5단계는 expert 중심 관점에 맞게 데이터 레이아웃을 관리하기 위해서만 존재하며, 실제 연산은 수행하지 않습니다.

변경 사항

Warp decode는 병렬 처리를 expert가 아니라 출력값을 중심으로 재구성해, 그 다섯 가지 “bookkeeping” 단계를 없앴습니다.

최신 GPU는 warp라고 하는 32개의 병렬 처리 레인 묶음으로 명령을 실행합니다. 새로운 접근 방식에서는 각 warp가 정확히 하나의 출력값만 계산하도록 할당됩니다. 각 warp는 필요한 가중치 데이터를 메모리에서 직접 읽어 오고, 라우팅된 8개 expert의 합계를 하나의 누적 합으로 집계한 뒤, 결과 하나를 기록합니다.

이처럼 warp가 서로 독립적이기 때문에, warp decode는 별도의 스테이징, 전환, warp 간 동기화 지점, 중간 버퍼 없이 실행될 수 있습니다. 전체 MoE 계산 레이어는 moe_gate_up_3d_batchedmoe_down_3d_batched 두 개의 커널로 압축됩니다.

두 커널이 작동하는 방식

gate/up 커널에서는 각 cooperative thread array(CTA)가 8개의 warp로 구성되며, 각 warp는 토큰과 라우팅된 expert의 각 조합마다 하나의 중간 뉴런을 맡습니다. warp는 라우팅된 expert ID를 로드하고, 해당 뉴런의 gate 및 up 가중치 행을 읽은 뒤 입력 활성화 벡터를 스트리밍 방식으로 순회합니다. MXFP8 가중치는 즉석에서 FP32로 변환되며, 두 dot product는 모두 private 레지스터에 누적됩니다.

두 커널이 한 번의 패스로 fused되어 있으므로, 활성화 벡터는 한 번만 읽고 두 프로젝션에 바로 재사용되며 shared memory에 따로 스테이징할 필요가 없습니다. warp 수준 리덕션이 끝나면 warp는 SiLU(gate) × up을 적용해 하나의 중간 값을 기록합니다.

down 커널에서는 각 warp가 한 토큰의 한 출력 차원을 맡습니다. 각 warp는 모든 top-k 라우팅 expert를 순회하면서, 해당 down-projection 가중치 행을 로드하고 중간 활성화를 스트리밍 방식으로 훑는 동시에 각 expert의 라우팅 가중치를 하나의 FP32 누산기에 계속 반영합니다.

모든 expert 처리가 끝나면 __shfl_xor_sync를 사용하는 warp 수준 butterfly reduction으로 32개의 lane별 부분합을 줄입니다. 이는 PTX shfl.sync.bfly 명령으로 직접 컴파일되며, shared memory를 완전히 거치지 않고 warp 내부 lane 간 레지스터를 교환하는 단일 하드웨어 프리미티브입니다.

여기서의 핵심 이점은 동기화가 lane mask를 통해 명령 자체에 내장되어 있으므로 L1 왕복, bank conflict, 명시적 배리어가 필요 없다는 점입니다. 별도의 에필로그 없이도 최종 가중 top-k 결합이 프로젝션 자체에 포함됩니다.

warp decode에서 각 warp는 서로 독립적이며, 전체 수명 동안 하나의 고정된 작업만 맡습니다. 즉, 하나의 출력 스칼라를 생성합니다. 이런 warp 독립성 덕분에 기존 경로에 필요했던 shared memory 스테이징, warp 간 동기화, 중간 버퍼가 사라집니다.

파이프라인 단순화 및 가속화

Warp decode는 두 가지 주요 메커니즘을 통해 성능을 향상합니다. 첫째, 기존 경로에 필요했던 단계와 버퍼를 제거하고, 둘째, 워프 간 독립성을 확보해 더 나은 스케줄링과 지연 시간 은닉이 가능해집니다.

단계 제거

처리량 향상 효과의 대부분은 단계 제거에서 나옵니다. 패딩, 스캐터링, 결합 단계를 제거합니다. 이러한 단계를 없애려면 기존 파이프라인의 단계를 단순히 합치는 수준이 아니라, 병렬성을 근본부터 다시 구성해야 합니다.

패딩 제거

기존 경로: grouped 커널 요구 사항을 맞추기 위해 각 expert의 token 목록을 2의 거듭제곱 또는 128바이트 경계에 맞춰 패딩합니다. 디코드 시점에는 token이 하나뿐이므로, 이 오버헤드는 분산시켜 상쇄할 수 없습니다.

워프 디코드 경로: expert별 배치를 아예 만들지 않으므로 이 오버헤드를 완전히 없앱니다.

scatter 및 결합 제거

기존 경로: 각 expert가 작업을 마치면 GPU 메모리에 8개의 중간 결과를 쓴 뒤, 이를 결합하기 위해 별도의 리덕션 단계를 실행합니다.

워프 디코드 경로: 각 expert의 라우팅 weight는 워프 내에서 누적 중인 누산기에 바로 반영됩니다. 8개의 중간 결과가 메모리에 생성되지 않으므로, 이후 리덕션 패스에 필요한 쓰기와 읽기 오버헤드를 모두 줄일 수 있습니다.

버퍼 제거

이 재구성은 기존 경로가 expert 중심 레이아웃 때문에 필요로 하던 두 개의 중간 메모리 버퍼도 없앱니다.

첫 번째는 activation gather buffer로, 입력 activation 벡터를 복사해 expert-major 레이아웃으로 다시 정렬한 버퍼입니다. 배치 크기가 1일 때는 이미 존재하는 데이터를 통째로 한 번 더 복사하는 셈입니다. 두 번째는 expert별 output 버퍼입니다. expert가 8개이고 hidden dimension이 2048이면, BF16 기준으로 토큰당 8 × 2048 × 2 bytes = 32 KB에 해당하며, 메모리를 할당해 쓰고, 곧바로 한 번 읽은 뒤 버려집니다.

Warp decode는 8개 expert의 기여물을 32개 warp lane 전반의 register accumulator에 누적하는 방식으로 이 두 버퍼를 모두 없애며, 마지막 단일 스칼라 쓰기 전까지는 어떤 데이터도 global memory로 나가지 않습니다. 토큰당 32+ KB의 중간 버퍼 트래픽을 없애면 실제로 성능을 좌우하는 weight row에 L2 cache 용량을 더 할당할 수 있습니다.

워프 독립성

이러한 재구성은 유지되는 계산을 더 빠르게 만들기도 하는데, 커널이 설계상 “embarrassingly parallel”하기 때문입니다. 즉, 모든 워프가 다른 워프와 완전히 독립적으로 동작합니다. 각 워프는 정확히 하나의 출력 스칼라를 담당하고 자신에게 필요한 가중치 행만 읽으므로, 워프 간에는 공유되는 변경 가능 상태가 없습니다.

단일 워프 수준에서는 이러한 독립성이 완전합니다. 입력 활성화는 읽기 전용이며, 누산기는 전용 레지스터에 저장되고, 출력 쓰기는 고유한 주소에 기록됩니다. 하드웨어 스케줄러의 관점에서 보면, 전체 출력 차원은 서로 독립적인 작업 항목들로 이루어진 평면적인 pool입니다.

GPU의 워프 스케줄러는 정확성 제약 없이 어떤 워프든 언제든, 어떤 순서로든 실행할 수 있습니다. 한 워프가 메모리 로드를 기다리며 멈추면, 스케줄러는 즉시 다른 워프로 전환합니다. B200의 148개 streaming multiprocessor 전반에서 수천 개의 워프가 동시에 실행되는 상황에서는, 메모리 지연 시간이 다른 워프의 유의미한 계산에 의해 거의 완전히 가려집니다.

이 커널은 선형적으로 확장되므로, 출력 차원을 두 배로 늘리면 추가 동기화 없이 독립적인 워프 수도 두 배가 됩니다. 토큰 배치 차원에서도 마찬가지이므로, 스케줄러는 노드 간 연결이 없는 하나의 평면적인 작업 네임스페이스를 보게 됩니다. 이는 expert 수준의 GEMM 커널에서 블록 내부 조정이 필요한 기존 방식과는 대조적입니다.

결과

대규모 환경에서의 엔드투엔드 디코드 처리량

NVIDIA B200 GPU에서 Qwen-3 스타일 모델을 실행하는 당사 내부 추론 시스템에서 테스트한 결과, 처리량이 일관되게 향상되었습니다. 처리량 증가는 모든 컨텍스트 길이 구간에서 거의 일정하게 나타나며, 이는 프롬프트 길이에 의존하지 않는 순수한 생성 시간 개선임을 보여줍니다.

개선된 정확도

중간 활성화 양자화 단계를 제거하면 품질에 측정 가능한 효과가 있습니다. BF16 활성화를 MXFP8로 변환한 뒤 다시 되돌리면 모델의 여러 레이어에 걸쳐 누적되는 반올림 오차 하한선이 생깁니다. Warp decode는 전체 과정에서 활성화는 BF16으로, 누산기는 FP32로 유지하므로 리덕션이 성능이 저하된 입력에 대해 수행되지 않습니다. 그 결과 warp decode의 출력은 기존 경로보다 완전한 32비트 기준값에 1.4배 더 가깝습니다.

하드웨어 효율성

warp decode 개발은 하드웨어의 최대 처리량에 얼마나 근접할 수 있는지 살펴보는 것에서 시작했습니다. 연속 메모리 읽기에서 B200의 실측 최고치는 6.8 TB/s입니다(copy kernel로 측정). warp decode는 B=32에서 3.95 TB/s를 유지하며, 이는 해당 최고치의 58%입니다. 남아 있는 격차는 expert routing이 만들어내는 랜덤 접근 패턴의 메모리 지연 비용 때문일 가능성이 높습니다. 각 token이 5, 8, 14, 19처럼 서로 인접하지 않은 expert로 라우팅될 수 있기 때문입니다.

반면 최고 처리량은 연속적인 (0,1,2,3) 메모리 읽기를 사용해 측정합니다. 참조 구현과의 정확성은 모든 배치 크기에서 매우 높았습니다. 최소 코사인 유사도는 0.999996 초과였고, 최대 절대 차이는 0.001953이었습니다.

Warp decode와 Composer 학습

Warp decode는 expert 중심 실행을 전반적으로 대체하는 범용 기술이 아닙니다. prefill이나 대규모 배치 추론처럼 처리량이 많은 워크로드는 여전히 expert 중심 패킹의 이점을 누립니다. 많은 토큰이 동일한 expert를 공유하고, 이를 정리하는 데 드는 비용이 충분한 실제 연산 전반에 분산되어 그만한 가치가 있기 때문입니다.

expert별로 공유되는 작업량이 그 오버헤드를 감수할 만큼 충분하지 않을 때는 Warp decode가 더 효과적이며, 이는 MoE 디코드에서 흔히 나타나는 경우입니다. 그래서 Warp decode는 Composer를 지속적으로 개선하는 데 중요한 역할을 합니다. 사전 학습 데이터와 RL에 대한 투자가 모델 출력의 품질을 결정한다면, warp decode와 같은 추론 관련 투자는 그 출력이 얼마나 빠르고 정확하게 개발자에게 전달되는지를 결정합니다.

카테고리: 조사

작성자s: Less Wright, Federico Cassano & Zhiyuan Zhang