Skip to content

Commit

Permalink
modified scripts, removed on-the-fly sorting from gibbs
Browse files Browse the repository at this point in the history
  • Loading branch information
rsexton2 committed Jan 23, 2024
1 parent eefa6ee commit 6ae2581
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 43 deletions.
205 changes: 163 additions & 42 deletions basicrta/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import gc
from scipy.optimize import linear_sum_assignment as lsa
import bz2
from scipy import stats
gc.enable()
mpl.rcParams['pdf.fonttype'] = 42
rng = default_rng()
Expand All @@ -31,6 +32,39 @@
'make_surv', 'norm_exp', 'get_dec'
]


def resort(r):
mcweights, mcrates = r.mcweights.copy(), r.mcrates.copy()
indicator[:] = indicator_bak
Ls, niter = [L], 0
for j in tqdm(range(r.niter)):
sorts = mcweights[j].argsort()[::-1]
mcweights[j] = mcweights[j][sorts]
mcrates[j] = mcrates[j][sorts]

while niter<10:
Z = np.zeros_like(z)
for j in tqdm(range(2000, 3000), desc='recomputing Q'):
tmp = mcweights[j]*mcrates[j]*np.exp(np.outer(-mcrates[j],x)).T
z = (tmp.T/tmp.sum(axis=1)).T
Z += z
Z = Z/1000

for j in tqdm(range(2000, 3000), desc='resorting'):
tmp = mcweights[j]*mcrates[j]*np.exp(np.outer(-mcrates[j],x)).T
z = (tmp.T/tmp.sum(axis=1)).T

tmpsum = np.ones((ncomp,ncomp), dtype=np.float64)
for k in range(ncomp):
tmpsum[k] = np.sum(z[:,k]*np.log(z[:,k]/Z.T), axis=1)

tmpsum[tmpsum!=tmpsum] = 1e20
sorts = lsa(tmpsum)[1]
mcweights[j] = mcweights[j][sorts]
mcrates[j] = mcrates[j][sorts]
niter += 1


def tm(Prot,i):
dif = Prot['tm{0}'.format(i)][1]-Prot['tm{0}'.format(i)][0]
return [Prot['tm{0}'.format(i)],dif]
Expand Down Expand Up @@ -81,8 +115,8 @@ def run(self):
#mcweights = np.memmap(f'{residue}/.mcweights.npy', shape=(self.niter + 1, ncomp), mode='w+')
#mcrates = np.memmap(f'{residue}/.mcrates.npy', shape=(self.niter + 1, ncomp), mode='w+')
#Ns = np.memmap(f'{residue}/.Ns.npy', shape=(self.niter, ncomp), mode='w+')
Indicator = np.memmap(f'{residue}/.indicator.npy', shape=(self.niter, x.shape[0]),
mode='w+', dtype=np.uint8)
#Indicator = np.memmap(f'{residue}/.indicator_{self.niter}.npy', shape=(self.niter, x.shape[0]),
# mode='w+', dtype=np.uint8)
mcweights = np.zeros((self.niter + 1, ncomp))
mcrates = np.zeros((self.niter + 1, ncomp))
Ns = np.zeros((self.niter, ncomp))
Expand All @@ -91,6 +125,14 @@ def run(self):
tmpw = 9*10**(-np.arange(1, ncomp+1, dtype=float))
mcweights[0], mcrates[0] = tmpw/tmpw.sum(), inrates[::-1]
whypers, rhypers = np.ones(ncomp)/[ncomp], np.ones((ncomp, 2))*[1, 3] # guess hyperparameters
#tmp = np.logspace(0,-2, ncomp)
#whypers = tmp/tmp.sum()
#tmpw = np.sort(np.outer(np.array([6, 3]), 10**np.arange(-ncomp//2, 0, dtype=float)).flatten())[::-1]
#whypers = (9*10**np.arange(-ncomp, 0, dtype=float))[::-1]
#whypers = tmpw/tmpw.sum()
#whypers = (np.arange(1,ncomp+1)[::-1]/np.arange(1,ncomp+1).sum())
#mult = 3
#whypers = (np.linspace(1,mult*(ncomp+1), ncomp)[::-1]/(np.linspace(1,mult*(ncomp+1), ncomp).sum()))
weights, rates = [], []
g, burnin = 0, 0

