Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.cursor
.vscode
/cmake-build-*
.idea
Expand Down
316 changes: 316 additions & 0 deletions nvmolkit/_fused_Butina.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,316 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import triton
import triton.language as tl

Comment on lines +16 to +19
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing SPDX license header

Every other Python source file in the project starts with the standard NVIDIA Apache-2.0 SPDX block. This new file omits it entirely.

TILE_X = 32
TILE_Y = 64

@triton.jit
def _popcount32(x):
x = x.to(tl.uint32)
return tl.inline_asm_elementwise(
asm="popc.b32 $0, $1;",
constraints="=r,r",
args=[x],
dtype=tl.uint32,
is_pure=True,
pack=1,
).to(tl.int32)


def _check_fingerprint_matrix(name: str, x: torch.Tensor) -> None:
if not isinstance(x, torch.Tensor):
raise TypeError(f"{name} must be a torch.Tensor")
if not x.is_cuda:
raise ValueError(f"{name} must be a CUDA tensor")
if x.dtype != torch.int32:
raise ValueError(f"{name} must have dtype int32")
if x.ndim != 2:
raise ValueError(f"{name} must be 2D, got shape={tuple(x.shape)}")


def _check_int32_vector(
name: str,
x: torch.Tensor,
expected_len: int,
*,
allow_larger: bool = False,
) -> None:
if not isinstance(x, torch.Tensor):
raise TypeError(f"{name} must be a torch.Tensor")
if not x.is_cuda:
raise ValueError(f"{name} must be a CUDA tensor")
if x.dtype != torch.int32:
raise ValueError(f"{name} must have dtype int32")
if x.ndim != 1:
raise ValueError(f"{name} must be 1D, got shape={tuple(x.shape)}")
if allow_larger:
if x.numel() < expected_len:
raise ValueError(f"{name} must have length >= {expected_len}, got {x.numel()}")
else:
if x.numel() != expected_len:
raise ValueError(f"{name} must have length {expected_len}, got {x.numel()}")

# pyright: reportUnreachable=false
# TODO: L2 Cache Optimizations
@triton.jit
def _update_neighbor_count_kernel(
x_ptr,
y_ptr,
neighbors_ptr,
n,
m,
K,
x_stride_n,
x_stride_k,
y_stride_n,
y_stride_k,
threshold,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
SUBTRACT: tl.constexpr,
METRIC: tl.constexpr,
):
"""Compute pairwise similarity between blocks of x and y using bit-packed fingerprints.

Atomically adds (SUBTRACT=False) or subtracts (SUBTRACT=True) the per-row
neighbor counts into ``neighbors_ptr``.
"""
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)

offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)

mask_m = offs_m < n
mask_n = offs_n < m

norm_x = tl.zeros((BLOCK_M,), dtype=tl.int32)
norm_y = tl.zeros((BLOCK_N,), dtype=tl.int32)
dots = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)

for k_block in range(0, tl.cdiv(K, BLOCK_K)):
k_offset = k_block * BLOCK_K
for kk in tl.static_range(0, BLOCK_K):
k_idx = k_offset + kk
k_mask = k_idx < K
xk = tl.load(
x_ptr + offs_m * x_stride_n + k_idx * x_stride_k,
mask=mask_m & k_mask,
other=0,
)
yk = tl.load(
y_ptr + offs_n * y_stride_n + k_idx * y_stride_k,
mask=mask_n & k_mask,
other=0,
)
norm_x += _popcount32(xk)
norm_y += _popcount32(yk)
dots += _popcount32(xk[:, None] & yk[None, :])

if METRIC == "tanimoto":
denom = norm_x[:, None] + norm_y[None, :] - dots
elif METRIC == "cosine":
denom = tl.sqrt(norm_x[:, None].to(tl.float32) * norm_y[None, :].to(tl.float32))
else:
raise ValueError(f"Invalid metric: {METRIC}")

valid = mask_m[:, None] & mask_n[None, :] & (denom > 0)

similarity = tl.where(valid, dots.to(tl.float32) / denom.to(tl.float32), 0.0)
is_neighbor = valid & (similarity >= threshold)

row_counts = tl.sum(is_neighbor.to(tl.int32), axis=1)
if SUBTRACT:
tl.atomic_add(neighbors_ptr + offs_m, -row_counts, mask=mask_m)
else:
tl.atomic_add(neighbors_ptr + offs_m, row_counts, mask=mask_m)


@triton.jit
def _extract_cluster_singleton_kernel(
x_ptr,
center_id,
is_free_ptr,
neighbors_ptr,
cluster_count_ptr,
cluster_indices_ptr,
threshold,
indices_ptr,
n,
K,
x_stride_n,
x_stride_k,
BLOCK_K: tl.constexpr,
METRIC: tl.constexpr,
):
"""For each free row, compute similarity to the cluster center.

Neighbors (similarity >= threshold) are assigned to the cluster from the
front of ``cluster_indices_ptr``; remaining rows whose neighbor degree is 1
are collected as singletons from the back.
"""
row = tl.program_id(axis=0)
row_mask = row < n

pa = tl.zeros((), dtype=tl.int32)
pb = tl.zeros((), dtype=tl.int32)
dot = tl.zeros((), dtype=tl.int32)

for k_block in range(0, tl.cdiv(K, BLOCK_K)):
k_offset = k_block * BLOCK_K
for kk in tl.static_range(0, BLOCK_K):
k_idx = k_offset + kk
k_mask = k_idx < K
center_k = tl.load(x_ptr + center_id * x_stride_n + k_idx * x_stride_k, mask=k_mask, other=0)
row_k = tl.load(
x_ptr + row * x_stride_n + k_idx * x_stride_k,
mask=row_mask & k_mask,
other=0,
)
pa += _popcount32(center_k)
pb += _popcount32(row_k)
dot += _popcount32(row_k & center_k)
Comment on lines +172 to +189
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Center molecule's popcount (pa) recomputed redundantly by every thread

In _remove_largest_cluster_kernel each of the n thread-blocks independently reloads and computes pa (the popcount of center_id's fingerprint) via a full K-dimensional scan. Since every thread reads the same center_id row and produces the same scalar, this is O(n × K/32) redundant work. Pre-computing pa on the host and passing it as a scalar parameter would eliminate this waste.


if METRIC == "tanimoto":
union = pa + pb - dot
elif METRIC == "cosine":
union = tl.sqrt(pa.to(tl.float32) * pb.to(tl.float32))
else:
raise ValueError(f"Invalid metric: {METRIC}")

row_is_free = tl.load(is_free_ptr + row, mask=row_mask, other=0)
valid = row_mask & (row_is_free != 0) & (union > 0)
similarity = tl.where(valid, dot.to(tl.float32) / union.to(tl.float32), 0.0)
is_neighbor = valid & (similarity >= threshold)

orig_idx = tl.load(indices_ptr + row, mask=row_mask, other=0)
neighbor_slot = tl.atomic_add(cluster_count_ptr + 0, 1, mask=is_neighbor)
tl.store(cluster_indices_ptr + neighbor_slot, orig_idx, mask=is_neighbor)
tl.store(is_free_ptr + row, 0, mask=is_neighbor)

degree = tl.load(neighbors_ptr + row, mask=row_mask, other=0)
is_singleton = row_mask & (~is_neighbor) & (degree == 1)
singleton_slot = tl.atomic_add(cluster_count_ptr + 1, -1, mask=is_singleton)
tl.store(cluster_indices_ptr + singleton_slot, orig_idx, mask=is_singleton)
tl.store(is_free_ptr + row, 0, mask=is_singleton)


def update_neighbor_counts(
x: torch.Tensor,
y: torch.Tensor,
neighbors: torch.Tensor,
threshold: float,
subtract: bool = False,
metric: str = "tanimoto",
) -> None:
"""Update per-row neighbor counts for fingerprints in ``x`` against ``y``.

For each row *i* in ``x``, counts how many rows in ``y`` have similarity
>= ``threshold`` and atomically adds (or subtracts when ``subtract=True``)
that count into ``neighbors[i]``.
"""
_check_fingerprint_matrix("x", x)
_check_fingerprint_matrix("y", y)
_check_int32_vector("neighbors", neighbors, x.shape[0])
if x.device != y.device or x.device != neighbors.device:
raise ValueError("x, y, and neighbors must be on the same CUDA device")
if x.shape[1] != y.shape[1]:
raise ValueError("x and y must have the same feature dimension")

n = x.shape[0]
m = y.shape[0]
K = x.shape[1]
grid = (triton.cdiv(n, TILE_X), triton.cdiv(m, TILE_Y))
_update_neighbor_count_kernel[grid](
x,
y,
neighbors,
n,
m,
K,
x.stride(0),
x.stride(1),
y.stride(0),
y.stride(1),
float(threshold),
BLOCK_M=TILE_X,
BLOCK_N=TILE_Y,
BLOCK_K=32,
num_warps=8,
SUBTRACT=subtract,
METRIC=metric,
)


def extract_cluster_and_singletons(
x: torch.Tensor,
id: int,
is_free: torch.Tensor,
neighbors: torch.Tensor,
cluster_count: torch.Tensor,
cluster_indices: torch.Tensor,
threshold: float,
indices: torch.Tensor,
metric: str = "tanimoto",
) -> None:
"""Extract the cluster around center ``id`` and collect singletons.

Every free row similar to the center (>= ``threshold``) is written into
``cluster_indices`` from the front; free rows that are not neighbors but
have a neighbor degree of 1 are collected as singletons from the back.
Both groups are marked as non-free in ``is_free``.
"""
_check_fingerprint_matrix("x", x)
n = x.shape[0]
K = x.shape[1]
_check_int32_vector("is_free", is_free, n)
_check_int32_vector("neighbors", neighbors, n)
_check_int32_vector("cluster_indices", cluster_indices, n, allow_larger=True)
_check_int32_vector("indices", indices, n)
_check_int32_vector("cluster_count", cluster_count, 2)
if not (0 <= id < n):
raise ValueError(f"id must be in [0, {n}), got {id}")
if (
x.device != is_free.device
or x.device != neighbors.device
or x.device != cluster_count.device
or x.device != cluster_indices.device
or x.device != indices.device
):
raise ValueError("all tensors must be on the same CUDA device")

grid = (n,)
_extract_cluster_singleton_kernel[grid](
x,
id,
is_free,
neighbors,
cluster_count,
cluster_indices,
float(threshold),
indices,
n,
K,
x.stride(0),
x.stride(1),
BLOCK_K=32,
num_warps=1,
METRIC=metric,
)
Loading