diff --git a/moldesign/_notebooks/Example 6. Align Two Molecules.ipynb b/moldesign/_notebooks/Example 6. Align Two Molecules.ipynb
new file mode 100644
index 0000000..8e840ee
--- /dev/null
+++ b/moldesign/_notebooks/Example 6. Align Two Molecules.ipynb
@@ -0,0 +1,264 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "About Forum Issues Tutorials Documentation\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "
Example 6: Align Two Molecules
\n",
+ "---\n",
+ "\n",
+ "\n",
+ "This notebook reads in two structures and calculates the transformation matrix aligning the C-alpha atoms from a subset of residues one structure to the other.\n",
+ "\n",
+ " - _Author_: [Dave Parker](https://github.com/ktbolt), Autodesk Research
\n",
+ " - _Created on_: August 23, 2017\n",
+ " - _Tags_: align, fit\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 70,
+ "metadata": {
+ "collapsed": true
+ },
+ "outputs": [],
+ "source": [
+ "import moldesign as mdt\n",
+ "from moldesign import units as u\n",
+ "from moldesign.molecules.atomcollections import AtomList\n",
+ "import numpy as np"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Read in and display the source molecule."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 71,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "2395475165f342f5b5d94da1027b4e26"
+ }
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "src_mol = mdt.from_pdb(\"1yu8\")\n",
+ "src_mol.draw()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Extract the source molecule C-alpha atoms for residues 64-72."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 89,
+ "metadata": {
+ "collapsed": true
+ },
+ "outputs": [],
+ "source": [
+ "src_residues = src_mol.chains[\"X\"].residues\n",
+ "src_residue_ids = set(range(64,73))\n",
+ "src_atoms = [res.atoms[\"CA\"] for res in src_residues if res.pdbindex in src_residue_ids]\n",
+ "src_atoms_list = AtomList(src_atoms)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Read in and display the destination molecule."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 76,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "WARNING: Residue GLU502 (index 271, chain A) is missing expected atoms. Attempting to infer chain end\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "4948f6a1dc234e02b35d375f52ffa115"
+ }
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "dst_mol = mdt.from_pdb(\"3ac2\")\n",
+ "dst_mol.draw()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Extract the destination molecule C-alpha atoms for residues 446-454."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 77,
+ "metadata": {
+ "collapsed": true
+ },
+ "outputs": [],
+ "source": [
+ "dst_residues = dst_mol.chains[\"A\"].residues\n",
+ "dst_residue_ids = set(range(446,455))\n",
+ "dst_atoms = [res.atoms[\"CA\"] for res in residues if res.pdbindex in residue_ids]\n",
+ "dst_atoms_list = AtomList(dst_atoms)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Calculate the transformation aligning the source atoms to destination atoms."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 88,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "rmsd = 0.412441\n",
+ "transformation = \n",
+ "[[ 1.00000000e+00 4.56737751e-16 2.21965480e-16 -7.10542736e-15]\n",
+ " [ -3.52269708e-16 1.00000000e+00 -6.83377612e-17 8.88178420e-15]\n",
+ " [ -4.55009138e-16 5.86301682e-16 1.00000000e+00 8.88178420e-15]\n",
+ " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 1.00000000e+00]]\n"
+ ]
+ }
+ ],
+ "source": [
+ "rmsd, xform = src_atoms_list.align(dst_atoms_list)\n",
+ "print(\"rmsd = %f\" % rmsd)\n",
+ "print(\"transformation = \\n\" + np.array_str(xform))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Transform all of the source atoms."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 79,
+ "metadata": {
+ "collapsed": true
+ },
+ "outputs": [],
+ "source": [
+ "src_atoms = src_mol.atoms\n",
+ "coords = np.array([0.0, 0.0, 0.0, 1.0],dtype=float)\n",
+ "\n",
+ "for i in range(0,len(src_atoms)):\n",
+ " coords[0:3] = src_atoms[i].position.value_in(u.angstrom)\n",
+ " xcoords = xform.dot(coords)\n",
+ " src_atoms[i].x = xcoords[0] * u.angstrom\n",
+ " src_atoms[i].y = xcoords[1] * u.angstrom\n",
+ " src_atoms[i].z = xcoords[2] * u.angstrom\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Merge the source and destination molecules into s single molecule and display it."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 80,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "WARNING: atom indices modified due to name clashes\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "888fa1f3e38b498ea3c11fc37be561f7"
+ }
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "merged_mol = mdt.Molecule([src_mol, dst_mol])\n",
+ "merged_mol.draw()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": true
+ },
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.1"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/moldesign/_tests/test_alignments.py b/moldesign/_tests/test_alignments.py
index 8fd072a..cffb6c8 100644
--- a/moldesign/_tests/test_alignments.py
+++ b/moldesign/_tests/test_alignments.py
@@ -6,10 +6,11 @@
import moldesign as mdt
from moldesign import geom
from moldesign.mathutils import normalized
+from moldesign.molecules.atomcollections import AtomList
from moldesign import units as u
from .molecule_fixtures import *
-from .helpers import assert_almost_equal
+from .helpers import assert_almost_equal, get_data_path
__PYTEST_MARK__ = 'internal' # mark all tests in this module with this label (see ./conftest.py)
@@ -146,3 +147,34 @@ def test_pmi_reorientation_on_benzene(benzene):
np.testing.assert_allclose(newdistmat.defunits_value(),
original_distmat)
+# Test the alignmnet of the atoms from two molecules.
+def test_align_molecules():
+
+ # Get the CA atoms for chain X residues 64-72.
+ mol1 = mdt.read(get_data_path('1yu8.pdb'))
+ residues = mol1.chains["X"].residues
+ residue_ids = set(range(64,73))
+ src_atoms = [res.atoms["CA"] for res in residues if res.pdbindex in residue_ids]
+
+ # Get the CA atoms for chain A residues 446-454.
+ mol2 = mdt.read(get_data_path('3ac2.pdb'))
+ residues = mol2.chains["A"].residues
+ residue_ids = set(range(446,455))
+ dst_atoms = [res.atoms["CA"] for res in residues if res.pdbindex in residue_ids]
+
+ # Align the two atom lists.
+ src_atoms_list = AtomList(src_atoms)
+ dst_atoms_list = AtomList(dst_atoms)
+ rmsd_error, xform = src_atoms_list.align(dst_atoms_list)
+
+ # Compare results to expected values.
+ rtol = 1e-5
+ expected_rmsd_error = 0.412441
+ expected_xform = np.array([(4.42719613e-01, 8.84833988e-01, -1.45148745e-01, 1.13863353e+01), \
+ (1.30739446e-02, -1.68229930e-01, -9.85661079e-01, 2.42427092e+01), \
+ (-8.96564786e-01, 4.34473825e-01, -8.60469602e-02, 2.24877465e+01), \
+ (0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.00000000e+00)])
+ np.testing.assert_allclose(rmsd_error, expected_rmsd_error, rtol)
+ np.testing.assert_allclose(xform, expected_xform, rtol)
+
+
diff --git a/moldesign/mathutils/__init__.py b/moldesign/mathutils/__init__.py
index fb97fc3..36d8c68 100644
--- a/moldesign/mathutils/__init__.py
+++ b/moldesign/mathutils/__init__.py
@@ -1,3 +1,4 @@
from .vectormath import *
from .eigen import *
-from .grids import *
\ No newline at end of file
+from .grids import *
+from .align import *
diff --git a/moldesign/mathutils/align.py b/moldesign/mathutils/align.py
new file mode 100644
index 0000000..35253aa
--- /dev/null
+++ b/moldesign/mathutils/align.py
@@ -0,0 +1,147 @@
+from __future__ import print_function, absolute_import, division
+from future.builtins import *
+from future import standard_library
+standard_library.install_aliases()
+
+# Copyright 2017 Autodesk Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from math import sqrt
+import numpy as np
+import moldesign as mdt
+from ..utils import exports
+
+@exports
+def rmsd_align(src_points, dst_points):
+ """ Calculate the 4x4 transformation matrix that aligns a set of source points
+ to a set of destination of points.
+
+ This function implements the method from Kabsch (Acta Cryst. (1978) A34, 827-828) that
+ calculate the optimal rotation matrix that minimizes the root mean squared deviation between
+ two paired sets of points.
+
+ Args:
+ src_points (numpy.ndarray): An Nx3 array of the 3D source points.
+ dst_points (numpy.ndarray): An Nx3 array of the 3D destination points.
+
+ Returns:
+ units.Scalar[length]: The root mean square deviation measuring the alignment error.
+ numpy.ndarray: The 4x4 array that transforms the source points to the destination points.
+ """
+
+ if (src_points.shape[0] == 0) or (dst_points.shape[0] == 0):
+ n = "Source" if src_points.shape[0] == 0 else "Destination"
+ print('WARNING: %s points array for RMSD aligning has zero length.' % n)
+ return 0.0, np.identity(4, dtype=float)
+
+ if src_points.shape[0] != dst_points.shape[0]:
+ raise ValueError(
+ 'The number of points for calculating the RMSD between the first array %d \n'
+ 'does not equal the number of points in the second array %d' % (num_points1, num_points2))
+
+ if mdt.units.get_units(src_points) != mdt.units.get_units(dst_points):
+ raise ValueError("The source points units '%s' don't match the destination points units '%s'" %
+ (mdt.units.get_units(src_points), mdt.units.get_units(dst_points)))
+
+ # Numpy matrix/vector multiplication results in a vector with dimensionless units
+ # even if the vector has units. To retain the vector's units multiple the result by
+ # the points units.
+ if isinstance(src_points, mdt.units.MdtQuantity):
+ units = mdt.units.get_units(src_points)
+ else:
+ units = 1.0
+
+ # Calculate point centers.
+ num_points = src_points.shape[0]
+ src_center = src_points.sum(axis=0) / num_points
+ dst_center = dst_points.sum(axis=0) / num_points
+
+ # Calculate correlation matrix.
+ corr_mat = np.dot(np.transpose(dst_points[:]-dst_center), src_points[:]-src_center)
+
+ # Compute singular value decomposition.
+ u, s, v = np.linalg.svd(corr_mat)
+ det = np.linalg.det(v) * np.linalg.det(u)
+
+ # Make sure the rotation preserves orientation (det = 1).
+ if det < 0.0:
+ u[:, -1] = -u[:, -1]
+
+ # Calculate matrix rotating src to dst.
+ rot_mat = np.dot(u, v)
+
+ # Calculate the 4x4 matrix transforming src to dst.
+ tsrc = np.identity(4, dtype=float)
+ tsrc[0:3,3] = src_center
+ tcenter = np.identity(4, dtype=float)
+ tcenter[0:3,3] = dst_center - src_center
+ rn = np.identity(4, dtype=float)
+ rn[:3,:3] = rot_mat
+ m1 = np.dot(rn, np.linalg.inv(tsrc))
+ m2 = np.dot(tcenter, tsrc)
+ xform = np.dot(m2, m1)
+
+ # Calculate rmsd error.
+ rmsd_error = 0.0
+ for i in range(num_points):
+ cdiff = rot_mat.dot(src_points[i]-src_center)*units - (dst_points[i]-dst_center)
+ rmsd_error += cdiff.dot(cdiff)
+ #__for i in range(num_points)
+ rmsd_error = np.sqrt(rmsd_error / num_points)
+ return rmsd_error, xform
+
+@exports
+def calculate_rmsd(points1, points2, centered=False):
+ """ Calculate the root mean square deviation (RMSD) between two sets of 3D points.
+
+ Args:
+ points1 (numpy.ndarray): An Nx3 array of the 3D points.
+ points2 (numpy.ndarray): An Nx3 array of the 3D points.
+ centered (bool): If true then the points are centered around (0,0,0).
+
+ Returns:
+ units.Scalar[length]: The root mean square deviation.
+ """
+ if (points1.shape[0] == 0) or (points2.shape[0] == 0):
+ n = 1 if points1.shape[0] == 0 else 2
+ print('WARNING: Points%d array for RMSD calculation has zero length.' % n)
+ return 0.0
+
+ if points1.shape[0] != points2.shape[0]:
+ raise ValueError(
+ 'The number of points for calculating the RMSD between the first array %d \n'
+ 'does not equal the number of points in the second array %d' % (num_points1, num_points2))
+
+ if mdt.units.get_units(points1) != mdt.units.get_units(points2):
+ raise ValueError("Points1 units '%s' don't match the points2 units '%s'" %
+ (mdt.units.get_units(points1), mdt.units.get_units(points2)))
+
+ # Calculate point centers.
+ num_points = points1.shape[0]
+ if centered:
+ center1 = np.array([0.0, 0.0, 0.0],dtype=float)
+ center2 = np.array([0.0, 0.0, 0.0],dtype=float)
+ else:
+ center1 = points1.sum(axis=0) / num_points
+ center2 = points2.sum(axis=0) / num_points
+
+ # Calculate RMSD.
+ rmsd = 0.0
+ center = center2 - center1
+ for i in range(num_points):
+ cdiff = points1[i] - points2[i] + center
+ rmsd += cdiff.dot(cdiff)
+
+ return np.sqrt(rmsd / num_points)
+
diff --git a/moldesign/molecules/atomcollections.py b/moldesign/molecules/atomcollections.py
index 1aa8ff1..f5476cc 100644
--- a/moldesign/molecules/atomcollections.py
+++ b/moldesign/molecules/atomcollections.py
@@ -25,6 +25,7 @@
import numpy as np
import moldesign as mdt
+from .. mathutils.align import rmsd_align, calculate_rmsd
from .. import units as u
from .. import utils, external, mathutils, widgets
from . import toplevel
@@ -433,6 +434,53 @@ def bonds_to(self, other):
bonds.append(bond)
return bonds
+ def align(self, other):
+ """ Aligns this list of atoms to another list of atoms.
+
+ Args:
+ other (AtomContainer): The list of atoms to align to.
+
+ Returns:
+ units.Scalar[length]: The root mean square deviation measuring the alignment error.
+ numpy.ndarray: The 4x4 matrix that transforms this list of atoms to the input list of atoms.
+ """
+ if len(self.atoms) != len(other.atoms):
+ raise ValueError(
+ 'The number of atoms for fitting this atom list %d does not equal \n'
+ 'the number of atoms in the input list %d' % (len(self.atoms), len(other.atoms)))
+
+ # Create atom coordinate arrays.
+ num_atoms = len(self.atoms)
+ coord_units = mdt.units.get_units(self.atoms[0].position)
+ self_coords = np.zeros((num_atoms, 3), dtype=float) * coord_units
+ other_coords = np.zeros((num_atoms, 3), dtype=float) * coord_units
+ for i in range(0,num_atoms):
+ self_coords[i] = self.atoms[i].position
+ other_coords[i] = other.atoms[i].position
+
+ # Calculate the 4x4 matrix transforming self_coords into other_coords.
+ rmsd_error, xform = rmsd_align(self_coords, other_coords)
+ return rmsd_error, xform
+
+ def rmsd(self, other):
+ """ Calculates the root mean square deviation (RMSD) between this list of atoms
+ and another list of atoms.
+
+ Args:
+ other (AtomContainer): The list of atoms to calculate the RMSD with.
+
+ Returns:
+ u.Scalar[length]: The root mean square deviation measuring the alignment error.
+ """
+ if len(self.atoms) != len(other.atoms):
+ raise ValueError('The number of atoms for calculating the RMSD between this atom list %d does not equal the number of atoms in the input list %d' % (len(self.atoms), len(other.atoms)))
+ num_atoms = len(self.atoms)
+ self_coords = np.zeros((num_atoms, 3), dtype=float)
+ other_coords = np.zeros((num_atoms, 3), dtype=float)
+ for i in range(0,num_atoms):
+ self_coords[i] = self.atoms[i].position.value_in(u.angstrom)
+ other_coords[i] = other.atoms[i].position.value_in(u.angstrom)
+ return rmsd(self_coords, other_coords) * u.default.length
@toplevel
class AtomList(AtomContainer, list): # order is important, list will override methods otherwise