Skip to content

Synkhorn dynamic multicore example#117

Open
MirkoDeVita98 wants to merge 3 commits into
mainfrom
synkorn
Open

Synkhorn dynamic multicore example#117
MirkoDeVita98 wants to merge 3 commits into
mainfrom
synkorn

Conversation

@MirkoDeVita98
Copy link
Copy Markdown
Collaborator

@MirkoDeVita98 MirkoDeVita98 commented Apr 21, 2026

Adds a PTODSL implementation of the fp16 Sinkhorn normalization kernel under examples/aot/sinkhorn_dynamic_multicore/, plus a JIT helper, a 66-case correctness test, and a benchmark that compares it against PyTorch fp16 and the hand-tuned C++ reference from (reference.cpp).

Differences vs reference.cpp

The reference is hand-tuned to squeeze the last cycle out of the hardware.
The PTODSL version trades a little of that for clarity:

Concern reference.cpp PTODSL builder
Per-L specialisation Templated runSinkhornImpl<T, TileL> switch Single MAX_DIM = 256 column stride
inv_mu1 broadcast Pre-tiled into a [ROW_CHUNK, L] UB buffer One tile.col_expand_mul
pow(x, lr) 2-term Padé approxLn + TEXP Native tile.log / tile.exp
Synchronisation barriers Manual set_flag / wait_flag ptoas --enable-insert-sync (auto)
Hand-laid UB layout Explicit byte offsets per buffer pto.alloc_tile per logical buffer

Correctness

Both kernels pass the same 66-case suite. Tightened tolerances measured
on every case:

kernel worst abs error worst rel error
PTODSL 9.8e-04 9.8e-04
reference C++ 5.9e-03 4.1e-03

Default tolerance in run_sinkhorn.py is now rtol=2e-3, atol=1e-3 (the
PTODSL kernel passes comfortably). For the reference, pass
--rtol 5e-2 --atol 1e-2 (matching the upstream torch_npu test). The
extra precision comes from using the hardware tile.log instead of the
2-term Padé approximation the reference relies on.

Performance

Ascend 910B2, fp16, order = 8, 5 warmup + 20 timed runs.

Single-matrix latency (head-shapes grid)

PTODSL is 10–55× faster than torch fp16 and 1.40–1.78× faster than
the hand-tuned reference C++
across every shape:

PTODSL vs torch fp16 (speedup):
head_shapes_speedup_pto_vs_torch

PTODSL vs reference C++ (speedup):
head_shapes_speedup_pto_vs_ref

Effective bandwidth and compute throughput:
head_shapes_bandwidth
head_shapes_flops

Batched vs serial (K = L = 128)

Per-matrix latency is flat from N = 1 to N = 32 (perfect multicore fill,
24 cores × 2 vector units), then linear. Batched PTODSL stays
~1.55–1.63× ahead of the reference at every batch size:

batched_vs_serial_log

Why PTODSL is faster

The reference comment estimates "8 barriers per phase, not 14". With
ptoas --enable-insert-sync the compiler picks the minimum based on
actual SSA dataflow, while a human writing manual flags has to be conservative and over-syncs. On top of that, native tile.log / tile.exp removes the Padé scratch tiles and their barriers, and tile.col_expand_mul
removes the INV_MU1_TILED scratch (the largest single buffer in the
reference's UB layout).

Workarounds

Two limitations of the current stack forced a small amount of extra
work in our builder; both are pure boilerplate and could be removed
by PTOAS fixes:

  1. Subview → reshape storage mismatch. pto.subview narrows the
    valid shape but reuses the parent's storage Numel, so a
    downstream tile.reshape fails the bisheng TRESHAPE byte-size
    static_assert. Worked around by copying the 8 mu2 elements into a
    static [1, ROW_CHUNK] tile before reshaping. Suggested fix: have
    pto.subview rewrite the result tile-buf storage shape when the
    slice sizes are static.
  2. 1×1 scalar tiles need dynamic valid-shape metadata. tile.min /
    tile.adds / tile.reshape require GetValidRow/Col even on a
    conceptually 1×1 tile, and there is no row-major scalar type accepted
    by tile.row_expand_div as a broadcast source. Worked around with a
    [8, 1] / [1, 8] tile holding valid_shape = [1, 1]. Suggested
    fix: lower these ops over a fully-static 1×1 tile via the
    immediate-form intrinsic.

A third minor item: K-indexed quantities are forced into row-major
because none of TMul/TSub/TMin/TLog/TExp/TSqrt/TAddS/TRowMin/TExpandDiv accepts a layout-override attribute. Adding one would let the builder keep them column-major and drop the per-chunk reshape entirely.

How to reproduce

cd examples/aot/sinkhorn_dynamic_multicore

# Correctness (PTODSL kernel)
./compile.sh
python run_sinkhorn.py --lib ./sinkhorn_lib.so

# Benchmark (JIT-compiles both kernels, writes CSV + plots under outputs/)
python bench_sinkhorn.py

Comment on lines +238 to +242
# ============================================================
# Per-batch loop — workers split N across all vector cores.
# ============================================================
for bi in pto.range(wid, N, num_workers):
# Init mu1, mu2, invMu1 to all-ones via muls(.,0)+adds(.,1).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The C++ ref will be modified to load multiple samples, to get higher memory BW util. Each sample is too small: huawei-csl/pto-kernels#134 (comment)

@learning-chip
Copy link
Copy Markdown
Collaborator

learning-chip commented Apr 23, 2026

With ptoas 0.27 I am getting:

 bash ./compile.sh 
Traceback (most recent call last):
  File "/workdir/pto-dsl/examples/aot/sinkhorn_dynamic_multicore/./sinkhorn_builder.py", line 26, in <module>
    from ptodsl import pto, tile, to_ir_module
  File "/workdir/pto-dsl/ptodsl/__init__.py", line 1, in <module>
    from . import pto, scalar, tile
  File "/workdir/pto-dsl/ptodsl/pto.py", line 1, in <module>
    from .api import pto as _pto
  File "/workdir/pto-dsl/ptodsl/api/__init__.py", line 1, in <module>
    from . import pto, scalar, tile
  File "/workdir/pto-dsl/ptodsl/api/tile.py", line 299, in <module>
    "int8_sym": _pto.QuantType.INT8_SYM,
                ^^^^^^^^^^^^^^
AttributeError: module 'mlir.dialects.pto' has no attribute 'QuantType'

Need 0.28?

ptoas 0.29 image works

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants