-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlaion_dataset.py
69 lines (50 loc) · 2.28 KB
/
laion_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import json
import cv2
import numpy as np
from annotator.util import resize_image, HWC3
from torch.utils.data import Dataset
import os
ROOT = "./"
DATA = "./data/laion-art/"
class LAIONDataset(Dataset):
def __init__(self, canny_low=100, canny_high=200):
self.canny_low = canny_low
self.canny_high = canny_high
self.data = []
# only first 30 directories seem to be correct, others might be corrupted (TODO redownload data)
directories = [f"000{i}" if i > 9 else f"0000{i}" for i in range(1)]
for d in directories:
print(f"Parsing directory {d}...")
directory_path = os.path.join(ROOT, DATA, d)
for filename in os.listdir(directory_path):
name = os.path.splitext(filename)[0]
jpg = f"{name}.jpg"
txt = f"{name}.txt"
jpg_path = os.path.join(directory_path, jpg)
txt_path = os.path.join(directory_path, txt)
if os.path.exists(jpg_path) and os.path.exists(txt_path):
with open(txt_path, 'r') as f:
prompt = f.read()
self.data.append({'target': os.path.join(d, jpg), 'prompt': prompt})
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
# source_filename = item['source'] # TODO pregenerate source (hint) data
target_filename = item['target']
prompt = item['prompt']
#source = cv2.imread(os.path.join(ROOT, DATA, source_filename))
target = cv2.imread(os.path.join(ROOT, DATA, target_filename))
# Do not forget that OpenCV read images in BGR order.
#source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)
# Resize source and target
#source = resize_image(source, 256)
target = resize_image(target, 256)
# Create source (hint) image
source = HWC3(cv2.Canny(target, self.canny_low, self.canny_high))
# Normalize source images to [0, 1].
source = source.astype(np.float32) / 255.0
# Normalize target images to [-1, 1].
target = (target.astype(np.float32) / 127.5) - 1.0
return dict(jpg=target, txt=prompt, hint=source)