Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit f09b902

Browse files
authoredMar 18, 2025··
Merge branch 'main' into nounroll
2 parents 3a8317e + 514e506 commit f09b902

19 files changed

+260
-11
lines changed
 

‎Project.toml

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>", "Sergio Sánchez Ramírez <sergio.sanchez.ramirez@bsc.es>", "Paul Berg <paul@plutojl.org>", "Avik Pal <avikpal@mit.edu>", "Mosè Giordano <mose@gnu.org>"]
4-
version = "0.2.44"
4+
version = "0.2.46"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -85,8 +85,8 @@ Preferences = "1.4"
8585
PythonCall = "0.9"
8686
Random = "1.10"
8787
Random123 = "1.7"
88-
ReactantCore = "0.1.5"
89-
Reactant_jll = "0.0.92"
88+
ReactantCore = "0.1.6"
89+
Reactant_jll = "0.0.93"
9090
Scratch = "1.2"
9191
Sockets = "1.10"
9292
SpecialFunctions = "2.4"

‎deps/ReactantExtra/WORKSPACE

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ http_archive(
99
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
1010
)
1111

12-
ENZYMEXLA_COMMIT = "d39a1fa814d329293645ff9771e68c9e9be8ceae"
12+
ENZYMEXLA_COMMIT = "9fe4c5e3f2ec044db7860ead3d42838589a08104"
1313
ENZYMEXLA_SHA256 = ""
1414

1515
http_archive(

‎docs/make.jl

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ pages = [
3636
"Overview" => "tutorials/index.md",
3737
"Profiling" => "tutorials/profiling.md",
3838
"Distributed" => "tutorials/multihost.md",
39+
"Local build" => "tutorials/local-build.md",
3940
],
4041
"API Reference" => [
4142
"Reactant API" => "api/api.md",

‎docs/src/.vitepress/config.mts

+2
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ export default defineConfig({
6464
{text: "Overview", link: "/tutorials/"},
6565
{text: "Profiling", link: "/tutorials/profiling"},
6666
{text: "Distributed", link: "/tutorials/multihost"},
67+
{text: "Local build", link: "/tutorials/local-build"},
6768
],
6869
},
6970
{
@@ -124,6 +125,7 @@ export default defineConfig({
124125
{ text: "Overview", link: "/tutorials/" },
125126
{ text: "Profiling", link: "/tutorials/profiling" },
126127
{ text: "Distributed", link: "/tutorials/multihost" },
128+
{ text: "Local build", link: "/tutorials/local-build" },
127129
],
128130
},
129131
"/api/": {

‎docs/src/tutorials/index.md

+1
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22

33
- [Profiling](@ref profiling).
44
- [Multi-Host Environments](@ref distributed).
5+
- [Local build of ReactantExtra](@ref local-build).
56

67
We are currently working on adding more tutorials to Reactant!! Please check back soon!

‎docs/src/tutorials/local-build.md

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# [Local build of ReactantExtra](@ref local-build)
2+
3+
In the `deps/` subdirectory of the Reactant repository there is a script to do local builds of ReactantExtra, including debug builds.
4+
5+
## Requirements
6+
7+
* Julia. If you don't have it already, you can obtain it from the [official Julia website](https://julialang.org/downloads/)
8+
* A reasonably recent C/C++ compiler, ideally GCC 12+.
9+
Older compilers may not work.
10+
* Bazel. If you don't have it already, you can download a build for your platform from [the latest `bazelbuild/bazelisk` release](https://github.com/bazelbuild/bazelisk/releases/latest) and put the `bazel` executable in `PATH`
11+
* not necessary in general, but for debug builds with CUDA support, you'll need a fast linker, like `lld` or `mold`
12+
Binutils `ld` won't work, don't even try using it.
13+
You can obtain `mold` for your platform from the [latest `rui314/mold` release](https://github.com/rui314/mold/releases/latest) and put the `mold` executable in `PATH`
14+
15+
## Building
16+
17+
At a high-level, after you `cd` to the `deps/` directory you can run the commands
18+
19+
```bash
20+
julia --project -e 'using Pkg; Pkg.instantiate()' # needed only the first time to install dependencies for this script
21+
julia -O0 --color=yes --project build_local.jl
22+
```
23+
24+
There are a few of options you may want to use to tweak the build.
25+
For more information run the command (what's show below may not be up to date, run the command locally to see the options available to you):
26+
27+
```console
28+
% julia -O0 --project build_local.jl --help
29+
usage: build_local.jl [--debug] [--backend BACKEND]
30+
[--gcc_host_compiler_path GCC_HOST_COMPILER_PATH]
31+
[--cc CC]
32+
[--hermetic_python_version HERMETIC_PYTHON_VERSION]
33+
[--jobs JOBS] [--copt COPT] [--cxxopt CXXOPT]
34+
[--extraopt EXTRAOPT] [--color COLOR] [-h]
35+
36+
optional arguments:
37+
--debug Build with debug mode (-c dbg).
38+
--backend BACKEND Build with the specified backend (auto, cpu,
39+
cuda). (default: "auto")
40+
--gcc_host_compiler_path GCC_HOST_COMPILER_PATH
41+
Path to the gcc host compiler. (default:
42+
"/usr/bin/gcc")
43+
--cc CC (default: "/usr/bin/cc")
44+
--hermetic_python_version HERMETIC_PYTHON_VERSION
45+
Hermetic Python version. (default: "3.10")
46+
--jobs JOBS Number of parallel jobs. (type: Int64,
47+
default: <MAXIMUM NUMBER OF CPUs>)
48+
--copt COPT Options to be passed to the C compiler. Can
49+
be used multiple times.
50+
--cxxopt CXXOPT Options to be passed to the C++ compiler. Can
51+
be used multiple times.
52+
--extraopt EXTRAOPT Extra options to be passed to Bazel. Can be
53+
used multiple times.
54+
--color COLOR Set to `yes` to enable color output, or `no`
55+
to disable it. Defaults to same color setting
56+
as the Julia process. (default: "no")
57+
-h, --help show this help message and exit
58+
```
59+
60+
### Doing a build on a system with memoryor number of processes restrictions
61+
62+
If you try to do the build on certain systems where there are in place restrictions on the number of processes or memory that your user can use (for example login node of clusters), you may have to limit the number of parallel jobs used by Bazel.
63+
By default Bazel would try to use the maximum number of CPUs available on the system, if you need reduce that pass the `--jobs JOBS` flag option.
64+
The Bazel server may be terminated abruptly if using too much memory (e.g. if concurrent compiler processes are cumulatively using a large amount of memory), also in this case reducing the number of parallel jobs may be beneficial.
65+
66+
### CUDA debug build
67+
68+
A CUDA debug build (`--debug --backend=cuda`) requires a recent GCC compiler (at least v12) and also a fast linker (see requirements above).
69+
You can tell GCC to use either `lld` or `mold` with `--extraopt '--linkopt=-fuse-ld=lld'` or `--extraopt '--linkopt=-fuse-ld=mold'` respectively.
70+
NOTE: the option `-fuse-ld=mold` was added in GCC 12, if you're trying to use an older version you can have some luck by making a symlink named `ld` pointing to `mold` in `PATH`, with higher precendce than Binutils `ld`.
71+
72+
### Using ccache
73+
74+
If you want to use `ccache` as your compiler, you may have to add the flag `--extraopt "--sandbox_writable_path=/path/to/ccache/directory"` to let `ccache` write to its own directory.

‎lib/ReactantCore/Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ReactantCore"
22
uuid = "a3311ec8-5e00-46d5-b541-4f83e724a433"
33
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>", "Sergio Sánchez Ramírez <sergio.sanchez.ramirez@bsc.es>", "Paul Berg <paul@plutojl.org>", "Avik Pal <avikpal@mit.edu>"]
4-
version = "0.1.5"
4+
version = "0.1.6"
55

66
[deps]
77
ExpressionExplorer = "21656369-7473-754a-2065-74616d696c43"

‎lib/ReactantCore/src/ReactantCore.jl

+12-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,18 @@ using MacroTools: MacroTools
66
export @trace, within_compile, MissingTracedValue
77

88
# Traits
9-
is_traced(x) = false
9+
function is_traced((@nospecialize x::T), seen=Base.IdSet()) where {T}
10+
if !isprimitivetype(x)
11+
for fn in fieldnames(T)
12+
f = getfield(x, fn)
13+
if !(f in seen)
14+
push!(seen, f)
15+
is_traced(f, seen) && return true
16+
end
17+
end
18+
end
19+
return false
20+
end
1021

1122
# New Type signifying that a value is missing
1223
mutable struct MissingTracedValue

‎src/Compiler.jl

+11
Original file line numberDiff line numberDiff line change
@@ -1055,6 +1055,17 @@ function compile_mlir!(
10551055
MLIR.API.mlirOperationDestroy(compiled_f.operation)
10561056
compiled_f.operation = MLIR.API.MlirOperation(C_NULL)
10571057

1058+
# Add a `donated` attr to the function arguments. This doesn't affect XLA, but lets us
1059+
# check which arguments were donated.
1060+
preserved_args_idx = last.(preserved_args)
1061+
for (i, arg) in enumerate(linear_args)
1062+
if i preserved_args_idx
1063+
MLIR.API.mlirFuncSetArgAttr(
1064+
func3, i - 1, "reactant.donated", MLIR.IR.UnitAttribute()
1065+
)
1066+
end
1067+
end
1068+
10581069
return Reactant.TracedUtils.CompiledMlirFnResult(
10591070
fnwrapped,
10601071
func3,

‎src/ConcreteRArray.jl

+2
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ end
341341

342342
# TODO replace this copy for `setindex!` maybe? how to copy data to already existing buffer? (i.e. `copyto!`)
343343
function Base.copy(bc::Base.Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcretePJRTArray}})
344+
bc = Broadcast.flatten(bc)
344345
for x in bc.args
345346
x isa ConcretePJRTArray && wait(x)
346347
end
@@ -370,6 +371,7 @@ function Base.copy(bc::Base.Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcreteP
370371
end
371372

372373
function Base.copy(bc::Base.Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcreteIFRTArray}})
374+
bc = Broadcast.flatten(bc)
373375
fn = compile(Broadcast.BroadcastFunction(bc.f), (bc.args...,))
374376
return fn(bc.args...)
375377
end

‎src/Ops.jl

+17-3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ end
3636

3737
const DEBUG_MODE::Ref{Bool} = Ref(false)
3838
const LARGE_CONSTANT_THRESHOLD = Ref(100 << 20) # 100 MiB
39+
const LARGE_CONSTANT_RAISE_ERROR = Ref(true)
3940

4041
function with_debug(f)
4142
old = DEBUG_MODE[]
@@ -89,8 +90,18 @@ end
8990
@noinline function constant(
9091
x::DenseArray{T,N}; location=mlir_stacktrace("constant", @__FILE__, @__LINE__)
9192
) where {T,N}
92-
sizeof(x) > LARGE_CONSTANT_THRESHOLD[] &&
93-
error("Generating a constant larger than $(LARGE_CONSTANT_THRESHOLD[]) bytes.")
93+
if sizeof(x) > LARGE_CONSTANT_THRESHOLD[]
94+
if LARGE_CONSTANT_RAISE_ERROR[]
95+
error(
96+
"Generating a constant of $(sizeof(x)) bytes, which larger than the $(LARGE_CONSTANT_THRESHOLD[]) bytes threshold",
97+
)
98+
else
99+
location = with_debug() do
100+
mlir_stacktrace("constant", @__FILE__, @__LINE__)
101+
end
102+
end
103+
end
104+
94105
value = MLIR.IR.DenseElementsAttribute(x)
95106
constants = constant_context()[2]
96107
if haskey(constants, value)
@@ -1713,7 +1724,10 @@ end
17131724
traced_args = Vector{Any}(undef, N)
17141725
for i in 1:N
17151726
@inbounds traced_args[i] = Reactant.make_tracer(
1716-
seen_args, args[i], (), Reactant.NoStopTracedTrack; track_numbers=Number
1727+
seen_args,
1728+
args[i],
1729+
(),
1730+
Reactant.NoStopTracedTrack, #; track_numbers=Number
17171731
)
17181732
end
17191733

‎src/TracedRArray.jl

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!, materialize_tra
2222
using ReactantCore: ReactantCore
2323
using GPUArraysCore: GPUArraysCore, @allowscalar
2424

25+
ReactantCore.is_traced(::TracedRArray, seen) = true
2526
ReactantCore.is_traced(::TracedRArray) = true
2627

2728
Base.strides(x::TracedRArray) = Base.size_to_strides(1, size(x)...)

‎src/TracedRNumber.jl

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using ..Reactant:
44
Reactant, TracedRNumber, TracedRArray, TracedUtils, Ops, MLIR, unwrapped_eltype
55
using ReactantCore
66

7+
ReactantCore.is_traced(::TracedRNumber, seen) = true
78
ReactantCore.is_traced(::TracedRNumber) = true
89

910
Base.getindex(a::TracedRNumber{T}) where {T} = a

‎src/TracedUtils.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,7 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
629629
end
630630

631631
function broadcast_to_size(arg::AbstractArray{<:TracedRNumber}, rsize)
632+
collect(size(arg)) == collect(rsize) && return arg
632633
if Reactant.ancestor(arg) isa TracedRArray
633634
return broadcast_to_size(materialize_traced_array(arg), rsize)
634635
end
@@ -658,7 +659,7 @@ function broadcast_to_size(arg::AbstractIrrational, rsize)
658659
end
659660

660661
function broadcast_to_size(arg::ReactantPrimitive, rsize)
661-
return Ops.constant(Base.fill(arg, Tuple(rsize)))
662+
return Ops.fill(arg, rsize)
662663
end
663664

664665
function broadcast_to_size(arg::TracedRNumber{T}, rsize) where {T}

‎src/Tracing.jl

+89
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,20 @@ Base.@nospecializeinfer function traced_type_inner(
551551
end
552552
end
553553

554+
Base.@nospecializeinfer function Reactant.traced_type_inner(
555+
@nospecialize(OA::Type{SubArray{T,N,P,I,L}}),
556+
seen,
557+
mode::Reactant.TraceMode,
558+
@nospecialize(track_numbers::Type),
559+
@nospecialize(sharding),
560+
@nospecialize(runtime)
561+
) where {T,N,P,I,L}
562+
P2 = Reactant.traced_type_inner(P, seen, mode, track_numbers, sharding, runtime)
563+
I2 = Reactant.traced_type_inner(I, seen, mode, track_numbers, sharding, runtime)
564+
T2 = eltype(P2)
565+
return SubArray{T2,N,P2,I2,L}
566+
end
567+
554568
for P in (Ptr, Core.LLVMPtr, Base.RefValue)
555569
@eval Base.@nospecializeinfer function traced_type_inner(
556570
@nospecialize(PT::Type{$P}),
@@ -918,6 +932,81 @@ function make_tracer(
918932
end
919933
append_path(@nospecialize(path), i) = (path..., i)
920934

935+
function make_tracer_via_immutable_constructor(
936+
seen,
937+
@nospecialize(prev),
938+
@nospecialize(path),
939+
mode;
940+
@nospecialize(track_numbers::Type = Union{}),
941+
@nospecialize(sharding = Sharding.NoSharding()),
942+
@nospecialize(runtime = nothing),
943+
kwargs...,
944+
)
945+
RT = Core.Typeof(prev)
946+
if haskey(seen, prev)
947+
if mode == TracedToTypes
948+
id = seen[prev]
949+
push!(path, id)
950+
return nothing
951+
elseif mode != NoStopTracedTrack && haskey(seen, prev)
952+
return seen[prev]
953+
end
954+
elseif mode == TracedToTypes
955+
push!(path, RT)
956+
seen[prev] = VisitedObject(length(seen) + 1)
957+
end
958+
TT = traced_type(RT, Val(mode), track_numbers, sharding, runtime)
959+
@assert !Base.isabstracttype(RT)
960+
@assert Base.isconcretetype(RT)
961+
nf = fieldcount(RT)
962+
963+
@assert !ismutabletype(TT)
964+
965+
if nf == 0
966+
if mode == TracedToTypes
967+
push!(path, prev)
968+
return nothing
969+
end
970+
return prev
971+
end
972+
973+
flds = Vector{Any}(undef, nf)
974+
changed = false
975+
for i in 1:nf
976+
if isdefined(prev, i)
977+
newpath = mode == TracedToTypes ? path : append_path(path, i)
978+
xi = Base.getfield(prev, i)
979+
xi2 = make_tracer(
980+
seen,
981+
xi,
982+
newpath,
983+
mode;
984+
track_numbers,
985+
sharding=Base.getproperty(sharding, i),
986+
runtime,
987+
kwargs...,
988+
)
989+
if xi !== xi2
990+
changed = true
991+
end
992+
flds[i] = xi2
993+
else
994+
nf = i - 1 # rest of tail must be undefined values
995+
break
996+
end
997+
end
998+
if mode == TracedToTypes
999+
return nothing
1000+
end
1001+
if !changed
1002+
seen[prev] = prev
1003+
return prev
1004+
end
1005+
y = TT(flds...)
1006+
seen[prev] = y
1007+
return y
1008+
end
1009+
9211010
function make_tracer(
9221011
seen,
9231012
@nospecialize(prev),

‎test/basic.jl

+6
Original file line numberDiff line numberDiff line change
@@ -1014,3 +1014,9 @@ end
10141014
@test Array(x_ra) ==
10151015
[0.0 0.0 1.0 1.0; 0.0 0.0 1.0 1.0; 1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0]
10161016
end
1017+
1018+
@testset "copy(::Broadcast.Broadcasted{ArrayStyle{ConcreteRArray}})" begin
1019+
x_ra = Reactant.to_rarray(ones(4, 4))
1020+
res = copy(Broadcast.broadcasted(-, Broadcast.broadcasted(+, x_ra, 1)))
1021+
@test res -(Array(x_ra) .+ 1)
1022+
end

0 commit comments

Comments
 (0)
Please sign in to comment.