Skip to content

Commit add9db1

Browse files
committed
optimize convergence check
1 parent bfb5290 commit add9db1

2 files changed

Lines changed: 53 additions & 41 deletions

File tree

cpp/src/cluster/detail/kmeans_common.cuh

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -596,26 +596,27 @@ void compute_centroid_shift(raft::resources const& handle,
596596
* @brief Evaluate convergence criteria entirely on device.
597597
*
598598
* Checks the cost-ratio and centroid-shift stopping conditions and writes
599-
* a boolean result (0 or 1) into @p done_flag. Also advances
600-
* @p prior_clustering_cost to the current cost for the next iteration.
599+
* 0 or 1 into @p done_flag, and advances @p prior_clustering_cost.
600+
* @p FlagT is deduced from @p done_flag (default `int`); MG callers pass
601+
* `int64_t` for NCCL allreduce compatibility.
601602
*/
602-
template <typename DataT>
603+
template <typename DataT, typename FlagT = int>
603604
__device__ void check_convergence(raft::device_scalar_view<const DataT> clustering_cost,
604605
raft::device_scalar_view<DataT> prior_clustering_cost,
605606
raft::device_scalar_view<const DataT> sqrd_norm_error,
606607
DataT tol,
607608
int n_iter,
608-
raft::device_scalar_view<int> done_flag)
609+
raft::device_scalar_view<FlagT> done_flag)
609610
{
610611
DataT cur_cost = *clustering_cost.data_handle();
611612
DataT norm_err = *sqrd_norm_error.data_handle();
612-
int done = 0;
613+
FlagT done = FlagT{0};
613614

614615
if (cur_cost != DataT{0} && n_iter > 1) {
615616
DataT delta = cur_cost / *prior_clustering_cost.data_handle();
616-
if (delta > DataT{1} - tol) done = 1;
617+
if (delta > DataT{1} - tol) done = FlagT{1};
617618
}
618-
if (norm_err < tol) done = 1;
619+
if (norm_err < tol) done = FlagT{1};
619620

620621
*prior_clustering_cost.data_handle() = cur_cost;
621622
*done_flag.data_handle() = done;

cpp/src/cluster/detail/kmeans_mg_batched.cuh

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,13 @@ void mnmg_fit(const raft::resources& handle,
332332

333333
std::mt19937 gen(params.rng_state.seed);
334334

335-
auto d_done = raft::make_device_scalar<int64_t>(dev_res, 0);
335+
// On-device convergence state, mirroring single-GPU `detail::fit`.
336+
// The flag is `int64_t` for NCCL allreduce compatibility; SUM>0 means
337+
// any rank converged, which guards against FP non-determinism in
338+
// compute_centroid_shift diverging ranks.
339+
auto d_prior_cost = raft::make_device_scalar<T>(dev_res, T{0});
340+
auto d_done_flag = raft::make_device_scalar<int64_t>(dev_res, 0);
341+
auto h_done_flag = raft::make_pinned_scalar<int64_t>(dev_res, 0);
336342

337343
std::optional<cuvs::spatial::knn::detail::utils::batch_load_iterator<T>> data_batches_opt;
338344
if (has_data) {
@@ -367,9 +373,23 @@ void mnmg_fit(const raft::resources& handle,
367373
raft::matrix::fill(dev_res, batch_weights.view(), T{1});
368374
}
369375

370-
T prior_cluster_cost = T{0};
376+
// Reset per-pass convergence state to avoid leaking it across n_init.
377+
raft::matrix::fill(dev_res, d_prior_cost.view(), T{0});
378+
*h_done_flag.data_handle() = 0;
371379

372380
for (local_n_iter = 1; local_n_iter <= iter_params.max_iter; ++local_n_iter) {
381+
// Consume the previous iteration's allreduced flag from pinned host.
382+
if (local_n_iter > 1) {
383+
raft::resource::sync_stream(dev_res);
384+
if (*h_done_flag.data_handle()) {
385+
--local_n_iter;
386+
RAFT_LOG_DEBUG("SNMG KMeans: threshold triggered after %d iterations on rank %d",
387+
static_cast<int>(local_n_iter),
388+
rank);
389+
break;
390+
}
391+
}
392+
373393
RAFT_LOG_DEBUG("SNMG KMeans: iteration %d on rank %d", local_n_iter, rank);
374394

375395
raft::matrix::fill(dev_res, centroid_sums.view(), T{0});
@@ -454,48 +474,39 @@ void mnmg_fit(const raft::resources& handle,
454474
rank_centroids_const,
455475
new_centroids.view());
456476

457-
// Phase 4: convergence check (synchronized across ranks)
477+
// Phase 4: device-side convergence evaluation. Compute shift,
478+
// run `check_convergence` via `map_offset`, allreduce the flag,
479+
// shadow into pinned host. Consumed at top of next iteration.
458480
cuvs::cluster::kmeans::detail::compute_centroid_shift<T, IdxT>(
459481
dev_res,
460482
raft::make_const_mdspan(rank_centroids.view()),
461483
raft::make_const_mdspan(new_centroids.view()),
462484
sqrd_norm_error_dev.view());
463-
T sqrdNormError = T{0};
464-
raft::copy(&sqrdNormError, sqrd_norm_error_dev.data_handle(), 1, stream);
465-
raft::resource::sync_stream(dev_res);
466485

467486
raft::copy(
468487
rank_centroids.data_handle(), new_centroids.data_handle(), n_clusters * n_features, stream);
469488

470-
bool done = false;
471-
472-
raft::copy(&local_inertia, clustering_cost.data_handle(), 1, stream);
473-
raft::resource::sync_stream(dev_res);
489+
auto d_cost_view = raft::make_device_scalar_view<const T>(clustering_cost.data_handle());
490+
auto d_prior_view = d_prior_cost.view();
491+
auto d_norm_view = raft::make_device_scalar_view<const T>(sqrd_norm_error_dev.data_handle());
492+
auto d_done_view = d_done_flag.view();
493+
T tol = static_cast<T>(params.tol);
494+
int iter = static_cast<int>(local_n_iter);
474495

475-
if (local_inertia == T{0}) {
476-
RAFT_LOG_WARN("Zero clustering cost detected: all points coincide with their centroids.");
477-
} else if (local_n_iter > 1 && prior_cluster_cost > T{0}) {
478-
T delta = local_inertia / prior_cluster_cost;
479-
if (delta > 1 - params.tol) { done = true; }
480-
}
481-
prior_cluster_cost = local_inertia;
482-
483-
if (sqrdNormError < params.tol) { done = true; }
484-
485-
// Allreduce the convergence flag so all ranks agree (FP non-determinism
486-
// in compute_centroid_shift could otherwise diverge ranks and deadlock)
487-
int64_t done_val = done ? 1 : 0;
488-
raft::copy(d_done.data_handle(), &done_val, 1, stream);
489-
SNMG_ALLREDUCE(d_done.data_handle(), d_done.data_handle(), 1);
490-
raft::copy(&done_val, d_done.data_handle(), 1, stream);
491-
raft::resource::sync_stream(dev_res);
492-
done = (done_val > 0);
493-
494-
if (done) {
495-
RAFT_LOG_DEBUG(
496-
"SNMG KMeans: threshold triggered after %d iterations on rank %d", local_n_iter, rank);
497-
break;
498-
}
496+
raft::linalg::map_offset(
497+
dev_res,
498+
raft::make_device_vector_view<int64_t, int>(d_done_flag.data_handle(), 1),
499+
[=] __device__(int) {
500+
cuvs::cluster::kmeans::detail::check_convergence(
501+
d_cost_view, d_prior_view, d_norm_view, tol, iter, d_done_view);
502+
return *d_done_view.data_handle();
503+
});
504+
505+
SNMG_ALLREDUCE(d_done_flag.data_handle(), d_done_flag.data_handle(), 1);
506+
507+
raft::copy(dev_res,
508+
raft::make_pinned_scalar_view(h_done_flag.data_handle()),
509+
raft::make_device_scalar_view<const int64_t>(d_done_flag.data_handle()));
499510
}
500511

501512
// Recompute inertia against the converged centroids

0 commit comments

Comments
 (0)