Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions csrc/trtllm_fused_moe_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,13 @@ class FusedMoeLauncher {
args->local_expert_offset + args->local_num_experts <= args->num_experts)
<< "expert offset and count must be within valid range";

if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
auto const routing_logits_dtype =
routing_logits.has_value() ? routing_logits.value().dtype() : dl_float32;
TVM_FFI_ICHECK_EQ(routing_logits_dtype, dl_float32)
<< "routing_logits must be float for DeepSeekV3 Routing method.";
}

check_routing_logits_shape();

if (routing_bias.has_value()) {
Expand Down Expand Up @@ -430,6 +437,12 @@ class Bf16MoeLauncher : public FusedMoeLauncher {
void check_routing() const override {
FusedMoeLauncher::check_routing_common();

if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
auto const routing_logits_dtype =
routing_logits.has_value() ? routing_logits.value().dtype() : dl_float32;
TVM_FFI_ICHECK_EQ(routing_logits_dtype, dl_float32)
<< "routing_logits must be float for DeepSeekV3 Routing method.";
}
// TODO n_group, topk_group validation?
}

Expand Down Expand Up @@ -784,6 +797,13 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
TVM_FFI_ICHECK_GT(args->num_experts, args->top_k) << "num_experts must be greater than top_k";
TVM_FFI_ICHECK_LE(args->local_num_experts + args->local_expert_offset, args->num_experts)
<< "num_experts must be greater or equal to local_num_experts + local_expert_offset";

if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
auto const routing_logits_dtype =
routing_logits.has_value() ? routing_logits.value().dtype() : dl_float32;
TVM_FFI_ICHECK_EQ(routing_logits_dtype, dl_float32)
<< "routing_logits must be float for DeepSeekV3 Routing method.";
}
}

void prepare_routing() override {
Expand Down Expand Up @@ -815,6 +835,13 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
void check_moe() const override {
FusedMoeLauncher::check_moe_common();

if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
auto const routing_logits_dtype =
routing_logits.has_value() ? routing_logits.value().dtype() : dl_float32;
TVM_FFI_ICHECK_EQ(routing_logits_dtype, dl_float32)
<< "routing_logits must be float for DeepSeekV3 Routing method.";
}

TVM_FFI_ICHECK_EQ(hidden_states.dtype(), dl_float8_e4m3fn) << "hidden_states must be fp8.";
TVM_FFI_ICHECK_EQ(hidden_states_scale.dtype(), dl_float32)
<< "hidden_states_scale must be float.";
Expand Down Expand Up @@ -1008,6 +1035,12 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {
void check_routing() const override {
// First call base class common routing checks
FusedMoeLauncher::check_routing_common();
if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
auto const routing_logits_dtype =
routing_logits.has_value() ? routing_logits.value().dtype() : dl_float32;
TVM_FFI_ICHECK_EQ(routing_logits_dtype, dl_float32)
<< "routing_logits must be float for DeepSeekV3 Routing method.";
}
}

void prepare_routing() override {
Expand Down