diff --git a/ada_verona/database/machine_learning_model/pytorch_network.py b/ada_verona/database/machine_learning_model/pytorch_network.py index fc31434..0957c09 100644 --- a/ada_verona/database/machine_learning_model/pytorch_network.py +++ b/ada_verona/database/machine_learning_model/pytorch_network.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - from pathlib import Path import numpy as np @@ -33,7 +32,7 @@ class PyTorchNetwork(Network): input_shape (tuple[int]): Input shape of the model. """ - def __init__(self, model: torch.nn.Module, input_shape: tuple[int], name: str) -> None: + def __init__(self, model: torch.nn.Module, input_shape: tuple[int], name: str, path: Path | None = None) -> None: """ Initialize the PyTorchNetwork with architecture and weights paths. @@ -46,6 +45,7 @@ def __init__(self, model: torch.nn.Module, input_shape: tuple[int], name: str) - self.model = model self.input_shape = input_shape self._name = name + self._path = path self.torch_model_wrapper = None @property @@ -57,7 +57,7 @@ def name(self) -> str: str: The name of the network. """ return self._name - + @property def path(self) -> Path: """ @@ -66,8 +66,8 @@ def path(self) -> Path: Returns: Path: The path of the network. """ - return None - + return self._path + def get_input_shape(self) -> np.ndarray: """ Get the input shape of the PyTorch model. @@ -89,14 +89,13 @@ def load_pytorch_model(self) -> torch.nn.Module: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = self.model.to(device) model.eval() - + self.torch_model_wrapper = TorchModelWrapper(model, self.get_input_shape()) - + return self.torch_model_wrapper - def to_dict(self): raise NotImplementedError("PytorchNetwork does not support to_dict() function currently.") - + def from_dict(cls, data: dict): - raise NotImplementedError("PytorchNetwork does not support from_dict() function currently.") \ No newline at end of file + raise NotImplementedError("PytorchNetwork does not support from_dict() function currently.")