अनुसंधान

warp decode के साथ बेहतर MoE मॉडल अनुमिति

समानांतरता अक्ष को पलटकर हम 1.8x तेज़ और अधिक सटीक MoE मॉडल अनुमिति हासिल करते हैं।

Less Wright, Federico Cassano & Zhiyuan Zhang12 मिनट में पढ़ें

ज़्यादातर MoE अनुमिति सिस्टम token जनरेशन पाथ को एक्सपर्ट्स के इर्द-गिर्द व्यवस्थित करते हैं। यह रूटिंग के काम करने के तरीके को दर्शाता है और बड़े पैमाने पर यही मानक तरीका रहा है। हालांकि, Blackwell GPUs पर छोटे-बैच डीकोड के लिए, हमें यह मिला कि एक्सपर्ट्स के बजाय आउटपुट के इर्द-गिर्द कर्नेल को व्यवस्थित करना बेहतर काम करता है। हम इस तरीके को “warp decode” कहते हैं।

हम warp decode तक इस बारे में सोचते हुए पहुँचे कि Blackwell पर MoE डिकोड के लिए वास्तव में अधिकतम हासिल की जा सकने वाली मेमोरी बैंडविड्थ क्या है। इससे हमने समानांतरता अक्ष को पूरी तरह पलट दिया। वार्प्स को एक्सपर्ट्स को सौंपने के बजाय, हम हर वार्प को एक ही आउटपुट मान (न्यूरॉन) सौंपते हैं।

ऐसे कर्नेल, जो प्रदर्शन और सटीकता दोनों में सुधार करें, बहुत कम मिलते हैं, और warp decode उनमें से एक है। Blackwell पर यह 1.84x थ्रूपुट सुधार देता है, साथ ही सटीकता भी बढ़ाता है, जिसमें आउटपुट पूर्ण FP32 संदर्भ के 1.4x अधिक करीब होते हैं। इससे Composer के लिए अनुसंधान और प्रशिक्षण पाइपलाइन तेज़ होती है, जिससे हम मॉडल को और तेज़ी से बेहतर बना पाते हैं और नए संस्करण अधिक बार शिप कर पाते हैं।

पारंपरिक MoE पथ

आधुनिक MoE मॉडल हर token को विशेषीकृत एक्सपर्ट नेटवर्क्स के एक उपसमूह से होकर रूट करते हैं — उदाहरण के लिए, किसी दिए गए लेयर पर 128 में से 8 एक्सपर्ट चुने जाते हैं। मानक कार्यान्वयन पूरी गणना को इन्हीं एक्सपर्ट्स के इर्द-गिर्द व्यवस्थित करता है: हर एक्सपर्ट को जिन टोकन की ज़रूरत होती है, उन्हें इकट्ठा करना, गणना चलाना, और फिर परिणामों को दोबारा जोड़ना।

यह प्रीफिल और बड़े बैचों के लिए अच्छी तरह काम करता है, जहाँ प्रति एक्सपर्ट होने वाला साझा काम डेटा को व्यवस्थित करने के ओवरहेड की भरपाई कर देता है। लेकिन ऑटोरेग्रेसिव डीकोड चरण के दौरान, जहाँ हम एक बार में केवल एक token उत्पन्न करते हैं, इसे उचित ठहराने लायक साझा काम नहीं होता। पारंपरिक पथ के आठ में से पाँच चरण सिर्फ विशेषज्ञ-केंद्रित दृश्य के लिए डेटा लेआउट को प्रबंधित करने के लिए होते हैं और इनमें कोई वास्तविक गणना नहीं होती।

हमने क्या बदला

Warp decode, समानांतरता को एक्सपर्ट्स के बजाय आउटपुट के इर्द-गिर्द पुनर्गठित करके, उन पाँच “बुककीपिंग” चरणों को खत्म कर देता है।

आधुनिक GPU निर्देशों को 32 समानांतर प्रोसेसिंग लेन के समूहों में चलाते हैं, जिन्हें वार्प कहा जाता है। हमारे नए तरीके में, हर वार्प को ठीक एक आउटपुट मान की गणना करने के लिए असाइन किया जाता है। वार्प अपने लिए आवश्यक वेट डेटा को सीधे मेमोरी से स्ट्रीम करता है, सभी आठ रूट किए गए एक्सपर्ट्स के योग को एक ही रनिंग टोटल में जोड़ता है, और एक परिणाम लिखता है।

यह वॉर्प इंडिपेंडेंस, warp decode को बिना किसी स्टेजिंग, हैंडऑफ़, क्रॉस-वॉर्प सिंक पॉइंट, या मध्यवर्ती बफ़र के चलने देता है। पूरी MoE compute layer को समेटकर दो कर्नेल में ला दिया गया है: moe_gate_up_3d_batched और moe_down_3d_batched

