Skip to content

Commit

Permalink
TEST/GTEST: Added cuda gpu switching testing.
Browse files Browse the repository at this point in the history
  • Loading branch information
rakhmets committed Dec 17, 2024
1 parent 4f0e6c1 commit d31bb61
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 4 deletions.
6 changes: 3 additions & 3 deletions test/gtest/common/mem_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ bool mem_buffer::is_mem_type_supported(ucs_memory_type_t mem_type)
mem_types.end();
}

void mem_buffer::set_device_context()
void mem_buffer::set_device_context(int device)
{
static __thread bool device_set = false;

Expand All @@ -179,7 +179,7 @@ void mem_buffer::set_device_context()

#if HAVE_CUDA
if (is_cuda_supported()) {
cudaSetDevice(0);
cudaSetDevice(device);
/* need to call free as context maybe lazily initialized when calling
* cudaSetDevice(0) but calling cudaFree(0) should guarantee context
* creation upon return */
Expand All @@ -189,7 +189,7 @@ void mem_buffer::set_device_context()

#if HAVE_ROCM
if (is_rocm_supported()) {
hipSetDevice(0);
hipSetDevice(device);
}
#endif

Expand Down
2 changes: 1 addition & 1 deletion test/gtest/common/mem_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class mem_buffer {
static bool is_gpu_supported();

/* set device context if compiled with GPU support */
static void set_device_context();
static void set_device_context(int device = 0);

/* returns whether ROCM device supports managed memory */
static bool is_rocm_managed_supported();
Expand Down
53 changes: 53 additions & 0 deletions test/gtest/ucp/test_ucp_mmap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ extern "C" {
#include <ucs/type/float8.h>
}

#if HAVE_CUDA
#include <cuda_runtime.h>
#endif

#include <cmath>
#include <list>

Expand Down Expand Up @@ -1248,3 +1252,52 @@ UCS_TEST_P(test_ucp_mmap_export, export_import) {
}

UCP_INSTANTIATE_TEST_CASE_GPU_AWARE(test_ucp_mmap_export)

#if HAVE_CUDA
class test_ucp_mmap_mgpu : public ucs::test {
};

UCS_TEST_F(test_ucp_mmap_mgpu, switch_gpu)
{
if (!mem_buffer::is_mem_type_supported(UCS_MEMORY_TYPE_CUDA)) {
UCS_TEST_SKIP_R("cuda is not supported");
}

int num_devices;
ASSERT_EQ(cudaGetDeviceCount(&num_devices), cudaSuccess);

if (num_devices < 2) {
UCS_TEST_SKIP_R("less than two cuda devices available");
}

ucs::handle<ucp_config_t*> config;
UCS_TEST_CREATE_HANDLE(ucp_config_t*, config, ucp_config_release,
ucp_config_read, NULL, NULL);

ucs::handle<ucp_context_h> context;
ucp_params_t params;
params.field_mask = UCP_PARAM_FIELD_FEATURES;
params.features = UCP_FEATURE_TAG;
UCS_TEST_CREATE_HANDLE(ucp_context_h, context, ucp_cleanup, ucp_init,
&params, config.get());

int device;
ASSERT_EQ(cudaGetDevice(&device), cudaSuccess);
ASSERT_EQ(cudaSetDevice((device + 1) % num_devices), cudaSuccess);

const size_t size = 16;
mem_buffer buffer(size, UCS_MEMORY_TYPE_CUDA);

ASSERT_EQ(cudaSetDevice(device), cudaSuccess);

ucp_mem_map_params_t mem_map_params;
mem_map_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS |
UCP_MEM_MAP_PARAM_FIELD_LENGTH;
mem_map_params.address = buffer.ptr();
mem_map_params.length = size;

ucp_mem_h ucp_mem;
ASSERT_EQ(ucp_mem_map(context.get(), &mem_map_params, &ucp_mem), UCS_OK);
EXPECT_EQ(ucp_mem_unmap(context.get(), ucp_mem), UCS_OK);
}
#endif

0 comments on commit d31bb61

Please sign in to comment.