Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 86 additions & 26 deletions benchmarks/ops/bench_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
import torch
import torch.nn.functional as F

from benchmarks.benchmark_base import BenchmarkBase, BenchmarkReport
from tileops.ops.da_cumsum import DaCumsumFwdOp
Expand All @@ -25,6 +26,11 @@
# ---------------------------------------------------------------------------
# Optional mamba_ssm Triton baselines
# ---------------------------------------------------------------------------
try:
from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd as _mamba_chunk_cumsum_fwd
except ImportError:
_mamba_chunk_cumsum_fwd = None

try:
from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd as _mamba_chunk_scan_fwd
except ImportError:
Expand All @@ -48,52 +54,101 @@ def da_cumsum_fwd_ref(
A: torch.Tensor,
num_chunks: int,
chunk_len: int,
) -> torch.Tensor:
"""PyTorch reference for da_cumsum_fwd (benchmark-local copy)."""
dt_bias: torch.Tensor | None = None,
dt_softplus: bool = False,
dt_min: float = 0.0,
dt_max: float = float("inf"),
) -> tuple[torch.Tensor, torch.Tensor]:
"""PyTorch reference for da_cumsum_fwd (benchmark-local copy).

Returns:
dt_out: (batch, n_heads, num_chunks, chunk_len) float32
dA_cumsum: (batch, n_heads, num_chunks, chunk_len) float32
"""
b, S, h = dt.shape
Q = chunk_len
C = num_chunks
dt_chunked = dt.float().reshape(b, C, Q, h)
dt_val = dt.float()
if dt_bias is not None:
dt_val = dt_val + dt_bias.float()
if dt_softplus:
dt_val = F.softplus(dt_val)
dt_val = torch.clamp(dt_val, min=dt_min, max=dt_max)
dt_chunked = dt_val.reshape(b, C, Q, h)
dt_out = dt_chunked.permute(0, 3, 1, 2).contiguous() # (b, h, C, Q)
dA = dt_chunked * A.float()
dA_cumsum = dA.cumsum(dim=2)
return dA_cumsum.permute(0, 3, 1, 2).contiguous()
dA_cumsum = dA.cumsum(dim=2).permute(0, 3, 1, 2).contiguous() # (b, h, C, Q)
return dt_out, dA_cumsum


class DaCumsumFwdBenchmark(BenchmarkBase[DaCumsumFwdTest]):

def calculate_flops(self) -> Optional[float]:
t = self.workload
b, c, L, h = t.batch, t.num_chunks, t.chunk_len, t.n_heads
# One multiply (dt * A) and one add per element for the inclusive scan
# Total: 2 * b * c * L * h
return float(2 * b * c * L * h)
# Core ops per element: 1 mul (dt*A) + 1 add (cumsum) = 2
# Optional bias add: +1; optional softplus (exp+log+add): +3; clamp (min+max): +2
bias_ops = 1 if t.has_dt_bias else 0
softplus_ops = 3 if t.dt_softplus else 0
ops_per_elem = 2 + bias_ops + softplus_ops + 2 # +2 for clamp always
return float(ops_per_elem * b * c * L * h)

def calculate_memory(self) -> Optional[float]:
t = self.workload
b, c, L, h = t.batch, t.num_chunks, t.chunk_len, t.n_heads
# float32 throughout
elem = 4
# Reads: dt (b, c*L, h) + A (h,)
reads = (b * c * L * h + h) * elem
# Writes: dA_cumsum (b, h, c, L)
writes = b * h * c * L * elem
elem = 4 # float32
# Reads: dt_raw (b, c*L, h) + A (h,) + optional dt_bias (h,)
reads = (b * c * L * h + h + (h if t.has_dt_bias else 0)) * elem
# Writes: dt_out (b, h, c, L) + dA_cumsum (b, h, c, L)
writes = 2 * b * h * c * L * elem
return float(reads + writes)


@DaCumsumFwdFixture
def test_da_cumsum_fwd_bench(batch, num_chunks, chunk_len, n_heads, tune):
test = DaCumsumFwdTest(batch, num_chunks, chunk_len, n_heads)
def test_da_cumsum_fwd_bench(batch, num_chunks, chunk_len, n_heads, has_dt_bias, dt_softplus, tune):
test = DaCumsumFwdTest(
batch, num_chunks, chunk_len, n_heads,
has_dt_bias=has_dt_bias, dt_softplus=dt_softplus,
)
bm = DaCumsumFwdBenchmark(test)
inputs = test.gen_inputs()