दो कर्नेल कैसे काम करते हैं

gate/up कर्नेल में, प्रत्येक cooperative thread array (CTA) में आठ वार्प होते हैं, और प्रत्येक वार्प के हिस्से में token और रूट किए गए एक्सपर्ट की हर जोड़ी के लिए एक मध्यवर्ती न्यूरॉन आता है। वार्प रूट किए गए एक्सपर्ट की ID लोड करता है, उस न्यूरॉन के लिए gate और up वेट की पंक्तियाँ पढ़ता है, और इनपुट सक्रियण वेक्टर पर स्ट्रीम करता है। MXFP8 वेट्स को चलते-चलते FP32 में रूपांतरित किया जाता है, और दोनों dot product निजी रजिस्टरों में एकत्रित होते हैं।

क्योंकि दोनों कर्नेल एक ही pass में fused हैं, सक्रियण वेक्टर को एक बार पढ़ा जाता है और shared memory स्टेजिंग के बिना तुरंत दोनों projection के लिए फिर से इस्तेमाल कर लिया जाता है। वार्प-स्तरीय रिडक्शन के बाद, वार्प SiLU(gate) × up लागू करता है और एक मध्यवर्ती मान लिखता है।

down कर्नेल में, प्रत्येक वार्प एक token के लिए एक आउटपुट आयाम का जिम्मा संभालता है। यह सभी शीर्ष-k रूट किए गए एक्सपर्ट्स पर लूप करता है, प्रासंगिक down-projection वेट पंक्ति को लोड करता है और मध्यवर्ती एक्टिवेशन्स पर स्ट्रीम करता है, साथ ही हर एक्सपर्ट के रूटिंग वेट को एक ही चल रहे FP32 एक्यूम्युलेटर में जोड़ता जाता है।

सभी एक्सपर्ट्स की प्रक्रिया पूरी होने के बाद, हम __shfl_xor_sync का इस्तेमाल करके वार्प-स्तरीय butterfly रिडक्शन से 32 लेन-स्थानीय आंशिक योगों को reduce करते हैं। यह सीधे PTX निर्देश shfl.sync.bfly में compile होता है, जो एकल hardware primitive है और shared memory को पूरी तरह bypass करते हुए वार्प के भीतर लेन के बीच रजिस्टरों का आदान-प्रदान करता है।

यहाँ फायदा यह है कि हमें L1 round-trips, bank conflicts, या explicit barriers की ज़रूरत नहीं पड़ती, क्योंकि synchronization लेन mask के ज़रिए निर्देश में ही शामिल है। अलग epilogue की बजाय, अंतिम weighted शीर्ष-k संयोजन projection का ही हिस्सा बन जाता है।

warp decode में प्रत्येक वार्प स्वतंत्र होता है और उसे अपने पूरे lifetime के लिए एक ही स्थिर assignment मिलता है: एक आउटपुट स्केलर तैयार करना। यही वॉर्प इंडिपेंडेंस shared memory स्टेजिंग, cross-warp synchronization, और उन मध्यवर्ती बफ़र को खत्म कर देती है जिनकी पारंपरिक पथ को आवश्यकता होती है।

पाइपलाइन का सरलीकरण और त्वरण

Warp decode दो मुख्य तरीकों से प्रदर्शन में सुधार लाता है: पारंपरिक पथ में आवश्यक चरणों और बफ़रों को हटाकर, और वॉर्प इंडिपेंडेंस बनाकर, जो बेहतर अनुसूचीकरण और विलंबता को छिपाने में मदद करता है।

चरण हटाना

चरणों को हटाने से थ्रूपुट वृद्धि का अधिकांश लाभ मिलता है। हम padding, scattering, और संयोजन चरण को हटाते हैं। इन चरणों को हटाने के लिए, पारंपरिक पाइपलाइन के चरणों को केवल मिलाने भर के बजाय, समानांतरता को शुरू से ही पुनर्गठित करना पड़ता है।

पैडिंग को समाप्त करना

पारंपरिक पथ: grouped कर्नेल की आवश्यकताओं के अनुरूप, हर एक्सपर्ट की token सूची को 2 की घात या 128 byte की सीमाओं तक पैड करता है। डीकोड के समय, जब केवल एक token होता है, तो यह ऐसा ओवरहेड बन जाता है जिसकी भरपाई नहीं की जा सकती।

