Skip to content

Commit 7e45011

Browse files
committed
feat: sharding annotations across nodes now working
1 parent 044e878 commit 7e45011

File tree

4 files changed

+45
-12
lines changed

4 files changed

+45
-12
lines changed

deps/ReactantExtra/API.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -1891,6 +1891,10 @@ hlo_sharding_to_string(const xla::HloSharding *hlo_sharding) {
18911891
return cstr_from_string(hlo_sharding->ToString(true));
18921892
}
18931893

1894+
extern "C" ifrt::MemoryKind *ifrt_memory_kind_from_string(const char *c_str) {
1895+
return new ifrt::MemoryKind(std::string(c_str));
1896+
}
1897+
18941898
extern "C" ifrt::HloSharding *ifrt_hlo_sharding_from_xla_hlo_sharding(
18951899
ifrt::Client *client, ifrt::Device **device_list, int32_t num_devices,
18961900
ifrt::MemoryKind *memory_kind, xla::HloSharding *xla_hlo_sharding) {

src/xla/IFRT/Device.jl

+2-8
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,13 @@ function XLA.memories(device::Device)
4848
device.device::Ptr{Cvoid}, memories_size::Ptr{Int32}
4949
)::Ptr{Ptr{Cvoid}}
5050
end
51-
memories = Vector{Memory}(undef, memories_size[])
52-
for i in 1:memories_size[]
53-
memories[i] = Memory(unsafe_load(ptr, i))
54-
end
55-
return memories
51+
return [Memory(unsafe_load(ptr, i)) for i in 1:memories_size[]]
5652
end
5753

5854
function XLA.default_memory(device_list::AbstractVector{Device})
5955
default_memories = XLA.default_memory.(device_list)
6056
default_memory_kinds = convert.(MemoryKind, default_memories)
61-
if !allequal(default_memory_kinds)
62-
error("All devices must have the same default memory")
63-
end
57+
@assert allequal(default_memory_kinds) "All devices must have the same default memory"
6458
return first(default_memories)
6559
end
6660

src/xla/IFRT/Memory.jl

+9
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@ mutable struct MemoryKind <: XLA.AbstractMemoryKind
1414
ptr::Ptr{Cvoid}
1515
end
1616

17+
function MemoryKind(str::AbstractString)
18+
str = string(str)
19+
GC.@preserve str begin
20+
return MemoryKind(
21+
@ccall MLIR.API.mlir_c.ifrt_memory_kind_from_string(str::Cstring)::Ptr{Cvoid}
22+
)
23+
end
24+
end
25+
1726
function Base.convert(::Type{MemoryKind}, memory::Memory)
1827
GC.@preserve memory begin
1928
return MemoryKind(

src/xla/IFRT/Sharding.jl

+30-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ mutable struct HloSharding
44

55
function HloSharding(ptr::Ptr{Cvoid})
66
@assert ptr != C_NULL
7-
return finalizer(free_hlo_sharding, new(ptr))
7+
# return finalizer(free_hlo_sharding, new(ptr))
8+
return new(ptr)
89
end
910
end
1011

@@ -16,14 +17,30 @@ function HloSharding(
1617
device_list::AbstractVector{<:Device}, xla_hlo_sharding::XLA.HloSharding
1718
)
1819
default_memory_kind = convert(MemoryKind, XLA.default_memory(device_list))
20+
return HloSharding(device_list, xla_hlo_sharding, default_memory_kind)
21+
end
22+
23+
function HloSharding(
24+
device_list::AbstractVector{<:Device},
25+
xla_hlo_sharding::XLA.HloSharding,
26+
memoy_kind::AbstractString,
27+
)
28+
return HloSharding(device_list, xla_hlo_sharding, MemoryKind(memoy_kind))
29+
end
30+
31+
function HloSharding(
32+
device_list::AbstractVector{<:Device},
33+
xla_hlo_sharding::XLA.HloSharding,
34+
memory_kind::MemoryKind,
35+
)
1936
client = XLA.client(device_list)
20-
GC.@preserve device_list default_memory_kind xla_hlo_sharding client begin
37+
GC.@preserve device_list memory_kind xla_hlo_sharding client begin
2138
return HloSharding(
2239
@ccall MLIR.API.mlir_c.ifrt_hlo_sharding_from_xla_hlo_sharding(
2340
client.client::Ptr{Cvoid},
2441
[d.device for d in device_list]::Ptr{Ptr{Cvoid}},
2542
length(device_list)::Int32,
26-
default_memory_kind.ptr::Ptr{Cvoid},
43+
memory_kind.ptr::Ptr{Cvoid},
2744
xla_hlo_sharding.ptr::Ptr{Cvoid},
2845
)::Ptr{Cvoid}
2946
)
@@ -51,14 +68,23 @@ mutable struct Sharding
5168

5269
function Sharding(ptr::Ptr{Cvoid})
5370
@assert ptr != C_NULL
54-
return finalizer(free_sharding, new(ptr))
71+
# return finalizer(free_sharding, new(ptr))
72+
return new(ptr)
5573
end
5674
end
5775

5876
function Sharding(device_list::AbstractVector{<:Device}, xla_hlo_sharding::XLA.HloSharding)
5977
return convert(Sharding, HloSharding(device_list, xla_hlo_sharding))
6078
end
6179

80+
function Sharding(
81+
device_list::AbstractVector{<:Device},
82+
xla_hlo_sharding::XLA.HloSharding,
83+
memoy_kind::Union{AbstractString,MemoryKind},
84+
)
85+
return convert(Sharding, HloSharding(device_list, xla_hlo_sharding, memoy_kind))
86+
end
87+
6288
function free_sharding(sharding::Sharding)
6389
@ccall MLIR.API.mlir_c.free_ifrt_sharding(sharding.ptr::Ptr{Cvoid})::Cvoid
6490
end

0 commit comments

Comments
 (0)