Expand All @@ -99,37 +141,111 @@ def run(self):
values = [mcweights, mcrates, ncomp, self.niter, _s, t, residue, Ns,
lnp, int(g), int(burnin)]
for j in tqdm(range(self.niter), desc=f'{residue}-K{ncomp}', position=self.loc, leave=False):
#if j<3000:
# whypers, rhypers = np.ones(ncomp)/[10], np.ones((ncomp, 2))*[1, 10] # guess hyperparameters
#elif (j>=3000)&(j<6000):
# whypers, rhypers = np.ones(ncomp)/[100], np.ones((ncomp, 2))*[1, 10] # guess hyperparameters
#else:
# whypers, rhypers = np.ones(ncomp)/[1000], np.ones((ncomp, 2))*[1, 10] # guess hyperparameters


#tmp = mcrates[j]*np.exp(np.outer(-mcrates[j],x)).T
tmp = mcweights[j]*mcrates[j]*np.exp(np.outer(-mcrates[j],x)).T
z = (tmp.T/tmp.sum(axis=1)).T

c = z.cumsum(axis=1)
uu = np.random.rand(len(c), 1)
s = np.array((uu < c).argmax(axis=1))
Indicator[j] = s
#Indicator[j] = s
np.put_along_axis(indicator, s[:,None], np.take_along_axis(indicator, s[:,None], axis=1)+1, axis=1)

uniqs = np.unique(s)
inds = [np.where(s==i)[0] for i in range(ncomp)]

# Compute log posterior
#lnp[j] = np.log(tmp.take(s)).sum()+np.log(z.take(s)).sum()+(Ns[j]*np.log(mcweights[j])).sum()+sum([sum(-mcrates[j,i]*x[inds[i]]*np.log(x[inds[i]])) for i in range(ncomp)])
lnp[j] = np.log(tmp.take(s)).sum()+np.log(z.take(s)).sum()+np.log(mcweights[j][uniqs]).sum()+np.log(mcrates[j][uniqs]).sum()-mcrates[j][uniqs].sum()
lnp[j] = np.log(tmp.take(s)).sum()+\
np.log(mcweights[j][uniqs]).sum()-\
(mcrates[j][uniqs]*rhypers[uniqs, 1]).sum()+\
np.log(mcweights[j]**(whypers-1)).sum()

Ns[j][:] = np.array([len(s[s==i]) for i in range(ncomp)])
Ns[j][:] = np.array([len(inds[i]) for i in range(ncomp)])
Ts = np.array([x[inds[i]].sum() for i in range(ncomp)])

#ms = np.zeros(ncomp)
#ss = np.ones(ncomp)
##midTs = np.zeros(ncomp)
#for aval in uniqs:
# #midTs[aval] = np.array([np.median(x[inds[aval]])*Nj/len(x) for Nj in Ns[j]])
# ms[aval] = np.array([np.mean(x[inds[aval]])*Nj/len(x) for Nj in Ns[j]]).sum()
# ss[aval] = np.array([np.std(x[inds[aval]])*Nj/len(x) for Nj in Ns[j] if Nj>1]).sum()
#ss[ss==0] = 1

# Sample posteriors
mcweights[j+1] = rng.dirichlet(whypers+Ns[j])
mcrates[j+1] = rng.gamma(rhypers[:,0]+Ns[j], 1/(rhypers[:,1]+Ts))

## Compute cost matrix for occupied states (initial fix (eqn. ))
#tmpsum = np.ones((len(uniqs),len(uniqs)), dtype=np.float64)
#for ii,val in enumerate(uniqs):
# for jj,val2 in enumerate(uniqs):
# tmpsum[ii,jj] =
# mcrates[j+1][val]*Ts[val2]-Ns[j][val2]*np.log(mcrates[j+1][val])

# # Compute cost matrix for occupied states (initial fix (eqn. ))
#if j>500:
# tmpsum = np.ones((len(uniqs),len(uniqs)), dtype=np.float64)
# for ii,val in enumerate(uniqs):
# for jj,val2 in enumerate(uniqs):
# #tmpsum[ii, jj] = Ts[val2]/midTs[val]+\
# # Ns[j][val2]*np.log(midTs[val])
# tmpsum[ii, jj] = Ns[j][val]*(((x[inds[val]]-ms[val2])/ss[val2])**2).sum()
#

# sorts = lsa(tmpsum)[1]
# emptys = np.array([i for i in np.arange(ncomp) if i not in uniqs])
# if len(emptys)>0:
# sortinds = np.concatenate([uniqs[sorts], emptys])
# else:
# sortinds = uniqs[sorts]

