Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
cc3f5c2
[Refactor] refactor gemm_sp following 5d729eee
botbw Apr 16, 2026
0bce4d0
[Doc] update doc
botbw Apr 16, 2026
eb55efe
[Refactor] remove gemm_sp CUTLASS templates
botbw Apr 16, 2026
a8c2351
Merge branch 'main' of https://github.com/tile-ai/tilelang into refac…
LeiWang1999 Apr 17, 2026
de32b3b
[cute] add mma sp tempaltes for sm80
botbw Apr 17, 2026
5d9329d
[templates] use templates in codegen
botbw Apr 19, 2026
948eac6
Add WGMMA_SP templates
botbw Apr 27, 2026
cca3c67
Pass layout_map to py lowering
botbw Apr 27, 2026
d28ba8d
Fix type
botbw Apr 27, 2026
2260ece
Add wgmma.sp checking
botbw Apr 28, 2026
70acbbc
Implement wgmma_sp_ss
botbw May 8, 2026
117b855
Add wgmma_sp_rs
botbw May 8, 2026
717fdcb
Add sparse selector
botbw May 8, 2026
760fa33
Fix layout and param pass
botbw May 9, 2026
c6afc18
Fix transpose metadata
botbw May 9, 2026
7d54fbc
Add mma.sp fp8
botbw May 9, 2026
e3e1af5
Fix integer rounding
botbw May 9, 2026
441cc3e
Fix metadata layout
botbw May 10, 2026
edcfc0f
Update compress and test cases
botbw May 10, 2026
d4c8445
Remove debug print statements from sparse compress and wgmma emitter
botbw May 10, 2026
2c106cf
Clean up gemm_sp test file
botbw May 10, 2026
c874289
Remove unused tvm import from gemm_sp_wgmma
botbw May 10, 2026
4012553
Update examples, tests, and docs for new compress() API
botbw May 10, 2026
3312279
Refactor sparse constants into sparse_config.py and clean up examples
botbw May 10, 2026
040fd15
Compute e_factor and e_replicate_factor instead of hardcoding tables
botbw May 10, 2026
ff6a302
Use DataType keys in SPARSE_PARAMS
botbw May 10, 2026
e91ee42
Move sparse_config.py from utils/ to intrinsics/
botbw May 11, 2026
6870ae3
Rename sparse_config.py to sparse_params.py
botbw May 11, 2026
8326178
Fix sparse.py breaking TIR type hints by removing future annotations
botbw May 11, 2026
fd8ca5b
Fix compress utils
botbw May 11, 2026
1c90464
Update example
botbw May 11, 2026
771254e
Add a compress benchmark
botbw May 11, 2026
cd52cc5
Remove print
botbw May 11, 2026
bd11d73
Merge branch 'main' into refactor_gemm_sp
botbw May 12, 2026
9796a76
Remove unused MetaType
botbw May 12, 2026
e2f2792
Polish
botbw May 12, 2026
d55434c
Refactor gemm_sp op
botbw May 12, 2026
69eb8ce
Update doc && remove unused file
botbw May 12, 2026
c10a262
Fix bug spotted by coderabbit
botbw May 12, 2026
0f3f7e0
Fix
botbw May 12, 2026
7eecf89
Add note
botbw May 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 9 additions & 24 deletions benchmark/matmul/benchmark_matmul_sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,14 @@
import torch
from triton.testing import do_bench

import tilelang
import tilelang.language as T
from tilelang.autotuner import autotune
from tilelang import jit
from tilelang.contrib import nvcc
from tilelang.layout import make_cutlass_metadata_layout
from tilelang.utils.sparse import get_e_factor

# Configure logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

arch = nvcc.get_target_compute_version()

ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}


def ref_program(A, B):
"""
Expand Down Expand Up @@ -88,7 +81,7 @@ def get_configs(M, N, K):
return configs


def matmul_sp(M, N, K, in_dtype, accum_dtype):
def matmul_sp(M, N, K, in_dtype, accum_dtype, e_dtype):
"""
Create an autotuned matrix multiplication kernel for matrices of shape:
- A: (M, K)
Expand Down Expand Up @@ -164,8 +157,8 @@ def kernel(
A TVM Tensor Language function (T.prim_func) that computes matmul.
"""
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
e_factor, e_dtype = ARCH_INFO[arch]
# accumulate in float for better numerical accurac
e_factor = get_e_factor(in_dtype, e_dtype)

@T.prim_func
def main(
Expand Down Expand Up @@ -200,15 +193,8 @@ def main(

# Clear out the accumulation buffer
T.clear(C_local)
T.disable_warp_group_reg_alloc()

T.use_swizzle(panel_size=10, enable=enable_rasterization)
T.annotate_layout(
{
E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, block_k=block_K),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, block_k=block_K),
}
)
# Loop over sub-blocks in K dimension, pipelined by num_stages
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
# Load a sub-block of A from global memory into A_shared
Expand All @@ -219,11 +205,13 @@ def main(
T.copy(B[k * block_K, bx * block_N], B_shared)
# Perform a partial matrix multiplication:
# C_local += A_shared @ B_shared
T.gemm_sp_v2(
T.gemm_sp(
A_shared,
E_shared,
B_shared,
C_local,
transpose_A=False,
transpose_E=False,
transpose_B=False,
policy=policy,
)
Expand All @@ -242,8 +230,8 @@ def main(
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument("--disable_cache", action="store_true")
parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype")
parser.add_argument("--e_dtype", type=str, default="int16", choices=["int16", "int8", "int32"], help="Metadata E datatype")
parser.add_argument(
"--bench_torch_sparse",
type=str,
Expand All @@ -253,16 +241,13 @@ def main(
)
args = parser.parse_args()

if args.disable_cache:
tilelang.disable_cache()

M, N, K = args.m, args.n, args.k

# Compute total floating-point operations to measure throughput
total_flops = 2 * M * N * K

# matmul(...) returns (best_latency, best_config, ref_latency)
best_result = matmul_sp(M, N, K, T.float16, args.accum_dtype)
best_result = matmul_sp(M, N, K, T.float16, args.accum_dtype, e_dtype=args.e_dtype)
best_latency = best_result.latency
best_config = best_result.config
A = torch.randn(M, K, dtype=torch.float16, device="cuda")
Expand Down
84 changes: 84 additions & 0 deletions benchmark/matmul/benchmark_matmul_sp_compress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import argparse
from typing import Optional

import torch
from tilelang.profiler import do_bench
from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse, torch_compress

SUPPORTED_DTYPE_NAMES = ["float16", "bfloat16", "float32", "int8"]
SUPPORTED_META_DTYPE_NAMES = ["int8", "int16", "int32"]


def _resolve_torch_dtype(name: str) -> torch.dtype:
dtype = getattr(torch, name, None)
if dtype is None:
raise ValueError(f"Unsupported torch dtype: {name}")
return dtype


def _generate_semi_sparse(m: int, k: int, dtype: torch.dtype, device: str = "cuda") -> torch.Tensor:
if dtype in (torch.int8, torch.uint8):
return randint_semi_sparse(m, k, low=-64, high=64, dtype=dtype, device=device)
return randn_semi_sparse(m, k, dtype=dtype, device=device)


def _compress_bytes(input_tensor: torch.Tensor, sparse_tensor: torch.Tensor, meta_tensor: torch.Tensor) -> int:
return (
input_tensor.numel() * input_tensor.element_size()
+ sparse_tensor.numel() * sparse_tensor.element_size()
+ meta_tensor.numel() * meta_tensor.element_size()
)


def benchmark_compress(
m: int,
k: int,
dtype: torch.dtype,
meta_dtype: Optional[torch.dtype] = None, # noqa: FA100
):
a0 = _generate_semi_sparse(m, k, dtype)

sparse0, meta0 = compress(a0, meta_dtype=meta_dtype)
ref_sparse0, ref_meta0 = torch_compress(a0, meta_dtype=meta_dtype)

bytes_per_compress = _compress_bytes(a0, sparse0, meta0)
bytes_per_torch = _compress_bytes(a0, ref_sparse0, ref_meta0)

tl_latency_ms = do_bench(lambda: compress(a0, meta_dtype=meta_dtype))
torch_latency_ms = do_bench(lambda: torch_compress(a0, meta_dtype=meta_dtype))

tl_latency_s = tl_latency_ms * 1e-3
torch_latency_s = torch_latency_ms * 1e-3
tl_throughput_gbps = bytes_per_compress / tl_latency_s / 1e9
torch_throughput_gbps = bytes_per_torch / torch_latency_s / 1e9

return bytes_per_compress, bytes_per_torch, tl_latency_ms, torch_latency_ms, tl_throughput_gbps, torch_throughput_gbps


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark two TileLang compress operators by memory throughput")
parser.add_argument("--m", type=int, default=16384, help="Matrix rows")
parser.add_argument("--k", type=int, default=16384, help="Matrix columns")
parser.add_argument("--dtype", type=str, default="float16", choices=SUPPORTED_DTYPE_NAMES, help="Input dtype")
parser.add_argument(
"--meta_dtype",
type=str,
default=None,
choices=SUPPORTED_META_DTYPE_NAMES,
help="Metadata dtype (defaults to the library choice for the input dtype)",
)
args = parser.parse_args()

dtype = _resolve_torch_dtype(args.dtype)
meta_dtype = _resolve_torch_dtype(args.meta_dtype) if args.meta_dtype is not None else None

bytes_per_compress, bytes_per_torch, tl_latency_ms, torch_latency_ms, tl_gbps, torch_gbps = benchmark_compress(
args.m,
args.k,
dtype=dtype,
meta_dtype=meta_dtype,
)

print(f"M={args.m} K={args.k} dtype={args.dtype} meta={args.meta_dtype or 'default'}")
print(f"tilelang: {tl_latency_ms:.4f} ms, {tl_gbps:.3f} GB/s (bytes={bytes_per_compress})")
print(f"torch: {torch_latency_ms:.4f} ms, {torch_gbps:.3f} GB/s (bytes={bytes_per_torch})")
Loading
Loading