Skip to content
This repository was archived by the owner on Jan 22, 2024. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/lib/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def _generate_next_value_(name, start, count, last_values):
CONFUSION_MATRIX = enum.auto()
INSTRUCTION_POINTER = enum.auto()
LOCALIZATION_ACCURACY = enum.auto()
SIZE_OOD = enum.auto()


def all_metric_names() -> Tuple[str]:
Expand Down
7 changes: 7 additions & 0 deletions core/lib/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Models library."""

from core.models import ipagnn
from core.models import rnn
from core.models import mlp
from core.models import transformer
from core.modules import transformer_config_lib
Expand Down Expand Up @@ -30,5 +31,11 @@ def make_model(config, info, deterministic):
info=info,
transformer_config=transformer_config,
)
elif model_class == 'LSTM':
return rnn.LSTM(
config=config,
info=info,
transformer_config=transformer_config,
)
else:
raise ValueError('Unexpected model_class.')
62 changes: 62 additions & 0 deletions core/models/rnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from typing import Any

import jax
from flax import linen as nn
import jax.numpy as jnp

from core.modules.rnn import encoder
from core.modules.ipagnn import spans
from third_party.flax_examples import transformer_modules


class LSTM(nn.Module):

config: Any
info: Any
transformer_config: transformer_modules.TransformerConfig

def setup(self):
config = self.config
vocab_size = self.info.vocab_size
max_tokens = config.max_tokens
max_num_nodes = config.max_num_nodes
max_num_edges = config.max_num_edges
self.token_embedder = spans.NodeAwareTokenEmbedder(
transformer_config=self.transformer_config,
num_embeddings=vocab_size,
features=config.hidden_size,
max_tokens=max_tokens,
max_num_nodes=max_num_nodes,
)
self.encoder = encoder.LSTMEncoder(num_layers=config.rnn_layers,
hidden_dim=config.hidden_size)

@nn.compact
def __call__(self, x):
tokens = x['tokens']
# tokens.shape: batch_size, max_tokens
tokens_mask = tokens > 0
# tokens_mask.shape: batch_size, max_tokens
encoder_mask = nn.make_attention_mask(tokens_mask, tokens_mask, dtype=jnp.float32)
# encoder_mask.shape: batch_size, 1, max_tokens, max_tokens
# NOTE(rgoel): Ensuring the token encoder is still a Transformer to ensure uniformity.
encoded_inputs = self.token_embedder(
tokens, x['node_token_span_starts'], x['node_token_span_ends'],
x['num_nodes'])
# encoded_inputs.shape: batch_size, max_tokens, hidden_size
encoded_inputs = self.encoder(encoded_inputs)
# encoded_inputs.shape: batch_size, max_tokens, hidden_size

# NOTE(rgoel): Using only the last state. We can change this to
# pooling across time stamps.
def get_last_state(inputs, last_token):
return inputs[last_token-1]

get_last_state_batch = jax.vmap(get_last_state)
x = get_last_state_batch(encoded_inputs, x['num_tokens'])
# x.shape: batch_size, 1, hidden_size
x = jnp.squeeze(x, axis=1)
# x.shape: batch_size, hidden_size
x = nn.Dense(features=self.info.num_classes)(x)
# x.shape: batch_size, num_classes
return x, None
6 changes: 6 additions & 0 deletions core/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,9 @@ def test_transformer(self):
config = config_lib.get_test_config()
config.model_class = 'Transformer'
validate_forward_pass(config, info)

def test_lstm(self):
info = info_lib.get_test_info()
config = config_lib.get_test_config()
config.model_class = 'LSTM'
validate_forward_pass(config, info)
2 changes: 1 addition & 1 deletion core/modules/ipagnn/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ def __call__(self,

encoded = nn.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x)

return encoded
return encoded
47 changes: 47 additions & 0 deletions core/modules/rnn/encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Callable, Any, Optional

import jax.numpy as jnp
from flax import linen as nn

from core.modules.rnn import lstm


class LSTMEncoder(nn.Module):
"""LSTM Model Encoder for sequence to sequence translation.

This Encoder does not encode the input
tokens itself. It assumes the tokens have already been encoded, and any
desired positional embeddings have already been aded.
"""

num_layers: int
hidden_dim: int
deterministic: bool = True
dropout_rate: float = 0.1
dtype: Any = jnp.float32

@nn.compact
def __call__(self,
encoded_inputs,):
"""Applies Transformer model on the encoded inputs.

Args:
encoded_inputs: pre-encoded input data.

Returns:
output of a lstm encoder.
"""
batch_size, max_tokens, _ = encoded_inputs.shape

x = encoded_inputs
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=self.deterministic)
x = x.astype(self.dtype)

# Input Encoder
for layer_num in range(self.num_layers):
initial_state = lstm.SimpleLSTM.initialize_carry((batch_size,), self.hidden_dim)
_, x = lstm.SimpleLSTM(name=f"lstm_{layer_num}")(initial_state, x) # TODO(rgoel): Add layer norm to each layer

encoded = nn.LayerNorm(dtype=self.dtype, name='encoder_norm')(x)

return encoded
51 changes: 51 additions & 0 deletions core/modules/rnn/lstm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2021 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""LSTM-based machine translation model."""

# pylint: disable=attribute-defined-outside-init,g-bare-generic
# See issue #620.
# pytype: disable=wrong-arg-count
# pytype: disable=wrong-keyword-args
# pytype: disable=attribute-error
# Note(rgoel): This module is a mirror of the `transformer_modules`

from typing import Callable, Any, Optional

import functools
from flax import linen as nn
from flax import struct
from jax import lax
import jax
import jax.numpy as jnp
import numpy as np


class SimpleLSTM(nn.Module):
"""A simple unidirectional LSTM."""

@functools.partial(
nn.transforms.scan,
variable_broadcast='params',
in_axes=1, out_axes=1,
split_rngs={'params': False})
@nn.compact
def __call__(self, carry, x):
return nn.OptimizedLSTMCell()(carry, x)

@staticmethod
def initialize_carry(batch_dims, hidden_size):
# Use fixed random key since default state init fn is just zeros.
return nn.OptimizedLSTMCell.initialize_carry(
jax.random.PRNGKey(0), batch_dims, hidden_size)