Skip to content

[Feat] Implement doubly-stochastic Sinkhorn normalization kernel#134

Draft
Mocchibird wants to merge 5 commits into
huawei-csl:mainfrom
Mocchibird:feat/sinkhorn
Draft

[Feat] Implement doubly-stochastic Sinkhorn normalization kernel#134
Mocchibird wants to merge 5 commits into
huawei-csl:mainfrom
Mocchibird:feat/sinkhorn

Conversation

@Mocchibird
Copy link
Copy Markdown
Contributor

Sinkhorn pto-isa vs torch

Speedup Bandwidth
speedup_heatmap bandwidth

Comment thread examples/jit_cpp/sinkhorn/test_sinkhorn.py Outdated
Comment on lines +140 to +142
TASSIGN(hFlat, UbOfs::MAT_HALF);
TASSIGN(fFlat, UbOfs::MAT_FP32);
TCVT(fFlat, hFlat, RoundMode::CAST_NONE);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the computation be all done in FP16 without upcasting to FP32?

Comment on lines +239 to +246
Shape2D<T> outShape(K, K);
DynStride outStride(K);
Global2D<T, MAX_DIM> gOut(gm_out, outShape, outStride);

wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0);
set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
TSTORE(gOut, outHalf);
Copy link
Copy Markdown
Collaborator

@learning-chip learning-chip Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A single TSTORE only works on [K, K] matrix. In mHC paper n = 4, thus the matrix size is only 4x4, just 32 Bytes. The DMA needs >= 16 KiB to get high bandwidth util. Need to process in batches.

(you can also confirm in tilelang example_mhc_pre.py that hc_mult = 4)

Comment on lines +211 to +213
# Override default grids for sinkhorn: batch=N (matrices), K=dim
batches = args.batches if args.batches else [1, 4, 8, 16, 32, 64]
dims = args.hidden_dims if args.hidden_dims else [4, 8, 16, 32, 64, 128]
Copy link
Copy Markdown
Collaborator

@learning-chip learning-chip Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To emulate mHC setting, the dim is as small as 4, while the batch can be very large. Check the input shape of:

    res_mix = mixes[:, 2 * hc_mult :].view(-1, hc_mult, hc_mult)

    res_mix = sinkhorn_normalize_ref(res_mix, repeat=sinkhorn_repeat, eps=hc_sinkhorn_eps)

from tilelang example

Also check tilelang-gpu's bandwidth util ratio for those shapes, to get a reasonable expectation for NPU perf.

Comment on lines +370 to +371
AICORE void sinkhorn(__gm__ T *in, __gm__ T *out, uint32_t N, uint32_t K,
uint32_t repeat, float eps) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

repeat can be a compile-time template parameter / const, just like for the hadamard kernel

Comment on lines +194 to +201
for (uint32_t it = 1; it < repeat; ++it) {
TASSIGN(v, VC);
TROWSUM(v, m, t);
pipe_barrier(PIPE_V);
TROWEXPANDDIV(m, m, v);
pipe_barrier(PIPE_V);
CN();
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This loop can be statically unrolled to save some scalar computes

- Introduced DISPATCH_SHAPES to cover various dispatch paths in kernel_sinkhorn.cpp based on batch size (N) and K values.
- Added DISPATCH_CASES for efficient testing of different (batch, K) combinations.
- Expanded DENSE_SHAPES for broader numerical regression coverage.
- Consolidated TEST_CASES to eliminate duplicates from DISPATCH and DENSE shapes.
- Updated test_output_is_doubly_stochastic to validate across representative shapes for each dispatch path.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants