15
15
*/
16
16
package org .tensorflow ;
17
17
18
+ import java .util .HashMap ;
18
19
import java .util .Map ;
19
20
import java .util .Set ;
20
21
import org .tensorflow .ndarray .Shape ;
22
+ import org .tensorflow .proto .framework .DataType ;
21
23
import org .tensorflow .proto .framework .SignatureDef ;
22
24
import org .tensorflow .proto .framework .TensorInfo ;
23
25
import org .tensorflow .proto .framework .TensorShapeProto ;
@@ -32,6 +34,16 @@ public class Signature {
32
34
/** The default signature key, when not provided */
33
35
public static final String DEFAULT_KEY = "serving_default" ;
34
36
37
+ public static class TensorDescription {
38
+ public final DataType dataType ;
39
+ public final Shape shape ;
40
+
41
+ public TensorDescription (DataType dataType , Shape shape ) {
42
+ this .dataType = dataType ;
43
+ this .shape = shape ;
44
+ }
45
+ }
46
+
35
47
/**
36
48
* Builds a new function signature.
37
49
*/
@@ -174,6 +186,32 @@ public String toString() {
174
186
return strBuilder .toString ();
175
187
}
176
188
189
+ private Map <String , TensorDescription > buildTensorDescriptionMap (Map <String , TensorInfo > dataMapIn ) {
190
+ Map <String , TensorDescription > dataTypeMap = new HashMap <>();
191
+ dataMapIn .forEach ((a , b ) -> {
192
+ long [] tensorDims = b .getTensorShape ().getDimList ().stream ().mapToLong (d -> d .getSize ()).toArray ();
193
+ Shape tensorShape = Shape .of (tensorDims );
194
+ dataTypeMap .put (a , new TensorDescription (b .getDtype (),
195
+ tensorShape ));
196
+ });
197
+ return dataTypeMap ;
198
+ }
199
+
200
+ /**
201
+ * Returns the names of the inputs in this signature mapped to their expected data type and shape
202
+ * @return
203
+ */
204
+ public Map <String , TensorDescription > getInputs () {
205
+ return buildTensorDescriptionMap (signatureDef .getInputsMap ());
206
+ }
207
+
208
+ /**
209
+ * Returns the names of the outputs in this signature mapped to their expected data type and shape
210
+ */
211
+ public Map <String , TensorDescription > getOutputs () {
212
+ return buildTensorDescriptionMap (signatureDef .getOutputsMap ());
213
+ }
214
+
177
215
Signature (String key , SignatureDef signatureDef ) {
178
216
this .key = key ;
179
217
this .signatureDef = signatureDef ;
0 commit comments