Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions mess/hamiltonian.py

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super interesting! 🚀

Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from mess.basis import Basis, renorm
from mess.integrals import eri_basis, kinetic_basis, nuclear_basis, overlap_basis
from mess.interop import to_pyscf
from mess.mesh import Mesh, density, density_and_grad, xcmesh_from_pyscf
from mess.mesh import Mesh, density, density_and_grad, sg1_mesh
from mess.orthnorm import symmetric
from mess.structure import nuclear_energy
from mess.types import FloatNxN, OrthNormTransform
Expand Down Expand Up @@ -119,7 +119,7 @@ class LDA(eqx.Module):

def __init__(self, basis: Basis):
self.basis = basis
self.mesh = xcmesh_from_pyscf(basis.structure)
self.mesh = sg1_mesh(basis.structure)

def __call__(self, P: FloatNxN) -> ScalarLike:
rho = density(self.basis, self.mesh, P)
Expand All @@ -134,7 +134,7 @@ class PBE(eqx.Module):

def __init__(self, basis: Basis):
self.basis = basis
self.mesh = xcmesh_from_pyscf(basis.structure)
self.mesh = sg1_mesh(basis.structure)

def __call__(self, P: FloatNxN) -> ScalarLike:
rho, grad_rho = density_and_grad(self.basis, self.mesh, P)
Expand All @@ -150,7 +150,7 @@ class PBE0(eqx.Module):

def __init__(self, basis: Basis, two_electron: TwoElectron):
self.basis = basis
self.mesh = xcmesh_from_pyscf(basis.structure)
self.mesh = sg1_mesh(basis.structure)
self.hfx = HartreeFockExchange(two_electron)

def __call__(self, P: FloatNxN) -> ScalarLike:
Expand All @@ -167,7 +167,7 @@ class B3LYP(eqx.Module):

def __init__(self, basis: Basis, two_electron: TwoElectron):
self.basis = basis
self.mesh = xcmesh_from_pyscf(basis.structure)
self.mesh = sg1_mesh(basis.structure)
self.hfx = HartreeFockExchange(two_electron)

def __call__(self, P: FloatNxN) -> ScalarLike:
Expand Down
28 changes: 28 additions & 0 deletions test/test_hamiltonian.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np
import pytest
import equinox as eqx
import jax
from jax.experimental import enable_x64
from numpy.testing import assert_allclose
from pyscf import dft
Expand Down Expand Up @@ -40,3 +42,29 @@ def test_energy(inputs, basis_name, mol):
actual = H(P) + nuclear_energy(mol)
expect = s.energy_tot()
assert_allclose(actual, expect, atol=1e-6)


def test_autograd_wrt_positions():
mol = molecule("h2")
scfmol = to_pyscf(mol, basis_name="def2-SVP")
s = dft.RKS(scfmol, xc=cases["lda"])
s.kernel()
P = np.asarray(s.make_rdm1())
g = s.Gradients()
scf_grad = g.kernel()

@jax.jit
def f(pos, rest, basis):
structure = eqx.combine(pos, rest)
basis = eqx.tree_at(lambda x: x.structure, basis, structure)
pcenter = structure.position[basis.primitives.atom_index]
basis = eqx.tree_at(lambda x: x.primitives.center, basis, pcenter)
H = Hamiltonian(basis=basis, xc_method="lda", backend="mess")

return H(P) + nuclear_energy(structure)

mol = jax.device_put(mol)
basis = basisset(mol, "def2-SVP")
pos, rest = eqx.partition(mol, lambda x: id(x) == id(mol.position))
grad_E = jax.grad(f)(pos, rest, basis)
assert_allclose(-grad_E.position, scf_grad, atol=1e-1)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have an intuition regarding these accuracies? And what is the current relative error? :D

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current problem is the XC evaluation isn't exactly like-for-like but it should be possible to get them to match much closer. Checkout test/test_autograd_integrals.py which shows that autodiff can match the analytic gradients of the one-electron components. I'm optimistic to have the absolute error at 1e-5 once mess can match the XC mesh generation exactly the same as pyscf's numerical quadrature.

Loading