Skip to content

Conversation

@nv-yunzheq
Copy link
Contributor

@nv-yunzheq nv-yunzheq commented Dec 15, 2025

πŸ“Œ Description

πŸ” Related Issues

πŸš€ Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

βœ… Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

πŸ§ͺ Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added a fused RMS normalization + FP4 quantization operation with support for 2D and 3D inputs and automatic performance tuning, exposed at package level.
  • Tests

    • Added comprehensive tests validating the new fused quantization path for 2D and 3D scenarios (including dequantization checks).
  • Chores

    • Updated test runner script to include the new test module and added test package initializer.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 15, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

A new fused RMSNorm + FP4 quantization API, rmsnorm_fp4quant, was added with cuDNN graph construction and execution paths, autotuning support, and accompanying tests for 2D/3D inputs. The package-level export exposes the new function.

Changes

Cohort / File(s) Summary
Package export
\flashinfer/init.py``
Exposes rmsnorm_fp4quant at package level
Core implementation
\flashinfer/norm.py``
Adds cuDNN-backed RMSNorm+FP4 quantization pipeline: cuDNN availability/handle utilities, graph building (_build_rmsnorm_fp4quant_graph / _create_rmsnorm_fp4quant_execution_plans), execution (_execute_rmsnorm_fp4quant_graph), a TunableRunner (_get_cudnn_rmsnorm_fp4quant_runner) for tactic selection, and public rmsnorm_fp4quant API with input validation, 2D/3D handling, and AutoTuner integration
Test package initializer
\tests/norm/init.py``
Adds package initializer for norm tests
Tests
\tests/norm/test_rmsnorm_fp4_quant.py``
Adds tests exercising fused RMSNorm+FP4: reference implementations (RMSNorm, FP4 block-scale quantize/dequantize), environment capability checks, and parametrized 2D/3D test cases validating shapes, dtypes, and numeric behavior

Sequence Diagram

sequenceDiagram
    participant User
    participant API as rmsnorm_fp4quant()
    participant Validate as Input Validation
    participant AutoTuner
    participant Runner as TunableRunner
    participant cuDNN as cuDNN Graph Engine
    participant Result as FP4 Output

    User->>API: call(input, weight, y_fp4, block_scale, eps, block_size)
    API->>Validate: check dtypes, shapes, divisibility, flatten if 3D
    Validate-->>API: ok
    API->>AutoTuner: build OptimizationProfile / sample inputs
    AutoTuner->>Runner: instantiate runner with cuDNN graph & plans
    Runner->>cuDNN: enumerate tactics / run plans for tuning
    cuDNN-->>Runner: timing/results
    AutoTuner-->>API: select best tactic
    API->>Runner: execute chosen tactic with real inputs
    Runner->>cuDNN: execute graph (RMSNorm -> FP4 block quant)
    cuDNN-->>Result: FP4 packed tensor + block scales
    Result-->>User: return quantized output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

  • Pay attention to cuDNN graph construction and plan/workspace management in flashinfer/norm.py
  • Review autotuning / TunableRunner integration and tactic selection logic
  • Verify 2D/3D input handling and block-size divisibility checks
  • Inspect new tests (tests/norm/test_rmsnorm_fp4_quant.py) for correctness of reference quantization/dequantization and environment gating
  • Confirm there is no unintended duplication of the same scaffolding blocks

Poem

🐰 I hopped through kernels, bits, and queues,
I braided norms with tiny FP4 hues,
cuDNN hummed a tuneful, tuned ballet,
Blocks of scale marched neatly on their way,
I nibbled bugs β€” now tests applaud πŸŽ‰πŸ₯•

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description is entirely blank template content with no actual description of the changes, no related issues linked, and no reviewer notes provided, failing to meet template requirements. Fill in the Description section explaining the new rmsnorm_fp4quant functionality, link any related issues, and document test coverage and pre-commit status.
βœ… Passed checks (2 passed)
Check name Status Explanation
Title check βœ… Passed The title '[WIP] Add norm fp4quant fusion' is directly related to the main changes in the PR, which add a cuDNN-based RMSNorm with FP4 quantization fusion operation.
Docstring Coverage βœ… Passed Docstring coverage is 94.74% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • πŸ“ Generate docstrings
πŸ§ͺ Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❀️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @nv-yunzheq, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new rmsnorm_fp4quant function designed to fuse RMS normalization and FP4 block scale quantization into a single, highly optimized operation. By utilizing the cuDNN graph API, this change aims to significantly enhance performance, particularly on Blackwell GPUs, for models that employ FP4 quantization. The implementation includes robust checks for cuDNN availability and version, integrates with an autotuning mechanism for optimal execution plan selection, and is accompanied by comprehensive unit tests to ensure functional correctness across various input configurations.

Highlights

  • Fused RMSNorm + FP4 Quantization: Introduced a new rmsnorm_fp4quant function that fuses RMS normalization and FP4 block scale quantization into a single operation, leveraging cuDNN for performance.
  • cuDNN Graph API Integration: The new fused operation utilizes the cuDNN graph API to create optimized execution plans, specifically targeting Blackwell GPUs for enhanced efficiency.
  • Autotuning Support: Integrated the rmsnorm_fp4quant function with the existing AutoTuner to automatically select the most efficient cuDNN execution plan based on input configurations.
  • Comprehensive Unit Tests: Added extensive unit tests for the rmsnorm_fp4quant function, including reference implementations for correctness verification across 2D and 3D input shapes, various batch sizes, hidden sizes, and data types.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with πŸ‘ and πŸ‘Ž on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fused RMSNorm with FP4 quantization kernel, rmsnorm_fp4quant, leveraging cuDNN for enhanced performance. The implementation is well-structured, utilizing cuDNN graphs and an autotuner. The addition of corresponding tests is also a valuable contribution. My review focuses on a few key areas for improvement:

  • There's a correctness issue in rmsnorm_fp4quant related to handling 3D input/output tensors, which could lead to runtime errors. The function also lacks validation for the shapes and dtypes of output tensors.
  • I've identified some unused code, specifically the _RMSNormFP4QuantUIDs enum, which can be removed for better clarity.
  • There are minor type hint inconsistencies in the CudnnRMSNormFP4QuantRunner class that I've pointed out.

I have provided detailed suggestions to address these points. Overall, this is a solid addition, and these changes will improve its robustness and maintainability.

Comment on lines 738 to 802
_check_cudnn_rmsnorm_fp4quant_availability()

# Validate input dtype
if input.dtype not in (torch.float16, torch.bfloat16):
raise ValueError(
f"Unsupported input dtype: {input.dtype}. "
f"Only torch.float16 and torch.bfloat16 are supported."
)

if input.dtype != weight.dtype:
raise ValueError(
f"Input and weight must have the same dtype. "
f"Got input: {input.dtype}, weight: {weight.dtype}"
)

# Handle input shape
if input.ndim == 2:
batch_size, hidden_size = input.shape
elif input.ndim == 3:
batch_size, seq_len, hidden_size = input.shape
# Flatten to 2D for processing
input = input.view(batch_size * seq_len, hidden_size)
batch_size = batch_size * seq_len
else:
raise ValueError(
f"Input must be 2D or 3D tensor, got {input.ndim}D tensor with shape {input.shape}"
)

# Validate dimensions
if hidden_size % block_size != 0:
raise ValueError(
f"hidden_size ({hidden_size}) must be divisible by block_size ({block_size})"
)

if hidden_size % 2 != 0:
raise ValueError(
f"hidden_size ({hidden_size}) must be divisible by 2 for FP4 packing"
)

# Flatten input to 2D for the runner
input_2d = input.view(batch_size, hidden_size)

# Use AutoTuner for execution plan selection
runners = [_get_cudnn_rmsnorm_fp4quant_runner()]
tuner = AutoTuner.get()

# Package inputs for the runner
# Note: eps and block_size are passed as Python values (not tensors)
inputs = [
input_2d,
weight,
y_fp4,
block_scale,
eps,
block_size,
]

runner, tactic = tuner.choose_one(
"rmsnorm_fp4quant",
runners,
TuningConfig(),
inputs,
)

runner(inputs=inputs, tactic=tactic)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of rmsnorm_fp4quant has a few issues:

  1. 3D Tensor Handling: It doesn't correctly handle 3D output tensors (y_fp4, block_scale). The docstring states 3D tensors are supported, but they are passed to the runner which expects 2D-compatible views, leading to a runtime error.
  2. Missing Validation: There is no validation for the shapes and dtypes of the output tensors, which could lead to cryptic errors if incorrect tensors are provided.
  3. Clarity: The logic for handling 2D/3D inputs and flattening can be simplified and made clearer.

