From f46cdb6c412a73f45f6324f35d438d9d2a2e7e85 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 8 Mar 2024 17:14:16 -0500 Subject: [PATCH] Adding casts to the if test so it passes on GPU. --- .../src/test/java/org/tensorflow/op/core/IfTest.java | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IfTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IfTest.java index 57bc0bc9ffb..1fc1b7ed46b 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IfTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IfTest.java @@ -27,6 +27,7 @@ import org.tensorflow.Session; import org.tensorflow.Signature; import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; public class IfTest { @@ -37,7 +38,8 @@ private static Operand basicIf(Ops tf, Operand a, Operand { Operand a1 = ops.placeholder(TInt32.class); Operand b1 = ops.placeholder(TInt32.class); - return Signature.builder().input("a", a1).input("b", b1).output("y", a1).build(); + Operand y = ops.identity(a1); + return Signature.builder().input("a", a1).input("b", b1).output("y", y).build(); }); ConcreteFunction elseBranch = @@ -45,7 +47,10 @@ private static Operand basicIf(Ops tf, Operand a, Operand { Operand a1 = ops.placeholder(TInt32.class); Operand b1 = ops.placeholder(TInt32.class); - Operand y = ops.math.neg(b1); + // Casts around the math.neg operator as it's not implemented correctly for int32 in + // GPUs at some point between TF 2.10 and TF 2.15. + Operand y = + ops.dtypes.cast(ops.math.neg(ops.dtypes.cast(a1, TFloat32.class)), TInt32.class); return Signature.builder().input("a", a1).input("b", b1).output("y", y).build(); });