Skip to content

Commit

Permalink
rotation correlation function
Browse files Browse the repository at this point in the history
  • Loading branch information
ljwoods2 committed Sep 13, 2024
1 parent 81c1eaa commit 49f811d
Showing 1 changed file with 189 additions and 0 deletions.
189 changes: 189 additions & 0 deletions imdclient/vdos.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,192 @@ def _conclude(self):
tmp1[: self.nCorr] = self.results.trCorr[:]
tmp1[self.nCorr :] = tmp1[1 : self.nCorr][::-1]
self.results.trVDoS = rfft(tmp1, axis=0)


class roVDoS(StreamFriendlyAnalysisBase):

def __init__(
self, trajectory, selection, nCorr=100, verbose=False, **kwargs
):
super().__init__(trajectory, verbose=False, **kwargs)
self.sel = selection
self.nCorr = nCorr
self.nRes = selection.residues.n_residues

def _prepare(self):

self.AngMomentumBuffer = np.zeros(
(self.nCorr, self.nRes, 3), dtype=np.float64
)
self.COMposBuffer = np.zeros(
(self.nCorr, self.nRes, 3), dtype=np.float64
)
self.COMvelBuffer = np.zeros(
(self.nCorr, self.nRes, 3), dtype=np.float64
)
self.MOIBuffer = np.zeros((self.nCorr, self.nRes, 3))

# time and frequency axes for correlation functions and VDoS
self.results["tau"] = np.zeros(self.nCorr, dtype=np.float64)
self.results["wavenumber"] = np.zeros(self.nCorr, dtype=np.float64)

self.results["roCorr"] = np.zeros(
(self.nCorr, self.nRes), dtype=np.float64
)
self.results["roVDoS"] = np.zeros(
(self.nCorr, self.nRes), dtype=np.float64
)

self.corrCnt = 0

# If residues can have different num atoms, use list
self.atMassLists = []
for res in self.sel.residues:
self.atMassLists.append(
res.atoms.masses[:, np.newaxis].astype("float64")
)

self.prev_evecs = None

def _single_frame(self):
idx = self._ts.frame % self.nCorr

if self._ts.frame < self.nCorr:
self.results.tau[idx] = self._ts.time

atMassLists = self.atMassLists
sel = self.sel
residue_masses = sel.residues.masses
COMvelBuffer = self.COMvelBuffer
COMposBuffer = self.COMposBuffer
AngMomentumBuffer = self.AngMomentumBuffer
MOIBuffer = self.MOIBuffer
prev_evecs = self.prev_evecs

calc_inertia_tensor = self._calc_inertia_tensor

for i in range(self.nRes):
atoms = sel.residues[i].atoms
atom_masses_col = atMassLists[i]
atom_masses_row = atoms.masses
residue_mass = residue_masses[i]

## Compute COM pos
pos = atoms.positions.astype("float64")
com_pos = np.sum(atom_masses_col * pos, axis=0) / residue_mass
COMposBuffer[idx, i] = com_pos

# valid_com = sel.residues[i].atoms.center_of_mass()
# if not np.allclose(com_pos, valid_com):
# print(f"Center of mass is not valid {com_pos} != {valid_com}")

## Compute COM pos
vel = atoms.velocities.astype("float64")
com_vel = np.sum(atom_masses_col * vel, axis=0) / residue_mass
COMvelBuffer[idx, i] = com_vel

## Calcular angular momentum
ang = np.sum(
atom_masses_col * np.cross((pos - com_pos), (vel - com_vel)),
axis=0,
)

## Compute MOI, Angular Momentum along principal axis
inertia_tensor = calc_inertia_tensor(com_pos, pos, atom_masses_row)
# valid_inertia_tensor = sel.residues[i].atoms.moment_of_inertia()
# if not np.allclose(inertia_tensor, valid_inertia_tensor):
# print(
# f"Inertia tensor is not valid {inertia_tensor} != {valid_inertia_tensor}"
# )

values, evecs = np.linalg.eigh(inertia_tensor)

## Ensure that the eigenvectors don't flip
if self.prev_evecs is not None:
for j in range(3):
if np.dot(evecs[:, j], prev_evecs[:, j]) < 0:
evecs[:, j] *= -1
prev_evecs = evecs

indices = np.argsort(values)
rot_axis = evecs[:, indices]

# valid_principal_axis = sel.residues[i].atoms.principal_axes().T
# if not np.allclose(rot_axis, valid_principal_axis):
# print(
# f"Principal axis is not valid {rot_axis} != {valid_principal_axis}"
# )

MOIBuffer[idx, i] = values

AngMomentumBuffer[idx, i] = np.dot(ang, rot_axis)
# np.sum(ang * rot_axis, axis=0)

# if sufficient data is available in buffers, compute correlation functions
if self._ts.frame >= self.nCorr - 1:
self._calcCorr(self._ts.frame + 1)

def _calc_inertia_tensor(self, com, pos, masses):
pos_com_frame = pos - com
tens = np.zeros((3, 3), dtype=np.float64)
# xx
tens[0][0] = (
masses * (pos_com_frame[:, 1] ** 2 + pos_com_frame[:, 2] ** 2)
).sum()
# xy & yx
tens[0][1] = tens[1][0] = -(
masses * pos_com_frame[:, 0] * pos_com_frame[:, 1]
).sum()
# xz & zx
tens[0][2] = tens[2][0] = -(
masses * pos_com_frame[:, 0] * pos_com_frame[:, 2]
).sum()
# yy
tens[1][1] = (
masses * (pos_com_frame[:, 0] ** 2 + pos_com_frame[:, 2] ** 2)
).sum()
# yz + zy
tens[1][2] = tens[2][1] = -(
masses * pos_com_frame[:, 1] * pos_com_frame[:, 2]
).sum()
# zz
tens[2][2] = (
masses * (pos_com_frame[:, 0] ** 2 + pos_com_frame[:, 1] ** 2)
).sum()
return tens

def _calcCorr(self, start):
"""compute correlation functions for all data in buffers"""
# compute time correlation function for COM translation (for each residue)
# can be parallelized
roCorr = self.results.roCorr
MOIBuffer = self.MOIBuffer
nCorr = self.nCorr
AngMomentumBuffer = self.AngMomentumBuffer

for i in range(nCorr):
j = start % nCorr
k = (j + i) % nCorr
roCorr[i] += np.sum(
(AngMomentumBuffer[j] / np.sqrt(MOIBuffer[j]))
* (AngMomentumBuffer[k] / np.sqrt(MOIBuffer[k])),
axis=1,
)
self.corrCnt += 1

def _conclude(self):
"""normalize correlation functions by number of data points
and calculate compute vibrational density of states from time correlation functions
"""
# Normalization
self.results.roCorr[:] /= self.corrCnt
## Calculate VDoS
period = (self.results.tau[1] - self.results.tau[0]) * (
2 * self.nCorr - 1
)
wn0 = (1.0 / period) * 33.35641
self.results.wavenumber = np.arange(0, self.nCorr) * wn0
tmp1 = np.zeros((2 * self.nCorr - 1, self.nRes), dtype=np.float64)
tmp1[: self.nCorr] = self.results.roCorr[:]
tmp1[self.nCorr :] = tmp1[1 : self.nCorr][::-1]
self.results.roVDoS = rfft(tmp1, axis=0)

0 comments on commit 49f811d

Please sign in to comment.