diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..0bd9cb5 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,31 @@ +"""Shared pytest fixtures and utilities for pyamtrack tests.""" + +import pytest + + +@pytest.fixture +def electron_energy_MeV_low(): + """Fixture providing a low electron energy in MeV for tests.""" + return 100.0 + + +@pytest.fixture +def electron_energy_MeV_high(): + """Fixture providing a high electron energy in MeV for tests.""" + return 1000.0 + + +def assert_list_of_type(result, expected_type, expected_first_value=None): + """ + Helper function to validate a list returned by getter functions. + + Args: + result: The list to validate + expected_type: The expected type of list elements (e.g., int, str) + expected_first_value: Optional expected value of the first element + """ + assert isinstance(result, list) + assert len(result) > 0 + assert all(isinstance(item, expected_type) for item in result) + if expected_first_value is not None: + assert result[0] == expected_first_value diff --git a/tests/test_cartesian_product.py b/tests/test_cartesian_product.py index 385d6c0..8048cb5 100644 --- a/tests/test_cartesian_product.py +++ b/tests/test_cartesian_product.py @@ -8,6 +8,25 @@ Shape = Union[int, tuple[int, ...]] +def generate_test_args(size1: Shape, size2: Shape, size3: Shape) -> tuple: + """ + Generate test arguments for cartesian product tests. + + Args: + size1: Size specification for first argument (energy) + size2: Size specification for second argument (material) + size3: Size specification for third argument (model) + + Returns: + Tuple of (energy, material, model) arrays with specified shapes + """ + return ( + np.random.uniform(900, 1100, size1), + np.random.randint(2, 8, size2), + np.random.randint(2, 8, size3), + ) + + def check_correct_shape(output: np.ndarray, size1: Shape, size2: Shape, size3: Shape) -> bool: """ Check whether a NumPy array has the expected shape based on three size specifications. @@ -87,11 +106,7 @@ def compare_coords(output: np.ndarray, inputs: tuple, coords: list[tuple[int, .. def test_cartesian_product(size1, size2, size3) -> None: """Simple tests for cartesian product""" - args = ( - np.random.uniform(900, 1100, size1), - np.random.randint(2, 8, size2), - np.random.randint(2, 8, size3), - ) + args = generate_test_args(size1, size2, size3) output = electron_range( *args, @@ -117,11 +132,7 @@ def test_cartesian_product(size1, size2, size3) -> None: ) def test_with_list_input(size1, size2, size3, arg_to_list): """Tests with one of the arguments being a list""" - args = [ - np.random.uniform(900, 1100, size1), - np.random.randint(2, 8, size2), - np.random.randint(2, 8, size3), - ] + args = list(generate_test_args(size1, size2, size3)) args[arg_to_list] = args[arg_to_list].tolist() @@ -150,11 +161,7 @@ def test_with_list_input(size1, size2, size3, arg_to_list): def test_with_scalar_input(size1, size2, size3): """Tests with one of the arguments being a scalar""" - args = [ - np.random.uniform(900, 1100, size1), - np.random.randint(2, 8, size2), - np.random.randint(2, 8, size3), - ] + args = list(generate_test_args(size1, size2, size3)) for i, size in zip(range(3), [size1, size2, size3]): if size == 1: diff --git a/tests/test_materials.py b/tests/test_materials.py index 9168b68..4dd8a90 100644 --- a/tests/test_materials.py +++ b/tests/test_materials.py @@ -1,6 +1,7 @@ import pytest import pyamtrack +from tests.conftest import assert_list_of_type def test_material_initialization_by_id(): @@ -29,26 +30,17 @@ def test_material_initialization_by_invalid_name(): def test_get_ids(): ids = pyamtrack.materials.get_ids() - assert isinstance(ids, list) - assert len(ids) > 0 - assert all(isinstance(id, int) for id in ids) - assert ids[0] == 1 # Check the first ID + assert_list_of_type(ids, int, expected_first_value=1) def test_get_long_names(): names = pyamtrack.materials.get_long_names() - assert isinstance(names, list) - assert len(names) > 0 - assert all(isinstance(name, str) for name in names) - assert names[0] == "Water, Liquid" # Check the first name + assert_list_of_type(names, str, expected_first_value="Water, Liquid") def test_get_short_names(): names = pyamtrack.materials.get_names() - assert isinstance(names, list) - assert len(names) > 0 - assert all(isinstance(name, str) for name in names) - assert names[0] == "water_liquid" # Check the first name + assert_list_of_type(names, str, expected_first_value="water_liquid") def test_via_object(): diff --git a/tests/test_particles.py b/tests/test_particles.py index b9e21ca..6246df2 100644 --- a/tests/test_particles.py +++ b/tests/test_particles.py @@ -1,6 +1,7 @@ import pytest import pyamtrack +from tests.conftest import assert_list_of_type def test_particle_initialization_by_id(): @@ -81,18 +82,12 @@ def test_particle_from_string_invalid(): def test_get_names(): names = pyamtrack.particles.get_names() - assert isinstance(names, list) - assert len(names) > 0 - assert all(isinstance(name, str) for name in names) - assert names[0] == "Hydrogen" # Check the first name + assert_list_of_type(names, str, expected_first_value="Hydrogen") def test_get_acronyms(): names = pyamtrack.particles.get_acronyms() - assert isinstance(names, list) - assert len(names) > 0 - assert all(isinstance(name, str) for name in names) - assert names[0] == "H" # Check the first name + assert_list_of_type(names, str, expected_first_value="H") def test_via_name_object(): diff --git a/tests/test_stopping.py b/tests/test_stopping.py index 68bf026..5b3b126 100644 --- a/tests/test_stopping.py +++ b/tests/test_stopping.py @@ -4,43 +4,37 @@ import pyamtrack.stopping -@pytest.fixture -def electron_energy_MeV(): - """Fixture providing the electron energy in MeV for tests.""" - return 1000.0 - - -def test_electron_range(electron_energy_MeV): +def test_electron_range(electron_energy_MeV_high): """Test the electron_range function for various inputs.""" - range_m = pyamtrack.stopping.electron_range(electron_energy_MeV) + range_m = pyamtrack.stopping.electron_range(electron_energy_MeV_high) assert range_m > 0.01, "Expected positive range for positive energy." -def test_electron_range_air_vs_water(electron_energy_MeV): +def test_electron_range_air_vs_water(electron_energy_MeV_high): """Test the electron_range function for air and water.""" - range_air = pyamtrack.stopping.electron_range(electron_energy_MeV, pyamtrack.materials.air) - range_water = pyamtrack.stopping.electron_range(electron_energy_MeV, pyamtrack.materials.water_liquid) + range_air = pyamtrack.stopping.electron_range(electron_energy_MeV_high, pyamtrack.materials.air) + range_water = pyamtrack.stopping.electron_range(electron_energy_MeV_high, pyamtrack.materials.water_liquid) assert range_air > range_water, "Expected range in air to be larger than in water." -def test_material_assignment(electron_energy_MeV): +def test_material_assignment(electron_energy_MeV_high): """Test the material assignment by ID.""" - range_default = pyamtrack.stopping.electron_range(electron_energy_MeV) - range_material_name = pyamtrack.stopping.electron_range(electron_energy_MeV, pyamtrack.materials.water_liquid) + range_default = pyamtrack.stopping.electron_range(electron_energy_MeV_high) + range_material_name = pyamtrack.stopping.electron_range(electron_energy_MeV_high, pyamtrack.materials.water_liquid) assert range_default == range_material_name, "Expected range to be the same for default and material name." - range_material_id = pyamtrack.stopping.electron_range(electron_energy_MeV, 1) + range_material_id = pyamtrack.stopping.electron_range(electron_energy_MeV_high, 1) assert range_default == range_material_id, "Expected range to be the same for default and material ID." -def test_mixed_parameter_types(electron_energy_MeV): +def test_mixed_parameter_types(electron_energy_MeV_high): """Test passing each parameter as list or numpy.ndarray, check for output type and shape""" - range_material_in_array = pyamtrack.stopping.electron_range(energy_MeV=[electron_energy_MeV]) - range_many_materials = pyamtrack.stopping.electron_range(electron_energy_MeV, [1, 2], 3) + range_material_in_array = pyamtrack.stopping.electron_range(energy_MeV=[electron_energy_MeV_high]) + range_many_materials = pyamtrack.stopping.electron_range(electron_energy_MeV_high, [1, 2], 3) range_many_methods = pyamtrack.stopping.electron_range( - electron_energy_MeV, pyamtrack.materials.water_liquid, [1, 2] + electron_energy_MeV_high, pyamtrack.materials.water_liquid, [1, 2] ) range_many_materials_and_methods = pyamtrack.stopping.electron_range( - electron_energy_MeV, [0, 1, pyamtrack.materials.water_liquid], [3, 4, "tabata"] + electron_energy_MeV_high, [0, 1, pyamtrack.materials.water_liquid], [3, 4, "tabata"] ) assert isinstance(range_material_in_array, np.ndarray) and range_material_in_array.shape == (1,) assert isinstance(range_many_materials, np.ndarray) and range_many_materials.shape == (2,) @@ -84,22 +78,22 @@ def test_arrays_with_mixed_dtypes(dtype1: type, dtype2: type): assert isinstance(range_numpy_arrays, np.ndarray) and range_numpy_arrays.shape == (3,) -def test_material_assignment_invalid(electron_energy_MeV): +def test_material_assignment_invalid(electron_energy_MeV_high): """Test the material assignment with an invalid ID.""" with pytest.raises( RuntimeError, match="Material argument must be an integer or a pyamtrack.materials.Material object", ): - pyamtrack.stopping.electron_range(electron_energy_MeV, "aaa") # Invalid ID + pyamtrack.stopping.electron_range(electron_energy_MeV_high, "aaa") # Invalid ID with pytest.raises( RuntimeError, match="Material argument must be an integer or a pyamtrack.materials.Material object", ): - pyamtrack.stopping.electron_range(electron_energy_MeV, pyamtrack.materials.get_ids) + pyamtrack.stopping.electron_range(electron_energy_MeV_high, pyamtrack.materials.get_ids) @pytest.mark.skip -def test_invalid_id(electron_energy_MeV): +def test_invalid_id(electron_energy_MeV_high): """Test the electron_range function with an invalid ID.""" with pytest.raises(ValueError, match="Invalid material ID"): - pyamtrack.stopping.electron_range(electron_energy_MeV, 1000000) + pyamtrack.stopping.electron_range(electron_energy_MeV_high, 1000000) diff --git a/tests/test_stopping_models.py b/tests/test_stopping_models.py index c5c1e34..118822a 100644 --- a/tests/test_stopping_models.py +++ b/tests/test_stopping_models.py @@ -5,12 +5,6 @@ import pyamtrack.stopping as stopping -@pytest.fixture -def electron_energy_MeV(): - """Fixture providing the electron energy in MeV for tests.""" - return 100.0 - - @pytest.fixture def models(): """Fixture providing all available models.""" @@ -38,38 +32,38 @@ def test_model_id_mapping(): assert stopping.model("scholz_new") == 8 -def test_invalid_model(electron_energy_MeV): +def test_invalid_model(electron_energy_MeV_low): """Test handling of invalid model names.""" with pytest.raises(ValueError, match="Unknown model name: invalid_model"): - stopping.electron_range(electron_energy_MeV, model="invalid_model") + stopping.electron_range(electron_energy_MeV_low, model="invalid_model") with pytest.raises(TypeError): - stopping.electron_range(electron_energy_MeV, model=None) + stopping.electron_range(electron_energy_MeV_low, model=None) @pytest.mark.parametrize( "model_name", ["butts_katz", "waligorski", "geiss", "scholz", "edmund", "tabata", "scholz_new"], ) -def test_model_output_validity(electron_energy_MeV, model_name): +def test_model_output_validity(electron_energy_MeV_low, model_name): """Test that each model produces physically meaningful results.""" - range_m = stopping.electron_range(electron_energy_MeV, model=model_name) + range_m = stopping.electron_range(electron_energy_MeV_low, model=model_name) assert range_m > 0, f"{model_name} model returned negative range" assert range_m < 1000, f"{model_name} model returned unreasonably large range" -def test_model_consistency(electron_energy_MeV): +def test_model_consistency(electron_energy_MeV_low): """Test that models can be specified by both name and ID.""" - range_by_name = stopping.electron_range(electron_energy_MeV, model="tabata") - range_by_id = stopping.electron_range(electron_energy_MeV, model=7) + range_by_name = stopping.electron_range(electron_energy_MeV_low, model="tabata") + range_by_id = stopping.electron_range(electron_energy_MeV_low, model=7) assert range_by_name == range_by_id -def test_model_relative_ranges(electron_energy_MeV, models): +def test_model_relative_ranges(electron_energy_MeV_low, models): """Test relative behavior of different models. While models may give different results, they should all be within reasonable physical bounds of each other for the same input.""" - ranges = [stopping.electron_range(electron_energy_MeV, model=m) for m in models] + ranges = [stopping.electron_range(electron_energy_MeV_low, model=m) for m in models] max_range = max(ranges) min_range = min(ranges) @@ -86,10 +80,10 @@ def test_energy_scaling(): assert ranges[1] > ranges[0], f"{model} model doesn't show expected energy scaling" -def test_material_independence(electron_energy_MeV, models): +def test_material_independence(electron_energy_MeV_low, models): """Test that models work with different materials.""" materials = [1, pyamtrack.materials.water_liquid, pyamtrack.materials.air] for model in models: - ranges = [stopping.electron_range(electron_energy_MeV, material=m, model=model) for m in materials] + ranges = [stopping.electron_range(electron_energy_MeV_low, material=m, model=model) for m in materials] assert all(r > 0 for r in ranges), f"{model} failed with some material"