diff --git a/pancax/__init__.py b/pancax/__init__.py index 1c21b75..d111c96 100644 --- a/pancax/__init__.py +++ b/pancax/__init__.py @@ -39,7 +39,6 @@ Tet4Element, \ Tet10Element from .history_writer import EnsembleHistoryWriter, HistoryWriter -from .logging import EnsembleLogger, Logger from .loss_functions import \ DirichletBCLoss, \ NeumannBCLoss, \ @@ -147,9 +146,6 @@ # history writers "EnsembleHistoryWriter", "HistoryWriter", - # loggers - "EnsembleLogger", - "Logger", # loss functions "DirichletBCLoss", "NeumannBCLoss", diff --git a/pancax/logging.py b/pancax/logging.py deleted file mode 100644 index fc213d0..0000000 --- a/pancax/logging.py +++ /dev/null @@ -1,101 +0,0 @@ -from abc import abstractmethod -from pathlib import Path -from typing import List -import equinox as eqx - - -def log_loss(loss, n, log_every): - if n % log_every == 0: - print(f"Epoch {n}:") - print(f"\tLoss = {loss[0].item()}") - for key, val in loss[1].items(): - if key == "props" or key == "dprops": - print(f"\t{key} = {val}") - else: - print(f"\t{key} = {val.item()}") - - -class BaseLogger(eqx.Module): - log_every: int - - @abstractmethod - def flush(self): - pass - - def log_loss(self, loss, epoch): - if epoch % self.log_every == 0: - self.write_epoch_value(epoch) - self.write_loss_value(loss) - self.write_aux_values(loss) - self.flush() - - @abstractmethod - def write_aux_values(self, loss): - pass - - @abstractmethod - def write_epoch_value(self, epoch): - pass - - @abstractmethod - def write_loss_value(self, loss): - pass - - -class Logger(BaseLogger): - log_file: any - - def __init__(self, log_file_in: str, log_every: int) -> None: - super().__init__(log_every) - log_file_in = Path(log_file_in) - self.log_file = open(log_file_in, "w") - - def __exit__(self, exc_type, exc_value, exc_traceback): - print("Closing log file.") - self.log_file.close() - - def flush(self): - self.log_file.flush() - - def write_aux_values(self, loss): - for key, val in loss[1].items(): - if key == "props" or key == "dprops": - self.log_file.write(f" {key} = {val}\n") - else: - self.log_file.write(f" {key} = {val.item()}\n") - - def write_epoch_value(self, epoch): - self.log_file.write(f"Epoch {epoch}:\n") - - def write_loss_value(self, loss): - self.log_file.write(f" Loss = {loss[0].item()}\n") - - -class EnsembleLogger(BaseLogger): - loggers: List[Logger] - - def __init__(self, base_name: str, n_pinns: int, log_every: int) -> None: - super().__init__(log_every) - self.loggers = [ - Logger(f"{base_name}_{n}.log", log_every) for n in range(n_pinns) - ] - - def flush(self): - for logger in self.loggers: - logger.flush() - - def write_aux_values(self, loss): - for key, val in loss[1].items(): - for n, logger in enumerate(self.loggers): - if key == "props" or key == "dprops": - logger.log_file.write(f" {key} = {val[n]}\n") - else: - logger.log_file.write(f" {key} = {val[n].item()}\n") - - def write_epoch_value(self, epoch): - for logger in self.loggers: - logger.write_epoch_value(epoch) - - def write_loss_value(self, loss): - for n, val in enumerate(loss[0]): - self.loggers[n].log_file.write(f" Loss = {val.item()}\n") diff --git a/pancax/utils.py b/pancax/utils.py index 17b6a88..66cfd66 100644 --- a/pancax/utils.py +++ b/pancax/utils.py @@ -15,57 +15,41 @@ def find_data_file(data_file_in: str): call_file = Path(inspect.stack()[1].filename) call_file_dir = call_file.parent + data_file = Path(os.path.join(call_file_dir, data_file_in)) + + if data_file.is_file(): + print(f"Found {data_file_in} in {data_file.parent}") + return data_file + data_file = Path(os.path.join(call_file_dir, "data", data_file_in)) - if not data_file.is_file(): - data_file = Path(os.path.join(call_file_dir, data_file)) - if not data_file.is_file(): - raise DataFileNotFoundException( - f"Could not find data file {data_file_in} in either " - f"{call_file_dir} or {call_file_dir}/data" - ) + if data_file.is_file(): + print(f"Found {data_file_in} in {data_file.parent}") + return data_file - print(f"Found {data_file_in} in {data_file.parent}") - return data_file + raise DataFileNotFoundException( + f"Could not find data file {data_file_in} in either " + f"{call_file_dir} or {call_file_dir}/data" + ) def find_mesh_file(mesh_file_in: str): call_file = Path(inspect.stack()[1].filename) call_file_dir = call_file.parent - mesh_file = Path(os.path.join(call_file_dir, "mesh", mesh_file_in)) - - if not mesh_file.is_file(): - mesh_file = Path(os.path.join(call_file_dir, mesh_file)) - if not mesh_file.is_file(): - raise MeshFileNotFoundException( - f"Could not find data file {mesh_file_in} in either " - f"{call_file_dir} or {call_file_dir}/mesh" - ) - - print(f"Found {mesh_file_in} in {mesh_file.parent}") - return mesh_file - - -def set_checkpoint_file(checkpoint_file_base: str): - call_file = Path(inspect.stack()[3].filename) - call_file_dir = call_file.parent - - checkpoint_dir = Path(os.path.join(call_file_dir, "checkpoint")) + mesh_file = Path(os.path.join(call_file_dir, mesh_file_in)) - if not checkpoint_dir.is_dir(): - os.makedirs(checkpoint_dir) + if mesh_file.is_file(): + print(f"Found {mesh_file_in} in {mesh_file.parent}") + return mesh_file - checkpoint_file = Path(os.path.join(checkpoint_dir, checkpoint_file_base)) + mesh_file = Path(os.path.join(call_file, "mesh", mesh_file_in)) - # if not mesh_file.is_file(): - # mesh_file = Path(os.path.join(call_file_dir, mesh_file)) - # if not mesh_file.is_file(): - # raise MeshFileNotFoundException( - # f'Could not find data file {mesh_file_in} in either ' - # f'{call_file_dir} or {call_file_dir}/mesh' - # ) + if mesh_file.is_file(): + print(f"Found {mesh_file_in} in {mesh_file.parent}") + return mesh_file - # print(f'Found {mesh_file_in} in {mesh_file.parent}') - # return mesh_file - return checkpoint_file + raise MeshFileNotFoundException( + f"Could not find data file {mesh_file_in} in either " + f"{call_file_dir} or {call_file_dir}/mesh" + ) diff --git a/pyproject.toml b/pyproject.toml index d1640d2..41f0d5e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = 'pancax' -version = '0.0.14' +version = '0.0.15' authors = [ {name = 'Craig M. Hamel, email = '} ] @@ -15,16 +15,14 @@ dependencies = [ [project.optional-dependencies] cpu = [ - 'chex', - 'equinox==0.13', - 'jax==0.6.2', + 'equinox~=0.13', + 'jax~=0.6.2', 'jaxtyping', 'optax' ] cuda = [ - 'chex', - 'equinox==0.13', - 'jax[cuda12]==0.6.2', + 'equinox~=0.13', + 'jax[cuda12]~=0.6.2', 'jaxtyping', 'optax' ] @@ -41,8 +39,8 @@ dev = [ ] rocm = [ 'chex', - 'equinox==0.12.1', - 'jax[rocm]==0.5.0', + 'equinox~=0.13', + 'jax[rocm]~=0.5', 'jaxtyping', 'optax' ] @@ -57,7 +55,6 @@ build-backend = "hatchling.build" [project.scripts] pancax = "pancax.cli:pancax_main" -# [tool.setuptools] [tool.hatch.build.targets.wheel] packages = [ 'pancax' @@ -71,10 +68,3 @@ exclude_lines = [ "def __repr__", "pass" ] - -#[tool.pytest.ini_options] -#minversion = "6.0" -#addopts = "-ra -q --disable-warnings" -#testpaths = [ -# "test" -#] diff --git a/scripts/rocm-docker.sh b/scripts/rocm-docker.sh index 7d381ae..3ebe6c6 100755 --- a/scripts/rocm-docker.sh +++ b/scripts/rocm-docker.sh @@ -9,7 +9,3 @@ docker run -it \ seccomp=unconfined \ -v $(pwd):/home/temp_user/pancax \ pancax /bin/bash -# --name rocm_jax rocm/jax-community:rocm6.2.3-jax0.4.33-py3.12.6 /bin/bash - -#docker attach rocm_jax - diff --git a/test/test_utils.py b/test/test_utils.py new file mode 100644 index 0000000..bc6d859 --- /dev/null +++ b/test/test_utils.py @@ -0,0 +1,26 @@ +def test_find_data_file(): + from pancax import find_data_file + find_data_file("data_global.csv") + + +def test_find_data_file_not_found(): + from pancax import find_data_file + from pancax.utils import DataFileNotFoundException + import pytest + + with pytest.raises(DataFileNotFoundException): + find_data_file("bad_file_name.csv") + + +def test_find_mesh_file(): + from pancax import find_mesh_file + find_mesh_file("mesh.g") + + +def test_find_mesh_file_not_found(): + from pancax import find_mesh_file + from pancax.utils import MeshFileNotFoundException + import pytest + + with pytest.raises(MeshFileNotFoundException): + find_mesh_file("bad_mesh_file.g")