Skip to content

Commit 2eceb84

Browse files
authored
Merge pull request #280 from robvanvolt/loader-for-webdataset-included
Added support for webdataset
2 parents d6107cc + 122bc51 commit 2eceb84

File tree

4 files changed

+146
-24
lines changed

4 files changed

+146
-24
lines changed

Diff for: .gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ outputs/
33
*.pt
44
taming/
55
wandb/
6+
dalle-ds-cp/
67

78
# Byte-compiled / optimized / DLL files
89
__pycache__/
@@ -90,6 +91,9 @@ ipython_config.py
9091
# pyenv
9192
.python-version
9293

94+
# Visual Studio Code
95+
.vscode
96+
9397
# pipenv
9498
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
9599
# However, in case of collaboration, if having platform-specific dependencies or dependencies

Diff for: README.md

+35-1
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,9 @@ Weights and Biases will allow you to monitor the temperature annealing, image re
332332

333333
Once you have trained a decent VAE to your satisfaction, you can move on to the next step with your model weights at `./vae.pt`.
334334

335-
### DALL-E
335+
### DALL-E Training
336+
337+
## Training using an Image-Text-Folder
336338

337339
Now you just have to invoke the `./train_dalle.py` script, indicating which VAE model you would like to use, as well as the path to your folder if images and text.
338340

@@ -370,6 +372,38 @@ You likely will not finish DALL-E training as quickly as you did your Discrete V
370372
$ python train_dalle.py --dalle_path ./dalle.pt --image_text_folder /path/to/data
371373
```
372374

375+
## Training using WebDataset
376+
377+
WebDataset files are regular .tar(.gz) files which can be streamed and used for DALLE-pytorch training.
378+
You Just need to provide the image (first comma separated argument) and caption (second comma separated argument)
379+
column key after the --wds argument. The ---image_text_folder points to your .tar(.gz) file instead of the datafolder.
380+
381+
```python
382+
$ python train_dalle.py --wds img,cap --image_text_folder /path/to/data.tar(.gz)
383+
```
384+
385+
Distributed training with deepspeed works the same way, e.g.:
386+
387+
```python
388+
$ deepspeed train_dalle.py --wds img,cap --image_text_folder /path/to/data.tar(.gz) --fp16 --deepspeed
389+
```
390+
391+
If you have containing shards (dataset split into several .tar(.gz) files), this is also supported:
392+
393+
```python
394+
$ deepspeed train_dalle.py --wds img,cap --image_text_folder /path/to/shardfolder --fp16 --deepspeed
395+
```
396+
397+
You can stream the data from a http server or gloogle cloud storage like this:
398+
399+
```python
400+
$ deepspeed train_dalle.py --image_text_folder "http://storage.googleapis.com/nvdata-openimages/openimages-train-{000000..000554}.tar" --wds jpg,json --taming --truncate_captions --random_resize_crop_lower_ratio=0.8 --attn_types=full --epochs=2 --fp16 --deepspeed
401+
```
402+
403+
In order to convert your image-text-folder to WebDataset format, you can make use of one of several methods.
404+
(https://www.youtube.com/watch?v=v_PacO-3OGQ here are given 4 examples, or a little helper script which also supports splitting your dataset
405+
into shards of .tar.gz files https://github.com/robvanvolt/DALLE-datasets/blob/main/wds_create_shards.py)
406+
373407
### DALL-E with OpenAI's VAE
374408

375409
You can now also train DALL-E without having to train the Discrete VAE at all, courtesy to their open-sourcing their model. You simply have to invoke the `train_dalle.py` script without specifying the `--vae_path`

Diff for: setup.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'dalle-pytorch',
55
packages = find_packages(),
66
include_package_data = True,
7-
version = '0.12.5',
7+
version = '0.13.0',
88
license='MIT',
99
description = 'DALL-E - Pytorch',
1010
author = 'Phil Wang',
@@ -30,7 +30,8 @@
3030
'torchvision',
3131
'transformers',
3232
'tqdm',
33-
'youtokentome'
33+
'youtokentome',
34+
'WebDataset'
3435
],
3536
classifiers=[
3637
'Development Status :: 4 - Beta',

Diff for: train_dalle.py

+104-21
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@
1717
from dalle_pytorch.loader import TextImageDataset
1818
from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, ChineseTokenizer, YttmTokenizer
1919

20+
# libraries needed for webdataset support
21+
import webdataset as wds
22+
from torchvision import transforms as T
23+
from PIL import Image
24+
from io import BytesIO
25+
26+
2027
# argument parsing
2128

2229
parser = argparse.ArgumentParser()
@@ -38,6 +45,13 @@
3845
parser.add_argument('--image_text_folder', type=str, required=True,
3946
help='path to your folder of images and text for learning the DALL-E')
4047

48+
parser.add_argument(
49+
'--wds',
50+
type = str,
51+
default='',
52+
help = 'Comma separated list of WebDataset (1) image and (2) text column names. Must contain 2 values, e.g. img,cap.'
53+
)
54+
4155
parser.add_argument('--truncate_captions', dest='truncate_captions', action='store_true',
4256
help='Captions passed in which exceed the max token length will be truncated if this is set.')
4357

@@ -92,13 +106,13 @@
92106

93107
model_group.add_argument('--dim', default = 512, type = int, help = 'Model dimension')
94108

95-
model_group.add_argument('--text_seq_len', default = 256, type = int, help = 'Text sequence length')
109+
model_group.add_argument('--text_seq_len', default = 128, type = int, help = 'Text sequence length')
96110

97111
model_group.add_argument('--depth', default = 2, type = int, help = 'Model depth')
98112

99-
model_group.add_argument('--heads', default = 8, type = int, help = 'Model number of heads')
113+
model_group.add_argument('--heads', default = 4, type = int, help = 'Model number of heads')
100114

101-
model_group.add_argument('--dim_head', default = 64, type = int, help = 'Model head dimension')
115+
model_group.add_argument('--dim_head', default = 16, type = int, help = 'Model head dimension')
102116

103117
train_group.add_argument('--ff_dropout', default = 0.0, type = float, help = 'Feed forward dropout.')
104118

@@ -112,10 +126,6 @@
112126

113127
args = parser.parse_args()
114128

115-
# quit early if you used the wrong folder name
116-
117-
assert Path(args.image_text_folder).exists(), f'The path {args.image_text_folder} was not found.'
118-
119129
# helpers
120130

121131
def exists(val):
@@ -137,6 +147,8 @@ def cp_path_to_dir(cp_path, tag):
137147
return cp_dir
138148

139149
# constants
150+
WEBDATASET_IMAGE_TEXT_COLUMNS = tuple(args.wds.split(','))
151+
ENABLE_WEBDATASET = True if len(WEBDATASET_IMAGE_TEXT_COLUMNS) == 2 else False
140152

141153
DALLE_OUTPUT_FILE_NAME = args.dalle_output_file_name + ".pt"
142154

@@ -169,6 +181,27 @@ def cp_path_to_dir(cp_path, tag):
169181

170182
DEEPSPEED_CP_AUX_FILENAME = 'auxiliary.pt'
171183

184+
if not ENABLE_WEBDATASET:
185+
# quit early if you used the wrong folder name
186+
assert Path(args.image_text_folder).exists(), f'The path {args.image_text_folder} was not found.'
187+
else:
188+
# quit early if no tar files were found
189+
if Path(args.image_text_folder).is_dir():
190+
DATASET = [str(p) for p in Path(args.image_text_folder).glob("**/*") if ".tar" in str(p).lower()] # .name
191+
assert len(DATASET) > 0, 'The directory ({}) does not contain any WebDataset/.tar files.'.format(args.image_text_folder)
192+
print('Found {} WebDataset .tar(.gz) file(s) under given path {}!'.format(len(DATASET), args.image_text_folder))
193+
elif ('http://' in args.image_text_folder.lower()) | ('https://' in args.image_text_folder.lower()):
194+
DATASET = f"pipe:curl -L -s {args.image_text_folder} || true"
195+
print('Found {} http(s) link under given path!'.format(len(DATASET), args.image_text_folder))
196+
elif 'gs://' in args.image_text_folder.lower():
197+
DATASET = f"pipe:gsutil cat {args.image_text_folder} || true"
198+
print('Found {} GCS link under given path!'.format(len(DATASET), args.image_text_folder))
199+
elif '.tar' in args.image_text_folder:
200+
DATASET = args.image_text_folder
201+
print('Found WebDataset .tar(.gz) file under given path {}!'.format(args.image_text_folder))
202+
else:
203+
raise Exception('No folder, no .tar(.gz) and no url pointing to tar files provided under {}.'.format(args.image_text_folder))
204+
172205
# initialize distributed backend
173206

174207
distr_backend = distributed_utils.set_backend_from_args(args)
@@ -283,19 +316,61 @@ def group_weight(model):
283316

284317
is_shuffle = not distributed_utils.using_backend(distributed_utils.HorovodBackend)
285318

286-
ds = TextImageDataset(
287-
args.image_text_folder,
288-
text_len=TEXT_SEQ_LEN,
289-
image_size=IMAGE_SIZE,
290-
resize_ratio=args.resize_ratio,
291-
truncate_captions=args.truncate_captions,
292-
tokenizer=tokenizer,
293-
shuffle=is_shuffle,
294-
)
319+
imagepreproc = T.Compose([
320+
T.Lambda(lambda img: img.convert('RGB')
321+
if img.mode != 'RGB' else img),
322+
T.RandomResizedCrop(IMAGE_SIZE,
323+
scale=(args.resize_ratio, 1.),
324+
ratio=(1., 1.)),
325+
T.ToTensor(),
326+
])
327+
328+
def imagetransform(b):
329+
return Image.open(BytesIO(b))
330+
331+
def tokenize(s):
332+
return tokenizer.tokenize(
333+
s.decode('utf-8'),
334+
TEXT_SEQ_LEN,
335+
truncate_text=args.truncate_captions).squeeze(0)
336+
337+
if ENABLE_WEBDATASET:
338+
DATASET_SIZE = int(1e9) # You need to set a nominal length for the Dataset in order to avoid warnings from DataLoader
339+
340+
myimg, mycap = WEBDATASET_IMAGE_TEXT_COLUMNS
341+
image_text_mapping = {
342+
myimg: imagetransform,
343+
mycap: tokenize
344+
}
345+
image_mapping = {
346+
myimg: imagepreproc
347+
}
348+
349+
num_batches = DATASET_SIZE // BATCH_SIZE
350+
351+
ds = (
352+
wds.WebDataset(DATASET, length=num_batches)
353+
# .shuffle(is_shuffle) # Commented out for WebDataset as the behaviour cannot be predicted yet
354+
.map_dict(**image_text_mapping)
355+
.map_dict(**image_mapping)
356+
.to_tuple(mycap, myimg)
357+
.batched(BATCH_SIZE, partial=False) # It is good to avoid partial batches when using Distributed training
358+
)
359+
else:
360+
ds = TextImageDataset(
361+
args.image_text_folder,
362+
text_len=TEXT_SEQ_LEN,
363+
image_size=IMAGE_SIZE,
364+
resize_ratio=args.resize_ratio,
365+
truncate_captions=args.truncate_captions,
366+
tokenizer=tokenizer,
367+
shuffle=is_shuffle,
368+
)
295369

296370
assert len(ds) > 0, 'dataset is empty'
297371
if distr_backend.is_root_worker():
298-
print(f'{len(ds)} image-text pairs found for training')
372+
if not ENABLE_WEBDATASET:
373+
print(f'{len(ds)} image-text pairs found for training')
299374

300375
if not is_shuffle:
301376
data_sampler = torch.utils.data.distributed.DistributedSampler(
@@ -306,10 +381,18 @@ def group_weight(model):
306381
else:
307382
data_sampler = None
308383

309-
dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=is_shuffle, drop_last=True, sampler=data_sampler)
384+
if ENABLE_WEBDATASET:
385+
# WebLoader for WebDataset and DeepSpeed compatibility
386+
dl = wds.WebLoader(ds, batch_size=None, shuffle=False) # optionally add num_workers=2 (n) argument
387+
number_of_batches = DATASET_SIZE // (BATCH_SIZE * distr_backend.get_world_size())
388+
dl = dl.repeat(2).slice(number_of_batches)
389+
dl.length = number_of_batches
390+
else:
391+
# Regular DataLoader for image-text-folder datasets
392+
dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=is_shuffle, drop_last=True, sampler=data_sampler)
310393

311-
# initialize DALL-E
312394

395+
# initialize DALL-E
313396

314397
dalle = DALLE(vae=vae, **dalle_params)
315398
if not using_deepspeed:
@@ -454,13 +537,13 @@ def save_model(path, epoch=0):
454537

455538
# training
456539

457-
# Saves a checkpoint before training begins to fail early when mis-configured.
540+
# Saves a checkpoint before training begins to fail early when mis-configured.
458541
# See https://github.com/lucidrains/DALLE-pytorch/wiki/DeepSpeed-Checkpoints
459542
save_model(DALLE_OUTPUT_FILE_NAME, epoch=resume_epoch)
460543
for epoch in range(resume_epoch, EPOCHS):
461544
if data_sampler:
462545
data_sampler.set_epoch(epoch)
463-
for i, (text, images) in enumerate(distr_dl):
546+
for i, (text, images) in enumerate((dl if ENABLE_WEBDATASET else distr_dl)):
464547
if i % 10 == 0 and distr_backend.is_root_worker():
465548
t = time.time()
466549
if args.fp16:

0 commit comments

Comments
 (0)