diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh index 50f9564ca5..c562ca9e00 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh @@ -67,6 +67,14 @@ using namespace cuvs::spatial::knn::detail; // NOLINT using internal_extents_t = int64_t; // The default mdspan extent type used internally. +inline cuvs::distance::DistanceType coarse_clustering_metric( + cuvs::distance::DistanceType metric) noexcept +{ + return metric == cuvs::distance::DistanceType::InnerProduct + ? cuvs::distance::DistanceType::L2Expanded + : metric; +} + /** * @brief Compute residual vectors from the source dataset given by selected indices. * @@ -1112,7 +1120,7 @@ void extend(raft::resources const& handle, auto centers_view = raft::make_device_matrix_view( cluster_centers.data(), n_clusters, index->dim()); cuvs::cluster::kmeans::balanced_params kmeans_params; - kmeans_params.metric = index->metric(); + kmeans_params.metric = coarse_clustering_metric(index->metric()); cuvs::cluster::kmeans::predict( handle, kmeans_params, batch_data_view, centers_view, batch_labels_view); vec_batches.prefetch_next_batch(); @@ -1323,7 +1331,7 @@ auto build(raft::resources const& handle, cluster_centers, impl->n_lists(), impl->dim()); cuvs::cluster::kmeans::balanced_params kmeans_params; kmeans_params.n_iters = params.kmeans_n_iters; - kmeans_params.metric = static_cast((int)impl->metric()); + kmeans_params.metric = coarse_clustering_metric(impl->metric()); if (impl->metric() == distance::DistanceType::CosineExpanded) { raft::linalg::row_normalize( diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_transform.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_transform.cuh index b5ea8afab4..e0032a4ac5 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_transform.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_transform.cuh @@ -55,7 +55,7 @@ void transform_batch(raft::resources const& res, // Compute the labels for each vector cuvs::cluster::kmeans::balanced_params kmeans_params; - kmeans_params.metric = index.metric(); + kmeans_params.metric = coarse_clustering_metric(index.metric()); cuvs::cluster::kmeans_balanced::predict( res, kmeans_params, dataset, cluster_centers, output_labels, utils::mapping{});