diff --git a/src/python/friendzone/utils/unwrap_inputs.py b/src/python/friendzone/utils/unwrap_inputs.py index f79278e..71d3a25 100644 --- a/src/python/friendzone/utils/unwrap_inputs.py +++ b/src/python/friendzone/utils/unwrap_inputs.py @@ -13,11 +13,43 @@ # limitations under the License. from simde import TotalEnergy, EnergyNuclearGradientStdVectorD +import numpy as np -def _compare_mol_and_point(mol, points): - """This function is essentially a work around for the comparisons not being - exposed to Python. +def _compare_mol_and_point(mol, points, atol=1e-12, rtol=0.0): + """ + Compare the 3D nuclear coordinates of a Molecule and a PointSet. + + This function is intended to ensure that gradient calculations + are being performed at the correct molecular geometry. + + Parameters + ---------- + mol : chemist.Molecule + The molecule whose nuclear coordinates are being compared. + + points : chemist.PointSet + The point set to compare against (usually created from the molecule's nuclei). + + atol : float, optional + Absolute tolerance used by np.isclose. Default is 1e-12. + This catches numerical noise due to Python/C++ floating-point boundary. + + rtol : float, optional + Relative tolerance used by np.isclose. Default is 0.0 (disabled), + which ensures no scaling with magnitude — appropriate for comparing coordinates. + + Returns + ------- + bool + True if all coordinate components match within the given tolerances. + False otherwise. + + Notes + ----- + Exact floating-point equality is not used because of small differences + (e.g., ~1e-314) introduced by Python/C++ interoperability. These are + not chemically meaningful and should be tolerated with a small `atol`. """ if mol.size() != points.size(): return False @@ -27,13 +59,22 @@ def _compare_mol_and_point(mol, points): point_i = points.at(i) for j in range(3): - if atom_i.coord(j) != point_i.coord(j): - return False + a = atom_i.coord(j) + b = point_i.coord(j) + + if rtol == 0.0 and atol == 0.0: + # Strict comparison: values must be bitwise identical + if a != b: + return False + else: + # Use np.isclose to allow for floating-point tolerance + if not np.isclose(a, b, atol=atol, rtol=rtol): + return False return True -def unwrap_inputs(pt, inputs): +def unwrap_inputs(pt, inputs, atol=1e-12, rtol=0.0): """ Code factorization for unwrapping a module's inputs. Many of our friends expose interfaces which are analogous to high-level