You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Please feel free to close this issue in case this is an expected behavior.
In case this is expected, it would be great if there would be an easy way to fix it from our (users') side other than installing newer CUDA version.
Summary
If CUDA version 12.1 is installed to the system and ptxas is already in the system PATH:
ptxas --version
ptxas: NVIDIA (R) Ptx optimizing assembler
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Feb__7_19:30:12_PST_2023
Cuda compilation tools, release 12.1, V12.1.66
Build cuda_12.1.r12.1/compiler.32415258_0
The system-installed ptxas binary (instead of pip-installed one) is used and jax throws an error:
importjaxjax.numpy.zeros(3)
Full Error Log
E0103 16:41:38.489316 2217885 ptx_compiler_helpers.cc:87] *** WARNING *** Invoking ptxas with version 12.1.66, which corresponds to a CUDA version <=12.6.2. CUDA version
s 12.x.y up to and including 12.6.2 miscompile certain edge cases around clamping.
Please upgrade to CUDA 12.6.3 or newer.
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 6149, in zeros
return lax.full(shape, 0, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/lax/lax.py", line 1752, in full
return broadcast(fill_value, shape)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/lax/lax.py", line 1244, in broadcast
return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/lax/lax.py", line 1278, in broadcast_in_dim
return broadcast_in_dim_p.bind(
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/core.py", line 463, in bind
return self.bind_with_trace(prev_trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/core.py", line 468, in bind_with_trace
return trace.process_primitive(self, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/core.py", line 941, in process_primitive
return primitive.impl(*args, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/dispatch.py", line 90, in apply_primitive
outs = fun(*args)
^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 337, in cache_miss
pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 195, in _python_pjit_helper
out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1672, in _pjit_call_impl_python
).compile()
^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 2415, in compile
executable = UnloadedMeshExecutable.from_hlo(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 2923, in from_hlo
xla_executable = _cached_compilation(
^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 2729, in _cached_compilation
xla_executable = compiler.compile_or_get_cached(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/compiler.py", line 452, in compile_or_get_cached
return _compile_and_write_cache(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/compiler.py", line 653, in _compile_and_write_cache
executable = backend_compile(
^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/profiler.py", line 333, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/compiler.py", line 309, in backend_compile
raise e
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/compiler.py", line 303, in backend_compile
return backend.compile(built_c, compile_options=options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: ptxas exited with non-zero error code 65280, output: ptxas /var/tmp/tempfile-###-8aa5641830c7ce6-2217885
-62acff3e14c12, line 5; fatal : Unsupported .version 8.3; current version is '8.1'
ptxas fatal : Ptx assembly aborted due to errors
I set LD_LIBRARY_PATH to be empty, but still encountered this error.
Related Issues
#25344: About the same error, but in my case I don't have triton installed; I think this is a separate issue. #18578: On ptxas binary priority issue.
Workaround
We can manually prepend the pip-installed ptxas binary path to PATH to avoid this error:
I would also suggest installing ptxas from conda with conda install cuda -c nvidia
as ptxas is part of the cuda toolkit. installing two ptxas is not ideal but will work as temp solution for now.
Thanks @takkyu2 for the detailed comment.
Description
Please feel free to close this issue in case this is an expected behavior.
In case this is expected, it would be great if there would be an easy way to fix it from our (users') side other than installing newer CUDA version.
Summary
If CUDA version 12.1 is installed to the system and ptxas is already in the system PATH:
After installing jax thorugh pip,
The system-installed ptxas binary (instead of pip-installed one) is used and jax throws an error:
Full Error Log
I set
LD_LIBRARY_PATH
to be empty, but still encountered this error.Related Issues
#25344: About the same error, but in my case I don't have triton installed; I think this is a separate issue.
#18578: On ptxas binary priority issue.
Workaround
We can manually prepend the pip-installed ptxas binary path to PATH to avoid this error:
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: