From 49f811de42649961d0e6757a220178cf818590ef Mon Sep 17 00:00:00 2001 From: ljwoods2 Date: Thu, 12 Sep 2024 21:26:08 -0700 Subject: [PATCH] rotation correlation function --- imdclient/vdos.py | 189 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 189 insertions(+) diff --git a/imdclient/vdos.py b/imdclient/vdos.py index adc4f30..5f7511f 100644 --- a/imdclient/vdos.py +++ b/imdclient/vdos.py @@ -228,3 +228,192 @@ def _conclude(self): tmp1[: self.nCorr] = self.results.trCorr[:] tmp1[self.nCorr :] = tmp1[1 : self.nCorr][::-1] self.results.trVDoS = rfft(tmp1, axis=0) + + +class roVDoS(StreamFriendlyAnalysisBase): + + def __init__( + self, trajectory, selection, nCorr=100, verbose=False, **kwargs + ): + super().__init__(trajectory, verbose=False, **kwargs) + self.sel = selection + self.nCorr = nCorr + self.nRes = selection.residues.n_residues + + def _prepare(self): + + self.AngMomentumBuffer = np.zeros( + (self.nCorr, self.nRes, 3), dtype=np.float64 + ) + self.COMposBuffer = np.zeros( + (self.nCorr, self.nRes, 3), dtype=np.float64 + ) + self.COMvelBuffer = np.zeros( + (self.nCorr, self.nRes, 3), dtype=np.float64 + ) + self.MOIBuffer = np.zeros((self.nCorr, self.nRes, 3)) + + # time and frequency axes for correlation functions and VDoS + self.results["tau"] = np.zeros(self.nCorr, dtype=np.float64) + self.results["wavenumber"] = np.zeros(self.nCorr, dtype=np.float64) + + self.results["roCorr"] = np.zeros( + (self.nCorr, self.nRes), dtype=np.float64 + ) + self.results["roVDoS"] = np.zeros( + (self.nCorr, self.nRes), dtype=np.float64 + ) + + self.corrCnt = 0 + + # If residues can have different num atoms, use list + self.atMassLists = [] + for res in self.sel.residues: + self.atMassLists.append( + res.atoms.masses[:, np.newaxis].astype("float64") + ) + + self.prev_evecs = None + + def _single_frame(self): + idx = self._ts.frame % self.nCorr + + if self._ts.frame < self.nCorr: + self.results.tau[idx] = self._ts.time + + atMassLists = self.atMassLists + sel = self.sel + residue_masses = sel.residues.masses + COMvelBuffer = self.COMvelBuffer + COMposBuffer = self.COMposBuffer + AngMomentumBuffer = self.AngMomentumBuffer + MOIBuffer = self.MOIBuffer + prev_evecs = self.prev_evecs + + calc_inertia_tensor = self._calc_inertia_tensor + + for i in range(self.nRes): + atoms = sel.residues[i].atoms + atom_masses_col = atMassLists[i] + atom_masses_row = atoms.masses + residue_mass = residue_masses[i] + + ## Compute COM pos + pos = atoms.positions.astype("float64") + com_pos = np.sum(atom_masses_col * pos, axis=0) / residue_mass + COMposBuffer[idx, i] = com_pos + + # valid_com = sel.residues[i].atoms.center_of_mass() + # if not np.allclose(com_pos, valid_com): + # print(f"Center of mass is not valid {com_pos} != {valid_com}") + + ## Compute COM pos + vel = atoms.velocities.astype("float64") + com_vel = np.sum(atom_masses_col * vel, axis=0) / residue_mass + COMvelBuffer[idx, i] = com_vel + + ## Calcular angular momentum + ang = np.sum( + atom_masses_col * np.cross((pos - com_pos), (vel - com_vel)), + axis=0, + ) + + ## Compute MOI, Angular Momentum along principal axis + inertia_tensor = calc_inertia_tensor(com_pos, pos, atom_masses_row) + # valid_inertia_tensor = sel.residues[i].atoms.moment_of_inertia() + # if not np.allclose(inertia_tensor, valid_inertia_tensor): + # print( + # f"Inertia tensor is not valid {inertia_tensor} != {valid_inertia_tensor}" + # ) + + values, evecs = np.linalg.eigh(inertia_tensor) + + ## Ensure that the eigenvectors don't flip + if self.prev_evecs is not None: + for j in range(3): + if np.dot(evecs[:, j], prev_evecs[:, j]) < 0: + evecs[:, j] *= -1 + prev_evecs = evecs + + indices = np.argsort(values) + rot_axis = evecs[:, indices] + + # valid_principal_axis = sel.residues[i].atoms.principal_axes().T + # if not np.allclose(rot_axis, valid_principal_axis): + # print( + # f"Principal axis is not valid {rot_axis} != {valid_principal_axis}" + # ) + + MOIBuffer[idx, i] = values + + AngMomentumBuffer[idx, i] = np.dot(ang, rot_axis) + # np.sum(ang * rot_axis, axis=0) + + # if sufficient data is available in buffers, compute correlation functions + if self._ts.frame >= self.nCorr - 1: + self._calcCorr(self._ts.frame + 1) + + def _calc_inertia_tensor(self, com, pos, masses): + pos_com_frame = pos - com + tens = np.zeros((3, 3), dtype=np.float64) + # xx + tens[0][0] = ( + masses * (pos_com_frame[:, 1] ** 2 + pos_com_frame[:, 2] ** 2) + ).sum() + # xy & yx + tens[0][1] = tens[1][0] = -( + masses * pos_com_frame[:, 0] * pos_com_frame[:, 1] + ).sum() + # xz & zx + tens[0][2] = tens[2][0] = -( + masses * pos_com_frame[:, 0] * pos_com_frame[:, 2] + ).sum() + # yy + tens[1][1] = ( + masses * (pos_com_frame[:, 0] ** 2 + pos_com_frame[:, 2] ** 2) + ).sum() + # yz + zy + tens[1][2] = tens[2][1] = -( + masses * pos_com_frame[:, 1] * pos_com_frame[:, 2] + ).sum() + # zz + tens[2][2] = ( + masses * (pos_com_frame[:, 0] ** 2 + pos_com_frame[:, 1] ** 2) + ).sum() + return tens + + def _calcCorr(self, start): + """compute correlation functions for all data in buffers""" + # compute time correlation function for COM translation (for each residue) + # can be parallelized + roCorr = self.results.roCorr + MOIBuffer = self.MOIBuffer + nCorr = self.nCorr + AngMomentumBuffer = self.AngMomentumBuffer + + for i in range(nCorr): + j = start % nCorr + k = (j + i) % nCorr + roCorr[i] += np.sum( + (AngMomentumBuffer[j] / np.sqrt(MOIBuffer[j])) + * (AngMomentumBuffer[k] / np.sqrt(MOIBuffer[k])), + axis=1, + ) + self.corrCnt += 1 + + def _conclude(self): + """normalize correlation functions by number of data points + and calculate compute vibrational density of states from time correlation functions + """ + # Normalization + self.results.roCorr[:] /= self.corrCnt + ## Calculate VDoS + period = (self.results.tau[1] - self.results.tau[0]) * ( + 2 * self.nCorr - 1 + ) + wn0 = (1.0 / period) * 33.35641 + self.results.wavenumber = np.arange(0, self.nCorr) * wn0 + tmp1 = np.zeros((2 * self.nCorr - 1, self.nRes), dtype=np.float64) + tmp1[: self.nCorr] = self.results.roCorr[:] + tmp1[self.nCorr :] = tmp1[1 : self.nCorr][::-1] + self.results.roVDoS = rfft(tmp1, axis=0)