-
Notifications
You must be signed in to change notification settings - Fork 596
[WIP] Add norm fp4quant fusion #2220
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughA new fused RMSNorm + FP4 quantization API, Changes
Sequence DiagramsequenceDiagram
participant User
participant API as rmsnorm_fp4quant()
participant Validate as Input Validation
participant AutoTuner
participant Runner as TunableRunner
participant cuDNN as cuDNN Graph Engine
participant Result as FP4 Output
User->>API: call(input, weight, y_fp4, block_scale, eps, block_size)
API->>Validate: check dtypes, shapes, divisibility, flatten if 3D
Validate-->>API: ok
API->>AutoTuner: build OptimizationProfile / sample inputs
AutoTuner->>Runner: instantiate runner with cuDNN graph & plans
Runner->>cuDNN: enumerate tactics / run plans for tuning
cuDNN-->>Runner: timing/results
AutoTuner-->>API: select best tactic
API->>Runner: execute chosen tactic with real inputs
Runner->>cuDNN: execute graph (RMSNorm -> FP4 block quant)
cuDNN-->>Result: FP4 packed tensor + block scales
Result-->>User: return quantized output
Estimated code review effortπ― 4 (Complex) | β±οΈ ~50 minutes
Poem
Pre-merge checks and finishing touchesβ Failed checks (1 warning)
β Passed checks (2 passed)
β¨ Finishing touches
π§ͺ Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @nv-yunzheq, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a new Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with π and π on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a fused RMSNorm with FP4 quantization kernel, rmsnorm_fp4quant, leveraging cuDNN for enhanced performance. The implementation is well-structured, utilizing cuDNN graphs and an autotuner. The addition of corresponding tests is also a valuable contribution. My review focuses on a few key areas for improvement:
- There's a correctness issue in
rmsnorm_fp4quantrelated to handling 3D input/output tensors, which could lead to runtime errors. The function also lacks validation for the shapes and dtypes of output tensors. - I've identified some unused code, specifically the
_RMSNormFP4QuantUIDsenum, which can be removed for better clarity. - There are minor type hint inconsistencies in the
CudnnRMSNormFP4QuantRunnerclass that I've pointed out.
I have provided detailed suggestions to address these points. Overall, this is a solid addition, and these changes will improve its robustness and maintainability.
flashinfer/norm.py
Outdated
| _check_cudnn_rmsnorm_fp4quant_availability() | ||
|
|
||
| # Validate input dtype | ||
| if input.dtype not in (torch.float16, torch.bfloat16): | ||
| raise ValueError( | ||
| f"Unsupported input dtype: {input.dtype}. " | ||
| f"Only torch.float16 and torch.bfloat16 are supported." | ||
| ) | ||
|
|
||
| if input.dtype != weight.dtype: | ||
| raise ValueError( | ||
| f"Input and weight must have the same dtype. " | ||
| f"Got input: {input.dtype}, weight: {weight.dtype}" | ||
| ) | ||
|
|
||
| # Handle input shape | ||
| if input.ndim == 2: | ||
| batch_size, hidden_size = input.shape | ||
| elif input.ndim == 3: | ||
| batch_size, seq_len, hidden_size = input.shape | ||
| # Flatten to 2D for processing | ||
| input = input.view(batch_size * seq_len, hidden_size) | ||
| batch_size = batch_size * seq_len | ||
| else: | ||
| raise ValueError( | ||
| f"Input must be 2D or 3D tensor, got {input.ndim}D tensor with shape {input.shape}" | ||
| ) | ||
|
|
||
| # Validate dimensions | ||
| if hidden_size % block_size != 0: | ||
| raise ValueError( | ||
| f"hidden_size ({hidden_size}) must be divisible by block_size ({block_size})" | ||
| ) | ||
|
|
||
| if hidden_size % 2 != 0: | ||
| raise ValueError( | ||
| f"hidden_size ({hidden_size}) must be divisible by 2 for FP4 packing" | ||
| ) | ||
|
|
||
| # Flatten input to 2D for the runner | ||
| input_2d = input.view(batch_size, hidden_size) | ||
|
|
||
| # Use AutoTuner for execution plan selection | ||
| runners = [_get_cudnn_rmsnorm_fp4quant_runner()] | ||
| tuner = AutoTuner.get() | ||
|
|
||
| # Package inputs for the runner | ||
| # Note: eps and block_size are passed as Python values (not tensors) | ||
| inputs = [ | ||
| input_2d, | ||
| weight, | ||
| y_fp4, | ||
| block_scale, | ||
| eps, | ||
| block_size, | ||
| ] | ||
|
|
||
| runner, tactic = tuner.choose_one( | ||
| "rmsnorm_fp4quant", | ||
| runners, | ||
| TuningConfig(), | ||
| inputs, | ||
| ) | ||
|
|
||
| runner(inputs=inputs, tactic=tactic) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation of rmsnorm_fp4quant has a few issues:
- 3D Tensor Handling: It doesn't correctly handle 3D output tensors (
y_fp4,block_scale). The docstring states 3D tensors are supported, but they are passed to the runner which expects 2D-compatible views, leading to a runtime error. - Missing Validation: There is no validation for the shapes and dtypes of the output tensors, which could lead to cryptic errors if incorrect tensors are provided.
- Clarity: The logic for handling 2D/3D inputs and flattening can be simplified and made clearer.
I've provided a refactored version of the function body that addresses these points by:
- Adding comprehensive validation for input and output tensor dtypes and shapes.
- Correctly creating 2D views for 3D input and output tensors before passing them to the runner.
- Restructuring the logic for better readability.
_check_cudnn_rmsnorm_fp4quant_availability()
# Validate input dtypes
if input.dtype not in (torch.float16, torch.bfloat16):
raise ValueError(
f"Unsupported input dtype: {input.dtype}. "
f"Only torch.float16 and torch.bfloat16 are supported."
)
if input.dtype != weight.dtype:
raise ValueError(
f"Input and weight must have the same dtype. "
f"Got input: {input.dtype}, weight: {weight.dtype}"
)
# Validate output dtypes
if y_fp4.dtype != torch.uint8:
raise ValueError(f"y_fp4 must have dtype torch.uint8, but got {y_fp4.dtype}")
if block_scale.dtype != torch.float8_e4m3fn:
raise ValueError(
f"block_scale must have dtype torch.float8_e4m3fn, but got {block_scale.dtype}"
)
# Validate dimensions
hidden_size = input.shape[-1]
if hidden_size % block_size != 0:
raise ValueError(
f"hidden_size ({hidden_size}) must be divisible by block_size ({block_size})"
)
if hidden_size % 2 != 0:
raise ValueError(
f"hidden_size ({hidden_size}) must be divisible by 2 for FP4 packing"
)
# Handle input/output shapes and prepare 2D views for the runner
if input.ndim == 2:
batch_size, _ = input.shape
input_2d = input
y_fp4_2d = y_fp4
block_scale_2d = block_scale
# Validate output shapes for 2D input
expected_y_fp4_shape = (batch_size, hidden_size // 2)
if y_fp4.shape != expected_y_fp4_shape:
raise ValueError(
f"For 2D input, y_fp4 shape must be {expected_y_fp4_shape}, but got {y_fp4.shape}"
)
expected_block_scale_shape = (batch_size, hidden_size // block_size)
if block_scale.shape != expected_block_scale_shape:
raise ValueError(
f"For 2D input, block_scale shape must be {expected_block_scale_shape}, but got {block_scale.shape}"
)
elif input.ndim == 3:
batch_size, seq_len, _ = input.shape
flat_batch_size = batch_size * seq_len
# Validate output shapes for 3D input
expected_y_fp4_shape = (batch_size, seq_len, hidden_size // 2)
if y_fp4.shape != expected_y_fp4_shape:
raise ValueError(
f"For 3D input, y_fp4 shape must be {expected_y_fp4_shape}, but got {y_fp4.shape}"
)
expected_block_scale_shape = (batch_size, seq_len, hidden_size // block_size)
if block_scale.shape != expected_block_scale_shape:
raise ValueError(
f"For 3D input, block_scale shape must be {expected_block_scale_shape}, but got {block_scale.shape}"
)
# Flatten to 2D for processing
input_2d = input.view(flat_batch_size, hidden_size)
y_fp4_2d = y_fp4.view(flat_batch_size, hidden_size // 2)
block_scale_2d = block_scale.view(flat_batch_size, hidden_size // block_size)
else:
raise ValueError(
f"Input must be 2D or 3D tensor, got {input.ndim}D tensor with shape {input.shape}"
)
# Use AutoTuner for execution plan selection
runners = [_get_cudnn_rmsnorm_fp4quant_runner()]
tuner = AutoTuner.get()
# Package inputs for the runner
# Note: eps and block_size are passed as Python values (not tensors)
inputs = [
input_2d,
weight,
y_fp4_2d,
block_scale_2d,
eps,
block_size,
]
runner, tactic = tuner.choose_one(
"rmsnorm_fp4quant",
runners,
TuningConfig(),
inputs,
)
runner(inputs=inputs, tactic=tactic)| class _RMSNormFP4QuantUIDs(Enum): | ||
| """UIDs for cuDNN RMSNorm + FP4 Quantization graph tensors.""" | ||
|
|
||
| X_UID = 0 | ||
| WEIGHT_UID = 1 | ||
| EPSILON_UID = 2 | ||
| Y_FP4_UID = 3 | ||
| BLOCK_SCALE_UID = 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| class CudnnRMSNormFP4QuantRunner(TunableRunner): | ||
| def get_valid_tactics( | ||
| self, | ||
| inputs: List[torch.Tensor], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type hint for the inputs parameter is List[torch.Tensor], but the list also contains non-tensor values like eps (float) and block_size (int). This mismatch can be misleading. To accurately reflect the contents of the list, please update the type hint to List[Any].
| inputs: List[torch.Tensor], | |
| inputs: List[Any], |
|
|
||
| def forward( | ||
| self, | ||
| inputs: List[torch.Tensor], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type hint for the inputs parameter is List[torch.Tensor], but the list also contains non-tensor values like eps (float) and block_size (int). This mismatch can be misleading. To accurately reflect the contents of the list, please update the type hint to List[Any].
| inputs: List[torch.Tensor], | |
| inputs: List[Any], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
π§Ή Nitpick comments (7)
flashinfer/norm.py (4)
324-378: Clean up unused enum and unusedinv_variancefrom cuDNN RMSNorm call.
_RMSNormFP4QuantUIDsis defined but not referenced anywhere; unless you plan to use UIDs in a follow-up (e.g., for debugging or graph introspection), itβs effectively dead code and can be removed along with theEnumimport.- In
_create_rmsnorm_fp4quant_execution_plans,inv_varianceis never used, which Ruff flags:y_rmsnorm, inv_variance = graph.rmsnorm(...)If you donβt need it, capture it with a dummy variable to silence the warning:
- y_rmsnorm, inv_variance = graph.rmsnorm( + y_rmsnorm, _inv_variance = graph.rmsnorm(These are small cleanups that keep the new section tidy and lint-clean.
578-615: Address unused arguments inCudnnRMSNormFP4QuantRunnerto keep linters green.Ruff reports several unused parameters:
profileinget_valid_tacticsweight,y_fp4,block_scale,epsin theinputsunpackdo_preparationandkwargsinforwardYou can satisfy the linter without changing behavior by explicitly marking them as used:
def get_valid_tactics( self, inputs: List[torch.Tensor], - profile: OptimizationProfile, + profile: OptimizationProfile, ) -> List[int]: - """Return list of valid execution plan indices for autotuning.""" + """Return list of valid execution plan indices for autotuning.""" + _ = profile # currently unused ( - input_tensor, - weight, - y_fp4, - block_scale, - eps, - block_size, + input_tensor, + _weight, + _y_fp4, + _block_scale, + _eps, + block_size, ) = inputsand:
def forward( self, inputs: List[torch.Tensor], tactic: int = -1, - do_preparation: bool = False, - **kwargs, + do_preparation: bool = False, + **kwargs, ): """Execute the RMSNorm + FP4 quantization with specified tactic.""" + _ = do_preparation, kwargs # currently unusedThis aligns with Ruffβs ARG002/RUF059 hints without altering the runtime behavior.
661-801: Validate output tensors and reconcile AutoTuner inputs with expected tensor-only lists.Two areas here are worth tightening up:
Output tensor validation
rmsnorm_fp4quantenforces dtype/shape oninputandweight, but assumesy_fp4andblock_scaleare correctly shaped/dtyped and on the same device. Mis-specified outputs could lead to cryptic cuDNN errors or memory misuse.Consider adding explicit checks along the lines of:
if y_fp4.dtype is not torch.uint8: raise ValueError("y_fp4 must be uint8 (packed FP4).") if block_scale.dtype is not torch.float8_e4m3fn: raise ValueError("block_scale must be float8_e4m3fn.") expected_y_shape = (batch_size, hidden_size // 2) expected_scale_shape = (batch_size, hidden_size // block_size) # For 3D, adjust expected_* to include seq_len similarly.and verify
input.device,weight.device,y_fp4.device, andblock_scale.deviceall match. This would fail fast with clear messages instead of relying on backend errors.Passing non-tensors to
AutoTuner.choose_one
AutoTuner.choose_oneis documented to takeinputs: List[torch.Tensor]and internally computesinput_shapes = tuple(self._get_input_sizes(inputs)). In this implementation,inputsincludes Python scalars:inputs = [ input_2d, weight, y_fp4, block_scale, eps, # float block_size, # int ]If
_get_input_sizesassumes every element is atorch.Tensor(e.g.,tuple(t.size() for t in inputs)), includingepsandblock_sizecould raise anAttributeErroror skew the cache key.It may be safer to:
- restrict
inputsto tensors only (e.g.,[input_2d, weight, y_fp4, block_scale]) and passepsandblock_sizethrough**kwargs, or- wrap them as 0-D tensors if they must participate in shape-based caching.
Please double-check the current implementation of
_get_input_sizesinflashinfer.autotuner.AutoTunerand adjust accordingly.
352-367: Optional: add a compute capability check to match βBlackwell-onlyβ contract.Docstring for
rmsnorm_fp4quantsays it βrequires β¦ a Blackwell GPU (compute capability >= 100)β, but_check_cudnn_rmsnorm_fp4quant_availabilitycurrently only validates cuDNN presence and backend version.To align behavior with the documented requirement (and with the testsβ
requires_blackwell()), consider adding a lightweight capability check:cc_major, cc_minor = torch.cuda.get_device_capability(device) if cc_major * 10 + cc_minor < 100: raise RuntimeError( "rmsnorm_fp4quant requires a Blackwell GPU (compute capability >= 100); " f"found compute capability {cc_major}.{cc_minor}." )This would give users a clear error instead of letting them hit backend failures on unsupported GPUs.
tests/norm/test_rmsnorm_fp4_quant.py (3)
65-114: Optional: drop unused reference quantization outputs or mark them as intentionally unused.In both tests you assign but never use
ref_fp4andref_scale:ref_fp4, ref_scale = ref_block_scale_quantize(ref_rmsnorm, block_size=block_size)Since the assertions only compare
y_dequantwithref_rmsnorm.float(), you can either:
- remove the assignments (and even the call, if you donβt need it for debugging), or
- mark them as intentionally unused:
- ref_fp4, ref_scale = ref_block_scale_quantize(ref_rmsnorm, block_size=block_size) + _ref_fp4, _ref_scale = ref_block_scale_quantize(ref_rmsnorm, block_size=block_size)Same pattern applies in the 3D test. This keeps the reference helper available while avoiding unused-variable warnings.
145-159: Make Blackwell capability check robust on CPU-only environments.
requires_blackwell()currently callstorch.cuda.get_device_capability()unconditionally viaget_cc(). On machines without CUDA, this can raise at import time when evaluating theblackwell_requiredskip marker.To keep the tests importable everywhere, you can guard on CUDA availability:
def get_cc(): - """Get CUDA compute capability.""" - major, minor = torch.cuda.get_device_capability() - return major * 10 + minor + """Get CUDA compute capability (returns 0 on non-CUDA systems).""" + if not torch.cuda.is_available(): + return 0 + major, minor = torch.cuda.get_device_capability() + return major * 10 + minor
requires_blackwell()can remain asreturn get_cc() >= 100.
134-143: Optional: narrow the broadexcept Exceptioninrequires_cudnn_fp4.In
requires_cudnn_fp4(), a bareexcept Exception:hides all errors from thecudnnimport and version query, including programming mistakes:try: import cudnn return cudnn.backend_version() >= 90700 except Exception: return FalseIf you only intend to treat import or runtime environment failures as βfeature not availableβ, consider narrowing the exception (e.g.,
ImportError,OSError, or a tuple of expected failures). Otherwise, this is acceptable for a test-only capability probe.
π Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
π Files selected for processing (4)
flashinfer/__init__.py(1 hunks)flashinfer/norm.py(2 hunks)tests/norm/__init__.py(1 hunks)tests/norm/test_rmsnorm_fp4_quant.py(1 hunks)
π§° Additional context used
𧬠Code graph analysis (2)
flashinfer/__init__.py (1)
flashinfer/norm.py (1)
rmsnorm_fp4quant(662-802)
flashinfer/norm.py (3)
flashinfer/autotuner.py (8)
AutoTuner(335-791)OptimizationProfile(168-183)TunableRunner(194-247)TuningConfig(101-141)get_valid_tactics(196-214)forward(220-244)get(362-365)choose_one(400-534)flashinfer/jit/norm.py (1)
gen_norm_module(21-33)csrc/norm.cu (2)
rmsnorm(22-78)rmsnorm(22-22)
πͺ GitHub Actions: pre-commit
tests/norm/__init__.py
[error] 1-1: End-of-file fixer modified the file. Please re-run hooks or commit again.
[error] 1-1: Trailing whitespace detected and removed by trailing-whitespace hook.
tests/norm/test_rmsnorm_fp4_quant.py
[error] 20-20: ruff-check: F401 'flashinfer' imported but unused.
flashinfer/norm.py
[error] 19-19: ruff-check: F401 'typing.Tuple' imported but unused.
πͺ Ruff (0.14.8)
tests/norm/test_rmsnorm_fp4_quant.py
62-62: Avoid specifying long messages outside the exception class
(TRY003)
94-94: Avoid specifying long messages outside the exception class
(TRY003)
141-141: Do not catch blind exception: Exception
(BLE001)
202-202: Unpacked variable ref_fp4 is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
202-202: Unpacked variable ref_scale is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
257-257: Unpacked variable ref_fp4 is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
257-257: Unpacked variable ref_scale is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
flashinfer/norm.py
343-346: Avoid specifying long messages outside the exception class
(TRY003)
355-358: Avoid specifying long messages outside the exception class
(TRY003)
363-366: Avoid specifying long messages outside the exception class
(TRY003)
376-378: Avoid specifying long messages outside the exception class
(TRY003)
446-446: Unpacked variable inv_variance is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
582-582: Unused method argument: profile
(ARG002)
587-587: Unpacked variable weight is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
588-588: Unpacked variable y_fp4 is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
589-589: Unpacked variable block_scale is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
590-590: Unpacked variable eps is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
613-613: Unused method argument: do_preparation
(ARG002)
614-614: Unused method argument: kwargs
(ARG002)
742-745: Avoid specifying long messages outside the exception class
(TRY003)
748-751: Avoid specifying long messages outside the exception class
(TRY003)
762-764: Avoid specifying long messages outside the exception class
(TRY003)
768-770: Avoid specifying long messages outside the exception class
(TRY003)
773-775: Avoid specifying long messages outside the exception class
(TRY003)
β° Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
π Additional comments (1)
flashinfer/__init__.py (1)
93-99: Exposermsnorm_fp4quantat top level looks good.The additional import cleanly surfaces
rmsnorm_fp4quantin the top-levelflashinferAPI alongside other norm functions; no further changes needed here.
tests/norm/__init__.py
Outdated
| # Tests for normalization operations | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix trailing whitespace / EOF formatting to satisfy pre-commit.
Pre-commit reports both end-of-file-fixer and trailing-whitespace modifying this file. Ensure the file has no trailing spaces and ends with exactly one newline, e.g.:
# Tests for normalization operationswith a single newline at the end of the file.
π§° Tools
πͺ GitHub Actions: pre-commit
[error] 1-1: End-of-file fixer modified the file. Please re-run hooks or commit again.
[error] 1-1: Trailing whitespace detected and removed by trailing-whitespace hook.
π€ Prompt for AI Agents
In tests/norm/__init__.py around lines 1 to 2, remove any trailing spaces on the
line and ensure the file ends with exactly one newline character (no extra blank
lines at EOF); edit the file so the only content is the comment line with no
trailing whitespace and save with a single terminating newline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
β»οΈ Duplicate comments (6)
flashinfer/norm.py (5)
19-19: Remove unusedTupleimport.This import is flagged by pre-commit (ruff F401) and has been noted in previous reviews. Please remove it to satisfy the linter.
328-335: Remove unused_RMSNormFP4QuantUIDsenum.This enum is defined but never used in the implementation. As noted in a previous review, removing it will improve code clarity.
583-587: Update type hint toList[Any]for mixed-type inputs.As noted in a previous review, the
inputsparameter contains both tensors and scalar values (eps, block_size), soList[torch.Tensor]is misleading.
613-619: Update type hint toList[Any]for mixed-type inputs.As noted in a previous review, the
inputsparameter contains both tensors and scalar values (eps, block_size), soList[torch.Tensor]is misleading.
774-823: Fix 3D tensor handling for output tensors.As noted in a comprehensive previous review, the current implementation doesn't properly handle 3D output tensors (
y_fp4andblock_scale). When the input is 3D, these outputs are also 3D but are passed to the runner without flattening, causing a shape mismatch at line 637-638 in theforwardmethod.The previous review provided a detailed fix that includes:
- Flattening output tensors to 2D views when input is 3D
- Adding comprehensive shape and dtype validation for outputs
- Clearer logic flow for handling 2D vs 3D cases
tests/norm/test_rmsnorm_fp4_quant.py (1)
20-20: Remove unusedflashinferimport.This import is flagged by pre-commit (ruff F401) and has been noted in a previous review. Only the specific imports from
flashinfer.normare needed.
π§Ή Nitpick comments (2)
flashinfer/norm.py (1)
450-457: Use_for unused return value.The
inv_varianceoutput fromgraph.rmsnormis not used. Consider using_to make this explicit:- y_rmsnorm, inv_variance = graph.rmsnorm( + y_rmsnorm, _ = graph.rmsnorm(tests/norm/test_rmsnorm_fp4_quant.py (1)
155-158: Remove commented-out parametrize decorators.These commented-out lines appear to be leftover debugging code and should be removed before merging.
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -# @pytest.mark.parametrize("batch_size", []) -# @pytest.mark.parametrize("seq_len", []) -# @pytest.mark.parametrize("hidden_size", []) -# @pytest.mark.parametrize("dtype", []) def test_rmsnorm_fp4quant_3d(batch_size, seq_len, hidden_size, dtype):
π Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
π Files selected for processing (4)
flashinfer/norm.py(2 hunks)scripts/task_jit_run_tests_part5.sh(1 hunks)tests/norm/__init__.py(1 hunks)tests/norm/test_rmsnorm_fp4_quant.py(1 hunks)
π§ Files skipped from review as they are similar to previous changes (1)
- tests/norm/init.py
π§° Additional context used
𧬠Code graph analysis (1)
flashinfer/norm.py (5)
flashinfer/api_logging.py (1)
flashinfer_api(464-565)flashinfer/autotuner.py (7)
AutoTuner(335-791)OptimizationProfile(168-183)TunableRunner(194-247)get_valid_tactics(196-214)forward(220-244)get(362-365)choose_one(400-534)flashinfer/utils.py (2)
backend_requirement(902-1184)supported_compute_capability(819-899)include/flashinfer/trtllm/common.h (1)
device(83-90)csrc/norm.cu (2)
rmsnorm(22-78)rmsnorm(22-22)
πͺ Ruff (0.14.8)
flashinfer/norm.py
347-350: Avoid specifying long messages outside the exception class
(TRY003)
359-362: Avoid specifying long messages outside the exception class
(TRY003)
367-370: Avoid specifying long messages outside the exception class
(TRY003)
380-382: Avoid specifying long messages outside the exception class
(TRY003)
450-450: Unpacked variable inv_variance is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
586-586: Unused method argument: profile
(ARG002)
591-591: Unpacked variable weight is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
592-592: Unpacked variable y_fp4 is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
593-593: Unpacked variable block_scale is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
594-594: Unpacked variable eps is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
617-617: Unused method argument: do_preparation
(ARG002)
618-618: Unused method argument: kwargs
(ARG002)
667-667: Unused function argument: input
(ARG001)
668-668: Unused function argument: weight
(ARG001)
669-669: Unused function argument: y_fp4
(ARG001)
670-670: Unused function argument: block_scale
(ARG001)
671-671: Unused function argument: eps
(ARG001)
672-672: Unused function argument: block_size
(ARG001)
763-766: Avoid specifying long messages outside the exception class
(TRY003)
769-772: Avoid specifying long messages outside the exception class
(TRY003)
783-785: Avoid specifying long messages outside the exception class
(TRY003)
789-791: Avoid specifying long messages outside the exception class
(TRY003)
794-796: Avoid specifying long messages outside the exception class
(TRY003)
tests/norm/test_rmsnorm_fp4_quant.py
63-63: Avoid specifying long messages outside the exception class
(TRY003)
74-74: Do not catch blind exception: Exception
(BLE001)
β° Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
π Additional comments (1)
scripts/task_jit_run_tests_part5.sh (1)
15-15: LGTM!The addition of the new test file to the test suite is correct and follows the existing pattern of running each test file separately.
π Description
π Related Issues
π Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
β Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.π§ͺ Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Tests
Chores
βοΈ Tip: You can customize this high-level summary in your review settings.