Skip to content

Commit 02bdfb1

Browse files
committed
docs: setup batching tutorial
1 parent 6e684c3 commit 02bdfb1

File tree

9 files changed

+48
-9
lines changed

9 files changed

+48
-9
lines changed

docs/make.jl

+10-2
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,11 @@ examples = [
2424
pages = [
2525
"Reactant.jl" => "index.md",
2626
"Introduction" => ["Getting Started" => "introduction/index.md"],
27-
"Tutorials" =>
28-
["Overview" => "tutorials/index.md", "Profiling" => "tutorials/profiling.md"],
27+
"Tutorials" => [
28+
"Overview" => "tutorials/index.md",
29+
"Profiling" => "tutorials/profiling.md",
30+
"Batching Functions with `Reactant.Ops.batch`" => "tutorials/batching.md",
31+
],
2932
"API Reference" => [
3033
"Reactant API" => "api/api.md",
3134
"Ops" => "api/ops.md",
@@ -38,6 +41,11 @@ pages = [
3841
"Func" => "api/func.md",
3942
"StableHLO" => "api/stablehlo.md",
4043
"VHLO" => "api/vhlo.md",
44+
"GPU" => "api/gpu.md",
45+
"LLVM" => "api/llvm.md",
46+
"NVVM" => "api/nvvm.md",
47+
"TPU" => "api/tpu.md",
48+
"Triton" => "api/triton.md",
4149
],
4250
"MLIR API" => "api/mlirc.md",
4351
"XLA" => "api/xla.md",

docs/src/.vitepress/config.mts

+9-1
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,12 @@ export default defineConfig({
5656
{
5757
text: "Tutorials",
5858
items: [
59-
{text: "Overview", link: "/tutorials/"},
59+
{ text: "Overview", link: "/tutorials/" },
6060
{text: "Profiling", link: "/tutorials/profiling"},
61+
{
62+
text: "Batching Functions with `Reactant.Ops.batch`",
63+
link: "/tutorials/batching"
64+
},
6165
],
6266
},
6367
{
@@ -112,6 +116,10 @@ export default defineConfig({
112116
items: [
113117
{ text: "Overview", link: "/tutorials/" },
114118
{ text: "Profiling", link: "/tutorials/profiling" },
119+
{
120+
text: "Batching Functions with `Reactant.Ops.batch`",
121+
link: "/tutorials/batching",
122+
},
115123
],
116124
},
117125
"/api/": {

docs/src/tutorials/batching.md

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# [Batching Functions with [`Reactant.Ops.batch`](@ref)](@id batching-tutorial)
2+
3+

docs/src/tutorials/index.md

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Tutorials
22

33
- [Profiling](@ref profiling).
4+
- [Batching Functions with `Reactant.Ops.batch`](@ref batching-tutorial)
45

56
We are currently working on adding more tutorials to Reactant!! Please check back soon!

src/Compiler.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,7 @@ function codegen_unflatten!(
776776
paths = (
777777
(
778778
p for p in Reactant.TracedUtils.get_paths(result) if
779-
length(p) 1 && (p[1] == :result || p[1] == :resargs)
779+
length(p) > 0 && (p[1] == :result || p[1] == :resargs)
780780
)...,
781781
)
782782
for path in paths
@@ -846,7 +846,7 @@ function codegen_unflatten!(
846846
paths = (
847847
(
848848
p for p in Reactant.TracedUtils.get_paths(result) if
849-
length(p) 1 && (p[1] == :result || p[1] == :resargs || p[1] == :args)
849+
length(p) > 0 && (p[1] == :result || p[1] == :resargs || p[1] == :args)
850850
)...,
851851
)
852852

src/Ops.jl

+20-2
Original file line numberDiff line numberDiff line change
@@ -2013,8 +2013,24 @@ end
20132013
# This function assumes that the last dimension of each element is the batch dimension by
20142014
# default. This is the standard Julia ordering for batching. We permutedims the ordering to
20152015
# make sure the first dimension is the batch dimension when calling `batch_internal` below.
2016-
# XXX: Mutation inside a batched function is not supported yet (need to set the results
2017-
# correctly)
2016+
"""
2017+
batch(f, args...; batch_dims=nothing, result_dims=nothing)
2018+
2019+
Map `f` over the arguments `args` along the batch dimensions `batch_dims` and return the results with the corresponding batch dimensions specified by `result_dims`. (For users
2020+
familiar with `jax`, this operation corresponds to `jax.vmap`.)
2021+
2022+
If `batch_dims` is `nothing`, we assume that the last dimension of each leaf of `args` is the batch dimension. If `result_dims` is `nothing`, we assume that the last dimension of each leaf of the returned values is the batch dimension.
2023+
2024+
To avoid batching a specific leaf, pass `nothing` for the corresponding `batch_dims`.
2025+
2026+
## Examples
2027+
2028+
For usage examples, see the [Batching Functions with `Reactant.Ops.batch`](@ref batching-tutorial) tutorial.
2029+
2030+
!!! danger
2031+
2032+
Mutation inside a batched function is not supported yet and will lead to unexpected results.
2033+
"""
20182034
@noinline function batch(f, args...; batch_dims=nothing, result_dims=nothing)
20192035
batch_sizes = Int64[]
20202036
batching_dims = if batch_dims === nothing
@@ -2060,6 +2076,8 @@ end
20602076
end
20612077

20622078
return fmap(results, result_dims) do result, dim
2079+
@assert dim !== nothing "Result batch dimension cannot be `nothing`"
2080+
20632081
order = collect(Int64, 1:ndims(result))
20642082
order[dim] = 1
20652083
order[1] = dim

src/Reactant.jl

-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ using Functors: @leaf
99
using Adapt: Adapt, WrappedArray
1010
using GPUArraysCore: GPUArraysCore, @allowscalar, allowscalar # keep this import to allow users to do `Reactant.allowscalar(false)`
1111

12-
using Functors: @leaf
13-
1412
export @allowscalar # re-exported from GPUArraysCore
1513

1614
# auxiliary types and functions

test/batching.jl

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
using Reactant, Test
2+

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
5757
@safetestset "Wrapped Arrays" include("wrapped_arrays.jl")
5858
@safetestset "Control Flow" include("control_flow.jl")
5959
@safetestset "Sorting" include("sorting.jl")
60+
@safetestset "Batching" include("batching.jl")
6061
end
6162

6263
if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"

0 commit comments

Comments
 (0)