@@ -1558,18 +1558,30 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
1558
1558
end
1559
1559
1560
1560
# XXX : Support linearization and de-linearization
1561
- # XXX : some of the args are not batched (use nothing)
1562
1561
function batch (
1563
- f, args:: Vector{<:TracedRArray} , batch_dims:: Vector{Int} , result_dims:: Vector{Int}
1562
+ f,
1563
+ args:: Vector{<:TracedRArray} ,
1564
+ batch_dims:: Vector{Union{Int,Nothing}} ,
1565
+ result_dims:: Union{Vector{Int},Nothing} = nothing ,
1564
1566
)
1565
1567
@assert length (batch_dims) == length (args)
1568
+
1569
+ batch_sizes = [dim === nothing ? 1 : size (x, dim) for (x, dim) in zip (args, batch_dims)]
1570
+ filter! (x -> x != 1 , batch_sizes)
1571
+ @assert allequal (batch_sizes) " batching dimensions must be equal"
1572
+ B = length (batch_sizes) == 0 ? 1 : first (batch_sizes)
1573
+
1566
1574
args = map (zip (args, batch_dims)) do (arg, dim)
1575
+ if dim === nothing
1576
+ return broadcast_in_dim (arg, collect (1 : ndims (arg)) .+ 1 , Int64[B, size (arg)... ])
1577
+ end
1567
1578
order = collect (1 : ndims (arg))
1568
1579
order[dim] = 1
1569
1580
order[1 ] = dim
1570
1581
return permutedims (arg, order)
1571
1582
end
1572
1583
results = batch (f, args)
1584
+ result_dims === nothing && (result_dims = ones (Int64, length (results)))
1573
1585
@assert length (results) == length (result_dims)
1574
1586
return map (zip (results, result_dims)) do (result, dim)
1575
1587
order = collect (1 : ndims (result))
0 commit comments