Skip to content

Commit 8e4c095

Browse files
authored
feat: add support for the remaining wrapper types (#369)
* feat: add materialize_traced_array for all other wrappers * refactor: use scatter for generating diagm * refactor: directly generate the region for simple_scatter_op * feat: generalize diagm * feat: efficient non-contiguous setindex * fix: non-contiguous indexing is now supported * feat: implement set_mlir_data for the remaining types * refactor: use `Ops.gather_getindex` to implement diag * fix: noinline ops * fix: incorrect rebase * fix: dispatches * fix: diagm for repeated indices and initial tests * fix: higher dimensional indexing + tests * fix: matrix multiplication of wrapper types * fix: de-specialize 3 arg mul!
1 parent d4e7c76 commit 8e4c095

File tree

10 files changed

+565
-142
lines changed

10 files changed

+565
-142
lines changed

Diff for: src/Compiler.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ function create_result(tocopy::T, path, result_stores) where {T}
4141
elems = Union{Symbol,Expr}[]
4242

4343
for i in 1:fieldcount(T)
44+
# If the field is undefined we don't set it. A common example for this is `du2`
45+
# for Tridiagonal
46+
isdefined(tocopy, i) || continue
4447
ev = create_result(getfield(tocopy, i), append_path(path, i), result_stores)
4548
push!(elems, ev)
4649
end
@@ -102,7 +105,7 @@ function create_result(tocopy::D, path, result_stores) where {K,V,D<:AbstractDic
102105
end
103106

104107
function create_result(
105-
tocopy::Union{Integer,AbstractFloat,AbstractString,Nothing,Type,Symbol},
108+
tocopy::Union{Integer,AbstractFloat,AbstractString,Nothing,Type,Symbol,Char},
106109
path,
107110
result_stores,
108111
)

Diff for: src/Ops.jl

+110
Original file line numberDiff line numberDiff line change
@@ -1418,4 +1418,114 @@ julia> Reactant.@jit(
14181418
end
14191419
end
14201420

1421+
"""
1422+
scatter_setindex(dest, scatter_indices, updates)
1423+
1424+
Uses [`MLIR.Dialects.stablehlo.scatter`](@ref) to set the values of `dest` at the indices
1425+
specified by `scatter_indices` to the values in `updates`. If the indices are contiguous it
1426+
is recommended to directly use [`MLIR.Dialects.stablehlo.dynamic_update_slice`](@ref)
1427+
instead.
1428+
"""
1429+
@noinline function scatter_setindex(
1430+
dest::TracedRArray{T,N},
1431+
scatter_indices::TracedRArray{Int64,2},
1432+
updates::TracedRArray{T,1},
1433+
) where {T,N}
1434+
@assert length(updates) == size(scatter_indices, 1)
1435+
@assert size(scatter_indices, 2) == N
1436+
1437+
update_computation = MLIR.IR.Region()
1438+
block = MLIR.IR.Block(
1439+
[mlir_type(TracedRNumber{T}), mlir_type(TracedRNumber{T})],
1440+
[MLIR.IR.Location(), MLIR.IR.Location()],
1441+
)
1442+
return_op = MLIR.Dialects.stablehlo.return_([MLIR.IR.argument(block, 2)])
1443+
MLIR.IR.rmfromparent!(return_op)
1444+
push!(block, return_op)
1445+
pushfirst!(update_computation, block)
1446+
1447+
#! format: off
1448+
update_window_dims = Int64[]
1449+
inserted_window_dims = collect(Int64, 0:(N - 1))
1450+
input_batching_dims = Int64[]
1451+
scatter_indices_batching_dims = Int64[]
1452+
scatter_dims_to_operand_dims = collect(Int64, 0:(N - 1))
1453+
index_vector_dim = Int64(1)
1454+
1455+
scatter_dimension_numbers = MLIR.API.stablehloScatterDimensionNumbersGet(
1456+
MLIR.IR.context(),
1457+
length(update_window_dims), update_window_dims,
1458+
length(inserted_window_dims), inserted_window_dims,
1459+
length(input_batching_dims), input_batching_dims,
1460+
length(scatter_indices_batching_dims), scatter_indices_batching_dims,
1461+
length(scatter_dims_to_operand_dims), scatter_dims_to_operand_dims,
1462+
index_vector_dim,
1463+
)
1464+
#! format: on
1465+
1466+
return TracedRArray{T,N}(
1467+
(),
1468+
MLIR.IR.result(
1469+
MLIR.Dialects.stablehlo.scatter(
1470+
[dest.mlir_data],
1471+
scatter_indices.mlir_data,
1472+
[updates.mlir_data];
1473+
result_0=[mlir_type(TracedRArray{T,N}, size(dest))],
1474+
update_computation,
1475+
scatter_dimension_numbers,
1476+
),
1477+
1,
1478+
),
1479+
size(dest),
1480+
)
1481+
end
1482+
1483+
"""
1484+
gather_getindex(src, gather_indices)
1485+
1486+
Uses [`MLIR.Dialects.stablehlo.gather`](@ref) to get the values of `src` at the indices
1487+
specified by `gather_indices`. If the indices are contiguous it is recommended to directly
1488+
use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
1489+
"""
1490+
@noinline function gather_getindex(
1491+
src::TracedRArray{T,N}, gather_indices::TracedRArray{Int64,2}
1492+
) where {T,N}
1493+
@assert size(gather_indices, 2) == N
1494+
1495+
#! format: off
1496+
offset_dims = Int64[1]
1497+
collapsed_slice_dims = collect(Int64, 0:(N - 2))
1498+
operand_batching_dims = Int64[]
1499+
start_indices_batching_dims = Int64[]
1500+
start_index_map = collect(Int64, 0:(N - 1))
1501+
index_vector_dim = Int64(1)
1502+
1503+
dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet(
1504+
MLIR.IR.context(),
1505+
Int64(length(offset_dims)), offset_dims,
1506+
Int64(length(collapsed_slice_dims)), collapsed_slice_dims,
1507+
Int64(length(operand_batching_dims)), operand_batching_dims,
1508+
Int64(length(start_indices_batching_dims)), start_indices_batching_dims,
1509+
Int64(length(start_index_map)), start_index_map,
1510+
Int64(index_vector_dim),
1511+
)
1512+
#! format: on
1513+
1514+
return reshape(
1515+
TracedRArray{T}(
1516+
MLIR.IR.result(
1517+
MLIR.Dialects.stablehlo.gather(
1518+
src.mlir_data,
1519+
gather_indices.mlir_data;
1520+
dimension_numbers,
1521+
slice_sizes=fill(Int64(1), N),
1522+
indices_are_sorted=false,
1523+
),
1524+
1,
1525+
),
1526+
),
1527+
size(gather_indices, 1),
1528+
)
1529+
end
1530+
14211531
end # module Ops

Diff for: src/Overlay.jl

+27
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,30 @@ for randfun in (:rand, :randn, :randexp)
115115
# end
116116
end
117117
end
118+
119+
# LinearAlgebra.jl overloads
120+
## `_mul!` goes through too many layers of abstractions and we aren't able to overload
121+
## without specializing on every possible combination of types
122+
for (cT, aT, bT) in (
123+
(:AbstractVector, :AbstractMatrix, :AbstractVector),
124+
(:AbstractMatrix, :AbstractMatrix, :AbstractVecOrMat),
125+
)
126+
@eval begin
127+
@reactant_overlay @noinline function LinearAlgebra.mul!(
128+
C::$cT, A::$aT, B::$bT, α::Number, β::Number
129+
)
130+
if any(Base.Fix2(isa, TracedRArray) ancestor, (C, A, B))
131+
TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β)
132+
else
133+
LinearAlgebra._mul!(C, A, B, α, β)
134+
end
135+
return C
136+
end
137+
138+
# Needed mostly for 1.10 where 3-arg mul is often specialized
139+
@reactant_overlay @noinline function LinearAlgebra.mul!(C::$cT, A::$aT, B::$bT)
140+
call_with_reactant(LinearAlgebra.mul!, C, A, B, true, false)
141+
return C
142+
end
143+
end
144+
end

Diff for: src/Reactant.jl

+14-6
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ mutable struct TracedRArray{T,N} <: RArray{TracedRNumber{T},N}
105105
) where {T,N}
106106
shape = Tuple(shape)
107107
if !isnothing(mlir_data)
108-
@assert size(MLIR.IR.type(mlir_data)) == shape
108+
@assert size(MLIR.IR.type(mlir_data)) == shape "Expected: $(shape), got: $(size(MLIR.IR.type(mlir_data)))"
109109
end
110110
return new{T,N}(paths, mlir_data, shape)
111111
end
@@ -119,15 +119,23 @@ const WrappedTracedRArray{T,N} = WrappedArray{
119119
const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}}
120120
const AnyTracedRVector{T} = AnyTracedRArray{T,1}
121121
const AnyTracedRMatrix{T} = Union{
122-
AnyTracedRArray{T,2},LinearAlgebra.Diagonal{T,TracedRArray{T,1}}
122+
AnyTracedRArray{T,2},
123+
LinearAlgebra.Diagonal{TracedRNumber{T},TracedRArray{T,1}},
124+
LinearAlgebra.Tridiagonal{TracedRNumber{T},TracedRArray{T,1}},
123125
}
124126
const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}}
125127

126-
function TracedRArray(data::MLIR.IR.Value)
128+
function TracedRArray{T}(data::MLIR.IR.Value) where {T}
127129
data_type = MLIR.IR.type(data)
128-
return TracedRArray{eltype(MLIR.IR.julia_type(data_type)),ndims(data_type)}(
129-
(), data, size(data_type)
130-
)
130+
if T == eltype(MLIR.IR.julia_type(data_type))
131+
return TracedRArray{T,ndims(data_type)}((), data, size(data_type))
132+
end
133+
tdata = TracedRArray(data)
134+
return Ops.convert(TracedRArray{T,ndims(data_type)}, tdata)
135+
end
136+
137+
function TracedRArray(data::MLIR.IR.Value)
138+
return TracedRArray{eltype(MLIR.IR.julia_type(MLIR.IR.type(data)))}(data)
131139
end
132140

133141
struct XLAArray{T,N} <: RArray{T,N} end

Diff for: src/TracedRArray.jl

+54-20
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@ using ..Reactant:
1414
MLIR,
1515
ancestor,
1616
unwrapped_eltype
17+
using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!, materialize_traced_array
18+
1719
using ReactantCore: ReactantCore
18-
using ..TracedUtils: TracedUtils, materialize_traced_array
1920
using GPUArraysCore: GPUArraysCore
2021

2122
ReactantCore.is_traced(::TracedRArray) = true
@@ -55,25 +56,37 @@ function Base.getindex(
5556
return TracedRNumber{T}((), res2)
5657
end
5758

58-
function Base.getindex(a::TracedRArray{T,0}) where {T}
59-
return TracedRNumber{T}((), a.mlir_data)
60-
end
59+
Base.getindex(a::TracedRArray{T,0}) where {T} = TracedRNumber{T}((), a.mlir_data)
6160

62-
# XXX: We want to support https://github.com/EnzymeAD/Reactant.jl/issues/242 eventually
6361
function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
6462
indices = map(enumerate(indices)) do (idx, i)
6563
i isa Colon && return 1:size(a, idx)
6664
i isa CartesianIndex && return Tuple(i)
6765
return i
6866
end
6967

70-
foreach(indices) do idxs
71-
idxs isa Number && return nothing
68+
non_contiguous_getindex = false
69+
for idxs in indices
70+
idxs isa Number && continue
7271
contiguous = all(isone, diff(idxs))
7372
# XXX: We want to throw error even for dynamic indexing
74-
if typeof(a) <: Bool
75-
contiguous || error("non-contiguous indexing is not supported")
73+
if typeof(contiguous) <: Bool && !contiguous
74+
non_contiguous_getindex = true
75+
break
76+
end
77+
end
78+
79+
if non_contiguous_getindex
80+
indices_tuples = collect(Iterators.product(indices...))
81+
indices = Matrix{Int}(
82+
undef, (length(indices_tuples), length(first(indices_tuples)))
83+
)
84+
for (i, idx) in enumerate(indices_tuples)
85+
indices[i, :] .= idx .- 1
7686
end
87+
indices = TracedUtils.promote_to(TracedRArray{Int,2}, indices)
88+
res = Ops.gather_getindex(a, indices)
89+
return Ops.reshape(res, size(indices_tuples)...)
7790
end
7891

7992
start_indices = map(indices) do i
@@ -99,16 +112,41 @@ function Base.getindex(a::WrappedTracedRArray, indices...)
99112
return getindex(ancestor(a), TracedUtils.get_ancestor_indices(a, indices...)...)
100113
end
101114

102-
function Base.setindex!(
103-
a::TracedRArray{T,N},
104-
v,
105-
indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int,TracedRNumber{Int}},N},
106-
) where {T,N}
115+
function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {T,N}
107116
indices = map(enumerate(indices)) do (idx, i)
108-
i isa Int ? (i:i) : (i isa Colon ? (1:size(a, idx)) : i)
117+
i isa Colon && return 1:size(a, idx)
118+
i isa CartesianIndex && return Tuple(i)
119+
return i
120+
end
121+
122+
non_contiguous_setindex = false
123+
for idxs in indices
124+
idxs isa Number && continue
125+
contiguous = all(isone, diff(idxs))
126+
# XXX: We want to throw error even for dynamic indexing
127+
if typeof(contiguous) <: Bool && !contiguous
128+
non_contiguous_setindex = true
129+
break
130+
end
131+
end
132+
133+
if non_contiguous_setindex
134+
indices_tuples = collect(Iterators.product(indices...))
135+
indices = Matrix{Int}(
136+
undef, (length(indices_tuples), length(first(indices_tuples)))
137+
)
138+
for (i, idx) in enumerate(indices_tuples)
139+
indices[i, :] .= idx .- 1
140+
end
141+
indices = TracedUtils.promote_to(TracedRArray{Int,2}, indices)
142+
res = Ops.scatter_setindex(a, indices, Ops.reshape(v, length(v)))
143+
a.mlir_data = res.mlir_data
144+
return v
109145
end
146+
110147
v = TracedUtils.broadcast_to_size(v, length.(indices))
111148
v = TracedUtils.promote_to(TracedRArray{T,N}, v)
149+
112150
indices = [
113151
(
114152
TracedUtils.promote_to(TracedRNumber{Int}, i isa Colon ? 1 : first(i)) - 1
@@ -124,11 +162,7 @@ function Base.setindex!(
124162
return v
125163
end
126164

127-
function Base.setindex!(
128-
a::AnyTracedRArray{T,N},
129-
v,
130-
indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int,TracedRNumber{Int}},N},
131-
) where {T,N}
165+
function Base.setindex!(a::AnyTracedRArray{T,N}, v, indices::Vararg{Any,N}) where {T,N}
132166
ancestor_indices = TracedUtils.get_ancestor_indices(a, indices...)
133167
setindex!(ancestor(a), v, ancestor_indices...)
134168
return a

0 commit comments

Comments
 (0)