op = DaCumsumFwdOp(batch, num_chunks, chunk_len, n_heads, seq_len=num_chunks * chunk_len, tune=tune)
inputs = test.gen_inputs() # (dt_raw, A, dt_bias)

op = DaCumsumFwdOp(
batch, num_chunks, chunk_len, n_heads,
seq_len=num_chunks * chunk_len,
has_dt_bias=has_dt_bias,
dt_softplus=dt_softplus,
tune=tune,
)
result = bm.profile(op, *inputs)
BenchmarkReport.record(op, locals(), result, tag="tileops")

def baseline(dt, A):
return da_cumsum_fwd_ref(dt, A, num_chunks, chunk_len)
result_bl = bm.profile(baseline, *inputs)
BenchmarkReport.record(op, locals(), result_bl, tag="torch-ref")
# ── Mamba-2 Triton baseline ──
# _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=...)
# returns (dA_cumsum, dt_out) — note reversed order vs TileOPs (dt_out, dA_cumsum)
if _mamba_chunk_cumsum_fwd is not None:
mamba_dt_bias = inputs[2] if has_dt_bias else None

def mamba_fwd():
return _mamba_chunk_cumsum_fwd(
inputs[0].contiguous(),
inputs[1].contiguous(),
chunk_len,
dt_bias=mamba_dt_bias.contiguous() if mamba_dt_bias is not None else None,
dt_softplus=dt_softplus,
)

result_mamba = bm.profile(mamba_fwd)
BenchmarkReport.record(op, locals(), result_mamba, tag="mamba")
else:
def baseline(dt_raw, A, dt_bias):
return da_cumsum_fwd_ref(
dt_raw, A, num_chunks, chunk_len,
dt_bias=dt_bias if has_dt_bias else None,
dt_softplus=dt_softplus,
)
result_bl = bm.profile(baseline, *inputs)
BenchmarkReport.record(op, locals(), result_bl, tag="torch-ref")


def ssd_chunk_scan_fwd_ref(x, cb, dA_cumsum, C, prev_states, dt, n_groups):
Expand Down Expand Up @@ -418,16 +473,21 @@ def ssd_state_passing_fwd_ref(
dA_chunk_cumsum: torch.Tensor,
initial_states: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""PyTorch reference for ssd_state_passing_fwd (benchmark-local copy)."""
"""PyTorch reference for ssd_state_passing_fwd (benchmark-local copy).

Matches mamba convention: out[:,c] = state *before* processing chunk c,
so out[:,0] = initial_states and final_states = state after chunk C-1.
"""
b, c, h, d = states.shape
out = []
out = [initial_states.float().clone()]
s = initial_states.float()

for ci in range(c):
scale = torch.exp(dA_chunk_cumsum[:, :, ci]).unsqueeze(-1)
u = states[:, ci, :, :].float()
s = scale * s + u
out.append(s.clone())
if ci < c - 1:
out.append(s.clone())

return torch.stack(out, dim=1), s

Expand Down
77 changes: 64 additions & 13 deletions tests/ops/test_mamba.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import torch
import torch.nn.functional as F

from tests.test_base import TestBase, allclose_compare
from tileops.ops.da_cumsum import DaCumsumFwdOp
Expand Down Expand Up @@ -28,34 +29,78 @@ def da_cumsum_fwd_ref(
A: torch.Tensor,
num_chunks: int,
chunk_len: int,
) -> torch.Tensor:
"""PyTorch reference for da_cumsum_fwd."""
dt_bias: torch.Tensor | None = None,
dt_softplus: bool = False,
dt_min: float = 0.0,
dt_max: float = float("inf"),
) -> tuple[torch.Tensor, torch.Tensor]:
"""PyTorch reference for da_cumsum_fwd.

Applies the same bias / softplus / clamp pipeline as the kernel, then
computes dt_out and the chunk-local inclusive prefix sum of dA = dt_out * A.

Returns:
dt_out: (batch, n_heads, num_chunks, chunk_len) float32
dA_cumsum: (batch, n_heads, num_chunks, chunk_len) float32
"""
b, S, h = dt.shape
Q = chunk_len
C = num_chunks
dt_chunked = dt.float().reshape(b, C, Q, h)
dA = dt_chunked * A.float()
dA_cumsum = dA.cumsum(dim=2)
return dA_cumsum.permute(0, 3, 1, 2).contiguous()
dt_val = dt.float()
if dt_bias is not None:
dt_val = dt_val + dt_bias.float()
if dt_softplus:
dt_val = F.softplus(dt_val)
dt_val = torch.clamp(dt_val, min=dt_min, max=dt_max)
dt_chunked = dt_val.reshape(b, C, Q, h) # (b, C, Q, h)
dt_out = dt_chunked.permute(0, 3, 1, 2).contiguous() # (b, h, C, Q)
dA = dt_chunked * A.float() # (b, C, Q, h)
dA_cumsum = dA.cumsum(dim=2).permute(0, 3, 1, 2).contiguous() # (b, h, C, Q)
return dt_out, dA_cumsum


class DaCumsumFwdTest(_DaCumsumFwdTestWorkload, TestBase):
def ref_program(self, dt, A):
return da_cumsum_fwd_ref(dt, A, self.num_chunks, self.chunk_len)
def ref_program(self, dt, A, dt_bias):
return da_cumsum_fwd_ref(
dt, A, self.num_chunks, self.chunk_len,
dt_bias=dt_bias if self.has_dt_bias else None,
dt_softplus=self.dt_softplus,
dt_min=self.dt_min,
dt_max=self.dt_max,
)


@DaCumsumFwdFixture
def test_da_cumsum_fwd(batch, num_chunks, chunk_len, n_heads, tune):
test = DaCumsumFwdTest(batch, num_chunks, chunk_len, n_heads)
def test_da_cumsum_fwd(batch, num_chunks, chunk_len, n_heads, has_dt_bias, dt_softplus, tune):
test = DaCumsumFwdTest(
batch, num_chunks, chunk_len, n_heads,
has_dt_bias=has_dt_bias, dt_softplus=dt_softplus,
)
op = DaCumsumFwdOp(
batch, num_chunks, chunk_len, n_heads,
seq_len=num_chunks * chunk_len,
has_dt_bias=has_dt_bias,
dt_softplus=dt_softplus,
tune=tune,
)
inputs = test.gen_inputs()
test.check(op, *inputs, atol=1e-5, rtol=1e-5)


@pytest.mark.smoke
def test_da_cumsum_fwd_missing_bias_raises():
"""DaCumsumFwdKernel must raise when has_dt_bias=True but dt_bias is None."""
from tileops.kernels.mamba import DaCumsumFwdKernel
kernel = DaCumsumFwdKernel(
batch=1, num_chunks=2, chunk_len=64, n_heads=4,
seq_len=128, has_dt_bias=True,
)
dt = torch.randn(1, 128, 4, dtype=torch.float32, device="cuda")
A = -torch.rand(4, dtype=torch.float32, device="cuda")
with pytest.raises(ValueError, match="dt_bias is required"):
kernel(dt, A, dt_bias=None)


def ssd_chunk_scan_fwd_ref(x, cb, dA_cumsum, C, prev_states, dt, n_groups):
"""Official-aligned PyTorch reference for chunk scan.

Expand Down Expand Up @@ -197,16 +242,22 @@ def ssd_state_passing_fwd_ref(
dA_chunk_cumsum: torch.Tensor,
initial_states: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""PyTorch reference for the inter-chunk recurrent scan."""
"""PyTorch reference for the inter-chunk recurrent scan.

Matches mamba convention: out[:,c] = state *before* processing chunk c,
so out[:,0] = initial_states and final_states = state after chunk C-1.
"""
b, c, h, d = states.shape
out = []
# out[:,0] = s_{-1} = initial_states (state before chunk 0)
out = [initial_states.float().clone()]
s = initial_states.float()

for ci in range(c):
scale = torch.exp(dA_chunk_cumsum[:, :, ci]).unsqueeze(-1)
u = states[:, ci, :, :].float()
s = scale * s + u
out.append(s.clone())
if ci < c - 1:
out.append(s.clone())

return torch.stack(out, dim=1), s

Expand Down
Loading
Loading