diff --git a/dl/image/classification/cifar10/train.py b/dl/image/classification/cifar10/train.py index 3d522d2..595fe04 100644 --- a/dl/image/classification/cifar10/train.py +++ b/dl/image/classification/cifar10/train.py @@ -1,6 +1,7 @@ from image_transformation import transformations +import torch.nn as nn import torchvision train_set = torchvision.datasets.CIFAR10(root=data_path, train=True, transform=transformations()['train'],