Skip to content
Open
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
cc3f5c2
[Refactor] refactor gemm_sp following 5d729eee
botbw Apr 16, 2026
0bce4d0
[Doc] update doc
botbw Apr 16, 2026
eb55efe
[Refactor] remove gemm_sp CUTLASS templates
botbw Apr 16, 2026
a8c2351
Merge branch 'main' of https://github.com/tile-ai/tilelang into refac…
LeiWang1999 Apr 17, 2026
de32b3b
[cute] add mma sp tempaltes for sm80
botbw Apr 17, 2026
5d9329d
[templates] use templates in codegen
botbw Apr 19, 2026
948eac6
Add WGMMA_SP templates
botbw Apr 27, 2026
cca3c67
Pass layout_map to py lowering
botbw Apr 27, 2026
d28ba8d
Fix type
botbw Apr 27, 2026
2260ece
Add wgmma.sp checking
botbw Apr 28, 2026
70acbbc
Implement wgmma_sp_ss
botbw May 8, 2026
117b855
Add wgmma_sp_rs
botbw May 8, 2026
717fdcb
Add sparse selector
botbw May 8, 2026
760fa33
Fix layout and param pass
botbw May 9, 2026
c6afc18
Fix transpose metadata
botbw May 9, 2026
7d54fbc
Add mma.sp fp8
botbw May 9, 2026
e3e1af5
Fix integer rounding
botbw May 9, 2026
441cc3e
Fix metadata layout
botbw May 10, 2026
edcfc0f
Update compress and test cases
botbw May 10, 2026
d4c8445
Remove debug print statements from sparse compress and wgmma emitter
botbw May 10, 2026
2c106cf
Clean up gemm_sp test file
botbw May 10, 2026
c874289
Remove unused tvm import from gemm_sp_wgmma
botbw May 10, 2026
4012553
Update examples, tests, and docs for new compress() API
botbw May 10, 2026
3312279
Refactor sparse constants into sparse_config.py and clean up examples
botbw May 10, 2026
040fd15
Compute e_factor and e_replicate_factor instead of hardcoding tables
botbw May 10, 2026
ff6a302
Use DataType keys in SPARSE_PARAMS
botbw May 10, 2026
e91ee42
Move sparse_config.py from utils/ to intrinsics/
botbw May 11, 2026
6870ae3
Rename sparse_config.py to sparse_params.py
botbw May 11, 2026
8326178
Fix sparse.py breaking TIR type hints by removing future annotations
botbw May 11, 2026
fd8ca5b
Fix compress utils
botbw May 11, 2026
1c90464
Update example
botbw May 11, 2026
771254e
Add a compress benchmark
botbw May 11, 2026
cd52cc5
Remove print
botbw May 11, 2026
bd11d73
Merge branch 'main' into refactor_gemm_sp
botbw May 12, 2026
9796a76
Remove unused MetaType
botbw May 12, 2026
e2f2792
Polish
botbw May 12, 2026
d55434c
Refactor gemm_sp op
botbw May 12, 2026
69eb8ce
Update doc && remove unused file
botbw May 12, 2026
c10a262
Fix bug spotted by coderabbit
botbw May 12, 2026
0f3f7e0
Fix
botbw May 12, 2026
7eecf89
Add note
botbw May 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 9 additions & 24 deletions benchmark/matmul/benchmark_matmul_sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,14 @@
import torch
from triton.testing import do_bench

import tilelang
import tilelang.language as T
from tilelang.autotuner import autotune
from tilelang import jit
from tilelang.contrib import nvcc
from tilelang.layout import make_cutlass_metadata_layout
from tilelang.utils.sparse import get_e_factor

# Configure logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

arch = nvcc.get_target_compute_version()

ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}


