Skip to content

Conversation

@bkryu
Copy link
Collaborator

@bkryu bkryu commented Dec 12, 2025

📌 Description

bench_gpu_time_with_cudagraph uses CUDA events with multiple kernel iterations within the graph to amortize timing overhead. However, the L2 cache becomes hot after the first iteration, leading to misleadingly better performance for memory-bound kernels. (Note: this is not an issue for cupti-based measurements with bench_gpu_time_with_cupti where the measurement overhead is small and thus L2 cache can be flushed between kernel launches)

This PR implements rotating buffers that cycle through different memory regions across iterations, ensuring cold L2 cache for each kernel invocation.

🔍 Related Issues

#2187

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Chores
    • Enhanced benchmarking infrastructure with a rotating-buffer subsystem to more reliably measure cold-L2 GPU performance and warn when rotation isn't applicable.
    • Expanded test harness and benchmark APIs to accept explicit input arguments and rotate buffers for per-backend reproducibility and autotuning.
    • Improved GPU timing utilities to propagate input args/kwargs across CUPTI, CUDA-graph and event paths.

✏️ Tip: You can customize this high-level summary in your review settings.

@bkryu bkryu added this to the 2025 Dec milestone Dec 12, 2025
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 12, 2025

Walkthrough

Refactors benchmark/test wrappers to accept explicit per-run inputs and introduces rotating-buffer utilities for cold‑L2 benchmarking in flashinfer/testing/utils.py. Threads rotate_buffers and input_args/input_kwargs through bench_gpu_time, CUDA-graph, CUPTI, and CUDA-event timing paths, and updates attention, gemm, and moe benchmark routines to pass richer argument sets into their run_backend/run_backend_wrapper functions.

Changes

Cohort / File(s) Summary
Testing utilities & benchmarking infra
flashinfer/testing/utils.py
Added rotating-buffer utilities (get_l2_cache_size, calculate_rotation_count, _create_rotated_buffer_copies, _clone_structure, _extract_gpu_tensors, etc.). Extended bench timing APIs — bench_gpu_time, bench_gpu_time_with_cudagraph, bench_gpu_time_with_cupti, bench_gpu_time_with_cuda_event — to accept rotate_buffers, input_args, and input_kwargs, and implemented rotation-aware invocation and CUDA-graph replay.
Attention benchmarks
benchmarks/routines/attention.py
Expanded multiple run_backend_wrapper signatures to accept explicit Q/K/V, KV cache variants, workspace buffers, block tables, and sequence-length indicators; updated all call sites to pass input_args and use bench_gpu_time(..., rotate_buffers=True, input_args=...).
GEMM benchmarks
benchmarks/routines/gemm.py
Extended run_backend signatures across FP8/FP4 tests to accept quantized inputs, scales, and index/indptr args; propagated input_args through autotune warmups and bench_gpu_time(..., rotate_buffers=True, input_args=...); adjusted reference checks to convert FP8 refs to float32 and relaxed some cosine-similarity tolerances.
MOE benchmarks
benchmarks/routines/moe.py
Expanded run_fp4_moe, run_cutlass (multiple variants), run_fp8_block_moe, and run_fp8_per_tensor_moe signatures to accept routing logits/biases, quantized inputs, per-path weights/scales; introduced input_args_for_bench and threaded rotate_buffers=True + input_args through warmup and benchmark invocations.

Sequence Diagram(s)

sequenceDiagram
    participant Bench as Bench Runner
    participant BenchAPI as bench_gpu_time
    participant Rotator as RotatingBufferMgr
    participant CUGraph as CUDA Graph / Capture
    participant CUPTI as CUPTI timing (optional)
    participant GPU as GPU / Backend

    Bench->>BenchAPI: request timing(fn, rotate_buffers=True, input_args)
    BenchAPI->>Rotator: detect GPU tensors & compute rotations
    Rotator-->>BenchAPI: rotated input copies (N)
    BenchAPI->>CUGraph: capture graph with rotated inputs (iter 1..N)
    CUGraph->>GPU: replay captures per rotation
    alt CUPTI available
      BenchAPI->>CUPTI: measure with input_args (rotated if enabled)
      CUPTI->>GPU: timed runs
    end
    BenchAPI->>Bench: aggregated timing results
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

  • Areas needing extra attention:
    • flashinfer/testing/utils.py: correctness of GPU-tensor extraction, rotation-copy semantics (non-contiguous tensors), and CUDA-graph replay correctness.
    • Threading of new input_args across all autotune/warmup loops in attention/gemm/moe files to ensure no missing or misordered parameters.
    • FP8→float32 conversion and adjusted tolerance logic in GEMM tests.

