Skip to content

Commit 8ce1ade

Browse files
committed
fix mnmg tests
1 parent 6ada5f6 commit 8ce1ade

3 files changed

Lines changed: 91 additions & 35 deletions

File tree

cpp/src/cluster/detail/kmeans_mg_distributed_init.cuh

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ void initKMeansPlusPlus_distributed(
234234
auto potentialCentroids =
235235
raft::make_device_matrix_view<DataT, IndexT>(centroidsBuf.data(), IndexT{1}, n_features);
236236
DataT* initialCentroid = centroidsBuf.data();
237-
if (rank == rp) {
237+
if (rank == rp) {
238238
RAFT_EXPECTS(n_local > 0,
239239
"selected source rank %d has no local rows; cannot pick an initial centroid",
240240
rp);
@@ -441,15 +441,13 @@ void initKMeansPlusPlus_distributed(
441441
recluster_params.n_init = 1;
442442

443443
auto weight_opt = std::make_optional(raft::make_const_mdspan(weight.view()));
444-
cuvs::cluster::kmeans::detail::kmeans_fit<DataT, IndexT>(
445-
handle,
446-
recluster_params,
447-
raft::make_const_mdspan(potentialCentroids),
448-
weight_opt,
449-
centroidsRawData,
450-
inertia_out.view(),
451-
n_iter_out.view(),
452-
std::ref(workspace));
444+
cuvs::cluster::kmeans::fit(handle,
445+
recluster_params,
446+
raft::make_const_mdspan(potentialCentroids),
447+
weight_opt,
448+
centroidsRawData,
449+
inertia_out.view(),
450+
n_iter_out.view());
453451

454452
} else if (static_cast<IndexT>(potentialCentroids.extent(0)) < n_clusters) {
455453
const IndexT n_random = n_clusters - static_cast<IndexT>(potentialCentroids.extent(0));

cpp/tests/CMakeLists.txt

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ rapids_test_init()
1515
function(ConfigureTest)
1616

1717
set(options NOCUDA FETCH_CODEBOOKS)
18-
set(oneValueArgs NAME GPUS PERCENT ADDITIONAL_DEP)
19-
set(multiValueArgs PATH)
18+
set(oneValueArgs NAME GPUS PERCENT)
19+
set(multiValueArgs PATH ADDITIONAL_DEP)
2020

2121
cmake_parse_arguments(_CUVS_TEST "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
2222
if(NOT DEFINED _CUVS_TEST_GPUS AND NOT DEFINED _CUVS_TEST_PERCENT)
@@ -332,6 +332,9 @@ if(BUILD_CAGRA_HNSWLIB)
332332
endif()
333333

334334
if(BUILD_MG_ALGOS)
335+
find_package(ucx REQUIRED)
336+
find_package(ucxx REQUIRED)
337+
335338
ConfigureTest(
336339
NAME NEIGHBORS_MG_TEST
337340
PATH neighbors/mg/test_float.cu
@@ -345,7 +348,11 @@ if(BUILD_MG_ALGOS)
345348
PATH cluster/kmeans_mg.cu
346349
GPUS 2
347350
PERCENT 100
348-
ADDITIONAL_DEP NCCL::NCCL
351+
ADDITIONAL_DEP
352+
NCCL::NCCL
353+
ucx::ucp
354+
ucx::ucs
355+
ucxx::ucxx
349356
)
350357
endif()
351358

cpp/tests/cluster/kmeans_mg.cu

Lines changed: 73 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77
#include "kmeans_test_blobs.cuh"
88

99
#include <cuvs/cluster/kmeans.hpp>
10+
#include <raft/common/nccl_macros.hpp>
11+
#include <raft/comms/std_comms.hpp>
12+
#include <raft/core/device_resources.hpp>
1013
#include <raft/core/device_resources_snmg.hpp>
1114
#include <raft/core/operators.hpp>
1215
#include <raft/core/resource/cuda_stream.hpp>
1316
#include <raft/core/resource/multi_gpu.hpp>
14-
#include <raft/core/resource/nccl_comm.hpp>
1517
#include <raft/core/resources.hpp>
1618
#include <raft/stats/adjusted_rand_index.cuh>
1719
#include <raft/util/cuda_utils.cuh>
@@ -21,11 +23,13 @@
2123

2224
#include <cuda_runtime.h>
2325
#include <gtest/gtest.h>
26+
#include <nccl.h>
2427
#include <omp.h>
2528

2629
#include <algorithm>
2730
#include <cmath>
2831
#include <cstdint>
32+
#include <memory>
2933
#include <optional>
3034
#include <vector>
3135

@@ -36,7 +40,7 @@ namespace {
3640
constexpr int kMaxRanksForNcclTest = 4;
3741

3842
template <typename T>
39-
int run_mg_fit_omp(raft::device_resources_snmg& clique,
43+
int run_mg_fit_omp(const std::vector<int>& device_ids,
4044
const cuvs::cluster::kmeans::params& kp,
4145
const T* h_X,
4246
const std::vector<T>* h_w,
@@ -50,8 +54,38 @@ int run_mg_fit_omp(raft::device_resources_snmg& clique,
5054
T& out_inertia,
5155
int64_t& out_n_iter)
5256
{
53-
const int num_ranks = raft::resource::get_num_ranks(clique);
54-
raft::resource::get_nccl_comms(clique);
57+
const int num_ranks = static_cast<int>(device_ids.size());
58+
59+
int current_device = 0;
60+
RAFT_CUDA_TRY(cudaGetDevice(&current_device));
61+
62+
std::vector<std::unique_ptr<raft::device_resources>> rank_resources;
63+
rank_resources.reserve(static_cast<size_t>(num_ranks));
64+
for (int r = 0; r < num_ranks; ++r) {
65+
RAFT_CUDA_TRY(cudaSetDevice(device_ids[static_cast<size_t>(r)]));
66+
rank_resources.push_back(std::make_unique<raft::device_resources>());
67+
}
68+
69+
std::vector<ncclComm_t> nccl_comms(static_cast<size_t>(num_ranks), nullptr);
70+
ncclUniqueId nccl_id;
71+
RAFT_NCCL_TRY(ncclGetUniqueId(&nccl_id));
72+
RAFT_NCCL_TRY(ncclGroupStart());
73+
for (int r = 0; r < num_ranks; ++r) {
74+
RAFT_CUDA_TRY(cudaSetDevice(device_ids[static_cast<size_t>(r)]));
75+
RAFT_NCCL_TRY(ncclCommInitRank(&nccl_comms[static_cast<size_t>(r)], num_ranks, nccl_id, r));
76+
}
77+
RAFT_NCCL_TRY(ncclGroupEnd());
78+
79+
for (int r = 0; r < num_ranks; ++r) {
80+
RAFT_CUDA_TRY(cudaSetDevice(device_ids[static_cast<size_t>(r)]));
81+
raft::comms::build_comms_nccl_only(rank_resources[static_cast<size_t>(r)].get(),
82+
nccl_comms[static_cast<size_t>(r)],
83+
num_ranks,
84+
r);
85+
}
86+
87+
RAFT_CUDA_TRY(cudaSetDevice(current_device));
88+
5589
partitions_per_rank = std::max(1, partitions_per_rank);
5690
out_h_centroids.assign(static_cast<size_t>(n_clusters) * n_features, T{0});
5791
T inertia = T{0};
@@ -65,8 +99,9 @@ int run_mg_fit_omp(raft::device_resources_snmg& clique,
6599
actual_threads = omp_get_num_threads();
66100
}
67101
if (actual_threads == num_ranks) {
68-
const int r = omp_get_thread_num();
69-
auto const& rank_res = raft::resource::set_current_device_to_rank(clique, r);
102+
const int r = omp_get_thread_num();
103+
RAFT_CUDA_TRY(cudaSetDevice(device_ids[static_cast<size_t>(r)]));
104+
auto const& rank_res = *rank_resources[static_cast<size_t>(r)];
70105
auto rank_stream = raft::resource::get_cuda_stream(rank_res);
71106

72107
const int base = n_samples / num_ranks;
@@ -137,7 +172,7 @@ int run_mg_fit_omp(raft::device_resources_snmg& clique,
137172
}
138173
}
139174

140-
cuvs::cluster::kmeans::mg::fit(clique,
175+
cuvs::cluster::kmeans::mg::fit(rank_res,
141176
kp,
142177
X_parts,
143178
sw_parts,
@@ -188,7 +223,7 @@ int run_mg_fit_omp(raft::device_resources_snmg& clique,
188223
}
189224
}
190225

191-
cuvs::cluster::kmeans::mg::fit(clique,
226+
cuvs::cluster::kmeans::mg::fit(rank_res,
192227
kp,
193228
X_parts,
194229
sw_parts,
@@ -197,8 +232,11 @@ int run_mg_fit_omp(raft::device_resources_snmg& clique,
197232
raft::make_host_scalar_view(&local_n_iter));
198233
}
199234

235+
// Ensure all ranks have completed the fit before writing outputs.
236+
raft::resource::sync_stream(rank_res);
237+
#pragma omp barrier
200238
if (r == 0) {
201-
// mnmg_fit writes outputs only on rank 0.
239+
// Copy rank 0's outputs for comparison.
202240
raft::update_host(
203241
out_h_centroids.data(), d_rank_centroids.data(), out_h_centroids.size(), rank_stream);
204242
raft::resource::sync_stream(rank_res);
@@ -208,6 +246,21 @@ int run_mg_fit_omp(raft::device_resources_snmg& clique,
208246
}
209247
}
210248

249+
for (int r = 0; r < num_ranks; ++r) {
250+
RAFT_CUDA_TRY(cudaSetDevice(device_ids[static_cast<size_t>(r)]));
251+
rank_resources[static_cast<size_t>(r)].reset();
252+
}
253+
rank_resources.clear();
254+
255+
RAFT_NCCL_TRY(ncclGroupStart());
256+
for (int r = 0; r < num_ranks; ++r) {
257+
RAFT_CUDA_TRY(cudaSetDevice(device_ids[static_cast<size_t>(r)]));
258+
auto comm = nccl_comms[static_cast<size_t>(r)];
259+
if (comm != nullptr) { RAFT_NCCL_TRY(ncclCommDestroy(comm)); }
260+
}
261+
RAFT_NCCL_TRY(ncclGroupEnd());
262+
RAFT_CUDA_TRY(cudaSetDevice(current_device));
263+
211264
out_inertia = inertia;
212265
out_n_iter = n_iter;
213266
return actual_threads;
@@ -568,9 +621,9 @@ struct KmeansMGNcclInputs {
568621
template <typename T>
569622
class KmeansMGNcclTest : public ::testing::TestWithParam<KmeansMGNcclInputs<T>> {
570623
protected:
571-
KmeansMGNcclTest() : clique_(make_clique_device_ids()) { clique_.set_memory_pool(50); }
624+
KmeansMGNcclTest() : device_ids_(make_nccl_test_device_ids()) {}
572625

573-
static std::vector<int> make_clique_device_ids()
626+
static std::vector<int> make_nccl_test_device_ids()
574627
{
575628
int num_devices = 0;
576629
RAFT_CUDA_TRY(cudaGetDeviceCount(&num_devices));
@@ -586,11 +639,9 @@ class KmeansMGNcclTest : public ::testing::TestWithParam<KmeansMGNcclInputs<T>>
586639
{
587640
testparams_ = ::testing::TestWithParam<KmeansMGNcclInputs<T>>::GetParam();
588641

589-
const int num_ranks = raft::resource::get_num_ranks(clique_);
642+
const int num_ranks = static_cast<int>(device_ids_.size());
590643
if (num_ranks < 1) { GTEST_SKIP() << "No CUDA devices available."; }
591644

592-
raft::resource::get_nccl_comms(clique_);
593-
594645
const int n_samples = testparams_.n_row;
595646
const int n_features = testparams_.n_col;
596647
const int n_clusters = testparams_.n_clusters;
@@ -653,7 +704,7 @@ class KmeansMGNcclTest : public ::testing::TestWithParam<KmeansMGNcclInputs<T>>
653704
const std::vector<T>* h_w_ptr = has_weights ? &h_w : nullptr;
654705
const std::vector<T>* h_init_ptr =
655706
testparams_.init == cuvs::cluster::kmeans::params::Array ? &h_initial_centroids : nullptr;
656-
const int actual_threads = run_mg_fit_omp<T>(clique_,
707+
const int actual_threads = run_mg_fit_omp<T>(device_ids_,
657708
kp,
658709
h_X,
659710
h_w_ptr,
@@ -786,7 +837,7 @@ class KmeansMGNcclTest : public ::testing::TestWithParam<KmeansMGNcclInputs<T>>
786837
}
787838
}
788839

789-
raft::device_resources_snmg clique_;
840+
std::vector<int> device_ids_;
790841
KmeansMGNcclInputs<T> testparams_;
791842
double ari_vs_ref_ = 0;
792843
double ari_vs_sg_ = 0;
@@ -928,9 +979,9 @@ INSTANTIATE_TEST_SUITE_P(KmeansMGNcclTests,
928979
template <typename T>
929980
class KmeansMGOversamplingTest : public ::testing::Test {
930981
protected:
931-
KmeansMGOversamplingTest() : clique_(make_clique_device_ids()) { clique_.set_memory_pool(50); }
982+
KmeansMGOversamplingTest() : device_ids_(make_nccl_test_device_ids()) {}
932983

933-
static std::vector<int> make_clique_device_ids()
984+
static std::vector<int> make_nccl_test_device_ids()
934985
{
935986
int num_devices = 0;
936987
RAFT_CUDA_TRY(cudaGetDeviceCount(&num_devices));
@@ -944,7 +995,7 @@ class KmeansMGOversamplingTest : public ::testing::Test {
944995

945996
void run_test_body()
946997
{
947-
const int num_ranks = raft::resource::get_num_ranks(clique_);
998+
const int num_ranks = static_cast<int>(device_ids_.size());
948999
if (num_ranks < 1) { GTEST_SKIP() << "No CUDA devices available."; }
9491000

9501001
constexpr int n_samples = 2000;
@@ -995,7 +1046,7 @@ class KmeansMGOversamplingTest : public ::testing::Test {
9951046
kp.oversampling_factor = oversampling_factor;
9961047

9971048
std::vector<T> h_centroids;
998-
const int actual_threads = run_mg_fit_omp<T>(clique_,
1049+
const int actual_threads = run_mg_fit_omp<T>(device_ids_,
9991050
kp,
10001051
h_X.data(),
10011052
/*h_w=*/nullptr,
@@ -1008,10 +1059,10 @@ class KmeansMGOversamplingTest : public ::testing::Test {
10081059
h_centroids,
10091060
inertia,
10101061
n_iter);
1011-
ASSERT_EQ(actual_threads, raft::resource::get_num_ranks(clique_));
1062+
ASSERT_EQ(actual_threads, static_cast<int>(device_ids_.size()));
10121063
}
10131064

1014-
raft::device_resources_snmg clique_;
1065+
std::vector<int> device_ids_;
10151066
};
10161067

10171068
typedef KmeansMGOversamplingTest<float> KmeansMGOversamplingTestF;

0 commit comments

Comments
 (0)