diff --git a/moldesign/_tests/test_gaussian_math.py b/moldesign/_tests/test_gaussian_math.py index d568f22..db80898 100644 --- a/moldesign/_tests/test_gaussian_math.py +++ b/moldesign/_tests/test_gaussian_math.py @@ -312,7 +312,7 @@ def test_vectorized_gaussian_function_evaluations(objkey, request): _assert_almost_equal(vector_results, expected, decimal=8) -@pytest.mark.parametrize('objkey', registered_types['basis_fn']) +@pytest.mark.parametrize('objkey', registered_types['basis_fn'] + ['linear_combination']) def test_gaussian_str_and_repr_works(objkey, request): g1 = request.getfixturevalue(objkey) str(g1) @@ -348,6 +348,23 @@ def test_normalization(objkey, request): assert abs(g1.norm - 1.0) < 1e-12 +def test_linear_combination_normalization(linear_combination): + g1 = linear_combination + oldnorm = g1.norm + + prefactor = (random.random() - 0.5) * 428.23 + for prim in g1: + prim.coeff *= prefactor + + try: + assert g1.norm != oldnorm + except u.DimensionalityError: + pass # this is a reasonable thing to happen too + + g1.normalize() + assert abs(g1.norm - 1.0) < 1e-12 + + def _gfuncval(g, coord): r = g.center - coord if len(coord.shape) > 1: diff --git a/moldesign/_tests/test_wfn.py b/moldesign/_tests/test_wfn.py index f94d1e3..7239a0c 100644 --- a/moldesign/_tests/test_wfn.py +++ b/moldesign/_tests/test_wfn.py @@ -64,6 +64,7 @@ def test_basis_function_3d_grids_same_in_pyscf_and_mdt(molkey, request): @pytest.mark.parametrize('molkey', ['h2_rhf_augccpvdz', 'h2_rhf_sto3g']) +@pytest.mark.screening def test_pyscf_basis_function_space_integral_normalized(molkey, request): mol = request.getfixturevalue(molkey) grid = mdt.mathutils.padded_grid(mol.positions, 8.0 * u.angstrom, npoints=150) diff --git a/moldesign/orbitals/primitives.py b/moldesign/orbitals/primitives.py index 72eff7c..34eceb4 100644 --- a/moldesign/orbitals/primitives.py +++ b/moldesign/orbitals/primitives.py @@ -135,7 +135,7 @@ def normalize(self): """ prefactor = 1.0 / self.norm for primitive in self.primitives: - primitive *= prefactor + primitive.coeff *= prefactor def overlap(self, other): """ Calculate orbital overlap with another object