Skip to content

Commit

Permalink
Adding casts to the if test so it passes on GPU.
Browse files Browse the repository at this point in the history
  • Loading branch information
Craigacp committed Mar 8, 2024
1 parent 3f89f60 commit f46cdb6
Showing 1 changed file with 7 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.tensorflow.Session;
import org.tensorflow.Signature;
import org.tensorflow.op.Ops;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt32;

public class IfTest {
Expand All @@ -37,15 +38,19 @@ private static Operand<TInt32> basicIf(Ops tf, Operand<TInt32> a, Operand<TInt32
(ops) -> {
Operand<TInt32> a1 = ops.placeholder(TInt32.class);
Operand<TInt32> b1 = ops.placeholder(TInt32.class);
return Signature.builder().input("a", a1).input("b", b1).output("y", a1).build();
Operand<TInt32> y = ops.identity(a1);
return Signature.builder().input("a", a1).input("b", b1).output("y", y).build();
});

ConcreteFunction elseBranch =
ConcreteFunction.create(
(ops) -> {
Operand<TInt32> a1 = ops.placeholder(TInt32.class);
Operand<TInt32> b1 = ops.placeholder(TInt32.class);
Operand<TInt32> y = ops.math.neg(b1);
// Casts around the math.neg operator as it's not implemented correctly for int32 in
// GPUs at some point between TF 2.10 and TF 2.15.
Operand<TInt32> y =
ops.dtypes.cast(ops.math.neg(ops.dtypes.cast(a1, TFloat32.class)), TInt32.class);
return Signature.builder().input("a", a1).input("b", b1).output("y", y).build();
});

Expand Down

0 comments on commit f46cdb6

Please sign in to comment.