Skip to content

Commit 66b573b

Browse files
committed
add reinmax
1 parent daf30d0 commit 66b573b

File tree

3 files changed

+28
-3
lines changed

3 files changed

+28
-3
lines changed

README.md

+10
Original file line numberDiff line numberDiff line change
@@ -751,4 +751,14 @@ $ python generate.py --chinese --text '追老鼠的猫'
751751
}
752752
```
753753

754+
```bibtex
755+
@article{Liu2023BridgingDA,
756+
title = {Bridging Discrete and Backpropagation: Straight-Through and Beyond},
757+
author = {Liyuan Liu and Chengyu Dong and Xiaodong Liu and Bin Yu and Jianfeng Gao},
758+
journal = {ArXiv},
759+
year = {2023},
760+
volume = {abs/2304.08612}
761+
}
762+
```
763+
754764
*Those who do not want to imitate anything, produce nothing.* - Dali

dalle_pytorch/dalle_pytorch.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def __init__(
111111
smooth_l1_loss = False,
112112
temperature = 0.9,
113113
straight_through = False,
114+
reinmax = False,
114115
kl_div_loss_weight = 0.,
115116
normalization = ((*((0.5,) * 3), 0), (*((0.5,) * 3), 1))
116117
):
@@ -125,6 +126,8 @@ def __init__(
125126
self.num_layers = num_layers
126127
self.temperature = temperature
127128
self.straight_through = straight_through
129+
self.reinmax = reinmax
130+
128131
self.codebook = nn.Embedding(num_tokens, codebook_dim)
129132

130133
hdim = hidden_dim
@@ -227,8 +230,20 @@ def forward(
227230
return logits # return logits for getting hard image indices for DALL-E training
228231

229232
temp = default(temp, self.temperature)
230-
soft_one_hot = F.gumbel_softmax(logits, tau = temp, dim = 1, hard = self.straight_through)
231-
sampled = einsum('b n h w, n d -> b d h w', soft_one_hot, self.codebook.weight)
233+
234+
one_hot = F.gumbel_softmax(logits, tau = temp, dim = 1, hard = self.straight_through)
235+
236+
if self.straight_through and self.reinmax:
237+
# use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612
238+
# algorithm 2
239+
one_hot = one_hot.detach()
240+
π0 = logits.softmax(dim = 1)
241+
π1 = (one_hot + (logits / temp).softmax(dim = 1)) / 2
242+
π1 = ((π1.log() - logits).detach() + logits).softmax(dim = 1)
243+
π2 = 2 * π1 - 0.5 * π0
244+
one_hot = π2 - π2.detach() + one_hot
245+
246+
sampled = einsum('b n h w, n d -> b d h w', one_hot, self.codebook.weight)
232247
out = self.decoder(sampled)
233248

234249
if not return_loss:

dalle_pytorch/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.6.4'
1+
__version__ = '1.6.5'

0 commit comments

Comments
 (0)