lightning_toolbox.DataModule
aims to capture the generic task for setting up DataModules in pytorch lightning (which are a way of decoupling data from the model implementation, see the docs). It could be very useful if you are dealing with datasets with the following assumption:
Creating a
torch.utils.data.DataLoader
from your dataset is as simple as callingtorch.utils.data.DataLoader(dataset, batch_size=...)
, which is the case for most datasets in deep learning.
In this case, you can use lightning_toolbox.DataModule
to create a LightningDataModule
from your dataset(s). You can either use it to transform your dataset(s) into a datamodule, or create your own datamodule by inheriting a class from it.
Table of Contents:
You can use lightning_toolbox.DataModule
either by instanciating it, or extending it's base functionality through class inheritance. Our DataModule
's functionality is two fold:
-
Dealing with the data: The data could be provided to the data module in either of the two following settings:
- Classes (or string import-path of) datasets to be instantiated along with their instantiation arguments.
- Dataset instances to be used directly.
The instantiation process takes place in the
setup
method, ensuring the correct behaviour in multi-device training scenarios. Depending on the provided configuration, the datamodule might split a givendataset
into training and validation splits based on aval_size
argument. -
Creating the dataloaders:
train_dataloader
,val_dataloader
,test_dataloader
methods are implemented to return the dataloaders based ontrain_data
,val_data
, andtest_data
attributes. The configuration for each dataloader could be individually modified through theDataModule
's constructor.
You can specify your datasets by passing them as arguments to the constructor, or by setting them as attributes of the class. There are four sets of dataset specific arguments in the constructor:
-
dataset
anddataset_args
: The dataset class and its arguments, which will be used to instantiate the training and validation datasets based onval_size
.val_size
is either an int specifying how many samples from the dataset are to be considered as the validation split or a float number between 0. and 1. specifying the portion of the data that should be used for validation.If you want to use a different dataset for validation, you can specify it in the
val_dataset
andval_dataset_args
arguments. Thedataset
should either be a dataset instance or a dataset class/import-path. If a dataset instance is provided asdataset
thendataset_args
are ignored. -
train_dataset
andtrain_dataset_args
: Similar todataset
anddataset_args
but only used for the training split (overridesdataset
anddataset_args
). -
val_dataset
andval_dataset_args
: Similar totrain_dataset
andval_dataset_args
but only used for the validation split. -
test_dataset
andtest_dataset_args
: Similar totrain_dataset
andtrain_dataset_args
but only used for the test split with the exception that they don't override thedataset*
arguments. This means that if you want a datamodule with a test split dataset and dataloader, you have to explicitly provide
As a clarification, dataset*
arguments are only used for the training and validation splits. If you
You can tell DataModule
to import your intended dataset classes by passing the import path as a string to any of the *dataset
arguments. For example, if you want to use the MNIST
dataset from torchvision.datasets
, you can pass dataset='torchvision.datasets.MNIST'
to the constructor. The *dataset_args
arguments should be a dictionary of arguments to be passed to the dataset class. For example, if you want to use the MNIST
dataset with root
set to data/
, you can pass dataset_args={'root': 'data/'}
to the constructor.
As mentioned before, the val_size
argument can be either an int specifying how many samples from the dataset are to be considered as the validation split or a float number between 0. and 1. specifying the portion of the data that should be used for validation. The seed
argument is used for the random split.
- If you want to use a different dataset for validation, you can specify it in the
val_dataset
andval_dataset_args
arguments, whereval_dataset
should either be a dataset instance or a dataset class/import-path. If a dataset instance is provided asval_dataset
thenval_dataset_args
are ignored. - If You don't want to use a validation split, you can set
val_size=0
, andval_batch_size=0
. - When using
val_dataset
,val_size
is automatically regarded as zero. Meaning that if adataset
is also provided, it is considered completely as the training split.
The following example creates a datamodule for the MNIST
dataset and uses 20% of it as its validation split.
from lightning_toolbox import DataModule
dm = DataModule(
dataset='torchvision.datasets.MNIST', # You could have also specifed the class itself (torchvision.datasets.MNIST)
dataset_args={'root': 'data/', 'download': True}, # All arguments to be passed to the dataset class constructor
val_size=0.2, # 20% of the data will be used for validation, val_size=200 would have used 200 samples for validation
seed=42, # The seed to be used for the random split
) # This datamodule does not have a test split
The following example creates a datamodule for the MNIST
dataset and uses a different dataset for validation.
from lightning_toolbox import DataModule
dm = DataModule(
dataset='torchvision.datasets.MNIST',
dataset_args={'root': 'data/', 'download': True},
val_dataset='torchvision.datasets.FashionMNIST',
val_dataset_args={'root': 'data/', 'download': True},
) # val_size is automatically regarded as zero and the dataset is considered completely as the `train_dataset`
Similar to the *dataset*
arguments, there are four arguments that could be set for the default behaiviour of the dataloaders: 1. batch_size
, 2. pin_memory
, 3. shuffle
, and 4. num_workers
. You can either set them directly or use a train_
/val_
/test_
prefix to modify individual dataloader behaviours. There are a couple of notes to point out:
batch_size
is set to 16 by default.pin_memory
is set toTrue
by default.train_shuffle
andval_shuffle
andtest_shuffle
are by default set to True, False, and False respectively. Given the sensitivity of shuffling in training, there is no baseshuffle
argument and if you want to change the shuffling configuration of any of the dataloaders, you have to individually specify your intended configuration.num_workers
(Number of workers) is set to zero by default. This means that the resulting datamodules are run on the same process/thread by default which could lead to performance bottlenecks. Make sure that you set these arguments suitable for your hardware and data (these articles could be helpful: 1, 2, and 3)- Needless to say, if a data-split is not present in the datamodule (e.g.
test_dataset
is not provided), the corresponding dataloader will not be available and its arguments will be ignored.
The following datamodule uses a batch size of 32 for the training split and 64 for the validation split, and uses 4 workers for the training split and 8 workers for the validation split.
from lightning_toolbox import DataModule
dm = DataModule(
..., # dataset descriptions
batch_size=32, # This will be used for every split (unless overriden)
val_batch_size=64, # This will override batch_size for the validation split
num_workers=4, # This will be used for all the splits (unless overriden)
val_num_workers=8, # This will override num_workers for the validation split
)
Transformations are functions applied to your datapoints before they are fed to the model. You can specify transformations for each of the splits by passing a train_transforms
, val_transforms
, and test_transforms
arguments. These arguments should be a list of transformations to be applied to the data. You can also pass a transforms
argument which will be used for all the splits. If you pass both transforms
and train_transforms
for example, the train_transforms
will be used for the training split and the transforms
will be used for the validation and test splits.
The following are possible ways of specifying a transformation object:
-
A callable function or a function descriptor which could be a path to a function or a code that evaluetes to a callable value (e.g.
torchvision.transforms.functional_tensor
ortorchvision.transforms.ToTensor()
), or directly a function, whether as a string or as a python code.To learn more about possible ways of specifying a custom function check out dypy's function specification documentation.
For Instance, the following are all valid ways of specifying a transformation:
from lightning_toolbox import DataModule from torchvision.transforms import ToTensor DataModule( ..., # dataset descriptions # a callable value transforms=lambda x: x, # ---- # a string definition of a annonymous lambda function transforms='lambda x: x', # ---- # path to a callable value transforms='torchvision.transforms.functional_tensor.to_tensor', # ---- # piece of code that evaluates to a callable value transforms='torchvision.transforms.ToTensor()', # ---- transforms=dict( # an actual piece of code that evaluates to a module with # callable functions (which we get the function of interest from) code=""" import torchvision.transforms as T def factor(x): return 1 def transform(x): return T.ToTensor()(x) * factor(x) """, function_of_interest='transform' ) )
-
A transformation class (and its arguments) which could be a string with the path to a class or the class variable itself (e.g.
"torchvision.transforms.ToTensor"
ortorchvision.transforms.RandomCrop
), or directly a class, whether as a string or as a python code.For instance, the following are all valid ways of specifying a transformation class:
from lightning_toolbox import DataModule DataModule( ..., # dataset descriptions # a class variable transforms=torchvision.transforms.ToTensor, # ---- # a string definition of a class variable # when a class is provided as a string, it is assumed that the class is in the dypy context # and that it needs no arguments to be instantiated transforms='torchvision.transforms.ToTensor', # ---- # a class path and its arguments transforms=dict( class_path='torchvision.transforms.RandomCrop', init_args=dict(size=32) ), )
It might be the case that you want to apply the same kind of transformations to a part of your data (like when dealing with (input, output) pairs). lightning_toolbox
provides utility transformations that can be used to handle such cases. These transformations are:
lightning_toolbox.data.transforms.TupleDataTransform
: This transformation class takes on sets of transformations, for each element in the tuple data. Everything is similar to the way you would specify transformations for the datamodule, except that you have to specify the transformations for elements seperately. For instance, the following is a valid way of specifying aTupleDataTransform
for a vision task for predicting the parity of the input digit image:from lightning_toolbox import DataModule from lightning_toolbox.data.transforms import TupleDataTransform DataModule( dataset='torchvision.datasets.MNIST', transforms=TupleDataTransform( transforms=['torchvision.transforms.ToTensor()', 'lambda x: x % 2'] ) ) # or equivalently (using class_path and init_args) DataModule( dataset='torchvision.datasets.MNIST', transforms=dict( class_path='lightning_toolbox.data.transforms.TupleDataTransform', init_args={ 0:'torchvision.transforms.ToTensor()', 1:'lambda x: x % 2' } ) )
You may wish to extend the base functionality of lightning_toolbox.DataModule
perhaps to only use the dataloader configurations, or provide your own perepare_data
method. The only thing to keep in mind is that, if you want to use the dataloader's functionality, your datamodule should have train_data
, val_data
and test_data
specified by the time lightning want's to call the dataloaders functions.
If you wish to use dataset lookup functionality the lookup process is done through dypy.eval
and made available through the static method get_dataset
. In order to use higher-level functionalities such as dataset splitting or dataloaders configurations.
If you wish to setup your own dataset and then split it into validation and train, call self.split_dataset()
, which by default takes place inside the fit stage of setup
. You just need to set self.train_data
, self.val_data
and self.dataset
accordingly.