Batch normalization in BaseDataModule #1555
Replies: 1 comment
-
These are just common default values meant to be overridden by the user. Most images in ML datasets (including the majority of remote sensing datasets) are stored in uint8 (0–255), so dividing by 255 puts them in the 0–1 range expected by most models.
Batch size and # workers are user-specific instance parameters. Users can and should modify these to work best for them, like any other hyperparameter. Mean and std dev are usually dataset-specific, and users rarely tune them like other hyperparameters. They should be overridden in a subclass like so: MySpecialDatasetDataModule(GeoDataModule):
mean = torch.tensor([123, 321, 521, 351])
std = torch.tensor([523, 654, 613, 634])
... You'll see that mean/std dev can either be a single float (same value for all channels) or an array for each channel. Hope this makes sense, let me know if you have any other questions! |
Beta Was this translation helpful? Give feedback.
-
Hi,
I've been using GeoDataModule and realised that all images are divided by 255. That doesn't make much sense for my data, so I looked into the BaseDataModule class and saw that:
mean = torch.tensor(0) std = torch.tensor(255)
...
self.aug: Transform = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), data_keys=["image"] )
Is there any particular reason why the mean and std are set there in such a way? Wouldn't having those two as class parameters (like batch size or number of workers) be better? Some users (like me) might use data that is not RGB, so they would need different values for normalisation.
(Note: I'm new to TorchGeo and machine learning in general, so I might be missing something obvious here!)
Beta Was this translation helpful? Give feedback.
All reactions