@@ -280,8 +280,8 @@ template <class Gemm> struct ExampleRunner {
280280 return passed;
281281 }
282282
283- // / Allocates device-side data
284- void allocate (const GroupGEMMOptions &options, const ElementA *block_A_ptr,
283+ // / Allocates device-side data for reference GEMM
284+ void allocate_for_ref_gemm (const GroupGEMMOptions &options, const ElementA *block_A_ptr,
285285 const ElementA *block_B_ptr, ElementOutput*block_C_ptr,
286286 int block_A_size, int block_B_size, int block_C_size) {
287287 int64_t total_elements_A = 0 ;
@@ -338,85 +338,10 @@ template <class Gemm> struct ExampleRunner {
338338 cumsum_device.copy_from_host (cumsum_host);
339339 }
340340
341- // / Initialize operands to be used in the GEMM and reference GEMM
342- void initialize_for_moe_gemm (const GroupGEMMOptions &options) {
341+ // / Initialize operands to be used in the reference GEMM
342+ void initialize (const GroupGEMMOptions &options) {
343343
344- problem_sizes.reset (options.groups );
345- problem_sizes.copy_from_host (options.problem_sizes_host .data ());
346-
347- //
348- // Assign pointers
349- //
350-
351- std::vector<ElementA *> ptr_A_host (1 );
352- std::vector<ElementB *> ptr_B_host (1 );
353- std::vector<ElementC *> ptr_C_host (1 );
354- std::vector<ElementOutput *> ptr_D_host (1 );
355- std::vector<ElementAccumulator *> ptr_alpha_host (options.groups );
356- std::vector<ElementAccumulator *> ptr_beta_host (options.groups );
357-
358- // Compute offsets, alpha & beta over group on host
359-
360- ptr_A_host.at (0 ) = block_A.get ();
361- ptr_B_host.at (0 ) = block_B.get ();
362- ptr_C_host.at (0 ) = block_C.get ();
363- ptr_D_host.at (0 ) = block_D.get ();
364- for (int32_t i = 0 ; i < options.groups ; ++i) {
365- // Fill host vector of alpha & beta with random values if using per-group
366- // values
367- alpha_host.push_back (
368- (options.alpha == FLT_MAX)
369- ? static_cast <ElementAccumulator>((rand () % 5 ) + 1 )
370- : options.alpha );
371- beta_host.push_back ((options.beta == FLT_MAX)
372- ? static_cast <ElementAccumulator>(rand () % 5 )
373- : options.beta );
374- // Fill host ptr vectors with offset addresses into device alpha/beta
375- // blocks
376- ptr_alpha_host.at (i) = block_alpha.get () + i;
377- ptr_beta_host.at (i) = block_beta.get () + i;
378- }
379-
380- // Allocate device memory & copy from host
381- ptr_A.reset (1 );
382- // Per-group alpha and beta
383- ptr_A.copy_from_host (ptr_A_host.data ());
384-
385- ptr_B.reset (1 );
386- ptr_B.copy_from_host (ptr_B_host.data ());
387-
388- ptr_C.reset (1 );
389- ptr_C.copy_from_host (ptr_C_host.data ());
390-
391- ptr_D.reset (1 );
392- ptr_D.copy_from_host (ptr_D_host.data ());
393-
394- stride_A.reset (options.groups );
395- stride_A.copy_from_host (stride_A_host.data ());
396-
397- stride_B.reset (options.groups );
398- stride_B.copy_from_host (stride_B_host.data ());
399-
400- stride_C.reset (options.groups );
401- stride_C.copy_from_host (stride_C_host.data ());
402-
403- stride_D.reset (options.groups );
404- stride_D.copy_from_host (stride_D_host.data ());
405-
406- // Per-group alpha and beta ptrs
407- alpha_device.reset (options.groups );
408- alpha_device.copy_from_host (ptr_alpha_host.data ());
409- beta_device.reset (options.groups );
410- beta_device.copy_from_host (ptr_beta_host.data ());
411-
412- // Per-group alpha and beta values - note these are not directly passed to
413- // kernel - the pointers (alpha_device/beta_device) are passed instead
414- block_alpha.copy_from_host (alpha_host.data ());
415- block_beta.copy_from_host (beta_host.data ());
416- }
417-
418- // / Initialize operands to be used in the GEMM and reference GEMM
419- void initialize_for_ref_gemm (const GroupGEMMOptions &options) {
344+ uint64_t seed = 2020 ;
420345
421346 problem_sizes.reset (options.groups );
422347 problem_sizes.copy_from_host (options.problem_sizes_host .data ());
@@ -438,8 +363,10 @@ template <class Gemm> struct ExampleRunner {
438363 ptr_B_host.at (i) = block_B.get () + offset_B.at (i);
439364 ptr_C_host.at (i) = block_C.get () + offset_C.at (i);
440365 ptr_D_host.at (i) = block_D.get () + offset_D.at (i);
441- // Fill host ptr vectors with offset addresses into device alpha/beta
442- // blocks
366+ // Fill host vector of alpha & beta with random values if using per-group values
367+ alpha_host.push_back ((options.alpha == FLT_MAX) ? static_cast <ElementAccumulator>((rand () % 5 ) + 1 ) : options.alpha );
368+ beta_host.push_back ((options.beta == FLT_MAX) ? static_cast <ElementAccumulator>(rand () % 5 ) : options.beta );
369+ // Fill host ptr vectors with offset addresses into device alpha/beta blocks
443370 ptr_alpha_host.at (i) = block_alpha.get () + i;
444371 ptr_beta_host.at (i) = block_beta.get () + i;
445372 }
@@ -475,9 +402,8 @@ template <class Gemm> struct ExampleRunner {
475402 alpha_device.copy_from_host (ptr_alpha_host.data ());
476403 beta_device.reset (options.groups );
477404 beta_device.copy_from_host (ptr_beta_host.data ());
478-
479- // Per-group alpha and beta values - note these are not directly passed to
480- // kernel - the pointers (alpha_device/beta_device) are passed instead
405+ // Per-group alpha and beta values - note these are not directly passed to kernel - the pointers
406+ // (alpha_device/beta_device) are passed instead
481407 block_alpha.copy_from_host (alpha_host.data ());
482408 block_beta.copy_from_host (beta_host.data ());
483409 }
@@ -486,6 +412,9 @@ template <class Gemm> struct ExampleRunner {
486412 typename Gemm::Arguments
487413 args_from_options (const GroupGEMMOptions &options,
488414 const cutlass::KernelHardwareInfo &hw_info,
415+ const ElementA* A_ptr,
416+ const ElementB* B_ptr,
417+ ElementOutput* D_ptr,
489418 const int gemm_N,
490419 const int gemm_K) {
491420 typename Gemm::Arguments arguments;
@@ -494,8 +423,8 @@ template <class Gemm> struct ExampleRunner {
494423 if (options.alpha != FLT_MAX && options.beta != FLT_MAX) {
495424 // If both alpha/beta are provided (via cmd line args) and are scalar,
496425 // i.e., same alpha/beta applies to all batches.
497- fusion_args.alpha = options. alpha ;
498- fusion_args.beta = options. beta ;
426+ fusion_args.alpha = 1 ;
427+ fusion_args.beta = 0 ;
499428 fusion_args.alpha_ptr = nullptr ;
500429 fusion_args.beta_ptr = nullptr ;
501430 fusion_args.alpha_ptr_array = nullptr ;
@@ -506,12 +435,12 @@ template <class Gemm> struct ExampleRunner {
506435 } else {
507436 // If pointers to alpha/beta are provided, i.e., alpha/beta can differ
508437 // between batches/groups.
509- fusion_args.alpha = 0 ;
438+ fusion_args.alpha = 1 ;
510439 fusion_args.beta = 0 ;
511440 fusion_args.alpha_ptr = nullptr ;
512441 fusion_args.beta_ptr = nullptr ;
513- fusion_args.alpha_ptr_array = alpha_device. get () ;
514- fusion_args.beta_ptr_array = beta_device. get () ;
442+ fusion_args.alpha_ptr_array = nullptr ;
443+ fusion_args.beta_ptr_array = nullptr ;
515444 // One alpha and beta per each group
516445 fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1 };
517446 fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1 };
@@ -523,11 +452,11 @@ template <class Gemm> struct ExampleRunner {
523452 // Per-GEMM problem shape info may only exist on the device.
524453 if (host_problem_shapes_available) {
525454 arguments = typename Gemm::Arguments{
526- cutlass::gemm::GemmUniversalMode::kGrouped ,
527- ptr_A. get ( ),
528- ptr_B. get ( ),
529- nullptr ,
530- ptr_D. get ( ),
455+ cutlass::gemm::GemmUniversalMode::kGrouped , // this just means grouped GEMM
456+ static_cast < const ElementA**>(( void *)A_ptr ),
457+ static_cast < const ElementB**>(( void *)B_ptr ),
458+ static_cast < const ElementC**>(( void *)D_ptr), // we could also pass nullptr
459+ static_cast <ElementOutput**>(( void *)D_ptr ),
531460 fusion_args,
532461 hw_info,
533462 {1 , RasterOrderOptions::AlongN},
@@ -538,10 +467,10 @@ template <class Gemm> struct ExampleRunner {
538467 } else {
539468 arguments = typename Gemm::Arguments{
540469 cutlass::gemm::GemmUniversalMode::kGrouped ,
541- ptr_A. get ( ),
542- ptr_B. get ( ),
543- nullptr ,
544- ptr_D. get ( ),
470+ static_cast < const ElementA**>(( void *)A_ptr ),
471+ static_cast < const ElementB**>(( void *)B_ptr ),
472+ static_cast < const ElementC**>(( void *)D_ptr) ,
473+ static_cast <ElementOutput**>(( void *)D_ptr ),
545474 fusion_args,
546475 hw_info,
547476 {1 , RasterOrderOptions::AlongN},
@@ -557,12 +486,11 @@ template <class Gemm> struct ExampleRunner {
557486 cutlass::Status run (const GroupGEMMOptions &options,
558487 const cutlass::KernelHardwareInfo &hw_info,
559488 const ElementA *A_ptr, const ElementB *B_ptr,
560- ElementOutput *C_ptr, int A_size, int B_size, int D_size, const int gemm_n, const int gemm_k) {
561- allocate (options, A_ptr, B_ptr, C_ptr, A_size, B_size, D_size);
562- initialize_for_moe_gemm (options);
489+ ElementOutput *D_ptr, int A_size, int B_size, int D_size, const int gemm_n, const int gemm_k) {
490+ allocate_for_ref_gemm (options, A_ptr, B_ptr, D_ptr, A_size, B_size, D_size);
563491
564492 Gemm gemm_op;
565- auto arguments = args_from_options (options, hw_info, gemm_n, gemm_k);
493+ auto arguments = args_from_options (options, hw_info, A_ptr, B_ptr, D_ptr, gemm_n, gemm_k);
566494
567495 size_t workspace_size = Gemm::get_workspace_size (arguments);
568496 cutlass::device_memory::allocation<uint8_t > workspace (workspace_size);
@@ -575,15 +503,14 @@ template <class Gemm> struct ExampleRunner {
575503 CUTLASS_CHECK (gemm_op.run ());
576504
577505 syclcompat::wait ();
578- initialize_for_ref_gemm (options);
506+ initialize (options);
579507 // Verify that the result is correct
580508 bool passed = verify (options);
581509 std::cout << " Disposition: " << (passed ? " Passed" : " Failed" ) << std::endl;
582510 if (!passed)
583511 return cutlass::Status::kErrorInternal ;
584- initialize_for_moe_gemm (options);
585512 syclcompat::wait ();
586- arguments = args_from_options (options, hw_info, gemm_n, gemm_k);
513+ arguments = args_from_options (options, hw_info, A_ptr, B_ptr, D_ptr, gemm_n, gemm_k);
587514 CUTLASS_CHECK (gemm_op.can_implement (arguments));
588515
589516 CUTLASS_CHECK (gemm_op.initialize (arguments, workspace.get ()));
@@ -647,20 +574,15 @@ void MoEGEMM(const bfloat16_t *activations, const bfloat16_t *weights,
647574 hw_info.device_id );
648575
649576 using LayoutA = cutlass::layout::RowMajor;
650- using LayoutB = cutlass::layout::ColumnMajor ;
577+ using LayoutB = cutlass::layout::RowMajor ;
651578 using LayoutC = cutlass::layout::RowMajor;
652579 using LayoutD = cutlass::layout::RowMajor;
653580
654581 using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
655- using GmemTiledCopyB = XE_2D_U16x16x16_LD_T ;
582+ using GmemTiledCopyB = XE_2D_U16x32x32_LD_V ;
656583
657584 // Workgroup-level tile
658585 using TileShape = Shape<_256, _256, _32>;
659- /*
660- using TiledMma =
661- TiledMMA<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
662- Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>;
663- */
664586
665587 using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32
666588 typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
@@ -710,18 +632,23 @@ void MoEGEMM(const bfloat16_t *activations, const bfloat16_t *weights,
710632
711633
712634int main (int argc, const char **argv) {
713- const int num_experts = 32 ;
635+ const int num_experts = 16 ;
714636
715- int total_rows_for_each_expert[num_experts] = {
637+ /* int total_rows_for_each_expert[num_experts] = {
716638 148, 231, 404, 180, 127, 244, 224, 244, 110, 617, 289, 845, 191, 424, 30, 97, 57, 324,
717- 62 , 77 , 75 , 144 , 250 , 287 , 629 , 370 , 161 , 101 , 215 , 113 , 224 , 35 };
639+ 62, 77, 75, 144, 250, 287, 629, 370, 161, 101, 215, 113, 224, 35}; */
640+
641+ int total_rows_for_each_expert[num_experts];
642+ for (int i = 0 ; i < num_experts; i++) {
643+ total_rows_for_each_expert[i] = 512 ;
644+ }
718645
719646 int num_tokens_incl_duplicated = 0 ;
720647 for (int i = 0 ; i < num_experts; i++) {
721648 num_tokens_incl_duplicated += total_rows_for_each_expert[i];
722649 }
723- int n_moe = 3072 ;
724- int k_moe = 4096 ;
650+ int n_moe = 16384 ;
651+ int k_moe = 5120 ;
725652
726653 cutlass::DeviceAllocation<int32_t > num_rows_per_expert_device;
727654 cutlass::DeviceAllocation<bfloat16_t > activations_data;
0 commit comments