diff --git a/benchmarks/ops/bench_mamba.py b/benchmarks/ops/bench_mamba.py index eddc413e2..ca79851ce 100644 --- a/benchmarks/ops/bench_mamba.py +++ b/benchmarks/ops/bench_mamba.py @@ -2,6 +2,7 @@ import pytest import torch +import torch.nn.functional as F from benchmarks.benchmark_base import BenchmarkBase, BenchmarkReport from tileops.ops.da_cumsum import DaCumsumFwdOp @@ -25,6 +26,11 @@ # --------------------------------------------------------------------------- # Optional mamba_ssm Triton baselines # --------------------------------------------------------------------------- +try: + from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd as _mamba_chunk_cumsum_fwd +except ImportError: + _mamba_chunk_cumsum_fwd = None + try: from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd as _mamba_chunk_scan_fwd except ImportError: @@ -48,15 +54,31 @@ def da_cumsum_fwd_ref( A: torch.Tensor, num_chunks: int, chunk_len: int, -) -> torch.Tensor: - """PyTorch reference for da_cumsum_fwd (benchmark-local copy).""" + dt_bias: torch.Tensor | None = None, + dt_softplus: bool = False, + dt_min: float = 0.0, + dt_max: float = float("inf"), +) -> tuple[torch.Tensor, torch.Tensor]: + """PyTorch reference for da_cumsum_fwd (benchmark-local copy). + + Returns: + dt_out: (batch, n_heads, num_chunks, chunk_len) float32 + dA_cumsum: (batch, n_heads, num_chunks, chunk_len) float32 + """ b, S, h = dt.shape Q = chunk_len C = num_chunks - dt_chunked = dt.float().reshape(b, C, Q, h) + dt_val = dt.float() + if dt_bias is not None: + dt_val = dt_val + dt_bias.float() + if dt_softplus: + dt_val = F.softplus(dt_val) + dt_val = torch.clamp(dt_val, min=dt_min, max=dt_max) + dt_chunked = dt_val.reshape(b, C, Q, h) + dt_out = dt_chunked.permute(0, 3, 1, 2).contiguous() # (b, h, C, Q) dA = dt_chunked * A.float() - dA_cumsum = dA.cumsum(dim=2) - return dA_cumsum.permute(0, 3, 1, 2).contiguous() + dA_cumsum = dA.cumsum(dim=2).permute(0, 3, 1, 2).contiguous() # (b, h, C, Q) + return dt_out, dA_cumsum class DaCumsumFwdBenchmark(BenchmarkBase[DaCumsumFwdTest]): @@ -64,36 +86,69 @@ class DaCumsumFwdBenchmark(BenchmarkBase[DaCumsumFwdTest]): def calculate_flops(self) -> Optional[float]: t = self.workload b, c, L, h = t.batch, t.num_chunks, t.chunk_len, t.n_heads - # One multiply (dt * A) and one add per element for the inclusive scan - # Total: 2 * b * c * L * h - return float(2 * b * c * L * h) + # Core ops per element: 1 mul (dt*A) + 1 add (cumsum) = 2 + # Optional bias add: +1; optional softplus (exp+log+add): +3; clamp (min+max): +2 + bias_ops = 1 if t.has_dt_bias else 0 + softplus_ops = 3 if t.dt_softplus else 0 + ops_per_elem = 2 + bias_ops + softplus_ops + 2 # +2 for clamp always + return float(ops_per_elem * b * c * L * h) def calculate_memory(self) -> Optional[float]: t = self.workload b, c, L, h = t.batch, t.num_chunks, t.chunk_len, t.n_heads - # float32 throughout - elem = 4 - # Reads: dt (b, c*L, h) + A (h,) - reads = (b * c * L * h + h) * elem - # Writes: dA_cumsum (b, h, c, L) - writes = b * h * c * L * elem + elem = 4 # float32 + # Reads: dt_raw (b, c*L, h) + A (h,) + optional dt_bias (h,) + reads = (b * c * L * h + h + (h if t.has_dt_bias else 0)) * elem + # Writes: dt_out (b, h, c, L) + dA_cumsum (b, h, c, L) + writes = 2 * b * h * c * L * elem return float(reads + writes) @DaCumsumFwdFixture -def test_da_cumsum_fwd_bench(batch, num_chunks, chunk_len, n_heads, tune): - test = DaCumsumFwdTest(batch, num_chunks, chunk_len, n_heads) +def test_da_cumsum_fwd_bench(batch, num_chunks, chunk_len, n_heads, has_dt_bias, dt_softplus, tune): + test = DaCumsumFwdTest( + batch, num_chunks, chunk_len, n_heads, + has_dt_bias=has_dt_bias, dt_softplus=dt_softplus, + ) bm = DaCumsumFwdBenchmark(test) - inputs = test.gen_inputs() - - op = DaCumsumFwdOp(batch, num_chunks, chunk_len, n_heads, seq_len=num_chunks * chunk_len, tune=tune) + inputs = test.gen_inputs() # (dt_raw, A, dt_bias) + + op = DaCumsumFwdOp( + batch, num_chunks, chunk_len, n_heads, + seq_len=num_chunks * chunk_len, + has_dt_bias=has_dt_bias, + dt_softplus=dt_softplus, + tune=tune, + ) result = bm.profile(op, *inputs) BenchmarkReport.record(op, locals(), result, tag="tileops") - def baseline(dt, A): - return da_cumsum_fwd_ref(dt, A, num_chunks, chunk_len) - result_bl = bm.profile(baseline, *inputs) - BenchmarkReport.record(op, locals(), result_bl, tag="torch-ref") + # ── Mamba-2 Triton baseline ── + # _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=...) + # returns (dA_cumsum, dt_out) — note reversed order vs TileOPs (dt_out, dA_cumsum) + if _mamba_chunk_cumsum_fwd is not None: + mamba_dt_bias = inputs[2] if has_dt_bias else None + + def mamba_fwd(): + return _mamba_chunk_cumsum_fwd( + inputs[0].contiguous(), + inputs[1].contiguous(), + chunk_len, + dt_bias=mamba_dt_bias.contiguous() if mamba_dt_bias is not None else None, + dt_softplus=dt_softplus, + ) + + result_mamba = bm.profile(mamba_fwd) + BenchmarkReport.record(op, locals(), result_mamba, tag="mamba") + else: + def baseline(dt_raw, A, dt_bias): + return da_cumsum_fwd_ref( + dt_raw, A, num_chunks, chunk_len, + dt_bias=dt_bias if has_dt_bias else None, + dt_softplus=dt_softplus, + ) + result_bl = bm.profile(baseline, *inputs) + BenchmarkReport.record(op, locals(), result_bl, tag="torch-ref") def ssd_chunk_scan_fwd_ref(x, cb, dA_cumsum, C, prev_states, dt, n_groups): @@ -418,16 +473,21 @@ def ssd_state_passing_fwd_ref( dA_chunk_cumsum: torch.Tensor, initial_states: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - """PyTorch reference for ssd_state_passing_fwd (benchmark-local copy).""" + """PyTorch reference for ssd_state_passing_fwd (benchmark-local copy). + + Matches mamba convention: out[:,c] = state *before* processing chunk c, + so out[:,0] = initial_states and final_states = state after chunk C-1. + """ b, c, h, d = states.shape - out = [] + out = [initial_states.float().clone()] s = initial_states.float() for ci in range(c): scale = torch.exp(dA_chunk_cumsum[:, :, ci]).unsqueeze(-1) u = states[:, ci, :, :].float() s = scale * s + u - out.append(s.clone()) + if ci < c - 1: + out.append(s.clone()) return torch.stack(out, dim=1), s diff --git a/tests/ops/test_mamba.py b/tests/ops/test_mamba.py index f6e13e7d9..9201e6f1f 100644 --- a/tests/ops/test_mamba.py +++ b/tests/ops/test_mamba.py @@ -1,5 +1,6 @@ import pytest import torch +import torch.nn.functional as F from tests.test_base import TestBase, allclose_compare from tileops.ops.da_cumsum import DaCumsumFwdOp @@ -28,34 +29,78 @@ def da_cumsum_fwd_ref( A: torch.Tensor, num_chunks: int, chunk_len: int, -) -> torch.Tensor: - """PyTorch reference for da_cumsum_fwd.""" + dt_bias: torch.Tensor | None = None, + dt_softplus: bool = False, + dt_min: float = 0.0, + dt_max: float = float("inf"), +) -> tuple[torch.Tensor, torch.Tensor]: + """PyTorch reference for da_cumsum_fwd. + + Applies the same bias / softplus / clamp pipeline as the kernel, then + computes dt_out and the chunk-local inclusive prefix sum of dA = dt_out * A. + + Returns: + dt_out: (batch, n_heads, num_chunks, chunk_len) float32 + dA_cumsum: (batch, n_heads, num_chunks, chunk_len) float32 + """ b, S, h = dt.shape Q = chunk_len C = num_chunks - dt_chunked = dt.float().reshape(b, C, Q, h) - dA = dt_chunked * A.float() - dA_cumsum = dA.cumsum(dim=2) - return dA_cumsum.permute(0, 3, 1, 2).contiguous() + dt_val = dt.float() + if dt_bias is not None: + dt_val = dt_val + dt_bias.float() + if dt_softplus: + dt_val = F.softplus(dt_val) + dt_val = torch.clamp(dt_val, min=dt_min, max=dt_max) + dt_chunked = dt_val.reshape(b, C, Q, h) # (b, C, Q, h) + dt_out = dt_chunked.permute(0, 3, 1, 2).contiguous() # (b, h, C, Q) + dA = dt_chunked * A.float() # (b, C, Q, h) + dA_cumsum = dA.cumsum(dim=2).permute(0, 3, 1, 2).contiguous() # (b, h, C, Q) + return dt_out, dA_cumsum class DaCumsumFwdTest(_DaCumsumFwdTestWorkload, TestBase): - def ref_program(self, dt, A): - return da_cumsum_fwd_ref(dt, A, self.num_chunks, self.chunk_len) + def ref_program(self, dt, A, dt_bias): + return da_cumsum_fwd_ref( + dt, A, self.num_chunks, self.chunk_len, + dt_bias=dt_bias if self.has_dt_bias else None, + dt_softplus=self.dt_softplus, + dt_min=self.dt_min, + dt_max=self.dt_max, + ) @DaCumsumFwdFixture -def test_da_cumsum_fwd(batch, num_chunks, chunk_len, n_heads, tune): - test = DaCumsumFwdTest(batch, num_chunks, chunk_len, n_heads) +def test_da_cumsum_fwd(batch, num_chunks, chunk_len, n_heads, has_dt_bias, dt_softplus, tune): + test = DaCumsumFwdTest( + batch, num_chunks, chunk_len, n_heads, + has_dt_bias=has_dt_bias, dt_softplus=dt_softplus, + ) op = DaCumsumFwdOp( batch, num_chunks, chunk_len, n_heads, seq_len=num_chunks * chunk_len, + has_dt_bias=has_dt_bias, + dt_softplus=dt_softplus, tune=tune, ) inputs = test.gen_inputs() test.check(op, *inputs, atol=1e-5, rtol=1e-5) +@pytest.mark.smoke +def test_da_cumsum_fwd_missing_bias_raises(): + """DaCumsumFwdKernel must raise when has_dt_bias=True but dt_bias is None.""" + from tileops.kernels.mamba import DaCumsumFwdKernel + kernel = DaCumsumFwdKernel( + batch=1, num_chunks=2, chunk_len=64, n_heads=4, + seq_len=128, has_dt_bias=True, + ) + dt = torch.randn(1, 128, 4, dtype=torch.float32, device="cuda") + A = -torch.rand(4, dtype=torch.float32, device="cuda") + with pytest.raises(ValueError, match="dt_bias is required"): + kernel(dt, A, dt_bias=None) + + def ssd_chunk_scan_fwd_ref(x, cb, dA_cumsum, C, prev_states, dt, n_groups): """Official-aligned PyTorch reference for chunk scan. @@ -197,16 +242,22 @@ def ssd_state_passing_fwd_ref( dA_chunk_cumsum: torch.Tensor, initial_states: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - """PyTorch reference for the inter-chunk recurrent scan.""" + """PyTorch reference for the inter-chunk recurrent scan. + + Matches mamba convention: out[:,c] = state *before* processing chunk c, + so out[:,0] = initial_states and final_states = state after chunk C-1. + """ b, c, h, d = states.shape - out = [] + # out[:,0] = s_{-1} = initial_states (state before chunk 0) + out = [initial_states.float().clone()] s = initial_states.float() for ci in range(c): scale = torch.exp(dA_chunk_cumsum[:, :, ci]).unsqueeze(-1) u = states[:, ci, :, :].float() s = scale * s + u - out.append(s.clone()) + if ci < c - 1: + out.append(s.clone()) return torch.stack(out, dim=1), s diff --git a/tileops/kernels/mamba/da_cumsum.py b/tileops/kernels/mamba/da_cumsum.py index ba36ab139..3336f25f9 100644 --- a/tileops/kernels/mamba/da_cumsum.py +++ b/tileops/kernels/mamba/da_cumsum.py @@ -2,24 +2,25 @@ Mamba-2 dA_cumsum forward kernel. Inputs: - dt: (batch, seq_len, n_heads) -- per-position discretization factor (float32) + dt: (batch, seq_len, n_heads) -- raw per-position dt (float32) A: (n_heads,) -- State Space Model (SSM) decay parameter (float32) + dt_bias: (n_heads,) -- optional per-head dt bias (float32) -Output: - dA_cumsum: (batch, n_heads, num_chunks, chunk_len) -- float32 +Outputs: + dt_out: (batch, n_heads, num_chunks, chunk_len) -- float32, processed dt after bias/softplus/clamp + dA_cumsum: (batch, n_heads, num_chunks, chunk_len) -- float32, inclusive prefix sum of dA = dt_out * A -For each (b, h, c, l), the kernel computes the inclusive prefix sum within each chunk: +For each (b, h, c, l), the kernel computes: - dA_cumsum[b, h, c, l] = sum_{i=0}^{l} dt[b, c*Q + i, h] * A[h] + dt_val = dt[b, c*Q + l, h] + if has_dt_bias: dt_val += dt_bias[h] + if dt_softplus: dt_val = softplus(dt_val) # with bypass for dt_val > 20 + dt_val = clamp(dt_val, dt_min, dt_max) + dt_out[b,h,c,l] = dt_val + dA_cumsum[b,h,c,l] = sum_{i=0}^{l} dt_out[b,h,c,i] * A[h] -This matches A_cumsum in the Mamba-2 ssd_minimal_discrete reference: - - A_cumsum = torch.cumsum(dt * A, dim=-1) # per-chunk, inclusive prefix sum - -The output is consumed by: - - ssd_chunk_scan_fwd: exp(dA_cumsum[l] - dA_cumsum[s]) scales the intra-chunk causal path - - ssd_chunk_state_fwd: exp(dA_cumsum[Q-1] - dA_cumsum[l]) * dt[l] gives per-position decay - - ssd_state_passing_fwd: dA_cumsum[..., Q-1] is the per-chunk scalar inter-chunk decay +This matches _chunk_cumsum_fwd_kernel in the Mamba-2 Triton reference +(mamba_ssm/ops/triton/ssd_chunk_state.py). Alignment with Mamba-2 paper: In ssd_minimal_discrete, A already absorbs dt (A = dt * A_log), so A_cumsum = cumsum(A). @@ -31,7 +32,8 @@ B = batch, S = seq_len = C * Q, H = n_heads, C = num_chunks, Q = chunk_len """ -from typing import Callable, Optional +import functools +from typing import Callable, Optional, Tuple import tilelang import tilelang.language as T @@ -42,12 +44,17 @@ __all__ = ["DaCumsumFwdKernel"] +@functools.lru_cache(maxsize=32) def _da_cumsum_fwd_kernel( batch: int, num_chunks: int, chunk_len: int, n_heads: int, seq_len: int, + dt_softplus: bool = False, + has_dt_bias: bool = False, + dt_min: float = 0.0, + dt_max: float = float("inf"), ) -> Callable: accum_dtype = "float" @@ -57,15 +64,17 @@ def _da_cumsum_fwd_kernel( H = n_heads S = seq_len - @tilelang.jit(out_idx=[-1]) + @tilelang.jit(out_idx=[-2, -1]) def kernel_func(threads: int): @T.prim_func def main( - dt: T.Tensor((B, S, H), accum_dtype), # type: ignore - A: T.Tensor((H,), accum_dtype), # type: ignore - dA_cumsum: T.Tensor((B, H, C, Q), accum_dtype), # type: ignore + dt: T.Tensor((B, S, H), accum_dtype), # type: ignore # raw dt input + A: T.Tensor((H,), accum_dtype), # type: ignore + dt_bias: T.Tensor((H,), accum_dtype), # type: ignore # may be dummy zeros if not has_dt_bias + dt_out: T.Tensor((B, H, C, Q), accum_dtype), # type: ignore # output: processed dt + dA_cumsum: T.Tensor((B, H, C, Q), accum_dtype), # type: ignore # output: inclusive cumsum ): - # Grid: one block per (batch, head, chunk) + # Grid: one block per (batch, head, chunk). # The serial scan over Q positions runs within each block. with T.Kernel(B, H, C, threads=threads) as (bb, bh, bc): # Load the per-head decay parameter once (scalar, constant across chunk). @@ -77,14 +86,36 @@ def main( for l in T.serial(Q): seq_idx = bc * Q + l - # Zero-pad positions that fall beyond the actual sequence length - # (handles the tail chunk when S is not a multiple of Q). in_bounds = seq_idx < S + + # Step 1: load raw dt; zero-pad out-of-bounds tail positions. dt_val = T.if_then_else( in_bounds, dt[bb, seq_idx, bh], T.float32(0.0), ) + + # Step 2: add per-head bias (compile-time conditional). + if has_dt_bias: + dt_val = dt_val + dt_bias[bh] + + # Step 3: softplus with large-value bypass (compile-time conditional). + # Uses log(1 + exp(x)) for x <= 20; identity for x > 20 to avoid overflow. + if dt_softplus: + dt_val = T.if_then_else( + dt_val <= T.float32(20.0), + T.log(T.float32(1.0) + T.exp(dt_val)), + dt_val, + ) + + # Step 4: clamp to [dt_min, dt_max]. + dt_val = T.min(T.max(dt_val, T.float32(dt_min)), T.float32(dt_max)) + + # Step 5: re-apply out-of-bounds zero mask after bias/softplus/clamp. + dt_val = T.if_then_else(in_bounds, dt_val, T.float32(0.0)) + + # Step 6: store processed dt and accumulate dA_cumsum. + dt_out[bb, bh, bc, l] = dt_val running[0] = running[0] + dt_val * dA_head dA_cumsum[bb, bh, bc, l] = running[0] @@ -101,12 +132,18 @@ def _da_cumsum_fwd_wrapped( n_heads: int, seq_len: int, threads: int, + dt_softplus: bool, + has_dt_bias: bool, + dt_min: float, + dt_max: float, dt: torch.Tensor, A: torch.Tensor, -) -> torch.Tensor: - return _da_cumsum_fwd_kernel(batch, num_chunks, chunk_len, n_heads, seq_len)( - threads, - )(dt, A) + dt_bias: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _da_cumsum_fwd_kernel( + batch, num_chunks, chunk_len, n_heads, seq_len, + dt_softplus, has_dt_bias, dt_min, dt_max, + )(threads)(dt, A, dt_bias) @_da_cumsum_fwd_wrapped.register_fake @@ -117,24 +154,33 @@ def _( n_heads: int, seq_len: int, threads: int, + dt_softplus: bool, + has_dt_bias: bool, + dt_min: float, + dt_max: float, dt: torch.Tensor, A: torch.Tensor, -) -> torch.Tensor: - return dt.new_empty((batch, n_heads, num_chunks, chunk_len), dtype=torch.float32) + dt_bias: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + dt_out = dt.new_empty((batch, n_heads, num_chunks, chunk_len), dtype=torch.float32) + dA_cumsum = dt.new_empty((batch, n_heads, num_chunks, chunk_len), dtype=torch.float32) + return dt_out, dA_cumsum class DaCumsumFwdKernel(Kernel): """Mamba-2 dA_cumsum forward kernel. - Computes the chunk-local inclusive prefix sum of dA = dt * A: - - dA_cumsum[b, h, c, l] = sum_{i=0}^{l} dt[b, c*Q+i, h] * A[h] + Applies optional per-head bias, optional softplus activation, and clamping to + raw dt values, then computes the chunk-local inclusive prefix sum of dA = dt * A. - This matches A_cumsum from the Mamba-2 ssd_minimal_discrete reference. + Inputs: + dt (batch, seq_len, n_heads) float32 — raw dt values. + A (n_heads,) float32 — State Space Model (SSM) decay parameters. + dt_bias (n_heads,) float32 — per-head dt bias; required when has_dt_bias=True. - Inputs: dt (batch, seq_len, n_heads), float32 - A (n_heads,), float32 - Output: dA_cumsum (batch, n_heads, num_chunks, chunk_len), float32 + Outputs: + dt_out (batch, n_heads, num_chunks, chunk_len) float32 — processed dt. + dA_cumsum (batch, n_heads, num_chunks, chunk_len) float32 — inclusive prefix sum. """ supported_archs: list[int] = [80, 86, 89, 90] @@ -146,6 +192,10 @@ def __init__( chunk_len: int, n_heads: int, seq_len: int, + dt_softplus: bool = False, + has_dt_bias: bool = False, + dt_min: float = 0.0, + dt_max: float = float("inf"), config: Optional[dict] = None, tune: bool = False, ) -> None: @@ -155,9 +205,15 @@ def __init__( self.chunk_len = chunk_len self.n_heads = n_heads self.seq_len = seq_len - # All inputs and output are always float32; no separate dtype parameter needed. + self.dt_softplus = dt_softplus + self.has_dt_bias = has_dt_bias + self.dt_min = dt_min + self.dt_max = dt_max self.dtype = torch.float32 - self.kernel = _da_cumsum_fwd_kernel(batch, num_chunks, chunk_len, n_heads, seq_len) + self.kernel = _da_cumsum_fwd_kernel( + batch, num_chunks, chunk_len, n_heads, seq_len, + dt_softplus, has_dt_bias, dt_min, dt_max, + ) self.init_config(config, tune) @property @@ -168,22 +224,42 @@ def default_config(self) -> dict: @property def autotune_configs(self) -> list[dict]: - # For small batch/head configs, intra-block parallelism can improve occupancy. - # Warp-level scan with __shfl_up could reduce Q=64 from 64 serial steps to ~6. - return [ - {"threads": 1}, - {"threads": 32}, - {"threads": 64}, - {"threads": 128}, - ] + # The inner scan is T.serial(Q): every thread executes the same loop + # and writes to the same locations. Multiple threads cause redundant + # work and write contention with no benefit. threads=1 is the only + # valid configuration until the scan is parallelised (e.g. warp-level + # __shfl_up reduce). + return [{"threads": 1}] def forward( self, dt: torch.Tensor, A: torch.Tensor, - ) -> torch.Tensor: + dt_bias: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Run the dA_cumsum forward pass. + + Args: + dt: (batch, seq_len, n_heads) float32 — raw dt values. + A: (n_heads,) float32 — SSM decay parameters. + dt_bias: (n_heads,) float32, optional — per-head dt bias. + Required when the kernel was constructed with has_dt_bias=True. + + Returns: + dt_out: (batch, n_heads, num_chunks, chunk_len) float32 — processed dt. + dA_cumsum: (batch, n_heads, num_chunks, chunk_len) float32 — inclusive prefix sum. + """ + dt = dt.contiguous() + A = A.contiguous() + if self.has_dt_bias and dt_bias is None: + raise ValueError("dt_bias is required when has_dt_bias=True") + # Allocate a dummy zero bias when has_dt_bias=False so the kernel + # signature stays fixed regardless of the compile-time flag. + dt_bias = dt.new_zeros(self.n_heads) if dt_bias is None else dt_bias.contiguous() + return _da_cumsum_fwd_wrapped( self.batch, self.num_chunks, self.chunk_len, self.n_heads, self.seq_len, self.config["threads"], - dt, A, + self.dt_softplus, self.has_dt_bias, self.dt_min, self.dt_max, + dt, A, dt_bias, ) diff --git a/tileops/kernels/mamba/ssd_chunk_scan.py b/tileops/kernels/mamba/ssd_chunk_scan.py index 6ee0bddf4..988cb46df 100644 --- a/tileops/kernels/mamba/ssd_chunk_scan.py +++ b/tileops/kernels/mamba/ssd_chunk_scan.py @@ -52,6 +52,7 @@ N = d_state, C = num_chunks, Q = chunk_len """ +import functools import itertools from typing import Callable, Optional @@ -64,6 +65,7 @@ __all__ = ["SSDChunkScanFwdKernel"] +@functools.lru_cache(maxsize=32) def _ssd_chunk_scan_fwd_kernel( batch: int, num_chunks: int, @@ -135,18 +137,6 @@ def main( acc = T.alloc_fragment((block_l, block_p), accum_dtype) T.clear(acc) - # load target-side dA_cumsum[b,h,c,l] for this l-tile - # alloc_shared so it is visible across all thread-parallel loops - dA_l = T.alloc_shared((block_l,), accum_dtype) - for ll in T.Parallel(block_l): - l_abs = l0 + ll - dA_l[ll] = T.if_then_else( - l_abs < Q, - dA_cumsum[bz, bh, bc, l_abs], - T.float32(0.0), - ) - T.sync_threads() - # ===================================================== # PART 1: history path # acc[l,p] += exp(dA_l[l]) * sum_n C[l,g,n] * prev_states[h,p,n] @@ -171,7 +161,9 @@ def main( ) # prev_states[b, c, h, p, n] layout: [B, C, H, P, N] float32 - for nn, pp in T.Parallel(block_n, block_p): + # Iterate (block_p, block_n) so consecutive threads vary nn (the contiguous N + # dim), giving coalesced 128-byte loads instead of strided-by-N accesses. + for pp, nn in T.Parallel(block_p, block_n): n_abs = n0 + nn p_abs = p0 + pp state_tile[nn, pp] = T.if_then_else( @@ -183,6 +175,17 @@ def main( # hist_acc += c_tile @ state_tile T.gemm(c_tile, state_tile, hist_acc) + # Load dA_l here, just before it's needed (avoids a sync before the N-loop) + dA_l = T.alloc_shared((block_l,), accum_dtype) + for ll in T.Parallel(block_l): + l_abs = l0 + ll + dA_l[ll] = T.if_then_else( + l_abs < Q, + dA_cumsum[bz, bh, bc, l_abs], + T.float32(0.0), + ) + T.sync_threads() + # scale by exp(dA_l[l]) and accumulate into acc for ll, pp in T.Parallel(block_l, block_p): l_abs = l0 + ll @@ -387,8 +390,8 @@ def default_config(self) -> dict: def autotune_configs(self) -> list[dict]: block_l = [32, 64] block_p = [32, 64] - block_n = [16, 32] - block_s = [32, 64] + block_n = [32, 64, 128] + block_s = [64, 128] threads = [128, 256] return [ {"block_l": c[0], "block_p": c[1], "block_n": c[2], "block_s": c[3], "threads": c[4]} diff --git a/tileops/kernels/mamba/ssd_chunk_state.py b/tileops/kernels/mamba/ssd_chunk_state.py index c4bee8c6e..0267e2d7b 100644 --- a/tileops/kernels/mamba/ssd_chunk_state.py +++ b/tileops/kernels/mamba/ssd_chunk_state.py @@ -64,6 +64,7 @@ N = d_state, C = num_chunks, Q = chunk_len """ +import functools import itertools from typing import Callable, Optional @@ -76,6 +77,7 @@ __all__ = ["SSDChunkStateFwdKernel"] +@functools.lru_cache(maxsize=32) def _ssd_chunk_state_fwd_kernel( batch: int, num_chunks: int, @@ -128,11 +130,21 @@ def main( ) as (bhc, bp, bn): # -------------------------------------------------------- - # 1. Decode fused axis + # 1. Decode fused axis (b, c, h — h is fastest-changing) + # + # Consecutive CTAs share the same (b, c), so they cover the + # same chunk rows in Bmat. When HEADS_PER_GROUP > 1, the + # HEADS_PER_GROUP consecutive h values that belong to the same + # group map to the same bg and therefore load identical b_tile + # data. Those loads are served from L2 after the first CTA + # warms the cache, reducing effective Bmat bandwidth by up to + # HEADS_PER_GROUP×. The alternative b,h,c order (c fastest) + # shifts chunk_start on every CTA step so no Bmat rows are + # reused between consecutive CTAs. # -------------------------------------------------------- - bz = bhc // (H * C) - bh = (bhc % (H * C)) // C - bc = bhc % C + bz = bhc // (C * H) + bc = (bhc % (C * H)) // H + bh = bhc % H n0 = bn * block_n p0 = bp * block_p @@ -162,20 +174,29 @@ def main( # x_scaled[l, p] = x[l, p] * w(l) (row-scaled x, dtype) # b_tile[l, n] = B[l, n] (unscaled, dtype) # - # w_tile[block_l] holds the per-position scalar weight in shared - # memory. It is filled by T.Parallel(block_l) — one load of - # dA_cumsum[l] and dt[l] per l — then T.sync_threads() makes it - # visible to the subsequent T.Parallel(block_l, block_p) loop, - # avoiding block_p redundant global loads per l. + # w_tile[block_l] holds the per-position scalar weight in + # shared memory. It is filled by T.Parallel(block_l) — one + # load of dA_cumsum[l] and dt[l] per l — then + # T.sync_threads() makes it visible to the subsequent + # T.Parallel(block_l, block_p) loop, avoiding block_p + # redundant global loads per l. + # + # x_scaled is written directly as dtype (cast in-place): + # x_scaled[ll, pp] = cast(float(x[ll, pp]) * w_tile[ll]) + # This eliminates the x_scaled_f32 register fragment + # (block_l * block_p / 32 fp32 regs per thread) and the + # separate T.copy cast step, freeing register budget for + # larger output tiles without changing the shared-memory + # footprint or numerical behavior (the multiply is still + # done in fp32 before truncation to dtype). # # GEMM: acc[p, n] += x_scaled^T @ b_tile # i.e. (block_l x block_p)^T @ (block_l x block_n) # = (block_p x block_l) @ (block_l x block_n) # -------------------------------------------------------- - w_tile = T.alloc_shared((block_l,), accum_dtype) - x_scaled_f32 = T.alloc_fragment((block_l, block_p), accum_dtype) - x_scaled = T.alloc_shared((block_l, block_p), dtype) - b_tile = T.alloc_shared((block_l, block_n), dtype) + w_tile = T.alloc_shared((block_l,), accum_dtype) + x_scaled = T.alloc_shared((block_l, block_p), dtype) + b_tile = T.alloc_shared((block_l, block_n), dtype) # -------------------------------------------------------- # 5. Reduce over chunk positions in L-tiles @@ -212,8 +233,9 @@ def main( w_tile[ll] = T.exp(T.min(dA_end - dA_l, T.float32(0.0))) * dt_l T.sync_threads() - # 5.1 Compute x_scaled[ll, pp] = x[ll, pp] * w_tile[ll] - # w_tile[ll] is read from shared — one value reused block_p times. + # 5.1 Compute x_scaled[ll, pp] = cast(float(x[ll,pp]) * w_tile[ll]) + # Written directly to shared memory as dtype, bypassing + # the intermediate fp32 register fragment. for ll, pp in T.Parallel(block_l, block_p): l_idx = l0 + ll p_idx = p0 + pp @@ -222,10 +244,7 @@ def main( T.cast(x[bz, chunk_start + l_idx, bh, p_idx], accum_dtype), T.float32(0.0), ) - x_scaled_f32[ll, pp] = x_val * w_tile[ll] - - # Cast scaled-x to kernel dtype for tensor-core GEMM - T.copy(x_scaled_f32, x_scaled) + x_scaled[ll, pp] = T.cast(x_val * w_tile[ll], dtype) # 5.2 Cooperative load: B # b_tile[ll, nn] = Bmat[bz, chunk_start + l0 + ll, bg, n0 + nn] @@ -360,8 +379,12 @@ def __init__( @property def default_config(self) -> dict: + # (block_p=64, block_n=64, block_l=64) gives the highest arithmetic + # intensity for the GEMM: K=64 doubles MMA phases vs K=32, and the + # removal of the x_scaled_f32 register fragment (saving block_l * + # block_p / 32 fp32 regs per thread) makes this size register-safe. return { - "block_n": 32, + "block_n": 64, "block_p": 64, "block_l": 64, "threads": 128, @@ -369,8 +392,17 @@ def default_config(self) -> dict: @property def autotune_configs(self) -> list[dict]: - block_n = [16, 32] - block_p = [32, 64] + # Grid rationale: + # block_p in {16, 32, 64}: 16 is the minimum MMA M-atom; 64 only + # viable after the x_scaled_f32 fragment was removed. + # block_n in {64, 128}: aligns to common d_state values; 128 covers + # the full d_state in one tile and maximises N-reuse per L-tile. + # block_l in {32, 64}: larger K improves GEMM arithmetic intensity; + # 32 keeps shared-memory pressure low for small d_head configs. + # threads in {128, 256}: 128 warps = 4, 256 warps = 8; higher thread + # count hides latency but increases register/shared pressure. + block_n = [64, 128] + block_p = [16, 32, 64] block_l = [32, 64] threads = [128, 256] _configs = list(itertools.product(block_n, block_p, block_l, threads)) diff --git a/tileops/kernels/mamba/ssd_decode.py b/tileops/kernels/mamba/ssd_decode.py index ee8822065..0191d7b80 100644 --- a/tileops/kernels/mamba/ssd_decode.py +++ b/tileops/kernels/mamba/ssd_decode.py @@ -39,6 +39,7 @@ B = batch, H = n_heads, P = d_head, N = d_state, G = n_groups """ +import functools from typing import Callable, Optional import tilelang @@ -107,6 +108,7 @@ # matching the behaviour of the optimised selective_state_update path. # ============================================================================= +@functools.lru_cache(maxsize=32) def _ssd_decode_kernel( batch: int, n_heads: int, diff --git a/tileops/kernels/mamba/ssd_state_passing.py b/tileops/kernels/mamba/ssd_state_passing.py index 4fd0e3601..0a2528b12 100644 --- a/tileops/kernels/mamba/ssd_state_passing.py +++ b/tileops/kernels/mamba/ssd_state_passing.py @@ -34,6 +34,7 @@ B = batch, C = num_chunks, H = n_heads, D = d_state """ +import functools import itertools from typing import Callable, Optional @@ -46,6 +47,7 @@ __all__ = ["SSDStatePassingFwdKernel"] +@functools.lru_cache(maxsize=32) def _ssd_state_passing_fwd_kernel( batch: int, num_chunks: int, @@ -97,6 +99,8 @@ def main( # ------------------------------------------------------------ # 2) Initialize running state from initial_states (or zero) + # and write s_{-1} = initial_states to out[:,0,:,:] + # (mamba convention: out[:,c] = state *before* chunk c) # ------------------------------------------------------------ for i in T.Parallel(block_d): di = d0 + i @@ -109,9 +113,17 @@ def main( else: s_frag[i] = T.float32(0.0) + # Write s_{-1} to out[:,0,:,:] + for i in T.Parallel(block_d): + di = d0 + i + if di < D: + out[bb, 0, bh, di] = s_frag[i] + # ------------------------------------------------------------ # 3) Scan over chunks serially # s_c = exp(dA_c) * s_{c-1} + u_c + # out[:,c+1] = s_c for c in [0, C-2] + # final_states = s_{C-1} # ------------------------------------------------------------ for c in T.serial(C): # load scalar dA_c and compute scale @@ -127,18 +139,19 @@ def main( T.float32(0.0), ) - # recurrent update + # recurrent update: s_c = scale * s_{c-1} + u_c for i in T.Parallel(block_d): s_frag[i] = scale * s_frag[i] + u_frag[i] - # write per-chunk output - for i in T.Parallel(block_d): - di = d0 + i - if di < D: - out[bb, c, bh, di] = s_frag[i] + # write s_c to out[:,c+1,:,:] for c < C-1 + if c < C - 1: + for i in T.Parallel(block_d): + di = d0 + i + if di < D: + out[bb, c + 1, bh, di] = s_frag[i] # ------------------------------------------------------------ - # 4) Write final state + # 4) Write final state s_{C-1} # ------------------------------------------------------------ for i in T.Parallel(block_d): di = d0 + i @@ -150,6 +163,7 @@ def main( return kernel_func + @torch.library.custom_op("top::ssd_state_passing_fwd", mutates_args=()) def _ssd_state_passing_fwd_wrapped( batch: int, @@ -165,9 +179,8 @@ def _ssd_state_passing_fwd_wrapped( initial_states: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: return _ssd_state_passing_fwd_kernel( - batch, num_chunks, n_heads, d_state, has_initial_states, dtype)( - block_d, threads, - )(states, dA_chunk_cumsum, initial_states) + batch, num_chunks, n_heads, d_state, has_initial_states, dtype, + )(block_d, threads)(states, dA_chunk_cumsum, initial_states) @_ssd_state_passing_fwd_wrapped.register_fake diff --git a/tileops/ops/da_cumsum.py b/tileops/ops/da_cumsum.py index b7aa62072..851fad6ca 100644 --- a/tileops/ops/da_cumsum.py +++ b/tileops/ops/da_cumsum.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Dict, Optional, Tuple import torch @@ -13,17 +13,20 @@ class DaCumsumFwdOp(Op): """Mamba-2 dA_cumsum forward operator. - Computes the chunk-local inclusive prefix sum of dA = dt * A: - - dA_cumsum[b, h, c, l] = sum_{i=0}^{l} dt[b, c*Q+i, h] * A[h] + Applies optional per-head bias, optional softplus activation, and clamping to + raw dt values, then computes the chunk-local inclusive prefix sum of dA = dt * A. Args: - batch: Batch size. - num_chunks: Number of chunks (seq_len / chunk_len). - chunk_len: Tokens per chunk. - n_heads: Number of attention heads. - seq_len: Total sequence length (= num_chunks * chunk_len). - tune: Whether to autotune tile config on construction. + batch: Batch size. + num_chunks: Number of chunks (seq_len / chunk_len). + chunk_len: Tokens per chunk. + n_heads: Number of attention heads. + seq_len: Total sequence length (= num_chunks * chunk_len). + dt_softplus: Whether to apply softplus (with bypass for dt > 20) to dt. + has_dt_bias: Whether a per-head dt_bias is added before softplus/clamp. + dt_min: Lower clamp bound applied after bias and softplus. + dt_max: Upper clamp bound applied after bias and softplus. + tune: Whether to autotune tile config on construction. """ def __init__( @@ -33,6 +36,10 @@ def __init__( chunk_len: int, n_heads: int, seq_len: int, + dt_softplus: bool = False, + has_dt_bias: bool = False, + dt_min: float = 0.0, + dt_max: float = float("inf"), tune: bool = False, kernel_map: Optional[Dict[str, Kernel]] = None, ): @@ -41,10 +48,19 @@ def __init__( self.chunk_len = chunk_len self.n_heads = n_heads self.seq_len = seq_len + self.dt_softplus = dt_softplus + self.has_dt_bias = has_dt_bias + self.dt_min = dt_min + self.dt_max = dt_max self.dtype = torch.float32 self.dispatch_kernel(kernel_map) self.kernel = self.kernel_map["da_cumsum_fwd"]( - batch, num_chunks, chunk_len, n_heads, seq_len, tune=tune, + batch, num_chunks, chunk_len, n_heads, seq_len, + dt_softplus=dt_softplus, + has_dt_bias=has_dt_bias, + dt_min=dt_min, + dt_max=dt_max, + tune=tune, ) @property @@ -55,15 +71,19 @@ def forward( self, dt: torch.Tensor, A: torch.Tensor, - ) -> torch.Tensor: + dt_bias: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: """Run the dA_cumsum forward pass. Args: - dt: (batch, seq_len, n_heads) float32 - A: (n_heads,) float32 + dt: (batch, seq_len, n_heads) float32 — raw dt values. + A: (n_heads,) float32 — SSM decay parameters. + dt_bias: (n_heads,) float32, optional — per-head dt bias. + Required when the op was constructed with has_dt_bias=True. Returns: - dA_cumsum: (batch, n_heads, num_chunks, chunk_len) float32 + dt_out: (batch, n_heads, num_chunks, chunk_len) float32 — processed dt. + dA_cumsum: (batch, n_heads, num_chunks, chunk_len) float32 — inclusive prefix sum. """ if not dt.is_cuda: raise ValueError("dt must be a CUDA tensor") @@ -73,4 +93,4 @@ def forward( dt = dt.contiguous() A = A.contiguous() - return self.kernel(dt, A) + return self.kernel(dt, A, dt_bias) diff --git a/workloads/mamba.py b/workloads/mamba.py index 9af217962..8d23e7fc2 100644 --- a/workloads/mamba.py +++ b/workloads/mamba.py @@ -8,11 +8,21 @@ class DaCumsumFwdFixture(FixtureBase): def get_params(cls): import pytest return [ - ("batch, num_chunks, chunk_len, n_heads, tune", [ - pytest.param(1, 2, 64, 4, False, marks=pytest.mark.smoke), - pytest.param(2, 4, 64, 8, False, marks=pytest.mark.full), - pytest.param(1, 2, 128, 4, False, marks=pytest.mark.full), - pytest.param(2, 4, 128, 16, False, marks=pytest.mark.full), + ("batch, num_chunks, chunk_len, n_heads, has_dt_bias, dt_softplus, tune", [ + # feature: no bias, no softplus (baseline path) + pytest.param(1, 2, 64, 4, False, False, False, marks=pytest.mark.smoke), + # feature: bias only (has_dt_bias branch, no softplus) + pytest.param(1, 2, 64, 4, True, False, False, marks=pytest.mark.smoke), + # feature: softplus only (no bias, dt_softplus branch) + pytest.param(1, 2, 64, 4, False, True, False, marks=pytest.mark.smoke), + # feature: bias + softplus (full pipeline) + pytest.param(1, 2, 64, 4, True, True, False, marks=pytest.mark.full), + # shape: larger batch and chunk count + pytest.param(2, 4, 64, 8, False, False, False, marks=pytest.mark.full), + # shape: larger chunk_len tile + pytest.param(1, 2, 128, 4, False, False, False, marks=pytest.mark.full), + # shape + feature: large shape with full pipeline + pytest.param(2, 4, 128, 16, True, True, False, marks=pytest.mark.full), ]), ] @@ -23,19 +33,33 @@ def __init__( num_chunks: int, chunk_len: int, n_heads: int, + has_dt_bias: bool = False, + dt_softplus: bool = False, + dt_min: float = 0.0, + dt_max: float = float("inf"), ): self.batch = batch self.num_chunks = num_chunks self.chunk_len = chunk_len self.n_heads = n_heads + self.has_dt_bias = has_dt_bias + self.dt_softplus = dt_softplus + self.dt_min = dt_min + self.dt_max = dt_max def gen_inputs(self): b, C, Q, h = self.batch, self.num_chunks, self.chunk_len, self.n_heads seq_len = C * Q - # dt > 0 (softplus output in Mamba-2), A <= 0 (negative decay) - dt = torch.rand(b, seq_len, h, dtype=torch.float32, device="cuda") * 0.1 + 0.01 + # Raw dt values; softplus maps R -> R+, so randn covers both sides of the nonlinearity. + # A <= 0 (negative decay) + dt_raw = torch.randn(b, seq_len, h, dtype=torch.float32, device="cuda") A = -torch.rand(h, dtype=torch.float32, device="cuda") - return dt, A + # dt_bias is random when used; zeros when not (kernel ignores it in that case). + if self.has_dt_bias: + dt_bias = torch.randn(h, dtype=torch.float32, device="cuda") * 0.5 + else: + dt_bias = torch.zeros(h, dtype=torch.float32, device="cuda") + return dt_raw, A, dt_bias class SSDChunkScanFwdFixture(FixtureBase):