Possibly related PRs

Suggested reviewers

  • aleozlx
  • cyx-6
  • nvmbreughe
  • jimmyzho
  • jiahanc

Poem

🐰 I hopped through buffers, cloned and spun,
Passing args explicit, one by one.
Graphs replayed with tensors new,
Cold‑L2 sings in me and you.
Benchmarks hum — a jitter-free run! ✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 59.38% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly and specifically describes the main feature: adding cold L2 cache benchmarking with rotating buffers, which directly aligns with the primary changes.
Description check ✅ Passed The PR description includes motivation (L2 cache warming issue), solution (rotating buffers), related issue link, and pre-commit checklist completion. However, the tests checklist items remain unchecked, indicating tests may not have been added or confirmed passing.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the accuracy of GPU performance benchmarking by introducing a "cold L2 cache" measurement capability. By implementing rotating buffers, the system now cycles through distinct memory regions for each kernel execution within a benchmark, preventing misleadingly high performance figures that can arise from a "hot" L2 cache. This is particularly vital for memory-bound kernels where cache effects can heavily influence observed speeds, ensuring that benchmark results more faithfully represent real-world cold-start scenarios.

Highlights

  • Cold L2 Cache Benchmarking: Implemented a rotating buffer mechanism to ensure a "cold" L2 cache for each kernel invocation during benchmarking, providing more accurate performance measurements for memory-bound kernels by preventing misleadingly high performance from hot cache effects.
  • New Utility Functions: Added several helper functions in flashinfer/testing/utils.py (get_l2_cache_size, _calculate_tensor_bytes, _extract_gpu_tensors, calculate_rotation_count, _clone_structure, _create_rotated_buffer_copies) to manage L2 cache size queries, tensor byte calculation, GPU tensor extraction, rotation count determination, and deep cloning of tensor structures for buffer rotation.
  • Enhanced Benchmarking API: The core benchmarking functions (bench_gpu_time_with_cuda_event, bench_gpu_time_with_cupti, bench_gpu_time_with_cudagraph) now accept explicit input_args and input_kwargs, improving flexibility and enabling the new rotating buffer feature for passing arguments to the benchmarked function.
  • Integration into Benchmarks: Existing benchmark routines for attention, GEMM, and MoE operations have been updated to leverage the new rotate_buffers option in the bench_gpu_time wrapper, allowing for cold L2 cache analysis across various operations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new 'rotating buffer' mechanism for cold-L2 cache benchmarking across various attention, GEMM, and MoE routines. The flashinfer/testing/utils.py file was updated to include utilities for calculating L2 cache size, extracting GPU tensors, cloning tensor structures while preserving strides, and creating multiple copies of input arguments for buffer rotation. The core benchmarking functions (bench_gpu_time_with_cuda_event, bench_gpu_time_with_cudagraph, bench_gpu_time_with_cupti, and the main bench_gpu_time wrapper) were refactored to accept explicit input_args and input_kwargs, and a new rotate_buffers flag. When rotate_buffers is enabled with CUDA graphs, the system automatically determines the necessary number of buffer copies based on L2 cache size and input tensor sizes, then rotates through these copies during graph capture to ensure cold-L2 cache conditions for each kernel invocation. Correspondingly, the run_backend_wrapper and run_backend functions in benchmarks/routines/attention.py, benchmarks/routines/gemm.py, and benchmarks/routines/moe.py were updated to accept all relevant input tensors and parameters as explicit arguments, facilitating their use with the new rotating buffer mechanism in the benchmarking utilities. Additionally, error returns for unsupported backends were changed from res to None in attention benchmarks.

@bkryu
Copy link
Collaborator Author

bkryu commented Dec 12, 2025

/bot run

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
benchmarks/routines/gemm.py (1)

841-982: Fix mm_fp4: mat2_fp4_trtllm / mat2_inv_s_trtllm can be undefined when trtllm isn’t selected.
Because these are now included in input_args, they must exist regardless of backend choice.

Minimal fix: initialize defaults before the conditional.

@@
-    if "trtllm" in backends:
+    # Defaults so we can always thread these through input_args
+    mat2_fp4_trtllm = mat2_fp4
+    mat2_inv_s_trtllm = mat2_inv_s
+    if "trtllm" in backends:
         mat2_fp4_trtllm, mat2_inv_s_trtllm = flashinfer.nvfp4_quantize(
             mat2,
             global_sf_mat2,
             sfLayout=flashinfer.SfLayout.layout_128x4,
             do_shuffle=True,
         )
flashinfer/testing/utils.py (1)

880-1012: Critical: CUPTI-missing fallback drops input_args/input_kwargs causing incorrect behavior.
When CUPTI isn't installed or is < 13, and use_cuda_graph=False, the fallback to bench_gpu_time_with_cuda_event is called without passing the provided input_args and input_kwargs, while the use_cuda_graph=True fallback correctly passes them to bench_gpu_time_with_cudagraph. This causes the function to ignore user-provided arguments.

Fix:

         else:
             return bench_gpu_time_with_cuda_event(
                 fn=fn,
                 dry_run_iters=dry_run_iters,
                 repeat_iters=repeat_iters,
                 dry_run_time_ms=dry_run_time_ms,
                 repeat_time_ms=repeat_time_ms,
                 l2_flush=l2_flush,
                 l2_flush_size_mb=l2_flush_size_mb,
                 l2_flush_device=l2_flush_device,
                 sleep_after_run=sleep_after_run,
+                input_args=input_args,
+                input_kwargs=input_kwargs,
             )
benchmarks/routines/attention.py (1)

499-590: Guard unsupported backends and fix trtllm-gen KV stride selection (decode).

The decode routine supports only ["fa2", "fa2_tc", "trtllm-gen", "cudnn", "trtllm-native"], but the argument parser allows ["fa2", "fa2_tc", "fa3", "cudnn", "cutlass", "trtllm-gen", "trtllm-native", "trtllm-gen-native"]. Three unsupported backends slip through:

  • In the refcheck path (lines 542–552), calling .detach() on None will crash.
  • kv_cache_for_trt is prepared and FP8-converted (lines 358–497) but never passed to run_backend_wrapper; trtllm-gen receives the incorrectly-strided kv_cache instead.
  • bench_gpu_time is called unconditionally (line 566), benchmarking unsupported backends as no-ops.

The proposed patch correctly:

  • Conditionally passes kv_cache_for_trt to trtllm-gen
  • Guards refcheck against None with an early continue
  • Skips benchmarking of unsupported backends
🧹 Nitpick comments (3)
benchmarks/routines/attention.py (2)

22-43: Prefer warnings.warn (and consider de-duping) for deprecated backend renames.
print() is fine for a CLI, but warnings.warn(..., DeprecationWarning/UserWarning) is easier to filter and won’t get lost in perf logs. Also consider de-duping if --backends contains repeated deprecated entries.

Also applies to: 180-185


992-1099: rotate_buffers=True + “fat” input_args may defeat cold-L2 intent or blow up memory.
Right now input_args includes large scratch buffers (e.g., workspace_buffer) plus index tables; the rotation heuristic counts all CUDA tensors, which can (a) disable rotation because “inputs exceed 5×L2” even if the kernel’s real working-set doesn’t, or (b) clone huge scratch tensors many times if rotation triggers.

Suggestion: pass only the tensors that materially contribute to the memory-working-set into input_args (or add an exclusion mechanism in bench_gpu_time_with_cudagraph).

Also applies to: 1485-1607, 1942-2039

benchmarks/routines/moe.py (1)

928-952: Avoid selected_experts.to(torch.int) inside the timed callable.
Move the cast out of run_cutlass and pass the already-cast tensor in input_args_for_bench to avoid measuring conversion overhead.

Example:

@@
-    routing_weights, selected_experts = _compute_routing(router_logits, top_k)
+    routing_weights, selected_experts = _compute_routing(router_logits, top_k)
+    selected_experts_i32 = selected_experts.to(torch.int)
@@
-        def run_cutlass(x, selected_experts, routing_weights, w31_local, w2_local, out):
+        def run_cutlass(x, selected_experts_i32, routing_weights, w31_local, w2_local, out):
             return cutlass_fused_moe(
                 x,
-                selected_experts.to(torch.int),
+                selected_experts_i32,
                 routing_weights,
@@
-        input_args_for_bench = (x, selected_experts, routing_weights, w31_local, w2_local, out)
+        input_args_for_bench = (x, selected_experts_i32, routing_weights, w31_local, w2_local, out)

Also applies to: 988-1019, 1078-1104, 1123-1135

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1ac4e1d and 6194eac.

📒 Files selected for processing (4)
  • benchmarks/routines/attention.py (12 hunks)
  • benchmarks/routines/gemm.py (12 hunks)
  • benchmarks/routines/moe.py (15 hunks)
  • flashinfer/testing/utils.py (15 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
benchmarks/routines/gemm.py (1)
flashinfer/testing/utils.py (1)
  • bench_gpu_time (1399-1537)
benchmarks/routines/attention.py (4)
flashinfer/testing/utils.py (1)
  • bench_gpu_time (1399-1537)
benchmarks/bench_append_paged_kv_cache.py (1)
  • fn (116-127)
benchmarks/bench_append_paged_mla_kv_cache.py (1)
  • fn (100-111)
benchmarks/bench_fused_add_rmsnorm.py (1)
  • fn (41-42)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (3)
benchmarks/routines/gemm.py (1)

265-305: Signature expansion + bench_gpu_time(input_args=...) wiring looks consistent.
The explicit argument passing should play nicely with the new bench_gpu_time entry point and avoids closure capture surprises.

Also applies to: 445-491, 648-700

benchmarks/routines/moe.py (1)

674-770: FP4 MoE runner signature threading into bench_gpu_time(input_args=...) is consistent.
Good alignment with the new benchmarking interface.

flashinfer/testing/utils.py (1)

731-877: Plumbing of input_args/input_kwargs through event/graph/unified entrypoint looks good.
The call_fn() abstraction makes behavior consistent across backends and aligns with the repo-wide signature expansions.

Also applies to: 1170-1397, 1399-1537

Comment on lines +38 to +211
def get_l2_cache_size(device=None) -> int:
"""
Get L2 cache size in bytes for the given CUDA device.
Args:
device: CUDA device (int, torch.device, or None for current device).
Returns:
L2 cache size in bytes.
"""
if device is None:
device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device)
return props.L2_cache_size


def _calculate_tensor_bytes(tensors: List[torch.Tensor]) -> int:
"""
Calculate total bytes of tensors residing on GPU.
Assumes all tensors are on the same device.
Args:
tensors: List of torch.Tensor objects.
Returns:
Total bytes occupied by GPU tensors (CPU tensors are ignored).
"""
total = 0
for t in tensors:
if isinstance(t, torch.Tensor) and t.is_cuda:
total += t.numel() * t.element_size()
return total


def _extract_gpu_tensors(obj) -> List[torch.Tensor]:
"""
Recursively extract all GPU-resident tensors from a nested structure
of lists, tuples, and dicts.
Args:
obj: Object to extract tensors from (can be tensor, list, tuple, dict, or other).
Returns:
Flat list of tensors on GPU found in the structure.
"""
tensors = []
if isinstance(obj, torch.Tensor) and obj.is_cuda:
tensors.append(obj)
elif isinstance(obj, (list, tuple)):
for item in obj:
tensors.extend(_extract_gpu_tensors(item))
elif isinstance(obj, dict):
for v in obj.values():
tensors.extend(_extract_gpu_tensors(v))
return tensors


def calculate_rotation_count(
tensors: List[torch.Tensor], device=None, min_rotations: int = 2
) -> int:
"""
Calculate the number of buffer copies needed to ensure cold L2 cache.
The function uses conservative thresholds to account for:
- LRU eviction being gradual (not all data evicted when capacity exceeded)
- Cache associativity effects (some data may persist in non-conflicting sets)
- Hardware prefetching behavior
Returns 1 (no rotation needed) only when tensor size substantially exceeds
L2 cache (>= 5x), ensuring cache effects are truly negligible.
Args:
tensors: List of tensors to consider for rotation (must be on GPU).
device: Device for L2 cache query (None for current device).
min_rotations: Minimum number of rotations when rotation is needed.
Returns:
Number of buffer copies needed (1 means no rotation needed).
"""
l2_size = get_l2_cache_size(device)
total_bytes = _calculate_tensor_bytes(tensors)

if total_bytes == 0:
return 1 # No tensors to rotate

# Use aggressive threshold: only skip rotation if tensors far exceed L2 (5x)
# This ensures cache effects are truly negligible even with prefetching
safe_cache_threshold = l2_size * 5
if total_bytes >= safe_cache_threshold:
return 1 # Tensors far exceed L2, no rotation needed

# Conservative formula: ensure between any two uses of the same buffer,
# we've accessed enough data to fully flush L2 with margin
# Using safe_cache_threshold ensures we account for all cache effects
num_rotations = math.ceil(safe_cache_threshold / total_bytes) + 1

return max(min_rotations, num_rotations)


def _clone_structure(obj):
"""
Deep clone a nested structure, cloning GPU tensors with detach().clone()
while preserving scalars, booleans, and other non-tensor values.
For non-contiguous tensors (e.g., created with as_strided), this function
preserves the stride pattern using torch.empty_strided() + copy_(). This is
important for backends like cuDNN that expect specific memory layouts.
Args:
obj: Object to clone (tensor, list, tuple, dict, or other).
Returns:
Cloned structure with GPU tensors cloned, other values preserved.
"""
if isinstance(obj, torch.Tensor):
if obj.is_cuda:
if obj.is_contiguous():
return obj.detach().clone()
else:
# Preserve stride pattern for non-contiguous tensors
# (e.g., as_strided views used by cuDNN paged attention)
result = torch.empty_strided(
obj.size(),
obj.stride(),
dtype=obj.dtype,
device=obj.device,
)
result.copy_(obj.detach())
return result
else:
return obj # CPU tensors returned as-is
elif isinstance(obj, list):
return [_clone_structure(item) for item in obj]
elif isinstance(obj, tuple):
return tuple(_clone_structure(item) for item in obj)
elif isinstance(obj, dict):
return {k: _clone_structure(v) for k, v in obj.items()}
else:
# Non-tensor, non-container: return as-is (e.g., int, float, str, bool, None)
return obj


def _create_rotated_buffer_copies(
input_args: Tuple, input_kwargs: dict, num_rotations: int
) -> List[Tuple[Tuple, dict]]:
"""
Create multiple copies of input_args and input_kwargs for buffer rotation.
The first copy (index 0) uses the original args/kwargs.
Subsequent copies clone all GPU tensors while preserving other values.
Args:
input_args: Positional arguments tuple.
input_kwargs: Keyword arguments dict.
num_rotations: Number of buffer copies to create.
Returns:
List of (args, kwargs) tuples, one for each rotation index.
"""
if num_rotations <= 1:
return [(input_args, input_kwargs)]

copies = []
# First copy uses original args/kwargs
copies.append((input_args, input_kwargs))

# Create cloned copies for remaining rotations
for _ in range(num_rotations - 1):
cloned_args = _clone_structure(input_args)
cloned_kwargs = _clone_structure(input_kwargs)
copies.append((cloned_args, cloned_kwargs))

return copies

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Cap rotation count to avoid pathological OOM (tiny inputs).
For very small total_bytes, ceil((5×L2)/total_bytes) can explode (thousands+ copies). Add a max_rotations cap and warn/clamp.

Example patch:

@@
-def calculate_rotation_count(
-    tensors: List[torch.Tensor], device=None, min_rotations: int = 2
-) -> int:
+def calculate_rotation_count(
+    tensors: List[torch.Tensor],
+    device=None,
+    min_rotations: int = 2,
+    max_rotations: int = 64,
+) -> int:
@@
-    num_rotations = math.ceil(safe_cache_threshold / total_bytes) + 1
-
-    return max(min_rotations, num_rotations)
+    num_rotations = math.ceil(safe_cache_threshold / total_bytes) + 1
+    num_rotations = max(min_rotations, num_rotations)
+    if num_rotations > max_rotations:
+        warnings.warn(
+            f"Requested {num_rotations} rotating buffer copies (inputs={total_bytes} bytes, "
+            f"L2={l2_size} bytes); clamping to max_rotations={max_rotations}.",
+            category=UserWarning,
+            stacklevel=2,
+        )
+        num_rotations = max_rotations
+    return num_rotations
🤖 Prompt for AI Agents
flashinfer/testing/utils.py around lines 38 to 211: the current
calculate_rotation_count can produce an extremely large num_rotations when
total_bytes is tiny (ceil((5*L2)/total_bytes) explodes), risking OOM; clamp the
computed rotations with a reasonable max_rotations (either hard-coded, e.g. 16
or 32, or passed as an optional parameter) and log/warn when you clamp so
callers know rotation was limited; implement by adding a max_rotations argument
(default 16), compute num_rotations as now, then set num_rotations =
min(num_rotations, max_rotations) and emit a warning via warnings.warn or logger
if clamping occurred, finally return max(min_rotations, num_rotations).

@flashinfer-bot
Copy link
Collaborator

GitLab MR !193 has been created, and the CI pipeline #40118719 is currently running. I'll report back once the pipeline job completes.

num_iters_within_graph (int): Kernel calls per graph (CUDA graph mode only,
default: 10).
rotate_buffers (bool): If True, use rotating buffers for cold-L2 benchmarking
(CUDA graph mode only, default: False).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have this constraint? From my understanding rotate_buffer should also work without CUDAGraph?

Copy link
Collaborator Author

@bkryu bkryu Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no technical constraint, but I did not implement it because when running outside of CUDA graph, we can simply flush L2 between runs with a buffer.zero_() (here)

    torch.cuda.synchronize()
    for iter_idx in range(repeat_iters):
        if l2_flush:
            buffer.zero_()
        start_events[iter_idx].record()
        call_fn()
        end_events[iter_idx].record()

It is the CUDA graph + CUDA events measurements that primarily requires rotating buffers because we launch the kernel num_iters_within_graph times within the CUDA graph and cannot flush the cache in between.

    g = torch.cuda.CUDAGraph()
    with torch.cuda.graph(g):
        if rotate_buffers and num_rotations > 1:
            # Capture with rotating buffers: use buffer[iter % num_rotations]
            for iter_idx in range(num_iters_within_graph):
                buf_idx = iter_idx % num_rotations
                call_fn_with_rotation(buf_idx)
        else:
            # Non-rotating capture (uses original args if provided)
            for _ in range(num_iters_within_graph):
                call_fn()
    torch.cuda.synchronize()
...
...
    torch.cuda.synchronize()
    for iter_idx in range(repeat_iters):
        if l2_flush:
            buffer.zero_()
        start_events[iter_idx].record()
        g.replay()
        end_events[iter_idx].record()

However, I can see the untidiness of users needing to think about L2 flush vs. rotating buffers.

Do you have any thoughts that come to mind whether or how we might want to clean this up?

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/testing/utils.py (1)

881-1011: Bug: CUPTI-unavailable fallback drops input_args/input_kwargs.

This makes bench_gpu_time(enable_cupti=True, ...) fail for fn that needs explicit args when CUPTI isn’t present.

         else:
             return bench_gpu_time_with_cuda_event(
                 fn=fn,
                 dry_run_iters=dry_run_iters,
                 repeat_iters=repeat_iters,
                 dry_run_time_ms=dry_run_time_ms,
                 repeat_time_ms=repeat_time_ms,
                 l2_flush=l2_flush,
                 l2_flush_size_mb=l2_flush_size_mb,
                 l2_flush_device=l2_flush_device,
                 sleep_after_run=sleep_after_run,
+                input_args=input_args,
+                input_kwargs=input_kwargs,
             )
♻️ Duplicate comments (1)
flashinfer/testing/utils.py (1)

95-135: Cap calculate_rotation_count to avoid pathological OOM for tiny inputs.

This still has the unbounded ceil((5×L2)/total_bytes) blow-up risk for small total_bytes (thousands+ copies). Please clamp and warn (as previously suggested).

 def calculate_rotation_count(
-    tensors: List[torch.Tensor], device=None, min_rotations: int = 2
+    tensors: List[torch.Tensor],
+    device=None,
+    min_rotations: int = 2,
+    max_rotations: int = 64,
 ) -> int:
@@
-    num_rotations = math.ceil(safe_cache_threshold / total_bytes) + 1
-
-    return max(min_rotations, num_rotations)
+    num_rotations = math.ceil(safe_cache_threshold / total_bytes) + 1
+    num_rotations = max(min_rotations, num_rotations)
+    if num_rotations > max_rotations:
+        warnings.warn(
+            f"Requested {num_rotations} rotating buffer copies (inputs={total_bytes} bytes, "
+            f"L2={l2_size} bytes); clamping to max_rotations={max_rotations}.",
+            category=UserWarning,
+            stacklevel=2,
+        )
+        num_rotations = max_rotations
+    return num_rotations
🧹 Nitpick comments (2)
flashinfer/testing/utils.py (2)

54-92: Dedup extracted GPU tensors to avoid double-counting and inconsistent rotation sizing.

If the same tensor object is referenced multiple times in input_args/input_kwargs, you’ll currently count it multiple times.

 def _extract_gpu_tensors(obj) -> List[torch.Tensor]:
@@
-    tensors = []
-    if isinstance(obj, torch.Tensor) and obj.is_cuda:
-        tensors.append(obj)
+    tensors: List[torch.Tensor] = []
+    seen: set[int] = set()
+    def visit(x):
+        if isinstance(x, torch.Tensor) and x.is_cuda:
+            k = id(x)
+            if k not in seen:
+                seen.add(k)
+                tensors.append(x)
+        elif isinstance(x, (list, tuple)):
+            for item in x:
+                visit(item)
+        elif isinstance(x, dict):
+            for v in x.values():
+                visit(v)
+
+    visit(obj)
-    elif isinstance(obj, (list, tuple)):
-        for item in obj:
-            tensors.extend(_extract_gpu_tensors(item))
-    elif isinstance(obj, dict):
-        for v in obj.values():
-            tensors.extend(_extract_gpu_tensors(v))
     return tensors

1275-1544: Use the inputs’ CUDA device for L2 sizing (and optionally guard misconfiguration of rotate_buffers).

Right now rotation sizing is keyed off l2_flush_device, which may differ from the tensors’ device in multi-GPU situations.

-            num_rotations = calculate_rotation_count(gpu_tensors, l2_flush_device)
+            num_rotations = calculate_rotation_count(gpu_tensors, gpu_tensors[0].device)

Optional (to reduce user confusion): warn if rotate_buffers=True while use_cuda_graph=False in bench_gpu_time, since it won’t do anything in the CUDA-event path.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6194eac and 531fd87.

📒 Files selected for processing (1)
  • flashinfer/testing/utils.py (15 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (1)
flashinfer/testing/utils.py (1)

731-879: CUDA-event path arg plumbing looks consistent.

Comment on lines +38 to +52
def get_l2_cache_size(device=None) -> int:
"""
Get L2 cache size in bytes for the given CUDA device.
Args:
device: CUDA device (int, torch.device, or None for current device).
Returns:
L2 cache size in bytes.
"""
if device is None:
device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device)
return props.L2_cache_size

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# First, check the exact content of the file and confirm the code snippet
head -60 flashinfer/testing/utils.py | tail -30

Repository: flashinfer-ai/flashinfer

Length of output: 922


🏁 Script executed:

# Search for get_l2_cache_size usage in the codebase
rg "get_l2_cache_size" --type py -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 630


🏁 Script executed:

# Check for other device handling patterns in the file
rg "torch\.cuda\." flashinfer/testing/utils.py -B 1 -A 1

Repository: flashinfer-ai/flashinfer

Length of output: 3540


🏁 Script executed:

# Check if there are CUDA availability checks elsewhere in the codebase
rg "torch\.cuda\.is_available" --type py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 1583


🏁 Script executed:

# Check the context of get_l2_cache_size calls - especially in calculate_rotation_count
rg "calculate_rotation_count" --type py -B 5 -A 10

Repository: flashinfer-ai/flashinfer

Length of output: 3124


🏁 Script executed:

# Look at the full function that calls get_l2_cache_size
sed -n '60,100p' flashinfer/testing/utils.py

Repository: flashinfer-ai/flashinfer

Length of output: 1264


🏁 Script executed:

# Check if there's any CUDA requirement or guard at module level
head -30 flashinfer/testing/utils.py

Repository: flashinfer-ai/flashinfer

Length of output: 877


🌐 Web query:

torch.cuda.get_device_properties PyTorch device type handling str torch.device

💡 Result:

Short answer

  • torch.cuda.get_device_properties accepts a torch.device, an int (legacy ordinal), or a device string (e.g. "cuda", "cuda:0"). If device is None it uses the current CUDA device. The function is intended for CUDA devices (not CPU/mps), so pass a CUDA device (or you'll get an error). [1][2]

Example

  • torch.cuda.get_device_properties(torch.device("cuda:0"))
  • torch.cuda.get_device_properties("cuda:0")
  • torch.cuda.get_device_properties(0) # legacy ordinal
  • torch.cuda.get_device_properties(None) # uses current device

Refs

  • torch.cuda.get_device_properties docs (accepts torch.device / int / str; None → current device). [1]
  • torch.device description (device string formats like "cuda", "cuda:0", and semantics). [2]

🏁 Script executed:

# Verify the exact signature and behavior by checking PyTorch usage
# Let's see if there are any existing device handling patterns in flashinfer
rg "torch\.device\(" flashinfer/testing/utils.py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if get_device_properties is called elsewhere with different device types
rg "get_device_properties" --type py -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 5383


🏁 Script executed:

# Let's verify the actual torch.cuda behavior with a focused test
python3 << 'EOF'
import torch

# Check if torch.cuda.get_device_properties accepts various types
# (Note: we're just checking the signature, not actually running on GPU)
try:
    # Check the function signature
    import inspect
    sig = inspect.signature(torch.cuda.get_device_properties)
    print("Signature:", sig)
    print("\nDocstring snippet:")
    print(torch.cuda.get_device_properties.__doc__[:500])
except Exception as e:
    print(f"Error: {e}")
EOF

Repository: flashinfer-ai/flashinfer

Length of output: 154


Add CUDA availability check to get_l2_cache_size (optional defensive programming).

The function will fail in CPU-only environments. While get_l2_cache_size is currently only called from benchmarking code that already requires CUDA, a defensive check makes the function more robust if reused elsewhere:

 def get_l2_cache_size(device=None) -> int:
     """
     Get L2 cache size in bytes for the given CUDA device.
 
     Args:
         device: CUDA device (int, torch.device, or None for current device).
 
     Returns:
         L2 cache size in bytes.
     """
+    if not torch.cuda.is_available():
+        raise RuntimeError("CUDA is not available; cannot query L2 cache size.")
     if device is None:
         device = torch.cuda.current_device()
     props = torch.cuda.get_device_properties(device)
     return props.L2_cache_size

Note: torch.cuda.get_device_properties already accepts int, torch.device, and str device specifiers directly, so no device type normalization is needed.

🤖 Prompt for AI Agents
In flashinfer/testing/utils.py around lines 38 to 52, the function
get_l2_cache_size calls torch.cuda APIs and will raise on CPU-only systems; add
a defensive CUDA availability check at the top of the function (e.g., if not
torch.cuda.is_available(): raise RuntimeError("CUDA is not available; cannot get
L2 cache size")), keep the existing device handling (no normalization needed),
and return the L2 cache size as before.

Comment on lines +137 to +210
def _clone_structure(obj):
"""
Deep clone a nested structure, cloning GPU tensors with detach().clone()
while preserving scalars, booleans, and other non-tensor values.
For non-contiguous tensors (e.g., created with as_strided), this function
preserves the stride pattern using torch.empty_strided() + copy_(). This is
important for backends like cuDNN that expect specific memory layouts.
Args:
obj: Object to clone (tensor, list, tuple, dict, or other).
Returns:
Cloned structure with GPU tensors cloned, other values preserved.
"""
if isinstance(obj, torch.Tensor):
if obj.is_cuda:
if obj.is_contiguous():
return obj.detach().clone()
else:
# Preserve stride pattern for non-contiguous tensors
# (e.g., as_strided views used by cuDNN paged attention)
result = torch.empty_strided(
obj.size(),
obj.stride(),
dtype=obj.dtype,
device=obj.device,
)
result.copy_(obj.detach())
return result
else:
return obj # CPU tensors returned as-is
elif isinstance(obj, list):
return [_clone_structure(item) for item in obj]
elif isinstance(obj, tuple):
return tuple(_clone_structure(item) for item in obj)
elif isinstance(obj, dict):
return {k: _clone_structure(v) for k, v in obj.items()}
else:
# Non-tensor, non-container: return as-is (e.g., int, float, str, bool, None)
return obj


def _create_rotated_buffer_copies(
input_args: Tuple, input_kwargs: dict, num_rotations: int
) -> List[Tuple[Tuple, dict]]:
"""
Create multiple copies of input_args and input_kwargs for buffer rotation.
The first copy (index 0) uses the original args/kwargs.
Subsequent copies clone all GPU tensors while preserving other values.
Args:
input_args: Positional arguments tuple.
input_kwargs: Keyword arguments dict.
num_rotations: Number of buffer copies to create.
Returns:
List of (args, kwargs) tuples, one for each rotation index.
"""
if num_rotations <= 1:
return [(input_args, input_kwargs)]

copies = []
# First copy uses original args/kwargs
copies.append((input_args, input_kwargs))

# Create cloned copies for remaining rotations
for _ in range(num_rotations - 1):
cloned_args = _clone_structure(input_args)
cloned_kwargs = _clone_structure(input_kwargs)
copies.append((cloned_args, cloned_kwargs))

return copies
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Preserve tensor aliasing when cloning inputs (memoize clones; share memo across args+kwargs).

Right now, if a tensor is referenced multiple times (or shared between input_args and input_kwargs), each occurrence becomes a distinct clone in rotated copies. That can change semantics for alias-sensitive kernels and also inflates memory.

-def _clone_structure(obj):
+def _clone_structure(obj, _memo: Optional[dict[int, Any]] = None):
@@
-    if isinstance(obj, torch.Tensor):
+    if _memo is None:
+        _memo = {}
+    if isinstance(obj, torch.Tensor):
+        k = id(obj)
+        if k in _memo:
+            return _memo[k]
         if obj.is_cuda:
             if obj.is_contiguous():
-                return obj.detach().clone()
+                cloned = obj.detach().clone()
+                _memo[k] = cloned
+                return cloned
             else:
@@
-                result.copy_(obj.detach())
-                return result
+                result.copy_(obj.detach())
+                _memo[k] = result
+                return result
@@
-    elif isinstance(obj, list):
-        return [_clone_structure(item) for item in obj]
+    elif isinstance(obj, list):
+        return [_clone_structure(item, _memo) for item in obj]
     elif isinstance(obj, tuple):
-        return tuple(_clone_structure(item) for item in obj)
+        return tuple(_clone_structure(item, _memo) for item in obj)
     elif isinstance(obj, dict):
-        return {k: _clone_structure(v) for k, v in obj.items()}
+        return {k: _clone_structure(v, _memo) for k, v in obj.items()}
@@
 def _create_rotated_buffer_copies(
@@
     for _ in range(num_rotations - 1):
-        cloned_args = _clone_structure(input_args)
-        cloned_kwargs = _clone_structure(input_kwargs)
+        memo: dict[int, Any] = {}
+        cloned_args = _clone_structure(input_args, memo)
+        cloned_kwargs = _clone_structure(input_kwargs, memo)
         copies.append((cloned_args, cloned_kwargs))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants