Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 195 additions & 18 deletions csrc/xpu/cutlass_kernels/collective/gemm/moe_array_mma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "cute/algorithm/functional.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cutlass/fp8_to_fp16.h"

/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm {
Expand All @@ -54,21 +55,152 @@ struct MainloopMoE16Group {

} // namespace cutlass::gemm

namespace cutlass {
CUTLASS_HOST_DEVICE
static cutlass::half_t convert_e5m2_to_half(uint8_t const& src) {
uint16_t bits_fp16 = src << 8;
cutlass::half_t result = reinterpret_cast<cutlass::half_t&>(bits_fp16);
return result;
}

CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t convert_e5m2_to_bf16(uint8_t const& src) {
constexpr uint16_t fast_cvt_u16 = 0x7780;
const cutlass::bfloat16_t fast_cvt_bf16 =
reinterpret_cast<const cutlass::bfloat16_t&>(fast_cvt_u16);

uint16_t shifted = src << 8;
int16_t adjusted = (int16_t)shifted >> 3;
uint16_t masked = adjusted & 0x8FFF;
cutlass::bfloat16_t result =
reinterpret_cast<cutlass::bfloat16_t&>(masked) * fast_cvt_bf16;
return result;
}

CUTLASS_HOST_DEVICE
static cutlass::half_t convert_e4m3_to_half(uint8_t const& src) {
constexpr uint16_t fast_cvt_u16_1 = 0x7880;
constexpr uint16_t fast_cvt_u16_2 = 0x1F1C;
const cutlass::half_t fast_cvt_fp16_1 =
reinterpret_cast<const cutlass::half_t&>(fast_cvt_u16_1);
const cutlass::half_t fast_cvt_fp16_2 =
reinterpret_cast<const cutlass::half_t&>(fast_cvt_u16_2);

uint16_t shifted = src << 8;
int16_t adjusted = reinterpret_cast<int16_t&>(shifted) >> 1;
uint16_t masked = reinterpret_cast<uint16_t&>(adjusted) & 0xBFFF;
cutlass::half_t intermediate =
reinterpret_cast<cutlass::half_t&>(masked) * fast_cvt_fp16_1;
cutlass::half_t result = intermediate * fast_cvt_fp16_2;
return result;
}

CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t convert_e4m3_to_bf16(uint8_t const& src) {
constexpr uint16_t fast_cvt_u16 = 0x7B80;
const cutlass::bfloat16_t fast_cvt_bf16 =
reinterpret_cast<const cutlass::bfloat16_t&>(fast_cvt_u16);
constexpr uint16_t fast_cmp_u16 = 0x7F0;
const cutlass::half_t fast_cmp_fp16 =
reinterpret_cast<const cutlass::half_t&>(fast_cmp_u16);

uint16_t shifted = src << 8;
int16_t adjusted = (int16_t)shifted >> 4;
uint16_t masked = adjusted & 0x87FF;
bool is_large =
abs(reinterpret_cast<cutlass::half_t&>(masked)) >= fast_cmp_fp16;

cutlass::bfloat16_t result =
reinterpret_cast<cutlass::bfloat16_t&>(masked) * fast_cvt_bf16;

if (is_large) {
uint16_t temp = reinterpret_cast<uint16_t&>(result) | 0x7FFF;
result = reinterpret_cast<cutlass::bfloat16_t&>(temp);
}

return result;
}

template <>
struct NumericConverter<cutlass::half_t, cutlass::float_e5m2_t,
FloatRoundStyle::round_to_nearest> {
using result_type = cutlass::half_t;
using source_type = cutlass::float_e5m2_t;
static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest;

CUTLASS_HOST_DEVICE
static result_type convert(source_type const& s) {
return convert_e5m2_to_half(s.storage);
}

CUTLASS_HOST_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};

template <>
struct NumericConverter<cutlass::bfloat16_t, cutlass::float_e5m2_t,
FloatRoundStyle::round_to_nearest> {
using result_type = cutlass::bfloat16_t;
using source_type = cutlass::float_e5m2_t;
static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest;

CUTLASS_HOST_DEVICE
static result_type convert(source_type const& s) {
return convert_e5m2_to_bf16(s.storage);
}

CUTLASS_HOST_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};

template <>
struct NumericConverter<cutlass::half_t, cutlass::float_e4m3_t,
FloatRoundStyle::round_to_nearest> {
using result_type = cutlass::half_t;
using source_type = cutlass::float_e4m3_t;
static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest;

CUTLASS_HOST_DEVICE
static result_type convert(source_type const& s) {
return convert_e4m3_to_half(s.storage);
}

CUTLASS_HOST_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};

template <>
struct NumericConverter<cutlass::bfloat16_t, cutlass::float_e4m3_t,
FloatRoundStyle::round_to_nearest> {
using result_type = cutlass::bfloat16_t;
using source_type = cutlass::float_e4m3_t;
static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest;

CUTLASS_HOST_DEVICE
static result_type convert(source_type const& s) {
return convert_e4m3_to_bf16(s.storage);
}

CUTLASS_HOST_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
} // namespace cutlass

/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective {
using namespace cute;
/////////////////////////////////////////////////////////////////////////////////////////////////

template <int Stages, class Schedule, class TileShape_, class ElementA_,
class StrideA_, class ElementB_, class StrideB_, class TiledMma_,
class GmemTiledCopyA_, class SmemLayoutAtomA_, class SmemCopyAtomA_,
class TransformA_, class GmemTiledCopyB_, class SmemLayoutAtomB_,
class SmemCopyAtomB_, class TransformB_>
class StrideA_, class ElementBOptionalTuple_, class StrideB_,
class TiledMma_, class GmemTiledCopyA_, class SmemLayoutAtomA_,
class SmemCopyAtomA_, class TransformA_, class GmemTiledCopyB_,
class SmemLayoutAtomB_, class SmemCopyAtomB_, class TransformB_>
struct CollectiveMma<MainloopMoE16Group<Stages, Schedule>, TileShape_,
ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_,
GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_,
TransformA_, GmemTiledCopyB_, SmemLayoutAtomB_,
SmemCopyAtomB_, TransformB_> {
ElementA_, StrideA_, ElementBOptionalTuple_, StrideB_,
TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_,
SmemCopyAtomA_, TransformA_, GmemTiledCopyB_,
SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_> {
//
// Type Aliases
//
Expand All @@ -77,7 +209,14 @@ struct CollectiveMma<MainloopMoE16Group<Stages, Schedule>, TileShape_,
using ElementA = ElementA_;
using StrideA = StrideA_;
using InternalStrideA = cute::remove_pointer_t<StrideA>;
using ElementB = ElementB_;
using ElementB =
detail::deduce_mixed_width_dtype_t<0, ElementBOptionalTuple_>;
static constexpr bool is_B_fp8_type =
std::is_same_v<ElementB, cutlass::float_e5m2_t> ||
std::is_same_v<ElementB, cutlass::float_e4m3_t>;
using StorageTypeB = std::conditional_t<is_B_fp8_type, uint8_t, ElementB>;
using ElementScaleB =
detail::deduce_mixed_width_dtype_t<1, ElementBOptionalTuple_>;
using StrideB = StrideB_;
using InternalStrideB = cute::remove_pointer_t<StrideB>;
using TiledMma = TiledMma_;
Expand All @@ -92,9 +231,11 @@ struct CollectiveMma<MainloopMoE16Group<Stages, Schedule>, TileShape_,
using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag;

static_assert(
platform::is_same<ElementA, ElementB>::value,
"MainloopIntelXeXMX16Array requires that A and B have same type.");
static_assert((!is_B_fp8_type &&
platform::is_same<ElementA, ElementB>::value) ||
is_B_fp8_type,
"MainloopIntelXeXMX16Array requires that A and B have same "
"type or B is fp8 dtype.");

static_assert(std::is_same_v<TransformA, cute::identity>,
"Transformation for A is not currently supported on Intel PVC");
Expand Down Expand Up @@ -136,13 +277,17 @@ struct CollectiveMma<MainloopMoE16Group<Stages, Schedule>, TileShape_,
using TensorNKL =
decltype(make_tensor(make_gmem_ptr(static_cast<ElementB const*>(nullptr)),
make_shape(0, 0, 0), InternalStrideB{})); //(n, k)
using MainloopTensors = cute::tuple<TensorMKL, TensorNKL>;
using TensorScale = decltype(make_tensor(
make_gmem_ptr(static_cast<ElementScaleB const*>(nullptr)),
make_shape(0, 0, 0))); //(1, 1)
using MainloopTensors = cute::tuple<TensorMKL, TensorNKL, TensorScale>;
// Host side kernel arguments
struct Arguments {
ElementA const* ptr_A;
StrideA dA;
ElementB const* ptr_B;
StrideB dB;
void const* ptr_B_scale;
int64_t const* expert_first_token_offset;
};

Expand All @@ -151,6 +296,7 @@ struct CollectiveMma<MainloopMoE16Group<Stages, Schedule>, TileShape_,
StrideA dA;
ElementB const* ptr_B;
StrideB dB;
void const* ptr_B_scale;
int64_t const* expert_first_token_offset;
};

Expand All @@ -173,8 +319,8 @@ struct CollectiveMma<MainloopMoE16Group<Stages, Schedule>, TileShape_,
auto init_N = get<1>(problem_shape_MNK);
auto init_K = get<2>(problem_shape_MNK);

return Params{args.ptr_A, args.dA, args.ptr_B, args.dB,
args.expert_first_token_offset};
return Params{args.ptr_A, args.dA, args.ptr_B,
args.dB, args.ptr_B_scale, args.expert_first_token_offset};
}

template <class ProblemShape>
Expand Down Expand Up @@ -261,7 +407,7 @@ struct CollectiveMma<MainloopMoE16Group<Stages, Schedule>, TileShape_,

Tensor tCrA = make_tensor<ElementA>(
make_fragment_layout(tiled_copy_a, tCgA(_, _, _, 0).shape()));
Tensor tCrB = make_tensor<ElementB>(
Tensor tCrB = make_tensor<StorageTypeB>(
make_fragment_layout(tiled_copy_b, tCgB(_, _, _, 0).shape()));

// Retile registers for copies
Expand Down Expand Up @@ -344,9 +490,27 @@ struct CollectiveMma<MainloopMoE16Group<Stages, Schedule>, TileShape_,
prefetch(tiled_prefetch_b, pBgB(_, _, _, prefetch_k));
}

cute::gemm(tiled_mma, tCrA, tCrB, accum);
if constexpr (is_B_fp8_type) {
constexpr int numel = decltype(size(tCrB))::value;
cutlass::NumericArrayConverter<ElementA, ElementB, numel> convert_op;
auto frag = convert_op(
*reinterpret_cast<const cutlass::Array<ElementB, numel>*>(
tCrB.data()));
Tensor tCrB_xx16 =
make_tensor(make_rmem_ptr<ElementA>(&frag), tCrB.layout());
cute::gemm(tiled_mma, tCrA, tCrB_xx16, accum);
} else {
cute::gemm(tiled_mma, tCrA, tCrB, accum);
}
barrier_wait(barrier_scope);
}
if constexpr (is_B_fp8_type) {
ElementAccumulator B_scale = ElementAccumulator(get<2>(load_tensors)[0]);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(accum); ++i) {
accum(i) *= B_scale;
}
}
}

