Skip to content

Build ROCm TransformerEngine wheel (weekly) #33

Build ROCm TransformerEngine wheel (weekly)

Build ROCm TransformerEngine wheel (weekly) #33

name: Build ROCm TransformerEngine wheel (weekly)
on:
workflow_dispatch:
schedule:
# Weekly at night (02:00 UTC every Monday), 2 hours ahead of scheduled tests.
- cron: "0 2 * * 1"
permissions:
contents: write
jobs:
build_upload_prune:
# AMD GPU runner (GitHub-hosted large runner label).
runs-on: linux-x86-64-4gpu-amd
container:
image: ghcr.io/rocm/jax-base-ubu24.rocm720:latest
options: >-
--device=/dev/kfd --device=/dev/dri --group-add video
--ipc=host --shm-size 64g
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined
--privileged
env:
XLA_PYTHON_CLIENT_MEM_FRACTION: "0.9"
NVTE_FUSED_ATTN_AOTRITON: "0"
env:
TE_WHEELS_KEEP_DAYS: "21"
steps:
- name: Checkout
uses: actions/checkout@v5
- name: Setup build environment (deps + venv)
shell: bash
run: |
set -euo pipefail
apt-get update
apt-get install -y --no-install-recommends git build-essential python3-dev
python3 -m pip install -U uv
python3 -m uv venv --seed
source .venv/bin/activate
uv pip install -U pip setuptools wheel pybind11 cmake
- name: Install ROCm JAX/JAXlib wheels (build against CI stack)
shell: bash
run: |
set -euo pipefail
source .venv/bin/activate
uv pip install -r src/dependencies/requirements/requirements_decoupled_rocm_jax_0_9_1.txt
- name: Install PyTorch ROCm (build-time dep for aiter JIT)
shell: bash
run: |
set -euo pipefail
source .venv/bin/activate
uv pip install torch --index-url https://download.pytorch.org/whl/rocm7.2
- name: Detect ROCm version and Python tag
shell: bash
run: |
set -euo pipefail
source .venv/bin/activate
# Detect ROCm version number from JAX backend string when available (e.g. 'rocm 70200').
ROCM_NUM="$([ -f /opt/rocm/.info/version ] && head -n1 /opt/rocm/.info/version | tr -d ' \t\r' || echo unknown)"
echo "Detected ROCm version: ${ROCM_NUM}"
echo "ROCM_NUM=${ROCM_NUM}" >> "${GITHUB_ENV}"
PYTAG="cp$(python3 -c 'import sys; print(f"{sys.version_info.major}{sys.version_info.minor}")')"
if [ "${PYTAG}" != "cp312" ]; then
echo "Expected Python 3.12 (cp312) for ROCm CI wheels, got ${PYTAG}."
exit 1
fi
echo "PYTAG=${PYTAG}" >> "${GITHUB_ENV}"
echo "REL_SCRIPT=.github/workflows/utils/te_wheels_release.py" >> "${GITHUB_ENV}"
- name: Clone ROCm/TransformerEngine (dev)
shell: bash
run: |
set -euo pipefail
rm -rf TransformerEngine
git clone --recursive --branch dev https://github.com/ROCm/TransformerEngine.git
cd TransformerEngine
git submodule update --init --recursive
TE_SHA="$(git rev-parse --short=12 HEAD)"
echo "TE_SHA=${TE_SHA}" >> "${GITHUB_ENV}"
- name: Select TE wheel arch for runner (mi300/mi355)
shell: bash
run: |
set -euo pipefail
source .venv/bin/activate
TE_WHEEL_ARCH="$(python3 .github/workflows/utils/install_te_rocm_wheel.py --print-arch)"
echo "Resolved TE wheel arch for runner: ${TE_WHEEL_ARCH}"
echo "TE_WHEEL_ARCH=${TE_WHEEL_ARCH}" >> "${GITHUB_ENV}"
# Build ONLY for the ROCm arch present on this CI runner (mi300 or mi355).
if [ "${TE_WHEEL_ARCH}" = "mi355" ]; then
SELECTOR="mi355"
GFX_ARCH="gfx950"
else
SELECTOR="mi300"
GFX_ARCH="gfx942;gfx941"
fi
echo "SELECTOR=${SELECTOR}" >> "${GITHUB_ENV}"
echo "GFX_ARCH=${GFX_ARCH}" >> "${GITHUB_ENV}"
- name: Build TE wheel
shell: bash
run: |
set -euo pipefail
source .venv/bin/activate
chmod +x "${REL_SCRIPT}" || true
export USE_ROCM=1
export HIP_PATH=/opt/rocm
export NVTE_FRAMEWORK=jax
export CMAKE_BUILD_PARALLEL_LEVEL=64
export NVTE_USE_ROCM=1
export NVTE_FUSED_ATTN_AOTRITON=0
export NVTE_BUILD_MAX_JOBS=180
echo "=== Building TE wheel for ${SELECTOR} (gfx=${GFX_ARCH}) ==="
pushd TransformerEngine >/dev/null
rm -rf build dist
export PYTHONPATH="$(pwd)/3rdparty/hipify_torch${PYTHONPATH:+:${PYTHONPATH}}"
export PYTORCH_ROCM_ARCH="${GFX_ARCH}"
export NVTE_ROCM_ARCH="${GFX_ARCH}"
python3 setup.py bdist_wheel
wheel_path="$(
python3 -c 'import glob; m=sorted(glob.glob("dist/transformer_engine-*.whl")); print(m[0] if m else "")'
)"
if [ -z "${wheel_path}" ]; then
echo "No wheel produced in dist/ (selector=${SELECTOR})."
exit 1
fi
wheel_base="$(basename "${wheel_path}")"
if [[ "${wheel_base}" == *"-1.${SELECTOR}-${PYTAG}-${PYTAG}-linux_x86_64.whl" ]]; then
asset_name="${wheel_base}"
else
asset_name="${wheel_base/-${PYTAG}-${PYTAG}-linux_x86_64.whl/-1.${SELECTOR}-${PYTAG}-${PYTAG}-linux_x86_64.whl}"
if [ "${asset_name}" = "${wheel_base}" ]; then
echo "Failed to rename wheel for selector=${SELECTOR}: ${wheel_base}"
exit 1
fi
fi
cp -f "${wheel_path}" "../${asset_name}"
popd >/dev/null
ls -lh "${asset_name}"
echo "TE_WHEEL_FILE=${asset_name}" >> "${GITHUB_ENV}"
- name: Upload wheel to rolling release tag (te-rocm-wheels)
shell: bash
env:
GITHUB_TOKEN: ${{ github.token }}
run: |
set -euo pipefail
source .venv/bin/activate
python3 "${REL_SCRIPT}" upload --no-prerelease --tag "te-rocm-wheels" --title "ROCm TransformerEngine wheels (latest)" --body "Rolling release for latest weekly-built ROCm TransformerEngine wheels used by CI." --file "${TE_WHEEL_FILE}"
- name: Prune old assets from rolling tag
shell: bash
env:
GITHUB_TOKEN: ${{ github.token }}
run: |
set -euo pipefail
source .venv/bin/activate
echo "Pruning rolling-tag assets older than ${TE_WHEELS_KEEP_DAYS} days"
python3 "${REL_SCRIPT}" prune-assets --tag "te-rocm-wheels" --keep-days "${TE_WHEELS_KEEP_DAYS}"
- name: Publish wheel to dated weekly release tag
shell: bash
env:
GITHUB_TOKEN: ${{ github.token }}
run: |
set -euo pipefail
source .venv/bin/activate
DATE_UTC="$(date -u +%Y-%m-%d)"
WEEKLY_TAG="te-rocm-wheels-${DATE_UTC}-${TE_SHA}"
WEEKLY_TITLE="ROCm TransformerEngine wheels ${DATE_UTC} (TE ${TE_SHA})"
# Keep this YAML-safe (no unindented heredocs inside `run: |`).
WEEKLY_BODY="$(
printf '%s\n\nROCm: %s\nPython: %s\nArch: %s (gfx=%s)\n' \
"Built from ROCm/TransformerEngine dev @ ${TE_SHA} on ${DATE_UTC}." \
"${ROCM_NUM}" "${PYTAG}" "${SELECTOR}" "${GFX_ARCH}"
)"
python3 "${REL_SCRIPT}" upload --no-prerelease --tag "${WEEKLY_TAG}" --title "${WEEKLY_TITLE}" --body "${WEEKLY_BODY}" --file "${TE_WHEEL_FILE}"
- name: Prune old weekly releases
shell: bash
env:
GITHUB_TOKEN: ${{ github.token }}
run: |
set -euo pipefail
source .venv/bin/activate
echo "Pruning weekly release pages older than ${TE_WHEELS_KEEP_DAYS} days"
python3 "${REL_SCRIPT}" prune-releases --prefix "te-rocm-wheels-" --keep-days "${TE_WHEELS_KEEP_DAYS}"