Skip to content

Commit 95f6074

Browse files
authored
fix: improve generated mlir for wrapped arrays (#732)
* fix: improve generated mlir for wrapped arrays * test: add test for no gather * fix: handle scalar index correctly
1 parent 904b789 commit 95f6074

File tree

3 files changed

+60
-5
lines changed

3 files changed

+60
-5
lines changed

src/TracedRArray.jl

+26-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module TracedRArrayOverrides
22

3+
using Adapt: WrappedReshapedArray
34
using Base.Broadcast
45
using Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle, instantiate
56

@@ -87,7 +88,7 @@ end
8788
function generate_index_list(i1, is...)
8889
list = reshape(i1, :, 1) .- 1
8990
for i in is
90-
i = reshape(i, :, 1)
91+
i = TracedUtils.broadcast_to_size(i, (length(i), 1))
9192
lorig = size(list, 1)
9293
list = repeat(list, size(i, 1), 1)
9394
i = repeat(i; inner=(lorig, 1)) .- 1
@@ -196,8 +197,12 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
196197
if any(i -> unwrapped_eltype(i) <: Bool, indices)
197198
error("Boolean indexing with TracedRArrays isn't fully supported yet.")
198199
end
199-
indices, integer_indices, result_size, _ = TracedUtils.traced_indices(indices...)
200-
res = Ops.gather_getindex(a, generate_index_list(indices...))
200+
indices, integer_indices, result_size, preddim_result_size, _ = TracedUtils.traced_indices(
201+
indices...
202+
)
203+
res = Ops.reshape(
204+
Ops.gather_getindex(a, generate_index_list(indices...)), preddim_result_size
205+
)
201206
isempty(integer_indices) ||
202207
(res = materialize_traced_array(dropdims(res; dims=integer_indices)))
203208
return Ops.reshape(res, result_size)
@@ -228,6 +233,24 @@ function Base.getindex(a::WrappedTracedRArray, indices...)
228233
return getindex(ancestor(a), TracedUtils.get_ancestor_indices(a, indices...)...)
229234
end
230235

236+
## Specialize certain dispatches for better codegen
237+
for aType in (
238+
WrappedReshapedArray{TracedRNumber{T},N,TracedRArray{T,M}} where {T,N,M},
239+
PermutedDimsArray{
240+
TracedRNumber{T},N,perm,iperm,TracedRArray{T,N}
241+
} where {T,N,perm,iperm},
242+
)
243+
@eval begin
244+
function Base.getindex(a::$aType, indices::Union{Int,TracedRNumber{Int}}...)
245+
return getindex(materialize_traced_array(a), indices...)
246+
end
247+
248+
function Base.getindex(a::$aType, indices...)
249+
return getindex(materialize_traced_array(a), indices...)
250+
end
251+
end
252+
end
253+
231254
function maybe_assert_scalar_setindexing(
232255
::TracedRArray{T,N}, ::Vararg{Union{Int,TracedRNumber{Int}},N}
233256
) where {T,N}

src/TracedUtils.jl

+13-2
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ function get_ancestor_indices(
6767
@assert length(indices) == N "Expected $N indices, got $(length(indices))"
6868
indices = normalize_indices(x, indices...)
6969
if any(is_traced, indices)
70-
indices, integer_indices, result_size, flattened_size = traced_indices(indices...)
70+
indices, integer_indices, result_size, _, flattened_size = traced_indices(
71+
indices...
72+
)
7173
linear_indices = mapreduce(+, enumerate(indices)) do (i, idx)
7274
bcasted_idxs = Ops.broadcast_in_dim(
7375
idx, ndims(idx) == 0 ? Int64[] : Int64[i], flattened_size
@@ -704,18 +706,27 @@ end
704706
function traced_indices(indices...)
705707
integer_indices = Int64[]
706708
result_size = Int64[]
709+
preddim_result_size = Int64[]
707710
flattened_size = Int64[length(idx) for idx in indices]
708711
new_indices = map(enumerate(indices)) do (i, idx)
709712
if idx isa Number
713+
push!(preddim_result_size, 1)
710714
push!(integer_indices, i)
711715
idx isa TracedRNumber && return idx
712716
return promote_to(TracedRNumber{Int}, idx)
713717
end
718+
append!(preddim_result_size, [size(idx)...])
714719
append!(result_size, [size(idx)...])
715720
idx isa TracedRArray && return materialize_traced_array(vec(idx))
716721
return promote_to(TracedRArray{Int,1}, vec(idx))
717722
end
718-
return new_indices, Tuple(integer_indices), result_size, flattened_size
723+
return (
724+
new_indices,
725+
Tuple(integer_indices),
726+
result_size,
727+
preddim_result_size,
728+
flattened_size,
729+
)
719730
end
720731

721732
end

test/wrapped_arrays.jl

+21
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,24 @@ end
243243
x_ra = Reactant.to_rarray(rand(3, 4, 3))
244244
@test @jit(fn(x_ra)) == fn(Array(x_ra))
245245
end
246+
247+
function reshape_getindex(x)
248+
x = reshape(x, 2, 4, 3)
249+
return x[1, :, :]
250+
end
251+
252+
function permutedims_getindex(x)
253+
x = PermutedDimsArray(x, (2, 1))
254+
return x[1, :]
255+
end
256+
257+
@testset "no gather getindex" begin
258+
x = ones(8, 3)
259+
x_ra = Reactant.to_rarray(x)
260+
261+
hlo = repr(@code_hlo(reshape_getindex(x_ra)))
262+
@test !occursin("stablehlo.gather", hlo)
263+
264+
hlo = repr(@code_hlo(permutedims_getindex(x_ra)))
265+
@test !occursin("stablehlo.gather", hlo)
266+
end

0 commit comments

Comments
 (0)