[Feat] Implement doubly-stochastic Sinkhorn normalization kernel#134
[Feat] Implement doubly-stochastic Sinkhorn normalization kernel#134Mocchibird wants to merge 5 commits into
Conversation
…associated benchmarks, tests, and documentation
| TASSIGN(hFlat, UbOfs::MAT_HALF); | ||
| TASSIGN(fFlat, UbOfs::MAT_FP32); | ||
| TCVT(fFlat, hFlat, RoundMode::CAST_NONE); |
There was a problem hiding this comment.
Can the computation be all done in FP16 without upcasting to FP32?
| Shape2D<T> outShape(K, K); | ||
| DynStride outStride(K); | ||
| Global2D<T, MAX_DIM> gOut(gm_out, outShape, outStride); | ||
|
|
||
| wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); | ||
| set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); | ||
| wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); | ||
| TSTORE(gOut, outHalf); |
There was a problem hiding this comment.
A single TSTORE only works on [K, K] matrix. In mHC paper n = 4, thus the matrix size is only 4x4, just 32 Bytes. The DMA needs >= 16 KiB to get high bandwidth util. Need to process in batches.
(you can also confirm in tilelang example_mhc_pre.py that hc_mult = 4)
| # Override default grids for sinkhorn: batch=N (matrices), K=dim | ||
| batches = args.batches if args.batches else [1, 4, 8, 16, 32, 64] | ||
| dims = args.hidden_dims if args.hidden_dims else [4, 8, 16, 32, 64, 128] |
There was a problem hiding this comment.
To emulate mHC setting, the dim is as small as 4, while the batch can be very large. Check the input shape of:
res_mix = mixes[:, 2 * hc_mult :].view(-1, hc_mult, hc_mult)
res_mix = sinkhorn_normalize_ref(res_mix, repeat=sinkhorn_repeat, eps=hc_sinkhorn_eps)from tilelang example
Also check tilelang-gpu's bandwidth util ratio for those shapes, to get a reasonable expectation for NPU perf.
…batch handling and visualization
| AICORE void sinkhorn(__gm__ T *in, __gm__ T *out, uint32_t N, uint32_t K, | ||
| uint32_t repeat, float eps) { |
There was a problem hiding this comment.
repeat can be a compile-time template parameter / const, just like for the hadamard kernel
| for (uint32_t it = 1; it < repeat; ++it) { | ||
| TASSIGN(v, VC); | ||
| TROWSUM(v, m, t); | ||
| pipe_barrier(PIPE_V); | ||
| TROWEXPANDDIV(m, m, v); | ||
| pipe_barrier(PIPE_V); | ||
| CN(); | ||
| } |
There was a problem hiding this comment.
This loop can be statically unrolled to save some scalar computes
- Introduced DISPATCH_SHAPES to cover various dispatch paths in kernel_sinkhorn.cpp based on batch size (N) and K values. - Added DISPATCH_CASES for efficient testing of different (batch, K) combinations. - Expanded DENSE_SHAPES for broader numerical regression coverage. - Consolidated TEST_CASES to eliminate duplicates from DISPATCH and DENSE shapes. - Updated test_output_is_doubly_stochastic to validate across representative shapes for each dispatch path.
Sinkhorn pto-isa vs torch