Skip to content

Commit d29944f

Browse files
committed
Update 11_bmg_moe_gemm_bf16.cpp
1 parent dad2193 commit d29944f

File tree

1 file changed

+70
-26
lines changed

1 file changed

+70
-26
lines changed

examples/11_bmg_moe_gemm_bf16/11_bmg_moe_gemm_bf16.cpp

Lines changed: 70 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -142,21 +142,28 @@ struct GroupGEMMOptions {
142142
}
143143

144144
/// Compute performance in GFLOP/s
145-
double gflops(double runtime_s,
145+
std::tuple<double, double, double> gflops(double runtime_s,
146146
std::vector<typename ProblemShape::UnderlyingProblemShape>
147147
problem_sizes_host) const {
148148
// Number of real-valued multiply-adds
149149
uint64_t fmas = uint64_t();
150+
uint64_t bytes_loaded = 0;
150151

151152
for (auto const &problem : problem_sizes_host) {
152-
fmas += static_cast<uint64_t>(get<0>(problem)) *
153-
static_cast<uint64_t>(get<1>(problem)) *
154-
static_cast<uint64_t>(get<2>(problem));
153+
auto M = static_cast<uint64_t>(get<0>(problem));
154+
auto N = static_cast<uint64_t>(get<1>(problem));
155+
auto K = static_cast<uint64_t>(get<2>(problem));
156+
fmas += M * N * K;
157+
bytes_loaded += /* sizeof(cutlass::bfloat16_t) */ 2 * (2 * M * N + N * K + M * K);
155158
}
156159
// Two flops per multiply-add
157160
uint64_t flop = uint64_t(2) * uint64_t(fmas);
158161
double gflop = double(flop) / double(1.0e9);
159-
return gflop / runtime_s;
162+
double arithmetic_intensity = double(flop) / double(bytes_loaded);
163+
double peak_mwm_bw = 456.0;
164+
double gflops_attainable = std::min<double>(117 * double(1.0e12), arithmetic_intensity * (peak_mwm_bw * 1024 * 1024 * 1024));
165+
double projected_time = flop/gflops_attainable;
166+
return std::make_tuple(gflop / runtime_s, double(bytes_loaded) / 1024 / 1024 / 1024 / runtime_s, projected_time * 1000);
160167
}
161168
};
162169

