|
1 | 1 | module TracedRArrayOverrides
|
2 | 2 |
|
| 3 | +using Adapt: WrappedReshapedArray |
3 | 4 | using Base.Broadcast
|
4 | 5 | using Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle, instantiate
|
5 | 6 |
|
|
87 | 88 | function generate_index_list(i1, is...)
|
88 | 89 | list = reshape(i1, :, 1) .- 1
|
89 | 90 | for i in is
|
90 |
| - i = reshape(i, :, 1) |
| 91 | + i = TracedUtils.broadcast_to_size(i, (length(i), 1)) |
91 | 92 | lorig = size(list, 1)
|
92 | 93 | list = repeat(list, size(i, 1), 1)
|
93 | 94 | i = repeat(i; inner=(lorig, 1)) .- 1
|
@@ -196,8 +197,12 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
|
196 | 197 | if any(i -> unwrapped_eltype(i) <: Bool, indices)
|
197 | 198 | error("Boolean indexing with TracedRArrays isn't fully supported yet.")
|
198 | 199 | end
|
199 |
| - indices, integer_indices, result_size, _ = TracedUtils.traced_indices(indices...) |
200 |
| - res = Ops.gather_getindex(a, generate_index_list(indices...)) |
| 200 | + indices, integer_indices, result_size, preddim_result_size, _ = TracedUtils.traced_indices( |
| 201 | + indices... |
| 202 | + ) |
| 203 | + res = Ops.reshape( |
| 204 | + Ops.gather_getindex(a, generate_index_list(indices...)), preddim_result_size |
| 205 | + ) |
201 | 206 | isempty(integer_indices) ||
|
202 | 207 | (res = materialize_traced_array(dropdims(res; dims=integer_indices)))
|
203 | 208 | return Ops.reshape(res, result_size)
|
@@ -228,6 +233,24 @@ function Base.getindex(a::WrappedTracedRArray, indices...)
|
228 | 233 | return getindex(ancestor(a), TracedUtils.get_ancestor_indices(a, indices...)...)
|
229 | 234 | end
|
230 | 235 |
|
| 236 | +## Specialize certain dispatches for better codegen |
| 237 | +for aType in ( |
| 238 | + WrappedReshapedArray{TracedRNumber{T},N,TracedRArray{T,M}} where {T,N,M}, |
| 239 | + PermutedDimsArray{ |
| 240 | + TracedRNumber{T},N,perm,iperm,TracedRArray{T,N} |
| 241 | + } where {T,N,perm,iperm}, |
| 242 | +) |
| 243 | + @eval begin |
| 244 | + function Base.getindex(a::$aType, indices::Union{Int,TracedRNumber{Int}}...) |
| 245 | + return getindex(materialize_traced_array(a), indices...) |
| 246 | + end |
| 247 | + |
| 248 | + function Base.getindex(a::$aType, indices...) |
| 249 | + return getindex(materialize_traced_array(a), indices...) |
| 250 | + end |
| 251 | + end |
| 252 | +end |
| 253 | + |
231 | 254 | function maybe_assert_scalar_setindexing(
|
232 | 255 | ::TracedRArray{T,N}, ::Vararg{Union{Int,TracedRNumber{Int}},N}
|
233 | 256 | ) where {T,N}
|
|
0 commit comments