Synkhorn dynamic multicore example#117
Open
MirkoDeVita98 wants to merge 3 commits into
Open
Conversation
added 2 commits
April 21, 2026 21:57
…he reference.cpp produced so file
Comment on lines
+238
to
+242
| # ============================================================ | ||
| # Per-batch loop — workers split N across all vector cores. | ||
| # ============================================================ | ||
| for bi in pto.range(wid, N, num_workers): | ||
| # Init mu1, mu2, invMu1 to all-ones via muls(.,0)+adds(.,1). |
Collaborator
There was a problem hiding this comment.
The C++ ref will be modified to load multiple samples, to get higher memory BW util. Each sample is too small: huawei-csl/pto-kernels#134 (comment)
Collaborator
|
ptoas 0.29 image works |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds a PTODSL implementation of the fp16 Sinkhorn normalization kernel under
examples/aot/sinkhorn_dynamic_multicore/, plus a JIT helper, a 66-case correctness test, and a benchmark that compares it against PyTorch fp16 and the hand-tuned C++ reference from (reference.cpp).Differences vs
reference.cppThe reference is hand-tuned to squeeze the last cycle out of the hardware.
The PTODSL version trades a little of that for clarity:
LspecialisationrunSinkhornImpl<T, TileL>switchMAX_DIM = 256column strideinv_mu1broadcast[ROW_CHUNK, L]UB buffertile.col_expand_mulpow(x, lr)approxLn+TEXPtile.log/tile.expset_flag/wait_flagptoas --enable-insert-sync(auto)pto.alloc_tileper logical bufferCorrectness
Both kernels pass the same 66-case suite. Tightened tolerances measured
on every case:
Default tolerance in
run_sinkhorn.pyis nowrtol=2e-3, atol=1e-3(thePTODSL kernel passes comfortably). For the reference, pass
--rtol 5e-2 --atol 1e-2(matching the upstream torch_npu test). Theextra precision comes from using the hardware
tile.loginstead of the2-term Padé approximation the reference relies on.
Performance
Ascend 910B2, fp16,
order = 8, 5 warmup + 20 timed runs.Single-matrix latency (head-shapes grid)
PTODSL is 10–55× faster than torch fp16 and 1.40–1.78× faster than
the hand-tuned reference C++ across every shape:
PTODSL vs torch fp16 (speedup):

PTODSL vs reference C++ (speedup):

Effective bandwidth and compute throughput:


Batched vs serial (K = L = 128)
Per-matrix latency is flat from N = 1 to N = 32 (perfect multicore fill,
24 cores × 2 vector units), then linear. Batched PTODSL stays
~1.55–1.63× ahead of the reference at every batch size:
Why PTODSL is faster
The reference comment estimates "8 barriers per phase, not 14". With
ptoas --enable-insert-syncthe compiler picks the minimum based onactual SSA dataflow, while a human writing manual flags has to be conservative and over-syncs. On top of that, native
tile.log/tile.expremoves the Padé scratch tiles and their barriers, andtile.col_expand_mulremoves the
INV_MU1_TILEDscratch (the largest single buffer in thereference's UB layout).
Workarounds
Two limitations of the current stack forced a small amount of extra
work in our builder; both are pure boilerplate and could be removed
by PTOAS fixes:
pto.subviewnarrows thevalid shape but reuses the parent's storage
Numel, so adownstream
tile.reshapefails the bishengTRESHAPEbyte-sizestatic_assert. Worked around by copying the 8 mu2 elements into astatic
[1, ROW_CHUNK]tile before reshaping. Suggested fix: havepto.subviewrewrite the result tile-buf storage shape when theslice sizes are static.
tile.min/tile.adds/tile.reshaperequireGetValidRow/Coleven on aconceptually 1×1 tile, and there is no row-major scalar type accepted
by
tile.row_expand_divas a broadcast source. Worked around with a[8, 1]/[1, 8]tile holdingvalid_shape = [1, 1]. Suggestedfix: lower these ops over a fully-static 1×1 tile via the
immediate-form intrinsic.
A third minor item: K-indexed quantities are forced into row-major
because none of
TMul/TSub/TMin/TLog/TExp/TSqrt/TAddS/TRowMin/TExpandDivaccepts a layout-override attribute. Adding one would let the builder keep them column-major and drop the per-chunk reshape entirely.How to reproduce