@@ -455,7 +462,7 @@ template <class Gemm> struct ExampleRunner {
455462
cutlass::gemm::GemmUniversalMode::kGrouped, // this just means grouped GEMM
456463
static_cast<const ElementA**>((void*)A_ptr),
457464
static_cast<const ElementB**>((void*)B_ptr),
458-
static_cast<const ElementC**>((void*)D_ptr), // we could also pass nullptr
465+
nullptr,//static_cast<const ElementC**>((void*)D_ptr), // we could also pass nullptr
459466
static_cast<ElementOutput**>((void*)D_ptr),
460467
fusion_args,
461468
hw_info,
@@ -469,7 +476,7 @@ template <class Gemm> struct ExampleRunner {
469476
cutlass::gemm::GemmUniversalMode::kGrouped,
470477
static_cast<const ElementA**>((void*)A_ptr),
471478
static_cast<const ElementB**>((void*)B_ptr),
472-
static_cast<const ElementC**>((void*)D_ptr),
479+
nullptr, // static_cast<const ElementC**>((void*)D_ptr),
473480
static_cast<ElementOutput**>((void*)D_ptr),
474481
fusion_args,
475482
hw_info,
@@ -525,7 +532,7 @@ template <class Gemm> struct ExampleRunner {
525532

526533
float cute_time = timer.seconds() * 1000;
527534
double cute_average_time = double(cute_time) / double(options.iterations);
528-
double gflops = options.gflops(cute_average_time / 1000.0,
535+
auto [gflops, mem_bw_util, projected_time] = options.gflops(cute_average_time / 1000.0,
529536
options.problem_sizes_host);
530537

531538
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
@@ -538,6 +545,7 @@ template <class Gemm> struct ExampleRunner {
538545
std::cout << " Avg runtime : " << cute_average_time << " ms"
539546
<< std::endl;
540547
std::cout << " GFLOPS : " << gflops << std::endl;
548+
std::cout << " Memory BW utilization : " << mem_bw_util << " GBPs" << std::endl;
541549
}
542550

543551
return cutlass::Status::kSuccess;
@@ -630,26 +638,23 @@ void MoEGEMM(const bfloat16_t *activations, const bfloat16_t *weights,
630638
num_rows_per_expert_obj.release();
631639
}
632640

633-
634-
int main(int argc, const char **argv) {
635-
const int num_experts = 16;
636-
637-
/* int total_rows_for_each_expert[num_experts] = {
638-
148, 231, 404, 180, 127, 244, 224, 244, 110, 617, 289, 845, 191, 424, 30, 97, 57, 324,
639-
62, 77, 75, 144, 250, 287, 629, 370, 161, 101, 215, 113, 224, 35}; */
640-
641-
int total_rows_for_each_expert[num_experts];
642-
for (int i = 0; i < num_experts; i++) {
643-
total_rows_for_each_expert[i] = 512;
644-
}
645-
641+
void launcher(int* M_per_expert, int N, int K, const int& num_experts) {
642+
int n_moe = N;
643+
int k_moe = K;
646644
int num_tokens_incl_duplicated = 0;
647-
for (int i = 0; i < num_experts; i++) {
648-
num_tokens_incl_duplicated += total_rows_for_each_expert[i];
645+
for(int i=0; i < num_experts; i++) {
646+
num_tokens_incl_duplicated += M_per_expert[i];
649647
}
650-
int n_moe = 16384;
651-
int k_moe = 5120;
652648

649+
float M_occupancy = 0.f;
650+
float actual_num_units = 0.f;
651+
int total_num_M_tiles = 0;
652+
for (int i=0; i < num_experts; i++) {
653+
total_num_M_tiles += (M_per_expert[i] + 63)/64;
654+
actual_num_units += M_per_expert[i]/64.0;
655+
}
656+
M_occupancy = actual_num_units / total_num_M_tiles;
657+
std::cout << "\n\n M-occupancy is " << M_occupancy << std::endl;
653658
cutlass::DeviceAllocation<int32_t> num_rows_per_expert_device;
654659
cutlass::DeviceAllocation<bfloat16_t> activations_data;
655660
cutlass::DeviceAllocation<bfloat16_t> weights_data;
@@ -658,7 +663,7 @@ int main(int argc, const char **argv) {
658663
size_t B_size = num_experts * n_moe * k_moe;
659664
size_t D_size = num_tokens_incl_duplicated * n_moe;
660665
num_rows_per_expert_device.reset(num_experts);
661-
num_rows_per_expert_device.copy_from_host(total_rows_for_each_expert);
666+
num_rows_per_expert_device.copy_from_host(M_per_expert);
662667
activations_data.reset(A_size);
663668
weights_data.reset(B_size);
664669
output_data.reset(D_size);
@@ -672,5 +677,44 @@ int main(int argc, const char **argv) {
672677
weights_data.release();
673678
output_data.release();
674679
num_rows_per_expert_device.release();
680+
}
681+
682+
683+
int main(int argc, const char **argv) {
684+
constexpr int num_experts = 32;
685+
constexpr int num_layers = 24;
686+
687+
int total_rows_for_each_expert[num_layers][num_experts] = {
688+
{148, 231, 404, 180, 127, 244, 224, 244, 110, 617, 289, 845, 191, 424, 30, 97, 57, 324, 62, 77, 75, 144, 250, 287, 629, 370, 161, 101, 215, 113, 224, 35},
689+
{666, 214, 448, 87, 4, 28, 48, 13, 74, 40, 546, 397, 487, 350, 26, 95, 517, 487, 295, 58, 637, 97, 139, 33, 126, 15, 352, 311, 995, 193, 135, 135},
690+
{1016, 30, 36, 452, 469, 473, 232, 0, 493, 14, 954, 6, 4, 6, 279, 3, 94, 106, 96, 48, 49, 113, 142, 169, 75, 99, 25, 220, 249, 289, 4, 1803},
691+
{350, 229, 703, 154, 8, 64, 80, 339, 2, 56, 5, 312, 1005, 29, 9, 11, 23, 0, 23, 431, 48, 129, 496, 476, 8, 1234, 7, 130, 34, 58, 41, 1554},
692+
{39, 10, 6, 2, 110, 1, 894, 8, 53, 0, 275, 6, 506, 421, 700, 178, 0, 530, 1623, 15, 231, 74, 6, 222, 1246, 116, 35, 20, 0, 6, 381, 334},
693+
{399, 5, 201, 6, 134, 93, 1748, 1, 51, 4, 38, 336, 53, 88, 328, 724, 15, 388, 706, 52, 19, 55, 52, 33, 623, 1, 222, 215, 69, 45, 308, 1036},
694+
{11, 8, 407, 571, 458, 275, 197, 211, 13, 564, 462, 114, 15, 13, 132, 24, 514, 2, 71, 13, 694, 47, 16, 203, 610, 40, 0, 1587, 66, 23, 196, 491},
695+
{0, 230, 116, 136, 315, 643, 6, 183, 37, 26, 960, 1, 8, 258, 21, 1602, 213, 198, 6, 196, 455, 557, 47, 282, 493, 18, 101, 11, 616, 45, 268, 0},
696+
{392, 305, 179, 14, 227, 98, 114, 39, 64, 1456, 465, 0, 18, 372, 0, 0, 189, 257, 25, 290, 486, 0, 12, 1534, 468, 4, 555, 35, 146, 0, 161, 143},
697+
{4, 107, 20, 125, 236, 898, 0, 0, 375, 2, 125, 0, 0, 1429, 36, 195, 1660, 0, 127, 454, 73, 358, 47, 79, 32, 20, 1465, 0, 0, 6, 109, 66},
698+
{19, 0, 0, 0, 2, 1638, 75, 135, 392, 2, 1494, 3, 23, 5, 4, 58, 0, 0, 71, 1285, 8, 441, 0, 145, 209, 408, 450, 2, 824, 13, 326, 16},
699+
{4, 2, 14, 0, 30, 206, 41, 131, 0, 429, 16, 895, 35, 21, 44, 128, 12, 0, 417, 0, 838, 917, 42, 115, 109, 1759, 0, 36, 17, 0, 1790, 0},
700+
{6, 483, 241, 1327, 17, 11, 480, 9, 880, 58, 4, 0, 61, 30, 16, 176, 9, 309, 26, 0, 0, 1882, 4, 281, 475, 783, 197, 0, 19, 15, 6, 243},
701+
{370, 1222, 0, 6, 108, 929, 2, 7, 157, 348, 149, 106, 2, 5, 25, 33, 1569, 8, 6, 106, 69, 1298, 0, 2, 529, 520, 0, 421, 0, 25, 26, 0},
702+
{59, 89, 0, 26, 25, 40, 1873, 141, 527, 371, 262, 62, 16, 0, 127, 234, 1637, 64, 132, 8, 0, 7, 161, 1005, 22, 1, 49, 6, 83, 925, 80, 16},
703+
{269, 617, 30, 4, 90, 26, 0, 16, 154, 212, 21, 269, 379, 174, 129, 32, 8, 121, 344, 15, 0, 591, 1494, 6, 737, 50, 112, 856, 483, 25, 454, 330},
704+
{0, 98, 1488, 22, 73, 0, 0, 343, 77, 4, 0, 612, 165, 268, 4, 10, 43, 0, 598, 271, 2, 73, 185, 0, 112, 779, 24, 1626, 0, 0, 0, 1171},
705+
{0, 0, 0, 189, 266, 1743, 0, 462, 20, 7, 668, 310, 40, 0, 10, 236, 423, 18, 0, 0, 0, 999, 0, 139, 1754, 8, 619, 3, 23, 0, 102, 9},
706+
{131, 1753, 0, 113, 24, 94, 2, 12, 108, 0, 0, 252, 97, 0, 1319, 233, 93, 1254, 195, 152, 14, 413, 4, 2, 220, 67, 20, 4, 34, 559, 837, 42},
707+
{55, 76, 0, 8, 0, 3, 1557, 975, 135, 271, 4, 0, 0, 666, 207, 152, 5, 2, 97, 364, 0, 13, 1423, 771, 159, 31, 223, 0, 431, 7, 409, 4},
708+
{4, 1026, 1799, 166, 694, 753, 0, 16, 0, 240, 1119, 19, 6, 0, 46, 659, 10, 0, 112, 808, 181, 0, 28, 22, 90, 0, 176, 0, 37, 5, 10, 22},
709+
{44, 0, 4, 153, 299, 1357, 6, 23, 0, 12, 4, 419, 73, 24, 16, 24, 1, 4, 4, 102, 16, 4, 0, 1953, 1850, 0, 908, 4, 0, 13, 708, 23},
710+
{6, 13, 123, 28, 197, 0, 202, 69, 0, 6, 0, 21, 1434, 1582, 11, 0, 6, 0, 7, 190, 4, 1700, 6, 434, 1886, 0, 14, 28, 8, 30, 25, 18},
711+
{5, 27, 1442, 18, 0, 6, 0, 73, 6, 781, 0, 1915, 291, 649, 98, 4, 33, 77, 6, 22, 73, 9, 8, 587, 1486, 32, 10, 244, 37, 0, 100, 9}
712+
};
713+
714+
for (int i = 0; i < num_layers; i++) {
715+
launcher(total_rows_for_each_expert[i], 5760, 2880, num_experts);
716+
launcher(total_rows_for_each_expert[i], 2880, 2880, num_experts);
717+
}
718+
675719
return 0;
676720
}

0 commit comments

Comments
 (0)