Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

If CUDA 12.1 is installed, pip-installed ptxas binary is not used and jax throws an error #25718

Open
takkyu2 opened this issue Jan 3, 2025 · 1 comment
Labels
bug Something isn't working

Comments

@takkyu2
Copy link

takkyu2 commented Jan 3, 2025

Description

Please feel free to close this issue in case this is an expected behavior.
In case this is expected, it would be great if there would be an easy way to fix it from our (users') side other than installing newer CUDA version.

Summary

If CUDA version 12.1 is installed to the system and ptxas is already in the system PATH:

ptxas --version
ptxas: NVIDIA (R) Ptx optimizing assembler
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Feb__7_19:30:12_PST_2023
Cuda compilation tools, release 12.1, V12.1.66
Build cuda_12.1.r12.1/compiler.32415258_0

After installing jax thorugh pip,

python -m venv venv
source venv/bin/activate
pip install -U "jax[cuda12]"

The system-installed ptxas binary (instead of pip-installed one) is used and jax throws an error:

import jax
jax.numpy.zeros(3)
Full Error Log
E0103 16:41:38.489316 2217885 ptx_compiler_helpers.cc:87] *** WARNING *** Invoking ptxas with version 12.1.66, which corresponds to a CUDA version <=12.6.2. CUDA version
s 12.x.y up to and including 12.6.2 miscompile certain edge cases around clamping.
Please upgrade to CUDA 12.6.3 or newer.
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 6149, in zeros
  return lax.full(shape, 0, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device))
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/lax/lax.py", line 1752, in full
  return broadcast(fill_value, shape)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/lax/lax.py", line 1244, in broadcast
  return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims,
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/lax/lax.py", line 1278, in broadcast_in_dim
  return broadcast_in_dim_p.bind(
         ^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/core.py", line 463, in bind
  return self.bind_with_trace(prev_trace, args, params)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/core.py", line 468, in bind_with_trace
  return trace.process_primitive(self, args, params)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/core.py", line 941, in process_primitive
  return primitive.impl(*args, **params)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/dispatch.py", line 90, in apply_primitive
  outs = fun(*args)
         ^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
  return fun(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 337, in cache_miss
  pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 195, in _python_pjit_helper
  out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params)
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1672, in _pjit_call_impl_python
  ).compile()
    ^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 2415, in compile
  executable = UnloadedMeshExecutable.from_hlo(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 2923, in from_hlo
  xla_executable = _cached_compilation(
                   ^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 2729, in _cached_compilation
  xla_executable = compiler.compile_or_get_cached(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/compiler.py", line 452, in compile_or_get_cached
  return _compile_and_write_cache(
         ^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/compiler.py", line 653, in _compile_and_write_cache
  executable = backend_compile(
               ^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/profiler.py", line 333, in wrapper
  return func(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/compiler.py", line 309, in backend_compile
  raise e
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/compiler.py", line 303, in backend_compile
  return backend.compile(built_c, compile_options=options)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: ptxas exited with non-zero error code 65280, output: ptxas /var/tmp/tempfile-###-8aa5641830c7ce6-2217885
-62acff3e14c12, line 5; fatal   : Unsupported .version 8.3; current version is '8.1'
ptxas fatal   : Ptx assembly aborted due to errors

I set LD_LIBRARY_PATH to be empty, but still encountered this error.

Related Issues

#25344: About the same error, but in my case I don't have triton installed; I think this is a separate issue.
#18578: On ptxas binary priority issue.

Workaround

We can manually prepend the pip-installed ptxas binary path to PATH to avoid this error:

export PATH=$(python -c "import site; print(site.getsitepackages()[0] + '/nvidia/cuda_nvcc/bin')"):$PATH

System info (python version, jaxlib version, accelerator, etc.)

CUDA version: 12.1

jax:    0.4.38
jaxlib: 0.4.38
numpy:  2.2.1
python: 3.12.5 (main, Aug 19 2024, 18:21:17) [GCC 9.4.0]
device info: NVIDIA H100 80GB HBM3-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='###', release='###', version='#18-Ubuntu SMP Fri Jul 26 14:21:24 UTC 2024', machine='x86_64')
@takkyu2 takkyu2 added the bug Something isn't working label Jan 3, 2025
@MuhammadHakami
Copy link

I would also suggest installing ptxas from conda with conda install cuda -c nvidia
as ptxas is part of the cuda toolkit. installing two ptxas is not ideal but will work as temp solution for now.
Thanks @takkyu2 for the detailed comment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants