diff --git a/core/lib/metrics.py b/core/lib/metrics.py index 49907118..b9300778 100644 --- a/core/lib/metrics.py +++ b/core/lib/metrics.py @@ -31,6 +31,7 @@ def _generate_next_value_(name, start, count, last_values): CONFUSION_MATRIX = enum.auto() INSTRUCTION_POINTER = enum.auto() LOCALIZATION_ACCURACY = enum.auto() + SIZE_OOD = enum.auto() def all_metric_names() -> Tuple[str]: diff --git a/core/lib/models.py b/core/lib/models.py index 1ea70548..c591ad42 100644 --- a/core/lib/models.py +++ b/core/lib/models.py @@ -1,6 +1,7 @@ """Models library.""" from core.models import ipagnn +from core.models import rnn from core.models import mlp from core.models import transformer from core.modules import transformer_config_lib @@ -30,5 +31,11 @@ def make_model(config, info, deterministic): info=info, transformer_config=transformer_config, ) + elif model_class == 'LSTM': + return rnn.LSTM( + config=config, + info=info, + transformer_config=transformer_config, + ) else: raise ValueError('Unexpected model_class.') diff --git a/core/models/rnn.py b/core/models/rnn.py new file mode 100644 index 00000000..49d90846 --- /dev/null +++ b/core/models/rnn.py @@ -0,0 +1,62 @@ +from typing import Any + +import jax +from flax import linen as nn +import jax.numpy as jnp + +from core.modules.rnn import encoder +from core.modules.ipagnn import spans +from third_party.flax_examples import transformer_modules + + +class LSTM(nn.Module): + + config: Any + info: Any + transformer_config: transformer_modules.TransformerConfig + + def setup(self): + config = self.config + vocab_size = self.info.vocab_size + max_tokens = config.max_tokens + max_num_nodes = config.max_num_nodes + max_num_edges = config.max_num_edges + self.token_embedder = spans.NodeAwareTokenEmbedder( + transformer_config=self.transformer_config, + num_embeddings=vocab_size, + features=config.hidden_size, + max_tokens=max_tokens, + max_num_nodes=max_num_nodes, + ) + self.encoder = encoder.LSTMEncoder(num_layers=config.rnn_layers, + hidden_dim=config.hidden_size) + + @nn.compact + def __call__(self, x): + tokens = x['tokens'] + # tokens.shape: batch_size, max_tokens + tokens_mask = tokens > 0 + # tokens_mask.shape: batch_size, max_tokens + encoder_mask = nn.make_attention_mask(tokens_mask, tokens_mask, dtype=jnp.float32) + # encoder_mask.shape: batch_size, 1, max_tokens, max_tokens + # NOTE(rgoel): Ensuring the token encoder is still a Transformer to ensure uniformity. + encoded_inputs = self.token_embedder( + tokens, x['node_token_span_starts'], x['node_token_span_ends'], + x['num_nodes']) + # encoded_inputs.shape: batch_size, max_tokens, hidden_size + encoded_inputs = self.encoder(encoded_inputs) + # encoded_inputs.shape: batch_size, max_tokens, hidden_size + + # NOTE(rgoel): Using only the last state. We can change this to + # pooling across time stamps. + def get_last_state(inputs, last_token): + return inputs[last_token-1] + + get_last_state_batch = jax.vmap(get_last_state) + x = get_last_state_batch(encoded_inputs, x['num_tokens']) + # x.shape: batch_size, 1, hidden_size + x = jnp.squeeze(x, axis=1) + # x.shape: batch_size, hidden_size + x = nn.Dense(features=self.info.num_classes)(x) + # x.shape: batch_size, num_classes + return x, None diff --git a/core/models/test_models.py b/core/models/test_models.py index e156e91f..d201f708 100644 --- a/core/models/test_models.py +++ b/core/models/test_models.py @@ -64,3 +64,9 @@ def test_transformer(self): config = config_lib.get_test_config() config.model_class = 'Transformer' validate_forward_pass(config, info) + + def test_lstm(self): + info = info_lib.get_test_info() + config = config_lib.get_test_config() + config.model_class = 'LSTM' + validate_forward_pass(config, info) diff --git a/core/modules/ipagnn/encoder.py b/core/modules/ipagnn/encoder.py index 375db70f..6469149a 100644 --- a/core/modules/ipagnn/encoder.py +++ b/core/modules/ipagnn/encoder.py @@ -42,4 +42,4 @@ def __call__(self, encoded = nn.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x) - return encoded + return encoded \ No newline at end of file diff --git a/core/modules/rnn/encoder.py b/core/modules/rnn/encoder.py new file mode 100644 index 00000000..10861c45 --- /dev/null +++ b/core/modules/rnn/encoder.py @@ -0,0 +1,47 @@ +from typing import Callable, Any, Optional + +import jax.numpy as jnp +from flax import linen as nn + +from core.modules.rnn import lstm + + +class LSTMEncoder(nn.Module): + """LSTM Model Encoder for sequence to sequence translation. + + This Encoder does not encode the input + tokens itself. It assumes the tokens have already been encoded, and any + desired positional embeddings have already been aded. + """ + + num_layers: int + hidden_dim: int + deterministic: bool = True + dropout_rate: float = 0.1 + dtype: Any = jnp.float32 + + @nn.compact + def __call__(self, + encoded_inputs,): + """Applies Transformer model on the encoded inputs. + + Args: + encoded_inputs: pre-encoded input data. + + Returns: + output of a lstm encoder. + """ + batch_size, max_tokens, _ = encoded_inputs.shape + + x = encoded_inputs + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=self.deterministic) + x = x.astype(self.dtype) + + # Input Encoder + for layer_num in range(self.num_layers): + initial_state = lstm.SimpleLSTM.initialize_carry((batch_size,), self.hidden_dim) + _, x = lstm.SimpleLSTM(name=f"lstm_{layer_num}")(initial_state, x) # TODO(rgoel): Add layer norm to each layer + + encoded = nn.LayerNorm(dtype=self.dtype, name='encoder_norm')(x) + + return encoded diff --git a/core/modules/rnn/lstm.py b/core/modules/rnn/lstm.py new file mode 100644 index 00000000..ec681733 --- /dev/null +++ b/core/modules/rnn/lstm.py @@ -0,0 +1,51 @@ +# Copyright 2021 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LSTM-based machine translation model.""" + +# pylint: disable=attribute-defined-outside-init,g-bare-generic +# See issue #620. +# pytype: disable=wrong-arg-count +# pytype: disable=wrong-keyword-args +# pytype: disable=attribute-error +# Note(rgoel): This module is a mirror of the `transformer_modules` + +from typing import Callable, Any, Optional + +import functools +from flax import linen as nn +from flax import struct +from jax import lax +import jax +import jax.numpy as jnp +import numpy as np + + +class SimpleLSTM(nn.Module): + """A simple unidirectional LSTM.""" + + @functools.partial( + nn.transforms.scan, + variable_broadcast='params', + in_axes=1, out_axes=1, + split_rngs={'params': False}) + @nn.compact + def __call__(self, carry, x): + return nn.OptimizedLSTMCell()(carry, x) + + @staticmethod + def initialize_carry(batch_dims, hidden_size): + # Use fixed random key since default state init fn is just zeros. + return nn.OptimizedLSTMCell.initialize_carry( + jax.random.PRNGKey(0), batch_dims, hidden_size)