Skip to content

Commit 96a5d24

Browse files
committed
Additional API support (#30)
get_output_index support added. Co-authored-by: Siva <[email protected]>
1 parent 567eeed commit 96a5d24

File tree

3 files changed

+53
-0
lines changed

3 files changed

+53
-0
lines changed

python/tvm/contrib/graph_executor.py

+35
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,9 @@ def __init__(self, module):
173173
self._get_input = module["get_input"]
174174
self._get_num_outputs = module["get_num_outputs"]
175175
self._get_input_index = module["get_input_index"]
176+
self._get_output_index = module["get_output_index"]
176177
self._get_input_info = module["get_input_info"]
178+
self._get_output_info = module["get_output_info"]
177179
self._get_num_inputs = module["get_num_inputs"]
178180
self._load_params = module["load_params"]
179181
self._share_params = module["share_params"]
@@ -315,6 +317,21 @@ def get_input_index(self, name):
315317
"""
316318
return self._get_input_index(name)
317319

320+
def get_output_index(self, name):
321+
"""Get outputs index via output name.
322+
323+
Parameters
324+
----------
325+
name : str
326+
The output key name
327+
328+
Returns
329+
-------
330+
index: int
331+
The output index. -1 will be returned if the given output name is not found.
332+
"""
333+
return self._get_output_index(name)
334+
318335
def get_input_info(self):
319336
"""Return the 'shape' and 'dtype' dictionaries of the graph.
320337
@@ -341,6 +358,24 @@ def get_input_info(self):
341358

342359
return shape_dict, dtype_dict
343360

361+
def get_output_info(self):
362+
"""Return the 'shape' and 'dtype' dictionaries of the graph.
363+
364+
Returns
365+
-------
366+
shape_dict : Map
367+
Shape dictionary - {output_name: tuple}.
368+
dtype_dict : Map
369+
dtype dictionary - {output_name: dtype}.
370+
"""
371+
output_info = self._get_output_info()
372+
assert "shape" in output_info
373+
shape_dict = output_info["shape"]
374+
assert "dtype" in output_info
375+
dtype_dict = output_info["dtype"]
376+
377+
return shape_dict, dtype_dict
378+
344379
def get_output(self, index, out=None):
345380
"""Get index-th output to out
346381

src/runtime/graph_executor/graph_executor.cc

+12
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,18 @@ PackedFunc GraphExecutor::GetFunction(const String& name, const ObjectPtr<Object
745745
CHECK(String::CanConvertFrom(args[0])) << "Input key is not a string";
746746
*rv = this->GetInputIndex(args[0].operator String());
747747
});
748+
} else if (name == "get_output_index") {
749+
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
750+
CHECK(String::CanConvertFrom(args[0])) << "Output key is not a string";
751+
int out_idx = -1;
752+
for (size_t i = 0; i < outputs_.size(); i++) {
753+
std::string& name = nodes_[outputs_[i].node_id].name;
754+
if (args[0].operator String() == name) {
755+
out_idx = i;
756+
}
757+
}
758+
*rv = out_idx;
759+
});
748760
} else if (name == "get_input_info") {
749761
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
750762
auto [shape_info, dtype_info] = this->GetInputInfo();

tests/python/relay/test_backend_graph_executor.py

+6
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,12 @@ def test_graph_executor_api():
467467
assert isinstance(dtype_dict[name], tvm.runtime.container.String)
468468
assert dtype_dict[name] == ty.dtype
469469

470+
shape_dict, dtype_dict = mod.get_output_info()
471+
assert isinstance(shape_dict, tvm.container.Map)
472+
assert isinstance(dtype_dict, tvm.container.Map)
473+
for i, key in enumerate(shape_dict):
474+
assert mod.get_output_index(key) == i
475+
470476

471477
@tvm.testing.requires_llvm
472478
def test_benchmark():

0 commit comments

Comments
 (0)