Skip to content

Commit 9af9ba4

Browse files
committed
Accept partially known shapes in boolean mask/updates
1 parent 5cf1568 commit 9af9ba4

File tree

4 files changed

+82
-2
lines changed

4 files changed

+82
-2
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ public static <T extends TType> Operand<T> create(
7878
if (maskShape.numDimensions() == 0) {
7979
throw new IllegalArgumentException("Mask cannot be a scalar.");
8080
}
81-
if (maskShape.hasUnknownDimension()) {
81+
if (maskShape.isUnknown()) {
8282
throw new IllegalArgumentException("Mask cannot have unknown number of dimensions");
8383
}
8484

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMaskUpdate.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ public static <T extends TType> Operand<T> create(
8686
if (maskShape.numDimensions() == 0) {
8787
throw new IllegalArgumentException("Mask cannot be a scalar.");
8888
}
89-
if (maskShape.hasUnknownDimension()) {
89+
if (maskShape.isUnknown()) {
9090
throw new IllegalArgumentException("Mask cannot have unknown number of dimensions");
9191
}
9292

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java

+38
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.tensorflow.op.core;
1818

1919
import static org.junit.jupiter.api.Assertions.assertEquals;
20+
import static org.junit.jupiter.api.Assertions.assertThrows;
2021

2122
import org.junit.jupiter.api.Test;
2223
import org.tensorflow.Graph;
@@ -66,4 +67,41 @@ public void testBooleanMask() {
6667
}
6768
}
6869
}
70+
71+
@Test
72+
public void testBooleanMaskWithPartiallyUnknownShape() {
73+
try (Graph g = new Graph();
74+
Session sess = new Session(g)) {
75+
Scope scope = new OpScope(g);
76+
77+
Operand<TInt32> input = Constant.arrayOf(scope, 1, 2, 3, 4);
78+
Placeholder<TBool> inputMask =
79+
Placeholder.create(scope, TBool.class, Placeholder.shape(Shape.of(Shape.UNKNOWN_SIZE)));
80+
81+
Operand<TInt32> output = BooleanMask.create(scope, input, inputMask);
82+
83+
try (TBool mask = TBool.vectorOf(true, false, false, true);
84+
TInt32 result =
85+
(TInt32) sess.runner().feed(inputMask, mask).fetch(output).run().get(0)) {
86+
// expected shape from Python tensorflow
87+
assertEquals(Shape.of(2), result.shape());
88+
assertEquals(1, result.getInt(0));
89+
assertEquals(4, result.getInt(1));
90+
}
91+
}
92+
}
93+
94+
@Test
95+
public void testBooleanMaskWithUnknownShape() {
96+
try (Graph g = new Graph()) {
97+
Scope scope = new OpScope(g);
98+
99+
Operand<TInt32> input = Constant.arrayOf(scope, 1, 2, 3, 4);
100+
Placeholder<TBool> inputMask = Placeholder.create(scope, TBool.class);
101+
102+
assertThrows(
103+
IllegalArgumentException.class,
104+
() -> BooleanMask.create(scope, input, inputMask));
105+
}
106+
}
69107
}

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java

+42
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.tensorflow.op.core;
1818

1919
import static org.junit.jupiter.api.Assertions.assertEquals;
20+
import static org.junit.jupiter.api.Assertions.assertThrows;
2021

2122
import org.junit.jupiter.api.Test;
2223
import org.tensorflow.Graph;
@@ -151,4 +152,45 @@ public void testBooleanMaskUpdateAxis() {
151152
}
152153
}
153154
}
155+
156+
@Test
157+
public void testBooleanMaskUpdateWithPartiallyUnknownShape() {
158+
try (Graph g = new Graph();
159+
Session sess = new Session(g)) {
160+
Scope scope = new OpScope(g);
161+
162+
Operand<TInt32> input = Constant.arrayOf(scope, 1, 2, 3, 4);
163+
Operand<TInt32> updates = Constant.arrayOf(scope, -1, 2);
164+
Placeholder<TBool> inputMask =
165+
Placeholder.create(scope, TBool.class, Placeholder.shape(Shape.of(Shape.UNKNOWN_SIZE)));
166+
167+
Operand<TInt32> output = BooleanMaskUpdate.create(scope, input, inputMask, updates);
168+
169+
try (TBool mask = TBool.vectorOf(false, true, false, true);
170+
TInt32 result =
171+
(TInt32) sess.runner().feed(inputMask, mask).fetch(output).run().get(0)) {
172+
// expected shape from Python tensorflow
173+
assertEquals(Shape.of(4), result.shape());
174+
assertEquals(1, result.getInt(0));
175+
assertEquals(-1, result.getInt(1));
176+
assertEquals(3, result.getInt(2));
177+
assertEquals(2, result.getInt(3));
178+
}
179+
}
180+
}
181+
182+
@Test
183+
public void testBooleanMaskUpdateWithUnknownShape() {
184+
try (Graph g = new Graph()) {
185+
Scope scope = new OpScope(g);
186+
187+
Operand<TInt32> input = Constant.arrayOf(scope, 1, 2, 3, 4);
188+
Operand<TInt32> updates = Constant.arrayOf(scope, -1, 2);
189+
Placeholder<TBool> inputMask = Placeholder.create(scope, TBool.class);
190+
191+
assertThrows(
192+
IllegalArgumentException.class,
193+
() -> BooleanMaskUpdate.create(scope, input, inputMask, updates));
194+
}
195+
}
154196
}

0 commit comments

Comments
 (0)