diff --git a/.gitignore b/.gitignore index 33093216..181edaca 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ pto_kernels_*.so *.whl __pycache__/ extra-info/ -lib*.so +*.so dist/ build *.egg-info/ diff --git a/examples/jit_cpp/linear_attention/README.md b/examples/jit_cpp/linear_attention/README.md new file mode 100644 index 00000000..de565e0a --- /dev/null +++ b/examples/jit_cpp/linear_attention/README.md @@ -0,0 +1,151 @@ +# PTO-ISA Linear Attention + +This directory contains a self-contained PTO-ISA linear attention example and a local step-by-step optimization tutorial. + +## What Is Here + +- `linear_attention.cpp`: the current optimized kernel +- `run_linear_attention.py`: correctness sweep against a PyTorch reference +- `benchmark_linear_attention.py`: throughput and bandwidth benchmark +- `optimization_lession.md`: reusable optimization notes for future PTO-ISA kernels +- `optimize_step_by_step/`: a tutorial ladder from naive fixed-shape code to the current fast path + +## Main Example + +The main example: +- compiles `linear_attention.cpp` with `bisheng` +- loads the generated `.so` via `ctypes` +- runs the kernel from PyTorch on NPU +- supports both a cached causal-mask path and a fast on-the-fly `TTRI` mask path +- checks numerical correctness with `torch.testing.assert_close` +- reports effective TFLOP/s and GiB/s on larger shapes + +Run correctness: + +```bash +python run_linear_attention.py +``` + +Run the default benchmark table: + +```bash +python benchmark_linear_attention.py --warmup 2 --repeats 5 --mask-variant cached_mask +python benchmark_linear_attention.py --warmup 2 --repeats 5 --mask-variant fast_onthefly +``` + +Quick smoke benchmark: + +```bash +python benchmark_linear_attention.py --quick --warmup 1 --repeats 3 +``` + +Throughput hunt: + +```bash +python benchmark_linear_attention.py --throughput-hunt --warmup 2 --repeats 5 +``` + +## Current Kernel Shape + +The current kernel keeps: +- dynamic `B` and `L` +- compile-time `H`, `D`, and `C` +- fixed `block_dim = num_cores` +- an explicit in-kernel loop over logical work items + +The current fast path is `C=128, D=128`. + +The directory now contains two PTO execution styles: +- legacy fused `head_first` `(B, H, T, D)` for the highest throughput reference path +- native `seq_first` `(B, T, H, D)` including gated and packed-varlen support without transpose or Python padding + +The main performance ideas now present in `linear_attention.cpp` are: +- precomputed causal mask passed from PyTorch and applied with vector tile ops +- an additional fast on-the-fly mask variant that builds the same triangular tile in UB with `TTRI` +- shared L0C reuse so larger tiles fit without changing the math +- cube-side `K=128 -> 2 x 64` L0 ping-pong inside the GEMM helper +- 2-slot cube/vector workspace pipeline for chunk overlap +- in-place mask application on `acc_ub` to reduce UB pressure +- two L1 hidden-state buffers so the next prefix-state tile can be prefetched early +- static strided full-chunk PTO loads/stores for native `seq_first` inputs, with dynamic `TLOAD`/`TSTORE` reserved for only true varlen tail chunks + +## Step-By-Step Tutorial + +`optimize_step_by_step/` mirrors the optimization path as runnable local examples: + +1. `01_naive_static_shape` +2. `02_naive_dynamic_shape` +3. `03_cached_mask` +4. `03a_fast_mask_construct` +5. `04_chunk128` +6. `05_l0_double_buffer` +7. `06_two_slot_cv_pipeline` +8. `07_l1_prefetching` + +The tutorial keeps each kernel source self-contained, but now shares common Python compile / test / benchmark helpers through `optimize_step_by_step/common/`. + +Start there if you want to understand how the kernel evolved, or if you want a smaller teaching version before reading the main optimized kernel. + +## Measured Results + +Command used: + +```bash +python benchmark_linear_attention.py --warmup 2 --repeats 5 --mask-variant cached_mask +python benchmark_linear_attention.py --warmup 2 --repeats 5 --mask-variant fast_onthefly +``` + +Current measured default-shape comparison on this machine: + +| Shape `(B,H,L,D,C)` | Cached ms | Cached TFLOP/s | Fast on-the-fly ms | Fast on-the-fly TFLOP/s | Faster variant | +| --- | ---: | ---: | ---: | ---: | --- | +| `(32, 20, 2048, 128, 128)` | `2.332` | `73.66` | `2.326` | `73.85` | `fast_onthefly` | +| `(24, 20, 4096, 128, 128)` | `3.442` | `74.87` | `3.379` | `76.26` | `fast_onthefly` | +| `(12, 20, 8192, 128, 128)` | `3.373` | `76.39` | `3.372` | `76.42` | `fast_onthefly` | +| `(24, 20, 6144, 128, 128)` | `4.985` | `77.55` | `5.017` | `77.05` | `cached_mask` | + +Best measured points from those two runs: +- `cached_mask`: `77.55 TFLOP/s` / `564.23 GiB/s` at `(24, 20, 6144, 128, 128)` +- `fast_onthefly`: `77.05 TFLOP/s` / `560.62 GiB/s` at `(24, 20, 6144, 128, 128)` + +Feature-extension quick table on this machine +using `python benchmark_linear_attention.py --quick --repeats 10 --warmup 3` +and the corresponding `--seq-first`, `--use-g`, and `--varlen-uniform` modes: + +| Shape `(B,H,L,D,C)` | PTO path | Median ms | TFLOP/s | GiB/s | +| --- | --- | ---: | ---: | ---: | +| `(8, 20, 1024, 128, 128)` | `legacy_head_first` | `0.416` | `51.62` | `375.66` | +| `(8, 20, 1024, 128, 128)` | `seq_first` | `0.549` | `39.10` | `284.54` | +| `(8, 20, 1024, 128, 128)` | `seq_first_gated` | `0.535` | `40.11` | `291.90` | +| `(8, 20, 1024, 128, 128)` | `seq_first_varlen_uniform` | `0.529` | `40.61` | `295.50` | +| `(16, 20, 1024, 128, 128)` | `legacy_head_first` | `0.710` | `60.49` | `440.13` | +| `(16, 20, 1024, 128, 128)` | `seq_first` | `0.880` | `48.80` | `355.07` | +| `(16, 20, 1024, 128, 128)` | `seq_first_gated` | `0.872` | `49.24` | `358.30` | +| `(16, 20, 1024, 128, 128)` | `seq_first_varlen_uniform` | `0.872` | `49.27` | `358.50` | + +Native `seq_first` larger-shape table on this machine +using `python benchmark_linear_attention.py --throughput-hunt --repeats 5 --warmup 2 --seq-first`: + +| Shape `(B,H,L,D,C)` | PTO path | Median ms | TFLOP/s | GiB/s | +| --- | --- | ---: | ---: | ---: | +| `(24, 20, 2048, 128, 128)` | `seq_first` | `2.154` | `59.81` | `435.15` | +| `(48, 20, 1024, 128, 128)` | `seq_first` | `2.160` | `59.66` | `434.09` | +| `(12, 20, 8192, 128, 128)` | `seq_first` | `4.085` | `63.08` | `458.99` | +| `(24, 20, 1536, 128, 128)` | `seq_first` | `1.661` | `58.19` | `423.39` | + +Notes: +- device-local results will vary +- bandwidth here excludes workspace traffic; the cached-mask rows include mask tensor traffic while the fast on-the-fly rows do not +- the same kernel family at `C=64, D=128` is roughly in the `28-31 TFLOP/s` range on large shapes, both for cached-mask and fast on-the-fly masking +- on this machine, the fast on-the-fly `TTRI` path is effectively tied with cached-mask at `C=128`, winning 3 of the 4 default benchmark shapes by a small margin +- the feature-extension rows above benchmark the native `seq_first` / gated / varlen PTO path with precomputed chunk states `h`; after optimization they now reach roughly `49 TFLOP/s` on the quick shapes and `~63 TFLOP/s` on larger seq-first workloads +- the native `seq_first` path still trails the legacy fused `head_first` reference on the smallest quick shapes, but it no longer needs any transpose or Python-side padding and now lands in the same broad throughput class on larger inputs + +## Reading Order + +If you are new to this directory: + +1. Read `optimize_step_by_step/README.md` +2. Run `01` and `02`, including their `numpy_sim.py` +3. Read the current `linear_attention.cpp` +4. Use `optimization_lession.md` as the checklist for future optimization work \ No newline at end of file diff --git a/examples/jit_cpp/linear_attention/benchmark_linear_attention.py b/examples/jit_cpp/linear_attention/benchmark_linear_attention.py new file mode 100644 index 00000000..e47a2de5 --- /dev/null +++ b/examples/jit_cpp/linear_attention/benchmark_linear_attention.py @@ -0,0 +1,321 @@ +import argparse +import os +from statistics import median + +import torch +import torch_npu # noqa: F401 + +from jit_util_linear_attention import BLOCK_DIM, get_causal_mask, jit_compile +from run_linear_attention import _apply_gating, _build_precomputed_h + +DTYPE = torch.float16 +_DEFAULT_MAX_CACHE_SIZE = 256 * 1024 * 1024 + +# Larger presets intended to drive better utilization while keeping H/D/C static +# within each compiled kernel. +DEFAULT_SHAPES = [ + (32, 20, 2048, 128, 128), + (24, 20, 4096, 128, 128), + (12, 20, 8192, 128, 128), + (24, 20, 6144, 128, 128), +] + +QUICK_SHAPES = [ + (8, 20, 1024, 128, 128), + (16, 20, 1024, 128, 128), +] + +THROUGHPUT_HUNT_SHAPES = [ + (24, 20, 2048, 128, 128), + (48, 20, 1024, 128, 128), + (12, 20, 8192, 128, 128), + (24, 20, 1536, 128, 128), +] + + +def parse_shapes(shape_text: str): + shapes = [] + for item in shape_text.split(";"): + item = item.strip() + if not item: + continue + parts = [int(x) for x in item.split("x")] + if len(parts) != 5: + raise ValueError( + "Each shape must be formatted as BxHxLxDxC, e.g. 16x20x1024x128x64" + ) + shapes.append(tuple(parts)) + return shapes + + +def estimate_flops(batch: int, heads: int, seq: int, hidden: int, chunk: int) -> int: + if seq % chunk != 0: + raise ValueError("This benchmark requires L to be a multiple of C.") + chunk_num = seq // chunk + flops_per_chunk = 4 * chunk * hidden * (chunk + hidden) + return batch * heads * chunk_num * flops_per_chunk + + +def estimate_gm_bytes( + batch: int, + heads: int, + seq: int, + hidden: int, + chunk: int, + *, + include_mask_bytes: bool, +) -> int: + if seq % chunk != 0: + raise ValueError("This benchmark requires L to be a multiple of C.") + chunk_num = seq // chunk + qkv_and_output_bytes = chunk_num * (4 * chunk * hidden * 2) + causal_mask_bytes = chunk * chunk * 2 if include_mask_bytes else 0 + return batch * heads * qkv_and_output_bytes + causal_mask_bytes + + +def make_inputs(batch: int, heads: int, seq: int, hidden: int): + q = torch.randn((batch, heads, seq, hidden), device="npu", dtype=DTYPE) + k = torch.randn((batch, heads, seq, hidden), device="npu", dtype=DTYPE) + v = torch.randn((batch, heads, seq, hidden), device="npu", dtype=DTYPE) + q = q / (q.pow(2).sum(dim=-1, keepdim=True).sqrt() + 1e-6) + k = k / (k.pow(2).sum(dim=-1, keepdim=True).sqrt() + 1e-6) + return q, k, v + + +def make_inputs_seq_first(batch: int, heads: int, seq: int, hidden: int): + q = torch.randn((batch, seq, heads, hidden), device="npu", dtype=DTYPE) + k = torch.randn((batch, seq, heads, hidden), device="npu", dtype=DTYPE) + v = torch.randn((batch, seq, heads, hidden), device="npu", dtype=DTYPE) + q = q / (q.pow(2).sum(dim=-1, keepdim=True).sqrt() + 1e-6) + k = k / (k.pow(2).sum(dim=-1, keepdim=True).sqrt() + 1e-6) + return q, k, v + + +def benchmark_shape( + src: str, + batch: int, + heads: int, + seq: int, + hidden: int, + chunk: int, + warmup: int, + repeats: int, + *, + seq_first: bool = False, + use_g: bool = False, + varlen_uniform: bool = False, + mask_variant: str = "cached_mask", +): + kernel = jit_compile(src, num_heads=heads, hidden_size=hidden, chunk_size=chunk) + causal_mask = get_causal_mask(chunk, DTYPE, 0) + cache = torch.ones(_DEFAULT_MAX_CACHE_SIZE, dtype=torch.int8, device="npu") + if mask_variant not in {"cached_mask", "fast_onthefly"}: + raise ValueError(f"Unsupported mask_variant: {mask_variant}") + use_fast_mask = mask_variant == "fast_onthefly" + + if not seq_first and not use_g and not varlen_uniform: + q, k, v = make_inputs(batch, heads, seq, hidden) + workspace_1 = torch.zeros((BLOCK_DIM, 2, chunk, chunk), device="npu", dtype=DTYPE) + workspace_2 = torch.zeros((BLOCK_DIM, 2, hidden, hidden), device="npu", dtype=DTYPE) + out = torch.zeros((batch, heads, seq, hidden), device="npu", dtype=DTYPE) + + def launch(): + kernel( + q, + k, + v, + workspace_1, + workspace_2, + causal_mask, + out, + use_fast_mask=use_fast_mask, + block_dim=BLOCK_DIM, + ) + else: + q, k, v = make_inputs_seq_first(batch, heads, seq, hidden) + g = torch.zeros((batch, seq, heads), device="npu", dtype=torch.float32) if use_g else None + cu_seqlens = None + if varlen_uniform: + total_t = batch * seq + cu_seqlens = torch.arange(0, total_t + 1, seq, device="npu", dtype=torch.int32) + q = q.reshape(1, total_t, heads, hidden).contiguous() + k = k.reshape(1, total_t, heads, hidden).contiguous() + v = v.reshape(1, total_t, heads, hidden).contiguous() + if g is not None: + g = g.reshape(1, total_t, heads).contiguous() + batch_for_kernel = batch + else: + batch_for_kernel = batch + + q_scaled, k_scaled = _apply_gating(q, k, g, head_first=False) + h_states = _build_precomputed_h( + k_scaled, + v, + chunk, + head_first=False, + cu_seqlens=cu_seqlens, + ).contiguous() + workspace_1 = torch.zeros((BLOCK_DIM, 2, chunk, chunk), device="npu", dtype=DTYPE) + out = torch.zeros_like(v) + + def launch(): + kernel( + q_scaled, + k_scaled, + v, + workspace_1, + h_states, + causal_mask, + out, + cu_seqlens=cu_seqlens, + seq_first=True, + use_precomputed_h=True, + use_fast_mask=use_fast_mask, + batch_size_override=batch_for_kernel, + block_dim=BLOCK_DIM, + ) + + batch = batch_for_kernel + + for _ in range(warmup): + launch() + torch.npu.synchronize() + + samples_ms = [] + for _ in range(repeats): + cache.zero_() + torch.npu.synchronize() + start = torch.npu.Event(enable_timing=True) + end = torch.npu.Event(enable_timing=True) + start.record() + launch() + end.record() + end.synchronize() + samples_ms.append(start.elapsed_time(end)) + + med_ms = median(samples_ms) + secs = med_ms / 1e3 + flops = estimate_flops(batch, heads, seq, hidden, chunk) + gm_bytes = estimate_gm_bytes( + batch, + heads, + seq, + hidden, + chunk, + include_mask_bytes=not use_fast_mask, + ) + tflops = flops / secs / 1e12 + gib_s = gm_bytes / secs / (2**30) + + return { + "shape": (batch, heads, seq, hidden, chunk), + "mask_variant": mask_variant, + "median_ms": med_ms, + "tflops": tflops, + "gib_s": gib_s, + "flops": flops, + "gm_bytes": gm_bytes, + } + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark the standalone PTO-ISA linear attention kernel." + ) + parser.add_argument("--warmup", type=int, default=5) + parser.add_argument("--repeats", type=int, default=20) + parser.add_argument( + "--shapes", + type=str, + default="", + help="Semicolon-separated BxHxLxDxC list, e.g. 16x20x1024x128x64;8x20x4096x128x64", + ) + parser.add_argument( + "--quick", + action="store_true", + help="Run a shorter preset shape list.", + ) + parser.add_argument( + "--throughput-hunt", + action="store_true", + help="Run a larger-shape preset to search for higher steady-state utilization.", + ) + parser.add_argument("--seq-first", action="store_true", help="Benchmark native (B, T, H, D) mode.") + parser.add_argument("--use-g", action="store_true", help="Benchmark gated mode using uniform zero gate.") + parser.add_argument( + "--varlen-uniform", + action="store_true", + help="Benchmark the seq-first varlen path with uniform cu_seqlens.", + ) + parser.add_argument( + "--mask-variant", + choices=["cached_mask", "fast_onthefly", "both"], + default="cached_mask", + help="Choose cached-mask, fast on-the-fly, or run both.", + ) + args = parser.parse_args() + + torch.manual_seed(0) + torch.npu.set_device("npu:0") + + src = os.path.join(os.path.dirname(os.path.abspath(__file__)), "linear_attention.cpp") + if args.shapes: + shapes = parse_shapes(args.shapes) + elif args.throughput_hunt: + shapes = THROUGHPUT_HUNT_SHAPES + elif args.quick: + shapes = QUICK_SHAPES + else: + shapes = DEFAULT_SHAPES + + header = ( + f"{'mask variant':>18} {'shape (B,H,L,D,C)':>24} {'ms':>9} {'TFLOP/s':>10} {'GiB/s':>10}" + ) + print(header) + print("-" * len(header)) + + results = [] + mask_variants = ( + ["cached_mask", "fast_onthefly"] + if args.mask_variant == "both" + else [args.mask_variant] + ) + + for batch, heads, seq, hidden, chunk in shapes: + for mask_variant in mask_variants: + print(f"Running {batch}x{heads}x{seq}x{hidden}x{chunk} [{mask_variant}] ...") + result = benchmark_shape( + src, + batch=batch, + heads=heads, + seq=seq, + hidden=hidden, + chunk=chunk, + warmup=args.warmup, + repeats=args.repeats, + seq_first=args.seq_first or args.varlen_uniform, + use_g=args.use_g, + varlen_uniform=args.varlen_uniform, + mask_variant=mask_variant, + ) + results.append(result) + print( + f"{result['mask_variant']:>18} " + f"{str(result['shape']):>24} " + f"{result['median_ms']:>9.3f} " + f"{result['tflops']:>10.2f} " + f"{result['gib_s']:>10.2f}" + ) + + if results: + best_tflops = max(results, key=lambda x: x["tflops"]) + best_bw = max(results, key=lambda x: x["gib_s"]) + print("\nBest throughput:") + print( + f" TFLOP/s: {best_tflops['tflops']:.2f} at shape {best_tflops['shape']}" + ) + print(f" GiB/s: {best_bw['gib_s']:.2f} at shape {best_bw['shape']}") + + +if __name__ == "__main__": + main() diff --git a/examples/jit_cpp/linear_attention/jit_util_linear_attention.py b/examples/jit_cpp/linear_attention/jit_util_linear_attention.py new file mode 100644 index 00000000..09a481fc --- /dev/null +++ b/examples/jit_cpp/linear_attention/jit_util_linear_attention.py @@ -0,0 +1,189 @@ +import ctypes +import os +import subprocess +from functools import lru_cache + +import torch + +ASCEND_TOOLKIT_HOME = os.environ["ASCEND_TOOLKIT_HOME"] +PTO_LIB_PATH = os.environ.get("PTO_LIB_PATH", ASCEND_TOOLKIT_HOME) +BLOCK_DIM = int(getattr(torch.npu.get_device_properties("npu:0"), "cube_core_num", 20)) + + +def compile_cpp( + kernel_cpp: str, + num_heads: int, + hidden_size: int, + chunk_size: int, + verbose: bool = False, + timeout: int = 180, +) -> str: + lib_dir = os.path.join(os.path.dirname(kernel_cpp), "compiled_lib") + os.makedirs(lib_dir, exist_ok=True) + lib_path = os.path.join( + lib_dir, + f"linear_attention_H{num_heads}_D{hidden_size}_C{chunk_size}_jit.so", + ) + + extra_flags = os.environ.get("LINEAR_ATTN_EXTRA_FLAGS", "").split() + flags = [ + "-fPIC", + "-shared", + "-xcce", + "-DMEMORY_BASE", + "-O2", + "-std=gnu++17", + "--cce-aicore-arch=dav-c220", + "-mllvm", + "-cce-aicore-stack-size=0x8000", + "-mllvm", + "-cce-aicore-function-stack-size=0x8000", + "-mllvm", + "-cce-aicore-record-overflow=true", + "-mllvm", + "-cce-aicore-dcci-insert-for-scalar=false", + "-Wno-macro-redefined", + "-Wno-ignored-attributes", + *extra_flags, + f"-I{PTO_LIB_PATH}/include", + f"-I{ASCEND_TOOLKIT_HOME}/include", + f"-I{ASCEND_TOOLKIT_HOME}/pkg_inc", + f"-I{ASCEND_TOOLKIT_HOME}/pkg_inc/runtime", + f"-I{ASCEND_TOOLKIT_HOME}/pkg_inc/profiling", + f"-DLINEAR_ATTN_H={num_heads}", + f"-DLINEAR_ATTN_D={hidden_size}", + f"-DLINEAR_ATTN_C={chunk_size}", + ] + + command = ["bisheng", *flags, kernel_cpp, "-o", lib_path] + if verbose: + print("compile command:", " ".join(command)) + + try: + subprocess.run(command, timeout=timeout, check=True) + except Exception as exc: + raise RuntimeError(f"Compile failed: {exc}") from exc + + if verbose: + print(f"generated {lib_path}") + return lib_path + + +def torch_to_ctypes(tensor: torch.Tensor) -> ctypes.c_void_p: + return ctypes.c_void_p(tensor.data_ptr()) + + +def optional_torch_to_ctypes(tensor: torch.Tensor | None) -> ctypes.c_void_p: + if tensor is None: + return ctypes.c_void_p() + return torch_to_ctypes(tensor) + + +@lru_cache(maxsize=None) +def get_causal_mask(chunk_size: int, dtype: torch.dtype, device_index: int): + vec_num = 2 + if chunk_size % vec_num != 0: + raise ValueError("chunk_size must be divisible by 2 for the causal mask.") + half_chunk = chunk_size // vec_num + mask = torch.zeros( + (vec_num, half_chunk, chunk_size), + device=f"npu:{device_index}", + dtype=dtype, + ) + for vid in range(vec_num): + rows = torch.arange(vid * half_chunk, (vid + 1) * half_chunk, device=mask.device) + cols = torch.arange(chunk_size, device=mask.device) + mask[vid] = (rows[:, None] >= cols[None, :]).to(dtype) + return mask.contiguous() + + +def load_lib(lib_path: str): + lib = ctypes.CDLL(os.path.abspath(lib_path)) + + lib.call_kernel.argtypes = [ + ctypes.c_uint32, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_int64, + ctypes.c_int64, + ctypes.c_uint32, + ctypes.c_uint32, + ctypes.c_uint32, + ] + lib.call_kernel.restype = None + + def linear_attention_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + workspace_1: torch.Tensor, + workspace_2: torch.Tensor, + causal_mask: torch.Tensor, + o: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + seq_first: bool = False, + use_precomputed_h: bool = False, + use_fast_mask: bool = False, + batch_size_override: int | None = None, + block_dim: int | None = None, + stream_ptr=None, + ): + if block_dim is None: + block_dim = BLOCK_DIM + if stream_ptr is None: + stream_ptr = torch.npu.current_stream()._as_parameter_ + if cu_seqlens is not None: + if cu_seqlens.dtype != torch.int32: + raise TypeError("cu_seqlens must be int32.") + if not cu_seqlens.is_contiguous(): + raise ValueError("cu_seqlens must be contiguous.") + + batch_size = q.shape[0] if batch_size_override is None else batch_size_override + lib.call_kernel( + block_dim, + stream_ptr, + torch_to_ctypes(q), + torch_to_ctypes(k), + torch_to_ctypes(v), + torch_to_ctypes(workspace_1), + torch_to_ctypes(workspace_2), + torch_to_ctypes(causal_mask), + torch_to_ctypes(o), + optional_torch_to_ctypes(cu_seqlens), + batch_size, + q.shape[1] if seq_first else q.shape[2], + int(seq_first), + int(use_precomputed_h), + int(use_fast_mask), + ) + + return linear_attention_func + + +@lru_cache(maxsize=None) +def jit_compile( + src_path: str, + num_heads: int, + hidden_size: int, + chunk_size: int, + verbose: bool = True, + clean_up: bool = False, +): + lib_path = compile_cpp( + src_path, + num_heads=num_heads, + hidden_size=hidden_size, + chunk_size=chunk_size, + verbose=verbose, + ) + func = load_lib(lib_path) + if clean_up: + os.remove(lib_path) + return func diff --git a/examples/jit_cpp/linear_attention/linear_attention.cpp b/examples/jit_cpp/linear_attention/linear_attention.cpp new file mode 100644 index 00000000..229eee56 --- /dev/null +++ b/examples/jit_cpp/linear_attention/linear_attention.cpp @@ -0,0 +1,1147 @@ +#include +#include +#include +#include + +using namespace pto; + +#ifndef LINEAR_ATTN_H +#define LINEAR_ATTN_H 2 +#endif + +#ifndef LINEAR_ATTN_D +#define LINEAR_ATTN_D 128 +#endif + +#ifndef LINEAR_ATTN_C +#define LINEAR_ATTN_C 64 +#endif + +template +using L1Mat = Tile; + +template +using L1MatTrans = + Tile; + +template +using UbVec = Tile; + +// PTO 8.5.0 bakes `diagonal` into the TTRI template arguments, while +// pto-isa-master passes it as a runtime argument. Keep one call site that +// accepts either form so the example builds against both header versions. +template +AICORE inline auto TTriCompatImpl(TileData &dst, int diagonal_value, int) + -> decltype(TTRI(dst, diagonal_value), void()) { + TTRI(dst, diagonal_value); +} + +template +AICORE inline auto TTriCompatImpl(TileData &dst, int, long) + -> decltype(TTRI(dst), void()) { + TTRI(dst); +} + +template +AICORE inline void TTriCompat(TileData &dst) { + TTriCompatImpl(dst, diagonal, 0); +} + +template +AICORE inline void SetCrossFlag(int32_t flag, int32_t mode) { + const int config = 1 | (mode << 4) | (flag << 8); + ffts_cross_core_sync(Pipe, config); +} + +AICORE inline void WaitCrossFlag(int32_t flag) { wait_flag_dev(flag); } + +template +AICORE inline void SetFlag(uint32_t id) { + set_flag(Src, Dst, static_cast(id)); +} + +template +AICORE inline void WaitFlag(uint32_t id) { + wait_flag(Src, Dst, static_cast(id)); +} + +template +AICORE inline void BuildLowerTriMask(TileData &mask_tile, int64_t vector_id) { + if (vector_id == 0) { + TTriCompat(mask_tile); + } else { + TTriCompat( + mask_tile); + } + pipe_barrier(PIPE_ALL); +} + +template +AICORE inline void MatmulL1( + TileAcc &dst, + std::conditional_t, L1Mat> &a_l1, + std::conditional_t, L1Mat> &b_l1, + bool init) { + if constexpr ((K % 64 == 0) && (K > 64)) { + constexpr int KStep = 64; + constexpr int Parts = K / KStep; + constexpr uintptr_t AStepBytes = M * KStep * sizeof(half); + constexpr uintptr_t BStepBytes = KStep * N * sizeof(half); + + TileLeft a_l0[2]; + TileRight b_l0[2]; + TASSIGN(a_l0[0], static_cast(0)); + TASSIGN(a_l0[1], AStepBytes); + TASSIGN(b_l0[0], static_cast(0)); + TASSIGN(b_l0[1], BStepBytes); + + SetFlag(0); + SetFlag(1); + + for (int part = 0; part < Parts; ++part) { + const int buf = part & 1; + WaitFlag(buf); + + if constexpr (TransposeA) { + L1MatTrans a_view; + TRESHAPE(a_view, a_l1); + TEXTRACT(a_l0[buf], a_view, 0, part * KStep); + } else { + TEXTRACT(a_l0[buf], a_l1, 0, part * KStep); + } + + if constexpr (TransposeB) { + L1MatTrans b_view; + TRESHAPE(b_view, b_l1); + TEXTRACT(b_l0[buf], b_view, part * KStep, 0); + } else { + TEXTRACT(b_l0[buf], b_l1, part * KStep, 0); + } + + SetFlag(buf); + WaitFlag(buf); + + if (init && part == 0) { + TMATMUL(dst, a_l0[buf], b_l0[buf]); + } else { + TMATMUL_ACC(dst, dst, a_l0[buf], b_l0[buf]); + } + + SetFlag(buf); + } + + WaitFlag(0); + WaitFlag(1); + pipe_barrier(PIPE_ALL); + } else { + TileLeft a_l0; + TileRight b_l0; + TASSIGN(a_l0, 0x0); + TASSIGN(b_l0, 0x0); + + if constexpr (TransposeA) { + L1MatTrans a_view; + TRESHAPE(a_view, a_l1); + TEXTRACT(a_l0, a_view, 0, 0); + } else { + TEXTRACT(a_l0, a_l1, 0, 0); + } + + if constexpr (TransposeB) { + L1MatTrans b_view; + TRESHAPE(b_view, b_l1); + TEXTRACT(b_l0, b_view, 0, 0); + } else { + TEXTRACT(b_l0, b_l1, 0, 0); + } + + pipe_barrier(PIPE_ALL); + if (init) { + TMATMUL(dst, a_l0, b_l0); + } else { + TMATMUL_ACC(dst, dst, a_l0, b_l0); + } + pipe_barrier(PIPE_ALL); + } +} + +struct LinearAttnSeqInfo { + uint32_t bos; + uint32_t seq_len; + uint32_t chunk_offset; + uint32_t token_base_offset; + uint32_t row_stride; +}; + +AICORE inline uint32_t DivCeilU32(uint32_t x, uint32_t y) { + return (x + y - 1) / y; +} + +AICORE inline LinearAttnSeqInfo GetLinearAttnSeqInfo( + uint32_t seq_idx, uint32_t head_idx, uint32_t num_heads, + uint32_t hidden_size, uint32_t chunk_size, uint32_t fixed_seq_len, + bool seq_first, __gm__ int32_t *cu_seqlens) { + if (!seq_first) { + const uint32_t chunk_num = DivCeilU32(fixed_seq_len, chunk_size); + return { + seq_idx * fixed_seq_len, + fixed_seq_len, + seq_idx * chunk_num, + ((seq_idx * num_heads + head_idx) * fixed_seq_len) * hidden_size, + hidden_size, + }; + } + + if (cu_seqlens == nullptr) { + const uint32_t bos = seq_idx * fixed_seq_len; + const uint32_t chunk_num = DivCeilU32(fixed_seq_len, chunk_size); + return { + bos, + fixed_seq_len, + seq_idx * chunk_num, + bos * num_heads * hidden_size + head_idx * hidden_size, + num_heads * hidden_size, + }; + } + + uint32_t bos = 0; + uint32_t chunk_offset = 0; + for (uint32_t i = 0; i < seq_idx; ++i) { + const uint32_t seq_start = static_cast(cu_seqlens[i]); + const uint32_t seq_end = static_cast(cu_seqlens[i + 1]); + chunk_offset += DivCeilU32(seq_end - seq_start, chunk_size); + } + bos = static_cast(cu_seqlens[seq_idx]); + const uint32_t eos = static_cast(cu_seqlens[seq_idx + 1]); + return { + bos, + eos - bos, + chunk_offset, + bos * num_heads * hidden_size + head_idx * hidden_size, + num_heads * hidden_size, + }; +} + +template +AICORE void main_kernel_precomputed(__gm__ half *q, __gm__ half *k, + __gm__ half *v, __gm__ half *workspace_1, + __gm__ half *h, __gm__ half *causal_mask, + __gm__ half *o, __gm__ int32_t *cu_seqlens, + int64_t batch_size, int64_t seq_len, + bool seq_first, uint32_t use_fast_mask, + uint64_t ffts_addr) { + constexpr int32_t StageCount = 2; + constexpr bool UseTwoStagePipeline = (ChunkSize >= 128); + constexpr bool InplaceMaskApply = true; + constexpr int32_t HalfChunk = ChunkSize / 2; + constexpr int32_t ChunkElems = ChunkSize * HiddenSize; + constexpr int32_t Workspace1SlotElems = ChunkSize * ChunkSize; + constexpr int32_t Workspace1Elems = StageCount * Workspace1SlotElems; + constexpr int32_t HiddenElems = HiddenSize * HiddenSize; + + constexpr int32_t QL1Addr = 0; + constexpr int32_t KL1Addr = QL1Addr + ChunkElems * sizeof(half); + constexpr int32_t VL1Addr = KL1Addr + ChunkElems * sizeof(half); + constexpr int32_t HL1Addr = VL1Addr + ChunkElems * sizeof(half); + constexpr int32_t AccL1Addr = HL1Addr + HiddenElems * sizeof(half); + constexpr int32_t HNextL1Addr = AccL1Addr + Workspace1SlotElems * sizeof(half); + constexpr int32_t SharedL0Addr = 0; + constexpr int32_t AccUbAddr = 0; + constexpr int32_t MaskUbAddr = AccUbAddr + HalfChunk * ChunkSize * sizeof(half); + + using ChunkGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using ChunkGlobalDynShape = Shape<1, 1, 1, DYNAMIC, DYNAMIC>; + using ChunkGlobalDynStride = Stride<1, 1, 1, DYNAMIC, 1>; + using ChunkGlobalDyn = + GlobalTensor; + using SeqChunkGlobalStride = Stride<1, 1, 1, -1, 1>; + using SeqChunkGlobal = + GlobalTensor, + SeqChunkGlobalStride, Layout::ND>; + using AccGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HiddenGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HalfAccGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HalfMaskGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using OutGlobalDyn = + GlobalTensor; + + using ChunkL1Dyn = Tile; + using OutL0Dyn = + TileAcc; + + const int64_t total_work = batch_size * NumHeads; + const int64_t cid = get_block_idx(); + const int64_t vid = get_subblockid(); + set_ffts_base_addr(ffts_addr); + + L1Mat q_l1; + L1Mat k_l1; + L1Mat v_l1; + L1Mat h_l1; + L1Mat h_next_l1; + L1Mat acc_l1; + TASSIGN(q_l1, QL1Addr); + TASSIGN(k_l1, KL1Addr); + TASSIGN(v_l1, VL1Addr); + TASSIGN(h_l1, HL1Addr); + TASSIGN(h_next_l1, HNextL1Addr); + TASSIGN(acc_l1, AccL1Addr); + + TileAcc acc_l0; + TileAcc o_l0; + TASSIGN(acc_l0, SharedL0Addr); + TASSIGN(o_l0, SharedL0Addr); + + UbVec acc_ub; + UbVec mask_ub; + TASSIGN(acc_ub, AccUbAddr); + TASSIGN(mask_ub, MaskUbAddr); + +#if defined(__DAV_C220_CUBE__) + for (int64_t work_idx = 0; work_idx < (total_work + block_num - 1) / block_num; + ++work_idx) { + const int64_t pid = work_idx * block_num + cid; + if (pid >= total_work) { + continue; + } + + const uint32_t by = static_cast(pid % NumHeads); + const uint32_t bz = static_cast(pid / NumHeads); + const LinearAttnSeqInfo seq_info = + GetLinearAttnSeqInfo(bz, by, NumHeads, HiddenSize, ChunkSize, + static_cast(seq_len), seq_first, + cu_seqlens); + const uint32_t chunk_num = DivCeilU32(seq_info.seq_len, ChunkSize); + const int64_t workspace1_base = cid * Workspace1Elems; + + if constexpr (UseTwoStagePipeline) { + const int32_t flag_base = static_cast((work_idx & 3) * 6); + int32_t h_buf = 0; + WaitCrossFlag(flag_base + 4); + + { + const uint32_t token_offset = seq_info.token_base_offset; + const uint32_t valid_rows = + min(seq_info.seq_len, static_cast(ChunkSize)); + if (valid_rows == ChunkSize && seq_info.row_stride == HiddenSize) { + ChunkGlobal q_global(q + token_offset); + ChunkGlobal k_global(k + token_offset); + TLOAD(q_l1, q_global); + TLOAD(k_l1, k_global); + } else if (valid_rows == ChunkSize) { + SeqChunkGlobal q_global(q + token_offset, {}, + {static_cast(seq_info.row_stride)}); + SeqChunkGlobal k_global(k + token_offset, {}, + {static_cast(seq_info.row_stride)}); + TLOAD(q_l1, q_global); + TLOAD(k_l1, k_global); + } else { + ChunkL1Dyn q_dyn(valid_rows, HiddenSize); + ChunkL1Dyn k_dyn(valid_rows, HiddenSize); + TASSIGN(q_dyn, QL1Addr); + TASSIGN(k_dyn, KL1Addr); + ChunkGlobalDyn q_global_dyn(q + token_offset, + {1, 1, 1, static_cast(valid_rows), + HiddenSize}, + {1, 1, 1, + static_cast(seq_info.row_stride), 1}); + ChunkGlobalDyn k_global_dyn(k + token_offset, + {1, 1, 1, static_cast(valid_rows), + HiddenSize}, + {1, 1, 1, + static_cast(seq_info.row_stride), 1}); + TLOAD(q_dyn, q_global_dyn); + TLOAD(k_dyn, k_global_dyn); + } + HiddenGlobal h_global(h + (seq_info.chunk_offset * NumHeads + by) * HiddenElems); + TLOAD(h_l1, h_global); + pipe_barrier(PIPE_ALL); + + MatmulL1(acc_l0, q_l1, k_l1, + true); + AccGlobal acc_global(workspace_1 + workspace1_base); + TSTORE(acc_global, acc_l0); + pipe_barrier(PIPE_ALL); + SetCrossFlag(flag_base, 2); + } + + for (uint32_t i = 0; i < chunk_num; ++i) { + const int32_t slot = static_cast(i & 1); + const int32_t next_slot = slot ^ 1; + const uint32_t row_start = i * ChunkSize; + const uint32_t valid_rows = + min(static_cast(seq_info.seq_len - row_start), + static_cast(ChunkSize)); + const uint32_t token_offset = seq_info.token_base_offset + + row_start * seq_info.row_stride; + + if (i + 1 < chunk_num) { + const uint32_t next_row_start = (i + 1) * ChunkSize; + const uint32_t next_valid_rows = + min(static_cast(seq_info.seq_len - next_row_start), + static_cast(ChunkSize)); + const uint32_t next_token_offset = seq_info.token_base_offset + + next_row_start * seq_info.row_stride; + const int64_t next_workspace1_base = + workspace1_base + next_slot * Workspace1SlotElems; + + if (next_valid_rows == ChunkSize && seq_info.row_stride == HiddenSize) { + ChunkGlobal q_global(q + next_token_offset); + ChunkGlobal k_global(k + next_token_offset); + TLOAD(q_l1, q_global); + TLOAD(k_l1, k_global); + } else if (next_valid_rows == ChunkSize) { + SeqChunkGlobal q_global(q + next_token_offset, {}, + {static_cast(seq_info.row_stride)}); + SeqChunkGlobal k_global(k + next_token_offset, {}, + {static_cast(seq_info.row_stride)}); + TLOAD(q_l1, q_global); + TLOAD(k_l1, k_global); + } else { + ChunkL1Dyn q_dyn(next_valid_rows, HiddenSize); + ChunkL1Dyn k_dyn(next_valid_rows, HiddenSize); + TASSIGN(q_dyn, QL1Addr); + TASSIGN(k_dyn, KL1Addr); + ChunkGlobalDyn q_global_dyn( + q + next_token_offset, + {1, 1, 1, static_cast(next_valid_rows), HiddenSize}, + {1, 1, 1, static_cast(seq_info.row_stride), 1}); + ChunkGlobalDyn k_global_dyn( + k + next_token_offset, + {1, 1, 1, static_cast(next_valid_rows), HiddenSize}, + {1, 1, 1, static_cast(seq_info.row_stride), 1}); + TLOAD(q_dyn, q_global_dyn); + TLOAD(k_dyn, k_global_dyn); + } + pipe_barrier(PIPE_ALL); + + MatmulL1( + acc_l0, q_l1, k_l1, true); + AccGlobal acc_global(workspace_1 + next_workspace1_base); + TSTORE(acc_global, acc_l0); + pipe_barrier(PIPE_ALL); + SetCrossFlag(flag_base + next_slot, 2); + } + + WaitCrossFlag(flag_base + 2 + slot); + AccGlobal masked_acc_global(workspace_1 + workspace1_base + + slot * Workspace1SlotElems); + TLOAD(acc_l1, masked_acc_global); + + if (valid_rows == ChunkSize && seq_info.row_stride == HiddenSize) { + ChunkGlobal q_global(q + token_offset); + ChunkGlobal v_global(v + token_offset); + TLOAD(q_l1, q_global); + TLOAD(v_l1, v_global); + } else if (valid_rows == ChunkSize) { + SeqChunkGlobal q_global(q + token_offset, {}, + {static_cast(seq_info.row_stride)}); + SeqChunkGlobal v_global(v + token_offset, {}, + {static_cast(seq_info.row_stride)}); + TLOAD(q_l1, q_global); + TLOAD(v_l1, v_global); + } else { + ChunkL1Dyn q_dyn(valid_rows, HiddenSize); + ChunkL1Dyn v_dyn(valid_rows, HiddenSize); + TASSIGN(q_dyn, QL1Addr); + TASSIGN(v_dyn, VL1Addr); + ChunkGlobalDyn q_global_dyn( + q + token_offset, + {1, 1, 1, static_cast(valid_rows), HiddenSize}, + {1, 1, 1, static_cast(seq_info.row_stride), 1}); + ChunkGlobalDyn v_global_dyn( + v + token_offset, + {1, 1, 1, static_cast(valid_rows), HiddenSize}, + {1, 1, 1, static_cast(seq_info.row_stride), 1}); + TLOAD(q_dyn, q_global_dyn); + TLOAD(v_dyn, v_global_dyn); + } + + if (i + 1 < chunk_num) { + HiddenGlobal next_h_global( + h + ((seq_info.chunk_offset + i + 1) * NumHeads + by) * HiddenElems); + if (h_buf == 0) { + TLOAD(h_next_l1, next_h_global); + } else { + TLOAD(h_l1, next_h_global); + } + } + pipe_barrier(PIPE_ALL); + + MatmulL1(o_l0, acc_l1, v_l1, + true); + if (h_buf == 0) { + MatmulL1(o_l0, q_l1, + h_l1, false); + } else { + MatmulL1(o_l0, q_l1, + h_next_l1, + false); + } + + if (valid_rows == ChunkSize && seq_info.row_stride == HiddenSize) { + ChunkGlobal o_global(o + token_offset); + TSTORE(o_global, o_l0); + } else if (valid_rows == ChunkSize) { + SeqChunkGlobal o_global(o + token_offset, {}, + {static_cast(seq_info.row_stride)}); + TSTORE(o_global, o_l0); + } else { + OutL0Dyn o_tail(valid_rows, HiddenSize); + TASSIGN(o_tail, SharedL0Addr); + OutGlobalDyn o_global_dyn(o + token_offset, + {1, 1, 1, static_cast(valid_rows), + HiddenSize}, + {1, 1, 1, + static_cast(seq_info.row_stride), 1}); + TSTORE(o_global_dyn, o_tail); + } + pipe_barrier(PIPE_ALL); + + if (i + 1 < chunk_num) { + h_buf ^= 1; + } + } + SetCrossFlag(flag_base + 5, 2); + } else { + for (uint32_t i = 0; i < chunk_num; ++i) { + const uint32_t row_start = i * ChunkSize; + const uint32_t valid_rows = + min(static_cast(seq_info.seq_len - row_start), + static_cast(ChunkSize)); + const uint32_t token_offset = seq_info.token_base_offset + + row_start * seq_info.row_stride; + + if (valid_rows == ChunkSize && seq_info.row_stride == HiddenSize) { + ChunkGlobal q_global(q + token_offset); + ChunkGlobal k_global(k + token_offset); + ChunkGlobal v_global(v + token_offset); + TLOAD(q_l1, q_global); + TLOAD(k_l1, k_global); + TLOAD(v_l1, v_global); + } else if (valid_rows == ChunkSize) { + SeqChunkGlobal q_global(q + token_offset, {}, + {static_cast(seq_info.row_stride)}); + SeqChunkGlobal k_global(k + token_offset, {}, + {static_cast(seq_info.row_stride)}); + SeqChunkGlobal v_global(v + token_offset, {}, + {static_cast(seq_info.row_stride)}); + TLOAD(q_l1, q_global); + TLOAD(k_l1, k_global); + TLOAD(v_l1, v_global); + } else { + ChunkL1Dyn q_dyn(valid_rows, HiddenSize); + ChunkL1Dyn k_dyn(valid_rows, HiddenSize); + ChunkL1Dyn v_dyn(valid_rows, HiddenSize); + TASSIGN(q_dyn, QL1Addr); + TASSIGN(k_dyn, KL1Addr); + TASSIGN(v_dyn, VL1Addr); + ChunkGlobalDyn q_global_dyn( + q + token_offset, + {1, 1, 1, static_cast(valid_rows), HiddenSize}, + {1, 1, 1, static_cast(seq_info.row_stride), 1}); + ChunkGlobalDyn k_global_dyn( + k + token_offset, + {1, 1, 1, static_cast(valid_rows), HiddenSize}, + {1, 1, 1, static_cast(seq_info.row_stride), 1}); + ChunkGlobalDyn v_global_dyn( + v + token_offset, + {1, 1, 1, static_cast(valid_rows), HiddenSize}, + {1, 1, 1, static_cast(seq_info.row_stride), 1}); + TLOAD(q_dyn, q_global_dyn); + TLOAD(k_dyn, k_global_dyn); + TLOAD(v_dyn, v_global_dyn); + } + + HiddenGlobal h_global( + h + ((seq_info.chunk_offset + i) * NumHeads + by) * HiddenElems); + TLOAD(h_l1, h_global); + pipe_barrier(PIPE_ALL); + + MatmulL1(acc_l0, q_l1, k_l1, + true); + AccGlobal acc_global(workspace_1 + workspace1_base); + TSTORE(acc_global, acc_l0); + pipe_barrier(PIPE_ALL); + SetCrossFlag(0, 2); + + WaitCrossFlag(1); + TLOAD(acc_l1, acc_global); + pipe_barrier(PIPE_ALL); + + MatmulL1(o_l0, acc_l1, v_l1, + true); + MatmulL1(o_l0, q_l1, h_l1, + false); + + if (valid_rows == ChunkSize && seq_info.row_stride == HiddenSize) { + ChunkGlobal o_global(o + token_offset); + TSTORE(o_global, o_l0); + } else if (valid_rows == ChunkSize) { + SeqChunkGlobal o_global(o + token_offset, {}, + {static_cast(seq_info.row_stride)}); + TSTORE(o_global, o_l0); + } else { + OutL0Dyn o_tail(valid_rows, HiddenSize); + TASSIGN(o_tail, SharedL0Addr); + OutGlobalDyn o_global_dyn(o + token_offset, + {1, 1, 1, static_cast(valid_rows), + HiddenSize}, + {1, 1, 1, + static_cast(seq_info.row_stride), 1}); + TSTORE(o_global_dyn, o_tail); + } + pipe_barrier(PIPE_ALL); + } + } + } +#endif + +#if defined(__DAV_C220_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + HalfMaskGlobal mask_global(causal_mask + vid * HalfChunk * ChunkSize); + if (use_fast_mask != 0) { + BuildLowerTriMask(mask_ub, vid); + } else { + TLOAD(mask_ub, mask_global); + pipe_barrier(PIPE_ALL); + } + + for (int64_t work_idx = 0; work_idx < (total_work + block_num - 1) / block_num; + ++work_idx) { + const int64_t pid = work_idx * block_num + cid; + if (pid >= total_work) { + continue; + } + + const uint32_t by = static_cast(pid % NumHeads); + const uint32_t bz = static_cast(pid / NumHeads); + const LinearAttnSeqInfo seq_info = + GetLinearAttnSeqInfo(bz, by, NumHeads, HiddenSize, ChunkSize, + static_cast(seq_len), seq_first, + cu_seqlens); + const uint32_t chunk_num = DivCeilU32(seq_info.seq_len, ChunkSize); + const int64_t workspace1_base = cid * Workspace1Elems; + + if constexpr (UseTwoStagePipeline) { + const int32_t flag_base = static_cast((work_idx & 3) * 6); + SetCrossFlag(flag_base + 4, 2); + for (uint32_t i = 0; i < chunk_num; ++i) { + const int32_t slot = static_cast(i & 1); + WaitCrossFlag(flag_base + slot); + HalfAccGlobal acc_global(workspace_1 + workspace1_base + + slot * Workspace1SlotElems + + vid * HalfChunk * ChunkSize); + TLOAD(acc_ub, acc_global); + pipe_barrier(PIPE_ALL); + if constexpr (InplaceMaskApply) { + TMUL(acc_ub, acc_ub, mask_ub); + } + pipe_barrier(PIPE_ALL); + TSTORE(acc_global, acc_ub); + pipe_barrier(PIPE_ALL); + SetCrossFlag(flag_base + 2 + slot, 2); + } + WaitCrossFlag(flag_base + 5); + } else { + for (uint32_t i = 0; i < chunk_num; ++i) { + WaitCrossFlag(0); + HalfAccGlobal acc_global(workspace_1 + workspace1_base + + vid * HalfChunk * ChunkSize); + TLOAD(acc_ub, acc_global); + pipe_barrier(PIPE_ALL); + TMUL(acc_ub, acc_ub, mask_ub); + pipe_barrier(PIPE_ALL); + TSTORE(acc_global, acc_ub); + pipe_barrier(PIPE_ALL); + SetCrossFlag(1, 2); + } + } + } +#endif +} + +template +AICORE void main_kernel(__gm__ half *q, __gm__ half *k, __gm__ half *v, + __gm__ half *workspace_1, __gm__ half *workspace_2, + __gm__ half *causal_mask, __gm__ half *o, + int64_t batch_size, int64_t seq_len, uint32_t use_fast_mask, + uint64_t ffts_addr) { + constexpr int32_t StageCount = 2; + constexpr bool UseTwoStagePipeline = (ChunkSize >= 128); + constexpr bool InplaceMaskApply = (ChunkSize >= 128); + constexpr int32_t VecNum = 2; + constexpr int32_t HalfChunk = ChunkSize / VecNum; + constexpr int32_t HalfHidden = HiddenSize / VecNum; + constexpr int32_t ChunkElems = ChunkSize * HiddenSize; + constexpr int32_t Workspace1SlotElems = ChunkSize * ChunkSize; + constexpr int32_t Workspace2SlotElems = HiddenSize * HiddenSize; + constexpr int32_t Workspace1Elems = StageCount * Workspace1SlotElems; + constexpr int32_t Workspace2Elems = StageCount * Workspace2SlotElems; + + constexpr int32_t QL1Addr = 0; + constexpr int32_t KL1Addr = QL1Addr + ChunkElems * sizeof(half); + constexpr int32_t VL1Addr = KL1Addr + ChunkElems * sizeof(half); + constexpr int32_t HL1Addr = VL1Addr + ChunkElems * sizeof(half); + constexpr int32_t AccL1Addr = HL1Addr + Workspace2SlotElems * sizeof(half); + constexpr int32_t HNextL1Addr = AccL1Addr + Workspace1SlotElems * sizeof(half); + + constexpr int32_t SharedL0Addr = 0; + + constexpr int32_t HsumUbAddr = 0; + constexpr int32_t AccUbAddr = + HsumUbAddr + HalfHidden * HiddenSize * sizeof(half); + constexpr int32_t HUbAddr = AccUbAddr + HalfChunk * ChunkSize * sizeof(half); + constexpr int32_t RawUBBytes = + (HalfHidden * HiddenSize + HalfChunk * ChunkSize + HalfHidden * HiddenSize + + HalfChunk * ChunkSize + + (InplaceMaskApply ? 0 : HalfChunk * ChunkSize)) * + sizeof(half); + constexpr bool PreloadMask = RawUBBytes <= 72 * 1024; + constexpr bool AliasMaskIntoH = + !PreloadMask && (HalfHidden * HiddenSize >= HalfChunk * ChunkSize); + constexpr int32_t MaskUbAddr = + AliasMaskIntoH ? HUbAddr : HUbAddr + HalfHidden * HiddenSize * sizeof(half); + constexpr int32_t MaskedAccUbAddr = + InplaceMaskApply ? AccUbAddr : MaskUbAddr + HalfChunk * ChunkSize * sizeof(half); + + constexpr int32_t L0CBytes = + (Workspace2SlotElems > Workspace1SlotElems + ? (Workspace2SlotElems > ChunkElems ? Workspace2SlotElems : ChunkElems) + : (Workspace1SlotElems > ChunkElems ? Workspace1SlotElems : ChunkElems)) * + sizeof(float); + constexpr int32_t UBBytes = + (HalfHidden * HiddenSize + HalfChunk * ChunkSize + + (AliasMaskIntoH ? HalfHidden * HiddenSize + : HalfHidden * HiddenSize + HalfChunk * ChunkSize) + + (InplaceMaskApply ? 0 : HalfChunk * ChunkSize)) * + sizeof(half); + constexpr int32_t L1Bytes = + UseTwoStagePipeline ? (HNextL1Addr + Workspace2SlotElems * sizeof(half)) + : (AccL1Addr + Workspace1SlotElems * sizeof(half)); + static_assert((HiddenSize % 2) == 0, "HiddenSize must be even."); + static_assert((ChunkSize % 2) == 0, "ChunkSize must be even."); + static_assert(L0CBytes <= 112 * 1024, + "Tile sizes exceed the validated L0C budget for this minimum kernel."); + static_assert(L1Bytes <= 192 * 1024, + "Tile sizes exceed the validated L1 budget for this minimum kernel."); + static_assert(PreloadMask || AliasMaskIntoH, + "Current minimum kernel requires either a preloaded mask or H UB large enough to alias the mask."); + static_assert(UBBytes <= 72 * 1024, + "Tile sizes exceed the validated UB budget for this minimum kernel."); + + using ChunkGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using AccGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HiddenGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HalfAccGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HalfHiddenGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HalfMaskGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + + const int64_t total_work = batch_size * NumHeads; + const int64_t chunk_num = seq_len / ChunkSize; + const int64_t cid = get_block_idx(); + const int64_t vid = get_subblockid(); + set_ffts_base_addr(ffts_addr); + + L1Mat q_l1; + L1Mat k_l1; + L1Mat v_l1; + L1Mat h_l1; + L1Mat h_next_l1; + L1Mat acc_l1; + TASSIGN(q_l1, QL1Addr); + TASSIGN(k_l1, KL1Addr); + TASSIGN(v_l1, VL1Addr); + TASSIGN(h_l1, HL1Addr); + TASSIGN(h_next_l1, HNextL1Addr); + TASSIGN(acc_l1, AccL1Addr); + + TileAcc acc_l0; + TileAcc h_l0; + TileAcc o_l0; + TASSIGN(acc_l0, SharedL0Addr); + TASSIGN(h_l0, SharedL0Addr); + TASSIGN(o_l0, SharedL0Addr); + + UbVec hsum_ub; + UbVec h_ub; + UbVec acc_ub; + UbVec mask_ub; + UbVec masked_acc_ub; + TASSIGN(hsum_ub, HsumUbAddr); + TASSIGN(acc_ub, AccUbAddr); + TASSIGN(h_ub, HUbAddr); + TASSIGN(mask_ub, MaskUbAddr); + TASSIGN(masked_acc_ub, MaskedAccUbAddr); + +#if defined(__DAV_C220_CUBE__) + for (int64_t work_idx = 0; work_idx < (total_work + block_num - 1) / block_num; + ++work_idx) { + const int64_t pid = work_idx * block_num + cid; + if (pid >= total_work) { + continue; + } + + const int64_t by = pid % NumHeads; + const int64_t bz = pid / NumHeads; + const int64_t qkv_base = ((bz * NumHeads + by) * seq_len) * HiddenSize; + const int64_t workspace1_base = cid * Workspace1Elems; + const int64_t workspace2_base = cid * Workspace2Elems; + + if constexpr (UseTwoStagePipeline) { + const int32_t flag_base = static_cast((work_idx & 3) * 6); + int32_t h_buf = 0; + WaitCrossFlag(flag_base + 4); + HiddenGlobal zero_h_global(workspace_2 + workspace2_base + Workspace2SlotElems); + TLOAD(h_l1, zero_h_global); + pipe_barrier(PIPE_ALL); + + { + const int64_t chunk_base = qkv_base; + ChunkGlobal q_global(q + chunk_base); + ChunkGlobal k_global(k + chunk_base); + ChunkGlobal v_global(v + chunk_base); + TLOAD(q_l1, q_global); + TLOAD(k_l1, k_global); + TLOAD(v_l1, v_global); + pipe_barrier(PIPE_ALL); + + MatmulL1(acc_l0, q_l1, k_l1, + true); + AccGlobal acc_global(workspace_1 + workspace1_base); + TSTORE(acc_global, acc_l0); + pipe_barrier(PIPE_ALL); + + MatmulL1(h_l0, k_l1, v_l1, + true); + HiddenGlobal h_out_global(workspace_2 + workspace2_base); + TSTORE(h_out_global, h_l0); + pipe_barrier(PIPE_ALL); + SetCrossFlag(flag_base, 2); + } + + for (int64_t i = 0; i < chunk_num; ++i) { + const int32_t slot = static_cast(i & 1); + const int32_t next_slot = slot ^ 1; + const int64_t chunk_base = qkv_base + i * ChunkElems; + + if (i + 1 < chunk_num) { + const int64_t next_chunk_base = qkv_base + (i + 1) * ChunkElems; + const int64_t next_workspace1_base = + workspace1_base + next_slot * Workspace1SlotElems; + const int64_t next_workspace2_base = + workspace2_base + next_slot * Workspace2SlotElems; + + ChunkGlobal q_global(q + next_chunk_base); + ChunkGlobal k_global(k + next_chunk_base); + ChunkGlobal v_global(v + next_chunk_base); + TLOAD(q_l1, q_global); + TLOAD(k_l1, k_global); + TLOAD(v_l1, v_global); + pipe_barrier(PIPE_ALL); + + MatmulL1( + acc_l0, q_l1, k_l1, true); + AccGlobal acc_global(workspace_1 + next_workspace1_base); + TSTORE(acc_global, acc_l0); + pipe_barrier(PIPE_ALL); + + MatmulL1(h_l0, k_l1, + v_l1, true); + HiddenGlobal h_out_global(workspace_2 + next_workspace2_base); + TSTORE(h_out_global, h_l0); + pipe_barrier(PIPE_ALL); + SetCrossFlag(flag_base + next_slot, 2); + } + + WaitCrossFlag(flag_base + 2 + slot); + AccGlobal masked_acc_global(workspace_1 + workspace1_base + + slot * Workspace1SlotElems); + TLOAD(acc_l1, masked_acc_global); + ChunkGlobal q_global(q + chunk_base); + ChunkGlobal v_global(v + chunk_base); + TLOAD(q_l1, q_global); + TLOAD(v_l1, v_global); + if (i + 1 < chunk_num) { + HiddenGlobal next_h_global(workspace_2 + workspace2_base + + slot * Workspace2SlotElems); + if (h_buf == 0) { + TLOAD(h_next_l1, next_h_global); + } else { + TLOAD(h_l1, next_h_global); + } + } + pipe_barrier(PIPE_ALL); + + MatmulL1(o_l0, acc_l1, + v_l1, true); + if (h_buf == 0) { + MatmulL1(o_l0, q_l1, + h_l1, false); + } else { + MatmulL1(o_l0, q_l1, + h_next_l1, + false); + } + + ChunkGlobal o_global(o + chunk_base); + TSTORE(o_global, o_l0); + pipe_barrier(PIPE_ALL); + + if (i + 1 < chunk_num) { + h_buf ^= 1; + } + } + SetCrossFlag(flag_base + 5, 2); + } else { + WaitCrossFlag(1); + + for (int64_t i = 0; i < chunk_num; ++i) { + const int64_t chunk_base = qkv_base + i * ChunkElems; + + ChunkGlobal q_global(q + chunk_base); + ChunkGlobal k_global(k + chunk_base); + ChunkGlobal v_global(v + chunk_base); + HiddenGlobal h_global(workspace_2 + workspace2_base); + + TLOAD(q_l1, q_global); + TLOAD(k_l1, k_global); + TLOAD(v_l1, v_global); + TLOAD(h_l1, h_global); + pipe_barrier(PIPE_ALL); + + MatmulL1(acc_l0, q_l1, k_l1, + true); + AccGlobal acc_global(workspace_1 + workspace1_base); + TSTORE(acc_global, acc_l0); + pipe_barrier(PIPE_ALL); + + MatmulL1(h_l0, k_l1, v_l1, + true); + HiddenGlobal h_out_global(workspace_2 + workspace2_base); + TSTORE(h_out_global, h_l0); + pipe_barrier(PIPE_ALL); + SetCrossFlag(0, 2); + + WaitCrossFlag(1); + AccGlobal masked_acc_global(workspace_1 + workspace1_base); + TLOAD(acc_l1, masked_acc_global); + pipe_barrier(PIPE_ALL); + + MatmulL1(o_l0, acc_l1, + v_l1, true); + MatmulL1(o_l0, q_l1, h_l1, + false); + + ChunkGlobal o_global(o + chunk_base); + TSTORE(o_global, o_l0); + pipe_barrier(PIPE_ALL); + } + } + } +#endif + +#if defined(__DAV_C220_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + HalfMaskGlobal mask_global(causal_mask + vid * HalfChunk * ChunkSize); + if (use_fast_mask != 0) { + BuildLowerTriMask(mask_ub, vid); + } else if constexpr (PreloadMask) { + TLOAD(mask_ub, mask_global); + pipe_barrier(PIPE_ALL); + } + + for (int64_t work_idx = 0; work_idx < (total_work + block_num - 1) / block_num; + ++work_idx) { + const int64_t pid = work_idx * block_num + cid; + if (pid >= total_work) { + continue; + } + + const int64_t workspace1_base = cid * Workspace1Elems; + const int64_t workspace2_base = cid * Workspace2Elems; + + TEXPANDS(hsum_ub, 0.0f); + pipe_barrier(PIPE_ALL); + if constexpr (UseTwoStagePipeline) { + const int32_t flag_base = static_cast((work_idx & 3) * 6); + HalfHiddenGlobal init_h_global_0(workspace_2 + workspace2_base + + vid * HalfHidden * HiddenSize); + HalfHiddenGlobal init_h_global_1(workspace_2 + workspace2_base + + Workspace2SlotElems + + vid * HalfHidden * HiddenSize); + TSTORE(init_h_global_0, hsum_ub); + TSTORE(init_h_global_1, hsum_ub); + pipe_barrier(PIPE_ALL); + SetCrossFlag(flag_base + 4, 2); + + for (int64_t i = 0; i < chunk_num; ++i) { + const int32_t slot = static_cast(i & 1); + WaitCrossFlag(flag_base + slot); + + const int64_t slot_workspace1_base = + workspace1_base + slot * Workspace1SlotElems; + const int64_t slot_workspace2_base = + workspace2_base + slot * Workspace2SlotElems; + HalfAccGlobal acc_global(workspace_1 + slot_workspace1_base + + vid * HalfChunk * ChunkSize); + HalfHiddenGlobal h_global(workspace_2 + slot_workspace2_base + + vid * HalfHidden * HiddenSize); + TLOAD(acc_ub, acc_global); + TLOAD(h_ub, h_global); + pipe_barrier(PIPE_ALL); + + // Precompute the chunk carry state H_t = sum_{j<=t}(K_j^T V_j) on the + // vector core, then write it back for the cube core output stage. + TADD(hsum_ub, hsum_ub, h_ub); + pipe_barrier(PIPE_ALL); + if ((use_fast_mask == 0) && !PreloadMask) { + TLOAD(mask_ub, mask_global); + pipe_barrier(PIPE_ALL); + } + if constexpr (InplaceMaskApply) { + TMUL(acc_ub, acc_ub, mask_ub); + } else { + TMUL(masked_acc_ub, acc_ub, mask_ub); + } + pipe_barrier(PIPE_ALL); + if constexpr (InplaceMaskApply) { + TSTORE(acc_global, acc_ub); + } else { + TSTORE(acc_global, masked_acc_ub); + } + TSTORE(h_global, hsum_ub); + pipe_barrier(PIPE_ALL); + SetCrossFlag(flag_base + 2 + slot, 2); + } + WaitCrossFlag(flag_base + 5); + } else { + HalfHiddenGlobal init_h_global(workspace_2 + workspace2_base + + vid * HalfHidden * HiddenSize); + TSTORE(init_h_global, hsum_ub); + pipe_barrier(PIPE_ALL); + SetCrossFlag(1, 2); + + for (int64_t i = 0; i < chunk_num; ++i) { + WaitCrossFlag(0); + + HalfAccGlobal acc_global(workspace_1 + workspace1_base + + vid * HalfChunk * ChunkSize); + HalfHiddenGlobal h_global(workspace_2 + workspace2_base + + vid * HalfHidden * HiddenSize); + TLOAD(acc_ub, acc_global); + TLOAD(h_ub, h_global); + pipe_barrier(PIPE_ALL); + + // Precompute the chunk carry state H_t = sum_{j<=t}(K_j^T V_j) on the + // vector core, then write it back for the cube core output stage. + TADD(hsum_ub, hsum_ub, h_ub); + pipe_barrier(PIPE_ALL); + if ((use_fast_mask == 0) && !PreloadMask) { + TLOAD(mask_ub, mask_global); + pipe_barrier(PIPE_ALL); + } + if constexpr (InplaceMaskApply) { + TMUL(acc_ub, acc_ub, mask_ub); + } else { + TMUL(masked_acc_ub, acc_ub, mask_ub); + } + pipe_barrier(PIPE_ALL); + if constexpr (InplaceMaskApply) { + TSTORE(acc_global, acc_ub); + } else { + TSTORE(acc_global, masked_acc_ub); + } + TSTORE(h_global, hsum_ub); + pipe_barrier(PIPE_ALL); + SetCrossFlag(1, 2); + } + } + } +#endif +} + +extern "C" __global__ AICORE void launch_linear_attention( + __gm__ uint8_t *q, __gm__ uint8_t *k, __gm__ uint8_t *v, + __gm__ uint8_t *workspace_1, __gm__ uint8_t *workspace_2, + __gm__ uint8_t *causal_mask, __gm__ uint8_t *o, + __gm__ int32_t *cu_seqlens, int64_t batch_size, int64_t seq_len, + uint32_t seq_first, uint32_t use_precomputed_h, uint32_t use_fast_mask, + uint64_t ffts_addr) { + if (use_precomputed_h != 0) { + main_kernel_precomputed( + reinterpret_cast<__gm__ half *>(q), reinterpret_cast<__gm__ half *>(k), + reinterpret_cast<__gm__ half *>(v), + reinterpret_cast<__gm__ half *>(workspace_1), + reinterpret_cast<__gm__ half *>(workspace_2), + reinterpret_cast<__gm__ half *>(causal_mask), + reinterpret_cast<__gm__ half *>(o), cu_seqlens, batch_size, seq_len, + seq_first != 0, use_fast_mask, ffts_addr); + return; + } + + main_kernel( + reinterpret_cast<__gm__ half *>(q), reinterpret_cast<__gm__ half *>(k), + reinterpret_cast<__gm__ half *>(v), + reinterpret_cast<__gm__ half *>(workspace_1), + reinterpret_cast<__gm__ half *>(workspace_2), + reinterpret_cast<__gm__ half *>(causal_mask), + reinterpret_cast<__gm__ half *>(o), batch_size, seq_len, use_fast_mask, + ffts_addr); +} + +extern "C" void call_kernel(uint32_t blockDim, void *stream, uint8_t *q, + uint8_t *k, uint8_t *v, uint8_t *workspace_1, + uint8_t *workspace_2, uint8_t *causal_mask, + uint8_t *o, int32_t *cu_seqlens, + int64_t batch_size, int64_t seq_len, + uint32_t seq_first, uint32_t use_precomputed_h, + uint32_t use_fast_mask) { + uint32_t ffts_len = 0; + uint64_t ffts_addr = 0; + rtGetC2cCtrlAddr(&ffts_addr, &ffts_len); + launch_linear_attention<<>>( + q, k, v, workspace_1, workspace_2, causal_mask, o, cu_seqlens, + batch_size, seq_len, seq_first, use_precomputed_h, use_fast_mask, + ffts_addr); +} diff --git a/examples/jit_cpp/linear_attention/optimization_lession.md b/examples/jit_cpp/linear_attention/optimization_lession.md new file mode 100644 index 00000000..abfd7cbc --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimization_lession.md @@ -0,0 +1,374 @@ +# Linear Attention Optimization Lessons + +This note records the optimization lessons learned from the self-contained PTO-ISA examples in this directory. It is meant to be a practical reference for future work on other PTO-ISA kernels, not just for `linear_attention`. + +The file name intentionally matches the requested spelling: `optimization_lession.md`. + +## How To Use This Note + +Use this file as: +- a checklist before starting a new kernel optimization task +- a reminder of which changes gave real speedups here +- a warning list of common correctness failures and deadlock traps +- a template for planning and recording future experiments + +If you want the concrete runnable history behind these lessons, read `optimize_step_by_step/README.md` and the numbered tutorial directories beside it. + +## Current Reference Point + +The current directory gives you two complementary references: +- `linear_attention.cpp`: the current optimized kernel +- `optimize_step_by_step/`: the local optimization ladder from naive code to the current fast path + +Current kernel shape: +- dynamic `B` and `L` +- compile-time `H`, `D`, and `C` +- fixed `block_dim = num_cores` +- persistent-kernel style work loop inside the kernel + +Current fast configuration: +- `C=128, D=128` +- precomputed causal mask +- shared L0C reuse +- cube-side L0 ping-pong +- 2-slot cube/vector staging +- in-place mask apply +- dual `H`-state L1 buffers +- native `seq_first` support keeps full-chunk strided PTO loads/stores on the fast path and uses dynamic `TLOAD`/`TSTORE` only for true varlen tails + +Current validated performance class in this directory: +- roughly `77 TFLOP/s` for the legacy fused `head_first` reference on large enough benchmark shapes +- roughly `63 TFLOP/s` for the native `seq_first` path on larger benchmark shapes without transpose or padding + +## Core Lessons + +### 1. Start With The Simplest Correct Structure + +The early tutorial steps were useful because they made the dataflow obvious: +- load `Q`, `K`, `V` +- form `QK^T` +- form `K^T V` +- apply the causal rule +- accumulate the running hidden state +- finish `O = masked_scores @ V + Q @ H` + +That simple structure is the right starting point even when it is slow. Optimization was much easier once the kernel had a clear baseline and a matching NumPy/PyTorch explanation. + +Rule for future kernels: +- get one small, readable, correctness-checked version working first +- only then start adding buffering, staging, and flag choreography + +### 2. Keep Hot Dimensions Compile-Time Specialized + +The biggest stability and codegen wins came from keeping the inner tile shape fixed at compile time. + +In this kernel that meant: +- `H`, `D`, and `C` as compile-time constants +- `B` and `L` as runtime dimensions + +Why it helped: +- fewer dynamic branches in inner loops +- simpler on-chip allocation +- more predictable tile lowering and instruction scheduling + +Rule for future kernels: +- keep the dimensions that determine tile shape and on-chip layout compile-time if you can +- push only outer problem-size dimensions to runtime + +### 3. Fixed Launch Shape Plus In-Kernel Work Loop Is A Good Default + +Switching to: +- fixed `block_dim = num_cores` +- logical work mapping inside the kernel + +was the right dynamic-shape structure. + +The key pattern is: +- `work_id = work_idx * block_num + cid` +- skip when `work_id >= total_work` + +Why it helped: +- host launch stays stable +- runtime shape changes do not require changing launch geometry +- the kernel becomes more like a persistent worker loop + +Rule for future kernels: +- if the logical workload varies but the per-core kernel structure stays the same, prefer a fixed launch plus in-kernel work assignment + +### 4. Budget L1, L0C, And UB Explicitly + +The kernel only became robust once memory use was treated as a first-class design constraint. + +Practices that helped: +- explicit byte accounting for L1, L0C, and UB +- `static_assert` guards for invalid tile choices +- separating "one-slot" and "two-slot" workspace footprints +- designing tile sizes around real on-chip budget, not just around the math + +Rule for future kernels: +- write the memory budget down in bytes +- fail early at compile time when a tile choice cannot fit +- treat on-chip memory planning as part of the algorithm design + +### 5. Remove Scalar Work From The Vector Path Early + +One of the biggest speedups came from deleting the scalar causal-mask loop and replacing it with: +- a precomputed triangular mask tensor +- a vector `TMUL` + +Why it helped: +- removed per-element scalar control flow +- let the vector unit handle masking as a tile operation +- made the vector side much simpler to pipeline later + +Rule for future kernels: +- whenever you see elementwise scalar loops on the vector side, first ask whether they can become tile-vector operations + +### 6. Reuse On-Chip Storage Aggressively + +Another important step was changing the L0C layout from: +- separate regions for score, state, and output + +to: +- one shared region reused across serialized cube stages + +Why it helped: +- larger tile choices became legal +- `C=128, D=128` fit without changing the math +- arithmetic intensity improved + +Rule for future kernels: +- if two stages never need the same buffer live at the same time, consider aliasing them onto the same on-chip region + +### 7. Optimize The Cube Microkernel Before Redesigning The Whole Kernel + +The first major structural speedup came from improving the local cube helper: +- split `K=128` into `2 x 64` +- ping-pong two L0 buffers +- overlap extract with cube compute + +Why it helped: +- the inner GEMM path stopped looking like a single serial block +- the outer algorithm stayed unchanged + +Rule for future kernels: +- before rewriting the whole kernel, inspect the most repeated GEMM-like helper and see whether load/compute overlap can be introduced there first + +### 8. Inter-Core Producer/Consumer Pipelines Give Large Wins + +The next big jump came from moving from: +- one chunk of cube work, then one chunk of vector work + +to: +- cube producing chunk `i + 1` while vector consumes chunk `i` + +The working version used: +- two workspace slots +- stage-aware cross-core flags +- an explicit end-of-work-item acknowledgment + +Why it helped: +- reduced chunk-to-chunk bubbles between cube and vector +- let both sides stay busier on long sequences + +Rule for future kernels: +- if cube and vector naturally form a producer/consumer pair, a small staged workspace is often worth more than another tiny local instruction tweak + +### 9. Reduce Temporary Tiles When UB Is Tight + +Applying the mask in-place on `acc_ub` removed one extra UB tile. + +Why it helped: +- lowered UB pressure +- made mask preload possible again +- cut some unnecessary data motion + +Rule for future kernels: +- once the functional structure is stable, inspect temporary tiles and ask which ones can safely become in-place updates + +### 10. Prefetch The Next Recurrent State Early + +The final major improvement here was adding a second `H`-state L1 buffer so the next prefix-state tile could be loaded while the current chunk still had work left. + +Why it helped: +- hid part of the recurrent-state load cost +- reduced bubbles in the output stage + +Rule for future kernels: +- in recurrent or iterative kernels, the next state load is often a good prefetch target once the main pipeline exists + +### 11. Compiler Flags Matter, But Only Measured Ones Count + +This directory also showed that not every seemingly stronger compiler option helps. + +The currently proven settings keep: +- stack sizing flags +- overflow-record flags +- `-cce-aicore-dcci-insert-for-scalar=false` + +The local sweep showed this kernel was faster without: +- `-cce-aicore-addr-transform` +- `-DL2_CACHE_HINT` + +Rule for future kernels: +- treat compiler flags as experiments, not assumptions +- keep only settings that survive correctness and benchmark comparison + +### 12. Prefer Static Strided Tiles Over Dynamic Tiles For The Common Seq-First Case + +The native `(B, T, H, D)` path improved once the common full-chunk case stopped +using dynamic PTO tensors everywhere. + +The useful split was: +- full chunk, dense row stride: use the normal static tile path +- full chunk, strided `seq_first` row layout: use static strided PTO tensors +- only true tail chunks use dynamic `TLOAD` / `TSTORE` + +Why it helped: +- dynamic tile metadata stopped sitting on the hot path for every `seq_first` + chunk +- full-chunk `seq_first` execution became closer to the legacy fixed-layout PTO + fast path +- true packed varlen tails still stayed native and correct + +Rule for future kernels: +- if a layout is strided but regular, try static strided global tensors before + falling back to dynamic tiles for the whole kernel +- reserve dynamic load/store machinery for the irregular tail path only + +### 13. Keep Experimental Fast Paths Only If They Beat The Best Structured Path + +During the `seq_first` work, a larger fused PTO fast path was prototyped to try +to match the old `head_first` kernel more directly. + +What happened: +- it increased code size and complexity +- it did not beat the optimized precomputed `seq_first` pipeline consistently +- the simpler staged `seq_first` path scaled better and was easier to keep + correct with native varlen support + +Rule for future kernels: +- when an experiment adds a second major kernel structure, demand a clear and + repeated benchmark win before keeping it +- if the improvement is not durable, revert it and keep the smaller design + +## Common Failure Modes + +These were recurring problems during the work here: + +### Deadlocks + +Typical causes: +- reusing a staged buffer before the peer core released it +- narrowing dependencies too aggressively +- forgetting an end-of-work-item handshake when the same physical core later serves a different logical job + +Guardrail: +- if you add or change a pipeline stage, re-check all producer/consumer ownership transitions explicitly + +### Silent Numerical Regressions + +Typical causes: +- wrong byte offsets +- wrong aliasing assumptions in L0C or UB +- reordered accumulation without checking tolerance impact + +Guardrail: +- keep full correctness sweeps, not just one smoke shape + +### Overfitting To Benchmark Noise + +Typical cause: +- keeping a change because one run was slightly faster + +Guardrail: +- compare repeated measurements on the same shape set +- revert marginal gains if they do not repeat cleanly + +### Complexity Without Throughput Gain + +Typical cause: +- adding a pipeline or microkernel that looks more advanced but does not improve the dominant bottleneck + +Guardrail: +- only keep structural complexity when the measured benefit is clear + +## Practical Optimization Order + +For a new PTO-ISA kernel, a good order is: + +1. Get a small, direct, correctness-checked baseline. +2. Move runtime variability out of the hot inner tile logic. +3. Remove scalar work from vector code. +4. Revisit tile shape and on-chip memory reuse. +5. Improve the local cube microkernel. +6. Add staged producer/consumer overlap between pipelines. +7. Reduce temporary buffers and prefetch recurrent state. +8. Sweep compiler flags only after the kernel structure is stable. +9. Tune benchmark shapes to expose steady-state throughput. + +## What To Measure + +For each experiment, keep the same checklist: + +- correctness on a full shape sweep +- at least one small smoke benchmark +- at least one larger steady-state benchmark +- best TFLOP/s shape +- best GiB/s shape +- whether bandwidth includes or excludes workspace traffic + +If possible, keep: +- one fixed quick shape set for iteration +- one fixed large-shape table for decisions + +## Experiment Template + +Record each attempt with: + +- `ID`: short experiment name +- `Goal`: the bottleneck being targeted +- `Hypothesis`: why it might help +- `Change`: exact implementation change +- `Check`: correctness and benchmark commands +- `Status`: `todo`, `doing`, `done`, `reverted`, `dropped` +- `Result`: measured outcome and short conclusion + +Recommended workflow: + +1. Pick one experiment only. +2. Record the benchmark shapes before editing. +3. Run correctness first. +4. Run the same benchmark set before and after. +5. Keep or drop the change based on repeated evidence. + +## Local Progression Summary + +The local tutorial ladder in `optimize_step_by_step/` also captures the high-level progression: + +1. naive static shape +2. dynamic work mapping +3. cached causal mask +4. larger chunk size +5. cube L0 ping-pong +6. two-slot cube/vector pipeline +7. L1 hidden-state prefetch + +That sequence is a useful default mental model for future optimization tasks: +- first remove obvious scalar waste +- then improve tile size and memory reuse +- then overlap local stages +- then overlap whole pipelines + +## Closing Thought + +The biggest gains in this directory did not come from changing the algorithm. They came from: + +- reducing scalar work +- specializing the hot path +- planning on-chip memory explicitly +- reusing buffers aggressively +- overlapping cube, vector, and memory movement +- keeping only measured improvements + +For future PTO-ISA kernels, the main takeaway is simple: start from a clear baseline, optimize one bottleneck at a time, and only keep structural complexity that earns its place in the benchmark table. diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/01_naive_static_shape/README.md b/examples/jit_cpp/linear_attention/optimize_step_by_step/01_naive_static_shape/README.md new file mode 100644 index 00000000..da1dc046 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/01_naive_static_shape/README.md @@ -0,0 +1,23 @@ +# Step 01: Naive Static Shape + +This is the beginner-friendly fixed-shape starting point. + +What it teaches: +- the smallest end-to-end PTO-ISA linear attention example +- how workspace buffers are laid out for one fixed `(B, H, L, D, C)` configuration +- how the kernel maps almost one-to-one to the `numpy_sim.py` logic +- why static-shape kernels are simple but inflexible + +Files: +- `linear_attention.cpp`: fixed-shape PTO kernel +- `jit_util_linear_attention.py`: JIT compile/load helper +- `run_linear_attention.py`: correctness check +- `benchmark_linear_attention.py`: simple fixed-shape benchmark +- `numpy_sim.py`: sequential NumPy emulation of the same indexing and workspace logic + +Suggested run order: +```bash +python numpy_sim.py +python run_linear_attention.py +python benchmark_linear_attention.py +``` diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/01_naive_static_shape/benchmark_linear_attention.py b/examples/jit_cpp/linear_attention/optimize_step_by_step/01_naive_static_shape/benchmark_linear_attention.py new file mode 100644 index 00000000..75217223 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/01_naive_static_shape/benchmark_linear_attention.py @@ -0,0 +1,36 @@ +from pathlib import Path +import sys + +COMMON_DIR = Path(__file__).resolve().parents[1] / "common" +if str(COMMON_DIR) not in sys.path: + sys.path.insert(0, str(COMMON_DIR)) + +import torch +import torch_npu # noqa: F401 + +from jit_util_linear_attention import jit_compile +from linear_attention_shared import kernel_src_path, make_inputs, measure_kernel_ms + +B, H, L, D, C = 2, 2, 512, 128, 64 +DTYPE = torch.float16 + + +def main(): + torch.npu.set_device("npu:0") + src = kernel_src_path(__file__) + kernel = jit_compile(src) + q, k, v = make_inputs(B, H, L, D) + workspace_1 = torch.zeros((B, H, C, C), device="npu", dtype=DTYPE) + workspace_2 = torch.zeros((B, H, D, D), device="npu", dtype=DTYPE) + output = torch.zeros((B, H, L, D), device="npu", dtype=DTYPE) + + median_ms = measure_kernel_ms( + lambda: kernel(q, k, v, workspace_1, workspace_2, output, block_dim=B * H), + warmup=3, + repeats=5, + ) + print("shape", (B, H, L, D, C), "median_ms", median_ms) + + +if __name__ == "__main__": + main() diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/01_naive_static_shape/jit_util_linear_attention.py b/examples/jit_cpp/linear_attention/optimize_step_by_step/01_naive_static_shape/jit_util_linear_attention.py new file mode 100644 index 00000000..d69f5ab8 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/01_naive_static_shape/jit_util_linear_attention.py @@ -0,0 +1,34 @@ +from pathlib import Path +import sys + +COMMON_DIR = Path(__file__).resolve().parents[1] / "common" +if str(COMMON_DIR) not in sys.path: + sys.path.insert(0, str(COMMON_DIR)) + +from functools import lru_cache + +from jit_shared import compile_cpp as shared_compile_cpp +from jit_shared import load_static_nomask_lib + + +def compile_cpp(kernel_cpp: str, verbose: bool = False, timeout: int = 180) -> str: + return shared_compile_cpp( + kernel_cpp, + output_name="linear_attention_jit.so", + std="c++17", + verbose=verbose, + timeout=timeout, + ) + + +def load_lib(lib_path: str): + return load_static_nomask_lib(lib_path) + + +@lru_cache(maxsize=None) +def jit_compile(src_path: str, verbose: bool = True, clean_up: bool = False): + lib_path = compile_cpp(src_path, verbose=verbose) + func = load_lib(lib_path) + if clean_up: + Path(lib_path).unlink(missing_ok=True) + return func diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/01_naive_static_shape/linear_attention.cpp b/examples/jit_cpp/linear_attention/optimize_step_by_step/01_naive_static_shape/linear_attention.cpp new file mode 100644 index 00000000..f33305de --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/01_naive_static_shape/linear_attention.cpp @@ -0,0 +1,260 @@ +#include +#include +#include +#include + +using namespace pto; + +constexpr int kBatch = 2; +constexpr int kHeads = 2; +constexpr int kSeqLen = 512; +constexpr int kHidden = 128; +constexpr int kChunk = 64; +constexpr int kTotalWork = kBatch * kHeads; +constexpr int kChunkCount = kSeqLen / kChunk; +constexpr int kVecParts = 2; +constexpr int kHalfChunk = kChunk / kVecParts; +constexpr int kHalfHidden = kHidden / kVecParts; + +template +using L1Mat = Tile; + +template +using L1MatTrans = + Tile; + +template +using UbVec = Tile; + +template +AICORE inline void SetCrossFlag(int32_t flag, int32_t mode) { + const int config = 1 | (mode << 4) | (flag << 8); + ffts_cross_core_sync(Pipe, config); +} + +AICORE inline void WaitCrossFlag(int32_t flag) { wait_flag_dev(flag); } + +template +AICORE inline void MatmulL1( + TileAcc &dst, + std::conditional_t, L1Mat> &a_l1, + std::conditional_t, L1Mat> &b_l1, + bool init) { + TileLeft a_l0; + TileRight b_l0; + TASSIGN(a_l0, 0x0); + TASSIGN(b_l0, 0x0); + + if constexpr (TransposeA) { + L1MatTrans a_view; + TRESHAPE(a_view, a_l1); + TEXTRACT(a_l0, a_view, 0, 0); + } else { + TEXTRACT(a_l0, a_l1, 0, 0); + } + + if constexpr (TransposeB) { + L1MatTrans b_view; + TRESHAPE(b_view, b_l1); + TEXTRACT(b_l0, b_view, 0, 0); + } else { + TEXTRACT(b_l0, b_l1, 0, 0); + } + + pipe_barrier(PIPE_ALL); + if (init) { + TMATMUL(dst, a_l0, b_l0); + } else { + TMATMUL_ACC(dst, dst, a_l0, b_l0); + } + pipe_barrier(PIPE_ALL); +} + +AICORE void main_kernel(__gm__ half *q, __gm__ half *k, __gm__ half *v, + __gm__ half *workspace_scores, + __gm__ half *workspace_state, __gm__ half *o, + uint64_t ffts_addr) { + constexpr int kChunkElems = kChunk * kHidden; + constexpr int kScoreElems = kChunk * kChunk; + constexpr int kStateElems = kHidden * kHidden; + constexpr int kScoreL1Addr = 0; + constexpr int kKeyL1Addr = kScoreL1Addr + kChunkElems * sizeof(half); + constexpr int kValueL1Addr = kKeyL1Addr + kChunkElems * sizeof(half); + constexpr int kStateL1Addr = kValueL1Addr + kChunkElems * sizeof(half); + constexpr int kMaskedScoreL1Addr = kStateL1Addr + kStateElems * sizeof(half); + constexpr int kScoreL0Addr = 0; + constexpr int kStateL0Addr = kScoreL0Addr + kScoreElems * sizeof(float); + constexpr int kOutputL0Addr = kStateL0Addr + kStateElems * sizeof(float); + constexpr int kPrefixStateUbAddr = 0; + constexpr int kScoreUbAddr = kPrefixStateUbAddr + kHalfHidden * kHidden * sizeof(half); + constexpr int kStateUbAddr = kScoreUbAddr + kHalfChunk * kChunk * sizeof(half); + constexpr int kZeroScoreUbAddr = kStateUbAddr + kHalfHidden * kHidden * sizeof(half); + + using ChunkGlobal = + GlobalTensor, + BaseShape2D, Layout::ND>; + using ScoreGlobal = + GlobalTensor, + BaseShape2D, Layout::ND>; + using StateGlobal = + GlobalTensor, + BaseShape2D, Layout::ND>; + using HalfScoreGlobal = + GlobalTensor, + BaseShape2D, Layout::ND>; + using HalfStateGlobal = + GlobalTensor, + BaseShape2D, Layout::ND>; + + const int64_t core_id = get_block_idx(); + const int64_t vector_id = get_subblockid(); + if (core_id >= kTotalWork) { + return; + } + set_ffts_base_addr(ffts_addr); + + const int64_t head_id = core_id % kHeads; + const int64_t batch_id = core_id / kHeads; + const int64_t qkv_base = ((batch_id * kHeads + head_id) * kSeqLen) * kHidden; + const int64_t score_workspace_base = core_id * kScoreElems; + const int64_t state_workspace_base = core_id * kStateElems; + + L1Mat q_chunk_l1; + L1Mat k_chunk_l1; + L1Mat v_chunk_l1; + L1Mat prefix_state_l1; + L1Mat masked_score_l1; + TASSIGN(q_chunk_l1, kScoreL1Addr); + TASSIGN(k_chunk_l1, kKeyL1Addr); + TASSIGN(v_chunk_l1, kValueL1Addr); + TASSIGN(prefix_state_l1, kStateL1Addr); + TASSIGN(masked_score_l1, kMaskedScoreL1Addr); + + TileAcc raw_score_l0; + TileAcc state_update_l0; + TileAcc output_l0; + TASSIGN(raw_score_l0, kScoreL0Addr); + TASSIGN(state_update_l0, kStateL0Addr); + TASSIGN(output_l0, kOutputL0Addr); + + UbVec running_state_ub; + UbVec state_delta_ub; + UbVec score_ub; + UbVec zero_score_ub; + TASSIGN(running_state_ub, kPrefixStateUbAddr); + TASSIGN(score_ub, kScoreUbAddr); + TASSIGN(state_delta_ub, kStateUbAddr); + TASSIGN(zero_score_ub, kZeroScoreUbAddr); + +#if defined(__DAV_C220_CUBE__) + WaitCrossFlag(1); + for (int chunk_index = 0; chunk_index < kChunkCount; ++chunk_index) { + const int64_t chunk_base = qkv_base + chunk_index * kChunkElems; + ChunkGlobal q_global(q + chunk_base); + ChunkGlobal k_global(k + chunk_base); + ChunkGlobal v_global(v + chunk_base); + StateGlobal prefix_state_global(workspace_state + state_workspace_base); + ScoreGlobal score_global(workspace_scores + score_workspace_base); + ChunkGlobal output_global(o + chunk_base); + + // Load the current Q/K/V chunk and the prefix state from the previous step. + TLOAD(q_chunk_l1, q_global); + TLOAD(k_chunk_l1, k_global); + TLOAD(v_chunk_l1, v_global); + TLOAD(prefix_state_l1, prefix_state_global); + pipe_barrier(PIPE_ALL); + + // First cube matmul: chunk-local QK^T scores. + MatmulL1(raw_score_l0, q_chunk_l1, + k_chunk_l1, true); + TSTORE(score_global, raw_score_l0); + pipe_barrier(PIPE_ALL); + + // Second cube matmul: K^T V contribution to the running hidden state. + MatmulL1(state_update_l0, k_chunk_l1, + v_chunk_l1, true); + TSTORE(prefix_state_global, state_update_l0); + pipe_barrier(PIPE_ALL); + SetCrossFlag(0, 2); + + // Wait for vector core to apply the causal mask and accumulate the state. + WaitCrossFlag(1); + TLOAD(masked_score_l1, score_global); + pipe_barrier(PIPE_ALL); + + // Final output: masked_scores @ V + Q @ prefix_state. + MatmulL1(output_l0, masked_score_l1, + v_chunk_l1, true); + MatmulL1(output_l0, q_chunk_l1, + prefix_state_l1, false); + TSTORE(output_global, output_l0); + pipe_barrier(PIPE_ALL); + } +#endif + +#if defined(__DAV_C220_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + TEXPANDS(running_state_ub, 0.0f); + TEXPANDS(zero_score_ub, 0.0f); + pipe_barrier(PIPE_ALL); + HalfStateGlobal state_slice_global(workspace_state + state_workspace_base + + vector_id * kHalfHidden * kHidden); + TSTORE(state_slice_global, running_state_ub); + pipe_barrier(PIPE_ALL); + SetCrossFlag(1, 2); + + for (int chunk_index = 0; chunk_index < kChunkCount; ++chunk_index) { + WaitCrossFlag(0); + HalfScoreGlobal score_slice_global(workspace_scores + score_workspace_base + + vector_id * kHalfChunk * kChunk); + TLOAD(score_ub, score_slice_global); + TLOAD(state_delta_ub, state_slice_global); + pipe_barrier(PIPE_ALL); + + // This scalar loop is intentionally naive: it makes the triangular mask + // visible to readers instead of hiding it in a precomputed tensor. + for (int row = 0; row < kHalfChunk; ++row) { + for (int col = 0; col < kChunk; ++col) { + if (vector_id * kHalfChunk + row < col) { + score_ub.SetValue(row * kChunk + col, + zero_score_ub.GetValue(row * kChunk + col)); + } + } + } + pipe_barrier(PIPE_ALL); + + TADD(running_state_ub, running_state_ub, state_delta_ub); + pipe_barrier(PIPE_ALL); + TSTORE(score_slice_global, score_ub); + TSTORE(state_slice_global, running_state_ub); + pipe_barrier(PIPE_ALL); + SetCrossFlag(1, 2); + } +#endif +} + +extern "C" __global__ AICORE void launch_linear_attention( + __gm__ uint8_t *q, __gm__ uint8_t *k, __gm__ uint8_t *v, + __gm__ uint8_t *workspace_scores, __gm__ uint8_t *workspace_state, + __gm__ uint8_t *o, uint64_t ffts_addr) { + main_kernel(reinterpret_cast<__gm__ half *>(q), + reinterpret_cast<__gm__ half *>(k), + reinterpret_cast<__gm__ half *>(v), + reinterpret_cast<__gm__ half *>(workspace_scores), + reinterpret_cast<__gm__ half *>(workspace_state), + reinterpret_cast<__gm__ half *>(o), ffts_addr); +} + +extern "C" void call_kernel(uint32_t block_dim, void *stream, uint8_t *q, + uint8_t *k, uint8_t *v, uint8_t *workspace_scores, + uint8_t *workspace_state, uint8_t *o) { + uint32_t ffts_len = 0; + uint64_t ffts_addr = 0; + rtGetC2cCtrlAddr(&ffts_addr, &ffts_len); + launch_linear_attention<<>>( + q, k, v, workspace_scores, workspace_state, o, ffts_addr); +} diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/01_naive_static_shape/numpy_sim.py b/examples/jit_cpp/linear_attention/optimize_step_by_step/01_naive_static_shape/numpy_sim.py new file mode 100644 index 00000000..5dacbea0 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/01_naive_static_shape/numpy_sim.py @@ -0,0 +1,57 @@ +import numpy as np + +B, H, L, D, C = 1, 2, 128, 128, 64 + + +def ref_linear_attention(q, k, v): + h = np.zeros((B, H, D, D), dtype=np.float32) + o = np.zeros((B, H, L, D), dtype=np.float32) + for t in range(L): + h += np.einsum("bhi,bhj->bhij", k[:, :, t].astype(np.float32), v[:, :, t].astype(np.float32)) + o[:, :, t] = np.einsum("bhi,bhij->bhj", q[:, :, t].astype(np.float32), h) + return o.astype(np.float16) + + +def step01_numpy_sim(q, k, v): + chunk_num = L // C + workspace_1 = np.zeros((B, H, C, C), dtype=np.float16) + workspace_2 = np.zeros((B, H, D, D), dtype=np.float16) + out = np.zeros((B, H, L, D), dtype=np.float16) + causal_mask = np.tril(np.ones((C, C), dtype=np.float32)) + + # Real hardware runs `B * H` work items in parallel across cores. + # This tutorial uses a plain sequential loop so the indexing is easy to see. + for bz in range(B): + for by in range(H): + h_state = np.zeros((D, D), dtype=np.float32) + for chunk_idx in range(chunk_num): + l0 = chunk_idx * C + l1 = l0 + C + q_tile = q[bz, by, l0:l1].astype(np.float32) + k_tile = k[bz, by, l0:l1].astype(np.float32) + v_tile = v[bz, by, l0:l1].astype(np.float32) + + acc = (q_tile @ k_tile.T) * causal_mask + workspace_1[bz, by] = acc.astype(np.float16) + + out[bz, by, l0:l1] = (acc @ v_tile + q_tile @ h_state).astype(np.float16) + h_state = h_state + k_tile.T @ v_tile + workspace_2[bz, by] = h_state.astype(np.float16) + return out + + +def main(): + np.random.seed(0) + q = np.random.randn(B, H, L, D).astype(np.float16) + k = np.random.randn(B, H, L, D).astype(np.float16) + v = np.random.randn(B, H, L, D).astype(np.float16) + q = q / (np.linalg.norm(q.astype(np.float32), axis=-1, keepdims=True) + 1e-6) + k = k / (np.linalg.norm(k.astype(np.float32), axis=-1, keepdims=True) + 1e-6) + ref = ref_linear_attention(q, k, v) + sim = step01_numpy_sim(q, k, v) + np.testing.assert_allclose(sim, ref, rtol=1e-2, atol=1e-2) + print('step01 numpy simulation passed') + + +if __name__ == '__main__': + main() diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/01_naive_static_shape/run_linear_attention.py b/examples/jit_cpp/linear_attention/optimize_step_by_step/01_naive_static_shape/run_linear_attention.py new file mode 100644 index 00000000..e99c3c17 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/01_naive_static_shape/run_linear_attention.py @@ -0,0 +1,37 @@ +from pathlib import Path +import sys + +COMMON_DIR = Path(__file__).resolve().parents[1] / "common" +if str(COMMON_DIR) not in sys.path: + sys.path.insert(0, str(COMMON_DIR)) + +import torch +import torch_npu # noqa: F401 + +from jit_util_linear_attention import jit_compile +from linear_attention_shared import run_correctness_cases + +B = 2 +H = 2 +L = 512 +D = 128 +C = 64 + + +def run_kernel(src: str, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, chunk_size: int): + del chunk_size + kernel = jit_compile(src) + workspace_1 = torch.zeros((B, H, C, C), device=q.device, dtype=torch.float16) + workspace_2 = torch.zeros((B, H, D, D), device=q.device, dtype=torch.float16) + output = torch.zeros((B, H, L, D), device=q.device, dtype=torch.float16) + kernel(q, k, v, workspace_1, workspace_2, output, block_dim=B * H) + torch.npu.synchronize() + return output + + +def main(): + run_correctness_cases(__file__, [(B, H, L, D, C)], run_kernel) + + +if __name__ == "__main__": + main() diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/02_naive_dynamic_shape/README.md b/examples/jit_cpp/linear_attention/optimize_step_by_step/02_naive_dynamic_shape/README.md new file mode 100644 index 00000000..3062c591 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/02_naive_dynamic_shape/README.md @@ -0,0 +1,23 @@ +# Step 02: Naive Dynamic Shape + +This step is the beginner-friendly dynamic-shape version. + +What it teaches: +- how `B` and `L` move from compile time to runtime +- why the launch `block_dim` becomes fixed to the device core count +- how the kernel loops internally over work items when `B * H` is larger than the number of cores +- how the dynamic kernel still stays close to the simple NumPy dataflow + +Files: +- `linear_attention.cpp`: minimal dynamic PTO kernel +- `jit_util_linear_attention.py`: dynamic-shape JIT helper +- `run_linear_attention.py`: correctness sweep +- `benchmark_linear_attention.py`: early benchmark script +- `numpy_sim.py`: sequential NumPy emulation of dynamic work partitioning + +Suggested run order: +```bash +python numpy_sim.py +python run_linear_attention.py +python benchmark_linear_attention.py --quick +``` diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/02_naive_dynamic_shape/benchmark_linear_attention.py b/examples/jit_cpp/linear_attention/optimize_step_by_step/02_naive_dynamic_shape/benchmark_linear_attention.py new file mode 100644 index 00000000..f7672586 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/02_naive_dynamic_shape/benchmark_linear_attention.py @@ -0,0 +1,53 @@ +from pathlib import Path +import sys + +COMMON_DIR = Path(__file__).resolve().parents[1] / "common" +if str(COMMON_DIR) not in sys.path: + sys.path.insert(0, str(COMMON_DIR)) + +from jit_util_linear_attention import BLOCK_DIM, jit_compile +from linear_attention_shared import benchmark_cli, benchmark_dynamic_kernel + +DEFAULT_SHAPES = [ + (16, 20, 1024, 128, 64), + (16, 20, 2048, 128, 64), + (32, 20, 1024, 128, 64), + (8, 20, 4096, 128, 64), + (16, 20, 1024, 256, 64), +] + +QUICK_SHAPES = [ + (8, 20, 1024, 128, 64), + (16, 20, 1024, 128, 64), +] + + +def benchmark_shape(src: str, *, batch: int, heads: int, seq_len: int, hidden: int, chunk: int, warmup: int, repeats: int): + return benchmark_dynamic_kernel( + src, + batch=batch, + heads=heads, + seq_len=seq_len, + hidden=hidden, + chunk=chunk, + warmup=warmup, + repeats=repeats, + jit_compile=jit_compile, + block_dim=BLOCK_DIM, + stage_count=1, + use_mask=False, + include_workspace_bytes=True, + ) + + +def main(): + benchmark_cli( + script_file=__file__, + default_shapes=DEFAULT_SHAPES, + quick_shapes=QUICK_SHAPES, + benchmark_shape=benchmark_shape, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/02_naive_dynamic_shape/jit_util_linear_attention.py b/examples/jit_cpp/linear_attention/optimize_step_by_step/02_naive_dynamic_shape/jit_util_linear_attention.py new file mode 100644 index 00000000..5beb2b82 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/02_naive_dynamic_shape/jit_util_linear_attention.py @@ -0,0 +1,59 @@ +from pathlib import Path +import sys + +COMMON_DIR = Path(__file__).resolve().parents[1] / "common" +if str(COMMON_DIR) not in sys.path: + sys.path.insert(0, str(COMMON_DIR)) + +from functools import lru_cache + +from jit_shared import BLOCK_DIM, compile_cpp as shared_compile_cpp +from jit_shared import load_dynamic_nomask_lib + + +def compile_cpp( + kernel_cpp: str, + num_heads: int, + hidden_size: int, + chunk_size: int, + verbose: bool = False, + timeout: int = 180, +) -> str: + return shared_compile_cpp( + kernel_cpp, + output_name=f"linear_attention_H{num_heads}_D{hidden_size}_C{chunk_size}_jit.so", + std="c++17", + defines=[ + f"-DLINEAR_ATTN_H={num_heads}", + f"-DLINEAR_ATTN_D={hidden_size}", + f"-DLINEAR_ATTN_C={chunk_size}", + ], + verbose=verbose, + timeout=timeout, + ) + + +def load_lib(lib_path: str): + return load_dynamic_nomask_lib(lib_path) + + +@lru_cache(maxsize=None) +def jit_compile( + src_path: str, + num_heads: int, + hidden_size: int, + chunk_size: int, + verbose: bool = True, + clean_up: bool = False, +): + lib_path = compile_cpp( + src_path, + num_heads=num_heads, + hidden_size=hidden_size, + chunk_size=chunk_size, + verbose=verbose, + ) + func = load_lib(lib_path) + if clean_up: + Path(lib_path).unlink(missing_ok=True) + return func diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/02_naive_dynamic_shape/linear_attention.cpp b/examples/jit_cpp/linear_attention/optimize_step_by_step/02_naive_dynamic_shape/linear_attention.cpp new file mode 100644 index 00000000..c9554198 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/02_naive_dynamic_shape/linear_attention.cpp @@ -0,0 +1,321 @@ +#include +#include +#include + +using namespace pto; + +#ifndef LINEAR_ATTN_H +#define LINEAR_ATTN_H 2 +#endif + +#ifndef LINEAR_ATTN_D +#define LINEAR_ATTN_D 128 +#endif + +#ifndef LINEAR_ATTN_C +#define LINEAR_ATTN_C 64 +#endif + +template +using L1Mat = Tile; + +template +using L1MatTrans = + Tile; + +template +using UbVec = Tile; + +template +AICORE inline void SetCrossFlag(int32_t flag, int32_t mode) { + const int config = 1 | (mode << 4) | (flag << 8); + ffts_cross_core_sync(Pipe, config); +} + +AICORE inline void WaitCrossFlag(int32_t flag) { wait_flag_dev(flag); } + +template +AICORE inline void MatmulL1( + TileAcc &dst, + std::conditional_t, L1Mat> &a_l1, + std::conditional_t, L1Mat> &b_l1, + bool init) { + TileLeft a_l0; + TileRight b_l0; + TASSIGN(a_l0, 0x0); + TASSIGN(b_l0, 0x0); + + if constexpr (TransposeA) { + L1MatTrans a_view; + TRESHAPE(a_view, a_l1); + TEXTRACT(a_l0, a_view, 0, 0); + } else { + TEXTRACT(a_l0, a_l1, 0, 0); + } + + if constexpr (TransposeB) { + L1MatTrans b_view; + TRESHAPE(b_view, b_l1); + TEXTRACT(b_l0, b_view, 0, 0); + } else { + TEXTRACT(b_l0, b_l1, 0, 0); + } + + pipe_barrier(PIPE_ALL); + if (init) { + TMATMUL(dst, a_l0, b_l0); + } else { + TMATMUL_ACC(dst, dst, a_l0, b_l0); + } + pipe_barrier(PIPE_ALL); +} + +template +AICORE void main_kernel(__gm__ half *q, __gm__ half *k, __gm__ half *v, + __gm__ half *workspace_scores, + __gm__ half *workspace_state, __gm__ half *o, + int64_t batch_size, int64_t seq_len, + uint64_t ffts_addr) { + constexpr int kVecParts = 2; + constexpr int kHalfChunk = ChunkSize / kVecParts; + constexpr int kHalfHidden = HiddenSize / kVecParts; + constexpr int kChunkElems = ChunkSize * HiddenSize; + constexpr int kScoreElems = ChunkSize * ChunkSize; + constexpr int kStateElems = HiddenSize * HiddenSize; + constexpr int kQueryL1Addr = 0; + constexpr int kKeyL1Addr = kQueryL1Addr + kChunkElems * sizeof(half); + constexpr int kValueL1Addr = kKeyL1Addr + kChunkElems * sizeof(half); + constexpr int kStateL1Addr = kValueL1Addr + kChunkElems * sizeof(half); + constexpr int kMaskedScoreL1Addr = kStateL1Addr + kStateElems * sizeof(half); + constexpr int kScoreL0Addr = 0; + constexpr int kStateL0Addr = kScoreL0Addr + kScoreElems * sizeof(float); + constexpr int kOutputL0Addr = kStateL0Addr + kStateElems * sizeof(float); + constexpr int kPrefixStateUbAddr = 0; + constexpr int kScoreUbAddr = kPrefixStateUbAddr + kHalfHidden * HiddenSize * sizeof(half); + constexpr int kStateUbAddr = kScoreUbAddr + kHalfChunk * ChunkSize * sizeof(half); + constexpr int kZeroScoreUbAddr = kStateUbAddr + kHalfHidden * HiddenSize * sizeof(half); + + static_assert((HiddenSize % 2) == 0, "HiddenSize must be even."); + static_assert((ChunkSize % 2) == 0, "ChunkSize must be even."); + + using ChunkGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using ScoreGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using StateGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HalfScoreGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HalfStateGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + + // Step 02 is the first dynamic-shape kernel: + // - H, D, and C are still compile-time constants + // - B and L now arrive at runtime + // The important consequence is that we can no longer launch exactly one core + // per logical (batch, head) job like step 01. Instead, Python launches a + // fixed block_dim equal to the device core count, and the kernel itself + // loops over as many logical jobs as needed. + const int64_t total_work = batch_size * NumHeads; + const int64_t chunk_count = seq_len / ChunkSize; + const int64_t core_id = get_block_idx(); + const int64_t vector_id = get_subblockid(); + set_ffts_base_addr(ffts_addr); + + L1Mat q_chunk_l1; + L1Mat k_chunk_l1; + L1Mat v_chunk_l1; + L1Mat prefix_state_l1; + L1Mat masked_score_l1; + TASSIGN(q_chunk_l1, kQueryL1Addr); + TASSIGN(k_chunk_l1, kKeyL1Addr); + TASSIGN(v_chunk_l1, kValueL1Addr); + TASSIGN(prefix_state_l1, kStateL1Addr); + TASSIGN(masked_score_l1, kMaskedScoreL1Addr); + + TileAcc raw_score_l0; + TileAcc state_update_l0; + TileAcc output_l0; + TASSIGN(raw_score_l0, kScoreL0Addr); + TASSIGN(state_update_l0, kStateL0Addr); + TASSIGN(output_l0, kOutputL0Addr); + + UbVec running_state_ub; + UbVec state_delta_ub; + UbVec score_ub; + UbVec zero_score_ub; + TASSIGN(running_state_ub, kPrefixStateUbAddr); + TASSIGN(score_ub, kScoreUbAddr); + TASSIGN(state_delta_ub, kStateUbAddr); + TASSIGN(zero_score_ub, kZeroScoreUbAddr); + +#if defined(__DAV_C220_CUBE__) + // In step 01, the launch is static: core_id directly identifies one fixed + // (batch, head) pair because B and H are baked into the program. + // + // In step 02, the launch stays fixed even when B changes. We therefore build + // a small "grid-stride loop" inside the kernel: + // work_idx = which round of jobs this core is handling + // work_id = the logical (batch, head) job assigned in that round + // + // Example: + // total_work = 100 logical jobs, block_num = 24 physical cores + // round 0 handles jobs 0..23 + // round 1 handles jobs 24..47 + // ... + // round 4 handles jobs 96..99, and the extra cores simply skip via + // "if (work_id >= total_work) continue". + for (int64_t work_idx = 0; work_idx < (total_work + block_num - 1) / block_num; + ++work_idx) { + const int64_t work_id = work_idx * block_num + core_id; + if (work_id >= total_work) { + continue; + } + + const int64_t head_id = work_id % NumHeads; + const int64_t batch_id = work_id / NumHeads; + // Once we have a flat logical work_id, convert it back to the mathematical + // indices used by the algorithm: first choose the head, then the batch. + const int64_t qkv_base = ((batch_id * NumHeads + head_id) * seq_len) * HiddenSize; + const int64_t score_workspace_base = core_id * kScoreElems; + const int64_t state_workspace_base = core_id * kStateElems; + + WaitCrossFlag(1); + for (int64_t chunk_index = 0; chunk_index < chunk_count; ++chunk_index) { + const int64_t chunk_base = qkv_base + chunk_index * kChunkElems; + ChunkGlobal q_global(q + chunk_base); + ChunkGlobal k_global(k + chunk_base); + ChunkGlobal v_global(v + chunk_base); + StateGlobal prefix_state_global(workspace_state + state_workspace_base); + ScoreGlobal score_global(workspace_scores + score_workspace_base); + ChunkGlobal output_global(o + chunk_base); + + // In the dynamic kernel the launch shape stays fixed; this loop maps + // many logical (batch, head) jobs onto the available cores. + TLOAD(q_chunk_l1, q_global); + TLOAD(k_chunk_l1, k_global); + TLOAD(v_chunk_l1, v_global); + TLOAD(prefix_state_l1, prefix_state_global); + pipe_barrier(PIPE_ALL); + + MatmulL1(raw_score_l0, + q_chunk_l1, + k_chunk_l1, true); + TSTORE(score_global, raw_score_l0); + pipe_barrier(PIPE_ALL); + + MatmulL1(state_update_l0, + k_chunk_l1, + v_chunk_l1, true); + TSTORE(prefix_state_global, state_update_l0); + pipe_barrier(PIPE_ALL); + SetCrossFlag(0, 2); + + WaitCrossFlag(1); + TLOAD(masked_score_l1, score_global); + pipe_barrier(PIPE_ALL); + + MatmulL1(output_l0, + masked_score_l1, + v_chunk_l1, true); + MatmulL1(output_l0, + q_chunk_l1, + prefix_state_l1, + false); + TSTORE(output_global, output_l0); + pipe_barrier(PIPE_ALL); + } + } +#endif + +#if defined(__DAV_C220_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + for (int64_t work_idx = 0; work_idx < (total_work + block_num - 1) / block_num; + ++work_idx) { + // Vector cores use the same logical work assignment as cube cores so both + // sides stay synchronized on which (batch, head) job is currently in the + // shared workspace for this physical core. + const int64_t work_id = work_idx * block_num + core_id; + if (work_id >= total_work) { + continue; + } + + const int64_t score_workspace_base = + core_id * kScoreElems + vector_id * kHalfChunk * ChunkSize; + const int64_t state_workspace_base = + core_id * kStateElems + vector_id * kHalfHidden * HiddenSize; + + TEXPANDS(running_state_ub, 0.0f); + TEXPANDS(zero_score_ub, 0.0f); + pipe_barrier(PIPE_ALL); + HalfStateGlobal state_slice_global(workspace_state + state_workspace_base); + TSTORE(state_slice_global, running_state_ub); + pipe_barrier(PIPE_ALL); + SetCrossFlag(1, 2); + + for (int64_t chunk_index = 0; chunk_index < chunk_count; ++chunk_index) { + WaitCrossFlag(0); + HalfScoreGlobal score_slice_global(workspace_scores + score_workspace_base); + TLOAD(score_ub, score_slice_global); + TLOAD(state_delta_ub, state_slice_global); + pipe_barrier(PIPE_ALL); + + for (int row = 0; row < kHalfChunk; ++row) { + for (int col = 0; col < ChunkSize; ++col) { + if (vector_id * kHalfChunk + row < col) { + score_ub.SetValue(row * ChunkSize + col, + zero_score_ub.GetValue(row * ChunkSize + col)); + } + } + } + pipe_barrier(PIPE_ALL); + + TADD(running_state_ub, running_state_ub, state_delta_ub); + pipe_barrier(PIPE_ALL); + TSTORE(score_slice_global, score_ub); + TSTORE(state_slice_global, running_state_ub); + pipe_barrier(PIPE_ALL); + SetCrossFlag(1, 2); + } + } +#endif +} + +extern "C" __global__ AICORE void launch_linear_attention( + __gm__ uint8_t *q, __gm__ uint8_t *k, __gm__ uint8_t *v, + __gm__ uint8_t *workspace_scores, __gm__ uint8_t *workspace_state, + __gm__ uint8_t *o, int64_t batch_size, int64_t seq_len, + uint64_t ffts_addr) { + main_kernel( + reinterpret_cast<__gm__ half *>(q), reinterpret_cast<__gm__ half *>(k), + reinterpret_cast<__gm__ half *>(v), + reinterpret_cast<__gm__ half *>(workspace_scores), + reinterpret_cast<__gm__ half *>(workspace_state), + reinterpret_cast<__gm__ half *>(o), batch_size, seq_len, ffts_addr); +} + +extern "C" void call_kernel(uint32_t block_dim, void *stream, uint8_t *q, + uint8_t *k, uint8_t *v, uint8_t *workspace_scores, + uint8_t *workspace_state, uint8_t *o, + int64_t batch_size, int64_t seq_len) { + uint32_t ffts_len = 0; + uint64_t ffts_addr = 0; + rtGetC2cCtrlAddr(&ffts_addr, &ffts_len); + launch_linear_attention<<>>( + q, k, v, workspace_scores, workspace_state, o, batch_size, seq_len, + ffts_addr); +} diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/02_naive_dynamic_shape/numpy_sim.py b/examples/jit_cpp/linear_attention/optimize_step_by_step/02_naive_dynamic_shape/numpy_sim.py new file mode 100644 index 00000000..83ff71e2 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/02_naive_dynamic_shape/numpy_sim.py @@ -0,0 +1,68 @@ +import math +import numpy as np + +BLOCK_DIM = 24 + + +def ref_linear_attention(q, k, v): + b, h, l, d = q.shape + state = np.zeros((b, h, d, d), dtype=np.float32) + out = np.zeros((b, h, l, d), dtype=np.float32) + for t in range(l): + state += np.einsum("bhi,bhj->bhij", k[:, :, t].astype(np.float32), v[:, :, t].astype(np.float32)) + out[:, :, t] = np.einsum("bhi,bhij->bhj", q[:, :, t].astype(np.float32), state) + return out.astype(np.float16) + + +def step02_numpy_sim(q, k, v, chunk_size): + b, h, l, d = q.shape + total_work = b * h + chunk_num = l // chunk_size + workspace_1 = np.zeros((BLOCK_DIM, chunk_size, chunk_size), dtype=np.float16) + workspace_2 = np.zeros((BLOCK_DIM, d, d), dtype=np.float16) + out = np.zeros((b, h, l, d), dtype=np.float16) + causal_mask = np.tril(np.ones((chunk_size, chunk_size), dtype=np.float32)) + + # The real kernel launches one fixed block per core and loops over `work_idx` inside the kernel. + # We emulate that with a sequential nested loop instead of actual parallel execution. + for work_idx in range(math.ceil(total_work / BLOCK_DIM)): + for cid in range(BLOCK_DIM): + pid = work_idx * BLOCK_DIM + cid + if pid >= total_work: + continue + by = pid % h + bz = pid // h + h_state = np.zeros((d, d), dtype=np.float32) + for chunk_idx in range(chunk_num): + l0 = chunk_idx * chunk_size + l1 = l0 + chunk_size + q_tile = q[bz, by, l0:l1].astype(np.float32) + k_tile = k[bz, by, l0:l1].astype(np.float32) + v_tile = v[bz, by, l0:l1].astype(np.float32) + + acc = (q_tile @ k_tile.T) * causal_mask + workspace_1[cid] = acc.astype(np.float16) + + out[bz, by, l0:l1] = (acc @ v_tile + q_tile @ h_state).astype(np.float16) + h_state = h_state + k_tile.T @ v_tile + workspace_2[cid] = h_state.astype(np.float16) + return out + + +def main(): + np.random.seed(0) + for shape in [(1, 2, 256, 128, 64), (4, 2, 512, 128, 64)]: + b, h, l, d, c = shape + q = np.random.randn(b, h, l, d).astype(np.float16) + k = np.random.randn(b, h, l, d).astype(np.float16) + v = np.random.randn(b, h, l, d).astype(np.float16) + q = q / (np.linalg.norm(q.astype(np.float32), axis=-1, keepdims=True) + 1e-6) + k = k / (np.linalg.norm(k.astype(np.float32), axis=-1, keepdims=True) + 1e-6) + ref = ref_linear_attention(q, k, v) + sim = step02_numpy_sim(q, k, v, c) + np.testing.assert_allclose(sim, ref, rtol=1e-2, atol=1e-2) + print('passed', shape) + + +if __name__ == '__main__': + main() diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/02_naive_dynamic_shape/run_linear_attention.py b/examples/jit_cpp/linear_attention/optimize_step_by_step/02_naive_dynamic_shape/run_linear_attention.py new file mode 100644 index 00000000..3a926995 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/02_naive_dynamic_shape/run_linear_attention.py @@ -0,0 +1,48 @@ +from pathlib import Path +import sys + +COMMON_DIR = Path(__file__).resolve().parents[1] / "common" +if str(COMMON_DIR) not in sys.path: + sys.path.insert(0, str(COMMON_DIR)) + +import torch +import torch_npu # noqa: F401 + +from jit_util_linear_attention import BLOCK_DIM, jit_compile +from linear_attention_shared import run_correctness_cases, run_dynamic_kernel + + +def run_kernel(src: str, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, chunk_size: int): + return run_dynamic_kernel( + src, + q, + k, + v, + chunk_size, + jit_compile=jit_compile, + block_dim=BLOCK_DIM, + stage_count=1, + use_mask=False, + ) + + +def main(): + test_configs = [ + (1, 2, 64, 128, 64), + (1, 2, 256, 128, 64), + (4, 2, 128, 128, 64), + (8, 2, 512, 128, 64), + (10, 2, 512, 128, 64), + (16, 2, 256, 128, 64), + (32, 2, 128, 128, 64), + (1, 2, 1024, 128, 64), + (8, 2, 2048, 128, 64), + (2, 2, 4096, 128, 64), + (16, 2, 1024, 128, 64), + (50, 20, 128, 128, 64), + ] + run_correctness_cases(__file__, test_configs, run_kernel) + + +if __name__ == "__main__": + main() diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/03_cached_mask/README.md b/examples/jit_cpp/linear_attention/optimize_step_by_step/03_cached_mask/README.md new file mode 100644 index 00000000..a702e94e --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/03_cached_mask/README.md @@ -0,0 +1,13 @@ +# Step 03: Precompute And Cache Causal Mask + +This step corresponds to commit `a9b54ed`. It also inherits the earlier cleanup/minimalization that made the kernel shorter and easier to follow. + +What changed: +- the triangular causal mask is built once in PyTorch and passed into the kernel +- the kernel applies the mask with vector tile operations instead of a slow scalar loop + +Suggested run order: +```bash +python run_linear_attention.py +python benchmark_linear_attention.py --quick +``` diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/03_cached_mask/benchmark_linear_attention.py b/examples/jit_cpp/linear_attention/optimize_step_by_step/03_cached_mask/benchmark_linear_attention.py new file mode 100644 index 00000000..85a0a8c5 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/03_cached_mask/benchmark_linear_attention.py @@ -0,0 +1,46 @@ +from pathlib import Path +import sys + +COMMON_DIR = Path(__file__).resolve().parents[1] / "common" +if str(COMMON_DIR) not in sys.path: + sys.path.insert(0, str(COMMON_DIR)) + +from jit_util_linear_attention import BLOCK_DIM, get_causal_mask, jit_compile +from linear_attention_shared import benchmark_cli, benchmark_dynamic_kernel + +DEFAULT_SHAPES = [(16, 20, 1024, 128, 64), (16, 20, 2048, 128, 64), (32, 20, 1024, 128, 64), (8, 20, 4096, 128, 64)] + +QUICK_SHAPES = [(8, 20, 1024, 128, 64), (16, 20, 1024, 128, 64)] + + + +def benchmark_shape(src: str, *, batch: int, heads: int, seq_len: int, hidden: int, chunk: int, warmup: int, repeats: int): + return benchmark_dynamic_kernel( + src, + batch=batch, + heads=heads, + seq_len=seq_len, + hidden=hidden, + chunk=chunk, + warmup=warmup, + repeats=repeats, + jit_compile=jit_compile, + block_dim=BLOCK_DIM, + stage_count=1, + use_mask=True, + include_workspace_bytes=False, + mask_factory=get_causal_mask, + ) + + +def main(): + benchmark_cli( + script_file=__file__, + default_shapes=DEFAULT_SHAPES, + quick_shapes=QUICK_SHAPES, + benchmark_shape=benchmark_shape + ) + + +if __name__ == "__main__": + main() diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/03_cached_mask/jit_util_linear_attention.py b/examples/jit_cpp/linear_attention/optimize_step_by_step/03_cached_mask/jit_util_linear_attention.py new file mode 100644 index 00000000..3fa5478b --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/03_cached_mask/jit_util_linear_attention.py @@ -0,0 +1,60 @@ +from pathlib import Path +import sys + +COMMON_DIR = Path(__file__).resolve().parents[1] / "common" +if str(COMMON_DIR) not in sys.path: + sys.path.insert(0, str(COMMON_DIR)) + +from functools import lru_cache + +from jit_shared import BLOCK_DIM, STEP03_KERNEL_FLAGS, compile_cpp as shared_compile_cpp +from jit_shared import get_causal_mask, load_dynamic_mask_lib + + +def compile_cpp( + kernel_cpp: str, + num_heads: int, + hidden_size: int, + chunk_size: int, + verbose: bool = False, + timeout: int = 180, +) -> str: + return shared_compile_cpp( + kernel_cpp, + output_name=f"linear_attention_H{num_heads}_D{hidden_size}_C{chunk_size}_jit.so", + std="gnu++17", + defines=[ + f"-DLINEAR_ATTN_H={num_heads}", + f"-DLINEAR_ATTN_D={hidden_size}", + f"-DLINEAR_ATTN_C={chunk_size}", + ], + extra_flags=STEP03_KERNEL_FLAGS, + verbose=verbose, + timeout=timeout, + ) + + +def load_lib(lib_path: str): + return load_dynamic_mask_lib(lib_path) + + +@lru_cache(maxsize=None) +def jit_compile( + src_path: str, + num_heads: int, + hidden_size: int, + chunk_size: int, + verbose: bool = True, + clean_up: bool = False, +): + lib_path = compile_cpp( + src_path, + num_heads=num_heads, + hidden_size=hidden_size, + chunk_size=chunk_size, + verbose=verbose, + ) + func = load_lib(lib_path) + if clean_up: + Path(lib_path).unlink(missing_ok=True) + return func diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/03_cached_mask/linear_attention.cpp b/examples/jit_cpp/linear_attention/optimize_step_by_step/03_cached_mask/linear_attention.cpp new file mode 100644 index 00000000..6afb8492 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/03_cached_mask/linear_attention.cpp @@ -0,0 +1,321 @@ +#include +#include +#include +#include + +using namespace pto; + +// Step 03 keeps the naive overall schedule from step 02, but replaces the +// scalar triangular-mask loop with a precomputed mask tensor from PyTorch. +// That lets the vector core apply causality with one tile-wise multiply. + +#ifndef LINEAR_ATTN_H +#define LINEAR_ATTN_H 2 +#endif + +#ifndef LINEAR_ATTN_D +#define LINEAR_ATTN_D 128 +#endif + +#ifndef LINEAR_ATTN_C +#define LINEAR_ATTN_C 64 +#endif + +template +using L1Mat = Tile; + +template +using L1MatTrans = + Tile; + +template +using UbVec = Tile; + +template +AICORE inline void SetCrossFlag(int32_t flag, int32_t mode) { + const int config = 1 | (mode << 4) | (flag << 8); + ffts_cross_core_sync(Pipe, config); +} + +AICORE inline void WaitCrossFlag(int32_t flag) { wait_flag_dev(flag); } + +template +AICORE inline void MatmulL1( + TileAcc &dst, + std::conditional_t, L1Mat> &a_l1, + std::conditional_t, L1Mat> &b_l1, + bool init) { + // For these early steps we use a single, easy-to-follow "load to L0 then + // matmul" helper. Later steps optimize the internals of this helper. + TileLeft a_l0; + TileRight b_l0; + TASSIGN(a_l0, 0x0); + TASSIGN(b_l0, 0x0); + + if constexpr (TransposeA) { + L1MatTrans a_view; + TRESHAPE(a_view, a_l1); + TEXTRACT(a_l0, a_view, 0, 0); + } else { + TEXTRACT(a_l0, a_l1, 0, 0); + } + + if constexpr (TransposeB) { + L1MatTrans b_view; + TRESHAPE(b_view, b_l1); + TEXTRACT(b_l0, b_view, 0, 0); + } else { + TEXTRACT(b_l0, b_l1, 0, 0); + } + + pipe_barrier(PIPE_ALL); + if (init) { + TMATMUL(dst, a_l0, b_l0); + } else { + TMATMUL_ACC(dst, dst, a_l0, b_l0); + } + pipe_barrier(PIPE_ALL); +} + +template +AICORE void main_kernel(__gm__ half *q, __gm__ half *k, __gm__ half *v, + __gm__ half *workspace_1, __gm__ half *workspace_2, + __gm__ half *causal_mask, __gm__ half *o, + int64_t batch_size, int64_t seq_len, + uint64_t ffts_addr) { + constexpr int32_t VecNum = 2; + constexpr int32_t HalfChunk = ChunkSize / VecNum; + constexpr int32_t HalfHidden = HiddenSize / VecNum; + constexpr int32_t ChunkElems = ChunkSize * HiddenSize; + constexpr int32_t Workspace1Elems = ChunkSize * ChunkSize; + constexpr int32_t Workspace2Elems = HiddenSize * HiddenSize; + + constexpr int32_t QL1Addr = 0; + constexpr int32_t KL1Addr = QL1Addr + ChunkElems * sizeof(half); + constexpr int32_t VL1Addr = KL1Addr + ChunkElems * sizeof(half); + constexpr int32_t HL1Addr = VL1Addr + ChunkElems * sizeof(half); + constexpr int32_t AccL1Addr = HL1Addr + Workspace2Elems * sizeof(half); + + constexpr int32_t AccL0Addr = 0; + constexpr int32_t HL0Addr = AccL0Addr + Workspace1Elems * sizeof(float); + constexpr int32_t OL0Addr = HL0Addr + Workspace2Elems * sizeof(float); + + constexpr int32_t HsumUbAddr = 0; + constexpr int32_t AccUbAddr = HsumUbAddr + HalfHidden * HiddenSize * sizeof(half); + constexpr int32_t HUbAddr = AccUbAddr + HalfChunk * ChunkSize * sizeof(half); + constexpr int32_t MaskUbAddr = HUbAddr + HalfHidden * HiddenSize * sizeof(half); + constexpr int32_t MaskedAccUbAddr = + MaskUbAddr + HalfChunk * ChunkSize * sizeof(half); + + constexpr int32_t L0CBytes = + (Workspace1Elems + Workspace2Elems + ChunkElems) * sizeof(float); + static_assert((HiddenSize % 2) == 0, "HiddenSize must be even."); + static_assert((ChunkSize % 2) == 0, "ChunkSize must be even."); + static_assert(L0CBytes <= 112 * 1024, + "Tile sizes exceed the validated L0C budget for this minimum kernel."); + + using ChunkGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using AccGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HiddenGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HalfAccGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HalfHiddenGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HalfMaskGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + + const int64_t total_work = batch_size * NumHeads; + const int64_t chunk_num = seq_len / ChunkSize; + const int64_t cid = get_block_idx(); + const int64_t vid = get_subblockid(); + set_ffts_base_addr(ffts_addr); + + L1Mat q_l1; + L1Mat k_l1; + L1Mat v_l1; + L1Mat h_l1; + L1Mat acc_l1; + TASSIGN(q_l1, QL1Addr); + TASSIGN(k_l1, KL1Addr); + TASSIGN(v_l1, VL1Addr); + TASSIGN(h_l1, HL1Addr); + TASSIGN(acc_l1, AccL1Addr); + + TileAcc acc_l0; + TileAcc h_l0; + TileAcc o_l0; + TASSIGN(acc_l0, AccL0Addr); + TASSIGN(h_l0, HL0Addr); + TASSIGN(o_l0, OL0Addr); + + UbVec hsum_ub; + UbVec h_ub; + UbVec acc_ub; + UbVec mask_ub; + UbVec masked_acc_ub; + TASSIGN(hsum_ub, HsumUbAddr); + TASSIGN(acc_ub, AccUbAddr); + TASSIGN(h_ub, HUbAddr); + TASSIGN(mask_ub, MaskUbAddr); + TASSIGN(masked_acc_ub, MaskedAccUbAddr); + +#if defined(__DAV_C220_CUBE__) + for (int64_t work_idx = 0; work_idx < (total_work + block_num - 1) / block_num; + ++work_idx) { + const int64_t pid = work_idx * block_num + cid; + if (pid >= total_work) { + continue; + } + + const int64_t by = pid % NumHeads; + const int64_t bz = pid / NumHeads; + const int64_t qkv_base = ((bz * NumHeads + by) * seq_len) * HiddenSize; + const int64_t workspace1_base = cid * Workspace1Elems; + const int64_t workspace2_base = cid * Workspace2Elems; + + WaitCrossFlag(1); + + for (int64_t i = 0; i < chunk_num; ++i) { + const int64_t chunk_base = qkv_base + i * ChunkElems; + + ChunkGlobal q_global(q + chunk_base); + ChunkGlobal k_global(k + chunk_base); + ChunkGlobal v_global(v + chunk_base); + HiddenGlobal h_global(workspace_2 + workspace2_base); + + TLOAD(q_l1, q_global); + TLOAD(k_l1, k_global); + TLOAD(v_l1, v_global); + TLOAD(h_l1, h_global); + pipe_barrier(PIPE_ALL); + + // Cube computes two intermediates for this chunk: + // 1) chunk-local scores Q K^T + // 2) hidden-state update K^T V + MatmulL1(acc_l0, q_l1, k_l1, + true); + AccGlobal acc_global(workspace_1 + workspace1_base); + TSTORE(acc_global, acc_l0); + pipe_barrier(PIPE_ALL); + + MatmulL1(h_l0, k_l1, v_l1, + true); + HiddenGlobal h_out_global(workspace_2 + workspace2_base); + TSTORE(h_out_global, h_l0); + pipe_barrier(PIPE_ALL); + SetCrossFlag(0, 2); + + // Vector core overwrites workspace_1 with the masked scores, then cube + // finishes O = masked_scores @ V + Q @ prefix_state. + WaitCrossFlag(1); + AccGlobal masked_acc_global(workspace_1 + workspace1_base); + TLOAD(acc_l1, masked_acc_global); + pipe_barrier(PIPE_ALL); + + MatmulL1(o_l0, acc_l1, + v_l1, true); + MatmulL1(o_l0, q_l1, + h_l1, false); + + ChunkGlobal o_global(o + chunk_base); + TSTORE(o_global, o_l0); + pipe_barrier(PIPE_ALL); + } + } +#endif + +#if defined(__DAV_C220_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + // This is the key change in step 03: each vector sub-core loads its own half + // of the triangular mask once, outside the chunk loop, and reuses it. + HalfMaskGlobal mask_global(causal_mask + vid * HalfChunk * ChunkSize); + TLOAD(mask_ub, mask_global); + pipe_barrier(PIPE_ALL); + + for (int64_t work_idx = 0; work_idx < (total_work + block_num - 1) / block_num; + ++work_idx) { + const int64_t pid = work_idx * block_num + cid; + if (pid >= total_work) { + continue; + } + + const int64_t workspace1_base = + cid * Workspace1Elems + vid * HalfChunk * ChunkSize; + const int64_t workspace2_base = + cid * Workspace2Elems + vid * HalfHidden * HiddenSize; + + TEXPANDS(hsum_ub, 0.0f); + pipe_barrier(PIPE_ALL); + HalfHiddenGlobal init_h_global(workspace_2 + workspace2_base); + TSTORE(init_h_global, hsum_ub); + pipe_barrier(PIPE_ALL); + SetCrossFlag(1, 2); + + for (int64_t i = 0; i < chunk_num; ++i) { + WaitCrossFlag(0); + + HalfAccGlobal acc_global(workspace_1 + workspace1_base); + HalfHiddenGlobal h_global(workspace_2 + workspace2_base); + TLOAD(acc_ub, acc_global); + TLOAD(h_ub, h_global); + pipe_barrier(PIPE_ALL); + // Elementwise multiply is much cheaper than the scalar if-statements from + // step 02, but the numerical effect is identical. + TMUL(masked_acc_ub, acc_ub, mask_ub); + + TADD(hsum_ub, hsum_ub, h_ub); + pipe_barrier(PIPE_ALL); + TSTORE(acc_global, masked_acc_ub); + TSTORE(h_global, hsum_ub); + pipe_barrier(PIPE_ALL); + SetCrossFlag(1, 2); + } + } +#endif +} + +extern "C" __global__ AICORE void launch_linear_attention( + __gm__ uint8_t *q, __gm__ uint8_t *k, __gm__ uint8_t *v, + __gm__ uint8_t *workspace_1, __gm__ uint8_t *workspace_2, + __gm__ uint8_t *causal_mask, __gm__ uint8_t *o, int64_t batch_size, + int64_t seq_len, uint64_t ffts_addr) { + main_kernel( + reinterpret_cast<__gm__ half *>(q), reinterpret_cast<__gm__ half *>(k), + reinterpret_cast<__gm__ half *>(v), + reinterpret_cast<__gm__ half *>(workspace_1), + reinterpret_cast<__gm__ half *>(workspace_2), + reinterpret_cast<__gm__ half *>(causal_mask), + reinterpret_cast<__gm__ half *>(o), batch_size, seq_len, ffts_addr); +} + +extern "C" void call_kernel(uint32_t blockDim, void *stream, uint8_t *q, + uint8_t *k, uint8_t *v, uint8_t *workspace_1, + uint8_t *workspace_2, uint8_t *causal_mask, + uint8_t *o, + int64_t batch_size, int64_t seq_len) { + uint32_t ffts_len = 0; + uint64_t ffts_addr = 0; + rtGetC2cCtrlAddr(&ffts_addr, &ffts_len); + launch_linear_attention<<>>( + q, k, v, workspace_1, workspace_2, causal_mask, o, batch_size, seq_len, + ffts_addr); +} diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/03_cached_mask/run_linear_attention.py b/examples/jit_cpp/linear_attention/optimize_step_by_step/03_cached_mask/run_linear_attention.py new file mode 100644 index 00000000..7ff97556 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/03_cached_mask/run_linear_attention.py @@ -0,0 +1,33 @@ +from pathlib import Path +import sys + +COMMON_DIR = Path(__file__).resolve().parents[1] / "common" +if str(COMMON_DIR) not in sys.path: + sys.path.insert(0, str(COMMON_DIR)) + +from jit_util_linear_attention import BLOCK_DIM, get_causal_mask, jit_compile +from linear_attention_shared import run_correctness_cases, run_dynamic_kernel + + +def run_kernel(src, q, k, v, chunk_size): + return run_dynamic_kernel( + src, + q, + k, + v, + chunk_size, + jit_compile=jit_compile, + block_dim=BLOCK_DIM, + stage_count=1, + use_mask=True, + mask_factory=get_causal_mask, + ) + + +def main(): + test_configs = [(1, 2, 64, 128, 64), (1, 2, 256, 128, 64), (4, 2, 128, 128, 64), (8, 2, 512, 128, 64), (10, 2, 512, 128, 64), (16, 2, 256, 128, 64), (32, 2, 128, 128, 64), (1, 2, 1024, 128, 64), (8, 2, 2048, 128, 64), (2, 2, 4096, 128, 64), (16, 2, 1024, 128, 64), (50, 20, 128, 128, 64)] + run_correctness_cases(__file__, test_configs, run_kernel) + + +if __name__ == "__main__": + main() diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/03a_fast_mask_construct/README.md b/examples/jit_cpp/linear_attention/optimize_step_by_step/03a_fast_mask_construct/README.md new file mode 100644 index 00000000..08b360bb --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/03a_fast_mask_construct/README.md @@ -0,0 +1,16 @@ +# Step 03a: Fast On-The-Fly Mask Construction + +This step keeps the dynamic-shape interface from step 02, but removes the +slow scalar `SetValue` / `GetValue` causal-mask loop. + +What changed: +- the lower-triangular mask is synthesized on-chip with PTO-ISA vector ops +- the mask is built with the higher-level `TTRI` PTO-ISA wrapper instead of raw per-element scalar updates +- the kernel then reuses the same fast `TMUL` masking pattern as step 03 +- unlike step 03, this path does not read a precomputed mask from global memory + +Suggested run order: +```bash +python run_linear_attention.py +python benchmark_linear_attention.py --quick +``` \ No newline at end of file diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/03a_fast_mask_construct/benchmark_linear_attention.py b/examples/jit_cpp/linear_attention/optimize_step_by_step/03a_fast_mask_construct/benchmark_linear_attention.py new file mode 100644 index 00000000..1a61359b --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/03a_fast_mask_construct/benchmark_linear_attention.py @@ -0,0 +1,62 @@ +from pathlib import Path +import sys + +COMMON_DIR = Path(__file__).resolve().parents[1] / "common" +if str(COMMON_DIR) not in sys.path: + sys.path.insert(0, str(COMMON_DIR)) + +from jit_util_linear_attention import BLOCK_DIM, jit_compile +from linear_attention_shared import benchmark_cli, benchmark_dynamic_kernel + +DEFAULT_SHAPES = [ + (16, 20, 1024, 128, 64), + (16, 20, 2048, 128, 64), + (32, 20, 1024, 128, 64), + (8, 20, 4096, 128, 64), +] + +QUICK_SHAPES = [ + (8, 20, 1024, 128, 64), + (16, 20, 1024, 128, 64), +] + + +def benchmark_shape( + src: str, + *, + batch: int, + heads: int, + seq_len: int, + hidden: int, + chunk: int, + warmup: int, + repeats: int, +): + return benchmark_dynamic_kernel( + src, + batch=batch, + heads=heads, + seq_len=seq_len, + hidden=hidden, + chunk=chunk, + warmup=warmup, + repeats=repeats, + jit_compile=jit_compile, + block_dim=BLOCK_DIM, + stage_count=1, + use_mask=False, + include_workspace_bytes=False, + ) + + +def main(): + benchmark_cli( + script_file=__file__, + default_shapes=DEFAULT_SHAPES, + quick_shapes=QUICK_SHAPES, + benchmark_shape=benchmark_shape, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/03a_fast_mask_construct/jit_util_linear_attention.py b/examples/jit_cpp/linear_attention/optimize_step_by_step/03a_fast_mask_construct/jit_util_linear_attention.py new file mode 100644 index 00000000..71425344 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/03a_fast_mask_construct/jit_util_linear_attention.py @@ -0,0 +1,60 @@ +from pathlib import Path +import sys + +COMMON_DIR = Path(__file__).resolve().parents[1] / "common" +if str(COMMON_DIR) not in sys.path: + sys.path.insert(0, str(COMMON_DIR)) + +from functools import lru_cache + +from jit_shared import BLOCK_DIM, STEP03_KERNEL_FLAGS, compile_cpp as shared_compile_cpp +from jit_shared import load_dynamic_nomask_lib + + +def compile_cpp( + kernel_cpp: str, + num_heads: int, + hidden_size: int, + chunk_size: int, + verbose: bool = False, + timeout: int = 180, +) -> str: + return shared_compile_cpp( + kernel_cpp, + output_name=f"linear_attention_H{num_heads}_D{hidden_size}_C{chunk_size}_jit.so", + std="gnu++17", + defines=[ + f"-DLINEAR_ATTN_H={num_heads}", + f"-DLINEAR_ATTN_D={hidden_size}", + f"-DLINEAR_ATTN_C={chunk_size}", + ], + extra_flags=STEP03_KERNEL_FLAGS, + verbose=verbose, + timeout=timeout, + ) + + +def load_lib(lib_path: str): + return load_dynamic_nomask_lib(lib_path) + + +@lru_cache(maxsize=None) +def jit_compile( + src_path: str, + num_heads: int, + hidden_size: int, + chunk_size: int, + verbose: bool = True, + clean_up: bool = False, +): + lib_path = compile_cpp( + src_path, + num_heads=num_heads, + hidden_size=hidden_size, + chunk_size=chunk_size, + verbose=verbose, + ) + func = load_lib(lib_path) + if clean_up: + Path(lib_path).unlink(missing_ok=True) + return func diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/03a_fast_mask_construct/linear_attention.cpp b/examples/jit_cpp/linear_attention/optimize_step_by_step/03a_fast_mask_construct/linear_attention.cpp new file mode 100644 index 00000000..69798ef8 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/03a_fast_mask_construct/linear_attention.cpp @@ -0,0 +1,327 @@ +#include +#include +#include + +using namespace pto; + +// Step 03a keeps the dynamic-shape kernel from step 02, but replaces the slow +// scalar mask loop with an on-chip vectorized mask build. Each vector sub-core +// constructs its lower-triangular half mask once with the higher-level TTRI +// PTO-ISA wrapper, then reuses the same cheap elementwise multiply as step 03. + +#ifndef LINEAR_ATTN_H +#define LINEAR_ATTN_H 2 +#endif + +#ifndef LINEAR_ATTN_D +#define LINEAR_ATTN_D 128 +#endif + +#ifndef LINEAR_ATTN_C +#define LINEAR_ATTN_C 64 +#endif + +template +using L1Mat = Tile; + +template +using L1MatTrans = + Tile; + +template +using UbVec = Tile; + +// PTO 8.5.0 bakes `diagonal` into the TTRI template arguments, while +// pto-isa-master passes it as a runtime argument. Keep one call site that +// accepts either form so the example builds against both header versions. +template +AICORE inline auto TTriCompatImpl(TileData &dst, int diagonal_value, int) + -> decltype(TTRI(dst, diagonal_value), void()) { + TTRI(dst, diagonal_value); +} + +template +AICORE inline auto TTriCompatImpl(TileData &dst, int, long) + -> decltype(TTRI(dst), void()) { + TTRI(dst); +} + +template +AICORE inline void TTriCompat(TileData &dst) { + TTriCompatImpl(dst, diagonal, 0); +} + +template +AICORE inline void SetCrossFlag(int32_t flag, int32_t mode) { + const int config = 1 | (mode << 4) | (flag << 8); + ffts_cross_core_sync(Pipe, config); +} + +AICORE inline void WaitCrossFlag(int32_t flag) { wait_flag_dev(flag); } + +template +AICORE inline void MatmulL1( + TileAcc &dst, + std::conditional_t, L1Mat> &a_l1, + std::conditional_t, L1Mat> &b_l1, + bool init) { + TileLeft a_l0; + TileRight b_l0; + TASSIGN(a_l0, 0x0); + TASSIGN(b_l0, 0x0); + + if constexpr (TransposeA) { + L1MatTrans a_view; + TRESHAPE(a_view, a_l1); + TEXTRACT(a_l0, a_view, 0, 0); + } else { + TEXTRACT(a_l0, a_l1, 0, 0); + } + + if constexpr (TransposeB) { + L1MatTrans b_view; + TRESHAPE(b_view, b_l1); + TEXTRACT(b_l0, b_view, 0, 0); + } else { + TEXTRACT(b_l0, b_l1, 0, 0); + } + + pipe_barrier(PIPE_ALL); + if (init) { + TMATMUL(dst, a_l0, b_l0); + } else { + TMATMUL_ACC(dst, dst, a_l0, b_l0); + } + pipe_barrier(PIPE_ALL); +} + +template +AICORE void main_kernel(__gm__ half *q, __gm__ half *k, __gm__ half *v, + __gm__ half *workspace_scores, + __gm__ half *workspace_state, __gm__ half *o, + int64_t batch_size, int64_t seq_len, + uint64_t ffts_addr) { + constexpr int kVecParts = 2; + constexpr int kHalfChunk = ChunkSize / kVecParts; + constexpr int kHalfHidden = HiddenSize / kVecParts; + constexpr int kChunkElems = ChunkSize * HiddenSize; + constexpr int kScoreElems = ChunkSize * ChunkSize; + constexpr int kStateElems = HiddenSize * HiddenSize; + constexpr int kQueryL1Addr = 0; + constexpr int kKeyL1Addr = kQueryL1Addr + kChunkElems * sizeof(half); + constexpr int kValueL1Addr = kKeyL1Addr + kChunkElems * sizeof(half); + constexpr int kStateL1Addr = kValueL1Addr + kChunkElems * sizeof(half); + constexpr int kMaskedScoreL1Addr = kStateL1Addr + kStateElems * sizeof(half); + constexpr int kScoreL0Addr = 0; + constexpr int kStateL0Addr = kScoreL0Addr + kScoreElems * sizeof(float); + constexpr int kOutputL0Addr = kStateL0Addr + kStateElems * sizeof(float); + constexpr int kPrefixStateUbAddr = 0; + constexpr int kScoreUbAddr = kPrefixStateUbAddr + kHalfHidden * HiddenSize * sizeof(half); + constexpr int kStateUbAddr = kScoreUbAddr + kHalfChunk * ChunkSize * sizeof(half); + constexpr int kZeroScoreUbAddr = kStateUbAddr + kHalfHidden * HiddenSize * sizeof(half); + constexpr int kMaskUbAddr = kZeroScoreUbAddr + kHalfChunk * ChunkSize * sizeof(half); + constexpr int kMaskedScoreUbAddr = kMaskUbAddr + kHalfChunk * ChunkSize * sizeof(half); + + static_assert((HiddenSize % 2) == 0, "HiddenSize must be even."); + static_assert((ChunkSize % 2) == 0, "ChunkSize must be even."); + + using ChunkGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using ScoreGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using StateGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HalfScoreGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HalfStateGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + + const int64_t total_work = batch_size * NumHeads; + const int64_t chunk_count = seq_len / ChunkSize; + const int64_t core_id = get_block_idx(); + const int64_t vector_id = get_subblockid(); + set_ffts_base_addr(ffts_addr); + + L1Mat q_chunk_l1; + L1Mat k_chunk_l1; + L1Mat v_chunk_l1; + L1Mat prefix_state_l1; + L1Mat masked_score_l1; + TASSIGN(q_chunk_l1, kQueryL1Addr); + TASSIGN(k_chunk_l1, kKeyL1Addr); + TASSIGN(v_chunk_l1, kValueL1Addr); + TASSIGN(prefix_state_l1, kStateL1Addr); + TASSIGN(masked_score_l1, kMaskedScoreL1Addr); + + TileAcc raw_score_l0; + TileAcc state_update_l0; + TileAcc output_l0; + TASSIGN(raw_score_l0, kScoreL0Addr); + TASSIGN(state_update_l0, kStateL0Addr); + TASSIGN(output_l0, kOutputL0Addr); + + UbVec running_state_ub; + UbVec state_delta_ub; + UbVec score_ub; + UbVec zero_score_ub; + UbVec mask_ub; + UbVec masked_score_ub; + TASSIGN(running_state_ub, kPrefixStateUbAddr); + TASSIGN(score_ub, kScoreUbAddr); + TASSIGN(state_delta_ub, kStateUbAddr); + TASSIGN(zero_score_ub, kZeroScoreUbAddr); + TASSIGN(mask_ub, kMaskUbAddr); + TASSIGN(masked_score_ub, kMaskedScoreUbAddr); + +#if defined(__DAV_C220_CUBE__) + for (int64_t work_idx = 0; work_idx < (total_work + block_num - 1) / block_num; + ++work_idx) { + const int64_t work_id = work_idx * block_num + core_id; + if (work_id >= total_work) { + continue; + } + + const int64_t head_id = work_id % NumHeads; + const int64_t batch_id = work_id / NumHeads; + const int64_t qkv_base = ((batch_id * NumHeads + head_id) * seq_len) * HiddenSize; + const int64_t score_workspace_base = core_id * kScoreElems; + const int64_t state_workspace_base = core_id * kStateElems; + + WaitCrossFlag(1); + for (int64_t chunk_index = 0; chunk_index < chunk_count; ++chunk_index) { + const int64_t chunk_base = qkv_base + chunk_index * kChunkElems; + ChunkGlobal q_global(q + chunk_base); + ChunkGlobal k_global(k + chunk_base); + ChunkGlobal v_global(v + chunk_base); + StateGlobal prefix_state_global(workspace_state + state_workspace_base); + ScoreGlobal score_global(workspace_scores + score_workspace_base); + ChunkGlobal output_global(o + chunk_base); + + TLOAD(q_chunk_l1, q_global); + TLOAD(k_chunk_l1, k_global); + TLOAD(v_chunk_l1, v_global); + TLOAD(prefix_state_l1, prefix_state_global); + pipe_barrier(PIPE_ALL); + + MatmulL1(raw_score_l0, + q_chunk_l1, + k_chunk_l1, true); + TSTORE(score_global, raw_score_l0); + pipe_barrier(PIPE_ALL); + + MatmulL1(state_update_l0, + k_chunk_l1, + v_chunk_l1, true); + TSTORE(prefix_state_global, state_update_l0); + pipe_barrier(PIPE_ALL); + SetCrossFlag(0, 2); + + WaitCrossFlag(1); + TLOAD(masked_score_l1, score_global); + pipe_barrier(PIPE_ALL); + + MatmulL1(output_l0, + masked_score_l1, + v_chunk_l1, true); + MatmulL1(output_l0, + q_chunk_l1, + prefix_state_l1, + false); + TSTORE(output_global, output_l0); + pipe_barrier(PIPE_ALL); + } + } +#endif + +#if defined(__DAV_C220_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + + // Build the lower-triangular mask once per vector sub-core and reuse it for + // every chunk. The two vector sub-cores cover global rows [0, 31] and [32, 63]. + TEXPANDS(zero_score_ub, 0.0f); + if (vector_id == 0) { + TTriCompat(mask_ub); + } else { + TTriCompat(mask_ub); + } + pipe_barrier(PIPE_ALL); + + for (int64_t work_idx = 0; work_idx < (total_work + block_num - 1) / block_num; + ++work_idx) { + const int64_t work_id = work_idx * block_num + core_id; + if (work_id >= total_work) { + continue; + } + + const int64_t score_workspace_base = + core_id * kScoreElems + vector_id * kHalfChunk * ChunkSize; + const int64_t state_workspace_base = + core_id * kStateElems + vector_id * kHalfHidden * HiddenSize; + + TEXPANDS(running_state_ub, 0.0f); + pipe_barrier(PIPE_ALL); + HalfStateGlobal state_slice_global(workspace_state + state_workspace_base); + TSTORE(state_slice_global, running_state_ub); + pipe_barrier(PIPE_ALL); + SetCrossFlag(1, 2); + + for (int64_t chunk_index = 0; chunk_index < chunk_count; ++chunk_index) { + WaitCrossFlag(0); + HalfScoreGlobal score_slice_global(workspace_scores + score_workspace_base); + TLOAD(score_ub, score_slice_global); + TLOAD(state_delta_ub, state_slice_global); + pipe_barrier(PIPE_ALL); + + TMUL(masked_score_ub, score_ub, mask_ub); + pipe_barrier(PIPE_ALL); + + TADD(running_state_ub, running_state_ub, state_delta_ub); + pipe_barrier(PIPE_ALL); + TSTORE(score_slice_global, masked_score_ub); + TSTORE(state_slice_global, running_state_ub); + pipe_barrier(PIPE_ALL); + SetCrossFlag(1, 2); + } + } +#endif +} + +extern "C" __global__ AICORE void launch_linear_attention( + __gm__ uint8_t *q, __gm__ uint8_t *k, __gm__ uint8_t *v, + __gm__ uint8_t *workspace_scores, __gm__ uint8_t *workspace_state, + __gm__ uint8_t *o, int64_t batch_size, int64_t seq_len, + uint64_t ffts_addr) { + main_kernel( + reinterpret_cast<__gm__ half *>(q), reinterpret_cast<__gm__ half *>(k), + reinterpret_cast<__gm__ half *>(v), + reinterpret_cast<__gm__ half *>(workspace_scores), + reinterpret_cast<__gm__ half *>(workspace_state), + reinterpret_cast<__gm__ half *>(o), batch_size, seq_len, ffts_addr); +} + +extern "C" void call_kernel(uint32_t block_dim, void *stream, uint8_t *q, + uint8_t *k, uint8_t *v, uint8_t *workspace_scores, + uint8_t *workspace_state, uint8_t *o, + int64_t batch_size, int64_t seq_len) { + uint32_t ffts_len = 0; + uint64_t ffts_addr = 0; + rtGetC2cCtrlAddr(&ffts_addr, &ffts_len); + launch_linear_attention<<>>( + q, k, v, workspace_scores, workspace_state, o, batch_size, seq_len, + ffts_addr); +} diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/03a_fast_mask_construct/run_linear_attention.py b/examples/jit_cpp/linear_attention/optimize_step_by_step/03a_fast_mask_construct/run_linear_attention.py new file mode 100644 index 00000000..3a926995 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/03a_fast_mask_construct/run_linear_attention.py @@ -0,0 +1,48 @@ +from pathlib import Path +import sys + +COMMON_DIR = Path(__file__).resolve().parents[1] / "common" +if str(COMMON_DIR) not in sys.path: + sys.path.insert(0, str(COMMON_DIR)) + +import torch +import torch_npu # noqa: F401 + +from jit_util_linear_attention import BLOCK_DIM, jit_compile +from linear_attention_shared import run_correctness_cases, run_dynamic_kernel + + +def run_kernel(src: str, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, chunk_size: int): + return run_dynamic_kernel( + src, + q, + k, + v, + chunk_size, + jit_compile=jit_compile, + block_dim=BLOCK_DIM, + stage_count=1, + use_mask=False, + ) + + +def main(): + test_configs = [ + (1, 2, 64, 128, 64), + (1, 2, 256, 128, 64), + (4, 2, 128, 128, 64), + (8, 2, 512, 128, 64), + (10, 2, 512, 128, 64), + (16, 2, 256, 128, 64), + (32, 2, 128, 128, 64), + (1, 2, 1024, 128, 64), + (8, 2, 2048, 128, 64), + (2, 2, 4096, 128, 64), + (16, 2, 1024, 128, 64), + (50, 20, 128, 128, 64), + ] + run_correctness_cases(__file__, test_configs, run_kernel) + + +if __name__ == "__main__": + main() diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/04_chunk128/README.md b/examples/jit_cpp/linear_attention/optimize_step_by_step/04_chunk128/README.md new file mode 100644 index 00000000..54c5255d --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/04_chunk128/README.md @@ -0,0 +1,7 @@ +# Step 04: Increase Chunk Size To 128 + +This step corresponds to commit `bd954f9`. + +What changed: +- the kernel is reworked to fit `C=128, D=128` within the validated on-chip memory budget +- arithmetic intensity improves, which moves the kernel from the `~30 TFLOP/s` class toward the `~50 TFLOP/s` class diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/04_chunk128/benchmark_linear_attention.py b/examples/jit_cpp/linear_attention/optimize_step_by_step/04_chunk128/benchmark_linear_attention.py new file mode 100644 index 00000000..09e89533 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/04_chunk128/benchmark_linear_attention.py @@ -0,0 +1,46 @@ +from pathlib import Path +import sys + +COMMON_DIR = Path(__file__).resolve().parents[1] / "common" +if str(COMMON_DIR) not in sys.path: + sys.path.insert(0, str(COMMON_DIR)) + +from jit_util_linear_attention import BLOCK_DIM, get_causal_mask, jit_compile +from linear_attention_shared import benchmark_cli, benchmark_dynamic_kernel + +DEFAULT_SHAPES = [(16, 20, 1024, 128, 128), (16, 20, 2048, 128, 128), (32, 20, 1024, 128, 128), (8, 20, 4096, 128, 128)] + +QUICK_SHAPES = [(8, 20, 1024, 128, 128), (16, 20, 1024, 128, 128)] + + + +def benchmark_shape(src: str, *, batch: int, heads: int, seq_len: int, hidden: int, chunk: int, warmup: int, repeats: int): + return benchmark_dynamic_kernel( + src, + batch=batch, + heads=heads, + seq_len=seq_len, + hidden=hidden, + chunk=chunk, + warmup=warmup, + repeats=repeats, + jit_compile=jit_compile, + block_dim=BLOCK_DIM, + stage_count=1, + use_mask=True, + include_workspace_bytes=False, + mask_factory=get_causal_mask, + ) + + +def main(): + benchmark_cli( + script_file=__file__, + default_shapes=DEFAULT_SHAPES, + quick_shapes=QUICK_SHAPES, + benchmark_shape=benchmark_shape + ) + + +if __name__ == "__main__": + main() diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/04_chunk128/jit_util_linear_attention.py b/examples/jit_cpp/linear_attention/optimize_step_by_step/04_chunk128/jit_util_linear_attention.py new file mode 100644 index 00000000..e2870489 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/04_chunk128/jit_util_linear_attention.py @@ -0,0 +1,60 @@ +from pathlib import Path +import sys + +COMMON_DIR = Path(__file__).resolve().parents[1] / "common" +if str(COMMON_DIR) not in sys.path: + sys.path.insert(0, str(COMMON_DIR)) + +from functools import lru_cache + +from jit_shared import BLOCK_DIM, OPTIMIZED_KERNEL_FLAGS, compile_cpp as shared_compile_cpp +from jit_shared import get_causal_mask, load_dynamic_mask_lib + + +def compile_cpp( + kernel_cpp: str, + num_heads: int, + hidden_size: int, + chunk_size: int, + verbose: bool = False, + timeout: int = 180, +) -> str: + return shared_compile_cpp( + kernel_cpp, + output_name=f"linear_attention_H{num_heads}_D{hidden_size}_C{chunk_size}_jit.so", + std="gnu++17", + defines=[ + f"-DLINEAR_ATTN_H={num_heads}", + f"-DLINEAR_ATTN_D={hidden_size}", + f"-DLINEAR_ATTN_C={chunk_size}", + ], + extra_flags=OPTIMIZED_KERNEL_FLAGS, + verbose=verbose, + timeout=timeout, + ) + + +def load_lib(lib_path: str): + return load_dynamic_mask_lib(lib_path) + + +@lru_cache(maxsize=None) +def jit_compile( + src_path: str, + num_heads: int, + hidden_size: int, + chunk_size: int, + verbose: bool = True, + clean_up: bool = False, +): + lib_path = compile_cpp( + src_path, + num_heads=num_heads, + hidden_size=hidden_size, + chunk_size=chunk_size, + verbose=verbose, + ) + func = load_lib(lib_path) + if clean_up: + Path(lib_path).unlink(missing_ok=True) + return func diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/04_chunk128/linear_attention.cpp b/examples/jit_cpp/linear_attention/optimize_step_by_step/04_chunk128/linear_attention.cpp new file mode 100644 index 00000000..dc2a95e8 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/04_chunk128/linear_attention.cpp @@ -0,0 +1,347 @@ +#include +#include +#include +#include + +using namespace pto; + +// Step 04 keeps the cached-mask idea from step 03, but moves from C=64 to +// C=128. The bigger chunk increases arithmetic intensity, so the matmuls do +// more useful work per load/store. + +#ifndef LINEAR_ATTN_H +#define LINEAR_ATTN_H 2 +#endif + +#ifndef LINEAR_ATTN_D +#define LINEAR_ATTN_D 128 +#endif + +#ifndef LINEAR_ATTN_C +#define LINEAR_ATTN_C 64 +#endif + +template +using L1Mat = Tile; + +template +using L1MatTrans = + Tile; + +template +using UbVec = Tile; + +template +AICORE inline void SetCrossFlag(int32_t flag, int32_t mode) { + const int config = 1 | (mode << 4) | (flag << 8); + ffts_cross_core_sync(Pipe, config); +} + +AICORE inline void WaitCrossFlag(int32_t flag) { wait_flag_dev(flag); } + +template +AICORE inline void MatmulL1( + TileAcc &dst, + std::conditional_t, L1Mat> &a_l1, + std::conditional_t, L1Mat> &b_l1, + bool init) { + // Still the simple single-stage helper: later steps will optimize this + // internal load/compute sequence without changing the outer algorithm. + TileLeft a_l0; + TileRight b_l0; + TASSIGN(a_l0, 0x0); + TASSIGN(b_l0, 0x0); + + if constexpr (TransposeA) { + L1MatTrans a_view; + TRESHAPE(a_view, a_l1); + TEXTRACT(a_l0, a_view, 0, 0); + } else { + TEXTRACT(a_l0, a_l1, 0, 0); + } + + if constexpr (TransposeB) { + L1MatTrans b_view; + TRESHAPE(b_view, b_l1); + TEXTRACT(b_l0, b_view, 0, 0); + } else { + TEXTRACT(b_l0, b_l1, 0, 0); + } + + pipe_barrier(PIPE_ALL); + if (init) { + TMATMUL(dst, a_l0, b_l0); + } else { + TMATMUL_ACC(dst, dst, a_l0, b_l0); + } + pipe_barrier(PIPE_ALL); +} + +template +AICORE void main_kernel(__gm__ half *q, __gm__ half *k, __gm__ half *v, + __gm__ half *workspace_1, __gm__ half *workspace_2, + __gm__ half *causal_mask, __gm__ half *o, + int64_t batch_size, int64_t seq_len, + uint64_t ffts_addr) { + constexpr int32_t VecNum = 2; + constexpr int32_t HalfChunk = ChunkSize / VecNum; + constexpr int32_t HalfHidden = HiddenSize / VecNum; + constexpr int32_t ChunkElems = ChunkSize * HiddenSize; + constexpr int32_t Workspace1Elems = ChunkSize * ChunkSize; + constexpr int32_t Workspace2Elems = HiddenSize * HiddenSize; + + constexpr int32_t QL1Addr = 0; + constexpr int32_t KL1Addr = QL1Addr + ChunkElems * sizeof(half); + constexpr int32_t VL1Addr = KL1Addr + ChunkElems * sizeof(half); + constexpr int32_t HL1Addr = VL1Addr + ChunkElems * sizeof(half); + constexpr int32_t AccL1Addr = HL1Addr + Workspace2Elems * sizeof(half); + + // With C=128 the score tile, state tile, and output tile no longer all fit in + // L0C at the same time. This step reuses one shared L0C address range. + constexpr int32_t SharedL0Addr = 0; + + constexpr int32_t HsumUbAddr = 0; + constexpr int32_t AccUbAddr = + HsumUbAddr + HalfHidden * HiddenSize * sizeof(half); + constexpr int32_t HUbAddr = AccUbAddr + HalfChunk * ChunkSize * sizeof(half); + constexpr int32_t RawUBBytes = + (HalfHidden * HiddenSize + HalfChunk * ChunkSize + HalfHidden * HiddenSize + + HalfChunk * ChunkSize + HalfChunk * ChunkSize) * + sizeof(half); + // A larger chunk also pressures UB more. If the mask fits, preload it once; + // otherwise alias it onto the H buffer and reload when needed. + constexpr bool PreloadMask = RawUBBytes <= 72 * 1024; + constexpr bool AliasMaskIntoH = + !PreloadMask && (HalfHidden * HiddenSize >= HalfChunk * ChunkSize); + constexpr int32_t MaskUbAddr = + AliasMaskIntoH ? HUbAddr : HUbAddr + HalfHidden * HiddenSize * sizeof(half); + constexpr int32_t MaskedAccUbAddr = + MaskUbAddr + HalfChunk * ChunkSize * sizeof(half); + + constexpr int32_t L0CBytes = + (Workspace2Elems > Workspace1Elems + ? (Workspace2Elems > ChunkElems ? Workspace2Elems : ChunkElems) + : (Workspace1Elems > ChunkElems ? Workspace1Elems : ChunkElems)) * + sizeof(float); + constexpr int32_t UBBytes = + (HalfHidden * HiddenSize + HalfChunk * ChunkSize + + (AliasMaskIntoH ? HalfHidden * HiddenSize + : HalfHidden * HiddenSize + HalfChunk * ChunkSize) + + HalfChunk * ChunkSize) * + sizeof(half); + static_assert((HiddenSize % 2) == 0, "HiddenSize must be even."); + static_assert((ChunkSize % 2) == 0, "ChunkSize must be even."); + static_assert(L0CBytes <= 112 * 1024, + "Tile sizes exceed the validated L0C budget for this minimum kernel."); + static_assert(PreloadMask || AliasMaskIntoH, + "Current minimum kernel requires either a preloaded mask or H UB large enough to alias the mask."); + static_assert(UBBytes <= 72 * 1024, + "Tile sizes exceed the validated UB budget for this minimum kernel."); + + using ChunkGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using AccGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HiddenGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HalfAccGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HalfHiddenGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HalfMaskGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + + const int64_t total_work = batch_size * NumHeads; + const int64_t chunk_num = seq_len / ChunkSize; + const int64_t cid = get_block_idx(); + const int64_t vid = get_subblockid(); + set_ffts_base_addr(ffts_addr); + + L1Mat q_l1; + L1Mat k_l1; + L1Mat v_l1; + L1Mat h_l1; + L1Mat acc_l1; + TASSIGN(q_l1, QL1Addr); + TASSIGN(k_l1, KL1Addr); + TASSIGN(v_l1, VL1Addr); + TASSIGN(h_l1, HL1Addr); + TASSIGN(acc_l1, AccL1Addr); + + TileAcc acc_l0; + TileAcc h_l0; + TileAcc o_l0; + TASSIGN(acc_l0, SharedL0Addr); + TASSIGN(h_l0, SharedL0Addr); + TASSIGN(o_l0, SharedL0Addr); + + UbVec hsum_ub; + UbVec h_ub; + UbVec acc_ub; + UbVec mask_ub; + UbVec masked_acc_ub; + TASSIGN(hsum_ub, HsumUbAddr); + TASSIGN(acc_ub, AccUbAddr); + TASSIGN(h_ub, HUbAddr); + TASSIGN(mask_ub, MaskUbAddr); + TASSIGN(masked_acc_ub, MaskedAccUbAddr); + +#if defined(__DAV_C220_CUBE__) + for (int64_t work_idx = 0; work_idx < (total_work + block_num - 1) / block_num; + ++work_idx) { + const int64_t pid = work_idx * block_num + cid; + if (pid >= total_work) { + continue; + } + + const int64_t by = pid % NumHeads; + const int64_t bz = pid / NumHeads; + const int64_t qkv_base = ((bz * NumHeads + by) * seq_len) * HiddenSize; + const int64_t workspace1_base = cid * Workspace1Elems; + const int64_t workspace2_base = cid * Workspace2Elems; + + WaitCrossFlag(1); + + for (int64_t i = 0; i < chunk_num; ++i) { + const int64_t chunk_base = qkv_base + i * ChunkElems; + + ChunkGlobal q_global(q + chunk_base); + ChunkGlobal k_global(k + chunk_base); + ChunkGlobal v_global(v + chunk_base); + HiddenGlobal h_global(workspace_2 + workspace2_base); + + TLOAD(q_l1, q_global); + TLOAD(k_l1, k_global); + TLOAD(v_l1, v_global); + TLOAD(h_l1, h_global); + pipe_barrier(PIPE_ALL); + + // The math is unchanged from step 03; the gain here comes from the larger + // 128x128 score tile and the memory layout changes above. + MatmulL1(acc_l0, q_l1, k_l1, + true); + AccGlobal acc_global(workspace_1 + workspace1_base); + TSTORE(acc_global, acc_l0); + pipe_barrier(PIPE_ALL); + + MatmulL1(h_l0, k_l1, v_l1, + true); + HiddenGlobal h_out_global(workspace_2 + workspace2_base); + TSTORE(h_out_global, h_l0); + pipe_barrier(PIPE_ALL); + SetCrossFlag(0, 2); + + WaitCrossFlag(1); + AccGlobal masked_acc_global(workspace_1 + workspace1_base); + TLOAD(acc_l1, masked_acc_global); + pipe_barrier(PIPE_ALL); + + MatmulL1(o_l0, acc_l1, + v_l1, true); + MatmulL1(o_l0, q_l1, + h_l1, false); + + ChunkGlobal o_global(o + chunk_base); + TSTORE(o_global, o_l0); + pipe_barrier(PIPE_ALL); + } + } +#endif + +#if defined(__DAV_C220_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + HalfMaskGlobal mask_global(causal_mask + vid * HalfChunk * ChunkSize); + if constexpr (PreloadMask) { + // Best case: the whole mask slice stays resident in UB across all chunks. + TLOAD(mask_ub, mask_global); + pipe_barrier(PIPE_ALL); + } + + for (int64_t work_idx = 0; work_idx < (total_work + block_num - 1) / block_num; + ++work_idx) { + const int64_t pid = work_idx * block_num + cid; + if (pid >= total_work) { + continue; + } + + const int64_t workspace1_base = + cid * Workspace1Elems + vid * HalfChunk * ChunkSize; + const int64_t workspace2_base = + cid * Workspace2Elems + vid * HalfHidden * HiddenSize; + + TEXPANDS(hsum_ub, 0.0f); + pipe_barrier(PIPE_ALL); + HalfHiddenGlobal init_h_global(workspace_2 + workspace2_base); + TSTORE(init_h_global, hsum_ub); + pipe_barrier(PIPE_ALL); + SetCrossFlag(1, 2); + + for (int64_t i = 0; i < chunk_num; ++i) { + WaitCrossFlag(0); + + HalfAccGlobal acc_global(workspace_1 + workspace1_base); + HalfHiddenGlobal h_global(workspace_2 + workspace2_base); + TLOAD(acc_ub, acc_global); + TLOAD(h_ub, h_global); + pipe_barrier(PIPE_ALL); + + TADD(hsum_ub, hsum_ub, h_ub); + pipe_barrier(PIPE_ALL); + if constexpr (!PreloadMask) { + // Fallback for tighter UB budgets at larger chunk sizes. + TLOAD(mask_ub, mask_global); + pipe_barrier(PIPE_ALL); + } + TMUL(masked_acc_ub, acc_ub, mask_ub); + pipe_barrier(PIPE_ALL); + TSTORE(acc_global, masked_acc_ub); + TSTORE(h_global, hsum_ub); + pipe_barrier(PIPE_ALL); + SetCrossFlag(1, 2); + } + } +#endif +} + +extern "C" __global__ AICORE void launch_linear_attention( + __gm__ uint8_t *q, __gm__ uint8_t *k, __gm__ uint8_t *v, + __gm__ uint8_t *workspace_1, __gm__ uint8_t *workspace_2, + __gm__ uint8_t *causal_mask, __gm__ uint8_t *o, int64_t batch_size, + int64_t seq_len, uint64_t ffts_addr) { + main_kernel( + reinterpret_cast<__gm__ half *>(q), reinterpret_cast<__gm__ half *>(k), + reinterpret_cast<__gm__ half *>(v), + reinterpret_cast<__gm__ half *>(workspace_1), + reinterpret_cast<__gm__ half *>(workspace_2), + reinterpret_cast<__gm__ half *>(causal_mask), + reinterpret_cast<__gm__ half *>(o), batch_size, seq_len, ffts_addr); +} + +extern "C" void call_kernel(uint32_t blockDim, void *stream, uint8_t *q, + uint8_t *k, uint8_t *v, uint8_t *workspace_1, + uint8_t *workspace_2, uint8_t *causal_mask, + uint8_t *o, + int64_t batch_size, int64_t seq_len) { + uint32_t ffts_len = 0; + uint64_t ffts_addr = 0; + rtGetC2cCtrlAddr(&ffts_addr, &ffts_len); + launch_linear_attention<<>>( + q, k, v, workspace_1, workspace_2, causal_mask, o, batch_size, seq_len, + ffts_addr); +} diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/04_chunk128/run_linear_attention.py b/examples/jit_cpp/linear_attention/optimize_step_by_step/04_chunk128/run_linear_attention.py new file mode 100644 index 00000000..db5cd229 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/04_chunk128/run_linear_attention.py @@ -0,0 +1,33 @@ +from pathlib import Path +import sys + +COMMON_DIR = Path(__file__).resolve().parents[1] / "common" +if str(COMMON_DIR) not in sys.path: + sys.path.insert(0, str(COMMON_DIR)) + +from jit_util_linear_attention import BLOCK_DIM, get_causal_mask, jit_compile +from linear_attention_shared import run_correctness_cases, run_dynamic_kernel + + +def run_kernel(src, q, k, v, chunk_size): + return run_dynamic_kernel( + src, + q, + k, + v, + chunk_size, + jit_compile=jit_compile, + block_dim=BLOCK_DIM, + stage_count=1, + use_mask=True, + mask_factory=get_causal_mask, + ) + + +def main(): + test_configs = [(1, 2, 64, 128, 64), (1, 2, 256, 128, 64), (4, 2, 128, 128, 64), (8, 2, 512, 128, 64), (10, 2, 512, 128, 64), (16, 2, 256, 128, 64), (32, 2, 128, 128, 64), (1, 2, 1024, 128, 64), (8, 2, 2048, 128, 64), (2, 2, 4096, 128, 64), (16, 2, 1024, 128, 64), (50, 20, 128, 128, 64), (1, 2, 128, 128, 128), (4, 2, 512, 128, 128), (16, 20, 1024, 128, 128)] + run_correctness_cases(__file__, test_configs, run_kernel) + + +if __name__ == "__main__": + main() diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/05_l0_double_buffer/README.md b/examples/jit_cpp/linear_attention/optimize_step_by_step/05_l0_double_buffer/README.md new file mode 100644 index 00000000..91b5023b --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/05_l0_double_buffer/README.md @@ -0,0 +1,7 @@ +# Step 05: L0 Double Buffer + +This step corresponds to commit `7b811b0`. + +What changed: +- the cube-side matmul helper splits `K=128` into `2 x 64` phases +- L0A/L0B ping-pong buffering hides part of the extract latency behind cube compute diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/05_l0_double_buffer/benchmark_linear_attention.py b/examples/jit_cpp/linear_attention/optimize_step_by_step/05_l0_double_buffer/benchmark_linear_attention.py new file mode 100644 index 00000000..eb3890ea --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/05_l0_double_buffer/benchmark_linear_attention.py @@ -0,0 +1,49 @@ +from pathlib import Path +import sys + +COMMON_DIR = Path(__file__).resolve().parents[1] / "common" +if str(COMMON_DIR) not in sys.path: + sys.path.insert(0, str(COMMON_DIR)) + +from jit_util_linear_attention import BLOCK_DIM, get_causal_mask, jit_compile +from linear_attention_shared import benchmark_cli, benchmark_dynamic_kernel + +DEFAULT_SHAPES = [(16, 20, 1024, 128, 128), (16, 20, 2048, 128, 128), (32, 20, 1024, 128, 128), (8, 20, 4096, 128, 128)] + +QUICK_SHAPES = [(8, 20, 1024, 128, 128), (16, 20, 1024, 128, 128)] + +THROUGHPUT_HUNT_SHAPES = [(24, 20, 2048, 128, 128), (48, 20, 1024, 128, 128), (12, 20, 8192, 128, 128), (24, 20, 1536, 128, 128)] + + + +def benchmark_shape(src: str, *, batch: int, heads: int, seq_len: int, hidden: int, chunk: int, warmup: int, repeats: int): + return benchmark_dynamic_kernel( + src, + batch=batch, + heads=heads, + seq_len=seq_len, + hidden=hidden, + chunk=chunk, + warmup=warmup, + repeats=repeats, + jit_compile=jit_compile, + block_dim=BLOCK_DIM, + stage_count=1, + use_mask=True, + include_workspace_bytes=False, + mask_factory=get_causal_mask, + ) + + +def main(): + benchmark_cli( + script_file=__file__, + default_shapes=DEFAULT_SHAPES, + quick_shapes=QUICK_SHAPES, + benchmark_shape=benchmark_shape, + throughput_hunt_shapes=THROUGHPUT_HUNT_SHAPES + ) + + +if __name__ == "__main__": + main() diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/05_l0_double_buffer/jit_util_linear_attention.py b/examples/jit_cpp/linear_attention/optimize_step_by_step/05_l0_double_buffer/jit_util_linear_attention.py new file mode 100644 index 00000000..e2870489 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/05_l0_double_buffer/jit_util_linear_attention.py @@ -0,0 +1,60 @@ +from pathlib import Path +import sys + +COMMON_DIR = Path(__file__).resolve().parents[1] / "common" +if str(COMMON_DIR) not in sys.path: + sys.path.insert(0, str(COMMON_DIR)) + +from functools import lru_cache + +from jit_shared import BLOCK_DIM, OPTIMIZED_KERNEL_FLAGS, compile_cpp as shared_compile_cpp +from jit_shared import get_causal_mask, load_dynamic_mask_lib + + +def compile_cpp( + kernel_cpp: str, + num_heads: int, + hidden_size: int, + chunk_size: int, + verbose: bool = False, + timeout: int = 180, +) -> str: + return shared_compile_cpp( + kernel_cpp, + output_name=f"linear_attention_H{num_heads}_D{hidden_size}_C{chunk_size}_jit.so", + std="gnu++17", + defines=[ + f"-DLINEAR_ATTN_H={num_heads}", + f"-DLINEAR_ATTN_D={hidden_size}", + f"-DLINEAR_ATTN_C={chunk_size}", + ], + extra_flags=OPTIMIZED_KERNEL_FLAGS, + verbose=verbose, + timeout=timeout, + ) + + +def load_lib(lib_path: str): + return load_dynamic_mask_lib(lib_path) + + +@lru_cache(maxsize=None) +def jit_compile( + src_path: str, + num_heads: int, + hidden_size: int, + chunk_size: int, + verbose: bool = True, + clean_up: bool = False, +): + lib_path = compile_cpp( + src_path, + num_heads=num_heads, + hidden_size=hidden_size, + chunk_size=chunk_size, + verbose=verbose, + ) + func = load_lib(lib_path) + if clean_up: + Path(lib_path).unlink(missing_ok=True) + return func diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/05_l0_double_buffer/linear_attention.cpp b/examples/jit_cpp/linear_attention/optimize_step_by_step/05_l0_double_buffer/linear_attention.cpp new file mode 100644 index 00000000..639da86d --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/05_l0_double_buffer/linear_attention.cpp @@ -0,0 +1,408 @@ +#include +#include +#include +#include + +using namespace pto; + +// Step 05 keeps the step-04 outer schedule, but upgrades the cube microkernel. +// When K=128, the cube core now alternates between two L0 buffers so that the +// next 64-wide slice can be extracted while the current slice is being used. + +#ifndef LINEAR_ATTN_H +#define LINEAR_ATTN_H 2 +#endif + +#ifndef LINEAR_ATTN_D +#define LINEAR_ATTN_D 128 +#endif + +#ifndef LINEAR_ATTN_C +#define LINEAR_ATTN_C 64 +#endif + +template +using L1Mat = Tile; + +template +using L1MatTrans = + Tile; + +template +using UbVec = Tile; + +template +AICORE inline void SetCrossFlag(int32_t flag, int32_t mode) { + const int config = 1 | (mode << 4) | (flag << 8); + ffts_cross_core_sync(Pipe, config); +} + +AICORE inline void WaitCrossFlag(int32_t flag) { wait_flag_dev(flag); } + +template +AICORE inline void SetFlag(uint32_t id) { + set_flag(Src, Dst, static_cast(id)); +} + +template +AICORE inline void WaitFlag(uint32_t id) { + wait_flag(Src, Dst, static_cast(id)); +} + +template +AICORE inline void MatmulL1( + TileAcc &dst, + std::conditional_t, L1Mat> &a_l1, + std::conditional_t, L1Mat> &b_l1, + bool init) { + if constexpr ((K % 64 == 0) && (K > 64)) { + // New in step 05: split the K dimension into 64-wide pieces and ping-pong + // between two L0 buffers. This is the first "real" pipeline inside a + // helper function, but the mathematical result is still one GEMM. + constexpr int KStep = 64; + constexpr int Parts = K / KStep; + constexpr uintptr_t AStepBytes = M * KStep * sizeof(half); + constexpr uintptr_t BStepBytes = KStep * N * sizeof(half); + + TileLeft a_l0[2]; + TileRight b_l0[2]; + TASSIGN(a_l0[0], static_cast(0)); + TASSIGN(a_l0[1], AStepBytes); + TASSIGN(b_l0[0], static_cast(0)); + TASSIGN(b_l0[1], BStepBytes); + + // These flags hand ownership of each L0 buffer back and forth between the + // extractor (MTE1) and the cube compute pipeline (M). + SetFlag(0); + SetFlag(1); + + for (int part = 0; part < Parts; ++part) { + const int buf = part & 1; + WaitFlag(buf); + + if constexpr (TransposeA) { + L1MatTrans a_view; + TRESHAPE(a_view, a_l1); + TEXTRACT(a_l0[buf], a_view, 0, part * KStep); + } else { + TEXTRACT(a_l0[buf], a_l1, 0, part * KStep); + } + + if constexpr (TransposeB) { + L1MatTrans b_view; + TRESHAPE(b_view, b_l1); + TEXTRACT(b_l0[buf], b_view, part * KStep, 0); + } else { + TEXTRACT(b_l0[buf], b_l1, part * KStep, 0); + } + + SetFlag(buf); + WaitFlag(buf); + + if (init && part == 0) { + TMATMUL(dst, a_l0[buf], b_l0[buf]); + } else { + TMATMUL_ACC(dst, dst, a_l0[buf], b_l0[buf]); + } + + SetFlag(buf); + } + + WaitFlag(0); + WaitFlag(1); + pipe_barrier(PIPE_ALL); + } else { + // Small-K fallback: keep the simpler one-shot path for readability. + TileLeft a_l0; + TileRight b_l0; + TASSIGN(a_l0, 0x0); + TASSIGN(b_l0, 0x0); + + if constexpr (TransposeA) { + L1MatTrans a_view; + TRESHAPE(a_view, a_l1); + TEXTRACT(a_l0, a_view, 0, 0); + } else { + TEXTRACT(a_l0, a_l1, 0, 0); + } + + if constexpr (TransposeB) { + L1MatTrans b_view; + TRESHAPE(b_view, b_l1); + TEXTRACT(b_l0, b_view, 0, 0); + } else { + TEXTRACT(b_l0, b_l1, 0, 0); + } + + pipe_barrier(PIPE_ALL); + if (init) { + TMATMUL(dst, a_l0, b_l0); + } else { + TMATMUL_ACC(dst, dst, a_l0, b_l0); + } + pipe_barrier(PIPE_ALL); + } +} + +template +AICORE void main_kernel(__gm__ half *q, __gm__ half *k, __gm__ half *v, + __gm__ half *workspace_1, __gm__ half *workspace_2, + __gm__ half *causal_mask, __gm__ half *o, + int64_t batch_size, int64_t seq_len, + uint64_t ffts_addr) { + constexpr int32_t VecNum = 2; + constexpr int32_t HalfChunk = ChunkSize / VecNum; + constexpr int32_t HalfHidden = HiddenSize / VecNum; + constexpr int32_t ChunkElems = ChunkSize * HiddenSize; + constexpr int32_t Workspace1Elems = ChunkSize * ChunkSize; + constexpr int32_t Workspace2Elems = HiddenSize * HiddenSize; + + constexpr int32_t QL1Addr = 0; + constexpr int32_t KL1Addr = QL1Addr + ChunkElems * sizeof(half); + constexpr int32_t VL1Addr = KL1Addr + ChunkElems * sizeof(half); + constexpr int32_t HL1Addr = VL1Addr + ChunkElems * sizeof(half); + constexpr int32_t AccL1Addr = HL1Addr + Workspace2Elems * sizeof(half); + + constexpr int32_t SharedL0Addr = 0; + + constexpr int32_t HsumUbAddr = 0; + constexpr int32_t AccUbAddr = + HsumUbAddr + HalfHidden * HiddenSize * sizeof(half); + constexpr int32_t HUbAddr = AccUbAddr + HalfChunk * ChunkSize * sizeof(half); + constexpr int32_t RawUBBytes = + (HalfHidden * HiddenSize + HalfChunk * ChunkSize + HalfHidden * HiddenSize + + HalfChunk * ChunkSize + HalfChunk * ChunkSize) * + sizeof(half); + constexpr bool PreloadMask = RawUBBytes <= 72 * 1024; + constexpr bool AliasMaskIntoH = + !PreloadMask && (HalfHidden * HiddenSize >= HalfChunk * ChunkSize); + constexpr int32_t MaskUbAddr = + AliasMaskIntoH ? HUbAddr : HUbAddr + HalfHidden * HiddenSize * sizeof(half); + constexpr int32_t MaskedAccUbAddr = + MaskUbAddr + HalfChunk * ChunkSize * sizeof(half); + + constexpr int32_t L0CBytes = + (Workspace2Elems > Workspace1Elems + ? (Workspace2Elems > ChunkElems ? Workspace2Elems : ChunkElems) + : (Workspace1Elems > ChunkElems ? Workspace1Elems : ChunkElems)) * + sizeof(float); + constexpr int32_t UBBytes = + (HalfHidden * HiddenSize + HalfChunk * ChunkSize + + (AliasMaskIntoH ? HalfHidden * HiddenSize + : HalfHidden * HiddenSize + HalfChunk * ChunkSize) + + HalfChunk * ChunkSize) * + sizeof(half); + static_assert((HiddenSize % 2) == 0, "HiddenSize must be even."); + static_assert((ChunkSize % 2) == 0, "ChunkSize must be even."); + static_assert(L0CBytes <= 112 * 1024, + "Tile sizes exceed the validated L0C budget for this minimum kernel."); + static_assert(PreloadMask || AliasMaskIntoH, + "Current minimum kernel requires either a preloaded mask or H UB large enough to alias the mask."); + static_assert(UBBytes <= 72 * 1024, + "Tile sizes exceed the validated UB budget for this minimum kernel."); + + using ChunkGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using AccGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HiddenGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HalfAccGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HalfHiddenGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HalfMaskGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + + const int64_t total_work = batch_size * NumHeads; + const int64_t chunk_num = seq_len / ChunkSize; + const int64_t cid = get_block_idx(); + const int64_t vid = get_subblockid(); + set_ffts_base_addr(ffts_addr); + + L1Mat q_l1; + L1Mat k_l1; + L1Mat v_l1; + L1Mat h_l1; + L1Mat acc_l1; + TASSIGN(q_l1, QL1Addr); + TASSIGN(k_l1, KL1Addr); + TASSIGN(v_l1, VL1Addr); + TASSIGN(h_l1, HL1Addr); + TASSIGN(acc_l1, AccL1Addr); + + TileAcc acc_l0; + TileAcc h_l0; + TileAcc o_l0; + TASSIGN(acc_l0, SharedL0Addr); + TASSIGN(h_l0, SharedL0Addr); + TASSIGN(o_l0, SharedL0Addr); + + UbVec hsum_ub; + UbVec h_ub; + UbVec acc_ub; + UbVec mask_ub; + UbVec masked_acc_ub; + TASSIGN(hsum_ub, HsumUbAddr); + TASSIGN(acc_ub, AccUbAddr); + TASSIGN(h_ub, HUbAddr); + TASSIGN(mask_ub, MaskUbAddr); + TASSIGN(masked_acc_ub, MaskedAccUbAddr); + +#if defined(__DAV_C220_CUBE__) + for (int64_t work_idx = 0; work_idx < (total_work + block_num - 1) / block_num; + ++work_idx) { + const int64_t pid = work_idx * block_num + cid; + if (pid >= total_work) { + continue; + } + + const int64_t by = pid % NumHeads; + const int64_t bz = pid / NumHeads; + const int64_t qkv_base = ((bz * NumHeads + by) * seq_len) * HiddenSize; + const int64_t workspace1_base = cid * Workspace1Elems; + const int64_t workspace2_base = cid * Workspace2Elems; + + WaitCrossFlag(1); + + for (int64_t i = 0; i < chunk_num; ++i) { + const int64_t chunk_base = qkv_base + i * ChunkElems; + + ChunkGlobal q_global(q + chunk_base); + ChunkGlobal k_global(k + chunk_base); + ChunkGlobal v_global(v + chunk_base); + HiddenGlobal h_global(workspace_2 + workspace2_base); + + TLOAD(q_l1, q_global); + TLOAD(k_l1, k_global); + TLOAD(v_l1, v_global); + TLOAD(h_l1, h_global); + pipe_barrier(PIPE_ALL); + + // Compared with step 04, the surrounding code barely changes. The speedup + // mainly comes from the improved MatmulL1 helper above. + MatmulL1(acc_l0, q_l1, k_l1, + true); + AccGlobal acc_global(workspace_1 + workspace1_base); + TSTORE(acc_global, acc_l0); + pipe_barrier(PIPE_ALL); + + MatmulL1(h_l0, k_l1, v_l1, + true); + HiddenGlobal h_out_global(workspace_2 + workspace2_base); + TSTORE(h_out_global, h_l0); + pipe_barrier(PIPE_ALL); + SetCrossFlag(0, 2); + + WaitCrossFlag(1); + AccGlobal masked_acc_global(workspace_1 + workspace1_base); + TLOAD(acc_l1, masked_acc_global); + pipe_barrier(PIPE_ALL); + + MatmulL1(o_l0, acc_l1, + v_l1, true); + MatmulL1(o_l0, q_l1, + h_l1, false); + + ChunkGlobal o_global(o + chunk_base); + TSTORE(o_global, o_l0); + pipe_barrier(PIPE_ALL); + } + } +#endif + +#if defined(__DAV_C220_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + HalfMaskGlobal mask_global(causal_mask + vid * HalfChunk * ChunkSize); + if constexpr (PreloadMask) { + TLOAD(mask_ub, mask_global); + pipe_barrier(PIPE_ALL); + } + + for (int64_t work_idx = 0; work_idx < (total_work + block_num - 1) / block_num; + ++work_idx) { + const int64_t pid = work_idx * block_num + cid; + if (pid >= total_work) { + continue; + } + + const int64_t workspace1_base = + cid * Workspace1Elems + vid * HalfChunk * ChunkSize; + const int64_t workspace2_base = + cid * Workspace2Elems + vid * HalfHidden * HiddenSize; + + TEXPANDS(hsum_ub, 0.0f); + pipe_barrier(PIPE_ALL); + HalfHiddenGlobal init_h_global(workspace_2 + workspace2_base); + TSTORE(init_h_global, hsum_ub); + pipe_barrier(PIPE_ALL); + SetCrossFlag(1, 2); + + for (int64_t i = 0; i < chunk_num; ++i) { + WaitCrossFlag(0); + + HalfAccGlobal acc_global(workspace_1 + workspace1_base); + HalfHiddenGlobal h_global(workspace_2 + workspace2_base); + TLOAD(acc_ub, acc_global); + TLOAD(h_ub, h_global); + pipe_barrier(PIPE_ALL); + + TADD(hsum_ub, hsum_ub, h_ub); + pipe_barrier(PIPE_ALL); + if constexpr (!PreloadMask) { + TLOAD(mask_ub, mask_global); + pipe_barrier(PIPE_ALL); + } + TMUL(masked_acc_ub, acc_ub, mask_ub); + pipe_barrier(PIPE_ALL); + TSTORE(acc_global, masked_acc_ub); + TSTORE(h_global, hsum_ub); + pipe_barrier(PIPE_ALL); + SetCrossFlag(1, 2); + } + } +#endif +} + +extern "C" __global__ AICORE void launch_linear_attention( + __gm__ uint8_t *q, __gm__ uint8_t *k, __gm__ uint8_t *v, + __gm__ uint8_t *workspace_1, __gm__ uint8_t *workspace_2, + __gm__ uint8_t *causal_mask, __gm__ uint8_t *o, int64_t batch_size, + int64_t seq_len, uint64_t ffts_addr) { + main_kernel( + reinterpret_cast<__gm__ half *>(q), reinterpret_cast<__gm__ half *>(k), + reinterpret_cast<__gm__ half *>(v), + reinterpret_cast<__gm__ half *>(workspace_1), + reinterpret_cast<__gm__ half *>(workspace_2), + reinterpret_cast<__gm__ half *>(causal_mask), + reinterpret_cast<__gm__ half *>(o), batch_size, seq_len, ffts_addr); +} + +extern "C" void call_kernel(uint32_t blockDim, void *stream, uint8_t *q, + uint8_t *k, uint8_t *v, uint8_t *workspace_1, + uint8_t *workspace_2, uint8_t *causal_mask, + uint8_t *o, + int64_t batch_size, int64_t seq_len) { + uint32_t ffts_len = 0; + uint64_t ffts_addr = 0; + rtGetC2cCtrlAddr(&ffts_addr, &ffts_len); + launch_linear_attention<<>>( + q, k, v, workspace_1, workspace_2, causal_mask, o, batch_size, seq_len, + ffts_addr); +} diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/05_l0_double_buffer/run_linear_attention.py b/examples/jit_cpp/linear_attention/optimize_step_by_step/05_l0_double_buffer/run_linear_attention.py new file mode 100644 index 00000000..db5cd229 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/05_l0_double_buffer/run_linear_attention.py @@ -0,0 +1,33 @@ +from pathlib import Path +import sys + +COMMON_DIR = Path(__file__).resolve().parents[1] / "common" +if str(COMMON_DIR) not in sys.path: + sys.path.insert(0, str(COMMON_DIR)) + +from jit_util_linear_attention import BLOCK_DIM, get_causal_mask, jit_compile +from linear_attention_shared import run_correctness_cases, run_dynamic_kernel + + +def run_kernel(src, q, k, v, chunk_size): + return run_dynamic_kernel( + src, + q, + k, + v, + chunk_size, + jit_compile=jit_compile, + block_dim=BLOCK_DIM, + stage_count=1, + use_mask=True, + mask_factory=get_causal_mask, + ) + + +def main(): + test_configs = [(1, 2, 64, 128, 64), (1, 2, 256, 128, 64), (4, 2, 128, 128, 64), (8, 2, 512, 128, 64), (10, 2, 512, 128, 64), (16, 2, 256, 128, 64), (32, 2, 128, 128, 64), (1, 2, 1024, 128, 64), (8, 2, 2048, 128, 64), (2, 2, 4096, 128, 64), (16, 2, 1024, 128, 64), (50, 20, 128, 128, 64), (1, 2, 128, 128, 128), (4, 2, 512, 128, 128), (16, 20, 1024, 128, 128)] + run_correctness_cases(__file__, test_configs, run_kernel) + + +if __name__ == "__main__": + main() diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/06_two_slot_cv_pipeline/README.md b/examples/jit_cpp/linear_attention/optimize_step_by_step/06_two_slot_cv_pipeline/README.md new file mode 100644 index 00000000..267a30b5 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/06_two_slot_cv_pipeline/README.md @@ -0,0 +1,8 @@ +# Step 06: Two-Slot Cube-Vector Pipeline + +This step corresponds to commit `3350511`. + +What changed: +- the per-core workspaces are doubled to two slots +- cube can prepare chunk `i + 1` while vector finishes chunk `i` +- explicit cross-core handshakes keep the staged buffers safe when `B * H` exceeds the core count diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/06_two_slot_cv_pipeline/benchmark_linear_attention.py b/examples/jit_cpp/linear_attention/optimize_step_by_step/06_two_slot_cv_pipeline/benchmark_linear_attention.py new file mode 100644 index 00000000..a658fa8b --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/06_two_slot_cv_pipeline/benchmark_linear_attention.py @@ -0,0 +1,49 @@ +from pathlib import Path +import sys + +COMMON_DIR = Path(__file__).resolve().parents[1] / "common" +if str(COMMON_DIR) not in sys.path: + sys.path.insert(0, str(COMMON_DIR)) + +from jit_util_linear_attention import BLOCK_DIM, get_causal_mask, jit_compile +from linear_attention_shared import benchmark_cli, benchmark_dynamic_kernel + +DEFAULT_SHAPES = [(16, 20, 1024, 128, 128), (16, 20, 2048, 128, 128), (32, 20, 1024, 128, 128), (8, 20, 4096, 128, 128)] + +QUICK_SHAPES = [(8, 20, 1024, 128, 128), (16, 20, 1024, 128, 128)] + +THROUGHPUT_HUNT_SHAPES = [(24, 20, 2048, 128, 128), (48, 20, 1024, 128, 128), (12, 20, 8192, 128, 128), (24, 20, 1536, 128, 128)] + + + +def benchmark_shape(src: str, *, batch: int, heads: int, seq_len: int, hidden: int, chunk: int, warmup: int, repeats: int): + return benchmark_dynamic_kernel( + src, + batch=batch, + heads=heads, + seq_len=seq_len, + hidden=hidden, + chunk=chunk, + warmup=warmup, + repeats=repeats, + jit_compile=jit_compile, + block_dim=BLOCK_DIM, + stage_count=2, + use_mask=True, + include_workspace_bytes=False, + mask_factory=get_causal_mask, + ) + + +def main(): + benchmark_cli( + script_file=__file__, + default_shapes=DEFAULT_SHAPES, + quick_shapes=QUICK_SHAPES, + benchmark_shape=benchmark_shape, + throughput_hunt_shapes=THROUGHPUT_HUNT_SHAPES + ) + + +if __name__ == "__main__": + main() diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/06_two_slot_cv_pipeline/jit_util_linear_attention.py b/examples/jit_cpp/linear_attention/optimize_step_by_step/06_two_slot_cv_pipeline/jit_util_linear_attention.py new file mode 100644 index 00000000..e2870489 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/06_two_slot_cv_pipeline/jit_util_linear_attention.py @@ -0,0 +1,60 @@ +from pathlib import Path +import sys + +COMMON_DIR = Path(__file__).resolve().parents[1] / "common" +if str(COMMON_DIR) not in sys.path: + sys.path.insert(0, str(COMMON_DIR)) + +from functools import lru_cache + +from jit_shared import BLOCK_DIM, OPTIMIZED_KERNEL_FLAGS, compile_cpp as shared_compile_cpp +from jit_shared import get_causal_mask, load_dynamic_mask_lib + + +def compile_cpp( + kernel_cpp: str, + num_heads: int, + hidden_size: int, + chunk_size: int, + verbose: bool = False, + timeout: int = 180, +) -> str: + return shared_compile_cpp( + kernel_cpp, + output_name=f"linear_attention_H{num_heads}_D{hidden_size}_C{chunk_size}_jit.so", + std="gnu++17", + defines=[ + f"-DLINEAR_ATTN_H={num_heads}", + f"-DLINEAR_ATTN_D={hidden_size}", + f"-DLINEAR_ATTN_C={chunk_size}", + ], + extra_flags=OPTIMIZED_KERNEL_FLAGS, + verbose=verbose, + timeout=timeout, + ) + + +def load_lib(lib_path: str): + return load_dynamic_mask_lib(lib_path) + + +@lru_cache(maxsize=None) +def jit_compile( + src_path: str, + num_heads: int, + hidden_size: int, + chunk_size: int, + verbose: bool = True, + clean_up: bool = False, +): + lib_path = compile_cpp( + src_path, + num_heads=num_heads, + hidden_size=hidden_size, + chunk_size=chunk_size, + verbose=verbose, + ) + func = load_lib(lib_path) + if clean_up: + Path(lib_path).unlink(missing_ok=True) + return func diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/06_two_slot_cv_pipeline/linear_attention.cpp b/examples/jit_cpp/linear_attention/optimize_step_by_step/06_two_slot_cv_pipeline/linear_attention.cpp new file mode 100644 index 00000000..49d01e95 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/06_two_slot_cv_pipeline/linear_attention.cpp @@ -0,0 +1,557 @@ +#include +#include +#include +#include + +using namespace pto; + +// Step 06 keeps the step-05 cube microkernel and adds a larger structural +// pipeline: cube works on chunk i+1 while vector finishes chunk i. The two +// workspace "slots" below are the staging area that makes this overlap safe. + +#ifndef LINEAR_ATTN_H +#define LINEAR_ATTN_H 2 +#endif + +#ifndef LINEAR_ATTN_D +#define LINEAR_ATTN_D 128 +#endif + +#ifndef LINEAR_ATTN_C +#define LINEAR_ATTN_C 64 +#endif + +template +using L1Mat = Tile; + +template +using L1MatTrans = + Tile; + +template +using UbVec = Tile; + +template +AICORE inline void SetCrossFlag(int32_t flag, int32_t mode) { + const int config = 1 | (mode << 4) | (flag << 8); + ffts_cross_core_sync(Pipe, config); +} + +AICORE inline void WaitCrossFlag(int32_t flag) { wait_flag_dev(flag); } + +template +AICORE inline void SetFlag(uint32_t id) { + set_flag(Src, Dst, static_cast(id)); +} + +template +AICORE inline void WaitFlag(uint32_t id) { + wait_flag(Src, Dst, static_cast(id)); +} + +template +AICORE inline void MatmulL1( + TileAcc &dst, + std::conditional_t, L1Mat> &a_l1, + std::conditional_t, L1Mat> &b_l1, + bool init) { + if constexpr ((K % 64 == 0) && (K > 64)) { + constexpr int KStep = 64; + constexpr int Parts = K / KStep; + constexpr uintptr_t AStepBytes = M * KStep * sizeof(half); + constexpr uintptr_t BStepBytes = KStep * N * sizeof(half); + + TileLeft a_l0[2]; + TileRight b_l0[2]; + TASSIGN(a_l0[0], static_cast(0)); + TASSIGN(a_l0[1], AStepBytes); + TASSIGN(b_l0[0], static_cast(0)); + TASSIGN(b_l0[1], BStepBytes); + + SetFlag(0); + SetFlag(1); + + for (int part = 0; part < Parts; ++part) { + const int buf = part & 1; + WaitFlag(buf); + + if constexpr (TransposeA) { + L1MatTrans a_view; + TRESHAPE(a_view, a_l1); + TEXTRACT(a_l0[buf], a_view, 0, part * KStep); + } else { + TEXTRACT(a_l0[buf], a_l1, 0, part * KStep); + } + + if constexpr (TransposeB) { + L1MatTrans b_view; + TRESHAPE(b_view, b_l1); + TEXTRACT(b_l0[buf], b_view, part * KStep, 0); + } else { + TEXTRACT(b_l0[buf], b_l1, part * KStep, 0); + } + + SetFlag(buf); + WaitFlag(buf); + + if (init && part == 0) { + TMATMUL(dst, a_l0[buf], b_l0[buf]); + } else { + TMATMUL_ACC(dst, dst, a_l0[buf], b_l0[buf]); + } + + SetFlag(buf); + } + + WaitFlag(0); + WaitFlag(1); + pipe_barrier(PIPE_ALL); + } else { + TileLeft a_l0; + TileRight b_l0; + TASSIGN(a_l0, 0x0); + TASSIGN(b_l0, 0x0); + + if constexpr (TransposeA) { + L1MatTrans a_view; + TRESHAPE(a_view, a_l1); + TEXTRACT(a_l0, a_view, 0, 0); + } else { + TEXTRACT(a_l0, a_l1, 0, 0); + } + + if constexpr (TransposeB) { + L1MatTrans b_view; + TRESHAPE(b_view, b_l1); + TEXTRACT(b_l0, b_view, 0, 0); + } else { + TEXTRACT(b_l0, b_l1, 0, 0); + } + + pipe_barrier(PIPE_ALL); + if (init) { + TMATMUL(dst, a_l0, b_l0); + } else { + TMATMUL_ACC(dst, dst, a_l0, b_l0); + } + pipe_barrier(PIPE_ALL); + } +} + +template +AICORE void main_kernel(__gm__ half *q, __gm__ half *k, __gm__ half *v, + __gm__ half *workspace_1, __gm__ half *workspace_2, + __gm__ half *causal_mask, __gm__ half *o, + int64_t batch_size, int64_t seq_len, + uint64_t ffts_addr) { + // Two slots are enough for a producer/consumer pipeline: + // one slot is owned by cube, the other by vector. + constexpr int32_t StageCount = 2; + constexpr bool UseTwoStagePipeline = (ChunkSize >= 128); + constexpr int32_t VecNum = 2; + constexpr int32_t HalfChunk = ChunkSize / VecNum; + constexpr int32_t HalfHidden = HiddenSize / VecNum; + constexpr int32_t ChunkElems = ChunkSize * HiddenSize; + constexpr int32_t Workspace1SlotElems = ChunkSize * ChunkSize; + constexpr int32_t Workspace2SlotElems = HiddenSize * HiddenSize; + constexpr int32_t Workspace1Elems = StageCount * Workspace1SlotElems; + constexpr int32_t Workspace2Elems = StageCount * Workspace2SlotElems; + + constexpr int32_t QL1Addr = 0; + constexpr int32_t KL1Addr = QL1Addr + ChunkElems * sizeof(half); + constexpr int32_t VL1Addr = KL1Addr + ChunkElems * sizeof(half); + constexpr int32_t HL1Addr = VL1Addr + ChunkElems * sizeof(half); + constexpr int32_t AccL1Addr = HL1Addr + Workspace2SlotElems * sizeof(half); + + constexpr int32_t SharedL0Addr = 0; + + constexpr int32_t HsumUbAddr = 0; + constexpr int32_t AccUbAddr = + HsumUbAddr + HalfHidden * HiddenSize * sizeof(half); + constexpr int32_t HUbAddr = AccUbAddr + HalfChunk * ChunkSize * sizeof(half); + constexpr int32_t RawUBBytes = + (HalfHidden * HiddenSize + HalfChunk * ChunkSize + HalfHidden * HiddenSize + + HalfChunk * ChunkSize + HalfChunk * ChunkSize) * + sizeof(half); + constexpr bool PreloadMask = RawUBBytes <= 72 * 1024; + constexpr bool AliasMaskIntoH = + !PreloadMask && (HalfHidden * HiddenSize >= HalfChunk * ChunkSize); + constexpr int32_t MaskUbAddr = + AliasMaskIntoH ? HUbAddr : HUbAddr + HalfHidden * HiddenSize * sizeof(half); + constexpr int32_t MaskedAccUbAddr = + MaskUbAddr + HalfChunk * ChunkSize * sizeof(half); + + constexpr int32_t L0CBytes = + (Workspace2SlotElems > Workspace1SlotElems + ? (Workspace2SlotElems > ChunkElems ? Workspace2SlotElems : ChunkElems) + : (Workspace1SlotElems > ChunkElems ? Workspace1SlotElems : ChunkElems)) * + sizeof(float); + constexpr int32_t UBBytes = + (HalfHidden * HiddenSize + HalfChunk * ChunkSize + + (AliasMaskIntoH ? HalfHidden * HiddenSize + : HalfHidden * HiddenSize + HalfChunk * ChunkSize) + + HalfChunk * ChunkSize) * + sizeof(half); + static_assert((HiddenSize % 2) == 0, "HiddenSize must be even."); + static_assert((ChunkSize % 2) == 0, "ChunkSize must be even."); + static_assert(L0CBytes <= 112 * 1024, + "Tile sizes exceed the validated L0C budget for this minimum kernel."); + static_assert(PreloadMask || AliasMaskIntoH, + "Current minimum kernel requires either a preloaded mask or H UB large enough to alias the mask."); + static_assert(UBBytes <= 72 * 1024, + "Tile sizes exceed the validated UB budget for this minimum kernel."); + + using ChunkGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using AccGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HiddenGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HalfAccGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HalfHiddenGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HalfMaskGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + + const int64_t total_work = batch_size * NumHeads; + const int64_t chunk_num = seq_len / ChunkSize; + const int64_t cid = get_block_idx(); + const int64_t vid = get_subblockid(); + set_ffts_base_addr(ffts_addr); + + L1Mat q_l1; + L1Mat k_l1; + L1Mat v_l1; + L1Mat h_l1; + L1Mat acc_l1; + TASSIGN(q_l1, QL1Addr); + TASSIGN(k_l1, KL1Addr); + TASSIGN(v_l1, VL1Addr); + TASSIGN(h_l1, HL1Addr); + TASSIGN(acc_l1, AccL1Addr); + + TileAcc acc_l0; + TileAcc h_l0; + TileAcc o_l0; + TASSIGN(acc_l0, SharedL0Addr); + TASSIGN(h_l0, SharedL0Addr); + TASSIGN(o_l0, SharedL0Addr); + + UbVec hsum_ub; + UbVec h_ub; + UbVec acc_ub; + UbVec mask_ub; + UbVec masked_acc_ub; + TASSIGN(hsum_ub, HsumUbAddr); + TASSIGN(acc_ub, AccUbAddr); + TASSIGN(h_ub, HUbAddr); + TASSIGN(mask_ub, MaskUbAddr); + TASSIGN(masked_acc_ub, MaskedAccUbAddr); + +#if defined(__DAV_C220_CUBE__) + for (int64_t work_idx = 0; work_idx < (total_work + block_num - 1) / block_num; + ++work_idx) { + const int64_t pid = work_idx * block_num + cid; + if (pid >= total_work) { + continue; + } + + const int64_t by = pid % NumHeads; + const int64_t bz = pid / NumHeads; + const int64_t qkv_base = ((bz * NumHeads + by) * seq_len) * HiddenSize; + const int64_t workspace1_base = cid * Workspace1Elems; + const int64_t workspace2_base = cid * Workspace2Elems; + + if constexpr (UseTwoStagePipeline) { + // Each in-flight work item gets its own small ring of cross-core flags so + // different logical jobs do not accidentally wake each other up. + const int32_t flag_base = static_cast((work_idx & 3) * 6); + WaitCrossFlag(flag_base + 4); + HiddenGlobal zero_h_global(workspace_2 + workspace2_base + Workspace2SlotElems); + TLOAD(h_l1, zero_h_global); + pipe_barrier(PIPE_ALL); + + { + // Prefill slot 0 so the sliding-window pipeline has an initial chunk to + // consume before it starts looking one chunk ahead. + const int64_t chunk_base = qkv_base; + ChunkGlobal q_global(q + chunk_base); + ChunkGlobal k_global(k + chunk_base); + ChunkGlobal v_global(v + chunk_base); + TLOAD(q_l1, q_global); + TLOAD(k_l1, k_global); + TLOAD(v_l1, v_global); + pipe_barrier(PIPE_ALL); + + MatmulL1(acc_l0, q_l1, k_l1, + true); + AccGlobal acc_global(workspace_1 + workspace1_base); + TSTORE(acc_global, acc_l0); + pipe_barrier(PIPE_ALL); + + MatmulL1(h_l0, k_l1, v_l1, + true); + HiddenGlobal h_out_global(workspace_2 + workspace2_base); + TSTORE(h_out_global, h_l0); + pipe_barrier(PIPE_ALL); + SetCrossFlag(flag_base, 2); + } + + for (int64_t i = 0; i < chunk_num; ++i) { + const int32_t slot = static_cast(i & 1); + const int32_t next_slot = slot ^ 1; + const int64_t chunk_base = qkv_base + i * ChunkElems; + + if (i + 1 < chunk_num) { + // Producer side: cube prepares chunk i+1 in the other slot while + // vector is still busy with slot "slot". + const int64_t next_chunk_base = qkv_base + (i + 1) * ChunkElems; + const int64_t next_workspace1_base = + workspace1_base + next_slot * Workspace1SlotElems; + const int64_t next_workspace2_base = + workspace2_base + next_slot * Workspace2SlotElems; + + ChunkGlobal q_global(q + next_chunk_base); + ChunkGlobal k_global(k + next_chunk_base); + ChunkGlobal v_global(v + next_chunk_base); + TLOAD(q_l1, q_global); + TLOAD(k_l1, k_global); + TLOAD(v_l1, v_global); + pipe_barrier(PIPE_ALL); + + MatmulL1( + acc_l0, q_l1, k_l1, true); + AccGlobal acc_global(workspace_1 + next_workspace1_base); + TSTORE(acc_global, acc_l0); + pipe_barrier(PIPE_ALL); + + MatmulL1(h_l0, k_l1, + v_l1, true); + HiddenGlobal h_out_global(workspace_2 + next_workspace2_base); + TSTORE(h_out_global, h_l0); + pipe_barrier(PIPE_ALL); + SetCrossFlag(flag_base + next_slot, 2); + } + + // Consumer hand-off: do not reuse this slot until vector reports that + // masking and state accumulation for the slot are finished. + WaitCrossFlag(flag_base + 2 + slot); + AccGlobal masked_acc_global(workspace_1 + workspace1_base + + slot * Workspace1SlotElems); + TLOAD(acc_l1, masked_acc_global); + ChunkGlobal q_global(q + chunk_base); + ChunkGlobal v_global(v + chunk_base); + TLOAD(q_l1, q_global); + TLOAD(v_l1, v_global); + pipe_barrier(PIPE_ALL); + + MatmulL1(o_l0, acc_l1, + v_l1, true); + MatmulL1(o_l0, q_l1, h_l1, + false); + + ChunkGlobal o_global(o + chunk_base); + TSTORE(o_global, o_l0); + pipe_barrier(PIPE_ALL); + + if (i + 1 < chunk_num) { + // Load the next prefix state after vector has written it back for this + // slot; that state will be needed when chunk i+1 reaches output stage. + HiddenGlobal next_h_global(workspace_2 + workspace2_base + + slot * Workspace2SlotElems); + TLOAD(h_l1, next_h_global); + pipe_barrier(PIPE_ALL); + } + } + SetCrossFlag(flag_base + 5, 2); + } else { + WaitCrossFlag(1); + + for (int64_t i = 0; i < chunk_num; ++i) { + const int64_t chunk_base = qkv_base + i * ChunkElems; + + ChunkGlobal q_global(q + chunk_base); + ChunkGlobal k_global(k + chunk_base); + ChunkGlobal v_global(v + chunk_base); + HiddenGlobal h_global(workspace_2 + workspace2_base); + + TLOAD(q_l1, q_global); + TLOAD(k_l1, k_global); + TLOAD(v_l1, v_global); + TLOAD(h_l1, h_global); + pipe_barrier(PIPE_ALL); + + MatmulL1(acc_l0, q_l1, k_l1, + true); + AccGlobal acc_global(workspace_1 + workspace1_base); + TSTORE(acc_global, acc_l0); + pipe_barrier(PIPE_ALL); + + MatmulL1(h_l0, k_l1, v_l1, + true); + HiddenGlobal h_out_global(workspace_2 + workspace2_base); + TSTORE(h_out_global, h_l0); + pipe_barrier(PIPE_ALL); + SetCrossFlag(0, 2); + + WaitCrossFlag(1); + AccGlobal masked_acc_global(workspace_1 + workspace1_base); + TLOAD(acc_l1, masked_acc_global); + pipe_barrier(PIPE_ALL); + + MatmulL1(o_l0, acc_l1, + v_l1, true); + MatmulL1(o_l0, q_l1, h_l1, + false); + + ChunkGlobal o_global(o + chunk_base); + TSTORE(o_global, o_l0); + pipe_barrier(PIPE_ALL); + } + } + } +#endif + +#if defined(__DAV_C220_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + HalfMaskGlobal mask_global(causal_mask + vid * HalfChunk * ChunkSize); + if constexpr (PreloadMask) { + TLOAD(mask_ub, mask_global); + pipe_barrier(PIPE_ALL); + } + + for (int64_t work_idx = 0; work_idx < (total_work + block_num - 1) / block_num; + ++work_idx) { + const int64_t pid = work_idx * block_num + cid; + if (pid >= total_work) { + continue; + } + + const int64_t workspace1_base = cid * Workspace1Elems; + const int64_t workspace2_base = cid * Workspace2Elems; + + TEXPANDS(hsum_ub, 0.0f); + pipe_barrier(PIPE_ALL); + if constexpr (UseTwoStagePipeline) { + const int32_t flag_base = static_cast((work_idx & 3) * 6); + HalfHiddenGlobal init_h_global_0(workspace_2 + workspace2_base + + vid * HalfHidden * HiddenSize); + HalfHiddenGlobal init_h_global_1(workspace_2 + workspace2_base + + Workspace2SlotElems + + vid * HalfHidden * HiddenSize); + TSTORE(init_h_global_0, hsum_ub); + TSTORE(init_h_global_1, hsum_ub); + pipe_barrier(PIPE_ALL); + SetCrossFlag(flag_base + 4, 2); + + for (int64_t i = 0; i < chunk_num; ++i) { + const int32_t slot = static_cast(i & 1); + WaitCrossFlag(flag_base + slot); + + const int64_t slot_workspace1_base = + workspace1_base + slot * Workspace1SlotElems; + const int64_t slot_workspace2_base = + workspace2_base + slot * Workspace2SlotElems; + HalfAccGlobal acc_global(workspace_1 + slot_workspace1_base + + vid * HalfChunk * ChunkSize); + HalfHiddenGlobal h_global(workspace_2 + slot_workspace2_base + + vid * HalfHidden * HiddenSize); + TLOAD(acc_ub, acc_global); + TLOAD(h_ub, h_global); + pipe_barrier(PIPE_ALL); + + // Vector consumes whichever slot cube just produced, updates the running + // hidden state, and then releases that slot back to cube. + TADD(hsum_ub, hsum_ub, h_ub); + pipe_barrier(PIPE_ALL); + if constexpr (!PreloadMask) { + TLOAD(mask_ub, mask_global); + pipe_barrier(PIPE_ALL); + } + TMUL(masked_acc_ub, acc_ub, mask_ub); + pipe_barrier(PIPE_ALL); + TSTORE(acc_global, masked_acc_ub); + TSTORE(h_global, hsum_ub); + pipe_barrier(PIPE_ALL); + SetCrossFlag(flag_base + 2 + slot, 2); + } + WaitCrossFlag(flag_base + 5); + } else { + HalfHiddenGlobal init_h_global(workspace_2 + workspace2_base + + vid * HalfHidden * HiddenSize); + TSTORE(init_h_global, hsum_ub); + pipe_barrier(PIPE_ALL); + SetCrossFlag(1, 2); + + for (int64_t i = 0; i < chunk_num; ++i) { + WaitCrossFlag(0); + + HalfAccGlobal acc_global(workspace_1 + workspace1_base + + vid * HalfChunk * ChunkSize); + HalfHiddenGlobal h_global(workspace_2 + workspace2_base + + vid * HalfHidden * HiddenSize); + TLOAD(acc_ub, acc_global); + TLOAD(h_ub, h_global); + pipe_barrier(PIPE_ALL); + + TADD(hsum_ub, hsum_ub, h_ub); + pipe_barrier(PIPE_ALL); + if constexpr (!PreloadMask) { + TLOAD(mask_ub, mask_global); + pipe_barrier(PIPE_ALL); + } + TMUL(masked_acc_ub, acc_ub, mask_ub); + pipe_barrier(PIPE_ALL); + TSTORE(acc_global, masked_acc_ub); + TSTORE(h_global, hsum_ub); + pipe_barrier(PIPE_ALL); + SetCrossFlag(1, 2); + } + } + } +#endif +} + +extern "C" __global__ AICORE void launch_linear_attention( + __gm__ uint8_t *q, __gm__ uint8_t *k, __gm__ uint8_t *v, + __gm__ uint8_t *workspace_1, __gm__ uint8_t *workspace_2, + __gm__ uint8_t *causal_mask, __gm__ uint8_t *o, int64_t batch_size, + int64_t seq_len, uint64_t ffts_addr) { + main_kernel( + reinterpret_cast<__gm__ half *>(q), reinterpret_cast<__gm__ half *>(k), + reinterpret_cast<__gm__ half *>(v), + reinterpret_cast<__gm__ half *>(workspace_1), + reinterpret_cast<__gm__ half *>(workspace_2), + reinterpret_cast<__gm__ half *>(causal_mask), + reinterpret_cast<__gm__ half *>(o), batch_size, seq_len, ffts_addr); +} + +extern "C" void call_kernel(uint32_t blockDim, void *stream, uint8_t *q, + uint8_t *k, uint8_t *v, uint8_t *workspace_1, + uint8_t *workspace_2, uint8_t *causal_mask, + uint8_t *o, + int64_t batch_size, int64_t seq_len) { + uint32_t ffts_len = 0; + uint64_t ffts_addr = 0; + rtGetC2cCtrlAddr(&ffts_addr, &ffts_len); + launch_linear_attention<<>>( + q, k, v, workspace_1, workspace_2, causal_mask, o, batch_size, seq_len, + ffts_addr); +} diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/06_two_slot_cv_pipeline/run_linear_attention.py b/examples/jit_cpp/linear_attention/optimize_step_by_step/06_two_slot_cv_pipeline/run_linear_attention.py new file mode 100644 index 00000000..db8b9633 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/06_two_slot_cv_pipeline/run_linear_attention.py @@ -0,0 +1,33 @@ +from pathlib import Path +import sys + +COMMON_DIR = Path(__file__).resolve().parents[1] / "common" +if str(COMMON_DIR) not in sys.path: + sys.path.insert(0, str(COMMON_DIR)) + +from jit_util_linear_attention import BLOCK_DIM, get_causal_mask, jit_compile +from linear_attention_shared import run_correctness_cases, run_dynamic_kernel + + +def run_kernel(src, q, k, v, chunk_size): + return run_dynamic_kernel( + src, + q, + k, + v, + chunk_size, + jit_compile=jit_compile, + block_dim=BLOCK_DIM, + stage_count=2, + use_mask=True, + mask_factory=get_causal_mask, + ) + + +def main(): + test_configs = [(1, 2, 64, 128, 64), (1, 2, 256, 128, 64), (4, 2, 128, 128, 64), (8, 2, 512, 128, 64), (10, 2, 512, 128, 64), (16, 2, 256, 128, 64), (32, 2, 128, 128, 64), (1, 2, 1024, 128, 64), (8, 2, 2048, 128, 64), (2, 2, 4096, 128, 64), (16, 2, 1024, 128, 64), (50, 20, 128, 128, 64), (1, 2, 128, 128, 128), (4, 2, 512, 128, 128), (16, 20, 1024, 128, 128)] + run_correctness_cases(__file__, test_configs, run_kernel) + + +if __name__ == "__main__": + main() diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/07_l1_prefetching/README.md b/examples/jit_cpp/linear_attention/optimize_step_by_step/07_l1_prefetching/README.md new file mode 100644 index 00000000..1b511292 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/07_l1_prefetching/README.md @@ -0,0 +1,29 @@ +# Step 07: L1 Prefetching + +This step corresponds to commit `26aac37`. + +What changed: +- a second `H`-state L1 tile is kept on the cube side +- the next accumulated hidden-state tile is prefetched while the current output path is already loading `Q`, `V`, and masked attention + +This is the current best educational endpoint in the optimization ladder. + +Important benchmarking note: +- the small `--quick` benchmark only uses `(8/16, 20, 1024, 128, 128)` shapes, so it typically reports around the mid-`60 TFLOP/s` range +- the kernel in this directory is intentionally kept identical to the current main example, so the full benchmark table reaches the same large-shape performance class + +To reproduce the main-example style result, run: + +```bash +python benchmark_linear_attention.py --warmup 2 --repeats 5 +``` + +On this machine, that full-table run validated: +- `77.71 TFLOP/s` / `565.43 GiB/s` at `(12, 20, 8192, 128, 128)` + +For comparison, the current main example measured immediately afterwards with the same command reached: +- `77.97 TFLOP/s` / `567.34 GiB/s` at `(24, 20, 6144, 128, 128)` + +That small difference is normal run-to-run noise. The important point is that this final tutorial step reaches the same `~78 TFLOP/s` performance class as the main example when both are benchmarked with the full table. + +Run those large-shape benchmarks one process at a time so the NPU is not oversubscribed by multiple concurrent benchmark jobs. diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/07_l1_prefetching/benchmark_linear_attention.py b/examples/jit_cpp/linear_attention/optimize_step_by_step/07_l1_prefetching/benchmark_linear_attention.py new file mode 100644 index 00000000..80e6c66c --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/07_l1_prefetching/benchmark_linear_attention.py @@ -0,0 +1,49 @@ +from pathlib import Path +import sys + +COMMON_DIR = Path(__file__).resolve().parents[1] / "common" +if str(COMMON_DIR) not in sys.path: + sys.path.insert(0, str(COMMON_DIR)) + +from jit_util_linear_attention import BLOCK_DIM, get_causal_mask, jit_compile +from linear_attention_shared import benchmark_cli, benchmark_dynamic_kernel + +DEFAULT_SHAPES = [(24, 20, 2048, 128, 128), (48, 20, 1024, 128, 128), (12, 20, 8192, 128, 128), (24, 20, 1536, 128, 128)] + +QUICK_SHAPES = [(8, 20, 1024, 128, 128), (16, 20, 1024, 128, 128)] + +THROUGHPUT_HUNT_SHAPES = [(32, 20, 2048, 128, 128), (24, 20, 4096, 128, 128), (12, 20, 8192, 128, 128), (24, 20, 6144, 128, 128)] + + + +def benchmark_shape(src: str, *, batch: int, heads: int, seq_len: int, hidden: int, chunk: int, warmup: int, repeats: int): + return benchmark_dynamic_kernel( + src, + batch=batch, + heads=heads, + seq_len=seq_len, + hidden=hidden, + chunk=chunk, + warmup=warmup, + repeats=repeats, + jit_compile=jit_compile, + block_dim=BLOCK_DIM, + stage_count=2, + use_mask=True, + include_workspace_bytes=False, + mask_factory=get_causal_mask, + ) + + +def main(): + benchmark_cli( + script_file=__file__, + default_shapes=DEFAULT_SHAPES, + quick_shapes=QUICK_SHAPES, + benchmark_shape=benchmark_shape, + throughput_hunt_shapes=THROUGHPUT_HUNT_SHAPES + ) + + +if __name__ == "__main__": + main() diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/07_l1_prefetching/jit_util_linear_attention.py b/examples/jit_cpp/linear_attention/optimize_step_by_step/07_l1_prefetching/jit_util_linear_attention.py new file mode 100644 index 00000000..e2870489 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/07_l1_prefetching/jit_util_linear_attention.py @@ -0,0 +1,60 @@ +from pathlib import Path +import sys + +COMMON_DIR = Path(__file__).resolve().parents[1] / "common" +if str(COMMON_DIR) not in sys.path: + sys.path.insert(0, str(COMMON_DIR)) + +from functools import lru_cache + +from jit_shared import BLOCK_DIM, OPTIMIZED_KERNEL_FLAGS, compile_cpp as shared_compile_cpp +from jit_shared import get_causal_mask, load_dynamic_mask_lib + + +def compile_cpp( + kernel_cpp: str, + num_heads: int, + hidden_size: int, + chunk_size: int, + verbose: bool = False, + timeout: int = 180, +) -> str: + return shared_compile_cpp( + kernel_cpp, + output_name=f"linear_attention_H{num_heads}_D{hidden_size}_C{chunk_size}_jit.so", + std="gnu++17", + defines=[ + f"-DLINEAR_ATTN_H={num_heads}", + f"-DLINEAR_ATTN_D={hidden_size}", + f"-DLINEAR_ATTN_C={chunk_size}", + ], + extra_flags=OPTIMIZED_KERNEL_FLAGS, + verbose=verbose, + timeout=timeout, + ) + + +def load_lib(lib_path: str): + return load_dynamic_mask_lib(lib_path) + + +@lru_cache(maxsize=None) +def jit_compile( + src_path: str, + num_heads: int, + hidden_size: int, + chunk_size: int, + verbose: bool = True, + clean_up: bool = False, +): + lib_path = compile_cpp( + src_path, + num_heads=num_heads, + hidden_size=hidden_size, + chunk_size=chunk_size, + verbose=verbose, + ) + func = load_lib(lib_path) + if clean_up: + Path(lib_path).unlink(missing_ok=True) + return func diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/07_l1_prefetching/linear_attention.cpp b/examples/jit_cpp/linear_attention/optimize_step_by_step/07_l1_prefetching/linear_attention.cpp new file mode 100644 index 00000000..dde6e764 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/07_l1_prefetching/linear_attention.cpp @@ -0,0 +1,595 @@ +#include +#include +#include +#include + +using namespace pto; + +// Step 07 keeps the two-slot cube/vector pipeline from step 06 and adds one +// more overlap opportunity: while the current output is being formed, cube also +// starts loading the next hidden-state tile into a second L1 buffer. + +#ifndef LINEAR_ATTN_H +#define LINEAR_ATTN_H 2 +#endif + +#ifndef LINEAR_ATTN_D +#define LINEAR_ATTN_D 128 +#endif + +#ifndef LINEAR_ATTN_C +#define LINEAR_ATTN_C 64 +#endif + +template +using L1Mat = Tile; + +template +using L1MatTrans = + Tile; + +template +using UbVec = Tile; + +template +AICORE inline void SetCrossFlag(int32_t flag, int32_t mode) { + const int config = 1 | (mode << 4) | (flag << 8); + ffts_cross_core_sync(Pipe, config); +} + +AICORE inline void WaitCrossFlag(int32_t flag) { wait_flag_dev(flag); } + +template +AICORE inline void SetFlag(uint32_t id) { + set_flag(Src, Dst, static_cast(id)); +} + +template +AICORE inline void WaitFlag(uint32_t id) { + wait_flag(Src, Dst, static_cast(id)); +} + +template +AICORE inline void MatmulL1( + TileAcc &dst, + std::conditional_t, L1Mat> &a_l1, + std::conditional_t, L1Mat> &b_l1, + bool init) { + if constexpr ((K % 64 == 0) && (K > 64)) { + constexpr int KStep = 64; + constexpr int Parts = K / KStep; + constexpr uintptr_t AStepBytes = M * KStep * sizeof(half); + constexpr uintptr_t BStepBytes = KStep * N * sizeof(half); + + TileLeft a_l0[2]; + TileRight b_l0[2]; + TASSIGN(a_l0[0], static_cast(0)); + TASSIGN(a_l0[1], AStepBytes); + TASSIGN(b_l0[0], static_cast(0)); + TASSIGN(b_l0[1], BStepBytes); + + SetFlag(0); + SetFlag(1); + + for (int part = 0; part < Parts; ++part) { + const int buf = part & 1; + WaitFlag(buf); + + if constexpr (TransposeA) { + L1MatTrans a_view; + TRESHAPE(a_view, a_l1); + TEXTRACT(a_l0[buf], a_view, 0, part * KStep); + } else { + TEXTRACT(a_l0[buf], a_l1, 0, part * KStep); + } + + if constexpr (TransposeB) { + L1MatTrans b_view; + TRESHAPE(b_view, b_l1); + TEXTRACT(b_l0[buf], b_view, part * KStep, 0); + } else { + TEXTRACT(b_l0[buf], b_l1, part * KStep, 0); + } + + SetFlag(buf); + WaitFlag(buf); + + if (init && part == 0) { + TMATMUL(dst, a_l0[buf], b_l0[buf]); + } else { + TMATMUL_ACC(dst, dst, a_l0[buf], b_l0[buf]); + } + + SetFlag(buf); + } + + WaitFlag(0); + WaitFlag(1); + pipe_barrier(PIPE_ALL); + } else { + TileLeft a_l0; + TileRight b_l0; + TASSIGN(a_l0, 0x0); + TASSIGN(b_l0, 0x0); + + if constexpr (TransposeA) { + L1MatTrans a_view; + TRESHAPE(a_view, a_l1); + TEXTRACT(a_l0, a_view, 0, 0); + } else { + TEXTRACT(a_l0, a_l1, 0, 0); + } + + if constexpr (TransposeB) { + L1MatTrans b_view; + TRESHAPE(b_view, b_l1); + TEXTRACT(b_l0, b_view, 0, 0); + } else { + TEXTRACT(b_l0, b_l1, 0, 0); + } + + pipe_barrier(PIPE_ALL); + if (init) { + TMATMUL(dst, a_l0, b_l0); + } else { + TMATMUL_ACC(dst, dst, a_l0, b_l0); + } + pipe_barrier(PIPE_ALL); + } +} + +template +AICORE void main_kernel(__gm__ half *q, __gm__ half *k, __gm__ half *v, + __gm__ half *workspace_1, __gm__ half *workspace_2, + __gm__ half *causal_mask, __gm__ half *o, + int64_t batch_size, int64_t seq_len, + uint64_t ffts_addr) { + constexpr int32_t StageCount = 2; + constexpr bool UseTwoStagePipeline = (ChunkSize >= 128); + // Retained optimization from the main kernel: when the chunk is large enough, + // reuse acc_ub as the destination of the mask multiply to save UB space. + constexpr bool InplaceMaskApply = (ChunkSize >= 128); + constexpr int32_t VecNum = 2; + constexpr int32_t HalfChunk = ChunkSize / VecNum; + constexpr int32_t HalfHidden = HiddenSize / VecNum; + constexpr int32_t ChunkElems = ChunkSize * HiddenSize; + constexpr int32_t Workspace1SlotElems = ChunkSize * ChunkSize; + constexpr int32_t Workspace2SlotElems = HiddenSize * HiddenSize; + constexpr int32_t Workspace1Elems = StageCount * Workspace1SlotElems; + constexpr int32_t Workspace2Elems = StageCount * Workspace2SlotElems; + + constexpr int32_t QL1Addr = 0; + constexpr int32_t KL1Addr = QL1Addr + ChunkElems * sizeof(half); + constexpr int32_t VL1Addr = KL1Addr + ChunkElems * sizeof(half); + constexpr int32_t HL1Addr = VL1Addr + ChunkElems * sizeof(half); + constexpr int32_t AccL1Addr = HL1Addr + Workspace2SlotElems * sizeof(half); + // New in step 07: reserve a second L1 buffer for the hidden state so the next + // prefix-state tile can be prefetched before it is needed. + constexpr int32_t HNextL1Addr = AccL1Addr + Workspace1SlotElems * sizeof(half); + + constexpr int32_t SharedL0Addr = 0; + + constexpr int32_t HsumUbAddr = 0; + constexpr int32_t AccUbAddr = + HsumUbAddr + HalfHidden * HiddenSize * sizeof(half); + constexpr int32_t HUbAddr = AccUbAddr + HalfChunk * ChunkSize * sizeof(half); + constexpr int32_t RawUBBytes = + (HalfHidden * HiddenSize + HalfChunk * ChunkSize + HalfHidden * HiddenSize + + HalfChunk * ChunkSize + + (InplaceMaskApply ? 0 : HalfChunk * ChunkSize)) * + sizeof(half); + constexpr bool PreloadMask = RawUBBytes <= 72 * 1024; + constexpr bool AliasMaskIntoH = + !PreloadMask && (HalfHidden * HiddenSize >= HalfChunk * ChunkSize); + constexpr int32_t MaskUbAddr = + AliasMaskIntoH ? HUbAddr : HUbAddr + HalfHidden * HiddenSize * sizeof(half); + constexpr int32_t MaskedAccUbAddr = + InplaceMaskApply ? AccUbAddr : MaskUbAddr + HalfChunk * ChunkSize * sizeof(half); + + constexpr int32_t L0CBytes = + (Workspace2SlotElems > Workspace1SlotElems + ? (Workspace2SlotElems > ChunkElems ? Workspace2SlotElems : ChunkElems) + : (Workspace1SlotElems > ChunkElems ? Workspace1SlotElems : ChunkElems)) * + sizeof(float); + constexpr int32_t UBBytes = + (HalfHidden * HiddenSize + HalfChunk * ChunkSize + + (AliasMaskIntoH ? HalfHidden * HiddenSize + : HalfHidden * HiddenSize + HalfChunk * ChunkSize) + + (InplaceMaskApply ? 0 : HalfChunk * ChunkSize)) * + sizeof(half); + constexpr int32_t L1Bytes = + UseTwoStagePipeline ? (HNextL1Addr + Workspace2SlotElems * sizeof(half)) + : (AccL1Addr + Workspace1SlotElems * sizeof(half)); + static_assert((HiddenSize % 2) == 0, "HiddenSize must be even."); + static_assert((ChunkSize % 2) == 0, "ChunkSize must be even."); + static_assert(L0CBytes <= 112 * 1024, + "Tile sizes exceed the validated L0C budget for this minimum kernel."); + static_assert(L1Bytes <= 192 * 1024, + "Tile sizes exceed the validated L1 budget for this minimum kernel."); + static_assert(PreloadMask || AliasMaskIntoH, + "Current minimum kernel requires either a preloaded mask or H UB large enough to alias the mask."); + static_assert(UBBytes <= 72 * 1024, + "Tile sizes exceed the validated UB budget for this minimum kernel."); + + using ChunkGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using AccGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HiddenGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HalfAccGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HalfHiddenGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + using HalfMaskGlobal = + GlobalTensor, + BaseShape2D, + Layout::ND>; + + const int64_t total_work = batch_size * NumHeads; + const int64_t chunk_num = seq_len / ChunkSize; + const int64_t cid = get_block_idx(); + const int64_t vid = get_subblockid(); + set_ffts_base_addr(ffts_addr); + + L1Mat q_l1; + L1Mat k_l1; + L1Mat v_l1; + L1Mat h_l1; + L1Mat h_next_l1; + L1Mat acc_l1; + TASSIGN(q_l1, QL1Addr); + TASSIGN(k_l1, KL1Addr); + TASSIGN(v_l1, VL1Addr); + TASSIGN(h_l1, HL1Addr); + TASSIGN(h_next_l1, HNextL1Addr); + TASSIGN(acc_l1, AccL1Addr); + + TileAcc acc_l0; + TileAcc h_l0; + TileAcc o_l0; + TASSIGN(acc_l0, SharedL0Addr); + TASSIGN(h_l0, SharedL0Addr); + TASSIGN(o_l0, SharedL0Addr); + + UbVec hsum_ub; + UbVec h_ub; + UbVec acc_ub; + UbVec mask_ub; + UbVec masked_acc_ub; + TASSIGN(hsum_ub, HsumUbAddr); + TASSIGN(acc_ub, AccUbAddr); + TASSIGN(h_ub, HUbAddr); + TASSIGN(mask_ub, MaskUbAddr); + TASSIGN(masked_acc_ub, MaskedAccUbAddr); + +#if defined(__DAV_C220_CUBE__) + for (int64_t work_idx = 0; work_idx < (total_work + block_num - 1) / block_num; + ++work_idx) { + const int64_t pid = work_idx * block_num + cid; + if (pid >= total_work) { + continue; + } + + const int64_t by = pid % NumHeads; + const int64_t bz = pid / NumHeads; + const int64_t qkv_base = ((bz * NumHeads + by) * seq_len) * HiddenSize; + const int64_t workspace1_base = cid * Workspace1Elems; + const int64_t workspace2_base = cid * Workspace2Elems; + + if constexpr (UseTwoStagePipeline) { + const int32_t flag_base = static_cast((work_idx & 3) * 6); + // h_buf tells us which L1 buffer currently holds the "ready to use" + // prefix state for the output matmul. + int32_t h_buf = 0; + WaitCrossFlag(flag_base + 4); + HiddenGlobal zero_h_global(workspace_2 + workspace2_base + Workspace2SlotElems); + TLOAD(h_l1, zero_h_global); + pipe_barrier(PIPE_ALL); + + { + const int64_t chunk_base = qkv_base; + ChunkGlobal q_global(q + chunk_base); + ChunkGlobal k_global(k + chunk_base); + ChunkGlobal v_global(v + chunk_base); + TLOAD(q_l1, q_global); + TLOAD(k_l1, k_global); + TLOAD(v_l1, v_global); + pipe_barrier(PIPE_ALL); + + MatmulL1(acc_l0, q_l1, k_l1, + true); + AccGlobal acc_global(workspace_1 + workspace1_base); + TSTORE(acc_global, acc_l0); + pipe_barrier(PIPE_ALL); + + MatmulL1(h_l0, k_l1, v_l1, + true); + HiddenGlobal h_out_global(workspace_2 + workspace2_base); + TSTORE(h_out_global, h_l0); + pipe_barrier(PIPE_ALL); + SetCrossFlag(flag_base, 2); + } + + for (int64_t i = 0; i < chunk_num; ++i) { + const int32_t slot = static_cast(i & 1); + const int32_t next_slot = slot ^ 1; + const int64_t chunk_base = qkv_base + i * ChunkElems; + + if (i + 1 < chunk_num) { + // As in step 06, cube keeps one chunk ahead in the opposite slot. + const int64_t next_chunk_base = qkv_base + (i + 1) * ChunkElems; + const int64_t next_workspace1_base = + workspace1_base + next_slot * Workspace1SlotElems; + const int64_t next_workspace2_base = + workspace2_base + next_slot * Workspace2SlotElems; + + ChunkGlobal q_global(q + next_chunk_base); + ChunkGlobal k_global(k + next_chunk_base); + ChunkGlobal v_global(v + next_chunk_base); + TLOAD(q_l1, q_global); + TLOAD(k_l1, k_global); + TLOAD(v_l1, v_global); + pipe_barrier(PIPE_ALL); + + MatmulL1( + acc_l0, q_l1, k_l1, true); + AccGlobal acc_global(workspace_1 + next_workspace1_base); + TSTORE(acc_global, acc_l0); + pipe_barrier(PIPE_ALL); + + MatmulL1(h_l0, k_l1, + v_l1, true); + HiddenGlobal h_out_global(workspace_2 + next_workspace2_base); + TSTORE(h_out_global, h_l0); + pipe_barrier(PIPE_ALL); + SetCrossFlag(flag_base + next_slot, 2); + } + + WaitCrossFlag(flag_base + 2 + slot); + AccGlobal masked_acc_global(workspace_1 + workspace1_base + + slot * Workspace1SlotElems); + TLOAD(acc_l1, masked_acc_global); + ChunkGlobal q_global(q + chunk_base); + ChunkGlobal v_global(v + chunk_base); + TLOAD(q_l1, q_global); + TLOAD(v_l1, v_global); + if (i + 1 < chunk_num) { + // Step-07-specific optimization: overlap the next prefix-state load + // with the current chunk's Q/V reload. One of h_l1 / h_next_l1 is + // being consumed now, while the other becomes the "next" buffer. + HiddenGlobal next_h_global(workspace_2 + workspace2_base + + slot * Workspace2SlotElems); + if (h_buf == 0) { + TLOAD(h_next_l1, next_h_global); + } else { + TLOAD(h_l1, next_h_global); + } + } + pipe_barrier(PIPE_ALL); + + MatmulL1(o_l0, acc_l1, + v_l1, true); + if (h_buf == 0) { + MatmulL1(o_l0, q_l1, + h_l1, false); + } else { + MatmulL1(o_l0, q_l1, + h_next_l1, + false); + } + + ChunkGlobal o_global(o + chunk_base); + TSTORE(o_global, o_l0); + pipe_barrier(PIPE_ALL); + + if (i + 1 < chunk_num) { + // Swap roles: the buffer we just prefetched becomes the current one + // for the next loop iteration. + h_buf ^= 1; + } + } + SetCrossFlag(flag_base + 5, 2); + } else { + WaitCrossFlag(1); + + for (int64_t i = 0; i < chunk_num; ++i) { + const int64_t chunk_base = qkv_base + i * ChunkElems; + + ChunkGlobal q_global(q + chunk_base); + ChunkGlobal k_global(k + chunk_base); + ChunkGlobal v_global(v + chunk_base); + HiddenGlobal h_global(workspace_2 + workspace2_base); + + TLOAD(q_l1, q_global); + TLOAD(k_l1, k_global); + TLOAD(v_l1, v_global); + TLOAD(h_l1, h_global); + pipe_barrier(PIPE_ALL); + + MatmulL1(acc_l0, q_l1, k_l1, + true); + AccGlobal acc_global(workspace_1 + workspace1_base); + TSTORE(acc_global, acc_l0); + pipe_barrier(PIPE_ALL); + + MatmulL1(h_l0, k_l1, v_l1, + true); + HiddenGlobal h_out_global(workspace_2 + workspace2_base); + TSTORE(h_out_global, h_l0); + pipe_barrier(PIPE_ALL); + SetCrossFlag(0, 2); + + WaitCrossFlag(1); + AccGlobal masked_acc_global(workspace_1 + workspace1_base); + TLOAD(acc_l1, masked_acc_global); + pipe_barrier(PIPE_ALL); + + MatmulL1(o_l0, acc_l1, + v_l1, true); + MatmulL1(o_l0, q_l1, h_l1, + false); + + ChunkGlobal o_global(o + chunk_base); + TSTORE(o_global, o_l0); + pipe_barrier(PIPE_ALL); + } + } + } +#endif + +#if defined(__DAV_C220_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + HalfMaskGlobal mask_global(causal_mask + vid * HalfChunk * ChunkSize); + if constexpr (PreloadMask) { + TLOAD(mask_ub, mask_global); + pipe_barrier(PIPE_ALL); + } + + for (int64_t work_idx = 0; work_idx < (total_work + block_num - 1) / block_num; + ++work_idx) { + const int64_t pid = work_idx * block_num + cid; + if (pid >= total_work) { + continue; + } + + const int64_t workspace1_base = cid * Workspace1Elems; + const int64_t workspace2_base = cid * Workspace2Elems; + + TEXPANDS(hsum_ub, 0.0f); + pipe_barrier(PIPE_ALL); + if constexpr (UseTwoStagePipeline) { + const int32_t flag_base = static_cast((work_idx & 3) * 6); + HalfHiddenGlobal init_h_global_0(workspace_2 + workspace2_base + + vid * HalfHidden * HiddenSize); + HalfHiddenGlobal init_h_global_1(workspace_2 + workspace2_base + + Workspace2SlotElems + + vid * HalfHidden * HiddenSize); + TSTORE(init_h_global_0, hsum_ub); + TSTORE(init_h_global_1, hsum_ub); + pipe_barrier(PIPE_ALL); + SetCrossFlag(flag_base + 4, 2); + + for (int64_t i = 0; i < chunk_num; ++i) { + const int32_t slot = static_cast(i & 1); + WaitCrossFlag(flag_base + slot); + + const int64_t slot_workspace1_base = + workspace1_base + slot * Workspace1SlotElems; + const int64_t slot_workspace2_base = + workspace2_base + slot * Workspace2SlotElems; + HalfAccGlobal acc_global(workspace_1 + slot_workspace1_base + + vid * HalfChunk * ChunkSize); + HalfHiddenGlobal h_global(workspace_2 + slot_workspace2_base + + vid * HalfHidden * HiddenSize); + TLOAD(acc_ub, acc_global); + TLOAD(h_ub, h_global); + pipe_barrier(PIPE_ALL); + + TADD(hsum_ub, hsum_ub, h_ub); + pipe_barrier(PIPE_ALL); + if constexpr (!PreloadMask) { + TLOAD(mask_ub, mask_global); + pipe_barrier(PIPE_ALL); + } + if constexpr (InplaceMaskApply) { + // Reusing acc_ub avoids one extra temporary tile in UB. + TMUL(acc_ub, acc_ub, mask_ub); + } else { + TMUL(masked_acc_ub, acc_ub, mask_ub); + } + pipe_barrier(PIPE_ALL); + if constexpr (InplaceMaskApply) { + TSTORE(acc_global, acc_ub); + } else { + TSTORE(acc_global, masked_acc_ub); + } + TSTORE(h_global, hsum_ub); + pipe_barrier(PIPE_ALL); + SetCrossFlag(flag_base + 2 + slot, 2); + } + WaitCrossFlag(flag_base + 5); + } else { + HalfHiddenGlobal init_h_global(workspace_2 + workspace2_base + + vid * HalfHidden * HiddenSize); + TSTORE(init_h_global, hsum_ub); + pipe_barrier(PIPE_ALL); + SetCrossFlag(1, 2); + + for (int64_t i = 0; i < chunk_num; ++i) { + WaitCrossFlag(0); + + HalfAccGlobal acc_global(workspace_1 + workspace1_base + + vid * HalfChunk * ChunkSize); + HalfHiddenGlobal h_global(workspace_2 + workspace2_base + + vid * HalfHidden * HiddenSize); + TLOAD(acc_ub, acc_global); + TLOAD(h_ub, h_global); + pipe_barrier(PIPE_ALL); + + TADD(hsum_ub, hsum_ub, h_ub); + pipe_barrier(PIPE_ALL); + if constexpr (!PreloadMask) { + TLOAD(mask_ub, mask_global); + pipe_barrier(PIPE_ALL); + } + if constexpr (InplaceMaskApply) { + TMUL(acc_ub, acc_ub, mask_ub); + } else { + TMUL(masked_acc_ub, acc_ub, mask_ub); + } + pipe_barrier(PIPE_ALL); + if constexpr (InplaceMaskApply) { + TSTORE(acc_global, acc_ub); + } else { + TSTORE(acc_global, masked_acc_ub); + } + TSTORE(h_global, hsum_ub); + pipe_barrier(PIPE_ALL); + SetCrossFlag(1, 2); + } + } + } +#endif +} + +extern "C" __global__ AICORE void launch_linear_attention( + __gm__ uint8_t *q, __gm__ uint8_t *k, __gm__ uint8_t *v, + __gm__ uint8_t *workspace_1, __gm__ uint8_t *workspace_2, + __gm__ uint8_t *causal_mask, __gm__ uint8_t *o, int64_t batch_size, + int64_t seq_len, uint64_t ffts_addr) { + main_kernel( + reinterpret_cast<__gm__ half *>(q), reinterpret_cast<__gm__ half *>(k), + reinterpret_cast<__gm__ half *>(v), + reinterpret_cast<__gm__ half *>(workspace_1), + reinterpret_cast<__gm__ half *>(workspace_2), + reinterpret_cast<__gm__ half *>(causal_mask), + reinterpret_cast<__gm__ half *>(o), batch_size, seq_len, ffts_addr); +} + +extern "C" void call_kernel(uint32_t blockDim, void *stream, uint8_t *q, + uint8_t *k, uint8_t *v, uint8_t *workspace_1, + uint8_t *workspace_2, uint8_t *causal_mask, + uint8_t *o, + int64_t batch_size, int64_t seq_len) { + uint32_t ffts_len = 0; + uint64_t ffts_addr = 0; + rtGetC2cCtrlAddr(&ffts_addr, &ffts_len); + launch_linear_attention<<>>( + q, k, v, workspace_1, workspace_2, causal_mask, o, batch_size, seq_len, + ffts_addr); +} diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/07_l1_prefetching/run_linear_attention.py b/examples/jit_cpp/linear_attention/optimize_step_by_step/07_l1_prefetching/run_linear_attention.py new file mode 100644 index 00000000..db8b9633 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/07_l1_prefetching/run_linear_attention.py @@ -0,0 +1,33 @@ +from pathlib import Path +import sys + +COMMON_DIR = Path(__file__).resolve().parents[1] / "common" +if str(COMMON_DIR) not in sys.path: + sys.path.insert(0, str(COMMON_DIR)) + +from jit_util_linear_attention import BLOCK_DIM, get_causal_mask, jit_compile +from linear_attention_shared import run_correctness_cases, run_dynamic_kernel + + +def run_kernel(src, q, k, v, chunk_size): + return run_dynamic_kernel( + src, + q, + k, + v, + chunk_size, + jit_compile=jit_compile, + block_dim=BLOCK_DIM, + stage_count=2, + use_mask=True, + mask_factory=get_causal_mask, + ) + + +def main(): + test_configs = [(1, 2, 64, 128, 64), (1, 2, 256, 128, 64), (4, 2, 128, 128, 64), (8, 2, 512, 128, 64), (10, 2, 512, 128, 64), (16, 2, 256, 128, 64), (32, 2, 128, 128, 64), (1, 2, 1024, 128, 64), (8, 2, 2048, 128, 64), (2, 2, 4096, 128, 64), (16, 2, 1024, 128, 64), (50, 20, 128, 128, 64), (1, 2, 128, 128, 128), (4, 2, 512, 128, 128), (16, 20, 1024, 128, 128)] + run_correctness_cases(__file__, test_configs, run_kernel) + + +if __name__ == "__main__": + main() diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/README.md b/examples/jit_cpp/linear_attention/optimize_step_by_step/README.md new file mode 100644 index 00000000..698a9b9b --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/README.md @@ -0,0 +1,87 @@ +# Linear Attention Step-By-Step Optimization + +This folder turns the historical optimization trail of the `jit_cpp/linear_attention` example into a runnable tutorial ladder. + +Each numbered directory contains a runnable snapshot of one major optimization step, copied onto the current `linear_attn` branch for teaching purposes. + +## Learning Path + +Suggested reading order: +1. `01_naive_static_shape` +2. `02_naive_dynamic_shape` +3. `03_cached_mask` +4. `03a_fast_mask_construct` +5. `04_chunk128` +6. `05_l0_double_buffer` +7. `06_two_slot_cv_pipeline` +8. `07_l1_prefetching` + +## What Each Step Teaches +- `01_naive_static_shape`: the smallest fixed-shape PTO-ISA kernel; easiest place to understand workspace layout and tensor indexing +- `02_naive_dynamic_shape`: move `B` and `L` to runtime, keep launch shape fixed to the number of cores, and loop over work items inside the kernel +- `03_cached_mask`: precompute the triangular mask in PyTorch and apply it with vector tile ops instead of scalar loops +- `03a_fast_mask_construct`: keep mask construction inside the kernel, but build the triangular mask once with vector writes and reuse it for all chunks +- `04_chunk128`: raise chunk size from `64` to `128` to increase arithmetic intensity +- `05_l0_double_buffer`: split `K=128` into `2 x 64` cube phases and overlap extract with compute +- `06_two_slot_cv_pipeline`: let cube prepare chunk `i + 1` while vector finishes chunk `i` +- `07_l1_prefetching`: keep two `H`-state L1 tiles so the next hidden-state chunk is loaded early + +## Quick Validated Progression + +These are the short smoke-test numbers produced while verifying the intermediate tutorial samples on this machine. They use the small `--quick` benchmark shapes so the whole ladder can be checked end to end in a reasonable amount of time. The first two rows below are from the newly simplified beginner kernels. + +| Step | Quick validated result | +| --- | --- | +| `01_naive_static_shape` | fixed-shape smoke benchmark: `(2, 2, 512, 128, 64)` in `0.275 ms` | +| `02_naive_dynamic_shape` | `5.21 TFLOP/s` at `(16, 20, 1024, 128, 64)` | +| `03_cached_mask` | `28.47 TFLOP/s` at `(16, 20, 1024, 128, 64)` | +| `03a_fast_mask_construct` | `29.79 TFLOP/s` at `(16, 20, 1024, 128, 64)` | +| `04_chunk128` | `49.73 TFLOP/s` at `(16, 20, 1024, 128, 128)` | +| `05_l0_double_buffer` | `52.57 TFLOP/s` at `(16, 20, 1024, 128, 128)` | +| `06_two_slot_cv_pipeline` | `63.15 TFLOP/s` at `(16, 20, 1024, 128, 128)` | + +The new `03a` result was measured with `python benchmark_linear_attention.py --shapes 16x20x1024x128x64 --warmup 2 --repeats 5`. +For the same command on this machine, step `02` measured `5.21 TFLOP/s` and step `03` measured `28.47 TFLOP/s`, so the fast on-the-fly mask closes essentially all of the gap to the cached-mask version without loading a mask tensor from global memory. + +## Final Step Full Benchmark + +The final tutorial step is intentionally identical to the current main example kernel and JIT helper. It should therefore be compared with the same full benchmark command, not with `--quick`. + +| Target | Command | Best validated result | +| --- | --- | --- | +| `07_l1_prefetching` | `python benchmark_linear_attention.py --warmup 2 --repeats 5` | `77.71 TFLOP/s` / `565.43 GiB/s` at `(12, 20, 8192, 128, 128)` | +| main example | `python benchmark_linear_attention.py --warmup 2 --repeats 5` | `77.97 TFLOP/s` / `567.34 GiB/s` at `(24, 20, 6144, 128, 128)` | + +Those two results were measured sequentially with the same command. The small gap is normal benchmark noise; the important point is that the final tutorial step reaches the same `~78 TFLOP/s` performance class as the current main example. + +## How To Read The Early Steps +- `01` and `02` include `numpy_sim.py`, which intentionally hides the real flag/synchronization details and replaces parallel core execution with a sequential loop. +- Those NumPy simulations focus on the dataflow: which tiles are loaded, how workspace is updated, and how chunked causal masking interacts with the running hidden state. + +## Notes +- `01` and `02` also include NumPy simulations that explain the tensor indexing and workspace layout without the real PTO synchronization details. +- The early steps were rewritten into smaller teaching kernels so they stay close to the NumPy emulation and avoid distracting helper boilerplate. +- JIT outputs are redirected into a local `compiled_lib/` subdirectory so the tutorial folders stay tidy. +- The later steps intentionally keep the code close to the optimized working snapshots, while the step README files explain the key optimization idea. +- Benchmark large-shape steps one process at a time. Running multiple heavy NPU benchmarks concurrently can lower measured TFLOP/s and make the step-to-step comparison misleading. + +## Suggested Validation Order + +Inside each step directory: + +```bash +python run_linear_attention.py +python benchmark_linear_attention.py --quick --warmup 1 --repeats 2 +``` + +For the final step, run the full table instead of `--quick`: + +```bash +python benchmark_linear_attention.py --warmup 2 --repeats 5 +``` + +For `01` and `02`, also run: + +```bash +python numpy_sim.py +``` diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/common/__init__.py b/examples/jit_cpp/linear_attention/optimize_step_by_step/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/common/jit_shared.py b/examples/jit_cpp/linear_attention/optimize_step_by_step/common/jit_shared.py new file mode 100644 index 00000000..b5b536e3 --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/common/jit_shared.py @@ -0,0 +1,281 @@ +import ctypes +import os +import subprocess +from functools import lru_cache + +import torch + +ASCEND_TOOLKIT_HOME = os.environ["ASCEND_TOOLKIT_HOME"] +PTO_LIB_PATH = os.environ.get("PTO_LIB_PATH", ASCEND_TOOLKIT_HOME) +BLOCK_DIM = int(getattr(torch.npu.get_device_properties("npu:0"), "cube_core_num", 20)) +AICORE_ARCH = "dav-c220" + +STACK_TUNING_FLAGS = [ + "-mllvm", + "-cce-aicore-stack-size=0x8000", + "-mllvm", + "-cce-aicore-function-stack-size=0x8000", + "-mllvm", + "-cce-aicore-record-overflow=true", +] + +OPTIMIZED_KERNEL_FLAGS = [ + *STACK_TUNING_FLAGS, + "-mllvm", + "-cce-aicore-dcci-insert-for-scalar=false", + "-Wno-macro-redefined", + "-Wno-ignored-attributes", +] + +STEP03_KERNEL_FLAGS = [ + *STACK_TUNING_FLAGS, + "-mllvm", + "-cce-aicore-addr-transform", + "-mllvm", + "-cce-aicore-dcci-insert-for-scalar=false", + "-DL2_CACHE_HINT", + "-Wno-macro-redefined", + "-Wno-ignored-attributes", +] + + +def _verify_aicore_predefines(kernel_cpp: str, include_flags: list[str], timeout: int) -> None: + probe = [ + "bisheng", + "-dM", + "-E", + "-xcce", + f"--cce-aicore-arch={AICORE_ARCH}", + *include_flags, + kernel_cpp, + ] + result = subprocess.run( + probe, + check=True, + timeout=min(timeout, 30), + capture_output=True, + text=True, + ) + macros = result.stdout + expected = ("#define __CCE_AICORE__ 220", "__DAV_C220_CUBE__", "__DAV_C220_VEC__") + if not all(token in macros for token in expected): + raise RuntimeError( + "bisheng did not expose the expected dav-c220 AICORE predefines " + "for this compile command. That can cause to " + "skip PTO tile/instruction headers in some preprocessing paths." + ) + + +def compile_cpp( + kernel_cpp: str, + *, + output_name: str, + std: str, + defines: list[str] | None = None, + extra_flags: list[str] | None = None, + verbose: bool = False, + timeout: int = 180, +) -> str: + lib_dir = os.path.join(os.path.dirname(kernel_cpp), "compiled_lib") + os.makedirs(lib_dir, exist_ok=True) + lib_path = os.path.join(lib_dir, output_name) + include_flags = [ + f"-I{PTO_LIB_PATH}/include", + f"-I{ASCEND_TOOLKIT_HOME}/include", + f"-I{ASCEND_TOOLKIT_HOME}/pkg_inc", + f"-I{ASCEND_TOOLKIT_HOME}/pkg_inc/runtime", + f"-I{ASCEND_TOOLKIT_HOME}/pkg_inc/profiling", + ] + + flags = [ + "-fPIC", + "-shared", + "-xcce", + "-DMEMORY_BASE", + "-O2", + f"-std={std}", + f"--cce-aicore-arch={AICORE_ARCH}", + *(extra_flags or []), + *include_flags, + *(defines or []), + ] + + command = ["bisheng", *flags, kernel_cpp, "-o", lib_path] + if verbose: + print("compile command:", " ".join(command)) + + try: + _verify_aicore_predefines(kernel_cpp, include_flags, timeout) + subprocess.run(command, timeout=timeout, check=True) + except Exception as exc: + raise RuntimeError(f"Compile failed: {exc}") from exc + + if verbose: + print(f"generated {lib_path}") + return lib_path + + +def torch_to_ctypes(tensor: torch.Tensor) -> ctypes.c_void_p: + return ctypes.c_void_p(tensor.data_ptr()) + + +@lru_cache(maxsize=None) +def get_causal_mask(chunk_size: int, dtype: torch.dtype, device_index: int): + vec_num = 2 + if chunk_size % vec_num != 0: + raise ValueError("chunk_size must be divisible by 2 for the causal mask.") + half_chunk = chunk_size // vec_num + mask = torch.zeros( + (vec_num, half_chunk, chunk_size), + device=f"npu:{device_index}", + dtype=dtype, + ) + for vid in range(vec_num): + rows = torch.arange(vid * half_chunk, (vid + 1) * half_chunk, device=mask.device) + cols = torch.arange(chunk_size, device=mask.device) + mask[vid] = (rows[:, None] >= cols[None, :]).to(dtype) + return mask.contiguous() + + +def _load_cdll(lib_path: str): + return ctypes.CDLL(os.path.abspath(lib_path)) + + +def load_static_nomask_lib(lib_path: str): + lib = _load_cdll(lib_path) + lib.call_kernel.argtypes = [ + ctypes.c_uint32, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ] + lib.call_kernel.restype = None + + def linear_attention_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + workspace_1: torch.Tensor, + workspace_2: torch.Tensor, + o: torch.Tensor, + block_dim: int | None = None, + stream_ptr=None, + ): + if block_dim is None: + block_dim = q.shape[0] * q.shape[1] + if stream_ptr is None: + stream_ptr = torch.npu.current_stream()._as_parameter_ + + lib.call_kernel( + block_dim, + stream_ptr, + torch_to_ctypes(q), + torch_to_ctypes(k), + torch_to_ctypes(v), + torch_to_ctypes(workspace_1), + torch_to_ctypes(workspace_2), + torch_to_ctypes(o), + ) + + return linear_attention_func + + +def load_dynamic_nomask_lib(lib_path: str): + lib = _load_cdll(lib_path) + lib.call_kernel.argtypes = [ + ctypes.c_uint32, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_int64, + ctypes.c_int64, + ] + lib.call_kernel.restype = None + + def linear_attention_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + workspace_1: torch.Tensor, + workspace_2: torch.Tensor, + o: torch.Tensor, + block_dim: int | None = None, + stream_ptr=None, + ): + if block_dim is None: + block_dim = BLOCK_DIM + if stream_ptr is None: + stream_ptr = torch.npu.current_stream()._as_parameter_ + + lib.call_kernel( + block_dim, + stream_ptr, + torch_to_ctypes(q), + torch_to_ctypes(k), + torch_to_ctypes(v), + torch_to_ctypes(workspace_1), + torch_to_ctypes(workspace_2), + torch_to_ctypes(o), + q.shape[0], + q.shape[2], + ) + + return linear_attention_func + + +def load_dynamic_mask_lib(lib_path: str): + lib = _load_cdll(lib_path) + lib.call_kernel.argtypes = [ + ctypes.c_uint32, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_int64, + ctypes.c_int64, + ] + lib.call_kernel.restype = None + + def linear_attention_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + workspace_1: torch.Tensor, + workspace_2: torch.Tensor, + causal_mask: torch.Tensor, + o: torch.Tensor, + block_dim: int | None = None, + stream_ptr=None, + ): + if block_dim is None: + block_dim = BLOCK_DIM + if stream_ptr is None: + stream_ptr = torch.npu.current_stream()._as_parameter_ + + lib.call_kernel( + block_dim, + stream_ptr, + torch_to_ctypes(q), + torch_to_ctypes(k), + torch_to_ctypes(v), + torch_to_ctypes(workspace_1), + torch_to_ctypes(workspace_2), + torch_to_ctypes(causal_mask), + torch_to_ctypes(o), + q.shape[0], + q.shape[2], + ) + + return linear_attention_func diff --git a/examples/jit_cpp/linear_attention/optimize_step_by_step/common/linear_attention_shared.py b/examples/jit_cpp/linear_attention/optimize_step_by_step/common/linear_attention_shared.py new file mode 100644 index 00000000..1e2708fe --- /dev/null +++ b/examples/jit_cpp/linear_attention/optimize_step_by_step/common/linear_attention_shared.py @@ -0,0 +1,338 @@ +import argparse +import os +from statistics import median + +import torch +import torch_npu # noqa: F401 + +DTYPE = torch.float16 +RTOL = 1e-2 + + +def kernel_src_path(script_file: str) -> str: + return os.path.join(os.path.dirname(os.path.abspath(script_file)), "linear_attention.cpp") + + +def ref_linear_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + batch, heads, seq_len, hidden = q.shape + q = q.float() + k = k.float() + v = v.float() + + state = torch.zeros((batch, heads, hidden, hidden), device=q.device, dtype=torch.float32) + output = torch.zeros((batch, heads, seq_len, hidden), device=q.device, dtype=torch.float32) + + for index in range(seq_len): + q_t = q[:, :, index, :] + k_t = k[:, :, index, :] + v_t = v[:, :, index, :] + state = state + torch.einsum("bhi,bhj->bhij", k_t, v_t) + output[:, :, index, :] = torch.einsum("bhi,bhij->bhj", q_t, state) + + return output.to(DTYPE) + + +def make_inputs(batch: int, heads: int, seq_len: int, hidden: int): + q = torch.randn((batch, heads, seq_len, hidden), device="npu", dtype=DTYPE) + k = torch.randn((batch, heads, seq_len, hidden), device="npu", dtype=DTYPE) + v = torch.randn((batch, heads, seq_len, hidden), device="npu", dtype=DTYPE) + q = q / (q.pow(2).sum(dim=-1, keepdim=True).sqrt() + 1e-6) + k = k / (k.pow(2).sum(dim=-1, keepdim=True).sqrt() + 1e-6) + return q, k, v + + +def atol_for_seq(seq_len: int) -> float: + if seq_len >= 4096: + return 4e-2 + if seq_len >= 2048: + return 2e-2 + return 1e-2 + + +def validate_output(output: torch.Tensor, reference: torch.Tensor, seq_len: int): + torch.testing.assert_close( + output.cpu(), + reference.cpu(), + rtol=RTOL, + atol=atol_for_seq(seq_len), + ) + + +def _workspace_shape(block_dim: int, stage_count: int, rows: int, cols: int): + if stage_count == 1: + return (block_dim, rows, cols) + return (block_dim, stage_count, rows, cols) + + +def run_dynamic_kernel( + src: str, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + chunk_size: int, + *, + jit_compile, + block_dim: int, + stage_count: int, + use_mask: bool, + mask_factory=None, +): + batch, heads, seq_len, hidden = q.shape + if seq_len % chunk_size != 0: + raise ValueError("This PTO-ISA example currently requires L to be a multiple of C.") + + kernel = jit_compile(src, num_heads=heads, hidden_size=hidden, chunk_size=chunk_size) + workspace_1 = torch.zeros( + _workspace_shape(block_dim, stage_count, chunk_size, chunk_size), + device=q.device, + dtype=DTYPE, + ) + workspace_2 = torch.zeros( + _workspace_shape(block_dim, stage_count, hidden, hidden), + device=q.device, + dtype=DTYPE, + ) + output = torch.zeros((batch, heads, seq_len, hidden), device=q.device, dtype=DTYPE) + + if use_mask: + causal_mask = mask_factory(chunk_size, DTYPE, q.device.index or 0) + kernel( + q, + k, + v, + workspace_1, + workspace_2, + causal_mask, + output, + block_dim=block_dim, + ) + else: + kernel(q, k, v, workspace_1, workspace_2, output, block_dim=block_dim) + + torch.npu.synchronize() + return output + + +def run_correctness_cases(script_file: str, test_configs, run_kernel): + torch.manual_seed(0) + torch.npu.set_device("npu:0") + src = kernel_src_path(script_file) + + for batch, heads, seq_len, hidden, chunk in test_configs: + print(f"Testing B={batch}, H={heads}, L={seq_len}, D={hidden}, C={chunk} (B*H={batch * heads})") + q, k, v = make_inputs(batch, heads, seq_len, hidden) + output = run_kernel(src, q, k, v, chunk) + reference = ref_linear_attention(q, k, v) + validate_output(output, reference, seq_len) + print(" passed!") + + print("Kernel Output Match!") + + +def parse_shapes(shape_text: str): + shapes = [] + for item in shape_text.split(";"): + item = item.strip() + if not item: + continue + parts = [int(x) for x in item.split("x")] + if len(parts) != 5: + raise ValueError( + "Each shape must be formatted as BxHxLxDxC, e.g. 16x20x1024x128x64" + ) + shapes.append(tuple(parts)) + return shapes + + +def estimate_flops(batch: int, heads: int, seq_len: int, hidden: int, chunk: int) -> int: + if seq_len % chunk != 0: + raise ValueError("This benchmark requires L to be a multiple of C.") + chunk_count = seq_len // chunk + flops_per_chunk = 4 * chunk * hidden * (chunk + hidden) + return batch * heads * chunk_count * flops_per_chunk + + +def estimate_gm_bytes( + batch: int, + heads: int, + seq_len: int, + hidden: int, + chunk: int, + *, + include_workspace: bool, + include_mask: bool, +) -> int: + if seq_len % chunk != 0: + raise ValueError("This benchmark requires L to be a multiple of C.") + chunk_count = seq_len // chunk + if include_workspace: + workspace_init_bytes = 2 * hidden * hidden + bytes_per_chunk = 8 * chunk * hidden + 8 * chunk * chunk + 8 * hidden * hidden + return batch * heads * (workspace_init_bytes + chunk_count * bytes_per_chunk) + + qkv_and_output_bytes = chunk_count * (4 * chunk * hidden * 2) + mask_bytes = chunk * chunk * 2 if include_mask else 0 + return batch * heads * qkv_and_output_bytes + mask_bytes + + +def measure_kernel_ms(run_once, warmup: int, repeats: int) -> float: + for _ in range(warmup): + run_once() + torch.npu.synchronize() + + samples_ms = [] + for _ in range(repeats): + start = torch.npu.Event(enable_timing=True) + end = torch.npu.Event(enable_timing=True) + start.record() + run_once() + end.record() + torch.npu.synchronize() + samples_ms.append(start.elapsed_time(end)) + return median(samples_ms) + + +def benchmark_dynamic_kernel( + src: str, + *, + batch: int, + heads: int, + seq_len: int, + hidden: int, + chunk: int, + warmup: int, + repeats: int, + jit_compile, + block_dim: int, + stage_count: int, + use_mask: bool, + include_workspace_bytes: bool, + mask_factory=None, +): + kernel = jit_compile(src, num_heads=heads, hidden_size=hidden, chunk_size=chunk) + q, k, v = make_inputs(batch, heads, seq_len, hidden) + workspace_1 = torch.zeros( + _workspace_shape(block_dim, stage_count, chunk, chunk), + device="npu", + dtype=DTYPE, + ) + workspace_2 = torch.zeros( + _workspace_shape(block_dim, stage_count, hidden, hidden), + device="npu", + dtype=DTYPE, + ) + output = torch.zeros((batch, heads, seq_len, hidden), device="npu", dtype=DTYPE) + causal_mask = mask_factory(chunk, DTYPE, 0) if use_mask else None + + def run_once(): + if use_mask: + kernel( + q, + k, + v, + workspace_1, + workspace_2, + causal_mask, + output, + block_dim=block_dim, + ) + else: + kernel(q, k, v, workspace_1, workspace_2, output, block_dim=block_dim) + + median_ms = measure_kernel_ms(run_once, warmup=warmup, repeats=repeats) + seconds = median_ms / 1e3 + flops = estimate_flops(batch, heads, seq_len, hidden, chunk) + gm_bytes = estimate_gm_bytes( + batch, + heads, + seq_len, + hidden, + chunk, + include_workspace=include_workspace_bytes, + include_mask=use_mask, + ) + return { + "shape": (batch, heads, seq_len, hidden, chunk), + "median_ms": median_ms, + "tflops": flops / seconds / 1e12, + "gib_s": gm_bytes / seconds / (2**30), + "flops": flops, + "gm_bytes": gm_bytes, + } + + +def benchmark_cli( + *, + script_file: str, + default_shapes, + quick_shapes, + benchmark_shape, + throughput_hunt_shapes=None, + description: str = "Benchmark the standalone PTO-ISA linear attention kernel.", +): + parser = argparse.ArgumentParser(description=description) + parser.add_argument("--warmup", type=int, default=5) + parser.add_argument("--repeats", type=int, default=20) + parser.add_argument( + "--shapes", + type=str, + default="", + help="Semicolon-separated BxHxLxDxC list, e.g. 16x20x1024x128x64;8x20x4096x128x64", + ) + parser.add_argument( + "--quick", + action="store_true", + help="Run a shorter preset shape list.", + ) + if throughput_hunt_shapes is not None: + parser.add_argument( + "--throughput-hunt", + action="store_true", + help="Run a larger-shape preset to search for higher steady-state utilization.", + ) + args = parser.parse_args() + + torch.manual_seed(0) + torch.npu.set_device("npu:0") + src = kernel_src_path(script_file) + + if args.shapes: + shapes = parse_shapes(args.shapes) + elif throughput_hunt_shapes is not None and args.throughput_hunt: + shapes = throughput_hunt_shapes + elif args.quick: + shapes = quick_shapes + else: + shapes = default_shapes + + header = f"{'shape (B,H,L,D,C)':>24} {'ms':>9} {'TFLOP/s':>10} {'GiB/s':>10}" + print(header) + print("-" * len(header)) + + results = [] + for batch, heads, seq_len, hidden, chunk in shapes: + print(f"Running {batch}x{heads}x{seq_len}x{hidden}x{chunk} ...") + result = benchmark_shape( + src, + batch=batch, + heads=heads, + seq_len=seq_len, + hidden=hidden, + chunk=chunk, + warmup=args.warmup, + repeats=args.repeats, + ) + results.append(result) + print( + f"{str(result['shape']):>24} " + f"{result['median_ms']:>9.3f} " + f"{result['tflops']:>10.2f} " + f"{result['gib_s']:>10.2f}" + ) + + if results: + best_tflops = max(results, key=lambda x: x["tflops"]) + best_bw = max(results, key=lambda x: x["gib_s"]) + print("\nBest throughput:") + print(f" TFLOP/s: {best_tflops['tflops']:.2f} at shape {best_tflops['shape']}") + print(f" GiB/s: {best_bw['gib_s']:.2f} at shape {best_bw['shape']}") diff --git a/examples/jit_cpp/linear_attention/run_linear_attention.py b/examples/jit_cpp/linear_attention/run_linear_attention.py new file mode 100644 index 00000000..25096819 --- /dev/null +++ b/examples/jit_cpp/linear_attention/run_linear_attention.py @@ -0,0 +1,308 @@ +import math +import os +from functools import lru_cache + +import torch +import torch_npu # noqa: F401 + +from jit_util_linear_attention import BLOCK_DIM, get_causal_mask, jit_compile + +DTYPE = torch.float16 +RTOL = 1e-2 + + +def _to_seq_first(x: torch.Tensor, head_first: bool) -> torch.Tensor: + return x.transpose(1, 2).contiguous() if head_first else x.contiguous() + + +def _from_seq_first(x: torch.Tensor, head_first: bool) -> torch.Tensor: + return x.transpose(1, 2).contiguous() if head_first else x.contiguous() + + +def _apply_gating( + q: torch.Tensor, + k: torch.Tensor, + g: torch.Tensor | None, + *, + head_first: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + if g is None: + return q, k + gate = torch.exp(g.float()).to(q.dtype) + inv_gate = torch.exp(-g.float()).to(k.dtype) + if head_first: + return q * gate.unsqueeze(-1), k * inv_gate.unsqueeze(-1) + return q * gate.unsqueeze(-1), k * inv_gate.unsqueeze(-1) + + +def _build_precomputed_h( + k: torch.Tensor, + v: torch.Tensor, + chunk_size: int, + *, + head_first: bool, + cu_seqlens: torch.Tensor | None = None, +) -> torch.Tensor: + k_seq = _to_seq_first(k, head_first).float() + v_seq = _to_seq_first(v, head_first).float() + _, total_t, num_heads, hidden = k_seq.shape + states = [] + + if cu_seqlens is None: + batch = k_seq.shape[0] + state = torch.zeros((batch, num_heads, hidden, hidden), device=k.device, dtype=torch.float32) + chunk_num = math.ceil(k_seq.shape[1] / chunk_size) + for i in range(chunk_num): + states.append(state.to(DTYPE)) + start = i * chunk_size + end = min(start + chunk_size, k_seq.shape[1]) + state = state + torch.einsum( + "bthd,bthe->bhde", + k_seq[:, start:end], + v_seq[:, start:end], + ) + return torch.stack(states, dim=1).contiguous().view(batch * chunk_num, num_heads, hidden, hidden) + + if head_first: + raise ValueError("cu_seqlens is only supported with seq-first inputs.") + + state = torch.zeros((num_heads, hidden, hidden), device=k.device, dtype=torch.float32) + for bos, eos in zip(cu_seqlens[:-1].tolist(), cu_seqlens[1:].tolist(), strict=False): + state.zero_() + for start in range(bos, eos, chunk_size): + states.append(state.to(DTYPE)) + end = min(start + chunk_size, eos) + state = state + torch.einsum( + "thd,the->hde", + k_seq[0, start:end], + v_seq[0, start:end], + ) + return torch.stack(states, dim=0).contiguous() + + +def ref_linear_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + g: torch.Tensor | None = None, + head_first: bool = True, + cu_seqlens: torch.Tensor | None = None, +) -> torch.Tensor: + q_scaled, k_scaled = _apply_gating(q, k, g, head_first=head_first) + q_seq = _to_seq_first(q_scaled, head_first).float() + k_seq = _to_seq_first(k_scaled, head_first).float() + v_seq = _to_seq_first(v, head_first).float() + out = torch.zeros_like(v_seq, dtype=torch.float32) + + if cu_seqlens is None: + batch, seq_len, num_heads, hidden = q_seq.shape + for b in range(batch): + h = torch.zeros((num_heads, hidden, hidden), device=q.device, dtype=torch.float32) + for i in range(seq_len): + k_i = k_seq[b, i] + v_i = v_seq[b, i] + h = h + torch.einsum("hd,he->hde", k_i, v_i) + out[b, i] = torch.einsum("hd,hde->he", q_seq[b, i], h) + else: + _, _, num_heads, hidden = q_seq.shape + for bos, eos in zip(cu_seqlens[:-1].tolist(), cu_seqlens[1:].tolist(), strict=False): + h = torch.zeros((num_heads, hidden, hidden), device=q.device, dtype=torch.float32) + for i in range(bos, eos): + k_i = k_seq[0, i] + v_i = v_seq[0, i] + h = h + torch.einsum("hd,he->hde", k_i, v_i) + out[0, i] = torch.einsum("hd,hde->he", q_seq[0, i], h) + + return _from_seq_first(out.to(DTYPE), head_first) + + +@lru_cache(maxsize=None) +def _compiled_kernel(src: str, h: int, d: int, c: int): + return jit_compile(src, num_heads=h, hidden_size=d, chunk_size=c) + + +def run_kernel( + src: str, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + chunk_size: int, + *, + g: torch.Tensor | None = None, + head_first: bool = True, + cu_seqlens: torch.Tensor | None = None, + mask_variant: str = "cached_mask", +): + if q.shape != k.shape or q.shape != v.shape: + raise ValueError("q, k, v must have identical shapes.") + if cu_seqlens is not None and head_first: + raise ValueError("cu_seqlens is only supported with seq-first inputs.") + + num_heads = q.shape[1] if head_first else q.shape[2] + hidden = q.shape[-1] + linear_attention_func = _compiled_kernel(src, num_heads, hidden, chunk_size) + causal_mask = get_causal_mask(chunk_size, DTYPE, q.device.index or 0) + if mask_variant not in {"cached_mask", "fast_onthefly"}: + raise ValueError(f"Unsupported mask_variant: {mask_variant}") + use_fast_mask = mask_variant == "fast_onthefly" + + if g is None and head_first and cu_seqlens is None and q.shape[2] % chunk_size == 0: + b, _, l, d = q.shape + workspace_1 = torch.zeros((BLOCK_DIM, 2, chunk_size, chunk_size), device=q.device, dtype=DTYPE) + workspace_2 = torch.zeros((BLOCK_DIM, 2, d, d), device=q.device, dtype=DTYPE) + o = torch.zeros((b, num_heads, l, d), device=q.device, dtype=DTYPE) + linear_attention_func( + q, + k, + v, + workspace_1, + workspace_2, + causal_mask, + o, + use_fast_mask=use_fast_mask, + block_dim=BLOCK_DIM, + ) + torch.npu.synchronize() + return o + + q_scaled, k_scaled = _apply_gating(q, k, g, head_first=head_first) + h_states = _build_precomputed_h( + k_scaled, + v, + chunk_size, + head_first=head_first, + cu_seqlens=cu_seqlens, + ) + workspace_1 = torch.zeros((BLOCK_DIM, 2, chunk_size, chunk_size), device=q.device, dtype=DTYPE) + o = torch.zeros_like(v) + linear_attention_func( + q_scaled.contiguous(), + k_scaled.contiguous(), + v.contiguous(), + workspace_1, + h_states.contiguous(), + causal_mask, + o, + cu_seqlens=cu_seqlens.contiguous() if cu_seqlens is not None else None, + seq_first=not head_first, + use_precomputed_h=True, + use_fast_mask=use_fast_mask, + batch_size_override=(len(cu_seqlens) - 1) if cu_seqlens is not None else None, + block_dim=BLOCK_DIM, + ) + torch.npu.synchronize() + return o + + +def _make_normalized(shape: tuple[int, ...]) -> torch.Tensor: + x = torch.randn(shape, device="npu", dtype=DTYPE) + return x / (x.pow(2).sum(dim=-1, keepdim=True).sqrt() + 1e-6) + + +def main(): + torch.manual_seed(0) + torch.npu.set_device("npu:0") + + src = os.path.join(os.path.dirname(os.path.abspath(__file__)), "linear_attention.cpp") + test_configs = [ + { + "label": "head_first fixed", + "shape": (4, 2, 128, 128), + "chunk": 64, + "head_first": True, + "g": None, + "cu_seqlens": None, + }, + { + "label": "seq_first fixed", + "shape": (4, 128, 2, 128), + "chunk": 64, + "head_first": False, + "g": None, + "cu_seqlens": None, + }, + { + "label": "seq_first gated", + "shape": (4, 128, 2, 128), + "chunk": 64, + "head_first": False, + "g": "random", + "cu_seqlens": None, + }, + { + "label": "seq_first uniform-zero gated", + "shape": (4, 128, 2, 128), + "chunk": 64, + "head_first": False, + "g": "zeros", + "cu_seqlens": None, + }, + { + "label": "seq_first varlen gated", + "shape": (1, 161, 2, 128), + "chunk": 64, + "head_first": False, + "g": "random", + "cu_seqlens": [0, 17, 96, 161], + }, + ] + mask_variants = ["cached_mask", "fast_onthefly"] + + for cfg in test_configs: + shape = cfg["shape"] + chunk = cfg["chunk"] + head_first = cfg["head_first"] + cu_seqlens = cfg["cu_seqlens"] + + q = _make_normalized(shape) + k = _make_normalized(shape) + v = torch.randn(shape, device="npu", dtype=DTYPE) + g = None + if cfg["g"] == "random": + g = torch.randn(shape[:-1], device="npu", dtype=torch.float32) + elif cfg["g"] == "zeros": + g = torch.zeros(shape[:-1], device="npu", dtype=torch.float32) + cu_tensor = ( + torch.tensor(cu_seqlens, device="npu", dtype=torch.int32) + if cu_seqlens is not None + else None + ) + ref_o = ref_linear_attention( + q, + k, + v, + g=g, + head_first=head_first, + cu_seqlens=cu_tensor, + ) + + total_t = shape[2] if head_first else shape[1] + if total_t >= 4096: + atol = 4e-2 + elif total_t >= 2048: + atol = 2e-2 + else: + atol = 1e-2 + + for mask_variant in mask_variants: + print(f"Testing {cfg['label']} [{mask_variant}] shape={shape} C={chunk}") + o = run_kernel( + src, + q, + k, + v, + chunk, + g=g, + head_first=head_first, + cu_seqlens=cu_tensor, + mask_variant=mask_variant, + ) + torch.testing.assert_close(o.cpu(), ref_o.cpu(), rtol=RTOL, atol=atol) + print(" passed!") + + print("Kernel Output Match!") + + +if __name__ == "__main__": + main() diff --git a/examples/jit_cpp/linear_attention/triton_baseline/README.md b/examples/jit_cpp/linear_attention/triton_baseline/README.md new file mode 100644 index 00000000..70f29010 --- /dev/null +++ b/examples/jit_cpp/linear_attention/triton_baseline/README.md @@ -0,0 +1,32 @@ +Triton baseline to compare with PTO kernel performance. + +This directory contains a self-contained Triton-Ascend forward baseline for the +naive chunkwise linear-attention `chunk_o` path: + +- `chunk_o.py`: fused forward kernel for the `HEAD_FIRST` layout `(B, H, T, D)`. +- `chunk_o_vllm_adapted.py`: copied/adapted vLLM-Ascend `chunk_o` kernel plus a thin adapter for equivalent naive-kernel inputs. +- `test_chunk_o.py`: forward correctness tests against a PyTorch reference. +- `benchmark_chunk_o.py`: measured comparison against the existing PTO C++ kernel and the copied vLLM-style Triton kernel. +- `performance_summary.md`: benchmark results captured on the current machine. + +Current scope: + +- Supports runtime-dynamic `B` and `L`. +- Keeps `D`, `V`, and `C` as compile-time constants for Triton codegen. +- Supports only the no-gating, fixed-length, `HEAD_FIRST` path. +- Supports optional `initial_state` and `output_final_state`. +- Includes a copied seq-first vLLM-style `chunk_o` path exercised in equivalent no-gating / uniform-g / static / varlen modes. + +Not implemented yet: + +- `SEQ_FIRST` layout `(B, T, H, D)`. +- Gated variants (`USE_G`). +- Varlen / offsets-based execution. +- The copied vLLM-style kernel is only verified/benchmarked at `C=64` on this device; its unmodified `BT=C=128` configuration overflowed UB. + +Quick commands: + +```bash +python -m pytest test_chunk_o.py -q +python benchmark_chunk_o.py --markdown-out performance_summary.md +``` diff --git a/examples/jit_cpp/linear_attention/triton_baseline/benchmark_chunk_o.py b/examples/jit_cpp/linear_attention/triton_baseline/benchmark_chunk_o.py new file mode 100644 index 00000000..af54f92b --- /dev/null +++ b/examples/jit_cpp/linear_attention/triton_baseline/benchmark_chunk_o.py @@ -0,0 +1,404 @@ +import argparse +import importlib.util +import os +from pathlib import Path +from statistics import median + +import torch +import torch_npu # noqa: F401 + +from chunk_o import build_chunk_states, chunk_o, get_causal_mask +from chunk_o_vllm_adapted import chunk_fwd_o as vllm_chunk_fwd_o +from chunk_o_vllm_adapted import prepare_vllm_equivalent_inputs + + +DTYPE = torch.float16 +THIS_DIR = Path(__file__).resolve().parent +PTO_DIR = THIS_DIR.parent +PTO_SRC = PTO_DIR / "linear_attention.cpp" +PTO_UTIL = PTO_DIR / "jit_util_linear_attention.py" +_DEFAULT_MAX_CACHE_SIZE = 256 * 1024 * 1024 + +DEFAULT_SHAPES = [ + (8, 20, 1024, 128, 64), + (16, 20, 1024, 128, 64), + (24, 20, 2048, 128, 64), +] + +VLLM_VARIANTS = [ + ("vllm_static_no_g", "none", "static"), + ("vllm_static_uniform_g", "uniform_zero", "static"), + ("vllm_varlen_no_g", "none", "varlen_equiv"), + ("vllm_varlen_uniform_g", "uniform_zero", "varlen_equiv"), +] + + +def parse_shapes(shape_text: str): + shapes = [] + for item in shape_text.split(";"): + item = item.strip() + if not item: + continue + parts = tuple(int(x) for x in item.split("x")) + if len(parts) != 5: + raise ValueError( + "Each shape must be formatted as BxHxLxDxC, e.g. 16x20x1024x128x128" + ) + shapes.append(parts) + return shapes + + +def estimate_flops(batch: int, heads: int, seq: int, hidden: int, chunk: int) -> int: + chunk_num = (seq + chunk - 1) // chunk + flops_per_chunk = 4 * chunk * hidden * (chunk + hidden) + return batch * heads * chunk_num * flops_per_chunk + + +def estimate_gm_bytes(batch: int, heads: int, seq: int, hidden: int, chunk: int) -> int: + chunk_num = (seq + chunk - 1) // chunk + qkv_and_output_bytes = chunk_num * (4 * chunk * hidden * 2) + return batch * heads * qkv_and_output_bytes + + +def make_inputs(batch: int, heads: int, seq: int, hidden: int): + q = torch.randn((batch, heads, seq, hidden), device="npu", dtype=DTYPE) + k = torch.randn((batch, heads, seq, hidden), device="npu", dtype=DTYPE) + v = torch.randn((batch, heads, seq, hidden), device="npu", dtype=DTYPE) + q = q / (q.pow(2).sum(dim=-1, keepdim=True).sqrt() + 1e-6) + k = k / (k.pow(2).sum(dim=-1, keepdim=True).sqrt() + 1e-6) + return q, k, v + + +def benchmark_callable(fn, warmup: int, repeats: int): + device = torch.npu + device.synchronize() + for _ in range(warmup): + fn() + device.synchronize() + + # Match the stronger timing pattern used in local profiling helpers. + cache = torch.ones(_DEFAULT_MAX_CACHE_SIZE, dtype=torch.int8, device="npu") + + samples_ms = [] + for _ in range(repeats): + cache.zero_() + device.synchronize() + start = device.Event(enable_timing=True) + end = device.Event(enable_timing=True) + start.record() + fn() + end.record() + end.synchronize() + samples_ms.append(start.elapsed_time(end)) + return median(samples_ms) + + +def load_pto_helpers(): + spec = importlib.util.spec_from_file_location("pto_jit_util", PTO_UTIL) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def benchmark_triton_shape( + kernel_name: str, + batch: int, + heads: int, + seq: int, + hidden: int, + chunk: int, + warmup: int, + repeats: int, + *, + use_cached_mask: bool, +): + q, k, v = make_inputs(batch, heads, seq, hidden) + precomputed_h = build_chunk_states(k, v, chunk) + precomputed_mask = ( + get_causal_mask(chunk, DTYPE, q.device.index or 0) if use_cached_mask else None + ) + med_ms = benchmark_callable( + lambda: chunk_o( + q, + k, + v, + chunk_size=chunk, + precomputed_h=precomputed_h, + precomputed_mask=precomputed_mask, + use_cached_mask=use_cached_mask, + ), + warmup, + repeats, + ) + return summarize_result(kernel_name, batch, heads, seq, hidden, chunk, med_ms) + + +def benchmark_pto_shape( + batch: int, heads: int, seq: int, hidden: int, chunk: int, warmup: int, repeats: int +): + if seq % chunk != 0: + raise ValueError("PTO benchmark path requires L to be a multiple of C.") + + pto = load_pto_helpers() + kernel = pto.jit_compile(str(PTO_SRC), num_heads=heads, hidden_size=hidden, chunk_size=chunk) + q, k, v = make_inputs(batch, heads, seq, hidden) + workspace_1 = torch.zeros( + (pto.BLOCK_DIM, 2, chunk, chunk), device="npu", dtype=DTYPE + ) + workspace_2 = torch.zeros( + (pto.BLOCK_DIM, 2, hidden, hidden), device="npu", dtype=DTYPE + ) + causal_mask = pto.get_causal_mask(chunk, DTYPE, 0) + out = torch.zeros((batch, heads, seq, hidden), device="npu", dtype=DTYPE) + + def run(): + kernel(q, k, v, workspace_1, workspace_2, causal_mask, out, block_dim=pto.BLOCK_DIM) + + med_ms = benchmark_callable(run, warmup, repeats) + return summarize_result("pto_cpp", batch, heads, seq, hidden, chunk, med_ms) + + +def benchmark_vllm_shape( + kernel_name: str, + batch: int, + heads: int, + seq: int, + hidden: int, + chunk: int, + warmup: int, + repeats: int, + *, + g_mode: str, + varlen_mode: str, +): + q, k, v = make_inputs(batch, heads, seq, hidden) + prepared = prepare_vllm_equivalent_inputs( + q, + k, + v, + chunk, + g_mode=g_mode, + varlen_mode=varlen_mode, + ) + + def run(): + vllm_chunk_fwd_o( + q=prepared["q"], + k=prepared["k"], + v=prepared["v"], + h=prepared["h"], + g=prepared["g"], + scale=1.0, + cu_seqlens=prepared["cu_seqlens"], + chunk_size=chunk, + ) + + med_ms = benchmark_callable(run, warmup, repeats) + return summarize_result(kernel_name, batch, heads, seq, hidden, chunk, med_ms) + + +def summarize_result( + kernel_name: str, + batch: int, + heads: int, + seq: int, + hidden: int, + chunk: int, + median_ms: float, +): + secs = median_ms / 1e3 + flops = estimate_flops(batch, heads, seq, hidden, chunk) + gm_bytes = estimate_gm_bytes(batch, heads, seq, hidden, chunk) + return { + "kernel": kernel_name, + "shape": (batch, heads, seq, hidden, chunk), + "median_ms": median_ms, + "tflops": flops / secs / 1e12, + "gib_s": gm_bytes / secs / (2**30), + } + + +def render_markdown(results): + lines = [ + "# Triton-Ascend `chunk_o` Performance", + "", + "| Kernel | Shape `(B,H,L,D,C)` | Median ms | TFLOP/s | GiB/s |", + "| --- | --- | ---: | ---: | ---: |", + ] + for result in results: + lines.append( + f"| {result['kernel']} | `{result['shape']}` | " + f"{result['median_ms']:.3f} | {result['tflops']:.2f} | {result['gib_s']:.2f} |" + ) + + grouped = {} + for result in results: + grouped.setdefault(result["shape"], {})[result["kernel"]] = result + + lines.extend(["", "## PTO / Kernel Speedup", ""]) + lines.append("| Shape `(B,H,L,D,C)` | Kernel | PTO / Kernel speedup | Kernel - PTO TFLOP/s delta |") + lines.append("| --- | --- | ---: | ---: |") + for shape, pair in grouped.items(): + if "pto_cpp" not in pair: + continue + pto_ms = pair["pto_cpp"]["median_ms"] + for kernel_name, result in pair.items(): + if kernel_name == "pto_cpp": + continue + speedup = result["median_ms"] / pto_ms + tflops_delta = result["tflops"] - pair["pto_cpp"]["tflops"] + lines.append( + f"| `{shape}` | `{kernel_name}` | {speedup:.2f}x | {tflops_delta:+.2f} |" + ) + + lines.extend(["", "## Triton Cached vs On-The-Fly", ""]) + lines.append( + "| Shape `(B,H,L,D,C)` | Cached ms | On-the-fly ms | On-the-fly / Cached | Cached TFLOP/s delta |" + ) + lines.append("| --- | ---: | ---: | ---: | ---: |") + for shape, pair in grouped.items(): + if "triton_mask_cached" not in pair or "triton_mask_onthefly" not in pair: + continue + cached = pair["triton_mask_cached"] + onthefly = pair["triton_mask_onthefly"] + lines.append( + f"| `{shape}` | {cached['median_ms']:.3f} | {onthefly['median_ms']:.3f} | " + f"{onthefly['median_ms'] / cached['median_ms']:.2f}x | " + f"{cached['tflops'] - onthefly['tflops']:+.2f} |" + ) + + lines.extend( + [ + "", + "Notes:", + "- Reported TFLOP/s and GiB/s are computed from the same algorithm-level model for both kernels.", + "- The Triton kernel is forward-only, head-first only, and currently omits gating and varlen support.", + "- `triton_mask_onthefly` computes the causal mask inside the Triton kernel but still uses precomputed chunk states `h`.", + "- `triton_mask_cached` is benchmarked with precomputed chunk states `h` and a cached causal mask, so state construction and mask setup are excluded from timed measurements.", + "- The copied vLLM-style kernel is benchmarked with precomputed `h` state and pre-transposed inputs, so transpose/setup cost is excluded as requested.", + "- On this device, the copied vLLM-style kernel compiled and ran for `C=64`, but the unmodified `BT=C=128` configuration overflowed UB and was not benchmarked.", + "- `TRITON_ALL_BLOCKS_PARALLEL` is intentionally left disabled here because it produced incorrect outputs for this kernel.", + ] + ) + return "\n".join(lines) + "\n" + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark Triton-Ascend chunk_o against the PTO linear-attention kernel." + ) + parser.add_argument("--warmup", type=int, default=5) + parser.add_argument("--repeats", type=int, default=20) + parser.add_argument( + "--shapes", + type=str, + default="", + help="Semicolon-separated BxHxLxDxC list, e.g. 16x20x1024x128x128", + ) + parser.add_argument( + "--skip-pto", + action="store_true", + help="Benchmark only the Triton-Ascend kernel.", + ) + parser.add_argument( + "--markdown-out", + type=str, + default="", + help="Optional path to write a markdown summary.", + ) + args = parser.parse_args() + + torch.manual_seed(0) + torch.npu.set_device("npu:0") + shapes = parse_shapes(args.shapes) if args.shapes else DEFAULT_SHAPES + + header = f"{'kernel':>14} {'shape (B,H,L,D,C)':>24} {'ms':>9} {'TFLOP/s':>10} {'GiB/s':>10}" + print(header) + print("-" * len(header)) + + results = [] + for shape in shapes: + batch, heads, seq, hidden, chunk = shape + print(f"Running Triton on-the-fly mask {shape} ...") + triton_onthefly_result = benchmark_triton_shape( + "triton_mask_onthefly", + batch, + heads, + seq, + hidden, + chunk, + args.warmup, + args.repeats, + use_cached_mask=False, + ) + results.append(triton_onthefly_result) + print( + f"{triton_onthefly_result['kernel']:>14} {str(triton_onthefly_result['shape']):>24} " + f"{triton_onthefly_result['median_ms']:>9.3f} {triton_onthefly_result['tflops']:>10.2f} " + f"{triton_onthefly_result['gib_s']:>10.2f}" + ) + + print(f"Running Triton cached mask {shape} ...") + triton_cached_result = benchmark_triton_shape( + "triton_mask_cached", + batch, + heads, + seq, + hidden, + chunk, + args.warmup, + args.repeats, + use_cached_mask=True, + ) + results.append(triton_cached_result) + print( + f"{triton_cached_result['kernel']:>14} {str(triton_cached_result['shape']):>24} " + f"{triton_cached_result['median_ms']:>9.3f} {triton_cached_result['tflops']:>10.2f} " + f"{triton_cached_result['gib_s']:>10.2f}" + ) + + for kernel_name, g_mode, varlen_mode in VLLM_VARIANTS: + print(f"Running {kernel_name} {shape} ...") + vllm_result = benchmark_vllm_shape( + kernel_name, + batch, + heads, + seq, + hidden, + chunk, + args.warmup, + args.repeats, + g_mode=g_mode, + varlen_mode=varlen_mode, + ) + results.append(vllm_result) + print( + f"{vllm_result['kernel']:>14} {str(vllm_result['shape']):>24} " + f"{vllm_result['median_ms']:>9.3f} {vllm_result['tflops']:>10.2f} " + f"{vllm_result['gib_s']:>10.2f}" + ) + + if not args.skip_pto: + print(f"Running PTO C++ {shape} ...") + pto_result = benchmark_pto_shape( + batch, heads, seq, hidden, chunk, args.warmup, args.repeats + ) + results.append(pto_result) + print( + f"{pto_result['kernel']:>14} {str(pto_result['shape']):>24} " + f"{pto_result['median_ms']:>9.3f} {pto_result['tflops']:>10.2f} " + f"{pto_result['gib_s']:>10.2f}" + ) + + if args.markdown_out: + markdown = render_markdown(results) + output_path = Path(args.markdown_out) + if not output_path.is_absolute(): + output_path = THIS_DIR / output_path + output_path.write_text(markdown) + print(f"\nWrote markdown summary to {output_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/jit_cpp/linear_attention/triton_baseline/chunk_o.py b/examples/jit_cpp/linear_attention/triton_baseline/chunk_o.py new file mode 100644 index 00000000..bf0efa3e --- /dev/null +++ b/examples/jit_cpp/linear_attention/triton_baseline/chunk_o.py @@ -0,0 +1,318 @@ +import math +from functools import lru_cache +from typing import Optional + +import torch +import torch_npu # noqa: F401 +import triton +import triton.language as tl + + +@triton.jit(do_not_specialize=["T", "NT", "total_bh", "scale"]) +def _chunk_o_fwd_kernel( + q, + k, + v, + h, + mask, + o, + scale, + T, + NT, + total_bh, + K: tl.constexpr, + V: tl.constexpr, + C: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_PRECOMPUTED_MASK: tl.constexpr, +): + pid = tl.program_id(0) + NV: tl.constexpr = tl.cdiv(V, BV) + i_bh = pid // NV + i_v = pid % NV + + if i_bh >= total_bh: + return + + q += i_bh * T * K + k += i_bh * T * K + v += i_bh * T * V + o += i_bh * T * V + + for i_c in range(NT): + chunk_start = i_c * C + h_base = h + ((i_bh * NT + i_c).to(tl.int64) * K * V) + p_v = tl.make_block_ptr( + v, (T, V), (V, 1), (chunk_start, i_v * BV), (C, BV), (1, 0) + ) + b_v = tl.load(p_v, boundary_check=(0, 1)) + + for i_t in range(tl.cdiv(C, BT)): + row_start = chunk_start + i_t * BT + p_o = tl.make_block_ptr( + o, (T, V), (V, 1), (row_start, i_v * BV), (BT, BV), (1, 0) + ) + p_mask = tl.make_block_ptr( + mask, (C, C), (C, 1), (i_t * BT, 0), (BT, C), (1, 0) + ) + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_a = tl.zeros([BT, C], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr( + q, (T, K), (K, 1), (row_start, i_k * BK), (BT, BK), (1, 0) + ) + p_k = tl.make_block_ptr( + k, (K, T), (1, K), (i_k * BK, chunk_start), (BK, C), (0, 1) + ) + p_h = tl.make_block_ptr( + h_base, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0) + ) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot(b_q, b_h) + b_a += tl.dot(b_q, b_k) + + if USE_PRECOMPUTED_MASK: + b_mask = tl.load(p_mask, boundary_check=(0, 1)) + b_a *= b_mask.to(b_a.dtype) + else: + row_offsets = (i_t * BT + tl.arange(0, BT)).to(tl.float32) + col_offsets = tl.arange(0, C).to(tl.float32) + b_a = tl.where(row_offsets[:, None] >= col_offsets[None, :], b_a, 0) + b_o += tl.dot(b_a.to(b_v.dtype), b_v) + b_o *= scale + + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def _require_head_first(x: torch.Tensor, name: str) -> None: + if x.ndim != 4: + raise ValueError(f"{name} must be rank-4, got {tuple(x.shape)}") + + +def ref_chunk_o( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + chunk_size: int, + *, + scale: float = 1.0, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, +): + _require_head_first(q, "q") + _require_head_first(k, "k") + _require_head_first(v, "v") + + b, h, t, d_k = q.shape + d_v = v.shape[-1] + qf = q.float() + kf = k.float() + vf = v.float() + + state = torch.zeros((b, h, d_k, d_v), device=q.device, dtype=torch.float32) + if initial_state is not None: + state.copy_(initial_state.float()) + + out = torch.zeros((b, h, t, d_v), device=q.device, dtype=torch.float32) + nt = math.ceil(t / chunk_size) + for i_t in range(nt): + start = i_t * chunk_size + end = min(start + chunk_size, t) + q_tile = qf[:, :, start:end, :] + k_tile = kf[:, :, start:end, :] + v_tile = vf[:, :, start:end, :] + attn = torch.matmul(q_tile, k_tile.transpose(-1, -2)).tril() + out[:, :, start:end, :] = ( + torch.matmul(q_tile, state) + torch.matmul(attn, v_tile) + ) * scale + state = state + torch.matmul(k_tile.transpose(-1, -2), v_tile) + + if output_final_state: + return out.to(v.dtype), state + return out.to(v.dtype) + + +def build_chunk_states( + k: torch.Tensor, + v: torch.Tensor, + chunk_size: int, + *, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, +): + _require_head_first(k, "k") + _require_head_first(v, "v") + + b, h, t, d_k = k.shape + d_v = v.shape[-1] + nt = math.ceil(t / chunk_size) + state = torch.zeros((b, h, d_k, d_v), device=k.device, dtype=torch.float32) + if initial_state is not None: + state.copy_(initial_state.float()) + + states = [] + for i_t in range(nt): + states.append(state.to(v.dtype)) + start = i_t * chunk_size + end = min(start + chunk_size, t) + state = state + torch.matmul( + k[:, :, start:end, :].float().transpose(-1, -2), + v[:, :, start:end, :].float(), + ) + + stacked = torch.stack(states, dim=2).contiguous() + if output_final_state: + return stacked, state + return stacked + + +@lru_cache(maxsize=None) +def get_causal_mask(chunk_size: int, dtype: torch.dtype, device_index: int) -> torch.Tensor: + if chunk_size <= 0: + raise ValueError(f"chunk_size must be positive, got {chunk_size}") + mask = torch.ones( + (chunk_size, chunk_size), + device=f"npu:{device_index}", + dtype=dtype, + ) + return torch.tril(mask).contiguous() + + +def _normalize_precomputed_h( + h: torch.Tensor, + b: int, + heads: int, + nt: int, + d_k: int, + d_v: int, +) -> torch.Tensor: + expected_5d = (b, heads, nt, d_k, d_v) + expected_4d = (b * nt, heads, d_k, d_v) + if tuple(h.shape) == expected_5d: + return h.contiguous().view(b * nt, heads, d_k, d_v) + if tuple(h.shape) == expected_4d: + return h.contiguous() + raise ValueError( + f"precomputed_h must have shape {expected_5d} or {expected_4d}, got {tuple(h.shape)}" + ) + + +def chunk_o( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + chunk_size: int, + *, + scale: float = 1.0, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + precomputed_h: Optional[torch.Tensor] = None, + precomputed_mask: Optional[torch.Tensor] = None, + use_cached_mask: bool = True, +): + # TODO: support seq_first layout: (B, T, H, D). + # TODO: support gated and varlen variants when the baseline is proven out. + _require_head_first(q, "q") + _require_head_first(k, "k") + _require_head_first(v, "v") + + if q.shape != k.shape: + raise ValueError(f"q and k must have the same shape, got {q.shape} vs {k.shape}") + if q.shape[:3] != v.shape[:3]: + raise ValueError( + f"q/k and v must agree on (B, H, T), got {q.shape[:3]} vs {v.shape[:3]}" + ) + if chunk_size <= 0: + raise ValueError(f"chunk_size must be positive, got {chunk_size}") + if q.device.type != "npu": + raise ValueError(f"expected NPU tensors, got {q.device}") + + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + b, h, t, d_k = q.shape + d_v = v.shape[-1] + total_bh = b * h + nt = math.ceil(t / chunk_size) + + if initial_state is not None: + expected = (b, h, d_k, d_v) + if tuple(initial_state.shape) != expected: + raise ValueError( + f"initial_state must have shape {expected}, got {tuple(initial_state.shape)}" + ) + initial_state = initial_state.contiguous() + + if precomputed_h is None: + built = build_chunk_states( + k, + v, + chunk_size, + initial_state=initial_state, + output_final_state=output_final_state, + ) + if output_final_state: + h_states, final_state = built + else: + h_states, final_state = built, None + precomputed_h = h_states + else: + if output_final_state: + raise ValueError( + "output_final_state=True is not supported together with externally supplied precomputed_h" + ) + final_state = None + + precomputed_h = _normalize_precomputed_h(precomputed_h, b, h, nt, d_k, d_v) + if precomputed_mask is not None: + if not use_cached_mask: + raise ValueError("precomputed_mask requires use_cached_mask=True") + expected_mask = (chunk_size, chunk_size) + if tuple(precomputed_mask.shape) != expected_mask: + raise ValueError( + f"precomputed_mask must have shape {expected_mask}, got {tuple(precomputed_mask.shape)}" + ) + if precomputed_mask.device != q.device: + raise ValueError( + f"precomputed_mask must be on {q.device}, got {precomputed_mask.device}" + ) + precomputed_mask = precomputed_mask.contiguous() + elif use_cached_mask: + precomputed_mask = get_causal_mask(chunk_size, q.dtype, q.device.index or 0) + else: + precomputed_mask = torch.empty((chunk_size, chunk_size), device=q.device, dtype=q.dtype) + out = torch.empty_like(v) + tile_rows = min(64, chunk_size) + bk = min(64, triton.next_power_of_2(d_k)) + bv = min(64, triton.next_power_of_2(d_v)) + grid = (total_bh * triton.cdiv(d_v, bv),) + _chunk_o_fwd_kernel[grid]( + q=q, + k=k, + v=v, + h=precomputed_h, + mask=precomputed_mask, + o=out, + scale=scale, + T=t, + NT=nt, + total_bh=total_bh, + K=d_k, + V=d_v, + C=chunk_size, + BT=tile_rows, + BK=bk, + BV=bv, + USE_PRECOMPUTED_MASK=use_cached_mask, + num_warps=4, + num_stages=2, + ) + return (out, final_state) if output_final_state else out + + +chunk_linear_attention = chunk_o diff --git a/examples/jit_cpp/linear_attention/triton_baseline/chunk_o_vllm_adapted.py b/examples/jit_cpp/linear_attention/triton_baseline/chunk_o_vllm_adapted.py new file mode 100644 index 00000000..0c2d59f4 --- /dev/null +++ b/examples/jit_cpp/linear_attention/triton_baseline/chunk_o_vllm_adapted.py @@ -0,0 +1,265 @@ +import math +from typing import Literal + +import torch +import torch_npu # noqa: F401 +import triton +import triton.language as tl + +# adapted from https://github.com/vllm-project/vllm-ascend/blob/v0.18.0rc1/vllm_ascend/ops/triton/fla/chunk_o.py + + +def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor: + lens = cu_seqlens[1:] - cu_seqlens[:-1] + return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(lens, chunk_size)]).cumsum(-1) + + +@triton.jit +def safe_exp(x): + return tl.exp(tl.where(x <= 0, x, float("-inf"))) + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.jit(do_not_specialize=["chunk_offsets", "scale", "T", "H", "Hg", "K", "V"]) +def chunk_fwd_kernel_o( + q, + k, + v, + h, + g, + o, + cu_seqlens, + chunk_offsets, + scale, + T, + H, + Hg, + K, + V, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + T_max = T + + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int64) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + q += (bos * Hg + i_h // (H // Hg)) * K + k += (bos * Hg + i_h // (H // Hg)) * K + v += (bos * H + i_h) * V + o += (bos * H + i_h) * V + + for i_t in range(NT): + i_tg = boh + i_t + h_base = h + (i_tg * H + i_h).to(tl.int64) * K * V + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h_base, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h = tl.load(p_h, boundary_check=(0, 1)) + + b_o += tl.dot(b_q, b_h) + b_A += tl.dot(b_q, b_k) + + if USE_G: + offs_t = i_t * BT + tl.arange(0, BT) + mask_t = offs_t < T + g_ptr = g + bos + i_h * T_max + b_g = tl.load(g_ptr + offs_t, mask=mask_t, other=0.0) + + b_o = b_o * tl.exp(b_g)[:, None] + b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :]) + + o_i = tl.arange(0, BT).to(tl.float32) + m_A = o_i[:, None] >= o_i[None, :] + b_A = tl.where(m_A, b_A, 0) + + p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_fwd_o( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + h: torch.Tensor, + g: torch.Tensor | None = None, + scale: float | None = None, + cu_seqlens: torch.LongTensor | None = None, + chunk_size: int = 64, +) -> torch.Tensor: + b, t, hg, k_dim, v_dim = *q.shape, v.shape[-1] + h_dim = v.shape[-2] + bt = chunk_size + + if scale is None: + scale = k.shape[-1] ** -0.5 + + o = torch.empty_like(v) + if cu_seqlens is None: + n, chunk_offsets = b, None + else: + n, chunk_offsets = len(cu_seqlens) - 1, prepare_chunk_offsets(cu_seqlens, bt) + bk = min(64, triton.next_power_of_2(k_dim)) + bv = min(64, triton.next_power_of_2(v_dim)) + + def grid(meta): + return (triton.cdiv(v_dim, meta["BV"]), n * h_dim) + + if g is not None: + g = g.transpose(1, 2).contiguous() + + chunk_fwd_kernel_o[grid]( + q=q, + k=k, + v=v, + h=h, + g=g, + o=o, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + scale=scale, + T=t, + H=h_dim, + Hg=hg, + K=k_dim, + V=v_dim, + BT=bt, + BK=bk, + BV=bv, + num_warps=4, + num_stages=2, + ) + return o + + +def build_chunk_states( + k: torch.Tensor, + v: torch.Tensor, + chunk_size: int, +) -> torch.Tensor: + b, h, t, d_k = k.shape + d_v = v.shape[-1] + nt = math.ceil(t / chunk_size) + state = torch.zeros((b, h, d_k, d_v), device=k.device, dtype=torch.float32) + states = [] + for i_t in range(nt): + states.append(state.to(v.dtype)) + start = i_t * chunk_size + end = min(start + chunk_size, t) + state = state + torch.matmul( + k[:, :, start:end, :].float().transpose(-1, -2), + v[:, :, start:end, :].float(), + ) + return torch.stack(states, dim=1).contiguous().view(b * nt, h, d_k, d_v) + + +def prepare_vllm_equivalent_inputs( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + chunk_size: int, + *, + g_mode: Literal["none", "uniform_zero"], + varlen_mode: Literal["static", "varlen_equiv"], +): + if q.ndim != 4 or k.ndim != 4 or v.ndim != 4: + raise ValueError("q, k, v must all be rank-4 tensors") + if q.shape != k.shape or q.shape[:3] != v.shape[:3]: + raise ValueError("q, k, v must agree on (B, H, T)") + + b, h, t, d = q.shape + q_seq = q.transpose(1, 2).contiguous() + k_seq = k.transpose(1, 2).contiguous() + v_seq = v.transpose(1, 2).contiguous() + h_states = build_chunk_states(k, v, chunk_size) + + g = None + if g_mode == "uniform_zero": + g = torch.zeros((b, t, h), device=q.device, dtype=torch.float32) + elif g_mode != "none": + raise ValueError(f"Unsupported g_mode: {g_mode}") + + if varlen_mode == "static": + return { + "q": q_seq, + "k": k_seq, + "v": v_seq, + "h": h_states, + "g": g, + "cu_seqlens": None, + "restore": lambda o: o.transpose(1, 2).contiguous(), + } + + if varlen_mode != "varlen_equiv": + raise ValueError(f"Unsupported varlen_mode: {varlen_mode}") + + total_t = b * t + cu_seqlens = torch.arange(0, total_t + 1, t, device=q.device, dtype=torch.long) + q_flat = q_seq.reshape(1, total_t, h, d).contiguous() + k_flat = k_seq.reshape(1, total_t, h, d).contiguous() + v_flat = v_seq.reshape(1, total_t, h, v.shape[-1]).contiguous() + if g is not None: + g = g.reshape(1, total_t, h).contiguous() + + return { + "q": q_flat, + "k": k_flat, + "v": v_flat, + "h": h_states, + "g": g, + "cu_seqlens": cu_seqlens, + "restore": lambda o: o.view(b, t, h, v.shape[-1]).transpose(1, 2).contiguous(), + } + + +def chunk_o_vllm_adapted( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + chunk_size: int, + *, + scale: float = 1.0, + g_mode: Literal["none", "uniform_zero"] = "none", + varlen_mode: Literal["static", "varlen_equiv"] = "static", +) -> torch.Tensor: + prepared = prepare_vllm_equivalent_inputs( + q, k, v, chunk_size, g_mode=g_mode, varlen_mode=varlen_mode + ) + out = chunk_fwd_o( + q=prepared["q"], + k=prepared["k"], + v=prepared["v"], + h=prepared["h"], + g=prepared["g"], + scale=scale, + cu_seqlens=prepared["cu_seqlens"], + chunk_size=chunk_size, + ) + return prepared["restore"](out) diff --git a/examples/jit_cpp/linear_attention/triton_baseline/performance_summary.md b/examples/jit_cpp/linear_attention/triton_baseline/performance_summary.md new file mode 100644 index 00000000..0c15606d --- /dev/null +++ b/examples/jit_cpp/linear_attention/triton_baseline/performance_summary.md @@ -0,0 +1,65 @@ +# Triton-Ascend `chunk_o` Performance + +| Kernel | Shape `(B,H,L,D,C)` | Median ms | TFLOP/s | GiB/s | +| --- | --- | ---: | ---: | ---: | +| triton_mask_onthefly | `(8, 20, 1024, 128, 64)` | 1.305 | 12.34 | 119.69 | +| triton_mask_cached | `(8, 20, 1024, 128, 64)` | 1.323 | 12.17 | 118.10 | +| vllm_static_no_g | `(8, 20, 1024, 128, 64)` | 1.470 | 10.96 | 106.29 | +| vllm_static_uniform_g | `(8, 20, 1024, 128, 64)` | 1.565 | 10.29 | 99.82 | +| vllm_varlen_no_g | `(8, 20, 1024, 128, 64)` | 2.030 | 7.93 | 76.97 | +| vllm_varlen_uniform_g | `(8, 20, 1024, 128, 64)` | 2.090 | 7.71 | 74.78 | +| pto_cpp | `(8, 20, 1024, 128, 64)` | 0.583 | 27.61 | 267.87 | +| triton_mask_onthefly | `(16, 20, 1024, 128, 64)` | 2.260 | 14.25 | 138.28 | +| triton_mask_cached | `(16, 20, 1024, 128, 64)` | 2.312 | 13.93 | 135.14 | +| vllm_static_no_g | `(16, 20, 1024, 128, 64)` | 2.681 | 12.01 | 116.56 | +| vllm_static_uniform_g | `(16, 20, 1024, 128, 64)` | 2.758 | 11.68 | 113.31 | +| vllm_varlen_no_g | `(16, 20, 1024, 128, 64)` | 3.200 | 10.07 | 97.66 | +| vllm_varlen_uniform_g | `(16, 20, 1024, 128, 64)` | 3.252 | 9.90 | 96.08 | +| pto_cpp | `(16, 20, 1024, 128, 64)` | 1.121 | 28.74 | 278.82 | +| triton_mask_onthefly | `(24, 20, 2048, 128, 64)` | 6.096 | 15.85 | 153.79 | +| triton_mask_cached | `(24, 20, 2048, 128, 64)` | 6.264 | 15.43 | 149.67 | +| vllm_static_no_g | `(24, 20, 2048, 128, 64)` | 7.458 | 12.96 | 125.71 | +| vllm_static_uniform_g | `(24, 20, 2048, 128, 64)` | 7.529 | 12.83 | 124.51 | +| vllm_varlen_no_g | `(24, 20, 2048, 128, 64)` | 7.978 | 12.11 | 117.51 | +| vllm_varlen_uniform_g | `(24, 20, 2048, 128, 64)` | 8.007 | 12.07 | 117.08 | +| pto_cpp | `(24, 20, 2048, 128, 64)` | 3.072 | 31.46 | 305.18 | + +## PTO / Kernel Speedup + +| Shape `(B,H,L,D,C)` | Kernel | PTO / Kernel speedup | Kernel - PTO TFLOP/s delta | +| --- | --- | ---: | ---: | +| `(8, 20, 1024, 128, 64)` | `triton_mask_onthefly` | 2.24x | -15.27 | +| `(8, 20, 1024, 128, 64)` | `triton_mask_cached` | 2.27x | -15.44 | +| `(8, 20, 1024, 128, 64)` | `vllm_static_no_g` | 2.52x | -16.66 | +| `(8, 20, 1024, 128, 64)` | `vllm_static_uniform_g` | 2.68x | -17.32 | +| `(8, 20, 1024, 128, 64)` | `vllm_varlen_no_g` | 3.48x | -19.68 | +| `(8, 20, 1024, 128, 64)` | `vllm_varlen_uniform_g` | 3.58x | -19.90 | +| `(16, 20, 1024, 128, 64)` | `triton_mask_onthefly` | 2.02x | -14.49 | +| `(16, 20, 1024, 128, 64)` | `triton_mask_cached` | 2.06x | -14.81 | +| `(16, 20, 1024, 128, 64)` | `vllm_static_no_g` | 2.39x | -16.73 | +| `(16, 20, 1024, 128, 64)` | `vllm_static_uniform_g` | 2.46x | -17.06 | +| `(16, 20, 1024, 128, 64)` | `vllm_varlen_no_g` | 2.85x | -18.67 | +| `(16, 20, 1024, 128, 64)` | `vllm_varlen_uniform_g` | 2.90x | -18.84 | +| `(24, 20, 2048, 128, 64)` | `triton_mask_onthefly` | 1.98x | -15.61 | +| `(24, 20, 2048, 128, 64)` | `triton_mask_cached` | 2.04x | -16.03 | +| `(24, 20, 2048, 128, 64)` | `vllm_static_no_g` | 2.43x | -18.50 | +| `(24, 20, 2048, 128, 64)` | `vllm_static_uniform_g` | 2.45x | -18.62 | +| `(24, 20, 2048, 128, 64)` | `vllm_varlen_no_g` | 2.60x | -19.34 | +| `(24, 20, 2048, 128, 64)` | `vllm_varlen_uniform_g` | 2.61x | -19.39 | + +## Triton Cached vs On-The-Fly + +| Shape `(B,H,L,D,C)` | Cached ms | On-the-fly ms | On-the-fly / Cached | Cached TFLOP/s delta | +| --- | ---: | ---: | ---: | ---: | +| `(8, 20, 1024, 128, 64)` | 1.323 | 1.305 | 0.99x | -0.16 | +| `(16, 20, 1024, 128, 64)` | 2.312 | 2.260 | 0.98x | -0.32 | +| `(24, 20, 2048, 128, 64)` | 6.264 | 6.096 | 0.97x | -0.43 | + +Notes: +- Reported TFLOP/s and GiB/s are computed from the same algorithm-level model for both kernels. +- The Triton kernel is forward-only, head-first only, and currently omits gating and varlen support. +- `triton_mask_onthefly` computes the causal mask inside the Triton kernel but still uses precomputed chunk states `h`. +- `triton_mask_cached` is benchmarked with precomputed chunk states `h` and a cached causal mask, so state construction and mask setup are excluded from timed measurements. +- The copied vLLM-style kernel is benchmarked with precomputed `h` state and pre-transposed inputs, so transpose/setup cost is excluded as requested. +- On this device, the copied vLLM-style kernel compiled and ran for `C=64`, but the unmodified `BT=C=128` configuration overflowed UB and was not benchmarked. +- `TRITON_ALL_BLOCKS_PARALLEL` is intentionally left disabled here because it produced incorrect outputs for this kernel. diff --git a/examples/jit_cpp/linear_attention/triton_baseline/test_chunk_o.py b/examples/jit_cpp/linear_attention/triton_baseline/test_chunk_o.py new file mode 100644 index 00000000..438cc9dd --- /dev/null +++ b/examples/jit_cpp/linear_attention/triton_baseline/test_chunk_o.py @@ -0,0 +1,190 @@ +import pytest +import torch +import torch_npu # noqa: F401 + +from chunk_o import build_chunk_states, chunk_o, get_causal_mask, ref_chunk_o +from chunk_o_vllm_adapted import chunk_o_vllm_adapted + + +DTYPE = torch.float16 +RTOL = 1e-2 + + +def make_inputs(b: int, h: int, l: int, d: int): + q = torch.randn((b, h, l, d), device="npu", dtype=DTYPE) + k = torch.randn((b, h, l, d), device="npu", dtype=DTYPE) + v = torch.randn((b, h, l, d), device="npu", dtype=DTYPE) + q = q / (q.pow(2).sum(dim=-1, keepdim=True).sqrt() + 1e-6) + k = k / (k.pow(2).sum(dim=-1, keepdim=True).sqrt() + 1e-6) + return q, k, v + + +def pick_atol(seq_len: int) -> float: + if seq_len >= 4096: + return 4e-2 + if seq_len >= 2048: + return 2e-2 + return 1e-2 + + +@pytest.mark.parametrize( + ("b", "h", "l", "d", "c"), + [ + (1, 2, 64, 128, 64), + (1, 2, 256, 128, 64), + (4, 2, 128, 128, 64), + (8, 2, 512, 128, 64), + (1, 2, 300, 128, 64), + (2, 2, 153, 64, 64), + (4, 4, 257, 128, 128), + ], +) +def test_chunk_o_forward_matches_reference(b: int, h: int, l: int, d: int, c: int): + torch.manual_seed(0) + torch.npu.set_device("npu:0") + + q, k, v = make_inputs(b, h, l, d) + out = chunk_o(q, k, v, chunk_size=c) + ref = ref_chunk_o(q, k, v, chunk_size=c) + + torch.npu.synchronize() + torch.testing.assert_close(out.cpu(), ref.cpu(), rtol=RTOL, atol=pick_atol(l)) + + +@pytest.mark.parametrize( + ("b", "h", "l", "d", "c"), + [ + (1, 2, 64, 128, 64), + (1, 2, 300, 128, 64), + (4, 4, 257, 128, 128), + ], +) +def test_chunk_o_onthefly_mask_matches_reference( + b: int, h: int, l: int, d: int, c: int +): + torch.manual_seed(5) + torch.npu.set_device("npu:0") + + q, k, v = make_inputs(b, h, l, d) + out = chunk_o(q, k, v, chunk_size=c, use_cached_mask=False) + ref = ref_chunk_o(q, k, v, chunk_size=c) + + torch.npu.synchronize() + torch.testing.assert_close(out.cpu(), ref.cpu(), rtol=RTOL, atol=pick_atol(l)) + + +@pytest.mark.parametrize( + ("b", "h", "l", "d", "c"), + [ + (1, 2, 65, 64, 64), + (2, 4, 192, 128, 64), + ], +) +def test_chunk_o_final_state_matches_reference( + b: int, h: int, l: int, d: int, c: int +): + torch.manual_seed(1) + torch.npu.set_device("npu:0") + + q, k, v = make_inputs(b, h, l, d) + h0 = torch.randn((b, h, d, d), device="npu", dtype=torch.float32) + + out, final_state = chunk_o( + q, k, v, chunk_size=c, initial_state=h0, output_final_state=True + ) + ref_out, ref_final_state = ref_chunk_o( + q, + k, + v, + chunk_size=c, + initial_state=h0, + output_final_state=True, + ) + + torch.npu.synchronize() + atol = pick_atol(l) + torch.testing.assert_close(out.cpu(), ref_out.cpu(), rtol=RTOL, atol=atol) + torch.testing.assert_close( + final_state.cpu(), ref_final_state.cpu(), rtol=RTOL, atol=atol + ) + + +@pytest.mark.parametrize( + ("b", "h", "l", "d", "c"), + [ + (1, 2, 256, 128, 64), + (4, 4, 257, 128, 128), + ], +) +def test_chunk_o_precomputed_h_matches_reference( + b: int, h: int, l: int, d: int, c: int +): + torch.manual_seed(3) + torch.npu.set_device("npu:0") + + q, k, v = make_inputs(b, h, l, d) + precomputed_h = build_chunk_states(k, v, c) + out = chunk_o(q, k, v, chunk_size=c, precomputed_h=precomputed_h) + ref = ref_chunk_o(q, k, v, chunk_size=c) + + torch.npu.synchronize() + torch.testing.assert_close(out.cpu(), ref.cpu(), rtol=RTOL, atol=pick_atol(l)) + + +@pytest.mark.parametrize( + ("b", "h", "l", "d", "c"), + [ + (1, 2, 256, 128, 64), + (4, 4, 257, 128, 128), + ], +) +def test_chunk_o_precomputed_mask_matches_reference( + b: int, h: int, l: int, d: int, c: int +): + torch.manual_seed(4) + torch.npu.set_device("npu:0") + + q, k, v = make_inputs(b, h, l, d) + precomputed_mask = get_causal_mask(c, DTYPE, q.device.index or 0) + out = chunk_o(q, k, v, chunk_size=c, precomputed_mask=precomputed_mask) + ref = ref_chunk_o(q, k, v, chunk_size=c) + + torch.npu.synchronize() + torch.testing.assert_close(out.cpu(), ref.cpu(), rtol=RTOL, atol=pick_atol(l)) + + +@pytest.mark.parametrize("g_mode", ["none", "uniform_zero"]) +@pytest.mark.parametrize("varlen_mode", ["static", "varlen_equiv"]) +@pytest.mark.parametrize( + ("b", "h", "l", "d", "c"), + [ + (1, 2, 256, 128, 64), + (2, 2, 300, 128, 64), + (4, 4, 512, 128, 64), + ], +) +def test_vllm_adapted_chunk_o_matches_reference( + b: int, + h: int, + l: int, + d: int, + c: int, + g_mode: str, + varlen_mode: str, +): + torch.manual_seed(2) + torch.npu.set_device("npu:0") + + q, k, v = make_inputs(b, h, l, d) + out = chunk_o_vllm_adapted( + q, + k, + v, + chunk_size=c, + g_mode=g_mode, + varlen_mode=varlen_mode, + ) + ref = ref_chunk_o(q, k, v, chunk_size=c) + + torch.npu.synchronize() + torch.testing.assert_close(out.cpu(), ref.cpu(), rtol=RTOL, atol=pick_atol(l))