Skip to content

Commit 5c4b687

Browse files
authored
nsys-jax: re-work to be more pip-install-able (#1165)
The overarching goal of this PR is to get closer to a world where the `nsys-jax` tooling is straightforwardly `pip install`-able. While the diff looks scary, it's mostly re-organisation. Substantive changes: - `nsys-jax` no longer bundles Python code in the output archives, the `install.sh` script provided for users to run on local machines becomes, loosely, `install 'pip nsys-jax[jupyter] @ git+https://github.com/NVIDIA/JAX-Toolbox.git@COMMIT#subdirectory=.github/container/nsys_jax'`, where `COMMIT` corresponds to the `nsys-jax` command that produced the archive. For the `ghcr.io/nvidia/jax` containers, this is the commit of JAX-Toolbox that triggered the container build. Changes included: - Introduce `/opt/pip-tools-post-install.d`, which `pip-finalize.sh` will execute the contents of *after* installing the `pip`-managed world - Migrate `install-protoc` to use this, so `pip-finalize.sh` can forget about that detail. - Install https://github.com/brendangregg/FlameGraph/blob/master/flamegraph.pl via this. - Patch the `nvtx_gpu_proj_trace` Python code in Nsight Systems 2024.5 and 2024.6 via this. - Move `nsys-jax` installation (specifically for the containers) into `install-nsys-jax.sh` and thereby clean up `install-nsight.sh`. The new script has to be told the git commit hash of JAX-Toolbox that is being built, because `nsys-jax` bakes this into an installation script in its output `.zip` archives to ensure the local environment matches the profile-collection environment. - The CLI tools like `nsys-jax`, `nsys-jax-combine` and `install-protoc` are now handled via `[project.scripts]` in `pyproject.toml` instead of being standalone Python scripts. This is "more standard", and also makes it easier to share code between `nsys-jax` and `nsys-jax-combine`. - The Python library is renamed from `jax_nsys` to `nsys_jax` for consistency. - It's now possible to set the default data loading path via the `NSYS_JAX_DEFAULT_PREFIX` environment variable; previously the default was the current working directory, but that can be inconvenient to steer in Jupyter environments.
1 parent de72dd8 commit 5c4b687

36 files changed

+1497
-1283
lines changed

Diff for: .github/container/Dockerfile.base

+9-20
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ ARG BASE_IMAGE=nvidia/cuda:12.6.2-devel-ubuntu22.04
33
ARG GIT_USER_NAME="JAX Toolbox"
44
55
ARG CLANG_VERSION=18
6+
ARG JAX_TOOLBOX_REF
67

78
###############################################################################
89
## Obtain GCP's NCCL TCPx plugin
@@ -30,6 +31,7 @@ ARG BASE_IMAGE
3031
ARG GIT_USER_EMAIL
3132
ARG GIT_USER_NAME
3233
ARG CLANG_VERSION
34+
ARG JAX_TOOLBOX_REF
3335
ENV CUDA_BASE_IMAGE=${BASE_IMAGE}
3436

3537
###############################################################################
@@ -110,7 +112,7 @@ RUN <<"EOF" bash -ex
110112
git config --global user.name "${GIT_USER_NAME}"
111113
git config --global user.email "${GIT_USER_EMAIL}"
112114
EOF
113-
RUN mkdir -p /opt/pip-tools.d
115+
RUN mkdir -p /opt/pip-tools.d /opt/pip-tools-post-install.d
114116
ADD --chmod=777 \
115117
git-clone.sh \
116118
pip-finalize.sh \
@@ -141,7 +143,6 @@ COPY --from=tcpx-installer /var/lib/tcpx/lib64 ${TCPX_LIBRARY_PATH}
141143
###############################################################################
142144

143145
ADD install-nsight.sh /usr/local/bin
144-
ADD nsys-2024.5-tid-export.patch /opt/nvidia
145146
RUN install-nsight.sh
146147

147148
###############################################################################
@@ -183,7 +184,7 @@ ENV PATH=/opt/amazon/efa/bin:${PATH}
183184
ADD install-nccl-sanity-check.sh /usr/local/bin
184185
ADD nccl-sanity-check.cu /opt
185186
RUN install-nccl-sanity-check.sh
186-
ADD jax-nccl-test parallel-launch /usr/local/bin
187+
ADD jax-nccl-test parallel-launch /usr/local/bin/
187188

188189
###############################################################################
189190
## Add the systemcheck to the entrypoint.
@@ -199,23 +200,11 @@ COPY check-shm.sh /opt/nvidia/entrypoint.d/
199200
# COPY gcp-autoconfig.sh /opt/nvidia/entrypoint.d/
200201

201202
###############################################################################
202-
## Add helper scripts for profiling with Nsight Systems
203-
##
204-
## The scripts saved to /opt/jax_nsys are embedded in the output archives
205-
## written by nsys-jax, while the nsys-jax and nsys-jax-combine scripts are
206-
## only used inside the containers.
207-
###############################################################################
208-
ADD nsys-jax nsys-jax-combine /usr/local/bin/
209-
ADD jax_nsys/ /opt/jax_nsys
210-
# The jax_nsys package should be installed inside the containers, so nsys-jax
211-
# can eagerly execute analysis recipes (--nsys-jax-analysis) in the container
212-
# environment, without an extra layer of virtual environment indirection.
213-
RUN echo "-e /opt/jax_nsys/python/jax_nsys" > /opt/pip-tools.d/requirements-nsys-jax.in
214-
# This should be embedded in output archives and be runnable inside containers
215-
RUN ln -s /opt/jax_nsys/install-protoc /usr/local/bin/
216-
# Should be available for execution inside the containers, should not be
217-
# embedded in the output archives.
218-
ADD jax_nsys_tests/ /opt/jax_nsys_tests
203+
## Install the nsys-jax JAX/XLA-aware profiling scripts, patch Nsight Systems
204+
###############################################################################
205+
206+
ADD install-nsys-jax.sh /usr/local/bin
207+
RUN install-nsys-jax.sh ${JAX_TOOLBOX_REF}
219208

220209
###############################################################################
221210
## Copy manifest file to the container

Diff for: .github/container/install-nsight.sh

-11
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,3 @@ apt-get install -y nsight-compute nsight-systems-cli-2024.6.1
1616
apt-get clean
1717

1818
rm -rf /var/lib/apt/lists/*
19-
20-
for NSYS in /opt/nvidia/nsight-systems-cli/2024.5.1 /opt/nvidia/nsight-systems-cli/2024.6.1; do
21-
if [[ -d "${NSYS}" ]]; then
22-
# * can match at least sbsa-armv8 and x86
23-
(cd ${NSYS}/target-linux-*/python/packages && git apply < /opt/nvidia/nsys-2024.5-tid-export.patch)
24-
fi
25-
done
26-
27-
# Install extra dependencies needed for `nsys recipe ...` commands. These are
28-
# used by the nsys-jax wrapper script.
29-
ln -s $(dirname $(realpath $(command -v nsys)))/python/packages/nsys_recipe/requirements/common.txt /opt/pip-tools.d/requirements-nsys-recipe.in

Diff for: .github/container/install-nsys-jax.sh

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#!/bin/bash
2+
set -exo pipefail
3+
4+
REF="$1"
5+
if [[ -z "${REF}" ]]; then
6+
echo "$0: <git ref of JAX-Toolbox>"
7+
exit 1
8+
fi
9+
10+
# Install extra dependencies needed for `nsys recipe ...` commands. These are
11+
# used by the nsys-jax wrapper script.
12+
NSYS_DIR=$(dirname $(realpath $(command -v nsys)))
13+
ln -s ${NSYS_DIR}/python/packages/nsys_recipe/requirements/common.txt /opt/pip-tools.d/requirements-nsys-recipe.in
14+
15+
# Install the nsys-jax package, which includes nsys-jax, nsys-jax-combine,
16+
# install-protoc (called from pip-finalize.sh), and nsys-jax-patch-nsys as well as the
17+
# nsys_jax Python library.
18+
URL="git+https://github.com/NVIDIA/JAX-Toolbox.git@${REF}#subdirectory=.github/container/nsys_jax&egg=nsys-jax"
19+
echo "-e '${URL}'" > /opt/pip-tools.d/requirements-nsys-jax.in
20+
21+
# protobuf will be installed at least as a dependency of nsys_jax in the base
22+
# image, but the installed version is likely to be influenced by other packages.
23+
echo "install-protoc /usr/local" > /opt/pip-tools-post-install.d/protoc
24+
chmod 755 /opt/pip-tools-post-install.d/protoc
25+
26+
# Make sure flamegraph.pl is available
27+
echo "install-flamegraph /usr/local" > /opt/pip-tools-post-install.d/flamegraph
28+
chmod 755 /opt/pip-tools-post-install.d/flamegraph
29+
30+
# Make sure Nsight Systems Python patches are installed if needed
31+
echo "nsys-jax-patch-nsys" > /opt/pip-tools-post-install.d/patch-nsys
32+
chmod 755 /opt/pip-tools-post-install.d/patch-nsys

Diff for: .github/container/jax_nsys/install-protoc

-65
This file was deleted.

Diff for: .github/container/jax_nsys/install.sh

-38
This file was deleted.

Diff for: .github/container/jax_nsys/python/jax_nsys/pyproject.toml

-17
This file was deleted.

0 commit comments

Comments
 (0)