-
Notifications
You must be signed in to change notification settings - Fork 1.4k
[tmva][sofie] Add a new SOFIE tutorial to show full pipeline from ONNX #21015
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
lmoneta
merged 2 commits into
root-project:master
from
lmoneta:tmva_sofie_new_onnx_tutorial
Jan 27, 2026
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,165 @@ | ||
| ## \file | ||
| ## \ingroup tutorial_ml | ||
| ## \notebook -nodraw | ||
| ## This macro provides a simple example for: | ||
| ## - creating a model with Pytorch and export to ONNX | ||
| ## - parsing the ONNX file with SOFIE and generate C++ code | ||
| ## - compiling the model using ROOT Cling | ||
| ## - run the code and optionally compare with ONNXRuntime | ||
| ## | ||
| ## | ||
| ## \macro_code | ||
| ## \macro_output | ||
| ## \author Lorenzo Moneta | ||
|
|
||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| import ROOT | ||
| import numpy as np | ||
| import inspect | ||
|
|
||
| def CreateAndTrainModel(modelName): | ||
|
|
||
| model = nn.Sequential( | ||
| nn.Linear(32,16), | ||
| nn.ReLU(), | ||
| nn.Linear(16,8), | ||
| nn.ReLU(), | ||
| nn.Linear(8,2), | ||
| nn.Softmax(dim=1) | ||
| ) | ||
|
|
||
| criterion = nn.MSELoss() | ||
| optimizer = torch.optim.SGD(model.parameters(),lr=0.01) | ||
|
|
||
|
|
||
| #train model with the random data | ||
| for i in range(500): | ||
| x=torch.randn(2,32) | ||
| y=torch.randn(2,2) | ||
| y_pred = model(x) | ||
| loss = criterion(y_pred,y) | ||
| optimizer.zero_grad() | ||
| loss.backward() | ||
| optimizer.step() | ||
|
|
||
| #******************************************************* | ||
| ## EXPORT to ONNX | ||
| # | ||
| # need to evaluate the model before exporting to ONNX | ||
| # and to provide a dummy input tensor to set the input model shape | ||
| model.eval() | ||
|
|
||
| modelFile = modelName + ".onnx" | ||
| dummy_x = torch.randn(1,32) | ||
| model(dummy_x) | ||
|
|
||
| #check for torch.onnx.export parameters | ||
| def filtered_kwargs(func, **candidate_kwargs): | ||
| sig = inspect.signature(func) | ||
| return { | ||
| k: v for k, v in candidate_kwargs.items() | ||
| if k in sig.parameters | ||
| } | ||
| kwargs = filtered_kwargs( | ||
| torch.onnx.export, | ||
| input_names=["input"], | ||
| output_names=["output"], | ||
| external_data=False, # may not exist | ||
| dynamo=True # may not exist | ||
| ) | ||
| print("calling torch.onnx.export with parameters",kwargs) | ||
|
|
||
| try: | ||
| torch.onnx.export(model, dummy_x, modelFile, **kwargs) | ||
| print("model exported to ONNX as",modelFile) | ||
| return modelFile | ||
| except TypeError: | ||
| print("Cannot export model from pytorch to ONNX - with version ",torch.__version__) | ||
| print("Skip tutorial execution") | ||
| exit() | ||
|
|
||
|
|
||
| def ParseModel(modelFile, verbose=False): | ||
|
|
||
| parser = ROOT.TMVA.Experimental.SOFIE.RModelParser_ONNX() | ||
| model = parser.Parse(modelFile,verbose) | ||
| # | ||
| #print model weights | ||
| if (verbose): | ||
| model.PrintInitializedTensors() | ||
| data = model.GetTensorData['float']('0weight') | ||
| print("0weight",data) | ||
| data = model.GetTensorData['float']('2weight') | ||
| print("2weight",data) | ||
|
|
||
| # Generating inference code | ||
| model.Generate(); | ||
| #generate header file (and .dat file) with modelName+.hxx | ||
| model.OutputGenerated(); | ||
| if (verbose) : | ||
| model.PrintGenerated() | ||
|
|
||
| modelCode = modelFile.replace(".onnx",".hxx") | ||
| print("Generated model header file ",modelCode) | ||
| return modelCode | ||
|
|
||
| ################################################################### | ||
| ## Step 1 : Create and Train model | ||
| ################################################################### | ||
|
|
||
| #use an arbitrary modelName | ||
| modelName = "LinearModel" | ||
| modelFile = CreateAndTrainModel(modelName) | ||
|
|
||
|
|
||
| ################################################################### | ||
| ## Step 2 : Parse model and generate inference code with SOFIE | ||
| ################################################################### | ||
|
|
||
| modelCode = ParseModel(modelFile, False) | ||
|
|
||
| ################################################################### | ||
| ## Step 3 : Compile the generated C++ model code | ||
| ################################################################### | ||
|
|
||
| ROOT.gInterpreter.Declare('#include "' + modelCode + '"') | ||
|
|
||
| ################################################################### | ||
| ## Step 4: Evaluate the model | ||
| ################################################################### | ||
|
|
||
| #get first the SOFIE session namespace | ||
| sofie = getattr(ROOT, 'TMVA_SOFIE_' + modelName) | ||
| session = sofie.Session() | ||
|
|
||
| x = np.random.normal(0,1,(1,32)).astype(np.float32) | ||
| print("\n************************************************************") | ||
| print("Running inference with SOFIE ") | ||
| print("\ninput to model is ",x) | ||
| y = session.infer(x) | ||
| # output shape is (1,2) | ||
| y_sofie = np.asarray(y.data()) | ||
| print("-> output using SOFIE = ", y_sofie) | ||
|
|
||
| #check inference with onnx | ||
| try: | ||
| import onnxruntime as ort | ||
| # Load model | ||
| print("Running inference with ONNXRuntime ") | ||
| ort_session = ort.InferenceSession(modelFile) | ||
|
|
||
| # Run inference | ||
| outputs = ort_session.run(None, {"input": x}) | ||
| y_ort = outputs[0] | ||
| print("-> output using ORT =", y_ort) | ||
|
|
||
| testFailed = abs(y_sofie-y_ort) > 0.01 | ||
| if (np.any(testFailed)): | ||
| raiseError('Result is different between SOFIE and ONNXRT') | ||
| else : | ||
| print("OK") | ||
|
|
||
| except ImportError: | ||
| print("Missing ONNXRuntime: skipping comparison test") | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.