Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions csrc/trtllm_fused_moe_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,12 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_float8_e4m3fn) << "gemm1_weights must be fp8.";
TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_float8_e4m3fn) << "gemm2_weights must be fp8.";

if (static_cast<RoutingMethodType>(routing_method_type) ==
RoutingMethodType::DeepSeekV3)
{
TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_float32) << "routing_logits must be float for DeepSeekV3 Routing method.";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The routing_logits member is of type Optional<TensorView>, so to access its dtype, you need to use .value().dtype(). The current code routing_logits.dtype() will not compile.

For better code organization, you might also consider moving this check to the check_routing() method (around line 754), as it's a check related to routing parameters and other routing_method_type checks are already there.

      TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_float32) << "routing_logits must be float for DeepSeekV3 Routing method.";

}

TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_float32)
<< "gemm1_weights_scale must be float.";
TVM_FFI_ICHECK_EQ(gemm1_weights_scale.ndim(), 3) << "gemm1_weights_scale must be 3D.";
Expand Down
Loading