template <typename ProblemShape_MNKL>
Expand Down Expand Up @@ -382,7 +546,20 @@ struct CollectiveMma<MainloopMoE16Group<Stages, Schedule>, TileShape_,
make_shape(N, K, (int32_t)1),
mainloop_params.dB[next_group]);

return cute::make_tuple(mA, mB);
if constexpr (is_B_fp8_type) {
ElementScaleB const* ptr_B_scale_curr_batch =
reinterpret_cast<ElementScaleB const*>(mainloop_params.ptr_B_scale) +
real_group;
auto ShapeScaleB = make_shape(1, 1, (int32_t)1);
Tensor mB_scale =
make_tensor(make_gmem_ptr(ptr_B_scale_curr_batch), ShapeScaleB);
return cute::make_tuple(mA, mB, mB_scale);
} else {
Tensor mB_scale_empty =
make_tensor(make_gmem_ptr(static_cast<ElementScaleB const*>(nullptr)),
make_shape(0, 0, (int32_t)0));
return cute::make_tuple(mA, mB, mB_scale_empty);
}
}
};

Expand Down
50 changes: 50 additions & 0 deletions csrc/xpu/cutlass_kernels/collective/gemm/moe_dtype_policy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ class moe_policy_base {
using ElementB = float;
using ElementOutput = float;
using ElementScale = float;
using GmemTiledCopyA = cute::XE_2D_TF32x32x16_LD_N;
using GmemTiledCopyB = cute::XE_2D_U32x32x16_LD_N;
using MMAOperation = cute::XE_8x16x8_F32TF32TF32F32_TT;
};

