diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 9a9014b649..d78295e985 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -32,43 +32,7 @@ steps: cuda: "*" env: REACTANT_TEST_GROUP: "{{matrix.group}}" - if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 120 - - - label: ":julia: :linux: aarch64 - Julia v{{matrix.version}} -- {{matrix.group}}" - matrix: - setup: - version: - - "1.10" - - "1.11" - group: - - core - - neural_networks - - integration - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.version}}" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - - lib/ReactantCore/src - commands: | - julia --project=. -e 'println("--- :julia: Instantiating project") - using Pkg - Pkg.develop([PackageSpec(path="lib/ReactantCore")])' - - julia --project=. -e 'println("--- :julia: Run Tests") - using Pkg - Pkg.test(; coverage="user")' - agents: - queue: "juliaecosystem" - os: "linux" - sandbox_capable: "true" - arch: "aarch64" - env: - REACTANT_TEST_GROUP: "{{matrix.group}}" + CUDA_VISIBLE_DEVICES: 0 if: build.message !~ /\[skip tests\]/ timeout_in_minutes: 120 diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000000..9b3aa8b721 --- /dev/null +++ b/.clang-format @@ -0,0 +1 @@ +BasedOnStyle: LLVM diff --git a/.github/workflows/CI-localjll.yml b/.github/workflows/CI-localjll.yml new file mode 100644 index 0000000000..ddb9c0dbe8 --- /dev/null +++ b/.github/workflows/CI-localjll.yml @@ -0,0 +1,93 @@ +name: CI with local libReactant + +on: + pull_request: + paths: + - '.github/workflows/CI-localjll.yml' + - 'deps/**' + push: + branches: + - main + - release-* + tags: '*' + paths: + - '.github/workflows/CI-localjll.yml' + - 'deps/**' + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + test: + name: Julia ${{ matrix.version }} - ${{ matrix.os }} - local libReactant - ${{ github.event_name }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + version: + - '1.10' + - '1.11' + os: + - ubuntu-24.04 + - macOS-latest + exclude: + - os: macOS-latest + version: '1.10' + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + - uses: julia-actions/cache@v2 + - uses: bazel-contrib/setup-bazel@0.14.0 + name: Set up Bazel + with: + # Avoid downloading Bazel every time. + bazelisk-cache: true + # Store build cache per workflow. + disk-cache: ${{ github.workflow }}-${{ matrix.os }}-${{ matrix.version }} + # Share repository cache between workflows. + repository-cache: true + bazelisk-version: 1.x + - name: Prepare build on macOS + if: ${{ startsWith(matrix.os, 'macOS-') }} + run: | + echo "SDKROOT=$(xcrun --show-sdk-path)" >> "${GITHUB_ENV}" + - name: Build libReactant + run: | + python -m pip install numpy + julia --color=yes --project=deps -e 'using Pkg; Pkg.instantiate()' + julia --color=yes --project=deps deps/build_local.jl + cp LocalPreferences.toml test/ + - name: "Install Dependencies" + run: | + import Pkg + Pkg.Registry.update() + # Install packages present in subdirectories + dev_pks = Pkg.PackageSpec[] + for path in ("lib/ReactantCore",) + push!(dev_pks, Pkg.PackageSpec(; path)) + end + Pkg.develop(dev_pks) + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=. {0} + # Only in Julia v1.10 we need to install `ReactantCore` manually. + if: ${{ matrix.version == '1.10' }} + env: + JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager + - name: "Run Tests" + run: | + import Pkg + Pkg.Registry.update() + Pkg.test(; coverage="user") + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=. {0} + id: run_tests + env: + JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager + XLA_FLAGS: "--xla_force_host_platform_device_count=8" + - uses: julia-actions/julia-processcoverage@v1 + - uses: codecov/codecov-action@v5 + with: + files: lcov.info diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 52014c48ef..0d84430243 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -1,11 +1,26 @@ name: CI + on: pull_request: + paths: + - '.github/workflows/CI.yml' + - 'ext/**' + - 'lib/**' + - 'src/**' + - 'test/**' + - 'Project.toml' push: branches: - main - release-* tags: '*' + paths: + - '.github/workflows/CI.yml' + - 'ext/**' + - 'lib/**' + - 'src/**' + - 'test/**' + - 'Project.toml' concurrency: # Skip intermediate builds: always. @@ -15,7 +30,8 @@ concurrency: jobs: test: - name: Julia ${{ matrix.version }} - ${{ matrix.test_group }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ matrix.libReactant }} libReactant - assertions=${{ matrix.assertions }} - ${{ github.event_name }} + timeout-minutes: 90 + name: Julia ${{ matrix.version }} - ${{ matrix.test_group }} - ${{ matrix.os }} - assertions=${{ matrix.assertions }} - ${{ github.event_name }} runs-on: ${{ matrix.os }} strategy: fail-fast: false @@ -25,63 +41,49 @@ jobs: - '1.11' # - 'nightly' os: - - ubuntu-20.04 + - ubuntu-24.04 + # `ubuntu-22.04-arm` is considered more stable than `ubuntu-24.04-arm`: + # . + - ubuntu-22.04-arm + - macOS-13 - macOS-latest test_group: - core - neural_networks - integration - arch: - - x64 - - aarch64 assertions: - false - libReactant: [packaged] include: - - os: ubuntu-20.04 - arch: x64 - libReactant: packaged + - os: ubuntu-24.04 version: '1.10' assertions: true test_group: core - - os: ubuntu-20.04 - arch: x64 - libReactant: packaged + - os: ubuntu-24.04 version: '1.10' assertions: true test_group: neural_networks - - os: ubuntu-20.04 - arch: x64 - libReactant: packaged + - os: ubuntu-24.04 version: '1.10' assertions: true test_group: integration - # - os: ubuntu-20.04 - # arch: x86 + # - os: ubuntu-24.04 # libReactant: packaged # version: '1.10' # test_group: core - # - os: ubuntu-20.04 - # arch: x86 + # - os: ubuntu-24.04 # libReactant: packaged # version: '1.10' # test_group: neural_networks - # - os: ubuntu-20.04 - # arch: x86 + # - os: ubuntu-24.04 # libReactant: packaged # version: '1.10' # test_group: integration - exclude: - # these are run on Buildkite - - os: ubuntu-20.04 - arch: aarch64 steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 if: ${{ ! matrix.assertions }} with: version: ${{ matrix.version }} - arch: ${{ matrix.arch }} - uses: julia-actions/cache@v2 - uses: actions/checkout@v4 if: ${{ matrix.assertions }} @@ -95,23 +97,7 @@ jobs: sed -i.bak 's/exit 2/exit 0/g' julia/deps/tools/jlchecksum make -C julia -j $(nproc) FORCE_ASSERTIONS=1 LLVM_ASSERTIONS=1 JULIA_PRECOMPILE=0 echo $PWD/julia/usr/bin >> $GITHUB_PATH - - name: Build libReactant - if: ${{ matrix.libReactant == 'local' && matrix.os != 'macOS-latest'}} - id: build_libreactant - run: | - python -m pip install numpy - julia --color=yes --project=deps -e 'using Pkg; Pkg.instantiate()' - julia --color=yes --project=deps deps/build_local.jl - cp LocalPreferences.toml test/ - - name: Build libReactant MacOS - if: ${{ matrix.libReactant == 'local' && matrix.os == 'macOS-latest'}} - id: build_libreactant_mac - run: | - python -m pip install numpy - julia --color=yes --project=deps -e 'using Pkg; Pkg.instantiate()' - SDKROOT=`xcrun --show-sdk-path` julia --color=yes --project=deps deps/build_local.jl - cp LocalPreferences.toml test/ - - name: "Install Dependencies and Run Tests" + - name: "Install Dependencies" run: | import Pkg Pkg.Registry.update() @@ -121,46 +107,24 @@ jobs: push!(dev_pks, Pkg.PackageSpec(; path)) end Pkg.develop(dev_pks) - Pkg.instantiate() + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=. {0} + # Only in Julia v1.10 we need to install `ReactantCore` manually. + if: ${{ matrix.version == '1.10' }} + env: + JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager + - name: "Run Tests" + timeout-minutes: 60 + run: | + import Pkg + Pkg.Registry.update() Pkg.test(; coverage="user") shell: julia --color=yes --code-coverage=user --depwarn=yes --project=. {0} id: run_tests env: JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager REACTANT_TEST_GROUP: ${{ matrix.test_group }} + XLA_FLAGS: "--xla_force_host_platform_device_count=8" - uses: julia-actions/julia-processcoverage@v1 - if: steps.run_tests.outcome == 'success' - uses: codecov/codecov-action@v5 - if: steps.run_tests.outcome == 'success' with: files: lcov.info - docs: - name: Documentation - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: '1' - - uses: julia-actions/cache@v2 - - run: | - julia --color=yes --project=docs -e ' - using Pkg - Pkg.develop([ - PackageSpec(path=pwd()), - PackageSpec("Reactant_jll"), - PackageSpec(path="lib/ReactantCore") - ]) - Pkg.instantiate()' - env: - JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager - - run: | - julia --color=yes --project=docs -e ' - using Documenter: DocMeta, doctest - using Reactant - DocMeta.setdocmeta!(Reactant, :DocTestSetup, :(using Reactant); recursive=true) - doctest(Reactant)' - - run: julia --color=yes --project=docs docs/make.jl - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml index b90370aa1a..6e8da06dca 100644 --- a/.github/workflows/CompatHelper.yml +++ b/.github/workflows/CompatHelper.yml @@ -33,7 +33,10 @@ jobs: name = "CompatHelper" uuid = "aa819f21-2bde-4658-8897-bab36330d9b7" version = "3" - Pkg.add(; name, uuid, version) + # Temporarily use debugging version + url = "https://github.com/JuliaRegistries/CompatHelper.jl.git" + rev = "f408ea193f9573c68a68d72932bcd56268c60340" + Pkg.add(; url, rev) shell: julia --color=yes {0} - name: "Run CompatHelper" run: | diff --git a/.github/workflows/Documenter.yaml b/.github/workflows/Documenter.yaml new file mode 100644 index 0000000000..b7e51cc450 --- /dev/null +++ b/.github/workflows/Documenter.yaml @@ -0,0 +1,60 @@ +name: Documentation + +on: + pull_request: + paths: + - '.github/workflows/Documenter.yaml' + - 'docs/**' + - 'lib/**' + - 'src/**' + push: + branches: + - main + tags: '*' + paths: + - '.github/workflows/Documenter.yaml' + - 'docs/**' + - 'lib/**' + - 'src/**' + +concurrency: + # Same group concurrency as the `PreviewCleanup.yml` workflow, because they both + # git-push to the same branch, so we want to avoid clashes. NOTE: this is + # different from the concurrency group below, which is to cancel successive + # jobs from within the PR. + group: docs-pushing + +jobs: + docs: + name: Documentation + runs-on: ubuntu-latest + concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: '1' + - uses: julia-actions/cache@v2 + - name: Instantiate docs environment + run: | + julia --color=yes --project=docs -e ' + using Pkg + Pkg.instantiate()' + env: + JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager + - name: Run doctests + run: | + julia --color=yes --project=docs -e ' + using Documenter: DocMeta, doctest + using Reactant + DocMeta.setdocmeta!(Reactant, :DocTestSetup, :(using Reactant); recursive=true) + doctest(Reactant)' + - name: Build documentation + run: julia --color=yes --project=docs docs/make.jl + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} diff --git a/.github/workflows/PreviewCleanup.yml b/.github/workflows/PreviewCleanup.yml new file mode 100644 index 0000000000..8ecaab80f9 --- /dev/null +++ b/.github/workflows/PreviewCleanup.yml @@ -0,0 +1,30 @@ +name: Doc Preview Cleanup + +on: + pull_request: + types: [closed] + +concurrency: + # Same group concurrency as the `docs.yml` workflow, because they both + # git-push to the same branch, so we want to avoid clashes. + group: docs-pushing + +jobs: + doc-preview-cleanup: + runs-on: ubuntu-latest + steps: + - name: Checkout gh-pages branch + uses: actions/checkout@v4 + with: + ref: gh-pages + - name: Delete preview and history + push changes + run: | + preview_directory=previews/PR${{ github.event.number }} + if [[ -d "${preview_directory}" ]]; then + git config user.name "${{github.actor}}" + git config user.email "${{github.actor_id}}+${{github.actor}}@users.noreply.github.com" + git rm -rf "${preview_directory}" + git commit -m 'Cleanup docs for PR #${{ github.event.number }}' + git branch gh-pages-new $(echo "Delete history" | git commit-tree HEAD^{tree}) + git push --force origin gh-pages-new:gh-pages + fi diff --git a/.github/workflows/benchmark_aggregate.yml b/.github/workflows/benchmark_aggregate.yml index 6f78ae3ae4..23e53efad7 100644 --- a/.github/workflows/benchmark_aggregate.yml +++ b/.github/workflows/benchmark_aggregate.yml @@ -1,4 +1,5 @@ name: Benchmarks + permissions: contents: write # contents permission to update benchmark contents in gh-pages branch statuses: read @@ -7,14 +8,25 @@ permissions: on: pull_request: - + paths: + - '.github/workflows/benchmark_aggregate.yml' + - 'ext/**' + - 'lib/**' + - 'src/**' + - 'Project.toml' push: branches: - main + paths: + - '.github/workflows/benchmark_aggregate.yml' + - 'ext/**' + - 'lib/**' + - 'src/**' + - 'Project.toml' jobs: benchmark: - if: ${{ !contains(github.event.head_commit.message, '[skip benchmarks]') }} + if: ${{ !contains(github.event.head_commit.message, '[skip benchmarks]') && ! github.event.pull_request.head.repo.fork }} runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/downgrade.yml b/.github/workflows/downgrade.yml index 3f03f8d0d0..48bb36db2d 100644 --- a/.github/workflows/downgrade.yml +++ b/.github/workflows/downgrade.yml @@ -1,11 +1,24 @@ name: Downgrade + on: pull_request: branches: - main + paths: + - '.github/workflows/downgrade.yml' + - 'ext/**' + - 'lib/**' + - 'src/**' + - 'Project.toml' push: branches: - main + paths: + - '.github/workflows/downgrade.yml' + - 'ext/**' + - 'lib/**' + - 'src/**' + - 'Project.toml' concurrency: # Skip intermediate builds: always. @@ -16,6 +29,7 @@ concurrency: jobs: downgrade: # if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} + timeout-minutes: 90 runs-on: ubuntu-latest strategy: fail-fast: false @@ -29,6 +43,7 @@ jobs: - uses: julia-actions/setup-julia@v2 with: version: "1.10" + - uses: julia-actions/cache@v2 - uses: julia-actions/julia-downgrade-compat@v1 with: skip: "ReactantCore" @@ -42,16 +57,14 @@ jobs: push!(dev_pks, Pkg.PackageSpec(; path)) end Pkg.develop(dev_pks) - Pkg.instantiate() Pkg.test(; coverage="user") shell: julia --color=yes --code-coverage=user --depwarn=yes --project=. {0} id: run_tests env: JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager REACTANT_TEST_GROUP: ${{ matrix.test_group }} + XLA_FLAGS: "--xla_force_host_platform_device_count=8" - uses: julia-actions/julia-processcoverage@v1 - if: steps.run_tests.outcome == 'success' - uses: codecov/codecov-action@v5 - if: steps.run_tests.outcome == 'success' with: files: lcov.info diff --git a/.github/workflows/format-check-cpp.yml b/.github/workflows/format-check-cpp.yml new file mode 100644 index 0000000000..ab39b03d3d --- /dev/null +++ b/.github/workflows/format-check-cpp.yml @@ -0,0 +1,31 @@ +name: Format Suggestions + +on: + push: + branches: + - main + tags: '*' + paths: + - '.github/workflows/format-check-cpp.yml' + - '**/*.cpp' + - '**/*.h' + pull_request: + paths: + - '.github/workflows/format-check-cpp.yml' + - '**/*.cpp' + - '**/*.h' + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: always. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + code-style-cpp: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: DoozyX/clang-format-lint-action@v0.18.2 + with: + source: 'deps' diff --git a/.github/workflows/format-check.yml b/.github/workflows/format-check.yml index 844401ccd5..c3e642e00e 100644 --- a/.github/workflows/format-check.yml +++ b/.github/workflows/format-check.yml @@ -1,10 +1,17 @@ name: Format Suggestions + on: push: branches: - main tags: '*' + paths: + - '.github/workflows/format-check.yml' + - '**/*.jl' pull_request: + paths: + - '.github/workflows/format-check.yml' + - '**/*.jl' concurrency: # Skip intermediate builds: always. @@ -13,7 +20,7 @@ concurrency: cancel-in-progress: true jobs: - code-style: + code-style-julia: runs-on: ubuntu-latest steps: - uses: julia-actions/julia-format@v3 diff --git a/.github/workflows/format-pr.yml b/.github/workflows/format-pr.yml index baeb3caa0f..9c17a24d30 100644 --- a/.github/workflows/format-pr.yml +++ b/.github/workflows/format-pr.yml @@ -1,4 +1,5 @@ name: Format 'main' + on: schedule: - cron: '0 0 * * *' @@ -38,6 +39,7 @@ jobs: branch: format-main delete-branch: true labels: format + author: enzyme-ci-bot[bot] <78882869+enzyme-ci-bot[bot]@users.noreply.github.com> - name: Check outputs run: | echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}" diff --git a/.github/workflows/regenerate-mlir-bindings.yml b/.github/workflows/regenerate-mlir-bindings.yml index 2bbb19c120..602a4f0081 100644 --- a/.github/workflows/regenerate-mlir-bindings.yml +++ b/.github/workflows/regenerate-mlir-bindings.yml @@ -1,10 +1,12 @@ name: Regenerate MLIR Bindings + on: schedule: - cron: '0 0 * * *' workflow_dispatch: + jobs: - make: + mlir-bindings: runs-on: ubuntu-latest permissions: contents: write @@ -38,7 +40,8 @@ jobs: working-directory: ./deps/ReactantExtra env: JULIA_DEPOT_PATH: ${{ runner.temp }}/julia_depot - - run: | + - name: Make generated files writable + run: | chmod -R u+rw ./src/mlir/Dialects/ chmod u+rw ./src/mlir/libMLIR_h.jl git config core.fileMode false @@ -48,6 +51,9 @@ jobs: using JuliaFormatter format("./src/mlir/Dialects/") format("./src/mlir/libMLIR_h.jl") + # Format twice to work around . + format("./src/mlir/Dialects/") + format("./src/mlir/libMLIR_h.jl") - name: Create Pull Request id: cpr uses: peter-evans/create-pull-request@v7 @@ -57,6 +63,7 @@ jobs: title: 'Regenerate MLIR Bindings' branch: regenerate-mlir-bindings delete-branch: true + author: enzyme-ci-bot[bot] <78882869+enzyme-ci-bot[bot]@users.noreply.github.com> - name: Check outputs run: | echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}" diff --git a/.gitignore b/.gitignore index f8e0c6a23e..1845a6d848 100644 --- a/.gitignore +++ b/.gitignore @@ -252,6 +252,7 @@ docs/site/ # environment. Manifest.toml Manifest-v*.toml +.CondaPkg .vscode/* .vscode/settings.json diff --git a/CondaPkg.toml b/CondaPkg.toml index 39a82d7737..93e9945128 100644 --- a/CondaPkg.toml +++ b/CondaPkg.toml @@ -1,2 +1,2 @@ -[deps] -jax = "" +[pip.deps] +jax = ">=0.4" diff --git a/Project.toml b/Project.toml index c99a7a2e44..37ae16a5f9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,44 +1,59 @@ name = "Reactant" uuid = "3c362404-f566-11ee-1572-e11a4b42c853" authors = ["William Moses ", "Valentin Churavy ", "Sergio Sánchez Ramírez ", "Paul Berg ", "Avik Pal "] -version = "0.2.11" +version = "0.2.36" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +LLVMOpenMP_jll = "1d63c593-3942-5779-bab2-d838dc0a180e" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433" Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0" Scratch = "6c6a2e73-6563-6170-7368-637461726353" +Sockets = "6462fe0b-24de-5631-8697-dd941f90decc" [weakdeps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" +MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" Random123 = "74087812-796a-5b5d-8853-05524746bad3" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df" -[sources.ReactantCore] -path = "lib/ReactantCore" +[sources] +ReactantCore = {path = "lib/ReactantCore"} [extensions] ReactantAbstractFFTsExt = "AbstractFFTs" ReactantArrayInterfaceExt = "ArrayInterface" -ReactantCUDAExt = "CUDA" +ReactantCUDAExt = ["CUDA", "GPUCompiler", "KernelAbstractions", "LLVM"] +ReactantKernelAbstractionsExt = "KernelAbstractions" +ReactantMPIExt = "MPI" ReactantNNlibExt = "NNlib" +ReactantOffsetArraysExt = "OffsetArrays" ReactantPythonCallExt = "PythonCall" ReactantRandom123Ext = "Random123" +ReactantSpecialFunctionsExt = "SpecialFunctions" ReactantStatisticsExt = "Statistics" ReactantYaoBlocksExt = "YaoBlocks" @@ -47,21 +62,32 @@ AbstractFFTs = "1.5" Adapt = "4.1" ArrayInterface = "7.17.1" CEnum = "0.5" -CUDA = "5.5" +CUDA = "5.6" Downloads = "1.6" -Enzyme = "0.13.22" +EnumX = "1" +Enzyme = "0.13.28" EnzymeCore = "0.8.8" -GPUArraysCore = "0.1.6, 0.2" +Functors = "0.5" +GPUArraysCore = "0.2" +GPUCompiler = "1.1.1" +KernelAbstractions = "0.9.30" +LLVM = "9.1" +LLVMOpenMP_jll = "18.1.7" LinearAlgebra = "1.10" +MPI = "0.20" NNlib = "0.9.26" +OffsetArrays = "1" OrderedCollections = "1" +PrecompileTools = "1.2" Preferences = "1.4" PythonCall = "0.9" Random = "1.10" Random123 = "1.7" -ReactantCore = "0.1.3" -Reactant_jll = "0.0.32" +ReactantCore = "0.1.5" +Reactant_jll = "0.0.80" Scratch = "1.2" +Sockets = "1.10" +SpecialFunctions = "2.4" Statistics = "1.10" YaoBlocks = "0.13" julia = "1.10" diff --git a/README.md b/README.md index afa0158668..487185d621 100644 --- a/README.md +++ b/README.md @@ -60,16 +60,3 @@ Reactant.set_default_backend("gpu") # ones favorite code will now all be executed on GPU, no CUDA.jl dependency even required! ``` - -## Installing Reactant on GPU Servers without Internet - -If you want to use Reactant on GPU Servers where all packages must be installed on the login nodes and the compute nodes don't have access to internet, -add the following to the Project.toml and precompile the package: - -```toml -[extras] -Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0" - -[preferences.Reactant_jll] -gpu = "cuda" -``` diff --git a/benchmark/setup.jl b/benchmark/setup.jl index 1d03f55c1d..f56e710f31 100644 --- a/benchmark/setup.jl +++ b/benchmark/setup.jl @@ -55,20 +55,20 @@ function setup_simple_benchmark!(suite::BenchmarkGroup, backend) suite["(Basics) 2D sum (2 x 10)"]["forward (compilation)"][backend][tag] = @benchmarkable begin @compile optimize = $(opt_pass) sum(x) end setup = begin - x = Reactant.ConcreteRArray(ones(2, 10)) + x = Reactant.to_rarray(ones(2, 10)) end suite["(Basics) sum(cos, x) (2 x 10)"]["forward (compilation)"][backend][tag] = @benchmarkable begin @compile optimize = $(opt_pass) sumcos(x) end setup = begin - x = Reactant.ConcreteRArray(ones(2, 10)) + x = Reactant.to_rarray(ones(2, 10)) end end suite["Basics ∇sumcos (2 x 10)"]["forward (compilation)"][backend]["Reactant"] = @benchmarkable begin @compile optimize = :all ∇sumcos(x) end setup = begin - x = Reactant.ConcreteRArray(ones(2, 10)) + x = Reactant.to_rarray(ones(2, 10)) end return nothing diff --git a/deps/Project.toml b/deps/Project.toml index 1eee64aec5..2b8460e285 100644 --- a/deps/Project.toml +++ b/deps/Project.toml @@ -1,10 +1,3 @@ [deps] -Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63" Preferences = "21216c6a-2e73-6563-6e65-726566657250" -Scratch = "6c6a2e73-6563-6170-7368-637461726353" -Clang = "40e3b903-d033-50b4-a0cc-940c62c95e31" -BinaryBuilderBase = "7f725544-6523-48cd-82d1-3fa08ff4056e" - -[compat] -Clang = "0.18" diff --git a/deps/ReactantExtra/.bazelrc b/deps/ReactantExtra/.bazelrc index a8b84e0f0d..4f13bbd7f1 100644 --- a/deps/ReactantExtra/.bazelrc +++ b/deps/ReactantExtra/.bazelrc @@ -18,14 +18,13 @@ build -c opt build:cuda --repo_env TF_NEED_CUDA=1 build:cuda --repo_env TF_NVCC_CLANG=1 build:cuda --repo_env TF_NCCL_USE_STUB=1 -build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" -build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" +build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.6.2" +build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.4.0" # "sm" means we emit only cubin, which is forward compatible within a GPU generation. # "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations. build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90" build:cuda --crosstool_top="@local_config_cuda//crosstool:toolchain" build:cuda --@local_config_cuda//:enable_cuda -build:cuda --@xla//xla/python:jax_cuda_pip_rpaths=true # Default hermetic CUDA and CUDNN versions. build:cuda --@local_config_cuda//cuda:include_cuda_libs=true build:cuda --@local_config_cuda//:cuda_compiler=nvcc diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 3292f38800..a27e9c16e8 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -9,6 +9,8 @@ #include "Enzyme/MLIR/Dialect/Ops.h" #include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h" #include "Enzyme/MLIR/Passes/Passes.h" + +#include "mlir/CAPI/Support.h" #include "mlir/Conversion/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -31,6 +33,7 @@ #include "mlir/InitAllPasses.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Transforms/Passes.h" +#include "mlir/Parser/Parser.h" #include "src/enzyme_ad/jax/Dialect/Dialect.h" #include "src/enzyme_ad/jax/Implementations/XLADerivatives.h" #include "src/enzyme_ad/jax/Passes/Passes.h" @@ -45,12 +48,13 @@ #include "xla/mlir/utils/type_util.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/pjrt/cpu/cpu_client.h" -#include "xla/pjrt/gpu/se_gpu_pjrt_client.h" -#include "xla/pjrt/pjrt_api.h" -#include "xla/pjrt/pjrt_c_api_client.h" -#include "xla/pjrt/pjrt_executable.h" -#include "xla/pjrt/status_casters.h" + +#include "tsl/platform/init_main.h" +#include "tsl/profiler/lib/profiler_session.h" +#include "tsl/profiler/lib/traceme.h" +#include "xla/python/profiler_utils.h" +#include "xla/tsl/profiler/rpc/client/capture_profile.h" +#include "xla/tsl/profiler/rpc/profiler_server.h" #include "xla/python/ifrt/hlo/hlo_program.h" #include "llvm/ExecutionEngine/ExecutionEngine.h" @@ -60,17 +64,50 @@ #include "llvm-c/TargetMachine.h" +// PJRT +#include "xla/pjrt/cpu/cpu_client.h" +#include "xla/pjrt/distributed/client.h" +#include "xla/pjrt/distributed/distributed.h" +#include "xla/pjrt/distributed/service.h" +#include "xla/pjrt/gpu/se_gpu_pjrt_client.h" +#include "xla/pjrt/pjrt_api.h" +#include "xla/pjrt/pjrt_c_api_client.h" +#include "xla/pjrt/pjrt_executable.h" + +// CPU collectives +#include "xla/backends/cpu/collectives/mpi_collectives.h" +#if defined(__linux__) +#include "gloo/transport/tcp/attr.h" +#include "gloo/transport/tcp/device.h" +#include "xla/backends/cpu/collectives/gloo_collectives.h" +#include "xla/backends/cpu/collectives/gloo_kv_store.h" +#elif defined(__APPLE__) +#include "gloo/transport/uv/device.h" +#include "xla/backends/cpu/collectives/gloo_collectives.h" +#include "xla/backends/cpu/collectives/gloo_kv_store.h" +#endif // defined(__linux__) + +// shardy +#include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/integrations/c/attributes.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/service/spmd/shardy/stablehlo_round_trip/export_shardings.h" +#include "xla/service/spmd/shardy/stablehlo_round_trip/stablehlo_import.h" + // IFRT #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/dtype.h" #include "xla/python/ifrt/executable.h" #include "xla/python/ifrt/hlo/hlo_program.h" #include "xla/python/ifrt/host_callback.h" #include "xla/python/ifrt/index.h" #include "xla/python/ifrt/index_domain.h" +#include "xla/python/ifrt/ir/ifrt_ir_program.h" #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" @@ -89,6 +126,17 @@ #include "xla/python/pjrt_ifrt/pjrt_memory.h" #include "xla/python/pjrt_ifrt/pjrt_topology.h" #include "xla/python/pjrt_ifrt/pjrt_tuple.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "xla/python/pjrt_ifrt/xla_sharding.h" + +// IFRT - Proxy (RPC) +#include "xla/python/ifrt_proxy/client/registry.h" +#include "xla/python/ifrt_proxy/server/grpc_server.h" + +#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "llvm/Support/ExtensibleRTTI.h" using namespace mlir; using namespace llvm; @@ -101,18 +149,64 @@ void registerGenerateApplyPatternsPass(); } // namespace enzyme } // namespace mlir +namespace reactant { + +template struct unwrap_type { + typedef T type; +}; +template struct unwrap_type> { + typedef T type; +}; +template struct unwrap_type> { + typedef T type; +}; + +template using unwrap_type_t = typename unwrap_type::type; + +template struct HeldValue { +public: + HeldValue(T &obj) : holded(obj) {} + ~HeldValue() = default; + + unwrap_type_t *ptr() const { return holded.get(); } + + T obj() const { return holded; } + + T value() const { return holded; } + + unwrap_type_t *operator->() const { return ptr(); } + +private: + T holded; +}; + +template HeldValue *capture(T obj) { + return new HeldValue(obj); +} + +} // namespace reactant + +using reactant::HeldValue; +using HeldPjRtClient = HeldValue>; +using HeldPjRtBuffer = HeldValue>; +using HeldIfrtArray = HeldValue>; + extern "C" void (*ReactantThrowError)(const char *) = nullptr; // Utilities for `StatusOr`. template T MyValueOrThrow(absl::StatusOr v) { - if (ReactantThrowError) { - if (!v.ok()) { - ReactantThrowError(v.status().ToString().c_str()); - throw xla::XlaRuntimeError(v.status().ToString().c_str()); + if (!v.ok()) { + ReactantThrowError(v.status().ToString().c_str()); + } + return std::move(v).value(); +} + +extern "C" void ReactantHandleCuResult(uint32_t curesult) { + if (curesult != 0) { + std::string err = "Bad Cuda Result = " + std::to_string(curesult); + if (ReactantThrowError) { + ReactantThrowError(err.c_str()); } - return std::move(v).value(); - } else { - return xla::ValueOrThrow(std::move(v)); } } @@ -133,9 +227,32 @@ extern "C" MlirAttribute mlirComplexAttrDoubleGetChecked(MlirLocation loc, unwrap(loc), cast(unwrap(type)), real, imag)); } +extern "C" MlirOperation mlirOperationParse(MlirContext ctx, + MlirStringRef code) { + ParserConfig config(unwrap(ctx)); + OwningOpRef owning_op = parseSourceString(unwrap(code), config); + if (!owning_op) + return MlirOperation{nullptr}; + return MlirOperation{owning_op.release()}; +} + // TODO mlirComplexAttrGetnValue // TODO extern "C" MlirTypeID mlirComplexAttrGetTypeID(void) { return // wrap(complex::NumberAttr::getTypeID()); } + +extern "C" void ReactantFuncSetResultAttr(MlirOperation op, intptr_t pos, + MlirStringRef name, + MlirAttribute attr) { + llvm::cast(unwrap(op)) + .setResultAttr(pos, unwrap(name), unwrap(attr)); +} + +extern "C" void ReactantFuncSetArgAttr(MlirOperation op, intptr_t pos, + MlirStringRef name, MlirAttribute attr) { + llvm::cast(unwrap(op)) + .setArgAttr(pos, unwrap(name), unwrap(attr)); +} + #pragma endregion // auxiliar functions @@ -165,7 +282,11 @@ T *unwrap_absl_statusor(absl::StatusOr status, char **error_msg) { // int xla::_LayoutProto_default_instance_; extern "C" void InitializeLogs() { - absl::InitializeLog(); + const char *binary = "julia"; + int argc = 1; + char *argv[] = {(char *)binary}; + char **argv2 = &argv[0]; + tsl::port::InitMain(binary, &argc, &argv2); LLVMInitializeX86Target(); LLVMInitializeX86TargetInfo(); LLVMInitializeX86TargetMC(); @@ -198,27 +319,93 @@ enzymeActivityAttrGet(MlirContext ctx, int32_t val) { (mlir::enzyme::Activity)val)); } -extern "C" PjRtClient *MakeCPUClient(uint8_t asynchronous, int node_id, - int num_nodes) { +// Create profiler session and start profiling +extern "C" tsl::ProfilerSession * +CreateProfilerSession(uint32_t device_tracer_level, + uint32_t host_tracer_level) { + tensorflow::ProfileOptions options = tsl::ProfilerSession::DefaultOptions(); + options.set_device_tracer_level(device_tracer_level); + options.set_host_tracer_level(host_tracer_level); + auto sess = tsl::ProfilerSession::Create(options); + return sess.release(); +} + +extern "C" void ProfilerSessionCollectData(tsl::ProfilerSession *session, + const char *path) { + tensorflow::profiler::XSpace xspace; + auto status = session->CollectData(&xspace); + if (!status.ok()) + ReactantThrowError("cannot collect data for profiler"); + tsl::profiler::ExportToTensorBoard(xspace, path, + /*also_export_trace_json*/ true); +} + +extern "C" void ProfilerSessionDelete(tsl::ProfilerSession *session) { + delete session; +} + +extern "C" int64_t ProfilerActivityStart(const char *name, int level) { + return tsl::profiler::TraceMe::ActivityStart(name, level); +} + +extern "C" void ProfilerActivityEnd(int64_t id) { + tsl::profiler::TraceMe::ActivityEnd(id); +} + +extern "C" tsl::profiler::ProfilerServer *ProfilerServerStart(int32_t port) { + auto server = new tsl::profiler::ProfilerServer(); + server->StartProfilerServer(port); + return server; +} + +extern "C" void ProfilerServerStop(tsl::profiler::ProfilerServer *server) { + delete server; +} + +PjRtClient *MakeCPUClientInternal( + uint8_t asynchronous, int node_id, + std::optional> collectives) { CpuClientOptions options; - // options.kv_store = "etcd"; + options.process_id = node_id; - // options.num_nodes = num_nodes; - // options.collectives = num_nodes; options.asynchronous = asynchronous != 0; + + if (collectives.has_value()) + options.collectives = collectives.value(); + auto client = MyValueOrThrow(GetTfrtCpuClient(options)); return client.release(); } +extern "C" PjRtClient *MakeCPUClient(uint8_t asynchronous, int node_id) { + std::optional> collectives; + return MakeCPUClientInternal(asynchronous, node_id, collectives); +} + // xla/python/xla.cc 390 -extern "C" PjRtClient *MakeGPUClient(int node_id, int num_nodes, - int *allowed_devices, - int num_allowed_devices, - const char *platform_name, - const char **error) { +extern "C" PjRtClient * +MakeGPUClient(int node_id, int num_nodes, int *allowed_devices, + int num_allowed_devices, double memory_fraction, bool preallocate, + const char *platform_name, const char **error, + void *distributed_runtime_client) { GpuClientOptions options; - // options.kv_store = "etcd"; + + if (num_nodes > 1) { + if (distributed_runtime_client == nullptr) { + *error = + "`distributed_runtime_client` must be non-null if `num_nodes` > 1"; + return nullptr; + } + auto typed_distributed_runtime_client = static_cast< + HeldValue> *>( + distributed_runtime_client); + options.kv_store = GetDistributedKeyValueStore( + typed_distributed_runtime_client->obj(), /*key_prefix=*/""); + } + // options.allocator_config = + options.allocator_config.preallocate = preallocate; + options.allocator_config.memory_fraction = memory_fraction; options.node_id = node_id; options.num_nodes = num_nodes; options.allowed_devices = @@ -292,11 +479,11 @@ extern "C" PjRtClient *MakeTPUClient(const char *tpu_path, const char **error) { LoadPjrtPlugin("tpu", tpu_library_path.c_str(), error); if (pluginLoad == nullptr) return nullptr; - auto tpu_status = InitializePjrtPlugin("tpu", error); if (tpu_status) return nullptr; + RegisterProfiler(pluginLoad); return GetCApiClient("TPU"); } @@ -322,6 +509,63 @@ extern "C" PjRtDevice *ClientGetAddressableDevice(PjRtClient *client, client->LookupAddressableDevice(PjRtLocalDeviceId(device_id))); } +extern "C" const char *ClientGetPlatformName(PjRtClient *client) { + return cstr_from_string(client->platform_name()); +} + +extern "C" const char *DeviceGetKind(PjRtDevice *device) { + return cstr_from_string(device->device_kind()); +} + +extern "C" void ClientGetDevices(PjRtClient *client, PjRtDevice **out_devices) { + auto span = client->devices(); + for (int i = 0; i < span.size(); i++) { + out_devices[i] = span[i]; + } +} + +extern "C" void ClientGetAddressableDevices(PjRtClient *client, + PjRtDevice **out_devices) { + auto span = client->addressable_devices(); + for (int i = 0; i < span.size(); i++) { + out_devices[i] = span[i]; + } +} + +// To keep in sync with JLAllocatorStats in src/XLA.jl +struct JLAllocatorStats { + int64_t num_allocs; + int64_t bytes_in_use; + int64_t peak_bytes_in_use; + int64_t largest_alloc_size; + int64_t bytes_limit; + int64_t bytes_reserved; + int64_t peak_bytes_reserved; + int64_t bytes_reservable_limit; + int64_t largest_free_block_bytes; + int64_t pool_bytes; + int64_t peak_pool_bytes; +}; + +extern "C" void PjRtDeviceGetAllocatorStats(PjRtDevice *device, + JLAllocatorStats *jlstats) { + auto stats = MyValueOrThrow(device->GetAllocatorStats()); + int64_t optnull = std::numeric_limits::min(); + + jlstats->num_allocs = stats.num_allocs; + jlstats->bytes_in_use = stats.bytes_in_use; + jlstats->peak_bytes_in_use = stats.peak_bytes_in_use; + jlstats->largest_alloc_size = stats.largest_alloc_size; + jlstats->bytes_limit = stats.bytes_limit.value_or(optnull); + jlstats->bytes_reserved = stats.bytes_reserved; + jlstats->peak_bytes_reserved = stats.peak_bytes_reserved; + jlstats->bytes_reservable_limit = + stats.bytes_reservable_limit.value_or(optnull); + jlstats->largest_free_block_bytes = stats.largest_free_block_bytes; + jlstats->pool_bytes = stats.pool_bytes.value_or(optnull); + jlstats->peak_pool_bytes = stats.peak_pool_bytes.value_or(optnull); +} + extern "C" void ExecutableFree(xla::PjRtLoadedExecutable *exec) { delete exec; } extern "C" PjRtDevice *BufferToDevice(PjRtBuffer *Buffer) { @@ -332,11 +576,28 @@ extern "C" PjRtClient *BufferToClient(PjRtBuffer *Buffer) { return Buffer->client(); } +extern "C" const int64_t *BufferShape(PjRtBuffer *Buffer) { + return Buffer->dimensions().data(); +} + +extern "C" int64_t BufferNDimensions(PjRtBuffer *Buffer) { + return Buffer->dimensions().length(); +} + +extern "C" xla::PrimitiveType BufferPrimitiveType(PjRtBuffer *Buffer) { + return Buffer->element_type(); +} + +extern "C" void PjRtBufferFree(PjRtBuffer *Buffer) { delete Buffer; } + extern "C" PjRtClient *DeviceToClient(PjRtDevice *Device) { return Device->client(); } -extern "C" void PjRtBufferFree(PjRtBuffer *Buffer) { delete Buffer; } +extern "C" PjRtClient * +PjRtLoadedExecutableGetClient(PjRtLoadedExecutable *exec) { + return exec->client(); +} // https://openxla.org/xla/shapes // This minor-to-major dimension order of 0 up to N-1 is akin to column-major @@ -351,6 +612,13 @@ std::vector col_major(int64_t dim) { return minor_to_major; } +extern "C" void ReactantLLVMParseCommandLineOptions(int argc, + const char *const *argv, + const char *Overview) { + llvm::cl::ParseCommandLineOptions(argc, argv, StringRef(Overview), + &llvm::nulls()); +} + std::vector row_major(int64_t dim) { std::vector minor_to_major; for (int i = 0; i < dim; i++) { @@ -360,6 +628,19 @@ std::vector row_major(int64_t dim) { } static void noop() {} +#ifdef REACTANT_CUDA +#include "third_party/gpus/cuda/include/cuda.h" +extern "C" int32_t ReactantCudaDriverGetVersion() { + int32_t data; + ReactantHandleCuResult(cuDriverGetVersion(&data)); + return data; +} +extern "C" int32_t ReactantHermeticCudaGetVersion() { return CUDA_VERSION; } +#else +extern "C" int32_t ReactantCudaDriverGetVersion() { return 0; } +extern "C" int32_t ReactantHermeticCudaGetVersion() { return 0; } +#endif + extern "C" void *UnsafeBufferPointer(PjRtBuffer *buffer) { auto unsafe = MyValueOrThrow(buffer->client()->UnsafeBufferPointer(buffer)); return (void *)unsafe; @@ -377,9 +658,10 @@ extern "C" PjRtBuffer *ArrayFromHostBuffer(PjRtClient *client, void *data, // auto buffer = xla::MyValueOrThrow(client->BufferFromHostBuffer(data, // primtype, shape, /*byte_strides*/{}, semantics, /*ondone*/{}, device, // &layout)); - auto buffer = MyValueOrThrow( - client->BufferFromHostBuffer(data, primtype, shape, /*byte_strides*/ {}, - semantics, /*ondone*/ {}, device)); + const xla::Layout *layout = nullptr; + auto buffer = MyValueOrThrow(client->BufferFromHostBuffer( + data, primtype, shape, /*byte_strides*/ {}, semantics, /*ondone*/ {}, + *device->default_memory_space(), layout)); auto bres = buffer.release(); return bres; } @@ -388,7 +670,8 @@ extern "C" uint8_t BufferOnCPU(PjRtBuffer *buffer) { return buffer->IsOnCpu(); } extern "C" PjRtBuffer *CopyBufferToDevice(PjRtBuffer *buffer, PjRtDevice *dst_device) { - auto res = MyValueOrThrow(buffer->CopyToDevice(dst_device)); + auto res = MyValueOrThrow( + buffer->CopyToMemorySpace(*dst_device->default_memory_space())); return res.release(); } @@ -407,6 +690,18 @@ extern "C" void BufferToHost(PjRtBuffer *buffer, void *data) { extern "C" void FreeClient(PjRtClient *client) { delete client; } +extern "C" int64_t PjRtDeviceGetLocalDeviceId(PjRtDevice *device) { + return device->local_device_id().value(); +} + +extern "C" int64_t PjRtDeviceGetGlobalDeviceId(PjRtDevice *device) { + return device->global_device_id().value(); +} + +extern "C" int64_t PjRtDeviceGetLocalHardwareId(PjRtDevice *device) { + return device->local_hardware_id().value(); +} + #include "xla/service/custom_call_target_registry.h" extern "C" void RegisterCustomCallTarget(const char *name, void *address, const char *platform) { @@ -432,58 +727,244 @@ extern "C" MlirModule ConvertLLVMStrToMLIR(const char *lmod, MlirContext cctx) { SMDiagnostic Err; auto llvmModule = llvm::parseIR(llvm::MemoryBufferRef(lmod, "conversion"), Err, Context); + if (!llvmModule) { + std::string err_str; + llvm::raw_string_ostream err_stream(err_str); + Err.print(/*ProgName=*/"LLVMToMLIR", err_stream); + err_stream.flush(); + if (ReactantThrowError) { + llvm::errs() << lmod << "\n"; + ReactantThrowError(err_str.c_str()); + return wrap((mlir::ModuleOp) nullptr); + } + } mlir::MLIRContext &context = *unwrap(cctx); auto res = mlir::translateLLVMIRToModule(std::move(llvmModule), &context, /*emitExpensiveWarnings*/ false, /*dropDICompositeElements*/ false) .release(); + if (!res) { + llvm::errs() << lmod << "\n"; + ReactantThrowError("Could not translate LLVM IR to MLIR Module"); + } return wrap(res); } -/* Note that this */ -extern "C" xla::PjRtLoadedExecutable *ClientCompile(PjRtClient *client, - MlirModule cmod) { - auto program = - std::make_unique(cast(*unwrap(cmod))); - - CompileOptions options; - // options.argument_layouts; - // options.executable_build_options.set_device_ordinal(); - // options.executable_build_options.set_result_layout(); - - auto addressable_devices = client->addressable_devices(); - if (!addressable_devices.empty()) { - int device_ordinal = options.executable_build_options.device_ordinal(); - if (device_ordinal < 0) { - device_ordinal = 0; +typedef PjRtFuture<> FutureType; +extern "C" void FreeFuture(FutureType *Future) { delete Future; } + +extern "C" uint8_t FutureIsReady(FutureType *Future) { + return Future->IsReady(); +} + +extern "C" void FutureAwait(FutureType *Future) { Future->Await(); } + +xla::CompileOptions GenerateCompileOptions(int64_t device_id, bool is_sharded, + const int64_t *mesh_ids, + int64_t num_mesh_ids, + const char *xla_gpu_cuda_data_dir) { + xla::CompileOptions options; + options.executable_build_options.mutable_debug_options() + ->set_xla_gpu_cuda_data_dir(xla_gpu_cuda_data_dir); + + if (is_sharded) { + assert(device_id < 0); + + options.executable_build_options.set_num_replicas(1); + options.executable_build_options.set_num_partitions(num_mesh_ids); + + options.executable_build_options.set_use_spmd_partitioning(true); + options.executable_build_options.set_use_shardy_partitioner(true); + + // auto partitioning for GPUs is not available in open source version of XLA + // options.executable_build_options.set_use_auto_spmd_partitioning(true); + // std::vector mesh_shape_vec(mesh_shape, mesh_shape + + // num_mesh_shape); + // options.executable_build_options.set_auto_spmd_partitioning_mesh_shape(mesh_shape_vec); + // std::vector mesh_ids_vec(mesh_ids, mesh_ids + num_mesh_ids); + // options.executable_build_options.set_auto_spmd_partitioning_mesh_ids(mesh_ids_vec); + + xla::DeviceAssignment device_assignment(1, num_mesh_ids); + for (int64_t i = 0; i < num_mesh_ids; ++i) { + int64_t mesh_id = mesh_ids[i]; + assert(mesh_id >= 0); + device_assignment(0, i) = mesh_id; } - assert(device_ordinal < addressable_devices.size()); - auto stats = addressable_devices[device_ordinal]->GetAllocatorStats(); - if (stats.ok() && stats->bytes_limit) { - options.executable_build_options.set_device_memory_size( - *stats->bytes_limit); + options.executable_build_options.set_device_assignment(device_assignment); + + options.executable_build_options + .set_allow_spmd_sharding_propagation_to_parameters({false}); + options.executable_build_options + .set_allow_spmd_sharding_propagation_to_output({false}); + } else { + assert(device_id >= 0); + + options.executable_build_options.set_num_replicas(1); + options.executable_build_options.set_num_partitions(1); + options.executable_build_options.set_device_ordinal(device_id); + + xla::DeviceAssignment device_assignment(1, 1); + device_assignment(0, 0) = device_id; + options.executable_build_options.set_device_assignment(device_assignment); + } + + return options; +} + +extern "C" xla::PjRtLoadedExecutable * +ClientCompile(PjRtClient *client, MlirModule cmod, int64_t device_id, + bool is_sharded, const int64_t *mesh_ids, int64_t num_mesh_ids, + const char *xla_gpu_cuda_data_dir) { + CompileOptions options = GenerateCompileOptions( + device_id, is_sharded, mesh_ids, num_mesh_ids, xla_gpu_cuda_data_dir); + + mlir::ModuleOp cmod_op = cast(*unwrap(cmod)); + if (is_sharded) { + // https://github.com/openxla/xla/blob/b3c641b05692f3712fb3c272e38665fdfa28bdf8/xla/python/py_client.cc#L460 + auto status = xla::ExportShardyForHloRoundTrip(cmod_op); + if (!status.ok()) { + ReactantThrowError(status.ToString().c_str()); } } - auto exec = - MyValueOrThrow(client->Compile(cast(*unwrap(cmod)), options)); + + auto exec = MyValueOrThrow(client->Compile(cmod_op, options)); return exec.release(); } -typedef PjRtFuture<> FutureType; -extern "C" void FreeFuture(FutureType *Future) { delete Future; } +extern "C" void +PjRtLoadedExecutableGetOuputShardings(xla::PjRtLoadedExecutable *exec, + xla::OpSharding **op_shardings, + int32_t num_op_shardings) { + std::optional> shardings = exec->GetOutputShardings(); + if (!shardings.has_value()) { + ReactantThrowError( + "No sharding found for the output of the loaded executable"); + } -extern "C" uint8_t FutureIsReady(FutureType *Future) { - return Future->IsReady(); + std::vector hlo_op_shardings = shardings.value(); + if (num_op_shardings != hlo_op_shardings.size()) { + ReactantThrowError(("Expected " + std::to_string(num_op_shardings) + + " shardings, got " + + std::to_string(hlo_op_shardings.size())) + .c_str()); + } + + for (int32_t i = 0; i < num_op_shardings; i++) { + op_shardings[i] = new xla::OpSharding(hlo_op_shardings[i]); + } } -extern "C" void FutureAwait(FutureType *Future) { Future->Await(); } +extern "C" void +PjRtLoadedExecutableGetParameterShardings(xla::PjRtLoadedExecutable *exec, + xla::OpSharding **op_shardings, + int32_t num_op_shardings) { + std::optional> shardings = + exec->GetParameterShardings(); + if (!shardings.has_value()) { + ReactantThrowError( + "No sharding found for the output of the loaded executable"); + } + + std::vector hlo_op_shardings = shardings.value(); + if (num_op_shardings != hlo_op_shardings.size()) { + ReactantThrowError(("Expected " + std::to_string(num_op_shardings) + + " shardings, got " + + std::to_string(hlo_op_shardings.size())) + .c_str()); + } + + for (int32_t i = 0; i < num_op_shardings; i++) { + op_shardings[i] = new xla::OpSharding(hlo_op_shardings[i]); + } +} + +extern "C" void XLAExecuteSharded(xla::PjRtLoadedExecutable *exec, int num_args, + PjRtBuffer **op_args, PjRtDevice *device, + uint8_t *is_arg_donatable, int num_results, + PjRtBuffer **op_results, uint8_t *futures, + FutureType **future_results) { + // Create a vector of PjRtBuffer* from the input array. + std::vector argument_handles(op_args, op_args + num_args); + + // Set up execution options. + ExecuteOptions options; + for (size_t i = 0; i < num_args; i++) { + if (!is_arg_donatable[i]) { + options.non_donatable_input_indices.insert(static_cast(i)); + } + } + options.untuple_result = true; + + // Optional future to hold asynchronous execution results. + std::optional> returned_future; + + auto results = MyValueOrThrow(exec->ExecuteSharded(argument_handles, device, + options, returned_future, + /*fill_future=*/true)); + + // Validate the number of results. + if (results.size() != num_results) { + ReactantThrowError( + ("Error: results.size()=" + std::to_string(results.size()) + + " does not match num_results=" + std::to_string(num_results) + "\n") + .c_str()); + } -extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int num_args, + // Handle futures if they are returned. + auto future_val = returned_future.has_value(); + *futures = future_val; + if (future_val) { + for (size_t i = 0; i < num_results; i++) { + future_results[i] = new FutureType(*returned_future); + } + } + + // Release the results into the output array. + for (size_t i = 0; i < num_results; i++) { + op_results[i] = results[i].release(); + } +} + +// This isn't exposed to julia, but leaving it here since it is very useful for +// debugging sharding (and generally for the execute workflow) +void PrintPjRtBuffer(PjRtBuffer *buffer) { + if (buffer) { + xla::Shape shape = MyValueOrThrow(buffer->HostShape()); + auto dims = shape.dimensions(); + auto nelems = std::accumulate(dims.begin(), dims.end(), 1, + std::multiplies()); + std::vector host_data(nelems); + BufferToHost(buffer, host_data.data()); + + for (int i = 0; i < nelems; ++i) { + std::cout << host_data[i] << " "; + } + std::cout << std::endl; + } else { + std::cout << " Buffer is nullptr" << std::endl; + } + return; +} + +extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int op_args_len, PjRtBuffer **op_args, uint8_t *is_arg_donatable, int num_results, PjRtBuffer **op_results, uint8_t *futures, FutureType **future_results) { - std::vector> argument_handles; - argument_handles.emplace_back(op_args, op_args + num_args); + xla::DeviceAssignment device_assignment = exec->device_assignment(); + int num_devices = device_assignment.computation_count(); + + // Ensure argument_handles is structured as num_devices x num_args + std::vector> argument_handles(num_devices); + int num_args = op_args_len / num_devices; + + // Distribute arguments across devices + for (int device_idx = 0; device_idx < num_devices; ++device_idx) { + argument_handles[device_idx].reserve(num_args); + for (int arg_idx = 0; arg_idx < num_args; ++arg_idx) { + argument_handles[device_idx].push_back( + op_args[device_idx * num_args + arg_idx]); + } + } ExecuteOptions options; @@ -492,34 +973,65 @@ extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int num_args, options.non_donatable_input_indices.insert((int)i); } options.untuple_result = true; - std::optional> returned_futures; - auto results = MyValueOrThrow( - exec->Execute(static_cast>>( - argument_handles), - options, returned_futures)); - assert(results.size() == 1); + std::optional> returned_futures = + std::vector(); + std::vector>> results = + MyValueOrThrow(exec->Execute( + static_cast>>( + argument_handles), + options, returned_futures)); + + if (results.size() != num_devices) { + ReactantThrowError((" results.size()=" + std::to_string(results.size()) + + " num_devices=" + std::to_string(num_devices) + "\n") + .c_str()); + } - if (results[0].size() != num_results) { - llvm::errs() << " results.size()=" << results.size() - << " num_results=" << num_results << "\n"; + for (int device_idx = 0; device_idx < num_devices; ++device_idx) { + // Remove mesh_id lookup since we're using device_idx ordering + if (results[device_idx].size() != num_results) { + ReactantThrowError( + (" results[" + std::to_string(device_idx) + + "].size()=" + std::to_string(results[device_idx].size()) + + " num_results=" + std::to_string(num_results) + "\n") + .c_str()); + } } - assert(results[0].size() == num_results); - if (returned_futures) { - *futures = true; - assert(returned_futures->size() == num_results); - for (size_t i = 0; i < num_results; i++) { - future_results[i] = new FutureType((*returned_futures)[i]); + + // Handle returned futures + auto future_val = returned_futures.has_value(); + *futures = future_val; + if (future_val) { + if (returned_futures->size() != num_devices) { + ReactantThrowError((" returned_futures->size()=" + + std::to_string(returned_futures->size()) + + " num_devices=" + std::to_string(num_devices) + "\n") + .c_str()); } - } else { - *futures = false; } - for (size_t i = 0; i < num_results; i++) { - op_results[i] = results[0][i].release(); + // Copy results into the output buffers + for (int device_idx = 0; device_idx < num_devices; ++device_idx) { + for (int result_idx = 0; result_idx < num_results; ++result_idx) { + int flat_index = device_idx * num_results + result_idx; + op_results[flat_index] = results[device_idx][result_idx].release(); + if (future_val) { + future_results[flat_index] = + new FutureType((*returned_futures)[device_idx]); + } + } } } +extern "C" int PjRtLoadedExecutableNumReplicas(PjRtLoadedExecutable *exec) { + return exec->num_replicas(); +} + +extern "C" int PjRtLoadedExecutableNumPartitions(PjRtLoadedExecutable *exec) { + return exec->num_partitions(); +} + void prepareRegistry(mlir::DialectRegistry ®istry); extern "C" void RegisterDialects(MlirContext cctx) { @@ -530,26 +1042,29 @@ extern "C" void RegisterDialects(MlirContext cctx) { context.loadDialect(); context.loadDialect(); context.loadDialect(); + context.loadDialect(); + context.loadDialect(); context.loadDialect(); context.loadDialect(); context.loadDialect(); context.loadDialect(); context.loadDialect(); + context.loadDialect(); + context.loadDialect(); } #include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h" #include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h" -extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { - mlir::DialectRegistry ®istry = *unwrap(creg); - prepareRegistry(registry); +#include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h" +extern "C" void InitializePasses(MlirDialectRegistry creg) { mlir::registerenzymePasses(); - regsiterenzymeXLAPasses(); + enzyme::registerenzymexlaPasses(); // Register the standard passes we want. mlir::registerCSEPass(); - mlir::registerConvertAffineToStandardPass(); + mlir::registerLowerAffinePass(); mlir::registerSCCPPass(); mlir::registerInlinerPass(); mlir::registerCanonicalizerPass(); @@ -557,11 +1072,7 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { mlir::registerLoopInvariantCodeMotionPass(); mlir::registerConvertSCFToOpenMPPass(); mlir::affine::registerAffinePasses(); - mlir::registerReconcileUnrealizedCasts(); - - mlir::registerLLVMDialectImport(registry); - mlir::registerNVVMDialectImport(registry); - mlir::LLVM::registerInlinerInterface(registry); + mlir::registerReconcileUnrealizedCastsPass(); /* registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) { @@ -583,6 +1094,19 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { mlir::transform::registerInterpreterPass(); mlir::enzyme::registerGenerateApplyPatternsPass(); mlir::enzyme::registerRemoveTransformPass(); + + // xla + shardy specific passes + xla::sdy::registerSdyRoundTripExportPipeline(); + xla::sdy::registerSdyRoundTripImportPipeline(); +} + +extern "C" void InitializeRegistry(MlirDialectRegistry creg) { + mlir::DialectRegistry ®istry = *unwrap(creg); + prepareRegistry(registry); + + mlir::registerLLVMDialectImport(registry); + mlir::registerNVVMDialectImport(registry); + mlir::LLVM::registerInlinerInterface(registry); } /// Returns an unused symbol in `module` for `oldSymbolName` by trying numeric @@ -609,7 +1133,8 @@ static mlir::StringAttr renameSymbol(llvm::StringRef oldSymName, static mlir::LogicalResult updateSymbolAndAllUses(mlir::SymbolOpInterface op, mlir::ModuleOp source, mlir::ModuleOp target, - unsigned &lastUsedID) { + unsigned &lastUsedID, + bool &shouldRemove) { using namespace llvm; using namespace mlir; @@ -619,6 +1144,13 @@ static mlir::LogicalResult updateSymbolAndAllUses(mlir::SymbolOpInterface op, return success(); } + if (auto func = dyn_cast(op.getOperation())) { + if (func.isExternal()) { + shouldRemove = true; + return success(); + } + } + StringAttr newSymName = renameSymbol(opName, lastUsedID, source, target); if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, source))) @@ -638,7 +1170,7 @@ extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC, unsigned lastUsedID = 0; - for (auto &op : *newMod.getBody()) { + for (auto &op : make_early_inc_range(*newMod.getBody())) { auto symbolOp = dyn_cast(op); if (!symbolOp) continue; @@ -649,10 +1181,15 @@ extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC, entryFn = &op; } - if (failed(updateSymbolAndAllUses(symbolOp, newMod, prevMod, lastUsedID))) { + bool shouldRemove = false; + if (failed(updateSymbolAndAllUses(symbolOp, newMod, prevMod, lastUsedID, + shouldRemove))) { assert(0 && "failed to update all uses"); } - SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private); + if (shouldRemove) + op.erase(); + else + SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private); } prevMod.getBody()->getOperations().splice( prevMod.getBody()->getOperations().end(), @@ -660,1005 +1197,1074 @@ extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC, return wrap(entryFn); } -#pragma region xla::ifrt +extern "C" void pjrt_client_dtor(HeldPjRtClient *client) { delete client; } -#pragma region xla::ifrt::Value -extern "C" ifrt::Client *ifrt_value_client(ifrt::Value *value) { - return value->client(); +extern "C" int pjrt_client_num_devices(HeldPjRtClient *client) { + return client->ptr()->device_count(); } -extern "C" ifrt::Future<> ifrt_value_get_ready_future(ifrt::Value *value) { - return value->GetReadyFuture(); +extern "C" int pjrt_client_num_addressable_devices(HeldPjRtClient *client) { + return client->ptr()->addressable_device_count(); } -extern "C" ifrt::Future<> ifrt_value_delete(ifrt::Value *value) { - return value->Delete(); +extern "C" int pjrt_client_pid(HeldPjRtClient *client) { + return client->ptr()->process_index(); } -extern "C" bool ifrt_value_is_deleted(ifrt::Value *value) { - return value->IsDeleted(); +extern "C" PjRtDevice *pjrt_client_get_device(HeldPjRtClient *client, + int device_id) { + return ClientGetDevice(client->ptr(), device_id); } -extern "C" const char *ifrt_value_debug_string(ifrt::Value *value) { - return cstr_from_string(value->DebugString()); +extern "C" PjRtDevice * +pjrt_client_get_addressable_device(HeldPjRtClient *client, int device_id) { + return ClientGetAddressableDevice(client->ptr(), device_id); } -#pragma endregion - -#pragma region xla::ifrt::Tuple -extern "C" int ifrt_tuple_arity(ifrt::Tuple *tuple) { return tuple->Arity(); } - -// TODO ifrt::Tuple::Unpack -#pragma endregion -#pragma region xla::ifrt::PjRtTuple -extern "C" ifrt::PjRtTuple * -ifrt_pjrt_tuple_ctor(ifrt::PjRtCompatibleClient *client, ifrt::Value *values, - int nvalues) { - auto values_ptr = new tsl::RCReference[nvalues]; - for (int i = 0; i < nvalues; i++) { - values_ptr[i] = tsl::RCReference(); - values_ptr[i].reset(&values[i]); - } - auto span = absl::Span>(values_ptr, nvalues); - return MyValueOrThrow(ifrt::PjRtTuple::Create(client, span)).release(); +extern "C" const char *pjrt_client_platform_name(HeldPjRtClient *client) { + return ClientGetPlatformName(client->ptr()); } -extern "C" void ifrt_pjrt_tuple_free(ifrt::PjRtTuple *tuple) { delete tuple; } -#pragma endregion +// deprecated +// extern "C" HeldValue> * +// reactant_hold_pjrtbuffer(xla::PjRtBuffer *buffer) { +// return reactant::capture(std::shared_ptr(buffer)); +// } -#pragma region xla::ifrt::DType -extern "C" ifrt::DType *ifrt_dtype_ctor(ifrt::DType::Kind kind) { - return new ifrt::DType(kind); +extern "C" HeldPjRtBuffer *pjrt_buffer_from_host(HeldPjRtClient *client, + void *data, uint64_t ptype, + size_t dim, int64_t *cshape, + PjRtDevice *device) { + PjRtBuffer *buffer = + ArrayFromHostBuffer(client->ptr(), data, ptype, dim, cshape, device); + return reactant::capture(std::shared_ptr(buffer)); } -extern "C" void ifrt_dtype_free(ifrt::DType *dtype) { delete dtype; } +extern "C" void pjrt_buffer_dtor(HeldPjRtBuffer *buffer) { delete buffer; } -extern "C" ifrt::DType::Kind ifrt_dtype_kind(ifrt::DType *dtype) { - return dtype->kind(); +extern "C" void *pjrt_buffer_unsafe_buffer_pointer(HeldPjRtBuffer *buffer) { + return UnsafeBufferPointer(buffer->ptr()); } -extern "C" bool ifrt_dtype_eq(ifrt::DType *dtype1, ifrt::DType *dtype2) { - return *dtype1 == *dtype2; +extern "C" bool pjrt_buffer_is_on_cpu(HeldPjRtBuffer *buffer) { + return buffer->ptr()->IsOnCpu(); } -extern "C" bool ifrt_dtype_ne(ifrt::DType *dtype1, ifrt::DType *dtype2) { - return *dtype1 != *dtype2; +extern "C" HeldPjRtBuffer *pjrt_buffer_copy_to_device(HeldPjRtBuffer *buffer, + PjRtDevice *dst_device) { + PjRtBuffer *ret = CopyBufferToDevice(buffer->ptr(), dst_device); + return reactant::capture(std::shared_ptr(ret)); } -// Returns -1 if not aligned to a byte boundary or there is no fixed size -extern "C" int ifrt_dtype_byte_size(ifrt::DType *dtype) { - auto byte_size = dtype->byte_size(); - if (byte_size.has_value()) { - return byte_size.value(); - } - return -1; +extern "C" void pjrt_buffer_to_host(HeldPjRtBuffer *buffer, void *data) { + BufferToHost(buffer->ptr(), data); } -// Returns -1 if there is no fixed size -extern "C" int ifrt_dtype_bit_size(ifrt::DType *dtype) { - auto bit_size = dtype->bit_size(); - if (bit_size.has_value()) { - return bit_size.value(); - } - return -1; +extern "C" void pjrt_buffer_print(HeldPjRtBuffer *buffer) { + PrintPjRtBuffer(buffer->ptr()); } -extern "C" const char *ifrt_dtype_debug_string(ifrt::DType *dtype) { - return cstr_from_string(dtype->DebugString()); +extern "C" PjRtDevice *pjrt_buffer_get_device(HeldPjRtBuffer *buffer) { + return buffer->ptr()->device(); } -// xla::PrimitiveType is a enum, so we use int to represent it on Julia side -extern "C" xla::PrimitiveType ifrt_to_primitive_type(ifrt::DType *dtype) { - return MyValueOrThrow(ifrt::ToPrimitiveType(*dtype)); +extern "C" HeldPjRtClient *pjrt_buffer_get_client(HeldPjRtBuffer *buffer) { + return reactant::capture( + std::shared_ptr(buffer->ptr()->client())); } -// xla::PrimitiveType is a enum, so we use int to represent it on Julia side -extern "C" ifrt::DType *ifrt_to_dtype(xla::PrimitiveType primitive_type) { - auto dtype = MyValueOrThrow(ifrt::ToDType(primitive_type)); - return new ifrt::DType(dtype.kind()); -} -#pragma endregion +extern "C" void ifrt_client_dtor(ifrt::Client *client) { delete client; } -#pragma region xla::ifrt::Shape -extern "C" ifrt::Shape *ifrt_shape_ctor(const int64_t *dims, size_t dims_size) { - return new ifrt::Shape(absl::Span(dims, dims_size)); +// generic version, but IFRT-PjRt backend only supports SingleDeviceSharding +// and FullyReplicated. use `ifrt_pjrt_array_create` if using IFRT-PjRt. +extern "C" HeldIfrtArray *ifrt_client_make_array_from_host_buffer( + ifrt::Client *client, void *data, + int dtype_kind, // int + int ndims, const int64_t *c_shape, + HeldValue> *sharding, + int c_semantics) { + auto dtype = ifrt::DType(static_cast(dtype_kind)); + auto shape = ifrt::Shape(absl::Span(c_shape, ndims)); + return reactant::capture(MyValueOrThrow(client->MakeArrayFromHostBuffer( + data, dtype, shape, + std::nullopt, // byte_strides + sharding->obj(), + static_cast(c_semantics), + [] {} // on_done_with_host_buffer + ))); } -extern "C" void ifrt_shape_free(ifrt::Shape *shape) { delete shape; } - -extern "C" const int64_t *ifrt_shape_dims(ifrt::Shape *shape) { - return shape->dims().data(); +extern "C" HeldIfrtArray *ifrt_client_make_single_shard_array_from_host_buffer( + ifrt::Client *client, void *data, + int dtype_kind, // int + int ndims, const int64_t *c_shape, int c_semantics, ifrt::Device *device, + const char *mem_kind) { + auto memory_kind = ifrt::MemoryKind(std::string(mem_kind)); + auto sharding = reactant::capture(std::shared_ptr( + ifrt::SingleDeviceSharding::Create(device, memory_kind).release())); + return ifrt_client_make_array_from_host_buffer( + client, data, dtype_kind, ndims, c_shape, sharding, c_semantics); } -extern "C" int64_t ifrt_shape_dims_num_elements(ifrt::Shape *shape) { - return shape->num_elements(); -} - -extern "C" const char *ifrt_shape_debug_string(ifrt::Shape *shape) { - return cstr_from_string(shape->DebugString()); -} -#pragma endregion +// all arrays are assumed to have same DType +// each process only provides arrays for its own addressable devices +extern "C" HeldIfrtArray *ifrt_client_assemble_array_from_single_shards( + ifrt::Client *client, int32_t ndims, const int64_t *c_shape, + HeldValue> *sharding, int32_t narrays, + HeldIfrtArray **c_arrays, int32_t c_semantics) { + ifrt::Shape shape = ifrt::Shape(absl::Span(c_shape, ndims)); + std::vector> arrays; + for (int i = 0; i < narrays; i++) { + arrays.emplace_back(c_arrays[i]->obj()); + } + return reactant::capture( + MyValueOrThrow(client->AssembleArrayFromSingleDeviceArrays( + shape, sharding->obj(), absl::MakeSpan(arrays), + static_cast(c_semantics), + ifrt::SingleDeviceShardSemantics::kAddressableShards))); +} + +// we should deprecate this because is IFRT-PjRt specific +// try use `ifrt_client_make_single_shard_array_from_host_buffer` instead +extern "C" HeldIfrtArray * +ifrt_pjrt_array_create(ifrt::PjRtClient *client, + HeldValue> *buffer) { + return reactant::capture(tsl::RCReference( + MyValueOrThrow(xla::ifrt::PjRtArray::Create(client, buffer->obj())))); +} + +// we might me interested in the `Compiler::Compile` method variant that accepts +// `Topology` +extern "C" xla::ifrt::LoadedExecutable * +ifrt_compile(ifrt::Client *client, MlirModule cmod, int64_t device_id, + bool is_sharded, const int64_t *mesh_ids, int64_t num_mesh_ids, + const char *xla_gpu_cuda_data_dir) { + xla::CompileOptions compile_options = GenerateCompileOptions( + device_id, is_sharded, mesh_ids, num_mesh_ids, xla_gpu_cuda_data_dir); + auto options = std::make_unique( + xla::ifrt::XlaCompileOptions(compile_options)); + + mlir::ModuleOp cmod_op = cast(*unwrap(cmod)); + if (is_sharded) { + // https://github.com/openxla/xla/blob/b3c641b05692f3712fb3c272e38665fdfa28bdf8/xla/python/py_client.cc#L460 + auto status = xla::ExportShardyForHloRoundTrip(cmod_op); + if (!status.ok()) { + ReactantThrowError(status.ToString().c_str()); + } + } -#pragma region xla::ifrt::DynamicShape -extern "C" ifrt::DynamicShape * -ifrt_dynamicshape_ctor(ifrt::Shape *shape, const bool *dynamic_dims_mask) { - auto tag = ifrt::BoundedDynamicShapeTag( - absl::Span(dynamic_dims_mask, shape->dims().size())); - auto dynshape = MyValueOrThrow(ifrt::DynamicShape::Create(*shape, tag)); - return new ifrt::DynamicShape(dynshape); -} + auto program = + std::make_unique(xla::ifrt::HloProgram(cmod_op)); + auto compiler = client->GetDefaultCompiler(); -extern "C" void ifrt_dynamicshape_free(ifrt::DynamicShape *shape) { - delete shape; + return MyValueOrThrow( + compiler->Compile(std::move(program), std::move(options))) + .release(); } -// TODO ifrt::DynamicShape::GetTag - -extern "C" bool ifrt_dynamicshape_eq(ifrt::DynamicShape *shape1, - ifrt::DynamicShape *shape2) { - return *shape1 == *shape2; +extern "C" void +ifrt_pjrt_loaded_executable_dtor(xla::ifrt::PjRtLoadedExecutable *exec) { + delete exec; } -extern "C" bool ifrt_dynamicshape_ne(ifrt::DynamicShape *shape1, - ifrt::DynamicShape *shape2) { - return *shape1 != *shape2; -} +extern "C" void ifrt_array_dtor(HeldIfrtArray *array) { delete array; } -extern "C" ifrt::Shape * -ifrt_dynamicshape_get_padded_shape(ifrt::DynamicShape *shape) { - auto padshape = MyValueOrThrow(shape->GetPaddedShape()); - return new ifrt::Shape(padshape); +// in principle, use ArrayCopySemantics::kAlwaysCopy (=0) +extern "C" FutureType * +ifrt_CopyArrayToHostBuffer(HeldIfrtArray *array, void *data, + ifrt::ArrayCopySemantics semantics) { + return new FutureType( + (*array)->CopyToHostBuffer(data, std::nullopt, semantics)); } -extern "C" bool ifrt_dynamicshape_is_dynamic_dim(ifrt::DynamicShape *shape, - int dimension) { - return shape->IsDynamicDim(dimension); +extern "C" void +PjRtLoadedExecutableGetHloModules(xla::PjRtLoadedExecutable *exec, + void **hlo_modules, int32_t *nmodules) { + auto hlo_modules_vec = MyValueOrThrow(exec->GetHloModules()); + *nmodules = hlo_modules_vec.size(); + for (int i = 0; i < *nmodules; i++) { + hlo_modules[i] = reactant::capture(hlo_modules_vec[i]); + } } extern "C" const char * -ifrt_dynamicshape_debug_string(ifrt::DynamicShape *shape) { - return cstr_from_string(shape->DebugString()); +HloModuleToString(HeldValue> *hlo_module) { + return cstr_from_string(hlo_module->obj()->ToString()); } -#pragma endregion -#pragma region xla::ifrt::Index -extern "C" ifrt::Index *ifrt_index_ctor(const int64_t *elements, - size_t elements_size) { - return new ifrt::Index(absl::Span(elements, elements_size)); -} +extern "C" void +FreeHloModule(HeldValue> *hlo_module) { + delete hlo_module; +} + +#pragma region IfRtClient + +// XXX: Bring back with the correct API +// extern "C" ifrt::proxy::GrpcServer * +// ifrt_proxy_grpc_server_create_from_ifrt_client_factory_cpu( +// const char *c_address, uint8_t asynchronous, int node_id) { +// std::string address = c_address; + +// return MyValueOrThrow( +// ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory( +// address, +// [asynchronous, +// node_id]() -> absl::StatusOr> +// { +// auto pjrt_client = std::shared_ptr( +// MakeCPUClient(asynchronous, node_id)); +// return std::shared_ptr( +// xla::ifrt::PjRtClient::Create(pjrt_client).release()); +// })) +// .release(); +// } -extern "C" ifrt::Index *ifrt_index_zeros(int num_elements) { - return new ifrt::Index(ifrt::Index::Zeros(num_elements)); -} +// extern "C" ifrt::proxy::GrpcServer * +// ifrt_proxy_grpc_server_create_from_ifrt_client_factory_gpu( +// int node_id, int num_nodes, int *allowed_devices, int +// num_allowed_devices, double memory_fraction, bool preallocate, const char +// *platform_name, const char **error) { +// return MyValueOrThrow( +// ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory( +// std::string(), +// [node_id, num_nodes, allowed_devices, num_allowed_devices, +// memory_fraction, preallocate, platform_name, +// error]() -> absl::StatusOr> { +// auto pjrt_client = +// std::shared_ptr(MakeGPUClient( +// node_id, num_nodes, allowed_devices, +// num_allowed_devices, memory_fraction, preallocate, +// platform_name, error)); +// return std::shared_ptr( +// xla::ifrt::PjRtClient::Create(pjrt_client).release()); +// })) +// .release(); +// } -extern "C" void ifrt_index_free(ifrt::Index *index) { delete index; } +// extern "C" ifrt::proxy::GrpcServer * +// ifrt_proxy_grpc_server_create_from_ifrt_client_factory_tpu( +// const char *c_address, const char *tpu_path, const char **error) { +// std::string address = c_address; +// +// return MyValueOrThrow( +// xla::ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory( +// address, +// [](xla::ifrt::AttributeMap initialization_data) -> +// absl::StatusOr> { +// auto pjrt_client = +// std::shared_ptr(GetCApiClient("TPU")); +// return +// xla::ifrt::PjRtClient::Create(std::move(pjrt_client)); +// })) +// .release(); +// } -extern "C" const int64_t *ifrt_index_elements(ifrt::Index *index) { - return index->elements().data(); +extern "C" void ifrt_proxy_grpc_server_dtor(ifrt::proxy::GrpcServer *server) { + delete server; } -extern "C" int ifrt_index_count(ifrt::Index *index) { - return index->elements().size(); +extern "C" const char * +ifrt_proxy_grpc_server_address(ifrt::proxy::GrpcServer *server) { + return cstr_from_string(server->address()); } -extern "C" bool ifrt_index_eq(ifrt::Index *index1, ifrt::Index *index2) { - return *index1 == *index2; +extern "C" void ifrt_proxy_grpc_server_wait(ifrt::proxy::GrpcServer *server) { + server->Wait(); } -extern "C" bool ifrt_index_ne(ifrt::Index *index1, ifrt::Index *index2) { - return *index1 != *index2; +// `c_proxy_server_address` must be of the form +// `:`; e.g. "grpc:localhost" +// NOTE not sure if we must pass the port, but probably yes +// by default, set `connection_timeout_in_minutes` to 2 +extern "C" ifrt::Client * +ifrt_proxy_create_client(const char *c_proxy_server_address, + int connection_timeout_in_minutes) { + std::string proxy_server_address = c_proxy_server_address; + ifrt::proxy::ClientConnectionOptions options = { + absl::Minutes(connection_timeout_in_minutes), + nullptr, // callback `on_disconnect` + nullptr, // callback `on_connection_update` + }; + return MyValueOrThrow( + ifrt::proxy::CreateClient(proxy_server_address, options)) + .release(); } -extern "C" ifrt::Index *ifrt_index_add(ifrt::Index *index, - ifrt::Index *offset) { - return new ifrt::Index(*index + *offset); -} +extern "C" ifrt::Client *ifrt_pjrt_make_client( + PjRtClient *pjrt_client, int node_id, int num_nodes, + void *distributed_runtime_client, const char **error, + std::string key_prefix, + std::optional> kv_store) { + ifrt::PjRtClient::CreateOptions options; + options.pjrt_client = std::shared_ptr(pjrt_client); -extern "C" ifrt::Index *ifrt_index_sub(ifrt::Index *index, - ifrt::Index *offset) { - return new ifrt::Index(*index - *offset); -} + if (num_nodes > 1) { + if (distributed_runtime_client == nullptr) { + *error = + "`distributed_runtime_client` must be non-null if `num_nodes` > 1"; + return nullptr; + } + if (kv_store.has_value()) { + options.kv_store = kv_store.value(); + } else { + auto typed_distributed_runtime_client = static_cast< + HeldValue> *>( + distributed_runtime_client); + options.kv_store = GetDistributedKeyValueStore( + typed_distributed_runtime_client->obj(), key_prefix); + } + } -// WARN we're not checking if the multiplier has the same size as the index -extern "C" ifrt::Index *ifrt_index_mul(ifrt::Index *index, - const int64_t *multiplier) { - return new ifrt::Index( - *index * absl::Span(multiplier, ifrt_index_count(index))); -} + options.process_id = node_id; + options.num_processes = num_nodes; -extern "C" void ifrt_index_add_inplace(ifrt::Index *index, - ifrt::Index *offset) { - *index += *offset; + return MyValueOrThrow(xla::ifrt::PjRtClient::Create(options)).release(); } -extern "C" void ifrt_index_sub_inplace(ifrt::Index *index, - ifrt::Index *offset) { - *index -= *offset; -} +const char *const kMpiTrampolineLibEnv = "MPITRAMPOLINE_LIB"; -extern "C" void ifrt_index_mul_inplace(ifrt::Index *index, - const int64_t *multiplier) { - *index *= absl::Span(multiplier, ifrt_index_count(index)); -} +extern "C" ifrt::Client * +ifrt_make_pjrt_cpu_client(uint8_t asynchronous, int node_id, int num_nodes, + void *distributed_runtime_client, + const char **error) { + std::optional> collectives; + std::optional> kv_store; + + if (distributed_runtime_client != nullptr) { + auto mpi_trampoline_path = llvm::sys::Process::GetEnv(kMpiTrampolineLibEnv); + if (mpi_trampoline_path) { + // Use MPI + // TODO: How do we Finalize?? + auto mpi_collectives = std::make_shared(); + collectives = mpi_collectives; + static_cast(mpi_collectives.get())->Init(); + } else { + // Use Gloo + auto typed_distributed_runtime_client = static_cast< + HeldValue> *>( + distributed_runtime_client); + kv_store = + GetDistributedKeyValueStore(typed_distributed_runtime_client->obj(), + /*key_prefix=*/"cpu:"); + auto gloo_kv_store = + std::make_unique(kv_store.value()); +#if defined(__linux__) + auto tcp_attrs = gloo::transport::tcp::attr(); + auto tcp_device = gloo::transport::tcp::CreateDevice(tcp_attrs); + collectives = std::make_shared( + std::move(gloo_kv_store), std::move(tcp_device)); +#elif defined(__APPLE__) + auto uv_attrs = gloo::transport::uv::attr(); + auto uv_device = gloo::transport::uv::CreateDevice(uv_attrs); + collectives = std::make_shared( + std::move(gloo_kv_store), std::move(uv_device)); +#else + ReactantThrowError( + "Gloo TCP Collectives only implemented for linux and macos"); +#endif + } + } -extern "C" const char *ifrt_index_debug_string(ifrt::Index *index) { - return cstr_from_string(index->DebugString()); + PjRtClient *pjrt_client = + MakeCPUClientInternal(asynchronous, node_id, collectives); + if (pjrt_client == nullptr) + return nullptr; + return ifrt_pjrt_make_client(pjrt_client, node_id, num_nodes, + distributed_runtime_client, error, "cpu", + kv_store); +} + +extern "C" ifrt::Client *ifrt_make_pjrt_gpu_client( + int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices, + double memory_fraction, bool preallocate, const char *platform_name, + const char **error, void *distributed_runtime_client) { + PjRtClient *pjrt_client = MakeGPUClient( + node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction, + preallocate, platform_name, error, distributed_runtime_client); + if (pjrt_client == nullptr) + return nullptr; + std::optional> kv_store; + return ifrt_pjrt_make_client(pjrt_client, node_id, num_nodes, + distributed_runtime_client, error, "gpu", + kv_store); } -#pragma endregion -#pragma region xla::ifrt::IndexDomain -extern "C" ifrt::IndexDomain *ifrt_indexdomain_ctor(ifrt::Shape *shape) { - return new ifrt::IndexDomain(*shape); +extern "C" ifrt::Client * +ifrt_make_pjrt_tpu_client(const char *tpu_path, const char **error, int node_id, + int num_nodes, void *distributed_runtime_client) { + PjRtClient *pjrt_client = MakeTPUClient(tpu_path, error); + if (pjrt_client == nullptr) + return nullptr; + std::optional> kv_store; + return ifrt_pjrt_make_client(pjrt_client, node_id, num_nodes, + distributed_runtime_client, error, "tpu", + kv_store); } -extern "C" ifrt::IndexDomain * -ifrt_indexdomain_ctor_with_origin(ifrt::Index *origin, ifrt::Shape *shape) { - return new ifrt::IndexDomain(*origin, *shape); -} +extern "C" void ifrt_FreeClient(ifrt::Client *client) { delete client; } -extern "C" void ifrt_indexdomain_free(ifrt::IndexDomain *index_domain) { - delete index_domain; +extern "C" int ifrt_client_device_count(ifrt::Client *client) { + return client->device_count(); } -extern "C" const ifrt::Index * -ifrt_indexdomain_origin(ifrt::IndexDomain *index_domain) { - return &index_domain->origin(); +extern "C" int ifrt_client_addressable_device_count(ifrt::Client *client) { + return client->addressable_device_count(); } -extern "C" const ifrt::Shape * -ifrt_indexdomain_shape(ifrt::IndexDomain *index_domain) { - return &index_domain->shape(); +extern "C" void ifrt_client_devices(ifrt::Client *client, + ifrt::Device **out_devices) { + auto span = client->devices(); + for (int i = 0; i < span.size(); i++) { + out_devices[i] = span[i]; + } } -extern "C" bool ifrt_indexdomain_eq(ifrt::IndexDomain *index_domain1, - ifrt::IndexDomain *index_domain2) { - return *index_domain1 == *index_domain2; +extern "C" void ifrt_client_addressable_devices(ifrt::Client *client, + ifrt::Device **out_devices) { + auto span = client->addressable_devices(); + for (int i = 0; i < span.size(); i++) { + out_devices[i] = span[i]; + } } -extern "C" bool ifrt_indexdomain_ne(ifrt::IndexDomain *index_domain1, - ifrt::IndexDomain *index_domain2) { - return *index_domain1 != *index_domain2; +extern "C" void ifrt_client_all_devices(ifrt::Client *client, + ifrt::Device **out_devices) { + auto span = client->GetAllDevices(); + for (int i = 0; i < span.size(); i++) { + out_devices[i] = span[i]; + } } -extern "C" ifrt::IndexDomain * -ifrt_indexdomain_add(ifrt::IndexDomain *index_domain, ifrt::Index *offset) { - return new ifrt::IndexDomain(*index_domain + *offset); +extern "C" ifrt::Device *ifrt_client_lookup_device(ifrt::Client *client, + int dev_id) { + return MyValueOrThrow( + client->LookupDevice(static_cast(dev_id))); } -extern "C" ifrt::IndexDomain * -ifrt_indexdomain_sub(ifrt::IndexDomain *index_domain, ifrt::Index *offset) { - return new ifrt::IndexDomain(*index_domain - *offset); +extern "C" ifrt::Device * +ifrt_client_lookup_addressable_device(ifrt::Client *client, int local_hw_id) { + return MyValueOrThrow(client->LookupAddressableDevice(local_hw_id)); } -extern "C" void ifrt_indexdomain_add_inplace(ifrt::IndexDomain *index_domain, - ifrt::Index *offset) { - *index_domain += *offset; +extern "C" int ifrt_ClientProcessIndex(ifrt::Client *client) { + return client->process_index(); } -extern "C" void ifrt_indexdomain_sub_inplace(ifrt::IndexDomain *index_domain, - ifrt::Index *offset) { - *index_domain -= *offset; +extern "C" const char *ifrt_ClientGetPlatformName(ifrt::Client *client) { + return cstr_from_string(client->platform_name()); } -extern "C" const char * -ifrt_indexdomain_debug_string(ifrt::IndexDomain *index_domain) { - return cstr_from_string(index_domain->DebugString()); +extern "C" ifrt::Device *ifrt_ClientGetDevice(ifrt::Client *client, int idx) { + return MyValueOrThrow(client->LookupDevice(ifrt::DeviceId(idx))); } -#pragma endregion -#pragma region xla::ifrt::MemoryKind -// Pass a nullptr to create a `MemoryKind` with no memory chosen. -extern "C" ifrt::MemoryKind *ifrt_memorykind_ctor(const char *memory_kind) { - if (memory_kind == nullptr) - return new ifrt::MemoryKind(); - return new ifrt::MemoryKind(std::string(memory_kind)); +extern "C" ifrt::Device *ifrt_ClientGetAddressableDevice(ifrt::Client *client, + int idx) { + return MyValueOrThrow(client->LookupAddressableDevice(idx)); } -extern "C" void ifrt_memorykind_free(ifrt::MemoryKind *memory_kind) { - delete memory_kind; -} +#pragma endregion -extern "C" bool ifrt_memorykind_eq(ifrt::MemoryKind *mk1, - ifrt::MemoryKind *mk2) { - return *mk1 == *mk2; -} +#pragma region IfRtDevice -extern "C" bool ifrt_memorykind_ne(ifrt::MemoryKind *mk1, - ifrt::MemoryKind *mk2) { - return *mk1 != *mk2; +extern "C" int64_t ifrt_DeviceGetGlobalDeviceId(ifrt::Device *device) { + return device->Id().value(); } -extern "C" const char *ifrt_memorykind_string(ifrt::MemoryKind *memory_kind) { - if (memory_kind->memory_kind().has_value()) - return cstr_from_string(memory_kind->memory_kind().value()); - else - return nullptr; +extern "C" const char *ifrt_DeviceGetKind(ifrt::Device *device) { + return cstr_from_string(device->Kind()); } -extern "C" ifrt::MemoryKind * -ifrt_memorykind_canonicalize(ifrt::MemoryKind *memory_kind, - ifrt::Device *device) { - return new ifrt::MemoryKind(CanonicalizeMemoryKind(*memory_kind, device)); +extern "C" ifrt::Client *ifrt_DeviceToClient(ifrt::Device *device) { + return device->client(); } -#pragma endregion -#pragma region xla::ifrt::Memory -// MemoryId is a struct with a single int32_t field --> check out -// xla/python/ifrt/memory.h -extern "C" ifrt::MemoryId ifrt_memory_id(ifrt::Memory *memory) { - return memory->Id(); +extern "C" bool ifrt_DeviceIsAddressable(ifrt::Device *device) { + return device->IsAddressable(); } -extern "C" const ifrt::MemoryKind *ifrt_memory_kind(ifrt::Memory *memory) { - return &(memory->Kind()); +tsl::RCReference ifrt_CreateDeviceListFromDevices( + ifrt::Client *client, ifrt::Device **device_list, int32_t num_devices) { + absl::Span devices(device_list, num_devices); + return client->MakeDeviceList(devices); } -extern "C" const char *ifrt_memory_to_string(ifrt::Memory *memory) { - return cstr_from_string(memory->ToString()); +extern "C" ifrt::Memory *ifrt_DeviceGetDefaultMemory(ifrt::Device *device) { + return MyValueOrThrow(device->DefaultMemory()); } -extern "C" const char *ifrt_memory_debug_string(ifrt::Memory *memory) { - return cstr_from_string(memory->DebugString()); +extern "C" ifrt::Memory **ifrt_DeviceGetMemories(ifrt::Device *device, + int32_t *size) { + auto memory_list = device->Memories(); + *size = memory_list.size(); + return const_cast(memory_list.data()); } -extern "C" std::tuple -ifrt_memory_devices(ifrt::Memory *memory) { - auto devices = memory->Devices(); - return std::make_tuple(devices.size(), - devices.data()); +extern "C" ifrt::MemoryKind *ifrt_MemoryGetMemoryKind(ifrt::Memory *memory) { + ifrt::MemoryKind *memory_kind = new ifrt::MemoryKind(memory->Kind()); + return memory_kind; } -#pragma endregion -#pragma region xla::ifrt::PjRtMemory -extern "C" ifrt::PjRtMemory * -ifrt_pjrt_memory_ctor(ifrt::PjRtClient *client, - xla::PjRtMemorySpace *memory_space) { - return new ifrt::PjRtMemory(client, memory_space); +extern "C" const char *ifrt_MemoryToString(ifrt::Memory *memory) { + return cstr_from_string(memory->ToString()); } -extern "C" void ifrt_pjrt_memory_free(ifrt::PjRtMemory *memory) { - delete memory; +extern "C" const char *ifrt_MemoryKindToString(ifrt::MemoryKind *memory_kind) { + auto memkind = memory_kind->memory_kind(); + if (!memkind.has_value()) + return ""; + return cstr_from_string(memkind.value()); } -extern "C" ifrt::PjRtClient *ifrt_pjrt_memory_client(ifrt::PjRtMemory *memory) { - return memory->client(); +extern "C" bool ifrt_MemoryKindsAreEqual(ifrt::MemoryKind *a, + ifrt::MemoryKind *b) { + return *a == *b; } -extern "C" xla::PjRtMemorySpace * -ifrt_pjrt_memory_space(ifrt::PjRtMemory *memory) { - return memory->pjrt_memory(); -} #pragma endregion -#pragma region xla::ifrt::Device -extern "C" ifrt::Client *ifrt_device_client(ifrt::Device *device) { - return device->client(); -} +#pragma region OpSharding -// DeviceId is a struct with a single int32_t field --> check out -// xla/pjrt/pjrt_common.h -extern "C" ifrt::DeviceId ifrt_device_id(ifrt::Device *device) { - return device->Id(); +extern "C" void free_op_sharding(xla::OpSharding *op_sharding) { + delete op_sharding; } -// TODO ifrt_device_attributes - -extern "C" const char *ifrt_device_kind(ifrt::Device *device) { - return cstr_from_string(device->Kind()); +extern "C" int32_t +op_sharding_to_op_sharding_type(xla::OpSharding *op_sharding) { + return static_cast(op_sharding->type()); } -extern "C" const char *ifrt_device_to_string(ifrt::Device *device) { - return cstr_from_string(device->ToString()); +extern "C" int32_t +op_sharding_to_shard_group_type(xla::OpSharding *op_sharding) { + return static_cast(op_sharding->shard_group_type()); } -extern "C" const char *ifrt_device_debug_string(ifrt::Device *device) { - return cstr_from_string(device->DebugString()); +extern "C" int32_t op_sharding_to_shard_group_id(xla::OpSharding *op_sharding) { + return static_cast(op_sharding->shard_group_id()); } -extern "C" ifrt::Memory *ifrt_device_default_memory(ifrt::Device *device) { - return MyValueOrThrow(device->DefaultMemory()); +extern "C" bool op_sharding_is_shard_group(xla::OpSharding *op_sharding) { + return op_sharding->is_shard_group(); } -// TODO ifrt_device_memories - -extern "C" bool ifrt_device_is_addressable(ifrt::Device *device) { - return device->IsAddressable(); +extern "C" bool +op_sharding_replicate_on_last_tile_dim(xla::OpSharding *op_sharding) { + return op_sharding->replicate_on_last_tile_dim(); } -extern "C" int ifrt_device_process_index(ifrt::Device *device) { - return device->ProcessIndex(); +extern "C" bool op_sharding_has_last_tile_dims(xla::OpSharding *op_sharding) { + return op_sharding->last_tile_dims_size() > 0; } -#pragma endregion -#pragma region xla::ifrt::PjRtDevice -// DeviceId is a struct with a single int32_t field --> check out -// xla/pjrt/pjrt_common.h -// TODO support `attributes` parameter -extern "C" ifrt::PjRtDevice * -ifrt_pjrt_device_ctor(ifrt::PjRtClient *client, ifrt::DeviceId device_id, - const char *kind, const char *to_string, - const char *debug_string, int process_index, - xla::PjRtDevice *pjrt_device) { - return new ifrt::PjRtDevice( - client, device_id, kind, to_string, debug_string, process_index, - absl::flat_hash_map(), pjrt_device); +extern "C" int32_t +op_sharding_last_tile_dims_size(xla::OpSharding *op_sharding) { + return static_cast(op_sharding->last_tile_dims_size()); } -extern "C" void ifrt_pjrt_device_free(ifrt::PjRtDevice *device) { - delete device; +extern "C" void op_sharding_last_tile_dims(xla::OpSharding *op_sharding, + int32_t *last_tile_dims) { + std::vector last_tile_dims_vec(op_sharding->last_tile_dims().begin(), + op_sharding->last_tile_dims().end()); + std::copy(last_tile_dims_vec.begin(), last_tile_dims_vec.end(), + last_tile_dims); + return; } -extern "C" xla::PjRtDevice * -ifrt_pjrt_device_pjrt_device(ifrt::PjRtDevice *device) { - return device->pjrt_device(); +extern "C" bool +op_sharding_has_iota_reshape_dims(xla::OpSharding *op_sharding) { + return op_sharding->iota_reshape_dims_size() > 0; } -#pragma endregion - -#pragma region xla::ifrt::Sharding -// TODO ifrt_sharding_devices -// TODO ifrt_sharding_memory_kind - -// extern "C" void ifrt_sharding_disassemble(ifrt::Sharding* sharding, -// ifrt::Shape* shape, char** error) { -// auto status = sharding->Disassemble(*shape); -// if (!status.ok()) { -// auto str = status.message(); -// char* err = (char*)malloc(str.size()+1); -// memcpy(err, str.data(), str.size()+1); -// *error = err; -// } -// } -// TODO ifrt_sharding_disassemble_dynamic_shape -// TODO ifrt_sharding_index_domains - -extern "C" const char *ifrt_sharding_debug_string(ifrt::Sharding *sharding) { - return cstr_from_string(sharding->DebugString()); +extern "C" int32_t +op_sharding_iota_reshape_dims_size(xla::OpSharding *op_sharding) { + return static_cast(op_sharding->iota_reshape_dims_size()); } -#pragma endregion -#pragma region xla::ifrt::Array -extern "C" ifrt::DType *ifrt_array_dtype(ifrt::Array *array) { - return new ifrt::DType(array->dtype()); +extern "C" void op_sharding_iota_reshape_dims(xla::OpSharding *op_sharding, + int32_t *iota_reshape_dims) { + std::vector iota_reshape_dims_vec( + op_sharding->iota_reshape_dims().begin(), + op_sharding->iota_reshape_dims().end()); + std::copy(iota_reshape_dims_vec.begin(), iota_reshape_dims_vec.end(), + iota_reshape_dims); + return; } -extern "C" const ifrt::Shape *ifrt_array_shape(ifrt::Array *array) { - return &(array->shape()); +extern "C" bool +op_sharding_has_iota_transpose_perm(xla::OpSharding *op_sharding) { + return op_sharding->iota_transpose_perm_size() > 0; } -extern "C" const ifrt::Sharding *ifrt_array_sharding(ifrt::Array *array) { - return &(array->sharding()); +extern "C" int32_t +op_sharding_iota_transpose_perm_size(xla::OpSharding *op_sharding) { + return static_cast(op_sharding->iota_transpose_perm_size()); } -extern "C" PjRtLayout *ifrt_array_layout(ifrt::Array *array) { - return MyValueOrThrow(array->layout()).release(); +extern "C" void op_sharding_iota_transpose_perm(xla::OpSharding *op_sharding, + int32_t *iota_transpose_perm) { + std::vector iota_transpose_perm_vec( + op_sharding->iota_transpose_perm().begin(), + op_sharding->iota_transpose_perm().end()); + std::copy(iota_transpose_perm_vec.begin(), iota_transpose_perm_vec.end(), + iota_transpose_perm); + return; } -// TODO xla::ifrt::Array::DisassembleIntoSingleDeviceArrays -// TODO xla::ifrt::Array::FullyReplicatedShard - -extern "C" ifrt::Future<> -ifrt_array_copy_to_host_buffer(ifrt::Array *array, void *data, - const int64_t *byte_strides, int semantics) { - return array->CopyToHostBuffer( - data, - absl::Span(byte_strides, array->shape().num_elements()), - ifrt::ArrayCopySemantics(semantics)); -} -#pragma endregion - -#pragma region xla::ifrt::PjRtArray -// TODO constructors / `Create` - -extern "C" std::tuple -ifrt_pjrt_array_pjrt_buffers(ifrt::PjRtArray *array) { - auto buffers = array->pjrt_buffers(); - auto buffers_ptr = new xla::PjRtBuffer *[buffers.size()]; - for (int i = 0; i < buffers.size(); i++) { - buffers_ptr[i] = buffers[i].get(); - } - return std::make_tuple(buffers.size(), buffers_ptr); +extern "C" bool +op_sharding_has_tile_assignment_dimensions(xla::OpSharding *op_sharding) { + return op_sharding->tile_assignment_dimensions_size() > 0; } -#pragma endregion -#pragma region xla::ifrt::Topology -extern "C" const char *ifrt_topology_platform_name(ifrt::Topology *topology) { - return cstr_from_string(topology->platform_name()); +extern "C" int32_t +op_sharding_tile_assignment_dimensions_size(xla::OpSharding *op_sharding) { + return static_cast(op_sharding->tile_assignment_dimensions_size()); } -extern "C" const char * -ifrt_topology_platform_version(ifrt::Topology *topology) { - return cstr_from_string(topology->platform_version()); +extern "C" void +op_sharding_tile_assignment_dimensions(xla::OpSharding *op_sharding, + int32_t *tile_assignment_dimensions) { + std::vector tile_assignment_dimensions_vec( + op_sharding->tile_assignment_dimensions().begin(), + op_sharding->tile_assignment_dimensions().end()); + std::copy(tile_assignment_dimensions_vec.begin(), + tile_assignment_dimensions_vec.end(), tile_assignment_dimensions); + return; } -// returns PjRtPlatformId which is a type alias for uint64_t -extern "C" uint64_t ifrt_topology_platform_id(ifrt::Topology *topology) { - return topology->platform_id(); +extern "C" bool +op_sharding_has_tile_assignment_devices(xla::OpSharding *op_sharding) { + return op_sharding->tile_assignment_devices_size() > 0; } -extern "C" std::tuple -ifrt_topology_device_descriptions(ifrt::Topology *topology) { - auto descriptions = topology->DeviceDescriptions(); - auto descriptions_ptr = - new const xla::PjRtDeviceDescription *[descriptions.size()]; - for (int i = 0; i < descriptions.size(); i++) { - descriptions_ptr[i] = descriptions[i].release(); - } - return std::make_tuple(descriptions.size(), descriptions_ptr); +extern "C" int32_t +op_sharding_tile_assignment_devices_size(xla::OpSharding *op_sharding) { + return static_cast(op_sharding->tile_assignment_devices_size()); } -// TODO xla::ifrt::Topology::GetDefaultLayout - -extern "C" const char *ifrt_topology_serialize(ifrt::Topology *topology) { - return cstr_from_string(MyValueOrThrow(topology->Serialize())); +extern "C" void +op_sharding_tile_assignment_devices(xla::OpSharding *op_sharding, + int32_t *tile_assignment_devices) { + std::vector tile_assignment_devices_vec( + op_sharding->tile_assignment_devices().begin(), + op_sharding->tile_assignment_devices().end()); + std::copy(tile_assignment_devices_vec.begin(), + tile_assignment_devices_vec.end(), tile_assignment_devices); + return; } -// TODO xla::ifrt::Topology::Attributes - #pragma endregion -#pragma region xla::ifrt::PjRtTopology -extern "C" ifrt::PjRtTopology * -ifrt_pjrt_topology_ctor(const xla::PjRtTopologyDescription *description) { - return new ifrt::PjRtTopology( - std::shared_ptr{description}); -} +#pragma region HloSharding -extern "C" const xla::PjRtTopologyDescription * -ifrt_pjrt_topology_description(ifrt::PjRtTopology *topology) { - return topology->description().get(); +extern "C" void free_hlo_sharding(xla::HloSharding *hlo_sharding) { + delete hlo_sharding; } -#pragma endregion -#pragma region xla::ifrt::Client -extern "C" int ifrt_client_device_count(ifrt::Client *client) { - return client->device_count(); +extern "C" void free_ifrt_hlo_sharding(ifrt::HloSharding *hlo_sharding) { + delete hlo_sharding; } -extern "C" int ifrt_client_addressable_device_count(ifrt::Client *client) { - return client->addressable_device_count(); +extern "C" xla::HloSharding * +hlo_sharding_from_op_sharding(xla::OpSharding *op_sharding) { + xla::HloSharding *hlo_sharding = new xla::HloSharding( + MyValueOrThrow(xla::HloSharding::FromProto(*op_sharding))); + return hlo_sharding; } -extern "C" ifrt::Device *const *ifrt_client_devices(ifrt::Client *client) { - return client->devices().data(); +extern "C" xla::OpSharding * +hlo_sharding_to_op_sharding(xla::HloSharding *hlo_sharding) { + xla::OpSharding *op_sharding = new xla::OpSharding(hlo_sharding->ToProto()); + return op_sharding; } -extern "C" ifrt::Device *const * -ifrt_client_addressable_devices(ifrt::Client *client) { - return client->addressable_devices().data(); +extern "C" const char * +hlo_sharding_to_string(const xla::HloSharding *hlo_sharding) { + return cstr_from_string(hlo_sharding->ToString(true)); } -extern "C" int ifrt_client_process_index(ifrt::Client *client) { - return client->process_index(); +extern "C" ifrt::MemoryKind *ifrt_memory_kind_from_string(const char *c_str) { + return new ifrt::MemoryKind(std::string(c_str)); } -// TODO xla::ifrt::Client::GetDefaultDeviceAssignment - -extern "C" ifrt::Device *ifrt_client_lookup_device(ifrt::Client *client, - int device_id) { - return MyValueOrThrow(client->LookupDevice(ifrt::DeviceId(device_id))); +extern "C" ifrt::HloSharding *ifrt_hlo_sharding_from_xla_hlo_sharding( + ifrt::Client *client, ifrt::Device **device_list, int32_t num_devices, + ifrt::MemoryKind *memory_kind, xla::HloSharding *xla_hlo_sharding) { + return ifrt::HloSharding::Create( + ifrt_CreateDeviceListFromDevices(client, device_list, num_devices), + *memory_kind, *xla_hlo_sharding) + .release(); } -extern "C" ifrt::Device * -ifrt_client_lookup_addressable_device(ifrt::Client *client, int device_id) { - return MyValueOrThrow(client->LookupAddressableDevice(device_id)); +extern "C" xla::HloSharding * +ifrt_hlo_sharding_to_xla_hlo_sharding(ifrt::HloSharding *hlo_sharding) { + xla::HloSharding *xla_hlo_sharding = + new xla::HloSharding(hlo_sharding->xla_hlo_sharding()); + return xla_hlo_sharding; } -extern "C" ifrt::Compiler *ifrt_client_default_compiler(ifrt::Client *client) { - return client->GetDefaultCompiler(); +extern "C" const char * +ifrt_hlo_sharding_to_string(ifrt::HloSharding *hlo_sharding) { + return cstr_from_string(hlo_sharding->DebugString()); } -// TODO ifrt_client_topology_for_devices -// TODO ifrt_client_default_layout_for_device -#pragma endregion - -#pragma region xla::ifrt::PjRtClient -// TODO support more parameters of `PjRtClient::CreateOptions` -extern "C" ifrt::PjRtClient * -ifrt_pjrt_client_ctor(xla::PjRtClient *pjrt_client) { - return MyValueOrThrow( - ifrt::PjRtClient::Create(ifrt::PjRtClient::CreateOptions{ - std::shared_ptr{pjrt_client}})) - .release(); +extern "C" ifrt::HloSharding *ifrt_sharding_to_ifrt_hlo_sharding( + HeldValue> *sharding) { + const ifrt::Sharding *val = sharding->obj().get(); + if (!llvm::isa(val)) + ReactantThrowError("Expected a HloSharding"); + return new ifrt::HloSharding(*llvm::dyn_cast(val)); } -extern "C" void ifrt_pjrt_client_free(ifrt::PjRtClient *client) { - delete client; +extern "C" void +free_ifrt_sharding(HeldValue> *sharding) { + delete sharding; } -extern "C" xla::PjRtClient * -ifrt_pjrt_client_pjrt_client(ifrt::PjRtClient *client) { - return client->pjrt_client(); +extern "C" HeldValue> * +ifrt_sharding_from_ifrt_hlo_sharding(ifrt::HloSharding *hlo_sharding) { + return reactant::capture(std::shared_ptr(hlo_sharding)); } -// TODO there are problems with using `make_shared -// extern "C" ifrt::PjRtCompatibleArray* -// ifrt_pjrt_client_create_pjrt_array(ifrt::PjRtClient* client, xla::PjRtBuffer* -// pjrt_buffer) { -// auto buffer_ptr = std::make_shared(*pjrt_buffer); -// return MyValueOrThrow(client->CreatePjRtArray(buffer_ptr)).release(); -// } - -// TODO extern "C" ifrt::PjRtCompatibleArray* -// ifrt_pjrt_client_create_pjrt_array_from_buffers(ifrt::Shape* shape, -// ifrt::PjRtBuffer** pjrt_buffers, int num_buffers) {} +extern "C" HeldValue> * +ifrt_sharding_from_hlo_sharding(ifrt::Client *client, + ifrt::Device **device_list, int32_t num_devices, + ifrt::MemoryKind *memory_kind, + xla::HloSharding *xla_hlo_sharding) { + return ifrt_sharding_from_ifrt_hlo_sharding( + ifrt_hlo_sharding_from_xla_hlo_sharding(client, device_list, num_devices, + memory_kind, xla_hlo_sharding)); +} -extern "C" ifrt::PjRtCompatibleDevice * -ifrt_pjrt_client_lookup_pjrt_device(ifrt::PjRtClient *client, - xla::PjRtDevice *pjrt_device) { - return MyValueOrThrow(client->LookupPjRtDevice(pjrt_device)); +extern "C" bool ifrt_sharding_is_single_device_sharding( + HeldValue> *sharding) { + return llvm::isa(sharding->obj().get()); } -extern "C" ifrt::PjRtCompatibleMemory * -ifrt_pjrt_client_lookup_pjrt_memory(ifrt::PjRtClient *client, - xla::PjRtMemorySpace *pjrt_memory_space) { - return MyValueOrThrow(client->LookupPjRtMemory(pjrt_memory_space)); +extern "C" bool ifrt_sharding_is_fully_replicated( + HeldValue> *sharding) { + return sharding->obj()->IsFullyReplicated(); } -#pragma endregion -#pragma region xla::ifrt::HostCallback extern "C" const char * -ifrt_hostcallback_serialize(ifrt::HostCallback *host_callback) { - return cstr_from_string(host_callback->Serialize()); +ifrt_sharding_to_string(HeldValue> *sharding) { + return cstr_from_string(sharding->obj()->DebugString()); } -#pragma endregion -#pragma region xla::ifrt::LoadedHostCallback -extern "C" ifrt::Client * -ifrt_loadedhostcallback_client(ifrt::LoadedHostCallback *host_callback) { - return host_callback->client(); +extern "C" int32_t ifrt_sharding_devices_size( + HeldValue> *sharding) { + return sharding->obj()->devices()->size(); } -extern "C" const char * -ifrt_loadedhostcallback_serialize(ifrt::LoadedHostCallback *host_callback) { - // auto msg = ; - return cstr_from_string(MyValueOrThrow(host_callback->Serialize())); +extern "C" void ifrt_sharding_to_device_list( + HeldValue> *sharding, + ifrt::Device **devices) { + auto device_list = sharding->obj()->devices()->devices(); + for (int i = 0; i < device_list.size(); i++) { + devices[i] = device_list[i]; + } } + #pragma endregion -#pragma region xla::ifrt::PjRtHostSendAndRecvLoadedHostCallback -extern "C" ifrt::PjRtHostSendAndRecvLoadedHostCallback * -ifrt_pjrt_hostsendandrecv_loadhostcallback_ctor( - ifrt::PjRtClient *client, xla::HostCallback *host_callback) { - auto xla_callback_ptr = std::make_unique(*host_callback); - return new ifrt::PjRtHostSendAndRecvLoadedHostCallback( - client, std::move(xla_callback_ptr)); -} +typedef ifrt::Future<> IfRtFutureType; -extern "C" void ifrt_pjrt_hostsendandrecv_loadhostcallback_free( - ifrt::PjRtHostSendAndRecvLoadedHostCallback *host_callback) { - delete host_callback; -} +extern "C" void ifrt_free_future(IfRtFutureType *Future) { delete Future; } -extern "C" xla::HostCallback * -ifrt_pjrt_hostsendandrecv_loadhostcallback_host_callback( - ifrt::PjRtHostSendAndRecvLoadedHostCallback *host_callback) { - return new xla::HostCallback(host_callback->host_callback()); +extern "C" uint8_t ifrt_future_is_ready(IfRtFutureType *Future) { + return Future->IsReady(); } -#pragma endregion -#pragma region xla::ifrt::Executable -extern "C" const char *ifrt_executable_name(ifrt::Executable *executable) { - return cstr_from_string(executable->name()); -} +extern "C" void ifrt_future_await(IfRtFutureType *Future) { Future->Await(); } -extern "C" const char * -ifrt_executable_fingerprint(ifrt::Executable *executable) { - auto result = MyValueOrThrow(executable->Fingerprint()); - if (!result.has_value()) - return ""; - return cstr_from_string(result.value()); -} +#pragma region IfRtArray -extern "C" const char *ifrt_executable_serialize(ifrt::Executable *executable) { - return cstr_from_string(MyValueOrThrow(executable->Serialize())); -} +extern "C" void ifrt_free_array(HeldIfrtArray *array) { delete array; } -extern "C" int ifrt_executable_num_devices(ifrt::Executable *executable) { - return executable->num_devices(); +extern "C" int64_t *ifrt_array_shape(HeldIfrtArray *array) { + auto dims = + static_cast>(array->obj()->shape().dims()); + int64_t *dims_ptr = new int64_t[dims.size()]; + std::copy(dims.begin(), dims.end(), dims_ptr); + return dims_ptr; } -extern "C" int64_t ifrt_executable_size(ifrt::Executable *executable) { - return executable->SizeOfGeneratedCodeInBytes(); +extern "C" int64_t ifrt_array_ndims(HeldIfrtArray *array) { + return array->obj()->shape().dims().size(); } -// TODO xla::ifrt::Executable::GetCompiledMemoryStats - -extern "C" std::tuple -ifrt_executable_parameter_shardings(ifrt::Executable *executable) { - auto shardings = executable->GetParameterShardings(); - if (!shardings.has_value()) - return std::make_tuple(0, nullptr); - return std::make_tuple(shardings.value().size(), shardings.value().data()); +extern "C" ifrt::DType ifrt_array_eltype(HeldIfrtArray *array) { + return array->obj()->dtype(); } -extern "C" std::tuple -ifrt_executable_output_shardings(ifrt::Executable *executable) { - auto shardings = executable->GetOutputShardings(); - if (!shardings.has_value()) - return std::make_tuple(0, nullptr); - return std::make_tuple(shardings.value().size(), shardings.value().data()); +extern "C" ifrt::Client *ifrt_array_to_client(HeldIfrtArray *array) { + return array->obj()->client(); } -extern "C" std::tuple -ifrt_executable_parameter_layouts(ifrt::Executable *executable) { - auto layouts = MyValueOrThrow(executable->GetParameterLayouts()); - auto layouts_ptr = new xla::PjRtLayout *[layouts.size()]; - for (int i = 0; i < layouts.size(); i++) { - layouts_ptr[i] = layouts[i].release(); - } - return std::make_tuple(layouts.size(), layouts_ptr); +extern "C" HeldValue> * +ifrt_array_to_sharding(HeldIfrtArray *array) { + return reactant::capture(array->obj()->shared_ptr_sharding()); } -extern "C" std::tuple -ifrt_executable_output_layouts(ifrt::Executable *executable) { - auto layouts = MyValueOrThrow(executable->GetOutputLayouts()); - auto layouts_ptr = new xla::PjRtLayout *[layouts.size()]; - for (int i = 0; i < layouts.size(); i++) { - layouts_ptr[i] = layouts[i].release(); - } - return std::make_tuple(layouts.size(), layouts_ptr); +extern "C" void ifrt_array_copy_to_host_buffer(HeldIfrtArray *array, + void *data) { + std::optional> byte_strides; + auto future = array->obj()->CopyToHostBuffer( + data, byte_strides, static_cast(0)); + future.Await(); + return; } -extern "C" std::tuple -ifrt_executable_hlo_modules(ifrt::Executable *executable) { - auto modules = MyValueOrThrow(executable->GetHloModules()); - auto modules_ptr = new xla::HloModule *[modules.size()]; - for (int i = 0; i < modules.size(); i++) { - modules_ptr[i] = modules[i].get(); +extern "C" HeldIfrtArray **ifrt_array_disassemble_into_single_device_arrays( + HeldIfrtArray *array, int32_t c_semantics, + int32_t c_single_device_shard_semantics, int32_t *narrays) { + std::vector> single_device_arrays = + MyValueOrThrow(array->obj()->DisassembleIntoSingleDeviceArrays( + static_cast(c_semantics), + static_cast( + c_single_device_shard_semantics))); + + *narrays = single_device_arrays.size(); + HeldIfrtArray **arrays = new HeldIfrtArray *[single_device_arrays.size()]; + for (int i = 0; i < single_device_arrays.size(); i++) { + arrays[i] = reactant::capture(std::move(single_device_arrays[i])); } - return std::make_tuple(modules.size(), modules_ptr); + return arrays; } -// TODO xla::ifrt::Executable::GetCostAnalysis #pragma endregion -#pragma region xla::ifrt::PjRtExecutable -// TODO there are problems with using `make_shared -// extern "C" ifrt::Executable* ifrt_pjrt_executable_ctor(xla::PjRtExecutable* -// pjrt_executable, ifrt::XlaCompileOptions* compile_options) { -// auto pjrt_executable_shared = -// std::make_shared(*pjrt_executable); auto options = -// std::make_unique(*compile_options); return -// MyValueOrThrow(ifrt::PjRtExecutable::Create(pjrt_executable_shared, -// std::move(options))).release(); -// } +#pragma region xla::Distributed -extern "C" void ifrt_pjrt_executable_free(ifrt::PjRtExecutable *executable) { - delete executable; -} +extern "C" HeldValue> * +GetDistributedRuntimeClient(char *c_address, int32_t node_id, + int32_t rpc_timeout_in_seconds, + // int32_t init_timeout, + int32_t shutdown_timeout_in_minutes, + int32_t heartbeat_interval_in_seconds, + int max_missing_heartbeats, bool use_compression) { + xla::DistributedRuntimeClient::Options options; + options.node_id = node_id; + options.rpc_timeout = absl::Seconds(rpc_timeout_in_seconds); + // options.init_timeout = absl::Seconds(init_timeout); + options.shutdown_timeout = absl::Minutes(shutdown_timeout_in_minutes); + options.heartbeat_interval = absl::Seconds(heartbeat_interval_in_seconds); + options.max_missing_heartbeats = max_missing_heartbeats; -extern "C" xla::PjRtExecutable * -ifrt_pjrt_executable_pjrt_executable(ifrt::PjRtExecutable *executable) { - return executable->pjrt_executable(); -} -#pragma endregion + std::string address = c_address; -#pragma region xla::ifrt::LoadedExecutable -extern "C" ifrt::Client * -ifrt_loadedexecutable_client(ifrt::LoadedExecutable *executable) { - return executable->client(); + return reactant::capture( + xla::GetDistributedRuntimeClient(address, options, use_compression)); } -extern "C" const char * -ifrt_loadedexecutable_name(ifrt::LoadedExecutable *executable) { - return cstr_from_string(executable->name()); +extern "C" void free_distributed_runtime_client( + HeldValue> *client) { + delete client; } -extern "C" const char * -ifrt_loadedexecutable_fingerprint(ifrt::LoadedExecutable *executable) { - auto result = MyValueOrThrow(executable->Fingerprint()); - if (!result.has_value()) - return ""; - return cstr_from_string(result.value()); +extern "C" void distributed_runtime_client_connect( + HeldValue> *client) { + auto status = client->obj()->Connect(); + if (!status.ok()) + ReactantThrowError(status.ToString().c_str()); } -extern "C" const char * -ifrt_loadedexecutable_serialize(ifrt::LoadedExecutable *executable) { - return cstr_from_string(MyValueOrThrow(executable->Serialize())); +extern "C" void distributed_runtime_client_shutdown( + HeldValue> *client) { + auto status = client->obj()->Shutdown(); + if (!status.ok()) + ReactantThrowError(status.ToString().c_str()); } -extern "C" ifrt::Future<> -ifrt_loadedexecutable_get_ready_future(ifrt::LoadedExecutable *executable) { - return executable->GetReadyFuture(); -} +extern "C" xla::DistributedRuntimeService *GetDistributedRuntimeService( + char *c_address, int num_nodes, int32_t heartbeat_interval_in_seconds, + int max_missing_heartbeats, int32_t cluster_register_timeout_in_minutes, + int32_t shutdown_timeout_in_minutes) { + xla::CoordinationServiceImpl::Options options; + options.num_nodes = num_nodes; + options.heartbeat_interval = absl::Seconds(heartbeat_interval_in_seconds); + options.max_missing_heartbeats = max_missing_heartbeats; + options.cluster_register_timeout = + absl::Minutes(cluster_register_timeout_in_minutes); + options.shutdown_timeout = absl::Minutes(shutdown_timeout_in_minutes); -extern "C" int -ifrt_loadedexecutable_num_devices(ifrt::LoadedExecutable *executable) { - return executable->num_devices(); -} + std::string address = c_address; -extern "C" int64_t -ifrt_loadedexecutable_size(ifrt::LoadedExecutable *executable) { - return executable->SizeOfGeneratedCodeInBytes(); + return MyValueOrThrow(xla::GetDistributedRuntimeService(address, options)) + .release(); } -// TODO xla::ifrt::GetCompiledMemoryStats - -extern "C" std::tuple -ifrt_loadedexecutable_parameter_shardings(ifrt::LoadedExecutable *executable) { - auto shardings = executable->GetParameterShardings(); - if (!shardings.has_value()) - return std::make_tuple(0, nullptr); - return std::make_tuple(shardings.value().size(), shardings.value().data()); +extern "C" void free_distributed_runtime_service( + HeldValue> *service) { + delete service; } -extern "C" std::tuple -ifrt_loadedexecutable_output_shardings(ifrt::LoadedExecutable *executable) { - auto shardings = executable->GetOutputShardings(); - if (!shardings.has_value()) - return std::make_tuple(0, nullptr); - return std::make_tuple(shardings.value().size(), shardings.value().data()); +extern "C" void distributed_runtime_service_shutdown( + HeldValue> *service) { + service->obj()->Shutdown(); } -extern "C" std::tuple -ifrt_loadedexecutable_parameter_layouts(ifrt::LoadedExecutable *executable) { - auto layouts = MyValueOrThrow(executable->GetParameterLayouts()); - auto layouts_ptr = new xla::PjRtLayout *[layouts.size()]; - for (int i = 0; i < layouts.size(); i++) { - layouts_ptr[i] = layouts[i].release(); - } - return std::make_tuple(layouts.size(), layouts_ptr); -} +#pragma endregion -extern "C" std::tuple -ifrt_loadedexecutable_output_layouts(ifrt::LoadedExecutable *executable) { - auto layouts = MyValueOrThrow(executable->GetOutputLayouts()); - auto layouts_ptr = new xla::PjRtLayout *[layouts.size()]; - for (int i = 0; i < layouts.size(); i++) { - layouts_ptr[i] = layouts[i].release(); +#pragma region Shardy + +extern "C" xla::HloSharding * +hloShardingFromTensorShardingAttr(mlir::sdy::TensorShardingAttr attr, + mlir::sdy::MeshAttr meshAttr) { + mlir::ArrayRef manual_axes = {}; + std::function + get_mesh_attr = [meshAttr](mlir::sdy::TensorShardingAttr local_attr) { + return meshAttr; + }; + + return new xla::HloSharding( + xla::sdy::convertToHloSharding(attr, get_mesh_attr, manual_axes)); +} + +// XXX: This is incorrect for multiple meshes. We need to use the current mesh +// to generate this instead of the global mesh Currently we are storing only a +// single mesh, so we can just use this. +extern "C" mlir::sdy::TensorShardingAttr hloShardingToTensorShardingAttr( + mlir::MLIRContext *context, const xla::HloSharding *hloSharding, + mlir::StringAttr meshName, mlir::sdy::MeshAttr meshAttr, int64_t rank, + const bool *isClosed, const int64_t *priority) { + const SmallDenseMap deviceIdToMaximalMeshName = + SmallDenseMap(); + mlir::sdy::TensorShardingAttr tensorShardingAttr = + xla::sdy::convertToSdySharding(*hloSharding, meshAttr, + deviceIdToMaximalMeshName, rank, + /*openDims=*/true); + + for (int64_t i = 0; i < rank; i++) { + auto oldDimSharding = tensorShardingAttr.getDimSharding(i); + + std::optional dimPriority; + if (priority[i] > 0) + dimPriority = priority[i]; + + tensorShardingAttr = tensorShardingAttr.replaceDimSharding( + i, mlir::sdy::DimensionShardingAttr::get(oldDimSharding.getContext(), + oldDimSharding.getAxes(), + isClosed[i], dimPriority)); } - return std::make_tuple(layouts.size(), layouts_ptr); -} -extern "C" std::tuple -ifrt_loadedexecutable_hlo_modules(ifrt::LoadedExecutable *executable) { - auto modules = MyValueOrThrow(executable->GetHloModules()); - auto modules_ptr = new xla::HloModule *[modules.size()]; - for (int i = 0; i < modules.size(); i++) { - modules_ptr[i] = modules[i].get(); - } - return std::make_tuple(modules.size(), modules_ptr); + return mlir::sdy::TensorShardingAttr::get( + context, meshName, tensorShardingAttr.getDimShardings(), + tensorShardingAttr.getReplicatedAxes()); } -// TODO xla::ifrt::LoadedExecutable::GetOutputMemoryKinds -// TODO xla::ifrt::LoadedExecutable::GetCostAnalysis +#pragma endregion -// extern "C" ifrt::LoadedExecutable::ExecuteResult* -// ifrt_loadedexecutable_execute(ifrt::LoadedExecutable* executable, -// ifrt::Array** args, size_t args_size, ifrt::Array** results, size_t -// results_size, ifrt::Future<*>** futures, size_t futures_size) { -// std::vector arguments(args, args + args_size); -// std::vector result(results, results + results_size); -// std::vector*> future(futures, futures + futures_size); -// return MyValueOrThrow(executable->Execute(arguments, result, future)); -// } +#pragma region ifrt::LoadedExecutable -extern "C" ifrt::Future<> -ifrt_loadedexecutable_delete(ifrt::LoadedExecutable *executable) { - return executable->Delete(); +extern "C" void ifrt_loaded_executable_dtor(ifrt::LoadedExecutable *exec) { + delete exec; } -extern "C" bool -ifrt_loadedexecutable_is_deleted(ifrt::LoadedExecutable *executable) { - return executable->IsDeleted(); -} +extern "C" void ifrt_loaded_executable_execute( + ifrt::LoadedExecutable *exec, int num_args, + HeldValue> **op_args, + uint8_t *is_arg_donatable, int num_results, + HeldValue> **op_results, uint8_t *futures, + FutureType **status) { + std::vector> args; + for (int i = 0; i < num_args; i++) { + args.emplace_back(op_args[i]->obj()); + } -extern "C" std::tuple -ifrt_loadedexecutable_addressable_devices(ifrt::LoadedExecutable *executable) { - auto devices = executable->addressable_devices(); - return std::make_tuple(devices.size(), devices.data()); -} + ifrt::ExecuteOptions options; + for (size_t i = 0; i < num_args; i++) { + if (!is_arg_donatable[i]) { + options.non_donatable_input_indices.insert(static_cast(i)); + } + } + options.fill_status = true; -// TODO auxiliary functions for xla::ifrt::LoadedExecutable::ExecuteResult -#pragma endregion + auto result = MyValueOrThrow(exec->Execute( + static_cast>>(args), + options, /* devices */ std::nullopt)); -#pragma region xla::ifrt::PjRtLoadedExecutable -// TODO add support for LoadedHostCallback -// TODO there are problems with using `make_shared -// extern "C" ifrt::LoadedExecutable* -// ifrt_pjrt_loadedexecutable_ctor(ifrt::PjRtCompatibleClient* client, -// xla::PjRtLoadedExecutable* pjrt_loaded_executable) { -// auto pjrt_loaded_executable_ptr = -// std::make_shared(*pjrt_loaded_executable); -// return MyValueOrThrow(ifrt::PjRtLoadedExecutable::Create(client, -// pjrt_loaded_executable_ptr, -// std::vector>())).release(); -// } + if (result.outputs.size() != num_results) { + llvm::errs() << "Error: results.size()=" << result.outputs.size() + << " does not match num_results=" << num_results << "\n"; + std::abort(); // Terminate if the number of results is incorrect. + } -// TODO add support for LoadedHostCallback -extern "C" ifrt::LoadedExecutable * -ifrt_pjrt_loadedexecutable_ctor_from_mlir_module( - ifrt::PjRtCompatibleClient *client, mlir::ModuleOp *module, - xla::CompileOptions *compile_options) { - return MyValueOrThrow( - ifrt::PjRtLoadedExecutable::Create( - client, *module, *compile_options, - std::vector>())) - .release(); -} + // there is only 1 status and is valid because we set `options.fill_status = + // true` + *futures = true; + *status = new FutureType(result.status); -extern "C" void -ifrt_pjrt_loadedexecutable_free(ifrt::PjRtLoadedExecutable *executable) { - delete executable; + for (int i = 0; i < num_results; i++) { + op_results[i] = reactant::capture(result.outputs[i]); + } } -extern "C" xla::PjRtLoadedExecutable * -ifrt_pjrt_loadedexecutable_pjrt_loadedexecutable( - ifrt::PjRtLoadedExecutable *executable) { - return executable->pjrt_loaded_executable(); +extern "C" ifrt::Client * +ifrt_loaded_executable_client(ifrt::LoadedExecutable *exec) { + return exec->client(); } -#pragma endregion -#pragma region xla::ifrt::CustomCallProgram -#pragma endregion +extern "C" void +ifrt_loaded_executable_get_parameter_shardings(ifrt::LoadedExecutable *exec, + xla::OpSharding **op_shardings, + int32_t num_op_shardings) { + std::optional> shardings = + exec->GetParameterShardings(); + if (!shardings.has_value()) { + ReactantThrowError( + "No sharding found for the output of the loaded executable"); + } -#pragma region xla::ifrt::HloProgram -extern "C" ifrt::HloProgram *ifrt_hloprogram_ctor() { - return new ifrt::HloProgram(); -} + std::vector hlo_op_shardings = shardings.value(); + if (num_op_shardings != hlo_op_shardings.size()) { + ReactantThrowError(("Expected " + std::to_string(num_op_shardings) + + " shardings, got " + + std::to_string(hlo_op_shardings.size())) + .c_str()); + } -extern "C" ifrt::HloProgram * -ifrt_hloprogram_ctor_with_module(mlir::ModuleOp *module) { - return new ifrt::HloProgram(*module); + for (int32_t i = 0; i < num_op_shardings; i++) { + op_shardings[i] = new xla::OpSharding(hlo_op_shardings[i]); + } } -// extern "C" ifrt::HloProgram* -// ifrt_hloprogram_ctor_with_context_and_module(mlir::MLIRContext* context, -// mlir::ModuleOp* module) { -// auto context_ptr = std::make_unique(*context); -// return new ifrt::HloProgram(std::move(context_ptr), *module); -// } -#pragma endregion +extern "C" void +ifrt_loaded_executable_get_output_shardings(ifrt::LoadedExecutable *exec, + xla::OpSharding **op_shardings, + int32_t num_op_shardings) { + std::optional> shardings = + exec->GetOutputShardings(); + if (!shardings.has_value()) { + ReactantThrowError( + "No sharding found for the output of the loaded executable"); + } -#pragma region xla::ifrt::Compiler -extern "C" ifrt::LoadedExecutable * -ifrt_compiler_compile(ifrt::Compiler *compiler, ifrt::Program *program) { - // apparently ifrt::CompileOptions is a legacy artifact so we don't use it and - // set directly to the default - auto program_ptr = std::make_unique(*program); - auto options = std::make_unique(); - return MyValueOrThrow( - compiler->Compile(std::move(program_ptr), std::move(options))) - .release(); -} + std::vector hlo_op_shardings = shardings.value(); + if (num_op_shardings != hlo_op_shardings.size()) { + ReactantThrowError(("Expected " + std::to_string(num_op_shardings) + + " shardings, got " + + std::to_string(hlo_op_shardings.size())) + .c_str()); + } -extern "C" ifrt::Executable * -ifrt_compiler_compile_with_topology(ifrt::Compiler *compiler, - ifrt::Program *program, - const ifrt::Topology *topology) { - // apparently ifrt::CompileOptions is a legacy artifact so we don't use it and - // set directly to the default - auto options = std::make_unique(); - auto program_ptr = std::make_unique(*program); - auto exec_ptr = - MyValueOrThrow(compiler->Compile(std::move(program_ptr), *topology, - std::move(options))) - .release(); - return exec_ptr; -} - -extern "C" ifrt::LoadedExecutable * -ifrt_compiler_deserialize_loadedexecutable(ifrt::Compiler *compiler, - const char *data) { - // apparently ifrt::DeserializeExecutableOptions is a legacy artifact so we - // don't use it and set directly to the default - auto options = std::make_unique(); - return MyValueOrThrow(compiler->DeserializeLoadedExecutable( - std::string(data), std::move(options))) - .release(); + for (int32_t i = 0; i < num_op_shardings; i++) { + op_shardings[i] = new xla::OpSharding(hlo_op_shardings[i]); + } } -#pragma endregion -#pragma region xla::ifrt::PjRtCompiler -extern "C" ifrt::PjRtCompiler * -ifrt_pjrt_compiler_ctor(ifrt::PjRtClient *client) { - return new ifrt::PjRtCompiler(client); +extern "C" void +ifrt_loaded_executable_get_hlo_modules(ifrt::LoadedExecutable *exec, + void **hlo_modules, int32_t *nmodules) { + auto hlo_modules_vec = MyValueOrThrow(exec->GetHloModules()); + *nmodules = hlo_modules_vec.size(); + for (int32_t i = 0; i < *nmodules; i++) { + hlo_modules[i] = reactant::capture(hlo_modules_vec[i]); + } } -extern "C" void ifrt_pjrt_compiler_free(ifrt::PjRtCompiler *compiler) { - delete compiler; +extern "C" int32_t +ifrt_loaded_executable_num_devices(ifrt::LoadedExecutable *exec) { + return static_cast(exec->num_devices()); } -#pragma endregion #pragma endregion diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 0a512067a9..fdcb11e673 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -54,9 +54,9 @@ cc_toolchain_config( coverage_link_flags = ["--coverage"], cpu = "k8", cxx_builtin_include_directories = [ - "/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/10.2.0", - "/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/10.2.0/x86_64-linux-musl", - "/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/10.2.0/backward", + "/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/GCC_VERSION", + "/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/GCC_VERSION/x86_64-linux-musl", + "/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/GCC_VERSION/backward", "/opt/x86_64-linux-musl/x86_64-linux-musl/include", "/opt/x86_64-linux-musl/bin/../include/x86_64-unknown-linux-musl/c++/v1", "/opt/x86_64-linux-musl/bin/../include/c++/v1", @@ -149,14 +149,14 @@ cc_toolchain_config( abi_libc_version = "local", abi_version = "local", cxx_builtin_include_directories = [ - "/opt/BB_TARGET/lib/gcc/BB_TARGET/10.2.0/include", - "/opt/BB_TARGET/lib/gcc/BB_TARGET/10.2.0/include-fixed", + "/opt/BB_TARGET/lib/gcc/BB_TARGET/GCC_VERSION/include", + "/opt/BB_TARGET/lib/gcc/BB_TARGET/GCC_VERSION/include-fixed", "/opt/BB_TARGET/BB_TARGET/include", "/opt/BB_TARGET/BB_TARGET/sys-root/usr/include", - "/opt/BB_TARGET/BB_TARGET/include/c++/10.2.0", - "/opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/BB_TARGET", - "/opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/backward", - "/opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/parallel" + "/opt/BB_TARGET/BB_TARGET/include/c++/GCC_VERSION", + "/opt/BB_TARGET/BB_TARGET/include/c++/GCC_VERSION/BB_TARGET", + "/opt/BB_TARGET/BB_TARGET/include/c++/GCC_VERSION/backward", + "/opt/BB_TARGET/BB_TARGET/include/c++/GCC_VERSION/parallel", ], tool_paths = { "ar": "/opt/bin/BB_FULL_TARGET/ar", @@ -193,14 +193,14 @@ cc_toolchain_config( "-Wno-free-nonheap-object", "-fno-omit-frame-pointer", # TODO cxx_builtin_include_directories doesn't seem to be working, so we add the INCLUDE_PATHs manually - "-isystem /opt/BB_TARGET/lib/gcc/BB_TARGET/10.2.0/include", - "-isystem /opt/BB_TARGET/lib/gcc/BB_TARGET/10.2.0/include-fixed", + "-isystem /opt/BB_TARGET/lib/gcc/BB_TARGET/GCC_VERSION/include", + "-isystem /opt/BB_TARGET/lib/gcc/BB_TARGET/GCC_VERSION/include-fixed", "-isystem /opt/BB_TARGET/BB_TARGET/include", "-isystem /opt/BB_TARGET/BB_TARGET/sys-root/usr/include", - "-isystem /opt/BB_TARGET/BB_TARGET/include/c++/10.2.0", - "-isystem /opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/BB_TARGET", - "-isystem /opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/backward", - "-isystem /opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/parallel", + "-isystem /opt/BB_TARGET/BB_TARGET/include/c++/GCC_VERSION", + "-isystem /opt/BB_TARGET/BB_TARGET/include/c++/GCC_VERSION/BB_TARGET", + "-isystem /opt/BB_TARGET/BB_TARGET/include/c++/GCC_VERSION/backward", + "-isystem /opt/BB_TARGET/BB_TARGET/include/c++/GCC_VERSION/parallel", ], opt_compile_flags = [ "-g0", @@ -360,11 +360,13 @@ cc_library( ], ) + [ - "@enzyme_ad//src/enzyme_ad/jax:RegistryUtils.cpp", + "@enzyme_ad//src/enzyme_ad/jax:RegistryUtils.cpp", + "@enzyme_ad//src/enzyme_ad/jax:gpu.cc", + "@enzyme_ad//src/enzyme_ad/jax:cpu.cc", # "@com_google_protobuf//:src/google/protobuf/io/coded_stream.cc", # "@xla//xla:xla.pb.cc", "@xla//xla:xla_data.pb.cc", - "@xla//xla/stream_executor:device_description.pb.cc", + # "@xla//xla/stream_executor:device_description.pb.cc", "@xla//xla/service:hlo.pb.cc", # # "@tsl//tsl/protobuf:dnn.pb.cc", #"@tsl//tsl/protobuf:histogram.pb.cc", @@ -383,7 +385,12 @@ cc_library( "-Werror=return-type", "-Werror=unused-result", "-Wno-error=stringop-truncation" + ] + select({ + "@xla//xla/tsl:is_cuda_enabled_and_oss":[ + "-DREACTANT_CUDA=1", ], + "//conditions:default": [], + }), alwayslink = True, linkstatic = True, linkopts = select({ @@ -391,6 +398,8 @@ cc_library( "@bazel_tools//src/conditions:darwin": [ "-Wl,-exported_symbol,_stablehlo*", "-Wl,-exported_symbol,_mlir*", +"-Wl,-exported_symbol,_sdy*", +"-Wl,-exported_symbol,_EnzymeJaXMapSymbol", "-Wl,-exported_symbol,_InitializeLogs", "-Wl,-exported_symbol,_SetLogLevel", "-Wl,-exported_symbol,_SetModuleLogLevel", @@ -407,6 +416,7 @@ cc_library( "-Wl,-exported_symbol,_ClientProcessIndex", "-Wl,-exported_symbol,_ClientGetDevice", "-Wl,-exported_symbol,_ClientGetAddressableDevice", +"-Wl,-exported_symbol,_PjRtDeviceGetAllocatorStats", "-Wl,-exported_symbol,_ExecutableFree", "-Wl,-exported_symbol,_BufferToDevice", "-Wl,-exported_symbol,_BufferToClient", @@ -419,18 +429,70 @@ cc_library( "-Wl,-exported_symbol,_BufferToHost", "-Wl,-exported_symbol,_FreeClient", "-Wl,-exported_symbol,_ClientCompile", +"-Wl,-exported_symbol,_ConvertLLVMStrToMLIR", "-Wl,-exported_symbol,_LinkInModule", "-Wl,-exported_symbol,_FreeFuture", "-Wl,-exported_symbol,_FutureIsReady", "-Wl,-exported_symbol,_FutureAwait", "-Wl,-exported_symbol,_XLAExecute", "-Wl,-exported_symbol,_RegisterDialects", -"-Wl,-exported_symbol,_InitializeRegistryAndPasses", -"-Wl,-exported_symbol,_ifrt_*", +"-Wl,-exported_symbol,_InitializeRegistry", +"-Wl,-exported_symbol,_InitializePasses", "-Wl,-exported_symbol,_RegisterCustomCallTarget", "-Wl,-exported_symbol,_ConvertLLVMToMLIR", -"-Wl,-exported_symbol,_EnzymeGPUCustomCall", +"-Wl,-exported_symbol,_RegisterEnzymeXLAGPUHandler", "-Wl,-exported_symbol,_ReactantThrowError", +"-Wl,-exported_symbol,_ReactantHandleCuResult", +"-Wl,-exported_symbol,_CreateProfilerSession", +"-Wl,-exported_symbol,_ProfilerSessionCollectData", +"-Wl,-exported_symbol,_ProfilerSessionDelete", +"-Wl,-exported_symbol,_ProfilerServerStart", +"-Wl,-exported_symbol,_ProfilerServerStop", +"-Wl,-exported_symbol,_ProfilerActivityStart", +"-Wl,-exported_symbol,_ProfilerActivityEnd", +"-Wl,-exported_symbol,_ReactantFuncSetArgAttr", +"-Wl,-exported_symbol,_ReactantHermeticCudaGetVersion", +"-Wl,-exported_symbol,_ReactantCudaDriverGetVersion", +"-Wl,-exported_symbol,_ReactantLLVMParseCommandLineOptions", +"-Wl,-exported_symbol,_PjRtDeviceGetLocalDeviceId", +"-Wl,-exported_symbol,_PjRtDeviceGetGlobalDeviceId", +"-Wl,-exported_symbol,_PjRtDeviceGetLocalHardwareId", +"-Wl,-exported_symbol,_XLAExecuteSharded", +"-Wl,-exported_symbol,_ClientGetPlatformName", +"-Wl,-exported_symbol,_RegisterEnzymeXLACPUHandler", +"-Wl,-exported_symbol,_PjRtLoadedExecutableGetClient", +"-Wl,-exported_symbol,_ReactantFuncSetResultAttr", +"-Wl,-exported_symbol,_BufferShape", +"-Wl,-exported_symbol,_BufferNDimensions", +"-Wl,-exported_symbol,_BufferPrimitiveType", +"-Wl,-exported_symbol,_PjRtLoadedExecutableGetOuputShardings", +"-Wl,-exported_symbol,_PjRtLoadedExecutableGetParameterShardings", +"-Wl,-exported_symbol,_PjRtLoadedExecutableGetHloModules", +"-Wl,-exported_symbol,_HloModuleToString", +"-Wl,-exported_symbol,_FreeHloModule", +"-Wl,-exported_symbol,_PjRtLoadedExecutableNumReplicas", +"-Wl,-exported_symbol,_PjRtLoadedExecutableNumPartitions", +"-Wl,-exported_symbol,_ifrt_*", +"-Wl,-exported_symbol,_reactant_*", +"-Wl,-exported_symbol,_free_op_sharding", +"-Wl,-exported_symbol,_free_hlo_sharding", +"-Wl,-exported_symbol,_free_ifrt_hlo_sharding", +"-Wl,-exported_symbol,_hlo_sharding_from_op_sharding", +"-Wl,-exported_symbol,_hlo_sharding_to_op_sharding", +"-Wl,-exported_symbol,_hlo_sharding_to_string", +"-Wl,-exported_symbol,_DeviceGetKind", +"-Wl,-exported_symbol,_GetDistributedRuntimeClient", +"-Wl,-exported_symbol,_free_distributed_runtime_client", +"-Wl,-exported_symbol,_distributed_runtime_client_connect", +"-Wl,-exported_symbol,_distributed_runtime_client_shutdown", +"-Wl,-exported_symbol,_GetDistributedRuntimeService", +"-Wl,-exported_symbol,_free_distributed_runtime_service", +"-Wl,-exported_symbol,_distributed_runtime_service_shutdown", +"-Wl,-exported_symbol,_ClientGetDevices", +"-Wl,-exported_symbol,_ClientGetAddressableDevices", +"-Wl,-exported_symbol,_hloShardingFromTensorShardingAttr", +"-Wl,-exported_symbol,_op_sharding_*", +"-Wl,-exported_symbol,_hloShardingToTensorShardingAttr", ]}), deps = [ "@enzyme//:EnzymeMLIR", @@ -456,11 +518,11 @@ cc_library( "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:TransformDialect", "@llvm-project//mlir:Transforms", - + "@llvm-project//mlir:LLVMIRToLLVMTranslation", "@llvm-project//mlir:LLVMIRToNVVMTranslation", "@llvm-project//mlir:LLVMIRTransforms", - + "@llvm-project//llvm:IRReader", "@llvm-project//llvm:Support", "@llvm-project//llvm:AArch64AsmParser", @@ -469,59 +531,122 @@ cc_library( "@llvm-project//llvm:X86CodeGen", "@enzyme_ad//src/enzyme_ad/jax:TransformOps", "@enzyme_ad//src/enzyme_ad/jax:XLADerivatives", + # "@enzyme_ad//src/enzyme_ad/jax:gpu", + "@xla//xla/ffi/api:ffi", + "@xla//xla/ffi:ffi_api", "@stablehlo//:chlo_ops", "@xla//xla/pjrt:pjrt_api", "@xla//xla/pjrt:pjrt_c_api_client", "@xla//xla/pjrt/cpu:cpu_client", - + "@xla//xla/pjrt/distributed:distributed", + "@xla//xla/pjrt/distributed:client", + "@xla//xla/pjrt/distributed:service", + "@xla//xla/service/spmd/shardy/stablehlo_round_trip:export_shardings", + "@xla//xla/service/spmd/shardy/stablehlo_round_trip:stablehlo_import", + "@xla//xla:xla_proto_cc", "@xla//xla:xla_proto_cc_impl", + "@xla//xla/stream_executor:device_description_proto_cc_impl", + + "@xla//xla/tsl/platform/default:platform_port", "@xla//xla/service:metrics_proto_cc", "@xla//xla/service:metrics_proto_cc_impl", + "@xla//xla/service:custom_call_target_registry", "@xla//xla/service/cpu:cpu_compiler", "@xla//xla/stream_executor/tpu:tpu_on_demand_compiler", "@xla//xla/stream_executor/tpu:tpu_executor", "@xla//xla/stream_executor/tpu:tpu_transfer_manager", - + "@xla//xla/service/cpu:cpu_transfer_manager", "@xla//xla/pjrt/gpu:se_gpu_pjrt_client", - - "@xla//xla/tsl/protobuf:protos_all_cc_impl", + + "@xla//xla/tsl/protobuf:protos_all_cc_impl", "@xla//xla/tsl/framework:allocator_registry_impl", "@xla//xla/pjrt:status_casters", "@xla//xla/python/ifrt:ifrt", "@xla//xla/python/pjrt_ifrt:pjrt_ifrt", + "@xla//xla/python/ifrt_proxy/server:grpc_server", + "@xla//xla/python/ifrt_proxy/client:grpc_client", + "@xla//xla/python/ifrt_proxy/client:registry", + # "@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options", + # "@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", "@xla//xla/python/ifrt/hlo:hlo_program", + "@xla//xla/python/ifrt/ir:ifrt_ir_program", "@xla//xla/ffi:call_frame", "@com_google_protobuf//:protobuf", - "@xla//xla/tsl/profiler/backends/cpu:annotation_stack_impl", + + "@tsl//tsl/profiler/lib:profiler_session_impl", + "@tsl//tsl/profiler/lib:profiler_factory_impl", + "@tsl//tsl/profiler/lib:profiler_controller", + "@tsl//tsl/profiler/lib:traceme", + "@xla//xla/tsl/profiler/rpc:profiler_server_impl", + "@xla//xla/tsl/profiler/rpc/client:capture_profile", + "@xla//xla/tsl/profiler/rpc/client:profiler_client", + "@xla//xla/tsl/profiler/backends/cpu:annotation_stack_impl", "@xla//xla/tsl/profiler/backends/cpu:traceme_recorder_impl", "@xla//xla/tsl/profiler/utils:time_utils_impl", + "@tsl//tsl/profiler/protobuf:profiler_service_monitor_result_proto_cc_impl", + "@tsl//tsl/profiler/protobuf:profiler_service_proto_cc_impl", + "@tsl//tsl/profiler/protobuf:profiler_analysis_proto_cc_impl", + "@tsl//tsl/profiler/protobuf:profiler_options_proto_cc_impl", + "@tsl//tsl/profiler/protobuf:profile_proto_cc_impl", + "@tsl//tsl/profiler/protobuf:xplane_proto_cc_impl", + "@tsl//tsl/profiler/protobuf:trace_events_proto_cc_impl", + + "@xla//xla/backends/profiler/cpu:host_tracer", + "@xla//xla/backends/profiler/cpu:host_tracer_impl", + "@xla//xla/backends/profiler/cpu:metadata_collector", + "@xla//xla/backends/profiler/cpu:metadata_utils", + "@xla//xla/backends/profiler/tpu:tpu_tracer", + "@xla//xla/python:profiler_utils", + "@xla//xla/backends/cpu/collectives:mpi_collectives", + "@tsl//tsl/platform:env_impl", "@xla//xla/stream_executor:stream_executor_impl", "@xla//xla/mlir/utils:type_util", "@stablehlo//:stablehlo_capi_objects", "@stablehlo//:chlo_capi_objects", + "@shardy//shardy/integrations/c:sdy_capi_objects", "@com_google_absl//absl/hash:hash", "@com_google_absl//absl/log:initialize", "@com_google_absl//absl/log:globals", "@llvm-project//mlir:CAPIIRObjects", + "@llvm-project//mlir:CAPILLVMObjects", + "@jax//jaxlib/mosaic:tpu_dialect_capi_objects", + "@jax//jaxlib/triton:triton_dialect_capi_objects", + "@xla//xla/stream_executor/cuda:cuda_compute_capability_proto_cc_impl", ] + select({ - "@xla//xla/tsl:is_cuda_enabled_and_oss":[ - "@xla//xla/stream_executor/cuda:all_runtime", - "@xla//xla/service/gpu/model:hlo_op_profiles", - "@xla//xla/service/gpu/model:hlo_op_profile_proto_cc_impl", - "@xla//xla/service/gpu:nvptx_compiler", - "@xla//xla/service/gpu:gpu_transfer_manager", - "@xla//xla/stream_executor:kernel", - ], - "//conditions:default": [], + "@xla//xla/tsl:is_cuda_enabled_and_oss":[ + "@xla//xla/stream_executor/cuda:all_runtime", + "@xla//xla/service/gpu/model:hlo_op_profiles", + "@xla//xla/service/gpu/model:hlo_op_profile_proto_cc_impl", + "@xla//xla/service/gpu:nvptx_compiler", + "@xla//xla/service/gpu:gpu_transfer_manager", + "@xla//xla/stream_executor:kernel", + "@xla//xla/backends/profiler/gpu:device_tracer", + ], + "//conditions:default": [ + ], }) + if_rocm([ "@xla//xla/service/gpu:amdgpu_compiler", - ]), + "@xla//xla/backends/profiler/gpu:device_tracer", + ]) + select({ + # gloo tcp transport only builds on linux + "@xla//xla/tsl:macos": [ + "@xla//xla/backends/cpu/collectives:gloo_collectives", + "@xla//xla/backends/cpu/collectives:gloo_kv_store", + "@gloo//:transport_uv", + ], + "@xla//xla/tsl:windows": [], + "//conditions:default": [ + "@xla//xla/backends/cpu/collectives:gloo_collectives", + "@xla//xla/backends/cpu/collectives:gloo_kv_store", + "@gloo//:transport_tcp", + ], + }), ) # cc_shared_library( @@ -603,6 +728,76 @@ gentbl_cc_library( tblgen = "//:mlir-jl-tblgen", ) +gentbl_cc_library( + name = "LlvmJLIncGen", + tbl_outs = [( + ["--generator=jl-op-defs", "--disable-module-wrap=0"], + "Llvm.jl" + ) + ], + td_file = "@llvm-project//mlir:include/mlir/Dialect/LLVMIR/LLVMOps.td", + deps = [ + "@llvm-project//mlir:LLVMOpsTdFiles", + ], + tblgen = "//:mlir-jl-tblgen", +) + +gentbl_cc_library( + name = "MemRefJLIncGen", + tbl_outs = [( + ["--generator=jl-op-defs", "--disable-module-wrap=0"], + "MemRef.jl" + ) + ], + td_file = "@llvm-project//mlir:include/mlir/Dialect/MemRef/IR/MemRefOps.td", + deps = [ + "@llvm-project//mlir:MemRefOpsTdFiles", + ], + tblgen = "//:mlir-jl-tblgen", +) + +gentbl_cc_library( + name = "NvvmIncJLGen", + tbl_outs = [( + ["--generator=jl-op-defs", "--disable-module-wrap=0"], + "Nvvm.jl" + ) + ], + td_file = "@llvm-project//mlir:include/mlir/Dialect/LLVMIR/NVVMOps.td", + deps = [ + "@llvm-project//mlir:NVVMOpsTdFiles", + ], + tblgen = "//:mlir-jl-tblgen", +) + +gentbl_cc_library( + name = "GpuIncJLGen", + tbl_outs = [( + ["--generator=jl-op-defs", "--disable-module-wrap=0"], + "Gpu.jl" + ) + ], + td_file = "@llvm-project//mlir:include/mlir/Dialect/GPU/IR/GPUOps.td", + deps = [ + "@llvm-project//mlir:GPUOpsTdFiles", + ], + tblgen = "//:mlir-jl-tblgen", +) + +gentbl_cc_library( + name = "MPIIncJLGen", + tbl_outs = [( + ["--generator=jl-op-defs", "--disable-module-wrap=0"], + "MPI.jl" + ) + ], + td_file = "@llvm-project//mlir:include/mlir/Dialect/MPI/IR/MPIOps.td", + deps = [ + "@llvm-project//mlir:MPITdFiles", + ], + tblgen = "//:mlir-jl-tblgen", +) + gentbl_cc_library( name = "EnzymeJLIncGen", tbl_outs = [( @@ -632,6 +827,35 @@ gentbl_cc_library( tblgen = "//:mlir-jl-tblgen", ) +gentbl_cc_library( + name = "TPUJLIncGen", + tbl_outs = [( + ["--generator=jl-op-defs", "--disable-module-wrap=0"], + "TPU.jl" + ) + ], + td_file = "@jax//jaxlib/mosaic:dialect/tpu/tpu.td", + deps = [ + "@jax//jaxlib/mosaic:tpu_td_files", + ], + tblgen = "//:mlir-jl-tblgen", +) + +gentbl_cc_library( + name = "TritonJLIncGen", + tbl_outs = [( + ["--generator=jl-op-defs", "--disable-module-wrap=0"], + "Triton.jl" + ) + ], + td_file = "@jax//jaxlib/triton:triton.td", + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + "@triton//:td_files", + ], + tblgen = "//:mlir-jl-tblgen", +) + gentbl_cc_library( name = "StableHLOJLIncGen", tbl_outs = [( @@ -674,6 +898,21 @@ gentbl_cc_library( tblgen = "//:mlir-jl-tblgen", ) +gentbl_cc_library( + name = "ShardyJLIncGen", + tbl_outs = [( + ["--generator=jl-op-defs", "--disable-module-wrap=0"], + "Shardy.jl" + ) + ], + td_file = "@shardy//shardy/dialect/sdy/ir:ops.td", + deps = [ + "@shardy//shardy/dialect/sdy/ir:sdy_td_files", + ], + tblgen = "//:mlir-jl-tblgen", + includes = ["external/shardy"], +) + genrule( name = "libMLIR_h.jl", tags = [ @@ -693,6 +932,7 @@ genrule( "@llvm-project//mlir:AsyncPassIncGen_filegroup", "@llvm-project//mlir:GPUPassIncGen_filegroup", "@stablehlo//:stablehlo/integrations/c/StablehloAttributes.h", + "@shardy//shardy/integrations/c:attributes.h", "//:Project.toml", "//:Manifest.toml", "//:wrap.toml", @@ -700,5 +940,5 @@ genrule( "//:make.jl" ], outs = ["libMLIR_h.jl"], - cmd = "$$JULIA \"--project=$(location //:Project.toml)\" \"$(location //:make.jl)\" \"$(location @llvm-project//mlir:include/mlir-c/Bindings/Python/Interop.h)\" \"$(location @llvm-project//llvm:include/llvm-c/Support.h)\" \"$(locations @llvm-project//mlir:ConversionPassIncGen_filegroup)\" \"$(location @stablehlo//:stablehlo/integrations/c/StablehloAttributes.h)\" \"$@\"", + cmd = "$$JULIA \"--project=$(location //:Project.toml)\" \"$(location //:make.jl)\" \"$(location @llvm-project//mlir:include/mlir-c/Bindings/Python/Interop.h)\" \"$(location @llvm-project//llvm:include/llvm-c/Support.h)\" \"$(locations @llvm-project//mlir:ConversionPassIncGen_filegroup)\" \"$(location @stablehlo//:stablehlo/integrations/c/StablehloAttributes.h)\" \"$(location @shardy//shardy/integrations/c:attributes.h)\" \"$@\"", ) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index dc72ecba92..077d017022 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "b6d6563aa3a3050474a4250bf18322f7ebf0b486" +ENZYMEXLA_COMMIT = "52f12204b764d0e61da249083ae1a3273da171b7" ENZYMEXLA_SHA256 = "" http_archive( @@ -51,15 +51,6 @@ load("@enzyme_ad//:workspace.bzl", "JAX_COMMIT", "JAX_SHA256", "ENZYME_COMMIT", XLA_PATCHES = XLA_PATCHES + [ """ -sed -i.bak0 "s/__cpp_lib_hardware_interference_size/HW_INTERFERENCE_SIZE/g" xla/backends/cpu/runtime/thunk_executor.h -""", -""" -sed -i.bak0 "s/__cpp_lib_hardware_interference_size/HW_INTERFERENCE_SIZE/g" xla/stream_executor/host/host_kernel.cc -""", -""" -sed -i.bak0 "s/__cpp_lib_hardware_interference_size/HW_INTERFERENCE_SIZE/g" xla/tsl/concurrency/async_value_ref.h -""", -""" sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.bzl -exec sed -i.bak0 's\\/HAVE_LINK_H=1\\/HAVE_LINK_H=0\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl """, """ @@ -94,6 +85,41 @@ LLVM_TARGETS = select({ "//conditions:default": ["AMDGPU", "NVPTX"], }) + ["AArch64", "X86", "ARM"] +# Uncomment these lines to use a custom LLVM commit +# LLVM_COMMIT = "b39c5cb6977f35ad727d86b2dd6232099734ffd3" +# LLVM_SHA256 = "" +# http_archive( +# name = "llvm-raw", +# build_file_content = "# empty", +# sha256 = LLVM_SHA256, +# strip_prefix = "llvm-project-" + LLVM_COMMIT, +# urls = ["https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT)], +# ) +# +# +# load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") +# maybe( +# http_archive, +# name = "llvm_zlib", +# build_file = "@llvm-raw//utils/bazel/third_party_build:zlib-ng.BUILD", +# sha256 = "e36bb346c00472a1f9ff2a0a4643e590a254be6379da7cddd9daeb9a7f296731", +# strip_prefix = "zlib-ng-2.0.7", +# urls = [ +# "https://github.com/zlib-ng/zlib-ng/archive/refs/tags/2.0.7.zip", +# ], +# ) +# +# maybe( +# http_archive, +# name = "llvm_zstd", +# build_file = "@llvm-raw//utils/bazel/third_party_build:zstd.BUILD", +# sha256 = "7c42d56fac126929a6a85dbc73ff1db2411d04f104fae9bdea51305663a83fd0", +# strip_prefix = "zstd-1.5.2", +# urls = [ +# "https://github.com/facebook/zstd/releases/download/v1.5.2/zstd-1.5.2.tar.gz" +# ], +# ) + http_archive( name = "jax", sha256 = JAX_SHA256, @@ -201,6 +227,21 @@ xla_workspace0() load("@jax//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") flatbuffers() +load("@jax//jaxlib:jax_python_wheel.bzl", "jax_python_wheel_repository") +jax_python_wheel_repository( + name = "jax_wheel", + version_key = "_version", + version_source = "@jax//jax:version.py", +) + +load( + "@tsl//third_party/py:python_wheel.bzl", + "python_wheel_version_suffix_repository", +) +python_wheel_version_suffix_repository( + name = "jax_wheel_version_suffix", +) + load( "@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", "cuda_json_init_repository", diff --git a/deps/ReactantExtra/make-bindings.jl b/deps/ReactantExtra/make-bindings.jl index f5eddb6c2e..43589c6f2e 100644 --- a/deps/ReactantExtra/make-bindings.jl +++ b/deps/ReactantExtra/make-bindings.jl @@ -1,8 +1,16 @@ +const bazel_cmd = if !isnothing(Sys.which("bazelisk")) + "bazelisk" +elseif !isnothing(Sys.which("bazel")) + "bazel" +else + error("Could not find `bazel` or `bazelisk` in PATH!") +end + function build_file(output_path) file = basename(output_path) run( Cmd( - `bazel build --action_env=JULIA=$(Base.julia_cmd().exec[1]) --action_env=JULIA_DEPOT_PATH=$(Base.DEPOT_PATH) --repo_env HERMETIC_PYTHON_VERSION="3.10" --check_visibility=false --verbose_failures //:$file`; + `$(bazel_cmd) build --action_env=JULIA=$(Base.julia_cmd().exec[1]) --action_env=JULIA_DEPOT_PATH=$(Base.DEPOT_PATH) --repo_env HERMETIC_PYTHON_VERSION="3.10" --check_visibility=false --verbose_failures //:$file`; dir=@__DIR__, ), ) @@ -23,6 +31,15 @@ for file in [ "StableHLO.jl", "CHLO.jl", "VHLO.jl", + "Llvm.jl", + "Nvvm.jl", + "Gpu.jl", + "Affine.jl", + "TPU.jl", + "Triton.jl", + "Shardy.jl", + "MPI.jl", + "MemRef.jl", ] build_file(joinpath(src_dir, "mlir", "Dialects", file)) end diff --git a/deps/ReactantExtra/make.jl b/deps/ReactantExtra/make.jl index 17ef9c4f44..b3661f5e01 100644 --- a/deps/ReactantExtra/make.jl +++ b/deps/ReactantExtra/make.jl @@ -18,9 +18,11 @@ let options = deepcopy(options) genarg = first(eachsplit(ARGS[3], " ")) - gen_include_dir = joinpath(splitpath(genarg)[1:(end - 3)]...) + gen_include_dir = joinpath(splitpath(genarg)[1:(end - 4)]...) - hlo_include_dir = joinpath(splitpath(ARGS[end - 1])[1:(end - 1)]...) + hlo_include_dir = joinpath(splitpath(ARGS[end - 2])[1:(end - 1)]...) + + sdy_include_dir = joinpath(splitpath(ARGS[end - 1])[1:(end - 1)]...) append!( args, @@ -33,6 +35,8 @@ let options = deepcopy(options) gen_include_dir, "-I", hlo_include_dir, + "-I", + sdy_include_dir, "-x", "c++", ], @@ -41,6 +45,7 @@ let options = deepcopy(options) headers = [ detect_headers(include_dir, args, Dict(), endswith("Python/Interop.h"))..., detect_headers(hlo_include_dir, args, Dict())..., + detect_headers(sdy_include_dir, args, Dict())..., ] ctx = create_context(headers, args, options) diff --git a/deps/ReactantExtra/tblgen/jl-generators.cc b/deps/ReactantExtra/tblgen/jl-generators.cc index ba2069eed7..4853a3dd1a 100644 --- a/deps/ReactantExtra/tblgen/jl-generators.cc +++ b/deps/ReactantExtra/tblgen/jl-generators.cc @@ -14,10 +14,19 @@ // limitations under the License. #include -#include -#include #include +#include +#include +#include "mlir/TableGen/Argument.h" +#include "mlir/TableGen/Class.h" +#include "mlir/TableGen/CodeGenHelpers.h" +#include "mlir/TableGen/Format.h" +#include "mlir/TableGen/Interfaces.h" +#include "mlir/TableGen/Operator.h" +#include "mlir/TableGen/Region.h" +#include "mlir/TableGen/SideEffects.h" +#include "mlir/TableGen/Trait.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringExtras.h" @@ -35,148 +44,140 @@ #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" -#include "mlir/TableGen/Argument.h" -#include "mlir/TableGen/Class.h" -#include "mlir/TableGen/CodeGenHelpers.h" -#include "mlir/TableGen/Format.h" -#include "mlir/TableGen/Interfaces.h" -#include "mlir/TableGen/Operator.h" -#include "mlir/TableGen/Region.h" -#include "mlir/TableGen/SideEffects.h" -#include "mlir/TableGen/Trait.h" - -namespace -{ - - llvm::cl::opt ExplainMissing( - "explain-missing", - llvm::cl::desc("Print the reason for skipping operations from output")); - llvm::cl::opt DialectName( - "dialect-name", llvm::cl::desc("Override the inferred dialect name, used as the name for the generated Julia module."), - llvm::cl::value_desc("dialect")); - - using namespace mlir; - using namespace mlir::tblgen; - - /// Returns true if the SameArgumentAndResultTypes trait can be used to infer - /// result types of the given operation. - static bool hasSameArgumentAndResultTypes(const Operator &op) - { - return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") && - op.getNumVariableLengthResults() == 0; - } - /// Returns true if the FirstAttrDerivedResultType trait can be used to infer - /// result types of the given operation. - static bool hasFirstAttrDerivedResultTypes(const Operator &op) - { - return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") && - op.getNumVariableLengthResults() == 0; - } +namespace { + +llvm::cl::opt ExplainMissing( + "explain-missing", + llvm::cl::desc("Print the reason for skipping operations from output")); +llvm::cl::opt + DialectName("dialect-name", + llvm::cl::desc("Override the inferred dialect name, used as " + "the name for the generated Julia module."), + llvm::cl::value_desc("dialect")); + +using namespace mlir; +using namespace mlir::tblgen; + +/// Returns true if the SameArgumentAndResultTypes trait can be used to infer +/// result types of the given operation. +static bool hasSameArgumentAndResultTypes(const Operator &op) { + return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") && + op.getNumVariableLengthResults() == 0; +} - /// Returns true if the InferTypeOpInterface can be used to infer result types - /// of the given operation. - static bool hasInferTypeInterface(const Operator &op) - { - return op.getTrait("::mlir::InferTypeOpInterface::Trait") && - op.getNumRegions() == 0; - } +/// Returns true if the FirstAttrDerivedResultType trait can be used to infer +/// result types of the given operation. +static bool hasFirstAttrDerivedResultTypes(const Operator &op) { + return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") && + op.getNumVariableLengthResults() == 0; +} - /// Returns true if there is a trait or interface that can be used to infer - /// result types of the given operation. - static bool canInferType(const Operator &op) - { - return hasSameArgumentAndResultTypes(op) || - hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op); - } +/// Returns true if the InferTypeOpInterface can be used to infer result types +/// of the given operation. +static bool hasInferTypeInterface(const Operator &op) { + return op.getTrait("::mlir::InferTypeOpInterface::Trait") && + op.getNumRegions() == 0; +} - std::string formatDescription(mlir::tblgen::Operator op) - { - std::string description; - description = op.getDescription().str(); - size_t pos = 0; - while (description[pos] == '\n') - ++pos; - size_t leading_spaces = 0; - while (description[pos++] == ' ') - ++leading_spaces; - if (leading_spaces) - { - std::string leading_spaces_str; - for (size_t i = 0; i < leading_spaces; ++i) - leading_spaces_str += "[ ]"; - description = std::regex_replace(description, std::regex("\n" + leading_spaces_str), "\n"); - } - description = std::regex_replace(description, std::regex(R"(\\)"), R"(\\)"); - description = std::regex_replace(description, std::regex("(['\"$])"), "\\$1"); - description = std::regex_replace(description, std::regex("(^|\n)(Example|Syntax):"), "$1# $2"); +/// Returns true if there is a trait or interface that can be used to infer +/// result types of the given operation. +static bool canInferType(const Operator &op) { + return hasSameArgumentAndResultTypes(op) || + hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op); +} - // remove trailing whitespaces and newlines - while (std::isspace(description.back())) { - description.pop_back(); - } - return description; +std::string formatDescription(mlir::tblgen::Operator op) { + std::string description; + description = op.getDescription().str(); + size_t pos = 0; + while (description[pos] == '\n') + ++pos; + size_t leading_spaces = 0; + while (description[pos++] == ' ') + ++leading_spaces; + if (leading_spaces) { + std::string leading_spaces_str; + for (size_t i = 0; i < leading_spaces; ++i) + leading_spaces_str += "[ ]"; + description = std::regex_replace( + description, std::regex("\n" + leading_spaces_str), "\n"); } + description = std::regex_replace(description, std::regex(R"(\\)"), R"(\\)"); + description = std::regex_replace(description, std::regex("(['\"$])"), "\\$1"); + description = std::regex_replace( + description, std::regex("(^|\n)(Example|Syntax):"), "$1# $2"); + + // remove trailing whitespaces and newlines + while (std::isspace(description.back())) { + description.pop_back(); + } + return description; +} - std::string getDialectName(llvm::ArrayRef op_defs) { - mlir::tblgen::Operator any_op(op_defs.front()); - assert( - std::all_of(op_defs.begin(), op_defs.end(), [&any_op](llvm::Record* op) { - return mlir::tblgen::Operator(op).getDialectName() == - any_op.getDialectName(); - })); - std::string dialect_name; - if (DialectName.empty()) { - dialect_name = any_op.getDialectName().str(); - } else { - dialect_name = DialectName; - } - return dialect_name; +std::string getDialectName(llvm::ArrayRef op_defs) { + mlir::tblgen::Operator any_op(op_defs.front()); + assert( + std::all_of(op_defs.begin(), op_defs.end(), [&any_op](llvm::Record *op) { + return mlir::tblgen::Operator(op).getDialectName() == + any_op.getDialectName(); + })); + std::string dialect_name; + if (DialectName.empty()) { + dialect_name = any_op.getDialectName().str(); + } else { + dialect_name = DialectName; } + return dialect_name; +} - std::string sanitizeName(std::string name, std::optional modulename = std::nullopt) { - // check if name starts with digit: - if (std::isdigit(name[0])) - { - name = "_" + name; - } - // check if name colides with Julia keywords, generated module name, or "location": - // https://docs.julialang.org/en/v1/base/base/#Keywords - std::vector reservedKeywords = {"include", "location", "baremodule", "begin", "break", "catch", "const", "continue", "do", "else", "elseif", "end", "export", "false", "finally", "for", "function", "global", "if", "import", "let", "local", "macro", "module", "public", "quote", "return", "struct", "true", "try", "using", "while"}; - if (modulename.has_value()) { - reservedKeywords.push_back(modulename.value()); - } - if (std::find(reservedKeywords.begin(), reservedKeywords.end(), name) != reservedKeywords.end()) - { - name = name + "_"; - } - // replace all .'s with _'s - std::replace(name.begin(), name.end(), '.', '_'); - std::replace(name.begin(), name.end(), '-', '_'); - return name; +std::string sanitizeName(std::string name, + std::optional modulename = std::nullopt) { + // check if name starts with digit: + if (std::isdigit(name[0])) { + name = "_" + name; + } + // check if name colides with Julia keywords, generated module name, or + // "location": https://docs.julialang.org/en/v1/base/base/#Keywords + std::vector reservedKeywords = { + "include", "location", "baremodule", "begin", "break", "catch", + "const", "continue", "do", "else", "elseif", "end", + "export", "false", "finally", "for", "function", "global", + "if", "import", "let", "local", "macro", "module", + "public", "quote", "return", "struct", "true", "try", + "using", "while"}; + if (modulename.has_value()) { + reservedKeywords.push_back(modulename.value()); + } + if (std::find(reservedKeywords.begin(), reservedKeywords.end(), name) != + reservedKeywords.end()) { + name = name + "_"; } + // replace all .'s with _'s + std::replace(name.begin(), name.end(), '.', '_'); + std::replace(name.begin(), name.end(), '-', '_'); + return name; +} } // namespace extern bool disableModuleWrap; bool emitOpTableDefs(const llvm::RecordKeeper &recordKeeper, - llvm::raw_ostream &os) -{ - llvm::ArrayRef opdefs = recordKeeper.getAllDerivedDefinitionsIfDefined("Op"); + llvm::raw_ostream &os) { + llvm::ArrayRef opdefs = + recordKeeper.getAllDerivedDefinitionsIfDefined("Op"); const char *moduleTemplate; - if (disableModuleWrap) - { - moduleTemplate = R"(import ...IR: IR, NamedAttribute, Value, Location, Block, Region, Attribute, create_operation, context, IndexType + if (disableModuleWrap) { + moduleTemplate = + R"(import ...IR: IR, NamedAttribute, Value, Location, Block, Region, Attribute, create_operation, context, IndexType import ..Dialects: namedattribute, operandsegmentsizes import ...API {0} )"; - } - else - { + } else { moduleTemplate = R"(module {0} using ...IR import ...IR: NamedAttribute, Value, Location, Block, Region, Attribute, create_operation, context, IndexType @@ -205,20 +206,20 @@ end operands, owned_regions, successors, attributes, results={7}, result_inference={8} - ))"; // 0: results, 1: operands, 2: owned_regions, 3: successors, 4: attributes, 5: optionals, 6: opname, 7: results expression, 8: result_inference + ))"; // 0: results, 1: operands, 2: owned_regions, 3: successors, 4: + // attributes, 5: optionals, 6: opname, 7: results expression, 8: + // result_inference std::string modulecontents = ""; std::string modulename; - if (!DialectName.empty()) - { + if (!DialectName.empty()) { modulename = DialectName; } else { modulename = getDialectName(opdefs); } - for (const auto *def : opdefs) - { + for (const auto *def : opdefs) { mlir::tblgen::Operator op(*def); std::string operandarguments = ""; @@ -226,24 +227,26 @@ end std::string optionals = ""; auto opname = op.getOperationName(); - auto functionname = opname.substr(op.getDialectName().str().length() + 1); // get rid of "dialect." prefix. + auto functionname = opname.substr(op.getDialectName().str().length() + + 1); // get rid of "dialect." prefix. functionname = sanitizeName(functionname, modulename); std::string description = ""; - if (op.hasDescription()) - { - description = "\"\"\"\n`"+functionname+"`\n"+formatDescription(op)+"\n\"\"\""; + if (op.hasDescription()) { + description = "\"\"\"\n`" + functionname + "`\n" + formatDescription(op) + + "\n\"\"\""; } bool inferrable = canInferType(op); - bool alreadykeyword = false; // set to true when first optional argument is encountered. This is used to insert a single semicolon (;) instead of a comma (,) as separator between positional and keyword arguments. - for (int i = 0; i < op.getNumOperands(); i++) - { + bool alreadykeyword = + false; // set to true when first optional argument is encountered. This + // is used to insert a single semicolon (;) instead of a comma + // (,) as separator between positional and keyword arguments. + for (int i = 0; i < op.getNumOperands(); i++) { const auto &named_operand = op.getOperand(i); std::string defaultvalue = ""; std::string operandname = named_operand.name.str(); - if (operandname.empty()) - { + if (operandname.empty()) { operandname = "operand_" + std::to_string(i); } operandname = sanitizeName(operandname); @@ -253,14 +256,12 @@ end bool optional = named_operand.isOptional(); bool variadic = named_operand.isVariadic(); - if (variadic) - { + if (variadic) { type = "Vector{" + type + "}"; } std::string separator = ", "; - if (optional) - { + if (optional) { optionals += llvm::formatv(R"(!isnothing({0}) && push!(operands, {0}{1}) )", operandname, (variadic ? "..." : "")); @@ -270,12 +271,11 @@ end if (!alreadykeyword) { alreadykeyword = true; separator = "; "; - } - } - else - { + } + } else { operandcontainer += operandname + (variadic ? "..." : "") + ", "; - separator = (!alreadykeyword && i == op.getNumOperands() - 1) ? "; " : ", "; + separator = + (!alreadykeyword && i == op.getNumOperands() - 1) ? "; " : ", "; } operandarguments += operandname + defaultvalue + "::" + type + separator; @@ -284,38 +284,35 @@ end operandarguments = "; "; } - if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) - { + if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { std::string operandsegmentsizes = ""; - for (int i = 0; i < op.getNumOperands(); i++) - { + for (int i = 0; i < op.getNumOperands(); i++) { const auto &named_operand = op.getOperand(i); std::string operandname = named_operand.name.str(); - if (operandname.empty()) - { + if (operandname.empty()) { operandname = "operand_" + std::to_string(i); } - if (named_operand.isOptional()) - { + if (named_operand.isOptional()) { operandsegmentsizes += "(" + operandname + "==nothing) ? 0 : 1"; continue; } - operandsegmentsizes += named_operand.isVariadic() ? "length(" + operandname + "), " : "1, "; + operandsegmentsizes += named_operand.isVariadic() + ? "length(" + operandname + "), " + : "1, "; } - optionals += llvm::formatv(R"(push!(attributes, operandsegmentsizes([{0}])) + optionals += + llvm::formatv(R"(push!(attributes, operandsegmentsizes([{0}])) )", - operandsegmentsizes); + operandsegmentsizes); } std::string resultarguments = ""; std::string resultcontainer = ""; - for (int i = 0; i < op.getNumResults(); i++) - { + for (int i = 0; i < op.getNumResults(); i++) { const auto &named_result = op.getResult(i); std::string defaultvalue = ""; std::string resultname = named_result.name.str(); - if (resultname.empty()) - { + if (resultname.empty()) { resultname = "result_" + std::to_string(i); } resultname = sanitizeName(resultname); @@ -324,33 +321,32 @@ end bool optional = named_result.isOptional() || inferrable; bool variadic = named_result.isVariadic(); - if (variadic) - { + if (variadic) { type = "Vector{" + type + "}"; } - if (optional) - { - optionals += llvm::formatv(R"(!isnothing({0}) && push!(op_ty_results, {0}{1}) + if (optional) { + optionals += + llvm::formatv(R"(!isnothing({0}) && push!(op_ty_results, {0}{1}) )", - resultname, (variadic ? "..." : "")); + resultname, (variadic ? "..." : "")); type = "Union{Nothing, " + type + "}"; defaultvalue = "=nothing"; - } - else - { + } else { resultcontainer += resultname + (variadic ? "..." : "") + ", "; } resultarguments += resultname + defaultvalue + "::" + type + ", "; } - std::string resultsexpression = (inferrable ? "(length(op_ty_results) == 0 ? nothing : op_ty_results)" : "op_ty_results"); - std::string resultinference = (inferrable ? "(length(op_ty_results) == 0 ? true : false)" : "false"); + std::string resultsexpression = + (inferrable ? "(length(op_ty_results) == 0 ? nothing : op_ty_results)" + : "op_ty_results"); + std::string resultinference = + (inferrable ? "(length(op_ty_results) == 0 ? true : false)" : "false"); std::string attributearguments = ""; std::string attributecontainer = ""; - for (int i = 0; i < op.getNumAttributes(); i++) - { + for (int i = 0; i < op.getNumAttributes(); i++) { const auto &named_attr = op.getAttribute(i); // Derived attributes are never materialized and don't have to be @@ -360,34 +356,33 @@ end std::string defaultvalue = ""; std::string attributename = named_attr.name.str(); - assert(!attributename.empty() && "expected NamedAttribute to have a name"); + assert(!attributename.empty() && + "expected NamedAttribute to have a name"); std::string sanitizedname = sanitizeName(attributename); - bool optional = named_attr.attr.isOptional() || named_attr.attr.hasDefaultValue(); + bool optional = + named_attr.attr.isOptional() || named_attr.attr.hasDefaultValue(); - if (optional) - { - optionals += llvm::formatv(R"(!isnothing({0}) && push!(attributes, namedattribute("{0}", {1})) + if (optional) { + optionals += llvm::formatv( + R"(!isnothing({0}) && push!(attributes, namedattribute("{0}", {1})) )", - attributename, sanitizedname); + attributename, sanitizedname); defaultvalue = "=nothing"; - } - else - { - attributecontainer += "namedattribute(\"" + attributename + "\", " + sanitizedname + "), "; + } else { + attributecontainer += "namedattribute(\"" + attributename + "\", " + + sanitizedname + "), "; } attributearguments += sanitizedname + defaultvalue + ", "; } std::string regionarguments = ""; std::string regioncontainer = ""; - for (size_t i = 0; i < op.getNumRegions(); i++) - { + for (size_t i = 0; i < op.getNumRegions(); i++) { const auto &named_region = op.getRegion(i); std::string defaultvalue = ""; std::string regionname = named_region.name.str(); - if (regionname.empty()) - { + if (regionname.empty()) { regionname = "region_" + std::to_string(i); } regionname = sanitizeName(regionname); @@ -395,8 +390,7 @@ end bool variadic = named_region.isVariadic(); - if (variadic) - { + if (variadic) { type = "Vector{" + type + "}"; } @@ -406,21 +400,18 @@ end std::string successorarguments = ""; std::string successorcontainer = ""; - for (size_t i = 0; i < op.getNumSuccessors(); i++) - { + for (size_t i = 0; i < op.getNumSuccessors(); i++) { const auto &named_successor = op.getSuccessor(i); std::string defaultvalue = ""; std::string successorname = named_successor.name.str(); - if (successorname.empty()) - { + if (successorname.empty()) { successorname = "successor_" + std::to_string(i); } successorname = sanitizeName(successorname); std::string type = "Block"; bool variadic = named_successor.isVariadic(); - if (variadic) - { + if (variadic) { type = "Vector{" + type + "}"; } @@ -428,18 +419,21 @@ end successorarguments += successorname + defaultvalue + "::" + type + ", "; } - std::string arguments = operandarguments + resultarguments + attributearguments + regionarguments + successorarguments; - std::string functionbody = llvm::formatv(functionbodytemplate, resultcontainer, operandcontainer, regioncontainer, successorcontainer, attributecontainer, optionals, opname, resultsexpression, resultinference); + std::string arguments = operandarguments + resultarguments + + attributearguments + regionarguments + + successorarguments; + std::string functionbody = + llvm::formatv(functionbodytemplate, resultcontainer, operandcontainer, + regioncontainer, successorcontainer, attributecontainer, + optionals, opname, resultsexpression, resultinference); - modulecontents += llvm::formatv(functiontemplate, functionname, arguments, functionbody, description); + modulecontents += llvm::formatv(functiontemplate, functionname, arguments, + functionbody, description); } - if (disableModuleWrap) - { + if (disableModuleWrap) { os << llvm::formatv(moduleTemplate, modulecontents); - } - else - { + } else { os << llvm::formatv(moduleTemplate, modulename, modulecontents); } diff --git a/deps/ReactantExtra/tblgen/mlir-jl-tblgen.cc b/deps/ReactantExtra/tblgen/mlir-jl-tblgen.cc index 7736b63e1b..6909350d9b 100644 --- a/deps/ReactantExtra/tblgen/mlir-jl-tblgen.cc +++ b/deps/ReactantExtra/tblgen/mlir-jl-tblgen.cc @@ -26,30 +26,33 @@ using namespace llvm; -using generator_function = bool(const llvm::RecordKeeper& recordKeeper, - llvm::raw_ostream& os); +using generator_function = bool(const llvm::RecordKeeper &recordKeeper, + llvm::raw_ostream &os); struct GeneratorInfo { - const char* name; - generator_function* generator; + const char *name; + generator_function *generator; }; extern generator_function emitOpTableDefs; extern generator_function emitTestTableDefs; -static std::array generators {{ - {"jl-op-defs", emitOpTableDefs}, +static std::array generators{{ + {"jl-op-defs", emitOpTableDefs}, }}; -generator_function* generator; +generator_function *generator; bool disableModuleWrap; int main(int argc, char **argv) { llvm::InitLLVM y(argc, argv); - llvm::cl::opt generatorOpt("generator", llvm::cl::desc("Generator to run"), cl::Required); - llvm::cl::opt disableModuleWrapOpt("disable-module-wrap", llvm::cl::desc("Disable module wrap"), cl::init(false)); + llvm::cl::opt generatorOpt( + "generator", llvm::cl::desc("Generator to run"), cl::Required); + llvm::cl::opt disableModuleWrapOpt( + "disable-module-wrap", llvm::cl::desc("Disable module wrap"), + cl::init(false)); cl::ParseCommandLineOptions(argc, argv); - for (const auto& spec : generators) { + for (const auto &spec : generators) { if (generatorOpt == spec.name) { generator = spec.generator; break; @@ -61,7 +64,8 @@ int main(int argc, char **argv) { } disableModuleWrap = disableModuleWrapOpt; - return TableGenMain(argv[0], [](raw_ostream& os, const RecordKeeper &records) { - return generator(records, os); - }); -} \ No newline at end of file + return TableGenMain(argv[0], + [](raw_ostream &os, const RecordKeeper &records) { + return generator(records, os); + }); +} diff --git a/deps/ReactantExtra/workspace.bzl b/deps/ReactantExtra/workspace.bzl new file mode 100644 index 0000000000..e69de29bb2 diff --git a/deps/build_local.jl b/deps/build_local.jl index 4138d2b6c6..4614ec97ea 100644 --- a/deps/build_local.jl +++ b/deps/build_local.jl @@ -1,30 +1,64 @@ # Invoke with -# `julia --project=deps deps/build_local.jl [dbg/opt] [auto/cpu/cuda]` +# `julia --project=deps deps/build_local.jl [--debug] [--backend=auto/cpu/cuda]` # the pre-built ReactantExtra_jll might not be loadable on this platform Reactant_jll = Base.UUID("0192cb87-2b54-54ad-80e0-3be72ad8a3c0") -using Pkg, Scratch, Preferences, Libdl +using ArgParse + +s = ArgParseSettings() +#! format: off +@add_arg_table! s begin + "--debug" + help = "Build with debug mode (-c dbg)." + action = :store_true + "--backend" + help = "Build with the specified backend (auto, cpu, cuda)." + default = "auto" + arg_type = String + "--gcc_host_compiler_path" + help = "Path to the gcc host compiler." + default = "/usr/bin/gcc" + arg_type = String + "--cc" + default = "/usr/bin/cc" + arg_type = String + "--hermetic_python_version" + help = "Hermetic Python version." + default = "3.10" + arg_type = String + "--jobs" + help = "Number of parallel jobs." + default = Sys.CPU_THREADS + arg_type = Int + "--copt" + help = "Options to be passed to the C compiler. Can be used multiple times." + action = :append_arg + arg_type = String + "--cxxopt" + help = "Options to be passed to the C++ compiler. Can be used multiple times." + action = :append_arg + arg_type = String + "--extraopt" + help = "Extra options to be passed to Bazel. Can be used multiple times." + action = :append_arg + arg_type = String + "--color" + help = "Set to `yes` to enable color output, or `no` to disable it. Defaults to same color setting as the Julia process." + default = something(Base.have_color, false) ? "yes" : "no" + arg_type = String +end +#! format: on +parsed_args = parse_args(ARGS, s) -# 1. Get a scratch directory -scratch_dir = get_scratch!(Reactant_jll, "build") -isdir(scratch_dir) && rm(scratch_dir; recursive=true) +println("Parsed args:") +for (k, v) in parsed_args + println(" $k = $v") +end +println() source_dir = joinpath(@__DIR__, "ReactantExtra") -# 2. Ensure that an appropriate LLVM_full_jll is installed -Pkg.activate(; temp=true) - -# Build! -@info "Building" source_dir scratch_dir -run(`mkdir -p $(scratch_dir)`) -run( - Cmd( - `$(Base.julia_cmd().exec[1]) --project=. -e "using Pkg; Pkg.instantiate()"`; - dir=source_dir, - ), -) - #--repo_env TF_NEED_ROCM=1 #--define=using_rocm=true --define=using_rocm_hipcc=true #--action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx1030" @@ -41,27 +75,10 @@ run( # --@local_config_cuda//:cuda_compiler=nvcc # --crosstool_top="@local_config_cuda//crosstool:toolchain" -build_kind = if length(ARGS) ≥ 1 - kind = ARGS[1] - if kind ∉ ("dbg", "opt") - error("Invalid build kind $(kind). Valid options are 'dbg' and 'opt'") - end - kind -else - "dbg" -end - -@info "Building JLL with -c $(build_kind)" +build_kind = parsed_args["debug"] ? "dbg" : "opt" -build_backend = if length(ARGS) ≥ 2 - backend = ARGS[2] - if backend ∉ ("auto", "cpu", "cuda") - error("Invalid build backend $(backend). Valid options are 'auto', 'cpu', and 'cuda'") - end - backend -else - "auto" -end +build_backend = parsed_args["backend"] +@assert build_backend in ("auto", "cpu", "cuda") if build_backend == "auto" build_backend = try @@ -78,46 +95,102 @@ elseif build_backend == "cpu" "" end -@info "Building JLL with backend $(build_backend)" - -if isempty(arg) - run( - Cmd( - `bazel build -c $(build_kind) --action_env=JULIA=$(Base.julia_cmd().exec[1]) - --repo_env HERMETIC_PYTHON_VERSION="3.10" - --check_visibility=false --verbose_failures :libReactantExtra.so`; - dir=source_dir, - ), - ) +bazel_cmd = if !isnothing(Sys.which("bazelisk")) + "bazelisk" +elseif !isnothing(Sys.which("bazel")) + "bazel" else - run( - Cmd( - `bazel build $(arg) -c $(build_kind) --action_env=JULIA=$(Base.julia_cmd().exec[1]) - --repo_env HERMETIC_PYTHON_VERSION="3.10" - --check_visibility=false --verbose_failures :libReactantExtra.so`; - dir=source_dir, - ), + error("Could not find `bazel` or `bazelisk` in PATH!") +end + +@info "Building JLL with $(bazel_cmd)" + +gcc_host_compiler_path = parsed_args["gcc_host_compiler_path"] +cc = parsed_args["cc"] +hermetic_python_version = parsed_args["hermetic_python_version"] + +# Try to guess if `cc` is GCC and get its version number. +cc_is_gcc, gcc_version = let + io = IOBuffer() + run(pipeline(ignorestatus(`$(cc) --version`); stdout=io)) + version_string = String(take!(io)) + # Detecing GCC is hard, the name "gcc" may not appear anywhere in the + # version string, but on the second line there should be FSF. + m = match( + r"\([^)]+\) (\d+\.\d+\.\d+).*\n.*Free Software Foundation, Inc\.", + version_string, ) + if !isnothing(m) + true, VersionNumber(m[1]) + else + false, v"0" + end end -# env=Dict("HOME"=>ENV["HOME"], "PATH"=>joinpath(source_dir, "..")*":"*ENV["PATH"]))) - -run(Cmd(`rm -f libReactantExtra.dylib`; dir=joinpath(source_dir, "bazel-bin"))) -run( - Cmd( - `ln -s libReactantExtra.so libReactantExtra.dylib`; - dir=joinpath(source_dir, "bazel-bin"), - ), -) + +build_cmd_list = [bazel_cmd, "build"] +!isempty(arg) && push!(build_cmd_list, arg) +append!(build_cmd_list, ["-c", "$(build_kind)"]) +push!(build_cmd_list, "--action_env=JULIA=$(Base.julia_cmd().exec[1])") +push!(build_cmd_list, "--repo_env=HERMETIC_PYTHON_VERSION=$(hermetic_python_version)") +push!(build_cmd_list, "--repo_env=GCC_HOST_COMPILER_PATH=$(gcc_host_compiler_path)") +push!(build_cmd_list, "--repo_env=CC=$(cc)") +push!(build_cmd_list, "--check_visibility=false") +push!(build_cmd_list, "--verbose_failures") +push!(build_cmd_list, "--jobs=$(parsed_args["jobs"])") +for opt in parsed_args["copt"] + push!(build_cmd_list, "--copt=$(opt)") +end +for opt in parsed_args["cxxopt"] + push!(build_cmd_list, "--cxxopt=$(opt)") +end +for opt in parsed_args["extraopt"] + push!(build_cmd_list, opt) +end +# Some versions of GCC can't deal with some components of XLA, disable them if necessary. +if cc_is_gcc && build_backend == "cuda" + arch = Base.BinaryPlatforms.arch(Base.BinaryPlatforms.HostPlatform()) + if arch == "x86_64" + if gcc_version < v"13" + push!(build_cmd_list, "--define=xnn_enable_avxvnniint8=false") + end + if gcc_version < v"12" + push!(build_cmd_list, "--define=xnn_enable_avx512fp16=false") + end + end +end +push!(build_cmd_list, "--color=$(parsed_args["color"])") +push!(build_cmd_list, ":libReactantExtra.so") + +run(Cmd(Cmd(build_cmd_list); dir=source_dir)) # Discover built libraries built_libs = filter(readdir(joinpath(source_dir, "bazel-bin"))) do file - endswith(file, "Extra.$(Libdl.dlext)") && startswith(file, "lib") + endswith(file, "Extra.so") && startswith(file, "lib") end lib_path = joinpath(source_dir, "bazel-bin", only(built_libs)) isfile(lib_path) || error("Could not find library $lib_path in build directory") -# Tell ReactReactantExtra_jllant_jll to load our library instead of the default artifact one +if build_backend == "cuda" + if !Base.Filesystem.ispath(joinpath(source_dir, "bazel-bin", "cuda", "bin", "ptxas")) + Base.Filesystem.mkpath(joinpath(source_dir, "bazel-bin", "cuda", "bin")) + Base.Filesystem.symlink( + joinpath( + source_dir, + "bazel-bin", + "libReactantExtra.so.runfiles", + "cuda_nvcc", + "bin", + "ptxas", + ), + joinpath(source_dir, "bazel-bin", "cuda", "bin", "ptxas"), + ) + end +end + +# Tell ReactantExtra_jll to load our library instead of the default artifact one +using Preferences + set_preferences!( joinpath(dirname(@__DIR__), "LocalPreferences.toml"), "Reactant_jll", diff --git a/deps/clang b/deps/clang deleted file mode 100755 index 77df2a34c4..0000000000 --- a/deps/clang +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/bash -/home/wmoses/llvms/llvm16/install/bin/clang -I/usr/include/x86_64-linux-gnu/c++/11 -L/home/wmoses/llvms/llvm16/build/lib/x86_64-unknown-linux-gnu -stdlib=libc++ -v "$@" diff --git a/deps/clang++ b/deps/clang++ deleted file mode 100755 index 25b16f719d..0000000000 --- a/deps/clang++ +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/bash -/home/wmoses/llvms/llvm16/build/bin/clang++ -I/usr/include/x86_64-linux-gnu/c++/11 -I/usr/include/c++/11 -I/usr/include/x86_64-linux-gnu/c++/11 -L/usr/lib/x86_64-linux-gnu "$@" diff --git a/deps/gcc b/deps/gcc deleted file mode 100755 index d92c9c10dc..0000000000 --- a/deps/gcc +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash -# /usr/local/cuda/bin/nvcc "$@" -/home/wmoses/llvms/llvm16/install/bin/clang -Xclang -fcuda-allow-variadic-functions -I/usr/include/c++/11 -I/usr/include/x86_64-linux-gnu/c++/11 -Wno-unused-command-line-argument -L/usr/lib/gcc/x86_64-linux-gnu/11 -static-libstdc++ "$@" || /home/wmoses/llvms/llvm16/install/bin/clang -Xclang -fcuda-allow-variadic-functions -I/usr/include/c++/11 -I/usr/include/x86_64-linux-gnu/c++/11 -Wno-unused-command-line-argument -L/usr/lib/gcc/x86_64-linux-gnu/11 -static-libstdc++ -g0 "$@" -g0 diff --git a/docs/Project.toml b/docs/Project.toml index c7e696f379..9d0c7ba803 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -2,7 +2,13 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterVitepress = "4710194d-e776-4893-9690-8d956a29c365" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" +ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] Documenter = "1.4.1" + +[sources] +Reactant = {path = ".."} +ReactantCore = {path = "../lib/ReactantCore"} diff --git a/docs/make.jl b/docs/make.jl index fcbaca60ef..fc1ac82e0f 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,11 +1,13 @@ -pushfirst!(LOAD_PATH, joinpath(@__DIR__, "..")) -pushfirst!(LOAD_PATH, joinpath(@__DIR__, "../lib/ReactantCore/")) - using Reactant, ReactantCore using Documenter, DocumenterVitepress DocMeta.setdocmeta!(Reactant, :DocTestSetup, :(using Reactant); recursive=true) +# Helper functions +function first_letter_uppercase(str) + return uppercase(str[1]) * str[2:end] +end + # Generate examples using Literate @@ -26,21 +28,24 @@ examples = [ pages = [ "Reactant.jl" => "index.md", - "Introduction" => ["Getting Started" => "introduction/index.md"], - "Tutorials" => ["Overview" => "tutorials/index.md"], + "Introduction" => [ + "Getting Started" => "introduction/index.md", + "Configuration" => "introduction/configuration.md", + ], + "Tutorials" => + ["Overview" => "tutorials/index.md", "Profiling" => "tutorials/profiling.md"], "API Reference" => [ "Reactant API" => "api/api.md", "Ops" => "api/ops.md", - "Dialects" => [ - "ArithOps" => "api/arith.md", - "Affine" => "api/affine.md", - "Builtin" => "api/builtin.md", - "Chlo" => "api/chlo.md", - "Enzyme" => "api/enzyme.md", - "Func" => "api/func.md", - "StableHLO" => "api/stablehlo.md", - "VHLO" => "api/vhlo.md", - ], + "Dialects" => sort!( + [ + first_letter_uppercase(first(splitext(basename(file)))) => + joinpath("api/dialects", file) for + file in readdir(joinpath(@__DIR__, "src/api/dialects")) if + splitext(file)[2] == ".md" + ]; + by=first, + ), "MLIR API" => "api/mlirc.md", "XLA" => "api/xla.md", "Internal API" => "api/internal.md", @@ -55,14 +60,13 @@ makedocs(; Reactant.MLIR, Reactant.MLIR.API, Reactant.MLIR.IR, - Reactant.MLIR.Dialects.chlo, - Reactant.MLIR.Dialects.vhlo, - Reactant.MLIR.Dialects.stablehlo, - Reactant.MLIR.Dialects.enzyme, - Reactant.MLIR.Dialects.arith, - Reactant.MLIR.Dialects.func, - Reactant.MLIR.Dialects.affine, - Reactant.MLIR.Dialects.builtin, + filter( + Base.Fix2(isa, Module), + [ + getproperty(Reactant.MLIR.Dialects, x) for + x in names(Reactant.MLIR.Dialects; all=true) if x != :Dialects + ], + )..., ], authors="William Moses , Valentin Churavy ", sitename="Reactant.jl", diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index 1dc25f2ad5..74e6d4ec59 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -51,25 +51,46 @@ export default defineConfig({ }, nav: [ { text: "Home", link: "/" }, - { text: "Getting Started", link: "/introduction" }, + { text: "Getting Started", + items: [ + { text: "Introduction", link: "/introduction" }, + { text: "Configuration", link: "/introduction/configuration" }, + ], + }, { text: "Benchmarks", link: "https://enzymead.github.io/Reactant.jl/benchmarks/" }, - { text: "Tutorials", link: "/tutorials/" }, + { + text: "Tutorials", + items: [ + {text: "Overview", link: "/tutorials/"}, + {text: "Profiling", link: "/tutorials/profiling"}, + ], + }, { text: "API", items: [ { text: "Core Reactant API", link: "/api/api" }, + { text: "Sharding", link: "/api/sharding" }, { text: "Ops", link: "/api/ops" }, { text: "MLIR Dialects", items: [ - { text: "ArithOps", link: "/api/arith" }, - { text: "Affine", link: "/api/affine" }, - { text: "Builtin", link: "/api/builtin" }, - { text: "Chlo", link: "/api/chlo" }, - { text: "Enzyme", link: "/api/enzyme" }, - { text: "Func", link: "/api/func" }, - { text: "StableHLO", link: "/api/stablehlo" }, - { text: "VHLO", link: "/api/vhlo" }, + { text: "ArithOps", link: "/api/dialects/arith" }, + { text: "Affine", link: "/api/dialects/affine" }, + { text: "Builtin", link: "/api/dialects/builtin" }, + { text: "Chlo", link: "/api/dialects/chlo" }, + { text: "Enzyme", link: "/api/dialects/enzyme" }, + { text: "EnzymeXLA", link: "/api/dialects/enzymexla" }, + { text: "Func", link: "/api/dialects/func" }, + { text: "GPU", link: "/api/dialects/gpu" }, + { text: "LLVM", link: "/api/dialects/llvm" }, + { text: "MPI", link: "/api/dialects/mpi" }, + { text: "MemRef", link: "/api/dialects/memref" }, + { text: "NVVM", link: "/api/dialects/nvvm" }, + { text: "Shardy", link: "/api/dialects/shardy" }, + { text: "StableHLO", link: "/api/dialects/stablehlo" }, + { text: "Triton", link: "/api/dialects/triton" }, + { text: "TPU", link: "/api/dialects/tpu" }, + { text: "VHLO", link: "/api/dialects/vhlo" }, ], }, { @@ -88,11 +109,11 @@ export default defineConfig({ ], sidebar: { "/introduction/": { - // @ts-ignore text: "Getting Started", collapsed: false, items: [ { text: "Introduction", link: "/introduction" }, + { text: "Configuration", link: "/introduction/configuration" }, ], }, "/tutorials/": { @@ -100,6 +121,7 @@ export default defineConfig({ collapsed: false, items: [ { text: "Overview", link: "/tutorials/" }, + { text: "Profiling", link: "/tutorials/profiling" }, ], }, "/api/": { @@ -110,19 +132,29 @@ export default defineConfig({ text: "Reactant API", link: "/api/api", }, + { text: "Sharding", link: "/api/sharding" }, { text: "Ops", link: "/api/ops" }, { text: "MLIR Dialects", collapsed: false, items: [ - { text: "ArithOps", link: "/api/arith" }, - { text: "Affine", link: "/api/affine" }, - { text: "Builtin", link: "/api/builtin" }, - { text: "Chlo", link: "/api/chlo" }, - { text: "Enzyme", link: "/api/enzyme" }, - { text: "Func", link: "/api/func" }, - { text: "StableHLO", link: "/api/stablehlo" }, - { text: "VHLO", link: "/api/vhlo" }, + { text: "ArithOps", link: "/api/dialects/arith" }, + { text: "Affine", link: "/api/dialects/affine" }, + { text: "Builtin", link: "/api/dialects/builtin" }, + { text: "Chlo", link: "/api/dialects/chlo" }, + { text: "Enzyme", link: "/api/dialects/enzyme" }, + { text: "EnzymeXLA", link: "/api/dialects/enzymexla" }, + { text: "Func", link: "/api/dialects/func" }, + { text: "GPU", link: "/api/dialects/gpu" }, + { text: "LLVM", link: "/api/dialects/llvm" }, + { text: "MPI", link: "/api/dialects/mpi" }, + { text: "MemRef", link: "/api/dialects/memref" }, + { text: "NVVM", link: "/api/dialects/nvvm" }, + { text: "Shardy", link: "/api/dialects/shardy" }, + { text: "StableHLO", link: "/api/dialects/stablehlo" }, + { text: "Triton", link: "/api/dialects/triton" }, + { text: "TPU", link: "/api/dialects/tpu" }, + { text: "VHLO", link: "/api/dialects/vhlo" }, ], }, { diff --git a/docs/src/api/api.md b/docs/src/api/api.md index 93c37b1793..d8746bddde 100644 --- a/docs/src/api/api.md +++ b/docs/src/api/api.md @@ -13,6 +13,10 @@ Reactant.@jit ## ReactantCore API +```@docs +within_compile +``` + ```@docs @trace ``` @@ -21,20 +25,24 @@ Reactant.@jit ```@docs @code_hlo +@code_mhlo +@code_xla ``` -```@raw html -
-``` +## Profile XLA -# Internal Functionality +Reactant can hook into XLA's profiler to generate compilation and execution traces. +See the [profiling tutorial](@ref profiling) for more details. -!!! danger "Private" +```@docs +Reactant.Profiler.with_profiler +Reactant.Profiler.annotate +Reactant.Profiler.@annotate +``` - These functions are not part of the public API and are subject to change at any time. +## Devices ```@docs -Reactant.Compiler.codegen_unflatten! -Reactant.Compiler.codegen_flatten! -Reactant.Compiler.codegen_xla_call +Reactant.devices +Reactant.addressable_devices ``` diff --git a/docs/src/api/affine.md b/docs/src/api/dialects/affine.md similarity index 100% rename from docs/src/api/affine.md rename to docs/src/api/dialects/affine.md diff --git a/docs/src/api/arith.md b/docs/src/api/dialects/arith.md similarity index 100% rename from docs/src/api/arith.md rename to docs/src/api/dialects/arith.md diff --git a/docs/src/api/builtin.md b/docs/src/api/dialects/builtin.md similarity index 100% rename from docs/src/api/builtin.md rename to docs/src/api/dialects/builtin.md diff --git a/docs/src/api/chlo.md b/docs/src/api/dialects/chlo.md similarity index 100% rename from docs/src/api/chlo.md rename to docs/src/api/dialects/chlo.md diff --git a/docs/src/api/enzyme.md b/docs/src/api/dialects/enzyme.md similarity index 100% rename from docs/src/api/enzyme.md rename to docs/src/api/dialects/enzyme.md diff --git a/docs/src/api/dialects/enzymexla.md b/docs/src/api/dialects/enzymexla.md new file mode 100644 index 0000000000..ed79c4fdee --- /dev/null +++ b/docs/src/api/dialects/enzymexla.md @@ -0,0 +1,9 @@ +```@meta +CollapsedDocStrings = true +``` + +# EnzymeXLA Dialect + +```@autodocs +Modules = [Reactant.MLIR.Dialects.enzymexla] +``` diff --git a/docs/src/api/func.md b/docs/src/api/dialects/func.md similarity index 100% rename from docs/src/api/func.md rename to docs/src/api/dialects/func.md diff --git a/docs/src/api/dialects/gpu.md b/docs/src/api/dialects/gpu.md new file mode 100644 index 0000000000..9cdf91aac6 --- /dev/null +++ b/docs/src/api/dialects/gpu.md @@ -0,0 +1,12 @@ +```@meta +CollapsedDocStrings = true +``` + +# GPU Dialect + +Refer to the [official documentation](https://mlir.llvm.org/docs/Dialects/GPU/) for +more details. + +```@autodocs +Modules = [Reactant.MLIR.Dialects.gpu] +``` diff --git a/docs/src/api/dialects/llvm.md b/docs/src/api/dialects/llvm.md new file mode 100644 index 0000000000..48a715429b --- /dev/null +++ b/docs/src/api/dialects/llvm.md @@ -0,0 +1,12 @@ +```@meta +CollapsedDocStrings = true +``` + +# LLVM Dialect + +Refer to the [official documentation](https://mlir.llvm.org/docs/Dialects/LLVM/) for +more details. + +```@autodocs +Modules = [Reactant.MLIR.Dialects.llvm] +``` diff --git a/docs/src/api/dialects/memref.md b/docs/src/api/dialects/memref.md new file mode 100644 index 0000000000..5d9e28c412 --- /dev/null +++ b/docs/src/api/dialects/memref.md @@ -0,0 +1,12 @@ +```@meta +CollapsedDocStrings = true +``` + +# MemRef Dialect + +Refer to the [official documentation](https://mlir.llvm.org/docs/Dialects/MemRef/) for more +details. + +```@autodocs +Modules = [Reactant.MLIR.Dialects.memref] +``` diff --git a/docs/src/api/dialects/mpi.md b/docs/src/api/dialects/mpi.md new file mode 100644 index 0000000000..5b0570714e --- /dev/null +++ b/docs/src/api/dialects/mpi.md @@ -0,0 +1,12 @@ +```@meta +CollapsedDocStrings = true +``` + +# MPI Dialect + +Refer to the [official documentation](https://mlir.llvm.org/docs/Dialects/MPI/) for +more details. + +```@autodocs +Modules = [Reactant.MLIR.Dialects.mpi] +``` diff --git a/docs/src/api/dialects/nvvm.md b/docs/src/api/dialects/nvvm.md new file mode 100644 index 0000000000..28169dc7a8 --- /dev/null +++ b/docs/src/api/dialects/nvvm.md @@ -0,0 +1,12 @@ +```@meta +CollapsedDocStrings = true +``` + +# NVVM Dialect + +Refer to the [official documentation](https://mlir.llvm.org/docs/Dialects/NVVMDialect/) for +more details. + +```@autodocs +Modules = [Reactant.MLIR.Dialects.nvvm] +``` diff --git a/docs/src/api/dialects/shardy.md b/docs/src/api/dialects/shardy.md new file mode 100644 index 0000000000..8e0192c5ea --- /dev/null +++ b/docs/src/api/dialects/shardy.md @@ -0,0 +1,11 @@ +```@meta +CollapsedDocStrings = true +``` + +# Shardy Dialect + +Refer to the [official documentation](https://openxla.org/shardy) for more details. + +```@autodocs +Modules = [Reactant.MLIR.Dialects.sdy] +``` diff --git a/docs/src/api/stablehlo.md b/docs/src/api/dialects/stablehlo.md similarity index 100% rename from docs/src/api/stablehlo.md rename to docs/src/api/dialects/stablehlo.md diff --git a/docs/src/api/dialects/tpu.md b/docs/src/api/dialects/tpu.md new file mode 100644 index 0000000000..9494cd9655 --- /dev/null +++ b/docs/src/api/dialects/tpu.md @@ -0,0 +1,12 @@ +```@meta +CollapsedDocStrings = true +``` + +# TPU Dialect + +Refer to the [official documentation](https://github.com/jax-ml/jax/blob/main/jaxlib/mosaic/dialect/tpu/tpu.td) for +more details. + +```@autodocs +Modules = [Reactant.MLIR.Dialects.tpu] +``` diff --git a/docs/src/api/dialects/triton.md b/docs/src/api/dialects/triton.md new file mode 100644 index 0000000000..fdfb9654ae --- /dev/null +++ b/docs/src/api/dialects/triton.md @@ -0,0 +1,12 @@ +```@meta +CollapsedDocStrings = true +``` + +# Triton Dialect + +Refer to the [official documentation](https://triton-lang.org/main/dialects/TritonDialect.html) for +more details. + +```@autodocs +Modules = [Reactant.MLIR.Dialects.tt] +``` diff --git a/docs/src/api/vhlo.md b/docs/src/api/dialects/vhlo.md similarity index 100% rename from docs/src/api/vhlo.md rename to docs/src/api/dialects/vhlo.md diff --git a/docs/src/api/internal.md b/docs/src/api/internal.md index a8788e5fb9..26b6d09922 100644 --- a/docs/src/api/internal.md +++ b/docs/src/api/internal.md @@ -4,9 +4,13 @@ CollapsedDocStrings = true # Internal API -These functions are not part of the public API and are subject to change at any time. +!!! danger "Private" + + These functions are not part of the public API and are subject to change at any time. ```@docs Reactant.REDUB_ARGUMENTS_NAME -Reactant.within_reactant_interpreter +Reactant.Compiler.codegen_unflatten! +Reactant.Compiler.codegen_flatten! +Reactant.Compiler.codegen_xla_call ``` diff --git a/docs/src/api/sharding.md b/docs/src/api/sharding.md new file mode 100644 index 0000000000..037f8dbac0 --- /dev/null +++ b/docs/src/api/sharding.md @@ -0,0 +1,14 @@ +```@meta +CollapsedDocStrings = true +``` + +# Sharding API + +`Reactant.Sharding` module provides a high-level API to construct MLIR operations with +support for sharding. + +Currently we haven't documented all the functions in `Reactant.Sharding`. + +```@autodocs +Modules = [Reactant.Sharding] +``` diff --git a/docs/src/introduction/configuration.md b/docs/src/introduction/configuration.md new file mode 100644 index 0000000000..f35aa1c6d3 --- /dev/null +++ b/docs/src/introduction/configuration.md @@ -0,0 +1,161 @@ +# Configuration + +When you [install](@ref Installation) `Reactant.jl`, the library powering the package compatible with your system will be automatically installed for you. +Below are some information about making sure that you are using the right configuration of Reactant for your machine. + +## Reactant with CPU + +At the moment Reactant supports only Linux (x86-64 and aarch64 architectures) and macOS (x86-64 and aarch64 architectures). +If you are using Julia on any of these systems, then Reactant should always support the CPU backend. +In the same environment where you installed Reactant you can verify it by running the following commands: + +```julia-repl +julia> import Pkg + +julia> Pkg.add("Reactant_jll") + [...] + +julia> import Reactant_jll + +julia> Reactant_jll.is_available() +true +``` + +If the last command returns `true`, you are good to go, if you get `false` but you think your system is one of the supported ones listed above, [open an issue](https://github.com/EnzymeAD/Reactant.jl/issues/new/choose). + +## Reactant with GPU + +At the moment Reactant supports only Nvidia GPUs. + +### Nvidia GPU + +Reactant can accelerate your code using Nvidia GPUs on Linux, with CUDA Driver 12.1+ on x86-64, and CUDA Driver 12.3+ on aarch64. +You can check if Reactant detected the GPU on your system by running the following commands in the environment where you installed Reactant: + +```julia-repl +julia> import Pkg + +julia> Pkg.add("Reactant_jll") + [...] + +julia> import Reactant_jll + +julia> Reactant_jll.is_available() +true + +julia> Reactant_jll.host_platform +Linux x86_64 {cuda_version=12.1, cxxstring_abi=cxx11, gpu=cuda, julia_version=1.11.3, libc=glibc, libgfortran_version=5.0.0, libstdcxx_version=3.4.30, mode=opt} +``` + +Like in the CPU section above, we ran `Reactant_jll.is_available()` to make sure Reactant is available at all, the `Reactant_jll.host_platform` variable then gives us more information about the detected platform. +In particular, if you have an Nvidia GPU you should expect to see `gpu=cuda` and `cuda_version=X.Y`, where `X.Y` should be a version less than or equal to the version of the CUDA Driver present in your system (don't worry if you don't see here exactly the same version as your CUDA Driver, that is expected). + +#### Debugging installation with Nvidia GPUs + +In some cases you may want to get more verbose information from Reactant during its installation process, to see how it detected CUDA. +To do that, you can force re-installation of `Reactant_jll` with increased verbosity with the commands + +```julia-repl +julia> rm(joinpath(Base.DEPOT_PATH[1], "compiled", "v$(VERSION.major).$(VERSION.minor)", "Reactant_jll"); recursive=true, force=true) + +julia> ENV["JULIA_DEBUG"] = "Reactant_jll"; + +julia> import Pkg + +julia> Pkg.add("Reactant_jll") + [...] + 1 dependency had output during precompilation: +┌ Reactant_jll +│ ┌ Debug: Detected CUDA Driver version 12.2.0 +│ └ @ Reactant_jll ~/.julia/packages/Reactant_jll/daenT/.pkg/platform_augmentation.jl:60 +│ ┌ Debug: Adding include dependency on /lib/x86_64-linux-gnu/libcuda.so.1 +│ └ @ Reactant_jll ~/.julia/packages/Reactant_jll/daenT/.pkg/platform_augmentation.jl:108 +``` + +Here you can see that on this system Reactant found the CUDA Driver at `/lib/x86_64-linux-gnu/libcuda.so.1` with version 12.2.0. + +#### Installing Reactant on GPU Servers without Internet + +If you want to use Reactant on GPU Servers where all packages must be installed on the login nodes and the compute nodes don't have access to internet, add the following to the `Project.toml` and precompile the package: + +```toml +[extras] +Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0" + +[preferences.Reactant_jll] +gpu = "cuda" +``` + +#### Disabling CUDA support + +Reactant looks for the CUDA Driver library `libcuda` to determine whether the current system supports Nvidia GPUs. +However in some cases this library may be actually present on the machine even though no GPU is actually attached to it, which would trick Reactant's installation process into believing a GPU is available. +Normally this is not a problem as Reactant will detect that in spite of the CUDA Driver being present there are no GPUs and will default to the CPU backend. +If you do experience issues due to a GPU being detected erroneously, you can force disabling GPU support by creating a file called `LocalPreferences.toml` in the environment where you installed Reactant with the following content: + +```toml +[Reactant_jll] +gpu = "none" +``` + +install the package `Reactant_jll`: + +```julia +import Pkg +Pkg.add("Reactant_jll") +``` + +and then when you restart Julia you should see + +```julia-repl +julia> import Reactant_jll + +julia> Reactant_jll.is_available() +true + +julia> Reactant_jll.host_platform +Linux x86_64 {cuda_version=none, cxxstring_abi=cxx11, gpu=none, julia_version=1.11.3, libc=glibc, libgfortran_version=5.0.0, libstdcxx_version=3.4.30, mode=opt} +``` + +Reactant is still available for your system, but this time GPU support is disabled. + +## Reactant with TPU + +Reactant should detect automatically when you are running on a machine with a TPU, and load dynamically the necessary modules. +You can verify a TPU was found correctly with the following commands: + +```julia-repl +julia> import Reactant + +julia> Reactant.has_tpu() +true +``` + +### Memory errors on Google Cloud Platform + +If you are running Julia on Google Cloud Platform, you may frequently get scary-looking memory-related error messages like: + +``` +double free or corruption (out) +``` + +or + +``` +free(): invalid pointer +``` + +This is due to the fact that in this environment a memory allocator incompatible with Julia is forced via the `LD_PRELOAD` environment variable. +Starting Julia with + +```sh +LD_PRELOAD='' julia +``` + +or unsetting the variable + +```sh +unset LD_PRELOAD +``` + +should solve this issue. diff --git a/docs/src/introduction/index.md b/docs/src/introduction/index.md index d732bb0a35..5f0cd5002c 100644 --- a/docs/src/introduction/index.md +++ b/docs/src/introduction/index.md @@ -53,3 +53,83 @@ f = @compile sinsum_add(input1,input2) # one can now run the program f(input1, input2) ``` + + +## Tips + +### Empty Cache + +When you encounter OOM (Out of Memory) errors, you can try to clear the cache by using Julia's builtin `GC.gc()` between memory-intensive operations. + +!!! note + This will only free memory which is not currently live. If the result of compiled function was stored in a vector, it will still be alive and `GC.gc()` won't free it. + +```julia +using Reactant +n = 500_000_000 +input1 = Reactant.ConcreteRArray(ones(n)) +input2 = Reactant.ConcreteRArray(ones(n)) + +function sin_add(x, y) + return sin.(x) .+ y +end + +f = @compile sin_add(input1,input2) + +for i = 1:10 + GC.gc() + @info "gc... $i" + f(input1, input2) # May cause OOM here for a 24GB GPU if GC is not used +end +``` + +If you **don't** use `GC.gc()` here, this may cause an OOM: + + + +```bash +[ Info: gc... 1 +[ Info: gc... 2 +[ Info: gc... 3 +... +E0105 09:48:28.755177 110350 pjrt_stream_executor_client.cc:3088] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 4000000000 bytes. +ERROR: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 4000000000 bytes. + +Stacktrace: + [1] reactant_err(msg::Cstring) + @ Reactant.XLA ~/.julia/packages/Reactant/7m11i/src/XLA.jl:104 + [2] macro expansion + @ ~/.julia/packages/Reactant/7m11i/src/XLA.jl:357 [inlined] + [3] ExecutableCall + @ ~/.julia/packages/Reactant/7m11i/src/XLA.jl:334 [inlined] + [4] macro expansion + @ ~/.julia/packages/Reactant/7m11i/src/Compiler.jl:798 [inlined] + [5] (::Reactant.Compiler.Thunk{…})(::ConcreteRArray{…}, ::ConcreteRArray{…}) + @ Reactant.Compiler ~/.julia/packages/Reactant/7m11i/src/Compiler.jl:909 + [6] top-level scope + @ ./REPL[7]:4 +Some type information was truncated. Use `show(err)` to see complete types. +``` + + +After using Julia's built-in `GC.gc()`: + + + +```bash +[ Info: gc... 1 +[ Info: gc... 2 +[ Info: gc... 3 +[ Info: gc... 4 +[ Info: gc... 5 +[ Info: gc... 6 +[ Info: gc... 7 +[ Info: gc... 8 +[ Info: gc... 9 +[ Info: gc... 10 +``` + + + + + diff --git a/docs/src/tutorials/images/perfetto.png b/docs/src/tutorials/images/perfetto.png new file mode 100644 index 0000000000..89cf4f72fb Binary files /dev/null and b/docs/src/tutorials/images/perfetto.png differ diff --git a/docs/src/tutorials/images/tensorboard.png b/docs/src/tutorials/images/tensorboard.png new file mode 100644 index 0000000000..5758e1dd65 Binary files /dev/null and b/docs/src/tutorials/images/tensorboard.png differ diff --git a/docs/src/tutorials/index.md b/docs/src/tutorials/index.md index eb2beb1f1f..87c2c8ddd3 100644 --- a/docs/src/tutorials/index.md +++ b/docs/src/tutorials/index.md @@ -1,3 +1,5 @@ # Tutorials -We are currently working on adding tutorials to Reactant!! Please check back soon! + - [Profiling](@ref profiling). + +We are currently working on adding more tutorials to Reactant!! Please check back soon! diff --git a/docs/src/tutorials/profiling.md b/docs/src/tutorials/profiling.md new file mode 100644 index 0000000000..ae7c752b35 --- /dev/null +++ b/docs/src/tutorials/profiling.md @@ -0,0 +1,84 @@ +# [Profiling](@id profiling) + +## Capturing traces + +When running Reactant, it is possible to capture traces using the [XLA profiler](https://jax.readthedocs.io/en/latest/profiling.html). +These traces can provide information about where the XLA specific parts of program spend time during compilation or execution. Note that tracing and compilation happen on the CPU even though the final execution is aimed to run on another device such as GPU or TPU. Therefore, including tracing and compilation in a trace will create annotations on the CPU. + +Let's setup a simple function which we can then profile + +```@example profiling +using Reactant + +x = Reactant.to_rarray(randn(Float32, 100, 2)) +W = Reactant.to_rarray(randn(Float32, 10, 100)) +b = Reactant.to_rarray(randn(Float32, 10)) + +linear(x, W, b) = (W * x) .+ b +``` + +The profiler can be accessed using the [`Reactant.with_profiler`](@ref Reactant.Profiler.with_profiler) function. + +```@example profiling +Reactant.with_profiler("./") do + mylinear = Reactant.@compile linear(x, W, b) + mylinear(x, W, b) +end +``` + +Running this function should create a folder called `plugins` in the folder provided to `Reactant.with_profiler` which will +contain the trace files. The traces can then be visualized in different ways. + +!!! note + For more insights about the current state of Reactant, it is possible to fetch device information about allocations using the [`Reactant.XLA.allocatorstats`](@ref) function. + +## Perfetto UI + +![The perfetto interface](images/perfetto.png) + +The first and easiest way to visualize a captured trace is to use the online [`perfetto.dev`](https://ui.perfetto.dev/) tool. +[`Reactant.with_profiler`](@ref Reactant.Profiler.with_profiler) has a keyword parameter called `create_perfetto_link` which will create a usable perfetto URL for the generated trace. +The function will block execution until the URL has been clicked and the trace is visualized. The URL only works once. + +```julia +Reactant.with_profiler("./"; create_perfetto_link=true) do + mylinear = Reactant.@compile linear(x, W, b) + mylinear(x, W, b) +end +``` + +!!! note + It is recommended to use the Chrome browser to open the perfetto URL. + +## Tensorboard + +![The tensorboard interface](images/tensorboard.png) + +Another option to visualize the generated trace files is to use the [tensorboard profiler plugin](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras). +The tensorboard viewer can offer more details than the timeline view such as visualization for compute graphs. + +First install tensorboard and its profiler plugin: + +```bash +pip install tensorboard tensorboard-plugin-profile +``` + +And then run the following in the folder where the `plugins` folder was generated: + +```bash +tensorboard --logdir ./ +``` + +## Adding Custom Annotations + +By default, the traces contain only information captured from within XLA. +The [`Reactant.Profiler.annotate`](@ref) function can be used to annotate traces for Julia code evaluated *during tracing*. + +```julia +Reactant.Profiler.annotate("my_annotation") do + # Do things... +end +``` + +The added annotations will be captured in the traces and can be seen in the different viewers along with the default XLA annotations. +When the profiler is not activated, then the custom annotations have no effect and can therefore always be activated. diff --git a/ext/ReactantArrayInterfaceExt.jl b/ext/ReactantArrayInterfaceExt.jl index ffd7ddc60e..651e0ad292 100644 --- a/ext/ReactantArrayInterfaceExt.jl +++ b/ext/ReactantArrayInterfaceExt.jl @@ -2,19 +2,21 @@ module ReactantArrayInterfaceExt using ArrayInterface: ArrayInterface using Reactant: - Reactant, RArray, ConcreteRArray, ConcreteRNumber, TracedRNumber, TracedRArray, Ops + Reactant, + RArray, + ConcretePJRTArray, + ConcretePJRTNumber, + TracedRNumber, + TracedRArray, + AnyTracedRArray, + Ops ArrayInterface.can_setindex(::Type{<:RArray}) = false ArrayInterface.fast_scalar_indexing(::Type{<:RArray}) = false -function ArrayInterface.aos_to_soa(x::AbstractArray{<:ConcreteRNumber{T}}) where {T} - x_c = ConcreteRArray(zeros(T, size(x))) - x_c .= x - return x_c -end - -function ArrayInterface.aos_to_soa(x::AbstractArray{<:TracedRNumber{T}}) where {T} - return Ops.reshape(vcat(x...), size(x)...) +for aType in + (AbstractArray{<:ConcretePJRTNumber}, AbstractArray{<:TracedRNumber}, AnyTracedRArray) + @eval ArrayInterface.aos_to_soa(x::$aType) = Reactant.aos_to_soa(x) end end diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index ba0765af5f..1741252356 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -1,14 +1,32 @@ module ReactantCUDAExt using CUDA -using Reactant: Reactant, TracedRArray, AnyTracedRArray, MLIR, TracedRNumber +using Reactant: + Reactant, TracedRArray, AnyTracedRArray, AnyConcretePJRTArray, MLIR, TracedRNumber +using Reactant.Compiler: raising using ReactantCore: @trace +using GPUCompiler: GPUCompiler +using KernelAbstractions: KernelAbstractions +import KernelAbstractions as KA +using LLVM: LLVM using Libdl +const ReactantKernelAbstractionsExt = Base.get_extension( + Reactant, :ReactantKernelAbstractionsExt +) +const ReactantBackend = ReactantKernelAbstractionsExt.ReactantBackend using Adapt struct CuTracedArray{T,N,A,Size} <: DenseArray{T,N} ptr::Core.LLVMPtr{T,A} + + function CuTracedArray{T,N,A,Size}(xs::TracedRArray) where {T,N,A,Size} + gc_vec = Reactant.Compiler.context_gc_vector[MLIR.IR.context()] + push!(gc_vec, xs) + @assert gc_vec[end] === xs + ptr = Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, Base.pointer_from_objref(xs)) + return new(ptr) + end end function Base.show(io::IO, a::AT) where {AT<:CuTracedArray} @@ -35,6 +53,19 @@ function Base.unsafe_convert( return x.ptr end +# TODO: arrays as allocated by the CUDA APIs are 256-byte aligned. we should keep track of +# this information, because it enables optimizations like Load Store Vectorization +# (cfr. shared memory and its wider-than-datatype alignment) + +@generated function alignment(::CuTracedArray{T}) where {T} + if Base.isbitsunion(T) + _, sz, al = Base.uniontype_layout(T) + al + else + Base.datatype_alignment(T) + end +end + ## indexing intrinsics CUDA.@device_function @inline function arrayref( @@ -49,7 +80,8 @@ CUDA.@device_function @inline function arrayref( end @inline function arrayref_bits(A::CuTracedArray{T}, index::Integer) where {T} - return unsafe_load(pointer(A), index) + align = alignment(A) + return unsafe_load(pointer(A), index, Val(align)) end @inline @generated function arrayref_union( @@ -94,7 +126,8 @@ CUDA.@device_function @inline function arrayset( end @inline function arrayset_bits(A::CuTracedArray{T}, x::T, index::Integer) where {T} - return unsafe_store!(pointer(A), x, index) + align = alignment(A) + return unsafe_store!(pointer(A), x, index, Val(align)) end @inline @generated function arrayset_union( @@ -107,9 +140,10 @@ end selector_ptr = typetagdata(A, index) unsafe_store!(selector_ptr, $(UInt8(sel - 1))) + align = alignment(A) data_ptr = pointer(A, index) - unsafe_store!(reinterpret(Core.LLVMPtr{$x,AS}, data_ptr), x, 1) + unsafe_store!(reinterpret(Core.LLVMPtr{$x,AS}, data_ptr), x, 1, Val(align)) return nothing end end @@ -118,7 +152,8 @@ CUDA.@device_function @inline function const_arrayref( A::CuTracedArray{T}, index::Integer ) where {T} @boundscheck checkbounds(A, index) - return unsafe_cached_load(pointer(A), index) + align = alignment(A) + return unsafe_cached_load(pointer(A), index, Val(align)) end ## indexing @@ -211,22 +246,143 @@ function Base.reshape(a::CuTracedArray{T,M,A}, dims::NTuple{N,Int}) where {T,N,M return _derived_array(a, T, dims) end -function Adapt.adapt_storage(::CUDA.KernelAdaptor, xs::TracedRArray{T,N}) where {T,N} - res = CuTracedArray{T,N,CUDA.AS.Global,size(xs)}( - Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, Base.pointer_from_objref(xs)) +struct ReactantKernelAdaptor end + +function Adapt.adapt_storage(to::ReactantKernelAdaptor, p::CUDA.CuPtr) + return error("Cannot convert CuPtr argument of Reactant Kernel") +end +function Adapt.adapt_storage(ka::ReactantKernelAdaptor, xs::DenseCuArray) + return Adapt.adapt_storage(ka, Array(xs)) +end +function Adapt.adapt_storage(ka::ReactantKernelAdaptor, xs::Array) + return Adapt.adapt_storage(ka, Reactant.Ops.constant(xs)) +end +function Adapt.adapt_structure( + to::ReactantKernelAdaptor, bc::Broadcast.Broadcasted{Style,<:Any,Type{T}} +) where {Style,T} + return Broadcast.Broadcasted{Style}( + (x...) -> T(x...), Adapt.adapt(to, bc.args), bc.axes ) - return res end -const _kernel_instances = Dict{Any,Any}() +function threads_to_workgroupsize(threads, ndrange) + total = 1 + return map(ndrange) do n + x = min(div(threads, total), n) + total *= x + return x + end +end + +function ka_with_reactant(ndrange, workgroupsize, obj, args...) + backend = KA.backend(obj) + + ndrange, workgroupsize, iterspace, dynamic = KA.launch_config( + obj, ndrange, workgroupsize + ) + # this might not be the final context, since we may tune the workgroupsize + ctx = KA.mkcontext(obj, ndrange, iterspace) + + # If the kernel is statically sized we can tell the compiler about that + if KA.workgroupsize(obj) <: KA.StaticSize + maxthreads = prod(KA.get(KA.workgroupsize(obj))) + else + maxthreads = nothing + end + + kernel = CUDA.@cuda launch = false always_inline = backend.always_inline maxthreads = + maxthreads obj.f(ctx, args...) + + # figure out the optimal workgroupsize automatically + if KA.workgroupsize(obj) <: KA.DynamicSize && workgroupsize === nothing + if !Reactant.Compiler.PartitionKA[] || raising() + threads = prod(ndrange) + else + config = CUDA.launch_configuration(kernel.fun; max_threads=prod(ndrange)) + if backend.prefer_blocks + # Prefer blocks over threads + threads = min(prod(ndrange), config.threads) + # XXX: Some kernels performs much better with all blocks active + cu_blocks = max(cld(prod(ndrange), threads), config.blocks) + threads = cld(prod(ndrange), cu_blocks) + else + threads = config.threads + end + workgroupsize = threads_to_workgroupsize(threads, ndrange) + iterspace, dynamic = KA.partition(obj, ndrange, workgroupsize) + end + ctx = KA.mkcontext(obj, ndrange, iterspace) + end + + blocks = length(KA.blocks(iterspace)) + threads = length(KA.workitems(iterspace)) + if blocks == 0 + return nothing + end + + # Launch kernel + kernel(ctx, args...; threads, blocks) + + return nothing +end + +Reactant.@reactant_overlay @noinline function (obj::KA.Kernel{ReactantBackend})( + args...; ndrange=nothing, workgroupsize=nothing +) + return Reactant.call_with_reactant( + ka_with_reactant, ndrange, workgroupsize, obj, args... + ) +end + +Adapt.adapt_storage(to::KA.ConstAdaptor, a::CuTracedArray) = Base.Experimental.Const(a) + +struct ReactantRefValue{T} <: Ref{T} + val::T +end +Base.getindex(r::ReactantRefValue{T}) where {T} = r.val +function Adapt.adapt_structure(to::ReactantKernelAdaptor, ref::Base.RefValue) + return ReactantRefValue(adapt(to, ref[])) +end + +function recudaconvert(arg) + return adapt(ReactantKernelAdaptor(), arg) +end +Reactant.@reactant_overlay @noinline function CUDA.cudaconvert(arg) + return recudaconvert(arg) +end + +function Adapt.adapt_storage(::ReactantKernelAdaptor, xs::TracedRArray{T,N}) where {T,N} + res = CuTracedArray{T,N,CUDA.AS.Global,size(xs)}(xs) + return res +end + +# Since we cache these objects we cannot cache data containing MLIR operations (e.g. the entry must be a string +# and not the operation itself). struct LLVMFunc{F,tt} f::Union{F,Nothing} - entry::MLIR.IR.Operation + entry::String end -const GPUCompiler = CUDA.GPUCompiler -const LLVM = GPUCompiler.LLVM +function Base.getproperty(f::LLVMFunc{F,tt}, sym::Symbol) where {F,tt} + if sym === :fun + f + else + Base.getfield(f, sym) + end +end + +# TODO in the future we may want to avoid doing a second cufunction compilation +# for computing the thread/block count (or potentially do it ourselves). +@noinline function CUDA.launch_configuration( + f::LLVMFunc{F,tt}; shmem::Union{Integer,Base.Callable}=0, max_threads::Integer=0 +) where {F,tt} + return CUDA.launch_configuration( + Base.inferencebarrier(CUDA.cufunction)(f.f, Tuple{tt.parameters[2:end]...}).fun; + shmem, + max_threads, + ) +end function GPULowerCPUFeaturesPass() return LLVM.NewPMModulePass("GPULowerCPUFeatures", GPUCompiler.cpu_features!) @@ -261,18 +417,260 @@ AddKernelStatePass() = LLVM.NewPMModulePass("AddKernelStatePass", kern_pass) LowerKernelStatePass() = LLVM.NewPMFunctionPass("LowerKernelStatePass", noop_pass) CleanupKernelStatePass() = LLVM.NewPMModulePass("CleanupKernelStatePass", noop_pass) +# From https://github.com/JuliaGPU/GPUCompiler.jl/blob/7b9322faa34685026c4601a5084eecf5a5d7f3fe/src/ptx.jl#L149 +function vendored_optimize_module!( + @nospecialize(job), mod::LLVM.Module, instcombine::Bool=false +) + tm = GPUCompiler.llvm_machine(job.config.target) + # TODO: Use the registered target passes (JuliaGPU/GPUCompiler.jl#450) + LLVM.@dispose pb = LLVM.NewPMPassBuilder() begin + LLVM.register!(pb, GPUCompiler.NVVMReflectPass()) + + LLVM.add!(pb, LLVM.NewPMFunctionPassManager()) do fpm + # TODO: need to run this earlier; optimize_module! is called after addOptimizationPasses! + LLVM.add!(fpm, GPUCompiler.NVVMReflectPass()) + + # needed by GemmKernels.jl-like code + LLVM.add!(fpm, LLVM.SpeculativeExecutionPass()) + + # NVPTX's target machine info enables runtime unrolling, + # but Julia's pass sequence only invokes the simple unroller. + LLVM.add!(fpm, LLVM.LoopUnrollPass(; job.config.opt_level)) + if instcombine + LLVM.add!(fpm, LLVM.InstCombinePass()) # clean-up redundancy + else + LLVM.add!(fpm, LLVM.InstSimplifyPass()) # clean-up redundancy + end + LLVM.add!(fpm, LLVM.NewPMLoopPassManager(; use_memory_ssa=true)) do lpm + LLVM.add!(lpm, LLVM.LICMPass()) # the inner runtime check might be outer loop invariant + end + + # the above loop unroll pass might have unrolled regular, non-runtime nested loops. + # that code still needs to be optimized (arguably, multiple unroll passes should be + # scheduled by the Julia optimizer). do so here, instead of re-optimizing entirely. + if job.config.opt_level == 2 + LLVM.add!(fpm, LLVM.GVNPass()) + elseif job.config.opt_level == 1 + LLVM.add!(fpm, LLVM.EarlyCSEPass()) + end + LLVM.add!(fpm, LLVM.DSEPass()) + + LLVM.add!(fpm, LLVM.SimplifyCFGPass()) + end + + # get rid of the internalized functions; now possible unused + LLVM.add!(pb, LLVM.GlobalDCEPass()) + + LLVM.run!(pb, mod, tm) + end +end + +function vendored_buildEarlyOptimizerPipeline( + mpm, @nospecialize(job), opt_level; instcombine=false +) + LLVM.add!(mpm, LLVM.NewPMCGSCCPassManager()) do cgpm + # TODO invokeCGSCCCallbacks + LLVM.add!(cgpm, LLVM.NewPMFunctionPassManager()) do fpm + LLVM.add!(fpm, LLVM.Interop.AllocOptPass()) + LLVM.add!(fpm, LLVM.Float2IntPass()) + LLVM.add!(fpm, LLVM.LowerConstantIntrinsicsPass()) + end + end + LLVM.add!(mpm, GPULowerCPUFeaturesPass()) + if opt_level >= 1 + LLVM.add!(mpm, LLVM.NewPMFunctionPassManager()) do fpm + if opt_level >= 2 + LLVM.add!(fpm, LLVM.SROAPass()) + if instcombine + LLVM.add!(fpm, LLVM.InstCombinePass()) + else + LLVM.add!(fpm, LLVM.InstSimplifyPass()) + end + LLVM.add!(fpm, LLVM.JumpThreadingPass()) + LLVM.add!(fpm, LLVM.CorrelatedValuePropagationPass()) + LLVM.add!(fpm, LLVM.ReassociatePass()) + LLVM.add!(fpm, LLVM.EarlyCSEPass()) + LLVM.add!(fpm, LLVM.Interop.AllocOptPass()) + else + if instcombine + LLVM.add!(fpm, LLVM.InstCombinePass()) + else + LLVM.add!(fpm, LLVM.InstSimplifyPass()) + end + LLVM.add!(fpm, LLVM.EarlyCSEPass()) + end + end + # TODO invokePeepholeCallbacks + end +end + +function vendored_buildIntrinsicLoweringPipeline( + mpm, @nospecialize(job), opt_level; instcombine::Bool=false +) + GPUCompiler.add!(mpm, LLVM.Interop.RemoveNIPass()) + + # lower GC intrinsics + if !GPUCompiler.uses_julia_runtime(job) + LLVM.add!(mpm, LLVM.NewPMFunctionPassManager()) do fpm + LLVM.add!(fpm, GPULowerGCFramePass()) + end + end + + # lower kernel state intrinsics + # NOTE: we can only do so here, as GC lowering can introduce calls to the runtime, + # and thus additional uses of the kernel state intrinsics. + if job.config.kernel + # TODO: now that all kernel state-related passes are being run here, merge some? + LLVM.add!(mpm, AddKernelStatePass()) + LLVM.add!(mpm, LLVM.NewPMFunctionPassManager()) do fpm + LLVM.add!(fpm, LowerKernelStatePass()) + end + LLVM.add!(mpm, CleanupKernelStatePass()) + end + + if !GPUCompiler.uses_julia_runtime(job) + # remove dead uses of ptls + LLVM.add!(mpm, LLVM.NewPMFunctionPassManager()) do fpm + LLVM.add!(fpm, LLVM.ADCEPass()) + end + LLVM.add!(mpm, GPULowerPTLSPass()) + end + + LLVM.add!(mpm, LLVM.NewPMFunctionPassManager()) do fpm + # lower exception handling + if GPUCompiler.uses_julia_runtime(job) + LLVM.add!(fpm, LLVM.Interop.LowerExcHandlersPass()) + end + LLVM.add!(fpm, GPUCompiler.GCInvariantVerifierPass()) + LLVM.add!(fpm, LLVM.Interop.LateLowerGCPass()) + if GPUCompiler.uses_julia_runtime(job) && VERSION >= v"1.11.0-DEV.208" + LLVM.add!(fpm, LLVM.Interop.FinalLowerGCPass()) + end + end + if GPUCompiler.uses_julia_runtime(job) && VERSION < v"1.11.0-DEV.208" + LLVM.add!(mpm, LLVM.Interop.FinalLowerGCPass()) + end + + if opt_level >= 2 + LLVM.add!(mpm, LLVM.NewPMFunctionPassManager()) do fpm + LLVM.add!(fpm, LLVM.GVNPass()) + LLVM.add!(fpm, LLVM.SCCPPass()) + LLVM.add!(fpm, LLVM.DCEPass()) + end + end + + # lower PTLS intrinsics + if GPUCompiler.uses_julia_runtime(job) + LLVM.add!(mpm, LLVM.Interop.LowerPTLSPass()) + end + + if opt_level >= 1 + LLVM.add!(mpm, LLVM.NewPMFunctionPassManager()) do fpm + if instcombine + LLVM.add!(fpm, LLVM.InstCombinePass()) + else + LLVM.add!(fpm, LLVM.InstSimplifyPass()) + end + LLVM.add!( + fpm, LLVM.SimplifyCFGPass(; GPUCompiler.AggressiveSimplifyCFGOptions...) + ) + end + end + + # remove Julia address spaces + LLVM.add!(mpm, LLVM.Interop.RemoveJuliaAddrspacesPass()) + + # Julia's operand bundles confuse the inliner, so repeat here now they are gone. + # FIXME: we should fix the inliner so that inlined code gets optimized early-on + return LLVM.add!(mpm, LLVM.AlwaysInlinerPass()) +end + +function vendored_buildScalarOptimizerPipeline( + fpm, @nospecialize(job), opt_level; instcombine::Bool=false +) + if opt_level >= 2 + LLVM.add!(fpm, LLVM.Interop.AllocOptPass()) + LLVM.add!(fpm, LLVM.SROAPass()) + LLVM.add!(fpm, LLVM.InstSimplifyPass()) + LLVM.add!(fpm, LLVM.GVNPass()) + LLVM.add!(fpm, LLVM.MemCpyOptPass()) + LLVM.add!(fpm, LLVM.SCCPPass()) + LLVM.add!(fpm, LLVM.CorrelatedValuePropagationPass()) + LLVM.add!(fpm, LLVM.DCEPass()) + LLVM.add!(fpm, LLVM.IRCEPass()) + if instcombine + LLVM.add!(fpm, LLVM.InstCombinePass()) + else + LLVM.add!(fpm, LLVM.InstSimplifyPass()) + end + LLVM.add!(fpm, LLVM.JumpThreadingPass()) + end + if opt_level >= 3 + LLVM.add!(fpm, LLVM.GVNPass()) + end + if opt_level >= 2 + LLVM.add!(fpm, LLVM.DSEPass()) + # TODO invokePeepholeCallbacks + LLVM.add!(fpm, LLVM.SimplifyCFGPass(; GPUCompiler.AggressiveSimplifyCFGOptions...)) + LLVM.add!(fpm, LLVM.Interop.AllocOptPass()) + LLVM.add!(fpm, LLVM.NewPMLoopPassManager()) do lpm + LLVM.add!(lpm, LLVM.LoopDeletionPass()) + LLVM.add!(lpm, LLVM.LoopInstSimplifyPass()) + end + LLVM.add!(fpm, LLVM.LoopDistributePass()) + end + # TODO invokeScalarOptimizerCallbacks +end + +function vendored_buildNewPMPipeline!(mpm, @nospecialize(job), opt_level) + # Doesn't call instcombine + GPUCompiler.buildEarlySimplificationPipeline(mpm, job, opt_level) + LLVM.add!(mpm, LLVM.AlwaysInlinerPass()) + vendored_buildEarlyOptimizerPipeline(mpm, job, opt_level) + LLVM.add!(mpm, LLVM.NewPMFunctionPassManager()) do fpm + # Doesn't call instcombine + GPUCompiler.buildLoopOptimizerPipeline(fpm, job, opt_level) + vendored_buildScalarOptimizerPipeline(fpm, job, opt_level) + if GPUCompiler.uses_julia_runtime(job) && opt_level >= 2 + # XXX: we disable vectorization, as this generally isn't useful for GPU targets + # and actually causes issues with some back-end compilers (like Metal). + # TODO: Make this not dependent on `uses_julia_runtime` (likely CPU), but it's own control + # Doesn't call instcombine + GPUCompiler.buildVectorPipeline(fpm, job, opt_level) + end + # if isdebug(:optim) + # add!(fpm, WarnMissedTransformationsPass()) + # end + end + vendored_buildIntrinsicLoweringPipeline(mpm, job, opt_level) + return GPUCompiler.buildCleanupPipeline(mpm, job, opt_level) +end + # compile to executable machine code function compile(job) # lower to PTX # TODO: on 1.9, this actually creates a context. cache those. entry = GPUCompiler.JuliaContext() do ctx mod, meta = GPUCompiler.compile( - :llvm, job; optimize=false, cleanup=false, validate=false + # :llvm, job; optimize=false, cleanup=false, validate=false, libraries=true + :llvm, + job; + optimize=false, + cleanup=false, + validate=false, + libraries=false, + # :llvm, job; optimize=false, cleanup=false, validate=true, libraries=false + # :llvm, job; optimize=false, cleanup=false, validate=false, libraries=false ) + if !Reactant.precompiling() + GPUCompiler.link_library!(mod, GPUCompiler.load_runtime(job)) + end entryname = LLVM.name(meta.entry) - GPUCompiler.optimize_module!(job, mod) + if Reactant.Compiler.DUMP_LLVMIR[] + println("cuda.jl immediate IR\n", string(mod)) + end opt_level = 2 tm = GPUCompiler.llvm_machine(job.config.target) LLVM.@dispose pb = LLVM.NewPMPassBuilder() begin @@ -284,11 +682,17 @@ function compile(job) LLVM.register!(pb, CleanupKernelStatePass()) LLVM.add!(pb, LLVM.NewPMModulePassManager()) do mpm - GPUCompiler.buildNewPMPipeline!(mpm, job, opt_level) + vendored_buildNewPMPipeline!(mpm, job, opt_level) end LLVM.run!(pb, mod, tm) end - GPUCompiler.optimize_module!(job, mod) + if Reactant.Compiler.DUMP_LLVMIR[] + println("cuda.jl pre vendor IR\n", string(mod)) + end + vendored_optimize_module!(job, mod) + if Reactant.Compiler.DUMP_LLVMIR[] + println("cuda.jl post vendor IR\n", string(mod)) + end LLVM.run!(CUDA.GPUCompiler.DeadArgumentEliminationPass(), mod, tm) for fname in ("gpu_report_exception", "gpu_signal_exception") @@ -305,17 +709,33 @@ function compile(job) end end - LLVM.strip_debuginfo!(mod) + errors = GPUCompiler.check_ir!(job, GPUCompiler.IRError[], mod) + unique!(errors) + filter!(errors) do err + (kind, bt, meta) = err + if meta !== nothing + if kind == GPUCompiler.UNKNOWN_FUNCTION && startswith(meta, "__nv") + return false + end + end + return true + end + if Reactant.Compiler.DUMP_LLVMIR[] + println("cuda.jl postopt IR\n", string(mod)) + end + if !isempty(errors) + throw(GPUCompiler.InvalidIRError(job, errors)) + end + # LLVM.strip_debuginfo!(mod) modstr = string(mod) - # This is a bit weird since we're taking a module from julia's llvm into reactant's llvm version # it is probably safer to reparse a string using the right llvm module api, so we will do that. - mmod = MLIR.IR.Module( @ccall MLIR.API.mlir_c.ConvertLLVMStrToMLIR( modstr::Cstring, MLIR.IR.context()::MLIR.API.MlirContext )::MLIR.API.MlirModule ) + @assert mmod != C_NULL linkRes = @ccall MLIR.API.mlir_c.LinkInModule( MLIR.IR.mmodule()::MLIR.API.MlirModule, @@ -323,10 +743,9 @@ function compile(job) entryname::Cstring, )::MLIR.API.MlirOperation - entry = MLIR.IR.Operation(linkRes) - - entry + String(Reactant.TracedUtils.get_attribute_by_name(linkRes, "sym_name")) end + return LLVMFunc{job.source.specTypes.parameters[1],job.source.specTypes}(nothing, entry) end @@ -336,9 +755,71 @@ function link(job, compiled) return compiled end +function abi_sizeof(@nospecialize(x)) + return sizeof(typeof(x)) +end +function abi_sizeof(@nospecialize(x::CuTracedArray)) + return sizeof(Ptr) +end +function abi_sizeof(@nospecialize(x::CUDA.CuDeviceArray)) + return sizeof(Ptr) +end + +function to_bytes(x) + sz = abi_sizeof(x) + ref = Ref(x) + GC.@preserve ref begin + ptr = Base.reinterpret(Ptr{UInt8}, Base.unsafe_convert(Ptr{Cvoid}, ref)) + vec = Vector{UInt8}(undef, sz) + for i in 1:sz + @inbounds vec[i] = Base.unsafe_load(ptr, i) + end + vec + end +end + +function Reactant.make_tracer( + seen, @nospecialize(prev::CuTracedArray), @nospecialize(path), mode; kwargs... +) + x = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, prev.ptr)) + x = x::TracedRArray + Reactant.make_tracer(seen, x, path, mode; kwargs...) + return prev +end + +function get_field_offset(T::Type, path) + offset = 0 + current_type = T + + for field in path + # Get the field index + field_idx = if field isa Integer + field + else + @assert field isa Symbol + findfirst(==(field), fieldnames(current_type)) + end + if field_idx === nothing + error( + "Field $field not found in type $current_type, fieldnames=$(fieldnames(current_type)) T=$T path=$path", + ) + end + + # Add the offset of this field + toffset = fieldoffset(current_type, field_idx) + tcurrent_type = fieldtype(current_type, field_idx) + offset += toffset + + # Update current_type to the field's type for next iteration + current_type = tcurrent_type + end + + return offset +end + Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( args...; - convert=Val(false), + convert=Val(true), blocks::CuDim=1, threads::CuDim=1, cooperative::Bool=false, @@ -348,73 +829,223 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( blockdim = CUDA.CuDim3(blocks) threaddim = CUDA.CuDim3(threads) + if convert == Val(true) + args = recudaconvert.(args) + end + mlir_args = MLIR.IR.Value[] restys = MLIR.IR.Type[] aliases = MLIR.IR.Attribute[] rarrays = TracedRArray[] - for (i, a) in enumerate(args) - @assert a isa CuTracedArray - ta = - Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, a.ptr))::TracedRArray - push!(rarrays, ta) - arg = ta.mlir_data + + fname = func.entry + + wrapper_tys = MLIR.IR.Type[] + ctx = MLIR.IR.context() + cullvm_ty = MLIR.IR.Type(MLIR.API.mlirLLVMPointerTypeGet(ctx, 1)) + + # linearize kernel arguments + seen = Reactant.OrderedIdDict() + kernelargsym = gensym("kernelarg") + for (i, prev) in enumerate(Any[func.f, args...]) + Reactant.make_tracer(seen, prev, (kernelargsym, i), Reactant.NoStopTracedTrack) + end + wrapper_tys = MLIR.IR.Type[] + for arg in values(seen) + if !(arg isa TracedRArray || arg isa TracedRNumber) + continue + end + push!(wrapper_tys, cullvm_ty) + end + + sym_name = String(gensym("call_$fname")) + mod = MLIR.IR.mmodule() + CConv = MLIR.IR.Attribute( + MLIR.API.mlirLLVMCConvAttrGet(ctx, MLIR.API.MlirLLVMCConvPTX_Kernel) + ) + voidty = MLIR.IR.Type(MLIR.API.mlirLLVMVoidTypeGet(ctx)) + wrapftype = MLIR.IR.Type( + MLIR.API.mlirLLVMFunctionTypeGet(voidty, length(wrapper_tys), wrapper_tys, false) + ) + wrapfunc = MLIR.IR.block!(MLIR.IR.body(mod)) do + return MLIR.Dialects.llvm.func(; + sym_name, + sym_visibility=MLIR.IR.Attribute("private"), + function_type=wrapftype, + body=MLIR.IR.Region(), + CConv, + ) + end + wrapbody = MLIR.IR.Block(wrapper_tys, [MLIR.IR.Location() for _ in wrapper_tys]) + push!(MLIR.IR.region(wrapfunc, 1), wrapbody) + for i in 1:length(wrapper_tys) + @ccall MLIR.API.mlir_c.ReactantFuncSetArgAttr( + wrapfunc::MLIR.API.MlirOperation, + (i - 1)::Csize_t, + "llvm.noalias"::MLIR.API.MlirStringRef, + MLIR.IR.UnitAttribute()::MLIR.API.MlirAttribute, + )::Cvoid + end + + wrapargs = MLIR.IR.Value[] + argidx = 1 + + symtab = MLIR.IR.SymbolTable(MLIR.IR.Operation(mod)) + gpufunc = MLIR.IR.lookup(symtab, fname) + MLIR.IR.attr!( + gpufunc, + "CConv", + MLIR.IR.Attribute(MLIR.API.mlirLLVMCConvAttrGet(ctx, MLIR.API.MlirLLVMCConvC)), + ) + gpu_function_type = MLIR.IR.Type( + Reactant.TracedUtils.get_attribute_by_name(gpufunc, "function_type") + ) + + trueidx = 1 + allocs = Union{Tuple{MLIR.IR.Value,MLIR.IR.Type},Nothing}[] + + llvmptr = MLIR.IR.Type(MLIR.API.mlirLLVMPointerTypeGet(ctx, 0)) + i8 = MLIR.IR.Type(UInt8) + allargs = [func.f, args...] + for a in allargs + if sizeof(a) == 0 + push!(allocs, nothing) + continue + end + + # TODO check for only integer and explicitly non cutraced types + MLIR.IR.block!(wrapbody) do + argty = MLIR.IR.Type( + MLIR.API.mlirLLVMFunctionTypeGetInput(gpu_function_type, trueidx - 1) + ) + trueidx += 1 + c1 = MLIR.IR.result( + MLIR.Dialects.llvm.mlir_constant(; + res=MLIR.IR.Type(Int64), value=MLIR.IR.Attribute(1) + ), + 1, + ) + alloc = MLIR.IR.result( + MLIR.Dialects.llvm.alloca( + c1; elem_type=MLIR.IR.Attribute(argty), res=llvmptr + ), + 1, + ) + push!(allocs, (alloc, argty)) + + sz = abi_sizeof(a) + array_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGet(MLIR.IR.Type(Int8), sz)) + cdata = MLIR.IR.result( + MLIR.Dialects.llvm.mlir_constant(; + res=array_ty, value=MLIR.IR.DenseElementsAttribute(to_bytes(a)) + ), + 1, + ) + MLIR.Dialects.llvm.store(cdata, alloc) + end + end + + argidx = 1 + for arg in values(seen) + if !(arg isa TracedRArray || arg isa TracedRNumber) + continue + end + + paths = Reactant.TracedUtils.get_paths(arg) + + arg = arg.mlir_data arg = Reactant.TracedUtils.transpose_val(arg) push!(restys, MLIR.IR.type(arg)) push!(mlir_args, arg) + push!( aliases, MLIR.IR.Attribute( MLIR.API.stablehloOutputOperandAliasGet( MLIR.IR.context(), - length(args) == 1 ? 0 : 1, - length(args) == 1 ? C_NULL : Ref{Int64}(i - 1), - i - 1, + length(wrapper_tys) == 1 ? 0 : 1, + length(wrapper_tys) == 1 ? C_NULL : Ref{Int64}(argidx - 1), + argidx - 1, 0, C_NULL, ), ), ) + + for p in paths + if p[1] !== kernelargsym + continue + end + # Get the allocation corresponding to which arg we're doing + alloc = allocs[p[2]][1] + + # we need to now compute the offset in bytes of the path + julia_arg = allargs[p[2]] + + offset = get_field_offset(typeof(julia_arg), p[3:end]) + MLIR.IR.block!(wrapbody) do + ptr = MLIR.IR.result( + MLIR.Dialects.llvm.getelementptr( + alloc, + MLIR.IR.Value[]; + res=llvmptr, + elem_type=i8, + rawConstantIndices=MLIR.IR.Attribute([Int32(offset)]), + ), + 1, + ) + MLIR.Dialects.llvm.store(MLIR.IR.argument(wrapbody, argidx), ptr) + end + end + argidx += 1 end - output_operand_aliases = MLIR.IR.Attribute(aliases) + MLIR.IR.block!(wrapbody) do + for arg in allocs + if arg === nothing + continue + end + alloc, argty = arg + argres = MLIR.IR.result(MLIR.Dialects.llvm.load(alloc; res=argty), 1) + push!(wrapargs, argres) + end + MLIR.Dialects.llvm.call( + wrapargs, + MLIR.IR.Value[]; + callee=MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)), + op_bundle_sizes=MLIR.IR.Attribute(Int32[]), + ) + MLIR.Dialects.llvm.return_(nothing) + end - fname = Reactant.TracedUtils.get_attribute_by_name(func.entry, "sym_name") - # Force public for now while we don't have real users - # MLIR.IR.rmattr!(func.entry, "sym_visibility") + output_operand_aliases = MLIR.IR.Attribute(aliases) - operands = MLIR.IR.Value[] + blk_operands = MLIR.IR.Value[] for idx in (blockdim.x, blockdim.y, blockdim.z, threaddim.x, threaddim.y, threaddim.z, shmem) push!( - operands, + blk_operands, Reactant.TracedUtils.promote_to(Reactant.TracedRNumber{Int}, idx).mlir_data, ) end - for arg in mlir_args - push!(operands, arg) - end - owned_regions = MLIR.IR.Region[] - successors = MLIR.IR.Block[] - attributes = MLIR.IR.NamedAttribute[ - MLIR.IR.NamedAttribute("fn", MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))), - MLIR.IR.NamedAttribute( - "output_operand_aliases", MLIR.IR.Attribute(output_operand_aliases) - ), - ] location = MLIR.IR.Location() - call = MLIR.IR.create_operation( - "enzymexla.kernel_call", - location; - operands, - owned_regions, - successors, - attributes, - results=restys, - result_inference=false, + @assert length(restys) == length(aliases) + call = MLIR.Dialects.enzymexla.kernel_call( + blk_operands..., + mlir_args; + result_0=restys, + fn=MLIR.IR.FlatSymbolRefAttribute(sym_name), + output_operand_aliases=MLIR.IR.Attribute(output_operand_aliases), ) - for (i, res) in enumerate(rarrays) - res.mlir_data = Reactant.TracedUtils.transpose_val(MLIR.IR.result(call, i)) + + argidx = 1 + for arg in values(seen) + if !(arg isa TracedRArray || arg isa TracedRNumber) + continue + end + arg.mlir_data = Reactant.TracedUtils.transpose_val(MLIR.IR.result(call, argidx)) + argidx += 1 end end @@ -456,30 +1087,133 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction( ) CUDA.GPUCompiler.cached_compilation(cache, source, config, compile, link) end - return res + return Core.Typeof(res)(f, res.entry) end -function __init__() - handle = Reactant.XLA.Libdl.dlopen(CUDA.CUDA_Driver_jll.libcuda; throw_error=false) - if handle === nothing - handle = C_NULL +Base.@nospecializeinfer function Reactant.traced_type_inner( + @nospecialize(A::Type{<:CuTracedArray}), + seen, + mode::Reactant.TraceMode, + @nospecialize(track_numbers::Type) +) + return A +end + +Base.@nospecializeinfer function Reactant.traced_type_inner( + @nospecialize(A::Type{<:CUDA.CuArray}), + seen, + mode::Reactant.TraceMode, + @nospecialize(track_numbers::Type), + @nospecialize(sharding) +) + T = eltype(A) + N = ndims(A) + if mode == Reactant.ArrayToConcrete && T <: Reactant.ReactantPrimitive + if !Reactant.Sharding.is_sharded(sharding) + return Reactant.ConcretePJRTArray{T,N,1,Reactant.Sharding.NoShardInfo} + else + error("TODO: implement sharding") + end + else + TT = Reactant.traced_type_inner(T, seen, mode, track_numbers, sharding) + if TT === T + return A + else + return Array{ + Reactant.traced_type_inner( + T, seen, mode, track_numbers, Base.getproperty(sharding, 1) + ), + N, + } + end end - ptr1 = Reactant.XLA.Libdl.dlsym(handle, "cuLaunchKernel"; throw_error=false) - if ptr1 === nothing - ptr1 = C_NULL +end + +function Reactant.make_tracer( + seen, + @nospecialize(prev::CUDA.CuArray), + @nospecialize(path), + mode; + @nospecialize(track_numbers::Type = Union{}), + @nospecialize(sharding = Reactant.Sharding.NoSharding()), + kwargs..., +) + RT = Core.Typeof(prev) + # XXX: If someone wants to shard the same array with different shardings, we need to + # somehow handle this correctly... Right now we just use the first sharding. + if haskey(seen, prev) + return seen[prev] end - ptr2 = Reactant.XLA.Libdl.dlsym(handle, "cuModuleLoadData"; throw_error=false) - if ptr2 === nothing - ptr2 = C_NULL + if mode == Reactant.ArrayToConcrete && eltype(RT) <: Reactant.ReactantPrimitive + return seen[prev] = Reactant.ConcretePJRTArray(Array(prev); sharding) end - ptr3 = Reactant.XLA.Libdl.dlsym(handle, "cuModuleGetFunction"; throw_error=false) - if ptr3 === nothing - ptr3 = C_NULL + TT = Reactant.traced_type(eltype(RT), Val(mode), track_numbers, sharding) + if TT === eltype(RT) + return prev + end + newa = Array{TT,ndims(RT)}(undef, size(prev)) + seen[prev] = newa + same = true + for I in eachindex(prev) + if isassigned(prev, I) + pv = prev[I] + nv = Reactant.make_tracer( + seen, + pv, + append_path(path, I), + mode; + track_numbers, + sharding=Base.getproperty(sharding, I), + kwargs..., + ) + if pv !== nv + same = false + end + @inbounds newa[I] = nv + end + end + if same + seen[prev] = prev + return prev + end + return newa +end + +function __init__() + if CUDA.functional() + target = CUDA._compiler_config(CUDA.device()).target + Reactant.Compiler.cubinChip[] = "sm_$(target.cap.major)$(target.cap.minor)" end - Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, ptr1) - Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, ptr2) - Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, ptr3) return nothing end +# In Julia v1.11.3 precompiling this module caches bad code: +# . +@static if !Sys.isapple() + Reactant.PrecompileTools.@setup_workload begin + Reactant.initialize_dialect() + client = Reactant.XLA.PJRT.CPUClient(; checkcount=false) + Reactant.PrecompileTools.@compile_workload begin + @static if Reactant.precompilation_supported() && VERSION != v"1.11.3" + function square_kernel!(x) + i = CUDA.threadIdx().x + x[i] *= x[i] + return nothing + end + + function square!(x) + CUDA.@cuda blocks = 1 threads = length(x) square_kernel!(x) + return nothing + end + y = Reactant.ConcretePJRTArray([2.0]; client) + Reactant.Compiler.compile_mlir(square!, (y,); optimize=false) + end + end + Reactant.XLA.free_client(client) + client.client = C_NULL + Reactant.deinitialize_dialect() + Reactant.clear_oc_cache() + end +end + end # module ReactantCUDAExt diff --git a/ext/ReactantKernelAbstractionsExt.jl b/ext/ReactantKernelAbstractionsExt.jl new file mode 100644 index 0000000000..358827ef27 --- /dev/null +++ b/ext/ReactantKernelAbstractionsExt.jl @@ -0,0 +1,96 @@ +module ReactantKernelAbstractionsExt + +using Reactant + +import KernelAbstractions as KA + +using Adapt: Adapt + +## back-end + +export ReactantBackend + +struct ReactantBackend <: KA.GPU end + +function Base.getproperty(x::ReactantBackend, sym::Symbol) + if sym === :always_inline + return true + elseif sym === :prefer_blocks + return false + else + return Base.getfield(x, sym) + end +end + +KA.allocate(n::ReactantBackend, ::Type{T}, dims::Tuple) where {T} = KA.zeros(b, T, dims) +function KA.zeros(::ReactantBackend, ::Type{T}, dims::Tuple) where {T} + return ConcretePJRTArray(zeros(T, dims)) +end +function KA.ones(::ReactantBackend, ::Type{T}, dims::Tuple) where {T} + return ConcretePJRTArray(ones(T, dims)) +end + +KA.get_backend(::Reactant.AnyTracedRArray) = ReactantBackend() +KA.get_backend(::Reactant.AnyConcretePJRTArray) = ReactantBackend() +function KA.synchronize(::ReactantBackend) end + +Adapt.adapt_storage(::ReactantBackend, a::Array) = a +Adapt.adapt_storage(::ReactantBackend, a::Reactant.AnyTracedRArray) = a +Adapt.adapt_storage(::ReactantBackend, a::Reactant.AnyConcretePJRTArray) = a +Adapt.adapt_storage(::KA.CPU, a::Reactant.AnyConcretePJRTArray) = convert(Array, a) + +## memory operations + +function KA.copyto!(::ReactantBackend, A, B) + Base.copyto!(A, B) + return A +end + +## kernel launch + +function KA.mkcontext(kernel::KA.Kernel{ReactantBackend}, _ndrange, iterspace) + return KA.CompilerMetadata{KA.ndrange(kernel),KA.DynamicCheck}(_ndrange, iterspace) +end + +function KA.launch_config(kernel::KA.Kernel{ReactantBackend}, ndrange, workgroupsize) + if ndrange isa Integer + ndrange = (ndrange,) + end + if workgroupsize isa Integer + workgroupsize = (workgroupsize,) + end + + # partition checked that the ndrange's agreed + if KA.ndrange(kernel) <: KA.StaticSize + ndrange = nothing + end + + iterspace, dynamic = + if KA.workgroupsize(kernel) <: KA.DynamicSize && workgroupsize === nothing + # use ndrange as preliminary workgroupsize for autotuning + KA.partition(kernel, ndrange, ndrange) + else + KA.partition(kernel, ndrange, workgroupsize) + end + + return ndrange, workgroupsize, iterspace, dynamic +end + +KA.argconvert(k::KA.Kernel{ReactantBackend}, arg) = arg + +function KA.priority!(::ReactantBackend, prio::Symbol) + if !(prio in (:high, :normal, :low)) + error("priority must be one of :high, :normal, :low") + end + return nothing +end + +function tokw(ndrange, workgroupsize, obj, args...) + @inline obj(args...; ndrange, workgroupsize) +end + +function (obj::KA.Kernel{ReactantBackend})(args...; ndrange=nothing, workgroupsize=nothing) + @jit tokw(ndrange, workgroupsize, obj, args...) +end + +end diff --git a/ext/ReactantMPIExt.jl b/ext/ReactantMPIExt.jl new file mode 100644 index 0000000000..5ede919c0d --- /dev/null +++ b/ext/ReactantMPIExt.jl @@ -0,0 +1,36 @@ +module ReactantMPIExt + +using Reactant: Reactant, Distributed +using MPI: MPI + +# https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/mpi4py_cluster.py +Distributed.is_env_present(::Distributed.MPIEnvDetector) = MPI.Initialized() + +function Distributed.get_coordinator_address( + ::Distributed.MPIEnvDetector, timeout_in_seconds::Integer +) + if MPI.Comm_rank(MPI.COMM_WORLD) == 0 + hostname = gethostname() + port_id = hash(hostname) % 2^12 + (65535 - 2^12 + 1) + hostname = "$(hostname):$(port_id)" + else + hostname = nothing + end + + return MPI.bcast(hostname, MPI.COMM_WORLD; root=0) +end + +function Distributed.get_process_count(::Distributed.MPIEnvDetector) + return Int(MPI.Comm_size(MPI.COMM_WORLD)) +end + +function Distributed.get_process_id(::Distributed.MPIEnvDetector) + return Int(MPI.Comm_rank(MPI.COMM_WORLD)) +end + +function Distributed.get_local_process_id(::Distributed.MPIEnvDetector) + new_comm = MPI.Comm_split_type(MPI.COMM_WORLD, MPI.COMM_TYPE_SHARED, 0) + return Int(MPI.Comm_rank(new_comm)) +end + +end diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index e90c000763..ee00463e2e 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -323,7 +323,7 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr This case is not optimized and will be slow." maxlog = 1 dims = NNlib.scatter_dims(src, dst, idxs) colons = ntuple(Returns(Colon()), dims) - start_sizes = ntuple(i -> size(src, i), dims) + start_sizes = ntuple(Base.Fix1(size, src), dims) results = map(CartesianIndices(idxs)) do k res = @allowscalar src[colons..., Tuple(idxs[k])...] res isa TracedRNumber && (res = TracedUtils.broadcast_to_size(res, (1,))) diff --git a/ext/ReactantOffsetArraysExt.jl b/ext/ReactantOffsetArraysExt.jl new file mode 100644 index 0000000000..5844d53c79 --- /dev/null +++ b/ext/ReactantOffsetArraysExt.jl @@ -0,0 +1,20 @@ +module ReactantOffsetArraysExt + +using OffsetArrays +using OffsetArrays: OffsetArray +using Reactant: Reactant, MLIR, Ops, TracedRArray + +Base.@nospecializeinfer function Reactant.traced_type_inner( + @nospecialize(OA::Type{<:OffsetArray}), + seen, + mode::Reactant.TraceMode, + @nospecialize(track_numbers::Type), + @nospecialize(sharding), +) + N = ndims(OA) + T = OffsetArrays.parenttype(OA) + T2 = Reactant.traced_type_inner(T, seen, mode, track_numbers, sharding) + return OffsetArray{eltype(T2),N,T2} +end + +end diff --git a/ext/ReactantPythonCallExt.jl b/ext/ReactantPythonCallExt.jl index be5b61fdd3..3a2c0f4a7e 100644 --- a/ext/ReactantPythonCallExt.jl +++ b/ext/ReactantPythonCallExt.jl @@ -8,22 +8,22 @@ using PythonCall const jaxptr = Ref{Py}() -const NUMPY_SIMPLE_TYPES = ( - ("bool_", Bool), - ("int8", Int8), - ("int16", Int16), - ("int32", Int32), - ("int64", Int64), - ("uint8", UInt8), - ("uint16", UInt16), - ("uint32", UInt32), - ("uint64", UInt64), - ("float16", Float16), - ("float32", Float32), - ("float64", Float64), - ("complex32", ComplexF16), - ("complex64", ComplexF32), - ("complex128", ComplexF64), +const NUMPY_SIMPLE_TYPES = Dict( + Bool => :bool_, + Int8 => :int8, + Int16 => :int16, + Int32 => :int32, + Int64 => :int64, + UInt8 => :uint8, + UInt16 => :uint16, + UInt32 => :uint32, + UInt64 => :uint64, + Float16 => :float16, + Float32 => :float32, + Float64 => :float64, + ComplexF16 => :complex32, + ComplexF32 => :complex64, + ComplexF64 => :complex128, ) function PythonCall.pycall( @@ -32,15 +32,10 @@ function PythonCall.pycall( jax = jaxptr[] numpy = jax.numpy inputs = map((arg0, argNs...)) do arg - JT = eltype(arg) - PT = nothing - for (CPT, CJT) in NUMPY_SIMPLE_TYPES - if JT == CJT - PT = CPT - break - end - end - numpy.zeros(size(arg); dtype=getproperty(numpy, Symbol(PT))) + numpy.zeros( + size(arg); + dtype=getproperty(numpy, NUMPY_SIMPLE_TYPES[Reactant.unwrapped_eltype(arg)]), + ) end lowered = jax.jit(f).lower(inputs...) txt = pyconvert(String, lowered.as_text()) diff --git a/ext/ReactantSpecialFunctionsExt.jl b/ext/ReactantSpecialFunctionsExt.jl new file mode 100644 index 0000000000..9ed006e386 --- /dev/null +++ b/ext/ReactantSpecialFunctionsExt.jl @@ -0,0 +1,118 @@ +module ReactantSpecialFunctionsExt +using SpecialFunctions +using Reactant: Ops, Reactant, TracedRNumber, ReactantFloat, ReactantInt, ReactantFloatInt +using Reactant.TracedRNumberOverrides: float + +for fn in [:digamma, :erf, :erfc, (:loggamma, :lgamma)] + (fns, fno) = fn isa Tuple ? fn : (fn, fn) + @eval(function SpecialFunctions.$fns(x::TracedRNumber{<:ReactantFloatInt}) + return Ops.$fno(float(x)) + end) +end + +function SpecialFunctions.gamma(x::TracedRNumber{<:ReactantFloat}) + return exp(Ops.lgamma(float(x))) +end + +function SpecialFunctions.gamma(n::TracedRNumber{<:ReactantInt}) + return round(gamma(float(n))) +end + +function SpecialFunctions.loggamma1p(x::TracedRNumber{<:ReactantFloat}) + return loggamma(1 + x) +end + +function SpecialFunctions.logfactorial(x::TracedRNumber{<:ReactantInt}) + return loggamma(1 + x) +end + +# SpecialFunctions.invdigamma + +function SpecialFunctions.trigamma(x::TracedRNumber{<:ReactantFloatInt}) + return Ops.polygamma(Ops.constant(Float64(1)), float(x))#TODO: change Ops definition +end + +function SpecialFunctions.polygamma( + n::TracedRNumber{<:ReactantFloatInt}, x::TracedRNumber{<:ReactantFloatInt} +) + return Ops.polygamma(float(n), float(x)) +end + +# SpecialFunctions.gamma_inc + +# SpecialFunctions.gamma_inc_inv + +function SpecialFunctions.loggammadiv( + a::TracedRNumber{T}, b::TracedRNumber{T} +) where {T<:ReactantFloat} + return log(gamma(b) / gamma(a + b)) +end + +#SpecialFunctions.gamma ... + +function SpecialFunctions.beta( + x::TracedRNumber{T}, y::TracedRNumber{T} +) where {T<:ReactantFloatInt} + return gamma(x) * gamma(y) / gamma(x + y) +end + +function SpecialFunctions.logbeta( + x::TracedRNumber{T}, y::TracedRNumber{T} +) where {T<:ReactantFloatInt} + return log(abs(beta(x, y))) +end + +#TODO: sign function +#SpecialFunctions.logabsbeta +#SpecialFunctions.logabsbinomial + +#SpecialFunctions.beta... + +#utilities... + +function SpecialFunctions.erf( + x::TracedRNumber{T}, y::TracedRNumber{T} +) where {T<:ReactantFloatInt} + return erf(y) - erf(x) +end + +#SpecialFunctions.erfcinv + +function SpecialFunctions.logerf( + x::TracedRNumber{T}, y::TracedRNumber{T} +) where {T<:ReactantFloatInt} + return log(erf(x, y)) +end + +function SpecialFunctions.erfcx(x::TracedRNumber{<:ReactantFloatInt}) + return exp(float(x^2)) * erfc(x) +end + +function SpecialFunctions.logerfc(x::TracedRNumber{<:ReactantFloatInt}) + return log(erfc(x)) +end + +function SpecialFunctions.logerfcx(x::TracedRNumber{<:ReactantFloatInt}) + return log(erfcx(x)) +end + +#Unsupported complex +#SpecialFunctions.erfi + +#SpecialFunctions.erfinv +#SpecialFunctions.dawson +#SpecialFunctions.faddeeva + +#Airy and Related Functions + +#Bessel ... + +#Elliptic Integrals + +function SpecialFunctions.zeta( + z::TracedRNumber{T}, s::TracedRNumber{T} +) where {T<:ReactantFloatInt} + return Ops.zeta(z, s) +end + +end # module ReactantSpecialFunctionsExt diff --git a/ext/ReactantStatisticsExt.jl b/ext/ReactantStatisticsExt.jl index 40db81a8ed..1c9e3f9f33 100644 --- a/ext/ReactantStatisticsExt.jl +++ b/ext/ReactantStatisticsExt.jl @@ -4,18 +4,22 @@ using Reactant: AnyTracedRArray using Reactant.TracedUtils: materialize_traced_array using Statistics: Statistics -function Statistics.mean(A::AnyTracedRArray{T,N}; dims=:) where {T,N} - A = materialize_traced_array(A) +function Statistics._mean(f::F, A::AnyTracedRArray{T,N}, dims) where {F,T,N} denom = dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims) - return mapreduce(identity, +, A; dims) / denom + return mapreduce(f, +, A; dims) / denom end -function Statistics.var( - A::AnyTracedRArray{T,N}; dims=:, mean=nothing, corrected=true +function Statistics._var( + A::AnyTracedRArray{T,N}, corrected::Bool, mean, ::Colon ) where {T,N} - A = materialize_traced_array(A) + mean === nothing && (mean = Statistics.mean(A)) + denom = length(A) - corrected + return mapreduce(abs2, +, A .- mean; dims=:) / denom +end + +function Statistics._var(A::AnyTracedRArray{T,N}, corrected::Bool, mean, dims) where {T,N} mean === nothing && (mean = Statistics.mean(A; dims)) - denom = (dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims)) - corrected + denom = prod(Base.Fix1(size, A), dims) - corrected return mapreduce(abs2, +, A .- mean; dims) / denom end diff --git a/lib/ReactantCore/Project.toml b/lib/ReactantCore/Project.toml index bec50b45e1..211816d6ac 100644 --- a/lib/ReactantCore/Project.toml +++ b/lib/ReactantCore/Project.toml @@ -1,7 +1,7 @@ name = "ReactantCore" uuid = "a3311ec8-5e00-46d5-b541-4f83e724a433" authors = ["William Moses ", "Valentin Churavy ", "Sergio Sánchez Ramírez ", "Paul Berg ", "Avik Pal "] -version = "0.1.3" +version = "0.1.5" [deps] ExpressionExplorer = "21656369-7473-754a-2065-74616d696c43" diff --git a/lib/ReactantCore/src/ReactantCore.jl b/lib/ReactantCore/src/ReactantCore.jl index f99d6cab90..acf8862a8a 100644 --- a/lib/ReactantCore/src/ReactantCore.jl +++ b/lib/ReactantCore/src/ReactantCore.jl @@ -3,7 +3,7 @@ module ReactantCore using ExpressionExplorer: ExpressionExplorer using MacroTools: MacroTools -export @trace, MissingTracedValue +export @trace, within_compile, MissingTracedValue # Traits is_traced(x) = false @@ -15,10 +15,19 @@ end MissingTracedValue() = MissingTracedValue(()) +Base.zero(::MissingTracedValue) = MissingTracedValue() + const SPECIAL_SYMBOLS = [ :(:), :nothing, :missing, :Inf, :Inf16, :Inf32, :Inf64, :Base, :Core ] +""" + within_compile() + +Returns true if this function is executed in a Reactant compilation context, otherwise false. +""" +@inline within_compile() = false # behavior is overwritten in Interpreter.jl + # Code generation """ @trace @@ -115,6 +124,13 @@ macro trace(expr) return esc(trace_if_with_returns(__module__, expr)) end end + Meta.isexpr(expr, :call) && return esc(trace_call(__module__, expr)) + if Meta.isexpr(expr, :(.), 2) && Meta.isexpr(expr.args[2], :tuple) + fname = :($(Base.Broadcast.BroadcastFunction)($(expr.args[1]))) + args = only(expr.args[2:end]).args + call = Expr(:call, fname, args...) + return esc(trace_call(__module__, call)) + end Meta.isexpr(expr, :if) && return esc(trace_if(__module__, expr)) Meta.isexpr(expr, :for) && return (esc(trace_for(__module__, expr))) return error("Only `if-elseif-else` blocks are currently supported by `@trace`") @@ -158,8 +174,16 @@ function trace_for(mod, expr) external_syms..., ) + cond_val(s) = :(@isdefined($s) ? $s : nothing) + + while_defined = gensym(:while_defined) + locals = Expr[ + [Expr(:(=), s, cond_val(s)) for s in external_syms]..., :(args = $(args_init)) + ] + + var_syms = all_syms.args[(begin + 1):end] reactant_code_block = quote - let args = $(args_init) + let $(locals...) cond_fn = $(all_syms) -> begin local num_iters = div($limit - $start, $step, RoundDown) @@ -170,11 +194,15 @@ function trace_for(mod, expr) end body_fn = $(all_syms) -> begin + local isdefined_before = isnothing.(Any[$(var_syms...)]) local step_ = $step local start_ = $start local $induction = start_ + $counter * step_ $body - ($counter + 1, $(all_syms.args[(begin + 1):end]...)) + local results_ = Any[ + s for (d, s) in zip(isdefined_before, Any[$(var_syms...)]) if !d + ] + ($counter + 1, results_...) end $(ReactantCore).traced_while(cond_fn, body_fn, args) @@ -182,7 +210,9 @@ function trace_for(mod, expr) end return quote - if any($(is_traced), $(Expr(:tuple, all_syms.args[(begin + 1):end]...))) + if $(within_compile)() && $(any)( + $(is_traced), $(Expr(:tuple, cond_val.(all_syms.args[(begin + 1):end])...)) + ) $(reactant_code_block) else $(expr) @@ -195,8 +225,12 @@ function trace_if_with_returns(mod, expr) new_expr, _, all_check_vars = trace_if( mod, expr.args[2]; store_last_line=expr.args[1], depth=1 ) + cond_name = first(all_check_vars) + original_cond = expr.args[2].args[1] + expr.args[2].args[1] = cond_name return quote - if any($(is_traced), ($(all_check_vars...),)) + $(cond_name) = $(original_cond) + if $(within_compile)() && $(any)($(is_traced), ($(all_check_vars...),)) $(new_expr) else $(expr) @@ -292,7 +326,7 @@ function trace_if(mod, expr; store_last_line=nothing, depth=0) non_existant_true_branch_vars = setdiff(all_output_vars, all_true_branch_vars) true_branch_extras = Expr( :block, - [:($(var) = $(MissingTracedValue())) for var in non_existant_true_branch_vars]..., + [:($(var) = $(MissingTracedValue)()) for var in non_existant_true_branch_vars]..., ) true_branch_fn = :(($(all_input_vars...),) -> begin @@ -310,7 +344,7 @@ function trace_if(mod, expr; store_last_line=nothing, depth=0) ) false_branch_extras = Expr( :block, - [:($(var) = $(MissingTracedValue())) for var in non_existant_false_branch_vars]..., + [:($(var) = $(MissingTracedValue)()) for var in non_existant_false_branch_vars]..., ) false_branch_fn = :(($(all_input_vars...),) -> begin @@ -323,29 +357,69 @@ function trace_if(mod, expr; store_last_line=nothing, depth=0) ) false_branch_fn = :($(false_branch_fn_name) = $(false_branch_fn)) + cond_name = gensym(:cond) + reactant_code_block = quote $(true_branch_fn) $(false_branch_fn) ($(all_output_vars...),) = $(traced_if)( - $(cond_expr), + $(cond_name), $(true_branch_fn_name), $(false_branch_fn_name), ($(all_input_vars...),), ) end - all_check_vars = [all_input_vars..., condition_vars...] + non_reactant_code_block = Expr(:if, cond_name, original_expr.args[2]) + if length(original_expr.args) > 2 # has else block + append!(non_reactant_code_block.args, original_expr.args[3:end]) + end + + all_check_vars = [cond_name, all_input_vars..., condition_vars...] unique!(all_check_vars) depth > 0 && return ( - reactant_code_block, (true_branch_fn_name, false_branch_fn_name), all_check_vars + quote + $(cond_name) = $(cond_expr) + $(reactant_code_block) + end, + (true_branch_fn_name, false_branch_fn_name), + all_check_vars, ) return quote - if any($(is_traced), ($(all_check_vars...),)) + $(cond_name) = $(cond_expr) + if $(within_compile)() && $(any)($(is_traced), ($(all_check_vars...),)) $(reactant_code_block) else - $(original_expr) + $(non_reactant_code_block) + end + end +end + +function correct_maybe_bcast_call(fname) + startswith(string(fname), '.') || return false, fname, fname + return true, Symbol(string(fname)[2:end]), fname +end + +function trace_call(mod, call) + bcast, fname, fname_full = correct_maybe_bcast_call(call.args[1]) + f = if bcast + quote + if isdefined(mod, $(Meta.quot(fname_full))) + $(fname_full) + else + Base.Broadcast.BroadcastFunction($(fname)) + end + end + else + :($(fname)) + end + return quote + if $(within_compile)() + $(traced_call)($f, $(call.args[2:end]...)) + else + $(call) end end end @@ -366,15 +440,21 @@ function traced_if(cond, true_fn, false_fn, args) return cond ? true_fn(args) : false_fn(args) end -function traced_while(cond_fn, body_fn, args) - while cond_fn(args...) - args = body_fn(args...) - end - return args -end +function traced_while end # defined inside Reactant.jl + +traced_call(f, args...; kwargs...) = f(args...; kwargs...) function cleanup_expr_to_avoid_boxing(expr, prepend::Symbol, all_vars) return MacroTools.postwalk(expr) do x + if Meta.isexpr(x, :kw) # undo lhs rewriting + if startswith(string(x.args[1]), string(prepend)) + return Expr( + :kw, + Symbol(string(x.args[1])[(length(string(prepend)) + 1):end]), + x.args[2], + ) + end + end if x isa Symbol && x ∈ all_vars return Symbol(prepend, x) end diff --git a/src/Compiler.jl b/src/Compiler.jl index bb912d0663..bd384a747b 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1,13 +1,14 @@ module Compiler using Reactant_jll +using Libdl: dlsym import ..Reactant: Reactant, MLIR, XLA, - ConcreteRArray, - ConcreteRNumber, + ConcretePJRTArray, + ConcretePJRTNumber, TracedRArray, TracedRNumber, RArray, @@ -16,11 +17,21 @@ import ..Reactant: make_tracer, TracedToConcrete, append_path, + ancestor, TracedType -@inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field) +import ..ReactantCore: correct_maybe_bcast_call + +@inline function traced_getfield(@nospecialize(obj::Dict), field) + return Base.getindex(obj, field) +end + +@inline function traced_getfield(@nospecialize(obj), field) + return Base.getfield(obj, field) +end + @inline function traced_getfield(@nospecialize(obj::AbstractArray{T}), field) where {T} - (isbitstype(T) || obj isa RArray) && return Base.getfield(obj, field) + (isbitstype(T) || ancestor(obj) isa RArray) && return Base.getfield(obj, field) return Base.getindex(obj, field) end @@ -28,11 +39,18 @@ end @inline function traced_setfield!( @nospecialize(obj::AbstractArray{T}), field, val ) where {T} - (isbitstype(T) || obj isa RArray) && return Base.setfield!(obj, field, val) + ancestor_obj = ancestor(obj) + (isbitstype(T) || ancestor_obj isa RArray) && return Base.setfield!(obj, field, val) return Base.setindex!(obj, val, field) end -function create_result(tocopy::T, path, result_stores) where {T} +@inline function traced_setfield!(@nospecialize(obj::Dict), field, val) + return Base.setindex!(obj, field, val) +end + +function create_result( + tocopy::T, path, result_stores, path_to_shard_info, sharding_mesh +) where {T} if !isstructtype(typeof(tocopy)) error("cannot copy $tocopy of type $(Core.Typeof(tocopy))") end @@ -40,237 +58,411 @@ function create_result(tocopy::T, path, result_stores) where {T} elems = Union{Symbol,Expr}[] for i in 1:fieldcount(T) - ev = create_result(getfield(tocopy, i), append_path(path, i), result_stores) + # If the field is undefined we don't set it. A common example for this is `du2` + # for Tridiagonal + isdefined(tocopy, i) || continue + ev = create_result( + getfield(tocopy, i), + append_path(path, i), + result_stores, + path_to_shard_info, + sharding_mesh, + ) push!(elems, ev) end return Expr(:new, T, elems...) end -function create_result(tocopy::ConcreteRNumber{T}, path, result_stores) where {T} +function __reconstruct_shardinfo(path, path_to_shard_info, sharding_mesh, N::Integer) + device_to_array_slices, hlo_sharding = path_to_shard_info[path] + delete!(path_to_shard_info, path) + sharding = Reactant.Sharding.HloSharding( + hlo_sharding, sharding_mesh, ntuple(Returns(true), N), ntuple(Returns(-1), N) + ) + return Reactant.Sharding.ShardInfo(sharding, device_to_array_slices) +end + +function create_result( + tocopy::ConcretePJRTNumber{T,D,S}, + path, + result_stores, + path_to_shard_info, + sharding_mesh, +) where {T,D,S} if haskey(result_stores, path) restore = result_stores[path] delete!(result_stores, path) - return :(ConcreteRNumber{$T}($restore)) + if path_to_shard_info !== nothing # restore sharding + sharding = __reconstruct_shardinfo( + path, path_to_shard_info, sharding_mesh, ndims(tocopy) + ) + return :(ConcretePJRTNumber{$T,length($(restore)),$(typeof(sharding))}( + ($(restore)...,), $sharding + )) + else + return :(ConcretePJRTNumber{$T}($restore)) + end + end + + if path_to_shard_info !== nothing # restore sharding + sharding = __reconstruct_shardinfo( + path, path_to_shard_info, sharding_mesh, ndims(tocopy) + ) + return :(ConcretePJRTNumber{$T,length($(tocopy.data)),$(typeof(sharding))}( + ($(tocopy.data...,)), $sharding + )) end # We will set the data for this later - return :(ConcreteRNumber{$T}($(tocopy.data))) + return :(ConcretePJRTNumber{$T}($(tocopy.data))) end -function create_result(tocopy::ConcreteRArray{T,N}, path, result_stores) where {T,N} +function create_result( + tocopy::ConcretePJRTArray{T,N,D,S}, + path, + result_stores, + path_to_shard_info, + sharding_mesh, +) where {T,N,D,S} if haskey(result_stores, path) restore = result_stores[path] delete!(result_stores, path) - return :(ConcreteRArray{$T,$N}($restore, $(tocopy.shape))) + if path_to_shard_info !== nothing # restore sharding + sharding = __reconstruct_shardinfo( + path, path_to_shard_info, sharding_mesh, ndims(tocopy) + ) + return :(ConcretePJRTArray{$T,$N,length($(restore)),$(typeof(sharding))}( + ($(restore)...,), $(tocopy.shape), $sharding + )) + else + return :(ConcretePJRTArray{$T,$N}($restore, $(tocopy.shape))) + end + end + + if path_to_shard_info !== nothing # restore sharding + sharding = __reconstruct_shardinfo( + path, path_to_shard_info, sharding_mesh, ndims(tocopy) + ) + return :(ConcretePJRTArray{$T,$N,length($(tocopy.data)),$(typeof(sharding))}( + ($(tocopy.data)...,), $(tocopy.shape), $sharding + )) end # We will set the data for this later - return :(ConcreteRArray{$T,$N}($(tocopy.data), $(tocopy.shape))) + return :(ConcretePJRTArray{$T,$N,$D,$S}( + $(tocopy.data), $(tocopy.shape), $(tocopy.sharding) + )) end -function create_result(tocopy::Array{T,N}, path, result_stores) where {T,N} +function create_result( + tocopy::Array{T,N}, path, result_stores, path_to_shard_info, sharding_mesh +) where {T,N} elems = Expr[] for (i, v) in enumerate(tocopy) - push!(elems, create_result(v, append_path(path, i), result_stores)) + push!( + elems, + create_result( + v, append_path(path, i), result_stores, path_to_shard_info, sharding_mesh + ), + ) end # TODO is there a way to not call `reshape` here? what expr is used for array literals? return :(reshape($T[$(elems...)], $(size(tocopy))...)) end -function create_result(tocopy::Tuple, path, result_stores) +function create_result( + tocopy::Tuple, path, result_stores, path_to_shard_info, sharding_mesh +) elems = Union{Symbol,Expr}[] for (k, v) in pairs(tocopy) - push!(elems, create_result(v, append_path(path, k), result_stores)) + push!( + elems, + create_result( + v, append_path(path, k), result_stores, path_to_shard_info, sharding_mesh + ), + ) end return :(($(elems...),)) end -function create_result(tocopy::NamedTuple{K,T}, path, result_stores) where {K,T} +function create_result( + tocopy::NamedTuple{K,T}, path, result_stores, path_to_shard_info, sharding_mesh +) where {K,T} elems = Union{Symbol,Expr}[] for (i, (k, v)) in enumerate(pairs(tocopy)) - push!(elems, create_result(v, append_path(path, i), result_stores)) + push!( + elems, + create_result( + v, append_path(path, i), result_stores, path_to_shard_info, sharding_mesh + ), + ) end return :(NamedTuple{$K}(($(elems...),))) end -function create_result(tocopy::D, path, result_stores) where {K,V,D<:AbstractDict{K,V}} +function create_result( + tocopy::D, path, result_stores, path_to_shard_info, sharding_mesh +) where {K,V,D<:AbstractDict{K,V}} elems = Expr[] for (i, p) in enumerate(pairs(tocopy)) - push!(elems, create_result(p, append_path(path, i), result_stores)) + push!( + elems, + create_result( + p, append_path(path, i), result_stores, path_to_shard_info, sharding_mesh + ), + ) end return :($D([$(elems...)])) end function create_result( - tocopy::Union{Integer,AbstractFloat,AbstractString,Nothing,Type,Symbol}, + tocopy::Union{Integer,AbstractFloat,AbstractString,Nothing,Type,Symbol,Char}, path, result_stores, + path_to_shard_info, + sharding_mesh, ) return Meta.quot(tocopy) end -const opt_passes::String = join( - [ - "inline{default-pipeline=canonicalize max-iterations=4}", - "canonicalize,cse", - "canonicalize", - "enzyme-hlo-generate-td{" * - join( - [ - "patterns=compare_op_canon<16>", - "transpose_transpose<16>", - "broadcast_in_dim_op_canon<16>", - "convert_op_canon<16>", - "dynamic_broadcast_in_dim_op_not_actually_dynamic<16>", - "chained_dynamic_broadcast_in_dim_canonicalization<16>", - "dynamic_broadcast_in_dim_all_dims_non_expanding<16>", - "noop_reduce_op_canon<16>", - "empty_reduce_op_canon<16>", - "dynamic_reshape_op_canon<16>", - "get_tuple_element_op_canon<16>", - "real_op_canon<16>", - "imag_op_canon<16>", - "conj_complex_negate<16>", - "get_dimension_size_op_canon<16>", - "gather_op_canon<16>", - "reshape_op_canon<16>", - "merge_consecutive_reshapes<16>", - "transpose_is_reshape<16>", - "zero_extent_tensor_canon<16>", - "reorder_elementwise_and_shape_op<16>", - "cse_broadcast_in_dim<16>", - "cse_slice<16>", - "cse_transpose<16>", - "cse_convert<16>", - "cse_pad<16>", - "cse_dot_general<16>", - "cse_reshape<16>", - "cse_mul<16>", - "cse_div<16>", - "cse_add<16>", - "cse_subtract<16>", - "cse_min<16>", - "cse_max<16>", - "cse_neg<16>", - "cse_concatenate<16>", - "concatenate_op_canon<16>(1024)", - "select_op_canon<16>(1024)", - "add_simplify<16>", - "sub_simplify<16>", - "and_simplify<16>", - "max_simplify<16>", - "min_simplify<16>", - "or_simplify<16>", - "negate_simplify<16>", - "mul_simplify<16>", - "div_simplify<16>", - "rem_simplify<16>", - "pow_simplify<16>", - "sqrt_simplify<16>", - "cos_simplify<16>", - "sin_simplify<16>", - "noop_slice<16>", - "noop_reverse<16>", - "const_prop_through_barrier<16>", - "slice_slice<16>", - "shift_right_logical_simplify<16>", - "pad_simplify<16>", - "negative_pad_to_slice<16>", - "tanh_simplify<16>", - "exp_simplify<16>", - "slice_simplify<16>", - "convert_simplify<16>", - "dynamic_slice_to_static<16>", - "dynamic_update_slice_elim<16>", - "concat_to_broadcast<16>", - "reduce_to_reshape<16>", - "broadcast_to_reshape<16>", - "gather_simplify<16>", - "iota_simplify<16>(1024)", - "broadcast_in_dim_simplify<16>(1024)", - "convert_concat<1>", - "dynamic_update_to_concat<1>", - "slice_of_dynamic_update<1>", - "slice_elementwise<1>", - "slice_pad<1>", - "dot_reshape_dot<1>", - "concat_const_prop<1>", - "concat_fuse<1>", - "pad_reshape_pad<1>", - "pad_pad<1>", - "concat_push_binop_add<1>", - "concat_push_binop_mul<1>", - "scatter_to_dynamic_update_slice<1>", - "reduce_concat<1>", - "slice_concat<1>", - "concat_slice<1>", - "bin_broadcast_splat_add<1>", - "bin_broadcast_splat_subtract<1>", - "bin_broadcast_splat_div<1>", - "bin_broadcast_splat_mul<1>", - "reshape_iota<16>", - "slice_reshape_slice<1>", - "dot_general_simplify<16>", - "transpose_simplify<16>", - "reshape_empty_broadcast<1>", - "add_pad_pad_to_concat<1>", - "broadcast_reshape<1>", - "slice_reshape_concat<1>", - "slice_reshape_elementwise<1>", - "slice_reshape_transpose<1>", - "slice_reshape_dot_general<1>", - "concat_pad<1>", - "reduce_pad<1>", - "broadcast_pad<1>", - "zero_product_reshape_pad<1>", - "mul_zero_pad<1>", - "div_zero_pad<1>", - "binop_const_reshape_pad<1>", - "binop_const_pad_add<1>", - "binop_const_pad_subtract<1>", - "binop_const_pad_mul<1>", - "binop_const_pad_div<1>", - "slice_reshape_pad<1>", - "binop_binop_pad_pad_add<1>", - "binop_binop_pad_pad_mul<1>", - "binop_pad_pad_add<1>", - "binop_pad_pad_subtract<1>", - "binop_pad_pad_mul<1>", - "binop_pad_pad_div<1>", - "binop_pad_pad_min<1>", - "binop_pad_pad_max<1>", - "unary_pad_push_convert<1>", - "unary_pad_push_tanh<1>", - "unary_pad_push_exp<1>", - "transpose_pad<1>", - "transpose_dot_reorder<1>", - "dot_transpose<1>", - "transpose_einsum<1>", - "einsum_transpose<1>", - "transpose_convolution<1>", - "convolution_transpose<1>", - "convert_convert_float<1>", - "concat_to_pad<1>", - "concat_appending_reshape<1>", - "reshape_iota<1>", - "broadcast_reduce<1>", - "slice_dot_general<1>", - "dot_reshape_pad<1>", - "pad_dot_general<1>(0)", - "dot_reshape_pad<1>", - "pad_dot_general<1>(1)", - "if_inline<1>", - "if_to_select<1>", - "dynamic_update_slice_const_prop", - "dynamic_gather_op_is_not_dynamic<16>", - ], - ';', - ) * - "}", - "transform-interpreter", - "enzyme-hlo-remove-transform", - ], - ',', -) +# Optimization passes via transform dialect +function optimization_passes(; no_nan::Bool=false, sroa::Bool=false, inline::Bool=true) + transform_passes_list = [ + "patterns=compare_op_canon<16>", + "transpose_transpose<16>", + "broadcast_in_dim_op_canon<16>", + "convert_op_canon<16>", + "dynamic_broadcast_in_dim_op_not_actually_dynamic<16>", + "chained_dynamic_broadcast_in_dim_canonicalization<16>", + "dynamic_broadcast_in_dim_all_dims_non_expanding<16>", + "noop_reduce_op_canon<16>", + "empty_reduce_op_canon<16>", + "dynamic_reshape_op_canon<16>", + "get_tuple_element_op_canon<16>", + "real_op_canon<16>", + "imag_op_canon<16>", + "conj_complex_negate<16>", + "get_dimension_size_op_canon<16>", + "gather_op_canon<16>", + "reshape_op_canon<16>", + "merge_consecutive_reshapes<16>", + "transpose_is_reshape<16>", + "zero_extent_tensor_canon<16>", + "reorder_elementwise_and_shape_op<16>", + "chlo_inf_const_prop<16>", + "gamma_const_prop<16>", + "cse_broadcast_in_dim<16>", + "cse_slice<16>", + "cse_transpose<16>", + "cse_convert<16>", + "cse_pad<16>", + "cse_dot_general<16>", + "cse_reshape<16>", + "cse_mul<16>", + "cse_div<16>", + "cse_add<16>", + "cse_subtract<16>", + "cse_min<16>", + "cse_max<16>", + "cse_neg<16>", + "cse_concatenate<16>", + "concatenate_op_canon<16>(1024)", + "select_op_canon<16>(1024)", + "add_simplify<16>", + "sub_simplify<16>", + "and_simplify<16>", + "max_simplify<16>", + "min_simplify<16>", + "or_simplify<16>", + "negate_simplify<16>", + "mul_simplify<16>", + "div_simplify<16>", + "rem_simplify<16>", + "pow_simplify<16>", + "sqrt_simplify<16>", + "cos_simplify<16>", + "sin_simplify<16>", + "noop_slice<16>", + "noop_reverse<16>", + "const_prop_through_barrier<16>", + "slice_slice<16>", + "shift_right_logical_simplify<16>", + "pad_simplify<16>", + "negative_pad_to_slice<16>", + "tanh_simplify<16>", + "exp_simplify<16>", + "slice_simplify<16>", + "convert_simplify<16>", + "dynamic_slice_to_static<16>", + "dynamic_update_slice_elim<16>", + "concat_to_broadcast<16>", + "reduce_to_reshape<16>", + "broadcast_to_reshape<16>", + "gather_simplify<16>", + "iota_simplify<16>(1024)", + "broadcast_in_dim_simplify<16>(1024)", + "convert_concat<1>", + "dynamic_update_to_concat<1>", + "slice_of_dynamic_update<1>", + "slice_elementwise<1>", + "slice_pad<1>", + "dot_reshape_dot<1>", + "concat_const_prop<1>", + "concat_fuse<1>", + "pad_reshape_pad<1>", + "pad_pad<1>", + "concat_push_binop_add<1>", + "concat_push_binop_mul<1>", + "scatter_to_dynamic_update_slice<1>", + "reduce_concat<1>", + "slice_concat<1>", + "concat_slice<1>", + "select_op_used_within_if<1>", + "bin_broadcast_splat_add<1>", + "bin_broadcast_splat_subtract<1>", + "bin_broadcast_splat_div<1>", + "bin_broadcast_splat_mul<1>", + "reshape_iota<16>", + "slice_reshape_slice<1>", + "dot_general_simplify<16>", + "transpose_simplify<16>", + "reshape_empty_broadcast<1>", + "add_pad_pad_to_concat<1>", + "broadcast_reshape<1>", + "slice_reshape_concat<1>", + "slice_reshape_elementwise<1>", + "slice_reshape_transpose<1>", + "slice_reshape_dot_general<1>", + "concat_pad<1>", + "reduce_pad<1>", + "broadcast_pad<1>", + "zero_product_reshape_pad<1>", + "mul_zero_pad<1>", + "div_zero_pad<1>", + "binop_const_reshape_pad<1>", + "binop_const_pad_add<1>", + "binop_const_pad_subtract<1>", + "binop_const_pad_mul<1>", + "binop_const_pad_div<1>", + "slice_reshape_pad<1>", + "binop_binop_pad_pad_add<1>", + "binop_binop_pad_pad_mul<1>", + "binop_pad_pad_add<1>", + "binop_pad_pad_subtract<1>", + "binop_pad_pad_mul<1>", + "binop_pad_pad_div<1>", + "binop_pad_pad_min<1>", + "binop_pad_pad_max<1>", + "unary_pad_push_convert<1>", + "unary_pad_push_tanh<1>", + "unary_pad_push_exp<1>", + "transpose_pad<1>", + "transpose_dot_reorder<1>", + "dot_transpose<1>", + "transpose_einsum<1>", + "einsum_transpose<1>", + "transpose_convolution<1>", + "convolution_transpose<1>", + "convert_convert_float<1>", + "concat_to_pad<1>", + "concat_appending_reshape<1>", + "reshape_iota<1>", + "broadcast_reduce<1>", + "slice_dot_general<1>", + "dot_reshape_pad<1>", + "pad_dot_general<1>(0)", + "dot_reshape_pad<1>", + "pad_dot_general<1>(1)", + "if_inline<1>", + "if_to_select<1>", + "dynamic_update_slice_const_prop", + "dynamic_gather_op_is_not_dynamic<16>", + "divide_sqrt_to_multiply_rsqrt<16>", + "binary_op_transpose_simplify_add", + "binary_op_transpose_simplify_sub", + "binary_op_transpose_simplify_mul", + "binary_op_transpose_simplify_div", + "binary_op_transpose_simplify_min", + "binary_op_transpose_simplify_max", + "binary_op_transpose_simplify_pow", + "binary_op_transpose_simplify_rem", + "binary_op_transpose_simplify_or", + "binary_op_transpose_simplify_and", + "binary_op_transpose_simplify_xor", + "associative_binary_op_reordering<1>", + "transpose_unary_transpose_abs", + "transpose_unary_transpose_neg", + "transpose_unary_transpose_sqrt", + "transpose_unary_transpose_rsqrt", + "transpose_unary_transpose_ceil", + "transpose_unary_transpose_convert", + "transpose_unary_transpose_cosine", + "transpose_unary_transpose_exp", + "transpose_unary_transpose_expm1", + "transpose_unary_transpose_log", + "transpose_unary_transpose_log1p", + "transpose_unary_transpose_sign", + "transpose_unary_transpose_sine", + "transpose_unary_transpose_tanh", + "transpose_broadcast_in_dim_to_broadcast_in_dim<16>", + "scatter_indices_are_unique", + "transpose_reduce_simplify", + "replace_neg_add_with_subtract", + "log_const_prop<1>", + "log_plus_one_const_prop<1>", + "binop_const_simplify", + "transpose_broadcast_in_dim_to_broadcast_in_dim", + "not_select_simplify", + "scatter_update_computation_const_prop", + "common_compare_expression_rewrite", + "compare_select_simplify", + "while_simplify<1>", + "scatter_update_computation_const_prop", + "if_remove_unused", + ] + if no_nan + append!( + transform_passes_list, + ["no_nan", "no_nan_self_sub_simplify", "no_nan_add_sub_simplify(1)"], + ) + else + push!(transform_passes_list, "no_nan_add_sub_simplify(0)") + end + transform_passes = join( + [ + "enzyme-hlo-generate-td{" * join(transform_passes_list, ';') * "}", + "transform-interpreter", + "enzyme-hlo-remove-transform", + ], + ",", + ) + func_passes = join(["canonicalize", "cse", "canonicalize", transform_passes], ",") + passes = String[] + if inline + push!(passes, "inline{default-pipeline=canonicalize max-iterations=4}") + end + if sroa + push!(passes, "propagate-constant-bounds") + if DUMP_LLVMIR[] + push!( + passes, + "sroa-wrappers{dump_prellvm=true dump_postllvm=true instcombine=false instsimplify=true}", + ) + else + push!(passes, "sroa-wrappers{instcombine=false instsimplify=true}") + end + push!(passes, "canonicalize") + push!(passes, "sroa-wrappers{instcombine=false instsimplify=true}") + push!(passes, "libdevice-funcs-raise") + push!(passes, "canonicalize") + push!(passes, "remove-duplicate-func-def") + end + push!(passes, func_passes) + return join(passes, ',') +end + +# TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate +# However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass]. +const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}" function run_pass_pipeline!(mod, pass_pipeline; enable_verifier=true) pm = MLIR.IR.PassManager() @@ -281,40 +473,197 @@ function run_pass_pipeline!(mod, pass_pipeline; enable_verifier=true) return mod end +const context_gc_vector = Dict{MLIR.IR.Context,Vector{TracedRArray}}() + # helper for debug purposes: String -> Text function run_pass_pipeline_on_source(source, pass_pipeline; enable_verifier=true) ctx = MLIR.IR.Context(Reactant.registry[], false) + context_gc_vector[ctx] = Vector{TracedRArray}(undef, 0) @ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid - MLIR.IR.context!(ctx) do + result = MLIR.IR.context!(ctx) do mod = parse(MLIR.IR.Module, source) run_pass_pipeline!(mod, pass_pipeline; enable_verifier) MLIR.IR.verifyall(MLIR.IR.Operation(mod); debug=true) Text(repr(mod)) end + Base.delete!(context_gc_vector, ctx) + return result end -function compile_mlir(f, args; kwargs...) +function compile_mlir(f, args; client=nothing, kwargs...) ctx = MLIR.IR.Context(Reactant.registry[], false) + context_gc_vector[ctx] = Vector{TracedRArray}(undef, 0) @ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid - MLIR.IR.context!(ctx) do + + backend = XLA.platform_name(client !== nothing ? client : XLA.default_backend()) + + if backend == "CUDA" + backend = "GPU" + elseif backend == "CPU" + backend = "cpu" + end + + results = MLIR.IR.context!(ctx) do mod = MLIR.IR.Module(MLIR.IR.Location()) - evalinfo = compile_mlir!(mod, f, args; kwargs...) - return mod, evalinfo... + + mlir_fn_res = compile_mlir!(mod, f, args; backend, kwargs...) + + # Attach a name, and partitioning attributes to the module + __add_mhlo_attributes_and_name!( + mod, f; mlir_fn_res.num_partitions, mlir_fn_res.num_replicas + ) + + return mod, mlir_fn_res + end + Base.delete!(context_gc_vector, ctx) + + return results +end + +const PartitionKA = Ref{Bool}(true) + +const cubinChip = Ref{String}("sm_60") +const cubinFormat = Ref{String}("bin") +const cuindexBitWidth = Ref{Int}(32) +const cuOptLevel = Ref{Int}(2) +# Wgatever the relevant highest version from our LLVM is within NVPTX.td +# Or more specifically looking at clang/lib/Driver/ToolChains/Cuda.cpp:684 +# We see relevant ptx version is CUDA 12.6 -> 85 +# 12.2 -> 82 +# 11.8 -> 78 +function cubinFeatures() + ver = @ccall MLIR.API.mlir_c.ReactantCudaDriverGetVersion()::UInt32 + # No cuda available + if ver == 0 + return "+ptx86" end + ver2 = @ccall MLIR.API.mlir_c.ReactantHermeticCudaGetVersion()::UInt32 + ver = min(ver, ver2) + major, ver = divrem(ver, 1000) + minor, patch = divrem(ver, 10) + version = VersionNumber(major, minor, patch) + # From https://github.com/llvm/llvm-project/blob/106c483a102e1328f11e2b1d9398f4ad2826b59f/clang/lib/Driver/ToolChains/Cuda.cpp#L685 + cuver_map = Dict([ + (126, 85), + (125, 85), + (124, 84), + (123, 83), + (122, 82), + (121, 81), + (120, 80), + (118, 78), + (117, 77), + (116, 76), + (115, 75), + (114, 74), + (113, 73), + (112, 72), + (111, 71), + (110, 70), + (102, 65), + (101, 64), + (100, 63), + (92, 61), + (91, 61), + (90, 60), + ]) + mver = major * 10 + minor + if mver > 126 + return 86 + end + ptx = cuver_map[mver] + return "+ptx$ptx" end -const cuLaunch = Ref{UInt}(0) -const cuFunc = Ref{UInt}(0) -const cuModule = Ref{UInt}(0) +const DEBUG_KERNEL = Ref{Bool}(false) +const DUMP_LLVMIR = Ref{Bool}(false) -function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) - fnwrapped, - func2, traced_result, result, seen_args, ret, linear_args, in_tys, - linear_results = MLIR.IR.mmodule!(mod) do - MLIR.IR.block!(MLIR.IR.body(mod)) do - return Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true) - end +function activate_raising!(is_raising::Bool) + stack = get!(task_local_storage(), :reactant_is_raising) do + Bool[] end + push!(stack, is_raising) + return nothing +end + +function deactivate_raising!(is_raising::Bool) + key = :reactant_is_raising + is_raising === last(task_local_storage(key)) || + error("Deactivating wrong Reactant raising context") + return pop!(task_local_storage(key)) +end + +function raising(; throw_error::Bool=true) + key = :reactant_is_raising + if !(haskey(task_local_storage(), key) && !Base.isempty(task_local_storage(key))) + throw_error && error("No Reactant raising context") + end + return last(task_local_storage(key)::Vector{Bool}) +end + +function raising!(f, is_raising::Bool) + activate_raising!(is_raising) + try + return f() + finally + deactivate_raising!(is_raising) + end +end + +function compile_mlir!( + mod, + f, + args, + callcache=Dict{ + Vector, + @NamedTuple{ + f_name::String, + mlir_result_types::Vector{MLIR.IR.Type}, + traced_result::Any, + mutated_args::Vector{Int}, + } + }(), + sdycache=IdDict{ + Reactant.Sharding.Mesh, + @NamedTuple{ + sym_name::MLIR.IR.Attribute, + mesh_attr::MLIR.IR.Attribute, + mesh_op::MLIR.IR.Operation, + } + }(); + optimize::Union{Bool,Symbol}=true, + no_nan::Bool=false, + backend="gpu", + fn_kwargs=(), + raise::Union{Bool,String}=false, + input_shardings=nothing, +) + # Explicitly don't use block! to avoid creating a closure, which creates + # both compile-time and relocatability issues + + MLIR.IR.activate!(mod) + MLIR.IR.activate!(MLIR.IR.body(mod)) + activate_callcache!(callcache) + activate_sdycache!(sdycache) + + # Save in the TLS whether we are raising. We identify that condition by + # checking whether the user set an explicit list of passes, or chose + # `raise=true` to use the default passes. + is_raising = raise isa String || raise + activate_raising!(is_raising) + + mlir_fn_res = try + Reactant.TracedUtils.make_mlir_fn(f, args, fn_kwargs, "main", true; input_shardings) + finally + deactivate_raising!(is_raising) + deactivate_sdycache!(sdycache) + deactivate_callcache!(callcache) + MLIR.IR.deactivate!(MLIR.IR.body(mod)) + MLIR.IR.deactivate!(mod) + end + (; fnwrapped, traced_result, seen_args, ret, linear_args, in_tys, linear_results) = + mlir_fn_res + compiled_f = mlir_fn_res.f concrete_seen = OrderedIdDict() @@ -328,10 +677,50 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) if isdefined(Reactant_jll, :ptxas_path) toolkit = Reactant_jll.ptxas_path[1:(end - length("/bin/ptxas"))] end - kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])}" + + if backend == "cpu" + kern = "lower-kernel{backend=cpu},canonicalize" + jit = "lower-jit{openmp=true backend=cpu},symbol-dce" + elseif DEBUG_KERNEL[] + curesulthandler = dlsym( + Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult" + ) + @assert curesulthandler !== nothing + curesulthandler = Base.reinterpret(UInt, curesulthandler) + kern = if is_raising + "lower-kernel{backend=cpu},symbol-dce,canonicalize" + else + "lower-kernel,canonicalize" + end + jit = "lower-jit{debug=true cuResultHandlerPtr=$curesulthandler cuOptLevel=$(cuOptLevel[]) cubinFormat=$(cubinFormat[]) indexBitWidth=$(cuindexBitWidth[]) cubinChip=$(cubinChip[]) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce" + else + kern = if is_raising + "lower-kernel{backend=cpu},symbol-dce,canonicalize" + else + "lower-kernel,canonicalize" + end + jit = "lower-jit{cuOptLevel=$(cuOptLevel[]) indexBitWidth=$(cuindexBitWidth[]) cubinFormat=$(cubinFormat[]) cubinChip=$(cubinChip[]) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce" + end + + opt_passes = optimization_passes(; no_nan, sroa=true) + opt_passes2 = optimization_passes(; no_nan, sroa=false) + + raise_passes = if raise isa String + # Raising passes were specified + raise + elseif raise + # Raise enabled but use default passes + "canonicalize,llvm-to-memref-access,canonicalize,convert-llvm-to-cf,canonicalize,enzyme-lift-cf-to-scf,canonicalize,func.func(canonicalize-loops),canonicalize-scf-for,canonicalize,affine-cfg,canonicalize,func.func(canonicalize-loops),canonicalize,llvm-to-affine-access,canonicalize,delinearize-indexing,canonicalize,simplify-affine-exprs,affine-cfg,canonicalize,raise-affine-to-stablehlo,arith-raise{stablehlo=true}," * + opt_passes2 + else + "canonicalize" + end + if optimize === :all - run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) - run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false) + run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ",")) + run_pass_pipeline!( + mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false + ) run_pass_pipeline!( mod, join( @@ -339,15 +728,19 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", - opt_passes, + opt_passes2, kern, + raise_passes, + jit, ], ',', ), ) elseif optimize === :before_kernel - run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) - run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false) + run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ",")) + run_pass_pipeline!( + mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false + ) run_pass_pipeline!( mod, join( @@ -355,13 +748,50 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", - opt_passes, + opt_passes2, + ], + ',', + ), + ) + elseif optimize === :before_jit + run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ",")) + run_pass_pipeline!( + mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false + ) + run_pass_pipeline!( + mod, + join( + [ + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + opt_passes2, + kern, + raise_passes, + ], + ',', + ), + ) + elseif optimize === :before_raise + run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ",")) + run_pass_pipeline!( + mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false + ) + run_pass_pipeline!( + mod, + join( + [ + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + opt_passes2, + kern, ], ',', ), ) elseif optimize === :no_enzyme - run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) + run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ",")) run_pass_pipeline!(mod, "arith-raise{stablehlo=true}"; enable_verifier=false) run_pass_pipeline!( mod, @@ -370,14 +800,16 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", - opt_passes, + opt_passes2, ], ',', ), ) elseif optimize === :only_enzyme run_pass_pipeline!(mod, "enzyme-batch") - run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false) + run_pass_pipeline!( + mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false + ) run_pass_pipeline!( mod, join( @@ -387,7 +819,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) ) elseif optimize === :after_enzyme run_pass_pipeline!(mod, "enzyme-batch") - run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false) + run_pass_pipeline!( + mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false + ) run_pass_pipeline!( mod, join( @@ -395,18 +829,35 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", - opt_passes, + opt_passes2, kern, + raise_passes, + jit, ], ',', ), ) elseif optimize === :before_enzyme - run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) - run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false) + run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ",")) + run_pass_pipeline!( + mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false + ) run_pass_pipeline!( - mod, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math," * kern + mod, + join( + [ + "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math", + kern, + raise_passes, + jit, + ], + ',', + ), ) + elseif optimize === :canonicalize + run_pass_pipeline!(mod, "canonicalize") + elseif optimize === :just_batch + run_pass_pipeline!(mod, "enzyme-batch") elseif optimize !== :none error("Invalid optimize option: $(Meta.quot(optimize))") end @@ -415,14 +866,17 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) results = [MLIR.IR.operand(ret, i) for i in 1:MLIR.IR.noperands(ret)] nresults = MLIR.IR.Value[] linear_results2 = TracedType[] + results_mask = falses(length(results)) for (i, op) in enumerate(results) if !MLIR.IR.is_block_arg(op) push!(nresults, op) push!(linear_results2, linear_results[i]) + results_mask[i] = true continue end push!(preserved_args, (linear_results[i], MLIR.IR.block_arg_num(op))) end + fnbody = MLIR.IR.block(ret) MLIR.API.mlirOperationDestroy(ret.operation) ret.operation = MLIR.API.MlirOperation(C_NULL) @@ -432,50 +886,147 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) out_tys2 = [MLIR.IR.type(a) for a in nresults] + res_attrs = MLIR.IR.attr(compiled_f, "res_attrs") + if res_attrs isa MLIR.IR.Attribute + res_attrs = MLIR.IR.Attribute[ + res_attrs[i - 1] for (i, present) in enumerate(results_mask) if present + ] + end + func3 = MLIR.Dialects.func.func_(; sym_name="main", function_type=MLIR.IR.FunctionType(in_tys, out_tys2), + arg_attrs=MLIR.IR.attr(compiled_f, "arg_attrs"), + res_attrs, + no_inline=MLIR.IR.attr(compiled_f, "no_inline"), body=MLIR.IR.Region(), ) - MLIR.API.mlirRegionTakeBody(MLIR.IR.region(func3, 1), MLIR.IR.region(func2, 1)) + MLIR.API.mlirRegionTakeBody(MLIR.IR.region(func3, 1), MLIR.IR.region(compiled_f, 1)) push!(MLIR.IR.body(mod), func3) - MLIR.API.mlirOperationDestroy(func2.operation) - func2.operation = MLIR.API.MlirOperation(C_NULL) + MLIR.API.mlirOperationDestroy(compiled_f.operation) + compiled_f.operation = MLIR.API.MlirOperation(C_NULL) - return linear_args, - linear_results2, preserved_args, seen_args, concrete_result, - fnwrapped + return Reactant.TracedUtils.CompiledMlirFnResult( + fnwrapped, + func3, + traced_result, + mlir_fn_res.result, + seen_args, + ret, + linear_args, + in_tys, + linear_results2, + mlir_fn_res.num_partitions, + mlir_fn_res.num_replicas, + mlir_fn_res.is_sharded, + preserved_args, + concrete_result, + mlir_fn_res.sharding_mesh, + mlir_fn_res.mutated_args, + ) end """ - @code_hlo [optimize = ...] f(args...) + @code_hlo [optimize = ...] [no_nan = ] f(args...) + +See also [`@code_xla`](@ref), [`@code_mhlo`](@ref). """ macro code_hlo(args...) - default_options = Dict{Symbol,Any}(:optimize => true) + default_options = Dict{Symbol,Any}( + :optimize => true, :no_nan => false, :client => nothing, :raise => false + ) compile_expr, (; compiled) = compile_call_expr( __module__, compile_mlir, default_options, args... ) - return esc(:($(compile_expr); - $(first)($(compiled)))) + #! format: off + return esc( + :( + $(compile_expr); + $(first)($(compiled)) + ) + ) + #! format: on +end + +""" + @code_mhlo [optimize = ...] [no_nan = ] f(args...) + +Similar to `@code_hlo`, but prints the module after running the XLA compiler. + +See also [`@code_xla`](@ref), [`@code_hlo`](@ref). +""" +macro code_mhlo(args...) + default_options = Dict{Symbol,Any}( + :optimize => true, :no_nan => false, :client => nothing, :raise => false + ) + compile_expr, (; compiled) = compile_call_expr( + __module__, compile_xla, default_options, args... + ) + #! format: off + return esc( + :( + $(compile_expr); + $(first)($(compiled)) + ) + ) + #! format: on end """ - @compile f(args...) + @code_xla [optimize = ...] [no_nan = ] f(args...) + +Similar to `@code_hlo`, but prints the HLO module. + +See also [`@code_mhlo`](@ref), [`@code_hlo`](@ref). +""" +macro code_xla(args...) + default_options = Dict{Symbol,Any}( + :optimize => true, :no_nan => false, :client => nothing, :raise => false + ) + compile_expr, (; compiled) = compile_call_expr( + __module__, compile_xla, default_options, args... + ) + #! format: off + return esc( + :( + $(compile_expr); + exec = $(compiled)[2]; + hlo_modules = $(XLA.get_hlo_modules)(exec); + length(hlo_modules) == 1 ? only(hlo_modules) : hlo_modules + ) + ) + #! format: on +end + +""" + @compile [optimize = ...] [no_nan = ] [sync = ] f(args...) """ macro compile(args...) - default_options = Dict{Symbol,Any}(:optimize => true, :sync => false) + default_options = Dict{Symbol,Any}( + :optimize => true, + :sync => false, + :no_nan => false, + :client => nothing, + :raise => false, + ) return esc(first(compile_call_expr(__module__, compile, default_options, args...))) end """ - @jit f(args...) + @jit [optimize = ...] [no_nan = ] [sync = ] f(args...) - Run @compile f(args..) then immediately execute it +Run @compile f(args..) then immediately execute it """ macro jit(args...) - default_options = Dict{Symbol,Any}(:optimize => true, :sync => false) + default_options = Dict{Symbol,Any}( + :optimize => true, + :sync => false, + :no_nan => false, + :client => nothing, + :raise => false, + ) compile_expr, (; compiled, args) = compile_call_expr( __module__, compile, default_options, args... ) @@ -489,20 +1040,21 @@ macro jit(args...) #! format: on end -function compile_call_expr(mod, compiler, options, args...) +function compile_call_expr(mod, compiler, options::Dict, args...) while length(args) > 1 option, args = args[1], args[2:end] if !Meta.isexpr(option, :(=)) error("Invalid option $(option)") else option_name = option.args[1] - @assert haskey(options, option_name) "Invalid option $(option_name)" + @assert haskey(options, option_name) "Invalid option name '$(option_name)'. Valid options are $(join(keys(options), ", "))" options[option_name] = option.args[2] end end call = only(args) f_symbol = gensym(:f) args_symbol = gensym(:args) + kwargs_symbol = gensym(:kwargs) compiled_symbol = gensym(:compiled) if Meta.isexpr(call, :call) @@ -518,10 +1070,24 @@ function compile_call_expr(mod, compiler, options, args...) else :($(fname)) end - args_rhs = Expr(:tuple, call.args[2:end]...) + args_rhs = call.args[2:end] + + # if (;) is used, we need to extract the kwargs + if length(args_rhs) ≥ 1 && Meta.isexpr(args_rhs[1], :parameters) + kwargs_rhs = args_rhs[1].args + args_rhs = args_rhs[2:end] + else + kwargs_rhs = () + end + kw_idxs = findall(Base.Fix2(Meta.isexpr, :kw), args_rhs) + arg_idxs = setdiff(1:length(args_rhs), kw_idxs) + + kwargs_rhs = (kwargs_rhs..., args_rhs[kw_idxs]...) + args_rhs = Expr(:tuple, args_rhs[arg_idxs]...) elseif Meta.isexpr(call, :(.), 2) && Meta.isexpr(call.args[2], :tuple) fname = :($(Base.Broadcast.BroadcastFunction)($(call.args[1]))) args_rhs = only(call.args[2:end]) + kwargs_rhs = () else error("Invalid function call: $(call)") end @@ -529,18 +1095,17 @@ function compile_call_expr(mod, compiler, options, args...) return quote $(f_symbol) = $(fname) $(args_symbol) = $(args_rhs) + $(kwargs_symbol) = (; $(kwargs_rhs...)) $(compiled_symbol) = $(compiler)( - $(f_symbol), $(args_symbol); $(Expr.(:kw, keys(options), values(options))...) + $(f_symbol), + $(args_symbol); + fn_kwargs=$(kwargs_symbol), + $(Expr.(:kw, keys(options), values(options))...), ) end, (; compiled=compiled_symbol, args=args_symbol) end -function correct_maybe_bcast_call(fname) - startswith(string(fname), '.') || return false, fname, fname - return true, Symbol(string(fname)[2:end]), fname -end - """ codegen_flatten! @@ -560,60 +1125,101 @@ The name is due to its similarity to the `flatten` function in `jax.tree_util.re The _linearized arguments_ do not directly refer to the are the arguments that have been flattened into a single list. """ -function codegen_flatten!(linear_args, result_stores) +function codegen_flatten!( + linear_args, + seen_args, + result_stores, + is_sharded::Bool, + mesh, + linear_parameter_shardings, + client, +) flatten_names = Symbol[] flatten_code = Expr[] - # resarg_code = Expr[] + + if is_sharded + inv_seen_args = Reactant.OrderedIdDict() + for (k, v) in seen_args + inv_seen_args[v] = k + end + end for (i, arg) in enumerate(linear_args) - paths = ((p for p in arg.paths if p[1] == :args)...,) + paths = ( + ( + p for + p in Reactant.TracedUtils.get_paths(arg) if length(p) > 0 && p[1] == :args + )..., + ) path = if length(paths) == 1 paths[1] else - throw("Invalid path duplication $(arg.paths) into $(paths)") + throw( + "Invalid path duplication $(Reactant.TracedUtils.get_paths(arg)) into $(paths)", + ) end usbuf = Symbol(:usbuf_, i) - sbuf = Symbol(:sbuf_, i) - push!(flatten_names, sbuf) flatcode = :(getindex(args, $(path[2]))) for p in path[3:end] flatcode = :(traced_getfield($flatcode, $(Meta.quot(p)))) end - push!(flatten_code, :($usbuf = $flatcode.data)) - push!(flatten_code, :($sbuf = XLA.synced_buffer($usbuf))) - - # TODO: unused for the time being - # respaths = ((p for p in arg.paths if p[1] == :result || p[1] == :resargs)...,) - - # resarg = false - # for respath in respaths - # if respath[1] == :result - # flatcode = :result - # respath = respath[2:end] - # result_stores[respath] = usbuf - # resarg = true - # else - # @assert respath[1] == :resargs - # if respath[2] != path[2] - # continue - # end - # # flatcode = :(args[$(respath[2])]) - # path = path[3:end] - # end - # # for p in path - # # flatcode = :(traced_getfield($flatcode, $(Meta.quot(p)))) - # # end - # # resarg = true - # # flatcode = :($flatcode.data = $usbuf) - # # @show flatcode - # # push!(flatten_code, res) - # end - # if resarg - # push!(resarg_code, :($usbuf = $flatcode.data)) - # end + + if is_sharded + carg = inv_seen_args[arg] + condensed_op_sharding = convert( + Reactant.Sharding.XLA.CondensedOpSharding, linear_parameter_shardings[i] + ) + if Reactant.Sharding.is_sharded(carg) + arg_condensed_op_sharding = convert( + Reactant.Sharding.XLA.CondensedOpSharding, + carg.sharding.sharding.hlo_sharding, + ) + # Check if the sharding provided is same as the one we have + @assert arg_condensed_op_sharding == condensed_op_sharding "Sharding provided by the user ($arg_condensed_op_sharding) does not match the sharding computed by XLA ($condensed_op_sharding). This generally means that Reactant.jl made an error in generating the executable. Please open an issue with the error message and an MWE." + + push!(flatten_code, :($usbuf = $flatcode.data)) + for j in 1:length(mesh) + sbuf = Symbol(:sbuf_, i, "_", mesh.logical_device_ids[j]) + push!(flatten_names, sbuf) + push!(flatten_code, :($sbuf = XLA.synced_buffer(getindex($usbuf, $j)))) + end + else + push!(flatten_code, :($usbuf = $flatcode)) + device_to_array_slices, _ = XLA.sharding_to_concrete_array_indices( + condensed_op_sharding, size(carg), mesh.logical_device_ids + ) + for j in 1:length(mesh) + device_id = mesh.logical_device_ids[j] + buf = Symbol(:buf_, i, :_, device_id) + slice = device_to_array_slices[j] + push!( + flatten_code, + :($buf = XLA.synced_buffer(only($usbuf[$(slice)...].data))), + ) + sbuf = Symbol(:s, buf) + device = XLA.get_device(client, device_id) + push!(flatten_names, sbuf) + push!(flatten_code, :($sbuf = XLA.copy_buffer_to_device($buf, $device))) + end + end + else + push!(flatten_code, :($usbuf = $flatcode.data)) + sbuf = Symbol(:sbuf_, i) + push!(flatten_names, sbuf) + if arg isa TracedRArray || arg isa TracedRNumber + push!(flatten_code, :($sbuf = only(XLA.synced_buffer($usbuf)))) + else + error("Unsupported type $(typeof(arg))") + end + end end + + # We reorder how the buffers are passed to the XLA call + is_sharded && + (flatten_names = vcat(eachrow(reshape(flatten_names, length(mesh), :))...)) + return flatten_names, flatten_code end @@ -630,22 +1236,31 @@ function codegen_unflatten!( linear_results, concrete_result, result_stores, + path_to_shard_info, + linear_result_shard_info, + sharding_mesh, ) cache_dict = gensym("cache_dict") - unflatten_code = Expr[:( - $cache_dict = $(IdDict{ - Union{TracedRArray,TracedRNumber},Union{ConcreteRArray,ConcreteRNumber} - }()) - ),] + has_cache_dict = false + unflatten_code = Expr[] # mutate the result stores to point to the correct concrete results - for (concrete_res_name, result) in zip(concretized_res_names, linear_results) - paths = ((p for p in result.paths if p[1] == :result || p[1] == :resargs)...,) + for (concrete_res_name, result, shard_info) in + zip(concretized_res_names, linear_results, linear_result_shard_info) + paths = ( + ( + p for p in Reactant.TracedUtils.get_paths(result) if + length(p) > 0 && (p[1] == :result || p[1] == :resargs) + )..., + ) for path in paths if path[1] == :result unflatcode = :result path = path[2:end] result_stores[path] = concrete_res_name + if path_to_shard_info !== nothing + path_to_shard_info[path] = shard_info + end continue else @assert path[1] == :resargs @@ -659,14 +1274,27 @@ function codegen_unflatten!( if length(path) > 0 final_val = gensym("final_val") clocal = gensym("clocal") + if !has_cache_dict + has_cache_dict = true + push!( + unflatten_code, + :( + $cache_dict = $(IdDict{ + Union{TracedRArray,TracedRNumber}, + Union{ConcretePJRTArray,ConcretePJRTNumber}, + }()) + ), + ) + end unflatcode = quote $final_val = traced_getfield($unflatcode, $(Meta.quot(path[end]))) if $final_val isa TracedRArray $clocal = if haskey($cache_dict, $final_val) $cache_dict[$final_val] else - $cache_dict[$final_val] = ConcreteRArray{ - eltype($final_val),ndims($final_val) + $cache_dict[$final_val] = ConcretePJRTArray{ + $(Reactant.unwrapped_eltype)($final_val), + ndims($final_val), }( $concrete_res_name, size($final_val) ) @@ -677,7 +1305,9 @@ function codegen_unflatten!( $clocal = if haskey($cache_dict, $final_val) $cache_dict[$final_val] else - $cache_dict[$final_val] = ConcreteRNumber{eltype($final_val)}( + $cache_dict[$final_val] = ConcretePJRTNumber{ + $(Reactant.unwrapped_eltype)($final_val) + }( $concrete_res_name ) $cache_dict[$final_val] @@ -688,7 +1318,7 @@ function codegen_unflatten!( end end else - unflatcode = :($unflatcode.data = $concrete_res_name) + unflatcode = :(traced_setfield!($unflatcode, :data, $concrete_res_name)) end push!(unflatten_code, unflatcode) end @@ -696,15 +1326,27 @@ function codegen_unflatten!( end prevkeys = collect(keys(result_stores)) - result_code = create_result(concrete_result, (), result_stores) + result_code = create_result( + concrete_result, (), result_stores, path_to_shard_info, sharding_mesh + ) postkeys = collect(keys(result_stores)) used = [t for t in prevkeys if !in(t, postkeys)] # if some argument is mutated, change them to point to the correct concrete results for (result, arg_idx) in preserved_args - for path in result.paths + paths = ( + ( + p for p in Reactant.TracedUtils.get_paths(result) if + length(p) > 0 && (p[1] == :result || p[1] == :resargs || p[1] == :args) + )..., + ) + + for path in paths arg = linear_args[arg_idx + 1] - argpath = only((p for p in arg.paths if p[1] == :args)) + argpath = only(( + p for + p in Reactant.TracedUtils.get_paths(arg) if length(p) > 0 && p[1] == :args + )) if path[1] == :result res = :result @@ -713,7 +1355,7 @@ function codegen_unflatten!( continue end else - @assert path[1] == :resargs || path[1] == :args + @assert path[1] == :resargs || path[1] == :args "Expected :resargs or :args, got $(path[1])" # We can optimize cases where we set the arg to itself if path[2:end] == argpath[2:end] continue @@ -754,10 +1396,19 @@ Generate Julia code to call the XLA executable. - `donated_args_mask`: A list of `UInt8`s representing whether the argument is donated. - `nresults`: The number of results to expect. """ -function codegen_xla_call(exec, flatten_names, donated_args_mask, nresults) +function codegen_xla_call( + exec, + device, + flatten_names, + donated_args_mask, + nresults, + is_sharded::Bool, + ndevices::Int, +) flatten_buffer_refs = map(n -> :($n.buffer), flatten_names) - concretized_res_names = Symbol[Symbol(:concrete_res_, i) for i in 1:nresults] + base_symbol_name = is_sharded ? Symbol(:result_buffer_m, ndevices, :_) : :result_buffer_ + concretized_res_names = Symbol[Symbol(base_symbol_name, i) for i in 1:nresults] concretized_res_code = map(enumerate(concretized_res_names)) do (i, varname) :($varname = linearized_results[$i]) end @@ -765,75 +1416,210 @@ function codegen_xla_call(exec, flatten_names, donated_args_mask, nresults) xla_call_code = if nresults == 0 :() else - quote - GC.@preserve $(flatten_names...) begin - linearized_results = XLA.ExecutableCall( - $exec, - ($(flatten_buffer_refs...),), - $(Tuple(donated_args_mask)), - Val($nresults), - ) + if is_sharded + quote + GC.@preserve $(flatten_names...) begin + linearized_results = XLA.execute( + $exec, + ($(flatten_buffer_refs...),), + $(Tuple(donated_args_mask)), + Val($nresults), + Val($ndevices), + ) + end + $(concretized_res_code...) + end + else + quote + GC.@preserve $(flatten_names...) begin + linearized_results = XLA.execute_sharded( + $exec, + $(device), + ($(flatten_buffer_refs...),), + $(Tuple(donated_args_mask)), + Val($nresults), + ) + end + $(concretized_res_code...) end - $(concretized_res_code...) end end return concretized_res_names, xla_call_code end -function compile_xla(f, args; client=nothing, optimize=true) +function __add_mhlo_attributes_and_name!(mod::MLIR.IR.Module, f; kwargs...) + fname = string(f) + length(fname) > 10 && (fname = fname[1:7] * "...") + __add_mhlo_attributes_and_name!(mod, fname; kwargs...) + return nothing +end + +function __add_mhlo_attributes_and_name!( + mod::MLIR.IR.Module, fname::String; num_partitions=1, num_replicas=1 +) + moduleop = MLIR.IR.Operation(mod) + module_name = Reactant.TracedUtils.__lookup_unique_name_in_module( + mod, "reactant_" * fname + ) + module_name = MLIR.IR.Attribute(module_name) + MLIR.IR.attr!(moduleop, "mhlo.num_partitions", MLIR.IR.Attribute(num_partitions)) + MLIR.IR.attr!(moduleop, "mhlo.num_replicas", MLIR.IR.Attribute(num_replicas)) + MLIR.IR.attr!( + moduleop, String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()), module_name + ) + return nothing +end + +function __resolve_device_and_client(client, seen_args, linear_args, is_sharded) + if is_sharded + client === nothing && (client = XLA.default_backend()) + return client, nothing + end + + device = nothing + if length(linear_args) > 0 + devices_list = [ + XLA.device(only(k.data)) for (k, v) in seen_args if v isa TracedRArray + ] + if !isempty(devices_list) + if !allequal(devices_list) + msg = "Expected all arguments to be on the same device, got:\n" + for (i, device) in enumerate(devices_list) + msg *= " Device $(i): $(string(device))\n" + end + throw(ArgumentError(msg)) + end + @assert allequal(devices_list) "All arguments must be on the same device: $(devices_list)" + device = first(devices_list) + end + end + + if client === nothing + if device !== nothing + client = XLA.client(device) + else + client = XLA.default_backend() + device = XLA.default_device(client) + end + else + if device !== nothing + @assert client == XLA.client(device) "client ($(client)) and XLA.client(device) ($(XLA.client(device))) must be the same" + else + device = XLA.default_device(client) + end + end + + return (client, device) +end + +function compile_xla(f, args; client=nothing, kwargs...) # register MLIR dialects ctx = MLIR.IR.Context(Reactant.registry[], false) + context_gc_vector[ctx] = Vector{TracedRArray}(undef, 0) @ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid - return MLIR.IR.context!(ctx) do + backend = XLA.platform_name(client !== nothing ? client : XLA.default_backend()) + + if backend == "CUDA" + backend = "GPU" + elseif backend == "CPU" + backend = "cpu" + end + + MLIR.IR.activate!(ctx) + results = try # compile function to MLIR module mod = MLIR.IR.Module(MLIR.IR.Location()) - linear_args, linear_results, preserved_args, seen_args, concrete_result, isclosure = compile_mlir!( - mod, f, args; optimize + mlir_fn_res = compile_mlir!(mod, f, args; backend, kwargs...) + + # Resolve client and device + client, device = __resolve_device_and_client( + client, + mlir_fn_res.seen_args, + mlir_fn_res.linear_args, + mlir_fn_res.is_sharded, ) - if isnothing(client) - if length(linear_args) > 0 - for (k, _) in Iterators.filter(((_, v),) -> v isa TracedRArray, seen_args) - client = XLA.client(k.data) - end - end - if isnothing(client) - client = XLA.default_backend[] - end - end + # Attach a name, and partitioning attributes to the module + __add_mhlo_attributes_and_name!( + mod, f; mlir_fn_res.num_partitions, mlir_fn_res.num_replicas + ) # compile MLIR module to XLA executable - exec = XLA.Compile(client, mod) - return exec, - linear_args, linear_results, preserved_args, seen_args, concrete_result, - isclosure + global_device_ids = if mlir_fn_res.is_sharded + vec(mlir_fn_res.sharding_mesh.device_ids) + else + Int64[] + end + mlir_fn_res.is_sharded && (device = nothing) + + exec = XLA.compile( + client, + device, + mod; + num_outputs=length(mlir_fn_res.linear_results), + num_parameters=length(mlir_fn_res.linear_args), + mlir_fn_res.is_sharded, + global_device_ids, + mlir_fn_res.num_replicas, + mlir_fn_res.num_partitions, + ) + + return mod, exec, mlir_fn_res, device, client + finally + MLIR.IR.deactivate!(ctx) end + + Base.delete!(context_gc_vector, ctx) + return results end -function compile(f, args; client=nothing, optimize=true, sync=false) - exec, linear_args, linear_results, preserved_args, seen_args, concrete_result, isclosure = compile_xla( - f, args; client, optimize - ) +function compile(f, args; sync=false, kwargs...) + _, exec, mlir_fn_res, device, client = compile_xla(f, args; kwargs...) + (; linear_args, seen_args, linear_results, preserved_args, concrete_result) = + mlir_fn_res preserved_args_idx = last.(preserved_args) donated_args_mask = map(1:length(linear_args)) do i UInt8(i ∉ preserved_args_idx) end - fnwrap = isclosure ? f : nothing - closure_ty = typeof(fnwrap) - result_stores = Dict{Tuple,Symbol}() + path_to_shard_info = mlir_fn_res.is_sharded ? Dict{Tuple,Tuple}() : nothing # generate Julia `Thunk` code - flatten_arg_names, flatten_code = codegen_flatten!(linear_args, result_stores) + flatten_arg_names, flatten_code = codegen_flatten!( + linear_args, + seen_args, + result_stores, + mlir_fn_res.is_sharded, + mlir_fn_res.sharding_mesh, + XLA.get_parameter_shardings(exec), + client, + ) concretized_res_names, xla_call_code = codegen_xla_call( - exec, flatten_arg_names, donated_args_mask, length(linear_results) + exec, + device, + flatten_arg_names, + donated_args_mask, + length(linear_results), + mlir_fn_res.is_sharded, + mlir_fn_res.is_sharded ? length(mlir_fn_res.sharding_mesh) : 1, ) + linear_result_shard_info = if mlir_fn_res.is_sharded + output_shardings = XLA.get_output_shardings(exec) + XLA.compute_array_indices_and_hlo_sharding.( + output_shardings, + size.(mlir_fn_res.linear_results), + (mlir_fn_res.sharding_mesh.logical_device_ids,), + ) + else + ntuple(Returns(nothing), length(linear_results)) + end + unflatten_code = codegen_unflatten!( linear_args, preserved_args, @@ -841,6 +1627,9 @@ function compile(f, args; client=nothing, optimize=true, sync=false) linear_results, concrete_result, result_stores, + path_to_shard_info, + linear_result_shard_info, + mlir_fn_res.sharding_mesh, ) sync_call = if sync @@ -854,37 +1643,118 @@ function compile(f, args; client=nothing, optimize=true, sync=false) end fname = gensym(Symbol(Symbol(f), :_reactant)) - expr = :(function $(fname)(args...) - $( - # if `f` is a closure, then prepend the closure into `args` - # the closure fields will be correctly extracted from it as the tracer has already passed through it - if !(closure_ty <: Nothing) - :(args = ($fnwrap, args...)) - end - ) + + body = quote $(flatten_code...) $(xla_call_code) $(sync_call) $(unflatten_code...) return result - end) + end - body = expr.args[2] - return register_thunk(fname, body) + if mlir_fn_res.fnwrapped + body = quote + args = ($f, args...) + $body + end + end + + return register_thunk( + fname, Tuple{map(Core.Typeof, args)...}, body, f, mlir_fn_res.fnwrapped + ) end # inspired by RuntimeGeneratedFunction.jl const __thunk_body_cache = Dict{Symbol,Expr}() -struct Thunk{tag} end +struct Thunk{FTy,CTy,tag,IsClosure,ArgTypes} + f::FTy + closure::CTy +end + +struct MisMatchedThunkTypeError{ThunkTy,FoundTypes} <: Base.Exception end + +function Base.showerror( + io::IO, ece::MisMatchedThunkTypeError{Thunk{FTy,CTy,tag,ArgTypes,IsClosure},FoundTypes} +) where {FTy,CTy,tag,ArgTypes,FoundTypes,IsClosure} + print( + io, + "\nThe Reactant-compiled function `$(Thunk{FTy,CTy, tag, ArgTypes, IsClosure})` exists, but no method is defined for this combination of argument types.", + ) + print( + io, + "\nYou passed in arguments with types\n\t(" * + join(FoundTypes.parameters, ", ") * + ")", + ) + return print( + io, + "\nHowever the method you are calling was compiled for arguments with types\n\t(" * + join(ArgTypes.parameters, ", ") * + ")", + ) +end -@generated function (thunk::Thunk{tag})(args...) where {tag} - return __thunk_body_cache[tag] +@generated function (thunk::Thunk{FTy,CTy,tag,ArgTypes,IsClosure})( + args... +) where {FTy,CTy,tag,ArgTypes,IsClosure} + FoundTypes = Tuple{args...} + if ArgTypes != FoundTypes + return quote + throw( + $(MisMatchedThunkTypeError{Thunk{FTy,CTy,tag,ArgTypes,IsClosure},FoundTypes}()) + ) + end + end + return quote + Base.invokelatest(thunk.closure,args...) + end end -function register_thunk(tag, body) +function register_thunk( + tag::Symbol, @nospecialize(argtys::Type), body::Expr, @nospecialize(f), isclosure::Bool +) __thunk_body_cache[tag] = body - return Thunk{tag}() + body = :((args...)->begin $body end) + closure = @eval $body + t = Thunk{Core.Typeof(f),Core.Typeof(closure),tag,argtys,isclosure}(f, closure) + @error t + return t +end + +for cache_type in (:callcache, :sdycache) + activate_fn = Symbol(:activate_, cache_type, :!) + deactivate_fn = Symbol(:deactivate_, cache_type, :!) + has_fn = Symbol(:_has_, cache_type) + + @eval begin + function $(activate_fn)(cache) + stack = get!(task_local_storage(), $(Meta.quot(cache_type))) do + return [] + end + push!(stack, cache) + return nothing + end + + function $(deactivate_fn)(cache) + cache === last(task_local_storage($(Meta.quot(cache_type)))) || + error("Deactivating wrong cache") + return pop!(task_local_storage($(Meta.quot(cache_type)))) + end + + function $(has_fn)() + return haskey(task_local_storage(), $(Meta.quot(cache_type))) && + !Base.isempty(task_local_storage($(Meta.quot(cache_type)))) + end + + function $(cache_type)(; throw_error::Bool=true) + if !$(has_fn)() + throw_error && error("No cache is active") + return nothing + end + return last(task_local_storage($(Meta.quot(cache_type)))) + end + end end end diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index d42ca08b6f..6e4d03a159 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -1,137 +1,99 @@ -struct XLAArray{T,N} <: RArray{T,N} - # size::NTuple{N,Int} +function get_buffer( + x::Union{ConcretePJRTArray,ConcretePJRTNumber}; no_error_for_scalar=false +) + if Sharding.is_sharded(x.sharding) + # For scalars this is mostly replicated + no_error_for_scalar && return first(x.data).buffer + error("`x` is sharded, so `get_buffer` is not defined") + end + return only(x.data).buffer end -mutable struct ConcreteRArray{T,N} <: RArray{T,N} - data::XLA.AsyncBuffer - # data::XLAArray{T, N} - shape::NTuple{N,Int} +function Base.collect(x::ConcretePJRTNumber{T}) where {T} + return collect(ConcretePJRTArray{T,0}(copy(x).data, ())) end -const WrappedConcreteRArray{T,N} = WrappedArray{T,N,ConcreteRArray,ConcreteRArray{T,N}} -const AnyConcreteRArray{T,N} = Union{ConcreteRArray{T,N},WrappedConcreteRArray{T,N}} - -mutable struct ConcreteRNumber{T} <: RNumber{T} - data::XLA.AsyncBuffer +Base.size(::AbstractConcreteNumber) = () +Base.real(x::AbstractConcreteNumber{<:Real}) = x +function Base.rtoldefault(T::Type{<:AbstractConcreteNumber}) + return T(Base.rtoldefault(unwrapped_eltype(T))) end -function ConcreteRNumber{T}( - data::T2; client=XLA.default_backend[], idx=XLA.default_device_idx[], device=nothing -) where {T<:Number,T2<:Number} - data = convert(T, data) - crarray = ConcreteRArray(fill(data); client, idx, device) - return ConcreteRNumber{T}(crarray.data) -end -function ConcreteRNumber( - data::T; client=XLA.default_backend[], idx=XLA.default_device_idx[], device=nothing -) where {T<:Number} - crarray = ConcreteRArray(fill(data); client, idx, device) - return ConcreteRNumber{T}(crarray.data) -end - -Base.collect(x::ConcreteRNumber{T}) where {T} = ConcreteRArray{T,0}(copy(x).data, ()) - -Base.size(::ConcreteRNumber) = () -Base.real(x::ConcreteRNumber{<:Real}) = x -function Base.rtoldefault(::Type{ConcreteRNumber{T}}) where {T} - return ConcreteRNumber(Base.rtoldefault(T)) -end +Base.strides(x::AbstractConcreteArray) = Base.size_to_strides(1, size(x)...) # Ensure the device and client are the same as the input -function Base.float(x::ConcreteRNumber{T}) where {T} - client = XLA.client(x.data) - device = XLA.device(x.data) - return ConcreteRNumber(float(T)(to_number(x)); client, device) +function Base.float(x::ConcretePJRTNumber{T}) where {T} + return ConcretePJRTNumber( + float(T)(to_number(x)); client=XLA.client(x), device=XLA.device(x), x.sharding + ) end # written like this to avoid ambiguity errors for T in Base.uniontypes(ReactantPrimitive) - @eval (::Type{$(T)})(x::ConcreteRNumber) = convert($T, x) + @eval (::Type{$(T)})(x::AbstractConcreteNumber) = convert($T, x) end -Base.convert(::Type{T}, x::ConcreteRNumber) where {T<:Number} = convert(T, to_number(x)) - -function ConcreteRArray( - data::T; client=XLA.default_backend[], idx=XLA.default_device_idx[], device=nothing -) where {T<:Number} - Base.depwarn( - "ConcreteRArray(data::Number) is deprecated, use ConcreteRNumber(data) instead", - :ConcreteRArray, - ) - return ConcreteRArray(fill(data); client, idx, device) +function Base.convert(::Type{T}, x::AbstractConcreteNumber) where {T<:Number} + return convert(T, to_number(x)) end -const ConcreteRScalar{T} = Union{ConcreteRArray{T,0},ConcreteRNumber{T}} +Adapt.adapt_storage(::Type{T}, x::AbstractArray) where {T<:AbstractConcreteArray} = T(x) -Adapt.adapt_storage(::Type{T}, x::AbstractArray) where {T<:ConcreteRArray} = T(x) +Base.size(x::AbstractConcreteArray) = x.shape -function ConcreteRArray( - data::Array{T,N}; - client=XLA.default_backend[], - idx=XLA.default_device_idx[], - device=nothing, -) where {T,N} - device = device === nothing ? XLA.ClientGetDevice(client, idx) : device - return ConcreteRArray{T,N}( - XLA.AsyncBuffer(XLA.ArrayFromHostBuffer(client, data, device), nothing), size(data) - ) -end +Base.isempty(x::Union{AbstractConcreteArray,AbstractConcreteNumber}) = any(isempty, x.data) -ConcreteRArray(x::AnyConcreteRArray) = ConcreteRArray{eltype(x),ndims(x)}(x) -ConcreteRArray{T}(x::AnyConcreteRArray) where {T} = ConcreteRArray{T,ndims(x)}(x) -ConcreteRArray{T,N}(x::ConcreteRArray{T,N}) where {T,N} = x -function ConcreteRArray{T,N}(x::AnyConcreteRArray) where {T,N} - return ConcreteRArray(convert(Array{T,N}, x)) -end +Base.isempty(x::WrappedConcretePJRTArray) = isempty(ancestor(x)) -Base.size(x::ConcreteRArray) = x.shape +function Base.convert(::Type{<:Array}, X::ConcretePJRTArray{T,N}) where {T,N} + if Sharding.is_sharded(X) + data = Array{T,N}(undef, size(X)...) -function Base.convert(::Type{T}, X::ConcreteRArray{ElType,N}) where {T<:Array,ElType,N} - data = Array{ElType,N}(undef, size(X)...) # TODO replace for `similar`? - XLA.await(X.data) - buf = X.data.buffer - GC.@preserve data buf begin - XLA.BufferToHost(buf, pointer(data)) + completed = Set{eltype(X.sharding.device_to_array_slices)}() + for idx in 1:length(X.data) + slice = X.sharding.device_to_array_slices[idx] + if slice ∉ completed + push!(completed, slice) + else + continue + end + data[slice...] = convert(Array{T}, X.data[idx]) + end + + return data + else + buf = XLA.synced_buffer(only(X.data)) + GC.@preserve buf begin + return convert(Array{T}, buf) + end end - return data - # XLA.from_row_major(data) end -function Base.convert( - ::Type{T}, X::WrappedConcreteRArray{ElType,N} -) where {T<:Array,ElType,N} +function Base.convert(::Type{<:Array}, X::WrappedConcretePJRTArray) fn = compile(TracedUtils.materialize_traced_array, (X,)) return convert(Array, fn(X)) end -Base.Array(x::AnyConcreteRArray) = convert(Array, x) +Base.Array(x::AnyConcretePJRTArray) = convert(Array, x) -function synchronize(x::Union{ConcreteRArray,ConcreteRNumber}) - XLA.synced_buffer(x.data) +function synchronize(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) + foreach(XLA.synced_buffer, x.data) return nothing end -# function Base.similar(x::ConcreteRArray{T,N}, ::Type{T2}) where {T,N,T2} -# return ConcreteRArray{T,N}(x.data) -# end -# function Base.convert(::Type{ConcreteRArray{T2,N}}, x::ConcreteRArray{T,N}) where {T,N,T2} -# return ConcreteRArray{T,N}(x.data) -# end - -function to_number(X::ConcreteRScalar{T}) where {T} +to_number(x::Number) = x +function to_number(X::ConcretePJRTScalar{T}) where {T} data = Ref{T}() - XLA.await(X.data) - buf = X.data.buffer + XLA.await(X) + buf = get_buffer(X; no_error_for_scalar=true) GC.@preserve data buf begin - XLA.BufferToHost(buf, data) + XLA.to_host(buf, data) end return data[] end -Base.convert(::Type{T}, x::ConcreteRScalar{T}) where {T} = to_number(x) +Base.convert(::Type{T}, x::ConcretePJRTScalar{T}) where {T<:Number} = to_number(x) -for jlop in (:(Base.abs),), T in (ConcreteRNumber,) - @eval begin - $(jlop)(x::$(T)) = $(jlop)(to_number(x)) - end +for jlop in (:(Base.abs),), T in (AbstractConcreteNumber,) + @eval $(jlop)(x::$(T)) = $(jlop)(to_number(x)) end for jlop in ( @@ -143,7 +105,7 @@ for jlop in ( :(Base.:^), :(Base.:(==)), ), - T in (ConcreteRNumber, ConcreteRArray{<:Any,0}) + T in (AbstractConcreteNumber, AbstractConcreteArray{<:Any,0}) @eval begin $(jlop)(x::$(T), y::$(T)) = $(jlop)(to_number(x), to_number(y)) @@ -152,41 +114,43 @@ for jlop in ( end end -for T in (ConcreteRNumber, ConcreteRArray{<:Any,0}) - @eval begin - function Base.isapprox(x::$(T), y::Number; kwargs...) - return Base.isapprox(to_number(x), y; kwargs...) - end +for jlop in (:(Base.isnan), :(Base.isfinite)), + T in (AbstractConcreteNumber, AbstractConcreteArray{<:Any,0}) - function Base.isapprox(x::Number, y::$(T); kwargs...) - return Base.isapprox(x, to_number(y); kwargs...) - end + @eval $(jlop)(x::$(T)) = $(jlop)(to_number(x)) +end - function Base.isapprox(x::$(T), y::$(T); kwargs...) - return Base.isapprox(to_number(x), to_number(y); kwargs...) +for T in (AbstractConcreteNumber, AbstractConcreteArray{<:Any,0}) + for (T1, T2) in ((T, Number), (Number, T), (T, T)) + @eval begin + function Base.isapprox(x::$(T1), y::$(T2); kwargs...) + return Base.isapprox(to_number(x), to_number(y); kwargs...) + end + function Base.isapprox( + x::AbstractArray{<:$(T1)}, y::AbstractArray{<:$(T2)}; kwargs... + ) + return Base.isapprox(to_number.(x), to_number.(y); kwargs...) + end end end end -function Base.isapprox(x::AnyConcreteRArray, y::AbstractArray; kwargs...) - return Base.isapprox(convert(Array, x), convert(Array, y); kwargs...) -end -function Base.isapprox(x::AbstractArray, y::AnyConcreteRArray; kwargs...) - return Base.isapprox(convert(Array, x), convert(Array, y); kwargs...) -end -function Base.isapprox(x::AnyConcreteRArray, y::AnyConcreteRArray; kwargs...) - return Base.isapprox(convert(Array, x), convert(Array, y); kwargs...) -end - -Base.:(==)(x::AnyConcreteRArray, y::AbstractArray) = convert(Array, x) == convert(Array, y) -Base.:(==)(x::AbstractArray, y::AnyConcreteRArray) = convert(Array, x) == convert(Array, y) -function Base.:(==)(x::AnyConcreteRArray, y::AnyConcreteRArray) - return convert(Array, x) == convert(Array, y) +for (T1, T2) in ( + (AnyConcretePJRTArray, AbstractArray), + (AbstractArray, AnyConcretePJRTArray), + (AnyConcretePJRTArray, AnyConcretePJRTArray), +) + @eval begin + function Base.isapprox(x::$(T1), y::$(T2); kwargs...) + return Base.isapprox(convert(Array, x), convert(Array, y); kwargs...) + end + Base.:(==)(x::$(T1), y::$(T2)) = convert(Array, x) == convert(Array, y) + end end -function Base.show(io::IO, X::ConcreteRScalar{T}) where {T} - if X.data == XLA.AsyncEmptyBuffer - println(io, "") +function Base.show(io::IO, X::ConcretePJRTScalar{T}) where {T} + if isempty(X) + print(io, "") return nothing end print(io, "$(typeof(X))(") @@ -195,19 +159,24 @@ function Base.show(io::IO, X::ConcreteRScalar{T}) where {T} return nothing end -function Base.print_array(io::IO, X::AnyConcreteRArray) - data = ancestor(X).data - if data == XLA.AsyncEmptyBuffer - println(io, "") +function Base.print_array(io::IO, X::AnyConcretePJRTArray) + if isempty(X) + print(io, "") return nothing end return Base.print_array(io, convert(Array, X)) end -function Base.show(io::IO, X::AnyConcreteRArray) - data = ancestor(X).data - if data == XLA.AsyncEmptyBuffer - println(io, "") +function Base.showarg(io::IO, a::ConcretePJRTArray{T,N}, toplevel) where {T,N} + toplevel || print(io, "::") + print(io, "ConcretePJRTArray{$T,$N}") + Sharding.is_sharded(a) && print(io, " with sharding $(typeof(a.sharding.sharding))") + return nothing +end + +function Base.show(io::IO, X::AnyConcretePJRTArray) + if isempty(X) + print(io, "") return nothing end print(io, "$(typeof(X))(") @@ -216,29 +185,25 @@ function Base.show(io::IO, X::AnyConcreteRArray) return nothing end -function Base.getindex(a::ConcreteRArray{T}, args::Vararg{Int,N}) where {T,N} - if a.data == XLA.AsyncEmptyBuffer - throw("Cannot getindex from empty buffer") - end +function Base.getindex(a::ConcretePJRTArray{T}, args::Vararg{Int,N}) where {T,N} + isempty(a) && throw("Cannot getindex from empty buffer") - XLA.await(a.data) - if buffer_on_cpu(a) - buf = a.data.buffer + XLA.await(a) + if buffer_on_cpu(a) && !Sharding.is_sharded(a) + buf = get_buffer(a) GC.@preserve buf begin - ptr = Base.unsafe_convert(Ptr{T}, XLA.UnsafeBufferPointer(buf)) + ptr = Base.unsafe_convert(Ptr{T}, XLA.unsafe_buffer_pointer(buf)) start = 0 for i in 1:N start *= size(a, N - i + 1) start += (args[N - i + 1] - 1) - # start *= size(a, i) - # start += (args[i]-1) end start += 1 return unsafe_load(ptr, start) end end - GPUArraysCore.assertscalar("getindex(::ConcreteRArray, ::Vararg{Int, N})") + GPUArraysCore.assertscalar("getindex(::ConcretePJRTArray, ::Vararg{Int, N})") return convert(Array, a)[args...] end @@ -247,22 +212,18 @@ function mysetindex!(a, v, args::Vararg{Any,N}) where {N} return nothing end -function Base.setindex!(a::ConcreteRArray{T}, v, args::Vararg{Int,N}) where {T,N} - if a.data == XLA.AsyncEmptyBuffer - throw("Cannot setindex! to empty buffer") - end +function Base.setindex!(a::ConcretePJRTArray{T}, v, args::Vararg{Int,N}) where {T,N} + isempty(a) && throw("Cannot setindex! to empty buffer") - XLA.await(a.data) - if buffer_on_cpu(a) - buf = a.data.buffer + XLA.await(a) + if buffer_on_cpu(a) && !Sharding.is_sharded(a) + buf = get_buffer(a) GC.@preserve buf begin - ptr = Base.unsafe_convert(Ptr{T}, XLA.UnsafeBufferPointer(buf)) + ptr = Base.unsafe_convert(Ptr{T}, XLA.unsafe_buffer_pointer(buf)) start = 0 for i in 1:N start *= size(a, N - i + 1) start += (args[N - i + 1] - 1) - # start *= size(a, i) - # start += (args[i]-1) end start += 1 unsafe_store!(ptr, v, start) @@ -270,63 +231,71 @@ function Base.setindex!(a::ConcreteRArray{T}, v, args::Vararg{Int,N}) where {T,N return a end - GPUArraysCore.assertscalar("setindex!(::ConcreteRArray, ::Any, ::Vararg{Int, N})") + GPUArraysCore.assertscalar("setindex!(::ConcretePJRTArray, ::Any, ::Vararg{Int, N})") fn = compile(mysetindex!, (a, v, args...)) fn(a, v, args...) return a end # TODO is there any way to allocate an uninitialized buffer in XLA? -function Base.similar(a::ConcreteRArray{T}, ::Type{S}=T, dims::Dims=size(a)) where {T,S} - return ConcreteRArray(Array{S}(undef, dims)) +function Base.similar(a::ConcretePJRTArray{T}, ::Type{S}=T, dims::Dims=size(a)) where {T,S} + return ConcretePJRTArray( + Array{S}(undef, dims); client=XLA.client(a), device=XLA.device(a), a.sharding + ) end -Base.similar(a::ConcreteRArray, dims::Dims) = similar(a, eltype(a), dims) - -function Base.similar(::Type{ConcreteRArray{T}}, dims) where {T} - return ConcreteRArray(similar(Array{T}, dims)) +Base.similar(a::ConcretePJRTArray, dims::Dims) = similar(a, eltype(a), dims) +function Base.similar(::Type{ConcretePJRTArray{T}}, dims) where {T} + return ConcretePJRTArray(similar(Array{T}, dims)) end # Broadcasting interface -Base.BroadcastStyle(::Type{<:ConcreteRArray}) = Broadcast.ArrayStyle{ConcreteRArray}() +Base.BroadcastStyle(::Type{<:ConcretePJRTArray}) = Broadcast.ArrayStyle{ConcretePJRTArray}() function Base.similar( - bc::Base.Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcreteRArray}}, ::Type{T} + bc::Base.Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcretePJRTArray}}, ::Type{T} ) where {T} - return ConcreteRArray(similar(Array{T}, axes(bc))) + # XXX: correct device + sharding? + return ConcretePJRTArray(similar(Array{T}, axes(bc))) end # TODO replace this copy for `setindex!` maybe? how to copy data to already existing buffer? (i.e. `copyto!`) -function Base.copy(bc::Base.Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcreteRArray}}) +function Base.copy(bc::Base.Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcretePJRTArray}}) for x in bc.args - x isa ConcreteRArray && XLA.await(x.data) + x isa ConcretePJRTArray && XLA.await(x) end - all_on_cpu = all(buffer_on_cpu, bc.args) - if all_on_cpu + if all(buffer_on_cpu, bc.args) && all( + x -> + !(x isa ConcretePJRTArray) || + (x isa ConcretePJRTArray && !Sharding.is_sharded(x)), + bc.args, + ) ElType = Base.Broadcast.combine_eltypes(bc.f, bc.args) if !Base.isconcretetype(ElType) throw( ErrorException( - "`copy` on `ConcreteRArray` for non-concrete eltype is not implemented" + "`copy` on `ConcretePJRTArray` for non-concrete eltype is not implemented", ), ) end aux = copyto!(similar(Array{ElType}, axes(bc)), bc) - return ConcreteRArray(aux) + return ConcretePJRTArray(aux) # XXX: result should be on correct device? end fn = compile(Broadcast.BroadcastFunction(bc.f), (bc.args...,)) return fn(bc.args...) end -function Base.copyto!(dest::ConcreteRArray, src::ConcreteRArray) +function Base.copyto!(dest::AbstractConcreteArray, src::AbstractConcreteArray) dest.data = src.data return dest end +Base.collect(x::AbstractConcreteArray) = convert(Array, x) + function Base.mapreduce( @nospecialize(f), @nospecialize(op), - @nospecialize(A::ConcreteRArray{T,N}); + @nospecialize(A::AbstractConcreteArray{T,N}); dims=:, init=nothing, ) where {T,N} @@ -344,28 +313,30 @@ end (f::CallMapReduce)(A) = Base.mapreduce(f.f, f.op, A; f.dims, f.init) buffer_on_cpu(::Any) = true -buffer_on_cpu(x::ConcreteRArray) = XLA.BufferOnCPU(x.data.buffer) +buffer_on_cpu(x::ConcretePJRTArray) = all(XLA.buffer_on_cpu, x.data) -function Ops.constant(x::ConcreteRArray; kwargs...) +function Ops.constant(x::AbstractConcreteArray; kwargs...) return Ops.constant(Base.convert(Array, x); kwargs...) end -function Ops.constant(x::ConcreteRNumber{T}; kwargs...) where {T} +function Ops.constant(x::AbstractConcreteNumber{T}; kwargs...) where {T} return Ops.constant(Base.convert(T, x); kwargs...) end -Base.zero(x::ConcreteRArray{T,N}) where {T,N} = ConcreteRArray(zeros(T, size(x)...)) +function Base.zero(x::ConcretePJRTArray{T,N}) where {T,N} + return ConcretePJRTArray( + zeros(T, size(x)...); client=XLA.client(x), device=XLA.device(x), x.sharding + ) +end -function Base.fill!(a::ConcreteRArray{T,N}, val) where {T,N} - if a.data == XLA.AsyncEmptyBuffer - throw("Cannot setindex! to empty buffer") - end +function Base.fill!(a::ConcretePJRTArray{T,N}, val) where {T,N} + isempty(a) && throw("Cannot setindex! to empty buffer") - XLA.await(a.data) - if buffer_on_cpu(a) - buf = a.data.buffer + XLA.await(a) + if buffer_on_cpu(a) && !Sharding.is_sharded(a) + buf = get_buffer(a) GC.@preserve buf begin - ptr = Base.unsafe_convert(Ptr{T}, XLA.UnsafeBufferPointer(buf)) + ptr = Base.unsafe_convert(Ptr{T}, XLA.unsafe_buffer_pointer(buf)) for start in 1:length(a) unsafe_store!(ptr, val, start) end diff --git a/src/ControlFlow.jl b/src/ControlFlow.jl index 0e0c001956..d3271de59e 100644 --- a/src/ControlFlow.jl +++ b/src/ControlFlow.jl @@ -1,159 +1,13 @@ function ReactantCore.traced_if( cond::TracedRNumber{Bool}, true_fn::TFn, false_fn::FFn, args ) where {TFn,FFn} - (_, true_branch_compiled, true_branch_results, _, _, _, _, _, true_linear_results) = Reactant.TracedUtils.make_mlir_fn( - true_fn, - args, - (), - string(gensym("true_branch")), - false; - return_dialect=:stablehlo, - no_args_in_result=true, - construct_function_without_args=true, - ) - - (_, false_branch_compiled, false_branch_results, _, _, _, _, _, false_linear_results) = Reactant.TracedUtils.make_mlir_fn( - false_fn, - args, - (), - string(gensym("false_branch")), - false; - return_dialect=:stablehlo, - no_args_in_result=true, - construct_function_without_args=true, - ) - - @assert length(true_branch_results) == length(false_branch_results) "true branch returned $(length(true_branch_results)) results, false branch returned $(length(false_branch_results)). This shouldn't happen." - - result_types = MLIR.IR.Type[] - linear_results = [] - true_block_insertions = [] - false_block_insertions = [] - for (i, (tr, fr)) in enumerate(zip(true_branch_results, false_branch_results)) - if typeof(tr) != typeof(fr) - if !(tr isa MissingTracedValue) && !(fr isa MissingTracedValue) - error("Result #$(i) for the branches have different types: true branch \ - returned `$(typeof(tr))`, false branch returned `$(typeof(fr))`.") - elseif tr isa MissingTracedValue - push!(result_types, MLIR.IR.type(fr.mlir_data)) - push!(linear_results, TracedUtils.new_traced_value(false_linear_results[i])) - push!(true_block_insertions, (i => linear_results[end])) - else - push!(result_types, MLIR.IR.type(tr.mlir_data)) - push!(linear_results, TracedUtils.new_traced_value(true_linear_results[i])) - push!(false_block_insertions, (i => linear_results[end])) - end - else - push!(result_types, MLIR.IR.type(tr.mlir_data)) - push!(linear_results, TracedUtils.new_traced_value(tr)) - end - end - - # Replace all uses of missing values with the correct values - true_branch_region = get_region_removing_missing_values( - true_branch_compiled, true_block_insertions - ) - - false_branch_region = get_region_removing_missing_values( - false_branch_compiled, false_block_insertions - ) - - MLIR.IR.rmfromparent!(true_branch_compiled) - MLIR.IR.rmfromparent!(false_branch_compiled) - - if_compiled = MLIR.Dialects.stablehlo.if_( - cond.mlir_data; - true_branch=true_branch_region, - false_branch=false_branch_region, - result_0=result_types, - ) - - return map(enumerate(linear_results)) do (i, res) - res.mlir_data = MLIR.IR.result(if_compiled, i) - return res - end -end - -function ReactantCore.traced_while( - cond_fn::CFn, body_fn::BFn, args -) where {CFn<:Function,BFn<:Function} - # TODO: detect and prevent mutation within the condition - - # We promote all incoming args (is there a better way to do this?) - traced_args = [ - if v isa Number && !(v isa TracedType) - Reactant.TracedUtils.promote_to(TracedRNumber{typeof(v)}, v) - else - v - end for v in args - ] - - (_, cond_fn_compiled, cond_fn_results, _, _, _, _, in_tys, cond_fn_linear_results) = Reactant.TracedUtils.make_mlir_fn( - cond_fn, - traced_args, - (), - string(gensym("cond_fn")), - false; - no_args_in_result=true, - return_dialect=:stablehlo, - do_transpose=false, - ) - - (_, body_fn_compiled, body_fn_results, _, _, _, _, _, body_fn_linear_results) = Reactant.TracedUtils.make_mlir_fn( - body_fn, - traced_args, - (), - string(gensym("body_fn")), - false; - no_args_in_result=true, - return_dialect=:stablehlo, - do_transpose=false, - ) - - cond_reg = take_region(cond_fn_compiled) - body_reg = take_region(body_fn_compiled) - - MLIR.IR.rmfromparent!(cond_fn_compiled) - MLIR.IR.rmfromparent!(body_fn_compiled) - - result_0 = in_tys - - operands = MLIR.IR.Value[v.mlir_data for v in traced_args] - - while_compiled = MLIR.Dialects.stablehlo.while_( - operands; result_0, cond=cond_reg, body=body_reg - ) - - return map(enumerate(traced_args)) do (i, res) - res.mlir_data = MLIR.IR.result(while_compiled, i) - return res - end + return Ops.if_condition(cond, true_fn, false_fn, args...) end -function take_region(compiled_fn) - region = MLIR.IR.Region() - MLIR.API.mlirRegionTakeBody(region, MLIR.API.mlirOperationGetRegion(compiled_fn, 0)) - return region +function ReactantCore.traced_call(f::Function, args...) + return Ops.call(f, args...) end -function get_region_removing_missing_values(compiled_fn, insertions) - region = take_region(compiled_fn) - block = MLIR.IR.Block(MLIR.API.mlirRegionGetFirstBlock(region), false) - return_op = MLIR.IR.terminator(block) - for (i, rt) in insertions - if rt isa TracedRNumber - attr = MLIR.IR.DenseElementsAttribute(Array{eltype(rt)}(undef, ())) - op = MLIR.Dialects.stablehlo.constant(; value=attr) - elseif rt isa TracedRArray - attr = MLIR.IR.DenseElementsAttribute(Array{eltype(rt)}(undef, size(rt))) - op = MLIR.Dialects.stablehlo.constant(; value=attr) - else - error("Unknown type $(typeof(rt))") - end - MLIR.IR.rmfromparent!(op) - insert!(block, 1, op) - val = MLIR.IR.result(op, 1) - MLIR.API.mlirValueReplaceAllUsesOfWith(MLIR.IR.operand(return_op, i), val) - end - return region +function ReactantCore.traced_while(cond_fn::CFn, body_fn::BFn, args) where {CFn,BFn} + return Ops.while_loop(cond_fn, body_fn, args...) end diff --git a/src/Devices.jl b/src/Devices.jl new file mode 100644 index 0000000000..57f95647bd --- /dev/null +++ b/src/Devices.jl @@ -0,0 +1,60 @@ +""" + devices(backend::String) + devices(backend::XLA.AbstractClient = XLA.default_backend()) + +Return a list of devices available for the given client. +""" +devices(backend::String) = devices(XLA.client(backend)) + +devices(client::XLA.AbstractClient=XLA.default_backend()) = XLA.devices(client) + +""" + addressable_devices(backend::String) + addressable_devices(backend::XLA.AbstractClient = XLA.default_backend()) + +Return a list of addressable devices available for the given client. +""" +addressable_devices(backend::String) = addressable_devices(XLA.client(backend)) + +function addressable_devices(client::XLA.AbstractClient=XLA.default_backend()) + return XLA.addressable_devices(client) +end + +# https://github.com/jax-ml/jax/blob/152099ee0ef31119f16f4c2dac50d84fcb1575ef/jax/_src/hardware_utils.py#L19-L55 +const _GOOGLE_PCI_VENDOR_ID = "0x1ae0" +const _TPU_PCI_DEVICE_IDS = ( + # TPU v2, v3 + "0x0027", + # No public name (plc) + "0x0056", + # TPU v4 + "0x005e", + # TPU v5p + "0x0062", + # TPU v5e + "0x0063", + # TPU v6e + "0x006f", +) + +function has_tpu() + Sys.islinux() || return false + + devices_dir = "/sys/bus/pci/devices/" + isdir(devices_dir) || return false + + try + for path in readdir(devices_dir; join=true, sort=false) + if strip(read(joinpath(path, "vendor"), String)) == _GOOGLE_PCI_VENDOR_ID && + strip(read(joinpath(path, "device"), String)) in _TPU_PCI_DEVICE_IDS + return true + end + end + catch ex + @warn "failed to query PCI device information" maxlog = 1 exception = ( + ex, catch_backtrace() + ) + end + + return false +end diff --git a/src/Distributed.jl b/src/Distributed.jl new file mode 100644 index 0000000000..5e7c498777 --- /dev/null +++ b/src/Distributed.jl @@ -0,0 +1,162 @@ +module Distributed + +using ..Reactant: Reactant + +const initialized = Ref(false) + +function initialize(; + coordinator_address::Union{Nothing,String}=nothing, + num_processes::Union{Nothing,Integer}=nothing, + process_id::Union{Nothing,Integer}=nothing, + local_gpu_device_ids::Union{Nothing,Vector{Int}}=nothing, + initialization_timeout_in_seconds::Integer=300, + kwargs..., +) + @assert !initialized[] "`Distributed.initialize` has already been called" + + (coordinator_address, num_processes, process_id, local_gpu_device_ids) = auto_detect_unset_distributed_params(; + coordinator_address, + num_processes, + process_id, + local_gpu_device_ids, + initialization_timeout_in_seconds, + ) + + @debug "Detected Reactant distributed params" coordinator_address num_processes process_id local_gpu_device_ids + + Reactant.XLA.update_global_state!(; + coordinator_address, num_processes, process_id, local_gpu_device_ids, kwargs... + ) + + @debug "New Global State" Reactant.XLA.global_state + + initialized[] = true + return nothing +end + +abstract type AbstractClusterEnvDetector end + +abstract type AbstractOMPIClusterEnvDetector <: AbstractClusterEnvDetector end + +struct OpenMPIORTEEnvDetector <: AbstractOMPIClusterEnvDetector end +struct OpenMPIPMIXEnvDetector <: AbstractOMPIClusterEnvDetector end + +struct MPIEnvDetector <: AbstractClusterEnvDetector end + +# Based on https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/cluster.py + +is_env_present(::AbstractClusterEnvDetector) = false + +function get_coordinator_address end +function get_process_count end +function get_process_id end +function get_local_process_id end + +function auto_detect_unset_distributed_params(; + detector_list=[OpenMPIORTEEnvDetector(), OpenMPIPMIXEnvDetector(), MPIEnvDetector()], + coordinator_address::Union{Nothing,String}=nothing, + num_processes::Union{Nothing,Integer}=nothing, + process_id::Union{Nothing,Integer}=nothing, + local_gpu_device_ids::Union{Nothing,Vector{Int}}=nothing, + initialization_timeout_in_seconds::Integer=300, +) + if all( + Base.Fix2(!==, nothing), + (coordinator_address, num_processes, process_id, local_gpu_device_ids), + ) + return coordinator_address, num_processes, process_id, local_gpu_device_ids + end + + idx = findfirst(is_env_present, detector_list) + if idx === nothing + error("Couldn't find a functional cluster environment detector. Attempted to use: \ + $(detector_list)") + end + + detector = detector_list[idx] + + @debug "Detected cluster environment" detector + + if coordinator_address === nothing + coordinator_address = get_coordinator_address( + detector, initialization_timeout_in_seconds + ) + end + + if num_processes === nothing + num_processes = get_process_count(detector) + end + + if process_id === nothing + process_id = get_process_id(detector) + end + + if local_gpu_device_ids === nothing + local_gpu_device_ids = [get_local_process_id(detector)] + end + + return coordinator_address, num_processes, process_id, local_gpu_device_ids +end + +# OpenMPIORTEEnvDetector & OpenMPIPMIXEnvDetector +# Based on https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/ompi_cluster.py and adapted for latest OpenMPI versions +const _ORTE_URI = "OMPI_MCA_orte_hnp_uri" +const _PMIX_SERVER_URI = ( + "PMIX_SERVER_URI2", + "PMIX_SERVER_URI3", + "PMIX_SERVER_URI4", + "PMIX_SERVER_URI41", + "PMIX_SERVER_URI21", +) +const _OMPI_PROCESS_COUNT = "OMPI_COMM_WORLD_SIZE" +const _OMPI_PROCESS_ID = "OMPI_COMM_WORLD_RANK" +const _OMPI_LOCAL_PROCESS_ID = "OMPI_COMM_WORLD_LOCAL_RANK" + +is_env_present(::OpenMPIORTEEnvDetector) = haskey(ENV, _ORTE_URI) +is_env_present(::OpenMPIPMIXEnvDetector) = any(Base.Fix1(haskey, ENV), _PMIX_SERVER_URI) + +function get_coordinator_address(::OpenMPIORTEEnvDetector, ::Integer) + orte_uri = ENV[_ORTE_URI] + + job_id = parse(Int, split(orte_uri, '.'; limit=2)[1]) + port = job_id % 2^12 + (65535 - 2^12 + 1) + + launcher_ip_match = match(r"tcp://(.+?)[,:]|tcp6://\[(.+?)[,\]]", orte_uri) + + @assert launcher_ip_match !== nothing "Could not parse coordinator IP address from \ + Open MPI environment." + + launcher_ip = launcher_ip_match.captures[findfirst( + !isnothing, launcher_ip_match.captures + )] + return "$(launcher_ip):$(port)" +end + +function get_coordinator_address(::OpenMPIPMIXEnvDetector, ::Integer) + varname = findfirst(Base.Fix1(haskey, ENV), _PMIX_SERVER_URI) + pmix_uri = ENV[_PMIX_SERVER_URI[varname]] + + job_id = parse(Int, split(split(pmix_uri, '-'; limit=3)[3], "@"; limit=2)[1]) + port = job_id % 2^12 + (65535 - 2^12 + 1) + + launcher_ip_match = match(r"tcp4://(.+?):|tcp6://\[(.+?)\]", pmix_uri) + + @assert launcher_ip_match !== nothing "Could not parse coordinator IP address from \ + Open MPI environment." + + launcher_ip = launcher_ip_match.captures[findfirst( + !isnothing, launcher_ip_match.captures + )] + + return "$(launcher_ip):$(port)" +end + +get_process_count(::AbstractOMPIClusterEnvDetector) = parse(Int, ENV[_OMPI_PROCESS_COUNT]) + +get_process_id(::AbstractOMPIClusterEnvDetector) = parse(Int, ENV[_OMPI_PROCESS_ID]) + +function get_local_process_id(::AbstractOMPIClusterEnvDetector) + return parse(Int, ENV[_OMPI_LOCAL_PROCESS_ID]) +end + +end diff --git a/src/Enzyme.jl b/src/Enzyme.jl new file mode 100644 index 0000000000..366352b5da --- /dev/null +++ b/src/Enzyme.jl @@ -0,0 +1,7 @@ +# TODO: move the overload_autodiff here as well + +# The default `onehot` will lead to scalar indexing +function Enzyme.onehot(x::TracedRArray{T,N}) where {T,N} + x_arr = zeros(T, size(x)) + return map(Base.Fix1(TracedUtils.promote_to, TracedRArray{T,N}), Enzyme.onehot(x_arr)) +end diff --git a/src/Interpreter.jl b/src/Interpreter.jl index f75e57cfa4..f3267ed353 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -39,6 +39,25 @@ function set_reactant_abi( ) (; fargs, argtypes) = arginfo + if f === ReactantCore.within_compile + if length(argtypes) != 1 + @static if VERSION < v"1.11.0-" + return CallMeta(Union{}, Effects(), NoCallInfo()) + else + return CallMeta(Union{}, Union{}, Effects(), NoCallInfo()) + end + end + @static if VERSION < v"1.11.0-" + return CallMeta( + Core.Const(true), Core.Compiler.EFFECTS_TOTAL, MethodResultPure() + ) + else + return CallMeta( + Core.Const(true), Union{}, Core.Compiler.EFFECTS_TOTAL, MethodResultPure() + ) + end + end + # Improve inference by considering call_with_reactant as having the same results as # the original call if f === Reactant.call_with_reactant @@ -64,8 +83,9 @@ end ReactantCacheToken(), REACTANT_METHOD_TABLE, world, - true, #=forward_rules=# - true, #=reverse_rules=# + false, #=forward_rules=# + false, #=reverse_rules=# + false, #=inactive_rules=# false, #=broadcast_rewrite=# set_reactant_abi, ) @@ -80,8 +100,9 @@ else REACTANT_CACHE, REACTANT_METHOD_TABLE, world, - true, #=forward_rules=# - true, #=forward_rules=# + false, #=forward_rules=# + false, #=reverse_rules=# + false, #=inactive_rules=# false, #=broadcast_rewrite=# set_reactant_abi, ) @@ -167,7 +188,7 @@ end function push_acts!(ad_inputs, x::BatchDuplicated, path, reverse) TracedUtils.push_val!(ad_inputs, x.val, path) if !reverse - ET = eltype(x.val) + ET = unwrapped_eltype(x.val) predims = size(x.val) cval = MLIR.IR.result( MLIR.Dialects.stablehlo.concatenate( @@ -182,7 +203,7 @@ end function push_acts!(ad_inputs, x::BatchDuplicatedNoNeed, path, reverse) TracedUtils.push_val!(ad_inputs, x.val, path) if !reverse - ET = eltype(x.val) + ET = unwrapped_eltype(x.val) predims = size(x.val) cval = MLIR.IR.result( MLIR.Dialects.stablehlo.concatenate( @@ -206,14 +227,13 @@ function set_act!(inp, path, reverse, tostore; emptypath=false) end #if inp isa Enzyme.Active || !reverse - x.mlir_data = tostore + TracedUtils.set_mlir_data!(x, tostore) #else # x.mlir_data = MLIR.IR.result(MLIR.Dialects.stablehlo.add(x.mlir_data, tostore), 1) #end - if emptypath - x.paths = () - end + emptypath && TracedUtils.set_paths!(x, ()) + return nothing end function overload_autodiff( @@ -235,9 +255,12 @@ function overload_autodiff( primf = f.val primargs = ((v.val for v in args)...,) - fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = TracedUtils.make_mlir_fn( + mlir_fn_res = TracedUtils.make_mlir_fn( primf, primargs, (), string(f) * "_autodiff", false ) + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f activity = Int32[] ad_inputs = MLIR.IR.Value[] @@ -264,22 +287,35 @@ function overload_autodiff( for a in linear_results if TracedUtils.has_residx(a) if needs_primal(CMode) - push!(outtys, TracedUtils.transpose_ty(MLIR.IR.type(a.mlir_data))) + push!( + outtys, + TracedUtils.transpose_ty(MLIR.IR.type(TracedUtils.get_mlir_data(a))), + ) end if CMode <: Enzyme.ForwardMode && !(A <: Enzyme.Const) if width == 1 - push!(outtys, TracedUtils.transpose_ty(MLIR.IR.type(a.mlir_data))) + push!( + outtys, + TracedUtils.transpose_ty( + MLIR.IR.type(TracedUtils.get_mlir_data(a)) + ), + ) else push!( outtys, TracedUtils.batch_ty( - width, TracedUtils.transpose_ty(MLIR.IR.type(a.mlir_data)) + width, + TracedUtils.transpose_ty( + MLIR.IR.type(TracedUtils.get_mlir_data(a)) + ), ), ) end end else - push!(outtys, TracedUtils.transpose_ty(MLIR.IR.type(a.mlir_data))) + push!( + outtys, TracedUtils.transpose_ty(MLIR.IR.type(TracedUtils.get_mlir_data(a))) + ) end end for (i, act) in enumerate(activity) @@ -298,7 +334,9 @@ function overload_autodiff( act = act_from_type(A, reverse, needs_primal(CMode)) push!(ret_activity, act) if act == enzyme_out || act == enzyme_outnoneed - attr = fill(MLIR.IR.Attribute(eltype(a)(1)), Ops.mlir_type(a)) + attr = MLIR.IR.DenseElementsAttribute( + fill(one(unwrapped_eltype(a)), size(a)) + ) cst = MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1) push!(ad_inputs, cst) end diff --git a/src/Ops.jl b/src/Ops.jl index 18ab2d7d4b..b49ca654c0 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -4,23 +4,30 @@ module Ops using ..MLIR: MLIR using ..MLIR.Dialects: stablehlo, chlo, enzyme -using ..Reactant: Reactant, TracedRArray, TracedRNumber, RArray, RNumber, MissingTracedValue - -function mlir_type(x::RArray{T,N}) where {T,N} - return MLIR.IR.TensorType(size(x), MLIR.IR.Type(T)) +using ..Reactant: + Reactant, + TracedRArray, + TracedRNumber, + RArray, + RNumber, + MissingTracedValue, + unwrapped_eltype +using ReactantCore: ReactantCore +using Functors: fmap + +function mlir_type(x::Union{RNumber,RArray}) + return MLIR.IR.TensorType(size(x), MLIR.IR.Type(unwrapped_eltype(x))) end -mlir_type(::RNumber{T}) where {T} = MLIR.IR.TensorType((), MLIR.IR.Type(T)) - mlir_type(::MissingTracedValue) = MLIR.IR.TensorType((), MLIR.IR.Type(Bool)) -function mlir_type(::Type{<:RArray{T,N}}, shape) where {T,N} +function mlir_type(RT::Type{<:RArray{T,N}}, shape) where {T,N} @assert length(shape) == N - return MLIR.IR.TensorType(shape, MLIR.IR.Type(T)) + return MLIR.IR.TensorType(shape, MLIR.IR.Type(unwrapped_eltype(RT))) end -function mlir_type(::Type{<:RNumber{T}}) where {T} - return MLIR.IR.TensorType((), MLIR.IR.Type(T)) +function mlir_type(RT::Type{<:RNumber}) + return MLIR.IR.TensorType((), MLIR.IR.Type(unwrapped_eltype(RT))) end function mlir_type(::Type{<:MissingTracedValue}) @@ -73,10 +80,82 @@ end @noinline function constant( x::T; location=mlir_stacktrace("constant", @__FILE__, @__LINE__) ) where {T<:Number} - res = constant(fill(x); location) + x isa TracedRNumber && return x + res = fill(x; location) return TracedRNumber{T}((), res.mlir_data) end +function fill( + v, dims::Base.DimOrInd...; location=mlir_stacktrace("fill", @__FILE__, @__LINE__) +) + return fill(v, dims; location) +end +function fill( + v, + dims::NTuple{N,Union{Integer,Base.OneTo}}; + location=mlir_stacktrace("fill", @__FILE__, @__LINE__), +) where {N} + return fill(v, map(Base.to_dim, dims); location) +end +function fill( + v, dims::NTuple{N,Integer}; location=mlir_stacktrace("fill", @__FILE__, @__LINE__) +) where {N} + return fill(v, collect(dims); location) +end +function fill(v, ::Tuple{}; location=mlir_stacktrace("fill", @__FILE__, @__LINE__)) + return fill(v, Int[]; location) +end + +function fill(number::TracedRNumber{T}, shape::Vector{Int}; location) where {T} + return Base.fill(number, Tuple(shape)) +end + +for (T, mlir_func) in ( + (Bool, :mlirDenseElementsAttrBoolSplatGet), + (UInt8, :mlirDenseElementsAttrUInt8SplatGet), + (Int8, :mlirDenseElementsAttrInt8SplatGet), + (UInt32, :mlirDenseElementsAttrUInt32SplatGet), + (Int32, :mlirDenseElementsAttrInt32SplatGet), + (UInt64, :mlirDenseElementsAttrUInt64SplatGet), + (Int64, :mlirDenseElementsAttrInt64SplatGet), + (Float32, :mlirDenseElementsAttrFloatSplatGet), + (Float64, :mlirDenseElementsAttrDoubleSplatGet), +) + @eval begin + @noinline function fill( + number::$T, + shape::Vector{Int}; + location=mlir_stacktrace("fill", @__FILE__, @__LINE__), + ) + tt = MLIR.IR.TensorType(shape, MLIR.IR.Type($T); location=location) + + splatattr = MLIR.API.$mlir_func(tt, number) + cst_op = stablehlo.constant(; output=tt, value=splatattr, location=location) + cst = MLIR.IR.result(cst_op) + ta = TracedRArray{$T,length(shape)}((), cst, shape) + return ta + end + end +end + +_fill_element_attr(x) = MLIR.IR.Attribute(x) +function _fill_element_attr(x::Complex) + return MLIR.IR.Attribute([ + MLIR.IR.Attribute(Base.real(x)), MLIR.IR.Attribute(Base.imag(x)) + ]) +end + +@noinline function fill( + element::T, shape::Vector{Int}; location=mlir_stacktrace("fill", @__FILE__, @__LINE__) +) where {T} + tt = MLIR.IR.TensorType(shape, MLIR.IR.Type(T)) + splatattr = MLIR.API.mlirDenseElementsAttrSplatGet(tt, _fill_element_attr(element)) + cst_op = stablehlo.constant(; output=tt, value=splatattr, location=location) + cst = MLIR.IR.result(cst_op) + ta = TracedRArray{T,length(shape)}((), cst, shape) + return ta +end + # unary elementwise ops for (dialect, op) in [ (:stablehlo, :abs), @@ -201,12 +280,9 @@ for (dialect, op) in [ end # is* checks -for (dialect, op) in [ - #(:stablehlo, :is_finite), - (:chlo, :is_inf), - (:chlo, :is_neg_inf), - (:chlo, :is_pos_inf), -] +for (dialect, op) in + [(:stablehlo, :is_finite), (:chlo, :is_inf), (:chlo, :is_neg_inf), (:chlo, :is_pos_inf)] + result = dialect == :stablehlo ? :y : :result @eval begin @noinline function $op( x::TracedRArray{T,N}; @@ -214,7 +290,9 @@ for (dialect, op) in [ ) where {T,N} res = MLIR.IR.result( $(:($dialect.$op))( - x.mlir_data; result=mlir_type(TracedRArray{Bool,N}, size(x)), location + x.mlir_data; + $(result)=mlir_type(TracedRArray{Bool,N}, size(x)), + location, ), ) return TracedRArray{Bool,N}((), res, size(x)) @@ -226,7 +304,7 @@ for (dialect, op) in [ ) where {T} res = MLIR.IR.result( $(:($dialect.$op))( - x.mlir_data; result=mlir_type(TracedRArray{Bool,0}, ()), location + x.mlir_data; $(result)=mlir_type(TracedRArray{Bool,0}, ()), location ), ) return TracedRNumber{Bool}((), res) @@ -234,26 +312,6 @@ for (dialect, op) in [ end end -@noinline function is_finite( - x::TracedRArray{T,N}; location=mlir_stacktrace("is_finite", @__FILE__, @__LINE__) -) where {T,N} - res = MLIR.IR.result( - stablehlo.is_finite( - x.mlir_data; y=mlir_type(TracedRArray{Bool,N}, size(x)), location - ), - ) - return TracedRArray{Bool,N}((), res, size(x)) -end - -@noinline function is_finite( - x::TracedRNumber{T}; location=mlir_stacktrace("is_finite", @__FILE__, @__LINE__) -) where {T} - res = MLIR.IR.result( - stablehlo.is_finite(x.mlir_data; y=mlir_type(TracedRArray{Bool,0}, ()), location) - ) - return TracedRNumber{Bool}((), res) -end - # fixes to default automated implementations @noinline function abs( x::TracedRArray{Complex{T},N}; location=mlir_stacktrace("abs", @__FILE__, @__LINE__) @@ -343,9 +401,9 @@ end @noinline function pad( x::TracedRArray{T,N}, padding_value::TracedRNumber{T}; - low=fill(0, N), - high=fill(0, N), - interior=fill(0, N), + low=Base.fill(0, N), + high=Base.fill(0, N), + interior=Base.fill(0, N), location=mlir_stacktrace("pad", @__FILE__, @__LINE__), ) where {T,N} rsize = size(x) .+ low .+ high .+ max.(size(x) .- 1, 0) .* interior @@ -470,6 +528,18 @@ end # ) # return TracedRArray{T,N}((), res, size(x)) # end +@noinline function bitcast_convert( + ::Type{U}, + x::TracedRNumber{T}; + location=mlir_stacktrace("bitcast_convert", @__FILE__, @__LINE__), +) where {T,U} + res = MLIR.IR.result( + stablehlo.bitcast_convert( + x.mlir_data; result_0=mlir_type(TracedRArray{U,0}, ()), location + ), + ) + return TracedRNumber{U}((), res) +end @noinline function fft( x::TracedRArray{T,N}; @@ -931,59 +1001,136 @@ end end # broadcast ops -# function broadcast_in_dim( -# x::TracedRArray{T,N}, -# dims::Vector{Int}; -# location=mlir_stacktrace( -# "broadcast_in_dim", @__FILE__, @__LINE__ -# ), -# ) where {T,N} -# rsize = restype = MLIR.IR.TensorType([...], mlir_type(T)) # mlir_type(TracedRArray{T,N}, size(x)) -# res = MLIR.IR.result( -# stablehlo.broadcast_in_dim( -# x.mlir_data; -# result_0=restype, -# broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims), -# location, -# ), -# ) -# return TracedRArray{T,N}((), res, size(x)) -# end +function broadcast_in_dim( + x::TracedRArray{T,N}, + dims::Vector{Int}, + result_size::Vector{Int}; + location=mlir_stacktrace("broadcast_in_dim", @__FILE__, @__LINE__), +) where {T,N} + @assert length(dims) == N -# sorting ops -# TODO need to trace over `comparator` -# function sort( -# x::TracedRArray{T,N}; -# comparator, -# dimension=-1, -# is_stable=false, -# location=mlir_stacktrace("sort", @__FILE__, @__LINE__), -# ) where {T,N} -# dimension = MLIR.IR.Attribute(dimension) -# is_stable = MLIR.IR.Attribute(is_stable) -# res = MLIR.IR.result( -# stablehlo.sort( -# x.mlir_data; -# result=mlir_type(TracedRArray{T,N}, size(x)), -# dimension, -# is_stable, -# location, -# ), -# ) -# return TracedRArray{T,N}((), res, size(x)) -# end + res = MLIR.IR.result( + stablehlo.broadcast_in_dim( + x.mlir_data; + result_0=MLIR.IR.TensorType(result_size, MLIR.IR.Type(T)), + broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims .- 1), + location, + ), + ) + return TracedRArray{T,Int64(length(result_size))}((), res, Tuple(result_size)) +end + +function broadcast_in_dim( + x::TracedRNumber{T}, + dims::Vector{Int}, + result_size::Vector{Int}; + location=mlir_stacktrace("broadcast_in_dim", @__FILE__, @__LINE__), +) where {T} + @assert length(dims) == 0 + + res = MLIR.IR.result( + stablehlo.broadcast_in_dim( + x.mlir_data; + result_0=MLIR.IR.TensorType(result_size, MLIR.IR.Type(T)), + broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims .- 1), + location, + ), + ) + return TracedRArray{T,Int64(length(result_size))}((), res, Tuple(result_size)) +end + +@noinline function sort( + xs::TracedRArray...; + comparator, + dimension=1, + is_stable=false, + location=mlir_stacktrace("sort", @__FILE__, @__LINE__), +) + #C4: + for x in xs + @assert 0 < dimension <= ndims(x) "$x invalid dimension" + end + + sample_inputs = Vector{Reactant.ConcretePJRTNumber}(undef, length(xs) * 2) + for i in eachindex(xs) + T = Reactant.unwrapped_eltype(xs[i]) + sample_inputs[2i - 1] = Reactant.ConcretePJRTNumber(T(0)) + sample_inputs[2i] = Reactant.ConcretePJRTNumber(T(0)) + end + func = + Reactant.TracedUtils.make_mlir_fn( + comparator, + (sample_inputs...,), + (), + "comparator"; + args_in_result=:none, + return_dialect=:stablehlo, + ).f + @assert MLIR.IR.nregions(func) == 1 + fn_name = String( + MLIR.IR.attr(func, String(MLIR.API.mlirSymbolTableGetSymbolAttributeName())) + ) + #C5: + @assert fn_name == "comparator" "$comparator: no function generated" + ftype_attr = MLIR.IR.attr(func, "function_type") + ftype = MLIR.IR.Type(ftype_attr) + @assert MLIR.IR.result(ftype) == MLIR.IR.TensorType((), MLIR.IR.Type(Bool)) error( + "$comparator return type is not tensor" + ) + + comparator = MLIR.IR.Region() + MLIR.API.mlirRegionTakeBody(comparator, MLIR.IR.region(func, 1)) + MLIR.IR.rmfromparent!(func) + + dimension = MLIR.IR.Attribute(dimension - 1) + is_stable = MLIR.IR.Attribute(is_stable) + + op = stablehlo.sort( + [x.mlir_data for x in xs]; + result_0=[mlir_type(typeof(x), size(x)) for x in xs], + dimension, + is_stable, + comparator, + location, + ) + return [ + TracedRArray{Reactant.unwrapped_eltype(xs[i]),ndims(xs[i])}( + (), MLIR.IR.result(op, i), size(xs[i]) + ) for i in eachindex(xs) + ] +end @noinline function top_k( - x::TracedRArray{T,N}, k; location=mlir_stacktrace("top_k", @__FILE__, @__LINE__) + x::TracedRArray{T,N}, + k; + dimension::Integer=N, + location=mlir_stacktrace("top_k", @__FILE__, @__LINE__), ) where {T,N} + @assert 1 <= dimension <= N + if dimension != N # chlo.top_k performs the operation along the last dimension + pdims = collect(Int64, 1:N) + pdims[dimension] = N + pdims[N] = dimension + x = permutedims(x, pdims) + end + rsize = [size(x)[1:(end - 1)]..., k] values = mlir_type(TracedRArray{T,N}, rsize) indices = mlir_type(TracedRArray{Int32,N}, rsize) op = chlo.top_k(x.mlir_data; values, indices, k, location) - return (; - values=TracedRArray{T,N}((), MLIR.IR.result(op, 1), rsize), - indices=TracedRArray{Int32,N}((), MLIR.IR.result(op, 2), rsize), - ) + indices = add( + TracedRArray{Int32,N}((), MLIR.IR.result(op, 2), rsize), + fill(Int32(1), Tuple(rsize)), + ) # return the 1-indexed index + indices = convert(TracedRArray{Int64,N}, indices) # julia indexes with Int64 generally + values = TracedRArray{T,N}((), MLIR.IR.result(op, 1), rsize) + + if dimension != N + values = permutedims(values, invperm(pdims)) + indices = permutedims(indices, invperm(pdims)) + end + + return (; values, indices) end @noinline function iota( @@ -1077,7 +1224,7 @@ end (; output_state, output) = rng_bit_generator(uT, seed, shape; algorithm, location) output = divide( convert(TracedRArray{T,ndims(output)}, output), - constant(fill(T(typemax(uT)), Tuple(shape)); location), + fill(T(typemax(uT)), Tuple(shape); location), ) return (; output_state, output) end @@ -1117,11 +1264,11 @@ fields: rand_uniform = res.output seed = res.output_state scaled_uniform = subtract( - multiply(rand_uniform, constant(fill(T(2), size(rand_uniform)))), - constant(fill(T(1), size(rand_uniform))), + multiply(rand_uniform, fill(T(2), size(rand_uniform))), + fill(T(1), size(rand_uniform)), ) probit = erf_inv(scaled_uniform) - rand_normal = multiply(probit, constant(fill(Base.sqrt(T(2)), size(rand_uniform)))) + rand_normal = multiply(probit, fill(Base.sqrt(T(2)), size(rand_uniform))) return (; output_state=seed, output=rand_normal) end @@ -1288,7 +1435,7 @@ julia> Reactant.@jit( Reactant.to_rarray(Float32[1, 2, 3]), ) ) -(ConcreteRArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),) +(ConcretePJRTArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),) ``` """ @noinline function hlo_call( @@ -1385,4 +1532,785 @@ julia> Reactant.@jit( end end +""" + scatter_setindex(dest, scatter_indices, updates) + +Uses [`MLIR.Dialects.stablehlo.scatter`](@ref) to set the values of `dest` at the indices +specified by `scatter_indices` to the values in `updates`. If the indices are contiguous it +is recommended to directly use [`MLIR.Dialects.stablehlo.dynamic_update_slice`](@ref) +instead. +""" +@noinline function scatter_setindex( + dest::TracedRArray{T,N}, + scatter_indices::TracedRArray{Int64,2}, + updates::TracedRArray{T2,1}, +) where {T,N,T2} + @assert length(updates) == size(scatter_indices, 1) + @assert size(scatter_indices, 2) == N + + updates = convert(TracedRArray{T,1}, updates) + + update_computation = MLIR.IR.Region() + block = MLIR.IR.Block( + [mlir_type(TracedRNumber{T}), mlir_type(TracedRNumber{T})], + [MLIR.IR.Location(), MLIR.IR.Location()], + ) + return_op = MLIR.Dialects.stablehlo.return_([MLIR.IR.argument(block, 2)]) + MLIR.IR.rmfromparent!(return_op) + push!(block, return_op) + pushfirst!(update_computation, block) + + #! format: off + update_window_dims = Int64[] + inserted_window_dims = collect(Int64, 0:(N - 1)) + input_batching_dims = Int64[] + scatter_indices_batching_dims = Int64[] + scatter_dims_to_operand_dims = collect(Int64, 0:(N - 1)) + index_vector_dim = Int64(1) + + scatter_dimension_numbers = MLIR.API.stablehloScatterDimensionNumbersGet( + MLIR.IR.context(), + length(update_window_dims), update_window_dims, + length(inserted_window_dims), inserted_window_dims, + length(input_batching_dims), input_batching_dims, + length(scatter_indices_batching_dims), scatter_indices_batching_dims, + length(scatter_dims_to_operand_dims), scatter_dims_to_operand_dims, + index_vector_dim, + ) + #! format: on + + return TracedRArray{T,N}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.scatter( + [dest.mlir_data], + scatter_indices.mlir_data, + [updates.mlir_data]; + result_0=[mlir_type(TracedRArray{T,N}, size(dest))], + update_computation, + scatter_dimension_numbers, + ), + 1, + ), + size(dest), + ) +end + +""" + gather_getindex(src, gather_indices) + +Uses [`MLIR.Dialects.stablehlo.gather`](@ref) to get the values of `src` at the indices +specified by `gather_indices`. If the indices are contiguous it is recommended to directly +use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead. +""" +@noinline function gather_getindex( + src::TracedRArray{T,N}, gather_indices::TracedRArray{Int64,2} +) where {T,N} + @assert size(gather_indices, 2) == N + + #! format: off + offset_dims = Int64[1] + collapsed_slice_dims = collect(Int64, 0:(N - 2)) + operand_batching_dims = Int64[] + start_indices_batching_dims = Int64[] + start_index_map = collect(Int64, 0:(N - 1)) + index_vector_dim = Int64(1) + + dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet( + MLIR.IR.context(), + Int64(length(offset_dims)), offset_dims, + Int64(length(collapsed_slice_dims)), collapsed_slice_dims, + Int64(length(operand_batching_dims)), operand_batching_dims, + Int64(length(start_indices_batching_dims)), start_indices_batching_dims, + Int64(length(start_index_map)), start_index_map, + Int64(index_vector_dim), + ) + #! format: on + + return reshape( + TracedRArray{T}( + MLIR.IR.result( + MLIR.Dialects.stablehlo.gather( + src.mlir_data, + gather_indices.mlir_data; + dimension_numbers, + slice_sizes=Base.fill(Int64(1), N), + indices_are_sorted=false, + ), + 1, + ), + ), + size(gather_indices, 1), + ) +end + +@noinline function while_loop(cond_fn::CFn, body_fn::BFn, args...) where {CFn,BFn} + # TODO: detect and prevent mutation within the condition + + # Make all the args traced or concrete + N = length(args) + seen_args = Reactant.OrderedIdDict() + traced_args = Vector{Any}(undef, N) + for i in 1:N + @inbounds traced_args[i] = Reactant.make_tracer( + seen_args, args[i], (), Reactant.NoStopTracedTrack; track_numbers=Number + ) + end + + linear_args = Reactant.TracedType[] + for (k, v) in seen_args + v isa Reactant.TracedType || continue + push!(linear_args, v) + end + + input_types = [mlir_type(arg) for arg in linear_args] + + cond_fn_compiled = + Reactant.TracedUtils.make_mlir_fn( + cond_fn, + traced_args, + (), + string(gensym("cond_fn")), + false; + return_dialect=:stablehlo, + args_in_result=:none, + do_transpose=false, + ).f + + body_fn_compiled = + Reactant.TracedUtils.make_mlir_fn( + body_fn, + traced_args, + (), + string(gensym("body_fn")), + false; + return_dialect=:stablehlo, + args_in_result=:none, + do_transpose=false, + ).f + + cond_reg = Reactant.TracedUtils.__take_region(cond_fn_compiled) + body_reg = Reactant.TracedUtils.__take_region(body_fn_compiled) + + MLIR.IR.rmfromparent!(cond_fn_compiled) + MLIR.IR.rmfromparent!(body_fn_compiled) + + while_op = MLIR.Dialects.stablehlo.while_( + MLIR.IR.Value[Reactant.TracedUtils.get_mlir_data(arg) for arg in linear_args]; + result_0=input_types, + cond=cond_reg, + body=body_reg, + ) + + return map(enumerate(linear_args)) do (i, arg) + Reactant.TracedUtils.set_mlir_data!(arg, MLIR.IR.result(while_op, i)) + end +end + +@noinline function if_condition( + cond::TracedRNumber{Bool}, true_fn::TFn, false_fn::FFn, args... +) where {TFn,FFn} + true_fn_names = (gensym(:true_fn_args), gensym(:true_result), gensym(:true_fn_resargs)) + false_fn_names = ( + gensym(:false_fn_args), gensym(:false_result), gensym(:false_fn_resargs) + ) + + # Make all the args traced or concrete + N = length(args) + tb_seen_args = Reactant.OrderedIdDict() + fb_seen_args = Reactant.OrderedIdDict() + tb_traced_args = Vector{Any}(undef, N) + fb_traced_args = Vector{Any}(undef, N) + for i in 1:N + @inbounds tb_traced_args[i] = Reactant.make_tracer( + tb_seen_args, + args[i], + (true_fn_names[1], i), + Reactant.TracedSetPath; + track_numbers=Number, + ) + @inbounds fb_traced_args[i] = Reactant.make_tracer( + fb_seen_args, + args[i], + (false_fn_names[1], i), + Reactant.TracedSetPath; + track_numbers=Number, + ) + end + + tb_linear_args = Reactant.TracedType[ + v for v in values(tb_seen_args) if v isa Reactant.TracedType + ] + fb_linear_args = Reactant.TracedType[ + v for v in values(fb_seen_args) if v isa Reactant.TracedType + ] + + input_types = [mlir_type(arg) for arg in tb_linear_args] + sym_visibility = MLIR.IR.Attribute("private") + + # compile the true branch without any returns first + true_fn_mod = MLIR.IR.mmodule() + true_func_tmp = MLIR.IR.block!(MLIR.IR.body(true_fn_mod)) do + return MLIR.Dialects.func.func_(; + sym_name=string(true_fn) * "_tb_tmp", + function_type=MLIR.IR.FunctionType(input_types, []), + body=MLIR.IR.Region(), + sym_visibility, + ) + end + true_fn_body = MLIR.IR.Block() + push!(MLIR.IR.region(true_func_tmp, 1), true_fn_body) + + true_fn_args = true_fn_names[1] + + MLIR.IR.activate!(true_fn_body) + tb_result = try + for (i, arg) in enumerate(tb_linear_args) + # find the right path to index the traced arg. + path = nothing + for p in Reactant.TracedUtils.get_paths(arg) + if length(p) > 0 && p[1] == true_fn_args + path = p[2:end] + end + end + if isnothing(path) + error("if_condition: could not find path for linear arg $i") + end + Reactant.TracedUtils.set_mlir_data!( + arg, + only( + Reactant.TracedUtils.push_val!( + [], tb_traced_args[path[1]], path[2:end] + ), + ), + ) + end + Reactant.call_with_reactant(true_fn, tb_traced_args...) + finally + MLIR.IR.deactivate!(true_fn_body) + end + + seen_true_results = Reactant.OrderedIdDict() + traced_true_results = Reactant.make_tracer( + seen_true_results, + tb_result, + (true_fn_names[2],), + Reactant.NoStopTracedTrack; + track_numbers=Number, + ) + for i in eachindex(tb_linear_args) + Reactant.make_tracer( + seen_true_results, + tb_linear_args[i], + (true_fn_names[3], i), + Reactant.NoStopTracedTrack; + track_numbers=Number, + ) + end + + tb_linear_results = Reactant.TracedType[ + v for v in values(seen_true_results) if v isa Reactant.TracedType + ] + + # compile the false branch without any returns similar to the true branch + false_fn_mod = MLIR.IR.mmodule() + false_func_tmp = MLIR.IR.block!(MLIR.IR.body(false_fn_mod)) do + return MLIR.Dialects.func.func_(; + sym_name=string(false_fn) * "_fb_tmp", + function_type=MLIR.IR.FunctionType(input_types, []), + body=MLIR.IR.Region(), + sym_visibility, + ) + end + false_fn_body = MLIR.IR.Block() + push!(MLIR.IR.region(false_func_tmp, 1), false_fn_body) + + false_fn_args = false_fn_names[1] + MLIR.IR.activate!(false_fn_body) + fb_result = try + for (i, arg) in enumerate(fb_linear_args) + # find the right path to index the traced arg. + path = nothing + for p in Reactant.TracedUtils.get_paths(arg) + if length(p) > 0 && p[1] == false_fn_args + path = p[2:end] + end + end + if isnothing(path) + error("if_condition: could not find path for linear arg $i") + end + Reactant.TracedUtils.set_mlir_data!( + arg, + only( + Reactant.TracedUtils.push_val!( + [], fb_traced_args[path[1]], path[2:end] + ), + ), + ) + end + Reactant.call_with_reactant(false_fn, fb_traced_args...) + finally + MLIR.IR.deactivate!(false_fn_body) + end + + seen_false_results = Reactant.OrderedIdDict() + traced_false_results = Reactant.make_tracer( + seen_false_results, + fb_result, + (false_fn_names[2],), + Reactant.NoStopTracedTrack; + track_numbers=Number, + ) + for i in eachindex(fb_linear_args) + Reactant.make_tracer( + seen_false_results, + fb_linear_args[i], + (false_fn_names[3], i), + Reactant.NoStopTracedTrack; + track_numbers=Number, + ) + end + + fb_linear_results = Reactant.TracedType[ + v for v in values(seen_false_results) if v isa Reactant.TracedType + ] + + tb_results_dict = Dict{Tuple,Reactant.TracedType}() + for tr in tb_linear_results + for path in Reactant.TracedUtils.get_paths(tr) + if length(path) > 0 && + (path[1] == true_fn_names[2] || path[1] == true_fn_names[3]) + tb_results_dict[path] = tr + end + end + end + + fb_results_dict = Dict{Tuple,Reactant.TracedType}() + for fr in fb_linear_results + for path in Reactant.TracedUtils.get_paths(fr) + if length(path) > 0 && + (path[1] == false_fn_names[2] || path[1] == false_fn_names[3]) + fb_results_dict[path] = fr + end + end + end + + all_paths = [] + for (path, tr) in tb_results_dict + if path[1] == true_fn_names[2] + push!(all_paths, (:result, path[2:end]...)) + elseif path[1] == true_fn_names[3] + push!(all_paths, (:resarg, path[2:end]...)) + end + end + for (path, fr) in fb_results_dict + if path[1] == false_fn_names[2] + push!(all_paths, (:result, path[2:end]...)) + elseif path[1] == false_fn_names[3] + push!(all_paths, (:resarg, path[2:end]...)) + end + end + all_paths = sort!(unique!(all_paths)) + tb_paths = [ + if path[1] == :result + (true_fn_names[2], path[2:end]...) + else + (true_fn_names[3], path[2:end]...) + end for path in all_paths + ] + fb_paths = [ + if path[1] == :result + (false_fn_names[2], path[2:end]...) + else + (false_fn_names[3], path[2:end]...) + end for path in all_paths + ] + + # finalize the true branch by adding the missing values + MLIR.IR.activate!(true_fn_body) + tb_corrected_linear_results = Reactant.TracedType[] + try + for (i, path) in enumerate(tb_paths) + if haskey(tb_results_dict, tb_paths[i]) + push!(tb_corrected_linear_results, tb_results_dict[tb_paths[i]]) + else + push!(tb_corrected_linear_results, zero(fb_results_dict[fb_paths[i]])) + end + end + finally + MLIR.IR.deactivate!(true_fn_body) + end + + # finalize the false branch by adding the missing values + MLIR.IR.activate!(false_fn_body) + fb_corrected_linear_results = Reactant.TracedType[] + try + for (i, path) in enumerate(fb_paths) + if haskey(fb_results_dict, fb_paths[i]) + push!(fb_corrected_linear_results, fb_results_dict[fb_paths[i]]) + else + push!(fb_corrected_linear_results, zero(tb_results_dict[tb_paths[i]])) + end + end + finally + MLIR.IR.deactivate!(false_fn_body) + end + + # All MissingTracedValues must be replaced with zeroes + @assert length(tb_corrected_linear_results) == length(fb_corrected_linear_results) + + result_types = MLIR.IR.Type[] + for (i, (tr, fr)) in + enumerate(zip(tb_corrected_linear_results, fb_corrected_linear_results)) + if tr isa MissingTracedValue && fr isa MissingTracedValue + continue # Don't insert into IR + end + res = if tr isa MissingTracedValue + @assert !(fr isa MissingTracedValue) + MLIR.IR.activate!(true_fn_body) + try + tb_corrected_linear_results[i] = zero(fr) + finally + MLIR.IR.deactivate!(true_fn_body) + end + fr + elseif fr isa MissingTracedValue + @assert !(tr isa MissingTracedValue) + MLIR.IR.activate!(false_fn_body) + try + fb_corrected_linear_results[i] = zero(tr) + finally + MLIR.IR.deactivate!(false_fn_body) + end + tr + else + if typeof(tr) != typeof(fr) + @assert typeof(tr) == typeof(fr) "$(typeof(tr)) vs $(typeof(fr))" + end + tr + end + push!(result_types, mlir_type(res)) + end + + MLIR.IR.activate!(true_fn_body) + try + vals = MLIR.IR.Value[ + Reactant.TracedUtils.get_mlir_data(res) for + res in tb_corrected_linear_results if !(res isa MissingTracedValue) + ] + MLIR.Dialects.stablehlo.return_(vals) + finally + MLIR.IR.deactivate!(true_fn_body) + end + + MLIR.IR.activate!(false_fn_body) + try + vals = MLIR.IR.Value[ + Reactant.TracedUtils.get_mlir_data(res) for + res in fb_corrected_linear_results if !(res isa MissingTracedValue) + ] + MLIR.Dialects.stablehlo.return_(vals) + finally + MLIR.IR.deactivate!(false_fn_body) + end + + # With the corrected results, we can compile the true and false branches + tb_out_types = [mlir_type(tr) for tr in tb_corrected_linear_results] + + true_fn_compiled = MLIR.IR.block!(MLIR.IR.body(true_fn_mod)) do + return MLIR.Dialects.func.func_(; + sym_name=Reactant.TracedUtils.__lookup_unique_name_in_module( + true_fn_mod, string(true_fn) * "_tb" + ), + function_type=MLIR.IR.FunctionType(input_types, tb_out_types), + body=MLIR.IR.Region(), + sym_visibility, + ) + end + MLIR.API.mlirRegionTakeBody( + MLIR.IR.region(true_fn_compiled, 1), MLIR.IR.region(true_func_tmp, 1) + ) + MLIR.API.mlirOperationDestroy(true_func_tmp.operation) + true_func_tmp.operation = MLIR.API.MlirOperation(C_NULL) + + fb_out_types = [mlir_type(fr) for fr in fb_corrected_linear_results] + + false_fn_compiled = MLIR.IR.block!(MLIR.IR.body(false_fn_mod)) do + return MLIR.Dialects.func.func_(; + sym_name=Reactant.TracedUtils.__lookup_unique_name_in_module( + false_fn_mod, string(false_fn) * "_fb" + ), + function_type=MLIR.IR.FunctionType(input_types, fb_out_types), + body=MLIR.IR.Region(), + sym_visibility, + ) + end + MLIR.API.mlirRegionTakeBody( + MLIR.IR.region(false_fn_compiled, 1), MLIR.IR.region(false_func_tmp, 1) + ) + MLIR.API.mlirOperationDestroy(false_func_tmp.operation) + false_func_tmp.operation = MLIR.API.MlirOperation(C_NULL) + + tb_region = Reactant.TracedUtils.__take_region(true_fn_compiled) + fb_region = Reactant.TracedUtils.__take_region(false_fn_compiled) + + MLIR.IR.rmfromparent!(true_fn_compiled) + MLIR.IR.rmfromparent!(false_fn_compiled) + + if_compiled = MLIR.Dialects.stablehlo.if_( + cond.mlir_data; true_branch=tb_region, false_branch=fb_region, result_0=result_types + ) + + corrected_traced_results = fmap(traced_false_results, traced_true_results) do fr, tr + if fr isa MissingTracedValue && tr isa MissingTracedValue + error("Both false and true branches are missing") + elseif fr isa MissingTracedValue + return tr + else + return fr + end + end + + for (residx, path) in enumerate(all_paths) + if path[1] == :result + Reactant.TracedUtils.set!( + corrected_traced_results, path[2:end], MLIR.IR.result(if_compiled, residx) + ) + elseif path[1] == :resarg + # The resarg path is with respect to the linear args, not the traced args. + # We find the path into traced args by searching for it in the linear args. + # Concretely, we look into tb_linear_args, but we could also look into fb_linear_args, they contain the same arg path. + @assert length(path) == 2 + argpath = nothing + for p in Reactant.TracedUtils.get_paths(tb_linear_args[path[2]]) + if length(p) > 0 && p[1] == true_fn_names[1] + argpath = p[2:end] + end + end + if isnothing(argpath) + error("if_condition: could not find path for resarg $path") + end + Reactant.TracedUtils.set!(args, argpath, MLIR.IR.result(if_compiled, residx)) + end + end + + return corrected_traced_results +end + +@noinline function call(f, args...) + seen_cache = Reactant.OrderedIdDict() + Reactant.make_tracer( + seen_cache, + args, + (), # we have to insert something here, but we remove it immediately below. + Reactant.TracedTrack; + toscalar=false, + ) + linear_args = [] + mlir_caller_args = Reactant.MLIR.IR.Value[] + for (k, v) in seen_cache + v isa Reactant.TracedType || continue + push!(linear_args, v) + push!(mlir_caller_args, v.mlir_data) + # make tracer inserted `()` into the path, here we remove it: + v.paths = v.paths[1:(end - 1)] + end + + seen = Dict() + cache_key = [] + Reactant.make_tracer(seen, (f, args...), cache_key, Reactant.TracedToTypes) + cache = Reactant.Compiler.callcache() + if haskey(cache, cache_key) + # cache lookup: + (; f_name, mlir_result_types, traced_result, mutated_args) = cache[cache_key] + else + f_name = String(gensym(Symbol(f))) + temp = Reactant.TracedUtils.make_mlir_fn( + f, args, (), f_name, false; args_in_result=:mutated, do_transpose=false + ) + (; traced_result, ret, mutated_args) = temp + mlir_result_types = [ + MLIR.IR.type(MLIR.IR.operand(ret, i)) for i in 1:MLIR.IR.noperands(ret) + ] + cache[cache_key] = (; f_name, mlir_result_types, traced_result, mutated_args) + end + + call_op = MLIR.Dialects.func.call( + mlir_caller_args; + result_0=mlir_result_types, + callee=MLIR.IR.FlatSymbolRefAttribute(f_name), + ) + + seen_results = Reactant.OrderedIdDict() + traced_result = Reactant.make_tracer( + seen_results, + traced_result, + (), # we have to insert something here, but we remove it immediately below. + Reactant.TracedSetPath; + toscalar=false, + ) + i = 1 + for (k, v) in seen_results + v isa Reactant.TracedType || continue + # this mutates `traced_result`, which is what we want: + v.mlir_data = MLIR.IR.result(call_op, i) + # make tracer inserted `()` into the path, here we remove it: + v.paths = v.paths[1:(end - 1)] + i += 1 + end + nres = MLIR.IR.nresults(call_op) + # mutated args are included as the last ones in the call op results + for (result_i, arg_i) in zip((nres - length(mutated_args)):nres, mutated_args) + Reactant.TracedUtils.set_mlir_data!( + linear_args[arg_i], MLIR.IR.result(call_op, result_i + 1) + ) + end + return traced_result +end + +# Shardy Ops +""" + mesh( + mesh::Reactant.Sharding.Mesh; mod::MLIR.IR.Module=MLIR.IR.mmodule(), + sym_name::String="mesh", + location=mlir_stacktrace("mesh", @__FILE__, @__LINE__) + ) + mesh( + mesh_axes::Vector{<:Pair{<:Union{String,Symbol},Int64}}, + device_ids::Vector{Int64}; + sym_name::String="mesh", + mod::MLIR.IR.Module=MLIR.IR.mmodule(), + location=mlir_stacktrace("mesh", @__FILE__, @__LINE__) + ) + +Produces a [`Reactant.MLIR.Dialects.sdy.mesh`](@ref) operation with the given `mesh` and +`device_ids`. + +Based on the provided `sym_name``, we generate a unique name for the mesh in the module's +`SymbolTable`. Note that users shouldn't use this sym_name directly, instead they should +use the returned `sym_name` to refer to the mesh in the module. + +!!! warning + + The `device_ids` argument are the logical device ids, not the physical device ids. + For example, if the physical device ids are `[2, 4, 123, 293]`, the corresponding + logical device ids are `[0, 1, 2, 3]`. + +## Returned Value + +We return a NamedTuple with the following fields: + +- `sym_name`: The unique name of the mesh in the module's `SymbolTable`. +- `mesh_attr`: `sdy::mlir::MeshAttr` representing the mesh. +- `mesh_op`: The `sdy.mesh` operation. +""" +@noinline function mesh( + m::Reactant.Sharding.Mesh; + mod::MLIR.IR.Module=MLIR.IR.mmodule(), + sym_name::String="mesh", + location=mlir_stacktrace("mesh", @__FILE__, @__LINE__), +) + cache = Reactant.Compiler.sdycache(; throw_error=ReactantCore.within_compile()) + cache !== nothing && haskey(cache, m) && return cache[m] + result = mesh( + [k => Int64(v) for (k, v) in zip(m.axis_names, size(m))], + m.logical_device_ids; + mod, + sym_name, + location, + ) + cache !== nothing && (cache[m] = result) + return result +end + +@noinline function mesh( + mesh_axes::Vector{<:Pair{<:Union{String,Symbol},Int64}}, + device_ids::AbstractVector{Int64}; + mod::MLIR.IR.Module=MLIR.IR.mmodule(), + sym_name::String="mesh", + location=mlir_stacktrace("mesh", @__FILE__, @__LINE__), +) + # See https://github.com/openxla/shardy/blob/f9d83e779a58b811b848c4edfaf68e88b636787d/shardy/dialect/sdy/ir/verifiers.cc#L647-L699 for the checks + ndevices = prod(last, mesh_axes) + + @assert allunique(first, mesh_axes) "mesh_axes must be unique" + @assert ndevices == length(device_ids) "length(device_ids) should be same as \ + prod(last, mesh_axes)" + @assert all(Base.Fix2(≥, 0), device_ids) "device_ids must be non-negative" + @assert Base.sort(device_ids) == 0:(ndevices - 1) "sorted device_ids must be the same \ + as iota(product(axes)), got \ + $(Base.sort(device_ids))" + + # error: if the ordered device ids are the same as iota(product(axes)), no need to + # specify them for simplicity + issorted(device_ids) && (device_ids = Int64[]) + + ctx = MLIR.IR.context() + mesh_axis_attrs = [ + MLIR.API.sdyMeshAxisAttrGet(ctx, String(name), size) for (name, size) in mesh_axes + ] + mesh_attr = MLIR.API.sdyMeshAttrGet( + ctx, + Int64(length(mesh_axis_attrs)), + mesh_axis_attrs, + Int64(length(device_ids)), + collect(Int64, device_ids), + ) + + sym_name = Reactant.TracedUtils.__lookup_unique_name_in_module(mod, sym_name) + + mesh_op = MLIR.IR.mmodule!(mod) do + return MLIR.Dialects.sdy.mesh(; sym_name, mesh=mesh_attr, location) + end + + # mesh_op needs to be moved to the beginning of the module + mesh_op = MLIR.IR.rmfromparent!(mesh_op) + mod_body = MLIR.IR.body(mod) + pushfirst!(mod_body, mesh_op) + + # We return the name of the mesh, since the operation is a Symbol op + return (; + sym_name=MLIR.IR.FlatSymbolRefAttribute(sym_name; context=ctx), + mesh_attr=MLIR.IR.Attribute(mesh_attr), + mesh_op=mesh_op, + ) +end + +""" + sharding_constraint( + input::Union{TracedRArray,TracedRNumber}, + sharding::Reactant.Sharding.AbstractSharding; + location=mlir_stacktrace("sharding_constraint", @__FILE__, @__LINE__) + ) + +Produces a [`Reactant.MLIR.Dialects.sdy.sharding_constraint`](@ref) operation with the given +`input` and `sharding`. +""" +@noinline function sharding_constraint( + input::Union{AbstractArray,Number}, + sharding::Reactant.Sharding.AbstractSharding; + location=mlir_stacktrace("sharding_constraint", @__FILE__, @__LINE__), +) + !(input isa TracedRNumber || input isa TracedRArray) && + (input = constant(input; location)) + + cache = Reactant.Compiler.sdycache() + haskey(cache, sharding.mesh) || Ops.mesh(sharding.mesh; location) + (; sym_name, mesh_attr) = cache[sharding.mesh] + tensor_sharding_attr = Reactant.Sharding.get_shardy_tensor_sharding_attribute( + sharding, MLIR.IR.context(), sym_name, mesh_attr; do_transpose=true + ) + resharded_value = MLIR.IR.result( + MLIR.Dialects.sdy.sharding_constraint( + input.mlir_data; sharding=tensor_sharding_attr, location + ), + 1, + ) + if input isa TracedRNumber + return TracedRNumber{unwrapped_eltype(input)}(resharded_value) + else + return TracedRArray{unwrapped_eltype(input)}(resharded_value) + end +end + end # module Ops diff --git a/src/Overlay.jl b/src/Overlay.jl index b9785b7fa3..5d9b85c838 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -3,15 +3,6 @@ # correctly. Once that (https://github.com/timholy/Revise.jl/issues/646) is resolved # we should move all the reactant_overrides to relevant files. -# Helper Function to determine if we are inside the ReactantInterpreter -""" - within_reactant_interpreter() - -Returns `true` if we are currently inside the ReactantInterpreter. -""" -@noinline within_reactant_interpreter() = false -@reactant_overlay @noinline within_reactant_interpreter() = true - # Compiling within a compile should return simply the original function @reactant_overlay function Compiler.compile( f, args; client=nothing, optimize=true, sync=false @@ -19,6 +10,14 @@ Returns `true` if we are currently inside the ReactantInterpreter. return f end +@reactant_overlay @noinline function Base.setindex!( + a::AnyTracedRArray{T,N}, v, indices::Vararg{Any,N} +) where {T,N} + ancestor_indices = TracedUtils.get_ancestor_indices(a, indices...) + (Base.inferencebarrier(setindex!))(Reactant.ancestor(a), v, ancestor_indices...) + return a +end + # Enzyme.jl overlays @reactant_overlay @noinline function Enzyme.autodiff_deferred( rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs} @@ -37,6 +36,12 @@ end return call_with_reactant(TracedRandom.default_rng) end +@reactant_overlay @noinline function TracedRandom.default_rng() + return TracedRNG( + TracedUtils.promote_to(TracedRArray{UInt64,1}, TracedRandom.make_seed()), "DEFAULT" + ) +end + ## Only problematic edge case here is the direct `(rng, A::AbstractArray)` call ## We can't directly overlay that call without breaking the semantics of inplace update for randfun in (:rand, :randn, :randexp) @@ -51,13 +56,9 @@ for randfun in (:rand, :randn, :randexp) if T <: ReactantPrimitive return TracedRandom.$(overload_randfun)(rng, T, dims) end - return error( - "Reactant doesn't support sampling of $(T) with the current interpreter." - ) - # XXX: The following will lead to illegal instruction - # @warn "Reactant doesn't support sampling of $(T) with the current \ - # interpreter. Falling back to native interpreter." maxlog = 1 - # return Random.$(randfun)(rng, T, dims) + @warn "Reactant doesn't support sampling of $(T) with the current \ + interpreter. Falling back to native interpreter." maxlog = 1 + return Base.inferencebarrier(Random.$(randfun))(rng, T, dims) end @reactant_overlay @noinline function Random.$(randfun)( @@ -72,13 +73,9 @@ for randfun in (:rand, :randn, :randexp) if T <: ReactantPrimitive return TracedRandom.$(overload_randfun)(rng, T, dim1, dims...) end - return error( - "Reactant doesn't support sampling of $(T) with the current interpreter." - ) - # XXX: The following will lead to illegal instruction - # @warn "Reactant doesn't support sampling of $(T) with the current \ - # interpreter. Falling back to native interpreter." maxlog = 1 - # return Random.$(randfun)(rng, T, dim1, dims...) + @warn "Reactant doesn't support sampling of $(T) with the current \ + interpreter. Falling back to native interpreter." maxlog = 1 + return Base.inferencebarrier(Random.$(randfun))(rng, T, dim1, dims...) end # scalars @@ -88,13 +85,9 @@ for randfun in (:rand, :randn, :randexp) if T <: ReactantPrimitive return TracedRandom.$(overload_randfun)(rng, T) end - return error( - "Reactant doesn't support sampling of $(T) with the current interpreter." - ) - # XXX: The following will lead to illegal instruction - # @warn "Reactant doesn't support sampling of $(T) with the current \ - # interpreter. Falling back to native interpreter." maxlog = 1 - # return Random.$(randfun)(rng, T) + @warn "Reactant doesn't support sampling of $(T) with the current \ + interpreter. Falling back to native interpreter." maxlog = 1 + return Base.inferencebarrier(Random.$(randfun))(rng, T) end # inplace @@ -103,15 +96,66 @@ for randfun in (:rand, :randn, :randexp) ) return TracedRandom.$(overload_randfun!)(rng, A) end + end +end + +# LinearAlgebra.jl overloads +## `mul!` goes through too many layers of abstractions and we aren't able to overload +## without specializing on every possible combination of types +for (cT, aT, bT) in ( + (:AbstractVector, :AbstractMatrix, :AbstractVector), + (:AbstractMatrix, :AbstractMatrix, :AbstractVecOrMat), +) + @eval begin + @reactant_overlay @noinline function LinearAlgebra.mul!( + C::$cT, A::$aT, B::$bT, α::Number, β::Number + ) + A, B = aos_to_soa(A), aos_to_soa(B) + C2 = aos_to_soa(C) + if use_overlayed_version((C2, A, B)) + TracedLinearAlgebra.overloaded_mul!(C2, A, B, α, β) + if C2 !== C + C .= C2 + end + else + # Inference barrier is required when calling function recursively within overload + # This is required since otherwise type inference will think this is a recursive edge + # rather than a call to the base method + Base.inferencebarrier(LinearAlgebra.mul!)(C, A, B, α, β) + end + return C + end + + # Needed mostly for 1.10 where 3-arg mul is often specialized + @reactant_overlay @noinline function LinearAlgebra.mul!(C::$cT, A::$aT, B::$bT) + call_with_reactant(LinearAlgebra.mul!, C, A, B, true, false) + return C + end + end +end + +# Base overloads +@reactant_overlay @noinline function Base._stack(dims::Union{Integer,Colon}, iter) + if use_overlayed_version(iter) + return TracedRArrayOverrides.overloaded_stack(dims, iter) + else + iter2 = collect(iter) + if any(use_overlayed_version, iter2) + return TracedRArrayOverrides.overloaded_stack(dims, iter2) + else + # Inference barrier is required when calling function recursively within overload + # This is required since otherwise type inference will think this is a recursive edge + # rather than a call to the base method + return Base.inferencebarrier(Base._stack)(dims, iter2) + end + end +end - # XXX: Uncomment once AbsInt issues with recursive calls are resolved - # @reactant_overlay @noinline function Random.$(randfun!)( - # rng::AbstractRNG, A::AbstractArray - # ) - # @warn "Directly writing to an array using Random.jl functions inside \ - # ReactantInterpreter will generate a constant array in the IR. Use with \ - # caution." maxlog = 1 - # return Random.$(randfun!)(rng, A) - # end +## fixes #493 +@reactant_overlay @noinline function Base._unique_dims(A::AbstractArray, dims::Colon) + if use_overlayed_version(A) + error("Reactant doesn't have a `Base._unique_dims` with the current interpreter.") + else + Base.inferencebarrier(Base._unique_dims)(A, dims) end end diff --git a/src/Precompile.jl b/src/Precompile.jl new file mode 100644 index 0000000000..0860657e35 --- /dev/null +++ b/src/Precompile.jl @@ -0,0 +1,76 @@ +using PrecompileTools +using PrecompileTools: @setup_workload, @compile_workload + +function infer_sig(sig) + interp = ReactantInterpreter() + + min_world = Ref{UInt}(typemin(UInt)) + max_world = Ref{UInt}(typemax(UInt)) + + lookup_result = Reactant.lookup_world( + sig, interp.world, Core.Compiler.method_table(interp), min_world, max_world + ) + match = lookup_result::Core.MethodMatch + # look up the method and code instance + mi = ccall( + :jl_specializations_get_linfo, + Ref{Core.MethodInstance}, + (Any, Any, Any), + match.method, + match.spec_types, + match.sparams, + ) + + @static if VERSION < v"1.11" + # For older Julia versions, we vendor in some of the code to prevent + # having to build the MethodInstance twice. + result = CC.InferenceResult(mi, CC.typeinf_lattice(interp)) + frame = CC.InferenceState(result, :no, interp) + @assert !isnothing(frame) + CC.typeinf(interp, frame) + ir = CC.run_passes(frame.src, CC.OptimizationState(frame, interp), result, nothing) + rt = CC.widenconst(CC.ignorelimited(result.result)) + else + ir, rt = CC.typeinf_ircode(interp, mi, nothing) + end +end + +function clear_oc_cache() + # Opaque closures capture the worldage of their compilation and thus are not relocatable + # Therefore we explicitly purge all OC's we have created here + for v in oc_capture_vec + if v isa Base.RefValue + p = Ptr{Ptr{Cvoid}}(pointer_from_objref(v)) + Base.atomic_pointerset(p, C_NULL, :monotonic) + else + empty!(v) + end + end +end + +# Precompilation on 1.10 hits an apparent bug: https://github.com/JuliaLang/julia/issues/56947 +function precompilation_supported() + return VERSION >= v"1.11" || VERSION >= v"1.10.8" +end + +function precompiling() + return (@ccall jl_generating_output()::Cint) == 1 +end + +@setup_workload begin + initialize_dialect() + client = XLA.PJRT.CPUClient(; checkcount=false) + @compile_workload begin + @static if precompilation_supported() + x = ConcretePJRTNumber(2.0; client) + Reactant.compile(sin, (x,); client, optimize=:all) + + y = ConcretePJRTArray([2.0]; client) + Reactant.compile(Base.sum, (y,); client, optimize=:all) + end + end + XLA.free_client(client) + client.client = C_NULL + deinitialize_dialect() + clear_oc_cache() +end diff --git a/src/PrimitiveTypes.jl b/src/PrimitiveTypes.jl new file mode 100644 index 0000000000..6d73d381c3 --- /dev/null +++ b/src/PrimitiveTypes.jl @@ -0,0 +1,80 @@ +# The types listed in this file are the ones present in StableHLO specification. + +# These only exist for the purpose of lowering. Since `ReactantPrimitive` is a fixed set of +# types, users can use these to convert their types to the primitive types supported by +# Reactant. +for T in (:F8E5M2, :F8E4M3FN, :F8E4M3B11FNUZ, :F8E5M2FNUZ, :F8E4M3FNUZ) + @eval begin + primitive type $(T) <: AbstractFloat 8 end + + Base.promote_rule(::Type{$(T)}, ::Type{Float16}) = Float16 + Base.promote_rule(::Type{Float16}, ::Type{$(T)}) = Float16 + + Base.promote_rule(::Type{$(T)}, ::Type{Float32}) = Float32 + Base.promote_rule(::Type{Float32}, ::Type{$(T)}) = Float32 + + Base.promote_rule(::Type{$(T)}, ::Type{Float64}) = Float64 + Base.promote_rule(::Type{Float64}, ::Type{$(T)}) = Float64 + + Base.promote_rule(::Type{$(T)}, ::Type{<:Integer}) = $(T) + Base.promote_rule(::Type{<:Integer}, ::Type{$(T)}) = $(T) + + @static if isdefined(Core, :BFloat16) + Base.promote_rule(::Type{$(T)}, ::Type{Core.BFloat16}) = Core.BFloat16 + Base.promote_rule(::Type{Core.BFloat16}, ::Type{$(T)}) = Core.BFloat16 + end + + # For type conversion we simply rely on XLA + (::Type{inT})(x::$(T)) where {inT<:Number} = convert(inT, x) + (::Type{$(T)})(x::inT) where {inT<:Number} = convert($(T), x) + + function Base.convert(::Type{inT}, x::$(T)) where {inT<:Number} + @assert MLIR.IR._has_context() "currently only supported inside compiled functions" + x isa TracedRNumber || (x = Ops.constant(x)) + return Ops.convert(TracedRNumber{inT}, x) + end + + function Base.convert(::Type{$(T)}, x::inT) where {inT<:Number} + @assert MLIR.IR._has_context() "currently only supported inside compiled functions" + x isa TracedRNumber || (x = Ops.constant(x)) + return Ops.convert(TracedRNumber{unwrapped_eltype($(T))}, x) + end + end +end + +const ReactantFloat8 = Union{F8E5M2,F8E4M3FN,F8E4M3B11FNUZ,F8E5M2FNUZ,F8E4M3FNUZ} + +# TODO: Quantized types + +@static if isdefined(Core, :BFloat16) + const ReactantFloat = Union{ + Float16,Core.BFloat16,Float32,Float64,Base.uniontypes(ReactantFloat8)... + } +else + const ReactantFloat = Union{Float16,Float32,Float64,Base.uniontypes(ReactantFloat8)...} +end + +const ReactantComplexFloat = Union{Complex{Float32},Complex{Float64}} + +const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64} + +const ReactantComplexInt = Union{[Complex{T} for T in Base.uniontypes(ReactantInt)]...} + +const ReactantFloatInt = Union{ + Base.uniontypes(ReactantInt)...,Base.uniontypes(ReactantFloat)... +} + +const ReactantPrimitive = Union{ + Bool, + Base.uniontypes(ReactantFloatInt)..., + Base.uniontypes(ReactantComplexInt)..., + Base.uniontypes(ReactantComplexFloat)..., +} + +@inline to_reactant_primitive(v::T) where {T} = reinterpret(reactant_primitive(T), v) +@inline reactant_primitive(::Type{T}) where {T} = nothing + +for T in Base.uniontypes(ReactantPrimitive) + @eval @inline to_reactant_primitive(val::$(T)) = val + @eval @inline reactant_primitive(::Type{$(T)}) = $(T) +end diff --git a/src/Profiler.jl b/src/Profiler.jl new file mode 100644 index 0000000000..6120bae1fb --- /dev/null +++ b/src/Profiler.jl @@ -0,0 +1,189 @@ +module Profiler + +import ..Reactant +using Sockets: Sockets + +""" + with_profiler(f, trace_output_dir::String; trace_device=true, trace_host=true, create_perfetto_link=false) + +Runs the provided function under a profiler for XLA (similar to [JAX's profiler](https://jax.readthedocs.io/en/latest/profiling.html)). +The traces will be exported in the provided folder and can be seen +using tools like [perfetto.dev](https://ui.perfetto.dev). It will return the return values +from the function. The `create_perfetto_link` parameter can be used +to automatically generate a perfetto url to visualize the trace. + +```julia +with_profiler("./traces/") do + compiled_func = @compile myfunc(x, y, z) + compiled_func(x, y, z) +end +``` +""" +function with_profiler( + f, + trace_output_dir::String; + trace_device=true, + trace_host=true, + create_perfetto_link=false, +) + device_tracer_level = UInt32(trace_device ? 1 : 0) + host_tracer_level = UInt32(trace_host ? 2 : 0) + profiler = @ccall Reactant.MLIR.API.mlir_c.CreateProfilerSession( + device_tracer_level::UInt32, host_tracer_level::UInt32 + )::Ptr{Cvoid} + + results = try + f() + finally + @ccall Reactant.MLIR.API.mlir_c.ProfilerSessionCollectData( + profiler::Ptr{Cvoid}, trace_output_dir::Cstring + )::Cvoid + @ccall Reactant.MLIR.API.mlir_c.ProfilerSessionDelete(profiler::Ptr{Cvoid})::Cvoid + end + + if create_perfetto_link + traces_path = joinpath(trace_output_dir, "plugins", "profile") + date = maximum(readdir(traces_path)) + traces_path = joinpath(traces_path, date) + + filename = first(f for f in readdir(traces_path) if endswith(f, ".trace.json.gz")) + serve_to_perfetto(joinpath(traces_path, filename)) + end + + return results +end + +# https://github.com/google/tsl/blob/ffeadbc9111309a845ab07df3ff41d59cb005afb/tsl/profiler/lib/traceme.h#L49-L53 +const TRACE_ME_LEVEL_CRITICAL = Cint(1) +const TRACE_ME_LEVEL_INFO = Cint(2) +const TRACE_ME_LEVEL_VERBOSE = Cint(3) + +""" + annotate(f, name, [level=TRACE_ME_LEVEL_CRITICAL]) + +Generate an annotation in the current trace. +""" +function annotate(f, name, level=TRACE_ME_LEVEL_CRITICAL) + id = @ccall Reactant.MLIR.API.mlir_c.ProfilerActivityStart( + name::Cstring, level::Cint + )::Int64 + try + f() + finally + @ccall Reactant.MLIR.API.mlir_c.ProfilerActivityEnd(id::Int64)::Cvoid + end +end + +""" + @annotate [name] function foo(a, b, c) + ... + end + +The created function will generate an annotation in the captured XLA profiles. +""" +macro annotate(name, func_def=nothing) + noname = isnothing(func_def) + func_def = something(func_def, name) + + if !Meta.isexpr(func_def, :function) + error("not a function definition: $func_def") + end + + name = noname ? string(func_def.args[1].args[1]) : name + code = func_def.args[2] + + code = quote + annotate(() -> $(esc(code)), $(esc(name))) + end + + return Expr(:function, esc(func_def.args[1]), code) +end + +export with_profiler, annotate, @annotate + +function serve_to_perfetto(path_to_trace_file) + port_hint = 9001 + port, server = Sockets.listenany(port_hint) + + try + url = "https://ui.perfetto.dev/#!/?url=http://127.0.0.1:$(port)/$(basename(path_to_trace_file))" + @info "Open $url" + # open_in_default_browser(url) + + while true + isopen(server) || break + + io = Sockets.accept(server) + @debug "Got connection" + msg = String(readuntil(io, UInt8['\r', '\n', '\r', '\n'])) + @debug "Got request" msg + if startswith(msg, "OPTIONS") + isopen(io) || continue + write( + io, + """ + HTTP/1.1 501 + Server: Reactant.jl + Access-Control-Allow-Origin: * + Content-Length: 0 + + """, + ) + close(io) + continue + end + if startswith(msg, "POST") + isopen(io) || continue + write( + io, + """ + HTTP/1.1 404 + Server: Reactant.jl + Access-Control-Allow-Origin: * + Content-Length: 0 + + """, + ) + close(io) + continue + end + + file = read(path_to_trace_file) + file_size = length(file) + + isopen(io) || continue + write( + io, + """ + HTTP/1.1 200 + Server: Reactant.jl + Access-Control-Allow-Origin: * + Content-Length: $(file_size) + Content-Type: application/gzip + + """, + ) + + write(io, file) + break + end + finally + isopen(server) && close(server) + end +end + +@inline function free_profiler(exec) + @ccall MLIR.API.mlir_c.ProfilerServerStop(exec.exec::Ptr{Cvoid})::Cvoid +end + +mutable struct ProfileServer + exec::Ptr{Cvoid} + + function ProfileServer(port) + exec = @ccall Reactant.MLIR.API.mlir_c.ProfilerServerStart(port::Int32)::Ptr{Cvoid} + @assert exec != C_NULL + return finalizer(free_profiler, new(exec)) + end +end + +end # module Profiler diff --git a/src/Reactant.jl b/src/Reactant.jl index bea0150744..2b54ccf649 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -1,9 +1,10 @@ module Reactant -using ReactantCore: ReactantCore, @trace, MissingTracedValue +using ReactantCore: ReactantCore, @trace, within_compile, MissingTracedValue using LinearAlgebra: LinearAlgebra using Random: Random, AbstractRNG +using Functors: @leaf using Adapt: Adapt, WrappedArray using GPUArraysCore: GPUArraysCore, @allowscalar, allowscalar # keep this import to allow users to do `Reactant.allowscalar(false)` @@ -17,47 +18,7 @@ using Enzyme struct ReactantABI <: Enzyme.EnzymeCore.ABI end -@static if isdefined(Core, :BFloat16) - const ReactantPrimitive = Union{ - Bool, - Int8, - UInt8, - Int16, - UInt16, - Int32, - UInt32, - Int64, - UInt64, - Float16, - Core.BFloat16, - Float32, - Float64, - Complex{Float32}, - Complex{Float64}, - } -else - const ReactantPrimitive = Union{ - Bool, - Int8, - UInt8, - Int16, - UInt16, - Int32, - UInt32, - Int64, - UInt64, - Float16, - Float32, - Float64, - Complex{Float32}, - Complex{Float64}, - } -end - -abstract type RArray{T<:ReactantPrimitive,N} <: AbstractArray{T,N} end -abstract type RNumber{T<:ReactantPrimitive} <: Number end - -Base.collect(A::RArray) = copy(A) +include("PrimitiveTypes.jl") function ancestor(x::AbstractArray) p_x = parent(x) @@ -65,55 +26,75 @@ function ancestor(x::AbstractArray) return ancestor(p_x) end +function ancestor(T::Type{<:AbstractArray}) + if applicable(Adapt.parent_type, T) + p_T = Adapt.parent_type(T) + p_T == T && return T + return ancestor(p_T) + end + @warn "`Adapt.parent_type` is not implemented for $(T). Assuming $T isn't a wrapped \ + array." maxlog = 1 + return T +end + include("mlir/MLIR.jl") -include("XLA.jl") +include("xla/XLA.jl") +include("Sharding.jl") +include("Devices.jl") include("Interpreter.jl") +include("Profiler.jl") +include("Types.jl") +include("Distributed.jl") + +const with_profiler = Profiler.with_profiler + +export Sharding include("utils.jl") -mutable struct TracedRArray{T,N} <: RArray{T,N} - paths::Tuple - mlir_data::Union{Nothing,MLIR.IR.Value} - shape::NTuple{N,Int} - - function TracedRArray{T,N}( - paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value}, shape - ) where {T,N} - shape = Tuple(shape) - if !isnothing(mlir_data) - @assert size(MLIR.IR.type(mlir_data)) == shape - end - return new{T,N}(paths, mlir_data, shape) +function TracedRArray{T}(data::MLIR.IR.Value) where {T} + data_type = MLIR.IR.type(data) + if T == eltype(MLIR.IR.julia_type(data_type)) + return TracedRArray{T,ndims(data_type)}((), data, size(data_type)) end + tdata = TracedRArray(data) + return Ops.convert(TracedRArray{T,ndims(data_type)}, tdata) end -const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}} -const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}} -const AnyTracedRVector{T} = AnyTracedRArray{T,1} -const AnyTracedRMatrix{T} = Union{ - AnyTracedRArray{T,2},LinearAlgebra.Diagonal{T,TracedRArray{T,1}} -} -const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}} - function TracedRArray(data::MLIR.IR.Value) - data_type = MLIR.IR.type(data) - return TracedRArray{eltype(MLIR.IR.julia_type(data_type)),ndims(data_type)}( - (), data, size(data_type) - ) + return TracedRArray{eltype(MLIR.IR.julia_type(MLIR.IR.type(data)))}(data) end -mutable struct TracedRNumber{T} <: RNumber{T} - paths::Tuple - mlir_data::Union{Nothing,MLIR.IR.Value} +unwrapped_eltype(::Type{T}) where {T<:Number} = T +unwrapped_eltype(::Type{<:RNumber{T}}) where {T} = T +unwrapped_eltype(::Type{TracedRNumber{T}}) where {T} = T - function TracedRNumber{T}( - paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value} - ) where {T} - if !isnothing(mlir_data) - @assert size(MLIR.IR.type(mlir_data)) == () +unwrapped_eltype(::T) where {T<:Number} = T +unwrapped_eltype(::RNumber{T}) where {T} = T +unwrapped_eltype(::TracedRNumber{T}) where {T} = T + +unwrapped_eltype(::Type{<:RArray{T,N}}) where {T,N} = T +unwrapped_eltype(::Type{<:AbstractArray{T,N}}) where {T,N} = unwrapped_eltype(T) +unwrapped_eltype(::Type{<:AnyTracedRArray{T,N}}) where {T,N} = T + +unwrapped_eltype(::RArray{T,N}) where {T,N} = T +unwrapped_eltype(::AbstractArray{T,N}) where {T,N} = unwrapped_eltype(T) +unwrapped_eltype(::AnyTracedRArray{T,N}) where {T,N} = T + +aos_to_soa(x::AbstractArray) = x +aos_to_soa(x::AnyTracedRArray) = x +function aos_to_soa(x::AbstractArray{<:ConcretePJRTNumber{T}}) where {T} + x_c = ConcretePJRTArray(zeros(T, size(x))) + x_c .= x + return x_c +end +function aos_to_soa(x::AbstractArray{TracedRNumber{T}}) where {T} + for i in eachindex(x) + if !isassigned(x, i) + x[i] = TracedUtils.promote_to(TracedRNumber{T}, 0) end - return new{T}(paths, mlir_data) end + return Ops.reshape(vcat(x...), size(x)...) end include("Ops.jl") @@ -124,14 +105,27 @@ include("TracedRArray.jl") include("ConcreteRArray.jl") -mutable struct TracedRNG <: Random.AbstractRNG - seed::Union{ConcreteRArray{UInt64,1},TracedRArray{UInt64,1}} - const algorithm::String +use_overlayed_version(iter) = any(use_overlayed_version, iter) + +use_overlayed_version(::TracedRArray) = true +use_overlayed_version(::TracedRNumber) = true +use_overlayed_version(::Number) = false +use_overlayed_version(::MissingTracedValue) = true +use_overlayed_version(::TracedRNG) = true + +function use_overlayed_version(x::AbstractArray) + a = ancestor(x) + a === x && return false + return use_overlayed_version(a) end # StdLib Overloads include("stdlibs/LinearAlgebra.jl") include("stdlibs/Random.jl") +include("stdlibs/Base.jl") + +# Other Integrations +include("Enzyme.jl") const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue} @@ -143,7 +137,7 @@ include("Overlay.jl") function Enzyme.make_zero( ::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false) -)::RT where {copy_if_inactive,RT<:RArray} +)::RT where {copy_if_inactive,RT<:Union{RArray,RNumber}} if haskey(seen, prev) return seen[prev] end @@ -155,23 +149,87 @@ function Enzyme.make_zero( return res end -using .Compiler: @compile, @code_hlo, @jit, traced_getfield, create_result, compile -export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit, @trace - -const registry = Ref{MLIR.IR.DialectRegistry}() -function __init__() +using .Compiler: + @compile, + @code_hlo, + @code_mhlo, + @jit, + @code_xla, + traced_getfield, + create_result, + compile +export ConcreteRArray, + ConcreteRNumber, + ConcretePJRTArray, + ConcretePJRTNumber, + @compile, + @code_hlo, + @code_mhlo, + @code_xla, + @jit, + @trace, + within_compile + +const registry = Ref{Union{Nothing,MLIR.IR.DialectRegistry}}() + +const passes_initialized = Ref(false) +function initialize_dialect() registry[] = MLIR.IR.DialectRegistry() - @ccall MLIR.API.mlir_c.InitializeRegistryAndPasses( + @ccall MLIR.API.mlir_c.InitializeRegistry( registry[]::MLIR.API.MlirDialectRegistry )::Cvoid + if !passes_initialized[] + @ccall MLIR.API.mlir_c.InitializePasses( + registry[]::MLIR.API.MlirDialectRegistry + )::Cvoid + passes_initialized[] = true + end + return nothing +end + +function deinitialize_dialect() + passes_initialized[] = false + return registry[] = nothing end -function set_default_backend(backend::XLA.Client) - return XLA.default_backend[] = backend +using Libdl +using Reactant_jll +using LLVMOpenMP_jll +function initialize_ptrs() + for name in ( + "__kmpc_barrier", + "__kmpc_global_thread_num", + "__kmpc_for_static_fini", + "__kmpc_for_static_init_8u", + "__kmpc_fork_call", + ) + sym = Libdl.dlsym(LLVMOpenMP_jll.libomp_handle, name) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(name::Cstring, sym::Ptr{Cvoid})::Cvoid + end + if (@ccall MLIR.API.mlir_c.ReactantHermeticCudaGetVersion()::UInt32) != 0 + for name in ( + "cuLaunchKernel", + "cuModuleLoadData", + "cuModuleGetFunction", + "cuStreamSynchronize", + ) + sym = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, name) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(name::Cstring, sym::Ptr{Cvoid})::Cvoid + end + end end -function set_default_backend(backend::String) - return set_default_backend(XLA.backends[backend]) +function __init__() + initialize_ptrs() + initialize_dialect() + return nothing +end + +function set_default_backend(backend::Union{String,XLA.AbstractClient}) + XLA.set_default_backend(backend) + return nothing end +include("Precompile.jl") + end # module diff --git a/src/Sharding.jl b/src/Sharding.jl new file mode 100644 index 0000000000..0c431c8dc1 --- /dev/null +++ b/src/Sharding.jl @@ -0,0 +1,514 @@ +module Sharding + +using ..Reactant: Reactant, XLA, MLIR + +""" + Mesh(devices::AbstractArray{XLA.AbstractDevice}, axis_names) + +Construct a `Mesh` from an array of devices and a tuple of axis names. The size of the i-th +axis is given by `size(devices, i)`. All the axis names must be unique, and cannot be +nothing. + +## Examples + +Assuming that we have a total of 8 devices, we can construct a mesh with the following: + +```julia-repl +julia> devices = Reactant.devices(); + +julia> mesh = Mesh(reshape(devices, 2, 2, 2), (:x, :y, :z)); + +julia> mesh = Mesh(reshape(devices, 4, 2), (:x, :y)); +``` +""" +struct Mesh{D} + device_ids::Array{Int64,D} + logical_device_ids::UnitRange{Int} + axis_names::NTuple{D,Symbol} + + function Mesh(devices::AbstractArray{<:XLA.AbstractDevice}, axis_names) + return Mesh(XLA.device_ordinal.(devices), axis_names) + end + + function Mesh( + device_ids::AbstractArray{<:Integer,D}, axis_names::NTuple{D,Union{String,Symbol}} + ) where {D} + return new{D}(device_ids, 0:(length(device_ids) - 1), Symbol.(axis_names)) + end + + # XXX (Deprecated): remove in v0.3 + function Mesh( + devices::NTuple{D,<:XLA.AbstractDevice}, shape::Dims{D}, axis_names + ) where {D} + Base.depwarn( + "Mesh(devices::NTuple{D,<:XLA.AbstractDevice}, shape::Dims{D}, axis_names) is \ + deprecated, use Mesh(reshape(collect(XLA.device_ordinal.(devices)), shape), \ + axis_names) instead", + :Mesh, + ) + global_ids = reshape(collect(XLA.device_ordinal.(devices)), shape) + return Mesh(global_ids, axis_names) + end + + # XXX (Deprecated): remove in v0.3 + function Mesh( + device_ids::Dims{D1}, shape::Dims{D}, axis_names::NTuple{D,Union{String,Symbol}} + ) where {D,D1} + Base.depwarn( + "Mesh(device_ids::Dims{D1}, shape::Dims{D}, \ + axis_names::NTuple{D,Union{String,Symbol}}) is deprecated, use \ + Mesh(reshape(collect(Int64, device_ids), shape), axis_names) instead", + :Mesh, + ) + return Mesh(reshape(collect(Int64, device_ids), shape), axis_names) + end +end + +Base.length(m::Mesh) = length(m.device_ids) +Base.ndims(::Mesh{D}) where {D} = D + +Base.size(mesh::Mesh) = size(mesh.device_ids) +Base.size(mesh::Mesh, axis::Int) = size(mesh.device_ids, axis) +function Base.size(mesh::Mesh, axis::Union{String,Symbol}) + return size(mesh, findfirst(==(Symbol(axis)), mesh.axis_names)) +end +Base.size(mesh::Mesh, ::Nothing) = 1 + +Base.in(axis::Union{String,Symbol}, mesh::Mesh) = Symbol(axis) ∈ mesh.axis_names + +abstract type AbstractSharding end + +function (T::AbstractSharding)(::XLA.AbstractClient, device, ::Union{AbstractArray,Number}) + return error( + "(::$(T))(::XLA.AbstractClient, device, ::Union{AbstractArray,Number}) is \ + not implemented" + ) +end + +function get_shardy_tensor_sharding_attribute end + +""" + NoSharding() + +Sharding annotation that indicates that the array is not sharded. + +See also: [`Sharding.NamedSharding`](@ref) +""" +struct NoSharding <: AbstractSharding end + +@inline ndevices(::NoSharding) = 1 + +@inline shard_type(::Type{NoSharding}, _) = ShardInfo{NoSharding,Nothing} + +# This allows us to mark entire branches as NoSharding +Base.getproperty(::NoSharding, x) = NoSharding() +Base.getproperty(::NoSharding, x::Symbol) = NoSharding() + +function (::NoSharding)(client::XLA.PJRT.Client, device, x::Union{AbstractArray,Number}) + device === nothing && (device = XLA.default_device(client)) + buffer = XLA.PJRT.AsyncBuffer(client, x, device) + return (buffer,), ShardInfo(NoSharding(), nothing) +end + +""" + NamedSharding( + mesh::Mesh, partition_spec::Tuple; + is_closed::NTuple{N,Bool}=ntuple(Returns(true), length(partition_spec)), + priority::NTuple{N,Int}=ntuple(i -> -1, length(partition_spec)), + ) + +Sharding annotation that indicates that the array is sharded along the given `partition_spec`. For details on the sharding representation see the +[Shardy documentation](https://openxla.org/shardy/sharding_representation). + +## Arguments + + - `mesh`: [`Sharding.Mesh`](@ref) that describes the mesh of the devices. + - `partition_spec`: Must be equal to the ndims of the array being sharded. Each element + can be: + 1. `nothing`: indicating the corresponding dimension is replicated along the axis. + 2. A tuple of axis names indicating the axis names that the corresponding dimension + is sharded along. + 3. A single axis name indicating the axis name that the corresponding dimension is + sharded along. + +## Keyword Arguments + + - `is_closed`: A tuple of booleans indicating whether the corresponding dimension is + closed along the axis. Defaults to `true` for all dimensions. + - `priority`: A tuple of integers indicating the priority of the corresponding dimension. + Defaults to `-1` for all dimensions. A negative priority means that the priority is + not considered by shardy. + +## Examples + +```julia-repl +julia> devices = Reactant.devices(); + +julia> mesh = Mesh(reshape(devices, 2, 2, 2), (:x, :y, :z)); + +julia> sharding = NamedSharding(mesh, (:x, :y, nothing)); # 3D Array sharded along x and y on dim 1 and 2 respectively, while dim 3 is replicated + +julia> sharding = NamedSharding(mesh, ((:x, :y), nothing, nothing)); # 3D Array sharded along x and y on dim 1, 2 and 3 are replicated + +julia> sharding = NamedSharding(mesh, (nothing, nothing)); # fully replicated Matrix +``` + +See also: [`Sharding.NoSharding`](@ref) +""" +struct NamedSharding{D1,D2,P<:Tuple} <: AbstractSharding + mesh::Mesh{D1} + partition_spec::P + is_closed::NTuple{D2,Bool} + priority::NTuple{D2,Int} + + function NamedSharding( + mesh::Mesh{D1}, + partition_spec::P; + is_closed::NTuple{D2,Bool}=ntuple(Returns(true), length(partition_spec)), + priority::NTuple{D2,Int}=ntuple(i -> -1, length(partition_spec)), + ) where {D1,P<:Tuple,D2} + axis_names = Symbol[] + pspec = () + for p in partition_spec + if p === nothing + pspec = (pspec..., nothing) + elseif p isa Tuple + @assert all(x -> x isa Symbol || x isa String, p) + sym_names = Symbol.(p) + append!(axis_names, sym_names) + pspec = (pspec..., sym_names) + elseif p isa Symbol || p isa String + push!(axis_names, Symbol(p)) + pspec = (pspec..., Symbol(p)) + else + error("Unexpected partition spec $(partition_spec) [$(p)]") + end + end + @assert allunique(axis_names) "Duplicate axis names!" + + return new{D1,D2,typeof(pspec)}(mesh, pspec, is_closed, priority) + end +end + +@inline ndevices(sharding::NamedSharding) = length(sharding.mesh.device_ids) + +@inline function shard_type(::Type{NamedSharding{D1,D2,P}}, N) where {D1,D2,P} + return shard_type(HloSharding{D1,D2}, N) +end + +function (sharding::NamedSharding)( + client::XLA.PJRT.Client, device::Nothing, x::Union{AbstractArray,Number} +) + @assert length(sharding.partition_spec) == ndims(x) + return HloSharding(sharding, client, device, x) +end + +function get_shardy_tensor_sharding_attribute( + sharding::NamedSharding, ctx, mesh_name, mesh_attr; do_transpose=true +) + dimension_sharding_attrs = Vector{MLIR.API.MlirAttribute}( + undef, length(sharding.partition_spec) + ) + for (j, name) in enumerate(sharding.partition_spec) + if name === nothing + axes = MLIR.IR.Attribute[] + else + names = name isa Symbol ? (name,) : name + axes = [ + MLIR.API.sdyAxisRefAttrGet( + ctx, String(name), MLIR.API.MlirAttribute(C_NULL) + ) for name in names + ] + end + dimension_sharding_attrs[j] = MLIR.API.sdyDimensionShardingAttrGet( + ctx, length(axes), axes, sharding.is_closed[j], sharding.priority[j] + ) + end + + return MLIR.IR.Attribute( + MLIR.API.sdyTensorShardingAttrGet( + ctx, + mesh_name, + length(dimension_sharding_attrs), + do_transpose ? reverse(dimension_sharding_attrs) : dimension_sharding_attrs, + 0, + MLIR.API.MlirAttribute[], + ), + ) +end + +# TODO: Something like NamedDims.jl will allow us to support NamedDimsSharding similar to +# `levanter` + +""" + DimsSharding( + mesh::Mesh{M}, + dims::NTuple{D,Int}, + partition_spec; + is_closed::NTuple{D,Bool}=ntuple(Returns(true), D), + priority::NTuple{D,Int}=ntuple(i -> -1, D), + ) + +Similar to [`NamedSharding`](@ref) but works for a arbitrary dimensional array. Dimensions +not specified in `dims` are replicated. If any dimension in `dims` is greater than the total +number of dimensions in the array, the corresponding `partition_spec`, `is_closed` and +`priority` are ignored. Additionally for any negative dimensions in `dims`, the true +dims are calculated as `ndims(x) - dim + 1`. A dims value of `0` will throw an error. +""" +struct DimsSharding{M,D,P} <: AbstractSharding + mesh::Mesh{M} + dims::NTuple{D,Int} + partition_spec::P + is_closed::NTuple{D,Bool} + priority::NTuple{D,Int} + + function DimsSharding( + mesh::Mesh{M}, + dims::NTuple{D,Int}, + partition_spec; + is_closed::NTuple{D,Bool}=ntuple(Returns(true), length(partition_spec)), + priority::NTuple{D,Int}=ntuple(i -> -1, length(partition_spec)), + ) where {M,D} + @assert length(partition_spec) == length(dims) + # Validity checks on the inputs are deferred to NamedSharding + return new{M,D,typeof(partition_spec)}( + mesh, dims, partition_spec, is_closed, priority + ) + end +end + +@inline ndevices(sharding::DimsSharding) = length(sharding.mesh.device_ids) + +@inline function shard_type(::Type{DimsSharding{M,D,P}}, N) where {M,D,P} + return shard_type(HloSharding{M,N}, N) +end + +function standardize_sharding(sharding::DimsSharding, x::Union{AbstractArray,Number}) + final_dims = map(sharding.dims) do d + @assert !iszero(d) "dims cannot contain 0" + return ifelse(d < 0, ndims(x) + d + 1, d) + end + + dim_indices = ntuple(i -> findfirst(==(i), final_dims), ndims(x)) + partition_spec = ntuple(ndims(x)) do i + dim_index = dim_indices[i] + dim_index === nothing && return nothing # replicated dimension + return sharding.partition_spec[dim_index] + end + is_closed = ntuple(ndims(x)) do i + dim_index = dim_indices[i] + dim_index === nothing && return true # replicated dimension + return sharding.is_closed[dim_index] + end + priority = ntuple(ndims(x)) do i + dim_index = dim_indices[i] + dim_index === nothing && return -1 # replicated dimension + return sharding.priority[dim_index] + end + + return NamedSharding(sharding.mesh, partition_spec; is_closed, priority) +end + +function (sharding::DimsSharding)( + client::XLA.PJRT.Client, device::Nothing, x::Union{AbstractArray,Number} +) + return (standardize_sharding(sharding, x))(client, device, x) +end + +# HloSharding +# This stores the sharding information in the form of XLA.HloSharding, and provides a +# central type for the final storage. It also potentially saves us the pain of not having +# to regenerate the partition spec from the HloSharding. +struct HloSharding{D1,D2} <: AbstractSharding + hlo_sharding::XLA.HloSharding + mesh::Mesh{D1} + is_closed::NTuple{D2,Bool} + priority::NTuple{D2,Int} + + function HloSharding( + hlo_sharding::XLA.HloSharding, mesh::Mesh{D1}, is_closed, priority + ) where {D1} + @assert length(is_closed) == length(priority) + return new{D1,length(is_closed)}(hlo_sharding, mesh, is_closed, priority) + end +end + +@inline ndevices(sharding::HloSharding) = length(sharding.mesh.device_ids) + +@inline function shard_type(::Type{HloSharding{D1,D2}}, N) where {D1,D2} + return ShardInfo{HloSharding{D1,D2},Vector{NTuple{N,UnitRange{Int64}}}} +end + +# This doesn't account for the size of the input so in-presence of padding this will be +# incorrect. Hence always use the HloSharding constructor. +function generate_hlo_sharding_from_tensor_attribute(sharding::NamedSharding) + if MLIR.IR._has_context() + ctx = MLIR.IR.context() + else + ctx = MLIR.IR.Context(Reactant.registry[], false) + @ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid + end + + MLIR.IR.context!(ctx) do + mesh_op = Reactant.Ops.mesh( + sharding.mesh; mod=MLIR.IR.Module(MLIR.IR.Location(; context=ctx)) + ) + + tensor_sharding_attr = get_shardy_tensor_sharding_attribute( + sharding, ctx, mesh_op.sym_name, mesh_op.mesh_attr; do_transpose=true + ) + + return HloSharding( + XLA.HloSharding( + @ccall MLIR.API.mlir_c.hloShardingFromTensorShardingAttr( + tensor_sharding_attr.attribute::MLIR.API.MlirAttribute, + mesh_op.mesh_attr.attribute::MLIR.API.MlirAttribute, + )::Ptr{Cvoid} + ), + sharding.mesh, + sharding.is_closed, + sharding.priority, + ) + end +end + +function HloSharding(sharding::NamedSharding, client::XLA.PJRT.Client, _, x) + hlo_sharding = generate_hlo_sharding_from_tensor_attribute(sharding) + + # Check if the input needs to be padded. If so this sharding is not valid and we + # need to request the tensor sharding from XLA + condensed_op_sharding = convert(XLA.CondensedOpSharding, hlo_sharding.hlo_sharding) + device_to_array_slices, needs_padding = XLA.sharding_to_concrete_array_indices( + condensed_op_sharding, size(x), hlo_sharding.mesh.logical_device_ids + ) + + if needs_padding + # Compile a dummy function to get the tensor sharding + tmp = if x isa Number + Reactant.ConcretePJRTNumber(zero(eltype(x))) + else + Reactant.ConcretePJRTArray(ones(eltype(x), size(x)...)) + end + _, exec, _, _, _ = Reactant.Compiler.compile_xla( + Reactant.Ops.negate, (tmp,); input_shardings=IdDict(tmp => sharding) + ) + xla_hlo_sharding = convert( + Reactant.XLA.HloSharding, only(Reactant.XLA.get_parameter_shardings(exec)) + ) + hlo_sharding = HloSharding( + xla_hlo_sharding, + hlo_sharding.mesh, + hlo_sharding.is_closed, + hlo_sharding.priority, + ) + + condensed_op_sharding = convert(XLA.CondensedOpSharding, hlo_sharding.hlo_sharding) + device_to_array_slices, needs_padding = XLA.sharding_to_concrete_array_indices( + condensed_op_sharding, size(x), hlo_sharding.mesh.logical_device_ids + ) + end + + data = ntuple(length(hlo_sharding.mesh)) do i + XLA.PJRT.AsyncBuffer( + client, + x[device_to_array_slices[i]...], + XLA.get_device(client, hlo_sharding.mesh.device_ids[i]), + ) + end + + return data, ShardInfo(hlo_sharding, device_to_array_slices) +end + +function (sharding::HloSharding)( + client::XLA.PJRT.Client, ::Nothing, x::Union{AbstractArray,Number} +) + condensed_op_sharding = convert(XLA.CondensedOpSharding, sharding.hlo_sharding) + + device_to_array_slices, needs_padding = XLA.sharding_to_concrete_array_indices( + condensed_op_sharding, size(x), sharding.mesh.logical_device_ids + ) + @assert !needs_padding "This shouldn't happen. Open an issue on Reactant.jl" + + data = ntuple(length(sharding.mesh)) do i + XLA.PJRT.AsyncBuffer( + client, + x[device_to_array_slices[i]...], + XLA.get_device(client, sharding.mesh.device_ids[i]), + ) + end + + return data, ShardInfo(sharding, device_to_array_slices) +end + +function get_shardy_tensor_sharding_attribute( + sharding::HloSharding, ctx, mesh_name, mesh_attr; kwargs... +) + string_mesh_name = MLIR.IR.Attribute(MLIR.IR.flatsymbol(mesh_name); context=ctx) + GC.@preserve sharding begin + return MLIR.IR.Attribute( + @ccall MLIR.API.mlir_c.hloShardingToTensorShardingAttr( + ctx::MLIR.API.MlirContext, + sharding.hlo_sharding.ptr::Ptr{Cvoid}, + string_mesh_name.attribute::MLIR.API.MlirAttribute, + mesh_attr.attribute::MLIR.API.MlirAttribute, + Int64(length(sharding.is_closed))::Int64, + Bool[sharding.is_closed...]::Ptr{Bool}, + Int64[sharding.priority...]::Ptr{Int64}, + )::MLIR.API.MlirAttribute + ) + end +end + +# Given Sharding + Array --> ShardInfo +# This is the structure that is stored in the `sharding` field of `ConcreteRArray` +struct ShardInfo{S,D} <: AbstractSharding + sharding::S + device_to_array_slices::D +end + +@inline ndevices(sharding::ShardInfo) = length(sharding.mesh) + +@inline shard_type(::Type{ShardInfo{S,D}}, N) where {S,D} = shard_type(S, N) + +function Base.getproperty(sharding::ShardInfo, name::Symbol) + name ∈ (:sharding, :device_to_array_slices) && return getfield(sharding, name) + return getproperty(sharding.sharding, name) +end + +function get_shardy_tensor_sharding_attribute(sharding::ShardInfo, args...; kwargs...) + return get_shardy_tensor_sharding_attribute(sharding.sharding, args...; kwargs...) +end + +function (sharding::ShardInfo)( + client::XLA.AbstractClient, device, x::Union{AbstractArray,Number} +) + return (sharding.sharding)(client, device, x) +end + +const NoShardInfo = ShardInfo{NoSharding,Nothing} + +ShardInfo{NoSharding,Nothing}() = ShardInfo(NoSharding(), nothing) + +""" + is_sharded(sharding) + is_sharded(x::AbstractArray) + +Checks whether the given sharding refers to no sharding. +""" +is_sharded(::NoSharding) = false +is_sharded(::NamedSharding) = true +is_sharded(::DimsSharding) = true +is_sharded(::HloSharding) = true +is_sharded(s::ShardInfo) = is_sharded(s.sharding) + +function is_sharded(x::AbstractArray) + ancestor_x = Reactant.ancestor(x) + hasfield(typeof(ancestor_x), :sharding) && return is_sharded(ancestor_x.sharding) + return false +end +function is_sharded(x::Number) + hasfield(typeof(x), :sharding) && return is_sharded(x.sharding) + return false +end + +end diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 90135e3206..43d0421418 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -1,23 +1,47 @@ module TracedRArrayOverrides +using Adapt: WrappedReshapedArray, WrappedArray using Base.Broadcast using Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle, instantiate -import ..TracedRArray -import ..TracedRNumber -import ..ReactantPrimitive -import ..WrappedTracedRArray -import ..AnyTracedRArray -using ..TracedUtils -import ..Ops -import ..MLIR -import ..ancestor +using ..Reactant: + Reactant, + TracedRArray, + TracedRNumber, + WrappedTracedRArray, + AnyTracedRArray, + AnyTracedRVector, + Ops, + MLIR, + ancestor, + allowscalar, + aos_to_soa, + unwrapped_eltype +using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!, materialize_traced_array + using ReactantCore: ReactantCore -import ..TracedUtils: materialize_traced_array -using GPUArraysCore: GPUArraysCore +using GPUArraysCore: GPUArraysCore, @allowscalar ReactantCore.is_traced(::TracedRArray) = true +Base.strides(x::TracedRArray) = Base.size_to_strides(1, size(x)...) + +Base.IndexStyle(::Type{<:TracedRArray}) = Base.IndexLinear() + +# This is required otherwise we will copy a tracedrarray each time +# we use it +function Base.convert(::Type{TracedRArray}, x::TracedRArray) + return x +end + +function Base.convert(::Type{TracedRArray}, x::AnyTracedRArray) + return Base.convert(TracedRArray{unwrapped_eltype(x),ndims(x)}, x) +end + +function Base.convert(::Type{TracedRArray}, x::AbstractArray) + return Base.convert(TracedRArray{eltype(x),ndims(x)}, x) +end + function Base.convert(::Type{TracedRArray{T,N}}, x::AbstractArray) where {T,N} @assert ndims(x) == N if x isa TracedRArray @@ -26,6 +50,9 @@ function Base.convert(::Type{TracedRArray{T,N}}, x::AbstractArray) where {T,N} end x isa WrappedTracedRArray && return convert(TracedRArray{T,N}, materialize_traced_array(x)) + if eltype(x) <: TracedRNumber + return convert(TracedRArray{T,N}, aos_to_soa(x)) + end return convert(TracedRArray{T,N}, Ops.constant(collect(x))) end @@ -53,25 +80,132 @@ function Base.getindex( return TracedRNumber{T}((), res2) end -function Base.getindex(a::TracedRArray{T,0}) where {T} +Base.getindex(a::TracedRArray{T,0}) where {T} = TracedRNumber{T}((), a.mlir_data) +function Base.getindex(a::TracedRArray{T,0}, ::CartesianIndex{0}) where {T} return TracedRNumber{T}((), a.mlir_data) end -# XXX: We want to support https://github.com/EnzymeAD/Reactant.jl/issues/242 eventually +function generate_index_list(i1, is...) + list = reshape(i1, :, 1) .- 1 + for i in is + i = TracedUtils.broadcast_to_size(i, (length(i), 1)) + lorig = size(list, 1) + list = repeat(list, size(i, 1), 1) + i = repeat(i; inner=(lorig, 1)) .- 1 + list = hcat(list, i) + end + return list +end + +function scalar_index_to_cartesian(idx::AbstractVector{T}, sz::NTuple{N,Int}) where {T,N} + idx = idx .- 1 + idxs = materialize_traced_array(reshape(idx .% T(sz[1]), :, 1)) + idx = idx .÷ T(sz[1]) + for i in 2:N + idxs = hcat(idxs, idx .% T(sz[i])) + idx = idx .÷ T(sz[i]) + end + return idxs +end + +function scalar_index_to_cartesian(idx::T, sz::NTuple{N,Int}) where {T<:Number,N} + idx = idx - 1 + idxs = (idx % T(sz[1]),) + idx = idx ÷ T(sz[1]) + for i in 2:N + idxs = (idxs..., idx % T(sz[i])) + idx = idx ÷ T(sz[i]) + end + return idxs +end + +function Base.getindex( + a::TracedRArray{T,N}, indices::Union{Int,TracedRNumber{Int}} +) where {T,N} + if indices isa Int + indices = TracedUtils.promote_to(TracedRNumber{Int}, indices) + end + indices = TracedUtils.broadcast_to_size(indices, (1,)) + return Ops.gather_getindex(a, scalar_index_to_cartesian(indices, size(a)))[1] +end + +function Base.getindex(a::TracedRArray{T,N}, indices) where {T,N} + if !(indices isa TracedRArray) + indices = collect(indices) + eltype(indices) <: CartesianIndex && (indices = LinearIndices(size(a))[indices]) + indices = TracedUtils.promote_to(TracedRArray{Int,ndims(indices)}, indices) + end + return materialize_traced_array( + reshape( + Ops.gather_getindex(a, scalar_index_to_cartesian(vec(indices), size(a))), + size(indices), + ), + ) +end + +Base.getindex(a::TracedRArray{T,N}, ::Colon) where {T,N} = materialize_traced_array(vec(a)) + +function Base.getindex(a::TracedRArray{T,N}, indices::CartesianIndex{N}) where {T,N} + indices = + materialize_traced_array( + reshape( + TracedUtils.promote_to( + TracedRArray{Int,1}, collect(Int64, vcat(Tuple(indices)...)) + ), + 1, + N, + ), + ) .- 1 + return Ops.gather_getindex(a, indices)[1] +end + +# Needed to prevent method ambiguity +function Base.getindex(a::TracedRArray{T,1}, indices::CartesianIndex{1}) where {T} + indices = + materialize_traced_array( + reshape( + TracedUtils.promote_to( + TracedRArray{Int,1}, collect(Int64, vcat(Tuple(indices)...)) + ), + 1, + 1, + ), + ) .- 1 + return Ops.gather_getindex(a, indices)[1] +end + function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} - indices = map(enumerate(indices)) do (idx, i) - i isa Colon && return 1:size(a, idx) - i isa CartesianIndex && return Tuple(i) - return i + indices = TracedUtils.normalize_indices(a, indices...) + + use_gather_getindex = false + for idxs in indices + idxs isa Number && continue + if idxs isa Reactant.TracedType + use_gather_getindex = true + break + end + contiguous = all(isone, diff(vec(idxs))) + if typeof(contiguous) <: Bool && !contiguous + use_gather_getindex = true + break + end end - foreach(indices) do idxs - idxs isa Number && return nothing - contiguous = all(isone, diff(idxs)) - # XXX: We want to throw error even for dynamic indexing - if typeof(a) <: Bool - contiguous || error("non-contiguous indexing is not supported") + if use_gather_getindex + # TODO: This will create a dynamically sized tensor and we need to implement + # `findall` for it. + if any(i -> unwrapped_eltype(i) <: Bool, indices) + error("Boolean indexing with TracedRArrays isn't fully supported yet.") end + indices, integer_indices, result_size, preddim_result_size, _ = TracedUtils.traced_indices( + indices... + ) + res = Ops.reshape( + Ops.gather_getindex(a, generate_index_list(indices...)), preddim_result_size + ) + isempty(integer_indices) || + (res = materialize_traced_array(dropdims(res; dims=integer_indices))) + return Ops.reshape(res, result_size) end start_indices = map(indices) do i @@ -83,8 +217,10 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} ) x = TracedRArray{T,N}((), res, Tuple(length.(indices))) - ddims = findall(Base.Fix2(isa, Integer), indices) - isempty(ddims) || return dropdims(x; dims=Tuple(ddims)) + ddims = findall(indices) do idx + return idx isa Integer || idx isa TracedRNumber{<:Integer} + end + isempty(ddims) || return materialize_traced_array(dropdims(x; dims=Tuple(ddims))) return x end @@ -97,16 +233,161 @@ function Base.getindex(a::WrappedTracedRArray, indices...) return getindex(ancestor(a), TracedUtils.get_ancestor_indices(a, indices...)...) end +## Specialize certain dispatches for better codegen +for aType in ( + WrappedReshapedArray{TracedRNumber{T},N,TracedRArray{T,M}} where {T,N,M}, + PermutedDimsArray{ + TracedRNumber{T},N,perm,iperm,TracedRArray{T,N} + } where {T,N,perm,iperm}, +) + @eval begin + function Base.getindex(a::$aType, indices::Union{Int,TracedRNumber{Int}}...) + return getindex(materialize_traced_array(a), indices...) + end + + function Base.getindex(a::$aType, indices...) + return getindex(materialize_traced_array(a), indices...) + end + end +end + +function maybe_assert_scalar_setindexing( + ::TracedRArray{T,N}, ::Vararg{Union{Int,TracedRNumber{Int}},N} +) where {T,N} + GPUArraysCore.assertscalar("setindex!(::TracedRArray, v, ::Vararg{Int, N})") + return nothing +end + +maybe_assert_scalar_setindexing(args...) = nothing + function Base.setindex!( - a::TracedRArray{T,N}, - v, - indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int,TracedRNumber{Int}},N}, + a::TracedRArray{T,N}, v, indices::Union{Int,TracedRNumber{Int}} ) where {T,N} - indices = map(enumerate(indices)) do (idx, i) - i isa Int ? (i:i) : (i isa Colon ? (1:size(a, idx)) : i) + GPUArraysCore.assertscalar( + "setindex!(::TracedRArray, v, ::Union{Int, TracedRNumber{Int}})" + ) + if indices isa Int + indices = TracedUtils.promote_to(TracedRNumber{Int}, indices) + end + indices = scalar_index_to_cartesian( + TracedUtils.broadcast_to_size(indices, (1,)), size(a) + ) + v = v isa Number ? v : vec(v) + res = Ops.scatter_setindex(a, indices, TracedUtils.broadcast_to_size(v, (1,))) + set_mlir_data!(a, get_mlir_data(res)) + return a +end + +# Avoid ambiguity +function Base.setindex!( + a::TracedRArray{T,1}, v, indices::Union{Int,TracedRNumber{Int}} +) where {T} + GPUArraysCore.assertscalar( + "setindex!(::TracedRArray, v, ::Union{Int, TracedRNumber{Int}})" + ) + if indices isa Int + indices = TracedUtils.promote_to(TracedRNumber{Int}, indices) + end + indices = scalar_index_to_cartesian( + TracedUtils.broadcast_to_size(indices, (1,)), size(a) + ) + v = v isa Number ? v : vec(v) + res = Ops.scatter_setindex(a, indices, TracedUtils.broadcast_to_size(v, (1,))) + set_mlir_data!(a, get_mlir_data(res)) + return a +end + +function Base.setindex!(a::TracedRArray{T,N}, v, indices) where {T,N} + if !(indices isa TracedRArray) + indices = collect(indices) + eltype(indices) <: CartesianIndex && (indices = LinearIndices(size(a))[indices]) + indices = TracedUtils.promote_to(TracedRArray{Int,ndims(indices)}, indices) end - v = TracedUtils.broadcast_to_size(v, length.(indices)) - v = TracedUtils.promote_to(TracedRArray{T,N}, v) + res = Ops.scatter_setindex( + a, + scalar_index_to_cartesian(vec(indices), size(a)), + materialize_traced_array(vec(v)), + ) + set_mlir_data!(a, get_mlir_data(res)) + return a +end + +function Base.setindex!(a::TracedRArray{T,N}, v, indices::CartesianIndex{N}) where {T,N} + GPUArraysCore.assertscalar("setindex!(::TracedRArray, v, ::CartesianIndex{N})") + indices = + materialize_traced_array( + reshape( + TracedUtils.promote_to( + TracedRArray{Int,1}, collect(Int64, vcat(Tuple(indices)...)) + ), + 1, + N, + ), + ) .- 1 + v = v isa Number ? v : vec(v) + res = Ops.scatter_setindex(a, indices, TracedUtils.broadcast_to_size(v, (1,))) + set_mlir_data!(a, get_mlir_data(res)) + return a +end + +function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {T,N} + if (N == 1) && (indices isa Colon) + # Remove ambiguity from the previous + # ```julia + # Base.setindex!(a::TracedRArray{T,N}, v, ::Colon) where {T,N} + # ``` + # signature, which would be confused with this one for N=1. + v = TracedUtils.broadcast_to_size(v, size(a)) + set_mlir_data!(a, get_mlir_data(v)) + return a + end + maybe_assert_scalar_setindexing(a, indices...) + + indices = TracedUtils.normalize_indices(a, indices...) + + use_scatter_setindex = false + for idxs in indices + idxs isa Number && continue + if idxs isa Reactant.TracedType + use_scatter_setindex = true + break + end + contiguous = all(isone, diff(idxs)) + if typeof(contiguous) <: Bool && !contiguous + use_scatter_setindex = true + break + end + end + + if use_scatter_setindex + # TODO: This will create a dynamically sized tensor and we need to implement + # `findall` for it. + if any(i -> unwrapped_eltype(i) <: Bool, indices) + error("Boolean indexing with TracedRArrays isn't fully supported yet.") + end + indices_list = map(Base.Fix1(TracedUtils.promote_to, TracedRArray{Int,1}), indices) + indices_list = generate_index_list(indices_list...) + res = Ops.scatter_setindex(a, indices_list, Ops.reshape(v, length(v))) + set_mlir_data!(a, get_mlir_data(res)) + return v + end + + if v isa Number + v = TracedUtils.broadcast_to_size(v, length.(indices)) + v = TracedUtils.promote_to(TracedRArray{T,N}, v) + else + v = TracedUtils.promote_to(TracedRArray{T,ndims(v)}, v) + non_integer_indices = [!(idx isa Integer) for idx in indices] + broadcast_dims = findall(non_integer_indices) + if length(broadcast_dims) == N + v = TracedUtils.broadcast_to_size(v, length.(indices)) + else + v = Ops.broadcast_in_dim( + materialize_traced_array(v), broadcast_dims, Int64.(length.(indices)) + ) + end + end + indices = [ ( TracedUtils.promote_to(TracedRNumber{Int}, i isa Colon ? 1 : first(i)) - 1 @@ -118,29 +399,20 @@ function Base.setindex!( ), 1, ) - a.mlir_data = res + set_mlir_data!(a, res) return v end -function Base.setindex!( - a::AnyTracedRArray{T,N}, - v, - indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int,TracedRNumber{Int}},N}, -) where {T,N} - ancestor_indices = TracedUtils.get_ancestor_indices(a, indices...) - setindex!(ancestor(a), v, ancestor_indices...) - return a -end - Base.Tuple(x::TracedRArray) = ntuple(Base.Fix1(Base.getindex, x), length(x)) Base.size(x::TracedRArray) = x.shape +Base.collect(x::TracedRArray) = copy(x) # XXX: Is this correct? + Base.copy(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), A.mlir_data, size(A)) -# TODO is there a way to create an unitialized `tensor`? does it show an advantage? maybe `fill`? function Base.similar(::TracedRArray, ::Type{T}, dims::Dims{N}) where {T,N} - return Ops.constant(zeros(T, dims)) + return Ops.fill(zero(unwrapped_eltype(T)), dims) end function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOContext}} @@ -173,6 +445,11 @@ TracedUtils.promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N} = TracedRArra function TracedUtils.promote_to(::TracedRArray{T,N}, rhs) where {T,N} return TracedUtils.promote_to(TracedRArray{T,N}, rhs) end +function TracedUtils.promote_to( + ::Type{TracedRArray{T,0}}, rhs::TracedRNumber{T2} +) where {T,T2} + return TracedRArray{T,0}((), Ops.convert(TracedRNumber{T}, rhs).mlir_data, ()) +end for (jlop, hloop, hlocomp, merge) in ((:(Base.:(==)), :compare, "EQ", :all), (:(Base.:(!=)), :compare, "NE", :any)) @@ -207,8 +484,11 @@ function Base.mapreduce( else init = Base.reduce_empty(Base.BottomRF(op), op_in_T) end - else - init = init::T + + if typeof(init) != op_in_T + op_in_T = typeof(init) + A = typeof(init).(A) + end end init = [TracedUtils.broadcast_to_size(init, ()).mlir_data] @@ -235,8 +515,8 @@ function Base.mapreduce( fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location(), MLIR.IR.Location()]) args = ( - TracedRNumber{op_in_T}((), MLIR.IR.argument(fnbody, 1)), - TracedRNumber{op_in_T}((), MLIR.IR.argument(fnbody, 2)), + TracedRNumber{Reactant.unwrapped_eltype(op_in_T)}((), MLIR.IR.argument(fnbody, 1)), + TracedRNumber{Reactant.unwrapped_eltype(op_in_T)}((), MLIR.IR.argument(fnbody, 2)), ) resty = MLIR.IR.block!(fnbody) do @@ -287,24 +567,28 @@ function Base.mapreducedim!( @nospecialize(R::TracedRArray), A::Base.AbstractArrayOrBroadcasted, ) - tmp = TracedUtils.broadcast_to_size( - Base.mapreduce(f, op, A; dims=1), (1, size(R)[2:end]...) - ) - R.mlir_data = broadcast(op, R, tmp).mlir_data + @assert length(size(R)) == length(size(A)) + dims = map(enumerate(zip(size(R), size(A)))) do (i, (sR, sA)) + sR == sA && return nothing + @assert sR == 1 + return i + end + tmp = mapreduce(f, op, A; dims=filter(!isnothing, dims)) + set_mlir_data!(R, get_mlir_data(tmp)) return R end -function Base.fill!(A::TracedRArray{T,N}, x) where {T,N} +function Base.fill!(A::AnyTracedRArray{T,N}, x) where {T,N} bcast = TracedUtils.broadcast_to_size(T(x), size(A)) - A.mlir_data = bcast.mlir_data + set_mlir_data!(A, get_mlir_data(bcast)) return A end -function Base.fill!(A::TracedRArray{T,N}, x::TracedRNumber{T2}) where {T,N,T2} +function Base.fill!(A::AnyTracedRArray{T,N}, x::TracedRNumber{T2}) where {T,N,T2} bcast = TracedUtils.broadcast_to_size( TracedUtils.promote_to(TracedRNumber{T}, x), size(A) ) - A.mlir_data = bcast.mlir_data + set_mlir_data!(A, get_mlir_data(bcast)) return A end @@ -319,16 +603,16 @@ end function Base.similar( ::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, dims -) where {T<:ReactantPrimitive,N} +) where {T<:Reactant.ReactantPrimitive,N} @assert N isa Int return TracedRArray{T,length(dims)}((), nothing, map(length, dims)) end function Base.similar( - bc::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{<:TracedRNumber{T}}, dims -) where {T<:ReactantPrimitive,N} + ::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{TracedRNumber{T}}, dims +) where {T<:Reactant.ReactantPrimitive,N} @assert N isa Int - return TracedRArray{T,N}((), nothing, map(length, dims)) + return TracedRArray{T,length(dims)}((), nothing, map(length, dims)) end function Broadcast.copy(bc::Broadcasted{<:AbstractReactantArrayStyle{0}}) @@ -339,30 +623,24 @@ end Base.eltype(::Broadcast.Extruded{T}) where {T} = eltype(T) +function first_scalar(x) + Reactant.@allowscalar first(x) +end + # we need to override the outer copy method to make sure we never fall back to scalar # iteration (see, e.g., CUDA.jl#145) function Broadcast.copy(bc::Broadcasted{<:AbstractReactantArrayStyle}) - ElType = Broadcast.combine_eltypes(bc.f, bc.args) - if ElType == Any - a1 = bc.args[1] - @show a1 - b1 = a1.args[1] - @show b1 - @show typeof(b1) - @show eltype(b1) - @show Broadcast._broadcast_getindex_eltype(a1.args[1]) - @show Broadcast.eltypes(a1.args) - @show Broadcast._broadcast_getindex_eltype(a1) - @show typeof(bc.args) - argT = Broadcast.eltypes(bc.args) - @show argT - RT = Base._return_type(bc.f, argT) - @show RT - T = Base.promote_typejoin_union(RT) - @show T - @show bc.f, bc.args - end - @assert ElType != Any + fn = if bc.f isa Type && bc.f <: Reactant.ReactantPrimitive + TracedUtils.TypeCast{bc.f}() + else + bc.f + end + ElType = Broadcast.combine_eltypes(fn, bc.args) + # Special case a union{} return so we can see the better error message + if ElType === Union{} + fn(map(first_scalar, bc.args)...) + end + @assert ElType != Any && ElType != Union{} sim = similar(bc, ElType) return copyto!(sim, bc) end @@ -380,6 +658,10 @@ function Base.copyto!(dest::TracedRArray{T,N}, src::TracedRArray{T,N}) where {T, return dest end +function Base.copyto!(dest::TracedRArray{T,N}, src::TracedRArray{T2,N}) where {T,T2,N} + return copyto!(dest, Ops.convert(TracedRArray{T,N}, src)) +end + function _copyto!(dest::AnyTracedRArray, bc::Broadcasted) axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc)) isempty(dest) && return dest @@ -388,11 +670,29 @@ function _copyto!(dest::AnyTracedRArray, bc::Broadcasted) args = (TracedUtils.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args) - res = TracedUtils.elem_apply(bc.f, args...) + res = TracedUtils.promote_to( + TracedRArray{unwrapped_eltype(dest),ndims(dest)}, + TracedUtils.elem_apply(bc.f, args...), + ) TracedUtils.set_mlir_data!(dest, res.mlir_data) return dest end +function _copyto!(dest::AbstractArray{<:TracedRNumber}, bc::Broadcasted) + axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc)) + isempty(dest) && return dest + + bc = Broadcast.preprocess(dest, bc) + + args = (TracedUtils.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args) + + res = TracedUtils.elem_apply(bc.f, args...) + for I in 1:length(dest) + dest[I] = Reactant.@allowscalar res[I] + end + return dest +end + dispatch_val(x) = x dispatch_val(::Val{D}) where {D} = D @@ -459,7 +759,7 @@ function Base._cat_t(dims, ::Type{T}, X::TracedRArray...) where {T} catdims = Base.dims2cat(dims) shape = Base.cat_size_shape(catdims, X...) - RT = Base.promote_eltype(T, X...) + RT = unwrapped_eltype(Base.promote_eltype(T, X...)) # convert to the target eltype X = map(Base.Fix1(TracedUtils.promote_to, TracedRArray{RT,length(shape)}), X) @@ -487,19 +787,22 @@ for (minT, maxT) in Iterators.product((Number, TracedRNumber), (Number, TracedRN end end -Base.all(f::Function, x::AnyTracedRArray) = mapreduce(f, &, x) -Base.any(f::Function, x::AnyTracedRArray) = mapreduce(f, |, x) +Base._all(f, x::AnyTracedRArray, dims) = mapreduce(f, &, x; dims) +Base._all(f, x::AnyTracedRArray, dims::Colon) = mapreduce(f, &, x; dims) +Base._any(f, x::AnyTracedRArray, dims) = mapreduce(f, |, x; dims) +Base._any(f, x::AnyTracedRArray, dims::Colon) = mapreduce(f, |, x; dims) # outer repeat -# Overridden because we don't need to further recur into the definitions here -function Base.repeat(x::AnyTracedRArray{T,N}, counts::Vararg{Int,M}) where {T,N,M} +function Base._RepeatInnerOuter.repeat_outer( + x::AnyTracedRArray{T,N}, counts::NTuple{M,Int} +) where {T,N,M} P = max(N, M) # potentially padded # (d1, d2, ..., dP) -> (d1, 1, d2, 1, ..., dP, 1) interleaved_size = ones(Int, 2P) interleaved_size[1:2:(2N)] .= size(x) - x_interleaved = reshape(x, interleaved_size...) + x_interleaved = reshape(materialize_traced_array(x), interleaved_size...) # (d1, 1, d2, 1, ..., dP, 1) -> (d1, r1, d2, r2, ..., dP, rP) broadcast_target_size = interleaved_size @@ -510,9 +813,274 @@ function Base.repeat(x::AnyTracedRArray{T,N}, counts::Vararg{Int,M}) where {T,N, # (d1, r1, d2, r2, ..., dP, rP) -> (d1*r1, d2*r2, ..., dP*rP) final_size = vec(prod(reshape(broadcast_target_size, 2, :); dims=1)) - x_final = reshape(x_broadcasted, final_size...) + return materialize_traced_array(reshape(x_broadcasted, final_size...)) +end + +# inner repeat +function Base._RepeatInnerOuter.repeat_inner( + x::AnyTracedRArray{T,N}, counts::NTuple{M,Int} +) where {T,N,M} + P = max(N, M) # potentially padded + + # (d1, d2, ..., dP) -> (1, d1, 1, d2, 1, ..., 1, dP) + interleaved_size = ones(Int, 2P) + interleaved_size[2:2:(2N)] .= size(x) + + x_interleaved = reshape(materialize_traced_array(x), interleaved_size...) + + # (1, d1, 1, d2, 1, ..., 1, dP) -> (r1, d1, r2, d2, ..., rP, dP) + broadcast_target_size = interleaved_size + broadcast_target_size[1:2:(2N)] .= counts + + x_broadcasted = TracedUtils.broadcast_to_size(x_interleaved, broadcast_target_size) + + # (r1, d1, r2, d2, ..., rP, dP) -> (d1*r1, d2*r2, ..., dP*rP) + final_size = vec(prod(reshape(broadcast_target_size, 2, :); dims=1)) + + return materialize_traced_array(reshape(x_broadcasted, final_size...)) +end + +# stack +function overloaded_stack(dims::Union{Integer,Colon}, xs) + @assert allequal(ndims.(xs)) "All arrays must have the same number of dimensions..." + dims = dims isa Colon ? ndims(first(xs)) + 1 : dims + res = map(xs) do x + new_shape = ntuple( + i -> i == dims ? 1 : (i < dims ? size(x, i) : size(x, i - 1)), ndims(x) + 1 + ) + return materialize_traced_array(reshape(x, new_shape)) + end + return cat(res...; dims) +end + +# sort +function Base.sort(x::AnyTracedRArray; alg=missing, order=missing, kwargs...) + return sort!(copy(x); alg, order, kwargs...) +end +function Base.sort(x::AnyTracedRVector; alg=missing, order=missing, kwargs...) + return sort!(copy(x); alg, order, dims=1, kwargs...) +end + +function Base.sort!( + x::AnyTracedRArray; + dims::Union{Integer,Nothing}=nothing, + lt=isless, + by=identity, + rev::Bool=false, + alg=missing, + order=missing, +) + if dims === nothing + @assert ndims(x) == 1 + dims = 1 + end + + @assert alg === missing "Reactant doesn't support `alg` kwarg for `sort!`" + @assert order === missing "Reactant doesn't support `order` kwarg for `sort!`" + + comparator = rev ? (a, b) -> !lt(by(a), by(b)) : (a, b) -> lt(by(a), by(b)) + res = only(Ops.sort(materialize_traced_array(x); dimension=dims, comparator)) + set_mlir_data!(x, get_mlir_data(res)) + return x +end + +function Base.sortperm(x::AnyTracedRArray; alg=missing, order=missing, kwargs...) + return sortperm!(similar(x, Int), x; alg, order, kwargs...) +end +function Base.sortperm(x::AnyTracedRVector; alg=missing, order=missing, kwargs...) + return sortperm!(similar(x, Int), x; alg, order, dims=1, kwargs...) +end + +function Base.sortperm!( + ix::AnyTracedRArray{Int,N}, + x::AnyTracedRArray{<:Any,N}; + dims::Union{Integer,Nothing}=nothing, + lt=isless, + by=identity, + rev::Bool=false, + alg=missing, + order=missing, +) where {N} + if dims === nothing + @assert ndims(x) == 1 + dims = 1 + end + + @assert alg === missing "Reactant doesn't support `alg` kwarg for `sortperm!`" + @assert order === missing "Reactant doesn't support `order` kwarg for `sortperm!`" + + comparator = + rev ? (a, b, i1, i2) -> !lt(by(a), by(b)) : (a, b, i1, i2) -> lt(by(a), by(b)) + idxs = Ops.constant(collect(LinearIndices(x))) + _, res = Ops.sort(materialize_traced_array(x), idxs; dimension=dims, comparator) + set_mlir_data!(ix, get_mlir_data(res)) + return ix +end + +function Base.partialsort(x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs...) + values, _ = overloaded_partialsort(x, k; kwargs...) + k = k .- minimum(k) .+ 1 + k isa Integer && return @allowscalar(values[k]) + return view(values, k) +end + +function Base.partialsort!(x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs...) + values, _ = overloaded_partialsort(x, k; kwargs...) + kget = k .- minimum(k) .+ 1 + val = @allowscalar(values[kget]) + @allowscalar setindex!(x, val, k) + k isa Integer && return val + return view(x, k) +end + +function Base.partialsortperm( + x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs... +) + idxs = overloaded_partialsort(x, k; kwargs...)[2] + k = k .- minimum(k) .+ 1 + k isa Integer && return @allowscalar(idxs[k]) + return view(idxs, k) +end + +function Base.partialsortperm!( + ix::AnyTracedRVector{Int}, + x::AnyTracedRVector, + k::Union{Integer,OrdinalRange}; + kwargs..., +) + _, idxs = overloaded_partialsort(x, k; kwargs...) + kget = k .- minimum(k) .+ 1 + val = @allowscalar(idxs[kget]) + @allowscalar setindex!(ix, val, k) + k isa Integer && return val + return view(ix, k) +end + +function overloaded_partialsort( + x::AnyTracedRVector, + k::Union{Integer,OrdinalRange}; + by=identity, + rev::Bool=false, + lt=isless, +) + if lt !== isless || by !== identity + comparator = + rev ? (a, b, i1, i2) -> !lt(by(a), by(b)) : (a, b, i1, i2) -> lt(by(a), by(b)) + idxs = Ops.constant(collect(LinearIndices(x))) + sorted_x, sorted_idxs = Ops.sort( + materialize_traced_array(x), idxs; dimension=1, comparator + ) + return sorted_x[1:maximum(k)], sorted_idxs[1:maximum(k)] + end + + # XXX: If `maxk` is beyond a threshold should we emit a sort directly? + !rev && (k = length(x) .- k .+ 1) + !(k isa Integer) && (k = maximum(k)) + (; values, indices) = Ops.top_k(materialize_traced_array(x), k) + if !rev + values = Ops.reverse(values; dimensions=[1]) + indices = Ops.reverse(indices; dimensions=[1]) + end + return values, indices +end + +# arg* functions +function Base.argmin(f::F, x::AnyTracedRArray) where {F} + idx = scalar_index_to_cartesian(argmin(f.(x)), size(x)) .+ 1 + return @allowscalar x[idx...] +end + +function Base.argmax(f::F, x::AnyTracedRArray) where {F} + idx = scalar_index_to_cartesian(argmax(f.(x)), size(x)) .+ 1 + return @allowscalar x[idx...] +end + +Base.argmin(x::AnyTracedRArray; kwargs...) = findmin(identity, x; kwargs...)[2] +Base.argmax(x::AnyTracedRArray; kwargs...) = findmax(identity, x; kwargs...)[2] + +# find* functions +Base.findfirst(x::AnyTracedRArray) = findfirst(identity, x) +Base.findlast(x::AnyTracedRArray) = findlast(identity, x) + +function Base.findfirst(f::Function, x::AnyTracedRArray) + fA = materialize_traced_array(vec(f.(x))) + (; indices) = Ops.top_k(fA, 1) + return @allowscalar indices[1] +end + +function Base.findlast(f::Function, x::AnyTracedRArray) + fA = Ops.reverse(materialize_traced_array(vec(f.(x))); dimensions=[1]) + (; indices) = Ops.top_k(fA, 1) + return length(x) - @allowscalar(indices[1]) + 1 +end + +Base.findmin(x::AnyTracedRVector) = findmin(identity, x; dims=1) +function Base.findmin(x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothing) + return findmin(identity, x; dims) +end + +Base.findmax(x::AnyTracedRVector) = findmax(identity, x; dims=1) +function Base.findmax(x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothing) + return findmax(identity, x; dims) +end + +## To avoid scalar indexing and constructing an array of tuples, we return the linear index +## instead of the cartesian index +function Base.findmin(f, x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothing) + if dims === nothing + if ndims(x) == 1 + dims = 1 + else + return findmin(f, vec(x); dims=1) + end + end + + fx = Ops.negate(materialize_traced_array(f.(x))) + (; values, indices) = Ops.top_k(fx, 1; dimension=dims) + + # Compute linear indices + strds = strides(x) + iotas = [Ops.iota(Int64, [size(indices)...]; iota_dimension=i) for i in 1:ndims(x)] + iotas[dims] = Ops.subtract(indices, Ops.fill(Int64(1), size(indices))) + linear_indices = Ops.fill(Int64(1), size(indices)) + for d in eachindex(iotas) + linear_indices = Ops.add( + linear_indices, + Ops.multiply(iotas[d], Ops.fill(Int64(strds[d]), size(iotas[d]))), + ) + end + + values = Ops.negate(values) + ndims(x) == 1 && return @allowscalar (values[1], linear_indices[1]) + return (values, linear_indices) +end + +function Base.findmax(f, x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothing) + if dims === nothing + if ndims(x) == 1 + dims = 1 + else + return findmax(f, vec(x); dims=1) + end + end + + fx = materialize_traced_array(f.(x)) + (; values, indices) = Ops.top_k(fx, 1; dimension=dims) + + # Compute linear indices + strds = strides(x) + iotas = [Ops.iota(Int64, [size(indices)...]; iota_dimension=i) for i in 1:ndims(x)] + iotas[dims] = Ops.subtract(indices, Ops.fill(Int64(1), size(indices))) + linear_indices = Ops.fill(Int64(1), size(indices)) + for d in eachindex(iotas) + linear_indices = Ops.add( + linear_indices, + Ops.multiply(iotas[d], Ops.fill(Int64(strds[d]), size(iotas[d]))), + ) + end - return x_final + ndims(x) == 1 && return @allowscalar (values[1], linear_indices[1]) + return (values, linear_indices) end end diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index df664031ea..d798b9b312 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -1,17 +1,11 @@ module TracedRNumberOverrides -import ..TracedRNumber -import ..TracedRArray -import ..ReactantPrimitive -using ..TracedUtils -import ..Ops -import ..MLIR +using ..Reactant: + Reactant, TracedRNumber, TracedRArray, TracedUtils, Ops, MLIR, unwrapped_eltype using ReactantCore ReactantCore.is_traced(::TracedRNumber) = true -Base.eltype(::Type{TracedRNumber{T}}) where {T} = T - Base.getindex(a::TracedRNumber{T}) where {T} = a Base.zero(::TracedRNumber{T}) where {T} = TracedUtils.promote_to(TracedRNumber{T}, zero(T)) @@ -22,10 +16,22 @@ function Base.eps(::Type{TracedRNumber{T}}) where {T} return TracedUtils.promote_to(TracedRNumber{T}, eps(T)) end -function Base.convert(::Type{<:TracedRNumber{T}}, x::Number) where {T} - return TracedUtils.promote_to(TracedRNumber{T}, T(x)) +function Base.isfinite(x::TracedRNumber{<:Complex}) + return isfinite(real(x)) & isfinite(imag(x)) +end +Base.isfinite(x::TracedRNumber{<:AbstractFloat}) = Ops.is_finite(x) + +function Base.isnan(x::TracedRNumber{<:Complex}) + return isnan(real(x)) | isnan(imag(x)) +end +function Base.isnan(x::TracedRNumber{T}) where {T<:AbstractFloat} + return !isfinite(x) & (x != typemax(T)) & (x != typemin(T)) end +Base.isinf(x::TracedRNumber{<:Complex}) = isinf(real(x)) | isinf(imag(x)) +Base.isinf(x::TracedRNumber{<:AbstractFloat}) = Ops.is_inf(x) +Base.isinf(::TracedRNumber{<:Integer}) = false + function Base.show(io::IOty, X::TracedRNumber{T}) where {T,IOty<:Union{IO,IOContext}} return print(io, "TracedRNumber{", T, "}(", X.paths, ")") end @@ -49,14 +55,18 @@ function Base.promote_rule(::Type{T}, ::Type{TracedRNumber{S}}) where {T,S} return TracedRNumber{Base.promote_type(T, S)} end -function Base.convert(::Type{TracedRNumber{T}}, x::Number) where {T} - return TracedUtils.promote_to(TracedRNumber{T}, x) +function Base.promote_rule(::Type{TracedRNumber{T}}, ::Type{S}) where {T,S} + return TracedRNumber{Base.promote_type(T, S)} end +# NOTE: This is inconsistent with the behavior of `convert` but we do it since it is a very +# common usecase TracedRNumber{T}(x::TracedRNumber{T}) where {T} = x - +function TracedRNumber{T}(x::TracedRNumber) where {T} + return TracedUtils.promote_to(TracedRNumber{unwrapped_eltype(T)}, x) +end function TracedRNumber{T}(x::Number) where {T} - return TracedUtils.promote_to(TracedRNumber{T}, x) + return TracedUtils.promote_to(TracedRNumber{unwrapped_eltype(T)}, x) end function TracedUtils.promote_to(::Type{TracedRNumber{T}}, rhs) where {T} @@ -66,11 +76,11 @@ function TracedUtils.promote_to(::Type{TracedRNumber{T}}, rhs) where {T} end if rhs isa TracedRArray{<:Any,0} return TracedUtils.promote_to( - TracedRNumber{T}, TracedRNumber{eltype(rhs)}((), rhs.mlir_data) + TracedRNumber{T}, + TracedRNumber{Reactant.unwrapped_eltype(rhs)}((), rhs.mlir_data), ) end - rhs isa Number && - return TracedUtils.promote_to(TracedRNumber{T}, Ops.constant(fill(T(rhs)))) + rhs isa Number && return TracedUtils.promote_to(TracedRNumber{T}, Ops.fill(T(rhs))) return TracedUtils.promote_to(TracedRNumber{T}, Ops.constant(collect(rhs))) end @@ -86,6 +96,7 @@ for (jlop, hloop) in ( (:(Base.:*), :multiply), (:(Base.:/), :divide), (:(Base.:^), :power), + (:(Base.rem), :remainder), ) @eval function $(jlop)( @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T}) @@ -94,12 +105,51 @@ for (jlop, hloop) in ( end end +function Base.rem( + @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::Number) +) where {T} + return Ops.remainder(lhs, TracedUtils.promote_to(TracedRNumber{T}, rhs)) +end +function Base.rem( + @nospecialize(lhs::Number), @nospecialize(rhs::TracedRNumber{T}) +) where {T} + return Ops.remainder(TracedUtils.promote_to(TracedRNumber{T}, lhs), rhs) +end + +# Based on https://github.com/JuliaLang/julia/blob/39255d47db7657950ff1c82137ecec5a70bae622/base/float.jl#L608-L617 +function Base.mod( + @nospecialize(x::Reactant.TracedRNumber{T}), @nospecialize(y::Reactant.TracedRNumber{T}) +) where {T} + r = rem(x, y) + return ifelse(r == 0, copysign(r, y), ifelse((r > 0) ⊻ (y > 0), r + y, r)) +end +function Base.mod( + @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::Number) +) where {T} + return mod(lhs, TracedUtils.promote_to(TracedRNumber{T}, rhs)) +end +function Base.mod( + @nospecialize(lhs::Number), @nospecialize(rhs::TracedRNumber{T}) +) where {T} + return mod(TracedUtils.promote_to(TracedRNumber{T}, lhs), rhs) +end + +function Base.div(@nospecialize(lhs::TracedRNumber{T}), rhs) where {T<:Integer} + return Ops.divide(lhs, TracedUtils.promote_to(TracedRNumber{T}, rhs)) +end + function Base.div( @nospecialize(lhs::TracedRNumber{T}), rhs, ::typeof(RoundDown) ) where {T<:Integer} return Ops.divide(lhs, TracedUtils.promote_to(TracedRNumber{T}, rhs)) end +function Base.:/( + @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T}) +) where {T<:Integer} + return float(lhs) / float(rhs) +end + for (jlop, hloop, hlocomp) in ( (:(Base.:(==)), :compare, "EQ"), (:(Base.:(!=)), :compare, "NE"), @@ -107,6 +157,7 @@ for (jlop, hloop, hlocomp) in ( (:(Base.:(>)), :compare, "GT"), (:(Base.:(<=)), :compare, "LE"), (:(Base.:(<)), :compare, "LT"), + (:(Base.isless), :compare, "LT"), ) @eval begin function $(jlop)( @@ -144,6 +195,14 @@ for (jlop, hloop, hlocomp) in ( end end +function Base.ifelse(@nospecialize(pred::TracedRNumber{Bool}), x::Number, y::Number) + return ifelse( + pred, + TracedUtils.promote_to(TracedRNumber{unwrapped_eltype(x)}, x), + TracedUtils.promote_to(TracedRNumber{unwrapped_eltype(y)}, y), + ) +end + function Base.ifelse( @nospecialize(pred::TracedRNumber{Bool}), @nospecialize(x::TracedRNumber{T1}), @@ -183,6 +242,12 @@ for (T1, T2) in zip((Bool, Integer), (Bool, Integer)) TracedUtils.promote_to(TracedRNumber{$(T)}, y), ) end + function Base.xor(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)}) + return Ops.xor( + TracedUtils.promote_to(TracedRNumber{$(T)}, x), + TracedUtils.promote_to(TracedRNumber{$(T)}, y), + ) + end Base.:!(x::TracedRNumber{<:$(T1)}) = Ops.not(x) end end @@ -198,6 +263,7 @@ for (jlop, hloop) in ( (:(Base.:-), :negate), (:(Base.sin), :sine), (:(Base.cos), :cosine), + (:(Base.tan), :tan), (:(Base.tanh), :tanh), (:(Base.FastMath.tanh_fast), :tanh), (:(Base.exp), :exponential), @@ -206,21 +272,46 @@ for (jlop, hloop) in ( (:(Base.log), :log), (:(Base.log1p), :log_plus_one), (:(Base.sqrt), :sqrt), - (:(Base.ceil), :ceil), - (:(Base.floor), :floor), + (:(Base.acos), :acos), + (:(Base.acosh), :acosh), + (:(Base.asin), :asin), + (:(Base.asinh), :asinh), + (:(Base.atan), :atan), + (:(Base.atanh), :atanh), + (:(Base.sign), :sign), ) @eval $(jlop)(@nospecialize(lhs::TracedRNumber)) = Ops.$(hloop)(lhs) end +for (jlop, hloop) in + ((:(Base.sinpi), :sine), (:(Base.cospi), :cosine), (:(Base.tanpi), :tan)) + @eval $(jlop)(@nospecialize(lhs::TracedRNumber{T})) where {T} = Ops.$(hloop)(T(π) * lhs) +end + +Base.sincospi(x::TracedRNumber{T}) where {T} = Ops.sine(T(π) * x), Ops.cosine(T(π) * x) + Base.conj(x::TracedRNumber) = x Base.conj(x::TracedRNumber{<:Complex}) = Ops.conj(x) Base.real(x::TracedRNumber) = x Base.real(x::TracedRNumber{<:Complex}) = Ops.real(x) +Base.isreal(::TracedRNumber) = false +Base.isreal(::TracedRNumber{<:Real}) = true + Base.imag(x::TracedRNumber) = zero(x) Base.imag(x::TracedRNumber{<:Complex}) = Ops.imag(x) +Base.iseven(x::TracedRNumber) = iseven(real(x)) +function Base.iseven(x::TracedRNumber{<:Real}) + return iszero( + rem( + TracedUtils.promote_to(TracedRNumber{Int}, x), + TracedUtils.promote_to(TracedRNumber{Int}, 2), + ), + ) +end + for (minT, maxT) in Iterators.product((Number, TracedRNumber), (Number, TracedRNumber)) @eval Base.clamp(x::TracedRNumber, min::$(minT), max::$(maxT)) = Ops.clamp(min, x, max) end @@ -228,11 +319,71 @@ end function Base.fill(x::TracedRNumber, dims::NTuple{N,Integer}) where {N} return TracedUtils.broadcast_to_size(x, dims) end +function Base.fill(x::TracedRNumber, ::Tuple{}) + return TracedUtils.broadcast_to_size(x, ()) +end function Base.float(x::TracedRNumber{T}) where {T} return TracedUtils.promote_to(TracedRNumber{float(T)}, x) end +using Reactant: ReactantFloat + +Base.round(A::TracedRNumber{<:ReactantFloat}) = Ops.round_nearest_even(A) +Base.floor(A::TracedRNumber{<:ReactantFloat}) = Ops.floor(A) +Base.ceil(A::TracedRNumber{<:ReactantFloat}) = Ops.ceil(A) + +function Base.unsafe_trunc( + T::Type{<:Reactant.ReactantInt}, x::TracedRNumber{<:Reactant.ReactantFloat} +) + return Ops.convert(TracedRNumber{T}, x) +end + +for Ti in (Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64, UInt128) + for Tf in (Float16, Float32, Float64) + if Ti <: Unsigned || sizeof(Ti) < sizeof(Tf) + # Here `Tf(typemin(Ti))-1` is exact, so we can compare the lower-bound + # directly. `Tf(typemax(Ti))+1` is either always exactly representable, or + # rounded to `Inf` (e.g. when `Ti==UInt128 && Tf==Float32`). + @eval begin + function Base.trunc(::Type{$Ti}, x::TracedRNumber{$Tf}) + # TODO throw error within traced + # if $(Tf(typemin(Ti))-one(Tf)) < x < $(Tf(typemax(Ti))+one(Tf)) + return Base.unsafe_trunc($Ti, x) + # else + # throw(Base.InexactError(:trunc, $Ti, x)) + # end + end + end + else + # Here `eps(Tf(typemin(Ti))) > 1`, so the only value which can be truncated to + # `Tf(typemin(Ti)` is itself. Similarly, `Tf(typemax(Ti))` is inexact and will + # be rounded up. This assumes that `Tf(typemin(Ti)) > -Inf`, which is true for + # these types, but not for `Float16` or larger integer types. + @eval begin + function Base.trunc(::Type{$Ti}, x::TracedRNumber{$Tf}) + # TODO throw error within traced + # if $(Tf(typemin(Ti))) <= x < $(Tf(typemax(Ti))) + return Base.unsafe_trunc($Ti, x) + # else + # throw(Base.InexactError(:trunc, $Ti, x)) + # end + end + end + end + end +end + +function Base.round(::Type{T}, x::TracedRNumber{<:AbstractFloat}) where {T<:Integer} + return trunc(T, Base.round(x)) +end +function Base.floor(::Type{T}, x::TracedRNumber{<:AbstractFloat}) where {T<:Integer} + return trunc(T, Base.floor(x)) +end +function Base.ceil(::Type{T}, x::TracedRNumber{<:AbstractFloat}) where {T<:Integer} + return trunc(T, Base.ceil(x)) +end + # Concatenation. Numbers in Julia are handled in a much less generic fashion than arrays Base.vcat(x::TracedRNumber...) = Base.typed_vcat(Base.promote_eltypeof(x...), x...) function Base.typed_vcat(::Type{T}, x::TracedRNumber...) where {T} @@ -264,4 +415,20 @@ function Base.typed_hvncat( return Base.typed_hvncat(T, dims, row_first, xs...) end +for (Ti, Tf) in ((Int16, Float16), (Int32, Float32), (Int64, Float64)) + @eval begin + Base.signbit(x::TracedRNumber{$(Ti)}) = x < 0 + Base.signbit(x::TracedRNumber{$(Tf)}) = signbit(Ops.bitcast_convert($(Ti), x)) + end end +Base.signbit(::TracedRNumber{<:Unsigned}) = ConcretePJRTNumber(false) + +Base.copysign(x::TracedRNumber, y::TracedRNumber) = ifelse(signbit(y), -1, 1) * abs(x) +function Base.copysign(x::TracedRNumber{T}, y::S) where {T,S<:Number} + return copysign(x, TracedUtils.promote_to(TracedRNumber{S}, y)) +end +function Base.copysign(x::S, y::TracedRNumber{T}) where {S<:Number,T} + return copysign(TracedUtils.promote_to(TracedRNumber{S}, x), y) +end + +end # module TracedRNumberOverrides diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 665c6732d3..b13ebcc7ad 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -3,95 +3,110 @@ # within compilation. However, it means these functions are a _lot_ faster to compile. module TracedUtils -using LinearAlgebra: LinearAlgebra -using Adapt: Adapt +using Adapt: Adapt, WrappedReshapedArray using ..Reactant: - RArray, + Reactant, + MLIR, RNumber, TracedRArray, TracedRNumber, WrappedTracedRArray, AnyTracedRArray, MissingTracedValue, - OrderedIdDict -import ..Reactant -import ..Reactant.MLIR -import ..ReactantPrimitive -import ..Ops + OrderedIdDict, + ReactantPrimitive, + Ops +using ReactantCore: MissingTracedValue, is_traced +using Functors: Functors materialize_traced_array(x::TracedRArray) = x + materialize_traced_array(x::WrappedTracedRArray) = x[axes(x)...] + function materialize_traced_array( - x::Adapt.WrappedReshapedArray{T,N,<:TracedRArray} -) where {T,N} + x::WrappedReshapedArray{TracedRNumber{T},N,TracedRArray{T,M}} +) where {T,N,M} return Ops.reshape(materialize_traced_array(parent(x)), size(x)...) end + function materialize_traced_array( - x::LinearAlgebra.Transpose{T,TracedRArray{T,N}} -) where {T,N} - px = parent(x) - A = ndims(px) == 1 ? reshape(px, :, 1) : px - return permutedims(A, (2, 1)) -end -function materialize_traced_array(x::LinearAlgebra.Adjoint{T,TracedRArray{T,N}}) where {T,N} - return conj(materialize_traced_array(transpose(parent(x)))) -end -function materialize_traced_array( - x::PermutedDimsArray{T,N,perm,iperm,<:TracedRArray{T,N}} + x::PermutedDimsArray{TracedRNumber{T},N,perm,iperm,TracedRArray{T,N}} ) where {T,N,perm,iperm} return permutedims(parent(x), perm) end -function materialize_traced_array(x::LinearAlgebra.Diagonal{T,TracedRArray{T,1}}) where {T} - return LinearAlgebra.diagm(parent(x)) -end get_mlir_data(x::TracedRNumber) = x.mlir_data set_mlir_data!(x::TracedRNumber, data) = (x.mlir_data = data; return x) +get_paths(x::TracedRNumber) = x.paths +set_paths!(x::TracedRNumber, paths) = (x.paths = paths; return x) get_mlir_data(x::TracedRArray) = x.mlir_data get_mlir_data(x::AnyTracedRArray) = get_mlir_data(materialize_traced_array(x)) +get_paths(x::TracedRArray) = x.paths +set_paths!(x::TracedRArray, paths) = (x.paths = paths; return x) + +get_paths(x::MissingTracedValue) = x.paths +set_paths!(x::MissingTracedValue, paths) = (x.paths = paths; return x) function set_mlir_data!(x::TracedRArray, data) x.mlir_data = data return x end -function set_mlir_data!(x::Adapt.WrappedReshapedArray{T,N,<:TracedRArray}, data) where {T,N} - res_mlir_data = Ops.reshape(TracedRArray(data), size(parent(x))...).mlir_data + +function set_mlir_data!( + x::WrappedReshapedArray{TracedRNumber{T},N,TracedRArray{T,M}}, data +) where {T,N,M} + res_mlir_data = Ops.reshape(TracedRArray{T}(data), size(parent(x))...).mlir_data set_mlir_data!(parent(x), res_mlir_data) return x end -function set_mlir_data!(x::LinearAlgebra.Transpose{T,TracedRArray{T,N}}, data) where {T,N} - tdata = TracedRArray(data) - px = parent(x) - px.mlir_data = ( - if ndims(px) == 1 - Ops.reshape(tdata, length(tdata)) - else - Ops.transpose(tdata, [2, 1]) + +function get_ancestor_indices( + x::WrappedReshapedArray{TracedRNumber{T},N,TracedRArray{T,M}}, indices... +) where {T,N,M} + @assert length(indices) == N "Expected $N indices, got $(length(indices))" + indices = normalize_indices(x, indices...) + if any(is_traced, indices) + indices, integer_indices, result_size, _, flattened_size = traced_indices( + indices... + ) + linear_indices = mapreduce(+, enumerate(indices)) do (i, idx) + bcasted_idxs = Ops.broadcast_in_dim( + idx, ndims(idx) == 0 ? Int64[] : Int64[i], flattened_size + ) + Base.stride(x, i) .* (bcasted_idxs .- 1) end - ).mlir_data - return x -end -function set_mlir_data!(x::LinearAlgebra.Adjoint{T,TracedRArray{T,N}}, data) where {T,N} - tdata = TracedRArray(data) - px = parent(x) - transposed_data = - ndims(px) == 1 ? Ops.reshape(tdata, length(tdata)) : Ops.transpose(tdata, [2, 1]) - px.mlir_data = (T <: Real ? transposed_data : Ops.conj(transposed_data)).mlir_data - return x + linear_indices = linear_indices .+ 1 + parent_linear_indices_all = collect(LinearIndices(size(parent(x)))) + parent_linear_indices = promote_to( + TracedRArray{Int64,ndims(parent_linear_indices_all)}, parent_linear_indices_all + )[linear_indices] + isempty(integer_indices) || ( + parent_linear_indices = materialize_traced_array( + dropdims(parent_linear_indices; dims=integer_indices) + ) + ) + parent_linear_indices = Ops.reshape(parent_linear_indices, result_size) + return (parent_linear_indices,) + else + # Have this as a separate code-path since we can generate non-dynamic indexing + cartesian_indices = CartesianIndex.(Iterators.product(indices...)) + linear_indices = LinearIndices(size(x))[cartesian_indices] + parent_linear_indices = LinearIndices(size(parent(x)))[linear_indices] + return (parent_linear_indices,) + end end + function set_mlir_data!( - x::PermutedDimsArray{T,N,perm,iperm,TracedRArray{T,N}}, data + x::PermutedDimsArray{TracedRNumber{T},N,perm,iperm,TracedRArray{T,N}}, data ) where {T,N,perm,iperm} - parent(x).mlir_data = permutedims(TracedRArray(data), iperm).mlir_data - return x -end -function set_mlir_data!(x::LinearAlgebra.Diagonal{T,TracedRArray{T,1}}, data) where {T} - parent(x).mlir_data = LinearAlgebra.diag(TracedRArray(data)).mlir_data + parent(x).mlir_data = permutedims(TracedRArray{T}(data), iperm).mlir_data return x end -function set_mlir_data!(x::AnyTracedRArray, data) - setindex!(x, TracedRArray(data), axes(x)...) + +function set_mlir_data!(x::AnyTracedRArray{T}, data) where {T} + ancestor_indices = get_ancestor_indices(x, axes(x)...) + setindex!(Reactant.ancestor(x), TracedRArray{T}(data), ancestor_indices...) return x end @@ -114,6 +129,27 @@ function transpose_val(val) return MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(val; permutation=attr), 1) end +mutable struct CompiledMlirFnResult{ + F,TR,Re,Rt,LA,LR,PA,CR,M<:Union{Nothing,Reactant.Sharding.Mesh},MA +} + fnwrapped::Bool + f::F + traced_result::TR + result::Re + seen_args::OrderedIdDict + ret::Rt + linear_args::Vector{LA} + in_tys::Vector{MLIR.IR.Type} + linear_results::Vector{LR} + num_partitions::Int + num_replicas::Int + is_sharded::Bool + preserved_args::PA + concrete_result::CR + sharding_mesh::M + mutated_args::MA +end + function make_mlir_fn( f, args, @@ -122,38 +158,40 @@ function make_mlir_fn( concretein=true; toscalar=false, return_dialect=:func, - no_args_in_result::Bool=false, + args_in_result::Symbol=:all, construct_function_without_args::Bool=false, do_transpose=true, + input_shardings=nothing, # This is not meant to be used by the user. ) if sizeof(typeof(f)) != 0 || f isa Base.BroadcastFunction - return ( - true, - make_mlir_fn( - Reactant.apply, - (f, args...), - kwargs, - name, - concretein; - toscalar, - return_dialect, - no_args_in_result, - construct_function_without_args, - do_transpose, - )[2:end]..., + mlir_fn_res = make_mlir_fn( + Reactant.apply, + (f, args...), + kwargs, + name, + concretein; + toscalar, + return_dialect, + do_transpose, + args_in_result, + input_shardings, ) + mlir_fn_res.fnwrapped = true + return mlir_fn_res end + num_partitions, num_replicas = 1, 1 + N = length(args) seen_args = OrderedIdDict() - traced_args = ntuple(N) do i - return Reactant.make_tracer( + traced_args = Vector{Any}(undef, N) + for i in 1:N + @inbounds traced_args[i] = Reactant.make_tracer( seen_args, args[i], (:args, i), concretein ? Reactant.ConcreteToTraced : Reactant.TracedSetPath; toscalar, - track_numbers=construct_function_without_args ? (Number,) : (), ) end @@ -164,7 +202,10 @@ function make_mlir_fn( end in_tys = if toscalar - [MLIR.IR.TensorType((), MLIR.IR.Type(eltype(arg))) for arg in linear_args] + [ + MLIR.IR.TensorType((), MLIR.IR.Type(Reactant.unwrapped_eltype(arg))) for + arg in linear_args + ] elseif do_transpose [transpose_ty(Ops.mlir_type(arg)) for arg in linear_args] else @@ -176,7 +217,23 @@ function make_mlir_fn( sym_visibility = MLIR.IR.Attribute("private") end + ctx = MLIR.IR.context() mod = MLIR.IR.mmodule() + + # Insert meshes for the sharded arguments + traced_args_to_shardings = OrderedIdDict() + for (k, v) in seen_args + if (k isa Reactant.ConcretePJRTArray || k isa Reactant.ConcretePJRTNumber) + if Reactant.Sharding.is_sharded(k) + Reactant.Ops.mesh(k.sharding.mesh) + traced_args_to_shardings[v] = k.sharding + elseif input_shardings !== nothing && haskey(input_shardings, k) + Reactant.Ops.mesh(input_shardings[k].mesh) + traced_args_to_shardings[v] = input_shardings[k] + end + end + end + func = MLIR.IR.block!(MLIR.IR.body(mod)) do return MLIR.Dialects.func.func_(; sym_name=name * "_tmp", @@ -185,32 +242,40 @@ function make_mlir_fn( ) end - if construct_function_without_args - fnbody = MLIR.IR.Block() - else - fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in linear_args]) - end + fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in linear_args]) push!(MLIR.IR.region(func, 1), fnbody) @assert MLIR.IR._has_block() - result = MLIR.IR.block!(fnbody) do + # Explicitly don't use block! to avoid creating a closure, which creates + # both compile-time and relocatability issues + MLIR.IR.activate!(fnbody) + + result = try for (i, arg) in enumerate(linear_args) - if construct_function_without_args - arg.mlir_data = args[i].mlir_data - else - raw_arg = MLIR.IR.argument(fnbody, i) - row_maj_arg = do_transpose ? transpose_val(raw_arg) : raw_arg - arg.mlir_data = row_maj_arg - end + raw_arg = MLIR.IR.argument(fnbody, i) + row_maj_arg = do_transpose ? transpose_val(raw_arg) : raw_arg + set_mlir_data!(arg, row_maj_arg) + end + + if isempty(kwargs) + Reactant.call_with_reactant(f, traced_args...) + else + Reactant.call_with_reactant(Core.kwcall, kwargs, f, traced_args...) end + finally + MLIR.IR.deactivate!(fnbody) + end - # TODO fix it for kwargs - #if concretein - Reactant.call_with_reactant(f, traced_args...) - #else - # f(traced_args...) - #end + # check which arguments have been mutated + mutated_args = Int[] + if !construct_function_without_args + for (i, arg) in enumerate(linear_args) + if get_mlir_data(arg) != MLIR.IR.argument(fnbody, i) + # mutation occured! + push!(mutated_args, i) + end + end end seen_results = OrderedIdDict() @@ -219,8 +284,7 @@ function make_mlir_fn( seen_results, result, (:result,), - concretein ? Reactant.TracedTrack : Reactant.TracedSetPath; - track_numbers=construct_function_without_args ? (Number,) : (), + concretein ? Reactant.TracedTrack : Reactant.TracedSetPath, ) # marks buffers to be donated @@ -234,60 +298,107 @@ function make_mlir_fn( end linear_results = Reactant.TracedType[] - for (k, v) in seen_results v isa Reactant.TracedType || continue - (no_args_in_result && length(v.paths) > 0 && v.paths[1][1] == :args) && continue + (args_in_result != :all && has_argidx(v)) && continue push!(linear_results, v) end + if args_in_result == :mutated + append!(linear_results, linear_args[mutated_args]) + end - out_tys = [transpose_ty(Ops.mlir_type(arg)) for arg in linear_results] + out_tys = if do_transpose + [transpose_ty(Ops.mlir_type(arg)) for arg in linear_results] + else + [Ops.mlir_type(arg) for arg in linear_results] + end - ret = MLIR.IR.block!(fnbody) do + MLIR.IR.activate!(fnbody) + ret = try vals = MLIR.IR.Value[] for res in linear_results col_maj = if res isa MissingTracedValue - broadcast_to_size(false, ()).mlir_data - elseif construct_function_without_args || !do_transpose - res.mlir_data + get_mlir_data(broadcast_to_size(false, ())) + elseif !do_transpose + get_mlir_data(res) elseif do_transpose - transpose_val(res.mlir_data) + transpose_val(get_mlir_data(res)) end push!(vals, col_maj) end - !no_args_in_result && @assert length(vals) == length(linear_results) + args_in_result == :all && @assert length(vals) == length(linear_results) dialect = getfield(MLIR.Dialects, return_dialect) - return dialect.return_(vals) - end - - name2 = name - - tab = MLIR.IR.SymbolTable(MLIR.IR.Operation(mod)) - for i in 0:10000 - name2 = if i == 0 - name - else - name * string(i) - end - if MLIR.IR.mlirIsNull(MLIR.API.mlirSymbolTableLookup(tab, name2)) - break - end + dialect.return_(vals) + finally + MLIR.IR.deactivate!(fnbody) end func2 = MLIR.IR.block!(MLIR.IR.body(mod)) do return MLIR.Dialects.func.func_(; - sym_name=name2, + sym_name=__lookup_unique_name_in_module(mod, name), function_type=MLIR.IR.FunctionType(in_tys, out_tys), body=MLIR.IR.Region(), + arg_attrs=MLIR.IR.attr(func, "arg_attrs"), + res_attrs=MLIR.IR.attr(func, "res_attrs"), + no_inline=MLIR.IR.attr(func, "no_inline"), sym_visibility, ) end MLIR.API.mlirRegionTakeBody(MLIR.IR.region(func2, 1), MLIR.IR.region(func, 1)) + mesh_cache = Reactant.Compiler.sdycache() + is_sharded = !isempty(mesh_cache) + + if is_sharded + unique_meshes = keys(mesh_cache) + + # TODO: support multiple meshes + if length(unique_meshes) > 1 + error("Currently we support using a single mesh") + sorted_devices = [sort(vec(m.device_ids)) for m in unique_meshes] + @assert allequal(sorted_devices) "All meshes must have the same device ids" + end + sharding_mesh = first(unique_meshes) + num_partitions = length(sharding_mesh) + + linear_arg_shardings = Vector{MLIR.IR.Attribute}(undef, length(linear_args)) + + # Attach `sdy.sharding` attribute to the argument + for (i, arg) in enumerate(linear_args) + if haskey(traced_args_to_shardings, arg) + sharding = traced_args_to_shardings[arg] + (; sym_name, mesh_attr) = mesh_cache[sharding.mesh] + linear_arg_shardings[i] = Reactant.Sharding.get_shardy_tensor_sharding_attribute( + sharding, ctx, sym_name, mesh_attr + ) + MLIR.API.mlirFuncSetArgAttr( + func2, i - 1, "sdy.sharding", linear_arg_shardings[i] + ) + end + end + + # Ensure the sharding of the mutated arguments is propagated to the results + result_not_replicated = falses(length(linear_results)) + for i in mutated_args + arg = linear_args[i] + if has_residx(arg) && haskey(traced_args_to_shardings, arg) + residx = findfirst(Base.Fix1(===, arg), linear_results) + @assert residx !== nothing + result_not_replicated[residx] = true + MLIR.API.mlirFuncSetResultAttr( + func2, residx - 1, "sdy.sharding", linear_arg_shardings[i] + ) + end + end + else + sharding_mesh = nothing + end + MLIR.API.mlirOperationDestroy(func.operation) func.operation = MLIR.API.MlirOperation(C_NULL) - return ( + + return CompiledMlirFnResult( false, func2, traced_result, @@ -297,21 +408,42 @@ function make_mlir_fn( linear_args, in_tys, linear_results, + num_partitions, + num_replicas, + is_sharded, + nothing, + nothing, + sharding_mesh, + mutated_args, ) end -elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:ReactantPrimitive} = x +function __lookup_unique_name_in_module(mod, name) + new_name = name + tab = MLIR.IR.SymbolTable(MLIR.IR.Operation(mod)) + for i in 0:10000 + new_name = i == 0 ? name : name * "_" * string(i) + MLIR.IR.mlirIsNull(MLIR.API.mlirSymbolTableLookup(tab, new_name)) && return new_name + end + modstr = string(mod) + return error("Mod\n$modstr\nCould not find unique name for $name") +end + +function __take_region(compiled_fn) + region = MLIR.IR.Region() + MLIR.API.mlirRegionTakeBody(region, MLIR.API.mlirOperationGetRegion(compiled_fn, 0)) + return region +end + +elem_apply(::Type{T}, x::TracedRArray{T}) where {T} = x -struct TypeCast{T<:ReactantPrimitive} <: Function end +struct TypeCast{T<:Reactant.ReactantPrimitive} <: Function end function (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} - return TracedUtils.promote_to(TracedRNumber{T}, x) + return promote_to(TracedRNumber{T}, x) end -function elem_apply( - ::Type{T}, x::TracedRArray{T2} -) where {T<:ReactantPrimitive,T2<:ReactantPrimitive} - # Special Path to prevent going down a despecialized path +function elem_apply(::Type{T}, x::TracedRArray) where {T<:Reactant.ReactantPrimitive} return elem_apply(TypeCast{T}(), x) end @@ -325,12 +457,12 @@ function push_val!(ad_inputs, x, path) for p in path x = Reactant.Compiler.traced_getfield(x, p) end - x = x.mlir_data + x = get_mlir_data(x) return push!(ad_inputs, x) end function get_argidx(x) - for path in x.paths + for path in get_paths(x) if length(path) == 0 continue end @@ -342,7 +474,7 @@ function get_argidx(x) end function has_argidx(x) - for path in x.paths + for path in get_paths(x) if length(path) == 0 continue end @@ -358,15 +490,13 @@ function set!(x, path, tostore; emptypath=false) x = Reactant.Compiler.traced_getfield(x, p) end - x.mlir_data = tostore + set_mlir_data!(x, tostore) - if emptypath - x.paths = () - end + return emptypath && set_paths!(x, ()) end function get_residx(x) - for path in x.paths + for path in get_paths(x) if length(path) == 0 continue end @@ -378,7 +508,7 @@ function get_residx(x) end function has_residx(x) - for path in x.paths + for path in get_paths(x) if length(path) == 0 continue end @@ -392,14 +522,17 @@ end function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} if all(iszero ∘ ndims, args) scalar_args = map(args) do arg - return promote_to(TracedRNumber{eltype(arg)}, arg) + return promote_to(TracedRNumber{Reactant.unwrapped_eltype(arg)}, arg) end return f(scalar_args...) end - fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn( + mlir_fn_res = make_mlir_fn( f, args, (), string(f) * "_broadcast_scalar", false; toscalar=true ) + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + (; result, seen_args, linear_args, linear_results) = mlir_fn_res invmap = IdDict() for (k, v) in seen_args @@ -413,10 +546,9 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} OutShape = isempty(seen_args) ? nothing : first(input_shapes) @assert !isnothing(OutShape) - in_tys2 = [Ops.mlir_type(invmap[arg]) for arg in linear_args] - out_tys2 = [ - MLIR.IR.TensorType(OutShape, MLIR.IR.Type(eltype(arg))) for arg in linear_results + MLIR.IR.TensorType(OutShape, MLIR.IR.Type(Reactant.unwrapped_eltype(arg))) for + arg in linear_results ] fname = get_attribute_by_name(func2, "sym_name") @@ -425,7 +557,7 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} batch_inputs = MLIR.IR.Value[] for a in linear_args - idx, path = TracedUtils.get_argidx(a) + idx, path = get_argidx(a) if idx == 1 && fnwrap push_val!(batch_inputs, f, path[3:end]) else @@ -446,20 +578,20 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} residx = 1 for a in linear_results - if TracedUtils.has_residx(a) - path = TracedUtils.get_residx(a) - TracedUtils.set!(result, path[2:end], MLIR.IR.result(res, residx)) + if has_residx(a) + path = get_residx(a) + set!(result, path[2:end], MLIR.IR.result(res, residx)) residx += 1 else - idx, path = TracedUtils.get_argidx(a) + idx, path = get_argidx(a) if idx == 1 && fnwrap - TracedUtils.set!(f, path[3:end], MLIR.IR.result(res, residx)) + set!(f, path[3:end], MLIR.IR.result(res, residx)) residx += 1 else if fnwrap idx -= 1 end - TracedUtils.set!(args[idx], path[3:end], MLIR.IR.result(res, residx)) + set!(args[idx], path[3:end], MLIR.IR.result(res, residx)) residx += 1 end end @@ -475,11 +607,26 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} return traced2_result end -new_traced_value(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), nothing, size(A)) -new_traced_value(::TracedRNumber{T}) where {T} = TracedRNumber{T}((), nothing) - +function broadcast_to_size(arg::AbstractArray{<:TracedRNumber}, rsize) + if Reactant.ancestor(arg) isa TracedRArray + return broadcast_to_size(materialize_traced_array(arg), rsize) + end + return broadcast_to_size(reshape(Ops.vcat(arg...), size(arg)...), rsize) +end broadcast_to_size(arg::AbstractArray, rsize) = broadcast_to_size(Ops.constant(arg), rsize) +broadcast_to_size(arg::AbstractRange, rsize) = broadcast_to_size(collect(arg), rsize) +function broadcast_to_size(arg::UnitRange, rsize) + # For small inputs this will be automatically optimized away, and for large ranges + # helps reduce the IR size + x = Ops.add( + Ops.iota(eltype(arg), [length(arg)]; iota_dimension=1), + Ops.fill(first(arg), [length(arg)]), + ) + return broadcast_to_size(x, rsize) +end +broadcast_to_size(arg::Base.OneTo, rsize) = broadcast_to_size(1:last(arg), rsize) + function broadcast_to_size(arg::Base.RefValue, rsize) # XXX: don't we want to expand here to rsize? return arg @@ -487,16 +634,14 @@ end broadcast_to_size(arg::Number, rsize) = Ops.constant(Base.fill(arg, Tuple(rsize))) -function broadcast_to_size(arg::TracedRNumber, rsize) +function broadcast_to_size(arg::TracedRNumber{T}, rsize) where {T} length(rsize) == 0 && return arg - return broadcast_to_size_internal( - TracedRArray{eltype(arg),0}((), arg.mlir_data, ()), rsize - ) + return broadcast_to_size_internal(TracedRArray{T,0}((), get_mlir_data(arg), ()), rsize) end function broadcast_to_size(arg::AnyTracedRArray{T,0}, rsize) where {T} arg = materialize_traced_array(arg) - return broadcast_to_size(TracedRNumber{T}((), arg.mlir_data), rsize) + return broadcast_to_size(TracedRNumber{T}((), get_mlir_data(arg)), rsize) end function broadcast_to_size(arg::AnyTracedRArray, rsize) @@ -512,30 +657,42 @@ function broadcast_to_size(arg::Broadcast.Extruded, rsize) return broadcast_to_size_internal(x, rsize) end -@noinline function broadcast_to_size_internal(x::TracedRArray, rsize) - dims = collect(Int64, 0:(length(size(x)) - 1)) +@noinline function broadcast_to_size_internal(x::TracedRArray{T}, rsize) where {T} + return Ops.broadcast_in_dim(x, collect(Int64, 1:ndims(x)), collect(Int64, rsize)) +end + +function normalize_indices(a::AbstractArray, indices...) + return map(enumerate(indices)) do (i, idx) + idx isa Colon && return collect(Int64, 1:size(a, i)) + idx isa CartesianIndex && return Tuple(idx) + idx isa AbstractArray{Bool} && return findall(idx) + return idx + end +end - if length(size(MLIR.IR.type(x.mlir_data))) != length(dims) - @show x - @show arg - @show rsize - @show rsize2 - @show dims +function traced_indices(indices...) + integer_indices = Int64[] + result_size = Int64[] + preddim_result_size = Int64[] + flattened_size = Int64[length(idx) for idx in indices] + new_indices = map(enumerate(indices)) do (i, idx) + if idx isa Number + push!(preddim_result_size, 1) + push!(integer_indices, i) + idx isa TracedRNumber && return idx + return promote_to(TracedRNumber{Int}, idx) + end + append!(preddim_result_size, [size(idx)...]) + append!(result_size, [size(idx)...]) + idx isa TracedRArray && return materialize_traced_array(vec(idx)) + return promote_to(TracedRArray{Int,1}, vec(idx)) end - @assert length(size(MLIR.IR.type(x.mlir_data))) == length(dims) - mlirty = MLIR.IR.type(x.mlir_data) - - return TracedRArray{eltype(x),Int(length(rsize))}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.broadcast_in_dim( - x.mlir_data; - result_0=MLIR.IR.TensorType([t for t in rsize], eltype(mlirty)), - broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims), - ), - 1, - ), - collect(rsize), + return ( + new_indices, + Tuple(integer_indices), + result_size, + preddim_result_size, + flattened_size, ) end diff --git a/src/Tracing.jl b/src/Tracing.jl index 5b1f1e72cb..d0668913c9 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -4,36 +4,89 @@ TracedToConcrete = 3 ArrayToConcrete = 4 TracedSetPath = 5 + TracedToTypes = 6 + NoStopTracedTrack = 7 end -for T in (DataType, Module, Nothing, Symbol, AbstractChar, AbstractString, RArray, RNumber) - @eval function traced_type(::Type{T}, seen, mode, track_numbers) where {T<:$T} +struct VisitedObject + id::Int +end + +function traced_type_inner end + +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{Union{}}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type), + @nospecialize(sharding) +) + return T +end + +for T in ( + DataType, + Module, + Nothing, + Symbol, + AbstractChar, + AbstractString, + AbstractFloat, + Integer, + RNumber, + Val, + VersionNumber, +) + @eval Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:$T}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type), + @nospecialize(sharding) + ) return T end end -function traced_type( - ::Type{T}, seen, mode::Val{Mode}, track_numbers -) where {T<:Union{AbstractFloat,Integer},Mode} - if Mode == ArrayToConcrete && any(Base.Fix1(<:, T), track_numbers) - return ConcreteRNumber{T} +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:ReactantPrimitive}), + seen, + @nospecialize(mode::TraceMode), + @nospecialize(track_numbers::Type), + @nospecialize(sharding) +) + if mode == ArrayToConcrete && T <: track_numbers + return ConcretePJRTNumber{ + T,Sharding.ndevices(sharding),Sharding.shard_type(typeof(sharding), 0) + } + elseif (mode == NoStopTracedTrack || mode == TracedTrack || mode == TracedSetPath) && + T <: track_numbers + return TracedRNumber{T} end return T end -function traced_type( - ::Type{C}, seen::ST, mode::Val{Mode}, track_numbers::TN -) where {T,C<:Complex{T},ST,Mode,TN} +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(C::Type{<:Complex}), + seen, + @nospecialize(mode::TraceMode), + @nospecialize(track_numbers::Type), + @nospecialize(sharding) +) if !(C isa UnionAll) - return Complex{traced_type(T, seen, mode, track_numbers)} + return Complex{traced_type_inner(C.parameters[1], seen, mode, track_numbers)} else - return @invoke traced_type( - C::Type{Any}, seen::ST, mode::Val{Mode}, track_numbers::TN - ) + return C end end -function traced_type(::Type{T}, seen, mode, track_numbers) where {T<:Function} +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:Function}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type), + @nospecialize(sharding) +) # functions are directly returned if sizeof(T) == 0 return T @@ -42,10 +95,13 @@ function traced_type(::Type{T}, seen, mode, track_numbers) where {T<:Function} # in closures, enclosured variables need to be traced N = fieldcount(T) changed = false - traced_fieldtypes = ntuple(Val(N)) do i - next = traced_type(fieldtype(T, i), seen, mode, track_numbers) + traced_fieldtypes = Type[] + for i in 1:N + next = traced_type_inner( + fieldtype(T, i), seen, mode, track_numbers, getproperty(sharding, i) + ) changed |= next != fieldtype(T, i) - next + push!(traced_fieldtypes, next) end if !changed @@ -56,37 +112,375 @@ function traced_type(::Type{T}, seen, mode, track_numbers) where {T<:Function} return Core.apply_type(T.name.wrapper, traced_fieldtypes...) end -@inline is_concrete_tuple(x::T2) where {T2} = - (x <: Tuple) && !(x === Tuple) && !(x isa UnionAll) - -function traced_type(::Type{T}, seen, mode, track_numbers) where {T<:Tuple} - if !Base.isconcretetype(T) || !is_concrete_tuple(T) || T isa UnionAll +Base.@nospecializeinfer function traced_tuple_type_inner( + @nospecialize(T::Type{<:Tuple}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type), + @nospecialize(sharding) +) + if T === Tuple + return T + end + if T isa UnionAll + if T.var.lb === Union{} && T.var.ub === Any + return UnionAll( + T.var, traced_type_inner(T.body, seen, mode, track_numbers, sharding) + ) + end throw(AssertionError("Type $T is not concrete type or concrete tuple")) - elseif is_concrete_tuple(T) && any(T2 isa Core.TypeofVararg for T2 in T.parameters) - # Tuple{((T2 isa Core.TypeofVararg ? Any : T2) for T2 in T.parameters)...} - throw(AssertionError("Type tuple of vararg $T is not supported")) - end - TT = [ - traced_type(T.parameters[i], seen, mode, track_numbers) for - i in 1:length(T.parameters) - ] + end + TT = Union{Type,Core.TypeofVararg}[] + for i in 1:length(T.parameters) + st = traced_type_inner(T.parameters[i], seen, mode, track_numbers, sharding) + push!(TT, st) + end return Tuple{TT...} end -function traced_type(::Type{T}, seen, mode, track_numbers) where {N,V,T<:NamedTuple{N,V}} - return NamedTuple{N,traced_type(V, seen, mode, track_numbers)} +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Core.TypeofVararg), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type), + @nospecialize(sharding) +) + return Vararg{traced_type_inner(T.T, seen, mode, track_numbers, sharding),T.N} +end + +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::TypeVar), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type), + @nospecialize(sharding) +) + if T.lb === Union{} && T.ub === Any + return T + end + throw(AssertionError("Unsupported Typevar $T lb=$(T.lb) ub=$(T.ub)")) +end + +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:Tuple}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type), + @nospecialize(sharding) +) + return traced_tuple_type_inner(T, seen, mode, track_numbers, sharding) +end + +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:NamedTuple}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type), + @nospecialize(sharding) +) + N = T.parameters[1] + V = T.parameters[2] + return NamedTuple{N,traced_type_inner(V, seen, mode, track_numbers, sharding)} +end + +Base.@nospecializeinfer @inline dict_key(::Type{<:AbstractDict}) = nothing +Base.@nospecializeinfer @inline dict_key(::Type{<:AbstractDict{K}}) where {K} = K +Base.@nospecializeinfer @inline dict_value(::Type{<:AbstractDict}) = nothing +Base.@nospecializeinfer @inline function dict_value( + T::Type{<:(AbstractDict{K,V} where {K})} +) where {V} + if @isdefined(V) + V + elseif T <: UnionAll + dict_value(T.body) + elseif T <: Dict && length(T.parameters) >= 2 + T.parameters[2] + else + error("Could not get element type of $T") + end +end + +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:AbstractDict}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type), + @nospecialize(sharding) +) + V = dict_value(T) + if V === nothing + return T + else + K = dict_key(T) + V2 = traced_type_inner(V, seen, mode, track_numbers, sharding) + if V == V2 + return T + end + dictty = if T isa UnionAll + T.body.name.wrapper + else + T.name.wrapper + end + if K !== nothing + return dictty{K,V2} + else + return (dictty{KT,V2} where {KT}) + end + end +end + +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T0::Type{<:ConcretePJRTNumber}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type), + @nospecialize(sharding) +) + T = T0.parameters[1] + if mode == ConcreteToTraced + return TracedRNumber{T} + elseif mode == TracedToConcrete + return T0 + else + throw("Abstract RNumber cannot be made concrete") + end +end + +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:ConcretePJRTArray}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type), + @nospecialize(sharding) +) + if mode == ConcreteToTraced + return TracedRArray{T.parameters[1],T.parameters[2]} + elseif mode == TracedToConcrete + return T + else + throw("Abstract RArray cannot be made concrete") + end +end + +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:ConcreteRNG}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type), + @nospecialize(sharding) +) + if mode == ConcreteToTraced + return TracedRNG + elseif mode == TracedToConcrete + return T + else + throw("Unsupported mode: $mode") + end +end + +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:MissingTracedValue}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type), + @nospecialize(sharding) +) + return error("This should not happen") +end + +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:TracedRArray}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type), + @nospecialize(sharding) +) + if mode == ConcreteToTraced + throw("TracedRArray cannot be traced") + elseif mode == TracedToConcrete + return ConcretePJRTArray{ + T.parameters[1], + T.parameters[2], + Sharding.ndevices(sharding), + Sharding.shard_type(typeof(sharding), T.parameters[2]), + } + elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath + return T + else + throw("Abstract RArray cannot be made concrete in mode $mode") + end +end + +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:TracedRNumber}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type), + @nospecialize(sharding) +) + if mode == ConcreteToTraced + throw("TracedRNumber cannot be traced") + elseif mode == TracedToConcrete + if T isa UnionAll + return UnionAll( + T.var, + ConcretePJRTNumber{ + T.var, + Sharding.ndevices(sharding), + Sharding.shard_type(typeof(sharding), 0), + }, + ) + end + return ConcretePJRTNumber{ + T.parameters[1], + Sharding.ndevices(sharding), + Sharding.shard_type(typeof(sharding), 0), + } + elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath + return T + else + throw("Abstract RNumber cannot be made concrete in mode $mode") + end +end + +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:TracedRNG}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type), + @nospecialize(sharding) +) + if mode == ConcreteToTraced + throw("TracedRNG cannot be traced") + elseif mode == TracedToConcrete + return ConcreteRNG{ + traced_type_inner(TracedRArray{UInt64,1}, seen, mode, track_numbers, sharding) + } + elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath + return T + else + throw("Unsupported mode: $mode") + end +end + +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(A::Type{AbstractArray}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type) +) + return A end -function traced_type(::Type{T}, seen, mode, track_numbers) where {K,V,T<:AbstractDict{K,V}} - dictty = T.name.wrapper - return dictty{K,traced_type(V, seen, mode, track_numbers)} +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(A::Type{AbstractArray{T}}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type) +) where {T} + if mode == ConcreteToTraced + return AbstractArray{TracedRNumber{T}} + else + return A + end end -@inline getmap(::Val{T}) where {T} = nothing -@inline getmap(::Val{T}, a, b, args...) where {T} = getmap(Val(T), args...) -@inline getmap(::Val{T}, ::Val{T}, ::Val{T2}, args...) where {T,T2} = T2 +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(A::Type{AbstractArray{T,N}}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type) +) where {T,N} + if mode == ConcreteToTraced + return AbstractArray{TracedRNumber{T},N} + else + return A + end +end -function traced_type(::Type{T}, seen, mode, track_numbers) where {T} +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(A::Type{<:Array}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type), + @nospecialize(sharding) +) + T = eltype(A) + if A isa UnionAll + if mode == ArrayToConcrete && T <: Reactant.ReactantPrimitive + return ConcretePJRTArray{T} + else + return Array{ + traced_type_inner(T, seen, mode, track_numbers, getproperty(sharding, 1)) + } + end + else + N = ndims(A) + if mode == ArrayToConcrete && T <: Reactant.ReactantPrimitive + return ConcretePJRTArray{ + T,N,Sharding.ndevices(sharding),Sharding.shard_type(typeof(sharding), N) + } + else + return Array{ + traced_type_inner(T, seen, mode, track_numbers, getproperty(sharding, 1)),N + } + end + end +end + +for P in (Ptr, Core.LLVMPtr, Base.RefValue) + @eval Base.@nospecializeinfer function traced_type_inner( + @nospecialize(PT::Type{$P}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type), + @nospecialize(sharding) + ) + return $P + end +end +for P in (Ptr, Base.RefValue) + @eval Base.@nospecializeinfer function traced_type_inner( + @nospecialize(PT::Type{$P{T}}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type), + @nospecialize(sharding) + ) where {T} + return $P{traced_type_inner(PT.parameters[1], seen, mode, track_numbers, sharding)} + end +end + +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(PT::Type{Core.LLVMPtr{T}}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type), + @nospecialize(sharding) +) where {T} + return Core.LLVMPtr{ + traced_type_inner(PT.body.parameters[1], seen, mode, track_numbers, sharding) + } +end +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(PT::Type{Core.LLVMPtr{T,A}}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type), + @nospecialize(sharding) +) where {T,A} + return Core.LLVMPtr{ + traced_type_inner(PT.parameters[1], seen, mode, track_numbers, sharding),A + } +end + +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type), + @nospecialize(sharding) +) if T === Any return T end @@ -104,59 +498,91 @@ function traced_type(::Type{T}, seen, mode, track_numbers) where {T} end # unknown number of fields - if T isa UnionAll + if Base.inferencebarrier(T) isa UnionAll + if T.var.lb === Union{} && T.var.ub === Any || T <: Type + return UnionAll( + T.var, traced_type_inner(T.body, seen, mode, track_numbers, sharding) + ) + end aT = Base.argument_datatype(T) if isnothing(aT) throw(TracedTypeError("Unhandled type $T")) end if isnothing(Base.datatype_fieldcount(aT)) - throw(TracedTypeError("Unhandled type $T")) + throw(TracedTypeError("Unhandled type $T, aT=$aT")) end + return T end if T isa Union return Union{ - traced_type(T.a, seen, mode, track_numbers), - traced_type(T.b, seen, mode, track_numbers), + traced_type_inner(T.a, seen, mode, track_numbers, sharding), + traced_type_inner(T.b, seen, mode, track_numbers, sharding), } end # if abstract it must be by reference if Base.isabstracttype(T) + if !(T isa UnionAll) && length(T.parameters) == 0 || T <: Type + return T + end throw(TracedTypeError("Unhandled abstract type $T")) end - if !(Base.isconcretetype(T) || T isa UnionAll) - throw(AssertionError("Type $T is not concrete type or concrete tuple")) + if T <: Tuple + return traced_tuple_type_inner(T, seen, mode, track_numbers, sharding) end - nextTy = getmap(Val(T), seen...) - if !isnothing(nextTy) - return nextTy + if haskey(seen, T) + return seen[T] end - seen2 = (Val(T), Val(T), seen...) + seen2 = copy(seen) + seen2[T] = T changed = false - subTys = Type[] + subTys = Union{Type,TypeVar}[] for f in 1:fieldcount(T) subT = fieldtype(T, f) - subTT = traced_type(subT, seen2, mode, track_numbers) + subTT = traced_type_inner(subT, seen2, mode, track_numbers, sharding) changed |= subT != subTT push!(subTys, subTT) end if !changed + for (k, v) in seen2 + seen[k] = v + end return T end + wrapped_carray = T <: AbstractArray && ancestor(T) <: ConcretePJRTArray + wrapped_tracedarray = T <: AbstractArray && ancestor(T) <: TracedRArray + subParms = [] - for SST in T.parameters - if SST isa Type - TrT = traced_type(SST, seen, mode, track_numbers) + for (i, SST) in enumerate(T.parameters) + if wrapped_carray && i == 1 && SST isa Type && SST <: ReactantPrimitive + # XXX: Sharding??? + TrT = traced_type_inner( + ConcretePJRTNumber{SST,1,Sharding.ShardInfo}, + seen, + mode, + track_numbers, + sharding, + ) + push!(subParms, TrT) + elseif wrapped_tracedarray && i == 1 && SST isa Type && SST <: TracedRNumber + TrT = traced_type_inner( + unwrapped_eltype(SST), seen, mode, track_numbers, sharding + ) push!(subParms, TrT) else - push!(subParms, SST) + if SST isa Type + TrT = traced_type_inner(SST, seen, mode, track_numbers, sharding) + push!(subParms, TrT) + else + push!(subParms, SST) + end end end @@ -165,98 +591,140 @@ function traced_type(::Type{T}, seen, mode, track_numbers) where {T} else TT2 = T end - seen3 = (Val(T), Val(TT2), seen...) + seen3 = copy(seen) + seen3[T] = TT2 if fieldcount(T) == fieldcount(TT2) legal = true for f in 1:fieldcount(T) subT = fieldtype(T, f) subT2 = fieldtype(TT2, f) - subTT = traced_type(subT, seen3, mode, track_numbers) + subTT = traced_type_inner(subT, seen3, mode, track_numbers, sharding) if subT2 != subTT legal = false break end end if legal + for (k, v) in seen3 + seen[k] = v + end return TT2 end end name = Symbol[] - throw(NoFieldMatchError(T, TT2)) + throw(NoFieldMatchError(T, TT2, subTys)) end -function traced_type( - ::Type{<:ConcreteRNumber{T}}, seen, ::Val{mode}, track_numbers -) where {T,mode} - if mode == ConcreteToTraced - return TracedRNumber{T} - elseif mode == TracedToConcrete - return ConcreteRNumber{T} - else - throw("Abstract RNumber cannot be made concrete") - end -end +const traced_type_cache = Dict{Tuple{TraceMode,Type,Any},Dict{Type,Type}}() -function traced_type( - ::Type{T}, seen, ::Val{mode}, track_numbers -) where {T<:ConcreteRArray,mode} - if mode == ConcreteToTraced - @inline base_typet(TV::TT) where {TT<:UnionAll} = - UnionAll(TV.var, base_typet(TV.body)) - @inline base_typet(TV::TT) where {TT<:DataType} = TracedRArray{TV.parameters...} - return base_typet(T) - elseif mode == TracedToConcrete - return T - else - throw("Abstract RArray cannot be made concrete") - end -end +# function traced_type_generator(world::UInt, source, self, @nospecialize(T::Type), @nospecialize(mode::Type{<:Val}), @nospecialize(track_numbers::Type)) +# @nospecialize +# T = T.parameters[1] +# mode = mode.parameters[1]::TraceMode +# track_numbers = track_numbers.parameters[1] +# +# +# min_world = Ref{UInt}(typemin(UInt)) +# max_world = Ref{UInt}(typemax(UInt)) +# +# sig = Tuple{typeof(traced_type_inner), Type{T}, Dict{Type, Type}, TraceMode, Type{track_numbers}} +# +# lookup_result = lookup_world( +# sig, world, nothing, min_world, max_world +# ) +# if lookup_result === nothing +# stub = Core.GeneratedFunctionStub(identity, Core.svec(:traced_type, :T, :mode, :track_numbers), Core.svec()) +# return stub(world, source, method_error) +# end +# match = lookup_result::Core.MethodMatch +# +# mi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance}, +# (Any, Any, Any), match.method, match.spec_types, match.sparams)::Core.MethodInstance +# +# ci = Core.Compiler.retrieve_code_info(mi, world)::Core.Compiler.CodeInfo +# +# cache = nothing +# cache_key = (mode, track_numbers) +# if haskey(traced_type_cache, cache_key) +# cache = traced_type_cache[cache_key] +# else +# cache = Dict{Type, Type}() +# traced_type_cache[cache_key] = cache +# end +# +# +# # prepare a new code info +# new_ci = copy(ci) +# empty!(new_ci.code) +# @static if isdefined(Core, :DebugInfo) +# new_ci.debuginfo = Core.DebugInfo(:none) +# else +# empty!(new_ci.codelocs) +# resize!(new_ci.linetable, 1) # see note below +# end +# empty!(new_ci.ssaflags) +# new_ci.ssavaluetypes = 0 +# new_ci.min_world = min_world[] +# new_ci.max_world = max_world[] +# edges = Any[mi] +# gensig = Tuple{typeof(traced_type_inner), Type, Dict{Type, Type}, TraceMode, Type{track_numbers}} +# push!(edges, ccall(:jl_method_table_for, Any, (Any,), gensig)) +# push!(edges, gensig) +# +# new_ci.edges = edges +# +# # XXX: setting this edge does not give us proper method invalidation, see +# # JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel. +# # invoking `code_llvm` also does the necessary codegen, as does calling the +# # underlying C methods -- which GPUCompiler does, so everything Just Works. +# +# # prepare the slots +# new_ci.slotnames = Symbol[Symbol("#self#"), :T, :mode, :track_numbers] +# new_ci.slotflags = UInt8[0x00 for i = 1:4] +# +# # return the codegen world age +# res1 = call_with_reactant(traced_type_inner, T, cache, mode, track_numbers) +# +# res0 = Base.invoke_in_world(world, traced_type_inner, T, cache, mode, track_numbers) +# res = Base.invokelatest(traced_type_inner, T, cache, mode, track_numbers) +# push!(new_ci.code, Core.Compiler.ReturnNode(res)) +# push!(new_ci.ssaflags, 0x00) # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code` +# @static if isdefined(Core, :DebugInfo) +# else +# push!(new_ci.codelocs, 1) # see note below +# end +# new_ci.ssavaluetypes += 1 +# +# # NOTE: we keep the first entry of the original linetable, and use it for location info +# # on the call to check_cache. we can't not have a codeloc (using 0 causes +# # corruption of the back trace), and reusing the target function's info +# # has as advantage that we see the name of the kernel in the backtraces. +# +# return new_ci +# end +# +# @eval Base.@assume_effects :removable :foldable :nothrow @inline function traced_type_old(T::Type, mode::Val, track_numbers::Type) +# $(Expr(:meta, :generated_only)) +# $(Expr(:meta, :generated, traced_type_generator)) +# end -function traced_type( - ::Type{T}, seen::ST, ::Val{mode}, track_numbers -) where {ST,T<:TracedType,mode} - T <: MissingTracedValue && error("TODO") - if mode == ConcreteToTraced - throw("TracedRArray $T cannot be traced") - elseif mode == TracedToConcrete - @inline base_typec(TV::TT) where {TT<:UnionAll} = - UnionAll(TV.var, base_typec(TV.body)) - @inline base_typec(TV::TT) where {TT<:DataType} = - (T <: TracedRArray ? ConcreteRArray : ConcreteRNumber){TV.parameters...} - return base_typec(T) - elseif mode == TracedTrack || mode == TracedSetPath +Base.@assume_effects :total @inline function traced_type( + T::Type, ::Val{mode}, track_numbers::Type, sharding +) where {mode} + if mode == TracedSetPath || mode == TracedTrack return T - else - throw("Abstract RArray $T cannot be made concrete in mode $mode") end -end - -function traced_type(::Type{T}, seen, mode, track_numbers) where {T<:XLAArray} - throw("XLA $T array cannot be traced") -end -function traced_type( - ::Type{A}, seen::ST, ::Val{mode}, track_numbers -) where {T,N,A<:Array{T,N},ST,mode} - if mode == ArrayToConcrete && T <: ReactantPrimitive - return ConcreteRArray{T,N} + cache = nothing + cache_key = (mode, track_numbers, sharding) + if haskey(traced_type_cache, cache_key) + cache = traced_type_cache[cache_key] else - return Array{traced_type(T, seen, Val(mode), track_numbers),N} - end -end - -for P in (Ptr, Core.LLVMPtr, Base.RefValue) - @eval function traced_type(::Type{P}, seen, mode, track_numbers) where {T,P<:$P{T}} - return $P{traced_type(T, seen, mode, track_numbers)} - end -end - -function traced_type(::Type{Val{T}}, seen, mode, track_numbers) where {T} - if traced_type(typeof(T), seen, mode, track_numbers) == typeof(T) - return Val{T} + cache = Dict{Type,Type}() + traced_type_cache[cache_key] = cache end - throw("Val type $(Val{T}) cannot be traced") + return traced_type_inner(T, cache, mode, track_numbers, sharding) end abstract type TracedTypeException <: Exception end @@ -272,36 +740,63 @@ end struct NoFieldMatchError <: TracedTypeException origty besteffort + subTys end function Base.showerror(io::IO, err::NoFieldMatchError) - print(io, "NoFieldMatchError: ") - return print( + println(io, "NoFieldMatchError: ") + println( io, "Cannot convert type $(err.origty), best attempt $(err.besteffort) failed.\nThis could be because the type does not capture the fieldtypes that should be converted in its type parameters.", ) + for (i, subty) in zip(1:fieldcount(err.origty), err.subTys) + origty = fieldtype(err.origty, i) + println(io, "idx=", i, " Derived: ", subty, " Existing: ", origty) + end end -append_path(path, i) = (path..., i) +function make_tracer( + seen, + @nospecialize(prev::Union{Base.ExceptionStack,Core.MethodInstance}), + @nospecialize(path), + mode; + kwargs..., +) + return prev +end +append_path(@nospecialize(path), i) = (path..., i) function make_tracer( seen, - @nospecialize(prev::RT), + @nospecialize(prev), @nospecialize(path), mode; - toscalar=false, - tobatch=nothing, - track_numbers=(), + @nospecialize(track_numbers::Type = Union{}), + @nospecialize(sharding = Sharding.NoSharding()), kwargs..., -) where {RT} +) + RT = Core.Typeof(prev) if haskey(seen, prev) - return seen[prev] + if mode == TracedToTypes + id = seen[prev] + push!(path, id) + return nothing + elseif mode != NoStopTracedTrack && haskey(seen, prev) + return seen[prev] + end + elseif mode == TracedToTypes + push!(path, RT) + seen[prev] = VisitedObject(length(seen) + 1) end - TT = traced_type(RT, (), Val(mode), track_numbers) + TT = traced_type(RT, Val(mode), track_numbers, sharding) @assert !Base.isabstracttype(RT) @assert Base.isconcretetype(RT) nf = fieldcount(RT) if TT === Module || TT === String + if mode == TracedToTypes + push!(path, prev) + return nothing + end return prev end @@ -311,15 +806,15 @@ function make_tracer( changed = false for i in 1:nf if isdefined(prev, i) + newpath = mode == TracedToTypes ? path : append_path(path, i) xi = Base.getfield(prev, i) xi2 = make_tracer( seen, xi, - append_path(path, i), + newpath, mode; - toscalar, - tobatch, track_numbers, + sharding=Base.getproperty(sharding, i), kwargs..., ) if xi !== xi2 @@ -336,6 +831,10 @@ function make_tracer( end if nf == 0 + if mode == TracedToTypes + push!(path, prev) + return nothing + end return prev end @@ -343,15 +842,15 @@ function make_tracer( changed = false for i in 1:nf if isdefined(prev, i) + newpath = mode == TracedToTypes ? path : append_path(path, i) xi = Base.getfield(prev, i) xi2 = make_tracer( seen, xi, - append_path(path, i), + newpath, mode; - toscalar, - tobatch, track_numbers, + sharding=Base.getproperty(sharding, i), kwargs..., ) if xi !== xi2 @@ -363,6 +862,9 @@ function make_tracer( break end end + if mode == TracedToTypes + return nothing + end if !changed seen[prev] = prev return prev @@ -373,10 +875,23 @@ function make_tracer( end function make_tracer( - seen, @nospecialize(prev::ConcreteRArray{T,N}), @nospecialize(path), mode; kwargs... + seen, + @nospecialize(prev::ConcretePJRTArray{T,N}), + @nospecialize(path), + mode; + @nospecialize(sharding = Sharding.NoSharding()), + kwargs..., ) where {T,N} + if mode == TracedToTypes + throw("Cannot have ConcretePJRTArray as function call argument.") + end if mode == ArrayToConcrete - return prev + if prev.sharding isa Sharding.ShardInfo{typeof(sharding)} + return prev + end + error( + "Mismatched sharding. Input has sharding $(prev.sharding), but requested sharding is $(typeof(sharding))", + ) end if mode != ConcreteToTraced throw("Cannot trace concrete") @@ -390,9 +905,23 @@ function make_tracer( return res end -function make_tracer(seen, prev::ConcreteRNumber{T}, path, mode; kwargs...) where {T} +function make_tracer( + seen, + prev::ConcretePJRTNumber{T}, + @nospecialize(path), + mode; + @nospecialize(sharding = Sharding.NoSharding()), + kwargs..., +) where {T} + if mode == TracedToTypes + throw("Cannot have ConcretePJRTNumber as function call argument.") + end if mode == ArrayToConcrete - return prev + if !Sharding.is_sharded(sharding) + return prev + else + return ConcretePJRTNumber(prev; sharding) + end end if mode != ConcreteToTraced throw("Cannot trace existing trace type") @@ -412,18 +941,30 @@ function make_tracer( mode; toscalar=false, tobatch=nothing, + @nospecialize(sharding = Sharding.NoSharding()), kwargs..., ) where {T,N} if mode == ConcreteToTraced throw("Cannot trace existing trace type") end + if mode == TracedToTypes + push!(path, MLIR.IR.type(prev.mlir_data)) + return nothing + end if mode == TracedTrack - prev.paths = (prev.paths..., path) + TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path)) if !haskey(seen, prev) return seen[prev] = prev end return prev end + if mode == NoStopTracedTrack + TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path)) + if !haskey(seen, prev) + seen[prev] = prev # don't return! + end + return prev + end if mode == TracedSetPath if haskey(seen, prev) return seen[prev] @@ -441,9 +982,15 @@ function make_tracer( if mode == TracedToConcrete if haskey(seen, prev) - return seen[prev]::ConcreteRArray{T,N} + return seen[prev]::ConcretePJRTArray{T,N} + end + if !Sharding.is_sharded(sharding) + res = ConcretePJRTArray{T,N,1,Sharding.NoShardInfo}( + (XLA.PJRT.AsyncEmptyBuffer,), size(prev), Sharding.NoShardInfo() + ) + else + error("TODO: implement sharding") end - res = ConcreteRArray{T,N}(XLA.AsyncEmptyBuffer, size(prev)) seen[prev] = res return res end @@ -458,18 +1005,30 @@ function make_tracer( mode; tobatch=nothing, toscalar=false, + @nospecialize(sharding = Sharding.NoSharding()), kwargs..., ) where {T} if mode == ConcreteToTraced throw("Cannot trace existing trace type") end + if mode == TracedToTypes + push!(path, MLIR.IR.type(prev.mlir_data)) + return nothing + end if mode == TracedTrack - prev.paths = (prev.paths..., path) + TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path)) if !haskey(seen, prev) return seen[prev] = prev end return prev end + if mode == NoStopTracedTrack + TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path)) + if !haskey(seen, prev) + seen[prev] = prev # don't return! + end + return prev + end if mode == TracedSetPath if haskey(seen, prev) return seen[prev] @@ -487,9 +1046,15 @@ function make_tracer( if mode == TracedToConcrete if haskey(seen, prev) - return seen[prev]::ConcreteRNumber{T} + return seen[prev]::ConcretePJRTNumber{T} + end + if !Sharding.is_sharded(sharding) + res = ConcretePJRTNumber{T,1,Sharding.NoShardInfo}( + (XLA.PJRT.AsyncEmptyBuffer,), Sharding.NoShardInfo() + ) + else + error("TODO: implement sharding") end - res = ConcreteRNumber{T}(XLA.AsyncEmptyBuffer) seen[prev] = res return res end @@ -503,13 +1068,23 @@ function make_tracer( if mode == ConcreteToTraced throw("Cannot trace existing trace type") end + if mode == TracedToTypes + throw("Cannot have MissingTracedValue as function call argument.") + end if mode == TracedTrack - prev.paths = (prev.paths..., path) + TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path)) if !haskey(seen, prev) return seen[prev] = prev end return prev end + if mode == NoStopTracedTrack + TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path)) + if !haskey(seen, prev) + seen[prev] = prev # don't return! + end + return prev + end if mode == TracedSetPath haskey(seen, prev) && return seen[prev] res = MissingTracedValue((path,)) @@ -524,15 +1099,24 @@ function make_tracer( end function make_tracer( - seen, @nospecialize(prev::RT), @nospecialize(path), mode; track_numbers=(), kwargs... -) where {RT<:Number} - length(track_numbers) == 0 && return prev - should_convert = any(Base.Fix1(<:, RT), track_numbers) - if should_convert + seen, + @nospecialize(prev::Number), + @nospecialize(path), + mode; + @nospecialize(track_numbers::Type = Union{}), + @nospecialize(sharding = Sharding.NoSharding()), + kwargs..., +) + if mode == TracedToTypes + push!(path, prev) + return nothing + end + RT = Core.Typeof(prev) + if RT <: track_numbers && mode != TracedSetPath && mode != TracedTrack if mode == ArrayToConcrete - return ConcreteRNumber(prev) + return ConcretePJRTNumber(prev; sharding) else - if mode == TracedTrack + if mode == TracedTrack || mode == NoStopTracedTrack res = TracedRNumber{RT}( (path,), TracedUtils.broadcast_to_size(prev, ()).mlir_data ) @@ -555,45 +1139,96 @@ function make_tracer( return prev end -make_tracer(seen, prev::Type, @nospecialize(path), mode; kwargs...) = prev -make_tracer(seen, prev::Symbol, @nospecialize(path), mode; kwargs...) = prev +function make_tracer(seen, @nospecialize(prev::Type), @nospecialize(path), mode; kwargs...) + if mode == TracedToTypes + push!(path, prev) + return nothing + end + return prev +end +function make_tracer(seen, prev::Symbol, @nospecialize(path), mode; kwargs...) + if mode == TracedToTypes + push!(path, prev) + return nothing + end + return prev +end function make_tracer( seen, - @nospecialize(prev::Complex{RT}), + @nospecialize(prev::Complex), @nospecialize(path), mode; - toscalar=false, - tobatch=nothing, + @nospecialize(sharding = Sharding.NoSharding()), kwargs..., -) where {RT} +) + Sharding.is_sharded(sharding) && error("Cannot specify sharding for Complex") + if mode == TracedToTypes + push!(path, Core.Typeof(prev)) + make_tracer(seen, prev.re, path, mode; kwargs...) + make_tracer(seen, prev.im, path, mode; kwargs...) + return nothing + end return Complex( - make_tracer( - seen, prev.re, append_path(path, :re), mode; toscalar, tobatch, kwargs... - ), - make_tracer( - seen, prev.im, append_path(path, :im), mode; toscalar, tobatch, kwargs... - ), + make_tracer(seen, prev.re, append_path(path, :re), mode; kwargs...), + make_tracer(seen, prev.im, append_path(path, :im), mode; kwargs...), ) end function make_tracer( - seen, @nospecialize(prev::RT), @nospecialize(path), mode; track_numbers=(), kwargs... -) where {RT<:Array} - if haskey(seen, prev) + seen, + @nospecialize(prev::Array), + @nospecialize(path), + mode; + @nospecialize(track_numbers::Type = Union{}), + @nospecialize(sharding = Sharding.NoSharding()), + kwargs..., +) + RT = Core.Typeof(prev) + # XXX: If someone wants to shard the same array with different shardings, we need to + # somehow handle this correctly... Right now we just use the first sharding. + if mode != NoStopTracedTrack && haskey(seen, prev) + if mode == TracedToTypes + visited = seen[prev] + push!(path, visited) + return nothing + end return seen[prev] end - if mode == ArrayToConcrete && eltype(RT) <: ReactantPrimitive - return seen[prev] = ConcreteRArray(prev) + if eltype(RT) <: ReactantPrimitive + if mode == ArrayToConcrete && return seen[prev] = ConcretePJRTArray(prev; sharding) + elseif mode == TracedToTypes + # Original array can get mutated so we store a copy: + push!(path, copy(prev)) + seen[prev] = VisitedObject(length(seen) + 1) + return nothing + end + elseif mode == TracedToTypes + push!(path, RT) + for I in eachindex(prev) + if isassigned(prev, I) + pv = prev[I] + make_tracer(seen, pv, path, mode; track_numbers, sharding, kwargs...) + end + end + return nothing end - TT = traced_type(eltype(RT), (), Val(mode), track_numbers) + TT = traced_type(eltype(RT), Val(mode), track_numbers, sharding) newa = Array{TT,ndims(RT)}(undef, size(prev)) seen[prev] = newa same = true for I in eachindex(prev) if isassigned(prev, I) pv = prev[I] - nv = make_tracer(seen, pv, append_path(path, I), mode; track_numbers, kwargs...) + nv = make_tracer( + seen, + pv, + append_path(path, I), + mode; + track_numbers, + sharding=Base.getproperty(sharding, I), + kwargs..., + ) if pv !== nv same = false end @@ -608,31 +1243,129 @@ function make_tracer( end function make_tracer( - seen, @nospecialize(prev::RT), @nospecialize(path), mode; kwargs... -) where {RT<:Tuple} + seen, + @nospecialize(prev::Dict{Key,Value}), + @nospecialize(path), + mode; + @nospecialize(track_numbers::Type = Union{}), + @nospecialize(sharding = Sharding.NoSharding()), + kwargs..., +) where {Key,Value} + RT = Core.Typeof(prev) + # XXX: If someone wants to shard the same array with different shardings, we need to + # somehow handle this correctly... Right now we just use the first sharding. + if mode != NoStopTracedTrack && haskey(seen, prev) + if mode == TracedToTypes + visited = seen[prev] + push!(path, visited) + return nothing + end + return seen[prev] + end + if eltype(RT) <: ReactantPrimitive + if mode == ArrayToConcrete && return seen[prev] = ConcretePJRTArray(prev; sharding) + elseif mode == TracedToTypes + # Original array can get mutated so we store a copy: + push!(path, copy(prev)) + seen[prev] = VisitedObject(length(seen) + 1) + return nothing + end + elseif mode == TracedToTypes + push!(path, RT) + for (k, v) in prev + make_tracer(seen, k, path, mode; track_numbers, sharding, kwargs...) + make_tracer(seen, v, path, mode; track_numbers, sharding, kwargs...) + end + return nothing + end + Value2 = traced_type(Value, Val(mode), track_numbers, sharding) + newa = Dict{Key,Value2}() + seen[prev] = newa + same = true + for (k, v) in prev + nv = make_tracer( + seen, + v, + append_path(path, k), + mode; + track_numbers, + sharding=Base.getproperty(sharding, k), + kwargs..., + ) + if v !== nv + same = false + end + newa[k] = nv + end + if same + seen[prev] = prev + return prev + end + return newa +end + +function make_tracer( + seen, + @nospecialize(prev::Tuple), + @nospecialize(path), + mode; + @nospecialize(sharding = Sharding.NoSharding()), + kwargs..., +) + RT = Core.Typeof(prev) + if mode == TracedToTypes + push!(path, RT) + for (i, v) in enumerate(prev) + make_tracer( + seen, v, path, mode; sharding=Base.getproperty(sharding, i), kwargs... + ) + end + return nothing + end return ( ( - make_tracer(seen, v, append_path(path, i), mode; kwargs...) for - (i, v) in enumerate(prev) + make_tracer( + seen, + v, + append_path(path, i), + mode; + sharding=Base.getproperty(sharding, i), + kwargs..., + ) for (i, v) in enumerate(prev) )..., ) end function make_tracer( seen, - @nospecialize(prev::NamedTuple{A,RT}), + @nospecialize(prev::NamedTuple), @nospecialize(path), mode; - track_numbers=(), + @nospecialize(track_numbers::Type = Union{}), + @nospecialize(sharding = Sharding.NoSharding()), kwargs..., -) where {A,RT} - return NamedTuple{A,traced_type(RT, (), Val(mode), track_numbers)}(( +) + NT = Core.Typeof(prev) + A = NT.parameters[1] + RT = NT.parameters[2] + + if mode == TracedToTypes + push!(path, NT) + for i in 1:length(A) + make_tracer( + seen, Base.getfield(prev, i), path, mode; track_numbers, sharding, kwargs... + ) + end + return nothing + end + return NamedTuple{A,traced_type(RT, Val(mode), track_numbers, sharding)}(( ( make_tracer( seen, Base.getfield(prev, i), append_path(path, i), mode; + sharding=Base.getproperty(sharding, i), track_numbers, kwargs..., ) for i in 1:length(A) @@ -640,12 +1373,30 @@ function make_tracer( )) end -function make_tracer(seen, prev::Core.Box, @nospecialize(path), mode; kwargs...) - if haskey(seen, prev) +function make_tracer( + seen, + prev::Core.Box, + @nospecialize(path), + mode; + @nospecialize(sharding = Sharding.NoSharding()), + kwargs..., +) + if mode == TracedToTypes + push!(path, Core.Box) + return make_tracer(seen, prev.contents, path, mode; sharding, kwargs...) + end + if mode != NoStopTracedTrack && haskey(seen, prev) return seen[prev] end prev2 = prev.contents - tr = make_tracer(seen, prev2, append_path(path, :contents), mode; kwargs...) + tr = make_tracer( + seen, + prev2, + append_path(path, :contents), + mode; + sharding=Base.getproperty(sharding, :contents), + kwargs..., + ) if tr === prev2 seen[prev] = prev return prev @@ -655,31 +1406,90 @@ function make_tracer(seen, prev::Core.Box, @nospecialize(path), mode; kwargs...) return res end -@inline function to_rarray(@nospecialize(x); track_numbers::Union{Bool,Tuple}=()) - track_numbers isa Bool && (track_numbers = track_numbers ? (Number,) : ()) - return to_rarray_internal(x, track_numbers) +@inline function to_rarray( + @nospecialize(x); + track_numbers::Union{Bool,Type}=false, + sharding=Sharding.Sharding.NoSharding(), +) + track_numbers isa Bool && (track_numbers = track_numbers ? Number : Union{}) + return to_rarray_internal(x, track_numbers, sharding) end -@inline function to_rarray_internal(@nospecialize(x), track_numbers::Tuple) - return make_tracer(OrderedIdDict(), x, (), Reactant.ArrayToConcrete; track_numbers) +@inline function to_rarray_internal( + @nospecialize(x), @nospecialize(track_numbers::Type), @nospecialize(sharding) +) + return make_tracer( + OrderedIdDict(), x, (), Reactant.ArrayToConcrete; track_numbers, sharding + ) end -function to_rarray_internal(@nospecialize(::TracedRArray), ::Tuple) - return error("Cannot convert TracedRArray to ConcreteRArray") +# fast paths avoiding make_tracer +function to_rarray_internal( + @nospecialize(::TracedRArray), + @nospecialize(track_numbers::Type), + @nospecialize(sharding) +) + return error("Cannot convert TracedRArray to ConcretePJRTArray") end -@inline to_rarray_internal(@nospecialize(x::ConcreteRArray), ::Tuple) = x + @inline function to_rarray_internal( - @nospecialize(x::AbstractArray{<:ReactantPrimitive}), ::Tuple + @nospecialize(x::ConcretePJRTArray), + @nospecialize(track_numbers::Type), + @nospecialize(sharding) ) - return ConcreteRArray(x) + if x.sharding isa Sharding.ShardInfo{typeof(sharding)} + return x + end + return error( + "Mismatched sharding. Input has sharding $(x.sharding), but requested sharding is $(typeof(sharding))", + ) end -@inline to_rarray_internal(@nospecialize(x::ConcreteRNumber), ::Tuple) = x @inline function to_rarray_internal( - @nospecialize(x::ReactantPrimitive), track_numbers::Tuple + @nospecialize(x::Array{<:ReactantPrimitive}), + @nospecialize(track_numbers::Type), + @nospecialize(sharding) ) - for T in track_numbers - typeof(x) <: T && return ConcreteRNumber(x) + return ConcretePJRTArray(x; sharding) +end + +@inline function to_rarray_internal( + @nospecialize(x::Array{T}), @nospecialize(track_numbers::Type), @nospecialize(sharding) +) where {T<:Number} + if reactant_primitive(T) !== nothing + return ConcretePJRTArray(to_reactant_primitive.(x); sharding) end + return @invoke to_rarray_internal(x::Any, track_numbers::Type, sharding) +end + +@inline function to_rarray_internal( + @nospecialize(x::ConcretePJRTNumber), + @nospecialize(track_numbers::Type), + @nospecialize(sharding) +) + if x.sharding isa Sharding.ShardInfo{typeof(sharding)} + return x + end + return error( + "Mismatched sharding. Input has sharding $(x.sharding), but requested sharding is $(typeof(sharding))", + ) +end + +@inline function to_rarray_internal( + @nospecialize(x::ReactantPrimitive), + @nospecialize(track_numbers::Type), + @nospecialize(sharding) +) + typeof(x) <: track_numbers && return ConcretePJRTNumber(x; sharding) return x end + +@inline function to_rarray_internal( + @nospecialize(x::Number), @nospecialize(track_numbers::Type), @nospecialize(sharding) +) + Sharding.is_sharded(sharding) && error("Cannot specify sharding for Numbers") + if reactant_primitive(typeof(x)) !== nothing + return ConcretePJRTArray(to_reactant_primitive(x)) + end + return @invoke to_rarray_internal(x::Any, track_numbers::Type, sharding) +end diff --git a/src/Types.jl b/src/Types.jl new file mode 100644 index 0000000000..951678fda5 --- /dev/null +++ b/src/Types.jl @@ -0,0 +1,195 @@ +abstract type RNumber{T<:ReactantPrimitive} <: Number end + +abstract type AbstractConcreteNumber{T} <: RNumber{T} end + +abstract type RArray{T,N} <: AbstractArray{T,N} end + +abstract type AbstractConcreteArray{T,N} <: RArray{T,N} end + +# Traced Types + +## MissingTracedValue -- defined in ReactantCore +@leaf MissingTracedValue + +## TracedRNumber +mutable struct TracedRNumber{T} <: RNumber{T} + paths::Tuple + mlir_data::Union{Nothing,MLIR.IR.Value} + + function TracedRNumber{T}( + paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value} + ) where {T} + if !isnothing(mlir_data) + @assert size(MLIR.IR.type(mlir_data)) == () + end + return new{T}(paths, mlir_data) + end +end + +@leaf TracedRNumber + +## TracedRArray +mutable struct TracedRArray{T,N} <: RArray{TracedRNumber{T},N} + paths::Tuple + mlir_data::Union{Nothing,MLIR.IR.Value} + shape::NTuple{N,Int} + + function TracedRArray{T,N}( + paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value}, shape + ) where {T,N} + shape = Tuple(shape) + if !isnothing(mlir_data) + @assert size(MLIR.IR.type(mlir_data)) == shape "Expected: $(shape), got: $(size(MLIR.IR.type(mlir_data)))" + end + return new{T,N}(paths, mlir_data, shape) + end +end + +@leaf TracedRArray +Adapt.parent_type(::Type{TracedRArray{T,N}}) where {T,N} = TracedRArray{T,N} + +const WrappedTracedRArray{T,N} = WrappedArray{ + TracedRNumber{T},N,TracedRArray,TracedRArray{T,N} +} +const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}} +const AnyTracedRVector{T} = AnyTracedRArray{T,1} +const AnyTracedRMatrix{T} = Union{ + AnyTracedRArray{T,2}, + LinearAlgebra.Diagonal{TracedRNumber{T},TracedRArray{T,1}}, + LinearAlgebra.Tridiagonal{TracedRNumber{T},TracedRArray{T,1}}, +} +const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}} + +## TracedRNG +mutable struct TracedRNG <: Random.AbstractRNG + seed::TracedRArray{UInt64,1} + const algorithm::String +end + +# Concrete Types +## ConcretePJRTNumber +mutable struct ConcretePJRTNumber{T,D,S<:Sharding.ShardInfo} <: AbstractConcreteNumber{T} + data::NTuple{D,XLA.PJRT.AsyncBuffer} + sharding::S +end + +ConcretePJRTNumber{T,1,Sharding.NoShardInfo}(x::Number) where {T} = ConcretePJRTNumber{T}(x) + +function ConcretePJRTNumber{T}(data::Tuple{XLA.PJRT.AsyncBuffer}) where {T} + return ConcretePJRTNumber{T,1,Sharding.NoShardInfo}(data, Sharding.NoShardInfo()) +end + +@leaf ConcretePJRTNumber + +function ConcretePJRTNumber{T}(data::T2; kwargs...) where {T<:Number,T2<:Number} + carray = ConcretePJRTArray(fill(convert(T, data)); kwargs...) + if !Sharding.is_sharded(carray.sharding) + return ConcretePJRTNumber{T,1,typeof(carray.sharding)}( + (carray.data[1],), carray.sharding + ) + end + @assert all(isnothing, carray.sharding.partition_spec) "ConcretePJRTNumber cannot be \ + sharded" + return ConcretePJRTNumber{T,length(carray.data),typeof(carray.sharding)}( + carray.data, carray.sharding + ) +end +function ConcretePJRTNumber(data::T; kwargs...) where {T<:Number} + return ConcretePJRTNumber{T}(data; kwargs...) +end + +## ConcretePJRTArray +mutable struct ConcretePJRTArray{T,N,D,S<:Sharding.ShardInfo} <: AbstractConcreteArray{T,N} + data::NTuple{D,XLA.PJRT.AsyncBuffer} + shape::NTuple{N,Int} + sharding::S +end + +@leaf ConcretePJRTArray +Adapt.parent_type(::Type{<:ConcretePJRTArray{T,N}}) where {T,N} = ConcretePJRTArray{T,N} +function Adapt.parent_type(::Type{ConcretePJRTArray{T,N,D,S}}) where {T,N,D,S} + return ConcretePJRTArray{T,N,D,S} +end + +Base.@deprecate ConcretePJRTArray(data::Number; kwargs...) ConcretePJRTNumber( + data; kwargs... +) + +function ConcretePJRTArray{T,N}( + data::Tuple{XLA.PJRT.AsyncBuffer}, shape::NTuple{N,Int} +) where {T,N} + return ConcretePJRTArray{T,N,1,Sharding.NoShardInfo}( + data, shape, Sharding.NoShardInfo() + ) +end + +function ConcretePJRTArray( + data::Array{T,N}; + client::XLA.AbstractClient=XLA.default_backend(), + idx::Union{Int,Nothing}=nothing, + device::Union{Nothing,XLA.AbstractDevice}=nothing, + sharding::Sharding.AbstractSharding=Sharding.NoSharding(), +) where {T,N} + if !Sharding.is_sharded(sharding) + if device === nothing + if idx === nothing + device = XLA.default_device(client) + else + device = XLA.get_device(client, idx) + end + else + if idx !== nothing + device_from_idx = XLA.get_device(client, idx) + @assert device_from_idx == device "If both `idx` and `device` are \ + specified, `idx` must match `device`" + end + end + sdata, sharding = sharding(client, device, data) + return ConcretePJRTArray{T,N,1,typeof(sharding)}(sdata, size(data), sharding) + end + @assert device === nothing && idx === nothing "If `sharding` is not `NoSharding`, `device` and `idx` cannot be specified!" + sharded_data, sharding = sharding(client, nothing, data) + return ConcretePJRTArray{T,N,length(sharded_data),typeof(sharding)}( + sharded_data, size(data), sharding + ) +end + +XLA.await(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = foreach(XLA.await, x.data) +XLA.client(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = XLA.client(x.data) +function XLA.device(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) + x.sharding isa Sharding.NoShardInfo && return XLA.device(only(x.data)) + return nothing # This is intentional to make constructing ConcretePJRTArrays easier +end + +const ConcretePJRTScalar{T} = Union{ConcretePJRTArray{T,0},ConcretePJRTNumber{T}} +const WrappedConcretePJRTArray{T,N,D,S} = WrappedArray{ + T,N,ConcretePJRTArray,ConcretePJRTArray{T,N,D,S} +} +const AnyConcretePJRTArray{T,N,D,S} = Union{ + ConcretePJRTArray{T,N,D,S},WrappedConcretePJRTArray{T,N,D,S} +} + +const AnyConcreteRArray = AnyConcretePJRTArray + +ConcretePJRTArray(x::AnyConcretePJRTArray) = ConcretePJRTArray{eltype(x),ndims(x)}(x) +ConcretePJRTArray{T}(x::AnyConcretePJRTArray) where {T} = ConcretePJRTArray{T,ndims(x)}(x) +ConcretePJRTArray{T,N}(x::ConcretePJRTArray{T,N}) where {T,N} = x +function ConcretePJRTArray{T,N}(x::AnyConcretePJRTArray) where {T,N} + ancestor_x = ancestor(x) + return ConcretePJRTArray( + convert(Array{T,N}, x); + client=XLA.client(ancestor_x), + device=XLA.device(ancestor_x), + sharding=ancestor_x.sharding, + ) +end + +## ConcreteRNG +mutable struct ConcreteRNG{S<:ConcretePJRTArray} <: Random.AbstractRNG + seed::S + const algorithm::String +end + +## Aliases to prevent breaking changes +const ConcreteRArray = ConcretePJRTArray +const ConcreteRNumber = ConcretePJRTNumber diff --git a/src/XLA.jl b/src/XLA.jl deleted file mode 100644 index 54b45cd00b..0000000000 --- a/src/XLA.jl +++ /dev/null @@ -1,506 +0,0 @@ -module XLA - -import ...MLIR - -mutable struct Client - client::Ptr{Cvoid} - - function Client(client::Ptr{Cvoid}) - return new(client) - #@assert client != C_NULL - #finalizer(new(client)) do client - # @ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid - #end - end -end - -function to_row_major(x::Array{T,N}) where {T,N} - return permutedims(x, reverse(Base.OneTo(N))) -end - -function to_row_major(x::Vector{T}) where {T} - return x -end - -function to_row_major(x::Matrix{T}) where {T} - return Matrix{T}(transpose(x)) -end - -function from_row_major(x::Array{T,N}) where {T,N} - return permutedims(x, reverse(Base.OneTo(N))) -end - -function from_row_major(x::Vector{T}) where {T} - return x -end - -function from_row_major(x::Matrix{T}) where {T} - return transpose(x) -end - -SetLogLevel(x) = @ccall MLIR.API.mlir_c.SetLogLevel(x::Cint)::Cvoid - -const cpuclientcount = Ref(0) -# TODO synchronization when async is not working because `future` in `ConcreteRArray` is always `nothing` -function CPUClient(asynchronous=false, node_id=0, num_nodes=1) - global cpuclientcount - @assert cpuclientcount[] == 0 - cpuclientcount[] += 1 - - f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "MakeCPUClient") - client = ccall(f, Ptr{Cvoid}, (UInt, Cint, Cint), asynchronous, node_id, num_nodes) - #client = @ccall MLIR.API.mlir_c.MakeCPUClient(asynchronous::UInt8, node_id::Cint, num_nodes::Cint)::Ptr{Cvoid} - return Client(client) -end - -function GPUClient(node_id=0, num_nodes=1, platform="gpu") - #allowed_devices = [-1] - # GC.@preserve allowed_devices begin - f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "MakeGPUClient") - refstr = Ref{Cstring}() - client = ccall( - f, - Ptr{Cvoid}, - (Cint, Cint, Ptr{Cvoid}, Cint, Cstring, Ptr{Cstring}), - node_id, - num_nodes, - C_NULL, - 0, - platform, - refstr, - ) - if client == C_NULL - throw(AssertionError(unsafe_string(refstr[]))) - end - return Client(client) -end - -function TPUClient(tpu_path::String) - f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "MakeTPUClient") - refstr = Ref{Cstring}() - client = ccall(f, Ptr{Cvoid}, (Cstring, Ptr{Cstring}), tpu_path, refstr) - if client == C_NULL - throw(AssertionError(unsafe_string(refstr[]))) - end - return Client(client) -end - -const backends = Dict{String,Client}() -const default_backend = Ref{Client}() -const default_device_idx = Ref{Int}(0) -using Reactant_jll -using Libdl -using Scratch, Downloads - -struct ReactantInternalError <: Base.Exception - msg::String -end - -function Base.showerror(io::IO, ece::ReactantInternalError) - return print(io, ece.msg, '\n') -end - -function reactant_err(msg::Cstring)::Cvoid - throw(ReactantInternalError(Base.unsafe_string(msg))) -end - -function __init__() - initLogs = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "InitializeLogs") - ccall(initLogs, Cvoid, ()) - # Add most log level - # SetLogLevel(0) - cpu = CPUClient() - backends["cpu"] = cpu - default_backend[] = cpu - - @static if !Sys.isapple() - if isfile("/usr/lib/libtpu.so") - dataset_dir = @get_scratch!("libtpu") - if !isfile(dataset_dir * "/libtpu.so") - Downloads.download( - "https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20240829-py3-none-any.whl", - dataset_dir * "/tpu.zip", - ) - run(`unzip -qq $(dataset_dir*"/tpu.zip") -d $(dataset_dir)/tmp`) - run(`mv $(dataset_dir)/tmp/libtpu/libtpu.so $(dataset_dir)/libtpu.so`) - rm(dataset_dir * "/tmp"; recursive=true) - rm(dataset_dir * "/tpu.zip"; recursive=true) - end - try - tpu = TPUClient(dataset_dir * "/libtpu.so") - backends["tpu"] = tpu - default_backend[] = tpu - catch e - println(stdout, e) - end - else - try - gpu = GPUClient() - backends["gpu"] = gpu - default_backend[] = gpu - catch e - println(stdout, e) - end - end - end - - @ccall MLIR.API.mlir_c.RegisterCustomCallTarget( - "enzymexla_gpu"::Cstring, - cglobal((:EnzymeGPUCustomCall, MLIR.API.mlir_c))::Ptr{Cvoid}, - "CUDA"::Cstring, - )::Cvoid - - # This wasn't properly exported on macos, we'll remove the try once macOS JLL - # has the fix. - errptr = cglobal((:ReactantThrowError, MLIR.API.mlir_c), Ptr{Ptr{Cvoid}}) - unsafe_store!(errptr, @cfunction(reactant_err, Cvoid, (Cstring,))) - return nothing -end - -@inline function free_exec(exec) - @ccall MLIR.API.mlir_c.ExecutableFree(exec.exec::Ptr{Cvoid})::Cvoid -end - -mutable struct LoadedExecutable - exec::Ptr{Cvoid} - - function LoadedExecutable(exec::Ptr{Cvoid}) - @assert exec != C_NULL - return finalizer(free_exec, new(exec)) - end -end - -@inline function free_future(future) - @ccall MLIR.API.mlir_c.FreeFuture(future.future::Ptr{Cvoid})::Cvoid -end - -mutable struct Future - future::Ptr{Cvoid} - - function Future(future::Ptr{Cvoid}) - # @assert future != C_NULL - return finalizer(free_future, new(future)) - end -end - -@inline function free_buffer(buffer) - sbuffer = buffer.buffer - if sbuffer != C_NULL - @ccall MLIR.API.mlir_c.PjRtBufferFree(sbuffer::Ptr{Cvoid})::Cvoid - end -end - -mutable struct Buffer - buffer::Ptr{Cvoid} - function Buffer(buffer::Ptr{Cvoid}) - return finalizer(free_buffer, new(buffer)) - end -end - -struct Device - device::Ptr{Cvoid} -end - -mutable struct AsyncBuffer - buffer::Buffer - future::Union{Future,Nothing} -end - -function device(buffer::Buffer) - GC.@preserve buffer begin - return Device( - @ccall MLIR.API.mlir_c.BufferToDevice(buffer.buffer::Ptr{Cvoid})::Ptr{Cvoid} - ) - end -end -function client(buffer::Buffer) - GC.@preserve buffer begin - return Client( - @ccall MLIR.API.mlir_c.BufferToClient(buffer.buffer::Ptr{Cvoid})::Ptr{Cvoid} - ) - end -end -function device(buffer::AsyncBuffer) - return device(buffer.buffer) -end -function client(buffer::AsyncBuffer) - return client(buffer.buffer) -end -function client(device::Device) - GC.@preserve device begin - return Client( - @ccall MLIR.API.mlir_c.DeviceToClient(device.device::Ptr{Cvoid})::Ptr{Cvoid} - ) - end -end - -# https://github.com/openxla/xla/blob/4bfb5c82a427151d6fe5acad8ebe12cee403036a/xla/xla_data.proto#L29 -@inline primitive_type(::Type{Bool}) = 1 - -@inline primitive_type(::Type{Int8}) = 2 -@inline primitive_type(::Type{UInt8}) = 6 - -@inline primitive_type(::Type{Int16}) = 3 -@inline primitive_type(::Type{UInt16}) = 7 - -@inline primitive_type(::Type{Int32}) = 4 -@inline primitive_type(::Type{UInt32}) = 8 - -@inline primitive_type(::Type{Int64}) = 5 -@inline primitive_type(::Type{UInt64}) = 9 - -@inline primitive_type(::Type{Float16}) = 10 -@inline primitive_type(::Type{Float32}) = 11 - -@static if isdefined(Core, :BFloat16) - @inline primitive_type(::Type{Core.BFloat16}) = 16 -end - -@inline primitive_type(::Type{Float64}) = 12 - -@inline primitive_type(::Type{Complex{Float32}}) = 15 -@inline primitive_type(::Type{Complex{Float64}}) = 18 - -function ArrayFromHostBuffer(client::Client, array::Array{T,N}, device) where {T,N} - sizear = Int64[s for s in reverse(size(array))] - buffer = GC.@preserve array sizear begin - @ccall MLIR.API.mlir_c.ArrayFromHostBuffer( - client.client::Ptr{Cvoid}, - pointer(array)::Ptr{T}, - primitive_type(T)::UInt64, - N::Csize_t, - pointer(sizear)::Ptr{Int64}, - device.device::Ptr{Cvoid}, - )::Ptr{Cvoid} - end - return Buffer(buffer) -end - -function BufferToHost(buffer::Buffer, data) - GC.@preserve buffer begin - @ccall MLIR.API.mlir_c.BufferToHost( - buffer.buffer::Ptr{Cvoid}, data::Ptr{Cvoid} - )::Cvoid - end -end - -# TODO users themselves need to gc preserve here -function UnsafeBufferPointer(buffer::Buffer) - @ccall MLIR.API.mlir_c.UnsafeBufferPointer(buffer.buffer::Ptr{Cvoid})::Ptr{Cvoid} -end - -function CopyBufferToDevice(buffer::Buffer, device::Device) - GC.@preserve buffer device begin - Buffer( - @ccall MLIR.API.mlir_c.CopyBufferToDevice( - buffer.buffer::Ptr{Cvoid}, device.device::Ptr{Cvoid} - )::Ptr{Cvoid} - ) - end -end - -function BufferOnCPU(buffer::Buffer) - GC.@preserve buffer begin - (@ccall MLIR.API.mlir_c.BufferOnCPU(buffer.buffer::Ptr{Cvoid})::UInt8) != 0 - end -end - -function execute_ir(N, n_outs, fn) - ptr = sizeof(Int) == sizeof(Int64) ? "i64" : "i32" - cint = sizeof(Cint) == sizeof(Int64) ? "i64" : "i32" - args = N > 0 ? ", [$N x $ptr] %inps, [$N x i8] %donated" : "" - stores = N > 0 ? """ -store [$N x $ptr] %inps, [$N x $ptr]* %inpa -store [$N x i8] %donated, [$N x i8]* %dona - """ : "" - - res = """define { [$n_outs x $ptr], [$n_outs x $ptr], i8 } @f($ptr %exec $args) alwaysinline { -entry: - %inpa = alloca [$N x $ptr] - %dona = alloca [$N x i8] - %outa = alloca [$n_outs x $ptr] - %futpa = alloca [$n_outs x $ptr] - $stores - %futa = alloca i8 - call void inttoptr ($ptr $fn to void ($ptr, $cint, [$N x $ptr]*, [$N x i8]*, $cint, [$n_outs x $ptr]*, i8*, [$n_outs x $ptr]*)*)($ptr %exec, $cint $N, [$N x $ptr]* nocapture readonly %inpa, [$N x i8]* nocapture readonly %dona, $cint $n_outs, [$n_outs x $ptr]* nocapture writeonly %outa, i8* nocapture writeonly %futa, [$n_outs x $ptr]* nocapture writeonly %futpa) - %out = load [$n_outs x $ptr], [$n_outs x $ptr]* %outa - %fut = load i8, i8* %futa - %futp = load [$n_outs x $ptr], [$n_outs x $ptr]* %futpa - %fca.0.insert = insertvalue { [$n_outs x $ptr], [$n_outs x $ptr], i8 } undef, [$n_outs x $ptr] %out, 0 - %fca.1.insert = insertvalue { [$n_outs x $ptr], [$n_outs x $ptr], i8 } %fca.0.insert, [$n_outs x $ptr] %futp, 1 - %fca.2.insert = insertvalue { [$n_outs x $ptr], [$n_outs x $ptr], i8 } %fca.1.insert, i8 %fut, 2 - ret { [$n_outs x $ptr], [$n_outs x $ptr], i8 } %fca.2.insert -} -""" - return res -end - -@generated function ExecutableCall( - exec::LoadedExecutable, - inputs::NTuple{N,Ptr{Cvoid}}, - donated_args::NTuple{N,UInt8}, - ::Val{n_outs}, -) where {N,n_outs} - sym0 = dlsym(Reactant_jll.libReactantExtra_handle, "XLAExecute") - xla_execute_fn = reinterpret(UInt, sym0) - ir = execute_ir(N, n_outs, xla_execute_fn) - results = [] - for i in 1:n_outs - push!( - results, - :(AsyncBuffer(Buffer(outputs[$i]), future ? Future(future_res[$i]) : nothing)), - ) - end - - args_type = N > 0 ? (Ptr{Cvoid}, NTuple{N,Ptr{Cvoid}}, NTuple{N,UInt8}) : (Ptr{Cvoid},) - args = N > 0 ? (:inputs, :donated_args) : () - return quote - Base.@_inline_meta - exec = exec.exec - GC.@preserve exec begin - outputs, future_res, future = Base.llvmcall( - ($ir, "f"), - Tuple{NTuple{n_outs,Ptr{Cvoid}},NTuple{n_outs,Ptr{Cvoid}},Bool}, - Tuple{$args_type...}, - exec, - $(args...), - ) - end - return ($(results...),) - end -end - -@inline function ExecutableCall0( - exec::LoadedExecutable, - inputs::NTuple{N,Ptr{Cvoid}}, - donated_args::NTuple{N,UInt8}, - ::Val{n_outs}, -) where {N,n_outs} - outputs = Ref{NTuple{n_outs,Ptr{Cvoid}}}() - future_res = Ref{NTuple{n_outs,Ptr{Cvoid}}}() - futures = Ref{UInt8}(0) - - inputs = Base.RefValue(inputs) - donated_args = Base.RefValue(donated_args) - GC.@preserve inputs donated_args outputs futures future_res begin - @ccall MLIR.API.mlir_c.XLAExecute( - exec.exec::Ptr{Cvoid}, - N::Cint, - inputs::Ptr{Cvoid}, - donated_args::Ptr{UInt8}, - n_outs::Cint, - Base.unsafe_convert(Ptr{Cvoid}, outputs)::Ptr{Cvoid}, - Base.unsafe_convert(Ptr{UInt8}, futures)::Ptr{UInt8}, - Base.unsafe_convert(Ptr{Cvoid}, future_res)::Ptr{Cvoid}, - )::Cvoid - end - - outputs = outputs[] - future_res = future_res[] - future = futures[] != 0 - - return ntuple(Val(n_outs)) do i - Base.@_inline_meta - return AsyncBuffer(Buffer(outputs[i]), future ? Future(future_res[i]) : nothing) - end -end - -function Compile(client::Client, mod::MLIR.IR.Module) - GC.@preserve client mod begin - executable = LoadedExecutable( - @ccall MLIR.API.mlir_c.ClientCompile( - client.client::Ptr{Cvoid}, mod.module_::MLIR.API.MlirModule - )::Ptr{Cvoid} - ) - end -end - -function ClientNumDevices(client::Client) - GC.@preserve client begin - return @ccall MLIR.API.mlir_c.ClientNumDevices(client.client::Ptr{Cvoid})::Cint - end -end - -function ClientNumAddressableDevices(client::Client) - GC.@preserve client begin - return @ccall MLIR.API.mlir_c.ClientNumAddressableDevices( - client.client::Ptr{Cvoid} - )::Cint - end -end - -function ClientProcessIndex(client::Client) - GC.@preserve client begin - return @ccall MLIR.API.mlir_c.ClientProcessIndex(client.client::Ptr{Cvoid})::Cint - end -end - -function ClientGetDevice(client::Client, idx) - GC.@preserve client begin - return Device( - @ccall MLIR.API.mlir_c.ClientGetDevice( - client.client::Ptr{Cvoid}, idx::Cint - )::Ptr{Cvoid} - ) - end -end - -function ClientGetAddressableDevice(client::Client, idx) - GC.@preserve client begin - return Device( - @ccall MLIR.API.mlir_c.ClientGetAddressableDevice( - client.client::Ptr{Cvoid}, idx::Cint - )::Ptr{Cvoid} - ) - end -end - -function is_ready(future::Future) - GC.@preserve future begin - return (@ccall MLIR.API.mlir_c.FutureIsReady(future.future::Ptr{Cvoid})::UInt8) != 0 - end -end - -@inline function await(future::Future)::Nothing - GC.@preserve future begin - @ccall MLIR.API.mlir_c.FutureAwait(future.future::Ptr{Cvoid})::Cvoid - end - return nothing -end - -function is_ready(buffer::AsyncBuffer)::Bool - future = buffer.future - if isnothing(future) - return true - else - return is_ready(future) - end -end - -const AsyncEmptyBuffer = AsyncBuffer(Buffer(C_NULL), nothing) - -@inline function await(buffer::AsyncBuffer)::Nothing - if buffer.future === nothing - return nothing - else - future = buffer.future - buffer.future = nothing - await(future::Future) - end - return nothing -end - -@inline function synced_buffer(buffer::AsyncBuffer) - if buffer.future !== nothing - future = buffer.future - buffer.future = nothing - await(future::Future) - end - return buffer.buffer -end - -@inline function synced_buffer(buffer::Buffer) - return buffer -end - -end diff --git a/src/mlir/Dialects/Affine.jl b/src/mlir/Dialects/Affine.jl index c3774d3853..9ce90aa908 100755 --- a/src/mlir/Dialects/Affine.jl +++ b/src/mlir/Dialects/Affine.jl @@ -83,6 +83,9 @@ In the above example, `%indices:3` conceptually holds the following: %indices_2 = affine.apply #map2()[%linear_index] ``` +In other words, `%0:3 = affine.delinearize_index %x into (B, C)` produces +`%0 = {%x / (B * C), (%x mod (B * C)) / C, %x mod C}`. + The basis may either contain `N` or `N-1` elements, where `N` is the number of results. If there are N basis elements, the first one will not be used during computations, but may be used during analysis and canonicalization to eliminate terms from @@ -98,7 +101,12 @@ That is, the example above could also have been written %0:3 = affine.delinearize_index %linear_index into (244, 244) : index, index ``` -Note that, due to the constraints of affine maps, all the basis elements must +Note that, for symmetry with `getPaddedBasis()`, if `hasOuterBound` is `true` +when one of the `OpFoldResult` builders is called but the first element of the +basis is `nullptr`, that first element is ignored and the builder proceeds as if +there was no outer bound. + +Due to the constraints of affine maps, all the basis elements must be strictly positive. A dynamic basis element being 0 or negative causes undefined behavior. """ @@ -382,6 +390,9 @@ That is, for indices `%idx_0` to `%idx_{N-1}` and basis elements `b_0` sum(i = 0 to N-1) %idx_i * product(j = i + 1 to N-1) B_j ``` +In other words, `%0 = affine.linearize_index [%z, %y, %x] by (Z, Y, X)` +gives `%0 = %x + %y * X + %z * X * Y`, or `%0 = %x + X * (%y + Y * (%z))`. + The basis may either have `N` or `N-1` elements, where `N` is the number of inputs to linearize_index. If `N` inputs are provided, the first one is not used in computation, but may be used during analysis or canonicalization as a bound @@ -390,6 +401,10 @@ on `%idx_0`. If all `N` basis elements are provided, the linearize_index operation is said to \"have an outer bound\". +As a convenience, and for symmetry with `getPaddedBasis()`, ifg the first +element of a set of `OpFoldResult`s passed to the builders of this operation is +`nullptr`, that element is ignored. + If the `disjoint` property is present, this is an optimization hint that, for all `i`, `0 <= %idx_i < B_i` - that is, no index affects any other index, except that `%idx_0` may be negative to make the index as a whole negative. diff --git a/src/mlir/Dialects/CHLO.jl b/src/mlir/Dialects/CHLO.jl index 7696a65567..119b2457da 100755 --- a/src/mlir/Dialects/CHLO.jl +++ b/src/mlir/Dialects/CHLO.jl @@ -1408,12 +1408,12 @@ the lhs is required to have one ragged dimension, and the rhs may have at most one group dimension. The op has three modes, depending on the kind of the lhs ragged dimension. -In mode 1, the shape-signature is `[b,m,k], [g,b,k,n], [g] -> [b,m,n]`. +In mode 1, the shape-signature is `[b,m,k], [g,b,k,n], [b,g] -> [b,m,n]`. Here the ragged dimension is an lhs non-contracting dimension (`m`). The dimensions `b` and `k` represent batch and contracting dimensions respectively. The rhs is required to have a group dimension (`g`). -In mode 2, the shape-signature is `[b,m,k], [b,k,n], [g] -> [g,b,m,n]`. +In mode 2, the shape-signature is `[b,m,k], [b,k,n], [b,g] -> [g,b,m,n]`. Here the ragged dimension is an lhs/rhs contracting dimension (`k`). In mode 3, the shape-signature is `[b,m,k], [b,k,n], [g] -> [b,m,n]`. Here diff --git a/src/mlir/Dialects/Enzyme.jl b/src/mlir/Dialects/Enzyme.jl index 9ebd8211b9..f922304da3 100755 --- a/src/mlir/Dialects/Enzyme.jl +++ b/src/mlir/Dialects/Enzyme.jl @@ -92,6 +92,33 @@ function batch( ) end +""" +`broadcast` + +Broadcast the operand by adding extra dimensions with sizes provided by the `shape` attribute to the front. +For scalar operands, ranked tensor is created. + +NOTE: Only works for scalar and *ranked* tensor operands for now. +""" +function broadcast(input::Value; output::IR.Type, shape, location=Location()) + op_ty_results = IR.Type[output,] + operands = Value[input,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("shape", shape),] + + return create_operation( + "enzyme.broadcast", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + function fwddiff( inputs::Vector{Value}; outputs::Vector{IR.Type}, diff --git a/src/mlir/Dialects/EnzymeXLA.jl b/src/mlir/Dialects/EnzymeXLA.jl index 0ee73ded4d..a7b1261484 100644 --- a/src/mlir/Dialects/EnzymeXLA.jl +++ b/src/mlir/Dialects/EnzymeXLA.jl @@ -13,6 +13,82 @@ import ...IR: import ..Dialects: namedattribute, operandsegmentsizes import ...API +function scope( + operands::Vector{Value}; results::Vector{IR.Type}, region::Region, location=Location() +) + op_ty_results = IR.Type[results...,] + operands = Value[operands...,] + owned_regions = Region[region,] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "enzymexla.scope", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function get_stream(; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "enzymexla.get_stream", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function jit_call( + inputs::Vector{Value}; + result_0::Vector{IR.Type}, + fn, + backend_config=nothing, + operand_layouts=nothing, + result_layouts=nothing, + output_operand_aliases=nothing, + location=Location(), +) + op_ty_results = IR.Type[result_0...,] + operands = Value[inputs...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("fn", fn),] + !isnothing(backend_config) && + push!(attributes, namedattribute("backend_config", backend_config)) + !isnothing(operand_layouts) && + push!(attributes, namedattribute("operand_layouts", operand_layouts)) + !isnothing(result_layouts) && + push!(attributes, namedattribute("result_layouts", result_layouts)) + !isnothing(output_operand_aliases) && + push!(attributes, namedattribute("output_operand_aliases", output_operand_aliases)) + + return create_operation( + "enzymexla.jit_call", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + function kernel_call( gridx::Value, gridy::Value, @@ -56,4 +132,42 @@ function kernel_call( ) end +function memref2pointer(source::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[source,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "enzymexla.memref2pointer", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function pointer2memref(source::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[source,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "enzymexla.pointer2memref", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + end # enzymexla diff --git a/src/mlir/Dialects/Func.jl b/src/mlir/Dialects/Func.jl index 6eb30523c8..dcadfd219c 100755 --- a/src/mlir/Dialects/Func.jl +++ b/src/mlir/Dialects/Func.jl @@ -34,6 +34,8 @@ function call_indirect( callee::Value, callee_operands::Vector{Value}; results::Vector{IR.Type}, + arg_attrs=nothing, + res_attrs=nothing, location=Location(), ) op_ty_results = IR.Type[results...,] @@ -41,6 +43,8 @@ function call_indirect( owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] + !isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs)) + !isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs)) return create_operation( "func.call_indirect", @@ -72,6 +76,8 @@ function call( operands::Vector{Value}; result_0::Vector{IR.Type}, callee, + arg_attrs=nothing, + res_attrs=nothing, no_inline=nothing, location=Location(), ) @@ -80,6 +86,8 @@ function call( owned_regions = Region[] successors = Block[] attributes = NamedAttribute[namedattribute("callee", callee),] + !isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs)) + !isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs)) !isnothing(no_inline) && push!(attributes, namedattribute("no_inline", no_inline)) return create_operation( @@ -219,7 +227,7 @@ that contains the operation. # Example ```mlir -func.func @foo() : (i32, f8) { +func.func @foo() -> (i32, f8) { ... return %0, %1 : i32, f8 } diff --git a/src/mlir/Dialects/Gpu.jl b/src/mlir/Dialects/Gpu.jl new file mode 100755 index 0000000000..6a7a8615c4 --- /dev/null +++ b/src/mlir/Dialects/Gpu.jl @@ -0,0 +1,3341 @@ +module gpu +using ...IR +import ...IR: + NamedAttribute, + Value, + Location, + Block, + Region, + Attribute, + create_operation, + context, + IndexType +import ..Dialects: namedattribute, operandsegmentsizes +import ...API + +""" +`all_reduce` + +The `all_reduce` op reduces the value of every work item across a local +workgroup. The result is equal for all work items of a workgroup. + +For example, both + +```mlir +%1 = gpu.all_reduce add %0 {} : (f32) -> (f32) +%2 = gpu.all_reduce %0 { +^bb(%lhs : f32, %rhs : f32): + %sum = arith.addf %lhs, %rhs : f32 + \"gpu.yield\"(%sum) : (f32) -> () +} : (f32) -> (f32) +``` + +compute the sum of each work item\'s %0 value. The first version specifies +the accumulation as operation, whereas the second version specifies the +accumulation as code region. The reduction operation must be one of: +* Integer types: `add`, `mul`, `minui`, `minsi`, `maxui`, `maxsi`, `and`, + `or`, `xor` +* Floating point types: `add`, `mul`, `minnumf`, `maxnumf`, `minimumf`, + `maximumf` + +If `uniform` flag is set either none or all work items of a workgroup +need to execute this op in convergence. +""" +function all_reduce( + value::Value; + result=nothing::Union{Nothing,IR.Type}, + op=nothing, + uniform=nothing, + body::Region, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[value,] + owned_regions = Region[body,] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + !isnothing(op) && push!(attributes, namedattribute("op", op)) + !isnothing(uniform) && push!(attributes, namedattribute("uniform", uniform)) + + return create_operation( + "gpu.all_reduce", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`alloc` + +The `gpu.alloc` operation allocates a region of memory on the GPU. It is +similar to the `memref.alloc` op, but supports asynchronous GPU execution. + +The op does not execute before all async dependencies have finished +executing. + +If the `async` keyword is present, the op is executed asynchronously (i.e. +it does not block until the execution has finished on the device). In +that case, it also returns a !gpu.async.token. + +If the `host_shared` keyword is present, the memory will be allocated in a +memory accessible both on host and on device. + +# Example + +```mlir +%memref, %token = gpu.alloc async [%dep] host_shared (%width) : memref<64x?xf32, 1> +``` +""" +function alloc( + asyncDependencies::Vector{Value}, + dynamicSizes::Vector{Value}, + symbolOperands::Vector{Value}; + memref::IR.Type, + asyncToken=nothing::Union{Nothing,IR.Type}, + hostShared=nothing, + location=Location(), +) + op_ty_results = IR.Type[memref,] + operands = Value[asyncDependencies..., dynamicSizes..., symbolOperands...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + push!( + attributes, + operandsegmentsizes([ + length(asyncDependencies), length(dynamicSizes), length(symbolOperands) + ]), + ) + !isnothing(asyncToken) && push!(op_ty_results, asyncToken) + !isnothing(hostShared) && push!(attributes, namedattribute("hostShared", hostShared)) + + return create_operation( + "gpu.alloc", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`barrier` + +The \"barrier\" op synchronizes all work items of a workgroup. It is used +to coordinate communication between the work items of the workgroup. + +```mlir +gpu.barrier +``` + +waits until all work items in the workgroup have reached this point +and all memory accesses made by these work items prior to the op are +visible to all work items in the workgroup. Data hazards between work items +accessing the same memory can be avoided by synchronizing work items +in-between these accesses. + +Either none or all work items of a workgroup need to execute this op +in convergence. +""" +function barrier(; location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "gpu.barrier", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`binary` + +GPU binaries provide a semantic mechanism for storing GPU objects, +e.g. the result of compiling a GPU module to an object file. + +This operation has 3 arguments: + - The name of the binary. + - An optional attribute implementing the offloading LLVM translation interface. + - An array of GPU object attributes. + +During translation, the offloading attribute will be called for translating +GPU `binary` and `launch_func` operations. The default offloading handler is: +`#gpu.select_object`, this handler selects the first object from the array +and embeds it as a string. + +Examples: +``` + // Selects the first object. + gpu.binary @myobject [#gpu.object<...>, #gpu.object<...>] + // Uses the `#foo.my_handler` for handling the binary during translation. + gpu.binary @myobject <#foo.my_handler> [#gpu.object<...>, #gpu.object<...>] + // Selects the object with the `#rocdl.target` target attribute. + gpu.binary @myobject <#gpu.select_object<#rocdl.target>> [#gpu.object<...>, #gpu.object<#rocdl.target, ...>] +``` +""" +function binary(; sym_name, offloadingHandler=nothing, objects, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("sym_name", sym_name), namedattribute("objects", objects) + ] + !isnothing(offloadingHandler) && + push!(attributes, namedattribute("offloadingHandler", offloadingHandler)) + + return create_operation( + "gpu.binary", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`block_dim` + +Returns the number of threads in the thread block (aka the block size) along +the x, y, or z `dimension`. + +# Example + +```mlir +%bDimX = gpu.block_dim x +``` + +If `known_block_size` is set on an this operation\'s enclosing `gpu.func`, +or `gpu.known_block_size` is set on an enclosing `FunctionOpInterface` +implementor, or if the enclosing `gpu.launch` specifies a constant size for +`dimension`\'s blocks, these contextual facts may be used to infer that this +operation has a constant value, though such a transformation will not be +performed by canonicalization or the default constant folder. Executions which +cause that constant-value assumption to be false incur undefined behavior. + +If `upper_bound` is set, executions where the bblock size along `dimension` +exceeds `upper_bound` cause undefined behavior. + +There is an implicit upper bound of `kMaxDim` (currently uint32_t::max). +""" +function block_dim(; + result_0=nothing::Union{Nothing,IR.Type}, + dimension, + upper_bound=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dimension", dimension),] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(upper_bound) && push!(attributes, namedattribute("upper_bound", upper_bound)) + + return create_operation( + "gpu.block_dim", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`block_id` + +Returns the block id, i.e. the index of the current block within the grid +along the x, y, or z `dimension`. + +# Example + +```mlir +%bIdY = gpu.block_id y +``` + +If `upper_bound` is set, or if one can be inferred from `known_grid_size`-type +annotations in context, executions where the block index in `dimension` would +be greater than or equal to that bound cause undefined behavior. `upper_bound` +takes priority over bounds inferrable from context. + +There is an implicit upper bound of `kMaxDim` (currently uint32_t::max). +""" +function block_id(; + result_0=nothing::Union{Nothing,IR.Type}, + dimension, + upper_bound=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dimension", dimension),] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(upper_bound) && push!(attributes, namedattribute("upper_bound", upper_bound)) + + return create_operation( + "gpu.block_id", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`cluster_block_id` + +Returns the block id within the cluster along the x, y, or z `dimension`. + +# Example + +```mlir +%cBlockIdY = gpu.cluster_block_id y +``` + +If `upper_bound` is set, then executing (a lowering of) this operation in an +environment where the number of thread blocks per cluster along `dimension` +is greater than `upper_bound` causes undefined behavior. + +There is an implicit upper bound of `kMaxClusterDim` (currently 8). +""" +function cluster_block_id(; + result_0=nothing::Union{Nothing,IR.Type}, + dimension, + upper_bound=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dimension", dimension),] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(upper_bound) && push!(attributes, namedattribute("upper_bound", upper_bound)) + + return create_operation( + "gpu.cluster_block_id", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`cluster_dim_blocks` + +Returns the number of thread blocks in the cluster along +the x, y, or z `dimension`. + +# Example + +```mlir +%cDimBlocksX = gpu.cluster_dim_blocks x +``` + +If `upper_bound` is set, then executing (a lowering of) this operation in an +environment where the thread blocks per cluster is greater than `upper_bound` +causes undefined behavior. + +There is an implicit upper bound of `kMaxClusterDim` (currently 8). +""" +function cluster_dim_blocks(; + result_0=nothing::Union{Nothing,IR.Type}, + dimension, + upper_bound=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dimension", dimension),] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(upper_bound) && push!(attributes, namedattribute("upper_bound", upper_bound)) + + return create_operation( + "gpu.cluster_dim_blocks", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`cluster_dim` + +Returns the number of cluster identifiers per grid along +the x, y, or z `dimension`. + +# Example + +```mlir +%cDimX = gpu.cluster_dim x +``` + +If `upper_bound` is set, then executing (a lowering of) this operation in an +environment where the clusters per grid is greater than `upper_bound` causes +undefined behavior. + +There is an implicit upper bound of `kMaxDim` (currently uint32_t::max). +""" +function cluster_dim(; + result_0=nothing::Union{Nothing,IR.Type}, + dimension, + upper_bound=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dimension", dimension),] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(upper_bound) && push!(attributes, namedattribute("upper_bound", upper_bound)) + + return create_operation( + "gpu.cluster_dim", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`cluster_id` + +Returns the cluster id, i.e. the index of the current cluster within the +grid along the x, y, or z `dimension`. + +# Example + +```mlir +%cIdY = gpu.cluster_id y +``` + +If `upper_bound` is set, then executing (a lowering of) this operation in an +environment where the number of clusters in the grid along `dimension` is +greater than `upper_bound` causes undefined behavior. + +There is an implicit upper bound of `kMaxDim` (currently uint32_t::max). +""" +function cluster_id(; + result_0=nothing::Union{Nothing,IR.Type}, + dimension, + upper_bound=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dimension", dimension),] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(upper_bound) && push!(attributes, namedattribute("upper_bound", upper_bound)) + + return create_operation( + "gpu.cluster_id", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`create_2to4_spmat` + +The `gpu.create_2to4_spmat` operation initializes a sparse matrix in dense +format with 2:4 sparsity. +The buffers must already be copied from the host to the device prior to +using this operation. The operation returns a handle to the sparse +matrix descriptor. + +If the `async` keyword is present, the op is executed asynchronously (i.e. +it does not block until the execution has finished on the device). In +that case, it returns a !gpu.async.token in addition to the environment. + +# Example + +```mlir +%spmat, %token = gpu.create_2to4_spmat async [%dep] {PRUNE_AND_CHECK} %rows, %cols, %mem: memref +``` +""" +function create_2to4_spmat( + asyncDependencies::Vector{Value}, + rows::Value, + cols::Value, + memref::Value; + spMat::IR.Type, + asyncToken=nothing::Union{Nothing,IR.Type}, + pruneFlag=nothing, + location=Location(), +) + op_ty_results = IR.Type[spMat,] + operands = Value[asyncDependencies..., rows, cols, memref] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(asyncToken) && push!(op_ty_results, asyncToken) + !isnothing(pruneFlag) && push!(attributes, namedattribute("pruneFlag", pruneFlag)) + + return create_operation( + "gpu.create_2to4_spmat", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`create_bsr` + +The `gpu.create_bsr` operation initializes a sparse matrix in BSR format +with the given sizes for the matrix and blocks from the given position, +index, and values buffers. The buffers must already be copied from the +host to the device prior to using this operation. The operation returns +a handle to the sparse matrix descriptor. + +The BSR format is similar to CSR, where the column indices represent +two-dimensional blocks instead of a single matrix entry. Note that this +operation (currently) only supports storage with **square** blocks, +i.e., `rBlockSize == cBlockSize`. + +If the `async` keyword is present, the op is executed asynchronously (i.e. +it does not block until the execution has finished on the device). In +that case, it returns a !gpu.async.token in addition to the environment. + +# Example + +```mlir +%spmat, %token = gpu.create_bsr async [%dep] + %brows, %bcols, %bnnz, %rBlockSize, %cBlockSize, + %bRowPos, %bColIdxs, %values : memref, memref, memref +``` +""" +function create_bsr( + asyncDependencies::Vector{Value}, + brows::Value, + bcols::Value, + bnnz::Value, + rBlockSize::Value, + cBlockSize::Value, + bRowPos::Value, + bColIdxs::Value, + values::Value; + spmat::IR.Type, + asyncToken=nothing::Union{Nothing,IR.Type}, + location=Location(), +) + op_ty_results = IR.Type[spmat,] + operands = Value[ + asyncDependencies..., + brows, + bcols, + bnnz, + rBlockSize, + cBlockSize, + bRowPos, + bColIdxs, + values, + ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(asyncToken) && push!(op_ty_results, asyncToken) + + return create_operation( + "gpu.create_bsr", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`create_coo_aos` + +The `gpu.create_coo_aos` operation initializes a sparse matrix in COO format +with the given sizes from the given index and values buffers. The buffers +must already be copied from the host to the device prior to using this +operation. The operation returns a handle to the sparse matrix descriptor. +Unlike the default `gpu.create_coo` operation, this operation builds the +COO format from a single index buffer in AoS format (note that this +feature has been deprecated in cuSparse 11.2). + +If the `async` keyword is present, the op is executed asynchronously (i.e. +it does not block until the execution has finished on the device). In +that case, it returns a !gpu.async.token in addition to the environment. + +# Example + +```mlir +%spmat, %token = gpu.create_coo_aos async [%dep] %rows, %cols, %nnz, %idxs, + %values : memref, memref +``` +""" +function create_coo_aos( + asyncDependencies::Vector{Value}, + rows::Value, + cols::Value, + nnz::Value, + idxs::Value, + values::Value; + spmat::IR.Type, + asyncToken=nothing::Union{Nothing,IR.Type}, + location=Location(), +) + op_ty_results = IR.Type[spmat,] + operands = Value[asyncDependencies..., rows, cols, nnz, idxs, values] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(asyncToken) && push!(op_ty_results, asyncToken) + + return create_operation( + "gpu.create_coo_aos", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`create_coo` + +The `gpu.create_coo` operation initializes a sparse matrix in COO format +with the given sizes from the given index and values buffers. The buffers +must already be copied from the host to the device prior to using this +operation. The operation returns a handle to the sparse matrix descriptor. +Note that this operation builds the COO in SoA format. + +If the `async` keyword is present, the op is executed asynchronously (i.e. +it does not block until the execution has finished on the device). In +that case, it returns a !gpu.async.token in addition to the environment. + +# Example + +```mlir +%spmat, %token = gpu.create_coo async [%dep] %rows, %cols, %nnz, %rowIdx, + %colIdx, %values : memref, memref, memref +``` +""" +function create_coo( + asyncDependencies::Vector{Value}, + rows::Value, + cols::Value, + nnz::Value, + rowIdxs::Value, + colIdxs::Value, + values::Value; + spmat::IR.Type, + asyncToken=nothing::Union{Nothing,IR.Type}, + location=Location(), +) + op_ty_results = IR.Type[spmat,] + operands = Value[asyncDependencies..., rows, cols, nnz, rowIdxs, colIdxs, values] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(asyncToken) && push!(op_ty_results, asyncToken) + + return create_operation( + "gpu.create_coo", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`create_csc` + +The `gpu.create_csc` operation initializes a sparse matrix in CSC format +with the given sizes from the given position, index, and values buffers. +The buffers must already be copied from the host to the device prior to +using this operation. The operation returns a handle to the sparse +matrix descriptor. + +The CSC format has exactly the same memory layout as its transpose +in CSR format (and vice versa). + +If the `async` keyword is present, the op is executed asynchronously (i.e. +it does not block until the execution has finished on the device). In +that case, it returns a !gpu.async.token in addition to the environment. + +# Example + +```mlir +%spmat, %token = gpu.create_csc async [%dep] %rows, %cols, %nnz, %colPos, + %rowIdx, %values : memref, memref, memref +``` +""" +function create_csc( + asyncDependencies::Vector{Value}, + rows::Value, + cols::Value, + nnz::Value, + colPos::Value, + rowIdxs::Value, + values::Value; + spmat::IR.Type, + asyncToken=nothing::Union{Nothing,IR.Type}, + location=Location(), +) + op_ty_results = IR.Type[spmat,] + operands = Value[asyncDependencies..., rows, cols, nnz, colPos, rowIdxs, values] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(asyncToken) && push!(op_ty_results, asyncToken) + + return create_operation( + "gpu.create_csc", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`create_csr` + +The `gpu.create_csr` operation initializes a sparse matrix in CSR format +with the given sizes from the given position, index, and values buffers. +The buffers must already be copied from the host to the device prior to +using this operation. The operation returns a handle to the sparse +matrix descriptor. + +The CSR format has exactly the same memory layout as its transpose +in CSC format (and vice versa). + +If the `async` keyword is present, the op is executed asynchronously (i.e. +it does not block until the execution has finished on the device). In +that case, it returns a !gpu.async.token in addition to the environment. + +# Example + +```mlir +%spmat, %token = gpu.create_csr async [%dep] %rows, %cols, %nnz, %rowPos, + %colIdx, %values : memref, memref, memref +``` +""" +function create_csr( + asyncDependencies::Vector{Value}, + rows::Value, + cols::Value, + nnz::Value, + rowPos::Value, + colIdxs::Value, + values::Value; + spmat::IR.Type, + asyncToken=nothing::Union{Nothing,IR.Type}, + location=Location(), +) + op_ty_results = IR.Type[spmat,] + operands = Value[asyncDependencies..., rows, cols, nnz, rowPos, colIdxs, values] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(asyncToken) && push!(op_ty_results, asyncToken) + + return create_operation( + "gpu.create_csr", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`create_dn_tensor` + +The `gpu.create_dn_tensor` operation initializes a dense tensor from +the given values buffer and sizes. The buffer must already be copied +from the host to the device prior to using this operation. The +operation returns a handle to the dense tensor descriptor. + +If the `async` keyword is present, the op is executed asynchronously (i.e. +it does not block until the execution has finished on the device). In +that case, it returns a !gpu.async.token in addition to the environment. + +# Example + +```mlir +%dmat, %token = gpu.create_dn_tensor async [%dep] %mem, %dims : index, index into memref +``` +""" +function create_dn_tensor( + asyncDependencies::Vector{Value}, + memref::Value, + dims::Vector{Value}; + dnTensor::IR.Type, + asyncToken=nothing::Union{Nothing,IR.Type}, + location=Location(), +) + op_ty_results = IR.Type[dnTensor,] + operands = Value[asyncDependencies..., memref, dims...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + push!(attributes, operandsegmentsizes([length(asyncDependencies), 1, length(dims)])) + !isnothing(asyncToken) && push!(op_ty_results, asyncToken) + + return create_operation( + "gpu.create_dn_tensor", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`dealloc` + +The `gpu.dealloc` operation frees the region of memory referenced by a +memref which was originally created by the `gpu.alloc` operation. It is +similar to the `memref.dealloc` op, but supports asynchronous GPU execution. + +The op does not execute before all async dependencies have finished +executing. + +If the `async` keyword is present, the op is executed asynchronously (i.e. +it does not block until the execution has finished on the device). In +that case, it returns a !gpu.async.token. + +# Example + +```mlir +%token = gpu.dealloc async [%dep] %memref : memref<8x64xf32, 1> +``` +""" +function dealloc( + asyncDependencies::Vector{Value}, + memref::Value; + asyncToken=nothing::Union{Nothing,IR.Type}, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[asyncDependencies..., memref] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(asyncToken) && push!(op_ty_results, asyncToken) + + return create_operation( + "gpu.dealloc", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`destroy_dn_tensor` + +The `gpu.destroy_dn_tensor` operation releases all resources of a dense +tensor represented by a handle that was previously created by a +`gpu.create_dn_tensor` operation. + +If the `async` keyword is present, the op is executed asynchronously (i.e. +it does not block until the execution has finished on the device). In +that case, it returns a !gpu.async.token in addition to the environment. + +# Example + +```mlir +%token = gpu.destroy_dn_tensor async [%dep] %dnTensor +``` +""" +function destroy_dn_tensor( + asyncDependencies::Vector{Value}, + dnTensor::Value; + asyncToken=nothing::Union{Nothing,IR.Type}, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[asyncDependencies..., dnTensor] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(asyncToken) && push!(op_ty_results, asyncToken) + + return create_operation( + "gpu.destroy_dn_tensor", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`destroy_sp_mat` + +The `gpu.destroy_sp_mat` operation releases all resources of a sparse +matrix represented by a handle that was previously created by a +one of the sparse matrix creation operations. + +If the `async` keyword is present, the op is executed asynchronously (i.e. +it does not block until the execution has finished on the device). In +that case, it returns a !gpu.async.token in addition to the environment. + +# Example + +```mlir +%token = gpu.destroy_sp_mat async [%dep] %spmat +``` +""" +function destroy_sp_mat( + asyncDependencies::Vector{Value}, + spmat::Value; + asyncToken=nothing::Union{Nothing,IR.Type}, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[asyncDependencies..., spmat] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(asyncToken) && push!(op_ty_results, asyncToken) + + return create_operation( + "gpu.destroy_sp_mat", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`dynamic_shared_memory` + +This operation provides a memref pointer to the start of dynamic shared +memory, often referred to as workgroup memory. It\'s important to note that +this dynamic shared memory needs to be allocated at kernel launch. One can +conveniently utilize `the dynamic_shared_memory_size` parameter of +`gpu.launch` for this purpose. + +Examples: +```mlir +%0 = gpu.dynamic.shared.memory : memref> +%1 = memref.view %0[%c8192][] : memref> + to memref<32x64xf32, #gpu.address_space> +%2 = memref.view %0[%c16384][] : memref> + to memref<32x64xf32, #gpu.address_space> +``` +""" +function dynamic_shared_memory(; resultMemref::IR.Type, location=Location()) + op_ty_results = IR.Type[resultMemref,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "gpu.dynamic_shared_memory", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`func` + +Defines a function that can be executed on a GPU. This supports memory +attribution and its body has a particular execution model. + +GPU functions are either kernels (as indicated by the `kernel` attribute) or +regular functions. The former can be launched from the host side, while the +latter are device side only. + +The memory attribution defines SSA values that correspond to memory buffers +allocated in the memory hierarchy of the GPU (see below). + +The operation has one attached region that corresponds to the body of the +function. The region arguments consist of the function arguments without +modification, followed by buffers defined in memory annotations. The body of +a GPU function, when launched, is executed by multiple work items. There are +no guarantees on the order in which work items execute, or on the connection +between them. In particular, work items are not necessarily executed in +lock-step. Synchronization ops such as \"gpu.barrier\" should be used to +coordinate work items. Declarations of GPU functions, i.e. not having the +body region, are not supported. + +A function may optionally be annotated with the block and/or grid sizes +that will be used when it is launched using the `known_block_size` and +`known_grid_size` attributes, respectively. If set, these attributes must +be arrays of three 32-bit integers giving the x, y, and z launch dimensions. +Launching a kernel that has these annotations, or that calls a function with +these annotations, using a block size or grid size other than what is specified +is undefined behavior. These attributes may be set on non-`gpu.func` functions +by using `gpu.known_block_size` or `gpu.known_grid_size`, but this carries +the risk that they will de discarded. + +# Syntax + +``` +op ::= `gpu.func` symbol-ref-id `(` argument-list `)` (`->` +function-result-list)? + memory-attribution `kernel`? function-attributes? region + +memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)? + (`private` `(` ssa-id-and-type-list `)`)? +``` + +# Example + +```mlir +gpu.func @foo(%arg0: index) + workgroup(%workgroup: memref<32xf32, 3>) + private(%private: memref<1xf32, 5>) + kernel + attributes {qux: \"quux\"} { + gpu.return +} +``` + +The generic form illustrates the concept + +```mlir +\"gpu.func\"(%arg: index) {sym_name: \"foo\", kernel, qux: \"quux\"} ({ +^bb0(%arg0: index, %workgroup: memref<32xf32, 3>, + %private: memref<1xf32, 5>): + \"gpu.return\"() : () -> () +}) : (index) -> () +``` + +Note the non-default memory spaces used in memref types in memory +attribution. +""" +function func(; + function_type, + arg_attrs=nothing, + res_attrs=nothing, + workgroup_attrib_attrs=nothing, + private_attrib_attrs=nothing, + known_block_size=nothing, + known_grid_size=nothing, + body::Region, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[body,] + successors = Block[] + attributes = NamedAttribute[namedattribute("function_type", function_type),] + !isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs)) + !isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs)) + !isnothing(workgroup_attrib_attrs) && + push!(attributes, namedattribute("workgroup_attrib_attrs", workgroup_attrib_attrs)) + !isnothing(private_attrib_attrs) && + push!(attributes, namedattribute("private_attrib_attrs", private_attrib_attrs)) + !isnothing(known_block_size) && + push!(attributes, namedattribute("known_block_size", known_block_size)) + !isnothing(known_grid_size) && + push!(attributes, namedattribute("known_grid_size", known_grid_size)) + + return create_operation( + "gpu.func", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`module_` + +GPU module contains code that is intended to be run on a GPU. A host device +can launch this code through a gpu.launc_func that creates a fully +qualified symbol through the gpu.module\'s symbol and a gpu.func symbol +contained in the gpu.module. + +The module\'s top-level scope is modeled by a single region with a single +block. GPU modules are required to have a name that is used for symbol +resolution by the gpu.launch_func operation. + +Using an op with a region to define a GPU module enables \"embedding\" GPU +modules with SIMT execution models in other dialects in a clean manner and +allows filtering of code regions to execute passes on only code intended to +or not intended to be run on the separate device. + +Modules can contain zero or more target attributes. These attributes encode +how to transform modules into binary strings and are used by the +`gpu-module-to-binary` pass to transform modules into GPU binaries. + +Modules can contain an optional `OffloadingTranslationAttr` attribute. This +attribute will be used during the `gpu-module-to-binary` pass to specify the +`OffloadingTranslationAttr` used when creating the `gpu.binary` operation. + +``` +gpu.module @symbol_name { + gpu.func {} + ... +} +// Module with offloading handler and target attributes. +gpu.module @symbol_name2 <#gpu.select_object<1>> [ + #nvvm.target, + #rocdl.target] { + gpu.func {} + ... +} +``` +""" +function module_(; + sym_name, + targets=nothing, + offloadingHandler=nothing, + bodyRegion::Region, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[bodyRegion,] + successors = Block[] + attributes = NamedAttribute[namedattribute("sym_name", sym_name),] + !isnothing(targets) && push!(attributes, namedattribute("targets", targets)) + !isnothing(offloadingHandler) && + push!(attributes, namedattribute("offloadingHandler", offloadingHandler)) + + return create_operation( + "gpu.module", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`global_id` + +Returns the unique global workitem/thread id, i.e., the unique index of the +current workitem/thread within all workgroups / grid along the x, y, or z +`dimension`. + +# Example + +```mlir +%gidX = gpu.global_id x +%gidX = gpu.global_id x upper_bound 65536 +``` + +The `upper_bound` attribute defines an upper bound analogously to the ones on +`thread_id` and `block_id`. If one is not set, the bound may be inferred from +a combination of `known_block_size` and `known_grid_size`-type annotations. +""" +function global_id(; + result_0=nothing::Union{Nothing,IR.Type}, + dimension, + upper_bound=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dimension", dimension),] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(upper_bound) && push!(attributes, namedattribute("upper_bound", upper_bound)) + + return create_operation( + "gpu.global_id", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`grid_dim` + +Returns the number of thread blocks in the grid along the x, y, or z +`dimension`. + +# Example + +```mlir +%gDimZ = gpu.grid_dim z +``` + + +If `known_grid_size` is set on an this operation\'s enclosing `gpu.func`, +or `gpu.known_grid_size` is set on an enclosing `FunctionOpInterface` +implementor, or if the enclosing `gpu.launch` specifies a constant size for +`dimension`\'s grid length, these contextual facts may be used to infer that this +operation has a constant value, though such a transformation will not be +performed by canonicalization or the default constant folder. Executions which +cause that constant-value assumption to be false incur undefined behavior. + +If `upper_bound` is set, executions where the grid size in `dimension` would +exceed `upper_bound` cause undefined behavior. + +There is an implicit upper bound of `kMaxDim` (currently uint32_t::max). +""" +function grid_dim(; + result_0=nothing::Union{Nothing,IR.Type}, + dimension, + upper_bound=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dimension", dimension),] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(upper_bound) && push!(attributes, namedattribute("upper_bound", upper_bound)) + + return create_operation( + "gpu.grid_dim", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`host_register` + +This op maps the provided host buffer into the device address space. + +This operation may not be supported in every environment, there is not yet a +way to check at runtime whether this feature is supported. + +Writes from the host are guaranteed to be visible to device kernels that are +launched afterwards. Writes from the device are guaranteed to be visible on +the host after synchronizing with the device kernel completion. +""" +function host_register(value::Value; location=Location()) + op_ty_results = IR.Type[] + operands = Value[value,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "gpu.host_register", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`host_unregister` + +This op unmaps the provided host buffer from the device address space. + +This operation may not be supported in every environment, there is not yet a + way to check at runtime whether this feature is supported. +""" +function host_unregister(value::Value; location=Location()) + op_ty_results = IR.Type[] + operands = Value[value,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "gpu.host_unregister", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`lane_id` + +Returns the lane id within the subgroup (warp/wave). + +# Example +```mlir +%laneId = gpu.lane_id +``` + +If `upper_bound` is set, executions with more than `upper_bound` lanes per +subgroup cause undefined behavior. In the abscence of `upper_bound`, +the lane id is still assumed to be non-negative and less than the +target-independent `kMaxSubgroupSize` (currently 128). +""" +function lane_id(; + result=nothing::Union{Nothing,IR.Type}, upper_bound=nothing, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + !isnothing(upper_bound) && push!(attributes, namedattribute("upper_bound", upper_bound)) + + return create_operation( + "gpu.lane_id", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`launch_func` + +Launch a kernel function on the specified grid of thread blocks. +`gpu.launch` operations are lowered to `gpu.launch_func` operations by +outlining the kernel body into a function in a dedicated module, which +reflects the separate compilation process. The kernel function is required +to have the `gpu.kernel` attribute. The module containing the kernel +function is required to be a gpu.module. And finally, the module containing +the kernel module (which thus cannot be the top-level module) is required +to have the `gpu.container_module` attribute. The `gpu.launch_func` +operation has a symbol attribute named `kernel` to identify the fully +specified kernel function to launch (both the gpu.module and func). + +The `gpu.launch_func` supports async dependencies: the kernel does not start +executing until the ops producing those async dependencies have completed. + +By the default, the host implicitly blocks until kernel execution has +completed. If the `async` keyword is present, the host does not block but +instead a `!gpu.async.token` is returned. Other async GPU ops can take this +token as dependency. + +The operation requires at least the grid and block sizes along the x,y,z +dimensions as arguments. When a lower-dimensional kernel is required, +unused sizes must be explicitly set to `1`. + +The remaining operands are optional. The first optional operand corresponds +to the amount of dynamic shared memory a kernel\'s workgroup should be +allocated; when this operand is not present, a zero size is assumed. + +The remaining operands if present are passed as arguments to the kernel +function. + +The `gpu.launch_func` also supports kernel launching with clusters if +supported by the target architecture. The cluster size can be set by +`clusterSizeX`, `clusterSizeY`, and `clusterSizeZ` arguments. When these +arguments are present, the Op launches a kernel that clusters the given +thread blocks. This feature is exclusive to certain architectures. + +# Example + +```mlir +module attributes {gpu.container_module} { + + // This module creates a separate compilation unit for the GPU compiler. + gpu.module @kernels { + func.func @kernel_1(%arg0 : f32, %arg1 : memref) + attributes { nvvm.kernel = true } { + + // Operations that produce block/thread IDs and dimensions are + // injected when outlining the `gpu.launch` body to a function called + // by `gpu.launch_func`. + %tIdX = gpu.thread_id x + %tIdY = gpu.thread_id y + %tIdZ = gpu.thread_id z + + %bDimX = gpu.block_dim x + %bDimY = gpu.block_dim y + %bDimZ = gpu.block_dim z + + %bIdX = gpu.block_id x + %bIdY = gpu.block_id y + %bIdZ = gpu.block_id z + + %gDimX = gpu.grid_dim x + %gDimY = gpu.grid_dim y + %gDimZ = gpu.grid_dim z + + // (Optional) Cluster size only for support architectures + %cIdX = gpu.cluster_id x + %cIdY = gpu.cluster_id y + %cIdZ = gpu.cluster_id z + + %cDimX = gpu.cluster_dim x + %cDimY = gpu.cluster_dim y + %cDimZ = gpu.cluster_dim z + + \"some_op\"(%bx, %tx) : (index, index) -> () + %42 = load %arg1[%bx] : memref + } + } + + %t0 = gpu.wait async + gpu.launch_func + async // (Optional) Don\'t block host, return token. + [%t0] // (Optional) Execute only after %t0 has completed. + @kernels::@kernel_1 // Kernel function. + clusters in (%cst, %cst, %cst) // (Optional) Cluster size only for support architectures. + blocks in (%cst, %cst, %cst) // Grid size. + threads in (%cst, %cst, %cst) // Block size. + dynamic_shared_memory_size %s // (Optional) Amount of dynamic shared + // memory to allocate for a workgroup. + args(%arg0 : f32, // (Optional) Kernel arguments. + %arg1 : memref) +} +``` +""" +function launch_func( + asyncDependencies::Vector{Value}, + gridSizeX::Value, + gridSizeY::Value, + gridSizeZ::Value, + blockSizeX::Value, + blockSizeY::Value, + blockSizeZ::Value, + clusterSizeX=nothing::Union{Nothing,Value}; + clusterSizeY=nothing::Union{Nothing,Value}, + clusterSizeZ=nothing::Union{Nothing,Value}, + dynamicSharedMemorySize=nothing::Union{Nothing,Value}, + kernelOperands::Vector{Value}, + asyncObject=nothing::Union{Nothing,Value}, + asyncToken=nothing::Union{Nothing,IR.Type}, + kernel, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[ + asyncDependencies..., + gridSizeX, + gridSizeY, + gridSizeZ, + blockSizeX, + blockSizeY, + blockSizeZ, + kernelOperands..., + ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("kernel", kernel),] + !isnothing(clusterSizeX) && push!(operands, clusterSizeX) + !isnothing(clusterSizeY) && push!(operands, clusterSizeY) + !isnothing(clusterSizeZ) && push!(operands, clusterSizeZ) + !isnothing(dynamicSharedMemorySize) && push!(operands, dynamicSharedMemorySize) + !isnothing(asyncObject) && push!(operands, asyncObject) + push!( + attributes, + operandsegmentsizes([ + length(asyncDependencies), + 1, + 1, + 1, + 1, + 1, + 1, + if (clusterSizeX == nothing) + 0 + elseif 1(clusterSizeY == nothing) + 0 + elseif 1(clusterSizeZ == nothing) + 0 + elseif 1(dynamicSharedMemorySize == nothing) + 0 + else + 1length(kernelOperands) + end, + (asyncObject == nothing) ? 0 : 1, + ]), + ) + !isnothing(asyncToken) && push!(op_ty_results, asyncToken) + + return create_operation( + "gpu.launch_func", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`launch` + +Launch a kernel on the specified grid of thread blocks. The body of the +kernel is defined by the single region that this operation contains. The +operation takes an optional list of async dependencies followed by six +operands and an optional operand. + +The `async` keyword indicates the kernel should be launched asynchronously; +the operation returns a new !gpu.async.token when the keyword is specified. +The kernel launched does not start executing until the ops producing its +async dependencies (optional operands) have completed. + +The first three operands (following any async dependencies) are grid sizes +along the x,y,z dimensions and the following three are block sizes along the +x,y,z dimensions. When a lower-dimensional kernel is required, unused sizes +must be explicitly set to `1`. The last operand is optional and corresponds +to the amount of dynamic shared memory a kernel\'s workgroup should be +allocated; when this operand is not present, a zero size is assumed. + +The body region has at least _twelve_ arguments, or _eighteen_ if cluster +dimensions are present, grouped as follows: + +- three optional arguments that contain cluster identifiers along x,y,z + dimensions; +- three arguments that contain block identifiers along x,y,z dimensions; +- three arguments that contain thread identifiers along x,y,z dimensions; +- operands of the `gpu.launch` operation as is (i.e. the operands for + grid and block sizes). +- a variadic number of Workgroup memory attributions. +- a variadic number of Private memory attributions. + +The `kernelFunc` and `kernelModule` attributes are optional and specifies +the kernel name and a module in which the kernel should be outlined. + +# Syntax + +``` +operation ::= `gpu.launch` (`async` (`[` ssa-id-list `]`)? )? + ( `clusters` `(` ssa-id-list `)` `in` ssa-reassignment )? + `blocks` `(` ssa-id-list `)` `in` ssa-reassignment + `threads` `(` ssa-id-list `)` `in` ssa-reassignment + (dynamic_shared_memory_size ssa-use)? + memory-attribution + region attr-dict? +ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)` +memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)? + (`private` `(` ssa-id-and-type-list `)`)? +``` + +# Example + +```mlir +gpu.launch blocks(%bx, %by, %bz) in (%sz_bx = %0, %sz_by = %1, %sz_bz = %2) + threads(%tx, %ty, %tz) in (%sz_tx = %3, %sz_ty = %4, %sz_tz = %5) { + // Block and thread identifiers, as well as block/grid sizes are + // immediately usable inside body region. + \"some_op\"(%bx, %tx) : (index, index) -> () + // Assuming %val1 is defined outside the gpu.launch region. + %42 = load %val1[%bx] : memref +} + +// Generic syntax explains how the pretty syntax maps to the IR structure. +\"gpu.launch\"(%cst, %cst, %c1, // Grid sizes. + %cst, %c1, %c1) // Block sizes. + + {/*attributes*/} + // All sizes and identifiers have \"index\" size. + : (index, index, index, index, index, index) -> () { +// The operation passes block and thread identifiers, followed by grid and +// block sizes. +^bb0(%bx : index, %by : index, %bz : index, + %tx : index, %ty : index, %tz : index, + %num_bx : index, %num_by : index, %num_bz : index, + %num_tx : index, %num_ty : index, %num_tz : index) + \"some_op\"(%bx, %tx) : (index, index) -> () + %3 = \"memref.load\"(%val1, %bx) : (memref, index) -> f32 +} + +// Launch with memory attributions. +gpu.launch blocks(%bx, %by, %bz) in (%sz_bx = %0, %sz_by = %1, %sz_bz = %2) + threads(%tx, %ty, %tz) in (%sz_tx = %3, %sz_ty = %4, %sz_tz = %5) + workgroup(%workgroup: memref<32xf32, 3>) + private(%private: memref<1xf32, 5>) { + // Block and thread identifiers, as well as block/grid sizes are + // immediately usable inside body region. + \"some_op\"(%bx, %tx) : (index, index) -> () + // Assuming %val1 is defined outside the gpu.launch region. + %42 = load %workgroup[%bx] : memref<32xf32, 3> +} + +// Launch with clusters. +gpu.launch clusters(%cx, %cy, %cz) in (%sz_cx = %0, %sz_cy = %1, %sz_cz = %2) + blocks(%bx, %by, %bz) in (%sz_bx = %3, %sz_by = %4, %sz_bz = %5) + threads(%tx, %ty, %tz) in (%sz_tx = %6, %sz_ty = %7, %sz_tz = %8) +{ + // Cluster, block and thread identifiers, as well as cluster/block/grid + // sizes are immediately usable inside body region. + \"some_op\"(%cx, %bx, %tx) : (index, index, index) -> () +} +``` + +Rationale: using operation/block arguments gives analyses a clear way of +understanding that a value has additional semantics (e.g., we will need to +know what value corresponds to threadIdx.x for coalescing). We can recover +these properties by analyzing the operations producing values, but it is +easier just to have that information by construction. +""" +function launch( + asyncDependencies::Vector{Value}, + gridSizeX::Value, + gridSizeY::Value, + gridSizeZ::Value, + blockSizeX::Value, + blockSizeY::Value, + blockSizeZ::Value, + clusterSizeX=nothing::Union{Nothing,Value}; + clusterSizeY=nothing::Union{Nothing,Value}, + clusterSizeZ=nothing::Union{Nothing,Value}, + dynamicSharedMemorySize=nothing::Union{Nothing,Value}, + asyncToken=nothing::Union{Nothing,IR.Type}, + kernelFunc=nothing, + kernelModule=nothing, + body::Region, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[ + asyncDependencies..., + gridSizeX, + gridSizeY, + gridSizeZ, + blockSizeX, + blockSizeY, + blockSizeZ, + ] + owned_regions = Region[body,] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(clusterSizeX) && push!(operands, clusterSizeX) + !isnothing(clusterSizeY) && push!(operands, clusterSizeY) + !isnothing(clusterSizeZ) && push!(operands, clusterSizeZ) + !isnothing(dynamicSharedMemorySize) && push!(operands, dynamicSharedMemorySize) + push!( + attributes, + operandsegmentsizes([ + length(asyncDependencies), + 1, + 1, + 1, + 1, + 1, + 1, + if (clusterSizeX == nothing) + 0 + elseif 1(clusterSizeY == nothing) + 0 + elseif 1(clusterSizeZ == nothing) + 0 + elseif 1(dynamicSharedMemorySize == nothing) + 0 + else + 1 + end, + ]), + ) + !isnothing(asyncToken) && push!(op_ty_results, asyncToken) + !isnothing(kernelFunc) && push!(attributes, namedattribute("kernelFunc", kernelFunc)) + !isnothing(kernelModule) && + push!(attributes, namedattribute("kernelModule", kernelModule)) + + return create_operation( + "gpu.launch", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`memcpy` + +The `gpu.memcpy` operation copies the content of one memref to another. + +The op does not execute before all async dependencies have finished +executing. + +If the `async` keyword is present, the op is executed asynchronously (i.e. +it does not block until the execution has finished on the device). In +that case, it returns a !gpu.async.token. + +# Example + +```mlir +%token = gpu.memcpy async [%dep] %dst, %src : memref, memref +``` +""" +function memcpy( + asyncDependencies::Vector{Value}, + dst::Value, + src::Value; + asyncToken=nothing::Union{Nothing,IR.Type}, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[asyncDependencies..., dst, src] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(asyncToken) && push!(op_ty_results, asyncToken) + + return create_operation( + "gpu.memcpy", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`memset` + +The `gpu.memset` operation sets the content of memref to a scalar value. + +The op does not execute before all async dependencies have finished +executing. + +If the `async` keyword is present, the op is executed asynchronously (i.e. +it does not block until the execution has finished on the device). In +that case, it returns a !gpu.async.token. + +# Example + +```mlir +%token = gpu.memset async [%dep] %dst, %value : memref, f32 +``` +""" +function memset( + asyncDependencies::Vector{Value}, + dst::Value, + value::Value; + asyncToken=nothing::Union{Nothing,IR.Type}, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[asyncDependencies..., dst, value] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(asyncToken) && push!(op_ty_results, asyncToken) + + return create_operation( + "gpu.memset", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`num_subgroups` + +Returns the number of subgroups within a workgroup. + +# Example + +```mlir +%numSg = gpu.num_subgroups : index +``` + +If `upper_bound` is set, executions with more than `upper_bound` subgroups +per workgroup cause undefined behavior. There is a default upper bound of +`kMaxDim` (currently uint32_t::max). +""" +function num_subgroups(; + result=nothing::Union{Nothing,IR.Type}, upper_bound=nothing, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + !isnothing(upper_bound) && push!(attributes, namedattribute("upper_bound", upper_bound)) + + return create_operation( + "gpu.num_subgroups", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`printf` + +`gpu.printf` takes a literal format string `format` and an arbitrary number of +scalar arguments that should be printed. + +The format string is a C-style printf string, subject to any restrictions +imposed by one\'s target platform. +""" +function printf(args::Vector{Value}; format, location=Location()) + op_ty_results = IR.Type[] + operands = Value[args...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("format", format),] + + return create_operation( + "gpu.printf", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`return_` + +A terminator operation for regions that appear in the body of `gpu.func` +functions. The operands to the `gpu.return` are the result values returned +by an invocation of the `gpu.func`. +""" +function return_(operands::Vector{Value}; location=Location()) + op_ty_results = IR.Type[] + operands = Value[operands...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "gpu.return", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`sddmm_buffer_size` + +The `gpu.sddmm_buffer_size` operation returns the buffer size required +to perform the SDDMM operation on the given sparse and dense matrices. +The operation expects handles returned by previous sparse operations +to construct an environment and the operands for SDDMM. + +If the `async` keyword is present, the op is executed asynchronously (i.e. +it does not block until the execution has finished on the device). In +that case, it returns a !gpu.async.token in addition to the environment. + +# Example + +```mlir +%buffersz, %token = gpu.sddmm_buffer_size async [%dep] %dnmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %spmatC into f32 +``` + +The matrix arguments can also be associated with one of the following +operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value +is NON_TRANSPOSE. +""" +function sddmm_buffer_size( + asyncDependencies::Vector{Value}, + dnmatA::Value, + dnmatB::Value, + spmatC::Value; + bufferSz::IR.Type, + asyncToken=nothing::Union{Nothing,IR.Type}, + modeA=nothing, + modeB=nothing, + computeType, + location=Location(), +) + op_ty_results = IR.Type[bufferSz,] + operands = Value[asyncDependencies..., dnmatA, dnmatB, spmatC] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("computeType", computeType),] + !isnothing(asyncToken) && push!(op_ty_results, asyncToken) + !isnothing(modeA) && push!(attributes, namedattribute("modeA", modeA)) + !isnothing(modeB) && push!(attributes, namedattribute("modeB", modeB)) + + return create_operation( + "gpu.sddmm_buffer_size", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`sddmm` + +The `gpu.sddmm` operation performs the SDDMM operation on the given sparse and +dense matrices, and buffer. The operation expects handles returned by previous +sparse operations to construct an environment and the operands for SDDMM. The +buffer must have been allocated on the device. + +If the `async` keyword is present, the op is executed asynchronously (i.e. +it does not block until the execution has finished on the device). In +that case, it returns a !gpu.async.token in addition to the environment. + +# Example + +```mlir +%token = gpu.sddmm async [%dep] %dnmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %spmatC, %buffer into f32 +``` + +The matrix arguments can also be associated with one of the following +operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value +is NON_TRANSPOSE. +""" +function sddmm( + asyncDependencies::Vector{Value}, + dnmatA::Value, + dnmatB::Value, + spmatC::Value, + buffer::Value; + asyncToken=nothing::Union{Nothing,IR.Type}, + modeA=nothing, + modeB=nothing, + computeType, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[asyncDependencies..., dnmatA, dnmatB, spmatC, buffer] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("computeType", computeType),] + !isnothing(asyncToken) && push!(op_ty_results, asyncToken) + !isnothing(modeA) && push!(attributes, namedattribute("modeA", modeA)) + !isnothing(modeB) && push!(attributes, namedattribute("modeB", modeB)) + + return create_operation( + "gpu.sddmm", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`set_csr_pointers` + +The `gpu.set_csr_pointers` assigns the given positions, coordinates, +and values buffer that reside on the device directly to the given sparse +matrix descriptor in csr format. + +If the `async` keyword is present, the op is executed asynchronously (i.e. +it does not block until the execution has finished on the device). In +that case, it returns a `!gpu.async.token` in addition to the environment. + +# Example + +```mlir +%token = gpu.set_csr_pointers async [%dep] %positions, %coordinates, %values + : memref, memref, memref +``` +""" +function set_csr_pointers( + asyncDependencies::Vector{Value}, + spmat::Value, + positions::Value, + coordinates::Value, + values::Value; + asyncToken=nothing::Union{Nothing,IR.Type}, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[asyncDependencies..., spmat, positions, coordinates, values] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(asyncToken) && push!(op_ty_results, asyncToken) + + return create_operation( + "gpu.set_csr_pointers", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`set_default_device` + +Operation that sets the current default GPU, using a zero-based index +into the set of GPUs on the system. The default GPU setting may be +thread-local. +""" +function set_default_device(devIndex::Value; location=Location()) + op_ty_results = IR.Type[] + operands = Value[devIndex,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "gpu.set_default_device", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`shuffle` + +The \"shuffle\" op moves values to a across lanes (a.k.a., invocations, +work items) within the same subgroup. The `width` argument specifies the +number of lanes that participate in the shuffle, and must be uniform +across all lanes. Further, the first `width` lanes of the subgroup must +be active. + +The intepretation of the `offset` arguments depends on the selected +`mode`. + +Returns the `shuffleResult` and `true` if the current lane id is smaller +than `width`, and an unspecified value and `false` otherwise. + +`xor` example: + +```mlir +%1, %2 = gpu.shuffle xor %0, %offset, %width : f32 +``` + +For lane `k`, returns the value `%0` from lane `k ^ offset`. Every lane +trades value with exactly one other lane. + +`down` example: + +```mlir +%cst1 = arith.constant 1 : i32 +%3, %4 = gpu.shuffle down %0, %cst1, %width : f32 +``` + +For lane `k`, returns the value from lane `(k + 1) % width`. + +`up` example: + +```mlir +%cst1 = arith.constant 1 : i32 +%5, %6 = gpu.shuffle up %0, %cst1, %width : f32 +``` + +For lane `k`, returns the value from lane `(k - 1) % width`. + +`idx` example: + +```mlir +%cst0 = arith.constant 0 : i32 +%7, %8 = gpu.shuffle idx %0, %cst0, %width : f32 +``` + +Broadcasts the value from lane 0 to all lanes. +""" +function shuffle( + value::Value, + offset::Value, + width::Value; + shuffleResult=nothing::Union{Nothing,IR.Type}, + valid=nothing::Union{Nothing,IR.Type}, + mode, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[value, offset, width] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("mode", mode),] + !isnothing(shuffleResult) && push!(op_ty_results, shuffleResult) + !isnothing(valid) && push!(op_ty_results, valid) + + return create_operation( + "gpu.shuffle", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`spgemm_copy` + +The `gpu.spgemm_copy` operation copies the sparse matrix result of +a SpGEMM computation. + +If the `async` keyword is present, the op is executed asynchronously (i.e. +it does not block until the execution has finished on the device). In +that case, it returns a `!gpu.async.token` in addition to the environment. + +# Example + +```mlir +gpu.spgemm_copy %spmatA, %spmatB, %spmatC, %spgemmDesc: f32 +``` + +The matrix arguments can also be associated with one of the following +operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value +is NON_TRANSPOSE. +""" +function spgemm_copy( + asyncDependencies::Vector{Value}, + desc::Value, + spmatA::Value, + spmatB::Value, + spmatC::Value; + asyncToken=nothing::Union{Nothing,IR.Type}, + modeA=nothing, + modeB=nothing, + computeType, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[asyncDependencies..., desc, spmatA, spmatB, spmatC] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("computeType", computeType),] + !isnothing(asyncToken) && push!(op_ty_results, asyncToken) + !isnothing(modeA) && push!(attributes, namedattribute("modeA", modeA)) + !isnothing(modeB) && push!(attributes, namedattribute("modeB", modeB)) + + return create_operation( + "gpu.spgemm_copy", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`spgemm_create_descr` + +The `gpu.spgemm_create_descr` creates a descriptor for the SpGEMM operation. +The descriptor describes the SpGEMM operation and stores the internal data +throughout the computation. It needs to be passed as an argument to +spgemm_* operations. + +If the `async` keyword is present, the op is executed asynchronously (i.e. +it does not block until the execution has finished on the device). In +that case, it returns a `!gpu.async.token` in addition to the environment. + +# Example + +```mlir +%desc, %token = gpu.spgemm_create_descr async [%dep] +``` +""" +function spgemm_create_descr( + asyncDependencies::Vector{Value}; + desc::IR.Type, + asyncToken=nothing::Union{Nothing,IR.Type}, + location=Location(), +) + op_ty_results = IR.Type[desc,] + operands = Value[asyncDependencies...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(asyncToken) && push!(op_ty_results, asyncToken) + + return create_operation( + "gpu.spgemm_create_descr", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`spgemm_destroy_descr` + +The `gpu.spgemm_destroy_descr` destroys the SpGEMM operation descriptor. + +If the `async` keyword is present, the op is executed asynchronously (i.e. +it does not block until the execution has finished on the device). In +that case, it returns a `!gpu.async.token` in addition to the environment. + +# Example + +```mlir +%token = gpu.spgemm_destroy_descr async [%dep] %desc +``` +""" +function spgemm_destroy_descr( + asyncDependencies::Vector{Value}, + desc::Value; + asyncToken=nothing::Union{Nothing,IR.Type}, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[asyncDependencies..., desc] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(asyncToken) && push!(op_ty_results, asyncToken) + + return create_operation( + "gpu.spgemm_destroy_descr", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`spgemm_work_estimation_or_compute` + +The `gpu.spgemm_work_estimation_or_compute` is used to call +cusparseSpGEMM_workEstimation or cusparseSpGEMM_compute. Both of them are +for both determining the buffer size and performing the actual computation. +The operation expects handles returned by previous sparse operations to +construct an environment and the operands for SpGEMM. +The buffer must have been allocated on the device. + +C\' = alpha * op(A) * op(B) + beta * C + +If the `async` keyword is present, the op is executed asynchronously (i.e. +it does not block until the execution has finished on the device). In +that case, it returns a `!gpu.async.token` in addition to the environment. + +# Example + +```mlir +%bufferSz, %token = gpu.spgemm_work_estimation_or_compute async [%dep] {COMPUTE} + %desc, %spmatA{NON_TRANSPOSE}, %spmatB{NON_TRANSPOSE}, + %spmatC, %spgemmDesc, %c0, %alloc: f32 into + memref<0xi8> +``` + +The matrix arguments can also be associated with one of the following +operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value +is NON_TRANSPOSE. +""" +function spgemm_work_estimation_or_compute( + asyncDependencies::Vector{Value}, + desc::Value, + spmatA::Value, + spmatB::Value, + spmatC::Value, + bufferSz::Value, + buffer::Value; + bufferSzNew::IR.Type, + asyncToken=nothing::Union{Nothing,IR.Type}, + modeA=nothing, + modeB=nothing, + computeType, + kind, + location=Location(), +) + op_ty_results = IR.Type[bufferSzNew,] + operands = Value[asyncDependencies..., desc, spmatA, spmatB, spmatC, bufferSz, buffer] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("computeType", computeType), namedattribute("kind", kind) + ] + !isnothing(asyncToken) && push!(op_ty_results, asyncToken) + !isnothing(modeA) && push!(attributes, namedattribute("modeA", modeA)) + !isnothing(modeB) && push!(attributes, namedattribute("modeB", modeB)) + + return create_operation( + "gpu.spgemm_work_estimation_or_compute", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`spmm_buffer_size` + +The `gpu.spmm_buffer_size` operation returns the buffer size required +to perform the SpMM operation on the given sparse and dense matrix. +The operation expects handles returned by previous sparse operations +to construct an environment and the operands for SpMM. + +If the `async` keyword is present, the op is executed asynchronously (i.e. +it does not block until the execution has finished on the device). In +that case, it returns a !gpu.async.token in addition to the environment. + +The matrix arguments can also be associated with one of the following +operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value +is NON_TRANSPOSE. + +# Example + +```mlir +%bufferszs, %token = gpu.spmm_buffer_size async [%dep] %spmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %dnmatC : i64 into f32 +``` +""" +function spmm_buffer_size( + asyncDependencies::Vector{Value}, + spmatA::Value, + dnmatB::Value, + dnmatC::Value; + bufferSzs::Vector{IR.Type}, + asyncToken=nothing::Union{Nothing,IR.Type}, + modeA=nothing, + modeB=nothing, + computeType, + location=Location(), +) + op_ty_results = IR.Type[bufferSzs...,] + operands = Value[asyncDependencies..., spmatA, dnmatB, dnmatC] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("computeType", computeType),] + !isnothing(asyncToken) && push!(op_ty_results, asyncToken) + !isnothing(modeA) && push!(attributes, namedattribute("modeA", modeA)) + !isnothing(modeB) && push!(attributes, namedattribute("modeB", modeB)) + + return create_operation( + "gpu.spmm_buffer_size", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`spmm` + +The `gpu.spmm` operation performs the SpMM operation on the given sparse and +dense matrix, and buffer. The operation expects handles returned by previous +sparse operations to construct an environment and the operands for SpMM. The +buffer must have been allocated on the device. + +If the `async` keyword is present, the op is executed asynchronously (i.e. +it does not block until the execution has finished on the device). In +that case, it returns a !gpu.async.token in addition to the environment. + +The matrix arguments can also be associated with one of the following +operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value +is NON_TRANSPOSE. + +# Example + +```mlir +%token = gpu.spmm async [%dep] %spmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %dnmatC, %buffers : type(\$buffers) into f32 +``` +""" +function spmm( + asyncDependencies::Vector{Value}, + spmatA::Value, + dnmatB::Value, + dnmatC::Value, + buffers::Vector{Value}; + asyncToken=nothing::Union{Nothing,IR.Type}, + modeA=nothing, + modeB=nothing, + computeType, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[asyncDependencies..., spmatA, dnmatB, dnmatC, buffers...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("computeType", computeType),] + push!( + attributes, + operandsegmentsizes([length(asyncDependencies), 1, 1, 1, length(buffers)]), + ) + !isnothing(asyncToken) && push!(op_ty_results, asyncToken) + !isnothing(modeA) && push!(attributes, namedattribute("modeA", modeA)) + !isnothing(modeB) && push!(attributes, namedattribute("modeB", modeB)) + + return create_operation( + "gpu.spmm", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`spmv_buffer_size` + +The `gpu.spmv_buffer_size` operation returns the buffer size required +to perform the SpMV operation on the given sparse matrix and dense vectors. +The operation expects handles returned by previous sparse operations +to construct an environment and the operands for SpMV. + +If the `async` keyword is present, the op is executed asynchronously (i.e. +it does not block until the execution has finished on the device). In +that case, it returns a !gpu.async.token in addition to the environment. + +The matrix arguments can also be associated with one of the following +operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value +is NON_TRANSPOSE. + +# Example + +```mlir +%buffersz, %token = gpu.spmv_buffer_size async [%dep] %spmatA{TRANSPOSE}, %dnX, %dnY into f32 +``` +""" +function spmv_buffer_size( + asyncDependencies::Vector{Value}, + spmatA::Value, + dnX::Value, + dnY::Value; + bufferSz::IR.Type, + asyncToken=nothing::Union{Nothing,IR.Type}, + modeA=nothing, + computeType, + location=Location(), +) + op_ty_results = IR.Type[bufferSz,] + operands = Value[asyncDependencies..., spmatA, dnX, dnY] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("computeType", computeType),] + !isnothing(asyncToken) && push!(op_ty_results, asyncToken) + !isnothing(modeA) && push!(attributes, namedattribute("modeA", modeA)) + + return create_operation( + "gpu.spmv_buffer_size", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`spmv` + +The `gpu.spmv` operation performs the SpMV operation on the given sparse matrix, +dense vectors, and buffer. The operation expects handles returned by previous +sparse operations to construct an environment and the operands for SpMV. The +buffer must have been allocated on the device. + +If the `async` keyword is present, the op is executed asynchronously (i.e. +it does not block until the execution has finished on the device). In +that case, it returns a !gpu.async.token in addition to the environment. + +The matrix arguments can also be associated with one of the following +operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value +is NON_TRANSPOSE. + +# Example + +```mlir +%token = gpu.spmv async [%dep] %spmatA{TRANSPOSE}, %dnX, %dnY : memref into bf16 +``` +""" +function spmv( + asyncDependencies::Vector{Value}, + spmatA::Value, + dnX::Value, + dnY::Value, + buffer::Value; + asyncToken=nothing::Union{Nothing,IR.Type}, + modeA=nothing, + computeType, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[asyncDependencies..., spmatA, dnX, dnY, buffer] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("computeType", computeType),] + !isnothing(asyncToken) && push!(op_ty_results, asyncToken) + !isnothing(modeA) && push!(attributes, namedattribute("modeA", modeA)) + + return create_operation( + "gpu.spmv", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`spmat_get_size` + +The `gpu.spmat_get_size` operation retrieves the number of rows, number of +columns, and number of non-zero elements of a sparse matrix. + +If the `async` keyword is present, the op is executed asynchronously (i.e. +it does not block until the execution has finished on the device). In +that case, it returns a `!gpu.async.token` in addition to the environment. + +# Example + +```mlir +%rows, %cols, %nnz, %token = gpu.spmat_get_size async [%dep] %spmatC +``` +""" +function spmat_get_size( + asyncDependencies::Vector{Value}, + spmat::Value; + rows::IR.Type, + cols::IR.Type, + nnz::IR.Type, + asyncToken=nothing::Union{Nothing,IR.Type}, + location=Location(), +) + op_ty_results = IR.Type[rows, cols, nnz] + operands = Value[asyncDependencies..., spmat] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(asyncToken) && push!(op_ty_results, asyncToken) + + return create_operation( + "gpu.spmat_get_size", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`subgroup_id` + +Returns the subgroup id, i.e., the index of the current subgroup within the +workgroup. + +# Example + +```mlir +%sgId = gpu.subgroup_id : index +``` + +Executions where there are more than `upper_bound` subgroups per workgroup +cause undefined behavior. There is an implicit upper bound of `kMaxDim` +(currently uint32_t::max). +""" +function subgroup_id(; + result=nothing::Union{Nothing,IR.Type}, upper_bound=nothing, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + !isnothing(upper_bound) && push!(attributes, namedattribute("upper_bound", upper_bound)) + + return create_operation( + "gpu.subgroup_id", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`subgroup_mma_compute` + +The `gpu.subgroup_mma_compute` operation performs a matrix-multiply accumulate (mma) +operation using all the threads in a subgroup. + +This operation takes three `!gpu.mma_matrix`s as arguments: these hold `A`, +`B` and `C`operands for the mma operation. The operation performed is represented +as `C += A * B`. The op returns a `!gpu.mma_matrix` which contains the result of +the operation held by all threads in a subgroup. `a_transpose` or +`b_transpose` if present, signify that the respective operand was loaded in a +transposed manner. The transpose operands are required to map to correct +underlying intrisics but they currently do not seem to affect correctness +even if they are absent given that the operands were loaded correctly using +the `transpose` attribute in `gpu.subgroup_mma_load_matrix` op. + +For integer types, the `A` and `B` matrices carry their signedness with their +types. The accumulator type is expected to be signless and imply a signed integer +with a greater width than the other two operands. + +This op is meant to be used along with `gpu.subgroup_mma_store_matrix` and +`gpu.subgroup_mma_load_matrix` ops. + +# Example + +```mlir +%D = gpu.subgroup_mma_compute_matrix %A, %B, %C : + !gpu.mma_matrix<16x16xf16, \"AOp\">, !gpu.mma_matrix<16x16xf16, \"BOp\">> + -> !gpu.mma_matrix<16x16xf16, \"COp\"> +``` +""" +function subgroup_mma_compute( + opA::Value, + opB::Value, + opC::Value; + res=nothing::Union{Nothing,IR.Type}, + a_transpose=nothing, + b_transpose=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[opA, opB, opC] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(res) && push!(op_ty_results, res) + !isnothing(a_transpose) && push!(attributes, namedattribute("a_transpose", a_transpose)) + !isnothing(b_transpose) && push!(attributes, namedattribute("b_transpose", b_transpose)) + + return create_operation( + "gpu.subgroup_mma_compute", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`subgroup_mma_constant_matrix` + +The `gpu.subgroup_mma_constant_matrix` creates a `!gpu.mma_matrix` with +constant elements. + +The operation takes a scalar input and return a `!gpu.mma_matrix` where +each element of is equal to the operand constant. The destination +mma_matrix type must have elememt type equal to the constant type. Since +the layout of `!gpu.mma_matrix` is opaque this only support setting all the +elements to the same value. + +This op is meant to be used along with `gpu.subgroup_mma_compute`. + +# Example + +```mlir + %0 = gpu.subgroup_mma_constant_matrix %a : + !gpu.mma_matrix<16x16xf16, \"AOp\"> + %1 = gpu.subgroup_mma_constant_matrix %b : + !gpu.mma_matrix<16x16xf32, \"COp\"> +``` +""" +function subgroup_mma_constant_matrix(value::Value; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[value,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "gpu.subgroup_mma_constant_matrix", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`subgroup_mma_elementwise` + +The `gpu.subgroup_mma_elementwise` takes `!gpu.mma_matrix` inputs and +compute a new `!gpu.mma_matrix` by applying an elementwise operation to each +element. + +Since the operation is elementwise and the matrix type must match, the +matrix elements are processed independently of the matrix layout. + +This op is meant to be used along with `gpu.subgroup_mma_compute`. + +# Example + +```mlir + %0 = %A, %B { opType = \"ADD\" } : + (!gpu.mma_matrix<16x16xf16, \"COp\">, !gpu.mma_matrix<16x16xf16, \"COp\">) + -> !gpu.mma_matrix<16x16xf16, \"COp\"> +``` +""" +function subgroup_mma_elementwise( + args::Vector{Value}; res::IR.Type, opType, location=Location() +) + op_ty_results = IR.Type[res,] + operands = Value[args...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("opType", opType),] + + return create_operation( + "gpu.subgroup_mma_elementwise", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`subgroup_mma_load_matrix` + +The `gpu.subgroup_mma_load_matrix` operation loads a matrix collectively +using all the threads in a subgroup. + +This operation takes a memref as its first operand: it is the source matrix +from which data is to be loaded. The op returns a `!gpu.mma_matrix`. The +source memref can be in global memory or shared memory. The load address is +determined using `indices`. The matrix being loaded into is the result. The +`leadDimension` attribute specifies the leading dimension size of the source +matrix which eventually allows the lowering to determine the size of each +row. If the `transpose` attribute is present then the op does a transposed load. + +For integer types, the resulting `!gpu.mma_matrix` type needs to specify the +signedness of the data if the matrix type is an `A` or `B` operand for +`gpu.subgroup_mma_compute`. + +This op is often meant to be used along with `gpu.subgroup_mma_store_matrix` and +`gpu.subgroup_mma_compute`. + +# Example + +```mlir + %0 = gpu.subgroup_mma_load_matrix src[%i,%j] : {leadDimension = 32 : i32} + : memref<32x32xf16, 3>, !gpu.mma_matrix<16x16xf16, \"AOp\"> +``` +""" +function subgroup_mma_load_matrix( + srcMemref::Value, + indices::Vector{Value}; + res::IR.Type, + leadDimension, + transpose=nothing, + location=Location(), +) + op_ty_results = IR.Type[res,] + operands = Value[srcMemref, indices...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("leadDimension", leadDimension),] + !isnothing(transpose) && push!(attributes, namedattribute("transpose", transpose)) + + return create_operation( + "gpu.subgroup_mma_load_matrix", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`subgroup_mma_store_matrix` + +The `gpu.subgroup_mma_store_matrix` operation stores a matrix collectively +using all the threads in a subgroup. + +This operation takes a `!gpu.mma_matrix` and a memref as operands. +`!gpu.mma_matrix` is the source value containing the data to be stored into the +destination memref which can be in global or shared memory. The store address +is determined using the indices provided. The `leadDimension` attribute +specifies the leading dimension of the destination matrix. If the +`transpose` attribute is present then the op does a transposed store. + +This op is often meant to be used along with `gpu.subgroup_mma_load_matrix` and +`gpu.subgroup_mma_compute`. + +# Example + +```mlir +gpu.subgroup_mma_store_matrix %D, %sg[%i,%j] : { leadDimension = 32 : i32} + : !gpu.mma_matrix<16x16xf16, \"COp\">, memref<32x32xf16, 3> +``` +""" +function subgroup_mma_store_matrix( + src::Value, + dstMemref::Value, + indices::Vector{Value}; + leadDimension, + transpose=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[src, dstMemref, indices...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("leadDimension", leadDimension),] + !isnothing(transpose) && push!(attributes, namedattribute("transpose", transpose)) + + return create_operation( + "gpu.subgroup_mma_store_matrix", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`subgroup_reduce` + +The `subgroup_reduce` op reduces the values of lanes (work items) across a +subgroup. + +The subgroup is divided into clusters starting at lane index 0. Within each +cluster, there are `size` lanes, and the lane index advances by `stride`. +A reduction is done for each cluster in parallel: every lane in the cluster +is reduced, and the result is equal for all lanes in the cluster. If `size` +is omitted, there is a single cluster covering the entire subgroup. If +`stride` is omitted, the stride is 1 (the cluster\'s lanes are contiguous). + +When the reduced value is of a vector type, each vector element is reduced +independently. Only 1-d vector types are allowed. + +# Example + +```mlir +%1 = gpu.subgroup_reduce add %a : (f32) -> f32 +%2 = gpu.subgroup_reduce add %b : (vector<4xf16>) -> vector<4xf16> +%3 = gpu.subgroup_reduce add %c cluster(size = 4) : (f32) -> f32 +%3 = gpu.subgroup_reduce add %c cluster(size = 4, stride = 2) : (f32) -> f32 +``` + +If `uniform` flag is set either none or all lanes of a subgroup need to execute +this op in convergence. + +The reduction operation must be one of: +* Integer types: `add`, `mul`, `minui`, `minsi`, `maxui`, `maxsi`, `and`, + `or`, `xor` +* Floating point types: `add`, `mul`, `minnumf`, `maxnumf`, `minimumf`, + `maximumf` +""" +function subgroup_reduce( + value::Value; + result=nothing::Union{Nothing,IR.Type}, + op, + uniform=nothing, + cluster_size=nothing, + cluster_stride=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[value,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("op", op),] + !isnothing(result) && push!(op_ty_results, result) + !isnothing(uniform) && push!(attributes, namedattribute("uniform", uniform)) + !isnothing(cluster_size) && + push!(attributes, namedattribute("cluster_size", cluster_size)) + !isnothing(cluster_stride) && + push!(attributes, namedattribute("cluster_stride", cluster_stride)) + + return create_operation( + "gpu.subgroup_reduce", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`subgroup_size` + +Returns the number of threads within a subgroup. + +# Example + +```mlir +%sgSz = gpu.subgroup_size : index +``` + +Executions where the number of threads per subgroup exceed `upper_bound` cause +undefined behavior. When no `upper_bound` is specified, range analyses and +similar machinery assume the default bound of `kMaxSubgroupSize`, currently +128. +""" +function subgroup_size(; + result=nothing::Union{Nothing,IR.Type}, upper_bound=nothing, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + !isnothing(upper_bound) && push!(attributes, namedattribute("upper_bound", upper_bound)) + + return create_operation( + "gpu.subgroup_size", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`terminator` + +A terminator operation for regions that appear in the body of `gpu.launch` +operation. These regions are not expected to return any value so the +terminator takes no operands. +""" +function terminator(; location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "gpu.terminator", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`thread_id` + +Returns the thread id, i.e. the index of the current thread within the block +along the x, y, or z `dimension`. + +# Example + +```mlir +%tIdX = gpu.thread_id x +``` + +If `upper_bound` is set, or if one can be inferred from `known_block_size`-type +annotations in context, executions where the thread index would be greater +than or equal to that bound cause undefined behavior. + +There is an implicit upper bound of `kMaxDim` (currently uint32_t::max). +""" +function thread_id(; + result_0=nothing::Union{Nothing,IR.Type}, + dimension, + upper_bound=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dimension", dimension),] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(upper_bound) && push!(attributes, namedattribute("upper_bound", upper_bound)) + + return create_operation( + "gpu.thread_id", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`wait` + +This op synchronizes the host or the device with a list of dependent ops. + +If the op contains the `async` keyword, it returns a new async token which +is synchronized with the op arguments. This new token is merely a shortcut +to the argument list, and one could replace the uses of the result with the +arguments for the same effect. The async version of this op is primarily +used to make each async token have a single use during lowering and +thereby make forks in async execution explicit. Example usage: + +```mlir +%t0 = gpu.foo async : !gpu.async.token +%t1 = gpu.bar async : !gpu.async.token +%t2 = gpu.wait async [%t0, %t1] +// gpu.baz doesn\'t run until gpu.foo and gpu.bar have both completed, just +// as if the async dependencies were [%t0, %t1]. +%t3 = gpu.baz async [%t2] +``` + +If the op does not contain the `async` keyword, it does not return a new +async token but blocks until all ops producing the async dependency tokens +finished execution. All dependent memory operations are visible to the host +once this op completes. Example usage: + +```mlir +%t0 = gpu.foo async : !gpu.async.token +%t1 = gpu.bar async : !gpu.async.token +// The gpu.wait op blocks until gpu.foo and gpu.bar have completed. +gpu.wait [%t0, %t1] +``` +""" +function wait( + asyncDependencies::Vector{Value}; + asyncToken=nothing::Union{Nothing,IR.Type}, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[asyncDependencies...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(asyncToken) && push!(op_ty_results, asyncToken) + + return create_operation( + "gpu.wait", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`warp_execute_on_lane_0` + +`warp_execute_on_lane_0` is an operation used to bridge the gap between +vector programming and SPMD programming model like GPU SIMT. It allows to +trivially convert a region of vector code meant to run on a multiple threads +into a valid SPMD region and then allows incremental transformation to +distribute vector operations on the threads. + +Any code present in the region would only be executed on first thread/lane +based on the `laneid` operand. The `laneid` operand is an integer ID between +[0, `warp_size`). The `warp_size` attribute indicates the number of lanes in +a warp. + +Operands are vector values distributed on all lanes that may be used by +the single lane execution. The matching region argument is a vector of all +the values of those lanes available to the single active lane. The +distributed dimension is implicit based on the shape of the operand and +argument. the properties of the distribution may be described by extra +attributes (e.g. affine map). + +Return values are distributed on all lanes using laneId as index. The +vector is distributed based on the shape ratio between the vector type of +the yield and the result type. +If the shapes are the same this means the value is broadcasted to all lanes. +In the future the distribution can be made more explicit using affine_maps +and will support having multiple Ids. + +Therefore the `warp_execute_on_lane_0` operations allow to implicitly copy +between lane0 and the lanes of the warp. When distributing a vector +from lane0 to all the lanes, the data are distributed in a block cyclic way. +For example `vector<64xf32>` gets distributed on 32 threads and map to +`vector<2xf32>` where thread 0 contains vector[0] and vector[1]. + +During lowering values passed as operands and return value need to be +visible to different lanes within the warp. This would usually be done by +going through memory. + +The region is *not* isolated from above. For values coming from the parent +region not going through operands only the lane 0 value will be accesible so +it generally only make sense for uniform values. + +# Example +``` +// Execute in parallel on all threads/lanes. +gpu.warp_execute_on_lane_0 (%laneid)[32] { + // Serial code running only on thread/lane 0. + ... +} +// Execute in parallel on all threads/lanes. +``` + +This may be lowered to an scf.if region as below: +``` + // Execute in parallel on all threads/lanes. + %cnd = arith.cmpi eq, %laneid, %c0 : index + scf.if %cnd { + // Serial code running only on thread/lane 0. + ... + } + // Execute in parallel on all threads/lanes. +``` + +When the region has operands and/or return values: +``` +// Execute in parallel on all threads/lanes. +%0 = gpu.warp_execute_on_lane_0(%laneid)[32] +args(%v0 : vector<4xi32>) -> (vector<1xf32>) { +^bb0(%arg0 : vector<128xi32>) : + // Serial code running only on thread/lane 0. + ... + gpu.yield %1 : vector<32xf32> +} +// Execute in parallel on all threads/lanes. +``` + +values at the region boundary would go through memory: +``` +// Execute in parallel on all threads/lanes. +... +// Store the data from each thread into memory and Synchronization. +%tmp0 = memreg.alloc() : memref<128xf32> +%tmp1 = memreg.alloc() : memref<32xf32> +%cnd = arith.cmpi eq, %laneid, %c0 : index +vector.store %v0, %tmp0[%laneid] : memref<128xf32>, vector<4xf32> +some_synchronization_primitive +scf.if %cnd { + // Serialized code running only on thread 0. + // Load the data from all the threads into a register from thread 0. This + // allow threads 0 to access data from all the threads. + %arg0 = vector.load %tmp0[%c0] : memref<128xf32>, vector<128xf32> + ... + // Store the data from thread 0 into memory. + vector.store %1, %tmp1[%c0] : memref<32xf32>, vector<32xf32> +} +// Synchronization and load the data in a block cyclic way so that the +// vector is distributed on all threads. +some_synchronization_primitive +%0 = vector.load %tmp1[%laneid] : memref<32xf32>, vector<32xf32> +// Execute in parallel on all threads/lanes. +``` +""" +function warp_execute_on_lane_0( + laneid::Value, + args::Vector{Value}; + results::Vector{IR.Type}, + warp_size, + warpRegion::Region, + location=Location(), +) + op_ty_results = IR.Type[results...,] + operands = Value[laneid, args...] + owned_regions = Region[warpRegion,] + successors = Block[] + attributes = NamedAttribute[namedattribute("warp_size", warp_size),] + + return create_operation( + "gpu.warp_execute_on_lane_0", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`yield` + +gpu.yield` is a special terminator operation for blocks inside regions +in gpu ops. It returns values to the immediately enclosing gpu op. + +# Example + +```mlir +gpu.yield %f0, %f1 : f32, f32 +``` +""" +function yield(values::Vector{Value}; location=Location()) + op_ty_results = IR.Type[] + operands = Value[values...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "gpu.yield", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +end # gpu diff --git a/src/mlir/Dialects/Llvm.jl b/src/mlir/Dialects/Llvm.jl new file mode 100755 index 0000000000..38ff9f89fa --- /dev/null +++ b/src/mlir/Dialects/Llvm.jl @@ -0,0 +1,2619 @@ +module llvm +using ...IR +import ...IR: + NamedAttribute, + Value, + Location, + Block, + Region, + Attribute, + create_operation, + context, + IndexType +import ..Dialects: namedattribute, operandsegmentsizes +import ...API + +function ashr( + lhs::Value, + rhs::Value; + res=nothing::Union{Nothing,IR.Type}, + isExact=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(res) && push!(op_ty_results, res) + !isnothing(isExact) && push!(attributes, namedattribute("isExact", isExact)) + + return create_operation( + "llvm.ashr", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function add( + lhs::Value, rhs::Value; res=nothing::Union{Nothing,IR.Type}, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(res) && push!(op_ty_results, res) + + return create_operation( + "llvm.add", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function addrspacecast(arg::Value; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[arg,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "llvm.addrspacecast", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`mlir_addressof` + +Creates an SSA value containing a pointer to a global value (function, +variable or alias). The global value can be defined after its first +referenced. If the global value is a constant, storing into it is not +allowed. + +Examples: + +```mlir +func @foo() { + // Get the address of a global variable. + %0 = llvm.mlir.addressof @const : !llvm.ptr + + // Use it as a regular pointer. + %1 = llvm.load %0 : !llvm.ptr -> i32 + + // Get the address of a function. + %2 = llvm.mlir.addressof @foo : !llvm.ptr + + // The function address can be used for indirect calls. + llvm.call %2() : !llvm.ptr, () -> () + + // Get the address of an aliased global. + %3 = llvm.mlir.addressof @const_alias : !llvm.ptr +} + +// Define the global. +llvm.mlir.global @const(42 : i32) : i32 + +// Define an alias. +llvm.mlir.alias @const_alias : i32 { + %0 = llvm.mlir.addressof @const : !llvm.ptr + llvm.return %0 : !llvm.ptr +} +``` +""" +function mlir_addressof(; res::IR.Type, global_name, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("global_name", global_name),] + + return create_operation( + "llvm.mlir.addressof", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`mlir_alias` + +`llvm.mlir.alias` is a top level operation that defines a global alias for +global variables and functions. The operation is always initialized by +using a initializer region which could be a direct map to another global +value or contain some address computation on top of it. + +It uses a symbol for its value, which will be uniqued by the module +with respect to other symbols in it. + +Similarly to functions and globals, they can also have a linkage attribute. +This attribute is placed between `llvm.mlir.alias` and the symbol name. If +the attribute is omitted, `external` linkage is assumed by default. + +Examples: + +```mlir +// Global alias use @-identifiers. +llvm.mlir.alias external @foo_alias {addr_space = 0 : i32} : !llvm.ptr { + %0 = llvm.mlir.addressof @some_function : !llvm.ptr + llvm.return %0 : !llvm.ptr +} + +// More complex initialization. +llvm.mlir.alias linkonce_odr hidden @glob +{addr_space = 0 : i32, dso_local} : !llvm.array<32 x i32> { + %0 = llvm.mlir.constant(1234 : i64) : i64 + %1 = llvm.mlir.addressof @glob.private : !llvm.ptr + %2 = llvm.ptrtoint %1 : !llvm.ptr to i64 + %3 = llvm.add %2, %0 : i64 + %4 = llvm.inttoptr %3 : i64 to !llvm.ptr + llvm.return %4 : !llvm.ptr +} +``` +""" +function mlir_alias(; + alias_type, + sym_name, + linkage, + dso_local=nothing, + thread_local_=nothing, + unnamed_addr=nothing, + visibility_=nothing, + initializer::Region, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[initializer,] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("alias_type", alias_type), + namedattribute("sym_name", sym_name), + namedattribute("linkage", linkage), + ] + !isnothing(dso_local) && push!(attributes, namedattribute("dso_local", dso_local)) + !isnothing(thread_local_) && + push!(attributes, namedattribute("thread_local_", thread_local_)) + !isnothing(unnamed_addr) && + push!(attributes, namedattribute("unnamed_addr", unnamed_addr)) + !isnothing(visibility_) && push!(attributes, namedattribute("visibility_", visibility_)) + + return create_operation( + "llvm.mlir.alias", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function alloca( + arraySize::Value; + res::IR.Type, + alignment=nothing, + elem_type, + inalloca=nothing, + location=Location(), +) + op_ty_results = IR.Type[res,] + operands = Value[arraySize,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("elem_type", elem_type),] + !isnothing(alignment) && push!(attributes, namedattribute("alignment", alignment)) + !isnothing(inalloca) && push!(attributes, namedattribute("inalloca", inalloca)) + + return create_operation( + "llvm.alloca", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function and( + lhs::Value, rhs::Value; res=nothing::Union{Nothing,IR.Type}, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(res) && push!(op_ty_results, res) + + return create_operation( + "llvm.and", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function cmpxchg( + ptr::Value, + cmp::Value, + val::Value; + res=nothing::Union{Nothing,IR.Type}, + success_ordering, + failure_ordering, + syncscope=nothing, + alignment=nothing, + weak=nothing, + volatile_=nothing, + access_groups=nothing, + alias_scopes=nothing, + noalias_scopes=nothing, + tbaa=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[ptr, cmp, val] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("success_ordering", success_ordering), + namedattribute("failure_ordering", failure_ordering), + ] + !isnothing(res) && push!(op_ty_results, res) + !isnothing(syncscope) && push!(attributes, namedattribute("syncscope", syncscope)) + !isnothing(alignment) && push!(attributes, namedattribute("alignment", alignment)) + !isnothing(weak) && push!(attributes, namedattribute("weak", weak)) + !isnothing(volatile_) && push!(attributes, namedattribute("volatile_", volatile_)) + !isnothing(access_groups) && + push!(attributes, namedattribute("access_groups", access_groups)) + !isnothing(alias_scopes) && + push!(attributes, namedattribute("alias_scopes", alias_scopes)) + !isnothing(noalias_scopes) && + push!(attributes, namedattribute("noalias_scopes", noalias_scopes)) + !isnothing(tbaa) && push!(attributes, namedattribute("tbaa", tbaa)) + + return create_operation( + "llvm.cmpxchg", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function atomicrmw( + ptr::Value, + val::Value; + res=nothing::Union{Nothing,IR.Type}, + bin_op, + ordering, + syncscope=nothing, + alignment=nothing, + volatile_=nothing, + access_groups=nothing, + alias_scopes=nothing, + noalias_scopes=nothing, + tbaa=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[ptr, val] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("bin_op", bin_op), namedattribute("ordering", ordering) + ] + !isnothing(res) && push!(op_ty_results, res) + !isnothing(syncscope) && push!(attributes, namedattribute("syncscope", syncscope)) + !isnothing(alignment) && push!(attributes, namedattribute("alignment", alignment)) + !isnothing(volatile_) && push!(attributes, namedattribute("volatile_", volatile_)) + !isnothing(access_groups) && + push!(attributes, namedattribute("access_groups", access_groups)) + !isnothing(alias_scopes) && + push!(attributes, namedattribute("alias_scopes", alias_scopes)) + !isnothing(noalias_scopes) && + push!(attributes, namedattribute("noalias_scopes", noalias_scopes)) + !isnothing(tbaa) && push!(attributes, namedattribute("tbaa", tbaa)) + + return create_operation( + "llvm.atomicrmw", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function bitcast(arg::Value; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[arg,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "llvm.bitcast", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function br( + destOperands::Vector{Value}; loop_annotation=nothing, dest::Block, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[destOperands...,] + owned_regions = Region[] + successors = Block[dest,] + attributes = NamedAttribute[] + !isnothing(loop_annotation) && + push!(attributes, namedattribute("loop_annotation", loop_annotation)) + + return create_operation( + "llvm.br", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`call_intrinsic` + +Call the specified llvm intrinsic. If the intrinsic is overloaded, use +the MLIR function type of this op to determine which intrinsic to call. +""" +function call_intrinsic( + args::Vector{Value}, + op_bundle_operands::Vector{Value}; + results=nothing::Union{Nothing,IR.Type}, + intrin, + fastmathFlags=nothing, + op_bundle_sizes, + op_bundle_tags=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[args..., op_bundle_operands...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("intrin", intrin), namedattribute("op_bundle_sizes", op_bundle_sizes) + ] + push!(attributes, operandsegmentsizes([length(args), length(op_bundle_operands)])) + !isnothing(results) && push!(op_ty_results, results) + !isnothing(fastmathFlags) && + push!(attributes, namedattribute("fastmathFlags", fastmathFlags)) + !isnothing(op_bundle_tags) && + push!(attributes, namedattribute("op_bundle_tags", op_bundle_tags)) + + return create_operation( + "llvm.call_intrinsic", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`call` + +In LLVM IR, functions may return either 0 or 1 value. LLVM IR dialect +implements this behavior by providing a variadic `call` operation for 0- and +1-result functions. Even though MLIR supports multi-result functions, LLVM +IR dialect disallows them. + +The `call` instruction supports both direct and indirect calls. Direct calls +start with a function name (`@`-prefixed) and indirect calls start with an +SSA value (`%`-prefixed). The direct callee, if present, is stored as a +function attribute `callee`. For indirect calls, the callee is of `!llvm.ptr` type +and is stored as the first value in `callee_operands`. If and only if the +callee is a variadic function, the `var_callee_type` attribute must carry +the variadic LLVM function type. The trailing type list contains the +optional indirect callee type and the MLIR function type, which differs from +the LLVM function type that uses an explicit void type to model functions +that do not return a value. + +Examples: + +```mlir +// Direct call without arguments and with one result. +%0 = llvm.call @foo() : () -> (f32) + +// Direct call with arguments and without a result. +llvm.call @bar(%0) : (f32) -> () + +// Indirect call with an argument and without a result. +%1 = llvm.mlir.addressof @foo : !llvm.ptr +llvm.call %1(%0) : !llvm.ptr, (f32) -> () + +// Direct variadic call. +llvm.call @printf(%0, %1) vararg(!llvm.func) : (!llvm.ptr, i32) -> i32 + +// Indirect variadic call +llvm.call %1(%0) vararg(!llvm.func) : !llvm.ptr, (i32) -> () +``` +""" +function call( + callee_operands::Vector{Value}, + op_bundle_operands::Vector{Value}; + result=nothing::Union{Nothing,IR.Type}, + var_callee_type=nothing, + callee=nothing, + fastmathFlags=nothing, + branch_weights=nothing, + CConv=nothing, + TailCallKind=nothing, + memory_effects=nothing, + convergent=nothing, + no_unwind=nothing, + will_return=nothing, + op_bundle_sizes, + op_bundle_tags=nothing, + arg_attrs=nothing, + res_attrs=nothing, + access_groups=nothing, + alias_scopes=nothing, + noalias_scopes=nothing, + tbaa=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[callee_operands..., op_bundle_operands...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("op_bundle_sizes", op_bundle_sizes),] + push!( + attributes, + operandsegmentsizes([length(callee_operands), length(op_bundle_operands)]), + ) + !isnothing(result) && push!(op_ty_results, result) + !isnothing(var_callee_type) && + push!(attributes, namedattribute("var_callee_type", var_callee_type)) + !isnothing(callee) && push!(attributes, namedattribute("callee", callee)) + !isnothing(fastmathFlags) && + push!(attributes, namedattribute("fastmathFlags", fastmathFlags)) + !isnothing(branch_weights) && + push!(attributes, namedattribute("branch_weights", branch_weights)) + !isnothing(CConv) && push!(attributes, namedattribute("CConv", CConv)) + !isnothing(TailCallKind) && + push!(attributes, namedattribute("TailCallKind", TailCallKind)) + !isnothing(memory_effects) && + push!(attributes, namedattribute("memory_effects", memory_effects)) + !isnothing(convergent) && push!(attributes, namedattribute("convergent", convergent)) + !isnothing(no_unwind) && push!(attributes, namedattribute("no_unwind", no_unwind)) + !isnothing(will_return) && push!(attributes, namedattribute("will_return", will_return)) + !isnothing(op_bundle_tags) && + push!(attributes, namedattribute("op_bundle_tags", op_bundle_tags)) + !isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs)) + !isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs)) + !isnothing(access_groups) && + push!(attributes, namedattribute("access_groups", access_groups)) + !isnothing(alias_scopes) && + push!(attributes, namedattribute("alias_scopes", alias_scopes)) + !isnothing(noalias_scopes) && + push!(attributes, namedattribute("noalias_scopes", noalias_scopes)) + !isnothing(tbaa) && push!(attributes, namedattribute("tbaa", tbaa)) + + return create_operation( + "llvm.call", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`comdat` + +Provides access to object file COMDAT section/group functionality. + +Examples: +```mlir +llvm.comdat @__llvm_comdat { + llvm.comdat_selector @any any +} +llvm.mlir.global internal constant @has_any_comdat(1 : i64) comdat(@__llvm_comdat::@any) : i64 +``` +""" +function comdat(; sym_name, body::Region, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[body,] + successors = Block[] + attributes = NamedAttribute[namedattribute("sym_name", sym_name),] + + return create_operation( + "llvm.comdat", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`comdat_selector` + +Provides access to object file COMDAT section/group functionality. + +Examples: +```mlir +llvm.comdat @__llvm_comdat { + llvm.comdat_selector @any any +} +llvm.mlir.global internal constant @has_any_comdat(1 : i64) comdat(@__llvm_comdat::@any) : i64 +``` +""" +function comdat_selector(; sym_name, comdat, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("sym_name", sym_name), namedattribute("comdat", comdat) + ] + + return create_operation( + "llvm.comdat_selector", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function cond_br( + condition::Value, + trueDestOperands::Vector{Value}, + falseDestOperands::Vector{Value}; + branch_weights=nothing, + loop_annotation=nothing, + trueDest::Block, + falseDest::Block, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[condition, trueDestOperands..., falseDestOperands...] + owned_regions = Region[] + successors = Block[trueDest, falseDest] + attributes = NamedAttribute[] + push!( + attributes, + operandsegmentsizes([1, length(trueDestOperands), length(falseDestOperands)]), + ) + !isnothing(branch_weights) && + push!(attributes, namedattribute("branch_weights", branch_weights)) + !isnothing(loop_annotation) && + push!(attributes, namedattribute("loop_annotation", loop_annotation)) + + return create_operation( + "llvm.cond_br", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`mlir_constant` + +Unlike LLVM IR, MLIR does not have first-class constant values. Therefore, +all constants must be created as SSA values before being used in other +operations. `llvm.mlir.constant` creates such values for scalars, vectors, +strings, and structs. It has a mandatory `value` attribute whose type +depends on the type of the constant value. The type of the constant value +must correspond to the attribute type converted to LLVM IR type. + +When creating constant scalars, the `value` attribute must be either an +integer attribute or a floating point attribute. The type of the attribute +may be omitted for `i64` and `f64` types that are implied. + +When creating constant vectors, the `value` attribute must be either an +array attribute, a dense attribute, or a sparse attribute that contains +integers or floats. The number of elements in the result vector must match +the number of elements in the attribute. + +When creating constant strings, the `value` attribute must be a string +attribute. The type of the constant must be an LLVM array of `i8`s, and the +length of the array must match the length of the attribute. + +When creating constant structs, the `value` attribute must be an array +attribute that contains integers or floats. The type of the constant must be +an LLVM struct type. The number of fields in the struct must match the +number of elements in the attribute, and the type of each LLVM struct field +must correspond to the type of the corresponding attribute element converted +to LLVM IR. + +Examples: + +```mlir +// Integer constant, internal i32 is mandatory +%0 = llvm.mlir.constant(42 : i32) : i32 + +// It\'s okay to omit i64. +%1 = llvm.mlir.constant(42) : i64 + +// Floating point constant. +%2 = llvm.mlir.constant(42.0 : f32) : f32 + +// Splat dense vector constant. +%3 = llvm.mlir.constant(dense<1.0> : vector<4xf32>) : vector<4xf32> +``` +""" +function mlir_constant(; res::IR.Type, value, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("value", value),] + + return create_operation( + "llvm.mlir.constant", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function extractelement( + vector::Value, position::Value; res=nothing::Union{Nothing,IR.Type}, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[vector, position] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(res) && push!(op_ty_results, res) + + return create_operation( + "llvm.extractelement", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function extractvalue(container::Value; res::IR.Type, position, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[container,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("position", position),] + + return create_operation( + "llvm.extractvalue", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function fadd( + lhs::Value, + rhs::Value; + res=nothing::Union{Nothing,IR.Type}, + fastmathFlags=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(res) && push!(op_ty_results, res) + !isnothing(fastmathFlags) && + push!(attributes, namedattribute("fastmathFlags", fastmathFlags)) + + return create_operation( + "llvm.fadd", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function fcmp( + lhs::Value, + rhs::Value; + res=nothing::Union{Nothing,IR.Type}, + predicate, + fastmathFlags=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("predicate", predicate),] + !isnothing(res) && push!(op_ty_results, res) + !isnothing(fastmathFlags) && + push!(attributes, namedattribute("fastmathFlags", fastmathFlags)) + + return create_operation( + "llvm.fcmp", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function fdiv( + lhs::Value, + rhs::Value; + res=nothing::Union{Nothing,IR.Type}, + fastmathFlags=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(res) && push!(op_ty_results, res) + !isnothing(fastmathFlags) && + push!(attributes, namedattribute("fastmathFlags", fastmathFlags)) + + return create_operation( + "llvm.fdiv", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function fmul( + lhs::Value, + rhs::Value; + res=nothing::Union{Nothing,IR.Type}, + fastmathFlags=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(res) && push!(op_ty_results, res) + !isnothing(fastmathFlags) && + push!(attributes, namedattribute("fastmathFlags", fastmathFlags)) + + return create_operation( + "llvm.fmul", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function fneg( + operand::Value; + res=nothing::Union{Nothing,IR.Type}, + fastmathFlags=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[operand,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(res) && push!(op_ty_results, res) + !isnothing(fastmathFlags) && + push!(attributes, namedattribute("fastmathFlags", fastmathFlags)) + + return create_operation( + "llvm.fneg", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function fpext(arg::Value; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[arg,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "llvm.fpext", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function fptosi(arg::Value; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[arg,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "llvm.fptosi", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function fptoui(arg::Value; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[arg,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "llvm.fptoui", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function fptrunc(arg::Value; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[arg,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "llvm.fptrunc", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function frem( + lhs::Value, + rhs::Value; + res=nothing::Union{Nothing,IR.Type}, + fastmathFlags=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(res) && push!(op_ty_results, res) + !isnothing(fastmathFlags) && + push!(attributes, namedattribute("fastmathFlags", fastmathFlags)) + + return create_operation( + "llvm.frem", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function fsub( + lhs::Value, + rhs::Value; + res=nothing::Union{Nothing,IR.Type}, + fastmathFlags=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(res) && push!(op_ty_results, res) + !isnothing(fastmathFlags) && + push!(attributes, namedattribute("fastmathFlags", fastmathFlags)) + + return create_operation( + "llvm.fsub", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function fence(; ordering, syncscope=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("ordering", ordering),] + !isnothing(syncscope) && push!(attributes, namedattribute("syncscope", syncscope)) + + return create_operation( + "llvm.fence", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function freeze(val::Value; res=nothing::Union{Nothing,IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[val,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(res) && push!(op_ty_results, res) + + return create_operation( + "llvm.freeze", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`getelementptr` + +This operation mirrors LLVM IRs \'getelementptr\' operation that is used to +perform pointer arithmetic. + +Like in LLVM IR, it is possible to use both constants as well as SSA values +as indices. In the case of indexing within a structure, it is required to +either use constant indices directly, or supply a constant SSA value. + +An optional \'inbounds\' attribute specifies the low-level pointer arithmetic +overflow behavior that LLVM uses after lowering the operation to LLVM IR. + +Examples: + +```mlir +// GEP with an SSA value offset +%0 = llvm.getelementptr %1[%2] : (!llvm.ptr, i64) -> !llvm.ptr, f32 + +// GEP with a constant offset and the inbounds attribute set +%0 = llvm.getelementptr inbounds %1[3] : (!llvm.ptr) -> !llvm.ptr, f32 + +// GEP with constant offsets into a structure +%0 = llvm.getelementptr %1[0, 1] + : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i32, f32)> +``` +""" +function getelementptr( + base::Value, + dynamicIndices::Vector{Value}; + res::IR.Type, + rawConstantIndices, + elem_type, + inbounds=nothing, + location=Location(), +) + op_ty_results = IR.Type[res,] + operands = Value[base, dynamicIndices...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("rawConstantIndices", rawConstantIndices), + namedattribute("elem_type", elem_type), + ] + !isnothing(inbounds) && push!(attributes, namedattribute("inbounds", inbounds)) + + return create_operation( + "llvm.getelementptr", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`mlir_global_ctors` + +Specifies a list of constructor functions and priorities. The functions +referenced by this array will be called in ascending order of priority (i.e. +lowest first) when the module is loaded. The order of functions with the +same priority is not defined. This operation is translated to LLVM\'s +global_ctors global variable. The initializer functions are run at load +time. The `data` field present in LLVM\'s global_ctors variable is not +modeled here. + +Examples: + +```mlir +llvm.mlir.global_ctors {@ctor} + +llvm.func @ctor() { + ... + llvm.return +} +``` +""" +function mlir_global_ctors(; ctors, priorities, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("ctors", ctors), namedattribute("priorities", priorities) + ] + + return create_operation( + "llvm.mlir.global_ctors", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`mlir_global_dtors` + +Specifies a list of destructor functions and priorities. The functions +referenced by this array will be called in descending order of priority (i.e. +highest first) when the module is unloaded. The order of functions with the +same priority is not defined. This operation is translated to LLVM\'s +global_dtors global variable. The `data` field present in LLVM\'s +global_dtors variable is not modeled here. + +Examples: + +```mlir +llvm.func @dtor() { + llvm.return +} +llvm.mlir.global_dtors {@dtor} +``` +""" +function mlir_global_dtors(; dtors, priorities, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("dtors", dtors), namedattribute("priorities", priorities) + ] + + return create_operation( + "llvm.mlir.global_dtors", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`mlir_global` + +Since MLIR allows for arbitrary operations to be present at the top level, +global variables are defined using the `llvm.mlir.global` operation. Both +global constants and variables can be defined, and the value may also be +initialized in both cases. + +There are two forms of initialization syntax. Simple constants that can be +represented as MLIR attributes can be given in-line: + +```mlir +llvm.mlir.global @variable(32.0 : f32) : f32 +``` + +This initialization and type syntax is similar to `llvm.mlir.constant` and +may use two types: one for MLIR attribute and another for the LLVM value. +These types must be compatible. + +More complex constants that cannot be represented as MLIR attributes can be +given in an initializer region: + +```mlir +// This global is initialized with the equivalent of: +// i32* getelementptr (i32* @g2, i32 2) +llvm.mlir.global constant @int_gep() : !llvm.ptr { + %0 = llvm.mlir.addressof @g2 : !llvm.ptr + %1 = llvm.mlir.constant(2 : i32) : i32 + %2 = llvm.getelementptr %0[%1] + : (!llvm.ptr, i32) -> !llvm.ptr, i32 + // The initializer region must end with `llvm.return`. + llvm.return %2 : !llvm.ptr +} +``` + +Only one of the initializer attribute or initializer region may be provided. + +`llvm.mlir.global` must appear at top-level of the enclosing module. It uses +an @-identifier for its value, which will be uniqued by the module with +respect to other @-identifiers in it. + +Examples: + +```mlir +// Global values use @-identifiers. +llvm.mlir.global constant @cst(42 : i32) : i32 + +// Non-constant values must also be initialized. +llvm.mlir.global @variable(32.0 : f32) : f32 + +// Strings are expected to be of wrapped LLVM i8 array type and do not +// automatically include the trailing zero. +llvm.mlir.global @string(\"abc\") : !llvm.array<3 x i8> + +// For strings globals, the trailing type may be omitted. +llvm.mlir.global constant @no_trailing_type(\"foo bar\") + +// A complex initializer is constructed with an initializer region. +llvm.mlir.global constant @int_gep() : !llvm.ptr { + %0 = llvm.mlir.addressof @g2 : !llvm.ptr + %1 = llvm.mlir.constant(2 : i32) : i32 + %2 = llvm.getelementptr %0[%1] + : (!llvm.ptr, i32) -> !llvm.ptr, i32 + llvm.return %2 : !llvm.ptr +} +``` + +Similarly to functions, globals have a linkage attribute. In the custom +syntax, this attribute is placed between `llvm.mlir.global` and the optional +`constant` keyword. If the attribute is omitted, `external` linkage is +assumed by default. + +Examples: + +```mlir +// A constant with internal linkage will not participate in linking. +llvm.mlir.global internal constant @cst(42 : i32) : i32 + +// By default, \"external\" linkage is assumed and the global participates in +// symbol resolution at link-time. +llvm.mlir.global @glob(0 : f32) : f32 + +// Alignment is optional +llvm.mlir.global private constant @y(dense<1.0> : tensor<8xf32>) : !llvm.array<8 x f32> +``` + +Like global variables in LLVM IR, globals can have an (optional) +alignment attribute using keyword `alignment`. The integer value of the +alignment must be a positive integer that is a power of 2. + +Examples: + +```mlir +// Alignment is optional +llvm.mlir.global private constant @y(dense<1.0> : tensor<8xf32>) { alignment = 32 : i64 } : !llvm.array<8 x f32> +``` +""" +function mlir_global(; + global_type, + constant=nothing, + sym_name, + linkage, + dso_local=nothing, + thread_local_=nothing, + externally_initialized=nothing, + value=nothing, + alignment=nothing, + addr_space=nothing, + unnamed_addr=nothing, + section=nothing, + comdat=nothing, + dbg_exprs=nothing, + visibility_=nothing, + initializer::Region, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[initializer,] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("global_type", global_type), + namedattribute("sym_name", sym_name), + namedattribute("linkage", linkage), + ] + !isnothing(constant) && push!(attributes, namedattribute("constant", constant)) + !isnothing(dso_local) && push!(attributes, namedattribute("dso_local", dso_local)) + !isnothing(thread_local_) && + push!(attributes, namedattribute("thread_local_", thread_local_)) + !isnothing(externally_initialized) && + push!(attributes, namedattribute("externally_initialized", externally_initialized)) + !isnothing(value) && push!(attributes, namedattribute("value", value)) + !isnothing(alignment) && push!(attributes, namedattribute("alignment", alignment)) + !isnothing(addr_space) && push!(attributes, namedattribute("addr_space", addr_space)) + !isnothing(unnamed_addr) && + push!(attributes, namedattribute("unnamed_addr", unnamed_addr)) + !isnothing(section) && push!(attributes, namedattribute("section", section)) + !isnothing(comdat) && push!(attributes, namedattribute("comdat", comdat)) + !isnothing(dbg_exprs) && push!(attributes, namedattribute("dbg_exprs", dbg_exprs)) + !isnothing(visibility_) && push!(attributes, namedattribute("visibility_", visibility_)) + + return create_operation( + "llvm.mlir.global", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function icmp( + lhs::Value, + rhs::Value; + res=nothing::Union{Nothing,IR.Type}, + predicate, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("predicate", predicate),] + !isnothing(res) && push!(op_ty_results, res) + + return create_operation( + "llvm.icmp", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`inline_asm` + +The InlineAsmOp mirrors the underlying LLVM semantics with a notable +exception: the embedded `asm_string` is not allowed to define or reference +any symbol or any global variable: only the operands of the op may be read, +written, or referenced. +Attempting to define or reference any symbol or any global behavior is +considered undefined behavior at this time. +""" +function inline_asm( + operands::Vector{Value}; + res=nothing::Union{Nothing,IR.Type}, + asm_string, + constraints, + has_side_effects=nothing, + is_align_stack=nothing, + asm_dialect=nothing, + operand_attrs=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[operands...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("asm_string", asm_string), namedattribute("constraints", constraints) + ] + !isnothing(res) && push!(op_ty_results, res) + !isnothing(has_side_effects) && + push!(attributes, namedattribute("has_side_effects", has_side_effects)) + !isnothing(is_align_stack) && + push!(attributes, namedattribute("is_align_stack", is_align_stack)) + !isnothing(asm_dialect) && push!(attributes, namedattribute("asm_dialect", asm_dialect)) + !isnothing(operand_attrs) && + push!(attributes, namedattribute("operand_attrs", operand_attrs)) + + return create_operation( + "llvm.inline_asm", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function insertelement( + vector::Value, + value::Value, + position::Value; + res=nothing::Union{Nothing,IR.Type}, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[vector, value, position] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(res) && push!(op_ty_results, res) + + return create_operation( + "llvm.insertelement", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function insertvalue( + container::Value, + value::Value; + res=nothing::Union{Nothing,IR.Type}, + position, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[container, value] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("position", position),] + !isnothing(res) && push!(op_ty_results, res) + + return create_operation( + "llvm.insertvalue", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function inttoptr(arg::Value; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[arg,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "llvm.inttoptr", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function invoke( + callee_operands::Vector{Value}, + normalDestOperands::Vector{Value}, + unwindDestOperands::Vector{Value}, + op_bundle_operands::Vector{Value}; + result=nothing::Union{Nothing,IR.Type}, + var_callee_type=nothing, + callee=nothing, + arg_attrs=nothing, + res_attrs=nothing, + branch_weights=nothing, + CConv=nothing, + op_bundle_sizes, + op_bundle_tags=nothing, + normalDest::Block, + unwindDest::Block, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[ + callee_operands..., + normalDestOperands..., + unwindDestOperands..., + op_bundle_operands..., + ] + owned_regions = Region[] + successors = Block[normalDest, unwindDest] + attributes = NamedAttribute[namedattribute("op_bundle_sizes", op_bundle_sizes),] + push!( + attributes, + operandsegmentsizes([ + length(callee_operands), + length(normalDestOperands), + length(unwindDestOperands), + length(op_bundle_operands), + ]), + ) + !isnothing(result) && push!(op_ty_results, result) + !isnothing(var_callee_type) && + push!(attributes, namedattribute("var_callee_type", var_callee_type)) + !isnothing(callee) && push!(attributes, namedattribute("callee", callee)) + !isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs)) + !isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs)) + !isnothing(branch_weights) && + push!(attributes, namedattribute("branch_weights", branch_weights)) + !isnothing(CConv) && push!(attributes, namedattribute("CConv", CConv)) + !isnothing(op_bundle_tags) && + push!(attributes, namedattribute("op_bundle_tags", op_bundle_tags)) + + return create_operation( + "llvm.invoke", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`func` + +MLIR functions are defined by an operation that is not built into the IR +itself. The LLVM dialect provides an `llvm.func` operation to define +functions compatible with LLVM IR. These functions have LLVM dialect +function type but use MLIR syntax to express it. They are required to have +exactly one result type. LLVM function operation is intended to capture +additional properties of LLVM functions, such as linkage and calling +convention, that may be modeled differently by the built-in MLIR function. + +```mlir +// The type of @bar is !llvm<\"i64 (i64)\"> +llvm.func @bar(%arg0: i64) -> i64 { + llvm.return %arg0 : i64 +} + +// Type type of @foo is !llvm<\"void (i64)\"> +// !llvm.void type is omitted +llvm.func @foo(%arg0: i64) { + llvm.return +} + +// A function with `internal` linkage. +llvm.func internal @internal_func() { + llvm.return +} +``` +""" +function func(; + sym_name, + sym_visibility=nothing, + function_type, + linkage=nothing, + dso_local=nothing, + CConv=nothing, + comdat=nothing, + convergent=nothing, + personality=nothing, + garbageCollector=nothing, + passthrough=nothing, + arg_attrs=nothing, + res_attrs=nothing, + function_entry_count=nothing, + memory_effects=nothing, + visibility_=nothing, + arm_streaming=nothing, + arm_locally_streaming=nothing, + arm_streaming_compatible=nothing, + arm_new_za=nothing, + arm_in_za=nothing, + arm_out_za=nothing, + arm_inout_za=nothing, + arm_preserves_za=nothing, + section=nothing, + unnamed_addr=nothing, + alignment=nothing, + vscale_range=nothing, + frame_pointer=nothing, + target_cpu=nothing, + tune_cpu=nothing, + target_features=nothing, + unsafe_fp_math=nothing, + no_infs_fp_math=nothing, + no_nans_fp_math=nothing, + approx_func_fp_math=nothing, + no_signed_zeros_fp_math=nothing, + denormal_fp_math=nothing, + denormal_fp_math_f32=nothing, + fp_contract=nothing, + no_inline=nothing, + always_inline=nothing, + no_unwind=nothing, + will_return=nothing, + optimize_none=nothing, + vec_type_hint=nothing, + work_group_size_hint=nothing, + reqd_work_group_size=nothing, + intel_reqd_sub_group_size=nothing, + body::Region, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[body,] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("sym_name", sym_name), namedattribute("function_type", function_type) + ] + !isnothing(sym_visibility) && + push!(attributes, namedattribute("sym_visibility", sym_visibility)) + !isnothing(linkage) && push!(attributes, namedattribute("linkage", linkage)) + !isnothing(dso_local) && push!(attributes, namedattribute("dso_local", dso_local)) + !isnothing(CConv) && push!(attributes, namedattribute("CConv", CConv)) + !isnothing(comdat) && push!(attributes, namedattribute("comdat", comdat)) + !isnothing(convergent) && push!(attributes, namedattribute("convergent", convergent)) + !isnothing(personality) && push!(attributes, namedattribute("personality", personality)) + !isnothing(garbageCollector) && + push!(attributes, namedattribute("garbageCollector", garbageCollector)) + !isnothing(passthrough) && push!(attributes, namedattribute("passthrough", passthrough)) + !isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs)) + !isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs)) + !isnothing(function_entry_count) && + push!(attributes, namedattribute("function_entry_count", function_entry_count)) + !isnothing(memory_effects) && + push!(attributes, namedattribute("memory_effects", memory_effects)) + !isnothing(visibility_) && push!(attributes, namedattribute("visibility_", visibility_)) + !isnothing(arm_streaming) && + push!(attributes, namedattribute("arm_streaming", arm_streaming)) + !isnothing(arm_locally_streaming) && + push!(attributes, namedattribute("arm_locally_streaming", arm_locally_streaming)) + !isnothing(arm_streaming_compatible) && push!( + attributes, namedattribute("arm_streaming_compatible", arm_streaming_compatible) + ) + !isnothing(arm_new_za) && push!(attributes, namedattribute("arm_new_za", arm_new_za)) + !isnothing(arm_in_za) && push!(attributes, namedattribute("arm_in_za", arm_in_za)) + !isnothing(arm_out_za) && push!(attributes, namedattribute("arm_out_za", arm_out_za)) + !isnothing(arm_inout_za) && + push!(attributes, namedattribute("arm_inout_za", arm_inout_za)) + !isnothing(arm_preserves_za) && + push!(attributes, namedattribute("arm_preserves_za", arm_preserves_za)) + !isnothing(section) && push!(attributes, namedattribute("section", section)) + !isnothing(unnamed_addr) && + push!(attributes, namedattribute("unnamed_addr", unnamed_addr)) + !isnothing(alignment) && push!(attributes, namedattribute("alignment", alignment)) + !isnothing(vscale_range) && + push!(attributes, namedattribute("vscale_range", vscale_range)) + !isnothing(frame_pointer) && + push!(attributes, namedattribute("frame_pointer", frame_pointer)) + !isnothing(target_cpu) && push!(attributes, namedattribute("target_cpu", target_cpu)) + !isnothing(tune_cpu) && push!(attributes, namedattribute("tune_cpu", tune_cpu)) + !isnothing(target_features) && + push!(attributes, namedattribute("target_features", target_features)) + !isnothing(unsafe_fp_math) && + push!(attributes, namedattribute("unsafe_fp_math", unsafe_fp_math)) + !isnothing(no_infs_fp_math) && + push!(attributes, namedattribute("no_infs_fp_math", no_infs_fp_math)) + !isnothing(no_nans_fp_math) && + push!(attributes, namedattribute("no_nans_fp_math", no_nans_fp_math)) + !isnothing(approx_func_fp_math) && + push!(attributes, namedattribute("approx_func_fp_math", approx_func_fp_math)) + !isnothing(no_signed_zeros_fp_math) && push!( + attributes, namedattribute("no_signed_zeros_fp_math", no_signed_zeros_fp_math) + ) + !isnothing(denormal_fp_math) && + push!(attributes, namedattribute("denormal_fp_math", denormal_fp_math)) + !isnothing(denormal_fp_math_f32) && + push!(attributes, namedattribute("denormal_fp_math_f32", denormal_fp_math_f32)) + !isnothing(fp_contract) && push!(attributes, namedattribute("fp_contract", fp_contract)) + !isnothing(no_inline) && push!(attributes, namedattribute("no_inline", no_inline)) + !isnothing(always_inline) && + push!(attributes, namedattribute("always_inline", always_inline)) + !isnothing(no_unwind) && push!(attributes, namedattribute("no_unwind", no_unwind)) + !isnothing(will_return) && push!(attributes, namedattribute("will_return", will_return)) + !isnothing(optimize_none) && + push!(attributes, namedattribute("optimize_none", optimize_none)) + !isnothing(vec_type_hint) && + push!(attributes, namedattribute("vec_type_hint", vec_type_hint)) + !isnothing(work_group_size_hint) && + push!(attributes, namedattribute("work_group_size_hint", work_group_size_hint)) + !isnothing(reqd_work_group_size) && + push!(attributes, namedattribute("reqd_work_group_size", reqd_work_group_size)) + !isnothing(intel_reqd_sub_group_size) && push!( + attributes, + namedattribute("intel_reqd_sub_group_size", intel_reqd_sub_group_size), + ) + + return create_operation( + "llvm.func", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function lshr( + lhs::Value, + rhs::Value; + res=nothing::Union{Nothing,IR.Type}, + isExact=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(res) && push!(op_ty_results, res) + !isnothing(isExact) && push!(attributes, namedattribute("isExact", isExact)) + + return create_operation( + "llvm.lshr", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function landingpad( + operand_0::Vector{Value}; res::IR.Type, cleanup=nothing, location=Location() +) + op_ty_results = IR.Type[res,] + operands = Value[operand_0...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(cleanup) && push!(attributes, namedattribute("cleanup", cleanup)) + + return create_operation( + "llvm.landingpad", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`linker_options` + +Pass the given options to the linker when the resulting object file is linked. +This is used extensively on Windows to determine the C runtime that the object +files should link against. + +Examples: +```mlir +// Link against the MSVC static threaded CRT. +llvm.linker_options [\"/DEFAULTLIB:\", \"libcmt\"] + +// Link against aarch64 compiler-rt builtins +llvm.linker_options [\"-l\", \"clang_rt.builtins-aarch64\"] +``` +""" +function linker_options(; options, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("options", options),] + + return create_operation( + "llvm.linker_options", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`load` + +The `load` operation is used to read from memory. A load may be marked as +atomic, volatile, and/or nontemporal, and takes a number of optional +attributes that specify aliasing information. + +An atomic load only supports a limited set of pointer, integer, and +floating point types, and requires an explicit alignment. + +Examples: +```mlir +// A volatile load of a float variable. +%0 = llvm.load volatile %ptr : !llvm.ptr -> f32 + +// A nontemporal load of a float variable. +%0 = llvm.load %ptr {nontemporal} : !llvm.ptr -> f32 + +// An atomic load of an integer variable. +%0 = llvm.load %ptr atomic monotonic {alignment = 8 : i64} + : !llvm.ptr -> i64 +``` + +See the following link for more details: +https://llvm.org/docs/LangRef.html#load-instruction +""" +function load( + addr::Value; + res::IR.Type, + alignment=nothing, + volatile_=nothing, + nontemporal=nothing, + invariant=nothing, + invariantGroup=nothing, + ordering=nothing, + syncscope=nothing, + access_groups=nothing, + alias_scopes=nothing, + noalias_scopes=nothing, + tbaa=nothing, + location=Location(), +) + op_ty_results = IR.Type[res,] + operands = Value[addr,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(alignment) && push!(attributes, namedattribute("alignment", alignment)) + !isnothing(volatile_) && push!(attributes, namedattribute("volatile_", volatile_)) + !isnothing(nontemporal) && push!(attributes, namedattribute("nontemporal", nontemporal)) + !isnothing(invariant) && push!(attributes, namedattribute("invariant", invariant)) + !isnothing(invariantGroup) && + push!(attributes, namedattribute("invariantGroup", invariantGroup)) + !isnothing(ordering) && push!(attributes, namedattribute("ordering", ordering)) + !isnothing(syncscope) && push!(attributes, namedattribute("syncscope", syncscope)) + !isnothing(access_groups) && + push!(attributes, namedattribute("access_groups", access_groups)) + !isnothing(alias_scopes) && + push!(attributes, namedattribute("alias_scopes", alias_scopes)) + !isnothing(noalias_scopes) && + push!(attributes, namedattribute("noalias_scopes", noalias_scopes)) + !isnothing(tbaa) && push!(attributes, namedattribute("tbaa", tbaa)) + + return create_operation( + "llvm.load", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function mul( + lhs::Value, rhs::Value; res=nothing::Union{Nothing,IR.Type}, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(res) && push!(op_ty_results, res) + + return create_operation( + "llvm.mul", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`mlir_none` + +Unlike LLVM IR, MLIR does not have first-class token values. They must be +explicitly created as SSA values using `llvm.mlir.none`. This operation has +no operands or attributes, and returns a none token value of a wrapped LLVM IR +pointer type. + +Examples: + +```mlir +%0 = llvm.mlir.none : !llvm.token +``` +""" +function mlir_none(; res=nothing::Union{Nothing,IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(res) && push!(op_ty_results, res) + + return create_operation( + "llvm.mlir.none", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function or( + lhs::Value, + rhs::Value; + res=nothing::Union{Nothing,IR.Type}, + isDisjoint=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(res) && push!(op_ty_results, res) + !isnothing(isDisjoint) && push!(attributes, namedattribute("isDisjoint", isDisjoint)) + + return create_operation( + "llvm.or", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`mlir_poison` + +Unlike LLVM IR, MLIR does not have first-class poison values. Such values +must be created as SSA values using `llvm.mlir.poison`. This operation has +no operands or attributes. It creates a poison value of the specified LLVM +IR dialect type. + +# Example + +```mlir +// Create a poison value for a structure with a 32-bit integer followed +// by a float. +%0 = llvm.mlir.poison : !llvm.struct<(i32, f32)> +``` +""" +function mlir_poison(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "llvm.mlir.poison", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function ptrtoint(arg::Value; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[arg,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "llvm.ptrtoint", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function resume(value::Value; location=Location()) + op_ty_results = IR.Type[] + operands = Value[value,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "llvm.resume", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function return_(arg=nothing::Union{Nothing,Value}; location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(arg) && push!(operands, arg) + + return create_operation( + "llvm.return", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function sdiv( + lhs::Value, + rhs::Value; + res=nothing::Union{Nothing,IR.Type}, + isExact=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(res) && push!(op_ty_results, res) + !isnothing(isExact) && push!(attributes, namedattribute("isExact", isExact)) + + return create_operation( + "llvm.sdiv", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function sext(arg::Value; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[arg,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "llvm.sext", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function sitofp(arg::Value; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[arg,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "llvm.sitofp", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function srem( + lhs::Value, rhs::Value; res=nothing::Union{Nothing,IR.Type}, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(res) && push!(op_ty_results, res) + + return create_operation( + "llvm.srem", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function select( + condition::Value, + trueValue::Value, + falseValue::Value; + res=nothing::Union{Nothing,IR.Type}, + fastmathFlags=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[condition, trueValue, falseValue] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(res) && push!(op_ty_results, res) + !isnothing(fastmathFlags) && + push!(attributes, namedattribute("fastmathFlags", fastmathFlags)) + + return create_operation( + "llvm.select", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function shl( + lhs::Value, rhs::Value; res=nothing::Union{Nothing,IR.Type}, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(res) && push!(op_ty_results, res) + + return create_operation( + "llvm.shl", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function shufflevector(v1::Value, v2::Value; res::IR.Type, mask, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[v1, v2] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("mask", mask),] + + return create_operation( + "llvm.shufflevector", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`store` + +The `store` operation is used to write to memory. A store may be marked as +atomic, volatile, and/or nontemporal, and takes a number of optional +attributes that specify aliasing information. + +An atomic store only supports a limited set of pointer, integer, and +floating point types, and requires an explicit alignment. + +Examples: +```mlir +// A volatile store of a float variable. +llvm.store volatile %val, %ptr : f32, !llvm.ptr + +// A nontemporal store of a float variable. +llvm.store %val, %ptr {nontemporal} : f32, !llvm.ptr + +// An atomic store of an integer variable. +llvm.store %val, %ptr atomic monotonic {alignment = 8 : i64} + : i64, !llvm.ptr +``` + +See the following link for more details: +https://llvm.org/docs/LangRef.html#store-instruction +""" +function store( + value::Value, + addr::Value; + alignment=nothing, + volatile_=nothing, + nontemporal=nothing, + invariantGroup=nothing, + ordering=nothing, + syncscope=nothing, + access_groups=nothing, + alias_scopes=nothing, + noalias_scopes=nothing, + tbaa=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[value, addr] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(alignment) && push!(attributes, namedattribute("alignment", alignment)) + !isnothing(volatile_) && push!(attributes, namedattribute("volatile_", volatile_)) + !isnothing(nontemporal) && push!(attributes, namedattribute("nontemporal", nontemporal)) + !isnothing(invariantGroup) && + push!(attributes, namedattribute("invariantGroup", invariantGroup)) + !isnothing(ordering) && push!(attributes, namedattribute("ordering", ordering)) + !isnothing(syncscope) && push!(attributes, namedattribute("syncscope", syncscope)) + !isnothing(access_groups) && + push!(attributes, namedattribute("access_groups", access_groups)) + !isnothing(alias_scopes) && + push!(attributes, namedattribute("alias_scopes", alias_scopes)) + !isnothing(noalias_scopes) && + push!(attributes, namedattribute("noalias_scopes", noalias_scopes)) + !isnothing(tbaa) && push!(attributes, namedattribute("tbaa", tbaa)) + + return create_operation( + "llvm.store", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function sub( + lhs::Value, rhs::Value; res=nothing::Union{Nothing,IR.Type}, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(res) && push!(op_ty_results, res) + + return create_operation( + "llvm.sub", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function switch( + value::Value, + defaultOperands::Vector{Value}, + caseOperands::Vector{Value}; + case_values=nothing, + case_operand_segments, + branch_weights=nothing, + defaultDestination::Block, + caseDestinations::Vector{Block}, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[value, defaultOperands..., caseOperands...] + owned_regions = Region[] + successors = Block[defaultDestination, caseDestinations...] + attributes = NamedAttribute[namedattribute( + "case_operand_segments", case_operand_segments + ),] + push!( + attributes, operandsegmentsizes([1, length(defaultOperands), length(caseOperands)]) + ) + !isnothing(case_values) && push!(attributes, namedattribute("case_values", case_values)) + !isnothing(branch_weights) && + push!(attributes, namedattribute("branch_weights", branch_weights)) + + return create_operation( + "llvm.switch", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function trunc(arg::Value; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[arg,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "llvm.trunc", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function udiv( + lhs::Value, + rhs::Value; + res=nothing::Union{Nothing,IR.Type}, + isExact=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(res) && push!(op_ty_results, res) + !isnothing(isExact) && push!(attributes, namedattribute("isExact", isExact)) + + return create_operation( + "llvm.udiv", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function uitofp(arg::Value; res::IR.Type, nonNeg=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[arg,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(nonNeg) && push!(attributes, namedattribute("nonNeg", nonNeg)) + + return create_operation( + "llvm.uitofp", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function urem( + lhs::Value, rhs::Value; res=nothing::Union{Nothing,IR.Type}, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(res) && push!(op_ty_results, res) + + return create_operation( + "llvm.urem", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`mlir_undef` + +Unlike LLVM IR, MLIR does not have first-class undefined values. Such values +must be created as SSA values using `llvm.mlir.undef`. This operation has no +operands or attributes. It creates an undefined value of the specified LLVM +IR dialect type. + +# Example + +```mlir +// Create a structure with a 32-bit integer followed by a float. +%0 = llvm.mlir.undef : !llvm.struct<(i32, f32)> +``` +""" +function mlir_undef(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "llvm.mlir.undef", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function unreachable(; location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "llvm.unreachable", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function va_arg(arg::Value; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[arg,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "llvm.va_arg", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function xor( + lhs::Value, rhs::Value; res=nothing::Union{Nothing,IR.Type}, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(res) && push!(op_ty_results, res) + + return create_operation( + "llvm.xor", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function zext(arg::Value; res::IR.Type, nonNeg=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[arg,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(nonNeg) && push!(attributes, namedattribute("nonNeg", nonNeg)) + + return create_operation( + "llvm.zext", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`mlir_zero` + +Unlike LLVM IR, MLIR does not have first-class zero-initialized values. +Such values must be created as SSA values using `llvm.mlir.zero`. This +operation has no operands or attributes. It creates a zero-initialized +value of the specified LLVM IR dialect type. + +# Example + +```mlir +// Create a zero-initialized value for a structure with a 32-bit integer +// followed by a float. +%0 = llvm.mlir.zero : !llvm.struct<(i32, f32)> +``` +""" +function mlir_zero(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "llvm.mlir.zero", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +end # llvm diff --git a/src/mlir/Dialects/MPI.jl b/src/mlir/Dialects/MPI.jl new file mode 100644 index 0000000000..68eb27f87a --- /dev/null +++ b/src/mlir/Dialects/MPI.jl @@ -0,0 +1,450 @@ +module mpi +using ...IR +import ...IR: + NamedAttribute, + Value, + Location, + Block, + Region, + Attribute, + create_operation, + context, + IndexType +import ..Dialects: namedattribute, operandsegmentsizes +import ...API + +""" +`allreduce` + +MPI_Allreduce performs a reduction operation on the values in the sendbuf +array and stores the result in the recvbuf array. The operation is +performed across all processes in the communicator. + +The `op` attribute specifies the reduction operation to be performed. +Currently only the `MPI_Op` predefined in the standard (e.g. `MPI_SUM`) are +supported. + +Communicators other than `MPI_COMM_WORLD` are not supported for now. + +This operation can optionally return an `!mpi.retval` value that can be used +to check for errors. +""" +function allreduce( + sendbuf::Value, + recvbuf::Value; + retval=nothing::Union{Nothing,IR.Type}, + op, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[sendbuf, recvbuf] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("op", op),] + !isnothing(retval) && push!(op_ty_results, retval) + + return create_operation( + "mpi.allreduce", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`barrier` + +MPI_Barrier blocks execution until all processes in the communicator have +reached this routine. + +Communicators other than `MPI_COMM_WORLD` are not supported for now. + +This operation can optionally return an `!mpi.retval` value that can be used +to check for errors. +""" +function barrier(; retval=nothing::Union{Nothing,IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(retval) && push!(op_ty_results, retval) + + return create_operation( + "mpi.barrier", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`comm_rank` + +Communicators other than `MPI_COMM_WORLD` are not supported for now. + +This operation can optionally return an `!mpi.retval` value that can be used +to check for errors. +""" +function comm_rank(; + retval=nothing::Union{Nothing,IR.Type}, rank::IR.Type, location=Location() +) + op_ty_results = IR.Type[rank,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(retval) && push!(op_ty_results, retval) + + return create_operation( + "mpi.comm_rank", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`comm_size` + +Communicators other than `MPI_COMM_WORLD` are not supported for now. + +This operation can optionally return an `!mpi.retval` value that can be used +to check for errors. +""" +function comm_size(; + retval=nothing::Union{Nothing,IR.Type}, size::IR.Type, location=Location() +) + op_ty_results = IR.Type[size,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(retval) && push!(op_ty_results, retval) + + return create_operation( + "mpi.comm_size", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`error_class` + +`MPI_Error_class` maps return values from MPI calls to a set of well-known +MPI error classes. +""" +function error_class(val::Value; errclass::IR.Type, location=Location()) + op_ty_results = IR.Type[errclass,] + operands = Value[val,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "mpi.error_class", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`finalize` + +This function cleans up the MPI state. Afterwards, no MPI methods may +be invoked (excpet for MPI_Get_version, MPI_Initialized, and MPI_Finalized). +Notably, MPI_Init cannot be called again in the same program. + +This operation can optionally return an `!mpi.retval` value that can be used +to check for errors. +""" +function finalize(; retval=nothing::Union{Nothing,IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(retval) && push!(op_ty_results, retval) + + return create_operation( + "mpi.finalize", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`irecv` + +MPI_Irecv begins a non-blocking receive of `size` elements of type `dtype` +from rank `dest`. The `tag` value and communicator enables the library to +determine the matching of multiple sends and receives between the same +ranks. + +Communicators other than `MPI_COMM_WORLD` are not supported for now. + +This operation can optionally return an `!mpi.retval` value that can be used +to check for errors. +""" +function irecv( + ref::Value, + tag::Value, + rank::Value; + retval=nothing::Union{Nothing,IR.Type}, + req::IR.Type, + location=Location(), +) + op_ty_results = IR.Type[req,] + operands = Value[ref, tag, rank] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(retval) && push!(op_ty_results, retval) + + return create_operation( + "mpi.irecv", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`isend` + +MPI_Isend begins a non-blocking send of `size` elements of type `dtype` to +rank `dest`. The `tag` value and communicator enables the library to +determine the matching of multiple sends and receives between the same +ranks. + +Communicators other than `MPI_COMM_WORLD` are not supported for now. + +This operation can optionally return an `!mpi.retval` value that can be used +to check for errors. +""" +function isend( + ref::Value, + tag::Value, + rank::Value; + retval=nothing::Union{Nothing,IR.Type}, + req::IR.Type, + location=Location(), +) + op_ty_results = IR.Type[req,] + operands = Value[ref, tag, rank] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(retval) && push!(op_ty_results, retval) + + return create_operation( + "mpi.isend", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`init` + +This operation must preceed most MPI calls (except for very few exceptions, +please consult with the MPI specification on these). + +Passing &argc, &argv is not supported currently. + +This operation can optionally return an `!mpi.retval` value that can be used +to check for errors. +""" +function init(; retval=nothing::Union{Nothing,IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(retval) && push!(op_ty_results, retval) + + return create_operation( + "mpi.init", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`recv` + +MPI_Recv performs a blocking receive of `size` elements of type `dtype` +from rank `source`. The `tag` value and communicator enables the library to +determine the matching of multiple sends and receives between the same +ranks. + +Communicators other than `MPI_COMM_WORLD` are not supported for now. +The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object +is not yet ported to MLIR. + +This operation can optionally return an `!mpi.retval` value that can be used +to check for errors. +""" +function recv( + ref::Value, + tag::Value, + source::Value; + retval=nothing::Union{Nothing,IR.Type}, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[ref, tag, source] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(retval) && push!(op_ty_results, retval) + + return create_operation( + "mpi.recv", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`retval_check` + +This operation compares MPI status codes to known error class +constants such as `MPI_SUCCESS`, or `MPI_ERR_COMM`. +""" +function retval_check(val::Value; res::IR.Type, errclass, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[val,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("errclass", errclass),] + + return create_operation( + "mpi.retval_check", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`send` + +MPI_Send performs a blocking send of `size` elements of type `dtype` to rank +`dest`. The `tag` value and communicator enables the library to determine +the matching of multiple sends and receives between the same ranks. + +Communicators other than `MPI_COMM_WORLD` are not supported for now. + +This operation can optionally return an `!mpi.retval` value that can be used +to check for errors. +""" +function send( + ref::Value, + tag::Value, + dest::Value; + retval=nothing::Union{Nothing,IR.Type}, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[ref, tag, dest] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(retval) && push!(op_ty_results, retval) + + return create_operation( + "mpi.send", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`wait` + +MPI_Wait blocks execution until the request has completed. + +The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object +is not yet ported to MLIR. + +This operation can optionally return an `!mpi.retval` value that can be used +to check for errors. +""" +function wait(req::Value; retval=nothing::Union{Nothing,IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[req,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(retval) && push!(op_ty_results, retval) + + return create_operation( + "mpi.wait", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +end # mpi diff --git a/src/mlir/Dialects/MemRef.jl b/src/mlir/Dialects/MemRef.jl new file mode 100644 index 0000000000..1e72af142f --- /dev/null +++ b/src/mlir/Dialects/MemRef.jl @@ -0,0 +1,1762 @@ +module memref +using ...IR +import ...IR: + NamedAttribute, + Value, + Location, + Block, + Region, + Attribute, + create_operation, + context, + IndexType +import ..Dialects: namedattribute, operandsegmentsizes +import ...API + +""" +`assume_alignment` + +The `assume_alignment` operation takes a memref and an integer of alignment +value, and internally annotates the buffer with the given alignment. If +the buffer isn\'t aligned to the given alignment, the behavior is undefined. + +This operation doesn\'t affect the semantics of a correct program. It\'s for +optimization only, and the optimization is best-effort. +""" +function assume_alignment(memref::Value; alignment, location=Location()) + op_ty_results = IR.Type[] + operands = Value[memref,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("alignment", alignment),] + + return create_operation( + "memref.assume_alignment", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`atomic_rmw` + +The `memref.atomic_rmw` operation provides a way to perform a read-modify-write +sequence that is free from data races. The kind enumeration specifies the +modification to perform. The value operand represents the new value to be +applied during the modification. The memref operand represents the buffer +that the read and write will be performed against, as accessed by the +specified indices. The arity of the indices is the rank of the memref. The +result represents the latest value that was stored. + +# Example + +```mlir +%x = memref.atomic_rmw \"addf\" %value, %I[%i] : (f32, memref<10xf32>) -> f32 +``` +""" +function atomic_rmw( + value::Value, + memref::Value, + indices::Vector{Value}; + result=nothing::Union{Nothing,IR.Type}, + kind, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[value, memref, indices...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("kind", kind),] + !isnothing(result) && push!(op_ty_results, result) + + return create_operation( + "memref.atomic_rmw", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`atomic_yield` + +\"memref.atomic_yield\" yields an SSA value from a +GenericAtomicRMWOp region. +""" +function atomic_yield(result::Value; location=Location()) + op_ty_results = IR.Type[] + operands = Value[result,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "memref.atomic_yield", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`copy` + +Copies the data from the source to the destination memref. + +Usage: + +```mlir +memref.copy %arg0, %arg1 : memref to memref +``` + +Source and destination are expected to have the same element type and shape. +Otherwise, the result is undefined. They may have different layouts. +""" +function copy(source::Value, target::Value; location=Location()) + op_ty_results = IR.Type[] + operands = Value[source, target] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "memref.copy", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`generic_atomic_rmw` + +The `memref.generic_atomic_rmw` operation provides a way to perform a +read-modify-write sequence that is free from data races. The memref operand +represents the buffer that the read and write will be performed against, as +accessed by the specified indices. The arity of the indices is the rank of +the memref. The result represents the latest value that was stored. The +region contains the code for the modification itself. The entry block has +a single argument that represents the value stored in `memref[indices]` +before the write is performed. No side-effecting ops are allowed in the +body of `GenericAtomicRMWOp`. + +# Example + +```mlir +%x = memref.generic_atomic_rmw %I[%i] : memref<10xf32> { + ^bb0(%current_value : f32): + %c1 = arith.constant 1.0 : f32 + %inc = arith.addf %c1, %current_value : f32 + memref.atomic_yield %inc : f32 +} +``` +""" +function generic_atomic_rmw( + memref::Value, + indices::Vector{Value}; + result::IR.Type, + atomic_body::Region, + location=Location(), +) + op_ty_results = IR.Type[result,] + operands = Value[memref, indices...] + owned_regions = Region[atomic_body,] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "memref.generic_atomic_rmw", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`load` + +The `load` op reads an element from a memref specified by an index list. The +output of load is a new value with the same type as the elements of the +memref. The arity of indices is the rank of the memref (i.e., if the memref +loaded from is of rank 3, then 3 indices are required for the load following +the memref identifier). + +In an `affine.if` or `affine.for` body, the indices of a load are restricted +to SSA values bound to surrounding loop induction variables, +[symbols](Affine.md/#dimensions-and-symbols), results of a +constant operations, or the result of an +`affine.apply` operation that can in turn take as arguments all of the +aforementioned SSA values or the recursively result of such an +`affine.apply` operation. + +# Example + +```mlir +%1 = affine.apply affine_map<(d0, d1) -> (3*d0)> (%i, %j) +%2 = affine.apply affine_map<(d0, d1) -> (d1+1)> (%i, %j) +%12 = memref.load %A[%1, %2] : memref<8x?xi32, #layout, memspace0> + +// Example of an indirect load (treated as non-affine) +%3 = affine.apply affine_map<(d0) -> (2*d0 + 1)>(%12) +%13 = memref.load %A[%3, %2] : memref<4x?xi32, #layout, memspace0> +``` + +**Context:** The `load` and `store` operations are specifically crafted to +fully resolve a reference to an element of a memref, and (in affine +`affine.if` and `affine.for` operations) the compiler can follow use-def +chains (e.g. through [`affine.apply`](Affine.md/#affineapply-affineapplyop) +operations) to precisely analyze references at compile-time using polyhedral +techniques. This is possible because of the +[restrictions on dimensions and symbols](Affine.md/#restrictions-on-dimensions-and-symbols) +in these contexts. +""" +function load( + memref::Value, + indices::Vector{Value}; + result=nothing::Union{Nothing,IR.Type}, + nontemporal=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[memref, indices...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + !isnothing(nontemporal) && push!(attributes, namedattribute("nontemporal", nontemporal)) + + return create_operation( + "memref.load", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`alloc` + +The `alloc` operation allocates a region of memory, as specified by its +memref type. + +# Example + +```mlir +%0 = memref.alloc() : memref<8x64xf32, 1> +``` + +The optional list of dimension operands are bound to the dynamic dimensions +specified in its memref type. In the example below, the ssa value \'%d\' is +bound to the second dimension of the memref (which is dynamic). + +```mlir +%0 = memref.alloc(%d) : memref<8x?xf32, 1> +``` + +The optional list of symbol operands are bound to the symbols of the +memrefs affine map. In the example below, the ssa value \'%s\' is bound to +the symbol \'s0\' in the affine map specified in the allocs memref type. + +```mlir +%0 = memref.alloc()[%s] : memref<8x64xf32, + affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> +``` + +This operation returns a single ssa value of memref type, which can be used +by subsequent load and store operations. + +The optional `alignment` attribute may be specified to ensure that the +region of memory that will be indexed is aligned at the specified byte +boundary. + +```mlir +%0 = memref.alloc()[%s] {alignment = 8} : + memref<8x64xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> +``` +""" +function alloc( + dynamicSizes::Vector{Value}, + symbolOperands::Vector{Value}; + memref::IR.Type, + alignment=nothing, + location=Location(), +) + op_ty_results = IR.Type[memref,] + operands = Value[dynamicSizes..., symbolOperands...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + push!(attributes, operandsegmentsizes([length(dynamicSizes), length(symbolOperands)])) + !isnothing(alignment) && push!(attributes, namedattribute("alignment", alignment)) + + return create_operation( + "memref.alloc", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`alloca` + +The `alloca` operation allocates memory on the stack, to be automatically +released when control transfers back from the region of its closest +surrounding operation with an +[`AutomaticAllocationScope`](../Traits.md/#automaticallocationscope) trait. +The amount of memory allocated is specified by its memref and additional +operands. For example: + +```mlir +%0 = memref.alloca() : memref<8x64xf32> +``` + +The optional list of dimension operands are bound to the dynamic dimensions +specified in its memref type. In the example below, the SSA value \'%d\' is +bound to the second dimension of the memref (which is dynamic). + +```mlir +%0 = memref.alloca(%d) : memref<8x?xf32> +``` + +The optional list of symbol operands are bound to the symbols of the +memref\'s affine map. In the example below, the SSA value \'%s\' is bound to +the symbol \'s0\' in the affine map specified in the allocs memref type. + +```mlir +%0 = memref.alloca()[%s] : memref<8x64xf32, + affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>> +``` + +This operation returns a single SSA value of memref type, which can be used +by subsequent load and store operations. An optional alignment attribute, if +specified, guarantees alignment at least to that boundary. If not specified, +an alignment on any convenient boundary compatible with the type will be +chosen. +""" +function alloca( + dynamicSizes::Vector{Value}, + symbolOperands::Vector{Value}; + memref::IR.Type, + alignment=nothing, + location=Location(), +) + op_ty_results = IR.Type[memref,] + operands = Value[dynamicSizes..., symbolOperands...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + push!(attributes, operandsegmentsizes([length(dynamicSizes), length(symbolOperands)])) + !isnothing(alignment) && push!(attributes, namedattribute("alignment", alignment)) + + return create_operation( + "memref.alloca", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`alloca_scope` + +The `memref.alloca_scope` operation represents an explicitly-delimited +scope for the alloca allocations. Any `memref.alloca` operations that are +used within this scope are going to be cleaned up automatically once +the control-flow exits the nested region. For example: + +```mlir +memref.alloca_scope { + %myalloca = memref.alloca(): memref<4x3xf32> + ... +} +``` + +Here, `%myalloca` memref is valid within the explicitly delimited scope +and is automatically deallocated at the end of the given region. Conceptually, +`memref.alloca_scope` is a passthrough operation with +`AutomaticAllocationScope` that spans the body of the region within the operation. + +`memref.alloca_scope` may also return results that are defined in the nested +region. To return a value, one should use `memref.alloca_scope.return` +operation: + +```mlir +%result = memref.alloca_scope { + ... + memref.alloca_scope.return %value +} +``` + +If `memref.alloca_scope` returns no value, the `memref.alloca_scope.return ` can +be left out, and will be inserted implicitly. +""" +function alloca_scope(; results::Vector{IR.Type}, bodyRegion::Region, location=Location()) + op_ty_results = IR.Type[results...,] + operands = Value[] + owned_regions = Region[bodyRegion,] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "memref.alloca_scope", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`alloca_scope_return` + +`memref.alloca_scope.return` operation returns zero or more SSA values +from the region within `memref.alloca_scope`. If no values are returned, +the return operation may be omitted. Otherwise, it has to be present +to indicate which values are going to be returned. For example: + +```mlir +memref.alloca_scope.return %value +``` +""" +function alloca_scope_return(results::Vector{Value}; location=Location()) + op_ty_results = IR.Type[] + operands = Value[results...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "memref.alloca_scope.return", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`cast` + +The `memref.cast` operation converts a memref from one type to an equivalent +type with a compatible shape. The source and destination types are +compatible if: + +a. Both are ranked memref types with the same element type, address space, +and rank and: + 1. Both have the same layout or both have compatible strided layouts. + 2. The individual sizes (resp. offset and strides in the case of strided + memrefs) may convert constant dimensions to dynamic dimensions and + vice-versa. + +If the cast converts any dimensions from an unknown to a known size, then it +acts as an assertion that fails at runtime if the dynamic dimensions +disagree with resultant destination size. + +# Example + +```mlir +// Assert that the input dynamic shape matches the destination static shape. +%2 = memref.cast %1 : memref to memref<4x4xf32> +// Erase static shape information, replacing it with dynamic information. +%3 = memref.cast %1 : memref<4xf32> to memref + +// The same holds true for offsets and strides. + +// Assert that the input dynamic shape matches the destination static stride. +%4 = memref.cast %1 : memref<12x4xf32, strided<[?, ?], offset: ?>> to + memref<12x4xf32, strided<[4, 1], offset: 5>> +// Erase static offset and stride information, replacing it with +// dynamic information. +%5 = memref.cast %1 : memref<12x4xf32, strided<[4, 1], offset: 5>> to + memref<12x4xf32, strided<[?, ?], offset: ?>> +``` + +b. Either or both memref types are unranked with the same element type, and +address space. + +# Example + +```mlir +Cast to concrete shape. + %4 = memref.cast %1 : memref<*xf32> to memref<4x?xf32> + +Erase rank information. + %5 = memref.cast %1 : memref<4x?xf32> to memref<*xf32> +``` +""" +function cast(source::Value; dest::IR.Type, location=Location()) + op_ty_results = IR.Type[dest,] + operands = Value[source,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "memref.cast", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`collapse_shape` + +The `memref.collapse_shape` op produces a new view with a smaller rank +whose sizes are a reassociation of the original `view`. The operation is +limited to such reassociations, where subsequent, contiguous dimensions are +collapsed into a single dimension. Such reassociations never require +additional allocs or copies. + +Collapsing non-contiguous dimensions is undefined behavior. When a group of +dimensions can be statically proven to be non-contiguous, collapses of such +groups are rejected in the verifier on a best-effort basis. In the general +case, collapses of dynamically-sized dims with dynamic strides cannot be +proven to be contiguous or non-contiguous due to limitations in the memref +type. + +A reassociation is defined as a continuous grouping of dimensions and is +represented with an array of DenseI64ArrayAttr attribute. + +Note: Only the dimensions within a reassociation group must be contiguous. +The remaining dimensions may be non-contiguous. + +The result memref type can be zero-ranked if the source memref type is +statically shaped with all dimensions being unit extent. In such a case, the +reassociation indices must be empty. + +Examples: + +```mlir +// Dimension collapse (i, j) -> i\' and k -> k\' +%1 = memref.collapse_shape %0 [[0, 1], [2]] : + memref into memref +``` + +For simplicity, this op may not be used to cast dynamicity of dimension +sizes and/or strides. I.e., a result dimension must be dynamic if and only +if at least one dimension in the corresponding reassociation group is +dynamic. Similarly, the stride of a result dimension must be dynamic if and +only if the corresponding start dimension in the source type is dynamic. + +Note: This op currently assumes that the inner strides are of the +source/result layout map are the faster-varying ones. +""" +function collapse_shape(src::Value; result::IR.Type, reassociation, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[src,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("reassociation", reassociation),] + + return create_operation( + "memref.collapse_shape", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`dealloc` + +The `dealloc` operation frees the region of memory referenced by a memref +which was originally created by the `alloc` operation. +The `dealloc` operation should not be called on memrefs which alias an +alloc\'d memref (e.g. memrefs returned by `view` operations). + +# Example + +```mlir +%0 = memref.alloc() : memref<8x64xf32, affine_map<(d0, d1) -> (d0, d1), 1>> +memref.dealloc %0 : memref<8x64xf32, affine_map<(d0, d1) -> (d0, d1), 1>> +``` +""" +function dealloc(memref::Value; location=Location()) + op_ty_results = IR.Type[] + operands = Value[memref,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "memref.dealloc", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`dim` + +The `dim` operation takes a memref and a dimension operand of type `index`. +It returns the size of the requested dimension of the given memref. +If the dimension index is out of bounds the behavior is undefined. + +The specified memref type is that of the first operand. + +# Example + +```mlir +// Always returns 4, can be constant folded: +%c0 = arith.constant 0 : index +%x = memref.dim %A, %c0 : memref<4 x ? x f32> + +// Returns the dynamic dimension of %A. +%c1 = arith.constant 1 : index +%y = memref.dim %A, %c1 : memref<4 x ? x f32> + +// Equivalent generic form: +%x = \"memref.dim\"(%A, %c0) : (memref<4 x ? x f32>, index) -> index +%y = \"memref.dim\"(%A, %c1) : (memref<4 x ? x f32>, index) -> index +``` +""" +function dim( + source::Value, index::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[source, index] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + return create_operation( + "memref.dim", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`dma_start` + +# Syntax + +``` +operation ::= `memref.dma_start` ssa-use`[`ssa-use-list`]` `,` + ssa-use`[`ssa-use-list`]` `,` ssa-use `,` + ssa-use`[`ssa-use-list`]` (`,` ssa-use `,` ssa-use)? + `:` memref-type `,` memref-type `,` memref-type +``` + +DmaStartOp starts a non-blocking DMA operation that transfers data from a +source memref to a destination memref. The source and destination memref +need not be of the same dimensionality, but need to have the same elemental +type. The operands include the source and destination memref\'s each followed +by its indices, size of the data transfer in terms of the number of elements +(of the elemental type of the memref), a tag memref with its indices, and +optionally at the end, a stride and a number_of_elements_per_stride +arguments. The tag location is used by a DmaWaitOp to check for completion. +The indices of the source memref, destination memref, and the tag memref +have the same restrictions as any load/store. The optional stride arguments +should be of \'index\' type, and specify a stride for the slower memory space +(memory space with a lower memory space id), transferring chunks of +number_of_elements_per_stride every stride until %num_elements are +transferred. Either both or no stride arguments should be specified. If the +source and destination locations overlap the behavior of this operation is +not defined. + +For example, a DmaStartOp operation that transfers 256 elements of a memref +\'%src\' in memory space 0 at indices [%i, %j] to memref \'%dst\' in memory +space 1 at indices [%k, %l], would be specified as follows: + +```mlir +%num_elements = arith.constant 256 +%idx = arith.constant 0 : index +%tag = memref.alloc() : memref<1 x i32, affine_map<(d0) -> (d0)>, 4> +dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] : + memref<40 x 128 x f32>, affine_map<(d0) -> (d0)>, 0>, + memref<2 x 1024 x f32>, affine_map<(d0) -> (d0)>, 1>, + memref<1 x i32>, affine_map<(d0) -> (d0)>, 2> +``` + +If %stride and %num_elt_per_stride are specified, the DMA is expected to +transfer %num_elt_per_stride elements every %stride elements apart from +memory space 0 until %num_elements are transferred. + +```mlir +dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride, + %num_elt_per_stride : +``` + +* TODO: add additional operands to allow source and destination striding, and +multiple stride levels. +* TODO: Consider replacing src/dst memref indices with view memrefs. +""" +function dma_start(operands::Vector{Value}; location=Location()) + op_ty_results = IR.Type[] + operands = Value[operands...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "memref.dma_start", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`dma_wait` + +DmaWaitOp blocks until the completion of a DMA operation associated with the +tag element \'%tag[%index]\'. %tag is a memref, and %index has to be an index +with the same restrictions as any load/store index. %num_elements is the +number of elements associated with the DMA operation. + +# Example + +```mlir + dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] : + memref<2048 x f32>, affine_map<(d0) -> (d0)>, 0>, + memref<256 x f32>, affine_map<(d0) -> (d0)>, 1> + memref<1 x i32>, affine_map<(d0) -> (d0)>, 2> + ... + ... + dma_wait %tag[%index], %num_elements : memref<1 x i32, affine_map<(d0) -> (d0)>, 2> + ``` +""" +function dma_wait( + tagMemRef::Value, tagIndices::Vector{Value}, numElements::Value; location=Location() +) + op_ty_results = IR.Type[] + operands = Value[tagMemRef, tagIndices..., numElements] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "memref.dma_wait", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`expand_shape` + +The `memref.expand_shape` op produces a new view with a higher rank whose +sizes are a reassociation of the original `view`. The operation is limited +to such reassociations, where a dimension is expanded into one or multiple +contiguous dimensions. Such reassociations never require additional allocs +or copies. + +A reassociation is defined as a grouping of dimensions and is represented +with an array of DenseI64ArrayAttr attributes. + +# Example + +```mlir +%r = memref.expand_shape %0 [[0, 1], [2]] output_shape [%sz0, %sz1, 32] + : memref into memref +``` + +If an op can be statically proven to be invalid (e.g, an expansion from +`memref<10xf32>` to `memref<2x6xf32>`), it is rejected by the verifier. If +it cannot statically be proven invalid (e.g., the full example above; it is +unclear whether the first source dimension is divisible by 5), the op is +accepted by the verifier. However, if the op is in fact invalid at runtime, +the behavior is undefined. + +The source memref can be zero-ranked. In that case, the reassociation +indices must be empty and the result shape may only consist of unit +dimensions. + +For simplicity, this op may not be used to cast dynamicity of dimension +sizes and/or strides. I.e., if and only if a source dimension is dynamic, +there must be a dynamic result dimension in the corresponding reassociation +group. Same for strides. + +The representation for the output shape supports a partially-static +specification via attributes specified through the `static_output_shape` +argument. A special sentinel value `ShapedType::kDynamic` encodes that the +corresponding entry has a dynamic value. There must be exactly as many SSA +inputs in `output_shape` as there are `ShapedType::kDynamic` entries in +`static_output_shape`. + +Note: This op currently assumes that the inner strides are of the +source/result layout map are the faster-varying ones. +""" +function expand_shape( + src::Value, + output_shape::Vector{Value}; + result::IR.Type, + reassociation, + static_output_shape, + location=Location(), +) + op_ty_results = IR.Type[result,] + operands = Value[src, output_shape...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("reassociation", reassociation), + namedattribute("static_output_shape", static_output_shape), + ] + + return create_operation( + "memref.expand_shape", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`extract_aligned_pointer_as_index` + +Extracts the underlying aligned pointer as an index. + +This operation is useful for lowering to lower-level dialects while still +avoiding the need to define a pointer type in higher-level dialects such as +the memref dialect. + +This operation is intended solely as step during lowering, it has no side +effects. A reverse operation that creates a memref from an index interpreted +as a pointer is explicitly discouraged. + +# Example + +``` + %0 = memref.extract_aligned_pointer_as_index %arg : memref<4x4xf32> -> index + %1 = arith.index_cast %0 : index to i64 + %2 = llvm.inttoptr %1 : i64 to !llvm.ptr + call @foo(%2) : (!llvm.ptr) ->() +``` +""" +function extract_aligned_pointer_as_index( + source::Value; aligned_pointer=nothing::Union{Nothing,IR.Type}, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[source,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(aligned_pointer) && push!(op_ty_results, aligned_pointer) + + return create_operation( + "memref.extract_aligned_pointer_as_index", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`extract_strided_metadata` + +Extracts a base buffer, offset and strides. This op allows additional layers +of transformations and foldings to be added as lowering progresses from +higher-level dialect to lower-level dialects such as the LLVM dialect. + +The op requires a strided memref source operand. If the source operand is not +a strided memref, then verification fails. + +This operation is also useful for completeness to the existing memref.dim op. +While accessing strides, offsets and the base pointer independently is not +available, this is useful for composing with its natural complement op: +`memref.reinterpret_cast`. + +Intended Use Cases: + +The main use case is to expose the logic for manipulate memref metadata at a +higher level than the LLVM dialect. +This makes lowering more progressive and brings the following benefits: + - not all users of MLIR want to lower to LLVM and the information to e.g. + lower to library calls---like libxsmm---or to SPIR-V was not available. + - foldings and canonicalizations can happen at a higher level in MLIR: + before this op existed, lowering to LLVM would create large amounts of + LLVMIR. Even when LLVM does a good job at folding the low-level IR from + a performance perspective, it is unnecessarily opaque and inefficient to + send unkempt IR to LLVM. + +# Example + +```mlir + %base, %offset, %sizes:2, %strides:2 = + memref.extract_strided_metadata %memref : + memref<10x?xf32>, index, index, index, index, index + + // After folding, the type of %m2 can be memref<10x?xf32> and further + // folded to %memref. + %m2 = memref.reinterpret_cast %base to + offset: [%offset], + sizes: [%sizes#0, %sizes#1], + strides: [%strides#0, %strides#1] + : memref to memref +``` +""" +function extract_strided_metadata( + source::Value; + base_buffer=nothing::Union{Nothing,IR.Type}, + offset=nothing::Union{Nothing,IR.Type}, + sizes=nothing::Union{Nothing,Vector{IR.Type}}, + strides=nothing::Union{Nothing,Vector{IR.Type}}, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[source,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(base_buffer) && push!(op_ty_results, base_buffer) + !isnothing(offset) && push!(op_ty_results, offset) + !isnothing(sizes) && push!(op_ty_results, sizes...) + !isnothing(strides) && push!(op_ty_results, strides...) + + return create_operation( + "memref.extract_strided_metadata", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`get_global` + +The `memref.get_global` operation retrieves the memref pointing to a +named global variable. If the global variable is marked constant, writing +to the result memref (such as through a `memref.store` operation) is +undefined. + +# Example + +```mlir +%x = memref.get_global @foo : memref<2xf32> +``` +""" +function get_global(; result::IR.Type, name, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("name", name),] + + return create_operation( + "memref.get_global", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`global_` + +The `memref.global` operation declares or defines a named global memref +variable. The backing memory for the variable is allocated statically and is +described by the type of the variable (which should be a statically shaped +memref type). The operation is a declaration if no `initial_value` is +specified, else it is a definition. The `initial_value` can either be a unit +attribute to represent a definition of an uninitialized global variable, or +an elements attribute to represent the definition of a global variable with +an initial value. The global variable can also be marked constant using the +`constant` unit attribute. Writing to such constant global variables is +undefined. + +The global variable can be accessed by using the `memref.get_global` to +retrieve the memref for the global variable. Note that the memref +for such global variable itself is immutable (i.e., memref.get_global for a +given global variable will always return the same memref descriptor). + +# Example + +```mlir +// Private variable with an initial value. +memref.global \"private\" @x : memref<2xf32> = dense<0.0,2.0> + +// Private variable with an initial value and an alignment (power of 2). +memref.global \"private\" @x : memref<2xf32> = dense<0.0,2.0> {alignment = 64} + +// Declaration of an external variable. +memref.global \"private\" @y : memref<4xi32> + +// Uninitialized externally visible variable. +memref.global @z : memref<3xf16> = uninitialized + +// Externally visible constant variable. +memref.global constant @c : memref<2xi32> = dense<1, 4> +``` +""" +function global_(; + sym_name, + sym_visibility=nothing, + type, + initial_value=nothing, + constant=nothing, + alignment=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("sym_name", sym_name), namedattribute("type", type) + ] + !isnothing(sym_visibility) && + push!(attributes, namedattribute("sym_visibility", sym_visibility)) + !isnothing(initial_value) && + push!(attributes, namedattribute("initial_value", initial_value)) + !isnothing(constant) && push!(attributes, namedattribute("constant", constant)) + !isnothing(alignment) && push!(attributes, namedattribute("alignment", alignment)) + + return create_operation( + "memref.global", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`memory_space_cast` + +This operation casts memref values between memory spaces. +The input and result will be memrefs of the same types and shape that alias +the same underlying memory, though, for some casts on some targets, +the underlying values of the pointer stored in the memref may be affected +by the cast. + +The input and result must have the same shape, element type, rank, and layout. + +If the source and target address spaces are the same, this operation is a noop. + +# Example + +```mlir +// Cast a GPU private memory attribution into a generic pointer +%2 = memref.memory_space_cast %1 : memref to memref +// Cast a generic pointer to workgroup-local memory +%4 = memref.memory_space_cast %3 : memref<5x4xi32> to memref<5x34xi32, 3> +// Cast between two non-default memory spaces +%6 = memref.memory_space_cast %5 + : memref<*xmemref, 5> to memref<*xmemref, 3> +``` +""" +function memory_space_cast(source::Value; dest::IR.Type, location=Location()) + op_ty_results = IR.Type[dest,] + operands = Value[source,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "memref.memory_space_cast", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`prefetch` + +The \"prefetch\" op prefetches data from a memref location described with +subscript indices similar to memref.load, and with three attributes: a +read/write specifier, a locality hint, and a cache type specifier as shown +below: + +```mlir +memref.prefetch %0[%i, %j], read, locality<3>, data : memref<400x400xi32> +``` + +The read/write specifier is either \'read\' or \'write\', the locality hint +ranges from locality<0> (no locality) to locality<3> (extremely local keep +in cache). The cache type specifier is either \'data\' or \'instr\' +and specifies whether the prefetch is performed on data cache or on +instruction cache. +""" +function prefetch( + memref::Value, + indices::Vector{Value}; + isWrite, + localityHint, + isDataCache, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[memref, indices...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("isWrite", isWrite), + namedattribute("localityHint", localityHint), + namedattribute("isDataCache", isDataCache), + ] + + return create_operation( + "memref.prefetch", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`rank` + +The `memref.rank` operation takes a memref operand and returns its rank. + +# Example + +```mlir +%0 = memref.rank %arg0 : memref<*xf32> +%1 = memref.rank %arg1 : memref +``` +""" +function rank(memref::Value; result_0=nothing::Union{Nothing,IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[memref,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result_0) && push!(op_ty_results, result_0) + + return create_operation( + "memref.rank", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`realloc` + +The `realloc` operation changes the size of a memory region. The memory +region is specified by a 1D source memref and the size of the new memory +region is specified by a 1D result memref type and an optional dynamic Value +of `Index` type. The source and the result memref must be in the same memory +space and have the same element type. + +The operation may move the memory region to a new location. In this case, +the content of the memory block is preserved up to the lesser of the new +and old sizes. If the new size if larger, the value of the extended memory +is undefined. This is consistent with the ISO C realloc. + +The operation returns an SSA value for the memref. + +# Example + +```mlir +%0 = memref.realloc %src : memref<64xf32> to memref<124xf32> +``` + +The source memref may have a dynamic shape, in which case, the compiler will +generate code to extract its size from the runtime data structure for the +memref. + +```mlir +%1 = memref.realloc %src : memref to memref<124xf32> +``` + +If the result memref has a dynamic shape, a result dimension operand is +needed to spefify its dynamic dimension. In the example below, the ssa value +\'%d\' specifies the unknown dimension of the result memref. + +```mlir +%2 = memref.realloc %src(%d) : memref to memref +``` + +An optional `alignment` attribute may be specified to ensure that the +region of memory that will be indexed is aligned at the specified byte +boundary. This is consistent with the fact that memref.alloc supports such +an optional alignment attribute. Note that in ISO C standard, neither alloc +nor realloc supports alignment, though there is aligned_alloc but not +aligned_realloc. + +```mlir +%3 = memref.realloc %src {alignment = 8} : memref<64xf32> to memref<124xf32> +``` + +Referencing the memref through the old SSA value after realloc is undefined +behavior. + +```mlir +%new = memref.realloc %old : memref<64xf32> to memref<124xf32> +%4 = memref.load %new[%index] // ok +%5 = memref.load %old[%index] // undefined behavior +``` +""" +function realloc( + source::Value, + dynamicResultSize=nothing::Union{Nothing,Value}; + result_0::IR.Type, + alignment=nothing, + location=Location(), +) + op_ty_results = IR.Type[result_0,] + operands = Value[source,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(dynamicResultSize) && push!(operands, dynamicResultSize) + !isnothing(alignment) && push!(attributes, namedattribute("alignment", alignment)) + + return create_operation( + "memref.realloc", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`reinterpret_cast` + +Modify offset, sizes and strides of an unranked/ranked memref. + +# Example +```mlir +memref.reinterpret_cast %ranked to + offset: [0], + sizes: [%size0, 10], + strides: [1, %stride1] +: memref to memref> + +memref.reinterpret_cast %unranked to + offset: [%offset], + sizes: [%size0, %size1], + strides: [%stride0, %stride1] +: memref<*xf32> to memref> +``` + +This operation creates a new memref descriptor using the base of the +source and applying the input arguments to the other metadata. +In other words: +```mlir +%dst = memref.reinterpret_cast %src to + offset: [%offset], + sizes: [%sizes], + strides: [%strides] +``` +means that `%dst`\'s descriptor will be: +```mlir +%dst.base = %src.base +%dst.aligned = %src.aligned +%dst.offset = %offset +%dst.sizes = %sizes +%dst.strides = %strides +``` +""" +function reinterpret_cast( + source::Value, + offsets::Vector{Value}, + sizes::Vector{Value}, + strides::Vector{Value}; + result::IR.Type, + static_offsets, + static_sizes, + static_strides, + location=Location(), +) + op_ty_results = IR.Type[result,] + operands = Value[source, offsets..., sizes..., strides...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("static_offsets", static_offsets), + namedattribute("static_sizes", static_sizes), + namedattribute("static_strides", static_strides), + ] + push!( + attributes, + operandsegmentsizes([1, length(offsets), length(sizes), length(strides)]), + ) + + return create_operation( + "memref.reinterpret_cast", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`reshape` + +The `reshape` operation converts a memref from one type to an +equivalent type with a provided shape. The data is never copied or +modified. The source and destination types are compatible if both have the +same element type, same number of elements, address space and identity +layout map. The following combinations are possible: + +a. Source type is ranked or unranked. Shape argument has static size. +Result type is ranked. + +```mlir +// Reshape statically-shaped memref. +%dst = memref.reshape %src(%shape) + : (memref<4x1xf32>, memref<1xi32>) to memref<4xf32> +%dst0 = memref.reshape %src(%shape0) + : (memref<4x1xf32>, memref<2xi32>) to memref<2x2xf32> +// Flatten unranked memref. +%dst = memref.reshape %src(%shape) + : (memref<*xf32>, memref<1xi32>) to memref +``` + +b. Source type is ranked or unranked. Shape argument has dynamic size. +Result type is unranked. + +```mlir +// Reshape dynamically-shaped 1D memref. +%dst = memref.reshape %src(%shape) + : (memref, memref) to memref<*xf32> +// Reshape unranked memref. +%dst = memref.reshape %src(%shape) + : (memref<*xf32>, memref) to memref<*xf32> +``` +""" +function reshape(source::Value, shape::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[source, shape] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "memref.reshape", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`store` + +Store a value to a memref location given by indices. The value stored should +have the same type as the elemental type of the memref. The number of +arguments provided within brackets need to match the rank of the memref. + +In an affine context, the indices of a store are restricted to SSA values +bound to surrounding loop induction variables, +[symbols](Affine.md/#restrictions-on-dimensions-and-symbols), results of a +`constant` operation, or the result of an +[`affine.apply`](Affine.md/#affineapply-affineapplyop) operation that can in +turn take as arguments all of the aforementioned SSA values or the +recursively result of such an `affine.apply` operation. + +# Example + +```mlir +memref.store %100, %A[%1, 1023] : memref<4x?xf32, #layout, memspace0> +``` + +**Context:** The `load` and `store` operations are specifically crafted to +fully resolve a reference to an element of a memref, and (in polyhedral +`affine.if` and `affine.for` operations) the compiler can follow use-def +chains (e.g. through [`affine.apply`](Affine.md/#affineapply-affineapplyop) +operations) to precisely analyze references at compile-time using polyhedral +techniques. This is possible because of the +[restrictions on dimensions and symbols](Affine.md/#restrictions-on-dimensions-and-symbols) +in these contexts. +""" +function store( + value::Value, + memref::Value, + indices::Vector{Value}; + nontemporal=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[value, memref, indices...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(nontemporal) && push!(attributes, namedattribute("nontemporal", nontemporal)) + + return create_operation( + "memref.store", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`transpose` + +The `transpose` op produces a strided memref whose sizes and strides +are a permutation of the original `in` memref. This is purely a metadata +transformation. + +# Example + +```mlir +%1 = memref.transpose %0 (i, j) -> (j, i) : memref to memref (d1 * s0 + d0)>> +``` +""" +function transpose(in::Value; result_0::IR.Type, permutation, location=Location()) + op_ty_results = IR.Type[result_0,] + operands = Value[in,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("permutation", permutation),] + + return create_operation( + "memref.transpose", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`view` + +The \"view\" operation extracts an N-D contiguous memref with empty layout map +with arbitrary element type from a 1-D contiguous memref with empty layout +map of i8 element type. The ViewOp supports the following arguments: + +* A single dynamic byte-shift operand must be specified which represents a + a shift of the base 1-D memref pointer from which to create the resulting + contiguous memref view with identity layout. +* A dynamic size operand that must be specified for each dynamic dimension + in the resulting view memref type. + +The \"view\" operation gives a structured indexing form to a flat 1-D buffer. +Unlike \"subview\" it can perform a type change. The type change behavior +requires the op to have special semantics because, e.g. a byte shift of 3 +cannot be represented as an offset on f64. +For now, a \"view\" op: + +1. Only takes a contiguous source memref with 0 offset and empty layout. +2. Must specify a byte_shift operand (in the future, a special integer + attribute may be added to support the folded case). +3. Returns a contiguous memref with 0 offset and empty layout. + +# Example + +```mlir +// Allocate a flat 1D/i8 memref. +%0 = memref.alloc() : memref<2048xi8> + +// ViewOp with dynamic offset and static sizes. +%1 = memref.view %0[%offset_1024][] : memref<2048xi8> to memref<64x4xf32> + +// ViewOp with dynamic offset and two dynamic size. +%2 = memref.view %0[%offset_1024][%size0, %size1] : + memref<2048xi8> to memref +``` +""" +function view( + source::Value, + byte_shift::Value, + sizes::Vector{Value}; + result_0::IR.Type, + location=Location(), +) + op_ty_results = IR.Type[result_0,] + operands = Value[source, byte_shift, sizes...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "memref.view", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`subview` + +The \"subview\" operation converts a memref type to another memref type +which represents a reduced-size view of the original memref as specified by +the operation\'s offsets, sizes and strides arguments. + +The SubView operation supports the following arguments: + +* source: the \"base\" memref on which to create a \"view\" memref. +* offsets: memref-rank number of offsets into the \"base\" memref at which to + create the \"view\" memref. +* sizes: memref-rank number of sizes which specify the sizes of the result + \"view\" memref type. +* strides: memref-rank number of strides that compose multiplicatively with + the base memref strides in each dimension. + +The representation based on offsets, sizes and strides support a +partially-static specification via attributes specified through the +`static_offsets`, `static_sizes` and `static_strides` arguments. A special +sentinel value ShapedType::kDynamic encodes that the corresponding entry has +a dynamic value. + +A subview operation may additionally reduce the rank of the resulting view +by removing dimensions that are statically known to be of size 1. + +Example 1: + +```mlir +%0 = memref.alloc() : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> + +// Create a sub-view of \"base\" memref \'%0\' with offset arguments \'%c0\', +// dynamic sizes for each dimension, and stride arguments \'%c1\'. +%1 = memref.subview %0[%c0, %c0][%size0, %size1][%c1, %c1] + : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> to + memref (d0 * s1 + d1 + s0)>> +``` + +Example 2: + +```mlir +%0 = memref.alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> + +// Create a sub-view of \"base\" memref \'%0\' with dynamic offsets, sizes, +// and strides. +// Note that dynamic offsets are represented by the linearized dynamic +// offset symbol \'s0\' in the subview memref layout map, and that the +// dynamic strides operands, after being applied to the base memref +// strides in each dimension, are represented in the view memref layout +// map as symbols \'s1\', \'s2\' and \'s3\'. +%1 = memref.subview %0[%i, %j, %k][%size0, %size1, %size2][%x, %y, %z] + : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to + memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> +``` + +Example 3: + +```mlir +%0 = memref.alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> + +// Subview with constant offsets, sizes and strides. +%1 = memref.subview %0[0, 2, 0][4, 4, 4][1, 1, 1] + : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to + memref<4x4x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2 + 8)>> +``` + +Example 4: + +```mlir +%0 = memref.alloc(%arg0, %arg1) : memref + +// Subview with constant size, but dynamic offsets and +// strides. The resulting memref has a static shape, but if the +// base memref has an affine map to describe the layout, the result +// memref also uses an affine map to describe the layout. The +// strides of the result memref is computed as follows: +// +// Let #map1 represents the layout of the base memref, and #map2 +// represents the layout of the result memref. A #mapsubview can be +// constructed to map an index from the result memref to the base +// memref (note that the description below uses more convenient +// naming for symbols, while in affine maps, symbols are +// represented as unsigned numbers that identify that symbol in the +// given affine map. +// +// #mapsubview = (d0, d1)[o0, o1, t0, t1] -> (d0 * t0 + o0, d1 * t1 + o1) +// +// where, o0, o1, ... are offsets, and t0, t1, ... are strides. Then, +// +// #map2 = #map1.compose(#mapsubview) +// +// If the layout map is represented as +// +// #map1 = (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0) +// +// then, +// +// #map2 = (d0, d1)[s0, s1, s2, o0, o1, t0, t1] -> +// (d0 * s1 * t0 + d1 * s2 * t1 + o0 * s1 + o1 * s2 + s0) +// +// Representing this canonically +// +// #map2 = (d0, d1)[r0, r1, r2] -> (d0 * r1 + d1 * r2 + r0) +// +// where, r0 = o0 * s1 + o1 * s2 + s0, r1 = s1 * t0, r2 = s2 * t1. +%1 = memref.subview %0[%i, %j][4, 4][%x, %y] : + : memref (d0 * s1 + d1 * s2 + s0)>> to + memref<4x4xf32, affine_map<(d0, d1)[r0, r1, r2] -> (d0 * r1 + d1 * r2 + r0)>> + +// Note that the subview op does not guarantee that the result +// memref is \"inbounds\" w.r.t to base memref. It is upto the client +// to ensure that the subview is accessed in a manner that is +// in-bounds. +``` + +Example 5: + +```mlir +// Rank-reducing subview. +%1 = memref.subview %0[0, 0, 0][1, 16, 4][1, 1, 1] : + memref<8x16x4xf32> to memref<16x4xf32> + +// Original layout: +// (d0, d1, d2) -> (64 * d0 + 16 * d1 + d2) +// Subviewed layout: +// (d0, d1, d2) -> (64 * (d0 + 3) + 4 * (d1 + 4) + d2 + 2) = (64 * d0 + 4 * d1 + d2 + 210) +// After rank reducing: +// (d0, d1) -> (4 * d0 + d1 + 210) +%3 = memref.subview %2[3, 4, 2][1, 6, 3][1, 1, 1] : + memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>> +``` +""" +function subview( + source::Value, + offsets::Vector{Value}, + sizes::Vector{Value}, + strides::Vector{Value}; + result::IR.Type, + static_offsets, + static_sizes, + static_strides, + location=Location(), +) + op_ty_results = IR.Type[result,] + operands = Value[source, offsets..., sizes..., strides...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("static_offsets", static_offsets), + namedattribute("static_sizes", static_sizes), + namedattribute("static_strides", static_strides), + ] + push!( + attributes, + operandsegmentsizes([1, length(offsets), length(sizes), length(strides)]), + ) + + return create_operation( + "memref.subview", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +end # memref diff --git a/src/mlir/Dialects/Nvvm.jl b/src/mlir/Dialects/Nvvm.jl new file mode 100755 index 0000000000..56eba720e9 --- /dev/null +++ b/src/mlir/Dialects/Nvvm.jl @@ -0,0 +1,3682 @@ +module nvvm +using ...IR +import ...IR: + NamedAttribute, + Value, + Location, + Block, + Region, + Attribute, + create_operation, + context, + IndexType +import ..Dialects: namedattribute, operandsegmentsizes +import ...API + +function barrier0(; location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.barrier0", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`barrier_arrive` + +Thread that executes this op announces their arrival at the barrier with +given id and continue their execution. + +The default barrier id is 0 that is similar to `nvvm.barrier` Op. When +`barrierId` is not present, the default barrier id is used. + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar) +""" +function barrier_arrive( + barrierId=nothing::Union{Nothing,Value}; numberOfThreads::Value, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[numberOfThreads,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(barrierId) && push!(operands, barrierId) + + return create_operation( + "nvvm.barrier.arrive", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function barrier( + barrierId=nothing::Union{Nothing,Value}; + numberOfThreads=nothing::Union{Nothing,Value}, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(barrierId) && push!(operands, barrierId) + !isnothing(numberOfThreads) && push!(operands, numberOfThreads) + push!(attributes, operandsegmentsizes([ + if (barrierId == nothing) + 0 + elseif 1(numberOfThreads == nothing) + 0 + else + 1 + end, + ])) + + return create_operation( + "nvvm.barrier", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_ntid_x(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.ntid.x", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_ntid_y(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.ntid.y", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_ntid_z(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.ntid.z", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_ctaid_x(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.ctaid.x", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_ctaid_y(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.ctaid.y", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_ctaid_z(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.ctaid.z", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_cluster_ctaid_x(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.cluster.ctaid.x", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_cluster_ctaid_y(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.cluster.ctaid.y", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_cluster_ctaid_z(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.cluster.ctaid.z", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`breakpoint` + +Breakpoint suspends execution of the program for debugging. +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-brkpt) +""" +function breakpoint(; location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.breakpoint", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_clock64(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.clock64", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_clock(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.clock", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`cluster_arrive` + +The `cluster.arrive` can be used by the threads within the cluster for synchronization and +communication. The `cluster.arrive` instruction marks the warps\' arrival at the barrier +without causing the executing thread to wait for other participating threads. + +The `aligned` attribute, when provided, generates the .aligned version of the PTX instruction. + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-barrier-cluster) +""" +function cluster_arrive(; aligned=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(aligned) && push!(attributes, namedattribute("aligned", aligned)) + + return create_operation( + "nvvm.cluster.arrive", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`cluster_arrive_relaxed` + +The `cluster.arrive` can be used by the threads within the cluster for synchronization and +communication. The `cluster.arrive` instruction marks the warps\' arrival at the barrier +without causing the executing thread to wait for other participating threads. + +The `aligned` attribute, when provided, generates the .aligned version of the PTX instruction. +The .relaxed qualifier on `cluster.arrive` specifies that there are no memory +ordering and visibility guarantees provided for the memory accesses performed prior to +`cluster.arrive`. + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-barrier-cluster) +""" +function cluster_arrive_relaxed(; aligned=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(aligned) && push!(attributes, namedattribute("aligned", aligned)) + + return create_operation( + "nvvm.cluster.arrive.relaxed", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_cluster_nctarank(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.cluster.nctarank", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_cluster_nctaid_x(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.cluster.nctaid.x", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_cluster_nctaid_y(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.cluster.nctaid.y", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_cluster_nctaid_z(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.cluster.nctaid.z", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_nclusterid_x(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.nclusterid.x", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_nclusterid_y(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.nclusterid.y", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_nclusterid_z(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.nclusterid.z", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_cluster_ctarank(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.cluster.ctarank", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_clusterid_x(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.clusterid.x", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_clusterid_y(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.clusterid.y", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_clusterid_z(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.clusterid.z", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`cluster_wait` + +The `cluster.wait` causes the executing thread to wait for all non-exited threads +of the cluster to perform `cluster.arrive`. The `aligned` attribute, when provided, +generates the .aligned version of the PTX instruction. + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-barrier-cluster) +""" +function cluster_wait(; aligned=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(aligned) && push!(attributes, namedattribute("aligned", aligned)) + + return create_operation( + "nvvm.cluster.wait", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`cp_async_bulk_commit_group` + +This Op commits all prior initiated but uncommitted cp.async.bulk +instructions into a cp.async.bulk-group. + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group) +""" +function cp_async_bulk_commit_group(; location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.cp.async.bulk.commit.group", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`cp_async_bulk_shared_cluster_global` + +Initiates an asynchronous copy operation from global memory to cluster\'s +shared memory. + +The `multicastMask` operand is optional. When it is present, the Op copies +data from global memory to shared memory of multiple CTAs in the cluster. +Operand `multicastMask` specifies the destination CTAs in the cluster such +that each bit position in the 16-bit `multicastMask` operand corresponds to +the `nvvm.read.ptx.sreg.ctaid` of the destination CTA. + +The `l2CacheHint` operand is optional, and it is used to specify cache +eviction policy that may be used during the memory access. + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk) +""" +function cp_async_bulk_shared_cluster_global( + dstMem::Value, + srcMem::Value, + mbar::Value, + size::Value, + multicastMask=nothing::Union{Nothing,Value}; + l2CacheHint=nothing::Union{Nothing,Value}, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[dstMem, srcMem, mbar, size] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(multicastMask) && push!(operands, multicastMask) + !isnothing(l2CacheHint) && push!(operands, l2CacheHint) + push!(attributes, operandsegmentsizes([ + 1, + 1, + 1, + 1, + if (multicastMask == nothing) + 0 + elseif 1(l2CacheHint == nothing) + 0 + else + 1 + end, + ])) + + return create_operation( + "nvvm.cp.async.bulk.shared.cluster.global", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`cp_async_bulk_global_shared_cta` + +Initiates an asynchronous copy operation from Shared CTA memory to +global memory. + +The `l2CacheHint` operand is optional, and it is used to specify cache +eviction policy that may be used during the memory access. + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk) +""" +function cp_async_bulk_global_shared_cta( + dstMem::Value, + srcMem::Value, + size::Value, + l2CacheHint=nothing::Union{Nothing,Value}; + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[dstMem, srcMem, size] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(l2CacheHint) && push!(operands, l2CacheHint) + + return create_operation( + "nvvm.cp.async.bulk.global.shared.cta", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`cp_async_bulk_shared_cluster_shared_cta` + +Initiates an asynchronous copy operation from Shared CTA memory to Shared +cluster memory. + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk) +""" +function cp_async_bulk_shared_cluster_shared_cta( + dstMem::Value, srcMem::Value, mbar::Value, size::Value; location=Location() +) + op_ty_results = IR.Type[] + operands = Value[dstMem, srcMem, mbar, size] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.cp.async.bulk.shared.cluster.shared.cta", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`cp_async_bulk_tensor_shared_cluster_global` + +Initiates an asynchronous copy operation on the tensor data from global +memory to shared memory. + +The Op operates has two load modes: +1) Tiled Mode: It\'s the default mode. The source multi-dimensional tensor +layout is preserved at the destination. + +2) Im2col Mode: This mode is used when `im2colOffsets` operands are present. +the elements in the Bounding Box of the source tensor are rearranged into +columns at the destination. In this mode, the tensor has to be at least +3-dimensional. + +The `multicastMask` operand is optional. When it is present, the Op copies +data from global memory to shared memory of multiple CTAs in the cluster. +Operand `multicastMask` specifies the destination CTAs in the cluster such +that each bit position in the 16-bit `multicastMask` operand corresponds to +the `nvvm.read.ptx.sreg.ctaid` of the destination CTA. + +The `l2CacheHint` operand is optional, and it is used to specify cache +eviction policy that may be used during the memory access. + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor) +""" +function cp_async_bulk_tensor_shared_cluster_global( + dstMem::Value, + tmaDescriptor::Value, + coordinates::Vector{Value}, + mbar::Value, + im2colOffsets::Vector{Value}, + multicastMask=nothing::Union{Nothing,Value}; + l2CacheHint=nothing::Union{Nothing,Value}, + predicate=nothing::Union{Nothing,Value}, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[dstMem, tmaDescriptor, coordinates..., mbar, im2colOffsets...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(multicastMask) && push!(operands, multicastMask) + !isnothing(l2CacheHint) && push!(operands, l2CacheHint) + !isnothing(predicate) && push!(operands, predicate) + push!( + attributes, + operandsegmentsizes([ + 1, + 1, + length(coordinates), + 1, + length(im2colOffsets), + if (multicastMask == nothing) + 0 + elseif 1(l2CacheHint == nothing) + 0 + elseif 1(predicate == nothing) + 0 + else + 1 + end, + ]), + ) + + return create_operation( + "nvvm.cp.async.bulk.tensor.shared.cluster.global", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`cp_async_bulk_tensor_prefetch` + +Initiates an asynchronous prefetch operation on the tensor data from global +memory to L2 cache. + +The Op has two modes: +1) Tiled Mode: It\'s the default mode. The source multi-dimensional tensor +layout is preserved at the destination. + +2) Im2col Mode: This mode is used when `im2colOffsets` operands are present. +the elements in the Bounding Box of the source tensor are rearranged into +columns at the destination. In this mode, the tensor has to be at least +3-dimensional. + +The `l2CacheHint` operand is optional, and it is used to specify cache +eviction policy that may be used during the memory access. + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-prefetch-tensor) +""" +function cp_async_bulk_tensor_prefetch( + tmaDescriptor::Value, + coordinates::Vector{Value}, + im2colOffsets::Vector{Value}, + l2CacheHint=nothing::Union{Nothing,Value}; + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[tmaDescriptor, coordinates..., im2colOffsets...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(l2CacheHint) && push!(operands, l2CacheHint) + push!( + attributes, + operandsegmentsizes([ + 1, length(coordinates), length(im2colOffsets), (l2CacheHint == nothing) ? 0 : 1 + ]), + ) + + return create_operation( + "nvvm.cp.async.bulk.tensor.prefetch", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`cp_async_bulk_tensor_reduce` + +Initiates an asynchronous reduction operation of tensor data in +global memory with tensor data in shared memory. + +The `mode` attribute indicates whether the copy mode is tile or im2col. +The `redOp` attribute specifies the reduction operations applied. +The supported reduction operations are: +{add, min, max, inc, dec, and, or, xor} + +The `l2CacheHint` operand is optional, and it is used to specify cache +eviction policy that may be used during the memory access. + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-reduce-async-bulk-tensor) +""" +function cp_async_bulk_tensor_reduce( + tmaDescriptor::Value, + srcMem::Value, + coordinates::Vector{Value}, + l2CacheHint=nothing::Union{Nothing,Value}; + redKind, + mode=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[tmaDescriptor, srcMem, coordinates...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("redKind", redKind),] + !isnothing(l2CacheHint) && push!(operands, l2CacheHint) + push!( + attributes, + operandsegmentsizes([1, 1, length(coordinates), (l2CacheHint == nothing) ? 0 : 1]), + ) + !isnothing(mode) && push!(attributes, namedattribute("mode", mode)) + + return create_operation( + "nvvm.cp.async.bulk.tensor.reduce", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function cp_async_bulk_tensor_global_shared_cta( + tmaDescriptor::Value, + srcMem::Value, + coordinates::Vector{Value}, + predicate=nothing::Union{Nothing,Value}; + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[tmaDescriptor, srcMem, coordinates...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(predicate) && push!(operands, predicate) + push!( + attributes, + operandsegmentsizes([1, 1, length(coordinates), (predicate == nothing) ? 0 : 1]), + ) + + return create_operation( + "nvvm.cp.async.bulk.tensor.global.shared.cta", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`cp_async_bulk_wait_group` + +Op waits for completion of the most recent bulk async-groups. + +The `\$group` operand tells waiting has to be done until for \$group or fewer +of the most recent bulk async-groups. If `\$group` is 0, the op wait until +all the most recent bulk async-groups have completed. + +The `\$read` indicates that the waiting has to be done until all the bulk +async operations in the specified bulk async-group have completed reading +from their source locations. + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group) +""" +function cp_async_bulk_wait_group(; group, read=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("group", group),] + !isnothing(read) && push!(attributes, namedattribute("read", read)) + + return create_operation( + "nvvm.cp.async.bulk.wait_group", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function cp_async_commit_group(; location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.cp.async.commit.group", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`cp_async_mbarrier_arrive` + +The `cp.async.mbarrier.arrive` Op makes the mbarrier object track +all prior cp.async operations initiated by the executing thread. +The `addr` operand specifies the address of the mbarrier object +in generic address space. The `noinc` attr impacts how the +mbarrier\'s state is updated. + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive) +""" +function cp_async_mbarrier_arrive(addr::Value; noinc=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[addr,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(noinc) && push!(attributes, namedattribute("noinc", noinc)) + + return create_operation( + "nvvm.cp.async.mbarrier.arrive", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`cp_async_mbarrier_arrive_shared` + +The `cp.async.mbarrier.arrive.shared` Op makes the mbarrier object +track all prior cp.async operations initiated by the executing thread. +The `addr` operand specifies the address of the mbarrier object in +shared memory. The `noinc` attr impacts how the mbarrier\'s state +is updated. + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive) +""" +function cp_async_mbarrier_arrive_shared(addr::Value; noinc=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[addr,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(noinc) && push!(attributes, namedattribute("noinc", noinc)) + + return create_operation( + "nvvm.cp.async.mbarrier.arrive.shared", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function cp_async_shared_global( + dst::Value, + src::Value, + cpSize=nothing::Union{Nothing,Value}; + size, + modifier, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[dst, src] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("size", size), namedattribute("modifier", modifier) + ] + !isnothing(cpSize) && push!(operands, cpSize) + + return create_operation( + "nvvm.cp.async.shared.global", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function cp_async_wait_group(; n, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("n", n),] + + return create_operation( + "nvvm.cp.async.wait.group", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`cvt_float_to_tf32` + +This Op converts the given f32 input to tf32. +The result `res` is represented as an i32 type. +The `relu` attribute, when set, lowers to the \'.relu\' variant of +the cvt instruction. The `rnd` and `sat` attributes specify the +the rounding and saturation modes respectively. + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) +""" +function cvt_float_to_tf32( + src::Value; res::IR.Type, rnd=nothing, sat=nothing, relu=nothing, location=Location() +) + op_ty_results = IR.Type[res,] + operands = Value[src,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(rnd) && push!(attributes, namedattribute("rnd", rnd)) + !isnothing(sat) && push!(attributes, namedattribute("sat", sat)) + !isnothing(relu) && push!(attributes, namedattribute("relu", relu)) + + return create_operation( + "nvvm.cvt.float.to.tf32", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`elect_sync` + +The `elect.sync` instruction elects one predicated active leader +thread from among a set of threads specified in membermask. +The membermask is set to `0xFFFFFFFF` for the current version +of this Op. The predicate result is set to `True` for the +leader thread, and `False` for all other threads. + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-elect-sync) +""" +function elect_sync(; pred::IR.Type, location=Location()) + op_ty_results = IR.Type[pred,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.elect.sync", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg0(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg0", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg1(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg1", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg2(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg2", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg3(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg3", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg4(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg4", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg5(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg5", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg6(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg6", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg7(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg7", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg8(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg8", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg9(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg9", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg10(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg10", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg11(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg11", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg12(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg12", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg13(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg13", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg14(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg14", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg15(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg15", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg16(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg16", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg17(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg17", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg18(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg18", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg19(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg19", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg20(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg20", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg21(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg21", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg22(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg22", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg23(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg23", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg24(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg24", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg25(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg25", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg26(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg26", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg27(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg27", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg28(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg28", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg29(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg29", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg30(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg30", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_envreg31(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.envreg31", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`exit` + +Ends execution of a thread. +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-exit) +""" +function exit(; location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.exit", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`fence_mbarrier_init` + +Fence operation that applies on the prior nvvm.mbarrier.init + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar) +""" +function fence_mbarrier_init(; location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.fence.mbarrier.init", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`fence_proxy_acquire` + +`fence.proxy.acquire` is a uni-directional fence used to establish ordering +between a prior memory access performed via the generic proxy and a +subsequent memory access performed via the tensormap proxy + +The address operand `addr` and the operand `size` together specify the +memory range `[addr, addr+size)` on which the ordering guarantees on the +memory accesses across the proxies is to be provided. The only supported +value for the `size` operand is 128 and must be an immediate. Generic Addressing +is used unconditionally, and the address specified by the operand `addr` must +fall within the `.global` state space. Otherwise, the behavior is undefined + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar) +""" +function fence_proxy_acquire( + addr::Value, size::Value; scope, fromProxy=nothing, toProxy=nothing, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[addr, size] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("scope", scope),] + !isnothing(fromProxy) && push!(attributes, namedattribute("fromProxy", fromProxy)) + !isnothing(toProxy) && push!(attributes, namedattribute("toProxy", toProxy)) + + return create_operation( + "nvvm.fence.proxy.acquire", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`fence_proxy` + +Fence operation with proxy to establish an ordering between memory accesses +that may happen through different proxies. + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar) +""" +function fence_proxy(; kind, space=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("kind", kind),] + !isnothing(space) && push!(attributes, namedattribute("space", space)) + + return create_operation( + "nvvm.fence.proxy", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`fence_proxy_release` + +`fence.proxy.release` is a uni-directional fence used to establish ordering +between a prior memory access performed via the generic proxy and a +subsequent memory access performed via the tensormap proxy. `fence.proxy.release` +operation can form a release sequence that synchronizes with an acquire +sequence that contains the fence.proxy.acquire proxy fence operation + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar) +""" +function fence_proxy_release(; + scope, fromProxy=nothing, toProxy=nothing, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("scope", scope),] + !isnothing(fromProxy) && push!(attributes, namedattribute("fromProxy", fromProxy)) + !isnothing(toProxy) && push!(attributes, namedattribute("toProxy", toProxy)) + + return create_operation( + "nvvm.fence.proxy.release", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function fence_sc_cluster(; location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.fence.sc.cluster", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_globaltimer(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.globaltimer", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_nctaid_x(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.nctaid.x", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_nctaid_y(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.nctaid.y", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_nctaid_z(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.nctaid.z", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_gridid(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.gridid", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`griddepcontrol_launch_dependents` + +Signals that specific dependents the runtime system designated to react to +this instruction can be scheduled as soon as all other CTAs in the grid +issue the same instruction or have completed. + + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-griddepcontrol) +""" +function griddepcontrol_launch_dependents(; location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.griddepcontrol.launch.dependents", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`griddepcontrol_wait` + +Causes the executing thread to wait until all prerequisite grids in flight +have completed and all the memory operations from the prerequisite grids +are performed and made visible to the current grid. + + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-griddepcontrol) +""" +function griddepcontrol_wait(; location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.griddepcontrol.wait", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_laneid(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.laneid", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_lanemask_eq(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.lanemask.eq", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_lanemask_ge(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.lanemask.ge", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_lanemask_gt(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.lanemask.gt", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_lanemask_le(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.lanemask.le", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_lanemask_lt(; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.read.ptx.sreg.lanemask.lt", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function ldmatrix(ptr::Value; res::IR.Type, num, layout, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[ptr,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("num", num), namedattribute("layout", layout) + ] + + return create_operation( + "nvvm.ldmatrix", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function mbarrier_arrive_expect_tx( + addr::Value, + txcount::Value, + predicate=nothing::Union{Nothing,Value}; + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[addr, txcount] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(predicate) && push!(operands, predicate) + + return create_operation( + "nvvm.mbarrier.arrive.expect_tx", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function mbarrier_arrive_expect_tx_shared( + addr::Value, + txcount::Value, + predicate=nothing::Union{Nothing,Value}; + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[addr, txcount] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(predicate) && push!(operands, predicate) + + return create_operation( + "nvvm.mbarrier.arrive.expect_tx.shared", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function mbarrier_arrive_nocomplete( + addr::Value, count::Value; res::IR.Type, location=Location() +) + op_ty_results = IR.Type[res,] + operands = Value[addr, count] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.mbarrier.arrive.nocomplete", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function mbarrier_arrive_nocomplete_shared( + addr::Value, count::Value; res::IR.Type, location=Location() +) + op_ty_results = IR.Type[res,] + operands = Value[addr, count] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.mbarrier.arrive.nocomplete.shared", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function mbarrier_arrive(addr::Value; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[addr,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.mbarrier.arrive", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function mbarrier_arrive_shared(addr::Value; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[addr,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.mbarrier.arrive.shared", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function mbarrier_init( + addr::Value, count::Value, predicate=nothing::Union{Nothing,Value}; location=Location() +) + op_ty_results = IR.Type[] + operands = Value[addr, count] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(predicate) && push!(operands, predicate) + + return create_operation( + "nvvm.mbarrier.init", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function mbarrier_init_shared( + addr::Value, count::Value, predicate=nothing::Union{Nothing,Value}; location=Location() +) + op_ty_results = IR.Type[] + operands = Value[addr, count] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(predicate) && push!(operands, predicate) + + return create_operation( + "nvvm.mbarrier.init.shared", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function mbarrier_inval(addr::Value; location=Location()) + op_ty_results = IR.Type[] + operands = Value[addr,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.mbarrier.inval", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function mbarrier_inval_shared(addr::Value; location=Location()) + op_ty_results = IR.Type[] + operands = Value[addr,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.mbarrier.inval.shared", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function mbarrier_test_wait(addr::Value, state::Value; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[addr, state] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.mbarrier.test.wait", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function mbarrier_test_wait_shared( + addr::Value, state::Value; res::IR.Type, location=Location() +) + op_ty_results = IR.Type[res,] + operands = Value[addr, state] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.mbarrier.test.wait.shared", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function mbarrier_try_wait_parity( + addr::Value, phase::Value, ticks::Value; location=Location() +) + op_ty_results = IR.Type[] + operands = Value[addr, phase, ticks] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.mbarrier.try_wait.parity", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function mbarrier_try_wait_parity_shared( + addr::Value, phase::Value, ticks::Value; location=Location() +) + op_ty_results = IR.Type[] + operands = Value[addr, phase, ticks] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.mbarrier.try_wait.parity.shared", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function mapa(a::Value, b::Value; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[a, b] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.mapa", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`mma_sync` + +The `nvvm.mma.sync` operation collectively performs the operation +`D = matmul(A, B) + C` using all threads in a warp. + +All the threads in the warp must execute the same `mma.sync` operation. + +For each possible multiplicand PTX data type, there are one or more possible +instruction shapes given as \"mMnNkK\". The below table describes the posssibilities +as well as the types required for the operands. Note that the data type for +C (the accumulator) and D (the result) can vary independently when there are +multiple possibilities in the \"C/D Type\" column. + +When an optional attribute cannot be immediately inferred from the types of +the operands and the result during parsing or validation, an error will be +raised. + +`b1Op` is only relevant when the binary (b1) type is given to +`multiplicandDataType`. It specifies how the multiply-and-acumulate is +performed and is either `xor_popc` or `and_poc`. The default is `xor_popc`. + +`intOverflowBehavior` is only relevant when the `multiplicandType` attribute +is one of `u8, s8, u4, s4`, this attribute describes how overflow is handled +in the accumulator. When the attribute is `satfinite`, the accumulator values +are clamped in the int32 range on overflow. This is the default behavior. +Alternatively, accumulator behavior `wrapped` can also be specified, in +which case overflow wraps from one end of the range to the other. + +`layoutA` and `layoutB` are required and should generally be set to +`#nvvm.mma_layout` and `#nvvm.mma_layout` respectively, but other +combinations are possible for certain layouts according to the table below. + +``` +| A/B Type | Shape | ALayout | BLayout | A Type | B Type | C/D Type | +|----------|-----------|---------|---------|----------|----------|-------------------| +| f64 | .m8n8k4 | row | col | 1x f64 | 1x f64 | 2x f64 | +| f16 | .m8n8k4 | row/col | row/col | 2x f16x2 | 2x f16x2 | 4x f16x2 or 8xf32 | +| | .m16n8k8 | row | col | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 | +| | .m16n8k16 | row | col | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 | +| bf16 | .m16n8k8 | row | col | 2x i32 | 1x i32 | 4x f32 | +| | .m16n8k16 | row | col | 4x i32 | 2x i32 | 4x f32 | +| tf32 | .m16n8k4 | row | col | 2x i32 | 1x i32 | 4x f32 | +| | .m16n8k8 | row | col | 4x i32 | 2x i32 | 2x f16x2 or 4 f32 | +| u8/s8 | .m8n8k16 | row | col | 1x i32 | 1x i32 | 2x i32 | +| | .m16n8k16 | row | col | 2x i32 | 1x i32 | 4x i32 | +| | .m16n8k32 | row | col | 4x i32 | 2x i32 | 4x i32 | +| u4/s4 | .m8n8k32 | row | col | 1x i32 | 1x i32 | 2x i32 | +| | m16n8k32 | row | col | 2x i32 | 1x i32 | 4x i32 | +| | m16n8k64 | row | col | 4x i32 | 2x i32 | 4x i32 | +| b1 | m8n8k128 | row | col | 1x i32 | 1x i32 | 2x i32 | +| | m16n8k128 | row | col | 2x i32 | 1x i32 | 4x i32 | +``` + + +# Example +```mlir + +%128 = nvvm.mma.sync A[%120, %121, %122, %123] + B[%124, %125] + C[%126, %127] + {layoutA = #nvvm.mma_layout, + layoutB = #nvvm.mma_layout, + shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} + : (vector<2xf16>, vector<2xf16>, vector<2xf16>) + -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> +``` +""" +function mma_sync( + operandA::Vector{Value}, + operandB::Vector{Value}, + operandC::Vector{Value}; + res::IR.Type, + shape, + b1Op=nothing, + intOverflowBehavior=nothing, + layoutA, + layoutB, + multiplicandAPtxType=nothing, + multiplicandBPtxType=nothing, + location=Location(), +) + op_ty_results = IR.Type[res,] + operands = Value[operandA..., operandB..., operandC...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("shape", shape), + namedattribute("layoutA", layoutA), + namedattribute("layoutB", layoutB), + ] + push!( + attributes, + operandsegmentsizes([length(operandA), length(operandB), length(operandC)]), + ) + !isnothing(b1Op) && push!(attributes, namedattribute("b1Op", b1Op)) + !isnothing(intOverflowBehavior) && + push!(attributes, namedattribute("intOverflowBehavior", intOverflowBehavior)) + !isnothing(multiplicandAPtxType) && + push!(attributes, namedattribute("multiplicandAPtxType", multiplicandAPtxType)) + !isnothing(multiplicandBPtxType) && + push!(attributes, namedattribute("multiplicandBPtxType", multiplicandBPtxType)) + + return create_operation( + "nvvm.mma.sync", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function prefetch_tensormap( + tmaDescriptor::Value, predicate=nothing::Union{Nothing,Value}; location=Location() +) + op_ty_results = IR.Type[] + operands = Value[tmaDescriptor,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(predicate) && push!(operands, predicate) + + return create_operation( + "nvvm.prefetch.tensormap", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function rcp_approx_ftz_f(arg::Value; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[arg,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.rcp.approx.ftz.f", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`redux_sync` + +`redux.sync` performs a reduction operation `kind` of the 32 bit source +register across all non-exited threads in the membermask. + +The `abs` and `nan` attributes can be used in the case of f32 input type, +where the `abs` attribute causes the absolute value of the input to be used +in the reduction operation, and the `nan` attribute causes the reduction +operation to return NaN if any of the inputs to participating threads are +NaN. + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-redux-sync) +""" +function redux_sync( + val::Value, + mask_and_clamp::Value; + res::IR.Type, + kind, + abs=nothing, + nan=nothing, + location=Location(), +) + op_ty_results = IR.Type[res,] + operands = Value[val, mask_and_clamp] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("kind", kind),] + !isnothing(abs) && push!(attributes, namedattribute("abs", abs)) + !isnothing(nan) && push!(attributes, namedattribute("nan", nan)) + + return create_operation( + "nvvm.redux.sync", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function setmaxregister(; regCount, action, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("regCount", regCount), namedattribute("action", action) + ] + + return create_operation( + "nvvm.setmaxregister", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`shfl_sync` + +The `shfl.sync` Op implements data shuffle within threads of a warp. +The `thread_mask` denotes the threads participating in the Op where +the bit position corresponds to a particular thread’s laneid. +The `offset` specifies a source lane or source lane offset +(depending on `kind`). The `val` is the input value to be copied from +the source. The `mask_and_clamp` contains two packed values specifying +a mask for logically splitting warps into sub-segments and an upper bound +for clamping the source lane index. + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl-sync) +""" +function shfl_sync( + thread_mask::Value, + val::Value, + offset::Value, + mask_and_clamp::Value; + res::IR.Type, + kind, + return_value_and_is_valid=nothing, + location=Location(), +) + op_ty_results = IR.Type[res,] + operands = Value[thread_mask, val, offset, mask_and_clamp] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("kind", kind),] + !isnothing(return_value_and_is_valid) && push!( + attributes, + namedattribute("return_value_and_is_valid", return_value_and_is_valid), + ) + + return create_operation( + "nvvm.shfl.sync", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_nsmid(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.nsmid", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_smid(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.smid", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`stmatrix` + +Collectively store one or more matrices across all threads in a warp to the +location indicated by the address operand \$ptr in shared memory. + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix) +""" +function stmatrix(ptr::Value, sources::Vector{Value}; layout, location=Location()) + op_ty_results = IR.Type[] + operands = Value[ptr, sources...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("layout", layout),] + + return create_operation( + "nvvm.stmatrix", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function bar_warp_sync(mask::Value; location=Location()) + op_ty_results = IR.Type[] + operands = Value[mask,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.bar.warp.sync", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`tcgen05_alloc` + +The `tcgen05.alloc` Op allocates tensor core memory for +the amount specified by `nCols` and writes the destination +address to the `addr` argument. The `nCols` operand specifies the +number of columns to be allocated and it must be a power-of-two. +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-memory-alloc-manage-instructions) +""" +function tcgen05_alloc(addr::Value, nCols::Value; group=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[addr, nCols] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(group) && push!(attributes, namedattribute("group", group)) + + return create_operation( + "nvvm.tcgen05.alloc", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`tcgen05_commit` + +The `tcgen05.commit` makes the mbarrier object, specified by +the operand `addr`, track the completion of all the prior +async-tcgen05 operations initiated by the executing thread. +The multicast variants allow signaling on the mbarrier objects +of multiple CTAs within the cluster. Operand `multicastMask`, +when present, specifies the destination CTAs in the cluster such +that each bit position in the 16-bit `multicastMask` operand +corresponds to the `nvvm.read.ptx.sreg.ctaid` of the destination CTA. +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen-async-sync-operations-commit) +""" +function tcgen05_commit( + addr::Value, + multicastMask=nothing::Union{Nothing,Value}; + group=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[addr,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(multicastMask) && push!(operands, multicastMask) + !isnothing(group) && push!(attributes, namedattribute("group", group)) + + return create_operation( + "nvvm.tcgen05.commit", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`tcgen05_cp` + +Instruction tcgen05.cp initiates an asynchronous copy operation from +shared memory to the location specified by the address operand `taddr` +in the Tensor Memory. The 64-bit register operand `smem_desc` specifies +the matrix descriptor representing the source matrix in the shared memory +that needs to be copied. + +# Example +```mlir + nvvm.tcgen05.cp %taddr, %smem_desc { + group = #nvvm.tcgen05_group, + shape = #nvvm.tcgen05_cp_shape, + multicast = #nvvm.tcgen05_cp_multicast, + srcFormat = #nvvm.tcgen05_cp_src_fmt + } +``` +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensorcore-5th-generation-instructions-tcgen05-cp) +""" +function tcgen05_cp( + taddr::Value, + smem_desc::Value; + shape, + group=nothing, + multicast=nothing, + srcFormat=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[taddr, smem_desc] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("shape", shape),] + !isnothing(group) && push!(attributes, namedattribute("group", group)) + !isnothing(multicast) && push!(attributes, namedattribute("multicast", multicast)) + !isnothing(srcFormat) && push!(attributes, namedattribute("srcFormat", srcFormat)) + + return create_operation( + "nvvm.tcgen05.cp", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`tcgen05_dealloc` + +The `tcgen05.dealloc` Op de-allocates the tensor core memory +specified by `tmemAddr`, which must be from a previous tensor +memory allocation. The `nCols` operand specifies the number +of columns to be de-allocated, and it must be a power-of-two. +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-memory-alloc-manage-instructions) +""" +function tcgen05_dealloc(taddr::Value, nCols::Value; group=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[taddr, nCols] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(group) && push!(attributes, namedattribute("group", group)) + + return create_operation( + "nvvm.tcgen05.dealloc", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`tcgen05_fence` + +The `tcgen05.fence` orders all prior async tcgen05 operations +with respect to the subsequent tcgen05 and execution ordering operations. +The `tcgen05.fence` orders all subsequent async tcgen05 operations +with respect to the prior tcgen05 and execution ordering operations. + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensorcore-5th-generation-instructions-tcgen05-fence) +""" +function tcgen05_fence(; kind, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("kind", kind),] + + return create_operation( + "nvvm.tcgen05.fence", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`tcgen05_relinquish_alloc_permit` + +The `tcgen05.relinquish_alloc_permit` Op specifies that the CTA +of the executing thread is relinquishing the right to allocate +Tensor Memory. So, it is illegal for a CTA to perform `tcgen05.alloc` +after any of its constituent threads execute `tcgen05.relinquish_alloc_permit`. +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-memory-alloc-manage-instructions) +""" +function tcgen05_relinquish_alloc_permit(; group=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(group) && push!(attributes, namedattribute("group", group)) + + return create_operation( + "nvvm.tcgen05.relinquish_alloc_permit", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`tcgen05_shift` + +The `tcgen05.shift` is an asynchronous instruction which initiates +the shifting of 32-byte elements downwards across all the rows, +except the last, by one row. The operand `taddr` specifies the base +address of the matrix in Tensor Memory whose rows must be down shifted. + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-shift) +""" +function tcgen05_shift(taddr::Value; group=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[taddr,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(group) && push!(attributes, namedattribute("group", group)) + + return create_operation( + "nvvm.tcgen05.shift", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`tcgen05_wait` + +The `tcgen05.wait` causes the executing thread to block until +all prior `tcgen05.ld` operations issued by the executing thread +have completed. Similarly, the `tcgen05.wait` causes the executing +thread to block until all prior `tcgen05.st` operations issued by the +executing thread have completed. +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-wait) +""" +function tcgen05_wait(; kind, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("kind", kind),] + + return create_operation( + "nvvm.tcgen05.wait", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_tid_x(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.tid.x", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_tid_y(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.tid.y", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_tid_z(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.tid.z", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function vote_ballot_sync(mask::Value, pred::Value; res::IR.Type, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[mask, pred] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.vote.ballot.sync", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function wmma_load( + ptr::Value, + stride::Value; + res::IR.Type, + m, + n, + k, + layout, + eltype, + frag, + location=Location(), +) + op_ty_results = IR.Type[res,] + operands = Value[ptr, stride] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("m", m), + namedattribute("n", n), + namedattribute("k", k), + namedattribute("layout", layout), + namedattribute("eltype", eltype), + namedattribute("frag", frag), + ] + + return create_operation( + "nvvm.wmma.load", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function wmma_mma( + args::Vector{Value}; + res::IR.Type, + m, + n, + k, + layoutA, + layoutB, + eltypeA, + eltypeB, + location=Location(), +) + op_ty_results = IR.Type[res,] + operands = Value[args...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("m", m), + namedattribute("n", n), + namedattribute("k", k), + namedattribute("layoutA", layoutA), + namedattribute("layoutB", layoutB), + namedattribute("eltypeA", eltypeA), + namedattribute("eltypeB", eltypeB), + ] + + return create_operation( + "nvvm.wmma.mma", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function wmma_store( + ptr::Value, + args::Vector{Value}, + stride::Value; + m, + n, + k, + layout, + eltype, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[ptr, args..., stride] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("m", m), + namedattribute("n", n), + namedattribute("k", k), + namedattribute("layout", layout), + namedattribute("eltype", eltype), + ] + + return create_operation( + "nvvm.wmma.store", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_nwarpid(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.nwarpid", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_warpid(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.warpid", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function read_ptx_sreg_warpsize(; res::IR.Type, range=nothing, location=Location()) + op_ty_results = IR.Type[res,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(range) && push!(attributes, namedattribute("range", range)) + + return create_operation( + "nvvm.read.ptx.sreg.warpsize", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`wgmma_fence_aligned` + +Enforce an ordering of register accesses between warpgroup level matrix +multiplication and other operations. + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-fence) +""" +function wgmma_fence_aligned(; location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.wgmma.fence.aligned", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`wgmma_commit_group_sync_aligned` + +Commits all prior uncommitted warpgroup level matrix multiplication operations. + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-commit-group) +""" +function wgmma_commit_group_sync_aligned(; location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "nvvm.wgmma.commit.group.sync.aligned", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`wgmma_mma_async` + +The warpgroup (128 threads) level matrix multiply and accumulate operation +has either of the following forms, where matrix D is called accumulator: + D = A * B + D + D = A * B, where the input from accumulator D is disabled. + +Supported shapes: +``` +|--------------|--------------|------------|--------------|---------------| +| | | | |f16+=e4m3*e4m3 | +| | | | |f16+=e5m2*e5m2 | +|f32+=tf32*tf32|f16+=f16 *f16 | s32+=s8*s8 |s32 += b1 * b1|f16+=e5m2*e4m3 | +| |f32+=f16 *f16 | s32+=u8*u8 | |f16+=e4m3*e5m2 | +| |f32+=bf16*bf16| s32+=u8*u8 | |f16+=e4m3*e5m2 | +| |f32+=bf16*bf16| s32+=s8*u8 | |f32+=e4m3*e4m3 | +| | | s32+=u8*s8 | |f32+=e5m2*e5m2 | +| | | | |f32+=e4m3*e5m2 | +| | | | |f32+=e4m3*e5m2 | +|--------------|--------------|------------|--------------|---------------| +| .m64n8k8 | .m64n8k16 | .m64n8k32 | .m64n8k256 | .m64n8k32 | +| .m64n16k8 | .m64n16k16 | .m64n16k32 | .m64n16k256 | .m64n16k32 | +| .m64n24k8 | .m64n24k16 | .m64n24k32 | .m64n24k256 | .m64n24k32 | +| .m64n32k8 | .m64n32k16 | .m64n32k32 | .m64n32k256 | .m64n32k32 | +| .m64n40k8 | .m64n40k16 | .m64n48k32 | .m64n48k256 | .m64n40k32 | +| .m64n48k8 | .m64n48k16 | .m64n64k32 | .m64n64k256 | .m64n48k32 | +| .m64n56k8 | .m64n56k16 | .m64n80k32 | .m64n80k256 | .m64n56k32 | +| .m64n64k8 | .m64n64k16 | .m64n96k32 | .m64n96k256 | .m64n64k32 | +| .m64n72k8 | .m64n72k16 | .m64n112k32| .m64n112k256 | .m64n72k32 | +| .m64n80k8 | .m64n80k16 | .m64n128k32| .m64n128k256 | .m64n80k32 | +| .m64n88k8 | .m64n88k16 | .m64n144k32| .m64n144k256 | .m64n88k32 | +| .m64n96k8 | .m64n96k16 | .m64n160k32| .m64n160k256 | .m64n96k32 | +| .m64n104k8 | .m64n104k16 | .m64n176k32| .m64n176k256 | .m64n104k32 | +| .m64n112k8 | .m64n112k16 | .m64n192k32| .m64n192k256 | .m64n112k32 | +| .m64n120k8 | .m64n120k16 | .m64n208k32| .m64n208k256 | .m64n120k32 | +| .m64n128k8 | .m64n128k16 | .m64n224k32| .m64n224k256 | .m64n128k32 | +| .m64n136k8 | .m64n136k16 | .m64n240k32| .m64n240k256 | .m64n136k32 | +| .m64n144k8 | .m64n144k16 | .m64n256k32| .m64n256k256 | .m64n144k32 | +| .m64n152k8 | .m64n152k16 | | | .m64n152k32 | +| .m64n160k8 | .m64n160k16 | | | .m64n160k32 | +| .m64n168k8 | .m64n168k16 | | | .m64n168k32 | +| .m64n176k8 | .m64n176k16 | | | .m64n176k32 | +| .m64n184k8 | .m64n184k16 | | | .m64n184k32 | +| .m64n192k8 | .m64n192k16 | | | .m64n192k32 | +| .m64n200k8 | .m64n200k16 | | | .m64n200k32 | +| .m64n208k8 | .m64n208k16 | | | .m64n208k32 | +| .m64n216k8 | .m64n216k16 | | | .m64n216k32 | +| .m64n224k8 | .m64n224k16 | | | .m64n224k32 | +| .m64n232k8 | .m64n232k16 | | | .m64n232k32 | +| .m64n240k8 | .m64n240k16 | | | .m64n240k32 | +| .m64n248k8 | .m64n248k16 | | | .m64n248k32 | +| .m64n256k8 | .m64n256k16 | | | .m64n256k32 | +|--------------|--------------|------------|--------------|---------------| +``` + + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) +""" +function wgmma_mma_async( + inouts::Value, + descriptorA::Value, + descriptorB::Value; + results::IR.Type, + shape, + typeA, + typeB, + typeD, + scaleD, + scaleA, + scaleB, + layoutA, + layoutB, + satfinite=nothing, + location=Location(), +) + op_ty_results = IR.Type[results,] + operands = Value[inouts, descriptorA, descriptorB] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("shape", shape), + namedattribute("typeA", typeA), + namedattribute("typeB", typeB), + namedattribute("typeD", typeD), + namedattribute("scaleD", scaleD), + namedattribute("scaleA", scaleA), + namedattribute("scaleB", scaleB), + namedattribute("layoutA", layoutA), + namedattribute("layoutB", layoutB), + ] + !isnothing(satfinite) && push!(attributes, namedattribute("satfinite", satfinite)) + + return create_operation( + "nvvm.wgmma.mma_async", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`wgmma_wait_group_sync_aligned` + +Signal the completion of a preceding warpgroup operation. + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-wait-group) +""" +function wgmma_wait_group_sync_aligned(; group, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("group", group),] + + return create_operation( + "nvvm.wgmma.wait.group.sync.aligned", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +end # nvvm diff --git a/src/mlir/Dialects/Shardy.jl b/src/mlir/Dialects/Shardy.jl new file mode 100644 index 0000000000..00b91b354b --- /dev/null +++ b/src/mlir/Dialects/Shardy.jl @@ -0,0 +1,713 @@ +module sdy +using ...IR +import ...IR: + NamedAttribute, + Value, + Location, + Block, + Region, + Attribute, + create_operation, + context, + IndexType +import ..Dialects: namedattribute, operandsegmentsizes +import ...API + +""" +`all_gather` + +Gathers chunks of a tensor along axes specified in `gathering_axes`. + +The `gathering_axes` is a list of lists of axes. The outer list is over the +dimensions of the tensor. Each inner list specifies the axes along which a +separate gather should be performed on the respective dimension. It will be +applied to the sharding of the operand (`tensor`) to obtain the sharding of +the result (`out_sharding`). + +Note that `out_sharding` is not used to determine the sharding of the +result. Instead, the sharding of the result is determined by the sharding of +the operand and the `gathering_axes`, and `out_sharding` must match this +inferred sharding. + +# Example +```mlir +%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{\"a\", \"b\", \"c\"}, {}, {\"d\"}\\]>]>} : tensor<8x8x8xf32> +%2 = sdy.all_gather [{\"b\", \"c\"}, {}, {\"d\"}\\] %1 out_sharding=<@mesh, [{\"a\"}, {}, {}\\]> : tensor<8x8x8xf32> +``` + +**Constraints:** +- Must satisfy the constraints listed in `Sdy_CollectiveOpInterface`. +- Elements in `gathering_axes` must satisfy the constraints listed in + `AxisRefListAttr`. +- Applying `gathering_axes` to the operand sharding gets `out_sharding`. +""" +function all_gather( + tensor::Value; + result=nothing::Union{Nothing,IR.Type}, + gathering_axes, + out_sharding, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[tensor,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("gathering_axes", gathering_axes), + namedattribute("out_sharding", out_sharding), + ] + !isnothing(result) && push!(op_ty_results, result) + + return create_operation( + "sdy.all_gather", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`all_reduce` + +Reduces chunks of a tensor along axes specified in `reduction_axes`. +The order of `reduction_axes` is not important for the result, but can +affect the order of the corresponding replica groups. + +**Constraints:** +- Must satisfy the constraints listed in `Sdy_CollectiveOpInterface`. +- `reduction_axes` must satisfy the constraints listed in `AxisRefListAttr`; +- `reduction_axes` must not overlap with the operand sharding axes; +""" +function all_reduce( + tensor::Value; + result=nothing::Union{Nothing,IR.Type}, + reduction_axes, + out_sharding, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[tensor,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("reduction_axes", reduction_axes), + namedattribute("out_sharding", out_sharding), + ] + !isnothing(result) && push!(op_ty_results, result) + + return create_operation( + "sdy.all_reduce", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`all_slice` + +Slices chunks of a tensor along axes specified in `slicing_axes`. There is +an algebric duality between `sdy.all_slice` and `sdy.all_gather`. + +The `slicing_axes` is a list of lists of axes. The outer list is over the +dimensions of the tensor. Each inner list specifies the axes along which a +slice should be performed on the respective dimension. It will be applied to +the sharding of the operand (`tensor`) to obtain the sharding of the result +(`out_sharding`). + +Note that `out_sharding` is not used to determine the sharding of the +result. Instead, the sharding of the result is determined by the sharding of +the operand and the `slicing_axes`, and `out_sharding` must match this +inferred sharding. + +# Example +```mlir +%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{\"a\"}, {}, {}\\]>]>} : tensor<8x8x8xf32> +%2 = sdy.all_slice [{\"b\", \"c\"}, {}, {\"d\"}\\] %1 out_sharding=<@mesh, [{\"a\", \"b\", \"c\"}, {}, {\"d\"}\\]> : tensor<8x8x8xf32> +``` + +**Constraints:** +- Elements in `slicing_axes` must satisfy the constraints listed in + `AxisRefListAttr`. +- Must satisfy the constraints listed in `Sdy_CollectiveOpInterface`. +- Applying `slicing_axes` to the operand sharding gets `out_sharding`. +""" +function all_slice( + tensor::Value; + result=nothing::Union{Nothing,IR.Type}, + slicing_axes, + out_sharding, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[tensor,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("slicing_axes", slicing_axes), + namedattribute("out_sharding", out_sharding), + ] + !isnothing(result) && push!(op_ty_results, result) + + return create_operation( + "sdy.all_slice", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`all_to_all` + +Slices chunks of a tensor along dimension `tgt_dim` and axes specified in +`axes`, scatteres those chunks along the axes, and concatenates them along +dimension `src_dim`. + +This operation is essentially a combination of an all-gather along `src_dim` +and `axes`, followed by an all-slice along `tgt_dim` and `axes`, i.e., a +suffix of the axes sharding dimension `src_dim` on the input tensor is +appended to the axes sharding dimension `tgt_dim` on the output tensor. + +The all-to-all will be applied to the sharding of the operand (`tensor`) to +obtain the sharding of the result (`out_sharding`). + +Note that `out_sharding` is not used to determine the sharding of the +result. Instead, the sharding of the result is determined by the sharding of +the operand, `src_dim`, `tgt_dim`, and `axes`, and `out_sharding` must match +this inferred sharding. + +# Example +```mlir +%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{\"a\", \"b\", \"c\"}, {}\\]>]>} : tensor<8x8xf32> +%2 = sdy.all_to_all {\"b\", \"c\"} 0->1 %1 out_sharding=<@mesh, [{\"a\"}, {\"b\", \"c\"}\\]> : tensor<8x8xf32> +``` + +**Constraints:** +- Must satisfy the constraints listed in `Sdy_CollectiveOpInterface`. +- `axes` must satisfy the constraints listed in `AxisRefListAttr`. +- `src_dim` and `tgt_dim` must be valid dimensions (positive and less than + rank of tensor), and different from each other. +- Moving `axes` from `src_dim` to `tgt_dim` in the operand sharding gets + `out_sharding`. +""" +function all_to_all( + tensor::Value; + result=nothing::Union{Nothing,IR.Type}, + src_dim, + tgt_dim, + axes, + out_sharding, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[tensor,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("src_dim", src_dim), + namedattribute("tgt_dim", tgt_dim), + namedattribute("axes", axes), + namedattribute("out_sharding", out_sharding), + ] + !isnothing(result) && push!(op_ty_results, result) + + return create_operation( + "sdy.all_to_all", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`collective_permute` + +Sends a chunk of the input tensor from each device to another to +reorder/replace the axes that shard the tensor. + +A collective permute can transform the input sharding such that each +dimension must be as sharded as it was before, i.e., it must be sharded +along axes whose product of sizes matches that of the axes that previously +sharded the tensor. + +This is useful for reordering axes in a single dimension or across different +dimensions, and swapping sharded axes with replicated ones. + +In the below example, the sharded tensor size is `tensor<1x4x2xf32>`, and +that is preserved by the collective permute. + +# Example +```mlir +sdy.mesh @mesh = <[\"a\"=2, \"b\"=2, \"c\"=4, \"d\"=2, \"e\"=2, \"f\"=2]> +%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{\"a\", \"c\"}, {\"f\"}, {\"d\", \"e\"}\\]>]>} : tensor<8x8x8xf32> +%2 = sdy.collective_permute %1 out_sharding=<@mesh, [{\"c\":(1)2, \"b\", \"f\"}, {\"a\"}, {\"e\", \"d\"}\\]> : tensor<8x8x8xf32> +``` + +**Constraints:** +- Must satisfy the constraints listed in `Sdy_CollectiveOpInterface`. +- If input and output sharding have different meshes, then those meshes must + have exactly the same axes and different order of device ids. +- For each dimension, the product of sharding axis sizes in `out_sharding` + must match that of the corresponding operand dimension sharding. +""" +function collective_permute( + tensor::Value; result=nothing::Union{Nothing,IR.Type}, out_sharding, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[tensor,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("out_sharding", out_sharding),] + !isnothing(result) && push!(op_ty_results, result) + + return create_operation( + "sdy.collective_permute", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`constant` + +Produces an `output` tensor from a constant `value`. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#constant + +NOTE: SDY defines its own constant op that isn\'t ConstantLike and doesn\'t +have a folder, so that we\'ll be able to duplicate constants without any +greedy pattern rewriter folding them back into a single constant. In this +way, constants can be sharded differently for every use, and no propagation +is done between constants (or constant expressions). + +# Example +```mlir +%output = sdy.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32> +``` +""" +function constant(; output=nothing::Union{Nothing,IR.Type}, value, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("value", value),] + !isnothing(output) && push!(op_ty_results, output) + + return create_operation( + "sdy.constant", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`data_flow_edge` + +A data flow edge of some op X defines a bridge between a set of sources +(each is either an operand of X or an operand of X\'s block terminator) and +a set of targets (each is either a result of X or a block argument of X), +such that all sources and targets should be sharded in the same way. + +An op can have multiple data flow edges that are orthogonal to one another. + +For example: + +```mlir + y_0, ..., y_n = while (x_0, ..., x_n) + ((pred_arg_0,... , pred_arg_n) { ... }) + ((body_arg_0,..., body_arg_n) { + ... + return return_value_0, ..., return_value_n + }) +``` + +This while op has n data flow edges, the i-th data flow edges is between +sources `x_i`, `return_value_i` and targets `y_i`, `pred_arg_i`, +`body_arg_i`. + +An `sdy.data_flow_edge` takes as input the owner of an edge (can be +any of the targets, but preferably an op result rather than a block +argument), which shouldn\'t have any other uses. This op isn\'t pure because +it can take an input that originally didn\'t have any uses. + +The `sdy.data_flow_edge` also holds an optional sharding for all targets of +the edge, and that sharding should be updated instead of the targets\' +sharding (if can be attached) during propagation. This is useful when an op +has many edges, as it\'s much more efficient to: +- propagate through each edge separately. +- update the sharding of each edge separately instead of all targets at once + (e.g. an op has a single immutable `TensorShardingPerValueAttr` for result + shardings). +- add each edge to the worklist separately when the sharding of a source has + changed. + +Propagation will propagate shardings between all sources and targets of a +`sdy.data_flow_edge` as if it was a regular op with the sources as operands +and targets as results, and an identity `sdy.op_sharding_rule`. That means +that forward propagation is from sources to targets and backwards +propagation is from targets to sources. + +We don\'t allow the input of a `sdy.data_flow_edge` to be defined by an +`SdyDialect` op, so we can assume that it\'s defined by an op that has +unregistered `sdy.sharding` attribute. + +NOTE: it\'s NOT the responsibility of the `sdy.data_flow_edge` to link +between sources and targets, it\'s simply attached to the owner of the edge. +The op that this edge is bound to (while in the example above) is +responsible for providing this information. +""" +function data_flow_edge( + input::Value; + result=nothing::Union{Nothing,IR.Type}, + sharding=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[input,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + !isnothing(sharding) && push!(attributes, namedattribute("sharding", sharding)) + + return create_operation( + "sdy.data_flow_edge", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`manual_computation` + +Jump into a region written in terms of per-device local code with explicit +collectives, where logical shapes match local per-device physical buffer +shapes and collectives correspond exactly to physical cross-device +communication. + +The body is local wrt the manual_axes. Propagation will occur through +the body on any free axes - those not in the manual_axes list. + +**Constraints:** +- Elements in `in_shardings` and `out_shardings` must satisfy the constraints listed in `TensorShardingAttr`. +- The number of global and local tensor inputs/outputs of the op region must match. +- The manual axes must come before any free axes in each dim sharding. +- The global and local shapes of the op regions arguments/results must match. +- No manual axes are split. +""" +function manual_computation( + tensors::Vector{Value}; + results::Vector{IR.Type}, + in_shardings, + out_shardings, + manual_axes, + body::Region, + location=Location(), +) + op_ty_results = IR.Type[results...,] + operands = Value[tensors...,] + owned_regions = Region[body,] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("in_shardings", in_shardings), + namedattribute("out_shardings", out_shardings), + namedattribute("manual_axes", manual_axes), + ] + + return create_operation( + "sdy.manual_computation", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`mesh` + +Defines a new named mesh. All meshes in a module must have the same number +of devices (except for meshes with a single device_id). +The mesh is a `Symbol` operation that appears in the module\'s +`SymbolTable` and can be referenced by its `name`. +""" +function mesh(; sym_name, mesh, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("sym_name", sym_name), namedattribute("mesh", mesh) + ] + + return create_operation( + "sdy.mesh", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`named_computation` + +Groups a computation, i.e. a block of operations, and gives it a name. +Propagation will flow in/out of the region as if everything was inlined. + +This can be used to handle propagating through call instructions to other +functions. Any users of Shardy should write an import/export pass that +converts their call ops to `sdy.named_computation` ops, duplicating/copying +the body of the called function into the body of the `named_computation`. + +The type of each block arguments and returned values in the region must be +the same as the type of the operands and results type of the op. + +# Example + +```mlir +%1 = sdy.named_computation<\"foo\">(%0) (%arg1: tensor<16x32xf32>) { + sdy.return %arg1 : tensor<16x32xf32> +} : (tensor<16x32xf32>) -> tensor<16x32xf32> +``` +""" +function named_computation( + operands::Vector{Value}; + result_0::Vector{IR.Type}, + name, + in_shardings=nothing, + out_shardings=nothing, + body::Region, + location=Location(), +) + op_ty_results = IR.Type[result_0...,] + operands = Value[operands...,] + owned_regions = Region[body,] + successors = Block[] + attributes = NamedAttribute[namedattribute("name", name),] + !isnothing(in_shardings) && + push!(attributes, namedattribute("in_shardings", in_shardings)) + !isnothing(out_shardings) && + push!(attributes, namedattribute("out_shardings", out_shardings)) + + return create_operation( + "sdy.named_computation", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`propagation_barrier` + +This op operates like an identity op, outputting the same value it took as +input. But in terms of propagation, this will only allow propagation to flow +through it in a certain direction. + +This prevents shardings from being propagated between the uses of the result +of the barrier op and its operand. + +- `FORWARD` means shardings can only flow from the operand to the result. +- `BACKWARD` means shardings can only flow from the result to the operand. +- `NONE` means no sharding can propagate through this op. +- Cannot specify `BOTH`, as this op would be redundant. +""" +function propagation_barrier( + input::Value; + result=nothing::Union{Nothing,IR.Type}, + allowed_direction, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[input,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("allowed_direction", allowed_direction),] + !isnothing(result) && push!(op_ty_results, result) + + return create_operation( + "sdy.propagation_barrier", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`reshard` + +Reshards the input tensor with the specified sharding, which is different +from the input tensor\'s existing sharding. + +Both ShardingConstraintOp and ReshardOp attach a sharding to a tensor. Their +lifespan is: +1. Before sharding propagation, ShardingConstraintOp is added by users. +2. Sharding propagation consumes ShardingConstraintOp. There is no + ShardingConstraintOp in the results of sharding propagation. Instead, + ReshardOp may be added if needed. +3. A partitioner converts a ReshardOp into a collective op (or an identity + op). There should be no ReshardOp in the results of the partitioner. + + // TODO(b/331680067). Add a canonicalization pattern to remove redundant + // reshard ops. +""" +function reshard( + input::Value; result=nothing::Union{Nothing,IR.Type}, sharding, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[input,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("sharding", sharding),] + !isnothing(result) && push!(op_ty_results, result) + + return create_operation( + "sdy.reshard", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function return_(results::Vector{Value}; location=Location()) + op_ty_results = IR.Type[] + operands = Value[results...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "sdy.return", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`sharding_constraint` + +Attaches a sharding to an intermediate tensor (e.g. the result of a matmul) +to indicate that this is how that tensor, or a subset of its uses, should be +sharded. + +If the sharding has open dimensions and unconstraint axes, it means the +tensor can be further sharded along the open dimensions. + +This op can either: +- Have no uses (dangling) - which means the attached sharding is how the + input tensor itself should be sharded. +- Have uses - which means the attached sharding is how the uses of the + sharding constraint op should be sharded, while other uses of the input + tensor might have a different sharding (if the input tensor has no other + uses then the behavior is the same as the no uses case). +""" +function sharding_constraint( + input::Value; result=nothing::Union{Nothing,IR.Type}, sharding, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[input,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("sharding", sharding),] + !isnothing(result) && push!(op_ty_results, result) + + return create_operation( + "sdy.sharding_constraint", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`sharding_group` + +This op provides an interface to assign tensors to sharding groups ( +groups of tensors that will be enforced to have identical shardings). +During propagation, as soon as one group element is sharded, all other +members will be sharded in exactly the same way. This operation takes the +argument group ID and returns no result, but instead modifies the internal +sharding group representation to add the input tensor to the group with the +given ID. +""" +function sharding_group(input::Value; group_id, location=Location()) + op_ty_results = IR.Type[] + operands = Value[input,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("group_id", group_id),] + + return create_operation( + "sdy.sharding_group", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +end # sdy diff --git a/src/mlir/Dialects/StableHLO.jl b/src/mlir/Dialects/StableHLO.jl index 32e229f9c8..f7d45ef928 100755 --- a/src/mlir/Dialects/StableHLO.jl +++ b/src/mlir/Dialects/StableHLO.jl @@ -1994,7 +1994,10 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#exponential ``` """ function exponential( - operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + operand::Value; + result=nothing::Union{Nothing,IR.Type}, + result_accuracy=nothing, + location=Location(), ) op_ty_results = IR.Type[] operands = Value[operand,] @@ -2002,6 +2005,8 @@ function exponential( successors = Block[] attributes = NamedAttribute[] !isnothing(result) && push!(op_ty_results, result) + !isnothing(result_accuracy) && + push!(attributes, namedattribute("result_accuracy", result_accuracy)) return create_operation( "stablehlo.exponential", diff --git a/src/mlir/Dialects/TPU.jl b/src/mlir/Dialects/TPU.jl new file mode 100644 index 0000000000..46a72c9567 --- /dev/null +++ b/src/mlir/Dialects/TPU.jl @@ -0,0 +1,1364 @@ +module tpu +using ...IR +import ...IR: + NamedAttribute, + Value, + Location, + Block, + Region, + Attribute, + create_operation, + context, + IndexType +import ..Dialects: namedattribute, operandsegmentsizes +import ...API + +function all_reduce( + input::Value; output=nothing::Union{Nothing,IR.Type}, dim, kind, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[input,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dim", dim), namedattribute("kind", kind)] + !isnothing(output) && push!(op_ty_results, output) + + return create_operation( + "tpu.all_reduce", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function sem_alloc(; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tpu.sem_alloc", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function assume_layout(input::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[input,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tpu.assume_layout", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function assume_multiple( + value::Value; result=nothing::Union{Nothing,IR.Type}, multiple, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[value,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("multiple", multiple),] + !isnothing(result) && push!(op_ty_results, result) + + return create_operation( + "tpu.assume_multiple", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function bitcast(input::Value; output::IR.Type, location=Location()) + op_ty_results = IR.Type[output,] + operands = Value[input,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tpu.bitcast", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function bitcast_vreg(input::Value; output::IR.Type, location=Location()) + op_ty_results = IR.Type[output,] + operands = Value[input,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tpu.bitcast_vreg", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`broadcast_in_sublanes` + +For each sublane `i`, broadcasts the value in lane `lane + i` along the entire +sublane. If `lane + i` is not in [0, lane_count), then the value in sublane `i` +is not defined (can be anything). +""" +function broadcast_in_sublanes(source::Value; output::IR.Type, lane, location=Location()) + op_ty_results = IR.Type[output,] + operands = Value[source,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("lane", lane),] + + return create_operation( + "tpu.broadcast_in_sublanes", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function concatenate( + sources::Vector{Value}; output::IR.Type, dimension, location=Location() +) + op_ty_results = IR.Type[output,] + operands = Value[sources...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dimension", dimension),] + + return create_operation( + "tpu.concatenate", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function create_mask( + low::Vector{Value}, high::Vector{Value}; output::IR.Type, location=Location() +) + op_ty_results = IR.Type[output,] + operands = Value[low..., high...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tpu.create_mask", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`create_subelement_mask` + +The \"half-sublanes\", \"quarter-sublanes\", etc. (unit is determined by +the type of `output`) of the mask are masked in the range specified by +`from` and `to`. + +- If `from <= to`, the range `[from, to)` is set and the rest is unset. +- If `to <= from`, the range `[to, from)` is unset and the rest is set. + +All lanes are set identically. + +# Example + +```mlir +%msk = tpu.create_subelement_mask 3, 9 : vector<8x128x2xi1> +``` + +This creates a mask `%msk` where, for all `lane`s, `%msk[*][lane][*]` is: + +``` +[[0, 0], [0, 1], [1, 1], [1, 1], [1, 0], [0, 0], [0, 0], [0, 0]] +``` + +It is currently only supported: +- In TPU v4, for `num_subelems` of 1 and 2. +- In TPU v5, for `num_subelems` of 1, 2, and 4. +""" +function create_subelement_mask(; output::IR.Type, from, to, location=Location()) + op_ty_results = IR.Type[output,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("from", from), namedattribute("to", to)] + + return create_operation( + "tpu.create_subelement_mask", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function delay(nanos::Value; location=Location()) + op_ty_results = IR.Type[] + operands = Value[nanos,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tpu.delay", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function device_id(; result=nothing::Union{Nothing,IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + return create_operation( + "tpu.device_id", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function dynamic_gather( + source::Value, indices::Value; output::IR.Type, dimension, location=Location() +) + op_ty_results = IR.Type[output,] + operands = Value[source, indices] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dimension", dimension),] + + return create_operation( + "tpu.dynamic_gather", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function dynamic_rotate( + value::Value, + amount::Value; + result::IR.Type, + dimension, + stride=nothing, + stride_dimension=nothing, + location=Location(), +) + op_ty_results = IR.Type[result,] + operands = Value[value, amount] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dimension", dimension),] + !isnothing(stride) && push!(attributes, namedattribute("stride", stride)) + !isnothing(stride_dimension) && + push!(attributes, namedattribute("stride_dimension", stride_dimension)) + + return create_operation( + "tpu.dynamic_rotate", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function enqueue_dma( + source::Value, + source_semaphore=nothing::Union{Nothing,Value}; + target::Value, + target_semaphore::Value, + device_id=nothing::Union{Nothing,Value}, + core_id=nothing::Union{Nothing,Value}, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[source, target, target_semaphore] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(source_semaphore) && push!(operands, source_semaphore) + !isnothing(device_id) && push!(operands, device_id) + !isnothing(core_id) && push!(operands, core_id) + push!( + attributes, + operandsegmentsizes([ + 1, + (source_semaphore == nothing) ? 0 : 11, + 1, + if (device_id == nothing) + 0 + elseif 1(core_id == nothing) + 0 + else + 1 + end, + ]), + ) + + return create_operation( + "tpu.enqueue_dma", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function erase_memref_layout(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[operand,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tpu.erase_memref_layout", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function fptosi(input::Value; output::IR.Type, rounding_mode, location=Location()) + op_ty_results = IR.Type[output,] + operands = Value[input,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("rounding_mode", rounding_mode),] + + return create_operation( + "tpu.fptosi", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function gather(source::Value; output::IR.Type, indices, dimension, location=Location()) + op_ty_results = IR.Type[output,] + operands = Value[source,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("indices", indices), namedattribute("dimension", dimension) + ] + + return create_operation( + "tpu.gather", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function sem_barrier(; semaphore::IR.Type, location=Location()) + op_ty_results = IR.Type[semaphore,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tpu.sem_barrier", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function internal_scratch(; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tpu.internal_scratch", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function iteration_bound(; result=nothing::Union{Nothing,IR.Type}, dim, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dim", dim),] + !isnothing(result) && push!(op_ty_results, result) + + return create_operation( + "tpu.iteration_bound", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function iota(; output::IR.Type, dimension=nothing, location=Location()) + op_ty_results = IR.Type[output,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(dimension) && push!(attributes, namedattribute("dimension", dimension)) + + return create_operation( + "tpu.iota", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function load( + base::Value, + indices::Vector{Value}; + result::IR.Type, + sublane_mask, + sublane_stride=nothing, + location=Location(), +) + op_ty_results = IR.Type[result,] + operands = Value[base, indices...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("sublane_mask", sublane_mask),] + !isnothing(sublane_stride) && + push!(attributes, namedattribute("sublane_stride", sublane_stride)) + + return create_operation( + "tpu.load", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function log_buffer(input::Value; shape, tag, location=Location()) + op_ty_results = IR.Type[] + operands = Value[input,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("shape", shape), namedattribute("tag", tag)] + + return create_operation( + "tpu.log_buffer", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function log(inputs::Vector{Value}; tag, formatted=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[inputs...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("tag", tag),] + !isnothing(formatted) && push!(attributes, namedattribute("formatted", formatted)) + + return create_operation( + "tpu.log", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function mask_cast(input::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[input,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tpu.mask_cast", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function matmul( + lhs::Value, + rhs::Value, + acc::Value; + result::IR.Type, + transpose_lhs=nothing, + transpose_rhs=nothing, + precision=nothing, + dimension_numbers=nothing, + location=Location(), +) + op_ty_results = IR.Type[result,] + operands = Value[lhs, rhs, acc] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(transpose_lhs) && + push!(attributes, namedattribute("transpose_lhs", transpose_lhs)) + !isnothing(transpose_rhs) && + push!(attributes, namedattribute("transpose_rhs", transpose_rhs)) + !isnothing(precision) && push!(attributes, namedattribute("precision", precision)) + !isnothing(dimension_numbers) && + push!(attributes, namedattribute("dimension_numbers", dimension_numbers)) + + return create_operation( + "tpu.matmul", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function memref_bitcast(input::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[input,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tpu.memref_bitcast", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function memref_reshape(input::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[input,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tpu.memref_reshape", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function memref_slice( + mem_ref::Value, + base_idx::Vector{Value}, + dynamic_sizes::Vector{Value}; + result::IR.Type, + location=Location(), +) + op_ty_results = IR.Type[result,] + operands = Value[mem_ref, base_idx..., dynamic_sizes...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + push!(attributes, operandsegmentsizes([1, length(base_idx), length(dynamic_sizes)])) + + return create_operation( + "tpu.memref_slice", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function memref_squeeze(input::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[input,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tpu.memref_squeeze", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function prng_random_bits(; output::IR.Type, location=Location()) + op_ty_results = IR.Type[output,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tpu.prng_random_bits", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function prng_set_seed_32(seeds::Vector{Value}; location=Location()) + op_ty_results = IR.Type[] + operands = Value[seeds...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tpu.prng_set_seed_32", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function pack_vmsk(low::Value, high::Value; output::IR.Type, location=Location()) + op_ty_results = IR.Type[output,] + operands = Value[low, high] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tpu.pack_vmsk", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function pack_subelements( + sources::Vector{Value}; output::IR.Type, positions, pack_format, location=Location() +) + op_ty_results = IR.Type[output,] + operands = Value[sources...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("positions", positions), namedattribute("pack_format", pack_format) + ] + + return create_operation( + "tpu.pack_subelements", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function region(; results::Vector{IR.Type}, region::Region, location=Location()) + op_ty_results = IR.Type[results...,] + operands = Value[] + owned_regions = Region[region,] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tpu.region", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function reinterpret_cast(input::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[input,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tpu.reinterpret_cast", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function relayout(input::Value; output=nothing::Union{Nothing,IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[input,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(output) && push!(op_ty_results, output) + + return create_operation( + "tpu.relayout", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function repeat(source::Value; output::IR.Type, dimension, times, location=Location()) + op_ty_results = IR.Type[output,] + operands = Value[source,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("dimension", dimension), namedattribute("times", times) + ] + + return create_operation( + "tpu.repeat", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function roll_vectors(input::Vector{Value}; output::IR.Type, location=Location()) + op_ty_results = IR.Type[output,] + operands = Value[input...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tpu.roll_vectors", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function rotate( + value::Value; + result=nothing::Union{Nothing,IR.Type}, + amount, + dimension, + stride=nothing, + stride_dimension=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[value,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("amount", amount), namedattribute("dimension", dimension) + ] + !isnothing(result) && push!(op_ty_results, result) + !isnothing(stride) && push!(attributes, namedattribute("stride", stride)) + !isnothing(stride_dimension) && + push!(attributes, namedattribute("stride_dimension", stride_dimension)) + + return create_operation( + "tpu.rotate", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function sem_read( + semaphore::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[semaphore,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + return create_operation( + "tpu.sem_read", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function sem_signal( + semaphore::Value, + amount::Value, + device_id=nothing::Union{Nothing,Value}; + core_id=nothing::Union{Nothing,Value}, + core_type=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[semaphore, amount] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(device_id) && push!(operands, device_id) + !isnothing(core_id) && push!(operands, core_id) + push!(attributes, operandsegmentsizes([ + 1, + 1, + if (device_id == nothing) + 0 + elseif 1(core_id == nothing) + 0 + else + 1 + end, + ])) + !isnothing(core_type) && push!(attributes, namedattribute("core_type", core_type)) + + return create_operation( + "tpu.sem_signal", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function sem_wait(semaphore::Value, amount::Value; location=Location()) + op_ty_results = IR.Type[] + operands = Value[semaphore, amount] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tpu.sem_wait", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function shuffled_load( + base::Value, + indices::Vector{Value}; + result::IR.Type, + sublane_mask, + sublane_offsets, + location=Location(), +) + op_ty_results = IR.Type[result,] + operands = Value[base, indices...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("sublane_mask", sublane_mask), + namedattribute("sublane_offsets", sublane_offsets), + ] + + return create_operation( + "tpu.shuffled_load", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function shuffled_store( + valueToStore::Value, + base::Value, + indices::Vector{Value}; + sublane_mask, + sublane_offsets, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[valueToStore, base, indices...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("sublane_mask", sublane_mask), + namedattribute("sublane_offsets", sublane_offsets), + ] + + return create_operation( + "tpu.shuffled_store", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function store( + valueToStore::Value, + base::Value, + indices::Vector{Value}, + mask=nothing::Union{Nothing,Value}; + sublane_mask, + sublane_stride=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[valueToStore, base, indices...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("sublane_mask", sublane_mask),] + !isnothing(mask) && push!(operands, mask) + push!( + attributes, operandsegmentsizes([1, 1, length(indices), (mask == nothing) ? 0 : 1]) + ) + !isnothing(sublane_stride) && + push!(attributes, namedattribute("sublane_stride", sublane_stride)) + + return create_operation( + "tpu.store", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function strided_load( + base::Value, indices::Vector{Value}; result::IR.Type, strides, location=Location() +) + op_ty_results = IR.Type[result,] + operands = Value[base, indices...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("strides", strides),] + + return create_operation( + "tpu.strided_load", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function strided_store( + valueToStore::Value, base::Value, indices::Vector{Value}; strides, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[valueToStore, base, indices...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("strides", strides),] + + return create_operation( + "tpu.strided_store", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function trace(; + results::Vector{IR.Type}, message, level, region::Region, location=Location() +) + op_ty_results = IR.Type[results...,] + operands = Value[] + owned_regions = Region[region,] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("message", message), namedattribute("level", level) + ] + + return create_operation( + "tpu.trace", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function trace_start(; message, level, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("message", message), namedattribute("level", level) + ] + + return create_operation( + "tpu.trace_start", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function trace_stop(; location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tpu.trace_stop", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function unpack_subelements( + source::Value; output::IR.Type, index, pack_format, location=Location() +) + op_ty_results = IR.Type[output,] + operands = Value[source,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("index", index), namedattribute("pack_format", pack_format) + ] + + return create_operation( + "tpu.unpack_subelements", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function unroll_vectors(input::Value; output::Vector{IR.Type}, location=Location()) + op_ty_results = IR.Type[output...,] + operands = Value[input,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tpu.unroll_vectors", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function vector_store( + valueToStore::Value, + base::Value, + indices::Vector{Value}, + mask=nothing::Union{Nothing,Value}; + strides, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[valueToStore, base, indices...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("strides", strides),] + !isnothing(mask) && push!(operands, mask) + push!( + attributes, operandsegmentsizes([1, 1, length(indices), (mask == nothing) ? 0 : 1]) + ) + + return create_operation( + "tpu.vector_store", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function wait_dma2(semaphore::Value, src::Value, dst::Value; location=Location()) + op_ty_results = IR.Type[] + operands = Value[semaphore, src, dst] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tpu.wait_dma2", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function wait_dma(semaphore::Value, ref::Value; location=Location()) + op_ty_results = IR.Type[] + operands = Value[semaphore, ref] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tpu.wait_dma", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function weird(input::Value; output::IR.Type, location=Location()) + op_ty_results = IR.Type[output,] + operands = Value[input,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tpu.weird", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function yield(results::Vector{Value}; location=Location()) + op_ty_results = IR.Type[] + operands = Value[results...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tpu.yield", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +end # tpu diff --git a/src/mlir/Dialects/Triton.jl b/src/mlir/Dialects/Triton.jl new file mode 100755 index 0000000000..36122eeee4 --- /dev/null +++ b/src/mlir/Dialects/Triton.jl @@ -0,0 +1,1532 @@ +module tt +using ...IR +import ...IR: + NamedAttribute, + Value, + Location, + Block, + Region, + Attribute, + create_operation, + context, + IndexType +import ..Dialects: namedattribute, operandsegmentsizes +import ...API + +""" +`call` + +The `tt.call` operation represents a direct call to a function that is +within the same symbol scope as the call. The operands and result types of +the call must match the specified function type. The callee is encoded as a +symbol reference attribute named \"callee\". + +# Example + +```mlir +%2 = tt.call @my_add(%0, %1) : (f32, f32) -> f32 +``` +""" +function call( + operands::Vector{Value}; result_0::Vector{IR.Type}, callee, location=Location() +) + op_ty_results = IR.Type[result_0...,] + operands = Value[operands...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("callee", callee),] + + return create_operation( + "tt.call", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`func` + +Operations within the function cannot implicitly capture values defined +outside of the function, i.e. Functions are `IsolatedFromAbove`. All +external references must use function arguments or attributes that establish +a symbolic connection (e.g. symbols referenced by name via a string +attribute like SymbolRefAttr). An external function declaration (used when +referring to a function declared in some other module) has no body. While +the MLIR textual form provides a nice inline syntax for function arguments, +they are internally represented as “block arguments” to the first block in +the region. + +Only dialect attribute names may be specified in the attribute dictionaries +for function arguments, results, or the function itself. + +# Example + +```mlir +// External function definitions. +tt.func @abort() +tt.func @scribble(i32, i64, memref) -> f64 + +// A function that returns its argument twice: +tt.func @count(%x: i64) -> (i64, i64) + attributes {fruit: \"banana\"} { + return %x, %x: i64, i64 +} + +// A function with an argument attribute +tt.func @example_fn_arg(%x: i32 {swift.self = unit}) + +// A function with a result attribute +tt.func @example_fn_result() -> (f64 {dialectName.attrName = 0 : i64}) + +// A function with an attribute +tt.func @example_fn_attr() attributes {dialectName.attrName = false} +``` +""" +function func(; + sym_name, + function_type, + sym_visibility=nothing, + arg_attrs=nothing, + res_attrs=nothing, + body::Region, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[body,] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("sym_name", sym_name), namedattribute("function_type", function_type) + ] + !isnothing(sym_visibility) && + push!(attributes, namedattribute("sym_visibility", sym_visibility)) + !isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs)) + !isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs)) + + return create_operation( + "tt.func", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`reinterpret_tensor_descriptor` + +This Op exists to help the transition from untyped raw TMA objects to typed Tensor descriptor objects. +Ideally, we can remove this once the APIs are fully fleshed out. +""" +function reinterpret_tensor_descriptor(rawDesc::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[rawDesc,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tt.reinterpret_tensor_descriptor", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`return_` + +The `tt.return` operation represents a return operation within a function. +The operation takes variable number of operands and produces no results. +The operand number and types must match the signature of the function +that contains the operation. + +# Example + +```mlir +tt.func @foo() : (i32, f8) { + ... + tt.return %0, %1 : i32, f8 +} +``` +""" +function return_(srcs::Vector{Value}; location=Location()) + op_ty_results = IR.Type[] + operands = Value[srcs...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tt.return", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function addptr(ptr::Value, offset::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[ptr, offset] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tt.addptr", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function advance(ptr::Value, offsets::Vector{Value}; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[ptr, offsets...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tt.advance", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`assert` + +`tt.assert` takes a condition tensor and a message string. +If the condition is false, the message is printed, and the program is aborted. +""" +function assert(condition::Value; message, location=Location()) + op_ty_results = IR.Type[] + operands = Value[condition,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("message", message),] + + return create_operation( + "tt.assert", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`atomic_cas` + +compare \$cmp with data \$old at location \$ptr, + +if \$old == \$cmp, store \$val to \$ptr, + +else store \$old to \$ptr, + +return \$old +""" +function atomic_cas( + ptr::Value, cmp::Value, val::Value; result::IR.Type, sem, scope, location=Location() +) + op_ty_results = IR.Type[result,] + operands = Value[ptr, cmp, val] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("sem", sem), namedattribute("scope", scope)] + + return create_operation( + "tt.atomic_cas", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`atomic_rmw` + +load data at \$ptr, do \$rmw_op with \$val, and store result to \$ptr. + +return old value at \$ptr +""" +function atomic_rmw( + ptr::Value, + val::Value, + mask=nothing::Union{Nothing,Value}; + result::IR.Type, + atomic_rmw_op, + sem, + scope, + location=Location(), +) + op_ty_results = IR.Type[result,] + operands = Value[ptr, val] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("atomic_rmw_op", atomic_rmw_op), + namedattribute("sem", sem), + namedattribute("scope", scope), + ] + !isnothing(mask) && push!(operands, mask) + + return create_operation( + "tt.atomic_rmw", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function bitcast(src::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[src,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tt.bitcast", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`broadcast` + +For a given tensor, broadcast changes one or more dimensions with size 1 +to a new size, e.g. tensor<1x32x1xf32> -> tensor<2x32x4xf32>. You cannot +change the size of a non-1 dimension. +""" +function broadcast(src::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[src,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tt.broadcast", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function cat(lhs::Value, rhs::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[lhs, rhs] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tt.cat", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`clampf` + +Clamp operation for floating point types. + +The operation takes three arguments: x, min, and max. It returns a tensor of the same shape as x with its values clamped to the range [min, max]. +""" +function clampf( + x::Value, + min::Value, + max::Value; + result=nothing::Union{Nothing,IR.Type}, + propagateNan, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[x, min, max] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("propagateNan", propagateNan),] + !isnothing(result) && push!(op_ty_results, result) + + return create_operation( + "tt.clampf", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`dot` + +\$d = matrix_multiply(\$a, \$b) + \$c. \$inputPrecision describes how to exercise the TC +when the inputs are f32. It can be one of: tf32, tf32x3, ieee. +tf32: use TC with tf32 ops. +tf32x3: implement the 3xTF32 trick. For more info see the pass in F32DotTC.cpp +ieee: don\'t use TC, implement dot in software. +If the GPU does not have Tensor cores or the inputs are not f32, this flag is ignored. +""" +function dot( + a::Value, + b::Value, + c::Value; + d=nothing::Union{Nothing,IR.Type}, + inputPrecision=nothing, + maxNumImpreciseAcc=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[a, b, c] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(d) && push!(op_ty_results, d) + !isnothing(inputPrecision) && + push!(attributes, namedattribute("inputPrecision", inputPrecision)) + !isnothing(maxNumImpreciseAcc) && + push!(attributes, namedattribute("maxNumImpreciseAcc", maxNumImpreciseAcc)) + + return create_operation( + "tt.dot", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`dot_scaled` + +\$d = matrix_multiply(scale(\$lhs, \$lhs_scale), scale(rlhs, \$rhs_scale)) + \$c. +Where scale(x, s) is a function that applies the scale per block following microscaling spec. +""" +function dot_scaled( + lhs::Value, + rhs::Value, + c::Value, + lhs_scale=nothing::Union{Nothing,Value}; + rhs_scale=nothing::Union{Nothing,Value}, + d::IR.Type, + lhs_type, + rhs_type, + fastMath, + location=Location(), +) + op_ty_results = IR.Type[d,] + operands = Value[lhs, rhs, c] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("lhs_type", lhs_type), + namedattribute("rhs_type", rhs_type), + namedattribute("fastMath", fastMath), + ] + !isnothing(lhs_scale) && push!(operands, lhs_scale) + !isnothing(rhs_scale) && push!(operands, rhs_scale) + push!(attributes, operandsegmentsizes([ + 1, + 1, + 1, + if (lhs_scale == nothing) + 0 + elseif 1(rhs_scale == nothing) + 0 + else + 1 + end, + ])) + + return create_operation( + "tt.dot_scaled", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`elementwise_inline_asm` + +Runs an inline asm block to generate one or more tensors. + +The asm block is given `packed_element` elements at a time. Exactly which +elems it receives is unspecified. +""" +function elementwise_inline_asm( + args::Vector{Value}; + result::Vector{IR.Type}, + asm_string, + constraints, + pure, + packed_element, + location=Location(), +) + op_ty_results = IR.Type[result...,] + operands = Value[args...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("asm_string", asm_string), + namedattribute("constraints", constraints), + namedattribute("pure", pure), + namedattribute("packed_element", packed_element), + ] + + return create_operation( + "tt.elementwise_inline_asm", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function expand_dims( + src::Value; result=nothing::Union{Nothing,IR.Type}, axis, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[src,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("axis", axis),] + !isnothing(result) && push!(op_ty_results, result) + + return create_operation( + "tt.expand_dims", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`experimental_descriptor_gather` + +The `tt.experimental_desciptor_gather` op will be lowered to NVIDIA TMA +load operations on targets that support it. + +`desc_ptr` is a pointer to the TMA descriptor allocated in global memory. +The descriptor block must have 1 row and the indices must be a 1D tensor. +Accordingly, the result is a 2D tensor multiple rows. + +This is an escape hatch and is only there for testing/experimenting. This +op will be removed in the future. +""" +function experimental_descriptor_gather( + desc::Value, x_offsets::Value, y_offset::Value; result::IR.Type, location=Location() +) + op_ty_results = IR.Type[result,] + operands = Value[desc, x_offsets, y_offset] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tt.experimental_descriptor_gather", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`experimental_descriptor_load` + +This operation will be lowered to Nvidia TMA load operation on targets supporting it. +`desc` is a tensor descriptor object. +The destination tensor type and shape must match the descriptor otherwise the result is undefined. + +This is an escape hatch and is only there for testing/experimenting. +This op will be removed in the future. +""" +function experimental_descriptor_load( + desc::Value, + indices::Vector{Value}; + result::IR.Type, + cache=nothing, + evict=nothing, + location=Location(), +) + op_ty_results = IR.Type[result,] + operands = Value[desc, indices...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(cache) && push!(attributes, namedattribute("cache", cache)) + !isnothing(evict) && push!(attributes, namedattribute("evict", evict)) + + return create_operation( + "tt.experimental_descriptor_load", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function experimental_descriptor_scatter( + desc::Value, x_offsets::Value, y_offset::Value, src::Value; location=Location() +) + op_ty_results = IR.Type[] + operands = Value[desc, x_offsets, y_offset, src] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tt.experimental_descriptor_scatter", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`experimental_descriptor_store` + +This operation will be lowered to Nvidia TMA store operation on targets supporting it. +`desc` is a tensor descriptor object. +The shape and types of `src` must match the descriptor otherwise the result is undefined. + +This is an escape hatch and is only there for testing/experimenting. +This op will be removed in the future. +""" +function experimental_descriptor_store( + desc::Value, src::Value, indices::Vector{Value}; location=Location() +) + op_ty_results = IR.Type[] + operands = Value[desc, src, indices...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tt.experimental_descriptor_store", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function experimental_tensormap_create( + desc_ptr::Value, + global_address::Value, + box_dim::Vector{Value}, + global_dim::Vector{Value}, + global_stride::Vector{Value}, + element_stride::Vector{Value}; + elem_type, + interleave_layout, + swizzle_mode, + fill_mode, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[ + desc_ptr, + global_address, + box_dim..., + global_dim..., + global_stride..., + element_stride..., + ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("elem_type", elem_type), + namedattribute("interleave_layout", interleave_layout), + namedattribute("swizzle_mode", swizzle_mode), + namedattribute("fill_mode", fill_mode), + ] + push!( + attributes, + operandsegmentsizes([ + 1, + 1, + length(box_dim), + length(global_dim), + length(global_stride), + length(element_stride), + ]), + ) + + return create_operation( + "tt.experimental_tensormap_create", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function experimental_tensormap_fenceproxy_acquire(desc_ptr::Value; location=Location()) + op_ty_results = IR.Type[] + operands = Value[desc_ptr,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tt.experimental_tensormap_fenceproxy_acquire", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`extern_elementwise` + +call an external function \$symbol implemented in \$libpath/\$libname with \$args +return \$libpath/\$libname:\$symbol(\$args...) +""" +function extern_elementwise( + srcs::Vector{Value}; + result::IR.Type, + libname, + libpath, + symbol, + pure, + location=Location(), +) + op_ty_results = IR.Type[result,] + operands = Value[srcs...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("libname", libname), + namedattribute("libpath", libpath), + namedattribute("symbol", symbol), + namedattribute("pure", pure), + ] + + return create_operation( + "tt.extern_elementwise", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`fp_to_fp` + +Floating point casting for custom types (F8), and non-default rounding modes. + +F8 <-> FP16, BF16, FP32, FP64 +""" +function fp_to_fp(src::Value; result::IR.Type, rounding=nothing, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[src,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(rounding) && push!(attributes, namedattribute("rounding", rounding)) + + return create_operation( + "tt.fp_to_fp", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`gather` + +Gather elements from the input tensor using the indices tensor along a +single specified axis. The output tensor has the same shape as the indices +tensor. The input and indices tensors must have the same number of +dimension, and each dimension of the indices tensor that is not the gather +dimension cannot be greater than the corresponding dimension in the input +tensor. + +The `efficient_layout` attribute is set when the compiler has determined an +optimized layout for the operation, indicating that it should not be +changed. +""" +function gather( + src::Value, + indices::Value; + result=nothing::Union{Nothing,IR.Type}, + axis, + efficient_layout=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[src, indices] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("axis", axis),] + !isnothing(result) && push!(op_ty_results, result) + !isnothing(efficient_layout) && + push!(attributes, namedattribute("efficient_layout", efficient_layout)) + + return create_operation( + "tt.gather", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function get_num_programs(; + result=nothing::Union{Nothing,IR.Type}, axis, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("axis", axis),] + !isnothing(result) && push!(op_ty_results, result) + + return create_operation( + "tt.get_num_programs", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function get_program_id(; result=nothing::Union{Nothing,IR.Type}, axis, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("axis", axis),] + !isnothing(result) && push!(op_ty_results, result) + + return create_operation( + "tt.get_program_id", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`histogram` + +Return the histogram of the input tensor. The number of bins is equal to +the dimension of the output tensor. Each bins has a width of 1 and bins +start at 0. +""" +function histogram(src::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[src,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tt.histogram", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function int_to_ptr(src::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[src,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tt.int_to_ptr", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`join` + +For example, if the two input tensors are 4x8xf32, returns a tensor of +shape 4x8x2xf32. + +Because Triton tensors always have a power-of-two number of elements, +the two input tensors must have the same shape. +""" +function join( + lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + return create_operation( + "tt.join", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function load( + ptr::Value, + mask=nothing::Union{Nothing,Value}; + other=nothing::Union{Nothing,Value}, + result=nothing::Union{Nothing,IR.Type}, + boundaryCheck=nothing, + padding=nothing, + cache=nothing, + evict=nothing, + isVolatile=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[ptr,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(mask) && push!(operands, mask) + !isnothing(other) && push!(operands, other) + push!(attributes, operandsegmentsizes([ + 1, + if (mask == nothing) + 0 + elseif 1(other == nothing) + 0 + else + 1 + end, + ])) + !isnothing(result) && push!(op_ty_results, result) + !isnothing(boundaryCheck) && + push!(attributes, namedattribute("boundaryCheck", boundaryCheck)) + !isnothing(padding) && push!(attributes, namedattribute("padding", padding)) + !isnothing(cache) && push!(attributes, namedattribute("cache", cache)) + !isnothing(evict) && push!(attributes, namedattribute("evict", evict)) + !isnothing(isVolatile) && push!(attributes, namedattribute("isVolatile", isVolatile)) + + return create_operation( + "tt.load", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`make_range` + +Returns an 1D int32 tensor. + +Values span from \$start to \$end (exclusive), with step = 1 +""" +function make_range(; result::IR.Type, start, end_, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("start", start), namedattribute("end", end_)] + + return create_operation( + "tt.make_range", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`make_tensor_descriptor` + +`tt.make_tensor_descriptor` takes both meta information of the parent tensor and the block size, +and returns a descriptor object which can be used to load/store from the tensor in global memory. +""" +function make_tensor_descriptor( + base::Value, + shape::Vector{Value}, + strides::Vector{Value}; + result::IR.Type, + location=Location(), +) + op_ty_results = IR.Type[result,] + operands = Value[base, shape..., strides...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tt.make_tensor_descriptor", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`make_tensor_ptr` + +`tt.make_tensor_ptr` takes both meta information of the parent tensor and the block tensor, then it returns a +pointer to the block tensor, e.g. returns a type of `tt.ptr>`. +""" +function make_tensor_ptr( + base::Value, + shape::Vector{Value}, + strides::Vector{Value}, + offsets::Vector{Value}; + result::IR.Type, + order, + location=Location(), +) + op_ty_results = IR.Type[result,] + operands = Value[base, shape..., strides..., offsets...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("order", order),] + + return create_operation( + "tt.make_tensor_ptr", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`mulhiui` + +Most significant N bits of the 2N-bit product of two integers. +""" +function mulhiui( + x::Value, y::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[x, y] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + return create_operation( + "tt.mulhiui", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`precise_divf` + +Precise div for floating point types. +""" +function precise_divf( + x::Value, y::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[x, y] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + return create_operation( + "tt.precise_divf", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`precise_sqrt` + +Precise sqrt for floating point types. +""" +function precise_sqrt(x::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[x,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + return create_operation( + "tt.precise_sqrt", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +""" +`print` + +`tt.print` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed. +format are generated automatically from the arguments. +""" +function print(args::Vector{Value}; prefix, hex, isSigned, location=Location()) + op_ty_results = IR.Type[] + operands = Value[args...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("prefix", prefix), + namedattribute("hex", hex), + namedattribute("isSigned", isSigned), + ] + + return create_operation( + "tt.print", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function ptr_to_int(src::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[src,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tt.ptr_to_int", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function reduce( + srcs::Vector{Value}; + result::Vector{IR.Type}, + axis, + combineOp::Region, + location=Location(), +) + op_ty_results = IR.Type[result...,] + operands = Value[srcs...,] + owned_regions = Region[combineOp,] + successors = Block[] + attributes = NamedAttribute[namedattribute("axis", axis),] + + return create_operation( + "tt.reduce", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function reduce_return(result::Vector{Value}; location=Location()) + op_ty_results = IR.Type[] + operands = Value[result...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tt.reduce.return", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`reshape` + +reinterpret a tensor to a different shape. + +If allow_reorder is set the compiler is free to change the order of +elements to generate more efficient code. + +If efficient_layout is set, this is a hint that the destination layout should be kept for performance reason. +The compiler is still free to change it for better performance. +""" +function reshape( + src::Value; + result::IR.Type, + allow_reorder=nothing, + efficient_layout=nothing, + location=Location(), +) + op_ty_results = IR.Type[result,] + operands = Value[src,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(allow_reorder) && + push!(attributes, namedattribute("allow_reorder", allow_reorder)) + !isnothing(efficient_layout) && + push!(attributes, namedattribute("efficient_layout", efficient_layout)) + + return create_operation( + "tt.reshape", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function scan( + srcs::Vector{Value}; + result::Vector{IR.Type}, + axis, + reverse, + combineOp::Region, + location=Location(), +) + op_ty_results = IR.Type[result...,] + operands = Value[srcs...,] + owned_regions = Region[combineOp,] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("axis", axis), namedattribute("reverse", reverse) + ] + + return create_operation( + "tt.scan", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function scan_return(result::Vector{Value}; location=Location()) + op_ty_results = IR.Type[] + operands = Value[result...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tt.scan.return", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function splat(src::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[src,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tt.splat", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`split` + +The input must be a tensor whose last dimension has size 2. Returns two +tensors, src[..., 0] and src[..., 1]. + +For example, if the input shape is 4x8x2xf32, returns two tensors of +shape 4x8xf32. +""" +function split( + src::Value; + outLHS=nothing::Union{Nothing,IR.Type}, + outRHS=nothing::Union{Nothing,IR.Type}, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[src,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(outLHS) && push!(op_ty_results, outLHS) + !isnothing(outRHS) && push!(op_ty_results, outRHS) + + return create_operation( + "tt.split", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +function store( + ptr::Value, + value::Value, + mask=nothing::Union{Nothing,Value}; + boundaryCheck=nothing, + cache=nothing, + evict=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[ptr, value] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(mask) && push!(operands, mask) + !isnothing(boundaryCheck) && + push!(attributes, namedattribute("boundaryCheck", boundaryCheck)) + !isnothing(cache) && push!(attributes, namedattribute("cache", cache)) + !isnothing(evict) && push!(attributes, namedattribute("evict", evict)) + + return create_operation( + "tt.store", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`trans` + +For example, given a tensor x with shape [1,2,4], transpose(x) with +order=[2,0,1] rearranges the tensor to have shape [4,1,2]. + +Although this op is called \"trans\", it implements both tl.trans() and +tl.permute(). (\"permute\" might be a better name, but it\'s called \"trans\" +because originally it only supported 2D tensors.) + +## Implementation note on encodings: + +In the TritonGPU dialect (and probably others), an encoding is chosen for +this op\'s output so it\'s a nop from the perspective of code generation. + +For example, suppose tensor x has an encoding such that GPU thread [i,j,k] +has a register containing element [i,j,k] of the tensor. Now we transpose +x with order [2,1,0], i.e. we reverse the order of its dimensions. In +TritonGPU, we will choose a layout for the output of the transpose so that +GPU thread [i,j,k] has element [k,j,i] of transpose(x). But this is the +same element it had before! All we\'ve done is \"rename\" the element that +thread [i,j,k] has. + +The \"real\" transpose -- i.e. moving data between GPU threads -- occurs in +convertLayout ops that appear before and/or after the operation. + +We do this so that you can chain multiple data-movement ops (e.g. +transpose+reshape+concat) without going to shared memory after each one. +""" +function trans( + src::Value; result=nothing::Union{Nothing,IR.Type}, order, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[src,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("order", order),] + !isnothing(result) && push!(op_ty_results, result) + + return create_operation( + "tt.trans", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + +end # tt diff --git a/src/mlir/Dialects/VHLO.jl b/src/mlir/Dialects/VHLO.jl index dbc7463d6b..1f706fcba8 100755 --- a/src/mlir/Dialects/VHLO.jl +++ b/src/mlir/Dialects/VHLO.jl @@ -1452,6 +1452,27 @@ function exponential_v1(operand::Value; result::IR.Type, location=Location()) ) end +function exponential_v2( + operand::Value; result::IR.Type, result_accuracy, location=Location() +) + op_ty_results = IR.Type[result,] + operands = Value[operand,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("result_accuracy", result_accuracy),] + + return create_operation( + "vhlo.exponential_v2", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + function exponential_minus_one_v1(operand::Value; result::IR.Type, location=Location()) op_ty_results = IR.Type[result,] operands = Value[operand,] diff --git a/src/mlir/IR/Attribute.jl b/src/mlir/IR/Attribute.jl index d37e7c9862..d7aac00830 100644 --- a/src/mlir/IR/Attribute.jl +++ b/src/mlir/IR/Attribute.jl @@ -588,7 +588,16 @@ function DenseElementsAttribute(values::AbstractArray{Float64}) ) end -# TODO mlirDenseElementsAttrBFloat16Get +if isdefined(Core, :BFloat16) + function DenseElementsAttribute(values::AbstractArray{Core.BFloat16}) + shaped_type = TensorType(size(values), Type(Core.BFloat16)) + return Attribute( + API.mlirDenseElementsAttrBFloat16Get( + shaped_type, length(values), to_row_major(values) + ), + ) + end +end function DenseElementsAttribute(values::AbstractArray{Float16}) shaped_type = TensorType(size(values), Type(Float16)) @@ -599,7 +608,7 @@ function DenseElementsAttribute(values::AbstractArray{Float16}) ) end -function DenseElementsAttribute(values::AbstractArray{<:Complex}) +function DenseElementsAttribute(values::AbstractArray) shaped_type = TensorType(size(values), Type(eltype(values))) return Attribute( API.mlirDenseElementsAttrRawBufferGet( @@ -698,6 +707,9 @@ function DenseArrayAttribute end @llvmversioned min = v"16" DenseArrayAttribute( values::AbstractArray{Int8}; context::Context=context() ) = Attribute(API.mlirDenseI8ArrayGet(context, length(values), to_row_major(values))) +@llvmversioned min = v"16" DenseArrayAttribute( + values::AbstractArray{UInt8}; context::Context=context() +) = Attribute(API.mlirDenseI8ArrayGet(context, length(values), to_row_major(values))) @llvmversioned min = v"16" DenseArrayAttribute( values::AbstractArray{Int16}; context::Context=context() ) = Attribute(API.mlirDenseI16ArrayGet(context, length(values), to_row_major(values))) diff --git a/src/mlir/IR/IR.jl b/src/mlir/IR/IR.jl index 044f27e5af..8da48846fc 100644 --- a/src/mlir/IR/IR.jl +++ b/src/mlir/IR/IR.jl @@ -1,5 +1,6 @@ module IR +using ..Reactant using ..API # do not export `Type`, as it is already defined in Core diff --git a/src/mlir/IR/Type.jl b/src/mlir/IR/Type.jl index c77224d687..3ac576e39c 100644 --- a/src/mlir/IR/Type.jl +++ b/src/mlir/IR/Type.jl @@ -168,6 +168,15 @@ Creates an f16 type in the given context. The type is owned by the context. """ Type(::Core.Type{Float16}; context::Context=context()) = Type(API.mlirF16TypeGet(context)) +if isdefined(Core, :BFloat16) + """ + Type(::Core.Type{Core.BFloat16}; context=context()) + + Creates an bf16 type in the given context. The type is owned by the context. + """ + Type(::Core.Type{Core.BFloat16}; context::Context=context()) = BFloat16Type(; context) +end + """ Type(Core.Type{Float32}; context=context()) @@ -182,6 +191,51 @@ Creates a f64 type in the given context. The type is owned by the context. """ Type(::Core.Type{Float64}; context::Context=context()) = Type(API.mlirF64TypeGet(context)) +""" + Type(::Core.Type{Reactant.F8E5M2}; context=context()) + +Creates a f8e5m2 type in the given context. The type is owned by the context. +""" +function Type(::Core.Type{<:Reactant.F8E5M2}; context::Context=context()) + return Type(API.mlirFloat8E5M2TypeGet(context)) +end + +""" + Type(::Core.Type{Reactant.F8E4M3FN}; context=context()) + +Creates a f8e4m3fn type in the given context. The type is owned by the context. +""" +function Type(::Core.Type{<:Reactant.F8E4M3FN}; context::Context=context()) + return Type(API.mlirFloat8E4M3FNTypeGet(context)) +end + +""" + Type(::Core.Type{Reactant.F8E4M3B11FNUZ}; context=context()) + +Creates a f8e4m3b11fnuz type in the given context. The type is owned by the context. +""" +function Type(::Core.Type{<:Reactant.F8E4M3B11FNUZ}; context::Context=context()) + return Type(API.mlirFloat8E4M3B11FNUZTypeGet(context)) +end + +""" + Type(::Core.Type{Reactant.F8E5M2FNUZ}; context=context()) + +Creates a f8e5m2fnuz type in the given context. The type is owned by the context. +""" +function Type(::Core.Type{<:Reactant.F8E5M2FNUZ}; context::Context=context()) + return Type(API.mlirFloat8E5M2FNUZTypeGet(context)) +end + +""" + Type(::Core.Type{Reactant.F8E4M3FNUZ}; context=context()) + +Creates a f8e4m3fnuz type in the given context. The type is owned by the context. +""" +function Type(::Core.Type{<:Reactant.F8E4M3FNUZ}; context::Context=context()) + return Type(API.mlirFloat8E4M3FNUZTypeGet(context)) +end + """ isf8e5m2(type) @@ -196,6 +250,27 @@ Checks whether the given type is an f8E4M3FN type. """ isf8e4m3fn(type::Type) = API.mlirTypeIsAFloat8E4M3FN(type) +""" + isf8e4m3b11fnuz(type) + +Checks whether the given type is an f8E4M3B11FNUZ type. +""" +isf8e4m3b11fnuz(type::Type) = API.mlirTypeIsAFloat8E4M3B11FNUZ(type) + +""" + isf8e5m2fnuz(type) + +Checks whether the given type is an f8E5M2FNUZ type. +""" +isf8e5m2fnuz(type::Type) = API.mlirTypeIsAFloat8E5M2FNUZ(type) + +""" + isf8e4m3fnuz(type) + +Checks whether the given type is an f8E4M3FNUZ type. +""" +isf8e4m3fnuz(type::Type) = API.mlirTypeIsAFloat8E4M3FNUZ(type) + """ isbf16(type) @@ -721,12 +796,24 @@ function julia_type(type::Type) throw("could not convert unsigned $width-bit integer type to julia") end end + elseif isbf16(type) + Core.BFloat16 elseif isf16(type) Float16 elseif isf32(type) Float32 elseif isf64(type) Float64 + elseif isf8e5m2(type) + Reactant.F8E5M2 + elseif isf8e4m3fn(type) + Reactant.F8E4M3FN + elseif isf8e4m3b11fnuz(type) + Reactant.F8E4M3B11FNUZ + elseif isf8e5m2fnuz(type) + Reactant.F8E5M2FNUZ + elseif isf8e4m3fnuz(type) + Reactant.F8E4M3FNUZ elseif isnone(type) Nothing elseif iscomplex(type) diff --git a/src/mlir/IR/Value.jl b/src/mlir/IR/Value.jl index b3292be12f..a24632d934 100644 --- a/src/mlir/IR/Value.jl +++ b/src/mlir/IR/Value.jl @@ -8,8 +8,8 @@ struct Value end Base.convert(::Core.Type{API.MlirValue}, value::Value) = value.value -Base.size(value::Value) = Base.size(Type(value)) -Base.ndims(value::Value) = Base.ndims(Type(value)) +Base.size(value::Value) = Base.size(type(value)) +Base.ndims(value::Value) = Base.ndims(type(value)) """ ==(value1, value2) diff --git a/src/mlir/MLIR.jl b/src/mlir/MLIR.jl index 6bbf3cad44..aec0ac27fd 100644 --- a/src/mlir/MLIR.jl +++ b/src/mlir/MLIR.jl @@ -1,5 +1,7 @@ module MLIR +using ..Reactant + module API using CEnum using Preferences diff --git a/src/mlir/libMLIR_h.jl b/src/mlir/libMLIR_h.jl index 5f5c0feeb1..20e3c96d89 100644 --- a/src/mlir/libMLIR_h.jl +++ b/src/mlir/libMLIR_h.jl @@ -610,6 +610,24 @@ function mlirLocationFileLineColGet(context, filename, line, col) )::MlirLocation end +""" + mlirLocationFileLineColRangeGet(context, filename, start_line, start_col, end_line, end_col) + +Creates an File/Line/Column range location owned by the given context. +""" +function mlirLocationFileLineColRangeGet( + context, filename, start_line, start_col, end_line, end_col +) + @ccall mlir_c.mlirLocationFileLineColRangeGet( + context::MlirContext, + filename::MlirStringRef, + start_line::Cuint, + start_col::Cuint, + end_line::Cuint, + end_col::Cuint, + )::MlirLocation +end + """ mlirLocationCallSiteGet(callee, caller) @@ -713,6 +731,17 @@ function mlirModuleCreateParse(context, _module) )::MlirModule end +""" + mlirModuleCreateParseFromFile(context, fileName) + +Parses a module from file and transfers ownership to the caller. +""" +function mlirModuleCreateParseFromFile(context, fileName) + @ccall mlir_c.mlirModuleCreateParseFromFile( + context::MlirContext, fileName::MlirStringRef + )::MlirModule +end + """ mlirModuleGetContext(_module) @@ -4243,6 +4272,12 @@ function mlirDenseElementsAttrGetUInt64Value(attr, pos) )::UInt64 end +function mlirDenseElementsAttrGetIndexValue(attr, pos) + @ccall mlir_c.mlirDenseElementsAttrGetIndexValue( + attr::MlirAttribute, pos::intptr_t + )::UInt64 +end + function mlirDenseElementsAttrGetFloatValue(attr, pos) @ccall mlir_c.mlirDenseElementsAttrGetFloatValue( attr::MlirAttribute, pos::intptr_t @@ -5793,1114 +5828,320 @@ function mlirOpaqueTypeGetData(type) @ccall mlir_c.mlirOpaqueTypeGetData(type::MlirType)::MlirStringRef end -struct MlirPass - ptr::Ptr{Cvoid} -end - -struct MlirExternalPass - ptr::Ptr{Cvoid} -end - -struct MlirPassManager - ptr::Ptr{Cvoid} -end +""" + mlirEnableGlobalDebug(enable) -struct MlirOpPassManager - ptr::Ptr{Cvoid} +Sets the global debugging flag. +""" +function mlirEnableGlobalDebug(enable) + @ccall mlir_c.mlirEnableGlobalDebug(enable::Bool)::Cvoid end """ - mlirPassManagerCreate(ctx) + mlirIsGlobalDebugEnabled() -Create a new top-level PassManager with the default anchor. +Retuns `true` if the global debugging flag is set, false otherwise. """ -function mlirPassManagerCreate(ctx) - @ccall mlir_c.mlirPassManagerCreate(ctx::MlirContext)::MlirPassManager +function mlirIsGlobalDebugEnabled() + @ccall mlir_c.mlirIsGlobalDebugEnabled()::Bool end """ - mlirPassManagerCreateOnOperation(ctx, anchorOp) + mlirSetGlobalDebugType(type) -Create a new top-level PassManager anchored on `anchorOp`. +Sets the current debug type, similarly to `-debug-only=type` in the command-line tools. Note that global debug should be enabled for any output to be produced. """ -function mlirPassManagerCreateOnOperation(ctx, anchorOp) - @ccall mlir_c.mlirPassManagerCreateOnOperation( - ctx::MlirContext, anchorOp::MlirStringRef - )::MlirPassManager +function mlirSetGlobalDebugType(type) + @ccall mlir_c.mlirSetGlobalDebugType(type::Cstring)::Cvoid end """ - mlirPassManagerDestroy(passManager) + mlirSetGlobalDebugTypes(types, n) -Destroy the provided PassManager. +Sets multiple current debug types, similarly to `-debug-only=type1,type2" in the command-line tools. Note that global debug should be enabled for any output to be produced. """ -function mlirPassManagerDestroy(passManager) - @ccall mlir_c.mlirPassManagerDestroy(passManager::MlirPassManager)::Cvoid +function mlirSetGlobalDebugTypes(types, n) + @ccall mlir_c.mlirSetGlobalDebugTypes(types::Ptr{Cstring}, n::intptr_t)::Cvoid end """ - mlirPassManagerIsNull(passManager) + mlirIsCurrentDebugType(type) -Checks if a PassManager is null. +Checks if `type` is set as the current debug type. """ -function mlirPassManagerIsNull(passManager) - @ccall mlir_c.mlirPassManagerIsNull(passManager::MlirPassManager)::Bool +function mlirIsCurrentDebugType(type) + @ccall mlir_c.mlirIsCurrentDebugType(type::Cstring)::Bool end """ - mlirPassManagerGetAsOpPassManager(passManager) + MlirDiagnostic -Cast a top-level PassManager to a generic OpPassManager. +An opaque reference to a diagnostic, always owned by the diagnostics engine (context). Must not be stored outside of the diagnostic handler. """ -function mlirPassManagerGetAsOpPassManager(passManager) - @ccall mlir_c.mlirPassManagerGetAsOpPassManager( - passManager::MlirPassManager - )::MlirOpPassManager +struct MlirDiagnostic + ptr::Ptr{Cvoid} end """ - mlirPassManagerRunOnOp(passManager, op) + MlirDiagnosticSeverity -Run the provided `passManager` on the given `op`. +Severity of a diagnostic. """ -function mlirPassManagerRunOnOp(passManager, op) - @ccall mlir_c.mlirPassManagerRunOnOp( - passManager::MlirPassManager, op::MlirOperation - )::MlirLogicalResult +@cenum MlirDiagnosticSeverity::UInt32 begin + MlirDiagnosticError = 0x0000000000000000 + MlirDiagnosticWarning = 0x0000000000000001 + MlirDiagnosticNote = 0x0000000000000002 + MlirDiagnosticRemark = 0x0000000000000003 end """ - mlirPassManagerEnableIRPrinting(passManager, printBeforeAll, printAfterAll, printModuleScope, printAfterOnlyOnChange, printAfterOnlyOnFailure, flags, treePrintingPath) +Opaque identifier of a diagnostic handler, useful to detach a handler. +""" +const MlirDiagnosticHandlerID = UInt64 -Enable IR printing. The treePrintingPath argument is an optional path to a directory where the dumps will be produced. If it isn't provided then dumps are produced to stderr. +# typedef MlirLogicalResult ( * MlirDiagnosticHandler ) ( MlirDiagnostic , void * userData ) """ -function mlirPassManagerEnableIRPrinting( - passManager, - printBeforeAll, - printAfterAll, - printModuleScope, - printAfterOnlyOnChange, - printAfterOnlyOnFailure, - flags, - treePrintingPath, -) - @ccall mlir_c.mlirPassManagerEnableIRPrinting( - passManager::MlirPassManager, - printBeforeAll::Bool, - printAfterAll::Bool, - printModuleScope::Bool, - printAfterOnlyOnChange::Bool, - printAfterOnlyOnFailure::Bool, - flags::MlirOpPrintingFlags, - treePrintingPath::MlirStringRef, - )::Cvoid -end +Diagnostic handler type. Accepts a reference to a diagnostic, which is only guaranteed to be live during the call. The handler is passed the `userData` that was provided when the handler was attached to a context. If the handler processed the diagnostic completely, it is expected to return success. Otherwise, it is expected to return failure to indicate that other handlers should attempt to process the diagnostic. +""" +const MlirDiagnosticHandler = Ptr{Cvoid} """ - mlirPassManagerEnableVerifier(passManager, enable) + mlirDiagnosticPrint(diagnostic, callback, userData) -Enable / disable verify-each. +Prints a diagnostic using the provided callback. """ -function mlirPassManagerEnableVerifier(passManager, enable) - @ccall mlir_c.mlirPassManagerEnableVerifier( - passManager::MlirPassManager, enable::Bool +function mlirDiagnosticPrint(diagnostic, callback, userData) + @ccall mlir_c.mlirDiagnosticPrint( + diagnostic::MlirDiagnostic, callback::MlirStringCallback, userData::Ptr{Cvoid} )::Cvoid end """ - mlirPassManagerGetNestedUnder(passManager, operationName) + mlirDiagnosticGetLocation(diagnostic) -Nest an OpPassManager under the top-level PassManager, the nested passmanager will only run on operations matching the provided name. The returned OpPassManager will be destroyed when the parent is destroyed. To further nest more OpPassManager under the newly returned one, see `mlirOpPassManagerNest` below. +Returns the location at which the diagnostic is reported. """ -function mlirPassManagerGetNestedUnder(passManager, operationName) - @ccall mlir_c.mlirPassManagerGetNestedUnder( - passManager::MlirPassManager, operationName::MlirStringRef - )::MlirOpPassManager +function mlirDiagnosticGetLocation(diagnostic) + @ccall mlir_c.mlirDiagnosticGetLocation(diagnostic::MlirDiagnostic)::MlirLocation end """ - mlirOpPassManagerGetNestedUnder(passManager, operationName) + mlirDiagnosticGetSeverity(diagnostic) -Nest an OpPassManager under the provided OpPassManager, the nested passmanager will only run on operations matching the provided name. The returned OpPassManager will be destroyed when the parent is destroyed. +Returns the severity of the diagnostic. """ -function mlirOpPassManagerGetNestedUnder(passManager, operationName) - @ccall mlir_c.mlirOpPassManagerGetNestedUnder( - passManager::MlirOpPassManager, operationName::MlirStringRef - )::MlirOpPassManager +function mlirDiagnosticGetSeverity(diagnostic) + @ccall mlir_c.mlirDiagnosticGetSeverity( + diagnostic::MlirDiagnostic + )::MlirDiagnosticSeverity end """ - mlirPassManagerAddOwnedPass(passManager, pass) + mlirDiagnosticGetNumNotes(diagnostic) -Add a pass and transfer ownership to the provided top-level mlirPassManager. If the pass is not a generic operation pass or a ModulePass, a new OpPassManager is implicitly nested under the provided PassManager. +Returns the number of notes attached to the diagnostic. """ -function mlirPassManagerAddOwnedPass(passManager, pass) - @ccall mlir_c.mlirPassManagerAddOwnedPass( - passManager::MlirPassManager, pass::MlirPass - )::Cvoid +function mlirDiagnosticGetNumNotes(diagnostic) + @ccall mlir_c.mlirDiagnosticGetNumNotes(diagnostic::MlirDiagnostic)::intptr_t end """ - mlirOpPassManagerAddOwnedPass(passManager, pass) + mlirDiagnosticGetNote(diagnostic, pos) -Add a pass and transfer ownership to the provided mlirOpPassManager. If the pass is not a generic operation pass or matching the type of the provided PassManager, a new OpPassManager is implicitly nested under the provided PassManager. +Returns `pos`-th note attached to the diagnostic. Expects `pos` to be a valid zero-based index into the list of notes. """ -function mlirOpPassManagerAddOwnedPass(passManager, pass) - @ccall mlir_c.mlirOpPassManagerAddOwnedPass( - passManager::MlirOpPassManager, pass::MlirPass - )::Cvoid +function mlirDiagnosticGetNote(diagnostic, pos) + @ccall mlir_c.mlirDiagnosticGetNote( + diagnostic::MlirDiagnostic, pos::intptr_t + )::MlirDiagnostic end """ - mlirOpPassManagerAddPipeline(passManager, pipelineElements, callback, userData) + mlirContextAttachDiagnosticHandler(context, handler, userData, deleteUserData) -Parse a sequence of textual MLIR pass pipeline elements and add them to the provided OpPassManager. If parsing fails an error message is reported using the provided callback. +Attaches the diagnostic handler to the context. Handlers are invoked in the reverse order of attachment until one of them processes the diagnostic completely. When a handler is invoked it is passed the `userData` that was provided when it was attached. If non-NULL, `deleteUserData` is called once the system no longer needs to call the handler (for instance after the handler is detached or the context is destroyed). Returns an identifier that can be used to detach the handler. """ -function mlirOpPassManagerAddPipeline(passManager, pipelineElements, callback, userData) - @ccall mlir_c.mlirOpPassManagerAddPipeline( - passManager::MlirOpPassManager, - pipelineElements::MlirStringRef, - callback::MlirStringCallback, +function mlirContextAttachDiagnosticHandler(context, handler, userData, deleteUserData) + @ccall mlir_c.mlirContextAttachDiagnosticHandler( + context::MlirContext, + handler::MlirDiagnosticHandler, userData::Ptr{Cvoid}, - )::MlirLogicalResult + deleteUserData::Ptr{Cvoid}, + )::MlirDiagnosticHandlerID end """ - mlirPrintPassPipeline(passManager, callback, userData) + mlirContextDetachDiagnosticHandler(context, id) -Print a textual MLIR pass pipeline by sending chunks of the string representation and forwarding `userData to `callback`. Note that the callback may be called several times with consecutive chunks of the string. +Detaches an attached diagnostic handler from the context given its identifier. """ -function mlirPrintPassPipeline(passManager, callback, userData) - @ccall mlir_c.mlirPrintPassPipeline( - passManager::MlirOpPassManager, callback::MlirStringCallback, userData::Ptr{Cvoid} +function mlirContextDetachDiagnosticHandler(context, id) + @ccall mlir_c.mlirContextDetachDiagnosticHandler( + context::MlirContext, id::MlirDiagnosticHandlerID )::Cvoid end """ - mlirParsePassPipeline(passManager, pipeline, callback, userData) + mlirEmitError(location, message) -Parse a textual MLIR pass pipeline and assign it to the provided OpPassManager. If parsing fails an error message is reported using the provided callback. +Emits an error at the given location through the diagnostics engine. Used for testing purposes. """ -function mlirParsePassPipeline(passManager, pipeline, callback, userData) - @ccall mlir_c.mlirParsePassPipeline( - passManager::MlirOpPassManager, - pipeline::MlirStringRef, - callback::MlirStringCallback, - userData::Ptr{Cvoid}, - )::MlirLogicalResult +function mlirEmitError(location, message) + @ccall mlir_c.mlirEmitError(location::MlirLocation, message::Cstring)::Cvoid end -""" - MlirExternalPassCallbacks +function mlirGetDialectHandle__amdgpu__() + @ccall mlir_c.mlirGetDialectHandle__amdgpu__()::MlirDialectHandle +end -Structure of external [`MlirPass`](@ref) callbacks. All callbacks are required to be set unless otherwise specified. +function mlirGetDialectHandle__arith__() + @ccall mlir_c.mlirGetDialectHandle__arith__()::MlirDialectHandle +end -| Field | Note | -| :--------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| construct | This callback is called from the pass is created. This is analogous to a C++ pass constructor. | -| destruct | This callback is called when the pass is destroyed This is analogous to a C++ pass destructor. | -| initialize | This callback is optional. The callback is called before the pass is run, allowing a chance to initialize any complex state necessary for running the pass. See Pass::initialize(MLIRContext *). | -| clone | This callback is called when the pass is cloned. See Pass::clonePass(). | -| run | This callback is called when the pass is run. See Pass::runOnOperation(). | -""" -struct MlirExternalPassCallbacks - construct::Ptr{Cvoid} - destruct::Ptr{Cvoid} - initialize::Ptr{Cvoid} - clone::Ptr{Cvoid} - run::Ptr{Cvoid} +function mlirGetDialectHandle__async__() + @ccall mlir_c.mlirGetDialectHandle__async__()::MlirDialectHandle end -""" - mlirCreateExternalPass(passID, name, argument, description, opName, nDependentDialects, dependentDialects, callbacks, userData) +function mlirGetDialectHandle__cf__() + @ccall mlir_c.mlirGetDialectHandle__cf__()::MlirDialectHandle +end -Creates an external [`MlirPass`](@ref) that calls the supplied `callbacks` using the supplied `userData`. If `opName` is empty, the pass is a generic operation pass. Otherwise it is an operation pass specific to the specified pass name. -""" -function mlirCreateExternalPass( - passID, - name, - argument, - description, - opName, - nDependentDialects, - dependentDialects, - callbacks, - userData, -) - @ccall mlir_c.mlirCreateExternalPass( - passID::MlirTypeID, - name::MlirStringRef, - argument::MlirStringRef, - description::MlirStringRef, - opName::MlirStringRef, - nDependentDialects::intptr_t, - dependentDialects::Ptr{MlirDialectHandle}, - callbacks::MlirExternalPassCallbacks, - userData::Ptr{Cvoid}, - )::MlirPass -end - -""" - mlirExternalPassSignalFailure(pass) - -This signals that the pass has failed. This is only valid to call during the `run` callback of [`MlirExternalPassCallbacks`](@ref). See Pass::signalPassFailure(). -""" -function mlirExternalPassSignalFailure(pass) - @ccall mlir_c.mlirExternalPassSignalFailure(pass::MlirExternalPass)::Cvoid -end - -function mlirRegisterConversionPasses() - @ccall mlir_c.mlirRegisterConversionPasses()::Cvoid -end - -function mlirCreateConversionArithToAMDGPUConversionPass() - @ccall mlir_c.mlirCreateConversionArithToAMDGPUConversionPass()::MlirPass -end - -function mlirRegisterConversionArithToAMDGPUConversionPass() - @ccall mlir_c.mlirRegisterConversionArithToAMDGPUConversionPass()::Cvoid -end - -function mlirCreateConversionArithToArmSMEConversionPass() - @ccall mlir_c.mlirCreateConversionArithToArmSMEConversionPass()::MlirPass -end - -function mlirRegisterConversionArithToArmSMEConversionPass() - @ccall mlir_c.mlirRegisterConversionArithToArmSMEConversionPass()::Cvoid -end - -function mlirCreateConversionArithToLLVMConversionPass() - @ccall mlir_c.mlirCreateConversionArithToLLVMConversionPass()::MlirPass -end - -function mlirRegisterConversionArithToLLVMConversionPass() - @ccall mlir_c.mlirRegisterConversionArithToLLVMConversionPass()::Cvoid -end - -function mlirCreateConversionConvertAMDGPUToROCDL() - @ccall mlir_c.mlirCreateConversionConvertAMDGPUToROCDL()::MlirPass -end - -function mlirRegisterConversionConvertAMDGPUToROCDL() - @ccall mlir_c.mlirRegisterConversionConvertAMDGPUToROCDL()::Cvoid -end - -function mlirCreateConversionConvertAffineForToGPU() - @ccall mlir_c.mlirCreateConversionConvertAffineForToGPU()::MlirPass -end - -function mlirRegisterConversionConvertAffineForToGPU() - @ccall mlir_c.mlirRegisterConversionConvertAffineForToGPU()::Cvoid -end - -function mlirCreateConversionConvertAffineToStandard() - @ccall mlir_c.mlirCreateConversionConvertAffineToStandard()::MlirPass -end - -function mlirRegisterConversionConvertAffineToStandard() - @ccall mlir_c.mlirRegisterConversionConvertAffineToStandard()::Cvoid -end - -function mlirCreateConversionConvertArithToEmitC() - @ccall mlir_c.mlirCreateConversionConvertArithToEmitC()::MlirPass -end - -function mlirRegisterConversionConvertArithToEmitC() - @ccall mlir_c.mlirRegisterConversionConvertArithToEmitC()::Cvoid -end - -function mlirCreateConversionConvertArithToSPIRV() - @ccall mlir_c.mlirCreateConversionConvertArithToSPIRV()::MlirPass -end - -function mlirRegisterConversionConvertArithToSPIRV() - @ccall mlir_c.mlirRegisterConversionConvertArithToSPIRV()::Cvoid -end - -function mlirCreateConversionConvertArmNeon2dToIntr() - @ccall mlir_c.mlirCreateConversionConvertArmNeon2dToIntr()::MlirPass -end - -function mlirRegisterConversionConvertArmNeon2dToIntr() - @ccall mlir_c.mlirRegisterConversionConvertArmNeon2dToIntr()::Cvoid -end - -function mlirCreateConversionConvertArmSMEToLLVM() - @ccall mlir_c.mlirCreateConversionConvertArmSMEToLLVM()::MlirPass -end - -function mlirRegisterConversionConvertArmSMEToLLVM() - @ccall mlir_c.mlirRegisterConversionConvertArmSMEToLLVM()::Cvoid -end - -function mlirCreateConversionConvertArmSMEToSCF() - @ccall mlir_c.mlirCreateConversionConvertArmSMEToSCF()::MlirPass -end - -function mlirRegisterConversionConvertArmSMEToSCF() - @ccall mlir_c.mlirRegisterConversionConvertArmSMEToSCF()::Cvoid -end - -function mlirCreateConversionConvertAsyncToLLVMPass() - @ccall mlir_c.mlirCreateConversionConvertAsyncToLLVMPass()::MlirPass -end - -function mlirRegisterConversionConvertAsyncToLLVMPass() - @ccall mlir_c.mlirRegisterConversionConvertAsyncToLLVMPass()::Cvoid -end - -function mlirCreateConversionConvertBufferizationToMemRef() - @ccall mlir_c.mlirCreateConversionConvertBufferizationToMemRef()::MlirPass -end - -function mlirRegisterConversionConvertBufferizationToMemRef() - @ccall mlir_c.mlirRegisterConversionConvertBufferizationToMemRef()::Cvoid -end - -function mlirCreateConversionConvertComplexToLLVMPass() - @ccall mlir_c.mlirCreateConversionConvertComplexToLLVMPass()::MlirPass -end - -function mlirRegisterConversionConvertComplexToLLVMPass() - @ccall mlir_c.mlirRegisterConversionConvertComplexToLLVMPass()::Cvoid -end - -function mlirCreateConversionConvertComplexToLibm() - @ccall mlir_c.mlirCreateConversionConvertComplexToLibm()::MlirPass -end - -function mlirRegisterConversionConvertComplexToLibm() - @ccall mlir_c.mlirRegisterConversionConvertComplexToLibm()::Cvoid -end - -function mlirCreateConversionConvertComplexToSPIRVPass() - @ccall mlir_c.mlirCreateConversionConvertComplexToSPIRVPass()::MlirPass -end - -function mlirRegisterConversionConvertComplexToSPIRVPass() - @ccall mlir_c.mlirRegisterConversionConvertComplexToSPIRVPass()::Cvoid -end - -function mlirCreateConversionConvertComplexToStandard() - @ccall mlir_c.mlirCreateConversionConvertComplexToStandard()::MlirPass -end - -function mlirRegisterConversionConvertComplexToStandard() - @ccall mlir_c.mlirRegisterConversionConvertComplexToStandard()::Cvoid -end - -function mlirCreateConversionConvertControlFlowToLLVMPass() - @ccall mlir_c.mlirCreateConversionConvertControlFlowToLLVMPass()::MlirPass -end - -function mlirRegisterConversionConvertControlFlowToLLVMPass() - @ccall mlir_c.mlirRegisterConversionConvertControlFlowToLLVMPass()::Cvoid -end - -function mlirCreateConversionConvertControlFlowToSPIRV() - @ccall mlir_c.mlirCreateConversionConvertControlFlowToSPIRV()::MlirPass -end - -function mlirRegisterConversionConvertControlFlowToSPIRV() - @ccall mlir_c.mlirRegisterConversionConvertControlFlowToSPIRV()::Cvoid -end - -function mlirCreateConversionConvertFuncToEmitC() - @ccall mlir_c.mlirCreateConversionConvertFuncToEmitC()::MlirPass -end - -function mlirRegisterConversionConvertFuncToEmitC() - @ccall mlir_c.mlirRegisterConversionConvertFuncToEmitC()::Cvoid -end - -function mlirCreateConversionConvertFuncToLLVMPass() - @ccall mlir_c.mlirCreateConversionConvertFuncToLLVMPass()::MlirPass -end - -function mlirRegisterConversionConvertFuncToLLVMPass() - @ccall mlir_c.mlirRegisterConversionConvertFuncToLLVMPass()::Cvoid -end - -function mlirCreateConversionConvertFuncToSPIRV() - @ccall mlir_c.mlirCreateConversionConvertFuncToSPIRV()::MlirPass -end - -function mlirRegisterConversionConvertFuncToSPIRV() - @ccall mlir_c.mlirRegisterConversionConvertFuncToSPIRV()::Cvoid -end - -function mlirCreateConversionConvertGPUToSPIRV() - @ccall mlir_c.mlirCreateConversionConvertGPUToSPIRV()::MlirPass -end - -function mlirRegisterConversionConvertGPUToSPIRV() - @ccall mlir_c.mlirRegisterConversionConvertGPUToSPIRV()::Cvoid -end - -function mlirCreateConversionConvertGpuLaunchFuncToVulkanLaunchFunc() - @ccall mlir_c.mlirCreateConversionConvertGpuLaunchFuncToVulkanLaunchFunc()::MlirPass -end - -function mlirRegisterConversionConvertGpuLaunchFuncToVulkanLaunchFunc() - @ccall mlir_c.mlirRegisterConversionConvertGpuLaunchFuncToVulkanLaunchFunc()::Cvoid -end - -function mlirCreateConversionConvertGpuOpsToLLVMSPVOps() - @ccall mlir_c.mlirCreateConversionConvertGpuOpsToLLVMSPVOps()::MlirPass -end - -function mlirRegisterConversionConvertGpuOpsToLLVMSPVOps() - @ccall mlir_c.mlirRegisterConversionConvertGpuOpsToLLVMSPVOps()::Cvoid -end - -function mlirCreateConversionConvertGpuOpsToNVVMOps() - @ccall mlir_c.mlirCreateConversionConvertGpuOpsToNVVMOps()::MlirPass -end - -function mlirRegisterConversionConvertGpuOpsToNVVMOps() - @ccall mlir_c.mlirRegisterConversionConvertGpuOpsToNVVMOps()::Cvoid -end - -function mlirCreateConversionConvertGpuOpsToROCDLOps() - @ccall mlir_c.mlirCreateConversionConvertGpuOpsToROCDLOps()::MlirPass -end - -function mlirRegisterConversionConvertGpuOpsToROCDLOps() - @ccall mlir_c.mlirRegisterConversionConvertGpuOpsToROCDLOps()::Cvoid -end - -function mlirCreateConversionConvertIndexToLLVMPass() - @ccall mlir_c.mlirCreateConversionConvertIndexToLLVMPass()::MlirPass -end - -function mlirRegisterConversionConvertIndexToLLVMPass() - @ccall mlir_c.mlirRegisterConversionConvertIndexToLLVMPass()::Cvoid -end - -function mlirCreateConversionConvertIndexToSPIRVPass() - @ccall mlir_c.mlirCreateConversionConvertIndexToSPIRVPass()::MlirPass -end - -function mlirRegisterConversionConvertIndexToSPIRVPass() - @ccall mlir_c.mlirRegisterConversionConvertIndexToSPIRVPass()::Cvoid -end - -function mlirCreateConversionConvertLinalgToStandard() - @ccall mlir_c.mlirCreateConversionConvertLinalgToStandard()::MlirPass -end - -function mlirRegisterConversionConvertLinalgToStandard() - @ccall mlir_c.mlirRegisterConversionConvertLinalgToStandard()::Cvoid -end - -function mlirCreateConversionConvertMathToFuncs() - @ccall mlir_c.mlirCreateConversionConvertMathToFuncs()::MlirPass -end - -function mlirRegisterConversionConvertMathToFuncs() - @ccall mlir_c.mlirRegisterConversionConvertMathToFuncs()::Cvoid -end - -function mlirCreateConversionConvertMathToLLVMPass() - @ccall mlir_c.mlirCreateConversionConvertMathToLLVMPass()::MlirPass -end - -function mlirRegisterConversionConvertMathToLLVMPass() - @ccall mlir_c.mlirRegisterConversionConvertMathToLLVMPass()::Cvoid -end - -function mlirCreateConversionConvertMathToLibm() - @ccall mlir_c.mlirCreateConversionConvertMathToLibm()::MlirPass -end - -function mlirRegisterConversionConvertMathToLibm() - @ccall mlir_c.mlirRegisterConversionConvertMathToLibm()::Cvoid -end - -function mlirCreateConversionConvertMathToROCDL() - @ccall mlir_c.mlirCreateConversionConvertMathToROCDL()::MlirPass -end - -function mlirRegisterConversionConvertMathToROCDL() - @ccall mlir_c.mlirRegisterConversionConvertMathToROCDL()::Cvoid -end - -function mlirCreateConversionConvertMathToSPIRV() - @ccall mlir_c.mlirCreateConversionConvertMathToSPIRV()::MlirPass -end - -function mlirRegisterConversionConvertMathToSPIRV() - @ccall mlir_c.mlirRegisterConversionConvertMathToSPIRV()::Cvoid -end - -function mlirCreateConversionConvertMemRefToEmitC() - @ccall mlir_c.mlirCreateConversionConvertMemRefToEmitC()::MlirPass -end - -function mlirRegisterConversionConvertMemRefToEmitC() - @ccall mlir_c.mlirRegisterConversionConvertMemRefToEmitC()::Cvoid -end - -function mlirCreateConversionConvertMemRefToSPIRV() - @ccall mlir_c.mlirCreateConversionConvertMemRefToSPIRV()::MlirPass -end - -function mlirRegisterConversionConvertMemRefToSPIRV() - @ccall mlir_c.mlirRegisterConversionConvertMemRefToSPIRV()::Cvoid -end - -function mlirCreateConversionConvertMeshToMPIPass() - @ccall mlir_c.mlirCreateConversionConvertMeshToMPIPass()::MlirPass -end - -function mlirRegisterConversionConvertMeshToMPIPass() - @ccall mlir_c.mlirRegisterConversionConvertMeshToMPIPass()::Cvoid -end - -function mlirCreateConversionConvertNVGPUToNVVMPass() - @ccall mlir_c.mlirCreateConversionConvertNVGPUToNVVMPass()::MlirPass -end - -function mlirRegisterConversionConvertNVGPUToNVVMPass() - @ccall mlir_c.mlirRegisterConversionConvertNVGPUToNVVMPass()::Cvoid -end - -function mlirCreateConversionConvertNVVMToLLVMPass() - @ccall mlir_c.mlirCreateConversionConvertNVVMToLLVMPass()::MlirPass -end - -function mlirRegisterConversionConvertNVVMToLLVMPass() - @ccall mlir_c.mlirRegisterConversionConvertNVVMToLLVMPass()::Cvoid -end - -function mlirCreateConversionConvertOpenACCToSCF() - @ccall mlir_c.mlirCreateConversionConvertOpenACCToSCF()::MlirPass -end - -function mlirRegisterConversionConvertOpenACCToSCF() - @ccall mlir_c.mlirRegisterConversionConvertOpenACCToSCF()::Cvoid -end - -function mlirCreateConversionConvertOpenMPToLLVMPass() - @ccall mlir_c.mlirCreateConversionConvertOpenMPToLLVMPass()::MlirPass -end - -function mlirRegisterConversionConvertOpenMPToLLVMPass() - @ccall mlir_c.mlirRegisterConversionConvertOpenMPToLLVMPass()::Cvoid -end - -function mlirCreateConversionConvertPDLToPDLInterp() - @ccall mlir_c.mlirCreateConversionConvertPDLToPDLInterp()::MlirPass -end - -function mlirRegisterConversionConvertPDLToPDLInterp() - @ccall mlir_c.mlirRegisterConversionConvertPDLToPDLInterp()::Cvoid -end - -function mlirCreateConversionConvertParallelLoopToGpu() - @ccall mlir_c.mlirCreateConversionConvertParallelLoopToGpu()::MlirPass -end - -function mlirRegisterConversionConvertParallelLoopToGpu() - @ccall mlir_c.mlirRegisterConversionConvertParallelLoopToGpu()::Cvoid -end - -function mlirCreateConversionConvertSCFToOpenMPPass() - @ccall mlir_c.mlirCreateConversionConvertSCFToOpenMPPass()::MlirPass -end - -function mlirRegisterConversionConvertSCFToOpenMPPass() - @ccall mlir_c.mlirRegisterConversionConvertSCFToOpenMPPass()::Cvoid -end - -function mlirCreateConversionConvertSPIRVToLLVMPass() - @ccall mlir_c.mlirCreateConversionConvertSPIRVToLLVMPass()::MlirPass -end - -function mlirRegisterConversionConvertSPIRVToLLVMPass() - @ccall mlir_c.mlirRegisterConversionConvertSPIRVToLLVMPass()::Cvoid -end - -function mlirCreateConversionConvertShapeConstraints() - @ccall mlir_c.mlirCreateConversionConvertShapeConstraints()::MlirPass -end - -function mlirRegisterConversionConvertShapeConstraints() - @ccall mlir_c.mlirRegisterConversionConvertShapeConstraints()::Cvoid -end - -function mlirCreateConversionConvertShapeToStandard() - @ccall mlir_c.mlirCreateConversionConvertShapeToStandard()::MlirPass -end - -function mlirRegisterConversionConvertShapeToStandard() - @ccall mlir_c.mlirRegisterConversionConvertShapeToStandard()::Cvoid -end - -function mlirCreateConversionConvertTensorToLinalg() - @ccall mlir_c.mlirCreateConversionConvertTensorToLinalg()::MlirPass -end - -function mlirRegisterConversionConvertTensorToLinalg() - @ccall mlir_c.mlirRegisterConversionConvertTensorToLinalg()::Cvoid -end - -function mlirCreateConversionConvertTensorToSPIRV() - @ccall mlir_c.mlirCreateConversionConvertTensorToSPIRV()::MlirPass -end - -function mlirRegisterConversionConvertTensorToSPIRV() - @ccall mlir_c.mlirRegisterConversionConvertTensorToSPIRV()::Cvoid -end - -function mlirCreateConversionConvertToLLVMPass() - @ccall mlir_c.mlirCreateConversionConvertToLLVMPass()::MlirPass -end - -function mlirRegisterConversionConvertToLLVMPass() - @ccall mlir_c.mlirRegisterConversionConvertToLLVMPass()::Cvoid -end - -function mlirCreateConversionConvertToSPIRVPass() - @ccall mlir_c.mlirCreateConversionConvertToSPIRVPass()::MlirPass -end - -function mlirRegisterConversionConvertToSPIRVPass() - @ccall mlir_c.mlirRegisterConversionConvertToSPIRVPass()::Cvoid -end - -function mlirCreateConversionConvertVectorToArmSME() - @ccall mlir_c.mlirCreateConversionConvertVectorToArmSME()::MlirPass -end - -function mlirRegisterConversionConvertVectorToArmSME() - @ccall mlir_c.mlirRegisterConversionConvertVectorToArmSME()::Cvoid -end - -function mlirCreateConversionConvertVectorToGPU() - @ccall mlir_c.mlirCreateConversionConvertVectorToGPU()::MlirPass -end - -function mlirRegisterConversionConvertVectorToGPU() - @ccall mlir_c.mlirRegisterConversionConvertVectorToGPU()::Cvoid -end - -function mlirCreateConversionConvertVectorToLLVMPass() - @ccall mlir_c.mlirCreateConversionConvertVectorToLLVMPass()::MlirPass -end - -function mlirRegisterConversionConvertVectorToLLVMPass() - @ccall mlir_c.mlirRegisterConversionConvertVectorToLLVMPass()::Cvoid -end - -function mlirCreateConversionConvertVectorToSCF() - @ccall mlir_c.mlirCreateConversionConvertVectorToSCF()::MlirPass -end - -function mlirRegisterConversionConvertVectorToSCF() - @ccall mlir_c.mlirRegisterConversionConvertVectorToSCF()::Cvoid -end - -function mlirCreateConversionConvertVectorToSPIRV() - @ccall mlir_c.mlirCreateConversionConvertVectorToSPIRV()::MlirPass -end - -function mlirRegisterConversionConvertVectorToSPIRV() - @ccall mlir_c.mlirRegisterConversionConvertVectorToSPIRV()::Cvoid -end - -function mlirCreateConversionConvertVectorToXeGPU() - @ccall mlir_c.mlirCreateConversionConvertVectorToXeGPU()::MlirPass -end - -function mlirRegisterConversionConvertVectorToXeGPU() - @ccall mlir_c.mlirRegisterConversionConvertVectorToXeGPU()::Cvoid -end - -function mlirCreateConversionConvertVulkanLaunchFuncToVulkanCallsPass() - @ccall mlir_c.mlirCreateConversionConvertVulkanLaunchFuncToVulkanCallsPass()::MlirPass -end - -function mlirRegisterConversionConvertVulkanLaunchFuncToVulkanCallsPass() - @ccall mlir_c.mlirRegisterConversionConvertVulkanLaunchFuncToVulkanCallsPass()::Cvoid -end - -function mlirCreateConversionFinalizeMemRefToLLVMConversionPass() - @ccall mlir_c.mlirCreateConversionFinalizeMemRefToLLVMConversionPass()::MlirPass -end - -function mlirRegisterConversionFinalizeMemRefToLLVMConversionPass() - @ccall mlir_c.mlirRegisterConversionFinalizeMemRefToLLVMConversionPass()::Cvoid -end - -function mlirCreateConversionGpuToLLVMConversionPass() - @ccall mlir_c.mlirCreateConversionGpuToLLVMConversionPass()::MlirPass -end - -function mlirRegisterConversionGpuToLLVMConversionPass() - @ccall mlir_c.mlirRegisterConversionGpuToLLVMConversionPass()::Cvoid -end - -function mlirCreateConversionLiftControlFlowToSCFPass() - @ccall mlir_c.mlirCreateConversionLiftControlFlowToSCFPass()::MlirPass -end - -function mlirRegisterConversionLiftControlFlowToSCFPass() - @ccall mlir_c.mlirRegisterConversionLiftControlFlowToSCFPass()::Cvoid -end - -function mlirCreateConversionLowerHostCodeToLLVMPass() - @ccall mlir_c.mlirCreateConversionLowerHostCodeToLLVMPass()::MlirPass -end - -function mlirRegisterConversionLowerHostCodeToLLVMPass() - @ccall mlir_c.mlirRegisterConversionLowerHostCodeToLLVMPass()::Cvoid -end - -function mlirCreateConversionMapMemRefStorageClass() - @ccall mlir_c.mlirCreateConversionMapMemRefStorageClass()::MlirPass -end - -function mlirRegisterConversionMapMemRefStorageClass() - @ccall mlir_c.mlirRegisterConversionMapMemRefStorageClass()::Cvoid -end - -function mlirCreateConversionReconcileUnrealizedCasts() - @ccall mlir_c.mlirCreateConversionReconcileUnrealizedCasts()::MlirPass -end - -function mlirRegisterConversionReconcileUnrealizedCasts() - @ccall mlir_c.mlirRegisterConversionReconcileUnrealizedCasts()::Cvoid -end - -function mlirCreateConversionSCFToControlFlow() - @ccall mlir_c.mlirCreateConversionSCFToControlFlow()::MlirPass -end - -function mlirRegisterConversionSCFToControlFlow() - @ccall mlir_c.mlirRegisterConversionSCFToControlFlow()::Cvoid -end - -function mlirCreateConversionSCFToEmitC() - @ccall mlir_c.mlirCreateConversionSCFToEmitC()::MlirPass -end - -function mlirRegisterConversionSCFToEmitC() - @ccall mlir_c.mlirRegisterConversionSCFToEmitC()::Cvoid -end - -function mlirCreateConversionSCFToSPIRV() - @ccall mlir_c.mlirCreateConversionSCFToSPIRV()::MlirPass -end - -function mlirRegisterConversionSCFToSPIRV() - @ccall mlir_c.mlirRegisterConversionSCFToSPIRV()::Cvoid -end - -function mlirCreateConversionSetLLVMModuleDataLayoutPass() - @ccall mlir_c.mlirCreateConversionSetLLVMModuleDataLayoutPass()::MlirPass -end - -function mlirRegisterConversionSetLLVMModuleDataLayoutPass() - @ccall mlir_c.mlirRegisterConversionSetLLVMModuleDataLayoutPass()::Cvoid -end - -function mlirCreateConversionTosaToArith() - @ccall mlir_c.mlirCreateConversionTosaToArith()::MlirPass -end - -function mlirRegisterConversionTosaToArith() - @ccall mlir_c.mlirRegisterConversionTosaToArith()::Cvoid -end - -function mlirCreateConversionTosaToLinalg() - @ccall mlir_c.mlirCreateConversionTosaToLinalg()::MlirPass -end - -function mlirRegisterConversionTosaToLinalg() - @ccall mlir_c.mlirRegisterConversionTosaToLinalg()::Cvoid -end - -function mlirCreateConversionTosaToLinalgNamed() - @ccall mlir_c.mlirCreateConversionTosaToLinalgNamed()::MlirPass -end - -function mlirRegisterConversionTosaToLinalgNamed() - @ccall mlir_c.mlirRegisterConversionTosaToLinalgNamed()::Cvoid -end - -function mlirCreateConversionTosaToMLProgram() - @ccall mlir_c.mlirCreateConversionTosaToMLProgram()::MlirPass -end - -function mlirRegisterConversionTosaToMLProgram() - @ccall mlir_c.mlirRegisterConversionTosaToMLProgram()::Cvoid -end - -function mlirCreateConversionTosaToSCF() - @ccall mlir_c.mlirCreateConversionTosaToSCF()::MlirPass -end - -function mlirRegisterConversionTosaToSCF() - @ccall mlir_c.mlirRegisterConversionTosaToSCF()::Cvoid -end - -function mlirCreateConversionTosaToTensor() - @ccall mlir_c.mlirCreateConversionTosaToTensor()::MlirPass -end - -function mlirRegisterConversionTosaToTensor() - @ccall mlir_c.mlirRegisterConversionTosaToTensor()::Cvoid -end - -function mlirCreateConversionUBToLLVMConversionPass() - @ccall mlir_c.mlirCreateConversionUBToLLVMConversionPass()::MlirPass -end - -function mlirRegisterConversionUBToLLVMConversionPass() - @ccall mlir_c.mlirRegisterConversionUBToLLVMConversionPass()::Cvoid -end - -function mlirCreateConversionUBToSPIRVConversionPass() - @ccall mlir_c.mlirCreateConversionUBToSPIRVConversionPass()::MlirPass -end - -function mlirRegisterConversionUBToSPIRVConversionPass() - @ccall mlir_c.mlirRegisterConversionUBToSPIRVConversionPass()::Cvoid -end - -""" - mlirEnableGlobalDebug(enable) - -Sets the global debugging flag. -""" -function mlirEnableGlobalDebug(enable) - @ccall mlir_c.mlirEnableGlobalDebug(enable::Bool)::Cvoid -end - -""" - mlirIsGlobalDebugEnabled() - -Retuns `true` if the global debugging flag is set, false otherwise. -""" -function mlirIsGlobalDebugEnabled() - @ccall mlir_c.mlirIsGlobalDebugEnabled()::Bool -end - -""" - mlirSetGlobalDebugType(type) - -Sets the current debug type, similarly to `-debug-only=type` in the command-line tools. Note that global debug should be enabled for any output to be produced. -""" -function mlirSetGlobalDebugType(type) - @ccall mlir_c.mlirSetGlobalDebugType(type::Cstring)::Cvoid -end - -""" - mlirSetGlobalDebugTypes(types, n) - -Sets multiple current debug types, similarly to `-debug-only=type1,type2" in the command-line tools. Note that global debug should be enabled for any output to be produced. -""" -function mlirSetGlobalDebugTypes(types, n) - @ccall mlir_c.mlirSetGlobalDebugTypes(types::Ptr{Cstring}, n::intptr_t)::Cvoid -end - -""" - mlirIsCurrentDebugType(type) - -Checks if `type` is set as the current debug type. -""" -function mlirIsCurrentDebugType(type) - @ccall mlir_c.mlirIsCurrentDebugType(type::Cstring)::Bool -end - -""" - MlirDiagnostic - -An opaque reference to a diagnostic, always owned by the diagnostics engine (context). Must not be stored outside of the diagnostic handler. -""" -struct MlirDiagnostic - ptr::Ptr{Cvoid} -end - -""" - MlirDiagnosticSeverity - -Severity of a diagnostic. -""" -@cenum MlirDiagnosticSeverity::UInt32 begin - MlirDiagnosticError = 0x0000000000000000 - MlirDiagnosticWarning = 0x0000000000000001 - MlirDiagnosticNote = 0x0000000000000002 - MlirDiagnosticRemark = 0x0000000000000003 -end - -""" -Opaque identifier of a diagnostic handler, useful to detach a handler. -""" -const MlirDiagnosticHandlerID = UInt64 - -# typedef MlirLogicalResult ( * MlirDiagnosticHandler ) ( MlirDiagnostic , void * userData ) -""" -Diagnostic handler type. Accepts a reference to a diagnostic, which is only guaranteed to be live during the call. The handler is passed the `userData` that was provided when the handler was attached to a context. If the handler processed the diagnostic completely, it is expected to return success. Otherwise, it is expected to return failure to indicate that other handlers should attempt to process the diagnostic. -""" -const MlirDiagnosticHandler = Ptr{Cvoid} - -""" - mlirDiagnosticPrint(diagnostic, callback, userData) - -Prints a diagnostic using the provided callback. -""" -function mlirDiagnosticPrint(diagnostic, callback, userData) - @ccall mlir_c.mlirDiagnosticPrint( - diagnostic::MlirDiagnostic, callback::MlirStringCallback, userData::Ptr{Cvoid} - )::Cvoid -end - -""" - mlirDiagnosticGetLocation(diagnostic) - -Returns the location at which the diagnostic is reported. -""" -function mlirDiagnosticGetLocation(diagnostic) - @ccall mlir_c.mlirDiagnosticGetLocation(diagnostic::MlirDiagnostic)::MlirLocation +function mlirGetDialectHandle__emitc__() + @ccall mlir_c.mlirGetDialectHandle__emitc__()::MlirDialectHandle end -""" - mlirDiagnosticGetSeverity(diagnostic) +@cenum MlirEmitCCmpPredicate::UInt64 begin + MLIR_EMITC_CMP_PREDICATE_EQ = 0x0000000000000000 + MLIR_EMITC_CMP_PREDICATE_NE = 0x0000000000000001 + MLIR_EMITC_CMP_PREDICATE_LT = 0x0000000000000002 + MLIR_EMITC_CMP_PREDICATE_LE = 0x0000000000000003 + MLIR_EMITC_CMP_PREDICATE_GT = 0x0000000000000004 + MLIR_EMITC_CMP_PREDICATE_GE = 0x0000000000000005 + MLIR_EMITC_CMP_PREDICATE_THREE_WAY = 0x0000000000000006 +end -Returns the severity of the diagnostic. -""" -function mlirDiagnosticGetSeverity(diagnostic) - @ccall mlir_c.mlirDiagnosticGetSeverity( - diagnostic::MlirDiagnostic - )::MlirDiagnosticSeverity +function mlirTypeIsAEmitCArrayType(type) + @ccall mlir_c.mlirTypeIsAEmitCArrayType(type::MlirType)::Bool end -""" - mlirDiagnosticGetNumNotes(diagnostic) +function mlirEmitCArrayTypeGetTypeID() + @ccall mlir_c.mlirEmitCArrayTypeGetTypeID()::MlirTypeID +end -Returns the number of notes attached to the diagnostic. -""" -function mlirDiagnosticGetNumNotes(diagnostic) - @ccall mlir_c.mlirDiagnosticGetNumNotes(diagnostic::MlirDiagnostic)::intptr_t +function mlirEmitCArrayTypeGet(nDims, shape, elementType) + @ccall mlir_c.mlirEmitCArrayTypeGet( + nDims::intptr_t, shape::Ptr{Int64}, elementType::MlirType + )::MlirType end -""" - mlirDiagnosticGetNote(diagnostic, pos) +function mlirTypeIsAEmitCLValueType(type) + @ccall mlir_c.mlirTypeIsAEmitCLValueType(type::MlirType)::Bool +end -Returns `pos`-th note attached to the diagnostic. Expects `pos` to be a valid zero-based index into the list of notes. -""" -function mlirDiagnosticGetNote(diagnostic, pos) - @ccall mlir_c.mlirDiagnosticGetNote( - diagnostic::MlirDiagnostic, pos::intptr_t - )::MlirDiagnostic +function mlirEmitCLValueTypeGetTypeID() + @ccall mlir_c.mlirEmitCLValueTypeGetTypeID()::MlirTypeID end -""" - mlirContextAttachDiagnosticHandler(context, handler, userData, deleteUserData) +function mlirEmitCLValueTypeGet(valueType) + @ccall mlir_c.mlirEmitCLValueTypeGet(valueType::MlirType)::MlirType +end -Attaches the diagnostic handler to the context. Handlers are invoked in the reverse order of attachment until one of them processes the diagnostic completely. When a handler is invoked it is passed the `userData` that was provided when it was attached. If non-NULL, `deleteUserData` is called once the system no longer needs to call the handler (for instance after the handler is detached or the context is destroyed). Returns an identifier that can be used to detach the handler. -""" -function mlirContextAttachDiagnosticHandler(context, handler, userData, deleteUserData) - @ccall mlir_c.mlirContextAttachDiagnosticHandler( - context::MlirContext, - handler::MlirDiagnosticHandler, - userData::Ptr{Cvoid}, - deleteUserData::Ptr{Cvoid}, - )::MlirDiagnosticHandlerID +function mlirTypeIsAEmitCOpaqueType(type) + @ccall mlir_c.mlirTypeIsAEmitCOpaqueType(type::MlirType)::Bool end -""" - mlirContextDetachDiagnosticHandler(context, id) +function mlirEmitCOpaqueTypeGetTypeID() + @ccall mlir_c.mlirEmitCOpaqueTypeGetTypeID()::MlirTypeID +end -Detaches an attached diagnostic handler from the context given its identifier. -""" -function mlirContextDetachDiagnosticHandler(context, id) - @ccall mlir_c.mlirContextDetachDiagnosticHandler( - context::MlirContext, id::MlirDiagnosticHandlerID - )::Cvoid +function mlirEmitCOpaqueTypeGet(ctx, value) + @ccall mlir_c.mlirEmitCOpaqueTypeGet(ctx::MlirContext, value::MlirStringRef)::MlirType end -""" - mlirEmitError(location, message) +function mlirTypeIsAEmitCPointerType(type) + @ccall mlir_c.mlirTypeIsAEmitCPointerType(type::MlirType)::Bool +end -Emits an error at the given location through the diagnostics engine. Used for testing purposes. -""" -function mlirEmitError(location, message) - @ccall mlir_c.mlirEmitError(location::MlirLocation, message::Cstring)::Cvoid +function mlirEmitCPointerTypeGetTypeID() + @ccall mlir_c.mlirEmitCPointerTypeGetTypeID()::MlirTypeID end -function mlirGetDialectHandle__amdgpu__() - @ccall mlir_c.mlirGetDialectHandle__amdgpu__()::MlirDialectHandle +function mlirEmitCPointerTypeGet(pointee) + @ccall mlir_c.mlirEmitCPointerTypeGet(pointee::MlirType)::MlirType end -function mlirGetDialectHandle__arith__() - @ccall mlir_c.mlirGetDialectHandle__arith__()::MlirDialectHandle +function mlirTypeIsAEmitCPtrDiffTType(type) + @ccall mlir_c.mlirTypeIsAEmitCPtrDiffTType(type::MlirType)::Bool end -function mlirGetDialectHandle__async__() - @ccall mlir_c.mlirGetDialectHandle__async__()::MlirDialectHandle +function mlirEmitCPtrDiffTTypeGetTypeID() + @ccall mlir_c.mlirEmitCPtrDiffTTypeGetTypeID()::MlirTypeID end -function mlirRegisterAsyncPasses() - @ccall mlir_c.mlirRegisterAsyncPasses()::Cvoid +function mlirEmitCPtrDiffTTypeGet(ctx) + @ccall mlir_c.mlirEmitCPtrDiffTTypeGet(ctx::MlirContext)::MlirType end -function mlirCreateAsyncAsyncFuncToAsyncRuntime() - @ccall mlir_c.mlirCreateAsyncAsyncFuncToAsyncRuntime()::MlirPass +function mlirTypeIsAEmitCSignedSizeTType(type) + @ccall mlir_c.mlirTypeIsAEmitCSignedSizeTType(type::MlirType)::Bool end -function mlirRegisterAsyncAsyncFuncToAsyncRuntime() - @ccall mlir_c.mlirRegisterAsyncAsyncFuncToAsyncRuntime()::Cvoid +function mlirEmitCSignedSizeTTypeGetTypeID() + @ccall mlir_c.mlirEmitCSignedSizeTTypeGetTypeID()::MlirTypeID end -function mlirCreateAsyncAsyncParallelFor() - @ccall mlir_c.mlirCreateAsyncAsyncParallelFor()::MlirPass +function mlirEmitCSignedSizeTTypeGet(ctx) + @ccall mlir_c.mlirEmitCSignedSizeTTypeGet(ctx::MlirContext)::MlirType end -function mlirRegisterAsyncAsyncParallelFor() - @ccall mlir_c.mlirRegisterAsyncAsyncParallelFor()::Cvoid +function mlirTypeIsAEmitCSizeTType(type) + @ccall mlir_c.mlirTypeIsAEmitCSizeTType(type::MlirType)::Bool end -function mlirCreateAsyncAsyncRuntimePolicyBasedRefCounting() - @ccall mlir_c.mlirCreateAsyncAsyncRuntimePolicyBasedRefCounting()::MlirPass +function mlirEmitCSizeTTypeGetTypeID() + @ccall mlir_c.mlirEmitCSizeTTypeGetTypeID()::MlirTypeID end -function mlirRegisterAsyncAsyncRuntimePolicyBasedRefCounting() - @ccall mlir_c.mlirRegisterAsyncAsyncRuntimePolicyBasedRefCounting()::Cvoid +function mlirEmitCSizeTTypeGet(ctx) + @ccall mlir_c.mlirEmitCSizeTTypeGet(ctx::MlirContext)::MlirType end -function mlirCreateAsyncAsyncRuntimeRefCounting() - @ccall mlir_c.mlirCreateAsyncAsyncRuntimeRefCounting()::MlirPass +function mlirAttributeIsAEmitCCmpPredicate(attr) + @ccall mlir_c.mlirAttributeIsAEmitCCmpPredicate(attr::MlirAttribute)::Bool end -function mlirRegisterAsyncAsyncRuntimeRefCounting() - @ccall mlir_c.mlirRegisterAsyncAsyncRuntimeRefCounting()::Cvoid +function mlirEmitCCmpPredicateAttrGet(ctx, val) + @ccall mlir_c.mlirEmitCCmpPredicateAttrGet( + ctx::MlirContext, val::MlirEmitCCmpPredicate + )::MlirAttribute end -function mlirCreateAsyncAsyncRuntimeRefCountingOpt() - @ccall mlir_c.mlirCreateAsyncAsyncRuntimeRefCountingOpt()::MlirPass +function mlirEmitCCmpPredicateAttrGetValue(attr) + @ccall mlir_c.mlirEmitCCmpPredicateAttrGetValue( + attr::MlirAttribute + )::MlirEmitCCmpPredicate end -function mlirRegisterAsyncAsyncRuntimeRefCountingOpt() - @ccall mlir_c.mlirRegisterAsyncAsyncRuntimeRefCountingOpt()::Cvoid +function mlirEmitCCmpPredicateAttrGetTypeID() + @ccall mlir_c.mlirEmitCCmpPredicateAttrGetTypeID()::MlirTypeID end -function mlirCreateAsyncAsyncToAsyncRuntime() - @ccall mlir_c.mlirCreateAsyncAsyncToAsyncRuntime()::MlirPass +function mlirAttributeIsAEmitCOpaque(attr) + @ccall mlir_c.mlirAttributeIsAEmitCOpaque(attr::MlirAttribute)::Bool end -function mlirRegisterAsyncAsyncToAsyncRuntime() - @ccall mlir_c.mlirRegisterAsyncAsyncToAsyncRuntime()::Cvoid +function mlirEmitCOpaqueAttrGet(ctx, value) + @ccall mlir_c.mlirEmitCOpaqueAttrGet( + ctx::MlirContext, value::MlirStringRef + )::MlirAttribute end -function mlirGetDialectHandle__cf__() - @ccall mlir_c.mlirGetDialectHandle__cf__()::MlirDialectHandle +function mlirEmitCOpaqueAttrGetValue(attr) + @ccall mlir_c.mlirEmitCOpaqueAttrGetValue(attr::MlirAttribute)::MlirStringRef end -function mlirGetDialectHandle__emitc__() - @ccall mlir_c.mlirGetDialectHandle__emitc__()::MlirDialectHandle +function mlirEmitCOpaqueAttrGetTypeID() + @ccall mlir_c.mlirEmitCOpaqueAttrGetTypeID()::MlirTypeID end function mlirGetDialectHandle__func__() @@ -6918,6 +6159,12 @@ function mlirFuncSetArgAttr(op, pos, name, attr) )::Cvoid end +function mlirFuncSetResultAttr(op, pos, name, attr) + @ccall mlir_c.mlirFuncSetResultAttr( + op::MlirOperation, pos::intptr_t, name::MlirStringRef, attr::MlirAttribute + )::Cvoid +end + function mlirGetDialectHandle__gpu__() @ccall mlir_c.mlirGetDialectHandle__gpu__()::MlirDialectHandle end @@ -6987,90 +6234,6 @@ function mlirGPUObjectAttrGetKernels(mlirObjectAttr) @ccall mlir_c.mlirGPUObjectAttrGetKernels(mlirObjectAttr::MlirAttribute)::MlirAttribute end -function mlirRegisterGPUPasses() - @ccall mlir_c.mlirRegisterGPUPasses()::Cvoid -end - -function mlirCreateGPUGpuAsyncRegionPass() - @ccall mlir_c.mlirCreateGPUGpuAsyncRegionPass()::MlirPass -end - -function mlirRegisterGPUGpuAsyncRegionPass() - @ccall mlir_c.mlirRegisterGPUGpuAsyncRegionPass()::Cvoid -end - -function mlirCreateGPUGpuDecomposeMemrefsPass() - @ccall mlir_c.mlirCreateGPUGpuDecomposeMemrefsPass()::MlirPass -end - -function mlirRegisterGPUGpuDecomposeMemrefsPass() - @ccall mlir_c.mlirRegisterGPUGpuDecomposeMemrefsPass()::Cvoid -end - -function mlirCreateGPUGpuEliminateBarriers() - @ccall mlir_c.mlirCreateGPUGpuEliminateBarriers()::MlirPass -end - -function mlirRegisterGPUGpuEliminateBarriers() - @ccall mlir_c.mlirRegisterGPUGpuEliminateBarriers()::Cvoid -end - -function mlirCreateGPUGpuKernelOutlining() - @ccall mlir_c.mlirCreateGPUGpuKernelOutlining()::MlirPass -end - -function mlirRegisterGPUGpuKernelOutlining() - @ccall mlir_c.mlirRegisterGPUGpuKernelOutlining()::Cvoid -end - -function mlirCreateGPUGpuLaunchSinkIndexComputations() - @ccall mlir_c.mlirCreateGPUGpuLaunchSinkIndexComputations()::MlirPass -end - -function mlirRegisterGPUGpuLaunchSinkIndexComputations() - @ccall mlir_c.mlirRegisterGPUGpuLaunchSinkIndexComputations()::Cvoid -end - -function mlirCreateGPUGpuMapParallelLoopsPass() - @ccall mlir_c.mlirCreateGPUGpuMapParallelLoopsPass()::MlirPass -end - -function mlirRegisterGPUGpuMapParallelLoopsPass() - @ccall mlir_c.mlirRegisterGPUGpuMapParallelLoopsPass()::Cvoid -end - -function mlirCreateGPUGpuModuleToBinaryPass() - @ccall mlir_c.mlirCreateGPUGpuModuleToBinaryPass()::MlirPass -end - -function mlirRegisterGPUGpuModuleToBinaryPass() - @ccall mlir_c.mlirRegisterGPUGpuModuleToBinaryPass()::Cvoid -end - -function mlirCreateGPUGpuNVVMAttachTarget() - @ccall mlir_c.mlirCreateGPUGpuNVVMAttachTarget()::MlirPass -end - -function mlirRegisterGPUGpuNVVMAttachTarget() - @ccall mlir_c.mlirRegisterGPUGpuNVVMAttachTarget()::Cvoid -end - -function mlirCreateGPUGpuROCDLAttachTarget() - @ccall mlir_c.mlirCreateGPUGpuROCDLAttachTarget()::MlirPass -end - -function mlirRegisterGPUGpuROCDLAttachTarget() - @ccall mlir_c.mlirRegisterGPUGpuROCDLAttachTarget()::Cvoid -end - -function mlirCreateGPUGpuSPIRVAttachTarget() - @ccall mlir_c.mlirCreateGPUGpuSPIRVAttachTarget()::MlirPass -end - -function mlirRegisterGPUGpuSPIRVAttachTarget() - @ccall mlir_c.mlirRegisterGPUGpuSPIRVAttachTarget()::Cvoid -end - function mlirGetDialectHandle__irdl__() @ccall mlir_c.mlirGetDialectHandle__irdl__()::MlirDialectHandle end @@ -7084,6 +6247,10 @@ function mlirLoadIRDLDialects(_module) @ccall mlir_c.mlirLoadIRDLDialects(_module::MlirModule)::MlirLogicalResult end +function mlirGetDialectHandle__index__() + @ccall mlir_c.mlirGetDialectHandle__index__()::MlirDialectHandle +end + function mlirGetDialectHandle__llvm__() @ccall mlir_c.mlirGetDialectHandle__llvm__()::MlirDialectHandle end @@ -7156,6 +6323,33 @@ function mlirLLVMFunctionTypeGet(resultType, nArgumentTypes, argumentTypes, isVa )::MlirType end +""" + mlirLLVMFunctionTypeGetNumInputs(type) + +Returns the number of input types. +""" +function mlirLLVMFunctionTypeGetNumInputs(type) + @ccall mlir_c.mlirLLVMFunctionTypeGetNumInputs(type::MlirType)::intptr_t +end + +""" + mlirLLVMFunctionTypeGetInput(type, pos) + +Returns the pos-th input type. +""" +function mlirLLVMFunctionTypeGetInput(type, pos) + @ccall mlir_c.mlirLLVMFunctionTypeGetInput(type::MlirType, pos::intptr_t)::MlirType +end + +""" + mlirLLVMFunctionTypeGetReturnType(type) + +Returns the return type of the function type. +""" +function mlirLLVMFunctionTypeGetReturnType(type) + @ccall mlir_c.mlirLLVMFunctionTypeGetReturnType(type::MlirType)::MlirType +end + """ mlirTypeIsALLVMStructType(type) @@ -7889,106 +7083,6 @@ function mlirGetDialectHandle__linalg__() @ccall mlir_c.mlirGetDialectHandle__linalg__()::MlirDialectHandle end -function mlirRegisterLinalgPasses() - @ccall mlir_c.mlirRegisterLinalgPasses()::Cvoid -end - -function mlirCreateLinalgConvertElementwiseToLinalgPass() - @ccall mlir_c.mlirCreateLinalgConvertElementwiseToLinalgPass()::MlirPass -end - -function mlirRegisterLinalgConvertElementwiseToLinalgPass() - @ccall mlir_c.mlirRegisterLinalgConvertElementwiseToLinalgPass()::Cvoid -end - -function mlirCreateLinalgConvertLinalgToAffineLoopsPass() - @ccall mlir_c.mlirCreateLinalgConvertLinalgToAffineLoopsPass()::MlirPass -end - -function mlirRegisterLinalgConvertLinalgToAffineLoopsPass() - @ccall mlir_c.mlirRegisterLinalgConvertLinalgToAffineLoopsPass()::Cvoid -end - -function mlirCreateLinalgConvertLinalgToLoopsPass() - @ccall mlir_c.mlirCreateLinalgConvertLinalgToLoopsPass()::MlirPass -end - -function mlirRegisterLinalgConvertLinalgToLoopsPass() - @ccall mlir_c.mlirRegisterLinalgConvertLinalgToLoopsPass()::Cvoid -end - -function mlirCreateLinalgConvertLinalgToParallelLoopsPass() - @ccall mlir_c.mlirCreateLinalgConvertLinalgToParallelLoopsPass()::MlirPass -end - -function mlirRegisterLinalgConvertLinalgToParallelLoopsPass() - @ccall mlir_c.mlirRegisterLinalgConvertLinalgToParallelLoopsPass()::Cvoid -end - -function mlirCreateLinalgLinalgBlockPackMatmul() - @ccall mlir_c.mlirCreateLinalgLinalgBlockPackMatmul()::MlirPass -end - -function mlirRegisterLinalgLinalgBlockPackMatmul() - @ccall mlir_c.mlirRegisterLinalgLinalgBlockPackMatmul()::Cvoid -end - -function mlirCreateLinalgLinalgDetensorizePass() - @ccall mlir_c.mlirCreateLinalgLinalgDetensorizePass()::MlirPass -end - -function mlirRegisterLinalgLinalgDetensorizePass() - @ccall mlir_c.mlirRegisterLinalgLinalgDetensorizePass()::Cvoid -end - -function mlirCreateLinalgLinalgElementwiseOpFusionPass() - @ccall mlir_c.mlirCreateLinalgLinalgElementwiseOpFusionPass()::MlirPass -end - -function mlirRegisterLinalgLinalgElementwiseOpFusionPass() - @ccall mlir_c.mlirRegisterLinalgLinalgElementwiseOpFusionPass()::Cvoid -end - -function mlirCreateLinalgLinalgFoldUnitExtentDimsPass() - @ccall mlir_c.mlirCreateLinalgLinalgFoldUnitExtentDimsPass()::MlirPass -end - -function mlirRegisterLinalgLinalgFoldUnitExtentDimsPass() - @ccall mlir_c.mlirRegisterLinalgLinalgFoldUnitExtentDimsPass()::Cvoid -end - -function mlirCreateLinalgLinalgGeneralizeNamedOpsPass() - @ccall mlir_c.mlirCreateLinalgLinalgGeneralizeNamedOpsPass()::MlirPass -end - -function mlirRegisterLinalgLinalgGeneralizeNamedOpsPass() - @ccall mlir_c.mlirRegisterLinalgLinalgGeneralizeNamedOpsPass()::Cvoid -end - -function mlirCreateLinalgLinalgInlineScalarOperandsPass() - @ccall mlir_c.mlirCreateLinalgLinalgInlineScalarOperandsPass()::MlirPass -end - -function mlirRegisterLinalgLinalgInlineScalarOperandsPass() - @ccall mlir_c.mlirRegisterLinalgLinalgInlineScalarOperandsPass()::Cvoid -end - -function mlirCreateLinalgLinalgNamedOpConversionPass() - @ccall mlir_c.mlirCreateLinalgLinalgNamedOpConversionPass()::MlirPass -end - -function mlirRegisterLinalgLinalgNamedOpConversionPass() - @ccall mlir_c.mlirRegisterLinalgLinalgNamedOpConversionPass()::Cvoid -end - -function mlirCreateLinalgLinalgSpecializeGenericOpsPass() - @ccall mlir_c.mlirCreateLinalgLinalgSpecializeGenericOpsPass()::MlirPass -end - -function mlirRegisterLinalgLinalgSpecializeGenericOpsPass() - @ccall mlir_c.mlirRegisterLinalgLinalgSpecializeGenericOpsPass()::Cvoid -end - function mlirGetDialectHandle__ml_program__() @ccall mlir_c.mlirGetDialectHandle__ml_program__()::MlirDialectHandle end @@ -8651,140 +7745,8 @@ function mlirSparseTensorEncodingAttrBuildLvlType(lvlFmt, properties, propSize, properties::Ptr{MlirSparseTensorLevelPropertyNondefault}, propSize::Cuint, n::Cuint, - m::Cuint, - )::MlirSparseTensorLevelType -end - -function mlirRegisterSparseTensorPasses() - @ccall mlir_c.mlirRegisterSparseTensorPasses()::Cvoid -end - -function mlirCreateSparseTensorLowerForeachToSCF() - @ccall mlir_c.mlirCreateSparseTensorLowerForeachToSCF()::MlirPass -end - -function mlirRegisterSparseTensorLowerForeachToSCF() - @ccall mlir_c.mlirRegisterSparseTensorLowerForeachToSCF()::Cvoid -end - -function mlirCreateSparseTensorLowerSparseIterationToSCF() - @ccall mlir_c.mlirCreateSparseTensorLowerSparseIterationToSCF()::MlirPass -end - -function mlirRegisterSparseTensorLowerSparseIterationToSCF() - @ccall mlir_c.mlirRegisterSparseTensorLowerSparseIterationToSCF()::Cvoid -end - -function mlirCreateSparseTensorLowerSparseOpsToForeach() - @ccall mlir_c.mlirCreateSparseTensorLowerSparseOpsToForeach()::MlirPass -end - -function mlirRegisterSparseTensorLowerSparseOpsToForeach() - @ccall mlir_c.mlirRegisterSparseTensorLowerSparseOpsToForeach()::Cvoid -end - -function mlirCreateSparseTensorPreSparsificationRewrite() - @ccall mlir_c.mlirCreateSparseTensorPreSparsificationRewrite()::MlirPass -end - -function mlirRegisterSparseTensorPreSparsificationRewrite() - @ccall mlir_c.mlirRegisterSparseTensorPreSparsificationRewrite()::Cvoid -end - -function mlirCreateSparseTensorSparseAssembler() - @ccall mlir_c.mlirCreateSparseTensorSparseAssembler()::MlirPass -end - -function mlirRegisterSparseTensorSparseAssembler() - @ccall mlir_c.mlirRegisterSparseTensorSparseAssembler()::Cvoid -end - -function mlirCreateSparseTensorSparseBufferRewrite() - @ccall mlir_c.mlirCreateSparseTensorSparseBufferRewrite()::MlirPass -end - -function mlirRegisterSparseTensorSparseBufferRewrite() - @ccall mlir_c.mlirRegisterSparseTensorSparseBufferRewrite()::Cvoid -end - -function mlirCreateSparseTensorSparseGPUCodegen() - @ccall mlir_c.mlirCreateSparseTensorSparseGPUCodegen()::MlirPass -end - -function mlirRegisterSparseTensorSparseGPUCodegen() - @ccall mlir_c.mlirRegisterSparseTensorSparseGPUCodegen()::Cvoid -end - -function mlirCreateSparseTensorSparseReinterpretMap() - @ccall mlir_c.mlirCreateSparseTensorSparseReinterpretMap()::MlirPass -end - -function mlirRegisterSparseTensorSparseReinterpretMap() - @ccall mlir_c.mlirRegisterSparseTensorSparseReinterpretMap()::Cvoid -end - -function mlirCreateSparseTensorSparseSpaceCollapse() - @ccall mlir_c.mlirCreateSparseTensorSparseSpaceCollapse()::MlirPass -end - -function mlirRegisterSparseTensorSparseSpaceCollapse() - @ccall mlir_c.mlirRegisterSparseTensorSparseSpaceCollapse()::Cvoid -end - -function mlirCreateSparseTensorSparseTensorCodegen() - @ccall mlir_c.mlirCreateSparseTensorSparseTensorCodegen()::MlirPass -end - -function mlirRegisterSparseTensorSparseTensorCodegen() - @ccall mlir_c.mlirRegisterSparseTensorSparseTensorCodegen()::Cvoid -end - -function mlirCreateSparseTensorSparseTensorConversionPass() - @ccall mlir_c.mlirCreateSparseTensorSparseTensorConversionPass()::MlirPass -end - -function mlirRegisterSparseTensorSparseTensorConversionPass() - @ccall mlir_c.mlirRegisterSparseTensorSparseTensorConversionPass()::Cvoid -end - -function mlirCreateSparseTensorSparseVectorization() - @ccall mlir_c.mlirCreateSparseTensorSparseVectorization()::MlirPass -end - -function mlirRegisterSparseTensorSparseVectorization() - @ccall mlir_c.mlirRegisterSparseTensorSparseVectorization()::Cvoid -end - -function mlirCreateSparseTensorSparsificationAndBufferization() - @ccall mlir_c.mlirCreateSparseTensorSparsificationAndBufferization()::MlirPass -end - -function mlirRegisterSparseTensorSparsificationAndBufferization() - @ccall mlir_c.mlirRegisterSparseTensorSparsificationAndBufferization()::Cvoid -end - -function mlirCreateSparseTensorSparsificationPass() - @ccall mlir_c.mlirCreateSparseTensorSparsificationPass()::MlirPass -end - -function mlirRegisterSparseTensorSparsificationPass() - @ccall mlir_c.mlirRegisterSparseTensorSparsificationPass()::Cvoid -end - -function mlirCreateSparseTensorStageSparseOperations() - @ccall mlir_c.mlirCreateSparseTensorStageSparseOperations()::MlirPass -end - -function mlirRegisterSparseTensorStageSparseOperations() - @ccall mlir_c.mlirRegisterSparseTensorStageSparseOperations()::Cvoid -end - -function mlirCreateSparseTensorStorageSpecifierToLLVM() - @ccall mlir_c.mlirCreateSparseTensorStorageSpecifierToLLVM()::MlirPass -end - -function mlirRegisterSparseTensorStorageSpecifierToLLVM() - @ccall mlir_c.mlirRegisterSparseTensorStorageSpecifierToLLVM()::Cvoid + m::Cuint, + )::MlirSparseTensorLevelType end function mlirGetDialectHandle__tensor__() @@ -9035,143 +7997,399 @@ function mlirExecutionEngineLookup(jit, name) end """ - mlirExecutionEngineRegisterSymbol(jit, name, sym) + mlirExecutionEngineRegisterSymbol(jit, name, sym) + +Register a symbol with the jit: this symbol will be accessible to the jitted code. +""" +function mlirExecutionEngineRegisterSymbol(jit, name, sym) + @ccall mlir_c.mlirExecutionEngineRegisterSymbol( + jit::MlirExecutionEngine, name::MlirStringRef, sym::Ptr{Cvoid} + )::Cvoid +end + +""" + mlirExecutionEngineDumpToObjectFile(jit, fileName) + +Dump as an object in `fileName`. +""" +function mlirExecutionEngineDumpToObjectFile(jit, fileName) + @ccall mlir_c.mlirExecutionEngineDumpToObjectFile( + jit::MlirExecutionEngine, fileName::MlirStringRef + )::Cvoid +end + +""" + mlirOperationImplementsInterface(operation, interfaceTypeID) + +Returns `true` if the given operation implements an interface identified by its TypeID. +""" +function mlirOperationImplementsInterface(operation, interfaceTypeID) + @ccall mlir_c.mlirOperationImplementsInterface( + operation::MlirOperation, interfaceTypeID::MlirTypeID + )::Bool +end + +""" + mlirOperationImplementsInterfaceStatic(operationName, context, interfaceTypeID) + +Returns `true` if the operation identified by its canonical string name implements the interface identified by its TypeID in the given context. Note that interfaces may be attached to operations in some contexts and not others. +""" +function mlirOperationImplementsInterfaceStatic(operationName, context, interfaceTypeID) + @ccall mlir_c.mlirOperationImplementsInterfaceStatic( + operationName::MlirStringRef, context::MlirContext, interfaceTypeID::MlirTypeID + )::Bool +end + +""" + mlirInferTypeOpInterfaceTypeID() + +Returns the interface TypeID of the InferTypeOpInterface. +""" +function mlirInferTypeOpInterfaceTypeID() + @ccall mlir_c.mlirInferTypeOpInterfaceTypeID()::MlirTypeID +end + +# typedef void ( * MlirTypesCallback ) ( intptr_t , MlirType * , void * ) +""" +These callbacks are used to return multiple types from functions while transferring ownership to the caller. The first argument is the number of consecutive elements pointed to by the second argument. The third argument is an opaque pointer forwarded to the callback by the caller. +""" +const MlirTypesCallback = Ptr{Cvoid} + +""" + mlirInferTypeOpInterfaceInferReturnTypes(opName, context, location, nOperands, operands, attributes, properties, nRegions, regions, callback, userData) + +Infers the return types of the operation identified by its canonical given the arguments that will be supplied to its generic builder. Calls `callback` with the types of inferred arguments, potentially several times, on success. Returns failure otherwise. +""" +function mlirInferTypeOpInterfaceInferReturnTypes( + opName, + context, + location, + nOperands, + operands, + attributes, + properties, + nRegions, + regions, + callback, + userData, +) + @ccall mlir_c.mlirInferTypeOpInterfaceInferReturnTypes( + opName::MlirStringRef, + context::MlirContext, + location::MlirLocation, + nOperands::intptr_t, + operands::Ptr{MlirValue}, + attributes::MlirAttribute, + properties::Ptr{Cvoid}, + nRegions::intptr_t, + regions::Ptr{MlirRegion}, + callback::MlirTypesCallback, + userData::Ptr{Cvoid}, + )::MlirLogicalResult +end + +""" + mlirInferShapedTypeOpInterfaceTypeID() + +Returns the interface TypeID of the InferShapedTypeOpInterface. +""" +function mlirInferShapedTypeOpInterfaceTypeID() + @ccall mlir_c.mlirInferShapedTypeOpInterfaceTypeID()::MlirTypeID +end + +# typedef void ( * MlirShapedTypeComponentsCallback ) ( bool , intptr_t , const int64_t * , MlirType , MlirAttribute , void * ) +""" +These callbacks are used to return multiple shaped type components from functions while transferring ownership to the caller. The first argument is the has rank boolean followed by the the rank and a pointer to the shape (if applicable). The next argument is the element type, then the attribute. The last argument is an opaque pointer forwarded to the callback by the caller. This callback will be called potentially multiple times for each shaped type components. +""" +const MlirShapedTypeComponentsCallback = Ptr{Cvoid} + +""" + mlirInferShapedTypeOpInterfaceInferReturnTypes(opName, context, location, nOperands, operands, attributes, properties, nRegions, regions, callback, userData) + +Infers the return shaped type components of the operation. Calls `callback` with the types of inferred arguments on success. Returns failure otherwise. +""" +function mlirInferShapedTypeOpInterfaceInferReturnTypes( + opName, + context, + location, + nOperands, + operands, + attributes, + properties, + nRegions, + regions, + callback, + userData, +) + @ccall mlir_c.mlirInferShapedTypeOpInterfaceInferReturnTypes( + opName::MlirStringRef, + context::MlirContext, + location::MlirLocation, + nOperands::intptr_t, + operands::Ptr{MlirValue}, + attributes::MlirAttribute, + properties::Ptr{Cvoid}, + nRegions::intptr_t, + regions::Ptr{MlirRegion}, + callback::MlirShapedTypeComponentsCallback, + userData::Ptr{Cvoid}, + )::MlirLogicalResult +end + +struct MlirPass + ptr::Ptr{Cvoid} +end + +struct MlirExternalPass + ptr::Ptr{Cvoid} +end + +struct MlirPassManager + ptr::Ptr{Cvoid} +end + +struct MlirOpPassManager + ptr::Ptr{Cvoid} +end + +""" + mlirPassManagerCreate(ctx) + +Create a new top-level PassManager with the default anchor. +""" +function mlirPassManagerCreate(ctx) + @ccall mlir_c.mlirPassManagerCreate(ctx::MlirContext)::MlirPassManager +end + +""" + mlirPassManagerCreateOnOperation(ctx, anchorOp) + +Create a new top-level PassManager anchored on `anchorOp`. +""" +function mlirPassManagerCreateOnOperation(ctx, anchorOp) + @ccall mlir_c.mlirPassManagerCreateOnOperation( + ctx::MlirContext, anchorOp::MlirStringRef + )::MlirPassManager +end + +""" + mlirPassManagerDestroy(passManager) + +Destroy the provided PassManager. +""" +function mlirPassManagerDestroy(passManager) + @ccall mlir_c.mlirPassManagerDestroy(passManager::MlirPassManager)::Cvoid +end + +""" + mlirPassManagerIsNull(passManager) + +Checks if a PassManager is null. +""" +function mlirPassManagerIsNull(passManager) + @ccall mlir_c.mlirPassManagerIsNull(passManager::MlirPassManager)::Bool +end + +""" + mlirPassManagerGetAsOpPassManager(passManager) + +Cast a top-level PassManager to a generic OpPassManager. +""" +function mlirPassManagerGetAsOpPassManager(passManager) + @ccall mlir_c.mlirPassManagerGetAsOpPassManager( + passManager::MlirPassManager + )::MlirOpPassManager +end + +""" + mlirPassManagerRunOnOp(passManager, op) + +Run the provided `passManager` on the given `op`. +""" +function mlirPassManagerRunOnOp(passManager, op) + @ccall mlir_c.mlirPassManagerRunOnOp( + passManager::MlirPassManager, op::MlirOperation + )::MlirLogicalResult +end + +""" + mlirPassManagerEnableIRPrinting(passManager, printBeforeAll, printAfterAll, printModuleScope, printAfterOnlyOnChange, printAfterOnlyOnFailure, flags, treePrintingPath) + +Enable IR printing. The treePrintingPath argument is an optional path to a directory where the dumps will be produced. If it isn't provided then dumps are produced to stderr. +""" +function mlirPassManagerEnableIRPrinting( + passManager, + printBeforeAll, + printAfterAll, + printModuleScope, + printAfterOnlyOnChange, + printAfterOnlyOnFailure, + flags, + treePrintingPath, +) + @ccall mlir_c.mlirPassManagerEnableIRPrinting( + passManager::MlirPassManager, + printBeforeAll::Bool, + printAfterAll::Bool, + printModuleScope::Bool, + printAfterOnlyOnChange::Bool, + printAfterOnlyOnFailure::Bool, + flags::MlirOpPrintingFlags, + treePrintingPath::MlirStringRef, + )::Cvoid +end + +""" + mlirPassManagerEnableVerifier(passManager, enable) + +Enable / disable verify-each. +""" +function mlirPassManagerEnableVerifier(passManager, enable) + @ccall mlir_c.mlirPassManagerEnableVerifier( + passManager::MlirPassManager, enable::Bool + )::Cvoid +end + +""" + mlirPassManagerGetNestedUnder(passManager, operationName) -Register a symbol with the jit: this symbol will be accessible to the jitted code. +Nest an OpPassManager under the top-level PassManager, the nested passmanager will only run on operations matching the provided name. The returned OpPassManager will be destroyed when the parent is destroyed. To further nest more OpPassManager under the newly returned one, see `mlirOpPassManagerNest` below. """ -function mlirExecutionEngineRegisterSymbol(jit, name, sym) - @ccall mlir_c.mlirExecutionEngineRegisterSymbol( - jit::MlirExecutionEngine, name::MlirStringRef, sym::Ptr{Cvoid} - )::Cvoid +function mlirPassManagerGetNestedUnder(passManager, operationName) + @ccall mlir_c.mlirPassManagerGetNestedUnder( + passManager::MlirPassManager, operationName::MlirStringRef + )::MlirOpPassManager end """ - mlirExecutionEngineDumpToObjectFile(jit, fileName) + mlirOpPassManagerGetNestedUnder(passManager, operationName) -Dump as an object in `fileName`. +Nest an OpPassManager under the provided OpPassManager, the nested passmanager will only run on operations matching the provided name. The returned OpPassManager will be destroyed when the parent is destroyed. """ -function mlirExecutionEngineDumpToObjectFile(jit, fileName) - @ccall mlir_c.mlirExecutionEngineDumpToObjectFile( - jit::MlirExecutionEngine, fileName::MlirStringRef - )::Cvoid +function mlirOpPassManagerGetNestedUnder(passManager, operationName) + @ccall mlir_c.mlirOpPassManagerGetNestedUnder( + passManager::MlirOpPassManager, operationName::MlirStringRef + )::MlirOpPassManager end """ - mlirOperationImplementsInterface(operation, interfaceTypeID) + mlirPassManagerAddOwnedPass(passManager, pass) -Returns `true` if the given operation implements an interface identified by its TypeID. +Add a pass and transfer ownership to the provided top-level mlirPassManager. If the pass is not a generic operation pass or a ModulePass, a new OpPassManager is implicitly nested under the provided PassManager. """ -function mlirOperationImplementsInterface(operation, interfaceTypeID) - @ccall mlir_c.mlirOperationImplementsInterface( - operation::MlirOperation, interfaceTypeID::MlirTypeID - )::Bool +function mlirPassManagerAddOwnedPass(passManager, pass) + @ccall mlir_c.mlirPassManagerAddOwnedPass( + passManager::MlirPassManager, pass::MlirPass + )::Cvoid end """ - mlirOperationImplementsInterfaceStatic(operationName, context, interfaceTypeID) + mlirOpPassManagerAddOwnedPass(passManager, pass) -Returns `true` if the operation identified by its canonical string name implements the interface identified by its TypeID in the given context. Note that interfaces may be attached to operations in some contexts and not others. +Add a pass and transfer ownership to the provided mlirOpPassManager. If the pass is not a generic operation pass or matching the type of the provided PassManager, a new OpPassManager is implicitly nested under the provided PassManager. """ -function mlirOperationImplementsInterfaceStatic(operationName, context, interfaceTypeID) - @ccall mlir_c.mlirOperationImplementsInterfaceStatic( - operationName::MlirStringRef, context::MlirContext, interfaceTypeID::MlirTypeID - )::Bool +function mlirOpPassManagerAddOwnedPass(passManager, pass) + @ccall mlir_c.mlirOpPassManagerAddOwnedPass( + passManager::MlirOpPassManager, pass::MlirPass + )::Cvoid end """ - mlirInferTypeOpInterfaceTypeID() + mlirOpPassManagerAddPipeline(passManager, pipelineElements, callback, userData) -Returns the interface TypeID of the InferTypeOpInterface. +Parse a sequence of textual MLIR pass pipeline elements and add them to the provided OpPassManager. If parsing fails an error message is reported using the provided callback. """ -function mlirInferTypeOpInterfaceTypeID() - @ccall mlir_c.mlirInferTypeOpInterfaceTypeID()::MlirTypeID +function mlirOpPassManagerAddPipeline(passManager, pipelineElements, callback, userData) + @ccall mlir_c.mlirOpPassManagerAddPipeline( + passManager::MlirOpPassManager, + pipelineElements::MlirStringRef, + callback::MlirStringCallback, + userData::Ptr{Cvoid}, + )::MlirLogicalResult end -# typedef void ( * MlirTypesCallback ) ( intptr_t , MlirType * , void * ) """ -These callbacks are used to return multiple types from functions while transferring ownership to the caller. The first argument is the number of consecutive elements pointed to by the second argument. The third argument is an opaque pointer forwarded to the callback by the caller. + mlirPrintPassPipeline(passManager, callback, userData) + +Print a textual MLIR pass pipeline by sending chunks of the string representation and forwarding `userData to `callback`. Note that the callback may be called several times with consecutive chunks of the string. """ -const MlirTypesCallback = Ptr{Cvoid} +function mlirPrintPassPipeline(passManager, callback, userData) + @ccall mlir_c.mlirPrintPassPipeline( + passManager::MlirOpPassManager, callback::MlirStringCallback, userData::Ptr{Cvoid} + )::Cvoid +end """ - mlirInferTypeOpInterfaceInferReturnTypes(opName, context, location, nOperands, operands, attributes, properties, nRegions, regions, callback, userData) + mlirParsePassPipeline(passManager, pipeline, callback, userData) -Infers the return types of the operation identified by its canonical given the arguments that will be supplied to its generic builder. Calls `callback` with the types of inferred arguments, potentially several times, on success. Returns failure otherwise. +Parse a textual MLIR pass pipeline and assign it to the provided OpPassManager. If parsing fails an error message is reported using the provided callback. """ -function mlirInferTypeOpInterfaceInferReturnTypes( - opName, - context, - location, - nOperands, - operands, - attributes, - properties, - nRegions, - regions, - callback, - userData, -) - @ccall mlir_c.mlirInferTypeOpInterfaceInferReturnTypes( - opName::MlirStringRef, - context::MlirContext, - location::MlirLocation, - nOperands::intptr_t, - operands::Ptr{MlirValue}, - attributes::MlirAttribute, - properties::Ptr{Cvoid}, - nRegions::intptr_t, - regions::Ptr{MlirRegion}, - callback::MlirTypesCallback, +function mlirParsePassPipeline(passManager, pipeline, callback, userData) + @ccall mlir_c.mlirParsePassPipeline( + passManager::MlirOpPassManager, + pipeline::MlirStringRef, + callback::MlirStringCallback, userData::Ptr{Cvoid}, )::MlirLogicalResult end """ - mlirInferShapedTypeOpInterfaceTypeID() + MlirExternalPassCallbacks -Returns the interface TypeID of the InferShapedTypeOpInterface. -""" -function mlirInferShapedTypeOpInterfaceTypeID() - @ccall mlir_c.mlirInferShapedTypeOpInterfaceTypeID()::MlirTypeID -end +Structure of external [`MlirPass`](@ref) callbacks. All callbacks are required to be set unless otherwise specified. -# typedef void ( * MlirShapedTypeComponentsCallback ) ( bool , intptr_t , const int64_t * , MlirType , MlirAttribute , void * ) -""" -These callbacks are used to return multiple shaped type components from functions while transferring ownership to the caller. The first argument is the has rank boolean followed by the the rank and a pointer to the shape (if applicable). The next argument is the element type, then the attribute. The last argument is an opaque pointer forwarded to the callback by the caller. This callback will be called potentially multiple times for each shaped type components. +| Field | Note | +| :--------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| construct | This callback is called from the pass is created. This is analogous to a C++ pass constructor. | +| destruct | This callback is called when the pass is destroyed This is analogous to a C++ pass destructor. | +| initialize | This callback is optional. The callback is called before the pass is run, allowing a chance to initialize any complex state necessary for running the pass. See Pass::initialize(MLIRContext *). | +| clone | This callback is called when the pass is cloned. See Pass::clonePass(). | +| run | This callback is called when the pass is run. See Pass::runOnOperation(). | """ -const MlirShapedTypeComponentsCallback = Ptr{Cvoid} +struct MlirExternalPassCallbacks + construct::Ptr{Cvoid} + destruct::Ptr{Cvoid} + initialize::Ptr{Cvoid} + clone::Ptr{Cvoid} + run::Ptr{Cvoid} +end """ - mlirInferShapedTypeOpInterfaceInferReturnTypes(opName, context, location, nOperands, operands, attributes, properties, nRegions, regions, callback, userData) + mlirCreateExternalPass(passID, name, argument, description, opName, nDependentDialects, dependentDialects, callbacks, userData) -Infers the return shaped type components of the operation. Calls `callback` with the types of inferred arguments on success. Returns failure otherwise. +Creates an external [`MlirPass`](@ref) that calls the supplied `callbacks` using the supplied `userData`. If `opName` is empty, the pass is a generic operation pass. Otherwise it is an operation pass specific to the specified pass name. """ -function mlirInferShapedTypeOpInterfaceInferReturnTypes( +function mlirCreateExternalPass( + passID, + name, + argument, + description, opName, - context, - location, - nOperands, - operands, - attributes, - properties, - nRegions, - regions, - callback, + nDependentDialects, + dependentDialects, + callbacks, userData, ) - @ccall mlir_c.mlirInferShapedTypeOpInterfaceInferReturnTypes( + @ccall mlir_c.mlirCreateExternalPass( + passID::MlirTypeID, + name::MlirStringRef, + argument::MlirStringRef, + description::MlirStringRef, opName::MlirStringRef, - context::MlirContext, - location::MlirLocation, - nOperands::intptr_t, - operands::Ptr{MlirValue}, - attributes::MlirAttribute, - properties::Ptr{Cvoid}, - nRegions::intptr_t, - regions::Ptr{MlirRegion}, - callback::MlirShapedTypeComponentsCallback, + nDependentDialects::intptr_t, + dependentDialects::Ptr{MlirDialectHandle}, + callbacks::MlirExternalPassCallbacks, userData::Ptr{Cvoid}, - )::MlirLogicalResult + )::MlirPass +end + +""" + mlirExternalPassSignalFailure(pass) + +This signals that the pass has failed. This is only valid to call during the `run` callback of [`MlirExternalPassCallbacks`](@ref). See Pass::signalPassFailure(). +""" +function mlirExternalPassSignalFailure(pass) + @ccall mlir_c.mlirExternalPassSignalFailure(pass::MlirExternalPass)::Cvoid end """ @@ -9879,206 +9097,116 @@ function LLVMParseCommandLineOptions(argc, argv, Overview) )::Cvoid end -""" - LLVMSearchForAddressOfSymbol(symbolName) - -This function will search through all previously loaded dynamic libraries for the symbol `symbolName`. If it is found, the address of that symbol is returned. If not, null is returned. - -# See also -sys::DynamicLibrary::SearchForAddressOfSymbol() -""" -function LLVMSearchForAddressOfSymbol(symbolName) - @ccall mlir_c.LLVMSearchForAddressOfSymbol(symbolName::Cstring)::Ptr{Cvoid} -end - -""" - LLVMAddSymbol(symbolName, symbolValue) - -This functions permanently adds the symbol `symbolName` with the value `symbolValue`. These symbols are searched before any libraries. - -# See also -sys::DynamicLibrary::AddSymbol() -""" -function LLVMAddSymbol(symbolName, symbolValue) - @ccall mlir_c.LLVMAddSymbol(symbolName::Cstring, symbolValue::Ptr{Cvoid})::Cvoid -end - -""" - mlirTranslateModuleToLLVMIR(_module, context) - -Translate operation that satisfies LLVM dialect module requirements into an LLVM IR module living in the given context. This translates operations from any dilalect that has a registered implementation of LLVMTranslationDialectInterface. - -# Returns -the generated LLVM IR Module from the translated MLIR module, it is owned by the caller. -""" -function mlirTranslateModuleToLLVMIR(_module, context) - @ccall mlir_c.mlirTranslateModuleToLLVMIR( - _module::MlirOperation, context::LLVMContextRef - )::LLVMModuleRef -end - -function mlirRegisterTransformsPasses() - @ccall mlir_c.mlirRegisterTransformsPasses()::Cvoid -end - -function mlirCreateTransformsCSE() - @ccall mlir_c.mlirCreateTransformsCSE()::MlirPass -end - -function mlirRegisterTransformsCSE() - @ccall mlir_c.mlirRegisterTransformsCSE()::Cvoid -end - -function mlirCreateTransformsCanonicalizer() - @ccall mlir_c.mlirCreateTransformsCanonicalizer()::MlirPass -end - -function mlirRegisterTransformsCanonicalizer() - @ccall mlir_c.mlirRegisterTransformsCanonicalizer()::Cvoid -end - -function mlirCreateTransformsCompositeFixedPointPass() - @ccall mlir_c.mlirCreateTransformsCompositeFixedPointPass()::MlirPass -end - -function mlirRegisterTransformsCompositeFixedPointPass() - @ccall mlir_c.mlirRegisterTransformsCompositeFixedPointPass()::Cvoid -end - -function mlirCreateTransformsControlFlowSink() - @ccall mlir_c.mlirCreateTransformsControlFlowSink()::MlirPass -end - -function mlirRegisterTransformsControlFlowSink() - @ccall mlir_c.mlirRegisterTransformsControlFlowSink()::Cvoid -end - -function mlirCreateTransformsGenerateRuntimeVerification() - @ccall mlir_c.mlirCreateTransformsGenerateRuntimeVerification()::MlirPass -end - -function mlirRegisterTransformsGenerateRuntimeVerification() - @ccall mlir_c.mlirRegisterTransformsGenerateRuntimeVerification()::Cvoid -end - -function mlirCreateTransformsInliner() - @ccall mlir_c.mlirCreateTransformsInliner()::MlirPass -end - -function mlirRegisterTransformsInliner() - @ccall mlir_c.mlirRegisterTransformsInliner()::Cvoid -end - -function mlirCreateTransformsLocationSnapshot() - @ccall mlir_c.mlirCreateTransformsLocationSnapshot()::MlirPass -end - -function mlirRegisterTransformsLocationSnapshot() - @ccall mlir_c.mlirRegisterTransformsLocationSnapshot()::Cvoid -end - -function mlirCreateTransformsLoopInvariantCodeMotion() - @ccall mlir_c.mlirCreateTransformsLoopInvariantCodeMotion()::MlirPass -end - -function mlirRegisterTransformsLoopInvariantCodeMotion() - @ccall mlir_c.mlirRegisterTransformsLoopInvariantCodeMotion()::Cvoid -end - -function mlirCreateTransformsLoopInvariantSubsetHoisting() - @ccall mlir_c.mlirCreateTransformsLoopInvariantSubsetHoisting()::MlirPass -end - -function mlirRegisterTransformsLoopInvariantSubsetHoisting() - @ccall mlir_c.mlirRegisterTransformsLoopInvariantSubsetHoisting()::Cvoid -end +""" + LLVMSearchForAddressOfSymbol(symbolName) -function mlirCreateTransformsMem2Reg() - @ccall mlir_c.mlirCreateTransformsMem2Reg()::MlirPass -end +This function will search through all previously loaded dynamic libraries for the symbol `symbolName`. If it is found, the address of that symbol is returned. If not, null is returned. -function mlirRegisterTransformsMem2Reg() - @ccall mlir_c.mlirRegisterTransformsMem2Reg()::Cvoid +# See also +sys::DynamicLibrary::SearchForAddressOfSymbol() +""" +function LLVMSearchForAddressOfSymbol(symbolName) + @ccall mlir_c.LLVMSearchForAddressOfSymbol(symbolName::Cstring)::Ptr{Cvoid} end -function mlirCreateTransformsPrintIRPass() - @ccall mlir_c.mlirCreateTransformsPrintIRPass()::MlirPass -end +""" + LLVMAddSymbol(symbolName, symbolValue) -function mlirRegisterTransformsPrintIRPass() - @ccall mlir_c.mlirRegisterTransformsPrintIRPass()::Cvoid -end +This functions permanently adds the symbol `symbolName` with the value `symbolValue`. These symbols are searched before any libraries. -function mlirCreateTransformsPrintOpStats() - @ccall mlir_c.mlirCreateTransformsPrintOpStats()::MlirPass +# See also +sys::DynamicLibrary::AddSymbol() +""" +function LLVMAddSymbol(symbolName, symbolValue) + @ccall mlir_c.LLVMAddSymbol(symbolName::Cstring, symbolValue::Ptr{Cvoid})::Cvoid end -function mlirRegisterTransformsPrintOpStats() - @ccall mlir_c.mlirRegisterTransformsPrintOpStats()::Cvoid -end +""" + mlirTranslateModuleToLLVMIR(_module, context) -function mlirCreateTransformsRemoveDeadValues() - @ccall mlir_c.mlirCreateTransformsRemoveDeadValues()::MlirPass -end +Translate operation that satisfies LLVM dialect module requirements into an LLVM IR module living in the given context. This translates operations from any dilalect that has a registered implementation of LLVMTranslationDialectInterface. -function mlirRegisterTransformsRemoveDeadValues() - @ccall mlir_c.mlirRegisterTransformsRemoveDeadValues()::Cvoid +# Returns +the generated LLVM IR Module from the translated MLIR module, it is owned by the caller. +""" +function mlirTranslateModuleToLLVMIR(_module, context) + @ccall mlir_c.mlirTranslateModuleToLLVMIR( + _module::MlirOperation, context::LLVMContextRef + )::LLVMModuleRef end -function mlirCreateTransformsSCCP() - @ccall mlir_c.mlirCreateTransformsSCCP()::MlirPass +struct MlirTypeFromLLVMIRTranslator + ptr::Ptr{Cvoid} end -function mlirRegisterTransformsSCCP() - @ccall mlir_c.mlirRegisterTransformsSCCP()::Cvoid -end +""" + mlirTypeFromLLVMIRTranslatorCreate(ctx) -function mlirCreateTransformsSROA() - @ccall mlir_c.mlirCreateTransformsSROA()::MlirPass +Create an LLVM::TypeFromLLVMIRTranslator and transfer ownership to the caller. +""" +function mlirTypeFromLLVMIRTranslatorCreate(ctx) + @ccall mlir_c.mlirTypeFromLLVMIRTranslatorCreate( + ctx::MlirContext + )::MlirTypeFromLLVMIRTranslator end -function mlirRegisterTransformsSROA() - @ccall mlir_c.mlirRegisterTransformsSROA()::Cvoid -end +""" + mlirTypeFromLLVMIRTranslatorDestroy(translator) -function mlirCreateTransformsStripDebugInfo() - @ccall mlir_c.mlirCreateTransformsStripDebugInfo()::MlirPass +Takes an LLVM::TypeFromLLVMIRTranslator owned by the caller and destroys it. It is the responsibility of the user to only pass an LLVM::TypeFromLLVMIRTranslator class. +""" +function mlirTypeFromLLVMIRTranslatorDestroy(translator) + @ccall mlir_c.mlirTypeFromLLVMIRTranslatorDestroy( + translator::MlirTypeFromLLVMIRTranslator + )::Cvoid end -function mlirRegisterTransformsStripDebugInfo() - @ccall mlir_c.mlirRegisterTransformsStripDebugInfo()::Cvoid -end +""" + mlirTypeFromLLVMIRTranslatorTranslateType(translator, llvmType) -function mlirCreateTransformsSymbolDCE() - @ccall mlir_c.mlirCreateTransformsSymbolDCE()::MlirPass +Translates the given LLVM IR type to the MLIR LLVM dialect. +""" +function mlirTypeFromLLVMIRTranslatorTranslateType(translator, llvmType) + @ccall mlir_c.mlirTypeFromLLVMIRTranslatorTranslateType( + translator::MlirTypeFromLLVMIRTranslator, llvmType::LLVMTypeRef + )::MlirType end -function mlirRegisterTransformsSymbolDCE() - @ccall mlir_c.mlirRegisterTransformsSymbolDCE()::Cvoid +struct MlirTypeToLLVMIRTranslator + ptr::Ptr{Cvoid} end -function mlirCreateTransformsSymbolPrivatize() - @ccall mlir_c.mlirCreateTransformsSymbolPrivatize()::MlirPass -end +""" + mlirTypeToLLVMIRTranslatorCreate(ctx) -function mlirRegisterTransformsSymbolPrivatize() - @ccall mlir_c.mlirRegisterTransformsSymbolPrivatize()::Cvoid +Create an LLVM::TypeToLLVMIRTranslator and transfer ownership to the caller. +""" +function mlirTypeToLLVMIRTranslatorCreate(ctx) + @ccall mlir_c.mlirTypeToLLVMIRTranslatorCreate( + ctx::LLVMContextRef + )::MlirTypeToLLVMIRTranslator end -function mlirCreateTransformsTopologicalSort() - @ccall mlir_c.mlirCreateTransformsTopologicalSort()::MlirPass -end +""" + mlirTypeToLLVMIRTranslatorDestroy(translator) -function mlirRegisterTransformsTopologicalSort() - @ccall mlir_c.mlirRegisterTransformsTopologicalSort()::Cvoid +Takes an LLVM::TypeToLLVMIRTranslator owned by the caller and destroys it. It is the responsibility of the user to only pass an LLVM::TypeToLLVMIRTranslator class. +""" +function mlirTypeToLLVMIRTranslatorDestroy(translator) + @ccall mlir_c.mlirTypeToLLVMIRTranslatorDestroy( + translator::MlirTypeToLLVMIRTranslator + )::Cvoid end -function mlirCreateTransformsViewOpGraph() - @ccall mlir_c.mlirCreateTransformsViewOpGraph()::MlirPass -end +""" + mlirTypeToLLVMIRTranslatorTranslateType(translator, mlirType) -function mlirRegisterTransformsViewOpGraph() - @ccall mlir_c.mlirRegisterTransformsViewOpGraph()::Cvoid +Translates the given MLIR LLVM dialect to the LLVM IR type. +""" +function mlirTypeToLLVMIRTranslatorTranslateType(translator, mlirType) + @ccall mlir_c.mlirTypeToLLVMIRTranslatorTranslateType( + translator::MlirTypeToLLVMIRTranslator, mlirType::MlirType + )::LLVMTypeRef end function stablehloScatterDimensionNumbersGet( @@ -10707,4 +9835,397 @@ function stablehloTypeExtensionsGetBoundsElem(attr, pos) )::Int64 end +function stablehloResultAccuracyModeAttrGet(ctx, value) + @ccall mlir_c.stablehloResultAccuracyModeAttrGet( + ctx::MlirContext, value::MlirStringRef + )::MlirAttribute +end + +function stablehloAttributeIsAResultAccuracyModeAttr(attr) + @ccall mlir_c.stablehloAttributeIsAResultAccuracyModeAttr(attr::MlirAttribute)::Bool +end + +function stablehloResultAccuracyModeAttrGetValue(attr) + @ccall mlir_c.stablehloResultAccuracyModeAttrGetValue( + attr::MlirAttribute + )::MlirStringRef +end + +function stablehloResultAccuracyAttrGet(ctx, atol, rtol, ulps, value) + @ccall mlir_c.stablehloResultAccuracyAttrGet( + ctx::MlirContext, atol::Cdouble, rtol::Cdouble, ulps::Int64, value::MlirStringRef + )::MlirAttribute +end + +function stablehloAttributeIsAResultAccuracyAttr(attr) + @ccall mlir_c.stablehloAttributeIsAResultAccuracyAttr(attr::MlirAttribute)::Bool +end + +function stablehloResultAccuracyAttrGetAtol(attr) + @ccall mlir_c.stablehloResultAccuracyAttrGetAtol(attr::MlirAttribute)::Cdouble +end + +function stablehloResultAccuracyAttrGetRtol(attr) + @ccall mlir_c.stablehloResultAccuracyAttrGetRtol(attr::MlirAttribute)::Cdouble +end + +function stablehloResultAccuracyAttrGetUlps(attr) + @ccall mlir_c.stablehloResultAccuracyAttrGetUlps(attr::MlirAttribute)::Int64 +end + +function stablehloResultAccuracyAttrGetMode(attr) + @ccall mlir_c.stablehloResultAccuracyAttrGetMode(attr::MlirAttribute)::MlirAttribute +end + +function sdyAttributeIsAMeshAxisAttr(attr) + @ccall mlir_c.sdyAttributeIsAMeshAxisAttr(attr::MlirAttribute)::Bool +end + +function sdyMeshAxisAttrGet(ctx, name, size) + @ccall mlir_c.sdyMeshAxisAttrGet( + ctx::MlirContext, name::MlirStringRef, size::Int64 + )::MlirAttribute +end + +function sdyMeshAxisAttrGetName(attr) + @ccall mlir_c.sdyMeshAxisAttrGetName(attr::MlirAttribute)::MlirStringRef +end + +function sdyMeshAxisAttrGetSize(attr) + @ccall mlir_c.sdyMeshAxisAttrGetSize(attr::MlirAttribute)::Int64 +end + +function sdyAttributeIsAMeshAttr(attr) + @ccall mlir_c.sdyAttributeIsAMeshAttr(attr::MlirAttribute)::Bool +end + +function sdyMeshAttrGet(ctx, nAxes, axes, nDeviceIds, deviceIds) + @ccall mlir_c.sdyMeshAttrGet( + ctx::MlirContext, + nAxes::intptr_t, + axes::Ptr{MlirAttribute}, + nDeviceIds::intptr_t, + deviceIds::Ptr{Int64}, + )::MlirAttribute +end + +function sdyMeshAttrGetDeviceIdsSize(attr) + @ccall mlir_c.sdyMeshAttrGetDeviceIdsSize(attr::MlirAttribute)::Int64 +end + +function sdyMeshAttrGetDeviceIdsElem(attr, pos) + @ccall mlir_c.sdyMeshAttrGetDeviceIdsElem(attr::MlirAttribute, pos::Int64)::Int64 +end + +function sdyMeshAttrGetAxesSize(attr) + @ccall mlir_c.sdyMeshAttrGetAxesSize(attr::MlirAttribute)::intptr_t +end + +function sdyMeshAttrGetAxesElem(attr, pos) + @ccall mlir_c.sdyMeshAttrGetAxesElem(attr::MlirAttribute, pos::intptr_t)::MlirAttribute +end + +function sdyAttributeIsASubAxisInfoAttr(attr) + @ccall mlir_c.sdyAttributeIsASubAxisInfoAttr(attr::MlirAttribute)::Bool +end + +function sdySubAxisInfoAttrGet(ctx, preSize, size) + @ccall mlir_c.sdySubAxisInfoAttrGet( + ctx::MlirContext, preSize::Int64, size::Int64 + )::MlirAttribute +end + +function sdySubAxisInfoAttrGetPreSize(attr) + @ccall mlir_c.sdySubAxisInfoAttrGetPreSize(attr::MlirAttribute)::Int64 +end + +function sdySubAxisInfoAttrGetSize(attr) + @ccall mlir_c.sdySubAxisInfoAttrGetSize(attr::MlirAttribute)::Int64 +end + +function sdyAttributeIsAnAxisRefAttr(attr) + @ccall mlir_c.sdyAttributeIsAnAxisRefAttr(attr::MlirAttribute)::Bool +end + +function sdyAxisRefAttrGet(ctx, name, subAxisInfo) + @ccall mlir_c.sdyAxisRefAttrGet( + ctx::MlirContext, name::MlirStringRef, subAxisInfo::MlirAttribute + )::MlirAttribute +end + +function sdyAxisRefAttrGetName(attr) + @ccall mlir_c.sdyAxisRefAttrGetName(attr::MlirAttribute)::MlirStringRef +end + +function sdyAxisRefAttrGetSubAxisInfo(attr) + @ccall mlir_c.sdyAxisRefAttrGetSubAxisInfo(attr::MlirAttribute)::MlirAttribute +end + +function sdyAttributeIsADimensionShardingAttr(attr) + @ccall mlir_c.sdyAttributeIsADimensionShardingAttr(attr::MlirAttribute)::Bool +end + +function sdyDimensionShardingAttrGet(ctx, nAxes, axes, isClosed, priority) + @ccall mlir_c.sdyDimensionShardingAttrGet( + ctx::MlirContext, + nAxes::intptr_t, + axes::Ptr{MlirAttribute}, + isClosed::Bool, + priority::Int64, + )::MlirAttribute +end + +function sdyDimensionShardingAttrGetAxesSize(attr) + @ccall mlir_c.sdyDimensionShardingAttrGetAxesSize(attr::MlirAttribute)::intptr_t +end + +function sdyDimensionShardingAttrGetAxesElem(attr, pos) + @ccall mlir_c.sdyDimensionShardingAttrGetAxesElem( + attr::MlirAttribute, pos::intptr_t + )::MlirAttribute +end + +function sdyDimensionShardingAttrGetIsClosed(attr) + @ccall mlir_c.sdyDimensionShardingAttrGetIsClosed(attr::MlirAttribute)::Bool +end + +function sdyDimensionShardingAttrGetPriority(attr) + @ccall mlir_c.sdyDimensionShardingAttrGetPriority(attr::MlirAttribute)::Int64 +end + +function sdyAttributeIsATensorShardingAttr(attr) + @ccall mlir_c.sdyAttributeIsATensorShardingAttr(attr::MlirAttribute)::Bool +end + +function sdyTensorShardingAttrGet( + ctx, meshOrRef, nDimShardings, dimShardings, nReplicatedAxes, replicatedAxes +) + @ccall mlir_c.sdyTensorShardingAttrGet( + ctx::MlirContext, + meshOrRef::MlirAttribute, + nDimShardings::intptr_t, + dimShardings::Ptr{MlirAttribute}, + nReplicatedAxes::intptr_t, + replicatedAxes::Ptr{MlirAttribute}, + )::MlirAttribute +end + +function sdyTensorShardingAttrGetMeshOrRef(attr) + @ccall mlir_c.sdyTensorShardingAttrGetMeshOrRef(attr::MlirAttribute)::MlirAttribute +end + +function sdyTensorShardingAttrGetDimShardingsSize(attr) + @ccall mlir_c.sdyTensorShardingAttrGetDimShardingsSize(attr::MlirAttribute)::intptr_t +end + +function sdyTensorShardingAttrGetDimShardingsElem(attr, pos) + @ccall mlir_c.sdyTensorShardingAttrGetDimShardingsElem( + attr::MlirAttribute, pos::intptr_t + )::MlirAttribute +end + +function sdyTensorShardingAttrGetReplicatedAxesSize(attr) + @ccall mlir_c.sdyTensorShardingAttrGetReplicatedAxesSize(attr::MlirAttribute)::intptr_t +end + +function sdyTensorShardingAttrGetReplicatedAxesElem(attr, pos) + @ccall mlir_c.sdyTensorShardingAttrGetReplicatedAxesElem( + attr::MlirAttribute, pos::intptr_t + )::MlirAttribute +end + +function sdyAttributeIsATensorShardingPerValueAttr(attr) + @ccall mlir_c.sdyAttributeIsATensorShardingPerValueAttr(attr::MlirAttribute)::Bool +end + +function sdyTensorShardingPerValueAttrGet(ctx, nShardings, shardings) + @ccall mlir_c.sdyTensorShardingPerValueAttrGet( + ctx::MlirContext, nShardings::intptr_t, shardings::Ptr{MlirAttribute} + )::MlirAttribute +end + +function sdyTensorShardingPerValueAttrGetShardingsSize(attr) + @ccall mlir_c.sdyTensorShardingPerValueAttrGetShardingsSize( + attr::MlirAttribute + )::intptr_t +end + +function sdyTensorShardingPerValueAttrGetShardingsElem(attr, pos) + @ccall mlir_c.sdyTensorShardingPerValueAttrGetShardingsElem( + attr::MlirAttribute, pos::intptr_t + )::MlirAttribute +end + +function sdyAttributeIsADimMappingAttr(attr) + @ccall mlir_c.sdyAttributeIsADimMappingAttr(attr::MlirAttribute)::Bool +end + +function sdyDimMappingAttrGet(ctx, nFactorIndices, factorIndices) + @ccall mlir_c.sdyDimMappingAttrGet( + ctx::MlirContext, nFactorIndices::intptr_t, factorIndices::Ptr{Int64} + )::MlirAttribute +end + +function sdyDimMappingAttrGetFactorIndicesSize(attr) + @ccall mlir_c.sdyDimMappingAttrGetFactorIndicesSize(attr::MlirAttribute)::intptr_t +end + +function sdyDimMappingAttrGetFactorIndicesElem(attr, pos) + @ccall mlir_c.sdyDimMappingAttrGetFactorIndicesElem( + attr::MlirAttribute, pos::intptr_t + )::Int64 +end + +function sdyAttributeIsATensorMappingAttr(attr) + @ccall mlir_c.sdyAttributeIsATensorMappingAttr(attr::MlirAttribute)::Bool +end + +function sdyTensorMappingAttrGet(ctx, nMappings, mappings) + @ccall mlir_c.sdyTensorMappingAttrGet( + ctx::MlirContext, nMappings::intptr_t, mappings::Ptr{MlirAttribute} + )::MlirAttribute +end + +function sdyTensorMappingAttrGetRank(attr) + @ccall mlir_c.sdyTensorMappingAttrGetRank(attr::MlirAttribute)::intptr_t +end + +function sdyTensorMappingAttrGetDimMappingsSize(attr) + @ccall mlir_c.sdyTensorMappingAttrGetDimMappingsSize(attr::MlirAttribute)::intptr_t +end + +function sdyTensorMappingAttrGetDimMappingsElem(attr, pos) + @ccall mlir_c.sdyTensorMappingAttrGetDimMappingsElem( + attr::MlirAttribute, pos::intptr_t + )::MlirAttribute +end + +function sdyAttributeIsAOpShardingRuleAttr(attr) + @ccall mlir_c.sdyAttributeIsAOpShardingRuleAttr(attr::MlirAttribute)::Bool +end + +function sdyOpShardingRuleAttrGet( + ctx, + nFactorSizes, + factorSizes, + nOperandMappings, + operandMappings, + nResultMappings, + resultMappings, + nReductionFactors, + reductionFactors, + nNeedReplicationFactors, + needReplicationFactors, + nPermutationFactors, + permutationFactors, + isCustomRule, +) + @ccall mlir_c.sdyOpShardingRuleAttrGet( + ctx::MlirContext, + nFactorSizes::intptr_t, + factorSizes::Ptr{Int64}, + nOperandMappings::intptr_t, + operandMappings::Ptr{MlirAttribute}, + nResultMappings::intptr_t, + resultMappings::Ptr{MlirAttribute}, + nReductionFactors::intptr_t, + reductionFactors::Ptr{Int64}, + nNeedReplicationFactors::intptr_t, + needReplicationFactors::Ptr{Int64}, + nPermutationFactors::intptr_t, + permutationFactors::Ptr{Int64}, + isCustomRule::Bool, + )::MlirAttribute +end + +function sdyOpShardingRuleAttrGetIsCustom(attr) + @ccall mlir_c.sdyOpShardingRuleAttrGetIsCustom(attr::MlirAttribute)::Bool +end + +function sdyOpShardingRuleAttrGetFactorSizesSize(attr) + @ccall mlir_c.sdyOpShardingRuleAttrGetFactorSizesSize(attr::MlirAttribute)::intptr_t +end + +function sdyOpShardingRuleAttrGetFactorSizesElem(attr, pos) + @ccall mlir_c.sdyOpShardingRuleAttrGetFactorSizesElem( + attr::MlirAttribute, pos::intptr_t + )::Int64 +end + +function sdyOpShardingRuleAttrGetOperandMappingsSize(attr) + @ccall mlir_c.sdyOpShardingRuleAttrGetOperandMappingsSize(attr::MlirAttribute)::intptr_t +end + +function sdyOpShardingRuleAttrGetOperandMappingsElem(attr, pos) + @ccall mlir_c.sdyOpShardingRuleAttrGetOperandMappingsElem( + attr::MlirAttribute, pos::intptr_t + )::MlirAttribute +end + +function sdyOpShardingRuleAttrGetResultMappingsSize(attr) + @ccall mlir_c.sdyOpShardingRuleAttrGetResultMappingsSize(attr::MlirAttribute)::intptr_t +end + +function sdyOpShardingRuleAttrGetResultMappingsElem(attr, pos) + @ccall mlir_c.sdyOpShardingRuleAttrGetResultMappingsElem( + attr::MlirAttribute, pos::intptr_t + )::MlirAttribute +end + +function sdyOpShardingRuleAttrGetReductionFactorsSize(attr) + @ccall mlir_c.sdyOpShardingRuleAttrGetReductionFactorsSize( + attr::MlirAttribute + )::intptr_t +end + +function sdyOpShardingRuleAttrGetReductionFactorsElem(attr, pos) + @ccall mlir_c.sdyOpShardingRuleAttrGetReductionFactorsElem( + attr::MlirAttribute, pos::intptr_t + )::Int64 +end + +function sdyOpShardingRuleAttrGetNeedReplicationFactorsSize(attr) + @ccall mlir_c.sdyOpShardingRuleAttrGetNeedReplicationFactorsSize( + attr::MlirAttribute + )::intptr_t +end + +function sdyOpShardingRuleAttrGetNeedReplicationFactorsElem(attr, pos) + @ccall mlir_c.sdyOpShardingRuleAttrGetNeedReplicationFactorsElem( + attr::MlirAttribute, pos::intptr_t + )::Int64 +end + +function sdyOpShardingRuleAttrGetPermutationFactorsSize(attr) + @ccall mlir_c.sdyOpShardingRuleAttrGetPermutationFactorsSize( + attr::MlirAttribute + )::intptr_t +end + +function sdyOpShardingRuleAttrGetPermutationFactorsElem(attr, pos) + @ccall mlir_c.sdyOpShardingRuleAttrGetPermutationFactorsElem( + attr::MlirAttribute, pos::intptr_t + )::Int64 +end + +function sdyAttributeIsAManualAxesAttr(attr) + @ccall mlir_c.sdyAttributeIsAManualAxesAttr(attr::MlirAttribute)::Bool +end + +function sdyManualAxesAttrGet(ctx, nAxes, axes) + @ccall mlir_c.sdyManualAxesAttrGet( + ctx::MlirContext, nAxes::intptr_t, axes::Ptr{MlirAttribute} + )::MlirAttribute +end + +function sdyManualAxesAttrGetAxesSize(attr) + @ccall mlir_c.sdyManualAxesAttrGetAxesSize(attr::MlirAttribute)::intptr_t +end + +function sdyManualAxesAttrGetAxesElem(attr, pos) + @ccall mlir_c.sdyManualAxesAttrGetAxesElem( + attr::MlirAttribute, pos::intptr_t + )::MlirStringRef +end + const MLIR_CAPI_DWARF_ADDRESS_SPACE_NULL = -1 diff --git a/src/stdlibs/Base.jl b/src/stdlibs/Base.jl new file mode 100644 index 0000000000..39b14538cd --- /dev/null +++ b/src/stdlibs/Base.jl @@ -0,0 +1,4 @@ +@inline Base.vcat(a::Number, b::Union{AnyConcretePJRTArray,AnyTracedRArray}) = + @allowscalar(vcat(fill!(similar(b, typeof(a), (1, size(b)[2:end]...)), a), b)) +@inline Base.hcat(a::Number, b::Union{AnyConcretePJRTArray,AnyTracedRArray}) = + @allowscalar(hcat(fill!(similar(b, typeof(a), (size(b)[1:(end - 1)]..., 1)), a), b)) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index c011f8aec0..007117bcff 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -1,51 +1,226 @@ module TracedLinearAlgebra -using ..Reactant -import ..TracedRArray -import ..TracedRNumber -import ..AnyTracedRArray -import ..AnyTracedRMatrix -import ..AnyTracedRVector - -import ..TracedUtils -using ..TracedUtils: get_mlir_data, materialize_traced_array, set_mlir_data! - -import ..Ops -import ..MLIR +using ..Reactant: + TracedRArray, + TracedRNumber, + AnyTracedRArray, + AnyTracedRMatrix, + AnyTracedRVector, + AnyTracedRVecOrMat, + WrappedTracedRArray, + unwrapped_eltype, + Ops, + MLIR + +using ..TracedUtils: TracedUtils, get_mlir_data, materialize_traced_array, set_mlir_data! + using LinearAlgebra -function LinearAlgebra.mul!( +# Various Wrapper Arrays defined in LinearAlgebra +function TracedUtils.materialize_traced_array( + x::Transpose{TracedRNumber{T},TracedRArray{T,N}} +) where {T,N} + px = parent(x) + A = ndims(px) == 1 ? reshape(px, :, 1) : px + return permutedims(A, (2, 1)) +end + +function TracedUtils.materialize_traced_array( + x::Transpose{TracedRNumber{T},<:WrappedTracedRArray{T,N}} +) where {T,N} + return materialize_traced_array(transpose(materialize_traced_array(parent(x)))) +end + +function TracedUtils.materialize_traced_array( + x::Adjoint{TracedRNumber{T},TracedRArray{T,N}} +) where {T,N} + return conj(materialize_traced_array(transpose(parent(x)))) +end + +function TracedUtils.materialize_traced_array( + x::Adjoint{TracedRNumber{T},<:WrappedTracedRArray{T,N}} +) where {T,N} + return materialize_traced_array(adjoint(materialize_traced_array(parent(x)))) +end + +function TracedUtils.materialize_traced_array( + x::Diagonal{TracedRNumber{T},TracedRArray{T,1}} +) where {T} + return diagm(parent(x)) +end + +function TracedUtils.materialize_traced_array( + x::Diagonal{TracedRNumber{T},WrappedTracedRArray{T,1}} +) where {T} + return diagm(materialize_traced_array(parent(x))) +end + +function TracedUtils.materialize_traced_array( + x::Tridiagonal{TracedRNumber{T},TracedRArray{T,1}} +) where {T} + return diagm(-1 => x.dl, 0 => x.d, 1 => x.du) +end + +for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE")) + uAT = Symbol(:Unit, AT) + @eval begin + function TracedUtils.materialize_traced_array( + x::$(AT){TracedRNumber{T},TracedRArray{T,2}} + ) where {T} + m, n = size(x) + row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1) + col_idxs = Ops.iota(Int, [m, n]; iota_dimension=2) + indicator = Ops.compare(row_idxs, col_idxs; comparison_direction=$(comp)) + return Ops.select(indicator, parent(x), zero(parent(x))) + end + + function TracedUtils.materialize_traced_array( + x::$(uAT){TracedRNumber{T},TracedRArray{T,2}} + ) where {T} + m, n = size(x) + row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1) + col_idxs = Ops.iota(Int, [m, n]; iota_dimension=2) + nondiag_indicator = Ops.compare(row_idxs, col_idxs; comparison_direction="NE") + x = materialize_traced_array($(AT)(parent(x))) + return Ops.select(nondiag_indicator, x, one.(x)) + end + end +end + +function TracedUtils.materialize_traced_array( + x::Symmetric{TracedRNumber{T},TracedRArray{T,2}} +) where {T} + m, n = size(x) + row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1) + col_idxs = Ops.iota(Int, [m, n]; iota_dimension=2) + if x.uplo == 'L' + indicator = Ops.compare(row_idxs, col_idxs; comparison_direction="GT") + x_lt = Ops.select(indicator, parent(x), zero(parent(x))) + x_ltd = materialize_traced_array(LowerTriangular(parent(x))) + return Ops.add(x_lt, Ops.transpose(x_ltd, [2, 1])) + else + indicator = Ops.compare(row_idxs, col_idxs; comparison_direction="LT") + x_ut = Ops.select(indicator, parent(x), zero(parent(x))) + x_utd = materialize_traced_array(UpperTriangular(parent(x))) + return Ops.add(Ops.transpose(x_utd, [2, 1]), x_ut) + end +end + +function TracedUtils.set_mlir_data!( + x::Transpose{TracedRNumber{T},TracedRArray{T,N}}, data +) where {T,N} + tdata = TracedRArray{T}(data) + px = parent(x) + px.mlir_data = ( + if ndims(px) == 1 + Ops.reshape(tdata, length(tdata)) + else + Ops.transpose(tdata, [2, 1]) + end + ).mlir_data + return x +end + +function TracedUtils.set_mlir_data!( + x::Adjoint{TracedRNumber{T},TracedRArray{T,N}}, data +) where {T,N} + tdata = TracedRArray{T}(data) + px = parent(x) + transposed_data = + ndims(px) == 1 ? Ops.reshape(tdata, length(tdata)) : Ops.transpose(tdata, [2, 1]) + px.mlir_data = (T <: Real ? transposed_data : Ops.conj(transposed_data)).mlir_data + return x +end + +function TracedUtils.set_mlir_data!( + x::Diagonal{TracedRNumber{T},TracedRArray{T,1}}, data +) where {T} + parent(x).mlir_data = diag(TracedRArray{T}(data)).mlir_data + return x +end + +for (AT, dcomp, ocomp) in ( + (:LowerTriangular, "GE", "LT"), + (:UnitLowerTriangular, "GT", "LE"), + (:UpperTriangular, "LE", "GT"), + (:UnitUpperTriangular, "LT", "GE"), +) + @eval function TracedUtils.set_mlir_data!( + x::$(AT){TracedRNumber{T},TracedRArray{T,2}}, data + ) where {T} + tdata = TracedRArray{T}(data) + z = zero(tdata) + m, n = size(x) + row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1) + col_idxs = Ops.iota(Int, [m, n]; iota_dimension=2) + data_indicator = Ops.compare(row_idxs, col_idxs; comparison_direction=$(dcomp)) + original_indicator = Ops.compare(row_idxs, col_idxs; comparison_direction=$(ocomp)) + res = Ops.add( + Ops.select(data_indicator, tdata, z), Ops.select(original_indicator, x.data, z) + ) + set_mlir_data!(x.data, res.mlir_data) + return x + end +end + +function TracedUtils.set_mlir_data!( + x::Symmetric{TracedRNumber{T},TracedRArray{T,2}}, data +) where {T} + if x.uplo == 'L' + set_mlir_data!(LowerTriangular(parent(x)), data) + else + set_mlir_data!(UpperTriangular(parent(x)), data) + end + return x +end + +function TracedUtils.set_mlir_data!( + x::Tridiagonal{TracedRNumber{T},TracedRArray{T,1}}, data +) where {T} + tdata = TracedRArray{T}(data) + set_mlir_data!(x.dl, diag(tdata, -1).mlir_data) + set_mlir_data!(x.d, diag(tdata, 0).mlir_data) + set_mlir_data!(x.du, diag(tdata, 1).mlir_data) + return x +end + +# Core functions +function overloaded_mul!( @nospecialize(C::TracedRArray{T,1}), - @nospecialize(A::AnyTracedRMatrix), - @nospecialize(B::AnyTracedRVector), + @nospecialize(A::AbstractMatrix), + @nospecialize(B::AbstractVector), α::Number=true, β::Number=false, ) where {T} - # TODO: The reshape operations are not getting optimized, we should directly call dot_general + # TODO: The reshape operations are not getting optimized, we should directly call + # dot_general rC = Ops.reshape(C, length(C), 1) - LinearAlgebra.mul!(rC, A, reshape(B, :, 1), α, β) + overloaded_mul!(rC, A, reshape(B, :, 1), α, β) C.mlir_data = get_mlir_data(vec(rC)) return C end -function LinearAlgebra.mul!( +function overloaded_mul!( @nospecialize(C::TracedRArray{T,2}), - @nospecialize(A::AnyTracedRMatrix), - @nospecialize(B::AnyTracedRVector), + @nospecialize(A::AbstractMatrix), + @nospecialize(B::AbstractVector), α::Number=true, β::Number=false, ) where {T} - LinearAlgebra.mul!(C, A, reshape(B, :, 1), α, β) + overloaded_mul!(C, A, reshape(B, :, 1), α, β) return C end -function LinearAlgebra.mul!( - @nospecialize(C::TracedRArray{T,2}), - @nospecialize(A::AnyTracedRMatrix), - @nospecialize(B::AnyTracedRMatrix), +function overloaded_mul!( + @nospecialize(C::TracedRArray{T,2} where {T}), + @nospecialize(A::AbstractMatrix), + @nospecialize(B::AbstractMatrix), α::Number=true, β::Number=false, -) where {T} +) + A = TracedUtils.promote_to(TracedRArray{unwrapped_eltype(A),2}, A) + B = TracedUtils.promote_to(TracedRArray{unwrapped_eltype(B),2}, B) + if size(C) != (size(A, 1), size(B, 2)) throw( DimensionMismatch( @@ -57,6 +232,7 @@ function LinearAlgebra.mul!( throw(DimensionMismatch("A has size $(size(A)), B has size $(size(B))")) end + T = unwrapped_eltype(C) tmp = Ops.dot_general( T.(materialize_traced_array(A)), T.(materialize_traced_array(B)); @@ -119,50 +295,160 @@ function LinearAlgebra.diag(x::AnyTracedRArray{T,2}, k::Integer=0) where {T} # :0: note: see current operation: %0 = "tensor.empty"() : () -> tensor<0xf64> length(indices) ≤ 0 && return TracedUtils.promote_to(TracedRArray{T,1}, T[]) - idxs = get_mlir_data(TracedUtils.promote_to(TracedRArray{Int,2}, indices)) - - #! format: off - dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet( - MLIR.IR.context(), - Int64(0), Int64[], - Int64(2), Int64[0, 1], - Int64(0), Int64[], - Int64(0), Int64[], - Int64(2), Int64[0, 1], - Int64(1) - ) - #! format: on + return Ops.gather_getindex(y, TracedUtils.promote_to(TracedRArray{Int,2}, indices)) +end - slice_sizes = get_mlir_data( - Reactant.TracedUtils.promote_to(TracedRArray{Int,1}, [1, 1]) - ) - res = MLIR.IR.result( - MLIR.Dialects.stablehlo.dynamic_gather( - get_mlir_data(y), idxs, slice_sizes; dimension_numbers - ), - 1, +function LinearAlgebra._diagm( + shape, kv::Pair{<:Integer,<:AnyTracedRArray{T,1}}... +) where {T} + m, n = LinearAlgebra.diagm_size(shape, kv...) + + # For repeated indices we need to aggregate the values + kv_updated = Dict{Integer,AnyTracedRArray{T,1}}() + for (k, v) in kv + if haskey(kv_updated, k) + kv_updated[k] = kv_updated[k] + v + else + kv_updated[k] = v + end + end + + scatter_indices = Matrix{Int64}[] + concat_inputs = MLIR.IR.Value[] + for (k, v) in pairs(kv_updated) + push!(scatter_indices, diagonal_indices_zero_indexed(m, n, k)[1:length(v), :]) + push!(concat_inputs, get_mlir_data(v)) + end + scatter_indices = Ops.constant(reduce(vcat, scatter_indices)) + values = TracedRArray{T,1}( + (), + MLIR.IR.result(MLIR.Dialects.stablehlo.concatenate(concat_inputs; dimension=0), 1), + (size(scatter_indices, 1),), ) - return TracedRArray{T,1}((), res, (diag_length,)) + return Ops.scatter_setindex(Ops.fill(zero(T), (m, n)), scatter_indices, values) +end + +# Common Utilities +## The cartesian version doesn't exist in julia 1.10 +function diagonal_indices_zero_indexed(m::Integer, n::Integer, k::Integer=0) + idx1, idx2 = 1 + max(0, -k), 1 + max(0, k) + L = max(0, k ≤ 0 ? min(m + k, n) : min(m, n - k)) + indices = Matrix{Int}(undef, (L, 2)) + for i in axes(indices, 1) + indices[i, 1] = idx1 + i - 2 + indices[i, 2] = idx2 + i - 2 + end + return indices +end + +function LinearAlgebra.ldiv!( + B::Union{ + AbstractArray{<:TracedRNumber{T},1}, + AbstractArray{<:TracedRNumber{T},2}, + AnyTracedRArray{T,1}, + AnyTracedRArray{T,2}, + }, + D::Diagonal, + A::AbstractVecOrMat, +) where {T} + LinearAlgebra.require_one_based_indexing(A, B) + dd = D.diag + d = length(dd) + m, n = size(A, 1), size(A, 2) + m′, n′ = size(B, 1), size(B, 2) + m == d || throw(DimensionMismatch("right hand side has $m rows but D is $d by $d")) + (m, n) == (m′, n′) || + throw(DimensionMismatch("expect output to be $m by $n, but got $m′ by $n′")) + B .= dd .\ A + # OG implementation below, we don't currently support the conditional throw exception + #j = findfirst(iszero, D.diag) + #isnothing(j) || throw(SingularException(j)) + #@inbounds for j = 1:n, i = 1:m + # B[i, j] = dd[i] \ A[i, j] + #end + return B end -function LinearAlgebra.diagm(v::AnyTracedRArray{T,1}) where {T} - return LinearAlgebra.diagm(length(v), length(v), v) +# Kronecker Product +function LinearAlgebra.kron( + x::AnyTracedRVecOrMat{T1}, y::AnyTracedRVecOrMat{T2} +) where {T1,T2} + x = materialize_traced_array(x) + y = materialize_traced_array(y) + z = similar(x, Base.promote_op(*, T1, T2), LinearAlgebra._kronsize(x, y)) + LinearAlgebra.kron!(z, x, y) + return z end -function LinearAlgebra.diagm(m::Integer, n::Integer, v::AnyTracedRArray{T,1}) where {T} - m, n = LinearAlgebra.diagm_size((m, n), 0 => v) # size check - v = materialize_traced_array(v) - D = length(v) - row_idxs = Ops.iota(Int, [D, D]; iota_dimension=1) - col_idxs = Ops.iota(Int, [D, D]; iota_dimension=2) - diag_indicator = Ops.compare(row_idxs, col_idxs; comparison_direction="EQ") +function LinearAlgebra.kron(x::AnyTracedRVector{T1}, y::AnyTracedRVector{T2}) where {T1,T2} + x = materialize_traced_array(x) + y = materialize_traced_array(y) + z = similar(x, Base.promote_op(*, T1, T2), length(x) * length(y)) + LinearAlgebra.kron!(z, x, y) + return z +end - mat = (v .+ zero(v)') .* diag_indicator - return Ops.pad( - mat, - TracedUtils.promote_to(TracedRNumber{T}, 0); - high=[m - length(v), n - length(v)], +function LinearAlgebra.kron!(C::AnyTracedRVector, A::AnyTracedRVector, B::AnyTracedRVector) + LinearAlgebra.kron!( + reshape(C, length(B), length(A)), reshape(A, 1, length(A)), reshape(B, length(B), 1) ) + return C +end + +function LinearAlgebra._kron!(C::AnyTracedRMatrix, A::AnyTracedRMatrix, B::AnyTracedRMatrix) + A = materialize_traced_array(A) + B = materialize_traced_array(B) + + final_shape = Int64[size(B, 1), size(A, 1), size(B, 2), size(A, 2)] + + A = Ops.broadcast_in_dim(A, Int64[2, 4], final_shape) + B = Ops.broadcast_in_dim(B, Int64[1, 3], final_shape) + + C_tmp = Ops.reshape(Ops.multiply(A, B), size(C)...) + set_mlir_data!(C, get_mlir_data(C_tmp)) + + return C +end + +function LinearAlgebra._kron!(C::AnyTracedRMatrix, A::AnyTracedRVector, B::AnyTracedRMatrix) + LinearAlgebra._kron!(C, reshape(A, length(A), 1), B) + return C +end + +function LinearAlgebra._kron!(C::AnyTracedRMatrix, A::AnyTracedRMatrix, B::AnyTracedRVector) + LinearAlgebra._kron!(C, A, reshape(B, length(B), 1)) + return C +end + +function LinearAlgebra.axpy!(α::Number, x::TracedRArray{T}, y::TracedRArray{T}) where {T} + if length(x) != length(y) + throw( + DimensionMismatch( + lazy"x has length $(length(x)), but y has length $(length(y))" + ), + ) + end + ax = Ops.multiply(x, TracedUtils.broadcast_to_size(T(α), size(x))) + + set_mlir_data!(y, get_mlir_data(Ops.add(y, ax))) + return y +end + +function LinearAlgebra.axpby!( + α::Number, x::TracedRArray{T}, β::Number, y::TracedRArray{T} +) where {T} + if length(x) != length(y) + throw( + DimensionMismatch( + lazy"x has length $(length(x)), but y has length $(length(y))" + ), + ) + end + ax = Ops.multiply(x, TracedUtils.broadcast_to_size(T(α), size(x))) + by = Ops.multiply(y, TracedUtils.broadcast_to_size(T(β), size(y))) + + set_mlir_data!(y, get_mlir_data(Ops.add(ax, by))) + return y end end diff --git a/src/stdlibs/Random.jl b/src/stdlibs/Random.jl index 271b78f802..708a0f1e71 100644 --- a/src/stdlibs/Random.jl +++ b/src/stdlibs/Random.jl @@ -8,61 +8,72 @@ using ..Reactant: Reactant, TracedRArray, TracedRNumber, + ConcreteRNG, TracedRNG, AnyTracedRArray, Reactant, TracedUtils, Ops, - ConcreteRArray + ConcretePJRTArray, + ConcretePJRTNumber, + unwrapped_eltype using Random: Random, AbstractRNG -@noinline function make_seed(rng::AbstractRNG=Random.RandomDevice()) - # XXX: We should really be able to call this here. But with our AbsInt it leads to a - # segfault. So we'll just call it in the rand! method. - # return rand(rng, UInt64, 2) - seed = Array{UInt64}(undef, 2) - Random.rand!(rng, seed) - return seed -end +@noinline make_seed(rng::AbstractRNG=Random.RandomDevice()) = + Random.rand!(rng, Vector{UInt64}(undef, 2)) -function Random.seed!(rng::TracedRNG, seed::Number) +@noinline function Random.seed!(rng::TracedRNG, seed::Number) if seed isa TracedRNumber error("Passing in `TracedRNumber` as a seed is not supported. Please pass in a \ `TracedRArray` of the appropriate size instead.") end seed = reinterpret(UInt64, Random.hash_seed(seed)) - seed = if Reactant.within_reactant_interpreter() - TracedUtils.promote_to(TracedRArray{UInt64,1}, seed[1:length(rng.seed)]) - else - ConcreteRArray(seed[1:length(rng.seed)]) - end - return Random.seed!(rng, seed) + return Random.seed!( + rng, TracedUtils.promote_to(TracedRArray{UInt64,1}, seed[1:length(rng.seed)]) + ) end -function Random.seed!(rng::TracedRNG, seed::AbstractArray{<:Integer,1}) +@noinline function Random.seed!(rng::TracedRNG, seed::AbstractVector{<:Integer}) return Random.seed!(rng, UInt64.(seed)) end -function Random.seed!(rng::TracedRNG, seed::AbstractArray{UInt64,1}) +@noinline function Random.seed!(rng::TracedRNG, seed::AbstractVector{UInt64}) return Random.seed!(rng, TracedUtils.promote_to(TracedRArray{UInt64,1}, seed)) end -function Random.seed!( - rng::TracedRNG, seed::Union{ConcreteRArray{UInt64,1},TracedRArray{UInt64,1}} -) +@noinline function Random.seed!(rng::TracedRNG, seed::TracedRArray{UInt64,1}) rng.seed = seed return rng end -@noinline TracedRNG() = TracedRNG(ConcreteRArray(make_seed())) -@noinline TracedRNG(seed::ConcreteRArray{UInt64,1}) = TracedRNG(seed, "DEFAULT") +@noinline function Random.seed!(rng::ConcreteRNG, seed::Number) + seed isa ConcretePJRTNumber && (seed = unwrapped_eltype(seed)(seed)) + seed = reinterpret(UInt64, Random.hash_seed(seed)) + return Random.seed!(rng, ConcretePJRTArray(seed)) +end + +@noinline function Random.seed!(rng::ConcreteRNG, seed::AbstractVector{<:Integer}) + return Random.seed!(rng, seed) +end -@noinline function default_rng() - Reactant.within_reactant_interpreter() || return TracedRNG() - return TracedRNG(TracedUtils.promote_to(TracedRArray{UInt64,1}, make_seed()), "DEFAULT") +@noinline function Random.seed!(rng::ConcreteRNG, seed::AbstractVector{UInt64}) + return Random.seed!(rng, ConcretePJRTArray(seed)) end +@noinline function Random.seed!(rng::ConcreteRNG, seed::ConcretePJRTArray{UInt64,1}) + rng.seed = seed + return rng +end + +Base.copy(rng::ConcreteRNG) = ConcreteRNG(copy(rng.seed), rng.algorithm) +Base.copy(rng::TracedRNG) = TracedRNG(copy(rng.seed), rng.algorithm) + +@noinline ConcreteRNG() = ConcreteRNG(ConcretePJRTArray(make_seed())) +@noinline ConcreteRNG(seed::ConcretePJRTArray{UInt64,1}) = ConcreteRNG(seed, "DEFAULT") + +@noinline default_rng() = ConcreteRNG() + @noinline rng_algorithm(rng::TracedRNG) = rng.algorithm @noinline rng_algorithm(::AbstractRNG) = "DEFAULT" diff --git a/src/utils.jl b/src/utils.jl index 56fa7587b4..36c4ca7fac 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,5 +1,5 @@ -function apply(f, args...; kwargs...) +function apply(f::F, args...; kwargs...) where {F} return f(args...; kwargs...) end @@ -89,19 +89,39 @@ function has_ancestor(query::Module, target::Module) end end -function should_rewrite_ft(@nospecialize(ft)) +function should_rewrite_call(@nospecialize(ft)) # Don't rewrite builtin or intrinsics if ft <: Core.IntrinsicFunction || ft <: Core.Builtin return false end if ft <: Core.Function - mod = ft.name.module - # Don't rewrite primitive ops, tracing utilities, or any MLIR-based functions - if has_ancestor(mod, Reactant.Ops) || - has_ancestor(mod, Reactant.TracedUtils) || - has_ancestor(mod, Reactant.MLIR) || - has_ancestor(mod, Reactant.TracedRandom) - return false + if hasfield(typeof(ft), :name) && + hasfield(typeof(ft.name), :name) && + isdefined(ft.name, :name) + namestr = String(ft.name.name) + if startswith(namestr, "##(overlay (. Reactant (inert REACTANT_METHOD_TABLE)") + return false + end + end + + # We need this for closures to work + if hasfield(typeof(ft), :name) && hasfield(typeof(ft.name), :module) + mod = ft.name.module + # Don't rewrite primitive ops, tracing utilities, or any MLIR-based functions + if has_ancestor(mod, Reactant.Ops) || + has_ancestor(mod, Reactant.TracedUtils) || + has_ancestor(mod, Reactant.MLIR) || + has_ancestor(mod, Reactant.TracedRandom) + return false + end + if string(mod) == "CUDA" + if ft.name.name == Symbol("#launch_configuration") + return false + end + if ft.name.name == Symbol("cudaconvert") + return false + end + end end end # Don't rewrite Val @@ -131,7 +151,10 @@ function should_rewrite_ft(@nospecialize(ft)) if ft <: Type{<:TracedRArray} || ft <: Type{<:TracedRNumber} || ft === Type{MLIR.IR.Location} || - ft === Type{MLIR.IR.Block} + ft === Type{MLIR.IR.Block} || + # TODO: perhaps problematic calls in `traced_call` + # should be moved to TracedUtils.jl: + ft <: typeof(Reactant.ReactantCore.traced_call) return false end @@ -145,7 +168,12 @@ function should_rewrite_ft(@nospecialize(ft)) ft <: typeof(Base.typemin) || ft <: typeof(Base.getproperty) || ft <: typeof(Base.vect) || - ft <: typeof(Base.eltype) + ft <: typeof(Base.eltype) || + ft <: typeof(Base.argtail) || + ft <: typeof(Base.identity) || + ft <: typeof(Base.print) || + ft <: typeof(Base.println) || + ft <: typeof(Adapt.adapt_structure) return false end @@ -153,6 +181,9 @@ function should_rewrite_ft(@nospecialize(ft)) return true end +# by default, same as `should_rewrite_call` +should_rewrite_invoke(@nospecialize(ft), @nospecialize(args)) = should_rewrite_call(ft) + # Avoid recursively interpreting into methods we define explicitly # as overloads, which we assume should handle the entirety of the # translation (and if not they can use call_in_reactant). @@ -165,7 +196,49 @@ function is_reactant_method(mi::Core.MethodInstance) return mt === REACTANT_METHOD_TABLE end -function rewrite_inst(inst, ir, interp) +struct MustThrowError end + +@generated function applyiterate_with_reactant( + iteratefn, applyfn, args::Vararg{Any,N} +) where {N} + if iteratefn != typeof(Base.iterate) + return quote + error("Unhandled apply_iterate with iteratefn=$iteratefn") + end + end + newargs = Vector{Expr}(undef, N) + for i in 1:N + @inbounds newargs[i] = :(args[$i]...) + end + quote + Base.@_inline_meta + call_with_reactant(applyfn, $(newargs...)) + end +end + +@generated function applyiterate_with_reactant( + mt::MustThrowError, iteratefn, applyfn, args::Vararg{Any,N} +) where {N} + @assert iteratefn == typeof(Base.iterate) + newargs = Vector{Expr}(undef, N) + for i in 1:N + @inbounds newargs[i] = :(args[$i]...) + end + quote + Base.@_inline_meta + call_with_reactant(mt, applyfn, $(newargs...)) + end +end + +function certain_error() + throw( + AssertionError( + "The inferred code was guaranteed to throw this error. And yet, it didn't. So here we are...", + ), + ) +end + +function rewrite_inst(inst, ir, interp, RT, guaranteed_error) if Meta.isexpr(inst, :call) # Even if type unstable we do not want (or need) to replace intrinsic # calls or builtins with our version. @@ -173,26 +246,58 @@ function rewrite_inst(inst, ir, interp) if ft == typeof(Core.kwcall) ft = Core.Compiler.widenconst(maybe_argextype(inst.args[3], ir)) end - if should_rewrite_ft(ft) - rep = Expr(:call, call_with_reactant, inst.args...) - return true, rep + if ft == typeof(Core._apply_iterate) + ft = Core.Compiler.widenconst(maybe_argextype(inst.args[3], ir)) + if Base.invokelatest(should_rewrite_call, ft) + if RT === Union{} + rep = Expr( + :call, + applyiterate_with_reactant, + MustThrowError(), + inst.args[2:end]..., + ) + return true, rep, Union{} + else + rep = Expr(:call, applyiterate_with_reactant, inst.args[2:end]...) + return true, rep, Any + end + end + elseif Base.invokelatest(should_rewrite_call, ft) + if RT === Union{} + rep = Expr(:call, call_with_reactant, MustThrowError(), inst.args...) + return true, rep, Union{} + else + rep = Expr(:call, call_with_reactant, inst.args...) + return true, rep, Any + end end end if Meta.isexpr(inst, :invoke) omi = inst.args[1]::Core.MethodInstance sig = omi.specTypes ft = sig.parameters[1] + argsig = sig.parameters[2:end] if ft == typeof(Core.kwcall) ft = sig.parameters[3] + argsig = sig.parameters[4:end] end - if should_rewrite_ft(ft) && !is_reactant_method(omi) + argsig = Core.apply_type(Core.Tuple, argsig...) + if Base.invokelatest(should_rewrite_invoke, ft, argsig) && !is_reactant_method(omi) method = omi.def::Core.Method min_world = Ref{UInt}(typemin(UInt)) max_world = Ref{UInt}(typemax(UInt)) + # RT = Any + if !method.isva || !Base.isvarargtype(sig.parameters[end]) - sig2 = Tuple{typeof(call_with_reactant),sig.parameters...} + if RT === Union{} + sig2 = Tuple{ + typeof(call_with_reactant),MustThrowError,sig.parameters... + } + else + sig2 = Tuple{typeof(call_with_reactant),sig.parameters...} + end else vartup = inst.args[end] ns = Type[] @@ -200,9 +305,18 @@ function rewrite_inst(inst, ir, interp) for i in 1:(length(inst.args) - 1 - (length(sig.parameters) - 1)) push!(ns, eT) end - sig2 = Tuple{ - typeof(call_with_reactant),sig.parameters[1:(end - 1)]...,ns... - } + if RT === Union{} + sig2 = Tuple{ + typeof(call_with_reactant), + MustThrowError, + sig.parameters[1:(end - 1)]..., + ns..., + } + else + sig2 = Tuple{ + typeof(call_with_reactant),sig.parameters[1:(end - 1)]...,ns... + } + end end lookup_result = lookup_world( @@ -220,23 +334,60 @@ function rewrite_inst(inst, ir, interp) match.sparams, ) n_method_args = method.nargs - rep = Expr(:invoke, mi, call_with_reactant, inst.args[2:end]...) - return true, rep + if RT === Union{} + rep = Expr( + :invoke, mi, call_with_reactant, MustThrowError(), inst.args[2:end]... + ) + return true, rep, Union{} + else + rep = Expr(:invoke, mi, call_with_reactant, inst.args[2:end]...) + return true, rep, Any + end end end - return false, inst + if isa(inst, Core.ReturnNode) && (!isdefined(inst, :val) || guaranteed_error) + min_world = Ref{UInt}(typemin(UInt)) + max_world = Ref{UInt}(typemax(UInt)) + + sig2 = Tuple{typeof(certain_error)} + + lookup_result = lookup_world( + sig2, interp.world, Core.Compiler.method_table(interp), min_world, max_world + ) + + match = lookup_result::Core.MethodMatch + # look up the method and code instance + mi = ccall( + :jl_specializations_get_linfo, + Ref{Core.MethodInstance}, + (Any, Any, Any), + match.method, + match.spec_types, + match.sparams, + ) + rep = Expr(:invoke, mi, certain_error) + return true, rep, Union{} + end + return false, inst, RT end -const oc_captures = Dict{Tuple{Type,Type,Core.CodeInfo,Int,Bool,Any},Core.OpaqueClosure}() +const oc_capture_vec = Vector{Any}() # Caching is both good to reducing compile times and necessary to work around julia bugs # in OpaqueClosure's: https://github.com/JuliaLang/julia/issues/56833 -function make_oc( - sig::Type, rt::Type, src::Core.CodeInfo, nargs::Int, isva::Bool, f::Any -)::Core.OpaqueClosure - key = (sig, rt, src, nargs, isva, f) +function make_oc_dict( + @nospecialize(oc_captures::Dict{FT,Core.OpaqueClosure}), + @nospecialize(sig::Type), + @nospecialize(rt::Type), + @nospecialize(src::Core.CodeInfo), + nargs::Int, + isva::Bool, + @nospecialize(f::FT) +)::Core.OpaqueClosure where {FT} + key = f if haskey(oc_captures, key) - return oc_captures[key] + oc = oc_captures[key] + oc else ores = ccall( :jl_new_opaque_closure_from_code_info, @@ -259,6 +410,39 @@ function make_oc( end end +function make_oc_ref( + oc_captures::Base.RefValue{Core.OpaqueClosure}, + @nospecialize(sig::Type), + @nospecialize(rt::Type), + @nospecialize(src::Core.CodeInfo), + nargs::Int, + isva::Bool, + @nospecialize(f) +)::Core.OpaqueClosure + if Base.isassigned(oc_captures) + return oc_captures[] + else + ores = ccall( + :jl_new_opaque_closure_from_code_info, + Any, + (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint), + sig, + rt, + rt, + @__MODULE__, + src, + 0, + nothing, + nargs, + isva, + f, + true, + )::Core.OpaqueClosure + oc_captures[] = ores + return ores + end +end + function safe_print(name, x) return ccall(:jl_, Cvoid, (Any,), name * " " * string(x)) end @@ -270,19 +454,22 @@ const DEBUG_INTERP = Ref(false) # to Any if our interpreter would change the return type of any result. # Also rewrite invoke (type stable call) to be :call, since otherwise apparently # screws up type inference after this (TODO this should be fixed). -function rewrite_insts!(ir, interp) +function rewrite_insts!(ir, interp, guaranteed_error) any_changed = false for (i, inst) in enumerate(ir.stmts) + # Explicitly skip any code which returns Union{} so that we throw the error + # instead of risking a segfault + RT = inst[:type] @static if VERSION < v"1.11" - changed, next = rewrite_inst(inst[:inst], ir, interp) + changed, next, RT = rewrite_inst(inst[:inst], ir, interp, RT, guaranteed_error) Core.Compiler.setindex!(ir.stmts[i], next, :inst) else - changed, next = rewrite_inst(inst[:stmt], ir, interp) + changed, next, RT = rewrite_inst(inst[:stmt], ir, interp, RT, guaranteed_error) Core.Compiler.setindex!(ir.stmts[i], next, :stmt) end if changed any_changed = true - Core.Compiler.setindex!(ir.stmts[i], Any, :type) + Core.Compiler.setindex!(ir.stmts[i], RT, :type) end end return ir, any_changed @@ -308,21 +495,35 @@ function call_with_reactant_generator( identity, Core.svec(:call_with_reactant, REDUB_ARGUMENTS_NAME), Core.svec() ) + fn = args[1] + sig = Tuple{args...} + + guaranteed_error = false + if fn === MustThrowError + guaranteed_error = true + fn = args[2] + sig = Tuple{args[2:end]...} + end + # look up the method match - builtin_error = :(throw( - AssertionError("Unsupported call_with_reactant of builtin $redub_arguments") - )) + builtin_error = + :(throw(AssertionError("Unsupported call_with_reactant of builtin $fn"))) - if args[1] <: Core.Builtin + if fn <: Core.Builtin return stub(world, source, builtin_error) end - method_error = :(throw( - MethodError($REDUB_ARGUMENTS_NAME[1], $REDUB_ARGUMENTS_NAME[2:end], $world) - )) - interp = ReactantInterpreter(; world) + if guaranteed_error + method_error = :(throw( + MethodError($REDUB_ARGUMENTS_NAME[2], $REDUB_ARGUMENTS_NAME[3:end], $world) + )) + else + method_error = :(throw( + MethodError($REDUB_ARGUMENTS_NAME[1], $REDUB_ARGUMENTS_NAME[2:end], $world) + )) + end - sig = Tuple{args...} + interp = ReactantInterpreter(; world) min_world = Ref{UInt}(typemin(UInt)) max_world = Ref{UInt}(typemax(UInt)) @@ -364,7 +565,21 @@ function call_with_reactant_generator( ir, rt = CC.typeinf_ircode(interp, mi, nothing) end - ir, any_changed = rewrite_insts!(ir, interp) + if guaranteed_error + if rt !== Union{} + safe_print("Inconsistent guaranteed error IR", ir) + end + rt = Union{} + end + + if DEBUG_INTERP[] + safe_print("ir", ir) + end + + if !is_reactant_method(mi::Core.MethodInstance) || guaranteed_error + ir, any_changed = rewrite_insts!(ir, interp, guaranteed_error) + end + src = ccall(:jl_new_code_info_uninit, Ref{CC.CodeInfo}, ()) src.slotnames = fill(:none, length(ir.argtypes) + 1) src.slotflags = fill(zero(UInt8), length(ir.argtypes)) @@ -408,6 +623,10 @@ function call_with_reactant_generator( fn_args = Any[] n_method_args = method.nargs n_actual_args = length(redub_arguments) + if guaranteed_error + offset += 1 + n_actual_args -= 1 + end tys = [] @@ -424,7 +643,7 @@ function call_with_reactant_generator( push!(overdubbed_codelocs, code_info.codelocs[1]) offset += 1 push!(fn_args, Core.SSAValue(length(overdubbed_code))) - push!(tys, redub_arguments[i]) + push!(tys, redub_arguments[i + (guaranteed_error ? 1 : 0)]) if DEBUG_INTERP[] push!( @@ -457,7 +676,12 @@ function call_with_reactant_generator( push!(overdubbed_code, trailing_arguments) push!(overdubbed_codelocs, code_info.codelocs[1]) push!(fn_args, Core.SSAValue(length(overdubbed_code))) - push!(tys, Tuple{redub_arguments[n_method_args:n_actual_args]...}) + push!( + tys, + Tuple{ + redub_arguments[(n_method_args:n_actual_args) .+ (guaranteed_error ? 1 : 0)]..., + }, + ) if DEBUG_INTERP[] push!( @@ -487,23 +711,39 @@ function call_with_reactant_generator( # inner code during compilation without special handling (i.e. call_in_world_total). # Opaque closures also require taking the function argument. We can work around the latter # if the function is stateless. But regardless, to work around this we sadly create/compile the opaque closure - oc = if false && Base.issingletontype(args[1]) + + dict, make_oc = if Base.issingletontype(fn) + Base.Ref{Core.OpaqueClosure}(), make_oc_ref + else + Dict{fn,Core.OpaqueClosure}(), make_oc_dict + end + + push!(oc_capture_vec, dict) + + oc = if false && Base.issingletontype(fn) res = Core._call_in_world_total( - world, make_oc, octup, rt, src, ocnargs, ocva, args[1].instance + world, make_oc, dict, octup, rt, src, ocnargs, ocva, fn.instance )::Core.OpaqueClosure else farg = fn_args[1] - push!(overdubbed_code, Expr(:call, make_oc, octup, rt, src, ocnargs, ocva, farg)) + rep = Expr(:call, make_oc, dict, octup, rt, src, ocnargs, ocva, farg) + push!(overdubbed_code, rep) push!(overdubbed_codelocs, code_info.codelocs[1]) Core.SSAValue(length(overdubbed_code)) end - push!(overdubbed_code, Expr(:(call), oc, fn_args[2:end]...)) - + push!(overdubbed_code, Expr(:call, oc, fn_args[2:end]...)) push!(overdubbed_codelocs, code_info.codelocs[1]) - push!(overdubbed_code, Core.ReturnNode(Core.SSAValue(length(overdubbed_code)))) + ocres = Core.SSAValue(length(overdubbed_code)) + + if DEBUG_INTERP[] + push!(overdubbed_code, Expr(:call, safe_print, "ocres", ocres)) + push!(overdubbed_codelocs, code_info.codelocs[1]) + end + + push!(overdubbed_code, Core.ReturnNode(ocres)) push!(overdubbed_codelocs, code_info.codelocs[1]) #=== set `code_info`/`reflection` fields accordingly ===# diff --git a/src/xla/Buffer.jl b/src/xla/Buffer.jl new file mode 100644 index 0000000000..091faf7895 --- /dev/null +++ b/src/xla/Buffer.jl @@ -0,0 +1,75 @@ +abstract type AbstractBuffer end + +function synced_buffer end +function buffer_on_cpu end +function to_host end +function unsafe_buffer_pointer end +function copy_buffer_to_device end +function sharding end + +Base.convert(::Type{Array}, buffer::AbstractBuffer) = convert(Array{eltype(buffer)}, buffer) + +function Base.convert(::Type{<:Array{T}}, buffer::AbstractBuffer) where {T} + arr = zeros(T, reverse(size(buffer))...) + XLA.to_host(buffer, arr) + return arr +end + +@inline function client( + buffers::Union{Array{<:AbstractBuffer},NTuple{<:Any,AbstractBuffer}} +) + all_clients = map(client, buffers) + @assert allequal(all_clients) "All buffers must have the same client" + return first(all_clients) +end + +@inline function synced_buffer( + buffers::Union{AbstractArray{<:AbstractBuffer},NTuple{<:Any,<:AbstractBuffer}} +) + return map(synced_buffer, buffers) +end + +function Base.show(io::IO, mime::MIME"text/plain", buffer::B) where {B<:AbstractBuffer} + print(io, "$(B) storing ") + show(io, mime, convert(Array, buffer)) + return nothing +end + +# Async Buffers +abstract type AbstractAsyncBuffer <: AbstractBuffer end + +Base.isempty(buffer::AbstractAsyncBuffer) = buffer.buffer.buffer == C_NULL + +function Base.convert(T::Type{Array}, buffer::AbstractAsyncBuffer) + XLA.await(buffer) + return convert(T, buffer.buffer) +end + +function Base.convert(T::Type{<:Array{T1}}, buffer::AbstractAsyncBuffer) where {T1} + XLA.await(buffer) + return convert(T, buffer.buffer) +end + +for op in (:(Base.ndims), :(Base.size), :(Base.eltype), :device, :client, :sharding) + @eval $op(buffer::AbstractAsyncBuffer) = $op(buffer.buffer) +end + +function XLA.synced_buffer(buffer::AbstractAsyncBuffer) + XLA.await(buffer) + return buffer.buffer +end + +function XLA.await(buffer::AbstractAsyncBuffer) + buffer.future === nothing && return nothing + future = buffer.future + buffer.future = nothing + XLA.await(future) + return nothing +end + +function XLA.is_ready(buffer::AbstractAsyncBuffer) + buffer.future === nothing && return true + return XLA.is_ready(buffer.future) +end + +XLA.buffer_on_cpu(buffer::AbstractAsyncBuffer) = XLA.buffer_on_cpu(buffer.buffer) diff --git a/src/xla/Client.jl b/src/xla/Client.jl new file mode 100644 index 0000000000..ccf715c1ba --- /dev/null +++ b/src/xla/Client.jl @@ -0,0 +1,16 @@ +abstract type AbstractClient end + +Base.:(==)(a::AbstractClient, b::AbstractClient) = a.client == b.client + +function client end +function free_client end +function num_devices end +function num_addressable_devices end +function process_index end +function devices end +function addressable_devices end +function get_device end +function get_addressable_device end +function platform_name end + +default_device(client::AbstractClient) = first(addressable_devices(client)) diff --git a/src/xla/Device.jl b/src/xla/Device.jl new file mode 100644 index 0000000000..f4b27eaa30 --- /dev/null +++ b/src/xla/Device.jl @@ -0,0 +1,26 @@ +abstract type AbstractDevice end + +function Base.show(io::IO, ::MIME"text/plain", device::D) where {D<:AbstractDevice} + print(io, "$(parentmodule(D)).Device($(device.device), \"$(string(device))\")") + return nothing +end + +function device end +function get_local_device_id end +function device_kind end +function default_memory end +function memories end +function is_addressable end + +""" + device_ordinal(device::Device) + +Given the device, return the corresponding global device ordinal in the client. +""" +function device_ordinal end + +function Base.string(device::AbstractDevice) + client = XLA.client(device) + pname = XLA.platform_name(client) + return "$(uppercase(pname)):$(device_ordinal(device)) $(device_kind(device))" +end diff --git a/src/xla/Distributed.jl b/src/xla/Distributed.jl new file mode 100644 index 0000000000..791d3cdc11 --- /dev/null +++ b/src/xla/Distributed.jl @@ -0,0 +1,200 @@ +# Client +mutable struct DistributedRuntimeClient + client::Ptr{Cvoid} + + function DistributedRuntimeClient(client::Ptr{Cvoid}) + @assert client != C_NULL + return finalizer(free_distributed_runtime_client, new(client)) + end +end + +function DistributedRuntimeClient( + coordinator_bind_address::String, + process_id::Integer; + rpc_timeout_in_seconds::Integer=120, + shutdown_timeout_in_minutes::Integer=5, + heartbeat_interval_in_seconds::Integer=10, + max_missing_heartbeats::Integer=10, + use_compression::Bool=true, +) + GC.@preserve coordinator_bind_address begin + client = @ccall MLIR.API.mlir_c.GetDistributedRuntimeClient( + coordinator_bind_address::Cstring, + Int32(process_id)::Int32, + Int32(rpc_timeout_in_seconds)::Int32, + Int32(shutdown_timeout_in_minutes)::Int32, + Int32(heartbeat_interval_in_seconds)::Int32, + Cint(max_missing_heartbeats)::Cint, + use_compression::Bool, + )::Ptr{Cvoid} + end + return DistributedRuntimeClient(client) +end + +function free_distributed_runtime_client(client::DistributedRuntimeClient) + GC.@preserve client begin + @ccall MLIR.API.mlir_c.free_distributed_runtime_client( + client.client::Ptr{Cvoid} + )::Cvoid + end +end + +function connect(client::DistributedRuntimeClient) + GC.@preserve client begin + @ccall MLIR.API.mlir_c.distributed_runtime_client_connect( + client.client::Ptr{Cvoid} + )::Cvoid + end +end + +function shutdown(client::DistributedRuntimeClient) + GC.@preserve client begin + @ccall MLIR.API.mlir_c.distributed_runtime_client_shutdown( + client.client::Ptr{Cvoid} + )::Cvoid + end +end + +# Service +mutable struct DistributedRuntimeService + service::Ptr{Cvoid} + + function DistributedRuntimeService(service::Ptr{Cvoid}) + @assert service != C_NULL + return finalizer(free_distributed_runtime_service, new(service)) + end +end + +function DistributedRuntimeService( + coordinator_bind_address::String, + num_nodes::Integer; + heartbeat_interval_in_seconds::Integer=10, + max_missing_heartbeats::Integer=10, + cluster_register_timeout_in_minutes::Integer=60, + shutdown_timeout_in_minutes::Integer=5, +) + GC.@preserve coordinator_bind_address begin + service = @ccall MLIR.API.mlir_c.GetDistributedRuntimeService( + coordinator_bind_address::Cstring, + Cint(num_nodes)::Cint, + Int32(heartbeat_interval_in_seconds)::Int32, + Cint(max_missing_heartbeats)::Cint, + Int32(cluster_register_timeout_in_minutes)::Int32, + Int32(shutdown_timeout_in_minutes)::Int32, + )::Ptr{Cvoid} + end + return DistributedRuntimeService(service) +end + +function free_distributed_runtime_service(service::DistributedRuntimeService) + GC.@preserve service begin + @ccall MLIR.API.mlir_c.free_distributed_runtime_service( + service.service::Ptr{Cvoid} + )::Cvoid + end +end + +function shutdown(service::DistributedRuntimeService) + GC.@preserve service begin + @ccall MLIR.API.mlir_c.distributed_runtime_service_shutdown( + service.service::Ptr{Cvoid} + )::Cvoid + end +end + +# Global State +@kwdef mutable struct State + process_id::Int = 0 + num_processes::Int = 1 + local_gpu_device_ids::Union{Nothing,Vector{Int}} = nothing + service::Union{Nothing,DistributedRuntimeService} = nothing + client::Union{Nothing,DistributedRuntimeClient} = nothing + coordinator_address::Union{Nothing,String} = nothing + coordinator_bind_address::Union{Nothing,String} = nothing +end + +function shutdown(state::State) + if state.service !== nothing + shutdown(state.service) + state.service = nothing + end + if state.client !== nothing + shutdown(state.client) + state.client = nothing + end +end + +function update!( + state::State; + coordinator_address::String, + num_processes::Int, + process_id::Int, + local_gpu_device_ids::Vector{Int}, + coordinator_bind_address::Union{Nothing,String}=nothing, + cluster_register_timeout_in_minutes::Integer=60, + rpc_timeout_in_seconds::Integer=120, + shutdown_timeout_in_minutes::Integer=5, + heartbeat_interval_in_seconds::Integer=10, + max_missing_heartbeats::Integer=10, + use_compression::Bool=true, +) + @assert 0 ≤ process_id < num_processes + + state.coordinator_address = coordinator_address + state.local_gpu_device_ids = local_gpu_device_ids + state.process_id = process_id + state.num_processes = num_processes + + if coordinator_bind_address === nothing + if haskey(ENV, "REACTANT_COORDINATOR_BIND_ADDRESS") + coordinator_bind_address = ENV["REACTANT_COORDINATOR_BIND_ADDRESS"] + else + coordinator_bind_address = + "[::]:" * rsplit(coordinator_address, ":"; limit=2)[2] + end + end + state.coordinator_bind_address = coordinator_bind_address + + if process_id == 0 + @assert state.service === nothing "`Reactant.Distributed.initialize` should only \ + be called once." + @debug "[PID $(process_id)] Starting Reactant distributed service on \ + $(coordinator_bind_address)" + state.service = DistributedRuntimeService( + coordinator_bind_address, + num_processes; + heartbeat_interval_in_seconds, + max_missing_heartbeats, + cluster_register_timeout_in_minutes, + shutdown_timeout_in_minutes, + ) + end + + # Check for proxy variables that might cause a hang + proxy_vars = filter(Base.Fix1(occursin, "_proxy") ∘ lowercase, keys(ENV)) + if length(proxy_vars) > 0 + vars = join(proxy_vars, ", ") + @warn "Reactant detected proxy variable(s) in the environment as distributed \ + setup: $(vars). On some systems, this may cause a hang of `XLA.update!` and \ + you may need to unset the proxy variables." + end + + @assert state.client === nothing "`Reactant.Distributed.initialize` should only be \ + called once." + state.client = DistributedRuntimeClient( + coordinator_address, + process_id; + rpc_timeout_in_seconds, + shutdown_timeout_in_minutes, + heartbeat_interval_in_seconds, + max_missing_heartbeats, + use_compression, + ) + @debug "[PID $(process_id)] Connecting to Reactant distributed service on \ + $(coordinator_address)" + connect(state.client) + @debug "[PID $(process_id)] Connected to Reactant distributed service on \ + $(coordinator_address)" + + return nothing +end diff --git a/src/xla/Future.jl b/src/xla/Future.jl new file mode 100644 index 0000000000..fe57f26ffc --- /dev/null +++ b/src/xla/Future.jl @@ -0,0 +1,5 @@ +abstract type AbstractFuture end + +function free_future end +function is_ready end +function await end diff --git a/src/xla/HloModule.jl b/src/xla/HloModule.jl new file mode 100644 index 0000000000..71e68e53f3 --- /dev/null +++ b/src/xla/HloModule.jl @@ -0,0 +1,20 @@ +mutable struct HloModule + ptr::Ptr{Cvoid} + + function HloModule(ptr::Ptr{Cvoid}) + @assert ptr != C_NULL + return finalizer(free_hlo_module, new(ptr)) + end +end + +function free_hlo_module(hlo_module) + @ccall MLIR.API.mlir_c.FreeHloModule(hlo_module.ptr::Ptr{Cvoid})::Cvoid +end + +function Base.show(io::IO, hlo_module::HloModule) + GC.@preserve hlo_module begin + str = @ccall MLIR.API.mlir_c.HloModuleToString(hlo_module.ptr::Ptr{Cvoid})::Cstring + end + print(io, unsafe_string_and_free(str)) + return nothing +end diff --git a/src/xla/IFRT/Array.jl b/src/xla/IFRT/Array.jl new file mode 100644 index 0000000000..139f14ca6f --- /dev/null +++ b/src/xla/IFRT/Array.jl @@ -0,0 +1,212 @@ +mutable struct Array <: XLA.AbstractBuffer + buffer::Ptr{Cvoid} + + function Array(buffer::Ptr{Cvoid}) + # return finalizer(free_ifrt_array, new(buffer)) + return new(buffer) + end +end + +function Array( + client::Client, + array::Base.Array{T,N}, + device::Device=XLA.default_device(client), + memory_kind::AbstractString=string(convert(MemoryKind, XLA.default_memory(device))), +) where {T,N} + sizear = collect(Int64, reverse(size(array))) + buffer = GC.@preserve array sizear begin + @ccall MLIR.API.mlir_c.ifrt_client_make_single_shard_array_from_host_buffer( + client.client::Ptr{Cvoid}, + array::Ptr{T}, + XLA.primitive_type(T)::UInt64, + N::Csize_t, + sizear::Ptr{Int64}, + 0::Cint, # kAlwaysCopy + device.device::Ptr{Cvoid}, + string(memory_kind)::Cstring, + )::Ptr{Cvoid} + end + return Array(buffer) +end + +function Array(client::Client, array::Base.Array{T,N}, sharding::Sharding) where {T,N} + sizear = collect(Int64, reverse(size(array))) + + if is_single_device_sharding(sharding) || is_fully_replicated(sharding) + buffer = GC.@preserve array sizear begin + @ccall MLIR.API.mlir_c.ifrt_client_make_array_from_host_buffer( + client.client::Ptr{Cvoid}, + array::Ptr{T}, + XLA.primitive_type(T)::Cint, + N::Csize_t, + sizear::Ptr{Int64}, + sharding.ptr::Ptr{Cvoid}, + 0::Cint, # kAlwaysCopy + )::Ptr{Cvoid} + end + return Array(buffer) + end + + all_devices = XLA.devices(sharding) + array_slices, _ = XLA.sharding_to_concrete_array_indices( + convert(XLA.HloSharding, sharding), + size(array), + collect(Int64, 0:(length(all_devices) - 1)), + ) + array_shape = collect(Int64, reverse(size(array))) + arrays_list = [ + Array(client, array[slice...], device).buffer for + (device, slice) in zip(all_devices, array_slices) if XLA.is_addressable(device) + ] + + buffer = GC.@preserve client arrays_list array_shape sharding begin + @ccall MLIR.API.mlir_c.ifrt_client_assemble_array_from_single_shards( + client.client::Ptr{Cvoid}, + Int32(length(array_shape))::Int32, + array_shape::Ptr{Int64}, + sharding.ptr::Ptr{Cvoid}, + Int32(length(arrays_list))::Int32, + arrays_list::Ptr{Ptr{Cvoid}}, + 2::Cint, # kDonateInput + )::Ptr{Cvoid} + end + + return Array(buffer) +end + +function Array(client::Client, array::Base.Array{T,N}, sharding) where {T,N} + @assert sharding isa Reactant.Sharding.AbstractSharding + if !(sharding isa Reactant.Sharding.HloSharding) + sharding = convert(Reactant.Sharding.HloSharding, sharding) + end + + (; hlo_sharding, mesh) = sharding + devices = XLA.get_device.((client,), mesh.device_ids) + ifrt_sharding = Sharding([devices...], hlo_sharding) + + return Array(client, array, ifrt_sharding) +end + +@inline function free_ifrt_array(buffer::Array) + sbuffer = buffer.buffer + if sbuffer != C_NULL + @ccall MLIR.API.mlir_c.ifrt_free_array(sbuffer::Ptr{Cvoid})::Cvoid + end +end + +function Base.ndims(buffer::Array) + GC.@preserve buffer begin + return @ccall MLIR.API.mlir_c.ifrt_array_ndims(buffer.buffer::Ptr{Cvoid})::Int64 + end +end + +function Base.size(buffer::Array) + GC.@preserve buffer begin + sz = @ccall MLIR.API.mlir_c.ifrt_array_shape(buffer.buffer::Ptr{Cvoid})::Ptr{Int64} + end + return Tuple(unsafe_wrap(Base.Array, sz, ndims(buffer))) +end + +function Base.eltype(buffer::Array) + GC.@preserve buffer begin + return XLA.julia_type( + @ccall MLIR.API.mlir_c.ifrt_array_eltype(buffer.buffer::Ptr{Cvoid})::Cint + ) + end +end + +function XLA.device(::Array) + return error("IFRT.Array can be sharded/replicated across multiple devices. Hence, \ + `XLA.device` is not defined.") +end + +function XLA.client(buffer::Array) + GC.@preserve buffer begin + return Client( + @ccall MLIR.API.mlir_c.ifrt_array_to_client( + buffer.buffer::Ptr{Cvoid} + )::Ptr{Cvoid} + ) + end +end + +XLA.synced_buffer(buffer::Array) = buffer + +function XLA.buffer_on_cpu(::Array) + return error("IFRT.Array does not support `XLA.buffer_on_cpu`") +end + +function XLA.to_host(buffer::Array, data) + sharding = XLA.sharding(buffer) + all_devices = XLA.devices(sharding) + + if length(all_devices) == 1 + GC.@preserve buffer data begin + @ccall MLIR.API.mlir_c.ifrt_array_copy_to_host_buffer( + buffer.buffer::Ptr{Cvoid}, data::Ptr{Cvoid} + )::Cvoid + end + return nothing + end + + if any(!is_addressable, all_devices) + @warn "Not all devices are addressable. Currently we only fill in the data for \ + addressable devices. Remaining slices of data in `data` are left \ + untouched." + end + + # While some client implementations might support directly copying to host, but we + # avoid the complexity of supporting that for now. + single_device_arrays = disassemble_into_single_device_arrays(buffer, true) + + array_slices, _ = XLA.sharding_to_concrete_array_indices( + convert(XLA.HloSharding, sharding), + size(data), + collect(Int64, 0:(length(all_devices) - 1)), + ) + array_slices = [ + slice for + (device, slice) in zip(all_devices, array_slices) if XLA.is_addressable(device) + ] + + @assert length(array_slices) == length(single_device_arrays) + + for (slice, arr) in zip(array_slices, single_device_arrays) + data_slice = data[slice...] + XLA.to_host(arr, data_slice) + data[slice...] .= data_slice + end + return nothing +end + +function disassemble_into_single_device_arrays(array::Array, only_addressable_devices::Bool) + c_single_device_shard_semantics = Int32(!only_addressable_devices) + narrays = Ref{Int32}(0) + arrays = GC.@preserve array begin + @ccall MLIR.API.mlir_c.ifrt_array_disassemble_into_single_device_arrays( + array.buffer::Ptr{Cvoid}, + Int32(0)::Int32, + c_single_device_shard_semantics::Int32, + narrays::Ptr{Int32}, + )::Ptr{Ptr{Cvoid}} + end + return [Array(unsafe_load(arrays, i)) for i in 1:narrays[]] +end + +function XLA.unsafe_buffer_pointer(::Array) + return error("IFRT.Array does not support `XLA.unsafe_buffer_pointer`") +end + +function XLA.copy_buffer_to_device(::Array, ::Device) + return error("IFRT.Array does not support `XLA.copy_buffer_to_device`") +end + +function XLA.sharding(buffer::Array) + GC.@preserve buffer begin + return Sharding( + @ccall MLIR.API.mlir_c.ifrt_array_to_sharding( + buffer.buffer::Ptr{Cvoid} + )::Ptr{Cvoid} + ) + end +end diff --git a/src/xla/IFRT/AsyncArray.jl b/src/xla/IFRT/AsyncArray.jl new file mode 100644 index 0000000000..b049289b13 --- /dev/null +++ b/src/xla/IFRT/AsyncArray.jl @@ -0,0 +1,8 @@ +mutable struct AsyncArray <: XLA.AbstractAsyncBuffer + buffer::Array + future::Union{Future,Nothing} +end + +const AsyncEmptyArray = AsyncArray(Array(C_NULL), nothing) + +AsyncArray(args...; kwargs...) = AsyncArray(Array(args...; kwargs...), nothing) diff --git a/src/xla/IFRT/Client.jl b/src/xla/IFRT/Client.jl new file mode 100644 index 0000000000..356a151401 --- /dev/null +++ b/src/xla/IFRT/Client.jl @@ -0,0 +1,192 @@ +mutable struct Client <: XLA.AbstractClient + client::Ptr{Cvoid} + + function Client(client::Ptr{Cvoid}; skip_check::Bool=false) + skip_check || (@assert client != C_NULL) + return new(client) + end +end + +function XLA.free_client(client::Client) + GC.@preserve client begin + @ccall MLIR.API.mlir_c.ifrt_FreeClient(client.client::Ptr{Cvoid})::Cvoid + end +end + +function XLA.num_devices(client::Client) + GC.@preserve client begin + return @ccall MLIR.API.mlir_c.ifrt_client_device_count( + client.client::Ptr{Cvoid} + )::Cint + end +end + +function XLA.num_addressable_devices(client::Client) + GC.@preserve client begin + return @ccall MLIR.API.mlir_c.ifrt_client_addressable_device_count( + client.client::Ptr{Cvoid} + )::Cint + end +end + +function XLA.process_index(client::Client) + GC.@preserve client begin + return @ccall MLIR.API.mlir_c.ifrt_ClientProcessIndex( + client.client::Ptr{Cvoid} + )::Cint + end +end + +function XLA.get_device(client::Client, idx) + GC.@preserve client begin + return Device( + @ccall MLIR.API.mlir_c.ifrt_client_lookup_device( + client.client::Ptr{Cvoid}, idx::Cint + )::Ptr{Cvoid} + ) + end +end + +function XLA.get_addressable_device(client::Client, idx) + GC.@preserve client begin + return Device( + @ccall MLIR.API.mlir_c.ifrt_client_lookup_addressable_device( + client.client::Ptr{Cvoid}, idx::Cint + )::Ptr{Cvoid} + ) + end +end + +function XLA.platform_name(client::Client) + GC.@preserve client begin + str = @ccall MLIR.API.mlir_c.ifrt_ClientGetPlatformName( + client.client::Ptr{Cvoid} + )::Cstring + end + return XLA.unsafe_string_and_free(str) +end + +function XLA.devices(client::Client) + ndevices = Int(XLA.num_devices(client)) + devices = Ref{NTuple{ndevices,Ptr{Cvoid}}}() + GC.@preserve client devices begin + @ccall MLIR.API.mlir_c.ifrt_client_devices( + client.client::Ptr{Cvoid}, devices::Ptr{Ptr{Cvoid}} + )::Cvoid + end + return [Device(device) for device in devices[]] +end + +function XLA.addressable_devices(client::Client) + naddressable_devices = Int(XLA.num_addressable_devices(client)) + addressable_devices = Ref{NTuple{naddressable_devices,Ptr{Cvoid}}}() + GC.@preserve client addressable_devices begin + @ccall MLIR.API.mlir_c.ifrt_client_addressable_devices( + client.client::Ptr{Cvoid}, addressable_devices::Ptr{Ptr{Cvoid}} + )::Cvoid + end + return [Device(device) for device in addressable_devices[]] +end + +# Different Backends +const cpu_client_count = Ref(0) +const gpu_client_count = Ref(0) +const tpu_client_count = Ref(0) + +for (backend, counter) in ( + (:CPUClient, :cpu_client_count), + (:GPUClient, :gpu_client_count), + (:TPUClient, :tpu_client_count), +) + main_fn = Symbol(:MakeIFRTPJRT, backend) + @eval function $(backend)(args...; checkcount::Bool=true, kwargs...) + if checkcount + @assert $(counter)[] == 0 + end + client, refstr = $(main_fn)(args...; kwargs...) + client == C_NULL && throw(AssertionError(unsafe_string(refstr[]))) + XLA.LLVMclopts("-nvptx-fma-level=1") + if checkcount + # Only increment the counter if we successfully created a client + $(counter)[] += 1 + end + return Client(client) + end +end + +function MakeIFRTPJRTCPUClient(; + node_id::Integer=0, + num_nodes::Integer=1, + asynchronous::Bool=true, + distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing, +) + refstr = Ref{Cstring}() + distributed_runtime_client = + distributed_runtime_client === nothing ? C_NULL : distributed_runtime_client.client + + GC.@preserve refstr distributed_runtime_client begin + client = @ccall MLIR.API.mlir_c.ifrt_make_pjrt_cpu_client( + asynchronous::UInt8, + node_id::Cint, + num_nodes::Cint, + distributed_runtime_client::Ptr{Cvoid}, + refstr::Ptr{Cstring}, + )::Ptr{Cvoid} + end + + return client, refstr +end + +function MakeIFRTPJRTGPUClient(; + node_id::Integer=0, + num_nodes::Integer=1, + platform::String="gpu", + allowed_devices::Union{Nothing,Vector{Int}}=nothing, + distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing, +) + refstr = Ref{Cstring}() + + num_allowed_devices = allowed_devices === nothing ? 0 : length(allowed_devices) + allowed_devices = allowed_devices === nothing ? C_NULL : allowed_devices + distributed_runtime_client = + distributed_runtime_client === nothing ? C_NULL : distributed_runtime_client.client + + GC.@preserve refstr allowed_devices distributed_runtime_client begin + client = @ccall MLIR.API.mlir_c.ifrt_make_pjrt_gpu_client( + node_id::Cint, + num_nodes::Cint, + allowed_devices::Ptr{Cvoid}, + num_allowed_devices::Cint, + XLA.XLA_REACTANT_GPU_MEM_FRACTION[]::Cdouble, + XLA.XLA_REACTANT_GPU_PREALLOCATE[]::Bool, + platform::Cstring, + refstr::Ptr{Cstring}, + distributed_runtime_client::Ptr{Cvoid}, + )::Ptr{Cvoid} + end + + return client, refstr +end + +function MakeIFRTPJRTTPUClient(; + tpu_path::String, + node_id::Integer=0, + num_nodes::Integer=1, + distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing, +) + refstr = Ref{Cstring}() + distributed_runtime_client = + distributed_runtime_client === nothing ? C_NULL : distributed_runtime_client.client + + GC.@preserve refstr distributed_runtime_client begin + client = @ccall MLIR.API.mlir_c.ifrt_make_pjrt_tpu_client( + tpu_path::Cstring, + refstr::Ptr{Cstring}, + node_id::Cint, + num_nodes::Cint, + distributed_runtime_client::Ptr{Cvoid}, + )::Ptr{Cvoid} + end + + return client, refstr +end diff --git a/src/xla/IFRT/Device.jl b/src/xla/IFRT/Device.jl new file mode 100644 index 0000000000..e550fd3013 --- /dev/null +++ b/src/xla/IFRT/Device.jl @@ -0,0 +1,74 @@ +struct Device <: XLA.AbstractDevice + device::Ptr{Cvoid} +end + +function XLA.client(device::Device) + GC.@preserve device begin + return Client( + @ccall MLIR.API.mlir_c.ifrt_DeviceToClient( + device.device::Ptr{Cvoid} + )::Ptr{Cvoid} + ) + end +end + +function XLA.device_ordinal(device::Device) + GC.@preserve device begin + return @ccall MLIR.API.mlir_c.ifrt_DeviceGetGlobalDeviceId( + device.device::Ptr{Cvoid} + )::Int64 + end +end + +function XLA.device_kind(device::Device) + GC.@preserve device begin + str = @ccall MLIR.API.mlir_c.ifrt_DeviceGetKind(device.device::Ptr{Cvoid})::Cstring + end + return XLA.unsafe_string_and_free(str) +end + +function XLA.get_local_device_id(::Device) + return error("Not implemented for ifrt devices") +end + +function XLA.default_memory(device::Device) + GC.@preserve device begin + return Memory( + @ccall MLIR.API.mlir_c.ifrt_DeviceGetDefaultMemory( + device.device::Ptr{Cvoid} + )::Ptr{Cvoid} + ) + end +end + +function XLA.memories(device::Device) + memories_size = Ref{Int32}(0) + GC.@preserve device memories_size begin + ptr = @ccall MLIR.API.mlir_c.ifrt_DeviceGetMemories( + device.device::Ptr{Cvoid}, memories_size::Ptr{Int32} + )::Ptr{Ptr{Cvoid}} + end + return [Memory(unsafe_load(ptr, i)) for i in 1:memories_size[]] +end + +# TODO: https://github.com/openxla/xla/blob/ad0814d221883609f784e57dd26914b17f92fbbc/xla/python/ifrt/sharding.cc#L60 +function XLA.default_memory(device_list::AbstractVector{Device}) + default_memories = XLA.default_memory.(device_list) + default_memory_kinds = convert.(MemoryKind, default_memories) + @assert allequal(default_memory_kinds) "All devices must have the same default memory" + return first(default_memories) +end + +function XLA.client(device_list::AbstractVector{Device}) + clients = XLA.client.(device_list) + @assert allequal(clients) "All devices must have the same client" + return first(clients) +end + +function XLA.is_addressable(device::Device) + GC.@preserve device begin + return @ccall MLIR.API.mlir_c.ifrt_DeviceIsAddressable( + device.device::Ptr{Cvoid} + )::Bool + end +end diff --git a/src/xla/IFRT/Future.jl b/src/xla/IFRT/Future.jl new file mode 100644 index 0000000000..bbb8128dbe --- /dev/null +++ b/src/xla/IFRT/Future.jl @@ -0,0 +1,27 @@ +mutable struct Future <: XLA.AbstractFuture + future::Ptr{Cvoid} + + function Future(future::Ptr{Cvoid}) + @assert future != C_NULL + return finalizer(free_future, new(future)) + end +end + +@inline function free_future(future::Future) + @ccall MLIR.API.mlir_c.ifrt_free_future(future.future::Ptr{Cvoid})::Cvoid +end + +function XLA.is_ready(future::Future) + GC.@preserve future begin + return (@ccall MLIR.API.mlir_c.ifrt_future_is_ready( + future.future::Ptr{Cvoid} + )::UInt8) != 0 + end +end + +@inline function XLA.await(future::Future)::Nothing + GC.@preserve future begin + @ccall MLIR.API.mlir_c.ifrt_future_await(future.future::Ptr{Cvoid})::Cvoid + end + return nothing +end diff --git a/src/xla/IFRT/IFRT.jl b/src/xla/IFRT/IFRT.jl new file mode 100644 index 0000000000..cdca7a520c --- /dev/null +++ b/src/xla/IFRT/IFRT.jl @@ -0,0 +1,15 @@ +module IFRT + +using ..Reactant: Reactant, MLIR +using ..XLA: XLA + +include("Client.jl") +include("Device.jl") +include("Memory.jl") +include("Future.jl") +include("Sharding.jl") +include("Array.jl") +include("AsyncArray.jl") +include("LoadedExecutable.jl") + +end diff --git a/src/xla/IFRT/LoadedExecutable.jl b/src/xla/IFRT/LoadedExecutable.jl new file mode 100644 index 0000000000..d126119d70 --- /dev/null +++ b/src/xla/IFRT/LoadedExecutable.jl @@ -0,0 +1,133 @@ +mutable struct LoadedExecutable <: XLA.AbstractLoadedExecutable + exec::Ptr{Cvoid} + num_outputs::Int64 + num_parameters::Int64 + is_sharded::Bool + num_replicas::Int64 + num_partitions::Int64 + + function LoadedExecutable(exec::Ptr{Cvoid}, args...) + @assert exec != C_NULL + return finalizer(free_exec, new(exec, args...)) + end +end + +function free_exec(exec::LoadedExecutable) + GC.@preserve exec begin + @ccall MLIR.API.mlir_c.ifrt_loaded_executable_dtor(exec.exec::Ptr{Cvoid})::Cvoid + end +end + +function XLA.client(exec::LoadedExecutable) + GC.@preserve exec begin + return Client( + @ccall MLIR.API.mlir_c.ifrt_loaded_executable_client( + exec.exec::Ptr{Cvoid} + )::Ptr{Cvoid} + ) + end +end + +XLA.num_partitions(exec::LoadedExecutable) = exec.num_partitions +XLA.num_replicas(exec::LoadedExecutable) = exec.num_replicas +XLA.num_devices(exec::LoadedExecutable) = XLA.num_replicas(exec) * XLA.num_partitions(exec) + +for (jlop, xlaop, field) in ( + (:get_output_shardings, :ifrt_loaded_executable_get_output_shardings, :num_outputs), + ( + :get_parameter_shardings, + :ifrt_loaded_executable_get_parameter_shardings, + :num_parameters, + ), +) + @eval function XLA.$(jlop)(exec::LoadedExecutable) + exec.is_sharded || return XLA.OpSharding[] + + op_shardings = Ref{NTuple{exec.$(field),Ptr{Cvoid}}}() + + GC.@preserve exec op_shardings begin + @ccall MLIR.API.mlir_c.$(xlaop)( + exec.exec::Ptr{Cvoid}, op_shardings::Ptr{Ptr{Cvoid}}, exec.$(field)::Cint + )::Cvoid + end + + return [XLA.OpSharding(op_sharding) for op_sharding in op_shardings[]] + end +end + +function XLA.get_hlo_modules(exec::LoadedExecutable) + # If we had compiled with MPMD then we would need all the partitions to get hlo_modules + # but if we used SPMD we get only 1 module. To be safe we allocate for all the modules + # and use the ones assigned to by XLA + hlo_modules = Ref{NTuple{Int64(XLA.num_partitions(exec)),Ptr{Cvoid}}}() + nmodules = Ref{Int32}(0) + GC.@preserve exec hlo_modules begin + @ccall MLIR.API.mlir_c.ifrt_loaded_executable_get_hlo_modules( + exec.exec::Ptr{Cvoid}, hlo_modules::Ptr{Ptr{Cvoid}}, nmodules::Ptr{Int32} + )::Cvoid + end + return map(XLA.HloModule, hlo_modules[][1:Int(nmodules[])]) +end + +function XLA.compile( + client::Client, + device::Union{Device,Nothing}, + mod::MLIR.IR.Module; + is_sharded::Bool=false, + global_device_ids::Vector{Int64}=Int64[], + num_outputs::Int64, + num_parameters::Int64, + num_replicas::Int64, + num_partitions::Int64, +) + device_id = is_sharded ? Int64(-1) : Int64(XLA.device_ordinal(device)) + GC.@preserve client mod begin + exec = @ccall MLIR.API.mlir_c.ifrt_compile( + client.client::Ptr{Cvoid}, + mod.module_::MLIR.API.MlirModule, + device_id::Clong, + is_sharded::Bool, + global_device_ids::Ptr{Clong}, + length(global_device_ids)::Clong, + XLA.CUDA_DATA_DIR[]::Cstring, + )::Ptr{Cvoid} + end + return LoadedExecutable( + exec, num_outputs, num_parameters, is_sharded, num_replicas, num_partitions + ) +end + +@inline function XLA.execute( + exec::LoadedExecutable, + inputs::NTuple{N,Ptr{Cvoid}}, + donated_args::NTuple{M,UInt8}, + ::Val{n_outs}, +) where {N,M,n_outs} + outputs = Ref{NTuple{n_outs,Ptr{Cvoid}}}() + future_res = Ref{Ptr{Cvoid}}() + futures = Ref{UInt8}(0) + + inputs = Base.RefValue(inputs) + donated_args = Base.RefValue(donated_args) + GC.@preserve exec outputs future_res futures begin + @ccall MLIR.API.mlir_c.ifrt_loaded_executable_execute( + exec.exec::Ptr{Cvoid}, + N::Cint, + inputs::Ptr{Ptr{Cvoid}}, + donated_args::Ptr{UInt8}, + n_outs::Cint, + Base.unsafe_convert(Ptr{Ptr{Cvoid}}, outputs)::Ptr{Ptr{Cvoid}}, + futures::Ptr{UInt8}, + future_res::Ptr{Ptr{Cvoid}}, + )::Cvoid + end + + outputs = outputs[] + future = futures[] != 0 + future && (future_res[] = future_res[]) + + return ntuple(n_outs) do i + Base.@_inline_meta + AsyncArray(Array(outputs[i]), future ? Future(future_res[]) : nothing) + end +end diff --git a/src/xla/IFRT/Memory.jl b/src/xla/IFRT/Memory.jl new file mode 100644 index 0000000000..b2f9575667 --- /dev/null +++ b/src/xla/IFRT/Memory.jl @@ -0,0 +1,56 @@ +mutable struct Memory <: XLA.AbstractMemory + ptr::Ptr{Cvoid} +end + +function Base.show(io::IO, memory::Memory) + GC.@preserve memory begin + str = @ccall MLIR.API.mlir_c.ifrt_MemoryToString(memory.ptr::Ptr{Cvoid})::Cstring + end + print(io, "XLA.IFRT.Memory(\"", XLA.unsafe_string_and_free(str), "\")") + return nothing +end + +mutable struct MemoryKind <: XLA.AbstractMemoryKind + ptr::Ptr{Cvoid} +end + +function MemoryKind(str::AbstractString) + str = string(str) + GC.@preserve str begin + return MemoryKind( + @ccall MLIR.API.mlir_c.ifrt_memory_kind_from_string(str::Cstring)::Ptr{Cvoid} + ) + end +end + +function Base.convert(::Type{MemoryKind}, memory::Memory) + GC.@preserve memory begin + return MemoryKind( + @ccall MLIR.API.mlir_c.ifrt_MemoryGetMemoryKind( + memory.ptr::Ptr{Cvoid} + )::Ptr{Cvoid} + ) + end +end + +function Base.:(==)(a::MemoryKind, b::MemoryKind) + GC.@preserve a b begin + return @ccall MLIR.API.mlir_c.ifrt_MemoryKindsAreEqual( + a.ptr::Ptr{Cvoid}, b.ptr::Ptr{Cvoid} + )::Bool + end +end + +function Base.string(memory_kind::MemoryKind) + GC.@preserve memory_kind begin + str = @ccall MLIR.API.mlir_c.ifrt_MemoryKindToString( + memory_kind.ptr::Ptr{Cvoid} + )::Cstring + end + return XLA.unsafe_string_and_free(str) +end + +function Base.show(io::IO, memory_kind::MemoryKind) + print(io, "XLA.IFRT.MemoryKind(\"", string(memory_kind), "\")") + return nothing +end diff --git a/src/xla/IFRT/Sharding.jl b/src/xla/IFRT/Sharding.jl new file mode 100644 index 0000000000..01d9780ed3 --- /dev/null +++ b/src/xla/IFRT/Sharding.jl @@ -0,0 +1,170 @@ +# xla::ifrt::HloSharding (distinct from xla::HloSharding) +mutable struct HloSharding + ptr::Ptr{Cvoid} + + function HloSharding(ptr::Ptr{Cvoid}) + @assert ptr != C_NULL + # return finalizer(free_hlo_sharding, new(ptr)) + return new(ptr) + end +end + +function free_hlo_sharding(hlo_sharding::HloSharding) + @ccall MLIR.API.mlir_c.free_ifrt_hlo_sharding(hlo_sharding.ptr::Ptr{Cvoid})::Cvoid +end + +function Base.convert(::Type{XLA.HloSharding}, sharding::HloSharding) + GC.@preserve sharding begin + return XLA.HloSharding( + @ccall MLIR.API.mlir_c.ifrt_hlo_sharding_to_xla_hlo_sharding( + sharding.ptr::Ptr{Cvoid} + )::Ptr{Cvoid} + ) + end +end + +function HloSharding( + device_list::AbstractVector{<:Device}, xla_hlo_sharding::XLA.HloSharding +) + addressable_devices = filter(XLA.is_addressable, device_list) + default_memory_kind = convert(MemoryKind, XLA.default_memory(addressable_devices)) + return HloSharding(device_list, xla_hlo_sharding, default_memory_kind) +end + +function HloSharding( + device_list::AbstractVector{<:Device}, + xla_hlo_sharding::XLA.HloSharding, + memoy_kind::AbstractString, +) + return HloSharding(device_list, xla_hlo_sharding, MemoryKind(memoy_kind)) +end + +function HloSharding( + device_list::AbstractVector{<:Device}, + xla_hlo_sharding::XLA.HloSharding, + memory_kind::MemoryKind, +) + client = XLA.client(device_list) + GC.@preserve device_list memory_kind xla_hlo_sharding client begin + return HloSharding( + @ccall MLIR.API.mlir_c.ifrt_hlo_sharding_from_xla_hlo_sharding( + client.client::Ptr{Cvoid}, + [d.device for d in device_list]::Ptr{Ptr{Cvoid}}, + length(device_list)::Int32, + memory_kind.ptr::Ptr{Cvoid}, + xla_hlo_sharding.ptr::Ptr{Cvoid}, + )::Ptr{Cvoid} + ) + end +end + +function Base.string(hlo_sharding::HloSharding) + GC.@preserve hlo_sharding begin + str = @ccall MLIR.API.mlir_c.ifrt_hlo_sharding_to_string( + hlo_sharding.ptr::Ptr{Cvoid} + )::Cstring + end + return XLA.unsafe_string_and_free(str) +end + +function Base.show(io::IO, ::MIME"text/plain", hlo_sharding::HloSharding) + print(io, "XLA.IFRT.HloSharding(\"", string(hlo_sharding), "\")") + return nothing +end + +# HloSharding is more specific than Sharding. But Sharding is a neater way to deal with +# most of the IFRT APIs. +mutable struct Sharding + ptr::Ptr{Cvoid} + + function Sharding(ptr::Ptr{Cvoid}) + @assert ptr != C_NULL + # return finalizer(free_sharding, new(ptr)) + return new(ptr) + end +end + +function Sharding(device_list::AbstractVector{<:Device}, xla_hlo_sharding::XLA.HloSharding) + return convert(Sharding, HloSharding(device_list, xla_hlo_sharding)) +end + +function Sharding( + device_list::AbstractVector{<:Device}, + xla_hlo_sharding::XLA.HloSharding, + memoy_kind::Union{AbstractString,MemoryKind}, +) + return convert(Sharding, HloSharding(device_list, xla_hlo_sharding, memoy_kind)) +end + +function free_sharding(sharding::Sharding) + @ccall MLIR.API.mlir_c.free_ifrt_sharding(sharding.ptr::Ptr{Cvoid})::Cvoid +end + +function XLA.devices(sharding::Sharding) + GC.@preserve sharding begin + ndevices = @ccall MLIR.API.mlir_c.ifrt_sharding_devices_size( + sharding.ptr::Ptr{Cvoid} + )::Int32 + end + devices = Ref{NTuple{Int64(ndevices),Ptr{Cvoid}}}() + GC.@preserve sharding devices begin + @ccall MLIR.API.mlir_c.ifrt_sharding_to_device_list( + sharding.ptr::Ptr{Cvoid}, devices::Ptr{Ptr{Cvoid}} + )::Cvoid + end + return [Device(device) for device in devices[]] +end + +function Base.convert(::Type{Sharding}, hlo_sharding::HloSharding) + GC.@preserve hlo_sharding begin + return Sharding( + @ccall MLIR.API.mlir_c.ifrt_sharding_from_ifrt_hlo_sharding( + hlo_sharding.ptr::Ptr{Cvoid} + )::Ptr{Cvoid} + ) + end +end + +function Base.convert(::Type{HloSharding}, sharding::Sharding) + GC.@preserve sharding begin + return HloSharding( + @ccall MLIR.API.mlir_c.ifrt_sharding_to_ifrt_hlo_sharding( + sharding.ptr::Ptr{Cvoid} + )::Ptr{Cvoid} + ) + end +end + +function Base.convert(::Type{XLA.HloSharding}, sharding::Sharding) + return convert(XLA.HloSharding, convert(HloSharding, sharding)) +end + +function Base.string(sharding::Sharding) + GC.@preserve sharding begin + str = @ccall MLIR.API.mlir_c.ifrt_sharding_to_string( + sharding.ptr::Ptr{Cvoid} + )::Cstring + end + return XLA.unsafe_string_and_free(str) +end + +function is_fully_replicated(sharding::Sharding) + GC.@preserve sharding begin + return @ccall MLIR.API.mlir_c.ifrt_sharding_is_fully_replicated( + sharding.ptr::Ptr{Cvoid} + )::Bool + end +end + +function is_single_device_sharding(sharding::Sharding) + GC.@preserve sharding begin + return @ccall MLIR.API.mlir_c.ifrt_sharding_is_single_device_sharding( + sharding.ptr::Ptr{Cvoid} + )::Bool + end +end + +function Base.show(io::IO, ::MIME"text/plain", sharding::Sharding) + print(io, "XLA.IFRT.Sharding(\"", string(sharding), "\")") + return nothing +end diff --git a/src/xla/LoadedExecutable.jl b/src/xla/LoadedExecutable.jl new file mode 100644 index 0000000000..55efd21b66 --- /dev/null +++ b/src/xla/LoadedExecutable.jl @@ -0,0 +1,12 @@ +abstract type AbstractLoadedExecutable end + +function num_replicas end +function num_partitions end +function num_devices end +function get_hlo_modules end +function get_output_shardings end +function get_parameter_shardings end + +function compile end +function execute end +function execute_sharded end diff --git a/src/xla/Memory.jl b/src/xla/Memory.jl new file mode 100644 index 0000000000..44f72aec75 --- /dev/null +++ b/src/xla/Memory.jl @@ -0,0 +1,3 @@ +abstract type AbstractMemory end + +abstract type AbstractMemoryKind end diff --git a/src/xla/PJRT/AsyncBuffer.jl b/src/xla/PJRT/AsyncBuffer.jl new file mode 100644 index 0000000000..9cec91eba8 --- /dev/null +++ b/src/xla/PJRT/AsyncBuffer.jl @@ -0,0 +1,8 @@ +mutable struct AsyncBuffer <: XLA.AbstractAsyncBuffer + buffer::Buffer + future::Union{Future,Nothing} +end + +const AsyncEmptyBuffer = AsyncBuffer(Buffer(C_NULL), nothing) + +AsyncBuffer(args...; kwargs...) = AsyncBuffer(Buffer(args...; kwargs...), nothing) diff --git a/src/xla/PJRT/Buffer.jl b/src/xla/PJRT/Buffer.jl new file mode 100644 index 0000000000..3a78830c8b --- /dev/null +++ b/src/xla/PJRT/Buffer.jl @@ -0,0 +1,100 @@ +mutable struct Buffer <: XLA.AbstractBuffer + buffer::Ptr{Cvoid} + + function Buffer(buffer::Ptr{Cvoid}) + return finalizer(free_buffer, new(buffer)) + end +end + +function Buffer(client::Client, array::Array{T,N}, device::Device) where {T,N} + sizear = collect(Int64, reverse(size(array))) + buffer = GC.@preserve array sizear begin + @ccall MLIR.API.mlir_c.ArrayFromHostBuffer( + client.client::Ptr{Cvoid}, + pointer(array)::Ptr{T}, + XLA.primitive_type(T)::UInt64, + N::Csize_t, + pointer(sizear)::Ptr{Int64}, + device.device::Ptr{Cvoid}, + )::Ptr{Cvoid} + end + return Buffer(buffer) +end + +@inline function free_buffer(buffer::Buffer) + sbuffer = buffer.buffer + if sbuffer != C_NULL + @ccall MLIR.API.mlir_c.PjRtBufferFree(sbuffer::Ptr{Cvoid})::Cvoid + end +end + +function Base.ndims(buffer::Buffer) + GC.@preserve buffer begin + return @ccall MLIR.API.mlir_c.BufferNDimensions(buffer.buffer::Ptr{Cvoid})::Cint + end +end + +function Base.size(buffer::Buffer) + GC.@preserve buffer begin + sz = @ccall MLIR.API.mlir_c.BufferShape(buffer.buffer::Ptr{Cvoid})::Ptr{Int64} + end + return Tuple(unsafe_wrap(Array, sz, ndims(buffer))) +end + +function Base.eltype(buffer::Buffer) + GC.@preserve buffer begin + return XLA.julia_type( + @ccall MLIR.API.mlir_c.BufferPrimitiveType(buffer.buffer::Ptr{Cvoid})::Cint + ) + end +end + +function XLA.device(buffer::Buffer) + GC.@preserve buffer begin + return Device( + @ccall MLIR.API.mlir_c.BufferToDevice(buffer.buffer::Ptr{Cvoid})::Ptr{Cvoid} + ) + end +end + +function XLA.client(buffer::Buffer) + GC.@preserve buffer begin + return Client( + @ccall MLIR.API.mlir_c.BufferToClient(buffer.buffer::Ptr{Cvoid})::Ptr{Cvoid} + ) + end +end + +XLA.synced_buffer(buffer::Buffer) = buffer + +function XLA.buffer_on_cpu(buffer::Buffer) + GC.@preserve buffer begin + return @ccall MLIR.API.mlir_c.BufferOnCPU(buffer.buffer::Ptr{Cvoid})::Bool + end +end + +function XLA.to_host(buffer::Buffer, data) + GC.@preserve buffer begin + @ccall MLIR.API.mlir_c.BufferToHost( + buffer.buffer::Ptr{Cvoid}, data::Ptr{Cvoid} + )::Cvoid + end +end + +# TODO: users themselves need to gc preserve here +function XLA.unsafe_buffer_pointer(buffer::Buffer) + @ccall MLIR.API.mlir_c.UnsafeBufferPointer(buffer.buffer::Ptr{Cvoid})::Ptr{Cvoid} +end + +function XLA.copy_buffer_to_device(buffer::Buffer, dev::Device) + XLA.device(buffer) == dev && return buffer + GC.@preserve buffer dev begin + Buffer( + @ccall MLIR.API.mlir_c.CopyBufferToDevice( + buffer.buffer::Ptr{Cvoid}, dev.device::Ptr{Cvoid} + )::Ptr{Cvoid} + ) + end +end + +XLA.sharding(::Buffer) = error("PJRT Buffers are not sharded.") diff --git a/src/xla/PJRT/Client.jl b/src/xla/PJRT/Client.jl new file mode 100644 index 0000000000..8842dc02f0 --- /dev/null +++ b/src/xla/PJRT/Client.jl @@ -0,0 +1,179 @@ +mutable struct Client <: XLA.AbstractClient + client::Ptr{Cvoid} + + function Client(client::Ptr{Cvoid}; skip_check::Bool=false) + skip_check || (@assert client != C_NULL) + return new(client) + end +end + +function XLA.free_client(client::Client) + GC.@preserve client begin + @ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid + end +end + +function XLA.num_devices(client::Client) + GC.@preserve client begin + return @ccall MLIR.API.mlir_c.ClientNumDevices(client.client::Ptr{Cvoid})::Cint + end +end + +function XLA.num_addressable_devices(client::Client) + GC.@preserve client begin + return @ccall MLIR.API.mlir_c.ClientNumAddressableDevices( + client.client::Ptr{Cvoid} + )::Cint + end +end + +function XLA.devices(client::Client) + ndevices = Int(XLA.num_devices(client)) + devices = Ref{NTuple{ndevices,Ptr{Cvoid}}}() + GC.@preserve client devices begin + @ccall MLIR.API.mlir_c.ClientGetDevices( + client.client::Ptr{Cvoid}, devices::Ptr{Ptr{Cvoid}} + )::Cvoid + end + return [Device(device) for device in devices[]] +end + +function XLA.addressable_devices(client::Client) + naddressable_devices = Int(XLA.num_addressable_devices(client)) + addressable_devices = Ref{NTuple{naddressable_devices,Ptr{Cvoid}}}() + GC.@preserve client addressable_devices begin + @ccall MLIR.API.mlir_c.ClientGetAddressableDevices( + client.client::Ptr{Cvoid}, addressable_devices::Ptr{Ptr{Cvoid}} + )::Cvoid + end + return [Device(device) for device in addressable_devices[]] +end + +function XLA.process_index(client::Client) + GC.@preserve client begin + return @ccall MLIR.API.mlir_c.ClientProcessIndex(client.client::Ptr{Cvoid})::Cint + end +end + +function XLA.get_device(client::Client, idx) + GC.@preserve client begin + return Device( + @ccall MLIR.API.mlir_c.ClientGetDevice( + client.client::Ptr{Cvoid}, idx::Cint + )::Ptr{Cvoid} + ) + end +end + +function XLA.get_addressable_device(client::Client, idx) + GC.@preserve client begin + return Device( + @ccall MLIR.API.mlir_c.ClientGetAddressableDevice( + client.client::Ptr{Cvoid}, idx::Cint + )::Ptr{Cvoid} + ) + end +end + +function XLA.platform_name(client::Client) + GC.@preserve client begin + str = @ccall MLIR.API.mlir_c.ClientGetPlatformName( + client.client::Ptr{Cvoid} + )::Cstring + end + return XLA.unsafe_string_and_free(str) +end + +# Different Backends +const cpu_client_count = Ref(0) +const gpu_client_count = Ref(0) +const tpu_client_count = Ref(0) + +for (backend, counter) in ( + (:CPUClient, :cpu_client_count), + (:GPUClient, :gpu_client_count), + (:TPUClient, :tpu_client_count), +) + main_fn = Symbol(:Make, backend) + @eval function $(backend)(args...; checkcount::Bool=true, kwargs...) + if checkcount + @assert $(counter)[] == 0 + end + client = Client($(main_fn)(args...; kwargs...)) + XLA.LLVMclopts("-nvptx-fma-level=1") + if checkcount + # Only increment the counter if we successfully created a client + $(counter)[] += 1 + end + return client + end +end + +function MakeCPUClient(; + node_id::Integer=0, + num_nodes::Integer=1, + asynchronous::Bool=true, + distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing, +) + @assert num_nodes == 1 "`PJRT.MakeCPUClient` does not support num_nodes > 1" + @assert distributed_runtime_client === nothing "`PJRT.MakeCPUClient` does not support \ + distributed_runtime_client" + + return @ccall MLIR.API.mlir_c.MakeCPUClient( + asynchronous::UInt8, node_id::Cint + )::Ptr{Cvoid} +end + +function MakeGPUClient(; + node_id::Integer=0, + num_nodes::Integer=1, + platform::String="gpu", + allowed_devices::Union{Nothing,Vector{Int}}=nothing, + distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing, +) + refstr = Ref{Cstring}() + + num_allowed_devices = allowed_devices === nothing ? 0 : length(allowed_devices) + allowed_devices = allowed_devices === nothing ? C_NULL : allowed_devices + distributed_runtime_client = + distributed_runtime_client === nothing ? C_NULL : distributed_runtime_client.client + + GC.@preserve refstr allowed_devices distributed_runtime_client begin + client = @ccall MLIR.API.mlir_c.MakeGPUClient( + node_id::Cint, + num_nodes::Cint, + allowed_devices::Ptr{Cvoid}, + num_allowed_devices::Cint, + XLA.XLA_REACTANT_GPU_MEM_FRACTION[]::Cdouble, + XLA.XLA_REACTANT_GPU_PREALLOCATE[]::Bool, + platform::Cstring, + refstr::Ptr{Cstring}, + distributed_runtime_client::Ptr{Cvoid}, + )::Ptr{Cvoid} + end + + client == C_NULL && throw(AssertionError(unsafe_string(refstr[]))) + return client +end + +function MakeTPUClient(; + tpu_path::String, + node_id::Integer=0, + num_nodes::Integer=1, + distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing, +) + @assert node_id == 0 "`PJRT.MakeTPUClient` does not support node_id" + @assert num_nodes == 1 "`PJRT.MakeTPUClient` does not support num_nodes > 1" + @assert distributed_runtime_client === nothing "`PJRT.MakeTPUClient` does not support \ + distributed_runtime_client" + + refstr = Ref{Cstring}() + GC.@preserve refstr begin + client = @ccall MLIR.API.mlir_c.MakeTPUClient( + tpu_path::Cstring, refstr::Ptr{Cstring} + )::Ptr{Cvoid} + end + + client == C_NULL && throw(AssertionError(unsafe_string(refstr[]))) + return client +end diff --git a/src/xla/PJRT/Device.jl b/src/xla/PJRT/Device.jl new file mode 100644 index 0000000000..c4c4a2caa5 --- /dev/null +++ b/src/xla/PJRT/Device.jl @@ -0,0 +1,34 @@ +struct Device <: XLA.AbstractDevice + device::Ptr{Cvoid} +end + +function XLA.client(device::Device) + GC.@preserve device begin + return Client( + @ccall MLIR.API.mlir_c.DeviceToClient(device.device::Ptr{Cvoid})::Ptr{Cvoid} + ) + end +end + +function XLA.device_ordinal(device::Device) + GC.@preserve device begin + return @ccall MLIR.API.mlir_c.PjRtDeviceGetLocalDeviceId( + device.device::Ptr{Cvoid} + )::Int64 + end +end + +function XLA.device_kind(device::Device) + GC.@preserve device begin + str = @ccall MLIR.API.mlir_c.DeviceGetKind(device.device::Ptr{Cvoid})::Cstring + end + return XLA.unsafe_string_and_free(str) +end + +function XLA.get_local_device_id(device::Device) + GC.@preserve device begin + return @ccall MLIR.API.mlir_c.PjRtDeviceGetLocalDeviceId( + device.device::Ptr{Cvoid} + )::Cint + end +end diff --git a/src/xla/PJRT/Future.jl b/src/xla/PJRT/Future.jl new file mode 100644 index 0000000000..0749735ab4 --- /dev/null +++ b/src/xla/PJRT/Future.jl @@ -0,0 +1,25 @@ +mutable struct Future <: XLA.AbstractFuture + future::Ptr{Cvoid} + + function Future(future::Ptr{Cvoid}) + @assert future != C_NULL + return finalizer(free_future, new(future)) + end +end + +@inline function free_future(future::Future) + @ccall MLIR.API.mlir_c.FreeFuture(future.future::Ptr{Cvoid})::Cvoid +end + +function XLA.is_ready(future::Future) + GC.@preserve future begin + return (@ccall MLIR.API.mlir_c.FutureIsReady(future.future::Ptr{Cvoid})::UInt8) != 0 + end +end + +@inline function XLA.await(future::Future)::Nothing + GC.@preserve future begin + @ccall MLIR.API.mlir_c.FutureAwait(future.future::Ptr{Cvoid})::Cvoid + end + return nothing +end diff --git a/src/xla/PJRT/LoadedExecutable.jl b/src/xla/PJRT/LoadedExecutable.jl new file mode 100644 index 0000000000..15d3dbc0d5 --- /dev/null +++ b/src/xla/PJRT/LoadedExecutable.jl @@ -0,0 +1,271 @@ +mutable struct LoadedExecutable <: XLA.AbstractLoadedExecutable + exec::Ptr{Cvoid} + num_outputs::Int64 + num_parameters::Int64 + is_sharded::Bool + num_replicas::Int64 + num_partitions::Int64 + + function LoadedExecutable(exec::Ptr{Cvoid}, args...) + @assert exec != C_NULL + return finalizer(free_exec, new(exec, args...)) + end +end + +@inline function free_exec(exec::LoadedExecutable) + @ccall MLIR.API.mlir_c.ExecutableFree(exec.exec::Ptr{Cvoid})::Cvoid +end + +function XLA.client(exec::LoadedExecutable) + GC.@preserve exec begin + return Client( + @ccall MLIR.API.mlir_c.PjRtLoadedExecutableGetClient( + exec.exec::Ptr{Cvoid} + )::Ptr{Cvoid} + ) + end +end + +XLA.num_partitions(exec::LoadedExecutable) = exec.num_partitions +XLA.num_replicas(exec::LoadedExecutable) = exec.num_replicas +XLA.num_devices(exec::LoadedExecutable) = XLA.num_replicas(exec) * XLA.num_partitions(exec) + +for (jlop, xlaop, field) in ( + (:get_output_shardings, :PjRtLoadedExecutableGetOuputShardings, :num_outputs), + (:get_parameter_shardings, :PjRtLoadedExecutableGetParameterShardings, :num_parameters), +) + @eval function XLA.$(jlop)(exec::LoadedExecutable) + exec.is_sharded || return XLA.OpSharding[] + + op_shardings = Ref{NTuple{exec.$(field),Ptr{Cvoid}}}() + + GC.@preserve op_shardings begin + @ccall MLIR.API.mlir_c.$(xlaop)( + exec.exec::Ptr{Cvoid}, op_shardings::Ptr{Ptr{Cvoid}}, exec.$(field)::Int32 + )::Cvoid + end + + return [XLA.OpSharding(op_sharding) for op_sharding in op_shardings[]] + end +end + +function XLA.get_hlo_modules(exec::LoadedExecutable) + # If we had compiled with MPMD then we would need all the partitions to get hlo_modules + # but if we used SPMD we get only 1 module. To be safe we allocate for all the modules + # and use the ones assigned to by XLA + hlo_modules = Ref{NTuple{Int64(XLA.num_partitions(exec)),Ptr{Cvoid}}}() + nmodules = Ref{Int32}(0) + GC.@preserve exec hlo_modules begin + @ccall MLIR.API.mlir_c.PjRtLoadedExecutableGetHloModules( + exec.exec::Ptr{Cvoid}, hlo_modules::Ptr{Ptr{Cvoid}}, nmodules::Ptr{Cint} + )::Cvoid + end + return map(XLA.HloModule, hlo_modules[][1:Int(nmodules[])]) +end + +function XLA.compile( + client::Client, + device::Union{Device,Nothing}, + mod::MLIR.IR.Module; + is_sharded::Bool=false, + global_device_ids::Vector{Int64}=Int64[], + num_outputs::Int64, + num_parameters::Int64, + num_replicas::Int64, + num_partitions::Int64, +) + device_id = is_sharded ? Int64(-1) : Int64(XLA.device_ordinal(device)) + GC.@preserve client mod begin + exec = @ccall MLIR.API.mlir_c.ClientCompile( + client.client::Ptr{Cvoid}, + mod.module_::MLIR.API.MlirModule, + device_id::Clong, + is_sharded::Bool, + global_device_ids::Ptr{Clong}, + length(global_device_ids)::Clong, + XLA.CUDA_DATA_DIR[]::Cstring, + )::Ptr{Cvoid} + end + return LoadedExecutable( + exec, num_outputs, num_parameters, is_sharded, num_replicas, num_partitions + ) +end + +function execute_ir(N, M, n_outs, fn, with_device::Bool, nmesh_ids::Int64) + ptr = sizeof(Int) == sizeof(Int64) ? "i64" : "i32" + cint = sizeof(Cint) == sizeof(Int64) ? "i64" : "i32" + args = N > 0 ? ", [$N x $ptr] %inps, [$M x i8] %donated" : "" + if with_device + args = "$ptr %dev $args" + else + args = "[$nmesh_ids x $ptr] %mesh_ids $args" + end + + stores = N > 0 ? """ + store [$N x $ptr] %inps, [$N x $ptr]* %inpa + store [$M x i8] %donated, [$M x i8]* %dona + """ : "" + + if !with_device + stores *= """ + store [$nmesh_ids x $ptr] %mesh_ids, [$nmesh_ids x $ptr]* %mesha + """ + end + + extra_str1 = with_device ? "$ptr" : "[$nmesh_ids x $ptr]*, i64" + extra_str2 = if with_device + "$ptr %dev" + else + "[$(nmesh_ids) x $ptr]* nocapture readonly %mesha, i64 $(nmesh_ids)" + end + + res = """define { [$n_outs x $ptr], [$n_outs x $ptr], i8 } @f($ptr %exec, $args) alwaysinline { + entry: + %inpa = alloca [$N x $ptr] + %dona = alloca [$M x i8] + %outa = alloca [$n_outs x $ptr] + %futpa = alloca [$n_outs x $ptr] + %mesha = alloca [$nmesh_ids x $ptr] + $stores + %futa = alloca i8 + call void inttoptr ($ptr $fn to void ($ptr, $cint, [$N x $ptr]*, $extra_str1, [$M x i8]*, $cint, [$n_outs x $ptr]*, i8*, [$n_outs x $ptr]*)*)($ptr %exec, $cint $N, [$N x $ptr]* nocapture readonly %inpa, $extra_str2, [$M x i8]* nocapture readonly %dona, $cint $n_outs, [$n_outs x $ptr]* nocapture writeonly %outa, i8* nocapture writeonly %futa, [$n_outs x $ptr]* nocapture writeonly %futpa) + %out = load [$n_outs x $ptr], [$n_outs x $ptr]* %outa + %fut = load i8, i8* %futa + %futp = load [$n_outs x $ptr], [$n_outs x $ptr]* %futpa + %fca.0.insert = insertvalue { [$n_outs x $ptr], [$n_outs x $ptr], i8 } undef, [$n_outs x $ptr] %out, 0 + %fca.1.insert = insertvalue { [$n_outs x $ptr], [$n_outs x $ptr], i8 } %fca.0.insert, [$n_outs x $ptr] %futp, 1 + %fca.2.insert = insertvalue { [$n_outs x $ptr], [$n_outs x $ptr], i8 } %fca.1.insert, i8 %fut, 2 + ret { [$n_outs x $ptr], [$n_outs x $ptr], i8 } %fca.2.insert + } + """ + return res +end + +@generated function XLA.execute_sharded( + exec::LoadedExecutable, + device::Device, + inputs::NTuple{N,Ptr{Cvoid}}, + donated_args::NTuple{N,UInt8}, + ::Val{n_outs}, +) where {N,n_outs} + sym0 = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "XLAExecuteSharded") + xla_execute_fn = reinterpret(UInt, sym0) + ir = execute_ir(N, N, n_outs, xla_execute_fn, true, 0) + results = [] + for i in 1:n_outs + push!( + results, + :(( + AsyncBuffer(Buffer(outputs[$i]), future ? Future(future_res[$i]) : nothing), + )), + ) + end + + args_type = if N > 0 + (Ptr{Cvoid}, Ptr{Cvoid}, NTuple{N,Ptr{Cvoid}}, NTuple{N,UInt8}) + else + (Ptr{Cvoid}, Ptr{Cvoid}) + end + args = N > 0 ? (:inputs, :donated_args) : () + return quote + Base.@_inline_meta + exec = exec.exec + device = device.device + GC.@preserve exec device begin + outputs, future_res, future = Base.llvmcall( + ($ir, "f"), + Tuple{NTuple{n_outs,Ptr{Cvoid}},NTuple{n_outs,Ptr{Cvoid}},Bool}, + Tuple{$args_type...}, + exec, + device, + $(args...), + ) + end + return ($(results...),) + end +end + +# XXX: Fix this +# @generated function XLA.execute( +# exec::LoadedExecutable, +# mesh_ids::Vector{Int64}, +# inputs::NTuple{N,Ptr{Cvoid}}, +# donated_args::NTuple{M,UInt8}, +# ::Val{n_outs}, +# ::Val{K}, +# ) where {N,M,K,n_outs} +# sym0 = dlsym(Reactant_jll.libReactantExtra_handle, "XLAExecute") +# xla_execute_fn = reinterpret(UInt, sym0) + +# ir = execute_ir(N, M, n_outs * K, xla_execute_fn, false, K) +# results = [Vector{Any}(undef, K) for i in 1:n_outs] +# for i in 1:n_outs, j in 1:K +# idx = (i - 1) * K + j +# results[i][j] = :(AsyncBuffer( +# Buffer(outputs[$idx]), future ? Future(future_res[$idx]) : nothing +# )) +# end + +# args_type = if N > 0 +# (Ptr{Cvoid}, Ptr{Clong}, NTuple{N,Ptr{Cvoid}}, NTuple{M,UInt8}) +# else +# (Ptr{Cvoid}, Ptr{Clong}) +# end +# args = N > 0 ? (:inputs, :donated_args) : () +# return quote +# Base.@_inline_meta +# exec = exec.exec +# GC.@preserve exec begin +# outputs, future_res, future = Base.llvmcall( +# ($ir, "f"), +# Tuple{NTuple{n_outs * K,Ptr{Cvoid}},NTuple{n_outs * K,Ptr{Cvoid}},Bool}, +# Tuple{$args_type...}, +# exec, +# mesh_ids, +# $(args...), +# ) +# end +# return ($(results...),) +# end +# end + +@inline function XLA.execute( + exec::LoadedExecutable, + inputs::NTuple{N,Ptr{Cvoid}}, + donated_args::NTuple{M,UInt8}, + ::Val{n_outs}, + ::Val{K}, +) where {N,M,n_outs,K} + outputs = Ref{NTuple{n_outs * K,Ptr{Cvoid}}}() + future_res = Ref{NTuple{n_outs * K,Ptr{Cvoid}}}() + futures = Ref{UInt8}(0) + + inputs = Base.RefValue(inputs) + donated_args = Base.RefValue(donated_args) + GC.@preserve inputs donated_args outputs futures future_res begin + @ccall MLIR.API.mlir_c.XLAExecute( + exec.exec::Ptr{Cvoid}, + N::Cint, + inputs::Ptr{Cvoid}, + donated_args::Ptr{UInt8}, + n_outs::Cint, + Base.unsafe_convert(Ptr{Cvoid}, outputs)::Ptr{Cvoid}, + Base.unsafe_convert(Ptr{UInt8}, futures)::Ptr{UInt8}, + Base.unsafe_convert(Ptr{Cvoid}, future_res)::Ptr{Cvoid}, + )::Cvoid + end + + outputs = outputs[] + future = futures[] != 0 + future && (future_res = future_res[]) + + return ntuple(Val(n_outs)) do j + ntuple(Val(K)) do i + Base.@_inline_meta + idx = (i - 1) * n_outs + j + return AsyncBuffer( + Buffer(outputs[idx]), future ? Future(future_res[idx]) : nothing + ) + end + end +end diff --git a/src/xla/PJRT/PJRT.jl b/src/xla/PJRT/PJRT.jl new file mode 100644 index 0000000000..1fab24650f --- /dev/null +++ b/src/xla/PJRT/PJRT.jl @@ -0,0 +1,16 @@ +module PJRT + +using ..Reactant: Reactant, MLIR +using ..XLA: XLA +using Reactant_jll: Reactant_jll + +using Libdl: Libdl + +include("Client.jl") +include("Device.jl") +include("Future.jl") +include("Buffer.jl") +include("AsyncBuffer.jl") +include("LoadedExecutable.jl") + +end diff --git a/src/xla/Sharding.jl b/src/xla/Sharding.jl new file mode 100644 index 0000000000..db4c968924 --- /dev/null +++ b/src/xla/Sharding.jl @@ -0,0 +1,386 @@ +@enumx OpShardingType begin + Replicated + Maximal + Tuple + Other + Manual + Unknown +end + +function Base.convert(::Type{OpShardingType.T}, i::Integer) + i == 0 && return OpShardingType.Replicated + i == 1 && return OpShardingType.Maximal + i == 2 && return OpShardingType.Tuple + i == 3 && return OpShardingType.Other + i == 4 && return OpShardingType.Manual + i == 5 && return OpShardingType.Unknown + return error("Invalid OpShardingType $i") +end + +# xla::OpSharding +mutable struct OpSharding + ptr::Ptr{Cvoid} + + function OpSharding(ptr::Ptr{Cvoid}) + @assert ptr != C_NULL + return finalizer(free_op_sharding, new(ptr)) + end +end + +function free_op_sharding(op_sharding::OpSharding) + @ccall MLIR.API.mlir_c.free_op_sharding(op_sharding.ptr::Ptr{Cvoid})::Cvoid +end + +function replicate_on_last_tile_dim(op_sharding::OpSharding) + GC.@preserve op_sharding begin + return @ccall MLIR.API.mlir_c.op_sharding_replicate_on_last_tile_dim( + op_sharding.ptr::Ptr{Cvoid} + )::Bool + end +end + +function op_sharding_type(op_sharding::OpSharding) + type = GC.@preserve op_sharding begin + @ccall MLIR.API.mlir_c.op_sharding_to_op_sharding_type( + op_sharding.ptr::Ptr{Cvoid} + )::Int32 + end + return convert(OpShardingType.T, type) +end + +function has_iota_reshape_dims(op_sharding::OpSharding) + GC.@preserve op_sharding begin + return @ccall MLIR.API.mlir_c.op_sharding_has_iota_reshape_dims( + op_sharding.ptr::Ptr{Cvoid} + )::Bool + end +end + +function iota_reshape_dims(op_sharding::OpSharding) + GC.@preserve op_sharding begin + ndims = @ccall MLIR.API.mlir_c.op_sharding_iota_reshape_dims_size( + op_sharding.ptr::Ptr{Cvoid} + )::Int32 + end + dimensions = Vector{Int32}(undef, ndims) + GC.@preserve op_sharding dimensions begin + @ccall MLIR.API.mlir_c.op_sharding_iota_reshape_dims( + op_sharding.ptr::Ptr{Cvoid}, dimensions::Ptr{Int32} + )::Cvoid + end + return dimensions +end + +function has_iota_transpose_perm(op_sharding::OpSharding) + GC.@preserve op_sharding begin + return @ccall MLIR.API.mlir_c.op_sharding_has_iota_transpose_perm( + op_sharding.ptr::Ptr{Cvoid} + )::Bool + end +end + +function iota_transpose_perm(op_sharding::OpSharding) + GC.@preserve op_sharding begin + ndims = @ccall MLIR.API.mlir_c.op_sharding_iota_transpose_perm_size( + op_sharding.ptr::Ptr{Cvoid} + )::Int32 + end + dimensions = Vector{Int32}(undef, ndims) + GC.@preserve op_sharding dimensions begin + @ccall MLIR.API.mlir_c.op_sharding_iota_transpose_perm( + op_sharding.ptr::Ptr{Cvoid}, dimensions::Ptr{Int32} + )::Cvoid + end + dimensions .+= 1 + return dimensions +end + +function tile_assignment_dimensions(op_sharding::OpSharding) + GC.@preserve op_sharding begin + ndims = @ccall MLIR.API.mlir_c.op_sharding_tile_assignment_dimensions_size( + op_sharding.ptr::Ptr{Cvoid} + )::Int32 + end + dimensions = Vector{Int32}(undef, ndims) + GC.@preserve op_sharding dimensions begin + @ccall MLIR.API.mlir_c.op_sharding_tile_assignment_dimensions( + op_sharding.ptr::Ptr{Cvoid}, dimensions::Ptr{Int32} + )::Cvoid + end + return dimensions +end + +function tile_assignment_devices(op_sharding::OpSharding) + GC.@preserve op_sharding begin + ndims = @ccall MLIR.API.mlir_c.op_sharding_tile_assignment_devices_size( + op_sharding.ptr::Ptr{Cvoid} + )::Int32 + end + devices = Vector{Int32}(undef, ndims) + GC.@preserve op_sharding devices begin + @ccall MLIR.API.mlir_c.op_sharding_tile_assignment_devices( + op_sharding.ptr::Ptr{Cvoid}, devices::Ptr{Int32} + )::Cvoid + end + return devices +end + +function has_last_tile_dims(op_sharding::OpSharding) + GC.@preserve op_sharding begin + return @ccall MLIR.API.mlir_c.op_sharding_has_last_tile_dims( + op_sharding.ptr::Ptr{Cvoid} + )::Bool + end +end + +# This separation is mostly for testing purposes +function generate_device_list_from_tile_assignment_devices(sharding::OpSharding) + return tile_assignment_devices(sharding) +end + +function generate_device_list_from_iota_tile(sharding::OpSharding) + return generate_device_list_from_iota_tile( + tile_assignment_dimensions(sharding), + iota_reshape_dims(sharding), + iota_transpose_perm(sharding), + ) +end + +function generate_device_list_from_iota_tile( + tile_assignment_dimensions, iota_reshape_dims, iota_transpose_perm +) + # Generate device IDs using iota + num_devices = prod(tile_assignment_dimensions) + ird = Int64.(iota_reshape_dims) + + # Permute the iota array if iota_transpose_perm is provided + # We need to ensure that we account for the col-major ordering in julia. See the + # unit tests for examples. + if !isempty(iota_transpose_perm) + # XXX: Simplify the permutedims + iota_devices = collect(Int64, reshape(0:(num_devices - 1), reverse(ird)...)) + + iota_devices = permutedims(iota_devices, reverse(1:ndims(iota_devices))) + iota_devices = permutedims(iota_devices, iota_transpose_perm) + iota_devices = permutedims(iota_devices, reverse(1:ndims(iota_devices))) + + return vec(iota_devices) + else + @assert num_devices == prod(ird) + return collect(0:(num_devices - 1)) + end +end + +function generate_device_list(sharding::OpSharding) + has_iota_reshape_dims(sharding) && return generate_device_list_from_iota_tile(sharding) + return generate_device_list_from_tile_assignment_devices(sharding) +end + +function get_number_of_ways_dim_sharded(op_sharding::OpSharding) + op_sharding_type(op_sharding) == OpShardingType.Replicated && return Int64[], 1 + td = tile_assignment_dimensions(op_sharding) + replicate_on_last_tile_dim(op_sharding) && return td[1:(end - 1)], td[end] + return td, 1 +end + +function sharding_to_concrete_array_indices(sharding::OpSharding, shape, device_ids) + return sharding_to_concrete_array_indices( + convert(CondensedOpSharding, sharding), shape, device_ids + ) +end + +function compute_array_indices_and_hlo_sharding( + sharding::OpSharding, array_size, device_ids +) + return compute_array_indices_and_hlo_sharding( + convert(CondensedOpSharding, sharding), array_size, device_ids + ) +end + +# This only stores the data that we currently support, and is useful for checking equality +# We would want to extend support to more of the fields at a later time +struct CondensedOpSharding{N} + opsharding::OpSharding + type::OpShardingType.T + replicate_on_last_tile_dim::Bool + tile_assignment::Array{Int64,N} +end + +function Base.:(==)(a::CondensedOpSharding, b::CondensedOpSharding) + return a.type == b.type && + a.replicate_on_last_tile_dim == b.replicate_on_last_tile_dim && + a.tile_assignment == b.tile_assignment +end + +function Base.convert(::Type{CondensedOpSharding}, sharding::OpSharding) + @assert !has_last_tile_dims(sharding) "Last Tile dimensions are not supported \ + yet!" + + type = op_sharding_type(sharding) + + if type == OpShardingType.Replicated || type == OpShardingType.Maximal + tile_assignment = generate_device_list(sharding) + elseif type == OpShardingType.Other + td = tile_assignment_dimensions(sharding) + tile_assignment = permutedims( + reshape(generate_device_list(sharding), Int64.(reverse(td))...), + reverse(1:length(td)), + ) + else + error("Invalid sharding type: $(type)") + end + + return CondensedOpSharding( + sharding, type, replicate_on_last_tile_dim(sharding), Int64.(tile_assignment) + ) +end + +function get_number_of_ways_dim_sharded(op_sharding::CondensedOpSharding{N}) where {N} + op_sharding.type == OpShardingType.Replicated && return Int64[], 1 + + if op_sharding.replicate_on_last_tile_dim + return ( + size(op_sharding.tile_assignment)[1:(N - 1)], + size(op_sharding.tile_assignment, N), + ) + end + return size(op_sharding.tile_assignment), 1 +end + +function sharding_to_concrete_array_indices( + sharding::CondensedOpSharding, shape::Dims{N}, device_ids +) where {N} + if sharding.type == OpShardingType.Replicated || sharding.type == OpShardingType.Maximal + return map(Returns(UnitRange.(1, shape)), device_ids), false + elseif sharding.type == OpShardingType.Other + partitions, num_replicas = get_number_of_ways_dim_sharded(sharding) + @assert length(partitions) == length(shape) + shape = reverse(shape) + + # XLA will automatically pad the inputs that don't match the final shape + partitionable_shape = map(zip(shape, partitions)) do (dim, n_shards) + dim % n_shards == 0 && return dim + res = dim + n_shards ÷ 2 + return res - res % n_shards + end + partitionable_shape = Tuple(partitionable_shape) + + needs_padding = any(partitionable_shape .!= shape) + + # Calculate indices for each dimension + axis_indices = + map(zip(partitionable_shape, shape, partitions)) do (dim_padded, dim, n_shards) + @assert dim > 0 "Invalid dimension: $dim" + @assert n_shards > 0 "Invalid number of shards: $n_shards" + n_shards == 1 && return [1:dim] + shard_size = dim_padded ÷ n_shards + + return [ + (i * shard_size + 1):min((i + 1) * shard_size, dim) for + i in 0:(n_shards - 1) + ] + end + + indices = Dict{Int,NTuple{N,UnitRange{Int}}}() + device_idx = 1 + for _ in 1:num_replicas + for idx_tuple in Iterators.product(axis_indices...) + indices[sharding.tile_assignment[device_idx]] = reverse(idx_tuple) + device_idx += 1 + end + end + + return map(Base.Fix1(getindex, indices), device_ids), needs_padding + else + error("Unsupported sharding type: $(sharding.type)") + end +end + +function compute_array_indices_and_hlo_sharding( + sharding::CondensedOpSharding, array_size, device_ids +) + return ( + first(sharding_to_concrete_array_indices(sharding, array_size, device_ids)), + convert(HloSharding, sharding.opsharding), + ) +end + +# Helper function to get device sequence along a dimension +function __get_device_sequence(arr, dim) + idx = ones(Int, ndims(arr)) + sequence = Int[] + for i in 1:size(arr, dim) + idx[dim] = i + push!(sequence, arr[idx...]) + end + return sequence +end + +# xla::HloSharding +mutable struct HloSharding + ptr::Ptr{Cvoid} + + function HloSharding(ptr::Ptr{Cvoid}) + @assert ptr != C_NULL + return finalizer(free_hlo_sharding, new(ptr)) + end +end + +function free_hlo_sharding(hlo_sharding::HloSharding) + @ccall MLIR.API.mlir_c.free_hlo_sharding(hlo_sharding.ptr::Ptr{Cvoid})::Cvoid +end + +function Base.convert(::Type{CondensedOpSharding}, hlo_sharding::HloSharding) + return convert(CondensedOpSharding, convert(OpSharding, hlo_sharding)) +end + +function Base.convert(::Type{OpSharding}, hlo_sharding::HloSharding) + GC.@preserve hlo_sharding begin + return OpSharding( + @ccall MLIR.API.mlir_c.hlo_sharding_to_op_sharding( + hlo_sharding.ptr::Ptr{Cvoid} + )::Ptr{Cvoid} + ) + end +end + +function Base.convert(::Type{HloSharding}, op_sharding::OpSharding) + GC.@preserve op_sharding begin + return HloSharding( + @ccall MLIR.API.mlir_c.hlo_sharding_from_op_sharding( + op_sharding.ptr::Ptr{Cvoid} + )::Ptr{Cvoid} + ) + end +end + +function Base.string(hlo_sharding::HloSharding) + GC.@preserve hlo_sharding begin + str = @ccall MLIR.API.mlir_c.hlo_sharding_to_string( + hlo_sharding.ptr::Ptr{Cvoid} + )::Cstring + end + return unsafe_string_and_free(str) +end + +function Base.show(io::IO, ::MIME"text/plain", hlo_sharding::HloSharding) + print(io, "XLA.HloSharding(\"", string(hlo_sharding), "\")") + return nothing +end + +function sharding_to_concrete_array_indices(sharding::HloSharding, shape, device_ids) + return sharding_to_concrete_array_indices( + convert(CondensedOpSharding, sharding), shape, device_ids + ) +end + +function compute_array_indices_and_hlo_sharding( + sharding::HloSharding, array_size, device_ids +) + return ( + compute_array_indices_and_hlo_sharding( + convert(CondensedOpSharding, sharding), array_size, device_ids + ), + sharding, + ) +end diff --git a/src/xla/Stats.jl b/src/xla/Stats.jl new file mode 100644 index 0000000000..629b3208a5 --- /dev/null +++ b/src/xla/Stats.jl @@ -0,0 +1,77 @@ +# To keep in sync with JLAllocatorStats in ReactantExtra/API.cpp +struct JLAllocatorStats + num_allocs::Int64 + bytes_in_use::Int64 + peak_bytes_in_use::Int64 + largest_alloc_size::Int64 + bytes_limit::Int64 + bytes_reserved::Int64 + peak_bytes_reserved::Int64 + bytes_reservable_limit::Int64 + largest_free_block_bytes::Int64 + pool_bytes::Int64 + peak_pool_bytes::Int64 +end + +""" + AllocatorStats() + +Contains the following fields: + - `num_allocs` + - `bytes_in_use` + - `peak_bytes_in_use` + - `largest_alloc_size` + - `bytes_limit` + - `bytes_reserved` + - `peak_bytes_reserved` + - `bytes_reservable_limit` + - `largest_free_block_bytes` + - `pool_bytes` + - `peak_pool_bytes` + +It should be constructed using the [`allocatorstats`](@ref) function. +""" +struct AllocatorStats + num_allocs::Int64 + bytes_in_use::Int64 + peak_bytes_in_use::Int64 + largest_alloc_size::Int64 + bytes_limit::Union{Nothing,Int64} + bytes_reserved::Int64 + peak_bytes_reserved::Int64 + bytes_reservable_limit::Union{Nothing,Int64} + largest_free_block_bytes::Int64 + pool_bytes::Union{Nothing,Int64} + peak_pool_bytes::Union{Nothing,Int64} +end + +""" + allocatorstats([device]) + +Return an [`AllocatorStats`](@ref) instance with information about the device specific allocator. + +!!! warning + This method is currently not implemented for the CPU device. +""" +function allocatorstats(device::AbstractDevice=XLA.default_device(XLA.default_backend())) + ref = Ref{JLAllocatorStats}() + @ccall MLIR.API.mlir_c.PjRtDeviceGetAllocatorStats( + device.device::Ptr{Cvoid}, ref::Ptr{Cvoid} + )::Cvoid + stats = ref[] + + nullopt = typemin(Int64) + return AllocatorStats( + stats.num_allocs, + stats.bytes_in_use, + stats.peak_bytes_in_use, + stats.largest_alloc_size, + stats.bytes_limit == nullopt ? nothing : stats.bytes_limit, + stats.bytes_reserved, + stats.peak_bytes_reserved, + stats.bytes_reservable_limit == nullopt ? nothing : stats.bytes_reservable_limit, + stats.largest_free_block_bytes, + stats.pool_bytes == nullopt ? nothing : stats.pool_bytes, + stats.peak_pool_bytes == nullopt ? nothing : stats.peak_pool_bytes, + ) +end diff --git a/src/xla/Utils.jl b/src/xla/Utils.jl new file mode 100644 index 0000000000..1bedcc88bb --- /dev/null +++ b/src/xla/Utils.jl @@ -0,0 +1,55 @@ +SetLogLevel(x) = @ccall MLIR.API.mlir_c.SetLogLevel(x::Cint)::Cvoid + +struct ReactantInternalError <: Base.Exception + msg::String +end + +function Base.showerror(io::IO, ece::ReactantInternalError) + return print(io, ece.msg, '\n') +end + +function reactant_err(msg::Cstring)::Cvoid + throw(ReactantInternalError(Base.unsafe_string(msg))) +end + +# https://github.com/openxla/xla/blob/4bfb5c82a427151d6fe5acad8ebe12cee403036a/xla/xla_data.proto#L29 +primitive_types_list = [ + (1, Bool), + (2, Int8), + (6, UInt8), + (3, Int16), + (7, UInt16), + (4, Int32), + (8, UInt32), + (5, Int64), + (9, UInt64), + (10, Float16), + (11, Float32), + (19, Reactant.F8E5M2), + (20, Reactant.F8E4M3FN), + (23, Reactant.F8E4M3B11FNUZ), + (24, Reactant.F8E5M2FNUZ), + (25, Reactant.F8E4M3FNUZ), + (12, Float64), + (15, Complex{Float32}), + (18, Complex{Float64}), +] + +@static if isdefined(Core, :BFloat16) + push!(primitive_types_list, (16, Core.BFloat16)) +end + +for (int_val, jl_type) in primitive_types_list + @eval begin + @inline primitive_type(::Type{$(jl_type)}) = $(int_val) + @inline julia_type(::Val{$(int_val)}) = $(jl_type) + end +end + +@inline julia_type(@nospecialize(x::Integer)) = julia_type(Val(Int64(x))) + +function unsafe_string_and_free(str::Cstring, args...) + str_jl = unsafe_string(str, args...) + @ccall free(str::Cstring)::Cvoid + return str_jl +end diff --git a/src/xla/XLA.jl b/src/xla/XLA.jl new file mode 100644 index 0000000000..97aacded34 --- /dev/null +++ b/src/xla/XLA.jl @@ -0,0 +1,193 @@ +module XLA + +using ..Reactant: Reactant, MLIR +using Reactant_jll +using Libdl +using Scratch, Downloads +using EnumX: @enumx + +const XLA_REACTANT_GPU_MEM_FRACTION = Ref{Float64}(0.75) +const XLA_REACTANT_GPU_PREALLOCATE = Ref{Bool}(true) + +using Reactant_jll +const CUDA_DATA_DIR = Ref( + isdefined(Reactant_jll, :ptxas_path) ? dirname(dirname(Reactant_jll.ptxas_path)) : "" +) + +function LLVMclopts(opts...) + args = ["", opts...] + @ccall MLIR.API.mlir_c.ReactantLLVMParseCommandLineOptions( + length(args)::Cint, args::Ptr{Cstring}, C_NULL::Ptr{Cvoid} + )::Cvoid +end + +include("Distributed.jl") +include("Client.jl") +include("Device.jl") +include("Sharding.jl") +include("LoadedExecutable.jl") +include("Future.jl") +include("Buffer.jl") +include("Stats.jl") +include("Utils.jl") +include("HloModule.jl") +include("Memory.jl") + +include("PJRT/PJRT.jl") + +include("IFRT/IFRT.jl") + +@kwdef mutable struct PJRTBackendState + initialized::Bool = false + clients::Dict{String,PJRT.Client} = Dict{String,PJRT.Client}() + default_client::PJRT.Client = PJRT.Client(C_NULL; skip_check=true) +end + +function Base.getproperty(bs::PJRTBackendState, sym::Symbol) + (sym === :initialized || bs.initialized) && return getfield(bs, sym) + initialize_default_pjrt_clients!(bs) + return getfield(bs, sym) +end + +function Base.setproperty!(bs::PJRTBackendState, sym::Symbol, val) + (sym === :initialized || bs.initialized) && return setfield!(bs, sym, val) + initialize_default_pjrt_clients!(bs) + return setfield!(bs, sym, val) +end + +const global_backend_state = PJRTBackendState() +const global_state = State() + +client(backend::String) = global_backend_state.clients[backend] +default_backend() = global_backend_state.default_client +process_index() = process_index(default_backend()) + +function set_default_backend(backend::AbstractClient) + global_backend_state.default_client = backend + return nothing +end + +function set_default_backend(backend::String) + global_backend_state.default_client = client(backend) + return nothing +end + +function update_global_state!(args...; kwargs...) + update!(global_state, args...; kwargs...) + # We conditionally initialize for now, since a lot of options that are set are not + # necessarily supported by PJRT. This makes testing for IFRT quite hard. + # Once we move to IFRT completely, we can remove this. + if global_backend_state.initialized + # We need to update the clients based on the new state + initialize_default_pjrt_clients!(global_backend_state) + end + return nothing +end + +function __init__() + # This must be the very first thing initialized (otherwise we can't throw errors) + errptr = cglobal((:ReactantThrowError, MLIR.API.mlir_c), Ptr{Ptr{Cvoid}}) + unsafe_store!(errptr, @cfunction(reactant_err, Cvoid, (Cstring,))) + + initLogs = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "InitializeLogs") + ccall(initLogs, Cvoid, ()) + # Add most log level + # SetLogLevel(0) + + if haskey(ENV, "XLA_REACTANT_GPU_MEM_FRACTION") + XLA_REACTANT_GPU_MEM_FRACTION[] = parse( + Float64, ENV["XLA_REACTANT_GPU_MEM_FRACTION"] + ) + @debug "XLA_REACTANT_GPU_MEM_FRACTION: " XLA_REACTANT_GPU_MEM_FRACTION[] + end + + if haskey(ENV, "XLA_REACTANT_GPU_PREALLOCATE") + XLA_REACTANT_GPU_PREALLOCATE[] = parse(Bool, ENV["XLA_REACTANT_GPU_PREALLOCATE"]) + @debug "XLA_REACTANT_GPU_PREALLOCATE: " XLA_REACTANT_GPU_PREALLOCATE[] + end + + if haskey(ENV, "REACTANT_VISIBLE_GPU_DEVICES") + global_state.local_gpu_device_ids = + parse.(Int, split(ENV["REACTANT_VISIBLE_GPU_DEVICES"], ",")) + @debug "REACTANT_VISIBLE_GPU_DEVICES: " global_state.local_gpu_device_ids + end + + @ccall MLIR.API.mlir_c.RegisterEnzymeXLACPUHandler()::Cvoid + @ccall MLIR.API.mlir_c.RegisterEnzymeXLAGPUHandler()::Cvoid + return nothing +end + +function initialize_default_pjrt_clients!(state::PJRTBackendState) + was_initialized = state.initialized + state.initialized = true + distributed_runtime_client = if global_state.num_processes > 1 + @assert global_state.client !== nothing + global_state.client + else + nothing + end + common_kwargs = (; + node_id=global_state.process_id, + num_nodes=global_state.num_processes, + distributed_runtime_client, + ) + + # CPU + if was_initialized && haskey(state.clients, "cpu") + XLA.free_client(state.clients["cpu"]) + XLA.PJRT.cpu_client_count[] -= 1 + end + cpu = PJRT.CPUClient(; common_kwargs..., asynchronous=true) + state.clients["cpu"] = cpu + state.default_client = cpu + + # Try TPU if possible, then try GPU (CUDA) + @static if !Sys.isapple() + if Reactant.has_tpu() + dataset_dir = @get_scratch!("libtpu") + if !isfile(dataset_dir * "/libtpu.so") + Downloads.download( + "https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20240829-py3-none-any.whl", + dataset_dir * "/tpu.zip", + ) + run(`unzip -qq $(dataset_dir*"/tpu.zip") -d $(dataset_dir)/tmp`) + run(`mv $(dataset_dir)/tmp/libtpu/libtpu.so $(dataset_dir)/libtpu.so`) + rm(dataset_dir * "/tmp"; recursive=true) + rm(dataset_dir * "/tpu.zip"; recursive=true) + end + try + if was_initialized && haskey(state.clients, "tpu") + XLA.free_client(state.clients["tpu"]) + XLA.PJRT.tpu_client_count[] -= 1 + end + tpu = PJRT.TPUClient(; + tpu_path=dataset_dir * "/libtpu.so", common_kwargs... + ) + state.clients["tpu"] = tpu + state.default_client = tpu + catch e + println(stdout, e) + end + else + if !Reactant.precompiling() + try + if was_initialized && haskey(state.clients, "gpu") + XLA.free_client(state.clients["gpu"]) + XLA.PJRT.gpu_client_count[] -= 1 + end + gpu = PJRT.GPUClient(; + common_kwargs..., allowed_devices=global_state.local_gpu_device_ids + ) + state.clients["gpu"] = gpu + state.default_client = gpu + catch e + println(stdout, e) + end + end + end + end + + return nothing +end + +end diff --git a/test/Project.toml b/test/Project.toml index 0b49f6692a..8d46640284 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,25 +1,28 @@ [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" +Float8s = "81dfefd7-55b0-40c6-a251-db853704e186" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Random123 = "74087812-796a-5b5d-8853-05524746bad3" -Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" @@ -28,29 +31,36 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -ArrayInterface = "7.10" +Adapt = "4.1" +ArrayInterface = "7.17.1" BenchmarkTools = "1.5" -CUDA = "5" +CUDA = "5.5" Distributions = "0.25" -Enzyme = "0.13.21" +Enzyme = "0.13.28" FFTW = "1.8" +Float8s = "0.1" Flux = "0.15, 0.16" Functors = "0.5" HypothesisTests = "0.11" InteractiveUtils = "1.10" +KernelAbstractions = "0.9.30" LinearAlgebra = "1.10" Lux = "1.4.1" LuxLib = "1.3" MLUtils = "0.4.4" NNlib = "0.9.26" +OffsetArrays = "1" OneHotArrays = "0.2.6" Optimisers = "0.4" PythonCall = "0.9" Random = "1.10" -Random123 = "1" +Random123 = "1.7" SafeTestsets = "0.1" SpecialFunctions = "2.4" StableRNGs = "1" Statistics = "1.10" StatsBase = "0.34" Test = "1.10" + +[extras] +Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0" diff --git a/test/autodiff.jl b/test/autodiff.jl index 044799bcb1..edf32c32e5 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -13,12 +13,12 @@ fwd(Mode, RT, x, y) = Enzyme.autodiff(Mode, square, RT, Duplicated(x, y)) fwd( Forward, Duplicated, - ConcreteRArray(ones(3, 2)), - ConcreteRArray(3.1 * ones(3, 2)), + Reactant.to_rarray(ones(3, 2)), + Reactant.to_rarray(3.1 * ones(3, 2)), ) ) - @test typeof(res1) == Tuple{ConcreteRArray{Float64,2}} + @test typeof(res1) == Tuple{ConcretePJRTArray{Float64,2,1,Sharding.NoShardInfo}} @test res1[1] ≈ ores1[1] ores1 = fwd(ForwardWithPrimal, Duplicated, ones(3, 2), 3.1 * ones(3, 2)) @@ -29,12 +29,15 @@ fwd(Mode, RT, x, y) = Enzyme.autodiff(Mode, square, RT, Duplicated(x, y)) fwd( set_abi(ForwardWithPrimal, Reactant.ReactantABI), Duplicated, - ConcreteRArray(ones(3, 2)), - ConcreteRArray(3.1 * ones(3, 2)), + Reactant.to_rarray(ones(3, 2)), + Reactant.to_rarray(3.1 * ones(3, 2)), ) ) - @test typeof(res1) == Tuple{ConcreteRArray{Float64,2},ConcreteRArray{Float64,2}} + @test typeof(res1) == Tuple{ + ConcretePJRTArray{Float64,2,1,Sharding.NoShardInfo}, + ConcretePJRTArray{Float64,2,1,Sharding.NoShardInfo}, + } @test res1[1] ≈ ores1[1] @test res1[2] ≈ ores1[2] @@ -42,7 +45,12 @@ fwd(Mode, RT, x, y) = Enzyme.autodiff(Mode, square, RT, Duplicated(x, y)) @test typeof(ores1) == Tuple{} res1 = @jit( - fwd(Forward, Const, ConcreteRArray(ones(3, 2)), ConcreteRArray(3.1 * ones(3, 2))) + fwd( + Forward, + Const, + Reactant.to_rarray(ones(3, 2)), + Reactant.to_rarray(3.1 * ones(3, 2)), + ) ) @test typeof(res1) == Tuple{} @@ -54,12 +62,12 @@ fwd(Mode, RT, x, y) = Enzyme.autodiff(Mode, square, RT, Duplicated(x, y)) fwd( set_abi(ForwardWithPrimal, Reactant.ReactantABI), Const, - ConcreteRArray(ones(3, 2)), - ConcreteRArray(3.1 * ones(3, 2)), + Reactant.to_rarray(ones(3, 2)), + Reactant.to_rarray(3.1 * ones(3, 2)), ) ) - @test typeof(res1) == Tuple{ConcreteRArray{Float64,2}} + @test typeof(res1) == Tuple{ConcretePJRTArray{Float64,2,1,Sharding.NoShardInfo}} @test res1[1] ≈ ores1[1] end @@ -68,11 +76,13 @@ function gw(z) end @testset "Forward Gradient" begin - x = Reactant.ConcreteRArray(3.1 * ones(2, 2)) - res = @jit gw(x) + x = Reactant.Reactant.to_rarray(3.1 * ones(2, 2)) + res = @test_warn r"`Adapt.parent_type` is not implemented for" @jit gw(x) # TODO we should probably override https://github.com/EnzymeAD/Enzyme.jl/blob/5e6a82dd08e74666822b9d7b2b46c36b075668ca/src/Enzyme.jl#L2132 # to make sure this gets merged as a tracedrarray - @test typeof(res) == Tuple{Enzyme.TupleArray{ConcreteRNumber{Float64},(2, 2),4,2}} + @test typeof(res) == Tuple{ + Enzyme.TupleArray{ConcretePJRTNumber{Float64,1,Sharding.NoShardInfo},(2, 2),4,2} + } @test res[1] ≈ ones(2, 2) end @@ -107,16 +117,39 @@ end ret = @jit Enzyme.gradient(Reverse, cached_return, x_ra, Const(stret)) @test @allowscalar all(isone, ret[1]) - @test stret.st isa ConcreteRArray + @test stret.st isa ConcretePJRTArray @test stret.st ≈ x .+ 1 stret = StateReturn1(nothing, nothing) ret = @jit Enzyme.gradient(Reverse, cached_return, x_ra, Const(stret)) @test @allowscalar all(isone, ret[1]) - @test stret.st1 isa ConcreteRArray + @test stret.st1 isa ConcretePJRTArray @test stret.st1 ≈ x .+ 1 - @test stret.st2 isa ConcreteRArray + @test stret.st2 isa ConcretePJRTArray @test stret.st2 ≈ x .+ 1 @test stret.st1 === stret.st2 end + +@testset "Nested AD" begin + x = ConcreteRNumber(3.1) + f(x) = x * x * x * x + df(x) = Enzyme.gradient(Reverse, f, x)[1] + res1 = @jit df(x) + @test res1 ≈ 4 * 3.1^3 + ddf(x) = Enzyme.gradient(Reverse, df, x)[1] + res2 = @jit ddf(x) + @test res2 ≈ 4 * 3 * 3.1^2 +end + +@testset "Seed initialization of Complex arrays on matmul: Issue #593" begin + a = ones(ComplexF64, 2, 2) + b = 2.0 * ones(ComplexF64, 2, 2) + a_re = Reactant.to_rarray(a) + b_re = Reactant.to_rarray(b) + df(x, y) = Enzyme.gradient(ReverseWithPrimal, *, x, y) + res = @jit df(a_re, b_re) # before, this segfaulted + @test res.val ≈ 4ones(2, 2) + @test res.derivs[1] ≈ 4ones(2, 2) + @test res.derivs[2] ≈ 2ones(2, 2) +end diff --git a/test/basic.jl b/test/basic.jl index edb0c0e354..3702318408 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -2,6 +2,8 @@ using Reactant using Test using Enzyme using Statistics +using Random +Random.seed!(123) fastmax(x::AbstractArray{T}) where {T} = reduce(max, x; dims=1, init=float(T)(-Inf)) @@ -12,7 +14,7 @@ using InteractiveUtils r_res = sum(x) - a = Reactant.ConcreteRArray(x) + a = Reactant.to_rarray(x) c_res = @allowscalar sum(a) @test c_res ≈ r_res @@ -25,7 +27,7 @@ end r_res = fastmax(x) - a = Reactant.ConcreteRArray(x) + a = Reactant.to_rarray(x) c_res = @allowscalar fastmax(a) @test c_res ≈ r_res @@ -41,7 +43,7 @@ sinexpbc(x) = sinexp.(x) r_res = sinexpbc(x) - a = Reactant.ConcreteRArray(x) + a = Reactant.to_rarray(x) c_res = @allowscalar sinexpbc(a) @test c_res ≈ r_res @@ -55,7 +57,7 @@ sum_compare(x) = sum(x) > 0 @testset "Basic mapreduce" begin x = rand(Float32, 10) - a = Reactant.ConcreteRArray(x) + a = Reactant.to_rarray(x) r_res = sumexp(x) f_res = @jit sumexp(a) @@ -76,7 +78,7 @@ end x = rand(2, 10) r_res = mysoftmax!(x) - a = Reactant.ConcreteRArray(x) + a = Reactant.to_rarray(x) f_res = @jit mysoftmax!(a) @test f_res ≈ r_res @@ -86,7 +88,7 @@ bcast_cos(x) = cos.(x) @testset "Basic cos" begin x = rand(3, 2) - c = Reactant.ConcreteRArray(x) + c = Reactant.to_rarray(x) @test @jit(bcast_cos(c)) ≈ cos.(x) end @@ -101,9 +103,7 @@ f_var(args...) = sum(args) @test @jit(f_var(x, y, z)) ≈ [6.6, 6.6, 6.6] end -function sumcos(x) - return sum(cos.(x)) -end +sumcos(x) = sum(cos.(x)) function grad_ip(x) dx = Enzyme.make_zero(x) @@ -118,7 +118,7 @@ function resgrad_ip(x) end @testset "Basic grad cos" begin - c = Reactant.ConcreteRArray(ones(3, 2)) + c = Reactant.to_rarray(ones(3, 2)) @test @jit(grad_ip(c)) ≈ -sin.(ones(3, 2)) @@ -128,35 +128,32 @@ end @test r ≈ -sin.(ones(3, 2)) end -function mul(A, B) - return A * B -end @testset "matmul" begin - c = Reactant.ConcreteRArray(ones(50, 70)) - d = Reactant.ConcreteRArray(ones(70, 30)) + c = Reactant.to_rarray(ones(50, 70)) + d = Reactant.to_rarray(ones(70, 30)) - @test @jit(mul(c, d)) ≈ mul(ones(50, 70), ones(70, 30)) + @test @jit(*(c, d)) ≈ *(ones(50, 70), ones(70, 30)) end -@testset "ConcreteRArray" begin - c = Reactant.ConcreteRArray(ones(50, 70)) +@testset "similar Reactant.to_rarray" begin + c = Reactant.to_rarray(ones(50, 70)) sim_c = similar(c) @test typeof(sim_c) == typeof(c) && size(sim_c) == size(sim_c) end -@testset "Reactant.@code_hlo" begin - W = Reactant.ConcreteRArray(randn(Float32, 10, 20)) - x = Reactant.ConcreteRArray(randn(Float32, 20, 5)) - res = Reactant.@code_hlo W * x +@testset "@code_hlo" begin + W = Reactant.to_rarray(randn(Float32, 10, 20)) + x = Reactant.to_rarray(randn(Float32, 20, 5)) + res = @code_hlo W * x res_repr = sprint(show, res) @test contains(res_repr, "stablehlo.dot_general") end -@testset "Reactant.@code_hlo broadcasting" begin - x = Reactant.ConcreteRArray(randn(Float32, 2, 2)) - y = Reactant.ConcreteRArray(randn(Float32, 2, 2)) - res = Reactant.@code_hlo (.+)(x, y) +@testset "@code_hlo broadcasting" begin + x = Reactant.to_rarray(randn(Float32, 2, 2)) + y = Reactant.to_rarray(randn(Float32, 2, 2)) + res = @code_hlo (.+)(x, y) res_repr = sprint(show, res) @test contains(res_repr, "stablehlo.add") @@ -164,51 +161,19 @@ end @testset "Statistics: `mean` & `var`" begin x = randn(2, 3, 4) - x_ca = Reactant.ConcreteRArray(x) - - # XXX: @jit doesn't work with `;` - # @test @jit(mean(x_ca)) ≈ mean(x) - # @test @jit(mean(x_ca; dims=1)) ≈ mean(x; dims=1) - # @test @jit(mean(x_ca; dims=(1, 2))) ≈ mean(x; dims=(1, 2)) - # @test @jit(mean(x_ca; dims=(1, 3))) ≈ mean(x; dims=(1, 3)) - - mean_fn1(x) = mean(x) - mean_fn2(x) = mean(x; dims=1) - mean_fn3(x) = mean(x; dims=(1, 2)) - mean_fn4(x) = mean(x; dims=(1, 3)) - - mean_fn1_compiled = @compile mean_fn1(x_ca) - mean_fn2_compiled = @compile mean_fn2(x_ca) - mean_fn3_compiled = @compile mean_fn3(x_ca) - mean_fn4_compiled = @compile mean_fn4(x_ca) - - @test mean_fn1(x) ≈ mean_fn1_compiled(x_ca) - @test mean_fn2(x) ≈ mean_fn2_compiled(x_ca) - @test mean_fn3(x) ≈ mean_fn3_compiled(x_ca) - @test mean_fn4(x) ≈ mean_fn4_compiled(x_ca) - - # XXX: @jit doesn't work with `;` - # @test @jit(var(x_ca)) ≈ var(x) - # @test @jit(var(x_ca; dims=1)) ≈ var(x; dims=1) - # @test @jit(var(x_ca; dims=(1, 2), corrected=false)) ≈ - # var(x; dims=(1, 2), corrected=false) - # @test @jit(var(x_ca; dims=(1, 3), corrected=false)) ≈ - # var(x; dims=(1, 3), corrected=false) - - var_fn1(x) = var(x) - var_fn2(x) = var(x; dims=1) - var_fn3(x) = var(x; dims=(1, 2), corrected=false) - var_fn4(x) = var(x; dims=(1, 3), corrected=false) - - var_fn1_compiled = @compile var_fn1(x_ca) - var_fn2_compiled = @compile var_fn2(x_ca) - var_fn3_compiled = @compile var_fn3(x_ca) - var_fn4_compiled = @compile var_fn4(x_ca) - - @test var_fn1(x) ≈ var_fn1_compiled(x_ca) - @test var_fn2(x) ≈ var_fn2_compiled(x_ca) - @test var_fn3(x) ≈ var_fn3_compiled(x_ca) - @test var_fn4(x) ≈ var_fn4_compiled(x_ca) + x_ca = Reactant.to_rarray(x) + + @test @jit(mean(x_ca)) ≈ mean(x) + @test @jit(mean(x_ca; dims=1)) ≈ mean(x; dims=1) + @test @jit(mean(x_ca; dims=(1, 2))) ≈ mean(x; dims=(1, 2)) + @test @jit(mean(x_ca; dims=(1, 3))) ≈ mean(x; dims=(1, 3)) + + @test @jit(var(x_ca)) ≈ var(x) + @test @jit(var(x_ca, dims=1)) ≈ var(x; dims=1) + @test @jit(var(x_ca, dims=(1, 2); corrected=false)) ≈ + var(x; dims=(1, 2), corrected=false) + @test @jit(var(x_ca; dims=(1, 3), corrected=false)) ≈ + var(x; dims=(1, 3), corrected=false) end @testset "concatenation" begin @@ -362,98 +327,68 @@ end @test y == test_typed_hvncat(x) @test eltype(y) === Int end + + @testset "Number and RArray" for a in [1.0f0, 1.0e0] + typeof_a = typeof(a) + _b = typeof_a.([2.0, 3.0, 4.0]) + _c = typeof_a.([2.0 3.0 4.0]) + b = Reactant.to_rarray(_b) + c = Reactant.to_rarray(_c) + + # vcat test + y = @jit vcat(a, b) + @test y == vcat(a, _b) + @test y isa ConcreteRArray{typeof_a,1} + + ## vcat test - adjoint + y1 = @jit vcat(a, c') + @test y1 == vcat(a, _c') + @test y1 isa ConcreteRArray{typeof_a,2} + + # hcat test + z = @jit hcat(a, c) + @test z == hcat(a, _c) + @test z isa ConcreteRArray{typeof_a,2} + + ## hcat test - adjoint + z1 = @jit hcat(a, b') + @test z1 == hcat(a, _b') + @test z1 isa ConcreteRArray{typeof_a,2} + end end @testset "repeat" begin + fn_inner(x, counts) = repeat(x; inner=counts) + @testset for (size, counts) in Iterators.product( [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)], [(), (1,), (2,), (2, 1), (1, 2), (2, 2), (2, 2, 2), (1, 1, 1, 1, 1)], ) x = rand(size...) - @test (@jit repeat(Reactant.to_rarray(x), counts...)) == repeat(x, counts...) - end -end - -function update_on_copy(x) - y = x[1:2, 2:4, :] - y[1:1, 1:1, :] = ones(1, 1, 3) - return y -end - -@testset "view / setindex" begin - x = rand(2, 4, 3) - y = copy(x) - x_concrete = Reactant.to_rarray(x) - y_concrete = Reactant.to_rarray(y) - - y1 = update_on_copy(x) - y2 = @jit update_on_copy(x_concrete) - @test x == y - @test x_concrete == y_concrete - @test y1 == y2 - - # function update_inplace(x) - # y = view(x, 1:2, 1:2, :) - # y[1, 1, :] .= 1 - # return y - # end - - # get_indices(x) = x[1:2, 1:2, :] - # get_view(x) = view(x, 1:2, 1:2, :) - - # get_indices_compiled = @compile get_indices(x_concrete) - # get_view_compiled = @compile get_view(x_concrete) -end - -function masking(x) - y = similar(x) - y[1:2, :] .= 0 - y[3:4, :] .= 1 - return y -end - -function masking!(x) - x[1:2, :] .= 0 - x[3:4, :] .= 1 - return x -end - -@testset "setindex! with views" begin - x = rand(4, 4) .+ 2.0 - x_ra = Reactant.to_rarray(x) - - y = masking(x) - y_ra = @jit(masking(x_ra)) - @test y ≈ y_ra - x_ra_array = Array(x_ra) - @test !(any(iszero, x_ra_array[1, :])) - @test !(any(iszero, x_ra_array[2, :])) - @test !(any(isone, x_ra_array[3, :])) - @test !(any(isone, x_ra_array[4, :])) + @testset "outer repeat" begin + @test (@jit repeat(Reactant.to_rarray(x), counts...)) == repeat(x, counts...) + end - y_ra = @jit(masking!(x_ra)) - @test y ≈ y_ra + length(counts) < length(size) && continue - x_ra_array = Array(x_ra) - @test @allowscalar all(iszero, x_ra_array[1, :]) - @test @allowscalar all(iszero, x_ra_array[2, :]) - @test @allowscalar all(isone, x_ra_array[3, :]) - @test @allowscalar all(isone, x_ra_array[4, :]) + @testset "inner repeat" begin + @test (@jit fn_inner(Reactant.to_rarray(x), counts)) == fn_inner(x, counts) + end + end end tuple_byref(x) = (; a=(; b=x)) -tuple_byref2(x) = abs2.(x), tuple_byref2(x) +tuple_byref2(x) = abs2.(x), tuple_byref(x) @testset "Tuple byref" begin x = Reactant.to_rarray([1.0 -2.0; -3.0 4.0]) @test @jit(tuple_byref(x)).a.b.data === x.data - # TODO this seems to hang during compile - # f2 = @compile tuple_byref2(x) - # r2 = f2(x) - # @test r2[2].a.b.data === x.data - # @test r2[1] == abs2.([1.0 -2.0; -3.0 4.0]) + f2 = @compile tuple_byref2(x) + r2 = f2(x) + @test r2[2].a.b.data === x.data + @test r2[1] == abs2.([1.0 -2.0; -3.0 4.0]) end sum_xxᵀ(x) = sum(x .* x') @@ -487,21 +422,21 @@ end f1(x) = x[1] * x[2] - x_ra = Reactant.to_rarray(x; track_numbers=(Number,)) + x_ra = Reactant.to_rarray(x; track_numbers=Number) f2 = @compile f1(x_ra) - @test f2(Reactant.to_rarray((5, 5.2); track_numbers=(Number,))) ≈ 5 * 5.2 - @test f2(Reactant.to_rarray((5, 5.2); track_numbers=(Number,))) isa ConcreteRNumber + @test f2(Reactant.to_rarray((5, 5.2); track_numbers=Number)) ≈ 5 * 5.2 + @test f2(Reactant.to_rarray((5, 5.2); track_numbers=Number)) isa ConcretePJRTNumber x_ra = Reactant.to_rarray(x) f3 = @compile f1(x_ra) @test f3(Reactant.to_rarray((5, 5.2))) ≈ f1(x) - @test !(f3(Reactant.to_rarray((5, 5.2))) isa ConcreteRNumber) + @test !(f3(Reactant.to_rarray((5, 5.2))) isa ConcretePJRTNumber) @test f3(Reactant.to_rarray((5, 5.2))) isa Number - x_ra = Reactant.to_rarray(x; track_numbers=(Int,)) + x_ra = Reactant.to_rarray(x; track_numbers=Int) f4 = @compile f1(x_ra) - @test f4(Reactant.to_rarray((5, 5.2); track_numbers=(Int,))) ≈ 5 * 3.14 - @test f4(Reactant.to_rarray((5, 5.2); track_numbers=(Int,))) isa ConcreteRNumber + @test f4(Reactant.to_rarray((5, 5.2); track_numbers=Int)) ≈ 5 * 3.14 + @test f4(Reactant.to_rarray((5, 5.2); track_numbers=Int)) isa ConcretePJRTNumber end @testset "Mixed" begin @@ -509,10 +444,10 @@ end f1(x) = x[1] * x[2] - x_ra = Reactant.to_rarray(x; track_numbers=(Number,)) + x_ra = Reactant.to_rarray(x; track_numbers=Number) f2 = @compile f1(x_ra) - res2 = f2(Reactant.to_rarray((5, [3.14]); track_numbers=(Number,))) + res2 = f2(Reactant.to_rarray((5, [3.14]); track_numbers=Number)) @test @allowscalar(only(res2)) ≈ 5 * 3.14 @test res2 isa ConcreteRArray @@ -545,7 +480,7 @@ end @test Float32(x) isa Float32 @test Float64(x) isa Float64 @test Int(x) isa Int - @test float(x) isa ConcreteRNumber{Float64} + @test float(x) isa ConcretePJRTNumber{Float64} end @testset "concrete number with fill" begin @@ -577,24 +512,28 @@ end end end -@testset "dynamic indexing" begin - x = randn(5, 3) - x_ra = Reactant.to_rarray(x) - - idx = [1, 2, 3] - idx_ra = Reactant.to_rarray(idx) - - fn(x, idx) = @allowscalar x[idx, :] +@testset "$op" for op in [:round, :ceil, :floor] + for x in (rand(Float32, (3, 3)), rand(Float64)) + intop = gensym("int_$op") + @eval begin + @test @jit($op.(ConcretePJRTNumber.($x))) == $op.($x) + $intop(x) = $op(Int, x) + @test @jit($intop.(ConcretePJRTNumber.($x))) == $intop.($x) + end + end +end - y = @jit(fn(x_ra, idx_ra)) - @test y ≈ x[idx, :] +@testset "sign" begin + x = collect(Float64, 0:0.01:1) .- 0.5 + x_ra = Reactant.to_rarray(x) + @test Array(@jit(sign.(x_ra))) ≈ sign.(x) end @testset "aos_to_soa" begin using ArrayInterface x_res = collect(reshape(1.0:4.0, 2, 1, 2)) - x_ca = ConcreteRNumber.(x_res) + x_ca = ConcretePJRTNumber.(x_res) y_ca1 = @allowscalar ArrayInterface.aos_to_soa(x_ca) @test y_ca1 ≈ x_res @@ -609,7 +548,7 @@ end x = randn(2, 3) x_ra = Reactant.to_rarray(x) - @testset "ConcreteRArray" begin + @testset "Reactant.to_rarray" begin y = collect(x_ra) @test y == x @test y !== x_ra @@ -624,10 +563,9 @@ end x = 5 x_ra = ConcreteRNumber(x) - @testset "ConcreteRNumber" begin + @testset "ConcretePJRTNumber" begin y = collect(x_ra) - @test y isa ConcreteRArray{Int,0} - @test y == x + @test y isa Array{Int,0} end @testset "TracedRArray" begin @@ -637,10 +575,12 @@ end end end -function f_row_major(x) +function f_row_major(x::AbstractArray{T}) where {T} y = [1 2; 3 4; 5 6] if x isa Reactant.TracedRArray - y = Reactant.TracedUtils.promote_to(Reactant.TracedRArray{eltype(x),2}, y) + y = Reactant.TracedUtils.promote_to( + Reactant.TracedRArray{Reactant.unwrapped_eltype(T),2}, y + ) end return x .+ y end @@ -653,19 +593,27 @@ end end @testset "ifelse" begin - @test 1.0 == - @jit ifelse(ConcreteRNumber(true), ConcreteRNumber(1.0), ConcreteRNumber(0.0f0)) + @test 1.0 == @test_warn r"`ifelse` with different element-types" @jit( + ifelse(ConcreteRNumber(true), ConcreteRNumber(1.0), ConcreteRNumber(0.0f0)) + ) @test @jit( ifelse(ConcreteRNumber(false), ConcreteRNumber(1.0), ConcreteRNumber(0.0f0)) - ) isa ConcreteRNumber{Float64} + ) isa ConcretePJRTNumber{Float64} @test 0.0f0 == @jit ifelse(ConcreteRNumber(false), ConcreteRNumber(1.0), ConcreteRNumber(0.0f0)) @test @jit( ifelse(ConcreteRNumber(false), ConcreteRNumber(1.0f0), ConcreteRNumber(0.0f0)) - ) isa ConcreteRNumber{Float32} + ) isa ConcretePJRTNumber{Float32} + + cond = ConcreteRNumber(true) + x = ConcreteRNumber(1.0) + @test @jit(ifelse(cond, x, 0.0)) == ConcreteRNumber(1.0) + @test @jit(ifelse(cond, 0.0, x)) == ConcreteRNumber(0.0) + @test @jit(ifelse(cond, 1.0, 0.0)) == ConcreteRNumber(1.0) + @test @jit(ifelse(cond, 0.0, 1.0)) == ConcreteRNumber(0.0) end -@testset "fill! and zero on ConcreteRArray" begin +@testset "fill! and zero on Reactant.to_rarray" begin x_ra = Reactant.to_rarray(rand(3, 4)) z = zero(x_ra) @@ -693,14 +641,301 @@ end @test @allowscalar T[1][1] == 2 ptr_x = Base.unsafe_convert( - Ptr{Float64}, Reactant.XLA.UnsafeBufferPointer(x.data.buffer) + Ptr{Float64}, Reactant.XLA.unsafe_buffer_pointer(x.data[1].buffer) ) ptr_res = Base.unsafe_convert( - Ptr{Float64}, Reactant.XLA.UnsafeBufferPointer(res.data.buffer) + Ptr{Float64}, Reactant.XLA.unsafe_buffer_pointer(res.data[1].buffer) ) ptr_T1 = Base.unsafe_convert( - Ptr{Float64}, Reactant.XLA.UnsafeBufferPointer(T[1].data.buffer) + Ptr{Float64}, Reactant.XLA.unsafe_buffer_pointer(T[1].data[1].buffer) ) @test ptr_x == ptr_res == ptr_T1 end + +@testset "eltype conversion inside interpreter" begin + function test_convert(x::AbstractArray{T}, eta) where {T} + eta = T(eta) + return x .* eta, eta + end + + res = @jit test_convert(Reactant.to_rarray(rand(4, 2)), ConcreteRNumber(3.0f0)) + + @test res[1] isa ConcreteRArray{Float64,2} + @test res[2] isa ConcretePJRTNumber{Float64} +end + +@testset "stack" begin + x = rand(4, 4) + y = rand(4, 4) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) + + s1(x) = stack((x, x)) + s2(x) = stack((x, x); dims=2) + s3(x, y) = stack((x, y); dims=2) + s4(x, y) = stack((x, y, x); dims=1) + + @test @jit(s1(x_ra)) ≈ s1(x) + @test @jit(s2(x_ra)) ≈ s2(x) + @test @jit(s3(x_ra, y_ra)) ≈ s3(x, y) + @test @jit(s4(x_ra, y_ra)) ≈ s4(x, y) + + # Test that we don't hit illegal instruction; `x` is intentionally not a traced array + @test @jit(s1(x)) isa Any + @test @jit(s2(x)) isa Any + @test @jit(s3(x, y)) isa Any + @test @jit(s4(x, y)) isa Any +end + +@testset "unstable stack" begin + x = rand(4, 4) + y = rand(4, 4) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) + + function s1(x) + xs = [] + push!(xs, x) + push!(xs, x) + return stack(xs) + end + function s2(x) + xs = [] + push!(xs, x) + push!(xs, x) + return stack(xs; dims=2) + end + function s3(x, y) + xs = [] + push!(xs, x) + push!(xs, y) + return stack(xs; dims=2) + end + function s4(x, y) + xs = [] + push!(xs, x) + push!(xs, y) + push!(xs, x) + return stack(xs; dims=2) + end + + @test @jit(s1(x_ra)) ≈ s1(x) + @test @jit(s2(x_ra)) ≈ s2(x) + @test @jit(s3(x_ra, y_ra)) ≈ s3(x, y) + @test @jit(s4(x_ra, y_ra)) ≈ s4(x, y) + + # Test that we don't hit illegal instruction; `x` is intentionally not a traced array + @test @jit(s1(x)) isa Any + @test @jit(s2(x)) isa Any + @test @jit(s3(x, y)) isa Any + @test @jit(s4(x, y)) isa Any +end + +@testset "duplicate args (#226)" begin + first_arg(x, y) = x + x_ra = Reactant.to_rarray(rand(2, 2)) + res = @jit first_arg(x_ra, x_ra) + @test res ≈ x_ra +end + +@testset "Common Trig Functions" begin + x = rand(Float32, 4, 16) + x_ra = Reactant.to_rarray(x) + + @testset for fn in (sinpi, cospi, tanpi, sin, cos, tan) + @test @jit(fn.(x_ra)) ≈ fn.(x) + @test @jit(fn.(x_ra)) isa ConcreteRArray{Float32,2} + end + + x = 0.235f0 + x_ra = Reactant.to_rarray(x; track_numbers=Number) + + @testset for fn in (sinpi, cospi, tanpi, sin, cos, tan) + @test @jit(fn.(x_ra)) ≈ fn.(x) + @test @jit(fn.(x_ra)) isa ConcretePJRTNumber{Float32} + end + @testset for fn in (sincospi, sincos) + res = @jit fn(x_ra) + @test res[1] ≈ fn(x)[1] + @test res[2] ≈ fn(x)[2] + @test res[1] isa ConcretePJRTNumber{Float32} + @test res[2] isa ConcretePJRTNumber{Float32} + end +end + +@testset "isfinite" begin + x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN]) + @test @jit(isfinite.(x)) == [true, false, false, false, false] + + x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN] .* im) + @test @jit(isfinite.(x)) == [true, false, false, false, false] +end + +@testset "isnan" begin + x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN]) + @test @jit(isnan.(x)) == [false, true, false, false, true] + + x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN] .* im) + @test @jit(isnan.(x)) == [false, true, false, false, true] +end + +@testset "isnan/isfinite" begin + @test isnan(Reactant.to_rarray(NaN; track_numbers=Number)) + @test !isnan(Reactant.to_rarray(0.0; track_numbers=Number)) + @test isfinite(Reactant.to_rarray(0.0; track_numbers=Number)) + @test !isfinite(Reactant.to_rarray(Inf; track_numbers=Number)) +end + +@testset "isinf" begin + @test Bool(@jit(isinf(ConcreteRNumber(Inf)))) + @test Bool(@jit(isinf(ConcreteRNumber(-Inf)))) + @test !Bool(@jit(isinf(ConcreteRNumber(2)))) + @test !Bool(@jit(isinf(ConcreteRNumber(2.0)))) + @test !Bool(@jit(isinf(ConcreteRNumber(true)))) +end + +@testset "mod and rem" begin + a = [-1.1, 7.7, -3.3, 9.9, -5.5] + b = [6.6, -2.2, -8.8, 4.4, -10.1] + + expected_mod = mod.(a, b) + @test @jit(mod.(Reactant.to_rarray(a), Reactant.to_rarray(b))) ≈ expected_mod + @test @jit(mod.(a, Reactant.to_rarray(b))) ≈ expected_mod + @test @jit(mod.(Reactant.to_rarray(a), b)) ≈ expected_mod + + expected_rem = rem.(a, b) + @test @jit(rem.(Reactant.to_rarray(a), Reactant.to_rarray(b))) ≈ expected_rem + @test @jit(rem.(a, Reactant.to_rarray(b))) ≈ expected_rem + @test @jit(rem.(Reactant.to_rarray(a), b)) ≈ expected_rem +end + +@testset "xor" begin + for a in (true, false), b in (true, false) + @test @jit(xor(ConcreteRNumber(a), ConcreteRNumber(b))) == xor(a, b) + end +end + +@testset "signbit" begin + for x in (-4, -3.14, -0.0f0, 0.0, 0, 5, 6.28f0) + @test @jit(signbit(ConcreteRNumber(x))) == signbit(x) + end +end + +@testset "copysign" begin + for a in (-3.14, -2, 0.0, 2.71, 42), b in (-7, -0.57, -0.0, 1, 3.14) + # Make sure also the return type is correct + @test Reactant.to_number(@jit(copysign(ConcreteRNumber(a), ConcreteRNumber(b)))) === + copysign(a, b) + end +end + +@testset "reduce integers" begin + x = rand(Bool, 100) + x_ra = Reactant.to_rarray(x) + + @test @jit(sum(x_ra)) == sum(x) + + x = rand(Int16, 100) + x_ra = Reactant.to_rarray(x) + + @test @jit(sum(x_ra)) == sum(x) +end + +@testset "/ on integers" begin + @test @jit(/(ConcreteRNumber(2), ConcreteRNumber(4))) ≈ 0.5 + @test @jit(/(ConcreteRNumber(2), 4)) ≈ 0.5 + @test @jit(/(2, ConcreteRNumber(4))) ≈ 0.5 + @test @jit(/(2, ConcreteRNumber(Int32(4)))) ≈ 0.5 +end + +@testset "Broadcasting with Range" begin + x = Reactant.to_rarray(rand(10)) + fn(x) = x .+ (1:length(x)) + + @test @jit(fn(x)) ≈ fn(Array(x)) +end + +function fntest1(x) + y = similar(x, 1, 1, 8) + sum!(y, x) + return y +end + +function fntest2(x) + y = similar(x, 2, 1, 8) + sum!(y, x) + return y +end + +function fntest3(x) + y = similar(x, 2, 1, 1) + sum!(abs2, y, x) + return y +end + +@testset "mapreducedim!" begin + x = reshape(collect(Float32, 1:64), 2, 4, 8) ./ 64 + x_ra = Reactant.to_rarray(x) + + @test Array(@jit(fntest1(x_ra))) ≈ fntest1(x) + @test Array(@jit(fntest2(x_ra))) ≈ fntest2(x) + @test Array(@jit(fntest3(x_ra))) ≈ fntest3(x) +end + +@testset "don't expand ranges by default" begin + fn(x) = Reactant.TracedUtils.broadcast_to_size(x, (length(x),)) + + hlo = repr(@code_hlo(fn(1:10000))) + @test contains(hlo, "stablehlo.iota") + @test contains(hlo, "stablehlo.add") + @test Array(@jit(fn(1:10000))) ≈ collect(1:10000) + + hlo = repr(@code_hlo(fn(32:10000))) + @test contains(hlo, "stablehlo.iota") + @test contains(hlo, "stablehlo.add") + @test Array(@jit(fn(32:10000))) ≈ collect(32:10000) + + hlo = repr(@code_hlo(fn(0:10000))) + @test contains(hlo, "stablehlo.iota") + @test !contains(hlo, "stablehlo.add") + @test Array(@jit(fn(0:10000))) ≈ collect(0:10000) + + hlo = repr(@code_hlo(fn(Base.OneTo(10000)))) + @test contains(hlo, "stablehlo.iota") + @test contains(hlo, "stablehlo.add") + @test Array(@jit(fn(Base.OneTo(10000)))) ≈ collect(Base.OneTo(10000)) +end + +function dip!(x) + x[:a] = x[:a] .* x[:b] + return nothing +end + +@testset "Dict" begin + x = Dict{Symbol,Vector{Float32}}() + x[:a] = 2.7 * ones(4) + x[:b] = 3.1 * ones(4) + + ra = Reactant.to_rarray(x) + @jit dip!(ra) + ra[:a] ≈ (2.7 * 2) * ones(4) +end + +@testset "@code_xla" begin + x_ra = Reactant.to_rarray(ones(4)) + hlo = repr(@code_xla(sin.(x_ra))) + @test contains(hlo, "HloModule") + @test contains(hlo, "sine") +end + +@testset "Raise keyword" begin + v = randn(Float32, 16) + rv = Reactant.to_rarray(v) + @test sin.(v) ≈ @jit raise = true sin.(rv) + @test cos.(v) ≈ @jit raise = false cos.(rv) + @test exp.(v) ≈ @jit raise = "canonicalize" exp.(rv) + @test_throws Reactant.MLIR.IR.AddPipelineException @jit raise = "this_pass-does_not_ExisT" exp.( + rv + ) +end diff --git a/test/bcast.jl b/test/bcast.jl index fd3dad23f7..26262672ff 100644 --- a/test/bcast.jl +++ b/test/bcast.jl @@ -13,10 +13,7 @@ mutable struct Data v::Reactant.TracedRArray{Float64,1} end @noinline function tmp(a, b, d) - @show d - @show typeof(d) c = d.v - @show typeof(c) return reshape(a, (4,)) ./ sqrt.(b .+ a) end @@ -52,17 +49,17 @@ function test() end end - return println(string(mod)) + return string(mod) end end -test() +@test test() == "module {\n}" -@testset "ConcreteRArray broadcasting" begin +@testset "ConcretePJRTArray broadcasting" begin x = ones(10, 10) y = ones(10, 10) - x_ca = Reactant.ConcreteRArray(x) - y_ca = Reactant.ConcreteRArray(y) + x_ca = Reactant.to_rarray(x) + y_ca = Reactant.to_rarray(y) @testset "Broadcasting" begin @test x .+ y ≈ @jit x_ca .+ y_ca diff --git a/test/buffer_donation.jl b/test/buffer_donation.jl index 8d15de9235..ec363c3519 100644 --- a/test/buffer_donation.jl +++ b/test/buffer_donation.jl @@ -13,25 +13,21 @@ function donate_inplace_mul(x, y) end @testset "buffer_donation" begin - a = Reactant.ConcreteRArray(ones(2, 2)) - b = Reactant.ConcreteRArray(3 * ones(2, 2)) + a = Reactant.to_rarray(ones(2, 2)) + b = Reactant.to_rarray(3 * ones(2, 2)) @jit(donate_fill_x_with_2(a, b)) @test convert(Array, a) == 2 * ones(2, 2) - _, _, _, preserved_args, _, _, _ = Reactant.Compiler.compile_xla( - donate_fill_x_with_2, (a, b) - ) + (; preserved_args) = Reactant.Compiler.compile_xla(donate_fill_x_with_2, (a, b))[3] preserved_args_idx = last.(preserved_args) @test preserved_args_idx == [1] # only `y`(i.e. `b`) is preserved - a = Reactant.ConcreteRArray(2 * ones(2, 2)) - b = Reactant.ConcreteRArray(3 * ones(2, 2)) + a = Reactant.to_rarray(2 * ones(2, 2)) + b = Reactant.to_rarray(3 * ones(2, 2)) @jit(donate_inplace_mul(a, b)) @test convert(Array, a) == 6 * ones(2, 2) - _, _, _, preserved_args, _, _, _ = Reactant.Compiler.compile_xla( - donate_inplace_mul, (a, b) - ) + (; preserved_args) = Reactant.Compiler.compile_xla(donate_inplace_mul, (a, b))[3] preserved_args_idx = last.(preserved_args) @test preserved_args_idx == [1] # only `y`(i.e. `b`) is preserved end diff --git a/test/closure.jl b/test/closure.jl index d6eb350082..65930f58a4 100644 --- a/test/closure.jl +++ b/test/closure.jl @@ -6,8 +6,8 @@ muler(x) = y -> x * y @testset "closure" begin x = ones(2, 2) y = ones(2, 2) - x_ra = Reactant.ConcreteRArray(x) - y_ra = Reactant.ConcreteRArray(y) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) f = muler(x_ra) @test @jit(f(y_ra)) ≈ x * y diff --git a/test/compile.jl b/test/compile.jl index 18821e430a..9b47865baf 100644 --- a/test/compile.jl +++ b/test/compile.jl @@ -10,7 +10,7 @@ Base.sum(x::NamedTuple{(:a,),Tuple{T}}) where {T<:Reactant.TracedRArray} = (; a= x2 = Reactant.to_rarray(x) res = @jit sum(x2) - @test res isa @NamedTuple{a::Reactant.ConcreteRNumber{Float64}} + @test res isa @NamedTuple{a::ConcreteRNumber{Float64,1,Sharding.NoShardInfo}} @test isapprox(res.a, sum(x.a)) end @@ -30,8 +30,8 @@ Base.sum(x::NamedTuple{(:a,),Tuple{T}}) where {T<:Reactant.TracedRArray} = (; a= @testset "world-age" begin a = ones(2, 10) b = ones(10, 2) - a_ra = Reactant.ConcreteRArray(a) - b_ra = Reactant.ConcreteRArray(b) + a_ra = Reactant.to_rarray(a) + b_ra = Reactant.to_rarray(b) fworld(x, y) = @jit(x * y) @@ -40,7 +40,7 @@ Base.sum(x::NamedTuple{(:a,),Tuple{T}}) where {T<:Reactant.TracedRArray} = (; a= @testset "type casting & optimized out returns" begin a = ones(2, 10) - a_ra = Reactant.ConcreteRArray(a) + a_ra = Reactant.to_rarray(a) ftype1(x) = Float64.(x) ftype2(x) = Float32.(x) @@ -48,8 +48,8 @@ Base.sum(x::NamedTuple{(:a,),Tuple{T}}) where {T<:Reactant.TracedRArray} = (; a= y1 = @jit ftype1(a_ra) y2 = @jit ftype2(a_ra) - @test y1 isa Reactant.ConcreteRArray{Float64,2} - @test y2 isa Reactant.ConcreteRArray{Float32,2} + @test y1 isa Reactant.ConcretePJRTArray{Float64,2} + @test y2 isa Reactant.ConcretePJRTArray{Float32,2} @test y1 ≈ Float64.(a) @test y2 ≈ Float32.(a) @@ -84,13 +84,13 @@ end hlo_code = @code_hlo f(x_ra) @test !startswith(string(hlo_code), "Module") - @test startswith(string(hlo_code), "module {") + @test startswith(string(hlo_code), "module") end @testset "Bool attributes" begin - x_ra = Reactant.to_rarray(false; track_numbers=(Number,)) + x_ra = Reactant.to_rarray(false; track_numbers=Number) @test @jit(iszero(x_ra)) == true - x_ra = Reactant.to_rarray(true; track_numbers=(Number,)) + x_ra = Reactant.to_rarray(true; track_numbers=Number) @test @jit(iszero(x_ra)) == false end @@ -101,3 +101,54 @@ end @test @allowscalar(x_ra[1]) ≈ x[1] @test @allowscalar(x_ra[1:1]) ≈ x[1:1] end + +@testset "no_nan passes" begin + x_ra = Reactant.to_rarray(rand(Float32, 4, 16)) + y_ra = Reactant.to_rarray(rand(Float32, 4, 16)) + + fn(x) = x .- x + + hlo = @code_hlo fn(x_ra) + @test occursin("subtract", repr(hlo)) + @test !occursin("constant", repr(hlo)) + hlo = @code_hlo no_nan = true fn(x_ra) + @test !occursin("subtract", repr(hlo)) + @test occursin("constant", repr(hlo)) + + fn(x, y) = begin + c = x .+ y + return c .- y + end + + hlo = @code_hlo fn(x_ra, y_ra) + @test occursin("subtract", repr(hlo)) + @test occursin("add", repr(hlo)) + hlo = @code_hlo no_nan = true fn(x_ra, y_ra) + @test !occursin("subtract", repr(hlo)) + @test !occursin("add", repr(hlo)) +end + +# While a bit specific, the following is used to check for a bug in `should_rewrite_call` +function sinusoidal_embedding( + x::AbstractArray{T,4}, min_freq, max_freq, embedding_dims::Int +) where {T} + if size(x)[1:3] != (1, 1, 1) + throw(DimensionMismatch("Input shape must be (1, 1, 1, batch)")) + end + + lower, upper = log(T(min_freq)), log(T(max_freq)) + n = embedding_dims ÷ 2 + x_ = 2 .* x .* exp.(reshape(range(lower, upper; length=n), 1, 1, n, 1)) + return cat(sinpi.(x_), cospi.(x_); dims=Val(3)) +end + +@testset "sinusoidal_embedding" begin + x_ra = Reactant.to_rarray(rand(Float32, 1, 1, 1, 4)) + hlo = @code_hlo sinusoidal_embedding(x_ra, 0.1, 10.0, 4) +end + +# test #493 +@testset "unique(::Vector{Symbol}) (#493)" begin + x = [:a, :b, :a] + @test @jit(unique(x)) == [:a, :b] +end diff --git a/test/complex.jl b/test/complex.jl index 43e3c4f6b3..728cf53ad2 100644 --- a/test/complex.jl +++ b/test/complex.jl @@ -89,7 +89,7 @@ end @testset "promote_to Complex" begin x = 1.0 + 2.0im - y = Reactant.ConcreteRNumber(x) + y = ConcreteRNumber(x) f = Reactant.compile((y,)) do z z + Reactant.TracedUtils.promote_to(Reactant.TracedRNumber{ComplexF64}, 1.0 - 3.0im) @@ -100,6 +100,6 @@ end @testset "complex reduction" begin x = randn(ComplexF32, 10, 10) - x_ra = Reactant.ConcreteRArray(x) + x_ra = Reactant.to_rarray(x) @test @jit(sum(abs2, x_ra)) ≈ sum(abs2, x) end diff --git a/test/control_flow.jl b/test/control_flow.jl index 9b4ee9fcf8..63f51f126e 100644 --- a/test/control_flow.jl +++ b/test/control_flow.jl @@ -1,5 +1,6 @@ using Reactant, Test using LinearAlgebra +using Reactant.ReactantCore function condition1(x) y = sum(x) @@ -219,14 +220,14 @@ end @testset "condition6: bareif relu" begin x = 2.0 - x_ra = Reactant.to_rarray(x; track_numbers=(Number,)) + x_ra = Reactant.to_rarray(x; track_numbers=Number) res_ra = @jit(condition6_bareif_relu(x_ra)) res = condition6_bareif_relu(x) @test res_ra ≈ res x = -2.0 - x_ra = Reactant.to_rarray(x; track_numbers=(Number,)) + x_ra = Reactant.to_rarray(x; track_numbers=Number) res_ra = @jit(condition6_bareif_relu(x_ra)) res = condition6_bareif_relu(x) @@ -246,21 +247,21 @@ end @testset "condition7: bare elseif" begin x = 2.0 - x_ra = Reactant.to_rarray(x; track_numbers=(Number,)) + x_ra = Reactant.to_rarray(x; track_numbers=Number) res_ra = @jit(condition7_bare_elseif(x_ra)) res = condition7_bare_elseif(x) @test res_ra ≈ res x = -2.0 - x_ra = Reactant.to_rarray(x; track_numbers=(Number,)) + x_ra = Reactant.to_rarray(x; track_numbers=Number) res_ra = @jit(condition7_bare_elseif(x_ra)) res = condition7_bare_elseif(x) @test res_ra ≈ res x = 0.0 - x_ra = Reactant.to_rarray(x; track_numbers=(Number,)) + x_ra = Reactant.to_rarray(x; track_numbers=Number) res_ra = @jit(condition7_bare_elseif(x_ra)) res = condition7_bare_elseif(x) @@ -355,7 +356,7 @@ function condition10_condition_with_setindex(x) @trace if sum(x) > 0 x[:, 1] = -1.0 else - x[1, 1] = 1.0 + @allowscalar x[1, 1] = 1.0 end return x end @@ -367,8 +368,8 @@ end res_ra = @jit(condition10_condition_with_setindex(x_ra)) @test @allowscalar(res_ra[1, 1]) == -1.0 @test @allowscalar(res_ra[2, 1]) == -1.0 - @test @allowscalar(x_ra[1, 1]) == -1.0 broken = true - @test @allowscalar(x_ra[2, 1]) == -1.0 broken = true + @test @allowscalar(x_ra[1, 1]) == -1.0 + @test @allowscalar(x_ra[2, 1]) == -1.0 x = -rand(2, 10) x[2, 1] = 0.0 @@ -377,7 +378,7 @@ end res_ra = @jit(condition10_condition_with_setindex(x_ra)) @test @allowscalar(res_ra[1, 1]) == 1.0 @test @allowscalar(res_ra[2, 1]) == 0.0 - @test @allowscalar(x_ra[1, 1]) == 1.0 broken = true + @test @allowscalar(x_ra[1, 1]) == 1.0 @test @allowscalar(x_ra[2, 1]) == 0.0 end @@ -455,9 +456,39 @@ end condition12_compile_test(x, y, z) end +function condition_with_structure(x) + y = x .+ 1 + @trace if sum(y) > 0 + z = (; a=y, b=(y .- 1, y)) + else + z = (; a=-y, b=(y, y .+ 1)) + end + return z +end + +@testset "condition with structure" begin + x = rand(2, 10) + x_ra = Reactant.to_rarray(x) + + res_ra = @jit condition_with_structure(x_ra) + res = condition_with_structure(x) + @test res_ra.a ≈ res.a + @test res_ra.b[1] ≈ res.b[1] + @test res_ra.b[2] ≈ res.b[2] + + x = -rand(2, 10) + x_ra = Reactant.to_rarray(x) + + res_ra = @jit condition_with_structure(x_ra) + res = condition_with_structure(x) + @test res_ra.a ≈ res.a + @test res_ra.b[1] ≈ res.b[1] + @test res_ra.b[2] ≈ res.b[2] +end + function for_with_step(x) @trace for i in 10:3:22 - x[i] = i * i + @allowscalar x[i] = i * i end return x end @@ -539,7 +570,7 @@ function cumsum!(x) v = zero(eltype(x)) @trace for i in 1:length(x) v += @allowscalar x[i] - x[i] = v + @allowscalar x[i] = v end return x end @@ -565,3 +596,180 @@ end @test @jit(for_ref_outer(x_ra)) ≈ for_ref_outer(x) end + +function for_inner_scope(x) + @trace for i in 1:10 + s = sum(x) + x = x / s + end + return x +end + +@testset "for: inner scope" begin + x = randn(Float64, 10) + x_ra = Reactant.to_rarray(x) + + @test @jit(for_inner_scope(x_ra)) ≈ for_inner_scope(x) +end + +function for_with_named_tuple(x) + st = (; x) + res = x + @trace for i in 1:10 + res .= res .+ st.x + end + return res +end + +@testset "for: named tuple" begin + x = randn(Float64, 10) + x_ra = Reactant.to_rarray(x) + + @test @jit(for_with_named_tuple(x_ra)) ≈ for_with_named_tuple(x) +end + +_call1(a, b) = a +function call1(a, b) + x = @trace _call1(a, b) + y = @trace _call1(a, b) + return @trace _call1(x, y) +end + +@testset "call: basic" begin + a = rand(2, 3) + b = rand(2, 3) + a_ra = Reactant.to_rarray(a) + b_ra = Reactant.to_rarray(b) + + @test @jit(call1(a_ra, b_ra)) ≈ call1(a, b) + + # check whether the func for _call1 was only generated once: + ir = @code_hlo optimize = false call1(a_ra, b_ra) + ops = [op for op in Reactant.MLIR.IR.OperationIterator(Reactant.MLIR.IR.body(ir))] + @test length(ops) == 2 # call1, _call1 + + # With different operand sizes, different functions need to be generated: + c = rand(4, 5) + c_ra = Reactant.to_rarray(c) + + @test @jit(call1(a_ra, c_ra)) ≈ call1(a, c) + ir = @code_hlo optimize = false call1(a_ra, c_ra) + ops = [op for op in Reactant.MLIR.IR.OperationIterator(Reactant.MLIR.IR.body(ir))] + @test length(ops) == 3 +end + +_call2(a) = a + a +function call2(a) + return @trace _call2(a) +end + +@testset "call: rnumber" begin + a = 10 + a_rn = ConcreteRNumber(a) + + @test @jit(call2(a_rn)) == call2(a) +end + +function _call3(x::Int, y) + if x > 10 + return y .+ y + else + return y .* y + end +end + +function call3(y) + z = @trace _call3(1, y) + @trace _call3(1, z) # doesn't generate new function because y.shape == z.shape + @trace _call3(11, y) # new function because x changed. +end + +@testset "call: caching for Julia operands" begin + y = rand(3) + y_ra = Reactant.to_rarray(y) + + ir = @code_hlo optimize = false call3(y_ra) + ops = [op for op in Reactant.MLIR.IR.OperationIterator(Reactant.MLIR.IR.body(ir))] + @test length(ops) == 5 # call3, .+, .*, _call3 (2X) +end + +struct Foo + x +end +struct Bar + x +end + +_call4(foobar::Union{Foo,Bar}) = foobar.x +function call4(foo, foo2, bar) + @trace _call4(foo) + @trace _call4(foo2) + @trace _call4(bar) +end + +@testset "call: Caching struct arguments" begin + a = rand(10) + b = rand(10) + foo = Foo(Reactant.to_rarray(a)) + foo2 = Foo(Reactant.to_rarray(b)) + bar = Foo(Bar(Reactant.to_rarray(b))) # typeof(foo) == typeof(bar), but these don't match! + ir = @code_hlo optimize = false call4(foo, foo2, bar) + ops = [op for op in Reactant.MLIR.IR.OperationIterator(Reactant.MLIR.IR.body(ir))] + @test length(ops) == 3 # call4, _call4 for {foo, foo2}, and _call4 for bar +end + +function _call5!(a, b) + @allowscalar a[1] = zero(eltype(a)) + return b +end + +function call5!(a, b) + @trace _call5!(a, b) + return a +end + +@testset "call: argument mutation" begin + a = ones(3) + b = ones(3) + a_ra = Reactant.to_rarray(a) + b_ra = Reactant.to_rarray(b) + @jit call5!(a_ra, b_ra) + call5!(a, b) + @test a_ra == a +end + +mutable struct TestClock{I} + iteration::I +end + +mutable struct TestSimulation{C,I,B} + clock::C + stop_iteration::I + running::B +end + +function step!(sim) + @trace if sim.clock.iteration >= sim.stop_iteration + sim.running = false + else + sim.clock.iteration += 1 # time step + end + return sim +end + +function simulate!(sim) + return ReactantCore.traced_while(sim -> sim.running, step!, (sim,)) +end + +@testset "simulation loop" begin + clock = TestClock(ConcreteRNumber(0)) + simulation = TestSimulation(clock, ConcreteRNumber(3), ConcreteRNumber(true)) + + f! = @compile sync = true simulate!(simulation) + result = f!(simulation) + + @test result == [3, 3, false] + @test simulation.running == false + @test simulation.clock.iteration == 3 + @test simulation.stop_iteration == 3 +end diff --git a/test/cuda.jl b/test/cuda.jl deleted file mode 100644 index 549002e4f1..0000000000 --- a/test/cuda.jl +++ /dev/null @@ -1,36 +0,0 @@ -using Reactant -using Test -using CUDA - -using Reactant_jll -@show Reactant_jll.libReactantExtra_path - -function square_kernel!(x) - #i = threadIdx().x - #x[i] *= x[i] - #@cuprintf("overwrote value of %f was thrown during kernel execution on thread (%d, %d, %d) in block (%d, %d, %d).\n", - # 0.0, threadIdx().x, threadIdx().y, threadIdx().z, blockIdx().x, blockIdx().y, blockIdx().z) - #x[i], threadIdx().x, threadIdx().y, threadIdx().z, blockIdx().x, blockIdx().y, blockIdx().z) - - # sync_threads() - return nothing -end - -# basic squaring on GPU -function square!(x) - @cuda blocks = 1 threads = length(x) square_kernel!(x) - return nothing -end - -@testset "Square Kernel" begin - oA = collect(1:1:64) - A = Reactant.to_rarray(oA) - # @show @code_hlo optimize = false square!(A) - # @show @code_hlo optimize = :before_kernel square!(A) - # @show @code_hlo square!(A) - func! = @compile square!(A) - func!(A) - @show A - @show oA - @test all(Array(A) .≈ (oA .* oA)) -end diff --git a/test/custom_number_types.jl b/test/custom_number_types.jl new file mode 100644 index 0000000000..24d0d4dac5 --- /dev/null +++ b/test/custom_number_types.jl @@ -0,0 +1,38 @@ +using Float8s, Reactant +using Reactant: TracedRNumber + +Reactant.reactant_primitive(::Type{Float8_4}) = Reactant.F8E4M3FN + +x = Float8_4[ + -1.125 -0.21875 1.12 + 1.875 0.4375 1.0 + 0.5625 -1.0 0.937 + -0.375 -0.34375 -0.6875 + 0.46875 0.75 -0.23437 + -0.6875 -0.203125 0.375 + 0.875 -0.8125 2.5 + -0.6875 -0.1171875 -1.625 + 0.75 0.9375 1.0 + 0.5 0.203125 1.75 +] +x_64 = Float64.(x) +x_ra = Reactant.to_rarray(x) + +@testset "Reductions" begin + sumall(x) = TracedRNumber{Float64}(sum(x)) + + @test @jit(sumall(x_ra)) ≈ sum(x_64) atol = 1e-1 rtol = 1e-1 + + sum1(x) = TracedRNumber{Float64}.(sum(x; dims=1)) + sum2(x) = TracedRNumber{Float64}.(sum(x; dims=2)) + sum12(x) = TracedRNumber{Float64}.(sum(x; dims=(1, 2))) + + @test @jit(sum1(x_ra)) ≈ sum(x_64; dims=1) atol = 1e-1 rtol = 1e-1 + @test @jit(sum2(x_ra)) ≈ sum(x_64; dims=2) atol = 1e-1 rtol = 1e-1 + @test @jit(sum12(x_ra)) ≈ sum(x_64; dims=(1, 2)) atol = 1e-1 rtol = 1e-1 +end + +@testset "Broadcasting" begin + fn(x) = TracedRNumber{Float64}.(x .+ 1) + @test @jit(fn(x_ra)) ≈ (x_64 .+ 1) atol = 1e-1 rtol = 1e-1 +end diff --git a/test/ifrt/low_level.jl b/test/ifrt/low_level.jl new file mode 100644 index 0000000000..350754681a --- /dev/null +++ b/test/ifrt/low_level.jl @@ -0,0 +1,56 @@ +# Testing manual IFRT buffer creation + compilation + execution +using Reactant, Test +using Reactant: XLA +using Reactant.XLA: IFRT + +fn_test1(x, y) = x .+ y +fn_test2(x, y) = x .* y +fn_test3(x, y) = x .+ y' .- x + +@testset "IFRT Low-level API" begin + x = reshape(collect(Float32, 1:64), 8, 8) + y = collect((x .+ 64)') + + pjrt_client = Reactant.XLA.default_backend() + platform_name = lowercase(XLA.platform_name(pjrt_client)) + + ifrt_client = if platform_name == "cpu" + IFRT.CPUClient(; checkcount=false) + elseif platform_name == "gpu" || platform_name == "cuda" + IFRT.GPUClient(; checkcount=false) + elseif platform_name == "tpu" + IFRT.TPUClient(; checkcount=false) + else + error("Unsupported platform: $(platform_name)") + end + + pjrt_x = ConcreteRArray(x) # XXX: Rename to ConcretePJRTArray + pjrt_y = ConcreteRArray(y) # XXX: Rename to ConcretePJRTArray + + ifrt_x = IFRT.Array(ifrt_client, x) # XXX: Use ConcreteIFRTArray once ready + ifrt_y = IFRT.Array(ifrt_client, y) # XXX: Use ConcreteIFRTArray once ready + + @testset for fn in (fn_test1, fn_test2, fn_test3) + pjrt_result = @jit fn(pjrt_x, pjrt_y) + + mlir_mod, mlir_fn_res = Reactant.Compiler.compile_mlir(fn, (pjrt_x, pjrt_y)) + + ifrt_loaded_executable = XLA.compile( + ifrt_client, + XLA.default_device(ifrt_client), + mlir_mod; + num_outputs=length(mlir_fn_res.linear_results), + num_parameters=length(mlir_fn_res.linear_args), + mlir_fn_res.is_sharded, + global_device_ids=Int64[], + num_replicas=1, + num_partitions=1, + ) + + ifrt_result = XLA.execute( + ifrt_loaded_executable, (ifrt_x.buffer, ifrt_y.buffer), UInt8.((0, 0)), Val(1) + ) + + @test convert(Array, only(ifrt_result)) ≈ Array(pjrt_result) + end +end diff --git a/test/indexing.jl b/test/indexing.jl new file mode 100644 index 0000000000..5d6aec3412 --- /dev/null +++ b/test/indexing.jl @@ -0,0 +1,339 @@ +using LinearAlgebra, Reactant, Test + +function update_on_copy(x) + y = x[1:2, 2:4, :] + y[1:1, 1:1, :] = ones(1, 1, 3) + return y +end + +@testset "view / setindex" begin + x = rand(2, 4, 3) + y = copy(x) + x_concrete = Reactant.to_rarray(x) + y_concrete = Reactant.to_rarray(y) + + y1 = update_on_copy(x) + y2 = @jit update_on_copy(x_concrete) + @test x == y + @test x_concrete == y_concrete + @test y1 == y2 + + # function update_inplace(x) + # y = view(x, 1:2, 1:2, :) + # y[1, 1, :] .= 1 + # return y + # end + + # get_indices(x) = x[1:2, 1:2, :] + # get_view(x) = view(x, 1:2, 1:2, :) + + # get_indices_compiled = @compile get_indices(x_concrete) + # get_view_compiled = @compile get_view(x_concrete) +end + +function maskset!(y, x) + y[:] = x + return nothing +end + +@testset "setindex! with vectors & colon indexing" begin + x = Reactant.to_rarray([4.0]) + y = Reactant.to_rarray([2.0]) + @jit(maskset!(y, x)) + @test y ≈ x + + x = Reactant.to_rarray(ones(3)) + y = Reactant.to_rarray(2 * ones(3)) + @jit(maskset!(y, x)) + @test y ≈ x +end + +function masking(x) + y = similar(x) + y[1:2, :] .= 0 + y[3:4, :] .= 1 + return y +end + +function masking!(x) + x[1:2, :] .= 0 + x[3:4, :] .= 1 + return x +end + +@testset "setindex! with views" begin + x = rand(4, 4) .+ 2.0 + x_ra = Reactant.to_rarray(x) + + y = masking(x) + y_ra = @jit(masking(x_ra)) + @test y ≈ y_ra + + x_ra_array = Array(x_ra) + @test !(any(iszero, x_ra_array[1, :])) + @test !(any(iszero, x_ra_array[2, :])) + @test !(any(isone, x_ra_array[3, :])) + @test !(any(isone, x_ra_array[4, :])) + + y_ra = @jit(masking!(x_ra)) + @test y ≈ y_ra + + x_ra_array = Array(x_ra) + @test @allowscalar all(iszero, x_ra_array[1, :]) + @test @allowscalar all(iszero, x_ra_array[2, :]) + @test @allowscalar all(isone, x_ra_array[3, :]) + @test @allowscalar all(isone, x_ra_array[4, :]) +end + +function non_contiguous_setindex!(x) + x[[1, 3, 2], [1, 2, 3, 4]] .= 1.0 + return x +end + +@testset "non-contiguous setindex!" begin + x = rand(6, 6) + x_ra = Reactant.to_rarray(x) + + y = @jit(non_contiguous_setindex!(x_ra)) + y = Array(y) + x_ra = Array(x_ra) + @test all(isone, y[1:3, 1:4]) + @test all(isone, x_ra[1:3, 1:4]) + @test !all(isone, y[4:end, :]) + @test !all(isone, x_ra[4:end, :]) + @test !all(isone, y[:, 5:end]) + @test !all(isone, x_ra[:, 5:end]) +end + +@testset "dynamic indexing" begin + x = randn(5, 3) + x_ra = Reactant.to_rarray(x) + + idx = [1, 2, 3] + idx_ra = Reactant.to_rarray(idx) + + fn(x, idx) = @allowscalar x[idx, :] + + y = @jit(fn(x_ra, idx_ra)) + @test y ≈ x[idx, :] +end + +@testset "non-contiguous indexing" begin + x = rand(4, 4, 3) + x_ra = Reactant.to_rarray(x) + + non_contiguous_indexing1(x) = x[[1, 3, 2], :, :] + non_contiguous_indexing2(x) = x[:, [1, 2, 1, 3], [1, 3]] + + @test @jit(non_contiguous_indexing1(x_ra)) ≈ non_contiguous_indexing1(x) + @test @jit(non_contiguous_indexing2(x_ra)) ≈ non_contiguous_indexing2(x) + + x = rand(4, 2) + x_ra = Reactant.to_rarray(x) + + non_contiguous_indexing3(x) = x[[1, 3, 2], :] + non_contiguous_indexing4(x) = x[:, [1, 2, 2]] + + @test @jit(non_contiguous_indexing3(x_ra)) ≈ non_contiguous_indexing3(x) + @test @jit(non_contiguous_indexing4(x_ra)) ≈ non_contiguous_indexing4(x) + + x = rand(4, 4, 3) + x_ra = Reactant.to_rarray(x) + + non_contiguous_indexing1!(x) = x[[1, 3, 2], :, :] .= 2 + non_contiguous_indexing2!(x) = x[:, [1, 2, 1, 3], [1, 3]] .= 2 + + @jit(non_contiguous_indexing1!(x_ra)) + non_contiguous_indexing1!(x) + @test x_ra ≈ x + + x = rand(4, 4, 3) + x_ra = Reactant.to_rarray(x) + + @jit(non_contiguous_indexing2!(x_ra)) + non_contiguous_indexing2!(x) + @test x_ra ≈ x + + x = rand(4, 2) + x_ra = Reactant.to_rarray(x) + + non_contiguous_indexing3!(x) = x[[1, 3, 2], :] .= 2 + non_contiguous_indexing4!(x) = x[:, [1, 2, 2]] .= 2 + + @jit(non_contiguous_indexing3!(x_ra)) + non_contiguous_indexing3!(x) + @test x_ra ≈ x + + x = rand(4, 2) + x_ra = Reactant.to_rarray(x) + + @jit(non_contiguous_indexing4!(x_ra)) + non_contiguous_indexing4!(x) + @test x_ra ≈ x +end + +@testset "indexing with traced arrays" begin + x = rand(4, 4, 3) + idx1 = [1, 3, 2] + idx3 = [1, 2, 1, 3] + + x_ra = Reactant.to_rarray(x) + idx1_ra = Reactant.to_rarray(idx1) + idx3_ra = Reactant.to_rarray(idx3) + + getindex1(x, idx1) = x[idx1, :, :] + getindex2(x, idx1) = x[:, idx1, :] + getindex3(x, idx3) = x[:, :, idx3] + getindex4(x, idx1, idx3) = x[idx1, :, idx3] + + @test @jit(getindex1(x_ra, idx1_ra)) ≈ getindex1(x, idx1) + @test @jit(getindex2(x_ra, idx1_ra)) ≈ getindex2(x, idx1) + @test @jit(getindex3(x_ra, idx3_ra)) ≈ getindex3(x, idx3) + @test @jit(getindex4(x_ra, idx1_ra, idx3_ra)) ≈ getindex4(x, idx1, idx3) +end + +@testset "linear indexing" begin + x = rand(4, 4, 3) + x_ra = Reactant.to_rarray(x) + + getindex_linear_scalar(x, idx) = @allowscalar x[idx] + + @testset for i in 1:length(x) + @test @jit(getindex_linear_scalar(x_ra, i)) ≈ getindex_linear_scalar(x, i) + @test @jit( + getindex_linear_scalar(x_ra, Reactant.to_rarray(i; track_numbers=Number)) + ) ≈ getindex_linear_scalar(x, i) + end + + idx = rand(1:length(x), 8) + idx_ra = Reactant.to_rarray(idx) + + getindex_linear_vector(x, idx) = x[idx] + + @test @jit(getindex_linear_vector(x_ra, idx_ra)) ≈ getindex_linear_vector(x, idx) + @test @jit(getindex_linear_vector(x_ra, idx)) ≈ getindex_linear_vector(x, idx) +end + +@testset "Boolean Indexing" begin + x_ra = Reactant.to_rarray(rand(Float32, 4, 16)) + idxs_ra = Reactant.to_rarray(rand(Bool, 16)) + + fn(x, idxs) = x[:, idxs] + + @test_throws ErrorException @jit(fn(x_ra, idxs_ra)) + + res = @jit fn(x_ra, Array(idxs_ra)) + @test res ≈ fn(Array(x_ra), Array(idxs_ra)) +end + +@testset "inconsistent indexing" begin + x_ra = Reactant.to_rarray(rand(3, 4, 3)) + idx_ra = Reactant.to_rarray(1; track_numbers=Number) + + fn1(x) = x[:, :, 1] + fn2(x, idx) = x[:, :, idx] + fn3(x, idx) = x[idx, :, 1] + + @test ndims(@jit(fn1(x_ra))) == 2 + @test ndims(@jit(fn2(x_ra, idx_ra))) == 2 + @test ndims(@jit(fn3(x_ra, idx_ra))) == 1 +end + +@testset "High-Dimensional Array Indexing" begin + x_ra = Reactant.to_rarray(rand(5, 4, 3)) + idx1_ra = Reactant.to_rarray(rand(1:5, 2, 2, 3)) + idx2_ra = Reactant.to_rarray(rand(1:4, 2, 2, 3)) + idx3 = rand(1:3, 2, 2, 3) + + fn(x, idx1, idx2, idx3) = x[idx1, idx2, idx3] + + @test @jit(fn(x_ra, idx1_ra, idx2_ra, idx3)) ≈ + fn(Array(x_ra), Array(idx1_ra), Array(idx2_ra), idx3) +end + +function issue_617(outf, fr, pr, I) + tmp = fr .* reshape(pr, size(fr)) + outv = @view outf[I] + vtmp = vec(tmp) + outv .= vtmp + return outf +end + +@testset "issue #617" begin + N, M = 4, 6 + + f = rand(ComplexF64, N, N) + p = rand(ComplexF64, N * N) + I = 1:(N^2) + out = rand(ComplexF64, M, M) + + fr = Reactant.to_rarray(f) + pr = Reactant.to_rarray(p) + outr = Reactant.to_rarray(out) + Ir = Reactant.to_rarray(I) + + @test @jit(issue_617(outr, fr, pr, Ir)) ≈ issue_617(out, f, p, I) +end + +function scalar_setindex(x, idx, val) + @allowscalar x[idx] = val + return x +end + +@testset "scalar setindex" begin + x = zeros(4, 4) + x_ra = Reactant.to_rarray(x) + + @test @jit(scalar_setindex(x_ra, 1, 1)) ≈ scalar_setindex(x, 1, 1) + @test @allowscalar x_ra[1] == 1 + + x = zeros(4, 4) + x_ra = Reactant.to_rarray(x) + + @test @jit(scalar_setindex(x_ra, ConcreteRNumber(1), 1)) ≈ scalar_setindex(x, 1, 1) + @test @allowscalar x_ra[1] == 1 +end + +function write_with_broadcast1!(x, y) + x[1, :, :] .= reshape(y, 4, 3) + return x +end +function write_with_broadcast2!(x, y) + x[:, 1, :] .= view(y, :, 1:3) + return x +end + +@testset "write_with_broadcast" begin + x_ra = Reactant.to_rarray(zeros(3, 4, 3)) + y_ra = Reactant.to_rarray(rand(3, 4)) + + res = @jit write_with_broadcast1!(x_ra, y_ra) + + @test res.data[1] === x_ra.data[1] + + res = Array(res) + y = Array(y_ra) + @test res[1, :, :] ≈ reshape(y, 4, 3) + + x_ra = Reactant.to_rarray(zeros(3, 4, 3)) + y_ra = Reactant.to_rarray(rand(3, 4)) + + res = @jit write_with_broadcast2!(x_ra, y_ra) + + @test res.data[1] === x_ra.data[1] + + res = Array(res) + y = Array(y_ra) + @test res[:, 1, :] ≈ view(y, :, 1:3) +end + +@testset "getindex ambiguity" begin + x = collect(Float32, 1:8) + x_ra = Reactant.to_rarray(x) + + idx = CartesianIndex(1) + + fn(x, idx) = @allowscalar x[idx] + + @test @jit(fn(x_ra, idx)) ≈ fn(x, idx) +end diff --git a/test/integration/cuda.jl b/test/integration/cuda.jl new file mode 100644 index 0000000000..d69d40a795 --- /dev/null +++ b/test/integration/cuda.jl @@ -0,0 +1,137 @@ + +using Reactant +using Test +using CUDA + +function square_kernel!(x, y) + i = threadIdx().x + x[i] *= y[i] + # We don't yet auto lower this via polygeist + # sync_threads() + return nothing +end + +# basic squaring on GPU +function square!(x, y) + @cuda blocks = 1 threads = length(x) square_kernel!(x, y) + return nothing +end + +@testset "Square Kernel" begin + oA = collect(1:1:64) + A = Reactant.to_rarray(oA) + B = Reactant.to_rarray(100 .* oA) + @jit square!(A, B) + @test all(Array(A) .≈ (oA .* oA .* 100)) + @test all(Array(B) .≈ (oA .* 100)) +end + +function sin_kernel!(x, y) + i = threadIdx().x + x[i] *= sin(y[i]) + return nothing +end + +# basic squaring on GPU +function sin!(x, y) + @cuda blocks = 1 threads = length(x) sin_kernel!(x, y) + return nothing +end + +@testset "Sin Kernel" begin + oA = collect(Float64, 1:1:64) + A = Reactant.to_rarray(oA) + B = Reactant.to_rarray(100 .* oA) + @jit sin!(A, B) + @test all(Array(A) .≈ oA .* sin.(oA .* 100)) + @test all(Array(B) .≈ (oA .* 100)) +end + +function smul_kernel!(x, y) + i = threadIdx().x + x[i] *= y + return nothing +end + +# basic squaring on GPU +function smul!(x) + @cuda blocks = 1 threads = length(x) smul_kernel!(x, 3) + @cuda blocks = 1 threads = length(x) smul_kernel!(x, 5) + return nothing +end + +@testset "Constant Op Kernel" begin + oA = collect(1:1:64) + A = Reactant.to_rarray(oA) + @jit smul!(A) + @test all(Array(A) .≈ oA .* 15) +end + +function tuplef!(tup) + tup[1][] += 2 + return nothing +end + +function tuplef2!(tup) + tup[2][] *= tup[1] + return nothing +end + +tuplef(a) = @cuda threads = 1 tuplef!((a,)) +tuplef2(a) = @cuda threads = 1 tuplef2!((5, a)) + +@testset "Structured Kernel Arguments" begin + A = Reactant.to_rarray(fill(1)) + @jit tuplef(A) + @test all(Array(A) .≈ 3) + + A = Reactant.to_rarray(fill(1)) + @jit tuplef2(A) + @test all(Array(A) .≈ 5) +end +A = Reactant.to_rarray(fill(1)) +@jit tuplef2(A) +@test all(Array(A) .≈ 5) + +# TODO this same code fails if we use a 0-d array...? +# maybe weird cuda things +function aliased!(tup) + x, y = tup + x[1] *= y[1] + return nothing +end + +function aliased(s) + tup = (s, s) + @cuda threads = 1 aliased!(tup) + return nothing +end + +@testset "Aliasing arguments" begin + a = Reactant.to_rarray([3]) + + @jit aliased(a) + @test all(Array(a) .== 9) +end + +using Reactant, CUDA + +function cmul!(a, b) + b[1] *= a[1] + return nothing +end + +function mixed(a, b) + @cuda threads = 1 cmul!(a, b) + return nothing +end + +@testset "Non-traced argument" begin + if CUDA.functional() + a = CuArray([4]) + b = Reactant.to_rarray([3]) + @jit mixed(a, b) + @test all(Array(a) .== 4) + @test all(Array(b) .== 12) + end +end diff --git a/test/integration/fft.jl b/test/integration/fft.jl index d39ac6d256..62e5de97de 100644 --- a/test/integration/fft.jl +++ b/test/integration/fft.jl @@ -2,12 +2,12 @@ using FFTW, Reactant @testset "fft" begin x = rand(ComplexF32, 2, 2, 3, 4) - x_ra = Reactant.ConcreteRArray(x) + x_ra = Reactant.to_rarray(x) @test_throws AssertionError @jit(fft(x_ra)) x = rand(ComplexF32, 2, 3, 4) - x_ra = Reactant.ConcreteRArray(x) + x_ra = Reactant.to_rarray(x) @test @jit(fft(x_ra)) ≈ fft(x) @test @jit(fft(x_ra, (1, 2))) ≈ fft(x, (1, 2)) @@ -24,12 +24,12 @@ end @testset "rfft" begin x = rand(2, 2, 3, 4) - x_ra = Reactant.ConcreteRArray(x) + x_ra = Reactant.to_rarray(x) @test_throws AssertionError @jit(rfft(x_ra)) x = rand(2, 3, 4) - x_ra = Reactant.ConcreteRArray(x) + x_ra = Reactant.to_rarray(x) @test @jit(rfft(x_ra)) ≈ rfft(x) @test @jit(rfft(x_ra, (1, 2))) ≈ rfft(x, (1, 2)) diff --git a/test/integration/kernelabstractions.jl b/test/integration/kernelabstractions.jl new file mode 100644 index 0000000000..7fa3bd6d92 --- /dev/null +++ b/test/integration/kernelabstractions.jl @@ -0,0 +1,70 @@ +using CUDA, KernelAbstractions, Reactant + +# Simple kernel for matrix multiplication +@kernel function matmul_kernel!(output, a) + i, j = @index(Global, NTuple) + # creating a temporary sum variable for matrix multiplication + + tmp_sum = zero(eltype(output)) + for k in 1:size(a)[2] + @inbounds tmp_sum += a[i, k] * a[k, j] + end + + @inbounds output[i, j] = tmp_sum +end + +# Creating a wrapper kernel for launching with error checks +function matmul!(output, a) + backend = KernelAbstractions.get_backend(output) + kernel! = matmul_kernel!(backend) + kernel!(output, a; ndrange=size(output)) + return KernelAbstractions.synchronize(backend) +end + +# https://github.com/EnzymeAD/Reactant.jl/issues/614 +const skip_non_cuda_tests = true + +@static if !Sys.isapple() + @testset "KernelAbstractions Matmul" begin + A = Reactant.to_rarray(ones(100, 100)) + out = Reactant.to_rarray(ones(100, 100)) + if CUDA.functional() + @test all(Array(@jit(matmul!(out, A))) .≈ 100) broken = true + else + @static if skip_non_cuda_tests + @test false broken = true + else + @code_hlo optimize = :before_kernel matmul!(out, A) + end + end + end +end + +# simple square kernel +@kernel function square_kernel!(y, @Const(x)) + i = @index(Global) + @inbounds y[i] = x[i] * x[i] +end + +function square(x) + y = similar(x) + backend = KernelAbstractions.get_backend(x) + kernel! = square_kernel!(backend) + kernel!(y, x; ndrange=length(x)) + return y +end + +@static if !Sys.isapple() + @testset "KernelAbstractions Square" begin + x = Reactant.to_rarray(collect(1:1:64) ./ 64) + if CUDA.functional() + @test all(Array(@jit(square(x))) .≈ Array(x) .* Array(x)) + else + @static if skip_non_cuda_tests + @test false broken = true + else + @code_hlo optimize = :before_kernel square(x) + end + end + end +end diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index 0c6efc5fdb..a49768310e 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -1,4 +1,4 @@ -using LinearAlgebra, Reactant +using LinearAlgebra, Reactant, Test function muladd2(A, x, b) C = similar(A, promote_type(eltype(A), eltype(b)), size(A, 1), size(x, 2)) @@ -57,6 +57,7 @@ end @test @jit(muladd2(A_ra, x_ra, b_ra)) ≈ muladd2(A, x, b) @test @jit(muladd_5arg(A_ra, x_ra, b_ra)) ≈ muladd2(A, x, b) @test @jit(muladd_5arg2(A_ra, x_ra, b_ra)) ≈ 2 .* A * x .+ b + @test @jit(A_ra * x) ≈ A * x @test @jit(mul_with_view1(A_ra, x_ra)) ≈ mul_with_view1(A, x) @@ -130,15 +131,135 @@ end @test @jit(diagm(4, 5, x_ra)) ≈ diagm(4, 5, x) @test @jit(diagm(6, 6, x_ra)) ≈ diagm(6, 6, x) @test_throws DimensionMismatch @jit(diagm(3, 3, x_ra)) + + x1 = rand(3) + x2 = rand(3) + x3 = rand(2) + x_ra1 = Reactant.to_rarray(x1) + x_ra2 = Reactant.to_rarray(x2) + x_ra3 = Reactant.to_rarray(x3) + + @test @jit(diagm(1 => x_ra1)) ≈ diagm(1 => x1) + @test @jit(diagm(1 => x_ra1, -1 => x_ra3)) ≈ diagm(1 => x1, -1 => x3) + @test @jit(diagm(1 => x_ra1, 1 => x_ra2)) ≈ diagm(1 => x1, 1 => x2) end -# TODO: Currently Diagonal(x) * x goes down the generic matmul path but it should clearly be -# optimized +# TODO: Currently (x) * x goes down the generic matmul path but it should +# clearly be optimized mul_diagonal(x) = Diagonal(x) * x +mul_tridiagonal(x) = Tridiagonal(x) * x +mul_unit_lower_triangular(x) = UnitLowerTriangular(x) * x +mul_unit_upper_triangular(x) = UnitUpperTriangular(x) * x +mul_lower_triangular(x) = LowerTriangular(x) * x +mul_upper_triangular(x) = UpperTriangular(x) * x +mul_symmetric(x) = Symmetric(x) * x + +@testset "Wrapper Types Matrix Multiplication" begin + x = rand(4, 4) + x_ra = Reactant.to_rarray(x) + + @testset "$(wrapper_type)" for (wrapper_type, fn) in [ + (Diagonal, mul_diagonal), + (Tridiagonal, mul_tridiagonal), + (UnitLowerTriangular, mul_unit_lower_triangular), + (UnitUpperTriangular, mul_unit_upper_triangular), + (LowerTriangular, mul_lower_triangular), + (UpperTriangular, mul_upper_triangular), + (Symmetric, mul_symmetric), + ] + @test @jit(fn(x_ra)) ≈ fn(x) + end +end + +@testset "kron" begin + @testset for T in (Int64, Float64, ComplexF64) + @testset for (x_sz, y_sz) in [ + ((3, 4), (2, 5)), ((3, 4), (2,)), ((3,), (2, 5)), ((3,), (5,)), ((10,), ()) + ] + x = x_sz == () ? rand(T) : rand(T, x_sz) + y = y_sz == () ? rand(T) : rand(T, y_sz) + x_ra = Reactant.to_rarray(x; track_numbers=Number) + y_ra = Reactant.to_rarray(y; track_numbers=Number) + @test @jit(kron(x_ra, y_ra)) ≈ kron(x, y) + end + end +end + +@testset "axpy!" begin + α = 3 + x = rand(Int64, 4) + x_ra = Reactant.to_rarray(x) + y = rand(Int64, 4) + y_ra = Reactant.to_rarray(y) -@testset "mul_diagonal" begin + @jit axpy!(α, x_ra, y_ra) + @test y_ra ≈ axpy!(α, x, y) + + α = 2 x = rand(4) x_ra = Reactant.to_rarray(x) + y = rand(4) + y_ra = Reactant.to_rarray(y) + + @jit axpy!(α, x_ra, y_ra) + @test y_ra ≈ axpy!(α, x, y) + + α = 4.12 + X = rand(3, 5) + Y = rand(3, 5) + X_ra = Reactant.to_rarray(X) + Y_ra = Reactant.to_rarray(Y) + + @jit axpy!(α, X_ra, Y_ra) + @test Y_ra ≈ axpy!(α, X, Y) + + α = 3.2 + 1im + x = rand(Complex{Float32}, 4) + x_ra = Reactant.to_rarray(x) + y = rand(Complex{Float32}, 4) + y_ra = Reactant.to_rarray(y) + + @jit axpy!(α, x_ra, y_ra) + @test y_ra ≈ axpy!(α, x, y) +end + +@testset "axpby!" begin + α = 3 + β = 2 + x = rand(Int64, 4) + x_ra = Reactant.to_rarray(x) + y = rand(Int64, 4) + y_ra = Reactant.to_rarray(y) + + @jit axpby!(α, x_ra, β, y_ra) + @test y_ra ≈ axpby!(α, x, β, y) + + α = 2 + β = 3 + x = rand(4) + x_ra = Reactant.to_rarray(x) + y = rand(4) + y_ra = Reactant.to_rarray(y) + + @jit axpby!(α, x_ra, β, y_ra) + @test y_ra ≈ axpby!(α, x, β, y) + + α = 4.12 + X = rand(3, 5) + Y = rand(3, 5) + X_ra = Reactant.to_rarray(X) + Y_ra = Reactant.to_rarray(Y) + + @jit axpby!(α, X_ra, β, Y_ra) + @test Y_ra ≈ axpby!(α, X, β, Y) + + α = 3.2 + 1im + β = 2.1 - 4.2im + x = rand(Complex{Float32}, 4) + x_ra = Reactant.to_rarray(x) + y = rand(Complex{Float32}, 4) + y_ra = Reactant.to_rarray(y) - @test @jit(mul_diagonal(x_ra)) ≈ mul_diagonal(x) + @jit axpby!(α, x_ra, β, y_ra) + @test y_ra ≈ axpby!(α, x, β, y) end diff --git a/test/integration/offsetarrays.jl b/test/integration/offsetarrays.jl new file mode 100644 index 0000000000..a8bb4f899e --- /dev/null +++ b/test/integration/offsetarrays.jl @@ -0,0 +1,19 @@ +using Reactant +using Test +using OffsetArrays + +function scalar_index(x) + @allowscalar getindex(x, -1, 0) +end +@testset "OffsetArrays" begin + A = Float64.(reshape(1:15, 3, 5)) + OA = OffsetArray(A, -1:1, 0:4) + rOA = Reactant.to_rarray(OA) + + oval = scalar_index(OA) + cval = scalar_index(rOA) + @test cval ≈ oval + + tval = @jit scalar_index(rOA) + @test tval ≈ oval +end diff --git a/test/integration/python.jl b/test/integration/python.jl index 54c2eec16d..f85099693a 100644 --- a/test/integration/python.jl +++ b/test/integration/python.jl @@ -2,12 +2,16 @@ using Reactant using Reactant: Ops using Test -using PythonCall -@testset "PythonCall" begin - jax = pyimport("jax") +# Jax on Github CI dislikes X86 macos +@static if !Sys.isapple() || Sys.ARCH != :x86_64 + using PythonCall - result = @jit jax.numpy.sum(Reactant.to_rarray(Float32[1, 2, 3])) - @test typeof(result) == ConcreteRNumber{Float32} - @test result ≈ 6 + @testset "PythonCall" begin + jax = pyimport("jax") + + result = @jit jax.numpy.sum(Reactant.to_rarray(Float32[1, 2, 3])) + @test typeof(result) == ConcretePJRTNumber{Float32,1,Sharding.NoShardInfo} + @test result ≈ 6 + end end diff --git a/test/integration/random.jl b/test/integration/random.jl index 275e0e2447..f7c41f6667 100644 --- a/test/integration/random.jl +++ b/test/integration/random.jl @@ -66,8 +66,8 @@ end # distributions @testset "Uniform Random" begin @testset "Deterministic Seed" begin - seed1 = ConcreteRArray(UInt64[1, 3]) - seed2 = ConcreteRArray(UInt64[1, 5]) + seed1 = Reactant.to_rarray(UInt64[1, 3]) + seed2 = Reactant.to_rarray(UInt64[1, 5]) fn(seed) = begin rng = Random.default_rng() @@ -110,8 +110,8 @@ end @testset "Normal Distribution" begin @testset "Deterministic Seed" begin - seed1 = ConcreteRArray(UInt64[1, 3]) - seed2 = ConcreteRArray(UInt64[1, 5]) + seed1 = Reactant.to_rarray(UInt64[1, 3]) + seed2 = Reactant.to_rarray(UInt64[1, 5]) fn(seed) = begin rng = Random.default_rng() @@ -147,8 +147,8 @@ end @testset "Exponential Distribution" begin @testset "Deterministic Seed" begin - seed1 = ConcreteRArray(UInt64[1, 3]) - seed2 = ConcreteRArray(UInt64[1, 5]) + seed1 = Reactant.to_rarray(UInt64[1, 3]) + seed2 = Reactant.to_rarray(UInt64[1, 5]) fn(seed) = begin rng = Random.default_rng() diff --git a/test/integration/special_functions.jl b/test/integration/special_functions.jl new file mode 100644 index 0000000000..c63f80b48d --- /dev/null +++ b/test/integration/special_functions.jl @@ -0,0 +1,102 @@ +using SpecialFunctions, Reactant + +macro ≈(a, b) + return quote + isapprox($a, $b; atol=1e-14) + end +end + +@testset "gamma" begin + @test SpecialFunctions.gamma(0.5) ≈ @jit(SpecialFunctions.gamma(ConcreteRNumber(0.5))) + @test SpecialFunctions.gamma(2) ≈ @jit(SpecialFunctions.gamma(ConcreteRNumber(2))) +end + +@testset "loggamma" begin + @test SpecialFunctions.loggamma(0.5) ≈ + @jit(SpecialFunctions.loggamma(ConcreteRNumber(0.5))) + @test abs(SpecialFunctions.loggamma(2)) < 1e-10 + @test abs(@jit(SpecialFunctions.loggamma(ConcreteRNumber(2)))) < 1e-10 +end + +@testset "digamma" begin + @test SpecialFunctions.digamma(0.5) ≈ + @jit(SpecialFunctions.digamma(ConcreteRNumber(0.5))) + @test SpecialFunctions.digamma(2) ≈ @jit(SpecialFunctions.digamma(ConcreteRNumber(2))) +end + +@testset "trigamma" begin + @test SpecialFunctions.trigamma(0.5) ≈ + @jit(SpecialFunctions.trigamma(ConcreteRNumber(0.5))) + @test SpecialFunctions.trigamma(2) ≈ @jit(SpecialFunctions.trigamma(ConcreteRNumber(2))) +end + +@testset "beta" begin + @test SpecialFunctions.beta(0.5, 0.6) ≈ + @jit(SpecialFunctions.beta(ConcreteRNumber(0.5), ConcreteRNumber(0.6))) + @test SpecialFunctions.beta(2, 4) ≈ + @jit(SpecialFunctions.beta(ConcreteRNumber(2), ConcreteRNumber(4))) +end + +@testset "logbeta" begin + @test SpecialFunctions.logbeta(0.5, 0.6) ≈ + @jit(SpecialFunctions.logbeta(ConcreteRNumber(0.5), ConcreteRNumber(0.6))) + @test SpecialFunctions.logbeta(2, 4) ≈ + @jit(SpecialFunctions.logbeta(ConcreteRNumber(2), ConcreteRNumber(4))) +end + +@testset "erf" begin + @test SpecialFunctions.erf(0.5) ≈ @jit(SpecialFunctions.erf(ConcreteRNumber(0.5))) + @test SpecialFunctions.erf(2) ≈ @jit(SpecialFunctions.erf(ConcreteRNumber(2))) +end + +@testset "erf with 2 arguments" begin + @test SpecialFunctions.erf(0.5, 0.6) ≈ + @jit(SpecialFunctions.erf(ConcreteRNumber(0.5), ConcreteRNumber(0.6))) + @test SpecialFunctions.erf(2, 4) ≈ + @jit(SpecialFunctions.erf(ConcreteRNumber(2), ConcreteRNumber(4))) +end + +@testset "erfc" begin + @test SpecialFunctions.erfc(0.5) ≈ @jit(SpecialFunctions.erfc(ConcreteRNumber(0.5))) + @test SpecialFunctions.erfc(2) ≈ @jit(SpecialFunctions.erfc(ConcreteRNumber(2))) +end + +@testset "logerf" begin + @test SpecialFunctions.logerf(0.5, 0.6) ≈ + @jit(SpecialFunctions.logerf(ConcreteRNumber(0.5), ConcreteRNumber(0.6))) + @test SpecialFunctions.logerf(2, 4) ≈ + @jit(SpecialFunctions.logerf(ConcreteRNumber(2), ConcreteRNumber(4))) +end + +@testset "erfcx" begin + @test SpecialFunctions.erfcx(0.5) ≈ @jit(SpecialFunctions.erfcx(ConcreteRNumber(0.5))) + @test SpecialFunctions.erfcx(2) ≈ @jit(SpecialFunctions.erfcx(ConcreteRNumber(2))) +end + +@testset "logerfc" begin + @test SpecialFunctions.logerfc(0.5) ≈ + @jit(SpecialFunctions.logerfc(ConcreteRNumber(0.5))) + @test SpecialFunctions.logerfc(2) ≈ @jit(SpecialFunctions.logerfc(ConcreteRNumber(2))) +end + +@testset "logerfcx" begin + @test SpecialFunctions.logerfcx(0.5) ≈ + @jit(SpecialFunctions.logerfcx(ConcreteRNumber(0.5))) + @test SpecialFunctions.logerfcx(2) ≈ @jit(SpecialFunctions.logerfcx(ConcreteRNumber(2))) +end + +@testset "loggamma1p" begin + @test SpecialFunctions.loggamma1p(0.5) ≈ + @jit SpecialFunctions.loggamma1p(ConcreteRNumber(0.5)) +end + +@testset "loggammadiv" begin + @test SpecialFunctions.loggammadiv(150, 20) ≈ + @jit SpecialFunctions.loggammadiv(ConcreteRNumber(150), ConcreteRNumber(20)) +end + +@testset "zeta" begin + s = Reactant.to_rarray([1.0, 2.0, 50.0]) + z = Reactant.to_rarray([1e-8, 0.001, 2.0]) + @test SpecialFunctions.zeta.(Array(s), Array(z)) ≈ @jit SpecialFunctions.zeta.(s, z) +end diff --git a/test/layout.jl b/test/layout.jl index 91c5d5e709..6b5dbca2a2 100644 --- a/test/layout.jl +++ b/test/layout.jl @@ -4,7 +4,7 @@ using Test @testset "Layout" begin x = reshape([1.0, 2.0, 3.0, 4.0], (2, 2)) - y = Reactant.ConcreteRArray(x) + y = Reactant.to_rarray(x) y2 = convert(Array{Float64,2}, y) diff --git a/test/nn/flux.jl b/test/nn/flux.jl index 02f55f4c08..58b1cfac8f 100644 --- a/test/nn/flux.jl +++ b/test/nn/flux.jl @@ -16,7 +16,7 @@ using Reactant, Flux origout = model(noisy) cmodel = Reactant.to_rarray(model) - cnoisy = Reactant.ConcreteRArray(noisy) + cnoisy = Reactant.to_rarray(noisy) f = Reactant.compile((a, b) -> a(b), (cmodel, cnoisy)) diff --git a/test/nn/lux.jl b/test/nn/lux.jl index 7916ce10fd..46ca0525d7 100644 --- a/test/nn/lux.jl +++ b/test/nn/lux.jl @@ -40,17 +40,17 @@ end cps = Reactant.to_rarray(ps) cst = Reactant.to_rarray(Lux.testmode(st)) cst2 = Reactant.to_rarray(st) - cnoisy = Reactant.ConcreteRArray(noisy) + cnoisy = Reactant.to_rarray(noisy) f = Reactant.compile((a, b, c, d) -> first(a(b, c, d)), (cmodel, cnoisy, cps, cst)) comp = f(cmodel, cnoisy, cps, cst) - @test comp ≈ origout atol = 1e-5 rtol = 1e-2 + @test comp ≈ origout atol = 1e-3 rtol = 1e-2 target = onehotbatch(truth, [true, false]) # 2×1000 OneHotMatrix - ctarget = Reactant.ConcreteRArray(Array{Float32}(target)) + ctarget = Reactant.to_rarray(Array{Float32}(target)) # ctarget = Reactant.to_rarray(target) res, dps = gradient_loss_function(model, noisy, target, ps, st) @@ -61,8 +61,9 @@ end res_reactant, dps_reactant = compiled_gradient(cmodel, cnoisy, ctarget, cps, cst2) - @test res ≈ res_reactant atol = 1e-5 rtol = 1e-2 + @test res ≈ res_reactant atol = 1e-3 rtol = 1e-2 + # See https://github.com/EnzymeAD/Reactant.jl/issues/578 for (dps1, dps2) in zip(fleaves(dps), fleaves(dps_reactant)) - @test dps1 ≈ dps2 atol = 1e-5 rtol = 1e-2 + @test_skip dps1 ≈ dps2 atol = 1e-3 rtol = 1e-2 end end diff --git a/test/nn/luxlib.jl b/test/nn/luxlib.jl index f1bafff210..08f2e76ad8 100644 --- a/test/nn/luxlib.jl +++ b/test/nn/luxlib.jl @@ -26,8 +26,8 @@ using LuxLib, Reactant, Enzyme, NNlib x = randn(Float32, 10, 12) bias = has_bias ? randn(Float32, 9) : nothing - weight_ra = Reactant.ConcreteRArray(weight) - x_ra = Reactant.ConcreteRArray(x) + weight_ra = Reactant.to_rarray(weight) + x_ra = Reactant.to_rarray(x) bias_ra = Reactant.to_rarray(bias) f_compile = Reactant.compile( @@ -93,8 +93,8 @@ end x = randn(Float32, 10, 10) b = randn(Float32, 10) - x_ra = Reactant.ConcreteRArray(x) - b_ra = Reactant.ConcreteRArray(b) + x_ra = Reactant.to_rarray(x) + b_ra = Reactant.to_rarray(b) f_compile = Reactant.compile(biasact, (act, x_ra, b_ra)) f_compile!! = Reactant.compile(biasact!!, (act, x_ra, b_ra)) @@ -145,7 +145,7 @@ end end x_act = randn(Float32, 10, 10) - x_act_ca = Reactant.ConcreteRArray(x_act) + x_act_ca = Reactant.to_rarray(x_act) @testset "Activation: $act" for act in ( identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2 @@ -187,8 +187,8 @@ end x = randn(Float32, 16, 16, 8, 2) bias = has_bias ? randn(Float32, 4) : nothing - weight_reactant = Reactant.ConcreteRArray(weight) - x_reactant = Reactant.ConcreteRArray(x) + weight_reactant = Reactant.to_rarray(weight) + x_reactant = Reactant.to_rarray(x) bias_reactant = Reactant.to_rarray(bias) @testset for stride in ((1, 1), (2, 2), (3, 3)), diff --git a/test/nn/nnlib.jl b/test/nn/nnlib.jl index 7359bca97d..d82d3db702 100644 --- a/test/nn/nnlib.jl +++ b/test/nn/nnlib.jl @@ -10,12 +10,12 @@ using NNlib, Reactant, Enzyme end x_act = randn(Float32, 10, 10) - x_act_ca = Reactant.ConcreteRArray(x_act) + x_act_ca = Reactant.to_rarray(x_act) @testset "Activation: $act" for act in ( identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2, relu6 ) - f_compile = Reactant.compile(sumabs2, (act, x_act)) + f_compile = Reactant.compile(sumabs2, (act, x_act_ca)) y_simple = sumabs2(act, x_act) y_compile = f_compile(act, x_act_ca) @@ -35,7 +35,7 @@ end @testset "Pooling" begin @testset for f in (NNlib.meanpool, NNlib.maxpool) x = randn(Float32, 32, 32, 3, 2) - x_reactant = Reactant.ConcreteRArray(x) + x_reactant = Reactant.to_rarray(x) @testset for window in ((2, 2), (3, 3), (4, 4)), stride in ((1, 1), (2, 2)), @@ -70,8 +70,8 @@ end weight = randn(Float32, 4, 4, 8 ÷ groups, 4) x = randn(Float32, 16, 16, 8, 2) - weight_reactant = Reactant.ConcreteRArray(weight) - x_reactant = Reactant.ConcreteRArray(x) + weight_reactant = Reactant.to_rarray(weight) + x_reactant = Reactant.to_rarray(x) @testset for stride in ((1, 1), (2, 2), (3, 3)), padding in ((0, 0), (1, 1), (2, 2), (0, 2), (2, 0), (0, 1), (1, 0)), @@ -113,8 +113,8 @@ end @testset "conv 1d: flip" begin x = [1.0f0; 2.0f0; 3.0f0;;;] W = [1.0f0; 2.0f0; 3.0f0;;;] - xx = Reactant.ConcreteRArray(x) - WW = Reactant.ConcreteRArray(W) + xx = Reactant.to_rarray(x) + WW = Reactant.to_rarray(W) conv_noflip(x, W) = NNlib.conv(x, W; pad=1, flipped=true) conv_flip(x, W) = NNlib.conv(x, W; pad=1, flipped=false) @test Reactant.compile(conv_noflip, (xx, WW))(xx, WW) == @@ -128,31 +128,31 @@ end x = rand(Float32, 4, 3, 5) y = rand(Float32, 3, 2, 5) - x_ra = Reactant.ConcreteRArray(x) - y_ra = Reactant.ConcreteRArray(y) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) @test @jit(batched_mul(x_ra, y_ra)) ≈ batched_mul(x, y) x = rand(Float32, 4, 3, 1) y = rand(Float32, 3, 2, 5) - x_ra = Reactant.ConcreteRArray(x) - y_ra = Reactant.ConcreteRArray(y) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) @test @jit(batched_mul(x_ra, y_ra)) ≈ batched_mul(x, y) x = rand(Float32, 4, 3, 5) y = rand(Float32, 3, 2, 1) - x_ra = Reactant.ConcreteRArray(x) - y_ra = Reactant.ConcreteRArray(y) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) @test @jit(batched_mul(x_ra, y_ra)) ≈ batched_mul(x, y) end @testset "Constant Padding: NNlib.pad_constant" begin x = rand(Float32, 4, 4) - x_ra = Reactant.ConcreteRArray(x) + x_ra = Reactant.to_rarray(x) # Symmetric Padding @test @jit(NNlib.pad_constant(x_ra, (1, 1))) ≈ NNlib.pad_constant(x, (1, 1)) @@ -191,14 +191,14 @@ end @test @jit(∇sumabs2(pad_fn2, x_ra)) ≈ ∇sumabs2(pad_fn2, x) x = rand(ComplexF32, 4, 4) - x_ra = Reactant.ConcreteRArray(x) + x_ra = Reactant.to_rarray(x) @test @jit(NNlib.pad_constant(x_ra, (1, 1))) ≈ NNlib.pad_constant(x, (1, 1)) end @testset "make_causal_mask" begin x = rand(2, 10) - x_ra = Reactant.ConcreteRArray(x) + x_ra = Reactant.to_rarray(x) @test @jit(NNlib.make_causal_mask(x_ra)) ≈ NNlib.make_causal_mask(x) @@ -222,14 +222,16 @@ end 5 7 7 5 ] - y1 = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) + y1 = @test_warn r"Using fallback implementation of `gather!`" @jit( + NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index)) + ) @test y1 ≈ output - @test y1 isa ConcreteRArray{Float32,2} + @test y1 isa ConcretePJRTArray{Float32,2} @test size(y1) == size(index) y2 = @jit(NNlib.gather(Reactant.to_rarray(src), index)) @test y2 ≈ output - @test y2 isa ConcreteRArray{Float32,2} + @test y2 isa ConcretePJRTArray{Float32,2} @test size(y2) == size(index) dst = Float32.(zero.(index)) @@ -260,12 +262,12 @@ end ][:, :, 1:1] y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) @test y ≈ output - @test y isa ConcreteRArray{Float32,3} + @test y isa ConcretePJRTArray{Float32,3} @test size(y) == size(index) y = @jit(NNlib.gather(Reactant.to_rarray(src), index)) @test y ≈ output - @test y isa ConcreteRArray{Float32,3} + @test y isa ConcretePJRTArray{Float32,3} @test size(y) == size(index) ## 2d src, 2d index of ints -> 3d output @@ -295,12 +297,12 @@ end y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) @test y ≈ output - @test y isa ConcreteRArray{Float32,3} + @test y isa ConcretePJRTArray{Float32,3} @test size(y) == (size(src)[1:(end - 1)]..., size(index)...) y = @jit(NNlib.gather(Reactant.to_rarray(src), index)) @test y ≈ output - @test y isa ConcreteRArray{Float32,3} + @test y isa ConcretePJRTArray{Float32,3} @test size(y) == (size(src)[1:(end - 1)]..., size(index)...) end @@ -316,13 +318,13 @@ end y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) M = NNlib.typelength(eltype(index)) Nsrc = ndims(src) - @test y isa ConcreteRArray{Float32,1} + @test y isa ConcretePJRTArray{Float32,1} @test size(y) == (size(src)[1:(Nsrc - M)]..., size(index)...) @test y ≈ output y = @jit(NNlib.gather(Reactant.to_rarray(src), index)) @test y ≈ output - @test y isa ConcreteRArray{Float32,1} + @test y isa ConcretePJRTArray{Float32,1} @test size(y) == (size(src)[1:(Nsrc - M)]..., size(index)...) @test y ≈ output @@ -334,13 +336,13 @@ end y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) M = NNlib.typelength(eltype(index)) Nsrc = ndims(src) - @test y isa ConcreteRArray{Float32,3} + @test y isa ConcretePJRTArray{Float32,3} @test size(y) == (size(src)[1:(Nsrc - M)]..., size(index)...) y = @jit(NNlib.gather(Reactant.to_rarray(src), index)) M = NNlib.typelength(eltype(index)) Nsrc = ndims(src) - @test y isa ConcreteRArray{Float32,3} + @test y isa ConcretePJRTArray{Float32,3} @test size(y) == (size(src)[1:(Nsrc - M)]..., size(index)...) end @@ -356,13 +358,13 @@ end y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) M = NNlib.typelength(eltype(index)) Nsrc = ndims(src) - @test y isa ConcreteRArray{Float32,1} + @test y isa ConcretePJRTArray{Float32,1} @test size(y) == (size(src)[1:(Nsrc - M)]..., size(index)...) @test y ≈ output y = @jit(NNlib.gather(Reactant.to_rarray(src), index)) @test y ≈ output - @test y isa ConcreteRArray{Float32,1} + @test y isa ConcretePJRTArray{Float32,1} @test size(y) == (size(src)[1:(Nsrc - M)]..., size(index)...) ## 3d src, 2d index of 2-tuples -> 3d output @@ -373,13 +375,13 @@ end y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) M = NNlib.typelength(eltype(index)) Nsrc = ndims(src) - @test y isa ConcreteRArray{Float32,3} + @test y isa ConcretePJRTArray{Float32,3} @test size(y) == (size(src)[1:(Nsrc - M)]..., size(index)...) y = @jit(NNlib.gather(Reactant.to_rarray(src), index)) M = NNlib.typelength(eltype(index)) Nsrc = ndims(src) - @test y isa ConcreteRArray{Float32,3} + @test y isa ConcretePJRTArray{Float32,3} @test size(y) == (size(src)[1:(Nsrc - M)]..., size(index)...) end end @@ -417,7 +419,7 @@ end @testset "Upsampling" begin x = randn(Float32, 4, 4, 3, 2) - x_ra = Reactant.ConcreteRArray(x) + x_ra = Reactant.to_rarray(x) @test @jit(NNlib.upsample_nearest(x_ra, (2, 2))) ≈ NNlib.upsample_nearest(x, (2, 2)) end diff --git a/test/ops.jl b/test/ops.jl index 82ec4cc8b8..928c4d3d4d 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -4,13 +4,13 @@ using LinearAlgebra using SpecialFunctions: SpecialFunctions @testset "abs" begin - x = ConcreteRArray([1, -1]) + x = Reactant.to_rarray([1, -1]) @test [1, 1] ≈ @jit Ops.abs(x) - x = ConcreteRArray([1.0, -1.0]) + x = Reactant.to_rarray([1.0, -1.0]) @test [1.0, 1.0] ≈ @jit Ops.abs(x) - x = ConcreteRArray([ + x = Reactant.to_rarray([ 3.0+4im -3.0+4im 3.0-4im -3.0-4im ]) @@ -21,20 +21,20 @@ using SpecialFunctions: SpecialFunctions end @testset "add" begin - a = ConcreteRArray([false, false, true, true]) - b = ConcreteRArray([false, true, false, true]) + a = Reactant.to_rarray([false, false, true, true]) + b = Reactant.to_rarray([false, true, false, true]) @test [false, true, true, false] ≈ @jit Ops.add(a, b) - a = ConcreteRArray([1, 2, 3, 4]) - b = ConcreteRArray([5, 6, -7, -8]) + a = Reactant.to_rarray([1, 2, 3, 4]) + b = Reactant.to_rarray([5, 6, -7, -8]) @test Array(a) .+ Array(b) ≈ @jit Ops.add(a, b) - a = ConcreteRArray([1.1, 2.2, 3.3, 4.4]) - b = ConcreteRArray([5.5, 6.6, -7.7, -8.8]) + a = Reactant.to_rarray([1.1, 2.2, 3.3, 4.4]) + b = Reactant.to_rarray([5.5, 6.6, -7.7, -8.8]) @test Array(a) .+ Array(b) ≈ @jit Ops.add(a, b) - a = ConcreteRArray([1.1 + 2.2im, 3.3 + 4.4im, 5.5 + 6.6im, 7.7 + 8.8im]) - b = ConcreteRArray([ + a = Reactant.to_rarray([1.1 + 2.2im, 3.3 + 4.4im, 5.5 + 6.6im, 7.7 + 8.8im]) + b = Reactant.to_rarray([ 9.9 + 10.10im, 11.11 + 12.12im, -13.13 + -14.14im, -15.15 + -16.16im ]) @test Array(a) .+ Array(b) ≈ @jit Ops.add(a, b) @@ -45,34 +45,34 @@ end end @testset "and" begin - a = ConcreteRArray([false, false, true, true]) - b = ConcreteRArray([false, true, false, true]) + a = Reactant.to_rarray([false, false, true, true]) + b = Reactant.to_rarray([false, true, false, true]) @test [false, false, false, true] ≈ @jit Ops.and(a, b) - a = ConcreteRArray([1, 2, 3, 4]) - b = ConcreteRArray([5, 6, -7, -8]) + a = Reactant.to_rarray([1, 2, 3, 4]) + b = Reactant.to_rarray([5, 6, -7, -8]) @test [1, 2, 1, 0] == @jit Ops.and(a, b) end @testset "atan2" begin - a = ConcreteRArray([1.1, 2.2, 3.3, 4.4]) - b = ConcreteRArray([5.5, 6.6, -7.7, -8.8]) + a = Reactant.to_rarray([1.1, 2.2, 3.3, 4.4]) + b = Reactant.to_rarray([5.5, 6.6, -7.7, -8.8]) @test atan.(Array(a), Array(b)) ≈ @jit Ops.atan2(a, b) # TODO couldn't find the definition of complex atan2 to compare against, but it should be implemented end @testset "cbrt" begin - x = ConcreteRArray([1.0, 8.0, 27.0, 64.0]) + x = Reactant.to_rarray([1.0, 8.0, 27.0, 64.0]) @test [1.0, 2.0, 3.0, 4.0] ≈ @jit Ops.cbrt(x) # TODO currently crashes, reenable when #291 is fixed - # x = ConcreteRArray([1.0 + 2.0im, 8.0 + 16.0im, 27.0 + 54.0im, 64.0 + 128.0im]) + # x = Reactant.to_rarray([1.0 + 2.0im, 8.0 + 16.0im, 27.0 + 54.0im, 64.0 + 128.0im]) # @test Array(x) .^ (1//3) ≈ @jit Ops.cbrt(x) end @testset "ceil" begin - x = ConcreteRArray( + x = Reactant.to_rarray( [ 1.1 2.2 3.3 4.4 5.5 6.6 7.7 8.8 9.9 10.0 -1.1 -2.2 -3.3 -4.4 -5.5 -6.6 -7.7 -8.8 -9.9 -10.0 @@ -87,7 +87,7 @@ end g1(x) = triu(Ops.cholesky(x)) g2(x) = tril(Ops.cholesky(x; lower=true)) - x = ConcreteRArray([ + x = Reactant.to_rarray([ 10.0 2.0 3.0 2.0 5.0 6.0 3.0 6.0 9.0 @@ -95,7 +95,7 @@ end @test cholesky(Array(x)).U ≈ @jit g1(x) @test transpose(cholesky(Array(x)).U) ≈ @jit g2(x) - x = ConcreteRArray( + x = Reactant.to_rarray( [ 10.0+0.0im 2.0-3.0im 3.0-4.0im 2.0+3.0im 5.0+0.0im 3.0-2.0im @@ -110,37 +110,43 @@ end @testset "clamp" begin for (_min, _max) in [ (3, 7), - (ConcreteRNumber(3), ConcreteRNumber(7)), ( - ConcreteRArray([3, 3, 3, 3, 3, 3, 3, 3, 3, 3]), - ConcreteRArray([7, 7, 7, 7, 7, 7, 7, 7, 7, 7]), + Reactant.to_rarray(3; track_numbers=true), + Reactant.to_rarray(7; track_numbers=true), + ), + ( + Reactant.to_rarray([3, 3, 3, 3, 3, 3, 3, 3, 3, 3]), + Reactant.to_rarray([7, 7, 7, 7, 7, 7, 7, 7, 7, 7]), ), ] - x = ConcreteRArray([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + x = Reactant.to_rarray([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) @test [3, 3, 3, 4, 5, 6, 7, 7, 7, 7] == @jit Ops.clamp(_min, x, _max) end for (_min, _max) in [ (3.0, 7.0), - (ConcreteRNumber(3.0), ConcreteRNumber(7.0)), ( - ConcreteRArray([3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0]), - ConcreteRArray([7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0]), + Reactant.to_rarray(3.0; track_numbers=true), + Reactant.to_rarray(7.0; track_numbers=true), + ), + ( + Reactant.to_rarray([3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0]), + Reactant.to_rarray([7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0]), ), ] - x = ConcreteRArray([1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.0]) + x = Reactant.to_rarray([1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.0]) @test [3.0, 3.0, 3.3, 4.4, 5.5, 6.6, 7.0, 7.0, 7.0, 7.0] == @jit Ops.clamp(_min, x, _max) end end @testset "complex" begin - x = ConcreteRNumber(1.1) - y = ConcreteRNumber(2.2) + x = Reactant.to_rarray(1.1; track_numbers=true) + y = Reactant.to_rarray(2.2; track_numbers=true) @test 1.1 + 2.2im ≈ @jit Ops.complex(x, y) - x = ConcreteRArray([1.1, 2.2, 3.3, 4.4]) - y = ConcreteRArray([5.5, 6.6, -7.7, -8.8]) + x = Reactant.to_rarray([1.1, 2.2, 3.3, 4.4]) + y = Reactant.to_rarray([5.5, 6.6, -7.7, -8.8]) @test [1.1 + 5.5im, 2.2 + 6.6im, 3.3 - 7.7im, 4.4 - 8.8im] ≈ @jit Ops.complex(x, y) end @@ -156,13 +162,13 @@ end @testset "cosine" begin # it crashes in apple x86_64 and it's a deprecated platform so we don't need to care a lot... if !(Sys.isapple() && Sys.ARCH === :x86_64) - x = ConcreteRArray([0, π / 2, π, 3π / 2, 2π]) + x = Reactant.to_rarray([0, π / 2, π, 3π / 2, 2π]) @test [1, 0, -1, 0, 1] ≈ @jit Ops.cosine(x) - x = ConcreteRArray([0.0, π / 2, π, 3π / 2, 2π]) + x = Reactant.to_rarray([0.0, π / 2, π, 3π / 2, 2π]) @test [1.0, 0.0, -1.0, 0.0, 1.0] ≈ @jit Ops.cosine(x) - x = ConcreteRArray([ + x = Reactant.to_rarray([ 0.0 + 0.0im, π / 2 + 0.0im, π + 0.0im, 3π / 2 + 0.0im, 2π + 0.0im ]) @test [1.0 + 0.0im, 0.0 + 0.0im, -1.0 + 0.0im, 0.0 + 0.0im, 1.0 + 0.0im] ≈ @@ -171,20 +177,23 @@ end end @testset "count_leading_zeros" begin - x = ConcreteRArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + x = Reactant.to_rarray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) @test [64, 63, 62, 62, 61, 61, 61, 61, 60, 60] ≈ @jit Ops.count_leading_zeros(x) end @testset "divide" begin - a = ConcreteRArray([5, 6, -7, -8]) - b = ConcreteRArray([1, 2, 3, 4]) + a = Reactant.to_rarray([5, 6, -7, -8]) + b = Reactant.to_rarray([1, 2, 3, 4]) @test Array(a) .÷ Array(b) ≈ @jit Ops.divide(a, b) for (a, b) in [ - (ConcreteRArray([1.1, 2.2, 3.3, 4.4]), ConcreteRArray([5.5, 6.6, -7.7, -8.8])), ( - ConcreteRArray([1.1 + 2.2im, 3.3 + 4.4im, 5.5 + 6.6im, 7.7 + 8.8im]), - ConcreteRArray([ + Reactant.to_rarray([1.1, 2.2, 3.3, 4.4]), + Reactant.to_rarray([5.5, 6.6, -7.7, -8.8]), + ), + ( + Reactant.to_rarray([1.1 + 2.2im, 3.3 + 4.4im, 5.5 + 6.6im, 7.7 + 8.8im]), + Reactant.to_rarray([ 9.9 + 10.10im, 11.11 + 12.12im, -13.13 + -14.14im, -15.15 + -16.16im ]), ), @@ -206,11 +215,14 @@ end ) for (a, b) in [ - (ConcreteRArray([1, 2, 3, 4]), ConcreteRArray([5, 6, -7, -8])), - (ConcreteRArray([1.0, 2.0, 3.0, 4.0]), ConcreteRArray([5.0, 6.0, -7.0, -8.0])), + (Reactant.to_rarray([1, 2, 3, 4]), Reactant.to_rarray([5, 6, -7, -8])), ( - ConcreteRArray([1.0, 2.0im, 3.0, 4.0im]), - ConcreteRArray([5.0, 6.0im, -7.0im, -8.0]), + Reactant.to_rarray([1.0, 2.0, 3.0, 4.0]), + Reactant.to_rarray([5.0, 6.0, -7.0, -8.0]), + ), + ( + Reactant.to_rarray([1.0, 2.0im, 3.0, 4.0im]), + Reactant.to_rarray([5.0, 6.0im, -7.0im, -8.0]), ), ] # NOTE `LinearAlgebra.dot` is not equal to `sum(a .* b)` on complex numbers due to conjugation @@ -220,8 +232,8 @@ end @test a .* b ≈ @jit fouter_batch1(a, b) end - a = ConcreteRArray([1 2; 3 4]) - b = ConcreteRArray([5 6; -7 -8]) + a = Reactant.to_rarray([1 2; 3 4]) + b = Reactant.to_rarray([5 6; -7 -8]) @test Array(a)' * Array(b) == @jit f1(a, b) end @@ -232,39 +244,46 @@ end f4(a, b) = Ops.einsum(a, b; equation="ik,kj->ij") for (a, b) in [ - (ConcreteRArray([1, 2, 3, 4]), ConcreteRArray([5, 6, -7, -8])), - (ConcreteRArray([1.0, 2.0, 3.0, 4.0]), ConcreteRArray([5.0, 6.0, -7.0, -8.0])), + (Reactant.to_rarray([1, 2, 3, 4]), Reactant.to_rarray([5, 6, -7, -8])), ( - ConcreteRArray([1.0 + 1im, 2.0 + 2im, 3.0 - 3im, 4.0 - 4im]), - ConcreteRArray([5.0 + 5im, 6.0 + 6im, -7.0 - 7im, -8.0 - 8im]), + Reactant.to_rarray([1.0, 2.0, 3.0, 4.0]), + Reactant.to_rarray([5.0, 6.0, -7.0, -8.0]), + ), + ( + Reactant.to_rarray([1.0 + 1im, 2.0 + 2im, 3.0 - 3im, 4.0 - 4im]), + Reactant.to_rarray([5.0 + 5im, 6.0 + 6im, -7.0 - 7im, -8.0 - 8im]), ), ] - @test a .* b ≈ @jit f1(a, b) - @test reshape(kron(Array(b), Array(a)), 4, 4) ≈ @jit f2(a, b) + @test a .* b ≈ + @test_warn r"`stablehlo.einsum` is on deprecation process" @jit f1(a, b) + @test reshape(kron(Array(b), Array(a)), 4, 4) ≈ + @test_warn r"`stablehlo.einsum` is on deprecation process" @jit f2(a, b) x = ConcreteRArray(reshape(a, (2, 2))) y = ConcreteRArray(reshape(b, (2, 2))) - @test x .* y ≈ @jit f3(x, y) - @test Array(x) * Array(y) ≈ @jit f4(x, y) + @test x .* y ≈ + @test_warn r"`stablehlo.einsum` is on deprecation process" @jit f3(x, y) + @test Array(x) * Array(y) ≈ + @test_warn r"`stablehlo.einsum` is on deprecation process" @jit f4(x, y) end end @testset "exponential" begin - x = ConcreteRArray([1.0, 2.0, 3.0, 4.0]) + x = Reactant.to_rarray([1.0, 2.0, 3.0, 4.0]) @test exp.(Array(x)) ≈ @jit Ops.exponential(x) if !(Sys.isapple() && Sys.ARCH === :x86_64) - x = ConcreteRArray([1.0 + 2.0im, 3.0 + 4.0im, 5.0 + 6.0im, 7.0 + 8.0im]) + x = Reactant.to_rarray([1.0 + 2.0im, 3.0 + 4.0im, 5.0 + 6.0im, 7.0 + 8.0im]) @test exp.(Array(x)) ≈ @jit Ops.exponential(x) end end @testset "exponential_minus_one" begin - x = ConcreteRArray([1.0, 2.0, 3.0, 4.0]) + x = Reactant.to_rarray([1.0, 2.0, 3.0, 4.0]) @test expm1.(Array(x)) ≈ @jit Ops.exponential_minus_one(x) if !(Sys.isapple() && Sys.ARCH === :x86_64) - x = ConcreteRArray([1.0 + 2.0im, 3.0 + 4.0im, 5.0 + 6.0im, 7.0 + 8.0im]) + x = Reactant.to_rarray([1.0 + 2.0im, 3.0 + 4.0im, 5.0 + 6.0im, 7.0 + 8.0im]) @test expm1.(Array(x)) ≈ @jit Ops.exponential_minus_one(x) end end @@ -273,29 +292,29 @@ end grfft(x) = Ops.fft(x; type="RFFT", length=[4]) gfft(x) = Ops.fft(x; type="FFT", length=[4]) - x = ConcreteRArray([1.0, 1.0, 1.0, 1.0]) + x = Reactant.to_rarray([1.0, 1.0, 1.0, 1.0]) @test ComplexF64[4.0, 0.0, 0.0] ≈ @jit grfft(x) - x = ConcreteRArray([0.0, 1.0, 0.0, -1.0]) + x = Reactant.to_rarray([0.0, 1.0, 0.0, -1.0]) @test ComplexF64[0.0, -2.0im, 0.0] ≈ @jit grfft(x) - x = ConcreteRArray([1.0, -1.0, 1.0, -1.0]) + x = Reactant.to_rarray([1.0, -1.0, 1.0, -1.0]) @test ComplexF64[0.0, 0.0, 4.0] ≈ @jit grfft(x) - x = ConcreteRArray(ComplexF64[1.0, 1.0, 1.0, 1.0]) + x = Reactant.to_rarray(ComplexF64[1.0, 1.0, 1.0, 1.0]) @test ComplexF64[4.0, 0.0, 0.0, 0.0] ≈ @jit gfft(x) - x = ConcreteRArray(ComplexF64[0.0, 1.0, 0.0, -1.0]) + x = Reactant.to_rarray(ComplexF64[0.0, 1.0, 0.0, -1.0]) @test ComplexF64[0.0, -2.0im, 0.0, 2.0im] ≈ @jit gfft(x) - x = ConcreteRArray(ComplexF64[1.0, -1.0, 1.0, -1.0]) + x = Reactant.to_rarray(ComplexF64[1.0, -1.0, 1.0, -1.0]) @test ComplexF64[0.0, 0.0, 4.0, 0.0] ≈ @jit gfft(x) # TODO test with complex numbers and inverse FFT end @testset "floor" begin - x = ConcreteRArray( + x = Reactant.to_rarray( [ 1.1 2.2 3.3 4.4 5.5 6.6 7.7 8.8 9.9 10.0 -1.1 -2.2 -3.3 -4.4 -5.5 -6.6 -7.7 -8.8 -9.9 -10.0 @@ -305,14 +324,14 @@ end end @testset "get_dimension_size" begin - x = ConcreteRArray(fill(0, (1, 2, 3, 4))) + x = Reactant.to_rarray(fill(0, (1, 2, 3, 4))) for i in 1:4 @test i == @jit Ops.get_dimension_size(x, i) end end @testset "imag" begin - x = ConcreteRArray([1.1 + 2.2im, 3.3 + 4.4im, 5.5 + 6.6im, 7.7 + 8.8im]) + x = Reactant.to_rarray([1.1 + 2.2im, 3.3 + 4.4im, 5.5 + 6.6im, 7.7 + 8.8im]) @test [2.2, 4.4, 6.6, 8.8] ≈ @jit Ops.imag(x) end @@ -335,71 +354,74 @@ end end @testset "is_finite" begin - x = ConcreteRArray([-Inf, Inf, NaN, -10.0, -0.0, 0.0, 10.0]) + x = Reactant.to_rarray([-Inf, Inf, NaN, -10.0, -0.0, 0.0, 10.0]) @test [false, false, false, true, true, true, true] ≈ @jit Ops.is_finite(x) end @testset "log" begin - x = ConcreteRArray([1.0, 2.0, 3.0, 4.0]) + x = Reactant.to_rarray([1.0, 2.0, 3.0, 4.0]) @test log.(Array(x)) ≈ @jit Ops.log(x) - x = ConcreteRArray([1.0 + 0.0im, 2.0 + 0.0im, -3.0 + 0.0im, -4.0 + 0.0im]) + x = Reactant.to_rarray([1.0 + 0.0im, 2.0 + 0.0im, -3.0 + 0.0im, -4.0 + 0.0im]) @test log.(Array(x)) ≈ @jit Ops.log(x) end @testset "log_plus_one" begin - x = ConcreteRArray([1.0, 2.0, 3.0, 4.0]) + x = Reactant.to_rarray([1.0, 2.0, 3.0, 4.0]) @test log.(Array(x)) ≈ @jit Ops.log(x) - x = ConcreteRArray([1.0 + 0.0im, 2.0 + 0.0im, -3.0 + 0.0im, -4.0 + 0.0im]) + x = Reactant.to_rarray([1.0 + 0.0im, 2.0 + 0.0im, -3.0 + 0.0im, -4.0 + 0.0im]) @test log.(Array(x)) ≈ @jit Ops.log(x) end @testset "logistic" begin - x = ConcreteRArray([0.0, 1.0, 2.0, 3.0]) + x = Reactant.to_rarray([0.0, 1.0, 2.0, 3.0]) l(x) = 1 / (1 + exp(-x)) @test l.(Array(x)) ≈ @jit Ops.logistic(x) end @testset "maximum" begin - x = ConcreteRArray([false, false, true, true]) - y = ConcreteRArray([false, true, false, true]) + x = Reactant.to_rarray([false, false, true, true]) + y = Reactant.to_rarray([false, true, false, true]) @test [false, true, true, true] == @jit Ops.maximum(x, y) - x = ConcreteRArray([-1, 0, 1, 10]) - y = ConcreteRArray([10, 1, 0, -1]) + x = Reactant.to_rarray([-1, 0, 1, 10]) + y = Reactant.to_rarray([10, 1, 0, -1]) @test [10, 1, 1, 10] == @jit Ops.maximum(x, y) - x = ConcreteRArray([-1.0, 0.0, 1.0, 10.0]) - y = ConcreteRArray([10.0, 1.0, 0.0, -1.0]) + x = Reactant.to_rarray([-1.0, 0.0, 1.0, 10.0]) + y = Reactant.to_rarray([10.0, 1.0, 0.0, -1.0]) @test [10.0, 1.0, 1.0, 10.0] == @jit Ops.maximum(x, y) end @testset "minimum" begin - x = ConcreteRArray([false, false, true, true]) - y = ConcreteRArray([false, true, false, true]) + x = Reactant.to_rarray([false, false, true, true]) + y = Reactant.to_rarray([false, true, false, true]) @test [false, false, false, true] == @jit Ops.minimum(x, y) - x = ConcreteRArray([-1, 0, 1, 10]) - y = ConcreteRArray([10, 1, 0, -1]) + x = Reactant.to_rarray([-1, 0, 1, 10]) + y = Reactant.to_rarray([10, 1, 0, -1]) @test [-1, 0, 0, -1] == @jit Ops.minimum(x, y) - x = ConcreteRArray([-1.0, 0.0, 1.0, 10.0]) - y = ConcreteRArray([10.0, 1.0, 0.0, -1.0]) + x = Reactant.to_rarray([-1.0, 0.0, 1.0, 10.0]) + y = Reactant.to_rarray([10.0, 1.0, 0.0, -1.0]) @test [-1.0, 0.0, 0.0, -1.0] == @jit Ops.minimum(x, y) end @testset "multiply" begin - x = ConcreteRArray([false, false, true, true]) - y = ConcreteRArray([false, true, false, true]) + x = Reactant.to_rarray([false, false, true, true]) + y = Reactant.to_rarray([false, true, false, true]) @test [false, false, false, true] == @jit Ops.multiply(x, y) for (a, b) in [ - (ConcreteRArray([5, 6, -7, -8]), ConcreteRArray([1, 2, 3, 4])), - (ConcreteRArray([1.1, 2.2, 3.3, 4.4]), ConcreteRArray([5.5, 6.6, -7.7, -8.8])), + (Reactant.to_rarray([5, 6, -7, -8]), Reactant.to_rarray([1, 2, 3, 4])), ( - ConcreteRArray([1.1 + 2.2im, 3.3 + 4.4im, 5.5 + 6.6im, 7.7 + 8.8im]), - ConcreteRArray([ + Reactant.to_rarray([1.1, 2.2, 3.3, 4.4]), + Reactant.to_rarray([5.5, 6.6, -7.7, -8.8]), + ), + ( + Reactant.to_rarray([1.1 + 2.2im, 3.3 + 4.4im, 5.5 + 6.6im, 7.7 + 8.8im]), + Reactant.to_rarray([ 9.9 + 10.10im, 11.11 + 12.12im, -13.13 + -14.14im, -15.15 + -16.16im ]), ), @@ -409,52 +431,52 @@ end end @testset "negate" begin - x = ConcreteRArray([-1, 0, 1, 10]) + x = Reactant.to_rarray([-1, 0, 1, 10]) @test [1, 0, -1, -10] == @jit Ops.negate(x) # on unsigned integers: (1) bitcast, (2) change sign and (3) bitcast - x = ConcreteRArray(UInt[0, 1, 10]) + x = Reactant.to_rarray(UInt[0, 1, 10]) @test reinterpret(UInt, Base.checked_neg.(reinterpret.(Int, Array(x)))) == @jit Ops.negate(x) - x = ConcreteRArray([-1.0, 0.0, 1.0, 10.0]) + x = Reactant.to_rarray([-1.0, 0.0, 1.0, 10.0]) @test [1.0, 0.0, -1.0, -10.0] ≈ @jit Ops.negate(x) - x = ConcreteRArray([-1.0 + 2im, 0.0 - 3im, 1.0 + 4im, 10.0 - 5im]) + x = Reactant.to_rarray([-1.0 + 2im, 0.0 - 3im, 1.0 + 4im, 10.0 - 5im]) @test [1.0 - 2im, 0.0 + 3im, -1.0 - 4im, -10.0 + 5im] ≈ @jit Ops.negate(x) end @testset "not" begin - x = ConcreteRArray([false, true]) + x = Reactant.to_rarray([false, true]) @test [true, false] == @jit Ops.not(x) - x = ConcreteRArray([1, 0]) + x = Reactant.to_rarray([1, 0]) @test [~1, ~0] == @jit Ops.not(x) end @testset "optimization_barrier" begin # TODO is there a better way to test this? we're only testing for identify # TODO crashing for just 1 argument - x = ConcreteRArray([1, 2, 3, 4]) - y = ConcreteRArray([5, 6, -7, -8]) + x = Reactant.to_rarray([1, 2, 3, 4]) + y = Reactant.to_rarray([5, 6, -7, -8]) @test (x, y) == @jit Ops.optimization_barrier(x, y) end @testset "or" begin - a = ConcreteRArray([false, false, true, true]) - b = ConcreteRArray([false, true, false, true]) + a = Reactant.to_rarray([false, false, true, true]) + b = Reactant.to_rarray([false, true, false, true]) @test [false, true, true, true] ≈ @jit Ops.or(a, b) - a = ConcreteRArray([1, 2, 3, 4]) - b = ConcreteRArray([5, 6, -7, -8]) + a = Reactant.to_rarray([1, 2, 3, 4]) + b = Reactant.to_rarray([5, 6, -7, -8]) @test Array(a) .| Array(b) == @jit Ops.or(a, b) end @testset "outfeed" begin end @testset "pad" begin - x = ConcreteRArray([1, 2, 3, 4]) - v = ConcreteRNumber(0) + x = Reactant.to_rarray([1, 2, 3, 4]) + v = Reactant.to_rarray(0; track_numbers=true) flow(x, v) = Ops.pad(x, v; low=[1]) @test [0, 1, 2, 3, 4] == @jit flow(x, v) @@ -465,7 +487,7 @@ end finterior(x, v) = Ops.pad(x, v; interior=[1]) @test [1, 0, 2, 0, 3, 0, 4] == @jit finterior(x, v) - x = ConcreteRArray([1 2; 3 4]) + x = Reactant.to_rarray([1 2; 3 4]) glow(x, v) = Ops.pad(x, v; low=[1, 2]) @test [0 0 0 0; 0 0 1 2; 0 0 3 4] == @jit glow(x, v) @@ -478,28 +500,28 @@ end end @testset "partition_id" begin - @test @jit(Ops.partition_id()) isa ConcreteRNumber{UInt32} + @test @jit(Ops.partition_id()) isa ConcretePJRTNumber{UInt32} end @testset "popcnt" begin - x = ConcreteRArray([0, 1, 2, 127]) + x = Reactant.to_rarray([0, 1, 2, 127]) @test [0, 1, 1, 7] == @jit Ops.popcnt(x) end @testset "power" begin - x = ConcreteRArray([-1, -1, -1, -1]) - p = ConcreteRArray([0, 1, 2, 3]) + x = Reactant.to_rarray([-1, -1, -1, -1]) + p = Reactant.to_rarray([0, 1, 2, 3]) @test Array(x) .^ Array(p) == @jit Ops.power(x, p) if !(Sys.isapple() && Sys.ARCH === :x86_64) - x = ConcreteRArray([0.0 + 1.0im, 0.0 + 1.0im, 0.0 + 1.0im, 0.0 + 1.0im]) - p = ConcreteRArray([0.0 + 0.0im, 1.0 + 0.0im, 2.0 + 0.0im, 3.0 + 0.0im]) + x = Reactant.to_rarray([0.0 + 1.0im, 0.0 + 1.0im, 0.0 + 1.0im, 0.0 + 1.0im]) + p = Reactant.to_rarray([0.0 + 0.0im, 1.0 + 0.0im, 2.0 + 0.0im, 3.0 + 0.0im]) @test Array(x) .^ Array(p) ≈ @jit Ops.power(x, p) end end @testset "real" begin - x = ConcreteRArray([1.1 + 2.2im, 3.3 + 4.4im, 5.5 + 6.6im, 7.7 + 8.8im]) + x = Reactant.to_rarray([1.1 + 2.2im, 3.3 + 4.4im, 5.5 + 6.6im, 7.7 + 8.8im]) @test [1.1, 3.3, 5.5, 7.7] ≈ @jit Ops.real(x) end @@ -507,31 +529,34 @@ end @testset "remainder" begin for (a, b) in [ - (ConcreteRArray([1, 2, 3, 4]), ConcreteRArray([5, 6, -7, -8])), - (ConcreteRArray([1.1, 2.2, 3.3, 4.4]), ConcreteRArray([5.5, 6.6, -7.7, -8.8])), + (Reactant.to_rarray([1, 2, 3, 4]), Reactant.to_rarray([5, 6, -7, -8])), + ( + Reactant.to_rarray([1.1, 2.2, 3.3, 4.4]), + Reactant.to_rarray([5.5, 6.6, -7.7, -8.8]), + ), ] @test Array(a) .% Array(b) ≈ @jit Ops.remainder(a, b) end end @testset "replica_id" begin - @test @jit(Ops.partition_id()) isa ConcreteRNumber{UInt32} + @test @jit(Ops.partition_id()) isa ConcretePJRTNumber{UInt32} end @testset "reshape" begin - x = ConcreteRArray([1, 2, 3, 4]) + x = Reactant.to_rarray([1, 2, 3, 4]) @test reshape(Array(x), 2, 2) == @jit Ops.reshape(x, 2, 2) - x = ConcreteRArray(collect(reshape(1:12, 2, 2, 3))) + x = Reactant.to_rarray(collect(reshape(1:12, 2, 2, 3))) @test reshape(Array(x), 3, 1, 4) == @jit Ops.reshape(x, 3, 1, 4) end @testset "reverse" begin - x = ConcreteRArray([1, 2, 3, 4]) + x = Reactant.to_rarray([1, 2, 3, 4]) g1(x) = Ops.reverse(x; dimensions=[1]) @test [4, 3, 2, 1] == @jit g1(x) - x = ConcreteRArray([1 2; 3 4]) + x = Reactant.to_rarray([1 2; 3 4]) g2(x) = Ops.reverse(x; dimensions=[2]) @test [3 4; 1 2] == @jit g1(x) @test [2 1; 4 3] == @jit g2(x) @@ -546,7 +571,7 @@ end @testset for (alg, sz) in [("DEFAULT", 2), ("PHILOX", 2), ("PHILOX", 3), ("THREE_FRY", 2)] - seed = ConcreteRArray(zeros(UInt64, sz)) + seed = Reactant.to_rarray(zeros(UInt64, sz)) res = @jit genInt32(seed) @test res.output_state !== seed @@ -585,48 +610,48 @@ end end @testset "round_nearest_afz" begin - x = ConcreteRArray([-2.5, 0.4, 0.5, 0.6, 2.5]) + x = Reactant.to_rarray([-2.5, 0.4, 0.5, 0.6, 2.5]) @test [-3.0, 0.0, 1.0, 1.0, 3.0] ≈ @jit Ops.round_nearest_afz(x) end @testset "round_nearest_even" begin - x = ConcreteRArray([-2.5, 0.4, 0.5, 0.6, 2.5]) + x = Reactant.to_rarray([-2.5, 0.4, 0.5, 0.6, 2.5]) @test [-2.0, 0.0, 0.0, 1.0, 2.0] ≈ @jit Ops.round_nearest_even(x) end @testset "rsqrt" begin - x = ConcreteRArray([1.0 4.0; 9.0 25.0]) + x = Reactant.to_rarray([1.0 4.0; 9.0 25.0]) @test 1 ./ sqrt.(Array(x)) ≈ @jit Ops.rsqrt(x) if !(Sys.isapple() && Sys.ARCH === :x86_64) - x = ConcreteRArray([1.0+1im 4.0+2im; 9.0+3im 25.0+4im]) + x = Reactant.to_rarray([1.0+1im 4.0+2im; 9.0+3im 25.0+4im]) @test 1 ./ sqrt.(Array(x)) ≈ @jit Ops.rsqrt(x) end end @testset "select" begin - ontrue = ConcreteRArray([1, 2, 3, 4]) - onfalse = ConcreteRArray([5, 6, -7, -8]) + ontrue = Reactant.to_rarray([1, 2, 3, 4]) + onfalse = Reactant.to_rarray([5, 6, -7, -8]) - pred = ConcreteRArray([true, true, false, false]) + pred = Reactant.to_rarray([true, true, false, false]) @test [1, 2, -7, -8] == @jit Ops.select(pred, ontrue, onfalse) - pred = ConcreteRArray([false, false, true, true]) + pred = Reactant.to_rarray([false, false, true, true]) @test [5, 6, 3, 4] == @jit Ops.select(pred, ontrue, onfalse) - pred = ConcreteRNumber(true) + pred = Reactant.to_rarray(true; track_numbers=true) @test ontrue == @jit Ops.select(pred, ontrue, onfalse) - pred = ConcreteRNumber(false) + pred = Reactant.to_rarray(false; track_numbers=true) @test onfalse == @jit Ops.select(pred, ontrue, onfalse) - ontrue = ConcreteRNumber(1) - onfalse = ConcreteRNumber(2) + ontrue = Reactant.to_rarray(1; track_numbers=true) + onfalse = Reactant.to_rarray(2; track_numbers=true) - pred = ConcreteRNumber(true) + pred = Reactant.to_rarray(true; track_numbers=true) @test ontrue == @jit Ops.select(pred, ontrue, onfalse) - pred = ConcreteRNumber(false) + pred = Reactant.to_rarray(false; track_numbers=true) @test onfalse == @jit Ops.select(pred, ontrue, onfalse) end @@ -635,31 +660,31 @@ end @testset "set_dimension_size" begin end @testset "shift_left" begin - a = ConcreteRArray([-1, 0, 1]) - b = ConcreteRArray([1, 2, 3]) + a = Reactant.to_rarray([-1, 0, 1]) + b = Reactant.to_rarray([1, 2, 3]) @test [-2, 0, 8] == @jit Ops.shift_left(a, b) end @testset "shift_right_arithmetic" begin - a = ConcreteRArray([-1, 0, 8]) - b = ConcreteRArray([1, 2, 3]) + a = Reactant.to_rarray([-1, 0, 8]) + b = Reactant.to_rarray([1, 2, 3]) @test [-1, 0, 1] == @jit Ops.shift_right_arithmetic(a, b) end @testset "shift_right_logical" begin - a = ConcreteRArray([-1, 0, 8]) - b = ConcreteRArray([1, 2, 3]) + a = Reactant.to_rarray([-1, 0, 8]) + b = Reactant.to_rarray([1, 2, 3]) @test [9223372036854775807, 0, 1] == @jit Ops.shift_right_logical(a, b) end @testset "sign" begin - x = ConcreteRArray([-1, 0, 1]) + x = Reactant.to_rarray([-1, 0, 1]) @test [-1, 0, 1] == @jit Ops.sign(x) - x = ConcreteRArray([Inf, -Inf, NaN, -NaN, -1.0, -0.0, +0.0, 1.0]) + x = Reactant.to_rarray([Inf, -Inf, NaN, -NaN, -1.0, -0.0, +0.0, 1.0]) @test [1.0, -1.0, NaN, NaN, -1.0, -0.0, 0.0, 1.0] ≈ @jit(Ops.sign(x)) nans = true - x = ConcreteRArray([ + x = Reactant.to_rarray([ NaN + 1.0im, 1.0 + NaN, 0.0 + 0.0im, -1.0 + 2.0im, 0.0 - 3.0im, 1.0 + 4.0im ]) @test [ @@ -674,13 +699,13 @@ end @testset "sine" begin if !(Sys.isapple() && Sys.ARCH === :x86_64) - x = ConcreteRArray([0, π / 2, π, 3π / 2, 2π]) + x = Reactant.to_rarray([0, π / 2, π, 3π / 2, 2π]) @test [0, 1, 0, -1, 0] ≈ @jit Ops.sine(x) - x = ConcreteRArray([0.0, π / 2, π, 3π / 2, 2π]) + x = Reactant.to_rarray([0.0, π / 2, π, 3π / 2, 2π]) @test [0.0, 1.0, 0.0, -1.0, 0.0] ≈ @jit Ops.sine(x) - x = ConcreteRArray([ + x = Reactant.to_rarray([ 0.0 + 0.0im, π / 2 + 0.0im, π + 0.0im, 3π / 2 + 0.0im, 2π + 0.0im ]) @test [0.0 + 0.0im, 1.0 + 0.0im, 0.0 + 0.0im, -1.0 + 0.0im, 0.0 + 0.0im] ≈ @@ -688,29 +713,45 @@ end end end +@testset "sort" begin + basic_sort(x, dimension) = only(Ops.sort(x; comparator=(a, b) -> a < b, dimension)) + @testset for i in 1:3 + t_size = tuple(fill(10, (i,))...) + x = randn(t_size) + xa = Reactant.to_rarray(x) + + @testset for j in 1:i + @test (i == 1 ? sort(x) : sort(x; dims=j)) == @jit basic_sort(xa, j) + end + end +end + @testset "slice" begin - x = ConcreteRArray([1, 2, 3, 4]) + x = Reactant.to_rarray([1, 2, 3, 4]) @test [2, 3] == @jit Ops.slice(x, [2], [3]) @test [1] == @jit Ops.slice(x, [1], [1]) end @testset "sqrt" begin - x = ConcreteRArray([1.0, 4.0, 9.0, 16.0]) + x = Reactant.to_rarray([1.0, 4.0, 9.0, 16.0]) @test [1.0, 2.0, 3.0, 4.0] ≈ @jit Ops.sqrt(x) if !(Sys.isapple() && Sys.ARCH === :x86_64) - x = ConcreteRArray([1.0 + 0im, 0.0 + 1im]) + x = Reactant.to_rarray([1.0 + 0im, 0.0 + 1im]) @test [1.0 + 0im, 1 / √2 * (1 + im)] ≈ @jit Ops.sqrt(x) end end @testset "subtract" begin for (a, b) in [ - (ConcreteRArray([1, 2, 3, 4]), ConcreteRArray([5, 6, -7, -8])), - (ConcreteRArray([1.1, 2.2, 3.3, 4.4]), ConcreteRArray([5.5, 6.6, -7.7, -8.8])), + (Reactant.to_rarray([1, 2, 3, 4]), Reactant.to_rarray([5, 6, -7, -8])), ( - ConcreteRArray([1.1 + 2.2im, 3.3 + 4.4im, 5.5 + 6.6im, 7.7 + 8.8im]), - ConcreteRArray([ + Reactant.to_rarray([1.1, 2.2, 3.3, 4.4]), + Reactant.to_rarray([5.5, 6.6, -7.7, -8.8]), + ), + ( + Reactant.to_rarray([1.1 + 2.2im, 3.3 + 4.4im, 5.5 + 6.6im, 7.7 + 8.8im]), + Reactant.to_rarray([ 9.9 + 10.10im, 11.11 + 12.12im, -13.13 + -14.14im, -15.15 + -16.16im ]), ), @@ -722,10 +763,10 @@ end @testset "tan" begin if !(Sys.isapple() && Sys.ARCH === :x86_64) # TODO tan(π/2) is Inf but it returns 1.633123935319537e16 - x = ConcreteRArray([0, π / 4, π / 2, 3π / 4, π]) + x = Reactant.to_rarray([0, π / 4, π / 2, 3π / 4, π]) @test [0.0, 1.0, 1.633123935319537e16, -1.0, 0.0] ≈ @jit Ops.tan(x) - x = ConcreteRArray([ + x = Reactant.to_rarray([ 0.0 + 0.0im, π / 4 + 0.0im, π / 2 + 0.0im, 3π / 4 + 0.0im, π + 0.0im ]) @test ComplexF64[0.0, 1.0, 1.633123935319537e16, -1.0, 0.0] ≈ @jit Ops.tan(x) @@ -733,17 +774,17 @@ end end @testset "tanh" begin - x = ConcreteRArray([-1.0, 0.0, 1.0]) + x = Reactant.to_rarray([-1.0, 0.0, 1.0]) @test [-0.7615941559557649, 0.0, 0.7615941559557649] ≈ @jit Ops.tanh(x) if !(Sys.isapple() && Sys.ARCH === :x86_64) - x = ConcreteRArray(ComplexF64[-1.0, 0.0, 1.0]) + x = Reactant.to_rarray(ComplexF64[-1.0, 0.0, 1.0]) @test ComplexF64[-0.7615941559557649, 0.0, 0.7615941559557649] ≈ @jit Ops.tanh(x) end end @testset "transpose" begin - x = ConcreteRArray(collect(reshape(1:12, 2, 2, 3))) + x = Reactant.to_rarray(collect(reshape(1:12, 2, 2, 3))) @test [ 1 3; 5 7; 9 11;;; 2 4; 6 8; 10 12 @@ -759,10 +800,10 @@ end # f5(a) = Ops.unary_einsum(a; equation="ij->i") # f6(a) = Ops.unary_einsum(a; equation="ii->i") -# x = ConcreteRArray([1, 2, 3, 4]) +# x = Reactant.to_rarray([1, 2, 3, 4]) # @test sum(Array(x)) ≈ @jit f1(x) -# x = ConcreteRArray([1 2; 3 4]) +# x = Reactant.to_rarray([1 2; 3 4]) # @test sum(Array(x)) ≈ @jit f4(x) # @test Base.transpose(Array(x)) ≈ @jit f3(x) # @test sum(Array(x); dims=1) ≈ @jit f4(x) @@ -771,57 +812,57 @@ end # end @testset "xor" begin - a = ConcreteRArray([false, false, true, true]) - b = ConcreteRArray([false, true, false, true]) + a = Reactant.to_rarray([false, false, true, true]) + b = Reactant.to_rarray([false, true, false, true]) @test [false, true, true, false] ≈ @jit Ops.xor(a, b) - a = ConcreteRArray([1, 2, 3, 4]) - b = ConcreteRArray([5, 6, -7, -8]) + a = Reactant.to_rarray([1, 2, 3, 4]) + b = Reactant.to_rarray([5, 6, -7, -8]) @test Array(a) .⊻ Array(b) == @jit Ops.xor(a, b) end @testset "acos" begin - x = ConcreteRArray([-1.0, 0.0, 1.0]) + x = Reactant.to_rarray([-1.0, 0.0, 1.0]) @test acos.(Array(x)) ≈ @jit Ops.acos(x) end @testset "acosh" begin - x = ConcreteRArray([1.0, 10.0]) + x = Reactant.to_rarray([1.0, 10.0]) @test acosh.(Array(x)) ≈ @jit Ops.acosh(x) end @testset "asin" begin - x = ConcreteRArray([-1.0, 0.0, 1.0]) + x = Reactant.to_rarray([-1.0, 0.0, 1.0]) @test asin.(Array(x)) ≈ @jit Ops.asin(x) end @testset "asinh" begin - x = ConcreteRArray([-1.0, 0.0, 1.0]) + x = Reactant.to_rarray([-1.0, 0.0, 1.0]) @test asinh.(Array(x)) ≈ @jit Ops.asinh(x) end @testset "atan" begin - x = ConcreteRArray([-1.0, 0.0, 1.0]) + x = Reactant.to_rarray([-1.0, 0.0, 1.0]) @test atan.(Array(x)) ≈ @jit Ops.atan(x) end @testset "atanh" begin - x = ConcreteRArray([-1.0, 0.0, 1.0]) + x = Reactant.to_rarray([-1.0, 0.0, 1.0]) @test atanh.(Array(x)) ≈ @jit Ops.atanh(x) end @testset "bessel_i1e" begin - x = ConcreteRArray([-1.0, 0.0, 1.0, 2.0, 3.0, 4.0]) + x = Reactant.to_rarray([-1.0, 0.0, 1.0, 2.0, 3.0, 4.0]) @test SpecialFunctions.besselix.(1, Array(x)) ≈ @jit Ops.bessel_i1e(x) end @testset "conj" begin - x = ConcreteRArray([-1.0 + 2im, 0.0 - 1im, 1.0 + 4im]) + x = Reactant.to_rarray([-1.0 + 2im, 0.0 - 1im, 1.0 + 4im]) @test conj(Array(x)) ≈ @jit Ops.conj(x) end @testset "cosh" begin - x = ConcreteRArray([-1.0, 0.0, 1.0]) + x = Reactant.to_rarray([-1.0, 0.0, 1.0]) @test cosh.(Array(x)) ≈ @jit Ops.cosh(x) end @@ -829,51 +870,52 @@ end # small divergence between chlo.digamma and SpecialFunctions.digamma: # on <=0, chlo.digamma returns NaN, SpecialFunctions.digamma returns Inf if !(Sys.isapple() && Sys.ARCH === :x86_64) - x = ConcreteRArray([-1.0, 0.0, 1.0]) + x = Reactant.to_rarray([-1.0, 0.0, 1.0]) @test [NaN, NaN, SpecialFunctions.digamma(1.0)] ≈ @jit(Ops.digamma(x)) nans = true end end @testset "erf_inv" begin - x = ConcreteRArray([-1.0, 0.0, 1.0]) + x = Reactant.to_rarray([-1.0, 0.0, 1.0]) @test SpecialFunctions.erfinv.(Array(x)) ≈ @jit Ops.erf_inv(x) end @testset "erf" begin - x = ConcreteRArray([-1.0, 0.0, 1.0]) + x = Reactant.to_rarray([-1.0, 0.0, 1.0]) @test SpecialFunctions.erf.(Array(x)) ≈ @jit Ops.erf(x) end @testset "erfc" begin - x = ConcreteRArray([-1.0, 0.0, 1.0]) + x = Reactant.to_rarray([-1.0, 0.0, 1.0]) @test SpecialFunctions.erfc.(Array(x)) ≈ @jit Ops.erfc(x) end @testset "is_inf" begin - x = ConcreteRArray([-Inf, Inf, NaN, -10.0, -0.0, 0.0, 10.0]) + x = Reactant.to_rarray([-Inf, Inf, NaN, -10.0, -0.0, 0.0, 10.0]) @test [true, true, false, false, false, false, false] ≈ @jit Ops.is_inf(x) end @testset "is_neg_inf" begin - x = ConcreteRArray([-Inf, Inf, NaN, -10.0, -0.0, 0.0, 10.0]) + x = Reactant.to_rarray([-Inf, Inf, NaN, -10.0, -0.0, 0.0, 10.0]) @test [true, false, false, false, false, false, false] ≈ @jit Ops.is_neg_inf(x) end @testset "is_pos_inf" begin - x = ConcreteRArray([-Inf, Inf, NaN, -10.0, -0.0, 0.0, 10.0]) + x = Reactant.to_rarray([-Inf, Inf, NaN, -10.0, -0.0, 0.0, 10.0]) @test [false, true, false, false, false, false, false] ≈ @jit Ops.is_pos_inf(x) end @testset "lgamma" begin if !(Sys.isapple() && Sys.ARCH === :x86_64) - x = ConcreteRArray([-1.0, 0.0, 1.0, 2.5]) - @test SpecialFunctions.lgamma.(Array(x)) ≈ @jit Ops.lgamma(x) + x = Reactant.to_rarray([-1.0, 0.0, 1.0, 2.5]) + lgamma(x) = (SpecialFunctions.logabsgamma(x))[1] + @test lgamma.(Array(x)) ≈ @jit Ops.lgamma(x) end end @testset "next_after" begin - x = ConcreteRArray([-1.0, 0.0, 1.0, 1.0, 2.5, 1e18, 1e18, 3e-9, 3e-9]) - y = ConcreteRArray([-2.0, 0.0, 1.0, 2.0, 3.0, 0.0, 1e19, 0, 1]) + x = Reactant.to_rarray([-1.0, 0.0, 1.0, 1.0, 2.5, 1e18, 1e18, 3e-9, 3e-9]) + y = Reactant.to_rarray([-2.0, 0.0, 1.0, 2.0, 3.0, 0.0, 1e19, 0, 1]) @test [ prevfloat(-1.0), 0.0, @@ -889,26 +931,33 @@ end @testset "polygamma" begin if !(Sys.isapple() && Sys.ARCH === :x86_64) - x = ConcreteRArray([-1.0, 0.0, 1.0, 1.0, 2.5]) - m = ConcreteRArray([3.0, 3.0, 2.0, 3.0, 4.0]) + x = Reactant.to_rarray([-1.0, 0.0, 1.0, 1.0, 2.5]) + m = Reactant.to_rarray([3.0, 3.0, 2.0, 3.0, 4.0]) @test SpecialFunctions.polygamma.(Int.(Array(m)), Array(x)) ≈ @jit Ops.polygamma(m, x) end end @testset "sinh" begin - x = ConcreteRArray([-1.0, 0.0, 1.0]) + x = Reactant.to_rarray([-1.0, 0.0, 1.0]) @test sinh.(Array(x)) ≈ @jit Ops.sinh(x) end @testset "top_k" begin - x = ConcreteRArray([1, 2, 3, 4]) - @test (; values=[4, 3], indices=[3, 2]) == @jit Ops.top_k(x, 2) + x = Reactant.to_rarray([1, 2, 3, 4]) + @test (; values=[4, 3], indices=[4, 3]) == @jit Ops.top_k(x, 2) + + x = Reactant.to_rarray([NaN, 123, 456, 789, 121]) + res = @jit Ops.top_k(x, 2) + true_res = (; values=[NaN, 789], indices=[1, 4]) + @test res.indices == true_res.indices + @test @allowscalar isnan(res.values[1]) + @test @allowscalar res.values[2] == 789 end @testset "zeta" begin - s = ConcreteRArray([1.0, 2.0, 50.0]) - z = ConcreteRArray([1e-8, 0.001, 2.0]) + s = Reactant.to_rarray([1.0, 2.0, 50.0]) + z = Reactant.to_rarray([1e-8, 0.001, 2.0]) @test SpecialFunctions.zeta.(Array(s), Array(z)) ≈ @jit Ops.zeta(s, z) end diff --git a/test/runtests.jl b/test/runtests.jl index 2b3238d101..1e848d1715 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -56,11 +56,25 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Shortcuts to MLIR ops" include("ops.jl") @safetestset "Wrapped Arrays" include("wrapped_arrays.jl") @safetestset "Control Flow" include("control_flow.jl") + @safetestset "Sorting" include("sorting.jl") + @safetestset "Indexing" include("indexing.jl") + if !Sys.isapple() + @safetestset "Custom Number Types" include("custom_number_types.jl") + end + @safetestset "Sharding" include("sharding.jl") + + @testset "IFRT" begin + @safetestset "IFRT Low-Level API" include("ifrt/low_level.jl") + end end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" + @safetestset "CUDA" include("integration/cuda.jl") + @safetestset "KernelAbstractions" include("integration/kernelabstractions.jl") @safetestset "Linear Algebra" include("integration/linear_algebra.jl") + @safetestset "OffsetArrays" include("integration/offsetarrays.jl") @safetestset "AbstractFFTs" include("integration/fft.jl") + @safetestset "SpecialFunctions" include("integration/special_functions.jl") @safetestset "Random" include("integration/random.jl") @safetestset "Python" include("integration/python.jl") end diff --git a/test/sharding.jl b/test/sharding.jl new file mode 100644 index 0000000000..5acd9c82c9 --- /dev/null +++ b/test/sharding.jl @@ -0,0 +1,265 @@ +# Currently an extremely simple test +using Reactant, Test + +const addressable_devices = Reactant.addressable_devices() + +function fn_test1(x) + y = x .+ x + x .+= 1 + z = x .* y + return y, x, z +end + +@testset "Sharding Across 2 Devices" begin + if length(addressable_devices) ≥ 2 + mesh = Sharding.Mesh([0 1;], ("x", "y")) + + data_sharding = Sharding.NamedSharding(mesh, ("y", nothing, "x")) + data_sharding2 = Sharding.NamedSharding(mesh, (nothing, "x", nothing)) + data_sharding3 = Sharding.NamedSharding(mesh, (nothing, nothing, nothing)) # fully replicated data + + data = reshape(collect(1:(16 * 4 * 12)) ./ (16 * 4 * 12), 16, 4, 12) + + cdata = Reactant.to_rarray(data) + cdata_sharded = Reactant.to_rarray(data; sharding=data_sharding) + cdata_sharded2 = Reactant.to_rarray(data; sharding=data_sharding2) + cdata_sharded3 = Reactant.to_rarray(data; sharding=data_sharding3) + + @test data ≈ + Array(cdata) ≈ + Array(cdata_sharded) ≈ + Array(cdata_sharded2) ≈ + Array(cdata_sharded3) + + @test cdata_sharded.sharding isa Sharding.ShardInfo{<:Sharding.HloSharding} + @test cdata_sharded2.sharding isa Sharding.ShardInfo{<:Sharding.HloSharding} + @test cdata_sharded3.sharding isa Sharding.ShardInfo{<:Sharding.HloSharding} + @test cdata.sharding isa Sharding.NoShardInfo + + true_res_y, true_res_x, true_res_z = fn_test1(data) + + for cd in (cdata, cdata_sharded, cdata_sharded2, cdata_sharded3) + local res_y, res_x, res_z = @jit fn_test1(cd) + @test Array(cd) ≈ Array(res_x) + @test Array(res_y) ≈ true_res_y + @test Array(res_z) ≈ true_res_z + @test Array(res_x) ≈ true_res_x + end + else + @warn "Not enough addressable devices to run sharding tests" + end +end + +predict(samples, w1, w2) = sin.(w2 * (w1 * tanh.(samples))) + +fn_test2(x) = x .+ x' + +fn_test3(x) = sum(x; dims=1) + +@testset "Sharding Across 8 Devices" begin + if length(addressable_devices) ≥ 8 + mesh = Sharding.Mesh(reshape(collect(Int64, 0:7), (4, 2)), ("data", "model")) + + x = reshape(collect(Float32, 1:16), 4, 4) + x_ra = Reactant.to_rarray( + x; sharding=Sharding.NamedSharding(mesh, ("data", "model")) + ) + + @test Array(@jit fn_test2(x_ra)) ≈ fn_test2(x) + + y_ra = @jit fn_test2(x_ra) + @test Array(@jit fn_test2(y_ra)) ≈ fn_test2(fn_test2(x)) + + @test Array(@jit fn_test3(x_ra)) ≈ fn_test3(x) + + samples = reshape(collect(Float32, 1:48), 4, 12) + w1 = reshape(collect(Float32, 1:16), 4, 4) + w2 = reshape(collect(Float32, 1:32), 8, 4) + + for (samples_sharding, w1_sharding, w2_sharding) in zip( + ( + Sharding.NamedSharding(mesh, ("model", "data")), + Sharding.NamedSharding(mesh, ("model", nothing)), + Sharding.NamedSharding(mesh, (nothing, "data")), + Sharding.DimsSharding(mesh, (2,), (:data,)), + ), + ( + Sharding.NamedSharding(mesh, ("model", "data")), + Sharding.NamedSharding(mesh, (nothing, "data")), + Sharding.NoSharding(), + Sharding.DimsSharding(mesh, (-2,), (:model,)), + ), + ( + Sharding.NamedSharding(mesh, ("model", "data")), + Sharding.NoSharding(), + Sharding.NoSharding(), + Sharding.NamedSharding(mesh, ("model", "data")), + ), + ) + samples_ra = Reactant.to_rarray(samples; sharding=samples_sharding) + w1_ra = Reactant.to_rarray(w1; sharding=w1_sharding) + w2_ra = Reactant.to_rarray(w2; sharding=w2_sharding) + + @test Array(@jit(predict(samples_ra, w1_ra, w2_ra))) ≈ predict(samples, w1, w2) + end + end +end + +@testset "Sharding with non-iota mesh" begin + if length(addressable_devices) ≥ 8 + mesh = Sharding.Mesh(reshape([4, 6, 0, 2, 7, 3, 1, 5], 4, 2), ("data", "model")) + x = reshape(collect(Float32, 1:16), 4, 4) + x_ra = Reactant.to_rarray( + x; sharding=Sharding.NamedSharding(mesh, ("data", "model")) + ) + @test Array(@jit fn_test2(x_ra)) ≈ fn_test2(x) + @test Reactant.to_number(@jit sum(x_ra)) ≈ sum(x) + else + @warn "Not enough addressable devices to run sharding tests" + end +end + +@testset "Multiple Axis Partition Spec" begin + if length(addressable_devices) ≥ 8 + mesh = Sharding.Mesh(reshape(collect(Int64, 0:7), 2, 4), ("data", "model")) + x = reshape(collect(Float32, 1:64), 8, 8) + x_ra = Reactant.to_rarray( + x; sharding=Sharding.NamedSharding(mesh, (("data", "model"), nothing)) + ) + @test Array(@jit fn_test2(x_ra)) ≈ fn_test2(x) + @test Reactant.to_number(@jit sum(x_ra)) ≈ sum(x) + else + @warn "Not enough addressable devices to run sharding tests" + end +end + +@testset "Open Axis Partition Spec" begin + if length(addressable_devices) ≥ 8 + mesh = Sharding.Mesh(reshape(collect(Int64, 0:7), 2, 4), ("data", "model")) + x = reshape(collect(Float32, 1:16), 4, 4) + x_ra = Reactant.to_rarray( + x; + sharding=Sharding.NamedSharding( + mesh, ("model", nothing); is_closed=(false, false) + ), + ) + @test Array(@jit fn_test2(x_ra)) ≈ fn_test2(x) + @test Reactant.to_number(@jit sum(x_ra)) ≈ sum(x) + else + @warn "Not enough addressable devices to run sharding tests" + end +end + +fn_test4(x, y) = x .+ sin.(y') + +@testset "Multiple Mesh Sharding" begin + if length(addressable_devices) ≥ 8 + mesh1 = Sharding.Mesh(reshape(collect(Int64, 0:7), (4, 2)), ("m1_x", "m1_y")) + mesh2 = Sharding.Mesh( + reshape([4, 6, 0, 2, 7, 3, 1, 5], 2, 2, 2), ("m2_x", "m2_y", "m2_z") + ) + + x = reshape(collect(Float32, 1:32), 8, 4) + y = reshape(collect(Float32, 1:32), 4, 8) + + x_ra = Reactant.to_rarray( + x; sharding=Sharding.NamedSharding(mesh1, ("m1_y", "m1_x")) + ) + y_ra = Reactant.to_rarray( + y; sharding=Sharding.NamedSharding(mesh2, ("m2_y", nothing)) + ) + + # This is supported in shardy & XLA, but we don't support it yet. + @test_throws ErrorException @jit fn_test4(x_ra, y_ra) + else + @warn "Not enough addressable devices to run sharding tests" + end +end + +@testset "Sharding Constraint" begin + if length(addressable_devices) ≥ 8 + mesh = Sharding.Mesh(reshape(collect(Int64, 0:7), 2, 4), ("data", "model")) + + x = reshape(collect(Float32, 1:16), 4, 4) + x_ra = Reactant.to_rarray( + x; sharding=Sharding.NamedSharding(mesh, ("data", "model")) + ) + + constraint = Sharding.NamedSharding(mesh, ("model", nothing)) + + function fn_with_constraint(x) + y = x .+ x + return Reactant.Ops.sharding_constraint(y, constraint) + end + + hlo = @code_hlo fn_with_constraint(x_ra) + @test contains(repr(hlo), "sharding_constraint") + + z = Reactant.to_rarray(x; sharding=constraint) + res = @jit fn_with_constraint(x_ra) + + @test x .+ x ≈ Array(res) + @test string(z.sharding.sharding.hlo_sharding) == + string(res.sharding.sharding.hlo_sharding) + @test string(res.sharding.sharding.hlo_sharding) != + string(x_ra.sharding.sharding.hlo_sharding) + + # Test we can compile even when there is an intermediate sharding + x_ra_no_sharding = Reactant.to_rarray(x) + + hlo = @code_hlo fn_with_constraint(x_ra_no_sharding) + @test contains(repr(hlo), "sharding_constraint") + + res = @jit fn_with_constraint(x_ra_no_sharding) + @test x .+ x ≈ Array(res) + @test string(z.sharding.sharding.hlo_sharding) == + string(res.sharding.sharding.hlo_sharding) + @test string(res.sharding.sharding.hlo_sharding) != + string(x_ra_no_sharding.sharding.sharding.hlo_sharding) + else + @warn "Not enough addressable devices to run sharding tests" + end +end + +@testset "Sharding with non-divisible axes sizes" begin + if length(Reactant.addressable_devices()) ≥ 8 + mesh = Sharding.Mesh(reshape(collect(Int64, 0:7), 2, 4), ("data", "model")) + x = reshape(collect(Float32, 1:14), 2, 7) + x_ra = Reactant.to_rarray( + x; sharding=Sharding.NamedSharding(mesh, ("data", "model")) + ) + + @test Array(@jit sum(x_ra; dims=2)) ≈ sum(x; dims=2) + + x = reshape(collect(Float32, 1:25), 5, 5) + x_ra = Reactant.to_rarray( + x; sharding=Sharding.NamedSharding(mesh, ("data", "model")) + ) + + @test Array(@jit fn_test2(x_ra)) ≈ fn_test2(x) + else + @warn "Not enough addressable devices to run sharding tests" + end +end + +# Tests from the examples in +# https://github.com/openxla/xla/blob/96d6678053d867099a42be9001c49b2ed7111afd/xla/hlo/ir/tile_assignment.h#L53-L68 +@testset "Device List from Iota Tile" begin + @test Reactant.XLA.generate_device_list_from_iota_tile( + [4, 4, 1], #=tile_assignment_dimensions=# + [4, 2, 2], #=iota_reshape_dims=# + [1, 2, 3], #=iota_transpose_perm=# + ) == collect(0:15) + + @test Reactant.XLA.generate_device_list_from_iota_tile( + [4, 4, 1], #=tile_assignment_dimensions=# + [4, 2, 2], #=iota_reshape_dims=# + [2, 1, 3], #=iota_transpose_perm=# + ) == [0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15] + + @test Reactant.XLA.generate_device_list_from_iota_tile( + [2, 4], #=tile_assignment_dimensions=# + [4, 2], #=iota_reshape_dims=# + [2, 1], #=iota_transpose_perm=# + ) == [0, 2, 4, 6, 1, 3, 5, 7] +end diff --git a/test/sorting.jl b/test/sorting.jl new file mode 100644 index 0000000000..d54699fa57 --- /dev/null +++ b/test/sorting.jl @@ -0,0 +1,190 @@ +using Reactant, Test + +@testset "sort & sortperm" begin + x = randn(10) + x_ra = Reactant.to_rarray(x) + + srt_rev(x) = sort(x; rev=true) + srtperm_rev(x) = sortperm(x; rev=true) + srt_by(x) = sort(x; by=abs2) + srtperm_by(x) = sortperm(x; by=abs2) + srt_lt(x) = sort(x; lt=(a, b) -> a > b) + srtperm_lt(x) = sortperm(x; lt=(a, b) -> a > b) + + @test @jit(sort(x_ra)) == sort(x) + @test @jit(srt_rev(x_ra)) == srt_rev(x) + @test @jit(srt_lt(x_ra)) == srt_lt(x) + @test @jit(srt_by(x_ra)) == srt_by(x) + @test @jit(sortperm(x_ra)) == sortperm(x) + @test @jit(srtperm_rev(x_ra)) == srtperm_rev(x) + @test @jit(srtperm_lt(x_ra)) == srtperm_lt(x) + @test @jit(srtperm_by(x_ra)) == srtperm_by(x) + + x = rand(10) + x_ra = Reactant.to_rarray(x) + @jit sort!(x_ra) + @test x_ra == sort(x) + + x = rand(10) + x_ra = Reactant.to_rarray(x) + ix = similar(x_ra, Int) + @jit sortperm!(ix, x_ra) + @test ix == sortperm(x) + + x = rand(10, 4, 3) + x_ra = Reactant.to_rarray(x) + + srt(x, d) = sort(x; dims=d) + srt_rev(x, d) = sort(x; dims=d, rev=true) + srt_by(x, d) = sort(x; dims=d, by=abs2) + srt_lt(x, d) = sort(x; dims=d, lt=(a, b) -> a > b) + srtperm(x, d) = sortperm(x; dims=d) + srtperm_rev(x, d) = sortperm(x; dims=d, rev=true) + srtperm_by(x, d) = sortperm(x; dims=d, by=abs2) + srtperm_lt(x, d) = sortperm(x; dims=d, lt=(a, b) -> a > b) + + @testset for d in 1:ndims(x) + @test @jit(srt(x_ra, d)) == srt(x, d) + @test @jit(srtperm(x_ra, d)) == srtperm(x, d) + @test @jit(srt_rev(x_ra, d)) == srt_rev(x, d) + @test @jit(srtperm_rev(x_ra, d)) == srtperm_rev(x, d) + @test @jit(srt_by(x_ra, d)) == srt_by(x, d) + @test @jit(srtperm_by(x_ra, d)) == srtperm_by(x, d) + @test @jit(srt_lt(x_ra, d)) == srt_lt(x, d) + @test @jit(srtperm_lt(x_ra, d)) == srtperm_lt(x, d) + end +end + +@testset "partialsort & partialsortperm" begin + x = randn(10) + x_ra = Reactant.to_rarray(x) + + @test @jit(partialsort(x_ra, 1:5)) == partialsort(x, 1:5) + @test @jit(partialsortperm(x_ra, 1:5)) == partialsortperm(x, 1:5) + @test @jit(partialsort(x_ra, 4)) == partialsort(x, 4) + @test @jit(partialsortperm(x_ra, 4)) == partialsortperm(x, 4) + + psrt_rev(x, k) = partialsort(x, k; rev=true) + psrtperm_rev(x, k) = partialsortperm(x, k; rev=true) + psrt_by(x, k) = partialsort(x, k; by=abs2) + psrtperm_by(x, k) = partialsortperm(x, k; by=abs2) + psrt_lt(x, k) = partialsort(x, k; lt=(a, b) -> a > b) + psrtperm_lt(x, k) = partialsortperm(x, k; lt=(a, b) -> a > b) + + @test @jit(psrt_rev(x_ra, 1:5)) == psrt_rev(x, 1:5) + @test @jit(psrtperm_rev(x_ra, 1:5)) == psrtperm_rev(x, 1:5) + @test @jit(psrt_by(x_ra, 1:5)) == psrt_by(x, 1:5) + @test @jit(psrtperm_by(x_ra, 1:5)) == psrtperm_by(x, 1:5) + @test @jit(psrt_lt(x_ra, 1:5)) == psrt_lt(x, 1:5) + @test @jit(psrtperm_lt(x_ra, 1:5)) == psrtperm_lt(x, 1:5) + + x = randn(10) + x_ra = Reactant.to_rarray(x) + @jit partialsort!(x_ra, 1:5) + partialsort!(x, 1:5) + @test Array(x_ra)[1:5] == x[1:5] + + x = randn(10) + x_ra = Reactant.to_rarray(x) + @jit partialsort!(x_ra, 3) + partialsort!(x, 3) + @test @allowscalar(x_ra[3]) == x[3] + + x = randn(10) + x_ra = Reactant.to_rarray(x) + + ix = similar(x, Int) + ix_ra = Reactant.to_rarray(ix) + @jit partialsortperm!(ix_ra, x_ra, 1:5) + partialsortperm!(ix, x, 1:5) + @test Array(ix_ra)[1:5] == ix[1:5] + + ix = similar(x, Int) + ix_ra = Reactant.to_rarray(ix) + @jit partialsortperm!(ix_ra, x_ra, 3) + partialsortperm!(ix, x, 3) + @test @allowscalar(ix_ra[3]) == ix[3] +end + +@testset "argmin / argmax" begin + x = rand(2, 3) + x_ra = Reactant.to_rarray(x) + + linargmin(x) = LinearIndices(x)[argmin(x)] + linargmax(x) = LinearIndices(x)[argmax(x)] + + @test linargmin(x) == @jit(argmin(x_ra)) + @test linargmax(x) == @jit(argmax(x_ra)) + + x = rand(2, 3, 4) + x_ra = Reactant.to_rarray(x) + + linargmin(x, d) = LinearIndices(x)[argmin(x; dims=d)] + linargmax(x, d) = LinearIndices(x)[argmax(x; dims=d)] + argmindims(x, d) = argmin(x; dims=d) + argmaxdims(x, d) = argmax(x; dims=d) + + @test linargmin(x, 1) == @jit(argmindims(x_ra, 1)) + @test linargmax(x, 1) == @jit(argmaxdims(x_ra, 1)) + @test linargmin(x, 2) == @jit(argmindims(x_ra, 2)) + @test linargmax(x, 2) == @jit(argmaxdims(x_ra, 2)) + @test linargmin(x, 3) == @jit(argmindims(x_ra, 3)) + @test linargmax(x, 3) == @jit(argmaxdims(x_ra, 3)) + + x = randn(2, 3, 4) + x_ra = Reactant.to_rarray(x) + + @test argmin(abs2, x) == @jit(argmin(abs2, x_ra)) + @test argmax(abs2, x) == @jit(argmax(abs2, x_ra)) +end + +@testset "findmin / findmax" begin + xvec = randn(10) + xvec_ra = Reactant.to_rarray(xvec) + + x = randn(2, 3) + x_ra = Reactant.to_rarray(x) + + function fwithlinindices(g, f, x; kwargs...) + values, indices = g(f, x; kwargs...) + return values, LinearIndices(x)[indices] + end + + @test fwithlinindices(findmin, identity, x) == @jit(findmin(x_ra)) + @test fwithlinindices(findmax, identity, x) == @jit(findmax(x_ra)) + @test fwithlinindices(findmin, identity, xvec) == @jit(findmin(xvec_ra)) + @test fwithlinindices(findmax, identity, xvec) == @jit(findmax(xvec_ra)) + + fmindims(x, d) = findmin(x; dims=d) + fmindims(f, x, d) = findmin(f, x; dims=d) + fmaxdims(x, d) = findmax(x; dims=d) + fmaxdims(f, x, d) = findmax(f, x; dims=d) + + @test fwithlinindices(findmin, identity, x; dims=1) == @jit(fmindims(x_ra, 1)) + @test fwithlinindices(findmax, identity, x; dims=1) == @jit(fmaxdims(x_ra, 1)) + @test fwithlinindices(findmin, identity, x; dims=2) == @jit(fmindims(x_ra, 2)) + @test fwithlinindices(findmax, identity, x; dims=2) == @jit(fmaxdims(x_ra, 2)) + @test fwithlinindices(findmin, abs2, x; dims=1) == @jit(fmindims(abs2, x_ra, 1)) + @test fwithlinindices(findmax, abs2, x; dims=1) == @jit(fmaxdims(abs2, x_ra, 1)) + @test fwithlinindices(findmin, abs2, x; dims=2) == @jit(fmindims(abs2, x_ra, 2)) + @test fwithlinindices(findmax, abs2, x; dims=2) == @jit(fmaxdims(abs2, x_ra, 2)) +end + +@testset "findfirst / findlast" begin + x = rand(Bool, 3, 4) + x_ra = Reactant.to_rarray(x) + + ffirstlinindices(x) = LinearIndices(x)[findfirst(x)] + ffirstlinindices(f, x) = LinearIndices(x)[findfirst(f, x)] + flastlinindices(x) = LinearIndices(x)[findlast(x)] + flastlinindices(f, x) = LinearIndices(x)[findlast(f, x)] + + @test ffirstlinindices(x) == @jit(findfirst(x_ra)) + @test flastlinindices(x) == @jit(findlast(x_ra)) + + x = rand(1:256, 3, 4) + x_ra = Reactant.to_rarray(x) + + @test ffirstlinindices(iseven, x) == @jit(findfirst(iseven, x_ra)) + @test flastlinindices(iseven, x) == @jit(findlast(iseven, x_ra)) +end diff --git a/test/struct.jl b/test/struct.jl index 82263a61fb..a027cbf8e7 100644 --- a/test/struct.jl +++ b/test/struct.jl @@ -1,5 +1,6 @@ using Reactant using Test +using Adapt # from bsc-quantic/Tenet.jl struct MockTensor{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N} @@ -11,6 +12,8 @@ MockTensor(data::A, inds) where {T,N,A<:AbstractArray{T,N}} = MockTensor{T,N,A}( Base.parent(t::MockTensor) = t.data Base.size(t::MockTensor) = size(parent(t)) +Adapt.parent_type(::Type{MockTensor{T,N,A}}) where {T,N,A} = A + Base.cos(x::MockTensor) = MockTensor(cos.(parent(x)), x.inds) bcast_cos(x::MockTensor) = cos(x) @@ -19,12 +22,26 @@ mutable struct MutableMockTensor{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N inds::Vector{Symbol} end +Base.@nospecializeinfer function Reactant.traced_type_inner( + @nospecialize(A::Type{<:MockTensor}), + seen, + mode::Reactant.TraceMode, + @nospecialize(track_numbers::Type) +) + T2 = Reactant.traced_type_inner(A.parameters[3], seen, mode, track_numbers) + MT = MockTensor{eltype(T2),ndims(A),T2} + return MT +end + function MutableMockTensor(data::A, inds) where {T,N,A<:AbstractArray{T,N}} return MutableMockTensor{T,N,A}(data, inds) end + Base.parent(t::MutableMockTensor) = t.data Base.size(t::MutableMockTensor) = size(parent(t)) +Adapt.parent_type(::Type{MutableMockTensor{T,N,A}}) where {T,N,A} = A + Base.cos(x::MutableMockTensor) = MutableMockTensor(cos.(parent(x)), x.inds) bcast_cos(x) = cos.(x) @@ -57,11 +74,12 @@ end @testset "MockTensor" begin @testset "immutable" begin x = MockTensor(rand(4, 4), [:i, :j]) - x2 = MockTensor(Reactant.ConcreteRArray(parent(x)), x.inds) + x2 = MockTensor(ConcretePJRTArray(parent(x)), x.inds) y = @jit(bcast_cos(x2)) - @test y isa MockTensor{Float64,2,Reactant.ConcreteRArray{Float64,2}} + @test y isa + MockTensor{Float64,2,ConcretePJRTArray{Float64,2,1,Sharding.NoShardInfo}} @test size(y) == (4, 4) @test isapprox(parent(y), bcast_cos(parent(x))) @test x.inds == [:i, :j] @@ -69,11 +87,13 @@ end @testset "mutable" begin x = MutableMockTensor(rand(4, 4), [:i, :j]) - x2 = MutableMockTensor(Reactant.ConcreteRArray(parent(x)), x.inds) + x2 = MutableMockTensor(ConcretePJRTArray(parent(x)), x.inds) y = @jit(bcast_cos(x2)) - @test y isa MutableMockTensor{Float64,2,Reactant.ConcreteRArray{Float64,2}} + @test y isa MutableMockTensor{ + Float64,2,ConcretePJRTArray{Float64,2,1,Sharding.NoShardInfo} + } @test size(y) == (4, 4) @test isapprox(parent(y), bcast_cos(parent(x))) @test x.inds == [:i, :j] diff --git a/test/tracing.jl b/test/tracing.jl index 3b73212dc7..2990b4f06c 100644 --- a/test/tracing.jl +++ b/test/tracing.jl @@ -1,98 +1,215 @@ using Reactant using Reactant: traced_type, - ConcreteRArray, TracedRArray, + TracedRNumber, ConcreteToTraced, ArrayToConcrete, NoFieldMatchError, - TracedTypeError + TracedTypeError, + ReactantPrimitive using Test +struct Wrapper{A,B} + a::A + b::B +end + +struct Descent{T} + eta::T +end + +struct RMSProp{Teta,Trho,Teps,C<:Bool} + eta::Teta + rho::Trho + epsilon::Teps + centred::C +end + @testset "Tracing" begin @testset "trace_type" begin @testset "mode = ConcreteToTraced" begin - @testset "$origty" for (origty, targetty) in [ - (Any, Any), - (Module, Module), - (DataType, DataType), + @testset "$origty" for (origty, targetty, targettynum) in [ + (Any, Any, Any), + (Real, Real, Real), + (Module, Module, Module), + (DataType, DataType, DataType), # (Union{}, Union{}), # fails - (Nothing, Nothing), - (Symbol, Symbol), - (Char, Char), - (AbstractString, AbstractString), - (String, String), + (Nothing, Nothing, Nothing), + (Symbol, Symbol, Symbol), + (Char, Char, Char), + (AbstractString, AbstractString, AbstractString), + (String, String, String), + (VersionNumber, VersionNumber, VersionNumber), # Numeric types - (AbstractFloat, AbstractFloat), - (Float16, Float16), - (Float32, Float32), - (Float64, Float64), - (Integer, Integer), - (Int8, Int8), - (Int16, Int16), - (Int32, Int32), - (Int64, Int64), - (Int128, Int128), - (UInt8, UInt8), - (UInt16, UInt16), - (UInt32, UInt32), - (UInt64, UInt64), - (UInt128, UInt128), - (Complex{Float16}, Complex{Float16}), - (Complex{Float32}, Complex{Float32}), - (Complex{Float64}, Complex{Float64}), - (Complex{Int8}, Complex{Int8}), - (Complex{Int16}, Complex{Int16}), - (Complex{Int32}, Complex{Int32}), - (Complex{Int64}, Complex{Int64}), - (Complex{Int128}, Complex{Int128}), - (Complex{UInt8}, Complex{UInt8}), - (Complex{UInt16}, Complex{UInt16}), - (Complex{UInt32}, Complex{UInt32}), - (Complex{UInt64}, Complex{UInt64}), - (Complex{UInt128}, Complex{UInt128}), + (AbstractFloat, AbstractFloat, AbstractFloat), + (Float16, Float16, TracedRNumber{Float16}), + (Float32, Float32, TracedRNumber{Float32}), + (Float64, Float64, TracedRNumber{Float64}), + (Integer, Integer, Integer), + (Int8, Int8, TracedRNumber{Int8}), + (Int16, Int16, TracedRNumber{Int16}), + (Int32, Int32, TracedRNumber{Int32}), + (Int64, Int64, TracedRNumber{Int64}), + (UInt8, UInt8, TracedRNumber{UInt8}), + (UInt16, UInt16, TracedRNumber{UInt16}), + (UInt32, UInt32, TracedRNumber{UInt32}), + (UInt64, UInt64, TracedRNumber{UInt64}), + (Complex{Float32}, Complex{Float32}, TracedRNumber{Complex{Float32}}), + (Complex{Float64}, Complex{Float64}, TracedRNumber{Complex{Float64}}), + (Complex{Int8}, Complex{Int8}, TracedRNumber{Complex{Int8}}), + (Complex{Int16}, Complex{Int16}, TracedRNumber{Complex{Int16}}), + (Complex{Int32}, Complex{Int32}, TracedRNumber{Complex{Int32}}), + (Complex{Int64}, Complex{Int64}, TracedRNumber{Complex{Int64}}), + (Complex{UInt8}, Complex{UInt8}, TracedRNumber{Complex{UInt8}}), + (Complex{UInt16}, Complex{UInt16}, TracedRNumber{Complex{UInt16}}), + (Complex{UInt32}, Complex{UInt32}, TracedRNumber{Complex{UInt32}}), + (Complex{UInt64}, Complex{UInt64}, TracedRNumber{Complex{UInt64}}), # RArray types - (ConcreteRArray{Float64,0}, TracedRArray{Float64,0}), - (ConcreteRArray{Float64,1}, TracedRArray{Float64,1}), - (ConcreteRArray{Float64,2}, TracedRArray{Float64,2}), - (ConcreteRArray{Float64,3}, TracedRArray{Float64,3}), + ( + ConcretePJRTArray{Float64,0,1,Sharding.NoShardInfo}, + TracedRArray{Float64,0}, + TracedRArray{Float64,0}, + ), + ( + ConcretePJRTArray{Float64,1,1,Sharding.NoShardInfo}, + TracedRArray{Float64,1}, + TracedRArray{Float64,1}, + ), + ( + ConcretePJRTArray{Float64,2,1,Sharding.NoShardInfo}, + TracedRArray{Float64,2}, + TracedRArray{Float64,2}, + ), + ( + ConcretePJRTArray{Float64,3,1,Sharding.NoShardInfo}, + TracedRArray{Float64,3}, + TracedRArray{Float64,3}, + ), # Array types - (Array{Float64,1}, Array{Float64,1}), - (Array{ConcreteRArray{Float64,2},1}, Array{TracedRArray{Float64,2},1}), + (Array{Float64,1}, Array{Float64,1}, Array{TracedRNumber{Float64},1}), + ( + Array{ConcretePJRTArray{Float64,2,1,Sharding.NoShardInfo},1}, + Array{TracedRArray{Float64,2},1}, + Array{TracedRArray{Float64,2},1}, + ), # Union types - (Union{Nothing,Int}, Union{Nothing,Int}), + (Union{Nothing,Int}, Union{Nothing,Int}, Union{Nothing,TracedRNumber{Int}}), ( - Union{Nothing,ConcreteRArray{Float64,1}}, + Union{Nothing,ConcretePJRTArray{Float64,1,1,Sharding.NoShardInfo}}, + Union{Nothing,TracedRArray{Float64,1}}, Union{Nothing,TracedRArray{Float64,1}}, ), # Ptr types - (Ptr{Float64}, Ptr{Float64}), - (Ptr{ConcreteRArray{Float64,1}}, Ptr{TracedRArray{Float64,1}}), - (Core.LLVMPtr{Float64}, Core.LLVMPtr{Float64}), + (Ptr{Float64}, Ptr{Float64}, Ptr{TracedRNumber{Float64}}), ( - Core.LLVMPtr{ConcreteRArray{Float64,1}}, + Ptr{ConcretePJRTArray{Float64,1,1,Sharding.NoShardInfo}}, + Ptr{TracedRArray{Float64,1}}, + Ptr{TracedRArray{Float64,1}}, + ), + ( + Core.LLVMPtr{Float64}, + Core.LLVMPtr{Float64}, + Core.LLVMPtr{TracedRNumber{Float64}}, + ), + ( + Core.LLVMPtr{ConcretePJRTArray{Float64,1,1,Sharding.NoShardInfo}}, + Core.LLVMPtr{TracedRArray{Float64,1}}, Core.LLVMPtr{TracedRArray{Float64,1}}, ), - (Base.RefValue{Float64}, Base.RefValue{Float64}), ( - Base.RefValue{ConcreteRArray{Float64,1}}, + Base.RefValue{Float64}, + Base.RefValue{Float64}, + Base.RefValue{TracedRNumber{Float64}}, + ), + ( + Base.RefValue{ConcretePJRTArray{Float64,1,1,Sharding.NoShardInfo}}, + Base.RefValue{TracedRArray{Float64,1}}, Base.RefValue{TracedRArray{Float64,1}}, ), # Val types - (Val{0}, Val{0}), - (Val{0.5}, Val{0.5}), - (Val{:x}, Val{:x}), + (Val{0}, Val{0}, Val{0}), + (Val{0.5}, Val{0.5}, Val{0.5}), + (Val{:x}, Val{:x}, Val{:x}), + ( + Dict{Int,ConcretePJRTArray{Float64,0,1,Sharding.NoShardInfo}}, + Dict{Int,TracedRArray{Float64,0}}, + Dict{Int,TracedRArray{Float64,0}}, + ), + (Dict{Int}, Dict{Int}, Dict{Int}), + (Dict, Dict, Dict), + ( + (Dict{A,ConcretePJRTArray{Float64,0,1,Sharding.NoShardInfo}} where {A}), + (Dict{A,TracedRArray{Float64,0}} where {A}), + (Dict{A,TracedRArray{Float64,0}} where {A}), + ), + ( + ( + Dict{ + Symbol,NTuple{nsteps,SpectralVariable3D} + } where {nsteps} where {SpectralVariable3D} + ), + ( + Dict{ + Symbol,NTuple{nsteps,SpectralVariable3D} + } where {nsteps} where {SpectralVariable3D} + ), + ( + Dict{ + Symbol,NTuple{nsteps,SpectralVariable3D} + } where {nsteps} where {SpectralVariable3D} + ), + ), + ( + Base.Pairs{Symbol,Union{}}, + Base.Pairs{Symbol,Union{}}, + Base.Pairs{Symbol,Union{}}, + ), + ( + NTuple{nsteps,SpectralVariable3D} where {nsteps,SpectralVariable3D}, + NTuple{nsteps,SpectralVariable3D} where {nsteps,SpectralVariable3D}, + NTuple{nsteps,SpectralVariable3D} where {nsteps,SpectralVariable3D}, + ), + ( + Base.RefValue{A} where {A}, + Base.RefValue{A} where {A}, + Base.RefValue{A} where {A}, + ), + (Wrapper{Symbol,Symbol}, Wrapper{Symbol,Symbol}, Wrapper{Symbol,Symbol}), + ( + Wrapper{Float64,Vector{Float64}}, + Wrapper{Float64,Vector{Float64}}, + Wrapper{TracedRNumber{Float64},Vector{Float64}}, + ), + ( + Wrapper{Float64,ConcretePJRTArray{Float64,1,1,Sharding.NoShardInfo}}, + Wrapper{Float64,TracedRArray{Float64,1}}, + Wrapper{TracedRNumber{Float64},TracedRArray{Float64,1}}, + ), + (Wrapper{Symbol}, Wrapper{Symbol}, Wrapper{Symbol}), + (Wrapper{Float64}, Wrapper{Float64}, Wrapper{TracedRNumber{Float64}}), + ( + Wrapper{ConcretePJRTArray{Float64,1,1,Sharding.NoShardInfo}}, + Wrapper{TracedRArray{Float64,1}}, + Wrapper{TracedRArray{Float64,1}}, + ), + (Wrapper, Wrapper, Wrapper), ] tracedty = traced_type( - origty, Reactant.OrderedIdDict(), Val(ConcreteToTraced), () + origty, Val(ConcreteToTraced), Union{}, Sharding.NoSharding() ) @test tracedty == targetty + + tracedty2 = traced_type( + origty, Val(ConcreteToTraced), ReactantPrimitive, Sharding.NoSharding() + ) + @test tracedty2 == targetty end @testset "$type" for type in [ @@ -102,36 +219,51 @@ using Test TracedRArray{Float64,3}, ] @test_throws Union{ErrorException,String} traced_type( - type, Reactant.OrderedIdDict(), Val(ConcreteToTraced), () + type, Val(ConcreteToTraced), Union{}, Sharding.NoSharding() ) end end @testset "traced_type exceptions" begin - @test_throws TracedTypeError Reactant.traced_type( - Real, Reactant.OrderedIdDict(), Val(Reactant.ArrayToConcrete), () - ) - struct Node x::Vector{Float64} y::Union{Nothing,Node} end @test_throws NoFieldMatchError traced_type( - Node, Reactant.OrderedIdDict(), Val(ArrayToConcrete), () + Node, Val(ArrayToConcrete), Union{}, Sharding.NoSharding() ) end end @testset "specialized dispatches" begin - @test @inferred Union{Float64,ConcreteRArray{Float64}} Reactant.to_rarray( - 1.0; track_numbers=(Number,) - ) isa ConcreteRNumber + @test @inferred Union{Float64,ConcretePJRTArray{Float64}} Reactant.to_rarray( + 1.0; track_numbers=Number + ) isa ConcretePJRTNumber @test @inferred Reactant.to_rarray(1.0) isa Float64 - @test @inferred Reactant.to_rarray(rand(3)) isa ConcreteRArray + @test @inferred Reactant.to_rarray(rand(3)) isa ConcretePJRTArray x_ra = Reactant.to_rarray(rand(3)) - @test @inferred Reactant.to_rarray(x_ra) isa ConcreteRArray + @test @inferred Reactant.to_rarray(x_ra) isa ConcretePJRTArray + + x_ra = Reactant.to_rarray(1.0; track_numbers=Number) + @test @inferred Reactant.to_rarray(x_ra) isa ConcretePJRTNumber + end + + @testset "no trace Val" begin + st = (; a=1, training=Val(true)) + st_traced = Reactant.to_rarray(st; track_numbers=Number) + @test st_traced.training isa Val{true} + end + + @testset "to_rarray(::AbstractRule)" begin + opt = Descent(0.1) + opt_traced = Reactant.to_rarray(opt; track_numbers=AbstractFloat) + @test opt_traced.eta isa ConcreteRNumber{Float64} - x_ra = Reactant.to_rarray(1.0; track_numbers=(Number,)) - @test @inferred Reactant.to_rarray(x_ra) isa ConcreteRNumber + opt = RMSProp(0.1, 0.9, 1e-8, true) + opt_traced = Reactant.to_rarray(opt; track_numbers=AbstractFloat) + @test opt_traced.eta isa ConcreteRNumber{Float64} + @test opt_traced.rho isa ConcreteRNumber{Float64} + @test opt_traced.epsilon isa ConcreteRNumber{Float64} + @test opt_traced.centred isa Bool end end diff --git a/test/wrapped_arrays.jl b/test/wrapped_arrays.jl index f5418e5c80..ba8a466162 100644 --- a/test/wrapped_arrays.jl +++ b/test/wrapped_arrays.jl @@ -148,7 +148,7 @@ end ("Transpose", write_to_transposed_array!), ("Adjoint", write_to_adjoint_array!), ] - x = ConcreteRArray(rand(3, 2)) + x = Reactant.to_rarray(rand(3, 2)) y = @jit fn(x) @test all(isone, Array(y)) end @@ -172,3 +172,119 @@ end @test all(iszero, y_res) end end + +function lower_triangular_write(x) + y = LowerTriangular(copy(x)) + @. y *= 2 + return y +end + +function upper_triangular_write(x) + y = UpperTriangular(copy(x)) + @. y *= 2 + return y +end + +function tridiagonal_write(x) + y = Tridiagonal(copy(x)) + @. y *= 2 + return y +end + +@testset "Broadcasted Multiply and Alloate" begin + @testset "$(aType)" for (aType, fn) in [ + ("LowerTriangular", lower_triangular_write), + ("UpperTriangular", upper_triangular_write), + ("Tridiagonal", tridiagonal_write), + ] + x = rand(4, 4) + x_ra = Reactant.to_rarray(x) + @test @jit(fn(x_ra)) ≈ fn(x) + end +end + +function broadcast_reshaped_array(x, idx1, idx2) + y = reshape(x, 20, 2) + return y[idx1, idx2] .+ 1 +end + +function broadcast_reshaped_array(x, idx1, idx2::Number) + y = reshape(x, 20, 2) + return y[idx1, idx2] .+ 1 +end + +function broadcast_reshaped_array(x, idx1) + y = reshape(x, 20, 2) + return y[idx1, :] .+ 1 +end + +@testset "Broadcast reshaped array" begin + x_ra = Reactant.to_rarray(rand(5, 4, 2)) + idx1_ra = Reactant.to_rarray(rand(1:20, 4)) + idx2_ra = Reactant.to_rarray([2, 1]) + + @test broadcast_reshaped_array(Array(x_ra), Array(idx1_ra), Array(idx2_ra)) ≈ + @jit(broadcast_reshaped_array(x_ra, idx1_ra, idx2_ra)) ≈ + @jit(broadcast_reshaped_array(x_ra, Array(idx1_ra), Array(idx2_ra))) + + idx3 = Reactant.to_rarray(2; track_numbers=true) + + @test broadcast_reshaped_array(Array(x_ra), Array(idx1_ra), Int64(idx3)) ≈ + @jit(broadcast_reshaped_array(x_ra, idx1_ra, idx3)) ≈ + @jit(broadcast_reshaped_array(x_ra, Array(idx1_ra), Int64(idx3))) + + @test broadcast_reshaped_array(Array(x_ra), Array(idx1_ra)) ≈ + @jit(broadcast_reshaped_array(x_ra, idx1_ra)) ≈ + @jit(broadcast_reshaped_array(x_ra, Array(idx1_ra))) +end + +@testset "reshaped subarray indexing" begin + fn(x) = view(x, 1:2) .+ 1 + x_ra = Reactant.to_rarray(rand(3, 4, 3)) + @test @jit(fn(x_ra)) == fn(Array(x_ra)) +end + +function reshape_getindex(x) + x = reshape(x, 2, 4, 3) + return x[1, :, :] +end + +function permutedims_getindex(x) + x = PermutedDimsArray(x, (2, 1)) + return x[1, :] +end + +@testset "no gather getindex" begin + x = ones(8, 3) + x_ra = Reactant.to_rarray(x) + + hlo = repr(@code_hlo(reshape_getindex(x_ra))) + @test !occursin("stablehlo.gather", hlo) + + hlo = repr(@code_hlo(permutedims_getindex(x_ra))) + @test !occursin("stablehlo.gather", hlo) +end + +function view_adjoint(x) + y = view(x, 1:2, 1:2) + return adjoint(y) .+ y +end + +function view_transpose(x) + y = view(x, 1:2, 1:2) + return transpose(y) .+ y +end + +function view_diagonal(x) + y = view(x, 1:2, 1:2) + return Diagonal(y) .+ y +end + +@testset "2 levels of wrapping" begin + x = reshape(collect(Float32, 1:8), 2, 4) + x_ra = Reactant.to_rarray(x) + + @test @jit(view_adjoint(x_ra)) ≈ view_adjoint(x) + @test @jit(view_transpose(x_ra)) ≈ view_transpose(x) + @test @jit(view_diagonal(x_ra)) ≈ view_diagonal(x) +end