Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def _ragged_block_scaled_bmm_kernel(
c, # Output matrix C [total_m, N]
m_indptr, # Segment offsets [Q+1], flattened 1D
q, # Number of batches
max_m, # Max segment size
max_m, # Host-side max segment size hint (kept for autotune cache key)
max_m_device, # 1-element int32 tensor (shape (1,)) — device-side ground truth for max(valid_m)
n, # Output N dimension
HAS_A_SCALE: ct.Constant[int], # Whether a_scale is provided (0 or 1)
BLOCK_M: ct.Constant[int],
Expand All @@ -47,11 +48,17 @@ def _ragged_block_scaled_bmm_kernel(

Uses persistent scheduling with static grid and GROUP_SIZE_M tile swizzling.
Uses Array.slice + TMA (ct.load/ct.store) for A and C access.

Defense-in-depth: the per-tile loop bound is computed from the device-side
`max_m_device` rather than the host-side `max_m`. This prevents silent output corruption when the caller passes
too small a host-side max_m hint.
"""
pid = ct.bid(0)

num_k_tiles = ct.num_tiles(a, axis=1, shape=(BLOCK_M, BLOCK_K))
num_pid_m = ct.cdiv(max_m, BLOCK_M)
# Override host max_m with device truth (see docstring).
max_m_runtime = ct.load(max_m_device, index=(0,), shape=(1,)).item()
num_pid_m = ct.cdiv(max_m_runtime, BLOCK_M)
num_pid_n = ct.cdiv(n, BLOCK_N)
tiles_per_batch = num_pid_m * num_pid_n
total_tiles = tiles_per_batch * q
Expand Down Expand Up @@ -179,7 +186,8 @@ def _ragged_block_scaled_bmm_swap_ab_kernel(
c, # Output matrix C [total_m, N]
m_indptr, # Segment offsets [Q+1], flattened 1D
q,
max_m,
max_m, # Host-side max segment size hint (kept for autotune cache key)
max_m_device, # 1-element int32 tensor (shape (1,)) — device-side ground truth for max(valid_m)
n,
HAS_A_SCALE: ct.Constant[int],
BLOCK_M: ct.Constant[int],
Expand All @@ -190,11 +198,15 @@ def _ragged_block_scaled_bmm_swap_ab_kernel(
"""
cuTile kernel for ragged block-scaled BMM with swap_ab optimization.
Uses Array.slice + TMA (ct.load/ct.store) for A and C access.

Defense-in-depth: same as `_ragged_block_scaled_bmm_kernel` — the per-tile
loop bound is computed from `max_m_device` (device truth), not the host hint.
"""
pid = ct.bid(0)

num_k_tiles = ct.num_tiles(a, axis=1, shape=(BLOCK_M, BLOCK_K))
num_pid_m = ct.cdiv(max_m, BLOCK_M)
max_m_runtime = ct.load(max_m_device, index=(0,), shape=(1,)).item()
num_pid_m = ct.cdiv(max_m_runtime, BLOCK_M)
num_pid_n = ct.cdiv(n, BLOCK_N)
tiles_per_batch = num_pid_m * num_pid_n
total_tiles = tiles_per_batch * q
Expand Down Expand Up @@ -427,6 +439,12 @@ def ragged_block_scaled_bmm(
):
"""
cuTile implementation of ragged block-scaled BMM.

`max_m_device` is an optional [1]-shape int tensor with the device-side
ground truth for max(per-batch valid_m). When provided, the kernel uses it
for its persistent-loop bound — preventing silent corruption if the host-side `max_m` hint
underestimates the actual per-batch max. When None, a fallback tensor is
materialized from `max_m`.
"""
# Validate inputs
assert transpose_a == False and transpose_b == True, "Only NT layout is supported"
Expand Down Expand Up @@ -458,6 +476,11 @@ def ragged_block_scaled_bmm(
out_dtype = torch.bfloat16
c = torch.empty((total_m, N), device=a.device, dtype=out_dtype)

# Materialize fallback max_m_device if the caller didn't pass one. The
# kernel always reads its grid bound from a device tensor (defense-in-depth).
if max_m_device is None:
max_m_device = torch.tensor([max_m], dtype=torch.int32, device=a.device)

# Get kernel configs
default_configs = _get_default_kernel_configs(total_m, Q, VEC_SIZE)
kernel_configs = get_kernel_configs(default_configs, kwargs.get("kernel_configs"))
Expand Down Expand Up @@ -509,6 +532,7 @@ def ragged_block_scaled_bmm(
m_indptr,
Q,
max_m,
max_m_device,
N,
has_a_scale,
BLOCK_M,
Expand Down
73 changes: 63 additions & 10 deletions src/tilegym/suites/flashinfer/cutile/gemm/ragged_bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def _ragged_bmm_kernel(
c, # Output matrix C [total_m, N]
m_indptr, # Segment offsets [Q+1], flattened 1D
q, # Number of batches
max_m, # Max segment size
max_m, # Host-side max segment size hint (kept for autotune cache key)
max_m_device, # 1-element int32 tensor (shape (1,)) — device-side ground truth for max(valid_m)
n, # Output N dimension
TRANSPOSE_A: ct.Constant[int], # Whether A is transposed (0 or 1)
TRANSPOSE_B: ct.Constant[int], # Whether B is transposed (0 or 1)
Expand All @@ -44,14 +45,21 @@ def _ragged_bmm_kernel(

Uses persistent scheduling with static grid and GROUP_SIZE_M tile swizzling.
Uses Array.slice + TMA (ct.load/ct.store) for A and C access.

Defense-in-depth: the per-tile loop bound is computed from the device-side
`max_m_device` rather than the host-side `max_m`. This prevents silent output corruption when the caller passes
too small a host-side max_m hint (rows beyond cdiv(max_m, BLOCK_M)*BLOCK_M
would otherwise never be processed).
"""
pid = ct.bid(0)

if TRANSPOSE_A == 1:
num_k_tiles = ct.num_tiles(a, axis=0, shape=(BLOCK_K, BLOCK_M))
else:
num_k_tiles = ct.num_tiles(a, axis=1, shape=(BLOCK_M, BLOCK_K))
num_pid_m = ct.cdiv(max_m, BLOCK_M)
# Override host max_m with device truth (see docstring).
max_m_runtime = ct.load(max_m_device, index=(0,), shape=(1,)).item()
num_pid_m = ct.cdiv(max_m_runtime, BLOCK_M)
num_pid_n = ct.cdiv(n, BLOCK_N)
tiles_per_batch = num_pid_m * num_pid_n
total_tiles = tiles_per_batch * q
Expand Down Expand Up @@ -157,7 +165,8 @@ def _ragged_bmm_swap_ab_kernel(
c, # Output matrix C [total_m, N]
m_indptr, # Segment offsets [Q+1], flattened 1D
q, # Number of batches
max_m, # Max segment size
max_m, # Host-side max segment size hint (kept for autotune cache key)
max_m_device, # 1-element int32 tensor (shape (1,)) — device-side ground truth for max(valid_m)
n, # Output N dimension
TRANSPOSE_A: ct.Constant[int], # Whether A is transposed (0 or 1)
TRANSPOSE_B: ct.Constant[int], # Whether B is transposed (0 or 1)
Expand All @@ -173,14 +182,18 @@ def _ragged_bmm_swap_ab_kernel(
when M dimension is small. Equivalent to: dot(B^T.T, A.T).T = A @ B^T

Uses Array.slice + TMA (ct.load/ct.store) for A and C access.

Defense-in-depth: same as `_ragged_bmm_kernel` — the per-tile loop bound
is computed from `max_m_device` (device truth), not the host hint.
"""
pid = ct.bid(0)

if TRANSPOSE_A == 1:
num_k_tiles = ct.num_tiles(a, axis=0, shape=(BLOCK_K, BLOCK_M))
else:
num_k_tiles = ct.num_tiles(a, axis=1, shape=(BLOCK_M, BLOCK_K))
num_pid_m = ct.cdiv(max_m, BLOCK_M)
max_m_runtime = ct.load(max_m_device, index=(0,), shape=(1,)).item()
num_pid_m = ct.cdiv(max_m_runtime, BLOCK_M)
num_pid_n = ct.cdiv(n, BLOCK_N)
tiles_per_batch = num_pid_m * num_pid_n
total_tiles = tiles_per_batch * q
Expand Down Expand Up @@ -412,7 +425,9 @@ def _get_default_kernel_configs():
}


def _ragged_bmm_autotune_standard(stream, a, b, c, m_indptr, Q, max_m, N, total_m, transpose_a, transpose_b):
def _ragged_bmm_autotune_standard(
stream, a, b, c, m_indptr, Q, max_m, max_m_device, N, total_m, transpose_a, transpose_b
):
"""
Autotuned launch for standard ragged BMM kernel.
"""
Expand All @@ -434,6 +449,7 @@ def args_fn(cfg):
m_indptr,
Q,
max_m,
max_m_device,
N,
transpose_a_int,
transpose_b_int,
Expand Down Expand Up @@ -475,7 +491,9 @@ def hints_fn(cfg):
ct.launch(stream, grid_fn(best_cfg), tuned_kernel, args_fn(best_cfg))


def _ragged_bmm_autotune_swap_ab(stream, a, b, c, m_indptr, Q, max_m, N, total_m, transpose_a, transpose_b):
def _ragged_bmm_autotune_swap_ab(
stream, a, b, c, m_indptr, Q, max_m, max_m_device, N, total_m, transpose_a, transpose_b
):
"""
Autotuned launch for swap_ab ragged BMM kernel.
"""
Expand All @@ -497,6 +515,7 @@ def args_fn(cfg):
m_indptr,
Q,
max_m,
max_m_device,
N,
transpose_a_int,
transpose_b_int,
Expand Down Expand Up @@ -563,8 +582,13 @@ def ragged_bmm(
a: Input matrix A, flattened [total_m, K] or [K, total_m] if transpose_a
b: Input matrix B, batched [Q, N, K] or [Q, K, N] if not transpose_b
m_indptr: Segment offsets tensor [Q+1]
max_m: Maximum segment size
max_m_device: Optional device tensor with max_m (unused in cuTile, kept for API compatibility)
max_m: Host-side maximum segment size hint (used for grid sizing and
autotune cache key). Should be >= max per-batch valid_m.
max_m_device: Optional [1]-shape int tensor with the device-side ground
truth for max(valid_m). When provided, the kernel uses this value
for its persistent-loop bound — making the kernel robust to a
host-side max_m underestimate. When None, a fallback tensor is
materialized from `max_m`.
transpose_a: Whether A is transposed
transpose_b: Whether B is transposed
out_dtype: Output dtype
Expand Down Expand Up @@ -594,6 +618,12 @@ def ragged_bmm(
out_dtype = a.dtype
c = torch.empty((total_m, N), device=a.device, dtype=out_dtype)

# Materialize fallback max_m_device if the caller didn't pass one. The
# kernel always reads its grid bound from a device tensor (defense-in-depth),
# so we keep the call sites uniform.
if max_m_device is None:
max_m_device = torch.tensor([max_m], dtype=torch.int32, device=a.device)

# Check if autotune is enabled
enable_autotune = is_autotune_enabled()

Expand All @@ -604,11 +634,33 @@ def ragged_bmm(
if enable_autotune:
if use_swap_ab:
_ragged_bmm_autotune_swap_ab(
torch.cuda.current_stream(), a, b, c, m_indptr, Q, max_m, N, total_m, transpose_a, transpose_b
torch.cuda.current_stream(),
a,
b,
c,
m_indptr,
Q,
max_m,
max_m_device,
N,
total_m,
transpose_a,
transpose_b,
)
else:
_ragged_bmm_autotune_standard(
torch.cuda.current_stream(), a, b, c, m_indptr, Q, max_m, N, total_m, transpose_a, transpose_b
torch.cuda.current_stream(),
a,
b,
c,
m_indptr,
Q,
max_m,
max_m_device,
N,
total_m,
transpose_a,
transpose_b,
)
else:
# Use fixed default configs
Expand Down Expand Up @@ -659,6 +711,7 @@ def ragged_bmm(
m_indptr,
Q,
max_m,
max_m_device,
N,
transpose_a_int,
transpose_b_int,
Expand Down
20 changes: 15 additions & 5 deletions tests/benchmark/bench_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@ def get_supported_backends():
return [p for p in ALL_BACKENDS if p is not None]


def create_benchmark_config(M, use_tma=True, use_chunked=False, use_multi_wave=False):
def create_benchmark_config(M, dtype, use_tma=True, use_chunked=False, use_multi_wave=False):
"""Create a benchmark configuration for given parameters"""
available_backends = get_supported_backends()
if not available_backends:
return None

backends, names, styles = zip(*available_backends)
dtype_name = str(dtype).split(".")[-1]

return triton.testing.Benchmark(
x_names=["N"],
Expand All @@ -55,15 +56,24 @@ def create_benchmark_config(M, use_tma=True, use_chunked=False, use_multi_wave=F
line_names=list(names),
styles=list(styles),
ylabel="GB/s",
plot_name=f"softmax-performance-tma-{use_tma}-chunked-{use_chunked}-multi-wave-{use_multi_wave}-GBps",
args={"M": M, "use_tma": use_tma, "use_chunked": use_chunked, "use_multi_wave": use_multi_wave},
plot_name=(
f"softmax-performance-{dtype_name}-tma-{use_tma}-chunked-{use_chunked}-multi-wave-{use_multi_wave}-GBps"
),
args={
"M": M,
"dtype": dtype,
"use_tma": use_tma,
"use_chunked": use_chunked,
"use_multi_wave": use_multi_wave,
},
)


@triton.testing.perf_report(
[
create_benchmark_config(M, use_tma, use_chunked, use_multi_wave)
create_benchmark_config(M, dtype, use_tma, use_chunked, use_multi_wave)
for M in [4096]
for dtype in [torch.float32, torch.bfloat16]
for use_tma, use_chunked, use_multi_wave in [
(False, False, False), # baseline
(True, False, False), # TMA softmax
Expand All @@ -72,7 +82,7 @@ def create_benchmark_config(M, use_tma=True, use_chunked=False, use_multi_wave=F
]
]
)
def bench_softmax(M, N, backend, use_tma, use_chunked, use_multi_wave, dtype=torch.float32, device=DEVICE):
def bench_softmax(M, N, backend, dtype, use_tma, use_chunked, use_multi_wave, device=DEVICE):
# Create data
x = torch.randn(M, N, dtype=dtype, device=device)

Expand Down
27 changes: 26 additions & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,7 +1310,27 @@ def benchmark_fn_cupti(
stats = torch.cuda.memory_stats()
peak_mem_mb = stats["allocated_bytes.all.peak"] // (1024 * 1024)

return {
# Extract kernel names and times from the last profiler run.
# ``prof`` is still in scope here (the last loop iteration's profiler), so we
# can call key_averages() without launching an extra profiled run.
# All kernels with self_device_time_total > 0 are captured regardless of
# kernel_filter, giving a complete picture of every GPU kernel that ran.
# kernel_name is set to the single most time-consuming kernel overall.
cupti_kernel_times = []
for _item in prof.key_averages():
if _item.self_device_time_total > 0:
cupti_kernel_times.append(
{
"name": _item.key,
"self_time_us": _item.self_device_time_total,
"total_time_us": _item.device_time_total,
"count": int(_item.count),
}
)
cupti_kernel_times.sort(key=lambda x: x["self_time_us"], reverse=True)
kernel_name = cupti_kernel_times[0]["name"] if cupti_kernel_times else None

res = {
"mean": times.mean().item(),
"std": times.std().item(),
"rel_std": (times.std() / times.mean()).item() * 100 if times.mean().item() > 0 else 0,
Expand All @@ -1320,6 +1340,11 @@ def benchmark_fn_cupti(
"nrep": len(times),
"peak_mem_mb": peak_mem_mb,
}
if kernel_name is not None:
res["kernel_name"] = kernel_name
if cupti_kernel_times:
res["kernel_times"] = cupti_kernel_times
return res


def benchmark_framework(framework_name, framework_fn, **benchmark_kwargs):
Expand Down
5 changes: 4 additions & 1 deletion tests/ops/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,10 @@ def test_perf(
"use_tma": use_tma,
}
if backend == "pytorch":
backend_fn = lambda: self.reference(a, b, transpose_a, transpose_b)
# Detach so the output does not require grad; this keeps the benchmark to the forward pass only.
_a = a.detach()
_b = b.detach()
backend_fn = lambda: self.reference(_a, _b, transpose_a, transpose_b)
elif tilegym.is_backend_available(backend):
tilegym.set_backend(backend)
if backend == "cutile" and transpose_b:
Expand Down
Loading