From fd2b3c5a2d8cb27930053dfcce15b926cf88f5c6 Mon Sep 17 00:00:00 2001 From: RishabGoel Date: Mon, 18 Oct 2021 20:30:43 -0400 Subject: [PATCH 1/8] added the rnn baseline --- core/lib/models.py | 7 +++ core/models/rnn.py | 67 ++++++++++++++++++++++ core/modules/ipagnn/encoder.py | 41 +++++++++++++- third_party/flax_examples/lstm_modules.py | 68 +++++++++++++++++++++++ 4 files changed, 182 insertions(+), 1 deletion(-) create mode 100644 core/models/rnn.py create mode 100644 third_party/flax_examples/lstm_modules.py diff --git a/core/lib/models.py b/core/lib/models.py index 1ea70548..c591ad42 100644 --- a/core/lib/models.py +++ b/core/lib/models.py @@ -1,6 +1,7 @@ """Models library.""" from core.models import ipagnn +from core.models import rnn from core.models import mlp from core.models import transformer from core.modules import transformer_config_lib @@ -30,5 +31,11 @@ def make_model(config, info, deterministic): info=info, transformer_config=transformer_config, ) + elif model_class == 'LSTM': + return rnn.LSTM( + config=config, + info=info, + transformer_config=transformer_config, + ) else: raise ValueError('Unexpected model_class.') diff --git a/core/models/rnn.py b/core/models/rnn.py new file mode 100644 index 00000000..a50c271b --- /dev/null +++ b/core/models/rnn.py @@ -0,0 +1,67 @@ +from typing import Any + +import jax +from flax import linen as nn +import jax.numpy as jnp + +from core.modules.ipagnn import encoder +from core.modules.ipagnn import spans +from third_party.flax_examples import transformer_modules, lstm_modules + + +class LSTM(nn.Module): + + config: Any + info: Any + transformer_config: transformer_modules.TransformerConfig + + def setup(self): + config = self.config + vocab_size = self.info.vocab_size + max_tokens = config.max_tokens + max_num_nodes = config.max_num_nodes + max_num_edges = config.max_num_edges + lstm_config = lstm_modules.LSTMConfig( + vocab_size=vocab_size, + num_layers=config.rnn_layers, + hidden_dim=config.hidden_size,) + self.token_embedder = spans.NodeAwareTokenEmbedder( + transformer_config=self.transformer_config, + num_embeddings=vocab_size, + features=config.hidden_size, + max_tokens=max_tokens, + max_num_nodes=max_num_nodes, + ) + self.encoder = encoder.LSTMEncoder(lstm_config) + + + @nn.compact + def __call__(self, x): + tokens = x['tokens'] + # tokens.shape: batch_size, max_tokens + tokens_mask = tokens > 0 + # tokens_mask.shape: batch_size, max_tokens + encoder_mask = nn.make_attention_mask(tokens_mask, tokens_mask, dtype=jnp.float32) + # encoder_mask.shape: batch_size, 1, max_tokens, max_tokens + # NOTE(rgoel): Ensuring the token encoder is still a Transformer to ensure uniformity. + encoded_inputs = self.token_embedder( + tokens, x['node_token_span_starts'], x['node_token_span_ends'], + x['num_nodes']) + # encoded_inputs.shape: batch_size, max_tokens, hidden_size + encoded_inputs = self.encoder(encoded_inputs) + # encoded_inputs.shape: batch_size, max_tokens, hidden_size + + # NOTE(rgoel): Using only the last state. We can change this to + # pooling across time stamps. + def get_last_state(inputs, last_token): + return inputs[last_token-1] + + get_last_state_batch = jax.vmap(get_last_state) + # encoded_inputs.shape: batch_size, max_tokens, hidden_size + x = get_last_state_batch(encoded_inputs, x['num_tokens']) + # x.shape: batch_size, 1, hidden_size + x = jnp.squeeze(x, 1) + # x.shape: batch_size, hidden_size + x = nn.Dense(features=self.info.num_classes)(x) + # x.shape: batch_size, num_classes + return x, None diff --git a/core/modules/ipagnn/encoder.py b/core/modules/ipagnn/encoder.py index 375db70f..3ab5051b 100644 --- a/core/modules/ipagnn/encoder.py +++ b/core/modules/ipagnn/encoder.py @@ -1,6 +1,6 @@ from flax import linen as nn -from third_party.flax_examples import transformer_modules +from third_party.flax_examples import transformer_modules, lstm_modules Encoder1DBlock = transformer_modules.Encoder1DBlock @@ -43,3 +43,42 @@ def __call__(self, encoded = nn.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x) return encoded + + +class LSTMEncoder(nn.Module): + """LSTM Model Encoder for sequence to sequence translation. + + This Encoder does not encode the input + tokens itself. It assumes the tokens have already been encoded, and any + desired positional embeddings have already been aded. + + Attributes: + config: LSTMConfig dataclass containing hyperparameters. + """ + config: lstm_modules.LSTMConfig + + @nn.compact + def __call__(self, + encoded_inputs,): + """Applies Transformer model on the encoded inputs. + + Args: + encoded_inputs: pre-encoded input data. + + Returns: + output of a lstm encoder. + """ + cfg = self.config + batch_size, max_tokens, _ = encoded_inputs.shape + x = encoded_inputs + x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=cfg.deterministic) + x = x.astype(cfg.dtype) + + # Input Encoder + for lyr in range(cfg.num_layers): + initial_state = lstm_modules.SimpleLSTM().initialize_carry((batch_size,), cfg.hidden_dim) + _, x = lstm_modules.SimpleLSTM()(initial_state, x) # TODO(rgoel): Add layer norm to each layer + + encoded = nn.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x) + + return encoded diff --git a/third_party/flax_examples/lstm_modules.py b/third_party/flax_examples/lstm_modules.py new file mode 100644 index 00000000..52c3226a --- /dev/null +++ b/third_party/flax_examples/lstm_modules.py @@ -0,0 +1,68 @@ +# Copyright 2021 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LSTM-based machine translation model.""" + +# pylint: disable=attribute-defined-outside-init,g-bare-generic +# See issue #620. +# pytype: disable=wrong-arg-count +# pytype: disable=wrong-keyword-args +# pytype: disable=attribute-error +# Note(rgoel): This module is a mirror of the `transformer_modules` + +from typing import Callable, Any, Optional + +import functools +from flax import linen as nn +from flax import struct +from jax import lax +import jax +import jax.numpy as jnp +import numpy as np + +@struct.dataclass +class LSTMConfig: + """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" + vocab_size: int + share_embeddings: bool = False + logits_via_embedding: bool = False + dtype: Any = jnp.float32 + num_layers: int = 1 + hidden_dim: int = 2048 + max_len: int = 2048 + dropout_rate: float = 0.1 + deterministic: bool = False + decode: bool = False + kernel_init: Callable = nn.initializers.xavier_uniform() + bias_init: Callable = nn.initializers.normal(stddev=1e-6) + posemb_init: Optional[Callable] = None + + +class SimpleLSTM(nn.Module): + """A simple unidirectional LSTM.""" + + @functools.partial( + nn.transforms.scan, + variable_broadcast='params', + in_axes=1, out_axes=1, + split_rngs={'params': False}) + @nn.compact + def __call__(self, carry, x): + return nn.OptimizedLSTMCell()(carry, x) + + @staticmethod + def initialize_carry(batch_dims, hidden_size): + # Use fixed random key since default state init fn is just zeros. + return nn.OptimizedLSTMCell.initialize_carry( + jax.random.PRNGKey(0), batch_dims, hidden_size) \ No newline at end of file From e45c018ddf8d392ae628664e92b0d018a3dfc8e1 Mon Sep 17 00:00:00 2001 From: RishabGoel Date: Mon, 18 Oct 2021 20:53:26 -0400 Subject: [PATCH 2/8] added lstm test --- core/models/test_models.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/core/models/test_models.py b/core/models/test_models.py index e156e91f..d201f708 100644 --- a/core/models/test_models.py +++ b/core/models/test_models.py @@ -64,3 +64,9 @@ def test_transformer(self): config = config_lib.get_test_config() config.model_class = 'Transformer' validate_forward_pass(config, info) + + def test_lstm(self): + info = info_lib.get_test_info() + config = config_lib.get_test_config() + config.model_class = 'LSTM' + validate_forward_pass(config, info) From 467cec18289c5dbfb3d0dccc74e2a3ef9110f5b4 Mon Sep 17 00:00:00 2001 From: RishabGoel Date: Mon, 18 Oct 2021 21:48:13 -0400 Subject: [PATCH 3/8] corrected formatting --- core/models/rnn.py | 2 -- core/modules/ipagnn/encoder.py | 1 + third_party/flax_examples/lstm_modules.py | 3 ++- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/models/rnn.py b/core/models/rnn.py index a50c271b..fa25733b 100644 --- a/core/models/rnn.py +++ b/core/models/rnn.py @@ -34,7 +34,6 @@ def setup(self): ) self.encoder = encoder.LSTMEncoder(lstm_config) - @nn.compact def __call__(self, x): tokens = x['tokens'] @@ -57,7 +56,6 @@ def get_last_state(inputs, last_token): return inputs[last_token-1] get_last_state_batch = jax.vmap(get_last_state) - # encoded_inputs.shape: batch_size, max_tokens, hidden_size x = get_last_state_batch(encoded_inputs, x['num_tokens']) # x.shape: batch_size, 1, hidden_size x = jnp.squeeze(x, 1) diff --git a/core/modules/ipagnn/encoder.py b/core/modules/ipagnn/encoder.py index 3ab5051b..fbf52d3b 100644 --- a/core/modules/ipagnn/encoder.py +++ b/core/modules/ipagnn/encoder.py @@ -70,6 +70,7 @@ def __call__(self, """ cfg = self.config batch_size, max_tokens, _ = encoded_inputs.shape + x = encoded_inputs x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=cfg.deterministic) x = x.astype(cfg.dtype) diff --git a/third_party/flax_examples/lstm_modules.py b/third_party/flax_examples/lstm_modules.py index 52c3226a..05d8bc9c 100644 --- a/third_party/flax_examples/lstm_modules.py +++ b/third_party/flax_examples/lstm_modules.py @@ -31,6 +31,7 @@ import jax.numpy as jnp import numpy as np + @struct.dataclass class LSTMConfig: """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" @@ -65,4 +66,4 @@ def __call__(self, carry, x): def initialize_carry(batch_dims, hidden_size): # Use fixed random key since default state init fn is just zeros. return nn.OptimizedLSTMCell.initialize_carry( - jax.random.PRNGKey(0), batch_dims, hidden_size) \ No newline at end of file + jax.random.PRNGKey(0), batch_dims, hidden_size) From 70602f0f3527881ce189c1d1bbbda35fac685ad0 Mon Sep 17 00:00:00 2001 From: RishabGoel Date: Tue, 26 Oct 2021 21:28:33 -0400 Subject: [PATCH 4/8] incorporated pr reviews --- core/models/rnn.py | 14 +++--- core/modules/ipagnn/encoder.py | 44 +---------------- core/modules/rnn/encoder.py | 47 +++++++++++++++++++ .../modules/rnn/lstm.py | 18 ------- 4 files changed, 55 insertions(+), 68 deletions(-) create mode 100644 core/modules/rnn/encoder.py rename third_party/flax_examples/lstm_modules.py => core/modules/rnn/lstm.py (75%) diff --git a/core/models/rnn.py b/core/models/rnn.py index fa25733b..d76fd061 100644 --- a/core/models/rnn.py +++ b/core/models/rnn.py @@ -4,9 +4,9 @@ from flax import linen as nn import jax.numpy as jnp -from core.modules.ipagnn import encoder +from core.modules.rnn import encoder from core.modules.ipagnn import spans -from third_party.flax_examples import transformer_modules, lstm_modules +from third_party.flax_examples import transformer_modules class LSTM(nn.Module): @@ -21,10 +21,6 @@ def setup(self): max_tokens = config.max_tokens max_num_nodes = config.max_num_nodes max_num_edges = config.max_num_edges - lstm_config = lstm_modules.LSTMConfig( - vocab_size=vocab_size, - num_layers=config.rnn_layers, - hidden_dim=config.hidden_size,) self.token_embedder = spans.NodeAwareTokenEmbedder( transformer_config=self.transformer_config, num_embeddings=vocab_size, @@ -32,7 +28,9 @@ def setup(self): max_tokens=max_tokens, max_num_nodes=max_num_nodes, ) - self.encoder = encoder.LSTMEncoder(lstm_config) + self.encoder = encoder.LSTMEncoder(vocab_size=vocab_size, + num_layers=config.rnn_layers, + hidden_dim=config.hidden_size) @nn.compact def __call__(self, x): @@ -58,7 +56,7 @@ def get_last_state(inputs, last_token): get_last_state_batch = jax.vmap(get_last_state) x = get_last_state_batch(encoded_inputs, x['num_tokens']) # x.shape: batch_size, 1, hidden_size - x = jnp.squeeze(x, 1) + x = jnp.squeeze(x, axis=1) # x.shape: batch_size, hidden_size x = nn.Dense(features=self.info.num_classes)(x) # x.shape: batch_size, num_classes diff --git a/core/modules/ipagnn/encoder.py b/core/modules/ipagnn/encoder.py index fbf52d3b..6469149a 100644 --- a/core/modules/ipagnn/encoder.py +++ b/core/modules/ipagnn/encoder.py @@ -1,6 +1,6 @@ from flax import linen as nn -from third_party.flax_examples import transformer_modules, lstm_modules +from third_party.flax_examples import transformer_modules Encoder1DBlock = transformer_modules.Encoder1DBlock @@ -42,44 +42,4 @@ def __call__(self, encoded = nn.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x) - return encoded - - -class LSTMEncoder(nn.Module): - """LSTM Model Encoder for sequence to sequence translation. - - This Encoder does not encode the input - tokens itself. It assumes the tokens have already been encoded, and any - desired positional embeddings have already been aded. - - Attributes: - config: LSTMConfig dataclass containing hyperparameters. - """ - config: lstm_modules.LSTMConfig - - @nn.compact - def __call__(self, - encoded_inputs,): - """Applies Transformer model on the encoded inputs. - - Args: - encoded_inputs: pre-encoded input data. - - Returns: - output of a lstm encoder. - """ - cfg = self.config - batch_size, max_tokens, _ = encoded_inputs.shape - - x = encoded_inputs - x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=cfg.deterministic) - x = x.astype(cfg.dtype) - - # Input Encoder - for lyr in range(cfg.num_layers): - initial_state = lstm_modules.SimpleLSTM().initialize_carry((batch_size,), cfg.hidden_dim) - _, x = lstm_modules.SimpleLSTM()(initial_state, x) # TODO(rgoel): Add layer norm to each layer - - encoded = nn.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x) - - return encoded + return encoded \ No newline at end of file diff --git a/core/modules/rnn/encoder.py b/core/modules/rnn/encoder.py new file mode 100644 index 00000000..300f3c01 --- /dev/null +++ b/core/modules/rnn/encoder.py @@ -0,0 +1,47 @@ +from typing import Callable, Any, Optional + +import jax.numpy as jnp +from flax import linen as nn + +from core.modules.rnn import lstm + + +class LSTMEncoder(nn.Module): + """LSTM Model Encoder for sequence to sequence translation. + + This Encoder does not encode the input + tokens itself. It assumes the tokens have already been encoded, and any + desired positional embeddings have already been aded. + """ + + dropout_rate: float + num_layers: int + hidden_dim: int + deterministic: bool = True + dtype: Any = jnp.float32 + + @nn.compact + def __call__(self, + encoded_inputs,): + """Applies Transformer model on the encoded inputs. + + Args: + encoded_inputs: pre-encoded input data. + + Returns: + output of a lstm encoder. + """ + batch_size, max_tokens, _ = encoded_inputs.shape + + x = encoded_inputs + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=self.deterministic) + x = x.astype(cfg.dtype) + + # Input Encoder + for layer_num in range(self.num_layers): + initial_state = lstm.SimpleLSTM.initialize_carry((batch_size,), self.hidden_dim) + _, x = lstm.SimpleLSTM(name=f"lstm_{layer_num}")(initial_state, x) # TODO(rgoel): Add layer norm to each layer + + encoded = nn.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x) + + return encoded diff --git a/third_party/flax_examples/lstm_modules.py b/core/modules/rnn/lstm.py similarity index 75% rename from third_party/flax_examples/lstm_modules.py rename to core/modules/rnn/lstm.py index 05d8bc9c..ec681733 100644 --- a/third_party/flax_examples/lstm_modules.py +++ b/core/modules/rnn/lstm.py @@ -32,24 +32,6 @@ import numpy as np -@struct.dataclass -class LSTMConfig: - """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" - vocab_size: int - share_embeddings: bool = False - logits_via_embedding: bool = False - dtype: Any = jnp.float32 - num_layers: int = 1 - hidden_dim: int = 2048 - max_len: int = 2048 - dropout_rate: float = 0.1 - deterministic: bool = False - decode: bool = False - kernel_init: Callable = nn.initializers.xavier_uniform() - bias_init: Callable = nn.initializers.normal(stddev=1e-6) - posemb_init: Optional[Callable] = None - - class SimpleLSTM(nn.Module): """A simple unidirectional LSTM.""" From 7ef978e7754e3e230c6e796b72d18ebad21cd18a Mon Sep 17 00:00:00 2001 From: RishabGoel Date: Tue, 26 Oct 2021 21:35:13 -0400 Subject: [PATCH 5/8] handled lint error --- core/modules/rnn/encoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/modules/rnn/encoder.py b/core/modules/rnn/encoder.py index 300f3c01..140156a3 100644 --- a/core/modules/rnn/encoder.py +++ b/core/modules/rnn/encoder.py @@ -35,13 +35,13 @@ def __call__(self, x = encoded_inputs x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=self.deterministic) - x = x.astype(cfg.dtype) + x = x.astype(self.dtype) # Input Encoder for layer_num in range(self.num_layers): initial_state = lstm.SimpleLSTM.initialize_carry((batch_size,), self.hidden_dim) _, x = lstm.SimpleLSTM(name=f"lstm_{layer_num}")(initial_state, x) # TODO(rgoel): Add layer norm to each layer - encoded = nn.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x) + encoded = nn.LayerNorm(dtype=self.dtype, name='encoder_norm')(x) return encoded From a24914f540ab5d37b63bb9d24634408881fba218 Mon Sep 17 00:00:00 2001 From: RishabGoel Date: Wed, 27 Oct 2021 00:01:17 -0400 Subject: [PATCH 6/8] debugged the test error --- core/models/rnn.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/core/models/rnn.py b/core/models/rnn.py index d76fd061..49d90846 100644 --- a/core/models/rnn.py +++ b/core/models/rnn.py @@ -28,8 +28,7 @@ def setup(self): max_tokens=max_tokens, max_num_nodes=max_num_nodes, ) - self.encoder = encoder.LSTMEncoder(vocab_size=vocab_size, - num_layers=config.rnn_layers, + self.encoder = encoder.LSTMEncoder(num_layers=config.rnn_layers, hidden_dim=config.hidden_size) @nn.compact From 00c21f76f7f3e8afc40bedd702437ae0deb436b7 Mon Sep 17 00:00:00 2001 From: RishabGoel Date: Wed, 27 Oct 2021 09:49:59 -0400 Subject: [PATCH 7/8] resolved failed pytest --- core/modules/rnn/encoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/modules/rnn/encoder.py b/core/modules/rnn/encoder.py index 140156a3..10861c45 100644 --- a/core/modules/rnn/encoder.py +++ b/core/modules/rnn/encoder.py @@ -14,10 +14,10 @@ class LSTMEncoder(nn.Module): desired positional embeddings have already been aded. """ - dropout_rate: float num_layers: int hidden_dim: int deterministic: bool = True + dropout_rate: float = 0.1 dtype: Any = jnp.float32 @nn.compact From eac3f4945298c6418e7f8a6ff2212c036fd74155 Mon Sep 17 00:00:00 2001 From: RishabGoel Date: Wed, 27 Oct 2021 09:56:57 -0400 Subject: [PATCH 8/8] First commit --- core/lib/metrics.py | 1 + 1 file changed, 1 insertion(+) diff --git a/core/lib/metrics.py b/core/lib/metrics.py index 49907118..b9300778 100644 --- a/core/lib/metrics.py +++ b/core/lib/metrics.py @@ -31,6 +31,7 @@ def _generate_next_value_(name, start, count, last_values): CONFUSION_MATRIX = enum.auto() INSTRUCTION_POINTER = enum.auto() LOCALIZATION_ACCURACY = enum.auto() + SIZE_OOD = enum.auto() def all_metric_names() -> Tuple[str]: