diff --git a/ada_verona/database/machine_learning_model/onnx_network.py b/ada_verona/database/machine_learning_model/onnx_network.py index 7238bbd..0cd589a 100644 --- a/ada_verona/database/machine_learning_model/onnx_network.py +++ b/ada_verona/database/machine_learning_model/onnx_network.py @@ -19,6 +19,7 @@ import onnx import torch from onnx2torch import convert +from onnxsim import simplify from ada_verona.database.machine_learning_model.network import Network from ada_verona.database.machine_learning_model.torch_model_wrapper import TorchModelWrapper @@ -102,7 +103,21 @@ def load_pytorch_model(self) -> torch.nn.Module: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch_model_wrapper = self.torch_model_wrapper if torch_model_wrapper is None: - torch_model = convert(self.path).to(device) + onnx_model = self.load_onnx_model() + # Simplify model + try: + model_simp, check = simplify(onnx_model) + if not check: + print(f"ONNX-simplifier validation failed for {self.name}, using original.") + model_to_convert = onnx_model + else: + model_to_convert = model_simp + except Exception as e: + print(f"Simplification failed ({e}). Attempting to convert original model.") + model_to_convert = onnx_model + + torch_model = convert(model_to_convert).to(device) + torch_model_wrapper = TorchModelWrapper(torch_model, self.get_input_shape()) self.torch_model_wrapper = torch_model_wrapper diff --git a/requirements.txt b/requirements.txt index 9fb2e33..6f840fe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ numpy>=1.24.3 onnx>=1.14.0 onnxruntime>=1.14.1 onnx2torch>=1.5.14 +onnxsim>=0.4.0 pandas>=2.0.1 PyYAML>=6.0.1 result>=0.9.0