@@ -1557,4 +1557,116 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
1557
1557
)
1558
1558
end
1559
1559
1560
+ @noinline function batch (
1561
+ f,
1562
+ args:: Vector{<:TracedRArray} ,
1563
+ batch_dims:: Vector{Union{Int,Nothing}} ,
1564
+ result_dims:: Union{Vector{Int},Nothing} = nothing ,
1565
+ )
1566
+ @assert length (batch_dims) == length (args)
1567
+
1568
+ batch_sizes = [dim === nothing ? 1 : size (x, dim) for (x, dim) in zip (args, batch_dims)]
1569
+ filter! (x -> x != 1 , batch_sizes)
1570
+ @assert allequal (batch_sizes) " batching dimensions must be equal"
1571
+ B = length (batch_sizes) == 0 ? 1 : first (batch_sizes)
1572
+
1573
+ args = map (zip (args, batch_dims)) do (arg, dim)
1574
+ if dim === nothing
1575
+ return broadcast_in_dim (arg, collect (1 : ndims (arg)) .+ 1 , Int64[B, size (arg)... ])
1576
+ end
1577
+ order = collect (1 : ndims (arg))
1578
+ order[dim] = 1
1579
+ order[1 ] = dim
1580
+ return permutedims (arg, order)
1581
+ end
1582
+ results = batch (f, args)
1583
+ result_dims === nothing && (result_dims = ones (Int64, length (results)))
1584
+ @assert length (results) == length (result_dims)
1585
+ return map (zip (results, result_dims)) do (result, dim)
1586
+ order = collect (1 : ndims (result))
1587
+ order[dim] = 1
1588
+ order[1 ] = dim
1589
+ return permutedims (result, order)
1590
+ end
1591
+ end
1592
+
1593
+ @noinline function batch (f, args:: Vector{<:TracedRArray} )
1594
+ batch_sizes = [size (x, 1 ) for x in args]
1595
+ @assert allequal (batch_sizes) " batching dimensions must be equal"
1596
+ B = first (batch_sizes)
1597
+
1598
+ in_tys = [
1599
+ MLIR. IR. TensorType (size (arg)[2 : end ], MLIR. IR. Type (Reactant. unwrapped_eltype (arg)))
1600
+ for arg in args
1601
+ ]
1602
+
1603
+ sym_visibility = MLIR. IR. Attribute (" private" )
1604
+
1605
+ mod = MLIR. IR. mmodule ()
1606
+ func = MLIR. IR. block! (MLIR. IR. body (mod)) do
1607
+ return MLIR. Dialects. func. func_ (;
1608
+ sym_name= string (f) * " _batch_tmp" ,
1609
+ function_type= MLIR. IR. FunctionType (in_tys, []),
1610
+ body= MLIR. IR. Region (),
1611
+ sym_visibility,
1612
+ )
1613
+ end
1614
+ fnbody = MLIR. IR. Block (in_tys, [MLIR. IR. Location () for arg in args])
1615
+ push! (MLIR. IR. region (func, 1 ), fnbody)
1616
+
1617
+ linear_args = [
1618
+ TracedRArray {Reactant.unwrapped_eltype(arg),ndims(arg) - 1} (
1619
+ (), nothing , size (arg)[2 : end ]
1620
+ ) for arg in args
1621
+ ]
1622
+
1623
+ MLIR. IR. activate! (fnbody)
1624
+ result = try
1625
+ for (i, arg) in enumerate (linear_args)
1626
+ raw_arg = MLIR. IR. argument (fnbody, i)
1627
+ Reactant. TracedUtils. set_mlir_data! (arg, raw_arg)
1628
+ end
1629
+ res = Reactant. call_with_reactant (f, linear_args... )
1630
+ (res isa TracedRArray || res isa TracedRNumber) && (res = [res])
1631
+ MLIR. Dialects. func. return_ ([r. mlir_data for r in res])
1632
+ res
1633
+ finally
1634
+ MLIR. IR. deactivate! (fnbody)
1635
+ end
1636
+
1637
+ comp_func = MLIR. IR. block! (MLIR. IR. body (mod)) do
1638
+ return MLIR. Dialects. func. func_ (;
1639
+ sym_name= string (f) * " _batch" ,
1640
+ function_type= MLIR. IR. FunctionType (in_tys, [mlir_type (r) for r in result]),
1641
+ body= MLIR. IR. Region (),
1642
+ sym_visibility,
1643
+ )
1644
+ end
1645
+ MLIR. API. mlirRegionTakeBody (MLIR. IR. region (comp_func, 1 ), MLIR. IR. region (func, 1 ))
1646
+ MLIR. API. mlirOperationDestroy (func. operation)
1647
+ func. operation = MLIR. API. MlirOperation (C_NULL )
1648
+
1649
+ fname = Reactant. TracedUtils. get_attribute_by_name (comp_func, " sym_name" )
1650
+ fname = MLIR. IR. FlatSymbolRefAttribute (Base. String (fname))
1651
+
1652
+ batch_inputs = [x. mlir_data for x in args]
1653
+ out_tys = [
1654
+ MLIR. IR. TensorType ((B, size (r)... ), MLIR. IR. Type (Reactant. unwrapped_eltype (r))) for
1655
+ r in result
1656
+ ]
1657
+
1658
+ op = MLIR. Dialects. enzyme. batch (
1659
+ batch_inputs;
1660
+ outputs= out_tys,
1661
+ fn= fname,
1662
+ batch_shape= MLIR. IR. DenseArrayAttribute (Int64[B]),
1663
+ )
1664
+
1665
+ return [
1666
+ TracedRArray {Reactant.unwrapped_eltype(r),ndims(r) + 1} (
1667
+ (), MLIR. IR. result (op, i), (B, size (r)... )
1668
+ ) for (i, r) in enumerate (result)
1669
+ ]
1670
+ end
1671
+
1560
1672
end # module Ops
0 commit comments