Skip to content

Remove dtype parameter, use previously existing "precision" instead #208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 8, 2023

Conversation

RaulPPelaez
Copy link
Collaborator

The dtype parameter was causing some issues (see #205).
I realized that there is already a "pecision" parameter that can be used for the same thing.
This PR removes the dtype argument and uses precision instead, which can be 16, 32 or 64.

@AntonioMirarchi could you check that the issues you were seeing are gone with this PR?

@AntonioMirarchi
Copy link
Contributor

Yes, let me to run a training and I will let you know. Looks good to me!

@RaulPPelaez
Copy link
Collaborator Author

I worked with @AntonioMirarchi here to include the possibility of training in double precision.
Lightning complains when the module is set to double but the DataModule provides single precision inputs.

Ideally every dataset class should process their corresponding files in either a user-provided dtype or just float64 and have the "get" method cast to whatever type is needed.
However, this is a huge undertaking due to the large amount of datasets currently available (which most just read/write files in float32 and all of them provide float32 in their get method).

Instead I opted for writing a dataset wrapper that is used by the DataModule:

class FloatCastDatasetWrapper(Dataset):
    def __init__(self, dataset, dtype=torch.float64):
        super(FloatCastDatasetWrapper, self).__init__(dataset.root, dataset.transform, dataset.pre_transform, dataset.pre_filter)
        self.dataset = dataset
        self.dtype = dtype

    def len(self):
        return len(self.dataset)

    def get(self, idx):
        data = self.dataset.get(idx)
        for key, value in data:
            if torch.is_tensor(value) and torch.is_floating_point(value):
                setattr(data, key, value.to(self.dtype))
        return data
    def __getattr__(self, name):
        # Check if the attribute exists in the underlying dataset
        if hasattr(self.dataset, name):
            return getattr(self.dataset, name)
        raise AttributeError(f"'{type(self).__name__}' and its underlying dataset have no attribute '{name}'")

This simply intercepts the get method and casts every tensor in the data to the correct type.

Bonus points: Should the need arise it would be easy now to enable training/inference in other floating types, such as bfloat16 or even lower precision stuff like nf4

Copy link
Contributor

@AntonioMirarchi AntonioMirarchi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I trained without any problem and error using single and double precision.

@RaulPPelaez
Copy link
Collaborator Author

@guillemsimeon @raimis please review

@RaulPPelaez RaulPPelaez merged commit dca6679 into torchmd:main Aug 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants