From f44cdc3a4d7a81be82fda0b707213633cc228a60 Mon Sep 17 00:00:00 2001 From: Kevin Boyd Date: Tue, 2 Jun 2026 14:19:28 -0400 Subject: [PATCH] Fix grid dimension overflow in fused Butina neighbor count kernel --- nvmolkit/_fusedButina.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/nvmolkit/_fusedButina.py b/nvmolkit/_fusedButina.py index 793b5a29..4199573c 100644 --- a/nvmolkit/_fusedButina.py +++ b/nvmolkit/_fusedButina.py @@ -17,6 +17,10 @@ import triton import triton.language as tl +# CUDA caps gridDim.y and gridDim.z at 65535, while gridDim.x allows up to 2**31-1. +# Tiles are linearized across a 2D launch grid with the bounded axis kept here. +MAX_GRID_DIM = 65535 + def get_cuda_autotune_config(): return [ @@ -121,9 +125,11 @@ def _update_neighbor_count_kernel( Atomically adds (SUBTRACT=False) or subtracts (SUBTRACT=True) the per-row neighbor counts into ``neighbors_ptr``. """ - pid = tl.program_id(axis=0) + pid = tl.program_id(axis=0).to(tl.int64) + tl.program_id(axis=1).to(tl.int64) * tl.num_programs(axis=0).to(tl.int64) num_pid_m = tl.cdiv(M, BLOCK_M) num_pid_n = tl.cdiv(N, BLOCK_N) + if pid >= num_pid_m.to(tl.int64) * num_pid_n.to(tl.int64): + return num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M @@ -271,7 +277,12 @@ def update_neighbor_counts( M = x.shape[0] N = y.shape[0] K = x.shape[1] - grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + def grid(meta): + num_tiles = triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]) + grid_y = min(num_tiles, MAX_GRID_DIM) + grid_x = triton.cdiv(num_tiles, grid_y) + return (grid_x, grid_y) + _update_neighbor_count_kernel[grid]( x, y,