@@ -111,6 +111,7 @@ def __init__(
111
111
smooth_l1_loss = False ,
112
112
temperature = 0.9 ,
113
113
straight_through = False ,
114
+ reinmax = False ,
114
115
kl_div_loss_weight = 0. ,
115
116
normalization = ((* ((0.5 ,) * 3 ), 0 ), (* ((0.5 ,) * 3 ), 1 ))
116
117
):
@@ -125,6 +126,8 @@ def __init__(
125
126
self .num_layers = num_layers
126
127
self .temperature = temperature
127
128
self .straight_through = straight_through
129
+ self .reinmax = reinmax
130
+
128
131
self .codebook = nn .Embedding (num_tokens , codebook_dim )
129
132
130
133
hdim = hidden_dim
@@ -227,8 +230,20 @@ def forward(
227
230
return logits # return logits for getting hard image indices for DALL-E training
228
231
229
232
temp = default (temp , self .temperature )
230
- soft_one_hot = F .gumbel_softmax (logits , tau = temp , dim = 1 , hard = self .straight_through )
231
- sampled = einsum ('b n h w, n d -> b d h w' , soft_one_hot , self .codebook .weight )
233
+
234
+ one_hot = F .gumbel_softmax (logits , tau = temp , dim = 1 , hard = self .straight_through )
235
+
236
+ if self .straight_through and self .reinmax :
237
+ # use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612
238
+ # algorithm 2
239
+ one_hot = one_hot .detach ()
240
+ π0 = logits .softmax (dim = 1 )
241
+ π1 = (one_hot + (logits / temp ).softmax (dim = 1 )) / 2
242
+ π1 = ((π1 .log () - logits ).detach () + logits ).softmax (dim = 1 )
243
+ π2 = 2 * π1 - 0.5 * π0
244
+ one_hot = π2 - π2 .detach () + one_hot
245
+
246
+ sampled = einsum ('b n h w, n d -> b d h w' , one_hot , self .codebook .weight )
232
247
out = self .decoder (sampled )
233
248
234
249
if not return_loss :
0 commit comments