Skip to content

Commit aa250a7

Browse files
authored
Merge branch 'main' into cagra_optimize
2 parents cf86064 + 3da8994 commit aa250a7

7 files changed

Lines changed: 380 additions & 85 deletions

File tree

c/src/preprocessing/quantize/pq.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -62,18 +62,10 @@ void* _build(cuvsResources_t res,
6262
auto dataset = dataset_tensor->dl_tensor;
6363

6464
auto res_ptr = reinterpret_cast<raft::resources*>(res);
65-
66-
auto quantizer_params = cuvs::preprocessing::quantize::pq::params{
67-
.pq_bits = params->pq_bits,
68-
.pq_dim = params->pq_dim,
69-
.use_subspaces = params->use_subspaces,
70-
.use_vq = params->use_vq,
71-
.vq_n_centers = params->vq_n_centers,
72-
.kmeans_n_iters = params->kmeans_n_iters,
73-
.pq_kmeans_type = static_cast<cuvs::cluster::kmeans::kmeans_type>(params->pq_kmeans_type),
74-
.max_train_points_per_pq_code = params->max_train_points_per_pq_code,
75-
.max_train_points_per_vq_cluster = params->max_train_points_per_vq_cluster
76-
};
65+
cuvs::preprocessing::quantize::pq::params quantizer_params(
66+
params->pq_bits, params->pq_dim, params->use_subspaces, params->use_vq, params->vq_n_centers,
67+
params->kmeans_n_iters, static_cast<cuvs::cluster::kmeans::kmeans_type>(params->pq_kmeans_type), params->max_train_points_per_pq_code,
68+
params->max_train_points_per_vq_cluster);
7769
cuvs::preprocessing::quantize::pq::quantizer<T>* ret = nullptr;
7870

7971
if (cuvs::core::is_dlpack_device_compatible(dataset)) {

cpp/include/cuvs/preprocessing/quantize/pq.hpp

Lines changed: 66 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,78 @@
55

66
#pragma once
77

8+
#include <cuvs/cluster/kmeans.hpp>
89
#include <cuvs/neighbors/common.hpp>
910
#include <raft/core/device_mdspan.hpp>
1011
#include <raft/core/handle.hpp>
1112
#include <raft/core/host_mdspan.hpp>
1213

14+
#include <variant>
15+
1316
namespace cuvs::preprocessing::quantize::pq {
1417

1518
/**
1619
* @defgroup pq Product Quantizer utilities
1720
* @{
1821
*/
1922

23+
/** Alias for the variant holding either balanced or regular k-means parameters. */
24+
using kmeans_params_variant =
25+
std::variant<cuvs::cluster::kmeans::balanced_params, cuvs::cluster::kmeans::params>;
26+
2027
/**
2128
* @brief Product Quantizer parameters.
2229
*/
2330
struct params {
31+
/**
32+
* Simplified constructor that will build an appropriate kmeans params object.
33+
*/
34+
params(uint32_t pq_bits,
35+
uint32_t pq_dim,
36+
bool use_subspaces,
37+
bool use_vq,
38+
uint32_t vq_n_centers,
39+
uint32_t kmeans_n_iters,
40+
cuvs::cluster::kmeans::kmeans_type pq_kmeans_type =
41+
cuvs::cluster::kmeans::kmeans_type::KMeansBalanced,
42+
uint32_t max_train_points_per_pq_code = 256,
43+
uint32_t max_train_points_per_vq_cluster = 1024)
44+
: pq_bits(pq_bits),
45+
pq_dim(pq_dim),
46+
use_subspaces(use_subspaces),
47+
use_vq(use_vq),
48+
vq_n_centers(vq_n_centers),
49+
kmeans_params(
50+
pq_kmeans_type == cuvs::cluster::kmeans::kmeans_type::KMeansBalanced
51+
? kmeans_params_variant{cuvs::cluster::kmeans::balanced_params{.n_iters = kmeans_n_iters}}
52+
: kmeans_params_variant{cuvs::cluster::kmeans::params{
53+
.n_clusters = 1 << pq_bits, .max_iter = static_cast<int>(kmeans_n_iters)}}),
54+
max_train_points_per_pq_code(max_train_points_per_pq_code),
55+
max_train_points_per_vq_cluster(max_train_points_per_vq_cluster)
56+
{
57+
}
58+
59+
params(uint32_t pq_bits,
60+
uint32_t pq_dim,
61+
bool use_subspaces,
62+
bool use_vq,
63+
uint32_t vq_n_centers,
64+
kmeans_params_variant kmeans_params,
65+
uint32_t max_train_points_per_pq_code = 256,
66+
uint32_t max_train_points_per_vq_cluster = 1024)
67+
: pq_bits(pq_bits),
68+
pq_dim(pq_dim),
69+
use_subspaces(use_subspaces),
70+
use_vq(use_vq),
71+
vq_n_centers(vq_n_centers),
72+
kmeans_params(kmeans_params),
73+
max_train_points_per_pq_code(max_train_points_per_pq_code),
74+
max_train_points_per_vq_cluster(max_train_points_per_vq_cluster)
75+
{
76+
}
77+
78+
params() = default;
79+
2480
/**
2581
* The bit length of the vector element after compression by PQ.
2682
*
@@ -32,7 +88,7 @@ struct params {
3288
uint32_t pq_bits = 8;
3389
/**
3490
* The dimensionality of the vector after compression by PQ.
35-
* When zero, an optimal value is selected using a heuristic.
91+
* When zero, dim / 4 is used as default.
3692
*
3793
* TODO: at the moment `dim` must be a multiple `pq_dim`.
3894
*/
@@ -50,19 +106,19 @@ struct params {
50106
bool use_vq = false;
51107
/**
52108
* Vector Quantization (VQ) codebook size - number of "coarse cluster centers".
53-
* When zero, an optimal value is selected using a heuristic.
109+
* When zero, an optimal value is selected using a heuristic. (sqrt(n_rows))
54110
*/
55111
uint32_t vq_n_centers = 0;
56-
/** The number of iterations searching for kmeans centers (both VQ & PQ phases). */
57-
uint32_t kmeans_n_iters = 25;
58112
/**
59-
* Type of k-means algorithm for PQ training.
60-
* Balanced k-means tends to be faster than regular k-means for PQ training, for
61-
* problem sets where the number of points per cluster are approximately equal.
62-
* Regular k-means may be better for skewed cluster distributions.
113+
* K-means parameters for PQ codebook training.
114+
*
115+
* Set to cuvs::cluster::kmeans::balanced_params for balanced k-means (default),
116+
* or cuvs::cluster::kmeans::params for regular k-means.
117+
* The active variant type selects the algorithm; balanced k-means tends to be faster
118+
* for PQ training where cluster sizes are approximately equal.
119+
* Only L2Expanded metric is supported. The number of clusters is always set to 1 << pq_bits.
63120
*/
64-
cuvs::cluster::kmeans::kmeans_type pq_kmeans_type =
65-
cuvs::cluster::kmeans::kmeans_type::KMeansBalanced;
121+
kmeans_params_variant kmeans_params = cuvs::cluster::kmeans::balanced_params{};
66122
/**
67123
* The max number of data points to use per PQ code during PQ codebook training. Using more data
68124
* points per PQ code may increase the quality of PQ codebook but may also increase the build

cpp/src/neighbors/detail/vpq_dataset.cuh

Lines changed: 48 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#pragma once
66

77
#include <cuvs/neighbors/common.hpp>
8+
#include <cuvs/preprocessing/quantize/pq.hpp>
89

910
#include "../../cluster/kmeans_balanced.cuh"
1011
#include "../../preprocessing/quantize/detail/pq_codepacking.cuh" // pq_bits-bitfield
@@ -74,50 +75,49 @@ namespace cuvs::neighbors::detail {
7475
template <typename MathT, typename IdxT>
7576
void train_pq_centers(
7677
const raft::resources& res,
77-
const cuvs::neighbors::vpq_params& params,
78+
const cuvs::preprocessing::quantize::pq::kmeans_params_variant& kmeans_params,
7879
const raft::device_matrix_view<const MathT, IdxT, raft::row_major> pq_trainset_view,
7980
const raft::device_matrix_view<MathT, uint32_t, raft::row_major> pq_centers_view,
8081
raft::device_vector_view<uint32_t, IdxT> sub_labels_view,
8182
raft::device_vector_view<uint32_t, IdxT> pq_cluster_sizes_view)
8283
{
83-
if (params.pq_kmeans_type == cuvs::cluster::kmeans::kmeans_type::KMeansBalanced) {
84-
cuvs::cluster::kmeans::balanced_params kmeans_params;
85-
kmeans_params.n_iters = params.kmeans_n_iters;
86-
kmeans_params.metric = cuvs::distance::DistanceType::L2Expanded;
87-
88-
cuvs::cluster::kmeans_balanced::helpers::build_clusters<
89-
MathT,
90-
MathT,
91-
IdxT,
92-
uint32_t,
93-
uint32_t,
94-
cuvs::spatial::knn::detail::utils::mapping<MathT>>(
95-
res,
96-
kmeans_params,
97-
pq_trainset_view,
98-
pq_centers_view,
99-
sub_labels_view,
100-
pq_cluster_sizes_view,
101-
cuvs::spatial::knn::detail::utils::mapping<MathT>{});
102-
} else {
103-
const auto pq_n_centers = pq_centers_view.extent(0);
104-
cuvs::cluster::kmeans::params kmeans_params;
105-
kmeans_params.n_clusters = pq_n_centers;
106-
kmeans_params.max_iter = params.kmeans_n_iters;
107-
kmeans_params.metric = cuvs::distance::DistanceType::L2Expanded;
108-
kmeans_params.init = cuvs::cluster::kmeans::params::InitMethod::Random;
109-
110-
std::optional<raft::device_vector_view<const MathT, IdxT>> sample_weight = std::nullopt;
111-
MathT inertia;
112-
IdxT n_iter;
113-
cuvs::cluster::kmeans::fit(res,
114-
kmeans_params,
115-
pq_trainset_view,
116-
sample_weight,
117-
pq_centers_view,
118-
raft::make_host_scalar_view<MathT>(&inertia),
119-
raft::make_host_scalar_view<IdxT>(&n_iter));
120-
}
84+
std::visit(
85+
[&](auto const& base_kmeans_params) {
86+
using KP = std::decay_t<decltype(base_kmeans_params)>;
87+
if constexpr (std::is_same_v<KP, cuvs::cluster::kmeans::balanced_params>) {
88+
auto bal_params = base_kmeans_params;
89+
bal_params.metric = cuvs::distance::DistanceType::L2Expanded;
90+
cuvs::cluster::kmeans_balanced::helpers::build_clusters<
91+
MathT,
92+
MathT,
93+
IdxT,
94+
uint32_t,
95+
uint32_t,
96+
cuvs::spatial::knn::detail::utils::mapping<MathT>>(
97+
res,
98+
bal_params,
99+
pq_trainset_view,
100+
pq_centers_view,
101+
sub_labels_view,
102+
pq_cluster_sizes_view,
103+
cuvs::spatial::knn::detail::utils::mapping<MathT>{});
104+
} else {
105+
auto classic_params = base_kmeans_params;
106+
classic_params.n_clusters = pq_centers_view.extent(0);
107+
classic_params.metric = cuvs::distance::DistanceType::L2Expanded;
108+
std::optional<raft::device_vector_view<const MathT, IdxT>> sample_weight = std::nullopt;
109+
MathT inertia;
110+
IdxT n_iter;
111+
cuvs::cluster::kmeans::fit(res,
112+
classic_params,
113+
pq_trainset_view,
114+
sample_weight,
115+
pq_centers_view,
116+
raft::make_host_scalar_view<MathT>(&inertia),
117+
raft::make_host_scalar_view<IdxT>(&n_iter));
118+
}
119+
},
120+
kmeans_params);
121121
}
122122

123123
template <typename DatasetT>
@@ -219,7 +219,7 @@ auto predict_vq(const raft::resources& res,
219219

220220
template <typename MathT, typename DatasetT>
221221
auto train_pq(const raft::resources& res,
222-
const vpq_params& params,
222+
const cuvs::preprocessing::quantize::pq::params& params,
223223
const DatasetT& dataset,
224224
const raft::device_matrix_view<const MathT, uint32_t, raft::row_major> vq_centers)
225225
-> raft::device_matrix<MathT, uint32_t, raft::row_major>
@@ -230,8 +230,8 @@ auto train_pq(const raft::resources& res,
230230
const ix_t pq_bits = params.pq_bits;
231231
const ix_t pq_n_centers = ix_t{1} << pq_bits;
232232
const ix_t pq_len = raft::div_rounding_up_safe(dim, pq_dim);
233-
const ix_t n_rows_train = std::min((ix_t)(n_rows * params.pq_kmeans_trainset_fraction),
234-
params.max_train_points_per_pq_code * pq_n_centers);
233+
const ix_t n_rows_train =
234+
std::min<ix_t>(n_rows, params.max_train_points_per_pq_code * pq_n_centers);
235235
RAFT_EXPECTS(
236236
n_rows_train >= pq_n_centers,
237237
"The number of training samples must be greater than or equal to the number of PQ centers");
@@ -261,8 +261,12 @@ auto train_pq(const raft::resources& res,
261261
pq_trainset.data_handle(), n_rows_train * pq_dim, pq_len);
262262
auto sub_labels = raft::make_device_vector<uint32_t, ix_t>(res, pq_trainset_view.extent(0));
263263
auto pq_cluster_sizes = raft::make_device_vector<uint32_t, ix_t>(res, pq_centers.extent(0));
264-
train_pq_centers<MathT, ix_t>(
265-
res, params, pq_trainset_view, pq_centers.view(), sub_labels.view(), pq_cluster_sizes.view());
264+
train_pq_centers<MathT, ix_t>(res,
265+
params.kmeans_params,
266+
pq_trainset_view,
267+
pq_centers.view(),
268+
sub_labels.view(),
269+
pq_cluster_sizes.view());
266270

267271
return pq_centers;
268272
}

cpp/src/neighbors/scann/detail/scann_build.cuh

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,14 +160,16 @@ index<T, IdxT> build(
160160
int dim_per_subspace = params.pq_dim;
161161
int num_clusters = 1 << params.pq_bits;
162162

163-
cuvs::preprocessing::quantize::pq::params pq_build_params;
164-
pq_build_params.pq_bits = params.pq_bits;
165-
pq_build_params.pq_dim = num_subspaces;
166-
pq_build_params.use_subspaces = true;
167-
pq_build_params.use_vq = false; // We already computed residuals
168-
pq_build_params.kmeans_n_iters = params.pq_train_iters;
169-
pq_build_params.max_train_points_per_pq_code = pq_n_rows_train / num_clusters;
170-
pq_build_params.pq_kmeans_type = cuvs::cluster::kmeans::kmeans_type::KMeansBalanced;
163+
cuvs::preprocessing::quantize::pq::params pq_build_params(
164+
params.pq_bits,
165+
num_subspaces,
166+
true,
167+
false,
168+
0,
169+
params.pq_train_iters,
170+
cuvs::cluster::kmeans::kmeans_type::KMeansBalanced,
171+
pq_n_rows_train / num_clusters,
172+
1024);
171173

172174
auto pq_quantizer = cuvs::preprocessing::quantize::pq::build(
173175
res, pq_build_params, raft::make_const_mdspan(trainset_residuals.view()));

0 commit comments

Comments
 (0)