diff --git a/.gitignore b/.gitignore index 031771e..b9a9d94 100644 --- a/.gitignore +++ b/.gitignore @@ -96,4 +96,8 @@ ENV/ /site # mypy -.mypy_cache/ \ No newline at end of file +.mypy_cache/ + +# vscode and its extensions +.vscode/* +.history/* \ No newline at end of file diff --git a/dnc/__init__.py b/dnc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/access.py b/dnc/access.py similarity index 99% rename from access.py rename to dnc/access.py index f4a8433..211d454 100644 --- a/access.py +++ b/dnc/access.py @@ -22,8 +22,8 @@ import sonnet as snt import tensorflow as tf -import addressing -import util +from dnc import addressing +from dnc import util AccessState = collections.namedtuple('AccessState', ( 'memory', 'read_weights', 'write_weights', 'linkage', 'usage')) @@ -53,7 +53,7 @@ def _erase_and_write(memory, address, reset_weights, values): expand_address = tf.expand_dims(address, 3) reset_weights = tf.expand_dims(reset_weights, 2) weighted_resets = expand_address * reset_weights - reset_gate = tf.reduce_prod(1 - weighted_resets, [1]) + reset_gate = util.reduce_prod(1 - weighted_resets, 1) memory *= reset_gate with tf.name_scope('additive_write', values=[memory, address, values]): diff --git a/access_test.py b/dnc/access_test.py similarity index 99% rename from access_test.py rename to dnc/access_test.py index 79df2f4..20fe7c2 100644 --- a/access_test.py +++ b/dnc/access_test.py @@ -22,8 +22,8 @@ import tensorflow as tf from tensorflow.python.ops import rnn -import access -import util +from dnc import access +from dnc import util BATCH_SIZE = 2 MEMORY_SIZE = 20 diff --git a/addressing.py b/dnc/addressing.py similarity index 98% rename from addressing.py rename to dnc/addressing.py index 77a88e8..97365b1 100644 --- a/addressing.py +++ b/dnc/addressing.py @@ -22,7 +22,7 @@ import sonnet as snt import tensorflow as tf -import util +from dnc import util # Ensure values are greater than epsilon to avoid numerical instability. _EPSILON = 1e-6 @@ -32,7 +32,7 @@ def _vector_norms(m): - squared_norms = tf.reduce_sum(m * m, axis=2, keep_dims=True) + squared_norms = tf.reduce_sum(m * m, axis=2, keepdims=True) return tf.sqrt(squared_norms + _EPSILON) @@ -202,7 +202,7 @@ def _link(self, prev_link, prev_precedence_weights, write_weights): containing the new link graphs for each write head. """ with tf.name_scope('link'): - batch_size = prev_link.get_shape()[0].value + batch_size = tf.shape(prev_link)[0] write_weights_i = tf.expand_dims(write_weights, 3) write_weights_j = tf.expand_dims(write_weights, 2) prev_precedence_weights_j = tf.expand_dims(prev_precedence_weights, 2) @@ -236,7 +236,7 @@ def _precedence_weights(self, prev_precedence_weights, write_weights): new precedence weights. """ with tf.name_scope('precedence_weights'): - write_sum = tf.reduce_sum(write_weights, 2, keep_dims=True) + write_sum = tf.reduce_sum(write_weights, 2, keepdims=True) return (1 - write_sum) * prev_precedence_weights + write_weights @property @@ -351,7 +351,7 @@ def _usage_after_write(self, prev_usage, write_weights): """ with tf.name_scope('usage_after_write'): # Calculate the aggregated effect of all write heads - write_weights = 1 - tf.reduce_prod(1 - write_weights, [1]) + write_weights = 1 - util.reduce_prod(1 - write_weights, 1) return prev_usage + (1 - prev_usage) * write_weights def _usage_after_read(self, prev_usage, free_gate, read_weights): @@ -370,7 +370,7 @@ def _usage_after_read(self, prev_usage, free_gate, read_weights): with tf.name_scope('usage_after_read'): free_gate = tf.expand_dims(free_gate, -1) free_read_weights = free_gate * read_weights - phi = tf.reduce_prod(1 - free_read_weights, [1], name='phi') + phi = util.reduce_prod(1 - free_read_weights, 1, name='phi') return prev_usage * phi def _allocation(self, usage): diff --git a/addressing_test.py b/dnc/addressing_test.py similarity index 99% rename from addressing_test.py rename to dnc/addressing_test.py index d8f803d..a8a8ac4 100644 --- a/addressing_test.py +++ b/dnc/addressing_test.py @@ -22,8 +22,8 @@ import sonnet as snt import tensorflow as tf -import addressing -import util +from dnc import addressing +from dnc import util class WeightedSoftmaxTest(tf.test.TestCase): diff --git a/dnc.py b/dnc/dnc.py similarity index 97% rename from dnc.py rename to dnc/dnc.py index 8df92cf..db14b2a 100644 --- a/dnc.py +++ b/dnc/dnc.py @@ -27,7 +27,7 @@ import sonnet as snt import tensorflow as tf -import access +from dnc import access DNCState = collections.namedtuple('DNCState', ('access_output', 'access_state', 'controller_state')) @@ -110,7 +110,7 @@ def _build(self, inputs, prev_state): controller_input, prev_controller_state) controller_output = self._clip_if_enabled(controller_output) - controller_state = snt.nest.map(self._clip_if_enabled, controller_state) + controller_state = tf.contrib.framework.nest.map_structure(self._clip_if_enabled, controller_state) access_output, access_state = self._access(controller_output, prev_access_state) diff --git a/images/dnc_model.png b/dnc/images/dnc_model.png similarity index 100% rename from images/dnc_model.png rename to dnc/images/dnc_model.png diff --git a/repeat_copy.py b/dnc/repeat_copy.py similarity index 100% rename from repeat_copy.py rename to dnc/repeat_copy.py diff --git a/util.py b/dnc/util.py similarity index 50% rename from util.py rename to dnc/util.py index ce35290..5009c77 100644 --- a/util.py +++ b/dnc/util.py @@ -25,17 +25,31 @@ def batch_invert_permutation(permutations): """Returns batched `tf.invert_permutation` for every row in `permutations`.""" with tf.name_scope('batch_invert_permutation', values=[permutations]): - unpacked = tf.unstack(permutations) - inverses = [tf.invert_permutation(permutation) for permutation in unpacked] - return tf.stack(inverses) + perm = tf.cast(permutations, tf.float32) + dim = int(perm.get_shape()[-1]) + size = tf.cast(tf.shape(perm)[0], tf.float32) + delta = tf.cast(tf.shape(perm)[-1], tf.float32) + rg = tf.range(0, size * delta, delta, dtype=tf.float32) + rg = tf.expand_dims(rg, 1) + rg = tf.tile(rg, [1, dim]) + perm = tf.add(perm, rg) + flat = tf.reshape(perm, [-1]) + perm = tf.invert_permutation(tf.cast(flat, tf.int32)) + perm = tf.reshape(perm, [-1, dim]) + return tf.subtract(perm, tf.cast(rg, tf.int32)) def batch_gather(values, indices): """Returns batched `tf.gather` for every row in the input.""" with tf.name_scope('batch_gather', values=[values, indices]): - unpacked = zip(tf.unstack(values), tf.unstack(indices)) - result = [tf.gather(value, index) for value, index in unpacked] - return tf.stack(result) + idx = tf.expand_dims(indices, -1) + size = tf.shape(indices)[0] + rg = tf.range(size, dtype=tf.int32) + rg = tf.expand_dims(rg, -1) + rg = tf.tile(rg, [1, int(indices.get_shape()[-1])]) + rg = tf.expand_dims(rg, -1) + gidx = tf.concat([rg, idx], -1) + return tf.gather_nd(values, gidx) def one_hot(length, index): @@ -43,3 +57,16 @@ def one_hot(length, index): result = np.zeros(length) result[index] = 1 return result + +def reduce_prod(x, axis, name=None): + """Efficient reduce product over axis. + + Uses tf.cumprod and tf.gather_nd as a workaround to the poor performance of calculating tf.reduce_prod's gradient on CPU. + """ + with tf.name_scope(name, 'util_reduce_prod', values=[x]): + cp = tf.cumprod(x, axis, reverse=True) + size = tf.shape(cp)[0] + idx1 = tf.range(tf.cast(size, tf.float32), dtype=tf.float32) + idx2 = tf.zeros([size], tf.float32) + indices = tf.stack([idx1, idx2], 1) + return tf.gather_nd(cp, tf.cast(indices, tf.int32)) diff --git a/util_test.py b/dnc/util_test.py similarity index 98% rename from util_test.py rename to dnc/util_test.py index 8cac46c..55e3f25 100644 --- a/util_test.py +++ b/dnc/util_test.py @@ -21,7 +21,7 @@ import numpy as np import tensorflow as tf -import util +from dnc import util class BatchInvertPermutation(tf.test.TestCase): diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..a5c3ee4 --- /dev/null +++ b/setup.py @@ -0,0 +1,12 @@ +from setuptools import setup + +setup( + name='dnc', + version='0.0.2', + description='This package provides an implementation of the Differentiable Neural Computer, as published in Nature.', + license='Apache Software License 2.0', + packages=['dnc'], + author='DeepMind', + keywords=['tensorflow', 'differentiable neural computer', 'dnc', 'deepmind', 'deep mind', 'sonnet', 'dm-sonnet', 'machine learning'], + url='https://github.com/deepmind/dnc' +) diff --git a/train.py b/train.py index a2ca1ad..036daef 100644 --- a/train.py +++ b/train.py @@ -21,8 +21,8 @@ import tensorflow as tf import sonnet as snt -import dnc -import repeat_copy +from dnc import dnc +from dnc import repeat_copy FLAGS = tf.flags.FLAGS