Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 0b17e18

Browse files
T2T TeamCopybara-Service
authored andcommitted
Fix TransformerMemory model
Fixes a bug with training and enables relative attention. Absolute attention works poorly when the timing signal is reset at the start of each chunk. PiperOrigin-RevId: 239686861
1 parent a6f8a00 commit 0b17e18

File tree

3 files changed

+49
-15
lines changed

3 files changed

+49
-15
lines changed

tensor2tensor/layers/common_attention.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1667,6 +1667,20 @@ def dot_product_attention_relative(q,
16671667
return _relative_attention_inner(weights, v, relations_values, False)
16681668

16691669

1670+
def dot_product_attention_relative_memory(q, k, v, bias, *args, **kwargs):
1671+
"""Wrapper of dot_product_attention_relative to use with recurrent memory."""
1672+
1673+
q_len = tf.shape(q)[2]
1674+
k_len = tf.shape(k)[2]
1675+
num_memory_items = k_len - q_len
1676+
1677+
q = tf.pad(q, [[0, 0], [0, 0], [num_memory_items, 0], [0, 0]])
1678+
bias = tf.pad(bias, [[0, 0], [0, 0], [num_memory_items, 0], [0, 0]])
1679+
output = dot_product_attention_relative(q, k, v, bias, *args, **kwargs)
1680+
1681+
return output[:, :, num_memory_items:, :]
1682+
1683+
16701684
def _relative_position_to_absolute_position_masked(x):
16711685
"""Helper to dot_product_self_attention_relative_v2.
16721686
@@ -4152,6 +4166,18 @@ def multihead_attention(query_antecedent,
41524166
save_weights_to=save_weights_to,
41534167
make_image_summary=make_image_summary,
41544168
cache=cache is not None)
4169+
elif attention_type == "dot_product_relative_memory":
4170+
x = dot_product_attention_relative_memory(
4171+
q,
4172+
k,
4173+
v,
4174+
bias,
4175+
max_relative_position,
4176+
dropout_rate,
4177+
image_shapes,
4178+
save_weights_to=save_weights_to,
4179+
make_image_summary=make_image_summary,
4180+
cache=cache is not None)
41554181
elif attention_type == "dot_product_unmasked_relative_v2":
41564182
x = dot_product_unmasked_self_attention_relative_v2(
41574183
q,

tensor2tensor/layers/transformer_memory.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,17 +90,21 @@ def pre_attention(self, segment, query_antecedent, memory_antecedent, bias):
9090
[tf.stop_gradient(previous_vals), query_antecedent], 1)
9191
new_bias = tf.concat([previous_bias, bias], -1)
9292

93-
cancel_update = tf.equal(self.previous_segment, segment[0])
9493
remember_segment = segment[0]
95-
remember_vals = tf.cond(
96-
cancel_update,
97-
lambda: self.previous_vals,
98-
lambda: tf.pad(query_antecedent, [[0, amount_to_pad], [0, 0], [0, 0]]))
99-
remember_bias = tf.cond(
100-
cancel_update,
101-
lambda: self.previous_bias,
102-
lambda: tf.zeros_like(bias) + tf.reduce_max(bias, -1, keep_dims=True))
103-
94+
# TODO(kitaev): The code assumes that we always either increment the chunk
95+
# number or reset it to zero, which is checked by the assertion. This
96+
# assumption will not hold if we re-run the model for each token, e.g. for
97+
# autoregressive greedy/beam/sampling decode.
98+
with tf.control_dependencies(
99+
[tf.Assert(tf.math.logical_or(
100+
tf.equal(remember_segment, 0),
101+
tf.equal(remember_segment, self.previous_segment + 1)),
102+
[self.previous_segment, remember_segment])]):
103+
remember_segment = tf.identity(remember_segment)
104+
remember_vals = tf.pad(query_antecedent,
105+
[[0, amount_to_pad], [0, 0], [0, 0]])
106+
remember_bias = tf.zeros_like(bias) + tf.reduce_max(
107+
bias, -1, keep_dims=True)
104108
token = (remember_segment, remember_vals, remember_bias)
105109

106110
return token, query_antecedent, new_memory_antecedent, new_bias

tensor2tensor/models/transformer.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2624,17 +2624,21 @@ def transformer_wikitext103_l4k_v0():
26242624

26252625

26262626
@registry.register_hparams
2627-
def transformer_wikitext103_l4k_memory():
2627+
def transformer_wikitext103_l4k_memory_v0():
26282628
"""HParams for training languagemodel_wikitext103_l4k with memory."""
26292629
hparams = transformer_wikitext103_l4k_v0()
26302630

2631-
hparams.split_targets_chunk_length = 8
2632-
hparams.split_targets_max_chunks = 512
2631+
hparams.split_targets_chunk_length = 64
2632+
hparams.split_targets_max_chunks = 64
26332633

26342634
# The hparams specify batch size *before* chunking, but we want to have a
26352635
# consistent 4K batch size *after* chunking to fully utilize the hardware.
26362636
target_tokens_per_batch = 4096
2637-
hparams.batch_size = target_tokens_per_batch * (
2638-
hparams.max_length / hparams.split_targets_chunk_length) # 2097152
2637+
hparams.batch_size = int(target_tokens_per_batch * (
2638+
hparams.max_length / hparams.split_targets_chunk_length)) # 262144
2639+
2640+
hparams.pos = None
2641+
hparams.self_attention_type = "dot_product_relative_memory"
2642+
hparams.max_relative_position = 2 * hparams.split_targets_chunk_length
26392643

26402644
return hparams

0 commit comments

Comments
 (0)