Skip to content

NCU Profile: GqaFwdWsPersistentKernel bottleneck analysis (llama8b-4k, H200) #11

@superAngGao

Description

@superAngGao

Overview

NCU full-set profiling of GqaFwdWsPersistentKernel from tile-ai/TileOPs#871, commit beab3ca.

Shape: B=1 S=4096 H=32 Hkv=8 D=128 causal=True fp16 (llama8b-4k)
Hardware: NVIDIA H200 (132 SMs, 228 KB smem/SM, 65536 regs/SM), CUDA 12.8
Kernel duration: ~310 µs → 66.7% FA3

The kernel is not bandwidth-bound on this shape. The bottleneck is tensor core pipeline utilization — only 57.7% of peak, with 39% idle cycles caused by non-wgmma work (softmax, barriers, smem reads) being too slow to fully overlap with wgmma execution.


1. Launch Configuration

Parameter Value
Grid (132, 1, 1) — 1 persistent CTA per SM
Block (384, 1, 1) — 3 warpgroups × 128 threads
Registers/thread 168
Total regs/block 168 × 384 = 64,512 (of 65,536/SM)
Shared memory/block (allocated) 165.9 KB (of 228 KB configurable)
Cluster size 0 (no cluster)
Waves/SM 1.0
Stack size 1024 B

2. Occupancy

Metric Value
Theoretical occupancy 18.75% (12 warps / 64 max)
Achieved occupancy 18.68%
Achieved active warps/SM 11.96
Block limit — registers 1 (bottleneck)
Block limit — shared mem 1 (bottleneck)
Block limit — warps 5
Block limit — barriers 12
Block limit — SM 32

Both registers and shared memory independently limit the SM to 1 block. The register file is 98.4% full (64,512 / 65,536).

3. Throughput (Speed of Light)

Metric Value
SM frequency 1.46 GHz
DRAM frequency 3.20 GHz
Compute (SM) throughput 57.7% of peak
Tensor pipe active (of elapsed) 57.7%
Tensor pipe active (of SM active) 60.9%
Memory throughput 38.8%
L1/TEX throughput 40.9%
L2 throughput 29.5%
DRAM throughput 5.17%
SM busy 60.9%
SM active cycles (avg) 430,587
SM elapsed cycles (avg) 455,541

Key observation

DRAM throughput is only 5.17% — the kernel is entirely L2-resident on this shape. Total DRAM traffic is 57.7 MB read + 21.0 MB write = 78.7 MB. The L2 hit rate is 90.7%, confirming that K/V data (B=1 × S=4096 × Hkv=8 × D=128 × 2B = 8 MB per K or V) fits comfortably in H200's 60 MB L2 cache.

The bottleneck is tensor core utilization (57.7%), not memory bandwidth.

4. Warp Stall Breakdown

Measured via smsp__average_warps_issue_stalled_*_per_issue_active.ratio:

Stall Reason Ratio Interpretation
long_scoreboard 3.30 Waiting for long-latency memory ops (smem→reg via L1/TEX, TMA completion)
wait (DEPBAR) 1.59 Waiting for wgmma to drain (wait_wgmma(0) / wait_wgmma(1))
barrier (mbarrier) 1.13 Waiting on mbarrier arrive/wait (producer↔consumer sync, ping-pong scheduler)
selected 1.00 Warp was selected and issued (not a stall)
no_instruction 0.42 I-cache miss or warp not yet ready
short_scoreboard 0.34 Waiting for short-latency ops (register deps)
dispatch_stall 0.27 Dispatch unit stall
not_selected 0.26 Warp eligible but not selected
branch_resolving 0.07 Branch resolution
mio_throttle 0.05 Memory I/O throttle
math_pipe_throttle 0.02 Tensor core back-pressure (negligible)
gmma 0.00 GMMA-specific stall
membar 0.00 Memory barrier
sleeping 0.00
lg_throttle 0.00
tex_throttle 0.00
drain 0.00
imc_miss 0.00
misc 0.00

Stall analysis

long_scoreboard (3.30) is the dominant stall. In this kernel, the long-latency operations are:

  • Shared memory reads via L1/TEX — consumer warpgroups reading k_smem_* / v_smem_* / q_shared_* to feed wgmma. The swizzled smem layout goes through L1TEX with 72.2% hit rate.
  • TMA async copy completion — producer waiting for TMA descriptor-based copies to land in smem.
  • This is NOT HBM latency (DRAM is only 5% utilized, L2 hit rate is 90.7%). The stall is on the smem → register file path through L1TEX.

wait / DEPBAR (1.59) is the second largest stall. Consumer warps stall on wait_wgmma(0) or wait_wgmma(1) (WARPGROUP.DEPBAR.LE). With IntraWGOverlap (wait_group<1>), the current softmax/rescale overlaps with the previous wgmma, but 1.59 ratio indicates the overlap is incomplete — the non-wgmma work finishes before the wgmma, so the consumer must idle-wait.

barrier (1.13) is the third stall. The kernel uses 8 mbarriers:

  • k_full, k_empty — K pipeline (producer → consumer data ready / consumer → producer buffer free)
  • v_full, v_empty — V pipeline
  • wg_sched_12, wg_sched_21 — WG1↔WG2 ping-pong scheduler
  • q_full_1, q_full_2 — per-sub-tile Q TMA loads

Each K-loop iteration triggers multiple arrive/wait pairs. PR body notes ~2% overhead from using mbarrier-based scheduler vs PTX bar.arrive named barriers (which need an upstream TileLang patch).

5. Instruction Statistics

Metric Value
Total instructions executed ~80.3M
Executed IPC (active) 1.41
Executed IPC (elapsed) 1.34
Issue slots busy 35.3%
Issued IPC (active) 1.41
Warp cycles per issued instruction 8.46
Avg active threads per warp 31.45
Avg not predicated-off threads per warp 31.06
Branch instructions ~2.28M
Branch efficiency 98.24%
Avg divergent branches 87.5

6. Memory Workload

Metric Value
Memory throughput 253.9 GB/s
Mem busy 38.8%
Max bandwidth 38.1%
L1/TEX hit rate 72.2%
L2 hit rate 90.7%
L2 compression 0% (disabled)
Mem pipes busy 31.2%
DRAM read 57.7 MB
DRAM write 21.0 MB
DRAM read BW 184.8 GB/s
DRAM write BW 67.2 GB/s

Shared memory (GMMA) instructions

Metric Value
GMMA instructions executed (smsp) 2,162,688
GMMA instructions (% of peak elapsed) 3.6%

7. Tensor Core Detail

Metric Value
sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_elapsed 57.68%
sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_active 60.92%
sm__pipe_tensor_cycles_active.min.pct_of_peak_sustained_elapsed 44.61%
sm__pipe_tensor_cycles_active.max.pct_of_peak_sustained_elapsed 59.48%
sm__inst_executed_pipe_tensor_op_gmma.avg.pct_of_peak_sustained_active 3.81%
sm__inst_executed_pipe_tensor_op_hmma.avg.pct_of_peak_sustained_active 1.90%
sm__pipe_tensor_cycles_active_realtime.avg.pct_of_peak_sustained_elapsed 27.42%
Achieved FLOPS (hgmma fp16, no sparsity) 57.68% of peak

The min/max spread (44.6%–59.5%) across SMs reflects causal tile pairing imbalance — even with (k, M-1-k) pairing, the two sub-tiles within a pair have different K-loop lengths, so some CTAs finish earlier than others.

8. Workload Distribution

Metric Value
Avg DRAM active cycles 51,325
Total DRAM elapsed cycles 47.7M
Avg SM active cycles 430,586
Total SM elapsed cycles 59.9M
Avg L2 active cycles 509,492
Total L2 elapsed cycles 48.9M
Avg SMSP active cycles 430,099
Total SMSP elapsed cycles 239.4M

9. Shared Memory Budget Breakdown

Current smem allocation (block_m=128, block_n=128, dim=128, fp16):

q_shared_1:  half_m × dim × 2B  =  64 × 128 × 2  =  16 KB
q_shared_2:  half_m × dim × 2B  =  64 × 128 × 2  =  16 KB
k_smem_0:    block_n × dim × 2B = 128 × 128 × 2  =  32 KB
k_smem_1:    block_n × dim × 2B = 128 × 128 × 2  =  32 KB
v_smem_0:    block_n × dim × 2B = 128 × 128 × 2  =  32 KB
v_smem_1:    block_n × dim × 2B = 128 × 128 × 2  =  32 KB
barriers (8) + driver smem + padding            ≈   6 KB
─────────────────────────────────────────────────────────
Total:                                            ~166 KB
Allocated (NCU reported):                         165.9 KB
Remaining (of 228 KB max):                        ~62 KB

