17
17
from dalle_pytorch .loader import TextImageDataset
18
18
from dalle_pytorch .tokenizer import tokenizer , HugTokenizer , ChineseTokenizer , YttmTokenizer
19
19
20
+ # libraries needed for webdataset support
21
+ import webdataset as wds
22
+ from torchvision import transforms as T
23
+ from PIL import Image
24
+ from io import BytesIO
25
+
26
+
20
27
# argument parsing
21
28
22
29
parser = argparse .ArgumentParser ()
38
45
parser .add_argument ('--image_text_folder' , type = str , required = True ,
39
46
help = 'path to your folder of images and text for learning the DALL-E' )
40
47
48
+ parser .add_argument (
49
+ '--wds' ,
50
+ type = str ,
51
+ default = '' ,
52
+ help = 'Comma separated list of WebDataset (1) image and (2) text column names. Must contain 2 values, e.g. img,cap.'
53
+ )
54
+
41
55
parser .add_argument ('--truncate_captions' , dest = 'truncate_captions' , action = 'store_true' ,
42
56
help = 'Captions passed in which exceed the max token length will be truncated if this is set.' )
43
57
92
106
93
107
model_group .add_argument ('--dim' , default = 512 , type = int , help = 'Model dimension' )
94
108
95
- model_group .add_argument ('--text_seq_len' , default = 256 , type = int , help = 'Text sequence length' )
109
+ model_group .add_argument ('--text_seq_len' , default = 128 , type = int , help = 'Text sequence length' )
96
110
97
111
model_group .add_argument ('--depth' , default = 2 , type = int , help = 'Model depth' )
98
112
99
- model_group .add_argument ('--heads' , default = 8 , type = int , help = 'Model number of heads' )
113
+ model_group .add_argument ('--heads' , default = 4 , type = int , help = 'Model number of heads' )
100
114
101
- model_group .add_argument ('--dim_head' , default = 64 , type = int , help = 'Model head dimension' )
115
+ model_group .add_argument ('--dim_head' , default = 16 , type = int , help = 'Model head dimension' )
102
116
103
117
train_group .add_argument ('--ff_dropout' , default = 0.0 , type = float , help = 'Feed forward dropout.' )
104
118
112
126
113
127
args = parser .parse_args ()
114
128
115
- # quit early if you used the wrong folder name
116
-
117
- assert Path (args .image_text_folder ).exists (), f'The path { args .image_text_folder } was not found.'
118
-
119
129
# helpers
120
130
121
131
def exists (val ):
@@ -137,6 +147,8 @@ def cp_path_to_dir(cp_path, tag):
137
147
return cp_dir
138
148
139
149
# constants
150
+ WEBDATASET_IMAGE_TEXT_COLUMNS = tuple (args .wds .split (',' ))
151
+ ENABLE_WEBDATASET = True if len (WEBDATASET_IMAGE_TEXT_COLUMNS ) == 2 else False
140
152
141
153
DALLE_OUTPUT_FILE_NAME = args .dalle_output_file_name + ".pt"
142
154
@@ -169,6 +181,27 @@ def cp_path_to_dir(cp_path, tag):
169
181
170
182
DEEPSPEED_CP_AUX_FILENAME = 'auxiliary.pt'
171
183
184
+ if not ENABLE_WEBDATASET :
185
+ # quit early if you used the wrong folder name
186
+ assert Path (args .image_text_folder ).exists (), f'The path { args .image_text_folder } was not found.'
187
+ else :
188
+ # quit early if no tar files were found
189
+ if Path (args .image_text_folder ).is_dir ():
190
+ DATASET = [str (p ) for p in Path (args .image_text_folder ).glob ("**/*" ) if ".tar" in str (p ).lower ()] # .name
191
+ assert len (DATASET ) > 0 , 'The directory ({}) does not contain any WebDataset/.tar files.' .format (args .image_text_folder )
192
+ print ('Found {} WebDataset .tar(.gz) file(s) under given path {}!' .format (len (DATASET ), args .image_text_folder ))
193
+ elif ('http://' in args .image_text_folder .lower ()) | ('https://' in args .image_text_folder .lower ()):
194
+ DATASET = f"pipe:curl -L -s { args .image_text_folder } || true"
195
+ print ('Found {} http(s) link under given path!' .format (len (DATASET ), args .image_text_folder ))
196
+ elif 'gs://' in args .image_text_folder .lower ():
197
+ DATASET = f"pipe:gsutil cat { args .image_text_folder } || true"
198
+ print ('Found {} GCS link under given path!' .format (len (DATASET ), args .image_text_folder ))
199
+ elif '.tar' in args .image_text_folder :
200
+ DATASET = args .image_text_folder
201
+ print ('Found WebDataset .tar(.gz) file under given path {}!' .format (args .image_text_folder ))
202
+ else :
203
+ raise Exception ('No folder, no .tar(.gz) and no url pointing to tar files provided under {}.' .format (args .image_text_folder ))
204
+
172
205
# initialize distributed backend
173
206
174
207
distr_backend = distributed_utils .set_backend_from_args (args )
@@ -283,19 +316,61 @@ def group_weight(model):
283
316
284
317
is_shuffle = not distributed_utils .using_backend (distributed_utils .HorovodBackend )
285
318
286
- ds = TextImageDataset (
287
- args .image_text_folder ,
288
- text_len = TEXT_SEQ_LEN ,
289
- image_size = IMAGE_SIZE ,
290
- resize_ratio = args .resize_ratio ,
291
- truncate_captions = args .truncate_captions ,
292
- tokenizer = tokenizer ,
293
- shuffle = is_shuffle ,
294
- )
319
+ imagepreproc = T .Compose ([
320
+ T .Lambda (lambda img : img .convert ('RGB' )
321
+ if img .mode != 'RGB' else img ),
322
+ T .RandomResizedCrop (IMAGE_SIZE ,
323
+ scale = (args .resize_ratio , 1. ),
324
+ ratio = (1. , 1. )),
325
+ T .ToTensor (),
326
+ ])
327
+
328
+ def imagetransform (b ):
329
+ return Image .open (BytesIO (b ))
330
+
331
+ def tokenize (s ):
332
+ return tokenizer .tokenize (
333
+ s .decode ('utf-8' ),
334
+ TEXT_SEQ_LEN ,
335
+ truncate_text = args .truncate_captions ).squeeze (0 )
336
+
337
+ if ENABLE_WEBDATASET :
338
+ DATASET_SIZE = int (1e9 ) # You need to set a nominal length for the Dataset in order to avoid warnings from DataLoader
339
+
340
+ myimg , mycap = WEBDATASET_IMAGE_TEXT_COLUMNS
341
+ image_text_mapping = {
342
+ myimg : imagetransform ,
343
+ mycap : tokenize
344
+ }
345
+ image_mapping = {
346
+ myimg : imagepreproc
347
+ }
348
+
349
+ num_batches = DATASET_SIZE // BATCH_SIZE
350
+
351
+ ds = (
352
+ wds .WebDataset (DATASET , length = num_batches )
353
+ # .shuffle(is_shuffle) # Commented out for WebDataset as the behaviour cannot be predicted yet
354
+ .map_dict (** image_text_mapping )
355
+ .map_dict (** image_mapping )
356
+ .to_tuple (mycap , myimg )
357
+ .batched (BATCH_SIZE , partial = False ) # It is good to avoid partial batches when using Distributed training
358
+ )
359
+ else :
360
+ ds = TextImageDataset (
361
+ args .image_text_folder ,
362
+ text_len = TEXT_SEQ_LEN ,
363
+ image_size = IMAGE_SIZE ,
364
+ resize_ratio = args .resize_ratio ,
365
+ truncate_captions = args .truncate_captions ,
366
+ tokenizer = tokenizer ,
367
+ shuffle = is_shuffle ,
368
+ )
295
369
296
370
assert len (ds ) > 0 , 'dataset is empty'
297
371
if distr_backend .is_root_worker ():
298
- print (f'{ len (ds )} image-text pairs found for training' )
372
+ if not ENABLE_WEBDATASET :
373
+ print (f'{ len (ds )} image-text pairs found for training' )
299
374
300
375
if not is_shuffle :
301
376
data_sampler = torch .utils .data .distributed .DistributedSampler (
@@ -306,10 +381,18 @@ def group_weight(model):
306
381
else :
307
382
data_sampler = None
308
383
309
- dl = DataLoader (ds , batch_size = BATCH_SIZE , shuffle = is_shuffle , drop_last = True , sampler = data_sampler )
384
+ if ENABLE_WEBDATASET :
385
+ # WebLoader for WebDataset and DeepSpeed compatibility
386
+ dl = wds .WebLoader (ds , batch_size = None , shuffle = False ) # optionally add num_workers=2 (n) argument
387
+ number_of_batches = DATASET_SIZE // (BATCH_SIZE * distr_backend .get_world_size ())
388
+ dl = dl .repeat (2 ).slice (number_of_batches )
389
+ dl .length = number_of_batches
390
+ else :
391
+ # Regular DataLoader for image-text-folder datasets
392
+ dl = DataLoader (ds , batch_size = BATCH_SIZE , shuffle = is_shuffle , drop_last = True , sampler = data_sampler )
310
393
311
- # initialize DALL-E
312
394
395
+ # initialize DALL-E
313
396
314
397
dalle = DALLE (vae = vae , ** dalle_params )
315
398
if not using_deepspeed :
@@ -454,13 +537,13 @@ def save_model(path, epoch=0):
454
537
455
538
# training
456
539
457
- # Saves a checkpoint before training begins to fail early when mis-configured.
540
+ # Saves a checkpoint before training begins to fail early when mis-configured.
458
541
# See https://github.com/lucidrains/DALLE-pytorch/wiki/DeepSpeed-Checkpoints
459
542
save_model (DALLE_OUTPUT_FILE_NAME , epoch = resume_epoch )
460
543
for epoch in range (resume_epoch , EPOCHS ):
461
544
if data_sampler :
462
545
data_sampler .set_epoch (epoch )
463
- for i , (text , images ) in enumerate (distr_dl ):
546
+ for i , (text , images ) in enumerate (( dl if ENABLE_WEBDATASET else distr_dl ) ):
464
547
if i % 10 == 0 and distr_backend .is_root_worker ():
465
548
t = time .time ()
466
549
if args .fp16 :
0 commit comments