Skip to content

Commit 945408c

Browse files
committed
feat: support passing in a nothing in dims
1 parent c2b8e46 commit 945408c

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

src/Ops.jl

+14-3
Original file line numberDiff line numberDiff line change
@@ -1557,19 +1557,30 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
15571557
)
15581558
end
15591559

1560-
# XXX: Support linearization and de-linearization
1561-
# XXX: some of the args are not batched (use nothing)
15621560
function batch(
1563-
f, args::Vector{<:TracedRArray}, batch_dims::Vector{Int}, result_dims::Vector{Int}
1561+
f,
1562+
args::Vector{<:TracedRArray},
1563+
batch_dims::Vector{Union{Int,Nothing}},
1564+
result_dims::Union{Vector{Int},Nothing}=nothing,
15641565
)
15651566
@assert length(batch_dims) == length(args)
1567+
1568+
batch_sizes = [dim === nothing ? 1 : size(x, dim) for (x, dim) in zip(args, batch_dims)]
1569+
filter!(x -> x != 1, batch_sizes)
1570+
@assert allequal(batch_sizes) "batching dimensions must be equal"
1571+
B = length(batch_sizes) == 0 ? 1 : first(batch_sizes)
1572+
15661573
args = map(zip(args, batch_dims)) do (arg, dim)
1574+
if dim === nothing
1575+
return broadcast_in_dim(arg, collect(1:ndims(arg)) .+ 1, Int64[B, size(arg)...])
1576+
end
15671577
order = collect(1:ndims(arg))
15681578
order[dim] = 1
15691579
order[1] = dim
15701580
return permutedims(arg, order)
15711581
end
15721582
results = batch(f, args)
1583+
result_dims === nothing && (result_dims = ones(Int64, length(results)))
15731584
@assert length(results) == length(result_dims)
15741585
return map(zip(results, result_dims)) do (result, dim)
15751586
order = collect(1:ndims(result))

0 commit comments

Comments
 (0)