Expand All @@ -23,6 +25,8 @@ class moe_bf16_policy : public moe_policy_base {
using ElementB = cutlass::bfloat16_t;
using ElementOutput = cutlass::bfloat16_t;
using ElementScale = cutlass::bfloat16_t;
using GmemTiledCopyA = cute::XE_2D_U16x32x32_LD_N;
using GmemTiledCopyB = cute::XE_2D_U16x32x32_LD_V;
using MMAOperation = cute::XE_8x16x16_F32BF16BF16F32_TT;
};

Expand All @@ -32,8 +36,54 @@ class moe_fp16_policy : public moe_policy_base {
using ElementB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using ElementScale = cutlass::half_t;
using GmemTiledCopyA = cute::XE_2D_U16x32x32_LD_N;
using GmemTiledCopyB = cute::XE_2D_U16x32x32_LD_V;
using MMAOperation = cute::XE_8x16x16_F32F16F16F32_TT;
};

class moe_e4m3fp16_policy : public moe_policy_base {
public:
using ElementA = cutlass::half_t;
using ElementB = cutlass::float_e4m3_t;
using ElementOutput = cutlass::half_t;
using ElementScale = cutlass::half_t;
using GmemTiledCopyA = cute::XE_2D_U16x32x32_LD_N;
using GmemTiledCopyB = cute::XE_2D_U8x32x32_LD_V;
using MMAOperation = cute::XE_8x16x16_F32F16F16F32_TT;
};

