Skip to content

Commit 520a357

Browse files
committed
take the free lunch
1 parent 554802c commit 520a357

File tree

3 files changed

+48
-6
lines changed

3 files changed

+48
-6
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,3 +249,12 @@ loss = model(seq, memory_replay_backprop = True) # memory efficient training fro
249249
url = {https://api.semanticscholar.org/CorpusID:272987528}
250250
}
251251
```
252+
253+
```bibtex
254+
@inproceedings{Zhou2024ValueRL,
255+
title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
256+
author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
257+
year = {2024},
258+
url = {https://api.semanticscholar.org/CorpusID:273532030}
259+
}
260+
```

recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torch import nn, einsum, Tensor
1212

1313
from einops import rearrange, repeat, pack, unpack
14+
from einops.layers.torch import Rearrange
1415

1516
from recurrent_memory_transformer_pytorch.attend import Attend
1617

@@ -120,6 +121,7 @@ def __init__(
120121
dim_head = 64,
121122
heads = 8,
122123
dropout = 0.,
124+
accept_value_residual = False,
123125
use_flash_attn = False,
124126
use_custom_causal_attn_mask = False
125127
):
@@ -141,22 +143,45 @@ def __init__(
141143
self.to_kv = Linear(dim, dim_inner * 2)
142144
self.to_out = Linear(dim_inner, dim)
143145

146+
# learned value residual mixing
147+
148+
self.learned_value_residual_mix = None
149+
150+
if accept_value_residual:
151+
self.learned_value_residual_mix = nn.Sequential(
152+
Linear(dim, heads),
153+
Rearrange('b n h -> b h n 1'),
154+
nn.Sigmoid()
155+
)
156+
144157
def forward(
145158
self,
146159
x,
147160
rotary_emb: tuple[Tensor, Tensor] | None = None,
148161
mask = None,
149-
xl_memories = None
162+
xl_memories = None,
163+
value_residual = None
150164
):
165+
assert not (exists(value_residual) ^ exists(self.learned_value_residual_mix))
166+
151167
h = self.heads
152168
x = self.norm(x)
153169

154-
155170
q = self.to_q(x)
156171
k, v = self.to_kv(x).chunk(2, dim = -1)
157172

173+
# split heads
174+
158175
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
159176

177+
# handle value residual
178+
179+
orig_v = v
180+
181+
if exists(self.learned_value_residual_mix):
182+
mix = self.learned_value_residual_mix(x)
183+
v = v.lerp(value_residual, mix)
184+
160185
# add a null key / value
161186
# to protect against an entirely masked out sequence
162187
# as well as giving attention ability to attend to nothing
@@ -191,7 +216,7 @@ def forward(
191216

192217
out = rearrange(out, 'b h n d -> b n (h d)')
193218

194-
return self.to_out(out), next_xl_memories
219+
return self.to_out(out), next_xl_memories, orig_v
195220

196221
# transformer
197222

@@ -269,14 +294,17 @@ def __init__(
269294

270295
self.layers = ModuleList([])
271296

272-
for _ in range(depth):
297+
for layer_index in range(depth):
298+
is_first = layer_index == 0
299+
273300
self.layers.append(ModuleList([
274301
init_hyper_conn(dim = dim, branch = Attention(
275302
dim = dim,
276303
dim_head = dim_head,
277304
causal = causal,
278305
heads = heads,
279306
use_flash_attn = use_flash_attn,
307+
accept_value_residual = not is_first,
280308
use_custom_causal_attn_mask = memory_not_causal,
281309
dropout = attn_dropout
282310
)),
@@ -435,15 +463,20 @@ def forward(
435463
if has_xl_memories and self.enhanced_xl_recurrence and len(xl_memories) > 1: # simply shift all the xl memories down by one, so lower layer gets access to representations from layer above
436464
xl_memories = [*xl_memories[1:], xl_memories[0]]
437465

466+
# value residual
467+
468+
value_residual = None
469+
438470
# expand streams for hyper connections
439471

440472
x = self.expand_streams(x)
441473

442474
# attention and feedforward
443475

444476
for attn, ff in self.layers:
445-
x, xl_memories = attn(x, mask = mask, xl_memories = next(xl_memories_iter, None), rotary_emb = rotary_emb)
477+
x, xl_memories, attn_values = attn(x, mask = mask, xl_memories = next(xl_memories_iter, None), rotary_emb = rotary_emb, value_residual = value_residual)
446478

479+
value_residual = default(value_residual, attn_values)
447480
new_xl_memories.append(xl_memories)
448481

449482
x = ff(x)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'recurrent-memory-transformer-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.6.0',
6+
version = '0.7.0',
77
license='MIT',
88
description = 'Recurrent Memory Transformer - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)