diff --git a/projects/composablekernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp b/projects/composablekernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp index 442fbbf846bc..ecd9c402df10 100644 --- a/projects/composablekernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp +++ b/projects/composablekernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp @@ -30,7 +30,8 @@ enum Activation { gelu_and_mul = 0, silu_and_mul = 1, - swiglustep_and_mul = 2 + swiglustep_and_mul = 2, + swiglu_oai_and_mul = 3 }; template ; @@ -1506,6 +1525,26 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base< } c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up); } + else if constexpr(ActivationOperation == + Activation::swiglu_oai_and_mul) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + float gate = scale_a * scale_b * c_thread_buf[cidx]; + float up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + c_thread_buf_fp32(cidx) = apply_swiglu_oai_activation(gate, up); + } } else { @@ -1579,6 +1618,18 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base< } c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up); } + else if constexpr(ActivationOperation == + Activation::swiglu_oai_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; + } + c_thread_buf_fp32(cidx) = apply_swiglu_oai_activation(gate, up); + } } else { @@ -2011,6 +2062,26 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base< } c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up); } + else if constexpr(ActivationOperation == + Activation::swiglu_oai_and_mul) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + float gate = scale_a * scale_b * c_thread_buf[cidx]; + float up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + c_thread_buf_fp32(cidx) = apply_swiglu_oai_activation(gate, up); + } } else { @@ -2084,6 +2155,18 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base< } c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up); } + else if constexpr(ActivationOperation == + Activation::swiglu_oai_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; + } + c_thread_buf_fp32(cidx) = apply_swiglu_oai_activation(gate, up); + } } else { diff --git a/projects/composablekernel/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp b/projects/composablekernel/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp index 5b791d3668c8..4e8b954afdb2 100644 --- a/projects/composablekernel/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp +++ b/projects/composablekernel/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp @@ -1140,6 +1140,9 @@ struct GridwiseMoeGemmBlockScale BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) { + static_assert(ActivationOperation != Activation::swiglu_oai_and_mul, + "gridwise_moe_gemm_blockscale does not support swiglu_oai_and_mul; use the " + "non-blockscale gridwise_moe_gemm."); #if defined(__gfx942__) || defined(__gfx950__) constexpr auto b_coherence_flag = NonTemporalLoadB ? AmdBufferCoherenceEnum::WAVE_NT1 @@ -1694,6 +1697,9 @@ struct GridwiseMoeGemmBlockScale BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) { + static_assert(ActivationOperation != Activation::swiglu_oai_and_mul, + "gridwise_moe_gemm_blockscale does not support swiglu_oai_and_mul; use the " + "non-blockscale gridwise_moe_gemm."); #if defined(__gfx942__) || defined(__gfx950__) constexpr auto b_coherence_flag = NonTemporalLoadB ? AmdBufferCoherenceEnum::WAVE_NT1