|
| 1 | +# Testing manual IFRT buffer creation + compilation + execution |
| 2 | +using Reactant, Test |
| 3 | +using Reactant: XLA |
| 4 | +using Reactant.XLA: IFRT |
| 5 | + |
| 6 | +fn_test1(x, y) = x .+ y |
| 7 | +fn_test2(x, y) = x .* y |
| 8 | +fn_test3(x, y) = x .+ y' .- x |
| 9 | + |
| 10 | +@testset "IFRT Low-level API" begin |
| 11 | + x = reshape(collect(Float32, 1:64), 8, 8) |
| 12 | + y = collect((x .+ 64)') |
| 13 | + |
| 14 | + pjrt_client = Reactant.XLA.default_backend() |
| 15 | + platform_name = lowercase(XLA.platform_name(pjrt_client)) |
| 16 | + |
| 17 | + ifrt_client = if platform_name == "cpu" |
| 18 | + IFRT.CPUClient(; checkcount=false) |
| 19 | + elseif platform_name == "gpu" |
| 20 | + IFRT.GPUClient(; checkcount=false) |
| 21 | + elseif platform_name == "tpu" |
| 22 | + IFRT.TPUClient(; checkcount=false) |
| 23 | + else |
| 24 | + error("Unsupported platform: $(platform_name)") |
| 25 | + end |
| 26 | + |
| 27 | + pjrt_x = ConcreteRArray(x) # XXX: Rename to ConcretePJRTArray |
| 28 | + pjrt_y = ConcreteRArray(y) # XXX: Rename to ConcretePJRTArray |
| 29 | + |
| 30 | + ifrt_x = IFRT.Array(ifrt_client, x) # XXX: Use ConcreteIFRTArray once ready |
| 31 | + ifrt_y = IFRT.Array(ifrt_client, y) # XXX: Use ConcreteIFRTArray once ready |
| 32 | + |
| 33 | + @testset for fn in (fn_test1, fn_test2, fn_test3) |
| 34 | + pjrt_result = @jit fn(pjrt_x, pjrt_y) |
| 35 | + |
| 36 | + mlir_mod, mlir_fn_res = Reactant.Compiler.compile_mlir(fn, (pjrt_x, pjrt_y)) |
| 37 | + |
| 38 | + ifrt_loaded_executable = XLA.compile( |
| 39 | + ifrt_client, |
| 40 | + XLA.default_device(ifrt_client), |
| 41 | + mlir_mod; |
| 42 | + num_outputs=length(mlir_fn_res.linear_results), |
| 43 | + num_parameters=length(mlir_fn_res.linear_args), |
| 44 | + mlir_fn_res.is_sharded, |
| 45 | + global_device_ids=Int64[], |
| 46 | + num_replicas=1, |
| 47 | + num_partitions=1, |
| 48 | + ) |
| 49 | + |
| 50 | + ifrt_result = XLA.execute( |
| 51 | + ifrt_loaded_executable, (ifrt_x.buffer, ifrt_y.buffer), UInt8.((0, 0)), Val(1) |
| 52 | + ) |
| 53 | + |
| 54 | + @test convert(Array, only(ifrt_result)) ≈ Array(pjrt_result) |
| 55 | + end |
| 56 | +end |
0 commit comments