@@ -1557,4 +1557,119 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
1557
1557
)
1558
1558
end
1559
1559
1560
+ function batch (
1561
+ f,
1562
+ args:: Vector{<:TracedRArray} ,
1563
+ batch_dims:: Vector{Union{Int,Nothing}} ,
1564
+ result_dims:: Union{Vector{Int},Nothing} = nothing ,
1565
+ )
1566
+ @assert length (batch_dims) == length (args)
1567
+
1568
+ batch_sizes = [dim === nothing ? 1 : size (x, dim) for (x, dim) in zip (args, batch_dims)]
1569
+ filter! (x -> x != 1 , batch_sizes)
1570
+ @assert allequal (batch_sizes) " batching dimensions must be equal"
1571
+ B = length (batch_sizes) == 0 ? 1 : first (batch_sizes)
1572
+
1573
+ args = map (zip (args, batch_dims)) do (arg, dim)
1574
+ if dim === nothing
1575
+ return broadcast_in_dim (arg, collect (1 : ndims (arg)) .+ 1 , Int64[B, size (arg)... ])
1576
+ end
1577
+ order = collect (1 : ndims (arg))
1578
+ order[dim] = 1
1579
+ order[1 ] = dim
1580
+ return permutedims (arg, order)
1581
+ end
1582
+ results = batch (f, args)
1583
+ result_dims === nothing && (result_dims = ones (Int64, length (results)))
1584
+ @assert length (results) == length (result_dims)
1585
+ return map (zip (results, result_dims)) do (result, dim)
1586
+ order = collect (1 : ndims (result))
1587
+ order[dim] = 1
1588
+ order[1 ] = dim
1589
+ return permutedims (result, order)
1590
+ end
1591
+ end
1592
+
1593
+ function batch (f, args:: Vector{<:TracedRArray} )
1594
+ batch_sizes = [size (x, 1 ) for x in args]
1595
+ @assert allequal (batch_sizes) " batching dimensions must be equal"
1596
+ B = first (batch_sizes)
1597
+
1598
+ in_tys = [
1599
+ MLIR. IR. TensorType (size (arg)[2 : end ], MLIR. IR. Type (Reactant. unwrapped_eltype (arg)))
1600
+ for arg in args
1601
+ ]
1602
+
1603
+ sym_visibility = MLIR. IR. Attribute (" private" )
1604
+
1605
+ mod = MLIR. IR. mmodule ()
1606
+ func = MLIR. IR. block! (MLIR. IR. body (mod)) do
1607
+ return MLIR. Dialects. func. func_ (;
1608
+ sym_name= string (f) * " _batch_tmp" ,
1609
+ function_type= MLIR. IR. FunctionType (in_tys, []),
1610
+ body= MLIR. IR. Region (),
1611
+ sym_visibility,
1612
+ )
1613
+ end
1614
+ fnbody = MLIR. IR. Block (in_tys, [MLIR. IR. Location () for arg in args])
1615
+ push! (MLIR. IR. region (func, 1 ), fnbody)
1616
+
1617
+ linear_args = [
1618
+ TracedRArray {Reactant.unwrapped_eltype(arg),ndims(arg) - 1} (
1619
+ (), nothing , size (arg)[2 : end ]
1620
+ ) for arg in args
1621
+ ]
1622
+
1623
+ MLIR. IR. activate! (fnbody)
1624
+ result = try
1625
+ for (i, arg) in enumerate (linear_args)
1626
+ raw_arg = MLIR. IR. argument (fnbody, i)
1627
+ Reactant. TracedUtils. set_mlir_data! (arg, raw_arg)
1628
+ end
1629
+ # XXX : call_with_reactant is not working here?
1630
+ # ERROR: type Nothing has no field stmts
1631
+ # res = Reactant.call_with_reactant(f, linear_args...)
1632
+ res = f (linear_args... )
1633
+ (res isa TracedRArray || res isa TracedRNumber) && (res = [res])
1634
+ MLIR. Dialects. func. return_ ([r. mlir_data for r in res])
1635
+ res
1636
+ finally
1637
+ MLIR. IR. deactivate! (fnbody)
1638
+ end
1639
+
1640
+ comp_func = MLIR. IR. block! (MLIR. IR. body (mod)) do
1641
+ return MLIR. Dialects. func. func_ (;
1642
+ sym_name= string (f) * " _batch" ,
1643
+ function_type= MLIR. IR. FunctionType (in_tys, [mlir_type (r) for r in result]),
1644
+ body= MLIR. IR. Region (),
1645
+ sym_visibility,
1646
+ )
1647
+ end
1648
+ MLIR. API. mlirRegionTakeBody (MLIR. IR. region (comp_func, 1 ), MLIR. IR. region (func, 1 ))
1649
+ MLIR. API. mlirOperationDestroy (func. operation)
1650
+ func. operation = MLIR. API. MlirOperation (C_NULL )
1651
+
1652
+ fname = Reactant. TracedUtils. get_attribute_by_name (comp_func, " sym_name" )
1653
+ fname = MLIR. IR. FlatSymbolRefAttribute (Base. String (fname))
1654
+
1655
+ batch_inputs = [x. mlir_data for x in args]
1656
+ out_tys = [
1657
+ MLIR. IR. TensorType ((B, size (r)... ), MLIR. IR. Type (Reactant. unwrapped_eltype (r))) for
1658
+ r in result
1659
+ ]
1660
+
1661
+ op = MLIR. Dialects. enzyme. batch (
1662
+ batch_inputs;
1663
+ outputs= out_tys,
1664
+ fn= fname,
1665
+ batch_shape= MLIR. IR. DenseArrayAttribute (Int64[B]),
1666
+ )
1667
+
1668
+ return [
1669
+ TracedRArray {Reactant.unwrapped_eltype(r),ndims(r) + 1} (
1670
+ (), MLIR. IR. result (op, i), (B, size (r)... )
1671
+ ) for (i, r) in enumerate (result)
1672
+ ]
1673
+ end
1674
+
1560
1675
end # module Ops
0 commit comments