diff --git a/.gitignore b/.gitignore index 6448ed3..67c89f5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.cursor .vscode /cmake-build-* .idea diff --git a/benchmarks/fused_butina_clustering_bench.py b/benchmarks/fused_butina_clustering_bench.py new file mode 100644 index 0000000..1b6d353 --- /dev/null +++ b/benchmarks/fused_butina_clustering_bench.py @@ -0,0 +1,177 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +import pandas as pd +import torch +from benchmark_timing import time_it + +from nvmolkit.clustering import fused_butina + +try: + from rdkit import DataStructs + from rdkit.DataStructs import ExplicitBitVect + from rdkit.ML.Cluster import Butina + + HAS_RDKIT = True +except ImportError: + HAS_RDKIT = False + print("RDKit not found. RDKit comparison will be skipped.") + + +def generate_data(n, num_clusters, noise_range=2, seed=42, num_words=64): + """Generate random bit vectors with underlying cluster structure.""" + torch.manual_seed(seed) + base_vectors = torch.randint( + -(2**31 - 1), 2**31 - 1, size=(num_clusters, num_words), dtype=torch.int32, device="cuda" + ) + x = torch.zeros((n, num_words), dtype=torch.int32, device="cuda") + for i in range(n): + x[i] = base_vectors[i % num_clusters] + noise = torch.randint(0, noise_range, size=(num_words,), dtype=torch.int32, device="cuda") + x[i] = x[i] ^ noise + return x + + +def get_rdkit_clusters(bit_tensor, threshold=0.5, metric="tanimoto"): + """Convert int32 tensor to RDKit ExplicitBitVects and run Butina.""" + n = bit_tensor.shape[0] + num_words = bit_tensor.shape[1] + fps = [] + for i in range(n): + bv = ExplicitBitVect(num_words * 32) + bits = bit_tensor[i].cpu().numpy() + for word_idx in range(num_words): + word = int(bits[word_idx]) + for bit_idx in range(32): + if (word >> bit_idx) & 1: + bv.SetBit(word_idx * 32 + bit_idx) + fps.append(bv) + + bulk_sim_fn = DataStructs.BulkTanimotoSimilarity if metric == "tanimoto" else DataStructs.BulkCosineSimilarity + dists = [] + for i in range(n): + dists.extend(bulk_sim_fn(fps[i], fps[:i], returnDistance=True)) + cutoff = 1.0 - threshold + clusters = Butina.ClusterData(dists, n, cutoff, isDistData=True, reordering=True) + return clusters + + +def run_test(n, threshold, num_clusters, noise_range=2, seed=42, num_words=64, metric="tanimoto"): + """Run a single comparison test between fused Butina and RDKit.""" + print(f"\n{'=' * 60}") + print( + f"Test: n={n}, threshold={threshold}, clusters={num_clusters}, noise={noise_range}, words={num_words}, metric={metric}" + ) + print(f"{'=' * 60}") + + x = generate_data(n, num_clusters, noise_range=noise_range, seed=seed, num_words=num_words) + cutoff = 1.0 - threshold + + triton_result = time_it(lambda: fused_butina(x, cutoff=cutoff, metric=metric), gpu_sync=True) + warp_clusters, _ = fused_butina(x, cutoff=cutoff, metric=metric) + torch.cuda.synchronize() + print(f"Triton: {triton_result.median_ms:.2f}ms (median), found {len(warp_clusters)} clusters") + + rdkit_time_ms = 0.0 + passed = True + if HAS_RDKIT: + try: + rdkit_result = time_it(lambda: get_rdkit_clusters(x, threshold=threshold, metric=metric), runs=1) + rdkit_clusters = get_rdkit_clusters(x, threshold=threshold, metric=metric) + rdkit_time_ms = rdkit_result.median_ms + print(f"RDKit: {rdkit_time_ms:.2f}ms (median), found {len(rdkit_clusters)} clusters") + + rdkit_set = {tuple(sorted(c)) for c in rdkit_clusters} + warp_set = {tuple(sorted(c)) for c in warp_clusters} + passed = rdkit_set == warp_set + if passed: + print("SUCCESS: Clusters match exactly!") + else: + print("DIFFERENCE DETECTED!") + print(f" Clusters only in RDKit: {len(rdkit_set - warp_set)}") + print(f" Clusters only in Triton: {len(warp_set - rdkit_set)}") + except Exception as e: + print(f"Error running RDKit: {e}") + + return { + "n": n, + "threshold": threshold, + "num_clusters": num_clusters, + "noise_range": noise_range, + "num_words": num_words, + "metric": metric, + "triton_median_ms": triton_result.median_ms, + "triton_std_ms": triton_result.std_ms, + "rdkit_median_ms": rdkit_time_ms, + "passed": passed, + } + + +if __name__ == "__main__": + metric = sys.argv[1] if len(sys.argv) > 1 else "tanimoto" + if metric not in ("tanimoto", "cosine"): + print("Usage: python fused_butina_clustering_bench.py [tanimoto|cosine]") + sys.exit(1) + + test_configs = [ + # (n, threshold, num_clusters, noise_range, num_words) + (100, 0.3, 20, 2, 32), + (100, 0.5, 20, 2, 32), + (100, 0.7, 20, 2, 64), + (100, 0.9, 20, 2, 64), + (500, 0.4, 50, 2, 32), + (500, 0.6, 50, 2, 32), + (500, 0.8, 50, 2, 64), + (1000, 0.3, 100, 2, 32), + (1000, 0.5, 100, 2, 64), + (1000, 0.7, 100, 2, 64), + (5000, 0.5, 200, 2, 32), + (5000, 0.7, 200, 2, 64), + (10000, 0.5, 500, 2, 32), + (10000, 0.5, 2000, 2, 64), + # Denser clusters (lower noise) with tight threshold + (1000, 0.9, 100, 1, 32), + # Sparser clusters (higher noise) with loose threshold + (1000, 0.3, 100, 4, 64), + # Many small clusters + (2000, 0.5, 1000, 2, 32), + # Few large clusters + (2000, 0.5, 10, 2, 64), + # (100000, 0.7, 1000, 128, 32), + ] + + results = [] + try: + for n, threshold, num_clusters, noise_range, num_words in test_configs: + result = run_test(n, threshold, num_clusters, noise_range=noise_range, num_words=num_words, metric=metric) + results.append(result) + except Exception as e: + print(f"Got exception: {e}, exiting early") + + df = pd.DataFrame(results) + print(f"\n{'=' * 60}") + print("SUMMARY") + print(f"{'=' * 60}") + print(df.to_string(index=False)) + + all_passed = all(r["passed"] for r in results) + n_passed = sum(1 for r in results if r["passed"]) + print(f"\n{n_passed}/{len(results)} tests passed.") + + df.to_csv("fused_butina_results.csv", index=False) + if not all_passed: + sys.exit(1) diff --git a/nvmolkit/_fused_Butina.py b/nvmolkit/_fused_Butina.py new file mode 100644 index 0000000..5cab14d --- /dev/null +++ b/nvmolkit/_fused_Butina.py @@ -0,0 +1,316 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import triton +import triton.language as tl + +TILE_X = 32 +TILE_Y = 64 + +@triton.jit +def _popcount32(x): + x = x.to(tl.uint32) + return tl.inline_asm_elementwise( + asm="popc.b32 $0, $1;", + constraints="=r,r", + args=[x], + dtype=tl.uint32, + is_pure=True, + pack=1, + ).to(tl.int32) + + +def _check_fingerprint_matrix(name: str, x: torch.Tensor) -> None: + if not isinstance(x, torch.Tensor): + raise TypeError(f"{name} must be a torch.Tensor") + if not x.is_cuda: + raise ValueError(f"{name} must be a CUDA tensor") + if x.dtype != torch.int32: + raise ValueError(f"{name} must have dtype int32") + if x.ndim != 2: + raise ValueError(f"{name} must be 2D, got shape={tuple(x.shape)}") + + +def _check_int32_vector( + name: str, + x: torch.Tensor, + expected_len: int, + *, + allow_larger: bool = False, +) -> None: + if not isinstance(x, torch.Tensor): + raise TypeError(f"{name} must be a torch.Tensor") + if not x.is_cuda: + raise ValueError(f"{name} must be a CUDA tensor") + if x.dtype != torch.int32: + raise ValueError(f"{name} must have dtype int32") + if x.ndim != 1: + raise ValueError(f"{name} must be 1D, got shape={tuple(x.shape)}") + if allow_larger: + if x.numel() < expected_len: + raise ValueError(f"{name} must have length >= {expected_len}, got {x.numel()}") + else: + if x.numel() != expected_len: + raise ValueError(f"{name} must have length {expected_len}, got {x.numel()}") + +# pyright: reportUnreachable=false +# TODO: L2 Cache Optimizations +@triton.jit +def _update_neighbor_count_kernel( + x_ptr, + y_ptr, + neighbors_ptr, + n, + m, + K, + x_stride_n, + x_stride_k, + y_stride_n, + y_stride_k, + threshold, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + SUBTRACT: tl.constexpr, + METRIC: tl.constexpr, +): + """Compute pairwise similarity between blocks of x and y using bit-packed fingerprints. + + Atomically adds (SUBTRACT=False) or subtracts (SUBTRACT=True) the per-row + neighbor counts into ``neighbors_ptr``. + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + mask_m = offs_m < n + mask_n = offs_n < m + + norm_x = tl.zeros((BLOCK_M,), dtype=tl.int32) + norm_y = tl.zeros((BLOCK_N,), dtype=tl.int32) + dots = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + + for k_block in range(0, tl.cdiv(K, BLOCK_K)): + k_offset = k_block * BLOCK_K + for kk in tl.static_range(0, BLOCK_K): + k_idx = k_offset + kk + k_mask = k_idx < K + xk = tl.load( + x_ptr + offs_m * x_stride_n + k_idx * x_stride_k, + mask=mask_m & k_mask, + other=0, + ) + yk = tl.load( + y_ptr + offs_n * y_stride_n + k_idx * y_stride_k, + mask=mask_n & k_mask, + other=0, + ) + norm_x += _popcount32(xk) + norm_y += _popcount32(yk) + dots += _popcount32(xk[:, None] & yk[None, :]) + + if METRIC == "tanimoto": + denom = norm_x[:, None] + norm_y[None, :] - dots + elif METRIC == "cosine": + denom = tl.sqrt(norm_x[:, None].to(tl.float32) * norm_y[None, :].to(tl.float32)) + else: + raise ValueError(f"Invalid metric: {METRIC}") + + valid = mask_m[:, None] & mask_n[None, :] & (denom > 0) + + similarity = tl.where(valid, dots.to(tl.float32) / denom.to(tl.float32), 0.0) + is_neighbor = valid & (similarity >= threshold) + + row_counts = tl.sum(is_neighbor.to(tl.int32), axis=1) + if SUBTRACT: + tl.atomic_add(neighbors_ptr + offs_m, -row_counts, mask=mask_m) + else: + tl.atomic_add(neighbors_ptr + offs_m, row_counts, mask=mask_m) + + +@triton.jit +def _extract_cluster_singleton_kernel( + x_ptr, + center_id, + is_free_ptr, + neighbors_ptr, + cluster_count_ptr, + cluster_indices_ptr, + threshold, + indices_ptr, + n, + K, + x_stride_n, + x_stride_k, + BLOCK_K: tl.constexpr, + METRIC: tl.constexpr, +): + """For each free row, compute similarity to the cluster center. + + Neighbors (similarity >= threshold) are assigned to the cluster from the + front of ``cluster_indices_ptr``; remaining rows whose neighbor degree is 1 + are collected as singletons from the back. + """ + row = tl.program_id(axis=0) + row_mask = row < n + + pa = tl.zeros((), dtype=tl.int32) + pb = tl.zeros((), dtype=tl.int32) + dot = tl.zeros((), dtype=tl.int32) + + for k_block in range(0, tl.cdiv(K, BLOCK_K)): + k_offset = k_block * BLOCK_K + for kk in tl.static_range(0, BLOCK_K): + k_idx = k_offset + kk + k_mask = k_idx < K + center_k = tl.load(x_ptr + center_id * x_stride_n + k_idx * x_stride_k, mask=k_mask, other=0) + row_k = tl.load( + x_ptr + row * x_stride_n + k_idx * x_stride_k, + mask=row_mask & k_mask, + other=0, + ) + pa += _popcount32(center_k) + pb += _popcount32(row_k) + dot += _popcount32(row_k & center_k) + + if METRIC == "tanimoto": + union = pa + pb - dot + elif METRIC == "cosine": + union = tl.sqrt(pa.to(tl.float32) * pb.to(tl.float32)) + else: + raise ValueError(f"Invalid metric: {METRIC}") + + row_is_free = tl.load(is_free_ptr + row, mask=row_mask, other=0) + valid = row_mask & (row_is_free != 0) & (union > 0) + similarity = tl.where(valid, dot.to(tl.float32) / union.to(tl.float32), 0.0) + is_neighbor = valid & (similarity >= threshold) + + orig_idx = tl.load(indices_ptr + row, mask=row_mask, other=0) + neighbor_slot = tl.atomic_add(cluster_count_ptr + 0, 1, mask=is_neighbor) + tl.store(cluster_indices_ptr + neighbor_slot, orig_idx, mask=is_neighbor) + tl.store(is_free_ptr + row, 0, mask=is_neighbor) + + degree = tl.load(neighbors_ptr + row, mask=row_mask, other=0) + is_singleton = row_mask & (~is_neighbor) & (degree == 1) + singleton_slot = tl.atomic_add(cluster_count_ptr + 1, -1, mask=is_singleton) + tl.store(cluster_indices_ptr + singleton_slot, orig_idx, mask=is_singleton) + tl.store(is_free_ptr + row, 0, mask=is_singleton) + + +def update_neighbor_counts( + x: torch.Tensor, + y: torch.Tensor, + neighbors: torch.Tensor, + threshold: float, + subtract: bool = False, + metric: str = "tanimoto", +) -> None: + """Update per-row neighbor counts for fingerprints in ``x`` against ``y``. + + For each row *i* in ``x``, counts how many rows in ``y`` have similarity + >= ``threshold`` and atomically adds (or subtracts when ``subtract=True``) + that count into ``neighbors[i]``. + """ + _check_fingerprint_matrix("x", x) + _check_fingerprint_matrix("y", y) + _check_int32_vector("neighbors", neighbors, x.shape[0]) + if x.device != y.device or x.device != neighbors.device: + raise ValueError("x, y, and neighbors must be on the same CUDA device") + if x.shape[1] != y.shape[1]: + raise ValueError("x and y must have the same feature dimension") + + n = x.shape[0] + m = y.shape[0] + K = x.shape[1] + grid = (triton.cdiv(n, TILE_X), triton.cdiv(m, TILE_Y)) + _update_neighbor_count_kernel[grid]( + x, + y, + neighbors, + n, + m, + K, + x.stride(0), + x.stride(1), + y.stride(0), + y.stride(1), + float(threshold), + BLOCK_M=TILE_X, + BLOCK_N=TILE_Y, + BLOCK_K=32, + num_warps=8, + SUBTRACT=subtract, + METRIC=metric, + ) + + +def extract_cluster_and_singletons( + x: torch.Tensor, + id: int, + is_free: torch.Tensor, + neighbors: torch.Tensor, + cluster_count: torch.Tensor, + cluster_indices: torch.Tensor, + threshold: float, + indices: torch.Tensor, + metric: str = "tanimoto", +) -> None: + """Extract the cluster around center ``id`` and collect singletons. + + Every free row similar to the center (>= ``threshold``) is written into + ``cluster_indices`` from the front; free rows that are not neighbors but + have a neighbor degree of 1 are collected as singletons from the back. + Both groups are marked as non-free in ``is_free``. + """ + _check_fingerprint_matrix("x", x) + n = x.shape[0] + K = x.shape[1] + _check_int32_vector("is_free", is_free, n) + _check_int32_vector("neighbors", neighbors, n) + _check_int32_vector("cluster_indices", cluster_indices, n, allow_larger=True) + _check_int32_vector("indices", indices, n) + _check_int32_vector("cluster_count", cluster_count, 2) + if not (0 <= id < n): + raise ValueError(f"id must be in [0, {n}), got {id}") + if ( + x.device != is_free.device + or x.device != neighbors.device + or x.device != cluster_count.device + or x.device != cluster_indices.device + or x.device != indices.device + ): + raise ValueError("all tensors must be on the same CUDA device") + + grid = (n,) + _extract_cluster_singleton_kernel[grid]( + x, + id, + is_free, + neighbors, + cluster_count, + cluster_indices, + float(threshold), + indices, + n, + K, + x.stride(0), + x.stride(1), + BLOCK_K=32, + num_warps=1, + METRIC=metric, + ) diff --git a/nvmolkit/clustering.py b/nvmolkit/clustering.py index 04b958f..8845d51 100644 --- a/nvmolkit/clustering.py +++ b/nvmolkit/clustering.py @@ -19,6 +19,7 @@ from nvmolkit import _clustering from nvmolkit._arrayHelpers import * # noqa: F403 +from nvmolkit._fused_Butina import extract_cluster_and_singletons, update_neighbor_counts from nvmolkit.types import AsyncGpuResult _VALID_NEIGHBORLIST_SIZES = frozenset({8, 16, 24, 32, 64, 128}) @@ -31,8 +32,7 @@ def butina( return_centroids: bool = False, stream: torch.cuda.Stream | None = None, ) -> AsyncGpuResult | tuple[AsyncGpuResult, AsyncGpuResult]: - """ - Perform Butina clustering on a distance matrix. + """Perform Butina clustering on a distance matrix. The Butina algorithm is a deterministic clustering method that groups items based on distance thresholds. It iteratively: @@ -81,3 +81,88 @@ def butina( clusters, centroids = result return AsyncGpuResult(clusters), AsyncGpuResult(centroids) return AsyncGpuResult(result) + + +def fused_butina( + x: torch.Tensor, + cutoff: float, + return_centroids: bool = False, + stream: torch.cuda.Stream | None = None, + metric: str = "tanimoto", +): + """Perform fused Butina clustering on a set of fingerprints. + + This function uses a fused implementation of Butina clustering that computes + similarities and neighbors on-the-fly, avoiding the need to compute and store + the full distance matrix. This makes it suitable for large datasets. + + Args: + x: Tensor of shape (N, D) containing the fingerprints to cluster. + cutoff: Distance threshold for clustering. Items are neighbors if their + distance is less than this cutoff (i.e. similarity > 1 - cutoff). + return_centroids: Whether to return centroid indices for each cluster. + stream: CUDA stream to use. If None, uses the current stream. + metric: Metric to use for similarity computation. Currently only "tanimoto" + and "cosine" are supported. + + Returns: + A tuple ``(clusters, cluster_sizes)`` where *clusters* is a list of tuples + representing each cluster (with the first element being the centroid), and + *cluster_sizes* is a list of cumulative cluster sizes. + If ``return_centroids`` is True, returns a tuple ``(clusters, cluster_sizes, centroids)`` + where *centroids* is a list of centroid indices. + """ + if metric not in ["tanimoto", "cosine"]: + raise ValueError(f"metric must be one of ['tanimoto', 'cosine'], got {metric}") + if stream is not None and not isinstance(stream, torch.cuda.Stream): + raise TypeError(f"stream must be a torch.cuda.Stream or None, got {type(stream).__name__}") + with torch.cuda.stream(stream): + n_start = x.shape[0] + device = x.device + indices = torch.arange(n_start, dtype=torch.int32, device=device) + cluster_count = torch.zeros(2, dtype=torch.int32, device=device) + cluster_count[1] = n_start - 1 + cluster_indices = torch.zeros(n_start, dtype=torch.int32, device=device) + cluster_sizes = [0] + centroids = [] + is_free = torch.ones(n_start, dtype=torch.int32, device=device) + neigh = torch.zeros(n_start, dtype=torch.int32, device=device) + threshold = float(1 - cutoff) + y = x + first_run = True + while cluster_count[0].item() < cluster_count[1].item(): + update_neighbor_counts(x, y, neigh, threshold, subtract=not first_run, metric=metric) + first_run = False + + max_val = neigh.max().item() + if max_val == 0: + break + id_max = neigh.shape[0] - 1 - neigh.flip(0).contiguous().argmax().item() + centroids.append(indices[id_max].item()) + + extract_cluster_and_singletons( + x, id_max, is_free, neigh, cluster_count, cluster_indices, threshold, indices, metric=metric + ) + cluster_sizes.append(cluster_count[0].item()) + x, y = x[is_free.bool(), :].contiguous(), x[~is_free.bool(), :].contiguous() + indices = indices[is_free.bool()].contiguous() + neigh = neigh[is_free.bool()].contiguous() + is_free = torch.ones(x.shape[0], dtype=torch.int32, device=x.device) + + for i in range(n_start - cluster_sizes[-1]): + item = cluster_sizes[-1] + cluster_sizes.append(cluster_sizes[-1] + 1) + centroids.append(cluster_indices[item].item()) + clusters = [] + indices_cpu = cluster_indices.cpu().numpy() + for i in range(len(cluster_sizes) - 1): + start_idx = cluster_sizes[i] + end_idx = cluster_sizes[i + 1] + cluster_members = indices_cpu[start_idx:end_idx].tolist() + + centroid = centroids[i] + members = [centroid] + [m for m in cluster_members if m != centroid] + clusters.append(tuple(members)) + if return_centroids: + return clusters, cluster_sizes, centroids + return clusters, cluster_sizes diff --git a/nvmolkit/tests/test_clustering.py b/nvmolkit/tests/test_clustering.py index de081bd..56c8fc2 100644 --- a/nvmolkit/tests/test_clustering.py +++ b/nvmolkit/tests/test_clustering.py @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np import pytest import torch -import numpy as np -from nvmolkit.clustering import butina + +from nvmolkit.clustering import butina, fused_butina def check_butina_correctness(hit_mat, clusts): @@ -143,3 +144,152 @@ def test_butina_invalid_neighborlist_max_size(invalid_size): dists = torch.zeros(n, n, dtype=torch.float64) with pytest.raises(ValueError, match="neighborlist_max_size must be one of"): butina(dists, 0.1, neighborlist_max_size=invalid_size) + + +# --------------------------------------------------------------------------- +# Helpers for fused_butina tests +# --------------------------------------------------------------------------- + + +def generate_clustered_fingerprints(n, num_words=32, num_clusters=10, noise_range=2, seed=42): + """Create bit-packed int32 fingerprints with controllable cluster structure.""" + torch.manual_seed(seed) + base_vectors = torch.randint(-(2**31 - 1), 2**31 - 1, size=(num_clusters, num_words), dtype=torch.int32).cuda() + x = torch.zeros((n, num_words), dtype=torch.int32, device="cuda") + for i in range(n): + x[i] = base_vectors[i % num_clusters] + noise = torch.randint(0, noise_range, size=(num_words,), dtype=torch.int32, device="cuda") + x[i] = x[i] ^ noise + return x + + +def compute_pairwise_similarity_cpu(x_np, metric="tanimoto"): + """Compute NxN similarity from (N, D) int32 bit-packed fingerprints on CPU.""" + n, d = x_np.shape + bits = np.unpackbits(x_np.view(np.uint8).reshape(n, d * 4), axis=1, bitorder="little").astype(np.float64) + popcnt = bits.sum(axis=1) + dots = bits @ bits.T + if metric == "tanimoto": + denom = popcnt[:, None] + popcnt[None, :] - dots + sim = np.where(denom > 0, dots / denom, 0.0) + elif metric == "cosine": + denom = np.sqrt(popcnt[:, None] * popcnt[None, :]) + sim = np.where(denom > 0, dots / denom, 0.0) + else: + raise ValueError(f"Unknown metric: {metric}") + return sim + + +def check_fused_butina_basic(clusters, cluster_sizes, n): + """Structural sanity checks on fused_butina output.""" + all_items = [] + for c in clusters: + all_items.extend(c) + assert sorted(all_items) == list(range(n)), "Not all items assigned exactly once" + + assert cluster_sizes[0] == 0 + assert cluster_sizes[-1] == n + assert len(cluster_sizes) == len(clusters) + 1 + for i in range(len(clusters)): + assert cluster_sizes[i + 1] - cluster_sizes[i] == len(clusters[i]) + + sizes = [len(c) for c in clusters] + for i in range(len(sizes) - 1): + assert sizes[i] >= sizes[i + 1], "Clusters not in non-increasing size order" + + +# --------------------------------------------------------------------------- +# fused_butina tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "n,metric,num_words", + [ + (50, "tanimoto", 32), + (100, "tanimoto", 64), + (200, "tanimoto", 32), + (50, "cosine", 32), + (100, "cosine", 64), + (200, "cosine", 32), + ], +) +def test_fused_butina_basic_correctness(n, metric, num_words): + x = generate_clustered_fingerprints(n, num_words=num_words, num_clusters=10) + cutoff = 0.4 + clusters, cluster_sizes = fused_butina(x, cutoff=cutoff, metric=metric) + + check_fused_butina_basic(clusters, cluster_sizes, n) + + sim = compute_pairwise_similarity_cpu(x.cpu().numpy(), metric=metric) + hit_mat = torch.tensor(sim >= (1.0 - cutoff), dtype=torch.bool).cuda() + check_butina_correctness(hit_mat, clusters) + + +def test_fused_butina_single_item(): + x = torch.randint(-(2**31 - 1), 2**31 - 1, (1, 32), dtype=torch.int32).cuda() + clusters, cluster_sizes = fused_butina(x, cutoff=0.5) + assert len(clusters) == 1 + assert clusters[0] == (0,) + assert cluster_sizes == [0, 1] + + +@pytest.mark.parametrize("metric", ["tanimoto", "cosine"]) +def test_fused_butina_all_identical(metric): + n = 50 + base = torch.randint(-(2**31 - 1), 2**31 - 1, (1, 32), dtype=torch.int32).cuda() + x = base.expand(n, -1).contiguous() + clusters, _cluster_sizes = fused_butina(x, cutoff=0.5, metric=metric) + assert len(clusters) == 1 + assert len(clusters[0]) == n + assert set(clusters[0]) == set(range(n)) + + +@pytest.mark.parametrize("metric", ["tanimoto", "cosine"]) +def test_fused_butina_all_singletons(metric): + n = 50 + torch.manual_seed(42) + x = torch.randint(-(2**31 - 1), 2**31 - 1, (n, 32), dtype=torch.int32).cuda() + clusters, _cluster_sizes = fused_butina(x, cutoff=0.001, metric=metric) + assert len(clusters) == n + for c in clusters: + assert len(c) == 1 + + +@pytest.mark.parametrize("n,metric", [(50, "tanimoto"), (50, "cosine"), (200, "tanimoto"), (200, "cosine")]) +def test_fused_butina_return_centroids(n, metric): + cutoff = 0.4 + x = generate_clustered_fingerprints(n, num_words=32, num_clusters=10) + clusters, _cluster_sizes, centroids = fused_butina(x, cutoff=cutoff, return_centroids=True, metric=metric) + + assert len(centroids) == len(clusters) + sim = compute_pairwise_similarity_cpu(x.cpu().numpy(), metric=metric) + threshold = 1.0 - cutoff + + for cluster, centroid in zip(clusters, centroids): + assert cluster[0] == centroid + assert 0 <= centroid < n + for member in cluster: + if member != centroid: + assert sim[centroid, member] >= threshold - 1e-6 + + +def test_fused_butina_on_explicit_stream(): + n = 100 + x = generate_clustered_fingerprints(n, num_words=32, num_clusters=10) + s = torch.cuda.Stream() + clusters, cluster_sizes = fused_butina(x, cutoff=0.4, stream=s) + s.synchronize() + check_fused_butina_basic(clusters, cluster_sizes, n) + + +def test_fused_butina_invalid_metric(): + x = torch.randint(-(2**31 - 1), 2**31 - 1, (10, 32), dtype=torch.int32).cuda() + with pytest.raises(ValueError, match="metric must be one of"): + fused_butina(x, cutoff=0.5, metric="euclidean") + + +def test_fused_butina_invalid_stream_type(): + x = torch.randint(-(2**31 - 1), 2**31 - 1, (10, 32), dtype=torch.int32).cuda() + with pytest.raises(TypeError): + fused_butina(x, cutoff=0.5, stream=42) diff --git a/pyproject.toml b/pyproject.toml index c3c667c..3263759 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ description = "Python bindings for the nvMolKit project" dependencies = [ "numpy", "torch", + "triton", ] [project.optional-dependencies]