|
7 | 7 | import org.tensorflow.Output;
|
8 | 8 | import org.tensorflow.ndarray.Shape;
|
9 | 9 | import org.tensorflow.op.Op;
|
| 10 | +import org.tensorflow.op.Ops; |
10 | 11 | import org.tensorflow.op.core.Assign;
|
11 | 12 | import org.tensorflow.op.core.Constant;
|
12 | 13 | import org.tensorflow.op.core.Variable;
|
@@ -224,53 +225,53 @@ protected Optional<Op> prepare(String scopeName) {
|
224 | 225 |
|
225 | 226 | /** {@inheritDoc} */
|
226 | 227 | @Override
|
227 |
| - protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) { |
| 228 | + protected <T extends TType> Op applyDense(Ops deps, Output<T> gradient, Output<T> variable) { |
228 | 229 | Class<T> type = gradient.type();
|
229 | 230 | Variable<T> m = getSlot(variable, FIRST_MOMENT).get(); // first Moment
|
230 | 231 | Variable<T> v = getSlot(variable, SECOND_MOMENT).get(); // Second Moment
|
231 | 232 |
|
232 | 233 | // gPrime = grad / coefficients['oneMinusMScheduleNew']
|
233 |
| - Operand<T> gPrime = tf.math.div(gradient, tf.dtypes.cast(oneMinusMScheduleNew, type)); |
| 234 | + Operand<T> gPrime = deps.math.div(gradient, deps.dtypes.cast(oneMinusMScheduleNew, type)); |
234 | 235 | // mT = (coefficients['beta_1_t'] * m + coefficients['one_minus_beta_1_t'] * grad)
|
235 | 236 | Operand<T> mT =
|
236 |
| - tf.math.add( |
237 |
| - tf.math.mul(tf.dtypes.cast(betaOneConst, type), m), |
238 |
| - tf.math.mul(tf.dtypes.cast(oneMinusBeta1, type), gradient)); |
| 237 | + deps.math.add( |
| 238 | + deps.math.mul(deps.dtypes.cast(betaOneConst, type), m), |
| 239 | + deps.math.mul(deps.dtypes.cast(oneMinusBeta1, type), gradient)); |
239 | 240 | // mT = state_ops.assign(m, mT, use_locking=self._use_locking)
|
240 | 241 | // update m
|
241 |
| - mT = tf.assign(m, mT, Assign.useLocking(true)); |
| 242 | + mT = deps.assign(m, mT, Assign.useLocking(true)); |
242 | 243 |
|
243 | 244 | // mTPrime = mT / coefficients['oneMinusMScheduleNext']
|
244 |
| - Operand<T> mTPrime = tf.math.div(mT, tf.dtypes.cast(oneMinusMScheduleNext, type)); |
| 245 | + Operand<T> mTPrime = deps.math.div(mT, deps.dtypes.cast(oneMinusMScheduleNext, type)); |
245 | 246 |
|
246 | 247 | // vT = (coefficients['beta_2_t'] * v + coefficients['one_minus_beta_2_t'] *
|
247 | 248 | // math_ops.square(grad))
|
248 | 249 | Operand<T> vT =
|
249 |
| - tf.math.add( |
250 |
| - tf.math.mul(tf.dtypes.cast(betaTwoConst, type), v), |
251 |
| - tf.math.mul(tf.dtypes.cast(oneMinusBeta2, type), tf.math.square(gradient))); |
| 250 | + deps.math.add( |
| 251 | + deps.math.mul(deps.dtypes.cast(betaTwoConst, type), v), |
| 252 | + deps.math.mul(deps.dtypes.cast(oneMinusBeta2, type), deps.math.square(gradient))); |
252 | 253 | // vT = state_ops.assign(v, vT, use_locking=self._use_locking)
|
253 | 254 | // update v
|
254 |
| - vT = tf.assign(v, vT, Assign.useLocking(true)); |
| 255 | + vT = deps.assign(v, vT, Assign.useLocking(true)); |
255 | 256 |
|
256 | 257 | // vTPrime = vT / coefficients['vTPrimeDenominator']
|
257 |
| - Operand<T> vTPrime = tf.math.div(vT, tf.dtypes.cast(vTPrimeDenominator, type)); |
| 258 | + Operand<T> vTPrime = deps.math.div(vT, deps.dtypes.cast(vTPrimeDenominator, type)); |
258 | 259 |
|
259 | 260 | // m_t_bar = (coefficients['oneMinusMT'] * gPrime + coefficients['mT1'] * mTPrime)
|
260 | 261 | Operand<T> m_t_bar =
|
261 |
| - tf.math.add( |
262 |
| - tf.math.mul(tf.dtypes.cast(oneMinusMT, type), gPrime), |
263 |
| - tf.math.mul(tf.dtypes.cast(mT1, type), mTPrime)); |
| 262 | + deps.math.add( |
| 263 | + deps.math.mul(deps.dtypes.cast(oneMinusMT, type), gPrime), |
| 264 | + deps.math.mul(deps.dtypes.cast(mT1, type), mTPrime)); |
264 | 265 | // varT = var - coefficients['lr_t'] * m_t_bar / (math_ops.sqrt(vTPrime) +
|
265 | 266 | // coefficients['epsilon'])
|
266 | 267 | Operand<T> varT =
|
267 |
| - tf.math.sub( |
| 268 | + deps.math.sub( |
268 | 269 | variable,
|
269 |
| - tf.math.div( |
270 |
| - tf.math.mul(tf.dtypes.cast(learningRateConst, type), m_t_bar), |
271 |
| - tf.math.add(tf.math.sqrt(vTPrime), tf.dtypes.cast(epsilonConst, type)))); |
| 270 | + deps.math.div( |
| 271 | + deps.math.mul(deps.dtypes.cast(learningRateConst, type), m_t_bar), |
| 272 | + deps.math.add(deps.math.sqrt(vTPrime), deps.dtypes.cast(epsilonConst, type)))); |
272 | 273 |
|
273 |
| - return tf.assign(variable, varT, Assign.useLocking(true)); |
| 274 | + return deps.assign(variable, varT, Assign.useLocking(true)); |
274 | 275 | }
|
275 | 276 |
|
276 | 277 | /**
|
|
0 commit comments