Skip to content

Commit 26e8dc3

Browse files
authored
Expose input and output types from Signature (#182)
1 parent 3a0489e commit 26e8dc3

File tree

2 files changed

+67
-3
lines changed

2 files changed

+67
-3
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java

+38
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
*/
1616
package org.tensorflow;
1717

18+
import java.util.HashMap;
1819
import java.util.Map;
1920
import java.util.Set;
2021
import org.tensorflow.ndarray.Shape;
22+
import org.tensorflow.proto.framework.DataType;
2123
import org.tensorflow.proto.framework.SignatureDef;
2224
import org.tensorflow.proto.framework.TensorInfo;
2325
import org.tensorflow.proto.framework.TensorShapeProto;
@@ -32,6 +34,16 @@ public class Signature {
3234
/** The default signature key, when not provided */
3335
public static final String DEFAULT_KEY = "serving_default";
3436

37+
public static class TensorDescription {
38+
public final DataType dataType;
39+
public final Shape shape;
40+
41+
public TensorDescription(DataType dataType, Shape shape) {
42+
this.dataType = dataType;
43+
this.shape = shape;
44+
}
45+
}
46+
3547
/**
3648
* Builds a new function signature.
3749
*/
@@ -174,6 +186,32 @@ public String toString() {
174186
return strBuilder.toString();
175187
}
176188

189+
private Map<String, TensorDescription> buildTensorDescriptionMap(Map<String, TensorInfo> dataMapIn) {
190+
Map<String, TensorDescription> dataTypeMap = new HashMap<>();
191+
dataMapIn.forEach((a, b) -> {
192+
long[] tensorDims = b.getTensorShape().getDimList().stream().mapToLong(d -> d.getSize()).toArray();
193+
Shape tensorShape = Shape.of(tensorDims);
194+
dataTypeMap.put(a, new TensorDescription(b.getDtype(),
195+
tensorShape));
196+
});
197+
return dataTypeMap;
198+
}
199+
200+
/**
201+
* Returns the names of the inputs in this signature mapped to their expected data type and shape
202+
* @return
203+
*/
204+
public Map<String, TensorDescription> getInputs() {
205+
return buildTensorDescriptionMap(signatureDef.getInputsMap());
206+
}
207+
208+
/**
209+
* Returns the names of the outputs in this signature mapped to their expected data type and shape
210+
*/
211+
public Map<String, TensorDescription> getOutputs() {
212+
return buildTensorDescriptionMap(signatureDef.getOutputsMap());
213+
}
214+
177215
Signature(String key, SignatureDef signatureDef) {
178216
this.key = key;
179217
this.signatureDef = signatureDef;

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java

+29-3
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@
1414
==============================================================================*/
1515
package org.tensorflow;
1616

17-
import static org.junit.jupiter.api.Assertions.assertNull;
18-
import static org.junit.jupiter.api.Assertions.assertThrows;
19-
2017
import org.junit.jupiter.api.Test;
18+
import org.tensorflow.Signature.TensorDescription;
2119
import org.tensorflow.op.Ops;
20+
import org.tensorflow.proto.framework.DataType;
21+
22+
import java.util.Map;
23+
24+
import static org.junit.jupiter.api.Assertions.*;
2225

2326
public class SignatureTest {
2427

@@ -43,6 +46,29 @@ public void cannotDuplicateInputOutputNames() {
4346
}
4447
}
4548

49+
@Test
50+
public void getInputsAndOutputs() {
51+
Ops tf = Ops.create();
52+
Signature builder = Signature.builder()
53+
.input("x", tf.constant(10.0f))
54+
.output("y", tf.constant(new float[][] {{10.0f, 30.0f}}))
55+
.output("z", tf.constant(20.0f)).build();
56+
57+
Map<String, TensorDescription> inputs = builder.getInputs();
58+
assertEquals(inputs.size(), 1);
59+
60+
Map<String, TensorDescription> outputs = builder.getOutputs();
61+
assertEquals(outputs.size(), 2);
62+
63+
assertEquals(outputs.get("y").dataType, DataType.DT_FLOAT);
64+
assertEquals(outputs.get("z").dataType, DataType.DT_FLOAT);
65+
assertArrayEquals(outputs.get("y").shape.asArray(), new long [] {1,2});
66+
assertArrayEquals(outputs.get("z").shape.asArray(), new long [] {});
67+
68+
Signature emptySignature = Signature.builder().build();
69+
assertEquals(emptySignature.getInputs().size(), 0);
70+
}
71+
4672
@Test
4773
public void emptyMethodNameConvertedToNull() {
4874
Signature signature = Signature.builder().key("f").build();

0 commit comments

Comments
 (0)