Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Support for Apple Silicon Mac GPU #565

Open
NMZ0429 opened this issue Dec 1, 2024 · 1 comment
Open

Adding Support for Apple Silicon Mac GPU #565

NMZ0429 opened this issue Dec 1, 2024 · 1 comment
Assignees
Labels
enhancement New feature or request

Comments

@NMZ0429
Copy link

NMZ0429 commented Dec 1, 2024

Feature request

  • I want to make this library compatible with Apple's GPU but it needs two lines of code to be modified.

What is the expected behavior?

  • Currently, running a training on Apple's GPU almost works by setting the device_name to "mps". Yet, at the end of the training when TabModel.explain method is called, it raises an error.

  • Specifically, if I started the training with the following initialization of the model, the line 354 of TabModel.explain raises TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

    regressor = TabNetRegressor(
        optimizer_fn=torch.optim.Adam,
        optimizer_params=dict(lr=2e-2),
        device_name="mps",
        mask_type="entmax", 
    )

    Then the error comes from the following line. This is because the data is in float64 while Apple's GPU only supports float32.

    for batch_nb, data in enumerate(dataloader):
    data = data.to(self.device).float()
    M_explain, masks = self.network.forward_masks(data)

What is motivation or use case for adding/changing the behavior?

  • I believe utilizing GPU on training benefits users of Apple computers.

How should this be implemented in your opinion?

  • I confirmed that adding the two lines below to the method solves the issue.
        for batch_nb, data in enumerate(dataloader):
+            if self.device == torch.device("mps"):
+                data = data.to(torch.float32)
            data = data.to(self.device).float()

            M_explain, masks = self.network.forward_masks(data)

Are you willing to work on this yourself?
yes

@NMZ0429 NMZ0429 added the enhancement New feature or request label Dec 1, 2024
@Optimox
Copy link
Collaborator

Optimox commented Dec 14, 2024

Hello @NMZ0429,

Thanks for this proposal, feel free to open a PR!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants