@@ -142,21 +142,28 @@ struct GroupGEMMOptions {
142142 }
143143
144144 // / Compute performance in GFLOP/s
145- double gflops (double runtime_s,
145+ std::tuple< double , double , double > gflops (double runtime_s,
146146 std::vector<typename ProblemShape::UnderlyingProblemShape>
147147 problem_sizes_host) const {
148148 // Number of real-valued multiply-adds
149149 uint64_t fmas = uint64_t ();
150+ uint64_t bytes_loaded = 0 ;
150151
151152 for (auto const &problem : problem_sizes_host) {
152- fmas += static_cast <uint64_t >(get<0 >(problem)) *
153- static_cast <uint64_t >(get<1 >(problem)) *
154- static_cast <uint64_t >(get<2 >(problem));
153+ auto M = static_cast <uint64_t >(get<0 >(problem));
154+ auto N = static_cast <uint64_t >(get<1 >(problem));
155+ auto K = static_cast <uint64_t >(get<2 >(problem));
156+ fmas += M * N * K;
157+ bytes_loaded += /* sizeof(cutlass::bfloat16_t) */ 2 * (2 * M * N + N * K + M * K);
155158 }
156159 // Two flops per multiply-add
157160 uint64_t flop = uint64_t (2 ) * uint64_t (fmas);
158161 double gflop = double (flop) / double (1.0e9 );
159- return gflop / runtime_s;
162+ double arithmetic_intensity = double (flop) / double (bytes_loaded);
163+ double peak_mwm_bw = 456.0 ;
164+ double gflops_attainable = std::min<double >(117 * double (1.0e12 ), arithmetic_intensity * (peak_mwm_bw * 1024 * 1024 * 1024 ));
165+ double projected_time = flop/gflops_attainable;
166+ return std::make_tuple (gflop / runtime_s, double (bytes_loaded) / 1024 / 1024 / 1024 / runtime_s, projected_time * 1000 );
160167 }
161168};
162169
@@ -455,7 +462,7 @@ template <class Gemm> struct ExampleRunner {
455462 cutlass::gemm::GemmUniversalMode::kGrouped , // this just means grouped GEMM
456463 static_cast <const ElementA**>((void *)A_ptr),
457464 static_cast <const ElementB**>((void *)B_ptr),
458- static_cast <const ElementC**>((void *)D_ptr), // we could also pass nullptr
465+ nullptr , // static_cast<const ElementC**>((void*)D_ptr), // we could also pass nullptr
459466 static_cast <ElementOutput**>((void *)D_ptr),
460467 fusion_args,
461468 hw_info,
@@ -469,7 +476,7 @@ template <class Gemm> struct ExampleRunner {
469476 cutlass::gemm::GemmUniversalMode::kGrouped ,
470477 static_cast <const ElementA**>((void *)A_ptr),
471478 static_cast <const ElementB**>((void *)B_ptr),
472- static_cast <const ElementC**>((void *)D_ptr),
479+ nullptr , // static_cast<const ElementC**>((void*)D_ptr),
473480 static_cast <ElementOutput**>((void *)D_ptr),
474481 fusion_args,
475482 hw_info,
@@ -525,7 +532,7 @@ template <class Gemm> struct ExampleRunner {
525532
526533 float cute_time = timer.seconds () * 1000 ;
527534 double cute_average_time = double (cute_time) / double (options.iterations );
528- double gflops = options.gflops (cute_average_time / 1000.0 ,
535+ auto [ gflops, mem_bw_util, projected_time] = options.gflops (cute_average_time / 1000.0 ,
529536 options.problem_sizes_host );
530537
531538 std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
@@ -538,6 +545,7 @@ template <class Gemm> struct ExampleRunner {
538545 std::cout << " Avg runtime : " << cute_average_time << " ms"
539546 << std::endl;
540547 std::cout << " GFLOPS : " << gflops << std::endl;
548+ std::cout << " Memory BW utilization : " << mem_bw_util << " GBPs" << std::endl;
541549 }
542550
543551 return cutlass::Status::kSuccess ;
@@ -630,26 +638,23 @@ void MoEGEMM(const bfloat16_t *activations, const bfloat16_t *weights,
630638 num_rows_per_expert_obj.release ();
631639}
632640
633-
634- int main (int argc, const char **argv) {
635- const int num_experts = 16 ;
636-
637- /* int total_rows_for_each_expert[num_experts] = {
638- 148, 231, 404, 180, 127, 244, 224, 244, 110, 617, 289, 845, 191, 424, 30, 97, 57, 324,
639- 62, 77, 75, 144, 250, 287, 629, 370, 161, 101, 215, 113, 224, 35}; */
640-
641- int total_rows_for_each_expert[num_experts];
642- for (int i = 0 ; i < num_experts; i++) {
643- total_rows_for_each_expert[i] = 512 ;
644- }
645-
641+ void launcher (int * M_per_expert, int N, int K, const int & num_experts) {
642+ int n_moe = N;
643+ int k_moe = K;
646644 int num_tokens_incl_duplicated = 0 ;
647- for (int i = 0 ; i < num_experts; i++) {
648- num_tokens_incl_duplicated += total_rows_for_each_expert [i];
645+ for (int i= 0 ; i < num_experts; i++) {
646+ num_tokens_incl_duplicated += M_per_expert [i];
649647 }
650- int n_moe = 16384 ;
651- int k_moe = 5120 ;
652648
649+ float M_occupancy = 0 .f ;
650+ float actual_num_units = 0 .f ;
651+ int total_num_M_tiles = 0 ;
652+ for (int i=0 ; i < num_experts; i++) {
653+ total_num_M_tiles += (M_per_expert[i] + 63 )/64 ;
654+ actual_num_units += M_per_expert[i]/64.0 ;
655+ }
656+ M_occupancy = actual_num_units / total_num_M_tiles;
657+ std::cout << " \n\n M-occupancy is " << M_occupancy << std::endl;
653658 cutlass::DeviceAllocation<int32_t > num_rows_per_expert_device;
654659 cutlass::DeviceAllocation<bfloat16_t > activations_data;
655660 cutlass::DeviceAllocation<bfloat16_t > weights_data;
@@ -658,7 +663,7 @@ int main(int argc, const char **argv) {
658663 size_t B_size = num_experts * n_moe * k_moe;
659664 size_t D_size = num_tokens_incl_duplicated * n_moe;
660665 num_rows_per_expert_device.reset (num_experts);
661- num_rows_per_expert_device.copy_from_host (total_rows_for_each_expert );
666+ num_rows_per_expert_device.copy_from_host (M_per_expert );
662667 activations_data.reset (A_size);
663668 weights_data.reset (B_size);
664669 output_data.reset (D_size);
@@ -672,5 +677,44 @@ int main(int argc, const char **argv) {
672677 weights_data.release ();
673678 output_data.release ();
674679 num_rows_per_expert_device.release ();
680+ }
681+
682+
683+ int main (int argc, const char **argv) {
684+ constexpr int num_experts = 32 ;
685+ constexpr int num_layers = 24 ;
686+
687+ int total_rows_for_each_expert[num_layers][num_experts] = {
688+ {148 , 231 , 404 , 180 , 127 , 244 , 224 , 244 , 110 , 617 , 289 , 845 , 191 , 424 , 30 , 97 , 57 , 324 , 62 , 77 , 75 , 144 , 250 , 287 , 629 , 370 , 161 , 101 , 215 , 113 , 224 , 35 },
689+ {666 , 214 , 448 , 87 , 4 , 28 , 48 , 13 , 74 , 40 , 546 , 397 , 487 , 350 , 26 , 95 , 517 , 487 , 295 , 58 , 637 , 97 , 139 , 33 , 126 , 15 , 352 , 311 , 995 , 193 , 135 , 135 },
690+ {1016 , 30 , 36 , 452 , 469 , 473 , 232 , 0 , 493 , 14 , 954 , 6 , 4 , 6 , 279 , 3 , 94 , 106 , 96 , 48 , 49 , 113 , 142 , 169 , 75 , 99 , 25 , 220 , 249 , 289 , 4 , 1803 },
691+ {350 , 229 , 703 , 154 , 8 , 64 , 80 , 339 , 2 , 56 , 5 , 312 , 1005 , 29 , 9 , 11 , 23 , 0 , 23 , 431 , 48 , 129 , 496 , 476 , 8 , 1234 , 7 , 130 , 34 , 58 , 41 , 1554 },
692+ {39 , 10 , 6 , 2 , 110 , 1 , 894 , 8 , 53 , 0 , 275 , 6 , 506 , 421 , 700 , 178 , 0 , 530 , 1623 , 15 , 231 , 74 , 6 , 222 , 1246 , 116 , 35 , 20 , 0 , 6 , 381 , 334 },
693+ {399 , 5 , 201 , 6 , 134 , 93 , 1748 , 1 , 51 , 4 , 38 , 336 , 53 , 88 , 328 , 724 , 15 , 388 , 706 , 52 , 19 , 55 , 52 , 33 , 623 , 1 , 222 , 215 , 69 , 45 , 308 , 1036 },
694+ {11 , 8 , 407 , 571 , 458 , 275 , 197 , 211 , 13 , 564 , 462 , 114 , 15 , 13 , 132 , 24 , 514 , 2 , 71 , 13 , 694 , 47 , 16 , 203 , 610 , 40 , 0 , 1587 , 66 , 23 , 196 , 491 },
695+ {0 , 230 , 116 , 136 , 315 , 643 , 6 , 183 , 37 , 26 , 960 , 1 , 8 , 258 , 21 , 1602 , 213 , 198 , 6 , 196 , 455 , 557 , 47 , 282 , 493 , 18 , 101 , 11 , 616 , 45 , 268 , 0 },
696+ {392 , 305 , 179 , 14 , 227 , 98 , 114 , 39 , 64 , 1456 , 465 , 0 , 18 , 372 , 0 , 0 , 189 , 257 , 25 , 290 , 486 , 0 , 12 , 1534 , 468 , 4 , 555 , 35 , 146 , 0 , 161 , 143 },
697+ {4 , 107 , 20 , 125 , 236 , 898 , 0 , 0 , 375 , 2 , 125 , 0 , 0 , 1429 , 36 , 195 , 1660 , 0 , 127 , 454 , 73 , 358 , 47 , 79 , 32 , 20 , 1465 , 0 , 0 , 6 , 109 , 66 },
698+ {19 , 0 , 0 , 0 , 2 , 1638 , 75 , 135 , 392 , 2 , 1494 , 3 , 23 , 5 , 4 , 58 , 0 , 0 , 71 , 1285 , 8 , 441 , 0 , 145 , 209 , 408 , 450 , 2 , 824 , 13 , 326 , 16 },
699+ {4 , 2 , 14 , 0 , 30 , 206 , 41 , 131 , 0 , 429 , 16 , 895 , 35 , 21 , 44 , 128 , 12 , 0 , 417 , 0 , 838 , 917 , 42 , 115 , 109 , 1759 , 0 , 36 , 17 , 0 , 1790 , 0 },
700+ {6 , 483 , 241 , 1327 , 17 , 11 , 480 , 9 , 880 , 58 , 4 , 0 , 61 , 30 , 16 , 176 , 9 , 309 , 26 , 0 , 0 , 1882 , 4 , 281 , 475 , 783 , 197 , 0 , 19 , 15 , 6 , 243 },
701+ {370 , 1222 , 0 , 6 , 108 , 929 , 2 , 7 , 157 , 348 , 149 , 106 , 2 , 5 , 25 , 33 , 1569 , 8 , 6 , 106 , 69 , 1298 , 0 , 2 , 529 , 520 , 0 , 421 , 0 , 25 , 26 , 0 },
702+ {59 , 89 , 0 , 26 , 25 , 40 , 1873 , 141 , 527 , 371 , 262 , 62 , 16 , 0 , 127 , 234 , 1637 , 64 , 132 , 8 , 0 , 7 , 161 , 1005 , 22 , 1 , 49 , 6 , 83 , 925 , 80 , 16 },
703+ {269 , 617 , 30 , 4 , 90 , 26 , 0 , 16 , 154 , 212 , 21 , 269 , 379 , 174 , 129 , 32 , 8 , 121 , 344 , 15 , 0 , 591 , 1494 , 6 , 737 , 50 , 112 , 856 , 483 , 25 , 454 , 330 },
704+ {0 , 98 , 1488 , 22 , 73 , 0 , 0 , 343 , 77 , 4 , 0 , 612 , 165 , 268 , 4 , 10 , 43 , 0 , 598 , 271 , 2 , 73 , 185 , 0 , 112 , 779 , 24 , 1626 , 0 , 0 , 0 , 1171 },
705+ {0 , 0 , 0 , 189 , 266 , 1743 , 0 , 462 , 20 , 7 , 668 , 310 , 40 , 0 , 10 , 236 , 423 , 18 , 0 , 0 , 0 , 999 , 0 , 139 , 1754 , 8 , 619 , 3 , 23 , 0 , 102 , 9 },
706+ {131 , 1753 , 0 , 113 , 24 , 94 , 2 , 12 , 108 , 0 , 0 , 252 , 97 , 0 , 1319 , 233 , 93 , 1254 , 195 , 152 , 14 , 413 , 4 , 2 , 220 , 67 , 20 , 4 , 34 , 559 , 837 , 42 },
707+ {55 , 76 , 0 , 8 , 0 , 3 , 1557 , 975 , 135 , 271 , 4 , 0 , 0 , 666 , 207 , 152 , 5 , 2 , 97 , 364 , 0 , 13 , 1423 , 771 , 159 , 31 , 223 , 0 , 431 , 7 , 409 , 4 },
708+ {4 , 1026 , 1799 , 166 , 694 , 753 , 0 , 16 , 0 , 240 , 1119 , 19 , 6 , 0 , 46 , 659 , 10 , 0 , 112 , 808 , 181 , 0 , 28 , 22 , 90 , 0 , 176 , 0 , 37 , 5 , 10 , 22 },
709+ {44 , 0 , 4 , 153 , 299 , 1357 , 6 , 23 , 0 , 12 , 4 , 419 , 73 , 24 , 16 , 24 , 1 , 4 , 4 , 102 , 16 , 4 , 0 , 1953 , 1850 , 0 , 908 , 4 , 0 , 13 , 708 , 23 },
710+ {6 , 13 , 123 , 28 , 197 , 0 , 202 , 69 , 0 , 6 , 0 , 21 , 1434 , 1582 , 11 , 0 , 6 , 0 , 7 , 190 , 4 , 1700 , 6 , 434 , 1886 , 0 , 14 , 28 , 8 , 30 , 25 , 18 },
711+ {5 , 27 , 1442 , 18 , 0 , 6 , 0 , 73 , 6 , 781 , 0 , 1915 , 291 , 649 , 98 , 4 , 33 , 77 , 6 , 22 , 73 , 9 , 8 , 587 , 1486 , 32 , 10 , 244 , 37 , 0 , 100 , 9 }
712+ };
713+
714+ for (int i = 0 ; i < num_layers; i++) {
715+ launcher (total_rows_for_each_expert[i], 5760 , 2880 , num_experts);
716+ launcher (total_rows_for_each_expert[i], 2880 , 2880 , num_experts);
717+ }
718+
675719 return 0 ;
676720}
0 commit comments