Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions pancax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
Tet4Element, \
Tet10Element
from .history_writer import EnsembleHistoryWriter, HistoryWriter
from .logging import EnsembleLogger, Logger
from .loss_functions import \
DirichletBCLoss, \
NeumannBCLoss, \
Expand Down Expand Up @@ -147,9 +146,6 @@
# history writers
"EnsembleHistoryWriter",
"HistoryWriter",
# loggers
"EnsembleLogger",
"Logger",
# loss functions
"DirichletBCLoss",
"NeumannBCLoss",
Expand Down
101 changes: 0 additions & 101 deletions pancax/logging.py

This file was deleted.

66 changes: 25 additions & 41 deletions pancax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,57 +15,41 @@ def find_data_file(data_file_in: str):
call_file = Path(inspect.stack()[1].filename)
call_file_dir = call_file.parent

data_file = Path(os.path.join(call_file_dir, data_file_in))

if data_file.is_file():
print(f"Found {data_file_in} in {data_file.parent}")
return data_file

data_file = Path(os.path.join(call_file_dir, "data", data_file_in))

if not data_file.is_file():
data_file = Path(os.path.join(call_file_dir, data_file))
if not data_file.is_file():
raise DataFileNotFoundException(
f"Could not find data file {data_file_in} in either "
f"{call_file_dir} or {call_file_dir}/data"
)
if data_file.is_file():
print(f"Found {data_file_in} in {data_file.parent}")
return data_file

print(f"Found {data_file_in} in {data_file.parent}")
return data_file
raise DataFileNotFoundException(
f"Could not find data file {data_file_in} in either "
f"{call_file_dir} or {call_file_dir}/data"
)


def find_mesh_file(mesh_file_in: str):
call_file = Path(inspect.stack()[1].filename)
call_file_dir = call_file.parent

mesh_file = Path(os.path.join(call_file_dir, "mesh", mesh_file_in))

if not mesh_file.is_file():
mesh_file = Path(os.path.join(call_file_dir, mesh_file))
if not mesh_file.is_file():
raise MeshFileNotFoundException(
f"Could not find data file {mesh_file_in} in either "
f"{call_file_dir} or {call_file_dir}/mesh"
)

print(f"Found {mesh_file_in} in {mesh_file.parent}")
return mesh_file


def set_checkpoint_file(checkpoint_file_base: str):
call_file = Path(inspect.stack()[3].filename)
call_file_dir = call_file.parent

checkpoint_dir = Path(os.path.join(call_file_dir, "checkpoint"))
mesh_file = Path(os.path.join(call_file_dir, mesh_file_in))

if not checkpoint_dir.is_dir():
os.makedirs(checkpoint_dir)
if mesh_file.is_file():
print(f"Found {mesh_file_in} in {mesh_file.parent}")
return mesh_file

checkpoint_file = Path(os.path.join(checkpoint_dir, checkpoint_file_base))
mesh_file = Path(os.path.join(call_file, "mesh", mesh_file_in))

# if not mesh_file.is_file():
# mesh_file = Path(os.path.join(call_file_dir, mesh_file))
# if not mesh_file.is_file():
# raise MeshFileNotFoundException(
# f'Could not find data file {mesh_file_in} in either '
# f'{call_file_dir} or {call_file_dir}/mesh'
# )
if mesh_file.is_file():
print(f"Found {mesh_file_in} in {mesh_file.parent}")
return mesh_file

# print(f'Found {mesh_file_in} in {mesh_file.parent}')
# return mesh_file
return checkpoint_file
raise MeshFileNotFoundException(
f"Could not find data file {mesh_file_in} in either "
f"{call_file_dir} or {call_file_dir}/mesh"
)
24 changes: 7 additions & 17 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = 'pancax'
version = '0.0.14'
version = '0.0.15'
authors = [
{name = 'Craig M. Hamel, email = <[email protected]>'}
]
Expand All @@ -15,16 +15,14 @@ dependencies = [

[project.optional-dependencies]
cpu = [
'chex',
'equinox==0.13',
'jax==0.6.2',
'equinox~=0.13',
'jax~=0.6.2',
'jaxtyping',
'optax'
]
cuda = [
'chex',
'equinox==0.13',
'jax[cuda12]==0.6.2',
'equinox~=0.13',
'jax[cuda12]~=0.6.2',
'jaxtyping',
'optax'
]
Expand All @@ -41,8 +39,8 @@ dev = [
]
rocm = [
'chex',
'equinox==0.12.1',
'jax[rocm]==0.5.0',
'equinox~=0.13',
'jax[rocm]~=0.5',
'jaxtyping',
'optax'
]
Expand All @@ -57,7 +55,6 @@ build-backend = "hatchling.build"
[project.scripts]
pancax = "pancax.cli:pancax_main"

# [tool.setuptools]
[tool.hatch.build.targets.wheel]
packages = [
'pancax'
Expand All @@ -71,10 +68,3 @@ exclude_lines = [
"def __repr__",
"pass"
]

#[tool.pytest.ini_options]
#minversion = "6.0"
#addopts = "-ra -q --disable-warnings"
#testpaths = [
# "test"
#]
4 changes: 0 additions & 4 deletions scripts/rocm-docker.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,3 @@ docker run -it \
seccomp=unconfined \
-v $(pwd):/home/temp_user/pancax \
pancax /bin/bash
# --name rocm_jax rocm/jax-community:rocm6.2.3-jax0.4.33-py3.12.6 /bin/bash

#docker attach rocm_jax

26 changes: 26 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
def test_find_data_file():
from pancax import find_data_file
find_data_file("data_global.csv")


def test_find_data_file_not_found():
from pancax import find_data_file
from pancax.utils import DataFileNotFoundException
import pytest

with pytest.raises(DataFileNotFoundException):
find_data_file("bad_file_name.csv")


def test_find_mesh_file():
from pancax import find_mesh_file
find_mesh_file("mesh.g")


def test_find_mesh_file_not_found():
from pancax import find_mesh_file
from pancax.utils import MeshFileNotFoundException
import pytest

with pytest.raises(MeshFileNotFoundException):
find_mesh_file("bad_mesh_file.g")
Loading