Skip to content

Commit 6e8f4b0

Browse files
avik-palwsmoses
authored andcommitted
docs: setup batching tutorial
1 parent ebe5727 commit 6e8f4b0

File tree

9 files changed

+53
-8
lines changed

9 files changed

+53
-8
lines changed

docs/make.jl

+9-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ examples = [
2424
pages = [
2525
"Reactant.jl" => "index.md",
2626
"Introduction" => ["Getting Started" => "introduction/index.md"],
27-
"Tutorials" => ["Overview" => "tutorials/index.md"],
27+
"Tutorials" => [
28+
"Overview" => "tutorials/index.md",
29+
"Batching Functions with `Reactant.Ops.batch`" => "tutorials/batching.md",
30+
],
2831
"API Reference" => [
2932
"Reactant API" => "api/api.md",
3033
"Ops" => "api/ops.md",
@@ -37,6 +40,11 @@ pages = [
3740
"Func" => "api/func.md",
3841
"StableHLO" => "api/stablehlo.md",
3942
"VHLO" => "api/vhlo.md",
43+
"GPU" => "api/gpu.md",
44+
"LLVM" => "api/llvm.md",
45+
"NVVM" => "api/nvvm.md",
46+
"TPU" => "api/tpu.md",
47+
"Triton" => "api/triton.md",
4048
],
4149
"MLIR API" => "api/mlirc.md",
4250
"XLA" => "api/xla.md",

docs/src/.vitepress/config.mts

+14-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,16 @@ export default defineConfig({
5353
{ text: "Home", link: "/" },
5454
{ text: "Getting Started", link: "/introduction" },
5555
{ text: "Benchmarks", link: "https://enzymead.github.io/Reactant.jl/benchmarks/" },
56-
{ text: "Tutorials", link: "/tutorials/" },
56+
{
57+
text: "Tutorials",
58+
items: [
59+
{ text: "Overview", link: "/tutorials/" },
60+
{
61+
text: "Batching Functions with `Reactant.Ops.batch`",
62+
link: "/tutorials/batching"
63+
},
64+
],
65+
},
5766
{
5867
text: "API",
5968
items: [
@@ -105,6 +114,10 @@ export default defineConfig({
105114
collapsed: false,
106115
items: [
107116
{ text: "Overview", link: "/tutorials/" },
117+
{
118+
text: "Batching Functions with `Reactant.Ops.batch`",
119+
link: "/tutorials/batching",
120+
},
108121
],
109122
},
110123
"/api/": {

docs/src/tutorials/batching.md

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# [Batching Functions with [`Reactant.Ops.batch`](@ref)](@id batching-tutorial)
2+
3+

docs/src/tutorials/index.md

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
# Tutorials
22

33
We are currently working on adding tutorials to Reactant!! Please check back soon!
4+
5+
- [Batching Functions with `Reactant.Ops.batch`](@ref batching-tutorial)

src/Compiler.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,7 @@ function codegen_unflatten!(
776776
paths = (
777777
(
778778
p for p in Reactant.TracedUtils.get_paths(result) if
779-
length(p) 1 && (p[1] == :result || p[1] == :resargs)
779+
length(p) > 0 && (p[1] == :result || p[1] == :resargs)
780780
)...,
781781
)
782782
for path in paths
@@ -846,7 +846,7 @@ function codegen_unflatten!(
846846
paths = (
847847
(
848848
p for p in Reactant.TracedUtils.get_paths(result) if
849-
length(p) 1 && (p[1] == :result || p[1] == :resargs || p[1] == :args)
849+
length(p) > 0 && (p[1] == :result || p[1] == :resargs || p[1] == :args)
850850
)...,
851851
)
852852

src/Ops.jl

+20-2
Original file line numberDiff line numberDiff line change
@@ -2013,8 +2013,24 @@ end
20132013
# This function assumes that the last dimension of each element is the batch dimension by
20142014
# default. This is the standard Julia ordering for batching. We permutedims the ordering to
20152015
# make sure the first dimension is the batch dimension when calling `batch_internal` below.
2016-
# XXX: Mutation inside a batched function is not supported yet (need to set the results
2017-
# correctly)
2016+
"""
2017+
batch(f, args...; batch_dims=nothing, result_dims=nothing)
2018+
2019+
Map `f` over the arguments `args` along the batch dimensions `batch_dims` and return the results with the corresponding batch dimensions specified by `result_dims`. (For users
2020+
familiar with `jax`, this operation corresponds to `jax.vmap`.)
2021+
2022+
If `batch_dims` is `nothing`, we assume that the last dimension of each leaf of `args` is the batch dimension. If `result_dims` is `nothing`, we assume that the last dimension of each leaf of the returned values is the batch dimension.
2023+
2024+
To avoid batching a specific leaf, pass `nothing` for the corresponding `batch_dims`.
2025+
2026+
## Examples
2027+
2028+
For usage examples, see the [Batching Functions with `Reactant.Ops.batch`](@ref batching-tutorial) tutorial.
2029+
2030+
!!! danger
2031+
2032+
Mutation inside a batched function is not supported yet and will lead to unexpected results.
2033+
"""
20182034
@noinline function batch(f, args...; batch_dims=nothing, result_dims=nothing)
20192035
batch_sizes = Int64[]
20202036
batching_dims = if batch_dims === nothing
@@ -2060,6 +2076,8 @@ end
20602076
end
20612077

20622078
return fmap(results, result_dims) do result, dim
2079+
@assert dim !== nothing "Result batch dimension cannot be `nothing`"
2080+
20632081
order = collect(Int64, 1:ndims(result))
20642082
order[dim] = 1
20652083
order[1] = dim

src/Reactant.jl

-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ using Functors: @leaf
99
using Adapt: Adapt, WrappedArray
1010
using GPUArraysCore: GPUArraysCore, @allowscalar, allowscalar # keep this import to allow users to do `Reactant.allowscalar(false)`
1111

12-
using Functors: @leaf
13-
1412
export @allowscalar # re-exported from GPUArraysCore
1513

1614
# auxiliary types and functions

test/batching.jl

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
using Reactant, Test
2+

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
5757
@safetestset "Wrapped Arrays" include("wrapped_arrays.jl")
5858
@safetestset "Control Flow" include("control_flow.jl")
5959
@safetestset "Sorting" include("sorting.jl")
60+
@safetestset "Batching" include("batching.jl")
6061
end
6162

6263
if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"

0 commit comments

Comments
 (0)