-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdataset.py
40 lines (32 loc) · 1.26 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import cv2
import os
from torch.utils.data import Dataset
class FaceDataset(Dataset):
def __init__(self, images_dir, images_name, target_dir=None,
transforms=None):
self.images_dir = images_dir
self.target_dir = target_dir
self.images_name = images_name
self.transforms = transforms
print('{} images'.format(len(self.images_name)))
def __len__(self):
return len(self.images_name)
def __getitem__(self, idx):
img_filename = os.path.join(self.images_dir, self.images_name[idx])
img = cv2.imread(img_filename)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if self.target_dir:
mask_filename = os.path.join(
self.target_dir, self.images_name[idx].replace('.jpg', '.png'))
mask = cv2.imread(mask_filename, 0)
else:
mask = []
if self.transforms:
if mask!=[]:
augmented = self.transforms(image=img, mask=mask)
img = augmented['image']
mask = augmented['mask']
else:
augmented = self.transforms(image=img)
img = augmented['image']
return img, mask