diff --git a/mess/hamiltonian.py b/mess/hamiltonian.py index 512101d..04b7a2e 100644 --- a/mess/hamiltonian.py +++ b/mess/hamiltonian.py @@ -13,7 +13,7 @@ from mess.basis import Basis, renorm from mess.integrals import eri_basis, kinetic_basis, nuclear_basis, overlap_basis from mess.interop import to_pyscf -from mess.mesh import Mesh, density, density_and_grad, xcmesh_from_pyscf +from mess.mesh import Mesh, density, density_and_grad, sg1_mesh from mess.orthnorm import symmetric from mess.structure import nuclear_energy from mess.types import FloatNxN, OrthNormTransform @@ -119,7 +119,7 @@ class LDA(eqx.Module): def __init__(self, basis: Basis): self.basis = basis - self.mesh = xcmesh_from_pyscf(basis.structure) + self.mesh = sg1_mesh(basis.structure) def __call__(self, P: FloatNxN) -> ScalarLike: rho = density(self.basis, self.mesh, P) @@ -134,7 +134,7 @@ class PBE(eqx.Module): def __init__(self, basis: Basis): self.basis = basis - self.mesh = xcmesh_from_pyscf(basis.structure) + self.mesh = sg1_mesh(basis.structure) def __call__(self, P: FloatNxN) -> ScalarLike: rho, grad_rho = density_and_grad(self.basis, self.mesh, P) @@ -150,7 +150,7 @@ class PBE0(eqx.Module): def __init__(self, basis: Basis, two_electron: TwoElectron): self.basis = basis - self.mesh = xcmesh_from_pyscf(basis.structure) + self.mesh = sg1_mesh(basis.structure) self.hfx = HartreeFockExchange(two_electron) def __call__(self, P: FloatNxN) -> ScalarLike: @@ -167,7 +167,7 @@ class B3LYP(eqx.Module): def __init__(self, basis: Basis, two_electron: TwoElectron): self.basis = basis - self.mesh = xcmesh_from_pyscf(basis.structure) + self.mesh = sg1_mesh(basis.structure) self.hfx = HartreeFockExchange(two_electron) def __call__(self, P: FloatNxN) -> ScalarLike: diff --git a/test/test_hamiltonian.py b/test/test_hamiltonian.py index 63f1d74..34fd7c6 100644 --- a/test/test_hamiltonian.py +++ b/test/test_hamiltonian.py @@ -1,5 +1,7 @@ import numpy as np import pytest +import equinox as eqx +import jax from jax.experimental import enable_x64 from numpy.testing import assert_allclose from pyscf import dft @@ -40,3 +42,29 @@ def test_energy(inputs, basis_name, mol): actual = H(P) + nuclear_energy(mol) expect = s.energy_tot() assert_allclose(actual, expect, atol=1e-6) + + +def test_autograd_wrt_positions(): + mol = molecule("h2") + scfmol = to_pyscf(mol, basis_name="def2-SVP") + s = dft.RKS(scfmol, xc=cases["lda"]) + s.kernel() + P = np.asarray(s.make_rdm1()) + g = s.Gradients() + scf_grad = g.kernel() + + @jax.jit + def f(pos, rest, basis): + structure = eqx.combine(pos, rest) + basis = eqx.tree_at(lambda x: x.structure, basis, structure) + pcenter = structure.position[basis.primitives.atom_index] + basis = eqx.tree_at(lambda x: x.primitives.center, basis, pcenter) + H = Hamiltonian(basis=basis, xc_method="lda", backend="mess") + + return H(P) + nuclear_energy(structure) + + mol = jax.device_put(mol) + basis = basisset(mol, "def2-SVP") + pos, rest = eqx.partition(mol, lambda x: id(x) == id(mol.position)) + grad_E = jax.grad(f)(pos, rest, basis) + assert_allclose(-grad_E.position, scf_grad, atol=1e-1)