Skip to content

Commit

Permalink
add gpu data wrapping with a unit test atlas_test_field_wrap_device
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrdar committed Apr 3, 2024
1 parent d378897 commit 8543241
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 17 deletions.
9 changes: 7 additions & 2 deletions src/atlas/array/native/NativeArray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,16 @@ Array* Array::create(const ArrayShape& shape, const ArrayLayout& layout) {
}
template <typename Value>
Array* Array::wrap(Value* data, const ArrayShape& shape) {
return new ArrayT<Value>(new native::WrappedDataStore<Value>(data), shape);
size_t size = 1;
for (int i = 0; i < shape.size(); ++i) {
size *= shape[i];
}
return new ArrayT<Value>(new native::WrappedDataStore<Value>(data, size), shape);
}
template <typename Value>
Array* Array::wrap(Value* data, const ArraySpec& spec) {
return new ArrayT<Value>(new native::WrappedDataStore<Value>(data), spec);
size_t size = spec.size();
return new ArrayT<Value>(new native::WrappedDataStore<Value>(data, size), spec);
}

Array::~Array() = default;
Expand Down
75 changes: 60 additions & 15 deletions src/atlas/array/native/NativeDataStore.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,38 +223,83 @@ class DataStore : public ArrayDataStore {
template <typename Value>
class WrappedDataStore : public ArrayDataStore {
public:
WrappedDataStore(Value* data_store): data_store_(data_store) {}
WrappedDataStore(Value* data_store, size_t size): data_store_(data_store), size_(size) {
setHostNeedsUpdate(false);
setDeviceNeedsUpdate(false);
#if ATLAS_NATIVE_STORAGE_BACKEND_CUDA
setDeviceNeedsUpdate(true);
#endif
}

virtual void updateHost() const override {}
void updateHost() const override {
#if ATLAS_NATIVE_STORAGE_BACKEND_CUDA
cudaMemcpy(data_store_, data_store_dev_, size_*sizeof(Value), cudaMemcpyDeviceToHost);
#endif
}

virtual void updateDevice() const override {}
void updateDevice() const override {
#if ATLAS_NATIVE_STORAGE_BACKEND_CUDA
if (not device_allocated_) {
allocateDevice();
}
cudaMemcpy(data_store_dev_, data_store_, size_*sizeof(Value), cudaMemcpyHostToDevice);
#endif
}

virtual bool valid() const override { return true; }
bool valid() const override { return true; }

virtual void syncHostDevice() const override {}
bool hostNeedsUpdate() const override { return (not host_updated_); }

virtual bool deviceAllocated() const override { return false; }
bool deviceNeedsUpdate() const override { return (not device_updated_); }

virtual void allocateDevice() const override {}
void setHostNeedsUpdate(bool v) const override { host_updated_ = (not v); }

virtual void deallocateDevice() const override {}
void setDeviceNeedsUpdate(bool v) const override { device_updated_ = (not v); }

virtual bool hostNeedsUpdate() const override { return true; }
void* voidDataStore() override { return static_cast<void*>(data_store_); }

virtual bool deviceNeedsUpdate() const override { return false; }
void* voidHostData() override { return static_cast<void*>(data_store_); }

virtual void setHostNeedsUpdate(bool) const override {}
void* voidDeviceData() override { return static_cast<void*>(data_store_dev_); }

virtual void setDeviceNeedsUpdate(bool) const override {}

virtual void* voidDataStore() override { return static_cast<void*>(data_store_); }
void syncHostDevice() const override {
if (host_updated_) updateDevice();
if (device_updated_) updateHost();
}

virtual void* voidHostData() override { return static_cast<void*>(data_store_); }
bool deviceAllocated() const override { return device_allocated_; }

virtual void* voidDeviceData() override { return static_cast<void*>(data_store_); }
void allocateDevice() const override {
#if ATLAS_NATIVE_STORAGE_BACKEND_CUDA
if (device_allocated_) {
return;
}
cudaMalloc((void**)&data_store_dev_, sizeof(Value)*size_);
device_allocated_ = true;
#if ATLAS_HAVE_ACC
atlas_acc_map_data(data_store_, data_store_dev_, sizeof(Value)*size_);
#endif
#endif
}

void deallocateDevice() const override {
#if ATLAS_NATIVE_STORAGE_BACKEND_CUDA
cudaFree(data_store_dev_);
device_allocated_ = false;
#if ATLAS_HAVE_ACC
atlas_acc_unmap_data(data_store_);
#endif
#endif
}

private:
Value* data_store_;
size_t size_;
Value* data_store_dev_;
mutable bool host_updated_;
mutable bool device_updated_;
mutable bool device_allocated_;
};

} // namespace native
Expand Down
15 changes: 15 additions & 0 deletions src/tests/field/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,20 @@ if( HAVE_FCTEST )
set_tests_properties ( atlas_fctest_field_device PROPERTIES LABELS "gpu;acc")
endif()

add_fctest( TARGET atlas_fctest_field_wrap_device
CONDITION atlas_HAVE_ACC AND ATLAS_STORAGE_BACKEND_CUDA
LINKER_LANGUAGE Fortran
SOURCES fctest_field_wrap_gpu.F90 external_acc_routine.F90
LIBS atlas_f
ENVIRONMENT ${ATLAS_TEST_ENVIRONMENT} ATLAS_RUN_NGPUS=1
)

if( TARGET atlas_fctest_field_wrap_device )
target_compile_options( atlas_fctest_field_wrap_device PUBLIC ${ACC_Fortran_FLAGS} )
target_link_libraries( atlas_fctest_field_wrap_device ${ACC_Fortran_FLAGS} )
target_link_options( atlas_fctest_field_wrap_device PRIVATE "${ACC_Fortran_FLAGS}")
set_tests_properties ( atlas_fctest_field_wrap_device PROPERTIES LABELS "gpu;acc")
endif()

endif()

0 comments on commit 8543241

Please sign in to comment.