Skip to content

Commit

Permalink
Fix broadcastMask/Update
Browse files Browse the repository at this point in the history
Accept partially unknown shaped mask
  • Loading branch information
karllessard committed May 31, 2024
1 parent 5cf1568 commit dcb9e11
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public static <T extends TType> Operand<T> create(
if (maskShape.numDimensions() == 0) {
throw new IllegalArgumentException("Mask cannot be a scalar.");
}
if (maskShape.hasUnknownDimension()) {
if (maskShape.isUnknown()) {
throw new IllegalArgumentException("Mask cannot have unknown number of dimensions");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public static <T extends TType> Operand<T> create(
if (maskShape.numDimensions() == 0) {
throw new IllegalArgumentException("Mask cannot be a scalar.");
}
if (maskShape.hasUnknownDimension()) {
if (maskShape.isUnknown()) {
throw new IllegalArgumentException("Mask cannot have unknown number of dimensions");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.tensorflow.op.core;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

import org.junit.jupiter.api.Test;
import org.tensorflow.Graph;
Expand Down Expand Up @@ -66,4 +67,39 @@ public void testBooleanMask() {
}
}
}

@Test
public void testBooleanMaskWithPartiallyUnknownShape() {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new OpScope(g);

Operand<TInt32> input = Constant.arrayOf(scope, 1, 2, 3, 4);
Placeholder<TBool> inputMask =
Placeholder.create(scope, TBool.class, Placeholder.shape(Shape.of(Shape.UNKNOWN_SIZE)));

Operand<TInt32> output = BooleanMask.create(scope, input, inputMask);

try (TBool mask = TBool.vectorOf(true, false, false, true);
TInt32 result = (TInt32) sess.runner().feed(inputMask, mask).fetch(output).run().get(0)) {
// expected shape from Python tensorflow
assertEquals(Shape.of(2), result.shape());
assertEquals(1, result.getInt(0));
assertEquals(4, result.getInt(1));
}
}
}

@Test
public void testBooleanMaskWithUnknownShape() {
try (Graph g = new Graph()) {
Scope scope = new OpScope(g);

Operand<TInt32> input = Constant.arrayOf(scope, 1, 2, 3, 4);
Placeholder<TBool> inputMask = Placeholder.create(scope, TBool.class);

assertThrows(
IllegalArgumentException.class, () -> BooleanMask.create(scope, input, inputMask));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.tensorflow.op.core;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

import org.junit.jupiter.api.Test;
import org.tensorflow.Graph;
Expand Down Expand Up @@ -151,4 +152,44 @@ public void testBooleanMaskUpdateAxis() {
}
}
}

@Test
public void testBooleanMaskUpdateWithPartiallyUnknownShape() {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new OpScope(g);

Operand<TInt32> input = Constant.arrayOf(scope, 1, 2, 3, 4);
Operand<TInt32> updates = Constant.arrayOf(scope, -1, 2);
Placeholder<TBool> inputMask =
Placeholder.create(scope, TBool.class, Placeholder.shape(Shape.of(Shape.UNKNOWN_SIZE)));

Operand<TInt32> output = BooleanMaskUpdate.create(scope, input, inputMask, updates);

try (TBool mask = TBool.vectorOf(false, true, false, true);
TInt32 result = (TInt32) sess.runner().feed(inputMask, mask).fetch(output).run().get(0)) {
// expected shape from Python tensorflow
assertEquals(Shape.of(4), result.shape());
assertEquals(1, result.getInt(0));
assertEquals(-1, result.getInt(1));
assertEquals(3, result.getInt(2));
assertEquals(2, result.getInt(3));
}
}
}

@Test
public void testBooleanMaskUpdateWithUnknownShape() {
try (Graph g = new Graph()) {
Scope scope = new OpScope(g);

Operand<TInt32> input = Constant.arrayOf(scope, 1, 2, 3, 4);
Operand<TInt32> updates = Constant.arrayOf(scope, -1, 2);
Placeholder<TBool> inputMask = Placeholder.create(scope, TBool.class);

assertThrows(
IllegalArgumentException.class,
() -> BooleanMaskUpdate.create(scope, input, inputMask, updates));
}
}
}

0 comments on commit dcb9e11

Please sign in to comment.