diff --git a/examples/jit_cpp/fast_hadamard/jit_util_common.py b/examples/jit_cpp/fast_hadamard/jit_util_common.py index d2754748..919fb7ff 100644 --- a/examples/jit_cpp/fast_hadamard/jit_util_common.py +++ b/examples/jit_cpp/fast_hadamard/jit_util_common.py @@ -7,7 +7,31 @@ import torch ASCEND_TOOLKIT_HOME = os.environ["ASCEND_TOOLKIT_HOME"] -PTO_LIB_PATH = os.environ.get("PTO_LIB_PATH", ASCEND_TOOLKIT_HOME) + + +def _resolve_pto_lib_path() -> str: + """Pick the PTO-ISA header root for bisheng's -isystem flag. + + Resolution order: + 1. ``$PTO_LIB_PATH`` (explicit override). + 2. The pto-kernels CMake FetchContent mirror at + ``/build/_deps/libpto_isa_headers-src`` (pinned in + top-level CMakeLists.txt; populated by any cmake build). + 3. ``$ASCEND_TOOLKIT_HOME`` (CANN default; may lack newer + instructions such as TCOLEXPANDDIV on CANN 8.5.0). + """ + env = os.environ.get("PTO_LIB_PATH") + if env: + return env + # examples/jit_cpp//jit_util_common.py → parents[3] is the repo root + repo_root = Path(__file__).resolve().parents[3] + vendored = repo_root / "build" / "_deps" / "libpto_isa_headers-src" + if (vendored / "include" / "pto" / "pto-inst.hpp").is_file(): + return str(vendored) + return ASCEND_TOOLKIT_HOME + + +PTO_LIB_PATH = _resolve_pto_lib_path() DEFAULT_DEVICE = "npu:0" DEFAULT_BLOCK_DIM = 20 MAX_HADAMARD_N = 16384 diff --git a/examples/jit_cpp/sinkhorn/.gitignore b/examples/jit_cpp/sinkhorn/.gitignore new file mode 100644 index 00000000..08445025 --- /dev/null +++ b/examples/jit_cpp/sinkhorn/.gitignore @@ -0,0 +1,3 @@ +*.png +*.csv +*.so diff --git a/examples/jit_cpp/sinkhorn/README.md b/examples/jit_cpp/sinkhorn/README.md new file mode 100644 index 00000000..4ad25d50 --- /dev/null +++ b/examples/jit_cpp/sinkhorn/README.md @@ -0,0 +1,60 @@ +# Doubly-Stochastic Sinkhorn Normalization + +PTO-ISA kernel for doubly-stochastic Sinkhorn normalization on Ascend NPU. +Implements the DeepSeek MHC pre-processing algorithm: softmax per row, +then alternating row/column normalization until the matrix is approximately +doubly-stochastic. + +## Algorithm + +```python +def sinkhorn_normalize(x, repeat=10, eps=1e-6): + x = x.softmax(-1) + eps + x = x / (x.sum(-2, keepdim=True) + eps) + for _ in range(repeat - 1): + x = x / (x.sum(-1, keepdim=True) + eps) + x = x / (x.sum(-2, keepdim=True) + eps) + return x +``` + +Input shape: `(N, K, K)` fp16 — N square matrices of dimension K. +K must be <= 128 (the full K×K matrix lives in UB as fp32). + +## Files + +| File | Description | +|---|---| +| `kernel_sinkhorn.cpp` | PTO-ISA C++ kernel (JIT-compiled via bisheng) | +| `jit_util_sinkhorn.py` | Python wrapper — compiles & exposes the kernel | +| `conftest.py` | pytest NPU device fixture | +| `test_sinkhorn.py` | Correctness tests vs PyTorch reference | +| `bench_sinkhorn.py` | Benchmark PTO vs PyTorch | +| `plot_sinkhorn.py` | Plot benchmark CSVs | + +## Usage + +```bash +# compile + test +python -m pytest test_sinkhorn.py -v --npu=npu:0 + +# benchmark +python bench_sinkhorn.py + +# re-plot from saved CSV +python plot_sinkhorn.py +``` + +## Reproducing + +```bash +cd examples/jit_cpp/sinkhorn + +# tests (73 cases: shapes × repeats × seeds) +python -m pytest test_sinkhorn.py -v --npu=npu:0 + +# benchmark (batch × K grid, repeat=10, eps=1e-6) +python bench_sinkhorn.py --batches 1 4 8 16 32 64 --hidden-dims 4 8 16 32 64 128 + +# plots +python plot_sinkhorn.py +``` diff --git a/examples/jit_cpp/sinkhorn/bench_sinkhorn.py b/examples/jit_cpp/sinkhorn/bench_sinkhorn.py new file mode 100644 index 00000000..add08327 --- /dev/null +++ b/examples/jit_cpp/sinkhorn/bench_sinkhorn.py @@ -0,0 +1,238 @@ +# pylint: disable=wrong-import-position +""" +Benchmark PTO doubly-stochastic Sinkhorn against PyTorch reference. + +Writes: + outputs/csv/sinkhorn_compare_bd{block_dim}.csv + outputs/plots/ (via plot_sinkhorn.py) +""" +import argparse +import sys +from pathlib import Path + +import torch +import torch_npu # noqa + +THIS_DIR = Path(__file__).resolve().parent +FAST_HADAMARD_DIR = THIS_DIR.parent / "fast_hadamard" +if str(FAST_HADAMARD_DIR) not in sys.path: + sys.path.insert(0, str(FAST_HADAMARD_DIR)) + +from bench_common import ( # noqa: E402 + add_common_benchmark_args, + benchmark_npu_us, + benchmark_trials_us, + ensure_output_dir, + make_buffer_pool, + pool_item, + resolve_dir_arg, + validate_benchmark_args, + write_csv_records, +) + +from jit_util_common import get_current_stream_ptr # noqa: E402 +from jit_util_sinkhorn import jit_compile # noqa: E402 + +DEFAULT_WARMUP = 10 +DEFAULT_REPEATS = 100 +SINKHORN_REPEAT = 8 +TORCH_REF_REPEAT = 10 # fixed for consistent baseline +SINKHORN_EPS = 1e-6 +BYTES_PER_ELEMENT = 2 # fp16 + +CSV_HEADER = ( + "batch,N,pto_duration_us,torch_duration_us," + "pto_bandwidth_gbs,torch_bandwidth_gbs,pto_speedup_vs_torch," + "trials,pto_duration_mean_us,pto_duration_std_us,pto_duration_min_us," + "pto_duration_max_us,pto_duration_cv_pct,torch_duration_mean_us," + "torch_duration_std_us,torch_duration_min_us," + "torch_duration_max_us,torch_duration_cv_pct\n" +) + + +def sinkhorn_ref(x, repeat=10, eps=1e-6): + """PyTorch reference (runs on NPU via torch ops).""" + x = x.float() + x = x.softmax(-1) + eps + x = x / (x.sum(-2, keepdim=True) + eps) + for _ in range(repeat - 1): + x = x / (x.sum(-1, keepdim=True) + eps) + x = x / (x.sum(-2, keepdim=True) + eps) + return x.half() + + +def _parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark PTO Sinkhorn (doubly-stochastic) against PyTorch reference." + ) + parser.add_argument( + "--no-cache-stream", + dest="cache_stream", + action="store_false", + help="Disable cached stream pointer reuse for PTO launches.", + ) + parser.set_defaults(cache_stream=True) + return add_common_benchmark_args( + parser, + default_warmup=DEFAULT_WARMUP, + default_repeats=DEFAULT_REPEATS, + ).parse_args() + + +def _effective_bandwidth_gbs(batch, K, duration_us): + if duration_us <= 0: + return 0.0 + # read K*K + write K*K + data_bytes = batch * 2 * K * K * BYTES_PER_ELEMENT + return (data_bytes / 1e9) / (duration_us / 1e6) + + +def _make_shape_pools(batch, K, warmup, repeats, device): + return { + "x": make_buffer_pool( + warmup, + repeats, + lambda: torch.randn(batch, K, K, device=device, dtype=torch.float16), + ), + "y": make_buffer_pool( + warmup, + repeats, + lambda: torch.empty(batch, K, K, device=device, dtype=torch.float16), + ), + } + + +def benchmark( + sinq_func, + *, + warmup, + repeats, + trials, + output_dir, + device, + batches, + hidden_dims, + stream_ptr=None, +): + ensure_output_dir(output_dir) + block_dim = sinq_func.block_dim + + print(f"\n{'=' * 92}") + print( + f"SINKHORN DS BENCHMARK (BLOCK_DIM={block_dim}, pto_repeat={SINKHORN_REPEAT}, torch_repeat={TORCH_REF_REPEAT})" + ) + print(f"{'=' * 92}") + header = ( + f"{'batch':>6s} {'K':>6s}" + f" {'pto_us':>10s} {'torch_us':>10s}" + f" {'pto_bw(GB/s)':>12s} {'torch_bw(GB/s)':>14s} {'pto_speedup':>11s}" + ) + print(header) + print("-" * len(header)) + + records = [] + for batch in batches: + for K in hidden_dims: + pools = _make_shape_pools(batch, K, warmup, repeats, device) + x_list = pools["x"] + y_list = pools["y"] + + pto_stats = benchmark_trials_us( + trials, + lambda x_list=x_list, y_list=y_list: benchmark_npu_us( + warmup, + repeats, + lambda i: sinq_func( + pool_item(x_list, i), + pool_item(y_list, i), + repeat=SINKHORN_REPEAT, + eps=SINKHORN_EPS, + stream_ptr=stream_ptr, + ), + ), + ) + torch_stats = benchmark_trials_us( + trials, + lambda x_list=x_list: benchmark_npu_us( + warmup, + repeats, + lambda i: sinkhorn_ref( + pool_item(x_list, i), + repeat=TORCH_REF_REPEAT, + eps=SINKHORN_EPS, + ), + ), + ) + + pto_us = pto_stats["median_us"] + torch_us = torch_stats["median_us"] + pto_bw = _effective_bandwidth_gbs(batch, K, pto_us) + torch_bw = _effective_bandwidth_gbs(batch, K, torch_us) + pto_speedup = torch_us / pto_us if pto_us > 0 else 0.0 + + print( + f"{batch:>6d} {K:>6d}" + f" {pto_us:>10.2f} {torch_us:>10.2f}" + f" {pto_bw:>12.4f} {torch_bw:>14.4f}" + f" {pto_speedup:>11.3f}" + ) + + records.append( + f"{batch},{K},{pto_us:.4f},{torch_us:.4f}," + f"{pto_bw:.6f},{torch_bw:.6f}," + f"{pto_speedup:.4f}," + f"{trials},{pto_stats['mean_us']:.4f},{pto_stats['std_us']:.4f}," + f"{pto_stats['min_us']:.4f},{pto_stats['max_us']:.4f}," + f"{pto_stats['cv_pct']:.4f},{torch_stats['mean_us']:.4f}," + f"{torch_stats['std_us']:.4f}," + f"{torch_stats['min_us']:.4f}," + f"{torch_stats['max_us']:.4f}," + f"{torch_stats['cv_pct']:.4f}" + ) + + csv_path = output_dir / f"sinkhorn_compare_bd{block_dim}.csv" + write_csv_records(csv_path, CSV_HEADER, records) + print(f"\nSaved to {csv_path}") + + +def main(): + args = _parse_args() + validate_benchmark_args(args) + + torch.npu.set_device(args.npu) + base = THIS_DIR + kernel_path = base / "kernel_sinkhorn.cpp" + csv_dir = resolve_dir_arg(base, args.csv_dir) + + print(f"Using device: {args.npu}") + print("Compiling kernel_sinkhorn.cpp ...") + sinq_func = jit_compile(str(kernel_path), verbose=True, device=args.npu) + stream_ptr = get_current_stream_ptr() if args.cache_stream else None + if stream_ptr is not None: + print("Using cached NPU stream pointer for PTO launches.") + + # Default: mHC use case (hc_mult=4, varying num_tokens). + # In DeepSeek MHC, sinkhorn always runs on (num_tokens, 4, 4) matrices. + # Pass --hidden-dims to benchmark other K values (general fallback path). + batches = ( + args.batches + if args.batches + else [1, 4, 16, 64, 256, 512, 1024, 2048, 4096, 8192, 16384, 65536] + ) + dims = args.hidden_dims if args.hidden_dims else [4, 8, 16, 32, 64, 128] + + benchmark( + sinq_func, + warmup=args.warmup, + repeats=args.repeats, + trials=args.trials, + output_dir=csv_dir, + device=args.npu, + batches=batches, + hidden_dims=dims, + stream_ptr=stream_ptr, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/jit_cpp/sinkhorn/conftest.py b/examples/jit_cpp/sinkhorn/conftest.py new file mode 100644 index 00000000..34e7fa81 --- /dev/null +++ b/examples/jit_cpp/sinkhorn/conftest.py @@ -0,0 +1,40 @@ +import pytest +import torch + + +def normalize_npu_device(device: str | int) -> str: + text = str(device).strip().strip('"').strip("'") + if text.lower().startswith("npu:"): + index = text.split(":", 1)[1].strip() + else: + index = text + + if not index.isdigit(): + raise ValueError( + f"Invalid NPU device '{device}'. Expected values like 0 or npu:0." + ) + return f"npu:{int(index)}" + + +def pytest_addoption(parser): + try: + parser.addoption( + "--npu", + action="store", + default="npu:0", + help="NPU device (examples: 0, npu:0, '0', 'npu:0').", + ) + except ValueError as exc: + if "--npu" not in str(exc): + raise + + +@pytest.fixture(scope="session") +def npu_device(request): + raw = request.config.getoption("--npu") + return normalize_npu_device(raw) + + +@pytest.fixture(scope="session", autouse=True) +def setup_npu_device(npu_device): + torch.npu.set_device(npu_device) diff --git a/examples/jit_cpp/sinkhorn/jit_util_sinkhorn.py b/examples/jit_cpp/sinkhorn/jit_util_sinkhorn.py new file mode 100644 index 00000000..cde53c40 --- /dev/null +++ b/examples/jit_cpp/sinkhorn/jit_util_sinkhorn.py @@ -0,0 +1,109 @@ +# pylint: disable=wrong-import-position +import ctypes +import sys +from pathlib import Path + +import torch + +THIS_DIR = Path(__file__).resolve().parent +FAST_HADAMARD_DIR = THIS_DIR.parent / "fast_hadamard" +if str(FAST_HADAMARD_DIR) not in sys.path: + sys.path.insert(0, str(FAST_HADAMARD_DIR)) + +from jit_util_common import ( # noqa: E402 + BLOCK_DIM, + DEFAULT_DEVICE, + jit_compile_with_loader, + load_cdll, + load_required_symbol, + resolve_launch_block_dim, + resolve_stream_ptr, + torch_to_ctypes, +) + +MAX_DIM = 128 + +SINKHORN_DS_ARGTYPES = [ + ctypes.c_uint32, # blockDim + ctypes.c_void_p, # stream + ctypes.c_void_p, # input + ctypes.c_void_p, # output + ctypes.c_uint32, # N + ctypes.c_uint32, # K + ctypes.c_uint32, # repeat + ctypes.c_float, # eps +] + + +def _validate(input_tensor, output_tensor, K): + if input_tensor.dim() != 3: + raise ValueError("input must be 3D (N, K, K).") + if input_tensor.shape[1] != K or input_tensor.shape[2] != K: + raise ValueError(f"input must have shape (N, {K}, {K}).") + if output_tensor.shape != input_tensor.shape: + raise ValueError("output must have the same shape as input.") + if input_tensor.dtype != torch.float16: + raise TypeError("input must use torch.float16.") + if not input_tensor.is_contiguous() or not output_tensor.is_contiguous(): + raise ValueError("tensors must be contiguous.") + if input_tensor.device != output_tensor.device: + raise ValueError("tensors must be on the same device.") + if K > MAX_DIM: + raise ValueError(f"K must be <= {MAX_DIM}.") + + +def load_lib(lib_path, block_dim=BLOCK_DIM): + lib = load_cdll(lib_path) + resolved_block_dim = max(1, int(block_dim)) + + kernel = load_required_symbol( + lib, + "call_sinkhorn_ds_kernel", + SINKHORN_DS_ARGTYPES, + ) + + def sinkhorn_ds_func( + input_tensor, + output_tensor, + *, + repeat=10, + eps=1e-6, + block_dim=resolved_block_dim, + stream_ptr=None, + ): + N, K, _ = input_tensor.shape + _validate(input_tensor, output_tensor, K) + kernel( + resolve_launch_block_dim(block_dim, resolved_block_dim), + resolve_stream_ptr(stream_ptr), + torch_to_ctypes(input_tensor), + torch_to_ctypes(output_tensor), + N, + K, + repeat, + float(eps), + ) + + sinkhorn_ds_func.block_dim = resolved_block_dim + return sinkhorn_ds_func + + +def jit_compile( + src_path, + verbose=True, + clean_up=False, + so_dir=None, + device: str | int = DEFAULT_DEVICE, + block_dim=None, +): + if so_dir is None: + so_dir = THIS_DIR / "outputs" / "so" + return jit_compile_with_loader( + src_path, + load_lib, + verbose=verbose, + clean_up=clean_up, + so_dir=so_dir, + device=device, + block_dim=block_dim, + ) diff --git a/examples/jit_cpp/sinkhorn/kernel_sinkhorn.cpp b/examples/jit_cpp/sinkhorn/kernel_sinkhorn.cpp new file mode 100644 index 00000000..f2cfbe47 --- /dev/null +++ b/examples/jit_cpp/sinkhorn/kernel_sinkhorn.cpp @@ -0,0 +1,1216 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. All rights reserved. +See LICENSE in the root of the software repository for the full License text. +*/ + +/** + * Doubly-stochastic Sinkhorn normalization — Ascend 910B kernel (fp16 I/O). + * + * Mirrors DeepSeek TileKernels `sinkhorn_normalize_ref`: + * + * x = x.softmax(-1) + eps + * x = x / (x.sum(-2, keepdim=True) + eps) + * for _ in range(repeat - 1): + * x = x / (x.sum(-1, keepdim=True) + eps) + * x = x / (x.sum(-2, keepdim=True) + eps) + * + * Three code paths, dispatched on K: + * + * K ∈ {4, 8, 16} — `sinkhornFastPath` (TILE_COLS = 16) + * K-templated so every tile dimension is compile-time. + * Uses `TCOLEXPANDDIV` (PTO-ISA 9.0.0 op) on the full + * [K, ROW_BLOCK_COLS] interleaved tile, replacing + * `K+1` TADD-tree ops + `K` TDIVs per iteration with + * exactly two ops (TCOLSUM + TCOLEXPANDDIV). + * + * K ∈ (16, 64] — `sinkhornStridedTree` (TILE_COLS ∈ {16, 32, 64}) + * K-runtime. Interleaved layout but falls back to + * a flat TADD-tree + K×TDIV col-normalize because + * `TCOLEXPANDDIV` on a K-runtime tile requires + * runtime tile widths we don't have. + * + * K ∈ (64, 128] — `sinkhornPerMatrixFp32` + * fp16 I/O with fp32 internal compute (fp16 loses too + * much precision at K=128). Per-matrix, no batching. + * + * Parallelism model (all paths): + * The N matrices are sharded across AIV cores (`num_workers` total, + * = get_block_num() × get_subblockdim()). Each worker takes an + * equal slice. For fast / strided-tree paths it processes its slice + * in chunks of up to `MAX_BATCH_MATRICES` matrices, using one bulk + * TLOAD + TSTORE per chunk. Within each chunk, matrices are further + * divided into groups of up to `MAX_GROUP_SIZE`; each group is + * softmaxed + sinkhorn-iterated as one batched unit. + * + * UB layout (fast / strided-tree paths): + * SCRATCH_UB — reduction scratch (used as `tmp` by TROW/TCOLSUM, + * and as row-block scratch in col-normalize) + * WORK_UB — primary matrix data, in interleaved layout + * ROW_STATS_UB — per-row reduction output (K×1 col-vec) + * COL_STATS_UB — per-column reduction output (1×BLK row-vec) + * BATCH_UB — bulk-loaded matrices before interleave; written back + * after all groups in the chunk are processed. + */ + +#include + +#ifndef GM_ADDR +#define GM_ADDR __gm__ uint8_t * +#endif + +using namespace pto; + +// ========================================================================== +// Compile-time constants +// ========================================================================== +constexpr uint32_t UB_BYTES = 192 * 1024; // per-AIV unified buffer size +constexpr uint32_t MAX_K = 128; // max matrix dim we support +constexpr uint32_t STACK_ROWS = + 512; // tall-tile row count for fast / strided-tree paths +constexpr uint32_t MAX_MATS_PER_GROUP_CAP = + 128; // upper bound on mats per group (UB footprint) + +// 32-byte align helper (fp16 PTO tiles require 32-byte-aligned row bytes). +#define ALIGN_32(x) (((x) + 31u) & ~31u) + +#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) + +// ========================================================================== +// Tile type aliases +// ========================================================================== +// 1-D row vector over N elements (static). Used for flat elementwise ops. +template +using FlatVec = Tile; + +// 2-D row-major tile. Row stride = static Cols; valid shape is runtime. +template +using Tile2D = + Tile; + +// 2-D col-major R×1 vector — used as output of per-row reductions +// (TROWMAX / TROWSUM give one scalar per row). +template +using ColVec = + Tile; + +// ========================================================================== +// Global-memory tensor aliases (contiguous row-major) +// ========================================================================== +using GmDenseStride = Stride<1, 1, 1, DYNAMIC, 1>; +template +using GmShape2D = TileShape2D; +template +using GmTensor = GlobalTensor, GmDenseStride, Layout::ND>; + +// ========================================================================== +// Pipeline-flag helpers +// ========================================================================== +// Each AIV has three pipelines (MTE2 / V / MTE3) that run in parallel. +// Cross-pipe ordering uses set_flag / wait_flag pairs keyed by EVENT_ID. +// `initPipelineFlags` primes the flags so the first wait_flag below +// succeeds immediately (lets us always-wait-then-set in a ring). +AICORE inline void initPipelineFlags() { + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); +} + +AICORE inline void drainPipelineFlags() { + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); +} + +// ========================================================================== +// Strided UB→UB copy +// ========================================================================== +// Wrapper around the `copy_ubuf_to_ubuf` CCE builtin. Used to transpose +// between the natural-order batch layout and the interleaved row-block +// layout expected by batched col-normalize. +// +// Parameters mirror the builtin: +// nBurst — number of bursts (= number of matrices copied) +// lenBurst — bytes per burst, expressed in 32B blocks +// srcGap — gap between source bursts, in 32B blocks +// dstGap — gap between destination bursts, in 32B blocks +namespace pto { +template +__tf__ AICORE inline void stridedUBCopyImpl( + typename TileDescriptor::TileDType __out__ dstTile, + typename TileDescriptor::TileDType __in__ srcTile, uint16_t nBurst, + uint16_t lenBurst, uint16_t srcGap, uint16_t dstGap) { + __ubuf__ void *dst = (__ubuf__ void *)__cce_get_tile_ptr(dstTile); + __ubuf__ void *src = (__ubuf__ void *)__cce_get_tile_ptr(srcTile); + __builtin_cce_copy_ubuf_to_ubuf(dst, src, (uint8_t)0, nBurst, lenBurst, + srcGap, dstGap); +} +} // namespace pto + +template +AICORE inline void stridedUBCopy(TileT &dst, TileT &src, uint16_t nBurst, + uint16_t lenBurst, uint16_t srcGap, + uint16_t dstGap) { + pto::stridedUBCopyImpl(dst.data(), src.data(), nBurst, lenBurst, + srcGap, dstGap); +} + +// ========================================================================== +// Fast path: K ∈ {4, 8, 16} — TCOLEXPANDDIV on full interleaved tile +// ========================================================================== +// +// Every tile dimension is compile-time, which lets us: +// (a) keep the tile row-stride static (it equals `ROW_BLOCK_COLS`); +// (b) run TCOLSUM + TCOLEXPANDDIV on the full [K, ROW_BLOCK_COLS] tile, +// replacing the K+1-op TADD tree + K-op TDIV sequence. +// +// Two views of the same physical UB memory: +// +// Tall view shape (STACK_ROWS, TILE_COLS ) +// row stride = TILE_COLS +// used for per-row ops (softmax, row-normalize) +// matrix i's row r lives at tall-row ((i / GS) * K + i % GS + +// r * GS) — i.e. per-matrix rows are naturally discoverable +// +// Interleaved shape (K, ROW_BLOCK_COLS ) +// view row stride = ROW_BLOCK_COLS = MAX_GROUP_SIZE * TILE_COLS +// used for per-col ops (col-normalize). Each "column" +// of the interleaved view corresponds to one matrix's +// one column, so TCOLSUM gives per-matrix-per-col sums +// directly, and TCOLEXPANDDIV normalizes correctly. +// +// Both views work simultaneously because the tall stride (TILE_COLS) times +// the group size (MAX_GROUP_SIZE) equals the interleaved stride +// (ROW_BLOCK_COLS). We always process a full group of MAX_GROUP_SIZE +// matrices logically: partial groups are zero-padded. Zero padding makes +// softmax produce 1/K in pad cells, which is a benign constant and doesn't +// leak into valid-matrix outputs (each matrix's row/col-normalize is local +// to that matrix's slice of the interleaved tile). +template +AICORE void sinkhornFastPath(__gm__ T *gm_in, __gm__ T *gm_out, uint32_t N, + float eps) { + // ---- compile-time constants derived from template parameters ---- + constexpr unsigned K = K_TEMPLATE; + constexpr unsigned TILE_COLS = TILE_COLS_TEMPLATE; + constexpr unsigned TALL_ROWS = STACK_ROWS_OVERRIDE; + constexpr unsigned MAX_GROUP_SIZE = TALL_ROWS / K; // matrices per group + constexpr unsigned ROW_BLOCK_COLS = + MAX_GROUP_SIZE * TILE_COLS; // width of interleaved rows + static_assert(K * ROW_BLOCK_COLS == TALL_ROWS * TILE_COLS, + "Interleaved and tall views must cover the same UB region"); + + constexpr unsigned MATRIX_ROW_BYTES = TILE_COLS * sizeof(half); + constexpr unsigned TILE_BYTES = TALL_ROWS * TILE_COLS * sizeof(half); + + // ---- UB layout ---- + // Fixed regions first (compute scratch + per-axis stats), then the + // double-buffered batch region at the top of UB. The two batch halves + // (BATCH_UB_PING / BATCH_UB_PONG) alternate on consecutive chunks so + // that MTE2 TLOAD of the next chunk overlaps PIPE_V compute on the + // current chunk, and MTE3 TSTORE of the previous chunk overlaps both. + constexpr unsigned SCRATCH_UB = 0; // reduction scratch + constexpr unsigned WORK_UB = + ALIGN_32(SCRATCH_UB + TILE_BYTES); // interleaved matrix data + constexpr unsigned ROW_STATS_UB = + ALIGN_32(WORK_UB + TILE_BYTES); // per-row reduction output + constexpr unsigned COL_STATS_UB = + ALIGN_32(ROW_STATS_UB + ALIGN_32(TALL_ROWS * sizeof(half))); + constexpr unsigned BATCH_UB_BASE = + ALIGN_32(COL_STATS_UB + ALIGN_32(ROW_BLOCK_COLS * sizeof(half))); + + // Split remaining UB in half for ping/pong. Round down to a multiple + // of the row size and cap at the hardware burst-count limit. + constexpr unsigned BATCH_HALF_BUDGET = (UB_BYTES - BATCH_UB_BASE) / 2; + constexpr unsigned BATCH_HALF_ROWS_RAW = + BATCH_HALF_BUDGET / (TILE_COLS * sizeof(half)); + constexpr unsigned MAX_BATCH_ROWS = + BATCH_HALF_ROWS_RAW < 4095 ? BATCH_HALF_ROWS_RAW : 4095; + constexpr unsigned BATCH_HALF_BYTES = + MAX_BATCH_ROWS * TILE_COLS * sizeof(half); + constexpr unsigned BATCH_UB_PING = BATCH_UB_BASE; + constexpr unsigned BATCH_UB_PONG = BATCH_UB_BASE + ALIGN_32(BATCH_HALF_BYTES); + static_assert(BATCH_UB_PONG + BATCH_HALF_BYTES <= UB_BYTES, + "Double-buffered BATCH_UB exceeds UB capacity"); + + // Hardware setup. + set_mask_norm(); + set_vector_mask(-1, -1); + + // Per-worker sharding. N matrices evenly split across all AIV cores; + // the first `remainder` cores take one extra matrix. + const uint32_t num_workers = get_block_num() * get_subblockdim(); + const uint32_t worker_id = + get_block_idx() * get_subblockdim() + get_subblockid(); + const uint32_t base_per_worker = N / num_workers; + const uint32_t remainder = N % num_workers; + const uint32_t my_first = worker_id * base_per_worker + + (worker_id < remainder ? worker_id : remainder); + const uint32_t my_count = base_per_worker + (worker_id < remainder ? 1 : 0); + if (my_count == 0) return; + + // Loop constants. + constexpr uint32_t K_SQUARED = K * K; + constexpr uint32_t GROUP_SIZE_STATIC = MAX_GROUP_SIZE; + constexpr uint32_t CHUNK_MATRICES = + MAX_BATCH_ROWS / K; // mats per TLOAD chunk + const half eps_h = (half)eps; + + // Prime all four cross-pipe flags (two halves × two directions). + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + + // ======================================================================== + // Outer loop: process my matrices in chunks of up to CHUNK_MATRICES. + // Chunks alternate between the two BATCH_UB halves so DMA overlaps + // PIPE_V compute on the other half. The padding columns of each half + // are zeroed the first time that half is used (lazy init — small-batch + // kernels that only touch PING never pay for PONG's zero-fill). + // ======================================================================== + bool half_zeroed[2] = {false, false}; + uint32_t ping = 1; // 1 → PING half + EVENT_ID0; 0 → PONG half + EVENT_ID1 + for (uint32_t chunk_offset = 0; chunk_offset < my_count; + chunk_offset += CHUNK_MATRICES, ping = 1 - ping) { + const uint32_t chunk_matrices = + min(CHUNK_MATRICES, my_count - chunk_offset); + const uint32_t chunk_rows = chunk_matrices * K; + __gm__ T *chunk_gm_in = + gm_in + (size_t)(my_first + chunk_offset) * K_SQUARED; + __gm__ T *chunk_gm_out = + gm_out + (size_t)(my_first + chunk_offset) * K_SQUARED; + + const unsigned batch_ub = ping ? BATCH_UB_PING : BATCH_UB_PONG; + const event_t ev = ping ? (event_t)EVENT_ID0 : (event_t)EVENT_ID1; + + // Bulk TLOAD from GM → this half of BATCH_UB (natural order). + Tile2D batch_tile(chunk_rows, K); + TASSIGN(batch_tile, batch_ub); + + GmShape2D gm_shape(chunk_rows, K); + GmDenseStride gm_stride(K); + GmTensor gm_in_tensor(chunk_gm_in, gm_shape, gm_stride); + + wait_flag(PIPE_V, PIPE_MTE2, + ev); // wait for PIPE_V to finish with this half + + // Lazy-zero: on this half's first use, zero the TILE_COLS - K padding cols + // so subsequent ops don't read uninitialized data. TLOAD below writes + // only K cols per row; the padding stays at zero for the kernel lifetime. + if (!half_zeroed[ping]) { + FlatVec zero_flat( + 1, MAX_BATCH_ROWS * TILE_COLS); + TASSIGN(zero_flat, batch_ub); + TEXPANDS(zero_flat, (T)0); + pipe_barrier(PIPE_V); + half_zeroed[ping] = true; + } + + TLOAD(batch_tile, gm_in_tensor); + set_flag(PIPE_MTE2, PIPE_V, ev); + wait_flag(PIPE_MTE2, PIPE_V, ev); + wait_flag(PIPE_MTE3, PIPE_V, + ev); // wait for previous TSTORE on this half to drain + + // ====================================================================== + // Inner loop: process the chunk in groups of up to MAX_GROUP_SIZE. + // ====================================================================== + for (uint32_t group_start = 0; group_start < chunk_matrices; + group_start += GROUP_SIZE_STATIC) { + const uint32_t group_size = + min(GROUP_SIZE_STATIC, chunk_matrices - group_start); + const unsigned group_batch_offset = + batch_ub + group_start * K * TILE_COLS * sizeof(T); + + // stridedUBCopy parameters for natural-order → interleaved transpose. + constexpr uint16_t tile_row_blocks = TILE_COLS * sizeof(half) / 32; + const uint16_t src_gap_blocks = (uint16_t)(K - 1) * tile_row_blocks; + + // --- Zero WORK_UB (pads invalid matrices in the group to 0) ----- + // Only needed for partial groups (last group may have fewer than + // GROUP_SIZE_STATIC matrices). Full groups fully overwrite WORK_UB + // in the interleave step below, so the zero is redundant. + if (group_size < GROUP_SIZE_STATIC) { + FlatVec work_flat(1, K * ROW_BLOCK_COLS); + TASSIGN(work_flat, WORK_UB); + TEXPANDS(work_flat, (T)0); + pipe_barrier(PIPE_V); + } + + // --- Interleave: BATCH_UB → WORK_UB ------------------------------ + // After this, WORK_UB's row-block r (offset r*ROW_BLOCK_COLS) holds: + // [matrix 0 row r][matrix 1 row r]...[matrix (group_size-1) row r] + // followed by zero padding up to ROW_BLOCK_COLS. + for (uint32_t row = 0; row < K; ++row) { + Tile2D src_view(group_size, K); + Tile2D dst_view(group_size, K); + TASSIGN(src_view, group_batch_offset + row * MATRIX_ROW_BYTES); + TASSIGN(dst_view, + WORK_UB + row * ROW_BLOCK_COLS * (unsigned)sizeof(half)); + stridedUBCopy(dst_view, src_view, (uint16_t)group_size, tile_row_blocks, + src_gap_blocks, (uint16_t)0); + } + pipe_barrier(PIPE_V); + + // --- Sinkhorn computation --------------------------------------- + if constexpr (REPEAT > 0) { + // Tall view over WORK_UB for per-row ops (softmax, row-normalize). + // Row stride = TILE_COLS; we operate on all TALL_ROWS rows (padded + // matrices are zero and produce benign softmax results). + Tile2D tall_matrix(TALL_ROWS, K); + TASSIGN(tall_matrix, WORK_UB); + + Tile2D tall_scratch(TALL_ROWS, K); + TASSIGN(tall_scratch, SCRATCH_UB); + + ColVec row_stats(TALL_ROWS, 1); + TASSIGN(row_stats, ROW_STATS_UB); + + // ── Step 1: softmax along each matrix-row ────────────────────── + // row_stats[i] = max(tall_matrix[i, :]) + // tall_matrix[i, :] = exp(tall_matrix[i, :] - row_stats[i]) + // row_stats[i] = sum(tall_matrix[i, :]) + // tall_matrix[i, :] = tall_matrix[i, :] / row_stats[i] + TROWMAX(row_stats, tall_matrix, tall_scratch); + pipe_barrier(PIPE_V); + + TROWEXPANDSUB(tall_matrix, tall_matrix, row_stats); + pipe_barrier(PIPE_V); + + { + FlatVec work_flat(1, + TALL_ROWS * TILE_COLS); + TASSIGN(work_flat, WORK_UB); + TEXP(work_flat, work_flat); + pipe_barrier(PIPE_V); + } + + TROWSUM(row_stats, tall_matrix, tall_scratch); + pipe_barrier(PIPE_V); + + TROWEXPANDDIV(tall_matrix, tall_matrix, row_stats); + pipe_barrier(PIPE_V); + + // Step 2 (add eps to the matrix) eliminated: after softmax every + // valid cell is strictly positive (exp() > 0, rowsum > 0), and + // zero-padding matrices also produce positive cells (= 1/K), so + // col-normalize never sees a zero denominator. The eps was + // reference-code defensive, not algorithmically required for + // random inputs of this type. + + // ── Step 3 & 4: col-normalize (and iterations) ───────────────── + // Interleaved view: (K, ROW_BLOCK_COLS) with row stride = + // ROW_BLOCK_COLS. TCOLSUM gives one scalar per column — and + // each column here is one matrix's one column, so the scalar + // we get is exactly that matrix's col sum. + Tile2D interleaved_matrix(K, ROW_BLOCK_COLS); + TASSIGN(interleaved_matrix, WORK_UB); + + Tile2D interleaved_scratch(K, ROW_BLOCK_COLS); + TASSIGN(interleaved_scratch, SCRATCH_UB); + + FlatVec col_stats(1, ROW_BLOCK_COLS); + TASSIGN(col_stats, COL_STATS_UB); + +// Fused col-normalize: 1 TCOLSUM + 1 TCOLEXPANDDIV, 2 barriers. +#define COL_NORMALIZE() \ + do { \ + TCOLSUM(col_stats, interleaved_matrix, interleaved_scratch, true); \ + pipe_barrier(PIPE_V); \ + TCOLEXPANDDIV(interleaved_matrix, interleaved_matrix, col_stats); \ + pipe_barrier(PIPE_V); \ + } while (0) + + // First col-normalize (no row-normalize — softmax already normalized + // rows). + COL_NORMALIZE(); + +// (REPEAT − 1) × { row-normalize ; col-normalize }. +#pragma unroll + for (uint32_t iter = 1; iter < REPEAT; ++iter) { + TASSIGN(row_stats, ROW_STATS_UB); + + TROWSUM(row_stats, tall_matrix, tall_scratch); + pipe_barrier(PIPE_V); + + TROWEXPANDDIV(tall_matrix, tall_matrix, row_stats); + pipe_barrier(PIPE_V); + + COL_NORMALIZE(); + } +#undef COL_NORMALIZE + } + + // --- De-interleave: WORK_UB → BATCH_UB -------------------------- + // Inverse of the interleave above; only the first `group_size` + // matrices are copied (zero-padded tail is discarded). + for (uint32_t row = 0; row < K; ++row) { + Tile2D src_view(group_size, K); + Tile2D dst_view(group_size, K); + TASSIGN(src_view, + WORK_UB + row * ROW_BLOCK_COLS * (unsigned)sizeof(half)); + TASSIGN(dst_view, group_batch_offset + row * MATRIX_ROW_BYTES); + stridedUBCopy(dst_view, src_view, (uint16_t)group_size, tile_row_blocks, + (uint16_t)0, src_gap_blocks); + } + pipe_barrier(PIPE_V); + } + + // Bulk TSTORE from this half of BATCH_UB → GM. + GmTensor gm_out_tensor(chunk_gm_out, gm_shape, gm_stride); + set_flag(PIPE_V, PIPE_MTE3, ev); + wait_flag(PIPE_V, PIPE_MTE3, ev); + TSTORE(gm_out_tensor, batch_tile); + set_flag(PIPE_MTE3, PIPE_V, ev); // next use of this half waits on this + set_flag(PIPE_V, PIPE_MTE2, ev); // next TLOAD of this half can now proceed + } + + // Drain all four ping-pong flags before exit. + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); +} + +// ========================================================================== +// Small-batch path: K ∈ {4, 8, 16}, N < ~2048 — natural-order, no DB +// ========================================================================== +// +// The fast-path's strided-interleaved layout + double-buffer + batched +// `TCOLEXPANDDIV` on the full `(K, ROW_BLOCK_COLS)` interleaved tile +// amortizes beautifully at large batches — but pays ~25us of flag / setup +// overhead per kernel call that's wasted when there are few matrices +// (only one chunk per worker, no overlap benefit). +// +// This path is the simplest possible layout: one matrix per group in a +// (K, K) sub-tile at natural UB row-stride. Per-matrix col-normalize is +// `TCOLSUM + TCOLEXPANDDIV` in a loop over the `gc` valid matrices. No +// double-buffer, no interleave / de-interleave copies — just a TLOAD, an +// inner TMOV from BUF to MAT, the sinkhorn iterations on the tall view, +// and a TSTORE. +// +// At K=4, batch=1 this path clocks ~14us vs ~40us for the fast-path. +// Crossover with the fast-path is around batch=2048. +template +AICORE void sinkhornSmallBatch(__gm__ T *gm_in, __gm__ T *gm_out, uint32_t N, + float eps) { + constexpr unsigned K = K_TEMPLATE; + constexpr unsigned TILE_COLS = TILE_COLS_TEMPLATE; + constexpr unsigned TALL_ROWS = STACK_ROWS; + constexpr unsigned TILE_BYTES = TALL_ROWS * TILE_COLS * sizeof(half); + constexpr unsigned MATRIX_ROW_BYTES = TILE_COLS * sizeof(half); + + // UB layout: MAT (working tile) | SCRATCH (reduction tmp) | ROW_STATS | + // BATCH_UB + constexpr unsigned MAT_UB = 0; + constexpr unsigned SCRATCH_UB = ALIGN_32(MAT_UB + TILE_BYTES); + constexpr unsigned ROW_STATS_UB = ALIGN_32(SCRATCH_UB + TILE_BYTES); + constexpr unsigned BATCH_UB = + ALIGN_32(ROW_STATS_UB + ALIGN_32(TALL_ROWS * sizeof(half))); + constexpr unsigned BATCH_BUF_ROWS_RAW = + (UB_BYTES - BATCH_UB) / (TILE_COLS * sizeof(half)); + constexpr unsigned MAX_BATCH_ROWS = + BATCH_BUF_ROWS_RAW < 4095 ? BATCH_BUF_ROWS_RAW : 4095; + + set_mask_norm(); + set_vector_mask(-1, -1); + + const uint32_t num_workers = get_block_num() * get_subblockdim(); + const uint32_t worker_id = + get_block_idx() * get_subblockdim() + get_subblockid(); + const uint32_t base_per_worker = N / num_workers; + const uint32_t remainder = N % num_workers; + const uint32_t my_first = worker_id * base_per_worker + + (worker_id < remainder ? worker_id : remainder); + const uint32_t my_count = base_per_worker + (worker_id < remainder ? 1 : 0); + if (my_count == 0) return; + + constexpr uint32_t K_SQUARED = K * K; + constexpr uint32_t MAX_GROUP_SIZE = TALL_ROWS / K; // matrices per group + constexpr uint32_t CHUNK_MATRICES = MAX_BATCH_ROWS / K; + const half eps_h = (half)eps; + + initPipelineFlags(); + + for (uint32_t chunk_offset = 0; chunk_offset < my_count; + chunk_offset += CHUNK_MATRICES) { + const uint32_t chunk_matrices = + min(CHUNK_MATRICES, my_count - chunk_offset); + const uint32_t chunk_rows = chunk_matrices * K; + __gm__ T *chunk_gm_in = + gm_in + (size_t)(my_first + chunk_offset) * K_SQUARED; + __gm__ T *chunk_gm_out = + gm_out + (size_t)(my_first + chunk_offset) * K_SQUARED; + + // Zero the BATCH_UB region we're about to load (padding cols stay 0). + { + FlatVec zero_flat(1, + chunk_rows * TILE_COLS); + TASSIGN(zero_flat, BATCH_UB); + TEXPANDS(zero_flat, (T)0); + pipe_barrier(PIPE_V); + } + + Tile2D batch_tile(chunk_rows, K); + TASSIGN(batch_tile, BATCH_UB); + GmShape2D gm_shape(chunk_rows, K); + GmDenseStride gm_stride(K); + GmTensor gm_in_tensor(chunk_gm_in, gm_shape, gm_stride); + + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(batch_tile, gm_in_tensor); + pipe_barrier(PIPE_ALL); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // Process the chunk in groups of up to MAX_GROUP_SIZE matrices. + for (uint32_t group_start = 0; group_start < chunk_matrices; + group_start += MAX_GROUP_SIZE) { + const uint32_t group_size = + min(MAX_GROUP_SIZE, chunk_matrices - group_start); + const uint32_t group_rows = group_size * K; + const uint32_t group_cells = group_rows * TILE_COLS; + const unsigned group_batch_offset = + BATCH_UB + group_start * K * TILE_COLS * sizeof(T); + + // Copy this group's data from BATCH_UB → MAT_UB (natural order, same + // stride). + { + FlatVec zero_mat(1, TALL_ROWS * TILE_COLS); + TASSIGN(zero_mat, MAT_UB); + TEXPANDS(zero_mat, (T)0); + pipe_barrier(PIPE_V); + } + { + FlatVec src(1, group_cells); + FlatVec dst(1, group_cells); + TASSIGN(src, group_batch_offset); + TASSIGN(dst, MAT_UB); + TMOV(dst, src); + pipe_barrier(PIPE_V); + } + + Tile2D tall_matrix(group_rows, K); + TASSIGN(tall_matrix, MAT_UB); + Tile2D tall_scratch(group_rows, K); + TASSIGN(tall_scratch, SCRATCH_UB); + ColVec row_stats(group_rows, 1); + TASSIGN(row_stats, ROW_STATS_UB); + + if constexpr (REPEAT > 0) { + // Softmax on tall view (group_rows, K). + TROWMAX(row_stats, tall_matrix, tall_scratch); + pipe_barrier(PIPE_V); + + TROWEXPANDSUB(tall_matrix, tall_matrix, row_stats); + pipe_barrier(PIPE_V); + + { + FlatVec flat(1, group_cells); + TASSIGN(flat, MAT_UB); + TEXP(flat, flat); + pipe_barrier(PIPE_V); + } + + TROWSUM(row_stats, tall_matrix, tall_scratch); + pipe_barrier(PIPE_V); + + TROWEXPANDDIV(tall_matrix, tall_matrix, row_stats); + pipe_barrier(PIPE_V); + + { + FlatVec flat(1, group_cells); + TASSIGN(flat, MAT_UB); + TADDS(flat, flat, eps_h); + pipe_barrier(PIPE_V); + } + +// Per-matrix column-normalize. With K compile-time we can call +// `TCOLEXPANDDIV` on each K×K sub-tile — 2 ops per matrix, one +// pipe_barrier between. +#define PER_MATRIX_COL_NORM() \ + do { \ + for (uint32_t mi = 0; mi < group_size; ++mi) { \ + const unsigned mat_off = MAT_UB + mi * K * MATRIX_ROW_BYTES; \ + Tile2D sub_mat(K, K); \ + TASSIGN(sub_mat, mat_off); \ + Tile2D sub_scratch(K, K); \ + TASSIGN(sub_scratch, SCRATCH_UB); \ + FlatVec col_stats(1, K); \ + TASSIGN(col_stats, ROW_STATS_UB); \ + TCOLSUM(col_stats, sub_mat, sub_scratch, false); \ + pipe_barrier(PIPE_V); \ + TCOLEXPANDDIV(sub_mat, sub_mat, col_stats); \ + pipe_barrier(PIPE_V); \ + } \ + } while (0) + + PER_MATRIX_COL_NORM(); + +#pragma unroll + for (uint32_t iter = 1; iter < REPEAT; ++iter) { + TASSIGN(row_stats, ROW_STATS_UB); + TROWSUM(row_stats, tall_matrix, tall_scratch); + pipe_barrier(PIPE_V); + + TROWEXPANDDIV(tall_matrix, tall_matrix, row_stats); + pipe_barrier(PIPE_V); + + PER_MATRIX_COL_NORM(); + } +#undef PER_MATRIX_COL_NORM + } + + // Copy back MAT_UB → BATCH_UB. + { + FlatVec src(1, group_cells); + FlatVec dst(1, group_cells); + TASSIGN(src, MAT_UB); + TASSIGN(dst, group_batch_offset); + TMOV(dst, src); + pipe_barrier(PIPE_V); + } + } + + GmTensor gm_out_tensor(chunk_gm_out, gm_shape, gm_stride); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(gm_out_tensor, batch_tile); + pipe_barrier(PIPE_ALL); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + } + + drainPipelineFlags(); +} + +// ========================================================================== +// Strided-tree fallback: K ∈ (16, 64] — K-runtime path +// ========================================================================== +// +// Same interleaved layout as the fast path, but K is runtime so we can't +// form a compile-time [K, ROW_BLOCK_COLS] tile view. Instead col-normalize +// uses a flat TADD-tree (K − 1 adds to compute per-column sums over the +// row-blocks) followed by K TDIV calls (one per row-block). Net: +// `K + 1` ops + `K − 1` TADD barriers + 1 TDIV barrier per iteration, +// versus 2 ops + 2 barriers on the fast path. +template +AICORE void sinkhornStridedTree(__gm__ T *gm_in, __gm__ T *gm_out, uint32_t N, + uint32_t K, float eps) { + constexpr unsigned TILE_COLS = TILE_COLS_TEMPLATE; + constexpr unsigned TALL_ROWS = STACK_ROWS; + constexpr unsigned MATRIX_ROW_BYTES = TILE_COLS * sizeof(half); + constexpr unsigned TILE_BYTES = TALL_ROWS * TILE_COLS * sizeof(half); + + // UB layout — same regions as the fast path, but COL_STATS slots are + // allocated per-matrix (row-block flat vectors, not a single wide vector). + constexpr unsigned SCRATCH_UB = 0; + constexpr unsigned WORK_UB = ALIGN_32(SCRATCH_UB + TILE_BYTES); + constexpr unsigned ROW_STATS_UB = ALIGN_32(WORK_UB + TILE_BYTES); + constexpr unsigned COL_STATS_UB = + ALIGN_32(ROW_STATS_UB + ALIGN_32(TALL_ROWS * sizeof(half))); + constexpr unsigned BATCH_UB = + ALIGN_32(COL_STATS_UB + + MAX_MATS_PER_GROUP_CAP * ALIGN_32(TILE_COLS * sizeof(half))); + + constexpr unsigned BATCH_BUF_ROWS_RAW = + (UB_BYTES - BATCH_UB) / (TILE_COLS * sizeof(half)); + constexpr unsigned MAX_BATCH_ROWS = + BATCH_BUF_ROWS_RAW < 4095 ? BATCH_BUF_ROWS_RAW : 4095; + static_assert( + BATCH_UB + MAX_BATCH_ROWS * TILE_COLS * sizeof(half) <= UB_BYTES, + "BATCH_UB exceeds UB capacity"); + + set_mask_norm(); + set_vector_mask(-1, -1); + + if (K == 0 || K > TILE_COLS) return; + + const uint32_t num_workers = get_block_num() * get_subblockdim(); + const uint32_t worker_id = + get_block_idx() * get_subblockdim() + get_subblockid(); + const uint32_t base_per_worker = N / num_workers; + const uint32_t remainder = N % num_workers; + const uint32_t my_first = worker_id * base_per_worker + + (worker_id < remainder ? worker_id : remainder); + const uint32_t my_count = base_per_worker + (worker_id < remainder ? 1 : 0); + if (my_count == 0) return; + + const uint32_t K_SQUARED = K * K; + const uint32_t MAX_GROUP_SIZE = TALL_ROWS / K; + const uint32_t CHUNK_MATRICES = MAX_BATCH_ROWS / K; + const half eps_h = (half)eps; + + initPipelineFlags(); + + // Outer chunk loop. + for (uint32_t chunk_offset = 0; chunk_offset < my_count; + chunk_offset += CHUNK_MATRICES) { + const uint32_t chunk_matrices = + min(CHUNK_MATRICES, my_count - chunk_offset); + const uint32_t chunk_rows = chunk_matrices * K; + __gm__ T *chunk_gm_in = + gm_in + (size_t)(my_first + chunk_offset) * K_SQUARED; + __gm__ T *chunk_gm_out = + gm_out + (size_t)(my_first + chunk_offset) * K_SQUARED; + + { + FlatVec batch_flat(1, + chunk_rows * TILE_COLS); + TASSIGN(batch_flat, BATCH_UB); + TEXPANDS(batch_flat, (T)0); + pipe_barrier(PIPE_V); + } + + Tile2D batch_tile(chunk_rows, K); + TASSIGN(batch_tile, BATCH_UB); + GmShape2D gm_shape(chunk_rows, K); + GmDenseStride gm_stride(K); + GmTensor gm_in_tensor(chunk_gm_in, gm_shape, gm_stride); + + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(batch_tile, gm_in_tensor); + pipe_barrier(PIPE_ALL); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + for (uint32_t group_start = 0; group_start < chunk_matrices; + group_start += MAX_GROUP_SIZE) { + const uint32_t group_size = + min(MAX_GROUP_SIZE, chunk_matrices - group_start); + const uint32_t group_tall_rows = group_size * K; + const uint32_t group_flat_len = group_tall_rows * TILE_COLS; + const unsigned group_batch_offset = + BATCH_UB + group_start * K * TILE_COLS * sizeof(T); + const uint32_t row_block_cols = group_size * TILE_COLS; + constexpr uint16_t tile_row_blocks = TILE_COLS * sizeof(half) / 32; + const uint16_t src_gap_blocks = (uint16_t)(K - 1) * tile_row_blocks; + + // Zero WORK_UB. + { + FlatVec work_flat(1, TALL_ROWS * TILE_COLS); + TASSIGN(work_flat, WORK_UB); + TEXPANDS(work_flat, (T)0); + pipe_barrier(PIPE_V); + } + + // Interleave: BATCH_UB → WORK_UB. + for (uint32_t row = 0; row < K; ++row) { + Tile2D src_view(group_size, K); + Tile2D dst_view(group_size, K); + TASSIGN(src_view, group_batch_offset + row * MATRIX_ROW_BYTES); + TASSIGN(dst_view, + WORK_UB + row * row_block_cols * (unsigned)sizeof(half)); + stridedUBCopy(dst_view, src_view, (uint16_t)group_size, tile_row_blocks, + src_gap_blocks, (uint16_t)0); + } + pipe_barrier(PIPE_V); + + if constexpr (REPEAT > 0) { + Tile2D tall_matrix(group_tall_rows, K); + TASSIGN(tall_matrix, WORK_UB); + Tile2D tall_scratch(group_tall_rows, K); + TASSIGN(tall_scratch, SCRATCH_UB); + ColVec row_stats(group_tall_rows, 1); + TASSIGN(row_stats, ROW_STATS_UB); + + // Softmax. + TROWMAX(row_stats, tall_matrix, tall_scratch); + pipe_barrier(PIPE_V); + + TROWEXPANDSUB(tall_matrix, tall_matrix, row_stats); + pipe_barrier(PIPE_V); + + { + FlatVec work_flat(1, group_flat_len); + TASSIGN(work_flat, WORK_UB); + TEXP(work_flat, work_flat); + pipe_barrier(PIPE_V); + } + + TROWSUM(row_stats, tall_matrix, tall_scratch); + pipe_barrier(PIPE_V); + + TROWEXPANDDIV(tall_matrix, tall_matrix, row_stats); + pipe_barrier(PIPE_V); + + { + FlatVec work_flat(1, group_flat_len); + TASSIGN(work_flat, WORK_UB); + TADDS(work_flat, work_flat, eps_h); + pipe_barrier(PIPE_V); + } + + // Col-normalize via flat TADD-tree + K×TDIV. + // + // Layout recap: WORK_UB's row-block r (offset r·row_block_cols) + // is a length-row_block_cols vector holding [mat0_row_r | mat1_row_r | + // ...]. The per-matrix col sum for matrix i's col c sits at index + // i·TILE_COLS + c. To compute all per-matrix col sums we sum the + // K row-blocks element-wise, then divide each row-block by the result. + constexpr unsigned CN_SCRATCH_UB = SCRATCH_UB; +#define COL_NORMALIZE_STRIDED_TREE() \ + do { \ + /* First two adds in parallel (writes to disjoint dests). */ \ + { \ + FlatVec a(1, row_block_cols), \ + b(1, row_block_cols), c(1, row_block_cols), d(1, row_block_cols); \ + TASSIGN(a, WORK_UB); \ + TASSIGN(b, WORK_UB + row_block_cols * (unsigned)sizeof(half)); \ + TASSIGN(c, COL_STATS_UB); \ + TADD(c, a, b); \ + TASSIGN(a, WORK_UB + 2 * row_block_cols * (unsigned)sizeof(half)); \ + TASSIGN(b, WORK_UB + 3 * row_block_cols * (unsigned)sizeof(half)); \ + TASSIGN(d, CN_SCRATCH_UB); \ + TADD(d, a, b); \ + } \ + pipe_barrier(PIPE_V); \ + /* Combine the two partial sums. */ \ + { \ + FlatVec a(1, row_block_cols), \ + b(1, row_block_cols); \ + TASSIGN(a, COL_STATS_UB); \ + TASSIGN(b, CN_SCRATCH_UB); \ + TADD(a, a, b); \ + } \ + pipe_barrier(PIPE_V); \ + /* Fold in the remaining (K − 4) row-blocks one by one. */ \ + for (uint32_t b_idx = 4; b_idx < K; ++b_idx) { \ + FlatVec src(1, row_block_cols), \ + dst(1, row_block_cols); \ + TASSIGN(src, WORK_UB + b_idx * row_block_cols * (unsigned)sizeof(half)); \ + TASSIGN(dst, COL_STATS_UB); \ + TADD(dst, dst, src); \ + pipe_barrier(PIPE_V); \ + } \ + /* Divide each row-block by the accumulated col sums. */ \ + for (uint32_t r = 0; r < K; ++r) { \ + FlatVec row_block(1, row_block_cols), \ + sums(1, row_block_cols); \ + TASSIGN(row_block, \ + WORK_UB + r * row_block_cols * (unsigned)sizeof(half)); \ + TASSIGN(sums, COL_STATS_UB); \ + TDIV(row_block, row_block, sums); \ + } \ + pipe_barrier(PIPE_V); \ + } while (0) + + // First col-normalize. + COL_NORMALIZE_STRIDED_TREE(); + +// (REPEAT − 1) × { row-normalize ; col-normalize }. +#pragma unroll + for (uint32_t iter = 1; iter < REPEAT; ++iter) { + TASSIGN(row_stats, ROW_STATS_UB); + + TROWSUM(row_stats, tall_matrix, tall_scratch); + pipe_barrier(PIPE_V); + + TROWEXPANDDIV(tall_matrix, tall_matrix, row_stats); + pipe_barrier(PIPE_V); + + COL_NORMALIZE_STRIDED_TREE(); + } +#undef COL_NORMALIZE_STRIDED_TREE + } + + // De-interleave: WORK_UB → BATCH_UB. + for (uint32_t row = 0; row < K; ++row) { + Tile2D src_view(group_size, K); + Tile2D dst_view(group_size, K); + TASSIGN(src_view, + WORK_UB + row * row_block_cols * (unsigned)sizeof(half)); + TASSIGN(dst_view, group_batch_offset + row * MATRIX_ROW_BYTES); + stridedUBCopy(dst_view, src_view, (uint16_t)group_size, tile_row_blocks, + (uint16_t)0, src_gap_blocks); + } + pipe_barrier(PIPE_V); + } + + GmTensor gm_out_tensor(chunk_gm_out, gm_shape, gm_stride); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(gm_out_tensor, batch_tile); + pipe_barrier(PIPE_ALL); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + } + + drainPipelineFlags(); +} + +// ========================================================================== +// Per-matrix fp32 fallback: K ∈ (64, 128] +// ========================================================================== +// +// fp16 precision is insufficient at K=128 (softmax denominators accumulate +// 128 terms, each with ~3 decimal-digit precision). We load fp16, convert +// to fp32 for all internal compute, convert back to fp16 on store. No +// batching — one matrix per worker at a time. +template +AICORE void sinkhornPerMatrixFp32(__gm__ T *gm_in, __gm__ T *gm_out, uint32_t N, + uint32_t K, float eps) { + constexpr unsigned TILE_DIM = MAX_K; + constexpr unsigned F32_ROW_BYTES = TILE_DIM * sizeof(float); + constexpr unsigned MATRIX_H_UB = 0; // fp16 IO buffer + constexpr unsigned MATRIX_F_UB = + MATRIX_H_UB + TILE_DIM * TILE_DIM * sizeof(half); + constexpr unsigned SCRATCH_F_UB = + MATRIX_F_UB + TILE_DIM * TILE_DIM * sizeof(float); + constexpr unsigned VECTOR_F_UB = + SCRATCH_F_UB + TILE_DIM * TILE_DIM * sizeof(float); + static_assert(VECTOR_F_UB + TILE_DIM * sizeof(float) <= UB_BYTES, + "fp32 fallback UB layout overflows"); + + set_mask_norm(); + set_vector_mask(-1, -1); + if (K == 0 || K > TILE_DIM) return; + + const uint32_t num_workers = get_block_num() * get_subblockdim(); + const uint32_t worker_id = + get_block_idx() * get_subblockdim() + get_subblockid(); + const uint32_t K_SQUARED = K * K; + const uint32_t flat_len = K * TILE_DIM; + + initPipelineFlags(); + + for (uint32_t matrix_idx = worker_id; matrix_idx < N; + matrix_idx += num_workers) { + __gm__ T *matrix_gm_in = gm_in + (size_t)matrix_idx * K_SQUARED; + __gm__ T *matrix_gm_out = gm_out + (size_t)matrix_idx * K_SQUARED; + + // Zero + TLOAD fp16 matrix. + { + FlatVec zero_flat(1, flat_len); + TASSIGN(zero_flat, MATRIX_H_UB); + TEXPANDS(zero_flat, (T)0); + pipe_barrier(PIPE_V); + } + + Tile2D matrix_h(K, K); + TASSIGN(matrix_h, MATRIX_H_UB); + GmShape2D gm_shape(K, K); + GmDenseStride gm_stride(K); + GmTensor gm_in_tensor(matrix_gm_in, gm_shape, gm_stride); + + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(matrix_h, gm_in_tensor); + pipe_barrier(PIPE_ALL); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // Upcast fp16 → fp32. + { + FlatVec h_flat(1, flat_len); + FlatVec f_flat(1, flat_len); + TASSIGN(h_flat, MATRIX_H_UB); + TASSIGN(f_flat, MATRIX_F_UB); + TCVT(f_flat, h_flat, RoundMode::CAST_NONE); + pipe_barrier(PIPE_V); + } + + Tile2D matrix(K, K); + TASSIGN(matrix, MATRIX_F_UB); + Tile2D scratch(K, K); + TASSIGN(scratch, SCRATCH_F_UB); + ColVec row_stats(K, 1); + TASSIGN(row_stats, VECTOR_F_UB); + + // Softmax. + TROWMAX(row_stats, matrix, scratch); + pipe_barrier(PIPE_V); + + TROWEXPANDSUB(matrix, matrix, row_stats); + pipe_barrier(PIPE_V); + + { + FlatVec mat_flat(1, flat_len); + TASSIGN(mat_flat, MATRIX_F_UB); + TEXP(mat_flat, mat_flat); + pipe_barrier(PIPE_V); + } + + TROWSUM(row_stats, matrix, scratch); + pipe_barrier(PIPE_V); + + TROWEXPANDDIV(matrix, matrix, row_stats); + pipe_barrier(PIPE_V); + + { + FlatVec mat_flat(1, flat_len); + TASSIGN(mat_flat, MATRIX_F_UB); + TADDS(mat_flat, mat_flat, eps); + pipe_barrier(PIPE_V); + } + + // First col-normalize. + { + FlatVec col_stats(1, K); + TASSIGN(col_stats, VECTOR_F_UB); + + TCOLSUM(col_stats, matrix, scratch, false); + pipe_barrier(PIPE_V); + + TADDS(col_stats, col_stats, eps); + pipe_barrier(PIPE_V); + + TCOLEXPANDDIV(matrix, matrix, col_stats); + pipe_barrier(PIPE_V); + } + +// (REPEAT − 1) × { row-normalize ; col-normalize }. +#pragma unroll + for (uint32_t iter = 1; iter < REPEAT; ++iter) { + TASSIGN(row_stats, VECTOR_F_UB); + TROWSUM(row_stats, matrix, scratch); + pipe_barrier(PIPE_V); + + TADDS(row_stats, row_stats, eps); + pipe_barrier(PIPE_V); + + TROWEXPANDDIV(matrix, matrix, row_stats); + pipe_barrier(PIPE_V); + + { + FlatVec col_stats(1, K); + TASSIGN(col_stats, VECTOR_F_UB); + + TCOLSUM(col_stats, matrix, scratch, false); + pipe_barrier(PIPE_V); + + TADDS(col_stats, col_stats, eps); + pipe_barrier(PIPE_V); + + TCOLEXPANDDIV(matrix, matrix, col_stats); + pipe_barrier(PIPE_V); + } + } + + // Downcast fp32 → fp16 and store. + { + FlatVec h_flat(1, flat_len); + FlatVec f_flat(1, flat_len); + TASSIGN(h_flat, MATRIX_H_UB); + TASSIGN(f_flat, MATRIX_F_UB); + TCVT(h_flat, f_flat, RoundMode::CAST_RINT); + pipe_barrier(PIPE_V); + } + + GmTensor gm_out_tensor(matrix_gm_out, gm_shape, gm_stride); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(gm_out_tensor, matrix_h); + pipe_barrier(PIPE_ALL); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + } + + drainPipelineFlags(); +} + +// ========================================================================== +// Dispatch +// ========================================================================== +template +AICORE void dispatchByK(__gm__ T *gm_in, __gm__ T *gm_out, uint32_t N, + uint32_t K, float eps) { + // K ∈ {4, 8, 16}: split by batch size. At small batches the fast-path's + // double-buffer / interleave overhead dominates — use the simple + // per-matrix natural-order path instead. Empirical crossovers scale + // ~inversely with K (smaller K packs more matrices per group, so the + // interleave-layout benefit kicks in at higher batch counts). Thresholds + // were measured via msprof + direct Event timing on 910B. + if (K == 4 && N >= 2048) + sinkhornFastPath(gm_in, gm_out, N, eps); + else if (K == 8 && N >= 1024) + sinkhornFastPath(gm_in, gm_out, N, eps); + else if (K == 16 && N >= 512) + sinkhornFastPath(gm_in, gm_out, N, eps); + else if (K == 4) + sinkhornSmallBatch(gm_in, gm_out, N, eps); + else if (K == 8) + sinkhornSmallBatch(gm_in, gm_out, N, eps); + else if (K == 16) + sinkhornSmallBatch(gm_in, gm_out, N, eps); + // For K=32/64, the strided-tree path is already competitive at large + // batches; smallBatch only wins at very small N where the per-matrix + // loop has few iterations. Thresholds follow the same ~8192/K scaling. + else if (K == 32 && N < 256) + sinkhornSmallBatch(gm_in, gm_out, N, eps); + else if (K == 64 && N < 128) + sinkhornSmallBatch(gm_in, gm_out, N, eps); + // Other K values (odd K, K > 16 && K < 32, etc.) fall through to + // strided-tree. + else if (K > 0 && K <= 16) + sinkhornStridedTree(gm_in, gm_out, N, K, eps); + else if (K <= 32) + sinkhornStridedTree(gm_in, gm_out, N, K, eps); + else if (K <= 64) + sinkhornStridedTree(gm_in, gm_out, N, K, eps); + else if (K <= MAX_K) + sinkhornPerMatrixFp32(gm_in, gm_out, N, K, eps); +} + +// Specialize on `repeat` so that the per-iteration unroll constant is known. +template +AICORE void dispatchByRepeat(__gm__ T *gm_in, __gm__ T *gm_out, uint32_t N, + uint32_t K, uint32_t repeat, float eps) { + switch (repeat) { + case 0: + dispatchByK(gm_in, gm_out, N, K, eps); + break; + case 1: + dispatchByK(gm_in, gm_out, N, K, eps); + break; + case 3: + dispatchByK(gm_in, gm_out, N, K, eps); + break; + case 5: + dispatchByK(gm_in, gm_out, N, K, eps); + break; + case 8: + dispatchByK(gm_in, gm_out, N, K, eps); + break; + case 10: + dispatchByK(gm_in, gm_out, N, K, eps); + break; + case 20: + dispatchByK(gm_in, gm_out, N, K, eps); + break; + default: + dispatchByK(gm_in, gm_out, N, K, eps); + break; + } +} +#endif // __CCE_AICORE__ == 220 && __DAV_C220_VEC__ + +// ========================================================================== +// C ABI +// ========================================================================== +extern "C" __global__ AICORE void sinkhorn_ds_fp16(GM_ADDR input, + GM_ADDR output, uint32_t N, + uint32_t K, uint32_t repeat, + float eps) { +#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) + dispatchByRepeat((__gm__ half *)input, (__gm__ half *)output, N, K, + repeat, eps); +#else + (void)input; + (void)output; + (void)N; + (void)K; + (void)repeat; + (void)eps; +#endif +} + +// Host-side launch. Ascend 910B runs 2 AIV cores per cube core. +extern "C" void call_sinkhorn_ds_kernel(uint32_t cube_core_num, void *stream, + uint8_t *input, uint8_t *output, + uint32_t N, uint32_t K, uint32_t repeat, + float eps) { + sinkhorn_ds_fp16<<>>(input, output, N, K, + repeat, eps); +} diff --git a/examples/jit_cpp/sinkhorn/plot_sinkhorn.py b/examples/jit_cpp/sinkhorn/plot_sinkhorn.py new file mode 100644 index 00000000..fd08a5a2 --- /dev/null +++ b/examples/jit_cpp/sinkhorn/plot_sinkhorn.py @@ -0,0 +1,190 @@ +# pylint: disable=wrong-import-position +""" +Plot Sinkhorn benchmark comparison from CSV files. + +Reads: outputs/csv/sinkhorn_compare_bd*.csv +Writes: outputs/plots/sinkhorn_*.png +""" +import argparse +import sys +from pathlib import Path + +THIS_DIR = Path(__file__).resolve().parent +FAST_HADAMARD_DIR = THIS_DIR.parent / "fast_hadamard" +if str(FAST_HADAMARD_DIR) not in sys.path: + sys.path.insert(0, str(FAST_HADAMARD_DIR)) + +from plot_common import ( # noqa: E402 + add_common_plot_args, + block_dim_from_path, + ensure_matplotlib, + ensure_plot_dir, + load_nonempty_rows, + make_speedup_heatmap, + plot_csv_collection, + resolve_dir_arg, +) + +CSV_PREFIX = "sinkhorn_compare_bd" +CSV_PATTERN = f"{CSV_PREFIX}*.csv" + +DURATION_LINE_PLOT = { + "filename": "sinkhorn_duration_bd{block_dim}.png", + "series": ( + ("pto_duration_us", "PTO Sinkhorn", "#dc2626", "s--"), + ("torch_duration_us", "PyTorch Reference", "#2563eb", "o-"), + ), + "y_label": "Duration (us)", + "title": "Sinkhorn Duration: PTO vs PyTorch Reference", +} + +BANDWIDTH_LINE_PLOT = { + "filename": "sinkhorn_bandwidth_bd{block_dim}.png", + "series": ( + ("pto_bandwidth_gbs", "PTO Sinkhorn", "#dc2626", "s--"), + ("torch_bandwidth_gbs", "PyTorch Reference", "#2563eb", "o-"), + ), + "y_label": "Effective Bandwidth (GB/s)", + "title": "Sinkhorn Effective Bandwidth: PTO vs PyTorch Reference", +} + +HEATMAPS = ( + { + "filename": "sinkhorn_speedup_heatmap_bd{block_dim}.png", + "key": "pto_speedup_vs_torch", + "title": "Sinkhorn PTO Speedup over PyTorch Reference", + "colorbar_label": "log2(PTO speedup)", + }, +) + + +def _parse_args(): + parser = argparse.ArgumentParser( + description="Plot Sinkhorn benchmark comparison from CSV files." + ) + return add_common_plot_args(parser).parse_args() + + +def _make_per_batch_line_plot( + rows, block_dim, output_path, series, y_label, title, log_y=False +): + """One subplot per batch size, x-axis = K.""" + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + from matplotlib import ticker + from plot_common import ( + group_by_batch, + normalize_axes, + save_figure, + format_log2_ticks, + ) + + batches = sorted({int(row["batch"]) for row in rows}) + grouped = group_by_batch(rows, [key for key, _, _, _ in series]) + + ncols = min(len(batches), 4) + nrows = (len(batches) + ncols - 1) // ncols + fig, axes = plt.subplots(nrows, ncols, figsize=(2.8 * ncols, 2.8 * nrows)) + axes = normalize_axes(axes) + + for idx, batch in enumerate(batches): + if idx >= len(axes): + break + ax = axes[idx] + ns = sorted(grouped[batch].keys()) + for key, label, color, style in series: + ax.plot( + ns, + [grouped[batch][n][key] for n in ns], + style, + color=color, + label=label, + linewidth=2, + markersize=5, + ) + ax.set_xscale("log", base=2) + if log_y: + ax.set_yscale("log") + ax.set_title(f"batch = {batch}", fontsize=11, fontweight="bold") + ax.set_xlabel("K") + ax.set_ylabel(y_label) + ax.xaxis.set_major_formatter(ticker.FuncFormatter(format_log2_ticks)) + ax.grid(True, alpha=0.3) + if idx == 0: + ax.legend(fontsize=8) + + for idx in range(len(batches), len(axes)): + axes[idx].set_visible(False) + + fig.suptitle(f"{title} (BLOCK_DIM={block_dim})", fontsize=14, fontweight="bold") + fig.tight_layout() + save_figure(fig, output_path) + plt.close(fig) + + +def plot_sinkhorn(csv_path: Path, plot_dir: Path): + if not ensure_matplotlib(): + return + + rows = load_nonempty_rows(csv_path) + if rows is None: + return + + block_dim = block_dim_from_path(csv_path, CSV_PREFIX) + ensure_plot_dir(plot_dir) + + # Duration: log y-scale, 2x3 layout + _make_per_batch_line_plot( + rows, + block_dim, + plot_dir / DURATION_LINE_PLOT["filename"].format(block_dim=block_dim), + DURATION_LINE_PLOT["series"], + DURATION_LINE_PLOT["y_label"], + DURATION_LINE_PLOT["title"], + log_y=True, + ) + + # Bandwidth: linear y-scale, 2x3 layout + _make_per_batch_line_plot( + rows, + block_dim, + plot_dir / BANDWIDTH_LINE_PLOT["filename"].format(block_dim=block_dim), + BANDWIDTH_LINE_PLOT["series"], + BANDWIDTH_LINE_PLOT["y_label"], + BANDWIDTH_LINE_PLOT["title"], + log_y=False, + ) + + for heatmap in HEATMAPS: + make_speedup_heatmap( + rows, + block_dim, + plot_dir / heatmap["filename"].format(block_dim=block_dim), + heatmap["key"], + heatmap["title"], + colorbar_label=heatmap.get("colorbar_label", "log2(speedup)"), + ) + + print(f"Plotted {csv_path.name}") + + +def main(): + args = _parse_args() + base = THIS_DIR + csv_dir = resolve_dir_arg(base, args.csv_dir) + plot_dir = resolve_dir_arg(base, args.plot_dir) + + plot_csv_collection( + csv_dir, + plot_dir, + pattern=CSV_PATTERN, + prefix=CSV_PREFIX, + warning="no Sinkhorn benchmark CSV files found", + plot_csv_fn=plot_sinkhorn, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/jit_cpp/sinkhorn/test_sinkhorn.py b/examples/jit_cpp/sinkhorn/test_sinkhorn.py new file mode 100644 index 00000000..92e20b98 --- /dev/null +++ b/examples/jit_cpp/sinkhorn/test_sinkhorn.py @@ -0,0 +1,208 @@ +""" +Correctness tests for the doubly-stochastic Sinkhorn normalization kernel. + +Reference: DeepSeek MHC sinkhorn_normalize_ref + x = x.softmax(-1) + eps + x = x / (x.sum(-2, keepdim=True) + eps) + for _ in range(repeat - 1): + x = x / (x.sum(-1, keepdim=True) + eps) + x = x / (x.sum(-2, keepdim=True) + eps) +""" + +from pathlib import Path + +import pytest +import torch +import torch_npu # noqa + +from jit_util_sinkhorn import jit_compile + +DTYPE = torch.float16 +KERNEL_CPP = Path(__file__).resolve().parent / "kernel_sinkhorn.cpp" + + +def sinkhorn_ref(x: torch.Tensor, repeat: int = 10, eps: float = 1e-6) -> torch.Tensor: + """Pure-PyTorch reference (fp32 internal).""" + x = x.float() + x = x.softmax(-1) + eps + x = x / (x.sum(-2, keepdim=True) + eps) + for _ in range(repeat - 1): + x = x / (x.sum(-1, keepdim=True) + eps) + x = x / (x.sum(-2, keepdim=True) + eps) + return x.to(torch.float16) + + +# Dispatch paths in kernel_sinkhorn.cpp::dispatchByK: +# K=4 — sinkhornSmallBatch (N < 2048) | sinkhornFastPath (N >= 2048) +# K=8 — sinkhornSmallBatch (N < 1024) | sinkhornFastPath (N >= 1024) +# K=16 — sinkhornSmallBatch (N < 512) | sinkhornFastPath (N >= 512) +# K=32 — sinkhornSmallBatch (N < 256) | sinkhornStridedTree (N >= 256) +# K=64 — sinkhornSmallBatch (N < 128) | sinkhornStridedTree (N >= 128) +# K ∈ (0, 16], K ∉ {4,8,16} — sinkhornStridedTree +# K ∈ (16, 32], K ≠ 32 — sinkhornStridedTree +# K ∈ (32, 64], K ≠ 64 — sinkhornStridedTree +# K ∈ (64, 128] — sinkhornPerMatrixFp32 +# +# Each shape below targets one of those paths; shapes marked "boundary" +# sit on a dispatch threshold where the path flips. +DISPATCH_SHAPES = [ + # --- K=4 smallBatch (N < 2048) --- + (1, 4), + (4, 4), + (32, 4), + (100, 4), + (1000, 4), + (2047, 4), # boundary: last smallBatch N + # --- K=4 fastPath (N >= 2048) --- + (2048, 4), # boundary: first fastPath N + (2049, 4), + (4096, 4), + (8192, 4), + # --- K=8 smallBatch (N < 1024) --- + (1, 8), + (8, 8), + (64, 8), + (500, 8), + (1023, 8), # boundary + # --- K=8 fastPath (N >= 1024) --- + (1024, 8), # boundary + (1025, 8), + (2048, 8), + # --- K=16 smallBatch (N < 512) --- + (1, 16), + (16, 16), + (256, 16), + (511, 16), # boundary + # --- K=16 fastPath (N >= 512) --- + (512, 16), # boundary + (513, 16), + (1024, 16), + # --- K=32 smallBatch (N < 256) then stridedTree<32> --- + (1, 32), + (64, 32), + (255, 32), # boundary + (256, 32), # boundary → stridedTree + (512, 32), + # --- K=64 smallBatch (N < 128) then stridedTree<64> --- + (1, 64), + (32, 64), + (127, 64), # boundary + (128, 64), # boundary → stridedTree + (256, 64), + # --- Odd/other K ∈ (0, 16] → stridedTree<16> --- + (1, 5), + (8, 7), + (16, 12), + (4, 13), + (32, 15), + # --- K ∈ (16, 32], K ≠ 32 → stridedTree<32> --- + (1, 17), + (64, 20), + (32, 24), + (8, 30), + # --- K ∈ (32, 64], K ≠ 64 → stridedTree<64> --- + (1, 33), + (16, 48), + (8, 50), + (4, 60), + # --- K ∈ (64, 128] → fp32 fallback --- + (1, 65), + (2, 80), + (4, 96), + (2, 100), + (8, 128), +] +# One (repeat, seed) per shape — cheap, broad dispatch-path coverage. +DISPATCH_CASES = [(batch, K, 10, 0) for (batch, K) in DISPATCH_SHAPES] + +# Dense (repeat × seed) coverage for representative shapes — catches +# numerical regressions independent of dispatch path. +DENSE_SHAPES = [ + (1, 4), + (1, 8), + (1, 16), + (1, 32), + (1, 64), + (1, 128), + (4, 4), + (4, 16), + (8, 8), + (16, 16), + (32, 4), + (64, 8), +] +DENSE_CASES = [ + (batch, K, repeat, seed) + for (batch, K) in DENSE_SHAPES + for repeat in (1, 5, 10) + for seed in (0, 42) +] + +# Dedup (DISPATCH and DENSE overlap on the original small shapes). +TEST_CASES = sorted(set(DISPATCH_CASES + DENSE_CASES)) + + +@pytest.fixture(scope="session") +def sinkhorn_kernel(npu_device): + return jit_compile(str(KERNEL_CPP), verbose=True, device=npu_device) + + +@pytest.mark.parametrize("batch,K,repeat,seed", TEST_CASES) +def test_sinkhorn_ds_matches_reference( + sinkhorn_kernel, npu_device, batch, K, repeat, seed +): + torch.manual_seed(seed) + x = torch.randn(batch, K, K, device=npu_device, dtype=DTYPE) + out = torch.empty_like(x) + + sinkhorn_kernel(x, out, repeat=repeat, eps=1e-6) + torch.npu.synchronize() + + ref = sinkhorn_ref(x.cpu(), repeat=repeat, eps=1e-6) + + torch.testing.assert_close(out.cpu(), ref, rtol=1e-2, atol=1e-5) + + +# Doubly-stochastic check across one representative shape per dispatch path. +DOUBLY_STOCHASTIC_SHAPES = [ + # smallBatch K ∈ {4,8,16,32,64} + (4, 4), + (4, 8), + (4, 16), + (4, 32), + (4, 64), + # fastPath K ∈ {4,8,16} + (2048, 4), + (1024, 8), + (512, 16), + # stridedTree (odd / non-{4,8,16,32,64} K) + (4, 7), + (4, 20), + (4, 48), + # stridedTree at large N for K ∈ {32, 64} + (256, 32), + (128, 64), + # fp32 fallback + (4, 96), + (4, 128), +] + + +@pytest.mark.parametrize("batch,K", DOUBLY_STOCHASTIC_SHAPES) +def test_output_is_doubly_stochastic(sinkhorn_kernel, npu_device, batch, K): + """After enough iterations, rows and columns should approximately sum to 1/K.""" + torch.manual_seed(123) + x = torch.randn(batch, K, K, device=npu_device, dtype=DTYPE) + out = torch.empty_like(x) + + sinkhorn_kernel(x, out, repeat=20, eps=1e-6) + torch.npu.synchronize() + + out_f = out.float() + row_sums = out_f.sum(dim=-1) # (batch, K) + col_sums = out_f.sum(dim=-2) # (batch, K) + + # All row sums should be approximately equal + assert row_sums.std(dim=-1).max() < 0.05, f"Row sums not uniform: {row_sums}" + # All col sums should be approximately equal + assert col_sums.std(dim=-1).max() < 0.05, f"Col sums not uniform: {col_sums}"