3-stage buffer would need +k_smem_2 + v_smem_2 = +64 KB → ~230 KB > 228 KB limit.

Even if it did fit, adding pipeline depth only helps when the bottleneck is producer-side TMA fill latency. On this shape, DRAM is 5% utilized and L2 hit rate is 91% — the producer is not the bottleneck.

10. Why 3-Stage Buffering Cannot Help

  1. Doesn't fit: 166 KB + 64 KB = 230 KB > 228 KB smem limit per SM.
  2. Wrong bottleneck: The producer (TMA loads) is not the bottleneck. DRAM throughput is 5.17%, and L2 hit rate is 90.7%. More prefetch depth doesn't help when data is already L2-resident.
  3. Consumer-side stall: The top stalls (long_scoreboard 3.30, wait 1.59, barrier 1.13) are all on the consumer side — reading smem, waiting for wgmma, and synchronizing between warpgroups. More pipeline stages don't address these.

11. Bottleneck Model

Tensor core 60.9% active (of SM active) → 39.1% idle
  ├── [3.30] smem→register read latency (long_scoreboard)
  │     └── Consumer warps stall reading k_smem/v_smem through L1TEX
  │         L1 hit rate 72.2%, throughput 40.9% — swizzled layout contention?
  ├── [1.59] wgmma pipeline drain (wait / DEPBAR)
  │     └── Non-wgmma work (softmax, rescale, copy) finishes before wgmma
  │         → consumer idle-waits on DEPBAR
  │         → IntraWGOverlap (wait_group<1>) helps but doesn't fully hide
  └── [1.13] mbarrier synchronization (barrier)
        └── 8 mbarriers × multiple arrive/wait per K-loop iteration
            ~2% overhead from mbarrier scheduler vs named barriers (PR #872)

The two consumers ping-pong on the tensor core port. Ideal utilization would be 100% if each consumer's non-wgmma work (softmax + rescale + barriers + copies) exactly matches the wgmma latency of the other consumer. The 60.9% utilization implies the non-wgmma chain is shorter than wgmma, so consumers finish their non-wgmma work and then idle-wait for wgmma to complete.

12. Optimization Directions

Based on this profile, ranked by expected impact:

Direction Mechanism Target stall Expected impact
TMA multicast Broadcast K/V via cluster to multiple Q-head CTAs long_scoreboard (reduces L1/L2 pressure) High (for high group-ratio shapes)
Thread block cluster Prerequisite for TMA multicast; also enables distributed smem long_scoreboard, barrier High
Named barriers Replace mbarrier scheduler with PTX bar.arrive barrier (1.13) ~2-5%
Reduce softmax path length Fewer instructions in non-wgmma chain → less idle time wait (1.59) Medium
block_n=176 Larger wgmma tile → higher arithmetic intensity → wgmma takes longer → better overlap with non-wgmma work wait (1.59) Medium (needs TileLang wgmma_macro fix)
Split-K Parallelize K loop across CTAs for long sequences N/A for this shape, helps 32k/128k Medium (long-seq only)
Tile ordering / L2 locality KV-head-major persistent iteration order long_scoreboard (L2 miss reduction) Low-medium
Register pressure reduction Reduce from 168 → ~160 regs/thread Unlocks headroom for larger blocks or 3-stage Low (indirect)

Not useful for this shape:

  • 3-stage buffering (doesn't fit; wrong bottleneck)
  • Increasing DRAM bandwidth utilization (already L2-resident)

Reproduction

# Profiling script at: ncu_profile_gqa.py in the PR branch
cd /tmp/tileops-pr871
TMPDIR=/home/ga/tmp TILELANG_CLEANUP_TEMP_FILES=1 CUDA_VISIBLE_DEVICES=7 \
  ncu --set full --target-processes all -f \
  -o /home/ga/tmp/gqa_ws_ncu \
  python ncu_profile_gqa.py

# Summary
TMPDIR=/home/ga/tmp ncu -i /home/ga/tmp/gqa_ws_ncu.ncu-rep --print-summary per-kernel

NCU report file: /home/ga/tmp/gqa_ws_ncu.ncu-rep


Profiled on 2026-04-10, branch feat/flash-attn/sm90-gqa-ws-persistent @ beab3ca, H200 GPU 7.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions