Skip to content

Commit 231e064

Browse files
committed
feat: support passing in a nothing in dims
1 parent c2b8e46 commit 231e064

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

src/Ops.jl

+14-2
Original file line numberDiff line numberDiff line change
@@ -1558,18 +1558,30 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
15581558
end
15591559

15601560
# XXX: Support linearization and de-linearization
1561-
# XXX: some of the args are not batched (use nothing)
15621561
function batch(
1563-
f, args::Vector{<:TracedRArray}, batch_dims::Vector{Int}, result_dims::Vector{Int}
1562+
f,
1563+
args::Vector{<:TracedRArray},
1564+
batch_dims::Vector{Union{Int,Nothing}},
1565+
result_dims::Union{Vector{Int},Nothing}=nothing,
15641566
)
15651567
@assert length(batch_dims) == length(args)
1568+
1569+
batch_sizes = [dim === nothing ? 1 : size(x, dim) for (x, dim) in zip(args, batch_dims)]
1570+
filter!(x -> x != 1, batch_sizes)
1571+
@assert allequal(batch_sizes) "batching dimensions must be equal"
1572+
B = length(batch_sizes) == 0 ? 1 : first(batch_sizes)
1573+
15661574
args = map(zip(args, batch_dims)) do (arg, dim)
1575+
if dim === nothing
1576+
return broadcast_in_dim(arg, collect(1:ndims(arg)) .+ 1, Int64[B, size(arg)...])
1577+
end
15671578
order = collect(1:ndims(arg))
15681579
order[dim] = 1
15691580
order[1] = dim
15701581
return permutedims(arg, order)
15711582
end
15721583
results = batch(f, args)
1584+
result_dims === nothing && (result_dims = ones(Int64, length(results)))
15731585
@assert length(results) == length(result_dims)
15741586
return map(zip(results, result_dims)) do (result, dim)
15751587
order = collect(1:ndims(result))

0 commit comments

Comments
 (0)