Skip to content
Merged
56 changes: 32 additions & 24 deletions cpp/src/neighbors/all_neighbors/all_neighbors.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,52 +9,60 @@
#include <cuvs/neighbors/all_neighbors.hpp>
#include <raft/matrix/shift.cuh>
#include <raft/util/cudart_utils.hpp>
#include <unordered_set>

namespace cuvs::neighbors::all_neighbors::detail {
using namespace cuvs::neighbors;

GRAPH_BUILD_ALGO check_params_validity(const all_neighbors_params& params,
bool do_mutual_reachability_dist)
{
using DT = cuvs::distance::DistanceType;

// InnerProduct is not supported for mutual reachability distance, because mutual reachability
// distance takes "max" of core distances and pairwise distance.
static const std::unordered_set<DT> mrd_allowed_metrics = {
DT::L2Expanded, DT::L2SqrtExpanded, DT::CosineExpanded};

static const std::unordered_set<DT> bf_allowed_metrics = {DT::L2Expanded,
DT::L2SqrtExpanded,
DT::CosineExpanded,
DT::L1,
DT::L2Unexpanded,
DT::L2SqrtUnexpanded,
DT::InnerProduct,
DT::Linf,
DT::Canberra,
DT::LpUnexpanded,
DT::CorrelationExpanded,
DT::JensenShannon};
Comment thread
jinsolp marked this conversation as resolved.

static const std::unordered_set<DT> nnd_allowed_metrics = {
DT::L2Expanded, DT::L2SqrtExpanded, DT::CosineExpanded, DT::InnerProduct};

if (std::holds_alternative<graph_build_params::brute_force_params>(params.graph_build_params)) {
if (do_mutual_reachability_dist) {
// InnerProduct is not supported for mutual reachability distance, because mutual reachability
// distance takes "max" of core distances and pairwise distance.
auto allowed_metrics = params.metric == cuvs::distance::DistanceType::L2Expanded ||
params.metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
params.metric == cuvs::distance::DistanceType::CosineExpanded;
RAFT_EXPECTS(
allowed_metrics,
mrd_allowed_metrics.count(params.metric),
"Distance metric for all-neighbors build with brute force for computing mutual "
"reachability distance should be L2Expanded, L2SqrtExpanded, or CosineExpanded.");
} else {
auto allowed_metrics = params.metric == cuvs::distance::DistanceType::L2Expanded ||
params.metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
params.metric == cuvs::distance::DistanceType::CosineExpanded ||
params.metric == cuvs::distance::DistanceType::InnerProduct;
RAFT_EXPECTS(allowed_metrics,
"Distance metric for all-neighbors build with brute force should be L2Expanded, "
"L2SqrtExpanded, CosineExpanded, or InnerProduct.");
RAFT_EXPECTS(
bf_allowed_metrics.count(params.metric),
"Distance metric for all-neighbors build with brute force should be L2Expanded, "
"L2SqrtExpanded, CosineExpanded, L1, L2Unexpanded, L2SqrtUnexpanded, InnerProduct, Linf, "
"Canberra, LpUnexpanded, CorrelationExpanded, or JensenShannon.");
}
return GRAPH_BUILD_ALGO::BRUTE_FORCE;
} else if (std::holds_alternative<graph_build_params::nn_descent_params>(
params.graph_build_params)) {
if (do_mutual_reachability_dist) {
// InnerProduct is not supported for mutual reachability distance, because mutual reachability
// distance takes "max" of core distances and pairwise distance.
auto allowed_metrics = params.metric == cuvs::distance::DistanceType::L2Expanded ||
params.metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
params.metric == cuvs::distance::DistanceType::CosineExpanded;
RAFT_EXPECTS(
allowed_metrics,
mrd_allowed_metrics.count(params.metric),
"Distance metric for all-neighbors build with NN Descent for computing mutual reachability "
"distance should be L2Expanded, L2SqrtExpanded, or CosineExpanded.");
} else {
auto allowed_metrics = params.metric == cuvs::distance::DistanceType::L2Expanded ||
params.metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
params.metric == cuvs::distance::DistanceType::CosineExpanded ||
params.metric == cuvs::distance::DistanceType::InnerProduct;
RAFT_EXPECTS(allowed_metrics,
RAFT_EXPECTS(nnd_allowed_metrics.count(params.metric),
"Distance metric for all-neighbors build with NN Descent should be L2Expanded, "
"L2SqrtExpanded, CosineExpanded, or InnerProduct.");
}
Expand Down
16 changes: 9 additions & 7 deletions python/cuvs/cuvs/neighbors/all_neighbors/all_neighbors.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0
#
# cython: language_level=3
Expand Down Expand Up @@ -111,11 +111,12 @@ cdef class AllNeighborsParams:
)

# Check metric consistency
ivf_pq_metric = ivf_pq_params.metric
if ivf_pq_metric != metric:
metric_type = DISTANCE_TYPES[metric]
ivf_pq_metric_type = DISTANCE_TYPES[ivf_pq_params.metric]
if ivf_pq_metric_type != metric_type:
raise ValueError(
f"Metric conflict: AllNeighborsParams metric '{metric}' "
f"does not match IVF-PQ metric '{ivf_pq_metric}'. Please "
f"does not match IVF-PQ metric '{ivf_pq_params.metric}'. Please "
f"ensure both use the same metric."
)

Expand All @@ -127,11 +128,12 @@ cdef class AllNeighborsParams:
)

# Check metric consistency
nn_descent_metric = nn_descent_params.metric
if nn_descent_metric != metric:
metric_type = DISTANCE_TYPES[metric]
nn_descent_metric_type = DISTANCE_TYPES[nn_descent_params.metric]
if nn_descent_metric_type != metric_type:
raise ValueError(
f"Metric conflict: AllNeighborsParams metric '{metric}' "
f"does not match NN-Descent metric '{nn_descent_metric}'. "
f"does not match NN-Descent metric '{nn_descent_params.metric}'. "
f"Please ensure both use the same metric."
)

Expand Down
48 changes: 44 additions & 4 deletions python/cuvs/cuvs/tests/test_all_neighbors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0
#

Expand Down Expand Up @@ -36,15 +36,40 @@ def make_cosine(

@pytest.mark.parametrize("algo", ["nn_descent", "brute_force", "ivf_pq"])
@pytest.mark.parametrize("cluster", ["single_cluster", "multi_cluster"])
@pytest.mark.parametrize("metric", ["sqeuclidean", "cosine"])
@pytest.mark.parametrize(
"metric",
[
"sqeuclidean",
"l2",
"cosine",
"l1",
"inner_product",
"chebyshev",
"canberra",
"minkowski",
"correlation",
"jensenshannon",
],
)
def test_all_neighbors_device_build_quality(algo, cluster, metric):
"""Test device build with quality validation against brute force ground
truth.
"""
n_rows, n_cols, k = 7151, 64, 16

if algo == "ivf_pq" and metric == "cosine":
pytest.skip("Skipping IVF-PQ with cosine distance")
if algo == "ivf_pq" and metric != "sqeuclidean":
pytest.skip(
"Skipping IVF-PQ for distance metrics other than sqeuclidean"
)
elif algo == "nn_descent" and metric not in [
"sqeuclidean",
"l2",
"cosine",
"inner_product",
]:
pytest.skip(
"Skipping NN-Descent for distance metrics other than sqeuclidean, l2, cosine, or inner_product"
)
Comment thread
jinsolp marked this conversation as resolved.
Outdated

if cluster == "single_cluster":
overlap_factor = 0
Expand All @@ -57,6 +82,21 @@ def test_all_neighbors_device_build_quality(algo, cluster, metric):
X, _ = make_cosine(
n_samples=n_rows, n_features=n_cols, random_state=42
)
elif metric == "jensenshannon":
# Jensen-Shannon requires non-negative values representing probability distributions
X, _ = make_blobs(
n_samples=n_rows,
n_features=n_cols,
centers=10,
cluster_std=1.0,
center_box=(0.0, 10.0), # Non-negative values only
random_state=42,
)
# Normalize each row to sum to 1 (probability distribution)
X = np.abs(X) # Ensure non-negative
row_sums = X.sum(axis=1, keepdims=True)
row_sums[row_sums == 0] = 1 # Avoid division by zero
X = X / row_sums
Comment thread
tarang-jain marked this conversation as resolved.
else:
X, _ = make_blobs(
n_samples=n_rows,
Expand Down
Loading