Skip to content

Commit 037fd11

Browse files
committed
feat: implement Ops.batch
1 parent 93f9f07 commit 037fd11

File tree

1 file changed

+112
-0
lines changed

1 file changed

+112
-0
lines changed

src/Ops.jl

+112
Original file line numberDiff line numberDiff line change
@@ -1557,4 +1557,116 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
15571557
)
15581558
end
15591559

1560+
@noinline function batch(
1561+
f,
1562+
args::Vector{<:TracedRArray},
1563+
batch_dims::Vector{Union{Int,Nothing}},
1564+
result_dims::Union{Vector{Int},Nothing}=nothing,
1565+
)
1566+
@assert length(batch_dims) == length(args)
1567+
1568+
batch_sizes = [dim === nothing ? 1 : size(x, dim) for (x, dim) in zip(args, batch_dims)]
1569+
filter!(x -> x != 1, batch_sizes)
1570+
@assert allequal(batch_sizes) "batching dimensions must be equal"
1571+
B = length(batch_sizes) == 0 ? 1 : first(batch_sizes)
1572+
1573+
args = map(zip(args, batch_dims)) do (arg, dim)
1574+
if dim === nothing
1575+
return broadcast_in_dim(arg, collect(1:ndims(arg)) .+ 1, Int64[B, size(arg)...])
1576+
end
1577+
order = collect(1:ndims(arg))
1578+
order[dim] = 1
1579+
order[1] = dim
1580+
return permutedims(arg, order)
1581+
end
1582+
results = batch(f, args)
1583+
result_dims === nothing && (result_dims = ones(Int64, length(results)))
1584+
@assert length(results) == length(result_dims)
1585+
return map(zip(results, result_dims)) do (result, dim)
1586+
order = collect(1:ndims(result))
1587+
order[dim] = 1
1588+
order[1] = dim
1589+
return permutedims(result, order)
1590+
end
1591+
end
1592+
1593+
@noinline function batch(f, args::Vector{<:TracedRArray})
1594+
batch_sizes = [size(x, 1) for x in args]
1595+
@assert allequal(batch_sizes) "batching dimensions must be equal"
1596+
B = first(batch_sizes)
1597+
1598+
in_tys = [
1599+
MLIR.IR.TensorType(size(arg)[2:end], MLIR.IR.Type(Reactant.unwrapped_eltype(arg)))
1600+
for arg in args
1601+
]
1602+
1603+
sym_visibility = MLIR.IR.Attribute("private")
1604+
1605+
mod = MLIR.IR.mmodule()
1606+
func = MLIR.IR.block!(MLIR.IR.body(mod)) do
1607+
return MLIR.Dialects.func.func_(;
1608+
sym_name=string(f) * "_batch_tmp",
1609+
function_type=MLIR.IR.FunctionType(in_tys, []),
1610+
body=MLIR.IR.Region(),
1611+
sym_visibility,
1612+
)
1613+
end
1614+
fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in args])
1615+
push!(MLIR.IR.region(func, 1), fnbody)
1616+
1617+
linear_args = [
1618+
TracedRArray{Reactant.unwrapped_eltype(arg),ndims(arg) - 1}(
1619+
(), nothing, size(arg)[2:end]
1620+
) for arg in args
1621+
]
1622+
1623+
MLIR.IR.activate!(fnbody)
1624+
result = try
1625+
for (i, arg) in enumerate(linear_args)
1626+
raw_arg = MLIR.IR.argument(fnbody, i)
1627+
Reactant.TracedUtils.set_mlir_data!(arg, raw_arg)
1628+
end
1629+
res = Reactant.call_with_reactant(f, linear_args...)
1630+
(res isa TracedRArray || res isa TracedRNumber) && (res = [res])
1631+
MLIR.Dialects.func.return_([r.mlir_data for r in res])
1632+
res
1633+
finally
1634+
MLIR.IR.deactivate!(fnbody)
1635+
end
1636+
1637+
comp_func = MLIR.IR.block!(MLIR.IR.body(mod)) do
1638+
return MLIR.Dialects.func.func_(;
1639+
sym_name=string(f) * "_batch",
1640+
function_type=MLIR.IR.FunctionType(in_tys, [mlir_type(r) for r in result]),
1641+
body=MLIR.IR.Region(),
1642+
sym_visibility,
1643+
)
1644+
end
1645+
MLIR.API.mlirRegionTakeBody(MLIR.IR.region(comp_func, 1), MLIR.IR.region(func, 1))
1646+
MLIR.API.mlirOperationDestroy(func.operation)
1647+
func.operation = MLIR.API.MlirOperation(C_NULL)
1648+
1649+
fname = Reactant.TracedUtils.get_attribute_by_name(comp_func, "sym_name")
1650+
fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))
1651+
1652+
batch_inputs = [x.mlir_data for x in args]
1653+
out_tys = [
1654+
MLIR.IR.TensorType((B, size(r)...), MLIR.IR.Type(Reactant.unwrapped_eltype(r))) for
1655+
r in result
1656+
]
1657+
1658+
op = MLIR.Dialects.enzyme.batch(
1659+
batch_inputs;
1660+
outputs=out_tys,
1661+
fn=fname,
1662+
batch_shape=MLIR.IR.DenseArrayAttribute(Int64[B]),
1663+
)
1664+
1665+
return [
1666+
TracedRArray{Reactant.unwrapped_eltype(r),ndims(r) + 1}(
1667+
(), MLIR.IR.result(op, i), (B, size(r)...)
1668+
) for (i, r) in enumerate(result)
1669+
]
1670+
end
1671+
15601672
end # module Ops

0 commit comments

Comments
 (0)