Skip to content

[TLE] Add tle.cumsum & Optimize TopK-Selector#489

Draft
sunnycase wants to merge 15 commits intotriton_v3.6.xfrom
feature/tle_topk_3.6
Draft

[TLE] Add tle.cumsum & Optimize TopK-Selector#489
sunnycase wants to merge 15 commits intotriton_v3.6.xfrom
feature/tle_topk_3.6

Conversation

@sunnycase
Copy link
Copy Markdown
Collaborator

@sunnycase sunnycase commented Apr 2, 2026

Summary

This PR builds on PR #395 (initial distributed APIs) and completes the 3.6 migration for TLE TopK/remote-pointer paths with root-cause fixes in lowering and backend codegen.

What’s Added

1. New TLE primitive: tle.cumsum(...)

  • Added frontend API for exclusive cumsum + total sum return:
    • exclusive, total = tle.cumsum(input, axis=..., reverse=..., dtype=...)
  • Added new IR op:
    • tle.exclusive_cumsum
  • Added dedicated lowering + layout optimization pipeline for cumsum.

2. Extended pointer primitives

  • tle.gpu.local_ptr(buffer, indices=None) now supports full-view pointer materialization (no explicit indices needed).
  • tle.remote(...) now supports direct pointer tensors (tl.tensor pointers), not only buffered tensors.
  • Remote pointer semantics were migrated to cluster-shared addrspace=7 for consistent downstream lowering/codegen.

3. Compiler/lowering pipeline updates

  • Added/renamed TLE passes:
    • triton-tle-select-encodings
    • triton-tle-insert-local-pointer-barriers
    • triton-tle-optimize-local-pointer-loads
    • triton-tle-optimize-local-pointer-stores
    • triton-tle-optimize-exclusive-cumsum-layouts
    • triton-tle-lower-exclusive-cumsum
  • NVIDIA backend load/store/atomic lowering updated to treat shared-family address spaces consistently (3 and 7).
  • Remote pointer lowering now uses dedicated remote-pointer conversion flow instead of metadata-carrier fallback paths.

Performance

Environment

  • GPU: NVIDIA H800
  • Date: 2026-04-02
  • All numbers below are from local benchmark runs on this branch.

A) TLE TopK Selector (batch = 64)

Command:

conda run -n flagtree python python/tutorials/tle/deepseek_v32/01-topk_selector.py \
  --skip_correctness \
  --providers tle-trt,tle-trt-1024threads,tle-cluster,triton \
  --bench_x_vals 64x4096x128,64x8192x256,64x32768x1024,64x32768x2048 \
  --warmup 10 --rep 20
batch seq_len topk TLE-TRT (ms) TLE-TRT-1024T (ms) TLE-Cluster (ms) Triton (ms) Best TLE (ms) Speedup vs Triton (Triton / Best TLE)
64 4096 128 0.009920 0.009504 0.063488 0.049696 0.009504 5.23x
64 8192 256 0.012384 0.011904 0.067072 0.064704 0.011904 5.44x
64 32768 1024 0.028480 0.024448 0.090240 0.151968 0.024448 6.22x
64 32768 2048 0.029632 0.025280 0.093120 0.150784 0.025280 5.96x

B) TLE TopK Selector (batch = 1)

Command:

conda run -n flagtree python python/tutorials/tle/deepseek_v32/01-topk_selector.py \
  --skip_correctness \
  --providers tle-trt,tle-trt-1024threads,tle-cluster,triton \
  --bench_x_vals 1x131072x2048,1x262144x2048,1x524288x2048 \
  --warmup 10 --rep 20
batch seq_len topk TLE-TRT (ms) TLE-TRT-1024T (ms) TLE-Cluster (ms) Triton (ms) Best TLE (ms) Speedup vs Triton (Triton / Best TLE)
1 131072 2048 0.075456 0.053152 0.029664 0.512480 0.029664 17.28x
1 262144 2048 0.134784 0.090496 0.038336 1.007152 0.038336 26.27x
1 524288 2048 0.251584 0.166000 0.055424 1.925536 0.055424 34.74x

C) TopK Kernel Microbenchmark (radix/streaming/torch)

Command:

conda run -n flagtree python python/tutorials/tle/03-topk.py --skip_correctness
M N K Triton-RadixSelect (ms) Triton-TopK (ms) Torch-TopK (ms) Best Triton (ms) Speedup vs Torch (Torch / Best Triton)
64 128 8 0.007616 0.006304 0.009856 0.006304 1.56x
64 1024 32 0.009536 0.008192 0.014864 0.008192 1.81x
64 8192 128 0.025808 0.034304 0.052144 0.025808 2.02x
128 32768 256 0.080768 0.152544 0.090016 0.080768 1.11x

Validation

  • Rebuilt backend + Python extension (./build-nvidia.sh) after rebase to ensure updated TLE passes/lowering were active.
  • Benchmarks above were rerun on rebuilt artifacts from this branch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants