Skip to content

Commit 57f342a

Browse files
committed
feat: add Ops.batch
1 parent dafa186 commit 57f342a

File tree

1 file changed

+86
-0
lines changed

1 file changed

+86
-0
lines changed

src/Ops.jl

+86
Original file line numberDiff line numberDiff line change
@@ -1557,4 +1557,90 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
15571557
)
15581558
end
15591559

1560+
# XXX: kwargs
1561+
# XXX: some of the args are not batched
1562+
# XXX: Arbitrary dimensions for batching
1563+
# XXX: Out-axis
1564+
# XXX: Multiple arg return
1565+
function batch(f, args::Vector{<:TracedRArray})
1566+
batch_sizes = [size(x, 1) for x in args]
1567+
@assert allequal(batch_sizes) "batching dimensions must be equal"
1568+
B = first(batch_sizes)
1569+
1570+
in_tys = [
1571+
MLIR.IR.TensorType(size(arg)[2:end], MLIR.IR.Type(Reactant.unwrapped_eltype(arg)))
1572+
for arg in args
1573+
]
1574+
1575+
sym_visibility = MLIR.IR.Attribute("private")
1576+
1577+
mod = MLIR.IR.mmodule()
1578+
func = MLIR.IR.block!(MLIR.IR.body(mod)) do
1579+
return MLIR.Dialects.func.func_(;
1580+
sym_name=string(f) * "_batch_tmp",
1581+
function_type=MLIR.IR.FunctionType(in_tys, []),
1582+
body=MLIR.IR.Region(),
1583+
sym_visibility,
1584+
)
1585+
end
1586+
fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in args])
1587+
push!(MLIR.IR.region(func, 1), fnbody)
1588+
1589+
linear_args = [
1590+
TracedRArray{Reactant.unwrapped_eltype(arg),ndims(arg) - 1}(
1591+
(), nothing, size(arg)[2:end]
1592+
) for arg in args
1593+
]
1594+
1595+
MLIR.IR.activate!(fnbody)
1596+
result = try
1597+
for (i, arg) in enumerate(linear_args)
1598+
raw_arg = MLIR.IR.argument(fnbody, i)
1599+
Reactant.TracedUtils.set_mlir_data!(arg, raw_arg)
1600+
end
1601+
# XXX: call_with_reactant is not working here?
1602+
# ERROR: type Nothing has no field stmts
1603+
# res = Reactant.call_with_reactant(f, linear_args...)
1604+
res = f(linear_args...)
1605+
@assert res isa TracedRArray
1606+
MLIR.Dialects.func.return_([res.mlir_data])
1607+
res
1608+
finally
1609+
MLIR.IR.deactivate!(fnbody)
1610+
end
1611+
1612+
comp_func = MLIR.IR.block!(MLIR.IR.body(mod)) do
1613+
return MLIR.Dialects.func.func_(;
1614+
sym_name=string(f) * "_batch",
1615+
function_type=MLIR.IR.FunctionType(in_tys, [mlir_type(result)]),
1616+
body=MLIR.IR.Region(),
1617+
sym_visibility,
1618+
)
1619+
end
1620+
MLIR.API.mlirRegionTakeBody(MLIR.IR.region(comp_func, 1), MLIR.IR.region(func, 1))
1621+
MLIR.API.mlirOperationDestroy(func.operation)
1622+
func.operation = MLIR.API.MlirOperation(C_NULL)
1623+
1624+
fname = Reactant.TracedUtils.get_attribute_by_name(comp_func, "sym_name")
1625+
fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))
1626+
1627+
batch_inputs = [x.mlir_data for x in args]
1628+
output_shape = (B, size(result)...)
1629+
out_tys = [
1630+
MLIR.IR.TensorType(output_shape, MLIR.IR.Type(Reactant.unwrapped_eltype(result)))
1631+
]
1632+
1633+
res = MLIR.Dialects.enzyme.batch(
1634+
batch_inputs;
1635+
outputs=out_tys,
1636+
fn=fname,
1637+
batch_shape=MLIR.IR.DenseArrayAttribute(Int64[B]),
1638+
)
1639+
1640+
res = MLIR.IR.result(res, 1)
1641+
return TracedRArray{Reactant.unwrapped_eltype(result),ndims(result) + 1}(
1642+
(), res, output_shape
1643+
)
1644+
end
1645+
15601646
end # module Ops

0 commit comments

Comments
 (0)