Skip to content

Commit 93c8534

Browse files
committed
feat: implement Ops.batch
1 parent 4fd0492 commit 93c8534

File tree

1 file changed

+115
-0
lines changed

1 file changed

+115
-0
lines changed

src/Ops.jl

+115
Original file line numberDiff line numberDiff line change
@@ -1557,4 +1557,119 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
15571557
)
15581558
end
15591559

1560+
function batch(
1561+
f,
1562+
args::Vector{<:TracedRArray},
1563+
batch_dims::Vector{Union{Int,Nothing}},
1564+
result_dims::Union{Vector{Int},Nothing}=nothing,
1565+
)
1566+
@assert length(batch_dims) == length(args)
1567+
1568+
batch_sizes = [dim === nothing ? 1 : size(x, dim) for (x, dim) in zip(args, batch_dims)]
1569+
filter!(x -> x != 1, batch_sizes)
1570+
@assert allequal(batch_sizes) "batching dimensions must be equal"
1571+
B = length(batch_sizes) == 0 ? 1 : first(batch_sizes)
1572+
1573+
args = map(zip(args, batch_dims)) do (arg, dim)
1574+
if dim === nothing
1575+
return broadcast_in_dim(arg, collect(1:ndims(arg)) .+ 1, Int64[B, size(arg)...])
1576+
end
1577+
order = collect(1:ndims(arg))
1578+
order[dim] = 1
1579+
order[1] = dim
1580+
return permutedims(arg, order)
1581+
end
1582+
results = batch(f, args)
1583+
result_dims === nothing && (result_dims = ones(Int64, length(results)))
1584+
@assert length(results) == length(result_dims)
1585+
return map(zip(results, result_dims)) do (result, dim)
1586+
order = collect(1:ndims(result))
1587+
order[dim] = 1
1588+
order[1] = dim
1589+
return permutedims(result, order)
1590+
end
1591+
end
1592+
1593+
function batch(f, args::Vector{<:TracedRArray})
1594+
batch_sizes = [size(x, 1) for x in args]
1595+
@assert allequal(batch_sizes) "batching dimensions must be equal"
1596+
B = first(batch_sizes)
1597+
1598+
in_tys = [
1599+
MLIR.IR.TensorType(size(arg)[2:end], MLIR.IR.Type(Reactant.unwrapped_eltype(arg)))
1600+
for arg in args
1601+
]
1602+
1603+
sym_visibility = MLIR.IR.Attribute("private")
1604+
1605+
mod = MLIR.IR.mmodule()
1606+
func = MLIR.IR.block!(MLIR.IR.body(mod)) do
1607+
return MLIR.Dialects.func.func_(;
1608+
sym_name=string(f) * "_batch_tmp",
1609+
function_type=MLIR.IR.FunctionType(in_tys, []),
1610+
body=MLIR.IR.Region(),
1611+
sym_visibility,
1612+
)
1613+
end
1614+
fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in args])
1615+
push!(MLIR.IR.region(func, 1), fnbody)
1616+
1617+
linear_args = [
1618+
TracedRArray{Reactant.unwrapped_eltype(arg),ndims(arg) - 1}(
1619+
(), nothing, size(arg)[2:end]
1620+
) for arg in args
1621+
]
1622+
1623+
MLIR.IR.activate!(fnbody)
1624+
result = try
1625+
for (i, arg) in enumerate(linear_args)
1626+
raw_arg = MLIR.IR.argument(fnbody, i)
1627+
Reactant.TracedUtils.set_mlir_data!(arg, raw_arg)
1628+
end
1629+
# XXX: call_with_reactant is not working here?
1630+
# ERROR: type Nothing has no field stmts
1631+
# res = Reactant.call_with_reactant(f, linear_args...)
1632+
res = f(linear_args...)
1633+
(res isa TracedRArray || res isa TracedRNumber) && (res = [res])
1634+
MLIR.Dialects.func.return_([r.mlir_data for r in res])
1635+
res
1636+
finally
1637+
MLIR.IR.deactivate!(fnbody)
1638+
end
1639+
1640+
comp_func = MLIR.IR.block!(MLIR.IR.body(mod)) do
1641+
return MLIR.Dialects.func.func_(;
1642+
sym_name=string(f) * "_batch",
1643+
function_type=MLIR.IR.FunctionType(in_tys, [mlir_type(r) for r in result]),
1644+
body=MLIR.IR.Region(),
1645+
sym_visibility,
1646+
)
1647+
end
1648+
MLIR.API.mlirRegionTakeBody(MLIR.IR.region(comp_func, 1), MLIR.IR.region(func, 1))
1649+
MLIR.API.mlirOperationDestroy(func.operation)
1650+
func.operation = MLIR.API.MlirOperation(C_NULL)
1651+
1652+
fname = Reactant.TracedUtils.get_attribute_by_name(comp_func, "sym_name")
1653+
fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))
1654+
1655+
batch_inputs = [x.mlir_data for x in args]
1656+
out_tys = [
1657+
MLIR.IR.TensorType((B, size(r)...), MLIR.IR.Type(Reactant.unwrapped_eltype(r))) for
1658+
r in result
1659+
]
1660+
1661+
op = MLIR.Dialects.enzyme.batch(
1662+
batch_inputs;
1663+
outputs=out_tys,
1664+
fn=fname,
1665+
batch_shape=MLIR.IR.DenseArrayAttribute(Int64[B]),
1666+
)
1667+
1668+
return [
1669+
TracedRArray{Reactant.unwrapped_eltype(r),ndims(r) + 1}(
1670+
(), MLIR.IR.result(op, i), (B, size(r)...)
1671+
) for (i, r) in enumerate(result)
1672+
]
1673+
end
1674+
15601675
end # module Ops

0 commit comments

Comments
 (0)