You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
nsys-jax: re-work to be more pip-install-able (#1165)
The overarching goal of this PR is to get closer to a world where the
`nsys-jax` tooling is straightforwardly `pip install`-able. While the
diff looks scary, it's mostly re-organisation.
Substantive changes:
- `nsys-jax` no longer bundles Python code in the output archives, the
`install.sh` script provided for users to run on local machines becomes,
loosely, `install 'pip nsys-jax[jupyter] @
git+https://github.com/NVIDIA/JAX-Toolbox.git@COMMIT#subdirectory=.github/container/nsys_jax'`,
where `COMMIT` corresponds to the `nsys-jax` command that produced the
archive. For the `ghcr.io/nvidia/jax` containers, this is the commit of
JAX-Toolbox that triggered the container build.
Changes included:
- Introduce `/opt/pip-tools-post-install.d`, which `pip-finalize.sh`
will execute the contents of *after* installing the `pip`-managed world
- Migrate `install-protoc` to use this, so `pip-finalize.sh` can forget
about that detail.
- Install
https://github.com/brendangregg/FlameGraph/blob/master/flamegraph.pl via
this.
- Patch the `nvtx_gpu_proj_trace` Python code in Nsight Systems 2024.5
and 2024.6 via this.
- Move `nsys-jax` installation (specifically for the containers) into
`install-nsys-jax.sh` and thereby clean up `install-nsight.sh`. The new
script has to be told the git commit hash of JAX-Toolbox that is being
built, because `nsys-jax` bakes this into an installation script in its
output `.zip` archives to ensure the local environment matches the
profile-collection environment.
- The CLI tools like `nsys-jax`, `nsys-jax-combine` and `install-protoc`
are now handled via `[project.scripts]` in `pyproject.toml` instead of
being standalone Python scripts. This is "more standard", and also makes
it easier to share code between `nsys-jax` and `nsys-jax-combine`.
- The Python library is renamed from `jax_nsys` to `nsys_jax` for
consistency.
- It's now possible to set the default data loading path via the
`NSYS_JAX_DEFAULT_PREFIX` environment variable; previously the default
was the current working directory, but that can be inconvenient to steer
in Jupyter environments.
0 commit comments