Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion examples/jit_cpp/fast_hadamard/jit_util_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,31 @@
import torch

ASCEND_TOOLKIT_HOME = os.environ["ASCEND_TOOLKIT_HOME"]
PTO_LIB_PATH = os.environ.get("PTO_LIB_PATH", ASCEND_TOOLKIT_HOME)


def _resolve_pto_lib_path() -> str:
"""Pick the PTO-ISA header root for bisheng's -isystem flag.

Resolution order:
1. ``$PTO_LIB_PATH`` (explicit override).
2. The pto-kernels CMake FetchContent mirror at
``<repo>/build/_deps/libpto_isa_headers-src`` (pinned in
top-level CMakeLists.txt; populated by any cmake build).
3. ``$ASCEND_TOOLKIT_HOME`` (CANN default; may lack newer
instructions such as TCOLEXPANDDIV on CANN 8.5.0).
"""
env = os.environ.get("PTO_LIB_PATH")
if env:
return env
# examples/jit_cpp/<example>/jit_util_common.py → parents[3] is the repo root
repo_root = Path(__file__).resolve().parents[3]
vendored = repo_root / "build" / "_deps" / "libpto_isa_headers-src"
if (vendored / "include" / "pto" / "pto-inst.hpp").is_file():
return str(vendored)
return ASCEND_TOOLKIT_HOME


PTO_LIB_PATH = _resolve_pto_lib_path()
DEFAULT_DEVICE = "npu:0"
DEFAULT_BLOCK_DIM = 20
MAX_HADAMARD_N = 16384
Expand Down
3 changes: 3 additions & 0 deletions examples/jit_cpp/sinkhorn/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*.png
*.csv
*.so
60 changes: 60 additions & 0 deletions examples/jit_cpp/sinkhorn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Doubly-Stochastic Sinkhorn Normalization

PTO-ISA kernel for doubly-stochastic Sinkhorn normalization on Ascend NPU.
Implements the DeepSeek MHC pre-processing algorithm: softmax per row,
then alternating row/column normalization until the matrix is approximately
doubly-stochastic.

## Algorithm

```python
def sinkhorn_normalize(x, repeat=10, eps=1e-6):
x = x.softmax(-1) + eps
x = x / (x.sum(-2, keepdim=True) + eps)
for _ in range(repeat - 1):
x = x / (x.sum(-1, keepdim=True) + eps)
x = x / (x.sum(-2, keepdim=True) + eps)
return x
```

Input shape: `(N, K, K)` fp16 — N square matrices of dimension K.
K must be <= 128 (the full K×K matrix lives in UB as fp32).

## Files

| File | Description |
|---|---|
| `kernel_sinkhorn.cpp` | PTO-ISA C++ kernel (JIT-compiled via bisheng) |
| `jit_util_sinkhorn.py` | Python wrapper — compiles & exposes the kernel |
| `conftest.py` | pytest NPU device fixture |
| `test_sinkhorn.py` | Correctness tests vs PyTorch reference |
| `bench_sinkhorn.py` | Benchmark PTO vs PyTorch |
| `plot_sinkhorn.py` | Plot benchmark CSVs |

## Usage

```bash
# compile + test
python -m pytest test_sinkhorn.py -v --npu=npu:0

# benchmark
python bench_sinkhorn.py

# re-plot from saved CSV
python plot_sinkhorn.py
```

## Reproducing

```bash
cd examples/jit_cpp/sinkhorn

# tests (73 cases: shapes × repeats × seeds)
python -m pytest test_sinkhorn.py -v --npu=npu:0

# benchmark (batch × K grid, repeat=10, eps=1e-6)
python bench_sinkhorn.py --batches 1 4 8 16 32 64 --hidden-dims 4 8 16 32 64 128

# plots
python plot_sinkhorn.py
```
238 changes: 238 additions & 0 deletions examples/jit_cpp/sinkhorn/bench_sinkhorn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
# pylint: disable=wrong-import-position
"""
Benchmark PTO doubly-stochastic Sinkhorn against PyTorch reference.

Writes:
outputs/csv/sinkhorn_compare_bd{block_dim}.csv
outputs/plots/ (via plot_sinkhorn.py)
"""
import argparse
import sys
from pathlib import Path

import torch
import torch_npu # noqa

