Skip to content

Commit 27079dd

Browse files
committed
add first training stability measure from cogview paper, hidden behind feature flag
1 parent 093b9ef commit 27079dd

File tree

4 files changed

+28
-4
lines changed

4 files changed

+28
-4
lines changed

README.md

+11
Original file line numberDiff line numberDiff line change
@@ -544,4 +544,15 @@ $ python generate.py --chinese --text '追老鼠的猫'
544544
}
545545
```
546546

547+
```bibtex
548+
@misc{ding2021cogview,
549+
title = {CogView: Mastering Text-to-Image Generation via Transformers},
550+
author = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang},
551+
year = {2021},
552+
eprint = {2105.13290},
553+
archivePrefix = {arXiv},
554+
primaryClass = {cs.CV}
555+
}
556+
```
557+
547558
*Those who do not want to imitate anything, produce nothing.* - Dali

dalle_pytorch/dalle_pytorch.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from dalle_pytorch import distributed_utils
1010
from dalle_pytorch.vae import OpenAIDiscreteVAE
1111
from dalle_pytorch.vae import VQGanVAE1024
12-
from dalle_pytorch.transformer import Transformer
12+
from dalle_pytorch.transformer import Transformer, DivideMax
1313

1414
# helpers
1515

@@ -322,6 +322,7 @@ def __init__(
322322
sparse_attn = False,
323323
attn_types = None,
324324
loss_img_weight = 7,
325+
stable = False
325326
):
326327
super().__init__()
327328
assert isinstance(vae, (DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE1024)), 'vae must be an instance of DiscreteVAE'
@@ -365,10 +366,12 @@ def __init__(
365366
ff_dropout = ff_dropout,
366367
attn_types = attn_types,
367368
image_fmap_size = image_fmap_size,
368-
sparse_attn = sparse_attn
369+
sparse_attn = sparse_attn,
370+
stable = stable
369371
)
370372

371373
self.to_logits = nn.Sequential(
374+
DivideMax(dim = -1) if stable else nn.Identity(),
372375
nn.LayerNorm(dim),
373376
nn.Linear(dim, self.total_tokens),
374377
)

dalle_pytorch/transformer.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@ def cast_tuple(val, depth = 1):
2626

2727
# classes
2828

29+
class DivideMax(nn.Module):
30+
def __init__(self, dim):
31+
super().__init__()
32+
self.dim = dim
33+
34+
def forward(self, x):
35+
maxes = x.amax(dim = self.dim, keepdim = True)
36+
return x / maxes
37+
2938
# https://arxiv.org/abs/2103.17239
3039
class LayerScale(nn.Module):
3140
def __init__(self, dim, depth, fn):
@@ -86,7 +95,8 @@ def __init__(
8695
ff_dropout = 0.,
8796
attn_types = None,
8897
image_fmap_size = None,
89-
sparse_attn = False
98+
sparse_attn = False,
99+
stable = False
90100
):
91101
super().__init__()
92102
layers = nn.ModuleList([])

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'dalle-pytorch',
55
packages = find_packages(),
66
include_package_data = True,
7-
version = '0.12.0',
7+
version = '0.12.1',
88
license='MIT',
99
description = 'DALL-E - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)