Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 85 additions & 34 deletions .github/scripts/Dockerfile.ci.deps
Original file line number Diff line number Diff line change
@@ -1,49 +1,100 @@
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
#
# See LICENSE for license information.
#
# TE CI deps image: Ubuntu 24.04 + TheRock ROCm tarball
#
# docker build -f .github/scripts/Dockerfile.ci.deps \
# --build-arg GPU_ARCH=gfx942 \
# -t te-ci-deps:rocm-7.12.0-ubuntu24.04-py312-gfx942 .
#
# ROCm installer: https://raw.githubusercontent.com/ROCm/TheRock/main/dockerfiles/install_rocm_tarball.sh

## TE CI Dockerfile
ARG BASE_DOCKER=registry-sc-harbor.amd.com/framework/compute-rocm-rel-7.2:57_ubuntu22.04_py3.11_pytorch_release-2.8_08d38866
FROM $BASE_DOCKER
WORKDIR /
FROM ubuntu:24.04

# Updated git via git-core PPA
RUN apt-get update && apt-get install -y --no-install-recommends software-properties-common \
&& add-apt-repository ppa:git-core/ppa -y \
&& apt-get update \
&& apt-get install -y --no-install-recommends git vim \
&& rm -rf /var/lib/apt/lists/*
ARG DEBIAN_FRONTEND=noninteractive
SHELL ["/bin/bash", "-euo", "pipefail", "-c"]

# Build arguments
ARG ROCM_VERSION=7.12.0
ARG GPU_ARCH=gfx942
ARG PYTHON_VERSION=3.12
ARG PYTHON_ABI=cp312
ARG TORCH_VERSION=2.10.0
ARG TORCHVISION_VERSION=0.25.0
ARG TORCHAUDIO_VERSION=2.10.0
ARG TRITON_VERSION=3.6.0
ARG JAX_VERSION=0.8.2
ARG FA_VERSION=v2.8.1
ARG ROCM_VERSION=7.2
ARG JAX_VERSION=0.8.0
ARG PYTHON_VERSION=311
# AITER - Required for MXFP4 FP4 GEMM kernels.
ARG AITER_COMMIT=77455e3ecf4f0d28756afc452e914940c45b944b
ARG INSTALL_ROCM_TARBALL_SH_URL=https://raw.githubusercontent.com/ROCm/TheRock/main/dockerfiles/install_rocm_tarball.sh

# Map GPU_ARCH → AMD GPU wheel/tarball family (once).
RUN case "${GPU_ARCH}" in \
gfx942) echo -n gfx94X-dcgpu > /etc/amd_gpu_family ;; \
gfx950) echo -n gfx950-dcgpu > /etc/amd_gpu_family ;; \
*) echo "GPU_ARCH must be gfx942 or gfx950 (got: ${GPU_ARCH})" >&2; exit 1 ;; \
esac

# Base OS packages
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates curl \
git vim \
build-essential cmake ninja-build pkg-config liblzma-dev \
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv python${PYTHON_VERSION}-dev python3-pip \
&& rm -rf /var/lib/apt/lists/*

# Native ROCm tarball → /opt/rocm
RUN AMD_GPU_FAMILY=$(cat /etc/amd_gpu_family) \
&& curl -fsSL -o /tmp/install_rocm_tarball.sh "${INSTALL_ROCM_TARBALL_SH_URL}" \
&& chmod +x /tmp/install_rocm_tarball.sh \
&& /tmp/install_rocm_tarball.sh "${ROCM_VERSION}" "${AMD_GPU_FAMILY}" stable \
&& rm -f /tmp/install_rocm_tarball.sh

# Isolated Python env for pip packages
RUN python${PYTHON_VERSION} -m venv /opt/venv

# Default container env: venv on PATH, ROCm toolchain + runtime libs, GPU arch for builds.
ENV GPU_ARCH=${GPU_ARCH} \
ROCM_PATH=/opt/rocm \
VIRTUAL_ENV=/opt/venv \
PATH=/opt/venv/bin:/opt/rocm/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin \
LD_LIBRARY_PATH=/opt/rocm/lib/rocm_sysdeps/lib:/opt/rocm/lib

RUN python -m pip install --upgrade pip setuptools wheel \
&& pip install --no-cache-dir ipython pytest fire pydantic pybind11 ninja pandas expecttest onnxscript

RUN pip install setuptools wheel
RUN pip install ipython pytest fire pydantic pybind11 ninja pandas
# Python ROCm SDK + torch / jax from https://repo.amd.com/rocm/whl/<AMD_GPU_FAMILY>/
RUN AMD_GPU_FAMILY=$(cat /etc/amd_gpu_family) \
&& W="https://repo.amd.com/rocm/whl/${AMD_GPU_FAMILY}" \
&& LIBS_PKG="rocm-sdk-libraries-$(echo "${AMD_GPU_FAMILY}" | tr '[:upper:]' '[:lower:]')" \
&& ROCM_WHEEL_TAG="rocm${ROCM_VERSION}" \
&& pip install --no-cache-dir \
--extra-index-url "${W}" \
"rocm-sdk-core==${ROCM_VERSION}" \
"rocm-sdk-devel==${ROCM_VERSION}" \
"${LIBS_PKG}==${ROCM_VERSION}" \
"${W}/rocm-${ROCM_VERSION}.tar.gz" \
"${W}/torch-${TORCH_VERSION}%2B${ROCM_WHEEL_TAG}-${PYTHON_ABI}-${PYTHON_ABI}-linux_x86_64.whl" \
"${W}/torchvision-${TORCHVISION_VERSION}%2B${ROCM_WHEEL_TAG}-${PYTHON_ABI}-${PYTHON_ABI}-linux_x86_64.whl" \
"${W}/torchaudio-${TORCHAUDIO_VERSION}%2B${ROCM_WHEEL_TAG}-${PYTHON_ABI}-${PYTHON_ABI}-linux_x86_64.whl" \
"${W}/triton-${TRITON_VERSION}%2B${ROCM_WHEEL_TAG}-${PYTHON_ABI}-${PYTHON_ABI}-linux_x86_64.whl" \
"${W}/jax_rocm7_pjrt-${JAX_VERSION}%2B${ROCM_WHEEL_TAG}-py3-none-manylinux_2_28_x86_64.whl" \
"${W}/jax_rocm7_plugin-${JAX_VERSION}%2B${ROCM_WHEEL_TAG}-${PYTHON_ABI}-${PYTHON_ABI}-manylinux_2_28_x86_64.whl" \
"${W}/jaxlib-${JAX_VERSION}%2B${ROCM_WHEEL_TAG}-${PYTHON_ABI}-${PYTHON_ABI}-manylinux_2_27_x86_64.whl" \
"jax==${JAX_VERSION}"

# Install flash-attention
RUN git clone --branch ${FA_VERSION} --depth 1 https://github.com/Dao-AILab/flash-attention.git \
&& cd flash-attention \
&& GPU_ARCHS="gfx950;gfx942" FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE FLASH_ATTENTION_SKIP_CK_BUILD=FALSE python setup.py install \
&& cd ..
RUN git clone --branch "${FA_VERSION}" --depth 1 https://github.com/Dao-AILab/flash-attention.git /tmp/flash-attention \
&& cd /tmp/flash-attention \
&& GPU_ARCHS="${GPU_ARCH}" FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE FLASH_ATTENTION_SKIP_CK_BUILD=FALSE \
python setup.py install \
&& rm -rf /tmp/flash-attention

# Install AITER
RUN git clone --no-checkout https://github.com/ROCm/aiter.git \
&& cd aiter \
&& git checkout ${AITER_COMMIT} \
RUN git clone --no-checkout https://github.com/ROCm/aiter.git /tmp/aiter \
&& cd /tmp/aiter \
&& git checkout "${AITER_COMMIT}" \
&& git submodule update --init --recursive \
&& pip install .

# Install JAX
RUN ROCM_MAJOR=$(echo "${ROCM_VERSION}" | cut -d. -f1) && pip install \
https://repo.radeon.com/rocm/manylinux/rocm-rel-${ROCM_VERSION}/jax_rocm${ROCM_MAJOR}_pjrt-${JAX_VERSION}%2Brocm${ROCM_VERSION}.0-py3-none-manylinux_2_28_x86_64.whl \
https://repo.radeon.com/rocm/manylinux/rocm-rel-${ROCM_VERSION}/jax_rocm${ROCM_MAJOR}_plugin-${JAX_VERSION}%2Brocm${ROCM_VERSION}.0-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux_2_28_x86_64.whl \
jax==${JAX_VERSION} \
https://repo.radeon.com/rocm/manylinux/rocm-rel-${ROCM_VERSION}/jaxlib-${JAX_VERSION}%2Brocm${ROCM_VERSION}.0-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux_2_27_x86_64.whl
&& GPU_ARCHS="${GPU_ARCH}" pip install --no-build-isolation --no-cache-dir . \
&& rm -rf /tmp/aiter

WORKDIR /workspace/
CMD ["/bin/bash"]
51 changes: 43 additions & 8 deletions .github/workflows/ci-deps-docker-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@
#
# See LICENSE for license information.
#
# Build .github/scripts/Dockerfile.ci.deps and push to Artifactory.
# Build .github/scripts/Dockerfile.ci.deps and push to Harbor.
#
# Tag convention (when image_tag is left empty):
# rocm-<rocm_version>-ubuntu24.04-py312-<gpu_arch>
# e.g. rocm-7.12.0-ubuntu24.04-py312-gfx942
#
# Local builds:
# docker build -f .github/scripts/Dockerfile.ci.deps \
# --build-arg GPU_ARCH=gfx942 \
# -t registry-sc-harbor.amd.com/framework/te-ci:rocm-7.12.0-ubuntu24.04-py312-gfx942 .
#
# Required repository secrets:
# ARTIFACTORY_DOCKER_USERNAME / ARTIFACTORY_DOCKER_PASSWORD — registry basic auth
Expand All @@ -16,10 +25,23 @@ name: Publish CI deps Docker image
on:
workflow_dispatch:
inputs:
image_tag:
description: "Image tag pushed to Artifactory (required)"
gpu_arch:
description: "GPU architecture for the CI deps image (Dockerfile GPU_ARCH build-arg)"
required: true
type: choice
options:
- gfx942
- gfx950
rocm_version:
description: "ROCm version string (must match Dockerfile wheel pins when changed)"
required: false
type: string
default: "7.12.0"
image_tag:
description: "Tag to push; leave empty for rocm-<ver>-ubuntu24.04-py312-<gpu_arch>"
required: false
type: string
default: ""

jobs:
build-and-push:
Expand All @@ -37,16 +59,24 @@ jobs:
env:
REGISTRY: ${{ vars.ARTIFACTORY_DOCKER_REGISTRY }}
REPOSITORY: ${{ vars.ARTIFACTORY_CI_DEPS_REPOSITORY }}
IMAGE_TAG: ${{ inputs.image_tag }}
IMAGE_TAG_INPUT: ${{ inputs.image_tag }}
GPU_ARCH: ${{ inputs.gpu_arch }}
ROCM_VERSION: ${{ inputs.rocm_version }}
run: |
set -euo pipefail
if [ -z "${REGISTRY}" ] || [ -z "${REPOSITORY}" ]; then
echo "Set repository variables ARTIFACTORY_DOCKER_REGISTRY and ARTIFACTORY_CI_DEPS_REPOSITORY." >&2
exit 1
fi
if [ -z "${IMAGE_TAG}" ]; then
echo "image_tag must be non-empty." >&2
exit 1
case "${GPU_ARCH}" in
gfx942|gfx950) ;;
*) echo "gpu_arch must be gfx942 or gfx950" >&2; exit 1 ;;
esac
ROCM_VER="${ROCM_VERSION:-7.12.0}"
if [ -n "${IMAGE_TAG_INPUT}" ]; then
echo "IMAGE_TAG=${IMAGE_TAG_INPUT}" >> "${GITHUB_ENV}"
else
echo "IMAGE_TAG=rocm-${ROCM_VER}-ubuntu24.04-py312-${GPU_ARCH}" >> "${GITHUB_ENV}"
fi

- name: Log in to container registry
Expand All @@ -60,12 +90,17 @@ jobs:
env:
REGISTRY: ${{ vars.ARTIFACTORY_DOCKER_REGISTRY }}
REPOSITORY: ${{ vars.ARTIFACTORY_CI_DEPS_REPOSITORY }}
IMAGE_TAG: ${{ inputs.image_tag }}
GPU_ARCH: ${{ inputs.gpu_arch }}
ROCM_VERSION_INPUT: ${{ inputs.rocm_version }}
run: |
set -euo pipefail
FULL_IMAGE="${REGISTRY}/${REPOSITORY}"
: "${IMAGE_TAG:?IMAGE_TAG must be set by the validate step}"
ROCM_VER="${ROCM_VERSION_INPUT:-7.12.0}"
docker build \
-f .github/scripts/Dockerfile.ci.deps \
--build-arg "ROCM_VERSION=${ROCM_VER}" \
--build-arg "GPU_ARCH=${GPU_ARCH}" \
-t "${FULL_IMAGE}:${IMAGE_TAG}" \
.
docker push "${FULL_IMAGE}:${IMAGE_TAG}"
Expand Down
28 changes: 19 additions & 9 deletions .github/workflows/rocm-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ jobs:
runs-on: ubuntu-latest
timeout-minutes: 10
outputs:
image-tag: ${{ steps.select-image.outputs.image-tag }}
image-tag-mi30x: ${{ steps.select-image.outputs.image-tag-mi30x }}
image-tag-mi35x: ${{ steps.select-image.outputs.image-tag-mi35x }}
steps:
- name: Checkout repository
uses: actions/checkout@v6
Expand Down Expand Up @@ -88,16 +89,25 @@ jobs:
fi

echo "Selected config key: $JSON_KEY"
IMAGE_TO_USE=$(jq -r --arg key "$JSON_KEY" '.docker_images[$key]' ci/ci_config.json)
CONFIG_ENTRY=$(jq -c --arg key "$JSON_KEY" '.docker_images[$key]' ci/ci_config.json)

MANUAL_OVERRIDE="${{ inputs.docker_image_override }}"
if [[ -n "$MANUAL_OVERRIDE" ]]; then
echo "::notice::Manual override detected: $MANUAL_OVERRIDE"
IMAGE_TO_USE="$MANUAL_OVERRIDE"
IMAGE_MI30X="$MANUAL_OVERRIDE"
IMAGE_MI35X="$MANUAL_OVERRIDE"
elif jq -e '.mi30x and .mi35x' <<< "$CONFIG_ENTRY" > /dev/null; then
IMAGE_MI30X=$(jq -r '.mi30x' <<< "$CONFIG_ENTRY")
IMAGE_MI35X=$(jq -r '.mi35x' <<< "$CONFIG_ENTRY")
else
IMAGE_MI30X=$(jq -r '.' <<< "$CONFIG_ENTRY")
IMAGE_MI35X="$IMAGE_MI30X"
fi

echo "Selected image: $IMAGE_TO_USE"
echo "image-tag=$IMAGE_TO_USE" >> $GITHUB_OUTPUT
echo "Selected mi30x (gfx942) image: $IMAGE_MI30X"
echo "Selected mi35x (gfx950) image: $IMAGE_MI35X"
echo "image-tag-mi30x=$IMAGE_MI30X" >> $GITHUB_OUTPUT
echo "image-tag-mi35x=$IMAGE_MI35X" >> $GITHUB_OUTPUT

build:
# Delegate wheel building to the reusable workflow on dev. It produces a core .whl plus framework .tar.gz sdists under artifact name `te-rocm-wheels`.
Expand Down Expand Up @@ -140,7 +150,7 @@ jobs:

- name: Pull Docker Image
run: |
docker pull ${{ needs.select_image.outputs.image-tag }}
docker pull ${{ matrix.arch_label == 'mi30x' && needs.select_image.outputs.image-tag-mi30x || needs.select_image.outputs.image-tag-mi35x }}

- name: Run Container
run: |
Expand All @@ -155,7 +165,7 @@ jobs:
--group-add $(getent group video | cut -d: -f3) \
-v "${{ github.workspace }}:/workspace" \
-w /workspace \
${{ needs.select_image.outputs.image-tag }}
${{ matrix.arch_label == 'mi30x' && needs.select_image.outputs.image-tag-mi30x || needs.select_image.outputs.image-tag-mi35x }}

- name: Install packages
run: |
Expand Down Expand Up @@ -337,7 +347,7 @@ jobs:

- name: Pull Docker Image
run: |
docker pull ${{ needs.select_image.outputs.image-tag }}
docker pull ${{ matrix.arch_label == 'mi30x' && needs.select_image.outputs.image-tag-mi30x || needs.select_image.outputs.image-tag-mi35x }}

- name: Run Container
run: |
Expand All @@ -352,7 +362,7 @@ jobs:
--group-add $(getent group video | cut -d: -f3) \
-v "${{ github.workspace }}:/workspace" \
-w /workspace \
${{ needs.select_image.outputs.image-tag }}
${{ matrix.arch_label == 'mi30x' && needs.select_image.outputs.image-tag-mi30x || needs.select_image.outputs.image-tag-mi35x }}

- name: Install packages
env:
Expand Down
Loading