Skip to content

Commit dad2193

Browse files
committed
Eliminate all H2D & D2H copies
1 parent 5b6f880 commit dad2193

File tree

3 files changed

+49
-122
lines changed

3 files changed

+49
-122
lines changed

examples/11_bmg_moe_gemm_bf16/11_bmg_moe_gemm_bf16.cpp

Lines changed: 45 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -280,8 +280,8 @@ template <class Gemm> struct ExampleRunner {
280280
return passed;
281281
}
282282

283-
/// Allocates device-side data
284-
void allocate(const GroupGEMMOptions &options, const ElementA *block_A_ptr,
283+
/// Allocates device-side data for reference GEMM
284+
void allocate_for_ref_gemm(const GroupGEMMOptions &options, const ElementA *block_A_ptr,
285285
const ElementA *block_B_ptr, ElementOutput*block_C_ptr,
286286
int block_A_size, int block_B_size, int block_C_size) {
287287
int64_t total_elements_A = 0;
@@ -338,85 +338,10 @@ template <class Gemm> struct ExampleRunner {
338338
cumsum_device.copy_from_host(cumsum_host);
339339
}
340340

341-
/// Initialize operands to be used in the GEMM and reference GEMM
342-
void initialize_for_moe_gemm(const GroupGEMMOptions &options) {
341+
/// Initialize operands to be used in the reference GEMM
342+
void initialize(const GroupGEMMOptions &options) {
343343

344-
problem_sizes.reset(options.groups);
345-
problem_sizes.copy_from_host(options.problem_sizes_host.data());
346-
347-
//
348-
// Assign pointers
349-
//
350-
351-
std::vector<ElementA *> ptr_A_host(1);
352-
std::vector<ElementB *> ptr_B_host(1);
353-
std::vector<ElementC *> ptr_C_host(1);
354-
std::vector<ElementOutput *> ptr_D_host(1);
355-
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
356-
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
357-
358-
// Compute offsets, alpha & beta over group on host
359-
360-
ptr_A_host.at(0) = block_A.get();
361-
ptr_B_host.at(0) = block_B.get();
362-
ptr_C_host.at(0) = block_C.get();
363-
ptr_D_host.at(0) = block_D.get();
364-
for (int32_t i = 0; i < options.groups; ++i) {
365-
// Fill host vector of alpha & beta with random values if using per-group
366-
// values
367-
alpha_host.push_back(
368-
(options.alpha == FLT_MAX)
369-
? static_cast<ElementAccumulator>((rand() % 5) + 1)
370-
: options.alpha);
371-
beta_host.push_back((options.beta == FLT_MAX)
372-
? static_cast<ElementAccumulator>(rand() % 5)
373-
: options.beta);
374-
// Fill host ptr vectors with offset addresses into device alpha/beta
375-
// blocks
376-
ptr_alpha_host.at(i) = block_alpha.get() + i;
377-
ptr_beta_host.at(i) = block_beta.get() + i;
378-
}
379-
380-
// Allocate device memory & copy from host
381-
ptr_A.reset(1);
382-
// Per-group alpha and beta
383-
ptr_A.copy_from_host(ptr_A_host.data());
384-
385-
ptr_B.reset(1);
386-
ptr_B.copy_from_host(ptr_B_host.data());
387-
388-
ptr_C.reset(1);
389-
ptr_C.copy_from_host(ptr_C_host.data());
390-
391-
ptr_D.reset(1);
392-
ptr_D.copy_from_host(ptr_D_host.data());
393-
394-
stride_A.reset(options.groups);
395-
stride_A.copy_from_host(stride_A_host.data());
396-
397-
stride_B.reset(options.groups);
398-
stride_B.copy_from_host(stride_B_host.data());
399-
400-
stride_C.reset(options.groups);
401-
stride_C.copy_from_host(stride_C_host.data());
402-
403-
stride_D.reset(options.groups);
404-
stride_D.copy_from_host(stride_D_host.data());
405-
406-
// Per-group alpha and beta ptrs
407-
alpha_device.reset(options.groups);
408-
alpha_device.copy_from_host(ptr_alpha_host.data());
409-
beta_device.reset(options.groups);
410-
beta_device.copy_from_host(ptr_beta_host.data());
411-
412-
// Per-group alpha and beta values - note these are not directly passed to
413-
// kernel - the pointers (alpha_device/beta_device) are passed instead
414-
block_alpha.copy_from_host(alpha_host.data());
415-
block_beta.copy_from_host(beta_host.data());
416-
}
417-
418-
/// Initialize operands to be used in the GEMM and reference GEMM
419-
void initialize_for_ref_gemm(const GroupGEMMOptions &options) {
344+
uint64_t seed = 2020;
420345

421346
problem_sizes.reset(options.groups);
422347
problem_sizes.copy_from_host(options.problem_sizes_host.data());
@@ -438,8 +363,10 @@ template <class Gemm> struct ExampleRunner {
438363
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
439364
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
440365
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
441-
// Fill host ptr vectors with offset addresses into device alpha/beta
442-
// blocks
366+
// Fill host vector of alpha & beta with random values if using per-group values
367+
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
368+
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
369+
// Fill host ptr vectors with offset addresses into device alpha/beta blocks
443370
ptr_alpha_host.at(i) = block_alpha.get() + i;
444371
ptr_beta_host.at(i) = block_beta.get() + i;
445372
}
@@ -475,9 +402,8 @@ template <class Gemm> struct ExampleRunner {
475402
alpha_device.copy_from_host(ptr_alpha_host.data());
476403
beta_device.reset(options.groups);
477404
beta_device.copy_from_host(ptr_beta_host.data());
478-
479-
// Per-group alpha and beta values - note these are not directly passed to
480-
// kernel - the pointers (alpha_device/beta_device) are passed instead
405+
// Per-group alpha and beta values - note these are not directly passed to kernel - the pointers
406+
// (alpha_device/beta_device) are passed instead
481407
block_alpha.copy_from_host(alpha_host.data());
482408
block_beta.copy_from_host(beta_host.data());
483409
}
@@ -486,6 +412,9 @@ template <class Gemm> struct ExampleRunner {
486412
typename Gemm::Arguments
487413
args_from_options(const GroupGEMMOptions &options,
488414
const cutlass::KernelHardwareInfo &hw_info,
415+
const ElementA* A_ptr,
416+
const ElementB* B_ptr,
417+
ElementOutput* D_ptr,
489418
const int gemm_N,
490419
const int gemm_K) {
491420
typename Gemm::Arguments arguments;
@@ -494,8 +423,8 @@ template <class Gemm> struct ExampleRunner {
494423
if (options.alpha != FLT_MAX && options.beta != FLT_MAX) {
495424
// If both alpha/beta are provided (via cmd line args) and are scalar,
496425
// i.e., same alpha/beta applies to all batches.
497-
fusion_args.alpha = options.alpha;
498-
fusion_args.beta = options.beta;
426+
fusion_args.alpha = 1;
427+
fusion_args.beta = 0;
499428
fusion_args.alpha_ptr = nullptr;
500429
fusion_args.beta_ptr = nullptr;
501430
fusion_args.alpha_ptr_array = nullptr;
@@ -506,12 +435,12 @@ template <class Gemm> struct ExampleRunner {
506435
} else {
507436
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ
508437
// between batches/groups.
509-
fusion_args.alpha = 0;
438+
fusion_args.alpha = 1;
510439
fusion_args.beta = 0;
511440
fusion_args.alpha_ptr = nullptr;
512441
fusion_args.beta_ptr = nullptr;
513-
fusion_args.alpha_ptr_array = alpha_device.get();
514-
fusion_args.beta_ptr_array = beta_device.get();
442+
fusion_args.alpha_ptr_array = nullptr;
443+
fusion_args.beta_ptr_array = nullptr;
515444
// One alpha and beta per each group
516445
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1};
517446
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
@@ -523,11 +452,11 @@ template <class Gemm> struct ExampleRunner {
523452
// Per-GEMM problem shape info may only exist on the device.
524453
if (host_problem_shapes_available) {
525454
arguments = typename Gemm::Arguments{
526-
cutlass::gemm::GemmUniversalMode::kGrouped,
527-
ptr_A.get(),
528-
ptr_B.get(),
529-
nullptr,
530-
ptr_D.get(),
455+
cutlass::gemm::GemmUniversalMode::kGrouped, // this just means grouped GEMM
456+
static_cast<const ElementA**>((void*)A_ptr),
457+
static_cast<const ElementB**>((void*)B_ptr),
458+
static_cast<const ElementC**>((void*)D_ptr), // we could also pass nullptr
459+
static_cast<ElementOutput**>((void*)D_ptr),
531460
fusion_args,
532461
hw_info,
533462
{1, RasterOrderOptions::AlongN},
@@ -538,10 +467,10 @@ template <class Gemm> struct ExampleRunner {
538467
} else {
539468
arguments = typename Gemm::Arguments{
540469
cutlass::gemm::GemmUniversalMode::kGrouped,
541-
ptr_A.get(),
542-
ptr_B.get(),
543-
nullptr,
544-
ptr_D.get(),
470+
static_cast<const ElementA**>((void*)A_ptr),
471+
static_cast<const ElementB**>((void*)B_ptr),
472+
static_cast<const ElementC**>((void*)D_ptr),
473+
static_cast<ElementOutput**>((void*)D_ptr),
545474
fusion_args,
546475
hw_info,
547476
{1, RasterOrderOptions::AlongN},
@@ -557,12 +486,11 @@ template <class Gemm> struct ExampleRunner {
557486
cutlass::Status run(const GroupGEMMOptions &options,
558487
const cutlass::KernelHardwareInfo &hw_info,
559488
const ElementA *A_ptr, const ElementB *B_ptr,
560-
ElementOutput *C_ptr, int A_size, int B_size, int D_size, const int gemm_n, const int gemm_k) {
561-
allocate(options, A_ptr, B_ptr, C_ptr, A_size, B_size, D_size);
562-
initialize_for_moe_gemm(options);
489+
ElementOutput *D_ptr, int A_size, int B_size, int D_size, const int gemm_n, const int gemm_k) {
490+
allocate_for_ref_gemm(options, A_ptr, B_ptr, D_ptr, A_size, B_size, D_size);
563491

564492
Gemm gemm_op;
565-
auto arguments = args_from_options(options, hw_info, gemm_n, gemm_k);
493+
auto arguments = args_from_options(options, hw_info, A_ptr, B_ptr, D_ptr, gemm_n, gemm_k);
566494

567495
size_t workspace_size = Gemm::get_workspace_size(arguments);
568496
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
@@ -575,15 +503,14 @@ template <class Gemm> struct ExampleRunner {
575503
CUTLASS_CHECK(gemm_op.run());
576504

577505
syclcompat::wait();
578-
initialize_for_ref_gemm(options);
506+
initialize(options);
579507
// Verify that the result is correct
580508
bool passed = verify(options);
581509
std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl;
582510
if (!passed)
583511
return cutlass::Status::kErrorInternal;
584-
initialize_for_moe_gemm(options);
585512
syclcompat::wait();
586-
arguments = args_from_options(options, hw_info, gemm_n, gemm_k);
513+
arguments = args_from_options(options, hw_info, A_ptr, B_ptr, D_ptr, gemm_n, gemm_k);
587514
CUTLASS_CHECK(gemm_op.can_implement(arguments));
588515

589516
CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get()));
@@ -647,20 +574,15 @@ void MoEGEMM(const bfloat16_t *activations, const bfloat16_t *weights,
647574
hw_info.device_id);
648575

649576
using LayoutA = cutlass::layout::RowMajor;
650-
using LayoutB = cutlass::layout::ColumnMajor;
577+
using LayoutB = cutlass::layout::RowMajor;
651578
using LayoutC = cutlass::layout::RowMajor;
652579
using LayoutD = cutlass::layout::RowMajor;
653580

654581
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
655-
using GmemTiledCopyB = XE_2D_U16x16x16_LD_T;
582+
using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;
656583

657584
// Workgroup-level tile
658585
using TileShape = Shape<_256, _256, _32>;
659-
/*
660-
using TiledMma =
661-
TiledMMA<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
662-
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>;
663-
*/
664586

665587
using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32
666588
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
@@ -710,18 +632,23 @@ void MoEGEMM(const bfloat16_t *activations, const bfloat16_t *weights,
710632

711633

712634
int main(int argc, const char **argv) {
713-
const int num_experts = 32;
635+
const int num_experts = 16;
714636

715-
int total_rows_for_each_expert[num_experts] = {
637+
/* int total_rows_for_each_expert[num_experts] = {
716638
148, 231, 404, 180, 127, 244, 224, 244, 110, 617, 289, 845, 191, 424, 30, 97, 57, 324,
717-
62, 77, 75, 144, 250, 287, 629, 370, 161, 101, 215, 113, 224, 35};
639+
62, 77, 75, 144, 250, 287, 629, 370, 161, 101, 215, 113, 224, 35}; */
640+
641+
int total_rows_for_each_expert[num_experts];
642+
for (int i = 0; i < num_experts; i++) {
643+
total_rows_for_each_expert[i] = 512;
644+
}
718645

719646
int num_tokens_incl_duplicated = 0;
720647
for (int i = 0; i < num_experts; i++) {
721648
num_tokens_incl_duplicated += total_rows_for_each_expert[i];
722649
}
723-
int n_moe = 3072;
724-
int k_moe = 4096;
650+
int n_moe = 16384;
651+
int k_moe = 5120;
725652

726653
cutlass::DeviceAllocation<int32_t> num_rows_per_expert_device;
727654
cutlass::DeviceAllocation<bfloat16_t> activations_data;

include/cutlass/epilogue/collective/xe_array_epilogue.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ template <typename ProblemShape_MNKL>
494494
TensorD mD_mnl;
495495
if constexpr (is_source_supported) {
496496
ElementC const *ptr_C_curr_batch =
497-
reinterpret_cast<ElementC const *>(params.ptr_C[0]) +
497+
reinterpret_cast<ElementC const *>((void*)(params.ptr_C)) +
498498
cumulative_M * N;
499499
mC_mnl = make_tensor(
500500
make_gmem_ptr(ptr_C_curr_batch),
@@ -504,7 +504,7 @@ template <typename ProblemShape_MNKL>
504504

505505
if constexpr (is_destination_supported) {
506506
ElementD *ptr_D_curr_batch =
507-
reinterpret_cast<ElementD *>(params.ptr_D[0]) +
507+
reinterpret_cast<ElementD *>((void*)(params.ptr_D)) +
508508
cumulative_M * N;
509509
mD_mnl = make_tensor(
510510
make_gmem_ptr(ptr_D_curr_batch),

include/cutlass/gemm/collective/xe_array_mma.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,10 +311,10 @@ template <typename ProblemShape_MNKL>
311311
const int32_t K = get<2>(problem_shape_mnkl);
312312

313313
ElementA const *ptr_A_curr_batch =
314-
reinterpret_cast<ElementA const *>(mainloop_params.ptr_A[0]) +
314+
reinterpret_cast<ElementA const *>((void*)(mainloop_params.ptr_A)) +
315315
cumulative_M * K;
316316
ElementB const *ptr_B_curr_batch =
317-
reinterpret_cast<ElementB const *>(mainloop_params.ptr_B[0]) +
317+
reinterpret_cast<ElementB const *>((void*)(mainloop_params.ptr_B)) +
318318
next_group * K * N;
319319

320320
Tensor mA = make_tensor(

0 commit comments

Comments
 (0)