Skip to content

Commit a8f5ba7

Browse files
committed
support symm memory on XPU devices
1 parent 86e450c commit a8f5ba7

File tree

9 files changed

+1506
-0
lines changed

9 files changed

+1506
-0
lines changed

src/xccl/IpcExchange.hpp

Lines changed: 400 additions & 0 deletions
Large diffs are not rendered by default.

src/xccl/XPUSymmetricMemory.cpp

Lines changed: 460 additions & 0 deletions
Large diffs are not rendered by default.

src/xccl/XPUSymmetricMemory.hpp

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
#include <sycl/sycl.hpp>
5+
#include <torch/csrc/distributed/c10d/Store.hpp>
6+
#include <torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>
7+
#include <xccl/XPUSymmetricMemoryTypes.hpp>
8+
9+
namespace c10d::symmetric_memory {
10+
11+
// Resource wrapper that owns a (vaddr, allocation handle) pair. Upon
12+
// destruction, it unmaps the vaddr and releases the allocation handle.
13+
struct AllocationRef : public c10::intrusive_ptr_target {
14+
void* ptr;
15+
HandleType handle;
16+
size_t block_size;
17+
int device_idx;
18+
bool local_allocation;
19+
20+
AllocationRef(
21+
void* ptr,
22+
HandleType handle,
23+
size_t block_size,
24+
int device_idx,
25+
bool local_allocation);
26+
27+
~AllocationRef();
28+
};
29+
30+
class XPUSymmetricMemory : public SymmetricMemory {
31+
public:
32+
XPUSymmetricMemory(
33+
std::vector<c10::intrusive_ptr<AllocationRef>> alloc_refs,
34+
std::vector<void*> buffers,
35+
std::vector<void*> signal_pads,
36+
HandleType mc_handle,
37+
void* mc_addr,
38+
size_t buffer_size,
39+
int local_device_idx,
40+
int rank,
41+
int world_size);
42+
43+
~XPUSymmetricMemory() override{};
44+
45+
std::vector<void*> get_buffer_ptrs() override;
46+
std::vector<void*> get_signal_pad_ptrs() override;
47+
void** get_buffer_ptrs_dev() override;
48+
void** get_signal_pad_ptrs_dev() override;
49+
size_t get_buffer_size() override;
50+
size_t get_signal_pad_size() override;
51+
52+
bool has_multicast_support() override;
53+
void* get_multicast_ptr() override;
54+
55+
at::Tensor get_buffer(
56+
int rank,
57+
c10::IntArrayRef sizes,
58+
c10::ScalarType dtype,
59+
int64_t storage_offset);
60+
61+
void barrier(int channel, size_t timeout_ms) override;
62+
void put_signal(int dst_rank, int channel, size_t timeout_ms) override;
63+
void wait_signal(int src_rank, int channel, size_t timeout_ms) override;
64+
65+
int get_rank() override;
66+
int get_world_size() override;
67+
c10::Device get_device() override;
68+
69+
void set_group_name(const std::string& group_name) {
70+
group_name_ = group_name;
71+
}
72+
73+
private:
74+
std::vector<c10::intrusive_ptr<AllocationRef>> alloc_refs_;
75+
std::vector<void*> buffers_;
76+
std::vector<void*> signal_pads_;
77+
HandleType mc_handle_;
78+
void* mc_addr_;
79+
size_t buffer_size_;
80+
int local_device_idx_;
81+
int rank_;
82+
int world_size_;
83+
void** buffers_dev_;
84+
void** signal_pads_dev_;
85+
std::string group_name_;
86+
};
87+
88+
struct Block : public c10::intrusive_ptr_target {
89+
c10::intrusive_ptr<AllocationRef> alloc_ref;
90+
int device_idx;
91+
size_t block_size;
92+
size_t buffer_size;
93+
size_t signal_pad_offset;
94+
std::optional<std::string> default_group_name;
95+
std::map<std::string, c10::intrusive_ptr<XPUSymmetricMemory>> symm_mems;
96+
97+
Block(
98+
c10::intrusive_ptr<AllocationRef> alloc_ref,
99+
int device_idx,
100+
size_t block_size,
101+
size_t buffer_size,
102+
size_t signal_pad_offset,
103+
const std::optional<std::string>& group_name);
104+
};
105+
106+
class XPUSymmetricMemoryAllocator : public SymmetricMemoryAllocator {
107+
public:
108+
void* alloc(
109+
size_t size,
110+
int device_idx,
111+
const std::optional<std::string>& group_name) override;
112+
113+
void free(void* ptr) override;
114+
size_t get_alloc_size(void* ptr) override;
115+
c10::intrusive_ptr<SymmetricMemory> rendezvous(
116+
void* ptr,
117+
const std::optional<std::string>& group_name) override;
118+
bool has_multicast_support(int device_idx) override;
119+
// void exchange_peer_ipc_mem(sycl::queue& queue, void* ptr);
120+
c10::DeviceType supported_device_type() override;
121+
std::string name() override;
122+
123+
private:
124+
c10::intrusive_ptr<Block> find_block(void* ptr);
125+
126+
std::shared_mutex mutex_;
127+
std::unordered_map<void*, c10::intrusive_ptr<Block>> ptr_to_block_;
128+
};
129+
130+
} // namespace c10d::symmetric_memory
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#pragma once
2+
3+
namespace c10d::symmetric_memory {
4+
5+
constexpr size_t signal_pad_size = 2048;
6+
using HandleType = void*;
7+
8+
} // namespace c10d::symmetric_memory
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#include <sys/socket.h>
2+
#include <sys/syscall.h>
3+
#include <sys/un.h>
4+
#include <unistd.h>
5+
6+
#include <c10/util/error.h>
7+
8+
#include <c10/xpu/XPUCachingAllocator.h>
9+
#include <level_zero/ze_api.h>
10+
#include <sycl/sycl.hpp>
11+
#include <torch/csrc/distributed/c10d/Store.hpp>
12+
#include <xccl/XPUSymmetricMemoryUtils.hpp>
13+
14+
namespace c10d::symmetric_memory {
15+
16+
std::string getSymmMemBackendXPU() {
17+
static auto val = c10::utils::get_env("TORCH_SYMMMEM");
18+
if (val.has_value()) {
19+
TORCH_CHECK(
20+
val.value() == "XPU",
21+
"TORCH_SYMMMEM environment variable must be 'XPU'.");
22+
return val.value();
23+
}
24+
return "XPU";
25+
}
26+
27+
bool device_has_multicast_support(int device_idx) {
28+
return false;
29+
}
30+
31+
bool allow_overlapping_devices() {
32+
return false;
33+
}
34+
35+
void map_block(
36+
void** ptr,
37+
ze_physical_mem_handle_t handle,
38+
size_t size,
39+
int device_idx) {
40+
sycl::queue current_queue = at::xpu::getCurrentXPUStream().queue();
41+
sycl::context sycl_ctx = current_queue.get_context();
42+
ze_context_handle_t ze_context =
43+
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_ctx);
44+
// 1. Reserve virtual address space
45+
void* virtual_ptr = nullptr;
46+
ze_result_t status = zeVirtualMemReserve(
47+
ze_context, // context
48+
nullptr, // let L0 pick virtual address
49+
size, // size
50+
&virtual_ptr // out: reserved address
51+
);
52+
TORCH_CHECK(status == ZE_RESULT_SUCCESS, "zeVirtualMemReserve failed");
53+
54+
// 2. Map physical memory to virtual address
55+
status = zeVirtualMemMap(
56+
ze_context,
57+
virtual_ptr, // virtual memory to map to
58+
size,
59+
handle, // physical memory handle
60+
0, // flags
61+
ZE_MEMORY_ACCESS_ATTRIBUTE_READWRITE // ze_memory_access_attribute_t
62+
);
63+
TORCH_CHECK(status == ZE_RESULT_SUCCESS, "zeVirtualMemMap failed");
64+
65+
// 3. Set access attributes
66+
ze_memory_access_attribute_t access = ZE_MEMORY_ACCESS_ATTRIBUTE_READWRITE;
67+
status =
68+
zeVirtualMemSetAccessAttribute(ze_context, virtual_ptr, size, access);
69+
TORCH_CHECK(
70+
status == ZE_RESULT_SUCCESS, "zeVirtualMemSetAccessAttribute failed");
71+
72+
// 4. Return pointer
73+
*ptr = virtual_ptr;
74+
}
75+
76+
} // namespace c10d::symmetric_memory
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
#pragma once
2+
#include <torch/csrc/distributed/c10d/Store.hpp>
3+
#include <torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>
4+
#include <xccl/XPUSymmetricMemoryTypes.hpp>
5+
6+
namespace c10d {
7+
namespace symmetric_memory {
8+
9+
std::string getSymmMemBackendXPU();
10+
11+
bool device_has_multicast_support(int device_idx);
12+
13+
bool allow_overlapping_devices();
14+
15+
// A set of store-based exchange methods with a preset prefix typically type of
16+
// the SymmetricMemory. Most used as static instances at respective
17+
// SymmetricMemory implementation files.
18+
class StoreExchange {
19+
public:
20+
StoreExchange(const std::string& store_prefix)
21+
: store_prefix_(store_prefix) {}
22+
23+
// Put template function in header file so that compiler can easily access it.
24+
template <typename T>
25+
std::vector<T> all_gather(
26+
const c10::intrusive_ptr<c10d::Store>& store,
27+
int rank,
28+
int world_size,
29+
T val) {
30+
static_assert(std::is_trivially_copyable_v<T>);
31+
32+
std::vector<std::string> peer_keys;
33+
peer_keys.reserve(world_size);
34+
for (int r = 0; r < world_size; ++r) {
35+
std::ostringstream oss;
36+
oss << store_prefix_ << "/" << seq_id_ << "/" << r;
37+
peer_keys.push_back(oss.str());
38+
}
39+
++seq_id_;
40+
41+
{
42+
std::vector<uint8_t> payload(
43+
reinterpret_cast<uint8_t*>(&val),
44+
reinterpret_cast<uint8_t*>(&val) + sizeof(T));
45+
store->set(peer_keys[rank], payload);
46+
}
47+
48+
std::vector<T> peer_vals;
49+
peer_vals.reserve(world_size);
50+
for (int r = 0; r < world_size; ++r) {
51+
if (r == rank) {
52+
peer_vals.push_back(val);
53+
continue;
54+
}
55+
store->wait({peer_keys[r]});
56+
auto payload = store->get(peer_keys[r]);
57+
TORCH_CHECK(payload.size() == sizeof(T));
58+
T peer_val{};
59+
std::memcpy(&peer_val, payload.data(), sizeof(T));
60+
peer_vals.push_back(peer_val);
61+
}
62+
return peer_vals;
63+
}
64+
65+
void barrier(
66+
const c10::intrusive_ptr<c10d::Store>& store,
67+
int rank,
68+
int world_size) {
69+
// TODO: implement an efficient one?
70+
all_gather(store, rank, world_size, 0);
71+
}
72+
73+
private:
74+
const std::string store_prefix_;
75+
size_t seq_id_ = 0;
76+
};
77+
78+
// Returns a pointer of virtual address that is mapped to the physical memory
79+
// held by the handle.
80+
// todo: will follow such physical memory handle map with virtual address,
81+
// when L0 provides physical handle exchange API and we have multicast support.
82+
void map_block(
83+
void** ptr,
84+
ze_physical_mem_handle_t handle,
85+
size_t size,
86+
int device_idx);
87+
88+
} // namespace symmetric_memory
89+
} // namespace c10d

0 commit comments

Comments
 (0)