Warp decode पथ: प्रति-एक्सपर्ट बैच कभी बनाए ही नहीं जाते, इसलिए यह इस ओवरहेड से पूरी तरह बचता है।

scatter और combine को हटाना

पारंपरिक पथ: हर एक्सपर्ट के काम पूरा करने के बाद, वह GPU मेमोरी में आठ मध्यवर्ती परिणाम लिखता है, फिर उन्हें संयोजित करने के लिए एक अलग रिडक्शन चरण चलाया जाता है।

Warp decode पथ: हर एक्सपर्ट का रूटिंग वेट वार्प के भीतर चल रहे एक्यूम्युलेटर में ही समाहित कर दिया जाता है। ये आठ मध्यवर्ती परिणाम कभी मेमोरी में बनते ही नहीं, जिससे बाद के रिडक्शन पास में लिखने और पढ़ने—दोनों की लागत बच जाती है।

बफ़र हटाना

यह पुनर्गठन उन दो मध्यवर्ती मेमोरी बफ़रों को भी हटा देता है, जिनकी ज़रूरत पारंपरिक पथ में उसके विशेषज्ञ-केंद्रित लेआउट के कारण पड़ती है।

पहला एक सक्रियण gather बफ़र है, यानी इनपुट सक्रियण वेक्टर, जिसे कॉपी करके एक्सपर्ट-मेजर लेआउट में फिर से व्यवस्थित किया जाता है। बैच साइज़ 1 पर, यह उस डेटा की पूरी कॉपी है जो पहले से मौजूद है। दूसरा प्रति-एक्सपर्ट आउटपुट बफ़र है। आठ एक्सपर्ट्स और hidden dimension 2048 के साथ, यह BF16 में प्रति token 8 × 2048 × 2 bytes = 32 KB होता है, जिसे allocate किया जाता है, लिखा जाता है, तुरंत एक बार पढ़ा जाता है, और फिर छोड़ दिया जाता है।

Warp decode दोनों को खत्म कर देता है, क्योंकि यह 32 वार्प लेन में आठ एक्सपर्ट्स के योगदान को एक register एक्यूम्युलेटर में समेट देता है, जहाँ अंतिम single-scalar write तक कुछ भी global memory में नहीं जाता। प्रति token 32+ KB मध्यवर्ती बफ़र ट्रैफ़िक हटाने से L2 cache की क्षमता उन weight rows के लिए खाली हो जाती है, जो वास्तव में प्रदर्शन तय करती हैं।

वॉर्प इंडिपेंडेंस

यह पुनर्गठन बची हुई गणना को और तेज़ बना देता है, क्योंकि डिज़ाइन के अनुसार यह कर्नेल “embarrassingly parallel” है: हर वार्प दूसरे हर वार्प से पूरी तरह स्वतंत्र है। चूँकि हर वार्प ठीक एक आउटपुट स्केलर संभालता है और केवल वही वेट पंक्तियाँ पढ़ता है जिनकी उसे आवश्यकता होती है, इसलिए वार्प्स के बीच कोई साझा परिवर्तनीय अवस्था नहीं होती।

एक अकेले वार्प के स्तर पर यह स्वतंत्रता पूरी तरह बनी रहती है। इनपुट एक्टिवेशन्स केवल-पढ़ने योग्य होते हैं, एक्यूम्युलेटर निजी रजिस्टरों में रहता है, और आउटपुट राइट एक अद्वितीय पते पर होती है। हार्डवेयर शेड्यूलर के नज़रिए से, पूरा आउटपुट डाइमेंशन स्वतंत्र कार्य-आइटमों का एक समतल पूल है।

GPU का वार्प शेड्यूलर किसी भी समय, किसी भी क्रम में, कोई भी वार्प जारी कर सकता है, और शुद्धता के लिहाज़ से उस पर कोई बाध्यता नहीं होती। जब एक वार्प मेमोरी लोड की प्रतीक्षा में रुक जाता है, तो शेड्यूलर तुरंत दूसरे वार्प पर स्विच कर जाता है। B200 के 148 स्ट्रीमिंग मल्टीप्रोसेसर्स में एक साथ चल रहे हज़ारों वार्प्स के साथ, मेमोरी विलंबता लगभग पूरी तरह दूसरे वार्प्स की उपयोगी गणना के पीछे छिप जाती है।

यह कर्नेल रैखिक रूप से स्केल भी करता है, यानी आउटपुट डाइमेंशन दोगुना करने पर स्वतंत्र वार्प्स की संख्या भी दोगुनी हो जाती है और कोई अतिरिक्त सिंक्रोनाइज़ेशन नहीं जुड़ता। token batch डाइमेंशन में भी यही बात लागू होती है, इसलिए शेड्यूलर को कार्य का एक समतल नेमस्पेस दिखता है, जिसमें नोड्स के बीच कोई किनारे नहीं होते। यह पारंपरिक पथ के विपरीत है, जहाँ एक्सपर्ट-स्तरीय GEMM कर्नेल्स को intra-block समन्वय की आवश्यकता होती है।

