diff --git a/examples/12_bmg_moe_gemm_cute_interface/12_bmg_moe_gemm_cute_interface.cpp b/examples/12_bmg_moe_gemm_cute_interface/12_bmg_moe_gemm_cute_interface.cpp new file mode 100644 index 0000000000..1632bf6b1e --- /dev/null +++ b/examples/12_bmg_moe_gemm_cute_interface/12_bmg_moe_gemm_cute_interface.cpp @@ -0,0 +1,424 @@ +/*************************************************************************************************** + * Copyright (c) 2025 Intel Corporation. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief CUTLASS Intel BMG MoE API example based on sycl-tla Group GEMM + +*/ + +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include + +#include +#include +#include + +#include + +#include "cutlass/kernel_hardware_info.h" +#include "cutlass/platform/platform.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/initialize_block.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/sycl_event_manager.hpp" + +#include "moe_grouped_gemm.hpp" +#include "moe_tile_scheduler.hpp" + +#pragma clang diagnostic ignored "-Wpass-failed" +#pragma clang diagnostic ignored "-Wdeprecated-declarations" + +using namespace cute; +using namespace MoE; + +using ElementAccumulator = float; // <- data type of accumulator + +struct VerificationHelper { + + bool error = false; + bool help = false; + + float alpha = 1.f; + float beta = 0.f; + int iterations; + int m = 0, n = 0, k = 0, groups; + int *num_rows_per_expert = nullptr; + std::vector + problem_sizes_host; + + VerificationHelper() + : error(false), help(false), alpha(1.f), beta(0.f), iterations(100) {} + + void parse(const int num_experts, const int *num_tokens_per_expert_host, + int moe_n, int moe_k, + const int *num_tokens_per_expert_device = nullptr) { + n = moe_n; + k = moe_k; + groups = num_experts; + iterations = 2; + num_rows_per_expert = const_cast(num_tokens_per_expert_device); + assert(groups > 0); + problem_sizes_host.clear(); + problem_sizes_host.reserve(groups); + for (int i = 0; i < groups; i++) { + problem_sizes_host.push_back({num_tokens_per_expert_host[i], n, k}); + m += num_tokens_per_expert_host[i]; + } + } + + /// Compute performance in GFLOP/s + std::tuple + gflops(double runtime_s, + std::vector + problem_sizes_host) const { + // Number of real-valued multiply-adds + uint64_t fmas = uint64_t(); + uint64_t bytes_loaded = 0; + + for (auto const &problem : problem_sizes_host) { + auto M = static_cast(get<0>(problem)); + auto N = static_cast(get<1>(problem)); + auto K = static_cast(get<2>(problem)); + fmas += M * N * K; + bytes_loaded += + /* sizeof(cutlass::bfloat16_t) */ 2 * (2 * M * N + N * K + M * K); + } + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * uint64_t(fmas); + double gflop = double(flop) / double(1.0e9); + double arithmetic_intensity = double(flop) / double(bytes_loaded); + double peak_mwm_bw = 456.0; + double gflops_attainable = std::min( + 117 * double(1.0e12), + arithmetic_intensity * (peak_mwm_bw * 1024 * 1024 * 1024)); + double projected_time = flop / gflops_attainable; + return std::make_tuple(gflop / runtime_s, + double(bytes_loaded) / 1024 / 1024 / 1024 / + runtime_s, + projected_time * 1000); + } + + template && + is_any_of_v && + is_any_of_v>> + bool verify(const ElementA *activations, const ElementB *weights, + ElementD *outputs) { + cutlass::DeviceAllocation output_ref; + cutlass::DeviceAllocation unused_c_matrix; + output_ref.reset(m * n); + unused_c_matrix.reset(m * n); + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + bool passed = true; + // Verify against individual reference GEMMs + int cumulative_sum = 0; + for (int32_t i = 0; i < groups; ++i) { + auto problem = problem_sizes_host.at(i); + auto M = get<0>(problem); + cutlass::TensorRef ref_A(activations + cumulative_sum * k, + LayoutA::packed({M, k})); + cutlass::TensorRef ref_B(weights + i * n * k, LayoutB::packed({k, n})); + cutlass::TensorRef ref_C(unused_c_matrix.get() + cumulative_sum * n, + LayoutC::packed({M, n})); + cutlass::TensorRef ref_D(output_ref.get() + cumulative_sum * n, + LayoutD::packed({M, n})); + + // + // Compute reference output + // + cutlass::reference::device::GemmComplex( + {M, n, k}, 1.0, ref_A, cutlass::ComplexTransform::kNone, ref_B, + cutlass::ComplexTransform::kNone, 0.0, ref_C, ref_D, + ElementAccumulator(0), + 1, // batch_count + M * k, // batch_stride_A + k * n, // batch_stride_B + M * n, // batch_stride_C + M * n // batch_stride_D + ); + + // Wait for kernel to finish + compat::wait(); + + // Check if output from CUTLASS kernel and reference kernel are equal or + // not + passed &= cutlass::reference::device::BlockCompareEqual( + output_ref.get() + cumulative_sum * n, outputs + cumulative_sum * n, + M * n); + if (!passed) { + break; + } + cumulative_sum += M; + } + return passed; + } +}; +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct is_complete : std::false_type {}; + +template struct is_complete : std::true_type {}; + +template +static constexpr bool is_complete_v = is_complete::value; + +template auto choose_tiled_mma(TA *A, TB *B) { + using TA_non_CV = cutlass::platform::remove_cv_t; + using TB_non_CV = cutlass::platform::remove_cv_t; + auto op = XE_DPAS_TT<8, float, TA_non_CV, TB_non_CV>{}; + + using WGTile = Shape<_256, _256, _32>; // 256x256 WG tile size + using SGLayout = + Layout, Stride<_4, _1, _0>>; // 8x4 SG tiling, n-major + + using MMA = typename TiledMMAHelper, Layout, + SGLayout>::TiledMMA; + + return MMA{}; +} + +// type tag to define a unique sycl kernel name +template class GemmCuteName; + +template +void MoEGEMMLauncher(const ElementA *activations, const ElementB *weights, + const ElementS *scales, ElementD *outputs, + const int gemm_n, const int gemm_k, + const int *num_rows_per_expert_device, + const int *num_tokens_per_expert_host, + const int num_experts) { + // Change device_id to another value if you are running on a machine with + // multiple GPUs and wish to use a GPU other than that with device ID 0. + // For example, in a framework, you could query device ID. + int sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); + cutlass::KernelHardwareInfo hw_info{0, sm_count}; + auto dummy_problem_shape = cute::Shape{1, gemm_k, gemm_n}; + // The GroupedGEMM API requires creation of a vector of ProblemShape objects + // for each GEMM problem, which is used in the GroupedGEMM tile-scheduler. If + // there are 32 groups, then a vector of 32 `ProblemShape` objects is created. + // Since these would not be known at compile time for a framework, they would + // have to be created at run-time instead. However, for MoEGEMM, I just + // provide one dummy shape, and then the custom code in tile scheduler can + // derive the shape of each GEMM problem. + auto dummy_group_problem_shape = + cutlass::gemm::GroupProblemShape>{ + 1, &dummy_problem_shape, nullptr}; + using TileShape = Shape<_256, _256, _32>; + using ClusterShape = Shape<_1, _1, _1>; + auto scheduler_params = + PersistentTileSchedulerXeMoE::to_underlying_arguments( + dummy_group_problem_shape, TileShape{}, ClusterShape{}, hw_info, + PersistentTileSchedulerXeMoE::Arguments{ + 1, RasterOrderOptions::AlongN}); + auto group_distribution = + PersistentTileSchedulerXeMoE::get_grid_shape( + scheduler_params, dummy_group_problem_shape, TileShape{}, + ClusterShape{}, hw_info, + PersistentTileSchedulerXeMoE::Arguments{ + 1, RasterOrderOptions::AlongN}); + auto mma = choose_tiled_mma(activations, weights); + auto MaxThreadsPerWorkgroup = size(mma); + dim3 local_range{MaxThreadsPerWorkgroup, 1, 1}; + + sycl::range<3> local = {local_range.x, local_range.y, local_range.z}; + sycl::range<3> groups = {group_distribution.x, group_distribution.y, + group_distribution.z}; + sycl::range<3> global = {local[0] * groups[0], local[1] * groups[1], + local[2] * groups[2]}; + + namespace syclex = sycl::ext::oneapi::experimental; + namespace intelex = sycl::ext::intel::experimental; + + syclex::properties kernel_props{syclex::sub_group_size<16>, + intelex::grf_size<256>}; + sycl::queue Q = compat::get_default_queue(); + GPU_Clock timer; + timer.start(); + auto event = Q.parallel_for< + GemmCuteName>( + sycl::nd_range<3>(global, local), kernel_props, [=](auto) { + MoE::MoEGEMM, + XE_LOAD_2D_VNNI<16, 32, 16, 16>, XE_STORE_2D<16, 8, 32>, + 'R', 'R', 'R'>(activations, weights, scales, outputs, mma, + num_rows_per_expert_device, num_experts, + gemm_n, gemm_k, scheduler_params); + }); + EventManager::getInstance().addEvent(event); + Q.wait_and_throw(); + float cute_time = timer.seconds() * 1000; + double cute_average_time = double(cute_time); + + VerificationHelper helper; + helper.parse(num_experts, num_tokens_per_expert_host, gemm_n, gemm_k); + assert(helper.verify(activations, weights, outputs)); + + auto [gflops, mem_bw_util, projected_time] = + helper.gflops(cute_average_time / 1000.0, helper.problem_sizes_host); + + std::cout << " Problem Sizes" << std::endl; + for (int32_t i = 0; i < num_experts; ++i) { + std::cout << " " << num_tokens_per_expert_host[i] << std::endl; + } + std::cout << " Groups : " << num_experts << std::endl; + std::cout << " Avg runtime : " << cute_average_time << " ms" << std::endl; + std::cout << " GFLOPS : " << gflops << std::endl; + std::cout << " Memory BW utilization : " << mem_bw_util << " GBPs" + << std::endl; +} + +void launcher(int *M_per_expert, int N, int K, const int &num_experts) { + int n_moe = N; + int k_moe = K; + int num_tokens_incl_duplicated = 0; + for (int i = 0; i < num_experts; i++) { + num_tokens_incl_duplicated += M_per_expert[i]; + } + + float M_occupancy = 0.f; + float actual_num_units = 0.f; + int total_num_M_tiles = 0; + for (int i = 0; i < num_experts; i++) { + total_num_M_tiles += (M_per_expert[i] + 255) / 256; + actual_num_units += M_per_expert[i] / 256.0; + } + M_occupancy = actual_num_units / total_num_M_tiles; + std::cout << "\n\n M-occupancy is " << M_occupancy << std::endl; + cutlass::DeviceAllocation num_rows_per_expert_device; + cutlass::DeviceAllocation activations_data; + cutlass::DeviceAllocation weights_data; + cutlass::DeviceAllocation output_data; + size_t A_size = num_tokens_incl_duplicated * k_moe; + size_t B_size = num_experts * n_moe * k_moe; + size_t D_size = num_tokens_incl_duplicated * n_moe; + num_rows_per_expert_device.reset(num_experts); + num_rows_per_expert_device.copy_from_host(M_per_expert); + activations_data.reset(A_size); + weights_data.reset(B_size); + output_data.reset(D_size); + uint64_t seed = 2023; + initialize_block(activations_data, seed + 2023); + initialize_block(weights_data, seed + 2022); + initialize_block(output_data, seed + 2021); + + MoEGEMMLauncher<'R', 'R'>(activations_data.get(), weights_data.get(), + static_cast(nullptr), output_data.get(), + n_moe, k_moe, num_rows_per_expert_device.get(), + M_per_expert, num_experts); +} + +int main(int argc, const char **argv) { + constexpr int num_experts = 32; + constexpr int num_layers = 24; + + int total_rows_for_each_expert[num_layers][num_experts] = { + {148, 231, 404, 180, 127, 244, 224, 244, 110, 617, 289, + 845, 191, 424, 30, 97, 57, 324, 62, 77, 75, 144, + 250, 287, 629, 370, 161, 101, 215, 113, 224, 35}, + {666, 214, 448, 87, 4, 28, 48, 13, 74, 40, 546, + 397, 487, 350, 26, 95, 517, 487, 295, 58, 637, 97, + 139, 33, 126, 15, 352, 311, 995, 193, 135, 135}, + {1016, 30, 36, 452, 469, 473, 232, 0, 493, 14, 954, + 6, 4, 6, 279, 3, 94, 106, 96, 48, 49, 113, + 142, 169, 75, 99, 25, 220, 249, 289, 4, 1803}, + {350, 229, 703, 154, 8, 64, 80, 339, 2, 56, 5, + 312, 1005, 29, 9, 11, 23, 0, 23, 431, 48, 129, + 496, 476, 8, 1234, 7, 130, 34, 58, 41, 1554}, + {39, 10, 6, 2, 110, 1, 894, 8, 53, 0, 275, + 6, 506, 421, 700, 178, 0, 530, 1623, 15, 231, 74, + 6, 222, 1246, 116, 35, 20, 0, 6, 381, 334}, + {399, 5, 201, 6, 134, 93, 1748, 1, 51, 4, 38, + 336, 53, 88, 328, 724, 15, 388, 706, 52, 19, 55, + 52, 33, 623, 1, 222, 215, 69, 45, 308, 1036}, + {11, 8, 407, 571, 458, 275, 197, 211, 13, 564, 462, + 114, 15, 13, 132, 24, 514, 2, 71, 13, 694, 47, + 16, 203, 610, 40, 0, 1587, 66, 23, 196, 491}, + {0, 230, 116, 136, 315, 643, 6, 183, 37, 26, 960, + 1, 8, 258, 21, 1602, 213, 198, 6, 196, 455, 557, + 47, 282, 493, 18, 101, 11, 616, 45, 268, 0}, + {392, 305, 179, 14, 227, 98, 114, 39, 64, 1456, 465, + 0, 18, 372, 0, 0, 189, 257, 25, 290, 486, 0, + 12, 1534, 468, 4, 555, 35, 146, 0, 161, 143}, + {4, 107, 20, 125, 236, 898, 0, 0, 375, 2, 125, + 0, 0, 1429, 36, 195, 1660, 0, 127, 454, 73, 358, + 47, 79, 32, 20, 1465, 0, 0, 6, 109, 66}, + {19, 0, 0, 0, 2, 1638, 75, 135, 392, 2, 1494, 3, 23, 5, 4, 58, + 0, 0, 71, 1285, 8, 441, 0, 145, 209, 408, 450, 2, 824, 13, 326, 16}, + {4, 2, 14, 0, 30, 206, 41, 131, 0, 429, 16, 895, 35, 21, 44, 128, + 12, 0, 417, 0, 838, 917, 42, 115, 109, 1759, 0, 36, 17, 0, 1790, 0}, + {6, 483, 241, 1327, 17, 11, 480, 9, 880, 58, 4, + 0, 61, 30, 16, 176, 9, 309, 26, 0, 0, 1882, + 4, 281, 475, 783, 197, 0, 19, 15, 6, 243}, + {370, 1222, 0, 6, 108, 929, 2, 7, 157, 348, 149, 106, 2, 5, 25, 33, + 1569, 8, 6, 106, 69, 1298, 0, 2, 529, 520, 0, 421, 0, 25, 26, 0}, + {59, 89, 0, 26, 25, 40, 1873, 141, 527, 371, 262, + 62, 16, 0, 127, 234, 1637, 64, 132, 8, 0, 7, + 161, 1005, 22, 1, 49, 6, 83, 925, 80, 16}, + {269, 617, 30, 4, 90, 26, 0, 16, 154, 212, 21, + 269, 379, 174, 129, 32, 8, 121, 344, 15, 0, 591, + 1494, 6, 737, 50, 112, 856, 483, 25, 454, 330}, + {0, 98, 1488, 22, 73, 0, 0, 343, 77, 4, 0, + 612, 165, 268, 4, 10, 43, 0, 598, 271, 2, 73, + 185, 0, 112, 779, 24, 1626, 0, 0, 0, 1171}, + {0, 0, 0, 189, 266, 1743, 0, 462, 20, 7, 668, 310, 40, 0, 10, 236, + 423, 18, 0, 0, 0, 999, 0, 139, 1754, 8, 619, 3, 23, 0, 102, 9}, + {131, 1753, 0, 113, 24, 94, 2, 12, 108, 0, 0, + 252, 97, 0, 1319, 233, 93, 1254, 195, 152, 14, 413, + 4, 2, 220, 67, 20, 4, 34, 559, 837, 42}, + {55, 76, 0, 8, 0, 3, 1557, 975, 135, 271, 4, 0, 0, 666, 207, 152, + 5, 2, 97, 364, 0, 13, 1423, 771, 159, 31, 223, 0, 431, 7, 409, 4}, + {4, 1026, 1799, 166, 694, 753, 0, 16, 0, 240, 1119, 19, 6, 0, 46, 659, + 10, 0, 112, 808, 181, 0, 28, 22, 90, 0, 176, 0, 37, 5, 10, 22}, + {44, 0, 4, 153, 299, 1357, 6, 23, 0, 12, 4, 419, 73, 24, 16, 24, + 1, 4, 4, 102, 16, 4, 0, 1953, 1850, 0, 908, 4, 0, 13, 708, 23}, + {6, 13, 123, 28, 197, 0, 202, 69, 0, 6, 0, 21, 1434, 1582, 11, 0, 6, + 0, 7, 190, 4, 1700, 6, 434, 1886, 0, 14, 28, 8, 30, 25, 18}, + {5, 27, 1442, 18, 0, 6, 0, 73, 6, 781, 0, 1915, 291, 649, 98, 4, + 33, 77, 6, 22, 73, 9, 8, 587, 1486, 32, 10, 244, 37, 0, 100, 9}}; + + for (int i = 0; i < num_layers; i++) { + launcher(total_rows_for_each_expert[i], 5760, 2880, num_experts); + launcher(total_rows_for_each_expert[i], 2880, 2880, num_experts); + } + + return 0; +} diff --git a/examples/12_bmg_moe_gemm_cute_interface/CMakeLists.txt b/examples/12_bmg_moe_gemm_cute_interface/CMakeLists.txt new file mode 100644 index 0000000000..000e98a28e --- /dev/null +++ b/examples/12_bmg_moe_gemm_cute_interface/CMakeLists.txt @@ -0,0 +1,37 @@ +# Copyright (c) 2025 Intel Corporation. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +if(NOT "${DPCPP_HOST_COMPILER}" MATCHES "g\\+\\+") + cutlass_example_add_executable( + 12_bmg_moe_gemm_cute_interface + 12_bmg_moe_gemm_cute_interface.cpp + ) + if(NOT DPCPP_SYCL_TARGET STREQUAL "spir64") + target_link_options( 12_bmg_moe_gemm_cute_interface PRIVATE -Xs "-options \"-igc_opts 'VectorAliasBBThreshold=10000'\"" ) + endif() +endif() diff --git a/examples/12_bmg_moe_gemm_cute_interface/moe_gemms.hpp b/examples/12_bmg_moe_gemm_cute_interface/moe_gemms.hpp new file mode 100644 index 0000000000..7b93c60d37 --- /dev/null +++ b/examples/12_bmg_moe_gemm_cute_interface/moe_gemms.hpp @@ -0,0 +1,158 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include +#include +#include + +#include + +#include "cutlass/kernel_hardware_info.h" +#include "cutlass/platform/platform.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/sycl_event_manager.hpp" + +#pragma clang diagnostic ignored "-Wpass-failed" +#pragma clang diagnostic ignored "-Wdeprecated-declarations" + +template struct is_16_bit_fp : std::false_type {}; + +template <> struct is_16_bit_fp : std::true_type {}; +template <> struct is_16_bit_fp : std::true_type {}; + +template +inline constexpr bool is_16_bit_fp_v = + is_16_bit_fp>>::value; + +static_assert(is_16_bit_fp_v); +static_assert(is_16_bit_fp_v); + +namespace MoE { + +using namespace cute; + +template < + class GmemTiledCopyA, class GmemTiledCopyB, class GmemTiledCopyD, + class ATensor, class BTensor, class DTensor, class TiledMMA, + class = std::enable_if_t && + is_16_bit_fp_v>> +CUTE_DEVICE void moe_gemm(ATensor const &A, // (M,K) + BTensor const &B, // (N,K) + DTensor &D, // (M,N) + Coord blk_coord, + TiledMMA const &mma) { + auto item = sycl::ext::oneapi::this_work_item::get_nd_item<2>(); + auto wg_m = get<0>(blk_coord); + auto wg_n = get<1>(blk_coord); + auto local_id = int(item.get_local_id(0)); + + Tensor cA = make_identity_tensor(A.shape()); // (M,K) + Tensor cB = make_identity_tensor(B.shape()); // (N,K) + Tensor cD = make_identity_tensor(D.shape()); // (M,N) + + auto wg_coord = make_coord(wg_m, wg_n, 0); + auto wg_tile = mma.tile_mnk(); + + Tensor gA = local_tile(cA, select<0, 2>(wg_tile), make_coord(wg_m, _)); + Tensor gB = local_tile(cB, select<1, 2>(wg_tile), make_coord(wg_n, _)); + Tensor gD = local_tile(cD, wg_tile, wg_coord, Step<_1, _1, X>{}); + + auto thr_mma = mma.get_slice(local_id); + + auto tiled_copy_a = get_block_2d_copy_A(mma, A); + auto tiled_copy_b = get_block_2d_copy_B(mma, B); + auto tiled_copy_d = make_block_2d_copy_CD(GmemTiledCopyD{}, mma, D); + + auto thr_copy_a = tiled_copy_a.get_slice(local_id); + auto thr_copy_b = tiled_copy_b.get_slice(local_id); + auto thr_copy_d = tiled_copy_d.get_slice(local_id); + + auto tCrA = thr_mma.partition_sg_fragment_A(gA(_, _, 0)); + auto tCrB = thr_mma.partition_sg_fragment_B(gB(_, _, 0)); + auto tCrD = thr_mma.partition_sg_fragment_C(gD); + auto tCrD_final = thr_copy_d.partition_sg_fragment_S(gD); + + auto tArA = thr_copy_a.partition_sg_fragment_D(gA(_, _, 0)); + auto tBrB = thr_copy_b.partition_sg_fragment_D(gB(_, _, 0)); + + Tensor tAgA = thr_copy_a.partition_S(gA); + Tensor tBgB = thr_copy_b.partition_S(gB); + auto tCgD = thr_copy_d.partition_D(gD); + + auto prefetch_a = make_block_2d_prefetch(tiled_copy_a); + auto prefetch_b = make_block_2d_prefetch(tiled_copy_b); + + auto thr_prefetch_A = prefetch_a.get_slice(local_id); + auto thr_prefetch_B = prefetch_b.get_slice(local_id); + + auto pAgA = thr_prefetch_A.partition_S(gA); + auto pBgB = thr_prefetch_B.partition_S(gB); + + constexpr int barrier_scope = 2; + int k_start_idx = 0; + int prefetch_k = k_start_idx; + const int prefetch_dist = 3; + int k_tile_count = ceil_div(shape<1>(A), get<2>(wg_tile)); + + CUTE_UNROLL + for (; prefetch_k < prefetch_dist; prefetch_k++) { + prefetch(prefetch_a, pAgA(_, _, _, prefetch_k)); + prefetch(prefetch_b, pBgB(_, _, _, prefetch_k)); + } + + for (int k_tile = k_start_idx; k_tile < k_tile_count; + k_tile++, prefetch_k++) { + barrier_arrive(barrier_scope); + + copy(tiled_copy_a, tAgA(_, _, _, k_tile), tArA); + copy(tiled_copy_b, tBgB(_, _, _, k_tile), tBrB); + + if (prefetch_k < k_tile_count) { + prefetch(prefetch_a, pAgA(_, _, _, prefetch_k)); + prefetch(prefetch_b, pBgB(_, _, _, prefetch_k)); + } + + reorder(tArA, tCrA); + reorder(tBrB, tCrB); + + cute::gemm(mma, tCrA, tCrB, tCrD); + barrier_wait(barrier_scope); + } + reorder(tCrD, tCrD_final); + copy(tiled_copy_d, tCrD_final, tCgD); +} + +} // namespace MoE diff --git a/examples/12_bmg_moe_gemm_cute_interface/moe_grouped_gemm.hpp b/examples/12_bmg_moe_gemm_cute_interface/moe_grouped_gemm.hpp new file mode 100644 index 0000000000..d0128f58a8 --- /dev/null +++ b/examples/12_bmg_moe_gemm_cute_interface/moe_grouped_gemm.hpp @@ -0,0 +1,147 @@ +/*************************************************************************************************** + * Copyright 2025 Intel corporation. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/platform/platform.h" +#include "moe_gemms.hpp" +#include "moe_tile_scheduler.hpp" +#include + +#pragma clang diagnostic ignored "-Wpass-failed" +#pragma clang diagnostic ignored "-Wdeprecated-declarations" + +namespace MoE { +using namespace cute; + +using ProblemShapeMNKL = Shape; +using ProblemShape = cutlass::gemm::GroupProblemShape>; +using TileScheduler = typename MoE::PersistentTileSchedulerXeMoE; +using RasterOrderOptions = typename TileScheduler::RasterOrderOptions; + +template +CUTE_DEVICE auto make_moe_tensor(T *ptr, int r, int c) { + auto shape = make_shape(r, c); + if constexpr (LayoutKind == 'C') + return make_tensor(make_gmem_ptr(ptr), + make_layout(shape, make_stride(_1{}, r))); + else + return make_tensor(make_gmem_ptr(ptr), + make_layout(shape, make_stride(c, _1{}))); +} + +template +CUTE_DEVICE void +MoEGEMM(const ElementA *Activations, const ElementB *Weights, + const ElementS *Scales, ElementD *Outputs, TiledMMA const &mma, + const int32_t *M_per_group, const int32_t num_experts, const int32_t N, + const int32_t K, + PersistentTileSchedulerSm90GroupParams scheduler_params) { + + TileScheduler scheduler{scheduler_params}; + + // TODO : Modify tile-scheduler to not require configure + scheduler.configure(const_cast(M_per_group), N, K, num_experts); + auto work_tile_info = scheduler.initial_work_tile_info(Shape<_1, _1, _1>{}); + constexpr char actual_layout_of_B = LayoutKindB ^ ('R' ^ 'C'); + bool did_group_change = true; + int32_t curr_group = 0; + int32_t prev_group = 0; + int32_t cumulative_M = 0; + int32_t M = 0; + + if (work_tile_info.is_valid()) { + // We don't really need this conditional outside the while loop. + // It simply helps initialize tensors. If using nullptr would be + // fine for their initialization, then we can remove this conditional. + curr_group = work_tile_info.L_idx; + M = M_per_group[curr_group]; + } + + auto A_tensor = make_moe_tensor( + const_cast(Activations), M, K); + auto B_tensor = make_moe_tensor( + const_cast(Weights), N, K); + auto D_tensor = make_moe_tensor(Outputs, M, N); + + while (work_tile_info.is_valid()) { + auto m_coord = work_tile_info.M_idx; + auto n_coord = work_tile_info.N_idx; + auto tile_coord = make_coord(m_coord, n_coord, _, 0); + + if (did_group_change) { + curr_group = work_tile_info.L_idx; + M = M_per_group[curr_group]; + // recompute each time because the groups don't necessarily increment by 1 + for (int i = prev_group; i < curr_group; i++) { + cumulative_M += M_per_group[i]; + } + prev_group = curr_group; + + ElementA *ptr_A_curr_batch = + const_cast(Activations) + cumulative_M * K; + ElementB *ptr_B_curr_batch = + const_cast(Weights) + curr_group * K * N; + ElementD *ptr_D_curr_batch = Outputs + cumulative_M * N; + + A_tensor = make_moe_tensor(ptr_A_curr_batch, M, K); + B_tensor = + make_moe_tensor(ptr_B_curr_batch, N, K); + D_tensor = make_moe_tensor(ptr_D_curr_batch, M, N); + did_group_change = false; + } + + // After adding scaledMM mainloops, add something like + // if constexpr (!cute::is_void_v) { + // moe_gemm( + // A_tensor, B_tensor, Scales, D_tensor, tile_coord, mma); + // } else { + moe_gemm( + A_tensor, B_tensor, D_tensor, tile_coord, mma); + + // Get next work tile + auto [next_work_tile_info, temp] = + scheduler.fetch_next_work(work_tile_info); + work_tile_info = next_work_tile_info; + did_group_change = curr_group != work_tile_info.L_idx; + } // end while loop +} + +} // namespace MoE diff --git a/examples/12_bmg_moe_gemm_cute_interface/moe_tile_scheduler.hpp b/examples/12_bmg_moe_gemm_cute_interface/moe_tile_scheduler.hpp new file mode 100644 index 0000000000..d2f486dc7f --- /dev/null +++ b/examples/12_bmg_moe_gemm_cute_interface/moe_tile_scheduler.hpp @@ -0,0 +1,347 @@ +/*************************************************************************************************** + * Copyright (c) 2025 Intel Corporation. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/gemm_coord.hpp" +#include "cutlass/kernel_hardware_info.hpp" + +namespace MoE { +using namespace cutlass::gemm::kernel::detail; +using namespace cutlass; +using namespace cutlass::gemm; +using namespace cute; +/////////////////////////////////////////////////////////////////////////////// +// Adapted from xe_tile_scheduler_group.hpp +// Persistent Thread Block (TB) scheduler for MoE GEMM +template class PersistentTileSchedulerXeMoE { + // + // Data members + // + +private: + uint64_t current_work_linear_idx_ = 0; + uint64_t total_grid_size_ = 0; + int32_t *num_rows_per_expert_ = nullptr; + int32_t K_ = 0; + int32_t N_ = 0; + int32_t num_experts_ = 0; + + // Tracking current group, its starting linear idx and total tiles + struct GroupInfo { + int group_idx = 0; + uint64_t start_linear_idx = 0; + uint64_t total_tiles = 0; + } current_group_info_; + +public: + struct WorkTileInfo { + int32_t M_idx = 0; + int32_t N_idx = 0; + int32_t L_idx = 0; + bool is_valid_tile = false; + + CUTLASS_HOST_DEVICE + bool is_valid() const { return is_valid_tile; } + + CUTLASS_HOST_DEVICE + static WorkTileInfo invalid_work_tile() { return {-1, -1, -1, false}; } + }; + + using ProblemShape = typename GroupProblemShape::UnderlyingProblemShape; + using Params = PersistentTileSchedulerSm90GroupParams; + using RasterOrder = typename Params::RasterOrder; + using RasterOrderOptions = typename Params::RasterOrderOptions; + + struct Arguments { + int max_swizzle_size = 1; + // Not applying Heuristics for Grouped problems, since largest dimension can + // change per group + RasterOrderOptions raster_order = RasterOrderOptions::AlongM; + }; + + // Sink scheduler params as a member + Params scheduler_params; + + // + // Methods + // + + CUTLASS_HOST_DEVICE void configure(int32_t *num_rows_per_expert, int32_t N, + int32_t K, int32_t num_experts) { + num_rows_per_expert_ = num_rows_per_expert; + N_ = N; + K_ = K; + num_experts_ = num_experts; + } + + // Given the inputs, computes the total number of output blocks this problem + // will compute over Note that this is only the logical size of our grid, not + // the physical grid we will actually launch. + template + CUTLASS_HOST_DEVICE static dim3 + get_tiled_cta_shape_mnl(const KernelHardwareInfo &hw_info, + ClusterShape cluster_shape) { + uint32_t total_ctas = 0; + uint32_t cta_in_N_dim = + 1; // We linearize the blocks across all the problems here + + total_ctas = hw_info.sm_count; + + return Params::get_tiled_cta_shape_mnl(to_gemm_coord(cluster_shape), + total_ctas, cta_in_N_dim); + } + + template + static Params to_underlying_arguments( + GroupProblemShape problem_shapes, TileShape tile_shape, + ClusterShape cluster_shape, KernelHardwareInfo const &hw_info, + Arguments const &arguments, [[maybe_unused]] void *workspace = nullptr, + [[maybe_unused]] const uint32_t epilogue_subtile = 1, + [[maybe_unused]] uint32_t ktile_start_alignment_count = 1u) { + + // We only need the tile and cluster shape during scheduler setup, so let + // FTAD do the magic + static_assert(cute::is_static::value); + static_assert(cute::is_static::value); + + dim3 problem_blocks = get_tiled_cta_shape_mnl(hw_info, cluster_shape); + + Params params; + params.initialize(problem_blocks, problem_shapes, to_gemm_coord(tile_shape), + to_gemm_coord(cluster_shape), hw_info, + arguments.max_swizzle_size, arguments.raster_order); + + return params; + } + + // Given the inputs, computes the physical grid we should launch. + template + CUTLASS_HOST_DEVICE static dim3 + get_grid_shape([[maybe_unused]] Params const ¶ms, + GroupProblemShape problem_shapes, TileShape tile_shape, + ClusterShape cluster_shape, KernelHardwareInfo hw_info, + Arguments arguments, bool truncate_by_problem_size = true) { + + dim3 problem_blocks = get_tiled_cta_shape_mnl(hw_info, cluster_shape); + + return Params::get_grid_shape(problem_blocks, to_gemm_coord(cluster_shape), + hw_info, arguments.max_swizzle_size, + arguments.raster_order, + /* truncate_by_problem_size = */ true); + } + + PersistentTileSchedulerXeMoE() = default; + + CUTLASS_DEVICE explicit PersistentTileSchedulerXeMoE(Params const ¶ms_) + : scheduler_params(params_) { + // MSVC requires protecting use of CUDA-specific nonstandard syntax, + // like blockIdx and gridDim, with __CUDA_ARCH__. +#if defined(__CUDA_ARCH__) || defined __SYCL_DEVICE_ONLY__ + if (scheduler_params.raster_order_ == RasterOrder::AlongN) { + current_work_linear_idx_ = + uint64_t(BlockIdxX()) + uint64_t(BlockIdxY()) * uint64_t(GridDimX()); + } else { + current_work_linear_idx_ = + uint64_t(BlockIdxX()) * uint64_t(GridDimY()) + uint64_t(BlockIdxY()); + } + + total_grid_size_ = + uint64_t(GridDimX()) * uint64_t(GridDimY()) * uint64_t(GridDimZ()); + +#else + CUTLASS_ASSERT(false && "This line should never be reached"); +#endif + } + + CUTLASS_DEVICE + WorkTileInfo get_current_work() { + return get_current_work_for_linear_idx(current_work_linear_idx_); + } + + CUTLASS_DEVICE + WorkTileInfo get_current_work_for_linear_idx(uint64_t linear_idx) { + if (scheduler_params.pre_processed_problem_shapes && + linear_idx >= scheduler_params.blocks_across_problem_) { + return WorkTileInfo::invalid_work_tile(); + } + + return get_work_idx_m_and_n( + linear_idx, current_group_info_, scheduler_params.problem_shapes_, + scheduler_params.cta_shape_, scheduler_params.cluster_shape_, + scheduler_params.divmod_cluster_shape_major_, + scheduler_params.divmod_cluster_shape_minor_, + scheduler_params.divmod_cta_shape_m_, + scheduler_params.divmod_cta_shape_n_, + scheduler_params.log_swizzle_size_, scheduler_params.raster_order_); + } + + CUTLASS_DEVICE + void advance_to_next_work(uint32_t advance_count = 1) { + current_work_linear_idx_ += total_grid_size_ * uint64_t(advance_count); + } + + // get work_idx_m, work_idx_n from linear_idx while applying swizzle + CUTLASS_DEVICE + WorkTileInfo + get_work_idx_m_and_n(uint64_t linear_idx, struct GroupInfo &group_info, + GroupProblemShape &problem_shapes, GemmCoord cta_shape, + cutlass::gemm::GemmCoord cluster_shape, + FastDivmodU64Pow2 const &divmod_cluster_shape_major, + FastDivmodU64Pow2 const &divmod_cluster_shape_minor, + FastDivmodU64 const &divmod_cta_shape_m, + FastDivmodU64 const &divmod_cta_shape_n, + int32_t log_swizzle_size, RasterOrder raster_order) { + + bool valid_tile = true; + uint64_t ctas_along_m, ctas_along_n; + int total_problem_groups = num_experts_; + ctas_along_m = divmod_cta_shape_m.divide( + cute::shape<0>( + ProblemShape(num_rows_per_expert_[group_info.group_idx], N_, K_)) + + divmod_cta_shape_m.divisor - 1); + ctas_along_n = divmod_cta_shape_n.divide( + cute::shape<1>( + ProblemShape(num_rows_per_expert_[group_info.group_idx], N_, K_)) + + divmod_cta_shape_n.divisor - 1); + + auto problem_blocks_m = + round_up(ctas_along_m, (1 << log_swizzle_size) * cluster_shape.m()); + auto problem_blocks_n = + round_up(ctas_along_n, (1 << log_swizzle_size) * cluster_shape.n()); + group_info.total_tiles = problem_blocks_m * problem_blocks_n; + + while (group_info.start_linear_idx + group_info.total_tiles <= linear_idx) { + group_info.group_idx++; + + if (group_info.group_idx >= total_problem_groups) + return WorkTileInfo::invalid_work_tile(); + + group_info.start_linear_idx += group_info.total_tiles; + ctas_along_m = divmod_cta_shape_m.divide( + cute::shape<0>(ProblemShape( + num_rows_per_expert_[group_info.group_idx], N_, K_)) + + divmod_cta_shape_m.divisor - 1); + ctas_along_n = divmod_cta_shape_n.divide( + cute::shape<1>(ProblemShape( + num_rows_per_expert_[group_info.group_idx], N_, K_)) + + divmod_cta_shape_n.divisor - 1); + + problem_blocks_m = + round_up(ctas_along_m, (1 << log_swizzle_size) * cluster_shape.m()); + problem_blocks_n = + round_up(ctas_along_n, (1 << log_swizzle_size) * cluster_shape.n()); + group_info.total_tiles = problem_blocks_m * problem_blocks_n; + } + + uint64_t cluster_id, cluster_major_offset = 0, cluster_minor_offset = 0; + uint64_t blk_per_grid_dim = divmod_cluster_shape_minor.divide( + linear_idx - group_info.start_linear_idx); + divmod_cluster_shape_major(cluster_id, cluster_major_offset, + blk_per_grid_dim); + + // With static schedulers, we launch grid such that all cluster are linear + // (1-D) order, i.e., there can only be one cluster in the minor dimension. + // get_grid_shape() in scheduler params put cluster_shape.m/n() as the minor + // dimension based on raster order AlongN/M resp. Therefore, the offset of a + // CTA (inside a cluster) in the minor dimension can be directly be inferred + // by the blockIdx along the minor dimension. + if (raster_order == RasterOrder::AlongN) { + cluster_minor_offset = BlockIdxX(); + } else { + cluster_minor_offset = BlockIdxY(); + } + + uint64_t cluster_idx_minor, cluster_idx_major; + + uint64_t cluster_idx_minor_div_swizzle, extra, offset; + + offset = cluster_id & ((1 << log_swizzle_size) - 1); + extra = cluster_id >> log_swizzle_size; + + uint64_t curr_group_cluster_blk_major; + if (raster_order == RasterOrder::AlongN) { + curr_group_cluster_blk_major = + divmod_cluster_shape_major.divide(problem_blocks_n); + } else { + curr_group_cluster_blk_major = + divmod_cluster_shape_major.divide(problem_blocks_m); + } + cluster_idx_minor_div_swizzle = extra / curr_group_cluster_blk_major; + cluster_idx_major = extra % curr_group_cluster_blk_major; + + cluster_idx_minor = + cluster_idx_minor_div_swizzle * (1 << log_swizzle_size) + offset; + + auto minor_work_idx = static_cast( + cluster_idx_minor * divmod_cluster_shape_minor.divisor + + cluster_minor_offset); + auto major_work_idx = static_cast( + cluster_idx_major * divmod_cluster_shape_major.divisor + + cluster_major_offset); + + if (raster_order == RasterOrder::AlongN) { + return {minor_work_idx, major_work_idx, group_info.group_idx, valid_tile}; + } else { + return {major_work_idx, minor_work_idx, group_info.group_idx, valid_tile}; + } + } + + // Returns whether the current WorkTileInfo passed in should continue to be + // used. Since this scheduler only schedules work in units of single, full + // output tiles, the WorkTileInfo passed in should not be used after having + // been processed. + CUTLASS_DEVICE + static bool continue_current_work(WorkTileInfo &) { return false; } + + // Kernel helper function to get next work tile + CUTLASS_DEVICE + auto fetch_next_work(WorkTileInfo work_tile_info) { + if (continue_current_work(work_tile_info)) { + return cute::make_tuple(work_tile_info, true); + } + + advance_to_next_work(); + return cute::make_tuple(get_current_work(), true); + } + + // Returns the initial work tile info that will be computed over + template + CUTLASS_DEVICE WorkTileInfo initial_work_tile_info(ClusterShape) { + return get_current_work(); + } +}; + +} // namespace MoE diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 6aaca98f03..8407747cd9 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -114,6 +114,7 @@ if(CUTLASS_ENABLE_SYCL) 09_bmg_grouped_gemm_f8 10_bmg_grouped_gemm_mixed_dtype 11_xe20_cutlass_library + 12_bmg_moe_gemm_cute_interface ) add_subdirectory(${EXAMPLE}) endforeach()