From 00dc383e9028082bc9231cc0b015f078aac986b7 Mon Sep 17 00:00:00 2001 From: Jinman Xie Date: Wed, 27 May 2026 15:28:02 -0700 Subject: [PATCH 1/4] benchmark_fn_cupti: surface kernel_name and kernel_times in result dict --- tests/common.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/tests/common.py b/tests/common.py index 30d699f6..c2d66888 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1310,7 +1310,27 @@ def benchmark_fn_cupti( stats = torch.cuda.memory_stats() peak_mem_mb = stats["allocated_bytes.all.peak"] // (1024 * 1024) - return { + # Extract kernel names and times from the last profiler run. + # ``prof`` is still in scope here (the last loop iteration's profiler), so we + # can call key_averages() without launching an extra profiled run. + # All kernels with self_device_time_total > 0 are captured regardless of + # kernel_filter, giving a complete picture of every GPU kernel that ran. + # kernel_name is set to the single most time-consuming kernel overall. + cupti_kernel_times = [] + for _item in prof.key_averages(): + if _item.self_device_time_total > 0: + cupti_kernel_times.append( + { + "name": _item.key, + "self_time_us": _item.self_device_time_total, + "total_time_us": _item.device_time_total, + "count": int(_item.count), + } + ) + cupti_kernel_times.sort(key=lambda x: x["self_time_us"], reverse=True) + kernel_name = cupti_kernel_times[0]["name"] if cupti_kernel_times else None + + res = { "mean": times.mean().item(), "std": times.std().item(), "rel_std": (times.std() / times.mean()).item() * 100 if times.mean().item() > 0 else 0, @@ -1320,6 +1340,11 @@ def benchmark_fn_cupti( "nrep": len(times), "peak_mem_mb": peak_mem_mb, } + if kernel_name is not None: + res["kernel_name"] = kernel_name + if cupti_kernel_times: + res["kernel_times"] = cupti_kernel_times + return res def benchmark_framework(framework_name, framework_fn, **benchmark_kwargs): From f22ee8658e57268c6d67d87fe15f45689d0e5e55 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Tue, 2 Jun 2026 22:45:19 -0700 Subject: [PATCH 2/4] bench_softmax: add bf16 coverage --- tests/benchmark/bench_softmax.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/tests/benchmark/bench_softmax.py b/tests/benchmark/bench_softmax.py index 50815bab..7c8b9c54 100644 --- a/tests/benchmark/bench_softmax.py +++ b/tests/benchmark/bench_softmax.py @@ -39,13 +39,14 @@ def get_supported_backends(): return [p for p in ALL_BACKENDS if p is not None] -def create_benchmark_config(M, use_tma=True, use_chunked=False, use_multi_wave=False): +def create_benchmark_config(M, dtype, use_tma=True, use_chunked=False, use_multi_wave=False): """Create a benchmark configuration for given parameters""" available_backends = get_supported_backends() if not available_backends: return None backends, names, styles = zip(*available_backends) + dtype_name = str(dtype).split(".")[-1] return triton.testing.Benchmark( x_names=["N"], @@ -55,15 +56,24 @@ def create_benchmark_config(M, use_tma=True, use_chunked=False, use_multi_wave=F line_names=list(names), styles=list(styles), ylabel="GB/s", - plot_name=f"softmax-performance-tma-{use_tma}-chunked-{use_chunked}-multi-wave-{use_multi_wave}-GBps", - args={"M": M, "use_tma": use_tma, "use_chunked": use_chunked, "use_multi_wave": use_multi_wave}, + plot_name=( + f"softmax-performance-{dtype_name}-tma-{use_tma}-chunked-{use_chunked}-multi-wave-{use_multi_wave}-GBps" + ), + args={ + "M": M, + "dtype": dtype, + "use_tma": use_tma, + "use_chunked": use_chunked, + "use_multi_wave": use_multi_wave, + }, ) @triton.testing.perf_report( [ - create_benchmark_config(M, use_tma, use_chunked, use_multi_wave) + create_benchmark_config(M, dtype, use_tma, use_chunked, use_multi_wave) for M in [4096] + for dtype in [torch.float32, torch.bfloat16] for use_tma, use_chunked, use_multi_wave in [ (False, False, False), # baseline (True, False, False), # TMA softmax @@ -72,7 +82,7 @@ def create_benchmark_config(M, use_tma=True, use_chunked=False, use_multi_wave=F ] ] ) -def bench_softmax(M, N, backend, use_tma, use_chunked, use_multi_wave, dtype=torch.float32, device=DEVICE): +def bench_softmax(M, N, backend, dtype, use_tma, use_chunked, use_multi_wave, device=DEVICE): # Create data x = torch.randn(M, N, dtype=dtype, device=device) From fa7f49811bf15993c68750f8da0b4fdf81b32957 Mon Sep 17 00:00:00 2001 From: Yifei Song Date: Tue, 2 Jun 2026 23:52:05 -0700 Subject: [PATCH 3/4] fix(flashinfer/moe): pass actual per-expert max_m to ragged_bmm --- .../cutile/gemm/ragged_block_scaled_bmm.py | 32 +++++++- .../flashinfer/cutile/gemm/ragged_bmm.py | 73 ++++++++++++++++--- 2 files changed, 91 insertions(+), 14 deletions(-) diff --git a/src/tilegym/suites/flashinfer/cutile/gemm/ragged_block_scaled_bmm.py b/src/tilegym/suites/flashinfer/cutile/gemm/ragged_block_scaled_bmm.py index 05be62d4..8421c670 100644 --- a/src/tilegym/suites/flashinfer/cutile/gemm/ragged_block_scaled_bmm.py +++ b/src/tilegym/suites/flashinfer/cutile/gemm/ragged_block_scaled_bmm.py @@ -28,7 +28,8 @@ def _ragged_block_scaled_bmm_kernel( c, # Output matrix C [total_m, N] m_indptr, # Segment offsets [Q+1], flattened 1D q, # Number of batches - max_m, # Max segment size + max_m, # Host-side max segment size hint (kept for autotune cache key) + max_m_device, # 1-element int32 tensor (shape (1,)) — device-side ground truth for max(valid_m) n, # Output N dimension HAS_A_SCALE: ct.Constant[int], # Whether a_scale is provided (0 or 1) BLOCK_M: ct.Constant[int], @@ -47,11 +48,17 @@ def _ragged_block_scaled_bmm_kernel( Uses persistent scheduling with static grid and GROUP_SIZE_M tile swizzling. Uses Array.slice + TMA (ct.load/ct.store) for A and C access. + + Defense-in-depth: the per-tile loop bound is computed from the device-side + `max_m_device` rather than the host-side `max_m`. This prevents silent output corruption when the caller passes + too small a host-side max_m hint. """ pid = ct.bid(0) num_k_tiles = ct.num_tiles(a, axis=1, shape=(BLOCK_M, BLOCK_K)) - num_pid_m = ct.cdiv(max_m, BLOCK_M) + # Override host max_m with device truth (see docstring). + max_m_runtime = ct.load(max_m_device, index=(0,), shape=(1,)).item() + num_pid_m = ct.cdiv(max_m_runtime, BLOCK_M) num_pid_n = ct.cdiv(n, BLOCK_N) tiles_per_batch = num_pid_m * num_pid_n total_tiles = tiles_per_batch * q @@ -179,7 +186,8 @@ def _ragged_block_scaled_bmm_swap_ab_kernel( c, # Output matrix C [total_m, N] m_indptr, # Segment offsets [Q+1], flattened 1D q, - max_m, + max_m, # Host-side max segment size hint (kept for autotune cache key) + max_m_device, # 1-element int32 tensor (shape (1,)) — device-side ground truth for max(valid_m) n, HAS_A_SCALE: ct.Constant[int], BLOCK_M: ct.Constant[int], @@ -190,11 +198,15 @@ def _ragged_block_scaled_bmm_swap_ab_kernel( """ cuTile kernel for ragged block-scaled BMM with swap_ab optimization. Uses Array.slice + TMA (ct.load/ct.store) for A and C access. + + Defense-in-depth: same as `_ragged_block_scaled_bmm_kernel` — the per-tile + loop bound is computed from `max_m_device` (device truth), not the host hint. """ pid = ct.bid(0) num_k_tiles = ct.num_tiles(a, axis=1, shape=(BLOCK_M, BLOCK_K)) - num_pid_m = ct.cdiv(max_m, BLOCK_M) + max_m_runtime = ct.load(max_m_device, index=(0,), shape=(1,)).item() + num_pid_m = ct.cdiv(max_m_runtime, BLOCK_M) num_pid_n = ct.cdiv(n, BLOCK_N) tiles_per_batch = num_pid_m * num_pid_n total_tiles = tiles_per_batch * q @@ -427,6 +439,12 @@ def ragged_block_scaled_bmm( ): """ cuTile implementation of ragged block-scaled BMM. + + `max_m_device` is an optional [1]-shape int tensor with the device-side + ground truth for max(per-batch valid_m). When provided, the kernel uses it + for its persistent-loop bound — preventing silent corruption if the host-side `max_m` hint + underestimates the actual per-batch max. When None, a fallback tensor is + materialized from `max_m`. """ # Validate inputs assert transpose_a == False and transpose_b == True, "Only NT layout is supported" @@ -458,6 +476,11 @@ def ragged_block_scaled_bmm( out_dtype = torch.bfloat16 c = torch.empty((total_m, N), device=a.device, dtype=out_dtype) + # Materialize fallback max_m_device if the caller didn't pass one. The + # kernel always reads its grid bound from a device tensor (defense-in-depth). + if max_m_device is None: + max_m_device = torch.tensor([max_m], dtype=torch.int32, device=a.device) + # Get kernel configs default_configs = _get_default_kernel_configs(total_m, Q, VEC_SIZE) kernel_configs = get_kernel_configs(default_configs, kwargs.get("kernel_configs")) @@ -509,6 +532,7 @@ def ragged_block_scaled_bmm( m_indptr, Q, max_m, + max_m_device, N, has_a_scale, BLOCK_M, diff --git a/src/tilegym/suites/flashinfer/cutile/gemm/ragged_bmm.py b/src/tilegym/suites/flashinfer/cutile/gemm/ragged_bmm.py index f26b1b39..1568f165 100644 --- a/src/tilegym/suites/flashinfer/cutile/gemm/ragged_bmm.py +++ b/src/tilegym/suites/flashinfer/cutile/gemm/ragged_bmm.py @@ -25,7 +25,8 @@ def _ragged_bmm_kernel( c, # Output matrix C [total_m, N] m_indptr, # Segment offsets [Q+1], flattened 1D q, # Number of batches - max_m, # Max segment size + max_m, # Host-side max segment size hint (kept for autotune cache key) + max_m_device, # 1-element int32 tensor (shape (1,)) — device-side ground truth for max(valid_m) n, # Output N dimension TRANSPOSE_A: ct.Constant[int], # Whether A is transposed (0 or 1) TRANSPOSE_B: ct.Constant[int], # Whether B is transposed (0 or 1) @@ -44,6 +45,11 @@ def _ragged_bmm_kernel( Uses persistent scheduling with static grid and GROUP_SIZE_M tile swizzling. Uses Array.slice + TMA (ct.load/ct.store) for A and C access. + + Defense-in-depth: the per-tile loop bound is computed from the device-side + `max_m_device` rather than the host-side `max_m`. This prevents silent output corruption when the caller passes + too small a host-side max_m hint (rows beyond cdiv(max_m, BLOCK_M)*BLOCK_M + would otherwise never be processed). """ pid = ct.bid(0) @@ -51,7 +57,9 @@ def _ragged_bmm_kernel( num_k_tiles = ct.num_tiles(a, axis=0, shape=(BLOCK_K, BLOCK_M)) else: num_k_tiles = ct.num_tiles(a, axis=1, shape=(BLOCK_M, BLOCK_K)) - num_pid_m = ct.cdiv(max_m, BLOCK_M) + # Override host max_m with device truth (see docstring). + max_m_runtime = ct.load(max_m_device, index=(0,), shape=(1,)).item() + num_pid_m = ct.cdiv(max_m_runtime, BLOCK_M) num_pid_n = ct.cdiv(n, BLOCK_N) tiles_per_batch = num_pid_m * num_pid_n total_tiles = tiles_per_batch * q @@ -157,7 +165,8 @@ def _ragged_bmm_swap_ab_kernel( c, # Output matrix C [total_m, N] m_indptr, # Segment offsets [Q+1], flattened 1D q, # Number of batches - max_m, # Max segment size + max_m, # Host-side max segment size hint (kept for autotune cache key) + max_m_device, # 1-element int32 tensor (shape (1,)) — device-side ground truth for max(valid_m) n, # Output N dimension TRANSPOSE_A: ct.Constant[int], # Whether A is transposed (0 or 1) TRANSPOSE_B: ct.Constant[int], # Whether B is transposed (0 or 1) @@ -173,6 +182,9 @@ def _ragged_bmm_swap_ab_kernel( when M dimension is small. Equivalent to: dot(B^T.T, A.T).T = A @ B^T Uses Array.slice + TMA (ct.load/ct.store) for A and C access. + + Defense-in-depth: same as `_ragged_bmm_kernel` — the per-tile loop bound + is computed from `max_m_device` (device truth), not the host hint. """ pid = ct.bid(0) @@ -180,7 +192,8 @@ def _ragged_bmm_swap_ab_kernel( num_k_tiles = ct.num_tiles(a, axis=0, shape=(BLOCK_K, BLOCK_M)) else: num_k_tiles = ct.num_tiles(a, axis=1, shape=(BLOCK_M, BLOCK_K)) - num_pid_m = ct.cdiv(max_m, BLOCK_M) + max_m_runtime = ct.load(max_m_device, index=(0,), shape=(1,)).item() + num_pid_m = ct.cdiv(max_m_runtime, BLOCK_M) num_pid_n = ct.cdiv(n, BLOCK_N) tiles_per_batch = num_pid_m * num_pid_n total_tiles = tiles_per_batch * q @@ -412,7 +425,9 @@ def _get_default_kernel_configs(): } -def _ragged_bmm_autotune_standard(stream, a, b, c, m_indptr, Q, max_m, N, total_m, transpose_a, transpose_b): +def _ragged_bmm_autotune_standard( + stream, a, b, c, m_indptr, Q, max_m, max_m_device, N, total_m, transpose_a, transpose_b +): """ Autotuned launch for standard ragged BMM kernel. """ @@ -434,6 +449,7 @@ def args_fn(cfg): m_indptr, Q, max_m, + max_m_device, N, transpose_a_int, transpose_b_int, @@ -475,7 +491,9 @@ def hints_fn(cfg): ct.launch(stream, grid_fn(best_cfg), tuned_kernel, args_fn(best_cfg)) -def _ragged_bmm_autotune_swap_ab(stream, a, b, c, m_indptr, Q, max_m, N, total_m, transpose_a, transpose_b): +def _ragged_bmm_autotune_swap_ab( + stream, a, b, c, m_indptr, Q, max_m, max_m_device, N, total_m, transpose_a, transpose_b +): """ Autotuned launch for swap_ab ragged BMM kernel. """ @@ -497,6 +515,7 @@ def args_fn(cfg): m_indptr, Q, max_m, + max_m_device, N, transpose_a_int, transpose_b_int, @@ -563,8 +582,13 @@ def ragged_bmm( a: Input matrix A, flattened [total_m, K] or [K, total_m] if transpose_a b: Input matrix B, batched [Q, N, K] or [Q, K, N] if not transpose_b m_indptr: Segment offsets tensor [Q+1] - max_m: Maximum segment size - max_m_device: Optional device tensor with max_m (unused in cuTile, kept for API compatibility) + max_m: Host-side maximum segment size hint (used for grid sizing and + autotune cache key). Should be >= max per-batch valid_m. + max_m_device: Optional [1]-shape int tensor with the device-side ground + truth for max(valid_m). When provided, the kernel uses this value + for its persistent-loop bound — making the kernel robust to a + host-side max_m underestimate. When None, a fallback tensor is + materialized from `max_m`. transpose_a: Whether A is transposed transpose_b: Whether B is transposed out_dtype: Output dtype @@ -594,6 +618,12 @@ def ragged_bmm( out_dtype = a.dtype c = torch.empty((total_m, N), device=a.device, dtype=out_dtype) + # Materialize fallback max_m_device if the caller didn't pass one. The + # kernel always reads its grid bound from a device tensor (defense-in-depth), + # so we keep the call sites uniform. + if max_m_device is None: + max_m_device = torch.tensor([max_m], dtype=torch.int32, device=a.device) + # Check if autotune is enabled enable_autotune = is_autotune_enabled() @@ -604,11 +634,33 @@ def ragged_bmm( if enable_autotune: if use_swap_ab: _ragged_bmm_autotune_swap_ab( - torch.cuda.current_stream(), a, b, c, m_indptr, Q, max_m, N, total_m, transpose_a, transpose_b + torch.cuda.current_stream(), + a, + b, + c, + m_indptr, + Q, + max_m, + max_m_device, + N, + total_m, + transpose_a, + transpose_b, ) else: _ragged_bmm_autotune_standard( - torch.cuda.current_stream(), a, b, c, m_indptr, Q, max_m, N, total_m, transpose_a, transpose_b + torch.cuda.current_stream(), + a, + b, + c, + m_indptr, + Q, + max_m, + max_m_device, + N, + total_m, + transpose_a, + transpose_b, ) else: # Use fixed default configs @@ -659,6 +711,7 @@ def ragged_bmm( m_indptr, Q, max_m, + max_m_device, N, transpose_a_int, transpose_b_int, From 8d0f0fb5a0468226c6e33e81fb683fcfc4cdba38 Mon Sep 17 00:00:00 2001 From: Zhiwei Fang Date: Wed, 3 Jun 2026 16:15:55 -0700 Subject: [PATCH 4/4] Remove grad from test_matmul.py pytorch backend --- tests/ops/test_matmul.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/ops/test_matmul.py b/tests/ops/test_matmul.py index 89a1f9c1..826b83bc 100644 --- a/tests/ops/test_matmul.py +++ b/tests/ops/test_matmul.py @@ -186,7 +186,10 @@ def test_perf( "use_tma": use_tma, } if backend == "pytorch": - backend_fn = lambda: self.reference(a, b, transpose_a, transpose_b) + # Detach so the output does not require grad; this keeps the benchmark to the forward pass only. + _a = a.detach() + _b = b.detach() + backend_fn = lambda: self.reference(_a, _b, transpose_a, transpose_b) elif tilegym.is_backend_available(backend): tilegym.set_backend(backend) if backend == "cutile" and transpose_b: