@@ -1557,19 +1557,30 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
1557
1557
)
1558
1558
end
1559
1559
1560
- # XXX : Support linearization and de-linearization
1561
- # XXX : some of the args are not batched (use nothing)
1562
1560
function batch (
1563
- f, args:: Vector{<:TracedRArray} , batch_dims:: Vector{Int} , result_dims:: Vector{Int}
1561
+ f,
1562
+ args:: Vector{<:TracedRArray} ,
1563
+ batch_dims:: Vector{Union{Int,Nothing}} ,
1564
+ result_dims:: Union{Vector{Int},Nothing} = nothing ,
1564
1565
)
1565
1566
@assert length (batch_dims) == length (args)
1567
+
1568
+ batch_sizes = [dim === nothing ? 1 : size (x, dim) for (x, dim) in zip (args, batch_dims)]
1569
+ filter! (x -> x != 1 , batch_sizes)
1570
+ @assert allequal (batch_sizes) " batching dimensions must be equal"
1571
+ B = length (batch_sizes) == 0 ? 1 : first (batch_sizes)
1572
+
1566
1573
args = map (zip (args, batch_dims)) do (arg, dim)
1574
+ if dim === nothing
1575
+ return broadcast_in_dim (arg, collect (1 : ndims (arg)) .+ 1 , Int64[B, size (arg)... ])
1576
+ end
1567
1577
order = collect (1 : ndims (arg))
1568
1578
order[dim] = 1
1569
1579
order[1 ] = dim
1570
1580
return permutedims (arg, order)
1571
1581
end
1572
1582
results = batch (f, args)
1583
+ result_dims === nothing && (result_dims = ones (Int64, length (results)))
1573
1584
@assert length (results) == length (result_dims)
1574
1585
return map (zip (results, result_dims)) do (result, dim)
1575
1586
order = collect (1 : ndims (result))
0 commit comments