-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add rosetta-maxtext Dockerfile
- Loading branch information
Showing
1 changed file
with
77 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# syntax=docker/dockerfile:1-labs | ||
ARG BASE_IMAGE=ghcr.io/nvidia/jax-mealkit:upstream-pax | ||
ARG [email protected] | ||
ARG GIT_USER_NAME=NVIDIA | ||
# If set to "true", then will pull new local patches, the manifest.yaml and create-distribution.sh (in case it was updated). | ||
# This is useful for development if you run `./bump.sh -i manifest.yaml` manually and do not want to trigger a full rebuild all | ||
# the way up to the jax build. | ||
ARG UPDATE_PATCHES=false | ||
# It is common for TE developers to test a different TE against the LLM application. This is a knob to override what's in the manifest | ||
# Accepts git-ref's from NVIDIA/TransformerEngine or pull requests (pull/$number/head) | ||
ARG UPDATED_TE_REF="" | ||
|
||
# Rosetta and optionally patches are pulled from this | ||
FROM scratch AS jax-toolbox | ||
|
||
############################################################################### | ||
### Download source and add auxiliary scripts | ||
################################################################################ | ||
|
||
FROM ${BASE_IMAGE} AS mealkit | ||
ARG GIT_USER_EMAIL | ||
ARG GIT_USER_NAME | ||
ARG UPDATE_PATCHES | ||
ARG UPDATED_TE_REF | ||
|
||
ENV ENABLE_TE=1 | ||
|
||
RUN --mount=target=/mnt/jax-toolbox,from=jax-toolbox <<"EOF" bash -exu | ||
MANIFEST_DIR=$(dirname ${MANIFEST_FILE}) | ||
if [[ "${UPDATE_PATCHES}" != "true" && "${UPDATE_PATCHES}" != "false" ]]; then | ||
echo "UPDATE_PATCHES can only be true or false" | ||
exit 1 | ||
fi | ||
if [[ "${UPDATE_PATCHES}" == "true" ]]; then | ||
cp -r /mnt/jax-toolbox/.github/container/patches ${MANIFEST_DIR}/ | ||
cp /mnt/jax-toolbox/.github/container/manifest.yaml ${MANIFEST_DIR}/manifest.yaml | ||
cp /mnt/jax-toolbox/.github/container/create-distribution.sh ${MANIFEST_DIR}/create-distribution.sh | ||
# TODO: remove | ||
cp /mnt/jax-toolbox/.github/container/pip-finalize.sh /usr/local/bin/ | ||
fi | ||
cp -r /mnt/jax-toolbox/rosetta /opt/rosetta | ||
|
||
if [[ -n "${UPDATED_TE_REF}" ]]; then | ||
TE_INSTALL_DIR=/opt/transformer-engine | ||
yq e ".transformer-engine.latest_verified_commit = \"${UPDATED_TE_REF}\"" -i $MANIFEST_FILE | ||
# Install from source instead of pre-built wheel | ||
sed -i -E 's@( file:///opt/transformer-engine)/dist/[^ ]*@\1@' /opt/pip-tools.d/requirements-te.in | ||
git -C $TE_INSTALL_DIR fetch -a | ||
if [[ "${UPDATED_TE_REF}" =~ ^pull/ ]]; then | ||
PR_ID=$(cut -d/ -f2 <<<"${UPDATED_TE_REF}") | ||
git -C $TE_INSTALL_DIR fetch origin ${UPDATED_TE_REF}:PR-${PR_ID} | ||
git -C $TE_INSTALL_DIR checkout PR-${PR_ID} | ||
else | ||
git -C $TE_INSTALL_DIR checkout ${UPDATED_TE_REF} | ||
fi | ||
fi | ||
|
||
# Setting the username/email is required to author commits from patches | ||
git config --global user.email "${GIT_USER_EMAIL}" | ||
git config --global user.name "${GIT_USER_NAME}" | ||
|
||
bash ${MANIFEST_DIR}/create-distribution.sh \ | ||
--manifest ${MANIFEST_FILE} \ | ||
--package maxtext | ||
# Remove .gitconfig to avoid end-user authoring commits as the "build user" | ||
rm -f ~/.gitconfig | ||
EOF | ||
|
||
WORKDIR /opt/rosetta | ||
|
||
############################################################################### | ||
### Install accumulated packages from the base image and the previous stage | ||
################################################################################ | ||
|
||
FROM mealkit as final | ||
|
||
RUN pip-finalize.sh |