diff --git a/mess/basis.py b/mess/basis.py index a8bd119..7e140af 100644 --- a/mess/basis.py +++ b/mess/basis.py @@ -26,29 +26,49 @@ class Basis(eqx.Module): orbitals: Tuple[Orbital] + num_orbitals: int = eqx.field(static=True) structure: Structure + occupancy: FloatN primitives: Primitive + num_primitives: int = eqx.field(static=True) coefficients: FloatN orbital_index: IntN basis_name: str = eqx.field(static=True) max_L: int = eqx.field(static=True) spherical: bool = eqx.field(static=True) - @property - def num_orbitals(self) -> int: - return len(self.orbitals) - - @property - def num_primitives(self) -> int: - return sum(ao.num_primitives for ao in self.orbitals) - - @property - def occupancy(self) -> FloatN: - # Assumes uncharged systems in restricted Kohn-Sham - occ = jnp.full(self.num_orbitals, 2.0) - mask = occ.cumsum() > self.structure.num_electrons - occ = jnp.where(mask, 0.0, occ) - return occ + def __init__(self, structure: Structure, basis_name: str, spherical: bool = True): + self.structure = structure + self.basis_name = basis_name + self.spherical = spherical + + orbitals = [] + atom_index = [] + + for atom_id in range(structure.num_atoms): + element = int(structure.atomic_number[atom_id]) + out = _bse_to_orbitals(basis_name, element, spherical) + atom_index.extend([atom_id] * sum(len(ao.primitives) for ao in out)) + orbitals += out + + primitives, coefficients, orbital_index = batch_orbitals(orbitals) + primitives = eqx.tree_at( + lambda p: p.atom_index, primitives, jnp.array(atom_index) + ) + center = structure.position[primitives.atom_index, :] + primitives = eqx.tree_at(lambda p: p.center, primitives, center) + + self.orbitals = orbitals + self.num_orbitals = len(orbitals) + self.num_primitives = sum(ao.num_primitives for ao in orbitals) + self.primitives = primitives + self.coefficients = coefficients + self.orbital_index = orbital_index + self.max_L = int(np.max(primitives.lmn)) + + occupancy = np.zeros(self.num_orbitals) + occupancy[: structure.num_electrons // 2] = 2.0 + self.occupancy = occupancy def to_dataframe(self) -> pd.DataFrame: def fixer(x): @@ -90,9 +110,6 @@ def _repr_html_(self) -> str | None: df = self.to_dataframe() return df._repr_html_() - def __hash__(self) -> int: - return hash(self.primitives) - def basisset( structure: Structure, basis_name: str = "sto-3g", spherical: bool = True @@ -110,30 +127,7 @@ def basisset( Returns: Basis constructed from inputs """ - orbitals = [] - atom_index = [] - - for atom_id in range(structure.num_atoms): - element = int(structure.atomic_number[atom_id]) - out = _bse_to_orbitals(basis_name, element, spherical) - atom_index.extend([atom_id] * sum(len(ao.primitives) for ao in out)) - orbitals += out - - primitives, coefficients, orbital_index = batch_orbitals(orbitals) - primitives = eqx.tree_at(lambda p: p.atom_index, primitives, jnp.array(atom_index)) - center = structure.position[primitives.atom_index, :] - primitives = eqx.tree_at(lambda p: p.center, primitives, center) - - basis = Basis( - orbitals=orbitals, - structure=structure, - primitives=primitives, - coefficients=coefficients, - orbital_index=orbital_index, - basis_name=basis_name, - max_L=int(np.max(primitives.lmn)), - spherical=spherical, - ) + basis = Basis(structure, basis_name, spherical) # TODO(hh): this introduces some performance overhead into basis construction that # could be pushed down into the cached orbitals. diff --git a/mess/integrals.py b/mess/integrals.py index ef05a8d..95a5634 100644 --- a/mess/integrals.py +++ b/mess/integrals.py @@ -23,6 +23,7 @@ from more_itertools import batched from typing import Callable +import equinox as eqx import jax.numpy as jnp import numpy as np from jax import jit, tree, vmap @@ -44,7 +45,7 @@ BinaryPrimitiveOp = Callable[[Primitive, Primitive], float] -@partial(jit, static_argnums=(0, 1)) +@partial(jit, static_argnums=1) def integrate_dense(basis: Basis, primitive_op: BinaryPrimitiveOp) -> FloatNxN: (ii, cl, lhs), (jj, cr, rhs) = basis_iter(basis) aij = cl * cr * vmap(primitive_op)(lhs, rhs) @@ -57,7 +58,7 @@ def integrate_dense(basis: Basis, primitive_op: BinaryPrimitiveOp) -> FloatNxN: return out -@partial(jit, static_argnums=(0, 1)) +@partial(jit, static_argnums=1) def integrate_sparse(basis: Basis, primitive_op: BinaryPrimitiveOp) -> FloatNxN: offset = [0] + [o.num_primitives for o in basis.orbitals] offset = np.cumsum(offset) @@ -193,7 +194,7 @@ def g_term(l1, l2, pa, pb, cp): nuclear_primitives = jit(_nuclear_primitives) -@partial(jit, static_argnums=0) +@jit def nuclear_basis(basis: Basis): def n(atomic_number, position): def op(pi, pj): @@ -294,7 +295,7 @@ def gen_ijkl(n: int): yield idx, jdx, kdx, ldx -@partial(jit, static_argnums=0) +@eqx.filter_jit def eri_basis_sparse(b: Basis): indices = [] batch = [] diff --git a/mess/primitive.py b/mess/primitive.py index 3fc2cd8..aba1054 100644 --- a/mess/primitive.py +++ b/mess/primitive.py @@ -43,16 +43,6 @@ def angular_momentum(self) -> int: def __call__(self, pos: FloatNx3) -> FloatN: return eval_primitive(self, pos) - def __hash__(self) -> int: - values = [] - for k, v in vars(self).items(): - if k.startswith("__") or v is None: - continue - - values.append(v.tobytes()) - - return hash(b"".join(values)) - @jit def normalize(lmn: Int3, alpha: float) -> float: diff --git a/mess/structure.py b/mess/structure.py index d9f0bcb..68151c7 100644 --- a/mess/structure.py +++ b/mess/structure.py @@ -15,23 +15,18 @@ class Structure(eqx.Module): atomic_number: IntN position: FloatNx3 + atomic_symbol: List[str] = eqx.field(static=True) + num_atoms: int = eqx.field(static=True) + num_electrons: int = eqx.field(static=True) - def __post_init__(self): - # single atom case - self.atomic_number = np.atleast_1d(self.atomic_number) - self.position = np.atleast_2d(self.position) + def __init__(self, atomic_number: IntN, position: FloatNx3): + self.atomic_number = np.atleast_1d(atomic_number) + self.position = np.atleast_2d(position) - @property - def num_atoms(self) -> int: - return len(self.atomic_number) + self.num_atoms = len(self.atomic_number) + self.num_electrons = int(np.sum(self.atomic_number)) - @property - def atomic_symbol(self) -> List[str]: - return [elements[z].symbol for z in self.atomic_number] - - @property - def num_electrons(self) -> int: - return np.sum(self.atomic_number) + self.atomic_symbol = [elements[z].symbol for z in self.atomic_number] def _repr_html_(self): import py3Dmol diff --git a/mess/zeropad_integrals.py b/mess/zeropad_integrals.py index 1121c60..a6184ea 100644 --- a/mess/zeropad_integrals.py +++ b/mess/zeropad_integrals.py @@ -11,7 +11,7 @@ from mess.types import FloatNxN -@partial(jit, static_argnums=0) +@jit def overlap_basis_zeropad(basis: Basis) -> FloatNxN: def op(a, b): return _overlap_primitives_zeropad(a, b, basis.max_L)