Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,8 @@ ENV/
/site

# mypy
.mypy_cache/
.mypy_cache/

# vscode and its extensions
.vscode/*
.history/*
Empty file added dnc/__init__.py
Empty file.
6 changes: 3 additions & 3 deletions access.py → dnc/access.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import sonnet as snt
import tensorflow as tf

import addressing
import util
from dnc import addressing
from dnc import util

AccessState = collections.namedtuple('AccessState', (
'memory', 'read_weights', 'write_weights', 'linkage', 'usage'))
Expand Down Expand Up @@ -53,7 +53,7 @@ def _erase_and_write(memory, address, reset_weights, values):
expand_address = tf.expand_dims(address, 3)
reset_weights = tf.expand_dims(reset_weights, 2)
weighted_resets = expand_address * reset_weights
reset_gate = tf.reduce_prod(1 - weighted_resets, [1])
reset_gate = util.reduce_prod(1 - weighted_resets, 1)
memory *= reset_gate

with tf.name_scope('additive_write', values=[memory, address, values]):
Expand Down
4 changes: 2 additions & 2 deletions access_test.py → dnc/access_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import tensorflow as tf
from tensorflow.python.ops import rnn

import access
import util
from dnc import access
from dnc import util

BATCH_SIZE = 2
MEMORY_SIZE = 20
Expand Down
12 changes: 6 additions & 6 deletions addressing.py → dnc/addressing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import sonnet as snt
import tensorflow as tf

import util
from dnc import util

# Ensure values are greater than epsilon to avoid numerical instability.
_EPSILON = 1e-6
Expand All @@ -32,7 +32,7 @@


def _vector_norms(m):
squared_norms = tf.reduce_sum(m * m, axis=2, keep_dims=True)
squared_norms = tf.reduce_sum(m * m, axis=2, keepdims=True)
return tf.sqrt(squared_norms + _EPSILON)


Expand Down Expand Up @@ -202,7 +202,7 @@ def _link(self, prev_link, prev_precedence_weights, write_weights):
containing the new link graphs for each write head.
"""
with tf.name_scope('link'):
batch_size = prev_link.get_shape()[0].value
batch_size = tf.shape(prev_link)[0]
write_weights_i = tf.expand_dims(write_weights, 3)
write_weights_j = tf.expand_dims(write_weights, 2)
prev_precedence_weights_j = tf.expand_dims(prev_precedence_weights, 2)
Expand Down Expand Up @@ -236,7 +236,7 @@ def _precedence_weights(self, prev_precedence_weights, write_weights):
new precedence weights.
"""
with tf.name_scope('precedence_weights'):
write_sum = tf.reduce_sum(write_weights, 2, keep_dims=True)
write_sum = tf.reduce_sum(write_weights, 2, keepdims=True)
return (1 - write_sum) * prev_precedence_weights + write_weights

@property
Expand Down Expand Up @@ -351,7 +351,7 @@ def _usage_after_write(self, prev_usage, write_weights):
"""
with tf.name_scope('usage_after_write'):
# Calculate the aggregated effect of all write heads
write_weights = 1 - tf.reduce_prod(1 - write_weights, [1])
write_weights = 1 - util.reduce_prod(1 - write_weights, 1)
return prev_usage + (1 - prev_usage) * write_weights

def _usage_after_read(self, prev_usage, free_gate, read_weights):
Expand All @@ -370,7 +370,7 @@ def _usage_after_read(self, prev_usage, free_gate, read_weights):
with tf.name_scope('usage_after_read'):
free_gate = tf.expand_dims(free_gate, -1)
free_read_weights = free_gate * read_weights
phi = tf.reduce_prod(1 - free_read_weights, [1], name='phi')
phi = util.reduce_prod(1 - free_read_weights, 1, name='phi')
return prev_usage * phi

def _allocation(self, usage):
Expand Down
4 changes: 2 additions & 2 deletions addressing_test.py → dnc/addressing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import sonnet as snt
import tensorflow as tf

import addressing
import util
from dnc import addressing
from dnc import util


class WeightedSoftmaxTest(tf.test.TestCase):
Expand Down
4 changes: 2 additions & 2 deletions dnc.py → dnc/dnc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import sonnet as snt
import tensorflow as tf

import access
from dnc import access

DNCState = collections.namedtuple('DNCState', ('access_output', 'access_state',
'controller_state'))
Expand Down Expand Up @@ -110,7 +110,7 @@ def _build(self, inputs, prev_state):
controller_input, prev_controller_state)

controller_output = self._clip_if_enabled(controller_output)
controller_state = snt.nest.map(self._clip_if_enabled, controller_state)
controller_state = tf.contrib.framework.nest.map_structure(self._clip_if_enabled, controller_state)

access_output, access_state = self._access(controller_output,
prev_access_state)
Expand Down
File renamed without changes
File renamed without changes.
39 changes: 33 additions & 6 deletions util.py → dnc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,48 @@
def batch_invert_permutation(permutations):
"""Returns batched `tf.invert_permutation` for every row in `permutations`."""
with tf.name_scope('batch_invert_permutation', values=[permutations]):
unpacked = tf.unstack(permutations)
inverses = [tf.invert_permutation(permutation) for permutation in unpacked]
return tf.stack(inverses)
perm = tf.cast(permutations, tf.float32)
dim = int(perm.get_shape()[-1])
size = tf.cast(tf.shape(perm)[0], tf.float32)
delta = tf.cast(tf.shape(perm)[-1], tf.float32)
rg = tf.range(0, size * delta, delta, dtype=tf.float32)
rg = tf.expand_dims(rg, 1)
rg = tf.tile(rg, [1, dim])
perm = tf.add(perm, rg)
flat = tf.reshape(perm, [-1])
perm = tf.invert_permutation(tf.cast(flat, tf.int32))
perm = tf.reshape(perm, [-1, dim])
return tf.subtract(perm, tf.cast(rg, tf.int32))


def batch_gather(values, indices):
"""Returns batched `tf.gather` for every row in the input."""
with tf.name_scope('batch_gather', values=[values, indices]):
unpacked = zip(tf.unstack(values), tf.unstack(indices))
result = [tf.gather(value, index) for value, index in unpacked]
return tf.stack(result)
idx = tf.expand_dims(indices, -1)
size = tf.shape(indices)[0]
rg = tf.range(size, dtype=tf.int32)
rg = tf.expand_dims(rg, -1)
rg = tf.tile(rg, [1, int(indices.get_shape()[-1])])
rg = tf.expand_dims(rg, -1)
gidx = tf.concat([rg, idx], -1)
return tf.gather_nd(values, gidx)


def one_hot(length, index):
"""Return an nd array of given `length` filled with 0s and a 1 at `index`."""
result = np.zeros(length)
result[index] = 1
return result

def reduce_prod(x, axis, name=None):
"""Efficient reduce product over axis.

Uses tf.cumprod and tf.gather_nd as a workaround to the poor performance of calculating tf.reduce_prod's gradient on CPU.
"""
with tf.name_scope(name, 'util_reduce_prod', values=[x]):
cp = tf.cumprod(x, axis, reverse=True)
size = tf.shape(cp)[0]
idx1 = tf.range(tf.cast(size, tf.float32), dtype=tf.float32)
idx2 = tf.zeros([size], tf.float32)
indices = tf.stack([idx1, idx2], 1)
return tf.gather_nd(cp, tf.cast(indices, tf.int32))
2 changes: 1 addition & 1 deletion util_test.py → dnc/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import numpy as np
import tensorflow as tf

import util
from dnc import util


class BatchInvertPermutation(tf.test.TestCase):
Expand Down
12 changes: 12 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from setuptools import setup

setup(
name='dnc',
version='0.0.2',
description='This package provides an implementation of the Differentiable Neural Computer, as published in Nature.',
license='Apache Software License 2.0',
packages=['dnc'],
author='DeepMind',
keywords=['tensorflow', 'differentiable neural computer', 'dnc', 'deepmind', 'deep mind', 'sonnet', 'dm-sonnet', 'machine learning'],
url='https://github.com/deepmind/dnc'
)
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import tensorflow as tf
import sonnet as snt

import dnc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you want this import in, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I just missed this one. I'm short of budgets and all machines(only 2 actually) at my disposal are busy running something, so I didn't bother to run the script...

Now it seems able to start and run smoothly, but I don't have the computer power to run it through.

If there's any other issues, please don't hesitate to point them out.

import repeat_copy
from dnc import dnc
from dnc import repeat_copy

FLAGS = tf.flags.FLAGS

Expand Down