-
Notifications
You must be signed in to change notification settings - Fork 24
WIP: Add fused Butina clustering with Triton similarity kernels #125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
moradza
wants to merge
8
commits into
NVIDIA-Digital-Bio:main
Choose a base branch
from
moradza:amoradzadeh/butina_clustering
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 2 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
01e7fd1
WIP: Add fused Butina clustering with Triton similarity kernels
moradza d1d8314
Support variable-length fingerprints in Triton similarity kernels
moradza 146dfdd
Refactor Triton similarity kernels: unify add/subtract, add cosine me…
moradza c43b2e0
rename butina files
moradza 0c8a033
Add fused_butina test suite and wire up CUDA stream support
moradza 4a5e4e3
Fix device mismatch in fused_butina for multi-GPU setups
moradza 7149a85
Add noqa for wildcard import and remove stale TODO
moradza 58a0df6
Move fused_butina benchmark from clustering.py to benchmarks/
moradza File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| .cursor | ||
| .vscode | ||
| /cmake-build-* | ||
| .idea | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,346 @@ | ||
| import torch | ||
| import triton | ||
| import triton.language as tl | ||
|
|
||
| TILE_X = 32 | ||
| TILE_Y = 32 | ||
| # TODO: L2 Cache Optimizations | ||
| @triton.jit | ||
| def _popcount32(x): | ||
| # SWAR bit count fallback for Triton builds without tl.popcount. | ||
| x = x.to(tl.uint32) | ||
| x = x - ((x >> 1) & 0x55555555) | ||
| x = (x & 0x33333333) + ((x >> 2) & 0x33333333) | ||
| x = (x + (x >> 4)) & 0x0F0F0F0F | ||
| x = x * 0x01010101 | ||
| return (x >> 24).to(tl.int32) | ||
|
|
||
|
|
||
| def _check_fp_tensor(name: str, x: torch.Tensor) -> None: | ||
| if not isinstance(x, torch.Tensor): | ||
| raise TypeError(f"{name} must be a torch.Tensor") | ||
| if not x.is_cuda: | ||
| raise ValueError(f"{name} must be a CUDA tensor") | ||
| if x.dtype != torch.int32: | ||
| raise ValueError(f"{name} must have dtype int32") | ||
| if x.ndim != 2: | ||
| raise ValueError(f"{name} must be 2D, got shape={tuple(x.shape)}") | ||
|
|
||
|
|
||
| def _check_vec_tensor( | ||
| name: str, | ||
| x: torch.Tensor, | ||
| expected_len: int, | ||
| *, | ||
| allow_larger: bool = False, | ||
| ) -> None: | ||
| if not isinstance(x, torch.Tensor): | ||
| raise TypeError(f"{name} must be a torch.Tensor") | ||
| if not x.is_cuda: | ||
| raise ValueError(f"{name} must be a CUDA tensor") | ||
| if x.dtype != torch.int32: | ||
| raise ValueError(f"{name} must have dtype int32") | ||
| if x.ndim != 1: | ||
| raise ValueError(f"{name} must be 1D, got shape={tuple(x.shape)}") | ||
| if allow_larger: | ||
| if x.numel() < expected_len: | ||
| raise ValueError( | ||
| f"{name} must have length >= {expected_len}, got {x.numel()}" | ||
| ) | ||
| else: | ||
| if x.numel() != expected_len: | ||
| raise ValueError(f"{name} must have length {expected_len}, got {x.numel()}") | ||
|
|
||
|
|
||
| @triton.jit | ||
| def _similarity_neighbor_kernel( | ||
| x_ptr, | ||
| y_ptr, | ||
| neighbors_ptr, | ||
| n, | ||
| m, | ||
| K, | ||
| x_stride_n, | ||
| x_stride_k, | ||
| y_stride_n, | ||
| y_stride_k, | ||
| threshold, | ||
| BLOCK_M: tl.constexpr, | ||
| BLOCK_N: tl.constexpr, | ||
| BLOCK_K: tl.constexpr, | ||
| ): | ||
| pid_m = tl.program_id(axis=0) | ||
| pid_n = tl.program_id(axis=1) | ||
|
|
||
| offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | ||
| offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | ||
|
|
||
| mask_m = offs_m < n | ||
| mask_n = offs_n < m | ||
|
|
||
| norm_x = tl.zeros((BLOCK_M,), dtype=tl.int32) | ||
| norm_y = tl.zeros((BLOCK_N,), dtype=tl.int32) | ||
| dots = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) | ||
|
|
||
| for k_block in range(0, tl.cdiv(K, BLOCK_K)): | ||
| k_offset = k_block * BLOCK_K | ||
| for kk in tl.static_range(0, BLOCK_K): | ||
| k_idx = k_offset + kk | ||
| k_mask = k_idx < K | ||
| xk = tl.load( | ||
| x_ptr + offs_m * x_stride_n + k_idx * x_stride_k, | ||
| mask=mask_m & k_mask, | ||
| other=0, | ||
| ) | ||
| yk = tl.load( | ||
| y_ptr + offs_n * y_stride_n + k_idx * y_stride_k, | ||
| mask=mask_n & k_mask, | ||
| other=0, | ||
| ) | ||
| norm_x += _popcount32(xk) | ||
| norm_y += _popcount32(yk) | ||
| dots += _popcount32(xk[:, None] & yk[None, :]) | ||
|
|
||
| denom = norm_x[:, None] + norm_y[None, :] - dots | ||
| valid = mask_m[:, None] & mask_n[None, :] & (denom > 0) | ||
| similarity = tl.where(valid, dots.to(tl.float32) / denom.to(tl.float32), 0.0) | ||
| is_neighbor = valid & (similarity >= threshold) | ||
|
|
||
| row_counts = tl.sum(is_neighbor.to(tl.int32), axis=1) | ||
| tl.atomic_add(neighbors_ptr + offs_m, row_counts, mask=mask_m) | ||
|
|
||
|
|
||
| @triton.jit | ||
| def _subtract_similarity_neighbor_kernel( | ||
| x_ptr, | ||
| y_ptr, | ||
| neighbors_ptr, | ||
| n, | ||
| m, | ||
| K, | ||
| x_stride_n, | ||
| x_stride_k, | ||
| y_stride_n, | ||
| y_stride_k, | ||
| threshold, | ||
| BLOCK_M: tl.constexpr, | ||
| BLOCK_N: tl.constexpr, | ||
| BLOCK_K: tl.constexpr, | ||
| ): | ||
| pid_m = tl.program_id(axis=0) | ||
| pid_n = tl.program_id(axis=1) | ||
|
|
||
| offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | ||
| offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | ||
|
|
||
| mask_m = offs_m < n | ||
| mask_n = offs_n < m | ||
|
|
||
| norm_x = tl.zeros((BLOCK_M,), dtype=tl.int32) | ||
| norm_y = tl.zeros((BLOCK_N,), dtype=tl.int32) | ||
| dots = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) | ||
|
|
||
| for k_block in range(0, tl.cdiv(K, BLOCK_K)): | ||
| k_offset = k_block * BLOCK_K | ||
| for kk in tl.static_range(0, BLOCK_K): | ||
| k_idx = k_offset + kk | ||
| k_mask = k_idx < K | ||
| xk = tl.load( | ||
| x_ptr + offs_m * x_stride_n + k_idx * x_stride_k, | ||
| mask=mask_m & k_mask, | ||
| other=0, | ||
| ) | ||
| yk = tl.load( | ||
| y_ptr + offs_n * y_stride_n + k_idx * y_stride_k, | ||
| mask=mask_n & k_mask, | ||
| other=0, | ||
| ) | ||
| norm_x += _popcount32(xk) | ||
| norm_y += _popcount32(yk) | ||
| dots += _popcount32(xk[:, None] & yk[None, :]) | ||
|
|
||
| denom = norm_x[:, None] + norm_y[None, :] - dots | ||
| valid = mask_m[:, None] & mask_n[None, :] & (denom > 0) | ||
| similarity = tl.where(valid, dots.to(tl.float32) / denom.to(tl.float32), 0.0) | ||
| is_neighbor = valid & (similarity >= threshold) | ||
|
|
||
| row_counts = -tl.sum(is_neighbor.to(tl.int32), axis=1) | ||
| tl.atomic_add(neighbors_ptr + offs_m, row_counts, mask=mask_m) | ||
|
|
||
|
|
||
| @triton.jit | ||
| def _remove_largest_cluster_kernel( | ||
| x_ptr, | ||
| center_id, | ||
| is_free_ptr, | ||
| neighbors_ptr, | ||
| cluster_count_ptr, | ||
| cluster_indices_ptr, | ||
| threshold, | ||
| indices_ptr, | ||
| n, | ||
| K, | ||
| x_stride_n, | ||
| x_stride_k, | ||
| BLOCK_K: tl.constexpr, | ||
| ): | ||
| row = tl.program_id(axis=0) | ||
| row_mask = row < n | ||
|
|
||
| pa = tl.zeros((), dtype=tl.int32) | ||
| pb = tl.zeros((), dtype=tl.int32) | ||
| dot = tl.zeros((), dtype=tl.int32) | ||
|
|
||
| for k_block in range(0, tl.cdiv(K, BLOCK_K)): | ||
| k_offset = k_block * BLOCK_K | ||
| for kk in tl.static_range(0, BLOCK_K): | ||
| k_idx = k_offset + kk | ||
| k_mask = k_idx < K | ||
| center_k = tl.load(x_ptr + center_id * x_stride_n + k_idx * x_stride_k, mask=k_mask, other=0) | ||
| row_k = tl.load( | ||
| x_ptr + row * x_stride_n + k_idx * x_stride_k, | ||
| mask=row_mask & k_mask, | ||
| other=0, | ||
| ) | ||
| pa += _popcount32(center_k) | ||
| pb += _popcount32(row_k) | ||
| dot += _popcount32(row_k & center_k) | ||
|
|
||
| union = pa + pb - dot | ||
| row_is_free = tl.load(is_free_ptr + row, mask=row_mask, other=0) | ||
| valid = row_mask & (row_is_free != 0) & (union > 0) | ||
| similarity = tl.where(valid, dot.to(tl.float32) / union.to(tl.float32), 0.0) | ||
| is_neighbor = valid & (similarity >= threshold) | ||
|
|
||
| orig_idx = tl.load(indices_ptr + row, mask=row_mask, other=0) | ||
| neighbor_slot = tl.atomic_add(cluster_count_ptr + 0, 1, mask=is_neighbor) | ||
| tl.store(cluster_indices_ptr + neighbor_slot, orig_idx, mask=is_neighbor) | ||
| tl.store(is_free_ptr + row, 0, mask=is_neighbor) | ||
|
|
||
| degree = tl.load(neighbors_ptr + row, mask=row_mask, other=0) | ||
| is_singleton = row_mask & (~is_neighbor) & (degree == 1) | ||
| singleton_slot = tl.atomic_add(cluster_count_ptr + 1, -1, mask=is_singleton) | ||
| tl.store(cluster_indices_ptr + singleton_slot, orig_idx, mask=is_singleton) | ||
| tl.store(is_free_ptr + row, 0, mask=is_singleton) | ||
|
|
||
|
|
||
| def similarity_neighbor( | ||
| x: torch.Tensor, | ||
| y: torch.Tensor, | ||
| neighbors: torch.Tensor, | ||
| threshold: float, | ||
| ) -> None: | ||
| _check_fp_tensor("x", x) | ||
| _check_fp_tensor("y", y) | ||
| _check_vec_tensor("neighbors", neighbors, x.shape[0]) | ||
| if x.device != y.device or x.device != neighbors.device: | ||
| raise ValueError("x, y, and neighbors must be on the same CUDA device") | ||
| if x.shape[1] != y.shape[1]: | ||
| raise ValueError("x and y must have the same feature dimension") | ||
|
|
||
| n = x.shape[0] | ||
| m = y.shape[0] | ||
| K = x.shape[1] | ||
| grid = (triton.cdiv(n, TILE_X), triton.cdiv(m, TILE_Y)) | ||
| _similarity_neighbor_kernel[grid]( | ||
| x, | ||
| y, | ||
| neighbors, | ||
| n, | ||
| m, | ||
| K, | ||
| x.stride(0), | ||
| x.stride(1), | ||
| y.stride(0), | ||
| y.stride(1), | ||
| float(threshold), | ||
| BLOCK_M=TILE_X, | ||
| BLOCK_N=TILE_Y, | ||
| BLOCK_K=32, | ||
| num_warps=8, | ||
| ) | ||
|
|
||
|
|
||
| def subtract_similarity_neighbor( | ||
| x: torch.Tensor, | ||
| y: torch.Tensor, | ||
| neighbors: torch.Tensor, | ||
| threshold: float, | ||
| ) -> None: | ||
| _check_fp_tensor("x", x) | ||
| _check_fp_tensor("y", y) | ||
| _check_vec_tensor("neighbors", neighbors, x.shape[0]) | ||
| if x.device != y.device or x.device != neighbors.device: | ||
| raise ValueError("x, y, and neighbors must be on the same CUDA device") | ||
| if x.shape[1] != y.shape[1]: | ||
| raise ValueError("x and y must have the same feature dimension") | ||
|
|
||
| n = x.shape[0] | ||
| m = y.shape[0] | ||
| K = x.shape[1] | ||
| grid = (triton.cdiv(n, TILE_X), triton.cdiv(m, TILE_Y)) | ||
| _subtract_similarity_neighbor_kernel[grid]( | ||
| x, | ||
| y, | ||
| neighbors, | ||
| n, | ||
| m, | ||
| K, | ||
| x.stride(0), | ||
| x.stride(1), | ||
| y.stride(0), | ||
| y.stride(1), | ||
| float(threshold), | ||
| BLOCK_M=TILE_X, | ||
| BLOCK_N=TILE_Y, | ||
| BLOCK_K=32, | ||
| num_warps=8, | ||
| ) | ||
|
|
||
|
|
||
| def remove_largest_cluster( | ||
| x: torch.Tensor, | ||
| id: int, | ||
| is_free: torch.Tensor, | ||
| neighbors: torch.Tensor, | ||
| cluster_count: torch.Tensor, | ||
| cluster_indices: torch.Tensor, | ||
| threshold: float, | ||
| indices: torch.Tensor, | ||
| ) -> None: | ||
| _check_fp_tensor("x", x) | ||
| n = x.shape[0] | ||
| K = x.shape[1] | ||
| _check_vec_tensor("is_free", is_free, n) | ||
| _check_vec_tensor("neighbors", neighbors, n) | ||
| _check_vec_tensor("cluster_indices", cluster_indices, n, allow_larger=True) | ||
| _check_vec_tensor("indices", indices, n) | ||
| _check_vec_tensor("cluster_count", cluster_count, 2) | ||
| if not (0 <= id < n): | ||
| raise ValueError(f"id must be in [0, {n}), got {id}") | ||
| if ( | ||
| x.device != is_free.device | ||
| or x.device != neighbors.device | ||
| or x.device != cluster_count.device | ||
| or x.device != cluster_indices.device | ||
| or x.device != indices.device | ||
| ): | ||
| raise ValueError("all tensors must be on the same CUDA device") | ||
|
|
||
| grid = (n,) | ||
| _remove_largest_cluster_kernel[grid]( | ||
| x, | ||
| id, | ||
| is_free, | ||
| neighbors, | ||
| cluster_count, | ||
| cluster_indices, | ||
| float(threshold), | ||
| indices, | ||
| n, | ||
| K, | ||
| x.stride(0), | ||
| x.stride(1), | ||
| BLOCK_K=32, | ||
| num_warps=1, | ||
| ) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_check_fp_tensoris misnamed — it validatesint32, not floating-point tensorsThe name
_check_fp_tensor("fp" = floating-point) is misleading; it enforcesdtype == torch.int32. The check is correct but the name creates confusion for future maintainers. Consider renaming to_check_int32_tensorand updating all three call sites.