Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Thunk Change #10

Closed
wants to merge 343 commits into from
Closed
Changes from all commits
Commits
Show all changes
343 commits
Select commit Hold shift + click to select a range
532fd73
Simplify process to builds docs (#554)
giordano Jan 17, 2025
53644c9
fix: inconsistent return dims (#558)
avik-pal Jan 17, 2025
bf43a65
Update WORKSPACE
wsmoses Jan 17, 2025
6310f83
fix: define getindexing into sub reshaped array (#556)
avik-pal Jan 17, 2025
fc4e53d
Format code (#562)
github-actions[bot] Jan 18, 2025
11510a2
Regenerate MLIR Bindings (#561)
github-actions[bot] Jan 18, 2025
2c9c03c
[CI] Format generated files twice to work around JuliaFormatter bug (…
giordano Jan 18, 2025
913bf3a
CUDA: fix nv intrinsic errs (#564)
wsmoses Jan 18, 2025
30da571
Update Project.toml
wsmoses Jan 18, 2025
044b670
[GHA] Add `paths` settings for workflow triggers (#563)
giordano Jan 18, 2025
01a5646
[GHA] Fix syntax of regenerate MLIR bindings workflow (#566)
giordano Jan 18, 2025
3481d1d
respect scopping rules in for (#310)
Pangoraw Jan 18, 2025
380c747
Format code of branch "main" (#568)
github-actions[bot] Jan 19, 2025
6592cea
feat: support arbitrary structures in control flow (#565)
avik-pal Jan 19, 2025
937510b
fix: reduction of integers (#573)
avik-pal Jan 19, 2025
32762fb
[CI] Remove useless call to `Pkg.instantiate` (#576)
giordano Jan 19, 2025
32c9226
fix: specialize / on integer types (#577)
avik-pal Jan 19, 2025
cf345d5
profiler: Add option to generate perfetto url (#575)
Pangoraw Jan 19, 2025
f917bfe
More jll/cuda stuff (#567)
wsmoses Jan 19, 2025
8ddb738
[ReactantExtra] Use XLA commit for building with CUDA 12.1 (#579)
giordano Jan 20, 2025
a328d27
Update Project.toml
wsmoses Jan 20, 2025
682d6d2
Regenerate MLIR Bindings (#580)
github-actions[bot] Jan 20, 2025
2be46df
Profiler annotations & tutorial (#582)
Pangoraw Jan 20, 2025
23e0f74
Update Compiler.jl
wsmoses Jan 20, 2025
cca721d
Fix for unknown cuda drivers (#586)
wsmoses Jan 20, 2025
63d407a
PTX fma and other flags (#585)
wsmoses Jan 20, 2025
29627db
[CI] Move tests on aarch64 linux to GitHub Actions (#543)
giordano Jan 21, 2025
66910bf
Fix condition to skip CUDA tests on aarch64 (#592)
giordano Jan 21, 2025
678b90d
feat: expose gpu memory allocation options (#589)
avik-pal Jan 21, 2025
1c42a58
feat: add the new optimization passes (#595)
avik-pal Jan 22, 2025
1b31dd7
Update ReactantCUDAExt.jl (#597)
wsmoses Jan 23, 2025
077e591
Add convert (#598)
wsmoses Jan 23, 2025
d5010bb
Update Project.toml
wsmoses Jan 23, 2025
635f35c
feat: support dynamic indexing for reshaped arrays (#601)
avik-pal Jan 24, 2025
534bea3
feat: overload LinearAlgebra.kron (#607)
avik-pal Jan 24, 2025
2118ee2
feat: more indexing support (#608)
avik-pal Jan 24, 2025
7dfdbb9
[tests] Always skip CUDA tests on non-CUDA machines (#615)
giordano Jan 25, 2025
24c3351
Add hermetic cuda getter (#612)
wsmoses Jan 25, 2025
2f13d4e
feat: forward more base ops to chlo (#611)
avik-pal Jan 25, 2025
67575fb
Fix dense elements attribute in `Enzyme.autodiff` #593 (#604)
mofeing Jan 25, 2025
4849c6b
feat: support lowering custom fp types (#596)
avik-pal Jan 26, 2025
51a1f46
feat: multi GPU support (#587)
avik-pal Jan 26, 2025
29b0eac
Regenerate MLIR Bindings (#621)
github-actions[bot] Jan 26, 2025
6e4c6a8
feat: build the shardy dialect (#622)
avik-pal Jan 26, 2025
bfc7b58
feat: support more set indexing (#625)
avik-pal Jan 26, 2025
272da5e
Typed rounding (#619)
wsmoses Jan 26, 2025
c499c7b
Add bound optimizations (#626)
wsmoses Jan 26, 2025
9ff575f
Update Project.toml
wsmoses Jan 26, 2025
8623343
[CI] Add workflow to clean up docs previews (#628)
giordano Jan 26, 2025
59b1856
fix: build error with shardy (#629)
avik-pal Jan 26, 2025
dea0350
Format code of branch "main" (#634)
github-actions[bot] Jan 27, 2025
8450be3
Regenerate MLIR Bindings (#627)
github-actions[bot] Jan 27, 2025
8d39727
[tests] Replace random custom type numbers with fixed set of numbers …
giordano Jan 27, 2025
7710603
fix cuda abi setting (#633)
wsmoses Jan 27, 2025
8db24b1
[ReactantExtra] Improvements to BUILD file to compile CUDA for aarch6…
giordano Jan 27, 2025
23a57df
[ReactantExtra] Bump XLA version (#640)
giordano Jan 27, 2025
fd60aad
Update WORKSPACE
wsmoses Jan 27, 2025
db2aa15
feat: add dispatch for KA get_backend (#645)
avik-pal Jan 28, 2025
27330e0
Applehw (#643)
wsmoses Jan 28, 2025
6951708
Use `xla/stream_executor/cuda:cuda_compute_capability_proto_cc_impl` …
giordano Jan 28, 2025
e38e8ca
Regenerate MLIR Bindings (#644)
github-actions[bot] Jan 28, 2025
f1129f0
docs: add shardy to docs (#648)
avik-pal Jan 29, 2025
f7f009c
TPU profiler (#642)
Pangoraw Jan 29, 2025
b5f9ecd
chore: generate shardy c wrappers (#650)
avik-pal Jan 29, 2025
5c67d0a
feat: the big jll PR (#653)
avik-pal Jan 30, 2025
cde5935
[CI] Fix path of previews direcotyr in PreviewCleanup workflow (#656)
giordano Jan 30, 2025
1087479
Regenerate MLIR Bindings (#651)
github-actions[bot] Jan 30, 2025
1925f70
CPU backend (#647)
wsmoses Jan 30, 2025
0feae66
Detect TPU using PCI devices (#659)
Pangoraw Jan 30, 2025
0d3c8df
Add IR dumping (#638)
wsmoses Jan 30, 2025
a2da11b
Replace `trim` -> `strip` (#661)
giordano Jan 30, 2025
4883486
Silence various warnings in tests (#662)
giordano Jan 30, 2025
696e176
Format code (#665)
github-actions[bot] Jan 31, 2025
9a1179d
Regenerate MLIR Bindings (#666)
github-actions[bot] Jan 31, 2025
52b60ad
Feature: allow colon indexing of traced **vectors** (#664)
floffy-f Jan 31, 2025
d3abcd4
[docs] Add information about configuration on GPU and TPU systems (#668)
giordano Jan 31, 2025
35bd5d2
Update index.md
wsmoses Jan 31, 2025
09da62f
Fix ntuple traced type issue on unionall (#669)
wsmoses Feb 1, 2025
e46eedf
KA ext (#667)
wsmoses Feb 1, 2025
e9471bd
Update Project.toml
wsmoses Feb 1, 2025
c2786dd
KA without cuda backend (#670)
wsmoses Feb 1, 2025
1e6037f
Update Project.toml
wsmoses Feb 1, 2025
603225a
[ReactantCUDAExt] Skip precompile load on Julia v1.11.3 (#675)
giordano Feb 1, 2025
bf2a020
Update enzyme-jax
wsmoses Feb 2, 2025
26317dc
Generate MLIR MPI dialect bindings
mofeing Feb 2, 2025
fe8ba34
Update WORKSPACE
wsmoses Feb 2, 2025
40c7708
Regenerate MLIR Bindings (#680)
github-actions[bot] Feb 2, 2025
e16a8df
Update WORKSPACE
wsmoses Feb 3, 2025
85e08c3
[ReactantExtra] Add argument to `ClientCompile` to pass CUDA data dir…
giordano Feb 3, 2025
b9da5c1
Update XLA.jl
wsmoses Feb 3, 2025
da047f6
Update BUILD
wsmoses Feb 3, 2025
18400b6
Use `LLVMOpenMP_jll` to call OpenMP functions (#673)
giordano Feb 3, 2025
0ddeab2
CUDA: fix gc issues (#685)
wsmoses Feb 3, 2025
e9b7a72
make `similar` return empty tensors. (#632)
jumerckx Feb 3, 2025
9339756
Update Project.toml
wsmoses Feb 3, 2025
bf0fb61
Regenerate MLIR Bindings (#686)
github-actions[bot] Feb 4, 2025
ba4405d
Misc fixes (#687)
wsmoses Feb 4, 2025
cac6f49
chore: missing upstream optimization passes (#624)
avik-pal Feb 4, 2025
234d168
`@trace` function calls (#366)
jumerckx Feb 4, 2025
21f7660
dict value fix (#688)
wsmoses Feb 4, 2025
e35fa0c
Update Project.toml
wsmoses Feb 5, 2025
6d998bc
Update Project.toml
wsmoses Feb 5, 2025
d15980a
Update pipeline.yml
wsmoses Feb 5, 2025
b506bfc
[deps] Some improvements to the `build_local.jl` script (#689)
giordano Feb 5, 2025
d842b33
Multiple device error (#690)
wsmoses Feb 5, 2025
c325683
feat: API changes for multi-device execution [ReactantExtra JLL chang…
avik-pal Feb 6, 2025
af24aa8
Ref ptr fix (#698)
wsmoses Feb 6, 2025
7194111
Add missing underscore in BUILD
giordano Feb 6, 2025
b6942cf
[CI] Use debugging version of CompatHelper
giordano Feb 7, 2025
77ada56
Add GPUCompiler and LLVM as deps to CUDA extension and run CUDA tests…
giordano Feb 7, 2025
4378d3e
Update WORKSPACE
wsmoses Feb 7, 2025
a0c0eb3
use current xla
wsmoses Feb 7, 2025
7f82be6
Update WORKSPACE
wsmoses Feb 7, 2025
f971fec
fix
wsmoses Feb 7, 2025
5f2a86d
builds
wsmoses Feb 7, 2025
170e48d
bump commit
wsmoses Feb 7, 2025
3eac3f8
feat: shardy and multi device execution (#637)
avik-pal Feb 7, 2025
6f2cc88
[ReactantExtra] Stop removing references to `hardware_interference_si…
giordano Feb 7, 2025
0794edf
vendor optimize (#703)
wsmoses Feb 8, 2025
24b4d31
Format code of branch "main" (#709)
github-actions[bot] Feb 8, 2025
611b800
fix: don't trace val (#710)
avik-pal Feb 8, 2025
c20b142
JLL related fixups (#706)
wsmoses Feb 8, 2025
a1fa03f
Update Project.toml (#705)
wsmoses Feb 8, 2025
8d6636d
Regenerate MLIR Bindings (#708)
github-actions[bot] Feb 8, 2025
1680698
Update Reactant.jl
wsmoses Feb 8, 2025
283400d
Format code (#711)
github-actions[bot] Feb 8, 2025
b4dc79f
feat: overload ifelse for more types (#712)
avik-pal Feb 8, 2025
a3949c7
fix build
wsmoses Feb 9, 2025
8794992
fix illegal span
wsmoses Feb 9, 2025
4c1a33b
refactor: split XLA.jl into multiple files (#716)
avik-pal Feb 9, 2025
b964042
feat: enable async on CPU (#717)
avik-pal Feb 9, 2025
15e8843
[ReactantExtra] feat: OpSharding bindings for Julia (#721)
avik-pal Feb 10, 2025
5f5f81b
[ReactantExtra] fix: build on mac (#722)
avik-pal Feb 10, 2025
447ff4f
[ReactantExtra] IFRT bindings (round 4) (#718)
mofeing Feb 10, 2025
1e28427
Update WORKSPACE (#723)
avik-pal Feb 10, 2025
e9f9788
Update WORKSPACE
Pangoraw Feb 10, 2025
130c0cd
Update WORKSPACE
avik-pal Feb 10, 2025
0e119dc
Update XLA.jl
wsmoses Feb 11, 2025
a1b534b
Fix jll (#724)
wsmoses Feb 11, 2025
a31b2e0
Update WORKSPACE
wsmoses Feb 11, 2025
90b0d1d
fix: multi-device execution and sharding [take III] (#713)
avik-pal Feb 11, 2025
01d2904
Update Project.toml
wsmoses Feb 11, 2025
b70d614
feat: add sign dispatches (#727)
avik-pal Feb 11, 2025
0760f4b
fix: correct dims handling in mapreducedim! (#728)
avik-pal Feb 11, 2025
9fdcbaa
chore: bump version for release
avik-pal Feb 11, 2025
4ca5147
Format code (#729)
github-actions[bot] Feb 12, 2025
e82c420
fix: prevent method ambiguity for CartesianIndex{1} (#730)
avik-pal Feb 12, 2025
afe1afb
[GHA] Some improvement to CI setup (#731)
giordano Feb 12, 2025
904b789
fix `Type(value)` instead of `type(value)` (#733)
jumerckx Feb 12, 2025
95f6074
fix: improve generated mlir for wrapped arrays (#732)
avik-pal Feb 12, 2025
467559d
fix: don't expand all ranges by default (#737)
avik-pal Feb 12, 2025
ce0c590
ci: add cpp format check (#739)
avik-pal Feb 13, 2025
3d98bd7
fix: unqualified Sharding access (#741)
avik-pal Feb 13, 2025
22ec225
feat: sharding via IFRT (#740)
avik-pal Feb 13, 2025
6e8ef9f
Force tracing of type to act as noop (#747)
wsmoses Feb 14, 2025
d1a6a24
Support for dicts (#748)
wsmoses Feb 14, 2025
08968ee
Update Project.toml
wsmoses Feb 14, 2025
5a98e91
Bump enzymexla
wsmoses Feb 15, 2025
000a250
Update WORKSPACE
wsmoses Feb 15, 2025
0a47c81
feat: JLL changes to expose HloModule (#749)
avik-pal Feb 15, 2025
cebf9eb
[IFRT] add ifrt-proxy server and client bindings (#750)
mofeing Feb 15, 2025
950e476
fix: ordering of arguments need to be according to device (#753)
avik-pal Feb 16, 2025
5d032c4
Support tracing of `rem` with only one operand being a `ConcreteRNumb…
giordano Feb 16, 2025
ae6e60c
Fix for ocean (#756)
wsmoses Feb 16, 2025
f42a67f
[ReactantCUDAExt] Remove extra method (#760)
giordano Feb 16, 2025
20f7a3c
feat: use parameter shardings from XLA (#743)
avik-pal Feb 16, 2025
7b79953
Bump to 0.2.30 (#757)
glwagner Feb 16, 2025
9204e39
Fix implementation of `mod` (#758)
giordano Feb 16, 2025
93bc64b
bump enzymexla
wsmoses Feb 16, 2025
d51867d
Bump actions/checkout from 3 to 4 (#762)
dependabot[bot] Feb 17, 2025
4ece29d
Update Project.toml
wsmoses Feb 17, 2025
c817f8a
[GHA] Run x86_64 macOS jobs on macOS-13 runners (#765)
giordano Feb 17, 2025
acaabc4
[IFRT] add c-bindings for "Held" PjRt classes (#751)
mofeing Feb 19, 2025
871790f
feat: JLL changes for IFRT Shardings (#770)
avik-pal Feb 19, 2025
a1d2b8b
track number in traced_type for mode == TracedSetPath (#772)
jumerckx Feb 19, 2025
afa90a4
Don't trace VersionNumber (#773)
milankl Feb 19, 2025
0a37253
refactor: move PJRT into a specific module (#771)
avik-pal Feb 19, 2025
83a2c1d
ci: run tests if julia file changes (#742)
avik-pal Feb 20, 2025
66b035e
[build_local] Add argument to set `--color` option for Bazel (#776)
giordano Feb 20, 2025
5d370a5
Update Raising passes
wsmoses Feb 21, 2025
1416a0a
Bump enzymexla
wsmoses Feb 21, 2025
df0075b
Update API.cpp
wsmoses Feb 21, 2025
64c866c
fix: code_xla (#782)
avik-pal Feb 21, 2025
ea9c9e9
Bump xla (#783)
wsmoses Feb 21, 2025
9ea612f
feat: JLL changes for #780
avik-pal Feb 21, 2025
b9cff51
Further enzymexla bump (#785)
wsmoses Feb 22, 2025
fdf21dc
Format code (#786)
github-actions[bot] Feb 22, 2025
447bdc8
feat: JLL changes for #788 (#789)
avik-pal Feb 22, 2025
230d77e
Temporarily use fork of XLA to point to fork of LLVM (#787)
giordano Feb 22, 2025
d18fd40
feat: more JLL changes for #788 (#790)
avik-pal Feb 22, 2025
7793387
Bump enzymejax (#795)
wsmoses Feb 22, 2025
0a7e078
feat: JLL changes for IFRT integration (#796)
avik-pal Feb 23, 2025
4aaae3c
Regenerate MLIR Bindings (#779)
github-actions[bot] Feb 23, 2025
b1fc493
feat: support kwargs in macros (#791)
avik-pal Feb 23, 2025
dbc0fe1
refactor: rename R* to PJRT* (#775)
avik-pal Feb 23, 2025
6b760ba
feat: fast path to directly generate HloSharding from shardy (#788)
avik-pal Feb 23, 2025
9a150f7
feat: use a global state to setup pjrt distributed runtime (#780)
avik-pal Feb 23, 2025
bcb0282
feat: JLL changes for sdy.sharding_constraint (#799)
avik-pal Feb 23, 2025
2a5711e
feat: initial IFRT integration (#764)
avik-pal Feb 23, 2025
192bef3
traced control flow fixes (#794)
jumerckx Feb 23, 2025
4611142
feat: add `Ops.sharding_constraint` (#798)
avik-pal Feb 24, 2025
c9f4586
Update WORKSPACE
wsmoses Feb 24, 2025
4cbad40
cf: extract condition to a name (#801)
Pangoraw Feb 24, 2025
aad100b
build(deps): bump bazel-contrib/setup-bazel from 0.13.0 to 0.14.0 (#802)
dependabot[bot] Feb 24, 2025
779f66c
Bump JLL (#803)
wsmoses Feb 24, 2025
5260061
restore anyconcreterarray (#804)
wsmoses Feb 24, 2025
3c67dd6
Fix raising (#805)
wsmoses Feb 25, 2025
2e138de
Bm (#807)
wsmoses Feb 25, 2025
9af8d87
[Compiler] Make `raise` a keyword argument (#797)
giordano Feb 25, 2025
87aa468
fix: simplify Mesh implementation (#806)
avik-pal Feb 25, 2025
07594f5
Update Project.toml
wsmoses Feb 25, 2025
94f8d34
Update Project.toml
wsmoses Feb 25, 2025
5c510f0
Bump enzymejax (#812)
wsmoses Feb 26, 2025
24fed87
Bump enzymexla passes (#814)
wsmoses Feb 27, 2025
e995366
feat: more sharding utilities (#809)
avik-pal Feb 27, 2025
2d4da3f
More enzymexla bump (#816)
wsmoses Feb 27, 2025
db1a123
Bump JLL (#817)
wsmoses Feb 27, 2025
d92e1e5
feat: adding axpy! and axpby! to linear algebra (#813)
tharittk Feb 27, 2025
f4216e8
fix: allow one arg of overloaded_mul to be a regular array (#821)
avik-pal Feb 27, 2025
c78f585
Regenerate MLIR Bindings (#815)
github-actions[bot] Feb 28, 2025
3991fd0
More enzymexla bump (#823)
wsmoses Feb 28, 2025
85bdbee
fix: return T if mode=TracedSetPath (#810)
glwagner Feb 28, 2025
a39b055
feat: support 2 levels of wrapping (#824)
avik-pal Feb 28, 2025
f33c023
feat: add isinf dispatches (#826)
avik-pal Mar 1, 2025
da85c90
Bump enzymexla (#829)
wsmoses Mar 2, 2025
5815733
Update WORKSPACE
wsmoses Mar 2, 2025
4b6b818
Update WORKSPACE
wsmoses Mar 2, 2025
a0659e1
update jll (#830)
wsmoses Mar 2, 2025
dc3f17c
Tag v0.2.34 (#831)
giordano Mar 2, 2025
1f800c4
Restore raising to executing (#832)
wsmoses Mar 2, 2025
100c9f7
Format code (#833)
github-actions[bot] Mar 3, 2025
28b4eda
Bump Reactant_jll and version number (#838)
giordano Mar 3, 2025
d1be533
feat: sharding with non-divisible dimensions [alternate approach] (#825)
avik-pal Mar 4, 2025
e522ced
Generate MemRef dialect bindings (#836)
mofeing Mar 4, 2025
07efa86
Update `make-bindings.jl` to generate MemRef dialect
mofeing Mar 4, 2025
451fc40
Quick fix path to MemRefOps.td
mofeing Mar 4, 2025
e7eec18
Regenerate MLIR Bindings (#842)
github-actions[bot] Mar 4, 2025
b0e5f11
fix: IFRT PJRT Client construction (#841)
avik-pal Mar 4, 2025
f09a335
ReactantExtra changes to fix MPI PR (#844)
mofeing Mar 5, 2025
f56f806
Bump enzymexla
wsmoses Mar 5, 2025
c7864b7
[CI] Set timeout of run-tests job to 60 minutes (#843)
giordano Mar 5, 2025
d2ff1aa
test: disable Lux gradient test for now (#845)
avik-pal Mar 5, 2025
3ea2ce9
docs: housekeeping + memref dialect (#846)
avik-pal Mar 5, 2025
ff32540
docs: fix links in nav bar (#848)
avik-pal Mar 5, 2025
ede4493
feat: more dispatches for any/all (#834)
avik-pal Mar 6, 2025
302274f
fix: number tracing (#849)
avik-pal Mar 6, 2025
0cfcef7
chore: bump version for release
avik-pal Mar 6, 2025
fdbbfeb
Additional enzymexlabump (#850)
wsmoses Mar 6, 2025
e191da4
Further enzymexla bump (#851)
wsmoses Mar 7, 2025
acc23bc
change Thunk
glou-nes Mar 7, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 1 addition & 37 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
@@ -32,43 +32,7 @@ steps:
cuda: "*"
env:
REACTANT_TEST_GROUP: "{{matrix.group}}"
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 120

- label: ":julia: :linux: aarch64 - Julia v{{matrix.version}} -- {{matrix.group}}"
matrix:
setup:
version:
- "1.10"
- "1.11"
group:
- core
- neural_networks
- integration
plugins:
- JuliaCI/julia#v1:
version: "{{matrix.version}}"
- JuliaCI/julia-coverage#v1:
codecov: true
dirs:
- src
- ext
- lib/ReactantCore/src
commands: |
julia --project=. -e 'println("--- :julia: Instantiating project")
using Pkg
Pkg.develop([PackageSpec(path="lib/ReactantCore")])'

julia --project=. -e 'println("--- :julia: Run Tests")
using Pkg
Pkg.test(; coverage="user")'
agents:
queue: "juliaecosystem"
os: "linux"
sandbox_capable: "true"
arch: "aarch64"
env:
REACTANT_TEST_GROUP: "{{matrix.group}}"
CUDA_VISIBLE_DEVICES: 0
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 120

1 change: 1 addition & 0 deletions .clang-format
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
BasedOnStyle: LLVM
93 changes: 93 additions & 0 deletions .github/workflows/CI-localjll.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
name: CI with local libReactant

on:
pull_request:
paths:
- '.github/workflows/CI-localjll.yml'
- 'deps/**'
push:
branches:
- main
- release-*
tags: '*'
paths:
- '.github/workflows/CI-localjll.yml'
- 'deps/**'

concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: only if it is a pull request build.
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}

jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - local libReactant - ${{ github.event_name }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
version:
- '1.10'
- '1.11'
os:
- ubuntu-24.04
- macOS-latest
exclude:
- os: macOS-latest
version: '1.10'
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
- uses: julia-actions/cache@v2
- uses: bazel-contrib/setup-bazel@0.14.0
name: Set up Bazel
with:
# Avoid downloading Bazel every time.
bazelisk-cache: true
# Store build cache per workflow.
disk-cache: ${{ github.workflow }}-${{ matrix.os }}-${{ matrix.version }}
# Share repository cache between workflows.
repository-cache: true
bazelisk-version: 1.x
- name: Prepare build on macOS
if: ${{ startsWith(matrix.os, 'macOS-') }}
run: |
echo "SDKROOT=$(xcrun --show-sdk-path)" >> "${GITHUB_ENV}"
- name: Build libReactant
run: |
python -m pip install numpy
julia --color=yes --project=deps -e 'using Pkg; Pkg.instantiate()'
julia --color=yes --project=deps deps/build_local.jl
cp LocalPreferences.toml test/
- name: "Install Dependencies"
run: |
import Pkg
Pkg.Registry.update()
# Install packages present in subdirectories
dev_pks = Pkg.PackageSpec[]
for path in ("lib/ReactantCore",)
push!(dev_pks, Pkg.PackageSpec(; path))
end
Pkg.develop(dev_pks)
shell: julia --color=yes --code-coverage=user --depwarn=yes --project=. {0}
# Only in Julia v1.10 we need to install `ReactantCore` manually.
if: ${{ matrix.version == '1.10' }}
env:
JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager
- name: "Run Tests"
run: |
import Pkg
Pkg.Registry.update()
Pkg.test(; coverage="user")
shell: julia --color=yes --code-coverage=user --depwarn=yes --project=. {0}
id: run_tests
env:
JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v5
with:
files: lcov.info
116 changes: 40 additions & 76 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
name: CI

on:
pull_request:
paths:
- '.github/workflows/CI.yml'
- 'ext/**'
- 'lib/**'
- 'src/**'
- 'test/**'
- 'Project.toml'
push:
branches:
- main
- release-*
tags: '*'
paths:
- '.github/workflows/CI.yml'
- 'ext/**'
- 'lib/**'
- 'src/**'
- 'test/**'
- 'Project.toml'

concurrency:
# Skip intermediate builds: always.
@@ -15,7 +30,8 @@ concurrency:

jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.test_group }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ matrix.libReactant }} libReactant - assertions=${{ matrix.assertions }} - ${{ github.event_name }}
timeout-minutes: 90
name: Julia ${{ matrix.version }} - ${{ matrix.test_group }} - ${{ matrix.os }} - assertions=${{ matrix.assertions }} - ${{ github.event_name }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
@@ -25,63 +41,49 @@ jobs:
- '1.11'
# - 'nightly'
os:
- ubuntu-20.04
- ubuntu-24.04
# `ubuntu-22.04-arm` is considered more stable than `ubuntu-24.04-arm`:
# <https://github.com/orgs/community/discussions/148648#discussioncomment-12099554>.
- ubuntu-22.04-arm
- macOS-13
- macOS-latest
test_group:
- core
- neural_networks
- integration
arch:
- x64
- aarch64
assertions:
- false
libReactant: [packaged]
include:
- os: ubuntu-20.04
arch: x64
libReactant: packaged
- os: ubuntu-24.04
version: '1.10'
assertions: true
test_group: core
- os: ubuntu-20.04
arch: x64
libReactant: packaged
- os: ubuntu-24.04
version: '1.10'
assertions: true
test_group: neural_networks
- os: ubuntu-20.04
arch: x64
libReactant: packaged
- os: ubuntu-24.04
version: '1.10'
assertions: true
test_group: integration
# - os: ubuntu-20.04
# arch: x86
# - os: ubuntu-24.04
# libReactant: packaged
# version: '1.10'
# test_group: core
# - os: ubuntu-20.04
# arch: x86
# - os: ubuntu-24.04
# libReactant: packaged
# version: '1.10'
# test_group: neural_networks
# - os: ubuntu-20.04
# arch: x86
# - os: ubuntu-24.04
# libReactant: packaged
# version: '1.10'
# test_group: integration
exclude:
# these are run on Buildkite
- os: ubuntu-20.04
arch: aarch64
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
if: ${{ ! matrix.assertions }}
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: julia-actions/cache@v2
- uses: actions/checkout@v4
if: ${{ matrix.assertions }}
@@ -95,23 +97,7 @@ jobs:
sed -i.bak 's/exit 2/exit 0/g' julia/deps/tools/jlchecksum
make -C julia -j $(nproc) FORCE_ASSERTIONS=1 LLVM_ASSERTIONS=1 JULIA_PRECOMPILE=0
echo $PWD/julia/usr/bin >> $GITHUB_PATH
- name: Build libReactant
if: ${{ matrix.libReactant == 'local' && matrix.os != 'macOS-latest'}}
id: build_libreactant
run: |
python -m pip install numpy
julia --color=yes --project=deps -e 'using Pkg; Pkg.instantiate()'
julia --color=yes --project=deps deps/build_local.jl
cp LocalPreferences.toml test/
- name: Build libReactant MacOS
if: ${{ matrix.libReactant == 'local' && matrix.os == 'macOS-latest'}}
id: build_libreactant_mac
run: |
python -m pip install numpy
julia --color=yes --project=deps -e 'using Pkg; Pkg.instantiate()'
SDKROOT=`xcrun --show-sdk-path` julia --color=yes --project=deps deps/build_local.jl
cp LocalPreferences.toml test/
- name: "Install Dependencies and Run Tests"
- name: "Install Dependencies"
run: |
import Pkg
Pkg.Registry.update()
@@ -121,46 +107,24 @@ jobs:
push!(dev_pks, Pkg.PackageSpec(; path))
end
Pkg.develop(dev_pks)
Pkg.instantiate()
shell: julia --color=yes --code-coverage=user --depwarn=yes --project=. {0}
# Only in Julia v1.10 we need to install `ReactantCore` manually.
if: ${{ matrix.version == '1.10' }}
env:
JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager
- name: "Run Tests"
timeout-minutes: 60
run: |
import Pkg
Pkg.Registry.update()
Pkg.test(; coverage="user")
shell: julia --color=yes --code-coverage=user --depwarn=yes --project=. {0}
id: run_tests
env:
JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager
REACTANT_TEST_GROUP: ${{ matrix.test_group }}
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
- uses: julia-actions/julia-processcoverage@v1
if: steps.run_tests.outcome == 'success'
- uses: codecov/codecov-action@v5
if: steps.run_tests.outcome == 'success'
with:
files: lcov.info
docs:
name: Documentation
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: '1'
- uses: julia-actions/cache@v2
- run: |
julia --color=yes --project=docs -e '
using Pkg
Pkg.develop([
PackageSpec(path=pwd()),
PackageSpec("Reactant_jll"),
PackageSpec(path="lib/ReactantCore")
])
Pkg.instantiate()'
env:
JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager
- run: |
julia --color=yes --project=docs -e '
using Documenter: DocMeta, doctest
using Reactant
DocMeta.setdocmeta!(Reactant, :DocTestSetup, :(using Reactant); recursive=true)
doctest(Reactant)'
- run: julia --color=yes --project=docs docs/make.jl
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}
5 changes: 4 additions & 1 deletion .github/workflows/CompatHelper.yml
Original file line number Diff line number Diff line change
@@ -33,7 +33,10 @@ jobs:
name = "CompatHelper"
uuid = "aa819f21-2bde-4658-8897-bab36330d9b7"
version = "3"
Pkg.add(; name, uuid, version)
# Temporarily use debugging version
url = "https://github.com/JuliaRegistries/CompatHelper.jl.git"
rev = "f408ea193f9573c68a68d72932bcd56268c60340"
Pkg.add(; url, rev)
shell: julia --color=yes {0}
- name: "Run CompatHelper"
run: |
60 changes: 60 additions & 0 deletions .github/workflows/Documenter.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
name: Documentation

on:
pull_request:
paths:
- '.github/workflows/Documenter.yaml'
- 'docs/**'
- 'lib/**'
- 'src/**'
push:
branches:
- main
tags: '*'
paths:
- '.github/workflows/Documenter.yaml'
- 'docs/**'
- 'lib/**'
- 'src/**'

concurrency:
# Same group concurrency as the `PreviewCleanup.yml` workflow, because they both
# git-push to the same branch, so we want to avoid clashes. NOTE: this is
# different from the concurrency group below, which is to cancel successive
# jobs from within the PR.
group: docs-pushing

jobs:
docs:
name: Documentation
runs-on: ubuntu-latest
concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: only if it is a pull request build.
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: '1'
- uses: julia-actions/cache@v2
- name: Instantiate docs environment
run: |
julia --color=yes --project=docs -e '
using Pkg
Pkg.instantiate()'
env:
JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager
- name: Run doctests
run: |
julia --color=yes --project=docs -e '
using Documenter: DocMeta, doctest
using Reactant
DocMeta.setdocmeta!(Reactant, :DocTestSetup, :(using Reactant); recursive=true)
doctest(Reactant)'
- name: Build documentation
run: julia --color=yes --project=docs docs/make.jl
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}
30 changes: 30 additions & 0 deletions .github/workflows/PreviewCleanup.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
name: Doc Preview Cleanup

on:
pull_request:
types: [closed]

concurrency:
# Same group concurrency as the `docs.yml` workflow, because they both
# git-push to the same branch, so we want to avoid clashes.
group: docs-pushing

jobs:
doc-preview-cleanup:
runs-on: ubuntu-latest
steps:
- name: Checkout gh-pages branch
uses: actions/checkout@v4
with:
ref: gh-pages
- name: Delete preview and history + push changes
run: |
preview_directory=previews/PR${{ github.event.number }}
if [[ -d "${preview_directory}" ]]; then
git config user.name "${{github.actor}}"
git config user.email "${{github.actor_id}}+${{github.actor}}@users.noreply.github.com"
git rm -rf "${preview_directory}"
git commit -m 'Cleanup docs for PR #${{ github.event.number }}'
git branch gh-pages-new $(echo "Delete history" | git commit-tree HEAD^{tree})
git push --force origin gh-pages-new:gh-pages
fi
16 changes: 14 additions & 2 deletions .github/workflows/benchmark_aggregate.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
name: Benchmarks

permissions:
contents: write # contents permission to update benchmark contents in gh-pages branch
statuses: read
@@ -7,14 +8,25 @@ permissions:

on:
pull_request:

paths:
- '.github/workflows/benchmark_aggregate.yml'
- 'ext/**'
- 'lib/**'
- 'src/**'
- 'Project.toml'
push:
branches:
- main
paths:
- '.github/workflows/benchmark_aggregate.yml'
- 'ext/**'
- 'lib/**'
- 'src/**'
- 'Project.toml'

jobs:
benchmark:
if: ${{ !contains(github.event.head_commit.message, '[skip benchmarks]') }}
if: ${{ !contains(github.event.head_commit.message, '[skip benchmarks]') && ! github.event.pull_request.head.repo.fork }}
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
19 changes: 16 additions & 3 deletions .github/workflows/downgrade.yml
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
name: Downgrade

on:
pull_request:
branches:
- main
paths:
- '.github/workflows/downgrade.yml'
- 'ext/**'
- 'lib/**'
- 'src/**'
- 'Project.toml'
push:
branches:
- main
paths:
- '.github/workflows/downgrade.yml'
- 'ext/**'
- 'lib/**'
- 'src/**'
- 'Project.toml'

concurrency:
# Skip intermediate builds: always.
@@ -16,6 +29,7 @@ concurrency:
jobs:
downgrade:
# if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }}
timeout-minutes: 90
runs-on: ubuntu-latest
strategy:
fail-fast: false
@@ -29,6 +43,7 @@ jobs:
- uses: julia-actions/setup-julia@v2
with:
version: "1.10"
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-downgrade-compat@v1
with:
skip: "ReactantCore"
@@ -42,16 +57,14 @@ jobs:
push!(dev_pks, Pkg.PackageSpec(; path))
end
Pkg.develop(dev_pks)
Pkg.instantiate()
Pkg.test(; coverage="user")
shell: julia --color=yes --code-coverage=user --depwarn=yes --project=. {0}
id: run_tests
env:
JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager
REACTANT_TEST_GROUP: ${{ matrix.test_group }}
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
- uses: julia-actions/julia-processcoverage@v1
if: steps.run_tests.outcome == 'success'
- uses: codecov/codecov-action@v5
if: steps.run_tests.outcome == 'success'
with:
files: lcov.info
31 changes: 31 additions & 0 deletions .github/workflows/format-check-cpp.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: Format Suggestions

on:
push:
branches:
- main
tags: '*'
paths:
- '.github/workflows/format-check-cpp.yml'
- '**/*.cpp'
- '**/*.h'
pull_request:
paths:
- '.github/workflows/format-check-cpp.yml'
- '**/*.cpp'
- '**/*.h'

concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: always.
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
code-style-cpp:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: DoozyX/clang-format-lint-action@v0.18.2
with:
source: 'deps'
9 changes: 8 additions & 1 deletion .github/workflows/format-check.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
name: Format Suggestions

on:
push:
branches:
- main
tags: '*'
paths:
- '.github/workflows/format-check.yml'
- '**/*.jl'
pull_request:
paths:
- '.github/workflows/format-check.yml'
- '**/*.jl'

concurrency:
# Skip intermediate builds: always.
@@ -13,7 +20,7 @@ concurrency:
cancel-in-progress: true

jobs:
code-style:
code-style-julia:
runs-on: ubuntu-latest
steps:
- uses: julia-actions/julia-format@v3
2 changes: 2 additions & 0 deletions .github/workflows/format-pr.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
name: Format 'main'

on:
schedule:
- cron: '0 0 * * *'
@@ -38,6 +39,7 @@ jobs:
branch: format-main
delete-branch: true
labels: format
author: enzyme-ci-bot[bot] <78882869+enzyme-ci-bot[bot]@users.noreply.github.com>
- name: Check outputs
run: |
echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}"
11 changes: 9 additions & 2 deletions .github/workflows/regenerate-mlir-bindings.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
name: Regenerate MLIR Bindings

on:
schedule:
- cron: '0 0 * * *'
workflow_dispatch:

jobs:
make:
mlir-bindings:
runs-on: ubuntu-latest
permissions:
contents: write
@@ -38,7 +40,8 @@ jobs:
working-directory: ./deps/ReactantExtra
env:
JULIA_DEPOT_PATH: ${{ runner.temp }}/julia_depot
- run: |
- name: Make generated files writable
run: |
chmod -R u+rw ./src/mlir/Dialects/
chmod u+rw ./src/mlir/libMLIR_h.jl
git config core.fileMode false
@@ -48,6 +51,9 @@ jobs:
using JuliaFormatter
format("./src/mlir/Dialects/")
format("./src/mlir/libMLIR_h.jl")
# Format twice to work around <https://github.com/domluna/JuliaFormatter.jl/issues/897>.
format("./src/mlir/Dialects/")
format("./src/mlir/libMLIR_h.jl")
- name: Create Pull Request
id: cpr
uses: peter-evans/create-pull-request@v7
@@ -57,6 +63,7 @@ jobs:
title: 'Regenerate MLIR Bindings'
branch: regenerate-mlir-bindings
delete-branch: true
author: enzyme-ci-bot[bot] <78882869+enzyme-ci-bot[bot]@users.noreply.github.com>
- name: Check outputs
run: |
echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}"
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -252,6 +252,7 @@ docs/site/
# environment.
Manifest.toml
Manifest-v*.toml
.CondaPkg

.vscode/*
.vscode/settings.json
4 changes: 2 additions & 2 deletions CondaPkg.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[deps]
jax = ""
[pip.deps]
jax = ">=0.4"
44 changes: 35 additions & 9 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,44 +1,59 @@
name = "Reactant"
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>", "Sergio Sánchez Ramírez <sergio.sanchez.ramirez@bsc.es>", "Paul Berg <paul@plutojl.org>", "Avik Pal <avikpal@mit.edu>"]
version = "0.2.11"
version = "0.2.36"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LLVMOpenMP_jll = "1d63c593-3942-5779-bab2-d838dc0a180e"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433"
Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"
Scratch = "6c6a2e73-6563-6170-7368-637461726353"
Sockets = "6462fe0b-24de-5631-8697-dd941f90decc"

[weakdeps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"

[sources.ReactantCore]
path = "lib/ReactantCore"
[sources]
ReactantCore = {path = "lib/ReactantCore"}

[extensions]
ReactantAbstractFFTsExt = "AbstractFFTs"
ReactantArrayInterfaceExt = "ArrayInterface"
ReactantCUDAExt = "CUDA"
ReactantCUDAExt = ["CUDA", "GPUCompiler", "KernelAbstractions", "LLVM"]
ReactantKernelAbstractionsExt = "KernelAbstractions"
ReactantMPIExt = "MPI"
ReactantNNlibExt = "NNlib"
ReactantOffsetArraysExt = "OffsetArrays"
ReactantPythonCallExt = "PythonCall"
ReactantRandom123Ext = "Random123"
ReactantSpecialFunctionsExt = "SpecialFunctions"
ReactantStatisticsExt = "Statistics"
ReactantYaoBlocksExt = "YaoBlocks"

@@ -47,21 +62,32 @@ AbstractFFTs = "1.5"
Adapt = "4.1"
ArrayInterface = "7.17.1"
CEnum = "0.5"
CUDA = "5.5"
CUDA = "5.6"
Downloads = "1.6"
Enzyme = "0.13.22"
EnumX = "1"
Enzyme = "0.13.28"
EnzymeCore = "0.8.8"
GPUArraysCore = "0.1.6, 0.2"
Functors = "0.5"
GPUArraysCore = "0.2"
GPUCompiler = "1.1.1"
KernelAbstractions = "0.9.30"
LLVM = "9.1"
LLVMOpenMP_jll = "18.1.7"
LinearAlgebra = "1.10"
MPI = "0.20"
NNlib = "0.9.26"
OffsetArrays = "1"
OrderedCollections = "1"
PrecompileTools = "1.2"
Preferences = "1.4"
PythonCall = "0.9"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.3"
Reactant_jll = "0.0.32"
ReactantCore = "0.1.5"
Reactant_jll = "0.0.80"
Scratch = "1.2"
Sockets = "1.10"
SpecialFunctions = "2.4"
Statistics = "1.10"
YaoBlocks = "0.13"
julia = "1.10"
13 changes: 0 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
@@ -60,16 +60,3 @@ Reactant.set_default_backend("gpu")

# ones favorite code will now all be executed on GPU, no CUDA.jl dependency even required!
```

## Installing Reactant on GPU Servers without Internet

If you want to use Reactant on GPU Servers where all packages must be installed on the login nodes and the compute nodes don't have access to internet,
add the following to the Project.toml and precompile the package:

```toml
[extras]
Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"

[preferences.Reactant_jll]
gpu = "cuda"
```
6 changes: 3 additions & 3 deletions benchmark/setup.jl
Original file line number Diff line number Diff line change
@@ -55,20 +55,20 @@ function setup_simple_benchmark!(suite::BenchmarkGroup, backend)
suite["(Basics) 2D sum (2 x 10)"]["forward (compilation)"][backend][tag] = @benchmarkable begin
@compile optimize = $(opt_pass) sum(x)
end setup = begin
x = Reactant.ConcreteRArray(ones(2, 10))
x = Reactant.to_rarray(ones(2, 10))
end

suite["(Basics) sum(cos, x) (2 x 10)"]["forward (compilation)"][backend][tag] = @benchmarkable begin
@compile optimize = $(opt_pass) sumcos(x)
end setup = begin
x = Reactant.ConcreteRArray(ones(2, 10))
x = Reactant.to_rarray(ones(2, 10))
end
end

suite["Basics ∇sumcos (2 x 10)"]["forward (compilation)"][backend]["Reactant"] = @benchmarkable begin
@compile optimize = :all ∇sumcos(x)
end setup = begin
x = Reactant.ConcreteRArray(ones(2, 10))
x = Reactant.to_rarray(ones(2, 10))
end

return nothing
9 changes: 1 addition & 8 deletions deps/Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
[deps]
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Scratch = "6c6a2e73-6563-6170-7368-637461726353"
Clang = "40e3b903-d033-50b4-a0cc-940c62c95e31"
BinaryBuilderBase = "7f725544-6523-48cd-82d1-3fa08ff4056e"

[compat]
Clang = "0.18"
5 changes: 2 additions & 3 deletions deps/ReactantExtra/.bazelrc
Original file line number Diff line number Diff line change
@@ -18,14 +18,13 @@ build -c opt
build:cuda --repo_env TF_NEED_CUDA=1
build:cuda --repo_env TF_NVCC_CLANG=1
build:cuda --repo_env TF_NCCL_USE_STUB=1
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2"
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1"
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.6.2"
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.4.0"
# "sm" means we emit only cubin, which is forward compatible within a GPU generation.
# "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations.
build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90"
build:cuda --crosstool_top="@local_config_cuda//crosstool:toolchain"
build:cuda --@local_config_cuda//:enable_cuda
build:cuda --@xla//xla/python:jax_cuda_pip_rpaths=true
# Default hermetic CUDA and CUDNN versions.
build:cuda --@local_config_cuda//cuda:include_cuda_libs=true
build:cuda --@local_config_cuda//:cuda_compiler=nvcc
2,244 changes: 1,425 additions & 819 deletions deps/ReactantExtra/API.cpp

Large diffs are not rendered by default.

316 changes: 278 additions & 38 deletions deps/ReactantExtra/BUILD

Large diffs are not rendered by default.

61 changes: 51 additions & 10 deletions deps/ReactantExtra/WORKSPACE
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@ http_archive(
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
)

ENZYMEXLA_COMMIT = "b6d6563aa3a3050474a4250bf18322f7ebf0b486"
ENZYMEXLA_COMMIT = "52f12204b764d0e61da249083ae1a3273da171b7"
ENZYMEXLA_SHA256 = ""

http_archive(
@@ -51,15 +51,6 @@ load("@enzyme_ad//:workspace.bzl", "JAX_COMMIT", "JAX_SHA256", "ENZYME_COMMIT",

XLA_PATCHES = XLA_PATCHES + [
"""
sed -i.bak0 "s/__cpp_lib_hardware_interference_size/HW_INTERFERENCE_SIZE/g" xla/backends/cpu/runtime/thunk_executor.h
""",
"""
sed -i.bak0 "s/__cpp_lib_hardware_interference_size/HW_INTERFERENCE_SIZE/g" xla/stream_executor/host/host_kernel.cc
""",
"""
sed -i.bak0 "s/__cpp_lib_hardware_interference_size/HW_INTERFERENCE_SIZE/g" xla/tsl/concurrency/async_value_ref.h
""",
"""
sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.bzl -exec sed -i.bak0 's\\/HAVE_LINK_H=1\\/HAVE_LINK_H=0\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl
""",
"""
@@ -94,6 +85,41 @@ LLVM_TARGETS = select({
"//conditions:default": ["AMDGPU", "NVPTX"],
}) + ["AArch64", "X86", "ARM"]

# Uncomment these lines to use a custom LLVM commit
# LLVM_COMMIT = "b39c5cb6977f35ad727d86b2dd6232099734ffd3"
# LLVM_SHA256 = ""
# http_archive(
# name = "llvm-raw",
# build_file_content = "# empty",
# sha256 = LLVM_SHA256,
# strip_prefix = "llvm-project-" + LLVM_COMMIT,
# urls = ["https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT)],
# )
#
#
# load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")
# maybe(
# http_archive,
# name = "llvm_zlib",
# build_file = "@llvm-raw//utils/bazel/third_party_build:zlib-ng.BUILD",
# sha256 = "e36bb346c00472a1f9ff2a0a4643e590a254be6379da7cddd9daeb9a7f296731",
# strip_prefix = "zlib-ng-2.0.7",
# urls = [
# "https://github.com/zlib-ng/zlib-ng/archive/refs/tags/2.0.7.zip",
# ],
# )
#
# maybe(
# http_archive,
# name = "llvm_zstd",
# build_file = "@llvm-raw//utils/bazel/third_party_build:zstd.BUILD",
# sha256 = "7c42d56fac126929a6a85dbc73ff1db2411d04f104fae9bdea51305663a83fd0",
# strip_prefix = "zstd-1.5.2",
# urls = [
# "https://github.com/facebook/zstd/releases/download/v1.5.2/zstd-1.5.2.tar.gz"
# ],
# )

http_archive(
name = "jax",
sha256 = JAX_SHA256,
@@ -201,6 +227,21 @@ xla_workspace0()
load("@jax//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")
flatbuffers()

load("@jax//jaxlib:jax_python_wheel.bzl", "jax_python_wheel_repository")
jax_python_wheel_repository(
name = "jax_wheel",
version_key = "_version",
version_source = "@jax//jax:version.py",
)

load(
"@tsl//third_party/py:python_wheel.bzl",
"python_wheel_version_suffix_repository",
)
python_wheel_version_suffix_repository(
name = "jax_wheel_version_suffix",
)

load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
"cuda_json_init_repository",
19 changes: 18 additions & 1 deletion deps/ReactantExtra/make-bindings.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
const bazel_cmd = if !isnothing(Sys.which("bazelisk"))
"bazelisk"
elseif !isnothing(Sys.which("bazel"))
"bazel"
else
error("Could not find `bazel` or `bazelisk` in PATH!")
end

function build_file(output_path)
file = basename(output_path)
run(
Cmd(
`bazel build --action_env=JULIA=$(Base.julia_cmd().exec[1]) --action_env=JULIA_DEPOT_PATH=$(Base.DEPOT_PATH) --repo_env HERMETIC_PYTHON_VERSION="3.10" --check_visibility=false --verbose_failures //:$file`;
`$(bazel_cmd) build --action_env=JULIA=$(Base.julia_cmd().exec[1]) --action_env=JULIA_DEPOT_PATH=$(Base.DEPOT_PATH) --repo_env HERMETIC_PYTHON_VERSION="3.10" --check_visibility=false --verbose_failures //:$file`;
dir=@__DIR__,
),
)
@@ -23,6 +31,15 @@ for file in [
"StableHLO.jl",
"CHLO.jl",
"VHLO.jl",
"Llvm.jl",
"Nvvm.jl",
"Gpu.jl",
"Affine.jl",
"TPU.jl",
"Triton.jl",
"Shardy.jl",
"MPI.jl",
"MemRef.jl",
]
build_file(joinpath(src_dir, "mlir", "Dialects", file))
end
9 changes: 7 additions & 2 deletions deps/ReactantExtra/make.jl
Original file line number Diff line number Diff line change
@@ -18,9 +18,11 @@ let options = deepcopy(options)

genarg = first(eachsplit(ARGS[3], " "))

gen_include_dir = joinpath(splitpath(genarg)[1:(end - 3)]...)
gen_include_dir = joinpath(splitpath(genarg)[1:(end - 4)]...)

hlo_include_dir = joinpath(splitpath(ARGS[end - 1])[1:(end - 1)]...)
hlo_include_dir = joinpath(splitpath(ARGS[end - 2])[1:(end - 1)]...)

sdy_include_dir = joinpath(splitpath(ARGS[end - 1])[1:(end - 1)]...)

append!(
args,
@@ -33,6 +35,8 @@ let options = deepcopy(options)
gen_include_dir,
"-I",
hlo_include_dir,
"-I",
sdy_include_dir,
"-x",
"c++",
],
@@ -41,6 +45,7 @@ let options = deepcopy(options)
headers = [
detect_headers(include_dir, args, Dict(), endswith("Python/Interop.h"))...,
detect_headers(hlo_include_dir, args, Dict())...,
detect_headers(sdy_include_dir, args, Dict())...,
]

ctx = create_context(headers, args, options)
398 changes: 196 additions & 202 deletions deps/ReactantExtra/tblgen/jl-generators.cc

Large diffs are not rendered by default.

32 changes: 18 additions & 14 deletions deps/ReactantExtra/tblgen/mlir-jl-tblgen.cc
Original file line number Diff line number Diff line change
@@ -26,30 +26,33 @@

using namespace llvm;

using generator_function = bool(const llvm::RecordKeeper& recordKeeper,
llvm::raw_ostream& os);
using generator_function = bool(const llvm::RecordKeeper &recordKeeper,
llvm::raw_ostream &os);

struct GeneratorInfo {
const char* name;
generator_function* generator;
const char *name;
generator_function *generator;
};

extern generator_function emitOpTableDefs;
extern generator_function emitTestTableDefs;

static std::array<GeneratorInfo, 1> generators {{
{"jl-op-defs", emitOpTableDefs},
static std::array<GeneratorInfo, 1> generators{{
{"jl-op-defs", emitOpTableDefs},
}};

generator_function* generator;
generator_function *generator;
bool disableModuleWrap;

int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv);
llvm::cl::opt<std::string> generatorOpt("generator", llvm::cl::desc("Generator to run"), cl::Required);
llvm::cl::opt<bool> disableModuleWrapOpt("disable-module-wrap", llvm::cl::desc("Disable module wrap"), cl::init(false));
llvm::cl::opt<std::string> generatorOpt(
"generator", llvm::cl::desc("Generator to run"), cl::Required);
llvm::cl::opt<bool> disableModuleWrapOpt(
"disable-module-wrap", llvm::cl::desc("Disable module wrap"),
cl::init(false));
cl::ParseCommandLineOptions(argc, argv);
for (const auto& spec : generators) {
for (const auto &spec : generators) {
if (generatorOpt == spec.name) {
generator = spec.generator;
break;
@@ -61,7 +64,8 @@ int main(int argc, char **argv) {
}
disableModuleWrap = disableModuleWrapOpt;

return TableGenMain(argv[0], [](raw_ostream& os, const RecordKeeper &records) {
return generator(records, os);
});
}
return TableGenMain(argv[0],
[](raw_ostream &os, const RecordKeeper &records) {
return generator(records, os);
});
}
Empty file.
207 changes: 140 additions & 67 deletions deps/build_local.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,64 @@
# Invoke with
# `julia --project=deps deps/build_local.jl [dbg/opt] [auto/cpu/cuda]`
# `julia --project=deps deps/build_local.jl [--debug] [--backend=auto/cpu/cuda]`

# the pre-built ReactantExtra_jll might not be loadable on this platform
Reactant_jll = Base.UUID("0192cb87-2b54-54ad-80e0-3be72ad8a3c0")

using Pkg, Scratch, Preferences, Libdl
using ArgParse

s = ArgParseSettings()
#! format: off
@add_arg_table! s begin
"--debug"
help = "Build with debug mode (-c dbg)."
action = :store_true
"--backend"
help = "Build with the specified backend (auto, cpu, cuda)."
default = "auto"
arg_type = String
"--gcc_host_compiler_path"
help = "Path to the gcc host compiler."
default = "/usr/bin/gcc"
arg_type = String
"--cc"
default = "/usr/bin/cc"
arg_type = String
"--hermetic_python_version"
help = "Hermetic Python version."
default = "3.10"
arg_type = String
"--jobs"
help = "Number of parallel jobs."
default = Sys.CPU_THREADS
arg_type = Int
"--copt"
help = "Options to be passed to the C compiler. Can be used multiple times."
action = :append_arg
arg_type = String
"--cxxopt"
help = "Options to be passed to the C++ compiler. Can be used multiple times."
action = :append_arg
arg_type = String
"--extraopt"
help = "Extra options to be passed to Bazel. Can be used multiple times."
action = :append_arg
arg_type = String
"--color"
help = "Set to `yes` to enable color output, or `no` to disable it. Defaults to same color setting as the Julia process."
default = something(Base.have_color, false) ? "yes" : "no"
arg_type = String
end
#! format: on
parsed_args = parse_args(ARGS, s)

# 1. Get a scratch directory
scratch_dir = get_scratch!(Reactant_jll, "build")
isdir(scratch_dir) && rm(scratch_dir; recursive=true)
println("Parsed args:")
for (k, v) in parsed_args
println(" $k = $v")
end
println()

source_dir = joinpath(@__DIR__, "ReactantExtra")

# 2. Ensure that an appropriate LLVM_full_jll is installed
Pkg.activate(; temp=true)

# Build!
@info "Building" source_dir scratch_dir
run(`mkdir -p $(scratch_dir)`)
run(
Cmd(
`$(Base.julia_cmd().exec[1]) --project=. -e "using Pkg; Pkg.instantiate()"`;
dir=source_dir,
),
)

#--repo_env TF_NEED_ROCM=1
#--define=using_rocm=true --define=using_rocm_hipcc=true
#--action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx1030"
@@ -41,27 +75,10 @@ run(
# --@local_config_cuda//:cuda_compiler=nvcc
# --crosstool_top="@local_config_cuda//crosstool:toolchain"

build_kind = if length(ARGS) 1
kind = ARGS[1]
if kind ("dbg", "opt")
error("Invalid build kind $(kind). Valid options are 'dbg' and 'opt'")
end
kind
else
"dbg"
end

@info "Building JLL with -c $(build_kind)"
build_kind = parsed_args["debug"] ? "dbg" : "opt"

build_backend = if length(ARGS) 2
backend = ARGS[2]
if backend ("auto", "cpu", "cuda")
error("Invalid build backend $(backend). Valid options are 'auto', 'cpu', and 'cuda'")
end
backend
else
"auto"
end
build_backend = parsed_args["backend"]
@assert build_backend in ("auto", "cpu", "cuda")

if build_backend == "auto"
build_backend = try
@@ -78,46 +95,102 @@ elseif build_backend == "cpu"
""
end

@info "Building JLL with backend $(build_backend)"

if isempty(arg)
run(
Cmd(
`bazel build -c $(build_kind) --action_env=JULIA=$(Base.julia_cmd().exec[1])
--repo_env HERMETIC_PYTHON_VERSION="3.10"
--check_visibility=false --verbose_failures :libReactantExtra.so`;
dir=source_dir,
),
)
bazel_cmd = if !isnothing(Sys.which("bazelisk"))
"bazelisk"
elseif !isnothing(Sys.which("bazel"))
"bazel"
else
run(
Cmd(
`bazel build $(arg) -c $(build_kind) --action_env=JULIA=$(Base.julia_cmd().exec[1])
--repo_env HERMETIC_PYTHON_VERSION="3.10"
--check_visibility=false --verbose_failures :libReactantExtra.so`;
dir=source_dir,
),
error("Could not find `bazel` or `bazelisk` in PATH!")
end

@info "Building JLL with $(bazel_cmd)"

gcc_host_compiler_path = parsed_args["gcc_host_compiler_path"]
cc = parsed_args["cc"]
hermetic_python_version = parsed_args["hermetic_python_version"]

# Try to guess if `cc` is GCC and get its version number.
cc_is_gcc, gcc_version = let
io = IOBuffer()
run(pipeline(ignorestatus(`$(cc) --version`); stdout=io))
version_string = String(take!(io))
# Detecing GCC is hard, the name "gcc" may not appear anywhere in the
# version string, but on the second line there should be FSF.
m = match(
r"\([^)]+\) (\d+\.\d+\.\d+).*\n.*Free Software Foundation, Inc\.",
version_string,
)
if !isnothing(m)
true, VersionNumber(m[1])
else
false, v"0"
end
end
# env=Dict("HOME"=>ENV["HOME"], "PATH"=>joinpath(source_dir, "..")*":"*ENV["PATH"])))

run(Cmd(`rm -f libReactantExtra.dylib`; dir=joinpath(source_dir, "bazel-bin")))
run(
Cmd(
`ln -s libReactantExtra.so libReactantExtra.dylib`;
dir=joinpath(source_dir, "bazel-bin"),
),
)

build_cmd_list = [bazel_cmd, "build"]
!isempty(arg) && push!(build_cmd_list, arg)
append!(build_cmd_list, ["-c", "$(build_kind)"])
push!(build_cmd_list, "--action_env=JULIA=$(Base.julia_cmd().exec[1])")
push!(build_cmd_list, "--repo_env=HERMETIC_PYTHON_VERSION=$(hermetic_python_version)")
push!(build_cmd_list, "--repo_env=GCC_HOST_COMPILER_PATH=$(gcc_host_compiler_path)")
push!(build_cmd_list, "--repo_env=CC=$(cc)")
push!(build_cmd_list, "--check_visibility=false")
push!(build_cmd_list, "--verbose_failures")
push!(build_cmd_list, "--jobs=$(parsed_args["jobs"])")
for opt in parsed_args["copt"]
push!(build_cmd_list, "--copt=$(opt)")
end
for opt in parsed_args["cxxopt"]
push!(build_cmd_list, "--cxxopt=$(opt)")
end
for opt in parsed_args["extraopt"]
push!(build_cmd_list, opt)
end
# Some versions of GCC can't deal with some components of XLA, disable them if necessary.
if cc_is_gcc && build_backend == "cuda"
arch = Base.BinaryPlatforms.arch(Base.BinaryPlatforms.HostPlatform())
if arch == "x86_64"
if gcc_version < v"13"
push!(build_cmd_list, "--define=xnn_enable_avxvnniint8=false")
end
if gcc_version < v"12"
push!(build_cmd_list, "--define=xnn_enable_avx512fp16=false")
end
end
end
push!(build_cmd_list, "--color=$(parsed_args["color"])")
push!(build_cmd_list, ":libReactantExtra.so")

run(Cmd(Cmd(build_cmd_list); dir=source_dir))

# Discover built libraries
built_libs = filter(readdir(joinpath(source_dir, "bazel-bin"))) do file
endswith(file, "Extra.$(Libdl.dlext)") && startswith(file, "lib")
endswith(file, "Extra.so") && startswith(file, "lib")
end

lib_path = joinpath(source_dir, "bazel-bin", only(built_libs))
isfile(lib_path) || error("Could not find library $lib_path in build directory")

# Tell ReactReactantExtra_jllant_jll to load our library instead of the default artifact one
if build_backend == "cuda"
if !Base.Filesystem.ispath(joinpath(source_dir, "bazel-bin", "cuda", "bin", "ptxas"))
Base.Filesystem.mkpath(joinpath(source_dir, "bazel-bin", "cuda", "bin"))
Base.Filesystem.symlink(
joinpath(
source_dir,
"bazel-bin",
"libReactantExtra.so.runfiles",
"cuda_nvcc",
"bin",
"ptxas",
),
joinpath(source_dir, "bazel-bin", "cuda", "bin", "ptxas"),
)
end
end

# Tell ReactantExtra_jll to load our library instead of the default artifact one
using Preferences

set_preferences!(
joinpath(dirname(@__DIR__), "LocalPreferences.toml"),
"Reactant_jll",
2 changes: 0 additions & 2 deletions deps/clang

This file was deleted.

2 changes: 0 additions & 2 deletions deps/clang++

This file was deleted.

3 changes: 0 additions & 3 deletions deps/gcc

This file was deleted.

6 changes: 6 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -2,7 +2,13 @@
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterVitepress = "4710194d-e776-4893-9690-8d956a29c365"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Documenter = "1.4.1"

[sources]
Reactant = {path = ".."}
ReactantCore = {path = "../lib/ReactantCore"}
50 changes: 27 additions & 23 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
pushfirst!(LOAD_PATH, joinpath(@__DIR__, ".."))
pushfirst!(LOAD_PATH, joinpath(@__DIR__, "../lib/ReactantCore/"))

using Reactant, ReactantCore
using Documenter, DocumenterVitepress

DocMeta.setdocmeta!(Reactant, :DocTestSetup, :(using Reactant); recursive=true)

# Helper functions
function first_letter_uppercase(str)
return uppercase(str[1]) * str[2:end]
end

# Generate examples

using Literate
@@ -26,21 +28,24 @@ examples = [

pages = [
"Reactant.jl" => "index.md",
"Introduction" => ["Getting Started" => "introduction/index.md"],
"Tutorials" => ["Overview" => "tutorials/index.md"],
"Introduction" => [
"Getting Started" => "introduction/index.md",
"Configuration" => "introduction/configuration.md",
],
"Tutorials" =>
["Overview" => "tutorials/index.md", "Profiling" => "tutorials/profiling.md"],
"API Reference" => [
"Reactant API" => "api/api.md",
"Ops" => "api/ops.md",
"Dialects" => [
"ArithOps" => "api/arith.md",
"Affine" => "api/affine.md",
"Builtin" => "api/builtin.md",
"Chlo" => "api/chlo.md",
"Enzyme" => "api/enzyme.md",
"Func" => "api/func.md",
"StableHLO" => "api/stablehlo.md",
"VHLO" => "api/vhlo.md",
],
"Dialects" => sort!(
[
first_letter_uppercase(first(splitext(basename(file)))) =>
joinpath("api/dialects", file) for
file in readdir(joinpath(@__DIR__, "src/api/dialects")) if
splitext(file)[2] == ".md"
];
by=first,
),
"MLIR API" => "api/mlirc.md",
"XLA" => "api/xla.md",
"Internal API" => "api/internal.md",
@@ -55,14 +60,13 @@ makedocs(;
Reactant.MLIR,
Reactant.MLIR.API,
Reactant.MLIR.IR,
Reactant.MLIR.Dialects.chlo,
Reactant.MLIR.Dialects.vhlo,
Reactant.MLIR.Dialects.stablehlo,
Reactant.MLIR.Dialects.enzyme,
Reactant.MLIR.Dialects.arith,
Reactant.MLIR.Dialects.func,
Reactant.MLIR.Dialects.affine,
Reactant.MLIR.Dialects.builtin,
filter(
Base.Fix2(isa, Module),
[
getproperty(Reactant.MLIR.Dialects, x) for
x in names(Reactant.MLIR.Dialects; all=true) if x != :Dialects
],
)...,
],
authors="William Moses <wsmoses@illinois.edu>, Valentin Churavy <vchuravy@mit.edu>",
sitename="Reactant.jl",
70 changes: 51 additions & 19 deletions docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
@@ -51,25 +51,46 @@ export default defineConfig({
},
nav: [
{ text: "Home", link: "/" },
{ text: "Getting Started", link: "/introduction" },
{ text: "Getting Started",
items: [
{ text: "Introduction", link: "/introduction" },
{ text: "Configuration", link: "/introduction/configuration" },
],
},
{ text: "Benchmarks", link: "https://enzymead.github.io/Reactant.jl/benchmarks/" },
{ text: "Tutorials", link: "/tutorials/" },
{
text: "Tutorials",
items: [
{text: "Overview", link: "/tutorials/"},
{text: "Profiling", link: "/tutorials/profiling"},
],
},
{
text: "API",
items: [
{ text: "Core Reactant API", link: "/api/api" },
{ text: "Sharding", link: "/api/sharding" },
{ text: "Ops", link: "/api/ops" },
{
text: "MLIR Dialects",
items: [
{ text: "ArithOps", link: "/api/arith" },
{ text: "Affine", link: "/api/affine" },
{ text: "Builtin", link: "/api/builtin" },
{ text: "Chlo", link: "/api/chlo" },
{ text: "Enzyme", link: "/api/enzyme" },
{ text: "Func", link: "/api/func" },
{ text: "StableHLO", link: "/api/stablehlo" },
{ text: "VHLO", link: "/api/vhlo" },
{ text: "ArithOps", link: "/api/dialects/arith" },
{ text: "Affine", link: "/api/dialects/affine" },
{ text: "Builtin", link: "/api/dialects/builtin" },
{ text: "Chlo", link: "/api/dialects/chlo" },
{ text: "Enzyme", link: "/api/dialects/enzyme" },
{ text: "EnzymeXLA", link: "/api/dialects/enzymexla" },
{ text: "Func", link: "/api/dialects/func" },
{ text: "GPU", link: "/api/dialects/gpu" },
{ text: "LLVM", link: "/api/dialects/llvm" },
{ text: "MPI", link: "/api/dialects/mpi" },
{ text: "MemRef", link: "/api/dialects/memref" },
{ text: "NVVM", link: "/api/dialects/nvvm" },
{ text: "Shardy", link: "/api/dialects/shardy" },
{ text: "StableHLO", link: "/api/dialects/stablehlo" },
{ text: "Triton", link: "/api/dialects/triton" },
{ text: "TPU", link: "/api/dialects/tpu" },
{ text: "VHLO", link: "/api/dialects/vhlo" },
],
},
{
@@ -88,18 +109,19 @@ export default defineConfig({
],
sidebar: {
"/introduction/": {
// @ts-ignore
text: "Getting Started",
collapsed: false,
items: [
{ text: "Introduction", link: "/introduction" },
{ text: "Configuration", link: "/introduction/configuration" },
],
},
"/tutorials/": {
text: "Tutorials",
collapsed: false,
items: [
{ text: "Overview", link: "/tutorials/" },
{ text: "Profiling", link: "/tutorials/profiling" },
],
},
"/api/": {
@@ -110,19 +132,29 @@ export default defineConfig({
text: "Reactant API",
link: "/api/api",
},
{ text: "Sharding", link: "/api/sharding" },
{ text: "Ops", link: "/api/ops" },
{
text: "MLIR Dialects",
collapsed: false,
items: [
{ text: "ArithOps", link: "/api/arith" },
{ text: "Affine", link: "/api/affine" },
{ text: "Builtin", link: "/api/builtin" },
{ text: "Chlo", link: "/api/chlo" },
{ text: "Enzyme", link: "/api/enzyme" },
{ text: "Func", link: "/api/func" },
{ text: "StableHLO", link: "/api/stablehlo" },
{ text: "VHLO", link: "/api/vhlo" },
{ text: "ArithOps", link: "/api/dialects/arith" },
{ text: "Affine", link: "/api/dialects/affine" },
{ text: "Builtin", link: "/api/dialects/builtin" },
{ text: "Chlo", link: "/api/dialects/chlo" },
{ text: "Enzyme", link: "/api/dialects/enzyme" },
{ text: "EnzymeXLA", link: "/api/dialects/enzymexla" },
{ text: "Func", link: "/api/dialects/func" },
{ text: "GPU", link: "/api/dialects/gpu" },
{ text: "LLVM", link: "/api/dialects/llvm" },
{ text: "MPI", link: "/api/dialects/mpi" },
{ text: "MemRef", link: "/api/dialects/memref" },
{ text: "NVVM", link: "/api/dialects/nvvm" },
{ text: "Shardy", link: "/api/dialects/shardy" },
{ text: "StableHLO", link: "/api/dialects/stablehlo" },
{ text: "Triton", link: "/api/dialects/triton" },
{ text: "TPU", link: "/api/dialects/tpu" },
{ text: "VHLO", link: "/api/dialects/vhlo" },
],
},
{
26 changes: 17 additions & 9 deletions docs/src/api/api.md
Original file line number Diff line number Diff line change
@@ -13,6 +13,10 @@ Reactant.@jit

## ReactantCore API

```@docs
within_compile
```

```@docs
@trace
```
@@ -21,20 +25,24 @@ Reactant.@jit

```@docs
@code_hlo
@code_mhlo
@code_xla
```

```@raw html
<br>
```
## Profile XLA

# Internal Functionality
Reactant can hook into XLA's profiler to generate compilation and execution traces.
See the [profiling tutorial](@ref profiling) for more details.

!!! danger "Private"
```@docs
Reactant.Profiler.with_profiler
Reactant.Profiler.annotate
Reactant.Profiler.@annotate
```

These functions are not part of the public API and are subject to change at any time.
## Devices

```@docs
Reactant.Compiler.codegen_unflatten!
Reactant.Compiler.codegen_flatten!
Reactant.Compiler.codegen_xla_call
Reactant.devices
Reactant.addressable_devices
```
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
9 changes: 9 additions & 0 deletions docs/src/api/dialects/enzymexla.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
```@meta
CollapsedDocStrings = true
```

# EnzymeXLA Dialect

```@autodocs
Modules = [Reactant.MLIR.Dialects.enzymexla]
```
File renamed without changes.
12 changes: 12 additions & 0 deletions docs/src/api/dialects/gpu.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
```@meta
CollapsedDocStrings = true
```

# GPU Dialect

Refer to the [official documentation](https://mlir.llvm.org/docs/Dialects/GPU/) for
more details.

```@autodocs
Modules = [Reactant.MLIR.Dialects.gpu]
```
12 changes: 12 additions & 0 deletions docs/src/api/dialects/llvm.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
```@meta
CollapsedDocStrings = true
```

# LLVM Dialect

Refer to the [official documentation](https://mlir.llvm.org/docs/Dialects/LLVM/) for
more details.

```@autodocs
Modules = [Reactant.MLIR.Dialects.llvm]
```
12 changes: 12 additions & 0 deletions docs/src/api/dialects/memref.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
```@meta
CollapsedDocStrings = true
```

# MemRef Dialect

Refer to the [official documentation](https://mlir.llvm.org/docs/Dialects/MemRef/) for more
details.

```@autodocs
Modules = [Reactant.MLIR.Dialects.memref]
```
12 changes: 12 additions & 0 deletions docs/src/api/dialects/mpi.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
```@meta
CollapsedDocStrings = true
```

# MPI Dialect

Refer to the [official documentation](https://mlir.llvm.org/docs/Dialects/MPI/) for
more details.

```@autodocs
Modules = [Reactant.MLIR.Dialects.mpi]
```
12 changes: 12 additions & 0 deletions docs/src/api/dialects/nvvm.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
```@meta
CollapsedDocStrings = true
```

# NVVM Dialect

Refer to the [official documentation](https://mlir.llvm.org/docs/Dialects/NVVMDialect/) for
more details.

```@autodocs
Modules = [Reactant.MLIR.Dialects.nvvm]
```
11 changes: 11 additions & 0 deletions docs/src/api/dialects/shardy.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
```@meta
CollapsedDocStrings = true
```

# Shardy Dialect

Refer to the [official documentation](https://openxla.org/shardy) for more details.

```@autodocs
Modules = [Reactant.MLIR.Dialects.sdy]
```
File renamed without changes.
12 changes: 12 additions & 0 deletions docs/src/api/dialects/tpu.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
```@meta
CollapsedDocStrings = true
```

# TPU Dialect

Refer to the [official documentation](https://github.com/jax-ml/jax/blob/main/jaxlib/mosaic/dialect/tpu/tpu.td) for
more details.

```@autodocs
Modules = [Reactant.MLIR.Dialects.tpu]
```
12 changes: 12 additions & 0 deletions docs/src/api/dialects/triton.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
```@meta
CollapsedDocStrings = true
```

# Triton Dialect

Refer to the [official documentation](https://triton-lang.org/main/dialects/TritonDialect.html) for
more details.

```@autodocs
Modules = [Reactant.MLIR.Dialects.tt]
```
File renamed without changes.
8 changes: 6 additions & 2 deletions docs/src/api/internal.md
Original file line number Diff line number Diff line change
@@ -4,9 +4,13 @@ CollapsedDocStrings = true

# Internal API

These functions are not part of the public API and are subject to change at any time.
!!! danger "Private"

These functions are not part of the public API and are subject to change at any time.

```@docs
Reactant.REDUB_ARGUMENTS_NAME
Reactant.within_reactant_interpreter
Reactant.Compiler.codegen_unflatten!
Reactant.Compiler.codegen_flatten!
Reactant.Compiler.codegen_xla_call
```
14 changes: 14 additions & 0 deletions docs/src/api/sharding.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
```@meta
CollapsedDocStrings = true
```

# Sharding API

`Reactant.Sharding` module provides a high-level API to construct MLIR operations with
support for sharding.

Currently we haven't documented all the functions in `Reactant.Sharding`.

```@autodocs
Modules = [Reactant.Sharding]
```
161 changes: 161 additions & 0 deletions docs/src/introduction/configuration.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Configuration

When you [install](@ref Installation) `Reactant.jl`, the library powering the package compatible with your system will be automatically installed for you.
Below are some information about making sure that you are using the right configuration of Reactant for your machine.

## Reactant with CPU

At the moment Reactant supports only Linux (x86-64 and aarch64 architectures) and macOS (x86-64 and aarch64 architectures).
If you are using Julia on any of these systems, then Reactant should always support the CPU backend.
In the same environment where you installed Reactant you can verify it by running the following commands:

```julia-repl
julia> import Pkg
julia> Pkg.add("Reactant_jll")
[...]
julia> import Reactant_jll
julia> Reactant_jll.is_available()
true
```

If the last command returns `true`, you are good to go, if you get `false` but you think your system is one of the supported ones listed above, [open an issue](https://github.com/EnzymeAD/Reactant.jl/issues/new/choose).

## Reactant with GPU

At the moment Reactant supports only Nvidia GPUs.

### Nvidia GPU

Reactant can accelerate your code using Nvidia GPUs on Linux, with CUDA Driver 12.1+ on x86-64, and CUDA Driver 12.3+ on aarch64.
You can check if Reactant detected the GPU on your system by running the following commands in the environment where you installed Reactant:

```julia-repl
julia> import Pkg
julia> Pkg.add("Reactant_jll")
[...]
julia> import Reactant_jll
julia> Reactant_jll.is_available()
true
julia> Reactant_jll.host_platform
Linux x86_64 {cuda_version=12.1, cxxstring_abi=cxx11, gpu=cuda, julia_version=1.11.3, libc=glibc, libgfortran_version=5.0.0, libstdcxx_version=3.4.30, mode=opt}
```

Like in the CPU section above, we ran `Reactant_jll.is_available()` to make sure Reactant is available at all, the `Reactant_jll.host_platform` variable then gives us more information about the detected platform.
In particular, if you have an Nvidia GPU you should expect to see `gpu=cuda` and `cuda_version=X.Y`, where `X.Y` should be a version less than or equal to the version of the CUDA Driver present in your system (don't worry if you don't see here exactly the same version as your CUDA Driver, that is expected).

#### Debugging installation with Nvidia GPUs

In some cases you may want to get more verbose information from Reactant during its installation process, to see how it detected CUDA.
To do that, you can force re-installation of `Reactant_jll` with increased verbosity with the commands

```julia-repl
julia> rm(joinpath(Base.DEPOT_PATH[1], "compiled", "v$(VERSION.major).$(VERSION.minor)", "Reactant_jll"); recursive=true, force=true)
julia> ENV["JULIA_DEBUG"] = "Reactant_jll";
julia> import Pkg
julia> Pkg.add("Reactant_jll")
[...]
1 dependency had output during precompilation:
┌ Reactant_jll
│ ┌ Debug: Detected CUDA Driver version 12.2.0
│ └ @ Reactant_jll ~/.julia/packages/Reactant_jll/daenT/.pkg/platform_augmentation.jl:60
│ ┌ Debug: Adding include dependency on /lib/x86_64-linux-gnu/libcuda.so.1
│ └ @ Reactant_jll ~/.julia/packages/Reactant_jll/daenT/.pkg/platform_augmentation.jl:108
```

Here you can see that on this system Reactant found the CUDA Driver at `/lib/x86_64-linux-gnu/libcuda.so.1` with version 12.2.0.

#### Installing Reactant on GPU Servers without Internet

If you want to use Reactant on GPU Servers where all packages must be installed on the login nodes and the compute nodes don't have access to internet, add the following to the `Project.toml` and precompile the package:

```toml
[extras]
Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"

[preferences.Reactant_jll]
gpu = "cuda"
```

#### Disabling CUDA support

Reactant looks for the CUDA Driver library `libcuda` to determine whether the current system supports Nvidia GPUs.
However in some cases this library may be actually present on the machine even though no GPU is actually attached to it, which would trick Reactant's installation process into believing a GPU is available.
Normally this is not a problem as Reactant will detect that in spite of the CUDA Driver being present there are no GPUs and will default to the CPU backend.
If you do experience issues due to a GPU being detected erroneously, you can force disabling GPU support by creating a file called `LocalPreferences.toml` in the environment where you installed Reactant with the following content:

```toml
[Reactant_jll]
gpu = "none"
```

install the package `Reactant_jll`:

```julia
import Pkg
Pkg.add("Reactant_jll")
```

and then when you restart Julia you should see

```julia-repl
julia> import Reactant_jll
julia> Reactant_jll.is_available()
true
julia> Reactant_jll.host_platform
Linux x86_64 {cuda_version=none, cxxstring_abi=cxx11, gpu=none, julia_version=1.11.3, libc=glibc, libgfortran_version=5.0.0, libstdcxx_version=3.4.30, mode=opt}
```

Reactant is still available for your system, but this time GPU support is disabled.

## Reactant with TPU

Reactant should detect automatically when you are running on a machine with a TPU, and load dynamically the necessary modules.
You can verify a TPU was found correctly with the following commands:

```julia-repl
julia> import Reactant
julia> Reactant.has_tpu()
true
```

### Memory errors on Google Cloud Platform

If you are running Julia on Google Cloud Platform, you may frequently get scary-looking memory-related error messages like:

```
double free or corruption (out)
```

or

```
free(): invalid pointer
```

This is due to the fact that in this environment a memory allocator incompatible with Julia is forced via the `LD_PRELOAD` environment variable.
Starting Julia with

```sh
LD_PRELOAD='' julia
```

or unsetting the variable

```sh
unset LD_PRELOAD
```

should solve this issue.
80 changes: 80 additions & 0 deletions docs/src/introduction/index.md
Original file line number Diff line number Diff line change
@@ -53,3 +53,83 @@ f = @compile sinsum_add(input1,input2)
# one can now run the program
f(input1, input2)
```


## Tips

### Empty Cache

When you encounter OOM (Out of Memory) errors, you can try to clear the cache by using Julia's builtin `GC.gc()` between memory-intensive operations.

!!! note
This will only free memory which is not currently live. If the result of compiled function was stored in a vector, it will still be alive and `GC.gc()` won't free it.

```julia
using Reactant
n = 500_000_000
input1 = Reactant.ConcreteRArray(ones(n))
input2 = Reactant.ConcreteRArray(ones(n))

function sin_add(x, y)
return sin.(x) .+ y
end

f = @compile sin_add(input1,input2)

for i = 1:10
GC.gc()
@info "gc... $i"
f(input1, input2) # May cause OOM here for a 24GB GPU if GC is not used
end
```

If you **don't** use `GC.gc()` here, this may cause an OOM:



```bash
[ Info: gc... 1
[ Info: gc... 2
[ Info: gc... 3
...
E0105 09:48:28.755177 110350 pjrt_stream_executor_client.cc:3088] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 4000000000 bytes.
ERROR: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 4000000000 bytes.

Stacktrace:
[1] reactant_err(msg::Cstring)
@ Reactant.XLA ~/.julia/packages/Reactant/7m11i/src/XLA.jl:104
[2] macro expansion
@ ~/.julia/packages/Reactant/7m11i/src/XLA.jl:357 [inlined]
[3] ExecutableCall
@ ~/.julia/packages/Reactant/7m11i/src/XLA.jl:334 [inlined]
[4] macro expansion
@ ~/.julia/packages/Reactant/7m11i/src/Compiler.jl:798 [inlined]
[5] (::Reactant.Compiler.Thunk{…})(::ConcreteRArray{…}, ::ConcreteRArray{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/7m11i/src/Compiler.jl:909
[6] top-level scope
@ ./REPL[7]:4
Some type information was truncated. Use `show(err)` to see complete types.
```
After using Julia's built-in `GC.gc()`:
```bash
[ Info: gc... 1
[ Info: gc... 2
[ Info: gc... 3
[ Info: gc... 4
[ Info: gc... 5
[ Info: gc... 6
[ Info: gc... 7
[ Info: gc... 8
[ Info: gc... 9
[ Info: gc... 10
```
Binary file added docs/src/tutorials/images/perfetto.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/src/tutorials/images/tensorboard.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 3 additions & 1 deletion docs/src/tutorials/index.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Tutorials

We are currently working on adding tutorials to Reactant!! Please check back soon!
- [Profiling](@ref profiling).

We are currently working on adding more tutorials to Reactant!! Please check back soon!
84 changes: 84 additions & 0 deletions docs/src/tutorials/profiling.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# [Profiling](@id profiling)

## Capturing traces

When running Reactant, it is possible to capture traces using the [XLA profiler](https://jax.readthedocs.io/en/latest/profiling.html).
These traces can provide information about where the XLA specific parts of program spend time during compilation or execution. Note that tracing and compilation happen on the CPU even though the final execution is aimed to run on another device such as GPU or TPU. Therefore, including tracing and compilation in a trace will create annotations on the CPU.

Let's setup a simple function which we can then profile

```@example profiling
using Reactant
x = Reactant.to_rarray(randn(Float32, 100, 2))
W = Reactant.to_rarray(randn(Float32, 10, 100))
b = Reactant.to_rarray(randn(Float32, 10))
linear(x, W, b) = (W * x) .+ b
```

The profiler can be accessed using the [`Reactant.with_profiler`](@ref Reactant.Profiler.with_profiler) function.

```@example profiling
Reactant.with_profiler("./") do
mylinear = Reactant.@compile linear(x, W, b)
mylinear(x, W, b)
end
```

Running this function should create a folder called `plugins` in the folder provided to `Reactant.with_profiler` which will
contain the trace files. The traces can then be visualized in different ways.

!!! note
For more insights about the current state of Reactant, it is possible to fetch device information about allocations using the [`Reactant.XLA.allocatorstats`](@ref) function.

## Perfetto UI

![The perfetto interface](images/perfetto.png)

The first and easiest way to visualize a captured trace is to use the online [`perfetto.dev`](https://ui.perfetto.dev/) tool.
[`Reactant.with_profiler`](@ref Reactant.Profiler.with_profiler) has a keyword parameter called `create_perfetto_link` which will create a usable perfetto URL for the generated trace.
The function will block execution until the URL has been clicked and the trace is visualized. The URL only works once.

```julia
Reactant.with_profiler("./"; create_perfetto_link=true) do
mylinear = Reactant.@compile linear(x, W, b)
mylinear(x, W, b)
end
```

!!! note
It is recommended to use the Chrome browser to open the perfetto URL.

## Tensorboard

![The tensorboard interface](images/tensorboard.png)

Another option to visualize the generated trace files is to use the [tensorboard profiler plugin](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras).
The tensorboard viewer can offer more details than the timeline view such as visualization for compute graphs.

First install tensorboard and its profiler plugin:

```bash
pip install tensorboard tensorboard-plugin-profile
```

And then run the following in the folder where the `plugins` folder was generated:

```bash
tensorboard --logdir ./
```

## Adding Custom Annotations

By default, the traces contain only information captured from within XLA.
The [`Reactant.Profiler.annotate`](@ref) function can be used to annotate traces for Julia code evaluated *during tracing*.

```julia
Reactant.Profiler.annotate("my_annotation") do
# Do things...
end
```

The added annotations will be captured in the traces and can be seen in the different viewers along with the default XLA annotations.
When the profiler is not activated, then the custom annotations have no effect and can therefore always be activated.
20 changes: 11 additions & 9 deletions ext/ReactantArrayInterfaceExt.jl
Original file line number Diff line number Diff line change
@@ -2,19 +2,21 @@ module ReactantArrayInterfaceExt

using ArrayInterface: ArrayInterface
using Reactant:
Reactant, RArray, ConcreteRArray, ConcreteRNumber, TracedRNumber, TracedRArray, Ops
Reactant,
RArray,
ConcretePJRTArray,
ConcretePJRTNumber,
TracedRNumber,
TracedRArray,
AnyTracedRArray,
Ops

ArrayInterface.can_setindex(::Type{<:RArray}) = false
ArrayInterface.fast_scalar_indexing(::Type{<:RArray}) = false

function ArrayInterface.aos_to_soa(x::AbstractArray{<:ConcreteRNumber{T}}) where {T}
x_c = ConcreteRArray(zeros(T, size(x)))
x_c .= x
return x_c
end

function ArrayInterface.aos_to_soa(x::AbstractArray{<:TracedRNumber{T}}) where {T}
return Ops.reshape(vcat(x...), size(x)...)
for aType in
(AbstractArray{<:ConcretePJRTNumber}, AbstractArray{<:TracedRNumber}, AnyTracedRArray)
@eval ArrayInterface.aos_to_soa(x::$aType) = Reactant.aos_to_soa(x)
end

end
890 changes: 812 additions & 78 deletions ext/ReactantCUDAExt.jl

Large diffs are not rendered by default.

96 changes: 96 additions & 0 deletions ext/ReactantKernelAbstractionsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
module ReactantKernelAbstractionsExt

using Reactant

import KernelAbstractions as KA

using Adapt: Adapt

## back-end

export ReactantBackend

struct ReactantBackend <: KA.GPU end

function Base.getproperty(x::ReactantBackend, sym::Symbol)
if sym === :always_inline
return true
elseif sym === :prefer_blocks
return false
else
return Base.getfield(x, sym)
end
end

KA.allocate(n::ReactantBackend, ::Type{T}, dims::Tuple) where {T} = KA.zeros(b, T, dims)
function KA.zeros(::ReactantBackend, ::Type{T}, dims::Tuple) where {T}
return ConcretePJRTArray(zeros(T, dims))
end
function KA.ones(::ReactantBackend, ::Type{T}, dims::Tuple) where {T}
return ConcretePJRTArray(ones(T, dims))
end

KA.get_backend(::Reactant.AnyTracedRArray) = ReactantBackend()
KA.get_backend(::Reactant.AnyConcretePJRTArray) = ReactantBackend()
function KA.synchronize(::ReactantBackend) end

Adapt.adapt_storage(::ReactantBackend, a::Array) = a
Adapt.adapt_storage(::ReactantBackend, a::Reactant.AnyTracedRArray) = a
Adapt.adapt_storage(::ReactantBackend, a::Reactant.AnyConcretePJRTArray) = a
Adapt.adapt_storage(::KA.CPU, a::Reactant.AnyConcretePJRTArray) = convert(Array, a)

## memory operations

function KA.copyto!(::ReactantBackend, A, B)
Base.copyto!(A, B)
return A
end

## kernel launch

function KA.mkcontext(kernel::KA.Kernel{ReactantBackend}, _ndrange, iterspace)
return KA.CompilerMetadata{KA.ndrange(kernel),KA.DynamicCheck}(_ndrange, iterspace)
end

function KA.launch_config(kernel::KA.Kernel{ReactantBackend}, ndrange, workgroupsize)
if ndrange isa Integer
ndrange = (ndrange,)
end
if workgroupsize isa Integer
workgroupsize = (workgroupsize,)
end

# partition checked that the ndrange's agreed
if KA.ndrange(kernel) <: KA.StaticSize
ndrange = nothing
end

iterspace, dynamic =
if KA.workgroupsize(kernel) <: KA.DynamicSize && workgroupsize === nothing
# use ndrange as preliminary workgroupsize for autotuning
KA.partition(kernel, ndrange, ndrange)
else
KA.partition(kernel, ndrange, workgroupsize)
end

return ndrange, workgroupsize, iterspace, dynamic
end

KA.argconvert(k::KA.Kernel{ReactantBackend}, arg) = arg

function KA.priority!(::ReactantBackend, prio::Symbol)
if !(prio in (:high, :normal, :low))
error("priority must be one of :high, :normal, :low")
end
return nothing
end

function tokw(ndrange, workgroupsize, obj, args...)
@inline obj(args...; ndrange, workgroupsize)
end

function (obj::KA.Kernel{ReactantBackend})(args...; ndrange=nothing, workgroupsize=nothing)
@jit tokw(ndrange, workgroupsize, obj, args...)
end

end
36 changes: 36 additions & 0 deletions ext/ReactantMPIExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
module ReactantMPIExt

using Reactant: Reactant, Distributed
using MPI: MPI

# https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/mpi4py_cluster.py
Distributed.is_env_present(::Distributed.MPIEnvDetector) = MPI.Initialized()

function Distributed.get_coordinator_address(
::Distributed.MPIEnvDetector, timeout_in_seconds::Integer
)
if MPI.Comm_rank(MPI.COMM_WORLD) == 0
hostname = gethostname()
port_id = hash(hostname) % 2^12 + (65535 - 2^12 + 1)
hostname = "$(hostname):$(port_id)"
else
hostname = nothing
end

return MPI.bcast(hostname, MPI.COMM_WORLD; root=0)
end

function Distributed.get_process_count(::Distributed.MPIEnvDetector)
return Int(MPI.Comm_size(MPI.COMM_WORLD))
end

function Distributed.get_process_id(::Distributed.MPIEnvDetector)
return Int(MPI.Comm_rank(MPI.COMM_WORLD))
end

function Distributed.get_local_process_id(::Distributed.MPIEnvDetector)
new_comm = MPI.Comm_split_type(MPI.COMM_WORLD, MPI.COMM_TYPE_SHARED, 0)
return Int(MPI.Comm_rank(new_comm))
end

end
2 changes: 1 addition & 1 deletion ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
@@ -323,7 +323,7 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr
This case is not optimized and will be slow." maxlog = 1
dims = NNlib.scatter_dims(src, dst, idxs)
colons = ntuple(Returns(Colon()), dims)
start_sizes = ntuple(i -> size(src, i), dims)
start_sizes = ntuple(Base.Fix1(size, src), dims)
results = map(CartesianIndices(idxs)) do k
res = @allowscalar src[colons..., Tuple(idxs[k])...]
res isa TracedRNumber && (res = TracedUtils.broadcast_to_size(res, (1,)))
20 changes: 20 additions & 0 deletions ext/ReactantOffsetArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
module ReactantOffsetArraysExt

using OffsetArrays
using OffsetArrays: OffsetArray
using Reactant: Reactant, MLIR, Ops, TracedRArray

Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(OA::Type{<:OffsetArray}),
seen,
mode::Reactant.TraceMode,
@nospecialize(track_numbers::Type),
@nospecialize(sharding),
)
N = ndims(OA)
T = OffsetArrays.parenttype(OA)
T2 = Reactant.traced_type_inner(T, seen, mode, track_numbers, sharding)
return OffsetArray{eltype(T2),N,T2}
end

end
45 changes: 20 additions & 25 deletions ext/ReactantPythonCallExt.jl
Original file line number Diff line number Diff line change
@@ -8,22 +8,22 @@ using PythonCall

const jaxptr = Ref{Py}()

const NUMPY_SIMPLE_TYPES = (
("bool_", Bool),
("int8", Int8),
("int16", Int16),
("int32", Int32),
("int64", Int64),
("uint8", UInt8),
("uint16", UInt16),
("uint32", UInt32),
("uint64", UInt64),
("float16", Float16),
("float32", Float32),
("float64", Float64),
("complex32", ComplexF16),
("complex64", ComplexF32),
("complex128", ComplexF64),
const NUMPY_SIMPLE_TYPES = Dict(
Bool => :bool_,
Int8 => :int8,
Int16 => :int16,
Int32 => :int32,
Int64 => :int64,
UInt8 => :uint8,
UInt16 => :uint16,
UInt32 => :uint32,
UInt64 => :uint64,
Float16 => :float16,
Float32 => :float32,
Float64 => :float64,
ComplexF16 => :complex32,
ComplexF32 => :complex64,
ComplexF64 => :complex128,
)

function PythonCall.pycall(
@@ -32,15 +32,10 @@ function PythonCall.pycall(
jax = jaxptr[]
numpy = jax.numpy
inputs = map((arg0, argNs...)) do arg
JT = eltype(arg)
PT = nothing
for (CPT, CJT) in NUMPY_SIMPLE_TYPES
if JT == CJT
PT = CPT
break
end
end
numpy.zeros(size(arg); dtype=getproperty(numpy, Symbol(PT)))
numpy.zeros(
size(arg);
dtype=getproperty(numpy, NUMPY_SIMPLE_TYPES[Reactant.unwrapped_eltype(arg)]),
)
end
lowered = jax.jit(f).lower(inputs...)
txt = pyconvert(String, lowered.as_text())
118 changes: 118 additions & 0 deletions ext/ReactantSpecialFunctionsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
module ReactantSpecialFunctionsExt
using SpecialFunctions
using Reactant: Ops, Reactant, TracedRNumber, ReactantFloat, ReactantInt, ReactantFloatInt
using Reactant.TracedRNumberOverrides: float

for fn in [:digamma, :erf, :erfc, (:loggamma, :lgamma)]
(fns, fno) = fn isa Tuple ? fn : (fn, fn)
@eval(function SpecialFunctions.$fns(x::TracedRNumber{<:ReactantFloatInt})
return Ops.$fno(float(x))
end)
end

function SpecialFunctions.gamma(x::TracedRNumber{<:ReactantFloat})
return exp(Ops.lgamma(float(x)))
end

function SpecialFunctions.gamma(n::TracedRNumber{<:ReactantInt})
return round(gamma(float(n)))
end

function SpecialFunctions.loggamma1p(x::TracedRNumber{<:ReactantFloat})
return loggamma(1 + x)
end

function SpecialFunctions.logfactorial(x::TracedRNumber{<:ReactantInt})
return loggamma(1 + x)
end

# SpecialFunctions.invdigamma

function SpecialFunctions.trigamma(x::TracedRNumber{<:ReactantFloatInt})
return Ops.polygamma(Ops.constant(Float64(1)), float(x))#TODO: change Ops definition
end

function SpecialFunctions.polygamma(
n::TracedRNumber{<:ReactantFloatInt}, x::TracedRNumber{<:ReactantFloatInt}
)
return Ops.polygamma(float(n), float(x))
end

# SpecialFunctions.gamma_inc

# SpecialFunctions.gamma_inc_inv

function SpecialFunctions.loggammadiv(
a::TracedRNumber{T}, b::TracedRNumber{T}
) where {T<:ReactantFloat}
return log(gamma(b) / gamma(a + b))
end

#SpecialFunctions.gamma ...

function SpecialFunctions.beta(
x::TracedRNumber{T}, y::TracedRNumber{T}
) where {T<:ReactantFloatInt}
return gamma(x) * gamma(y) / gamma(x + y)
end

function SpecialFunctions.logbeta(
x::TracedRNumber{T}, y::TracedRNumber{T}
) where {T<:ReactantFloatInt}
return log(abs(beta(x, y)))
end

#TODO: sign function
#SpecialFunctions.logabsbeta
#SpecialFunctions.logabsbinomial

#SpecialFunctions.beta...

#utilities...

function SpecialFunctions.erf(
x::TracedRNumber{T}, y::TracedRNumber{T}
) where {T<:ReactantFloatInt}
return erf(y) - erf(x)
end

#SpecialFunctions.erfcinv

function SpecialFunctions.logerf(
x::TracedRNumber{T}, y::TracedRNumber{T}
) where {T<:ReactantFloatInt}
return log(erf(x, y))
end

function SpecialFunctions.erfcx(x::TracedRNumber{<:ReactantFloatInt})
return exp(float(x^2)) * erfc(x)
end

function SpecialFunctions.logerfc(x::TracedRNumber{<:ReactantFloatInt})
return log(erfc(x))
end

function SpecialFunctions.logerfcx(x::TracedRNumber{<:ReactantFloatInt})
return log(erfcx(x))
end

#Unsupported complex
#SpecialFunctions.erfi

#SpecialFunctions.erfinv
#SpecialFunctions.dawson
#SpecialFunctions.faddeeva

#Airy and Related Functions

#Bessel ...

#Elliptic Integrals

function SpecialFunctions.zeta(
z::TracedRNumber{T}, s::TracedRNumber{T}
) where {T<:ReactantFloatInt}
return Ops.zeta(z, s)
end

end # module ReactantSpecialFunctionsExt
18 changes: 11 additions & 7 deletions ext/ReactantStatisticsExt.jl
Original file line number Diff line number Diff line change
@@ -4,18 +4,22 @@ using Reactant: AnyTracedRArray
using Reactant.TracedUtils: materialize_traced_array
using Statistics: Statistics

function Statistics.mean(A::AnyTracedRArray{T,N}; dims=:) where {T,N}
A = materialize_traced_array(A)
function Statistics._mean(f::F, A::AnyTracedRArray{T,N}, dims) where {F,T,N}
denom = dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims)
return mapreduce(identity, +, A; dims) / denom
return mapreduce(f, +, A; dims) / denom
end

function Statistics.var(
A::AnyTracedRArray{T,N}; dims=:, mean=nothing, corrected=true
function Statistics._var(
A::AnyTracedRArray{T,N}, corrected::Bool, mean, ::Colon
) where {T,N}
A = materialize_traced_array(A)
mean === nothing && (mean = Statistics.mean(A))
denom = length(A) - corrected
return mapreduce(abs2, +, A .- mean; dims=:) / denom
end

function Statistics._var(A::AnyTracedRArray{T,N}, corrected::Bool, mean, dims) where {T,N}
mean === nothing && (mean = Statistics.mean(A; dims))
denom = (dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims)) - corrected
denom = prod(Base.Fix1(size, A), dims) - corrected
return mapreduce(abs2, +, A .- mean; dims) / denom
end

2 changes: 1 addition & 1 deletion lib/ReactantCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ReactantCore"
uuid = "a3311ec8-5e00-46d5-b541-4f83e724a433"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>", "Sergio Sánchez Ramírez <sergio.sanchez.ramirez@bsc.es>", "Paul Berg <paul@plutojl.org>", "Avik Pal <avikpal@mit.edu>"]
version = "0.1.3"
version = "0.1.5"

[deps]
ExpressionExplorer = "21656369-7473-754a-2065-74616d696c43"
116 changes: 98 additions & 18 deletions lib/ReactantCore/src/ReactantCore.jl
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@ module ReactantCore
using ExpressionExplorer: ExpressionExplorer
using MacroTools: MacroTools

export @trace, MissingTracedValue
export @trace, within_compile, MissingTracedValue

# Traits
is_traced(x) = false
@@ -15,10 +15,19 @@ end

MissingTracedValue() = MissingTracedValue(())

Base.zero(::MissingTracedValue) = MissingTracedValue()

const SPECIAL_SYMBOLS = [
:(:), :nothing, :missing, :Inf, :Inf16, :Inf32, :Inf64, :Base, :Core
]

"""
within_compile()
Returns true if this function is executed in a Reactant compilation context, otherwise false.
"""
@inline within_compile() = false # behavior is overwritten in Interpreter.jl

# Code generation
"""
@trace <expr>
@@ -115,6 +124,13 @@ macro trace(expr)
return esc(trace_if_with_returns(__module__, expr))
end
end
Meta.isexpr(expr, :call) && return esc(trace_call(__module__, expr))
if Meta.isexpr(expr, :(.), 2) && Meta.isexpr(expr.args[2], :tuple)
fname = :($(Base.Broadcast.BroadcastFunction)($(expr.args[1])))
args = only(expr.args[2:end]).args
call = Expr(:call, fname, args...)
return esc(trace_call(__module__, call))
end
Meta.isexpr(expr, :if) && return esc(trace_if(__module__, expr))
Meta.isexpr(expr, :for) && return (esc(trace_for(__module__, expr)))
return error("Only `if-elseif-else` blocks are currently supported by `@trace`")
@@ -158,8 +174,16 @@ function trace_for(mod, expr)
external_syms...,
)

cond_val(s) = :(@isdefined($s) ? $s : nothing)

while_defined = gensym(:while_defined)
locals = Expr[
[Expr(:(=), s, cond_val(s)) for s in external_syms]..., :(args = $(args_init))
]

var_syms = all_syms.args[(begin + 1):end]
reactant_code_block = quote
let args = $(args_init)
let $(locals...)
cond_fn =
$(all_syms) -> begin
local num_iters = div($limit - $start, $step, RoundDown)
@@ -170,19 +194,25 @@ function trace_for(mod, expr)
end
body_fn =
$(all_syms) -> begin
local isdefined_before = isnothing.(Any[$(var_syms...)])
local step_ = $step
local start_ = $start
local $induction = start_ + $counter * step_
$body
($counter + 1, $(all_syms.args[(begin + 1):end]...))
local results_ = Any[
s for (d, s) in zip(isdefined_before, Any[$(var_syms...)]) if !d
]
($counter + 1, results_...)
end

$(ReactantCore).traced_while(cond_fn, body_fn, args)
end
end

return quote
if any($(is_traced), $(Expr(:tuple, all_syms.args[(begin + 1):end]...)))
if $(within_compile)() && $(any)(
$(is_traced), $(Expr(:tuple, cond_val.(all_syms.args[(begin + 1):end])...))
)
$(reactant_code_block)
else
$(expr)
@@ -195,8 +225,12 @@ function trace_if_with_returns(mod, expr)
new_expr, _, all_check_vars = trace_if(
mod, expr.args[2]; store_last_line=expr.args[1], depth=1
)
cond_name = first(all_check_vars)
original_cond = expr.args[2].args[1]
expr.args[2].args[1] = cond_name
return quote
if any($(is_traced), ($(all_check_vars...),))
$(cond_name) = $(original_cond)
if $(within_compile)() && $(any)($(is_traced), ($(all_check_vars...),))
$(new_expr)
else
$(expr)
@@ -292,7 +326,7 @@ function trace_if(mod, expr; store_last_line=nothing, depth=0)
non_existant_true_branch_vars = setdiff(all_output_vars, all_true_branch_vars)
true_branch_extras = Expr(
:block,
[:($(var) = $(MissingTracedValue())) for var in non_existant_true_branch_vars]...,
[:($(var) = $(MissingTracedValue)()) for var in non_existant_true_branch_vars]...,
)

true_branch_fn = :(($(all_input_vars...),) -> begin
@@ -310,7 +344,7 @@ function trace_if(mod, expr; store_last_line=nothing, depth=0)
)
false_branch_extras = Expr(
:block,
[:($(var) = $(MissingTracedValue())) for var in non_existant_false_branch_vars]...,
[:($(var) = $(MissingTracedValue)()) for var in non_existant_false_branch_vars]...,
)

false_branch_fn = :(($(all_input_vars...),) -> begin
@@ -323,29 +357,69 @@ function trace_if(mod, expr; store_last_line=nothing, depth=0)
)
false_branch_fn = :($(false_branch_fn_name) = $(false_branch_fn))

cond_name = gensym(:cond)

reactant_code_block = quote
$(true_branch_fn)
$(false_branch_fn)
($(all_output_vars...),) = $(traced_if)(
$(cond_expr),
$(cond_name),
$(true_branch_fn_name),
$(false_branch_fn_name),
($(all_input_vars...),),
)
end

all_check_vars = [all_input_vars..., condition_vars...]
non_reactant_code_block = Expr(:if, cond_name, original_expr.args[2])
if length(original_expr.args) > 2 # has else block
append!(non_reactant_code_block.args, original_expr.args[3:end])
end

all_check_vars = [cond_name, all_input_vars..., condition_vars...]
unique!(all_check_vars)

depth > 0 && return (
reactant_code_block, (true_branch_fn_name, false_branch_fn_name), all_check_vars
quote
$(cond_name) = $(cond_expr)
$(reactant_code_block)
end,
(true_branch_fn_name, false_branch_fn_name),
all_check_vars,
)

return quote
if any($(is_traced), ($(all_check_vars...),))
$(cond_name) = $(cond_expr)
if $(within_compile)() && $(any)($(is_traced), ($(all_check_vars...),))
$(reactant_code_block)
else
$(original_expr)
$(non_reactant_code_block)
end
end
end

function correct_maybe_bcast_call(fname)
startswith(string(fname), '.') || return false, fname, fname
return true, Symbol(string(fname)[2:end]), fname
end

function trace_call(mod, call)
bcast, fname, fname_full = correct_maybe_bcast_call(call.args[1])
f = if bcast
quote
if isdefined(mod, $(Meta.quot(fname_full)))
$(fname_full)
else
Base.Broadcast.BroadcastFunction($(fname))
end
end
else
:($(fname))
end
return quote
if $(within_compile)()
$(traced_call)($f, $(call.args[2:end]...))
else
$(call)
end
end
end
@@ -366,15 +440,21 @@ function traced_if(cond, true_fn, false_fn, args)
return cond ? true_fn(args) : false_fn(args)
end

function traced_while(cond_fn, body_fn, args)
while cond_fn(args...)
args = body_fn(args...)
end
return args
end
function traced_while end # defined inside Reactant.jl

traced_call(f, args...; kwargs...) = f(args...; kwargs...)

function cleanup_expr_to_avoid_boxing(expr, prepend::Symbol, all_vars)
return MacroTools.postwalk(expr) do x
if Meta.isexpr(x, :kw) # undo lhs rewriting
if startswith(string(x.args[1]), string(prepend))
return Expr(
:kw,
Symbol(string(x.args[1])[(length(string(prepend)) + 1):end]),
x.args[2],
)
end
end
if x isa Symbol && x all_vars
return Symbol(prepend, x)
end
1,562 changes: 1,216 additions & 346 deletions src/Compiler.jl

Large diffs are not rendered by default.

339 changes: 155 additions & 184 deletions src/ConcreteRArray.jl

Large diffs are not rendered by default.

156 changes: 5 additions & 151 deletions src/ControlFlow.jl
Original file line number Diff line number Diff line change
@@ -1,159 +1,13 @@
function ReactantCore.traced_if(
cond::TracedRNumber{Bool}, true_fn::TFn, false_fn::FFn, args
) where {TFn,FFn}
(_, true_branch_compiled, true_branch_results, _, _, _, _, _, true_linear_results) = Reactant.TracedUtils.make_mlir_fn(
true_fn,
args,
(),
string(gensym("true_branch")),
false;
return_dialect=:stablehlo,
no_args_in_result=true,
construct_function_without_args=true,
)

(_, false_branch_compiled, false_branch_results, _, _, _, _, _, false_linear_results) = Reactant.TracedUtils.make_mlir_fn(
false_fn,
args,
(),
string(gensym("false_branch")),
false;
return_dialect=:stablehlo,
no_args_in_result=true,
construct_function_without_args=true,
)

@assert length(true_branch_results) == length(false_branch_results) "true branch returned $(length(true_branch_results)) results, false branch returned $(length(false_branch_results)). This shouldn't happen."

result_types = MLIR.IR.Type[]
linear_results = []
true_block_insertions = []
false_block_insertions = []
for (i, (tr, fr)) in enumerate(zip(true_branch_results, false_branch_results))
if typeof(tr) != typeof(fr)
if !(tr isa MissingTracedValue) && !(fr isa MissingTracedValue)
error("Result #$(i) for the branches have different types: true branch \
returned `$(typeof(tr))`, false branch returned `$(typeof(fr))`.")
elseif tr isa MissingTracedValue
push!(result_types, MLIR.IR.type(fr.mlir_data))
push!(linear_results, TracedUtils.new_traced_value(false_linear_results[i]))
push!(true_block_insertions, (i => linear_results[end]))
else
push!(result_types, MLIR.IR.type(tr.mlir_data))
push!(linear_results, TracedUtils.new_traced_value(true_linear_results[i]))
push!(false_block_insertions, (i => linear_results[end]))
end
else
push!(result_types, MLIR.IR.type(tr.mlir_data))
push!(linear_results, TracedUtils.new_traced_value(tr))
end
end

# Replace all uses of missing values with the correct values
true_branch_region = get_region_removing_missing_values(
true_branch_compiled, true_block_insertions
)

false_branch_region = get_region_removing_missing_values(
false_branch_compiled, false_block_insertions
)

MLIR.IR.rmfromparent!(true_branch_compiled)
MLIR.IR.rmfromparent!(false_branch_compiled)

if_compiled = MLIR.Dialects.stablehlo.if_(
cond.mlir_data;
true_branch=true_branch_region,
false_branch=false_branch_region,
result_0=result_types,
)

return map(enumerate(linear_results)) do (i, res)
res.mlir_data = MLIR.IR.result(if_compiled, i)
return res
end
end

function ReactantCore.traced_while(
cond_fn::CFn, body_fn::BFn, args
) where {CFn<:Function,BFn<:Function}
# TODO: detect and prevent mutation within the condition

# We promote all incoming args (is there a better way to do this?)
traced_args = [
if v isa Number && !(v isa TracedType)
Reactant.TracedUtils.promote_to(TracedRNumber{typeof(v)}, v)
else
v
end for v in args
]

(_, cond_fn_compiled, cond_fn_results, _, _, _, _, in_tys, cond_fn_linear_results) = Reactant.TracedUtils.make_mlir_fn(
cond_fn,
traced_args,
(),
string(gensym("cond_fn")),
false;
no_args_in_result=true,
return_dialect=:stablehlo,
do_transpose=false,
)

(_, body_fn_compiled, body_fn_results, _, _, _, _, _, body_fn_linear_results) = Reactant.TracedUtils.make_mlir_fn(
body_fn,
traced_args,
(),
string(gensym("body_fn")),
false;
no_args_in_result=true,
return_dialect=:stablehlo,
do_transpose=false,
)

cond_reg = take_region(cond_fn_compiled)
body_reg = take_region(body_fn_compiled)

MLIR.IR.rmfromparent!(cond_fn_compiled)
MLIR.IR.rmfromparent!(body_fn_compiled)

result_0 = in_tys

operands = MLIR.IR.Value[v.mlir_data for v in traced_args]

while_compiled = MLIR.Dialects.stablehlo.while_(
operands; result_0, cond=cond_reg, body=body_reg
)

return map(enumerate(traced_args)) do (i, res)
res.mlir_data = MLIR.IR.result(while_compiled, i)
return res
end
return Ops.if_condition(cond, true_fn, false_fn, args...)
end

function take_region(compiled_fn)
region = MLIR.IR.Region()
MLIR.API.mlirRegionTakeBody(region, MLIR.API.mlirOperationGetRegion(compiled_fn, 0))
return region
function ReactantCore.traced_call(f::Function, args...)
return Ops.call(f, args...)
end

function get_region_removing_missing_values(compiled_fn, insertions)
region = take_region(compiled_fn)
block = MLIR.IR.Block(MLIR.API.mlirRegionGetFirstBlock(region), false)
return_op = MLIR.IR.terminator(block)
for (i, rt) in insertions
if rt isa TracedRNumber
attr = MLIR.IR.DenseElementsAttribute(Array{eltype(rt)}(undef, ()))
op = MLIR.Dialects.stablehlo.constant(; value=attr)
elseif rt isa TracedRArray
attr = MLIR.IR.DenseElementsAttribute(Array{eltype(rt)}(undef, size(rt)))
op = MLIR.Dialects.stablehlo.constant(; value=attr)
else
error("Unknown type $(typeof(rt))")
end
MLIR.IR.rmfromparent!(op)
insert!(block, 1, op)
val = MLIR.IR.result(op, 1)
MLIR.API.mlirValueReplaceAllUsesOfWith(MLIR.IR.operand(return_op, i), val)
end
return region
function ReactantCore.traced_while(cond_fn::CFn, body_fn::BFn, args) where {CFn,BFn}
return Ops.while_loop(cond_fn, body_fn, args...)
end
60 changes: 60 additions & 0 deletions src/Devices.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""
devices(backend::String)
devices(backend::XLA.AbstractClient = XLA.default_backend())

Return a list of devices available for the given client.
"""
devices(backend::String) = devices(XLA.client(backend))

devices(client::XLA.AbstractClient=XLA.default_backend()) = XLA.devices(client)

"""
addressable_devices(backend::String)
addressable_devices(backend::XLA.AbstractClient = XLA.default_backend())

Return a list of addressable devices available for the given client.
"""
addressable_devices(backend::String) = addressable_devices(XLA.client(backend))

function addressable_devices(client::XLA.AbstractClient=XLA.default_backend())
return XLA.addressable_devices(client)
end

# https://github.com/jax-ml/jax/blob/152099ee0ef31119f16f4c2dac50d84fcb1575ef/jax/_src/hardware_utils.py#L19-L55
const _GOOGLE_PCI_VENDOR_ID = "0x1ae0"
const _TPU_PCI_DEVICE_IDS = (
# TPU v2, v3
"0x0027",
# No public name (plc)
"0x0056",
# TPU v4
"0x005e",
# TPU v5p
"0x0062",
# TPU v5e
"0x0063",
# TPU v6e
"0x006f",
)

function has_tpu()
Sys.islinux() || return false

devices_dir = "/sys/bus/pci/devices/"
isdir(devices_dir) || return false

try
for path in readdir(devices_dir; join=true, sort=false)
if strip(read(joinpath(path, "vendor"), String)) == _GOOGLE_PCI_VENDOR_ID &&
strip(read(joinpath(path, "device"), String)) in _TPU_PCI_DEVICE_IDS
return true
end
end
catch ex
@warn "failed to query PCI device information" maxlog = 1 exception = (
ex, catch_backtrace()
)
end

return false
end
162 changes: 162 additions & 0 deletions src/Distributed.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
module Distributed

using ..Reactant: Reactant

const initialized = Ref(false)

function initialize(;
coordinator_address::Union{Nothing,String}=nothing,
num_processes::Union{Nothing,Integer}=nothing,
process_id::Union{Nothing,Integer}=nothing,
local_gpu_device_ids::Union{Nothing,Vector{Int}}=nothing,
initialization_timeout_in_seconds::Integer=300,
kwargs...,
)
@assert !initialized[] "`Distributed.initialize` has already been called"

(coordinator_address, num_processes, process_id, local_gpu_device_ids) = auto_detect_unset_distributed_params(;
coordinator_address,
num_processes,
process_id,
local_gpu_device_ids,
initialization_timeout_in_seconds,
)

@debug "Detected Reactant distributed params" coordinator_address num_processes process_id local_gpu_device_ids

Reactant.XLA.update_global_state!(;
coordinator_address, num_processes, process_id, local_gpu_device_ids, kwargs...
)

@debug "New Global State" Reactant.XLA.global_state

initialized[] = true
return nothing
end

abstract type AbstractClusterEnvDetector end

abstract type AbstractOMPIClusterEnvDetector <: AbstractClusterEnvDetector end

struct OpenMPIORTEEnvDetector <: AbstractOMPIClusterEnvDetector end
struct OpenMPIPMIXEnvDetector <: AbstractOMPIClusterEnvDetector end

struct MPIEnvDetector <: AbstractClusterEnvDetector end

# Based on https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/cluster.py

is_env_present(::AbstractClusterEnvDetector) = false

function get_coordinator_address end
function get_process_count end
function get_process_id end
function get_local_process_id end

function auto_detect_unset_distributed_params(;
detector_list=[OpenMPIORTEEnvDetector(), OpenMPIPMIXEnvDetector(), MPIEnvDetector()],
coordinator_address::Union{Nothing,String}=nothing,
num_processes::Union{Nothing,Integer}=nothing,
process_id::Union{Nothing,Integer}=nothing,
local_gpu_device_ids::Union{Nothing,Vector{Int}}=nothing,
initialization_timeout_in_seconds::Integer=300,
)
if all(
Base.Fix2(!==, nothing),
(coordinator_address, num_processes, process_id, local_gpu_device_ids),
)
return coordinator_address, num_processes, process_id, local_gpu_device_ids
end

idx = findfirst(is_env_present, detector_list)
if idx === nothing
error("Couldn't find a functional cluster environment detector. Attempted to use: \
$(detector_list)")
end

detector = detector_list[idx]

@debug "Detected cluster environment" detector

if coordinator_address === nothing
coordinator_address = get_coordinator_address(
detector, initialization_timeout_in_seconds
)
end

if num_processes === nothing
num_processes = get_process_count(detector)
end

if process_id === nothing
process_id = get_process_id(detector)
end

if local_gpu_device_ids === nothing
local_gpu_device_ids = [get_local_process_id(detector)]
end

return coordinator_address, num_processes, process_id, local_gpu_device_ids
end

# OpenMPIORTEEnvDetector & OpenMPIPMIXEnvDetector
# Based on https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/ompi_cluster.py and adapted for latest OpenMPI versions
const _ORTE_URI = "OMPI_MCA_orte_hnp_uri"
const _PMIX_SERVER_URI = (
"PMIX_SERVER_URI2",
"PMIX_SERVER_URI3",
"PMIX_SERVER_URI4",
"PMIX_SERVER_URI41",
"PMIX_SERVER_URI21",
)
const _OMPI_PROCESS_COUNT = "OMPI_COMM_WORLD_SIZE"
const _OMPI_PROCESS_ID = "OMPI_COMM_WORLD_RANK"
const _OMPI_LOCAL_PROCESS_ID = "OMPI_COMM_WORLD_LOCAL_RANK"

is_env_present(::OpenMPIORTEEnvDetector) = haskey(ENV, _ORTE_URI)
is_env_present(::OpenMPIPMIXEnvDetector) = any(Base.Fix1(haskey, ENV), _PMIX_SERVER_URI)

function get_coordinator_address(::OpenMPIORTEEnvDetector, ::Integer)
orte_uri = ENV[_ORTE_URI]

job_id = parse(Int, split(orte_uri, '.'; limit=2)[1])
port = job_id % 2^12 + (65535 - 2^12 + 1)

launcher_ip_match = match(r"tcp://(.+?)[,:]|tcp6://\[(.+?)[,\]]", orte_uri)

@assert launcher_ip_match !== nothing "Could not parse coordinator IP address from \
Open MPI environment."

launcher_ip = launcher_ip_match.captures[findfirst(
!isnothing, launcher_ip_match.captures
)]
return "$(launcher_ip):$(port)"
end

function get_coordinator_address(::OpenMPIPMIXEnvDetector, ::Integer)
varname = findfirst(Base.Fix1(haskey, ENV), _PMIX_SERVER_URI)
pmix_uri = ENV[_PMIX_SERVER_URI[varname]]

job_id = parse(Int, split(split(pmix_uri, '-'; limit=3)[3], "@"; limit=2)[1])
port = job_id % 2^12 + (65535 - 2^12 + 1)

launcher_ip_match = match(r"tcp4://(.+?):|tcp6://\[(.+?)\]", pmix_uri)

@assert launcher_ip_match !== nothing "Could not parse coordinator IP address from \
Open MPI environment."

launcher_ip = launcher_ip_match.captures[findfirst(
!isnothing, launcher_ip_match.captures
)]

return "$(launcher_ip):$(port)"
end

get_process_count(::AbstractOMPIClusterEnvDetector) = parse(Int, ENV[_OMPI_PROCESS_COUNT])

get_process_id(::AbstractOMPIClusterEnvDetector) = parse(Int, ENV[_OMPI_PROCESS_ID])

function get_local_process_id(::AbstractOMPIClusterEnvDetector)
return parse(Int, ENV[_OMPI_LOCAL_PROCESS_ID])
end

end
7 changes: 7 additions & 0 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# TODO: move the overload_autodiff here as well

# The default `onehot` will lead to scalar indexing
function Enzyme.onehot(x::TracedRArray{T,N}) where {T,N}
x_arr = zeros(T, size(x))
return map(Base.Fix1(TracedUtils.promote_to, TracedRArray{T,N}), Enzyme.onehot(x_arr))
end
70 changes: 54 additions & 16 deletions src/Interpreter.jl
Original file line number Diff line number Diff line change
@@ -39,6 +39,25 @@ function set_reactant_abi(
)
(; fargs, argtypes) = arginfo

if f === ReactantCore.within_compile
if length(argtypes) != 1
@static if VERSION < v"1.11.0-"
return CallMeta(Union{}, Effects(), NoCallInfo())
else
return CallMeta(Union{}, Union{}, Effects(), NoCallInfo())
end
end
@static if VERSION < v"1.11.0-"
return CallMeta(
Core.Const(true), Core.Compiler.EFFECTS_TOTAL, MethodResultPure()
)
else
return CallMeta(
Core.Const(true), Union{}, Core.Compiler.EFFECTS_TOTAL, MethodResultPure()
)
end
end

# Improve inference by considering call_with_reactant as having the same results as
# the original call
if f === Reactant.call_with_reactant
@@ -64,8 +83,9 @@ end
ReactantCacheToken(),
REACTANT_METHOD_TABLE,
world,
true, #=forward_rules=#
true, #=reverse_rules=#
false, #=forward_rules=#
false, #=reverse_rules=#
false, #=inactive_rules=#
false, #=broadcast_rewrite=#
set_reactant_abi,
)
@@ -80,8 +100,9 @@ else
REACTANT_CACHE,
REACTANT_METHOD_TABLE,
world,
true, #=forward_rules=#
true, #=forward_rules=#
false, #=forward_rules=#
false, #=reverse_rules=#
false, #=inactive_rules=#
false, #=broadcast_rewrite=#
set_reactant_abi,
)
@@ -167,7 +188,7 @@ end
function push_acts!(ad_inputs, x::BatchDuplicated, path, reverse)
TracedUtils.push_val!(ad_inputs, x.val, path)
if !reverse
ET = eltype(x.val)
ET = unwrapped_eltype(x.val)
predims = size(x.val)
cval = MLIR.IR.result(
MLIR.Dialects.stablehlo.concatenate(
@@ -182,7 +203,7 @@ end
function push_acts!(ad_inputs, x::BatchDuplicatedNoNeed, path, reverse)
TracedUtils.push_val!(ad_inputs, x.val, path)
if !reverse
ET = eltype(x.val)
ET = unwrapped_eltype(x.val)
predims = size(x.val)
cval = MLIR.IR.result(
MLIR.Dialects.stablehlo.concatenate(
@@ -206,14 +227,13 @@ function set_act!(inp, path, reverse, tostore; emptypath=false)
end

#if inp isa Enzyme.Active || !reverse
x.mlir_data = tostore
TracedUtils.set_mlir_data!(x, tostore)
#else
# x.mlir_data = MLIR.IR.result(MLIR.Dialects.stablehlo.add(x.mlir_data, tostore), 1)
#end

if emptypath
x.paths = ()
end
emptypath && TracedUtils.set_paths!(x, ())
return nothing
end

function overload_autodiff(
@@ -235,9 +255,12 @@ function overload_autodiff(
primf = f.val
primargs = ((v.val for v in args)...,)

fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = TracedUtils.make_mlir_fn(
mlir_fn_res = TracedUtils.make_mlir_fn(
primf, primargs, (), string(f) * "_autodiff", false
)
(; result, linear_args, in_tys, linear_results) = mlir_fn_res
fnwrap = mlir_fn_res.fnwrapped
func2 = mlir_fn_res.f

activity = Int32[]
ad_inputs = MLIR.IR.Value[]
@@ -264,22 +287,35 @@ function overload_autodiff(
for a in linear_results
if TracedUtils.has_residx(a)
if needs_primal(CMode)
push!(outtys, TracedUtils.transpose_ty(MLIR.IR.type(a.mlir_data)))
push!(
outtys,
TracedUtils.transpose_ty(MLIR.IR.type(TracedUtils.get_mlir_data(a))),
)
end
if CMode <: Enzyme.ForwardMode && !(A <: Enzyme.Const)
if width == 1
push!(outtys, TracedUtils.transpose_ty(MLIR.IR.type(a.mlir_data)))
push!(
outtys,
TracedUtils.transpose_ty(
MLIR.IR.type(TracedUtils.get_mlir_data(a))
),
)
else
push!(
outtys,
TracedUtils.batch_ty(
width, TracedUtils.transpose_ty(MLIR.IR.type(a.mlir_data))
width,
TracedUtils.transpose_ty(
MLIR.IR.type(TracedUtils.get_mlir_data(a))
),
),
)
end
end
else
push!(outtys, TracedUtils.transpose_ty(MLIR.IR.type(a.mlir_data)))
push!(
outtys, TracedUtils.transpose_ty(MLIR.IR.type(TracedUtils.get_mlir_data(a)))
)
end
end
for (i, act) in enumerate(activity)
@@ -298,7 +334,9 @@ function overload_autodiff(
act = act_from_type(A, reverse, needs_primal(CMode))
push!(ret_activity, act)
if act == enzyme_out || act == enzyme_outnoneed
attr = fill(MLIR.IR.Attribute(eltype(a)(1)), Ops.mlir_type(a))
attr = MLIR.IR.DenseElementsAttribute(
fill(one(unwrapped_eltype(a)), size(a))
)
cst = MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1)
push!(ad_inputs, cst)
end
Loading