Skip to content
Merged
20 changes: 15 additions & 5 deletions cpp/src/neighbors/all_neighbors/all_neighbors.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand Down Expand Up @@ -32,10 +32,20 @@ GRAPH_BUILD_ALGO check_params_validity(const all_neighbors_params& params,
auto allowed_metrics = params.metric == cuvs::distance::DistanceType::L2Expanded ||
params.metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
params.metric == cuvs::distance::DistanceType::CosineExpanded ||
params.metric == cuvs::distance::DistanceType::InnerProduct;
RAFT_EXPECTS(allowed_metrics,
"Distance metric for all-neighbors build with brute force should be L2Expanded, "
"L2SqrtExpanded, CosineExpanded, or InnerProduct.");
params.metric == cuvs::distance::DistanceType::L1 ||
params.metric == cuvs::distance::DistanceType::L2Unexpanded ||
params.metric == cuvs::distance::DistanceType::L2SqrtUnexpanded ||
params.metric == cuvs::distance::DistanceType::InnerProduct ||
params.metric == cuvs::distance::DistanceType::Linf ||
params.metric == cuvs::distance::DistanceType::Canberra ||
params.metric == cuvs::distance::DistanceType::LpUnexpanded ||
params.metric == cuvs::distance::DistanceType::CorrelationExpanded ||
params.metric == cuvs::distance::DistanceType::JensenShannon;
Comment thread
jinsolp marked this conversation as resolved.
Outdated
RAFT_EXPECTS(
allowed_metrics,
"Distance metric for all-neighbors build with brute force should be L2Expanded, "
"L2SqrtExpanded, CosineExpanded, L1, L2Unexpanded, L2SqrtUnexpanded, InnerProduct, Linf, "
"Canberra, LpUnexpanded, CorrelationExpanded, or JensenShannon.");
}
return GRAPH_BUILD_ALGO::BRUTE_FORCE;
} else if (std::holds_alternative<graph_build_params::nn_descent_params>(
Expand Down
16 changes: 9 additions & 7 deletions python/cuvs/cuvs/neighbors/all_neighbors/all_neighbors.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0
#
# cython: language_level=3
Expand Down Expand Up @@ -111,11 +111,12 @@ cdef class AllNeighborsParams:
)

# Check metric consistency
ivf_pq_metric = ivf_pq_params.metric
if ivf_pq_metric != metric:
metric_type = DISTANCE_TYPES[metric]
ivf_pq_metric_type = DISTANCE_TYPES[ivf_pq_params.metric]
if ivf_pq_metric_type != metric_type:
raise ValueError(
f"Metric conflict: AllNeighborsParams metric '{metric}' "
f"does not match IVF-PQ metric '{ivf_pq_metric}'. Please "
f"does not match IVF-PQ metric '{ivf_pq_params.metric}'. Please "
f"ensure both use the same metric."
)

Expand All @@ -127,11 +128,12 @@ cdef class AllNeighborsParams:
)

# Check metric consistency
nn_descent_metric = nn_descent_params.metric
if nn_descent_metric != metric:
metric_type = DISTANCE_TYPES[metric]
nn_descent_metric_type = DISTANCE_TYPES[nn_descent_params.metric]
if nn_descent_metric_type != metric_type:
raise ValueError(
f"Metric conflict: AllNeighborsParams metric '{metric}' "
f"does not match NN-Descent metric '{nn_descent_metric}'. "
f"does not match NN-Descent metric '{nn_descent_params.metric}'. "
f"Please ensure both use the same metric."
)

Expand Down
48 changes: 44 additions & 4 deletions python/cuvs/cuvs/tests/test_all_neighbors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0
#

Expand Down Expand Up @@ -36,15 +36,40 @@ def make_cosine(

@pytest.mark.parametrize("algo", ["nn_descent", "brute_force", "ivf_pq"])
@pytest.mark.parametrize("cluster", ["single_cluster", "multi_cluster"])
@pytest.mark.parametrize("metric", ["sqeuclidean", "cosine"])
@pytest.mark.parametrize(
"metric",
[
"sqeuclidean",
"l2",
"cosine",
"l1",
"inner_product",
"chebyshev",
"canberra",
"minkowski",
"correlation",
"jensenshannon",
],
)
def test_all_neighbors_device_build_quality(algo, cluster, metric):
"""Test device build with quality validation against brute force ground
truth.
"""
n_rows, n_cols, k = 7151, 64, 16

if algo == "ivf_pq" and metric == "cosine":
pytest.skip("Skipping IVF-PQ with cosine distance")
if algo == "ivf_pq" and metric != "sqeuclidean":
pytest.skip(
"Skipping IVF-PQ for distance metrics other than sqeuclidean"
)
elif algo == "nn_descent" and metric not in [
"sqeuclidean",
"l2",
"cosine",
"inner_product",
]:
pytest.skip(
"Skipping NN-Descent for distance metrics other than sqeuclidean, l2, cosine, or inner_product"
)
Comment thread
jinsolp marked this conversation as resolved.
Outdated

if cluster == "single_cluster":
overlap_factor = 0
Expand All @@ -57,6 +82,21 @@ def test_all_neighbors_device_build_quality(algo, cluster, metric):
X, _ = make_cosine(
n_samples=n_rows, n_features=n_cols, random_state=42
)
elif metric == "jensenshannon":
# Jensen-Shannon requires non-negative values representing probability distributions
X, _ = make_blobs(
n_samples=n_rows,
n_features=n_cols,
centers=10,
cluster_std=1.0,
center_box=(0.0, 10.0), # Non-negative values only
random_state=42,
)
# Normalize each row to sum to 1 (probability distribution)
X = np.abs(X) # Ensure non-negative
row_sums = X.sum(axis=1, keepdims=True)
row_sums[row_sums == 0] = 1 # Avoid division by zero
X = X / row_sums
Comment thread
tarang-jain marked this conversation as resolved.
else:
X, _ = make_blobs(
n_samples=n_rows,
Expand Down
Loading