-
Notifications
You must be signed in to change notification settings - Fork 593
Add data type check for deepseek fp4 moe #2165
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Warning Rate limit exceeded@samuellees has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 0 minutes and 58 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. 📒 Files selected for processing (1)
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdded guarded runtime checks in multiple MoE launcher code paths: when Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @samuellees, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the robustness of the FlashInfer library by adding a critical data type check for the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds a data type check for routing_logits when using the DeepSeekV3 routing method with the Fp8BlockScaleLauncher. This improves the robustness of the API. However, the added code contains a syntax error that will prevent compilation. I've provided a comment with a suggested fix for this issue. Note that while the PR title mentions fp4, the change is applied to the fp8 launcher.
| if (static_cast<RoutingMethodType>(routing_method_type) == | ||
| RoutingMethodType::DeepSeekV3) | ||
| { | ||
| TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_float32) << "routing_logits must be float for DeepSeekV3 Routing method."; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The routing_logits member is of type Optional<TensorView>, so to access its dtype, you need to use .value().dtype(). The current code routing_logits.dtype() will not compile.
For better code organization, you might also consider moving this check to the check_routing() method (around line 754), as it's a check related to routing parameters and other routing_method_type checks are already there.
TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_float32) << "routing_logits must be float for DeepSeekV3 Routing method.";
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
838-843: Remove this redundant check — already validated incheck_routing().Per a past review comment, this check was intended to be moved to
check_routing(), not duplicated. Since the dtype validation now exists incheck_routing_common()(lines 215-220) which is called beforecheck_moe(), this block should be removed.void check_moe() const override { FusedMoeLauncher::check_moe_common(); - if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) { - auto const routing_logits_dtype = - routing_logits.has_value() ? routing_logits.value().dtype() : dl_float32; - TVM_FFI_ICHECK_EQ(routing_logits_dtype, dl_float32) - << "routing_logits must be float for DeepSeekV3 Routing method."; - } - TVM_FFI_ICHECK_EQ(hidden_states.dtype(), dl_float8_e4m3fn) << "hidden_states must be fp8.";
🧹 Nitpick comments (3)
csrc/trtllm_fused_moe_kernel_launcher.cu (3)
440-445: Duplicate check — already handled bycheck_routing_common().The call to
FusedMoeLauncher::check_routing_common()on line 438 already performs this exact DeepSeekV3 dtype validation. This block can be removed to avoid redundancy.void check_routing() const override { FusedMoeLauncher::check_routing_common(); - if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) { - auto const routing_logits_dtype = - routing_logits.has_value() ? routing_logits.value().dtype() : dl_float32; - TVM_FFI_ICHECK_EQ(routing_logits_dtype, dl_float32) - << "routing_logits must be float for DeepSeekV3 Routing method."; - } // TODO n_group, topk_group validation? }
800-806: Duplicate check —check_routing_common()already validates this.Line 768 invokes
check_routing_common()which performs this same DeepSeekV3 dtype validation. Consider removing this block to reduce redundancy.TVM_FFI_ICHECK_LE(args->local_num_experts + args->local_expert_offset, args->num_experts) << "num_experts must be greater or equal to local_num_experts + local_expert_offset"; - - if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) { - auto const routing_logits_dtype = - routing_logits.has_value() ? routing_logits.value().dtype() : dl_float32; - TVM_FFI_ICHECK_EQ(routing_logits_dtype, dl_float32) - << "routing_logits must be float for DeepSeekV3 Routing method."; - } }
1038-1043: Duplicate check —check_routing_common()already handles this.Since
FusedMoeLauncher::check_routing_common()is called on line 1037 and already includes the DeepSeekV3 dtype validation, this block is redundant.void check_routing() const override { // First call base class common routing checks FusedMoeLauncher::check_routing_common(); - if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) { - auto const routing_logits_dtype = - routing_logits.has_value() ? routing_logits.value().dtype() : dl_float32; - TVM_FFI_ICHECK_EQ(routing_logits_dtype, dl_float32) - << "routing_logits must be float for DeepSeekV3 Routing method."; - } }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
csrc/trtllm_fused_moe_kernel_launcher.cu(5 hunks)
🔇 Additional comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
215-220: Correct placement for the dtype check.Adding the DeepSeekV3 dtype validation in
check_routing_common()ensures all derived launchers inherit this check automatically. The conditional fallback todl_float32whenrouting_logitsis absent is also appropriate.
|
/bot run |
yzh119
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, cc @jiahanc for viz.
📌 Description
It's so important to have dtype check to ensure API callers use FlashInfer correctly. And warning or errors should be complained if the parameters don't match kernel's requirement.
Or this will bring a lot of debugging efforts from framework side. (eg. PR13761, PR14350, PR14135)
FlashInfer is a great product, and we really hope FlashInfer be greater. Please pay attention to this kind of checks. Thanks a lot. cc @yzh119
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.