Skip to content

Commit 216a512

Browse files
authored
Merge branch 'main' into view-pq-quantizer
2 parents f8a5415 + 63ed6ec commit 216a512

23 files changed

Lines changed: 2955 additions & 909 deletions

cpp/CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,7 +1145,6 @@ if(NOT BUILD_CPU_ONLY)
11451145
CXX_STANDARD_REQUIRED ON
11461146
CUDA_STANDARD 20
11471147
CUDA_STANDARD_REQUIRED ON
1148-
CUDA_RESOLVE_DEVICE_SYMBOLS ON
11491148
INTERFACE_POSITION_INDEPENDENT_CODE ON
11501149
POSITION_INDEPENDENT_CODE ON
11511150
)
@@ -1202,7 +1201,6 @@ SECTIONS
12021201
CXX_STANDARD_REQUIRED ON
12031202
CUDA_STANDARD 20
12041203
CUDA_STANDARD_REQUIRED ON
1205-
CUDA_RESOLVE_DEVICE_SYMBOLS ON
12061204
POSITION_INDEPENDENT_CODE ON
12071205
INTERFACE_POSITION_INDEPENDENT_CODE ON
12081206
EXPORT_NAME cuvs_static

cpp/src/cluster/detail/kmeans.cuh

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -694,14 +694,15 @@ void kmeans_fit(
694694

695695
rmm::device_uvector<char> batch_workspace(streaming_batch_size, stream);
696696

697-
cuvs::spatial::knn::detail::utils::batch_load_iterator<DataT> data_batches(
698-
X.data_handle(), n_samples, n_features, streaming_batch_size, stream);
697+
auto data_batches = cuvs::spatial::knn::detail::utils::make_batch_load_iterator<DataT>(
698+
handle, X.data_handle(), n_samples, n_features, streaming_batch_size, stream);
699699
// Host-path weight batches: only materialized when weights are provided and
700700
// the data resides on host
701-
std::optional<cuvs::spatial::knn::detail::utils::batch_load_iterator<DataT>> weight_batches;
701+
std::optional<cuvs::spatial::knn::detail::utils::batch_load_iterator_dyn<DataT>> weight_batches;
702702
if constexpr (!data_on_device) {
703703
if (weight_ptr != nullptr) {
704-
weight_batches.emplace(weight_ptr, n_samples, 1, streaming_batch_size, stream);
704+
weight_batches = cuvs::spatial::knn::detail::utils::make_batch_load_iterator<DataT>(
705+
handle, weight_ptr, n_samples, IndexT{1}, streaming_batch_size, stream);
705706
} else {
706707
raft::matrix::fill(handle, batch_weights_buf.view(), DataT{1});
707708
}
@@ -833,7 +834,7 @@ void kmeans_fit(
833834
raft::make_device_matrix_view<DataT, IndexT>(new_centroids_ptr, n_clusters, n_features);
834835

835836
data_batches.reset();
836-
using wt_iter_t = cuvs::spatial::knn::detail::utils::batch_load_iterator<DataT>;
837+
using wt_iter_t = cuvs::spatial::knn::detail::utils::batch_load_iterator_dyn<DataT>;
837838
std::optional<wt_iter_t> wt_it;
838839
if (weight_batches.has_value()) {
839840
weight_batches->reset();
@@ -932,7 +933,7 @@ void kmeans_fit(
932933

933934
iter_inertia = DataT{0};
934935
data_batches.reset();
935-
using wt_iter_t = cuvs::spatial::knn::detail::utils::batch_load_iterator<DataT>;
936+
using wt_iter_t = cuvs::spatial::knn::detail::utils::batch_load_iterator_dyn<DataT>;
936937
std::optional<wt_iter_t> wt_it;
937938
if (weight_batches.has_value()) {
938939
weight_batches->reset();

cpp/src/cluster/detail/kmeans_balanced.cuh

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <raft/core/operators.hpp>
1717
#include <raft/core/resource/cuda_stream.hpp>
1818
#include <raft/core/resource/device_memory_resource.hpp>
19+
#include <raft/core/resource/device_properties.hpp>
1920
#include <raft/core/resource/thrust_policy.hpp>
2021
#include <raft/linalg/add.cuh>
2122
#include <raft/linalg/gemm.cuh>
@@ -171,22 +172,28 @@ inline std::enable_if_t<std::is_floating_point_v<MathT>> predict_core(
171172
* @return A suggested minibatch size and the expected memory cost per-row (in bytes)
172173
*/
173174
template <typename MathT, typename IdxT>
174-
constexpr auto calc_minibatch_size(IdxT n_clusters,
175-
IdxT n_rows,
176-
IdxT dim,
177-
cuvs::distance::DistanceType metric,
178-
bool needs_conversion) -> std::tuple<IdxT, size_t>
175+
auto calc_minibatch_size(const raft::resources& handle,
176+
IdxT n_clusters,
177+
IdxT n_rows,
178+
IdxT dim,
179+
cuvs::distance::DistanceType metric,
180+
bool needs_conversion) -> std::tuple<IdxT, size_t>
179181
{
180182
n_clusters = std::max<IdxT>(1, n_clusters);
181183

182184
// Estimate memory needs per row (i.e element of the batch).
183185
size_t mem_per_row = 0;
184186
switch (metric) {
185-
// fusedL2NN needs a mutex and a key-value pair for each row.
186187
case distance::DistanceType::L2Expanded:
187188
case distance::DistanceType::L2SqrtExpanded: {
188-
mem_per_row += sizeof(int);
189-
mem_per_row += sizeof(raft::KeyValuePair<IdxT, MathT>);
189+
if (use_fused<MathT, IdxT, IdxT>(handle, n_rows, n_clusters, dim)) {
190+
// fusedL2NN needs a mutex and a key-value pair for each row.
191+
mem_per_row += sizeof(int);
192+
mem_per_row += sizeof(raft::KeyValuePair<IdxT, MathT>);
193+
} else {
194+
// unfused path needs a full GEMM output (distance matrix row).
195+
mem_per_row += sizeof(MathT) * n_clusters;
196+
}
190197
} break;
191198
// Other metrics require storing a distance matrix.
192199
default: {
@@ -377,8 +384,8 @@ void predict(const raft::resources& handle,
377384
raft::common::nvtx::range<cuvs::common::nvtx::domain::cuvs> fun_scope(
378385
"predict(%zu, %u)", static_cast<size_t>(n_rows), n_clusters);
379386
auto mem_res = mr.value_or(raft::resource::get_workspace_resource_ref(handle));
380-
auto [max_minibatch_size, _mem_per_row] =
381-
calc_minibatch_size<MathT>(n_clusters, n_rows, dim, params.metric, std::is_same_v<T, MathT>);
387+
auto [max_minibatch_size, _mem_per_row] = calc_minibatch_size<MathT>(
388+
handle, n_clusters, n_rows, dim, params.metric, std::is_same_v<T, MathT>);
382389
rmm::device_uvector<MathT> cur_dataset(
383390
std::is_same_v<T, MathT> ? 0 : max_minibatch_size * dim, stream, mem_res);
384391
bool need_compute_norm =
@@ -989,8 +996,8 @@ void build_hierarchical(const raft::resources& handle,
989996
// TODO: Remove the explicit managed memory- we shouldn't be creating this on the user's behalf.
990997
rmm::mr::managed_memory_resource managed_memory;
991998
rmm::device_async_resource_ref device_memory = raft::resource::get_workspace_resource_ref(handle);
992-
auto [max_minibatch_size, mem_per_row] =
993-
calc_minibatch_size<MathT>(n_clusters, n_rows, dim, params.metric, std::is_same_v<T, MathT>);
999+
auto [max_minibatch_size, mem_per_row] = calc_minibatch_size<MathT>(
1000+
handle, n_clusters, n_rows, dim, params.metric, std::is_same_v<T, MathT>);
9941001

9951002
// Precompute the L2 norm of the dataset if relevant and not yet computed.
9961003
rmm::device_uvector<MathT> dataset_norm_buf(0, stream, device_memory);

cpp/src/cluster/detail/kmeans_common.cuh

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <raft/core/memory_type.hpp>
2020
#include <raft/core/operators.hpp>
2121
#include <raft/core/resource/cuda_stream.hpp>
22+
#include <raft/core/resource/device_properties.hpp>
2223
#include <raft/core/resource/thrust_policy.hpp>
2324
#include <raft/core/resources.hpp>
2425
#include <raft/linalg/map.cuh>
@@ -56,6 +57,31 @@
5657

5758
namespace cuvs::cluster::kmeans::detail {
5859

60+
/**
61+
* @brief Returns true if the fused distance NN implementation should be used.
62+
*
63+
* On Ampere (SM <= 8.x) always use fused.
64+
* On Hopper (SM 9.x) use fused when m or n >= 4096.
65+
* On Blackwell (SM >= 10.x) use unfused.
66+
*/
67+
template <typename MathT, typename IdxT, typename LabelT>
68+
bool use_fused(const raft::resources& handle, IdxT m, IdxT n, IdxT k)
69+
{
70+
cudaDeviceProp prop;
71+
prop = raft::resource::get_device_properties(handle);
72+
if (prop.major <= 8) {
73+
// Use fused for Ampere or before
74+
return true;
75+
} else if (prop.major == 9 && (m >= 4096 || n >= 4096)) {
76+
// On Hopper if m, n are bigger than 4096, use fused
77+
return true;
78+
} else if (prop.major >= 10) {
79+
// On Blackwell onwards, use unfused
80+
return false;
81+
}
82+
return false;
83+
}
84+
5985
template <typename DataT, typename IndexT>
6086
struct SamplingOp {
6187
DataT* rnd;

cpp/src/cluster/detail/minClusterDistanceCompute.cu

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
*/
55

66
#include "../../distance/fused_distance_nn.cuh"
7+
#include "../../distance/unfused_distance_nn.cuh"
78
#include "kmeans_common.cuh"
89

910
#include <raft/matrix/init.cuh>
@@ -50,24 +51,50 @@ void minClusterAndDistanceCompute(
5051
raft::KeyValuePair<IndexT, DataT> initial_value(0, std::numeric_limits<DataT>::max());
5152
raft::matrix::fill(handle, minClusterAndDistance, initial_value);
5253

53-
workspace.resize((sizeof(int)) * n_samples, stream);
54-
55-
cuvs::distance::fusedDistanceNNMinReduce<DataT, raft::KeyValuePair<IndexT, DataT>, IndexT>(
56-
minClusterAndDistance.data_handle(),
57-
X.data_handle(),
58-
centroids.data_handle(),
59-
L2NormX.data_handle(),
60-
centroidsNorm.data_handle(),
61-
n_samples,
62-
n_clusters,
63-
n_features,
64-
(void*)workspace.data(),
65-
metric != cuvs::distance::DistanceType::L2Expanded,
66-
false,
67-
true,
68-
metric,
69-
0.0f,
70-
stream);
54+
bool should_use_fused =
55+
use_fused<DataT, IndexT, IndexT>(handle, n_samples, n_clusters, n_features);
56+
57+
if (should_use_fused) {
58+
workspace.resize((sizeof(int)) * n_samples, stream);
59+
60+
cuvs::distance::fusedDistanceNNMinReduce<DataT, raft::KeyValuePair<IndexT, DataT>, IndexT>(
61+
minClusterAndDistance.data_handle(),
62+
X.data_handle(),
63+
centroids.data_handle(),
64+
L2NormX.data_handle(),
65+
centroidsNorm.data_handle(),
66+
n_samples,
67+
n_clusters,
68+
n_features,
69+
(void*)workspace.data(),
70+
metric != cuvs::distance::DistanceType::L2Expanded,
71+
false,
72+
true,
73+
metric,
74+
0.0f,
75+
stream);
76+
} else {
77+
workspace.resize(sizeof(DataT) * n_samples * n_clusters, stream);
78+
79+
cuvs::distance::
80+
unfusedDistanceNNMinReduce<DataT, DataT, raft::KeyValuePair<IndexT, DataT>, IndexT>(
81+
handle,
82+
minClusterAndDistance.data_handle(),
83+
X.data_handle(),
84+
centroids.data_handle(),
85+
L2NormX.data_handle(),
86+
centroidsNorm.data_handle(),
87+
n_samples,
88+
n_clusters,
89+
n_features,
90+
(void*)workspace.data(),
91+
metric != cuvs::distance::DistanceType::L2Expanded,
92+
false,
93+
true,
94+
metric,
95+
0.0f,
96+
stream);
97+
}
7198
} else {
7299
auto dataBatchSize = getDataBatchSize(batch_samples, n_samples);
73100
auto centroidsBatchSize = getCentroidsBatchSize(batch_centroids, n_clusters);

0 commit comments

Comments
 (0)