Skip to content

Commit 4a5afee

Browse files
committed
Resolve latest MaxText Transformer Engine wheel from ROCm MaxText releases
1 parent 2483b46 commit 4a5afee

3 files changed

Lines changed: 32 additions & 12 deletions

File tree

.github/workflows/benchmark_rocm.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# CI - Benchmark ROCm
22
#
3-
# This workflow runs ROCm benchmarks in ROCm team's GHCR containers.
3+
# This workflow runs the ROCm benchmarks in ROCm team's GHCR containers.
44
# It can be triggered manually via workflow_dispatch or called by other workflows
55
# via workflow_call.
66
#
77
# It consists of the following job:
8-
# run-benchmark:
8+
# run-benchmarks:
99
# - Runs in ROCm team's container (ghcr.io/rocm/jax-base-ubu24-rocm*:latest)
1010
# - Downloads the JAX and jaxlib wheels from GCS, and ROCm plugins from S3.
11-
# - Executes the target benchmark runner script at `targets/<target>/run.sh`.
11+
# - Executes the target benchmark scripts at `targets/<target>/run.sh`.
1212
name: CI - Benchmark ROCm
1313
on:
1414
workflow_dispatch:
@@ -135,7 +135,7 @@ env:
135135
UV_DEFAULT_INDEX: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple"
136136

137137
jobs:
138-
run-benchmark:
138+
run-benchmarks:
139139
defaults:
140140
run:
141141
# Set the shell to bash as GitHub actions run with /bin/sh by default
@@ -144,7 +144,7 @@ jobs:
144144
continue-on-error: true
145145
# Run in ROCm team's GHCR container with GPU access
146146
container:
147-
image: ghcr.io/rocm/jax-base-ubu24.rocm722:2d65281b00de2bcafc811247d563b6e5e7c887af
147+
image: ghcr.io/rocm/jax-base-ubu24.${{ inputs.rocm-tag }}:latest # zizmor: ignore[unpinned-images]
148148
credentials:
149149
username: ${{ github.actor }}
150150
password: ${{ secrets.GITHUB_TOKEN }}
@@ -187,7 +187,7 @@ jobs:
187187
TARGET: ${{ inputs.target }}
188188
WORKLOAD: ${{ inputs.workload }}
189189
timeout-minutes: 120
190-
run: ./ci/benchmark_targets/${TARGET}/run_${TARGET}_rocm.sh" --workload "${WORKLOAD}"
190+
run: ./ci/benchmark_targets/${TARGET}/run_${TARGET}_rocm.sh --workload "${WORKLOAD}"
191191
- name: Upload GitHub artifacts
192192
if: always()
193193
continue-on-error: true

.github/workflows/wheel_benchmarks_nightly_release.yml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# CI - Wheel Tests (Nightly/Release)
2+
#
3+
# This workflow is used to benchmark the JAX wheels that were built by internal CI jobs.
4+
#
5+
# 1. run-benchmark-rocm: Calls the `benchmark_rocm.yml` workflow which downloads the JAX wheels and
6+
# ROCm plugins from S3 and runs ROCm benchmarks.
17
name: CI - Wheel Benchmarks (Nightly/Release)
28

39
on:
@@ -55,6 +61,4 @@ jobs:
5561
jaxlib-version: "head"
5662
skip-download-jaxlib-and-plugins-from-gcs: ${{inputs.skip-download-jaxlib-and-plugins-from-gcs}}
5763
gcs_download_uri: ${{ inputs.gcs_download_uri }}
58-
s3_download_uri: ${{ inputs.s3_download_uri }}
5964
use-te: "1"
60-
te-wheel-url: "https://github.com/ROCm/maxtext/releases/download/te-rocm-wheels-2026-05-04-86438dc3d04e/transformer_engine-2.12.0.dev0+86438dc3-1.mi355-cp312-cp312-linux_x86_64.whl"

ci/benchmark_targets/maxtext_rocm/run_maxtext_rocm.sh

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,27 @@ source ./ci/utilities/install_wheels_locally.sh
4343

4444
[[ -f "${ENV_FILE}" ]] && source "${ENV_FILE}"
4545

46-
[[ "${USE_TE:-0}" == "1" ]] && {
47-
: "${TE_WHEEL_URL:?TE_WHEEL_URL must be set when USE_TE=1}"
48-
echo "[setup] installing Transformer Engine"
46+
if [[ "${USE_TE:-0}" == "1" ]]; then
47+
echo "[setup] resolving latest Transformer Engine wheel"
48+
PY_TAG="cp$(echo "${JAXCI_HERMETIC_PYTHON_VERSION:-3.12}" | tr -d '.')"
49+
TE_WHEEL_URL="$(
50+
curl -fsSL https://api.github.com/repos/ROCm/maxtext/releases \
51+
| grep "browser_download_url:" \
52+
| grep "te-rocm-wheels-" \
53+
| grep "${PY_TAG}" \
54+
| grep linux_x86_64.whl \
55+
| head -n1 \
56+
| cut -d '"' -f4
57+
)"
58+
59+
[[ -n "${TE_WHEEL_URL}" ]] || {
60+
echo "failed to resolve Transformer Engine wheel for ${PY_TAG}" >&2
61+
exit 1
62+
}
63+
64+
echo "[setup] installing Transformer Engine from ${TE_WHEEL_URL}"
4965
"${PYTHON_BIN}" -m pip install --no-deps "${TE_WHEEL_URL}"
50-
}
66+
fi
5167

5268
export PY_COLORS=1
5369
export NCCL_DEBUG=WARN

0 commit comments

Comments
 (0)