Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interface should be public for external usage #522

Merged
merged 3 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public class LossesHelper {
* @param tf the TensorFlow Ops
* @param predictions Predicted values, a <code>Operand</code> of arbitrary dimensions.
* @param labels Optional label <code>Operand</code> whose dimensions match <code>prediction
* </code>.
* </code> .
* @param <T> the data type for the labels, predictions and result
* @return LossTuple of <code>prediction</code>, <code>label</code>,<code>sampleWeight</code> will
* be null. Each of them possibly has the last dimension squeezed, <code>sampleWeight</code>
Expand All @@ -77,7 +77,7 @@ public static <T extends TNumber> LossTuple<T> squeezeOrExpandDimensions(
* @param tf the TensorFlow Ops
* @param predictions Predicted values, a <code>Operand</code> of arbitrary dimensions.
* @param labels Optional label <code>Operand</code> whose dimensions match <code>prediction
* </code>.
* </code> .
* @param sampleWeights Optional sample weight(s) <code>Operand</code> whose dimensions match
* <code>
* prediction</code>.
Expand Down Expand Up @@ -179,7 +179,7 @@ private static <T extends TNumber> Operand<T> maybeExpandWeights(
*
* @param tf the TensorFlowOps
* @param labels Label values, a <code>Tensor</code> whose dimensions match <code>predictions
* </code>.
* </code> .
* @param predictions Predicted values, a <code>Tensor</code> of arbitrary dimensions.
* @param <T> the data type for the labels, predictions and result
* @return <code>labels</code> and <code>predictions</code>, possibly with last dim squeezed.
Expand All @@ -194,7 +194,7 @@ public static <T extends TNumber> LossTuple<T> removeSqueezableDimensions(
*
* @param tf the TensorFlowOps
* @param labels Label values, a <code>Operand</code> whose dimensions match <code>predictions
* </code>.
* </code> .
* @param predictions Predicted values, a <code>Tensor</code> of arbitrary dimensions.
* @param expectedRankDiff Expected result of <code>rank(predictions) - rank(labels)</code>.
* @param <T> the data type for the labels, predictions and result
Expand Down Expand Up @@ -222,11 +222,13 @@ public static <T extends TNumber> LossTuple<T> removeSqueezableDimensions(
// Use dynamic rank.

// TODO: hold for lazy select feature,
// Operand<TInt32> rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels));
// Operand<TInt32> rankDiff = tf.math.sub(tf.rank(predictions),
// tf.rank(labels));
if (predictionsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(predictionsShape.size(-1), 1)) {
/*
* TODO, if we ever get a select that does lazy evaluation, but for now do the tf.squeeze
* predictions = tf.select( tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ),
* TODO, if we ever get a select that does lazy evaluation, but for now do the
* tf.squeeze predictions = tf.select(
* tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ),
* tf.squeeze(predictions, Squeeze.axis(Arrays.asList(-1L))), predictions ); *
*/
predictions = tf.squeeze(predictions, Squeeze.axis(Collections.singletonList(-1L)));
Expand Down Expand Up @@ -282,11 +284,12 @@ private static <T extends TNumber> Operand<T> reduceWeightedLoss(
if (reduction == Reduction.NONE) {
loss = weightedLoss;
} else {
loss =
tf.reduceSum(weightedLoss, allAxes(tf, weightedLoss), ReduceSum.keepDims(Boolean.FALSE));
if (reduction == Reduction.AUTO || reduction == Reduction.SUM_OVER_BATCH_SIZE) {
loss = safeMean(tf, loss, weightedLoss.shape().size());
}
loss = safeMean(tf, weightedLoss);
} else
loss =
tf.reduceSum(
weightedLoss, allAxes(tf, weightedLoss), ReduceSum.keepDims(Boolean.FALSE));
}
return loss;
}
Expand All @@ -301,10 +304,10 @@ private static <T extends TNumber> Operand<T> reduceWeightedLoss(
* @return A scalar representing the mean of <code>losses</code>. If <code>numElements</code> is
* zero, then zero is returned.
*/
public static <T extends TNumber> Operand<T> safeMean(
Ops tf, Operand<T> losses, long numElements) {
Operand<T> totalLoss = tf.reduceSum(losses, allAxes(tf, losses));
return tf.math.divNoNan(totalLoss, cast(tf, tf.constant(numElements), losses.type()));
public static <T extends TNumber> Operand<T> safeMean(Ops tf, Operand<T> losses) {
Operand<T> totalLoss =
tf.reduceSum(losses, allAxes(tf, losses), ReduceSum.keepDims(Boolean.FALSE));
return tf.math.divNoNan(totalLoss, cast(tf, tf.shape.size(tf.shape(losses)), losses.type()));
}

/**
Expand Down Expand Up @@ -348,7 +351,8 @@ public static <T extends TNumber> Operand<T> rangeCheck(
tf.math.logicalAnd(
tf.reduceAll(tf.math.greaterEqual(values, minValue), allDims),
tf.reduceAll(tf.math.lessEqual(values, maxValue), allDims));
// Graph and Eager mode need to be handled differently, control dependencies are not allowed in
// Graph and Eager mode need to be handled differently, control dependencies are
// not allowed in
// Eager mode
if (tf.scope().env().isGraph()) {
AssertThat assertThat =
Expand Down Expand Up @@ -398,7 +402,8 @@ public static <T extends TNumber> Operand<T> valueCheck(
} else return values;
} else { // use dynamic shape
Operand<TBool> cond = tf.math.equal(tf.shape.size(tf.shape(diff.out())), tf.constant(0));
// Graph and Eager mode need to be handled differently, control dependencies are not allowed
// Graph and Eager mode need to be handled differently, control dependencies are
// not allowed
// in Eager mode
if (tf.scope().env().isGraph()) {
AssertThat assertThat =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import org.tensorflow.types.family.TNumber;

/** Interface for metrics */
interface Metric {
public interface Metric {

/**
* Creates a List of Operations to update the metric state based on input values.
Expand Down
Loading