@@ -173,7 +173,9 @@ def __init__(self, module):
173
173
self ._get_input = module ["get_input" ]
174
174
self ._get_num_outputs = module ["get_num_outputs" ]
175
175
self ._get_input_index = module ["get_input_index" ]
176
+ self ._get_output_index = module ["get_output_index" ]
176
177
self ._get_input_info = module ["get_input_info" ]
178
+ self ._get_output_info = module ["get_output_info" ]
177
179
self ._get_num_inputs = module ["get_num_inputs" ]
178
180
self ._load_params = module ["load_params" ]
179
181
self ._share_params = module ["share_params" ]
@@ -315,6 +317,21 @@ def get_input_index(self, name):
315
317
"""
316
318
return self ._get_input_index (name )
317
319
320
+ def get_output_index (self , name ):
321
+ """Get outputs index via output name.
322
+
323
+ Parameters
324
+ ----------
325
+ name : str
326
+ The output key name
327
+
328
+ Returns
329
+ -------
330
+ index: int
331
+ The output index. -1 will be returned if the given output name is not found.
332
+ """
333
+ return self ._get_output_index (name )
334
+
318
335
def get_input_info (self ):
319
336
"""Return the 'shape' and 'dtype' dictionaries of the graph.
320
337
@@ -341,6 +358,24 @@ def get_input_info(self):
341
358
342
359
return shape_dict , dtype_dict
343
360
361
+ def get_output_info (self ):
362
+ """Return the 'shape' and 'dtype' dictionaries of the graph.
363
+
364
+ Returns
365
+ -------
366
+ shape_dict : Map
367
+ Shape dictionary - {output_name: tuple}.
368
+ dtype_dict : Map
369
+ dtype dictionary - {output_name: dtype}.
370
+ """
371
+ output_info = self ._get_output_info ()
372
+ assert "shape" in output_info
373
+ shape_dict = output_info ["shape" ]
374
+ assert "dtype" in output_info
375
+ dtype_dict = output_info ["dtype" ]
376
+
377
+ return shape_dict , dtype_dict
378
+
344
379
def get_output (self , index , out = None ):
345
380
"""Get index-th output to out
346
381
0 commit comments