diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..cfa3a61 --- /dev/null +++ b/.flake8 @@ -0,0 +1,3 @@ +[flake8] +ignore = E203, E266, E501, W503, F403, F401 +max-line-length = 88 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..1028092 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,11 @@ +repos: +- repo: https://github.com/ambv/black + rev: 21.6b0 + hooks: + - id: black + language_version: python3.9 +- repo: https://gitlab.com/pycqa/flake8 + rev: 3.9.2 + hooks: + - id: flake8 + diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..35deb69 --- /dev/null +++ b/Makefile @@ -0,0 +1,24 @@ +all: install + +install: venv + : # Activate venv and install requirements + mkdir tmp + source venv/bin/activate && TMPDIR=tmp pip install -r requirements.txt + rm -r tmp/ + pre-commit install + +venv: + : # Create venv if it doesn't exist + : # test -d venv || virtualenv -p python3 --no-site-packages venv + test -d venv || python -m venv venv + +test: venv + source venv/bin/activate && python -m pytest + +clean: + rm -rf venv/ + find -iname "*.pyc" -delete + rm -rf logs/ + rm -rf .pytest_cache + rm -rf tmp/ + diff --git a/README.md b/README.md index f604c3c..d3e0a55 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,23 @@ architecture. ![DNC architecture](images/dnc_model.png) +## Installation +```shell +make install +``` + +The above command will create a virtual environment and install the dependencies and pre-commit hooks. + +Run `source venv/bin/activate` in the root directory of this repository to activate the installed virtual env. + +## Testing +```shell +make test +``` + +Run unit tests in `tests/` using pytest. + + ## Train The `DNC` requires an installation of [TensorFlow](https://www.tensorflow.org/) and [Sonnet](https://github.com/deepmind/sonnet). An example training script is @@ -59,13 +76,26 @@ $ ipython train.py -- --memory_size=64 --num_bits=8 --max_length=3 Periodically saving, or 'checkpointing', the model is disabled by default. To enable, use the `checkpoint_interval` flag. E.g. `--checkpoint_interval=10000` will ensure a checkpoint is created every `10,000` steps. The model will be -checkpointed to `/tmp/tf/dnc/` by default. From there training can be resumed. -To specify an alternate checkpoint directory, use the `checkpoint_dir` flag. -Note: ensure that `/tmp/tf/dnc/` is deleted before training is resumed with +checkpointed to `./logs/repeat_copy/checkpoint` by default. From there training can be resumed. +To specify an alternate checkpoint directory, use the `log_dir` flag. +Note: ensure that existing checkpoints are deleted or moved before training is resumed with different model parameters, to avoid shape inconsistency errors. More generally, the `DNC` class found within `dnc.py` can be used as a standard TensorFlow rnn core and unrolled with TensorFlow rnn ops, such as -`tf.nn.dynamic_rnn` on any sequential task. +`keras.layers.RNN` on any sequential task. + +## Model Inspection +```shell +jupyter notebook interactive.ipynb +``` + +Jupyter notebook that loads a trained model from checkpoints. It provides helper functions for evaluating arbitrary input bit sequences and visualizing output and intermediate read/write states. + +```shell +tensorboard --logdir logs/repeat_copy/ +``` + +Tensorboard visualization of test/train loss and TensorFlow Graph. Test/Train loss is emitted based on `report_interval`. Disclaimer: This is not an official Google product diff --git a/dnc/access.py b/dnc/access.py index 211d454..87fcd89 100644 --- a/dnc/access.py +++ b/dnc/access.py @@ -18,301 +18,350 @@ from __future__ import division from __future__ import print_function -import collections +import numpy as np import sonnet as snt import tensorflow as tf -from dnc import addressing -from dnc import util +from dnc import addressing, util -AccessState = collections.namedtuple('AccessState', ( - 'memory', 'read_weights', 'write_weights', 'linkage', 'usage')) +# For indexing directly into MemoryAccess state +MEMORY = 0 +READ_WEIGHTS = 1 +WRITE_WEIGHTS = 2 +LINKAGE = 3 +USAGE = 4 def _erase_and_write(memory, address, reset_weights, values): - """Module to erase and write in the external memory. + """Module to erase and write in the external memory. - Erase operation: - M_t'(i) = M_{t-1}(i) * (1 - w_t(i) * e_t) + Erase operation: + M_t'(i) = M_{t-1}(i) * (1 - w_t(i) * e_t) - Add operation: - M_t(i) = M_t'(i) + w_t(i) * a_t + Add operation: + M_t(i) = M_t'(i) + w_t(i) * a_t - where e are the reset_weights, w the write weights and a the values. + where e are the reset_weights, w the write weights and a the values. - Args: - memory: 3-D tensor of shape `[batch_size, memory_size, word_size]`. - address: 3-D tensor `[batch_size, num_writes, memory_size]`. - reset_weights: 3-D tensor `[batch_size, num_writes, word_size]`. - values: 3-D tensor `[batch_size, num_writes, word_size]`. + Args: + memory: 3-D tensor of shape `[batch_size, memory_size, word_size]`. + address: 3-D tensor `[batch_size, num_writes, memory_size]`. + reset_weights: 3-D tensor `[batch_size, num_writes, word_size]`. + values: 3-D tensor `[batch_size, num_writes, word_size]`. - Returns: - 3-D tensor of shape `[batch_size, num_writes, word_size]`. - """ - with tf.name_scope('erase_memory', values=[memory, address, reset_weights]): + Returns: + 3-D tensor of shape `[batch_size, num_writes, word_size]`. + """ expand_address = tf.expand_dims(address, 3) reset_weights = tf.expand_dims(reset_weights, 2) weighted_resets = expand_address * reset_weights reset_gate = util.reduce_prod(1 - weighted_resets, 1) memory *= reset_gate - with tf.name_scope('additive_write', values=[memory, address, values]): add_matrix = tf.matmul(address, values, adjoint_a=True) memory += add_matrix - return memory + return memory class MemoryAccess(snt.RNNCore): - """Access module of the Differentiable Neural Computer. + """Access module of the Differentiable Neural Computer. - This memory module supports multiple read and write heads. It makes use of: + This memory module supports multiple read and write heads. It makes use of: - * `addressing.TemporalLinkage` to track the temporal ordering of writes in - memory for each write head. - * `addressing.FreenessAllocator` for keeping track of memory usage, where - usage increase when a memory location is written to, and decreases when - memory is read from that the controller says can be freed. + * `addressing.TemporalLinkage` to track the temporal ordering of writes in + memory for each write head. + * `addressing.Freeness` for keeping track of memory usage, where + usage increase when a memory location is written to, and decreases when + memory is read from that the controller says can be freed. - Write-address selection is done by an interpolation between content-based - lookup and using unused memory. + Write-address selection is done by an interpolation between content-based + lookup and using unused memory. - Read-address selection is done by an interpolation of content-based lookup - and following the link graph in the forward or backwards read direction. - """ - - def __init__(self, - memory_size=128, - word_size=20, - num_reads=1, - num_writes=1, - name='memory_access'): - """Creates a MemoryAccess module. - - Args: - memory_size: The number of memory slots (N in the DNC paper). - word_size: The width of each memory slot (W in the DNC paper) - num_reads: The number of read heads (R in the DNC paper). - num_writes: The number of write heads (fixed at 1 in the paper). - name: The name of the module. + Read-address selection is done by an interpolation of content-based lookup + and following the link graph in the forward or backwards read direction. """ - super(MemoryAccess, self).__init__(name=name) - self._memory_size = memory_size - self._word_size = word_size - self._num_reads = num_reads - self._num_writes = num_writes - self._write_content_weights_mod = addressing.CosineWeights( - num_writes, word_size, name='write_content_weights') - self._read_content_weights_mod = addressing.CosineWeights( - num_reads, word_size, name='read_content_weights') - - self._linkage = addressing.TemporalLinkage(memory_size, num_writes) - self._freeness = addressing.Freeness(memory_size) - - def _build(self, inputs, prev_state): - """Connects the MemoryAccess module into the graph. - - Args: - inputs: tensor of shape `[batch_size, input_size]`. This is used to - control this access module. - prev_state: Instance of `AccessState` containing the previous state. - - Returns: - A tuple `(output, next_state)`, where `output` is a tensor of shape - `[batch_size, num_reads, word_size]`, and `next_state` is the new - `AccessState` named tuple at the current time t. - """ - inputs = self._read_inputs(inputs) - - # Update usage using inputs['free_gate'] and previous read & write weights. - usage = self._freeness( - write_weights=prev_state.write_weights, - free_gate=inputs['free_gate'], - read_weights=prev_state.read_weights, - prev_usage=prev_state.usage) - - # Write to memory. - write_weights = self._write_weights(inputs, prev_state.memory, usage) - memory = _erase_and_write( - prev_state.memory, - address=write_weights, - reset_weights=inputs['erase_vectors'], - values=inputs['write_vectors']) - - linkage_state = self._linkage(write_weights, prev_state.linkage) - - # Read from memory. - read_weights = self._read_weights( - inputs, - memory=memory, - prev_read_weights=prev_state.read_weights, - link=linkage_state.link) - read_words = tf.matmul(read_weights, memory) - - return (read_words, AccessState( - memory=memory, - read_weights=read_weights, - write_weights=write_weights, - linkage=linkage_state, - usage=usage)) - - def _read_inputs(self, inputs): - """Applies transformations to `inputs` to get control for this module.""" - - def _linear(first_dim, second_dim, name, activation=None): - """Returns a linear transformation of `inputs`, followed by a reshape.""" - linear = snt.Linear(first_dim * second_dim, name=name)(inputs) - if activation is not None: - linear = activation(linear, name=name + '_activation') - return tf.reshape(linear, [-1, first_dim, second_dim]) - - # v_t^i - The vectors to write to memory, for each write head `i`. - write_vectors = _linear(self._num_writes, self._word_size, 'write_vectors') - - # e_t^i - Amount to erase the memory by before writing, for each write head. - erase_vectors = _linear(self._num_writes, self._word_size, 'erase_vectors', - tf.sigmoid) - - # f_t^j - Amount that the memory at the locations read from at the previous - # time step can be declared unused, for each read head `j`. - free_gate = tf.sigmoid( - snt.Linear(self._num_reads, name='free_gate')(inputs)) - - # g_t^{a, i} - Interpolation between writing to unallocated memory and - # content-based lookup, for each write head `i`. Note: `a` is simply used to - # identify this gate with allocation vs writing (as defined below). - allocation_gate = tf.sigmoid( - snt.Linear(self._num_writes, name='allocation_gate')(inputs)) - - # g_t^{w, i} - Overall gating of write amount for each write head. - write_gate = tf.sigmoid( - snt.Linear(self._num_writes, name='write_gate')(inputs)) - - # \pi_t^j - Mixing between "backwards" and "forwards" positions (for - # each write head), and content-based lookup, for each read head. - num_read_modes = 1 + 2 * self._num_writes - read_mode = snt.BatchApply(tf.nn.softmax)( - _linear(self._num_reads, num_read_modes, name='read_mode')) - - # Parameters for the (read / write) "weights by content matching" modules. - write_keys = _linear(self._num_writes, self._word_size, 'write_keys') - write_strengths = snt.Linear(self._num_writes, name='write_strengths')( - inputs) - - read_keys = _linear(self._num_reads, self._word_size, 'read_keys') - read_strengths = snt.Linear(self._num_reads, name='read_strengths')(inputs) - - result = { - 'read_content_keys': read_keys, - 'read_content_strengths': read_strengths, - 'write_content_keys': write_keys, - 'write_content_strengths': write_strengths, - 'write_vectors': write_vectors, - 'erase_vectors': erase_vectors, - 'free_gate': free_gate, - 'allocation_gate': allocation_gate, - 'write_gate': write_gate, - 'read_mode': read_mode, - } - return result - - def _write_weights(self, inputs, memory, usage): - """Calculates the memory locations to write to. - - This uses a combination of content-based lookup and finding an unused - location in memory, for each write head. - - Args: - inputs: Collection of inputs to the access module, including controls for - how to chose memory writing, such as the content to look-up and the - weighting between content-based and allocation-based addressing. - memory: A tensor of shape `[batch_size, memory_size, word_size]` - containing the current memory contents. - usage: Current memory usage, which is a tensor of shape `[batch_size, - memory_size]`, used for allocation-based addressing. - - Returns: - tensor of shape `[batch_size, num_writes, memory_size]` indicating where - to write to (if anywhere) for each write head. - """ - with tf.name_scope('write_weights', values=[inputs, memory, usage]): - # c_t^{w, i} - The content-based weights for each write head. - write_content_weights = self._write_content_weights_mod( - memory, inputs['write_content_keys'], - inputs['write_content_strengths']) - - # a_t^i - The allocation weights for each write head. - write_allocation_weights = self._freeness.write_allocation_weights( - usage=usage, - write_gates=(inputs['allocation_gate'] * inputs['write_gate']), - num_writes=self._num_writes) - - # Expands gates over memory locations. - allocation_gate = tf.expand_dims(inputs['allocation_gate'], -1) - write_gate = tf.expand_dims(inputs['write_gate'], -1) - - # w_t^{w, i} - The write weightings for each write head. - return write_gate * (allocation_gate * write_allocation_weights + - (1 - allocation_gate) * write_content_weights) - - def _read_weights(self, inputs, memory, prev_read_weights, link): - """Calculates read weights for each read head. - - The read weights are a combination of following the link graphs in the - forward or backward directions from the previous read position, and doing - content-based lookup. The interpolation between these different modes is - done by `inputs['read_mode']`. - - Args: - inputs: Controls for this access module. This contains the content-based - keys to lookup, and the weightings for the different read modes. - memory: A tensor of shape `[batch_size, memory_size, word_size]` - containing the current memory contents to do content-based lookup. - prev_read_weights: A tensor of shape `[batch_size, num_reads, - memory_size]` containing the previous read locations. - link: A tensor of shape `[batch_size, num_writes, memory_size, - memory_size]` containing the temporal write transition graphs. - - Returns: - A tensor of shape `[batch_size, num_reads, memory_size]` containing the - read weights for each read head. - """ - with tf.name_scope( - 'read_weights', values=[inputs, memory, prev_read_weights, link]): - # c_t^{r, i} - The content weightings for each read head. - content_weights = self._read_content_weights_mod( - memory, inputs['read_content_keys'], inputs['read_content_strengths']) - - # Calculates f_t^i and b_t^i. - forward_weights = self._linkage.directional_read_weights( - link, prev_read_weights, forward=True) - backward_weights = self._linkage.directional_read_weights( - link, prev_read_weights, forward=False) - - backward_mode = inputs['read_mode'][:, :, :self._num_writes] - forward_mode = ( - inputs['read_mode'][:, :, self._num_writes:2 * self._num_writes]) - content_mode = inputs['read_mode'][:, :, 2 * self._num_writes] - - read_weights = ( - tf.expand_dims(content_mode, 2) * content_weights + tf.reduce_sum( - tf.expand_dims(forward_mode, 3) * forward_weights, 2) + - tf.reduce_sum(tf.expand_dims(backward_mode, 3) * backward_weights, 2)) - - return read_weights - - @property - def state_size(self): - """Returns a tuple of the shape of the state tensors.""" - return AccessState( - memory=tf.TensorShape([self._memory_size, self._word_size]), - read_weights=tf.TensorShape([self._num_reads, self._memory_size]), - write_weights=tf.TensorShape([self._num_writes, self._memory_size]), - linkage=self._linkage.state_size, - usage=self._freeness.state_size) - - @property - def output_size(self): - """Returns the output shape.""" - return tf.TensorShape([self._num_reads, self._word_size]) + def __init__( + self, + memory_size=128, + word_size=20, + num_reads=1, + num_writes=1, + name="memory_access", + dtype=tf.float32, + ): + """Creates a MemoryAccess module. + + Args: + memory_size: The number of memory slots (N in the DNC paper). + word_size: The width of each memory slot (W in the DNC paper) + num_reads: The number of read heads (R in the DNC paper). + num_writes: The number of write heads (fixed at 1 in the paper). + name: The name of the module. + """ + super(MemoryAccess, self).__init__(name=name) + self._memory_size = memory_size + self._word_size = word_size + self._num_reads = num_reads + self._num_writes = num_writes + + self._dtype = dtype + + self._write_content_weights_mod = addressing.CosineWeights( + num_writes, word_size, name="write_content_weights" + ) + self._read_content_weights_mod = addressing.CosineWeights( + num_reads, word_size, name="read_content_weights" + ) + + self._linkage = addressing.TemporalLinkage(memory_size, num_writes, dtype=dtype) + self._freeness = addressing.Freeness(memory_size, dtype=dtype) + + self._linear_layers = {} + + # keras.layers.RNN abstract method + def call(self, inputs, prev_state): + return self.__call__(inputs, prev_state) + + # sonnet.RNNCore abstract method + def __call__(self, inputs, prev_state): + """Connects the MemoryAccess module into the graph. + + Args: + inputs: tensor of shape `[batch_size, input_size]`. This is used to + control this access module. + prev_state: nested list of tensors containing the previous state. + + Returns: + A tuple `(output, next_state)`, where `output` is a tensor of shape + `[batch_size, num_reads, word_size]`, and `next_state` is the new + nested list of tensors at the current time t. + """ + ( + prev_memory, + prev_read_weights, + prev_write_weights, + prev_linkage, + prev_usage, + ) = prev_state + + inputs = self._read_inputs(inputs) + + # Update usage using inputs['free_gate'] and previous read & write weights. + usage = self._freeness( + write_weights=prev_write_weights, + free_gate=inputs["free_gate"], + read_weights=prev_read_weights, + prev_usage=prev_usage, + ) + + # Write to memory. + write_weights = self._write_weights(inputs, prev_memory, usage) + memory = _erase_and_write( + prev_memory, + address=write_weights, + reset_weights=inputs["erase_vectors"], + values=inputs["write_vectors"], + ) + + [link, precedence_weights] = self._linkage(write_weights, prev_linkage) + + # Read from memory. + read_weights = self._read_weights( + inputs, memory=memory, prev_read_weights=prev_read_weights, link=link + ) + read_words = tf.matmul(read_weights, memory) + + return ( + read_words, + [memory, read_weights, write_weights, [link, precedence_weights], usage], + ) + + def _read_inputs(self, inputs): + """Applies transformations to `inputs` to get control for this module.""" + + def _linear(dims, name, activation=None): + """Returns a linear transformation of `inputs`, followed by a reshape.""" + linear = self._linear_layers.get(name) + if not linear: + linear = snt.Linear(np.prod(dims), name=name) + self._linear_layers[name] = linear + + linear = linear(inputs) + if activation is not None: + linear = activation(linear, name=name + "_activation") + return tf.reshape(linear, [-1, *dims]) + + # v_t^i - The vectors to write to memory, for each write head `i`. + write_vectors = _linear([self._num_writes, self._word_size], "write_vectors") + + # e_t^i - Amount to erase the memory by before writing, for each write head. + erase_vectors = _linear( + [self._num_writes, self._word_size], "erase_vectors", tf.sigmoid + ) + + # f_t^j - Amount that the memory at the locations read from at the previous + # time step can be declared unused, for each read head `j`. + free_gate = _linear([self._num_reads], "free_gate", tf.sigmoid) + + # g_t^{a, i} - Interpolation between writing to unallocated memory and + # content-based lookup, for each write head `i`. Note: `a` is simply used to + # identify this gate with allocation vs writing (as defined below). + allocation_gate = _linear([self._num_writes], "allocation_gate", tf.sigmoid) + + # g_t^{w, i} - Overall gating of write amount for each write head. + write_gate = _linear([self._num_writes], "write_gate", tf.sigmoid) + + # \pi_t^j - Mixing between "backwards" and "forwards" positions (for + # each write head), and content-based lookup, for each read head. + num_read_modes = 1 + 2 * self._num_writes + read_mode = snt.BatchApply(tf.nn.softmax)( + _linear([self._num_reads, num_read_modes], name="read_mode") + ) + + # Parameters for the (read / write) "weights by content matching" modules. + write_keys = _linear([self._num_writes, self._word_size], "write_keys") + write_strengths = _linear([self._num_writes], name="write_strengths") + + read_keys = _linear([self._num_reads, self._word_size], "read_keys") + read_strengths = _linear([self._num_reads], name="read_strengths") + + result = { + "read_content_keys": read_keys, + "read_content_strengths": read_strengths, + "write_content_keys": write_keys, + "write_content_strengths": write_strengths, + "write_vectors": write_vectors, + "erase_vectors": erase_vectors, + "free_gate": free_gate, + "allocation_gate": allocation_gate, + "write_gate": write_gate, + "read_mode": read_mode, + } + return result + + def _write_weights(self, inputs, memory, usage): + """Calculates the memory locations to write to. + + This uses a combination of content-based lookup and finding an unused + location in memory, for each write head. + + Args: + inputs: Collection of inputs to the access module, including controls for + how to chose memory writing, such as the content to look-up and the + weighting between content-based and allocation-based addressing. + memory: A tensor of shape `[batch_size, memory_size, word_size]` + containing the current memory contents. + usage: Current memory usage, which is a tensor of shape `[batch_size, + memory_size]`, used for allocation-based addressing. + + Returns: + tensor of shape `[batch_size, num_writes, memory_size]` indicating where + to write to (if anywhere) for each write head. + """ + # c_t^{w, i} - The content-based weights for each write head. + write_content_weights = self._write_content_weights_mod( + memory, inputs["write_content_keys"], inputs["write_content_strengths"] + ) + + # a_t^i - The allocation weights for each write head. + write_allocation_weights = self._freeness.write_allocation_weights( + usage=usage, + write_gates=(inputs["allocation_gate"] * inputs["write_gate"]), + num_writes=self._num_writes, + ) + + # Expands gates over memory locations. + allocation_gate = tf.expand_dims(inputs["allocation_gate"], -1) + write_gate = tf.expand_dims(inputs["write_gate"], -1) + + # w_t^{w, i} - The write weightings for each write head. + return write_gate * ( + allocation_gate * write_allocation_weights + + (1 - allocation_gate) * write_content_weights + ) + + def _read_weights(self, inputs, memory, prev_read_weights, link): + """Calculates read weights for each read head. + + The read weights are a combination of following the link graphs in the + forward or backward directions from the previous read position, and doing + content-based lookup. The interpolation between these different modes is + done by `inputs['read_mode']`. + + Args: + inputs: Controls for this access module. This contains the content-based + keys to lookup, and the weightings for the different read modes. + memory: A tensor of shape `[batch_size, memory_size, word_size]` + containing the current memory contents to do content-based lookup. + prev_read_weights: A tensor of shape `[batch_size, num_reads, + memory_size]` containing the previous read locations. + link: A tensor of shape `[batch_size, num_writes, memory_size, + memory_size]` containing the temporal write transition graphs. + + Returns: + A tensor of shape `[batch_size, num_reads, memory_size]` containing the + read weights for each read head. + """ + # c_t^{r, i} - The content weightings for each read head. + content_weights = self._read_content_weights_mod( + memory, inputs["read_content_keys"], inputs["read_content_strengths"] + ) + + # Calculates f_t^i and b_t^i. + forward_weights = self._linkage.directional_read_weights( + link, prev_read_weights, forward=True + ) + backward_weights = self._linkage.directional_read_weights( + link, prev_read_weights, forward=False + ) + + backward_mode = inputs["read_mode"][:, :, : self._num_writes] + forward_mode = inputs["read_mode"][ + :, :, self._num_writes : 2 * self._num_writes + ] + content_mode = inputs["read_mode"][:, :, 2 * self._num_writes] + + read_weights = ( + tf.expand_dims(content_mode, 2) * content_weights + + tf.reduce_sum( + input_tensor=tf.expand_dims(forward_mode, 3) * forward_weights, axis=2 + ) + + tf.reduce_sum( + input_tensor=tf.expand_dims(backward_mode, 3) * backward_weights, axis=2 + ) + ) + + return read_weights + + # keras uses get_initial_state + def get_initial_state(self, batch_size=None, inputs=None, dtype=None): + return util.initial_state_from_state_size( + self.state_size, batch_size, self._dtype + ) + + # snt.RNNCore uses initial_state + def initial_state(self, batch_size): + return self.get_initial_state(batch_size=batch_size) + + @property + def state_size(self): + """Returns a list of the shape of the state tensors.""" + return [ + # memory + tf.TensorShape([self._memory_size, self._word_size]), + # read_weights + tf.TensorShape([self._num_reads, self._memory_size]), + # write_weights + tf.TensorShape([self._num_writes, self._memory_size]), + # linkage + self._linkage.state_size, + # usage + self._freeness.state_size, + ] + + @property + def output_size(self): + """Returns the output shape.""" + return tf.TensorShape([self._num_reads, self._word_size]) diff --git a/dnc/access_test.py b/dnc/access_test.py deleted file mode 100644 index 20fe7c2..0000000 --- a/dnc/access_test.py +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright 2017 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for memory access.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import tensorflow as tf -from tensorflow.python.ops import rnn - -from dnc import access -from dnc import util - -BATCH_SIZE = 2 -MEMORY_SIZE = 20 -WORD_SIZE = 6 -NUM_READS = 2 -NUM_WRITES = 3 -TIME_STEPS = 4 -INPUT_SIZE = 10 - - -class MemoryAccessTest(tf.test.TestCase): - - def setUp(self): - self.module = access.MemoryAccess(MEMORY_SIZE, WORD_SIZE, NUM_READS, - NUM_WRITES) - self.initial_state = self.module.initial_state(BATCH_SIZE) - - def testBuildAndTrain(self): - inputs = tf.random_normal([TIME_STEPS, BATCH_SIZE, INPUT_SIZE]) - - output, _ = rnn.dynamic_rnn( - cell=self.module, - inputs=inputs, - initial_state=self.initial_state, - time_major=True) - - targets = np.random.rand(TIME_STEPS, BATCH_SIZE, NUM_READS, WORD_SIZE) - loss = tf.reduce_mean(tf.square(output - targets)) - train_op = tf.train.GradientDescentOptimizer(1).minimize(loss) - init = tf.global_variables_initializer() - - with self.test_session(): - init.run() - train_op.run() - - def testValidReadMode(self): - inputs = self.module._read_inputs( - tf.random_normal([BATCH_SIZE, INPUT_SIZE])) - init = tf.global_variables_initializer() - - with self.test_session() as sess: - init.run() - inputs = sess.run(inputs) - - # Check that the read modes for each read head constitute a probability - # distribution. - self.assertAllClose(inputs['read_mode'].sum(2), - np.ones([BATCH_SIZE, NUM_READS])) - self.assertGreaterEqual(inputs['read_mode'].min(), 0) - - def testWriteWeights(self): - memory = 10 * (np.random.rand(BATCH_SIZE, MEMORY_SIZE, WORD_SIZE) - 0.5) - usage = np.random.rand(BATCH_SIZE, MEMORY_SIZE) - - allocation_gate = np.random.rand(BATCH_SIZE, NUM_WRITES) - write_gate = np.random.rand(BATCH_SIZE, NUM_WRITES) - write_content_keys = np.random.rand(BATCH_SIZE, NUM_WRITES, WORD_SIZE) - write_content_strengths = np.random.rand(BATCH_SIZE, NUM_WRITES) - - # Check that turning on allocation gate fully brings the write gate to - # the allocation weighting (which we will control by controlling the usage). - usage[:, 3] = 0 - allocation_gate[:, 0] = 1 - write_gate[:, 0] = 1 - - inputs = { - 'allocation_gate': tf.constant(allocation_gate), - 'write_gate': tf.constant(write_gate), - 'write_content_keys': tf.constant(write_content_keys), - 'write_content_strengths': tf.constant(write_content_strengths) - } - - weights = self.module._write_weights(inputs, - tf.constant(memory), - tf.constant(usage)) - - with self.test_session(): - weights = weights.eval() - - # Check the weights sum to their target gating. - self.assertAllClose(np.sum(weights, axis=2), write_gate, atol=5e-2) - - # Check that we fully allocated to the third row. - weights_0_0_target = util.one_hot(MEMORY_SIZE, 3) - self.assertAllClose(weights[0, 0], weights_0_0_target, atol=1e-3) - - def testReadWeights(self): - memory = 10 * (np.random.rand(BATCH_SIZE, MEMORY_SIZE, WORD_SIZE) - 0.5) - prev_read_weights = np.random.rand(BATCH_SIZE, NUM_READS, MEMORY_SIZE) - prev_read_weights /= prev_read_weights.sum(2, keepdims=True) + 1 - - link = np.random.rand(BATCH_SIZE, NUM_WRITES, MEMORY_SIZE, MEMORY_SIZE) - # Row and column sums should be at most 1: - link /= np.maximum(link.sum(2, keepdims=True), 1) - link /= np.maximum(link.sum(3, keepdims=True), 1) - - # We query the memory on the third location in memory, and select a large - # strength on the query. Then we select a content-based read-mode. - read_content_keys = np.random.rand(BATCH_SIZE, NUM_READS, WORD_SIZE) - read_content_keys[0, 0] = memory[0, 3] - read_content_strengths = tf.constant( - 100., shape=[BATCH_SIZE, NUM_READS], dtype=tf.float64) - read_mode = np.random.rand(BATCH_SIZE, NUM_READS, 1 + 2 * NUM_WRITES) - read_mode[0, 0, :] = util.one_hot(1 + 2 * NUM_WRITES, 2 * NUM_WRITES) - inputs = { - 'read_content_keys': tf.constant(read_content_keys), - 'read_content_strengths': read_content_strengths, - 'read_mode': tf.constant(read_mode), - } - read_weights = self.module._read_weights(inputs, memory, prev_read_weights, - link) - with self.test_session(): - read_weights = read_weights.eval() - - # read_weights for batch 0, read head 0 should be memory location 3 - self.assertAllClose( - read_weights[0, 0, :], util.one_hot(MEMORY_SIZE, 3), atol=1e-3) - - def testGradients(self): - inputs = tf.constant(np.random.randn(BATCH_SIZE, INPUT_SIZE), tf.float32) - output, _ = self.module(inputs, self.initial_state) - loss = tf.reduce_sum(output) - - tensors_to_check = [ - inputs, self.initial_state.memory, self.initial_state.read_weights, - self.initial_state.linkage.precedence_weights, - self.initial_state.linkage.link - ] - shapes = [x.get_shape().as_list() for x in tensors_to_check] - with self.test_session() as sess: - sess.run(tf.global_variables_initializer()) - err = tf.test.compute_gradient_error(tensors_to_check, shapes, loss, [1]) - self.assertLess(err, 0.1) - - -if __name__ == '__main__': - tf.test.main() diff --git a/dnc/addressing.py b/dnc/addressing.py index 97365b1..6c4832c 100644 --- a/dnc/addressing.py +++ b/dnc/addressing.py @@ -18,7 +18,6 @@ from __future__ import division from __future__ import print_function -import collections import sonnet as snt import tensorflow as tf @@ -27,384 +26,391 @@ # Ensure values are greater than epsilon to avoid numerical instability. _EPSILON = 1e-6 -TemporalLinkageState = collections.namedtuple('TemporalLinkageState', - ('link', 'precedence_weights')) +# For indexing directly into TemporalLinkage state +LINK = 0 +PRECEDENCE_WEIGHTS = 1 def _vector_norms(m): - squared_norms = tf.reduce_sum(m * m, axis=2, keepdims=True) - return tf.sqrt(squared_norms + _EPSILON) + squared_norms = tf.compat.v1.reduce_sum(input_tensor=m * m, axis=2, keepdims=True) + return tf.sqrt(squared_norms + _EPSILON) def weighted_softmax(activations, strengths, strengths_op): - """Returns softmax over activations multiplied by positive strengths. - - Args: - activations: A tensor of shape `[batch_size, num_heads, memory_size]`, of - activations to be transformed. Softmax is taken over the last dimension. - strengths: A tensor of shape `[batch_size, num_heads]` containing strengths to - multiply by the activations prior to the softmax. - strengths_op: An operation to transform strengths before softmax. - - Returns: - A tensor of same shape as `activations` with weighted softmax applied. - """ - transformed_strengths = tf.expand_dims(strengths_op(strengths), -1) - sharp_activations = activations * transformed_strengths - softmax = snt.BatchApply(module_or_op=tf.nn.softmax) - return softmax(sharp_activations) - - -class CosineWeights(snt.AbstractModule): - """Cosine-weighted attention. - - Calculates the cosine similarity between a query and each word in memory, then - applies a weighted softmax to return a sharp distribution. - """ - - def __init__(self, - num_heads, - word_size, - strength_op=tf.nn.softplus, - name='cosine_weights'): - """Initializes the CosineWeights module. + """Returns softmax over activations multiplied by positive strengths. Args: - num_heads: number of memory heads. - word_size: memory word size. - strength_op: operation to apply to strengths (default is tf.nn.softplus). - name: module name (default 'cosine_weights') - """ - super(CosineWeights, self).__init__(name=name) - self._num_heads = num_heads - self._word_size = word_size - self._strength_op = strength_op - - def _build(self, memory, keys, strengths): - """Connects the CosineWeights module into the graph. - - Args: - memory: A 3-D tensor of shape `[batch_size, memory_size, word_size]`. - keys: A 3-D tensor of shape `[batch_size, num_heads, word_size]`. - strengths: A 2-D tensor of shape `[batch_size, num_heads]`. + activations: A tensor of shape `[batch_size, num_heads, memory_size]`, of + activations to be transformed. Softmax is taken over the last dimension. + strengths: A tensor of shape `[batch_size, num_heads]` containing strengths to + multiply by the activations prior to the softmax. + strengths_op: An operation to transform strengths before softmax. Returns: - Weights tensor of shape `[batch_size, num_heads, memory_size]`. - """ - # Calculates the inner product between the query vector and words in memory. - dot = tf.matmul(keys, memory, adjoint_b=True) - - # Outer product to compute denominator (euclidean norm of query and memory). - memory_norms = _vector_norms(memory) - key_norms = _vector_norms(keys) - norm = tf.matmul(key_norms, memory_norms, adjoint_b=True) - - # Calculates cosine similarity between the query vector and words in memory. - similarity = dot / (norm + _EPSILON) - - return weighted_softmax(similarity, strengths, self._strength_op) - - -class TemporalLinkage(snt.RNNCore): - """Keeps track of write order for forward and backward addressing. - - This is a pseudo-RNNCore module, whose state is a pair `(link, - precedence_weights)`, where `link` is a (collection of) graphs for (possibly - multiple) write heads (represented by a tensor with values in the range - [0, 1]), and `precedence_weights` records the "previous write locations" used - to build the link graphs. - - The function `directional_read_weights` computes addresses following the - forward and backward directions in the link graphs. - """ - - def __init__(self, memory_size, num_writes, name='temporal_linkage'): - """Construct a TemporalLinkage module. - - Args: - memory_size: The number of memory slots. - num_writes: The number of write heads. - name: Name of the module. + A tensor of same shape as `activations` with weighted softmax applied. """ - super(TemporalLinkage, self).__init__(name=name) - self._memory_size = memory_size - self._num_writes = num_writes - - def _build(self, write_weights, prev_state): - """Calculate the updated linkage state given the write weights. - - Args: - write_weights: A tensor of shape `[batch_size, num_writes, memory_size]` - containing the memory addresses of the different write heads. - prev_state: `TemporalLinkageState` tuple containg a tensor `link` of - shape `[batch_size, num_writes, memory_size, memory_size]`, and a - tensor `precedence_weights` of shape `[batch_size, num_writes, - memory_size]` containing the aggregated history of recent writes. - - Returns: - A `TemporalLinkageState` tuple `next_state`, which contains the updated - link and precedence weights. - """ - link = self._link(prev_state.link, prev_state.precedence_weights, - write_weights) - precedence_weights = self._precedence_weights(prev_state.precedence_weights, - write_weights) - return TemporalLinkageState( - link=link, precedence_weights=precedence_weights) - - def directional_read_weights(self, link, prev_read_weights, forward): - """Calculates the forward or the backward read weights. - - For each read head (at a given address), there are `num_writes` link graphs - to follow. Thus this function computes a read address for each of the - `num_reads * num_writes` pairs of read and write heads. - - Args: - link: tensor of shape `[batch_size, num_writes, memory_size, - memory_size]` representing the link graphs L_t. - prev_read_weights: tensor of shape `[batch_size, num_reads, - memory_size]` containing the previous read weights w_{t-1}^r. - forward: Boolean indicating whether to follow the "future" direction in - the link graph (True) or the "past" direction (False). - - Returns: - tensor of shape `[batch_size, num_reads, num_writes, memory_size]` - """ - with tf.name_scope('directional_read_weights'): - # We calculate the forward and backward directions for each pair of - # read and write heads; hence we need to tile the read weights and do a - # sort of "outer product" to get this. - expanded_read_weights = tf.stack([prev_read_weights] * self._num_writes, - 1) - result = tf.matmul(expanded_read_weights, link, adjoint_b=forward) - # Swap dimensions 1, 2 so order is [batch, reads, writes, memory]: - return tf.transpose(result, perm=[0, 2, 1, 3]) - - def _link(self, prev_link, prev_precedence_weights, write_weights): - """Calculates the new link graphs. - - For each write head, the link is a directed graph (represented by a matrix - with entries in range [0, 1]) whose vertices are the memory locations, and - an edge indicates temporal ordering of writes. - - Args: - prev_link: A tensor of shape `[batch_size, num_writes, memory_size, - memory_size]` representing the previous link graphs for each write - head. - prev_precedence_weights: A tensor of shape `[batch_size, num_writes, - memory_size]` which is the previous "aggregated" write weights for - each write head. - write_weights: A tensor of shape `[batch_size, num_writes, memory_size]` - containing the new locations in memory written to. + transformed_strengths = tf.expand_dims(strengths_op(strengths), -1) + sharp_activations = activations * transformed_strengths + softmax = snt.BatchApply(module=tf.nn.softmax) + return softmax(sharp_activations) - Returns: - A tensor of shape `[batch_size, num_writes, memory_size, memory_size]` - containing the new link graphs for each write head. - """ - with tf.name_scope('link'): - batch_size = tf.shape(prev_link)[0] - write_weights_i = tf.expand_dims(write_weights, 3) - write_weights_j = tf.expand_dims(write_weights, 2) - prev_precedence_weights_j = tf.expand_dims(prev_precedence_weights, 2) - prev_link_scale = 1 - write_weights_i - write_weights_j - new_link = write_weights_i * prev_precedence_weights_j - link = prev_link_scale * prev_link + new_link - # Return the link with the diagonal set to zero, to remove self-looping - # edges. - return tf.matrix_set_diag( - link, - tf.zeros( - [batch_size, self._num_writes, self._memory_size], - dtype=link.dtype)) - - def _precedence_weights(self, prev_precedence_weights, write_weights): - """Calculates the new precedence weights given the current write weights. - - The precedence weights are the "aggregated write weights" for each write - head, where write weights with sum close to zero will leave the precedence - weights unchanged, but with sum close to one will replace the precedence - weights. - Args: - prev_precedence_weights: A tensor of shape `[batch_size, num_writes, - memory_size]` containing the previous precedence weights. - write_weights: A tensor of shape `[batch_size, num_writes, memory_size]` - containing the new write weights. +class CosineWeights(snt.Module): + """Cosine-weighted attention. - Returns: - A tensor of shape `[batch_size, num_writes, memory_size]` containing the - new precedence weights. + Calculates the cosine similarity between a query and each word in memory, then + applies a weighted softmax to return a sharp distribution. """ - with tf.name_scope('precedence_weights'): - write_sum = tf.reduce_sum(write_weights, 2, keepdims=True) - return (1 - write_sum) * prev_precedence_weights + write_weights - @property - def state_size(self): - """Returns a `TemporalLinkageState` tuple of the state tensors' shapes.""" - return TemporalLinkageState( - link=tf.TensorShape( - [self._num_writes, self._memory_size, self._memory_size]), - precedence_weights=tf.TensorShape([self._num_writes, - self._memory_size]),) + def __init__( + self, num_heads, word_size, strength_op=tf.nn.softplus, name="cosine_weights" + ): + """Initializes the CosineWeights module. + Args: + num_heads: number of memory heads. + word_size: memory word size. + strength_op: operation to apply to strengths (default is tf.nn.softplus). + name: module name (default 'cosine_weights') + """ + super(CosineWeights, self).__init__(name=name) + self._num_heads = num_heads + self._word_size = word_size + self._strength_op = strength_op -class Freeness(snt.RNNCore): - """Memory usage that is increased by writing and decreased by reading. - - This module is a pseudo-RNNCore whose state is a tensor with values in - the range [0, 1] indicating the usage of each of `memory_size` memory slots. + def __call__(self, memory, keys, strengths): + """Connects the CosineWeights module into the graph. - The usage is: + Args: + memory: A 3-D tensor of shape `[batch_size, memory_size, word_size]`. + keys: A 3-D tensor of shape `[batch_size, num_heads, word_size]`. + strengths: A 2-D tensor of shape `[batch_size, num_heads]`. - * Increased by writing, where usage is increased towards 1 at the write - addresses. - * Decreased by reading, where usage is decreased after reading from a - location when free_gate is close to 1. + Returns: + Weights tensor of shape `[batch_size, num_heads, memory_size]`. + """ + # Calculates the inner product between the query vector and words in memory. + dot = tf.matmul(keys, memory, adjoint_b=True) - The function `write_allocation_weights` can be invoked to get free locations - to write to for a number of write heads. - """ + # Outer product to compute denominator (euclidean norm of query and memory). + memory_norms = _vector_norms(memory) + key_norms = _vector_norms(keys) + norm = tf.matmul(key_norms, memory_norms, adjoint_b=True) - def __init__(self, memory_size, name='freeness'): - """Creates a Freeness module. + # Calculates cosine similarity between the query vector and words in memory. + similarity = dot / (norm + _EPSILON) - Args: - memory_size: Number of memory slots. - name: Name of the module. - """ - super(Freeness, self).__init__(name=name) - self._memory_size = memory_size + return weighted_softmax(similarity, strengths, self._strength_op) - def _build(self, write_weights, free_gate, read_weights, prev_usage): - """Calculates the new memory usage u_t. - Memory that was written to in the previous time step will have its usage - increased; memory that was read from and the controller says can be "freed" - will have its usage decreased. +class TemporalLinkage(snt.RNNCore): + """Keeps track of write order for forward and backward addressing. - Args: - write_weights: tensor of shape `[batch_size, num_writes, - memory_size]` giving write weights at previous time step. - free_gate: tensor of shape `[batch_size, num_reads]` which indicates - which read heads read memory that can now be freed. - read_weights: tensor of shape `[batch_size, num_reads, - memory_size]` giving read weights at previous time step. - prev_usage: tensor of shape `[batch_size, memory_size]` giving - usage u_{t - 1} at the previous time step, with entries in range - [0, 1]. + This is a pseudo-RNNCore module, whose state is a pair `(link, + precedence_weights)`, where `link` is a (collection of) graphs for (possibly + multiple) write heads (represented by a tensor with values in the range + [0, 1]), and `precedence_weights` records the "previous write locations" used + to build the link graphs. - Returns: - tensor of shape `[batch_size, memory_size]` representing updated memory - usage. + The function `directional_read_weights` computes addresses following the + forward and backward directions in the link graphs. """ - # Calculation of usage is not differentiable with respect to write weights. - write_weights = tf.stop_gradient(write_weights) - usage = self._usage_after_write(prev_usage, write_weights) - usage = self._usage_after_read(usage, free_gate, read_weights) - return usage - - def write_allocation_weights(self, usage, write_gates, num_writes): - """Calculates freeness-based locations for writing to. - This finds unused memory by ranking the memory locations by usage, for each - write head. (For more than one write head, we use a "simulated new usage" - which takes into account the fact that the previous write head will increase - the usage in that area of the memory.) + def __init__( + self, memory_size, num_writes, name="temporal_linkage", dtype=tf.float32 + ): + """Construct a TemporalLinkage module. + + Args: + memory_size: The number of memory slots. + num_writes: The number of write heads. + name: Name of the module. + """ + super(TemporalLinkage, self).__init__(name=name) + self._memory_size = memory_size + self._num_writes = num_writes + self._dtype = dtype + + def __call__(self, write_weights, prev_state): + """Calculate the updated linkage state given the write weights. + + Args: + write_weights: A tensor of shape `[batch_size, num_writes, memory_size]` + containing the memory addresses of the different write heads. + prev_state: list of tensors containg a tensor `link` of + shape `[batch_size, num_writes, memory_size, memory_size]`, and a + tensor `precedence_weights` of shape `[batch_size, num_writes, + memory_size]` containing the aggregated history of recent writes. + + Returns: + A list of tensors `next_state`, which contains the updated + link and precedence weights. + """ + prev_link, prev_precedence_weights = prev_state + + return [ + self._link(prev_link, prev_precedence_weights, write_weights), + self._precedence_weights(prev_precedence_weights, write_weights), + ] + + def directional_read_weights(self, link, prev_read_weights, forward): + """Calculates the forward or the backward read weights. + + For each read head (at a given address), there are `num_writes` link graphs + to follow. Thus this function computes a read address for each of the + `num_reads * num_writes` pairs of read and write heads. + + Args: + link: tensor of shape `[batch_size, num_writes, memory_size, + memory_size]` representing the link graphs L_t. + prev_read_weights: tensor of shape `[batch_size, num_reads, + memory_size]` containing the previous read weights w_{t-1}^r. + forward: Boolean indicating whether to follow the "future" direction in + the link graph (True) or the "past" direction (False). + + Returns: + tensor of shape `[batch_size, num_reads, num_writes, memory_size]` + """ + # We calculate the forward and backward directions for each pair of + # read and write heads; hence we need to tile the read weights and do a + # sort of "outer product" to get this. + expanded_read_weights = tf.stack([prev_read_weights] * self._num_writes, 1) + result = tf.matmul(expanded_read_weights, link, adjoint_b=forward) + # Swap dimensions 1, 2 so order is [batch, reads, writes, memory]: + return tf.transpose(a=result, perm=[0, 2, 1, 3]) + + def _link(self, prev_link, prev_precedence_weights, write_weights): + """Calculates the new link graphs. + + For each write head, the link is a directed graph (represented by a matrix + with entries in range [0, 1]) whose vertices are the memory locations, and + an edge indicates temporal ordering of writes. + + Args: + prev_link: A tensor of shape `[batch_size, num_writes, memory_size, + memory_size]` representing the previous link graphs for each write + head. + prev_precedence_weights: A tensor of shape `[batch_size, num_writes, + memory_size]` which is the previous "aggregated" write weights for + each write head. + write_weights: A tensor of shape `[batch_size, num_writes, memory_size]` + containing the new locations in memory written to. + + Returns: + A tensor of shape `[batch_size, num_writes, memory_size, memory_size]` + containing the new link graphs for each write head. + """ + batch_size = tf.shape(input=prev_link)[0] + write_weights_i = tf.expand_dims(write_weights, 3) + write_weights_j = tf.expand_dims(write_weights, 2) + prev_precedence_weights_j = tf.expand_dims(prev_precedence_weights, 2) + prev_link_scale = 1 - write_weights_i - write_weights_j + new_link = write_weights_i * prev_precedence_weights_j + link = prev_link_scale * prev_link + new_link + # Return the link with the diagonal set to zero, to remove self-looping + # edges. + return tf.linalg.set_diag( + link, + tf.zeros( + [batch_size, self._num_writes, self._memory_size], dtype=link.dtype + ), + ) + + def _precedence_weights(self, prev_precedence_weights, write_weights): + """Calculates the new precedence weights given the current write weights. + + The precedence weights are the "aggregated write weights" for each write + head, where write weights with sum close to zero will leave the precedence + weights unchanged, but with sum close to one will replace the precedence + weights. + + Args: + prev_precedence_weights: A tensor of shape `[batch_size, num_writes, + memory_size]` containing the previous precedence weights. + write_weights: A tensor of shape `[batch_size, num_writes, memory_size]` + containing the new write weights. + + Returns: + A tensor of shape `[batch_size, num_writes, memory_size]` containing the + new precedence weights. + """ + write_sum = tf.reduce_sum(input_tensor=write_weights, axis=2, keepdims=True) + return (1 - write_sum) * prev_precedence_weights + write_weights + + def initial_state(self, batch_size): + return util.initial_state_from_state_size( + self.state_size, batch_size, self._dtype + ) + + @property + def state_size(self): + """Returns a list of the state tensors' shapes.""" + return [ + # link + tf.TensorShape([self._num_writes, self._memory_size, self._memory_size]), + # precedence_weights + tf.TensorShape([self._num_writes, self._memory_size]), + ] - Args: - usage: A tensor of shape `[batch_size, memory_size]` representing - current memory usage. - write_gates: A tensor of shape `[batch_size, num_writes]` with values in - the range [0, 1] indicating how much each write head does writing - based on the address returned here (and hence how much usage - increases). - num_writes: The number of write heads to calculate write weights for. - - Returns: - tensor of shape `[batch_size, num_writes, memory_size]` containing the - freeness-based write locations. Note that this isn't scaled by - `write_gate`; this scaling must be applied externally. - """ - with tf.name_scope('write_allocation_weights'): - # expand gatings over memory locations - write_gates = tf.expand_dims(write_gates, -1) - allocation_weights = [] - for i in range(num_writes): - allocation_weights.append(self._allocation(usage)) - # update usage to take into account writing to this new allocation - usage += ((1 - usage) * write_gates[:, i, :] * allocation_weights[i]) +class Freeness(snt.RNNCore): + """Memory usage that is increased by writing and decreased by reading. - # Pack the allocation weights for the write heads into one tensor. - return tf.stack(allocation_weights, axis=1) + This module is a pseudo-RNNCore whose state is a tensor with values in + the range [0, 1] indicating the usage of each of `memory_size` memory slots. - def _usage_after_write(self, prev_usage, write_weights): - """Calcualtes the new usage after writing to memory. + The usage is: - Args: - prev_usage: tensor of shape `[batch_size, memory_size]`. - write_weights: tensor of shape `[batch_size, num_writes, memory_size]`. + * Increased by writing, where usage is increased towards 1 at the write + addresses. + * Decreased by reading, where usage is decreased after reading from a + location when free_gate is close to 1. - Returns: - New usage, a tensor of shape `[batch_size, memory_size]`. + The function `write_allocation_weights` can be invoked to get free locations + to write to for a number of write heads. """ - with tf.name_scope('usage_after_write'): - # Calculate the aggregated effect of all write heads - write_weights = 1 - util.reduce_prod(1 - write_weights, 1) - return prev_usage + (1 - prev_usage) * write_weights - - def _usage_after_read(self, prev_usage, free_gate, read_weights): - """Calcualtes the new usage after reading and freeing from memory. - Args: - prev_usage: tensor of shape `[batch_size, memory_size]`. - free_gate: tensor of shape `[batch_size, num_reads]` with entries in the - range [0, 1] indicating the amount that locations read from can be - freed. - read_weights: tensor of shape `[batch_size, num_reads, memory_size]`. - - Returns: - New usage, a tensor of shape `[batch_size, memory_size]`. - """ - with tf.name_scope('usage_after_read'): - free_gate = tf.expand_dims(free_gate, -1) - free_read_weights = free_gate * read_weights - phi = util.reduce_prod(1 - free_read_weights, 1, name='phi') - return prev_usage * phi - - def _allocation(self, usage): - r"""Computes allocation by sorting `usage`. - - This corresponds to the value a = a_t[\phi_t[j]] in the paper. - - Args: - usage: tensor of shape `[batch_size, memory_size]` indicating current - memory usage. This is equal to u_t in the paper when we only have one - write head, but for multiple write heads, one should update the usage - while iterating through the write heads to take into account the - allocation returned by this function. - - Returns: - Tensor of shape `[batch_size, memory_size]` corresponding to allocation. - """ - with tf.name_scope('allocation'): - # Ensure values are not too small prior to cumprod. - usage = _EPSILON + (1 - _EPSILON) * usage - - nonusage = 1 - usage - sorted_nonusage, indices = tf.nn.top_k( - nonusage, k=self._memory_size, name='sort') - sorted_usage = 1 - sorted_nonusage - prod_sorted_usage = tf.cumprod(sorted_usage, axis=1, exclusive=True) - sorted_allocation = sorted_nonusage * prod_sorted_usage - inverse_indices = util.batch_invert_permutation(indices) - - # This final line "unsorts" sorted_allocation, so that the indexing - # corresponds to the original indexing of `usage`. - return util.batch_gather(sorted_allocation, inverse_indices) - - @property - def state_size(self): - """Returns the shape of the state tensor.""" - return tf.TensorShape([self._memory_size]) + def __init__(self, memory_size, name="freeness", dtype=tf.float32): + """Creates a Freeness module. + + Args: + memory_size: Number of memory slots. + name: Name of the module. + """ + super(Freeness, self).__init__(name=name) + self._memory_size = memory_size + self._dtype = dtype + + def __call__(self, write_weights, free_gate, read_weights, prev_usage): + """Calculates the new memory usage u_t. + + Memory that was written to in the previous time step will have its usage + increased; memory that was read from and the controller says can be "freed" + will have its usage decreased. + + Args: + write_weights: tensor of shape `[batch_size, num_writes, + memory_size]` giving write weights at previous time step. + free_gate: tensor of shape `[batch_size, num_reads]` which indicates + which read heads read memory that can now be freed. + read_weights: tensor of shape `[batch_size, num_reads, + memory_size]` giving read weights at previous time step. + prev_usage: tensor of shape `[batch_size, memory_size]` giving + usage u_{t - 1} at the previous time step, with entries in range + [0, 1]. + + Returns: + tensor of shape `[batch_size, memory_size]` representing updated memory + usage. + """ + # Calculation of usage is not differentiable with respect to write weights. + write_weights = tf.stop_gradient(write_weights) + usage = self._usage_after_write(prev_usage, write_weights) + usage = self._usage_after_read(usage, free_gate, read_weights) + return usage + + def write_allocation_weights(self, usage, write_gates, num_writes): + """Calculates freeness-based locations for writing to. + + This finds unused memory by ranking the memory locations by usage, for each + write head. (For more than one write head, we use a "simulated new usage" + which takes into account the fact that the previous write head will increase + the usage in that area of the memory.) + + Args: + usage: A tensor of shape `[batch_size, memory_size]` representing + current memory usage. + write_gates: A tensor of shape `[batch_size, num_writes]` with values in + the range [0, 1] indicating how much each write head does writing + based on the address returned here (and hence how much usage + increases). + num_writes: The number of write heads to calculate write weights for. + + Returns: + tensor of shape `[batch_size, num_writes, memory_size]` containing the + freeness-based write locations. Note that this isn't scaled by + `write_gate`; this scaling must be applied externally. + """ + # expand gatings over memory locations + write_gates = tf.expand_dims(write_gates, -1) + + allocation_weights = [] + for i in range(num_writes): + allocation_weights.append(self._allocation(usage)) + # update usage to take into account writing to this new allocation + usage += (1 - usage) * write_gates[:, i, :] * allocation_weights[i] + + # Pack the allocation weights for the write heads into one tensor. + return tf.stack(allocation_weights, axis=1) + + def _usage_after_write(self, prev_usage, write_weights): + """Calculates the new usage after writing to memory. + + Args: + prev_usage: tensor of shape `[batch_size, memory_size]`. + write_weights: tensor of shape `[batch_size, num_writes, memory_size]`. + + Returns: + New usage, a tensor of shape `[batch_size, memory_size]`. + """ + # Calculate the aggregated effect of all write heads + write_weights = 1 - util.reduce_prod(1 - write_weights, 1) + return prev_usage + (1 - prev_usage) * write_weights + + def _usage_after_read(self, prev_usage, free_gate, read_weights): + """Calculates the new usage after reading and freeing from memory. + + Args: + prev_usage: tensor of shape `[batch_size, memory_size]`. + free_gate: tensor of shape `[batch_size, num_reads]` with entries in the + range [0, 1] indicating the amount that locations read from can be + freed. + read_weights: tensor of shape `[batch_size, num_reads, memory_size]`. + + Returns: + New usage, a tensor of shape `[batch_size, memory_size]`. + """ + free_gate = tf.expand_dims(free_gate, -1) + free_read_weights = free_gate * read_weights + phi = util.reduce_prod(1 - free_read_weights, 1, name="phi") + return prev_usage * phi + + def _allocation(self, usage): + r"""Computes allocation by sorting `usage`. + + This corresponds to the value a = a_t[\phi_t[j]] in the paper. + + Args: + usage: tensor of shape `[batch_size, memory_size]` indicating current + memory usage. This is equal to u_t in the paper when we only have one + write head, but for multiple write heads, one should update the usage + while iterating through the write heads to take into account the + allocation returned by this function. + + Returns: + Tensor of shape `[batch_size, memory_size]` corresponding to allocation. + """ + # Ensure values are not too small prior to cumprod. + usage = _EPSILON + (1 - _EPSILON) * usage + + nonusage = 1 - usage + sorted_nonusage, indices = tf.nn.top_k( + nonusage, k=self._memory_size, name="sort" + ) + sorted_usage = 1 - sorted_nonusage + prod_sorted_usage = tf.math.cumprod(sorted_usage, axis=1, exclusive=True) + sorted_allocation = sorted_nonusage * prod_sorted_usage + inverse_indices = tf.cast(util.batch_invert_permutation(indices), tf.int32) + + # This final line "unsorts" sorted_allocation, so that the indexing + # corresponds to the original indexing of `usage`. + return util.batch_gather(sorted_allocation, inverse_indices) + + # freeness size is independent of batch size + def initial_state(self, batch_size): + return tf.zeros([self._memory_size], dtype=self._dtype) + + @property + def state_size(self): + """Returns the shape of the state tensor.""" + return tf.TensorShape([self._memory_size]) diff --git a/dnc/addressing_test.py b/dnc/addressing_test.py deleted file mode 100644 index a8a8ac4..0000000 --- a/dnc/addressing_test.py +++ /dev/null @@ -1,420 +0,0 @@ -# Copyright 2017 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for memory addressing.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import sonnet as snt -import tensorflow as tf - -from dnc import addressing -from dnc import util - - -class WeightedSoftmaxTest(tf.test.TestCase): - - def testValues(self): - batch_size = 5 - num_heads = 3 - memory_size = 7 - - activations_data = np.random.randn(batch_size, num_heads, memory_size) - weights_data = np.ones((batch_size, num_heads)) - - activations = tf.placeholder(tf.float32, - [batch_size, num_heads, memory_size]) - weights = tf.placeholder(tf.float32, [batch_size, num_heads]) - # Run weighted softmax with identity placed on weights. Output should be - # equal to a standalone softmax. - observed = addressing.weighted_softmax(activations, weights, tf.identity) - expected = snt.BatchApply( - module_or_op=tf.nn.softmax, name='BatchSoftmax')(activations) - with self.test_session() as sess: - observed = sess.run( - observed, - feed_dict={activations: activations_data, - weights: weights_data}) - expected = sess.run(expected, feed_dict={activations: activations_data}) - self.assertAllClose(observed, expected) - - -class CosineWeightsTest(tf.test.TestCase): - - def testShape(self): - batch_size = 5 - num_heads = 3 - memory_size = 7 - word_size = 2 - - module = addressing.CosineWeights(num_heads, word_size) - mem = tf.placeholder(tf.float32, [batch_size, memory_size, word_size]) - keys = tf.placeholder(tf.float32, [batch_size, num_heads, word_size]) - strengths = tf.placeholder(tf.float32, [batch_size, num_heads]) - weights = module(mem, keys, strengths) - self.assertTrue(weights.get_shape().is_compatible_with( - [batch_size, num_heads, memory_size])) - - def testValues(self): - batch_size = 5 - num_heads = 4 - memory_size = 10 - word_size = 2 - - mem_data = np.random.randn(batch_size, memory_size, word_size) - np.copyto(mem_data[0, 0], [1, 2]) - np.copyto(mem_data[0, 1], [3, 4]) - np.copyto(mem_data[0, 2], [5, 6]) - - keys_data = np.random.randn(batch_size, num_heads, word_size) - np.copyto(keys_data[0, 0], [5, 6]) - np.copyto(keys_data[0, 1], [1, 2]) - np.copyto(keys_data[0, 2], [5, 6]) - np.copyto(keys_data[0, 3], [3, 4]) - strengths_data = np.random.randn(batch_size, num_heads) - - module = addressing.CosineWeights(num_heads, word_size) - mem = tf.placeholder(tf.float32, [batch_size, memory_size, word_size]) - keys = tf.placeholder(tf.float32, [batch_size, num_heads, word_size]) - strengths = tf.placeholder(tf.float32, [batch_size, num_heads]) - weights = module(mem, keys, strengths) - - with self.test_session() as sess: - result = sess.run( - weights, - feed_dict={mem: mem_data, - keys: keys_data, - strengths: strengths_data}) - - # Manually checks results. - strengths_softplus = np.log(1 + np.exp(strengths_data)) - similarity = np.zeros((memory_size)) - - for b in xrange(batch_size): - for h in xrange(num_heads): - key = keys_data[b, h] - key_norm = np.linalg.norm(key) - - for m in xrange(memory_size): - row = mem_data[b, m] - similarity[m] = np.dot(key, row) / (key_norm * np.linalg.norm(row)) - - similarity = np.exp(similarity * strengths_softplus[b, h]) - similarity /= similarity.sum() - self.assertAllClose(result[b, h], similarity, atol=1e-4, rtol=1e-4) - - def testDivideByZero(self): - batch_size = 5 - num_heads = 4 - memory_size = 10 - word_size = 2 - - module = addressing.CosineWeights(num_heads, word_size) - keys = tf.random_normal([batch_size, num_heads, word_size]) - strengths = tf.random_normal([batch_size, num_heads]) - - # First row of memory is non-zero to concentrate attention on this location. - # Remaining rows are all zero. - first_row_ones = tf.ones([batch_size, 1, word_size], dtype=tf.float32) - remaining_zeros = tf.zeros( - [batch_size, memory_size - 1, word_size], dtype=tf.float32) - mem = tf.concat((first_row_ones, remaining_zeros), 1) - - output = module(mem, keys, strengths) - gradients = tf.gradients(output, [mem, keys, strengths]) - - with self.test_session() as sess: - output, gradients = sess.run([output, gradients]) - self.assertFalse(np.any(np.isnan(output))) - self.assertFalse(np.any(np.isnan(gradients[0]))) - self.assertFalse(np.any(np.isnan(gradients[1]))) - self.assertFalse(np.any(np.isnan(gradients[2]))) - - -class TemporalLinkageTest(tf.test.TestCase): - - def testModule(self): - batch_size = 7 - memory_size = 4 - num_reads = 11 - num_writes = 5 - module = addressing.TemporalLinkage( - memory_size=memory_size, num_writes=num_writes) - - prev_link_in = tf.placeholder( - tf.float32, (batch_size, num_writes, memory_size, memory_size)) - prev_precedence_weights_in = tf.placeholder( - tf.float32, (batch_size, num_writes, memory_size)) - write_weights_in = tf.placeholder(tf.float32, - (batch_size, num_writes, memory_size)) - - state = addressing.TemporalLinkageState( - link=np.zeros([batch_size, num_writes, memory_size, memory_size]), - precedence_weights=np.zeros([batch_size, num_writes, memory_size])) - - calc_state = module(write_weights_in, - addressing.TemporalLinkageState( - link=prev_link_in, - precedence_weights=prev_precedence_weights_in)) - - with self.test_session() as sess: - num_steps = 5 - for i in xrange(num_steps): - write_weights = np.random.rand(batch_size, num_writes, memory_size) - write_weights /= write_weights.sum(2, keepdims=True) + 1 - - # Simulate (in final steps) link 0-->1 in head 0 and 3-->2 in head 1 - if i == num_steps - 2: - write_weights[0, 0, :] = util.one_hot(memory_size, 0) - write_weights[0, 1, :] = util.one_hot(memory_size, 3) - elif i == num_steps - 1: - write_weights[0, 0, :] = util.one_hot(memory_size, 1) - write_weights[0, 1, :] = util.one_hot(memory_size, 2) - - state = sess.run( - calc_state, - feed_dict={ - prev_link_in: state.link, - prev_precedence_weights_in: state.precedence_weights, - write_weights_in: write_weights - }) - - # link should be bounded in range [0, 1] - self.assertGreaterEqual(state.link.min(), 0) - self.assertLessEqual(state.link.max(), 1) - - # link diagonal should be zero - self.assertAllEqual( - state.link[:, :, range(memory_size), range(memory_size)], - np.zeros([batch_size, num_writes, memory_size])) - - # link rows and columns should sum to at most 1 - self.assertLessEqual(state.link.sum(2).max(), 1) - self.assertLessEqual(state.link.sum(3).max(), 1) - - # records our transitions in batch 0: head 0: 0->1, and head 1: 3->2 - self.assertAllEqual(state.link[0, 0, :, 0], util.one_hot(memory_size, 1)) - self.assertAllEqual(state.link[0, 1, :, 3], util.one_hot(memory_size, 2)) - - # Now test calculation of forward and backward read weights - prev_read_weights = np.random.rand(batch_size, num_reads, memory_size) - prev_read_weights[0, 5, :] = util.one_hot(memory_size, 0) # read 5, posn 0 - prev_read_weights[0, 6, :] = util.one_hot(memory_size, 2) # read 6, posn 2 - forward_read_weights = module.directional_read_weights( - tf.constant(state.link), - tf.constant(prev_read_weights, dtype=tf.float32), - forward=True) - backward_read_weights = module.directional_read_weights( - tf.constant(state.link), - tf.constant(prev_read_weights, dtype=tf.float32), - forward=False) - - with self.test_session(): - forward_read_weights = forward_read_weights.eval() - backward_read_weights = backward_read_weights.eval() - - # Check directional weights calculated correctly. - self.assertAllEqual( - forward_read_weights[0, 5, 0, :], # read=5, write=0 - util.one_hot(memory_size, 1)) - self.assertAllEqual( - backward_read_weights[0, 6, 1, :], # read=6, write=1 - util.one_hot(memory_size, 3)) - - def testPrecedenceWeights(self): - batch_size = 7 - memory_size = 3 - num_writes = 5 - module = addressing.TemporalLinkage( - memory_size=memory_size, num_writes=num_writes) - - prev_precedence_weights = np.random.rand(batch_size, num_writes, - memory_size) - write_weights = np.random.rand(batch_size, num_writes, memory_size) - - # These should sum to at most 1 for each write head in each batch. - write_weights /= write_weights.sum(2, keepdims=True) + 1 - prev_precedence_weights /= prev_precedence_weights.sum(2, keepdims=True) + 1 - - write_weights[0, 1, :] = 0 # batch 0 head 1: no writing - write_weights[1, 2, :] /= write_weights[1, 2, :].sum() # b1 h2: all writing - - precedence_weights = module._precedence_weights( - prev_precedence_weights=tf.constant(prev_precedence_weights), - write_weights=tf.constant(write_weights)) - - with self.test_session(): - precedence_weights = precedence_weights.eval() - - # precedence weights should be bounded in range [0, 1] - self.assertGreaterEqual(precedence_weights.min(), 0) - self.assertLessEqual(precedence_weights.max(), 1) - - # no writing in batch 0, head 1 - self.assertAllClose(precedence_weights[0, 1, :], - prev_precedence_weights[0, 1, :]) - - # all writing in batch 1, head 2 - self.assertAllClose(precedence_weights[1, 2, :], write_weights[1, 2, :]) - - -class FreenessTest(tf.test.TestCase): - - def testModule(self): - batch_size = 5 - memory_size = 11 - num_reads = 3 - num_writes = 7 - module = addressing.Freeness(memory_size) - - free_gate = np.random.rand(batch_size, num_reads) - - # Produce read weights that sum to 1 for each batch and head. - prev_read_weights = np.random.rand(batch_size, num_reads, memory_size) - prev_read_weights[1, :, 3] = 0 # no read at batch 1, position 3; see below - prev_read_weights /= prev_read_weights.sum(2, keepdims=True) - prev_write_weights = np.random.rand(batch_size, num_writes, memory_size) - prev_write_weights /= prev_write_weights.sum(2, keepdims=True) - prev_usage = np.random.rand(batch_size, memory_size) - - # Add some special values that allows us to test the behaviour: - prev_write_weights[1, 2, 3] = 1 # full write in batch 1, head 2, position 3 - prev_read_weights[2, 0, 4] = 1 # full read at batch 2, head 0, position 4 - free_gate[2, 0] = 1 # can free up all locations for batch 2, read head 0 - - usage = module( - tf.constant(prev_write_weights), - tf.constant(free_gate), - tf.constant(prev_read_weights), tf.constant(prev_usage)) - with self.test_session(): - usage = usage.eval() - - # Check all usages are between 0 and 1. - self.assertGreaterEqual(usage.min(), 0) - self.assertLessEqual(usage.max(), 1) - - # Check that the full write at batch 1, position 3 makes it fully used. - self.assertEqual(usage[1][3], 1) - - # Check that the full free at batch 2, position 4 makes it fully free. - self.assertEqual(usage[2][4], 0) - - def testWriteAllocationWeights(self): - batch_size = 7 - memory_size = 23 - num_writes = 5 - module = addressing.Freeness(memory_size) - - usage = np.random.rand(batch_size, memory_size) - write_gates = np.random.rand(batch_size, num_writes) - - # Turn off gates for heads 1 and 3 in batch 0. This doesn't scaling down the - # weighting, but it means that the usage doesn't change, so we should get - # the same allocation weightings for: (1, 2) and (3, 4) (but all others - # being different). - write_gates[0, 1] = 0 - write_gates[0, 3] = 0 - # and turn heads 0 and 2 on for full effect. - write_gates[0, 0] = 1 - write_gates[0, 2] = 1 - - # In batch 1, make one of the usages 0 and another almost 0, so that these - # entries get most of the allocation weights for the first and second heads. - usage[1] = usage[1] * 0.9 + 0.1 # make sure all entries are in [0.1, 1] - usage[1][4] = 0 # write head 0 should get allocated to position 4 - usage[1][3] = 1e-4 # write head 1 should get allocated to position 3 - write_gates[1, 0] = 1 # write head 0 fully on - write_gates[1, 1] = 1 # write head 1 fully on - - weights = module.write_allocation_weights( - usage=tf.constant(usage), - write_gates=tf.constant(write_gates), - num_writes=num_writes) - - with self.test_session(): - weights = weights.eval() - - # Check that all weights are between 0 and 1 - self.assertGreaterEqual(weights.min(), 0) - self.assertLessEqual(weights.max(), 1) - - # Check that weights sum to close to 1 - self.assertAllClose( - np.sum(weights, axis=2), np.ones([batch_size, num_writes]), atol=1e-3) - - # Check the same / different allocation weight pairs as described above. - self.assertGreater(np.abs(weights[0, 0, :] - weights[0, 1, :]).max(), 0.1) - self.assertAllEqual(weights[0, 1, :], weights[0, 2, :]) - self.assertGreater(np.abs(weights[0, 2, :] - weights[0, 3, :]).max(), 0.1) - self.assertAllEqual(weights[0, 3, :], weights[0, 4, :]) - - self.assertAllClose(weights[1][0], util.one_hot(memory_size, 4), atol=1e-3) - self.assertAllClose(weights[1][1], util.one_hot(memory_size, 3), atol=1e-3) - - def testWriteAllocationWeightsGradient(self): - batch_size = 7 - memory_size = 5 - num_writes = 3 - module = addressing.Freeness(memory_size) - - usage = tf.constant(np.random.rand(batch_size, memory_size)) - write_gates = tf.constant(np.random.rand(batch_size, num_writes)) - weights = module.write_allocation_weights(usage, write_gates, num_writes) - - with self.test_session(): - err = tf.test.compute_gradient_error( - [usage, write_gates], - [usage.get_shape().as_list(), write_gates.get_shape().as_list()], - weights, - weights.get_shape().as_list(), - delta=1e-5) - self.assertLess(err, 0.01) - - def testAllocation(self): - batch_size = 7 - memory_size = 13 - usage = np.random.rand(batch_size, memory_size) - module = addressing.Freeness(memory_size) - allocation = module._allocation(tf.constant(usage)) - with self.test_session(): - allocation = allocation.eval() - - # 1. Test that max allocation goes to min usage, and vice versa. - self.assertAllEqual(np.argmin(usage, axis=1), np.argmax(allocation, axis=1)) - self.assertAllEqual(np.argmax(usage, axis=1), np.argmin(allocation, axis=1)) - - # 2. Test that allocations sum to almost 1. - self.assertAllClose(np.sum(allocation, axis=1), np.ones(batch_size), 0.01) - - def testAllocationGradient(self): - batch_size = 1 - memory_size = 5 - usage = tf.constant(np.random.rand(batch_size, memory_size)) - module = addressing.Freeness(memory_size) - allocation = module._allocation(usage) - with self.test_session(): - err = tf.test.compute_gradient_error( - usage, - usage.get_shape().as_list(), - allocation, - allocation.get_shape().as_list(), - delta=1e-5) - self.assertLess(err, 0.01) - - -if __name__ == '__main__': - tf.test.main() diff --git a/dnc/dnc.py b/dnc/dnc.py index db14b2a..4a35675 100644 --- a/dnc/dnc.py +++ b/dnc/dnc.py @@ -22,121 +22,141 @@ from __future__ import division from __future__ import print_function -import collections -import numpy as np import sonnet as snt import tensorflow as tf -from dnc import access +from dnc import access, util -DNCState = collections.namedtuple('DNCState', ('access_output', 'access_state', - 'controller_state')) +# For directly indexing into DNC state +ACCESS_OUTPUT = 0 +ACCESS_STATE = 1 +CONTROLLER_STATE = 2 class DNC(snt.RNNCore): - """DNC core module. - - Contains controller and memory access module. - """ - - def __init__(self, - access_config, - controller_config, - output_size, - clip_value=None, - name='dnc'): - """Initializes the DNC core. - - Args: - access_config: dictionary of access module configurations. - controller_config: dictionary of controller (LSTM) module configurations. - output_size: output dimension size of core. - clip_value: clips controller and core output values to between - `[-clip_value, clip_value]` if specified. - name: module name (default 'dnc'). - - Raises: - TypeError: if direct_input_size is not None for any access module other - than KeyValueMemory. - """ - super(DNC, self).__init__(name=name) - - with self._enter_variable_scope(): - self._controller = snt.LSTM(**controller_config) - self._access = access.MemoryAccess(**access_config) - - self._access_output_size = np.prod(self._access.output_size.as_list()) - self._output_size = output_size - self._clip_value = clip_value or 0 - - self._output_size = tf.TensorShape([output_size]) - self._state_size = DNCState( - access_output=self._access_output_size, - access_state=self._access.state_size, - controller_state=self._controller.state_size) - - def _clip_if_enabled(self, x): - if self._clip_value > 0: - return tf.clip_by_value(x, -self._clip_value, self._clip_value) - else: - return x - - def _build(self, inputs, prev_state): - """Connects the DNC core into the graph. - - Args: - inputs: Tensor input. - prev_state: A `DNCState` tuple containing the fields `access_output`, - `access_state` and `controller_state`. `access_state` is a 3-D Tensor - of shape `[batch_size, num_reads, word_size]` containing read words. - `access_state` is a tuple of the access module's state, and - `controller_state` is a tuple of controller module's state. - - Returns: - A tuple `(output, next_state)` where `output` is a tensor and `next_state` - is a `DNCState` tuple containing the fields `access_output`, - `access_state`, and `controller_state`. - """ - - prev_access_output = prev_state.access_output - prev_access_state = prev_state.access_state - prev_controller_state = prev_state.controller_state - - batch_flatten = snt.BatchFlatten() - controller_input = tf.concat( - [batch_flatten(inputs), batch_flatten(prev_access_output)], 1) - - controller_output, controller_state = self._controller( - controller_input, prev_controller_state) - - controller_output = self._clip_if_enabled(controller_output) - controller_state = tf.contrib.framework.nest.map_structure(self._clip_if_enabled, controller_state) + """DNC core module. - access_output, access_state = self._access(controller_output, - prev_access_state) - - output = tf.concat([controller_output, batch_flatten(access_output)], 1) - output = snt.Linear( - output_size=self._output_size.as_list()[0], - name='output_linear')(output) - output = self._clip_if_enabled(output) - - return output, DNCState( - access_output=access_output, - access_state=access_state, - controller_state=controller_state) - - def initial_state(self, batch_size, dtype=tf.float32): - return DNCState( - controller_state=self._controller.initial_state(batch_size, dtype), - access_state=self._access.initial_state(batch_size, dtype), - access_output=tf.zeros( - [batch_size] + self._access.output_size.as_list(), dtype)) - - @property - def state_size(self): - return self._state_size + Contains controller and memory access module. + """ - @property - def output_size(self): - return self._output_size + def __init__( + self, + access_config, + controller_config, + output_size, + batch_size, + clip_value=None, + name="dnc", + dtype=tf.float32, + ): + """Initializes the DNC core. + + Args: + access_config: dictionary of access module configurations. + controller_config: dictionary of controller (LSTM) module configurations. + output_size: output dimension size of core. + clip_value: clips controller and core output values to between + `[-clip_value, clip_value]` if specified. + name: module name (default 'dnc'). + + Raises: + TypeError: if direct_input_size is not None for any access module other + than KeyValueMemory. + """ + super(DNC, self).__init__(name=name) + + self._dtype = dtype + # dm-sonnet=2.0.0 LSTM is not integrated with TF2 tracing. + # Use keras to allow for Tensorboard visualization + # self._controller = snt.LSTM(**controller_config, dtype=tf.float64) + self._controller = tf.keras.layers.LSTMCell(**controller_config, dtype=dtype) + self._access = access.MemoryAccess(**access_config, dtype=dtype) + + self._output_size = output_size + self._batch_size = batch_size + self._clip_value = clip_value or 0 + + self._output_linear = snt.Linear(output_size=output_size, name="output_linear") + + def _clip_if_enabled(self, x): + if self._clip_value > 0: + return tf.clip_by_value(x, -self._clip_value, self._clip_value) + else: + return x + + # keras.layers.RNN abstract method + def call(self, inputs, prev_state): + return self.__call__(inputs, prev_state) + + # sonnet.RNNCore abstract method + def __call__(self, inputs, prev_state): + """Connects the DNC core into the graph. + + Args: + inputs: Tensor input. + prev_state: A `DNCState` tuple containing the fields `access_output`, + `access_state` and `controller_state`. `access_state` is a 3-D Tensor + of shape `[batch_size, num_reads, word_size]` containing read words. + `access_state` is a tuple of the access module's state, and + `controller_state` is a tuple of controller module's state. + + Returns: + A tuple `(output, next_state)` where `output` is a tensor and `next_state` + is a nested list of tensors representing the dnc state: `access_output`, + `access_state`, and `controller_state`. + """ + [prev_access_output, prev_access_state, prev_controller_state] = prev_state + + batch_flatten = tf.keras.layers.Flatten() + controller_input = tf.concat( + [batch_flatten(inputs), batch_flatten(prev_access_output)], 1 + ) + + controller_output, controller_state = self._controller( + controller_input, prev_controller_state + ) + + controller_output = self._clip_if_enabled(controller_output) + controller_state = tf.nest.map_structure( + self._clip_if_enabled, controller_state + ) + + access_output, access_state = self._access(controller_output, prev_access_state) + + output = tf.concat([controller_output, batch_flatten(access_output)], 1) + output = self._output_linear(output) + output = self._clip_if_enabled(output) + + return ( + output, + [ + access_output, + access_state, + controller_state, + ], + ) + + # keras.layers.RNN uses get_initial_state + def get_initial_state(self, batch_size=None, inputs=None, dtype=None): + return util.initial_state_from_state_size( + self.state_size, batch_size, self._dtype + ) + + # sonnet.RNNCore uses initial_state + def initial_state(self, batch_size=None): + return self.get_initial_state(batch_size=batch_size) + + @property + def state_size(self): + return [ + # access_output + self._access.output_size, + # access_state + self._access.state_size, + # controller_state + self._controller.state_size, + ] + + @property + def output_size(self): + return tf.TensorShape([self._output_size]) diff --git a/dnc/repeat_copy.py b/dnc/repeat_copy.py index ad52579..ea2c8ac 100644 --- a/dnc/repeat_copy.py +++ b/dnc/repeat_copy.py @@ -22,371 +22,440 @@ import sonnet as snt import tensorflow as tf -DatasetTensors = collections.namedtuple('DatasetTensors', ('observations', - 'target', 'mask')) +DatasetTensors = collections.namedtuple( + "DatasetTensors", ("observations", "target", "mask") +) -def masked_sigmoid_cross_entropy(logits, - target, - mask, - time_average=False, - log_prob_in_bits=False): - """Adds ops to graph which compute the (scalar) NLL of the target sequence. +def masked_sigmoid_cross_entropy( + logits, target, mask, time_average=False, log_prob_in_bits=False +): + """Adds ops to graph which compute the (scalar) NLL of the target sequence. - The logits parametrize independent bernoulli distributions per time-step and - per batch element, and irrelevant time/batch elements are masked out by the - mask tensor. + The logits parametrize independent bernoulli distributions per time-step and + per batch element, and irrelevant time/batch elements are masked out by the + mask tensor. - Args: - logits: `Tensor` of activations for which sigmoid(`logits`) gives the - bernoulli parameter. - target: time-major `Tensor` of target. - mask: time-major `Tensor` to be multiplied elementwise with cost T x B cost - masking out irrelevant time-steps. - time_average: optionally average over the time dimension (sum by default). - log_prob_in_bits: iff True express log-probabilities in bits (default nats). - - Returns: - A `Tensor` representing the log-probability of the target. - """ - xent = tf.nn.sigmoid_cross_entropy_with_logits(labels=target, logits=logits) - loss_time_batch = tf.reduce_sum(xent, axis=2) - loss_batch = tf.reduce_sum(loss_time_batch * mask, axis=0) + Args: + logits: `Tensor` of activations for which sigmoid(`logits`) gives the + bernoulli parameter. + target: time-major `Tensor` of target. + mask: time-major `Tensor` to be multiplied elementwise with cost T x B cost + masking out irrelevant time-steps. + time_average: optionally average over the time dimension (sum by default). + log_prob_in_bits: iff True express log-probabilities in bits (default nats). + + Returns: + A `Tensor` representing the log-probability of the target. + """ + xent = tf.nn.sigmoid_cross_entropy_with_logits(labels=target, logits=logits) + loss_time_batch = tf.reduce_sum(input_tensor=xent, axis=2) + loss_batch = tf.reduce_sum(input_tensor=loss_time_batch * mask, axis=0) - batch_size = tf.cast(tf.shape(logits)[1], dtype=loss_time_batch.dtype) + batch_size = tf.cast(tf.shape(input=logits)[1], dtype=loss_time_batch.dtype) - if time_average: - mask_count = tf.reduce_sum(mask, axis=0) - loss_batch /= (mask_count + np.finfo(np.float32).eps) + if time_average: + mask_count = tf.reduce_sum(input_tensor=mask, axis=0) + loss_batch /= mask_count + np.finfo(np.float32).eps - loss = tf.reduce_sum(loss_batch) / batch_size - if log_prob_in_bits: - loss /= tf.log(2.) + loss = tf.reduce_sum(input_tensor=loss_batch) / batch_size + if log_prob_in_bits: + loss /= tf.math.log(2.0) - return loss + return loss def bitstring_readable(data, batch_size, model_output=None, whole_batch=False): - """Produce a human readable representation of the sequences in data. - - Args: - data: data to be visualised - batch_size: size of batch - model_output: optional model output tensor to visualize alongside data. - whole_batch: whether to visualise the whole batch. Only the first sample - will be visualized if False - - Returns: - A string used to visualise the data batch - """ - - def _readable(datum): - return '+' + ' '.join(['-' if x == 0 else '%d' % x for x in datum]) + '+' - - obs_batch = data.observations - targ_batch = data.target - - iterate_over = range(batch_size) if whole_batch else range(1) - - batch_strings = [] - for batch_index in iterate_over: - obs = obs_batch[:, batch_index, :] - targ = targ_batch[:, batch_index, :] - - obs_channels = range(obs.shape[1]) - targ_channels = range(targ.shape[1]) - obs_channel_strings = [_readable(obs[:, i]) for i in obs_channels] - targ_channel_strings = [_readable(targ[:, i]) for i in targ_channels] - - readable_obs = 'Observations:\n' + '\n'.join(obs_channel_strings) - readable_targ = 'Targets:\n' + '\n'.join(targ_channel_strings) - strings = [readable_obs, readable_targ] - - if model_output is not None: - output = model_output[:, batch_index, :] - output_strings = [_readable(output[:, i]) for i in targ_channels] - strings.append('Model Output:\n' + '\n'.join(output_strings)) - - batch_strings.append('\n\n'.join(strings)) - - return '\n' + '\n\n\n\n'.join(batch_strings) - - -class RepeatCopy(snt.AbstractModule): - """Sequence data generator for the task of repeating a random binary pattern. - - When called, an instance of this class will return a tuple of tensorflow ops - (obs, targ, mask), representing an input sequence, target sequence, and - binary mask. Each of these ops produces tensors whose first two dimensions - represent sequence position and batch index respectively. The value in - mask[t, b] is equal to 1 iff a prediction about targ[t, b, :] should be - penalized and 0 otherwise. - - For each realisation from this generator, the observation sequence is - comprised of I.I.D. uniform-random binary vectors (and some flags). - - The target sequence is comprised of this binary pattern repeated - some number of times (and some flags). Before explaining in more detail, - let's examine the setup pictorially for a single batch element: - - ```none - Note: blank space represents 0. - - time ------------------------------------------> - - +-------------------------------+ - mask: |0000000001111111111111111111111| - +-------------------------------+ - - +-------------------------------+ - target: | 1| 'end-marker' channel. - | 101100110110011011001 | - | 010101001010100101010 | - +-------------------------------+ - - +-------------------------------+ - observation: | 1011001 | - | 0101010 | - |1 | 'start-marker' channel - | 3 | 'num-repeats' channel. - +-------------------------------+ - ``` - - The length of the random pattern and the number of times it is repeated - in the target are both discrete random variables distributed according to - uniform distributions whose parameters are configured at construction time. - - The obs sequence has two extra channels (components in the trailing dimension) - which are used for flags. One channel is marked with a 1 at the first time - step and is otherwise equal to 0. The other extra channel is zero until the - binary pattern to be repeated ends. At this point, it contains an encoding of - the number of times the observation pattern should be repeated. Rather than - simply providing this integer number directly, it is normalised so that - a neural network may have an easier time representing the number of - repetitions internally. To allow a network to be readily evaluated on - instances of this task with greater numbers of repetitions, the range with - respect to which this encoding is normalised is also configurable by the user. - - As in the diagram, the target sequence is offset to begin directly after the - observation sequence; both sequences are padded with zeros to accomplish this, - resulting in their lengths being equal. Additional padding is done at the end - so that all sequences in a minibatch represent tensors with the same shape. - """ - - def __init__( - self, - num_bits=6, - batch_size=1, - min_length=1, - max_length=1, - min_repeats=1, - max_repeats=2, - norm_max=10, - log_prob_in_bits=False, - time_average_cost=False, - name='repeat_copy',): - """Creates an instance of RepeatCopy task. + """Produce a human readable representation of the sequences in data. Args: - name: A name for the generator instance (for name scope purposes). - num_bits: The dimensionality of each random binary vector. - batch_size: Minibatch size per realization. - min_length: Lower limit on number of random binary vectors in the - observation pattern. - max_length: Upper limit on number of random binary vectors in the - observation pattern. - min_repeats: Lower limit on number of times the obervation pattern - is repeated in targ. - max_repeats: Upper limit on number of times the observation pattern - is repeated in targ. - norm_max: Upper limit on uniform distribution w.r.t which the encoding - of the number of repetitions presented in the observation sequence - is normalised. - log_prob_in_bits: By default, log probabilities are expressed in units of - nats. If true, express log probabilities in bits. - time_average_cost: If true, the cost at each time step will be - divided by the `true`, sequence length, the number of non-masked time - steps, in each sequence before any subsequent reduction over the time - and batch dimensions. + data: data to be visualised + batch_size: size of batch + model_output: optional model output tensor to visualize alongside data. + whole_batch: whether to visualise the whole batch. Only the first sample + will be visualized if False + + Returns: + A string used to visualise the data batch """ - super(RepeatCopy, self).__init__(name=name) - - self._batch_size = batch_size - self._num_bits = num_bits - self._min_length = min_length - self._max_length = max_length - self._min_repeats = min_repeats - self._max_repeats = max_repeats - self._norm_max = norm_max - self._log_prob_in_bits = log_prob_in_bits - self._time_average_cost = time_average_cost - - def _normalise(self, val): - return val / self._norm_max - - def _unnormalise(self, val): - return val * self._norm_max - - @property - def time_average_cost(self): - return self._time_average_cost - - @property - def log_prob_in_bits(self): - return self._log_prob_in_bits - - @property - def num_bits(self): - """The dimensionality of each random binary vector in a pattern.""" - return self._num_bits - - @property - def target_size(self): - """The dimensionality of the target tensor.""" - return self._num_bits + 1 - - @property - def batch_size(self): - return self._batch_size - - def _build(self): - """Implements build method which adds ops to graph.""" - - # short-hand for private fields. - min_length, max_length = self._min_length, self._max_length - min_reps, max_reps = self._min_repeats, self._max_repeats - num_bits = self.num_bits - batch_size = self.batch_size - - # We reserve one dimension for the num-repeats and one for the start-marker. - full_obs_size = num_bits + 2 - # We reserve one target dimension for the end-marker. - full_targ_size = num_bits + 1 - start_end_flag_idx = full_obs_size - 2 - num_repeats_channel_idx = full_obs_size - 1 - - # Samples each batch index's sequence length and the number of repeats. - sub_seq_length_batch = tf.random_uniform( - [batch_size], minval=min_length, maxval=max_length + 1, dtype=tf.int32) - num_repeats_batch = tf.random_uniform( - [batch_size], minval=min_reps, maxval=max_reps + 1, dtype=tf.int32) - - # Pads all the batches to have the same total sequence length. - total_length_batch = sub_seq_length_batch * (num_repeats_batch + 1) + 3 - max_length_batch = tf.reduce_max(total_length_batch) - residual_length_batch = max_length_batch - total_length_batch - - obs_batch_shape = [max_length_batch, batch_size, full_obs_size] - targ_batch_shape = [max_length_batch, batch_size, full_targ_size] - mask_batch_trans_shape = [batch_size, max_length_batch] - - obs_tensors = [] - targ_tensors = [] - mask_tensors = [] - - # Generates patterns for each batch element independently. - for batch_index in range(batch_size): - sub_seq_len = sub_seq_length_batch[batch_index] - num_reps = num_repeats_batch[batch_index] - - # The observation pattern is a sequence of random binary vectors. - obs_pattern_shape = [sub_seq_len, num_bits] - obs_pattern = tf.cast( - tf.random_uniform( - obs_pattern_shape, minval=0, maxval=2, dtype=tf.int32), - tf.float32) - - # The target pattern is the observation pattern repeated n times. - # Some reshaping is required to accomplish the tiling. - targ_pattern_shape = [sub_seq_len * num_reps, num_bits] - flat_obs_pattern = tf.reshape(obs_pattern, [-1]) - flat_targ_pattern = tf.tile(flat_obs_pattern, tf.stack([num_reps])) - targ_pattern = tf.reshape(flat_targ_pattern, targ_pattern_shape) - - # Expand the obs_pattern to have two extra channels for flags. - # Concatenate start flag and num_reps flag to the sequence. - obs_flag_channel_pad = tf.zeros([sub_seq_len, 2]) - obs_start_flag = tf.one_hot( - [start_end_flag_idx], full_obs_size, on_value=1., off_value=0.) - num_reps_flag = tf.one_hot( - [num_repeats_channel_idx], - full_obs_size, - on_value=self._normalise(tf.cast(num_reps, tf.float32)), - off_value=0.) - - # note the concatenation dimensions. - obs = tf.concat([obs_pattern, obs_flag_channel_pad], 1) - obs = tf.concat([obs_start_flag, obs], 0) - obs = tf.concat([obs, num_reps_flag], 0) - - # Now do the same for the targ_pattern (it only has one extra channel). - targ_flag_channel_pad = tf.zeros([sub_seq_len * num_reps, 1]) - targ_end_flag = tf.one_hot( - [start_end_flag_idx], full_targ_size, on_value=1., off_value=0.) - targ = tf.concat([targ_pattern, targ_flag_channel_pad], 1) - targ = tf.concat([targ, targ_end_flag], 0) - - # Concatenate zeros at end of obs and begining of targ. - # This aligns them s.t. the target begins as soon as the obs ends. - obs_end_pad = tf.zeros([sub_seq_len * num_reps + 1, full_obs_size]) - targ_start_pad = tf.zeros([sub_seq_len + 2, full_targ_size]) - - # The mask is zero during the obs and one during the targ. - mask_off = tf.zeros([sub_seq_len + 2]) - mask_on = tf.ones([sub_seq_len * num_reps + 1]) - - obs = tf.concat([obs, obs_end_pad], 0) - targ = tf.concat([targ_start_pad, targ], 0) - mask = tf.concat([mask_off, mask_on], 0) - - obs_tensors.append(obs) - targ_tensors.append(targ) - mask_tensors.append(mask) - - # End the loop over batch index. - # Compute how much zero padding is needed to make tensors sequences - # the same length for all batch elements. - residual_obs_pad = [ - tf.zeros([residual_length_batch[i], full_obs_size]) - for i in range(batch_size) - ] - residual_targ_pad = [ - tf.zeros([residual_length_batch[i], full_targ_size]) - for i in range(batch_size) - ] - residual_mask_pad = [ - tf.zeros([residual_length_batch[i]]) for i in range(batch_size) - ] - - # Concatenate the pad to each batch element. - obs_tensors = [ - tf.concat([o, p], 0) for o, p in zip(obs_tensors, residual_obs_pad) - ] - targ_tensors = [ - tf.concat([t, p], 0) for t, p in zip(targ_tensors, residual_targ_pad) - ] - mask_tensors = [ - tf.concat([m, p], 0) for m, p in zip(mask_tensors, residual_mask_pad) - ] - - # Concatenate each batch element into a single tensor. - obs = tf.reshape(tf.concat(obs_tensors, 1), obs_batch_shape) - targ = tf.reshape(tf.concat(targ_tensors, 1), targ_batch_shape) - mask = tf.transpose( - tf.reshape(tf.concat(mask_tensors, 0), mask_batch_trans_shape)) - return DatasetTensors(obs, targ, mask) - - def cost(self, logits, targ, mask): - return masked_sigmoid_cross_entropy( - logits, - targ, - mask, - time_average=self.time_average_cost, - log_prob_in_bits=self.log_prob_in_bits) - - def to_human_readable(self, data, model_output=None, whole_batch=False): - obs = data.observations - unnormalised_num_reps_flag = self._unnormalise(obs[:,:,-1:]).round() - obs = np.concatenate([obs[:,:,:-1], unnormalised_num_reps_flag], axis=2) - data = data._replace(observations=obs) - return bitstring_readable(data, self.batch_size, model_output, whole_batch) + + def _readable(datum): + return "+" + " ".join(["-" if x == 0 else "%d" % x for x in datum]) + "+" + + obs_batch = data.observations + targ_batch = data.target + + iterate_over = range(batch_size) if whole_batch else range(1) + + batch_strings = [] + for batch_index in iterate_over: + obs = obs_batch[:, batch_index, :] + targ = targ_batch[:, batch_index, :] + + obs_channels = range(obs.shape[1]) + targ_channels = range(targ.shape[1]) + obs_channel_strings = [_readable(obs[:, i]) for i in obs_channels] + targ_channel_strings = [_readable(targ[:, i]) for i in targ_channels] + + readable_obs = "Observations:\n" + "\n".join(obs_channel_strings) + readable_targ = "Targets:\n" + "\n".join(targ_channel_strings) + strings = [readable_obs, readable_targ] + + if model_output is not None: + output = model_output[:, batch_index, :] + output_strings = [_readable(output[:, i]) for i in targ_channels] + strings.append("Model Output:\n" + "\n".join(output_strings)) + + batch_strings.append("\n\n".join(strings)) + + return "\n" + "\n\n\n\n".join(batch_strings) + + +class RepeatCopy(snt.Module): + """Sequence data generator for the task of repeating a random binary pattern. + + When called, an instance of this class will return a tuple of tensorflow ops + (obs, targ, mask), representing an input sequence, target sequence, and + binary mask. Each of these ops produces tensors whose first two dimensions + represent sequence position and batch index respectively. The value in + mask[t, b] is equal to 1 iff a prediction about targ[t, b, :] should be + penalized and 0 otherwise. + + For each realisation from this generator, the observation sequence is + comprised of I.I.D. uniform-random binary vectors (and some flags). + + The target sequence is comprised of this binary pattern repeated + some number of times (and some flags). Before explaining in more detail, + let's examine the setup pictorially for a single batch element: + + ```none + Note: blank space represents 0. + + time ------------------------------------------> + + +-------------------------------+ + mask: |0000000001111111111111111111111| + +-------------------------------+ + + +-------------------------------+ + target: | 1| 'end-marker' channel. + | 101100110110011011001 | + | 010101001010100101010 | + +-------------------------------+ + + +-------------------------------+ + observation: | 1011001 | + | 0101010 | + |1 | 'start-marker' channel + | 3 | 'num-repeats' channel. + +-------------------------------+ + ``` + + The length of the random pattern and the number of times it is repeated + in the target are both discrete random variables distributed according to + uniform distributions whose parameters are configured at construction time. + + The obs sequence has two extra channels (components in the trailing dimension) + which are used for flags. One channel is marked with a 1 at the first time + step and is otherwise equal to 0. The other extra channel is zero until the + binary pattern to be repeated ends. At this point, it contains an encoding of + the number of times the observation pattern should be repeated. Rather than + simply providing this integer number directly, it is normalised so that + a neural network may have an easier time representing the number of + repetitions internally. To allow a network to be readily evaluated on + instances of this task with greater numbers of repetitions, the range with + respect to which this encoding is normalised is also configurable by the user. + + As in the diagram, the target sequence is offset to begin directly after the + observation sequence; both sequences are padded with zeros to accomplish this, + resulting in their lengths being equal. Additional padding is done at the end + so that all sequences in a minibatch represent tensors with the same shape. + """ + + def __init__( + self, + num_bits=6, + batch_size=1, + min_length=1, + max_length=1, + min_repeats=1, + max_repeats=2, + norm_max=10, + log_prob_in_bits=False, + time_average_cost=False, + name="repeat_copy", + dtype=tf.float32, + ): + """Creates an instance of RepeatCopy task. + + Args: + name: A name for the generator instance (for name scope purposes). + num_bits: The dimensionality of each random binary vector. + batch_size: Minibatch size per realization. + min_length: Lower limit on number of random binary vectors in the + observation pattern. + max_length: Upper limit on number of random binary vectors in the + observation pattern. + min_repeats: Lower limit on number of times the obervation pattern + is repeated in targ. + max_repeats: Upper limit on number of times the observation pattern + is repeated in targ. + norm_max: Upper limit on uniform distribution w.r.t which the encoding + of the number of repetitions presented in the observation sequence + is normalised. + log_prob_in_bits: By default, log probabilities are expressed in units of + nats. If true, express log probabilities in bits. + time_average_cost: If true, the cost at each time step will be + divided by the `true`, sequence length, the number of non-masked time + steps, in each sequence before any subsequent reduction over the time + and batch dimensions. + """ + super(RepeatCopy, self).__init__(name=name) + + self._batch_size = batch_size + self._num_bits = num_bits + self._min_length = min_length + self._max_length = max_length + self._min_repeats = min_repeats + self._max_repeats = max_repeats + self._norm_max = norm_max + self._log_prob_in_bits = log_prob_in_bits + self._time_average_cost = time_average_cost + self._dtype = dtype + + @classmethod + def _normalise(cls, val, normalise_factor): + return val / normalise_factor + + @classmethod + def _unnormalise(cls, val, normalise_factor): + return val * normalise_factor + + @property + def time_average_cost(self): + return self._time_average_cost + + @property + def log_prob_in_bits(self): + return self._log_prob_in_bits + + @property + def num_bits(self): + """The dimensionality of each random binary vector in a pattern.""" + return self._num_bits + + @property + def target_size(self): + """The dimensionality of the target tensor.""" + return self._num_bits + 1 + + @property + def batch_size(self): + return self._batch_size + + def __call__(self): + return self._build() + # return self.datasettensor + + def _build(self): + """Implements build method which returns a new labelled data set every invocation.""" + + # short-hand for private fields. + min_length, max_length = self._min_length, self._max_length + min_reps, max_reps = self._min_repeats, self._max_repeats + num_bits = self.num_bits + batch_size = self.batch_size + + # We reserve one dimension for the num-repeats and one for the start-marker. + full_obs_size = num_bits + 2 + # We reserve one target dimension for the end-marker. + full_targ_size = num_bits + 1 + + # Samples each batch index's sequence length and the number of repeats. + sub_seq_length_batch = tf.random.uniform( + [batch_size], minval=min_length, maxval=max_length + 1, dtype=tf.int32 + ) + num_repeats_batch = tf.random.uniform( + [batch_size], minval=min_reps, maxval=max_reps + 1, dtype=tf.int32 + ) + + # Pads all the batches to have the same total sequence length. + total_length_batch = sub_seq_length_batch * (num_repeats_batch + 1) + 3 + max_length_batch = tf.reduce_max(input_tensor=total_length_batch) + residual_length_batch = max_length_batch - total_length_batch + + obs_batch_shape = [max_length_batch, batch_size, full_obs_size] + targ_batch_shape = [max_length_batch, batch_size, full_targ_size] + mask_batch_trans_shape = [batch_size, max_length_batch] + + obs_tensors = [] + targ_tensors = [] + mask_tensors = [] + + # Generates patterns for each batch element independently. + for batch_index in range(batch_size): + sub_seq_len = sub_seq_length_batch[batch_index] + num_reps = num_repeats_batch[batch_index] + + # The observation pattern is a sequence of random binary vectors. + obs_pattern_shape = [sub_seq_len, num_bits] + obs_pattern = tf.cast( + tf.random.uniform( + obs_pattern_shape, minval=0, maxval=2, dtype=tf.int32 + ), + tf.float32, + ) + + (obs, targ, mask) = self.derive_data_from_inputs( + obs_pattern, num_reps, self._norm_max + ) + + obs_tensors.append(obs) + targ_tensors.append(targ) + mask_tensors.append(mask) + + # End the loop over batch index. + # Compute how much zero padding is needed to make tensors sequences + # the same length for all batch elements. + residual_obs_pad = [ + tf.zeros([residual_length_batch[i], full_obs_size]) + for i in range(batch_size) + ] + residual_targ_pad = [ + tf.zeros([residual_length_batch[i], full_targ_size]) + for i in range(batch_size) + ] + residual_mask_pad = [ + tf.zeros([residual_length_batch[i]]) for i in range(batch_size) + ] + + # Concatenate the pad to each batch element. + obs_tensors = [ + tf.concat([o, p], 0) for o, p in zip(obs_tensors, residual_obs_pad) + ] + targ_tensors = [ + tf.concat([t, p], 0) for t, p in zip(targ_tensors, residual_targ_pad) + ] + mask_tensors = [ + tf.concat([m, p], 0) for m, p in zip(mask_tensors, residual_mask_pad) + ] + + # Concatenate each batch element into a single tensor. + obs = tf.cast( + tf.reshape(tf.concat(obs_tensors, 1), obs_batch_shape), dtype=self._dtype + ) + targ = tf.cast( + tf.reshape(tf.concat(targ_tensors, 1), targ_batch_shape), dtype=self._dtype + ) + mask = tf.cast( + tf.transpose( + a=tf.reshape(tf.concat(mask_tensors, 0), mask_batch_trans_shape) + ), + dtype=self._dtype, + ) + return DatasetTensors(obs, targ, mask) + + @classmethod + def derive_data_from_inputs( + cls, obs_pattern, num_reps, num_rep_normalise_factor=10 + ): + """Derive observation, target, and mask patterns from input observation tensor. + + Extracted from _build so it can be used for manual inspection of user defined sequences. + + Args: + cls: The RepeatCopy class + obs_pattern: Tensor representing the bit sequences to copy. + Of shape (sub_seq_len, num_bits). + num_reps: Int, number of times to repeat obs_pattern. + num_rep_normalise_factor: Double, normalisation factor for repeat parameter. + + Returns: + obs: Input tensor, obs_pattern with appropriate start sequence flag, num_reps flag + and zero padding after pattern. + targ: Target tensor, obs_pattern repeated num_reps times with appropriate stop flag + and zero padding before pattern. + mask: Mask tensor, 0s for input phase and 1s for model output phase to be used for + determining what timesteps should be considered for calculating loss + """ + + sub_seq_len, num_bits = obs_pattern.shape + + full_obs_size = num_bits + 2 + # We reserve one target dimension for the end-marker. + full_targ_size = num_bits + 1 + start_end_flag_idx = full_obs_size - 2 + num_repeats_channel_idx = full_obs_size - 1 + + # The target pattern is the observation pattern repeated n times. + # Some reshaping is required to accomplish the tiling. + targ_pattern_shape = [sub_seq_len * num_reps, num_bits] + flat_obs_pattern = tf.reshape(obs_pattern, [-1]) + flat_targ_pattern = tf.tile(flat_obs_pattern, tf.stack([num_reps])) + targ_pattern = tf.reshape(flat_targ_pattern, targ_pattern_shape) + + # Expand the obs_pattern to have two extra channels for flags. + # Concatenate start flag and num_reps flag to the sequence. + obs_flag_channel_pad = tf.zeros([sub_seq_len, 2]) + obs_start_flag = tf.one_hot( + [start_end_flag_idx], full_obs_size, on_value=1.0, off_value=0.0 + ) + num_reps_flag = tf.one_hot( + [num_repeats_channel_idx], + full_obs_size, + on_value=cls._normalise( + tf.cast(num_reps, tf.float32), num_rep_normalise_factor + ), + off_value=0.0, + ) + + # note the concatenation dimensions. + obs = tf.concat([obs_pattern, obs_flag_channel_pad], 1) + obs = tf.concat([obs_start_flag, obs], 0) + obs = tf.concat([obs, num_reps_flag], 0) + + # Now do the same for the targ_pattern (it only has one extra channel). + targ_flag_channel_pad = tf.zeros([sub_seq_len * num_reps, 1]) + targ_end_flag = tf.one_hot( + [start_end_flag_idx], full_targ_size, on_value=1.0, off_value=0.0 + ) + targ = tf.concat([targ_pattern, targ_flag_channel_pad], 1) + targ = tf.concat([targ, targ_end_flag], 0) + + # Concatenate zeros at end of obs and begining of targ. + # This aligns them s.t. the target begins as soon as the obs ends. + obs_end_pad = tf.zeros([sub_seq_len * num_reps + 1, full_obs_size]) + targ_start_pad = tf.zeros([sub_seq_len + 2, full_targ_size]) + + # The mask is zero during the obs and one during the targ. + mask_off = tf.zeros([sub_seq_len + 2]) + mask_on = tf.ones([sub_seq_len * num_reps + 1]) + + obs = tf.concat([obs, obs_end_pad], 0) + targ = tf.concat([targ_start_pad, targ], 0) + mask = tf.concat([mask_off, mask_on], 0) + + return (obs, targ, mask) + + def cost(self, logits, targ, mask): + return masked_sigmoid_cross_entropy( + logits, + targ, + mask, + time_average=self.time_average_cost, + log_prob_in_bits=self.log_prob_in_bits, + ) + + def to_human_readable(self, data, model_output=None, whole_batch=False): + data = DatasetTensors( + observations=data.observations.numpy(), + target=data.target.numpy(), + mask=data.mask.numpy(), + ) + obs = data.observations + unnormalised_num_reps_flag = self._unnormalise( + obs[:, :, -1:], self._norm_max + ).round() + obs = np.concatenate([obs[:, :, :-1], unnormalised_num_reps_flag], axis=2) + data = data._replace(observations=obs) + return bitstring_readable(data, self.batch_size, model_output, whole_batch) diff --git a/dnc/util.py b/dnc/util.py index 5009c77..2b6d29d 100644 --- a/dnc/util.py +++ b/dnc/util.py @@ -23,28 +23,26 @@ def batch_invert_permutation(permutations): - """Returns batched `tf.invert_permutation` for every row in `permutations`.""" - with tf.name_scope('batch_invert_permutation', values=[permutations]): + """Returns batched `tf.invert_permutation` for every row in `permutations`.""" perm = tf.cast(permutations, tf.float32) dim = int(perm.get_shape()[-1]) - size = tf.cast(tf.shape(perm)[0], tf.float32) - delta = tf.cast(tf.shape(perm)[-1], tf.float32) + size = tf.cast(tf.shape(input=perm)[0], tf.float32) + delta = tf.cast(tf.shape(input=perm)[-1], tf.float32) rg = tf.range(0, size * delta, delta, dtype=tf.float32) rg = tf.expand_dims(rg, 1) rg = tf.tile(rg, [1, dim]) perm = tf.add(perm, rg) flat = tf.reshape(perm, [-1]) - perm = tf.invert_permutation(tf.cast(flat, tf.int32)) + perm = tf.math.invert_permutation(tf.cast(flat, tf.int32)) perm = tf.reshape(perm, [-1, dim]) return tf.subtract(perm, tf.cast(rg, tf.int32)) def batch_gather(values, indices): - """Returns batched `tf.gather` for every row in the input.""" - with tf.name_scope('batch_gather', values=[values, indices]): - idx = tf.expand_dims(indices, -1) - size = tf.shape(indices)[0] - rg = tf.range(size, dtype=tf.int32) + """Returns batched `tf.gather` for every row in the input.""" + idx = tf.expand_dims(tf.cast(indices, tf.int32), -1) + size = tf.shape(input=indices)[0] + rg = tf.range(tf.cast(size, tf.int32), dtype=tf.int32) rg = tf.expand_dims(rg, -1) rg = tf.tile(rg, [1, int(indices.get_shape()[-1])]) rg = tf.expand_dims(rg, -1) @@ -53,20 +51,42 @@ def batch_gather(values, indices): def one_hot(length, index): - """Return an nd array of given `length` filled with 0s and a 1 at `index`.""" - result = np.zeros(length) - result[index] = 1 - return result + """Return an nd array of given `length` filled with 0s and a 1 at `index`.""" + result = np.zeros(length) + result[index] = 1 + return result + def reduce_prod(x, axis, name=None): - """Efficient reduce product over axis. - - Uses tf.cumprod and tf.gather_nd as a workaround to the poor performance of calculating tf.reduce_prod's gradient on CPU. - """ - with tf.name_scope(name, 'util_reduce_prod', values=[x]): - cp = tf.cumprod(x, axis, reverse=True) - size = tf.shape(cp)[0] - idx1 = tf.range(tf.cast(size, tf.float32), dtype=tf.float32) - idx2 = tf.zeros([size], tf.float32) - indices = tf.stack([idx1, idx2], 1) - return tf.gather_nd(cp, tf.cast(indices, tf.int32)) + """Efficient reduce product over axis. + + Uses tf.cumprod and tf.gather_nd as a workaround to the poor performance of calculating tf.reduce_prod's gradient on CPU. + + As of TF2, reduce_prod seems not to be the a culprit of increased timings: + https://github.com/tensorflow/tensorflow/issues/40748 + + Workaround code for future reference: + + with tf.compat.v1.name_scope(name, 'util_reduce_prod', values=[x]): + cp = tf.math.cumprod(x, axis, reverse=True) + size = tf.shape(input=cp)[0] + idx1 = tf.range(tf.cast(size, tf.float32), dtype=tf.float32) + idx2 = tf.zeros([size], tf.float32) + indices = tf.stack([idx1, idx2], 1) + return tf.gather_nd(cp, tf.cast(indices, tf.int32)) + """ + return tf.math.reduce_prod(x, axis=axis, name=name) + + +# Utility function to convert nested state_size to compatible zero initial_state. +def initial_state_from_state_size(state_size, batch_size, dtype): + if isinstance(state_size, int): + return tf.zeros([batch_size, state_size], dtype=dtype) + if isinstance(state_size, tf.TensorShape): + return tf.zeros([batch_size] + state_size.as_list(), dtype=dtype) + elif isinstance(state_size, list): + return [initial_state_from_state_size(s, batch_size, dtype) for s in state_size] + + raise NotImplementedError( + f"Cannot parse initial_state from state_size of type {type(state_size)}: {state_size}" + ) diff --git a/dnc/util_test.py b/dnc/util_test.py deleted file mode 100644 index 55e3f25..0000000 --- a/dnc/util_test.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2017 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for utility functions.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import tensorflow as tf - -from dnc import util - - -class BatchInvertPermutation(tf.test.TestCase): - - def test(self): - # Tests that the _batch_invert_permutation function correctly inverts a - # batch of permutations. - batch_size = 5 - length = 7 - - permutations = np.empty([batch_size, length], dtype=int) - for i in xrange(batch_size): - permutations[i] = np.random.permutation(length) - - inverse = util.batch_invert_permutation(tf.constant(permutations, tf.int32)) - with self.test_session(): - inverse = inverse.eval() - - for i in xrange(batch_size): - for j in xrange(length): - self.assertEqual(permutations[i][inverse[i][j]], j) - - -class BatchGather(tf.test.TestCase): - - def test(self): - values = np.array([[3, 1, 4, 1], [5, 9, 2, 6], [5, 3, 5, 7]]) - indexs = np.array([[1, 2, 0, 3], [3, 0, 1, 2], [0, 2, 1, 3]]) - target = np.array([[1, 4, 3, 1], [6, 5, 9, 2], [5, 5, 3, 7]]) - result = util.batch_gather(tf.constant(values), tf.constant(indexs)) - with self.test_session(): - result = result.eval() - self.assertAllEqual(target, result) - - -if __name__ == '__main__': - tf.test.main() diff --git a/interactive.ipynb b/interactive.ipynb new file mode 100644 index 0000000..13dc6fb --- /dev/null +++ b/interactive.ipynb @@ -0,0 +1,323 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "474c9cfa", + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright 2017 Google Inc.\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License.\n", + "# ==============================================================================\n", + "\"\"\"Example notebook for inspecting the DNC model trained on the repeat copy task.\"\"\"\n", + "\n", + "from __future__ import absolute_import\n", + "from __future__ import division\n", + "from __future__ import print_function\n", + "\n", + "import argparse\n", + "import datetime\n", + "import tensorflow as tf\n", + "\n", + "from dnc import dnc, access\n", + "from dnc import repeat_copy\n", + "\n", + "from collections import namedtuple\n", + "\n", + "# Update hyper parameters based on trained model\n", + "flags_dict = {\n", + " # Model parameters\n", + " \"hidden_size\": 64, # Size of LSTM hidden layer.\n", + " \"memory_size\": 16, # The number of memory slots.\n", + " \"word_size\": 16, #\"The width of each memory slot.\"\n", + " \"num_write_heads\": 1, #\"Number of memory write heads.\"\n", + " \"num_read_heads\": 4, #\"Number of memory read heads.\"\n", + " \"clip_value\": 20, #\"Maximum absolute value of controller and dnc outputs.\"\n", + "\n", + " # Optimizer parameters.\n", + " \"max_grad_norm\": 50, #\"Gradient clipping norm limit.\"\n", + " \"learning_rate\": 1e-4, #\"Optimizer learning rate.\"\n", + " \"optimizer_epsilon\": 1e-10, #\"Epsilon used for RMSProp optimizer.\"\n", + "\n", + " # Task parameters\n", + " \"batch_size\": 1, #\"Batch size for training.\"\n", + " \"num_bits\": 8, #\"Dimensionality of each vector to copy\"\n", + " \"min_length\": 1,#\"Lower limit on number of vectors in the observation pattern to copy\"\n", + " \"max_length\": 3,#\"Upper limit on number of vectors in the observation pattern to copy\"\n", + " \"min_repeats\": 1,#\"Lower limit on number of copy repeats.\"\n", + " \"max_repeats\": 3, #\"Upper limit on number of copy repeats.\"\n", + "\n", + " \"checkpoint_dir\": \"./logs/repeat_copy/checkpoint\", #\"Checkpointing directory.\"\n", + "}\n", + "\n", + "flags_schema = namedtuple('flags_schema', list(flags_dict.keys()))\n", + "FLAGS = flags_schema(**flags_dict)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3112d2e0", + "metadata": {}, + "outputs": [], + "source": [ + "def load_model():\n", + " \"\"\"Load dnc core model from checkpoint directory\"\"\"\n", + " access_config = {\n", + " \"memory_size\": FLAGS.memory_size,\n", + " \"word_size\": FLAGS.word_size,\n", + " \"num_reads\": FLAGS.num_read_heads,\n", + " \"num_writes\": FLAGS.num_write_heads,\n", + " }\n", + " controller_config = {\n", + " #\"hidden_size\": FLAGS.hidden_size,\n", + " \"units\": FLAGS.hidden_size,\n", + " }\n", + " clip_value = FLAGS.clip_value\n", + "\n", + " dnc_cell = dnc.DNC(\n", + " access_config, controller_config, FLAGS.num_bits + 1, FLAGS.batch_size, clip_value)\n", + " dnc_core = tf.keras.layers.RNN(\n", + " cell=dnc_cell,\n", + " time_major=True,\n", + " return_sequences=True,\n", + " return_state=True,\n", + " )\n", + " optimizer = tf.compat.v1.train.RMSPropOptimizer(\n", + " FLAGS.learning_rate, epsilon=FLAGS.optimizer_epsilon)\n", + "\n", + " # Set up model checkpointing\n", + " checkpoint = tf.train.Checkpoint(model=dnc_core, optimizer=optimizer)\n", + " manager = tf.train.CheckpointManager(checkpoint, FLAGS.checkpoint_dir, max_to_keep=10)\n", + "\n", + " checkpoint.restore(manager.latest_checkpoint)\n", + " if manager.latest_checkpoint:\n", + " print(\"Restored from {}\".format(manager.latest_checkpoint))\n", + " else:\n", + " print(\"Initializing from scratch.\")\n", + " return dnc_core\n", + "\n", + "\n", + "dnc_core = load_model()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8daa62a5", + "metadata": {}, + "outputs": [], + "source": [ + "def read_weights_from_dnc_state(dnc_state):\n", + " return dnc_state[dnc.ACCESS_STATE][access.READ_WEIGHTS]\n", + "def write_weights_from_dnc_state(dnc_state):\n", + " return dnc_state[dnc.ACCESS_STATE][access.WRITE_WEIGHTS]\n", + "def memory_from_dnc_state(dnc_state):\n", + " return dnc_state[dnc.ACCESS_STATE][access.MEMORY]\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6500f979", + "metadata": {}, + "outputs": [], + "source": [ + "def evaluate_model(\n", + " x,\n", + " mask,\n", + " rnn_model,\n", + "):\n", + " \"\"\"Obtain output sequence and intermediate states when evaluating x.\n", + " \n", + " Args:\n", + " x: input tensor\n", + " mask: Mask tensor, currently unused\n", + " rnn_model: keras.layers.RNN instance\n", + " \n", + " Returns:\n", + " output_sequence: List of tensors representing the model output\n", + " sequence for each time step\n", + " output_states: List of rnn states (may be nested list of tensors)\n", + " output for each time step\n", + " \"\"\"\n", + " output_sequence = []\n", + " output_states = []\n", + " input_state = rnn_model.get_initial_state(inputs=x)\n", + " \n", + " for input_seq in x:\n", + " output = rnn_model(\n", + " inputs=tf.expand_dims(input_seq, axis=0),\n", + " initial_state=input_state,\n", + " )\n", + " \n", + " output_seq = output[0]\n", + " output_state = output[1:]\n", + " \n", + " output_sequence.append(tf.round(tf.sigmoid(output_seq)))\n", + " #output_sequence.append(output_seq)\n", + " output_states.append(output_state)\n", + "\n", + " input_state = output_state\n", + " \n", + " return output_sequence, output_states" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b29aad3a", + "metadata": {}, + "outputs": [], + "source": [ + "import seaborn\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "def visualize_results(obs, targ, pred, mask):\n", + " obs = tf.transpose(obs)\n", + " targ = tf.transpose(targ)\n", + " pred = tf.transpose(tf.squeeze(pred))\n", + " \n", + " seaborn.set(rc = {'figure.figsize':(\n", + " 15.0 / 64 * obs.shape[1], # time, x-axis\n", + " 15.0 / 64 * obs.shape[0], # biz position, y-axis\n", + " )})\n", + " \n", + " seaborn.heatmap(obs)\n", + " plt.title('RepeatCopy Task Inputs')\n", + " plt.xlabel('time step')\n", + " plt.ylabel('bit position')\n", + " plt.show()\n", + " \n", + " seaborn.heatmap(targ)\n", + " plt.title('RepeatCopy Task Target')\n", + " plt.xlabel('time step')\n", + " plt.ylabel('bit position')\n", + " plt.show()\n", + " \n", + " seaborn.heatmap(pred)\n", + " plt.title('RepeatCopy Task Model Outputs')\n", + " plt.xlabel('time step')\n", + " plt.ylabel('bit position')\n", + " plt.show()\n", + "\n", + "def visualize_states(states):\n", + " #memory = [memory_from_dnc_state(state)[0] for state in states]\n", + " read_weights = [read_weights_from_dnc_state(state)[0] for state in states]\n", + " read_weights = tf.transpose(tf.stack(read_weights), [1,2,0])\n", + " \n", + " write_weights = [write_weights_from_dnc_state(state)[0] for state in states]\n", + " write_weights = tf.transpose(tf.stack(write_weights), [1,2,0])\n", + " \n", + " \"\"\"memory_color_range = {\n", + " 'vmin': np.min(memory),\n", + " 'vmax': np.max(memory)\n", + " }\"\"\"\n", + " read_weights_color_range = {\n", + " 'vmin': np.min(read_weights),\n", + " 'vmax': np.max(read_weights),\n", + " }\n", + " write_weights_color_range = {\n", + " 'vmin': np.min(write_weights),\n", + " 'vmax': np.max(write_weights),\n", + " }\n", + " \n", + " \n", + " seaborn.set(rc = {'figure.figsize':(\n", + " 15.0 / 64 * write_weights.shape[2], # time, x-axis\n", + " 15.0 / 64 * write_weights.shape[1], # memory, y-axis\n", + " )})\n", + " \n", + " # Visualize write weights over time\n", + " for i, write_head in enumerate(write_weights):\n", + " seaborn.heatmap(write_head, **write_weights_color_range)\n", + " plt.title(f'Write Weights for Write Head {i}')\n", + " plt.xlabel('time step')\n", + " plt.ylabel('memory slot')\n", + " plt.show()\n", + " \n", + " # Visualize read weights over time\n", + " for i, read_head in enumerate(read_weights):\n", + " seaborn.heatmap(read_head, **read_weights_color_range)\n", + " plt.title(f'Read Weights for Read Head {i}')\n", + " plt.xlabel('time step')\n", + " plt.ylabel('memory slot')\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "89115c0e", + "metadata": {}, + "outputs": [], + "source": [ + "def debug_model(x, num_repeats):\n", + " x = tf.convert_to_tensor(x, dtype=tf.float32)\n", + " obs, targ, mask = repeat_copy.RepeatCopy.derive_data_from_inputs(\n", + " x, \n", + " num_repeats, \n", + " 10 # repeat_copy._norm_max, default value of 10, modify if using different norm\n", + " )\n", + " \n", + " output_sequence, states = evaluate_model(tf.expand_dims(obs, [1]), None, dnc_core)\n", + " \n", + " visualize_results(obs, targ, tf.stack(output_sequence), mask)\n", + " visualize_states(states)\n", + " return output_sequence, states" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eeb76634", + "metadata": {}, + "outputs": [], + "source": [ + "a = debug_model([[1]*8, [0]*8], 20)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b73f6272", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..75560f7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,65 @@ +absl-py==0.12.0 +appdirs==1.4.4 +astunparse==1.6.3 +attrs==21.2.0 +black==21.6b0 +cachetools==4.2.2 +certifi==2020.12.5 +cfgv==3.3.0 +chardet==4.0.0 +click==8.0.1 +distlib==0.3.2 +dm-sonnet==2.0.0 +dm-tree==0.1.6 +filelock==3.0.12 +flake8==3.9.2 +flatbuffers==1.12 +gast==0.4.0 +google-auth==1.30.0 +google-auth-oauthlib==0.4.4 +google-pasta==0.2.0 +grpcio==1.34.1 +h5py==3.1.0 +identify==2.2.10 +idna==2.10 +iniconfig==1.1.1 +keras-nightly==2.5.0.dev2021032900 +Keras-Preprocessing==1.1.2 +Markdown==3.3.4 +mccabe==0.6.1 +mypy-extensions==0.4.3 +nodeenv==1.6.0 +numpy==1.19.5 +oauthlib==3.1.0 +opt-einsum==3.3.0 +packaging==20.9 +pathspec==0.8.1 +pluggy==0.13.1 +pre-commit==2.13.0 +protobuf==3.17.0 +py==1.10.0 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +pycodestyle==2.7.0 +pyflakes==2.3.1 +pyparsing==2.4.7 +pytest==6.2.4 +PyYAML==5.4.1 +regex==2021.4.4 +requests==2.25.1 +requests-oauthlib==1.3.0 +rsa==4.7.2 +six==1.15.0 +tabulate==0.8.9 +tensorboard==2.5.0 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.0 +tensorflow==2.5.0 +tensorflow-estimator==2.5.0 +termcolor==1.1.0 +toml==0.10.2 +typing-extensions==3.7.4.3 +urllib3==1.26.4 +virtualenv==20.4.7 +Werkzeug==2.0.0 +wrapt==1.12.1 diff --git a/setup.py b/setup.py index a5c3ee4..2a6bb1f 100644 --- a/setup.py +++ b/setup.py @@ -1,12 +1,21 @@ from setuptools import setup setup( - name='dnc', - version='0.0.2', - description='This package provides an implementation of the Differentiable Neural Computer, as published in Nature.', - license='Apache Software License 2.0', - packages=['dnc'], - author='DeepMind', - keywords=['tensorflow', 'differentiable neural computer', 'dnc', 'deepmind', 'deep mind', 'sonnet', 'dm-sonnet', 'machine learning'], - url='https://github.com/deepmind/dnc' + name="dnc", + version="0.0.2", + description="This package provides an implementation of the Differentiable Neural Computer, as published in Nature.", + license="Apache Software License 2.0", + packages=["dnc"], + author="DeepMind", + keywords=[ + "tensorflow", + "differentiable neural computer", + "dnc", + "deepmind", + "deep mind", + "sonnet", + "dm-sonnet", + "machine learning", + ], + url="https://github.com/deepmind/dnc", ) diff --git a/tests/access_test.py b/tests/access_test.py new file mode 100644 index 0000000..678accc --- /dev/null +++ b/tests/access_test.py @@ -0,0 +1,187 @@ +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for memory access.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf +from tensorflow.python.framework import random_seed + +from dnc import access, addressing, util + +BATCH_SIZE = 2 +MEMORY_SIZE = 20 +WORD_SIZE = 6 +NUM_READS = 2 +NUM_WRITES = 3 +TIME_STEPS = 4 +INPUT_SIZE = 10 + +DTYPE = tf.float32 + +# set seeds for determinism +np.random.seed(42) +random_seed.set_seed(42) + + +class MemoryAccessTest(tf.test.TestCase): + def setUp(self): + self.cell = access.MemoryAccess(MEMORY_SIZE, WORD_SIZE, NUM_READS, NUM_WRITES) + self.module = tf.keras.layers.RNN( + cell=self.cell, + time_major=True, + return_sequences=True, + ) + + def testBuildAndTrain(self): + inputs = tf.random.normal([TIME_STEPS, BATCH_SIZE, INPUT_SIZE], dtype=DTYPE) + targets = np.random.rand(TIME_STEPS, BATCH_SIZE, NUM_READS, WORD_SIZE) + + def loss(outputs, targets): + return tf.reduce_mean(input_tensor=tf.square(outputs - targets)) + + with tf.GradientTape() as tape: + outputs = self.module( + inputs=inputs, + initial_state=self.module.get_initial_state(inputs), + ) + loss_value = loss(outputs, targets) + gradients = tape.gradient(loss_value, self.module.trainable_variables) + + optimizer = tf.keras.optimizers.SGD(learning_rate=0.1) + optimizer.apply_gradients(zip(gradients, self.module.trainable_variables)) + + def testValidReadMode(self): + inputs = self.cell._read_inputs( + tf.random.normal([BATCH_SIZE, INPUT_SIZE], dtype=DTYPE) + ) + + # Check that the read modes for each read head constitute a probability + # distribution. + self.assertAllClose( + inputs["read_mode"].numpy().sum(2), np.ones([BATCH_SIZE, NUM_READS]) + ) + self.assertGreaterEqual(inputs["read_mode"].numpy().min(), 0) + + def testWriteWeights(self): + memory = 10 * (np.random.rand(BATCH_SIZE, MEMORY_SIZE, WORD_SIZE) - 0.5) + usage = np.random.rand(BATCH_SIZE, MEMORY_SIZE) + + allocation_gate = np.random.rand(BATCH_SIZE, NUM_WRITES) + write_gate = np.random.rand(BATCH_SIZE, NUM_WRITES) + write_content_keys = np.random.rand(BATCH_SIZE, NUM_WRITES, WORD_SIZE) + write_content_strengths = np.random.rand(BATCH_SIZE, NUM_WRITES) + + # Check that turning on allocation gate fully brings the write gate to + # the allocation weighting (which we will control by controlling the usage). + usage[:, 3] = 0 + allocation_gate[:, 0] = 1 + write_gate[:, 0] = 1 + + inputs = { + "allocation_gate": tf.constant(allocation_gate, dtype=DTYPE), + "write_gate": tf.constant(write_gate, dtype=DTYPE), + "write_content_keys": tf.constant(write_content_keys, dtype=DTYPE), + "write_content_strengths": tf.constant( + write_content_strengths, dtype=DTYPE + ), + } + + weights = self.cell._write_weights( + inputs, tf.constant(memory, dtype=DTYPE), tf.constant(usage, dtype=DTYPE) + ) + + weights = weights.numpy() + + # Check the weights sum to their target gating. + self.assertAllClose(np.sum(weights, axis=2), write_gate, atol=5e-2) + + # Check that we fully allocated to the third row. + weights_0_0_target = util.one_hot(MEMORY_SIZE, 3) + self.assertAllClose(weights[0, 0], weights_0_0_target, atol=1e-3) + + def testReadWeights(self): + memory = 10 * (np.random.rand(BATCH_SIZE, MEMORY_SIZE, WORD_SIZE) - 0.5) + prev_read_weights = np.random.rand(BATCH_SIZE, NUM_READS, MEMORY_SIZE) + prev_read_weights /= prev_read_weights.sum(2, keepdims=True) + 1 + + link = np.random.rand(BATCH_SIZE, NUM_WRITES, MEMORY_SIZE, MEMORY_SIZE) + # Row and column sums should be at most 1: + link /= np.maximum(link.sum(2, keepdims=True), 1) + link /= np.maximum(link.sum(3, keepdims=True), 1) + + # We query the memory on the third location in memory, and select a large + # strength on the query. Then we select a content-based read-mode. + read_content_keys = np.random.rand(BATCH_SIZE, NUM_READS, WORD_SIZE) + read_content_keys[0, 0] = memory[0, 3] + read_content_strengths = tf.constant( + 100.0, shape=[BATCH_SIZE, NUM_READS], dtype=DTYPE + ) + read_mode = np.random.rand(BATCH_SIZE, NUM_READS, 1 + 2 * NUM_WRITES) + read_mode[0, 0, :] = util.one_hot(1 + 2 * NUM_WRITES, 2 * NUM_WRITES) + inputs = { + "read_content_keys": tf.constant(read_content_keys, dtype=DTYPE), + "read_content_strengths": read_content_strengths, + "read_mode": tf.constant(read_mode, dtype=DTYPE), + } + read_weights = self.cell._read_weights( + inputs, + tf.cast(memory, dtype=DTYPE), + tf.cast(prev_read_weights, dtype=DTYPE), + tf.cast(link, dtype=DTYPE), + ) + read_weights = read_weights.numpy() + + # read_weights for batch 0, read head 0 should be memory location 3 + self.assertAllClose( + read_weights[0, 0, :], util.one_hot(MEMORY_SIZE, 3), atol=1e-3 + ) + + def testGradients(self): + inputs = tf.constant(np.random.randn(1, BATCH_SIZE, INPUT_SIZE), dtype=DTYPE) + initial_state = self.module.get_initial_state(inputs=inputs) + + def evaluate_module(inputs, memory, read_weights, precedence_weights, link): + # construct initial state with tensors to check + init_state = [ + memory, + read_weights, + initial_state[access.WRITE_WEIGHTS], + [link, precedence_weights], + initial_state[access.USAGE], + ] + output = self.module(inputs, init_state) + loss = tf.reduce_sum(input_tensor=output) + return loss + + tensors_to_check = [ + inputs, + initial_state[access.MEMORY], + initial_state[access.READ_WEIGHTS], + initial_state[access.LINKAGE][addressing.PRECEDENCE_WEIGHTS], + initial_state[access.LINKAGE][addressing.LINK], + ] + + theoretical, numerical = tf.test.compute_gradient( + evaluate_module, tensors_to_check, delta=1e-5 + ) + self.assertLess( + sum([tf.norm(numerical[i] - theoretical[i]) for i in range(2)]), + 0.02, + tensors_to_check, + ) diff --git a/tests/addressing_test.py b/tests/addressing_test.py new file mode 100644 index 0000000..9907ee7 --- /dev/null +++ b/tests/addressing_test.py @@ -0,0 +1,404 @@ +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for memory addressing.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import sonnet as snt +import tensorflow as tf +from tensorflow.python.framework import random_seed + +from dnc import addressing, util + +# set seeds for determinism +np.random.seed(42) +random_seed.set_seed(42) + + +class WeightedSoftmaxTest(tf.test.TestCase): + def testValues(self): + batch_size = 5 + num_heads = 3 + memory_size = 7 + + activations = np.random.randn(batch_size, num_heads, memory_size) + weights = np.ones((batch_size, num_heads)) + + # Run weighted softmax with identity placed on weights. Output should be + # equal to a standalone softmax. + observed = addressing.weighted_softmax(activations, weights, tf.identity) + expected = snt.BatchApply(tf.nn.softmax, num_dims=1)((activations)) + self.assertAllClose(observed, expected) + + +class CosineWeightsTest(tf.test.TestCase): + def testShape(self): + batch_size = 5 + num_heads = 3 + memory_size = 7 + word_size = 2 + + module = addressing.CosineWeights(num_heads, word_size) + mem = np.random.randn(batch_size, memory_size, word_size) + keys = np.random.randn(batch_size, num_heads, word_size) + strengths = np.random.randn(batch_size, num_heads) + weights = module(mem, keys, strengths) + self.assertTrue( + weights.get_shape().is_compatible_with([batch_size, num_heads, memory_size]) + ) + + def testValues(self): + batch_size = 5 + num_heads = 4 + memory_size = 10 + word_size = 2 + + mem = np.random.randn(batch_size, memory_size, word_size) + np.copyto(mem[0, 0], [1, 2]) + np.copyto(mem[0, 1], [3, 4]) + np.copyto(mem[0, 2], [5, 6]) + + keys = np.random.randn(batch_size, num_heads, word_size) + np.copyto(keys[0, 0], [5, 6]) + np.copyto(keys[0, 1], [1, 2]) + np.copyto(keys[0, 2], [5, 6]) + np.copyto(keys[0, 3], [3, 4]) + strengths = np.random.randn(batch_size, num_heads) + + module = addressing.CosineWeights(num_heads, word_size) + weights = module(mem, keys, strengths) + + # Manually checks results. + strengths_softplus = np.log(1 + np.exp(strengths)) + similarity = np.zeros((memory_size)) + + for b in range(batch_size): + for h in range(num_heads): + key = keys[b, h] + key_norm = np.linalg.norm(key) + + for m in range(memory_size): + row = mem[b, m] + similarity[m] = np.dot(key, row) / (key_norm * np.linalg.norm(row)) + + similarity = np.exp(similarity * strengths_softplus[b, h]) + similarity /= similarity.sum() + self.assertAllClose(weights[b, h], similarity, atol=1e-4, rtol=1e-4) + + def testDivideByZero(self): + batch_size = 5 + num_heads = 4 + memory_size = 10 + word_size = 2 + + module = addressing.CosineWeights(num_heads, word_size) + keys = tf.Variable( + tf.random.normal([batch_size, num_heads, word_size], dtype=tf.float64) + ) + strengths = tf.Variable( + tf.random.normal([batch_size, num_heads], dtype=tf.float64) + ) + + # First row of memory is non-zero to concentrate attention on this location. + # Remaining rows are all zero. + first_row_ones = tf.ones([batch_size, 1, word_size], dtype=tf.float64) + remaining_zeros = tf.zeros( + [batch_size, memory_size - 1, word_size], dtype=tf.float64 + ) + mem = tf.Variable(tf.concat((first_row_ones, remaining_zeros), 1)) + + with tf.GradientTape() as gtape: + output = module(mem, keys, strengths) + gradients = gtape.gradient(target=output, sources=[mem, keys, strengths]) + + self.assertFalse(np.any(np.isnan(output))) + self.assertFalse(np.any(np.isnan(gradients[0]))) + self.assertFalse(np.any(np.isnan(gradients[1]))) + self.assertFalse(np.any(np.isnan(gradients[2]))) + + +class TemporalLinkageTest(tf.test.TestCase): + def testModule(self): + batch_size = 7 + memory_size = 4 + num_reads = 11 + num_writes = 5 + module = addressing.TemporalLinkage( + memory_size=memory_size, num_writes=num_writes + ) + + state = [ + # link + np.zeros([batch_size, num_writes, memory_size, memory_size]), + # precedence_weights + np.zeros([batch_size, num_writes, memory_size]), + ] + + num_steps = 5 + for i in range(num_steps): + write_weights = np.random.rand(batch_size, num_writes, memory_size) + write_weights /= write_weights.sum(2, keepdims=True) + 1 + + # Simulate (in final steps) link 0-->1 in head 0 and 3-->2 in head 1 + if i == num_steps - 2: + write_weights[0, 0, :] = util.one_hot(memory_size, 0) + write_weights[0, 1, :] = util.one_hot(memory_size, 3) + elif i == num_steps - 1: + write_weights[0, 0, :] = util.one_hot(memory_size, 1) + write_weights[0, 1, :] = util.one_hot(memory_size, 2) + + prev_link_in = state[addressing.LINK] + prev_precedence_weights_in = state[addressing.PRECEDENCE_WEIGHTS] + write_weights_in = write_weights + + state = module( + write_weights_in, + [ + # link + prev_link_in, + # precedence_weights + prev_precedence_weights_in, + ], + ) + + result_link = state[addressing.LINK] + + # link should be bounded in range [0, 1] + self.assertGreaterEqual(tf.math.reduce_min(result_link), 0) + self.assertLessEqual(tf.math.reduce_max(result_link), 1) + + # link diagonal should be zero + self.assertAllEqual( + tf.linalg.diag_part(result_link), + np.zeros([batch_size, num_writes, memory_size]), + ) + + # link rows and columns should sum to at most 1 + self.assertLessEqual( + tf.math.reduce_max(tf.math.reduce_sum(result_link, axis=2)), 1 + ) + self.assertLessEqual( + tf.math.reduce_max(tf.math.reduce_sum(result_link, axis=3)), 1 + ) + + # records our transitions in batch 0: head 0: 0->1, and head 1: 3->2 + self.assertAllEqual(result_link[0, 0, :, 0], util.one_hot(memory_size, 1)) + self.assertAllEqual(result_link[0, 1, :, 3], util.one_hot(memory_size, 2)) + + # Now test calculation of forward and backward read weights + prev_read_weights = np.random.rand(batch_size, num_reads, memory_size) + prev_read_weights[0, 5, :] = util.one_hot(memory_size, 0) # read 5, posn 0 + prev_read_weights[0, 6, :] = util.one_hot(memory_size, 2) # read 6, posn 2 + forward_read_weights = module.directional_read_weights( + tf.constant(result_link), + tf.constant(prev_read_weights, dtype=tf.float64), + forward=True, + ) + backward_read_weights = module.directional_read_weights( + tf.constant(result_link), + tf.constant(prev_read_weights, dtype=tf.float64), + forward=False, + ) + + # Check directional weights calculated correctly. + self.assertAllEqual( + forward_read_weights[0, 5, 0, :], # read=5, write=0 + util.one_hot(memory_size, 1), + ) + self.assertAllEqual( + backward_read_weights[0, 6, 1, :], # read=6, write=1 + util.one_hot(memory_size, 3), + ) + + def testPrecedenceWeights(self): + batch_size = 7 + memory_size = 3 + num_writes = 5 + module = addressing.TemporalLinkage( + memory_size=memory_size, num_writes=num_writes + ) + + prev_precedence_weights = np.random.rand(batch_size, num_writes, memory_size) + write_weights = np.random.rand(batch_size, num_writes, memory_size) + + # These should sum to at most 1 for each write head in each batch. + write_weights /= write_weights.sum(2, keepdims=True) + 1 + prev_precedence_weights /= prev_precedence_weights.sum(2, keepdims=True) + 1 + + write_weights[0, 1, :] = 0 # batch 0 head 1: no writing + write_weights[1, 2, :] /= write_weights[1, 2, :].sum() # b1 h2: all writing + + precedence_weights = module._precedence_weights( + prev_precedence_weights=tf.constant(prev_precedence_weights), + write_weights=tf.constant(write_weights), + ) + + # precedence weights should be bounded in range [0, 1] + self.assertGreaterEqual(tf.math.reduce_min(precedence_weights), 0) + self.assertLessEqual(tf.math.reduce_max(precedence_weights), 1) + + # no writing in batch 0, head 1 + self.assertAllClose( + precedence_weights[0, 1, :], prev_precedence_weights[0, 1, :] + ) + + # all writing in batch 1, head 2 + self.assertAllClose(precedence_weights[1, 2, :], write_weights[1, 2, :]) + + +class FreenessTest(tf.test.TestCase): + def testModule(self): + batch_size = 5 + memory_size = 11 + num_reads = 3 + num_writes = 7 + module = addressing.Freeness(memory_size) + + free_gate = np.random.rand(batch_size, num_reads) + + # Produce read weights that sum to 1 for each batch and head. + prev_read_weights = np.random.rand(batch_size, num_reads, memory_size) + prev_read_weights[1, :, 3] = 0 # no read at batch 1, position 3; see below + prev_read_weights /= prev_read_weights.sum(2, keepdims=True) + prev_write_weights = np.random.rand(batch_size, num_writes, memory_size) + prev_write_weights /= prev_write_weights.sum(2, keepdims=True) + prev_usage = np.random.rand(batch_size, memory_size) + + # Add some special values that allows us to test the behaviour: + prev_write_weights[1, 2, 3] = 1 # full write in batch 1, head 2, position 3 + prev_read_weights[2, 0, 4] = 1 # full read at batch 2, head 0, position 4 + free_gate[2, 0] = 1 # can free up all locations for batch 2, read head 0 + + usage = module( + tf.constant(prev_write_weights), + tf.constant(free_gate), + tf.constant(prev_read_weights), + tf.constant(prev_usage), + ) + + usage = usage.numpy() + + # Check all usages are between 0 and 1. + self.assertGreaterEqual(usage.min(), 0) + self.assertLessEqual(usage.max(), 1) + + # Check that the full write at batch 1, position 3 makes it fully used. + self.assertEqual(usage[1][3], 1) + + # Check that the full free at batch 2, position 4 makes it fully free. + self.assertEqual(usage[2][4], 0) + + def testWriteAllocationWeights(self): + batch_size = 7 + memory_size = 23 + num_writes = 5 + module = addressing.Freeness(memory_size) + + usage = np.random.rand(batch_size, memory_size) + write_gates = np.random.rand(batch_size, num_writes) + + # Turn off gates for heads 1 and 3 in batch 0. This doesn't scaling down the + # weighting, but it means that the usage doesn't change, so we should get + # the same allocation weightings for: (1, 2) and (3, 4) (but all others + # being different). + write_gates[0, 1] = 0 + write_gates[0, 3] = 0 + # and turn heads 0 and 2 on for full effect. + write_gates[0, 0] = 1 + write_gates[0, 2] = 1 + + # In batch 1, make one of the usages 0 and another almost 0, so that these + # entries get most of the allocation weights for the first and second heads. + usage[1] = usage[1] * 0.9 + 0.1 # make sure all entries are in [0.1, 1] + usage[1][4] = 0 # write head 0 should get allocated to position 4 + usage[1][3] = 1e-4 # write head 1 should get allocated to position 3 + write_gates[1, 0] = 1 # write head 0 fully on + write_gates[1, 1] = 1 # write head 1 fully on + + weights = module.write_allocation_weights( + usage=tf.constant(usage), + write_gates=tf.constant(write_gates), + num_writes=num_writes, + ) + + weights = weights.numpy() + + # Check that all weights are between 0 and 1 + self.assertGreaterEqual(weights.min(), 0) + self.assertLessEqual(weights.max(), 1) + + # Check that weights sum to close to 1 + self.assertAllClose( + np.sum(weights, axis=2), np.ones([batch_size, num_writes]), atol=1e-3 + ) + + # Check the same / different allocation weight pairs as described above. + self.assertGreater(np.abs(weights[0, 0, :] - weights[0, 1, :]).max(), 0.1) + self.assertAllEqual(weights[0, 1, :], weights[0, 2, :]) + self.assertGreater(np.abs(weights[0, 2, :] - weights[0, 3, :]).max(), 0.1) + self.assertAllEqual(weights[0, 3, :], weights[0, 4, :]) + + self.assertAllClose(weights[1][0], util.one_hot(memory_size, 4), atol=1e-3) + self.assertAllClose(weights[1][1], util.one_hot(memory_size, 3), atol=1e-3) + + def testWriteAllocationWeightsGradient(self): + batch_size = 7 + memory_size = 5 + num_writes = 3 + module = addressing.Freeness(memory_size) + + usage = tf.constant(np.random.rand(batch_size, memory_size)) + write_gates = tf.constant(np.random.rand(batch_size, num_writes)) + # weights = module.write_allocation_weights(usage, write_gates, num_writes) + + theoretical, numerical = tf.test.compute_gradient( + lambda usage, write_gates: module.write_allocation_weights( + usage, write_gates, num_writes + ), + [usage, write_gates], + delta=1e-5, + ) + self.assertLess( + sum([tf.norm(numerical[i] - theoretical[i]) for i in range(2)]), 0.01 + ) + + def testAllocation(self): + batch_size = 7 + memory_size = 13 + usage = np.random.rand(batch_size, memory_size) + module = addressing.Freeness(memory_size) + allocation = module._allocation(tf.constant(usage)) + + # 1. Test that max allocation goes to min usage, and vice versa. + self.assertAllEqual(np.argmin(usage, axis=1), np.argmax(allocation, axis=1)) + self.assertAllEqual(np.argmax(usage, axis=1), np.argmin(allocation, axis=1)) + + # 2. Test that allocations sum to almost 1. + self.assertAllClose(np.sum(allocation, axis=1), np.ones(batch_size), 0.01) + + def testAllocationGradient(self): + batch_size = 1 + memory_size = 5 + usage = tf.constant(np.random.rand(batch_size, memory_size)) + module = addressing.Freeness(memory_size) + theoretical, numerical = tf.test.compute_gradient( + module._allocation, [usage], delta=1e-5 + ) + self.assertLess( + sum([tf.norm(numerical[i] - theoretical[i]) for i in range(1)]), 0.01 + ) diff --git a/tests/dnc_test.py b/tests/dnc_test.py new file mode 100644 index 0000000..8a5765a --- /dev/null +++ b/tests/dnc_test.py @@ -0,0 +1,104 @@ +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for DNCCore""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import datetime +import numpy as np +import tensorflow as tf +from tensorflow.python.framework import random_seed + +from dnc import dnc, access, addressing +from dnc import repeat_copy + +# set seeds for determinism +np.random.seed(42) +random_seed.set_seed(42) + +DTYPE = tf.float32 + +# Model parameters +HIDDEN_SIZE = 64 +MEMORY_SIZE = 16 +WORD_SIZE = 16 +NUM_WRITE_HEADS = 1 +NUM_READ_HEADS = 4 +CLIP_VALUE = 20 + +# Optimizer parameters. +MAX_GRAD_NORM = 50 +LEARNING_RATE = 1e-4 +OPTIMIZER_EPSILON = 1e-10 + +# Task parameters +BATCH_SIZE = 16 +TIME_STEPS = 4 +INPUT_SIZE = 4 +OUTPUT_SIZE = 4 + + +class DNCCoreTest(tf.test.TestCase): + def setUp(self): + access_config = { + "memory_size": MEMORY_SIZE, + "word_size": WORD_SIZE, + "num_reads": NUM_READ_HEADS, + "num_writes": NUM_WRITE_HEADS, + } + controller_config = { + # "hidden_size": FLAGS.hidden_size, + "units": HIDDEN_SIZE, + } + + self.module = dnc.DNC( + access_config, + controller_config, + OUTPUT_SIZE, + BATCH_SIZE, + CLIP_VALUE, + name="dnc_test", + dtype=DTYPE, + ) + self.initial_state = self.module.get_initial_state(batch_size=BATCH_SIZE) + + def testBuildAndTrain(self): + inputs = tf.random.normal([TIME_STEPS, BATCH_SIZE, INPUT_SIZE], dtype=DTYPE) + targets = np.random.rand(TIME_STEPS, BATCH_SIZE, OUTPUT_SIZE) + + def loss(outputs, targets): + return tf.reduce_mean(input_tensor=tf.square(outputs - targets)) + + optimizer = tf.compat.v1.train.RMSPropOptimizer( + LEARNING_RATE, epsilon=OPTIMIZER_EPSILON + ) + + with tf.GradientTape() as tape: + # outputs, _ = tf.compat.v1.nn.dynamic_rnn( + outputs = tf.keras.layers.RNN( + cell=self.module, + time_major=True, + return_sequences=True, + )( + inputs=inputs, + initial_state=self.initial_state, + ) + loss_value = loss(outputs, targets) + gradients = tape.gradient(loss_value, self.module.trainable_variables) + + grads, _ = tf.clip_by_global_norm(gradients, MAX_GRAD_NORM) + optimizer.apply_gradients(zip(gradients, self.module.trainable_variables)) diff --git a/tests/util_test.py b/tests/util_test.py new file mode 100644 index 0000000..d281722 --- /dev/null +++ b/tests/util_test.py @@ -0,0 +1,86 @@ +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for utility functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf +import pytest + +from dnc import util + +# set seeds for determinism +np.random.seed(42) + + +class BatchInvertPermutation(tf.test.TestCase): + def test(self): + # Tests that the _batch_invert_permutation function correctly inverts a + # batch of permutations. + batch_size = 5 + length = 7 + + permutations = np.empty([batch_size, length], dtype=int) + for i in range(batch_size): + permutations[i] = np.random.permutation(length) + + inverse = util.batch_invert_permutation(tf.constant(permutations, tf.int32)) + inverse = inverse.numpy() + + for i in range(batch_size): + for j in range(length): + self.assertEqual(permutations[i][inverse[i][j]], j) + + +class BatchGather(tf.test.TestCase): + def test(self): + values = np.array([[3, 1, 4, 1], [5, 9, 2, 6], [5, 3, 5, 7]]) + indexs = np.array([[1, 2, 0, 3], [3, 0, 1, 2], [0, 2, 1, 3]]) + target = np.array([[1, 4, 3, 1], [6, 5, 9, 2], [5, 5, 3, 7]]) + result = util.batch_gather(tf.constant(values), tf.constant(indexs)) + result = result.numpy() + self.assertAllEqual(target, result) + + +@pytest.mark.parametrize( + "batch_size, state_size, initial_state", + [ + (2, [], []), + (2, 2, tf.zeros([2, 2], dtype=tf.float32)), + ( + 2, + [tf.TensorShape([1, 3]), 2], + [tf.zeros([2, 1, 3], dtype=tf.float32), tf.zeros([2, 2], dtype=tf.float32)], + ), + ( + 2, + [2, [2, [tf.TensorShape([1, 3])]]], + [ + tf.zeros([2, 2], dtype=tf.float32), + [ + tf.zeros([2, 2], dtype=tf.float32), + [tf.zeros([2, 1, 3], dtype=tf.float32)], + ], + ], + ), + ], +) +def test_initial_state_from_state_size(batch_size, state_size, initial_state): + assert str(initial_state) == str( + util.initial_state_from_state_size(state_size, batch_size, tf.float32) + ) diff --git a/train.py b/train.py index 036daef..e98a5c4 100644 --- a/train.py +++ b/train.py @@ -18,150 +18,302 @@ from __future__ import division from __future__ import print_function +import argparse +import datetime import tensorflow as tf -import sonnet as snt from dnc import dnc from dnc import repeat_copy -FLAGS = tf.flags.FLAGS +parser = argparse.ArgumentParser(description="Train DNC for repeat copy task.") # Model parameters -tf.flags.DEFINE_integer("hidden_size", 64, "Size of LSTM hidden layer.") -tf.flags.DEFINE_integer("memory_size", 16, "The number of memory slots.") -tf.flags.DEFINE_integer("word_size", 16, "The width of each memory slot.") -tf.flags.DEFINE_integer("num_write_heads", 1, "Number of memory write heads.") -tf.flags.DEFINE_integer("num_read_heads", 4, "Number of memory read heads.") -tf.flags.DEFINE_integer("clip_value", 20, - "Maximum absolute value of controller and dnc outputs.") +parser.add_argument( + "--hidden_size", default=64, type=int, help="Size of LSTM hidden layer." +) +parser.add_argument( + "--memory_size", default=16, type=int, help="The number of memory slots." +) +parser.add_argument( + "--word_size", default=16, type=int, help="The width of each memory slot." +) +parser.add_argument( + "--num_write_heads", default=1, type=int, help="Number of memory write heads." +) +parser.add_argument( + "--num_read_heads", default=4, type=int, help="Number of memory read heads." +) +parser.add_argument( + "--clip_value", + default=20, + type=int, + help="Maximum absolute value of controller and dnc outputs.", +) # Optimizer parameters. -tf.flags.DEFINE_float("max_grad_norm", 50, "Gradient clipping norm limit.") -tf.flags.DEFINE_float("learning_rate", 1e-4, "Optimizer learning rate.") -tf.flags.DEFINE_float("optimizer_epsilon", 1e-10, - "Epsilon used for RMSProp optimizer.") +parser.add_argument( + "--max_grad_norm", default=50, type=float, help="Gradient clipping norm limit." +) +parser.add_argument( + "--learning_rate", default=1e-4, type=float, help="Optimizer learning rate." +) +parser.add_argument( + "--optimizer_epsilon", + default=1e-10, + type=float, + help="Epsilon used for RMSProp optimizer.", +) # Task parameters -tf.flags.DEFINE_integer("batch_size", 16, "Batch size for training.") -tf.flags.DEFINE_integer("num_bits", 4, "Dimensionality of each vector to copy") -tf.flags.DEFINE_integer( - "min_length", 1, - "Lower limit on number of vectors in the observation pattern to copy") -tf.flags.DEFINE_integer( - "max_length", 2, - "Upper limit on number of vectors in the observation pattern to copy") -tf.flags.DEFINE_integer("min_repeats", 1, - "Lower limit on number of copy repeats.") -tf.flags.DEFINE_integer("max_repeats", 2, - "Upper limit on number of copy repeats.") +parser.add_argument( + "--batch_size", default=16, type=int, help="Batch size for training." +) +parser.add_argument( + "--num_bits", default=8, type=int, help="Dimensionality of each vector to copy" +) +parser.add_argument( + "--min_length", + default=1, + type=int, + help="Lower limit on number of vectors in the observation pattern to copy", +) +parser.add_argument( + "--max_length", + default=3, + type=int, + help="Upper limit on number of vectors in the observation pattern to copy", +) +parser.add_argument( + "--min_repeats", default=1, type=int, help="Lower limit on number of copy repeats." +) +parser.add_argument( + "--max_repeats", default=3, type=int, help="Upper limit on number of copy repeats." +) # Training options. -tf.flags.DEFINE_integer("num_training_iterations", 100000, - "Number of iterations to train for.") -tf.flags.DEFINE_integer("report_interval", 100, - "Iterations between reports (samples, valid loss).") -tf.flags.DEFINE_string("checkpoint_dir", "/tmp/tf/dnc", - "Checkpointing directory.") -tf.flags.DEFINE_integer("checkpoint_interval", -1, - "Checkpointing step interval.") - - -def run_model(input_sequence, output_size): - """Runs model on input sequence.""" - - access_config = { - "memory_size": FLAGS.memory_size, - "word_size": FLAGS.word_size, - "num_reads": FLAGS.num_read_heads, - "num_writes": FLAGS.num_write_heads, - } - controller_config = { - "hidden_size": FLAGS.hidden_size, - } - clip_value = FLAGS.clip_value - - dnc_core = dnc.DNC(access_config, controller_config, output_size, clip_value) - initial_state = dnc_core.initial_state(FLAGS.batch_size) - output_sequence, _ = tf.nn.dynamic_rnn( - cell=dnc_core, - inputs=input_sequence, - time_major=True, - initial_state=initial_state) - - return output_sequence +parser.add_argument( + "--epochs", default=100000, type=int, help="Number of epochs to train for." +) +parser.add_argument( + "--log_dir", default="./logs/repeat_copy", type=str, help="Logging directory." +) +parser.add_argument( + "--report_interval", + default=500, + type=int, + help="Epochs between reports (samples, valid loss).", +) +parser.add_argument( + "--checkpoint_interval", default=-1, type=int, help="Checkpointing step interval." +) +parser.add_argument( + "--test_set_size", + default=100, + type=int, + help="Number of datapoints in the test/validation data set.", +) + +FLAGS = parser.parse_args() + + +def train_step(dataset_tensors, rnn_model, optimizer, loss_fn): + return train_step_graphed( + dataset_tensors.observations, + dataset_tensors.target, + dataset_tensors.mask, + rnn_model, + optimizer, + loss_fn, + ) + + +@tf.function +def train_step_graphed( + x, + y, + mask, + rnn_model, + optimizer, + loss_fn, +): + """Runs model on input sequence.""" + initial_state = rnn_model.get_initial_state(x) + with tf.GradientTape() as tape: + output_sequence = rnn_model( + inputs=x, + initial_state=initial_state, + ) + loss_value = loss_fn(output_sequence, y, mask) + grads = tape.gradient(loss_value, rnn_model.trainable_variables) + grads, _ = tf.clip_by_global_norm(grads, FLAGS.max_grad_norm) + optimizer.apply_gradients(zip(grads, rnn_model.trainable_variables)) + return loss_value + + +def test_step(dataset_tensors, rnn_model, optimizer, loss_fn): + return test_step_graphed( + dataset_tensors.observations, + dataset_tensors.target, + dataset_tensors.mask, + rnn_model, + loss_fn, + ) + + +@tf.function +def test_step_graphed( + x, + y, + mask, + rnn_model, + loss_fn, +): + initial_state = rnn_model.get_initial_state(x) + output_sequence = rnn_model( + inputs=x, + initial_state=initial_state, + ) + loss_value = loss_fn(output_sequence, y, mask) + # Used for visualization. + output = tf.round(tf.expand_dims(mask, -1) * tf.sigmoid(output_sequence)) + return loss_value, output def train(num_training_iterations, report_interval): - """Trains the DNC and periodically reports the loss.""" - - dataset = repeat_copy.RepeatCopy(FLAGS.num_bits, FLAGS.batch_size, - FLAGS.min_length, FLAGS.max_length, - FLAGS.min_repeats, FLAGS.max_repeats) - dataset_tensors = dataset() - - output_logits = run_model(dataset_tensors.observations, dataset.target_size) - # Used for visualization. - output = tf.round( - tf.expand_dims(dataset_tensors.mask, -1) * tf.sigmoid(output_logits)) - - train_loss = dataset.cost(output_logits, dataset_tensors.target, - dataset_tensors.mask) - - # Set up optimizer with global norm clipping. - trainable_variables = tf.trainable_variables() - grads, _ = tf.clip_by_global_norm( - tf.gradients(train_loss, trainable_variables), FLAGS.max_grad_norm) - - global_step = tf.get_variable( - name="global_step", - shape=[], - dtype=tf.int64, - initializer=tf.zeros_initializer(), - trainable=False, - collections=[tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.GLOBAL_STEP]) - - optimizer = tf.train.RMSPropOptimizer( - FLAGS.learning_rate, epsilon=FLAGS.optimizer_epsilon) - train_step = optimizer.apply_gradients( - zip(grads, trainable_variables), global_step=global_step) - - saver = tf.train.Saver() - - if FLAGS.checkpoint_interval > 0: - hooks = [ - tf.train.CheckpointSaverHook( - checkpoint_dir=FLAGS.checkpoint_dir, - save_steps=FLAGS.checkpoint_interval, - saver=saver) - ] - else: - hooks = [] - - # Train. - with tf.train.SingularMonitoredSession( - hooks=hooks, checkpoint_dir=FLAGS.checkpoint_dir) as sess: - - start_iteration = sess.run(global_step) - total_loss = 0 - - for train_iteration in range(start_iteration, num_training_iterations): - _, loss = sess.run([train_step, train_loss]) - total_loss += loss - - if (train_iteration + 1) % report_interval == 0: - dataset_tensors_np, output_np = sess.run([dataset_tensors, output]) - dataset_string = dataset.to_human_readable(dataset_tensors_np, - output_np) - tf.logging.info("%d: Avg training loss %f.\n%s", - train_iteration, total_loss / report_interval, - dataset_string) - total_loss = 0 + """Trains the DNC and periodically reports the loss.""" + + train_dataset = repeat_copy.RepeatCopy( + FLAGS.num_bits, + FLAGS.batch_size, + FLAGS.min_length, + FLAGS.max_length, + FLAGS.min_repeats, + FLAGS.max_repeats, + dtype=tf.float32, + ) + # Generate test data with double maximum repeat length + test_dataset = repeat_copy.RepeatCopy( + FLAGS.num_bits, + FLAGS.test_set_size, # FLAGS.batch_size, + FLAGS.min_length, + FLAGS.max_length, + FLAGS.max_repeats * 2, + FLAGS.max_repeats * 2, + dtype=tf.float32, + ) + + dataset_tensor = train_dataset() + test_dataset_tensor = test_dataset() + + access_config = { + "memory_size": FLAGS.memory_size, + "word_size": FLAGS.word_size, + "num_reads": FLAGS.num_read_heads, + "num_writes": FLAGS.num_write_heads, + } + controller_config = { + # snt.LSTM takes hidden_size as parameter + # "hidden_size": FLAGS.hidden_size, + # keras.layers.LSTM takes units as parameter + "units": FLAGS.hidden_size, + } + clip_value = FLAGS.clip_value + + dnc_cell = dnc.DNC( + access_config, + controller_config, + train_dataset.target_size, + FLAGS.batch_size, + clip_value, + ) + dnc_core = tf.keras.layers.RNN( + cell=dnc_cell, + time_major=True, + return_sequences=True, + ) + optimizer = tf.compat.v1.train.RMSPropOptimizer( + FLAGS.learning_rate, epsilon=FLAGS.optimizer_epsilon + ) + loss_fn = train_dataset.cost + + # Set up logging and metrics + train_loss = tf.keras.metrics.Mean("train_loss", dtype=tf.float32) + test_loss = tf.keras.metrics.Mean("test_loss", dtype=tf.float32) + + # current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + train_log_dir = FLAGS.log_dir + "/train" + test_log_dir = FLAGS.log_dir + "/test" + train_summary_writer = tf.summary.create_file_writer(train_log_dir) + test_summary_writer = tf.summary.create_file_writer(test_log_dir) + + # Test once to initialize + graph_log_dir = FLAGS.log_dir + "/graph" + graph_writer = tf.summary.create_file_writer(graph_log_dir) + with graph_writer.as_default(): + tf.summary.trace_on(graph=True, profiler=True) + test_step(dataset_tensor, dnc_core, optimizer, loss_fn) + tf.summary.trace_export(name="dnc_trace", step=0, profiler_outdir=graph_log_dir) + + # Set up model checkpointing + checkpoint = tf.train.Checkpoint(model=dnc_core, optimizer=optimizer) + manager = tf.train.CheckpointManager( + checkpoint, FLAGS.log_dir + "/checkpoint", max_to_keep=10 + ) + + checkpoint.restore(manager.latest_checkpoint) + if manager.latest_checkpoint: + print("Restored from {}".format(manager.latest_checkpoint)) + else: + print("Initializing from scratch.") + + # Train. + for epoch in range(num_training_iterations): + dataset_tensor = train_dataset() + train_loss_value = train_step(dataset_tensor, dnc_core, optimizer, loss_fn) + train_loss(train_loss_value) + + # report metrics + if (epoch) % report_interval == 0: + test_loss_value, output = test_step( + test_dataset_tensor, dnc_core, optimizer, test_dataset.cost + ) + test_loss(test_loss_value) + with test_summary_writer.as_default(): + tf.summary.scalar("loss", test_loss.result(), step=epoch) + with train_summary_writer.as_default(): + tf.summary.scalar("loss", train_loss.result(), step=epoch) + + template = "Epoch {}, Loss: {}, Test Loss: {}" + print( + template.format( + epoch, + train_loss.result(), + test_loss.result(), + ) + ) + + dataset_string = test_dataset.to_human_readable( + test_dataset_tensor, output.numpy() + ) + print(dataset_string) + + # reset metrics every report_interval + train_loss.reset_states() + test_loss.reset_states() + + # save model at defined intervals after training begins if enabled + if ( + FLAGS.checkpoint_interval > 0 + and epoch + and epoch % FLAGS.checkpoint_interval == 0 + ): + manager.save() def main(unused_argv): - tf.logging.set_verbosity(3) # Print INFO log messages. - train(FLAGS.num_training_iterations, FLAGS.report_interval) + tf.compat.v1.logging.set_verbosity(3) # Print INFO log messages. + train(FLAGS.epochs, FLAGS.report_interval) if __name__ == "__main__": - tf.app.run() + tf.compat.v1.app.run()