diff --git a/cpp/src/neighbors/all_neighbors/all_neighbors.cuh b/cpp/src/neighbors/all_neighbors/all_neighbors.cuh index b6831d1f3c..02de3ea799 100644 --- a/cpp/src/neighbors/all_neighbors/all_neighbors.cuh +++ b/cpp/src/neighbors/all_neighbors/all_neighbors.cuh @@ -9,6 +9,7 @@ #include #include #include +#include namespace cuvs::neighbors::all_neighbors::detail { using namespace cuvs::neighbors; @@ -16,45 +17,52 @@ using namespace cuvs::neighbors; GRAPH_BUILD_ALGO check_params_validity(const all_neighbors_params& params, bool do_mutual_reachability_dist) { + using DT = cuvs::distance::DistanceType; + + // InnerProduct is not supported for mutual reachability distance, because mutual reachability + // distance takes "max" of core distances and pairwise distance. + static const std::unordered_set
mrd_allowed_metrics = { + DT::L2Expanded, DT::L2SqrtExpanded, DT::CosineExpanded}; + + static const std::unordered_set
bf_allowed_metrics = {DT::L2Expanded, + DT::L2SqrtExpanded, + DT::CosineExpanded, + DT::L1, + DT::L2Unexpanded, + DT::L2SqrtUnexpanded, + DT::InnerProduct, + DT::Linf, + DT::Canberra, + DT::LpUnexpanded, + DT::CorrelationExpanded, + DT::JensenShannon}; + + static const std::unordered_set
nnd_allowed_metrics = { + DT::L2Expanded, DT::L2SqrtExpanded, DT::CosineExpanded, DT::InnerProduct}; + if (std::holds_alternative(params.graph_build_params)) { if (do_mutual_reachability_dist) { - // InnerProduct is not supported for mutual reachability distance, because mutual reachability - // distance takes "max" of core distances and pairwise distance. - auto allowed_metrics = params.metric == cuvs::distance::DistanceType::L2Expanded || - params.metric == cuvs::distance::DistanceType::L2SqrtExpanded || - params.metric == cuvs::distance::DistanceType::CosineExpanded; RAFT_EXPECTS( - allowed_metrics, + mrd_allowed_metrics.count(params.metric), "Distance metric for all-neighbors build with brute force for computing mutual " "reachability distance should be L2Expanded, L2SqrtExpanded, or CosineExpanded."); } else { - auto allowed_metrics = params.metric == cuvs::distance::DistanceType::L2Expanded || - params.metric == cuvs::distance::DistanceType::L2SqrtExpanded || - params.metric == cuvs::distance::DistanceType::CosineExpanded || - params.metric == cuvs::distance::DistanceType::InnerProduct; - RAFT_EXPECTS(allowed_metrics, - "Distance metric for all-neighbors build with brute force should be L2Expanded, " - "L2SqrtExpanded, CosineExpanded, or InnerProduct."); + RAFT_EXPECTS( + bf_allowed_metrics.count(params.metric), + "Distance metric for all-neighbors build with brute force should be L2Expanded, " + "L2SqrtExpanded, CosineExpanded, L1, L2Unexpanded, L2SqrtUnexpanded, InnerProduct, Linf, " + "Canberra, LpUnexpanded, CorrelationExpanded, or JensenShannon."); } return GRAPH_BUILD_ALGO::BRUTE_FORCE; } else if (std::holds_alternative( params.graph_build_params)) { if (do_mutual_reachability_dist) { - // InnerProduct is not supported for mutual reachability distance, because mutual reachability - // distance takes "max" of core distances and pairwise distance. - auto allowed_metrics = params.metric == cuvs::distance::DistanceType::L2Expanded || - params.metric == cuvs::distance::DistanceType::L2SqrtExpanded || - params.metric == cuvs::distance::DistanceType::CosineExpanded; RAFT_EXPECTS( - allowed_metrics, + mrd_allowed_metrics.count(params.metric), "Distance metric for all-neighbors build with NN Descent for computing mutual reachability " "distance should be L2Expanded, L2SqrtExpanded, or CosineExpanded."); } else { - auto allowed_metrics = params.metric == cuvs::distance::DistanceType::L2Expanded || - params.metric == cuvs::distance::DistanceType::L2SqrtExpanded || - params.metric == cuvs::distance::DistanceType::CosineExpanded || - params.metric == cuvs::distance::DistanceType::InnerProduct; - RAFT_EXPECTS(allowed_metrics, + RAFT_EXPECTS(nnd_allowed_metrics.count(params.metric), "Distance metric for all-neighbors build with NN Descent should be L2Expanded, " "L2SqrtExpanded, CosineExpanded, or InnerProduct."); } diff --git a/python/cuvs/cuvs/neighbors/all_neighbors/all_neighbors.pyx b/python/cuvs/cuvs/neighbors/all_neighbors/all_neighbors.pyx index fb4750d90d..ce920b47c1 100644 --- a/python/cuvs/cuvs/neighbors/all_neighbors/all_neighbors.pyx +++ b/python/cuvs/cuvs/neighbors/all_neighbors/all_neighbors.pyx @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 # # cython: language_level=3 @@ -111,11 +111,12 @@ cdef class AllNeighborsParams: ) # Check metric consistency - ivf_pq_metric = ivf_pq_params.metric - if ivf_pq_metric != metric: + metric_type = DISTANCE_TYPES[metric] + ivf_pq_metric_type = DISTANCE_TYPES[ivf_pq_params.metric] + if ivf_pq_metric_type != metric_type: raise ValueError( f"Metric conflict: AllNeighborsParams metric '{metric}' " - f"does not match IVF-PQ metric '{ivf_pq_metric}'. Please " + f"does not match IVF-PQ metric '{ivf_pq_params.metric}'. Please " f"ensure both use the same metric." ) @@ -127,11 +128,12 @@ cdef class AllNeighborsParams: ) # Check metric consistency - nn_descent_metric = nn_descent_params.metric - if nn_descent_metric != metric: + metric_type = DISTANCE_TYPES[metric] + nn_descent_metric_type = DISTANCE_TYPES[nn_descent_params.metric] + if nn_descent_metric_type != metric_type: raise ValueError( f"Metric conflict: AllNeighborsParams metric '{metric}' " - f"does not match NN-Descent metric '{nn_descent_metric}'. " + f"does not match NN-Descent metric '{nn_descent_params.metric}'. " f"Please ensure both use the same metric." ) diff --git a/python/cuvs/cuvs/tests/test_all_neighbors.py b/python/cuvs/cuvs/tests/test_all_neighbors.py index a232a58af8..2de5846054 100644 --- a/python/cuvs/cuvs/tests/test_all_neighbors.py +++ b/python/cuvs/cuvs/tests/test_all_neighbors.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 # @@ -36,15 +36,32 @@ def make_cosine( @pytest.mark.parametrize("algo", ["nn_descent", "brute_force", "ivf_pq"]) @pytest.mark.parametrize("cluster", ["single_cluster", "multi_cluster"]) -@pytest.mark.parametrize("metric", ["sqeuclidean", "cosine"]) +@pytest.mark.parametrize( + "metric", + [ + "sqeuclidean", + "l2", + "cosine", + "l1", + "inner_product", + "chebyshev", + "canberra", + "minkowski", + "correlation", + "jensenshannon", + ], +) def test_all_neighbors_device_build_quality(algo, cluster, metric): """Test device build with quality validation against brute force ground truth. """ n_rows, n_cols, k = 7151, 64, 16 - if algo == "ivf_pq" and metric == "cosine": - pytest.skip("Skipping IVF-PQ with cosine distance") + ivf_pq_valid_metrics = {"sqeuclidean"} + nnd_valid_metrics = {"sqeuclidean", "l2", "cosine", "inner_product"} + is_invalid = (algo == "ivf_pq" and metric not in ivf_pq_valid_metrics) or ( + algo == "nn_descent" and metric not in nnd_valid_metrics + ) if cluster == "single_cluster": overlap_factor = 0 @@ -57,6 +74,21 @@ def test_all_neighbors_device_build_quality(algo, cluster, metric): X, _ = make_cosine( n_samples=n_rows, n_features=n_cols, random_state=42 ) + elif metric == "jensenshannon": + # Jensen-Shannon requires non-negative values representing probability distributions + X, _ = make_blobs( + n_samples=n_rows, + n_features=n_cols, + centers=10, + cluster_std=1.0, + center_box=(0.0, 10.0), # Non-negative values only + random_state=42, + ) + # Normalize each row to sum to 1 (probability distribution) + X = np.abs(X) # Ensure non-negative + row_sums = X.sum(axis=1, keepdims=True) + row_sums[row_sums == 0] = 1 # Avoid division by zero + X = X / row_sums else: X, _ = make_blobs( n_samples=n_rows, @@ -98,6 +130,18 @@ def test_all_neighbors_device_build_quality(algo, cluster, metric): ) res = Resources() + + if is_invalid: + with pytest.raises(Exception, match="Distance metric"): + all_neighbors.build( + X_device, + k, + params, + distances=cupy.empty((n_rows, k), dtype=cupy.float32), + resources=res, + ) + return + indices, distances = all_neighbors.build( X_device, k,