Skip to content

Commit

Permalink
Use pluto in atlas where possible instead of hic calls
Browse files Browse the repository at this point in the history
  • Loading branch information
wdeconinck committed Feb 27, 2025
1 parent 77a9e3f commit 19343bd
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 223 deletions.
1 change: 1 addition & 0 deletions src/atlas/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,7 @@ ecbuild_add_library( TARGET atlas
atlas_io
hic
hicsparse
pluto
$<${atlas_HAVE_EIGEN}:Eigen3::Eigen>
$<${atlas_HAVE_OMP_CXX}:OpenMP::OpenMP_CXX>
$<${atlas_HAVE_GRIDTOOLS_STORAGE}:GridTools::gridtools>
Expand Down
165 changes: 52 additions & 113 deletions src/atlas/array/native/NativeDataStore.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#include <limits> // std::numeric_limits<T>::signaling_NaN
#include <sstream>

#include "pluto/pluto.h"

#include "atlas/array/ArrayDataStore.h"
#include "atlas/library/Library.h"
#include "atlas/library/config.h"
Expand All @@ -24,9 +26,6 @@
#include "atlas/runtime/Log.h"
#include "eckit/log/Bytes.h"

#include "hic/hic.h"


#define ATLAS_ACC_DEBUG 0

//------------------------------------------------------------------------------
Expand Down Expand Up @@ -94,26 +93,15 @@ template <typename Value>
void initialise(Value[], size_t) {}
#endif

static int devices() {
static int devices_ = [](){
int n = 0;
auto err = hicGetDeviceCount(&n);
if (err != hicSuccess) {
n = 0;
static_cast<void>(hicGetLastError());
}
return n;
}();
return devices_;
}

template <typename Value>
class DataStore : public ArrayDataStore {
public:
DataStore(size_t size): size_(size) {
DataStore(size_t size): size_(size),
host_allocator_{pluto::new_delete_resource()},
device_allocator_{pluto::device_resource()} {
allocateHost();
initialise(host_data_, size_);
if (ATLAS_HAVE_GPU && devices()) {
if (ATLAS_HAVE_GPU && pluto::devices()) {
device_updated_ = false;
}
else {
Expand All @@ -127,25 +115,19 @@ class DataStore : public ArrayDataStore {
}

void updateDevice() const override {
if (ATLAS_HAVE_GPU && devices()) {
if (ATLAS_HAVE_GPU && pluto::devices()) {
if (not device_allocated_) {
allocateDevice();
}
hicError_t err = hicMemcpy(device_data_, host_data_, size_*sizeof(Value), hicMemcpyHostToDevice);
if (err != hicSuccess) {
throw_AssertionFailed("Failed to updateDevice: "+std::string(hicGetErrorString(err)), Here());
}
pluto::copy_host_to_device(device_data_, host_data_, size_);
device_updated_ = true;
}
}

void updateHost() const override {
if constexpr (ATLAS_HAVE_GPU) {
if (device_allocated_) {
hicError_t err = hicMemcpy(host_data_, device_data_, size_*sizeof(Value), hicMemcpyDeviceToHost);
if (err != hicSuccess) {
throw_AssertionFailed("Failed to updateHost: "+std::string(hicGetErrorString(err)), Here());
}
pluto::copy_device_to_host(host_data_, device_data_, size_);
host_updated_ = true;
}
}
Expand Down Expand Up @@ -174,32 +156,24 @@ class DataStore : public ArrayDataStore {
bool deviceAllocated() const override { return device_allocated_; }

void allocateDevice() const override {
if (ATLAS_HAVE_GPU && devices()) {
if (ATLAS_HAVE_GPU && pluto::devices()) {
if (device_allocated_) {
return;
}
if (size_) {
hicError_t err = hicMalloc((void**)&device_data_, sizeof(Value)*size_);
if (err != hicSuccess) {
throw_AssertionFailed("Failed to allocate GPU memory: " + std::string(hicGetErrorString(err)), Here());
}
device_data_ = device_allocator_.allocate(size_);
device_allocated_ = true;
accMap();
}
}
}

void deallocateDevice() const override {
if constexpr (ATLAS_HAVE_GPU) {
if (device_allocated_) {
accUnmap();
hicError_t err = hicFree(device_data_);
if (err != hicSuccess) {
throw_AssertionFailed("Failed to deallocate GPU memory: " + std::string(hicGetErrorString(err)), Here());
}
device_data_ = nullptr;
device_allocated_ = false;
}
if (device_allocated_) {
accUnmap();
device_allocator_.deallocate(device_data_,size_);
device_data_ = nullptr;
device_allocated_ = false;
}
}

Expand Down Expand Up @@ -259,36 +233,22 @@ class DataStore : public ArrayDataStore {
throw_Exception(ss.str(), loc);
}

void alloc_aligned(Value*& ptr, size_t n) {
if (n > 0) {
const size_t alignment = 64 * sizeof(Value);
size_t bytes = sizeof(Value) * n;
MemoryHighWatermark::instance() += bytes;

int err = posix_memalign((void**)&ptr, alignment, bytes);
if (err) {
throw_AllocationFailed(bytes, Here());
}
}
else {
ptr = nullptr;
}
}

void free_aligned(Value*& ptr) {
if (ptr) {
free(ptr);
ptr = nullptr;
MemoryHighWatermark::instance() -= footprint();
}
}

void allocateHost() {
alloc_aligned(host_data_, size_);
if (size_ > 0) {
MemoryHighWatermark::instance() += footprint();
host_data_ = host_allocator_.allocate(size_);
}
else {
host_data_ = nullptr;
}
}

void deallocateHost() {
free_aligned(host_data_);
if (host_data_) {
host_allocator_.deallocate(host_data_, size_);
host_data_ = nullptr;
MemoryHighWatermark::instance() -= footprint();
}
}

size_t footprint() const { return sizeof(Value) * size_; }
Expand All @@ -302,6 +262,8 @@ class DataStore : public ArrayDataStore {
mutable bool device_allocated_{false};
mutable bool acc_mapped_{false};

pluto::allocator<Value> host_allocator_;
mutable pluto::allocator<Value> device_allocator_;
};

//------------------------------------------------------------------------------
Expand All @@ -311,22 +273,23 @@ class WrappedDataStore : public ArrayDataStore {
public:

void init_device() {
if (ATLAS_HAVE_GPU && devices()) {
if (ATLAS_HAVE_GPU && pluto::devices()) {
device_updated_ = false;
}
else {
device_data_ = host_data_;
}
}

WrappedDataStore(Value* host_data, size_t size): host_data_(host_data), size_(size) {
WrappedDataStore(Value* host_data, size_t size): host_data_(host_data), size_(size),
device_allocator_{pluto::device_resource()} {
init_device();
}

WrappedDataStore(Value* host_data, const ArraySpec& spec):
host_data_(host_data),
size_(spec.size())
{
size_(spec.size()),
device_allocator_{pluto::device_resource()} {
init_device();
contiguous_ = spec.contiguous();
if (! contiguous_) {
Expand Down Expand Up @@ -363,25 +326,17 @@ class WrappedDataStore : public ArrayDataStore {
}

void updateDevice() const override {
if (ATLAS_HAVE_GPU && devices()) {
if (ATLAS_HAVE_GPU && pluto::devices()) {
if (not device_allocated_) {
allocateDevice();
}
if (contiguous_) {
hicError_t err = hicMemcpy(device_data_, host_data_, size_*sizeof(Value), hicMemcpyHostToDevice);
if (err != hicSuccess) {
throw_AssertionFailed("Failed to updateDevice: "+std::string(hicGetErrorString(err)), Here());
}
pluto::copy_host_to_device(device_data_, host_data_, size_);
}
else {
hicError_t err = hicMemcpy2D(
device_data_, memcpy_h2d_pitch_ * sizeof(Value),
pluto::memcpy_host_to_device_2D(device_data_, memcpy_h2d_pitch_ * sizeof(Value),
host_data_, memcpy_d2h_pitch_ * sizeof(Value),
memcpy_width_ * sizeof(Value), memcpy_height_,
hicMemcpyHostToDevice);
if (err != hicSuccess) {
throw_AssertionFailed("Failed to updateDevice: "+std::string(hicGetErrorString(err)), Here());
}
memcpy_width_ * sizeof(Value), memcpy_height_);
}
device_updated_ = true;
}
Expand All @@ -391,20 +346,12 @@ class WrappedDataStore : public ArrayDataStore {
if constexpr (ATLAS_HAVE_GPU) {
if (device_allocated_) {
if (contiguous_) {
hicError_t err = hicMemcpy(host_data_, device_data_, size_*sizeof(Value), hicMemcpyDeviceToHost);
if (err != hicSuccess) {
throw_AssertionFailed("Failed to updateHost: "+std::string(hicGetErrorString(err)), Here());
}
pluto::copy_device_to_host(host_data_, device_data_, size_);
}
else {
hicError_t err = hicMemcpy2D(
host_data_, memcpy_d2h_pitch_ * sizeof(Value),
pluto::memcpy_device_to_host_2D(host_data_, memcpy_d2h_pitch_ * sizeof(Value),
device_data_, memcpy_h2d_pitch_ * sizeof(Value),
memcpy_width_ * sizeof(Value), memcpy_height_,
hicMemcpyDeviceToHost);
if (err != hicSuccess) {
throw_AssertionFailed("Failed to updateHost: "+std::string(hicGetErrorString(err)), Here());
}
memcpy_width_ * sizeof(Value), memcpy_height_);
}
host_updated_ = true;
}
Expand Down Expand Up @@ -435,15 +382,12 @@ class WrappedDataStore : public ArrayDataStore {
bool deviceAllocated() const override { return device_allocated_; }

void allocateDevice() const override {
if (ATLAS_HAVE_GPU && devices()) {
if (ATLAS_HAVE_GPU && pluto::devices()) {
if (device_allocated_) {
return;
}
if (size_) {
hicError_t err = hicMalloc((void**)&device_data_, sizeof(Value)*size_);
if (err != hicSuccess) {
throw_AssertionFailed("Failed to allocate GPU memory: " + std::string(hicGetErrorString(err)), Here());
}
device_data_ = device_allocator_.allocate(size_);
device_allocated_ = true;
if (contiguous_) {
accMap();
Expand All @@ -453,18 +397,13 @@ class WrappedDataStore : public ArrayDataStore {
}

void deallocateDevice() const override {
if constexpr (ATLAS_HAVE_GPU) {
if (device_allocated_) {
if (contiguous_) {
accUnmap();
}
hicError_t err = hicFree(device_data_);
if (err != hicSuccess) {
throw_AssertionFailed("Failed to deallocate GPU memory: " + std::string(hicGetErrorString(err)), Here());
}
device_data_ = nullptr;
device_allocated_ = false;
if (device_allocated_) {
if (contiguous_) {
accUnmap();
}
device_allocator_.deallocate(device_data_, size_);
device_data_ = nullptr;
device_allocated_ = false;
}
}

Expand Down Expand Up @@ -505,7 +444,6 @@ class WrappedDataStore : public ArrayDataStore {
}

void accUnmap() const override {
#if ATLAS_HAVE_ACC
if (acc_mapped_) {
ATLAS_ASSERT(atlas::acc::is_present(host_data_, size_ * sizeof(Value)));
if constexpr(ATLAS_ACC_DEBUG) {
Expand All @@ -514,7 +452,6 @@ class WrappedDataStore : public ArrayDataStore {
atlas::acc::unmap(host_data_);
acc_mapped_ = false;
}
#endif
}

private:
Expand All @@ -532,6 +469,8 @@ class WrappedDataStore : public ArrayDataStore {
mutable bool device_updated_{true};
mutable bool device_allocated_{false};
mutable bool acc_mapped_{false};

mutable pluto::allocator<Value> device_allocator_;
};

} // namespace native
Expand Down
22 changes: 2 additions & 20 deletions src/atlas/library/Library.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ static bool feature_MKL() {
} // namespace
#endif

#include "hic/hic.h"
#include "pluto/pluto.h"

#include "atlas_io/Trace.h"

Expand Down Expand Up @@ -128,24 +128,6 @@ static void init_data_paths(std::vector<std::string>& data_paths) {
add_tokens(data_paths, "~atlas/share", ":");
}

static std::size_t devices() {
if constexpr (ATLAS_HAVE_GPU) {
static std::size_t _devices = []() -> std::size_t {
int num_devices = 0;
auto err = hicGetDeviceCount(&num_devices);
if (err) {
num_devices = 0;
}
return static_cast<std::size_t>(num_devices);
}();
return _devices;
}
else {
return 0;
}
}


} // namespace

//----------------------------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -363,7 +345,7 @@ void Library::initialise(const eckit::Parametrisation& config) {
out << " OMP\n";
out << " max_threads [" << atlas_omp_get_max_threads() << "] \n";
out << " GPU\n";
out << " devices [" << devices() << "] \n";
out << " devices [" << pluto::devices() << "] \n";
out << " OpenACC [" << acc::devices() << "] \n";
out << " \n";
out << " log.info [" << str(info_) << "] \n";
Expand Down
Loading

0 comments on commit 19343bd

Please sign in to comment.