Skip to content

Commit 6c0cee7

Browse files
committed
[skip ci] Trimming almost done
1 parent 177c85c commit 6c0cee7

File tree

5 files changed

+1247
-0
lines changed

5 files changed

+1247
-0
lines changed

examples/cute/tutorial/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ if (CUTLASS_ENABLE_SYCL)
5757
bgemm_bmg_legacy.cpp
5858
)
5959

60+
cutlass_example_add_executable(
61+
cute_tutorial_moe_gemm
62+
moe/moe_example.cpp
63+
)
64+
6065
cutlass_example_add_executable(
6166
cute_tutorial_xe_gemm
6267
xe_gemm.cpp
Lines changed: 375 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,375 @@
1+
/***************************************************************************************************
2+
* Copyright (c) 2025 Intel Corporation. All rights reserved.
3+
* SPDX-License-Identifier: BSD-3-Clause
4+
*
5+
* Redistribution and use in source and binary forms, with or without
6+
* modification, are permitted provided that the following conditions are met:
7+
*
8+
* 1. Redistributions of source code must retain the above copyright notice,
9+
*this list of conditions and the following disclaimer.
10+
*
11+
* 2. Redistributions in binary form must reproduce the above copyright notice,
12+
* this list of conditions and the following disclaimer in the documentation
13+
* and/or other materials provided with the distribution.
14+
*
15+
* 3. Neither the name of the copyright holder nor the names of its
16+
* contributors may be used to endorse or promote products derived from
17+
* this software without specific prior written permission.
18+
*
19+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22+
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23+
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24+
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25+
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26+
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27+
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28+
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29+
*POSSIBILITY OF SUCH DAMAGE.
30+
*
31+
**************************************************************************************************/
32+
/*! \file
33+
\brief CUTLASS Intel BMG MoE API example based on sycl-tla Group GEMM
34+
35+
*/
36+
37+
#include "cutlass/util/GPU_Clock.hpp"
38+
39+
#include <cute/tensor.hpp>
40+
#include <random>
41+
42+
#include <cute/util/compat.hpp>
43+
#include <sycl/ext/intel/experimental/grf_size_properties.hpp>
44+
#include <sycl/sycl.hpp>
45+
46+
#include <cute/tensor.hpp>
47+
48+
#include "cutlass/kernel_hardware_info.h"
49+
#include "cutlass/platform/platform.h"
50+
#include "cutlass/tensor_ref.h"
51+
#include "cutlass/util/GPU_Clock.hpp"
52+
#include "cutlass/util/device_memory.h"
53+
#include "cutlass/util/initialize_block.hpp"
54+
#include "cutlass/util/reference/device/gemm_complex.h"
55+
#include "cutlass/util/reference/device/tensor_compare.h"
56+
#include "cutlass/util/reference/host/tensor_fill.h"
57+
#include "cutlass/util/sycl_event_manager.hpp"
58+
59+
#include "../../../common/sycl_cute_common.hpp"
60+
#include "moe_grouped_gemm.hpp"
61+
#include "moe_tile_scheduler.hpp"
62+
63+
#pragma clang diagnostic ignored "-Wpass-failed"
64+
#pragma clang diagnostic ignored "-Wdeprecated-declarations"
65+
66+
#include <cfloat>
67+
68+
using namespace cute;
69+
70+
using ElementAccumulator = float; // <- data type of accumulator
71+
using ElementComputeEpilogue = float; // <- data type of epilogue operations
72+
using ElementA = bfloat16_t; // <- data type of elements in input matrix A
73+
using ElementB = bfloat16_t; // <- data type of elements in input matrix B
74+
using ElementOutput = bfloat16_t; // <- data type of elements in output matrix D
75+
76+
///////////////////////////////////////////////////////////////////////////////////////////////////
77+
78+
#define CUTLASS_SYCL_PROFILING_ENABLED
79+
80+
// Command line options parsing
81+
struct GroupGEMMOptions {
82+
83+
bool error = false;
84+
bool help = false;
85+
86+
float alpha = 1.f;
87+
float beta = 0.f;
88+
int iterations;
89+
int m = 0, n = 0, k = 0, groups;
90+
int *num_rows_per_expert = nullptr;
91+
std::vector<typename MoE::ProblemShape::UnderlyingProblemShape>
92+
problem_sizes_host;
93+
94+
GroupGEMMOptions()
95+
: error(false), help(false), alpha(1.f), beta(0.f), iterations(100) {}
96+
97+
void parse(const int num_experts, const int *num_tokens_per_expert_host,
98+
int moe_n, int moe_k,
99+
const int *num_tokens_per_expert_device = nullptr) {
100+
n = moe_n;
101+
k = moe_k;
102+
groups = num_experts;
103+
iterations = 2;
104+
num_rows_per_expert = const_cast<int *>(num_tokens_per_expert_device);
105+
assert(groups > 0);
106+
problem_sizes_host.clear();
107+
problem_sizes_host.reserve(groups);
108+
for (int i = 0; i < groups; i++) {
109+
problem_sizes_host.push_back({num_tokens_per_expert_host[i], n, k});
110+
}
111+
}
112+
113+
/// Compute performance in GFLOP/s
114+
std::tuple<double, double, double>
115+
gflops(double runtime_s,
116+
std::vector<typename MoE::ProblemShape::UnderlyingProblemShape>
117+
problem_sizes_host) const {
118+
// Number of real-valued multiply-adds
119+
uint64_t fmas = uint64_t();
120+
uint64_t bytes_loaded = 0;
121+
122+
for (auto const &problem : problem_sizes_host) {
123+
auto M = static_cast<uint64_t>(get<0>(problem));
124+
auto N = static_cast<uint64_t>(get<1>(problem));
125+
auto K = static_cast<uint64_t>(get<2>(problem));
126+
fmas += M * N * K;
127+
bytes_loaded +=
128+
/* sizeof(cutlass::bfloat16_t) */ 2 * (2 * M * N + N * K + M * K);
129+
}
130+
// Two flops per multiply-add
131+
uint64_t flop = uint64_t(2) * uint64_t(fmas);
132+
double gflop = double(flop) / double(1.0e9);
133+
double arithmetic_intensity = double(flop) / double(bytes_loaded);
134+
double peak_mwm_bw = 456.0;
135+
double gflops_attainable = std::min<double>(
136+
117 * double(1.0e12),
137+
arithmetic_intensity * (peak_mwm_bw * 1024 * 1024 * 1024));
138+
double projected_time = flop / gflops_attainable;
139+
return std::make_tuple(gflop / runtime_s,
140+
double(bytes_loaded) / 1024 / 1024 / 1024 /
141+
runtime_s,
142+
projected_time * 1000);
143+
}
144+
};
145+
146+
///////////////////////////////////////////////////////////////////////////////////////////////////
147+
148+
template <typename T, char LayoutKind>
149+
auto make_device_tensor(sycl::queue &Q, int r, int c) {
150+
using namespace cute;
151+
auto ptr = make_gmem_ptr(sycl::malloc_device<T>(r * c, Q));
152+
auto shape = make_shape(r, c);
153+
if constexpr (LayoutKind == 'C')
154+
return make_tensor(ptr, make_layout(shape, make_stride(_1{}, r)));
155+
else
156+
return make_tensor(ptr, make_layout(shape, make_stride(c, _1{})));
157+
}
158+
159+
template <typename T, size_t = 0> struct is_complete : std::false_type {};
160+
161+
template <typename T> struct is_complete<T, 0 * sizeof(T)> : std::true_type {};
162+
163+
template <typename T>
164+
static constexpr bool is_complete_v = is_complete<T>::value;
165+
166+
// Some of this code has been authored by Peter caday
167+
template <typename TA, typename TB, typename TC> auto choose_mma_op() {
168+
if constexpr (is_complete_v<XE_DPAS_TT<8, TC, TA, TB>>)
169+
return XE_DPAS_TT<8, TC, TA, TB>{};
170+
else if constexpr (is_same_v<TA, cute::bfloat16_t>)
171+
return XE_DPAS_TT<8, float, cute::bfloat16_t>{};
172+
else /* Use f16 by default as upconversion sequences are typically faster */
173+
return XE_DPAS_TT<8, float, cute::half_t>{};
174+
}
175+
176+
template <class TA, class TB, class TD>
177+
auto choose_tiled_mma(TA *A, TB *B, TD *D) {
178+
179+
auto op = choose_mma_op<TA, TB, TD>();
180+
181+
using WGTile = Shape<_256, _256, _32>; // 256x256 WG tile size
182+
using SGLayout =
183+
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>; // 8x4 SG tiling, n-major
184+
185+
using MMA = typename TiledMMAHelper<MMA_Atom<decltype(op)>, Layout<WGTile>,
186+
SGLayout>::TiledMMA;
187+
188+
return MMA{};
189+
}
190+
191+
// type tag to define a unique sycl kernel name
192+
template <class, class, class, char, char> class GemmCuteName;
193+
194+
template <char layoutA, char layoutB, class ElementA, class ElementB,
195+
class ElementS, class ElementD>
196+
void MoEGEMMLauncher(const ElementA *activations, const ElementB *weights,
197+
const ElementS *scales, ElementD *outputs,
198+
const int gemm_n, const int gemm_k,
199+
const int *num_rows_per_expert_device,
200+
const int num_experts) {
201+
// Change device_id to another value if you are running on a machine with
202+
// multiple GPUs and wish to use a GPU other than that with device ID 0.
203+
// For example, in a framework, you could query device ID.
204+
int sm_count =
205+
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0);
206+
cutlass::KernelHardwareInfo hw_info{0, sm_count};
207+
// TODO: Add launch code as per the revised implementation
208+
auto dummy_problem_shape = cute::Shape<int, int, int>{1, gemm_k, gemm_n};
209+
// I forgot why I used this hack
210+
auto dummy_group_problem_shape =
211+
cutlass::gemm::GroupProblemShape<Shape<int, int, int>>{
212+
1, &dummy_problem_shape, nullptr};
213+
using TileShape = Shape<_256, _256, _32>;
214+
using ClusterShape = Shape<_1, _1, _1>;
215+
auto scheduler_params = MoE::PersistentTileSchedulerXeMoE<MoE::ProblemShape>::
216+
to_underlying_arguments(
217+
dummy_group_problem_shape, TileShape{}, ClusterShape{}, hw_info,
218+
MoE::PersistentTileSchedulerXeMoE<MoE::ProblemShape>::Arguments{
219+
1, MoE::RasterOrderOptions::AlongN});
220+
auto group_distribution =
221+
MoE::PersistentTileSchedulerXeMoE<MoE::ProblemShape>::get_grid_shape(
222+
scheduler_params, dummy_group_problem_shape, TileShape{},
223+
ClusterShape{}, hw_info,
224+
MoE::PersistentTileSchedulerXeMoE<MoE::ProblemShape>::Arguments{
225+
1, MoE::RasterOrderOptions::AlongN});
226+
auto mma = choose_tiled_mma(activations, weights, outputs);
227+
auto MaxThreadsPerWorkgroup = size(mma);
228+
dim3 local_range{MaxThreadsPerWorkgroup, 1, 1};
229+
230+
sycl::range<3> local = {local_range.x, local_range.y, local_range.z};
231+
sycl::range<3> groups = {group_distribution.x, group_distribution.y,
232+
group_distribution.z};
233+
sycl::range<3> global = {local[0] * groups[0], local[1] * groups[1],
234+
local[2] * groups[2]};
235+
236+
namespace syclex = sycl::ext::oneapi::experimental;
237+
namespace intelex = sycl::ext::intel::experimental;
238+
239+
syclex::properties kernel_props{syclex::sub_group_size<16>,
240+
intelex::grf_size<256>};
241+
sycl::queue Q = compat::get_default_queue();
242+
auto event = Q.parallel_for<
243+
GemmCuteName<ElementA, ElementB, ElementD, layoutA, layoutB>>(
244+
sycl::nd_range<3>(global, local), kernel_props, [=](auto) {
245+
MoE::MoEGEMM<XE_LOAD_2D<16, 32, 32, 16>,
246+
XE_LOAD_2D_VNNI<16, 32, 16, 16>, XE_STORE_2D<16, 8, 16>,
247+
'R', 'R', 'R'>(activations, weights, scales, outputs, mma,
248+
num_rows_per_expert_device, num_experts,
249+
gemm_n, gemm_k, scheduler_params);
250+
});
251+
EventManager::getInstance().addEvent(event);
252+
Q.wait_and_throw();
253+
}
254+
255+
void launcher(int *M_per_expert, int N, int K, const int &num_experts) {
256+
int n_moe = N;
257+
int k_moe = K;
258+
int num_tokens_incl_duplicated = 0;
259+
for (int i = 0; i < num_experts; i++) {
260+
num_tokens_incl_duplicated += M_per_expert[i];
261+
}
262+
263+
float M_occupancy = 0.f;
264+
float actual_num_units = 0.f;
265+
int total_num_M_tiles = 0;
266+
for (int i = 0; i < num_experts; i++) {
267+
total_num_M_tiles += (M_per_expert[i] + 255) / 256;
268+
actual_num_units += M_per_expert[i] / 256.0;
269+
}
270+
M_occupancy = actual_num_units / total_num_M_tiles;
271+
std::cout << "\n\n M-occupancy is " << M_occupancy << std::endl;
272+
cutlass::DeviceAllocation<int32_t> num_rows_per_expert_device;
273+
cutlass::DeviceAllocation<bfloat16_t> activations_data;
274+
cutlass::DeviceAllocation<bfloat16_t> weights_data;
275+
cutlass::DeviceAllocation<bfloat16_t> output_data;
276+
size_t A_size = num_tokens_incl_duplicated * k_moe;
277+
size_t B_size = num_experts * n_moe * k_moe;
278+
size_t D_size = num_tokens_incl_duplicated * n_moe;
279+
num_rows_per_expert_device.reset(num_experts);
280+
num_rows_per_expert_device.copy_from_host(M_per_expert);
281+
activations_data.reset(A_size);
282+
weights_data.reset(B_size);
283+
output_data.reset(D_size);
284+
uint64_t seed = 2023;
285+
initialize_block(activations_data, seed + 2023);
286+
initialize_block(weights_data, seed + 2022);
287+
initialize_block(output_data, seed + 2021);
288+
289+
MoEGEMMLauncher<'R', 'R'>(activations_data.get(), weights_data.get(),
290+
static_cast<void *>(nullptr), output_data.get(),
291+
n_moe, k_moe, num_rows_per_expert_device.get(),
292+
num_experts);
293+
294+
activations_data.release();
295+
weights_data.release();
296+
output_data.release();
297+
num_rows_per_expert_device.release();
298+
}
299+
300+
int main(int argc, const char **argv) {
301+
constexpr int num_experts = 32;
302+
constexpr int num_layers = 24;
303+
304+
int total_rows_for_each_expert[num_layers][num_experts] = {
305+
{148, 231, 404, 180, 127, 244, 224, 244, 110, 617, 289,
306+
845, 191, 424, 30, 97, 57, 324, 62, 77, 75, 144,
307+
250, 287, 629, 370, 161, 101, 215, 113, 224, 35},
308+
{666, 214, 448, 87, 4, 28, 48, 13, 74, 40, 546,
309+
397, 487, 350, 26, 95, 517, 487, 295, 58, 637, 97,
310+
139, 33, 126, 15, 352, 311, 995, 193, 135, 135},
311+
{1016, 30, 36, 452, 469, 473, 232, 0, 493, 14, 954,
312+
6, 4, 6, 279, 3, 94, 106, 96, 48, 49, 113,
313+
142, 169, 75, 99, 25, 220, 249, 289, 4, 1803},
314+
{350, 229, 703, 154, 8, 64, 80, 339, 2, 56, 5,
315+
312, 1005, 29, 9, 11, 23, 0, 23, 431, 48, 129,
316+
496, 476, 8, 1234, 7, 130, 34, 58, 41, 1554},
317+
{39, 10, 6, 2, 110, 1, 894, 8, 53, 0, 275,
318+
6, 506, 421, 700, 178, 0, 530, 1623, 15, 231, 74,
319+
6, 222, 1246, 116, 35, 20, 0, 6, 381, 334},
320+
{399, 5, 201, 6, 134, 93, 1748, 1, 51, 4, 38,
321+
336, 53, 88, 328, 724, 15, 388, 706, 52, 19, 55,
322+
52, 33, 623, 1, 222, 215, 69, 45, 308, 1036},
323+
{11, 8, 407, 571, 458, 275, 197, 211, 13, 564, 462,
324+
114, 15, 13, 132, 24, 514, 2, 71, 13, 694, 47,
325+
16, 203, 610, 40, 0, 1587, 66, 23, 196, 491},
326+
{0, 230, 116, 136, 315, 643, 6, 183, 37, 26, 960,
327+
1, 8, 258, 21, 1602, 213, 198, 6, 196, 455, 557,
328+
47, 282, 493, 18, 101, 11, 616, 45, 268, 0},
329+
{392, 305, 179, 14, 227, 98, 114, 39, 64, 1456, 465,
330+
0, 18, 372, 0, 0, 189, 257, 25, 290, 486, 0,
331+
12, 1534, 468, 4, 555, 35, 146, 0, 161, 143},
332+
{4, 107, 20, 125, 236, 898, 0, 0, 375, 2, 125,
333+
0, 0, 1429, 36, 195, 1660, 0, 127, 454, 73, 358,
334+
47, 79, 32, 20, 1465, 0, 0, 6, 109, 66},
335+
{19, 0, 0, 0, 2, 1638, 75, 135, 392, 2, 1494, 3, 23, 5, 4, 58,
336+
0, 0, 71, 1285, 8, 441, 0, 145, 209, 408, 450, 2, 824, 13, 326, 16},
337+
{4, 2, 14, 0, 30, 206, 41, 131, 0, 429, 16, 895, 35, 21, 44, 128,
338+
12, 0, 417, 0, 838, 917, 42, 115, 109, 1759, 0, 36, 17, 0, 1790, 0},
339+
{6, 483, 241, 1327, 17, 11, 480, 9, 880, 58, 4,
340+
0, 61, 30, 16, 176, 9, 309, 26, 0, 0, 1882,
341+
4, 281, 475, 783, 197, 0, 19, 15, 6, 243},
342+
{370, 1222, 0, 6, 108, 929, 2, 7, 157, 348, 149, 106, 2, 5, 25, 33,
343+
1569, 8, 6, 106, 69, 1298, 0, 2, 529, 520, 0, 421, 0, 25, 26, 0},
344+
{59, 89, 0, 26, 25, 40, 1873, 141, 527, 371, 262,
345+
62, 16, 0, 127, 234, 1637, 64, 132, 8, 0, 7,
346+
161, 1005, 22, 1, 49, 6, 83, 925, 80, 16},
347+
{269, 617, 30, 4, 90, 26, 0, 16, 154, 212, 21,
348+
269, 379, 174, 129, 32, 8, 121, 344, 15, 0, 591,
349+
1494, 6, 737, 50, 112, 856, 483, 25, 454, 330},
350+
{0, 98, 1488, 22, 73, 0, 0, 343, 77, 4, 0,
351+
612, 165, 268, 4, 10, 43, 0, 598, 271, 2, 73,
352+
185, 0, 112, 779, 24, 1626, 0, 0, 0, 1171},
353+
{0, 0, 0, 189, 266, 1743, 0, 462, 20, 7, 668, 310, 40, 0, 10, 236,
354+
423, 18, 0, 0, 0, 999, 0, 139, 1754, 8, 619, 3, 23, 0, 102, 9},
355+
{131, 1753, 0, 113, 24, 94, 2, 12, 108, 0, 0,
356+
252, 97, 0, 1319, 233, 93, 1254, 195, 152, 14, 413,
357+
4, 2, 220, 67, 20, 4, 34, 559, 837, 42},
358+
{55, 76, 0, 8, 0, 3, 1557, 975, 135, 271, 4, 0, 0, 666, 207, 152,
359+
5, 2, 97, 364, 0, 13, 1423, 771, 159, 31, 223, 0, 431, 7, 409, 4},
360+
{4, 1026, 1799, 166, 694, 753, 0, 16, 0, 240, 1119, 19, 6, 0, 46, 659,
361+
10, 0, 112, 808, 181, 0, 28, 22, 90, 0, 176, 0, 37, 5, 10, 22},
362+
{44, 0, 4, 153, 299, 1357, 6, 23, 0, 12, 4, 419, 73, 24, 16, 24,
363+
1, 4, 4, 102, 16, 4, 0, 1953, 1850, 0, 908, 4, 0, 13, 708, 23},
364+
{6, 13, 123, 28, 197, 0, 202, 69, 0, 6, 0, 21, 1434, 1582, 11, 0, 6,
365+
0, 7, 190, 4, 1700, 6, 434, 1886, 0, 14, 28, 8, 30, 25, 18},
366+
{5, 27, 1442, 18, 0, 6, 0, 73, 6, 781, 0, 1915, 291, 649, 98, 4,
367+
33, 77, 6, 22, 73, 9, 8, 587, 1486, 32, 10, 244, 37, 0, 100, 9}};
368+
369+
for (int i = 0; i < num_layers; i++) {
370+
launcher(total_rows_for_each_expert[i], 5760, 2880, num_experts);
371+
launcher(total_rows_for_each_expert[i], 2880, 2880, num_experts);
372+
}
373+
374+
return 0;
375+
}

0 commit comments

Comments
 (0)