Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion ada_verona/database/machine_learning_model/onnx_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import onnx
import torch
from onnx2torch import convert
from onnxsim import simplify

from ada_verona.database.machine_learning_model.network import Network
from ada_verona.database.machine_learning_model.torch_model_wrapper import TorchModelWrapper
Expand Down Expand Up @@ -102,7 +103,21 @@ def load_pytorch_model(self) -> torch.nn.Module:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_model_wrapper = self.torch_model_wrapper
if torch_model_wrapper is None:
torch_model = convert(self.path).to(device)
onnx_model = self.load_onnx_model()
# Simplify model
try:
model_simp, check = simplify(onnx_model)
if not check:
print(f"ONNX-simplifier validation failed for {self.name}, using original.")
model_to_convert = onnx_model
else:
model_to_convert = model_simp
except Exception as e:
print(f"Simplification failed ({e}). Attempting to convert original model.")
model_to_convert = onnx_model

torch_model = convert(model_to_convert).to(device)

torch_model_wrapper = TorchModelWrapper(torch_model, self.get_input_shape())
self.torch_model_wrapper = torch_model_wrapper

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ numpy>=1.24.3
onnx>=1.14.0
onnxruntime>=1.14.1
onnx2torch>=1.5.14
onnxsim>=0.4.0
pandas>=2.0.1
PyYAML>=6.0.1
result>=0.9.0
Expand Down
Loading