Problem
On AMD Instinct MI350X (gfx950), import jax; jax.devices() silently hangs
for >5 minutes (no log, no error) on first GPU initialization. Setting
AMD_COMGR_NAMESPACE=1 makes init return immediately. This is not documented in
the JAX-ROCm install notes, so a first-time user just sees an apparent deadlock.
Reproduction
# jax 0.8.2 (ROCm build), ROCm 7.2, gfx950
python -c "import jax; print(jax.devices())"
# -> hangs >5 min during XLA/COMGR backend init
# Workaround:
AMD_COMGR_NAMESPACE=1 python -c "import jax; print(jax.devices())"
# -> returns immediately, lists the gfx950 device(s)
## Root Cause (suspected)
COMGR namespace handling during XLA backend compilation on gfx950. The AMD_COMGR_NAMESPACE=1 env var sidesteps it — looks like a known AMD/COMGR interaction rather than a JAX logic bug, but it surfaces only through JAX init.
## Suggested Fix
Either document AMD_COMGR_NAMESPACE=1 as a required/recommended env var for gfx950 in the ROCm-JAX README, or set a sane default/early warning when init takes abnormally long.
Problem
On AMD Instinct MI350X (gfx950),
import jax; jax.devices()silently hangsfor >5 minutes (no log, no error) on first GPU initialization. Setting
AMD_COMGR_NAMESPACE=1makes init return immediately. This is not documented inthe JAX-ROCm install notes, so a first-time user just sees an apparent deadlock.
Reproduction