# mcweights[j+1] = mcweights[j+1][sortinds]
# mcrates[j+1] = mcrates[j+1][sortinds]
# Ns[j], Ts = Ns[j][sortinds], Ts[sortinds]
# #midTs = midTs[sortinds]

#if j>500:
# uniqs = np.arange(len(uniqs))
# tmpsum = np.ones((len(uniqs),len(uniqs)), dtype=np.float64)
# for ii,val in enumerate(uniqs):
# for jj,val2 in enumerate(uniqs):
# #tmpsum[ii,jj] = mcrates[j+1][val]*Ts[val2]-Ns[j][val2]*\
# # np.log(mcweights[j+1][val]*mcrates[j+1][val])
# #tmpsum[ii, jj] = Ts[val2]/midTs[val]+\
# # Ns[j][val2]*np.log(len(x)*midTs[val]/Ns[j][val])
# #tmpsum[ii, jj] = abs(Ts[val2]*(1/midTs[val]-mcrates[j+1][val])+\
# # Ns[j][val2]*np.log(len(x)*midTs[val]/(Ns[j][val]*\
# # mcweights[j+1][val]*mcrates[j+1][val])))
# #tmpsum[ii, jj] = np.exp(1-mcweights[j+1][val2]*Ns[j][val]/(len(x)))
# #tmpsum[ii, jj] = abs(mcweights[j+1][val2]-Ns[j][val]/len(x))
# #abs(mcrates[j+1][val2]-1/Ts[val])
# #np.exp(abs(mcrates[j+1][val2]-1/midTs[val]))
# #tmpsum[ii, jj] = abs(mcweights[j+1][val2]-mcweights[j][val])*\
# # abs(mcrates[j+1][val2]-mcrates[j][val])
# # #np.exp(abs(mcrates[j+1][val2]-1/midTs[val]))
# #tmpsum[ii, jj] = \
# #Ns[j][val]*(((x[inds[val]]-ms[val2])/ss[val2])**2).sum()
# #Ns[j][val](((x[inds[val]]-ms[val2])/ss[val2])**2)
# #*np.exp(abs(1/midTs[val]-mcrates[j+1][val2]))
# #tmpsum[ii, jj] = Ts[val2]/midTs[val]+Ns[j][val2]*np.log(midTs[val])
# #tmpsum[ii, jj] = abs(pinvals[val2] - mcweights[j+1][val])
# #tmpsum[ii, jj] = abs(pinvals[val2] - mcrates[j+1][val])
# tmpsum[ii, jj] = Ts[val]/midTs[val2]+Ns[j][val]*np.log(midTs[val2])

# sorts = lsa(tmpsum)[1]
# emptys = np.array([i for i in np.arange(ncomp) if i not in uniqs])
# if len(emptys)>0:
# sortinds = np.concatenate([uniqs[sorts], emptys])
# else:
# sortinds = uniqs[sorts]

# mcweights[j+1] = mcweights[j+1][sortinds]
# mcrates[j+1] = mcrates[j+1][sortinds]
# Ns[j], Ts = Ns[j][sortinds], Ts[sortinds]
# midTs = midTs[sortinds]
#for i in range(ncomp):
# Indicator[j][inds[i]] = sortinds[i]
# test cost matrix
#tmpsum = np.ones((len(uniqs),len(uniqs)), dtype=np.float64)
#for ii,val in enumerate(uniqs):
Expand All @@ -144,42 +260,47 @@ def run(self):
# print(sort, uniqs)

## Relabel states
mcweights[j+1][uniqs] = mcweights[j+1][:len(uniqs)]
mcrates[j+1][uniqs] = mcrates[j+1][:len(uniqs)]
# mcweights[j+1][uniqs] = mcweights[j+1][:len(uniqs)]
# mcrates[j+1][uniqs] = mcrates[j+1][:len(uniqs)]

#mcweights[j+1][uniqs] = mcweights[j+1][:len(uniqs)]
#mcrates[j+1][uniqs] = mcrates[j+1][:len(uniqs)]


if self.sort:
if j==600:
pinvals = np.median(mcweights[200:600], axis=0)
if j>600:
#avgs = ((j-100)*avgs+mcweights[j+1])/(j+1-100)
#mids = np.median(mcweights[:j+1], axis=0)

