Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: CUDA driver error: operation not supported #1446

Open
paolovic opened this issue Dec 18, 2024 · 5 comments
Open

RuntimeError: CUDA driver error: operation not supported #1446

paolovic opened this issue Dec 18, 2024 · 5 comments

Comments

@paolovic
Copy link

HI,

unfortunately, I get this error:

ipdb> n
==((====))==  Unsloth 2024.12.4: Fast Llama patching. Transformers:4.47.0.
   \\   /|    GPU: NVIDIA L40S-48C. Max memory: 47.712 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.4.0+cu121. CUDA: 8.9. CUDA Toolkit: 12.1. Triton: 3.0.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.27.post2. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
RuntimeError: CUDA driver error: operation not supported
> /projects/fine-tuning/unsloth_template.py(70)<module>()
     69
---> 70 model, tokenizer = FastLanguageModel.from_pretrained(
     71     model_name = "/models/Llama-3.3-70B-Instruct-bnb-4bit", # or choose "unsloth/Llama-3.2-1B-Instruct"

ipdb> c
Traceback (most recent call last):
  File "/projects/fine-tuning/unsloth_template.py", line 70, in <module>
    model, tokenizer = FastLanguageModel.from_pretrained(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/environments/unsloth_env/lib64/python3.11/site-packages/unsloth/models/loader.py", line 256, in from_pretrained
    model, tokenizer = dispatch_model.from_pretrained(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/environments/unsloth_env/lib64/python3.11/site-packages/unsloth/models/llama.py", line 1663, in from_pretrained
    model = AutoModelForCausalLM.from_pretrained(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/environments/unsloth_env/lib64/python3.11/site-packages/transformers/models/auto/auto_factory.py", line 564, in from_pretrained
    return model_class.from_pretrained(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/environments/unsloth_env/lib64/python3.11/site-packages/transformers/modeling_utils.py", line 4130, in from_pretrained
    model = cls(config, *model_args, **model_kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/environments/unsloth_env/lib64/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1083, in __init__
    self.model = LlamaModel(config)
                 ^^^^^^^^^^^^^^^^^^
  File "/environments/unsloth_env/lib64/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 812, in __init__
    self.rotary_emb = LlamaRotaryEmbedding(config=config)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/environments/unsloth_env/lib64/python3.11/site-packages/unsloth/models/llama.py", line 1149, in __init__
    self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype())
  File "/environments/unsloth_env/lib64/python3.11/site-packages/unsloth/models/llama.py", line 1164, in _set_cos_sin_cache
    self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA driver error: operation not supported
@cool9203
Copy link

cool9203 commented Dec 19, 2024

I have similar problems, but not be same, just get same error message

Used vgpu vm with a6000 vram 24G(vram full is 48G)

full script: https://github.com/cool9203/unsloth-train/blob/d1c1ab702707ae5bdf69c0d303006c5726a61b23/unsloth_train/train_vision.py

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2024.12.4: Fast Mllama vision patching. Transformers: 4.46.2.
   \\   /|    GPU: NVIDIA RTXA6000-24Q. Max memory: 23.784 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.5.0+cu124. CUDA: 8.6. CUDA Toolkit: 12.4. Triton: 3.1.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.28.post2. FA2 = True]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Loading checkpoint shards:   0%|                                                                                                                                                             | 0/2 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/iii/yoga/unsloth-train/unsloth_train/__main__.py", line 173, in <module>
    train_model(**parameters)
  File "/usr/lib/python3.10/unittest/mock.py", line 1379, in patched
    return func(*newargs, **newkeywargs)
  File "/home/iii/yoga/unsloth-train/unsloth_train/train_vision.py", line 47, in train_model
    model, tokenizer = FastVisionModel.from_pretrained(
  File "/home/iii/yoga/unsloth-train/.venv/lib/python3.10/site-packages/unsloth/models/loader.py", line 492, in from_pretrained
    model, tokenizer = FastBaseVisionModel.from_pretrained(
  File "/home/iii/yoga/unsloth-train/.venv/lib/python3.10/site-packages/unsloth/models/vision.py", line 145, in from_pretrained
    model = AutoModelForVision2Seq.from_pretrained(
  File "/home/iii/yoga/unsloth-train/.venv/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 564, in from_pretrained
    return model_class.from_pretrained(
  File "/home/iii/yoga/unsloth-train/.venv/lib/python3.10/site-packages/transformers/modeling_utils.py", line 4225, in from_pretrained
    ) = cls._load_pretrained_model(
  File "/home/iii/yoga/unsloth-train/.venv/lib/python3.10/site-packages/transformers/modeling_utils.py", line 4728, in _load_pretrained_model
    new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
  File "/home/iii/yoga/unsloth-train/.venv/lib/python3.10/site-packages/transformers/modeling_utils.py", line 993, in _load_state_dict_into_meta_model
    set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
  File "/home/iii/yoga/unsloth-train/.venv/lib/python3.10/site-packages/accelerate/utils/modeling.py", line 329, in set_module_tensor_to_device
    new_value = value.to(device)
RuntimeError: CUDA driver error: operation not supported

nvidia-smi:

+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.05             Driver Version: 550.127.05     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA RTXA6000-24Q            On  |   00000000:00:10.0 Off |                    0 |
| N/A   N/A    P8             N/A /  N/A  |       1MiB /  24576MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

install package:

accelerate               1.2.1
aiohappyeyeballs         2.4.4
aiohttp                  3.11.11
aiosignal                1.3.2
async-timeout            5.0.1
attrs                    24.3.0
bitsandbytes             0.45.0
certifi                  2024.12.14
charset-normalizer       3.4.0
cut-cross-entropy        24.12.2
datasets                 3.2.0
dill                     0.3.8
docstring-parser         0.16
editables                0.5
einops                   0.8.0
et-xmlfile               2.0.0
filelock                 3.13.1
flash-attn               2.7.0.post2
frozenlist               1.5.0
fsspec                   2024.2.0
gensim                   4.3.3
hatchling                1.27.0
hf-transfer              0.1.8
huggingface-hub          0.27.0
idna                     3.10
jieba                    0.42.1
jinja2                   3.1.3
markdown-it-py           3.0.0
markupsafe               2.1.5
mdurl                    0.1.2
mpmath                   1.3.0
multidict                6.1.0
multiprocess             0.70.16
networkx                 3.2.1
ninja                    1.11.1.3
numpy                    1.26.4
nvidia-cublas-cu12       12.4.5.8
nvidia-cuda-cupti-cu12   12.4.127
nvidia-cuda-nvrtc-cu12   12.4.127
nvidia-cuda-runtime-cu12 12.4.127
nvidia-cudnn-cu12        9.1.0.70
nvidia-cufft-cu12        11.2.1.3
nvidia-curand-cu12       10.3.5.147
nvidia-cusolver-cu12     11.6.1.9
nvidia-cusparse-cu12     12.3.1.170
nvidia-nccl-cu12         2.21.5
nvidia-nvjitlink-cu12    12.4.127
nvidia-nvtx-cu12         12.4.127
openpyxl                 3.1.5
packaging                24.2
pandas                   2.2.3
pathspec                 0.12.1
peft                     0.14.0
pillow                   10.2.0
pip                      24.3.1
pluggy                   1.5.0
propcache                0.2.1
protobuf                 3.20.3
psutil                   6.1.0
pyarrow                  18.1.0
pygments                 2.18.0
python-dateutil          2.9.0.post0
pytz                     2024.2
pyyaml                   6.0.2
regex                    2024.11.6
requests                 2.32.3
rich                     13.9.4
safetensors              0.4.5
scipy                    1.13.1
sentencepiece            0.2.0
setuptools               75.6.0
shtab                    1.7.1
six                      1.17.0
smart-open               7.1.0
sympy                    1.13.1
tokenizers               0.21.0
tomli                    2.2.1
torch                    2.5.0
torchvision              0.20.1+cu124
tqdm                     4.67.1
transformers             4.47.0
triton                   3.1.0
trl                      0.13.0
trove-classifiers        2024.10.21.16
typeguard                4.2.0
typing-extensions        4.9.0
tyro                     0.9.3
tzdata                   2024.2
unsloth                  2024.12.4
unsloth-zoo              2024.12.1
urllib3                  2.2.3
wheel                    0.45.1
wrapt                    1.17.0
xformers                 0.0.28.post2
xxhash                   3.5.0
yarl                     1.18.3

@joey00072
Copy link

check is torch installed properly with cuda

import torch
print(f"{torch.cuda.is_available()=}")
print(f"{torch.rand(6,9).to(torch.device('cuda'))=}")

@cool9203
Copy link

cool9203 commented Dec 24, 2024

I feeling not torch problem, maybe is vgpu or newest cuda driver or flash attention problem? but this is my guess

I can run my script in another computer, windows + wsl2 + container + a6000 gpu, use same docker image,
but in vgpu computer always get error.


root@097ce03c393c:/app# python
Python 3.12.8 (main, Dec  6 2024, 19:59:28) [Clang 18.1.8 ] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> print(f"{torch.cuda.is_available()=}")
torch.cuda.is_available()=True
>>> print(f"{torch.rand(6,9).to(torch.device('cuda'))=}")
torch.rand(6,9).to(torch.device('cuda'))=tensor([[0.9446, 0.0999, 0.1625, 0.9882, 0.0199, 0.6384, 0.4474, 0.0070, 0.8371],
        [0.8284, 0.0053, 0.4615, 0.0505, 0.2884, 0.8938, 0.3250, 0.2470, 0.2900],
        [0.9626, 0.8180, 0.8720, 0.4847, 0.0161, 0.0646, 0.9458, 0.0677, 0.3900],
        [0.6896, 0.4652, 0.3111, 0.5769, 0.7169, 0.2230, 0.5424, 0.2265, 0.0369],
        [0.7812, 0.7536, 0.8573, 0.7497, 0.3073, 0.2007, 0.1929, 0.0783, 0.6394],
        [0.3464, 0.0429, 0.9927, 0.0218, 0.7993, 0.1402, 0.8592, 0.6124, 0.6949]],
       device='cuda:0')

And i can run pytorch quick start

pytorch quick start output:

root@097ce03c393c:/app# python test.py
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 26.4M/26.4M [00:07<00:00, 3.77MB/s]
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29.5k/29.5k [00:00<00:00, 109kB/s]
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4.42M/4.42M [00:02<00:00, 1.75MB/s]
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5.15k/5.15k [00:00<00:00, 83.0MB/s]
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64
Using cuda device
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)
Epoch 1
-------------------------------
loss: 2.304316  [   64/60000]
loss: 2.289975  [ 6464/60000]
loss: 2.270400  [12864/60000]
loss: 2.259455  [19264/60000]
loss: 2.245398  [25664/60000]
loss: 2.209586  [32064/60000]
loss: 2.219506  [38464/60000]
loss: 2.183300  [44864/60000]
loss: 2.181254  [51264/60000]
loss: 2.146851  [57664/60000]
Test Error:
 Accuracy: 41.1%, Avg loss: 2.141327

Epoch 2
-------------------------------
loss: 2.155709  [   64/60000]
loss: 2.144184  [ 6464/60000]
loss: 2.084640  [12864/60000]
loss: 2.103154  [19264/60000]
loss: 2.049047  [25664/60000]
loss: 1.985874  [32064/60000]
loss: 2.025819  [38464/60000]
loss: 1.936248  [44864/60000]
loss: 1.945789  [51264/60000]
loss: 1.884334  [57664/60000]
Test Error:
 Accuracy: 52.2%, Avg loss: 1.869702

I test run train text model, got same problem

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/app/unsloth_train/__main__.py", line 173, in <module>
    train_model(**parameters)
  File "/root/.local/share/uv/python/cpython-3.12.8-linux-x86_64-gnu/lib/python3.12/unittest/mock.py", line 1396, in patched
    return func(*newargs, **newkeywargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/unsloth_train/train.py", line 52, in train_model
    model, tokenizer = FastLanguageModel.from_pretrained(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/.venv/lib/python3.12/site-packages/unsloth/models/loader.py", line 256, in from_pretrained
    model, tokenizer = dispatch_model.from_pretrained(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/.venv/lib/python3.12/site-packages/unsloth/models/llama.py", line 1663, in from_pretrained
    model = AutoModelForCausalLM.from_pretrained(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/.venv/lib/python3.12/site-packages/transformers/models/auto/auto_factory.py", line 564, in from_pretrained
    return model_class.from_pretrained(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/.venv/lib/python3.12/site-packages/transformers/modeling_utils.py", line 4130, in from_pretrained
    model = cls(config, *model_args, **model_kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/.venv/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 1083, in __init__
    self.model = LlamaModel(config)
                 ^^^^^^^^^^^^^^^^^^
  File "/app/.venv/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 812, in __init__
    self.rotary_emb = LlamaRotaryEmbedding(config=config)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/.venv/lib/python3.12/site-packages/unsloth/models/llama.py", line 1149, in __init__
    self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype())
  File "/app/.venv/lib/python3.12/site-packages/unsloth/models/llama.py", line 1164, in _set_cos_sin_cache
    self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA driver error: operation not supported

@cool9203
Copy link

Hello, my coworker @treeaaa test it, this error from

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,roundup_power2_divisions:[64:128,256:64,>:32]"

Delete this line will be work.

We guess is low level error, vgpu or cuda not support this, so got error.

@paolovic
Copy link
Author

Thank you @cool9203
this workaround works

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants