Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 32 additions & 24 deletions cpp/src/neighbors/all_neighbors/all_neighbors.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,52 +9,60 @@
#include <cuvs/neighbors/all_neighbors.hpp>
#include <raft/matrix/shift.cuh>
#include <raft/util/cudart_utils.hpp>
#include <unordered_set>

namespace cuvs::neighbors::all_neighbors::detail {
using namespace cuvs::neighbors;

GRAPH_BUILD_ALGO check_params_validity(const all_neighbors_params& params,
bool do_mutual_reachability_dist)
{
using DT = cuvs::distance::DistanceType;

// InnerProduct is not supported for mutual reachability distance, because mutual reachability
// distance takes "max" of core distances and pairwise distance.
static const std::unordered_set<DT> mrd_allowed_metrics = {
DT::L2Expanded, DT::L2SqrtExpanded, DT::CosineExpanded};

static const std::unordered_set<DT> bf_allowed_metrics = {DT::L2Expanded,
DT::L2SqrtExpanded,
DT::CosineExpanded,
DT::L1,
DT::L2Unexpanded,
DT::L2SqrtUnexpanded,
DT::InnerProduct,
DT::Linf,
DT::Canberra,
DT::LpUnexpanded,
DT::CorrelationExpanded,
DT::JensenShannon};
Comment thread
jinsolp marked this conversation as resolved.

static const std::unordered_set<DT> nnd_allowed_metrics = {
DT::L2Expanded, DT::L2SqrtExpanded, DT::CosineExpanded, DT::InnerProduct};

if (std::holds_alternative<graph_build_params::brute_force_params>(params.graph_build_params)) {
if (do_mutual_reachability_dist) {
// InnerProduct is not supported for mutual reachability distance, because mutual reachability
// distance takes "max" of core distances and pairwise distance.
auto allowed_metrics = params.metric == cuvs::distance::DistanceType::L2Expanded ||
params.metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
params.metric == cuvs::distance::DistanceType::CosineExpanded;
RAFT_EXPECTS(
allowed_metrics,
mrd_allowed_metrics.count(params.metric),
"Distance metric for all-neighbors build with brute force for computing mutual "
"reachability distance should be L2Expanded, L2SqrtExpanded, or CosineExpanded.");
} else {
auto allowed_metrics = params.metric == cuvs::distance::DistanceType::L2Expanded ||
params.metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
params.metric == cuvs::distance::DistanceType::CosineExpanded ||
params.metric == cuvs::distance::DistanceType::InnerProduct;
RAFT_EXPECTS(allowed_metrics,
"Distance metric for all-neighbors build with brute force should be L2Expanded, "
"L2SqrtExpanded, CosineExpanded, or InnerProduct.");
RAFT_EXPECTS(
bf_allowed_metrics.count(params.metric),
"Distance metric for all-neighbors build with brute force should be L2Expanded, "
"L2SqrtExpanded, CosineExpanded, L1, L2Unexpanded, L2SqrtUnexpanded, InnerProduct, Linf, "
"Canberra, LpUnexpanded, CorrelationExpanded, or JensenShannon.");
}
return GRAPH_BUILD_ALGO::BRUTE_FORCE;
} else if (std::holds_alternative<graph_build_params::nn_descent_params>(
params.graph_build_params)) {
if (do_mutual_reachability_dist) {
// InnerProduct is not supported for mutual reachability distance, because mutual reachability
// distance takes "max" of core distances and pairwise distance.
auto allowed_metrics = params.metric == cuvs::distance::DistanceType::L2Expanded ||
params.metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
params.metric == cuvs::distance::DistanceType::CosineExpanded;
RAFT_EXPECTS(
allowed_metrics,
mrd_allowed_metrics.count(params.metric),
"Distance metric for all-neighbors build with NN Descent for computing mutual reachability "
"distance should be L2Expanded, L2SqrtExpanded, or CosineExpanded.");
} else {
auto allowed_metrics = params.metric == cuvs::distance::DistanceType::L2Expanded ||
params.metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
params.metric == cuvs::distance::DistanceType::CosineExpanded ||
params.metric == cuvs::distance::DistanceType::InnerProduct;
RAFT_EXPECTS(allowed_metrics,
RAFT_EXPECTS(nnd_allowed_metrics.count(params.metric),
"Distance metric for all-neighbors build with NN Descent should be L2Expanded, "
"L2SqrtExpanded, CosineExpanded, or InnerProduct.");
}
Expand Down
16 changes: 9 additions & 7 deletions python/cuvs/cuvs/neighbors/all_neighbors/all_neighbors.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0
#
# cython: language_level=3
Expand Down Expand Up @@ -111,11 +111,12 @@ cdef class AllNeighborsParams:
)

# Check metric consistency
ivf_pq_metric = ivf_pq_params.metric
if ivf_pq_metric != metric:
metric_type = DISTANCE_TYPES[metric]
ivf_pq_metric_type = DISTANCE_TYPES[ivf_pq_params.metric]
if ivf_pq_metric_type != metric_type:
raise ValueError(
f"Metric conflict: AllNeighborsParams metric '{metric}' "
f"does not match IVF-PQ metric '{ivf_pq_metric}'. Please "
f"does not match IVF-PQ metric '{ivf_pq_params.metric}'. Please "
f"ensure both use the same metric."
)

Expand All @@ -127,11 +128,12 @@ cdef class AllNeighborsParams:
)

# Check metric consistency
nn_descent_metric = nn_descent_params.metric
if nn_descent_metric != metric:
metric_type = DISTANCE_TYPES[metric]
nn_descent_metric_type = DISTANCE_TYPES[nn_descent_params.metric]
if nn_descent_metric_type != metric_type:
raise ValueError(
f"Metric conflict: AllNeighborsParams metric '{metric}' "
f"does not match NN-Descent metric '{nn_descent_metric}'. "
f"does not match NN-Descent metric '{nn_descent_params.metric}'. "
f"Please ensure both use the same metric."
)

Expand Down
52 changes: 48 additions & 4 deletions python/cuvs/cuvs/tests/test_all_neighbors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0
#

Expand Down Expand Up @@ -36,15 +36,32 @@ def make_cosine(

@pytest.mark.parametrize("algo", ["nn_descent", "brute_force", "ivf_pq"])
@pytest.mark.parametrize("cluster", ["single_cluster", "multi_cluster"])
@pytest.mark.parametrize("metric", ["sqeuclidean", "cosine"])
@pytest.mark.parametrize(
"metric",
[
"sqeuclidean",
"l2",
"cosine",
"l1",
"inner_product",
"chebyshev",
"canberra",
"minkowski",
"correlation",
"jensenshannon",
],
)
def test_all_neighbors_device_build_quality(algo, cluster, metric):
"""Test device build with quality validation against brute force ground
truth.
"""
n_rows, n_cols, k = 7151, 64, 16

if algo == "ivf_pq" and metric == "cosine":
pytest.skip("Skipping IVF-PQ with cosine distance")
ivf_pq_valid_metrics = {"sqeuclidean"}
nnd_valid_metrics = {"sqeuclidean", "l2", "cosine", "inner_product"}
is_invalid = (algo == "ivf_pq" and metric not in ivf_pq_valid_metrics) or (
algo == "nn_descent" and metric not in nnd_valid_metrics
)

if cluster == "single_cluster":
overlap_factor = 0
Expand All @@ -57,6 +74,21 @@ def test_all_neighbors_device_build_quality(algo, cluster, metric):
X, _ = make_cosine(
n_samples=n_rows, n_features=n_cols, random_state=42
)
elif metric == "jensenshannon":
# Jensen-Shannon requires non-negative values representing probability distributions
X, _ = make_blobs(
n_samples=n_rows,
n_features=n_cols,
centers=10,
cluster_std=1.0,
center_box=(0.0, 10.0), # Non-negative values only
random_state=42,
)
# Normalize each row to sum to 1 (probability distribution)
X = np.abs(X) # Ensure non-negative
row_sums = X.sum(axis=1, keepdims=True)
row_sums[row_sums == 0] = 1 # Avoid division by zero
X = X / row_sums
Comment thread
tarang-jain marked this conversation as resolved.
else:
X, _ = make_blobs(
n_samples=n_rows,
Expand Down Expand Up @@ -98,6 +130,18 @@ def test_all_neighbors_device_build_quality(algo, cluster, metric):
)

res = Resources()

if is_invalid:
with pytest.raises(Exception, match="Distance metric"):
all_neighbors.build(
X_device,
k,
params,
distances=cupy.empty((n_rows, k), dtype=cupy.float32),
resources=res,
)
return

indices, distances = all_neighbors.build(
X_device,
k,
Expand Down
Loading