1717import torch .nn as nn
1818import ROOT
1919import numpy as np
20+ import inspect
2021
2122def CreateAndTrainModel (modelName ):
2223
@@ -43,19 +44,41 @@ def CreateAndTrainModel(modelName):
4344 loss .backward ()
4445 optimizer .step ()
4546
47+ #*******************************************************
48+ ## EXPORT to ONNX
49+ #
50+ # need to evaluate the model before exporting to ONNX
51+ # and to provide a dummy input tensor to set the input model shape
4652 model .eval ()
47- #export the model to ONNX
53+
4854 modelFile = modelName + ".onnx"
4955 dummy_x = torch .randn (1 ,32 )
5056 model (dummy_x )
51- torch .onnx .export (model , dummy_x , modelFile , export_params = True ,
52- dynamo = True , # this is for new PyTorch exporter from version 2.5
53- external_data = False , # this important to avoid weights saved in a different onnx.data file
54- input_names = ["input" ],
55- output_names = ["output" ])
56- print ("model exported to ONNX as" ,modelFile )
57- return modelFile
5857
58+ #check for torch.onnx.export parameters
59+ def filtered_kwargs (func , ** candidate_kwargs ):
60+ sig = inspect .signature (func )
61+ return {
62+ k : v for k , v in candidate_kwargs .items ()
63+ if k in sig .parameters
64+ }
65+ kwargs = filtered_kwargs (
66+ torch .onnx .export ,
67+ input_names = ["input" ],
68+ output_names = ["output" ],
69+ external_data = False , # may not exist
70+ dynamo = True # may not exist
71+ )
72+ print ("calling torch.onnx.export with parameters" ,kwargs )
73+
74+ try :
75+ torch .onnx .export (model , dummy_x , modelFile , ** kwargs )
76+ print ("model exported to ONNX as" ,modelFile )
77+ return modelFile
78+ except TypeError :
79+ print ("Cannot export model from pytorch to ONNX - with version " ,torch .__version__ )
80+ print ("Skip tutorial execution" )
81+ exit ()
5982
6083
6184def ParseModel (modelFile , verbose = False ):
0 commit comments