Skip to content

Commit f79c865

Browse files
committed
Dependencies
1 parent 74ff9bb commit f79c865

File tree

3 files changed

+42
-2
lines changed

3 files changed

+42
-2
lines changed

Diff for: python/tvm/runtime/executor/aot_executor.py

+32
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def __init__(self, module):
6868
self._get_input_index = module["get_input_index"]
6969
self._get_num_inputs = module["get_num_inputs"]
7070
self._get_input_name = module["get_input_name"]
71+
self._get_output_index = module["get_output_index"]
72+
self._get_output_info = module["get_output_info"]
7173

7274
def set_input(self, key=None, value=None, **params):
7375
"""Set inputs to the module via kwargs
@@ -199,3 +201,33 @@ def get_input_info(self):
199201
dtype_dict[input_name] = input_tensor.dtype
200202

201203
return shape_dict, dtype_dict
204+
205+
def get_output_info(self):
206+
"""Return the 'shape' and 'dtype' dictionaries of the graph.
207+
Returns
208+
-------
209+
shape_dict : Map
210+
Shape dictionary - {output_name: tuple}.
211+
dtype_dict : Map
212+
dtype dictionary - {output_name: dtype}.
213+
"""
214+
output_info = self._get_output_info()
215+
assert "shape" in output_info
216+
shape_dict = output_info["shape"]
217+
assert "dtype" in output_info
218+
dtype_dict = output_info["dtype"]
219+
220+
return shape_dict, dtype_dict
221+
222+
def get_output_index(self, name):
223+
"""Get outputs index via output name.
224+
Parameters
225+
----------
226+
name : str
227+
The output key name
228+
Returns
229+
-------
230+
index: int
231+
The output index. -1 will be returned if the given output name is not found.
232+
"""
233+
return self._get_output_index(name)

Diff for: src/runtime/crt/aot_executor_module/aot_executor_module.c

+5-1
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,8 @@ static const TVMBackendPackedCFunc aot_executor_registry_funcs[] = {
210210
&TVMAotExecutorModule_NotImplemented, // set_input (implemented via python wrapper)
211211
&TVMAotExecutorModule_NotImplemented, // share_params (do not implement)
212212
&TVMAotExecutorModule_GetInputName, // get_input_name
213+
&TVMAotExecutorModule_NotImplemented, // get_output_index
214+
&TVMAotExecutorModule_NotImplemented, // get_output_info
213215
};
214216

215217
static const TVMFuncRegistry aot_executor_registry = {
@@ -223,7 +225,9 @@ static const TVMFuncRegistry aot_executor_registry = {
223225
"run\0"
224226
"set_input\0"
225227
"share_params\0"
226-
"get_input_name\0",
228+
"get_input_name\0"
229+
"get_output_index\0"
230+
"get_output_info\0",
227231
aot_executor_registry_funcs};
228232

229233
tvm_crt_error_t TVMAotExecutorModule_Register() {

Diff for: src/runtime/crt/graph_executor_module/graph_executor_module.c

+5-1
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,8 @@ static const TVMBackendPackedCFunc graph_executor_registry_funcs[] = {
235235
&TVMGraphExecutorModule_Run,
236236
&TVMGraphExecutorModule_SetInput,
237237
&TVMGraphExecutorModule_NotImplemented, // share_params
238+
&TVMGraphExecutorModule_NotImplemented, // get_output_index
239+
&TVMGraphExecutorModule_NotImplemented, // get_output_info
238240
};
239241

240242
static const TVMFuncRegistry graph_executor_registry = {
@@ -247,7 +249,9 @@ static const TVMFuncRegistry graph_executor_registry = {
247249
"load_params\0"
248250
"run\0"
249251
"set_input\0"
250-
"share_params\0",
252+
"share_params\0"
253+
"get_output_index\0"
254+
"get_output_info\0",
251255
graph_executor_registry_funcs};
252256

253257
tvm_crt_error_t TVMGraphExecutorModule_Register() {

0 commit comments

Comments
 (0)