|
| 1 | +/*************************************************************************************************** |
| 2 | + * Copyright (c) 2025 Intel Corporation. All rights reserved. |
| 3 | + * SPDX-License-Identifier: BSD-3-Clause |
| 4 | + * |
| 5 | + * Redistribution and use in source and binary forms, with or without |
| 6 | + * modification, are permitted provided that the following conditions are met: |
| 7 | + * |
| 8 | + * 1. Redistributions of source code must retain the above copyright notice, |
| 9 | + *this list of conditions and the following disclaimer. |
| 10 | + * |
| 11 | + * 2. Redistributions in binary form must reproduce the above copyright notice, |
| 12 | + * this list of conditions and the following disclaimer in the documentation |
| 13 | + * and/or other materials provided with the distribution. |
| 14 | + * |
| 15 | + * 3. Neither the name of the copyright holder nor the names of its |
| 16 | + * contributors may be used to endorse or promote products derived from |
| 17 | + * this software without specific prior written permission. |
| 18 | + * |
| 19 | + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" |
| 20 | + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
| 21 | + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE |
| 22 | + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE |
| 23 | + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR |
| 24 | + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF |
| 25 | + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS |
| 26 | + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN |
| 27 | + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) |
| 28 | + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE |
| 29 | + *POSSIBILITY OF SUCH DAMAGE. |
| 30 | + * |
| 31 | + **************************************************************************************************/ |
| 32 | +/*! \file |
| 33 | + \brief CUTLASS Intel BMG MoE API example based on sycl-tla Group GEMM |
| 34 | +
|
| 35 | +*/ |
| 36 | + |
| 37 | +#include "cutlass/util/GPU_Clock.hpp" |
| 38 | + |
| 39 | +#include <cute/tensor.hpp> |
| 40 | +#include <random> |
| 41 | + |
| 42 | +#include <cute/util/compat.hpp> |
| 43 | +#include <sycl/ext/intel/experimental/grf_size_properties.hpp> |
| 44 | +#include <sycl/sycl.hpp> |
| 45 | + |
| 46 | +#include <cute/tensor.hpp> |
| 47 | + |
| 48 | +#include "cutlass/kernel_hardware_info.h" |
| 49 | +#include "cutlass/platform/platform.h" |
| 50 | +#include "cutlass/tensor_ref.h" |
| 51 | +#include "cutlass/util/GPU_Clock.hpp" |
| 52 | +#include "cutlass/util/device_memory.h" |
| 53 | +#include "cutlass/util/initialize_block.hpp" |
| 54 | +#include "cutlass/util/reference/device/gemm_complex.h" |
| 55 | +#include "cutlass/util/reference/device/tensor_compare.h" |
| 56 | +#include "cutlass/util/reference/host/tensor_fill.h" |
| 57 | +#include "cutlass/util/sycl_event_manager.hpp" |
| 58 | + |
| 59 | +#include "../../../common/sycl_cute_common.hpp" |
| 60 | +#include "moe_grouped_gemm.hpp" |
| 61 | +#include "moe_tile_scheduler.hpp" |
| 62 | + |
| 63 | +#pragma clang diagnostic ignored "-Wpass-failed" |
| 64 | +#pragma clang diagnostic ignored "-Wdeprecated-declarations" |
| 65 | + |
| 66 | +#include <cfloat> |
| 67 | + |
| 68 | +using namespace cute; |
| 69 | +using namespace MoE; |
| 70 | + |
| 71 | +using ElementAccumulator = float; // <- data type of accumulator |
| 72 | +using ElementComputeEpilogue = float; // <- data type of epilogue operations |
| 73 | +using ElementA = bfloat16_t; // <- data type of elements in input matrix A |
| 74 | +using ElementB = bfloat16_t; // <- data type of elements in input matrix B |
| 75 | +using ElementOutput = bfloat16_t; // <- data type of elements in output matrix D |
| 76 | + |
| 77 | +/////////////////////////////////////////////////////////////////////////////////////////////////// |
| 78 | + |
| 79 | +#define CUTLASS_SYCL_PROFILING_ENABLED |
| 80 | + |
| 81 | +// Command line options parsing |
| 82 | +struct GroupGEMMOptions { |
| 83 | + |
| 84 | + bool error = false; |
| 85 | + bool help = false; |
| 86 | + |
| 87 | + float alpha = 1.f; |
| 88 | + float beta = 0.f; |
| 89 | + int iterations; |
| 90 | + int m = 0, n = 0, k = 0, groups; |
| 91 | + int *num_rows_per_expert = nullptr; |
| 92 | + std::vector<typename MoE::ProblemShape::UnderlyingProblemShape> |
| 93 | + problem_sizes_host; |
| 94 | + |
| 95 | + GroupGEMMOptions() |
| 96 | + : error(false), help(false), alpha(1.f), beta(0.f), iterations(100) {} |
| 97 | + |
| 98 | + void parse(const int num_experts, const int *num_tokens_per_expert_host, |
| 99 | + int moe_n, int moe_k, |
| 100 | + const int *num_tokens_per_expert_device = nullptr) { |
| 101 | + n = moe_n; |
| 102 | + k = moe_k; |
| 103 | + groups = num_experts; |
| 104 | + iterations = 2; |
| 105 | + num_rows_per_expert = const_cast<int *>(num_tokens_per_expert_device); |
| 106 | + assert(groups > 0); |
| 107 | + problem_sizes_host.clear(); |
| 108 | + problem_sizes_host.reserve(groups); |
| 109 | + for (int i = 0; i < groups; i++) { |
| 110 | + problem_sizes_host.push_back({num_tokens_per_expert_host[i], n, k}); |
| 111 | + } |
| 112 | + } |
| 113 | + |
| 114 | + /// Compute performance in GFLOP/s |
| 115 | + std::tuple<double, double, double> |
| 116 | + gflops(double runtime_s, |
| 117 | + std::vector<typename MoE::ProblemShape::UnderlyingProblemShape> |
| 118 | + problem_sizes_host) const { |
| 119 | + // Number of real-valued multiply-adds |
| 120 | + uint64_t fmas = uint64_t(); |
| 121 | + uint64_t bytes_loaded = 0; |
| 122 | + |
| 123 | + for (auto const &problem : problem_sizes_host) { |
| 124 | + auto M = static_cast<uint64_t>(get<0>(problem)); |
| 125 | + auto N = static_cast<uint64_t>(get<1>(problem)); |
| 126 | + auto K = static_cast<uint64_t>(get<2>(problem)); |
| 127 | + fmas += M * N * K; |
| 128 | + bytes_loaded += |
| 129 | + /* sizeof(cutlass::bfloat16_t) */ 2 * (2 * M * N + N * K + M * K); |
| 130 | + } |
| 131 | + // Two flops per multiply-add |
| 132 | + uint64_t flop = uint64_t(2) * uint64_t(fmas); |
| 133 | + double gflop = double(flop) / double(1.0e9); |
| 134 | + double arithmetic_intensity = double(flop) / double(bytes_loaded); |
| 135 | + double peak_mwm_bw = 456.0; |
| 136 | + double gflops_attainable = std::min<double>( |
| 137 | + 117 * double(1.0e12), |
| 138 | + arithmetic_intensity * (peak_mwm_bw * 1024 * 1024 * 1024)); |
| 139 | + double projected_time = flop / gflops_attainable; |
| 140 | + return std::make_tuple(gflop / runtime_s, |
| 141 | + double(bytes_loaded) / 1024 / 1024 / 1024 / |
| 142 | + runtime_s, |
| 143 | + projected_time * 1000); |
| 144 | + } |
| 145 | +}; |
| 146 | + |
| 147 | +/////////////////////////////////////////////////////////////////////////////////////////////////// |
| 148 | + |
| 149 | +template <typename T, char LayoutKind> |
| 150 | +auto make_device_tensor(sycl::queue &Q, int r, int c) { |
| 151 | + using namespace cute; |
| 152 | + auto ptr = make_gmem_ptr(sycl::malloc_device<T>(r * c, Q)); |
| 153 | + auto shape = make_shape(r, c); |
| 154 | + if constexpr (LayoutKind == 'C') |
| 155 | + return make_tensor(ptr, make_layout(shape, make_stride(_1{}, r))); |
| 156 | + else |
| 157 | + return make_tensor(ptr, make_layout(shape, make_stride(c, _1{}))); |
| 158 | +} |
| 159 | + |
| 160 | +template <typename T, size_t = 0> struct is_complete : std::false_type {}; |
| 161 | + |
| 162 | +template <typename T> struct is_complete<T, 0 * sizeof(T)> : std::true_type {}; |
| 163 | + |
| 164 | +template <typename T> |
| 165 | +static constexpr bool is_complete_v = is_complete<T>::value; |
| 166 | + |
| 167 | +// Some of this code has been authored by Peter caday |
| 168 | +template <typename TA, typename TB, typename TC> auto choose_mma_op() { |
| 169 | + if constexpr (is_complete_v<XE_DPAS_TT<8, TC, TA, TB>>) |
| 170 | + return XE_DPAS_TT<8, TC, TA, TB>{}; |
| 171 | + else if constexpr (is_same_v<TA, cute::bfloat16_t>) |
| 172 | + return XE_DPAS_TT<8, float, cute::bfloat16_t>{}; |
| 173 | + else /* Use f16 by default as upconversion sequences are typically faster */ |
| 174 | + return XE_DPAS_TT<8, float, cute::half_t>{}; |
| 175 | +} |
| 176 | + |
| 177 | +template <class TA, class TB, class TD> |
| 178 | +auto choose_tiled_mma(TA *A, TB *B, TD *D) { |
| 179 | + |
| 180 | + auto op = choose_mma_op<TA, TB, TD>(); |
| 181 | + |
| 182 | + using WGTile = Shape<_256, _256, _32>; // 256x256 WG tile size |
| 183 | + using SGLayout = |
| 184 | + Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>; // 8x4 SG tiling, n-major |
| 185 | + |
| 186 | + using MMA = typename TiledMMAHelper<MMA_Atom<decltype(op)>, Layout<WGTile>, |
| 187 | + SGLayout>::TiledMMA; |
| 188 | + |
| 189 | + return MMA{}; |
| 190 | +} |
| 191 | + |
| 192 | +// type tag to define a unique sycl kernel name |
| 193 | +template <class, class, class, char, char> class GemmCuteName; |
| 194 | + |
| 195 | +template <char layoutA, char layoutB, class ElementA, class ElementB, |
| 196 | + class ElementS, class ElementD> |
| 197 | +void MoEGEMMLauncher(const ElementA *activations, const ElementB *weights, |
| 198 | + const ElementS *scales, ElementD *outputs, |
| 199 | + const int gemm_n, const int gemm_k, |
| 200 | + const int *num_rows_per_expert_device, |
| 201 | + const int num_experts) { |
| 202 | + // Change device_id to another value if you are running on a machine with |
| 203 | + // multiple GPUs and wish to use a GPU other than that with device ID 0. |
| 204 | + // For example, in a framework, you could query device ID. |
| 205 | + int sm_count = |
| 206 | + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); |
| 207 | + cutlass::KernelHardwareInfo hw_info{0, sm_count}; |
| 208 | + auto dummy_problem_shape = cute::Shape<int, int, int>{1, gemm_k, gemm_n}; |
| 209 | + // I forgot why I used this hack |
| 210 | + auto dummy_group_problem_shape = |
| 211 | + cutlass::gemm::GroupProblemShape<Shape<int, int, int>>{ |
| 212 | + 1, &dummy_problem_shape, nullptr}; |
| 213 | + using TileShape = Shape<_256, _256, _32>; |
| 214 | + using ClusterShape = Shape<_1, _1, _1>; |
| 215 | + auto scheduler_params = |
| 216 | + PersistentTileSchedulerXeMoE<ProblemShape>::to_underlying_arguments( |
| 217 | + dummy_group_problem_shape, TileShape{}, ClusterShape{}, hw_info, |
| 218 | + PersistentTileSchedulerXeMoE<ProblemShape>::Arguments{ |
| 219 | + 1, RasterOrderOptions::AlongN}); |
| 220 | + auto group_distribution = |
| 221 | + PersistentTileSchedulerXeMoE<ProblemShape>::get_grid_shape( |
| 222 | + scheduler_params, dummy_group_problem_shape, TileShape{}, |
| 223 | + ClusterShape{}, hw_info, |
| 224 | + PersistentTileSchedulerXeMoE<ProblemShape>::Arguments{ |
| 225 | + 1, RasterOrderOptions::AlongN}); |
| 226 | + auto mma = choose_tiled_mma(activations, weights, outputs); |
| 227 | + auto MaxThreadsPerWorkgroup = size(mma); |
| 228 | + dim3 local_range{MaxThreadsPerWorkgroup, 1, 1}; |
| 229 | + |
| 230 | + sycl::range<3> local = {local_range.x, local_range.y, local_range.z}; |
| 231 | + sycl::range<3> groups = {group_distribution.x, group_distribution.y, |
| 232 | + group_distribution.z}; |
| 233 | + sycl::range<3> global = {local[0] * groups[0], local[1] * groups[1], |
| 234 | + local[2] * groups[2]}; |
| 235 | + |
| 236 | + namespace syclex = sycl::ext::oneapi::experimental; |
| 237 | + namespace intelex = sycl::ext::intel::experimental; |
| 238 | + |
| 239 | + syclex::properties kernel_props{syclex::sub_group_size<16>, |
| 240 | + intelex::grf_size<256>}; |
| 241 | + sycl::queue Q = compat::get_default_queue(); |
| 242 | + auto event = Q.parallel_for< |
| 243 | + GemmCuteName<ElementA, ElementB, ElementD, layoutA, layoutB>>( |
| 244 | + sycl::nd_range<3>(global, local), kernel_props, [=](auto) { |
| 245 | + MoE::MoEGEMM<XE_LOAD_2D<16, 32, 32, 16>, |
| 246 | + XE_LOAD_2D_VNNI<16, 32, 16, 16>, XE_STORE_2D<16, 8, 16>, |
| 247 | + 'R', 'R', 'R'>(activations, weights, scales, outputs, mma, |
| 248 | + num_rows_per_expert_device, num_experts, |
| 249 | + gemm_n, gemm_k, scheduler_params); |
| 250 | + }); |
| 251 | + EventManager::getInstance().addEvent(event); |
| 252 | + Q.wait_and_throw(); |
| 253 | +} |
| 254 | + |
| 255 | +void launcher(int *M_per_expert, int N, int K, const int &num_experts) { |
| 256 | + int n_moe = N; |
| 257 | + int k_moe = K; |
| 258 | + int num_tokens_incl_duplicated = 0; |
| 259 | + for (int i = 0; i < num_experts; i++) { |
| 260 | + num_tokens_incl_duplicated += M_per_expert[i]; |
| 261 | + } |
| 262 | + |
| 263 | + float M_occupancy = 0.f; |
| 264 | + float actual_num_units = 0.f; |
| 265 | + int total_num_M_tiles = 0; |
| 266 | + for (int i = 0; i < num_experts; i++) { |
| 267 | + total_num_M_tiles += (M_per_expert[i] + 255) / 256; |
| 268 | + actual_num_units += M_per_expert[i] / 256.0; |
| 269 | + } |
| 270 | + M_occupancy = actual_num_units / total_num_M_tiles; |
| 271 | + std::cout << "\n\n M-occupancy is " << M_occupancy << std::endl; |
| 272 | + cutlass::DeviceAllocation<int32_t> num_rows_per_expert_device; |
| 273 | + cutlass::DeviceAllocation<bfloat16_t> activations_data; |
| 274 | + cutlass::DeviceAllocation<bfloat16_t> weights_data; |
| 275 | + cutlass::DeviceAllocation<bfloat16_t> output_data; |
| 276 | + size_t A_size = num_tokens_incl_duplicated * k_moe; |
| 277 | + size_t B_size = num_experts * n_moe * k_moe; |
| 278 | + size_t D_size = num_tokens_incl_duplicated * n_moe; |
| 279 | + num_rows_per_expert_device.reset(num_experts); |
| 280 | + num_rows_per_expert_device.copy_from_host(M_per_expert); |
| 281 | + activations_data.reset(A_size); |
| 282 | + weights_data.reset(B_size); |
| 283 | + output_data.reset(D_size); |
| 284 | + uint64_t seed = 2023; |
| 285 | + initialize_block(activations_data, seed + 2023); |
| 286 | + initialize_block(weights_data, seed + 2022); |
| 287 | + initialize_block(output_data, seed + 2021); |
| 288 | + |
| 289 | + MoEGEMMLauncher<'R', 'R'>(activations_data.get(), weights_data.get(), |
| 290 | + static_cast<void *>(nullptr), output_data.get(), |
| 291 | + n_moe, k_moe, num_rows_per_expert_device.get(), |
| 292 | + num_experts); |
| 293 | + |
| 294 | + activations_data.release(); |
| 295 | + weights_data.release(); |
| 296 | + output_data.release(); |
| 297 | + num_rows_per_expert_device.release(); |
| 298 | +} |
| 299 | + |
| 300 | +int main(int argc, const char **argv) { |
| 301 | + constexpr int num_experts = 32; |
| 302 | + constexpr int num_layers = 24; |
| 303 | + |
| 304 | + int total_rows_for_each_expert[num_layers][num_experts] = { |
| 305 | + {148, 231, 404, 180, 127, 244, 224, 244, 110, 617, 289, |
| 306 | + 845, 191, 424, 30, 97, 57, 324, 62, 77, 75, 144, |
| 307 | + 250, 287, 629, 370, 161, 101, 215, 113, 224, 35}, |
| 308 | + {666, 214, 448, 87, 4, 28, 48, 13, 74, 40, 546, |
| 309 | + 397, 487, 350, 26, 95, 517, 487, 295, 58, 637, 97, |
| 310 | + 139, 33, 126, 15, 352, 311, 995, 193, 135, 135}, |
| 311 | + {1016, 30, 36, 452, 469, 473, 232, 0, 493, 14, 954, |
| 312 | + 6, 4, 6, 279, 3, 94, 106, 96, 48, 49, 113, |
| 313 | + 142, 169, 75, 99, 25, 220, 249, 289, 4, 1803}, |
| 314 | + {350, 229, 703, 154, 8, 64, 80, 339, 2, 56, 5, |
| 315 | + 312, 1005, 29, 9, 11, 23, 0, 23, 431, 48, 129, |
| 316 | + 496, 476, 8, 1234, 7, 130, 34, 58, 41, 1554}, |
| 317 | + {39, 10, 6, 2, 110, 1, 894, 8, 53, 0, 275, |
| 318 | + 6, 506, 421, 700, 178, 0, 530, 1623, 15, 231, 74, |
| 319 | + 6, 222, 1246, 116, 35, 20, 0, 6, 381, 334}, |
| 320 | + {399, 5, 201, 6, 134, 93, 1748, 1, 51, 4, 38, |
| 321 | + 336, 53, 88, 328, 724, 15, 388, 706, 52, 19, 55, |
| 322 | + 52, 33, 623, 1, 222, 215, 69, 45, 308, 1036}, |
| 323 | + {11, 8, 407, 571, 458, 275, 197, 211, 13, 564, 462, |
| 324 | + 114, 15, 13, 132, 24, 514, 2, 71, 13, 694, 47, |
| 325 | + 16, 203, 610, 40, 0, 1587, 66, 23, 196, 491}, |
| 326 | + {0, 230, 116, 136, 315, 643, 6, 183, 37, 26, 960, |
| 327 | + 1, 8, 258, 21, 1602, 213, 198, 6, 196, 455, 557, |
| 328 | + 47, 282, 493, 18, 101, 11, 616, 45, 268, 0}, |
| 329 | + {392, 305, 179, 14, 227, 98, 114, 39, 64, 1456, 465, |
| 330 | + 0, 18, 372, 0, 0, 189, 257, 25, 290, 486, 0, |
| 331 | + 12, 1534, 468, 4, 555, 35, 146, 0, 161, 143}, |
| 332 | + {4, 107, 20, 125, 236, 898, 0, 0, 375, 2, 125, |
| 333 | + 0, 0, 1429, 36, 195, 1660, 0, 127, 454, 73, 358, |
| 334 | + 47, 79, 32, 20, 1465, 0, 0, 6, 109, 66}, |
| 335 | + {19, 0, 0, 0, 2, 1638, 75, 135, 392, 2, 1494, 3, 23, 5, 4, 58, |
| 336 | + 0, 0, 71, 1285, 8, 441, 0, 145, 209, 408, 450, 2, 824, 13, 326, 16}, |
| 337 | + {4, 2, 14, 0, 30, 206, 41, 131, 0, 429, 16, 895, 35, 21, 44, 128, |
| 338 | + 12, 0, 417, 0, 838, 917, 42, 115, 109, 1759, 0, 36, 17, 0, 1790, 0}, |
| 339 | + {6, 483, 241, 1327, 17, 11, 480, 9, 880, 58, 4, |
| 340 | + 0, 61, 30, 16, 176, 9, 309, 26, 0, 0, 1882, |
| 341 | + 4, 281, 475, 783, 197, 0, 19, 15, 6, 243}, |
| 342 | + {370, 1222, 0, 6, 108, 929, 2, 7, 157, 348, 149, 106, 2, 5, 25, 33, |
| 343 | + 1569, 8, 6, 106, 69, 1298, 0, 2, 529, 520, 0, 421, 0, 25, 26, 0}, |
| 344 | + {59, 89, 0, 26, 25, 40, 1873, 141, 527, 371, 262, |
| 345 | + 62, 16, 0, 127, 234, 1637, 64, 132, 8, 0, 7, |
| 346 | + 161, 1005, 22, 1, 49, 6, 83, 925, 80, 16}, |
| 347 | + {269, 617, 30, 4, 90, 26, 0, 16, 154, 212, 21, |
| 348 | + 269, 379, 174, 129, 32, 8, 121, 344, 15, 0, 591, |
| 349 | + 1494, 6, 737, 50, 112, 856, 483, 25, 454, 330}, |
| 350 | + {0, 98, 1488, 22, 73, 0, 0, 343, 77, 4, 0, |
| 351 | + 612, 165, 268, 4, 10, 43, 0, 598, 271, 2, 73, |
| 352 | + 185, 0, 112, 779, 24, 1626, 0, 0, 0, 1171}, |
| 353 | + {0, 0, 0, 189, 266, 1743, 0, 462, 20, 7, 668, 310, 40, 0, 10, 236, |
| 354 | + 423, 18, 0, 0, 0, 999, 0, 139, 1754, 8, 619, 3, 23, 0, 102, 9}, |
| 355 | + {131, 1753, 0, 113, 24, 94, 2, 12, 108, 0, 0, |
| 356 | + 252, 97, 0, 1319, 233, 93, 1254, 195, 152, 14, 413, |
| 357 | + 4, 2, 220, 67, 20, 4, 34, 559, 837, 42}, |
| 358 | + {55, 76, 0, 8, 0, 3, 1557, 975, 135, 271, 4, 0, 0, 666, 207, 152, |
| 359 | + 5, 2, 97, 364, 0, 13, 1423, 771, 159, 31, 223, 0, 431, 7, 409, 4}, |
| 360 | + {4, 1026, 1799, 166, 694, 753, 0, 16, 0, 240, 1119, 19, 6, 0, 46, 659, |
| 361 | + 10, 0, 112, 808, 181, 0, 28, 22, 90, 0, 176, 0, 37, 5, 10, 22}, |
| 362 | + {44, 0, 4, 153, 299, 1357, 6, 23, 0, 12, 4, 419, 73, 24, 16, 24, |
| 363 | + 1, 4, 4, 102, 16, 4, 0, 1953, 1850, 0, 908, 4, 0, 13, 708, 23}, |
| 364 | + {6, 13, 123, 28, 197, 0, 202, 69, 0, 6, 0, 21, 1434, 1582, 11, 0, 6, |
| 365 | + 0, 7, 190, 4, 1700, 6, 434, 1886, 0, 14, 28, 8, 30, 25, 18}, |
| 366 | + {5, 27, 1442, 18, 0, 6, 0, 73, 6, 781, 0, 1915, 291, 649, 98, 4, |
| 367 | + 33, 77, 6, 22, 73, 9, 8, 587, 1486, 32, 10, 244, 37, 0, 100, 9}}; |
| 368 | + |
| 369 | + for (int i = 0; i < num_layers; i++) { |
| 370 | + launcher(total_rows_for_each_expert[i], 5760, 2880, num_experts); |
| 371 | + launcher(total_rows_for_each_expert[i], 2880, 2880, num_experts); |
| 372 | + } |
| 373 | + |
| 374 | + return 0; |
| 375 | +} |
0 commit comments