File tree 2 files changed +37
-1
lines changed
2 files changed +37
-1
lines changed Original file line number Diff line number Diff line change @@ -5,7 +5,7 @@ run/paint/*
5
5
run /write /*
6
6
run /vae /*
7
7
8
- data
8
+ / data
9
9
* .pth
10
10
* .iml
11
11
.idea
Original file line number Diff line number Diff line change
1
+ import torchvision
2
+ import os
3
+ import PIL
4
+ import numpy as np
5
+ import torch as t
6
+ from PIL import Image
7
+ from torch .utils .data import Dataset
8
+
9
+ class FaceDataset (Dataset ):
10
+ def __init__ (self , faces_path ):
11
+ items = []
12
+ labels = []
13
+ for img in os .listdir (faces_path ):
14
+ item = os .path .join (faces_path , img )
15
+ items .append (item )
16
+ labels .append (img )
17
+ self .items = items
18
+ self .labels = labels
19
+
20
+ def __len__ (self ):
21
+ return len (self .items )
22
+
23
+ def _get_image_ (self , idx ):
24
+ img = self .items [idx ]
25
+ img = PIL .Image .open (str (img )).convert ('RGB' )
26
+ img = torchvision .transforms .Resize ([128 , 128 ])(img )
27
+ a = np .asarray (img )
28
+ a = np .transpose (a , (1 , 0 , 2 ))
29
+ a = np .transpose (a , (2 , 1 , 0 ))
30
+ return t .from_numpy (a .astype (np .float32 , copy = False )).div (255 )
31
+
32
+ def __getitem__ (self , idx ):
33
+ return self ._get_image_ (idx ), self .labels [idx ]
34
+
35
+ def get_item_by_jpg (self , jpg ):
36
+ return self ._get_image_ (self .labels .index (jpg ))
You can’t perform that action at this time.
0 commit comments