77#include " kmeans_test_blobs.cuh"
88
99#include < cuvs/cluster/kmeans.hpp>
10+ #include < raft/common/nccl_macros.hpp>
11+ #include < raft/comms/std_comms.hpp>
12+ #include < raft/core/device_resources.hpp>
1013#include < raft/core/device_resources_snmg.hpp>
1114#include < raft/core/operators.hpp>
1215#include < raft/core/resource/cuda_stream.hpp>
1316#include < raft/core/resource/multi_gpu.hpp>
14- #include < raft/core/resource/nccl_comm.hpp>
1517#include < raft/core/resources.hpp>
1618#include < raft/stats/adjusted_rand_index.cuh>
1719#include < raft/util/cuda_utils.cuh>
2123
2224#include < cuda_runtime.h>
2325#include < gtest/gtest.h>
26+ #include < nccl.h>
2427#include < omp.h>
2528
2629#include < algorithm>
2730#include < cmath>
2831#include < cstdint>
32+ #include < memory>
2933#include < optional>
3034#include < vector>
3135
@@ -36,7 +40,7 @@ namespace {
3640constexpr int kMaxRanksForNcclTest = 4 ;
3741
3842template <typename T>
39- int run_mg_fit_omp (raft::device_resources_snmg& clique ,
43+ int run_mg_fit_omp (const std::vector< int >& device_ids ,
4044 const cuvs::cluster::kmeans::params& kp,
4145 const T* h_X,
4246 const std::vector<T>* h_w,
@@ -50,8 +54,38 @@ int run_mg_fit_omp(raft::device_resources_snmg& clique,
5054 T& out_inertia,
5155 int64_t & out_n_iter)
5256{
53- const int num_ranks = raft::resource::get_num_ranks (clique);
54- raft::resource::get_nccl_comms (clique);
57+ const int num_ranks = static_cast <int >(device_ids.size ());
58+
59+ int current_device = 0 ;
60+ RAFT_CUDA_TRY (cudaGetDevice (¤t_device));
61+
62+ std::vector<std::unique_ptr<raft::device_resources>> rank_resources;
63+ rank_resources.reserve (static_cast <size_t >(num_ranks));
64+ for (int r = 0 ; r < num_ranks; ++r) {
65+ RAFT_CUDA_TRY (cudaSetDevice (device_ids[static_cast <size_t >(r)]));
66+ rank_resources.push_back (std::make_unique<raft::device_resources>());
67+ }
68+
69+ std::vector<ncclComm_t> nccl_comms (static_cast <size_t >(num_ranks), nullptr );
70+ ncclUniqueId nccl_id;
71+ RAFT_NCCL_TRY (ncclGetUniqueId (&nccl_id));
72+ RAFT_NCCL_TRY (ncclGroupStart ());
73+ for (int r = 0 ; r < num_ranks; ++r) {
74+ RAFT_CUDA_TRY (cudaSetDevice (device_ids[static_cast <size_t >(r)]));
75+ RAFT_NCCL_TRY (ncclCommInitRank (&nccl_comms[static_cast <size_t >(r)], num_ranks, nccl_id, r));
76+ }
77+ RAFT_NCCL_TRY (ncclGroupEnd ());
78+
79+ for (int r = 0 ; r < num_ranks; ++r) {
80+ RAFT_CUDA_TRY (cudaSetDevice (device_ids[static_cast <size_t >(r)]));
81+ raft::comms::build_comms_nccl_only (rank_resources[static_cast <size_t >(r)].get (),
82+ nccl_comms[static_cast <size_t >(r)],
83+ num_ranks,
84+ r);
85+ }
86+
87+ RAFT_CUDA_TRY (cudaSetDevice (current_device));
88+
5589 partitions_per_rank = std::max (1 , partitions_per_rank);
5690 out_h_centroids.assign (static_cast <size_t >(n_clusters) * n_features, T{0 });
5791 T inertia = T{0 };
@@ -65,8 +99,9 @@ int run_mg_fit_omp(raft::device_resources_snmg& clique,
6599 actual_threads = omp_get_num_threads ();
66100 }
67101 if (actual_threads == num_ranks) {
68- const int r = omp_get_thread_num ();
69- auto const & rank_res = raft::resource::set_current_device_to_rank (clique, r);
102+ const int r = omp_get_thread_num ();
103+ RAFT_CUDA_TRY (cudaSetDevice (device_ids[static_cast <size_t >(r)]));
104+ auto const & rank_res = *rank_resources[static_cast <size_t >(r)];
70105 auto rank_stream = raft::resource::get_cuda_stream (rank_res);
71106
72107 const int base = n_samples / num_ranks;
@@ -137,7 +172,7 @@ int run_mg_fit_omp(raft::device_resources_snmg& clique,
137172 }
138173 }
139174
140- cuvs::cluster::kmeans::mg::fit (clique ,
175+ cuvs::cluster::kmeans::mg::fit (rank_res ,
141176 kp,
142177 X_parts,
143178 sw_parts,
@@ -188,7 +223,7 @@ int run_mg_fit_omp(raft::device_resources_snmg& clique,
188223 }
189224 }
190225
191- cuvs::cluster::kmeans::mg::fit (clique ,
226+ cuvs::cluster::kmeans::mg::fit (rank_res ,
192227 kp,
193228 X_parts,
194229 sw_parts,
@@ -197,8 +232,11 @@ int run_mg_fit_omp(raft::device_resources_snmg& clique,
197232 raft::make_host_scalar_view (&local_n_iter));
198233 }
199234
235+ // Ensure all ranks have completed the fit before writing outputs.
236+ raft::resource::sync_stream (rank_res);
237+ #pragma omp barrier
200238 if (r == 0 ) {
201- // mnmg_fit writes outputs only on rank 0 .
239+ // Copy rank 0's outputs for comparison .
202240 raft::update_host (
203241 out_h_centroids.data (), d_rank_centroids.data (), out_h_centroids.size (), rank_stream);
204242 raft::resource::sync_stream (rank_res);
@@ -208,6 +246,21 @@ int run_mg_fit_omp(raft::device_resources_snmg& clique,
208246 }
209247 }
210248
249+ for (int r = 0 ; r < num_ranks; ++r) {
250+ RAFT_CUDA_TRY (cudaSetDevice (device_ids[static_cast <size_t >(r)]));
251+ rank_resources[static_cast <size_t >(r)].reset ();
252+ }
253+ rank_resources.clear ();
254+
255+ RAFT_NCCL_TRY (ncclGroupStart ());
256+ for (int r = 0 ; r < num_ranks; ++r) {
257+ RAFT_CUDA_TRY (cudaSetDevice (device_ids[static_cast <size_t >(r)]));
258+ auto comm = nccl_comms[static_cast <size_t >(r)];
259+ if (comm != nullptr ) { RAFT_NCCL_TRY (ncclCommDestroy (comm)); }
260+ }
261+ RAFT_NCCL_TRY (ncclGroupEnd ());
262+ RAFT_CUDA_TRY (cudaSetDevice (current_device));
263+
211264 out_inertia = inertia;
212265 out_n_iter = n_iter;
213266 return actual_threads;
@@ -568,9 +621,9 @@ struct KmeansMGNcclInputs {
568621template <typename T>
569622class KmeansMGNcclTest : public ::testing::TestWithParam<KmeansMGNcclInputs<T>> {
570623 protected:
571- KmeansMGNcclTest () : clique_(make_clique_device_ids ()) { clique_. set_memory_pool ( 50 ); }
624+ KmeansMGNcclTest () : device_ids_(make_nccl_test_device_ids ()) {}
572625
573- static std::vector<int > make_clique_device_ids ()
626+ static std::vector<int > make_nccl_test_device_ids ()
574627 {
575628 int num_devices = 0 ;
576629 RAFT_CUDA_TRY (cudaGetDeviceCount (&num_devices));
@@ -586,11 +639,9 @@ class KmeansMGNcclTest : public ::testing::TestWithParam<KmeansMGNcclInputs<T>>
586639 {
587640 testparams_ = ::testing::TestWithParam<KmeansMGNcclInputs<T>>::GetParam ();
588641
589- const int num_ranks = raft::resource::get_num_ranks (clique_ );
642+ const int num_ranks = static_cast < int >(device_ids_. size () );
590643 if (num_ranks < 1 ) { GTEST_SKIP () << " No CUDA devices available." ; }
591644
592- raft::resource::get_nccl_comms (clique_);
593-
594645 const int n_samples = testparams_.n_row ;
595646 const int n_features = testparams_.n_col ;
596647 const int n_clusters = testparams_.n_clusters ;
@@ -653,7 +704,7 @@ class KmeansMGNcclTest : public ::testing::TestWithParam<KmeansMGNcclInputs<T>>
653704 const std::vector<T>* h_w_ptr = has_weights ? &h_w : nullptr ;
654705 const std::vector<T>* h_init_ptr =
655706 testparams_.init == cuvs::cluster::kmeans::params::Array ? &h_initial_centroids : nullptr ;
656- const int actual_threads = run_mg_fit_omp<T>(clique_ ,
707+ const int actual_threads = run_mg_fit_omp<T>(device_ids_ ,
657708 kp,
658709 h_X,
659710 h_w_ptr,
@@ -786,7 +837,7 @@ class KmeansMGNcclTest : public ::testing::TestWithParam<KmeansMGNcclInputs<T>>
786837 }
787838 }
788839
789- raft::device_resources_snmg clique_ ;
840+ std::vector< int > device_ids_ ;
790841 KmeansMGNcclInputs<T> testparams_;
791842 double ari_vs_ref_ = 0 ;
792843 double ari_vs_sg_ = 0 ;
@@ -928,9 +979,9 @@ INSTANTIATE_TEST_SUITE_P(KmeansMGNcclTests,
928979template <typename T>
929980class KmeansMGOversamplingTest : public ::testing::Test {
930981 protected:
931- KmeansMGOversamplingTest () : clique_(make_clique_device_ids ()) { clique_. set_memory_pool ( 50 ); }
982+ KmeansMGOversamplingTest () : device_ids_(make_nccl_test_device_ids ()) {}
932983
933- static std::vector<int > make_clique_device_ids ()
984+ static std::vector<int > make_nccl_test_device_ids ()
934985 {
935986 int num_devices = 0 ;
936987 RAFT_CUDA_TRY (cudaGetDeviceCount (&num_devices));
@@ -944,7 +995,7 @@ class KmeansMGOversamplingTest : public ::testing::Test {
944995
945996 void run_test_body ()
946997 {
947- const int num_ranks = raft::resource::get_num_ranks (clique_ );
998+ const int num_ranks = static_cast < int >(device_ids_. size () );
948999 if (num_ranks < 1 ) { GTEST_SKIP () << " No CUDA devices available." ; }
9491000
9501001 constexpr int n_samples = 2000 ;
@@ -995,7 +1046,7 @@ class KmeansMGOversamplingTest : public ::testing::Test {
9951046 kp.oversampling_factor = oversampling_factor;
9961047
9971048 std::vector<T> h_centroids;
998- const int actual_threads = run_mg_fit_omp<T>(clique_ ,
1049+ const int actual_threads = run_mg_fit_omp<T>(device_ids_ ,
9991050 kp,
10001051 h_X.data (),
10011052 /* h_w=*/ nullptr ,
@@ -1008,10 +1059,10 @@ class KmeansMGOversamplingTest : public ::testing::Test {
10081059 h_centroids,
10091060 inertia,
10101061 n_iter);
1011- ASSERT_EQ (actual_threads, raft::resource::get_num_ranks (clique_ ));
1062+ ASSERT_EQ (actual_threads, static_cast < int >(device_ids_. size () ));
10121063 }
10131064
1014- raft::device_resources_snmg clique_ ;
1065+ std::vector< int > device_ids_ ;
10151066};
10161067
10171068typedef KmeansMGOversamplingTest<float > KmeansMGOversamplingTestF;
0 commit comments