Skip to content

Commit 7fc621e

Browse files
committed
test: low level IFRT tests
1 parent 2a98366 commit 7fc621e

File tree

3 files changed

+60
-5
lines changed

3 files changed

+60
-5
lines changed

test/ifrt/low_level.jl

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Testing manual IFRT buffer creation + compilation + execution
2+
using Reactant, Test
3+
using Reactant: XLA
4+
using Reactant.XLA: IFRT
5+
6+
fn_test1(x, y) = x .+ y
7+
fn_test2(x, y) = x .* y
8+
fn_test3(x, y) = x .+ y' .- x
9+
10+
@testset "IFRT Low-level API" begin
11+
x = reshape(collect(Float32, 1:64), 8, 8)
12+
y = collect((x .+ 64)')
13+
14+
pjrt_client = Reactant.XLA.default_backend()
15+
platform_name = lowercase(XLA.platform_name(pjrt_client))
16+
17+
ifrt_client = if platform_name == "cpu"
18+
IFRT.CPUClient(; checkcount=false)
19+
elseif platform_name == "gpu"
20+
IFRT.GPUClient(; checkcount=false)
21+
elseif platform_name == "tpu"
22+
IFRT.TPUClient(; checkcount=false)
23+
else
24+
error("Unsupported platform: $(platform_name)")
25+
end
26+
27+
pjrt_x = ConcreteRArray(x) # XXX: Rename to ConcretePJRTArray
28+
pjrt_y = ConcreteRArray(y) # XXX: Rename to ConcretePJRTArray
29+
30+
ifrt_x = IFRT.Array(ifrt_client, x) # XXX: Use ConcreteIFRTArray once ready
31+
ifrt_y = IFRT.Array(ifrt_client, y) # XXX: Use ConcreteIFRTArray once ready
32+
33+
@testset for fn in (fn_test1, fn_test2, fn_test3)
34+
pjrt_result = @jit fn(pjrt_x, pjrt_y)
35+
36+
mlir_mod, mlir_fn_res = Reactant.Compiler.compile_mlir(fn, (pjrt_x, pjrt_y))
37+
38+
ifrt_loaded_executable = XLA.compile(
39+
ifrt_client,
40+
XLA.default_device(ifrt_client),
41+
mlir_mod;
42+
num_outputs=length(mlir_fn_res.linear_results),
43+
num_parameters=length(mlir_fn_res.linear_args),
44+
mlir_fn_res.is_sharded,
45+
global_device_ids=Int64[],
46+
num_replicas=1,
47+
num_partitions=1,
48+
)
49+
50+
ifrt_result = XLA.execute(
51+
ifrt_loaded_executable, (ifrt_x.buffer, ifrt_y.buffer), UInt8.((0, 0)), Val(1)
52+
)
53+
54+
@test convert(Array, only(ifrt_result)) Array(pjrt_result)
55+
end
56+
end

test/ifrt_manual.jl

-4
This file was deleted.

test/runtests.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
6262
@safetestset "Custom Number Types" include("custom_number_types.jl")
6363
end
6464
@safetestset "Sharding" include("sharding.jl")
65-
@safetestset "IFRT Low-Level API" include("ifrt_manual.jl")
65+
66+
@testset "IFRT" begin
67+
@safetestset "IFRT Low-Level API" include("ifrt/low_level.jl")
68+
end
6669
end
6770

6871
if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"

0 commit comments

Comments
 (0)