From dd218927d92abdcc3e891f100da8cd51c9c2c532 Mon Sep 17 00:00:00 2001 From: mirkodevita Date: Tue, 21 Apr 2026 21:16:31 +0000 Subject: [PATCH 1/3] working version of synkhorn --- .../aot/sinkhorn_dynamic_multicore/.gitignore | 6 + .../aot/sinkhorn_dynamic_multicore/README.md | 88 ++++ .../aot/sinkhorn_dynamic_multicore/caller.cpp | 42 ++ .../aot/sinkhorn_dynamic_multicore/compile.sh | 24 + .../jit_util_sinkhorn.py | 303 ++++++++++++ .../run_sinkhorn.py | 218 +++++++++ .../sinkhorn_builder.py | 437 ++++++++++++++++++ ptodsl/api/pto.py | 4 + ptodsl/api/pto_general.py | 22 + ptodsl/api/scalar.py | 27 ++ ptodsl/api/tile.py | 57 +++ 11 files changed, 1228 insertions(+) create mode 100644 examples/aot/sinkhorn_dynamic_multicore/.gitignore create mode 100644 examples/aot/sinkhorn_dynamic_multicore/README.md create mode 100644 examples/aot/sinkhorn_dynamic_multicore/caller.cpp create mode 100755 examples/aot/sinkhorn_dynamic_multicore/compile.sh create mode 100644 examples/aot/sinkhorn_dynamic_multicore/jit_util_sinkhorn.py create mode 100644 examples/aot/sinkhorn_dynamic_multicore/run_sinkhorn.py create mode 100644 examples/aot/sinkhorn_dynamic_multicore/sinkhorn_builder.py diff --git a/examples/aot/sinkhorn_dynamic_multicore/.gitignore b/examples/aot/sinkhorn_dynamic_multicore/.gitignore new file mode 100644 index 00000000..a8a3e1ef --- /dev/null +++ b/examples/aot/sinkhorn_dynamic_multicore/.gitignore @@ -0,0 +1,6 @@ +*.pto +*.cpp +!caller.cpp +*.so +*_artifacts/ +__pycache__/ diff --git a/examples/aot/sinkhorn_dynamic_multicore/README.md b/examples/aot/sinkhorn_dynamic_multicore/README.md new file mode 100644 index 00000000..f68bce0c --- /dev/null +++ b/examples/aot/sinkhorn_dynamic_multicore/README.md @@ -0,0 +1,88 @@ +# Sinkhorn normalization (dynamic-batch, multicore) + +PTODSL implementation of the Sinkhorn-style row/column normalization kernel +defined in [reference.cpp](reference.cpp). The kernel iteratively rescales +two diagonal vectors `mu1[L]` and `mu2[K]` so that the row and column +standard deviations of `matrix_in / (mu2[:, None] * mu1[None, :])` converge +to a common target value. + +## Algorithm + +For each `(K, L)` matrix in a batch of `N`: + +1. Initialise `mu1 = mu2 = invMu1 = 1`. +2. For `phase = 0 .. order`: + - Compute row & column **unbiased** standard deviations of + `cm = matrix_in / (mu2 * mu1)` in chunks of `ROW_CHUNK = 8` rows. + - `phase == 0`: set `tgt = min(rStd.min(), cStd.min()) + eps`. + - `phase > 0` : `mu2 *= (rStd / tgt) ** lr` and `mu1 *= (cStd / tgt) ** lr`, + then refresh `invMu1 = 1 / mu1`. +3. Write `matrix_out = matrix_in / (mu2 * mu1)`, plus `mu1_out`, `mu2_out`. + +## Design choices vs the hand-tuned reference + +The reference C++ exists primarily to squeeze every last cycle out of the +hardware (templated `TileL`, manual UB layout, hand-pipelined `set_flag` / +`wait_flag`, 2-term Padé `approxLn`). The PTODSL version trades a small +amount of throughput for clarity: + +| Concern | Reference | PTODSL builder | +| --------------------- | ---------------------------------- | ---------------------------------------- | +| Per-`L` specialisation| `runSinkhornImpl` switch | Single `MAX_DIM = 256` column stride | +| `inv_mu1` broadcast | Pre-tiled to `[ROW_CHUNK, L]` buf | `tile.col_expand_mul` | +| `pow(x, lr)` | 2-term Padé `approxLn` + `TEXP` | Native `tile.log` / `tile.exp` | +| Pipe synchronisation | Manual `set_flag` / `wait_flag` | `ptoas --enable-insert-sync` | + +## Constraints + +- `1 <= K, L <= 256` (`MAX_DIM`). +- `K % 8 == 0` (`ROW_CHUNK`). Tail handling for non-aligned `K` is left to + a future revision; the reference handles it via dynamic `cr`. +- Inputs are `fp16`; internal compute is `fp32`. + +## PTODSL / PTOAS workarounds + +Two limitations of the current stack forced extra plumbing compared to the +reference. Both are pure boilerplate and could be removed by toolchain +fixes. + +| # | Workaround in [sinkhorn_builder.py](sinkhorn_builder.py) | Root cause | Suggested fix | +| - | -------------------------------------------------------- | ---------- | ------------- | +| 1 | `mu2` is held as `RowMajor [1, MAX_DIM]` then per-chunk **copied** into a static `[1, ROW_CHUNK]` tile (`mu2RowStatic`) before being reshaped to the col-major `[ROW_CHUNK, 1]` sibling fed to `tile.row_expand_div`. | `pto.subview` narrows the *valid* shape but reuses the parent's *storage* `Numel`, so a downstream `tile.reshape` fails the bisheng `TRESHAPE` byte-size `static_assert` (parent `Numel = 256` ≠ `8`). | Have `pto.subview` rewrite the result tile-buf type's storage `shape` to the slice sizes when those sizes are static, so subview→reshape round-trips through a tile whose `Numel` matches what TRESHAPE expects. Alternatively expose a typed view-cast op (the new `pto.bitcast`/`pto.set_validshape` cover dtype/valid-shape but not storage-shape narrowing). | +| 2 | A static `[1, 1]` "scalar" tile is allocated as `[8, 1]`/`[1, 8]` with dynamic `valid_shape=[-1, -1]` so that `tile.min` / `tile.adds` / `tile.reshape` find the runtime `GetValidRow/Col` they require even though the value is conceptually 1×1. | The verifier+codegen for `TMin`/`TAddS`/`TRESHAPE` requires dynamic-valid metadata even on degenerate 1×1 tiles, and there's no row-major scalar type accepted by `tile.row_expand_div` as a broadcast source. | Lower `TMin`/`TAddS` over a fully-static `1×1` tile by emitting the immediate-form intrinsic directly, and let `tile.row_expand_div` accept a row-major `[1, 1]` scalar source (broadcast over both axes). | + +A third minor item: every K-indexed quantity (`mu2`, `rowSum`, `rowSqsum`) +is forced into `RowMajor [1, MAX_DIM]` instead of the natural `ColMajor +[MAX_DIM, 1]` because none of the elementwise ops +(`TMul/TSub/TMin/TLog/TExp/TSqrt/TAddS/TRowMin/T*ExpandDiv`) accept a +layout-override attribute. Adding such an attribute would let the builder +keep K-vectors col-major and drop the per-chunk reshape entirely. + +## Test coverage + +[run_sinkhorn.py](run_sinkhorn.py) runs the same matrix the upstream +torch_npu suite uses: `11 shapes × {order ∈ {1, 5, 10}} × {seed ∈ {0, +42}} = 66 cases`, including non-square `(1, 16, 256)` and `(1, 256, 16)`, +batched `(8, 128, 128)`, and the boundary `(1, 256, 256)`. Tolerances +match upstream (`rtol=5e-2`, `atol=1e-2`). All 66 cases pass against the +PyTorch reference. + +## Files + +| File | Purpose | +| --------------------- | --------------------------------------------------------- | +| `sinkhorn_builder.py` | PTODSL kernel — emits MLIR via stdout | +| `caller.cpp` | Thin C wrapper, exports `call_sinkhorn_kernel` | +| `compile.sh` | `python builder > .pto` → `ptoas` → `bisheng` shared lib | +| `run_sinkhorn.py` | Numerical correctness vs PyTorch reference | +| `reference.cpp` | Hand-tuned baseline (kept for documentation) | + +## Usage + +```bash +# 1. Generate MLIR + compile shared library (inside the NPU container). +./compile.sh + +# 2. Run correctness check. +python ./run_sinkhorn.py --lib ./sinkhorn_lib.so +``` diff --git a/examples/aot/sinkhorn_dynamic_multicore/caller.cpp b/examples/aot/sinkhorn_dynamic_multicore/caller.cpp new file mode 100644 index 00000000..a9418331 --- /dev/null +++ b/examples/aot/sinkhorn_dynamic_multicore/caller.cpp @@ -0,0 +1,42 @@ +#ifndef KERNEL_CPP +#define KERNEL_CPP "sinkhorn.cpp" +#endif +#include KERNEL_CPP + +extern "C" void call_sinkhorn_kernel( + uint32_t blockDim, + void *stream, + uint8_t *matrix_in, + uint8_t *matrix_out, + uint8_t *mu1_out, + uint8_t *mu2_out, + uint32_t N, + uint32_t K, + uint32_t L, + uint32_t order, + float lr, + float eps, + float invK, + float invL, + float invK1, + float invL1) +{ + // Reference fires `blockDim * 2` because each AIC has 2 AIVs and the + // reference is vector-only. The PTODSL builder targets vector cores too + // (vector_section), so spawn 2x logical workers per supplied blockDim. + _kernel<<>>( + reinterpret_cast(matrix_in), + reinterpret_cast(matrix_out), + reinterpret_cast(mu1_out), + reinterpret_cast(mu2_out), + static_cast(N), + static_cast(K), + static_cast(L), + static_cast(order), + lr, + eps, + invK, + invL, + invK1, + invL1); +} diff --git a/examples/aot/sinkhorn_dynamic_multicore/compile.sh b/examples/aot/sinkhorn_dynamic_multicore/compile.sh new file mode 100755 index 00000000..90927bd1 --- /dev/null +++ b/examples/aot/sinkhorn_dynamic_multicore/compile.sh @@ -0,0 +1,24 @@ +set -e + +rm -f sinkhorn.pto sinkhorn.cpp sinkhorn_lib.so + +python ./sinkhorn_builder.py > ./sinkhorn.pto +ptoas --enable-insert-sync ./sinkhorn.pto -o ./sinkhorn.cpp + +# CANN 8.5 headers don't have CompactMode; need latest pto-isa source. +PTO_LIB_PATH=${PTO_LIB_PATH:-/sources/pto-isa} +bisheng \ + -I${PTO_LIB_PATH}/include \ + -fPIC -shared -D_FORTIFY_SOURCE=2 -O2 -std=c++17 \ + -Wno-macro-redefined -Wno-ignored-attributes -fstack-protector-strong \ + -xcce -Xhost-start -Xhost-end \ + -mllvm -cce-aicore-stack-size=0x8000 \ + -mllvm -cce-aicore-function-stack-size=0x8000 \ + -mllvm -cce-aicore-record-overflow=true \ + -mllvm -cce-aicore-addr-transform \ + -mllvm -cce-aicore-dcci-insert-for-scalar=false \ + --npu-arch=dav-2201 -DMEMORY_BASE \ + -std=gnu++17 \ + -DKERNEL_CPP="\"sinkhorn.cpp\"" \ + ./caller.cpp \ + -o ./sinkhorn_lib.so diff --git a/examples/aot/sinkhorn_dynamic_multicore/jit_util_sinkhorn.py b/examples/aot/sinkhorn_dynamic_multicore/jit_util_sinkhorn.py new file mode 100644 index 00000000..60cf765d --- /dev/null +++ b/examples/aot/sinkhorn_dynamic_multicore/jit_util_sinkhorn.py @@ -0,0 +1,303 @@ +"""JIT utilities for the Sinkhorn example. + +Provides three entry points: + +* :func:`compile_pto_lib` — Runs ``sinkhorn_builder.py`` → ``ptoas`` → ``bisheng`` + to produce ``sinkhorn_lib.so`` (the PTODSL kernel). +* :func:`compile_reference_lib` — Compiles ``reference.cpp`` directly with + ``bisheng`` to produce ``reference_lib.so`` (the hand-tuned baseline). +* :func:`load_lib` — ``ctypes`` wrapper around ``call_sinkhorn_kernel``, + shared by both libraries (identical C ABI). + +Caches build artefacts under ``outputs/so/`` next to this file; rebuilds only +when the source file is newer than the cached ``.so``. +""" + +from __future__ import annotations + +import ctypes +import hashlib +import os +import shutil +import subprocess +import sys +from pathlib import Path +from typing import Callable, Optional + +import torch +import torch_npu # noqa: F401 (registers torch.npu) + +from ptodsl.npu_info import get_num_cube_cores + +THIS_DIR = Path(__file__).resolve().parent +DEFAULT_SO_DIR = THIS_DIR / "outputs" / "so" +PTO_LIB_PATH = Path(os.environ.get("PTO_LIB_PATH", "/sources/pto-isa")) + +MAX_DIM = 256 +ROW_CHUNK = 8 +BLOCK_DIM = get_num_cube_cores() + +SINKHORN_ARGTYPES = [ + ctypes.c_uint32, # blockDim + ctypes.c_void_p, # stream + ctypes.c_void_p, # matrix_in + ctypes.c_void_p, # matrix_out + ctypes.c_void_p, # mu1_out + ctypes.c_void_p, # mu2_out + ctypes.c_uint32, # N + ctypes.c_uint32, # K + ctypes.c_uint32, # L + ctypes.c_uint32, # order + ctypes.c_float, # lr + ctypes.c_float, # eps + ctypes.c_float, # invK + ctypes.c_float, # invL + ctypes.c_float, # invK1 + ctypes.c_float, # invL1 +] + +BISHENG_FLAGS = [ + "-fPIC", + "-shared", + "-D_FORTIFY_SOURCE=2", + "-O2", + "-std=c++17", + "-Wno-macro-redefined", + "-Wno-ignored-attributes", + "-fstack-protector-strong", + "-xcce", + "-Xhost-start", + "-Xhost-end", + "-mllvm", + "-cce-aicore-stack-size=0x8000", + "-mllvm", + "-cce-aicore-function-stack-size=0x8000", + "-mllvm", + "-cce-aicore-record-overflow=true", + "-mllvm", + "-cce-aicore-addr-transform", + "-mllvm", + "-cce-aicore-dcci-insert-for-scalar=false", + "--npu-arch=dav-2201", + "-DMEMORY_BASE", + "-std=gnu++17", +] + + +# --------------------------------------------------------------------------- +# Build helpers +# --------------------------------------------------------------------------- + + +def _file_digest(*paths: Path) -> str: + h = hashlib.sha256() + for p in paths: + if p.exists(): + h.update(p.read_bytes()) + return h.hexdigest()[:12] + + +def _run(cmd, *, cwd=None, verbose=False): + if verbose: + print(f"$ {' '.join(map(str, cmd))}") + subprocess.run(list(map(str, cmd)), check=True, cwd=cwd) + + +def _bisheng_compile(srcs, out_so: Path, *, defines=None, verbose=False): + out_so.parent.mkdir(parents=True, exist_ok=True) + cmd = ["bisheng", f"-I{PTO_LIB_PATH}/include", *BISHENG_FLAGS] + for k, v in (defines or {}).items(): + cmd.append(f"-D{k}={v}") + cmd += [*map(str, srcs), "-o", str(out_so)] + _run(cmd, verbose=verbose) + + +def compile_pto_lib( + builder_py: Path | str = THIS_DIR / "sinkhorn_builder.py", + *, + so_dir: Path | str = DEFAULT_SO_DIR, + verbose: bool = False, + force: bool = False, +) -> Path: + """Build the PTODSL Sinkhorn kernel into ``/sinkhorn_lib.so``.""" + builder_py = Path(builder_py).resolve() + caller_cpp = THIS_DIR / "caller.cpp" + so_dir = Path(so_dir) + so_dir.mkdir(parents=True, exist_ok=True) + + digest = _file_digest(builder_py, caller_cpp) + out_so = so_dir / f"sinkhorn_lib.{digest}.so" + if out_so.exists() and not force: + return out_so + + pto_path = so_dir / f"sinkhorn.{digest}.pto" + cpp_path = so_dir / f"sinkhorn.{digest}.cpp" + + # 1) builder -> .pto + if verbose: + print(f"$ python {builder_py} > {pto_path}") + with open(pto_path, "w") as f: + subprocess.run([sys.executable, str(builder_py)], check=True, stdout=f) + + # 2) ptoas -> .cpp + _run( + ["ptoas", "--enable-insert-sync", str(pto_path), "-o", str(cpp_path)], + verbose=verbose, + ) + + # 3) bisheng caller.cpp (which #includes the generated cpp) + _bisheng_compile( + [caller_cpp], + out_so, + defines={"KERNEL_CPP": f'\\"{cpp_path}\\"'}, + verbose=verbose, + ) + return out_so + + +def compile_reference_lib( + reference_cpp: Path | str = THIS_DIR / "reference.cpp", + *, + so_dir: Path | str = DEFAULT_SO_DIR, + verbose: bool = False, + force: bool = False, +) -> Path: + """Build ``reference.cpp`` into ``/reference_lib.so``.""" + reference_cpp = Path(reference_cpp).resolve() + so_dir = Path(so_dir) + so_dir.mkdir(parents=True, exist_ok=True) + + digest = _file_digest(reference_cpp) + out_so = so_dir / f"reference_lib.{digest}.so" + if out_so.exists() and not force: + return out_so + + _bisheng_compile([reference_cpp], out_so, verbose=verbose) + return out_so + + +# --------------------------------------------------------------------------- +# Loader +# --------------------------------------------------------------------------- + + +def _torch_to_ctypes(t: torch.Tensor) -> ctypes.c_void_p: + return ctypes.c_void_p(t.data_ptr()) + + +def _validate_io(matrix_in, matrix_out, mu1_out, mu2_out, K, L): + if matrix_in.dim() != 3: + raise ValueError("matrix_in must be a 3D tensor (N, K, L).") + N = matrix_in.shape[0] + if matrix_in.shape[1] != K or matrix_in.shape[2] != L: + raise ValueError(f"matrix_in must have shape (N, {K}, {L}).") + if matrix_out.shape != matrix_in.shape: + raise ValueError("matrix_out must have the same shape as matrix_in.") + if mu1_out.shape != (N, L): + raise ValueError(f"mu1_out must have shape ({N}, {L}).") + if mu2_out.shape != (N, K): + raise ValueError(f"mu2_out must have shape ({N}, {K}).") + for name, t in [ + ("matrix_in", matrix_in), + ("matrix_out", matrix_out), + ("mu1_out", mu1_out), + ("mu2_out", mu2_out), + ]: + if t.dtype != torch.float16: + raise TypeError(f"{name} must use torch.float16.") + if not t.is_contiguous(): + raise ValueError(f"{name} must be contiguous.") + if not (matrix_in.device == matrix_out.device == mu1_out.device == mu2_out.device): + raise ValueError("All tensors must be on the same device.") + if K > MAX_DIM or L > MAX_DIM: + raise ValueError(f"K and L must be <= {MAX_DIM}.") + if K == 0 or L == 0: + raise ValueError("K and L must be positive.") + + +def load_lib(lib_path: Path | str, *, block_dim: int = BLOCK_DIM) -> Callable: + """Open ``lib_path`` and return a ``sinkhorn(...)`` callable.""" + lib = ctypes.CDLL(str(lib_path)) + lib.call_sinkhorn_kernel.argtypes = SINKHORN_ARGTYPES + lib.call_sinkhorn_kernel.restype = None + block_dim = max(1, int(block_dim)) + + def sinkhorn( + matrix_in, + matrix_out, + mu1_out, + mu2_out, + *, + order: int = 10, + lr: float = 0.5, + eps: float = 1e-3, + block_dim: int = block_dim, + stream_ptr: Optional[int] = None, + ): + N, K, L = matrix_in.shape + _validate_io(matrix_in, matrix_out, mu1_out, mu2_out, K, L) + + inv_k = 1.0 / K + inv_l = 1.0 / L + inv_k1 = 1.0 / (K - 1) if K > 1 else 1.0 + inv_l1 = 1.0 / (L - 1) if L > 1 else 1.0 + + if stream_ptr is None: + stream_ptr = torch.npu.current_stream()._as_parameter_ + + lib.call_sinkhorn_kernel( + block_dim, + stream_ptr, + _torch_to_ctypes(matrix_in), + _torch_to_ctypes(matrix_out), + _torch_to_ctypes(mu1_out), + _torch_to_ctypes(mu2_out), + int(N), + int(K), + int(L), + int(order), + float(lr), + float(eps), + float(inv_k), + float(inv_l), + float(inv_k1), + float(inv_l1), + ) + + sinkhorn.block_dim = block_dim + return sinkhorn + + +def jit_compile_pto( + builder_py: Path | str = THIS_DIR / "sinkhorn_builder.py", + *, + verbose: bool = True, + so_dir: Path | str = DEFAULT_SO_DIR, + block_dim: int = BLOCK_DIM, + force: bool = False, +) -> Callable: + """One-shot: build PTODSL kernel + return loaded callable.""" + so = compile_pto_lib(builder_py, so_dir=so_dir, verbose=verbose, force=force) + return load_lib(so, block_dim=block_dim) + + +def jit_compile_reference( + reference_cpp: Path | str = THIS_DIR / "reference.cpp", + *, + verbose: bool = True, + so_dir: Path | str = DEFAULT_SO_DIR, + block_dim: int = BLOCK_DIM, + force: bool = False, +) -> Callable: + """One-shot: build reference.cpp + return loaded callable.""" + so = compile_reference_lib( + reference_cpp, so_dir=so_dir, verbose=verbose, force=force + ) + return load_lib(so, block_dim=block_dim) + + +def clean_cache(so_dir: Path | str = DEFAULT_SO_DIR) -> None: + """Remove all cached build artefacts.""" + so_dir = Path(so_dir) + if so_dir.exists(): + shutil.rmtree(so_dir) diff --git a/examples/aot/sinkhorn_dynamic_multicore/run_sinkhorn.py b/examples/aot/sinkhorn_dynamic_multicore/run_sinkhorn.py new file mode 100644 index 00000000..12be17d6 --- /dev/null +++ b/examples/aot/sinkhorn_dynamic_multicore/run_sinkhorn.py @@ -0,0 +1,218 @@ +"""Sinkhorn normalization: PTO kernel vs PyTorch reference.""" + +import argparse +import ctypes + +import torch +import torch_npu # noqa: F401 + +from ptodsl.npu_info import get_num_cube_cores, get_test_device + +_DEFAULT_NUM_CORES = get_num_cube_cores() +ROW_CHUNK = 8 + + +def torch_to_ctypes(t): + return ctypes.c_void_p(t.data_ptr()) + + +def load_lib(lib_path, block_dim=_DEFAULT_NUM_CORES): + lib = ctypes.CDLL(lib_path) + lib.call_sinkhorn_kernel.argtypes = [ + ctypes.c_uint32, # blockDim + ctypes.c_void_p, # stream + ctypes.c_void_p, # matrix_in + ctypes.c_void_p, # matrix_out + ctypes.c_void_p, # mu1_out + ctypes.c_void_p, # mu2_out + ctypes.c_uint32, # N + ctypes.c_uint32, # K + ctypes.c_uint32, # L + ctypes.c_uint32, # order + ctypes.c_float, # lr + ctypes.c_float, # eps + ctypes.c_float, # invK + ctypes.c_float, # invL + ctypes.c_float, # invK1 + ctypes.c_float, # invL1 + ] + lib.call_sinkhorn_kernel.restype = None + + def sinkhorn( + mat_in, + mat_out, + mu1_out, + mu2_out, + N, + K, + L, + order, + lr, + eps, + block_dim=block_dim, + stream_ptr=None, + ): + if stream_ptr is None: + stream_ptr = torch.npu.current_stream()._as_parameter_ + invK = 1.0 / K + invL = 1.0 / L + invK1 = 1.0 / max(K - 1, 1) + invL1 = 1.0 / max(L - 1, 1) + lib.call_sinkhorn_kernel( + block_dim, + stream_ptr, + torch_to_ctypes(mat_in), + torch_to_ctypes(mat_out), + torch_to_ctypes(mu1_out), + torch_to_ctypes(mu2_out), + N, + K, + L, + order, + lr, + eps, + invK, + invL, + invK1, + invL1, + ) + + return sinkhorn + + +def sinkhorn_ref(matrix_in, order, lr, eps): + """PyTorch reference Sinkhorn normalization (matches reference.cpp). + + Per (K, L) matrix: + mu1[L] = mu2[K] = 1 + For each phase 0..order: + cm = matrix_in / (mu2[:, None] * mu1[None, :]) + rStd[k] = unbiased std of cm[k, :] + cStd[l] = unbiased std of cm[:, l] + if phase == 0: tgt = min(rStd.min(), cStd.min()) + eps + else: mu2 *= (rStd / tgt) ** lr + mu1 *= (cStd / tgt) ** lr + out = matrix_in / (mu2[:, None] * mu1[None, :]) + """ + N, K, L = matrix_in.shape + cm_in = matrix_in.float() + out = torch.empty_like(cm_in) + mu1_all = torch.empty(N, L, device=matrix_in.device, dtype=torch.float32) + mu2_all = torch.empty(N, K, device=matrix_in.device, dtype=torch.float32) + + for bi in range(N): + cm0 = cm_in[bi] + mu1 = torch.ones(L, device=matrix_in.device, dtype=torch.float32) + mu2 = torch.ones(K, device=matrix_in.device, dtype=torch.float32) + tgt = None + for phase in range(order + 1): + cm = cm0 / (mu2[:, None] * mu1[None, :]) + rStd = cm.std(dim=1, unbiased=True) + cStd = cm.std(dim=0, unbiased=True) + if phase == 0: + tgt = torch.minimum(rStd.min(), cStd.min()) + eps + else: + mu2 = mu2 * torch.clamp(rStd / tgt, min=1e-12).pow(lr) + mu1 = mu1 * torch.clamp(cStd / tgt, min=1e-12).pow(lr) + out[bi] = cm0 / (mu2[:, None] * mu1[None, :]) + mu1_all[bi] = mu1 + mu2_all[bi] = mu2 + return ( + out.to(matrix_in.dtype), + mu1_all.to(matrix_in.dtype), + mu2_all.to(matrix_in.dtype), + ) + + +def test_sinkhorn(lib_path, block_dim=_DEFAULT_NUM_CORES): + device = get_test_device() + torch.npu.set_device(device) + + sinkhorn = load_lib(lib_path=lib_path, block_dim=block_dim) + + torch.manual_seed(0) + dtype = torch.float16 + # Mirrors the upstream torch_npu test suite (shapes x orders x seeds); + # cases where K is not a multiple of ROW_CHUNK=8 are skipped because the + # current builder only supports K-aligned chunking. + SHAPES = [ + (1, 16, 16), + (1, 32, 32), + (1, 64, 64), + (1, 128, 128), + (1, 256, 256), + (2, 64, 64), + (4, 32, 64), + (4, 64, 32), + (8, 128, 128), + (1, 16, 256), + (1, 256, 16), + ] + ORDERS = [1, 5, 10] + SEEDS = [0, 42] + LR, EPS = 0.5, 1e-3 + cases = [ + (N, K, L, order, LR, EPS, seed) + for (N, K, L) in SHAPES + for order in ORDERS + for seed in SEEDS + ] + + results = [] + for N, K, L, order, lr, eps, seed in cases: + if K % ROW_CHUNK != 0: + print( + f"[skip ] N={N} K={K} L={L} order={order} seed={seed} " + f"(K not multiple of {ROW_CHUNK})" + ) + results.append((N, K, L, order, seed, "skip")) + continue + torch.manual_seed(seed) + # Positive entries: keep within fp16 range. + mat_in = torch.rand(N, K, L, device=device, dtype=dtype) + 0.1 + mat_out = torch.empty_like(mat_in) + mu1_out = torch.empty(N, L, device=device, dtype=dtype) + mu2_out = torch.empty(N, K, device=device, dtype=dtype) + + ref_out, ref_mu1, ref_mu2 = sinkhorn_ref(mat_in, order, lr, eps) + + sinkhorn(mat_in, mat_out, mu1_out, mu2_out, N, K, L, order, lr, eps) + torch.npu.synchronize() + + ok = True + details = [] + for name, got, want in [ + ("matrix_out", mat_out, ref_out), + ("mu1_out", mu1_out, ref_mu1), + ("mu2_out", mu2_out, ref_mu2), + ]: + try: + torch.testing.assert_close(got, want, rtol=5e-2, atol=1e-2) + except AssertionError as err: + ok = False + details.append(f" {name}: {str(err).strip()[:200]}") + + status = "match" if ok else "mismatch" + print( + f"[{status}] N={N} K={K} L={L} order={order} seed={seed} " + f"lr={lr} eps={eps}" + ) + for d in details: + print(d) + results.append((N, K, L, order, seed, status)) + + print("\nsummary:") + counts = {"match": 0, "mismatch": 0, "skip": 0} + for r in results: + counts[r[-1]] = counts.get(r[-1], 0) + 1 + print(" ", r) + print(f"\n totals: {counts}") + return results + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--lib", default="./sinkhorn_lib.so") + parser.add_argument("--block-dim", type=int, default=_DEFAULT_NUM_CORES) + args = parser.parse_args() + test_sinkhorn(args.lib, block_dim=args.block_dim) diff --git a/examples/aot/sinkhorn_dynamic_multicore/sinkhorn_builder.py b/examples/aot/sinkhorn_dynamic_multicore/sinkhorn_builder.py new file mode 100644 index 00000000..45a6fd7f --- /dev/null +++ b/examples/aot/sinkhorn_dynamic_multicore/sinkhorn_builder.py @@ -0,0 +1,437 @@ +""" +Sinkhorn normalization kernel — PTODSL builder (fp16 I/O, fp32 internal). + +Algorithm (matches reference.cpp): + For each (K, L) matrix in the batch of N: + 1. Initialise mu1[L] = mu2[K] = invMu1[L] = 1.0. + 2. For phase = 0..order: + - Compute row & col standard deviations (unbiased) of cm/(mu1*mu2) + in chunks of ROW_CHUNK rows. + - phase == 0: tgt = min(min(rStd), min(cStd)) + eps [stored in tile] + - phase > 0: mu2 *= (rStd / tgt)^lr ; mu1 *= (cStd / tgt)^lr ; + invMu1 = 1 / mu1 + 3. Write matrix_out = cm / (mu1 * mu2) ; write mu1_out, mu2_out. + +Design choices vs the hand-tuned reference: + - No template specialisation on TileL: a single MAX_DIM=256 column stride + is used for every L. Generated MLIR + ptoas auto-sync replaces the + reference's manual flag/pipe management. + - Native tile.log / tile.exp instead of the 2-term Pade approxLn. + - col_expand_mul replaces the pre-tiled inv_mu1 buffer (one elementwise + broadcast op instead of an explicit row-tile copy + flat TMUL). + - Constraint: K must be a multiple of ROW_CHUNK = 8 (kernel returns + early otherwise — same MAX_DIM upper-bound as reference). +""" + +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const + +MAX_DIM = 256 +ROW_CHUNK = 8 + + +def meta_data(): + fp16 = pto.float16 + fp32 = pto.float32 + i32 = pto.int32 + + ptr_fp16 = pto.PtrType(fp16) + + tensor2_fp16 = pto.TensorType(rank=2, dtype=fp16) + + chunk_sub_fp16 = pto.SubTensorType(shape=[ROW_CHUNK, MAX_DIM], dtype=fp16) + row_sub_fp16 = pto.SubTensorType(shape=[1, MAX_DIM], dtype=fp16) + + # ---- VEC tile types ---- + row_vec_cfg = pto.TileBufConfig() # default RowMajor + col_vec_cfg = pto.TileBufConfig(blayout="ColMajor") + + # Row vector [1, MAX_DIM] RowMajor fp32 — for L-indexed quantities + # (mu1, invMu1, colSum, colSqsum, scratchL, zeroL). + row_vec_fp32 = pto.TileBufType( + shape=[1, MAX_DIM], + valid_shape=[1, -1], + dtype=fp32, + memory_space="VEC", + config=row_vec_cfg, + ) + row_vec_fp16 = pto.TileBufType( + shape=[1, MAX_DIM], + valid_shape=[1, -1], + dtype=fp16, + memory_space="VEC", + config=row_vec_cfg, + ) + + # Per-chunk static col-major tile [ROW_CHUNK, 1] — used as TROWSUM dst + # and TROWEXPANDDIV rhs scratch. Both shape and valid are static, so any + # tile.reshape between this and its row-major sibling is fully static + # and exercises the working codegen path. + chunk_col_fp32_st = pto.TileBufType( + shape=[ROW_CHUNK, 1], + dtype=fp32, + memory_space="VEC", + config=col_vec_cfg, + ) + chunk_row_fp32_st = pto.TileBufType( + shape=[1, ROW_CHUNK], + dtype=fp32, + memory_space="VEC", + config=row_vec_cfg, + ) + + # 2D chunk tiles [ROW_CHUNK, MAX_DIM] + chunk_fp16 = pto.TileBufType( + shape=[ROW_CHUNK, MAX_DIM], + valid_shape=[ROW_CHUNK, -1], + dtype=fp16, + memory_space="VEC", + config=row_vec_cfg, + ) + chunk_fp32 = pto.TileBufType( + shape=[ROW_CHUNK, MAX_DIM], + valid_shape=[ROW_CHUNK, -1], + dtype=fp32, + memory_space="VEC", + config=row_vec_cfg, + ) + + # Scalar [1, 1] ColMajor — broadcast-target for row_expand_div / col_expand_div. + # Use dynamic valid_shape so the tile carries GetValidRow/Col, which are + # required by TADDS / TMIN / TRESHAPE on the static 1x1 corner case. + scalar_col_fp32 = pto.TileBufType( + shape=[8, 1], + valid_shape=[-1, -1], + dtype=fp32, + memory_space="VEC", + config=col_vec_cfg, + ) + # Scalar [1, 1] RowMajor alias used for elementwise ops (TMin, TAddS,...) + # which require row-major layout. + scalar_row_fp32 = pto.TileBufType( + shape=[1, 8], + valid_shape=[-1, -1], + dtype=fp32, + memory_space="VEC", + config=row_vec_cfg, + ) + + return locals() + + +def build_sinkhorn(fn_name="sinkhorn_fp16"): + @to_ir_module(meta_data=meta_data) + def _kernel( + matrix_in_ptr: "ptr_fp16", + matrix_out_ptr: "ptr_fp16", + mu1_out_ptr: "ptr_fp16", + mu2_out_ptr: "ptr_fp16", + N_i32: "i32", + K_i32: "i32", + L_i32: "i32", + order_i32: "i32", + lr: "fp32", + eps: "fp32", + invK: "fp32", + invL: "fp32", + invK1: "fp32", + invL1: "fp32", + ) -> None: + c0 = const(0) + c1 = const(1) + cMAX_DIM = const(MAX_DIM) + cROW_CHUNK = const(ROW_CHUNK) + f0 = const(0.0, s.float32) + f1 = const(1.0, s.float32) + + N = s.index_cast(N_i32) + K = s.index_cast(K_i32) + L = s.index_cast(L_i32) + order = s.index_cast(order_i32) + + with pto.vector_section(): + # Bounds: 0 < K, L <= MAX_DIM ; K must be a multiple of ROW_CHUNK. + ok = ( + (K > c0) + & (L > c0) + & (cMAX_DIM >= K) + & (cMAX_DIM >= L) + & s.eq(K % cROW_CHUNK, c0) + ) + with pto.if_context(ok): + cid = pto.get_block_idx() + sub_bid = pto.get_subblock_idx() + sub_bnum = pto.get_subblock_num() + num_blocks = pto.get_block_num() + wid = s.index_cast(cid * sub_bnum + sub_bid) + num_workers = s.index_cast(num_blocks * sub_bnum) + + # ---- Allocate UB tiles (per worker, reused across batches) ---- + # K-indexed quantities are RowMajor [1, MAX_DIM] valid_col=K so + # all subsequent elementwise ops (mul/sub/maxs/sqrt/log/exp/min) + # work natively without going through a dynamic-valid reshape. + mu1 = pto.alloc_tile(row_vec_fp32, valid_col=L) + mu2 = pto.alloc_tile(row_vec_fp32, valid_col=K) + invMu1 = pto.alloc_tile(row_vec_fp32, valid_col=L) + + colSum = pto.alloc_tile(row_vec_fp32, valid_col=L) + colSqsum = pto.alloc_tile(row_vec_fp32, valid_col=L) + rowSum = pto.alloc_tile(row_vec_fp32, valid_col=K) + rowSqsum = pto.alloc_tile(row_vec_fp32, valid_col=K) + + scratchL = pto.alloc_tile(row_vec_fp32, valid_col=L) + scratchK = pto.alloc_tile(row_vec_fp32, valid_col=K) + + chunkH = pto.alloc_tile(chunk_fp16, valid_col=L) + chunkF = pto.alloc_tile(chunk_fp32, valid_col=L) + chunkTmp = pto.alloc_tile(chunk_fp32, valid_col=L) + + # Per-chunk static col-major scratch (TROWSUM dst, TROWEXPANDDIV + # rhs). [ROW_CHUNK, 1] both shape and valid fully static. + rsumScratch = pto.alloc_tile(chunk_col_fp32_st) + rsqScratch = pto.alloc_tile(chunk_col_fp32_st) + + # Static [1, ROW_CHUNK] row-major staging tile used to copy + # the dynamic mu2[jg : jg + ROW_CHUNK] subview into a tile + # whose storage Numel matches the [ROW_CHUNK, 1] col-major + # sibling, so the subsequent tile.reshape passes the + # codegen TRESHAPE byte-size static_assert. + mu2RowStatic = pto.alloc_tile(chunk_row_fp32_st) + + tgtScalar = pto.alloc_tile(scalar_col_fp32, valid_row=c1, valid_col=c1) + rMinTile = pto.alloc_tile(scalar_col_fp32, valid_row=c1, valid_col=c1) + cMinTile = pto.alloc_tile(scalar_col_fp32, valid_row=c1, valid_col=c1) + + # Output staging tiles (fp16) + mu1H = pto.alloc_tile(row_vec_fp16, valid_col=L) + mu2H = pto.alloc_tile(row_vec_fp16, valid_col=K) + + # ---- Tensor views (rank-2 for matrix_in/out, rank-1 for mu*) ---- + NK = N * K + tv_in = pto.as_tensor( + tensor2_fp16, + ptr=matrix_in_ptr, + shape=[NK, L], + strides=[L, c1], + ) + tv_out = pto.as_tensor( + tensor2_fp16, + ptr=matrix_out_ptr, + shape=[NK, L], + strides=[L, c1], + ) + tv_mu1 = pto.as_tensor( + tensor2_fp16, + ptr=mu1_out_ptr, + shape=[N, L], + strides=[L, c1], + ) + tv_mu2 = pto.as_tensor( + tensor2_fp16, + ptr=mu2_out_ptr, + shape=[N, K], + strides=[K, c1], + ) + + # ============================================================ + # Per-batch loop — workers split N across all vector cores. + # ============================================================ + for bi in pto.range(wid, N, num_workers): + # Init mu1, mu2, invMu1 to all-ones via muls(.,0)+adds(.,1). + tile.muls(mu1, f0, mu1) + tile.adds(mu1, f1, mu1) + tile.muls(mu2, f0, mu2) + tile.adds(mu2, f1, mu2) + tile.muls(invMu1, f0, invMu1) + tile.adds(invMu1, f1, invMu1) + + bi_row_off = bi * K # row offset of this batch in tv_in/out + + # ---------------------------------------------------------- + # Phase loop: phase 0 sets tgt; phases 1..order update mu. + # ---------------------------------------------------------- + for phase in pto.range(c0, order + c1, c1): + # Reset col accumulators. + tile.muls(colSum, f0, colSum) + tile.muls(colSqsum, f0, colSqsum) + + # Stream matrix in ROW_CHUNK-row chunks. + for jg in pto.range(c0, K, cROW_CHUNK): + # Load chunk fp16 [ROW_CHUNK, L] from GM. + chunk_view = pto.slice_view( + chunk_sub_fp16, + source=tv_in, + offsets=[bi_row_off + jg, c0], + sizes=[cROW_CHUNK, L], + ) + pto.load(chunk_view, chunkH) + + # fp16 -> fp32 + tile.cvt(chunkH, chunkF) + + # Build a col-major [ROW_CHUNK, 1] static view of + # mu2[jg : jg + ROW_CHUNK]. Subviewing mu2 keeps + # the parent's storage Numel (=MAX_DIM), so a + # direct reshape to [ROW_CHUNK, 1] would fail the + # codegen TRESHAPE byte-size static_assert. Copy + # the 8 elements into a static [1, ROW_CHUNK] + # tile (storage Numel=8) first via tile.muls + # with multiplier 1.0, then reshape that to the + # col-major sibling (storage Numel matches). + mu2_row_chunk = pto.subview( + mu2, offsets=[c0, jg], sizes=[1, ROW_CHUNK] + ) + tile.muls(mu2_row_chunk, f1, mu2RowStatic) + mu2_col_chunk = tile.reshape( + chunk_col_fp32_st, mu2RowStatic + ) + tile.row_expand_div(chunkF, mu2_col_chunk, chunkF) + + # Multiply each col by invMu1[c] (broadcast row-vec). + tile.col_expand_mul(chunkF, invMu1, chunkF) + + # Row-sum into per-chunk col scratch, then scatter + # into row-major rowSum[jg : jg + ROW_CHUNK]. + tile.row_sum(chunkF, chunkTmp, rsumScratch) + rsum_row_view = tile.reshape(chunk_row_fp32_st, rsumScratch) + rowSum_chunk = pto.subview( + rowSum, offsets=[c0, jg], sizes=[1, ROW_CHUNK] + ) + tile.muls(rsum_row_view, f1, rowSum_chunk) + + # Col-sum: accumulate across chunks. + tile.col_sum(chunkF, chunkTmp, scratchL, is_binary=True) + tile.add(colSum, scratchL, colSum) + + # Square chunk for sq-sum stats. + tile.mul(chunkF, chunkF, chunkF) + + tile.row_sum(chunkF, chunkTmp, rsqScratch) + rsq_row_view = tile.reshape(chunk_row_fp32_st, rsqScratch) + rowSq_chunk = pto.subview( + rowSqsum, offsets=[c0, jg], sizes=[1, ROW_CHUNK] + ) + tile.muls(rsq_row_view, f1, rowSq_chunk) + + tile.col_sum(chunkF, chunkTmp, scratchL, is_binary=True) + tile.add(colSqsum, scratchL, colSqsum) + + # ---- Finalise row std (unbiased): rStd = sqrt(max(0, + # (rSqsum - rSum^2 * invL) * invL1)) ---- + tile.mul(rowSum, rowSum, scratchK) + tile.muls(scratchK, invL, scratchK) + tile.sub(rowSqsum, scratchK, rowSqsum) + tile.muls(rowSqsum, invL1, rowSqsum) + tile.maxs(rowSqsum, f0, rowSqsum) + tile.sqrt(rowSqsum, rowSqsum) + + # ---- Finalise col std (unbiased) ---- + tile.mul(colSum, colSum, scratchL) + tile.muls(scratchL, invK, scratchL) + tile.sub(colSqsum, scratchL, colSqsum) + tile.muls(colSqsum, invK1, colSqsum) + tile.maxs(colSqsum, f0, colSqsum) + tile.sqrt(colSqsum, colSqsum) + + with pto.if_context(s.eq(phase, c0), has_else=True) as br: + # ---- Phase 0: tgt = min(rStd_min, cStd_min) + eps ---- + tile.row_min(rowSqsum, scratchK, rMinTile) + tile.row_min(colSqsum, scratchL, cMinTile) + # TMin / TAddS need row-major: alias the col scalars + # via tile.reshape (static [1,1] -> [1,1] is a no-op + # codegen-wise but the source has dynamic valid 1x1 + # set explicitly at alloc). + rMin_r = tile.reshape(scalar_row_fp32, rMinTile) + cMin_r = tile.reshape(scalar_row_fp32, cMinTile) + tgt_r = tile.reshape(scalar_row_fp32, tgtScalar) + tile.min(rMin_r, cMin_r, tgt_r) + tile.adds(tgt_r, eps, tgt_r) + with br.else_context(): + # ---- Phase >0: mu2 *= (rStd/tgt)^lr ---- + # rowSqsum is RowMajor [1, K], tgtScalar is ColMajor + # [8, 1] valid=1x1 — TROWEXPANDDIV broadcasts the + # 1-element col-vec across the K columns of the + # 1-row dst (= scalar division). + tile.row_expand_div(rowSqsum, tgtScalar, rowSqsum) + tile.maxs(rowSqsum, const(1e-12, s.float32), rowSqsum) + tile.log(rowSqsum, rowSqsum) + tile.muls(rowSqsum, lr, rowSqsum) + tile.exp(rowSqsum, rowSqsum) + tile.mul(mu2, rowSqsum, mu2) + + # ---- mu1 *= (cStd/tgt)^lr ---- + tile.row_expand_div(colSqsum, tgtScalar, colSqsum) + tile.maxs(colSqsum, const(1e-12, s.float32), colSqsum) + tile.log(colSqsum, colSqsum) + tile.muls(colSqsum, lr, colSqsum) + tile.exp(colSqsum, colSqsum) + tile.mul(mu1, colSqsum, mu1) + + # invMu1 = 1 / mu1 + tile.reciprocal(mu1, invMu1) + + # ============================================================ + # Write matrix_out = cm / (mu1 * mu2) + # ============================================================ + for jg in pto.range(c0, K, cROW_CHUNK): + chunk_view = pto.slice_view( + chunk_sub_fp16, + source=tv_in, + offsets=[bi_row_off + jg, c0], + sizes=[cROW_CHUNK, L], + ) + pto.load(chunk_view, chunkH) + tile.cvt(chunkH, chunkF) + + mu2_row_chunk = pto.subview( + mu2, offsets=[c0, jg], sizes=[1, ROW_CHUNK] + ) + tile.muls(mu2_row_chunk, f1, mu2RowStatic) + mu2_col_chunk = tile.reshape(chunk_col_fp32_st, mu2RowStatic) + tile.row_expand_div(chunkF, mu2_col_chunk, chunkF) + tile.col_expand_mul(chunkF, invMu1, chunkF) + + tile.cvt(chunkF, chunkH, rmode="cast_rint") + + out_view = pto.slice_view( + chunk_sub_fp16, + source=tv_out, + offsets=[bi_row_off + jg, c0], + sizes=[cROW_CHUNK, L], + ) + pto.store(chunkH, out_view) + + # ---- Write mu1_out (length L per batch) ---- + tile.cvt(mu1, mu1H, rmode="cast_rint") + mu1_view = pto.slice_view( + row_sub_fp16, + source=tv_mu1, + offsets=[bi, c0], + sizes=[c1, L], + ) + pto.store(mu1H, mu1_view) + + # ---- Write mu2_out (length K per batch) ---- + # mu2 is now RowMajor [1, MAX_DIM] valid_col=K — direct cvt. + tile.cvt(mu2, mu2H, rmode="cast_rint") + mu2_view = pto.slice_view( + row_sub_fp16, + source=tv_mu2, + offsets=[bi, c0], + sizes=[c1, K], + ) + pto.store(mu2H, mu2_view) + + _ = fn_name + return _kernel + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--fn-name", default="sinkhorn_fp16") + args = parser.parse_args() + print(build_sinkhorn(fn_name=args.fn_name)) diff --git a/ptodsl/api/pto.py b/ptodsl/api/pto.py index 963f99fb..11cf128c 100644 --- a/ptodsl/api/pto.py +++ b/ptodsl/api/pto.py @@ -19,6 +19,8 @@ load, reserve_buffer, slice_view, + subview, + subset, store, tfree_from_aic, tfree_from_aiv, @@ -66,6 +68,8 @@ "addptr", "as_tensor", "slice_view", + "subview", + "subset", "vector_section", "cube_section", "range", diff --git a/ptodsl/api/pto_general.py b/ptodsl/api/pto_general.py index 7a8e7a05..064bd2fa 100644 --- a/ptodsl/api/pto_general.py +++ b/ptodsl/api/pto_general.py @@ -84,6 +84,26 @@ def slice_view(subtensor_type, *, source, offsets, sizes): ).result +def subview(source, *, offsets, sizes): + """Create a strided view of a parent tile buffer (`pto.subview`). + + - ``offsets`` are runtime values (one per source-rank dim). + - ``sizes`` are compile-time integers (static shape of the result). + + The result is a ``!pto.tile_buf`` view that aliases ``source`` storage with + inherited strides and a dynamic offset. Useful for slicing a vector tile + into per-subblock row ranges (vector sub-block parallelism) or splitting an + ACC tile along K for sub-tile matmul accumulation. + """ + offset_vals = [_unwrap(v) for v in offsets] + return _pto.SubViewOp(source, offset_vals, list(sizes)).result + + +# Legacy alias — ptoas ≤0.27 named this op `pto.subset`. Kept so existing +# kernels keep importing. +subset = subview + + @contextmanager def vector_section(): section = _pto.SectionVectorOp() @@ -255,6 +275,8 @@ def print(format, scalar): "addptr", "as_tensor", "slice_view", + "subview", + "subset", "vector_section", "cube_section", "alloc_tile", diff --git a/ptodsl/api/scalar.py b/ptodsl/api/scalar.py index 3e574c91..5ddc4970 100644 --- a/ptodsl/api/scalar.py +++ b/ptodsl/api/scalar.py @@ -50,6 +50,33 @@ def __mod__(self, other): def __rmod__(self, other): return Value(arith.RemSIOp(_unwrap(other), _unwrap(self)).result) + def __and__(self, other): + return Value(arith.AndIOp(_unwrap(self), _unwrap(other)).result) + + def __rand__(self, other): + return Value(arith.AndIOp(_unwrap(other), _unwrap(self)).result) + + def __or__(self, other): + return Value(arith.OrIOp(_unwrap(self), _unwrap(other)).result) + + def __ror__(self, other): + return Value(arith.OrIOp(_unwrap(other), _unwrap(self)).result) + + def __xor__(self, other): + return Value(arith.XOrIOp(_unwrap(self), _unwrap(other)).result) + + def __rxor__(self, other): + return Value(arith.XOrIOp(_unwrap(other), _unwrap(self)).result) + + def __invert__(self): + # Bitwise NOT via xor with all-ones of the same integer type. + # Implemented as: self XOR (-1) using arith constant of matching type. + from mlir.ir import IntegerAttr + + ty = _unwrap(self).type + neg_one = arith.ConstantOp(ty, IntegerAttr.get(ty, -1)).result + return Value(arith.XOrIOp(_unwrap(self), neg_one).result) + @staticmethod def _cmp(lhs, rhs, predicate): return Value(arith.CmpIOp(predicate, _unwrap(lhs), _unwrap(rhs)).result) diff --git a/ptodsl/api/tile.py b/ptodsl/api/tile.py index 56817aa9..8771159e 100644 --- a/ptodsl/api/tile.py +++ b/ptodsl/api/tile.py @@ -221,6 +221,20 @@ def sort32(src, dst, idx): } +def reshape(dst_type, src): + """Reinterpret a tile's layout without moving data (zero-cost cast). + + Typical use: convert a column-vector [N, 1] ColMajor reduction tile into + a row-vector [1, N] RowMajor tile (or vice-versa) so that element-wise + ops like TMAX, TSUB, TMUL, TEXP can operate on it in the required layout. + + dst_type: a TileBufType describing the target shape / layout. + src: the source tile value. + Returns: a new tile SSA value with the reinterpreted type. + """ + return _pto.TReshapeOp(dst_type, src).result + + def muls(src, scalar, dst): """Multiply every element of a tile by a scalar value (tile * scalar → tile).""" _pto.tmuls(src, _unwrap(scalar), dst) @@ -231,6 +245,41 @@ def adds(src, scalar, dst): _pto.tadds(src, _unwrap(scalar), dst) +def subs(src, scalar, dst): + """Subtract a scalar from every element of a tile (tile - scalar → tile).""" + _pto.tsubs(src, _unwrap(scalar), dst) + + +def divs(src, scalar, dst): + """Divide every element of a tile by a scalar value (tile / scalar → tile).""" + _pto.tdivs(src, _unwrap(scalar), dst) + + +def maxs(src, scalar, dst): + """Element-wise max of a tile and a scalar (max(tile, scalar) → tile).""" + _pto.tmaxs(src, _unwrap(scalar), dst) + + +def mins(src, scalar, dst): + """Element-wise min of a tile and a scalar (min(tile, scalar) → tile).""" + _pto.tmins(src, _unwrap(scalar), dst) + + +def ands(src, scalar, dst): + """Bitwise AND of every element of a tile with a scalar value.""" + _pto.tands(src, _unwrap(scalar), dst) + + +def ors(src, scalar, dst): + """Bitwise OR of every element of a tile with a scalar value.""" + _pto.tors(src, _unwrap(scalar), dst) + + +def xors(src, scalar, dst): + """Bitwise XOR of every element of a tile with a scalar value.""" + _pto.txors(src, _unwrap(scalar), dst) + + def cvt(src, dst, *, rmode=None): """Convert tile element type (e.g. float32 → float16, float16 → float32). @@ -322,7 +371,15 @@ def print(source): "sort32", "muls", "adds", + "subs", + "divs", + "maxs", + "mins", + "ands", + "ors", + "xors", "cvt", "quant", "subset", + "reshape", ] From 1e5a7f7eed59c99bed802ccaa503de1b0e95e8c8 Mon Sep 17 00:00:00 2001 From: mirkodevita Date: Tue, 21 Apr 2026 21:57:49 +0000 Subject: [PATCH 2/3] accept tol parameters for correctness script to be able to run with the reference.cpp produced so file --- .../aot/sinkhorn_dynamic_multicore/README.md | 42 +- .../bench_sinkhorn.py | 455 ++++++++++++++++++ .../jit_util_sinkhorn.py | 2 +- .../run_sinkhorn.py | 17 +- 4 files changed, 504 insertions(+), 12 deletions(-) create mode 100644 examples/aot/sinkhorn_dynamic_multicore/bench_sinkhorn.py diff --git a/examples/aot/sinkhorn_dynamic_multicore/README.md b/examples/aot/sinkhorn_dynamic_multicore/README.md index f68bce0c..858bf538 100644 --- a/examples/aot/sinkhorn_dynamic_multicore/README.md +++ b/examples/aot/sinkhorn_dynamic_multicore/README.md @@ -69,13 +69,15 @@ PyTorch reference. ## Files -| File | Purpose | -| --------------------- | --------------------------------------------------------- | -| `sinkhorn_builder.py` | PTODSL kernel — emits MLIR via stdout | -| `caller.cpp` | Thin C wrapper, exports `call_sinkhorn_kernel` | -| `compile.sh` | `python builder > .pto` → `ptoas` → `bisheng` shared lib | -| `run_sinkhorn.py` | Numerical correctness vs PyTorch reference | -| `reference.cpp` | Hand-tuned baseline (kept for documentation) | +| File | Purpose | +| ----------------------- | ---------------------------------------------------------- | +| `sinkhorn_builder.py` | PTODSL kernel — emits MLIR via stdout | +| `caller.cpp` | Thin C wrapper, exports `call_sinkhorn_kernel` | +| `compile.sh` | `python builder > .pto` → `ptoas` → `bisheng` shared lib | +| `run_sinkhorn.py` | Numerical correctness vs PyTorch reference | +| `reference.cpp` | Hand-tuned baseline (`call_sinkhorn_kernel` self-contained)| +| `jit_util_sinkhorn.py` | Cached JIT compile + `ctypes` loader for both kernels | +| `bench_sinkhorn.py` | Throughput benchmark (torch / PTODSL / reference) | ## Usage @@ -83,6 +85,30 @@ PyTorch reference. # 1. Generate MLIR + compile shared library (inside the NPU container). ./compile.sh -# 2. Run correctness check. +# 2. Run correctness check (66 cases mirroring upstream torch_npu suite). python ./run_sinkhorn.py --lib ./sinkhorn_lib.so + +# 3. JIT-compile both kernels and benchmark. +python ./bench_sinkhorn.py +# Outputs: +# outputs/csv/{head_shapes_bench,batched_vs_serial}.csv +# outputs/plots/head_shapes_*.png, batched_vs_serial_log.png ``` + +## Throughput (Atlas 800I A2, fp16, order=8, lr=0.9, eps=1e-6) + +Single-matrix latency over the transformer-head grid (K ∈ {64, 128, 256}, +L ∈ {32, 64, 128, 256}), 5 warmup + 20 timed runs: + +| K | L | torch fp16 (µs) | PTODSL (µs) | reference C++ (µs) | PTODSL / torch | PTODSL / ref | +| --: | --: | --------------: | ----------: | -----------------: | -------------: | -----------: | +| 64 | 32 | 2170 | 39 | 65 | **55.6×** | **1.66×** | +| 64 | 256 | 2017 | 57 | 81 | 35.1× | 1.41× | +| 128 | 128 | 1983 | 79 | 124 | 25.1× | 1.56× | +| 256 | 256 | 2012 | 200 | 282 | 10.1× | 1.41× | + +Across all 12 shapes the PTODSL kernel is **10–55× faster than torch +fp16** and **1.40–1.78× faster than the hand-tuned reference C++**. The +batched-vs-serial sweep (K = L = 128) shows PTODSL holds the same ~80 µs +from N = 1 to N = 32 (perfect multicore scaling), consistently +**~1.55–1.63×** ahead of the reference at every batch size. diff --git a/examples/aot/sinkhorn_dynamic_multicore/bench_sinkhorn.py b/examples/aot/sinkhorn_dynamic_multicore/bench_sinkhorn.py new file mode 100644 index 00000000..fe598fb6 --- /dev/null +++ b/examples/aot/sinkhorn_dynamic_multicore/bench_sinkhorn.py @@ -0,0 +1,455 @@ +"""Benchmark fp16 Sinkhorn — torch vs PTODSL kernel vs hand-tuned reference. + +Shapes: + K (head_dim) : 64, 128, 256 + L (n_tokens) : 32, 64, 128, 256 + Batch : 1 (one (K, L) matrix per call) + +Plus a batched-vs-serial sweep at K=L=128. + +Writes: + outputs/csv/head_shapes_bench.csv + outputs/csv/batched_vs_serial.csv + outputs/plots/head_shapes_*.png + outputs/plots/batched_vs_serial_log.png +""" + +# pylint: disable=wrong-import-position +import argparse +import csv +from pathlib import Path + +import torch +import torch_npu # noqa: F401 + +from jit_util_sinkhorn import ( + BLOCK_DIM, + jit_compile_pto, + jit_compile_reference, +) +from ptodsl.npu_info import get_test_device + +THIS_DIR = Path(__file__).resolve().parent + +# --- Sinkhorn hyperparameters ---------------------------------------------- +SINKHORN_ORDER = 8 +SINKHORN_LR = 0.9 +SINKHORN_EPS = 1e-6 + +# --- Benchmark grids ------------------------------------------------------- +HEAD_DIMS = [64, 128, 256] +N_TOKENS = [32, 64, 128, 256] + +BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256] +BATCH_K = 128 +BATCH_L = 128 + +KERNEL_WARMUP = 10 +KERNEL_REPEATS = 50 + + +# --- torch reference ------------------------------------------------------- + + +def sinq_torch_fp16(matrix, sinkhorn_order=8, sinkhorn_lr=0.9, sinkhorn_eps=1e-6): + """Vectorised torch SINQ on (N, K, L). Stays in fp16.""" + K, L = matrix.shape[-2], matrix.shape[-1] + m = matrix + mu1 = torch.ones(*matrix.shape[:-2], L, dtype=m.dtype, device=m.device) + mu2 = torch.ones(*matrix.shape[:-2], K, 1, dtype=m.dtype, device=m.device) + tgt = ( + torch.minimum( + m.std(dim=-1).amin(dim=-1, keepdim=True), + m.std(dim=-2).amin(dim=-1, keepdim=True), + ).unsqueeze(-1) + + sinkhorn_eps + ) + for _ in range(sinkhorn_order): + cur = m / mu1.unsqueeze(-2) / mu2 + mu1 = mu1 * (cur.std(dim=-2) / tgt.squeeze(-1)) ** sinkhorn_lr + mu2 = mu2 * ((cur.std(dim=-1) / tgt.squeeze(-1)) ** sinkhorn_lr).unsqueeze(-1) + return m / mu1.unsqueeze(-2) / mu2, mu1, mu2.squeeze(-1) + + +# --- timing / metric helpers ---------------------------------------------- + + +def time_npu(fn, warmup=None, repeats=None): + warmup = KERNEL_WARMUP if warmup is None else warmup + repeats = KERNEL_REPEATS if repeats is None else repeats + for _ in range(warmup): + fn() + torch.npu.synchronize() + start = torch.npu.Event(enable_timing=True) + end = torch.npu.Event(enable_timing=True) + start.record() + for _ in range(repeats): + fn() + end.record() + torch.npu.synchronize() + return start.elapsed_time(end) * 1000.0 / repeats # us + + +def bytes_per_call(K, L, dtype_bytes): + return (2 * K * L + L + K) * dtype_bytes + + +def flops_per_call(K, L, order): + return K * L * (6 * (order + 1) + 2) + + +# --- head-shapes bench ----------------------------------------------------- + + +def _call_kernel(fn, mat, out, mu1, mu2, stream_ptr): + return fn( + mat, + out, + mu1, + mu2, + order=SINKHORN_ORDER, + lr=SINKHORN_LR, + eps=SINKHORN_EPS, + stream_ptr=stream_ptr, + ) + + +def run_head_shapes(pto_func, ref_func, stream_ptr, device): + rows = [] + header = ( + f"{'K':>4} {'L':>4} | " + f"{'torch_us':>9} {'pto_us':>8} {'ref_us':>8} | " + f"{'pto/torch':>9} {'pto/ref':>8}" + ) + print(header) + print("-" * len(header)) + + for K in HEAD_DIMS: + for L in N_TOKENS: + torch.random.manual_seed(42) + mat = torch.rand(1, K, L, dtype=torch.float16, device=device) + 0.1 + out = torch.empty_like(mat) + mu1 = torch.empty(1, L, dtype=torch.float16, device=device) + mu2 = torch.empty(1, K, dtype=torch.float16, device=device) + + t_us = time_npu( + lambda: sinq_torch_fp16( + mat, + sinkhorn_order=SINKHORN_ORDER, + sinkhorn_lr=SINKHORN_LR, + sinkhorn_eps=SINKHORN_EPS, + ) + ) + p_us = time_npu( + lambda: _call_kernel(pto_func, mat, out, mu1, mu2, stream_ptr) + ) + r_us = time_npu( + lambda: _call_kernel(ref_func, mat, out, mu1, mu2, stream_ptr) + ) + + B = bytes_per_call(K, L, 2) + F = flops_per_call(K, L, SINKHORN_ORDER) + sp_torch = t_us / p_us + sp_ref = r_us / p_us + print( + f"{K:>4d} {L:>4d} | " + f"{t_us:>9.2f} {p_us:>8.2f} {r_us:>8.2f} | " + f"{sp_torch:>9.2f}x {sp_ref:>7.2f}x" + ) + rows.append( + { + "K": K, + "L": L, + "torch_us": t_us, + "pto_us": p_us, + "ref_us": r_us, + "torch_GB_s": B / (t_us * 1e3), + "pto_GB_s": B / (p_us * 1e3), + "ref_GB_s": B / (r_us * 1e3), + "torch_GFLOPS": F / (t_us * 1e3), + "pto_GFLOPS": F / (p_us * 1e3), + "ref_GFLOPS": F / (r_us * 1e3), + "speedup_pto_vs_torch": sp_torch, + "speedup_pto_vs_ref": sp_ref, + } + ) + return rows + + +# --- batched-vs-serial bench ----------------------------------------------- + + +def run_batched_vs_serial(pto_func, ref_func, stream_ptr, device): + print(f"\nK={BATCH_K}, L={BATCH_L}, order={SINKHORN_ORDER}") + print( + f"{'N':>5} {'pto bat us':>11} {'pto ser us':>11} " + f"{'ref bat us':>11} {'pto/ref bat':>12}" + ) + rows = [] + for N in BATCH_SIZES: + mat = torch.rand(N, BATCH_K, BATCH_L, dtype=torch.float16, device=device) + 0.1 + out = torch.empty_like(mat) + mu1 = torch.empty(N, BATCH_L, dtype=torch.float16, device=device) + mu2 = torch.empty(N, BATCH_K, dtype=torch.float16, device=device) + + p_bat = time_npu(lambda: _call_kernel(pto_func, mat, out, mu1, mu2, stream_ptr)) + r_bat = time_npu(lambda: _call_kernel(ref_func, mat, out, mu1, mu2, stream_ptr)) + + mats_1 = [ + ( + torch.rand(1, BATCH_K, BATCH_L, dtype=torch.float16, device=device) + + 0.1, + torch.empty(1, BATCH_K, BATCH_L, dtype=torch.float16, device=device), + torch.empty(1, BATCH_L, dtype=torch.float16, device=device), + torch.empty(1, BATCH_K, dtype=torch.float16, device=device), + ) + for _ in range(N) + ] + + def _serial(fn): + for m, o, m1, m2 in mats_1: + _call_kernel(fn, m, o, m1, m2, stream_ptr) + + p_ser = time_npu(lambda: _serial(pto_func)) + sp_pto_ref = r_bat / p_bat if p_bat > 0 else float("nan") + + print( + f"{N:>5d} {p_bat:>11.2f} {p_ser:>11.2f} " + f"{r_bat:>11.2f} {sp_pto_ref:>11.2f}x" + ) + rows.append( + { + "N": N, + "pto_batched_us": p_bat, + "pto_serial_us": p_ser, + "pto_batched_per_mat_us": p_bat / N, + "pto_serial_per_mat_us": p_ser / N, + "ref_batched_us": r_bat, + "ref_batched_per_mat_us": r_bat / N, + "speedup_pto_vs_ref": sp_pto_ref, + "speedup_batched_vs_serial": ( + p_ser / p_bat if p_bat > 0 else float("nan") + ), + } + ) + return rows + + +# --- plots ---------------------------------------------------------------- + + +def _shape_labels(rows): + return [f"{r['K']}x{r['L']}" for r in rows] + + +def plot_speedup_grid(rows, key, title, path): + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + import numpy as np + + grid = np.full((len(HEAD_DIMS), len(N_TOKENS)), np.nan) + for r in rows: + i = HEAD_DIMS.index(r["K"]) + j = N_TOKENS.index(r["L"]) + grid[i, j] = r[key] + + fig, ax = plt.subplots(figsize=(7.5, 4.5)) + vmax = max(np.nanmax(grid), 1.0) + im = ax.imshow( + grid, aspect="auto", cmap="viridis", vmin=min(np.nanmin(grid), 1.0), vmax=vmax + ) + ax.set_xticks(range(len(N_TOKENS)), [str(l) for l in N_TOKENS]) + ax.set_yticks(range(len(HEAD_DIMS)), [str(k) for k in HEAD_DIMS]) + ax.set_xlabel("n_tokens (L)") + ax.set_ylabel("head_dim (K)") + ax.set_title(title) + for i in range(grid.shape[0]): + for j in range(grid.shape[1]): + ax.text( + j, + i, + f"{grid[i, j]:.2f}x", + ha="center", + va="center", + color="white" if grid[i, j] < vmax * 0.6 else "black", + fontsize=10, + ) + fig.colorbar(im, ax=ax, label="speedup (x)") + fig.tight_layout() + fig.savefig(path, dpi=130) + plt.close(fig) + print(f"saved -> {path}") + + +def _grouped_bar(rows, keys, colors, labels, ylabel, title, path): + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + import numpy as np + + xlabels = _shape_labels(rows) + x = np.arange(len(xlabels)) + w = 0.8 / len(keys) + fig, ax = plt.subplots(figsize=(11, 4.5)) + for i, (key, color, label) in enumerate(zip(keys, colors, labels)): + vals = [r[key] for r in rows] + ax.bar(x + (i - (len(keys) - 1) / 2) * w, vals, w, label=label, color=color) + ax.set_xticks(x, xlabels, rotation=45, ha="right") + ax.set_xlabel("shape (head_dim x n_tokens)") + ax.set_ylabel(ylabel) + ax.set_title(title) + ax.grid(True, axis="y", alpha=0.3) + ax.legend() + fig.tight_layout() + fig.savefig(path, dpi=130) + plt.close(fig) + print(f"saved -> {path}") + + +def plot_batched(rows, path): + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + Ns = [r["N"] for r in rows] + p_bat = [r["pto_batched_per_mat_us"] for r in rows] + p_ser = [r["pto_serial_per_mat_us"] for r in rows] + r_bat = [r["ref_batched_per_mat_us"] for r in rows] + sp = [r["speedup_pto_vs_ref"] for r in rows] + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 4.5)) + ax1.plot(Ns, p_bat, "o-", color="#dc2626", label="PTODSL batched") + ax1.plot(Ns, p_ser, "s--", color="#94a3b8", label="PTODSL serial") + ax1.plot(Ns, r_bat, "^-", color="#0369a1", label="reference batched") + ax1.set_xscale("log", base=2) + ax1.set_yscale("log") + ax1.set_xticks(Ns, [str(n) for n in Ns]) + ax1.set_xlabel("batch size N") + ax1.set_ylabel("per-matrix latency (us, log)") + ax1.set_title(f"Per-matrix cost @ K=L={BATCH_K}") + ax1.grid(True, which="both", alpha=0.3) + ax1.legend() + + ax2.plot( + Ns, [r["pto_batched_us"] for r in rows], "o-", color="#dc2626", label="PTODSL" + ) + ax2.plot( + Ns, + [r["ref_batched_us"] for r in rows], + "^-", + color="#0369a1", + label="reference", + ) + ax2.set_xscale("log", base=2) + ax2.set_yscale("log") + ax2.set_xticks(Ns, [str(n) for n in Ns]) + ax2.set_xlabel("batch size N") + ax2.set_ylabel("total batched latency (us, log)") + ax2.set_title("Total wall time + ref/PTO speedup") + ax2.grid(True, which="both", alpha=0.3) + ax2.legend(loc="upper left") + ax2_r = ax2.twinx() + ax2_r.plot(Ns, sp, "v:", color="#059669", label="ref/PTO speedup") + ax2_r.set_ylabel("speedup (ref / PTO, x)", color="#059669") + ax2_r.tick_params(axis="y", labelcolor="#059669") + + fig.suptitle("Sinkhorn — PTODSL vs reference (batched & serial)") + fig.tight_layout() + fig.savefig(path, dpi=130) + plt.close(fig) + print(f"saved -> {path}") + + +# --- main ----------------------------------------------------------------- + + +def _parse_args(): + p = argparse.ArgumentParser(description=__doc__.splitlines()[0]) + p.add_argument("--warmup", type=int, default=KERNEL_WARMUP) + p.add_argument("--repeats", type=int, default=KERNEL_REPEATS) + p.add_argument("--skip-batched", action="store_true") + p.add_argument("--force-rebuild", action="store_true") + return p.parse_args() + + +def _write_csv(path, rows): + if not rows: + return + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", newline="") as f: + w = csv.DictWriter(f, fieldnames=list(rows[0].keys())) + w.writeheader() + w.writerows(rows) + print(f"saved -> {path}") + + +def main(): + global KERNEL_WARMUP, KERNEL_REPEATS + args = _parse_args() + KERNEL_WARMUP = args.warmup + KERNEL_REPEATS = args.repeats + + device = get_test_device() + torch.npu.set_device(device) + print(f"Using device: {device}, block_dim={BLOCK_DIM}") + + print("Compiling PTODSL kernel ...") + pto_func = jit_compile_pto(verbose=True, force=args.force_rebuild) + print("Compiling reference.cpp ...") + ref_func = jit_compile_reference(verbose=True, force=args.force_rebuild) + + stream_ptr = torch.npu.current_stream()._as_parameter_ + + csv_dir = THIS_DIR / "outputs" / "csv" + plot_dir = THIS_DIR / "outputs" / "plots" + csv_dir.mkdir(parents=True, exist_ok=True) + plot_dir.mkdir(parents=True, exist_ok=True) + + # --- head shapes --- + hs_rows = run_head_shapes(pto_func, ref_func, stream_ptr, device) + _write_csv(csv_dir / "head_shapes_bench.csv", hs_rows) + + plot_speedup_grid( + hs_rows, + "speedup_pto_vs_torch", + f"PTODSL vs torch fp16 — speedup (x), order={SINKHORN_ORDER}", + plot_dir / "head_shapes_speedup_pto_vs_torch.png", + ) + plot_speedup_grid( + hs_rows, + "speedup_pto_vs_ref", + f"PTODSL vs reference C++ — speedup (x), order={SINKHORN_ORDER}", + plot_dir / "head_shapes_speedup_pto_vs_ref.png", + ) + _grouped_bar( + hs_rows, + keys=["torch_GB_s", "pto_GB_s", "ref_GB_s"], + colors=["#94a3b8", "#dc2626", "#0369a1"], + labels=["torch fp16", "PTODSL", "reference C++"], + ylabel="effective bandwidth (GB/s)", + title=f"Sinkhorn fp16 bandwidth — order={SINKHORN_ORDER}, batch=1", + path=plot_dir / "head_shapes_bandwidth.png", + ) + _grouped_bar( + hs_rows, + keys=["torch_GFLOPS", "pto_GFLOPS", "ref_GFLOPS"], + colors=["#94a3b8", "#dc2626", "#0369a1"], + labels=["torch fp16", "PTODSL", "reference C++"], + ylabel="effective GFLOPS", + title=f"Sinkhorn fp16 compute — order={SINKHORN_ORDER}, batch=1", + path=plot_dir / "head_shapes_flops.png", + ) + + # --- batched vs serial --- + if not args.skip_batched: + bs_rows = run_batched_vs_serial(pto_func, ref_func, stream_ptr, device) + _write_csv(csv_dir / "batched_vs_serial.csv", bs_rows) + plot_batched(bs_rows, plot_dir / "batched_vs_serial_log.png") + + +if __name__ == "__main__": + main() diff --git a/examples/aot/sinkhorn_dynamic_multicore/jit_util_sinkhorn.py b/examples/aot/sinkhorn_dynamic_multicore/jit_util_sinkhorn.py index 60cf765d..6898ed11 100644 --- a/examples/aot/sinkhorn_dynamic_multicore/jit_util_sinkhorn.py +++ b/examples/aot/sinkhorn_dynamic_multicore/jit_util_sinkhorn.py @@ -149,7 +149,7 @@ def compile_pto_lib( _bisheng_compile( [caller_cpp], out_so, - defines={"KERNEL_CPP": f'\\"{cpp_path}\\"'}, + defines={"KERNEL_CPP": f'"{cpp_path}"'}, verbose=verbose, ) return out_so diff --git a/examples/aot/sinkhorn_dynamic_multicore/run_sinkhorn.py b/examples/aot/sinkhorn_dynamic_multicore/run_sinkhorn.py index 12be17d6..c47b38bb 100644 --- a/examples/aot/sinkhorn_dynamic_multicore/run_sinkhorn.py +++ b/examples/aot/sinkhorn_dynamic_multicore/run_sinkhorn.py @@ -124,7 +124,7 @@ def sinkhorn_ref(matrix_in, order, lr, eps): ) -def test_sinkhorn(lib_path, block_dim=_DEFAULT_NUM_CORES): +def test_sinkhorn(lib_path, block_dim=_DEFAULT_NUM_CORES, rtol=2e-3, atol=1e-3): device = get_test_device() torch.npu.set_device(device) @@ -187,7 +187,7 @@ def test_sinkhorn(lib_path, block_dim=_DEFAULT_NUM_CORES): ("mu2_out", mu2_out, ref_mu2), ]: try: - torch.testing.assert_close(got, want, rtol=5e-2, atol=1e-2) + torch.testing.assert_close(got, want, rtol=rtol, atol=atol) except AssertionError as err: ok = False details.append(f" {name}: {str(err).strip()[:200]}") @@ -202,6 +202,7 @@ def test_sinkhorn(lib_path, block_dim=_DEFAULT_NUM_CORES): results.append((N, K, L, order, seed, status)) print("\nsummary:") + print(f" tolerances: rtol={rtol}, atol={atol}") counts = {"match": 0, "mismatch": 0, "skip": 0} for r in results: counts[r[-1]] = counts.get(r[-1], 0) + 1 @@ -214,5 +215,15 @@ def test_sinkhorn(lib_path, block_dim=_DEFAULT_NUM_CORES): parser = argparse.ArgumentParser() parser.add_argument("--lib", default="./sinkhorn_lib.so") parser.add_argument("--block-dim", type=int, default=_DEFAULT_NUM_CORES) + parser.add_argument( + "--rtol", + type=float, + default=2e-3, + help="Relative tolerance (default 2e-3 — the PTODSL kernel passes; " + "use 5e-2 to also accept the hand-tuned reference.cpp).", + ) + parser.add_argument( + "--atol", type=float, default=1e-3, help="Absolute tolerance (default 1e-3)." + ) args = parser.parse_args() - test_sinkhorn(args.lib, block_dim=args.block_dim) + test_sinkhorn(args.lib, block_dim=args.block_dim, rtol=args.rtol, atol=args.atol) From 9cd6a2bc5cf19c6d3de065d434a218c9ee9d03be Mon Sep 17 00:00:00 2001 From: mirkodevita Date: Tue, 21 Apr 2026 22:01:34 +0000 Subject: [PATCH 3/3] added synkhorn reference that was previously excluded by gitignore --- .../aot/sinkhorn_dynamic_multicore/.gitignore | 1 + .../sinkhorn_dynamic_multicore/reference.cpp | 609 ++++++++++++++++++ 2 files changed, 610 insertions(+) create mode 100644 examples/aot/sinkhorn_dynamic_multicore/reference.cpp diff --git a/examples/aot/sinkhorn_dynamic_multicore/.gitignore b/examples/aot/sinkhorn_dynamic_multicore/.gitignore index a8a3e1ef..80054a96 100644 --- a/examples/aot/sinkhorn_dynamic_multicore/.gitignore +++ b/examples/aot/sinkhorn_dynamic_multicore/.gitignore @@ -1,6 +1,7 @@ *.pto *.cpp !caller.cpp +!reference.cpp *.so *_artifacts/ __pycache__/ diff --git a/examples/aot/sinkhorn_dynamic_multicore/reference.cpp b/examples/aot/sinkhorn_dynamic_multicore/reference.cpp new file mode 100644 index 00000000..a7ee3ac7 --- /dev/null +++ b/examples/aot/sinkhorn_dynamic_multicore/reference.cpp @@ -0,0 +1,609 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ + +/** + * Sinkhorn normalization kernel for Ascend NPU (fp16 I/O, fp32 internal). + * + * Algorithm: + * For each (K, L) matrix in the batch: + * 1. Compute row/col standard deviations of cm / (mu1 * mu2). + * 2. Target = min(all stds) + eps. + * 3. Iterate: mu *= pow(std / tgt, lr) for each row/col. + * 4. Output: matrix_out = cm / (mu1 * mu2). + * + * Performance design: + * - Templated on TileL (column width) so the 2D row stride matches the + * data. flat TCVT / TMUL operate on cr*TileL elements instead of + * cr*MAX_DIM — up to 8x less work for small L. + * - inv_mu1 pre-tiled into a 2D flat buffer once per phase; a single + * flat TMUL per chunk replaces 8 row-by-row TDIVs. + * - pow(x,lr) via 2-term Pade approxLn + TEXP (8 barriers, not 14). + */ + +#include + +// clang-format off +#ifndef GM_ADDR +#define GM_ADDR __gm__ uint8_t* +#endif +// clang-format on + +using namespace pto; + +#define DIV_ROUNDUP(x, y) (((x) + (y) - 1) / (y)) +#define ALIGN_UP(x, y) (DIV_ROUNDUP((x), (y)) * (y)) + +constexpr uint32_t UB_USABLE_BYTES = 192 * 1024; +constexpr uint32_t MAX_DIM = 256; +constexpr uint32_t ROW_CHUNK = 8; +constexpr uint32_t TILE_ALIGN = 16; + +// ---------- UB layout (sized for worst case MAX_DIM) ---------- +namespace UbOfs { +constexpr unsigned MU1 = 0x00000; +constexpr unsigned MU2 = MU1 + MAX_DIM * sizeof(float); +constexpr unsigned INV_MU1 = MU2 + MAX_DIM * sizeof(float); +constexpr unsigned ROW_SUM = INV_MU1 + MAX_DIM * sizeof(float); +constexpr unsigned ROW_SQSUM = ROW_SUM + MAX_DIM * sizeof(float); +constexpr unsigned COL_SUM = ROW_SQSUM + MAX_DIM * sizeof(float); +constexpr unsigned COL_SQSUM = COL_SUM + MAX_DIM * sizeof(float); +constexpr unsigned CHUNK_HALF = COL_SQSUM + MAX_DIM * sizeof(float); +constexpr unsigned CHUNK_FP32 = CHUNK_HALF + ROW_CHUNK * MAX_DIM * sizeof(half); +constexpr unsigned CHUNK_TMP = CHUNK_FP32 + ROW_CHUNK * MAX_DIM * sizeof(float); +constexpr unsigned SCRATCH = CHUNK_TMP + ROW_CHUNK * MAX_DIM * sizeof(float); +constexpr unsigned SCALAR_A = SCRATCH + MAX_DIM * sizeof(float); +constexpr unsigned SCALAR_B = SCALAR_A + 32; +constexpr unsigned ZERO_VEC = SCALAR_B + 32; +constexpr unsigned LN_TMP1 = ZERO_VEC + MAX_DIM * sizeof(float); +constexpr unsigned LN_TMP2 = LN_TMP1 + MAX_DIM * sizeof(float); +constexpr unsigned INV_MU1_TILED = LN_TMP2 + MAX_DIM * sizeof(float); +constexpr unsigned TOTAL = INV_MU1_TILED + ROW_CHUNK * MAX_DIM * sizeof(float); +} // namespace UbOfs + +static_assert(UbOfs::TOTAL <= UB_USABLE_BYTES, + "Sinkhorn UB layout exceeds 192 KB."); + +#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) + +// ---------- Tile type aliases ---------- +using StrideDim5 = pto::Stride<1, 1, 1, 1, 1>; + +template +using Vec1D = Tile; + +template +using Global1D = GlobalTensor, StrideDim5>; + +using DynStride = Stride<1, 1, 1, DYNAMIC, 1>; +template +using Shape2D = TileShape2D; +template +using Tile2D = + Tile; +template +using Global2D = GlobalTensor, DynStride, Layout::ND>; + +using ScalarCol = Tile; + +template +using ColVec = + Tile; + +// ---------- Pipe helpers ---------- +AICORE inline void initPipeFlags() { + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); +} + +AICORE inline void drainPipeFlags() { + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); +} + +// ---------- 2-term Pade approxLn(x) for x > 0 ---------- +template +AICORE void approxLn(uint32_t N, unsigned dataOfs) { + Vec1D v(1, N); + Vec1D t1(1, N); + Vec1D r(1, N); + TASSIGN(v, dataOfs); + TASSIGN(t1, UbOfs::LN_TMP1); + TASSIGN(r, UbOfs::SCRATCH); + + TADDS(t1, v, -1.0f); + pipe_barrier(PIPE_V); + TADDS(v, v, 1.0f); + pipe_barrier(PIPE_V); + TDIV(r, t1, v); + pipe_barrier(PIPE_V); + TMUL(t1, r, r); + pipe_barrier(PIPE_V); + TMULS(v, t1, 1.0f / 3.0f); + pipe_barrier(PIPE_V); + TADDS(v, v, 1.0f); + pipe_barrier(PIPE_V); + TMUL(v, v, r); + pipe_barrier(PIPE_V); + TMULS(v, v, 2.0f); + pipe_barrier(PIPE_V); +} + +// ---------- Tile inv_mu1 into 2D flat buffer ---------- +template +AICORE void tileInvMu1(uint32_t La) { + constexpr unsigned rowBytes = TileL * sizeof(float); + Vec1D src(1, La); + TASSIGN(src, UbOfs::INV_MU1); + for (uint32_t r = 0; r < ROW_CHUNK; ++r) { + Vec1D dst(1, La); + TASSIGN(dst, UbOfs::INV_MU1_TILED + r * rowBytes); + TMULS(dst, src, 1.0f); + pipe_barrier(PIPE_V); + } +} + +// ---------- Main kernel, templated on tile column width ---------- +template +AICORE void runSinkhornImpl(__gm__ T *matrix_in, __gm__ T *matrix_out, + __gm__ T *mu1_out, __gm__ T *mu2_out, uint32_t N, + uint32_t K, uint32_t L, uint32_t La, uint32_t Ka, + uint32_t order, float lr, float eps, float invK, + float invL, float invK1, float invL1) { + const uint32_t num_workers = get_block_num() * get_subblockdim(); + const uint32_t wid = get_block_idx() * get_subblockdim() + get_subblockid(); + const uint32_t KL = K * L; + + initPipeFlags(); + + for (uint32_t bi = wid; bi < N; bi += num_workers) { + __gm__ T *cm = matrix_in + static_cast(bi) * KL; + + // ---- init ---- + Vec1D mu1(1, La); + Vec1D mu2(1, Ka); + Vec1D invMu1(1, La); + TASSIGN(mu1, UbOfs::MU1); + TASSIGN(mu2, UbOfs::MU2); + TASSIGN(invMu1, UbOfs::INV_MU1); + TEXPANDS(mu1, 1.0f); + pipe_barrier(PIPE_V); + TEXPANDS(mu2, 1.0f); + pipe_barrier(PIPE_V); + TEXPANDS(invMu1, 1.0f); + pipe_barrier(PIPE_V); + + { + uint32_t zLen = Ka > La ? Ka : La; + Vec1D zeroVec(1, zLen); + TASSIGN(zeroVec, UbOfs::ZERO_VEC); + TEXPANDS(zeroVec, 0.0f); + pipe_barrier(PIPE_V); + } + + tileInvMu1(La); + + // ============================================================ + // Phase loop + // ============================================================ + for (uint32_t phase = 0; phase <= order; ++phase) { + Vec1D colSum(1, La); + Vec1D colSqsum(1, La); + TASSIGN(colSum, UbOfs::COL_SUM); + TASSIGN(colSqsum, UbOfs::COL_SQSUM); + TEXPANDS(colSum, 0.0f); + pipe_barrier(PIPE_V); + TEXPANDS(colSqsum, 0.0f); + pipe_barrier(PIPE_V); + + // ---- stream matrix in ROW_CHUNK-row chunks ---- + for (uint32_t jg = 0; jg < K; jg += ROW_CHUNK) { + const uint32_t cr = (jg + ROW_CHUNK <= K) ? ROW_CHUNK : (K - jg); + const uint32_t flat = cr * TileL; // tight: stride = TileL + + // Load + Tile2D chunkHalf(cr, La); + TASSIGN(chunkHalf, UbOfs::CHUNK_HALF); + Shape2D chunkShape(cr, L); + DynStride chunkStride(L); + Global2D chunkGlobal(cm + jg * L, chunkShape, chunkStride); + + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(chunkHalf, chunkGlobal); + pipe_barrier(PIPE_ALL); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // fp16 -> fp32 + Vec1D halfFlat(1, flat); + Vec1D fp32Flat(1, flat); + TASSIGN(halfFlat, UbOfs::CHUNK_HALF); + TASSIGN(fp32Flat, UbOfs::CHUNK_FP32); + TCVT(fp32Flat, halfFlat, RoundMode::CAST_NONE); + pipe_barrier(PIPE_V); + + Tile2D chunk(cr, La); + TASSIGN(chunk, UbOfs::CHUNK_FP32); + + // Divide by mu2 + ColVec mu2Sub(cr, 1); + TASSIGN(mu2Sub, UbOfs::MU2 + jg * sizeof(float)); + TROWEXPANDDIV(chunk, chunk, mu2Sub); + pipe_barrier(PIPE_V); + + // Multiply by tiled inv_mu1 (1 flat TMUL) + { + Vec1D cFlat(1, flat); + Vec1D iFlat(1, flat); + TASSIGN(cFlat, UbOfs::CHUNK_FP32); + TASSIGN(iFlat, UbOfs::INV_MU1_TILED); + TMUL(cFlat, cFlat, iFlat); + pipe_barrier(PIPE_V); + } + + // Row stats + Tile2D tmp(cr, La); + TASSIGN(tmp, UbOfs::CHUNK_TMP); + + ColVec rowSumPart(cr, 1); + TASSIGN(rowSumPart, UbOfs::ROW_SUM + jg * sizeof(float)); + TROWSUM(rowSumPart, chunk, tmp); + pipe_barrier(PIPE_V); + + Vec1D partCol(1, La); + TASSIGN(partCol, UbOfs::SCRATCH); + TCOLSUM(partCol, chunk, tmp, false); + pipe_barrier(PIPE_V); + if (jg == 0) { + TMULS(colSum, partCol, 1.0f); + } else { + TADD(colSum, colSum, partCol); + } + pipe_barrier(PIPE_V); + + TMUL(chunk, chunk, chunk); + pipe_barrier(PIPE_V); + + ColVec rowSqPart(cr, 1); + TASSIGN(rowSqPart, UbOfs::ROW_SQSUM + jg * sizeof(float)); + TROWSUM(rowSqPart, chunk, tmp); + pipe_barrier(PIPE_V); + + TCOLSUM(partCol, chunk, tmp, false); + pipe_barrier(PIPE_V); + if (jg == 0) { + TMULS(colSqsum, partCol, 1.0f); + } else { + TADD(colSqsum, colSqsum, partCol); + } + pipe_barrier(PIPE_V); + + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + } // chunk loop + + // ---- Finalise row_std ---- + Vec1D rSum(1, Ka); + Vec1D rStd(1, Ka); + Vec1D scrK(1, Ka); + Vec1D zeroK(1, Ka); + TASSIGN(rSum, UbOfs::ROW_SUM); + TASSIGN(rStd, UbOfs::ROW_SQSUM); + TASSIGN(scrK, UbOfs::SCRATCH); + TASSIGN(zeroK, UbOfs::ZERO_VEC); + + TMUL(scrK, rSum, rSum); + pipe_barrier(PIPE_V); + TMULS(scrK, scrK, invL); + pipe_barrier(PIPE_V); + TSUB(rStd, rStd, scrK); + pipe_barrier(PIPE_V); + TMULS(rStd, rStd, invL1); + pipe_barrier(PIPE_V); + TMAX(rStd, rStd, zeroK); + pipe_barrier(PIPE_V); + TSQRT(rStd, rStd); + pipe_barrier(PIPE_V); + + // ---- Finalise col_std ---- + Vec1D cSum(1, La); + Vec1D cStd(1, La); + Vec1D scrL(1, La); + Vec1D zeroL(1, La); + TASSIGN(cSum, UbOfs::COL_SUM); + TASSIGN(cStd, UbOfs::COL_SQSUM); + TASSIGN(scrL, UbOfs::SCRATCH); + TASSIGN(zeroL, UbOfs::ZERO_VEC); + + TMUL(scrL, cSum, cSum); + pipe_barrier(PIPE_V); + TMULS(scrL, scrL, invK); + pipe_barrier(PIPE_V); + TSUB(cStd, cStd, scrL); + pipe_barrier(PIPE_V); + TMULS(cStd, cStd, invK1); + pipe_barrier(PIPE_V); + TMAX(cStd, cStd, zeroL); + pipe_barrier(PIPE_V); + TSQRT(cStd, cStd); + pipe_barrier(PIPE_V); + + if (phase == 0) { + Vec1D rStd1D(1, Ka); + Vec1D rMinTmp(1, Ka); + ScalarCol rMinS(1, 1); + TASSIGN(rStd1D, UbOfs::ROW_SQSUM); + TASSIGN(rMinTmp, UbOfs::SCRATCH); + TASSIGN(rMinS, UbOfs::SCALAR_A); + TROWMIN(rMinS, rStd1D, rMinTmp); + pipe_barrier(PIPE_V); + + Vec1D cStd1D(1, La); + Vec1D cMinTmp(1, La); + ScalarCol cMinS(1, 1); + TASSIGN(cStd1D, UbOfs::COL_SQSUM); + TASSIGN(cMinTmp, UbOfs::SCRATCH); + TASSIGN(cMinS, UbOfs::SCALAR_B); + TROWMIN(cMinS, cStd1D, cMinTmp); + pipe_barrier(PIPE_V); + + Vec1D sA(1, 1); + Vec1D sB(1, 1); + TASSIGN(sA, UbOfs::SCALAR_A); + TASSIGN(sB, UbOfs::SCALAR_B); + TMIN(sA, sA, sB); + pipe_barrier(PIPE_V); + TADDS(sA, sA, eps); + pipe_barrier(PIPE_V); + } else { + // ---- mu update ---- + ScalarCol tgtCol(1, 1); + TASSIGN(tgtCol, UbOfs::SCALAR_A); + + // mu2 *= pow(row_std / tgt, lr) + Vec1D rStdUpd(1, Ka); + TASSIGN(rStdUpd, UbOfs::ROW_SQSUM); + TROWEXPANDDIV(rStdUpd, rStdUpd, tgtCol); + pipe_barrier(PIPE_V); + { + Vec1D epsVec(1, Ka); + TASSIGN(epsVec, UbOfs::LN_TMP1); + TEXPANDS(epsVec, 1e-12f); + pipe_barrier(PIPE_V); + TMAX(rStdUpd, rStdUpd, epsVec); + pipe_barrier(PIPE_V); + } + approxLn(Ka, UbOfs::ROW_SQSUM); + TASSIGN(rStdUpd, UbOfs::ROW_SQSUM); + TMULS(rStdUpd, rStdUpd, lr); + pipe_barrier(PIPE_V); + TEXP(rStdUpd, rStdUpd); + pipe_barrier(PIPE_V); + Vec1D mu2Upd(1, Ka); + TASSIGN(mu2Upd, UbOfs::MU2); + TMUL(mu2Upd, mu2Upd, rStdUpd); + pipe_barrier(PIPE_V); + + // mu1 *= pow(col_std / tgt, lr) + Vec1D cStdUpd(1, La); + TASSIGN(cStdUpd, UbOfs::COL_SQSUM); + TROWEXPANDDIV(cStdUpd, cStdUpd, tgtCol); + pipe_barrier(PIPE_V); + { + Vec1D epsVec(1, La); + TASSIGN(epsVec, UbOfs::LN_TMP1); + TEXPANDS(epsVec, 1e-12f); + pipe_barrier(PIPE_V); + TMAX(cStdUpd, cStdUpd, epsVec); + pipe_barrier(PIPE_V); + } + approxLn(La, UbOfs::COL_SQSUM); + TASSIGN(cStdUpd, UbOfs::COL_SQSUM); + TMULS(cStdUpd, cStdUpd, lr); + pipe_barrier(PIPE_V); + TEXP(cStdUpd, cStdUpd); + pipe_barrier(PIPE_V); + Vec1D mu1Upd(1, La); + TASSIGN(mu1Upd, UbOfs::MU1); + TMUL(mu1Upd, mu1Upd, cStdUpd); + pipe_barrier(PIPE_V); + + // Refresh inv_mu1 and re-tile + Vec1D ones(1, La); + TASSIGN(ones, UbOfs::LN_TMP1); + TEXPANDS(ones, 1.0f); + pipe_barrier(PIPE_V); + Vec1D newInv(1, La); + TASSIGN(newInv, UbOfs::INV_MU1); + TASSIGN(mu1Upd, UbOfs::MU1); + TDIV(newInv, ones, mu1Upd); + pipe_barrier(PIPE_V); + tileInvMu1(La); + } + } // phase loop + + // ============================================================ + // Write output + // ============================================================ + __gm__ T *out = matrix_out + static_cast(bi) * KL; + + for (uint32_t jg = 0; jg < K; jg += ROW_CHUNK) { + const uint32_t cr = (jg + ROW_CHUNK <= K) ? ROW_CHUNK : (K - jg); + const uint32_t flat = cr * TileL; + + Tile2D chunkHalf(cr, La); + TASSIGN(chunkHalf, UbOfs::CHUNK_HALF); + Shape2D inShape(cr, L); + DynStride inStride(L); + Global2D inGlobal(cm + jg * L, inShape, inStride); + + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(chunkHalf, inGlobal); + pipe_barrier(PIPE_ALL); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + Vec1D hFlat(1, flat); + Vec1D fFlat(1, flat); + TASSIGN(hFlat, UbOfs::CHUNK_HALF); + TASSIGN(fFlat, UbOfs::CHUNK_FP32); + TCVT(fFlat, hFlat, RoundMode::CAST_NONE); + pipe_barrier(PIPE_V); + + Tile2D chunk(cr, La); + TASSIGN(chunk, UbOfs::CHUNK_FP32); + + ColVec mu2Sub(cr, 1); + TASSIGN(mu2Sub, UbOfs::MU2 + jg * sizeof(float)); + TROWEXPANDDIV(chunk, chunk, mu2Sub); + pipe_barrier(PIPE_V); + + { + Vec1D cFlat(1, flat); + Vec1D iFlat(1, flat); + TASSIGN(cFlat, UbOfs::CHUNK_FP32); + TASSIGN(iFlat, UbOfs::INV_MU1_TILED); + TMUL(cFlat, cFlat, iFlat); + pipe_barrier(PIPE_V); + } + + TCVT(hFlat, fFlat, RoundMode::CAST_RINT); + pipe_barrier(PIPE_V); + + Shape2D outShape(cr, L); + DynStride outStride(L); + Global2D outGlobal(out + jg * L, outShape, outStride); + + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(outGlobal, chunkHalf); + pipe_barrier(PIPE_ALL); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + } + + // ---- Write mu1_out ---- + { + Vec1D mu1F(1, La); + Vec1D mu1H(1, La); + TASSIGN(mu1F, UbOfs::MU1); + TASSIGN(mu1H, UbOfs::CHUNK_HALF); + TCVT(mu1H, mu1F, RoundMode::CAST_RINT); + pipe_barrier(PIPE_V); + + Global1D mu1G(mu1_out + static_cast(bi) * L); + TASSIGN(mu1G, (mu1_out + static_cast(bi) * L)); + + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(mu1G, mu1H); + pipe_barrier(PIPE_ALL); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + } + + // ---- Write mu2_out ---- + { + Vec1D mu2F(1, Ka); + Vec1D mu2H(1, Ka); + TASSIGN(mu2F, UbOfs::MU2); + TASSIGN(mu2H, UbOfs::CHUNK_HALF); + TCVT(mu2H, mu2F, RoundMode::CAST_RINT); + pipe_barrier(PIPE_V); + + Global1D mu2G(mu2_out + static_cast(bi) * K); + TASSIGN(mu2G, (mu2_out + static_cast(bi) * K)); + + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(mu2G, mu2H); + pipe_barrier(PIPE_ALL); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + } + } // bi loop + + drainPipeFlags(); +} + +// ---------- Dispatch to TileL-specialised impl ---------- +template +AICORE void runSinkhorn(__gm__ T *matrix_in, __gm__ T *matrix_out, + __gm__ T *mu1_out, __gm__ T *mu2_out, uint32_t N, + uint32_t K, uint32_t L, uint32_t order, float lr, + float eps, float invK, float invL, float invK1, + float invL1) { + set_mask_norm(); + set_vector_mask(-1, -1); + if (K == 0 || L == 0 || K > MAX_DIM || L > MAX_DIM) return; + + const uint32_t La = ALIGN_UP(L, TILE_ALIGN); + const uint32_t Ka = ALIGN_UP(K, TILE_ALIGN); + + // Dispatch to tight-stride specialisation. + // For La <= 32, the flat vectors are too short — barrier overhead dominates, + // so wider stride (MAX_DIM) amortises better. Specialise from La >= 64. + switch (La) { + case 64: + runSinkhornImpl(matrix_in, matrix_out, mu1_out, mu2_out, N, K, L, + La, Ka, order, lr, eps, invK, invL, invK1, invL1); + break; + case 128: + runSinkhornImpl(matrix_in, matrix_out, mu1_out, mu2_out, N, K, L, + La, Ka, order, lr, eps, invK, invL, invK1, invL1); + break; + default: + // La <= 32 or La >= 192: use MAX_DIM stride (long flat vectors) + runSinkhornImpl(matrix_in, matrix_out, mu1_out, mu2_out, N, K, + L, La, Ka, order, lr, eps, invK, invL, invK1, + invL1); + break; + } +} + +#endif // __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) + +// ---------- Entry points ---------- + +extern "C" __global__ AICORE void sinkhorn_fp16( + GM_ADDR matrix_in, GM_ADDR matrix_out, GM_ADDR mu1_out, GM_ADDR mu2_out, + uint32_t N, uint32_t K, uint32_t L, uint32_t order, float lr, float eps, + float invK, float invL, float invK1, float invL1) { +#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) + runSinkhorn((__gm__ half *)matrix_in, (__gm__ half *)matrix_out, + (__gm__ half *)mu1_out, (__gm__ half *)mu2_out, N, K, L, + order, lr, eps, invK, invL, invK1, invL1); +#else + (void)matrix_in; + (void)matrix_out; + (void)mu1_out; + (void)mu2_out; + (void)N; + (void)K; + (void)L; + (void)order; + (void)lr; + (void)eps; + (void)invK; + (void)invL; + (void)invK1; + (void)invL1; +#endif +} + +extern "C" void call_sinkhorn_kernel(uint32_t blockDim, void *stream, + uint8_t *matrix_in, uint8_t *matrix_out, + uint8_t *mu1_out, uint8_t *mu2_out, + uint32_t N, uint32_t K, uint32_t L, + uint32_t order, float lr, float eps, + float invK, float invL, float invK1, + float invL1) { + sinkhorn_fp16<<>>( + matrix_in, matrix_out, mu1_out, mu2_out, N, K, L, order, lr, eps, invK, + invL, invK1, invL1); +}