I've provided a refactored version of the function body that addresses these points by:

  • Adding comprehensive validation for input and output tensor dtypes and shapes.
  • Correctly creating 2D views for 3D input and output tensors before passing them to the runner.
  • Restructuring the logic for better readability.
    _check_cudnn_rmsnorm_fp4quant_availability()

    # Validate input dtypes
    if input.dtype not in (torch.float16, torch.bfloat16):
        raise ValueError(
            f"Unsupported input dtype: {input.dtype}. "
            f"Only torch.float16 and torch.bfloat16 are supported."
        )
    if input.dtype != weight.dtype:
        raise ValueError(
            f"Input and weight must have the same dtype. "
            f"Got input: {input.dtype}, weight: {weight.dtype}"
        )

    # Validate output dtypes
    if y_fp4.dtype != torch.uint8:
        raise ValueError(f"y_fp4 must have dtype torch.uint8, but got {y_fp4.dtype}")
    if block_scale.dtype != torch.float8_e4m3fn:
        raise ValueError(
            f"block_scale must have dtype torch.float8_e4m3fn, but got {block_scale.dtype}"
        )

    # Validate dimensions
    hidden_size = input.shape[-1]
    if hidden_size % block_size != 0:
        raise ValueError(
            f"hidden_size ({hidden_size}) must be divisible by block_size ({block_size})"
        )
    if hidden_size % 2 != 0:
        raise ValueError(
            f"hidden_size ({hidden_size}) must be divisible by 2 for FP4 packing"
        )

    # Handle input/output shapes and prepare 2D views for the runner
    if input.ndim == 2:
        batch_size, _ = input.shape
        input_2d = input
        y_fp4_2d = y_fp4
        block_scale_2d = block_scale

        # Validate output shapes for 2D input
        expected_y_fp4_shape = (batch_size, hidden_size // 2)
        if y_fp4.shape != expected_y_fp4_shape:
            raise ValueError(
                f"For 2D input, y_fp4 shape must be {expected_y_fp4_shape}, but got {y_fp4.shape}"
            )
        expected_block_scale_shape = (batch_size, hidden_size // block_size)
        if block_scale.shape != expected_block_scale_shape:
            raise ValueError(
                f"For 2D input, block_scale shape must be {expected_block_scale_shape}, but got {block_scale.shape}"
            )
    elif input.ndim == 3:
        batch_size, seq_len, _ = input.shape
        flat_batch_size = batch_size * seq_len

        # Validate output shapes for 3D input
        expected_y_fp4_shape = (batch_size, seq_len, hidden_size // 2)
        if y_fp4.shape != expected_y_fp4_shape:
            raise ValueError(
                f"For 3D input, y_fp4 shape must be {expected_y_fp4_shape}, but got {y_fp4.shape}"
            )
        expected_block_scale_shape = (batch_size, seq_len, hidden_size // block_size)
        if block_scale.shape != expected_block_scale_shape:
            raise ValueError(
                f"For 3D input, block_scale shape must be {expected_block_scale_shape}, but got {block_scale.shape}"
            )

        # Flatten to 2D for processing
        input_2d = input.view(flat_batch_size, hidden_size)
        y_fp4_2d = y_fp4.view(flat_batch_size, hidden_size // 2)
        block_scale_2d = block_scale.view(flat_batch_size, hidden_size // block_size)
    else:
        raise ValueError(
            f"Input must be 2D or 3D tensor, got {input.ndim}D tensor with shape {input.shape}"
        )

    # Use AutoTuner for execution plan selection
    runners = [_get_cudnn_rmsnorm_fp4quant_runner()]
    tuner = AutoTuner.get()

    # Package inputs for the runner
    # Note: eps and block_size are passed as Python values (not tensors)
    inputs = [
        input_2d,
        weight,
        y_fp4_2d,
        block_scale_2d,
        eps,
        block_size,
    ]

    runner, tactic = tuner.choose_one(
        "rmsnorm_fp4quant",
        runners,
        TuningConfig(),
        inputs,
    )

    runner(inputs=inputs, tactic=tactic)

Comment on lines +324 to +331
class _RMSNormFP4QuantUIDs(Enum):
"""UIDs for cuDNN RMSNorm + FP4 Quantization graph tensors."""

X_UID = 0
WEIGHT_UID = 1
EPSILON_UID = 2
Y_FP4_UID = 3
BLOCK_SCALE_UID = 4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _RMSNormFP4QuantUIDs enum is defined but appears to be unused throughout the file. To improve code clarity and remove dead code, it's best to remove this enum.

class CudnnRMSNormFP4QuantRunner(TunableRunner):
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type hint for the inputs parameter is List[torch.Tensor], but the list also contains non-tensor values like eps (float) and block_size (int). This mismatch can be misleading. To accurately reflect the contents of the list, please update the type hint to List[Any].

Suggested change
inputs: List[torch.Tensor],
inputs: List[Any],


def forward(
self,
inputs: List[torch.Tensor],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type hint for the inputs parameter is List[torch.Tensor], but the list also contains non-tensor values like eps (float) and block_size (int). This mismatch can be misleading. To accurately reflect the contents of the list, please update the type hint to List[Any].

Suggested change
inputs: List[torch.Tensor],
inputs: List[Any],

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (7)
flashinfer/norm.py (4)

324-378: Clean up unused enum and unused inv_variance from cuDNN RMSNorm call.

  • _RMSNormFP4QuantUIDs is defined but not referenced anywhere; unless you plan to use UIDs in a follow-up (e.g., for debugging or graph introspection), it’s effectively dead code and can be removed along with the Enum import.
  • In _create_rmsnorm_fp4quant_execution_plans, inv_variance is never used, which Ruff flags:
y_rmsnorm, inv_variance = graph.rmsnorm(...)

If you don’t need it, capture it with a dummy variable to silence the warning:

-    y_rmsnorm, inv_variance = graph.rmsnorm(
+    y_rmsnorm, _inv_variance = graph.rmsnorm(

These are small cleanups that keep the new section tidy and lint-clean.


578-615: Address unused arguments in CudnnRMSNormFP4QuantRunner to keep linters green.

Ruff reports several unused parameters:

  • profile in get_valid_tactics
  • weight, y_fp4, block_scale, eps in the inputs unpack
  • do_preparation and kwargs in forward

You can satisfy the linter without changing behavior by explicitly marking them as used:

        def get_valid_tactics(
            self,
            inputs: List[torch.Tensor],
-            profile: OptimizationProfile,
+            profile: OptimizationProfile,
        ) -> List[int]:
-            """Return list of valid execution plan indices for autotuning."""
+            """Return list of valid execution plan indices for autotuning."""
+            _ = profile  # currently unused

            (
-                input_tensor,
-                weight,
-                y_fp4,
-                block_scale,
-                eps,
-                block_size,
+                input_tensor,
+                _weight,
+                _y_fp4,
+                _block_scale,
+                _eps,
+                block_size,
            ) = inputs

and:

        def forward(
            self,
            inputs: List[torch.Tensor],
            tactic: int = -1,
-            do_preparation: bool = False,
-            **kwargs,
+            do_preparation: bool = False,
+            **kwargs,
        ):
             """Execute the RMSNorm + FP4 quantization with specified tactic."""
+            _ = do_preparation, kwargs  # currently unused

This aligns with Ruff’s ARG002/RUF059 hints without altering the runtime behavior.


661-801: Validate output tensors and reconcile AutoTuner inputs with expected tensor-only lists.

Two areas here are worth tightening up:

  1. Output tensor validation

    rmsnorm_fp4quant enforces dtype/shape on input and weight, but assumes y_fp4 and block_scale are correctly shaped/dtyped and on the same device. Mis-specified outputs could lead to cryptic cuDNN errors or memory misuse.

    Consider adding explicit checks along the lines of:

    if y_fp4.dtype is not torch.uint8:
        raise ValueError("y_fp4 must be uint8 (packed FP4).")
    
    if block_scale.dtype is not torch.float8_e4m3fn:
        raise ValueError("block_scale must be float8_e4m3fn.")
    
    expected_y_shape = (batch_size, hidden_size // 2)
    expected_scale_shape = (batch_size, hidden_size // block_size)
    # For 3D, adjust expected_* to include seq_len similarly.

    and verify input.device, weight.device, y_fp4.device, and block_scale.device all match. This would fail fast with clear messages instead of relying on backend errors.

  2. Passing non-tensors to AutoTuner.choose_one

    AutoTuner.choose_one is documented to take inputs: List[torch.Tensor] and internally computes input_shapes = tuple(self._get_input_sizes(inputs)). In this implementation, inputs includes Python scalars:

    inputs = [
        input_2d,
        weight,
        y_fp4,
        block_scale,
        eps,         # float
        block_size,  # int
    ]

    If _get_input_sizes assumes every element is a torch.Tensor (e.g., tuple(t.size() for t in inputs)), including eps and block_size could raise an AttributeError or skew the cache key.

    It may be safer to:

    • restrict inputs to tensors only (e.g., [input_2d, weight, y_fp4, block_scale]) and pass eps and block_size through **kwargs, or
    • wrap them as 0-D tensors if they must participate in shape-based caching.

    Please double-check the current implementation of _get_input_sizes in flashinfer.autotuner.AutoTuner and adjust accordingly.


352-367: Optional: add a compute capability check to match β€œBlackwell-only” contract.

Docstring for rmsnorm_fp4quant says it β€œrequires … a Blackwell GPU (compute capability >= 100)”, but _check_cudnn_rmsnorm_fp4quant_availability currently only validates cuDNN presence and backend version.

To align behavior with the documented requirement (and with the tests’ requires_blackwell()), consider adding a lightweight capability check:

cc_major, cc_minor = torch.cuda.get_device_capability(device)
if cc_major * 10 + cc_minor < 100:
    raise RuntimeError(
        "rmsnorm_fp4quant requires a Blackwell GPU (compute capability >= 100); "
        f"found compute capability {cc_major}.{cc_minor}."
    )

This would give users a clear error instead of letting them hit backend failures on unsupported GPUs.

tests/norm/test_rmsnorm_fp4_quant.py (3)

65-114: Optional: drop unused reference quantization outputs or mark them as intentionally unused.

In both tests you assign but never use ref_fp4 and ref_scale:

ref_fp4, ref_scale = ref_block_scale_quantize(ref_rmsnorm, block_size=block_size)

Since the assertions only compare y_dequant with ref_rmsnorm.float(), you can either:

  • remove the assignments (and even the call, if you don’t need it for debugging), or
  • mark them as intentionally unused:
-    ref_fp4, ref_scale = ref_block_scale_quantize(ref_rmsnorm, block_size=block_size)
+    _ref_fp4, _ref_scale = ref_block_scale_quantize(ref_rmsnorm, block_size=block_size)

Same pattern applies in the 3D test. This keeps the reference helper available while avoiding unused-variable warnings.


145-159: Make Blackwell capability check robust on CPU-only environments.

requires_blackwell() currently calls torch.cuda.get_device_capability() unconditionally via get_cc(). On machines without CUDA, this can raise at import time when evaluating the blackwell_required skip marker.

To keep the tests importable everywhere, you can guard on CUDA availability:

def get_cc():
-    """Get CUDA compute capability."""
-    major, minor = torch.cuda.get_device_capability()
-    return major * 10 + minor
+    """Get CUDA compute capability (returns 0 on non-CUDA systems)."""
+    if not torch.cuda.is_available():
+        return 0
+    major, minor = torch.cuda.get_device_capability()
+    return major * 10 + minor

requires_blackwell() can remain as return get_cc() >= 100.


134-143: Optional: narrow the broad except Exception in requires_cudnn_fp4.

In requires_cudnn_fp4(), a bare except Exception: hides all errors from the cudnn import and version query, including programming mistakes:

try:
    import cudnn
    return cudnn.backend_version() >= 90700
except Exception:
    return False

If you only intend to treat import or runtime environment failures as β€œfeature not available”, consider narrowing the exception (e.g., ImportError, OSError, or a tuple of expected failures). Otherwise, this is acceptable for a test-only capability probe.

πŸ“œ Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

πŸ“₯ Commits

Reviewing files that changed from the base of the PR and between f0355f7 and 13289e8.

πŸ“’ Files selected for processing (4)
  • flashinfer/__init__.py (1 hunks)
  • flashinfer/norm.py (2 hunks)
  • tests/norm/__init__.py (1 hunks)
  • tests/norm/test_rmsnorm_fp4_quant.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/__init__.py (1)
flashinfer/norm.py (1)
  • rmsnorm_fp4quant (662-802)
flashinfer/norm.py (3)
flashinfer/autotuner.py (8)
  • AutoTuner (335-791)
  • OptimizationProfile (168-183)
  • TunableRunner (194-247)
  • TuningConfig (101-141)
  • get_valid_tactics (196-214)
  • forward (220-244)
  • get (362-365)
  • choose_one (400-534)
flashinfer/jit/norm.py (1)
  • gen_norm_module (21-33)
csrc/norm.cu (2)
  • rmsnorm (22-78)
  • rmsnorm (22-22)
πŸͺ› GitHub Actions: pre-commit
tests/norm/__init__.py

[error] 1-1: End-of-file fixer modified the file. Please re-run hooks or commit again.


[error] 1-1: Trailing whitespace detected and removed by trailing-whitespace hook.

tests/norm/test_rmsnorm_fp4_quant.py

[error] 20-20: ruff-check: F401 'flashinfer' imported but unused.

flashinfer/norm.py

[error] 19-19: ruff-check: F401 'typing.Tuple' imported but unused.

πŸͺ› Ruff (0.14.8)
tests/norm/test_rmsnorm_fp4_quant.py

62-62: Avoid specifying long messages outside the exception class

(TRY003)


94-94: Avoid specifying long messages outside the exception class

(TRY003)


141-141: Do not catch blind exception: Exception

(BLE001)


202-202: Unpacked variable ref_fp4 is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


202-202: Unpacked variable ref_scale is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


257-257: Unpacked variable ref_fp4 is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


257-257: Unpacked variable ref_scale is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

flashinfer/norm.py

343-346: Avoid specifying long messages outside the exception class

(TRY003)


355-358: Avoid specifying long messages outside the exception class

(TRY003)


363-366: Avoid specifying long messages outside the exception class

(TRY003)


376-378: Avoid specifying long messages outside the exception class

(TRY003)


446-446: Unpacked variable inv_variance is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


582-582: Unused method argument: profile

(ARG002)


587-587: Unpacked variable weight is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


588-588: Unpacked variable y_fp4 is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


589-589: Unpacked variable block_scale is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


590-590: Unpacked variable eps is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


613-613: Unused method argument: do_preparation

(ARG002)


614-614: Unused method argument: kwargs

(ARG002)


742-745: Avoid specifying long messages outside the exception class

(TRY003)


748-751: Avoid specifying long messages outside the exception class

(TRY003)


762-764: Avoid specifying long messages outside the exception class

(TRY003)


768-770: Avoid specifying long messages outside the exception class

(TRY003)


773-775: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
πŸ”‡ Additional comments (1)
flashinfer/__init__.py (1)

93-99: Expose rmsnorm_fp4quant at top level looks good.

The additional import cleanly surfaces rmsnorm_fp4quant in the top-level flashinfer API alongside other norm functions; no further changes needed here.

Comment on lines 1 to 2
# Tests for normalization operations

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟑 Minor

Fix trailing whitespace / EOF formatting to satisfy pre-commit.

Pre-commit reports both end-of-file-fixer and trailing-whitespace modifying this file. Ensure the file has no trailing spaces and ends with exactly one newline, e.g.:

# Tests for normalization operations

with a single newline at the end of the file.

🧰 Tools
πŸͺ› GitHub Actions: pre-commit

[error] 1-1: End-of-file fixer modified the file. Please re-run hooks or commit again.


[error] 1-1: Trailing whitespace detected and removed by trailing-whitespace hook.

πŸ€– Prompt for AI Agents
In tests/norm/__init__.py around lines 1 to 2, remove any trailing spaces on the
line and ensure the file ends with exactly one newline character (no extra blank
lines at EOF); edit the file so the only content is the comment line with no
trailing whitespace and save with a single terminating newline.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (6)
flashinfer/norm.py (5)

19-19: Remove unused Tuple import.

This import is flagged by pre-commit (ruff F401) and has been noted in previous reviews. Please remove it to satisfy the linter.


328-335: Remove unused _RMSNormFP4QuantUIDs enum.

This enum is defined but never used in the implementation. As noted in a previous review, removing it will improve code clarity.


583-587: Update type hint to List[Any] for mixed-type inputs.

As noted in a previous review, the inputs parameter contains both tensors and scalar values (eps, block_size), so List[torch.Tensor] is misleading.


613-619: Update type hint to List[Any] for mixed-type inputs.

As noted in a previous review, the inputs parameter contains both tensors and scalar values (eps, block_size), so List[torch.Tensor] is misleading.


774-823: Fix 3D tensor handling for output tensors.

As noted in a comprehensive previous review, the current implementation doesn't properly handle 3D output tensors (y_fp4 and block_scale). When the input is 3D, these outputs are also 3D but are passed to the runner without flattening, causing a shape mismatch at line 637-638 in the forward method.

The previous review provided a detailed fix that includes:

  • Flattening output tensors to 2D views when input is 3D
  • Adding comprehensive shape and dtype validation for outputs
  • Clearer logic flow for handling 2D vs 3D cases
tests/norm/test_rmsnorm_fp4_quant.py (1)

20-20: Remove unused flashinfer import.

This import is flagged by pre-commit (ruff F401) and has been noted in a previous review. Only the specific imports from flashinfer.norm are needed.

🧹 Nitpick comments (2)
flashinfer/norm.py (1)

450-457: Use _ for unused return value.

The inv_variance output from graph.rmsnorm is not used. Consider using _ to make this explicit:

-    y_rmsnorm, inv_variance = graph.rmsnorm(
+    y_rmsnorm, _ = graph.rmsnorm(
tests/norm/test_rmsnorm_fp4_quant.py (1)

155-158: Remove commented-out parametrize decorators.

These commented-out lines appear to be leftover debugging code and should be removed before merging.

 @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
-# @pytest.mark.parametrize("batch_size", [])
-# @pytest.mark.parametrize("seq_len", [])
-# @pytest.mark.parametrize("hidden_size", [])
-# @pytest.mark.parametrize("dtype", [])
 def test_rmsnorm_fp4quant_3d(batch_size, seq_len, hidden_size, dtype):
πŸ“œ Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

πŸ“₯ Commits

Reviewing files that changed from the base of the PR and between 13289e8 and 1be5a03.

πŸ“’ Files selected for processing (4)
  • flashinfer/norm.py (2 hunks)
  • scripts/task_jit_run_tests_part5.sh (1 hunks)
  • tests/norm/__init__.py (1 hunks)
  • tests/norm/test_rmsnorm_fp4_quant.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/norm/init.py
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/norm.py (5)
flashinfer/api_logging.py (1)
  • flashinfer_api (464-565)
flashinfer/autotuner.py (7)
  • AutoTuner (335-791)
  • OptimizationProfile (168-183)
  • TunableRunner (194-247)
  • get_valid_tactics (196-214)
  • forward (220-244)
  • get (362-365)
  • choose_one (400-534)
flashinfer/utils.py (2)
  • backend_requirement (902-1184)
  • supported_compute_capability (819-899)
include/flashinfer/trtllm/common.h (1)
  • device (83-90)
csrc/norm.cu (2)
  • rmsnorm (22-78)
  • rmsnorm (22-22)
πŸͺ› Ruff (0.14.8)
flashinfer/norm.py

347-350: Avoid specifying long messages outside the exception class

(TRY003)


359-362: Avoid specifying long messages outside the exception class

(TRY003)


367-370: Avoid specifying long messages outside the exception class

(TRY003)


380-382: Avoid specifying long messages outside the exception class

(TRY003)


450-450: Unpacked variable inv_variance is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


586-586: Unused method argument: profile

(ARG002)


591-591: Unpacked variable weight is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


592-592: Unpacked variable y_fp4 is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


593-593: Unpacked variable block_scale is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


594-594: Unpacked variable eps is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


617-617: Unused method argument: do_preparation

(ARG002)


618-618: Unused method argument: kwargs

(ARG002)


667-667: Unused function argument: input

(ARG001)


668-668: Unused function argument: weight

(ARG001)


669-669: Unused function argument: y_fp4

(ARG001)


670-670: Unused function argument: block_scale

(ARG001)


671-671: Unused function argument: eps

(ARG001)


672-672: Unused function argument: block_size

(ARG001)


763-766: Avoid specifying long messages outside the exception class

(TRY003)


769-772: Avoid specifying long messages outside the exception class

(TRY003)


783-785: Avoid specifying long messages outside the exception class

(TRY003)


789-791: Avoid specifying long messages outside the exception class

(TRY003)


794-796: Avoid specifying long messages outside the exception class

(TRY003)

tests/norm/test_rmsnorm_fp4_quant.py

63-63: Avoid specifying long messages outside the exception class

(TRY003)


74-74: Do not catch blind exception: Exception

(BLE001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
πŸ”‡ Additional comments (1)
scripts/task_jit_run_tests_part5.sh (1)

15-15: LGTM!

The addition of the new test file to the test suite is correct and follows the existing pattern of running each test file separately.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant