@@ -4,7 +4,8 @@ mutable struct HloSharding
4
4
5
5
function HloSharding (ptr:: Ptr{Cvoid} )
6
6
@assert ptr != C_NULL
7
- return finalizer (free_hlo_sharding, new (ptr))
7
+ # return finalizer(free_hlo_sharding, new(ptr))
8
+ return new (ptr)
8
9
end
9
10
end
10
11
@@ -16,14 +17,30 @@ function HloSharding(
16
17
device_list:: AbstractVector{<:Device} , xla_hlo_sharding:: XLA.HloSharding
17
18
)
18
19
default_memory_kind = convert (MemoryKind, XLA. default_memory (device_list))
20
+ return HloSharding (device_list, xla_hlo_sharding, default_memory_kind)
21
+ end
22
+
23
+ function HloSharding (
24
+ device_list:: AbstractVector{<:Device} ,
25
+ xla_hlo_sharding:: XLA.HloSharding ,
26
+ memoy_kind:: AbstractString ,
27
+ )
28
+ return HloSharding (device_list, xla_hlo_sharding, MemoryKind (memoy_kind))
29
+ end
30
+
31
+ function HloSharding (
32
+ device_list:: AbstractVector{<:Device} ,
33
+ xla_hlo_sharding:: XLA.HloSharding ,
34
+ memory_kind:: MemoryKind ,
35
+ )
19
36
client = XLA. client (device_list)
20
- GC. @preserve device_list default_memory_kind xla_hlo_sharding client begin
37
+ GC. @preserve device_list memory_kind xla_hlo_sharding client begin
21
38
return HloSharding (
22
39
@ccall MLIR. API. mlir_c. ifrt_hlo_sharding_from_xla_hlo_sharding (
23
40
client. client:: Ptr{Cvoid} ,
24
41
[d. device for d in device_list]:: Ptr{Ptr{Cvoid}} ,
25
42
length (device_list):: Int32 ,
26
- default_memory_kind . ptr:: Ptr{Cvoid} ,
43
+ memory_kind . ptr:: Ptr{Cvoid} ,
27
44
xla_hlo_sharding. ptr:: Ptr{Cvoid} ,
28
45
):: Ptr{Cvoid}
29
46
)
@@ -51,14 +68,23 @@ mutable struct Sharding
51
68
52
69
function Sharding (ptr:: Ptr{Cvoid} )
53
70
@assert ptr != C_NULL
54
- return finalizer (free_sharding, new (ptr))
71
+ # return finalizer(free_sharding, new(ptr))
72
+ return new (ptr)
55
73
end
56
74
end
57
75
58
76
function Sharding (device_list:: AbstractVector{<:Device} , xla_hlo_sharding:: XLA.HloSharding )
59
77
return convert (Sharding, HloSharding (device_list, xla_hlo_sharding))
60
78
end
61
79
80
+ function Sharding (
81
+ device_list:: AbstractVector{<:Device} ,
82
+ xla_hlo_sharding:: XLA.HloSharding ,
83
+ memoy_kind:: Union{AbstractString,MemoryKind} ,
84
+ )
85
+ return convert (Sharding, HloSharding (device_list, xla_hlo_sharding, memoy_kind))
86
+ end
87
+
62
88
function free_sharding (sharding:: Sharding )
63
89
@ccall MLIR. API. mlir_c. free_ifrt_sharding (sharding. ptr:: Ptr{Cvoid} ):: Cvoid
64
90
end
0 commit comments