Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 36 additions & 42 deletions mess/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,29 +26,49 @@

class Basis(eqx.Module):
orbitals: Tuple[Orbital]
num_orbitals: int = eqx.field(static=True)
structure: Structure
occupancy: FloatN
primitives: Primitive
num_primitives: int = eqx.field(static=True)
coefficients: FloatN
orbital_index: IntN
basis_name: str = eqx.field(static=True)
max_L: int = eqx.field(static=True)
spherical: bool = eqx.field(static=True)

@property
def num_orbitals(self) -> int:
return len(self.orbitals)

@property
def num_primitives(self) -> int:
return sum(ao.num_primitives for ao in self.orbitals)

@property
def occupancy(self) -> FloatN:
# Assumes uncharged systems in restricted Kohn-Sham
occ = jnp.full(self.num_orbitals, 2.0)
mask = occ.cumsum() > self.structure.num_electrons
occ = jnp.where(mask, 0.0, occ)
return occ
def __init__(self, structure: Structure, basis_name: str, spherical: bool = True):
self.structure = structure
self.basis_name = basis_name
self.spherical = spherical

orbitals = []
atom_index = []

for atom_id in range(structure.num_atoms):
element = int(structure.atomic_number[atom_id])
out = _bse_to_orbitals(basis_name, element, spherical)
atom_index.extend([atom_id] * sum(len(ao.primitives) for ao in out))
orbitals += out

primitives, coefficients, orbital_index = batch_orbitals(orbitals)
primitives = eqx.tree_at(
lambda p: p.atom_index, primitives, jnp.array(atom_index)
)
center = structure.position[primitives.atom_index, :]
primitives = eqx.tree_at(lambda p: p.center, primitives, center)

self.orbitals = orbitals
self.num_orbitals = len(orbitals)
self.num_primitives = sum(ao.num_primitives for ao in orbitals)
self.primitives = primitives
self.coefficients = coefficients
self.orbital_index = orbital_index
self.max_L = int(np.max(primitives.lmn))

occupancy = np.zeros(self.num_orbitals)
occupancy[: structure.num_electrons // 2] = 2.0
self.occupancy = occupancy

def to_dataframe(self) -> pd.DataFrame:
def fixer(x):
Expand Down Expand Up @@ -90,9 +110,6 @@ def _repr_html_(self) -> str | None:
df = self.to_dataframe()
return df._repr_html_()

def __hash__(self) -> int:
return hash(self.primitives)


def basisset(
structure: Structure, basis_name: str = "sto-3g", spherical: bool = True
Expand All @@ -110,30 +127,7 @@ def basisset(
Returns:
Basis constructed from inputs
"""
orbitals = []
atom_index = []

for atom_id in range(structure.num_atoms):
element = int(structure.atomic_number[atom_id])
out = _bse_to_orbitals(basis_name, element, spherical)
atom_index.extend([atom_id] * sum(len(ao.primitives) for ao in out))
orbitals += out

primitives, coefficients, orbital_index = batch_orbitals(orbitals)
primitives = eqx.tree_at(lambda p: p.atom_index, primitives, jnp.array(atom_index))
center = structure.position[primitives.atom_index, :]
primitives = eqx.tree_at(lambda p: p.center, primitives, center)

basis = Basis(
orbitals=orbitals,
structure=structure,
primitives=primitives,
coefficients=coefficients,
orbital_index=orbital_index,
basis_name=basis_name,
max_L=int(np.max(primitives.lmn)),
spherical=spherical,
)
basis = Basis(structure, basis_name, spherical)

# TODO(hh): this introduces some performance overhead into basis construction that
# could be pushed down into the cached orbitals.
Expand Down
9 changes: 5 additions & 4 deletions mess/integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from more_itertools import batched
from typing import Callable

import equinox as eqx
import jax.numpy as jnp
import numpy as np
from jax import jit, tree, vmap
Expand All @@ -44,7 +45,7 @@
BinaryPrimitiveOp = Callable[[Primitive, Primitive], float]


@partial(jit, static_argnums=(0, 1))
@partial(jit, static_argnums=1)
def integrate_dense(basis: Basis, primitive_op: BinaryPrimitiveOp) -> FloatNxN:
(ii, cl, lhs), (jj, cr, rhs) = basis_iter(basis)
aij = cl * cr * vmap(primitive_op)(lhs, rhs)
Expand All @@ -57,7 +58,7 @@ def integrate_dense(basis: Basis, primitive_op: BinaryPrimitiveOp) -> FloatNxN:
return out


@partial(jit, static_argnums=(0, 1))
@partial(jit, static_argnums=1)
def integrate_sparse(basis: Basis, primitive_op: BinaryPrimitiveOp) -> FloatNxN:
offset = [0] + [o.num_primitives for o in basis.orbitals]
offset = np.cumsum(offset)
Expand Down Expand Up @@ -193,7 +194,7 @@ def g_term(l1, l2, pa, pb, cp):
nuclear_primitives = jit(_nuclear_primitives)


@partial(jit, static_argnums=0)
@jit
def nuclear_basis(basis: Basis):
def n(atomic_number, position):
def op(pi, pj):
Expand Down Expand Up @@ -294,7 +295,7 @@ def gen_ijkl(n: int):
yield idx, jdx, kdx, ldx


@partial(jit, static_argnums=0)
@eqx.filter_jit
def eri_basis_sparse(b: Basis):
indices = []
batch = []
Expand Down
10 changes: 0 additions & 10 deletions mess/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,6 @@ def angular_momentum(self) -> int:
def __call__(self, pos: FloatNx3) -> FloatN:
return eval_primitive(self, pos)

def __hash__(self) -> int:
values = []
for k, v in vars(self).items():
if k.startswith("__") or v is None:
continue

values.append(v.tobytes())

return hash(b"".join(values))


@jit
def normalize(lmn: Int3, alpha: float) -> float:
Expand Down
23 changes: 9 additions & 14 deletions mess/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,18 @@
class Structure(eqx.Module):
atomic_number: IntN
position: FloatNx3
atomic_symbol: List[str] = eqx.field(static=True)
num_atoms: int = eqx.field(static=True)
num_electrons: int = eqx.field(static=True)

def __post_init__(self):
# single atom case
self.atomic_number = np.atleast_1d(self.atomic_number)
self.position = np.atleast_2d(self.position)
def __init__(self, atomic_number: IntN, position: FloatNx3):
self.atomic_number = np.atleast_1d(atomic_number)
self.position = np.atleast_2d(position)

@property
def num_atoms(self) -> int:
return len(self.atomic_number)
self.num_atoms = len(self.atomic_number)
self.num_electrons = int(np.sum(self.atomic_number))

@property
def atomic_symbol(self) -> List[str]:
return [elements[z].symbol for z in self.atomic_number]

@property
def num_electrons(self) -> int:
return np.sum(self.atomic_number)
self.atomic_symbol = [elements[z].symbol for z in self.atomic_number]

def _repr_html_(self):
import py3Dmol
Expand Down
2 changes: 1 addition & 1 deletion mess/zeropad_integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from mess.types import FloatNxN


@partial(jit, static_argnums=0)
@jit
def overlap_basis_zeropad(basis: Basis) -> FloatNxN:
def op(a, b):
return _overlap_primitives_zeropad(a, b, basis.max_L)
Expand Down