Skip to content

feat: Add test for solid harmonic moment translation (#159) #276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 193 additions & 1 deletion src/grid/tests/test_molgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_almost_equal
import scipy.special
from grid.utils import solid_harmonics, convert_cart_to_sph

from grid.atomgrid import AtomGrid, _get_rgrid_size
from grid.basegrid import LocalGrid
Expand All @@ -35,6 +37,73 @@
# Ignore angular/Lebedev grid warnings where the weights are negative:
pytestmark = pytest.mark.filterwarnings("ignore:Lebedev weights are negative which can*")

# Helper functions for index mapping and binomial coefficient calculation
def horton_index(l, m):
"""Convert (l, m) to flat Horton order index."""
if not isinstance(l, int) or not isinstance(m, int):
raise TypeError("l and m must be integers")
if abs(m) > l:
raise ValueError("abs(m) must be <= l")
if m == 0:
return l * l
elif m > 0:
return l * l + 2 * m - 1
else: # m < 0
return l * l + 2 * abs(m)

def get_lm_from_horton_index(k):
"""Convert flat Horton order index k back to (l, m)."""
if not isinstance(k, (int, np.integer)) or k < 0:
raise ValueError("k must be a non-negative integer")
l = int(np.floor(np.sqrt(k)))
# Check if k corresponds to a valid index for this l
if k >= (l + 1)**2:
# This happens if k is not a perfect square and floor(sqrt(k)) was rounded down incorrectly
# e.g. k=8, l=floor(sqrt(8))=2, (l+1)**2 = 9. k=8 is valid.
# e.g. k=9, l=floor(sqrt(9))=3, (l+1)**2 = 16. k=9 is valid (l=3, m=0)
# Need a robust way to get l
l = int(np.ceil(np.sqrt(k+1)) - 1)

m_abs_or_zero = (k - l*l + 1) // 2 # Corresponds to abs(m) for m != 0, or m=0
if k == l*l:
m = 0
elif (k - l*l) % 2 == 1: # Positive m
m = m_abs_or_zero
else: # Negative m
m = -m_abs_or_zero

# Final check
if abs(m) > l:
# If k was invalid, recalculate l based on the fact that k < (l_true+1)^2
corrected_l = 0
while (corrected_l + 1)**2 <= k:
corrected_l += 1
l = corrected_l
# Recalculate m based on the corrected l
if k == l*l:
m = 0
elif (k - l*l) % 2 == 1:
m = (k - l*l + 1) // 2
else:
m = -(k - l*l) // 2

return l, m
def binomial_safe(n, k):
"""Safely compute binomial coefficient C(n, k), returning 0 if k < 0 or k > n."""
if k < 0 or k > n:
return 0
# Use float conversion for safety with potentially large numbers before sqrt
# and handle potential precision issues with exact=True for large ints
try:
# Attempt exact calculation first for smaller numbers if possible
if n < 1000: # Heuristic threshold
return float(scipy.special.comb(n, k, exact=True))
else:
# Use floating point for potentially large numbers
return scipy.special.comb(n, k, exact=False)
except OverflowError:
# Fallback to floating point if exact calculation overflows
return scipy.special.comb(n, k, exact=False)

class TestMolGrid(TestCase):
"""MolGrid test class."""
Expand Down Expand Up @@ -810,9 +879,132 @@ def test_integrate_hirshfeld_weights_pair_1s(self):
occupation = mg.integrate(fn)
assert_almost_equal(occupation, 2.5, decimal=5)

def test_multipole_translation(self):
"""Test that pure solid harmonic moments translate correctly."""
l_max = 3 # Maximum multipole order to test
alpha = 2.0 # Gaussian exponent
center_orig = np.array([0.0, 0.0, 0.0])
shift_vec = np.array([0.1, -0.2, 0.3]) # A small shift vector
center_shifted = center_orig + shift_vec

# --- Grid Setup ---
# Use a reasonable grid settings for accuracy
# Use GaussLaguerre directly as it produces points on (0, inf)
rgrid = GaussLaguerre(75) # More points for radial grid
# tf = ExpRTransform(alpha=1.0, r0=1e-7) # Use ExpRTransform which maps (0, inf)
# rgrid = tf.transform_1d_grid(pts) <-- Remove transformation step
numbers = np.array([1]) # Single Hydrogen atom (simplest case)
# Place atom at origin for simplicity now, grid should extend enough
atom_coord = np.array([0.0, 0.0, 0.0])
# Use a fine preset for better angular/pruning accuracy
grid_preset = "fine" # "fine" or "veryfine" recommended
atg1 = AtomGrid.from_preset(
atnum=numbers[0], preset=grid_preset, center=atom_coord, rgrid=rgrid
)
becke = BeckeWeights(order=3) # Becke weights
mg = MolGrid(numbers, [atg1], becke)
# ------------------

# --- Define test function (Gaussian centered at origin) ---
def gaussian_density(points, center, exp):
dist_sq = np.sum((points - center)**2, axis=1)
norm = (exp / np.pi)**(1.5) # Normalization for gaussian exp(-exp*r^2)
return norm * np.exp(-exp * dist_sq)

# Evaluate the function centered at the global origin
func_vals = gaussian_density(mg.points, center=np.array([0.0, 0.0, 0.0]), exp=alpha)
# ---------------------------------------------------------

# Calculate moments directly at original and shifted centers
moments_orig = mg.moments(func_vals=func_vals, orders=l_max, centers=center_orig.reshape(1,3), type_mom="pure")
moments_shifted_direct = mg.moments(func_vals=func_vals, orders=l_max, centers=center_shifted.reshape(1,3), type_mom="pure")

# Squeeze the output from (L, 1) to (L,) to match analytical calculation shape
moments_orig = moments_orig.squeeze()
moments_shifted_direct = moments_shifted_direct.squeeze()

# Calculate solid harmonics of the shift vector
shift_vec_cart = shift_vec.reshape(1, 3)
# Need r, theta, phi for solid_harmonics
r_shift = np.linalg.norm(shift_vec)
if r_shift == 0:
# If shift is zero, R_lm is non-zero only for l=m=0, handle separately or ensure shift > 0
R_lm_a = np.zeros(((l_max + 1)**2,))
R_lm_a[0] = 1.0 # R_00 = 1/sqrt(4pi), but solid_harmonics includes the sqrt(4pi/(2l+1)) factor
# Let's re-check the solid_harmonics definition R_lm = sqrt(4pi/(2l+1)) r^l Y_lm
# For l=0, m=0: R_00 = sqrt(4pi/1) * r^0 * Y_00 = sqrt(4pi) * 1 * (1/sqrt(4pi)) = 1.0
else:
shift_vec_sph = convert_cart_to_sph(shift_vec_cart) # Returns (1, 3) array [r, theta, phi]
R_lm_a = solid_harmonics(l_max, shift_vec_sph).flatten() # Get R_lm(a) in flat Horton order

# Calculate analytically shifted moments
num_moments = (l_max + 1)**2
moments_shifted_analytical = np.zeros_like(moments_orig, dtype=np.float64)

for k_lm in range(num_moments):
l, m = get_lm_from_horton_index(k_lm)
term_sum = 0.0
for k_lambda_mu in range(num_moments):
lam, mu = get_lm_from_horton_index(k_lambda_mu)
if lam > l:
continue

l_minus_lambda = l - lam
m_minus_mu = m - mu

if abs(m_minus_mu) > l_minus_lambda:
continue

# Binomial term: sqrt(C(l+m, lam+mu) * C(l-m, lam-mu))
bc1 = binomial_safe(l + m, lam + mu)
bc2 = binomial_safe(l - m, lam - mu)
binom_prod = bc1 * bc2

if binom_prod < 0: # Should not happen with correct binomial_safe
binom_term = 0.0
else:
binom_term = np.sqrt(binom_prod)

if np.isnan(binom_term): # Handle NaN just in case
binom_term = 0.0

# Get R_{l-lambda, m-mu}(a)
k_R = horton_index(l_minus_lambda, m_minus_mu)
# Ensure index is valid before accessing
if k_R >= len(R_lm_a):
R_term = 0.0 # Index out of bounds implies this term is zero
else:
R_term = R_lm_a[k_R]

# Get M_{lambda, mu}(0)
M_term = moments_orig[k_lambda_mu]

# Accumulate sum: (-1)^(l-lambda) * binom_term * R_term * M_term
term_sum += ((-1)**(l - lam)) * binom_term * R_term * M_term

moments_shifted_analytical[k_lm] = term_sum

# Compare results with appropriate tolerance
# Tolerance might need adjustment based on grid quality and l_max
tolerance_kwargs = {'rtol': 1e-5, 'atol': 1e-7}
if l_max > 2: # Potentially lower tolerance for higher moments
tolerance_kwargs = {'rtol': 1e-4, 'atol': 1e-6}

# Provide helpful output on failure
# print(f"Lmax={l_max}, Shift={shift_vec}")
# print(f"Moments Original (k=0..{num_moments-1}): {moments_orig}")
# print(f"Moments Shifted Direct: {moments_shifted_direct}")
# print(f"Moments Shifted Analytical: {moments_shifted_analytical}")
# diff = moments_shifted_direct - moments_shifted_analytical
# print(f"Difference (Direct - Analytical): {diff}")
# print(f"Max Abs Diff: {np.max(np.abs(diff))}")
# print(f"Indices of large diff: {np.where(np.abs(diff) > tolerance_kwargs['atol'] + tolerance_kwargs['rtol'] * np.abs(moments_shifted_analytical))}")

assert_allclose(moments_shifted_direct, moments_shifted_analytical, **tolerance_kwargs)


def test_interpolation_with_gaussian_center():
r"""Test interpolation with molecular grid of sum of two Gaussian examples."""
"""Test if get_multipole works with a specific center."""
coordinates = np.array([[0.0, 0.0, -1.5], [0.0, 0.0, 1.5]])

pts = Trapezoidal(400)
Expand Down