You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I want to make this library compatible with Apple's GPU but it needs two lines of code to be modified.
What is the expected behavior?
Currently, running a training on Apple's GPU almost works by setting the device_name to "mps". Yet, at the end of the training when TabModel.explain method is called, it raises an error.
Specifically, if I started the training with the following initialization of the model, the line 354 of TabModel.explain raises TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
What is motivation or use case for adding/changing the behavior?
I believe utilizing GPU on training benefits users of Apple computers.
How should this be implemented in your opinion?
I confirmed that adding the two lines below to the method solves the issue.
for batch_nb, data in enumerate(dataloader):
+ if self.device == torch.device("mps"):+ data = data.to(torch.float32)
data = data.to(self.device).float()
M_explain, masks = self.network.forward_masks(data)
Are you willing to work on this yourself?
yes
The text was updated successfully, but these errors were encountered:
Feature request
What is the expected behavior?
Currently, running a training on Apple's GPU almost works by setting the
device_name
to"mps"
. Yet, at the end of the training whenTabModel.explain
method is called, it raises an error.Specifically, if I started the training with the following initialization of the model, the line 354 of
TabModel.explain
raisesTypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
Then the error comes from the following line. This is because the
data
is in float64 while Apple's GPU only supports float32.tabnet/pytorch_tabnet/abstract_model.py
Lines 353 to 356 in 2c0c4eb
What is motivation or use case for adding/changing the behavior?
How should this be implemented in your opinion?
Are you willing to work on this yourself?
yes
The text was updated successfully, but these errors were encountered: