Skip to content

Commit 403cc54

Browse files
authored
Multi-GPU Batched KMeans (#2017)
Closes #1989. Adds multi-GPU support to KMeans fit for host-resident data, with two modes: - **OpenMP (cuVS SNMG)**: A single process drives all local GPUs via OMP threads and raw NCCL. Activated automatically when the handle is a `device_resources_snmg`. - **RAFT comms (Ray / Dask / MPI)**: Each rank is a separate process that calls fit with its own data shard and an initialized RAFT communicator. Coordination uses the RAFT comms. Both modes share the same core Lloyd's loop, batched streaming of host data, NCCL/comms allreduce of centroid sums and counts, and synchronized convergence. Supports sample weights, n_init best-of-N restarts, KMeansPlusPlus initialization, and float/double. Falls back to single-GPU when neither multi-GPU resources nor comms are present. Authors: - Victor Lafargue (https://github.com/viclafargue) - Tarang Jain (https://github.com/tarang-jain) Approvers: - Tarang Jain (https://github.com/tarang-jain) - Micka (https://github.com/lowener) - Dante Gama Dessavre (https://github.com/dantegd) URL: #2017
1 parent 547a413 commit 403cc54

15 files changed

Lines changed: 1667 additions & 92 deletions

cpp/include/cuvs/cluster/kmeans.hpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ struct params : base_params {
125125
* When set to 0 (default) with host data uses `min(3 * n_clusters, n_samples)`
126126
* as a default.
127127
*
128+
* In Batched multi-GPU host-data fits, the effective KMeansPlusPlus initialization
129+
* sample is materialized on device on every rank. Every rank must have enough
130+
* GPU memory for this sample, and rank 0 must also have enough GPU memory for
131+
* the seeding workspace.
132+
*
128133
* Default: 0.
129134
*/
130135
int64_t init_size = 0;
@@ -134,6 +139,9 @@ struct params : base_params {
134139
* When set to 0, defaults to n_samples (process all at once).
135140
* Only used by the batched (host-data) code path and ignored by device-data
136141
* overloads.
142+
*
143+
* In multi-GPU mode, this is a per-rank batch size. Each rank processes up to
144+
* this many local samples per batch, clamped to that rank's local sample count.
137145
* Default: 0 (process all data at once).
138146
*/
139147
int64_t streaming_batch_size = 0;
@@ -177,7 +185,20 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 };
177185
*
178186
* This overload supports out-of-core computation where the dataset resides
179187
* on the host. Data is processed in GPU-sized batches, streaming from host to device.
180-
* The batch size is controlled by params.streaming_batch_size.
188+
* The batch size is controlled by params.streaming_batch_size. In multi-GPU mode,
189+
* this is a per-rank batch size.
190+
*
191+
* Multi-GPU dispatch is selected automatically based on the handle state:
192+
* - If `raft::resource::is_multi_gpu(handle)` (cuVS SNMG): the full dataset X
193+
* is split across GPUs internally with an OpenMP parallel region and NCCL.
194+
* - If `raft::resource::comms_initialized(handle)` (Dask/Ray/MPI): X is treated as
195+
* this worker's partition, and RAFT communicators are used for collectives.
196+
* - Otherwise: single-GPU batched k-means.
197+
*
198+
* With `params.init == InitMethod::KMeansPlusPlus` in multi-GPU mode, the
199+
* effective initialization sample must fit in GPU memory on every rank because
200+
* it is materialized on every device. Rank 0 must also have enough GPU memory
201+
* for the seeding workspace before centroids are broadcast.
181202
*
182203
* @code{.cpp}
183204
* #include <raft/core/resources.hpp>
@@ -208,7 +229,8 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 };
208229
* raft::make_host_scalar_view(&n_iter));
209230
* @endcode
210231
*
211-
* @param[in] handle The raft handle.
232+
* @param[in] handle The raft handle. When a multi-GPU resource is
233+
* attached, multi-GPU dispatch is used automatically.
212234
* @param[in] params Parameters for KMeans model. Batch size is read from
213235
* params.streaming_batch_size.
214236
* @param[in] X Training instances on HOST memory. The data must

cpp/src/cluster/detail/kmeans.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -969,7 +969,7 @@ void kmeans_fit(
969969

970970
if (iter_inertia < inertia[0]) {
971971
inertia[0] = iter_inertia;
972-
n_iter[0] = n_current_iter;
972+
n_iter[0] = std::min(n_current_iter, static_cast<IndexT>(iter_params.max_iter));
973973
raft::copy(centroids.data_handle(), cur_centroids_ptr, centroid_buf_size, stream);
974974
}
975975
RAFT_LOG_DEBUG("KMeans.fit after iteration-%d/%d: inertia - %f, n_iter - %d",

cpp/src/cluster/detail/kmeans_common.cuh

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -622,26 +622,26 @@ void compute_centroid_shift(raft::resources const& handle,
622622
* @brief Evaluate convergence criteria entirely on device.
623623
*
624624
* Checks the cost-ratio and centroid-shift stopping conditions and writes
625-
* a boolean result (0 or 1) into @p done_flag. Also advances
626-
* @p prior_clustering_cost to the current cost for the next iteration.
625+
* 0 or 1 into @p done_flag, and advances @p prior_clustering_cost.
626+
* @p FlagT is deduced from @p done_flag (default `int`).
627627
*/
628-
template <typename DataT>
628+
template <typename DataT, typename FlagT = int>
629629
__device__ void check_convergence(raft::device_scalar_view<const DataT> clustering_cost,
630630
raft::device_scalar_view<DataT> prior_clustering_cost,
631631
raft::device_scalar_view<const DataT> sqrd_norm_error,
632632
DataT tol,
633633
int n_iter,
634-
raft::device_scalar_view<int> done_flag)
634+
raft::device_scalar_view<FlagT> done_flag)
635635
{
636636
DataT cur_cost = *clustering_cost.data_handle();
637637
DataT norm_err = *sqrd_norm_error.data_handle();
638-
int done = 0;
638+
FlagT done = FlagT{0};
639639

640640
if (cur_cost != DataT{0} && n_iter > 1) {
641641
DataT delta = cur_cost / *prior_clustering_cost.data_handle();
642-
if (delta > DataT{1} - tol) done = 1;
642+
if (delta > DataT{1} - tol) done = FlagT{1};
643643
}
644-
if (norm_err < tol) done = 1;
644+
if (norm_err < tol) done = FlagT{1};
645645

646646
*prior_clustering_cost.data_handle() = cur_cost;
647647
*done_flag.data_handle() = done;

cpp/src/cluster/detail/kmeans_mg.cuh

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -463,35 +463,31 @@ void checkWeights(const raft::resources& handle,
463463
raft::device_vector_view<DataT, IndexT> weight)
464464
{
465465
cudaStream_t stream = raft::resource::get_cuda_stream(handle);
466-
rmm::device_scalar<DataT> wt_aggr(stream);
466+
auto d_wt_sum = raft::make_device_scalar<DataT>(handle, DataT{0});
467467

468468
const auto& comm = raft::resource::get_comms(handle);
469469

470470
auto n_samples = weight.extent(0);
471471
raft::linalg::mapThenSumReduce(
472-
wt_aggr.data(), n_samples, raft::identity_op{}, stream, weight.data_handle());
472+
d_wt_sum.data_handle(), n_samples, raft::identity_op{}, stream, weight.data_handle());
473473

474-
comm.allreduce<DataT>(wt_aggr.data(), // sendbuff
475-
wt_aggr.data(), // recvbuff
476-
1, // count
474+
comm.allreduce<DataT>(d_wt_sum.data_handle(), // sendbuff
475+
d_wt_sum.data_handle(), // recvbuff
476+
1, // count
477477
raft::comms::op_t::SUM,
478478
stream);
479-
DataT wt_sum = wt_aggr.value(stream);
480-
raft::resource::sync_stream(handle, stream);
481-
RAFT_EXPECTS(wt_sum > DataT{0}, "invalid parameter (sum of sample weights must be positive)");
482479

483-
if (wt_sum != n_samples) {
484-
CUVS_LOG_KMEANS(handle,
485-
"[Warning!] KMeans: normalizing the user provided sample weights to "
486-
"sum up to %d samples",
487-
n_samples);
488-
489-
raft::linalg::map(handle,
490-
weight,
491-
raft::compose_op(raft::mul_const_op<DataT>{static_cast<DataT>(n_samples)},
492-
raft::div_const_op<DataT>{wt_sum}),
493-
raft::make_const_mdspan(weight));
494-
}
480+
// Normalize weights so they sum to n_samples (per rank). Reading the sum from
481+
// a device pointer avoids a host copy / stream sync. When the sum already
482+
// equals n_samples this is a numerical no-op (matches single-GPU behavior).
483+
const DataT* d_wt_sum_ptr = d_wt_sum.data_handle();
484+
raft::linalg::map(
485+
handle,
486+
weight,
487+
[n_samples, d_wt_sum_ptr] __device__(DataT w) {
488+
return w * static_cast<DataT>(n_samples) / *d_wt_sum_ptr;
489+
},
490+
raft::make_const_mdspan(weight));
495491
}
496492

497493
template <typename DataT, typename IndexT>
@@ -750,6 +746,7 @@ void fit(const raft::resources& handle,
750746
break;
751747
}
752748
}
749+
n_iter[0] = std::min(n_iter[0], static_cast<IndexT>(params.max_iter));
753750
}
754751

755752
}; // namespace cuvs::cluster::kmeans::mg::detail

0 commit comments

Comments
 (0)