Skip to content

Commit 31182c3

Browse files
committed
fix gitignore
1 parent 3d30e86 commit 31182c3

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ run/paint/*
55
run/write/*
66
run/vae/*
77

8-
data
8+
/data
99
*.pth
1010
*.iml
1111
.idea

utils/data/FaceDataset.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import torchvision
2+
import os
3+
import PIL
4+
import numpy as np
5+
import torch as t
6+
from PIL import Image
7+
from torch.utils.data import Dataset
8+
9+
class FaceDataset(Dataset):
10+
def __init__(self, faces_path):
11+
items = []
12+
labels = []
13+
for img in os.listdir(faces_path):
14+
item = os.path.join(faces_path, img)
15+
items.append(item)
16+
labels.append(img)
17+
self.items = items
18+
self.labels = labels
19+
20+
def __len__(self):
21+
return len(self.items)
22+
23+
def _get_image_(self, idx):
24+
img = self.items[idx]
25+
img = PIL.Image.open(str(img)).convert('RGB')
26+
img = torchvision.transforms.Resize([128, 128])(img)
27+
a = np.asarray(img)
28+
a = np.transpose(a, (1, 0, 2))
29+
a = np.transpose(a, (2, 1, 0))
30+
return t.from_numpy(a.astype(np.float32, copy=False)).div(255)
31+
32+
def __getitem__(self, idx):
33+
return self._get_image_(idx), self.labels[idx]
34+
35+
def get_item_by_jpg(self, jpg):
36+
return self._get_image_(self.labels.index(jpg))

0 commit comments

Comments
 (0)