-
Notifications
You must be signed in to change notification settings - Fork 55
Description
Hi, I try to run the code on tf=1.15 but I get an error in the construct model and training ops section
code:
`tf.reset_default_graph()
data = sample_mog(params['batch_size'])
noise = ds.Normal(tf.zeros(params['z_dim']),
tf.ones(params['z_dim'])).sample(params['batch_size'])
with slim.arg_scope([slim.fully_connected], weights_initializer=tf.orthogonal_initializer(gain=1.4)):
samples = generator(noise, output_dim=params['x_dim'])
real_score = discriminator(data)
fake_score = discriminator(samples, reuse=True)
loss = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=real_score, labels=tf.ones_like(real_score)) +
tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_score, labels=tf.zeros_like(fake_score)))
gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "generator")
disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "discriminator")
d_opt = Adam(lr=params['disc_learning_rate'], beta_1=params['beta1'], epsilon=params['epsilon'])
updates = d_opt.get_updates(disc_vars, [], loss)
d_train_op = tf.group(*updates, name="d_train_op")
if params['unrolling_steps'] > 0:
update_dict = extract_update_dict(updates)
cur_update_dict = update_dict
for i in xrange(params['unrolling_steps'] - 1):
cur_update_dict = graph_replace(update_dict, cur_update_dict)
unrolled_loss = graph_replace(loss, cur_update_dict)
else:
unrolled_loss = loss
g_train_opt = tf.train.AdamOptimizer(params['gen_learning_rate'], beta1=params['beta1'], epsilon=params['epsilon'])
g_train_op = g_train_opt.minimize(-unrolled_loss, var_list=gen_vars)`
Error details:
NameErrorTraceback (most recent call last)
in ()
27 if params['unrolling_steps'] > 0:
28 # Get dictionary mapping from variables to their update value after one optimization step
---> 29 update_dict = extract_update_dict(updates)
30 cur_update_dict = update_dict
31 for i in xrange(params['unrolling_steps'] - 1):in extract_update_dict(update_ops)
19 updates[var.value()] = var + value
20 else:
---> 21 raise ValueError("Update op type (%s) must be of type Assign or AssignAdd"%update_op.op.type)
22 return updatesNameError: global name 'update_op' is not defined
How can I solve this error?