Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Got "Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR" with TF 2.18 and jax with cuda_local #25658

Open
davidshen84 opened this issue Dec 22, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@davidshen84
Copy link

Description

Hi,

This is the code I used to create the docker image.

  • With TF 2.16, the docker image works with the below example.
  • With TF 2.18, if I use jax[cuda12_local], I get the error message; if I use jax[cuda12], the docker image works with the below example. However, the image size is doubled.

According to:

TF 2.18 and Jax use the same Cuda version and should be able to share the Cuda installation.

Example

from typing import Any, Optional, Tuple, Type, Union

import jax

rng_root = jax.random.PRNGKey(0)
rng_keys = ["noise", "dropout"]
(rng,) = jax.random.split(rng_root, 1)

print(rng)

Error

E1222 23:00:20.145007      21 cuda_dnn.cc:534] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
E1222 23:00:20.145107      21 cuda_dnn.cc:538] Memory usage: 7438598144 bytes free, 8589410304 bytes total.
E1222 23:00:20.145305      21 cuda_dnn.cc:534] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
E1222 23:00:20.145386      21 cuda_dnn.cc:538] Memory usage: 7438598144 bytes free, 8589410304 bytes total.
E1222 23:00:20.162389      21 cuda_dnn.cc:534] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
E1222 23:00:20.162519      21 cuda_dnn.cc:538] Memory usage: 7438598144 bytes free, 8589410304 bytes total.
E1222 23:00:20.162916      21 cuda_dnn.cc:534] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
E1222 23:00:20.163038      21 cuda_dnn.cc:538] Memory usage: 7438598144 bytes free, 8589410304 bytes total.

System info (python version, jaxlib version, accelerator, etc.)

Used tensorflow/tensorflow:2.18.0-gpu-jupyter as the base image and jax[cuda12_local] for Jax. More details are in the linked repository.

@davidshen84 davidshen84 added the bug Something isn't working label Dec 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant