@@ -1557,4 +1557,90 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
1557
1557
)
1558
1558
end
1559
1559
1560
+ # XXX : kwargs
1561
+ # XXX : some of the args are not batched
1562
+ # XXX : Arbitrary dimensions for batching
1563
+ # XXX : Out-axis
1564
+ # XXX : Multiple arg return
1565
+ function batch (f, args:: Vector{<:TracedRArray} )
1566
+ batch_sizes = [size (x, 1 ) for x in args]
1567
+ @assert allequal (batch_sizes) " batching dimensions must be equal"
1568
+ B = first (batch_sizes)
1569
+
1570
+ in_tys = [
1571
+ MLIR. IR. TensorType (size (arg)[2 : end ], MLIR. IR. Type (Reactant. unwrapped_eltype (arg)))
1572
+ for arg in args
1573
+ ]
1574
+
1575
+ sym_visibility = MLIR. IR. Attribute (" private" )
1576
+
1577
+ mod = MLIR. IR. mmodule ()
1578
+ func = MLIR. IR. block! (MLIR. IR. body (mod)) do
1579
+ return MLIR. Dialects. func. func_ (;
1580
+ sym_name= string (f) * " _batch_tmp" ,
1581
+ function_type= MLIR. IR. FunctionType (in_tys, []),
1582
+ body= MLIR. IR. Region (),
1583
+ sym_visibility,
1584
+ )
1585
+ end
1586
+ fnbody = MLIR. IR. Block (in_tys, [MLIR. IR. Location () for arg in args])
1587
+ push! (MLIR. IR. region (func, 1 ), fnbody)
1588
+
1589
+ linear_args = [
1590
+ TracedRArray {Reactant.unwrapped_eltype(arg),ndims(arg) - 1} (
1591
+ (), nothing , size (arg)[2 : end ]
1592
+ ) for arg in args
1593
+ ]
1594
+
1595
+ MLIR. IR. activate! (fnbody)
1596
+ result = try
1597
+ for (i, arg) in enumerate (linear_args)
1598
+ raw_arg = MLIR. IR. argument (fnbody, i)
1599
+ Reactant. TracedUtils. set_mlir_data! (arg, raw_arg)
1600
+ end
1601
+ # XXX : call_with_reactant is not working here?
1602
+ # ERROR: type Nothing has no field stmts
1603
+ # res = Reactant.call_with_reactant(f, linear_args...)
1604
+ res = f (linear_args... )
1605
+ @assert res isa TracedRArray
1606
+ MLIR. Dialects. func. return_ ([res. mlir_data])
1607
+ res
1608
+ finally
1609
+ MLIR. IR. deactivate! (fnbody)
1610
+ end
1611
+
1612
+ comp_func = MLIR. IR. block! (MLIR. IR. body (mod)) do
1613
+ return MLIR. Dialects. func. func_ (;
1614
+ sym_name= string (f) * " _batch" ,
1615
+ function_type= MLIR. IR. FunctionType (in_tys, [mlir_type (result)]),
1616
+ body= MLIR. IR. Region (),
1617
+ sym_visibility,
1618
+ )
1619
+ end
1620
+ MLIR. API. mlirRegionTakeBody (MLIR. IR. region (comp_func, 1 ), MLIR. IR. region (func, 1 ))
1621
+ MLIR. API. mlirOperationDestroy (func. operation)
1622
+ func. operation = MLIR. API. MlirOperation (C_NULL )
1623
+
1624
+ fname = Reactant. TracedUtils. get_attribute_by_name (comp_func, " sym_name" )
1625
+ fname = MLIR. IR. FlatSymbolRefAttribute (Base. String (fname))
1626
+
1627
+ batch_inputs = [x. mlir_data for x in args]
1628
+ output_shape = (B, size (result)... )
1629
+ out_tys = [
1630
+ MLIR. IR. TensorType (output_shape, MLIR. IR. Type (Reactant. unwrapped_eltype (result)))
1631
+ ]
1632
+
1633
+ res = MLIR. Dialects. enzyme. batch (
1634
+ batch_inputs;
1635
+ outputs= out_tys,
1636
+ fn= fname,
1637
+ batch_shape= MLIR. IR. DenseArrayAttribute (Int64[B]),
1638
+ )
1639
+
1640
+ res = MLIR. IR. result (res, 1 )
1641
+ return TracedRArray {Reactant.unwrapped_eltype(result),ndims(result) + 1} (
1642
+ (), res, output_shape
1643
+ )
1644
+ end
1645
+
1560
1646
end # module Ops
0 commit comments