परिणाम

बड़े पैमाने पर एंड-टू-एंड डीकोड थ्रूपुट

हमारी आंतरिक अनुमिति प्रणाली पर NVIDIA B200 GPUs के साथ Qwen-3 शैली का मॉडल चलाकर की गई Testing में सुसंगत थ्रूपुट वृद्धि मिली। यह थ्रूपुट वृद्धि सभी संदर्भ-लंबाई बकेट्स में एक-सी रहती है, जिससे पुष्टि होती है कि यह केवल जनरेशन-समय का सुधार है और प्रॉम्प्ट की लंबाई पर निर्भर नहीं करता।

सुधारित सटीकता

मध्यवर्ती सक्रियण क्वांटाइज़ेशन चरण को हटाने से गुणवत्ता पर मापने योग्य प्रभाव पड़ता है। BF16 एक्टिवेशन्स को MXFP8 में और फिर वापस रूपांतरित करने से राउंडिंग त्रुटि की एक न्यूनतम सीमा बनती है, जो मॉडल की परतों में जमा होती जाती है। warp decode पूरे समय एक्टिवेशन्स को BF16 में और एक्यूम्युलेटर्स को FP32 में रखता है, इसलिए रिडक्शन कभी भी खराब हो चुके इनपुट्स पर काम नहीं करता। नतीजतन, warp decode ऐसे आउटपुट देता है जो क्लासिकल पथ की तुलना में पूर्ण 32-बिट ग्राउंड ट्रुथ के 1.4x अधिक करीब होते हैं।

हार्डवेयर दक्षता

हमने यह सवाल पूछते हुए warp decode विकसित करना शुरू किया कि हम हार्डवेयर के अधिकतम थ्रूपुट के कितने करीब पहुँच सकते हैं। कॉन्टिग्युअस मेमोरी रीड्स के लिए B200 का मापा गया पीक 6.8 TB/s है (इसे copy कर्नेल इस्तेमाल करके मापा गया है)। B=32 पर warp decode 3.95 TB/s बनाए रखता है, यानी उस पीक का 58%। बचा हुआ अंतर संभवतः उन रैंडम पहुँच पैटर्न की मेमोरी लेटेंसी लागत को दर्शाता है, जो एक्सपर्ट रूटिंग बनाती है, क्योंकि हर token को 5, 8, 14, 19 आदि जैसे गैर-सन्निहित एक्सपर्ट्स तक रूट किया जा सकता है।

इसके विपरीत, पीक थ्रूपुट को कॉन्टिग्युअस (0,1,2,3) मेमोरी रीड्स इस्तेमाल करके मापा जाता है। संदर्भ इम्प्लीमेंटेशन के मुकाबले शुद्धता सभी बैच साइज़ में बहुत सटीक रही: न्यूनतम कोसाइन समानता > 0.999996, अधिकतम निरपेक्ष अंतर 0.001953।

Warp decode और Composer प्रशिक्षण

Warp decode विशेषज्ञ-केंद्रित निष्पादन का कोई सार्वभौमिक विकल्प नहीं है। प्रीफिल और बड़े-बैच अनुमिति जैसे उच्च-मात्रा वाले वर्कलोड अब भी विशेषज्ञ-केंद्रित पैकिंग से लाभ उठाते हैं, क्योंकि कई टोकन एक ही एक्सपर्ट को साझा करते हैं, और उन्हें व्यवस्थित करने की लागत पर्याप्त वास्तविक गणना पर बँट जाती है, जिससे यह तरीका फायदेमंद बनता है।

Warp decode तब बेहतर काम करता है, जब प्रति एक्सपर्ट इतना साझा कार्य नहीं होता कि उस ओवरहेड को उचित ठहराया जा सके, जैसा कि अक्सर MoE डिकोड में होता है। इसी वजह से यह Composer को लगातार बेहतर बनाने के हमारे तरीके का एक अहम हिस्सा है। जहाँ प्रीट्रेनिंग डेटा और RL में निवेश मॉडल के आउटपुट की गुणवत्ता तय करते हैं, वहीं Warp decode जैसे अनुमिति-संबंधी निवेश यह तय करते हैं कि वे आउटपुट डेवलपर्स तक कितनी तेज़ी और सटीकता से पहुँचते हैं।