diff --git a/examples/jit_cpp/fast_inverse/.gitignore b/examples/jit_cpp/fast_inverse/.gitignore new file mode 100644 index 00000000..a937d7e5 --- /dev/null +++ b/examples/jit_cpp/fast_inverse/.gitignore @@ -0,0 +1,2 @@ +benchmark_results +*.so diff --git a/examples/jit_cpp/fast_inverse/README.md b/examples/jit_cpp/fast_inverse/README.md new file mode 100644 index 00000000..07621c5e --- /dev/null +++ b/examples/jit_cpp/fast_inverse/README.md @@ -0,0 +1,111 @@ +## fast_inverse — JIT triangular matrix inverse (recursive unroll) + +JIT-compiled example of `kernel_tri_inv_rec_unroll`, which inverts a batch of +upper-triangular fp16 matrices stored in a multi-dimensional tensor. + +### Algorithm + +Given an input tensor whose last two dimensions form an n×n upper-triangular +matrix U (off-diagonal part only; the diagonal is assumed to be all-ones), the +kernel computes the inverse of (U + I) for every matrix in the batch. + +The implementation uses a two-phase recursive approach on Ascend cube cores: + +1. **Inv-trick phase** – inverts each 16×16 diagonal fractal block via a + Neumann-series expansion (`X = (I − M) + (I − M)·M + …`). +2. **Unrolled recursion phase** – assembles partial inverses of progressively + larger sub-blocks until the full matrix is inverted. + +### Files + +| File | Purpose | +|------|---------| +| `fast_inverse.cpp` | Thin JIT wrapper: includes the kernel and exposes `call_kernel` | +| `jit_util_fast_inverse.py` | Compiles the kernel with `bisheng` and loads it via `ctypes` | +| `run_fast_inverse.py` | Correctness test suite, including aligned and varlen BSND coverage | +| `run_fast_inverse_varlen_like_triton.py` | Standalone varlen runner that mirrors the Triton `test_solve_tril_varlen` input generation in pure PyTorch | +| `benchmark_bsnd_fast_inverse.py` | Benchmarks fixed BSND vs varlen-uniform BSND and plots effective bandwidth | + +### Usage + +```bash +export PTO_LIB_PATH=/sources/pto-isa/ # need latest header, not CANN 8.5.0 default + +cd examples/jit_cpp/fast_inverse +python run_fast_inverse.py +``` + +The script compiles `fast_inverse.cpp` on first run (takes ~60 s), then +executes correctness checks across a range of matrix sizes (16, 32, 64, 128) +and batch configurations. + +To run the standalone Triton-like varlen coverage: + +```bash +export PTO_LIB_PATH=/sources/pto-isa/ + +cd examples/jit_cpp/fast_inverse +python run_fast_inverse_varlen_like_triton.py +``` + +That script: + +- uses the same varlen case list and input-generation structure as + `flash-linear-attention/tests/ops/test_solve_tril.py::test_solve_tril_varlen` +- keeps PTO inputs in `float16` +- emulates `chunk_scaled_dot_kkt_fwd` in PyTorch because Triton is not available +- prints a simple pytest-like `PASS` / `FAIL` report plus a final summary + +### Supported matrix sizes + +`matrix_size` (last dimension of the input tensor) must be one of: **16, 32, +64, 128**. + +### Layout conventions + +In general, the input to the `fast_inverse` kernels is a set of `D × D` sized triangular matrices. Depending on how these matrices are stored in memory, we might have `contiguous` layout, or the so-called `BSND` layout. The main input is a batch of sequences, and each sequence is then split in "chunks" of length `chunk_size`. This `chunk_size` is the same as the matrix size `D`. + +Both layouts depend on the following parameters: +- The parameter `B` denotes the batch-size (or batch-dimension). This is always the first dimension of the input tensor. +- The parameter `N` or `H` (used interchangeably) is the number of heads. +- `D` is equal to the `chunk_size`. +- `S` is the total sum of all sequence lengths combined. +`BSND` can be thought of as the "raw" input tensor. The `contiguous` layout can be obtained, for example, by transposing the `N` and `S` dimensions, and by "chunking" the `S` dimension to chunks of size `S`. The final tensor will be transformed from shape `(B,S,N,D)` to `->(B,N,S/D,D)`, where we assumed that `D` divides `S` for simplicity. + +The actual kernel can verify if the input is in `BSND` layout or in `contiguous` layout by specifying the input argument `num_bsnd_heads`. If it is equal to zero, then the format is assumed to be `contiguous` + +| `num_bsnd_heads` | Memory layout | +|-----------------|---------------| +| `0` (default) | Each matrix stored consecutively in row-major order (`B × … × D × D`) | +| `> 0` | BSND layout: `(B, S, N, D)` where `S` is chunked into tiles of size D and N heads are interleaved | + +### Varlen BSND mode + +The standalone example also supports variable-length BSND inputs with the same +external signature as the Triton reference path: callers provide packed BSND +data plus `cu_seqlens`, and the PTO kernel derives each chunk row-start and +tail size internally on NPU. The kernel still inverts dense `D x D` tiles, but +tail chunks load/store only their valid prefix. + +### Benchmark + +To compare the original fixed-length BSND path against the new varlen path in a +matched-size sanity check: + +```bash +export PTO_LIB_PATH=/sources/pto-isa/ + +cd examples/jit_cpp/fast_inverse +python benchmark_bsnd_fast_inverse.py --chunk-size 64 +``` + +The benchmark script: + +- runs only the PTO-ISA BSND kernel +- compares `bsnd-fixed` against `bsnd-varlen-uniform` +- uses uniform `cu_seqlens=[0, T, 2T, ...]` so both paths process the same + total data size +- reports numerical agreement between the two outputs +- also generates a true-varlen benchmark that plots scattered bandwidth points + against aggregated sequence length +- writes all CSV and PNG artifacts into `examples/jit_cpp/fast_inverse/benchmark_results/` diff --git a/examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py b/examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py new file mode 100644 index 00000000..257101ab --- /dev/null +++ b/examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py @@ -0,0 +1,981 @@ +#!/usr/bin/env python3 +# -------------------------------------------------------------------------------- +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# All rights reserved. +# See LICENSE in the root of the software repository: +# https://github.com/huawei-csl/pto-kernels/ +# for the full License text. +# -------------------------------------------------------------------------------- + +""" +Benchmark the standalone BSND fast-inverse kernel. + +This script benchmarks the PTO-ISA BSND kernel in two modes using Triton-unit- +test-like inputs: + +1. `bsnd-fixed`: + Original aligned BSND layout with shape `(B, T, H, D)`. +2. `bsnd-varlen-uniform`: + The new varlen path using packed shape `(1, B*T, H, D)` with uniform + `cu_seqlens = [0, T, 2T, ...]`. + +The two modes use the same total token count and the same underlying `k` / `beta` +inputs. `A` is generated in eager PyTorch with an emulation of +`chunk_scaled_dot_kkt_fwd`, then each valid chunk is transposed before launch so +the PTO kernel still sees its expected upper-triangular layout. The script also +checks that both modes produce numerically matching results after transposing +outputs back to the lower-triangular convention used by the Triton tests. +""" + +from __future__ import annotations + +import argparse +import csv +import math +import os +import time +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn.functional as F +import torch_npu # noqa: F401 + +from host_metadata_util import ( + build_chunk_sequence_prefix_cpp, + build_varlen_chunk_metadata_cpp, +) +from jit_util_fast_inverse import jit_compile + + +DEFAULT_SEQLENS = (512, 1024, 2048, 4096, 8192, 16384) +DEFAULT_CACHE_SIZE = 256 * 1024 * 1024 +DEFAULT_FEATURE_DIM = 64 +NPU_DEVICE = os.getenv("GDN_TRI_INVERSE_NPU_DEVICE", "npu:0") +THIS_DIR = Path(__file__).resolve().parent +RESULTS_DIR = THIS_DIR / "benchmark_results" +DEFAULT_TRUE_VARLEN_SAMPLES = 6 + + +def parse_int_list(spec: str) -> tuple[int, ...]: + parts = [p.strip() for p in spec.split(",") if p.strip()] + if not parts: + raise argparse.ArgumentTypeError("expected at least one integer") + try: + return tuple(int(p, 10) for p in parts) + except ValueError as exc: + raise argparse.ArgumentTypeError(f"invalid integer list {spec!r}: {exc}") from exc + + +def make_minus_identity(matrix_size: int, device: str) -> torch.Tensor: + minus_identity = torch.zeros( + matrix_size, + matrix_size, + dtype=torch.half, + device=device, + ) + minus_identity.fill_diagonal_(-1) + return minus_identity + + +def count_varlen_chunks( + cu_seqlens: torch.Tensor, + chunk_size: int, +) -> int: + cu_seqlens_list = [int(x) for x in cu_seqlens.detach().cpu().tolist()] + return sum( + (cu_seqlens_list[i + 1] - cu_seqlens_list[i] + chunk_size - 1) // chunk_size + for i in range(len(cu_seqlens_list) - 1) + ) + + +def chunk_scaled_dot_kkt_fwd_emulated( + k: torch.Tensor, + beta: torch.Tensor, + cu_seqlens: torch.Tensor, + chunk_size: int, +) -> torch.Tensor: + total_tokens = int(cu_seqlens[-1].item()) + num_heads = k.shape[2] + A = torch.zeros((1, total_tokens, num_heads, chunk_size), dtype=k.dtype, device=k.device) + + for bos, eos in zip(cu_seqlens[:-1].tolist(), cu_seqlens[1:].tolist(), strict=False): + for chunk_start in range(bos, eos, chunk_size): + chunk_end = min(chunk_start + chunk_size, eos) + actual_size = chunk_end - chunk_start + k_chunk = k[:, chunk_start:chunk_end].transpose(1, 2).to(torch.float32) + beta_chunk = ( + beta[:, chunk_start:chunk_end] + .transpose(1, 2) + .unsqueeze(-1) + .to(torch.float32) + ) + scores = torch.matmul(k_chunk, k_chunk.transpose(-1, -2)) + scores = torch.tril(scores * beta_chunk, diagonal=-1).to(k.dtype) + A[:, chunk_start:chunk_end, :, :actual_size] = scores.transpose(1, 2) + + return A + + +def transpose_valid_chunks( + A: torch.Tensor, + cu_seqlens: torch.Tensor, + chunk_size: int, +) -> torch.Tensor: + transposed = torch.zeros_like(A) + for bos, eos in zip(cu_seqlens[:-1].tolist(), cu_seqlens[1:].tolist(), strict=False): + for chunk_start in range(bos, eos, chunk_size): + actual_size = min(chunk_size, eos - chunk_start) + chunk = A[:, chunk_start : chunk_start + actual_size, :, :actual_size] + transposed[:, chunk_start : chunk_start + actual_size, :, :actual_size] = chunk.transpose( + 1, 3 + ) + return transposed + + +def build_fixed_bsnd_input( + batch_size: int, + seqlen: int, + num_heads: int, + chunk_size: int, + feature_dim: int, + device: str, +) -> tuple[torch.Tensor, torch.Tensor]: + total_tokens = batch_size * seqlen + cu_seqlens = torch.arange( + 0, + total_tokens + 1, + seqlen, + dtype=torch.int32, + device=device, + ) + k = F.normalize( + torch.randn((1, total_tokens, num_heads, feature_dim), dtype=torch.float16, device=device), + dim=-1, + ) + beta = torch.randn((1, total_tokens, num_heads), dtype=torch.float16, device=device).sigmoid() + A = transpose_valid_chunks( + chunk_scaled_dot_kkt_fwd_emulated(k, beta, cu_seqlens, chunk_size), + cu_seqlens, + chunk_size, + ) + return A.reshape(batch_size, seqlen, num_heads, chunk_size).contiguous(), cu_seqlens + + +def build_uniform_varlen_input( + fixed_input: torch.Tensor, + batch_size: int, + seqlen: int, + chunk_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + total_tokens = batch_size * seqlen + packed_input = fixed_input.reshape(1, total_tokens, fixed_input.shape[2], chunk_size).contiguous() + cu_seqlens = torch.arange( + 0, + total_tokens + 1, + seqlen, + dtype=torch.int32, + device=fixed_input.device, + ) + return packed_input, cu_seqlens + + +def sample_true_varlen_lengths( + batch_size: int, + aggregated_tokens: int, + rng: np.random.Generator, +) -> list[int]: + if aggregated_tokens < batch_size: + raise ValueError("aggregated_tokens must be >= batch_size.") + + remaining = aggregated_tokens - batch_size + while True: + weights = rng.dirichlet(np.ones(batch_size)) + extras = np.floor(weights * remaining).astype(np.int64) + deficit = remaining - int(extras.sum()) + if deficit > 0: + extras[:deficit] += 1 + lengths = (extras + 1).tolist() + if any(length != lengths[0] for length in lengths): + return lengths + + +def build_true_varlen_input( + seq_lens: list[int], + num_heads: int, + chunk_size: int, + feature_dim: int, + device: str, +) -> tuple[torch.Tensor, torch.Tensor]: + cu_seqlens = np.cumsum([0, *seq_lens], dtype=np.int64) + cu_seqlens_tensor = torch.tensor(cu_seqlens.tolist(), dtype=torch.int32, device=device) + total_tokens = int(cu_seqlens[-1]) + k = F.normalize( + torch.randn((1, total_tokens, num_heads, feature_dim), dtype=torch.float16, device=device), + dim=-1, + ) + beta = torch.randn((1, total_tokens, num_heads), dtype=torch.float16, device=device).sigmoid() + packed_input = transpose_valid_chunks( + chunk_scaled_dot_kkt_fwd_emulated(k, beta, cu_seqlens_tensor, chunk_size), + cu_seqlens_tensor, + chunk_size, + ) + return packed_input.contiguous(), cu_seqlens_tensor + + +def make_fixed_runner( + tri_inv_func, + tensor_in: torch.Tensor, +) -> tuple[callable, torch.Tensor]: + matrix_size = tensor_in.shape[-1] + num_bsnd_heads = tensor_in.shape[-2] + num_matrices = tensor_in.numel() // (matrix_size * matrix_size) + tensor_out = torch.empty_like(tensor_in, dtype=torch.float32) + minus_identity = make_minus_identity(matrix_size, str(tensor_in.device)) + + def run(): + tri_inv_func( + tensor_out, + tensor_in, + minus_identity, + matrix_size, + num_matrices, + num_bsnd_heads, + ) + + return run, tensor_out + + +def make_varlen_runner( + tri_inv_func, + tensor_in: torch.Tensor, + cu_seqlens: torch.Tensor, +) -> tuple[callable, torch.Tensor]: + matrix_size = tensor_in.shape[-1] + num_bsnd_heads = tensor_in.shape[-2] + num_matrices = count_varlen_chunks(cu_seqlens, matrix_size) * num_bsnd_heads + tensor_out = torch.empty_like(tensor_in, dtype=torch.float32) + minus_identity = make_minus_identity(matrix_size, str(tensor_in.device)) + + def run(): + tri_inv_func( + tensor_out, + tensor_in, + minus_identity, + matrix_size, + num_matrices, + num_bsnd_heads, + cu_seqlens=cu_seqlens, + ) + + return run, tensor_out + + +def make_varlen_runner_host_metadata( + tri_inv_func, + tensor_in: torch.Tensor, + chunk_indices: torch.Tensor, + chunk_valid_sizes: torch.Tensor, +) -> tuple[callable, torch.Tensor]: + matrix_size = tensor_in.shape[-1] + num_bsnd_heads = tensor_in.shape[-2] + num_matrices = int(chunk_indices.numel()) * num_bsnd_heads + tensor_out = torch.empty_like(tensor_in, dtype=torch.float32) + minus_identity = make_minus_identity(matrix_size, str(tensor_in.device)) + + def run(): + tri_inv_func( + tensor_out, + tensor_in, + minus_identity, + matrix_size, + num_matrices, + num_bsnd_heads, + chunk_indices=chunk_indices, + chunk_valid_sizes=chunk_valid_sizes, + ) + + return run, tensor_out + + +def make_varlen_runner_prefix_metadata( + tri_inv_func, + tensor_in: torch.Tensor, + cu_seqlens: torch.Tensor, + chunk_sequence_prefix: torch.Tensor, +) -> tuple[callable, torch.Tensor]: + matrix_size = tensor_in.shape[-1] + num_bsnd_heads = tensor_in.shape[-2] + num_matrices = count_varlen_chunks(cu_seqlens, matrix_size) * num_bsnd_heads + tensor_out = torch.empty_like(tensor_in, dtype=torch.float32) + minus_identity = make_minus_identity(matrix_size, str(tensor_in.device)) + + def run(): + tri_inv_func( + tensor_out, + tensor_in, + minus_identity, + matrix_size, + num_matrices, + num_bsnd_heads, + cu_seqlens=cu_seqlens, + chunk_sequence_prefix=chunk_sequence_prefix, + ) + + return run, tensor_out + + +def benchmark_ms( + fn, + warmup_iters: int, + benchmark_iters: int, + device: str, +) -> list[float]: + start_events = [torch.npu.Event(enable_timing=True) for _ in range(benchmark_iters)] + end_events = [torch.npu.Event(enable_timing=True) for _ in range(benchmark_iters)] + + torch.npu.synchronize() + for _ in range(warmup_iters): + fn() + torch.npu.synchronize() + + cache = torch.ones(DEFAULT_CACHE_SIZE, dtype=torch.int8, device=device) + times_ms: list[float] = [] + for idx in range(benchmark_iters): + cache.zero_() + torch.npu.synchronize() + start_events[idx].record() + fn() + end_events[idx].record() + end_events[idx].synchronize() + times_ms.append(start_events[idx].elapsed_time(end_events[idx])) + return times_ms + + +def build_host_metadata_on_npu( + cu_seqlens: torch.Tensor, + chunk_size: int, + device: str, +) -> tuple[torch.Tensor, torch.Tensor]: + chunk_indices_cpu, chunk_valid_sizes_cpu = build_varlen_chunk_metadata_cpp( + cu_seqlens, + chunk_size, + ) + return ( + chunk_indices_cpu.to(device=device).contiguous(), + chunk_valid_sizes_cpu.to(device=device).contiguous(), + ) + + +def build_prefix_metadata_on_npu( + cu_seqlens: torch.Tensor, + chunk_size: int, + device: str, +) -> torch.Tensor: + return build_chunk_sequence_prefix_cpp(cu_seqlens, chunk_size).to( + device=device + ).contiguous() + + +def benchmark_host_metadata_prep_ms( + cu_seqlens: torch.Tensor, + chunk_size: int, + benchmark_iters: int, + device: str, +) -> list[float]: + times_ms: list[float] = [] + cache = torch.ones(DEFAULT_CACHE_SIZE, dtype=torch.int8, device=device) + for _ in range(benchmark_iters): + cache.zero_() + torch.npu.synchronize() + start = time.perf_counter() + build_host_metadata_on_npu(cu_seqlens, chunk_size, device) + torch.npu.synchronize() + times_ms.append((time.perf_counter() - start) * 1000.0) + return times_ms + + +def benchmark_prefix_metadata_prep_ms( + cu_seqlens: torch.Tensor, + chunk_size: int, + benchmark_iters: int, + device: str, +) -> list[float]: + times_ms: list[float] = [] + cache = torch.ones(DEFAULT_CACHE_SIZE, dtype=torch.int8, device=device) + for _ in range(benchmark_iters): + cache.zero_() + torch.npu.synchronize() + start = time.perf_counter() + build_prefix_metadata_on_npu(cu_seqlens, chunk_size, device) + torch.npu.synchronize() + times_ms.append((time.perf_counter() - start) * 1000.0) + return times_ms + + +def add_bandwidth_fields(row: dict[str, float | int | str], input_dtype_bytes: int = 2) -> None: + size_elems = int(row.get("valid_numel", row["numel"])) + mem_bytes = size_elems * (input_dtype_bytes + 4) + row["mem_bytes"] = mem_bytes + row["bw_gbs"] = (mem_bytes / 1e9) / (float(row["time_us"]) / 1e6) + + +def accuracy_metrics(reference: torch.Tensor, candidate: torch.Tensor) -> tuple[float, float]: + ref = reference.detach().cpu().to(torch.float64) + cand = candidate.detach().cpu().to(torch.float64) + diff = ref - cand + max_abs = diff.abs().max().item() + denom = torch.sum(ref * ref).item() + rel_frob = 0.0 if denom == 0 else math.sqrt(torch.sum(diff * diff).item() / denom) + return max_abs, rel_frob + + +def write_csv(csv_path: Path, rows: list[dict[str, float | int | str]]) -> None: + csv_path.parent.mkdir(parents=True, exist_ok=True) + fieldnames = [ + "inverse_type", + "metadata_strategy", + "dtype", + "B", + "T", + "aggregated_T", + "padded_T", + "H", + "numel", + "valid_numel", + "chunk_size", + "time_us", + "kernel_time_us", + "metadata_time_us", + "mem_bytes", + "bw_gbs", + "max_abs_diff_to_fixed", + "rel_frob_diff_to_fixed", + "sample_id", + "seq_lens", + ] + with csv_path.open("w", newline="", encoding="utf-8") as handle: + writer = csv.DictWriter(handle, fieldnames=fieldnames) + writer.writeheader() + for row in rows: + writer.writerow(row) + + +def plot_bandwidth(plot_path: Path, rows: list[dict[str, float | int | str]], batch_size: int, num_heads: int, chunk_size: int) -> None: + plot_path.parent.mkdir(parents=True, exist_ok=True) + fixed_rows = [row for row in rows if row["inverse_type"] == "bsnd-fixed"] + varlen_device_rows = [ + row + for row in rows + if row["inverse_type"] == "bsnd-varlen-uniform" + and row["metadata_strategy"] == "device-cu_seqlens" + ] + varlen_host_rows = [ + row + for row in rows + if row["inverse_type"] == "bsnd-varlen-uniform" + and row["metadata_strategy"] == "host-cpp" + ] + varlen_prefix_rows = [ + row + for row in rows + if row["inverse_type"] == "bsnd-varlen-uniform" + and row["metadata_strategy"] == "device-chunk-prefix" + ] + + fig, ax = plt.subplots(figsize=(7.5, 5.0)) + ax.plot( + [int(row["T"]) / 1000.0 for row in fixed_rows], + [float(row["bw_gbs"]) for row in fixed_rows], + marker="o", + linewidth=2, + label="BSND fixed", + ) + ax.plot( + [int(row["T"]) / 1000.0 for row in varlen_device_rows], + [float(row["bw_gbs"]) for row in varlen_device_rows], + marker="s", + linewidth=2, + label="BSND varlen device metadata", + ) + ax.plot( + [int(row["T"]) / 1000.0 for row in varlen_host_rows], + [float(row["bw_gbs"]) for row in varlen_host_rows], + marker="^", + linewidth=2, + label="BSND varlen host metadata", + ) + ax.plot( + [int(row["T"]) / 1000.0 for row in varlen_prefix_rows], + [float(row["bw_gbs"]) for row in varlen_prefix_rows], + marker="d", + linewidth=2, + label="BSND varlen prefix metadata", + ) + ax.set_xlabel("Sequence length T (K)") + ax.set_ylabel("Effective bandwidth (GB/s)") + ax.set_title( + f"Fast inverse BSND bandwidth\n" + f"(batch={batch_size}, head={num_heads}, chunk_size={chunk_size})" + ) + ax.set_ylim(bottom=0) + ax.grid(alpha=0.25) + ax.legend() + fig.tight_layout() + fig.savefig(plot_path, dpi=150) + plt.close(fig) + + +def plot_true_varlen_scatter( + plot_path: Path, + rows: list[dict[str, float | int | str]], + batch_size: int, + num_heads: int, + chunk_size: int, +) -> None: + plot_path.parent.mkdir(parents=True, exist_ok=True) + fig, ax = plt.subplots(figsize=(7.5, 5.0)) + ax.scatter( + [int(row["aggregated_T"]) for row in rows], + [float(row["bw_gbs"]) for row in rows], + alpha=0.8, + s=32, + ) + ax.set_xlabel("Aggregated sequence length") + ax.set_ylabel("Effective bandwidth (GB/s)") + ax.set_title( + f"Fast inverse true-varlen BSND bandwidth\n" + f"(batch={batch_size}, head={num_heads}, chunk_size={chunk_size})" + ) + ax.set_ylim(bottom=0) + ax.grid(alpha=0.25) + fig.tight_layout() + fig.savefig(plot_path, dpi=150) + plt.close(fig) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Benchmark standalone BSND fast-inverse kernel.") + parser.add_argument("--warmup", type=int, default=5) + parser.add_argument("--repeats", type=int, default=20) + parser.add_argument("--B", type=int, default=32, help="Dense BSND batch size.") + parser.add_argument("--H", type=int, default=4, help="Number of BSND heads.") + parser.add_argument("--chunk-size", type=int, default=64) + parser.add_argument( + "--feature-dim", + type=int, + default=DEFAULT_FEATURE_DIM, + help="Feature dimension used to generate Triton-like `k` inputs.", + ) + parser.add_argument( + "--seqlens", + type=parse_int_list, + default=DEFAULT_SEQLENS, + metavar="T[,T,...]", + help=( + "Comma-separated dense per-sequence lengths to benchmark " + f"(default: {','.join(map(str, DEFAULT_SEQLENS))})" + ), + ) + parser.add_argument( + "--csv", + type=str, + default="", + help="Optional CSV output path. Defaults to bench_results_bsnd_fast_inverse_.csv", + ) + parser.add_argument( + "--plot", + type=str, + default="", + help="Optional plot output path. Defaults to bench_results_bsnd_fast_inverse_bw_.png", + ) + parser.add_argument( + "--true-varlen-csv", + type=str, + default="", + help="Optional CSV path for true-varlen benchmark points.", + ) + parser.add_argument( + "--true-varlen-plot", + type=str, + default="", + help="Optional scatter plot path for true-varlen benchmark points.", + ) + parser.add_argument( + "--true-varlen-samples", + type=int, + default=DEFAULT_TRUE_VARLEN_SAMPLES, + help="Number of random true-varlen batches per aggregated sequence length.", + ) + args = parser.parse_args() + + torch.npu.set_device(NPU_DEVICE) + + src = THIS_DIR / "fast_inverse.cpp" + print(f"Compiling {src} ...") + tri_inv_func = jit_compile(str(src)) + print("Compilation successful.\n") + + RESULTS_DIR.mkdir(parents=True, exist_ok=True) + csv_path = ( + Path(args.csv) + if args.csv + else RESULTS_DIR / f"bench_results_bsnd_fast_inverse_{args.chunk_size}.csv" + ) + plot_path = ( + Path(args.plot) + if args.plot + else RESULTS_DIR / f"bench_results_bsnd_fast_inverse_bw_{args.chunk_size}.png" + ) + true_varlen_csv_path = ( + Path(args.true_varlen_csv) + if args.true_varlen_csv + else RESULTS_DIR / f"bench_results_bsnd_fast_inverse_true_varlen_{args.chunk_size}.csv" + ) + true_varlen_plot_path = ( + Path(args.true_varlen_plot) + if args.true_varlen_plot + else RESULTS_DIR / f"bench_results_bsnd_fast_inverse_true_varlen_bw_{args.chunk_size}.png" + ) + + rows: list[dict[str, float | int | str]] = [] + true_varlen_rows: list[dict[str, float | int | str]] = [] + rng = np.random.default_rng(42) + + for seqlen in args.seqlens: + if seqlen % args.chunk_size != 0: + print( + f"Skipping T={seqlen}: requires T to be a multiple of chunk_size={args.chunk_size} " + "for matched fixed vs uniform-varlen comparison." + ) + continue + + total_tokens = args.B * seqlen + print( + f"Profiling T={seqlen}, total_tokens={total_tokens}, " + f"B={args.B}, H={args.H}, chunk_size={args.chunk_size}, feature_dim={args.feature_dim}" + ) + + fixed_input, uniform_cu_seqlens = build_fixed_bsnd_input( + batch_size=args.B, + seqlen=seqlen, + num_heads=args.H, + chunk_size=args.chunk_size, + feature_dim=args.feature_dim, + device=NPU_DEVICE, + ) + varlen_input, cu_seqlens = build_uniform_varlen_input( + fixed_input, + batch_size=args.B, + seqlen=seqlen, + chunk_size=args.chunk_size, + ) + cu_seqlens_cpu = cu_seqlens.cpu() + + print(f" uniform cu_seqlens: {cu_seqlens.cpu().tolist()}") + + fixed_run, fixed_out = make_fixed_runner(tri_inv_func, fixed_input) + varlen_run_device, varlen_out_device = make_varlen_runner( + tri_inv_func, + varlen_input, + cu_seqlens, + ) + chunk_sequence_prefix = build_prefix_metadata_on_npu( + cu_seqlens_cpu, + args.chunk_size, + NPU_DEVICE, + ) + varlen_run_prefix, varlen_out_prefix = make_varlen_runner_prefix_metadata( + tri_inv_func, + varlen_input, + cu_seqlens, + chunk_sequence_prefix, + ) + chunk_indices, chunk_valid_sizes = build_host_metadata_on_npu( + cu_seqlens_cpu, + args.chunk_size, + NPU_DEVICE, + ) + varlen_run_host, varlen_out_host = make_varlen_runner_host_metadata( + tri_inv_func, + varlen_input, + chunk_indices, + chunk_valid_sizes, + ) + + fixed_run() + varlen_run_device() + varlen_run_prefix() + varlen_run_host() + torch.npu.synchronize() + + packed_fixed_out = transpose_valid_chunks( + fixed_out.reshape(1, total_tokens, args.H, args.chunk_size), + uniform_cu_seqlens, + args.chunk_size, + ) + packed_varlen_out_device = transpose_valid_chunks( + varlen_out_device, + cu_seqlens, + args.chunk_size, + ) + packed_varlen_out_prefix = transpose_valid_chunks( + varlen_out_prefix, + cu_seqlens, + args.chunk_size, + ) + packed_varlen_out_host = transpose_valid_chunks( + varlen_out_host, + cu_seqlens, + args.chunk_size, + ) + max_abs_diff_device, rel_frob_diff_device = accuracy_metrics( + packed_fixed_out, + packed_varlen_out_device, + ) + max_abs_diff_host, rel_frob_diff_host = accuracy_metrics( + packed_fixed_out, + packed_varlen_out_host, + ) + max_abs_diff_prefix, rel_frob_diff_prefix = accuracy_metrics( + packed_fixed_out, + packed_varlen_out_prefix, + ) + print( + f" accuracy vs fixed: device max_abs_diff={max_abs_diff_device:.3e}, " + f"device rel_frob_diff={rel_frob_diff_device:.3e}, " + f"prefix max_abs_diff={max_abs_diff_prefix:.3e}, " + f"prefix rel_frob_diff={rel_frob_diff_prefix:.3e}, " + f"host max_abs_diff={max_abs_diff_host:.3e}, " + f"host rel_frob_diff={rel_frob_diff_host:.3e}" + ) + + fixed_times_ms = benchmark_ms( + fixed_run, + warmup_iters=args.warmup, + benchmark_iters=args.repeats, + device=NPU_DEVICE, + ) + varlen_device_times_ms = benchmark_ms( + varlen_run_device, + warmup_iters=args.warmup, + benchmark_iters=args.repeats, + device=NPU_DEVICE, + ) + prefix_metadata_times_ms = benchmark_prefix_metadata_prep_ms( + cu_seqlens_cpu, + args.chunk_size, + benchmark_iters=args.repeats, + device=NPU_DEVICE, + ) + varlen_prefix_kernel_times_ms = benchmark_ms( + varlen_run_prefix, + warmup_iters=args.warmup, + benchmark_iters=args.repeats, + device=NPU_DEVICE, + ) + host_metadata_times_ms = benchmark_host_metadata_prep_ms( + cu_seqlens_cpu, + args.chunk_size, + benchmark_iters=args.repeats, + device=NPU_DEVICE, + ) + varlen_host_kernel_times_ms = benchmark_ms( + varlen_run_host, + warmup_iters=args.warmup, + benchmark_iters=args.repeats, + device=NPU_DEVICE, + ) + + fixed_row = { + "inverse_type": "bsnd-fixed", + "metadata_strategy": "none", + "dtype": "fp16", + "B": args.B, + "T": seqlen, + "aggregated_T": total_tokens, + "padded_T": total_tokens, + "H": args.H, + "numel": fixed_input.numel(), + "valid_numel": fixed_input.numel(), + "chunk_size": args.chunk_size, + "time_us": int(round(np.mean(fixed_times_ms) * 1000.0)), + "kernel_time_us": int(round(np.mean(fixed_times_ms) * 1000.0)), + "metadata_time_us": 0, + "max_abs_diff_to_fixed": 0.0, + "rel_frob_diff_to_fixed": 0.0, + "sample_id": "", + "seq_lens": "", + } + add_bandwidth_fields(fixed_row) + + varlen_device_row = { + "inverse_type": "bsnd-varlen-uniform", + "metadata_strategy": "device-cu_seqlens", + "dtype": "fp16", + "B": args.B, + "T": seqlen, + "aggregated_T": total_tokens, + "padded_T": total_tokens, + "H": args.H, + "numel": varlen_input.numel(), + "valid_numel": total_tokens * args.H * args.chunk_size, + "chunk_size": args.chunk_size, + "time_us": int(round(np.mean(varlen_device_times_ms) * 1000.0)), + "kernel_time_us": int(round(np.mean(varlen_device_times_ms) * 1000.0)), + "metadata_time_us": 0, + "max_abs_diff_to_fixed": max_abs_diff_device, + "rel_frob_diff_to_fixed": rel_frob_diff_device, + "sample_id": "", + "seq_lens": ",".join([str(seqlen)] * args.B), + } + add_bandwidth_fields(varlen_device_row) + + avg_prefix_metadata_us = int(round(np.mean(prefix_metadata_times_ms) * 1000.0)) + avg_prefix_kernel_us = int(round(np.mean(varlen_prefix_kernel_times_ms) * 1000.0)) + varlen_prefix_row = { + "inverse_type": "bsnd-varlen-uniform", + "metadata_strategy": "device-chunk-prefix", + "dtype": "fp16", + "B": args.B, + "T": seqlen, + "aggregated_T": total_tokens, + "padded_T": total_tokens, + "H": args.H, + "numel": varlen_input.numel(), + "valid_numel": total_tokens * args.H * args.chunk_size, + "chunk_size": args.chunk_size, + "time_us": avg_prefix_metadata_us + avg_prefix_kernel_us, + "kernel_time_us": avg_prefix_kernel_us, + "metadata_time_us": avg_prefix_metadata_us, + "max_abs_diff_to_fixed": max_abs_diff_prefix, + "rel_frob_diff_to_fixed": rel_frob_diff_prefix, + "sample_id": "", + "seq_lens": ",".join([str(seqlen)] * args.B), + } + add_bandwidth_fields(varlen_prefix_row) + + avg_host_metadata_us = int(round(np.mean(host_metadata_times_ms) * 1000.0)) + avg_host_kernel_us = int(round(np.mean(varlen_host_kernel_times_ms) * 1000.0)) + varlen_host_row = { + "inverse_type": "bsnd-varlen-uniform", + "metadata_strategy": "host-cpp", + "dtype": "fp16", + "B": args.B, + "T": seqlen, + "aggregated_T": total_tokens, + "padded_T": total_tokens, + "H": args.H, + "numel": varlen_input.numel(), + "valid_numel": total_tokens * args.H * args.chunk_size, + "chunk_size": args.chunk_size, + "time_us": avg_host_metadata_us + avg_host_kernel_us, + "kernel_time_us": avg_host_kernel_us, + "metadata_time_us": avg_host_metadata_us, + "max_abs_diff_to_fixed": max_abs_diff_host, + "rel_frob_diff_to_fixed": rel_frob_diff_host, + "sample_id": "", + "seq_lens": ",".join([str(seqlen)] * args.B), + } + add_bandwidth_fields(varlen_host_row) + + rows.extend([fixed_row, varlen_device_row, varlen_prefix_row, varlen_host_row]) + print( + f" fixed: time_us={fixed_row['time_us']}, bw_gbs={fixed_row['bw_gbs']:.2f} | " + f"varlen-device: time_us={varlen_device_row['time_us']}, " + f"bw_gbs={varlen_device_row['bw_gbs']:.2f} | " + f"varlen-prefix: time_us={varlen_prefix_row['time_us']} " + f"(meta={varlen_prefix_row['metadata_time_us']}, kernel={varlen_prefix_row['kernel_time_us']}), " + f"bw_gbs={varlen_prefix_row['bw_gbs']:.2f} | " + f"varlen-host: time_us={varlen_host_row['time_us']} " + f"(meta={varlen_host_row['metadata_time_us']}, kernel={varlen_host_row['kernel_time_us']}), " + f"bw_gbs={varlen_host_row['bw_gbs']:.2f}" + ) + device_metadata_overhead_us = ( + varlen_device_row["kernel_time_us"] - varlen_host_row["kernel_time_us"] + ) + prefix_metadata_overhead_us = ( + varlen_device_row["kernel_time_us"] - varlen_prefix_row["kernel_time_us"] + ) + print( + f" metadata overhead comparison: device_vs_host_kernel_delta_us={device_metadata_overhead_us}, " + f"device_vs_prefix_kernel_delta_us={prefix_metadata_overhead_us}, " + f"prefix_metadata_us={varlen_prefix_row['metadata_time_us']}, " + f"host_cpp_metadata_us={varlen_host_row['metadata_time_us']}" + ) + + for sample_idx in range(args.true_varlen_samples): + seq_lens = sample_true_varlen_lengths(args.B, total_tokens, rng) + packed_input, cu_seqlens = build_true_varlen_input( + seq_lens=seq_lens, + num_heads=args.H, + chunk_size=args.chunk_size, + feature_dim=args.feature_dim, + device=NPU_DEVICE, + ) + varlen_run_true, _ = make_varlen_runner( + tri_inv_func, + packed_input, + cu_seqlens, + ) + times_ms = benchmark_ms( + varlen_run_true, + warmup_iters=args.warmup, + benchmark_iters=args.repeats, + device=NPU_DEVICE, + ) + row = { + "inverse_type": "bsnd-varlen-true", + "metadata_strategy": "device-cu_seqlens", + "dtype": "fp16", + "B": args.B, + "T": seqlen, + "aggregated_T": total_tokens, + "padded_T": int(packed_input.shape[1]), + "H": args.H, + "numel": packed_input.numel(), + "valid_numel": total_tokens * args.H * args.chunk_size, + "chunk_size": args.chunk_size, + "time_us": int(round(np.mean(times_ms) * 1000.0)), + "kernel_time_us": int(round(np.mean(times_ms) * 1000.0)), + "metadata_time_us": 0, + "max_abs_diff_to_fixed": "", + "rel_frob_diff_to_fixed": "", + "sample_id": sample_idx, + "seq_lens": ",".join(map(str, seq_lens)), + } + add_bandwidth_fields(row) + true_varlen_rows.append(row) + print( + f" true-varlen sample={sample_idx}: aggregated_T={total_tokens}, " + f"padded_T={row['padded_T']}, bw_gbs={row['bw_gbs']:.2f}" + ) + + if not rows: + raise RuntimeError("No benchmark rows were generated.") + + write_csv(csv_path, rows) + plot_bandwidth( + plot_path, + rows, + batch_size=args.B, + num_heads=args.H, + chunk_size=args.chunk_size, + ) + write_csv(true_varlen_csv_path, true_varlen_rows) + plot_true_varlen_scatter( + true_varlen_plot_path, + true_varlen_rows, + batch_size=args.B, + num_heads=args.H, + chunk_size=args.chunk_size, + ) + print(f"\nWrote CSV: {csv_path}") + print(f"Wrote plot: {plot_path}") + print(f"Wrote true-varlen CSV: {true_varlen_csv_path}") + print(f"Wrote true-varlen plot: {true_varlen_plot_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/jit_cpp/fast_inverse/fast_inverse.cpp b/examples/jit_cpp/fast_inverse/fast_inverse.cpp new file mode 100644 index 00000000..42ba6cce --- /dev/null +++ b/examples/jit_cpp/fast_inverse/fast_inverse.cpp @@ -0,0 +1,48 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ + +// Include the triangular inverse kernel implementation. +// The build script adds csrc/kernel/ to the include path so that +// kernel_utils.h (included by kernel_tri_inv_rec_unroll.cpp) is found. +#include "kernel_tri_inv_rec_unroll.cpp" + +/** + * @brief JIT entry point for the triangular inverse (recursive unroll) kernel. + * + * @param blockDim Number of AI-Core blocks to launch. + * @param stream NPU stream handle. + * @param tensor_out fp32 output buffer (same element count as tensor_in). + * @param tensor_in fp16 input buffer holding the upper-triangular matrices + * (diagonal is assumed to be all-ones). + * @param minus_identity_in fp16 buffer of size matrix_size×matrix_size + * pre-filled with -I (negative identity). + * @param matrix_size Side length of each square matrix (16 / 32 / 64 / 128). + * @param num_matrices Total number of matrices to invert. + * @param num_bsnd_heads 0 for standard (B…ND) layout; + * N (number of heads) for BSND layout. + * @param cu_seqlens Optional int32 pointer used only for varlen BSND when the + * device kernel derives chunk metadata itself. + * @param chunk_sequence_prefix Optional int32 pointer containing a compact + * per-sequence cumulative chunk-count prefix. + * @param chunk_indices Optional int32 pointer containing per-chunk row starts + * for the host-precomputed varlen path. + * @param chunk_valid_sizes Optional int32 pointer containing each chunk's + * runtime size for the host-precomputed varlen path. + */ +extern "C" void call_kernel(uint32_t blockDim, void* stream, void* tensor_out, + void* tensor_in, void* minus_identity_in, + uint32_t matrix_size, uint32_t num_matrices, + uint32_t num_bsnd_heads, void* cu_seqlens, + void* chunk_sequence_prefix, + void* chunk_indices, void* chunk_valid_sizes) { + tri_inv_rec_unroll_fp16<<>>( + tensor_out, tensor_in, minus_identity_in, matrix_size, num_matrices, + num_bsnd_heads, cu_seqlens, chunk_sequence_prefix, chunk_indices, + chunk_valid_sizes); +} diff --git a/examples/jit_cpp/fast_inverse/host_chunk_metadata.cpp b/examples/jit_cpp/fast_inverse/host_chunk_metadata.cpp new file mode 100644 index 00000000..bfef9b29 --- /dev/null +++ b/examples/jit_cpp/fast_inverse/host_chunk_metadata.cpp @@ -0,0 +1,58 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ + +#include +#include + +extern "C" uint32_t count_varlen_chunks_host_cpp(const int32_t* cu_seqlens, + uint32_t num_sequences, + uint32_t chunk_size) { + uint32_t total_chunks = 0; + for (uint32_t seq_idx = 0; seq_idx < num_sequences; ++seq_idx) { + const uint32_t seq_start = static_cast(cu_seqlens[seq_idx]); + const uint32_t seq_end = static_cast(cu_seqlens[seq_idx + 1]); + const uint32_t seq_len = seq_end - seq_start; + total_chunks += (seq_len + chunk_size - 1) / chunk_size; + } + return total_chunks; +} + +extern "C" void build_varlen_chunk_metadata_host_cpp( + const int32_t* cu_seqlens, uint32_t num_sequences, uint32_t chunk_size, + int32_t* chunk_indices, int32_t* chunk_valid_sizes) { + uint32_t chunk_idx = 0; + for (uint32_t seq_idx = 0; seq_idx < num_sequences; ++seq_idx) { + const uint32_t seq_start = static_cast(cu_seqlens[seq_idx]); + const uint32_t seq_end = static_cast(cu_seqlens[seq_idx + 1]); + for (uint32_t row_start = seq_start; row_start < seq_end; + row_start += chunk_size) { + const uint32_t valid_size = + std::min(chunk_size, static_cast(seq_end - row_start)); + chunk_indices[chunk_idx] = static_cast(row_start); + chunk_valid_sizes[chunk_idx] = static_cast(valid_size); + ++chunk_idx; + } + } +} + +extern "C" void build_chunk_sequence_prefix_host_cpp( + const int32_t* cu_seqlens, uint32_t num_sequences, uint32_t chunk_size, + int32_t* chunk_sequence_prefix) { + chunk_sequence_prefix[0] = static_cast(num_sequences); + chunk_sequence_prefix[1] = 0; + + uint32_t total_chunks = 0; + for (uint32_t seq_idx = 0; seq_idx < num_sequences; ++seq_idx) { + const uint32_t seq_start = static_cast(cu_seqlens[seq_idx]); + const uint32_t seq_end = static_cast(cu_seqlens[seq_idx + 1]); + const uint32_t seq_len = seq_end - seq_start; + total_chunks += (seq_len + chunk_size - 1) / chunk_size; + chunk_sequence_prefix[seq_idx + 2] = static_cast(total_chunks); + } +} diff --git a/examples/jit_cpp/fast_inverse/host_metadata_util.py b/examples/jit_cpp/fast_inverse/host_metadata_util.py new file mode 100644 index 00000000..b0d13383 --- /dev/null +++ b/examples/jit_cpp/fast_inverse/host_metadata_util.py @@ -0,0 +1,133 @@ +# -------------------------------------------------------------------------------- +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# All rights reserved. +# See LICENSE in the root of the software repository: +# https://github.com/huawei-csl/pto-kernels/ +# for the full License text. +# -------------------------------------------------------------------------------- + +from __future__ import annotations + +import ctypes +import os +import subprocess + +import torch + + +_THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +_SRC = os.path.join(_THIS_DIR, "host_chunk_metadata.cpp") +_LIB = os.path.join(_THIS_DIR, "host_chunk_metadata.so") +_HOST_LIB = None + + +def _torch_to_ctypes(tensor: torch.Tensor) -> ctypes.c_void_p: + return ctypes.c_void_p(tensor.data_ptr()) + + +def compile_host_metadata_cpp(timeout: int = 60) -> str: + compiler = os.environ.get("CXX", "g++") + command = [ + compiler, + "-O3", + "-std=c++17", + "-shared", + "-fPIC", + _SRC, + "-o", + _LIB, + ] + try: + subprocess.run(command, timeout=timeout, check=True) + except Exception as exc: + raise RuntimeError(f"Host metadata compilation failed: {exc}") from exc + return _LIB + + +def load_host_metadata_lib(): + global _HOST_LIB + if _HOST_LIB is not None: + return _HOST_LIB + + lib_path = compile_host_metadata_cpp() + lib = ctypes.CDLL(os.path.abspath(lib_path)) + lib.count_varlen_chunks_host_cpp.argtypes = [ + ctypes.c_void_p, + ctypes.c_uint32, + ctypes.c_uint32, + ] + lib.count_varlen_chunks_host_cpp.restype = ctypes.c_uint32 + lib.build_varlen_chunk_metadata_host_cpp.argtypes = [ + ctypes.c_void_p, + ctypes.c_uint32, + ctypes.c_uint32, + ctypes.c_void_p, + ctypes.c_void_p, + ] + lib.build_varlen_chunk_metadata_host_cpp.restype = None + lib.build_chunk_sequence_prefix_host_cpp.argtypes = [ + ctypes.c_void_p, + ctypes.c_uint32, + ctypes.c_uint32, + ctypes.c_void_p, + ] + lib.build_chunk_sequence_prefix_host_cpp.restype = None + _HOST_LIB = lib + return lib + + +def build_varlen_chunk_metadata_cpp( + cu_seqlens: torch.Tensor | list[int], + chunk_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + lib = load_host_metadata_lib() + if isinstance(cu_seqlens, torch.Tensor): + cu_seqlens_cpu = cu_seqlens.detach().to(device="cpu", dtype=torch.int32).contiguous() + else: + cu_seqlens_cpu = torch.tensor(cu_seqlens, dtype=torch.int32) + + if cu_seqlens_cpu.numel() < 2: + raise ValueError("cu_seqlens must contain at least 2 entries.") + + num_sequences = cu_seqlens_cpu.numel() - 1 + num_chunks = int( + lib.count_varlen_chunks_host_cpp( + _torch_to_ctypes(cu_seqlens_cpu), + num_sequences, + chunk_size, + ) + ) + chunk_indices = torch.empty(num_chunks, dtype=torch.int32) + chunk_valid_sizes = torch.empty(num_chunks, dtype=torch.int32) + lib.build_varlen_chunk_metadata_host_cpp( + _torch_to_ctypes(cu_seqlens_cpu), + num_sequences, + chunk_size, + _torch_to_ctypes(chunk_indices), + _torch_to_ctypes(chunk_valid_sizes), + ) + return chunk_indices, chunk_valid_sizes + + +def build_chunk_sequence_prefix_cpp( + cu_seqlens: torch.Tensor | list[int], + chunk_size: int, +) -> torch.Tensor: + lib = load_host_metadata_lib() + if isinstance(cu_seqlens, torch.Tensor): + cu_seqlens_cpu = cu_seqlens.detach().to(device="cpu", dtype=torch.int32).contiguous() + else: + cu_seqlens_cpu = torch.tensor(cu_seqlens, dtype=torch.int32) + + if cu_seqlens_cpu.numel() < 2: + raise ValueError("cu_seqlens must contain at least 2 entries.") + + num_sequences = cu_seqlens_cpu.numel() - 1 + chunk_sequence_prefix = torch.empty(num_sequences + 2, dtype=torch.int32) + lib.build_chunk_sequence_prefix_host_cpp( + _torch_to_ctypes(cu_seqlens_cpu), + num_sequences, + chunk_size, + _torch_to_ctypes(chunk_sequence_prefix), + ) + return chunk_sequence_prefix diff --git a/examples/jit_cpp/fast_inverse/jit_util_fast_inverse.py b/examples/jit_cpp/fast_inverse/jit_util_fast_inverse.py new file mode 100644 index 00000000..a8c86c7d --- /dev/null +++ b/examples/jit_cpp/fast_inverse/jit_util_fast_inverse.py @@ -0,0 +1,170 @@ +# -------------------------------------------------------------------------------- +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# All rights reserved. +# See LICENSE in the root of the software repository: +# https://github.com/huawei-csl/pto-kernels/ +# for the full License text. +# -------------------------------------------------------------------------------- + +import ctypes +import os +import subprocess + +import torch + +# --------------------------------------------------------------------------- +# Environment / paths +# --------------------------------------------------------------------------- +PTO_LIB_PATH = os.environ.get("PTO_LIB_PATH", os.environ["ASCEND_TOOLKIT_HOME"]) + +# Directory of this file → repo-root/examples/jit_cpp/fast_inverse +_THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +# csrc/kernel lives three levels up from this file +_CSRC_KERNEL_DIR = os.path.abspath(os.path.join(_THIS_DIR, "../../../csrc/kernel")) + +BLOCK_DIM = int(getattr(torch.npu.get_device_properties("npu:0"), "cube_core_num", 20)) + + +# --------------------------------------------------------------------------- +# Compilation +# --------------------------------------------------------------------------- + +def compile_cpp(kernel_cpp: str, verbose: bool = False, timeout: int = 180) -> str: + """Compile *kernel_cpp* with bisheng and return the path to the .so.""" + lib_path = os.path.join(os.path.dirname(kernel_cpp), "fast_inverse_jit.so") + + flags = [ + "-fPIC", + "-shared", + "-xcce", + "-DMEMORY_BASE", + "-O2", + "-std=c++17", + # Resolve kernel_utils.h (included by kernel_tri_inv_rec_unroll.cpp) + f"-I{_CSRC_KERNEL_DIR}", + # PTO-ISA headers + f"-I{PTO_LIB_PATH}/include", + # Target the Ascend 910B cube core + "--cce-soc-version=Ascend910B4", + "--cce-soc-core-type=CubeCore", + ] + + command = ["bisheng", *flags, kernel_cpp, "-o", lib_path] + if verbose: + print("Compiling fast_inverse kernel:") + print(" ", " ".join(command)) + + try: + subprocess.run(command, timeout=timeout, check=True) + except Exception as exc: + raise RuntimeError(f"Compilation failed: {exc}") from exc + + if verbose: + print(f"Generated: {lib_path}") + return lib_path + + +# --------------------------------------------------------------------------- +# Loading +# --------------------------------------------------------------------------- + +def _torch_to_ctypes(tensor: torch.Tensor) -> ctypes.c_void_p: + return ctypes.c_void_p(tensor.data_ptr()) + + +def load_lib(lib_path: str): + """Load the compiled .so and return a Python callable for the kernel.""" + lib = ctypes.CDLL(os.path.abspath(lib_path)) + + lib.call_kernel.argtypes = [ + ctypes.c_uint32, # blockDim + ctypes.c_void_p, # stream + ctypes.c_void_p, # tensor_out (fp32) + ctypes.c_void_p, # tensor_in (fp16) + ctypes.c_void_p, # minus_identity_in (fp16) + ctypes.c_uint32, # matrix_size + ctypes.c_uint32, # num_matrices + ctypes.c_uint32, # num_bsnd_heads + ctypes.c_void_p, # cu_seqlens (optional int32 metadata) + ctypes.c_void_p, # chunk_sequence_prefix (optional int32 metadata) + ctypes.c_void_p, # chunk_indices (optional int32 metadata) + ctypes.c_void_p, # chunk_valid_sizes (optional int32 metadata) + ] + lib.call_kernel.restype = None + + def tri_inv_func( + tensor_out: torch.Tensor, + tensor_in: torch.Tensor, + minus_identity: torch.Tensor, + matrix_size: int, + num_matrices: int, + num_bsnd_heads: int = 0, + cu_seqlens: torch.Tensor | None = None, + chunk_sequence_prefix: torch.Tensor | None = None, + chunk_indices: torch.Tensor | None = None, + chunk_valid_sizes: torch.Tensor | None = None, + block_dim: int = BLOCK_DIM, + stream_ptr=None, + ): + if stream_ptr is None: + stream_ptr = torch.npu.current_stream()._as_parameter_ # noqa + if cu_seqlens is not None: + if cu_seqlens.dtype != torch.int32: + raise TypeError("cu_seqlens must be int32.") + if not cu_seqlens.is_contiguous(): + raise ValueError("cu_seqlens must be contiguous.") + if chunk_sequence_prefix is not None: + if chunk_sequence_prefix.dtype != torch.int32: + raise TypeError("chunk_sequence_prefix must be int32.") + if not chunk_sequence_prefix.is_contiguous(): + raise ValueError("chunk_sequence_prefix must be contiguous.") + if chunk_indices is not None: + if chunk_indices.dtype != torch.int32: + raise TypeError("chunk_indices must be int32.") + if not chunk_indices.is_contiguous(): + raise ValueError("chunk_indices must be contiguous.") + if chunk_valid_sizes is not None: + if chunk_valid_sizes.dtype != torch.int32: + raise TypeError("chunk_valid_sizes must be int32.") + if not chunk_valid_sizes.is_contiguous(): + raise ValueError("chunk_valid_sizes must be contiguous.") + if (chunk_indices is None) != (chunk_valid_sizes is None): + raise ValueError("chunk_indices and chunk_valid_sizes must be provided together.") + effective_block_dim = min(block_dim, num_matrices) + lib.call_kernel( + effective_block_dim, + stream_ptr, + _torch_to_ctypes(tensor_out), + _torch_to_ctypes(tensor_in), + _torch_to_ctypes(minus_identity), + matrix_size, + num_matrices, + num_bsnd_heads, + _torch_to_ctypes(cu_seqlens) + if cu_seqlens is not None + else ctypes.c_void_p(), + _torch_to_ctypes(chunk_sequence_prefix) + if chunk_sequence_prefix is not None + else ctypes.c_void_p(), + _torch_to_ctypes(chunk_indices) + if chunk_indices is not None + else ctypes.c_void_p(), + _torch_to_ctypes(chunk_valid_sizes) + if chunk_valid_sizes is not None + else ctypes.c_void_p(), + ) + + return tri_inv_func + + +# --------------------------------------------------------------------------- +# Convenience: compile + load in one call +# --------------------------------------------------------------------------- + +def jit_compile(src_path: str, verbose: bool = True, clean_up: bool = False): + """Compile *src_path* and return the kernel callable.""" + lib_path = compile_cpp(src_path, verbose=verbose) + func = load_lib(lib_path) + if clean_up: + os.remove(lib_path) + return func diff --git a/examples/jit_cpp/fast_inverse/kernel_tri_inv_rec_unroll.cpp b/examples/jit_cpp/fast_inverse/kernel_tri_inv_rec_unroll.cpp new file mode 100644 index 00000000..3d093648 --- /dev/null +++ b/examples/jit_cpp/fast_inverse/kernel_tri_inv_rec_unroll.cpp @@ -0,0 +1,976 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ + +#ifndef MEMORY_BASE +#define MEMORY_BASE +#endif +#include + +#include "kernel_utils.h" + +#define GM_ADDR __gm__ uint8_t* // To avoid #include "kernel_operator.h" +using namespace pto; +using namespace kernel_utils; + +#define BSND_OFFSET(tile_id, N, S, D) \ + (((tile_id) / (N)) * (S) * (N) * (D) + ((tile_id) % (N)) * (D)) + +/* + * For aligned BSND, tile_id enumerates chunk-major then head-major and maps to + * a fixed-stride address inside the dense BSND tensor. + */ +AICORE inline uint32_t GetBSNDFixedTileOffset(uint32_t tile_id, + uint32_t num_bsnd_heads, + uint32_t matrix_size) { + return BSND_OFFSET(tile_id, num_bsnd_heads, matrix_size, matrix_size); +} + +struct BSNDVarlenTileInfo { + uint32_t bsnd_offset; + uint32_t valid_size; +}; + +/* + * For cu_seqlens-based varlen BSND, tile_id still enumerates chunk-major then + * head-major. We recover the owning sequence by scanning cu_seqlens and + * counting chunks per sequence. + */ +AICORE inline BSNDVarlenTileInfo GetBSNDVarlenTileInfoFromCuSeqlens( + uint32_t tile_id, uint32_t num_bsnd_heads, uint32_t matrix_size, + __gm__ int32_t* cu_seqlens) { + const uint32_t head_idx = tile_id % num_bsnd_heads; + const uint32_t chunk_idx = tile_id / num_bsnd_heads; + + uint32_t seq_start = static_cast(cu_seqlens[0]); + uint32_t accumulated_chunks = 0; + for (uint32_t seq_idx = 0;; ++seq_idx) { + const uint32_t seq_end = static_cast(cu_seqlens[seq_idx + 1]); + const uint32_t seq_len = seq_end - seq_start; + const uint32_t seq_num_chunks = CeilDiv(seq_len, matrix_size); + if (chunk_idx < accumulated_chunks + seq_num_chunks) { + const uint32_t local_chunk_idx = chunk_idx - accumulated_chunks; + const uint32_t row_start = seq_start + local_chunk_idx * matrix_size; + const uint32_t valid_size = + min(static_cast(seq_end - row_start), matrix_size); + return {row_start * num_bsnd_heads * matrix_size + head_idx * matrix_size, + valid_size}; + } + accumulated_chunks += seq_num_chunks; + seq_start = seq_end; + } +} + +AICORE inline BSNDVarlenTileInfo GetBSNDVarlenTileInfoFromChunkMetadata( + uint32_t tile_id, uint32_t num_bsnd_heads, uint32_t matrix_size, + __gm__ int32_t* chunk_indices, __gm__ int32_t* chunk_valid_sizes) { + const uint32_t head_idx = tile_id % num_bsnd_heads; + const uint32_t chunk_idx = tile_id / num_bsnd_heads; + const uint32_t row_start = static_cast(chunk_indices[chunk_idx]); + const uint32_t valid_size = + static_cast(chunk_valid_sizes[chunk_idx]); + return {row_start * num_bsnd_heads * matrix_size + head_idx * matrix_size, + valid_size}; +} + +AICORE inline BSNDVarlenTileInfo GetBSNDVarlenTileInfoFromChunkPrefix( + uint32_t tile_id, uint32_t num_bsnd_heads, uint32_t matrix_size, + __gm__ int32_t* cu_seqlens, __gm__ int32_t* chunk_sequence_prefix) { + const uint32_t head_idx = tile_id % num_bsnd_heads; + const uint32_t chunk_idx = tile_id / num_bsnd_heads; + const uint32_t num_sequences = + static_cast(chunk_sequence_prefix[0]); + + uint32_t left = 0; + uint32_t right = num_sequences; + while (left < right) { + const uint32_t mid = (left + right) / 2; + const uint32_t chunk_end = + static_cast(chunk_sequence_prefix[mid + 2]); + if (chunk_idx < chunk_end) { + right = mid; + } else { + left = mid + 1; + } + } + + const uint32_t seq_idx = left; + const uint32_t chunk_base = + static_cast(chunk_sequence_prefix[seq_idx + 1]); + const uint32_t local_chunk_idx = chunk_idx - chunk_base; + const uint32_t seq_start = static_cast(cu_seqlens[seq_idx]); + const uint32_t seq_end = static_cast(cu_seqlens[seq_idx + 1]); + const uint32_t row_start = seq_start + local_chunk_idx * matrix_size; + const uint32_t valid_size = + min(static_cast(seq_end - row_start), matrix_size); + return {row_start * num_bsnd_heads * matrix_size + head_idx * matrix_size, + valid_size}; +} + +/* + * @brief: Takes as input two matrices of size MatrixSize * MatrixSize each. + * The src matrix lies in L1, while the dst matrix lies either in L0A or L0B. + * This kernel copies only the diagonal blocks (fractals) of size FractalSize * + * FractalSize from the src matrix to the dst matrix. + * + * @tparam InputT Input data type (fp16). + * @tparam FractalSize Size of each fractal matrix (diagonal block). + * @tparam MatrixSize Size of the entire input/output matrices. + * @tparam SrcL1TileT The actual tile type of the src matrix. + * @tparam DstL0TileT The actual tile type of the dst matrix. + * + * @param src Tile in L1 memory. + * @param dst Tile in L0A or L0B memory. + */ +template +AICORE inline void CopyDiagonalFractalsL1ToL0(SrcL1TileT src, DstL0TileT dst) { + constexpr uint32_t NumFractals = MatrixSize / FractalSize; + constexpr bool is_left = + std::is_same_v>; + constexpr TileType LeftOrRight = is_left ? TileType::Left : TileType::Right; + constexpr SLayout InnerLayout = + is_left ? SLayout::RowMajor : SLayout::ColMajor; + + Tile + fractals[NumFractals]; + const std::uintptr_t starting_address = + reinterpret_cast(dst.data()); + for (uint32_t i = 0; i < NumFractals; ++i) { + TASSIGN(fractals[i], starting_address + i * FractalSize * + (MatrixSize + FractalSize) * + sizeof(InputT)); + TEXTRACT(fractals[i], src, i * FractalSize, i * FractalSize); + } +} + +/* + * @brief: Takes as input two matrices of size MatrixSize * MatrixSize each, + * and an integer block_size. The src matrix lies in L1, while the dst matrix + * either in L0A or L0B. This method copies some of the diagonal blocks from the + * input to the output as follows: + * - If dst is in L0A (left): copy even diagonal blocks 0, 2, 4, ... + * - If dst is in L0B (right): copy odd blocks 1, 3, 5, ... + * Important note: the dst matrix should be initialized to all-zeros before + * calling this method + * + * @tparam InputT Input data type (fp16). + * @tparam FractalSize Size of each fractal matrix (diagonal block). + * @tparam MatrixSize Size of the entire input/output matrices. + * @tparam SrcL1TileT The actual tile type of the src matrix. + * @tparam DstL0TileT The actual tile type of the dst matrix. + * + * @param src Tile in L1 memory. + * @param dst Tile in L0A or L0B memory. + * @param block_size Size of diagonal blocks. Needs: block_size >= FractalSize. + */ +template +AICORE inline void CopyOddOrEvenBlocksL1ToL0(SrcL1TileT src, DstL0TileT dst, + uint32_t block_size) { + constexpr bool is_left = + std::is_same_v>; + constexpr TileType LeftOrRight = is_left ? TileType::Left : TileType::Right; + constexpr SLayout InnerLayout = + is_left ? SLayout::RowMajor : SLayout::ColMajor; + + const uint32_t starting_block_index = is_left ? 0 : 1; + + const uint32_t num_blocks = MatrixSize / block_size; + const uint32_t num_fractals_per_block = block_size / FractalSize; + + Tile + fractals[MatrixSize / FractalSize]; + + const std::uintptr_t starting_address = + reinterpret_cast(dst.data()); + for (uint32_t i = 0; i < num_fractals_per_block; ++i) { + for (uint32_t j = 0; j < num_fractals_per_block; ++j) { + for (uint32_t b = starting_block_index; b < num_blocks; b += 2) { + const uint32_t offset = + b * (MatrixSize + FractalSize) * block_size + + i * MatrixSize * FractalSize + + j * FractalSize * FractalSize; + TASSIGN(fractals[b], starting_address + offset * sizeof(InputT)); + TEXTRACT(fractals[b], src, b * block_size + i * FractalSize, + b * block_size + j * FractalSize); + } + } + } +} + +/* + * @brief: Prepares Identity and Zeros matrix. + * + * @tparam TileL1AB The type of the input tiles in L1. + * @tparam TileL0A The type of the input tiles in L0A. + * @tparam TileL0B The type of the input tiles in L0B. + * @tparam TileL0C The type of the input tiles in L0C. + * + * @param I_neg_l1_tile Tile containing the -I (negative identity) matrix. + * @param Zero_l1_tile Tile to store the all-zero matrix. + * @param I_l1_tile Tile to store the identity matrix. + * @param a_l0_tile Tile in L0A for matmuls. + * @param b_l0_tile Tile in L0B for matmuls. + * @param c_l0_tile Tile in L0C for matmuls. + */ +template +AICORE inline void PrepareAuxiliaryMatrices( + TileL1AB I_neg_l1_tile, TileL1AB Zero_l1_tile, TileL1AB I_l1_tile, + TileL0A a_l0_tile, TileL0B b_l0_tile, TileL0C c_l0_tile) { + TMOV(a_l0_tile, I_neg_l1_tile); + TMOV(b_l0_tile, I_neg_l1_tile); + set_flag(PIPE_MTE1, PIPE_M, static_cast(0)); + wait_flag(PIPE_MTE1, PIPE_M, static_cast(0)); + + TMATMUL(c_l0_tile, a_l0_tile, b_l0_tile); + set_flag(PIPE_M, PIPE_FIX, static_cast(0)); + wait_flag(PIPE_M, PIPE_FIX, static_cast(0)); + + TMOV(I_l1_tile, c_l0_tile); + set_flag(PIPE_FIX, PIPE_MTE1, static_cast(0)); + wait_flag(PIPE_FIX, PIPE_MTE1, static_cast(0)); + + TMOV(b_l0_tile, I_l1_tile); + set_flag(PIPE_MTE1, PIPE_M, static_cast(0)); + wait_flag(PIPE_MTE1, PIPE_M, static_cast(0)); + + TMATMUL_ACC(c_l0_tile, c_l0_tile, a_l0_tile, b_l0_tile); + set_flag(PIPE_M, PIPE_FIX, static_cast(0)); + wait_flag(PIPE_M, PIPE_FIX, static_cast(0)); + + TMOV(Zero_l1_tile, c_l0_tile); + set_flag(PIPE_FIX, PIPE_MTE1, static_cast(0)); + wait_flag(PIPE_FIX, PIPE_MTE1, static_cast(0)); +} + +/* + * @brief: Inverts a single matrix / tile of the global tensor. + */ +template +AICORE inline void InvertSingleTile(TileL1AB X_l1_tile, TileL1AB I_l1_tile, + TileL1AB I_neg_l1_tile, + TileL1AB M_neg_l1_tile, + TileL1AB Zero_l1_tile, TileL1AB Y_l1_tile, + TileL0A* a_l0_tile, TileL0B* b_l0_tile, + TileL0C* c_l0_tile, + const uint32_t tile_id) { + const event_t event_0 = static_cast(tile_id); + const event_t event_1 = static_cast(tile_id + NumTilesPerCubeIter); + + TMOV(b_l0_tile[0], Y_l1_tile); + TMOV(a_l0_tile[0], I_neg_l1_tile); + set_flag(PIPE_MTE1, PIPE_M, event_0); + TMOV(a_l0_tile[1], Zero_l1_tile); + TMOV(b_l0_tile[1], Zero_l1_tile); + set_flag(PIPE_MTE1, PIPE_M, event_1); + wait_flag(PIPE_MTE1, PIPE_M, event_1); + set_flag(PIPE_M, PIPE_MTE1, event_1); + wait_flag(PIPE_M, PIPE_MTE1, event_1); + CopyDiagonalFractalsL1ToL0(Y_l1_tile, + a_l0_tile[1]); + CopyDiagonalFractalsL1ToL0(Y_l1_tile, + b_l0_tile[1]); + set_flag(PIPE_MTE1, PIPE_M, event_1); + + wait_flag(PIPE_MTE1, PIPE_M, event_0); + TMATMUL(c_l0_tile[0], a_l0_tile[0], b_l0_tile[0]); + set_flag(PIPE_M, PIPE_FIX, event_0); + set_flag(PIPE_M, PIPE_MTE1, event_0); + + wait_flag(PIPE_M, PIPE_FIX, event_0); + TMOV(M_neg_l1_tile, c_l0_tile[0]); + set_flag(PIPE_FIX, PIPE_M, event_0); + + wait_flag(PIPE_MTE1, PIPE_M, event_1); + set_flag(PIPE_MTE1, PIPE_M, event_1); + TMATMUL(c_l0_tile[1], a_l0_tile[1], b_l0_tile[1]); + set_flag(PIPE_M, PIPE_FIX, event_1); + wait_flag(PIPE_M, PIPE_FIX, event_1); + TMOV(Y_l1_tile, c_l0_tile[1]); + set_flag(PIPE_FIX, PIPE_M, event_1); + wait_flag(PIPE_FIX, PIPE_M, event_1); + + wait_flag(PIPE_M, PIPE_MTE1, event_0); + TMOV(b_l0_tile[0], I_neg_l1_tile); + TMOV(a_l0_tile[0], I_neg_l1_tile); + set_flag(PIPE_MTE1, PIPE_M, event_0); + + wait_flag(PIPE_MTE1, PIPE_M, event_0); + wait_flag(PIPE_FIX, PIPE_M, event_0); + wait_flag(PIPE_MTE1, PIPE_M, event_1); + TMATMUL(c_l0_tile[0], a_l0_tile[1], b_l0_tile[0]); + set_flag(PIPE_M, PIPE_FIX, event_0); + wait_flag(PIPE_M, PIPE_FIX, event_0); + set_flag(PIPE_FIX, PIPE_M, event_0); + wait_flag(PIPE_FIX, PIPE_M, event_0); + + TMATMUL_ACC(c_l0_tile[0], c_l0_tile[0], a_l0_tile[0], b_l0_tile[0]); + set_flag(PIPE_M, PIPE_FIX, event_1); + wait_flag(PIPE_M, PIPE_FIX, event_1); + TMOV(X_l1_tile, c_l0_tile[0]); + + set_flag(PIPE_FIX, PIPE_M, event_0); + set_flag(PIPE_M, PIPE_MTE1, event_0); + set_flag(PIPE_FIX, PIPE_MTE1, event_0); + set_flag(PIPE_FIX, PIPE_M, event_1); + set_flag(PIPE_M, PIPE_MTE1, event_1); + set_flag(PIPE_FIX, PIPE_MTE1, event_1); + for (uint32_t block_size = 1; block_size < FractalSize / 2; block_size *= 2) { + wait_flag(PIPE_M, PIPE_MTE1, event_0); + TMOV(b_l0_tile[0], I_l1_tile); + wait_flag(PIPE_FIX, PIPE_MTE1, event_0); + TMOV(a_l0_tile[0], X_l1_tile); + set_flag(PIPE_MTE1, PIPE_M, event_0); + + wait_flag(PIPE_FIX, PIPE_MTE1, event_1); + TMOV(b_l0_tile[1], Y_l1_tile); + set_flag(PIPE_MTE1, PIPE_M, event_1); + + wait_flag(PIPE_FIX, PIPE_M, event_0); + wait_flag(PIPE_MTE1, PIPE_M, event_0); + TMATMUL(c_l0_tile[0], a_l0_tile[0], b_l0_tile[0]); + set_flag(PIPE_M, PIPE_FIX, event_0); + wait_flag(PIPE_M, PIPE_FIX, event_0); + set_flag(PIPE_FIX, PIPE_M, event_0); + wait_flag(PIPE_FIX, PIPE_M, event_0); + + if (block_size < FractalSize / 4) { + wait_flag(PIPE_M, PIPE_MTE1, event_1); + TMOV(a_l0_tile[1], Y_l1_tile); + wait_flag(PIPE_MTE1, PIPE_M, event_1); + set_flag(PIPE_MTE1, PIPE_M, event_1); + + wait_flag(PIPE_MTE1, PIPE_M, event_1); + wait_flag(PIPE_FIX, PIPE_M, event_1); + TMATMUL(c_l0_tile[1], a_l0_tile[1], b_l0_tile[1]); + set_flag(PIPE_M, PIPE_MTE1, event_1); + set_flag(PIPE_M, PIPE_FIX, event_1); + set_flag(PIPE_MTE1, PIPE_M, event_1); + + wait_flag(PIPE_M, PIPE_FIX, event_1); + TMOV(Y_l1_tile, c_l0_tile[1]); + set_flag(PIPE_FIX, PIPE_M, event_1); + } + set_flag(PIPE_FIX, PIPE_MTE1, event_1); + + wait_flag(PIPE_MTE1, PIPE_M, event_1); + TMATMUL_ACC(c_l0_tile[0], c_l0_tile[0], a_l0_tile[0], b_l0_tile[1]); + set_flag(PIPE_M, PIPE_MTE1, event_0); + set_flag(PIPE_M, PIPE_FIX, event_0); + + wait_flag(PIPE_M, PIPE_FIX, event_0); + TMOV(X_l1_tile, c_l0_tile[0]); + set_flag(PIPE_FIX, PIPE_M, event_0); + set_flag(PIPE_FIX, PIPE_MTE1, event_0); + } + wait_flag(PIPE_FIX, PIPE_MTE1, event_1); + wait_flag(PIPE_M, PIPE_MTE1, event_1); + wait_flag(PIPE_FIX, PIPE_M, event_1); + wait_flag(PIPE_FIX, PIPE_MTE1, event_0); + wait_flag(PIPE_M, PIPE_MTE1, event_0); + wait_flag(PIPE_FIX, PIPE_M, event_0); + + TMOV(b_l0_tile[1], M_neg_l1_tile); + TMOV(a_l0_tile[0], I_l1_tile); + + if constexpr (MatrixSize > FractalSize) { + set_flag(PIPE_FIX, PIPE_M, event_1); + } + set_flag(PIPE_M, PIPE_MTE1, event_1); + set_flag(PIPE_M, PIPE_MTE1, event_0); + set_flag(PIPE_FIX, PIPE_MTE1, event_1); + set_flag(PIPE_FIX, PIPE_M, event_0); + for (uint32_t block_size = FractalSize; block_size < MatrixSize; + block_size *= 2) { + wait_flag(PIPE_M, PIPE_MTE1, event_0); + TMOV(a_l0_tile[1], Zero_l1_tile); + + wait_flag(PIPE_M, PIPE_MTE1, event_1); + TMOV(b_l0_tile[0], I_l1_tile); + set_flag(PIPE_MTE1, PIPE_M, event_0); + + wait_flag(PIPE_FIX, PIPE_MTE1, event_1); + CopyOddOrEvenBlocksL1ToL0(X_l1_tile, + a_l0_tile[1], + block_size); + set_flag(PIPE_MTE1, PIPE_M, event_1); + + wait_flag(PIPE_MTE1, PIPE_M, event_0); + wait_flag(PIPE_FIX, PIPE_M, event_0); + TMATMUL(c_l0_tile[0], a_l0_tile[0], b_l0_tile[0]); + + wait_flag(PIPE_MTE1, PIPE_M, event_1); + wait_flag(PIPE_FIX, PIPE_M, event_1); + TMATMUL(c_l0_tile[1], a_l0_tile[1], b_l0_tile[0]); + set_flag(PIPE_M, PIPE_MTE1, event_1); + + TMATMUL_ACC(c_l0_tile[0], c_l0_tile[0], a_l0_tile[1], b_l0_tile[1]); + set_flag(PIPE_M, PIPE_FIX, event_0); + set_flag(PIPE_M, PIPE_MTE1, event_0); + + wait_flag(PIPE_M, PIPE_FIX, event_0); + TMOV(Y_l1_tile, c_l0_tile[0]); + set_flag(PIPE_FIX, PIPE_MTE1, event_0); + set_flag(PIPE_FIX, PIPE_M, event_0); + + wait_flag(PIPE_M, PIPE_MTE1, event_1); + TMOV(b_l0_tile[0], Zero_l1_tile); + CopyOddOrEvenBlocksL1ToL0(X_l1_tile, + b_l0_tile[0], + block_size); + + wait_flag(PIPE_M, PIPE_MTE1, event_0); + wait_flag(PIPE_FIX, PIPE_MTE1, event_0); + TMOV(a_l0_tile[1], Y_l1_tile); + set_flag(PIPE_MTE1, PIPE_M, event_0); + + wait_flag(PIPE_MTE1, PIPE_M, event_0); + TMATMUL_ACC(c_l0_tile[1], c_l0_tile[1], a_l0_tile[1], b_l0_tile[0]); + set_flag(PIPE_M, PIPE_MTE1, event_0); + set_flag(PIPE_M, PIPE_MTE1, event_1); + set_flag(PIPE_M, PIPE_FIX, event_0); + wait_flag(PIPE_M, PIPE_FIX, event_0); + + if (block_size < MatrixSize / 2) { + TMOV(X_l1_tile, c_l0_tile[1]); + set_flag(PIPE_FIX, PIPE_M, event_1); + } + set_flag(PIPE_FIX, PIPE_MTE1, event_1); + } + wait_flag(PIPE_M, PIPE_MTE1, event_0); + wait_flag(PIPE_M, PIPE_MTE1, event_1); + wait_flag(PIPE_FIX, PIPE_M, event_0); + wait_flag(PIPE_FIX, PIPE_MTE1, event_1); +} + +/* + * @brief: Runs the main kernel (inverts all matrices in the tensor) + */ +template +AICORE inline void TriInvRecUnrollKernel(__gm__ OutputT* M_inv, + __gm__ InputT* M, __gm__ InputT* I_neg, + uint32_t total_tiles, + uint32_t num_bsnd_heads = 0, + __gm__ int32_t* cu_seqlens = nullptr, + __gm__ int32_t* chunk_sequence_prefix = + nullptr, + __gm__ int32_t* chunk_indices = nullptr, + __gm__ int32_t* chunk_valid_sizes = + nullptr) { + constexpr uint32_t TileLen = MatrixSize * MatrixSize; + constexpr uint32_t FractalSize = 16; + constexpr uint32_t NumL0Buffers = 2; + + if (get_block_idx() * NumTilesPerCubeIter >= total_tiles) { + return; + } + + using GlobalTileShapeIn = + TileShape2D; + using GlobalTileStridesIn = typename std::conditional< + !IsBSND, BaseShape2D, + Stride<1, 1, 1, -1, 1>>::type; + using GlobalTileIn = + GlobalTensor; + using GlobalTileDynShape = Shape<1, 1, 1, DYNAMIC, DYNAMIC>; + using GlobalTileDynStride = Stride<1, 1, 1, DYNAMIC, 1>; + using GlobalTileInDyn = + GlobalTensor; + + using GlobalTileStridesINeg = + BaseShape2D; + using GlobalTileINeg = GlobalTensor; + + using GlobalTileShapeOut = + TileShape2D; + using GlobalTileStridesOut = typename std::conditional< + !IsBSND, BaseShape2D, + Stride<1, 1, 1, -1, 1>>::type; + using GlobalTileOut = GlobalTensor; + using GlobalTileOutDyn = + GlobalTensor; + + using TileL1AB = + Tile; + using TileL1ABDyn = Tile; + using TileL0CDyn = TileAcc; + + using TileL0A = TileLeft; + using TileL0B = TileRight; + using TileL0C = TileAcc; + + GlobalTileINeg I_neg_global_in(I_neg); + + TileL1AB X_l1_tile; + TileL1AB I_l1_tile; + TileL1AB I_neg_l1_tile; + TileL1AB M_neg_l1_tile; + TileL1AB Zero_l1_tile; + TileL1AB Y_l1_tile[NumTilesPerCubeIter]; + + TileL0A a_l0_tile[NumL0Buffers]; + TileL0B b_l0_tile[NumL0Buffers]; + TileL0C c_l0_tile[NumL0Buffers]; + + TASSIGN(I_l1_tile, 0x0); + TASSIGN(I_neg_l1_tile, 0x0 + TileLen * sizeof(InputT)); + TASSIGN(Zero_l1_tile, 0x0 + 2 * TileLen * sizeof(InputT)); + TASSIGN(M_neg_l1_tile, 0x0 + 3 * TileLen * sizeof(InputT)); + TASSIGN(X_l1_tile, 0x0 + 4 * TileLen * sizeof(InputT)); + for (uint32_t tile_id = 0; tile_id < NumTilesPerCubeIter; ++tile_id) { + TASSIGN(Y_l1_tile[tile_id], 0x0 + (5 + tile_id) * TileLen * sizeof(InputT)); + } + + for (uint32_t buffer_num = 0; buffer_num < NumL0Buffers; ++buffer_num) { + TASSIGN(a_l0_tile[buffer_num], 0x0 + buffer_num * TileLen * sizeof(InputT)); + TASSIGN(b_l0_tile[buffer_num], 0x0 + buffer_num * TileLen * sizeof(InputT)); + TASSIGN(c_l0_tile[buffer_num], + 0x0 + buffer_num * TileLen * sizeof(OutputT)); + } + TLOAD(I_neg_l1_tile, I_neg_global_in); + set_flag(PIPE_MTE2, PIPE_MTE1, static_cast(0)); + wait_flag(PIPE_MTE2, PIPE_MTE1, static_cast(0)); + + PrepareAuxiliaryMatrices( + I_neg_l1_tile, Zero_l1_tile, I_l1_tile, a_l0_tile[0], b_l0_tile[0], + c_l0_tile[0]); + + const uint32_t max_iters_per_aic = + CeilDiv(total_tiles, (uint32_t)(NumTilesPerCubeIter * get_block_num())); + + uint32_t bsnd_tile_offsets[NumTilesPerCubeIter] = {0}; + uint32_t bsnd_tile_valid_sizes[NumTilesPerCubeIter] = {0}; + + uint32_t next_tile_id_that_waits_for_pipe_fix_pipe_m = 0; + set_flag(PIPE_FIX, PIPE_M, + static_cast(next_tile_id_that_waits_for_pipe_fix_pipe_m)); + for (uint32_t tile_id = 0; tile_id < NumTilesPerCubeIter; ++tile_id) { + set_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); + } + for (uint32_t cube_iter = 0; cube_iter < max_iters_per_aic; ++cube_iter) { + const uint32_t global_index = + (cube_iter * get_block_num() + get_block_idx()) * NumTilesPerCubeIter; + if (global_index >= total_tiles) { + break; + } + for (uint32_t tile_id = 0; (tile_id < NumTilesPerCubeIter) && + (global_index + tile_id < total_tiles); + ++tile_id) { + if constexpr (IsBSND) { + const uint32_t global_tile_id = global_index + tile_id; + if (chunk_indices != nullptr && chunk_valid_sizes != nullptr) { + const BSNDVarlenTileInfo tile_info = + GetBSNDVarlenTileInfoFromChunkMetadata( + global_tile_id, num_bsnd_heads, MatrixSize, chunk_indices, + chunk_valid_sizes); + bsnd_tile_offsets[tile_id] = tile_info.bsnd_offset; + bsnd_tile_valid_sizes[tile_id] = tile_info.valid_size; + } else if (chunk_sequence_prefix != nullptr && cu_seqlens != nullptr) { + const BSNDVarlenTileInfo tile_info = + GetBSNDVarlenTileInfoFromChunkPrefix( + global_tile_id, num_bsnd_heads, MatrixSize, cu_seqlens, + chunk_sequence_prefix); + bsnd_tile_offsets[tile_id] = tile_info.bsnd_offset; + bsnd_tile_valid_sizes[tile_id] = tile_info.valid_size; + } else if (cu_seqlens != nullptr) { + const BSNDVarlenTileInfo tile_info = GetBSNDVarlenTileInfoFromCuSeqlens( + global_tile_id, num_bsnd_heads, MatrixSize, cu_seqlens); + bsnd_tile_offsets[tile_id] = tile_info.bsnd_offset; + bsnd_tile_valid_sizes[tile_id] = tile_info.valid_size; + } else { + bsnd_tile_offsets[tile_id] = + GetBSNDFixedTileOffset(global_tile_id, num_bsnd_heads, MatrixSize); + bsnd_tile_valid_sizes[tile_id] = MatrixSize; + } + const uint32_t bsnd_offset = bsnd_tile_offsets[tile_id]; + const uint32_t valid_size = bsnd_tile_valid_sizes[tile_id]; + const int row_stride = static_cast(MatrixSize * num_bsnd_heads); + wait_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); + if (valid_size < MatrixSize) { + TileL1ABDyn Y_dyn_l1_tile(valid_size, valid_size); + TASSIGN(Y_dyn_l1_tile, + 0x0 + (5 + tile_id) * TileLen * sizeof(InputT)); + GlobalTileInDyn M_global_in_dyn( + M + bsnd_offset, + {1, 1, 1, static_cast(valid_size), static_cast(valid_size)}, + {1, 1, 1, row_stride, 1}); + TLOAD(Y_dyn_l1_tile, M_global_in_dyn); + set_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); + wait_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); + TFILLPAD(Y_dyn_l1_tile, Y_dyn_l1_tile); + } else { + GlobalTileIn M_global_in(M + bsnd_offset, {}, {row_stride}); + TLOAD(Y_l1_tile[tile_id], M_global_in); + } + } else { + GlobalTileIn M_global_in(M + (global_index + tile_id) * TileLen); + wait_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); + TLOAD(Y_l1_tile[tile_id], M_global_in); + } + set_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); + } + + constexpr uint32_t final_c_buffer_index = MatrixSize > FractalSize ? 1 : 0; + for (uint32_t tile_id = 0; (tile_id < NumTilesPerCubeIter) && + (global_index + tile_id < total_tiles); + ++tile_id) { + wait_flag(PIPE_FIX, PIPE_M, static_cast(tile_id)); + wait_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); + + InvertSingleTile( + X_l1_tile, I_l1_tile, I_neg_l1_tile, M_neg_l1_tile, Zero_l1_tile, + Y_l1_tile[tile_id], a_l0_tile, b_l0_tile, c_l0_tile, tile_id); + + set_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); + + if constexpr (IsBSND) { + const uint32_t bsnd_offset = bsnd_tile_offsets[tile_id]; + const uint32_t valid_size = bsnd_tile_valid_sizes[tile_id]; + const int row_stride = static_cast(MatrixSize * num_bsnd_heads); + if (valid_size < MatrixSize) { + const event_t event_0 = static_cast(tile_id); + const event_t event_1 = + static_cast(tile_id + NumTilesPerCubeIter); + TileL0CDyn c_l0_tail_tile(valid_size, valid_size); + TASSIGN(c_l0_tail_tile, + 0x0 + final_c_buffer_index * TileLen * sizeof(OutputT)); + if constexpr (final_c_buffer_index == 1) { + set_flag(PIPE_M, PIPE_FIX, event_1); + wait_flag(PIPE_M, PIPE_FIX, event_1); + } else { + set_flag(PIPE_M, PIPE_FIX, event_0); + wait_flag(PIPE_M, PIPE_FIX, event_0); + } + set_flag(PIPE_FIX, PIPE_MTE3, static_cast(tile_id)); + wait_flag(PIPE_FIX, PIPE_MTE3, static_cast(tile_id)); + GlobalTileOutDyn M_inv_global_out_dyn( + M_inv + bsnd_offset, + {1, 1, 1, static_cast(valid_size), static_cast(valid_size)}, + {1, 1, 1, row_stride, 1}); + TSTORE(M_inv_global_out_dyn, c_l0_tail_tile); + } else { + GlobalTileOut M_inv_global_out(M_inv + bsnd_offset, {}, + {row_stride}); + TSTORE(M_inv_global_out, c_l0_tile[final_c_buffer_index]); + } + } else { + GlobalTileOut M_inv_global_out(M_inv + + (global_index + tile_id) * TileLen); + TSTORE(M_inv_global_out, c_l0_tile[final_c_buffer_index]); + } + next_tile_id_that_waits_for_pipe_fix_pipe_m = + (tile_id + 1) % NumTilesPerCubeIter; + set_flag( + PIPE_FIX, PIPE_M, + static_cast(next_tile_id_that_waits_for_pipe_fix_pipe_m)); + } + } + for (uint32_t tile_id = 0; tile_id < NumTilesPerCubeIter; ++tile_id) { + wait_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); + } + wait_flag(PIPE_FIX, PIPE_M, + static_cast(next_tile_id_that_waits_for_pipe_fix_pipe_m)); +} + +/* + * @brief: Varlen BSND kernel. + * + * The input/output tensors stay unpadded. For tail chunks with size + * `actual_size < MatrixSize`, the kernel: + * 1. derives the chunk row-start and runtime size from `cu_seqlens` + * 2. loads only the valid `actual_size x actual_size` prefix via dynamic TLOAD + * 3. zero-fills the remaining rows/cols in-place via TFILLPAD_INPLACE + * 4. runs the original dense recursive inverse on the materialized full tile + * 5. stores only the valid `actual_size x actual_size` prefix back to GM + */ +template +AICORE inline void TriInvRecUnrollKernelBSNDVarlen( + __gm__ OutputT* M_inv, __gm__ InputT* M, __gm__ InputT* I_neg, + uint32_t total_tiles, uint32_t num_bsnd_heads, __gm__ int32_t* cu_seqlens) { + constexpr uint32_t TileLen = MatrixSize * MatrixSize; + constexpr uint32_t FractalSize = 16; + constexpr uint32_t NumL0Buffers = 2; + + if (get_block_idx() * NumTilesPerCubeIter >= total_tiles) { + return; + } + + using GlobalTileShapeIn = + TileShape2D; + using GlobalTileStridesIn = Stride<1, 1, 1, -1, 1>; + using GlobalTileIn = + GlobalTensor; + + using GlobalTileDynShape = Shape<1, 1, 1, DYNAMIC, DYNAMIC>; + using GlobalTileDynStride = Stride<1, 1, 1, DYNAMIC, 1>; + using GlobalTileInDyn = + GlobalTensor; + using GlobalTileOutDyn = + GlobalTensor; + + using GlobalTileStridesINeg = + BaseShape2D; + using GlobalTileINeg = GlobalTensor; + + using GlobalTileShapeOut = + TileShape2D; + using GlobalTileStridesOut = Stride<1, 1, 1, -1, 1>; + using GlobalTileOut = GlobalTensor; + + using TileL1AB = + Tile; + using TileL1ABDyn = Tile; + + using TileL0A = TileLeft; + using TileL0B = TileRight; + using TileL0C = TileAcc; + using TileL0CDyn = TileAcc; + + GlobalTileINeg I_neg_global_in(I_neg); + + TileL1AB X_l1_tile; + TileL1AB I_l1_tile; + TileL1AB I_neg_l1_tile; + TileL1AB M_neg_l1_tile; + TileL1AB Zero_l1_tile; + TileL1AB Y_l1_tile[NumTilesPerCubeIter]; + + TileL0A a_l0_tile[NumL0Buffers]; + TileL0B b_l0_tile[NumL0Buffers]; + TileL0C c_l0_tile[NumL0Buffers]; + + TASSIGN(I_l1_tile, 0x0); + TASSIGN(I_neg_l1_tile, 0x0 + TileLen * sizeof(InputT)); + TASSIGN(Zero_l1_tile, 0x0 + 2 * TileLen * sizeof(InputT)); + TASSIGN(M_neg_l1_tile, 0x0 + 3 * TileLen * sizeof(InputT)); + TASSIGN(X_l1_tile, 0x0 + 4 * TileLen * sizeof(InputT)); + for (uint32_t tile_id = 0; tile_id < NumTilesPerCubeIter; ++tile_id) { + TASSIGN(Y_l1_tile[tile_id], 0x0 + (5 + tile_id) * TileLen * sizeof(InputT)); + } + + for (uint32_t buffer_num = 0; buffer_num < NumL0Buffers; ++buffer_num) { + TASSIGN(a_l0_tile[buffer_num], 0x0 + buffer_num * TileLen * sizeof(InputT)); + TASSIGN(b_l0_tile[buffer_num], 0x0 + buffer_num * TileLen * sizeof(InputT)); + TASSIGN(c_l0_tile[buffer_num], + 0x0 + buffer_num * TileLen * sizeof(OutputT)); + } + TLOAD(I_neg_l1_tile, I_neg_global_in); + set_flag(PIPE_MTE2, PIPE_MTE1, static_cast(0)); + wait_flag(PIPE_MTE2, PIPE_MTE1, static_cast(0)); + + PrepareAuxiliaryMatrices( + I_neg_l1_tile, Zero_l1_tile, I_l1_tile, a_l0_tile[0], b_l0_tile[0], + c_l0_tile[0]); + + const uint32_t max_iters_per_aic = + CeilDiv(total_tiles, (uint32_t)(NumTilesPerCubeIter * get_block_num())); + constexpr uint32_t final_c_buffer_index = MatrixSize > FractalSize ? 1 : 0; + + for (uint32_t cube_iter = 0; cube_iter < max_iters_per_aic; ++cube_iter) { + const uint32_t global_index = + (cube_iter * get_block_num() + get_block_idx()) * NumTilesPerCubeIter; + if (global_index >= total_tiles) { + break; + } + + for (uint32_t tile_id = 0; (tile_id < NumTilesPerCubeIter) && + (global_index + tile_id < total_tiles); + ++tile_id) { + const uint32_t global_tile_id = global_index + tile_id; + const BSNDVarlenTileInfo tile_info = GetBSNDVarlenTileInfoFromCuSeqlens( + global_tile_id, num_bsnd_heads, MatrixSize, cu_seqlens); + const uint32_t valid_size = tile_info.valid_size; + const uint32_t bsnd_offset = tile_info.bsnd_offset; + const int row_stride = static_cast(MatrixSize * num_bsnd_heads); + + if (valid_size == MatrixSize) { + GlobalTileIn M_global_in(M + bsnd_offset, {}, {row_stride}); + TLOAD(Y_l1_tile[tile_id], M_global_in); + set_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); + wait_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); + } else { + TileL1ABDyn Y_dyn_l1_tile(valid_size, valid_size); + TASSIGN(Y_dyn_l1_tile, + 0x0 + (5 + tile_id) * TileLen * sizeof(InputT)); + GlobalTileInDyn M_global_in_dyn(M + bsnd_offset, + {1, 1, 1, valid_size, valid_size}, + {1, 1, 1, row_stride, 1}); + TLOAD(Y_dyn_l1_tile, M_global_in_dyn); + set_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); + wait_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); + TFILLPAD(Y_dyn_l1_tile, Y_dyn_l1_tile); + } + + InvertSingleTile( + X_l1_tile, I_l1_tile, I_neg_l1_tile, M_neg_l1_tile, Zero_l1_tile, + Y_l1_tile[tile_id], a_l0_tile, b_l0_tile, c_l0_tile, tile_id); + + if (valid_size == MatrixSize) { + GlobalTileOut M_inv_global_out(M_inv + bsnd_offset, {}, {row_stride}); + TSTORE(M_inv_global_out, c_l0_tile[final_c_buffer_index]); + } else { + const event_t event_0 = static_cast(tile_id); + const event_t event_1 = static_cast(tile_id + NumTilesPerCubeIter); + TileL0CDyn c_l0_tail_tile(valid_size, valid_size); + TASSIGN(c_l0_tail_tile, + 0x0 + final_c_buffer_index * TileLen * sizeof(OutputT)); + if constexpr (final_c_buffer_index == 1) { + set_flag(PIPE_M, PIPE_FIX, event_1); + wait_flag(PIPE_M, PIPE_FIX, event_1); + } else { + set_flag(PIPE_M, PIPE_FIX, event_0); + wait_flag(PIPE_M, PIPE_FIX, event_0); + } + set_flag(PIPE_FIX, PIPE_MTE3, static_cast(tile_id)); + wait_flag(PIPE_FIX, PIPE_MTE3, static_cast(tile_id)); + GlobalTileOutDyn M_inv_global_out_dyn( + M_inv + bsnd_offset, {1, 1, 1, valid_size, valid_size}, + {1, 1, 1, row_stride, 1}); + TSTORE(M_inv_global_out_dyn, c_l0_tail_tile); + } + } + } +} + +template +AICORE void runKernelTriInvRecUnroll(__gm__ OutputT* M_inv, __gm__ InputT* M, + __gm__ InputT* I_neg, uint32_t total_tiles, + uint32_t num_bsnd_heads = 0, + __gm__ int32_t* cu_seqlens = nullptr, + __gm__ int32_t* chunk_sequence_prefix = + nullptr, + __gm__ int32_t* chunk_indices = nullptr, + __gm__ int32_t* chunk_valid_sizes = + nullptr) { +#if (__CHECK_FEATURE_AT_PRECOMPILE) || \ + (__CCE_AICORE__ == 220 && defined(__DAV_C220_CUBE__)) + TriInvRecUnrollKernel(M_inv, M, I_neg, total_tiles, num_bsnd_heads, + cu_seqlens, chunk_sequence_prefix, + chunk_indices, + chunk_valid_sizes); +#else +// Nothing to do on AIV +#endif +} + +template +AICORE void run_tri_inv_rec_unroll(__gm__ float* tensor_out, + __gm__ InputT* tensor_in, + __gm__ InputT* minus_identity_in, + uint32_t matrix_size, uint32_t num_matrices, + uint32_t num_bsnd_heads, + __gm__ int32_t* cu_seqlens, + __gm__ int32_t* chunk_sequence_prefix, + __gm__ int32_t* chunk_indices, + __gm__ int32_t* chunk_valid_sizes) { + static_assert(std::is_same_v, + "tri_inv_rec_unroll supports only fp16."); + switch (matrix_size) { + case 16: + runKernelTriInvRecUnroll( + tensor_out, tensor_in, minus_identity_in, num_matrices, num_bsnd_heads, + cu_seqlens, chunk_sequence_prefix, chunk_indices, + chunk_valid_sizes); + break; + case 32: + runKernelTriInvRecUnroll( + tensor_out, tensor_in, minus_identity_in, num_matrices, num_bsnd_heads, + cu_seqlens, chunk_sequence_prefix, chunk_indices, + chunk_valid_sizes); + break; + case 64: + runKernelTriInvRecUnroll( + tensor_out, tensor_in, minus_identity_in, num_matrices, num_bsnd_heads, + cu_seqlens, chunk_sequence_prefix, chunk_indices, + chunk_valid_sizes); + break; + case 128: + runKernelTriInvRecUnroll( + tensor_out, tensor_in, minus_identity_in, num_matrices, num_bsnd_heads, + cu_seqlens, chunk_sequence_prefix, chunk_indices, + chunk_valid_sizes); + break; + } +} + +extern "C" __global__ AICORE void tri_inv_rec_unroll_fp16( + __gm__ void* tensor_out, __gm__ void* tensor_in, + __gm__ void* minus_identity_in, uint32_t matrix_size, uint32_t num_matrices, + uint32_t num_bsnd_heads, __gm__ void* cu_seqlens, + __gm__ void* chunk_sequence_prefix, + __gm__ void* chunk_indices, __gm__ void* chunk_valid_sizes) { + if (num_bsnd_heads == 0) { + if (num_matrices <= get_block_num()) { + run_tri_inv_rec_unroll( + (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, + (__gm__ half*)minus_identity_in, matrix_size, num_matrices, + num_bsnd_heads, (__gm__ int32_t*)cu_seqlens, + (__gm__ int32_t*)chunk_sequence_prefix, + (__gm__ int32_t*)chunk_indices, (__gm__ int32_t*)chunk_valid_sizes); + } else if (num_matrices <= 2 * get_block_num()) { + run_tri_inv_rec_unroll( + (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, + (__gm__ half*)minus_identity_in, matrix_size, num_matrices, + num_bsnd_heads, (__gm__ int32_t*)cu_seqlens, + (__gm__ int32_t*)chunk_sequence_prefix, + (__gm__ int32_t*)chunk_indices, (__gm__ int32_t*)chunk_valid_sizes); + } else { + run_tri_inv_rec_unroll( + (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, + (__gm__ half*)minus_identity_in, matrix_size, num_matrices, + num_bsnd_heads, (__gm__ int32_t*)cu_seqlens, + (__gm__ int32_t*)chunk_sequence_prefix, + (__gm__ int32_t*)chunk_indices, (__gm__ int32_t*)chunk_valid_sizes); + } + } else { + if (num_matrices <= get_block_num()) { + run_tri_inv_rec_unroll( + (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, + (__gm__ half*)minus_identity_in, matrix_size, num_matrices, + num_bsnd_heads, (__gm__ int32_t*)cu_seqlens, + (__gm__ int32_t*)chunk_sequence_prefix, + (__gm__ int32_t*)chunk_indices, (__gm__ int32_t*)chunk_valid_sizes); + } else if (num_matrices <= 2 * get_block_num()) { + run_tri_inv_rec_unroll( + (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, + (__gm__ half*)minus_identity_in, matrix_size, num_matrices, + num_bsnd_heads, (__gm__ int32_t*)cu_seqlens, + (__gm__ int32_t*)chunk_sequence_prefix, + (__gm__ int32_t*)chunk_indices, (__gm__ int32_t*)chunk_valid_sizes); + } else { + run_tri_inv_rec_unroll( + (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, + (__gm__ half*)minus_identity_in, matrix_size, num_matrices, + num_bsnd_heads, (__gm__ int32_t*)cu_seqlens, + (__gm__ int32_t*)chunk_sequence_prefix, + (__gm__ int32_t*)chunk_indices, (__gm__ int32_t*)chunk_valid_sizes); + } + } +} diff --git a/examples/jit_cpp/fast_inverse/metadata_overhead.md b/examples/jit_cpp/fast_inverse/metadata_overhead.md new file mode 100644 index 00000000..d6b4ddb6 --- /dev/null +++ b/examples/jit_cpp/fast_inverse/metadata_overhead.md @@ -0,0 +1,148 @@ +# Metadata Overhead Comparison + +This note compares three varlen BSND metadata strategies for the fast-inverse PTO kernel. + +## Strategies + +### 1. Device-side metadata from `cu_seqlens` + +Files: +- `kernel_tri_inv_rec_unroll.cpp` +- `fast_inverse.cpp` +- `jit_util_fast_inverse.py` + +Behavior: +- Python passes only `cu_seqlens` for the varlen path. +- The NPU kernel derives each chunk's row offset and `valid_size` by scanning `cu_seqlens` inside `GetBSNDVarlenTileInfoFromCuSeqlens()`. + +Pros: +- Matches the Triton-style deployment API. +- No host-side metadata buffers to build or upload. +- Best end-to-end latency in the current measurements. + +Cons: +- Adds a small amount of device-side work per tile. + +### 2. Device-side compact chunk-prefix metadata + +Files: +- `host_chunk_metadata.cpp` +- `host_metadata_util.py` +- `kernel_tri_inv_rec_unroll.cpp` +- `fast_inverse.cpp` +- `jit_util_fast_inverse.py` + +Behavior: +- A small host C++ helper builds a compact per-sequence cumulative chunk-count prefix. +- Python uploads that prefix together with `cu_seqlens`. +- The NPU kernel uses the prefix to binary-search the owning sequence for each chunk, instead of walking all prior sequences. + +Pros: +- Reduces in-kernel metadata work compared with full `cu_seqlens` walking. +- Metadata payload is much smaller than full per-chunk host metadata. +- Better end-to-end than the full host metadata path. + +Cons: +- Still requires host preprocessing and one extra metadata upload. +- Still slower end-to-end than pure device-side `cu_seqlens` scanning in the current measurements. + +### 3. Host-side C++ metadata precompute + +Files: +- `host_chunk_metadata.cpp` +- `host_metadata_util.py` +- `kernel_tri_inv_rec_unroll.cpp` +- `fast_inverse.cpp` +- `jit_util_fast_inverse.py` + +Behavior: +- A small host C++ helper builds `chunk_indices` and `chunk_valid_sizes` from `cu_seqlens`. +- Python uploads those buffers to NPU memory. +- The NPU kernel uses the precomputed metadata directly and skips the in-kernel `cu_seqlens` scan. + +Pros: +- Simpler varlen metadata lookup inside the kernel. +- Kernel-only time is slightly lower or roughly equal to the device-side scan path. + +Cons: +- Host metadata build plus host-to-device upload dominates the savings. +- Worse end-to-end latency in the current measurements. + +## Quick Perf Summary + +Benchmark setup: +- script: `benchmark_bsnd_fast_inverse.py` +- input style: Triton-unit-test-like `k` / `beta` generation +- config: `B=32`, `H=4`, `feature_dim=64` +- seqlens: `2048,8192` +- repeats: `10` +- warmup: `3` +- true-varlen samples: `0` + +### `chunk_size=64` + +| T | Device scan total | Prefix total | Prefix kernel | Prefix metadata | Host total | Host kernel | Host metadata | +|---|---:|---:|---:|---:|---:|---:|---:| +| 2048 | 564 us | 746 us | 559 us | 187 us | 836 us | 559 us | 277 us | +| 8192 | 2071 us | 2235 us | 2049 us | 186 us | 2340 us | 2048 us | 292 us | + +Takeaway: +- Prefix metadata cuts host metadata overhead from about `277-292 us` down to about `186-187 us`. +- Kernel-only time for prefix is slightly better than full device-side scanning, by about `5-22 us`. +- End to end, plain device-side `cu_seqlens` scanning is still best. + +### `chunk_size=128` + +| T | Device scan total | Prefix total | Prefix kernel | Prefix metadata | Host total | Host kernel | Host metadata | +|---|---:|---:|---:|---:|---:|---:|---:| +| 2048 | 1085 us | 1298 us | 1084 us | 214 us | 1363 us | 1080 us | 283 us | +| 8192 | 4065 us | 4253 us | 4056 us | 197 us | 4351 us | 4063 us | 288 us | + +Takeaway: +- Prefix metadata cuts host metadata overhead from about `283-288 us` down to about `197-214 us`. +- Kernel-only improvement versus device scan is tiny, around `1-9 us`. +- End to end, the pure device-side scan still wins. + +## Conclusion + +For the current implementation and tested shapes, the device-side `cu_seqlens` scan is still the best overall strategy. + +Reason: +- The compact prefix path does reduce kernel-side metadata work and is clearly better than full host per-chunk metadata. +- But the saved kernel time is still much smaller than the cost of building and uploading the prefix. +- The full host per-chunk metadata path remains the slowest end-to-end option. + +## How To Reproduce + +From `examples/jit_cpp/fast_inverse/`: + +```bash +export PTO_LIB_PATH=/sources/pto-isa + +python benchmark_bsnd_fast_inverse.py \ + --chunk-size 64 \ + --seqlens 2048,8192 \ + --repeats 10 \ + --warmup 3 \ + --true-varlen-samples 0 + +python benchmark_bsnd_fast_inverse.py \ + --chunk-size 128 \ + --seqlens 2048,8192 \ + --repeats 10 \ + --warmup 3 \ + --true-varlen-samples 0 +``` + +The benchmark writes: +- `benchmark_results/bench_results_bsnd_fast_inverse_64.csv` +- `benchmark_results/bench_results_bsnd_fast_inverse_128.csv` +- `benchmark_results/bench_results_bsnd_fast_inverse_bw_64.png` +- `benchmark_results/bench_results_bsnd_fast_inverse_bw_128.png` + +Relevant CSV fields: +- `metadata_strategy` +- `time_us` +- `kernel_time_us` +- `metadata_time_us` +- `bw_gbs` diff --git a/examples/jit_cpp/fast_inverse/run_fast_inverse.py b/examples/jit_cpp/fast_inverse/run_fast_inverse.py new file mode 100644 index 00000000..b7dbad99 --- /dev/null +++ b/examples/jit_cpp/fast_inverse/run_fast_inverse.py @@ -0,0 +1,443 @@ +# -------------------------------------------------------------------------------- +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# All rights reserved. +# See LICENSE in the root of the software repository: +# https://github.com/huawei-csl/pto-kernels/ +# for the full License text. +# -------------------------------------------------------------------------------- + +""" +Correctness tests for the JIT-compiled triangular inverse (recursive unroll) +kernel. Run from the fast_inverse/ directory: + + python run_fast_inverse.py +""" + +import numpy as np +import torch +import torch_npu # noqa: F401 – registers the NPU backend + +from jit_util_fast_inverse import jit_compile + +# --------------------------------------------------------------------------- +# Reproducibility +# --------------------------------------------------------------------------- +torch.manual_seed(42) +np.random.seed(42) + + +# --------------------------------------------------------------------------- +# Matrix generators (identical to the unit-test suite) +# --------------------------------------------------------------------------- +def random_triu_matrix(n, block_dim_x, block_dim_y, scale=0.1): + return scale * torch.triu( + torch.rand((block_dim_x, block_dim_y, n, n)), + diagonal=1, + ) + + +def ones_triu_matrix(n, block_dim_x, block_dim_y): + return torch.triu(torch.ones((block_dim_x, block_dim_y, n, n)), diagonal=1) + + +def block_ones_triu_matrix(n, block_dim_x, block_dim_y): + U_ = np.ones((16, 16)) + n_blocks = n // 16 + U = np.zeros((block_dim_x, block_dim_y, n, n)) + for x in range(block_dim_x): + for y in range(block_dim_y): + for i in range(n_blocks): + s, e = i * 16, i * 16 + 16 + U[x, y, s:e, s:e] = U_ + return torch.from_numpy(np.triu(U, 1)) + + +def block_random_triu_matrix(n, block_dim_x, block_dim_y, scale=0.2): + U_ = np.triu(scale * np.random.rand(16, 16), k=1) + U = np.zeros((block_dim_x, block_dim_y, n, n)) + for x in range(block_dim_x): + for y in range(block_dim_y): + for i in range(0, n, 16): + U[x, y, i : i + 16, i : i + 16] = U_.copy() + return torch.from_numpy(U) + + +# --------------------------------------------------------------------------- +# Reference implementation (CPU / numpy) +# --------------------------------------------------------------------------- +def linalg_inv_ref(U: torch.Tensor) -> torch.Tensor: + """Invert (U + I) for each matrix in the batch using numpy.""" + n = U.shape[-1] + identity = np.eye(n, dtype=np.double) + out = np.zeros(U.shape, dtype=np.double) + for x in range(U.shape[0]): + for y in range(U.shape[1]): + out[x, y] = np.linalg.inv(U[x, y].numpy().astype(np.double) + identity) + return torch.from_numpy(out) + + +def invert_single_chunk_ref(U: torch.Tensor) -> torch.Tensor: + """Invert one upper-triangular chunk U where U is (..., m, m).""" + m = U.shape[-1] + return torch.from_numpy( + np.linalg.inv(U.numpy().astype(np.double) + np.eye(m, dtype=np.double)) + ) + + +# --------------------------------------------------------------------------- +# Kernel helpers +# --------------------------------------------------------------------------- +def _make_minus_identity(matrix_size: int, device: str) -> torch.Tensor: + I_neg = torch.zeros(matrix_size, matrix_size, dtype=torch.half, device=device) + I_neg.fill_diagonal_(-1) + return I_neg + + +def _count_varlen_chunks( + cu_seqlens: torch.Tensor | list[int], + chunk_size: int, +) -> int: + if isinstance(cu_seqlens, torch.Tensor): + cu_seqlens_list = [int(x) for x in cu_seqlens.detach().cpu().tolist()] + else: + cu_seqlens_list = [int(x) for x in cu_seqlens] + return sum( + (cu_seqlens_list[i + 1] - cu_seqlens_list[i] + chunk_size - 1) // chunk_size + for i in range(len(cu_seqlens_list) - 1) + ) + + +def _run_kernel(tri_inv_func, U_fp16: torch.Tensor): + """ + Allocate output, build -I, run kernel, return fp64 CPU result. + + U_fp16 : (block_dim_x, block_dim_y, n, n) half tensor on NPU. + """ + matrix_size = U_fp16.shape[-1] + num_matrices = U_fp16.numel() // (matrix_size * matrix_size) + device = U_fp16.device + + tensor_out = torch.zeros_like(U_fp16, dtype=torch.float32) + I_neg = _make_minus_identity(matrix_size, str(device)) + + torch.npu.synchronize() + tri_inv_func(tensor_out, U_fp16, I_neg, matrix_size, num_matrices) + torch.npu.synchronize() + + return tensor_out.cpu().to(torch.float64) + + +def _run_kernel_bsnd( + tri_inv_func, + U_bsnd_fp16: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, +): + """ + Run the kernel in BSND mode and return fp64 CPU result. + + U_bsnd_fp16 : (B, S, N, D) half tensor on NPU where each (D, D) block + along the S dimension is one matrix to invert. + cu_seqlens : optional int32 tensor containing cumulative sequence lengths + for varlen BSND inputs. + """ + matrix_size = U_bsnd_fp16.shape[-1] + num_bsnd_heads = U_bsnd_fp16.shape[-2] + if cu_seqlens is not None: + num_matrices = _count_varlen_chunks(cu_seqlens, matrix_size) * num_bsnd_heads + else: + num_matrices = U_bsnd_fp16.numel() // (matrix_size * matrix_size) + device = U_bsnd_fp16.device + + tensor_out = torch.zeros_like(U_bsnd_fp16, dtype=torch.float32) + I_neg = _make_minus_identity(matrix_size, str(device)) + if cu_seqlens is not None: + cu_seqlens = cu_seqlens.to(device=device, dtype=torch.int32).contiguous() + + torch.npu.synchronize() + tri_inv_func( + tensor_out, + U_bsnd_fp16, + I_neg, + matrix_size, + num_matrices, + num_bsnd_heads, + cu_seqlens=cu_seqlens, + ) + torch.npu.synchronize() + + return tensor_out.cpu().to(torch.float64) + + +def _build_varlen_bsnd_case( + gen, + cu_seqlens: list[int], + num_heads: int, + chunk_size: int, +): + """ + Build an unpadded BSND tensor plus reference output for varlen testing. + + Each sequence contributes only its true rows in the packed BSND tensor. + """ + seq_lens = [ + cu_seqlens[i + 1] - cu_seqlens[i] + for i in range(len(cu_seqlens) - 1) + ] + print( + f" varlen sequence lengths: {seq_lens} " + f"(chunk_size={chunk_size}, num_heads={num_heads})" + ) + + total_tokens = cu_seqlens[-1] + num_chunks = sum( + (cu_seqlens[i + 1] - cu_seqlens[i] + chunk_size - 1) // chunk_size + for i in range(len(cu_seqlens) - 1) + ) + chunk_mats = gen(chunk_size, num_chunks, num_heads).to(torch.float64) + + U = torch.zeros((1, total_tokens, num_heads, chunk_size), dtype=torch.float64) + golden = torch.zeros((1, total_tokens, num_heads, chunk_size), dtype=torch.float64) + + chunk_idx = 0 + + for seq_idx in range(len(cu_seqlens) - 1): + seq_start = cu_seqlens[seq_idx] + seq_end = cu_seqlens[seq_idx + 1] + for chunk_start in range(seq_start, seq_end, chunk_size): + actual_size = min(chunk_size, seq_end - chunk_start) + chunk = chunk_mats[chunk_idx] + for head_idx in range(num_heads): + U_valid = chunk[head_idx, :actual_size, :actual_size] + U[ + 0, + chunk_start : chunk_start + actual_size, + head_idx, + :actual_size, + ] = U_valid + golden[ + 0, + chunk_start : chunk_start + actual_size, + head_idx, + :actual_size, + ] = invert_single_chunk_ref(U_valid) + + chunk_idx += 1 + + return ( + U, + golden, + torch.tensor(cu_seqlens, dtype=torch.int32), + ) + + +# --------------------------------------------------------------------------- +# Single test – standard layout +# --------------------------------------------------------------------------- +def _test_case( + tri_inv_func, + U: torch.Tensor, + atol: float, + rtol: float, + ftol: float, + label: str, +): + U_fp16 = U.to(torch.half) + golden = linalg_inv_ref(U_fp16) + + actual = _run_kernel(tri_inv_func, U_fp16.npu()) + + frob = torch.sqrt( + torch.sum((golden - actual) ** 2) / torch.sum(golden ** 2) + ).item() + + assert np.allclose( + actual.numpy(), + golden.numpy(), + atol=atol, + rtol=rtol, + ), f"[{label}] allclose failed — shape {U.shape}, rtol={rtol}" + assert frob <= ftol, f"[{label}] Frobenius error {frob:.2e} > {ftol:.2e}" + + print(f" PASS {label} frob={frob:.2e}") + + +# --------------------------------------------------------------------------- +# Single test – BSND layout +# --------------------------------------------------------------------------- +def _test_case_bsnd( + tri_inv_func, + U: torch.Tensor, + B: int, + S: int, + N: int, + D: int, + atol: float, + rtol: float, + ftol: float, + label: str, +): + """ + U has shape (B*S//D, N, D, D) – the raw generator output. + It is converted to (B, S, N, D) before being fed to the kernel. + """ + U_fp16 = U.to(torch.half) + golden = linalg_inv_ref(U_fp16) + golden = golden.transpose(1, 2).contiguous().reshape(B, S, N, D) + + U_bsnd = U_fp16.transpose(1, 2).contiguous().reshape(B, S, N, D) + actual = _run_kernel_bsnd(tri_inv_func, U_bsnd.npu()) + + frob = torch.sqrt( + torch.sum((golden - actual) ** 2) / torch.sum(golden ** 2) + ).item() + + assert np.allclose( + actual.numpy(), + golden.numpy(), + atol=atol, + rtol=rtol, + ), f"[{label}] allclose failed — shape {U_bsnd.shape}, rtol={rtol}" + assert frob <= ftol, f"[{label}] Frobenius error {frob:.2e} > {ftol:.2e}" + + print(f" PASS {label} frob={frob:.2e}") + + +def _test_case_bsnd_varlen( + tri_inv_func, + gen, + cu_seqlens: list[int], + N: int, + D: int, + atol: float, + rtol: float, + ftol: float, + label: str, +): + U_varlen, golden, cu_seqlens_tensor = _build_varlen_bsnd_case( + gen, + cu_seqlens, + N, + D, + ) + actual_varlen = _run_kernel_bsnd( + tri_inv_func, + U_varlen.to(torch.half).npu(), + cu_seqlens=cu_seqlens_tensor.npu(), + ) + actual = actual_varlen + + frob = torch.sqrt( + torch.sum((golden - actual) ** 2) / torch.sum(golden ** 2) + ).item() + + assert np.allclose( + actual.numpy(), + golden.numpy(), + atol=atol, + rtol=rtol, + ), f"[{label}] allclose failed — shape {actual.shape}, rtol={rtol}" + assert frob <= ftol, f"[{label}] Frobenius error {frob:.2e} > {ftol:.2e}" + + print(f" PASS {label} frob={frob:.2e}") + + +# --------------------------------------------------------------------------- +# Test suite +# --------------------------------------------------------------------------- +def run_tests(tri_inv_func): + cases = [ + ("block_ones", block_ones_triu_matrix, 0, 0, 0), + ("ones", ones_triu_matrix, 0, 0, 0), + ("block_random", block_random_triu_matrix, 5e-5, 0.1, 1e-4), + ("random", random_triu_matrix, 5e-5, 0.1, 1e-4), + ] + + total = passed = 0 + + print("=== Standard layout ===") + sizes = [16, 32, 64, 128] + x_dims = [1, 2, 4] + y_dims = [2, 4] + + for n in sizes: + for bdx in x_dims: + for bdy in y_dims: + for name, gen, atol, rtol, ftol in cases: + total += 1 + label = f"n={n} x={bdx} y={bdy} [{name}]" + try: + U = gen(n, bdx, bdy) + _test_case(tri_inv_func, U, atol, rtol, ftol, label) + passed += 1 + except AssertionError as err: + print(f" FAIL {label}: {err}") + + print("\n=== BSND layout ===") + bsnd_configs = [ + (B, S, N, D) + for B in [1, 4] + for S in [128, 256] + for N in [4, 8] + for D in [16, 32, 64, 128] + if S % D == 0 + ] + + for B, S, N, D in bsnd_configs: + for name, gen, atol, rtol, ftol in cases: + total += 1 + label = f"B={B} S={S} N={N} D={D} [{name}]" + try: + U = gen(D, B * S // D, N) + _test_case_bsnd(tri_inv_func, U, B, S, N, D, atol, rtol, ftol, label) + passed += 1 + except AssertionError as err: + print(f" FAIL {label}: {err}") + + print("\n=== BSND varlen layout ===") + varlen_configs = [ + (4, 16, [0, 15]), + (4, 32, [0, 256, 500, 1000]), + (4, 64, [0, 15, 100, 300, 1200, 2000]), + (4, 16, [0, 1, 100, 300, 1200, 2048]), + (4, 32, [0, 200, 512, 1200, 2048]), + ] + + for N, D, cu_seqlens in varlen_configs: + for name, gen, atol, rtol, ftol in cases: + total += 1 + label = f"N={N} D={D} cu={cu_seqlens} [{name}]" + try: + _test_case_bsnd_varlen( + tri_inv_func, + gen, + cu_seqlens, + N, + D, + atol, + rtol, + ftol, + label, + ) + passed += 1 + except AssertionError as err: + print(f" FAIL {label}: {err}") + + print(f"\n{passed}/{total} tests passed.") + return passed == total + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- +if __name__ == "__main__": + import os + + torch.npu.set_device("npu:0") + + src = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fast_inverse.cpp") + print(f"Compiling {src} ...") + tri_inv_func = jit_compile(src) + print("Compilation successful.\n") + + ok = run_tests(tri_inv_func) + raise SystemExit(0 if ok else 1) diff --git a/examples/jit_cpp/fast_inverse/run_fast_inverse_varlen_like_triton.py b/examples/jit_cpp/fast_inverse/run_fast_inverse_varlen_like_triton.py new file mode 100644 index 00000000..5ef553b5 --- /dev/null +++ b/examples/jit_cpp/fast_inverse/run_fast_inverse_varlen_like_triton.py @@ -0,0 +1,202 @@ +""" +Standalone varlen BSND correctness runner that mirrors the Triton unit tests: +https://github.com/fla-org/flash-linear-attention/blob/v0.4.2/tests/ops/test_solve_tril.py + +But changes: +1. uses fp16 inputs because the PTO kernel currently supports fp16 only +2. emulates `chunk_scaled_dot_kkt_fwd` in PyTorch because Triton is unavailable + +Run from the fast_inverse/ directory: + + export PTO_LIB_PATH=/sources/pto-isa + python run_fast_inverse_varlen_like_triton.py +""" + +from __future__ import annotations + +import os + +import numpy as np +import torch +import torch.nn.functional as F +import torch_npu # noqa: F401 + +from jit_util_fast_inverse import jit_compile + + +torch.manual_seed(42) +np.random.seed(42) + + +def _make_minus_identity(matrix_size: int, device: torch.device) -> torch.Tensor: + minus_identity = torch.zeros( + (matrix_size, matrix_size), + dtype=torch.float16, + device=device, + ) + minus_identity.fill_diagonal_(-1) + return minus_identity + + +def _count_varlen_chunks(cu_seqlens: torch.Tensor, chunk_size: int) -> int: + return sum( + (int(eos) - int(bos) + chunk_size - 1) // chunk_size + for bos, eos in zip(cu_seqlens[:-1].tolist(), cu_seqlens[1:].tolist(), strict=False) + ) + + +def _chunk_scaled_dot_kkt_fwd_emulated( + k: torch.Tensor, + beta: torch.Tensor, + cu_seqlens: torch.Tensor, + chunk_size: int, +) -> torch.Tensor: + t_total = int(cu_seqlens[-1].item()) + num_heads = k.shape[2] + A = torch.zeros((1, t_total, num_heads, chunk_size), dtype=k.dtype, device=k.device) + + for bos, eos in zip(cu_seqlens[:-1].tolist(), cu_seqlens[1:].tolist(), strict=False): + for chunk_start in range(bos, eos, chunk_size): + chunk_end = min(chunk_start + chunk_size, eos) + actual_size = chunk_end - chunk_start + k_chunk = k[:, chunk_start:chunk_end].transpose(1, 2).to(torch.float32) + beta_chunk = ( + beta[:, chunk_start:chunk_end] + .transpose(1, 2) + .unsqueeze(-1) + .to(torch.float32) + ) + scores = torch.matmul(k_chunk, k_chunk.transpose(-1, -2)) + scores = torch.tril(scores * beta_chunk, diagonal=-1).to(k.dtype) + A[:, chunk_start:chunk_end, :, :actual_size] = scores.transpose(1, 2) + + return A + + +def _reference_inverse(A: torch.Tensor, cu_seqlens: torch.Tensor, chunk_size: int) -> torch.Tensor: + A_cpu = A.cpu().to(torch.float64) + ref = torch.zeros_like(A_cpu, dtype=torch.float64) + for bos, eos in zip(cu_seqlens[:-1].tolist(), cu_seqlens[1:].tolist(), strict=False): + for chunk_start in range(bos, eos, chunk_size): + actual_size = min(chunk_size, eos - chunk_start) + ref[:, chunk_start : chunk_start + actual_size, :, :actual_size] = torch.inverse( + A_cpu[:, chunk_start : chunk_start + actual_size, :, :actual_size].transpose(1, 2) + + torch.eye(actual_size, dtype=torch.float64)[None, None, ...] + ).transpose(1, 2) + return ref + + +def _transpose_valid_chunks( + A: torch.Tensor, + cu_seqlens: torch.Tensor, + chunk_size: int, +) -> torch.Tensor: + transposed = torch.zeros_like(A) + for bos, eos in zip(cu_seqlens[:-1].tolist(), cu_seqlens[1:].tolist(), strict=False): + for chunk_start in range(bos, eos, chunk_size): + actual_size = min(chunk_size, eos - chunk_start) + chunk = A[:, chunk_start : chunk_start + actual_size, :, :actual_size] + transposed[:, chunk_start : chunk_start + actual_size, :, :actual_size] = chunk.transpose( + 1, 3 + ) + return transposed + + +def _run_pto_varlen(tri_inv_func, A: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + chunk_size = A.shape[-1] + num_heads = A.shape[-2] + num_matrices = _count_varlen_chunks(cu_seqlens, chunk_size) * num_heads + tensor_out = torch.zeros_like(A, dtype=torch.float32) + minus_identity = _make_minus_identity(chunk_size, A.device) + + torch.npu.synchronize() + tri_inv_func( + tensor_out, + A, + minus_identity, + chunk_size, + num_matrices, + num_heads, + cu_seqlens=cu_seqlens, + ) + torch.npu.synchronize() + return tensor_out.cpu().to(torch.float64) + + +def _run_case( + tri_inv_func, + H: int, + D: int, + chunk_size: int, + cu_seqlens_list: list[int], + atol: float = 5e-4, + rtol: float = 5e-2, + ftol: float = 1e-4, +) -> None: + device = torch.device("npu:0") + T = cu_seqlens_list[-1] + cu_seqlens = torch.tensor(cu_seqlens_list, dtype=torch.int32, device=device) + + # Match the Triton varlen test structure, using fp16 instead of bf16. + k = F.normalize(torch.randn((1, T, H, D), dtype=torch.float16, device=device), dim=-1) + beta = torch.randn((1, T, H), dtype=torch.float16, device=device).sigmoid() + A = _chunk_scaled_dot_kkt_fwd_emulated( + k=k, + beta=beta, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + ) + + ref = _reference_inverse(A, cu_seqlens, chunk_size) + tri = _run_pto_varlen( + tri_inv_func, + _transpose_valid_chunks(A, cu_seqlens, chunk_size), + cu_seqlens, + ) + tri = _transpose_valid_chunks(tri, cu_seqlens, chunk_size) + + frob = torch.sqrt(torch.sum((ref - tri) ** 2) / torch.sum(ref ** 2)).item() + torch.testing.assert_close(tri, ref, atol=atol, rtol=rtol) + assert frob <= ftol, f"Frobenius error {frob:.2e} > {ftol:.2e}" + + +def main() -> int: + if "PTO_LIB_PATH" not in os.environ: + fallback = "/sources/pto-isa" + if os.path.exists(fallback): + os.environ["PTO_LIB_PATH"] = fallback + + torch.npu.set_device("npu:0") + + src = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fast_inverse.cpp") + print(f"Compiling {src} ...") + tri_inv_func = jit_compile(src) + print("Compilation successful.\n") + + cases = [ + (4, 64, 16, [0, 15]), + (4, 64, 32, [0, 256, 500, 1000]), + (4, 100, 64, [0, 15, 100, 300, 1200, 2000]), + (4, 64, 16, [0, 1, 100, 300, 1200, 2048]), + (4, 128, 32, [0, 200, 512, 1200, 2048]), + ] + + total = 0 + passed = 0 + print("=== Varlen Like Triton ===") + for H, D, chunk_size, cu_seqlens in cases: + total += 1 + label = f"H={H} D={D} chunk_size={chunk_size} cu_seqlens={cu_seqlens}" + try: + _run_case(tri_inv_func, H, D, chunk_size, cu_seqlens) + print(f" PASS {label}") + passed += 1 + except Exception as err: + print(f" FAIL {label}: {err}") + + print(f"\n{passed}/{total} cases passed.") + return 0 if passed == total else 1 + + +if __name__ == "__main__": + raise SystemExit(main())