-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #14 from kumardeepakamd/main
support onnx operator tests
- Loading branch information
Showing
11 changed files
with
230 additions
and
207 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,36 +1,27 @@ | ||
import numpy as np | ||
import numpy | ||
import onnxruntime | ||
|
||
|
||
# the generated or checked in onnx file must always be called | ||
# model.onnx | ||
# This file forms middle of the runmodel.py for this test | ||
# which is generated by run.py script in root of e2eshark | ||
# test dir | ||
# <e2eshark>/tools/stubs/onnxstartmodel.py | ||
# this model.py | ||
# <e2eshark>/tools/stubs/onnxendmodel.py | ||
# are concatenated in that order to form | ||
# <test-run dir>/onnx/<test category>/<test name>/runmodel.py | ||
# which is run to run the model to generate output | ||
# Leave the above comment in the file | ||
# The generated or checked in onnx file must always be called model.onnx | ||
# the tools/stubs/onnxmodel.py is appended to model.py | ||
# to form runmodel.py in the rundirectory which is then taken | ||
# through flow | ||
|
||
|
||
# insert here any onnx API call to generate onnx file if | ||
# not using a checked in onnx model | ||
|
||
|
||
# to locally test, can uncomment the line below | ||
# comment it back for launching from run.py as this will be set | ||
# with full path to onnx to allow running in a separate run dir | ||
# session = onnxruntime.InferenceSession("model.onnx", None) | ||
|
||
# start an onnxrt session | ||
session = onnxruntime.InferenceSession("model.onnx", None) | ||
|
||
# fill the lines that set test_input and onnx_output | ||
# these two are special names and should not be changed | ||
test_input = np.random.rand(1, 3, 224, 224).astype(np.float32) | ||
test_input = numpy.random.rand(1, 3, 224, 224).astype(numpy.float32) | ||
print("Input:", test_input) | ||
|
||
# Get the name of the input of the model | ||
input_name = session.get_inputs()[0].name | ||
|
||
# call inference session | ||
onnx_output = [session.run([], {input_name: test_input})[0]] | ||
test_output = [session.run([], {input_name: test_input})[0]] | ||
print("Onput:", test_output) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,3 +18,4 @@ def name(self): | |
|
||
model = op_linear() | ||
test_input = torch.randn(8, 3) | ||
test_output = model(test_input) |
Oops, something went wrong.