Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ enum Activation
{
gelu_and_mul = 0,
silu_and_mul = 1,
swiglustep_and_mul = 2
swiglustep_and_mul = 2,
swiglu_oai_and_mul = 3
};
Comment on lines 31 to 35

template <typename ALayout,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,25 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
up = math::min(math::max(up, -kSwiGluClamp), kSwiGluClamp);
return gate * up;
}

// Clamp limit for swiglu_oai_and_mul (gpt-oss / OAI form): gate clamped to <= L,
// up clamped to [-L, L]; L hardcoded to 7.0. alpha = 1.702 per gpt-oss default.
static constexpr float kSwiGluOaiLimit = 7.0f;
static constexpr float kSwiGluOaiAlpha = 1.702f;

// Helper: apply OAI SwiGLU activation gate*sigmoid(alpha*gate)*(up+1) with pre-activation
// clamp (gate upper-bounded, up symmetric). Mirrors ck_tile::moe::Swiglu in
// ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp. Distinct from swiglustep (no +1, no alpha).
// Used by all four swiglu_oai_and_mul epilogue paths (quant/non-quant x pipeline-A/B).
__host__ __device__ static constexpr float apply_swiglu_oai_activation(float gate, float up)
{
gate = math::min(gate, kSwiGluOaiLimit); // gate <= 7
up = math::min(math::max(up, -kSwiGluOaiLimit), kSwiGluOaiLimit); // up in [-7, 7]
// sigmoid(alpha*gate) = 1 / (1 + exp(-alpha*gate)).
const float sig = 1.0f / (1.0f + math::exp(kSwiGluOaiAlpha * -gate));
return gate * sig * (up + 1.0f); // OAI form
}

using mfma_selector =
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB, is_single_rate_mfma>;

Expand Down Expand Up @@ -1506,6 +1525,26 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
}
c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up);
}
else if constexpr(ActivationOperation ==
Activation::swiglu_oai_and_mul)
{
const float scale_up =
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
PerTokenQuant];
float gate = scale_a * scale_b * c_thread_buf[cidx];
float up = scale_a * scale_up * c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.template AsType<float>()[m4];
up = up * topk_weights.template AsType<float>()[m4];
}
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
{
gate *= 16;
up *= 16;
}
c_thread_buf_fp32(cidx) = apply_swiglu_oai_activation(gate, up);
}
}
else
{
Expand Down Expand Up @@ -1579,6 +1618,18 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
}
c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up);
}
else if constexpr(ActivationOperation ==
Activation::swiglu_oai_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.template AsType<float>()[m4];
up = up * topk_weights.template AsType<float>()[m4];
}
c_thread_buf_fp32(cidx) = apply_swiglu_oai_activation(gate, up);
}
}
else
{
Expand Down Expand Up @@ -2011,6 +2062,26 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
}
c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up);
}
else if constexpr(ActivationOperation ==
Activation::swiglu_oai_and_mul)
{
const float scale_up =
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
PerTokenQuant];
float gate = scale_a * scale_b * c_thread_buf[cidx];
float up = scale_a * scale_up * c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.template AsType<float>()[m4];
up = up * topk_weights.template AsType<float>()[m4];
}
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
{
gate *= 16;
up *= 16;
}
c_thread_buf_fp32(cidx) = apply_swiglu_oai_activation(gate, up);
}
}
else
{
Expand Down Expand Up @@ -2084,6 +2155,18 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
}
c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up);
}
else if constexpr(ActivationOperation ==
Activation::swiglu_oai_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.template AsType<float>()[m4];
up = up * topk_weights.template AsType<float>()[m4];
}
c_thread_buf_fp32(cidx) = apply_swiglu_oai_activation(gate, up);
}
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,9 @@ struct GridwiseMoeGemmBlockScale
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
static_assert(ActivationOperation != Activation::swiglu_oai_and_mul,
"gridwise_moe_gemm_blockscale does not support swiglu_oai_and_mul; use the "
"non-blockscale gridwise_moe_gemm.");
#if defined(__gfx942__) || defined(__gfx950__)
constexpr auto b_coherence_flag = NonTemporalLoadB
? AmdBufferCoherenceEnum::WAVE_NT1
Expand Down Expand Up @@ -1694,6 +1697,9 @@ struct GridwiseMoeGemmBlockScale
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
static_assert(ActivationOperation != Activation::swiglu_oai_and_mul,
"gridwise_moe_gemm_blockscale does not support swiglu_oai_and_mul; use the "
"non-blockscale gridwise_moe_gemm.");
#if defined(__gfx942__) || defined(__gfx950__)
constexpr auto b_coherence_flag = NonTemporalLoadB
? AmdBufferCoherenceEnum::WAVE_NT1
Expand Down
Loading