Skip to content

Commit

Permalink
Accept partially known shapes in boolean mask/updates
Browse files Browse the repository at this point in the history
  • Loading branch information
karllessard committed May 31, 2024
1 parent 5cf1568 commit 2cb6de8
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import org.tensorflow.exceptions.TensorFlowException;
import org.tensorflow.proto.RunMetadata;
import org.tensorflow.types.family.TType;

/**
* An {@link AutoCloseable} wrapper around a {@link Map} containing {@link Tensor}s.
Expand Down Expand Up @@ -115,6 +117,31 @@ public Tensor get(int index) {
}
}

/**
* Gets the value from the container at the specified index, casting it to a given tensor type
*
* <p>Throws {@link IllegalStateException} if the container has been closed, and {@link
* IndexOutOfBoundsException} if the index is invalid.
*
* @param index The index to lookup.
* @param type tensor type
* @return The value at the index.
*/
public <T extends TType> T get(int index, Class<T> type) {
if (!closed) {
var tensor = list.get(index);
try {
return type.cast(tensor);
} catch (ClassCastException e) {
var tensorName = map.keySet().stream().collect(Collectors.toList()).get(index);
throw new IllegalArgumentException(
buildInvalidTensorTypeExceptionMessage(tensor, tensorName, type));
}
} else {
throw new IllegalStateException("Result is closed");
}
}

/**
* Gets the value from the container assuming it's not been closed.
*
Expand All @@ -131,6 +158,33 @@ public Optional<Tensor> get(String key) {
}
}

/**
* Gets the value from the container, assuming it's not been closed, casting it to a given tensor
* type.
*
* <p>Throws {@link IllegalStateException} if the container has been closed.
*
* @param key The key to lookup.
* @param type tensor type
* @return Optional.of the value if it exists.
*/
public <T extends TType> Optional<T> get(String key, Class<T> type) {
if (!closed) {
return Optional.ofNullable(map.get(key))
.map(
t -> {
try {
return type.cast(t);
} catch (ClassCastException e) {
throw new IllegalArgumentException(
buildInvalidTensorTypeExceptionMessage(t, key, type));
}
});
} else {
throw new IllegalStateException("Result is closed");
}
}

/**
* Metadata about the run.
*
Expand Down Expand Up @@ -196,4 +250,20 @@ public Optional<RunMetadata> getMetadata() {
private boolean closed;

private static final Logger logger = Logger.getLogger(Result.class.getName());

private String buildInvalidTensorTypeExceptionMessage(
Tensor tensor, String tensorName, Class<? extends TType> requestedType) {
String actualTypeName =
tensor instanceof TType
? ((TType) tensor).type().getSimpleName()
: tensor.getClass().getName();
throw new IllegalStateException(
"Tensor \""
+ tensorName
+ "\" of type \""
+ actualTypeName
+ "\" is not compatible with requested type \""
+ requestedType.getSimpleName()
+ "\"");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public static <T extends TType> Operand<T> create(
if (maskShape.numDimensions() == 0) {
throw new IllegalArgumentException("Mask cannot be a scalar.");
}
if (maskShape.hasUnknownDimension()) {
if (maskShape.isUnknown()) {
throw new IllegalArgumentException("Mask cannot have unknown number of dimensions");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public static <T extends TType> Operand<T> create(
if (maskShape.numDimensions() == 0) {
throw new IllegalArgumentException("Mask cannot be a scalar.");
}
if (maskShape.hasUnknownDimension()) {
if (maskShape.isUnknown()) {
throw new IllegalArgumentException("Mask cannot have unknown number of dimensions");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.tensorflow.op.core;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

import org.junit.jupiter.api.Test;
import org.tensorflow.Graph;
Expand Down Expand Up @@ -66,4 +67,39 @@ public void testBooleanMask() {
}
}
}

@Test
public void testBooleanMaskWithPartiallyUnknownShape() {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new OpScope(g);

Operand<TInt32> input = Constant.arrayOf(scope, 1, 2, 3, 4);
Placeholder<TBool> inputMask =
Placeholder.create(scope, TBool.class, Placeholder.shape(Shape.of(Shape.UNKNOWN_SIZE)));

Operand<TInt32> output = BooleanMask.create(scope, input, inputMask);

try (TBool mask = TBool.vectorOf(true, false, false, true);
TInt32 result = (TInt32) sess.runner().feed(inputMask, mask).fetch(output).run().get(0)) {
// expected shape from Python tensorflow
assertEquals(Shape.of(2), result.shape());
assertEquals(1, result.getInt(0));
assertEquals(4, result.getInt(1));
}
}
}

@Test
public void testBooleanMaskWithUnknownShape() {
try (Graph g = new Graph()) {
Scope scope = new OpScope(g);

Operand<TInt32> input = Constant.arrayOf(scope, 1, 2, 3, 4);
Placeholder<TBool> inputMask = Placeholder.create(scope, TBool.class);

assertThrows(
IllegalArgumentException.class, () -> BooleanMask.create(scope, input, inputMask));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.tensorflow.op.core;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

import org.junit.jupiter.api.Test;
import org.tensorflow.Graph;
Expand Down Expand Up @@ -151,4 +152,44 @@ public void testBooleanMaskUpdateAxis() {
}
}
}

@Test
public void testBooleanMaskUpdateWithPartiallyUnknownShape() {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new OpScope(g);

Operand<TInt32> input = Constant.arrayOf(scope, 1, 2, 3, 4);
Operand<TInt32> updates = Constant.arrayOf(scope, -1, 2);
Placeholder<TBool> inputMask =
Placeholder.create(scope, TBool.class, Placeholder.shape(Shape.of(Shape.UNKNOWN_SIZE)));

Operand<TInt32> output = BooleanMaskUpdate.create(scope, input, inputMask, updates);

try (TBool mask = TBool.vectorOf(false, true, false, true);
TInt32 result = (TInt32) sess.runner().feed(inputMask, mask).fetch(output).run().get(0)) {
// expected shape from Python tensorflow
assertEquals(Shape.of(4), result.shape());
assertEquals(1, result.getInt(0));
assertEquals(-1, result.getInt(1));
assertEquals(3, result.getInt(2));
assertEquals(2, result.getInt(3));
}
}
}

@Test
public void testBooleanMaskUpdateWithUnknownShape() {
try (Graph g = new Graph()) {
Scope scope = new OpScope(g);

Operand<TInt32> input = Constant.arrayOf(scope, 1, 2, 3, 4);
Operand<TInt32> updates = Constant.arrayOf(scope, -1, 2);
Placeholder<TBool> inputMask = Placeholder.create(scope, TBool.class);

assertThrows(
IllegalArgumentException.class,
() -> BooleanMaskUpdate.create(scope, input, inputMask, updates));
}
}
}

0 comments on commit 2cb6de8

Please sign in to comment.