diff --git a/src/tilegym/ops/cutile/mla.py b/src/tilegym/ops/cutile/mla.py index 07e57d98..e3e30982 100644 --- a/src/tilegym/ops/cutile/mla.py +++ b/src/tilegym/ops/cutile/mla.py @@ -34,6 +34,9 @@ def _mla_sm90_autotune_configs(): for tm in [64, 128, 256]: for tn in [64, 128]: yield SimpleNamespace(TILE_M=tm, TILE_N=tn, num_ctas=1, occupancy=1) + for tm in [64, 128]: + for tn in [64, 128]: + yield SimpleNamespace(TILE_M=tm, TILE_N=tn, num_ctas=1, occupancy=2) @ct.kernel diff --git a/src/tilegym/suites/liger/__init__.py b/src/tilegym/suites/liger/__init__.py index f32c9714..9a9904c2 100644 --- a/src/tilegym/suites/liger/__init__.py +++ b/src/tilegym/suites/liger/__init__.py @@ -20,14 +20,32 @@ # Import unified interface from .ops import cross_entropy from .ops import fused_linear_jsd +from .ops import fused_neighborhood_attention from .ops import geglu +from .ops import group_norm from .ops import jsd +from .ops import kl_div from .ops import layer_norm +from .ops import llama4_rope +from .ops import multi_token_attention +from .ops import qwen2vl_mrope +from .ops import rope +from .ops import sparsemax +from .ops import tiled_mlp __all__ = [ "cross_entropy", "fused_linear_jsd", + "fused_neighborhood_attention", "geglu", + "group_norm", "jsd", + "kl_div", "layer_norm", + "llama4_rope", + "multi_token_attention", + "qwen2vl_mrope", + "rope", + "sparsemax", + "tiled_mlp", ] diff --git a/src/tilegym/suites/liger/cutile/__init__.py b/src/tilegym/suites/liger/cutile/__init__.py index d00c8963..7300f997 100644 --- a/src/tilegym/suites/liger/cutile/__init__.py +++ b/src/tilegym/suites/liger/cutile/__init__.py @@ -6,24 +6,56 @@ from . import cross_entropy # noqa: F401 from . import fused_linear_jsd # noqa: F401 +from . import fused_neighborhood_attention # noqa: F401 from . import geglu # noqa: F401 +from . import group_norm # noqa: F401 from . import jsd # noqa: F401 +from . import kl_div # noqa: F401 from . import layer_norm # noqa: F401 +from . import llama4_rope # noqa: F401 +from . import multi_token_attention # noqa: F401 +from . import qwen2vl_mrope # noqa: F401 +from . import rope # noqa: F401 +from . import sparsemax # noqa: F401 +from . import tiled_mlp # noqa: F401 from .cross_entropy import CrossEntropyCuTileFunction # noqa: F401 from .fused_linear_jsd import FusedLinearJSDCuTileFunction # noqa: F401 from .geglu import GEGLUCuTileFunction # noqa: F401 +from .group_norm import GroupNormCuTileFunction # noqa: F401 from .jsd import JSDCuTileFunction # noqa: F401 +from .kl_div import KLDivCuTileFunction # noqa: F401 from .layer_norm import LayerNormCuTileFunction # noqa: F401 +from .llama4_rope import Llama4RopeCuTileFunction # noqa: F401 +from .multi_token_attention import MultiTokenAttentionCuTileFunction # noqa: F401 +from .qwen2vl_mrope import Qwen2VLMRopeCuTileFunction # noqa: F401 +from .rope import RopeCuTileFunction # noqa: F401 +from .sparsemax import SparsemaxCuTileFunction # noqa: F401 __all__ = [ "CrossEntropyCuTileFunction", "FusedLinearJSDCuTileFunction", "GEGLUCuTileFunction", + "GroupNormCuTileFunction", "JSDCuTileFunction", + "KLDivCuTileFunction", "LayerNormCuTileFunction", + "Llama4RopeCuTileFunction", + "MultiTokenAttentionCuTileFunction", + "SparsemaxCuTileFunction", "cross_entropy", "fused_linear_jsd", + "fused_neighborhood_attention", "geglu", + "group_norm", "jsd", + "kl_div", "layer_norm", + "llama4_rope", + "multi_token_attention", + "Qwen2VLMRopeCuTileFunction", + "RopeCuTileFunction", + "qwen2vl_mrope", + "rope", + "sparsemax", + "tiled_mlp", ] diff --git a/src/tilegym/suites/liger/cutile/fused_neighborhood_attention.py b/src/tilegym/suites/liger/cutile/fused_neighborhood_attention.py new file mode 100644 index 00000000..f6a68b72 --- /dev/null +++ b/src/tilegym/suites/liger/cutile/fused_neighborhood_attention.py @@ -0,0 +1,894 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +# cuTile kernels for fused neighborhood attention (forward + backward). + +import functools +import math +import os +from types import SimpleNamespace + +import cuda.tile as ct +import torch +from cuda.tile import RoundingMode as RMd +from cuda.tile.tune import exhaustive_search + +from tilegym.backend import register_impl + +ConstInt = ct.Constant[int] +ConstBool = ct.Constant[bool] + +INV_LOG_2 = 1.0 / math.log(2) + + +# cuTile kernel: Fused flash-style forward (online softmax, no O(S^2) HBM) +# NOTE: occupancy is NOT hard-coded in the decorator; it is injected via +# replace_hints() at launch time so that exhaustive_search can sweep it. +@ct.kernel +def _fna_fused_forward_kernel( + query, + key, + value, + output, + lse_cache, # 1-D flattened [B*H*padded_seq], log-sum-exp in log2 space for backward + qk_scale: float, + NUM_HEADS: int, + SEQ_LEN: int, + HEAD_DIM: ConstInt, + KERNEL_SIZE: ConstInt, + DILATION: ConstInt, + BLOCK_M: ConstInt, + BLOCK_N: ConstInt, + LSE_STRIDE: int, # padded_seq_len = cdiv(seq_len, BLOCK_M) * BLOCK_M +): + """ + Fused flash-style neighborhood attention forward pass with LSE output. + + Grid: (batch*heads, cdiv(seq_len, BLOCK_M), 1) + Each block handles BLOCK_M query rows with online softmax — never + materialises the full [B, H, S, S] attention matrix. + + Q, K, V: [B, H, S, D] output: [B, H, S, D] + LSE: [B*H*S] (1-D flattened), log-sum-exp = m_i + log2(l_i) + """ + batch_head = ct.bid(0) + tile_m = ct.bid(1) + + batch_id = batch_head // NUM_HEADS + head_id = batch_head % NUM_HEADS + + # Adjust scale for exp2-based softmax (multiply by 1/log(2)) + scale_log2 = qk_scale * INV_LOG_2 + + # Absolute row indices for this tile: [BLOCK_M, 1] + rows = tile_m * BLOCK_M + ct.arange(BLOCK_M, dtype=ct.int32) # [BLOCK_M] + rows = rows[:, None] # [BLOCK_M, 1] + + # Online softmax running state — all in float32 for stability. + # Use a large-but-finite -1e30 instead of -inf for m_i so that + # alpha = exp2(m_i - m_ij) never evaluates as exp2(-inf - (-inf)) = NaN. + _M_INIT = -1e30 + m_i = ct.full((BLOCK_M, 1), _M_INIT, dtype=ct.float32) # row max + l_i = ct.full((BLOCK_M, 1), 0.0, dtype=ct.float32) # row sum + acc = ct.full((BLOCK_M, HEAD_DIM), 0.0, dtype=ct.float32) # output accumulator + + # Load Q tile once: [BLOCK_M, head_dim]. latency= hints let the compiler overlap the TMA loads + # with the MMA (K=2, V=4, Q=2, following the TileGym attention pattern); loads stay inside the + # ct.mma loop to avoid the warp-specialization hang. + q = ct.load(query, index=(batch_id, head_id, tile_m, 0), shape=(1, 1, BLOCK_M, HEAD_DIM), latency=2).reshape( + (BLOCK_M, HEAD_DIM) + ) + + half_k = KERNEL_SIZE // 2 + + # The neighborhood window only touches columns in [m0 - band_w, m0 + BLOCK_M - 1 + band_w], + # so restrict the key/value loop to the band of tiles that intersect that window instead of + # scanning the whole sequence. Loads stay inside the loop and the in-kernel mask below + # remains the correctness guard for partially-covered edge tiles. + band_w = half_k * DILATION + m0 = tile_m * BLOCK_M + band_lo = max(0, m0 - band_w) + band_hi = min(SEQ_LEN, m0 + BLOCK_M + band_w) + n_start = band_lo // BLOCK_N + n_end = ct.cdiv(band_hi, BLOCK_N) + for j in range(n_start, n_end): + # 1. Compute Q @ K[j]^T -> qk: [BLOCK_M, BLOCK_N] + # Load K tile transposed: shape (1,1, head_dim, BLOCK_N) to get K^T + k = ct.load( + key, + index=(batch_id, head_id, 0, j), + shape=(1, 1, HEAD_DIM, BLOCK_N), + order=(0, 1, 3, 2), + latency=2, + ).reshape((HEAD_DIM, BLOCK_N)) # [head_dim, BLOCK_N] + + q_mma = ct.astype(q, ct.tfloat32) if q.dtype == ct.float32 else q + k_mma = ct.astype(k, ct.tfloat32) if k.dtype == ct.float32 else k + + qk = ct.zeros((BLOCK_M, BLOCK_N), dtype=ct.float32) + qk = ct.mma(q_mma, k_mma, qk) # [BLOCK_M, BLOCK_N] + # 2. Apply neighborhood mask inline + cols = j * BLOCK_N + ct.arange(BLOCK_N, dtype=ct.int32) # [BLOCK_N] + cols = cols[None, :] # [1, BLOCK_N] + + # Neighborhood window: |col - row| <= half_k * dilation + col_lo = rows - half_k * DILATION # [BLOCK_M, 1] + col_hi = rows + half_k * DILATION # [BLOCK_M, 1] + + in_range = (cols >= col_lo) & (cols <= col_hi) # [BLOCK_M, BLOCK_N] + + if DILATION > 1: + # Also require (col - row) % dilation == 0 + rel = cols - rows # [BLOCK_M, BLOCK_N] + valid_dilation = (rel % DILATION) == 0 + in_range = in_range & valid_dilation + + # Out-of-bounds columns (col >= seq_len) are always masked + in_bounds = cols < SEQ_LEN # [1, BLOCK_N] + in_range = in_range & in_bounds + + # Mask out-of-neighborhood entries with -inf (they become 0 after exp). + neg_inf_tile = ct.full((BLOCK_M, BLOCK_N), -math.inf, dtype=ct.float32) + qk = ct.where(in_range, qk, neg_inf_tile) # [BLOCK_M, BLOCK_N] + # 3. Online softmax update (exp2 trick for efficiency) + qk_scaled = qk * scale_log2 # [BLOCK_M, BLOCK_N] + + # m_ij: row-wise max of scaled logits. + # When all entries are masked (-inf), max gives -inf, but m_i starts at -1e30 + # (not -inf), so max(m_i=-1e30, -inf) = -1e30, keeping m_ij finite. + m_ij = max(m_i, ct.max(qk_scaled, axis=-1, keepdims=True)) # [BLOCK_M, 1] + + # p = exp2(qk_scaled - m_ij): masked entries give exp2(-inf - finite) = 0 + p = ct.exp2(qk_scaled - m_ij, flush_to_zero=True) # [BLOCK_M, BLOCK_N] + l_ij = ct.sum(p, axis=-1, keepdims=True) # [BLOCK_M, 1] + + # alpha = exp2(m_i - m_ij): both are finite (m_i starts at -1e30), no NaN. + alpha = ct.exp2(m_i - m_ij, flush_to_zero=True) # [BLOCK_M, 1] + l_i = l_i * alpha + l_ij # [BLOCK_M, 1] + acc = acc * alpha # [BLOCK_M, head_dim] + + # Update m_i + m_i = m_ij + # 4. acc += p @ V[j] + v = ct.load( + value, + index=(batch_id, head_id, j, 0), + shape=(1, 1, BLOCK_N, HEAD_DIM), + latency=4, + ).reshape((BLOCK_N, HEAD_DIM)) # [BLOCK_N, head_dim] + + p_cast = p.astype(query.dtype) + p_mma = ct.astype(p_cast, ct.tfloat32) if p_cast.dtype == ct.float32 else p_cast + v_mma = ct.astype(v, ct.tfloat32) if v.dtype == ct.float32 else v + + acc = ct.mma(p_mma, v_mma, acc) # [BLOCK_M, head_dim] + # 5. Final normalisation and store + acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX) + ct.store( + output, + index=(batch_id, head_id, tile_m, 0), + tile=acc.reshape((1, 1, BLOCK_M, HEAD_DIM)).astype(output.dtype), + ) + # 6. Store LSE = m_i + log2(l_i) for backward (in log2 space) + # Scatter into 1-D padded LSE array of size B*H*padded_seq. + # padded_seq = cdiv(seq_len, BLOCK_M) * BLOCK_M (passed via lse_stride). + # For rows beyond seq_len (padding), write +1e30 so that backward + # reconstructs p ~ 0 for those ghost rows, keeping gradients clean. + lse_tile = m_i + ct.log2(l_i) # [BLOCK_M, 1] + lse_tile = lse_tile.reshape((BLOCK_M,)) # [BLOCK_M] + lse_offsets = ct.arange(BLOCK_M, dtype=ct.int32) + row_ids = tile_m * BLOCK_M + lse_offsets # [BLOCK_M] — always < padded_seq + # Out-of-bounds rows get +1e30 so backward reconstructs p ~ 0 + lse_safe = ct.where( + row_ids < SEQ_LEN, + lse_tile, + ct.full((BLOCK_M,), 1e30, dtype=ct.float32), + ) + # lse_stride = padded_seq_len (passed as a ConstInt), guaranteed >= row_ids + lse_indices = batch_head * LSE_STRIDE + row_ids + ct.scatter(lse_cache, lse_indices, lse_safe) + + +# cuTile kernel: backward preprocess — delta_cache[i] = rowsum(O * grad_output) +@ct.kernel(occupancy=1) +def _fna_bwd_preprocess_kernel( + output, + grad_output, + delta_cache, # 1-D padded [B*H*padded_seq] + NUM_HEADS: int, + SEQ_LEN: int, + HEAD_DIM: ConstInt, + BLOCK_M: ConstInt, + LSE_STRIDE: int, # padded_seq_len = cdiv(seq_len, BLOCK_M) * BLOCK_M +): + """ + Compute delta_cache[i] = sum_d(O[i,d] * grad_output[i,d]) for each position i. + + Grid: (batch*heads, cdiv(seq_len, BLOCK_M), 1) + O, grad_output: [B, H, S, D] delta_cache: [B*H*padded_seq] 1-D padded float32 + """ + batch_head = ct.bid(0) + tile_m = ct.bid(1) + + batch_id = batch_head // NUM_HEADS + head_id = batch_head % NUM_HEADS + + o_tile = ( + ct.load(output, index=(batch_id, head_id, tile_m, 0), shape=(1, 1, BLOCK_M, HEAD_DIM)) + .reshape((BLOCK_M, HEAD_DIM)) + .astype(ct.float32) + ) + + do_tile = ( + ct.load(grad_output, index=(batch_id, head_id, tile_m, 0), shape=(1, 1, BLOCK_M, HEAD_DIM)) + .reshape((BLOCK_M, HEAD_DIM)) + .astype(ct.float32) + ) + + delta = ct.sum(o_tile * do_tile, axis=-1, keepdims=False) # [BLOCK_M] + + # Zero out contributions from rows beyond seq_len (padding rows) + delta_offsets = ct.arange(BLOCK_M, dtype=ct.int32) + row_ids = tile_m * BLOCK_M + delta_offsets # [BLOCK_M] + delta_safe = ct.where( + row_ids < SEQ_LEN, + delta, + ct.zeros((BLOCK_M,), dtype=ct.float32), + ) + delta_indices = batch_head * LSE_STRIDE + row_ids + ct.scatter(delta_cache, delta_indices, delta_safe) + + +# cuTile kernel: fused backward grad_key + grad_value +# Each block owns one K/V tile (tile_n) and loops over all Q tiles, +# reconstructing P from Q, K, and LSE (no stored attn_weights). +@ct.kernel(occupancy=1) +def _fna_bwd_dkdv_kernel( + query, + key, + value, + grad_output, + grad_key, + grad_value, + lse_cache, # 1-D padded [B*H*padded_seq] float32 + delta_cache, # 1-D padded [B*H*padded_seq] float32 + qk_scale: float, + NUM_HEADS: int, + SEQ_LEN: int, + HEAD_DIM: ConstInt, + KERNEL_SIZE: ConstInt, + DILATION: ConstInt, + BLOCK_M: ConstInt, # Q-tile size (inner loop) + BLOCK_N: ConstInt, # K/V-tile size (this block) + LSE_STRIDE: int, # padded_seq_len = cdiv(seq_len, BLOCK_M) * BLOCK_M +): + """ + Fused backward: grad_key and grad_value for one K/V tile. + + Grid: (batch*heads, cdiv(seq_len, BLOCK_N), 1) + + For each KV tile n: + Loop over Q tiles m: + reconstruct P[m, n] = exp2(QK[m,n] * scale * INV_LOG2 - LSE[m]) + apply neighborhood mask (set P = 0 where masked) + grad_value[n] += P^T @ grad_output[m] + dP[m,n] = grad_output[m] @ V[n]^T + dS[m,n] = P * (dP - delta_cache[m]) (softmax bwd, masked positions stay 0) + grad_key[n] += dS[m,n]^T @ Q[m] * scale + """ + batch_head = ct.bid(0) + tile_n = ct.bid(1) # this block's KV tile + + batch_id = batch_head // NUM_HEADS + head_id = batch_head % NUM_HEADS + + scale_log2 = qk_scale * INV_LOG_2 + half_k = KERNEL_SIZE // 2 + + # Accumulate grad_key and grad_value in float32 + dk_acc = ct.full((BLOCK_N, HEAD_DIM), 0.0, dtype=ct.float32) + dv_acc = ct.full((BLOCK_N, HEAD_DIM), 0.0, dtype=ct.float32) + + # Load K and V tiles for this block (reused across all Q-tile iterations) + k = ct.load(key, index=(batch_id, head_id, tile_n, 0), shape=(1, 1, BLOCK_N, HEAD_DIM), latency=2).reshape( + (BLOCK_N, HEAD_DIM) + ) + v = ct.load(value, index=(batch_id, head_id, tile_n, 0), shape=(1, 1, BLOCK_N, HEAD_DIM), latency=4).reshape( + (BLOCK_N, HEAD_DIM) + ) + + # Column indices for this KV tile: [1, BLOCK_N] + cols = tile_n * BLOCK_N + ct.arange(BLOCK_N, dtype=ct.int32) # [BLOCK_N] + cols_bcast = cols[None, :] # [1, BLOCK_N] + + # Scatter indices for LSE / delta_cache lookup + lse_delta_offsets = ct.arange(BLOCK_M, dtype=ct.int32) + + # This block owns one KV tile; only Q tiles whose neighborhood window reaches columns + # [n0, n0 + BLOCK_N - 1] contribute, i.e. rows in [n0 - band_w, n0 + BLOCK_N - 1 + band_w]. + # Band the Q-tile loop to that row range; the in-loop mask still guards edge columns. + band_w = half_k * DILATION + n0 = tile_n * BLOCK_N + band_lo = max(0, n0 - band_w) + band_hi = min(SEQ_LEN, n0 + BLOCK_N + band_w) + m_start = band_lo // BLOCK_M + m_end = ct.cdiv(band_hi, BLOCK_M) + for m_idx in range(m_start, m_end): + # Row indices for this Q tile: [BLOCK_M, 1] + rows = m_idx * BLOCK_M + ct.arange(BLOCK_M, dtype=ct.int32) # [BLOCK_M] + rows_bcast = rows[:, None] # [BLOCK_M, 1] + + # Load Q and grad_output tiles + q = ct.load(query, index=(batch_id, head_id, m_idx, 0), shape=(1, 1, BLOCK_M, HEAD_DIM), latency=2).reshape( + (BLOCK_M, HEAD_DIM) + ) + do = ct.load( + grad_output, index=(batch_id, head_id, m_idx, 0), shape=(1, 1, BLOCK_M, HEAD_DIM), latency=2 + ).reshape((BLOCK_M, HEAD_DIM)) + + # Gather LSE and delta_cache for this Q tile (using padded stride lse_stride) + lse_indices = batch_head * LSE_STRIDE + m_idx * BLOCK_M + lse_delta_offsets + lse = ct.gather(lse_cache, lse_indices) # [BLOCK_M] + delta = ct.gather(delta_cache, lse_indices) # [BLOCK_M] + lse = lse[:, None] # [BLOCK_M, 1] + delta = delta[:, None] # [BLOCK_M, 1] + + # Compute QK^T: [BLOCK_M, BLOCK_N] + k_t = k.permute((1, 0)) # [head_dim, BLOCK_N] + q_cast = q.astype(ct.float32) + k_t_cast = k_t.astype(ct.float32) + qk = ct.zeros((BLOCK_M, BLOCK_N), dtype=ct.float32) + qk = ct.mma( + ct.astype(q_cast, ct.tfloat32) if q_cast.dtype == ct.float32 else q_cast, + ct.astype(k_t_cast, ct.tfloat32) if k_t_cast.dtype == ct.float32 else k_t_cast, + qk, + ) + + # Reconstruct P = exp2(QK * scale_log2 - LSE) + p = ct.exp2(qk * scale_log2 - lse, flush_to_zero=True) # [BLOCK_M, BLOCK_N] + + # Apply neighborhood mask: positions outside window become P=0 + col_lo = rows_bcast - half_k * DILATION # [BLOCK_M, 1] + col_hi = rows_bcast + half_k * DILATION # [BLOCK_M, 1] + in_range = (cols_bcast >= col_lo) & (cols_bcast <= col_hi) # [BLOCK_M, BLOCK_N] + if DILATION > 1: + rel = cols_bcast - rows_bcast + in_range = in_range & ((rel % DILATION) == 0) + in_bounds = cols_bcast < SEQ_LEN + in_range = in_range & in_bounds + p = ct.where(in_range, p, ct.zeros((BLOCK_M, BLOCK_N), dtype=ct.float32)) + + # grad_value += P^T @ grad_output (P: [BLOCK_M, BLOCK_N], grad_output: [BLOCK_M, head_dim]) + p_t = p.permute((1, 0)) # [BLOCK_N, BLOCK_M] + do_cast = do.astype(ct.float32) + p_t_mma = ct.astype(p_t, ct.tfloat32) if p_t.dtype == ct.float32 else p_t + do_mma = ct.astype(do_cast, ct.tfloat32) if do_cast.dtype == ct.float32 else do_cast + dv_acc = ct.mma(p_t_mma, do_mma, dv_acc) # [BLOCK_N, head_dim] + + # dP = grad_output @ V^T: [BLOCK_M, BLOCK_N] + v_t = v.permute((1, 0)) # [head_dim, BLOCK_N] + v_t_cast = v_t.astype(ct.float32) + dp = ct.zeros((BLOCK_M, BLOCK_N), dtype=ct.float32) + dp = ct.mma( + ct.astype(do_cast, ct.tfloat32) if do_cast.dtype == ct.float32 else do_cast, + ct.astype(v_t_cast, ct.tfloat32) if v_t_cast.dtype == ct.float32 else v_t_cast, + dp, + ) + + # dS = P * (dP - delta_cache): softmax backward, masked positions stay 0 + ds = p * (dp - delta) # [BLOCK_M, BLOCK_N] + # Re-apply mask: out-of-bounds V/K loads can return NaN, giving 0*NaN=NaN. + # Explicitly zero masked positions to prevent NaN from propagating into grad_key/grad_value. + ds = ct.where(in_range, ds, ct.zeros((BLOCK_M, BLOCK_N), dtype=ct.float32)) + + # grad_key += dS^T @ Q * scale + ds_t = ds.permute((1, 0)) # [BLOCK_N, BLOCK_M] + q_cast2 = q.astype(ct.float32) + ds_t_mma = ct.astype(ds_t, ct.tfloat32) if ds_t.dtype == ct.float32 else ds_t + q_mma2 = ct.astype(q_cast2, ct.tfloat32) if q_cast2.dtype == ct.float32 else q_cast2 + dk_acc = ct.mma(ds_t_mma, q_mma2, dk_acc) # [BLOCK_N, head_dim] + + dk_acc = dk_acc * qk_scale + + ct.store( + grad_key, + index=(batch_id, head_id, tile_n, 0), + tile=dk_acc.reshape((1, 1, BLOCK_N, HEAD_DIM)).astype(grad_key.dtype), + ) + ct.store( + grad_value, + index=(batch_id, head_id, tile_n, 0), + tile=dv_acc.reshape((1, 1, BLOCK_N, HEAD_DIM)).astype(grad_value.dtype), + ) + + +# cuTile kernel: fused backward grad_query +# Each block owns one Q tile (tile_m) and loops over all K/V tiles, +# reconstructing P from Q, K, and LSE (no stored attn_weights). +@ct.kernel(occupancy=1) +def _fna_bwd_dq_kernel( + query, + key, + value, + grad_output, + grad_query, + lse_cache, # 1-D padded [B*H*padded_seq] float32 + delta_cache, # 1-D padded [B*H*padded_seq] float32 + qk_scale: float, + NUM_HEADS: int, + SEQ_LEN: int, + HEAD_DIM: ConstInt, + KERNEL_SIZE: ConstInt, + DILATION: ConstInt, + BLOCK_M: ConstInt, # Q-tile size (this block) + BLOCK_N: ConstInt, # K/V-tile size (inner loop) + LSE_STRIDE: int, # padded_seq_len = cdiv(seq_len, BLOCK_M) * BLOCK_M +): + """ + Fused backward: grad_query for one Q tile. + + Grid: (batch*heads, cdiv(seq_len, BLOCK_M), 1) + + For each Q tile m: + Load Q[m], grad_output[m], LSE[m], delta_cache[m] + Loop over K/V tiles n: + reconstruct P[m,n] + apply neighborhood mask + dP[m,n] = grad_output[m] @ V[n]^T + dS[m,n] = P * (dP - delta_cache[m]) + grad_query[m] += dS[m,n] @ K[n] * scale + """ + batch_head = ct.bid(0) + tile_m = ct.bid(1) + + batch_id = batch_head // NUM_HEADS + head_id = batch_head % NUM_HEADS + + scale_log2 = qk_scale * INV_LOG_2 + half_k = KERNEL_SIZE // 2 + + # Accumulate grad_query in float32 + dq_acc = ct.full((BLOCK_M, HEAD_DIM), 0.0, dtype=ct.float32) + + # Pre-compute indices (arithmetic only, no memory access) + lse_delta_offsets = ct.arange(BLOCK_M, dtype=ct.int32) + lse_indices = batch_head * LSE_STRIDE + tile_m * BLOCK_M + lse_delta_offsets + + # Row indices for this Q tile: [BLOCK_M, 1] + rows = tile_m * BLOCK_M + ct.arange(BLOCK_M, dtype=ct.int32) # [BLOCK_M] + rows_bcast = rows[:, None] # [BLOCK_M, 1] + + # Only KV tiles intersecting this Q tile's neighborhood window contribute; band the loop + # like the forward pass. The in-loop mask still zeroes out-of-window entries. + band_w = half_k * DILATION + m0 = tile_m * BLOCK_M + band_lo = max(0, m0 - band_w) + band_hi = min(SEQ_LEN, m0 + BLOCK_M + band_w) + n_start = band_lo // BLOCK_N + n_end = ct.cdiv(band_hi, BLOCK_N) + for n_idx in range(n_start, n_end): + # All tensor loads are inside the loop to avoid a cuTile warp-specialization + # deadlock: having TMA loads (Q, grad_output) outside a ct.mma accumulation loop + # causes a kernel hang for loop counts >= 4 (seq_len >= 512 with BLOCK_N=128). + # Q, grad_output, LSE, delta_cache are per-Q-tile constants reloaded redundantly each iteration. + q = ct.load(query, index=(batch_id, head_id, tile_m, 0), shape=(1, 1, BLOCK_M, HEAD_DIM), latency=2).reshape( + (BLOCK_M, HEAD_DIM) + ) + do = ct.load( + grad_output, index=(batch_id, head_id, tile_m, 0), shape=(1, 1, BLOCK_M, HEAD_DIM), latency=2 + ).reshape((BLOCK_M, HEAD_DIM)) + lse = ct.gather(lse_cache, lse_indices)[:, None] # [BLOCK_M, 1] + delta = ct.gather(delta_cache, lse_indices)[:, None] # [BLOCK_M, 1] + + # Column indices for this KV tile: [1, BLOCK_N] + cols = n_idx * BLOCK_N + ct.arange(BLOCK_N, dtype=ct.int32) # [BLOCK_N] + cols_bcast = cols[None, :] # [1, BLOCK_N] + + # Load K and V for this KV tile + k = ct.load(key, index=(batch_id, head_id, n_idx, 0), shape=(1, 1, BLOCK_N, HEAD_DIM), latency=2).reshape( + (BLOCK_N, HEAD_DIM) + ) + v = ct.load(value, index=(batch_id, head_id, n_idx, 0), shape=(1, 1, BLOCK_N, HEAD_DIM), latency=4).reshape( + (BLOCK_N, HEAD_DIM) + ) + + # Compute QK^T: [BLOCK_M, BLOCK_N] + k_t = k.permute((1, 0)) # [head_dim, BLOCK_N] + q_cast = q.astype(ct.float32) + k_t_cast = k_t.astype(ct.float32) + qk = ct.zeros((BLOCK_M, BLOCK_N), dtype=ct.float32) + qk = ct.mma( + ct.astype(q_cast, ct.tfloat32) if q_cast.dtype == ct.float32 else q_cast, + ct.astype(k_t_cast, ct.tfloat32) if k_t_cast.dtype == ct.float32 else k_t_cast, + qk, + ) + + # Reconstruct P = exp2(QK * scale_log2 - LSE) + p = ct.exp2(qk * scale_log2 - lse, flush_to_zero=True) # [BLOCK_M, BLOCK_N] + + # Apply neighborhood mask: positions outside window become P=0 + col_lo = rows_bcast - half_k * DILATION + col_hi = rows_bcast + half_k * DILATION + in_range = (cols_bcast >= col_lo) & (cols_bcast <= col_hi) + if DILATION > 1: + rel = cols_bcast - rows_bcast + in_range = in_range & ((rel % DILATION) == 0) + in_bounds = cols_bcast < SEQ_LEN + in_range = in_range & in_bounds + p = ct.where(in_range, p, ct.zeros((BLOCK_M, BLOCK_N), dtype=ct.float32)) + + # dP = grad_output @ V^T: [BLOCK_M, BLOCK_N] + v_t = v.permute((1, 0)) # [head_dim, BLOCK_N] + do_cast = do.astype(ct.float32) + v_t_cast = v_t.astype(ct.float32) + dp = ct.zeros((BLOCK_M, BLOCK_N), dtype=ct.float32) + dp = ct.mma( + ct.astype(do_cast, ct.tfloat32) if do_cast.dtype == ct.float32 else do_cast, + ct.astype(v_t_cast, ct.tfloat32) if v_t_cast.dtype == ct.float32 else v_t_cast, + dp, + ) + + # dS = P * (dP - delta_cache): softmax backward, masked positions stay 0 + ds = p * (dp - delta) # [BLOCK_M, BLOCK_N] + # Re-apply mask: out-of-bounds K/V loads can return NaN, giving 0*NaN=NaN. + # Explicitly zero masked positions to prevent NaN from propagating into grad_query. + ds = ct.where(in_range, ds, ct.zeros((BLOCK_M, BLOCK_N), dtype=ct.float32)) + + # grad_query += dS @ K * scale + k_cast = k.astype(ct.float32) + ds_mma = ct.astype(ds, ct.tfloat32) if ds.dtype == ct.float32 else ds + k_mma = ct.astype(k_cast, ct.tfloat32) if k_cast.dtype == ct.float32 else k_cast + dq_acc = ct.mma(ds_mma, k_mma, dq_acc) # [BLOCK_M, head_dim] + + dq_acc = dq_acc * qk_scale + + ct.store( + grad_query, + index=(batch_id, head_id, tile_m, 0), + tile=dq_acc.reshape((1, 1, BLOCK_M, HEAD_DIM)).astype(grad_query.dtype), + ) + + +# --------------------------------------------------------------------------- +# Autotuning helpers for the fused forward kernel +# --------------------------------------------------------------------------- + + +def _fused_fwd_autotune_configs(): + """ + Search space for the fused forward kernel. + + Sweep BLOCK_M in {64, 128} x BLOCK_N in {32, 64, 128} x occupancy in {1, 2}. + Total: 2 x 3 x 2 = 12 configs. + + Design notes: + - BLOCK_M=256 removed: too many registers at occupancy=2; keep at most 128. + - BLOCK_N=32 included: for a ~7-key window most of BLOCK_N=128 is masked, + so a narrower N-tile reduces useless MMA work. + - BLOCK_N < 32 not supported by ct.mma on sm_100. + - occupancy=1 allows larger tiles without register spill; + occupancy=2 increases parallelism for small shapes. + """ + for bm in [64, 128]: + for bn in [32, 64, 128]: + for occ in [1, 2]: + yield SimpleNamespace(BLOCK_M=bm, BLOCK_N=bn, occupancy=occ) + + +# Module-level cache: (batch*heads, seq_len, head_dim, dtype, device_str) -> (cfg, tuned_kernel) +_fwd_autotune_cache: dict = {} + + +def _fused_fwd_autotune( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + lse: torch.Tensor, + scale: float, + num_heads: int, + seq_len: int, + head_dim: int, + kernel_size: int, + dilation: int, + stream, +) -> SimpleNamespace: + """ + Run exhaustive_search once per problem shape, return the best config. + The chosen config (BLOCK_M, BLOCK_N, occupancy) and the associated + replace_hints()-patched kernel are cached at module level so subsequent + calls go straight to ct.launch. + """ + batch_heads = query.shape[0] * query.shape[1] + cache_key = (batch_heads, head_dim, query.dtype, query.device) + + if cache_key not in _fwd_autotune_cache: + configs = list(_fused_fwd_autotune_configs()) + if os.environ.get("DISABLE_AUTOTUNE", "0") == "1": + configs = configs[:1] + + def grid_fn(cfg): + return ( + batch_heads, + (seq_len + cfg.BLOCK_M - 1) // cfg.BLOCK_M, + 1, + ) + + def args_fn(cfg): + padded_seq = ((seq_len + cfg.BLOCK_M - 1) // cfg.BLOCK_M) * cfg.BLOCK_M + # Resize lse to match BLOCK_M for this config (autotuner calls args_fn per config). + # No pre-fill: the kernel scatters every slot (real rows = m+log2(l); ghost/padding + # rows = 1e30 via the in-kernel ct.where), so torch.empty is safe and drops a + # redundant device-side fill kernel. Timing-only buffer; full coverage either way. + lse_cfg = torch.empty( + (batch_heads * padded_seq,), + device=query.device, + dtype=torch.float32, + ) + return ( + query, + key, + value, + output, + lse_cfg, + scale, + num_heads, + seq_len, + head_dim, + kernel_size, + dilation, + cfg.BLOCK_M, + cfg.BLOCK_N, + padded_seq, # lse_stride + ) + + def hints_fn(cfg): + return {"occupancy": cfg.occupancy} + + with ct.compiler_timeout(30): + result = exhaustive_search(configs, stream, grid_fn, _fna_fused_forward_kernel, args_fn, hints_fn) + best_cfg = result.best.config + tuned_kernel = _fna_fused_forward_kernel.replace_hints(occupancy=best_cfg.occupancy) + _fwd_autotune_cache[cache_key] = (best_cfg, tuned_kernel) + + return _fwd_autotune_cache[cache_key] + + +def _fused_neighborhood_attention_fused_forward_ct( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kernel_size: int = 7, + dilation: int = 1, + scale: float = None, +) -> tuple: + """Flash-style cuTile forward pass. Returns ``(output, lse)`` for the backward pass.""" + batch_size, num_heads, seq_len, head_dim = query.shape + + if scale is None: + scale = 1.0 / math.sqrt(head_dim) + + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + output = torch.empty_like(query) + stream = torch.cuda.current_stream() + + # padded_seq depends on the tuned BLOCK_M; allocate a placeholder lse + # that the autotuner args_fn will resize per config. + placeholder_lse = torch.empty(0, device=query.device, dtype=torch.float32) + best_cfg, tuned_kernel = _fused_fwd_autotune( + query, + key, + value, + output, + placeholder_lse, + scale, + num_heads, + seq_len, + head_dim, + kernel_size, + dilation, + stream, + ) + + fwd_block_m = best_cfg.BLOCK_M + fwd_block_n = best_cfg.BLOCK_N + padded_seq = ((seq_len + fwd_block_m - 1) // fwd_block_m) * fwd_block_m + # LSE: 1-D padded [B*H*padded_seq] in float32 (m + log2(l) from online softmax). + # No pre-fill: the forward kernel scatters EVERY slot (real rows = m+log2(l); ghost/padding + # rows = 1e30 via the in-kernel ct.where, so backward still reads 1e30 => p ~ 0). torch.empty + # is safe and drops a redundant device-side fill kernel on every forward call. + lse = torch.empty( + (batch_size * num_heads * padded_seq,), + device=query.device, + dtype=torch.float32, + ) + + grid = (batch_size * num_heads, (seq_len + fwd_block_m - 1) // fwd_block_m, 1) + ct.launch( + stream, + grid, + tuned_kernel, + ( + query, + key, + value, + output, + lse, + scale, + num_heads, + seq_len, + head_dim, + kernel_size, + dilation, + fwd_block_m, + fwd_block_n, + padded_seq, + ), + ) + return output, lse + + +# Autograd Function: cuTile forward + cuTile backward +def _ensure_contiguous_ct(fn): + @functools.wraps(fn) + def wrapper(ctx, *args, **kwargs): + def _c(x): + return x.contiguous() if isinstance(x, torch.Tensor) else x + + return fn(ctx, *[_c(a) for a in args], **{k: _c(v) for k, v in kwargs.items()}) + + return wrapper + + +class _FusedNeighborhoodAttentionFunctionCT(torch.autograd.Function): + @staticmethod + @_ensure_contiguous_ct + def forward(ctx, query, key, value, kernel_size=7, dilation=1, scale=None): + # Use the fused flash-style forward (no O(S^2) HBM materialisation). + # Also saves LSE (log-sum-exp, shape [B*H*S]) for the backward pass. + output, lse = _fused_neighborhood_attention_fused_forward_ct(query, key, value, kernel_size, dilation, scale) + # Save Q, K, V, and LSE — no O(S^2) attn_weights saved. + ctx.save_for_backward(query, key, value, output, lse) + ctx.kernel_size = kernel_size + ctx.dilation = dilation + ctx.scale = scale + + # Record which BLOCK_M was actually used in forward so the backward can + # reconstruct the correct LSE stride (padded_seq = cdiv(seq_len, fwd_block_m)*fwd_block_m). + # _fused_neighborhood_attention_fused_forward_ct always populates _fwd_autotune_cache. + batch_size, num_heads, seq_len, head_dim = query.shape + cache_key = (batch_size * num_heads, head_dim, query.dtype, query.device) + cfg, _ = _fwd_autotune_cache[cache_key] + ctx.fwd_block_m = cfg.BLOCK_M + ctx.fwd_block_n = cfg.BLOCK_N + + return output + + @staticmethod + @_ensure_contiguous_ct + def backward(ctx, grad_output): + query, key, value, output, lse = ctx.saved_tensors + grad_output = grad_output.contiguous() + + batch_size, num_heads, seq_len, head_dim = query.shape + scale = ctx.scale if ctx.scale is not None else 1.0 / math.sqrt(head_dim) + kernel_size = ctx.kernel_size + dilation = ctx.dilation + + # Use the BLOCK_M/N that forward actually used so LSE strides are consistent. + fwd_block_m = ctx.fwd_block_m + fwd_block_n = ctx.fwd_block_n + + # padded_seq must match the stride used in forward to store LSE + padded_seq = ((seq_len + fwd_block_m - 1) // fwd_block_m) * fwd_block_m + + stream = torch.cuda.current_stream() + # delta_cache uses the same padded stride as LSE (ghost-row slots = 0) + delta = torch.zeros(batch_size * num_heads * padded_seq, device=query.device, dtype=torch.float32) + grid_pre = ( + batch_size * num_heads, + (seq_len + fwd_block_m - 1) // fwd_block_m, + 1, + ) + ct.launch( + stream, + grid_pre, + _fna_bwd_preprocess_kernel, + (output, grad_output, delta, num_heads, seq_len, head_dim, fwd_block_m, padded_seq), + ) + grad_key = torch.zeros_like(key) + grad_value = torch.zeros_like(value) + grid_dkdv = ( + batch_size * num_heads, + (seq_len + fwd_block_n - 1) // fwd_block_n, + 1, + ) + ct.launch( + stream, + grid_dkdv, + _fna_bwd_dkdv_kernel, + ( + query, + key, + value, + grad_output, + grad_key, + grad_value, + lse, + delta, + scale, + num_heads, + seq_len, + head_dim, + kernel_size, + dilation, + fwd_block_m, # BLOCK_M for inner Q-tile loop + fwd_block_n, # BLOCK_N for this block's KV tile + padded_seq, # lse_stride + ), + ) + grad_query = torch.zeros_like(query) + grid_dq = ( + batch_size * num_heads, + (seq_len + fwd_block_m - 1) // fwd_block_m, + 1, + ) + ct.launch( + stream, + grid_dq, + _fna_bwd_dq_kernel, + ( + query, + key, + value, + grad_output, + grad_query, + lse, + delta, + scale, + num_heads, + seq_len, + head_dim, + kernel_size, + dilation, + fwd_block_m, # BLOCK_M for this block's Q tile + fwd_block_n, # BLOCK_N for inner KV-tile loop + padded_seq, # lse_stride + ), + ) + + return grad_query, grad_key, grad_value, None, None, None + + +# TileGym registered implementation +@register_impl("liger.fused_neighborhood_attention", backend="cutile") +def fused_neighborhood_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kernel_size: int = 7, + dilation: int = 1, + scale: float = None, +) -> torch.Tensor: + """ + Fused neighborhood attention — full cuTile forward and backward. + + Args: + query: [batch, heads, seq_len, head_dim] + key: [batch, heads, seq_len, head_dim] + value: [batch, heads, seq_len, head_dim] + kernel_size: neighborhood window size (must be odd) + dilation: dilation factor for neighborhood window + scale: attention scale factor (default: 1/sqrt(head_dim)) + + Returns: + output tensor of shape [batch, heads, seq_len, head_dim] + """ + inference = (not torch.is_grad_enabled()) or not (query.requires_grad or key.requires_grad or value.requires_grad) + if inference: + output, _ = _fused_neighborhood_attention_fused_forward_ct(query, key, value, kernel_size, dilation, scale) + return output + return _FusedNeighborhoodAttentionFunctionCT.apply(query, key, value, kernel_size, dilation, scale) diff --git a/src/tilegym/suites/liger/cutile/group_norm.py b/src/tilegym/suites/liger/cutile/group_norm.py new file mode 100644 index 00000000..96f8f346 --- /dev/null +++ b/src/tilegym/suites/liger/cutile/group_norm.py @@ -0,0 +1,314 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +""" +Group Normalization kernel (CuTile backend). + +Forward: 2D grid (batch_size, num_groups). Each block computes mean/variance + over all elements of a group, then normalizes with per-channel W, B. + mean_stats and rstd_stats are stored for the backward pass. + +Backward: 2D grid (batch_size, num_groups). Each block computes: + - dw_partial[batch_idx, *channels_in_group] — partial gradient for W + - db_partial[batch_idx, *channels_in_group] — partial gradient for B + - DX for its (batch, group) slice + The host reduces dw = dw_partial.sum(dim=0) and db = db_partial.sum(dim=0). +""" + +import cuda.tile as ct +import torch + +from tilegym.backend import register_impl + +from .utils import next_power_of_2 + +MAX_FUSED_SIZE = 65536 + + +@ct.kernel +def _group_norm_fwd_kernel( + x_input, # (batch_size * num_channels, hidden_size_per_channel) + y_output, # (batch_size * num_channels, hidden_size_per_channel) + weight, # (num_channels,) + bias, # (num_channels,) + mean_stats, # (batch_size * num_groups,) — indexed as [batch*num_groups + group] + rstd_stats, # (batch_size * num_groups,) + NUM_CHANNELS: ct.Constant[int], + NUM_GROUPS: ct.Constant[int], + CHANNELS_PER_GROUP: ct.Constant[int], + TOTAL_HIDDEN_SIZE: ct.Constant[int], # hidden_size_per_channel + eps, + BLOCK_SIZE: ct.Constant[int], +): + """ + Group norm forward. + + Grid: (batch_size, num_groups, 1). + One block per (batch, group): computes mean/variance over all channels in the + group, then normalizes with per-channel W and B. + """ + batch_idx = ct.bid(0) + group_idx = ct.bid(1) + + group_row = batch_idx * NUM_GROUPS + group_idx # scalar index for mean_stats/rstd_stats + + # Total elements per group (for normalization denominator) + N = CHANNELS_PER_GROUP * TOTAL_HIDDEN_SIZE + inv_N = 1.0 / N # pre-compute reciprocal to avoid two divisions + + # num_h_chunks is constant for all channels — hoist outside both loops + num_h_chunks = (TOTAL_HIDDEN_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE + + # OOB positions (col_idx >= hidden_size) get padding_value=0 → contribute 0 (correct) + sum_tile = ct.full((BLOCK_SIZE,), 0.0, dtype=ct.float32) + sum_sq_tile = ct.full((BLOCK_SIZE,), 0.0, dtype=ct.float32) + + for c_in_group in range(CHANNELS_PER_GROUP): + channel_idx = group_idx * CHANNELS_PER_GROUP + c_in_group + row_idx = batch_idx * NUM_CHANNELS + channel_idx + + for hi in range(num_h_chunks): + col_idx = ct.arange(BLOCK_SIZE, dtype=ct.int32) + hi * BLOCK_SIZE + x_tile = ct.astype( + ct.gather(x_input, (row_idx, col_idx), check_bounds=True, padding_value=0.0), + ct.float32, + ) + sum_tile = sum_tile + x_tile + sum_sq_tile = sum_sq_tile + x_tile * x_tile + + s = ct.sum(sum_tile, 0, keepdims=False) # scalar + sq = ct.sum(sum_sq_tile, 0, keepdims=False) # scalar + mean = s * inv_N + variance = sq * inv_N - mean * mean + rstd = ct.rsqrt(variance + eps) + + # Store mean and rstd + ct.scatter(mean_stats, group_row, ct.astype(mean, mean_stats.dtype)) + ct.scatter(rstd_stats, group_row, ct.astype(rstd, rstd_stats.dtype)) + + for c_in_group in range(CHANNELS_PER_GROUP): + channel_idx = group_idx * CHANNELS_PER_GROUP + c_in_group + row_idx = batch_idx * NUM_CHANNELS + channel_idx + + w_scalar = ct.astype(ct.load(weight, channel_idx, shape=()), ct.float32) + b_scalar = ct.astype(ct.load(bias, channel_idx, shape=()), ct.float32) + + for hi in range(num_h_chunks): + col_idx = ct.arange(BLOCK_SIZE, dtype=ct.int32) + hi * BLOCK_SIZE + x_tile = ct.astype( + ct.gather(x_input, (row_idx, col_idx), check_bounds=True, padding_value=0.0), + ct.float32, + ) + y_tile = (x_tile - mean) * rstd * w_scalar + b_scalar + ct.scatter(y_output, (row_idx, col_idx), ct.astype(y_tile, y_output.dtype), check_bounds=True) + + +@ct.kernel +def _group_norm_bwd_kernel( + x_input, # (batch_size * num_channels, hidden_size_per_channel) + upstream, # (batch_size * num_channels, hidden_size_per_channel) upstream gradient + weight, # (num_channels,) + mean_stats, # (batch_size * num_groups,) + rstd_stats, # (batch_size * num_groups,) + dx_output, # (batch_size * num_channels, hidden_size_per_channel) output gradient + dw_partial, # (batch_size, num_channels) partial weight gradient — host does .sum(dim=0) + db_partial, # (batch_size, num_channels) partial bias gradient — host does .sum(dim=0) + NUM_CHANNELS: ct.Constant[int], + NUM_GROUPS: ct.Constant[int], + CHANNELS_PER_GROUP: ct.Constant[int], + TOTAL_HIDDEN_SIZE: ct.Constant[int], # hidden_size_per_channel + BLOCK_SIZE: ct.Constant[int], +): + """ + Group norm backward. + + Grid: (batch_size, num_groups, 1). + Each block computes DX for its (batch, group) slice and writes partial + dw and db for the channels it owns. + """ + batch_idx = ct.bid(0) + group_idx = ct.bid(1) + + group_row = batch_idx * NUM_GROUPS + group_idx + + mean = ct.astype(ct.load(mean_stats, group_row, shape=()), ct.float32) + rstd = ct.astype(ct.load(rstd_stats, group_row, shape=()), ct.float32) + + N = CHANNELS_PER_GROUP * TOTAL_HIDDEN_SIZE + inv_N = 1.0 / N # pre-compute reciprocal: one multiply instead of two scalar divisions + + # num_h_chunks is constant for all channels — hoist outside both loops + num_h_chunks = (TOTAL_HIDDEN_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE + + # Pass 1: compute c1, c2 and partial dw, db for each channel in the group + c1_tile = ct.full((BLOCK_SIZE,), 0.0, dtype=ct.float32) + c2_tile = ct.full((BLOCK_SIZE,), 0.0, dtype=ct.float32) + + for c_in_group in range(CHANNELS_PER_GROUP): + channel_idx = group_idx * CHANNELS_PER_GROUP + c_in_group + row_idx = batch_idx * NUM_CHANNELS + channel_idx + + w_scalar = ct.astype(ct.load(weight, channel_idx, shape=()), ct.float32) + + dW_acc_tile = ct.full((BLOCK_SIZE,), 0.0, dtype=ct.float32) + dB_acc_tile = ct.full((BLOCK_SIZE,), 0.0, dtype=ct.float32) + + for hi in range(num_h_chunks): + col_idx = ct.arange(BLOCK_SIZE, dtype=ct.int32) + hi * BLOCK_SIZE + x_tile = ct.astype( + ct.gather(x_input, (row_idx, col_idx), check_bounds=True, padding_value=0.0), + ct.float32, + ) + upstream_tile = ct.astype( + ct.gather(upstream, (row_idx, col_idx), check_bounds=True, padding_value=0.0), + ct.float32, + ) + x_hat = (x_tile - mean) * rstd + wdy = w_scalar * upstream_tile + c1_tile = c1_tile + x_hat * wdy + c2_tile = c2_tile + wdy + dW_acc_tile = dW_acc_tile + upstream_tile * x_hat + dB_acc_tile = dB_acc_tile + upstream_tile + + # Reduce per-channel partial dW, dB to scalar and write to partial buffer + dW_val = ct.sum(dW_acc_tile, 0, keepdims=False) + dB_val = ct.sum(dB_acc_tile, 0, keepdims=False) + ct.scatter(dw_partial, (batch_idx, channel_idx), ct.astype(dW_val, dw_partial.dtype)) + ct.scatter(db_partial, (batch_idx, channel_idx), ct.astype(dB_val, db_partial.dtype)) + + c1 = ct.sum(c1_tile, 0, keepdims=False) * inv_N + c2 = ct.sum(c2_tile, 0, keepdims=False) * inv_N + + # Pass 2: compute DX = (wdy - (x_hat * c1 + c2)) * rstd + for c_in_group in range(CHANNELS_PER_GROUP): + channel_idx = group_idx * CHANNELS_PER_GROUP + c_in_group + row_idx = batch_idx * NUM_CHANNELS + channel_idx + + w_scalar = ct.astype(ct.load(weight, channel_idx, shape=()), ct.float32) + + for hi in range(num_h_chunks): + col_idx = ct.arange(BLOCK_SIZE, dtype=ct.int32) + hi * BLOCK_SIZE + x_tile = ct.astype( + ct.gather(x_input, (row_idx, col_idx), check_bounds=True, padding_value=0.0), + ct.float32, + ) + upstream_tile = ct.astype( + ct.gather(upstream, (row_idx, col_idx), check_bounds=True, padding_value=0.0), + ct.float32, + ) + x_hat = (x_tile - mean) * rstd + wdy = w_scalar * upstream_tile + dx = (wdy - (x_hat * c1 + c2)) * rstd + ct.scatter(dx_output, (row_idx, col_idx), ct.astype(dx, dx_output.dtype), check_bounds=True) + + +class GroupNormCuTileFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, X, W, B, num_channels, num_groups, eps): + if not X.is_contiguous(): + X = X.contiguous() + if not W.is_contiguous(): + W = W.contiguous() + if not B.is_contiguous(): + B = B.contiguous() + + shape = X.shape + batch_size = shape[0] + channels_per_group = num_channels // num_groups + hidden_size = X.shape[-1] # hidden_size_per_channel (spatial dim) + + BLOCK_SIZE = min(MAX_FUSED_SIZE, next_power_of_2(hidden_size)) + + # Reshape to 2D: (batch_size * num_channels, hidden_size) + X_2d = X.view(batch_size * num_channels, hidden_size).contiguous() + Y_2d = torch.empty_like(X_2d) + mean_stats = torch.empty(batch_size * num_groups, dtype=X.dtype, device=X.device) + rstd_stats = torch.empty(batch_size * num_groups, dtype=X.dtype, device=X.device) + + grid = (batch_size, num_groups, 1) + ct.launch( + torch.cuda.current_stream(), + grid, + _group_norm_fwd_kernel, + ( + X_2d, + Y_2d, + W, + B, + mean_stats, + rstd_stats, + int(num_channels), + int(num_groups), + int(channels_per_group), + int(hidden_size), + float(eps), + int(BLOCK_SIZE), + ), + ) + + ctx.num_channels = num_channels + ctx.num_groups = num_groups + ctx.save_for_backward(X_2d, W, B, mean_stats, rstd_stats) + ctx.shape = shape + ctx.BLOCK_SIZE = BLOCK_SIZE + return Y_2d.view(*shape) + + @staticmethod + def backward(ctx, dY): + X_2d, W, B, mean_stats, rstd_stats = ctx.saved_tensors + num_channels = ctx.num_channels + num_groups = ctx.num_groups + shape = ctx.shape + BLOCK_SIZE = ctx.BLOCK_SIZE + + batch_size = shape[0] + hidden_size = shape[-1] + channels_per_group = num_channels // num_groups + + if not dY.is_contiguous(): + dY = dY.contiguous() + dY_2d = dY.view(batch_size * num_channels, hidden_size).contiguous() + + dx_2d = torch.empty_like(X_2d) + dw_partial = torch.zeros(batch_size, num_channels, dtype=W.dtype, device=W.device) + db_partial = torch.zeros(batch_size, num_channels, dtype=B.dtype, device=B.device) + + grid = (batch_size, num_groups, 1) + ct.launch( + torch.cuda.current_stream(), + grid, + _group_norm_bwd_kernel, + ( + X_2d, + dY_2d, + W, + mean_stats, + rstd_stats, + dx_2d, + dw_partial, + db_partial, + int(num_channels), + int(num_groups), + int(channels_per_group), + int(hidden_size), + int(BLOCK_SIZE), + ), + ) + + dw = dw_partial.sum(dim=0) + db = db_partial.sum(dim=0) + return dx_2d.view(*shape), dw, db, None, None, None + + +@register_impl("liger.group_norm", backend="cutile") +def group_norm( + X: torch.Tensor, + num_channels: int, + num_groups: int, + W: torch.Tensor, + B: torch.Tensor, + eps: float = 1e-5, + **kwargs, +) -> torch.Tensor: + return GroupNormCuTileFunction.apply(X, W, B, num_channels, num_groups, eps) diff --git a/src/tilegym/suites/liger/cutile/kl_div.py b/src/tilegym/suites/liger/cutile/kl_div.py new file mode 100644 index 00000000..22650106 --- /dev/null +++ b/src/tilegym/suites/liger/cutile/kl_div.py @@ -0,0 +1,321 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +""" +KL Divergence loss kernel (CuTile backend). + +Computes KL(y_true || y_pred) where y_pred is in log-space. + +Forward: row-parallel, one block per token row (BT). + - "none" mode: writes per-element loss to (BT, V) output. + - Reduce modes (sum/mean/batchmean): accumulates row sum to (BT,) output via fold trick; + final reduction applied at Python level. + +Backward: row-parallel, computes -y_true (or -exp(y_true)) and scatters to gradient tensor. + +ALIGNED optimization +==================== +When N_FULL_CHUNKS > 0, the first N_FULL_CHUNKS chunks are exactly BLOCK_SIZE elements and +use check_bounds=False (hardware TMA path). Only the final tail chunk (if V % BLOCK_SIZE != 0) +uses check_bounds=True (software bounds checking). This avoids per-element bounds checking +for the majority of chunks and activates the hardware TMA path for maximum bandwidth. +""" + +import cuda.tile as ct +import torch + +from tilegym.backend import register_impl + +from .utils import next_power_of_2 + +MAX_FUSED_SIZE = 4096 # Use 4096 for better pipelining (31 chunks for V=128256) + +_REDUCTION_MODE_NONE = 0 +_REDUCTION_MODE_SUM = 1 +_REDUCTION_MODE_MEAN = 2 +_REDUCTION_MODE_BATCHMEAN = 3 + +_str_to_reduction_mode = { + "none": _REDUCTION_MODE_NONE, + "sum": _REDUCTION_MODE_SUM, + "mean": _REDUCTION_MODE_MEAN, + "batchmean": _REDUCTION_MODE_BATCHMEAN, +} + + +@ct.kernel +def _kldiv_fwd_none_ct( + Y, # (BT, V) log-probs input + GT, # (BT, V) target (probs or log-probs) + LOSS, # (BT, V) per-element output + n_cols: ct.Constant[int], + eps: ct.Constant[float], + BLOCK_SIZE: ct.Constant[int], + LOG_TARGET: ct.Constant[int], + N_FULL_CHUNKS: ct.Constant[int], # number of full (non-tail) chunks; use check_bounds=False +): + """ + Forward kernel for reduction='none'. Grid: (BT, 1, 1). + Writes per-element KL loss to LOSS[row, col]. + Full chunks use check_bounds=False (hardware TMA); tail chunk uses check_bounds=True. + """ + row_idx = ct.bid(0) + # Pre-compute eps_tile once outside the loop (compiler hint: loop-invariant) + eps_tile = ct.full((BLOCK_SIZE,), eps, dtype=ct.float32) + + # Fast path: full aligned chunks (check_bounds=False -> hardware TMA) + for ci in range(N_FULL_CHUNKS): + col_idx = ct.add(ct.arange(BLOCK_SIZE, dtype=ct.int32), ci * BLOCK_SIZE) + y = ct.astype(ct.gather(Y, (row_idx, col_idx), check_bounds=False), ct.float32) + gt = ct.astype(ct.gather(GT, (row_idx, col_idx), check_bounds=False), ct.float32) + + if LOG_TARGET: + loss = ct.exp(gt) * (gt - y) + else: + gt_clipped = ct.maximum(gt, eps_tile) + loss = gt * (ct.log(gt_clipped) - y) + + ct.scatter(LOSS, (row_idx, col_idx), ct.astype(loss, LOSS.dtype), check_bounds=False) + + # Slow path: tail chunk only if V is not exactly divisible by BLOCK_SIZE + if N_FULL_CHUNKS * BLOCK_SIZE < n_cols: + ci = N_FULL_CHUNKS + col_idx = ct.add(ct.arange(BLOCK_SIZE, dtype=ct.int32), ci * BLOCK_SIZE) + y = ct.astype(ct.gather(Y, (row_idx, col_idx), check_bounds=True, padding_value=0.0), ct.float32) + gt = ct.astype(ct.gather(GT, (row_idx, col_idx), check_bounds=True, padding_value=0.0), ct.float32) + + if LOG_TARGET: + loss = ct.exp(gt) * (gt - y) + else: + gt_clipped = ct.maximum(gt, eps_tile) + loss = gt * (ct.log(gt_clipped) - y) + + ct.scatter(LOSS, (row_idx, col_idx), ct.astype(loss, LOSS.dtype), check_bounds=True) + + +@ct.kernel +def _kldiv_fwd_reduce_ct( + Y, # (BT, V) log-probs input + GT, # (BT, V) target (probs or log-probs) + LOSS, # (BT,) per-row sum output + n_cols: ct.Constant[int], + eps: ct.Constant[float], + BLOCK_SIZE: ct.Constant[int], + LOG_TARGET: ct.Constant[int], + N_FULL_CHUNKS: ct.Constant[int], # number of full (non-tail) chunks; use check_bounds=False +): + """ + Forward kernel for sum/mean/batchmean reductions. Grid: (BT, 1, 1). + Computes per-row sum via fold trick and stores to LOSS[row]. + Full chunks use check_bounds=False (hardware TMA); tail chunk uses check_bounds=True. + """ + row_idx = ct.bid(0) + + loss_acc = ct.full((BLOCK_SIZE,), 0.0, dtype=ct.float32) + # Pre-compute eps_tile once outside the loop (compiler hint: loop-invariant) + eps_tile = ct.full((BLOCK_SIZE,), eps, dtype=ct.float32) + + # Fast path: full aligned chunks (check_bounds=False -> hardware TMA) + for ci in range(N_FULL_CHUNKS): + col_idx = ct.add(ct.arange(BLOCK_SIZE, dtype=ct.int32), ci * BLOCK_SIZE) + y = ct.astype(ct.gather(Y, (row_idx, col_idx), check_bounds=False), ct.float32) + gt = ct.astype(ct.gather(GT, (row_idx, col_idx), check_bounds=False), ct.float32) + + if LOG_TARGET: + loss = ct.exp(gt) * (gt - y) + else: + gt_clipped = ct.maximum(gt, eps_tile) + loss = gt * (ct.log(gt_clipped) - y) + + loss_acc = ct.add(loss_acc, loss) + + # Slow path: tail chunk only if V is not exactly divisible by BLOCK_SIZE + if N_FULL_CHUNKS * BLOCK_SIZE < n_cols: + ci = N_FULL_CHUNKS + col_idx = ct.add(ct.arange(BLOCK_SIZE, dtype=ct.int32), ci * BLOCK_SIZE) + y = ct.astype(ct.gather(Y, (row_idx, col_idx), check_bounds=True, padding_value=0.0), ct.float32) + gt = ct.astype(ct.gather(GT, (row_idx, col_idx), check_bounds=True, padding_value=0.0), ct.float32) + + if LOG_TARGET: + loss = ct.exp(gt) * (gt - y) + else: + gt_clipped = ct.maximum(gt, eps_tile) + loss = gt * (ct.log(gt_clipped) - y) + + loss_acc = ct.add(loss_acc, loss) + + row_sum = ct.sum(loss_acc, 0, keepdims=False) + ct.scatter(LOSS, row_idx, ct.astype(row_sum, LOSS.dtype)) + + +@ct.kernel +def _kldiv_bwd_ct( + GT, # (BT, V) target (probs or log-probs) + GRADS, # (BT, V) output gradient + n_cols: ct.Constant[int], + scale: ct.Constant[float], # combined grad_output * reduction_normalizer — fused to avoid extra kernel launches + BLOCK_SIZE: ct.Constant[int], + LOG_TARGET: ct.Constant[int], + N_FULL_CHUNKS: ct.Constant[int], # number of full (non-tail) chunks; use check_bounds=False +): + """ + Backward kernel. Grid: (BT, 1, 1). + Gradient w.r.t. y_pred: -y_true * scale (or -exp(y_true) * scale for log_target). + scale fuses grad_output and the reduction normalizer (1/BT or 1/(BT*V)) so that + both are applied in a single pass, eliminating post-kernel element-wise ops. + Full chunks use check_bounds=False (hardware TMA); tail chunk uses check_bounds=True. + """ + row_idx = ct.bid(0) + + # Fast path: full aligned chunks (check_bounds=False -> hardware TMA) + for ci in range(N_FULL_CHUNKS): + col_idx = ct.add(ct.arange(BLOCK_SIZE, dtype=ct.int32), ci * BLOCK_SIZE) + gt = ct.astype(ct.gather(GT, (row_idx, col_idx), check_bounds=False), ct.float32) + + if LOG_TARGET: + res = -ct.exp(gt) * scale + else: + res = -gt * scale + + ct.scatter(GRADS, (row_idx, col_idx), ct.astype(res, GRADS.dtype), check_bounds=False) + + # Slow path: tail chunk only if V is not exactly divisible by BLOCK_SIZE + if N_FULL_CHUNKS * BLOCK_SIZE < n_cols: + ci = N_FULL_CHUNKS + col_idx = ct.add(ct.arange(BLOCK_SIZE, dtype=ct.int32), ci * BLOCK_SIZE) + gt = ct.astype(ct.gather(GT, (row_idx, col_idx), check_bounds=True, padding_value=0.0), ct.float32) + + if LOG_TARGET: + res = -ct.exp(gt) * scale + else: + res = -gt * scale + + ct.scatter(GRADS, (row_idx, col_idx), ct.astype(res, GRADS.dtype), check_bounds=True) + + +def _kldiv_forward_ct(y_pred, y_true, log_target, reduction, eps): + BT, V = y_pred.shape + BLOCK_SIZE = min(MAX_FUSED_SIZE, next_power_of_2(V)) + reduction_int = _str_to_reduction_mode[reduction] + n_full_chunks = V // BLOCK_SIZE # full aligned chunks (tail handled separately in kernel) + + grid = (BT, 1, 1) + + if reduction_int == _REDUCTION_MODE_NONE: + output_tensor = torch.zeros(BT, V, device=y_pred.device, dtype=torch.float32) + ct.launch( + torch.cuda.current_stream(), + grid, + _kldiv_fwd_none_ct, + ( + y_pred, + y_true, + output_tensor, + int(V), + float(eps), + int(BLOCK_SIZE), + int(log_target), + int(n_full_chunks), + ), + ) + return output_tensor + else: + row_sums = torch.zeros(BT, device=y_pred.device, dtype=torch.float32) + ct.launch( + torch.cuda.current_stream(), + grid, + _kldiv_fwd_reduce_ct, + ( + y_pred, + y_true, + row_sums, + int(V), + float(eps), + int(BLOCK_SIZE), + int(log_target), + int(n_full_chunks), + ), + ) + if reduction_int == _REDUCTION_MODE_BATCHMEAN: + return row_sums.sum() / BT + elif reduction_int == _REDUCTION_MODE_SUM: + return row_sums.sum(dim=0) + else: # mean + return row_sums.sum() / (BT * V) + + +def _kldiv_backward_ct(y_true, scale, log_target): + BT, V = y_true.shape + BLOCK_SIZE = min(MAX_FUSED_SIZE, next_power_of_2(V)) + n_full_chunks = V // BLOCK_SIZE # full aligned chunks (tail handled separately in kernel) + + new_grads = torch.empty_like(y_true) + grid = (BT, 1, 1) + ct.launch( + torch.cuda.current_stream(), + grid, + _kldiv_bwd_ct, + ( + y_true, + new_grads, + int(V), + float(scale), + int(BLOCK_SIZE), + int(log_target), + int(n_full_chunks), + ), + ) + + return new_grads + + +class KLDivCuTileFunction(torch.autograd.Function): + """CuTile autograd wrapper for KL divergence loss.""" + + @staticmethod + def forward(ctx, y_pred, y_true, reduction, log_target, eps): + y_pred = y_pred.contiguous() + y_true = y_true.contiguous() + ctx.save_for_backward(y_true) + ctx.reduction = reduction + ctx.log_target = log_target + return _kldiv_forward_ct(y_pred, y_true, log_target, reduction, eps) + + @staticmethod + def backward(ctx, grad_output): + (y_true,) = ctx.saved_tensors + BT, V = y_true.shape + + # Compute combined scale: fuse grad_output and reduction normalizer into a + # single scalar so the kernel can apply both in one pass, eliminating the + # extra element-wise kernel launches that existed previously. + if grad_output.numel() == 1: + scale = grad_output.item() + else: + scale = 1.0 + + if ctx.reduction == "batchmean": + scale /= BT + elif ctx.reduction == "mean": + scale /= BT * V + + derivative = _kldiv_backward_ct(y_true, scale, ctx.log_target) + + # Non-scalar grad_output (rare: only when reduction="none"): apply separately + if grad_output.numel() != 1: + derivative = derivative * grad_output + + return derivative, None, None, None, None + + +@register_impl("liger.kl_div", backend="cutile") +def kl_div( + y_pred: torch.Tensor, + y_true: torch.Tensor, + reduction: str = "batchmean", + log_target: bool = False, + eps: float = 1e-10, + **kwargs, +) -> torch.Tensor: + return KLDivCuTileFunction.apply(y_pred, y_true, reduction, log_target, eps) diff --git a/src/tilegym/suites/liger/cutile/llama4_rope.py b/src/tilegym/suites/liger/cutile/llama4_rope.py new file mode 100644 index 00000000..713e45cf --- /dev/null +++ b/src/tilegym/suites/liger/cutile/llama4_rope.py @@ -0,0 +1,218 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +""" +Llama4-style Rotary Position Embedding (RoPE) kernel (CuTile backend). + +Applies in-place complex multiplication: (q_r + i*q_i) * (f_r + i*f_i). + +Grid: (batch_size, seq_len, n_heads_max) — one block per (batch, seq, head). + +Interleaved layout: q[b, s, h, 2*d] = real part, q[b, s, h, 2*d+1] = imaginary part. +We construct the stride-2 index pairs using: + base = ct.arange(BLOCK_SIZE) + doubled = base + base # [0, 2, 4, ..., 2*(BLOCK_SIZE-1)] + real_idx = doubled + d_start*2 + imag_idx = real_idx + 1 + +q and k are passed as 2D views (B*S*H, head_dim) for simpler indexing. +freqs_cis is passed as (S, head_dim) after view_as_real + reshape. +""" + +import cuda.tile as ct +import torch + +from tilegym.backend import register_impl + + +def _select_kernel_meta(head_dim_half: int): + if head_dim_half >= 256: + return 128, 8 + if head_dim_half >= 96: + return 128, 4 + if head_dim_half >= 48: + return 64, 4 + if head_dim_half >= 24: + return 32, 2 + return 16, 2 + + +@ct.kernel +def _llama4_rope_kernel( + query, # (B*S*H_q, head_dim) query — modified in-place + key, # (B*S*H_k, head_dim) key — modified in-place + freqs, # (S, head_dim) frequencies (view_as_real, flattened last 2 dims) + seq_len, + HEAD_DIM_HALF: ct.Constant[int], + N_Q_HEADS: ct.Constant[int], + N_K_HEADS: ct.Constant[int], + imag_sign, + BLOCK_SIZE: ct.Constant[int], +): + """ + RoPE kernel. + + Grid: (batch_size, seq_len, n_heads_max). + One block per (batch, seq, head) position. + + For each d-block, loads BLOCK_SIZE (real, imag) pairs from Q/K and FREQS, + computes complex multiplication, and stores back in-place. + + Index construction (stride-2 interleaved): + base = arange(BLOCK_SIZE) # [0, 1, ..., BLOCK_SIZE-1] + doubled = base + base # [0, 2, 4, ..., 2*(BLOCK_SIZE-1)] + real_idx = doubled + d_start*2 # real column indices + imag_idx = real_idx + 1 # imag column indices + """ + batch_idx = ct.bid(0) + seq_idx = ct.bid(1) + pid_h = ct.bid(2) + + # Number of BLOCK_SIZE blocks over head_dim_half + n_d_blocks = (HEAD_DIM_HALF + BLOCK_SIZE - 1) // BLOCK_SIZE + + for di in range(n_d_blocks): + d_start = di * BLOCK_SIZE + + # Build interleaved column indices for real/imag parts + base = ct.arange(BLOCK_SIZE, dtype=ct.int32) + doubled = base + base # [0, 2, 4, ..., 2*(BLOCK_SIZE-1)] + real_idx = doubled + d_start * 2 # real column indices in (seq, head_dim) + imag_idx = real_idx + 1 # imag column indices + + # Load frequencies for this seq position and d-block + f_r = ct.astype( + ct.gather(freqs, (seq_idx, real_idx), check_bounds=True, padding_value=0.0, latency=3), + ct.float32, + ) + f_i = ct.astype( + ct.gather(freqs, (seq_idx, imag_idx), check_bounds=True, padding_value=0.0, latency=3), + ct.float32, + ) + f_i = f_i * imag_sign + + # Process query head + if pid_h < N_Q_HEADS: + q_row = batch_idx * seq_len * N_Q_HEADS + seq_idx * N_Q_HEADS + pid_h + q_r = ct.astype( + ct.gather(query, (q_row, real_idx), check_bounds=True, padding_value=0.0, latency=3), ct.float32 + ) + q_i = ct.astype( + ct.gather(query, (q_row, imag_idx), check_bounds=True, padding_value=0.0, latency=3), ct.float32 + ) + + # Complex multiply: (q_r + i*q_i) * (f_r + i*f_i) + new_q_r = q_r * f_r - q_i * f_i + new_q_i = q_r * f_i + q_i * f_r + + ct.scatter(query, (q_row, real_idx), ct.astype(new_q_r, query.dtype), check_bounds=True) + ct.scatter(query, (q_row, imag_idx), ct.astype(new_q_i, query.dtype), check_bounds=True) + + # Process key head + if pid_h < N_K_HEADS: + k_row = batch_idx * seq_len * N_K_HEADS + seq_idx * N_K_HEADS + pid_h + k_r = ct.astype( + ct.gather(key, (k_row, real_idx), check_bounds=True, padding_value=0.0, latency=3), ct.float32 + ) + k_i = ct.astype( + ct.gather(key, (k_row, imag_idx), check_bounds=True, padding_value=0.0, latency=3), ct.float32 + ) + + new_k_r = k_r * f_r - k_i * f_i + new_k_i = k_r * f_i + k_i * f_r + + ct.scatter(key, (k_row, real_idx), ct.astype(new_k_r, key.dtype), check_bounds=True) + ct.scatter(key, (k_row, imag_idx), ct.astype(new_k_i, key.dtype), check_bounds=True) + + +def _llama4_rope_forward_ct(q, k, freqs_cis, BLOCK_SIZE=None, imag_sign=1.0): + original_dtype = q.dtype + + batch_size, seq_len, n_q_heads, head_dim = q.shape + _, _, n_k_heads, _ = k.shape + head_dim_half = head_dim // 2 + + # Normalize freqs_cis to (seq_len, head_dim) real layout + if freqs_cis.is_complex(): + freqs_cis = freqs_cis.reshape(-1, freqs_cis.shape[-1]) + if freqs_cis.shape[0] > seq_len: + freqs_cis = freqs_cis[:seq_len] + freqs_cis = torch.view_as_real(freqs_cis) # (seq_len, head_dim_half, 2) + + if freqs_cis.ndim == 3: + # (seq_len, head_dim_half, 2) → (seq_len, head_dim) + freqs_cis = freqs_cis.reshape(freqs_cis.shape[0], -1) + + compute_dtype = torch.float32 if q.dtype == torch.float32 else q.dtype + if k.dtype != q.dtype: + k = k.to(q.dtype) + q = q.to(compute_dtype).contiguous() + k = k.to(compute_dtype).contiguous() + freqs_cis = freqs_cis.float().contiguous() + + if BLOCK_SIZE is None: + BLOCK_SIZE, _ = _select_kernel_meta(head_dim_half) + + # Reshape to 2D for the kernel: (B*S*H, head_dim) + q_2d = q.reshape(batch_size * seq_len * n_q_heads, head_dim).contiguous() + k_2d = k.reshape(batch_size * seq_len * n_k_heads, head_dim).contiguous() + + n_heads_max = max(n_q_heads, n_k_heads) + grid = (batch_size, seq_len, n_heads_max) + + ct.launch( + torch.cuda.current_stream(), + grid, + _llama4_rope_kernel, + ( + q_2d, + k_2d, + freqs_cis, + int(seq_len), + int(head_dim_half), + int(n_q_heads), + int(n_k_heads), + float(imag_sign), + int(BLOCK_SIZE), + ), + ) + + q_out = q_2d.reshape(batch_size, seq_len, n_q_heads, head_dim) + k_out = k_2d.reshape(batch_size, seq_len, n_k_heads, head_dim) + + if q_out.dtype != original_dtype: + q_out = q_out.to(original_dtype) + if k_out.dtype != original_dtype: + k_out = k_out.to(original_dtype) + + return q_out, k_out + + +class Llama4RopeCuTileFunction(torch.autograd.Function): + """CuTile autograd wrapper for Llama4 RoPE.""" + + @staticmethod + def forward(ctx, q, k, freqs_cis, BLOCK_SIZE=None): + q_out, k_out = _llama4_rope_forward_ct(q, k, freqs_cis, BLOCK_SIZE, imag_sign=1.0) + ctx.save_for_backward(freqs_cis.detach() if isinstance(freqs_cis, torch.Tensor) else freqs_cis) + ctx.BLOCK_SIZE = BLOCK_SIZE + return q_out, k_out + + @staticmethod + def backward(ctx, dq, dk): + (freqs_cis,) = ctx.saved_tensors + BLOCK_SIZE = getattr(ctx, "BLOCK_SIZE", None) + dq_out, dk_out = _llama4_rope_forward_ct(dq, dk, freqs_cis, BLOCK_SIZE, imag_sign=-1.0) + return dq_out, dk_out, None, None + + +@register_impl("liger.llama4_rope", backend="cutile") +def llama4_rope( + q: torch.Tensor, + k: torch.Tensor, + freqs_cis: torch.Tensor, + BLOCK_SIZE: int = None, + **kwargs, +): + return Llama4RopeCuTileFunction.apply(q, k, freqs_cis, BLOCK_SIZE) diff --git a/src/tilegym/suites/liger/cutile/multi_token_attention.py b/src/tilegym/suites/liger/cutile/multi_token_attention.py new file mode 100644 index 00000000..f9f4e6f9 --- /dev/null +++ b/src/tilegym/suites/liger/cutile/multi_token_attention.py @@ -0,0 +1,319 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT +import cuda.tile as ct +import torch +import torch.nn.functional as F +from torch.nn.modules.utils import _pair + +from tilegym.backend import register_impl + +from .sparsemax import _sparsemax_backward_ct +from .sparsemax import _sparsemax_forward_ct +from .utils import next_power_of_2 + +_MASK_INF_VAL = -1e9 # large negative; -inf breaks multiply-accumulate pattern ((-inf)*0 = NaN) + + +def _select_block_size(L: int) -> int: + return min(next_power_of_2(L), 128) + + +@ct.kernel +def _mask_inf_fwd_kernel( + scores_2d, + output_2d, + L: ct.Constant[int], + BLOCK: ct.Constant[int], +): + actual_row = ct.bid(0) + batch_id = ct.bid(1) + row_idx = batch_id * L + actual_row + n_chunks = (L + BLOCK - 1) // BLOCK + + for ci in range(n_chunks): + col_start = ci * BLOCK + col_idx = ct.arange(BLOCK, dtype=ct.int32) + col_start + src_tile = ct.load(scores_2d, index=(row_idx, ci), shape=(1, BLOCK), padding_mode=ct.PaddingMode.ZERO).reshape( + (BLOCK,) + ) + is_future_f = ct.astype(col_idx > actual_row, ct.float32) + is_past_f = ct.astype(col_idx <= actual_row, ct.float32) + out_tile = ( + ct.astype(src_tile, ct.float32) * is_past_f + ct.full((BLOCK,), _MASK_INF_VAL, ct.float32) * is_future_f + ) + ct.store(output_2d, index=(row_idx, ci), tile=ct.astype(out_tile, output_2d.dtype).reshape((1, BLOCK))) + + +@ct.kernel +def _mask_zero_fwd_kernel( + scores_2d, + output_2d, + L: ct.Constant[int], + BLOCK: ct.Constant[int], +): + actual_row = ct.bid(0) + batch_id = ct.bid(1) + row_idx = batch_id * L + actual_row + n_chunks = (L + BLOCK - 1) // BLOCK + + for ci in range(n_chunks): + col_start = ci * BLOCK + col_idx = ct.arange(BLOCK, dtype=ct.int32) + col_start + src_tile = ct.load(scores_2d, index=(row_idx, ci), shape=(1, BLOCK), padding_mode=ct.PaddingMode.ZERO).reshape( + (BLOCK,) + ) + is_past_f = ct.astype(col_idx <= actual_row, ct.float32) + out_tile = ct.astype(src_tile, ct.float32) * is_past_f + ct.store(output_2d, index=(row_idx, ci), tile=ct.astype(out_tile, output_2d.dtype).reshape((1, BLOCK))) + + +@ct.kernel +def _mask_bwd_kernel( + grad_2d, + output_2d, + L: ct.Constant[int], + BLOCK: ct.Constant[int], +): + actual_row = ct.bid(0) + batch_id = ct.bid(1) + row_idx = batch_id * L + actual_row + n_chunks = (L + BLOCK - 1) // BLOCK + + for ci in range(n_chunks): + col_start = ci * BLOCK + col_idx = ct.arange(BLOCK, dtype=ct.int32) + col_start + grad_tile = ct.load(grad_2d, index=(row_idx, ci), shape=(1, BLOCK), padding_mode=ct.PaddingMode.ZERO).reshape( + (BLOCK,) + ) + is_past_f = ct.astype(col_idx <= actual_row, ct.float32) + out_tile = ct.astype(grad_tile, ct.float32) * is_past_f + ct.store(output_2d, index=(row_idx, ci), tile=ct.astype(out_tile, output_2d.dtype).reshape((1, BLOCK))) + + +@ct.kernel +def _fused_softmax_zeromask_bwd_kernel( + probs_2d, + grad_probs_2d, + output_2d, + L: ct.Constant[int], + BLOCK: ct.Constant[int], +): + """Fused softmax backward + causal zero-mask: dx = p*(dp - dot(p,dp)); zero col>row.""" + actual_row = ct.bid(0) + batch_id = ct.bid(1) + row_idx = batch_id * L + actual_row + n_chunks = (L + BLOCK - 1) // BLOCK + + dot_tile = ct.full((BLOCK,), 0.0, dtype=ct.float32) + for ci in range(n_chunks): + col_idx = ct.arange(BLOCK, dtype=ct.int32) + ci * BLOCK + p_tile = ct.astype( + ct.gather(probs_2d, (row_idx, col_idx), check_bounds=True, padding_value=0.0), + ct.float32, + ) + dp_tile = ct.astype( + ct.gather(grad_probs_2d, (row_idx, col_idx), check_bounds=True, padding_value=0.0), + ct.float32, + ) + dot_tile = dot_tile + p_tile * dp_tile + dot = ct.sum(dot_tile, 0, keepdims=False) + + for ci in range(n_chunks): + col_idx = ct.arange(BLOCK, dtype=ct.int32) + ci * BLOCK + p_tile = ct.astype( + ct.gather(probs_2d, (row_idx, col_idx), check_bounds=True, padding_value=0.0), + ct.float32, + ) + dp_tile = ct.astype( + ct.gather(grad_probs_2d, (row_idx, col_idx), check_bounds=True, padding_value=0.0), + ct.float32, + ) + dx_tile = p_tile * (dp_tile - dot) + is_past_f = ct.astype(col_idx <= actual_row, ct.float32) + ct.scatter(output_2d, (row_idx, col_idx), ct.astype(dx_tile * is_past_f, output_2d.dtype), check_bounds=True) + + +def _mask_launch(tensor: torch.Tensor, kernel) -> torch.Tensor: + *batch, L, _ = tensor.shape + N = int(torch.prod(torch.tensor(batch))) if batch else 1 + t_f = tensor.reshape(N * L, L).contiguous() + out = torch.empty_like(t_f) + BLOCK = _select_block_size(L) + ct.launch(torch.cuda.current_stream(), (L, N, 1), kernel, (t_f, out, int(L), int(BLOCK))) + return out.reshape(*batch, L, L) + + +def _mask_inf_forward_ct(scores: torch.Tensor) -> torch.Tensor: + return _mask_launch(scores, _mask_inf_fwd_kernel) + + +def _mask_zero_forward_ct(scores: torch.Tensor) -> torch.Tensor: + return _mask_launch(scores, _mask_zero_fwd_kernel) + + +def _mask_backward_ct(grad: torch.Tensor) -> torch.Tensor: + return _mask_launch(grad, _mask_bwd_kernel) + + +def _fused_softmax_zeromask_bwd_ct_launch(probs: torch.Tensor, grad_probs: torch.Tensor) -> torch.Tensor: + *batch, L, _ = probs.shape + N = int(torch.prod(torch.tensor(batch))) if batch else 1 + p_f = probs.reshape(N * L, L).contiguous() + dp_f = grad_probs.reshape(N * L, L).contiguous() + out = torch.empty_like(p_f) + + BLOCK = _select_block_size(L) + grid = (L, N, 1) + ct.launch( + torch.cuda.current_stream(), + grid, + _fused_softmax_zeromask_bwd_kernel, + (p_f, dp_f, out, int(L), int(BLOCK)), + ) + return out.reshape(*batch, L, L) + + +def _conv1x1_backward(grad_out: torch.Tensor, inp: torch.Tensor, weight: torch.Tensor): + """mm-based 1x1 conv backward -- bypasses cuDNN dispatch overhead. + + For a kernel_size=1 conv: + grad_input[b,cin,h,w] = sum_cout(W[cout,cin] * dout[b,cout,h,w]) + grad_weight[cout,cin] = sum_{b,h,w}(dout[b,cout,h,w] * inp[b,cin,h,w]) + + Both reduce to matrix multiplications on the (B*H*W, C) reshape, letting + cuBLAS SGEMM handle the compute. On B200 for CH=1, L=128 this is ~1.69x + faster than F.conv_transpose2d + torch.nn.grad.conv2d_weight because it + bypasses cuDNN's per-call dispatch overhead for this tiny shape. + """ + B, C_out, H, W = grad_out.shape + C_in = inp.shape[1] + N = B * H * W + go_2d = grad_out.permute(0, 2, 3, 1).reshape(N, C_out) + in_2d = inp.permute(0, 2, 3, 1).reshape(N, C_in) + w_2d = weight.view(C_out, C_in) + grad_input = torch.mm(go_2d, w_2d).reshape(B, H, W, C_in).permute(0, 3, 1, 2).contiguous() + grad_weight = torch.mm(go_2d.t(), in_2d).view(weight.shape) + return grad_input, grad_weight + + +class MultiTokenAttentionCuTileFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, scores, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, sparse=False): + scores = scores.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + + ctx.sparse = sparse + + if sparse: + if scores.dtype != torch.float32: + raise RuntimeError( + f"CuTile sparse multi-token attention only supports fp32 input scores. Got dtype={scores.dtype}." + ) + compute_dtype = torch.float32 + weight_c, bias_c = weight, bias + + scores_inf = _mask_inf_forward_ct(scores) + probs, out_flat_sparse = _sparsemax_forward_ct(scores_inf, dim=-1) + out_conv = F.conv2d( + probs, weight_c, bias_c, stride=stride, padding=padding, dilation=dilation, groups=groups + ) + out = _mask_zero_forward_ct(out_conv) + ctx.save_for_backward(scores_inf, probs, out_flat_sparse, weight_c, bias_c) + else: + compute_dtype = scores.dtype + # fp16: promote to float32 for TF32 conv+softmax — avoids backward regression on small shapes (L≤128). + if compute_dtype == torch.float16: + scores = scores.float() + weight_c = weight.float() + bias_c = bias.float() if bias is not None else None + else: + weight_c, bias_c = weight, bias + + scores_inf = _mask_inf_forward_ct(scores) + probs = torch.softmax(scores_inf, dim=-1) + out_conv = F.conv2d( + probs, weight_c, bias_c, stride=stride, padding=padding, dilation=dilation, groups=groups + ) + out = _mask_zero_forward_ct(out_conv) + ctx.save_for_backward(scores_inf, probs, weight_c, bias_c) + + ctx.stride = _pair(stride) + ctx.padding = _pair(padding) + ctx.dilation = _pair(dilation) + ctx.groups = groups + ctx.compute_dtype = compute_dtype + + return out.to(compute_dtype) + + @staticmethod + def backward(ctx, grad_out): + stride, padding, dilation, groups = (ctx.stride, ctx.padding, ctx.dilation, ctx.groups) + sparse = ctx.sparse + + if sparse: + scores_inf, probs, out_flat_sparse, weight, bias = ctx.saved_tensors + else: + scores_inf, probs, weight, bias = ctx.saved_tensors + + # .contiguous() is required: PyTorch's sum().backward() passes a broadcast + # tensor (strides=0), which would cause CuTile gather to read invalid offsets. + grad_out_c = grad_out.to(probs.dtype).contiguous() + + grad_conv = _mask_backward_ct(grad_out_c) + + # conv backward: mm-based 1x1 shortcut or cuDNN fallback + if stride == (1, 1) and padding == (0, 0) and dilation == (1, 1) and groups == 1: + grad_probs, grad_weight = _conv1x1_backward(grad_conv, probs, weight) + else: + grad_probs = F.conv_transpose2d( + grad_conv, weight, None, stride=stride, padding=padding, dilation=dilation, groups=groups + ) + grad_weight = torch.nn.grad.conv2d_weight( + input=probs, + weight_size=weight.shape, + grad_output=grad_conv, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + grad_bias = None + if bias is not None: + grad_bias = grad_conv.sum(dim=(0, 2, 3)) + + if sparse: + grad_scores_inf = _sparsemax_backward_ct(grad_probs.contiguous(), out_flat_sparse, dim=-1) + grad_scores = _mask_backward_ct(grad_scores_inf.to(probs.dtype).contiguous()) + else: + grad_scores = _fused_softmax_zeromask_bwd_ct_launch(probs, grad_probs) + + orig = ctx.compute_dtype + return ( + grad_scores.to(orig), + grad_weight.to(orig), + grad_bias.to(orig) if grad_bias is not None else None, + None, + None, + None, + None, + None, + ) + + +@register_impl("liger.multi_token_attention", backend="cutile") +def multi_token_attention( + scores: torch.Tensor, + weight: torch.Tensor, + bias=None, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + sparse: bool = False, + **kwargs, +) -> torch.Tensor: + return MultiTokenAttentionCuTileFunction.apply(scores, weight, bias, stride, padding, dilation, groups, sparse) diff --git a/src/tilegym/suites/liger/cutile/qwen2vl_mrope.py b/src/tilegym/suites/liger/cutile/qwen2vl_mrope.py new file mode 100644 index 00000000..f977eccd --- /dev/null +++ b/src/tilegym/suites/liger/cutile/qwen2vl_mrope.py @@ -0,0 +1,280 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +# Adapted from https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py + +""" +Qwen2VL Multimodal Rotary Position Embedding (M-RoPE) kernel (CuTile backend). + +Half-split layout: left half of head_dim = real part, right half = imaginary part. +Three RoPE sections: temporal [0, t_end), height [t_end, h_end), width [h_end, hd//2). +cos/sin shape: (3, bsz, seq_len, head_dim). +Grid: (bsz * seq_len,) — one program per token. +""" + +import cuda.tile as ct +import torch + +from tilegym.backend import register_impl + +from .utils import next_power_of_2 + +ConstInt = ct.Constant[int] +PAD_ZERO = ct.PaddingMode.ZERO + + +@ct.kernel +def _qwen2vl_mrope_kernel( + query, # (bsz, seq_len, n_qh, 2, head_dim_half) + key, # (bsz, seq_len, n_kh, 2, head_dim_half) + cos, # (3, bsz, seq_len, hd) + sin, # (3, bsz, seq_len, hd) + sl, + N_QH: ConstInt, + N_KH: ConstInt, + MROPE_SECTION_T: ConstInt, + MROPE_SECTION_H: ConstInt, + sin_sign, + HEAD_DIM_HALF: ConstInt, + TILE_HD: ConstInt, + TILE_QH: ConstInt, + TILE_KH: ConstInt, + ALIGNED: ct.Constant[bool], +): + pid = ct.bid(0) + batch_idx = pid // sl + seq_idx = pid % sl + + t_end = MROPE_SECTION_T + h_end = t_end + MROPE_SECTION_H + + # Load cos/sin for 3 sections: temporal, height, width. + # When ALIGNED (TILE_HD == head_dim_half, i.e. power-of-2), skip zero-padding + # for the hardware TMA fast path. Otherwise use PAD_ZERO for safety. + if ALIGNED: + t_cos = ct.load(cos, index=(0, batch_idx, seq_idx, 0), shape=(1, 1, 1, TILE_HD)).reshape((1, TILE_HD)) + t_sin = ct.load(sin, index=(0, batch_idx, seq_idx, 0), shape=(1, 1, 1, TILE_HD)).reshape((1, TILE_HD)) + h_cos = ct.load(cos, index=(1, batch_idx, seq_idx, 0), shape=(1, 1, 1, TILE_HD)).reshape((1, TILE_HD)) + h_sin = ct.load(sin, index=(1, batch_idx, seq_idx, 0), shape=(1, 1, 1, TILE_HD)).reshape((1, TILE_HD)) + w_cos = ct.load(cos, index=(2, batch_idx, seq_idx, 0), shape=(1, 1, 1, TILE_HD)).reshape((1, TILE_HD)) + w_sin = ct.load(sin, index=(2, batch_idx, seq_idx, 0), shape=(1, 1, 1, TILE_HD)).reshape((1, TILE_HD)) + else: + t_cos = ct.load(cos, index=(0, batch_idx, seq_idx, 0), shape=(1, 1, 1, TILE_HD), padding_mode=PAD_ZERO).reshape( + (1, TILE_HD) + ) + t_sin = ct.load(sin, index=(0, batch_idx, seq_idx, 0), shape=(1, 1, 1, TILE_HD), padding_mode=PAD_ZERO).reshape( + (1, TILE_HD) + ) + h_cos = ct.load(cos, index=(1, batch_idx, seq_idx, 0), shape=(1, 1, 1, TILE_HD), padding_mode=PAD_ZERO).reshape( + (1, TILE_HD) + ) + h_sin = ct.load(sin, index=(1, batch_idx, seq_idx, 0), shape=(1, 1, 1, TILE_HD), padding_mode=PAD_ZERO).reshape( + (1, TILE_HD) + ) + w_cos = ct.load(cos, index=(2, batch_idx, seq_idx, 0), shape=(1, 1, 1, TILE_HD), padding_mode=PAD_ZERO).reshape( + (1, TILE_HD) + ) + w_sin = ct.load(sin, index=(2, batch_idx, seq_idx, 0), shape=(1, 1, 1, TILE_HD), padding_mode=PAD_ZERO).reshape( + (1, TILE_HD) + ) + + # Section masks + d_idx = ct.arange(TILE_HD, dtype=ct.int32) + t_mask = d_idx < t_end + h_mask = ct.bitwise_and((d_idx >= t_end), (d_idx < h_end)) + w_mask = ct.bitwise_and((d_idx >= h_end), (d_idx < HEAD_DIM_HALF)) + + t_f = ct.astype(t_mask, ct.float32) + h_f = ct.astype(h_mask, ct.float32) + w_f = ct.astype(w_mask, ct.float32) + + cos_row = t_cos * t_f + h_cos * h_f + w_cos * w_f + sin_row = (t_sin * t_f + h_sin * h_f + w_sin * w_f) * sin_sign + + # Process Q: load all heads at once + if ALIGNED: + q_r = ct.load(query, index=(batch_idx, seq_idx, 0, 0, 0), shape=(1, 1, TILE_QH, 1, TILE_HD)).reshape( + (TILE_QH, TILE_HD) + ) + q_i = ct.load(query, index=(batch_idx, seq_idx, 0, 1, 0), shape=(1, 1, TILE_QH, 1, TILE_HD)).reshape( + (TILE_QH, TILE_HD) + ) + else: + q_r = ct.load( + query, index=(batch_idx, seq_idx, 0, 0, 0), shape=(1, 1, TILE_QH, 1, TILE_HD), padding_mode=PAD_ZERO + ).reshape((TILE_QH, TILE_HD)) + q_i = ct.load( + query, index=(batch_idx, seq_idx, 0, 1, 0), shape=(1, 1, TILE_QH, 1, TILE_HD), padding_mode=PAD_ZERO + ).reshape((TILE_QH, TILE_HD)) + new_q_r = q_r * cos_row - q_i * sin_row + new_q_i = q_i * cos_row + q_r * sin_row + ct.store( + query, + index=(batch_idx, seq_idx, 0, 0, 0), + tile=new_q_r.reshape((1, 1, TILE_QH, 1, TILE_HD)).astype(query.dtype), + ) + ct.store( + query, + index=(batch_idx, seq_idx, 0, 1, 0), + tile=new_q_i.reshape((1, 1, TILE_QH, 1, TILE_HD)).astype(query.dtype), + ) + + # Process K: load all heads at once + if ALIGNED: + k_r = ct.load(key, index=(batch_idx, seq_idx, 0, 0, 0), shape=(1, 1, TILE_KH, 1, TILE_HD)).reshape( + (TILE_KH, TILE_HD) + ) + k_i = ct.load(key, index=(batch_idx, seq_idx, 0, 1, 0), shape=(1, 1, TILE_KH, 1, TILE_HD)).reshape( + (TILE_KH, TILE_HD) + ) + else: + k_r = ct.load( + key, index=(batch_idx, seq_idx, 0, 0, 0), shape=(1, 1, TILE_KH, 1, TILE_HD), padding_mode=PAD_ZERO + ).reshape((TILE_KH, TILE_HD)) + k_i = ct.load( + key, index=(batch_idx, seq_idx, 0, 1, 0), shape=(1, 1, TILE_KH, 1, TILE_HD), padding_mode=PAD_ZERO + ).reshape((TILE_KH, TILE_HD)) + new_k_r = k_r * cos_row - k_i * sin_row + new_k_i = k_i * cos_row + k_r * sin_row + ct.store( + key, + index=(batch_idx, seq_idx, 0, 0, 0), + tile=new_k_r.reshape((1, 1, TILE_KH, 1, TILE_HD)).astype(key.dtype), + ) + ct.store( + key, + index=(batch_idx, seq_idx, 0, 1, 0), + tile=new_k_i.reshape((1, 1, TILE_KH, 1, TILE_HD)).astype(key.dtype), + ) + + +def _qwen2vl_mrope_forward(q, k, cos, sin, mrope_section): + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + batch_size, seq_len, n_q_head, head_dim = q.shape + n_kv_head = k.shape[2] + head_dim_half = head_dim // 2 + TILE_HD = next_power_of_2(head_dim_half) + TILE_QH = next_power_of_2(n_q_head) + TILE_KH = next_power_of_2(n_kv_head) + # ALIGNED: both TILE_HD == head_dim_half (power-of-2) AND + # TILE_QH == n_q_head AND TILE_KH == n_kv_head — no padding needed anywhere + ALIGNED = (TILE_HD == head_dim_half) and (TILE_QH == n_q_head) and (TILE_KH == n_kv_head) + + n_row = batch_size * seq_len + + q = q.contiguous() + k = k.contiguous() + cos = cos.contiguous() + sin = sin.contiguous() + + q_5d = q.reshape(batch_size, seq_len, n_q_head, 2, head_dim_half) + k_5d = k.reshape(batch_size, seq_len, n_kv_head, 2, head_dim_half) + + grid = (n_row,) + ct.launch( + torch.cuda.current_stream(), + grid, + _qwen2vl_mrope_kernel, + ( + q_5d, + k_5d, + cos, + sin, + int(seq_len), + int(n_q_head), + int(n_kv_head), + int(mrope_section[0]), + int(mrope_section[1]), + float(1.0), + int(head_dim_half), + int(TILE_HD), + int(TILE_QH), + int(TILE_KH), + bool(ALIGNED), + ), + ) + + q_out = q_5d.reshape(batch_size, seq_len, n_q_head, head_dim) + k_out = k_5d.reshape(batch_size, seq_len, n_kv_head, head_dim) + return q_out.transpose(1, 2), k_out.transpose(1, 2), cos, sin + + +def _qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section): + dq = dq.transpose(1, 2) + dk = dk.transpose(1, 2) + + batch_size, seq_len, n_q_head, head_dim = dq.shape + n_kv_head = dk.shape[2] + head_dim_half = head_dim // 2 + TILE_HD = next_power_of_2(head_dim_half) + TILE_QH = next_power_of_2(n_q_head) + TILE_KH = next_power_of_2(n_kv_head) + ALIGNED = (TILE_HD == head_dim_half) and (TILE_QH == n_q_head) and (TILE_KH == n_kv_head) + + n_row = batch_size * seq_len + + dq = dq.contiguous() + dk = dk.contiguous() + + dq_5d = dq.reshape(batch_size, seq_len, n_q_head, 2, head_dim_half) + dk_5d = dk.reshape(batch_size, seq_len, n_kv_head, 2, head_dim_half) + + grid = (n_row,) + ct.launch( + torch.cuda.current_stream(), + grid, + _qwen2vl_mrope_kernel, + ( + dq_5d, + dk_5d, + cos, + sin, + int(seq_len), + int(n_q_head), + int(n_kv_head), + int(mrope_section[0]), + int(mrope_section[1]), + float(-1.0), + int(head_dim_half), + int(TILE_HD), + int(TILE_QH), + int(TILE_KH), + bool(ALIGNED), + ), + ) + + dq_out = dq_5d.reshape(batch_size, seq_len, n_q_head, head_dim) + dk_out = dk_5d.reshape(batch_size, seq_len, n_kv_head, head_dim) + return dq_out.transpose(1, 2), dk_out.transpose(1, 2) + + +class Qwen2VLMRopeCuTileFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1): + q, k, cos, sin = _qwen2vl_mrope_forward(q, k, cos, sin, mrope_section) + ctx.save_for_backward(cos, sin) + ctx.mrope_section = mrope_section + return q, k + + @staticmethod + def backward(ctx, dq, dk): + cos, sin = ctx.saved_tensors + mrope_section = ctx.mrope_section + dq, dk = _qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section) + return dq, dk, None, None, None, None + + +@register_impl("liger.qwen2vl_mrope", backend="cutile") +def qwen2vl_mrope( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + mrope_section: list, + unsqueeze_dim: int = 1, + **kwargs, +) -> tuple: + return Qwen2VLMRopeCuTileFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim) diff --git a/src/tilegym/suites/liger/cutile/rope.py b/src/tilegym/suites/liger/cutile/rope.py new file mode 100644 index 00000000..9ecc3181 --- /dev/null +++ b/src/tilegym/suites/liger/cutile/rope.py @@ -0,0 +1,395 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +""" +Rotary Positional Embedding (RoPE) kernel (CuTile backend). + +HuggingFace Llama/Mistral variant -- half-split layout: + left half = real part, right half = imaginary part. + forward: new_r = r*cos - i*sin, new_i = i*cos + r*sin + backward: new_r = r*cos + i*sin, new_i = i*cos - r*sin (sin_sign=-1.0) + +Two kernel variants selected at runtime via the ALIGNED flag: + + ALIGNED case (power-of-2 head_dim AND all head counts exactly match tile sizes): + _rope_4d_ct -- operates on Q/K in original (bsz, n_heads, seq_len, head_dim) layout; + uses block indices 0/1 in the last dim for real/imag halves. + COS/SIN passed in (cos_bs, seq_len, head_dim) form -- no reshape needed. + Avoids expensive host-side transpose+contiguous+reshape copies. + + Non-ALIGNED case: + _rope_5d_ct -- operates on Q/K in (bsz, seq_len, n_heads, 2, head_dim_half) layout; + uses padding_mode=PAD_ZERO for safety on non-power-of-2 shapes. + COS/SIN in (cos_bs, seq_len, 1, head_dim_half) form (via _prepare_cos_sin). + +Grid: (bsz * seq_len,) -- one block per token. + +PERF NOTES: +- ALIGNED 4D path eliminates ~0.035 ms of host-side tensor manipulation per call + (transpose+contiguous on Q and K, plus _prepare_cos_sin reshape). +- COS/SIN TMA: for ALIGNED, load directly from (cos_bs, seq_len, head_dim) -- block index 0 + in the last dim grabs elements [0, TILE_HD) = the cosine/sine values for this token. +- Q/K TMA: block index 0/1 in last dim selects real/imag halves of head_dim. +""" + +import cuda.tile as ct +import torch + +from tilegym.backend import register_impl + +from .utils import next_power_of_2 + +ConstInt = ct.Constant[int] +PAD_ZERO = ct.PaddingMode.ZERO + + +@ct.kernel +def _rope_4d_ct( + Q, # (bsz, n_q_heads, seq_len, head_dim) -- original layout, head_dim = 2*TILE_HD + K, # (bsz, n_k_heads, seq_len, head_dim) + COS, # (cos_bs, seq_len, head_dim) -- first TILE_HD elements are the cos values + SIN, # (cos_bs, seq_len, head_dim) + cos_bs: ConstInt, + seq_len: ConstInt, + sin_sign: ct.Constant[float], + TILE_QH: ConstInt, + TILE_KH: ConstInt, + TILE_HD: ConstInt, +): + """Fast path for ALIGNED shapes: no host-side transpose or reshape needed.""" + cos_bs = COS.shape[0] + + pid = ct.bid(0) + batch_idx = pid // seq_len + seq_idx = pid % seq_len + cos_batch_idx = 0 if cos_bs == 1 else batch_idx + + # Load first TILE_HD elements of cos/sin: these are the rotation values for this token. + # Index (cos_batch_idx, seq_idx, 0) loads COS[cos_batch_idx, seq_idx, 0:TILE_HD]. + cos_row = ct.astype( + ct.load(COS, index=(cos_batch_idx, seq_idx, 0), shape=(1, 1, TILE_HD)).reshape((1, TILE_HD)), + ct.float32, + ) + sin_row = ( + ct.astype( + ct.load(SIN, index=(cos_batch_idx, seq_idx, 0), shape=(1, 1, TILE_HD)).reshape((1, TILE_HD)), + ct.float32, + ) + * sin_sign + ) + + # Q in (bsz, n_q_heads, seq_len, head_dim): index (b, h, s, 0) = real half, + # index (b, h, s, 1) = imag half (block 1 starts at element TILE_HD = head_dim_half). + q_r = ct.astype( + ct.load(Q, index=(batch_idx, 0, seq_idx, 0), shape=(1, TILE_QH, 1, TILE_HD)).reshape((TILE_QH, TILE_HD)), + ct.float32, + ) + q_i = ct.astype( + ct.load(Q, index=(batch_idx, 0, seq_idx, 1), shape=(1, TILE_QH, 1, TILE_HD)).reshape((TILE_QH, TILE_HD)), + ct.float32, + ) + new_q_r = q_r * cos_row - q_i * sin_row + new_q_i = q_i * cos_row + q_r * sin_row + ct.store(Q, index=(batch_idx, 0, seq_idx, 0), tile=new_q_r.reshape((1, TILE_QH, 1, TILE_HD)).astype(Q.dtype)) + ct.store(Q, index=(batch_idx, 0, seq_idx, 1), tile=new_q_i.reshape((1, TILE_QH, 1, TILE_HD)).astype(Q.dtype)) + + # K in (bsz, n_k_heads, seq_len, head_dim) + k_r = ct.astype( + ct.load(K, index=(batch_idx, 0, seq_idx, 0), shape=(1, TILE_KH, 1, TILE_HD)).reshape((TILE_KH, TILE_HD)), + ct.float32, + ) + k_i = ct.astype( + ct.load(K, index=(batch_idx, 0, seq_idx, 1), shape=(1, TILE_KH, 1, TILE_HD)).reshape((TILE_KH, TILE_HD)), + ct.float32, + ) + new_k_r = k_r * cos_row - k_i * sin_row + new_k_i = k_i * cos_row + k_r * sin_row + ct.store(K, index=(batch_idx, 0, seq_idx, 0), tile=new_k_r.reshape((1, TILE_KH, 1, TILE_HD)).astype(K.dtype)) + ct.store(K, index=(batch_idx, 0, seq_idx, 1), tile=new_k_i.reshape((1, TILE_KH, 1, TILE_HD)).astype(K.dtype)) + + +@ct.kernel +def _rope_5d_ct( + Q, # (bsz, seq_len, n_q_heads, 2, head_dim_half) -- 5D layout for non-aligned case + K, # (bsz, seq_len, n_k_heads, 2, head_dim_half) + COS, # (cos_bs, seq_len, 1, head_dim_half) + SIN, # (cos_bs, seq_len, 1, head_dim_half) + cos_bs: ConstInt, + seq_len: ConstInt, + sin_sign: ct.Constant[float], + TILE_QH: ConstInt, + TILE_KH: ConstInt, + TILE_HD: ConstInt, +): + """Fallback path for non-ALIGNED shapes: uses PAD_ZERO for non-power-of-2 dims.""" + cos_bs = COS.shape[0] + + pid = ct.bid(0) + batch_idx = pid // seq_len + seq_idx = pid % seq_len + cos_batch_idx = 0 if cos_bs == 1 else batch_idx + + cos_row = ct.astype( + ct.load(COS, index=(cos_batch_idx, seq_idx, 0, 0), shape=(1, 1, 1, TILE_HD), padding_mode=PAD_ZERO).reshape( + (1, TILE_HD) + ), + ct.float32, + ) + sin_row = ( + ct.astype( + ct.load(SIN, index=(cos_batch_idx, seq_idx, 0, 0), shape=(1, 1, 1, TILE_HD), padding_mode=PAD_ZERO).reshape( + (1, TILE_HD) + ), + ct.float32, + ) + * sin_sign + ) + + q_r = ct.astype( + ct.load( + Q, index=(batch_idx, seq_idx, 0, 0, 0), shape=(1, 1, TILE_QH, 1, TILE_HD), padding_mode=PAD_ZERO + ).reshape((TILE_QH, TILE_HD)), + ct.float32, + ) + q_i = ct.astype( + ct.load( + Q, index=(batch_idx, seq_idx, 0, 1, 0), shape=(1, 1, TILE_QH, 1, TILE_HD), padding_mode=PAD_ZERO + ).reshape((TILE_QH, TILE_HD)), + ct.float32, + ) + new_q_r = q_r * cos_row - q_i * sin_row + new_q_i = q_i * cos_row + q_r * sin_row + ct.store( + Q, + index=(batch_idx, seq_idx, 0, 0, 0), + tile=new_q_r.reshape((1, 1, TILE_QH, 1, TILE_HD)).astype(Q.dtype), + ) + ct.store( + Q, + index=(batch_idx, seq_idx, 0, 1, 0), + tile=new_q_i.reshape((1, 1, TILE_QH, 1, TILE_HD)).astype(Q.dtype), + ) + + k_r = ct.astype( + ct.load( + K, index=(batch_idx, seq_idx, 0, 0, 0), shape=(1, 1, TILE_KH, 1, TILE_HD), padding_mode=PAD_ZERO + ).reshape((TILE_KH, TILE_HD)), + ct.float32, + ) + k_i = ct.astype( + ct.load( + K, index=(batch_idx, seq_idx, 0, 1, 0), shape=(1, 1, TILE_KH, 1, TILE_HD), padding_mode=PAD_ZERO + ).reshape((TILE_KH, TILE_HD)), + ct.float32, + ) + new_k_r = k_r * cos_row - k_i * sin_row + new_k_i = k_i * cos_row + k_r * sin_row + ct.store( + K, + index=(batch_idx, seq_idx, 0, 0, 0), + tile=new_k_r.reshape((1, 1, TILE_KH, 1, TILE_HD)).astype(K.dtype), + ) + ct.store( + K, + index=(batch_idx, seq_idx, 0, 1, 0), + tile=new_k_i.reshape((1, 1, TILE_KH, 1, TILE_HD)).astype(K.dtype), + ) + + +def _prepare_cos_sin(cos, sin, seq_len, head_dim_half): + """Slice cos/sin to head_dim_half and reshape to (cos_bs, seq_len, 1, head_dim_half). + + Used only for the non-ALIGNED (5D) path. Keeps cos/sin in their original dtype; + the float32 cast is done on-chip inside the kernels. + """ + cos_bs = cos.shape[0] + cos_4d = cos[..., :head_dim_half].contiguous().reshape(cos_bs, seq_len, 1, head_dim_half) + sin_4d = sin[..., :head_dim_half].contiguous().reshape(cos_bs, seq_len, 1, head_dim_half) + return cos_4d, sin_4d, cos_bs + + +class RopeCuTileFunction(torch.autograd.Function): + """CuTile autograd wrapper for RoPE. + + ALIGNED case (power-of-2 head_dim, all head counts exactly match tile sizes): + Uses _rope_4d_ct on Q/K without transpose or reshape. COS/SIN passed in 3D form. + This is the common case for LLMs (head_dim=64/128/256, n_heads=power-of-2). + + Non-ALIGNED case: falls back to _rope_5d_ct with transpose+reshape+contiguous. + """ + + @staticmethod + def forward(ctx, q, k, cos, sin): + # q: (bsz, n_q_heads, seq_len, head_dim) + # k: (bsz, n_k_heads, seq_len, head_dim) + bsz, n_q_heads, seq_len, head_dim = q.shape + n_k_heads = k.shape[1] + head_dim_half = head_dim // 2 + original_dtype = q.dtype + + TILE_HD = next_power_of_2(head_dim_half) + TILE_QH = next_power_of_2(n_q_heads) + TILE_KH = next_power_of_2(n_k_heads) + # Require contiguous inputs for the 4D path: non-contiguous q/k (e.g. from + # a transpose of seq-first storage) would force an unavoidable copy inside + # the ALIGNED branch that doesn't happen in Liger's seq-first design. + # Falling back to the 5D path lets transpose+contiguous be a no-op. + ALIGNED = ( + (TILE_HD == head_dim_half) + and (TILE_QH == n_q_heads) + and (TILE_KH == n_k_heads) + and q.is_contiguous() + and k.is_contiguous() + ) + + n_row = bsz * seq_len + grid = (n_row,) + + if ALIGNED: + # Fast path: 4D kernel -- no transpose, no reshape, no _prepare_cos_sin + cos_3d = cos.contiguous() # (cos_bs, seq_len, head_dim) + sin_3d = sin.contiguous() + cos_bs = cos_3d.shape[0] + q_in = q.contiguous() + k_in = k.contiguous() + ct.launch( + torch.cuda.current_stream(), + grid, + _rope_4d_ct, + ( + q_in, + k_in, + cos_3d, + sin_3d, + int(cos_bs), + int(seq_len), + float(1.0), + int(TILE_QH), + int(TILE_KH), + int(TILE_HD), + ), + ) + ctx.save_for_backward(cos_3d, sin_3d) + ctx.cos_4d = None + else: + # Slow path: 5D kernel with transpose+contiguous + q_t = q.transpose(1, 2).contiguous() + k_t = k.transpose(1, 2).contiguous() + cos_4d, sin_4d, cos_bs = _prepare_cos_sin(cos, sin, seq_len, head_dim_half) + q_5d = q_t.view(bsz, seq_len, n_q_heads, 2, head_dim_half) + k_5d = k_t.view(bsz, seq_len, n_k_heads, 2, head_dim_half) + ct.launch( + torch.cuda.current_stream(), + grid, + _rope_5d_ct, + ( + q_5d, + k_5d, + cos_4d, + sin_4d, + int(cos_bs), + int(seq_len), + float(-1.0 if False else 1.0), + int(TILE_QH), + int(TILE_KH), + int(TILE_HD), + ), + ) + q_in = q_t.view(bsz, seq_len, n_q_heads, head_dim).transpose(1, 2).to(original_dtype) + k_in = k_t.view(bsz, seq_len, n_k_heads, head_dim).transpose(1, 2).to(original_dtype) + cos_bs = cos_4d.shape[0] + ctx.save_for_backward(cos_4d, sin_4d) + ctx.cos_4d = True + + ctx.bsz = bsz + ctx.seq_len = seq_len + ctx.n_q_heads = n_q_heads + ctx.n_k_heads = n_k_heads + ctx.head_dim = head_dim + ctx.cos_bs = cos_bs + ctx.original_dtype = original_dtype + ctx.ALIGNED = ALIGNED + ctx.TILE_QH = TILE_QH + ctx.TILE_KH = TILE_KH + ctx.TILE_HD = TILE_HD + + return q_in, k_in + + @staticmethod + def backward(ctx, dq, dk): + ALIGNED = ctx.ALIGNED + bsz = ctx.bsz + seq_len = ctx.seq_len + n_q_heads = ctx.n_q_heads + n_k_heads = ctx.n_k_heads + head_dim = ctx.head_dim + cos_bs = ctx.cos_bs + head_dim_half = head_dim // 2 + TILE_QH = ctx.TILE_QH + TILE_KH = ctx.TILE_KH + TILE_HD = ctx.TILE_HD + + n_row = bsz * seq_len + grid = (n_row,) + + if ALIGNED: + cos_3d, sin_3d = ctx.saved_tensors + dq_in = dq.contiguous() + dk_in = dk.contiguous() + ct.launch( + torch.cuda.current_stream(), + grid, + _rope_4d_ct, + ( + dq_in, + dk_in, + cos_3d, + sin_3d, + int(cos_bs), + int(seq_len), + float(-1.0), + int(TILE_QH), + int(TILE_KH), + int(TILE_HD), + ), + ) + return dq_in.contiguous(), dk_in.contiguous(), None, None + else: + cos_4d, sin_4d = ctx.saved_tensors + dq_t = dq.transpose(1, 2).contiguous() + dk_t = dk.transpose(1, 2).contiguous() + dq_5d = dq_t.view(bsz, seq_len, n_q_heads, 2, head_dim_half) + dk_5d = dk_t.view(bsz, seq_len, n_k_heads, 2, head_dim_half) + ct.launch( + torch.cuda.current_stream(), + grid, + _rope_5d_ct, + ( + dq_5d, + dk_5d, + cos_4d, + sin_4d, + int(cos_bs), + int(seq_len), + float(-1.0), + int(TILE_QH), + int(TILE_KH), + int(TILE_HD), + ), + ) + dq_out = dq_t.view(bsz, seq_len, n_q_heads, head_dim).transpose(1, 2).to(ctx.original_dtype) + dk_out = dk_t.view(bsz, seq_len, n_k_heads, head_dim).transpose(1, 2).to(ctx.original_dtype) + # Return views directly (no copy needed); autograd handles non-contiguous grads. + return dq_out, dk_out, None, None + + +@register_impl("liger.rope", backend="cutile") +def rope( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + **kwargs, +) -> tuple: + return RopeCuTileFunction.apply(q, k, cos, sin) diff --git a/src/tilegym/suites/liger/cutile/sparsemax.py b/src/tilegym/suites/liger/cutile/sparsemax.py new file mode 100644 index 00000000..c578c812 --- /dev/null +++ b/src/tilegym/suites/liger/cutile/sparsemax.py @@ -0,0 +1,240 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT +import cuda.tile as ct +import torch + +from tilegym.backend import register_impl + +from .utils import next_power_of_2 + +# 20 bisections give fp32-scale tau precision (~1e-6 relative interval). +_BSEARCH_ITER = 20 + + +def _select_block_size(n_cols: int) -> int: + return min(next_power_of_2(n_cols), 4096) + + +@ct.kernel(occupancy=4) +def _sparsemax_bsearch_kernel( + y_output, + x_input, + N_COLS: ct.Constant[int], + BLOCK_SIZE: ct.Constant[int], + BSEARCH_ITER: ct.Constant[int], +): + row_idx = ct.bid(0) + n_chunks = (N_COLS + BLOCK_SIZE - 1) // BLOCK_SIZE + + x_max = ct.full((1,), -1e38, dtype=ct.float32) + x_sum = ct.full((1,), 0.0, dtype=ct.float32) + + for ci in range(n_chunks): + col_idx = ct.arange(BLOCK_SIZE, dtype=ct.int32) + ci * BLOCK_SIZE + x_tile = ct.astype( + ct.gather(x_input, (row_idx, col_idx), check_bounds=True, padding_value=-1e38), + ct.float32, + ) + valid_mask = ct.astype(col_idx < N_COLS, ct.float32) + x_sum = x_sum + ct.sum(x_tile * valid_mask, 0, keepdims=True) + x_max = ct.maximum(x_max, ct.max(x_tile, 0, keepdims=True)) + + n_cols_f = ct.full((1,), float(N_COLS), dtype=ct.float32) + tau_lo = (x_sum - ct.full((1,), 1.0, ct.float32)) / n_cols_f + tau_hi = x_max + + one = ct.full((1,), 1.0, ct.float32) + half = ct.full((1,), 0.5, ct.float32) + + for _ in range(BSEARCH_ITER): + tau_mid = half * (tau_lo + tau_hi) + f = ct.full((1,), 0.0, ct.float32) + + for ci in range(n_chunks): + col_idx = ct.arange(BLOCK_SIZE, dtype=ct.int32) + ci * BLOCK_SIZE + x_tile = ct.astype( + ct.gather(x_input, (row_idx, col_idx), check_bounds=True, padding_value=-1e38), + ct.float32, + ) + valid_mask = ct.astype(col_idx < N_COLS, ct.float32) + in_supp = ct.astype(x_tile > tau_mid, ct.float32) * valid_mask + f = f + ct.sum(in_supp * (x_tile - tau_mid), 0, keepdims=True) + + tau_lo = ct.where(f >= one, tau_mid, tau_lo) + tau_hi = ct.where(f < one, tau_mid, tau_hi) + + tau = half * (tau_lo + tau_hi) + + zero = ct.full((BLOCK_SIZE,), 0.0, ct.float32) + for ci in range(n_chunks): + col_idx = ct.arange(BLOCK_SIZE, dtype=ct.int32) + ci * BLOCK_SIZE + x_tile = ct.astype( + ct.gather(x_input, (row_idx, col_idx), check_bounds=True, padding_value=0.0), + ct.float32, + ) + y_tile = ct.maximum(x_tile - tau, zero) + ct.scatter(y_output, (row_idx, col_idx), ct.astype(y_tile, y_output.dtype), check_bounds=True) + + +@ct.kernel(occupancy=2) +def _sparsemax_bsearch_kernel_large( + y_output, + x_input, + N_COLS: ct.Constant[int], + BLOCK_SIZE: ct.Constant[int], + BSEARCH_ITER: ct.Constant[int], +): + """occupancy=2: at N>16384 high occupancy thrashes L2 (7×128KB/SM); occ=2 keeps rows in L2.""" + row_idx = ct.bid(0) + n_chunks = (N_COLS + BLOCK_SIZE - 1) // BLOCK_SIZE + + x_max = ct.full((1,), -1e38, dtype=ct.float32) + x_sum = ct.full((1,), 0.0, dtype=ct.float32) + + for ci in range(n_chunks): + col_idx = ct.arange(BLOCK_SIZE, dtype=ct.int32) + ci * BLOCK_SIZE + x_tile = ct.astype( + ct.gather(x_input, (row_idx, col_idx), check_bounds=True, padding_value=-1e38), + ct.float32, + ) + valid_mask = ct.astype(col_idx < N_COLS, ct.float32) + x_sum = x_sum + ct.sum(x_tile * valid_mask, 0, keepdims=True) + x_max = ct.maximum(x_max, ct.max(x_tile, 0, keepdims=True)) + + n_cols_f = ct.full((1,), float(N_COLS), dtype=ct.float32) + tau_lo = (x_sum - ct.full((1,), 1.0, ct.float32)) / n_cols_f + tau_hi = x_max + + one = ct.full((1,), 1.0, ct.float32) + half = ct.full((1,), 0.5, ct.float32) + + for _ in range(BSEARCH_ITER): + tau_mid = half * (tau_lo + tau_hi) + f = ct.full((1,), 0.0, ct.float32) + + for ci in range(n_chunks): + col_idx = ct.arange(BLOCK_SIZE, dtype=ct.int32) + ci * BLOCK_SIZE + x_tile = ct.astype( + ct.gather(x_input, (row_idx, col_idx), check_bounds=True, padding_value=-1e38), + ct.float32, + ) + valid_mask = ct.astype(col_idx < N_COLS, ct.float32) + in_supp = ct.astype(x_tile > tau_mid, ct.float32) * valid_mask + f = f + ct.sum(in_supp * (x_tile - tau_mid), 0, keepdims=True) + + tau_lo = ct.where(f >= one, tau_mid, tau_lo) + tau_hi = ct.where(f < one, tau_mid, tau_hi) + + tau = half * (tau_lo + tau_hi) + + zero = ct.full((BLOCK_SIZE,), 0.0, ct.float32) + for ci in range(n_chunks): + col_idx = ct.arange(BLOCK_SIZE, dtype=ct.int32) + ci * BLOCK_SIZE + x_tile = ct.astype( + ct.gather(x_input, (row_idx, col_idx), check_bounds=True, padding_value=0.0), + ct.float32, + ) + y_tile = ct.maximum(x_tile - tau, zero) + ct.scatter(y_output, (row_idx, col_idx), ct.astype(y_tile, y_output.dtype), check_bounds=True) + + +@ct.kernel +def _sparsemax_bwd_kernel( + grad_input, + output, + grad_output, + N_COLS: ct.Constant[int], + BLOCK_SIZE: ct.Constant[int], +): + row_idx = ct.bid(0) + n_chunks = (N_COLS + BLOCK_SIZE - 1) // BLOCK_SIZE + + go_sum_tile = ct.full((BLOCK_SIZE,), 0.0, ct.float32) + supp_cnt_tile = ct.full((BLOCK_SIZE,), 0.0, ct.float32) + + for ci in range(n_chunks): + col_idx = ct.arange(BLOCK_SIZE, dtype=ct.int32) + ci * BLOCK_SIZE + o_tile = ct.astype(ct.gather(output, (row_idx, col_idx), check_bounds=True, padding_value=0.0), ct.float32) + go_tile = ct.astype( + ct.gather(grad_output, (row_idx, col_idx), check_bounds=True, padding_value=0.0), ct.float32 + ) + supp_f = ct.astype(o_tile > ct.full((BLOCK_SIZE,), 0.0, ct.float32), ct.float32) + go_sum_tile = go_sum_tile + supp_f * go_tile + supp_cnt_tile = supp_cnt_tile + supp_f + + go_sum = ct.sum(go_sum_tile, 0, keepdims=False) + supp_cnt = ct.sum(supp_cnt_tile, 0, keepdims=False) + mean_go = go_sum / (supp_cnt + 1e-6) + + for ci in range(n_chunks): + col_idx = ct.arange(BLOCK_SIZE, dtype=ct.int32) + ci * BLOCK_SIZE + o_tile = ct.astype(ct.gather(output, (row_idx, col_idx), check_bounds=True, padding_value=0.0), ct.float32) + go_tile = ct.astype( + ct.gather(grad_output, (row_idx, col_idx), check_bounds=True, padding_value=0.0), ct.float32 + ) + supp_f = ct.astype(o_tile > ct.full((BLOCK_SIZE,), 0.0, ct.float32), ct.float32) + gi_tile = supp_f * (go_tile - mean_go) + ct.scatter(grad_input, (row_idx, col_idx), ct.astype(gi_tile, grad_input.dtype), check_bounds=True) + + +def _sparsemax_forward_ct(x: torch.Tensor, dim: int): + if dim < 0: + dim += x.dim() + x_sw = x.transpose(dim, -1).contiguous() + n_cols = x_sw.size(-1) + n_rows = x_sw.numel() // n_cols + x_flat = x_sw.view(n_rows, n_cols) + + BLOCK_SIZE = _select_block_size(n_cols) + out_flat = torch.empty_like(x_flat) + kernel = _sparsemax_bsearch_kernel_large if n_cols > 16384 else _sparsemax_bsearch_kernel + ct.launch( + torch.cuda.current_stream(), + (n_rows, 1, 1), + kernel, + (out_flat, x_flat, int(n_cols), int(BLOCK_SIZE), int(_BSEARCH_ITER)), + ) + + return out_flat.view_as(x_sw).transpose(dim, -1).contiguous(), out_flat + + +def _sparsemax_backward_ct(grad_out: torch.Tensor, out_flat: torch.Tensor, dim: int): + grad_sw = grad_out.transpose(dim, -1).contiguous() + n_cols = grad_sw.size(-1) + n_rows = grad_sw.numel() // n_cols + go_flat = grad_sw.view(n_rows, n_cols).contiguous() + + BLOCK_SIZE = _select_block_size(n_cols) + dx_flat = torch.empty_like(go_flat) + ct.launch( + torch.cuda.current_stream(), + (n_rows, 1, 1), + _sparsemax_bwd_kernel, + (dx_flat, out_flat, go_flat, int(n_cols), int(BLOCK_SIZE)), + ) + + return dx_flat.view_as(grad_sw).transpose(dim, -1) + + +class SparsemaxCuTileFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, dim): + y, out_flat = _sparsemax_forward_ct(x.contiguous(), dim) + ctx.save_for_backward(out_flat) + ctx.dim = dim + return y + + @staticmethod + def backward(ctx, grad_out): + (out_flat,) = ctx.saved_tensors + return _sparsemax_backward_ct(grad_out.contiguous(), out_flat, ctx.dim), None + + +@register_impl("liger.sparsemax", backend="cutile") +def sparsemax( + input: torch.Tensor, + dim: int = -1, + **kwargs, +) -> torch.Tensor: + return SparsemaxCuTileFunction.apply(input, dim) diff --git a/src/tilegym/suites/liger/cutile/tiled_mlp.py b/src/tilegym/suites/liger/cutile/tiled_mlp.py new file mode 100644 index 00000000..e7ea62c2 --- /dev/null +++ b/src/tilegym/suites/liger/cutile/tiled_mlp.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +""" +Tiled MLP (cuTile backend). + +Pure Python implementation — no GPU kernel. +Shards input along sequence dimension (dim=-2), applies fn on each shard, +and concatenates. Backward re-computes forward per shard to save memory. + +""" + +import math +from typing import Callable +from typing import List +from typing import Optional + +import torch + +from tilegym.backend import register_impl + + +class _TiledMLPFunctionCT(torch.autograd.Function): + """Tiled MLP computation (no GPU kernel, memory-efficient via re-computation).""" + + @staticmethod + def forward(ctx, fn, mlp_module, x, shards, compute_params=None): + ctx.fn = fn + ctx.mlp_module = mlp_module + ctx.shards = shards + ctx.save_for_backward(x) + + x_shards = list(torch.chunk(x, chunks=shards, dim=-2)) + with torch.no_grad(): + output_shards = [fn(mlp_module, x_shard) for x_shard in x_shards] + return torch.cat(output_shards, dim=-2) + + @staticmethod + def backward(ctx, *grads): + fn = ctx.fn + (x,) = ctx.saved_tensors + mlp_module = ctx.mlp_module + shards = ctx.shards + + x_requires_grad = x.requires_grad + + # Chunk along dim=-2 to match forward sharding exactly. + # Flattening to 2D first and re-chunking would create different + # row groupings, leading to different GEMM algorithms and relu-mask + # flips at near-zero activations (up to ~0.09 gradient error). + x_detached = x.detach() + x_shards = list(torch.chunk(x_detached, chunks=shards, dim=-2)) + grad_shards = list(torch.chunk(grads[0], chunks=shards, dim=-2)) + + # Pre-allocate gradient buffer and chunk it into views aligned with x_shards. + # This lets a single backward() per shard populate x_shard_leaf.grad AND + # accumulate weight gradients, halving the number of backward passes vs the + # previous retain_graph=True + second backward() approach. + if x_requires_grad: + x_grad = torch.zeros_like(x_detached) + x_grad_shards = list(torch.chunk(x_grad, chunks=shards, dim=-2)) + else: + x_grad = None + + for i, (x_shard, grad_shard) in enumerate(zip(x_shards, grad_shards)): + x_shard_leaf = x_shard.detach().requires_grad_(x_requires_grad) + if x_requires_grad: + # Pre-assign the gradient buffer slice so backward() fills it in-place. + x_shard_leaf.grad = x_grad_shards[i] + with torch.enable_grad(): + output = fn(mlp_module, x_shard_leaf) + # Single backward per shard: accumulates weight gradients AND populates + # x_shard_leaf.grad (which is a view into x_grad) when x_requires_grad. + torch.autograd.backward(output, grad_shard) + + return None, None, x_grad, None, None + + +def _apply_tiled_mlp_ct(fn, mlp_module, x, num_shards=None, compute_params=None): + if num_shards is None: + hidden_size = x.shape[-1] + seqlen = x.shape[-2] + num_shards = math.ceil(seqlen / hidden_size) + num_shards = max(1, num_shards) + return _TiledMLPFunctionCT.apply(fn, mlp_module, x, num_shards, compute_params) + + +@register_impl("liger.tiled_mlp", backend="cutile") +def tiled_mlp( + fn: Callable, + mlp_module: torch.nn.Module, + x: torch.Tensor, + num_shards: Optional[int] = None, + compute_params: Optional[List] = None, + **kwargs, +) -> torch.Tensor: + return _apply_tiled_mlp_ct(fn, mlp_module, x, num_shards, compute_params) diff --git a/src/tilegym/suites/liger/ops.py b/src/tilegym/suites/liger/ops.py index 933b0701..616e2809 100644 --- a/src/tilegym/suites/liger/ops.py +++ b/src/tilegym/suites/liger/ops.py @@ -47,6 +47,20 @@ def jsd( raise NotImplementedError(f"jsd is not implemented for {get_current_backend()}") +@dispatch( + "liger.fused_neighborhood_attention", +) +def fused_neighborhood_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kernel_size: int = 7, + dilation: int = 1, + scale: float = None, +) -> torch.Tensor: + raise NotImplementedError(f"fused_neighborhood_attention not implemented for {get_current_backend()}") + + @dispatch( "liger.cross_entropy", ) @@ -153,6 +167,70 @@ def geglu( raise NotImplementedError(f"geglu is not implemented for {get_current_backend()}") +@dispatch( + "liger.group_norm", +) +def group_norm( + X: torch.Tensor, + num_channels: int, + num_groups: int, + W: torch.Tensor, + B: torch.Tensor, + eps: float = 1e-5, +) -> torch.Tensor: + """ + Group Normalization. + + Divides channels into groups and normalizes within each group. + + Reference: https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/group_norm.py + + Args: + X: Input tensor of shape (batch_size, num_channels, *spatial). + num_channels: Total number of channels. + num_groups: Number of groups to divide channels into. + W: Affine scale weight of shape (num_channels,). + B: Affine shift bias of shape (num_channels,). + eps: Epsilon for numerical stability. Default: 1e-5 + + Returns: + Normalized output tensor of same shape as X. + """ + raise NotImplementedError(f"group_norm is not implemented for {get_current_backend()}") + + +@dispatch( + "liger.kl_div", +) +def kl_div( + y_pred: torch.Tensor, + y_true: torch.Tensor, + reduction: str = "batchmean", + log_target: bool = False, + eps: float = 1e-10, +) -> torch.Tensor: + """ + KL Divergence loss: KL(y_true || y_pred). + + Expects y_pred as log-probabilities. y_true can be probabilities (default) + or log-probabilities (when log_target=True). + + Reference: https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/kl_div.py + + Args: + y_pred: Log-probability predictions of shape (BT, V). + y_true: Target values of shape (BT, V). Probabilities when log_target=False, + log-probabilities when log_target=True. + reduction: Reduction mode: "none" | "sum" | "mean" | "batchmean". Default: "batchmean" + log_target: If True, y_true is treated as log-probabilities. Default: False + eps: Small value for numerical stability (clamping y_true). Default: 1e-10 + + Returns: + Loss tensor. Shape (BT, V) when reduction="none", scalar otherwise. + """ + raise NotImplementedError(f"kl_div is not implemented for {get_current_backend()}") + + @dispatch( "liger.layer_norm", ) @@ -179,3 +257,184 @@ def layer_norm( Normalized output tensor of same shape as X. """ raise NotImplementedError(f"layer_norm is not implemented for {get_current_backend()}") + + +@dispatch( + "liger.llama4_rope", +) +def llama4_rope( + q: torch.Tensor, + k: torch.Tensor, + freqs_cis: torch.Tensor, + BLOCK_SIZE: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Llama4-style Rotary Position Embedding (RoPE) applied in-place to q and k. + + Performs complex multiplication: (q_r + i*q_i) * (f_r + i*f_i). + + Reference: https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/llama4_rope.py + + Args: + q: Query tensor of shape (batch_size, seq_len, n_q_heads, head_dim). + k: Key tensor of shape (batch_size, seq_len, n_k_heads, head_dim). + freqs_cis: Frequency tensor of shape (seq_len, head_dim//2) complex, + or (seq_len, head_dim//2, 2) real, or (seq_len, head_dim) real. + BLOCK_SIZE: Tile size for kernel (auto-selected if None). Default: None + + Returns: + Tuple (q, k) with rotary embeddings applied in-place. + """ + raise NotImplementedError(f"llama4_rope is not implemented for {get_current_backend()}") + + +@dispatch( + "liger.qwen2vl_mrope", +) +def qwen2vl_mrope( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + mrope_section: list, + unsqueeze_dim: int = 1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Qwen2VL Multimodal Rotary Positional Embedding (M-RoPE). + + Applies rotary embeddings to q and k using temporal / height / width sections. + + Reference: https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py + + Args: + q: Query tensor of shape (bsz, n_q_head, seq_len, head_dim). + k: Key tensor of shape (bsz, n_kv_head, seq_len, head_dim). + cos: Cosine tensor of shape (3, bsz, seq_len, head_dim). + sin: Sine tensor of shape (3, bsz, seq_len, head_dim). + mrope_section: List [t_section, h_section] with the number of head-dim + positions allocated to temporal and height embeddings. + + Returns: + Tuple (q, k) with M-RoPE applied in-place. + """ + raise NotImplementedError(f"qwen2vl_mrope is not implemented for {get_current_backend()}") + + +@dispatch( + "liger.rope", +) +def rope( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Rotary Positional Embedding (RoPE) — HuggingFace Llama/Mistral variant. + + Half-split layout: left half = real, right half = imaginary. + forward: new_r = r*cos - i*sin, new_i = i*cos + r*sin + backward: new_r = r*cos + i*sin, new_i = i*cos - r*sin + + Reference: https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/rope.py + + Args: + q: Query tensor of shape (bsz, n_q_heads, seq_len, head_dim). + k: Key tensor of shape (bsz, n_kv_heads, seq_len, head_dim). + cos: Cosine tensor of shape (1_or_bsz, seq_len, head_dim). + sin: Sine tensor of shape (1_or_bsz, seq_len, head_dim). + + Returns: + Tuple (q, k) with RoPE applied. + """ + raise NotImplementedError(f"rope is not implemented for {get_current_backend()}") + + +@dispatch( + "liger.sparsemax", +) +def sparsemax( + input: torch.Tensor, + dim: int = -1, +) -> torch.Tensor: + """ + Sparsemax: projects input onto the probability simplex (sparse alternative to softmax). + + Reference: https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/sparsemax.py + + Args: + input: Input tensor of any shape. + dim: Dimension along which sparsemax is computed. Default: -1 + + Returns: + Sparsemax output of same shape as input. + """ + raise NotImplementedError(f"sparsemax is not implemented for {get_current_backend()}") + + +@dispatch( + "liger.tiled_mlp", +) +def tiled_mlp( + fn: Callable, + mlp_module: torch.nn.Module, + x: torch.Tensor, + num_shards: Optional[int] = None, + compute_params: Optional[List] = None, +) -> torch.Tensor: + """ + Tiled MLP computation for memory-efficient long-sequence processing. + + Shards the input along the sequence dimension, applies fn on each shard, + and concatenates the results. Backward re-computes forward per shard. + + Reference: https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/tiled_mlp.py + + Args: + fn: Function to apply on each shard: fn(mlp_module, x_shard) -> output_shard. + mlp_module: The MLP nn.Module object. + x: Input tensor of shape (*, seq_len, hidden_size). + num_shards: Number of shards. If None, auto-computed as ceil(seq_len/hidden_size). + compute_params: Optional list of parameters for ZeRO optimization. Default: None + + Returns: + Output tensor of same shape as x. + """ + raise NotImplementedError(f"tiled_mlp is not implemented for {get_current_backend()}") + + +@dispatch( + "liger.multi_token_attention", +) +def multi_token_attention( + scores: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + sparse: bool = False, +) -> torch.Tensor: + """ + Multi-Token Attention: causal masking + softmax + conv2d + causal masking. + + Applies a causal lower-triangular mask, softmax attention, a learnable 2D + convolution over the attention matrix, and a final causal zero-mask. + + Reference: https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/multi_token_attention.py + + Args: + scores: Attention score tensor of shape (*, L, L). + weight: Conv2d weight of shape (out_channels, in_channels/groups, kH, kW). + bias: Optional conv2d bias of shape (out_channels,). Default: None + stride: Conv2d stride. Default: 1 + padding: Conv2d padding. Default: 0 + dilation: Conv2d dilation. Default: 1 + groups: Conv2d groups. Default: 1 + sparse: Use sparsemax instead of softmax. Default: False + + Returns: + Output tensor of same shape as scores. + """ + raise NotImplementedError(f"multi_token_attention is not implemented for {get_current_backend()}") diff --git a/tests/suites/liger/test_fused_neighborhood_attention.py b/tests/suites/liger/test_fused_neighborhood_attention.py new file mode 100644 index 00000000..c9a923c9 --- /dev/null +++ b/tests/suites/liger/test_fused_neighborhood_attention.py @@ -0,0 +1,151 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +import gc +import math + +import pytest +import torch + +import tilegym +from tests import common +from tilegym.suites.liger.ops import fused_neighborhood_attention + +# (batch, heads, seq_len, head_dim) +SHAPES = [ + (1, 1, 16, 32), + (2, 4, 32, 64), + (1, 2, 64, 32), +] + +FLOAT_DTYPES = [torch.float32, torch.float16, torch.bfloat16] + +KERNEL_SIZES = [3, 7] +DILATIONS = [1, 2] + + +def _ref_neighborhood_attention(query, key, value, kernel_size=7, dilation=1, scale=None): + """Pure-PyTorch reference implementation.""" + batch, heads, seq_len, head_dim = query.shape + if scale is None: + scale = 1.0 / math.sqrt(head_dim) + + q = query.float() + k = key.float() + v = value.float() + + # Build neighborhood mask [seq_len, seq_len] + half = kernel_size // 2 + mask = torch.zeros(seq_len, seq_len, device=query.device, dtype=torch.float32) + for i in range(seq_len): + for j in range(seq_len): + dist = abs(i - j) + if dilation == 1: + if dist <= half: + mask[i, j] = 1.0 + else: + if dist <= half * dilation and (i - j) % dilation == 0: + mask[i, j] = 1.0 + + # scores: [B, H, S, S] + scores = torch.einsum("bhid,bhjd->bhij", q, k) * scale + scores = scores.masked_fill(mask[None, None] == 0, float("-inf")) + attn = torch.softmax(scores, dim=-1) + attn = torch.nan_to_num(attn, nan=0.0) # rows with all -inf become 0 + out = torch.einsum("bhij,bhjd->bhid", attn, v) + return out.to(query.dtype) + + +class Test_Liger_FusedNeighborhoodAttention(common.PyTestCase): + _backends = ["cutile"] + + @pytest.mark.parametrize("shape", SHAPES) + @pytest.mark.parametrize("kernel_size", KERNEL_SIZES) + @pytest.mark.parametrize("dtype", FLOAT_DTYPES) + @pytest.mark.parametrize("backend", _backends) + def test_op(self, shape, kernel_size, dtype, backend, monkeypatch): + self.setUp() + if tilegym.is_backend_available(backend): + tilegym.set_backend(backend) + else: + pytest.skip(f"Backend {backend} is not available") + + batch, heads, seq_len, head_dim = shape + q = torch.randn(batch, heads, seq_len, head_dim, dtype=dtype, device="cuda") + k = torch.randn(batch, heads, seq_len, head_dim, dtype=dtype, device="cuda") + v = torch.randn(batch, heads, seq_len, head_dim, dtype=dtype, device="cuda") + + atol = 2e-2 if dtype != torch.float32 else 5e-3 + rtol = 1e-2 + + framework_fn = lambda: fused_neighborhood_attention(q, k, v, kernel_size=kernel_size) + ref_fn = lambda: _ref_neighborhood_attention(q, k, v, kernel_size=kernel_size) + + self.assertCorrectness( + framework_fn, + ref_fn, + kwargs={}, + atol=atol, + rtol=rtol, + ) + + @pytest.mark.parametrize("shape", SHAPES) + @pytest.mark.parametrize("dilation", DILATIONS) + @pytest.mark.parametrize("backend", _backends) + def test_op_dilation(self, shape, dilation, backend, monkeypatch): + self.setUp() + if tilegym.is_backend_available(backend): + tilegym.set_backend(backend) + else: + pytest.skip(f"Backend {backend} is not available") + + batch, heads, seq_len, head_dim = shape + dtype = torch.float32 + q = torch.randn(batch, heads, seq_len, head_dim, dtype=dtype, device="cuda") + k = torch.randn(batch, heads, seq_len, head_dim, dtype=dtype, device="cuda") + v = torch.randn(batch, heads, seq_len, head_dim, dtype=dtype, device="cuda") + + framework_fn = lambda: fused_neighborhood_attention(q, k, v, kernel_size=7, dilation=dilation) + ref_fn = lambda: _ref_neighborhood_attention(q, k, v, kernel_size=7, dilation=dilation) + + self.assertCorrectness( + framework_fn, + ref_fn, + kwargs={}, + atol=5e-3, + rtol=1e-2, + ) + + @pytest.mark.parametrize("shape", SHAPES) + @pytest.mark.parametrize("backend", _backends) + def test_op_backward(self, shape, backend, monkeypatch): + self.setUp() + if backend == "cutile": + pytest.skip("cutile backward for FusedNeighborhoodAttention hangs in CI; skip until fixed") + if tilegym.is_backend_available(backend): + tilegym.set_backend(backend) + else: + pytest.skip(f"Backend {backend} is not available") + + batch, heads, seq_len, head_dim = shape + dtype = torch.float32 + + q = torch.randn(batch, heads, seq_len, head_dim, dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn(batch, heads, seq_len, head_dim, dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn(batch, heads, seq_len, head_dim, dtype=dtype, device="cuda", requires_grad=True) + + q_ref = q.detach().clone().requires_grad_(True) + k_ref = k.detach().clone().requires_grad_(True) + v_ref = v.detach().clone().requires_grad_(True) + + out = fused_neighborhood_attention(q, k, v, kernel_size=7) + out_ref = _ref_neighborhood_attention(q_ref, k_ref, v_ref, kernel_size=7) + + grad = torch.randn_like(out) + out.backward(grad) + out_ref.backward(grad.clone()) + + torch.testing.assert_close(q.grad, q_ref.grad, atol=1e-2, rtol=1e-1) + torch.testing.assert_close(k.grad, k_ref.grad, atol=1e-2, rtol=1e-1) + torch.testing.assert_close(v.grad, v_ref.grad, atol=1e-2, rtol=1e-1) diff --git a/tests/suites/liger/test_group_norm.py b/tests/suites/liger/test_group_norm.py new file mode 100644 index 00000000..b86083c9 --- /dev/null +++ b/tests/suites/liger/test_group_norm.py @@ -0,0 +1,112 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +import gc + +import pytest +import torch +import torch.nn.functional as F + +import tilegym +from tests import common +from tilegym.suites.liger.ops import group_norm + + +class Test_Liger_GroupNorm(common.PyTestCase): + _backends = ["cutile"] + + @staticmethod + def reference(X, num_channels, num_groups, W, B, eps=1e-5): + """PyTorch float32 reference for group normalization.""" + return F.group_norm(X.float(), num_groups, W.float(), B.float(), eps=eps).to(X.dtype) + + @pytest.mark.parametrize( + "batch_size, num_channels, num_groups, hidden_size, dtype", + [ + (2, 4, 2, 8, torch.float32), + (4, 8, 4, 16, torch.float32), + (2, 4, 2, 8, torch.float16), + (2, 4, 2, 8, torch.bfloat16), + (2, 6, 3, 10, torch.float32), # non-power-of-2 + # Shapes from Liger test/transformers/test_group_norm.py + (2, 63, 21, 2163, torch.float32), + (16, 32, 1, 4096, torch.float32), + ], + ) + @pytest.mark.parametrize("backend", _backends) + def test_op_forward(self, batch_size, num_channels, hidden_size, num_groups, dtype, backend, monkeypatch): + """Test forward output matches PyTorch F.group_norm reference.""" + monkeypatch.setenv("DISABLE_AUTOTUNE", "1") + self.setUp() + if tilegym.is_backend_available(backend): + tilegym.set_backend(backend) + else: + pytest.skip(f"Backend {backend} is not available") + + device = torch.device("cuda") + X = torch.randn(batch_size, num_channels, hidden_size, dtype=dtype, device=device) + W = torch.randn(num_channels, dtype=dtype, device=device) + B = torch.randn(num_channels, dtype=dtype, device=device) + + atol = 1e-2 if dtype != torch.float32 else 5e-3 + rtol = 1e-2 if dtype != torch.float32 else 5e-3 + + Y_test = group_norm(X.clone(), num_channels, num_groups, W, B) + Y_ref = self.reference(X, num_channels, num_groups, W, B) + + assert torch.allclose(Y_test.float(), Y_ref.float(), atol=atol, rtol=rtol), ( + f"Forward mismatch: max_diff={((Y_test.float() - Y_ref.float()).abs().max()).item():.6f}" + ) + + @pytest.mark.parametrize( + "batch_size, num_channels, num_groups, hidden_size, dtype", + [ + (2, 4, 2, 8, torch.float32), + (4, 8, 4, 16, torch.float32), + (2, 4, 2, 8, torch.float16), + (2, 4, 2, 8, torch.bfloat16), + (2, 63, 21, 2163, torch.float32), + ], + ) + @pytest.mark.parametrize("backend", _backends) + def test_op_backward(self, batch_size, num_channels, hidden_size, num_groups, dtype, backend, monkeypatch): + """Test backward gradients (dX, dW, dB) match PyTorch reference.""" + monkeypatch.setenv("DISABLE_AUTOTUNE", "1") + self.setUp() + if tilegym.is_backend_available(backend): + tilegym.set_backend(backend) + else: + pytest.skip(f"Backend {backend} is not available") + + device = torch.device("cuda") + X_data = torch.randn(batch_size, num_channels, hidden_size, dtype=dtype, device=device) + W_data = torch.randn(num_channels, dtype=dtype, device=device) + B_data = torch.randn(num_channels, dtype=dtype, device=device) + + atol = 1e-1 + rtol = 1e-1 + + # Test implementation + X_test = X_data.clone().requires_grad_(True) + W_test = W_data.clone().requires_grad_(True) + B_test = B_data.clone().requires_grad_(True) + Y_test = group_norm(X_test, num_channels, num_groups, W_test, B_test) + Y_test.backward(torch.ones_like(Y_test)) + + # Reference (float32) + X_ref = X_data.clone().float().requires_grad_(True) + W_ref = W_data.clone().float().requires_grad_(True) + B_ref = B_data.clone().float().requires_grad_(True) + Y_ref = F.group_norm(X_ref, num_groups, W_ref, B_ref) + Y_ref.backward(torch.ones_like(Y_ref)) + + assert torch.allclose(X_test.grad.float(), X_ref.grad.float(), atol=atol, rtol=rtol), ( + f"dX mismatch: max_diff={((X_test.grad.float() - X_ref.grad.float()).abs().max()).item():.6f}" + ) + assert torch.allclose(W_test.grad.float(), W_ref.grad.float(), atol=atol, rtol=rtol), ( + f"dW mismatch: max_diff={((W_test.grad.float() - W_ref.grad.float()).abs().max()).item():.6f}" + ) + assert torch.allclose(B_test.grad.float(), B_ref.grad.float(), atol=atol, rtol=rtol), ( + f"dB mismatch: max_diff={((B_test.grad.float() - B_ref.grad.float()).abs().max()).item():.6f}" + ) diff --git a/tests/suites/liger/test_kl_div.py b/tests/suites/liger/test_kl_div.py new file mode 100644 index 00000000..7bbe69f3 --- /dev/null +++ b/tests/suites/liger/test_kl_div.py @@ -0,0 +1,119 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +import gc + +import pytest +import torch + +import tilegym +from tests import common +from tilegym.suites.liger.ops import kl_div + + +class Test_Liger_KLDiv(common.PyTestCase): + _backends = ["cutile"] + + @staticmethod + def reference(y_pred, y_true, reduction="batchmean", log_target=False, eps=1e-10): + """PyTorch float32 reference for KL divergence.""" + y_pred_f = y_pred.float() + y_true_f = y_true.float() + if not log_target: + loss = y_true_f * (torch.log(torch.clamp(y_true_f, min=eps)) - y_pred_f) + else: + loss = torch.exp(y_true_f) * (y_true_f - y_pred_f) + + if reduction == "none": + return loss + elif reduction == "sum": + return loss.sum() + elif reduction == "mean": + return loss.sum() / (loss.shape[0] * loss.shape[1]) + else: # batchmean + return loss.sum() / loss.shape[0] + + @pytest.mark.parametrize( + "shape, dtype", + [ + ((4, 256), torch.float32), + ((8, 512), torch.float32), + ((16, 1024), torch.float32), + ((4, 256), torch.float16), + ((4, 256), torch.bfloat16), + ((4, 300), torch.float32), # non-power-of-2 + ], + ) + @pytest.mark.parametrize("reduction", ["none", "sum", "mean", "batchmean"]) + @pytest.mark.parametrize("log_target", [False, True]) + @pytest.mark.parametrize("backend", _backends) + def test_op(self, shape, dtype, reduction, log_target, backend, monkeypatch): + self.setUp() + if tilegym.is_backend_available(backend): + tilegym.set_backend(backend) + else: + pytest.skip(f"Backend {backend} is not available") + + device = torch.device("cuda") + # y_pred: log-probs (log-softmax ensures valid log-probs) + y_pred = torch.log_softmax(torch.randn(*shape, dtype=dtype, device=device), dim=-1) + # y_true: probs (softmax) or log-probs + y_true_raw = torch.softmax(torch.randn(*shape, dtype=dtype, device=device), dim=-1) + y_true = torch.log(y_true_raw.float().clamp(min=1e-10)).to(dtype) if log_target else y_true_raw + + self.assertCorrectness( + kl_div, + self.reference, + { + "y_pred": y_pred, + "y_true": y_true, + "reduction": reduction, + "log_target": log_target, + }, + atol=1e-2, + rtol=1e-2, + ) + + @pytest.mark.parametrize( + "shape, dtype", + [ + ((4, 256), torch.float32), + ((8, 512), torch.float32), + ((4, 256), torch.float16), + ((4, 256), torch.bfloat16), + ((16, 1024), torch.float32), + ((4, 300), torch.float32), # non-power-of-2 + ], + ) + @pytest.mark.parametrize("reduction", ["sum", "mean", "batchmean"]) + @pytest.mark.parametrize("log_target", [False, True]) + @pytest.mark.parametrize("backend", _backends) + def test_op_backward(self, shape, dtype, reduction, log_target, backend, monkeypatch): + """Test backward pass (gradient w.r.t. y_pred).""" + self.setUp() + if tilegym.is_backend_available(backend): + tilegym.set_backend(backend) + else: + pytest.skip(f"Backend {backend} is not available") + + device = torch.device("cuda") + y_pred = torch.log_softmax(torch.randn(*shape, dtype=dtype, device=device), dim=-1).requires_grad_(True) + y_true_raw = torch.softmax(torch.randn(*shape, dtype=dtype, device=device), dim=-1) + y_true = torch.log(y_true_raw.float().clamp(min=1e-10)).to(dtype) if log_target else y_true_raw + + dout = torch.ones((), dtype=dtype, device=device) + + self.assertCorrectness( + kl_div, + self.reference, + { + "y_pred": y_pred, + "y_true": y_true, + "reduction": reduction, + "log_target": log_target, + }, + gradient=dout, + atol=1e-2, + rtol=1e-2, + ) diff --git a/tests/suites/liger/test_llama4_rope.py b/tests/suites/liger/test_llama4_rope.py new file mode 100644 index 00000000..46c1d93e --- /dev/null +++ b/tests/suites/liger/test_llama4_rope.py @@ -0,0 +1,131 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +import gc + +import pytest +import torch + +import tilegym +from tests import common +from tilegym.suites.liger.ops import llama4_rope + + +def _make_freqs_cis(seq_len, head_dim, device, dtype=torch.float32): + """Build a simple rotary frequency tensor of shape (seq_len, head_dim//2, 2).""" + half = head_dim // 2 + # theta = 1 / (10000 ^ (2i / head_dim)) for i in [0, half) + theta = 1.0 / (10000 ** (torch.arange(0, half, dtype=torch.float32, device=device) * 2.0 / head_dim)) + t = torch.arange(seq_len, dtype=torch.float32, device=device) + freqs = torch.outer(t, theta) # (seq_len, half) + # return as (seq_len, half, 2) real tensor (cos, sin pairs) + freqs_cis = torch.stack([freqs.cos(), freqs.sin()], dim=-1).to(dtype) + return freqs_cis + + +def _reference_rope(q, k, freqs_cis): + """PyTorch reference for RoPE: complex multiplication in float32.""" + q_f = q.float() + k_f = k.float() + + # freqs_cis: (seq_len, head_dim//2, 2) or (seq_len, head_dim) + if freqs_cis.ndim == 2: + freqs_cis = freqs_cis.view(freqs_cis.shape[0], -1, 2) + freqs_cis_f = freqs_cis.float() # (seq, half, 2) + f_r = freqs_cis_f[..., 0] # (seq, half) + f_i = freqs_cis_f[..., 1] # (seq, half) + + def apply_rope(x): + # x: (B, S, H, D), treat as (B, S, H, half, 2) + B, S, H, D = x.shape + x_r = x[..., 0::2] # (B, S, H, half) real + x_i = x[..., 1::2] # (B, S, H, half) imag + # expand freqs to (1, S, 1, half) + fr = f_r.unsqueeze(0).unsqueeze(2) # (1, S, 1, half) + fi = f_i.unsqueeze(0).unsqueeze(2) + new_r = x_r * fr - x_i * fi + new_i = x_r * fi + x_i * fr + out = torch.stack([new_r, new_i], dim=-1).reshape(B, S, H, D) + return out + + return apply_rope(q_f).to(q.dtype), apply_rope(k_f).to(k.dtype) + + +class Test_Liger_Llama4Rope(common.PyTestCase): + _backends = ["cutile"] + + @pytest.mark.parametrize( + "shape, dtype", + [ + ((2, 8, 4, 64), torch.float32), + ((1, 16, 8, 128), torch.float32), + ((2, 8, 4, 64), torch.float16), + ((2, 8, 4, 64), torch.bfloat16), + ((1, 4, 2, 48), torch.float32), # non-power-of-2 head_dim + ], + ) + @pytest.mark.parametrize("backend", _backends) + def test_op(self, shape, dtype, backend, monkeypatch): + self.setUp() + if tilegym.is_backend_available(backend): + tilegym.set_backend(backend) + else: + pytest.skip(f"Backend {backend} is not available") + + device = torch.device("cuda") + B, S, H_q, D = shape + H_k = H_q // 2 if H_q > 1 else H_q # GQA-style: fewer key heads + + q = torch.randn(B, S, H_q, D, dtype=dtype, device=device) + k = torch.randn(B, S, H_k, D, dtype=dtype, device=device) + freqs_cis = _make_freqs_cis(S, D, device, dtype=torch.float32) + + def ref_fn(q=q, k=k, freqs_cis=freqs_cis): + return _reference_rope(q, k, freqs_cis) + + def fw_fn(q=q, k=k, freqs_cis=freqs_cis): + return llama4_rope(q.clone(), k.clone(), freqs_cis) + + q_ref, k_ref = ref_fn() + q_out, k_out = fw_fn() + + assert torch.allclose(q_out.float(), q_ref.float(), atol=1e-2, rtol=1e-2), ( + f"q mismatch: max_diff={(q_out.float() - q_ref.float()).abs().max()}" + ) + assert torch.allclose(k_out.float(), k_ref.float(), atol=1e-2, rtol=1e-2), ( + f"k mismatch: max_diff={(k_out.float() - k_ref.float()).abs().max()}" + ) + + @pytest.mark.parametrize( + "shape, dtype", + [ + ((2, 8, 4, 64), torch.float32), + ((2, 8, 4, 64), torch.float16), + ((2, 8, 4, 64), torch.bfloat16), + ], + ) + @pytest.mark.parametrize("backend", _backends) + def test_op_backward(self, shape, dtype, backend, monkeypatch): + """Test backward pass using conjugate rotation (imag_sign=-1).""" + self.setUp() + if tilegym.is_backend_available(backend): + tilegym.set_backend(backend) + else: + pytest.skip(f"Backend {backend} is not available") + + device = torch.device("cuda") + B, S, H_q, D = shape + H_k = H_q + + q = torch.randn(B, S, H_q, D, dtype=dtype, device=device, requires_grad=True) + k = torch.randn(B, S, H_k, D, dtype=dtype, device=device, requires_grad=True) + freqs_cis = _make_freqs_cis(S, D, device, dtype=torch.float32) + + dout_q = torch.ones(B, S, H_q, D, dtype=dtype, device=device) + dout_k = torch.ones(B, S, H_k, D, dtype=dtype, device=device) + + # Forward + backward + q_out, k_out = llama4_rope(q.clone().requires_grad_(True), k.clone().requires_grad_(True), freqs_cis) + # Backward should not raise + (q_out.sum() + k_out.sum()).backward() diff --git a/tests/suites/liger/test_multi_token_attention.py b/tests/suites/liger/test_multi_token_attention.py new file mode 100644 index 00000000..1c900557 --- /dev/null +++ b/tests/suites/liger/test_multi_token_attention.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +import gc + +import pytest +import torch +import torch.nn.functional as F + +import tilegym +from tests import common +from tilegym.suites.liger.ops import multi_token_attention + + +def _reference_multi_token_attention(scores, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + """ + PyTorch float32 reference for multi-token attention. + + 1. Apply causal -inf mask (future positions → -inf) + 2. Softmax + 3. Conv2d + 4. Apply causal zero mask (future positions → 0) + """ + scores_f = scores.float() + L = scores_f.shape[-1] + # Causal mask: upper triangular → -inf + mask = torch.triu(torch.ones(L, L, device=scores.device, dtype=torch.bool), diagonal=1) + # Expand mask to match scores shape (*, L, L) + for _ in range(scores_f.dim() - 2): + mask = mask.unsqueeze(0) + mask = mask.expand_as(scores_f) + scores_masked = scores_f.masked_fill(mask, -1e9) + probs = torch.softmax(scores_masked, dim=-1) + + out_conv = F.conv2d( + probs, + weight.float(), + bias.float() if bias is not None else None, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + # Zero out future positions + zero_mask = torch.tril(torch.ones(L, L, device=scores.device, dtype=torch.float32)) + for _ in range(out_conv.dim() - 2): + zero_mask = zero_mask.unsqueeze(0) + zero_mask = zero_mask.expand_as(out_conv) + out = out_conv * zero_mask + return out.to(scores.dtype) + + +class Test_Liger_MultiTokenAttention(common.PyTestCase): + _backends = ["cutile"] + + @pytest.mark.parametrize( + "batch_channels_L, dtype", + [ + ((2, 1, 8), torch.float32), + ((1, 2, 16), torch.float32), + ((2, 1, 8), torch.float16), + ((2, 1, 8), torch.bfloat16), + ], + ) + @pytest.mark.parametrize("backend", _backends) + def test_op(self, batch_channels_L, dtype, backend, monkeypatch): + """Test multi-token attention forward with 1x1 conv weight.""" + self.setUp() + if tilegym.is_backend_available(backend): + tilegym.set_backend(backend) + else: + pytest.skip(f"Backend {backend} is not available") + + device = torch.device("cuda") + batch, channels, L = batch_channels_L + + # Use identity-like 1×1 conv weight + weight = torch.ones(channels, channels, 1, 1, dtype=dtype, device=device) + scores = torch.randn(batch, channels, L, L, dtype=dtype, device=device) + + def fw(): + return multi_token_attention(scores.clone(), weight) + + def ref(): + return _reference_multi_token_attention(scores, weight) + + self.assertCorrectness(fw, ref, kwargs={}, atol=1e-2, rtol=1e-2) + + @pytest.mark.parametrize( + "batch_channels_L, dtype", + [ + ((2, 1, 8), torch.float32), + ((1, 1, 16), torch.float32), + ], + ) + @pytest.mark.parametrize("backend", _backends) + def test_op_backward(self, batch_channels_L, dtype, backend, monkeypatch): + """Test backward pass.""" + self.setUp() + if tilegym.is_backend_available(backend): + tilegym.set_backend(backend) + else: + pytest.skip(f"Backend {backend} is not available") + + device = torch.device("cuda") + batch, channels, L = batch_channels_L + + weight = torch.ones(channels, channels, 1, 1, dtype=dtype, device=device, requires_grad=True) + scores = torch.randn(batch, channels, L, L, dtype=dtype, device=device, requires_grad=True) + + out = multi_token_attention(scores, weight) + out.sum().backward() + + assert scores.grad is not None + assert weight.grad is not None + + @pytest.mark.parametrize( + "batch, channels, L, groups, dtype", + [ + (2, 1, 8, 1, torch.float32), # baseline: CH=1, groups=1 (mm path) + (2, 2, 8, 1, torch.float32), # CH=2, groups=1: mm path generalises to C > 1 + (2, 2, 8, 2, torch.float32), # CH=2, groups=2 depthwise: exercises cuDNN fallback + (2, 2, 8, 2, torch.bfloat16), # bf16, depthwise: fallback + bf16 precision + ], + ) + @pytest.mark.parametrize("backend", _backends) + def test_op_backward_conv_params(self, batch, channels, L, groups, dtype, backend, monkeypatch): + """Backward gradient correctness for varied channel/group configurations. + + Specifically exercises: + - CH > 1 with groups=1: mm-based conv backward must generalise to C_in > 1 + - groups > 1 (depthwise conv): mm path is incorrect; cuDNN fallback must be used + """ + self.setUp() + if tilegym.is_backend_available(backend): + tilegym.set_backend(backend) + else: + pytest.skip(f"Backend {backend} is not available") + + device = torch.device("cuda") + torch.manual_seed(42) + + # weight shape for grouped conv: (C_out, C_in // groups, kH, kW) + weight_shape = (channels, channels // groups, 1, 1) + s = torch.randn(batch, channels, L, L, dtype=dtype, device=device) + w = torch.randn(*weight_shape, dtype=dtype, device=device) + + # fp32 reference + s_ref = s.float().detach().requires_grad_(True) + w_ref = w.float().detach().requires_grad_(True) + _reference_multi_token_attention(s_ref, w_ref, groups=groups).sum().backward() + + # tilegym implementation + s_nvt = s.detach().requires_grad_(True) + w_nvt = w.detach().requires_grad_(True) + multi_token_attention(s_nvt, w_nvt, groups=groups).sum().backward() + + # fp32 is tight; bf16 runs natively with inherent precision loss vs fp32 ref + atol = 1e-4 if dtype == torch.float32 else 0.1 + rtol = 1e-3 if dtype == torch.float32 else 0.1 + + assert torch.allclose(s_nvt.grad.float(), s_ref.grad, atol=atol, rtol=rtol), ( + f"scores.grad mismatch (dtype={dtype}, groups={groups}, " + f"max_err={(s_nvt.grad.float() - s_ref.grad).abs().max().item():.5f})" + ) + assert torch.allclose(w_nvt.grad.float(), w_ref.grad, atol=atol, rtol=rtol), ( + f"weight.grad mismatch (dtype={dtype}, groups={groups}, " + f"max_err={(w_nvt.grad.float() - w_ref.grad).abs().max().item():.5f})" + ) + + @pytest.mark.parametrize( + "batch_channels_L, dtype", + [ + ((2, 1, 8), torch.float32), + ((2, 1, 8), torch.float16), + ((2, 1, 8), torch.bfloat16), + ], + ) + @pytest.mark.parametrize("backend", _backends) + def test_op_with_bias(self, batch_channels_L, dtype, backend, monkeypatch): + """Test multi-token attention forward with bias.""" + self.setUp() + if tilegym.is_backend_available(backend): + tilegym.set_backend(backend) + else: + pytest.skip(f"Backend {backend} is not available") + + device = torch.device("cuda") + batch, channels, L = batch_channels_L + + weight = torch.ones(channels, channels, 1, 1, dtype=dtype, device=device) + bias = torch.zeros(channels, dtype=dtype, device=device) + scores = torch.randn(batch, channels, L, L, dtype=dtype, device=device) + + def fw(): + return multi_token_attention(scores.clone(), weight, bias=bias) + + def ref(): + return _reference_multi_token_attention(scores, weight, bias=bias) + + self.assertCorrectness(fw, ref, kwargs={}, atol=1e-2, rtol=1e-2) diff --git a/tests/suites/liger/test_qwen2vl_mrope.py b/tests/suites/liger/test_qwen2vl_mrope.py new file mode 100644 index 00000000..0369a77a --- /dev/null +++ b/tests/suites/liger/test_qwen2vl_mrope.py @@ -0,0 +1,158 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +import gc + +import pytest +import torch + +import tilegym +from tests import common +from tilegym.suites.liger.ops import qwen2vl_mrope + + +def _make_cos_sin(bsz, seq_len, head_dim, device, dtype=torch.float32): + """ + Build cos/sin tensors of shape (3, bsz, seq_len, head_dim). + + Temporal section: cos[0], Height section: cos[1], Width section: cos[2]. + """ + half = head_dim // 2 + theta = 1.0 / (10000 ** (torch.arange(0, half, dtype=torch.float32, device=device) * 2.0 / head_dim)) + t = torch.arange(seq_len, dtype=torch.float32, device=device) + freqs = torch.outer(t, theta) # (seq_len, half) + cos_1d = freqs.cos() # (seq_len, half) + sin_1d = freqs.sin() # (seq_len, half) + + # Full head_dim: pad right half with zeros (only first half used for RoPE) + cos_full = torch.cat([cos_1d, torch.zeros_like(cos_1d)], dim=-1) # (seq_len, head_dim) + sin_full = torch.cat([sin_1d, torch.zeros_like(sin_1d)], dim=-1) + + # Expand to (3, bsz, seq_len, head_dim) — same values across sections for simplicity + cos_3d = cos_full.unsqueeze(0).unsqueeze(0).expand(3, bsz, seq_len, head_dim).contiguous().to(dtype) + sin_3d = sin_full.unsqueeze(0).unsqueeze(0).expand(3, bsz, seq_len, head_dim).contiguous().to(dtype) + return cos_3d, sin_3d + + +def _reference_mrope(q, k, cos, sin, mrope_section): + """ + PyTorch float32 reference for Qwen2VL M-RoPE. + + q: (bsz, n_q_heads, seq_len, head_dim) + k: (bsz, n_k_heads, seq_len, head_dim) + cos/sin: (3, bsz, seq_len, head_dim) + mrope_section: [t_section, h_section] + """ + bsz, n_q_heads, seq_len, head_dim = q.shape + hd_half = head_dim // 2 + + t_end = mrope_section[0] + h_end = t_end + mrope_section[1] + + # Build effective cos/sin: (bsz, seq_len, hd_half) + cos_eff = torch.zeros(bsz, seq_len, hd_half, dtype=torch.float32, device=q.device) + sin_eff = torch.zeros(bsz, seq_len, hd_half, dtype=torch.float32, device=q.device) + + cos_eff[:, :, :t_end] = cos[0, :, :, :t_end].float() + sin_eff[:, :, :t_end] = sin[0, :, :, :t_end].float() + cos_eff[:, :, t_end:h_end] = cos[1, :, :, t_end:h_end].float() + sin_eff[:, :, t_end:h_end] = sin[1, :, :, t_end:h_end].float() + cos_eff[:, :, h_end:] = cos[2, :, :, h_end:hd_half].float() + sin_eff[:, :, h_end:] = sin[2, :, :, h_end:hd_half].float() + + def apply(x, c, s): + # x: (bsz, n_heads, seq_len, head_dim) + x_f = x.float() + x_r = x_f[..., :hd_half] # left half + x_i = x_f[..., hd_half:] # right half + # c, s: (bsz, seq_len, hd_half) → (bsz, 1, seq_len, hd_half) + c_exp = c.unsqueeze(1) + s_exp = s.unsqueeze(1) + new_r = x_r * c_exp - x_i * s_exp + new_i = x_i * c_exp + x_r * s_exp + return torch.cat([new_r, new_i], dim=-1).to(x.dtype) + + return apply(q, cos_eff, sin_eff), apply(k, cos_eff, sin_eff) + + +class Test_Liger_Qwen2VLMRope(common.PyTestCase): + _backends = ["cutile"] + + @pytest.mark.parametrize( + "shape, dtype", + [ + ((2, 8, 16, 64), torch.float32), + ((1, 4, 8, 128), torch.float32), + ((2, 8, 16, 64), torch.float16), + ((2, 8, 16, 64), torch.bfloat16), + ((1, 2, 4, 48), torch.float32), # non-power-of-2 head_dim + ], + ) + @pytest.mark.parametrize("backend", _backends) + def test_op(self, shape, dtype, backend, monkeypatch): + """Test M-RoPE forward pass.""" + self.setUp() + if tilegym.is_backend_available(backend): + tilegym.set_backend(backend) + else: + pytest.skip(f"Backend {backend} is not available") + + device = torch.device("cuda") + bsz, n_q_heads, seq_len, head_dim = shape + n_k_heads = max(1, n_q_heads // 2) + hd_half = head_dim // 2 + # mrope_section: split head_dim_half into 3 roughly equal parts + t_sec = hd_half // 3 + h_sec = hd_half // 3 + mrope_section = [t_sec, h_sec] + + q = torch.randn(bsz, n_q_heads, seq_len, head_dim, dtype=dtype, device=device) + k = torch.randn(bsz, n_k_heads, seq_len, head_dim, dtype=dtype, device=device) + cos, sin = _make_cos_sin(bsz, seq_len, head_dim, device) + + q_ref, k_ref = _reference_mrope(q, k, cos, sin, mrope_section) + q_out, k_out = qwen2vl_mrope(q.clone(), k.clone(), cos, sin, mrope_section) + + assert torch.allclose(q_out.float(), q_ref.float(), atol=1e-2, rtol=1e-2), ( + f"q mismatch: max_diff={(q_out.float() - q_ref.float()).abs().max()}" + ) + assert torch.allclose(k_out.float(), k_ref.float(), atol=1e-2, rtol=1e-2), ( + f"k mismatch: max_diff={(k_out.float() - k_ref.float()).abs().max()}" + ) + + @pytest.mark.parametrize( + "shape, dtype", + [ + ((2, 8, 16, 64), torch.float32), + ((2, 8, 16, 64), torch.float16), + ((2, 8, 16, 64), torch.bfloat16), + ], + ) + @pytest.mark.parametrize("backend", _backends) + def test_op_backward(self, shape, dtype, backend, monkeypatch): + """Test backward pass (gradient flows through M-RoPE).""" + self.setUp() + if tilegym.is_backend_available(backend): + tilegym.set_backend(backend) + else: + pytest.skip(f"Backend {backend} is not available") + + device = torch.device("cuda") + bsz, n_q_heads, seq_len, head_dim = shape + n_k_heads = n_q_heads + hd_half = head_dim // 2 + t_sec = hd_half // 3 + h_sec = hd_half // 3 + mrope_section = [t_sec, h_sec] + + q = torch.randn(bsz, n_q_heads, seq_len, head_dim, dtype=dtype, device=device, requires_grad=True) + k = torch.randn(bsz, n_k_heads, seq_len, head_dim, dtype=dtype, device=device, requires_grad=True) + cos, sin = _make_cos_sin(bsz, seq_len, head_dim, device) + + q_out, k_out = qwen2vl_mrope(q, k, cos, sin, mrope_section) + (q_out.sum() + k_out.sum()).backward() + + # Gradient should be non-None + assert q.grad is not None + assert k.grad is not None diff --git a/tests/suites/liger/test_rope.py b/tests/suites/liger/test_rope.py new file mode 100644 index 00000000..d7cd05c7 --- /dev/null +++ b/tests/suites/liger/test_rope.py @@ -0,0 +1,143 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +import gc + +import pytest +import torch + +import tilegym +from tests import common +from tilegym.suites.liger.ops import rope + + +def _make_cos_sin(bsz, seq_len, head_dim, device, dtype=torch.float32, broadcast_batch=False): + """ + Build cos/sin tensors of shape (cos_bsz, seq_len, head_dim). + + cos_bsz = 1 if broadcast_batch else bsz. + Only the first head_dim//2 columns carry real rotation values; the rest are zeros. + """ + cos_bsz = 1 if broadcast_batch else bsz + half = head_dim // 2 + theta = 1.0 / (10000 ** (torch.arange(0, half, dtype=torch.float32, device=device) * 2.0 / head_dim)) + t = torch.arange(seq_len, dtype=torch.float32, device=device) + freqs = torch.outer(t, theta) # (seq_len, half) + cos_1d = freqs.cos() # (seq_len, half) + sin_1d = freqs.sin() + + # Pad right half with zeros (only first half is used by RoPE) + cos_full = torch.cat([cos_1d, torch.zeros_like(cos_1d)], dim=-1) # (seq_len, head_dim) + sin_full = torch.cat([sin_1d, torch.zeros_like(sin_1d)], dim=-1) + + cos = cos_full.unsqueeze(0).expand(cos_bsz, seq_len, head_dim).contiguous().to(dtype) + sin = sin_full.unsqueeze(0).expand(cos_bsz, seq_len, head_dim).contiguous().to(dtype) + return cos, sin + + +def _reference_rope(q, k, cos, sin): + """ + PyTorch float32 reference for RoPE (HuggingFace Llama/Mistral half-split variant). + + q: (bsz, n_q_heads, seq_len, head_dim) + k: (bsz, n_k_heads, seq_len, head_dim) + cos/sin: (1_or_bsz, seq_len, head_dim) + """ + head_dim = q.shape[-1] + hd_half = head_dim // 2 + + # cos/sin: use first half of head_dim; expand batch dim if needed + cos_h = cos[..., :hd_half].float() # (cos_bsz, seq_len, hd_half) + sin_h = sin[..., :hd_half].float() + + def apply(x): + x_f = x.float() + x_r = x_f[..., :hd_half] # (bsz, n_heads, seq_len, hd_half) + x_i = x_f[..., hd_half:] + + # cos/sin: (cos_bsz, seq_len, hd_half) → (cos_bsz, 1, seq_len, hd_half) + c = cos_h.unsqueeze(1) + s = sin_h.unsqueeze(1) + new_r = x_r * c - x_i * s + new_i = x_i * c + x_r * s + return torch.cat([new_r, new_i], dim=-1).to(x.dtype) + + return apply(q), apply(k) + + +class Test_Liger_Rope(common.PyTestCase): + _backends = ["cutile"] + + @pytest.mark.parametrize( + "shape, dtype", + [ + ((2, 8, 16, 64), torch.float32), + ((1, 4, 8, 128), torch.float32), + ((2, 8, 16, 64), torch.float16), + ((2, 8, 16, 64), torch.bfloat16), + ((1, 2, 4, 48), torch.float32), # non-power-of-2 head_dim + ((3, 4, 7, 64), torch.float32), # non-power-of-2 heads + ], + ) + @pytest.mark.parametrize("broadcast_batch", [True, False]) + @pytest.mark.parametrize("backend", _backends) + def test_op(self, shape, dtype, broadcast_batch, backend, monkeypatch): + """Test RoPE forward pass.""" + self.setUp() + if tilegym.is_backend_available(backend): + tilegym.set_backend(backend) + else: + pytest.skip(f"Backend {backend} is not available") + + device = torch.device("cuda") + bsz, n_q_heads, seq_len, head_dim = shape + n_k_heads = max(1, n_q_heads // 2) + + q = torch.randn(bsz, n_q_heads, seq_len, head_dim, dtype=dtype, device=device) + k = torch.randn(bsz, n_k_heads, seq_len, head_dim, dtype=dtype, device=device) + cos, sin = _make_cos_sin(bsz, seq_len, head_dim, device, broadcast_batch=broadcast_batch) + + q_ref, k_ref = _reference_rope(q, k, cos, sin) + q_out, k_out = rope(q.clone(), k.clone(), cos, sin) + + assert torch.allclose(q_out.float(), q_ref.float(), atol=1e-2, rtol=1e-2), ( + f"q mismatch: max_diff={(q_out.float() - q_ref.float()).abs().max()}" + ) + assert torch.allclose(k_out.float(), k_ref.float(), atol=1e-2, rtol=1e-2), ( + f"k mismatch: max_diff={(k_out.float() - k_ref.float()).abs().max()}" + ) + + @pytest.mark.parametrize( + "shape, dtype", + [ + ((2, 8, 16, 64), torch.float32), + ((2, 8, 16, 64), torch.float16), + ((2, 8, 16, 64), torch.bfloat16), + ((1, 4, 8, 128), torch.float32), + ((1, 2, 4, 48), torch.float32), # non-power-of-2 head_dim + ((3, 4, 7, 64), torch.float32), # non-power-of-2 heads + ], + ) + @pytest.mark.parametrize("backend", _backends) + def test_op_backward(self, shape, dtype, backend, monkeypatch): + """Test backward pass (gradient flows through RoPE).""" + self.setUp() + if tilegym.is_backend_available(backend): + tilegym.set_backend(backend) + else: + pytest.skip(f"Backend {backend} is not available") + + device = torch.device("cuda") + bsz, n_q_heads, seq_len, head_dim = shape + n_k_heads = n_q_heads + + q = torch.randn(bsz, n_q_heads, seq_len, head_dim, dtype=dtype, device=device, requires_grad=True) + k = torch.randn(bsz, n_k_heads, seq_len, head_dim, dtype=dtype, device=device, requires_grad=True) + cos, sin = _make_cos_sin(bsz, seq_len, head_dim, device) + + q_out, k_out = rope(q, k, cos, sin) + (q_out.sum() + k_out.sum()).backward() + + assert q.grad is not None + assert k.grad is not None diff --git a/tests/suites/liger/test_sparsemax.py b/tests/suites/liger/test_sparsemax.py new file mode 100644 index 00000000..aa46bc71 --- /dev/null +++ b/tests/suites/liger/test_sparsemax.py @@ -0,0 +1,105 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +import gc + +import pytest +import torch + +import tilegym +from tests import common +from tilegym.suites.liger.ops import sparsemax + + +def _reference_sparsemax(x, dim=-1): + """ + PyTorch float32 reference for sparsemax. + + Projects input onto the probability simplex along `dim`. + Algorithm: sort descending → cumsum → find support → compute tau → clip. + """ + x_f = x.float() + input_dims = x_f.dim() + if dim < 0: + dim = input_dims + dim + + x_sorted, _ = torch.sort(x_f, dim=dim, descending=True) + cumsum = torch.cumsum(x_sorted, dim=dim) + input_size = x_f.size(dim) + r = torch.arange(1, input_size + 1, device=x.device, dtype=torch.float32) + shape = [1] * input_dims + shape[dim] = input_size + r = r.view(shape) + k_bound = 1 + r * x_sorted + support = k_bound > cumsum + k = support.sum(dim=dim, keepdim=True).clamp(min=1) + support_sum = (x_sorted * support).sum(dim=dim, keepdim=True) + tau = (support_sum - 1) / k + return torch.clamp(x_f - tau, min=0).to(x.dtype) + + +class Test_Liger_Sparsemax(common.PyTestCase): + _backends = ["cutile"] + + @pytest.mark.parametrize( + "shape, dtype", + [ + ((2, 128, 512), torch.float32), + ((5, 123, 123), torch.float32), + ((4, 256), torch.float32), + ], + ) + @pytest.mark.parametrize("dim", [-1, 1]) + @pytest.mark.parametrize("backend", _backends) + def test_op(self, shape, dtype, dim, backend, monkeypatch): + """Test sparsemax forward pass.""" + self.setUp() + if tilegym.is_backend_available(backend): + tilegym.set_backend(backend) + else: + pytest.skip(f"Backend {backend} is not available") + + # Skip trivial cases (dim with size 1 or same as last dim special handling) + if dim >= len(shape) or dim < -len(shape): + pytest.skip("invalid dim") + actual_dim = dim if dim >= 0 else len(shape) + dim + if shape[actual_dim] <= 1: + pytest.skip("trivial dim") + + device = torch.device("cuda") + torch.manual_seed(0) + x = torch.randn(*shape, dtype=dtype, device=device) + + def fw(): + return sparsemax(x.clone(), dim=dim) + + def ref(): + return _reference_sparsemax(x, dim=dim) + + self.assertCorrectness(fw, ref, kwargs={}, atol=1e-4, rtol=1e-4) + + @pytest.mark.parametrize( + "shape, dtype", + [ + ((2, 128, 512), torch.float32), + ((4, 256), torch.float32), + ], + ) + @pytest.mark.parametrize("backend", _backends) + def test_op_backward(self, shape, dtype, backend, monkeypatch): + """Test backward pass (gradient flows through sparsemax).""" + self.setUp() + if tilegym.is_backend_available(backend): + tilegym.set_backend(backend) + else: + pytest.skip(f"Backend {backend} is not available") + + device = torch.device("cuda") + torch.manual_seed(0) + x = torch.randn(*shape, dtype=dtype, device=device, requires_grad=True) + + y = sparsemax(x, dim=-1) + y.sum().backward() + + assert x.grad is not None diff --git a/tests/suites/liger/test_tiled_mlp.py b/tests/suites/liger/test_tiled_mlp.py new file mode 100644 index 00000000..bb6cc2bd --- /dev/null +++ b/tests/suites/liger/test_tiled_mlp.py @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +import gc + +import pytest +import torch +import torch.nn as nn + +import tilegym +from tests import common +from tilegym.suites.liger.ops import tiled_mlp + + +class SimpleMLP(nn.Module): + """Simple MLP for testing: linear + relu.""" + + def __init__(self, hidden_size, device="cuda", dtype=torch.float32): + super().__init__() + self.fc = nn.Linear(hidden_size, hidden_size, device=device, dtype=dtype) + + def forward(self, x): + return torch.relu(self.fc(x)) + + +class Test_Liger_TiledMLP(common.PyTestCase): + _backends = ["cutile"] + + @pytest.mark.parametrize( + "bsz, seq_len, hidden_size", + [ + # Shapes from Liger test_tiled_mlp.py (using hidden_size only) + (1, 1024, 128), # num_shards=8 if auto + (2, 1024, 64), # num_shards=16 if auto + (4, 127, 128), # weird shape + # Shapes from Liger test/transformers/test_tiled_mlp.py (SwiGLU variant) + (2, 512, 512), + (1, 1024, 256), + ], + ) + @pytest.mark.parametrize("num_shards", [None, 2, 4]) + @pytest.mark.parametrize("check_2d", [True, False]) + @pytest.mark.parametrize("backend", _backends) + def test_op_forward(self, bsz, seq_len, hidden_size, num_shards, check_2d, backend, monkeypatch): + """Test that tiled computation matches non-tiled computation.""" + self.setUp() + if tilegym.is_backend_available(backend): + tilegym.set_backend(backend) + else: + pytest.skip(f"Backend {backend} is not available") + + device = torch.device("cuda") + # Scale input down to reduce numerical sensitivity + x_data = torch.randn(bsz, seq_len, hidden_size, device=device, dtype=torch.float32) * 0.1 + + if check_2d: + x_data = x_data.view(-1, hidden_size) + + mlp = SimpleMLP(hidden_size, device=device, dtype=torch.float32) + fn = lambda mod, x: mod(x) + + # Reference: non-tiled computation + x_ref = x_data.detach().clone().requires_grad_(True) + out_ref = fn(mlp, x_ref) + + # Tiled computation + x_tiled = x_data.detach().clone().requires_grad_(True) + out_tiled = tiled_mlp(fn, mlp, x_tiled, num_shards=num_shards) + + # atol=1e-3: float32 matmul with different chunk sizes may use different + # cuBLAS algorithms, causing floating-point differences up to ~1e-4. + assert torch.allclose(out_tiled, out_ref, atol=1e-3, rtol=1e-3), ( + f"Forward mismatch: max_diff={((out_tiled - out_ref).abs().max()).item():.8f}" + ) + + @pytest.mark.parametrize( + "bsz, seq_len, hidden_size", + [ + (1, 1024, 128), + (4, 127, 128), + (2, 512, 512), + ], + ) + @pytest.mark.parametrize("num_shards", [None, 2, 4]) + @pytest.mark.parametrize("backend", _backends) + def test_op_backward(self, bsz, seq_len, hidden_size, num_shards, backend, monkeypatch): + """Test that tiled backward matches non-tiled backward.""" + self.setUp() + if tilegym.is_backend_available(backend): + tilegym.set_backend(backend) + else: + pytest.skip(f"Backend {backend} is not available") + + device = torch.device("cuda") + x_data = torch.randn(bsz, seq_len, hidden_size, device=device, dtype=torch.float32) * 0.1 + + mlp = SimpleMLP(hidden_size, device=device, dtype=torch.float32) + fn = lambda mod, x: mod(x) + + # Reference: non-tiled backward + x_ref = x_data.detach().clone().requires_grad_(True) + out_ref = fn(mlp, x_ref) + out_ref.sum().backward() + + # Tiled backward + x_tiled = x_data.detach().clone().requires_grad_(True) + out_tiled = tiled_mlp(fn, mlp, x_tiled, num_shards=num_shards) + out_tiled.sum().backward() + + assert x_tiled.grad is not None, "Tiled backward produced no gradient" + assert torch.allclose(x_tiled.grad, x_ref.grad, atol=1e-5, rtol=1e-5), ( + f"Backward mismatch: max_diff={((x_tiled.grad - x_ref.grad).abs().max()).item():.8f}" + )