From d9c72dd587479d0d422f60f1c87b28f38a3c3e38 Mon Sep 17 00:00:00 2001 From: Mocchibird Date: Tue, 21 Apr 2026 15:39:16 +0000 Subject: [PATCH 1/5] [Feat] Implement doubly-stochastic Sinkhorn normalization kernel and associated benchmarks, tests, and documentation --- examples/jit_cpp/sinkhorn/.gitignore | 3 + examples/jit_cpp/sinkhorn/README.md | 60 ++++ examples/jit_cpp/sinkhorn/bench_sinkhorn.py | 210 ++++++++++++++ examples/jit_cpp/sinkhorn/conftest.py | 40 +++ .../jit_cpp/sinkhorn/jit_util_sinkhorn.py | 110 +++++++ examples/jit_cpp/sinkhorn/kernel_sinkhorn.cpp | 274 ++++++++++++++++++ examples/jit_cpp/sinkhorn/plot_sinkhorn.py | 113 ++++++++ examples/jit_cpp/sinkhorn/test_sinkhorn.py | 94 ++++++ 8 files changed, 904 insertions(+) create mode 100644 examples/jit_cpp/sinkhorn/.gitignore create mode 100644 examples/jit_cpp/sinkhorn/README.md create mode 100644 examples/jit_cpp/sinkhorn/bench_sinkhorn.py create mode 100644 examples/jit_cpp/sinkhorn/conftest.py create mode 100644 examples/jit_cpp/sinkhorn/jit_util_sinkhorn.py create mode 100644 examples/jit_cpp/sinkhorn/kernel_sinkhorn.cpp create mode 100644 examples/jit_cpp/sinkhorn/plot_sinkhorn.py create mode 100644 examples/jit_cpp/sinkhorn/test_sinkhorn.py diff --git a/examples/jit_cpp/sinkhorn/.gitignore b/examples/jit_cpp/sinkhorn/.gitignore new file mode 100644 index 00000000..289776ee --- /dev/null +++ b/examples/jit_cpp/sinkhorn/.gitignore @@ -0,0 +1,3 @@ +*.png +*.csv +*.so \ No newline at end of file diff --git a/examples/jit_cpp/sinkhorn/README.md b/examples/jit_cpp/sinkhorn/README.md new file mode 100644 index 00000000..4ad25d50 --- /dev/null +++ b/examples/jit_cpp/sinkhorn/README.md @@ -0,0 +1,60 @@ +# Doubly-Stochastic Sinkhorn Normalization + +PTO-ISA kernel for doubly-stochastic Sinkhorn normalization on Ascend NPU. +Implements the DeepSeek MHC pre-processing algorithm: softmax per row, +then alternating row/column normalization until the matrix is approximately +doubly-stochastic. + +## Algorithm + +```python +def sinkhorn_normalize(x, repeat=10, eps=1e-6): + x = x.softmax(-1) + eps + x = x / (x.sum(-2, keepdim=True) + eps) + for _ in range(repeat - 1): + x = x / (x.sum(-1, keepdim=True) + eps) + x = x / (x.sum(-2, keepdim=True) + eps) + return x +``` + +Input shape: `(N, K, K)` fp16 — N square matrices of dimension K. +K must be <= 128 (the full K×K matrix lives in UB as fp32). + +## Files + +| File | Description | +|---|---| +| `kernel_sinkhorn.cpp` | PTO-ISA C++ kernel (JIT-compiled via bisheng) | +| `jit_util_sinkhorn.py` | Python wrapper — compiles & exposes the kernel | +| `conftest.py` | pytest NPU device fixture | +| `test_sinkhorn.py` | Correctness tests vs PyTorch reference | +| `bench_sinkhorn.py` | Benchmark PTO vs PyTorch | +| `plot_sinkhorn.py` | Plot benchmark CSVs | + +## Usage + +```bash +# compile + test +python -m pytest test_sinkhorn.py -v --npu=npu:0 + +# benchmark +python bench_sinkhorn.py + +# re-plot from saved CSV +python plot_sinkhorn.py +``` + +## Reproducing + +```bash +cd examples/jit_cpp/sinkhorn + +# tests (73 cases: shapes × repeats × seeds) +python -m pytest test_sinkhorn.py -v --npu=npu:0 + +# benchmark (batch × K grid, repeat=10, eps=1e-6) +python bench_sinkhorn.py --batches 1 4 8 16 32 64 --hidden-dims 4 8 16 32 64 128 + +# plots +python plot_sinkhorn.py +``` diff --git a/examples/jit_cpp/sinkhorn/bench_sinkhorn.py b/examples/jit_cpp/sinkhorn/bench_sinkhorn.py new file mode 100644 index 00000000..a0daa613 --- /dev/null +++ b/examples/jit_cpp/sinkhorn/bench_sinkhorn.py @@ -0,0 +1,210 @@ +# pylint: disable=wrong-import-position +""" +Benchmark PTO doubly-stochastic Sinkhorn against PyTorch reference. + +Writes: + outputs/csv/sinkhorn_compare_bd{block_dim}.csv + outputs/plots/ (via plot_sinkhorn.py) +""" +import argparse +import sys +from pathlib import Path + +import torch +import torch_npu # noqa + +THIS_DIR = Path(__file__).resolve().parent +FAST_HADAMARD_DIR = THIS_DIR.parent / "fast_hadamard" +if str(FAST_HADAMARD_DIR) not in sys.path: + sys.path.insert(0, str(FAST_HADAMARD_DIR)) + +from bench_common import ( # noqa: E402 + add_common_benchmark_args, + benchmark_npu_us, + benchmark_trials_us, + ensure_output_dir, + make_buffer_pool, + pool_item, + resolve_dir_arg, + validate_benchmark_args, + write_csv_records, +) + +from jit_util_common import get_current_stream_ptr # noqa: E402 +from jit_util_sinkhorn import jit_compile # noqa: E402 + +DEFAULT_WARMUP = 10 +DEFAULT_REPEATS = 100 +SINKHORN_REPEAT = 10 +SINKHORN_EPS = 1e-6 +BYTES_PER_ELEMENT = 2 # fp16 + +CSV_HEADER = ( + "batch,N,pto_duration_us,torch_duration_us," + "pto_bandwidth_gbs,torch_bandwidth_gbs,pto_speedup_vs_torch," + "trials,pto_duration_mean_us,pto_duration_std_us,pto_duration_min_us," + "pto_duration_max_us,pto_duration_cv_pct,torch_duration_mean_us," + "torch_duration_std_us,torch_duration_min_us," + "torch_duration_max_us,torch_duration_cv_pct\n" +) + + +def sinkhorn_ref(x, repeat=10, eps=1e-6): + """PyTorch reference (runs on NPU via torch ops).""" + x = x.float() + x = x.softmax(-1) + eps + x = x / (x.sum(-2, keepdim=True) + eps) + for _ in range(repeat - 1): + x = x / (x.sum(-1, keepdim=True) + eps) + x = x / (x.sum(-2, keepdim=True) + eps) + return x.half() + + +def _parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark PTO Sinkhorn (doubly-stochastic) against PyTorch reference." + ) + parser.add_argument( + "--no-cache-stream", dest="cache_stream", action="store_false", + help="Disable cached stream pointer reuse for PTO launches.", + ) + parser.set_defaults(cache_stream=True) + return add_common_benchmark_args( + parser, default_warmup=DEFAULT_WARMUP, default_repeats=DEFAULT_REPEATS, + ).parse_args() + + +def _effective_bandwidth_gbs(batch, K, duration_us): + if duration_us <= 0: + return 0.0 + # read K*K + write K*K + data_bytes = batch * 2 * K * K * BYTES_PER_ELEMENT + return (data_bytes / 1e9) / (duration_us / 1e6) + + +def _make_shape_pools(batch, K, warmup, repeats, device): + return { + "x": make_buffer_pool( + warmup, repeats, + lambda: torch.randn(batch, K, K, device=device, dtype=torch.float16), + ), + "y": make_buffer_pool( + warmup, repeats, + lambda: torch.empty(batch, K, K, device=device, dtype=torch.float16), + ), + } + + +def benchmark( + sinq_func, *, warmup, repeats, trials, output_dir, device, batches, + hidden_dims, stream_ptr=None, +): + ensure_output_dir(output_dir) + block_dim = sinq_func.block_dim + + print(f"\n{'=' * 92}") + print(f"SINKHORN DS BENCHMARK (BLOCK_DIM={block_dim}, repeat={SINKHORN_REPEAT})") + print(f"{'=' * 92}") + header = ( + f"{'batch':>6s} {'K':>6s}" + f" {'pto_us':>10s} {'torch_us':>10s}" + f" {'pto_bw(GB/s)':>12s} {'torch_bw(GB/s)':>14s} {'pto_speedup':>11s}" + ) + print(header) + print("-" * len(header)) + + records = [] + for batch in batches: + for K in hidden_dims: + pools = _make_shape_pools(batch, K, warmup, repeats, device) + x_list = pools["x"] + y_list = pools["y"] + + pto_stats = benchmark_trials_us( + trials, + lambda x_list=x_list, y_list=y_list: benchmark_npu_us( + warmup, repeats, + lambda i: sinq_func( + pool_item(x_list, i), pool_item(y_list, i), + repeat=SINKHORN_REPEAT, eps=SINKHORN_EPS, + stream_ptr=stream_ptr, + ), + ), + ) + torch_stats = benchmark_trials_us( + trials, + lambda x_list=x_list: benchmark_npu_us( + warmup, repeats, + lambda i: sinkhorn_ref( + pool_item(x_list, i), + repeat=SINKHORN_REPEAT, eps=SINKHORN_EPS, + ), + ), + ) + + pto_us = pto_stats["median_us"] + torch_us = torch_stats["median_us"] + pto_bw = _effective_bandwidth_gbs(batch, K, pto_us) + torch_bw = _effective_bandwidth_gbs(batch, K, torch_us) + pto_speedup = torch_us / pto_us if pto_us > 0 else 0.0 + + print( + f"{batch:>6d} {K:>6d}" + f" {pto_us:>10.2f} {torch_us:>10.2f}" + f" {pto_bw:>12.4f} {torch_bw:>14.4f}" + f" {pto_speedup:>11.3f}" + ) + + records.append( + f"{batch},{K},{pto_us:.4f},{torch_us:.4f}," + f"{pto_bw:.6f},{torch_bw:.6f}," + f"{pto_speedup:.4f}," + f"{trials},{pto_stats['mean_us']:.4f},{pto_stats['std_us']:.4f}," + f"{pto_stats['min_us']:.4f},{pto_stats['max_us']:.4f}," + f"{pto_stats['cv_pct']:.4f},{torch_stats['mean_us']:.4f}," + f"{torch_stats['std_us']:.4f}," + f"{torch_stats['min_us']:.4f}," + f"{torch_stats['max_us']:.4f}," + f"{torch_stats['cv_pct']:.4f}" + ) + + csv_path = output_dir / f"sinkhorn_compare_bd{block_dim}.csv" + write_csv_records(csv_path, CSV_HEADER, records) + print(f"\nSaved to {csv_path}") + + +def main(): + args = _parse_args() + validate_benchmark_args(args) + + torch.npu.set_device(args.npu) + base = THIS_DIR + kernel_path = base / "kernel_sinkhorn.cpp" + csv_dir = resolve_dir_arg(base, args.csv_dir) + + print(f"Using device: {args.npu}") + print("Compiling kernel_sinkhorn.cpp ...") + sinq_func = jit_compile(str(kernel_path), verbose=True, device=args.npu) + stream_ptr = get_current_stream_ptr() if args.cache_stream else None + if stream_ptr is not None: + print("Using cached NPU stream pointer for PTO launches.") + + # Override default grids for sinkhorn: batch=N (matrices), K=dim + batches = args.batches if args.batches else [1, 4, 8, 16, 32, 64] + dims = args.hidden_dims if args.hidden_dims else [4, 8, 16, 32, 64, 128] + + benchmark( + sinq_func, + warmup=args.warmup, + repeats=args.repeats, + trials=args.trials, + output_dir=csv_dir, + device=args.npu, + batches=batches, + hidden_dims=dims, + stream_ptr=stream_ptr, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/jit_cpp/sinkhorn/conftest.py b/examples/jit_cpp/sinkhorn/conftest.py new file mode 100644 index 00000000..34e7fa81 --- /dev/null +++ b/examples/jit_cpp/sinkhorn/conftest.py @@ -0,0 +1,40 @@ +import pytest +import torch + + +def normalize_npu_device(device: str | int) -> str: + text = str(device).strip().strip('"').strip("'") + if text.lower().startswith("npu:"): + index = text.split(":", 1)[1].strip() + else: + index = text + + if not index.isdigit(): + raise ValueError( + f"Invalid NPU device '{device}'. Expected values like 0 or npu:0." + ) + return f"npu:{int(index)}" + + +def pytest_addoption(parser): + try: + parser.addoption( + "--npu", + action="store", + default="npu:0", + help="NPU device (examples: 0, npu:0, '0', 'npu:0').", + ) + except ValueError as exc: + if "--npu" not in str(exc): + raise + + +@pytest.fixture(scope="session") +def npu_device(request): + raw = request.config.getoption("--npu") + return normalize_npu_device(raw) + + +@pytest.fixture(scope="session", autouse=True) +def setup_npu_device(npu_device): + torch.npu.set_device(npu_device) diff --git a/examples/jit_cpp/sinkhorn/jit_util_sinkhorn.py b/examples/jit_cpp/sinkhorn/jit_util_sinkhorn.py new file mode 100644 index 00000000..233e0b28 --- /dev/null +++ b/examples/jit_cpp/sinkhorn/jit_util_sinkhorn.py @@ -0,0 +1,110 @@ +# pylint: disable=wrong-import-position +import ctypes +import sys +from pathlib import Path + +import torch + +THIS_DIR = Path(__file__).resolve().parent +FAST_HADAMARD_DIR = THIS_DIR.parent / "fast_hadamard" +if str(FAST_HADAMARD_DIR) not in sys.path: + sys.path.insert(0, str(FAST_HADAMARD_DIR)) + +from jit_util_common import ( # noqa: E402 + BLOCK_DIM, + DEFAULT_DEVICE, + jit_compile_with_loader, + load_cdll, + load_required_symbol, + resolve_launch_block_dim, + resolve_stream_ptr, + torch_to_ctypes, +) + +MAX_DIM = 128 + +SINKHORN_DS_ARGTYPES = [ + ctypes.c_uint32, # blockDim + ctypes.c_void_p, # stream + ctypes.c_void_p, # input + ctypes.c_void_p, # output + ctypes.c_uint32, # N + ctypes.c_uint32, # K + ctypes.c_uint32, # repeat + ctypes.c_float, # eps +] + + +def _validate(input_tensor, output_tensor, K): + if input_tensor.dim() != 3: + raise ValueError("input must be 3D (N, K, K).") + N = input_tensor.shape[0] + if input_tensor.shape[1] != K or input_tensor.shape[2] != K: + raise ValueError(f"input must have shape (N, {K}, {K}).") + if output_tensor.shape != input_tensor.shape: + raise ValueError("output must have the same shape as input.") + if input_tensor.dtype != torch.float16: + raise TypeError("input must use torch.float16.") + if not input_tensor.is_contiguous() or not output_tensor.is_contiguous(): + raise ValueError("tensors must be contiguous.") + if input_tensor.device != output_tensor.device: + raise ValueError("tensors must be on the same device.") + if K > MAX_DIM: + raise ValueError(f"K must be <= {MAX_DIM}.") + + +def load_lib(lib_path, block_dim=BLOCK_DIM): + lib = load_cdll(lib_path) + resolved_block_dim = max(1, int(block_dim)) + + kernel = load_required_symbol( + lib, + "call_sinkhorn_ds_kernel", + SINKHORN_DS_ARGTYPES, + ) + + def sinkhorn_ds_func( + input_tensor, + output_tensor, + *, + repeat=10, + eps=1e-6, + block_dim=resolved_block_dim, + stream_ptr=None, + ): + N, K, _ = input_tensor.shape + _validate(input_tensor, output_tensor, K) + kernel( + resolve_launch_block_dim(block_dim, resolved_block_dim), + resolve_stream_ptr(stream_ptr), + torch_to_ctypes(input_tensor), + torch_to_ctypes(output_tensor), + N, + K, + repeat, + float(eps), + ) + + sinkhorn_ds_func.block_dim = resolved_block_dim + return sinkhorn_ds_func + + +def jit_compile( + src_path, + verbose=True, + clean_up=False, + so_dir=None, + device: str | int = DEFAULT_DEVICE, + block_dim=None, +): + if so_dir is None: + so_dir = THIS_DIR / "outputs" / "so" + return jit_compile_with_loader( + src_path, + load_lib, + verbose=verbose, + clean_up=clean_up, + so_dir=so_dir, + device=device, + block_dim=block_dim, + ) diff --git a/examples/jit_cpp/sinkhorn/kernel_sinkhorn.cpp b/examples/jit_cpp/sinkhorn/kernel_sinkhorn.cpp new file mode 100644 index 00000000..fdaa0a19 --- /dev/null +++ b/examples/jit_cpp/sinkhorn/kernel_sinkhorn.cpp @@ -0,0 +1,274 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ + +/** + * Doubly-stochastic Sinkhorn normalization kernel (fp16 I/O, fp32 internal). + * + * Implements the DeepSeek MHC pre-processing sinkhorn: + * 1. softmax per row + eps + * 2. column-normalize (+ eps) + * 3. repeat (repeat-1) times: row-normalize (+ eps), column-normalize (+ eps) + * + * Design: + * - One vector core per (K, K) matrix. + * - The entire matrix lives in UB as fp32 via a 2D tile with static dims + * MAX_DIM × MAX_DIM but dynamic dims (K, K). All reductions (TROWSUM, + * TROWMAX, TCOLSUM) respect the dynamic K, ignoring padding. + * - K must be <= MAX_DIM (128). + */ + +#include + +// clang-format off +#ifndef GM_ADDR +#define GM_ADDR __gm__ uint8_t* +#endif +// clang-format on + +using namespace pto; + +constexpr uint32_t UB_USABLE_BYTES = 192 * 1024; +constexpr uint32_t MAX_DIM = 128; + +namespace UbOfs { +constexpr unsigned MAT_HALF = 0; +constexpr unsigned MAT_FP32 = MAT_HALF + MAX_DIM * MAX_DIM * sizeof(half); +constexpr unsigned TMP = MAT_FP32 + MAX_DIM * MAX_DIM * sizeof(float); +constexpr unsigned VEC_BUF = TMP + MAX_DIM * MAX_DIM * sizeof(float); +constexpr unsigned TOTAL = VEC_BUF + MAX_DIM * sizeof(float); +} // namespace UbOfs + +static_assert(UbOfs::TOTAL <= UB_USABLE_BYTES, "Sinkhorn DS UB exceeds 192 KB."); + +#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) + +using StrideDim5 = pto::Stride<1, 1, 1, 1, 1>; + +template +using Vec1D = Tile; + +template +using Global1D = GlobalTensor, StrideDim5>; + +template +using Tile2D = Tile; + +using DynStride = Stride<1, 1, 1, DYNAMIC, 1>; +template +using Shape2D = TileShape2D; +template +using Global2D = GlobalTensor, DynStride, Layout::ND>; + +template +using ColVec = Tile; + +AICORE inline void initPipeFlags() { + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); +} + +AICORE inline void drainPipeFlags() { + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); +} + +// Column-normalize: divide mat[r,:] by vec[:] for each row. +// (Replaces unavailable TCOLEXPANDDIV.) +AICORE void colNormDiv(uint32_t K) { + constexpr unsigned rowBytes = MAX_DIM * sizeof(float); + for (uint32_t r = 0; r < K; ++r) { + Vec1D row(1, K); + Vec1D vec(1, K); + TASSIGN(row, UbOfs::MAT_FP32 + r * rowBytes); + TASSIGN(vec, UbOfs::VEC_BUF); + TDIV(row, row, vec); + pipe_barrier(PIPE_V); + } +} + +template +AICORE void runSinkhornDS(__gm__ T *input, __gm__ T *output, + uint32_t N, uint32_t K, + uint32_t repeat, float eps) { + set_mask_norm(); + set_vector_mask(-1, -1); + if (K == 0 || K > MAX_DIM) return; + + const uint32_t num_workers = get_block_num() * get_subblockdim(); + const uint32_t wid = get_block_idx() * get_subblockdim() + get_subblockid(); + const uint32_t KK = K * K; + // Flat count covering the 2D buffer (row stride = MAX_DIM). + const uint32_t flat = K * MAX_DIM; + + initPipeFlags(); + + for (uint32_t bi = wid; bi < N; bi += num_workers) { + __gm__ T *gm_in = input + static_cast(bi) * KK; + __gm__ T *gm_out = output + static_cast(bi) * KK; + + // ---- Zero fp16 buffer, then load (K, K) ---- + { + Vec1D zHalf(1, flat); + TASSIGN(zHalf, UbOfs::MAT_HALF); + TEXPANDS(zHalf, (T)0); + pipe_barrier(PIPE_V); + } + + Tile2D matHalf(K, K); + TASSIGN(matHalf, UbOfs::MAT_HALF); + Shape2D inShape(K, K); + DynStride inStride(K); + Global2D gIn(gm_in, inShape, inStride); + + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(matHalf, gIn); + pipe_barrier(PIPE_ALL); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // fp16 → fp32 + Vec1D hFlat(1, flat); + Vec1D fFlat(1, flat); + TASSIGN(hFlat, UbOfs::MAT_HALF); + TASSIGN(fFlat, UbOfs::MAT_FP32); + TCVT(fFlat, hFlat, RoundMode::CAST_NONE); + pipe_barrier(PIPE_V); + + // 2D view with dynamic (K, K) — reductions respect this. + Tile2D mat(K, K); + TASSIGN(mat, UbOfs::MAT_FP32); + Tile2D tmp(K, K); + TASSIGN(tmp, UbOfs::TMP); + + // ============================================================ + // Softmax per row: max-subtract, exp, sum, divide + // ============================================================ + ColVec vecCol(K, 1); + TASSIGN(vecCol, UbOfs::VEC_BUF); + + // Row max → ColVec(K, 1) + TROWMAX(vecCol, mat, tmp); + pipe_barrier(PIPE_V); + + // Subtract max per row + TROWEXPANDSUB(mat, mat, vecCol); + pipe_barrier(PIPE_V); + + // Exp (flat — includes padding, but padding was 0, exp(0-max)=exp(-max)≈0) + TEXP(fFlat, fFlat); + pipe_barrier(PIPE_V); + + // Row sum → ColVec(K, 1) + TROWSUM(vecCol, mat, tmp); + pipe_barrier(PIPE_V); + + // Add eps to row sums (as 1D view of VEC_BUF) + { + Vec1D vecFlat(1, K); + TASSIGN(vecFlat, UbOfs::VEC_BUF); + TADDS(vecFlat, vecFlat, eps); + pipe_barrier(PIPE_V); + } + + // Divide by (row_sum + eps) + TROWEXPANDDIV(mat, mat, vecCol); + pipe_barrier(PIPE_V); + + // Add eps to all elements + TADDS(fFlat, fFlat, eps); + pipe_barrier(PIPE_V); + + // ============================================================ + // Column normalize + // ============================================================ + { + Vec1D colSums(1, K); + TASSIGN(colSums, UbOfs::VEC_BUF); + TCOLSUM(colSums, mat, tmp, false); + pipe_barrier(PIPE_V); + TADDS(colSums, colSums, eps); + pipe_barrier(PIPE_V); + } + colNormDiv(K); + + // ============================================================ + // Iterate (repeat-1) times: row-norm + col-norm + // ============================================================ + for (uint32_t it = 1; it < repeat; ++it) { + // Row normalize + TASSIGN(vecCol, UbOfs::VEC_BUF); + TROWSUM(vecCol, mat, tmp); + pipe_barrier(PIPE_V); + { + Vec1D vecFlat(1, K); + TASSIGN(vecFlat, UbOfs::VEC_BUF); + TADDS(vecFlat, vecFlat, eps); + pipe_barrier(PIPE_V); + } + TROWEXPANDDIV(mat, mat, vecCol); + pipe_barrier(PIPE_V); + + // Column normalize + { + Vec1D colSums(1, K); + TASSIGN(colSums, UbOfs::VEC_BUF); + TCOLSUM(colSums, mat, tmp, false); + pipe_barrier(PIPE_V); + TADDS(colSums, colSums, eps); + pipe_barrier(PIPE_V); + } + colNormDiv(K); + } + + // ============================================================ + // Store: fp32 → fp16 → HBM + // ============================================================ + TCVT(hFlat, fFlat, RoundMode::CAST_RINT); + pipe_barrier(PIPE_V); + + Tile2D outHalf(K, K); + TASSIGN(outHalf, UbOfs::MAT_HALF); + Shape2D outShape(K, K); + DynStride outStride(K); + Global2D gOut(gm_out, outShape, outStride); + + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(gOut, outHalf); + pipe_barrier(PIPE_ALL); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + } + + drainPipeFlags(); +} + +#endif + +extern "C" __global__ AICORE void sinkhorn_ds_fp16(GM_ADDR input, + GM_ADDR output, + uint32_t N, uint32_t K, + uint32_t repeat, + float eps) { +#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) + runSinkhornDS((__gm__ half *)input, (__gm__ half *)output, + N, K, repeat, eps); +#else + (void)input; (void)output; (void)N; (void)K; (void)repeat; (void)eps; +#endif +} + +extern "C" void call_sinkhorn_ds_kernel(uint32_t blockDim, void *stream, + uint8_t *input, uint8_t *output, + uint32_t N, uint32_t K, + uint32_t repeat, float eps) { + sinkhorn_ds_fp16<<>>( + input, output, N, K, repeat, eps); +} diff --git a/examples/jit_cpp/sinkhorn/plot_sinkhorn.py b/examples/jit_cpp/sinkhorn/plot_sinkhorn.py new file mode 100644 index 00000000..760be017 --- /dev/null +++ b/examples/jit_cpp/sinkhorn/plot_sinkhorn.py @@ -0,0 +1,113 @@ +# pylint: disable=wrong-import-position +""" +Plot Sinkhorn benchmark comparison from CSV files. + +Reads: outputs/csv/sinkhorn_compare_bd*.csv +Writes: outputs/plots/sinkhorn_*.png +""" +import argparse +import sys +from pathlib import Path + +THIS_DIR = Path(__file__).resolve().parent +FAST_HADAMARD_DIR = THIS_DIR.parent / "fast_hadamard" +if str(FAST_HADAMARD_DIR) not in sys.path: + sys.path.insert(0, str(FAST_HADAMARD_DIR)) + +from plot_common import ( # noqa: E402 + add_common_plot_args, + block_dim_from_path, + ensure_matplotlib, + ensure_plot_dir, + load_nonempty_rows, + make_batched_line_plot, + make_speedup_heatmap, + plot_csv_collection, + resolve_dir_arg, +) + +CSV_PREFIX = "sinkhorn_compare_bd" +CSV_PATTERN = f"{CSV_PREFIX}*.csv" + +DURATION_LINE_PLOT = { + "filename": "sinkhorn_duration_bd{block_dim}.png", + "series": ( + ("pto_duration_us", "PTO Sinkhorn", "#dc2626", "s--"), + ("torch_duration_us", "PyTorch Reference", "#2563eb", "o-"), + ), + "y_label": "Duration (us)", + "title": "Sinkhorn Duration: PTO vs PyTorch Reference", +} + +BANDWIDTH_LINE_PLOT = { + "filename": "sinkhorn_bandwidth_bd{block_dim}.png", + "series": ( + ("pto_bandwidth_gbs", "PTO Sinkhorn", "#dc2626", "s--"), + ("torch_bandwidth_gbs", "PyTorch Reference", "#2563eb", "o-"), + ), + "y_label": "Effective Bandwidth (GB/s)", + "title": "Sinkhorn Effective Bandwidth: PTO vs PyTorch Reference", +} + +HEATMAPS = ( + { + "filename": "sinkhorn_speedup_heatmap_bd{block_dim}.png", + "key": "pto_speedup_vs_torch", + "title": "Sinkhorn PTO Speedup over PyTorch Reference", + "colorbar_label": "log2(PTO speedup)", + }, +) + + +def _parse_args(): + parser = argparse.ArgumentParser( + description="Plot Sinkhorn benchmark comparison from CSV files." + ) + return add_common_plot_args(parser).parse_args() + + +def plot_sinkhorn(csv_path: Path, plot_dir: Path): + if not ensure_matplotlib(): + return + + rows = load_nonempty_rows(csv_path) + if rows is None: + return + + block_dim = block_dim_from_path(csv_path, CSV_PREFIX) + ensure_plot_dir(plot_dir) + + for plot in (DURATION_LINE_PLOT, BANDWIDTH_LINE_PLOT): + make_batched_line_plot( + rows, block_dim, + plot_dir / plot["filename"].format(block_dim=block_dim), + plot["series"], plot["y_label"], plot["title"], + ) + + for heatmap in HEATMAPS: + make_speedup_heatmap( + rows, block_dim, + plot_dir / heatmap["filename"].format(block_dim=block_dim), + heatmap["key"], heatmap["title"], + colorbar_label=heatmap.get("colorbar_label", "log2(speedup)"), + ) + + print(f"Plotted {csv_path.name}") + + +def main(): + args = _parse_args() + base = THIS_DIR + csv_dir = resolve_dir_arg(base, args.csv_dir) + plot_dir = resolve_dir_arg(base, args.plot_dir) + + plot_csv_collection( + csv_dir, plot_dir, + pattern=CSV_PATTERN, prefix=CSV_PREFIX, + warning="no Sinkhorn benchmark CSV files found", + plot_csv_fn=plot_sinkhorn, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/jit_cpp/sinkhorn/test_sinkhorn.py b/examples/jit_cpp/sinkhorn/test_sinkhorn.py new file mode 100644 index 00000000..341caf81 --- /dev/null +++ b/examples/jit_cpp/sinkhorn/test_sinkhorn.py @@ -0,0 +1,94 @@ +""" +Correctness tests for the doubly-stochastic Sinkhorn normalization kernel. + +Reference: DeepSeek MHC sinkhorn_normalize_ref + x = x.softmax(-1) + eps + x = x / (x.sum(-2, keepdim=True) + eps) + for _ in range(repeat - 1): + x = x / (x.sum(-1, keepdim=True) + eps) + x = x / (x.sum(-2, keepdim=True) + eps) +""" +from pathlib import Path + +import pytest +import torch +import torch_npu # noqa + +from jit_util_sinkhorn import jit_compile + +DTYPE = torch.float16 +KERNEL_CPP = Path(__file__).resolve().parent / "kernel_sinkhorn.cpp" + + +def sinkhorn_ref(x: torch.Tensor, repeat: int = 10, eps: float = 1e-6) -> torch.Tensor: + """Pure-PyTorch reference (fp32 internal).""" + x = x.float() + x = x.softmax(-1) + eps + x = x / (x.sum(-2, keepdim=True) + eps) + for _ in range(repeat - 1): + x = x / (x.sum(-1, keepdim=True) + eps) + x = x / (x.sum(-2, keepdim=True) + eps) + return x.to(torch.float16) + + +TEST_SHAPES = [ + (1, 4), + (1, 8), + (1, 16), + (1, 32), + (1, 64), + (1, 128), + (4, 4), + (4, 16), + (8, 8), + (16, 16), + (32, 4), + (64, 8), +] +TEST_REPEATS = [1, 5, 10] +TEST_SEEDS = [0, 42] +TEST_CASES = [ + (N, K, repeat, seed) + for N, K in TEST_SHAPES + for repeat in TEST_REPEATS + for seed in TEST_SEEDS +] + + +@pytest.fixture(scope="session") +def sinkhorn_kernel(npu_device): + return jit_compile(str(KERNEL_CPP), verbose=True, device=npu_device) + + +@pytest.mark.parametrize("N,K,repeat,seed", TEST_CASES) +def test_sinkhorn_ds_matches_reference(sinkhorn_kernel, npu_device, N, K, repeat, seed): + torch.manual_seed(seed) + x = torch.randn(N, K, K, device=npu_device, dtype=DTYPE) + out = torch.empty_like(x) + + sinkhorn_kernel(x, out, repeat=repeat, eps=1e-6) + torch.npu.synchronize() + + ref = sinkhorn_ref(x.cpu(), repeat=repeat, eps=1e-6) + + torch.testing.assert_close(out.cpu(), ref, rtol=1e-2, atol=1e-5) + + +def test_output_is_doubly_stochastic(sinkhorn_kernel, npu_device): + """After enough iterations, rows and columns should approximately sum to 1/K.""" + torch.manual_seed(123) + K = 8 + x = torch.randn(4, K, K, device=npu_device, dtype=DTYPE) + out = torch.empty_like(x) + + sinkhorn_kernel(x, out, repeat=20, eps=1e-6) + torch.npu.synchronize() + + out_f = out.float() + row_sums = out_f.sum(dim=-1) # (4, K) + col_sums = out_f.sum(dim=-2) # (4, K) + + # All row sums should be approximately equal + assert row_sums.std(dim=-1).max() < 0.05, f"Row sums not uniform: {row_sums}" + # All col sums should be approximately equal + assert col_sums.std(dim=-1).max() < 0.05, f"Col sums not uniform: {col_sums}" From a88d98eaefd30e3dd83b99a8f52e1aab9a8c7fdf Mon Sep 17 00:00:00 2001 From: Mocchibird Date: Tue, 21 Apr 2026 15:50:38 +0000 Subject: [PATCH 2/5] better plotting --- examples/jit_cpp/sinkhorn/.gitignore | 2 +- examples/jit_cpp/sinkhorn/bench_sinkhorn.py | 41 +++++--- .../jit_cpp/sinkhorn/jit_util_sinkhorn.py | 2 +- examples/jit_cpp/sinkhorn/kernel_sinkhorn.cpp | 50 +++++----- examples/jit_cpp/sinkhorn/plot_sinkhorn.py | 96 +++++++++++++++++-- examples/jit_cpp/sinkhorn/test_sinkhorn.py | 1 + 6 files changed, 147 insertions(+), 45 deletions(-) diff --git a/examples/jit_cpp/sinkhorn/.gitignore b/examples/jit_cpp/sinkhorn/.gitignore index 289776ee..08445025 100644 --- a/examples/jit_cpp/sinkhorn/.gitignore +++ b/examples/jit_cpp/sinkhorn/.gitignore @@ -1,3 +1,3 @@ *.png *.csv -*.so \ No newline at end of file +*.so diff --git a/examples/jit_cpp/sinkhorn/bench_sinkhorn.py b/examples/jit_cpp/sinkhorn/bench_sinkhorn.py index a0daa613..689fc3e9 100644 --- a/examples/jit_cpp/sinkhorn/bench_sinkhorn.py +++ b/examples/jit_cpp/sinkhorn/bench_sinkhorn.py @@ -65,12 +65,16 @@ def _parse_args(): description="Benchmark PTO Sinkhorn (doubly-stochastic) against PyTorch reference." ) parser.add_argument( - "--no-cache-stream", dest="cache_stream", action="store_false", + "--no-cache-stream", + dest="cache_stream", + action="store_false", help="Disable cached stream pointer reuse for PTO launches.", ) parser.set_defaults(cache_stream=True) return add_common_benchmark_args( - parser, default_warmup=DEFAULT_WARMUP, default_repeats=DEFAULT_REPEATS, + parser, + default_warmup=DEFAULT_WARMUP, + default_repeats=DEFAULT_REPEATS, ).parse_args() @@ -85,19 +89,29 @@ def _effective_bandwidth_gbs(batch, K, duration_us): def _make_shape_pools(batch, K, warmup, repeats, device): return { "x": make_buffer_pool( - warmup, repeats, + warmup, + repeats, lambda: torch.randn(batch, K, K, device=device, dtype=torch.float16), ), "y": make_buffer_pool( - warmup, repeats, + warmup, + repeats, lambda: torch.empty(batch, K, K, device=device, dtype=torch.float16), ), } def benchmark( - sinq_func, *, warmup, repeats, trials, output_dir, device, batches, - hidden_dims, stream_ptr=None, + sinq_func, + *, + warmup, + repeats, + trials, + output_dir, + device, + batches, + hidden_dims, + stream_ptr=None, ): ensure_output_dir(output_dir) block_dim = sinq_func.block_dim @@ -123,10 +137,13 @@ def benchmark( pto_stats = benchmark_trials_us( trials, lambda x_list=x_list, y_list=y_list: benchmark_npu_us( - warmup, repeats, + warmup, + repeats, lambda i: sinq_func( - pool_item(x_list, i), pool_item(y_list, i), - repeat=SINKHORN_REPEAT, eps=SINKHORN_EPS, + pool_item(x_list, i), + pool_item(y_list, i), + repeat=SINKHORN_REPEAT, + eps=SINKHORN_EPS, stream_ptr=stream_ptr, ), ), @@ -134,10 +151,12 @@ def benchmark( torch_stats = benchmark_trials_us( trials, lambda x_list=x_list: benchmark_npu_us( - warmup, repeats, + warmup, + repeats, lambda i: sinkhorn_ref( pool_item(x_list, i), - repeat=SINKHORN_REPEAT, eps=SINKHORN_EPS, + repeat=SINKHORN_REPEAT, + eps=SINKHORN_EPS, ), ), ) diff --git a/examples/jit_cpp/sinkhorn/jit_util_sinkhorn.py b/examples/jit_cpp/sinkhorn/jit_util_sinkhorn.py index 233e0b28..fd88ec64 100644 --- a/examples/jit_cpp/sinkhorn/jit_util_sinkhorn.py +++ b/examples/jit_cpp/sinkhorn/jit_util_sinkhorn.py @@ -31,7 +31,7 @@ ctypes.c_uint32, # N ctypes.c_uint32, # K ctypes.c_uint32, # repeat - ctypes.c_float, # eps + ctypes.c_float, # eps ] diff --git a/examples/jit_cpp/sinkhorn/kernel_sinkhorn.cpp b/examples/jit_cpp/sinkhorn/kernel_sinkhorn.cpp index fdaa0a19..ea700c46 100644 --- a/examples/jit_cpp/sinkhorn/kernel_sinkhorn.cpp +++ b/examples/jit_cpp/sinkhorn/kernel_sinkhorn.cpp @@ -39,12 +39,13 @@ constexpr uint32_t MAX_DIM = 128; namespace UbOfs { constexpr unsigned MAT_HALF = 0; constexpr unsigned MAT_FP32 = MAT_HALF + MAX_DIM * MAX_DIM * sizeof(half); -constexpr unsigned TMP = MAT_FP32 + MAX_DIM * MAX_DIM * sizeof(float); -constexpr unsigned VEC_BUF = TMP + MAX_DIM * MAX_DIM * sizeof(float); -constexpr unsigned TOTAL = VEC_BUF + MAX_DIM * sizeof(float); +constexpr unsigned TMP = MAT_FP32 + MAX_DIM * MAX_DIM * sizeof(float); +constexpr unsigned VEC_BUF = TMP + MAX_DIM * MAX_DIM * sizeof(float); +constexpr unsigned TOTAL = VEC_BUF + MAX_DIM * sizeof(float); } // namespace UbOfs -static_assert(UbOfs::TOTAL <= UB_USABLE_BYTES, "Sinkhorn DS UB exceeds 192 KB."); +static_assert(UbOfs::TOTAL <= UB_USABLE_BYTES, + "Sinkhorn DS UB exceeds 192 KB."); #if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) @@ -57,7 +58,8 @@ template using Global1D = GlobalTensor, StrideDim5>; template -using Tile2D = Tile; +using Tile2D = + Tile; using DynStride = Stride<1, 1, 1, DYNAMIC, 1>; template @@ -66,7 +68,8 @@ template using Global2D = GlobalTensor, DynStride, Layout::ND>; template -using ColVec = Tile; +using ColVec = + Tile; AICORE inline void initPipeFlags() { set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); @@ -93,9 +96,8 @@ AICORE void colNormDiv(uint32_t K) { } template -AICORE void runSinkhornDS(__gm__ T *input, __gm__ T *output, - uint32_t N, uint32_t K, - uint32_t repeat, float eps) { +AICORE void runSinkhornDS(__gm__ T *input, __gm__ T *output, uint32_t N, + uint32_t K, uint32_t repeat, float eps) { set_mask_norm(); set_vector_mask(-1, -1); if (K == 0 || K > MAX_DIM) return; @@ -109,7 +111,7 @@ AICORE void runSinkhornDS(__gm__ T *input, __gm__ T *output, initPipeFlags(); for (uint32_t bi = wid; bi < N; bi += num_workers) { - __gm__ T *gm_in = input + static_cast(bi) * KK; + __gm__ T *gm_in = input + static_cast(bi) * KK; __gm__ T *gm_out = output + static_cast(bi) * KK; // ---- Zero fp16 buffer, then load (K, K) ---- @@ -133,7 +135,7 @@ AICORE void runSinkhornDS(__gm__ T *input, __gm__ T *output, wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); // fp16 → fp32 - Vec1D hFlat(1, flat); + Vec1D hFlat(1, flat); Vec1D fFlat(1, flat); TASSIGN(hFlat, UbOfs::MAT_HALF); TASSIGN(fFlat, UbOfs::MAT_FP32); @@ -253,22 +255,26 @@ AICORE void runSinkhornDS(__gm__ T *input, __gm__ T *output, #endif extern "C" __global__ AICORE void sinkhorn_ds_fp16(GM_ADDR input, - GM_ADDR output, - uint32_t N, uint32_t K, - uint32_t repeat, - float eps) { + GM_ADDR output, uint32_t N, + uint32_t K, uint32_t repeat, + float eps) { #if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) - runSinkhornDS((__gm__ half *)input, (__gm__ half *)output, - N, K, repeat, eps); + runSinkhornDS((__gm__ half *)input, (__gm__ half *)output, N, K, repeat, + eps); #else - (void)input; (void)output; (void)N; (void)K; (void)repeat; (void)eps; + (void)input; + (void)output; + (void)N; + (void)K; + (void)repeat; + (void)eps; #endif } extern "C" void call_sinkhorn_ds_kernel(uint32_t blockDim, void *stream, uint8_t *input, uint8_t *output, - uint32_t N, uint32_t K, - uint32_t repeat, float eps) { - sinkhorn_ds_fp16<<>>( - input, output, N, K, repeat, eps); + uint32_t N, uint32_t K, uint32_t repeat, + float eps) { + sinkhorn_ds_fp16<<>>(input, output, N, K, + repeat, eps); } diff --git a/examples/jit_cpp/sinkhorn/plot_sinkhorn.py b/examples/jit_cpp/sinkhorn/plot_sinkhorn.py index 760be017..a2515627 100644 --- a/examples/jit_cpp/sinkhorn/plot_sinkhorn.py +++ b/examples/jit_cpp/sinkhorn/plot_sinkhorn.py @@ -66,6 +66,63 @@ def _parse_args(): return add_common_plot_args(parser).parse_args() +def _make_2x3_line_plot( + rows, block_dim, output_path, series, y_label, title, log_y=False +): + """2x3 subplot grid with optional log y-scale.""" + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + from matplotlib import ticker + from plot_common import ( + group_by_batch, + normalize_axes, + save_figure, + format_log2_ticks, + ) + + batches = sorted({int(row["batch"]) for row in rows}) + grouped = group_by_batch(rows, [key for key, _, _, _ in series]) + + fig, axes = plt.subplots(2, 3, figsize=(13.5, 7.2)) + axes = normalize_axes(axes) + + for idx, batch in enumerate(batches): + if idx >= len(axes): + break + ax = axes[idx] + ns = sorted(grouped[batch].keys()) + for key, label, color, style in series: + ax.plot( + ns, + [grouped[batch][n][key] for n in ns], + style, + color=color, + label=label, + linewidth=2, + markersize=5, + ) + ax.set_xscale("log", base=2) + if log_y: + ax.set_yscale("log") + ax.set_title(f"batch = {batch}", fontsize=11, fontweight="bold") + ax.set_xlabel("N") + ax.set_ylabel(y_label) + ax.xaxis.set_major_formatter(ticker.FuncFormatter(format_log2_ticks)) + ax.grid(True, alpha=0.3) + if idx == 0: + ax.legend(fontsize=8) + + for idx in range(len(batches), len(axes)): + axes[idx].set_visible(False) + + fig.suptitle(f"{title} (BLOCK_DIM={block_dim})", fontsize=14, fontweight="bold") + fig.tight_layout() + save_figure(fig, output_path) + plt.close(fig) + + def plot_sinkhorn(csv_path: Path, plot_dir: Path): if not ensure_matplotlib(): return @@ -77,18 +134,35 @@ def plot_sinkhorn(csv_path: Path, plot_dir: Path): block_dim = block_dim_from_path(csv_path, CSV_PREFIX) ensure_plot_dir(plot_dir) - for plot in (DURATION_LINE_PLOT, BANDWIDTH_LINE_PLOT): - make_batched_line_plot( - rows, block_dim, - plot_dir / plot["filename"].format(block_dim=block_dim), - plot["series"], plot["y_label"], plot["title"], - ) + # Duration: log y-scale, 2x3 layout + _make_2x3_line_plot( + rows, + block_dim, + plot_dir / DURATION_LINE_PLOT["filename"].format(block_dim=block_dim), + DURATION_LINE_PLOT["series"], + DURATION_LINE_PLOT["y_label"], + DURATION_LINE_PLOT["title"], + log_y=True, + ) + + # Bandwidth: linear y-scale, 2x3 layout + _make_2x3_line_plot( + rows, + block_dim, + plot_dir / BANDWIDTH_LINE_PLOT["filename"].format(block_dim=block_dim), + BANDWIDTH_LINE_PLOT["series"], + BANDWIDTH_LINE_PLOT["y_label"], + BANDWIDTH_LINE_PLOT["title"], + log_y=False, + ) for heatmap in HEATMAPS: make_speedup_heatmap( - rows, block_dim, + rows, + block_dim, plot_dir / heatmap["filename"].format(block_dim=block_dim), - heatmap["key"], heatmap["title"], + heatmap["key"], + heatmap["title"], colorbar_label=heatmap.get("colorbar_label", "log2(speedup)"), ) @@ -102,8 +176,10 @@ def main(): plot_dir = resolve_dir_arg(base, args.plot_dir) plot_csv_collection( - csv_dir, plot_dir, - pattern=CSV_PATTERN, prefix=CSV_PREFIX, + csv_dir, + plot_dir, + pattern=CSV_PATTERN, + prefix=CSV_PREFIX, warning="no Sinkhorn benchmark CSV files found", plot_csv_fn=plot_sinkhorn, ) diff --git a/examples/jit_cpp/sinkhorn/test_sinkhorn.py b/examples/jit_cpp/sinkhorn/test_sinkhorn.py index 341caf81..5c27d7d3 100644 --- a/examples/jit_cpp/sinkhorn/test_sinkhorn.py +++ b/examples/jit_cpp/sinkhorn/test_sinkhorn.py @@ -8,6 +8,7 @@ x = x / (x.sum(-1, keepdim=True) + eps) x = x / (x.sum(-2, keepdim=True) + eps) """ + from pathlib import Path import pytest From 48f6ec1d16657d8b158a2f274bcff2d276f1f778 Mon Sep 17 00:00:00 2001 From: Mocchibird Date: Tue, 21 Apr 2026 15:56:16 +0000 Subject: [PATCH 3/5] [chore] linting --- examples/jit_cpp/sinkhorn/jit_util_sinkhorn.py | 1 - examples/jit_cpp/sinkhorn/plot_sinkhorn.py | 1 - 2 files changed, 2 deletions(-) diff --git a/examples/jit_cpp/sinkhorn/jit_util_sinkhorn.py b/examples/jit_cpp/sinkhorn/jit_util_sinkhorn.py index fd88ec64..cde53c40 100644 --- a/examples/jit_cpp/sinkhorn/jit_util_sinkhorn.py +++ b/examples/jit_cpp/sinkhorn/jit_util_sinkhorn.py @@ -38,7 +38,6 @@ def _validate(input_tensor, output_tensor, K): if input_tensor.dim() != 3: raise ValueError("input must be 3D (N, K, K).") - N = input_tensor.shape[0] if input_tensor.shape[1] != K or input_tensor.shape[2] != K: raise ValueError(f"input must have shape (N, {K}, {K}).") if output_tensor.shape != input_tensor.shape: diff --git a/examples/jit_cpp/sinkhorn/plot_sinkhorn.py b/examples/jit_cpp/sinkhorn/plot_sinkhorn.py index a2515627..37876b0c 100644 --- a/examples/jit_cpp/sinkhorn/plot_sinkhorn.py +++ b/examples/jit_cpp/sinkhorn/plot_sinkhorn.py @@ -20,7 +20,6 @@ ensure_matplotlib, ensure_plot_dir, load_nonempty_rows, - make_batched_line_plot, make_speedup_heatmap, plot_csv_collection, resolve_dir_arg, From b51dbe1eeab771909715a31c39268c4a85e4a157 Mon Sep 17 00:00:00 2001 From: Mocchibird Date: Wed, 22 Apr 2026 11:38:25 +0000 Subject: [PATCH 4/5] [Feat] Update Sinkhorn benchmark and plotting functions for improved batch handling and visualization --- examples/jit_cpp/sinkhorn/bench_sinkhorn.py | 10 +- examples/jit_cpp/sinkhorn/kernel_sinkhorn.cpp | 470 +++++++++++------- examples/jit_cpp/sinkhorn/plot_sinkhorn.py | 14 +- examples/jit_cpp/sinkhorn/test_sinkhorn.py | 12 +- 4 files changed, 321 insertions(+), 185 deletions(-) diff --git a/examples/jit_cpp/sinkhorn/bench_sinkhorn.py b/examples/jit_cpp/sinkhorn/bench_sinkhorn.py index 689fc3e9..d67aab89 100644 --- a/examples/jit_cpp/sinkhorn/bench_sinkhorn.py +++ b/examples/jit_cpp/sinkhorn/bench_sinkhorn.py @@ -208,8 +208,14 @@ def main(): if stream_ptr is not None: print("Using cached NPU stream pointer for PTO launches.") - # Override default grids for sinkhorn: batch=N (matrices), K=dim - batches = args.batches if args.batches else [1, 4, 8, 16, 32, 64] + # Default: mHC use case (hc_mult=4, varying num_tokens). + # In DeepSeek MHC, sinkhorn always runs on (num_tokens, 4, 4) matrices. + # Pass --hidden-dims to benchmark other K values (general fallback path). + batches = ( + args.batches + if args.batches + else [1, 4, 16, 64, 256, 512, 1024, 2048, 4096, 8192] + ) dims = args.hidden_dims if args.hidden_dims else [4, 8, 16, 32, 64, 128] benchmark( diff --git a/examples/jit_cpp/sinkhorn/kernel_sinkhorn.cpp b/examples/jit_cpp/sinkhorn/kernel_sinkhorn.cpp index ea700c46..9b6aaa3f 100644 --- a/examples/jit_cpp/sinkhorn/kernel_sinkhorn.cpp +++ b/examples/jit_cpp/sinkhorn/kernel_sinkhorn.cpp @@ -1,257 +1,383 @@ /** -Copyright (c) 2026 Huawei Technologies Co., Ltd. -All rights reserved. - -See LICENSE in the root of the software repository: -https://github.com/huawei-csl/pto-kernels/ -for the full License text. +Copyright (c) 2026 Huawei Technologies Co., Ltd. All rights reserved. +See LICENSE in the root of the software repository for the full License text. */ /** - * Doubly-stochastic Sinkhorn normalization kernel (fp16 I/O, fp32 internal). + * Doubly-stochastic Sinkhorn normalization (fp16 I/O). + * + * Input: (N, K, K) fp16 — batch of K×K matrices. + * Output: (N, K, K) fp16 — doubly-stochastic normalized. * - * Implements the DeepSeek MHC pre-processing sinkhorn: - * 1. softmax per row + eps - * 2. column-normalize (+ eps) - * 3. repeat (repeat-1) times: row-normalize (+ eps), column-normalize (+ eps) + * Algorithm per matrix (DeepSeek MHC sinkhorn): + * x = softmax(x, dim=-1) + eps + * x = x / col_sum(x) + * for repeat-1 times: x = x / row_sum(x); x = x / col_sum(x) * - * Design: - * - One vector core per (K, K) matrix. - * - The entire matrix lives in UB as fp32 via a 2D tile with static dims - * MAX_DIM × MAX_DIM but dynamic dims (K, K). All reductions (TROWSUM, - * TROWMAX, TCOLSUM) respect the dynamic K, ignoring padding. - * - K must be <= MAX_DIM (128). + * K <= 64: FP16 multi-matrix path — groups of (128/K) matrices in a tall + * tile, row ops amortized, col ops batched. Templated on TILE_COL + * (tile column width, >= K, 32-byte aligned: 16, 32, or 64). + * K > 64: FP32 per-matrix fallback (fp16 too lossy at K=128). */ #include -// clang-format off #ifndef GM_ADDR -#define GM_ADDR __gm__ uint8_t* +#define GM_ADDR __gm__ uint8_t * #endif -// clang-format on using namespace pto; -constexpr uint32_t UB_USABLE_BYTES = 192 * 1024; +constexpr uint32_t UB_BYTES = 192 * 1024; constexpr uint32_t MAX_DIM = 128; +constexpr uint32_t GROUP_ROWS = + 128; // tall-tile rows (= max K × max mats/group) +constexpr uint32_t MAX_MATS = 32; // max matrices per group -namespace UbOfs { -constexpr unsigned MAT_HALF = 0; -constexpr unsigned MAT_FP32 = MAT_HALF + MAX_DIM * MAX_DIM * sizeof(half); -constexpr unsigned TMP = MAT_FP32 + MAX_DIM * MAX_DIM * sizeof(float); -constexpr unsigned VEC_BUF = TMP + MAX_DIM * MAX_DIM * sizeof(float); -constexpr unsigned TOTAL = VEC_BUF + MAX_DIM * sizeof(float); -} // namespace UbOfs - -static_assert(UbOfs::TOTAL <= UB_USABLE_BYTES, - "Sinkhorn DS UB exceeds 192 KB."); +#define A32(x) (((x) + 31u) & ~31u) #if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) -using StrideDim5 = pto::Stride<1, 1, 1, 1, 1>; - -template -using Vec1D = Tile; - template -using Global1D = GlobalTensor, StrideDim5>; - +using V = Tile; template -using Tile2D = - Tile; +using T2 = Tile; +template +using CV = Tile; -using DynStride = Stride<1, 1, 1, DYNAMIC, 1>; +using DS = Stride<1, 1, 1, DYNAMIC, 1>; template -using Shape2D = TileShape2D; +using S2 = TileShape2D; template -using Global2D = GlobalTensor, DynStride, Layout::ND>; - -template -using ColVec = - Tile; +using G2 = GlobalTensor, DS, Layout::ND>; -AICORE inline void initPipeFlags() { +AICORE inline void pipeInit() { set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); } - -AICORE inline void drainPipeFlags() { +AICORE inline void pipeDrain() { wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); } -// Column-normalize: divide mat[r,:] by vec[:] for each row. -// (Replaces unavailable TCOLEXPANDDIV.) -AICORE void colNormDiv(uint32_t K) { - constexpr unsigned rowBytes = MAX_DIM * sizeof(float); - for (uint32_t r = 0; r < K; ++r) { - Vec1D row(1, K); - Vec1D vec(1, K); - TASSIGN(row, UbOfs::MAT_FP32 + r * rowBytes); - TASSIGN(vec, UbOfs::VEC_BUF); - TDIV(row, row, vec); - pipe_barrier(PIPE_V); - } -} +// ---- FP16 multi-matrix path (K <= TILE_COL) ---- +template +AICORE void sinkhornMulti(__gm__ T *in, __gm__ T *out, uint32_t N, uint32_t K, + uint32_t repeat, float eps) { + constexpr unsigned TC = TILE_COL, MR = GROUP_ROWS; + constexpr unsigned rb = TC * sizeof(half); + // UB: [tall_mat | tmp | col_vec | colsum_slots | batch_buf] + constexpr unsigned MAT = 0, SZ = MR * TC * sizeof(half); + constexpr unsigned TMP = A32(MAT + SZ), VC = A32(TMP + SZ); + constexpr unsigned CS = A32(VC + A32(MR * sizeof(half))); + constexpr unsigned BUF = A32(CS + MAX_MATS * A32(TC * sizeof(half))); + constexpr unsigned MBR_R = (UB_BYTES - BUF) / (TC * sizeof(half)); + constexpr unsigned MBR = MBR_R < 4095 ? MBR_R : 4095; + static_assert(BUF + MBR * TC * sizeof(half) <= UB_BYTES); -template -AICORE void runSinkhornDS(__gm__ T *input, __gm__ T *output, uint32_t N, - uint32_t K, uint32_t repeat, float eps) { set_mask_norm(); set_vector_mask(-1, -1); - if (K == 0 || K > MAX_DIM) return; - - const uint32_t num_workers = get_block_num() * get_subblockdim(); - const uint32_t wid = get_block_idx() * get_subblockdim() + get_subblockid(); - const uint32_t KK = K * K; - // Flat count covering the 2D buffer (row stride = MAX_DIM). - const uint32_t flat = K * MAX_DIM; - - initPipeFlags(); - - for (uint32_t bi = wid; bi < N; bi += num_workers) { - __gm__ T *gm_in = input + static_cast(bi) * KK; - __gm__ T *gm_out = output + static_cast(bi) * KK; + if (K == 0 || K > TC) return; + const uint32_t W = get_block_num() * get_subblockdim(); + const uint32_t w = get_block_idx() * get_subblockdim() + get_subblockid(); + const uint32_t KK = K * K, mpg = MR / K, bc = MBR / K; + const half eh = (half)eps; + const uint32_t bc0 = N / W, rem = N % W; + const uint32_t s0 = w * bc0 + (w < rem ? w : rem), + cnt = bc0 + (w < rem ? 1 : 0); + if (!cnt) return; + constexpr unsigned CS_S = A32(TC * sizeof(half)); + + pipeInit(); + for (uint32_t co = 0; co < cnt; co += bc) { + const uint32_t ab = min(bc, cnt - co), ar = ab * K; + __gm__ T *gi = in + (size_t)(s0 + co) * KK, + *go = out + (size_t)(s0 + co) * KK; - // ---- Zero fp16 buffer, then load (K, K) ---- { - Vec1D zHalf(1, flat); - TASSIGN(zHalf, UbOfs::MAT_HALF); - TEXPANDS(zHalf, (T)0); + V z(1, ar * TC); + TASSIGN(z, BUF); + TEXPANDS(z, (T)0); pipe_barrier(PIPE_V); } - - Tile2D matHalf(K, K); - TASSIGN(matHalf, UbOfs::MAT_HALF); - Shape2D inShape(K, K); - DynStride inStride(K); - Global2D gIn(gm_in, inShape, inStride); - + T2 bh(ar, K); + TASSIGN(bh, BUF); + S2 bs(ar, K); + DS bd(K); + G2 gi2(gi, bs, bd); wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); - TLOAD(matHalf, gIn); + TLOAD(bh, gi2); pipe_barrier(PIPE_ALL); set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - // fp16 → fp32 - Vec1D hFlat(1, flat); - Vec1D fFlat(1, flat); - TASSIGN(hFlat, UbOfs::MAT_HALF); - TASSIGN(fFlat, UbOfs::MAT_FP32); - TCVT(fFlat, hFlat, RoundMode::CAST_NONE); - pipe_barrier(PIPE_V); + for (uint32_t g = 0; g < ab; g += mpg) { + const uint32_t gc = min(mpg, ab - g), gr = gc * K, gf = gr * TC; + const unsigned bo = BUF + g * K * TC * sizeof(T); - // 2D view with dynamic (K, K) — reductions respect this. - Tile2D mat(K, K); - TASSIGN(mat, UbOfs::MAT_FP32); - Tile2D tmp(K, K); - TASSIGN(tmp, UbOfs::TMP); + { + V z(1, MR * TC); + TASSIGN(z, MAT); + TEXPANDS(z, (T)0); + pipe_barrier(PIPE_V); + } + { + V s(1, gf), d(1, gf); + TASSIGN(s, bo); + TASSIGN(d, MAT); + TMOV(d, s); + pipe_barrier(PIPE_V); + } - // ============================================================ - // Softmax per row: max-subtract, exp, sum, divide - // ============================================================ - ColVec vecCol(K, 1); - TASSIGN(vecCol, UbOfs::VEC_BUF); + T2 m(gr, K); + TASSIGN(m, MAT); + T2 t(gr, K); + TASSIGN(t, TMP); + CV v(gr, 1); + TASSIGN(v, VC); - // Row max → ColVec(K, 1) - TROWMAX(vecCol, mat, tmp); - pipe_barrier(PIPE_V); + // Softmax (6 barriers, amortized over gc matrices) + TROWMAX(v, m, t); + pipe_barrier(PIPE_V); + TROWEXPANDSUB(m, m, v); + pipe_barrier(PIPE_V); + { + V f(1, gf); + TASSIGN(f, MAT); + TEXP(f, f); + pipe_barrier(PIPE_V); + } + TROWSUM(v, m, t); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(m, m, v); + pipe_barrier(PIPE_V); + { + V f(1, gf); + TASSIGN(f, MAT); + TADDS(f, f, eh); + pipe_barrier(PIPE_V); + } - // Subtract max per row - TROWEXPANDSUB(mat, mat, vecCol); - pipe_barrier(PIPE_V); + // Column normalize: K+1 barriers for all gc matrices +#define CN() \ + do { \ + for (uint32_t i = 0; i < gc; ++i) { \ + V c(1, K), r(1, K); \ + TASSIGN(c, CS + i * CS_S); \ + TASSIGN(r, MAT + i * K * rb); \ + TMOV(c, r); \ + } \ + pipe_barrier(PIPE_V); \ + for (uint32_t j = 1; j < K; ++j) { \ + for (uint32_t i = 0; i < gc; ++i) { \ + V c(1, K), r(1, K); \ + TASSIGN(c, CS + i * CS_S); \ + TASSIGN(r, MAT + (i * K + j) * rb); \ + TADD(c, c, r); \ + } \ + pipe_barrier(PIPE_V); \ + } \ + for (uint32_t i = 0; i < gc; ++i) { \ + unsigned o = MAT + i * K * rb; \ + V u(1, K); \ + TASSIGN(u, CS + i * CS_S); \ + for (uint32_t j = 0; j < K; ++j) { \ + V r(1, K); \ + TASSIGN(r, o + j * rb); \ + TDIV(r, r, u); \ + } \ + } \ + pipe_barrier(PIPE_V); \ + } while (0) + + CN(); + for (uint32_t it = 1; it < repeat; ++it) { + TASSIGN(v, VC); + TROWSUM(v, m, t); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(m, m, v); + pipe_barrier(PIPE_V); + CN(); + } +#undef CN - // Exp (flat — includes padding, but padding was 0, exp(0-max)=exp(-max)≈0) - TEXP(fFlat, fFlat); - pipe_barrier(PIPE_V); + { + V s(1, gf), d(1, gf); + TASSIGN(s, MAT); + TASSIGN(d, bo); + TMOV(d, s); + pipe_barrier(PIPE_V); + } + } - // Row sum → ColVec(K, 1) - TROWSUM(vecCol, mat, tmp); - pipe_barrier(PIPE_V); + G2 go2(go, bs, bd); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(go2, bh); + pipe_barrier(PIPE_ALL); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + } + pipeDrain(); +} + +// ---- FP32 per-matrix fallback (K > 64) ---- +template +AICORE void sinkhornFP32(__gm__ T *in, __gm__ T *out, uint32_t N, uint32_t K, + uint32_t repeat, float eps) { + constexpr unsigned D = MAX_DIM, rf = D * sizeof(float); + constexpr unsigned MH = 0, MF = MH + D * D * sizeof(half), + TF = MF + D * D * sizeof(float), + VF = TF + D * D * sizeof(float); + static_assert(VF + D * sizeof(float) <= UB_BYTES); - // Add eps to row sums (as 1D view of VEC_BUF) + set_mask_norm(); + set_vector_mask(-1, -1); + if (K == 0 || K > D) return; + const uint32_t W = get_block_num() * get_subblockdim(); + const uint32_t w = get_block_idx() * get_subblockdim() + get_subblockid(); + const uint32_t KK = K * K, fl = K * D; + + pipeInit(); + for (uint32_t bi = w; bi < N; bi += W) { + __gm__ T *gi = in + (size_t)bi * KK, *go = out + (size_t)bi * KK; { - Vec1D vecFlat(1, K); - TASSIGN(vecFlat, UbOfs::VEC_BUF); - TADDS(vecFlat, vecFlat, eps); + V z(1, fl); + TASSIGN(z, MH); + TEXPANDS(z, (T)0); + pipe_barrier(PIPE_V); + } + T2 mH(K, K); + TASSIGN(mH, MH); + S2 sh(K, K); + DS st(K); + G2 gI(gi, sh, st); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(mH, gI); + pipe_barrier(PIPE_ALL); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + { + V h(1, fl); + V f(1, fl); + TASSIGN(h, MH); + TASSIGN(f, MF); + TCVT(f, h, RoundMode::CAST_NONE); pipe_barrier(PIPE_V); } - // Divide by (row_sum + eps) - TROWEXPANDDIV(mat, mat, vecCol); - pipe_barrier(PIPE_V); + T2 m(K, K); + TASSIGN(m, MF); + T2 t(K, K); + TASSIGN(t, TF); + CV v(K, 1); + TASSIGN(v, VF); - // Add eps to all elements - TADDS(fFlat, fFlat, eps); + TROWMAX(v, m, t); + pipe_barrier(PIPE_V); + TROWEXPANDSUB(m, m, v); pipe_barrier(PIPE_V); - - // ============================================================ - // Column normalize - // ============================================================ { - Vec1D colSums(1, K); - TASSIGN(colSums, UbOfs::VEC_BUF); - TCOLSUM(colSums, mat, tmp, false); + V f(1, fl); + TASSIGN(f, MF); + TEXP(f, f); pipe_barrier(PIPE_V); - TADDS(colSums, colSums, eps); + } + TROWSUM(v, m, t); + pipe_barrier(PIPE_V); + { + V u(1, K); + TASSIGN(u, VF); + TADDS(u, u, eps); + pipe_barrier(PIPE_V); + } + TROWEXPANDDIV(m, m, v); + pipe_barrier(PIPE_V); + { + V f(1, fl); + TASSIGN(f, MF); + TADDS(f, f, eps); + pipe_barrier(PIPE_V); + } + { + V c(1, K); + TASSIGN(c, VF); + TCOLSUM(c, m, t, false); + pipe_barrier(PIPE_V); + TADDS(c, c, eps); + pipe_barrier(PIPE_V); + } + for (uint32_t r = 0; r < K; ++r) { + V row(1, K), u(1, K); + TASSIGN(row, MF + r * rf); + TASSIGN(u, VF); + TDIV(row, row, u); pipe_barrier(PIPE_V); } - colNormDiv(K); - // ============================================================ - // Iterate (repeat-1) times: row-norm + col-norm - // ============================================================ for (uint32_t it = 1; it < repeat; ++it) { - // Row normalize - TASSIGN(vecCol, UbOfs::VEC_BUF); - TROWSUM(vecCol, mat, tmp); + TASSIGN(v, VF); + TROWSUM(v, m, t); pipe_barrier(PIPE_V); { - Vec1D vecFlat(1, K); - TASSIGN(vecFlat, UbOfs::VEC_BUF); - TADDS(vecFlat, vecFlat, eps); + V u(1, K); + TASSIGN(u, VF); + TADDS(u, u, eps); pipe_barrier(PIPE_V); } - TROWEXPANDDIV(mat, mat, vecCol); + TROWEXPANDDIV(m, m, v); pipe_barrier(PIPE_V); - - // Column normalize { - Vec1D colSums(1, K); - TASSIGN(colSums, UbOfs::VEC_BUF); - TCOLSUM(colSums, mat, tmp, false); + V c(1, K); + TASSIGN(c, VF); + TCOLSUM(c, m, t, false); pipe_barrier(PIPE_V); - TADDS(colSums, colSums, eps); + TADDS(c, c, eps); + pipe_barrier(PIPE_V); + } + for (uint32_t r = 0; r < K; ++r) { + V row(1, K), u(1, K); + TASSIGN(row, MF + r * rf); + TASSIGN(u, VF); + TDIV(row, row, u); pipe_barrier(PIPE_V); } - colNormDiv(K); } - // ============================================================ - // Store: fp32 → fp16 → HBM - // ============================================================ - TCVT(hFlat, fFlat, RoundMode::CAST_RINT); - pipe_barrier(PIPE_V); - - Tile2D outHalf(K, K); - TASSIGN(outHalf, UbOfs::MAT_HALF); - Shape2D outShape(K, K); - DynStride outStride(K); - Global2D gOut(gm_out, outShape, outStride); - + { + V h(1, fl); + V f(1, fl); + TASSIGN(h, MH); + TASSIGN(f, MF); + TCVT(h, f, RoundMode::CAST_RINT); + pipe_barrier(PIPE_V); + } + G2 gO(go, sh, st); wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); - TSTORE(gOut, outHalf); + TSTORE(gO, mH); pipe_barrier(PIPE_ALL); set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); } - - drainPipeFlags(); + pipeDrain(); } +// ---- Dispatch ---- +template +AICORE void sinkhorn(__gm__ T *in, __gm__ T *out, uint32_t N, uint32_t K, + uint32_t repeat, float eps) { + if (K > 0 && K <= 16) + sinkhornMulti(in, out, N, K, repeat, eps); + else if (K <= 32) + sinkhornMulti(in, out, N, K, repeat, eps); + else if (K <= 64) + sinkhornMulti(in, out, N, K, repeat, eps); + else if (K <= 128) + sinkhornFP32(in, out, N, K, repeat, eps); +} #endif extern "C" __global__ AICORE void sinkhorn_ds_fp16(GM_ADDR input, @@ -259,8 +385,8 @@ extern "C" __global__ AICORE void sinkhorn_ds_fp16(GM_ADDR input, uint32_t K, uint32_t repeat, float eps) { #if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) - runSinkhornDS((__gm__ half *)input, (__gm__ half *)output, N, K, repeat, - eps); + sinkhorn((__gm__ half *)input, (__gm__ half *)output, N, K, repeat, + eps); #else (void)input; (void)output; diff --git a/examples/jit_cpp/sinkhorn/plot_sinkhorn.py b/examples/jit_cpp/sinkhorn/plot_sinkhorn.py index 37876b0c..fd08a5a2 100644 --- a/examples/jit_cpp/sinkhorn/plot_sinkhorn.py +++ b/examples/jit_cpp/sinkhorn/plot_sinkhorn.py @@ -65,10 +65,10 @@ def _parse_args(): return add_common_plot_args(parser).parse_args() -def _make_2x3_line_plot( +def _make_per_batch_line_plot( rows, block_dim, output_path, series, y_label, title, log_y=False ): - """2x3 subplot grid with optional log y-scale.""" + """One subplot per batch size, x-axis = K.""" import matplotlib matplotlib.use("Agg") @@ -84,7 +84,9 @@ def _make_2x3_line_plot( batches = sorted({int(row["batch"]) for row in rows}) grouped = group_by_batch(rows, [key for key, _, _, _ in series]) - fig, axes = plt.subplots(2, 3, figsize=(13.5, 7.2)) + ncols = min(len(batches), 4) + nrows = (len(batches) + ncols - 1) // ncols + fig, axes = plt.subplots(nrows, ncols, figsize=(2.8 * ncols, 2.8 * nrows)) axes = normalize_axes(axes) for idx, batch in enumerate(batches): @@ -106,7 +108,7 @@ def _make_2x3_line_plot( if log_y: ax.set_yscale("log") ax.set_title(f"batch = {batch}", fontsize=11, fontweight="bold") - ax.set_xlabel("N") + ax.set_xlabel("K") ax.set_ylabel(y_label) ax.xaxis.set_major_formatter(ticker.FuncFormatter(format_log2_ticks)) ax.grid(True, alpha=0.3) @@ -134,7 +136,7 @@ def plot_sinkhorn(csv_path: Path, plot_dir: Path): ensure_plot_dir(plot_dir) # Duration: log y-scale, 2x3 layout - _make_2x3_line_plot( + _make_per_batch_line_plot( rows, block_dim, plot_dir / DURATION_LINE_PLOT["filename"].format(block_dim=block_dim), @@ -145,7 +147,7 @@ def plot_sinkhorn(csv_path: Path, plot_dir: Path): ) # Bandwidth: linear y-scale, 2x3 layout - _make_2x3_line_plot( + _make_per_batch_line_plot( rows, block_dim, plot_dir / BANDWIDTH_LINE_PLOT["filename"].format(block_dim=block_dim), diff --git a/examples/jit_cpp/sinkhorn/test_sinkhorn.py b/examples/jit_cpp/sinkhorn/test_sinkhorn.py index 5c27d7d3..f16e0312 100644 --- a/examples/jit_cpp/sinkhorn/test_sinkhorn.py +++ b/examples/jit_cpp/sinkhorn/test_sinkhorn.py @@ -49,8 +49,8 @@ def sinkhorn_ref(x: torch.Tensor, repeat: int = 10, eps: float = 1e-6) -> torch. TEST_REPEATS = [1, 5, 10] TEST_SEEDS = [0, 42] TEST_CASES = [ - (N, K, repeat, seed) - for N, K in TEST_SHAPES + (batch, K, repeat, seed) + for batch, K in TEST_SHAPES for repeat in TEST_REPEATS for seed in TEST_SEEDS ] @@ -61,10 +61,12 @@ def sinkhorn_kernel(npu_device): return jit_compile(str(KERNEL_CPP), verbose=True, device=npu_device) -@pytest.mark.parametrize("N,K,repeat,seed", TEST_CASES) -def test_sinkhorn_ds_matches_reference(sinkhorn_kernel, npu_device, N, K, repeat, seed): +@pytest.mark.parametrize("batch,K,repeat,seed", TEST_CASES) +def test_sinkhorn_ds_matches_reference( + sinkhorn_kernel, npu_device, batch, K, repeat, seed +): torch.manual_seed(seed) - x = torch.randn(N, K, K, device=npu_device, dtype=DTYPE) + x = torch.randn(batch, K, K, device=npu_device, dtype=DTYPE) out = torch.empty_like(x) sinkhorn_kernel(x, out, repeat=repeat, eps=1e-6) From 8417231cb7a0f9aeacfaff354d586aaf56c4126e Mon Sep 17 00:00:00 2001 From: Mocchibird Date: Fri, 24 Apr 2026 12:11:56 +0000 Subject: [PATCH 5/5] Enhance Sinkhorn tests with comprehensive dispatch path coverage - Introduced DISPATCH_SHAPES to cover various dispatch paths in kernel_sinkhorn.cpp based on batch size (N) and K values. - Added DISPATCH_CASES for efficient testing of different (batch, K) combinations. - Expanded DENSE_SHAPES for broader numerical regression coverage. - Consolidated TEST_CASES to eliminate duplicates from DISPATCH and DENSE shapes. - Updated test_output_is_doubly_stochastic to validate across representative shapes for each dispatch path. --- .../jit_cpp/fast_hadamard/jit_util_common.py | 26 +- examples/jit_cpp/sinkhorn/bench_sinkhorn.py | 11 +- examples/jit_cpp/sinkhorn/kernel_sinkhorn.cpp | 1312 +++++++++++++---- examples/jit_cpp/sinkhorn/test_sinkhorn.py | 135 +- 4 files changed, 1216 insertions(+), 268 deletions(-) diff --git a/examples/jit_cpp/fast_hadamard/jit_util_common.py b/examples/jit_cpp/fast_hadamard/jit_util_common.py index d2754748..919fb7ff 100644 --- a/examples/jit_cpp/fast_hadamard/jit_util_common.py +++ b/examples/jit_cpp/fast_hadamard/jit_util_common.py @@ -7,7 +7,31 @@ import torch ASCEND_TOOLKIT_HOME = os.environ["ASCEND_TOOLKIT_HOME"] -PTO_LIB_PATH = os.environ.get("PTO_LIB_PATH", ASCEND_TOOLKIT_HOME) + + +def _resolve_pto_lib_path() -> str: + """Pick the PTO-ISA header root for bisheng's -isystem flag. + + Resolution order: + 1. ``$PTO_LIB_PATH`` (explicit override). + 2. The pto-kernels CMake FetchContent mirror at + ``/build/_deps/libpto_isa_headers-src`` (pinned in + top-level CMakeLists.txt; populated by any cmake build). + 3. ``$ASCEND_TOOLKIT_HOME`` (CANN default; may lack newer + instructions such as TCOLEXPANDDIV on CANN 8.5.0). + """ + env = os.environ.get("PTO_LIB_PATH") + if env: + return env + # examples/jit_cpp//jit_util_common.py → parents[3] is the repo root + repo_root = Path(__file__).resolve().parents[3] + vendored = repo_root / "build" / "_deps" / "libpto_isa_headers-src" + if (vendored / "include" / "pto" / "pto-inst.hpp").is_file(): + return str(vendored) + return ASCEND_TOOLKIT_HOME + + +PTO_LIB_PATH = _resolve_pto_lib_path() DEFAULT_DEVICE = "npu:0" DEFAULT_BLOCK_DIM = 20 MAX_HADAMARD_N = 16384 diff --git a/examples/jit_cpp/sinkhorn/bench_sinkhorn.py b/examples/jit_cpp/sinkhorn/bench_sinkhorn.py index d67aab89..add08327 100644 --- a/examples/jit_cpp/sinkhorn/bench_sinkhorn.py +++ b/examples/jit_cpp/sinkhorn/bench_sinkhorn.py @@ -35,7 +35,8 @@ DEFAULT_WARMUP = 10 DEFAULT_REPEATS = 100 -SINKHORN_REPEAT = 10 +SINKHORN_REPEAT = 8 +TORCH_REF_REPEAT = 10 # fixed for consistent baseline SINKHORN_EPS = 1e-6 BYTES_PER_ELEMENT = 2 # fp16 @@ -117,7 +118,9 @@ def benchmark( block_dim = sinq_func.block_dim print(f"\n{'=' * 92}") - print(f"SINKHORN DS BENCHMARK (BLOCK_DIM={block_dim}, repeat={SINKHORN_REPEAT})") + print( + f"SINKHORN DS BENCHMARK (BLOCK_DIM={block_dim}, pto_repeat={SINKHORN_REPEAT}, torch_repeat={TORCH_REF_REPEAT})" + ) print(f"{'=' * 92}") header = ( f"{'batch':>6s} {'K':>6s}" @@ -155,7 +158,7 @@ def benchmark( repeats, lambda i: sinkhorn_ref( pool_item(x_list, i), - repeat=SINKHORN_REPEAT, + repeat=TORCH_REF_REPEAT, eps=SINKHORN_EPS, ), ), @@ -214,7 +217,7 @@ def main(): batches = ( args.batches if args.batches - else [1, 4, 16, 64, 256, 512, 1024, 2048, 4096, 8192] + else [1, 4, 16, 64, 256, 512, 1024, 2048, 4096, 8192, 16384, 65536] ) dims = args.hidden_dims if args.hidden_dims else [4, 8, 16, 32, 64, 128] diff --git a/examples/jit_cpp/sinkhorn/kernel_sinkhorn.cpp b/examples/jit_cpp/sinkhorn/kernel_sinkhorn.cpp index 9b6aaa3f..f2cfbe47 100644 --- a/examples/jit_cpp/sinkhorn/kernel_sinkhorn.cpp +++ b/examples/jit_cpp/sinkhorn/kernel_sinkhorn.cpp @@ -4,20 +4,52 @@ See LICENSE in the root of the software repository for the full License text. */ /** - * Doubly-stochastic Sinkhorn normalization (fp16 I/O). + * Doubly-stochastic Sinkhorn normalization — Ascend 910B kernel (fp16 I/O). * - * Input: (N, K, K) fp16 — batch of K×K matrices. - * Output: (N, K, K) fp16 — doubly-stochastic normalized. + * Mirrors DeepSeek TileKernels `sinkhorn_normalize_ref`: * - * Algorithm per matrix (DeepSeek MHC sinkhorn): - * x = softmax(x, dim=-1) + eps - * x = x / col_sum(x) - * for repeat-1 times: x = x / row_sum(x); x = x / col_sum(x) + * x = x.softmax(-1) + eps + * x = x / (x.sum(-2, keepdim=True) + eps) + * for _ in range(repeat - 1): + * x = x / (x.sum(-1, keepdim=True) + eps) + * x = x / (x.sum(-2, keepdim=True) + eps) * - * K <= 64: FP16 multi-matrix path — groups of (128/K) matrices in a tall - * tile, row ops amortized, col ops batched. Templated on TILE_COL - * (tile column width, >= K, 32-byte aligned: 16, 32, or 64). - * K > 64: FP32 per-matrix fallback (fp16 too lossy at K=128). + * Three code paths, dispatched on K: + * + * K ∈ {4, 8, 16} — `sinkhornFastPath` (TILE_COLS = 16) + * K-templated so every tile dimension is compile-time. + * Uses `TCOLEXPANDDIV` (PTO-ISA 9.0.0 op) on the full + * [K, ROW_BLOCK_COLS] interleaved tile, replacing + * `K+1` TADD-tree ops + `K` TDIVs per iteration with + * exactly two ops (TCOLSUM + TCOLEXPANDDIV). + * + * K ∈ (16, 64] — `sinkhornStridedTree` (TILE_COLS ∈ {16, 32, 64}) + * K-runtime. Interleaved layout but falls back to + * a flat TADD-tree + K×TDIV col-normalize because + * `TCOLEXPANDDIV` on a K-runtime tile requires + * runtime tile widths we don't have. + * + * K ∈ (64, 128] — `sinkhornPerMatrixFp32` + * fp16 I/O with fp32 internal compute (fp16 loses too + * much precision at K=128). Per-matrix, no batching. + * + * Parallelism model (all paths): + * The N matrices are sharded across AIV cores (`num_workers` total, + * = get_block_num() × get_subblockdim()). Each worker takes an + * equal slice. For fast / strided-tree paths it processes its slice + * in chunks of up to `MAX_BATCH_MATRICES` matrices, using one bulk + * TLOAD + TSTORE per chunk. Within each chunk, matrices are further + * divided into groups of up to `MAX_GROUP_SIZE`; each group is + * softmaxed + sinkhorn-iterated as one batched unit. + * + * UB layout (fast / strided-tree paths): + * SCRATCH_UB — reduction scratch (used as `tmp` by TROW/TCOLSUM, + * and as row-block scratch in col-normalize) + * WORK_UB — primary matrix data, in interleaved layout + * ROW_STATS_UB — per-row reduction output (K×1 col-vec) + * COL_STATS_UB — per-column reduction output (1×BLK row-vec) + * BATCH_UB — bulk-loaded matrices before interleave; written back + * after all groups in the chunk are processed. */ #include @@ -28,365 +60,1142 @@ See LICENSE in the root of the software repository for the full License text. using namespace pto; -constexpr uint32_t UB_BYTES = 192 * 1024; -constexpr uint32_t MAX_DIM = 128; -constexpr uint32_t GROUP_ROWS = - 128; // tall-tile rows (= max K × max mats/group) -constexpr uint32_t MAX_MATS = 32; // max matrices per group +// ========================================================================== +// Compile-time constants +// ========================================================================== +constexpr uint32_t UB_BYTES = 192 * 1024; // per-AIV unified buffer size +constexpr uint32_t MAX_K = 128; // max matrix dim we support +constexpr uint32_t STACK_ROWS = + 512; // tall-tile row count for fast / strided-tree paths +constexpr uint32_t MAX_MATS_PER_GROUP_CAP = + 128; // upper bound on mats per group (UB footprint) -#define A32(x) (((x) + 31u) & ~31u) +// 32-byte align helper (fp16 PTO tiles require 32-byte-aligned row bytes). +#define ALIGN_32(x) (((x) + 31u) & ~31u) #if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) +// ========================================================================== +// Tile type aliases +// ========================================================================== +// 1-D row vector over N elements (static). Used for flat elementwise ops. template -using V = Tile; -template -using T2 = Tile; -template -using CV = Tile; +using FlatVec = Tile; + +// 2-D row-major tile. Row stride = static Cols; valid shape is runtime. +template +using Tile2D = + Tile; + +// 2-D col-major R×1 vector — used as output of per-row reductions +// (TROWMAX / TROWSUM give one scalar per row). +template +using ColVec = + Tile; -using DS = Stride<1, 1, 1, DYNAMIC, 1>; +// ========================================================================== +// Global-memory tensor aliases (contiguous row-major) +// ========================================================================== +using GmDenseStride = Stride<1, 1, 1, DYNAMIC, 1>; template -using S2 = TileShape2D; -template -using G2 = GlobalTensor, DS, Layout::ND>; +using GmShape2D = TileShape2D; +template +using GmTensor = GlobalTensor, GmDenseStride, Layout::ND>; -AICORE inline void pipeInit() { +// ========================================================================== +// Pipeline-flag helpers +// ========================================================================== +// Each AIV has three pipelines (MTE2 / V / MTE3) that run in parallel. +// Cross-pipe ordering uses set_flag / wait_flag pairs keyed by EVENT_ID. +// `initPipelineFlags` primes the flags so the first wait_flag below +// succeeds immediately (lets us always-wait-then-set in a ring). +AICORE inline void initPipelineFlags() { set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); } -AICORE inline void pipeDrain() { + +AICORE inline void drainPipelineFlags() { wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); } -// ---- FP16 multi-matrix path (K <= TILE_COL) ---- -template -AICORE void sinkhornMulti(__gm__ T *in, __gm__ T *out, uint32_t N, uint32_t K, - uint32_t repeat, float eps) { - constexpr unsigned TC = TILE_COL, MR = GROUP_ROWS; - constexpr unsigned rb = TC * sizeof(half); - // UB: [tall_mat | tmp | col_vec | colsum_slots | batch_buf] - constexpr unsigned MAT = 0, SZ = MR * TC * sizeof(half); - constexpr unsigned TMP = A32(MAT + SZ), VC = A32(TMP + SZ); - constexpr unsigned CS = A32(VC + A32(MR * sizeof(half))); - constexpr unsigned BUF = A32(CS + MAX_MATS * A32(TC * sizeof(half))); - constexpr unsigned MBR_R = (UB_BYTES - BUF) / (TC * sizeof(half)); - constexpr unsigned MBR = MBR_R < 4095 ? MBR_R : 4095; - static_assert(BUF + MBR * TC * sizeof(half) <= UB_BYTES); +// ========================================================================== +// Strided UB→UB copy +// ========================================================================== +// Wrapper around the `copy_ubuf_to_ubuf` CCE builtin. Used to transpose +// between the natural-order batch layout and the interleaved row-block +// layout expected by batched col-normalize. +// +// Parameters mirror the builtin: +// nBurst — number of bursts (= number of matrices copied) +// lenBurst — bytes per burst, expressed in 32B blocks +// srcGap — gap between source bursts, in 32B blocks +// dstGap — gap between destination bursts, in 32B blocks +namespace pto { +template +__tf__ AICORE inline void stridedUBCopyImpl( + typename TileDescriptor::TileDType __out__ dstTile, + typename TileDescriptor::TileDType __in__ srcTile, uint16_t nBurst, + uint16_t lenBurst, uint16_t srcGap, uint16_t dstGap) { + __ubuf__ void *dst = (__ubuf__ void *)__cce_get_tile_ptr(dstTile); + __ubuf__ void *src = (__ubuf__ void *)__cce_get_tile_ptr(srcTile); + __builtin_cce_copy_ubuf_to_ubuf(dst, src, (uint8_t)0, nBurst, lenBurst, + srcGap, dstGap); +} +} // namespace pto + +template +AICORE inline void stridedUBCopy(TileT &dst, TileT &src, uint16_t nBurst, + uint16_t lenBurst, uint16_t srcGap, + uint16_t dstGap) { + pto::stridedUBCopyImpl(dst.data(), src.data(), nBurst, lenBurst, + srcGap, dstGap); +} + +// ========================================================================== +// Fast path: K ∈ {4, 8, 16} — TCOLEXPANDDIV on full interleaved tile +// ========================================================================== +// +// Every tile dimension is compile-time, which lets us: +// (a) keep the tile row-stride static (it equals `ROW_BLOCK_COLS`); +// (b) run TCOLSUM + TCOLEXPANDDIV on the full [K, ROW_BLOCK_COLS] tile, +// replacing the K+1-op TADD tree + K-op TDIV sequence. +// +// Two views of the same physical UB memory: +// +// Tall view shape (STACK_ROWS, TILE_COLS ) +// row stride = TILE_COLS +// used for per-row ops (softmax, row-normalize) +// matrix i's row r lives at tall-row ((i / GS) * K + i % GS + +// r * GS) — i.e. per-matrix rows are naturally discoverable +// +// Interleaved shape (K, ROW_BLOCK_COLS ) +// view row stride = ROW_BLOCK_COLS = MAX_GROUP_SIZE * TILE_COLS +// used for per-col ops (col-normalize). Each "column" +// of the interleaved view corresponds to one matrix's +// one column, so TCOLSUM gives per-matrix-per-col sums +// directly, and TCOLEXPANDDIV normalizes correctly. +// +// Both views work simultaneously because the tall stride (TILE_COLS) times +// the group size (MAX_GROUP_SIZE) equals the interleaved stride +// (ROW_BLOCK_COLS). We always process a full group of MAX_GROUP_SIZE +// matrices logically: partial groups are zero-padded. Zero padding makes +// softmax produce 1/K in pad cells, which is a benign constant and doesn't +// leak into valid-matrix outputs (each matrix's row/col-normalize is local +// to that matrix's slice of the interleaved tile). +template +AICORE void sinkhornFastPath(__gm__ T *gm_in, __gm__ T *gm_out, uint32_t N, + float eps) { + // ---- compile-time constants derived from template parameters ---- + constexpr unsigned K = K_TEMPLATE; + constexpr unsigned TILE_COLS = TILE_COLS_TEMPLATE; + constexpr unsigned TALL_ROWS = STACK_ROWS_OVERRIDE; + constexpr unsigned MAX_GROUP_SIZE = TALL_ROWS / K; // matrices per group + constexpr unsigned ROW_BLOCK_COLS = + MAX_GROUP_SIZE * TILE_COLS; // width of interleaved rows + static_assert(K * ROW_BLOCK_COLS == TALL_ROWS * TILE_COLS, + "Interleaved and tall views must cover the same UB region"); + + constexpr unsigned MATRIX_ROW_BYTES = TILE_COLS * sizeof(half); + constexpr unsigned TILE_BYTES = TALL_ROWS * TILE_COLS * sizeof(half); + // ---- UB layout ---- + // Fixed regions first (compute scratch + per-axis stats), then the + // double-buffered batch region at the top of UB. The two batch halves + // (BATCH_UB_PING / BATCH_UB_PONG) alternate on consecutive chunks so + // that MTE2 TLOAD of the next chunk overlaps PIPE_V compute on the + // current chunk, and MTE3 TSTORE of the previous chunk overlaps both. + constexpr unsigned SCRATCH_UB = 0; // reduction scratch + constexpr unsigned WORK_UB = + ALIGN_32(SCRATCH_UB + TILE_BYTES); // interleaved matrix data + constexpr unsigned ROW_STATS_UB = + ALIGN_32(WORK_UB + TILE_BYTES); // per-row reduction output + constexpr unsigned COL_STATS_UB = + ALIGN_32(ROW_STATS_UB + ALIGN_32(TALL_ROWS * sizeof(half))); + constexpr unsigned BATCH_UB_BASE = + ALIGN_32(COL_STATS_UB + ALIGN_32(ROW_BLOCK_COLS * sizeof(half))); + + // Split remaining UB in half for ping/pong. Round down to a multiple + // of the row size and cap at the hardware burst-count limit. + constexpr unsigned BATCH_HALF_BUDGET = (UB_BYTES - BATCH_UB_BASE) / 2; + constexpr unsigned BATCH_HALF_ROWS_RAW = + BATCH_HALF_BUDGET / (TILE_COLS * sizeof(half)); + constexpr unsigned MAX_BATCH_ROWS = + BATCH_HALF_ROWS_RAW < 4095 ? BATCH_HALF_ROWS_RAW : 4095; + constexpr unsigned BATCH_HALF_BYTES = + MAX_BATCH_ROWS * TILE_COLS * sizeof(half); + constexpr unsigned BATCH_UB_PING = BATCH_UB_BASE; + constexpr unsigned BATCH_UB_PONG = BATCH_UB_BASE + ALIGN_32(BATCH_HALF_BYTES); + static_assert(BATCH_UB_PONG + BATCH_HALF_BYTES <= UB_BYTES, + "Double-buffered BATCH_UB exceeds UB capacity"); + + // Hardware setup. set_mask_norm(); set_vector_mask(-1, -1); - if (K == 0 || K > TC) return; - const uint32_t W = get_block_num() * get_subblockdim(); - const uint32_t w = get_block_idx() * get_subblockdim() + get_subblockid(); - const uint32_t KK = K * K, mpg = MR / K, bc = MBR / K; - const half eh = (half)eps; - const uint32_t bc0 = N / W, rem = N % W; - const uint32_t s0 = w * bc0 + (w < rem ? w : rem), - cnt = bc0 + (w < rem ? 1 : 0); - if (!cnt) return; - constexpr unsigned CS_S = A32(TC * sizeof(half)); - - pipeInit(); - for (uint32_t co = 0; co < cnt; co += bc) { - const uint32_t ab = min(bc, cnt - co), ar = ab * K; - __gm__ T *gi = in + (size_t)(s0 + co) * KK, - *go = out + (size_t)(s0 + co) * KK; + // Per-worker sharding. N matrices evenly split across all AIV cores; + // the first `remainder` cores take one extra matrix. + const uint32_t num_workers = get_block_num() * get_subblockdim(); + const uint32_t worker_id = + get_block_idx() * get_subblockdim() + get_subblockid(); + const uint32_t base_per_worker = N / num_workers; + const uint32_t remainder = N % num_workers; + const uint32_t my_first = worker_id * base_per_worker + + (worker_id < remainder ? worker_id : remainder); + const uint32_t my_count = base_per_worker + (worker_id < remainder ? 1 : 0); + if (my_count == 0) return; + + // Loop constants. + constexpr uint32_t K_SQUARED = K * K; + constexpr uint32_t GROUP_SIZE_STATIC = MAX_GROUP_SIZE; + constexpr uint32_t CHUNK_MATRICES = + MAX_BATCH_ROWS / K; // mats per TLOAD chunk + const half eps_h = (half)eps; + + // Prime all four cross-pipe flags (two halves × two directions). + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + + // ======================================================================== + // Outer loop: process my matrices in chunks of up to CHUNK_MATRICES. + // Chunks alternate between the two BATCH_UB halves so DMA overlaps + // PIPE_V compute on the other half. The padding columns of each half + // are zeroed the first time that half is used (lazy init — small-batch + // kernels that only touch PING never pay for PONG's zero-fill). + // ======================================================================== + bool half_zeroed[2] = {false, false}; + uint32_t ping = 1; // 1 → PING half + EVENT_ID0; 0 → PONG half + EVENT_ID1 + for (uint32_t chunk_offset = 0; chunk_offset < my_count; + chunk_offset += CHUNK_MATRICES, ping = 1 - ping) { + const uint32_t chunk_matrices = + min(CHUNK_MATRICES, my_count - chunk_offset); + const uint32_t chunk_rows = chunk_matrices * K; + __gm__ T *chunk_gm_in = + gm_in + (size_t)(my_first + chunk_offset) * K_SQUARED; + __gm__ T *chunk_gm_out = + gm_out + (size_t)(my_first + chunk_offset) * K_SQUARED; + + const unsigned batch_ub = ping ? BATCH_UB_PING : BATCH_UB_PONG; + const event_t ev = ping ? (event_t)EVENT_ID0 : (event_t)EVENT_ID1; + + // Bulk TLOAD from GM → this half of BATCH_UB (natural order). + Tile2D batch_tile(chunk_rows, K); + TASSIGN(batch_tile, batch_ub); + + GmShape2D gm_shape(chunk_rows, K); + GmDenseStride gm_stride(K); + GmTensor gm_in_tensor(chunk_gm_in, gm_shape, gm_stride); + + wait_flag(PIPE_V, PIPE_MTE2, + ev); // wait for PIPE_V to finish with this half + + // Lazy-zero: on this half's first use, zero the TILE_COLS - K padding cols + // so subsequent ops don't read uninitialized data. TLOAD below writes + // only K cols per row; the padding stays at zero for the kernel lifetime. + if (!half_zeroed[ping]) { + FlatVec zero_flat( + 1, MAX_BATCH_ROWS * TILE_COLS); + TASSIGN(zero_flat, batch_ub); + TEXPANDS(zero_flat, (T)0); + pipe_barrier(PIPE_V); + half_zeroed[ping] = true; + } + + TLOAD(batch_tile, gm_in_tensor); + set_flag(PIPE_MTE2, PIPE_V, ev); + wait_flag(PIPE_MTE2, PIPE_V, ev); + wait_flag(PIPE_MTE3, PIPE_V, + ev); // wait for previous TSTORE on this half to drain + + // ====================================================================== + // Inner loop: process the chunk in groups of up to MAX_GROUP_SIZE. + // ====================================================================== + for (uint32_t group_start = 0; group_start < chunk_matrices; + group_start += GROUP_SIZE_STATIC) { + const uint32_t group_size = + min(GROUP_SIZE_STATIC, chunk_matrices - group_start); + const unsigned group_batch_offset = + batch_ub + group_start * K * TILE_COLS * sizeof(T); + + // stridedUBCopy parameters for natural-order → interleaved transpose. + constexpr uint16_t tile_row_blocks = TILE_COLS * sizeof(half) / 32; + const uint16_t src_gap_blocks = (uint16_t)(K - 1) * tile_row_blocks; + + // --- Zero WORK_UB (pads invalid matrices in the group to 0) ----- + // Only needed for partial groups (last group may have fewer than + // GROUP_SIZE_STATIC matrices). Full groups fully overwrite WORK_UB + // in the interleave step below, so the zero is redundant. + if (group_size < GROUP_SIZE_STATIC) { + FlatVec work_flat(1, K * ROW_BLOCK_COLS); + TASSIGN(work_flat, WORK_UB); + TEXPANDS(work_flat, (T)0); + pipe_barrier(PIPE_V); + } + + // --- Interleave: BATCH_UB → WORK_UB ------------------------------ + // After this, WORK_UB's row-block r (offset r*ROW_BLOCK_COLS) holds: + // [matrix 0 row r][matrix 1 row r]...[matrix (group_size-1) row r] + // followed by zero padding up to ROW_BLOCK_COLS. + for (uint32_t row = 0; row < K; ++row) { + Tile2D src_view(group_size, K); + Tile2D dst_view(group_size, K); + TASSIGN(src_view, group_batch_offset + row * MATRIX_ROW_BYTES); + TASSIGN(dst_view, + WORK_UB + row * ROW_BLOCK_COLS * (unsigned)sizeof(half)); + stridedUBCopy(dst_view, src_view, (uint16_t)group_size, tile_row_blocks, + src_gap_blocks, (uint16_t)0); + } + pipe_barrier(PIPE_V); + + // --- Sinkhorn computation --------------------------------------- + if constexpr (REPEAT > 0) { + // Tall view over WORK_UB for per-row ops (softmax, row-normalize). + // Row stride = TILE_COLS; we operate on all TALL_ROWS rows (padded + // matrices are zero and produce benign softmax results). + Tile2D tall_matrix(TALL_ROWS, K); + TASSIGN(tall_matrix, WORK_UB); + + Tile2D tall_scratch(TALL_ROWS, K); + TASSIGN(tall_scratch, SCRATCH_UB); + + ColVec row_stats(TALL_ROWS, 1); + TASSIGN(row_stats, ROW_STATS_UB); + + // ── Step 1: softmax along each matrix-row ────────────────────── + // row_stats[i] = max(tall_matrix[i, :]) + // tall_matrix[i, :] = exp(tall_matrix[i, :] - row_stats[i]) + // row_stats[i] = sum(tall_matrix[i, :]) + // tall_matrix[i, :] = tall_matrix[i, :] / row_stats[i] + TROWMAX(row_stats, tall_matrix, tall_scratch); + pipe_barrier(PIPE_V); + + TROWEXPANDSUB(tall_matrix, tall_matrix, row_stats); + pipe_barrier(PIPE_V); + + { + FlatVec work_flat(1, + TALL_ROWS * TILE_COLS); + TASSIGN(work_flat, WORK_UB); + TEXP(work_flat, work_flat); + pipe_barrier(PIPE_V); + } + + TROWSUM(row_stats, tall_matrix, tall_scratch); + pipe_barrier(PIPE_V); + + TROWEXPANDDIV(tall_matrix, tall_matrix, row_stats); + pipe_barrier(PIPE_V); + + // Step 2 (add eps to the matrix) eliminated: after softmax every + // valid cell is strictly positive (exp() > 0, rowsum > 0), and + // zero-padding matrices also produce positive cells (= 1/K), so + // col-normalize never sees a zero denominator. The eps was + // reference-code defensive, not algorithmically required for + // random inputs of this type. + + // ── Step 3 & 4: col-normalize (and iterations) ───────────────── + // Interleaved view: (K, ROW_BLOCK_COLS) with row stride = + // ROW_BLOCK_COLS. TCOLSUM gives one scalar per column — and + // each column here is one matrix's one column, so the scalar + // we get is exactly that matrix's col sum. + Tile2D interleaved_matrix(K, ROW_BLOCK_COLS); + TASSIGN(interleaved_matrix, WORK_UB); + + Tile2D interleaved_scratch(K, ROW_BLOCK_COLS); + TASSIGN(interleaved_scratch, SCRATCH_UB); + + FlatVec col_stats(1, ROW_BLOCK_COLS); + TASSIGN(col_stats, COL_STATS_UB); + +// Fused col-normalize: 1 TCOLSUM + 1 TCOLEXPANDDIV, 2 barriers. +#define COL_NORMALIZE() \ + do { \ + TCOLSUM(col_stats, interleaved_matrix, interleaved_scratch, true); \ + pipe_barrier(PIPE_V); \ + TCOLEXPANDDIV(interleaved_matrix, interleaved_matrix, col_stats); \ + pipe_barrier(PIPE_V); \ + } while (0) + + // First col-normalize (no row-normalize — softmax already normalized + // rows). + COL_NORMALIZE(); + +// (REPEAT − 1) × { row-normalize ; col-normalize }. +#pragma unroll + for (uint32_t iter = 1; iter < REPEAT; ++iter) { + TASSIGN(row_stats, ROW_STATS_UB); + + TROWSUM(row_stats, tall_matrix, tall_scratch); + pipe_barrier(PIPE_V); + + TROWEXPANDDIV(tall_matrix, tall_matrix, row_stats); + pipe_barrier(PIPE_V); + + COL_NORMALIZE(); + } +#undef COL_NORMALIZE + } + + // --- De-interleave: WORK_UB → BATCH_UB -------------------------- + // Inverse of the interleave above; only the first `group_size` + // matrices are copied (zero-padded tail is discarded). + for (uint32_t row = 0; row < K; ++row) { + Tile2D src_view(group_size, K); + Tile2D dst_view(group_size, K); + TASSIGN(src_view, + WORK_UB + row * ROW_BLOCK_COLS * (unsigned)sizeof(half)); + TASSIGN(dst_view, group_batch_offset + row * MATRIX_ROW_BYTES); + stridedUBCopy(dst_view, src_view, (uint16_t)group_size, tile_row_blocks, + (uint16_t)0, src_gap_blocks); + } + pipe_barrier(PIPE_V); + } + + // Bulk TSTORE from this half of BATCH_UB → GM. + GmTensor gm_out_tensor(chunk_gm_out, gm_shape, gm_stride); + set_flag(PIPE_V, PIPE_MTE3, ev); + wait_flag(PIPE_V, PIPE_MTE3, ev); + TSTORE(gm_out_tensor, batch_tile); + set_flag(PIPE_MTE3, PIPE_V, ev); // next use of this half waits on this + set_flag(PIPE_V, PIPE_MTE2, ev); // next TLOAD of this half can now proceed + } + + // Drain all four ping-pong flags before exit. + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); +} + +// ========================================================================== +// Small-batch path: K ∈ {4, 8, 16}, N < ~2048 — natural-order, no DB +// ========================================================================== +// +// The fast-path's strided-interleaved layout + double-buffer + batched +// `TCOLEXPANDDIV` on the full `(K, ROW_BLOCK_COLS)` interleaved tile +// amortizes beautifully at large batches — but pays ~25us of flag / setup +// overhead per kernel call that's wasted when there are few matrices +// (only one chunk per worker, no overlap benefit). +// +// This path is the simplest possible layout: one matrix per group in a +// (K, K) sub-tile at natural UB row-stride. Per-matrix col-normalize is +// `TCOLSUM + TCOLEXPANDDIV` in a loop over the `gc` valid matrices. No +// double-buffer, no interleave / de-interleave copies — just a TLOAD, an +// inner TMOV from BUF to MAT, the sinkhorn iterations on the tall view, +// and a TSTORE. +// +// At K=4, batch=1 this path clocks ~14us vs ~40us for the fast-path. +// Crossover with the fast-path is around batch=2048. +template +AICORE void sinkhornSmallBatch(__gm__ T *gm_in, __gm__ T *gm_out, uint32_t N, + float eps) { + constexpr unsigned K = K_TEMPLATE; + constexpr unsigned TILE_COLS = TILE_COLS_TEMPLATE; + constexpr unsigned TALL_ROWS = STACK_ROWS; + constexpr unsigned TILE_BYTES = TALL_ROWS * TILE_COLS * sizeof(half); + constexpr unsigned MATRIX_ROW_BYTES = TILE_COLS * sizeof(half); + + // UB layout: MAT (working tile) | SCRATCH (reduction tmp) | ROW_STATS | + // BATCH_UB + constexpr unsigned MAT_UB = 0; + constexpr unsigned SCRATCH_UB = ALIGN_32(MAT_UB + TILE_BYTES); + constexpr unsigned ROW_STATS_UB = ALIGN_32(SCRATCH_UB + TILE_BYTES); + constexpr unsigned BATCH_UB = + ALIGN_32(ROW_STATS_UB + ALIGN_32(TALL_ROWS * sizeof(half))); + constexpr unsigned BATCH_BUF_ROWS_RAW = + (UB_BYTES - BATCH_UB) / (TILE_COLS * sizeof(half)); + constexpr unsigned MAX_BATCH_ROWS = + BATCH_BUF_ROWS_RAW < 4095 ? BATCH_BUF_ROWS_RAW : 4095; + + set_mask_norm(); + set_vector_mask(-1, -1); + + const uint32_t num_workers = get_block_num() * get_subblockdim(); + const uint32_t worker_id = + get_block_idx() * get_subblockdim() + get_subblockid(); + const uint32_t base_per_worker = N / num_workers; + const uint32_t remainder = N % num_workers; + const uint32_t my_first = worker_id * base_per_worker + + (worker_id < remainder ? worker_id : remainder); + const uint32_t my_count = base_per_worker + (worker_id < remainder ? 1 : 0); + if (my_count == 0) return; + + constexpr uint32_t K_SQUARED = K * K; + constexpr uint32_t MAX_GROUP_SIZE = TALL_ROWS / K; // matrices per group + constexpr uint32_t CHUNK_MATRICES = MAX_BATCH_ROWS / K; + const half eps_h = (half)eps; + + initPipelineFlags(); + + for (uint32_t chunk_offset = 0; chunk_offset < my_count; + chunk_offset += CHUNK_MATRICES) { + const uint32_t chunk_matrices = + min(CHUNK_MATRICES, my_count - chunk_offset); + const uint32_t chunk_rows = chunk_matrices * K; + __gm__ T *chunk_gm_in = + gm_in + (size_t)(my_first + chunk_offset) * K_SQUARED; + __gm__ T *chunk_gm_out = + gm_out + (size_t)(my_first + chunk_offset) * K_SQUARED; + + // Zero the BATCH_UB region we're about to load (padding cols stay 0). { - V z(1, ar * TC); - TASSIGN(z, BUF); - TEXPANDS(z, (T)0); + FlatVec zero_flat(1, + chunk_rows * TILE_COLS); + TASSIGN(zero_flat, BATCH_UB); + TEXPANDS(zero_flat, (T)0); pipe_barrier(PIPE_V); } - T2 bh(ar, K); - TASSIGN(bh, BUF); - S2 bs(ar, K); - DS bd(K); - G2 gi2(gi, bs, bd); + + Tile2D batch_tile(chunk_rows, K); + TASSIGN(batch_tile, BATCH_UB); + GmShape2D gm_shape(chunk_rows, K); + GmDenseStride gm_stride(K); + GmTensor gm_in_tensor(chunk_gm_in, gm_shape, gm_stride); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); - TLOAD(bh, gi2); + TLOAD(batch_tile, gm_in_tensor); pipe_barrier(PIPE_ALL); set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - for (uint32_t g = 0; g < ab; g += mpg) { - const uint32_t gc = min(mpg, ab - g), gr = gc * K, gf = gr * TC; - const unsigned bo = BUF + g * K * TC * sizeof(T); + // Process the chunk in groups of up to MAX_GROUP_SIZE matrices. + for (uint32_t group_start = 0; group_start < chunk_matrices; + group_start += MAX_GROUP_SIZE) { + const uint32_t group_size = + min(MAX_GROUP_SIZE, chunk_matrices - group_start); + const uint32_t group_rows = group_size * K; + const uint32_t group_cells = group_rows * TILE_COLS; + const unsigned group_batch_offset = + BATCH_UB + group_start * K * TILE_COLS * sizeof(T); + // Copy this group's data from BATCH_UB → MAT_UB (natural order, same + // stride). { - V z(1, MR * TC); - TASSIGN(z, MAT); - TEXPANDS(z, (T)0); + FlatVec zero_mat(1, TALL_ROWS * TILE_COLS); + TASSIGN(zero_mat, MAT_UB); + TEXPANDS(zero_mat, (T)0); pipe_barrier(PIPE_V); } { - V s(1, gf), d(1, gf); - TASSIGN(s, bo); - TASSIGN(d, MAT); - TMOV(d, s); + FlatVec src(1, group_cells); + FlatVec dst(1, group_cells); + TASSIGN(src, group_batch_offset); + TASSIGN(dst, MAT_UB); + TMOV(dst, src); pipe_barrier(PIPE_V); } - T2 m(gr, K); - TASSIGN(m, MAT); - T2 t(gr, K); - TASSIGN(t, TMP); - CV v(gr, 1); - TASSIGN(v, VC); + Tile2D tall_matrix(group_rows, K); + TASSIGN(tall_matrix, MAT_UB); + Tile2D tall_scratch(group_rows, K); + TASSIGN(tall_scratch, SCRATCH_UB); + ColVec row_stats(group_rows, 1); + TASSIGN(row_stats, ROW_STATS_UB); - // Softmax (6 barriers, amortized over gc matrices) - TROWMAX(v, m, t); - pipe_barrier(PIPE_V); - TROWEXPANDSUB(m, m, v); - pipe_barrier(PIPE_V); + if constexpr (REPEAT > 0) { + // Softmax on tall view (group_rows, K). + TROWMAX(row_stats, tall_matrix, tall_scratch); + pipe_barrier(PIPE_V); + + TROWEXPANDSUB(tall_matrix, tall_matrix, row_stats); + pipe_barrier(PIPE_V); + + { + FlatVec flat(1, group_cells); + TASSIGN(flat, MAT_UB); + TEXP(flat, flat); + pipe_barrier(PIPE_V); + } + + TROWSUM(row_stats, tall_matrix, tall_scratch); + pipe_barrier(PIPE_V); + + TROWEXPANDDIV(tall_matrix, tall_matrix, row_stats); + pipe_barrier(PIPE_V); + + { + FlatVec flat(1, group_cells); + TASSIGN(flat, MAT_UB); + TADDS(flat, flat, eps_h); + pipe_barrier(PIPE_V); + } + +// Per-matrix column-normalize. With K compile-time we can call +// `TCOLEXPANDDIV` on each K×K sub-tile — 2 ops per matrix, one +// pipe_barrier between. +#define PER_MATRIX_COL_NORM() \ + do { \ + for (uint32_t mi = 0; mi < group_size; ++mi) { \ + const unsigned mat_off = MAT_UB + mi * K * MATRIX_ROW_BYTES; \ + Tile2D sub_mat(K, K); \ + TASSIGN(sub_mat, mat_off); \ + Tile2D sub_scratch(K, K); \ + TASSIGN(sub_scratch, SCRATCH_UB); \ + FlatVec col_stats(1, K); \ + TASSIGN(col_stats, ROW_STATS_UB); \ + TCOLSUM(col_stats, sub_mat, sub_scratch, false); \ + pipe_barrier(PIPE_V); \ + TCOLEXPANDDIV(sub_mat, sub_mat, col_stats); \ + pipe_barrier(PIPE_V); \ + } \ + } while (0) + + PER_MATRIX_COL_NORM(); + +#pragma unroll + for (uint32_t iter = 1; iter < REPEAT; ++iter) { + TASSIGN(row_stats, ROW_STATS_UB); + TROWSUM(row_stats, tall_matrix, tall_scratch); + pipe_barrier(PIPE_V); + + TROWEXPANDDIV(tall_matrix, tall_matrix, row_stats); + pipe_barrier(PIPE_V); + + PER_MATRIX_COL_NORM(); + } +#undef PER_MATRIX_COL_NORM + } + + // Copy back MAT_UB → BATCH_UB. { - V f(1, gf); - TASSIGN(f, MAT); - TEXP(f, f); + FlatVec src(1, group_cells); + FlatVec dst(1, group_cells); + TASSIGN(src, MAT_UB); + TASSIGN(dst, group_batch_offset); + TMOV(dst, src); pipe_barrier(PIPE_V); } - TROWSUM(v, m, t); - pipe_barrier(PIPE_V); - TROWEXPANDDIV(m, m, v); + } + + GmTensor gm_out_tensor(chunk_gm_out, gm_shape, gm_stride); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(gm_out_tensor, batch_tile); + pipe_barrier(PIPE_ALL); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + } + + drainPipelineFlags(); +} + +// ========================================================================== +// Strided-tree fallback: K ∈ (16, 64] — K-runtime path +// ========================================================================== +// +// Same interleaved layout as the fast path, but K is runtime so we can't +// form a compile-time [K, ROW_BLOCK_COLS] tile view. Instead col-normalize +// uses a flat TADD-tree (K − 1 adds to compute per-column sums over the +// row-blocks) followed by K TDIV calls (one per row-block). Net: +// `K + 1` ops + `K − 1` TADD barriers + 1 TDIV barrier per iteration, +// versus 2 ops + 2 barriers on the fast path. +template +AICORE void sinkhornStridedTree(__gm__ T *gm_in, __gm__ T *gm_out, uint32_t N, + uint32_t K, float eps) { + constexpr unsigned TILE_COLS = TILE_COLS_TEMPLATE; + constexpr unsigned TALL_ROWS = STACK_ROWS; + constexpr unsigned MATRIX_ROW_BYTES = TILE_COLS * sizeof(half); + constexpr unsigned TILE_BYTES = TALL_ROWS * TILE_COLS * sizeof(half); + + // UB layout — same regions as the fast path, but COL_STATS slots are + // allocated per-matrix (row-block flat vectors, not a single wide vector). + constexpr unsigned SCRATCH_UB = 0; + constexpr unsigned WORK_UB = ALIGN_32(SCRATCH_UB + TILE_BYTES); + constexpr unsigned ROW_STATS_UB = ALIGN_32(WORK_UB + TILE_BYTES); + constexpr unsigned COL_STATS_UB = + ALIGN_32(ROW_STATS_UB + ALIGN_32(TALL_ROWS * sizeof(half))); + constexpr unsigned BATCH_UB = + ALIGN_32(COL_STATS_UB + + MAX_MATS_PER_GROUP_CAP * ALIGN_32(TILE_COLS * sizeof(half))); + + constexpr unsigned BATCH_BUF_ROWS_RAW = + (UB_BYTES - BATCH_UB) / (TILE_COLS * sizeof(half)); + constexpr unsigned MAX_BATCH_ROWS = + BATCH_BUF_ROWS_RAW < 4095 ? BATCH_BUF_ROWS_RAW : 4095; + static_assert( + BATCH_UB + MAX_BATCH_ROWS * TILE_COLS * sizeof(half) <= UB_BYTES, + "BATCH_UB exceeds UB capacity"); + + set_mask_norm(); + set_vector_mask(-1, -1); + + if (K == 0 || K > TILE_COLS) return; + + const uint32_t num_workers = get_block_num() * get_subblockdim(); + const uint32_t worker_id = + get_block_idx() * get_subblockdim() + get_subblockid(); + const uint32_t base_per_worker = N / num_workers; + const uint32_t remainder = N % num_workers; + const uint32_t my_first = worker_id * base_per_worker + + (worker_id < remainder ? worker_id : remainder); + const uint32_t my_count = base_per_worker + (worker_id < remainder ? 1 : 0); + if (my_count == 0) return; + + const uint32_t K_SQUARED = K * K; + const uint32_t MAX_GROUP_SIZE = TALL_ROWS / K; + const uint32_t CHUNK_MATRICES = MAX_BATCH_ROWS / K; + const half eps_h = (half)eps; + + initPipelineFlags(); + + // Outer chunk loop. + for (uint32_t chunk_offset = 0; chunk_offset < my_count; + chunk_offset += CHUNK_MATRICES) { + const uint32_t chunk_matrices = + min(CHUNK_MATRICES, my_count - chunk_offset); + const uint32_t chunk_rows = chunk_matrices * K; + __gm__ T *chunk_gm_in = + gm_in + (size_t)(my_first + chunk_offset) * K_SQUARED; + __gm__ T *chunk_gm_out = + gm_out + (size_t)(my_first + chunk_offset) * K_SQUARED; + + { + FlatVec batch_flat(1, + chunk_rows * TILE_COLS); + TASSIGN(batch_flat, BATCH_UB); + TEXPANDS(batch_flat, (T)0); pipe_barrier(PIPE_V); + } + + Tile2D batch_tile(chunk_rows, K); + TASSIGN(batch_tile, BATCH_UB); + GmShape2D gm_shape(chunk_rows, K); + GmDenseStride gm_stride(K); + GmTensor gm_in_tensor(chunk_gm_in, gm_shape, gm_stride); + + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(batch_tile, gm_in_tensor); + pipe_barrier(PIPE_ALL); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + for (uint32_t group_start = 0; group_start < chunk_matrices; + group_start += MAX_GROUP_SIZE) { + const uint32_t group_size = + min(MAX_GROUP_SIZE, chunk_matrices - group_start); + const uint32_t group_tall_rows = group_size * K; + const uint32_t group_flat_len = group_tall_rows * TILE_COLS; + const unsigned group_batch_offset = + BATCH_UB + group_start * K * TILE_COLS * sizeof(T); + const uint32_t row_block_cols = group_size * TILE_COLS; + constexpr uint16_t tile_row_blocks = TILE_COLS * sizeof(half) / 32; + const uint16_t src_gap_blocks = (uint16_t)(K - 1) * tile_row_blocks; + + // Zero WORK_UB. { - V f(1, gf); - TASSIGN(f, MAT); - TADDS(f, f, eh); + FlatVec work_flat(1, TALL_ROWS * TILE_COLS); + TASSIGN(work_flat, WORK_UB); + TEXPANDS(work_flat, (T)0); pipe_barrier(PIPE_V); } - // Column normalize: K+1 barriers for all gc matrices -#define CN() \ - do { \ - for (uint32_t i = 0; i < gc; ++i) { \ - V c(1, K), r(1, K); \ - TASSIGN(c, CS + i * CS_S); \ - TASSIGN(r, MAT + i * K * rb); \ - TMOV(c, r); \ - } \ - pipe_barrier(PIPE_V); \ - for (uint32_t j = 1; j < K; ++j) { \ - for (uint32_t i = 0; i < gc; ++i) { \ - V c(1, K), r(1, K); \ - TASSIGN(c, CS + i * CS_S); \ - TASSIGN(r, MAT + (i * K + j) * rb); \ - TADD(c, c, r); \ - } \ - pipe_barrier(PIPE_V); \ - } \ - for (uint32_t i = 0; i < gc; ++i) { \ - unsigned o = MAT + i * K * rb; \ - V u(1, K); \ - TASSIGN(u, CS + i * CS_S); \ - for (uint32_t j = 0; j < K; ++j) { \ - V r(1, K); \ - TASSIGN(r, o + j * rb); \ - TDIV(r, r, u); \ - } \ - } \ - pipe_barrier(PIPE_V); \ - } while (0) + // Interleave: BATCH_UB → WORK_UB. + for (uint32_t row = 0; row < K; ++row) { + Tile2D src_view(group_size, K); + Tile2D dst_view(group_size, K); + TASSIGN(src_view, group_batch_offset + row * MATRIX_ROW_BYTES); + TASSIGN(dst_view, + WORK_UB + row * row_block_cols * (unsigned)sizeof(half)); + stridedUBCopy(dst_view, src_view, (uint16_t)group_size, tile_row_blocks, + src_gap_blocks, (uint16_t)0); + } + pipe_barrier(PIPE_V); + + if constexpr (REPEAT > 0) { + Tile2D tall_matrix(group_tall_rows, K); + TASSIGN(tall_matrix, WORK_UB); + Tile2D tall_scratch(group_tall_rows, K); + TASSIGN(tall_scratch, SCRATCH_UB); + ColVec row_stats(group_tall_rows, 1); + TASSIGN(row_stats, ROW_STATS_UB); - CN(); - for (uint32_t it = 1; it < repeat; ++it) { - TASSIGN(v, VC); - TROWSUM(v, m, t); + // Softmax. + TROWMAX(row_stats, tall_matrix, tall_scratch); pipe_barrier(PIPE_V); - TROWEXPANDDIV(m, m, v); + + TROWEXPANDSUB(tall_matrix, tall_matrix, row_stats); pipe_barrier(PIPE_V); - CN(); - } -#undef CN - { - V s(1, gf), d(1, gf); - TASSIGN(s, MAT); - TASSIGN(d, bo); - TMOV(d, s); + { + FlatVec work_flat(1, group_flat_len); + TASSIGN(work_flat, WORK_UB); + TEXP(work_flat, work_flat); + pipe_barrier(PIPE_V); + } + + TROWSUM(row_stats, tall_matrix, tall_scratch); + pipe_barrier(PIPE_V); + + TROWEXPANDDIV(tall_matrix, tall_matrix, row_stats); pipe_barrier(PIPE_V); + + { + FlatVec work_flat(1, group_flat_len); + TASSIGN(work_flat, WORK_UB); + TADDS(work_flat, work_flat, eps_h); + pipe_barrier(PIPE_V); + } + + // Col-normalize via flat TADD-tree + K×TDIV. + // + // Layout recap: WORK_UB's row-block r (offset r·row_block_cols) + // is a length-row_block_cols vector holding [mat0_row_r | mat1_row_r | + // ...]. The per-matrix col sum for matrix i's col c sits at index + // i·TILE_COLS + c. To compute all per-matrix col sums we sum the + // K row-blocks element-wise, then divide each row-block by the result. + constexpr unsigned CN_SCRATCH_UB = SCRATCH_UB; +#define COL_NORMALIZE_STRIDED_TREE() \ + do { \ + /* First two adds in parallel (writes to disjoint dests). */ \ + { \ + FlatVec a(1, row_block_cols), \ + b(1, row_block_cols), c(1, row_block_cols), d(1, row_block_cols); \ + TASSIGN(a, WORK_UB); \ + TASSIGN(b, WORK_UB + row_block_cols * (unsigned)sizeof(half)); \ + TASSIGN(c, COL_STATS_UB); \ + TADD(c, a, b); \ + TASSIGN(a, WORK_UB + 2 * row_block_cols * (unsigned)sizeof(half)); \ + TASSIGN(b, WORK_UB + 3 * row_block_cols * (unsigned)sizeof(half)); \ + TASSIGN(d, CN_SCRATCH_UB); \ + TADD(d, a, b); \ + } \ + pipe_barrier(PIPE_V); \ + /* Combine the two partial sums. */ \ + { \ + FlatVec a(1, row_block_cols), \ + b(1, row_block_cols); \ + TASSIGN(a, COL_STATS_UB); \ + TASSIGN(b, CN_SCRATCH_UB); \ + TADD(a, a, b); \ + } \ + pipe_barrier(PIPE_V); \ + /* Fold in the remaining (K − 4) row-blocks one by one. */ \ + for (uint32_t b_idx = 4; b_idx < K; ++b_idx) { \ + FlatVec src(1, row_block_cols), \ + dst(1, row_block_cols); \ + TASSIGN(src, WORK_UB + b_idx * row_block_cols * (unsigned)sizeof(half)); \ + TASSIGN(dst, COL_STATS_UB); \ + TADD(dst, dst, src); \ + pipe_barrier(PIPE_V); \ + } \ + /* Divide each row-block by the accumulated col sums. */ \ + for (uint32_t r = 0; r < K; ++r) { \ + FlatVec row_block(1, row_block_cols), \ + sums(1, row_block_cols); \ + TASSIGN(row_block, \ + WORK_UB + r * row_block_cols * (unsigned)sizeof(half)); \ + TASSIGN(sums, COL_STATS_UB); \ + TDIV(row_block, row_block, sums); \ + } \ + pipe_barrier(PIPE_V); \ + } while (0) + + // First col-normalize. + COL_NORMALIZE_STRIDED_TREE(); + +// (REPEAT − 1) × { row-normalize ; col-normalize }. +#pragma unroll + for (uint32_t iter = 1; iter < REPEAT; ++iter) { + TASSIGN(row_stats, ROW_STATS_UB); + + TROWSUM(row_stats, tall_matrix, tall_scratch); + pipe_barrier(PIPE_V); + + TROWEXPANDDIV(tall_matrix, tall_matrix, row_stats); + pipe_barrier(PIPE_V); + + COL_NORMALIZE_STRIDED_TREE(); + } +#undef COL_NORMALIZE_STRIDED_TREE + } + + // De-interleave: WORK_UB → BATCH_UB. + for (uint32_t row = 0; row < K; ++row) { + Tile2D src_view(group_size, K); + Tile2D dst_view(group_size, K); + TASSIGN(src_view, + WORK_UB + row * row_block_cols * (unsigned)sizeof(half)); + TASSIGN(dst_view, group_batch_offset + row * MATRIX_ROW_BYTES); + stridedUBCopy(dst_view, src_view, (uint16_t)group_size, tile_row_blocks, + (uint16_t)0, src_gap_blocks); } + pipe_barrier(PIPE_V); } - G2 go2(go, bs, bd); + GmTensor gm_out_tensor(chunk_gm_out, gm_shape, gm_stride); wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); - TSTORE(go2, bh); + TSTORE(gm_out_tensor, batch_tile); pipe_barrier(PIPE_ALL); set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); } - pipeDrain(); + + drainPipelineFlags(); } -// ---- FP32 per-matrix fallback (K > 64) ---- -template -AICORE void sinkhornFP32(__gm__ T *in, __gm__ T *out, uint32_t N, uint32_t K, - uint32_t repeat, float eps) { - constexpr unsigned D = MAX_DIM, rf = D * sizeof(float); - constexpr unsigned MH = 0, MF = MH + D * D * sizeof(half), - TF = MF + D * D * sizeof(float), - VF = TF + D * D * sizeof(float); - static_assert(VF + D * sizeof(float) <= UB_BYTES); +// ========================================================================== +// Per-matrix fp32 fallback: K ∈ (64, 128] +// ========================================================================== +// +// fp16 precision is insufficient at K=128 (softmax denominators accumulate +// 128 terms, each with ~3 decimal-digit precision). We load fp16, convert +// to fp32 for all internal compute, convert back to fp16 on store. No +// batching — one matrix per worker at a time. +template +AICORE void sinkhornPerMatrixFp32(__gm__ T *gm_in, __gm__ T *gm_out, uint32_t N, + uint32_t K, float eps) { + constexpr unsigned TILE_DIM = MAX_K; + constexpr unsigned F32_ROW_BYTES = TILE_DIM * sizeof(float); + constexpr unsigned MATRIX_H_UB = 0; // fp16 IO buffer + constexpr unsigned MATRIX_F_UB = + MATRIX_H_UB + TILE_DIM * TILE_DIM * sizeof(half); + constexpr unsigned SCRATCH_F_UB = + MATRIX_F_UB + TILE_DIM * TILE_DIM * sizeof(float); + constexpr unsigned VECTOR_F_UB = + SCRATCH_F_UB + TILE_DIM * TILE_DIM * sizeof(float); + static_assert(VECTOR_F_UB + TILE_DIM * sizeof(float) <= UB_BYTES, + "fp32 fallback UB layout overflows"); set_mask_norm(); set_vector_mask(-1, -1); - if (K == 0 || K > D) return; - const uint32_t W = get_block_num() * get_subblockdim(); - const uint32_t w = get_block_idx() * get_subblockdim() + get_subblockid(); - const uint32_t KK = K * K, fl = K * D; - - pipeInit(); - for (uint32_t bi = w; bi < N; bi += W) { - __gm__ T *gi = in + (size_t)bi * KK, *go = out + (size_t)bi * KK; + if (K == 0 || K > TILE_DIM) return; + + const uint32_t num_workers = get_block_num() * get_subblockdim(); + const uint32_t worker_id = + get_block_idx() * get_subblockdim() + get_subblockid(); + const uint32_t K_SQUARED = K * K; + const uint32_t flat_len = K * TILE_DIM; + + initPipelineFlags(); + + for (uint32_t matrix_idx = worker_id; matrix_idx < N; + matrix_idx += num_workers) { + __gm__ T *matrix_gm_in = gm_in + (size_t)matrix_idx * K_SQUARED; + __gm__ T *matrix_gm_out = gm_out + (size_t)matrix_idx * K_SQUARED; + + // Zero + TLOAD fp16 matrix. { - V z(1, fl); - TASSIGN(z, MH); - TEXPANDS(z, (T)0); + FlatVec zero_flat(1, flat_len); + TASSIGN(zero_flat, MATRIX_H_UB); + TEXPANDS(zero_flat, (T)0); pipe_barrier(PIPE_V); } - T2 mH(K, K); - TASSIGN(mH, MH); - S2 sh(K, K); - DS st(K); - G2 gI(gi, sh, st); + + Tile2D matrix_h(K, K); + TASSIGN(matrix_h, MATRIX_H_UB); + GmShape2D gm_shape(K, K); + GmDenseStride gm_stride(K); + GmTensor gm_in_tensor(matrix_gm_in, gm_shape, gm_stride); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); - TLOAD(mH, gI); + TLOAD(matrix_h, gm_in_tensor); pipe_barrier(PIPE_ALL); set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // Upcast fp16 → fp32. { - V h(1, fl); - V f(1, fl); - TASSIGN(h, MH); - TASSIGN(f, MF); - TCVT(f, h, RoundMode::CAST_NONE); + FlatVec h_flat(1, flat_len); + FlatVec f_flat(1, flat_len); + TASSIGN(h_flat, MATRIX_H_UB); + TASSIGN(f_flat, MATRIX_F_UB); + TCVT(f_flat, h_flat, RoundMode::CAST_NONE); pipe_barrier(PIPE_V); } - T2 m(K, K); - TASSIGN(m, MF); - T2 t(K, K); - TASSIGN(t, TF); - CV v(K, 1); - TASSIGN(v, VF); + Tile2D matrix(K, K); + TASSIGN(matrix, MATRIX_F_UB); + Tile2D scratch(K, K); + TASSIGN(scratch, SCRATCH_F_UB); + ColVec row_stats(K, 1); + TASSIGN(row_stats, VECTOR_F_UB); - TROWMAX(v, m, t); + // Softmax. + TROWMAX(row_stats, matrix, scratch); pipe_barrier(PIPE_V); - TROWEXPANDSUB(m, m, v); + + TROWEXPANDSUB(matrix, matrix, row_stats); pipe_barrier(PIPE_V); + { - V f(1, fl); - TASSIGN(f, MF); - TEXP(f, f); + FlatVec mat_flat(1, flat_len); + TASSIGN(mat_flat, MATRIX_F_UB); + TEXP(mat_flat, mat_flat); pipe_barrier(PIPE_V); } - TROWSUM(v, m, t); + + TROWSUM(row_stats, matrix, scratch); pipe_barrier(PIPE_V); - { - V u(1, K); - TASSIGN(u, VF); - TADDS(u, u, eps); - pipe_barrier(PIPE_V); - } - TROWEXPANDDIV(m, m, v); + + TROWEXPANDDIV(matrix, matrix, row_stats); pipe_barrier(PIPE_V); + { - V f(1, fl); - TASSIGN(f, MF); - TADDS(f, f, eps); + FlatVec mat_flat(1, flat_len); + TASSIGN(mat_flat, MATRIX_F_UB); + TADDS(mat_flat, mat_flat, eps); pipe_barrier(PIPE_V); } + + // First col-normalize. { - V c(1, K); - TASSIGN(c, VF); - TCOLSUM(c, m, t, false); + FlatVec col_stats(1, K); + TASSIGN(col_stats, VECTOR_F_UB); + + TCOLSUM(col_stats, matrix, scratch, false); pipe_barrier(PIPE_V); - TADDS(c, c, eps); + + TADDS(col_stats, col_stats, eps); pipe_barrier(PIPE_V); - } - for (uint32_t r = 0; r < K; ++r) { - V row(1, K), u(1, K); - TASSIGN(row, MF + r * rf); - TASSIGN(u, VF); - TDIV(row, row, u); + + TCOLEXPANDDIV(matrix, matrix, col_stats); pipe_barrier(PIPE_V); } - for (uint32_t it = 1; it < repeat; ++it) { - TASSIGN(v, VF); - TROWSUM(v, m, t); +// (REPEAT − 1) × { row-normalize ; col-normalize }. +#pragma unroll + for (uint32_t iter = 1; iter < REPEAT; ++iter) { + TASSIGN(row_stats, VECTOR_F_UB); + TROWSUM(row_stats, matrix, scratch); pipe_barrier(PIPE_V); - { - V u(1, K); - TASSIGN(u, VF); - TADDS(u, u, eps); - pipe_barrier(PIPE_V); - } - TROWEXPANDDIV(m, m, v); + + TADDS(row_stats, row_stats, eps); pipe_barrier(PIPE_V); + + TROWEXPANDDIV(matrix, matrix, row_stats); + pipe_barrier(PIPE_V); + { - V c(1, K); - TASSIGN(c, VF); - TCOLSUM(c, m, t, false); + FlatVec col_stats(1, K); + TASSIGN(col_stats, VECTOR_F_UB); + + TCOLSUM(col_stats, matrix, scratch, false); pipe_barrier(PIPE_V); - TADDS(c, c, eps); + + TADDS(col_stats, col_stats, eps); pipe_barrier(PIPE_V); - } - for (uint32_t r = 0; r < K; ++r) { - V row(1, K), u(1, K); - TASSIGN(row, MF + r * rf); - TASSIGN(u, VF); - TDIV(row, row, u); + + TCOLEXPANDDIV(matrix, matrix, col_stats); pipe_barrier(PIPE_V); } } + // Downcast fp32 → fp16 and store. { - V h(1, fl); - V f(1, fl); - TASSIGN(h, MH); - TASSIGN(f, MF); - TCVT(h, f, RoundMode::CAST_RINT); + FlatVec h_flat(1, flat_len); + FlatVec f_flat(1, flat_len); + TASSIGN(h_flat, MATRIX_H_UB); + TASSIGN(f_flat, MATRIX_F_UB); + TCVT(h_flat, f_flat, RoundMode::CAST_RINT); pipe_barrier(PIPE_V); } - G2 gO(go, sh, st); + + GmTensor gm_out_tensor(matrix_gm_out, gm_shape, gm_stride); wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); - TSTORE(gO, mH); + TSTORE(gm_out_tensor, matrix_h); pipe_barrier(PIPE_ALL); set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); } - pipeDrain(); + + drainPipelineFlags(); } -// ---- Dispatch ---- -template -AICORE void sinkhorn(__gm__ T *in, __gm__ T *out, uint32_t N, uint32_t K, - uint32_t repeat, float eps) { - if (K > 0 && K <= 16) - sinkhornMulti(in, out, N, K, repeat, eps); +// ========================================================================== +// Dispatch +// ========================================================================== +template +AICORE void dispatchByK(__gm__ T *gm_in, __gm__ T *gm_out, uint32_t N, + uint32_t K, float eps) { + // K ∈ {4, 8, 16}: split by batch size. At small batches the fast-path's + // double-buffer / interleave overhead dominates — use the simple + // per-matrix natural-order path instead. Empirical crossovers scale + // ~inversely with K (smaller K packs more matrices per group, so the + // interleave-layout benefit kicks in at higher batch counts). Thresholds + // were measured via msprof + direct Event timing on 910B. + if (K == 4 && N >= 2048) + sinkhornFastPath(gm_in, gm_out, N, eps); + else if (K == 8 && N >= 1024) + sinkhornFastPath(gm_in, gm_out, N, eps); + else if (K == 16 && N >= 512) + sinkhornFastPath(gm_in, gm_out, N, eps); + else if (K == 4) + sinkhornSmallBatch(gm_in, gm_out, N, eps); + else if (K == 8) + sinkhornSmallBatch(gm_in, gm_out, N, eps); + else if (K == 16) + sinkhornSmallBatch(gm_in, gm_out, N, eps); + // For K=32/64, the strided-tree path is already competitive at large + // batches; smallBatch only wins at very small N where the per-matrix + // loop has few iterations. Thresholds follow the same ~8192/K scaling. + else if (K == 32 && N < 256) + sinkhornSmallBatch(gm_in, gm_out, N, eps); + else if (K == 64 && N < 128) + sinkhornSmallBatch(gm_in, gm_out, N, eps); + // Other K values (odd K, K > 16 && K < 32, etc.) fall through to + // strided-tree. + else if (K > 0 && K <= 16) + sinkhornStridedTree(gm_in, gm_out, N, K, eps); else if (K <= 32) - sinkhornMulti(in, out, N, K, repeat, eps); + sinkhornStridedTree(gm_in, gm_out, N, K, eps); else if (K <= 64) - sinkhornMulti(in, out, N, K, repeat, eps); - else if (K <= 128) - sinkhornFP32(in, out, N, K, repeat, eps); + sinkhornStridedTree(gm_in, gm_out, N, K, eps); + else if (K <= MAX_K) + sinkhornPerMatrixFp32(gm_in, gm_out, N, K, eps); } -#endif +// Specialize on `repeat` so that the per-iteration unroll constant is known. +template +AICORE void dispatchByRepeat(__gm__ T *gm_in, __gm__ T *gm_out, uint32_t N, + uint32_t K, uint32_t repeat, float eps) { + switch (repeat) { + case 0: + dispatchByK(gm_in, gm_out, N, K, eps); + break; + case 1: + dispatchByK(gm_in, gm_out, N, K, eps); + break; + case 3: + dispatchByK(gm_in, gm_out, N, K, eps); + break; + case 5: + dispatchByK(gm_in, gm_out, N, K, eps); + break; + case 8: + dispatchByK(gm_in, gm_out, N, K, eps); + break; + case 10: + dispatchByK(gm_in, gm_out, N, K, eps); + break; + case 20: + dispatchByK(gm_in, gm_out, N, K, eps); + break; + default: + dispatchByK(gm_in, gm_out, N, K, eps); + break; + } +} +#endif // __CCE_AICORE__ == 220 && __DAV_C220_VEC__ + +// ========================================================================== +// C ABI +// ========================================================================== extern "C" __global__ AICORE void sinkhorn_ds_fp16(GM_ADDR input, GM_ADDR output, uint32_t N, uint32_t K, uint32_t repeat, float eps) { #if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) - sinkhorn((__gm__ half *)input, (__gm__ half *)output, N, K, repeat, - eps); + dispatchByRepeat((__gm__ half *)input, (__gm__ half *)output, N, K, + repeat, eps); #else (void)input; (void)output; @@ -397,10 +1206,11 @@ extern "C" __global__ AICORE void sinkhorn_ds_fp16(GM_ADDR input, #endif } -extern "C" void call_sinkhorn_ds_kernel(uint32_t blockDim, void *stream, +// Host-side launch. Ascend 910B runs 2 AIV cores per cube core. +extern "C" void call_sinkhorn_ds_kernel(uint32_t cube_core_num, void *stream, uint8_t *input, uint8_t *output, uint32_t N, uint32_t K, uint32_t repeat, float eps) { - sinkhorn_ds_fp16<<>>(input, output, N, K, - repeat, eps); + sinkhorn_ds_fp16<<>>(input, output, N, K, + repeat, eps); } diff --git a/examples/jit_cpp/sinkhorn/test_sinkhorn.py b/examples/jit_cpp/sinkhorn/test_sinkhorn.py index f16e0312..92e20b98 100644 --- a/examples/jit_cpp/sinkhorn/test_sinkhorn.py +++ b/examples/jit_cpp/sinkhorn/test_sinkhorn.py @@ -32,7 +32,92 @@ def sinkhorn_ref(x: torch.Tensor, repeat: int = 10, eps: float = 1e-6) -> torch. return x.to(torch.float16) -TEST_SHAPES = [ +# Dispatch paths in kernel_sinkhorn.cpp::dispatchByK: +# K=4 — sinkhornSmallBatch (N < 2048) | sinkhornFastPath (N >= 2048) +# K=8 — sinkhornSmallBatch (N < 1024) | sinkhornFastPath (N >= 1024) +# K=16 — sinkhornSmallBatch (N < 512) | sinkhornFastPath (N >= 512) +# K=32 — sinkhornSmallBatch (N < 256) | sinkhornStridedTree (N >= 256) +# K=64 — sinkhornSmallBatch (N < 128) | sinkhornStridedTree (N >= 128) +# K ∈ (0, 16], K ∉ {4,8,16} — sinkhornStridedTree +# K ∈ (16, 32], K ≠ 32 — sinkhornStridedTree +# K ∈ (32, 64], K ≠ 64 — sinkhornStridedTree +# K ∈ (64, 128] — sinkhornPerMatrixFp32 +# +# Each shape below targets one of those paths; shapes marked "boundary" +# sit on a dispatch threshold where the path flips. +DISPATCH_SHAPES = [ + # --- K=4 smallBatch (N < 2048) --- + (1, 4), + (4, 4), + (32, 4), + (100, 4), + (1000, 4), + (2047, 4), # boundary: last smallBatch N + # --- K=4 fastPath (N >= 2048) --- + (2048, 4), # boundary: first fastPath N + (2049, 4), + (4096, 4), + (8192, 4), + # --- K=8 smallBatch (N < 1024) --- + (1, 8), + (8, 8), + (64, 8), + (500, 8), + (1023, 8), # boundary + # --- K=8 fastPath (N >= 1024) --- + (1024, 8), # boundary + (1025, 8), + (2048, 8), + # --- K=16 smallBatch (N < 512) --- + (1, 16), + (16, 16), + (256, 16), + (511, 16), # boundary + # --- K=16 fastPath (N >= 512) --- + (512, 16), # boundary + (513, 16), + (1024, 16), + # --- K=32 smallBatch (N < 256) then stridedTree<32> --- + (1, 32), + (64, 32), + (255, 32), # boundary + (256, 32), # boundary → stridedTree + (512, 32), + # --- K=64 smallBatch (N < 128) then stridedTree<64> --- + (1, 64), + (32, 64), + (127, 64), # boundary + (128, 64), # boundary → stridedTree + (256, 64), + # --- Odd/other K ∈ (0, 16] → stridedTree<16> --- + (1, 5), + (8, 7), + (16, 12), + (4, 13), + (32, 15), + # --- K ∈ (16, 32], K ≠ 32 → stridedTree<32> --- + (1, 17), + (64, 20), + (32, 24), + (8, 30), + # --- K ∈ (32, 64], K ≠ 64 → stridedTree<64> --- + (1, 33), + (16, 48), + (8, 50), + (4, 60), + # --- K ∈ (64, 128] → fp32 fallback --- + (1, 65), + (2, 80), + (4, 96), + (2, 100), + (8, 128), +] +# One (repeat, seed) per shape — cheap, broad dispatch-path coverage. +DISPATCH_CASES = [(batch, K, 10, 0) for (batch, K) in DISPATCH_SHAPES] + +# Dense (repeat × seed) coverage for representative shapes — catches +# numerical regressions independent of dispatch path. +DENSE_SHAPES = [ (1, 4), (1, 8), (1, 16), @@ -46,15 +131,16 @@ def sinkhorn_ref(x: torch.Tensor, repeat: int = 10, eps: float = 1e-6) -> torch. (32, 4), (64, 8), ] -TEST_REPEATS = [1, 5, 10] -TEST_SEEDS = [0, 42] -TEST_CASES = [ +DENSE_CASES = [ (batch, K, repeat, seed) - for batch, K in TEST_SHAPES - for repeat in TEST_REPEATS - for seed in TEST_SEEDS + for (batch, K) in DENSE_SHAPES + for repeat in (1, 5, 10) + for seed in (0, 42) ] +# Dedup (DISPATCH and DENSE overlap on the original small shapes). +TEST_CASES = sorted(set(DISPATCH_CASES + DENSE_CASES)) + @pytest.fixture(scope="session") def sinkhorn_kernel(npu_device): @@ -77,19 +163,44 @@ def test_sinkhorn_ds_matches_reference( torch.testing.assert_close(out.cpu(), ref, rtol=1e-2, atol=1e-5) -def test_output_is_doubly_stochastic(sinkhorn_kernel, npu_device): +# Doubly-stochastic check across one representative shape per dispatch path. +DOUBLY_STOCHASTIC_SHAPES = [ + # smallBatch K ∈ {4,8,16,32,64} + (4, 4), + (4, 8), + (4, 16), + (4, 32), + (4, 64), + # fastPath K ∈ {4,8,16} + (2048, 4), + (1024, 8), + (512, 16), + # stridedTree (odd / non-{4,8,16,32,64} K) + (4, 7), + (4, 20), + (4, 48), + # stridedTree at large N for K ∈ {32, 64} + (256, 32), + (128, 64), + # fp32 fallback + (4, 96), + (4, 128), +] + + +@pytest.mark.parametrize("batch,K", DOUBLY_STOCHASTIC_SHAPES) +def test_output_is_doubly_stochastic(sinkhorn_kernel, npu_device, batch, K): """After enough iterations, rows and columns should approximately sum to 1/K.""" torch.manual_seed(123) - K = 8 - x = torch.randn(4, K, K, device=npu_device, dtype=DTYPE) + x = torch.randn(batch, K, K, device=npu_device, dtype=DTYPE) out = torch.empty_like(x) sinkhorn_kernel(x, out, repeat=20, eps=1e-6) torch.npu.synchronize() out_f = out.float() - row_sums = out_f.sum(dim=-1) # (4, K) - col_sums = out_f.sum(dim=-2) # (4, K) + row_sums = out_f.sum(dim=-1) # (batch, K) + col_sums = out_f.sum(dim=-2) # (batch, K) # All row sums should be approximately equal assert row_sums.std(dim=-1).max() < 0.05, f"Row sums not uniform: {row_sums}"