Skip to content

Commit

Permalink
fixing the API for some functions to be less silly and fixing the RNN
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed Nov 28, 2014
1 parent 2c99eac commit f399811
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 33 deletions.
3 changes: 2 additions & 1 deletion driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,10 @@ def costfun(batch, model):
parser.add_argument('--generator', dest='generator', type=str, default='lstm', help='generator to use')
parser.add_argument('-c', '--regc', dest='regc', type=float, default=1e-8, help='regularization strength')
parser.add_argument('--tanhC_version', dest='tanhC_version', type=int, default=0, help='use tanh version of LSTM?')
parser.add_argument('--rnn_relu_encoders', dest='rnn_relu_encoders', type=int, default=0, help='relu encoders before going to RNN?')

# optimization parameters
parser.add_argument('-m', '--max_epochs', dest='max_epochs', type=int, default=20, help='number of epochs to train for')
parser.add_argument('-m', '--max_epochs', dest='max_epochs', type=int, default=50, help='number of epochs to train for')
parser.add_argument('--solver', dest='solver', type=str, default='rmsprop', help='solver type: vanilla/adagrad/adadelta/rmsprop')
parser.add_argument('--momentum', dest='momentum', type=float, default=0.0, help='momentum for vanilla sgd')
parser.add_argument('--decay_rate', dest='decay_rate', type=float, default=0.999, help='decay rate for adadelta/rmsprop')
Expand Down
6 changes: 2 additions & 4 deletions eval_sentence_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,8 @@ def main(params):
n+=1
print 'image %d/%d:' % (n, max_images)
references = [x['tokens'] for x in img['sentences']] # as list of lists of tokens
kwparams = { 'tanhC_version' : checkpoint_params.get('tanhC_version', 0) ,\
'beam_size' : params['beam_size'],\
'generator' : checkpoint_params['generator']}
Ys = BatchGenerator.predict([{'image':img}], model, **kwparams)
kwparams = { 'beam_size' : params['beam_size'] }
Ys = BatchGenerator.predict([{'image':img}], model, checkpoint_params, **kwparams)

img_blob = {} # we will build this up
img_blob['img_path'] = img['local_file_path']
Expand Down
14 changes: 6 additions & 8 deletions imagernn/generic_batch_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def init(params, misc):
output_size = len(misc['ixtoword']) # these should match though
image_size = 4096 # size of CNN vectors hardcoded here

assert image_encoding_size == word_encoding_size, 'for now these must match. later with other models these could be different.'
if generator == 'lstm':
assert image_encoding_size == word_encoding_size, 'this implementation does not support different sizes for these parameters'

# initialize the encoder models
model = {}
Expand Down Expand Up @@ -84,10 +85,7 @@ def forward(batch, model, params, misc, predict_mode = False):
Xi = Xe[i,:]

# forward prop through the RNN
kwparams = { 'drop_prob_encoder' : params.get('drop_prob_encoder',0.0), \
'drop_prob_decoder' : params.get('drop_prob_decoder',0.0), \
'tanhC_version' : params.get('tanhC_version', 0) }
gen_Y, gen_cache = Generator.forward(Xi, Xs, model, predict_mode = predict_mode, **kwparams)
gen_Y, gen_cache = Generator.forward(Xi, Xs, model, params, predict_mode = predict_mode)
gen_caches.append((ix, gen_cache))
Ys.append(gen_Y)

Expand Down Expand Up @@ -138,17 +136,17 @@ def backward(dY, cache):
return grads

@staticmethod
def predict(batch, model, **kwparams):
def predict(batch, model, params, **kwparams):
""" some code duplication here with forward pass, but I think we want the freedom in future """
F = np.row_stack(x['image']['feat'] for x in batch)
We = model['We']
be = model['be']
Xe = F.dot(We) + be # Xe becomes N x image_encoding_size
generator_str = kwparams['generator']
generator_str = params['generator']
Generator = decodeGenerator(generator_str)
Ys = []
for i,x in enumerate(batch):
gen_Y = Generator.predict(Xe[i, :], model, model['Ws'], **kwparams)
gen_Y = Generator.predict(Xe[i, :], model, model['Ws'], params, **kwparams)
Ys.append(gen_Y)
return Ys

Expand Down
12 changes: 6 additions & 6 deletions imagernn/lstm_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def init(input_size, hidden_size, output_size):
return { 'model' : model, 'update' : update, 'regularize' : regularize }

@staticmethod
def forward(Xi, Xs, model, **kwargs):
def forward(Xi, Xs, model, params, **kwargs):
"""
Xi is 1-d array of size D (containing the image representation)
Xs is N x D (N time steps, rows are data containng word representations), and
Expand All @@ -39,9 +39,9 @@ def forward(Xi, Xs, model, **kwargs):
# options
# use the version of LSTM with tanh? Otherwise dont use tanh (Google style)
# following http://arxiv.org/abs/1409.3215
tanhC_version = kwargs.get('tanhC_version', 0)
drop_prob_encoder = kwargs.get('drop_prob_encoder', 0.0)
drop_prob_decoder = kwargs.get('drop_prob_decoder', 0.0)
tanhC_version = params.get('tanhC_version', 0)
drop_prob_encoder = params.get('drop_prob_encoder', 0.0)
drop_prob_decoder = params.get('drop_prob_decoder', 0.0)

if drop_prob_encoder > 0: # if we want dropout on the encoder
# inverted version of dropout here. Suppose the drop_prob is 0.5, then during training
Expand Down Expand Up @@ -191,7 +191,7 @@ def backward(dY, cache):
return { 'WLSTM': dWLSTM, 'Wd': dWd, 'bd': dbd, 'dXi': dX[0,:], 'dXs': dX[1:,:] }

@staticmethod
def predict(Xi, model, Ws, **kwargs):
def predict(Xi, model, Ws, params, **kwargs):
"""
Run in prediction mode with beam search. The input is the vector Xi, which
should be a 1-D array that contains the encoded image vector. We go from there.
Expand All @@ -200,7 +200,7 @@ def predict(Xi, model, Ws, **kwargs):
this because we may not want it to be exactly model['Ws']. For example it could be
fixed word vectors from somewhere else.
"""
tanhC_version = kwargs['tanhC_version']
tanhC_version = params['tanhC_version']
beam_size = kwargs.get('beam_size', 1)

WLSTM = model['WLSTM']
Expand Down
70 changes: 56 additions & 14 deletions imagernn/rnn_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,22 @@ class RNNGenerator:
def init(input_size, hidden_size, output_size):

model = {}
# connections to x_t
model['Wxh'] = initw(input_size, hidden_size)
model['bxh'] = np.zeros((1, hidden_size))
# connections to h_{t-1}
model['Whh'] = initw(hidden_size, hidden_size)
model['bhh'] = np.zeros((1, hidden_size))
# Decoder weights (e.g. mapping to vocabulary)
model['Wd'] = initw(hidden_size, output_size) * 0.01 # decoder
model['Wd'] = initw(hidden_size, output_size) * 0.1 # decoder
model['bd'] = np.zeros((1, output_size))

update = ['Whh', 'bhh', 'Wd', 'bd']
regularize = ['Whh', 'Wd']
update = ['Whh', 'bhh', 'Wxh', 'bxh', 'Wd', 'bd']
regularize = ['Whh', 'Wxh', 'Wd']
return { 'model' : model, 'update' : update, 'regularize' : regularize }

@staticmethod
def forward(Xi, Xs, model, **kwargs):
def forward(Xi, Xs, model, params, **kwargs):
"""
Xi is 1-d array of size D1 (containing the image representation)
Xs is N x D2 (N time steps, rows are data containng word representations), and
Expand All @@ -36,8 +39,9 @@ def forward(Xi, Xs, model, **kwargs):
predict_mode = kwargs.get('predict_mode', False)

# options
drop_prob_encoder = kwargs.get('drop_prob_encoder', 0.0)
drop_prob_decoder = kwargs.get('drop_prob_decoder', 0.0)
drop_prob_encoder = params.get('drop_prob_encoder', 0.0)
drop_prob_decoder = params.get('drop_prob_decoder', 0.0)
relu_encoders = params.get('rnn_relu_encoders', 0)

if drop_prob_encoder > 0: # if we want dropout on the encoder
# inverted version of dropout here. Suppose the drop_prob is 0.5, then during training
Expand All @@ -53,6 +57,15 @@ def forward(Xi, Xs, model, **kwargs):
Ui = (np.random.rand(*(Xi.shape)) < (1 - drop_prob_encoder)) * scale
Xi *= Ui # drop!

# encode input vectors
Wxh = model['Wxh']
bxh = model['bxh']
Xsh = Xs.dot(Wxh) + bxh

if relu_encoders:
Xsh = np.maximum(Xsh, 0)
Xi = np.maximum(Xi, 0)

# recurrence iteration for the Multimodal RNN similar to one described in Karpathy et al.
d = model['Wd'].shape[0] # size of hidden layer
n = Xs.shape[0]
Expand All @@ -62,8 +75,7 @@ def forward(Xi, Xs, model, **kwargs):
for t in xrange(n):

prev = np.zeros(d) if t == 0 else H[t-1]
ht = Xi + Xs[t] + prev.dot(Whh) + bhh
H[t] = np.maximum(ht, 0) # ReLU nonlinearity
H[t] = np.maximum(Xi + Xsh[t] + prev.dot(Whh) + bhh, 0) # also ReLU

if drop_prob_decoder > 0: # if we want dropout on the decoder
if not predict_mode: # and we are in training mode
Expand All @@ -83,6 +95,10 @@ def forward(Xi, Xs, model, **kwargs):
cache['H'] = H
cache['Wd'] = Wd
cache['Xs'] = Xs
cache['Xsh'] = Xsh
cache['Wxh'] = Wxh
cache['Xi'] = Xi
cache['relu_encoders'] = relu_encoders
cache['drop_prob_encoder'] = drop_prob_encoder
cache['drop_prob_decoder'] = drop_prob_decoder
if drop_prob_encoder > 0:
Expand All @@ -98,9 +114,13 @@ def backward(dY, cache):
Wd = cache['Wd']
H = cache['H']
Xs = cache['Xs']
Xsh = cache['Xsh']
Whh = cache['Whh']
Wxh = cache['Wxh']
Xi = cache['Xi']
drop_prob_encoder = cache['drop_prob_encoder']
drop_prob_decoder = cache['drop_prob_decoder']
relu_encoders = cache['relu_encoders']
n,d = H.shape

# backprop the decoder
Expand All @@ -113,36 +133,52 @@ def backward(dY, cache):
dH *= cache['U2']

# backprop the recurrent connections
dXs = np.zeros(Xs.shape)
dXsh = np.zeros(Xsh.shape)
dXi = np.zeros(d)
dWhh = np.zeros(Whh.shape)
dbhh = np.zeros((1,d))
for t in reversed(xrange(n)):
dht = (H[t] > 0) * dH[t] # backprop ReLU
dXi += dht # backprop to Xi
dXs[t] += dht # backprop to word encodings
dXsh[t] += dht # backprop to word encodings
dbhh[0] += dht # backprop to bias

if t > 0:
dH[t-1] += dht.dot(Whh.transpose())
dWhh += np.outer(H[t-1], dht)

if relu_encoders:
# backprop relu
dXsh[Xsh <= 0] = 0
dXi[Xi <= 0] = 0

# backprop the word encoder
dWxh = Xs.transpose().dot(dXsh)
dbxh = np.sum(dXsh, axis=0, keepdims = True)
dXs = dXsh.dot(Wxh.transpose())

if drop_prob_encoder > 0: # backprop encoder dropout
dXi *= cache['Ui']
dXs *= cache['Us']

return { 'Whh': dWhh, 'bhh': dbhh, 'Wd': dWd, 'bd': dbd, 'dXs' : dXs, 'dXi': dXi }
return { 'Whh': dWhh, 'bhh': dbhh, 'Wd': dWd, 'bd': dbd, 'Wxh':dWxh, 'bxh':dbxh, 'dXs' : dXs, 'dXi': dXi }

@staticmethod
def predict(Xi, model, Ws, **kwargs):
def predict(Xi, model, Ws, params, **kwargs):

beam_size = kwargs.get('beam_size', 1)
relu_encoders = params.get('rnn_relu_encoders', 0)

d = model['Wd'].shape[0] # size of hidden layer
Whh = model['Whh']
bhh = model['bhh']
Wd = model['Wd']
bd = model['bd']
Wxh = model['Wxh']
bxh = model['bxh']

if relu_encoders:
Xi = np.maximum(Xi, 0)

if beam_size > 1:
# perform beam search
Expand All @@ -160,7 +196,10 @@ def predict(Xi, model, Ws, **kwargs):
beam_candidates.append(b)
continue
# tick the RNN for this beam
h1 = np.maximum(Xi + Ws[ixprev] + b[2].dot(Whh) + bhh, 0)
Xsh = Ws[ixprev].dot(Wxh) + bxh
if relu_encoders:
Xsh = np.maximum(Xsh, 0)
h1 = np.maximum(Xi + Xsh + b[2].dot(Whh) + bhh, 0)
y1 = h1.dot(Wd) + bd

# compute new candidates that expand out form this beam
Expand Down Expand Up @@ -190,7 +229,10 @@ def predict(Xi, model, Ws, **kwargs):
hprev = np.zeros((1, d)) # hidden layer representation
xsprev = Ws[0] # start token
while True:
ht = np.maximum(Xi + Ws[ixprev] + hprev.dot(Whh) + bhh, 0)
Xsh = Ws[ixprev].dot(Wxh) + bxh
if relu_encoders:
Xsh = np.maximum(Xsh, 0)
ht = np.maximum(Xi + Xsh + hprev.dot(Whh) + bhh, 0)
Y = ht.dot(Wd) + bd
hprev = ht

Expand Down

0 comments on commit f399811

Please sign in to comment.