@@ -398,8 +398,12 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
398398 GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
399399 GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
400400 GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
401- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
402- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32,
401+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1,
402+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2,
403+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4,
404+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6,
405+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8,
406+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16,
403407 GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
404408 GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
405409 GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,
@@ -1428,8 +1432,12 @@ @implementation GGMLMetalClass
14281432 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
14291433 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
14301434 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
1431- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
1432- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm);
1435+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1, mul_mm_id_map0_f16_ne20_1, has_simdgroup_mm);
1436+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2, mul_mm_id_map0_f16_ne20_2, has_simdgroup_mm);
1437+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4, mul_mm_id_map0_f16_ne20_4, has_simdgroup_mm);
1438+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6, mul_mm_id_map0_f16_ne20_6, has_simdgroup_mm);
1439+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8, mul_mm_id_map0_f16_ne20_8, has_simdgroup_mm);
1440+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16, mul_mm_id_map0_f16_ne20_16, has_simdgroup_mm);
14331441 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
14341442 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
14351443 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat);
@@ -3908,38 +3916,6 @@ static int ggml_metal_encode_node(
39083916 default : break ;
39093917 }
39103918
3911- const int64_t neh10 = ne10; // n_embd
3912- const int64_t neh11 = ne21; // n_tokens
3913- const int64_t neh12 = ne02; // n_expert
3914-
3915- const uint64_t nbh10 = ggml_type_size (GGML_TYPE_F16);
3916- const uint64_t nbh11 = nbh10*neh10;
3917- const uint64_t nbh12 = nbh11*neh11;
3918- const uint64_t nbh13 = nbh12*neh12;
3919-
3920- const size_t s_src1 = ggml_type_size (GGML_TYPE_F16)*neh10*neh11*neh12;
3921- id <MTLBuffer > h_src1 = ggml_metal_mem_pool_alloc (mem_pool, s_src1);
3922- if (!h_src1) {
3923- GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_src1);
3924- return 0 ;
3925- }
3926-
3927- const int64_t neh0 = ne0;
3928- const int64_t neh1 = ne21;
3929- const int64_t neh2 = ne02;
3930-
3931- const uint64_t nbh0 = ggml_type_size (GGML_TYPE_F32);
3932- const uint64_t nbh1 = nbh0*neh0;
3933- const uint64_t nbh2 = nbh1*neh1;
3934- // const uint64_t nbh3 = nbh2*neh2;
3935-
3936- const size_t s_dst = ggml_type_size (GGML_TYPE_F32)*neh0*neh1*neh2;
3937- id <MTLBuffer > h_dst = ggml_metal_mem_pool_alloc (mem_pool, s_dst);
3938- if (!h_dst) {
3939- GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_dst);
3940- return 0 ;
3941- }
3942-
39433919 // tokens per expert
39443920 const size_t s_tpe = ggml_type_size (GGML_TYPE_I32)*ne02;
39453921 id <MTLBuffer > h_tpe = ggml_metal_mem_pool_alloc (mem_pool, s_tpe);
@@ -3949,41 +3925,54 @@ static int ggml_metal_encode_node(
39493925 }
39503926
39513927 // id map
3952- // [n_expert_used, n_tokens ]
3953- const size_t s_ids = ggml_type_size (GGML_TYPE_I32)*ne20* ne21;
3928+ // [n_tokens, n_expert ]
3929+ const size_t s_ids = ggml_type_size (GGML_TYPE_I32)*ne21*ne02 ;
39543930 id <MTLBuffer > h_ids = ggml_metal_mem_pool_alloc (mem_pool, s_ids);
39553931 if (!h_ids) {
39563932 GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_ids);
39573933 return 0 ;
39583934 }
39593935
39603936 {
3961- const int nth = MIN (1024 , ne10/4 );
3962-
39633937 ggml_metal_kargs_mul_mm_id_map0 args = {
3938+ ne02,
39643939 ne10,
3965- ne11, // n_expert_used (bcast)
3940+ ne11, // n_expert_used (bcast)
39663941 nb11,
39673942 nb12,
3968- neh11, // n_tokens
3969- nbh11,
3970- ne20, // n_expert_used
3943+ ne21, // n_tokens
3944+ ne20, // n_expert_used
39713945 nb21,
39723946 };
39733947
39743948 id <MTLComputePipelineState > pipeline = nil ;
39753949
3976- pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline ;
3950+ pipeline = nil ;
3951+
3952+ switch (ne20) {
3953+ case 1 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1 ].pipeline ; break ;
3954+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2 ].pipeline ; break ;
3955+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4 ].pipeline ; break ;
3956+ case 6 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6 ].pipeline ; break ;
3957+ case 8 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8 ].pipeline ; break ;
3958+ case 16 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16].pipeline ; break ;
3959+ default : GGML_ABORT (" missing specialization for ne20 = %d " , (int ) ne20);
3960+ }
3961+
3962+ GGML_ASSERT (ne02 <= (int ) pipeline.maxTotalThreadsPerThreadgroup );
3963+
3964+ const size_t smem = ne02*ne20*sizeof (uint16_t );
3965+
3966+ GGML_ASSERT (smem <= device.maxThreadgroupMemoryLength );
39773967
39783968 [encoder setComputePipelineState: pipeline];
39793969 [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
3980- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
3981- [encoder setBuffer: id_src2 offset: offs_src2 atIndex: 2 ];
3982- [encoder setBuffer: h_src1 offset: 0 atIndex: 3 ];
3983- [encoder setBuffer: h_tpe offset: 0 atIndex: 4 ];
3984- [encoder setBuffer: h_ids offset: 0 atIndex: 5 ];
3970+ [encoder setBuffer: id_src2 offset: offs_src2 atIndex: 1 ];
3971+ [encoder setBuffer: h_tpe offset: 0 atIndex: 2 ];
3972+ [encoder setBuffer: h_ids offset: 0 atIndex: 3 ];
3973+ [encoder setThreadgroupMemoryLength: smem atIndex: 0 ];
39853974
3986- [encoder dispatchThreadgroups: MTLSizeMake (ne02 , 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (nth , 1 , 1 )];
3975+ [encoder dispatchThreadgroups: MTLSizeMake (1 , 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (ne02 , 1 , 1 )];
39873976 }
39883977
39893978 {
@@ -4022,56 +4011,30 @@ static int ggml_metal_encode_node(
40224011 /* .nb01 =*/ nb01,
40234012 /* .nb02 =*/ nb02,
40244013 /* .nb03 =*/ nb03,
4025- /* .neh12 =*/ neh12,
4026- /* .nbh10 =*/ nbh10,
4027- /* .nbh11 =*/ nbh11,
4028- /* .nbh12 =*/ nbh12,
4029- /* .nbh13 =*/ nbh13,
4030- /* .neh0 =*/ neh0,
4031- /* .neh1 =*/ neh1,
4014+ /* .ne11 =*/ ne11, // n_expert_used (bcast)
4015+ /* .nb10 =*/ nb10,
4016+ /* .nb11 =*/ nb11,
4017+ /* .nb12 =*/ nb12,
4018+ /* .nb13 =*/ nb13,
4019+ /* .ne20 =*/ ne20, // n_expert_used
4020+ /* .ne21 =*/ ne21, // n_tokens
4021+ /* .ne0 =*/ ne0,
4022+ /* .ne1 =*/ ne1,
40324023 /* .r2 =*/ r2,
40334024 /* .r3 =*/ r3,
40344025 };
40354026
40364027 [encoder setComputePipelineState: pipeline];
40374028 [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
40384029 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
4039- [encoder setBuffer: h_src1 offset: 0 atIndex: 2 ];
4030+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
40404031 [encoder setBuffer: h_tpe offset: 0 atIndex: 3 ];
4041- [encoder setBuffer: h_dst offset: 0 atIndex: 4 ];
4032+ [encoder setBuffer: h_ids offset: 0 atIndex: 4 ];
4033+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 5 ];
40424034
40434035 [encoder setThreadgroupMemoryLength: 8192 atIndex: 0 ];
40444036 [encoder dispatchThreadgroups: MTLSizeMake ((ne21 + 31 )/32 , (ne01 + 63 )/64 , ne02) threadsPerThreadgroup: MTLSizeMake (128 , 1 , 1 )];
40454037 }
4046-
4047- {
4048- GGML_ASSERT (ne0 % 4 == 0 );
4049-
4050- const int nth = MIN (1024 , ne0/4 );
4051-
4052- ggml_metal_kargs_mul_mm_id_map1 args = {
4053- ne20, // n_expert_used
4054- neh0,
4055- neh1,
4056- nbh1,
4057- nbh2,
4058- ne0,
4059- nb1,
4060- nb2,
4061- };
4062-
4063- id <MTLComputePipelineState > pipeline = nil ;
4064-
4065- pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline ;
4066-
4067- [encoder setComputePipelineState: pipeline];
4068- [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
4069- [encoder setBuffer: h_dst offset: 0 atIndex: 1 ];
4070- [encoder setBuffer: h_ids offset: 0 atIndex: 2 ];
4071- [encoder setBuffer: id_dst offset: offs_dst atIndex: 3 ];
4072-
4073- [encoder dispatchThreadgroups: MTLSizeMake (ne20, ne21, 1 ) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
4074- }
40754038 } else {
40764039 id <MTLComputePipelineState > pipeline = nil ;
40774040
0 commit comments