Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/maca/attention_sink/example_gqa_sink_bwd_varlen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import tilelang
from tilelang.profiler import do_bench
import tilelang.language as T
from tilelang.utils.target import determine_target, target_is_maca
import argparse
from typing import Optional
import sys
Expand All @@ -12,6 +13,8 @@


def get_bwd_configs():
if target_is_maca(determine_target("auto", return_object=True)):
return 32, 16, 1, 128
sm_major, sm_minor = torch.cuda.get_device_capability()
sm_version = sm_major * 10 + sm_minor
if sm_version == 80:
Expand Down
21 changes: 20 additions & 1 deletion examples/maca/attention_sink/example_gqa_sink_fwd_varlen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import tilelang.language as T
import tilelang.testing
from tilelang.profiler import do_bench
from tilelang.utils.target import determine_target, target_is_maca
from typing import Optional
import sys
import os
Expand All @@ -15,6 +16,12 @@
from varlen_utils import generate_random_padding_mask, generate_qkv


def get_fwd_configs():
if target_is_maca(determine_target("auto", return_object=True)):
return 64, 32, 1, 128
return 128, 128, 2, 256


@tilelang.jit(
out_idx=[7],
pass_configs={
Expand Down Expand Up @@ -352,8 +359,20 @@ def main(
UQ = q_unpad.shape[0]
UKV = k_unpad.shape[0]

block_M, block_N, num_stages, threads = get_fwd_configs()
kernel = flashattn_sink(
batch, groups, UQ, UKV, heads, dim, is_causal, window_size=window_size, block_M=128, block_N=128, num_stages=2, threads=256
batch,
groups,
UQ,
UKV,
heads,
dim,
is_causal,
window_size=window_size,
block_M=block_M,
block_N=block_N,
num_stages=num_stages,
threads=threads,
)

out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, sinks)
Expand Down
3 changes: 0 additions & 3 deletions examples/maca/attention_sink/test_example_attention_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@ def test_example_mha_sink_fwd_bhsd_sliding_window():
example_mha_sink_fwd_bhsd.main(window_size=128)


@tilelang.testing.pytest.mark.xfail
def test_example_mha_sink_bwd_bhsd():
example_mha_sink_bwd_bhsd.main()


@tilelang.testing.pytest.mark.xfail
def test_example_mha_sink_bwd_bhsd_sliding_window():
example_mha_sink_bwd_bhsd.main(window_size=128)

Expand All @@ -33,7 +31,6 @@ def test_example_gqa_sink_bwd_bhsd_sliding_window():
example_gqa_sink_bwd_bhsd.main(window_size=128)


@tilelang.testing.pytest.mark.xfail
def test_example_gqa_sink_varlen():
example_gqa_sink_fwd_varlen.main() # non-causal
example_gqa_sink_bwd_varlen.main() # causal
Expand Down
6 changes: 4 additions & 2 deletions examples/maca/convolution/example_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,10 @@ def main(argv=None):

args = parser.parse_args(argv)
N, C, H, W, F, K, S, D, P = args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p
a = torch.randn(N, H, W, C).cuda().half()
b = torch.randn(K, K, C, F).cuda().half()
# MACA may default to non-standard 4D strides here; TileLang expects
# dense NHWC/HWCF tensors.
a = torch.randn(N, H, W, C).cuda().half().contiguous()
b = torch.randn(K, K, C, F).cuda().half().contiguous()

block_m = 64
block_n = 128
Expand Down
2 changes: 0 additions & 2 deletions examples/maca/convolution/test_example_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import example_convolution_autotune


# TODO(@cy): TMA with convolution must be fixed in future.
@tilelang.testing.pytest.mark.xfail
def test_example_convolution():
example_convolution.main([])

Expand Down
24 changes: 16 additions & 8 deletions examples/maca/deepseek_mla/example_mla_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tilelang.language as T
from einops import rearrange, einsum
import argparse
from tilelang.utils.target import determine_target, target_is_maca


@tilelang.jit(
Expand All @@ -20,6 +21,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num)
assert kv_head_num == 1, "kv_head_num must be 1"
is_maca = target_is_maca(determine_target("auto", return_object=True))
pipeline_stages = 1 if is_maca else 2
main_threads = 64 if is_maca else 256

@T.prim_func
def main_split(
Expand All @@ -32,7 +36,7 @@ def main_split(
Output: T.Tensor([batch, heads, dim], dtype),
):
# flash_attn_split
with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=256) as (bid, hid, bz):
with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=main_threads) as (bid, hid, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
S_shared = T.alloc_shared([block_H, block_N], dtype)
Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
Expand All @@ -58,7 +62,7 @@ def main_split(
T.fill(scores_max, -T.infinity(accum_dtype))

loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=2):
for k in T.Pipelined(loop_range, num_stages=pipeline_stages):
kv_start = (seqlen_kv // num_split) * bz + k * block_N
kv_end = (seqlen_kv // num_split) * bz + (k + 1) * block_N
T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared)
Expand Down Expand Up @@ -129,7 +133,7 @@ def main_no_split(
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid):
with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=main_threads) as (hid, bid):
Q_shared = T.alloc_shared([block_H, dim], dtype)
S_shared = T.alloc_shared([block_H, block_N], dtype)
Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
Expand All @@ -153,7 +157,7 @@ def main_no_split(
T.fill(scores_max, -T.infinity(accum_dtype))

loop_range = T.ceildiv(seqlen_kv, block_N)
for k in T.Pipelined(loop_range, num_stages=2):
for k in T.Pipelined(loop_range, num_stages=pipeline_stages):
T.copy(KV[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], KV_shared)
T.copy(K_pe[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_pe_shared)
T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True)
Expand Down Expand Up @@ -232,8 +236,10 @@ def main(
qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
pv_flops = 2 * batch * heads * kv_ctx * dim
total_flops = qk_flops + pv_flops
BLOCK_N = 64
BLOCK_H = min(64, heads // kv_heads)
target = determine_target("auto", return_object=True)
is_maca = target_is_maca(target)
BLOCK_N = 16 if is_maca else 64
BLOCK_H = min(16 if is_maca else 64, heads // kv_heads)
num_split = 1
softmax_scale = (dim + pe_dim) ** -0.5

Expand All @@ -253,8 +259,10 @@ def run_regression_perf(
dim=512,
pe_dim=64,
):
BLOCK_N = 64
BLOCK_H = min(64, heads // kv_heads)
target = determine_target("auto", return_object=True)
is_maca = target_is_maca(target)
BLOCK_N = 16 if is_maca else 64
BLOCK_H = min(16 if is_maca else 64, heads // kv_heads)
num_split = 1
softmax_scale = (dim + pe_dim) ** -0.5

Expand Down
1 change: 0 additions & 1 deletion examples/maca/deepseek_mla/test_example_mla_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import example_mla_decode


@tilelang.testing.pytest.mark.xfail
def test_example_mla_decode():
example_mla_decode.main()

Expand Down
34 changes: 25 additions & 9 deletions examples/maca/flash_attention/example_gqa_bwd_tma_reduce_varlen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import tilelang
import tilelang.language as T
from tilelang.contrib import nvcc
from tilelang.utils.target import determine_target, target_is_maca
import argparse
from einops import rearrange, repeat
from bert_padding import pad_input, unpad_input
Expand All @@ -20,6 +21,22 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
return padding_mask


def is_maca_target():
return target_is_maca(determine_target("auto", return_object=True))


def get_varlen_fwd_configs():
if is_maca_target():
return 64, 32
return 128, 64


def get_varlen_bwd_configs(use_atomic):
if is_maca_target():
return 32, 32, 128, 1, False
return 128, 32, 256, 2, use_atomic


@tilelang.jit(
out_idx=[5, 6],
pass_configs={
Expand Down Expand Up @@ -491,8 +508,7 @@ def forward(
):
BATCH, N_CTX, H, D_HEAD_QK = q.shape
D_HEAD_V = v.shape[-1]
block_M = 128
block_N = 64
block_M, block_N = get_varlen_fwd_configs()
q_unpad, indices_q, _, _ = unpad_input(q, (torch.arange(N_CTX, device=q.device).unsqueeze(0) < seqlens_q.unsqueeze(1)))
k_unpad, indices_k, _, _ = unpad_input(k, (torch.arange(N_CTX, device=k.device).unsqueeze(0) < seqlens_k.unsqueeze(1)))
v_unpad, _, _, _ = unpad_input(v, (torch.arange(N_CTX, device=v.device).unsqueeze(0) < seqlens_k.unsqueeze(1)))
Expand All @@ -506,6 +522,7 @@ def forward(
ctx.save_for_backward(q_unpad, k_unpad, v_unpad, o_unpad, lse, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k)
ctx.batch = BATCH
ctx.causal = causal
_, _, _, _, use_atomic = get_varlen_bwd_configs(use_atomic)
ctx.use_atomic = use_atomic
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
Expand All @@ -530,12 +547,11 @@ def maybe_contiguous(x):
return x

do, q, k, v, o = [maybe_contiguous(x) for x in (do_unpad, q, k, v, o)]
block_M = 128
block_N = 32
block_M, block_N, threads, num_stages, use_atomic = get_varlen_bwd_configs(ctx.use_atomic)
mod_prep = flashattn_bwd_preprocess(BATCH, H, total_q, N_CTX, ctx.max_seqlen_q, D_HEAD_V)
delta = mod_prep(o, do, cu_seqlens_q)

if ctx.use_atomic:
if use_atomic:
kernel = flashattn_bwd_atomic_add(
BATCH,
total_q,
Expand All @@ -548,8 +564,8 @@ def maybe_contiguous(x):
ctx.causal,
block_M,
block_N,
threads=256,
num_stages=2,
threads=threads,
num_stages=num_stages,
groups=groups,
)
dq = torch.zeros_like(q, dtype=torch.float32)
Expand All @@ -569,8 +585,8 @@ def maybe_contiguous(x):
ctx.causal,
block_M,
block_N,
threads=256,
num_stages=2,
threads=threads,
num_stages=num_stages,
groups=groups,
)
mod_post = flashattn_bwd_postprocess(total_q, total_kv, H, HEAD_KV, D_HEAD_QK, D_HEAD_V)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import example_gqa_fwd_varlen


@tilelang.testing.pytest.mark.xfail
def test_example_gqa_bwd_tma_reduce_varlen():
example_gqa_bwd_tma_reduce_varlen.main()

Expand Down
60 changes: 34 additions & 26 deletions examples/maca/gdn/example_chunk_delta_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tilelang
import tilelang.language as T
from tilelang.profiler import do_bench
from tilelang.utils.target import determine_target, target_is_maca

print(tilelang.__file__, flush=True)

Expand All @@ -16,9 +17,10 @@

print(fla.__file__, flush=True)
from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu
except ImportError:
print("fla not found, using tilelang implementation")
except Exception as exc:
print(f"fla unavailable, using tilelang implementation: {exc}")
fla = None
chunk_gated_delta_rule_bwd_dhu = None

import torch
import torch.nn.functional as F
Expand All @@ -42,24 +44,24 @@ def prepare_input(
gate_dtype,
state_dtype,
):
Q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
K = F.normalize(K, dim=-1, p=2)
W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
Q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda().contiguous()
K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda().contiguous()
K = F.normalize(K, dim=-1, p=2).contiguous()
W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda().contiguous()
# Note: G should be in logspace and do chunkwise cumsum
G = torch.randn(B, S, H, dtype=gate_dtype).cuda()
G = F.logsigmoid(G)
try:
from fla.ops.utils.cumsum import chunk_local_cumsum

G = chunk_local_cumsum(G, chunk_size)
except ImportError:
print("fla not found, skip cumsum")
except Exception as exc:
print(f"fla unavailable, skip cumsum: {exc}")

h0 = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda()
dht = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda()
dO = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
dv = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
h0 = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda().contiguous()
dht = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda().contiguous()
dO = torch.randn(B, S, H, DV, dtype=input_dtype).cuda().contiguous()
dv = torch.randn(B, S, H, DV, dtype=input_dtype).cuda().contiguous()
return Q, K, W, G, h0, dht, dO, dv


Expand All @@ -76,14 +78,14 @@ def prepare_input_fake(
gate_dtype,
state_dtype,
):
Q = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
K = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
W = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
Q = torch.ones(B, S, H, DK, dtype=input_dtype).cuda().contiguous()
K = torch.ones(B, S, H, DK, dtype=input_dtype).cuda().contiguous()
W = torch.ones(B, S, H, DK, dtype=input_dtype).cuda().contiguous()
G = torch.ones(B, S, H, dtype=gate_dtype).cuda()
h0 = torch.ones(B, H, DK, DV, dtype=input_dtype).cuda()
dht = torch.ones(B, H, DK, DV, dtype=input_dtype).cuda()
dO = torch.ones(B, S, H, DV, dtype=input_dtype).cuda()
dv = torch.ones(B, S, H, DV, dtype=input_dtype).cuda()
h0 = torch.ones(B, H, DK, DV, dtype=input_dtype).cuda().contiguous()
dht = torch.ones(B, H, DK, DV, dtype=input_dtype).cuda().contiguous()
dO = torch.ones(B, S, H, DV, dtype=input_dtype).cuda().contiguous()
dv = torch.ones(B, S, H, DV, dtype=input_dtype).cuda().contiguous()
return Q, K, W, G, h0, dht, dO, dv


Expand Down Expand Up @@ -206,6 +208,10 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
threads=256,
num_stages=0,
):
is_maca = target_is_maca(determine_target("auto", return_object=True))
if is_maca:
block_DV = min(block_DV, 16)

block_S = chunk_size
# Should support cu_seqlen
BS = S // block_S
Expand Down Expand Up @@ -265,14 +271,16 @@ def kernel(
Q_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
Q_fragment_t = T.alloc_fragment((DK, block_S), dtype=accum_dtype)

T.use_swizzle(10)
if not is_maca:
T.use_swizzle(10)

T.annotate_layout(
{
dO_shared: tilelang.layout.make_swizzled_layout(dO_shared),
Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
}
)
if not is_maca:
T.annotate_layout(
{
dO_shared: tilelang.layout.make_swizzled_layout(dO_shared),
Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
}
)

if use_final_state_gradient:
T.copy(dht[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_dh_shared)
Expand Down
Loading
Loading