Skip to content

Commit 4b9660a

Browse files
Craigacpkarllessard
authored andcommitted
Adding a control dependency on the gradients to the gradient optimizers.
This improves determinism and makes the gradients be computed correctly for unclear reasons. (#520) Co-authored-by: Nicolas Feybesse ([email protected])
1 parent a7bb135 commit 4b9660a

File tree

12 files changed

+90
-82
lines changed

12 files changed

+90
-82
lines changed

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java

+6-5
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.tensorflow.Operand;
2121
import org.tensorflow.Output;
2222
import org.tensorflow.op.Op;
23+
import org.tensorflow.op.Ops;
2324
import org.tensorflow.op.core.Variable;
2425
import org.tensorflow.op.train.ApplyAdadelta;
2526
import org.tensorflow.types.family.TType;
@@ -150,16 +151,16 @@ private <T extends TType> void createAdaDeltaSlot(Output<T> v) {
150151

151152
/** {@inheritDoc} */
152153
@Override
153-
protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
154+
protected <T extends TType> Op applyDense(Ops deps, Output<T> gradient, Output<T> variable) {
154155
Variable<T> accumSlot = getSlot(variable, ACCUMULATOR).get();
155156
Variable<T> accumUpdateSlot = getSlot(variable, ACCUMULATOR_UPDATE).get();
156-
return tf.train.applyAdadelta(
157+
return deps.train.applyAdadelta(
157158
variable,
158159
accumSlot,
159160
accumUpdateSlot,
160-
tf.dtypes.cast(tf.constant(learningRate), gradient.type()),
161-
tf.dtypes.cast(tf.constant(rho), gradient.type()),
162-
tf.dtypes.cast(tf.constant(epsilon), gradient.type()),
161+
deps.dtypes.cast(deps.constant(learningRate), gradient.type()),
162+
deps.dtypes.cast(deps.constant(rho), gradient.type()),
163+
deps.dtypes.cast(deps.constant(epsilon), gradient.type()),
163164
gradient,
164165
ApplyAdadelta.useLocking(true));
165166
}

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java

+4-3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.tensorflow.Operand;
2121
import org.tensorflow.Output;
2222
import org.tensorflow.op.Op;
23+
import org.tensorflow.op.Ops;
2324
import org.tensorflow.op.core.Variable;
2425
import org.tensorflow.op.train.ApplyAdagrad;
2526
import org.tensorflow.types.family.TType;
@@ -140,10 +141,10 @@ private <T extends TType> void createAdaGradSlot(Output<T> v) {
140141

141142
/** {@inheritDoc} */
142143
@Override
143-
protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
144+
protected <T extends TType> Op applyDense(Ops deps, Output<T> gradient, Output<T> variable) {
144145
Variable<T> slot = getSlot(variable, ACCUMULATOR).get();
145-
return tf.train.applyAdagrad(
146-
variable, slot, tf.dtypes.cast(tf.constant(learningRate), gradient.type()), gradient, opts);
146+
return deps.train.applyAdagrad(
147+
variable, slot, deps.dtypes.cast(deps.constant(learningRate), gradient.type()), gradient, opts);
147148
}
148149

149150
/** {@inheritDoc} */

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java

+6-5
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.tensorflow.Output;
2323
import org.tensorflow.ndarray.Shape;
2424
import org.tensorflow.op.Op;
25+
import org.tensorflow.op.Ops;
2526
import org.tensorflow.op.core.Variable;
2627
import org.tensorflow.op.train.ApplyAdagradDa;
2728
import org.tensorflow.types.TInt64;
@@ -209,17 +210,17 @@ private <T extends TType> void createAdaGradDASlot(Output<T> v) {
209210

210211
/** {@inheritDoc} */
211212
@Override
212-
protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
213+
protected <T extends TType> Op applyDense(Ops deps, Output<T> gradient, Output<T> variable) {
213214
Variable<T> gradSlot = getSlot(variable, ACCUMULATOR).get();
214215
Variable<T> gradSquaredSlot = getSlot(variable, SQUARED_ACCUMULATOR).get();
215-
return tf.train.applyAdagradDa(
216+
return deps.train.applyAdagradDa(
216217
variable,
217218
gradSlot,
218219
gradSquaredSlot,
219220
gradient,
220-
tf.dtypes.cast(tf.constant(learningRate), gradient.type()),
221-
tf.dtypes.cast(tf.constant(l1Strength), gradient.type()),
222-
tf.dtypes.cast(tf.constant(l2Strength), gradient.type()),
221+
deps.dtypes.cast(deps.constant(learningRate), gradient.type()),
222+
deps.dtypes.cast(deps.constant(l1Strength), gradient.type()),
223+
deps.dtypes.cast(deps.constant(l2Strength), gradient.type()),
223224
globalStep,
224225
ApplyAdagradDa.useLocking(true));
225226
}

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java

+9-8
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.tensorflow.Output;
2323
import org.tensorflow.ndarray.Shape;
2424
import org.tensorflow.op.Op;
25+
import org.tensorflow.op.Ops;
2526
import org.tensorflow.op.Scope;
2627
import org.tensorflow.op.annotation.Endpoint;
2728
import org.tensorflow.op.annotation.Operator;
@@ -223,19 +224,19 @@ private <T extends TType> void createAdamSlot(Output<T> v) {
223224

224225
/** {@inheritDoc} */
225226
@Override
226-
protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
227+
protected <T extends TType> Op applyDense(Ops deps, Output<T> gradient, Output<T> variable) {
227228
Variable<T> firstMomentSlot = getSlot(variable, FIRST_MOMENT).get();
228229
Variable<T> secondMomentSlot = getSlot(variable, SECOND_MOMENT).get();
229-
return tf.train.applyAdam(
230+
return deps.train.applyAdam(
230231
variable,
231232
firstMomentSlot,
232233
secondMomentSlot,
233-
tf.dtypes.cast(betaOnePower, gradient.type()),
234-
tf.dtypes.cast(betaTwoPower, gradient.type()),
235-
tf.dtypes.cast(learningRateConst, gradient.type()),
236-
tf.dtypes.cast(betaOneConst, gradient.type()),
237-
tf.dtypes.cast(betaTwoConst, gradient.type()),
238-
tf.dtypes.cast(epsilonConst, gradient.type()),
234+
deps.dtypes.cast(betaOnePower, gradient.type()),
235+
deps.dtypes.cast(betaTwoPower, gradient.type()),
236+
deps.dtypes.cast(learningRateConst, gradient.type()),
237+
deps.dtypes.cast(betaOneConst, gradient.type()),
238+
deps.dtypes.cast(betaTwoConst, gradient.type()),
239+
deps.dtypes.cast(epsilonConst, gradient.type()),
239240
gradient,
240241
ApplyAdam.useLocking(true));
241242
}

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java

+8-7
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import org.tensorflow.Output;
88
import org.tensorflow.ndarray.Shape;
99
import org.tensorflow.op.Op;
10+
import org.tensorflow.op.Ops;
1011
import org.tensorflow.op.core.Constant;
1112
import org.tensorflow.op.core.Variable;
1213
import org.tensorflow.op.train.ApplyAdaMax;
@@ -155,19 +156,19 @@ private <T extends TType> void createAdamaxSlot(Output<T> v) {
155156

156157
/** {@inheritDoc} */
157158
@Override
158-
protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
159+
protected <T extends TType> Op applyDense(Ops deps, Output<T> gradient, Output<T> variable) {
159160
Variable<T> firstMomentSlot = getSlot(variable, FIRST_MOMENT).get();
160161
Variable<T> secondMomentSlot = getSlot(variable, SECOND_MOMENT).get();
161162
return ApplyAdaMax.create(
162-
this.tf.scope(),
163+
deps.scope(),
163164
variable,
164165
firstMomentSlot,
165166
secondMomentSlot,
166-
tf.dtypes.cast(betaOnePower, gradient.type()),
167-
tf.dtypes.cast(learningRateConst, gradient.type()),
168-
tf.dtypes.cast(betaOneConst, gradient.type()),
169-
tf.dtypes.cast(betaTwoConst, gradient.type()),
170-
tf.dtypes.cast(epsilonConst, gradient.type()),
167+
deps.dtypes.cast(betaOnePower, gradient.type()),
168+
deps.dtypes.cast(learningRateConst, gradient.type()),
169+
deps.dtypes.cast(betaOneConst, gradient.type()),
170+
deps.dtypes.cast(betaTwoConst, gradient.type()),
171+
deps.dtypes.cast(epsilonConst, gradient.type()),
171172
gradient,
172173
ApplyAdaMax.useLocking(true));
173174
}

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java

+9-8
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import org.tensorflow.Operand;
66
import org.tensorflow.Output;
77
import org.tensorflow.op.Op;
8+
import org.tensorflow.op.Ops;
89
import org.tensorflow.op.core.Variable;
910
import org.tensorflow.op.train.ApplyFtrl;
1011
import org.tensorflow.types.family.TType;
@@ -238,21 +239,21 @@ private <T extends TType> void createFtrlSlot(Output<T> v) {
238239

239240
/** {@inheritDoc} */
240241
@Override
241-
protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
242+
protected <T extends TType> Op applyDense(Ops deps, Output<T> gradient, Output<T> variable) {
242243
Variable<T> accumSlot = getSlot(variable, ACCUMULATOR).get();
243244
Variable<T> linearSlot = getSlot(variable, LINEAR_ACCUMULATOR).get();
244245
ApplyFtrl.Options options = ApplyFtrl.useLocking(true);
245-
return this.tf.train.applyFtrl(
246+
return deps.train.applyFtrl(
246247
variable,
247248
accumSlot, // accum
248249
linearSlot, // linear
249250
gradient, // gradient
250-
tf.dtypes.cast(tf.constant(learningRate), gradient.type()), // lr
251-
tf.dtypes.cast(tf.constant(l1RegularizationStrength), gradient.type()), // l1
252-
tf.dtypes.cast(tf.constant(l2RegularizationStrength), gradient.type()), // l2
253-
tf.dtypes.cast(
254-
tf.constant(l2ShrinkageRegularizationStrength), gradient.type()), // l2Shrinkage
255-
tf.dtypes.cast(tf.constant(learningRatePower), gradient.type()), // lrPower
251+
deps.dtypes.cast(deps.constant(learningRate), gradient.type()), // lr
252+
deps.dtypes.cast(deps.constant(l1RegularizationStrength), gradient.type()), // l1
253+
deps.dtypes.cast(deps.constant(l2RegularizationStrength), gradient.type()), // l2
254+
deps.dtypes.cast(
255+
deps.constant(l2ShrinkageRegularizationStrength), gradient.type()), // l2Shrinkage
256+
deps.dtypes.cast(deps.constant(learningRatePower), gradient.type()), // lrPower
256257
options);
257258
}
258259

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java

+4-3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.tensorflow.Graph;
1919
import org.tensorflow.Output;
2020
import org.tensorflow.op.Op;
21+
import org.tensorflow.op.Ops;
2122
import org.tensorflow.op.train.ApplyGradientDescent;
2223
import org.tensorflow.types.family.TType;
2324

@@ -65,10 +66,10 @@ public GradientDescent(Graph graph, String name, float learningRate) {
6566

6667
/** {@inheritDoc} */
6768
@Override
68-
protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
69-
return tf.train.applyGradientDescent(
69+
protected <T extends TType> Op applyDense(Ops deps, Output<T> gradient, Output<T> variable) {
70+
return deps.train.applyGradientDescent(
7071
variable,
71-
tf.dtypes.cast(tf.constant(learningRate), gradient.type()),
72+
deps.dtypes.cast(deps.constant(learningRate), gradient.type()),
7273
gradient,
7374
ApplyGradientDescent.useLocking(true));
7475
}

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java

+5-4
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.tensorflow.Operand;
2121
import org.tensorflow.Output;
2222
import org.tensorflow.op.Op;
23+
import org.tensorflow.op.Ops;
2324
import org.tensorflow.op.core.Variable;
2425
import org.tensorflow.op.train.ApplyMomentum;
2526
import org.tensorflow.types.family.TType;
@@ -130,14 +131,14 @@ private <T extends TType> void createMomentumSlot(Output<T> v) {
130131

131132
/** {@inheritDoc} */
132133
@Override
133-
protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
134+
protected <T extends TType> Op applyDense(Ops deps, Output<T> gradient, Output<T> variable) {
134135
Variable<T> slot = getSlot(variable, MOMENTUM).get();
135-
return tf.train.applyMomentum(
136+
return deps.train.applyMomentum(
136137
variable,
137138
slot,
138-
tf.dtypes.cast(tf.constant(learningRate), gradient.type()),
139+
deps.dtypes.cast(deps.constant(learningRate), gradient.type()),
139140
gradient,
140-
tf.dtypes.cast(tf.constant(momentum), gradient.type()),
141+
deps.dtypes.cast(deps.constant(momentum), gradient.type()),
141142
ApplyMomentum.useNesterov(useNesterov),
142143
ApplyMomentum.useLocking(true));
143144
}

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java

+21-20
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import org.tensorflow.Output;
88
import org.tensorflow.ndarray.Shape;
99
import org.tensorflow.op.Op;
10+
import org.tensorflow.op.Ops;
1011
import org.tensorflow.op.core.Assign;
1112
import org.tensorflow.op.core.Constant;
1213
import org.tensorflow.op.core.Variable;
@@ -224,53 +225,53 @@ protected Optional<Op> prepare(String scopeName) {
224225

225226
/** {@inheritDoc} */
226227
@Override
227-
protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
228+
protected <T extends TType> Op applyDense(Ops deps, Output<T> gradient, Output<T> variable) {
228229
Class<T> type = gradient.type();
229230
Variable<T> m = getSlot(variable, FIRST_MOMENT).get(); // first Moment
230231
Variable<T> v = getSlot(variable, SECOND_MOMENT).get(); // Second Moment
231232

232233
// gPrime = grad / coefficients['oneMinusMScheduleNew']
233-
Operand<T> gPrime = tf.math.div(gradient, tf.dtypes.cast(oneMinusMScheduleNew, type));
234+
Operand<T> gPrime = deps.math.div(gradient, deps.dtypes.cast(oneMinusMScheduleNew, type));
234235
// mT = (coefficients['beta_1_t'] * m + coefficients['one_minus_beta_1_t'] * grad)
235236
Operand<T> mT =
236-
tf.math.add(
237-
tf.math.mul(tf.dtypes.cast(betaOneConst, type), m),
238-
tf.math.mul(tf.dtypes.cast(oneMinusBeta1, type), gradient));
237+
deps.math.add(
238+
deps.math.mul(deps.dtypes.cast(betaOneConst, type), m),
239+
deps.math.mul(deps.dtypes.cast(oneMinusBeta1, type), gradient));
239240
// mT = state_ops.assign(m, mT, use_locking=self._use_locking)
240241
// update m
241-
mT = tf.assign(m, mT, Assign.useLocking(true));
242+
mT = deps.assign(m, mT, Assign.useLocking(true));
242243

243244
// mTPrime = mT / coefficients['oneMinusMScheduleNext']
244-
Operand<T> mTPrime = tf.math.div(mT, tf.dtypes.cast(oneMinusMScheduleNext, type));
245+
Operand<T> mTPrime = deps.math.div(mT, deps.dtypes.cast(oneMinusMScheduleNext, type));
245246

246247
// vT = (coefficients['beta_2_t'] * v + coefficients['one_minus_beta_2_t'] *
247248
// math_ops.square(grad))
248249
Operand<T> vT =
249-
tf.math.add(
250-
tf.math.mul(tf.dtypes.cast(betaTwoConst, type), v),
251-
tf.math.mul(tf.dtypes.cast(oneMinusBeta2, type), tf.math.square(gradient)));
250+
deps.math.add(
251+
deps.math.mul(deps.dtypes.cast(betaTwoConst, type), v),
252+
deps.math.mul(deps.dtypes.cast(oneMinusBeta2, type), deps.math.square(gradient)));
252253
// vT = state_ops.assign(v, vT, use_locking=self._use_locking)
253254
// update v
254-
vT = tf.assign(v, vT, Assign.useLocking(true));
255+
vT = deps.assign(v, vT, Assign.useLocking(true));
255256

256257
// vTPrime = vT / coefficients['vTPrimeDenominator']
257-
Operand<T> vTPrime = tf.math.div(vT, tf.dtypes.cast(vTPrimeDenominator, type));
258+
Operand<T> vTPrime = deps.math.div(vT, deps.dtypes.cast(vTPrimeDenominator, type));
258259

259260
// m_t_bar = (coefficients['oneMinusMT'] * gPrime + coefficients['mT1'] * mTPrime)
260261
Operand<T> m_t_bar =
261-
tf.math.add(
262-
tf.math.mul(tf.dtypes.cast(oneMinusMT, type), gPrime),
263-
tf.math.mul(tf.dtypes.cast(mT1, type), mTPrime));
262+
deps.math.add(
263+
deps.math.mul(deps.dtypes.cast(oneMinusMT, type), gPrime),
264+
deps.math.mul(deps.dtypes.cast(mT1, type), mTPrime));
264265
// varT = var - coefficients['lr_t'] * m_t_bar / (math_ops.sqrt(vTPrime) +
265266
// coefficients['epsilon'])
266267
Operand<T> varT =
267-
tf.math.sub(
268+
deps.math.sub(
268269
variable,
269-
tf.math.div(
270-
tf.math.mul(tf.dtypes.cast(learningRateConst, type), m_t_bar),
271-
tf.math.add(tf.math.sqrt(vTPrime), tf.dtypes.cast(epsilonConst, type))));
270+
deps.math.div(
271+
deps.math.mul(deps.dtypes.cast(learningRateConst, type), m_t_bar),
272+
deps.math.add(deps.math.sqrt(vTPrime), deps.dtypes.cast(epsilonConst, type))));
272273

273-
return tf.assign(variable, varT, Assign.useLocking(true));
274+
return deps.assign(variable, varT, Assign.useLocking(true));
274275
}
275276

276277
/**

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java

+6-4
Original file line numberDiff line numberDiff line change
@@ -168,14 +168,16 @@ public Op applyGradients(List<GradAndVar<? extends TType>> gradsAndVars, String
168168
gradsAndVars.stream().map(GradAndVar::getVariable).collect(Collectors.toList());
169169

170170
createSlots(variables);
171+
List<Op> gradients = gradsAndVars.stream().map(GradAndVar::getGradient).filter(g -> !g.isClosed()).collect(Collectors.toList());
172+
Ops tfOpsGrads = tf.withControlDependencies(gradients);
171173

172174
Optional<Op> prepOp = prepare(name + "/prepare");
173175

174176
List<Op> updateOps = new ArrayList<>();
175177
prepOp.ifPresent(updateOps::add);
176178
for (GradAndVar<? extends TType> pair : gradsAndVars) {
177179
if (!pair.gradient.isClosed()) {
178-
updateOps.add(applyDense(pair));
180+
updateOps.add(applyDense(tfOpsGrads, pair));
179181
}
180182
}
181183

@@ -261,8 +263,8 @@ protected void createSlots(List<Output<? extends TType>> variables) {}
261263
* @param <T> the datatype of the gradients and variables.
262264
* @return An operand which applies the desired optimizer update to the variable.
263265
*/
264-
private <T extends TType> Op applyDense(GradAndVar<T> gradVarPair) {
265-
return applyDense(gradVarPair.getGradient(), gradVarPair.getVariable());
266+
private <T extends TType> Op applyDense(Ops opDependencies, GradAndVar<T> gradVarPair) {
267+
return applyDense(opDependencies, gradVarPair.getGradient(), gradVarPair.getVariable());
266268
}
267269

268270
/**
@@ -273,7 +275,7 @@ private <T extends TType> Op applyDense(GradAndVar<T> gradVarPair) {
273275
* @param <T> The type of the variable.
274276
* @return An operand which applies the desired optimizer update to the variable.
275277
*/
276-
protected abstract <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable);
278+
protected abstract <T extends TType> Op applyDense(Ops opDependencies, Output<T> gradient, Output<T> variable);
277279

278280
/**
279281
* Gathers up the update operations into a single op that can be used as a run target.

0 commit comments

Comments
 (0)