-
Notifications
You must be signed in to change notification settings - Fork 84
Remove dtype parameter, use previously existing "precision" instead #208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Yes, let me to run a training and I will let you know. Looks good to me! |
I worked with @AntonioMirarchi here to include the possibility of training in double precision. Ideally every dataset class should process their corresponding files in either a user-provided dtype or just float64 and have the "get" method cast to whatever type is needed. Instead I opted for writing a dataset wrapper that is used by the DataModule: class FloatCastDatasetWrapper(Dataset):
def __init__(self, dataset, dtype=torch.float64):
super(FloatCastDatasetWrapper, self).__init__(dataset.root, dataset.transform, dataset.pre_transform, dataset.pre_filter)
self.dataset = dataset
self.dtype = dtype
def len(self):
return len(self.dataset)
def get(self, idx):
data = self.dataset.get(idx)
for key, value in data:
if torch.is_tensor(value) and torch.is_floating_point(value):
setattr(data, key, value.to(self.dtype))
return data
def __getattr__(self, name):
# Check if the attribute exists in the underlying dataset
if hasattr(self.dataset, name):
return getattr(self.dataset, name)
raise AttributeError(f"'{type(self).__name__}' and its underlying dataset have no attribute '{name}'") This simply intercepts the get method and casts every tensor in the data to the correct type. Bonus points: Should the need arise it would be easy now to enable training/inference in other floating types, such as bfloat16 or even lower precision stuff like nf4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! I trained without any problem and error using single and double precision.
@guillemsimeon @raimis please review |
The dtype parameter was causing some issues (see #205).
I realized that there is already a "pecision" parameter that can be used for the same thing.
This PR removes the dtype argument and uses precision instead, which can be 16, 32 or 64.
@AntonioMirarchi could you check that the issues you were seeing are gone with this PR?