Skip to content

[gfx950] JAX GPU init silently hangs (>5min) unless AMD_COMGR_NAMESPACE=1 #800

Description

@ZJLi2013

Problem

On AMD Instinct MI350X (gfx950), import jax; jax.devices() silently hangs
for >5 minutes (no log, no error) on first GPU initialization. Setting
AMD_COMGR_NAMESPACE=1 makes init return immediately. This is not documented in
the JAX-ROCm install notes, so a first-time user just sees an apparent deadlock.

Reproduction

# jax 0.8.2 (ROCm build), ROCm 7.2, gfx950
python -c "import jax; print(jax.devices())"
# -> hangs >5 min during XLA/COMGR backend init

# Workaround:
AMD_COMGR_NAMESPACE=1 python -c "import jax; print(jax.devices())"
# -> returns immediately, lists the gfx950 device(s)

## Root Cause (suspected)
COMGR namespace handling during XLA backend compilation on gfx950. The AMD_COMGR_NAMESPACE=1 env var sidesteps it — looks like a known AMD/COMGR interaction rather than a JAX logic bug, but it surfaces only through JAX init.

## Suggested Fix
Either document AMD_COMGR_NAMESPACE=1 as a required/recommended env var for gfx950 in the ROCm-JAX README, or set a sane default/early warning when init takes abnormally long.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions