-
Notifications
You must be signed in to change notification settings - Fork 0
/
custom_datasets.py
48 lines (41 loc) · 1.88 KB
/
custom_datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import os, re
from torchvision.datasets import ImageFolder
from utils import path2age
class ImageFolderWithAges(ImageFolder):
"""Custom dataset that includes face image age. Extends
torchvision.datasets.ImageFolder
"""
def __init__(self, pat, pos, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pat = pat
self.pos = pos
# override the __getitem__ method. this is the method that dataloader calls
def __getitem__(self, index):
# this is what ImageFolder normally returns
original_tuple = super(ImageFolderWithAges, self).__getitem__(index)
age = path2age(self.imgs[index][0], self.pat, self.pos)
# make a new tuple that includes original and the path
tuple_with_path = (original_tuple + (age,))
return tuple_with_path
class ImageFolderWithAgeGroup(ImageFolder):
"""Custom dataset that includes face image age group, categorized by [0-12, 13-18, 19-25,
26-35, 36-45, 46-55, 56-65, >= 66]. Extends torchvision.datasets.ImageFolder
"""
def __init__(self, pat, pos, cutoffs, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pat = pat
self.pos = pos
self.cutoffs = cutoffs
# override the __getitem__ method. this is the method that dataloader calls
def __getitem__(self, index):
# this is what ImageFolder normally returns
original_tuple = super(ImageFolderWithAgeGroup, self).__getitem__(index)
age = path2age(self.imgs[index][0], self.pat, self.pos)
# make a new tuple that includes original and the path
tuple_with_path = (original_tuple + (self.find_group(age),))
return tuple_with_path
def find_group(self, age):
for i, cut in enumerate(self.cutoffs):
if age <= cut:
return i
return len(self.cutoffs)