Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 47 additions & 6 deletions src/python/friendzone/utils/unwrap_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,43 @@
# limitations under the License.

from simde import TotalEnergy, EnergyNuclearGradientStdVectorD
import numpy as np


def _compare_mol_and_point(mol, points):
"""This function is essentially a work around for the comparisons not being
exposed to Python.
def _compare_mol_and_point(mol, points, atol=1e-12, rtol=0.0):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is NOT defaulting to the original behavior. The original behavior was exact equality.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure then what was the first comment about having a conditional, the first time I shared the error, that Is why I change the function and add a conditional for equality if needed.

Don't know what to do from here

"""
Compare the 3D nuclear coordinates of a Molecule and a PointSet.

This function is intended to ensure that gradient calculations
are being performed at the correct molecular geometry.

Parameters
----------
mol : chemist.Molecule
The molecule whose nuclear coordinates are being compared.

points : chemist.PointSet
The point set to compare against (usually created from the molecule's nuclei).

atol : float, optional
Absolute tolerance used by np.isclose. Default is 1e-12.
This catches numerical noise due to Python/C++ floating-point boundary.

rtol : float, optional
Relative tolerance used by np.isclose. Default is 0.0 (disabled),
which ensures no scaling with magnitude — appropriate for comparing coordinates.

Returns
-------
bool
True if all coordinate components match within the given tolerances.
False otherwise.

Notes
-----
Exact floating-point equality is not used because of small differences
(e.g., ~1e-314) introduced by Python/C++ interoperability. These are
not chemically meaningful and should be tolerated with a small `atol`.
"""
if mol.size() != points.size():
return False
Expand All @@ -27,13 +59,22 @@ def _compare_mol_and_point(mol, points):
point_i = points.at(i)

for j in range(3):
if atom_i.coord(j) != point_i.coord(j):
return False
a = atom_i.coord(j)
b = point_i.coord(j)

if rtol == 0.0 and atol == 0.0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check can run into problems with floating-point representation of hard-zero. Please use None as the default values for atol and rtol and switch this line to check for None and not 0.0.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as before, not sure if I need to leave it the way it was

# Strict comparison: values must be bitwise identical
if a != b:
return False
else:
# Use np.isclose to allow for floating-point tolerance
if not np.isclose(a, b, atol=atol, rtol=rtol):
return False

return True


def unwrap_inputs(pt, inputs):
def unwrap_inputs(pt, inputs, atol=1e-12, rtol=0.0):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't be changing this method. The call to _compare_mol_and_point should always use exact equality. If the user provides even slightly different molecule objects, they are NOT requesting the derivative at the same point.

If this is where your error was originating from (as opposed to you trying to reuse _compare_mol_and_point, which is what I thought the problem was), then I suspect that the problem is that you're not passing the PointSet object from your Molecule object into your module.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as before not sure if I need to leave it the way it was

""" Code factorization for unwrapping a module's inputs.

Many of our friends expose interfaces which are analogous to high-level
Expand Down