You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
q, k, v=map(lambdat: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
159
176
177
+
# handle value residual
178
+
179
+
orig_v=v
180
+
181
+
ifexists(self.learned_value_residual_mix):
182
+
mix=self.learned_value_residual_mix(x)
183
+
v=v.lerp(value_residual, mix)
184
+
160
185
# add a null key / value
161
186
# to protect against an entirely masked out sequence
162
187
# as well as giving attention ability to attend to nothing
@@ -191,7 +216,7 @@ def forward(
191
216
192
217
out=rearrange(out, 'b h n d -> b n (h d)')
193
218
194
-
returnself.to_out(out), next_xl_memories
219
+
returnself.to_out(out), next_xl_memories, orig_v
195
220
196
221
# transformer
197
222
@@ -269,14 +294,17 @@ def __init__(
269
294
270
295
self.layers=ModuleList([])
271
296
272
-
for_inrange(depth):
297
+
forlayer_indexinrange(depth):
298
+
is_first=layer_index==0
299
+
273
300
self.layers.append(ModuleList([
274
301
init_hyper_conn(dim=dim, branch=Attention(
275
302
dim=dim,
276
303
dim_head=dim_head,
277
304
causal=causal,
278
305
heads=heads,
279
306
use_flash_attn=use_flash_attn,
307
+
accept_value_residual=notis_first,
280
308
use_custom_causal_attn_mask=memory_not_causal,
281
309
dropout=attn_dropout
282
310
)),
@@ -435,15 +463,20 @@ def forward(
435
463
ifhas_xl_memoriesandself.enhanced_xl_recurrenceandlen(xl_memories) >1: # simply shift all the xl memories down by one, so lower layer gets access to representations from layer above
0 commit comments