class moe_e5m2fp16_policy : public moe_policy_base {
public:
using ElementA = cutlass::half_t;
using ElementB = cutlass::float_e5m2_t;
using ElementOutput = cutlass::half_t;
using ElementScale = cutlass::half_t;
using GmemTiledCopyA = cute::XE_2D_U16x32x32_LD_N;
using GmemTiledCopyB = cute::XE_2D_U8x32x32_LD_V;
using MMAOperation = cute::XE_8x16x16_F32F16F16F32_TT;
};

class moe_e4m3bf16_policy : public moe_policy_base {
public:
using ElementA = cutlass::bfloat16_t;
using ElementB = cutlass::float_e4m3_t;
using ElementOutput = cutlass::bfloat16_t;
using ElementScale = cutlass::bfloat16_t;
using GmemTiledCopyA = cute::XE_2D_U16x32x32_LD_N;
using GmemTiledCopyB = cute::XE_2D_U8x32x32_LD_V;
using MMAOperation = cute::XE_8x16x16_F32BF16BF16F32_TT;
};

class moe_e5m2bf16_policy : public moe_policy_base {
public:
using ElementA = cutlass::bfloat16_t;
using ElementB = cutlass::float_e5m2_t;
using ElementOutput = cutlass::bfloat16_t;
using ElementScale = cutlass::bfloat16_t;
using GmemTiledCopyA = cute::XE_2D_U16x32x32_LD_N;
using GmemTiledCopyB = cute::XE_2D_U8x32x32_LD_V;
using MMAOperation = cute::XE_8x16x16_F32BF16BF16F32_TT;
};

} // namespace grouped_gemm
} // namespace gpu::cutlass_kernel
Loading