Skip to content

Commit d7b0912

Browse files
committed
[tmva][sofie] Check arguments of torch.onnx.export function
Also in case of failure calling export , exit tutorials without giving an error Improve also comments following Sanjiban suggested review
1 parent caca586 commit d7b0912

1 file changed

Lines changed: 31 additions & 8 deletions

File tree

tutorials/machine_learning/TMVA_SOFIE_ONNX.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch.nn as nn
1818
import ROOT
1919
import numpy as np
20+
import inspect
2021

2122
def CreateAndTrainModel(modelName):
2223

@@ -43,19 +44,41 @@ def CreateAndTrainModel(modelName):
4344
loss.backward()
4445
optimizer.step()
4546

47+
#*******************************************************
48+
## EXPORT to ONNX
49+
#
50+
# need to evaluate the model before exporting to ONNX
51+
# and to provide a dummy input tensor to set the input model shape
4652
model.eval()
47-
#export the model to ONNX
53+
4854
modelFile = modelName + ".onnx"
4955
dummy_x = torch.randn(1,32)
5056
model(dummy_x)
51-
torch.onnx.export(model, dummy_x, modelFile, export_params=True,
52-
dynamo = True, # this is for new PyTorch exporter from version 2.5
53-
external_data=False, # this important to avoid weights saved in a different onnx.data file
54-
input_names=["input"],
55-
output_names=["output"])
56-
print("model exported to ONNX as",modelFile)
57-
return modelFile
5857

58+
#check for torch.onnx.export parameters
59+
def filtered_kwargs(func, **candidate_kwargs):
60+
sig = inspect.signature(func)
61+
return {
62+
k: v for k, v in candidate_kwargs.items()
63+
if k in sig.parameters
64+
}
65+
kwargs = filtered_kwargs(
66+
torch.onnx.export,
67+
input_names=["input"],
68+
output_names=["output"],
69+
external_data=False, # may not exist
70+
dynamo=True # may not exist
71+
)
72+
print("calling torch.onnx.export with parameters",kwargs)
73+
74+
try:
75+
torch.onnx.export(model, dummy_x, modelFile, **kwargs)
76+
print("model exported to ONNX as",modelFile)
77+
return modelFile
78+
except TypeError:
79+
print("Cannot export model from pytorch to ONNX - with version ",torch.__version__)
80+
print("Skip tutorial execution")
81+
exit()
5982

6083

6184
def ParseModel(modelFile, verbose=False):

0 commit comments

Comments
 (0)