Skip to content

Commit 4ba6b5f

Browse files
committed
Optimized & refactored MoEGEMM
1 parent e50cca0 commit 4ba6b5f

File tree

7 files changed

+129
-396
lines changed

7 files changed

+129
-396
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ if(VLLM_GPU_LANG STREQUAL "SYCL")
171171
set(CUTLASS_ENABLE_HEADERS_ONLY "ON" CACHE BOOL "Enable only the header library")
172172

173173
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
174-
set(CUTLASS_REVISION "9baca2cff3a28590fcd03e55515e2d91ff2cbc8b" CACHE STRING "CUTLASS revision to use")
174+
set(CUTLASS_REVISION "2eeb05da5b801b34114b6b394dcef836fc9a7cc9" CACHE STRING "CUTLASS revision to use")
175175

176176
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
177177
FetchContent_Declare(

csrc/xpu/cutlass_kernels/grouped_gemm.hpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,18 @@ namespace gpu::cutlass_kernel {
1414

1515
namespace grouped_gemm {
1616
void kernel_functor(sycl::queue& stream, void* ptr_A, void* ptr_B, void* ptr_D,
17-
void* ptr_alpha, void* ptr_beta, void* offset, int64_t N,
18-
int64_t K, int64_t groups);
17+
void* offset, int32_t N, int32_t K, int32_t groups);
1918
}
2019

2120
/* gemm2(group_A, w2, output, offset) */
2221

2322
at::Tensor grouped_gemm_func(at::Tensor& ptr_A, at::Tensor& ptr_B,
24-
at::Tensor& ptr_D, at::Tensor& ptr_alpha,
25-
at::Tensor& ptr_beta, at::Tensor& offset,
23+
at::Tensor& ptr_D, at::Tensor& tokens_per_expert,
2624
int64_t N, int64_t K, int64_t groups) {
2725
auto& dpcpp_queue = vllm::xpu::vllmGetQueue();
2826
grouped_gemm::kernel_functor(dpcpp_queue, ptr_A.data_ptr(), ptr_B.data_ptr(),
29-
ptr_D.data_ptr(), ptr_alpha.data_ptr(),
30-
ptr_beta.data_ptr(), offset.data_ptr(), N, K,
31-
groups);
27+
ptr_D.data_ptr(), tokens_per_expert.data_ptr(), (int32_t)N, (int32_t)K,
28+
(int32_t)groups);
3229
return ptr_D;
3330
}
3431

0 commit comments

Comments
 (0)