Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Shared pytest fixtures and utilities for pyamtrack tests."""

import pytest


@pytest.fixture
def electron_energy_MeV_low():
"""Fixture providing a low electron energy in MeV for tests."""
return 100.0


@pytest.fixture
def electron_energy_MeV_high():
"""Fixture providing a high electron energy in MeV for tests."""
return 1000.0


def assert_list_of_type(result, expected_type, expected_first_value=None):
"""
Helper function to validate a list returned by getter functions.

Args:
result: The list to validate
expected_type: The expected type of list elements (e.g., int, str)
expected_first_value: Optional expected value of the first element
"""
assert isinstance(result, list)
assert len(result) > 0
assert all(isinstance(item, expected_type) for item in result)
if expected_first_value is not None:
assert result[0] == expected_first_value
37 changes: 22 additions & 15 deletions tests/test_cartesian_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,25 @@
Shape = Union[int, tuple[int, ...]]


def generate_test_args(size1: Shape, size2: Shape, size3: Shape) -> tuple:
"""
Generate test arguments for cartesian product tests.

Args:
size1: Size specification for first argument (energy)
size2: Size specification for second argument (material)
size3: Size specification for third argument (model)

Returns:
Tuple of (energy, material, model) arrays with specified shapes
"""
return (
np.random.uniform(900, 1100, size1),
np.random.randint(2, 8, size2),
np.random.randint(2, 8, size3),
)


def check_correct_shape(output: np.ndarray, size1: Shape, size2: Shape, size3: Shape) -> bool:
"""
Check whether a NumPy array has the expected shape based on three size specifications.
Expand Down Expand Up @@ -87,11 +106,7 @@ def compare_coords(output: np.ndarray, inputs: tuple, coords: list[tuple[int, ..
def test_cartesian_product(size1, size2, size3) -> None:
"""Simple tests for cartesian product"""

args = (
np.random.uniform(900, 1100, size1),
np.random.randint(2, 8, size2),
np.random.randint(2, 8, size3),
)
args = generate_test_args(size1, size2, size3)

output = electron_range(
*args,
Expand All @@ -117,11 +132,7 @@ def test_cartesian_product(size1, size2, size3) -> None:
)
def test_with_list_input(size1, size2, size3, arg_to_list):
"""Tests with one of the arguments being a list"""
args = [
np.random.uniform(900, 1100, size1),
np.random.randint(2, 8, size2),
np.random.randint(2, 8, size3),
]
args = list(generate_test_args(size1, size2, size3))

args[arg_to_list] = args[arg_to_list].tolist()

Expand Down Expand Up @@ -150,11 +161,7 @@ def test_with_list_input(size1, size2, size3, arg_to_list):
def test_with_scalar_input(size1, size2, size3):
"""Tests with one of the arguments being a scalar"""

args = [
np.random.uniform(900, 1100, size1),
np.random.randint(2, 8, size2),
np.random.randint(2, 8, size3),
]
args = list(generate_test_args(size1, size2, size3))

for i, size in zip(range(3), [size1, size2, size3]):
if size == 1:
Expand Down
16 changes: 4 additions & 12 deletions tests/test_materials.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

import pyamtrack
from tests.conftest import assert_list_of_type


def test_material_initialization_by_id():
Expand Down Expand Up @@ -29,26 +30,17 @@ def test_material_initialization_by_invalid_name():

def test_get_ids():
ids = pyamtrack.materials.get_ids()
assert isinstance(ids, list)
assert len(ids) > 0
assert all(isinstance(id, int) for id in ids)
assert ids[0] == 1 # Check the first ID
assert_list_of_type(ids, int, expected_first_value=1)


def test_get_long_names():
names = pyamtrack.materials.get_long_names()
assert isinstance(names, list)
assert len(names) > 0
assert all(isinstance(name, str) for name in names)
assert names[0] == "Water, Liquid" # Check the first name
assert_list_of_type(names, str, expected_first_value="Water, Liquid")


def test_get_short_names():
names = pyamtrack.materials.get_names()
assert isinstance(names, list)
assert len(names) > 0
assert all(isinstance(name, str) for name in names)
assert names[0] == "water_liquid" # Check the first name
assert_list_of_type(names, str, expected_first_value="water_liquid")


def test_via_object():
Expand Down
11 changes: 3 additions & 8 deletions tests/test_particles.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

import pyamtrack
from tests.conftest import assert_list_of_type


def test_particle_initialization_by_id():
Expand Down Expand Up @@ -81,18 +82,12 @@ def test_particle_from_string_invalid():

def test_get_names():
names = pyamtrack.particles.get_names()
assert isinstance(names, list)
assert len(names) > 0
assert all(isinstance(name, str) for name in names)
assert names[0] == "Hydrogen" # Check the first name
assert_list_of_type(names, str, expected_first_value="Hydrogen")


def test_get_acronyms():
names = pyamtrack.particles.get_acronyms()
assert isinstance(names, list)
assert len(names) > 0
assert all(isinstance(name, str) for name in names)
assert names[0] == "H" # Check the first name
assert_list_of_type(names, str, expected_first_value="H")


def test_via_name_object():
Expand Down
44 changes: 19 additions & 25 deletions tests/test_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,37 @@
import pyamtrack.stopping


@pytest.fixture
def electron_energy_MeV():
"""Fixture providing the electron energy in MeV for tests."""
return 1000.0


def test_electron_range(electron_energy_MeV):
def test_electron_range(electron_energy_MeV_high):
"""Test the electron_range function for various inputs."""
range_m = pyamtrack.stopping.electron_range(electron_energy_MeV)
range_m = pyamtrack.stopping.electron_range(electron_energy_MeV_high)
assert range_m > 0.01, "Expected positive range for positive energy."


def test_electron_range_air_vs_water(electron_energy_MeV):
def test_electron_range_air_vs_water(electron_energy_MeV_high):
"""Test the electron_range function for air and water."""
range_air = pyamtrack.stopping.electron_range(electron_energy_MeV, pyamtrack.materials.air)
range_water = pyamtrack.stopping.electron_range(electron_energy_MeV, pyamtrack.materials.water_liquid)
range_air = pyamtrack.stopping.electron_range(electron_energy_MeV_high, pyamtrack.materials.air)
range_water = pyamtrack.stopping.electron_range(electron_energy_MeV_high, pyamtrack.materials.water_liquid)
assert range_air > range_water, "Expected range in air to be larger than in water."


def test_material_assignment(electron_energy_MeV):
def test_material_assignment(electron_energy_MeV_high):
"""Test the material assignment by ID."""
range_default = pyamtrack.stopping.electron_range(electron_energy_MeV)
range_material_name = pyamtrack.stopping.electron_range(electron_energy_MeV, pyamtrack.materials.water_liquid)
range_default = pyamtrack.stopping.electron_range(electron_energy_MeV_high)
range_material_name = pyamtrack.stopping.electron_range(electron_energy_MeV_high, pyamtrack.materials.water_liquid)
assert range_default == range_material_name, "Expected range to be the same for default and material name."
range_material_id = pyamtrack.stopping.electron_range(electron_energy_MeV, 1)
range_material_id = pyamtrack.stopping.electron_range(electron_energy_MeV_high, 1)
assert range_default == range_material_id, "Expected range to be the same for default and material ID."


def test_mixed_parameter_types(electron_energy_MeV):
def test_mixed_parameter_types(electron_energy_MeV_high):
"""Test passing each parameter as list or numpy.ndarray, check for output type and shape"""
range_material_in_array = pyamtrack.stopping.electron_range(energy_MeV=[electron_energy_MeV])
range_many_materials = pyamtrack.stopping.electron_range(electron_energy_MeV, [1, 2], 3)
range_material_in_array = pyamtrack.stopping.electron_range(energy_MeV=[electron_energy_MeV_high])
range_many_materials = pyamtrack.stopping.electron_range(electron_energy_MeV_high, [1, 2], 3)
range_many_methods = pyamtrack.stopping.electron_range(
electron_energy_MeV, pyamtrack.materials.water_liquid, [1, 2]
electron_energy_MeV_high, pyamtrack.materials.water_liquid, [1, 2]
)
range_many_materials_and_methods = pyamtrack.stopping.electron_range(
electron_energy_MeV, [0, 1, pyamtrack.materials.water_liquid], [3, 4, "tabata"]
electron_energy_MeV_high, [0, 1, pyamtrack.materials.water_liquid], [3, 4, "tabata"]
)
assert isinstance(range_material_in_array, np.ndarray) and range_material_in_array.shape == (1,)
assert isinstance(range_many_materials, np.ndarray) and range_many_materials.shape == (2,)
Expand Down Expand Up @@ -84,22 +78,22 @@ def test_arrays_with_mixed_dtypes(dtype1: type, dtype2: type):
assert isinstance(range_numpy_arrays, np.ndarray) and range_numpy_arrays.shape == (3,)


def test_material_assignment_invalid(electron_energy_MeV):
def test_material_assignment_invalid(electron_energy_MeV_high):
"""Test the material assignment with an invalid ID."""
with pytest.raises(
RuntimeError,
match="Material argument must be an integer or a pyamtrack.materials.Material object",
):
pyamtrack.stopping.electron_range(electron_energy_MeV, "aaa") # Invalid ID
pyamtrack.stopping.electron_range(electron_energy_MeV_high, "aaa") # Invalid ID
with pytest.raises(
RuntimeError,
match="Material argument must be an integer or a pyamtrack.materials.Material object",
):
pyamtrack.stopping.electron_range(electron_energy_MeV, pyamtrack.materials.get_ids)
pyamtrack.stopping.electron_range(electron_energy_MeV_high, pyamtrack.materials.get_ids)


@pytest.mark.skip
def test_invalid_id(electron_energy_MeV):
def test_invalid_id(electron_energy_MeV_high):
"""Test the electron_range function with an invalid ID."""
with pytest.raises(ValueError, match="Invalid material ID"):
pyamtrack.stopping.electron_range(electron_energy_MeV, 1000000)
pyamtrack.stopping.electron_range(electron_energy_MeV_high, 1000000)
30 changes: 12 additions & 18 deletions tests/test_stopping_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,6 @@
import pyamtrack.stopping as stopping


@pytest.fixture
def electron_energy_MeV():
"""Fixture providing the electron energy in MeV for tests."""
return 100.0


@pytest.fixture
def models():
"""Fixture providing all available models."""
Expand Down Expand Up @@ -38,38 +32,38 @@ def test_model_id_mapping():
assert stopping.model("scholz_new") == 8


def test_invalid_model(electron_energy_MeV):
def test_invalid_model(electron_energy_MeV_low):
"""Test handling of invalid model names."""
with pytest.raises(ValueError, match="Unknown model name: invalid_model"):
stopping.electron_range(electron_energy_MeV, model="invalid_model")
stopping.electron_range(electron_energy_MeV_low, model="invalid_model")
with pytest.raises(TypeError):
stopping.electron_range(electron_energy_MeV, model=None)
stopping.electron_range(electron_energy_MeV_low, model=None)


@pytest.mark.parametrize(
"model_name",
["butts_katz", "waligorski", "geiss", "scholz", "edmund", "tabata", "scholz_new"],
)
def test_model_output_validity(electron_energy_MeV, model_name):
def test_model_output_validity(electron_energy_MeV_low, model_name):
"""Test that each model produces physically meaningful results."""
range_m = stopping.electron_range(electron_energy_MeV, model=model_name)
range_m = stopping.electron_range(electron_energy_MeV_low, model=model_name)
assert range_m > 0, f"{model_name} model returned negative range"
assert range_m < 1000, f"{model_name} model returned unreasonably large range"


def test_model_consistency(electron_energy_MeV):
def test_model_consistency(electron_energy_MeV_low):
"""Test that models can be specified by both name and ID."""
range_by_name = stopping.electron_range(electron_energy_MeV, model="tabata")
range_by_id = stopping.electron_range(electron_energy_MeV, model=7)
range_by_name = stopping.electron_range(electron_energy_MeV_low, model="tabata")
range_by_id = stopping.electron_range(electron_energy_MeV_low, model=7)
assert range_by_name == range_by_id


def test_model_relative_ranges(electron_energy_MeV, models):
def test_model_relative_ranges(electron_energy_MeV_low, models):
"""Test relative behavior of different models.

While models may give different results, they should all be within
reasonable physical bounds of each other for the same input."""
ranges = [stopping.electron_range(electron_energy_MeV, model=m) for m in models]
ranges = [stopping.electron_range(electron_energy_MeV_low, model=m) for m in models]
max_range = max(ranges)
min_range = min(ranges)

Expand All @@ -86,10 +80,10 @@ def test_energy_scaling():
assert ranges[1] > ranges[0], f"{model} model doesn't show expected energy scaling"


def test_material_independence(electron_energy_MeV, models):
def test_material_independence(electron_energy_MeV_low, models):
"""Test that models work with different materials."""
materials = [1, pyamtrack.materials.water_liquid, pyamtrack.materials.air]

for model in models:
ranges = [stopping.electron_range(electron_energy_MeV, material=m, model=model) for m in materials]
ranges = [stopping.electron_range(electron_energy_MeV_low, material=m, model=model) for m in materials]
assert all(r > 0 for r in ranges), f"{model} failed with some material"
Loading