@@ -14,21 +14,18 @@ namespace gpu::cutlass_kernel {
1414
1515namespace grouped_gemm {
1616void kernel_functor (sycl::queue& stream, void * ptr_A, void * ptr_B, void * ptr_D,
17- void * ptr_alpha, void * ptr_beta, void * offset, int64_t N,
18- int64_t K, int64_t groups);
17+ void * offset, int32_t N, int32_t K, int32_t groups);
1918}
2019
2120/* gemm2(group_A, w2, output, offset) */
2221
2322at::Tensor grouped_gemm_func (at::Tensor& ptr_A, at::Tensor& ptr_B,
24- at::Tensor& ptr_D, at::Tensor& ptr_alpha,
25- at::Tensor& ptr_beta, at::Tensor& offset,
23+ at::Tensor& ptr_D, at::Tensor& tokens_per_expert,
2624 int64_t N, int64_t K, int64_t groups) {
2725 auto & dpcpp_queue = vllm::xpu::vllmGetQueue ();
2826 grouped_gemm::kernel_functor (dpcpp_queue, ptr_A.data_ptr (), ptr_B.data_ptr (),
29- ptr_D.data_ptr (), ptr_alpha.data_ptr (),
30- ptr_beta.data_ptr (), offset.data_ptr (), N, K,
31- groups);
27+ ptr_D.data_ptr (), tokens_per_expert.data_ptr (), (int32_t )N, (int32_t )K,
28+ (int32_t )groups);
3229 return ptr_D;
3330}
3431
0 commit comments