## test cost matrix
tmpsum = np.ones((len(uniqs),len(uniqs)), dtype=np.float64)
for ii,val in enumerate(uniqs):
for jj,val2 in enumerate(uniqs):
tmpsum[ii,jj] = abs(mcweights[j+1][val]-pinvals[val2])

# Hungarian algorithm for minimum cost
sortinds = lsa(tmpsum)[1]

# Relabel states
mcweights[j+1][uniqs], mcrates[j+1][uniqs] = mcweights[j+1][sortinds], mcrates[j+1][sortinds]

gc.collect()
# ####################CHECK IF Ns Ts NEED TO BE SORTED AS WELL!!!!!
# if self.sort:
# if j==600:
# pinvals = np.median(mcweights[200:600], axis=0)
# if j>600:
# #avgs = ((j-100)*avgs+mcweights[j+1])/(j+1-100)
# #mids = np.median(mcweights[:j+1], axis=0)
#
# ## test cost matrix
# tmpsum = np.ones((len(uniqs),len(uniqs)), dtype=np.float64)
# for ii,val in enumerate(uniqs):
# for jj,val2 in enumerate(uniqs):
# tmpsum[ii,jj] = abs(mcweights[j+1][val]-pinvals[val2])
#
# # Hungarian algorithm for minimum cost
# sortinds = lsa(tmpsum)[1]
#
# # Relabel states
# mcweights[j+1][uniqs], mcrates[j+1][uniqs] = mcweights[j+1][sortinds], mcrates[j+1][sortinds]
#
# gc.collect()
# # ####################CHECK IF Ns Ts NEED TO BE SORTED AS WELL!!!!!


naninds = np.where(lnp!=lnp)[0]
lnp, Ns = np.delete(lnp, naninds), np.delete(Ns, naninds)
mcrates = np.delete(mcrates, naninds, axis=0)
mcweights = np.delete(mcweights, naninds, axis=0)
# naninds = np.where(lnp!=lnp)[0]
# lnp, Ns = np.delete(lnp, naninds), np.delete(Ns, naninds)
# mcrates = np.delete(mcrates, naninds, axis=0)
# mcweights = np.delete(mcweights, naninds, axis=0)

burnin, g, nsample = pmts.detect_equilibration(lnp, fast=False)
if self.niter>50000:
Fast=True
else:
Fast=False

burnin, g, nsample = pmts.detect_equilibration(lnp, fast=Fast)
g = np.ceil(g)

plt.close('all')
Expand Down
3 changes: 2 additions & 1 deletion scripts/gibbs_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
parser.add_argument('--resids', nargs='?')
parser.add_argument('--niter', nargs='?', default=10000)
parser.add_argument('--sort', nargs='?', default=True)
parser.add_argument('--ncomp', nargs='?', default=10, type=int)
args = parser.parse_args()
a = np.load(args.contacts)

ts, ncomp = 0.1, 10
ts, ncomp = 0.1, args.ncomp
cutoff = float(args.contacts.split('.npy')[0].split('_')[-1])
nproc, prot = 1, args.protname
if args.niter:
Expand Down
36 changes: 36 additions & 0 deletions scripts/gibbs_synth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from multiprocessing import shared_memory
from basicrta import *
from basicrta.functions import simulate_hn
from multiprocessing import Pool, Lock
from basicrta import istarmap
import numpy as np
import MDAnalysis as mda
import os
from tqdm import tqdm
import gc

if __name__ == "__main__":
# Parts of code taken from Shep (Centrifuge3.py, SuperMCMC.py)

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--niter', nargs='?', default=10000)
parser.add_argument('-N', nargs='?', default=100000)
parser.add_argument('--ncomp', nargs='?', default=10, type=int)
args = parser.parse_args()

ts, ncomp, N = 0.1, args.ncomp, args.N
if args.niter:
niter = int(args.niter)

residue, ts, nproc = 'X1', 0.1, 1
times = simulate_hn(N, [0.901, 0.09, 0.009], [5, 0.1, 0.001])

if not os.path.exists(f'X1'):
os.mkdir(f'X1')
os.chdir(f'X1')

input_list = np.array([residue, times, ts, ncomp, niter], dtype=object)
with Pool(nproc, initializer=tqdm.set_lock, initargs=(Lock(),)) as p:
for _ in tqdm(p.istarmap(run_residue, input_list), total=len(residue), position=0, desc='overall progress'):
pass

0 comments on commit 6ae2581

Please sign in to comment.