def ref_program(A, B):
"""
Expand Down Expand Up @@ -88,7 +81,7 @@ def get_configs(M, N, K):
return configs


def matmul_sp(M, N, K, in_dtype, accum_dtype):
def matmul_sp(M, N, K, in_dtype, accum_dtype, e_dtype):
"""
Create an autotuned matrix multiplication kernel for matrices of shape:
- A: (M, K)
Expand Down Expand Up @@ -164,8 +157,8 @@ def kernel(
A TVM Tensor Language function (T.prim_func) that computes matmul.
"""
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
e_factor, e_dtype = ARCH_INFO[arch]
# accumulate in float for better numerical accurac
e_factor = get_e_factor(in_dtype, e_dtype)

@T.prim_func
def main(
Expand Down Expand Up @@ -200,15 +193,8 @@ def main(

# Clear out the accumulation buffer
T.clear(C_local)
T.disable_warp_group_reg_alloc()

T.use_swizzle(panel_size=10, enable=enable_rasterization)
T.annotate_layout(
{
E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, block_k=block_K),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, block_k=block_K),
}
)
# Loop over sub-blocks in K dimension, pipelined by num_stages
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
# Load a sub-block of A from global memory into A_shared
Expand All @@ -219,11 +205,13 @@ def main(
T.copy(B[k * block_K, bx * block_N], B_shared)
# Perform a partial matrix multiplication:
# C_local += A_shared @ B_shared
T.gemm_sp_v2(
T.gemm_sp(
A_shared,
E_shared,
B_shared,
C_local,
transpose_A=False,
transpose_E=False,
transpose_B=False,
policy=policy,
)
Expand All @@ -242,8 +230,8 @@ def main(
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument("--disable_cache", action="store_true")
parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype")
parser.add_argument("--e_dtype", type=str, default="int16", choices=["int16", "int8", "int32"], help="Metadata E datatype")
parser.add_argument(
"--bench_torch_sparse",
type=str,
Expand All @@ -253,16 +241,13 @@ def main(
)
args = parser.parse_args()

if args.disable_cache:
tilelang.disable_cache()

M, N, K = args.m, args.n, args.k

# Compute total floating-point operations to measure throughput
total_flops = 2 * M * N * K

# matmul(...) returns (best_latency, best_config, ref_latency)
best_result = matmul_sp(M, N, K, T.float16, args.accum_dtype)
best_result = matmul_sp(M, N, K, T.float16, args.accum_dtype, e_dtype=args.e_dtype)
best_latency = best_result.latency
best_config = best_result.config
A = torch.randn(M, K, dtype=torch.float16, device="cuda")
Expand Down
84 changes: 84 additions & 0 deletions benchmark/matmul/benchmark_matmul_sp_compress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import argparse
from typing import Optional

import torch
from tilelang.profiler import do_bench
from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse, torch_compress

SUPPORTED_DTYPE_NAMES = ["float16", "bfloat16", "float32", "int8"]
SUPPORTED_META_DTYPE_NAMES = ["int8", "int16", "int32"]


def _resolve_torch_dtype(name: str) -> torch.dtype:
dtype = getattr(torch, name, None)
if dtype is None:
raise ValueError(f"Unsupported torch dtype: {name}")
return dtype


def _generate_semi_sparse(m: int, k: int, dtype: torch.dtype, device: str = "cuda") -> torch.Tensor:
if dtype in (torch.int8, torch.uint8):
return randint_semi_sparse(m, k, low=-64, high=64, dtype=dtype, device=device)
return randn_semi_sparse(m, k, dtype=dtype, device=device)


def _compress_bytes(input_tensor: torch.Tensor, sparse_tensor: torch.Tensor, meta_tensor: torch.Tensor) -> int:
return (
input_tensor.numel() * input_tensor.element_size()
+ sparse_tensor.numel() * sparse_tensor.element_size()
+ meta_tensor.numel() * meta_tensor.element_size()
)


def benchmark_compress(
m: int,
k: int,
dtype: torch.dtype,
meta_dtype: Optional[torch.dtype] = None, # noqa: FA100
):
a0 = _generate_semi_sparse(m, k, dtype)

sparse0, meta0 = compress(a0, meta_dtype=meta_dtype)
ref_sparse0, ref_meta0 = torch_compress(a0, meta_dtype=meta_dtype)

bytes_per_compress = _compress_bytes(a0, sparse0, meta0)
bytes_per_torch = _compress_bytes(a0, ref_sparse0, ref_meta0)

tl_latency_ms = do_bench(lambda: compress(a0, meta_dtype=meta_dtype))
torch_latency_ms = do_bench(lambda: torch_compress(a0, meta_dtype=meta_dtype))

tl_latency_s = tl_latency_ms * 1e-3
torch_latency_s = torch_latency_ms * 1e-3
tl_throughput_gbps = bytes_per_compress / tl_latency_s / 1e9
torch_throughput_gbps = bytes_per_torch / torch_latency_s / 1e9

return bytes_per_compress, bytes_per_torch, tl_latency_ms, torch_latency_ms, tl_throughput_gbps, torch_throughput_gbps


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark two TileLang compress operators by memory throughput")
parser.add_argument("--m", type=int, default=16384, help="Matrix rows")
parser.add_argument("--k", type=int, default=16384, help="Matrix columns")
parser.add_argument("--dtype", type=str, default="float16", choices=SUPPORTED_DTYPE_NAMES, help="Input dtype")
parser.add_argument(
"--meta_dtype",
type=str,
default=None,
choices=SUPPORTED_META_DTYPE_NAMES,
help="Metadata dtype (defaults to the library choice for the input dtype)",
)
args = parser.parse_args()

dtype = _resolve_torch_dtype(args.dtype)
meta_dtype = _resolve_torch_dtype(args.meta_dtype) if args.meta_dtype is not None else None

bytes_per_compress, bytes_per_torch, tl_latency_ms, torch_latency_ms, tl_gbps, torch_gbps = benchmark_compress(
args.m,
args.k,
dtype=dtype,
meta_dtype=meta_dtype,
)

print(f"M={args.m} K={args.k} dtype={args.dtype} meta={args.meta_dtype or 'default'}")
print(f"tilelang: {tl_latency_ms:.4f} ms, {tl_gbps:.3f} GB/s (bytes={bytes_per_compress})")
print(f"torch: {torch_latency_ms:.4f} ms, {torch_gbps:.3f} GB/s (bytes={bytes_per_torch})")
92 changes: 26 additions & 66 deletions docs/deeplearning_operators/matmul_sparse.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,50 +38,35 @@ To utilize sparse Tensor Cores, a dense tensor must first be **compressed** into

Both `PyTorch` and `vLLM` use `CUTLASS` as their computation backend (see references [here](https://github.com/pytorch/pytorch/blob/a8d6afb511a69687bbb2b7e88a3cf67917e1697e/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu#L47) and [here](https://github.com/vllm-project/vllm/blob/a5dd03c1ebc5e4f56f3c9d3dc0436e9c582c978f/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh#L116)), leveraging `CUTLASS`’s built-in compressor (or reimplementing it in `PyTorch`).

A set of **CUTLASS-compatible** compressors is provided in `tilelang.utils.sparse`, where a dense tensor—along with other required arguments (e.g., block_K for sm90, transpose options)—can be passed in to perform the compression.
A compressor is provided in `tilelang.utils.sparse`. Pass in a dense 2:4-sparse tensor and optionally a metadata dtype to get back the compressed values and metadata:

```python
from tilelang.utils.sparse import compress
A_sparse, E = compress(A, transposed=trans_A, block_k=block_K)
A_sparse, E = compress(A) # default: int16 metadata for fp16/bf16
A_sparse, E = compress(A.t().contiguous()) # compress the transposed layout
```

Here, `A_sparse` contains all the non-zero elements of `A`, while `E` stores the corresponding metadata (indexing information) required to reconstruct the original sparse pattern.
Here, `A_sparse` contains all the non-zero elements of `A`, while `E` stores the corresponding metadata (indexing information) required to reconstruct the original sparse pattern. The metadata uses a natural row-major layout that `T.gemm_sp` consumes directly — no additional layout annotation is needed.

> NOTE: When using CUTLASS compressor, there is no naive position correspondence between the positions in `A_sparse`/`A` and `E`. (i.e. the 4-element group at [n, k] doesn't match the 4-bit metadata at [n, k] if you consider metadata as int4 tensor)
The metadata is reordered internally to optimize memory access patterns (e.g., for ldsm instructions and vectorized loads).
For more information, see **A note on `gemm_sp` and `gemm_sp_v2`**.
## `T.gemm_sp`

## `T.gemm_sp` with CUTLASS's compressor
A 2:4 sparse GEMM kernel is similar to its dense counterpart, except that it also requires loading the metadata into shared memory and passing it to `T.gemm_sp`.

:::{warning}

It is strongly recommended to use T.gemm_sp_v2 due to its greater flexibility and faster compilation time.

:::

A 2:4 sparse GEMM kernel is similar to its dense counterpart, except that it also requires handling the associated metadata.

Check comments in below kernel code for required modification.
The default metadata dtype for fp16/bf16 is `int16` with an E-factor of 16 (one `int16` value covers 16 K-elements). For int8/float8 the default is `int32` with E-factor 32.

```python
def matmul_sp_sm80(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
trans_A,
trans_B,
from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter

def matmul_sp(
M, N, K,
block_M, block_N, block_K,
in_dtype, out_dtype, accum_dtype,
num_stages, threads,
trans_A, trans_B,
):
is_8_bit = "8" in in_dtype
metadata_dtype = 'int32' if is_8_bit else 'int16'
E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype] # Calculate shape for given datatypes
metadata_dtype = "int32" if is_8_bit else "int16"
E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype]
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
B_shape = (K, N) if not trans_B else (N, K)
A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
Expand All @@ -91,23 +76,16 @@ def matmul_sp_sm80(

@T.prim_func
def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), metadata_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), metadata_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) # Allocate smem for metadata
E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype)
C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({ # Annotate reordered cutlass metadata layout
E:
make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
E_shared:
make_cutlass_metadata_layout(
E_shared, mma_dtype=in_dtype, arch="8.0"),
})
T.clear(C_frag)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
Expand All @@ -119,19 +97,15 @@ def matmul_sp_sm80(
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B) # Call gemm_sp with non-zero values and metadata
T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B)
T.copy(C_frag, C[by * block_M, bx * block_N])

return main
```

Under the hood, `gemm_sp` invokes templates adapted from `CUTLASS`, and a compatible metadata layout must be specified using `T.annotate_layout`.

## `T.gemm_sp_v2` with a custom compressor

To migrate to `gemm_sp_v2`, simply replace occurrences of `gemm_sp`.
## `T.gemm_sp` with a custom compressor

Unlike `gemm_sp`, `gemm_sp_v2` can operate without `T.annotate_layout`, and it also supports user-defined layouts and compressors.
`T.gemm_sp` lowers directly to PTX, removing the need for a fixed metadata layout. It can operate without `T.annotate_layout`, and supports user-defined layouts and compressors.

The metadata is stored in a `(u)int8`/`(u)int16`/`(u)int32` tensor, where **each 4-bit chunk represents two 2-bit indices** of non-zero elements within four consecutive elements. Here, we start with an `int16` example, which is the **default dtype** for `bf16` and `fp16` on Ampere GPUs.

Expand Down Expand Up @@ -172,7 +146,7 @@ def decode_metadata(meta: torch.Tensor) -> torch.Tensor:

The compressor can be implement at either `PyTorch`/`NumPy` level or kernel level.

For example, `PyTorch` provides an Ampere compressor [here](https://github.com/pytorch/pytorch/blob/267d0197bfca0232488d51dd1ff735d619adc2cf/torch/sparse/_semi_structured_conversions.py#L47-L179). Note that in this implementation, a [permutation](https://github.com/pytorch/pytorch/blob/267d0197bfca0232488d51dd1ff735d619adc2cf/torch/sparse/_semi_structured_conversions.py#L173-L175) is applied to match CUTLASS’s metadata layout. If you do not annotate a metadata layout when using `gemm_sp_v2`, your compressor should replicate the same behavior as the PyTorch example—but without using the `_calculate_meta_reordering_scatter_offsets` function.
For example, `PyTorch` provides an Ampere compressor [here](https://github.com/pytorch/pytorch/blob/267d0197bfca0232488d51dd1ff735d619adc2cf/torch/sparse/_semi_structured_conversions.py#L47-L179). Note that in this implementation, a [permutation](https://github.com/pytorch/pytorch/blob/267d0197bfca0232488d51dd1ff735d619adc2cf/torch/sparse/_semi_structured_conversions.py#L173-L175) is applied to match CUTLASS’s metadata layout. If you do not annotate a metadata layout when using `gemm_sp`, your compressor should replicate the same behavior as the PyTorch example—but without using the `_calculate_meta_reordering_scatter_offsets` function.
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

If you want to use a custom metadata layout in your kernel, one approach is to define the layout in `TileLang` and then apply the same layout to both your compressor kernel and the matmul_sp kernel.

Expand Down Expand Up @@ -245,17 +219,3 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):

return kernel
```

## A note on `gemm_sp` and `gemm_sp_v2`

Initially, `T.gemm_sp` followed the same design as `T.gemm`, lowering to a `CUTLASS` template. This inherently requires metadata to be reordered offline following a predetermined layout.

However, fixing a specific layout introduces several potential issues:

1. Painful debugging experience: Debugging a failed kernel becomes difficult due to the reordered indexing, including permutations and swizzling.

2. Limited flexibility: For example, concatenating two compressed tensors, such as `A_sparse_0` and `A_sparse_1`, into a new `A_sparse` makes sense. However, concatenating their metadata `E_0` and `E_1` may not be valid unless the layout allows it mathematically.

3. Alignment requirements: `CUTLASS` enforces strict alignment checks, and many hyperparameter configurations can lead to compilation errors. (For reference, sm8x was implemented in `CUTLASS 2`.)

`T.gemm_sp_v2` was designed to address these limitations, following the approach of `T.gemm`. It lowers directly to PTX, removing the need for a fixed metadata layout.
Loading