Skip to content

Apple Silicon does not support float64 #14

@vicha-w

Description

@vicha-w

Hi,

I am trying to run SPANet ttbar example as detailed in [1] but on my M4 Pro MacBook Pro, using the following command to see if it works:

python -m spanet.train -of ./options_files/full_hadronic_ttbar/example.json --time_limit 00:00:01:00

The training script spewed the following traceback:

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/top/Documents/ttx_v2_spanet/SPANet/spanet/train.py", line 260, in <module>
    main(**parser.parse_args().__dict__)
  File "/Users/top/Documents/ttx_v2_spanet/SPANet/spanet/train.py", line 190, in main
    trainer.fit(model, ckpt_path=checkpoint)
  File "/Users/top/Documents/virtual-environments/py3p12_spanet/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 561, in fit
    call._call_and_handle_interrupt(
  File "/Users/top/Documents/virtual-environments/py3p12_spanet/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py", line 48, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/top/Documents/virtual-environments/py3p12_spanet/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 599, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/Users/top/Documents/virtual-environments/py3p12_spanet/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 1012, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/Users/top/Documents/virtual-environments/py3p12_spanet/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 1054, in _run_stage
    self._run_sanity_check()
  File "/Users/top/Documents/virtual-environments/py3p12_spanet/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 1083, in _run_sanity_check
    val_loop.run()
  File "/Users/top/Documents/virtual-environments/py3p12_spanet/lib/python3.12/site-packages/pytorch_lightning/loops/utilities.py", line 179, in _decorator
    return loop_run(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/top/Documents/virtual-environments/py3p12_spanet/lib/python3.12/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 145, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/Users/top/Documents/virtual-environments/py3p12_spanet/lib/python3.12/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 411, in _evaluation_step
    batch = call._call_strategy_hook(trainer, "batch_to_device", batch, dataloader_idx=dataloader_idx)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/top/Documents/virtual-environments/py3p12_spanet/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py", line 328, in _call_strategy_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/Users/top/Documents/virtual-environments/py3p12_spanet/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py", line 278, in batch_to_device
    return model._apply_batch_transfer_handler(batch, device=device, dataloader_idx=dataloader_idx)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/top/Documents/virtual-environments/py3p12_spanet/lib/python3.12/site-packages/pytorch_lightning/core/module.py", line 352, in _apply_batch_transfer_handler
    batch = self._call_batch_hook("transfer_batch_to_device", batch, device, dataloader_idx)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/top/Documents/virtual-environments/py3p12_spanet/lib/python3.12/site-packages/pytorch_lightning/core/module.py", line 341, in _call_batch_hook
    return trainer_method(trainer, hook_name, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/top/Documents/virtual-environments/py3p12_spanet/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py", line 176, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/Users/top/Documents/virtual-environments/py3p12_spanet/lib/python3.12/site-packages/pytorch_lightning/core/hooks.py", line 611, in transfer_batch_to_device
    return move_data_to_device(batch, device)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/top/Documents/virtual-environments/py3p12_spanet/lib/python3.12/site-packages/lightning_fabric/utilities/apply_func.py", line 110, in move_data_to_device
    return apply_to_collection(batch, dtype=_TransferableDataType, function=batch_to)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/top/Documents/virtual-environments/py3p12_spanet/lib/python3.12/site-packages/lightning_utilities/core/apply_func.py", line 74, in apply_to_collection
    return _apply_to_collection_slow(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/top/Documents/virtual-environments/py3p12_spanet/lib/python3.12/site-packages/lightning_utilities/core/apply_func.py", line 127, in _apply_to_collection_slow
    v = _apply_to_collection_slow(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/top/Documents/virtual-environments/py3p12_spanet/lib/python3.12/site-packages/lightning_utilities/core/apply_func.py", line 127, in _apply_to_collection_slow
    v = _apply_to_collection_slow(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/top/Documents/virtual-environments/py3p12_spanet/lib/python3.12/site-packages/lightning_utilities/core/apply_func.py", line 127, in _apply_to_collection_slow
    v = _apply_to_collection_slow(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/top/Documents/virtual-environments/py3p12_spanet/lib/python3.12/site-packages/lightning_utilities/core/apply_func.py", line 98, in _apply_to_collection_slow
    return function(data, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/top/Documents/virtual-environments/py3p12_spanet/lib/python3.12/site-packages/lightning_fabric/utilities/apply_func.py", line 104, in batch_to
    data_output = data.to(device, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

I have checked SPANet's source code but I could not find any part that designates a tensor as float64 type which probably caused the error. Apparently, setting the accelerator to "cpu" in line 165 of train.py [2] solved the issue and the training module runs without complaints.

I have also tried torch.set_default_dtype(torch.float32) at the beginning of the training script but this does not fix this issue either.

For your convenience, here is the list of relevant Python packages I have set up in my virtual environment :

  • h5py==3.14.0
  • lightning-utilities==0.14.3
  • numba==0.61.2
  • numpy==2.2.6
  • pytorch-lightning==2.5.1.post0
  • scikit-learn==1.7.0
  • scipy==1.15.3
  • sympy==1.14.0
  • tensorboard==2.19.0
  • tensorboard-data-server==0.7.2
  • torch==2.7.1
  • torchmetrics==1.7.3

I understand that the code here is designed specifically for CUDA environments, but it would be nice if this network architecture also works with Apple Silicon. Please let me know if you would like further details or if you would like me to try applying some patches and test it on my MacBook Pro.

Thanks!
Vichayanun

[1] https://github.com/Alexanders101/SPANet/blob/master/docs/TTBar.md#training
[2]

SPANet/spanet/train.py

Lines 164 to 177 in 46c6805

trainer = pl.Trainer(
accelerator="gpu" if options.num_gpu > 0 else "auto",
devices=options.num_gpu if options.num_gpu > 0 else "auto",
strategy="ddp" if options.num_gpu > 1 else "auto",
precision="16-mixed" if fp16 else "32-true",
gradient_clip_val=options.gradient_clip if options.gradient_clip > 0 else None,
max_epochs=epochs,
max_time=time_limit,
logger=logger,
profiler=profiler,
callbacks=callbacks
)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions