Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
290 changes: 290 additions & 0 deletions .github/workflows/build_linux_jax_wheels.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
name: Build Portable Linux JAX Wheels

on:
workflow_call:
inputs:
amdgpu_family:
required: true
type: string
python_version:
required: true
type: string
release_type:
description: The type of release to build ("dev", "nightly", or "prerelease"). All developer-triggered jobs should use "dev"!
required: true
type: string
s3_subdir:
description: S3 subdirectory, not including the GPU-family
required: true
type: string
s3_staging_subdir:
description: S3 staging subdirectory, not including the GPU-family
required: true
type: string
rocm_version:
description: ROCm version to install
type: string
tar_url:
description: URL to TheRock tarball to build against
type: string
cloudfront_url:
description: CloudFront URL pointing to Python index
required: true
type: string
cloudfront_staging_url:
description: CloudFront base URL pointing to staging Python index
required: true
type: string
repository:
description: "Repository to checkout. Defaults to `ROCm/TheRock`."
type: string
default: "ROCm/TheRock"
ref:
description: "Branch, tag or SHA to checkout. Defaults to the reference or SHA that triggered the workflow."
type: string
workflow_dispatch:
inputs:
amdgpu_family:
type: choice
options:
- gfx101X-dgpu
- gfx103X-dgpu
- gfx110X-all
- gfx1150
- gfx1151
- gfx120X-all
- gfx90X-dcgpu
- gfx94X-dcgpu
- gfx950-dcgpu
default: gfx94X-dcgpu
python_version:
required: true
type: string
default: "3.12"
release_type:
type: choice
description: Type of release to create. All developer-triggered jobs should use "dev"!
options:
- dev
- nightly
- prerelease
default: dev
s3_subdir:
description: S3 subdirectory, not including the GPU-family
type: string
default: "v2"
s3_staging_subdir:
description: S3 staging subdirectory, not including the GPU-family
type: string
default: "v2-staging"
rocm_version:
description: ROCm version to install
type: string
tar_url:
description: URL to TheRock tarball to build against
type: string
cloudfront_url:
description: CloudFront base URL pointing to Python index
type: string
default: "https://rocm.devreleases.amd.com/v2"
cloudfront_staging_url:
description: CloudFront base URL pointing to staging Python index
type: string
default: "https://rocm.devreleases.amd.com/v2-staging"
jax_ref:
description: rocm-jax repository ref/branch to check out
type: string
default: rocm-jaxlib-v0.8.0

permissions:
id-token: write
contents: read

run-name: Build Linux JAX Wheels (${{ inputs.amdgpu_family }}, ${{ inputs.python_version }}, ${{ inputs.release_type }})

jobs:
build_jax_wheels:
strategy:
matrix:
jax_ref: [rocm-jaxlib-v0.8.0]
name: Build Linux JAX Wheels | ${{ inputs.amdgpu_family }} | Python ${{ inputs.python_version }}
runs-on: ${{ github.repository_owner == 'ROCm' && 'azure-linux-scale-rocm' || 'ubuntu-24.04' }}
env:
PACKAGE_DIST_DIR: ${{ github.workspace }}/jax/jax_rocm_plugin/wheelhouse
S3_BUCKET_PY: "therock-${{ inputs.release_type }}-python"
outputs:
cp_version: ${{ env.cp_version }}
jax_version: ${{ steps.extract_jax_version.outputs.jax_version }}
steps:
- name: Checkout TheRock
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1

- name: Checkout JAX
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
with:
path: jax
repository: rocm/rocm-jax
ref: ${{ matrix.jax_ref }}

- name: Configure Git Identity
run: |
git config --global user.name "therockbot"
git config --global user.email "[email protected]"

- name: "Setting up Python"
uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0
with:
python-version: ${{ inputs.python_version }}

- name: Select Python version
run: |
python build_tools/github_actions/python_to_cp_version.py \
--python-version ${{ inputs.python_version }}

- name: Build JAX Wheels
env:
ROCM_VERSION: ${{ inputs.rocm_version }}
run: |
ls -lah
pushd jax
python3 build/ci_build \
--compiler=clang \
--python-versions="${{ inputs.python_version }}" \
--rocm-version="${ROCM_VERSION}" \
--therock-path="${{ inputs.tar_url }}" \
dist_wheels

- name: Extract JAX version
id: extract_jax_version
run: |
# Extract JAX version from requirements.txt (e.g., "jax==0.8.0")
# Remove all whitespace from requirements.txt to simplify parsing
# Search for lines starting with "jax==" or "jaxlib==" followed by version (excluding comments)
# Extract the version number by splitting on '=' and taking the 3rd field
# [^#]+ matches one or more characters that are NOT '#', ensuring we stop before any inline comments
JAX_VERSION=$(tr -d ' ' < jax/build/requirements.txt \
| grep -E '^(jax|jaxlib)==[^#]+' | head -n1 | cut -d'=' -f3)
echo "jax_version=$JAX_VERSION" >> "$GITHUB_OUTPUT"

- name: Install AWS CLI
if: always()
run: bash ./dockerfiles/install_awscli.sh

- name: Configure AWS Credentials
if: always()
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 # v5.1.1
with:
aws-region: us-east-2
role-to-assume: arn:aws:iam::692859939525:role/therock-${{ inputs.release_type }}-releases

- name: Upload wheels to S3
if: ${{ github.repository_owner == 'ROCm' }}
run: |
aws s3 cp ${{ env.PACKAGE_DIST_DIR }}/ s3://${{ env.S3_BUCKET_PY }}/${{ inputs.s3_staging_subdir }}/${{ inputs.amdgpu_family }}/ \
--recursive --exclude "*" --include "*.whl"

- name: (Re-)Generate Python package release index
if: ${{ github.repository_owner == 'ROCm' }}
run: |
python3 -m venv .venv
source .venv/bin/activate
pip3 install boto3 packaging
python3 ./build_tools/third_party/s3_management/manage.py ${{ inputs.s3_staging_subdir }}/${{ inputs.amdgpu_family }}

generate_target_to_run:
name: Generate target_to_run
runs-on: ubuntu-24.04
outputs:
test_runs_on: ${{ steps.configure.outputs.test-runs-on }}
bypass_tests_for_releases: ${{ steps.configure.outputs.bypass_tests_for_releases }}
steps:
- name: Checking out repository
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
with:
repository: ${{ inputs.repository || github.repository }}
ref: ${{ inputs.ref || '' }}

- name: Generating target to run
id: configure
env:
TARGET: ${{ inputs.amdgpu_family }}
PLATFORM: "linux"
# Variable comes from ROCm organization variable 'ROCM_THEROCK_TEST_RUNNERS'
ROCM_THEROCK_TEST_RUNNERS: ${{ vars.ROCM_THEROCK_TEST_RUNNERS }}
LOAD_TEST_RUNNERS_FROM_VAR: false
run: python ./build_tools/github_actions/configure_target_run.py

test_jax_wheels:
name: Test JAX wheels | ${{ inputs.amdgpu_family }} | ${{ needs.generate_target_to_run.outputs.test_runs_on }}
needs: [build_jax_wheels, generate_target_to_run]
permissions:
contents: read
packages: read
uses: ./.github/workflows/test_linux_jax_wheels.yml
with:
amdgpu_family: ${{ inputs.amdgpu_family }}
release_type: ${{ inputs.release_type }}
s3_subdir: ${{ inputs.s3_subdir }}
package_index_url: ${{ inputs.cloudfront_staging_url }}
rocm_version: ${{ inputs.rocm_version }}
tar_url: ${{ inputs.tar_url }}
python_version: ${{ inputs.python_version }}
repository: ${{ inputs.repository || github.repository }}
ref: ${{ inputs.ref || '' }}
jax_ref: ${{ inputs.jax_ref }}
test_runs_on: ${{ needs.generate_target_to_run.outputs.test_runs_on }}

upload_jax_wheels:
name: Release JAX Wheels to S3
needs: [build_jax_wheels, generate_target_to_run, test_jax_wheels]
if: ${{ !cancelled() }}
runs-on: ubuntu-24.04
env:
S3_BUCKET_PY: "therock-${{ inputs.release_type }}-python"
JAX_VERSION: "${{ needs.build_jax_wheels.outputs.jax_version }}"
ROCM_VERSION: "${{ inputs.rocm_version }}"
CP_VERSION: "${{ needs.build_jax_wheels.outputs.cp_version }}"

steps:
- name: Checkout
uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1
with:
repository: ${{ inputs.repository || github.repository }}
ref: ${{ inputs.ref || '' }}

- name: Configure AWS Credentials
if: always()
uses: aws-actions/configure-aws-credentials@00943011d9042930efac3dcd3a170e4273319bc8 # v5.1.0
with:
aws-region: us-east-2
role-to-assume: arn:aws:iam::692859939525:role/therock-${{ inputs.release_type }}-releases

- name: Determine upload flag
env:
BUILD_RESULT: ${{ needs.build_jax_wheels.result }}
TEST_RESULT: ${{ needs.test_jax_wheels.result }}
TEST_RUNS_ON: ${{ needs.generate_target_to_run.outputs.test_runs_on }}
BYPASS_TESTS_FOR_RELEASES: ${{ needs.generate_target_to_run.outputs.bypass_tests_for_releases }}
run: python ./build_tools/github_actions/promote_wheels_based_on_policy.py

- name: Copy JAX wheels from staging to release S3
if: ${{ env.upload == 'true' }}
run: |
echo "Copying exact tested wheels to release S3 bucket..."
aws s3 cp \
s3://${S3_BUCKET_PY}/${{ inputs.s3_staging_subdir }}/${{ inputs.amdgpu_family }}/ \
s3://${S3_BUCKET_PY}/${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }}/ \
--recursive \
--exclude "*" \
--include "jaxlib-${JAX_VERSION}+rocm${ROCM_VERSION}-${CP_VERSION}-manylinux_2_27_x86_64.whl" \
--include "jax_rocm7_plugin-${JAX_VERSION}+rocm${ROCM_VERSION}-${CP_VERSION}-manylinux_2_28_x86_64.whl" \
--include "jax_rocm7_pjrt-${JAX_VERSION}+rocm${ROCM_VERSION}-py3-none-manylinux_2_28_x86_64.whl"

- name: (Re-)Generate Python package release index
if: ${{ env.upload == 'true' }}
env:
# Environment variables to be set for `manage.py`
CUSTOM_PREFIX: "${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }}"
run: |
pip install boto3 packaging
python ./build_tools/third_party/s3_management/manage.py ${{ env.CUSTOM_PREFIX }}
Loading
Loading