-
Notifications
You must be signed in to change notification settings - Fork 4
/
celeb_race.py
49 lines (34 loc) · 1.38 KB
/
celeb_race.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import sys
import os
import torch
import torchvision
import numpy as np
from torchvision.datasets import CelebA
from torch.utils.data import Subset
white = np.load(os.path.expanduser('~/post_hoc_debiasing/celebrace/white_full.npy'))
black = np.load(os.path.expanduser('~/post_hoc_debiasing/celebrace/black_full.npy'))
asian = np.load(os.path.expanduser('~/post_hoc_debiasing/celebrace/asian_full.npy'))
class CelebRace(CelebA):
def __getitem__(self, index):
X, target = super().__getitem__(index)
ind = int(self.filename[index].split('.')[0])
augment = torch.tensor([white[ind-1] > .501,
black[ind-1] > .501,
asian[ind-1] > .501,
ind,
1-target[20]], dtype=torch.long)
return X, torch.cat((target, augment))
def unambiguous(dataset, split='train', thresh=.7):
if split == 'train':
n = 162770
else:
n = 19962
unambiguous_indices = [i for i in range(n) if (white[i] > thresh or black[i] > thresh or asian[i] > thresh)]
return Subset(dataset, unambiguous_indices)
def split_check(dataset, split='train', thresh=.7):
if split == 'train':
n = 162770
else:
n = 19962
unambiguous_indices = [i for i in range(n) if (asian[i] > thresh)]
return Subset(dataset, unambiguous_indices)