diff --git a/tests/test_generation.py b/tests/test_generation.py index d039115e9..6253bfa4f 100644 --- a/tests/test_generation.py +++ b/tests/test_generation.py @@ -78,10 +78,28 @@ def _compare_attributes(obj1, obj2, exclude=None): if isinstance(val1, np.ndarray): assert np.allclose(val1, val2) + elif attr == "prim_interstitial_coords": + _compare_prim_interstitial_coords(val1, val2) + elif attr == "defects" and any(len(i.defect_structure) == 0 for i in val1["vacancies"]): + continue # StructureMatcher comparison breaks for empty structures, which we can have with + # our 1-atom primitive Cu input else: assert val1 == val2 +def _compare_prim_interstitial_coords(result, expected): # check attribute set + if result is None: + assert expected is None + return + + assert len(result) == len(expected), "Lengths do not match" + + for (r_coord, r_num, r_list), (e_coord, e_num, e_list) in zip(result, expected): + assert np.array_equal(r_coord, e_coord), "Coordinates do not match" + assert r_num == e_num, "Number of coordinates do not match" + assert all(np.array_equal(r, e) for r, e in zip(r_list, e_list)), "List of arrays do not match" + + class DefectsGeneratorTest(unittest.TestCase): def setUp(self): # don't run heavy tests on GH Actions, these are run locally (too slow without multiprocessing etc) @@ -1453,18 +1471,7 @@ def test_interstitial_coords(self): (np.array([0.75, 0.75, 0.75]), 1, [np.array([0.75, 0.75, 0.75])]), (np.array([0.5, 0.5, 0.5]), 1, [np.array([0.5, 0.5, 0.5])]), ] - - def _test_prim_interstitial_coords(result, expected): # check attribute set - assert len(result) == len(expected), "Lengths do not match" - - for (r_coord, r_num, r_list), (e_coord, e_num, e_list) in zip(result, expected): - assert np.array_equal(r_coord, e_coord), "Coordinates do not match" - assert r_num == e_num, "Number of coordinates do not match" - assert all( - np.array_equal(r, e) for r, e in zip(r_list, e_list) - ), "List of arrays do not match" - - _test_prim_interstitial_coords(CdTe_defect_gen.prim_interstitial_coords, expected) + _compare_prim_interstitial_coords(CdTe_defect_gen.prim_interstitial_coords, expected) # defect_gen_check changes defect_entries ordering, so save to json first: self._save_defect_gen_jsons(CdTe_defect_gen) @@ -1475,7 +1482,7 @@ def _test_prim_interstitial_coords(result, expected): # check attribute set self.prim_cdte, interstitial_coords=[[0.5, 0.5, 0.5]], extrinsic={"Te": ["Se", "S"]} ) assert CdTe_defect_gen.interstitial_coords == [[0.5, 0.5, 0.5]] # check attribute set - _test_prim_interstitial_coords( + _compare_prim_interstitial_coords( CdTe_defect_gen.prim_interstitial_coords, [(np.array([0.5, 0.5, 0.5]), 1, [np.array([0.5, 0.5, 0.5])])], # check attribute set ) @@ -1516,7 +1523,7 @@ def _test_prim_interstitial_coords(result, expected): # check attribute set self._load_and_test_defect_gen_jsons(ytos_defect_gen) assert ytos_defect_gen.interstitial_coords == ytos_interstitial_coords # check attribute set - _test_prim_interstitial_coords( + _compare_prim_interstitial_coords( ytos_defect_gen.prim_interstitial_coords, [ ( @@ -1563,7 +1570,7 @@ def _test_prim_interstitial_coords(result, expected): # check attribute set assert ( CdTe_defect_gen.interstitial_coords == CdTe_supercell_interstitial_coords ) # check attribute set - _test_prim_interstitial_coords( + _compare_prim_interstitial_coords( CdTe_defect_gen.prim_interstitial_coords, [ (