55#pragma once
66
77#include < cuvs/neighbors/common.hpp>
8+ #include < cuvs/preprocessing/quantize/pq.hpp>
89
910#include " ../../cluster/kmeans_balanced.cuh"
1011#include " ../../preprocessing/quantize/detail/pq_codepacking.cuh" // pq_bits-bitfield
@@ -74,50 +75,49 @@ namespace cuvs::neighbors::detail {
7475template <typename MathT, typename IdxT>
7576void train_pq_centers (
7677 const raft::resources& res,
77- const cuvs::neighbors::vpq_params& params ,
78+ const cuvs::preprocessing::quantize::pq::kmeans_params_variant& kmeans_params ,
7879 const raft::device_matrix_view<const MathT, IdxT, raft::row_major> pq_trainset_view,
7980 const raft::device_matrix_view<MathT, uint32_t , raft::row_major> pq_centers_view,
8081 raft::device_vector_view<uint32_t , IdxT> sub_labels_view,
8182 raft::device_vector_view<uint32_t , IdxT> pq_cluster_sizes_view)
8283{
83- if (params.pq_kmeans_type == cuvs::cluster::kmeans::kmeans_type::KMeansBalanced) {
84- cuvs::cluster::kmeans::balanced_params kmeans_params;
85- kmeans_params.n_iters = params.kmeans_n_iters ;
86- kmeans_params.metric = cuvs::distance::DistanceType::L2Expanded;
87-
88- cuvs::cluster::kmeans_balanced::helpers::build_clusters<
89- MathT,
90- MathT,
91- IdxT,
92- uint32_t ,
93- uint32_t ,
94- cuvs::spatial::knn::detail::utils::mapping<MathT>>(
95- res,
96- kmeans_params,
97- pq_trainset_view,
98- pq_centers_view,
99- sub_labels_view,
100- pq_cluster_sizes_view,
101- cuvs::spatial::knn::detail::utils::mapping<MathT>{});
102- } else {
103- const auto pq_n_centers = pq_centers_view.extent (0 );
104- cuvs::cluster::kmeans::params kmeans_params;
105- kmeans_params.n_clusters = pq_n_centers;
106- kmeans_params.max_iter = params.kmeans_n_iters ;
107- kmeans_params.metric = cuvs::distance::DistanceType::L2Expanded;
108- kmeans_params.init = cuvs::cluster::kmeans::params::InitMethod::Random;
109-
110- std::optional<raft::device_vector_view<const MathT, IdxT>> sample_weight = std::nullopt ;
111- MathT inertia;
112- IdxT n_iter;
113- cuvs::cluster::kmeans::fit (res,
114- kmeans_params,
115- pq_trainset_view,
116- sample_weight,
117- pq_centers_view,
118- raft::make_host_scalar_view<MathT>(&inertia),
119- raft::make_host_scalar_view<IdxT>(&n_iter));
120- }
84+ std::visit (
85+ [&](auto const & base_kmeans_params) {
86+ using KP = std::decay_t <decltype (base_kmeans_params)>;
87+ if constexpr (std::is_same_v<KP , cuvs::cluster::kmeans::balanced_params>) {
88+ auto bal_params = base_kmeans_params;
89+ bal_params.metric = cuvs::distance::DistanceType::L2Expanded;
90+ cuvs::cluster::kmeans_balanced::helpers::build_clusters<
91+ MathT,
92+ MathT,
93+ IdxT,
94+ uint32_t ,
95+ uint32_t ,
96+ cuvs::spatial::knn::detail::utils::mapping<MathT>>(
97+ res,
98+ bal_params,
99+ pq_trainset_view,
100+ pq_centers_view,
101+ sub_labels_view,
102+ pq_cluster_sizes_view,
103+ cuvs::spatial::knn::detail::utils::mapping<MathT>{});
104+ } else {
105+ auto classic_params = base_kmeans_params;
106+ classic_params.n_clusters = pq_centers_view.extent (0 );
107+ classic_params.metric = cuvs::distance::DistanceType::L2Expanded;
108+ std::optional<raft::device_vector_view<const MathT, IdxT>> sample_weight = std::nullopt ;
109+ MathT inertia;
110+ IdxT n_iter;
111+ cuvs::cluster::kmeans::fit (res,
112+ classic_params,
113+ pq_trainset_view,
114+ sample_weight,
115+ pq_centers_view,
116+ raft::make_host_scalar_view<MathT>(&inertia),
117+ raft::make_host_scalar_view<IdxT>(&n_iter));
118+ }
119+ },
120+ kmeans_params);
121121}
122122
123123template <typename DatasetT>
@@ -219,7 +219,7 @@ auto predict_vq(const raft::resources& res,
219219
220220template <typename MathT, typename DatasetT>
221221auto train_pq (const raft::resources& res,
222- const vpq_params & params,
222+ const cuvs::preprocessing::quantize::pq::params & params,
223223 const DatasetT& dataset,
224224 const raft::device_matrix_view<const MathT, uint32_t , raft::row_major> vq_centers)
225225 -> raft::device_matrix<MathT, uint32_t, raft::row_major>
@@ -230,8 +230,8 @@ auto train_pq(const raft::resources& res,
230230 const ix_t pq_bits = params.pq_bits ;
231231 const ix_t pq_n_centers = ix_t {1 } << pq_bits;
232232 const ix_t pq_len = raft::div_rounding_up_safe (dim, pq_dim);
233- const ix_t n_rows_train = std::min (( ix_t )(n_rows * params. pq_kmeans_trainset_fraction ),
234- params.max_train_points_per_pq_code * pq_n_centers);
233+ const ix_t n_rows_train =
234+ std::min< ix_t >(n_rows, params.max_train_points_per_pq_code * pq_n_centers);
235235 RAFT_EXPECTS (
236236 n_rows_train >= pq_n_centers,
237237 " The number of training samples must be greater than or equal to the number of PQ centers" );
@@ -261,8 +261,12 @@ auto train_pq(const raft::resources& res,
261261 pq_trainset.data_handle (), n_rows_train * pq_dim, pq_len);
262262 auto sub_labels = raft::make_device_vector<uint32_t , ix_t >(res, pq_trainset_view.extent (0 ));
263263 auto pq_cluster_sizes = raft::make_device_vector<uint32_t , ix_t >(res, pq_centers.extent (0 ));
264- train_pq_centers<MathT, ix_t >(
265- res, params, pq_trainset_view, pq_centers.view (), sub_labels.view (), pq_cluster_sizes.view ());
264+ train_pq_centers<MathT, ix_t >(res,
265+ params.kmeans_params ,
266+ pq_trainset_view,
267+ pq_centers.view (),
268+ sub_labels.view (),
269+ pq_cluster_sizes.view ());
266270
267271 return pq_centers;
268272}
0 commit comments