Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions ada_verona/database/machine_learning_model/pytorch_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from pathlib import Path

import numpy as np
Expand All @@ -33,7 +32,7 @@ class PyTorchNetwork(Network):
input_shape (tuple[int]): Input shape of the model.
"""

def __init__(self, model: torch.nn.Module, input_shape: tuple[int], name: str) -> None:
def __init__(self, model: torch.nn.Module, input_shape: tuple[int], name: str, path: Path | None = None) -> None:
"""
Initialize the PyTorchNetwork with architecture and weights paths.

Expand All @@ -46,6 +45,7 @@ def __init__(self, model: torch.nn.Module, input_shape: tuple[int], name: str) -
self.model = model
self.input_shape = input_shape
self._name = name
self._path = path
self.torch_model_wrapper = None

@property
Expand All @@ -57,7 +57,7 @@ def name(self) -> str:
str: The name of the network.
"""
return self._name

@property
def path(self) -> Path:
"""
Expand All @@ -66,8 +66,8 @@ def path(self) -> Path:
Returns:
Path: The path of the network.
"""
return None
return self._path

def get_input_shape(self) -> np.ndarray:
"""
Get the input shape of the PyTorch model.
Expand All @@ -89,14 +89,13 @@ def load_pytorch_model(self) -> torch.nn.Module:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = self.model.to(device)
model.eval()

self.torch_model_wrapper = TorchModelWrapper(model, self.get_input_shape())

return self.torch_model_wrapper


def to_dict(self):
raise NotImplementedError("PytorchNetwork does not support to_dict() function currently.")

def from_dict(cls, data: dict):
raise NotImplementedError("PytorchNetwork does not support from_dict() function currently.")
raise NotImplementedError("PytorchNetwork does not support from_dict() function currently.")