THIS_DIR = Path(__file__).resolve().parent
FAST_HADAMARD_DIR = THIS_DIR.parent / "fast_hadamard"
if str(FAST_HADAMARD_DIR) not in sys.path:
sys.path.insert(0, str(FAST_HADAMARD_DIR))

from bench_common import ( # noqa: E402
add_common_benchmark_args,
benchmark_npu_us,
benchmark_trials_us,
ensure_output_dir,
make_buffer_pool,
pool_item,
resolve_dir_arg,
validate_benchmark_args,
write_csv_records,
)

from jit_util_common import get_current_stream_ptr # noqa: E402
from jit_util_sinkhorn import jit_compile # noqa: E402

DEFAULT_WARMUP = 10
DEFAULT_REPEATS = 100
SINKHORN_REPEAT = 8
TORCH_REF_REPEAT = 10 # fixed for consistent baseline
SINKHORN_EPS = 1e-6
BYTES_PER_ELEMENT = 2 # fp16

CSV_HEADER = (
"batch,N,pto_duration_us,torch_duration_us,"
"pto_bandwidth_gbs,torch_bandwidth_gbs,pto_speedup_vs_torch,"
"trials,pto_duration_mean_us,pto_duration_std_us,pto_duration_min_us,"
"pto_duration_max_us,pto_duration_cv_pct,torch_duration_mean_us,"
"torch_duration_std_us,torch_duration_min_us,"
"torch_duration_max_us,torch_duration_cv_pct\n"
)


def sinkhorn_ref(x, repeat=10, eps=1e-6):
"""PyTorch reference (runs on NPU via torch ops)."""
x = x.float()
x = x.softmax(-1) + eps
x = x / (x.sum(-2, keepdim=True) + eps)
for _ in range(repeat - 1):
x = x / (x.sum(-1, keepdim=True) + eps)
x = x / (x.sum(-2, keepdim=True) + eps)
return x.half()


def _parse_args():
parser = argparse.ArgumentParser(
description="Benchmark PTO Sinkhorn (doubly-stochastic) against PyTorch reference."
)
parser.add_argument(
"--no-cache-stream",
dest="cache_stream",
action="store_false",
help="Disable cached stream pointer reuse for PTO launches.",
)
parser.set_defaults(cache_stream=True)
return add_common_benchmark_args(
parser,
default_warmup=DEFAULT_WARMUP,
default_repeats=DEFAULT_REPEATS,
).parse_args()


def _effective_bandwidth_gbs(batch, K, duration_us):
if duration_us <= 0:
return 0.0
# read K*K + write K*K
data_bytes = batch * 2 * K * K * BYTES_PER_ELEMENT
return (data_bytes / 1e9) / (duration_us / 1e6)


def _make_shape_pools(batch, K, warmup, repeats, device):
return {
"x": make_buffer_pool(
warmup,
repeats,
lambda: torch.randn(batch, K, K, device=device, dtype=torch.float16),
),
"y": make_buffer_pool(
warmup,
repeats,
lambda: torch.empty(batch, K, K, device=device, dtype=torch.float16),
),
}


def benchmark(
sinq_func,
*,
warmup,
repeats,
trials,
output_dir,
device,
batches,
hidden_dims,
stream_ptr=None,
):
ensure_output_dir(output_dir)
block_dim = sinq_func.block_dim

print(f"\n{'=' * 92}")
print(
f"SINKHORN DS BENCHMARK (BLOCK_DIM={block_dim}, pto_repeat={SINKHORN_REPEAT}, torch_repeat={TORCH_REF_REPEAT})"
)
print(f"{'=' * 92}")
header = (
f"{'batch':>6s} {'K':>6s}"
f" {'pto_us':>10s} {'torch_us':>10s}"
f" {'pto_bw(GB/s)':>12s} {'torch_bw(GB/s)':>14s} {'pto_speedup':>11s}"
)
print(header)
print("-" * len(header))

records = []
for batch in batches:
for K in hidden_dims:
pools = _make_shape_pools(batch, K, warmup, repeats, device)
x_list = pools["x"]
y_list = pools["y"]

pto_stats = benchmark_trials_us(
trials,
lambda x_list=x_list, y_list=y_list: benchmark_npu_us(
warmup,
repeats,
lambda i: sinq_func(
pool_item(x_list, i),
pool_item(y_list, i),
repeat=SINKHORN_REPEAT,
eps=SINKHORN_EPS,
stream_ptr=stream_ptr,
),
),
)
torch_stats = benchmark_trials_us(
trials,
lambda x_list=x_list: benchmark_npu_us(
warmup,
repeats,
lambda i: sinkhorn_ref(
pool_item(x_list, i),
repeat=TORCH_REF_REPEAT,
eps=SINKHORN_EPS,
),
),
)

pto_us = pto_stats["median_us"]
torch_us = torch_stats["median_us"]
pto_bw = _effective_bandwidth_gbs(batch, K, pto_us)
torch_bw = _effective_bandwidth_gbs(batch, K, torch_us)
pto_speedup = torch_us / pto_us if pto_us > 0 else 0.0

print(
f"{batch:>6d} {K:>6d}"
f" {pto_us:>10.2f} {torch_us:>10.2f}"
f" {pto_bw:>12.4f} {torch_bw:>14.4f}"
f" {pto_speedup:>11.3f}"
)

records.append(
f"{batch},{K},{pto_us:.4f},{torch_us:.4f},"
f"{pto_bw:.6f},{torch_bw:.6f},"
f"{pto_speedup:.4f},"
f"{trials},{pto_stats['mean_us']:.4f},{pto_stats['std_us']:.4f},"
f"{pto_stats['min_us']:.4f},{pto_stats['max_us']:.4f},"
f"{pto_stats['cv_pct']:.4f},{torch_stats['mean_us']:.4f},"
f"{torch_stats['std_us']:.4f},"
f"{torch_stats['min_us']:.4f},"
f"{torch_stats['max_us']:.4f},"
f"{torch_stats['cv_pct']:.4f}"
)

csv_path = output_dir / f"sinkhorn_compare_bd{block_dim}.csv"
write_csv_records(csv_path, CSV_HEADER, records)
print(f"\nSaved to {csv_path}")


def main():
args = _parse_args()
validate_benchmark_args(args)

torch.npu.set_device(args.npu)
base = THIS_DIR
kernel_path = base / "kernel_sinkhorn.cpp"
csv_dir = resolve_dir_arg(base, args.csv_dir)

print(f"Using device: {args.npu}")
print("Compiling kernel_sinkhorn.cpp ...")
sinq_func = jit_compile(str(kernel_path), verbose=True, device=args.npu)
stream_ptr = get_current_stream_ptr() if args.cache_stream else None
if stream_ptr is not None:
print("Using cached NPU stream pointer for PTO launches.")

# Default: mHC use case (hc_mult=4, varying num_tokens).
# In DeepSeek MHC, sinkhorn always runs on (num_tokens, 4, 4) matrices.
# Pass --hidden-dims to benchmark other K values (general fallback path).
batches = (
args.batches
if args.batches
else [1, 4, 16, 64, 256, 512, 1024, 2048, 4096, 8192, 16384, 65536]
)
dims = args.hidden_dims if args.hidden_dims else [4, 8, 16, 32, 64, 128]

benchmark(
sinq_func,
warmup=args.warmup,
repeats=args.repeats,
trials=args.trials,
output_dir=csv_dir,
device=args.npu,
batches=batches,
hidden_dims=dims,
stream_ptr=stream_ptr,
)


if __name__ == "__main__":
main()
40 changes: 40 additions & 0 deletions examples/jit_cpp/sinkhorn/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest
import torch


def normalize_npu_device(device: str | int) -> str:
text = str(device).strip().strip('"').strip("'")
if text.lower().startswith("npu:"):
index = text.split(":", 1)[1].strip()
else:
index = text

if not index.isdigit():
raise ValueError(
f"Invalid NPU device '{device}'. Expected values like 0 or npu:0."
)
return f"npu:{int(index)}"


def pytest_addoption(parser):
try:
parser.addoption(
"--npu",
action="store",
default="npu:0",
help="NPU device (examples: 0, npu:0, '0', 'npu:0').",
)
except ValueError as exc:
if "--npu" not in str(exc):
raise


@pytest.fixture(scope="session")
def npu_device(request):
raw = request.config.getoption("--npu")
return normalize_npu_device(raw)


@pytest.fixture(scope="session", autouse=True)
def setup_npu_device(npu_device):
torch.npu.set_device(npu_device)
Loading
Loading