@@ -332,7 +332,13 @@ void mnmg_fit(const raft::resources& handle,
332332
333333 std::mt19937 gen (params.rng_state .seed );
334334
335- auto d_done = raft::make_device_scalar<int64_t >(dev_res, 0 );
335+ // On-device convergence state, mirroring single-GPU `detail::fit`.
336+ // The flag is `int64_t` for NCCL allreduce compatibility; SUM>0 means
337+ // any rank converged, which guards against FP non-determinism in
338+ // compute_centroid_shift diverging ranks.
339+ auto d_prior_cost = raft::make_device_scalar<T>(dev_res, T{0 });
340+ auto d_done_flag = raft::make_device_scalar<int64_t >(dev_res, 0 );
341+ auto h_done_flag = raft::make_pinned_scalar<int64_t >(dev_res, 0 );
336342
337343 std::optional<cuvs::spatial::knn::detail::utils::batch_load_iterator<T>> data_batches_opt;
338344 if (has_data) {
@@ -367,9 +373,23 @@ void mnmg_fit(const raft::resources& handle,
367373 raft::matrix::fill (dev_res, batch_weights.view (), T{1 });
368374 }
369375
370- T prior_cluster_cost = T{0 };
376+ // Reset per-pass convergence state to avoid leaking it across n_init.
377+ raft::matrix::fill (dev_res, d_prior_cost.view (), T{0 });
378+ *h_done_flag.data_handle () = 0 ;
371379
372380 for (local_n_iter = 1 ; local_n_iter <= iter_params.max_iter ; ++local_n_iter) {
381+ // Consume the previous iteration's allreduced flag from pinned host.
382+ if (local_n_iter > 1 ) {
383+ raft::resource::sync_stream (dev_res);
384+ if (*h_done_flag.data_handle ()) {
385+ --local_n_iter;
386+ RAFT_LOG_DEBUG (" SNMG KMeans: threshold triggered after %d iterations on rank %d" ,
387+ static_cast <int >(local_n_iter),
388+ rank);
389+ break ;
390+ }
391+ }
392+
373393 RAFT_LOG_DEBUG (" SNMG KMeans: iteration %d on rank %d" , local_n_iter, rank);
374394
375395 raft::matrix::fill (dev_res, centroid_sums.view (), T{0 });
@@ -454,48 +474,39 @@ void mnmg_fit(const raft::resources& handle,
454474 rank_centroids_const,
455475 new_centroids.view ());
456476
457- // Phase 4: convergence check (synchronized across ranks)
477+ // Phase 4: device-side convergence evaluation. Compute shift,
478+ // run `check_convergence` via `map_offset`, allreduce the flag,
479+ // shadow into pinned host. Consumed at top of next iteration.
458480 cuvs::cluster::kmeans::detail::compute_centroid_shift<T, IdxT>(
459481 dev_res,
460482 raft::make_const_mdspan (rank_centroids.view ()),
461483 raft::make_const_mdspan (new_centroids.view ()),
462484 sqrd_norm_error_dev.view ());
463- T sqrdNormError = T{0 };
464- raft::copy (&sqrdNormError, sqrd_norm_error_dev.data_handle (), 1 , stream);
465- raft::resource::sync_stream (dev_res);
466485
467486 raft::copy (
468487 rank_centroids.data_handle (), new_centroids.data_handle (), n_clusters * n_features, stream);
469488
470- bool done = false ;
471-
472- raft::copy (&local_inertia, clustering_cost.data_handle (), 1 , stream);
473- raft::resource::sync_stream (dev_res);
489+ auto d_cost_view = raft::make_device_scalar_view<const T>(clustering_cost.data_handle ());
490+ auto d_prior_view = d_prior_cost.view ();
491+ auto d_norm_view = raft::make_device_scalar_view<const T>(sqrd_norm_error_dev.data_handle ());
492+ auto d_done_view = d_done_flag.view ();
493+ T tol = static_cast <T>(params.tol );
494+ int iter = static_cast <int >(local_n_iter);
474495
475- if (local_inertia == T{0 }) {
476- RAFT_LOG_WARN (" Zero clustering cost detected: all points coincide with their centroids." );
477- } else if (local_n_iter > 1 && prior_cluster_cost > T{0 }) {
478- T delta = local_inertia / prior_cluster_cost;
479- if (delta > 1 - params.tol ) { done = true ; }
480- }
481- prior_cluster_cost = local_inertia;
482-
483- if (sqrdNormError < params.tol ) { done = true ; }
484-
485- // Allreduce the convergence flag so all ranks agree (FP non-determinism
486- // in compute_centroid_shift could otherwise diverge ranks and deadlock)
487- int64_t done_val = done ? 1 : 0 ;
488- raft::copy (d_done.data_handle (), &done_val, 1 , stream);
489- SNMG_ALLREDUCE (d_done.data_handle (), d_done.data_handle (), 1 );
490- raft::copy (&done_val, d_done.data_handle (), 1 , stream);
491- raft::resource::sync_stream (dev_res);
492- done = (done_val > 0 );
493-
494- if (done) {
495- RAFT_LOG_DEBUG (
496- " SNMG KMeans: threshold triggered after %d iterations on rank %d" , local_n_iter, rank);
497- break ;
498- }
496+ raft::linalg::map_offset (
497+ dev_res,
498+ raft::make_device_vector_view<int64_t , int >(d_done_flag.data_handle (), 1 ),
499+ [=] __device__ (int ) {
500+ cuvs::cluster::kmeans::detail::check_convergence (
501+ d_cost_view, d_prior_view, d_norm_view, tol, iter, d_done_view);
502+ return *d_done_view.data_handle ();
503+ });
504+
505+ SNMG_ALLREDUCE (d_done_flag.data_handle (), d_done_flag.data_handle (), 1 );
506+
507+ raft::copy (dev_res,
508+ raft::make_pinned_scalar_view (h_done_flag.data_handle ()),
509+ raft::make_device_scalar_view<const int64_t >(d_done_flag.data_handle ()));
499510 }
500511
501512 // Recompute inertia against the converged centroids
0 commit comments