From 6cde304a5c86b20819e8110c1ce729692c178435 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Thu, 5 Jun 2025 10:21:25 +0200 Subject: [PATCH 1/2] support optional `data_path` args --- apax/nodes/model.py | 39 ++++++++++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/apax/nodes/model.py b/apax/nodes/model.py index 1f069d00..efde46e2 100644 --- a/apax/nodes/model.py +++ b/apax/nodes/model.py @@ -47,9 +47,12 @@ class Apax(ApaxBase): verbosity of logging during training """ - data: list[ase.Atoms] = zntrack.deps() + data: list[ase.Atoms]|None = zntrack.deps() + data_path: str|pathlib.Path|None = zntrack.deps_path(None) + config: str = zntrack.params_path() - validation_data: list[ase.Atoms] = zntrack.deps() + validation_data: list[ase.Atoms]|None = zntrack.deps() + validation_data_path: str|pathlib.Path|None = zntrack.deps_path(None) model: t.Optional[ApaxBase] = zntrack.deps(None) nl_skin: float = zntrack.params(0.5) log_level: str = zntrack.params("info") @@ -61,6 +64,18 @@ class Apax(ApaxBase): metrics: dict = zntrack.metrics() + def __post_init__(self): + super().__post_init__() + + if self.data is not None and self.data_path is not None: + raise ValueError( + "You can either provide `data` or `data_path`, not both." + ) + if self.validation_data is not None and self.validation_data_path is not None: + raise ValueError( + "You can either provide `validation_data` or `validation_data_path`, not both." + ) + @property def parameter(self) -> dict: parameter = yaml.safe_load(self.state.fs.read_text(self.config)) @@ -68,8 +83,8 @@ def parameter(self) -> dict: custom_parameters = { "directory": self.model_directory.as_posix(), "experiment": "", - "train_data_path": self.train_data_file.as_posix(), - "val_data_path": self.validation_data_file.as_posix(), + "train_data_path": self.train_data_file.as_posix() if self.data is None else self.data_path, + "val_data_path": self.validation_data_file.as_posix() if self.validation_data is None else self.validation_data_path, } if self.model is not None: @@ -97,12 +112,18 @@ def get_metrics(self): def run(self): """Primary method to run which executes all steps of the model training""" - if not self.state.restarted: - train_db = znh5md.IO(self.train_data_file.as_posix()) - train_db.extend(self.data) - val_db = znh5md.IO(self.validation_data_file.as_posix()) - val_db.extend(self.validation_data) + if not self.state.restarted: + if self.data is not None: + train_db = znh5md.IO(self.train_data_file.as_posix()) + train_db.extend(self.data) + else: + self.train_data_file.write_text(f"Using {self.data_path} instead") + if self.validation_data is not None: + val_db = znh5md.IO(self.validation_data_file.as_posix()) + val_db.extend(self.validation_data) + else: + self.validation_data_file.write_text(f"Using {self.validation_data_path} instead") csv_path = self.model_directory / "log.csv" if self.state.restarted and csv_path.is_file(): From 82d9b2e7eaf5501bcbe8c0909eaadb62c0a92007 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Jun 2025 10:54:37 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- apax/nodes/model.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/apax/nodes/model.py b/apax/nodes/model.py index efde46e2..8fb1646b 100644 --- a/apax/nodes/model.py +++ b/apax/nodes/model.py @@ -47,12 +47,12 @@ class Apax(ApaxBase): verbosity of logging during training """ - data: list[ase.Atoms]|None = zntrack.deps() - data_path: str|pathlib.Path|None = zntrack.deps_path(None) + data: list[ase.Atoms] | None = zntrack.deps() + data_path: str | pathlib.Path | None = zntrack.deps_path(None) config: str = zntrack.params_path() - validation_data: list[ase.Atoms]|None = zntrack.deps() - validation_data_path: str|pathlib.Path|None = zntrack.deps_path(None) + validation_data: list[ase.Atoms] | None = zntrack.deps() + validation_data_path: str | pathlib.Path | None = zntrack.deps_path(None) model: t.Optional[ApaxBase] = zntrack.deps(None) nl_skin: float = zntrack.params(0.5) log_level: str = zntrack.params("info") @@ -68,9 +68,7 @@ def __post_init__(self): super().__post_init__() if self.data is not None and self.data_path is not None: - raise ValueError( - "You can either provide `data` or `data_path`, not both." - ) + raise ValueError("You can either provide `data` or `data_path`, not both.") if self.validation_data is not None and self.validation_data_path is not None: raise ValueError( "You can either provide `validation_data` or `validation_data_path`, not both." @@ -83,8 +81,12 @@ def parameter(self) -> dict: custom_parameters = { "directory": self.model_directory.as_posix(), "experiment": "", - "train_data_path": self.train_data_file.as_posix() if self.data is None else self.data_path, - "val_data_path": self.validation_data_file.as_posix() if self.validation_data is None else self.validation_data_path, + "train_data_path": self.train_data_file.as_posix() + if self.data is None + else self.data_path, + "val_data_path": self.validation_data_file.as_posix() + if self.validation_data is None + else self.validation_data_path, } if self.model is not None: @@ -112,7 +114,6 @@ def get_metrics(self): def run(self): """Primary method to run which executes all steps of the model training""" - if not self.state.restarted: if self.data is not None: train_db = znh5md.IO(self.train_data_file.as_posix()) @@ -123,7 +124,9 @@ def run(self): val_db = znh5md.IO(self.validation_data_file.as_posix()) val_db.extend(self.validation_data) else: - self.validation_data_file.write_text(f"Using {self.validation_data_path} instead") + self.validation_data_file.write_text( + f"Using {self.validation_data_path} instead" + ) csv_path = self.model_directory / "log.csv" if self.state.restarted and csv_path.is_file():