Skip to content

Commit 538ef60

Browse files
committed
docs: setup batching tutorial
1 parent 8c4e3d8 commit 538ef60

File tree

7 files changed

+51
-4
lines changed

7 files changed

+51
-4
lines changed

docs/make.jl

+9-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ examples = [
2424
pages = [
2525
"Reactant.jl" => "index.md",
2626
"Introduction" => ["Getting Started" => "introduction/index.md"],
27-
"Tutorials" => ["Overview" => "tutorials/index.md"],
27+
"Tutorials" => [
28+
"Overview" => "tutorials/index.md",
29+
"Batching Functions with `Reactant.Ops.batch`" => "tutorials/batching.md",
30+
],
2831
"API Reference" => [
2932
"Reactant API" => "api/api.md",
3033
"Ops" => "api/ops.md",
@@ -37,6 +40,11 @@ pages = [
3740
"Func" => "api/func.md",
3841
"StableHLO" => "api/stablehlo.md",
3942
"VHLO" => "api/vhlo.md",
43+
"GPU" => "api/gpu.md",
44+
"LLVM" => "api/llvm.md",
45+
"NVVM" => "api/nvvm.md",
46+
"TPU" => "api/tpu.md",
47+
"Triton" => "api/triton.md",
4048
],
4149
"MLIR API" => "api/mlirc.md",
4250
"XLA" => "api/xla.md",

docs/src/.vitepress/config.mts

+14-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,16 @@ export default defineConfig({
5353
{ text: "Home", link: "/" },
5454
{ text: "Getting Started", link: "/introduction" },
5555
{ text: "Benchmarks", link: "https://enzymead.github.io/Reactant.jl/benchmarks/" },
56-
{ text: "Tutorials", link: "/tutorials/" },
56+
{
57+
text: "Tutorials",
58+
items: [
59+
{ text: "Overview", link: "/tutorials/" },
60+
{
61+
text: "Batching Functions with `Reactant.Ops.batch`",
62+
link: "/tutorials/batching"
63+
},
64+
],
65+
},
5766
{
5867
text: "API",
5968
items: [
@@ -105,6 +114,10 @@ export default defineConfig({
105114
collapsed: false,
106115
items: [
107116
{ text: "Overview", link: "/tutorials/" },
117+
{
118+
text: "Batching Functions with `Reactant.Ops.batch`",
119+
link: "/tutorials/batching",
120+
},
108121
],
109122
},
110123
"/api/": {

docs/src/tutorials/batching.md

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# [Batching Functions with [`Reactant.Ops.batch`](@ref)](@id batching-tutorial)
2+
3+

docs/src/tutorials/index.md

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
# Tutorials
22

33
We are currently working on adding tutorials to Reactant!! Please check back soon!
4+
5+
- [Batching Functions with `Reactant.Ops.batch`](@ref batching-tutorial)

src/Ops.jl

+20-2
Original file line numberDiff line numberDiff line change
@@ -1604,8 +1604,24 @@ end
16041604
# This function assumes that the last dimension of each element is the batch dimension by
16051605
# default. This is the standard Julia ordering for batching. We permutedims the ordering to
16061606
# make sure the first dimension is the batch dimension when calling `batch_internal` below.
1607-
# XXX: Mutation inside a batched function is not supported yet (need to set the results
1608-
# correctly)
1607+
"""
1608+
batch(f, args...; batch_dims=nothing, result_dims=nothing)
1609+
1610+
Map `f` over the arguments `args` along the batch dimensions `batch_dims` and return the results with the corresponding batch dimensions specified by `result_dims`. (For users
1611+
familiar with `jax`, this operation corresponds to `jax.vmap`.)
1612+
1613+
If `batch_dims` is `nothing`, we assume that the last dimension of each leaf of `args` is the batch dimension. If `result_dims` is `nothing`, we assume that the last dimension of each leaf of the returned values is the batch dimension.
1614+
1615+
To avoid batching a specific leaf, pass `nothing` for the corresponding `batch_dims`.
1616+
1617+
## Examples
1618+
1619+
For usage examples, see the [Batching Functions with `Reactant.Ops.batch`](@ref batching-tutorial) tutorial.
1620+
1621+
!!! danger
1622+
1623+
Mutation inside a batched function is not supported yet and will lead to unexpected results.
1624+
"""
16091625
@noinline function batch(f, args...; batch_dims=nothing, result_dims=nothing)
16101626
batch_sizes = Int64[]
16111627
batching_dims = if batch_dims === nothing
@@ -1651,6 +1667,8 @@ end
16511667
end
16521668

16531669
return Functors.fmap(results, result_dims) do result, dim
1670+
@assert dim !== nothing "Result batch dimension cannot be `nothing`"
1671+
16541672
order = collect(Int64, 1:ndims(result))
16551673
order[dim] = 1
16561674
order[1] = dim

test/batching.jl

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
using Reactant, Test
2+

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
5757
@safetestset "Wrapped Arrays" include("wrapped_arrays.jl")
5858
@safetestset "Control Flow" include("control_flow.jl")
5959
@safetestset "Sorting" include("sorting.jl")
60+
@safetestset "Batching" include("batching.jl")
6061
end
6162

6263
if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"

0 commit comments

Comments
 (0)