From 6ae25810e7b8acf6c0e9ae77ea7783ba3f992f4b Mon Sep 17 00:00:00 2001 From: Rick Sexton Date: Mon, 22 Jan 2024 19:35:40 -0700 Subject: [PATCH] modified scripts, removed on-the-fly sorting from gibbs --- basicrta/functions.py | 205 ++++++++++++++++++++++++++++++++-------- scripts/gibbs_serial.py | 3 +- scripts/gibbs_synth.py | 36 +++++++ 3 files changed, 201 insertions(+), 43 deletions(-) create mode 100644 scripts/gibbs_synth.py diff --git a/basicrta/functions.py b/basicrta/functions.py index d7a95af..9ade366 100644 --- a/basicrta/functions.py +++ b/basicrta/functions.py @@ -19,6 +19,7 @@ import gc from scipy.optimize import linear_sum_assignment as lsa import bz2 +from scipy import stats gc.enable() mpl.rcParams['pdf.fonttype'] = 42 rng = default_rng() @@ -31,6 +32,39 @@ 'make_surv', 'norm_exp', 'get_dec' ] + +def resort(r): + mcweights, mcrates = r.mcweights.copy(), r.mcrates.copy() + indicator[:] = indicator_bak + Ls, niter = [L], 0 + for j in tqdm(range(r.niter)): + sorts = mcweights[j].argsort()[::-1] + mcweights[j] = mcweights[j][sorts] + mcrates[j] = mcrates[j][sorts] + + while niter<10: + Z = np.zeros_like(z) + for j in tqdm(range(2000, 3000), desc='recomputing Q'): + tmp = mcweights[j]*mcrates[j]*np.exp(np.outer(-mcrates[j],x)).T + z = (tmp.T/tmp.sum(axis=1)).T + Z += z + Z = Z/1000 + + for j in tqdm(range(2000, 3000), desc='resorting'): + tmp = mcweights[j]*mcrates[j]*np.exp(np.outer(-mcrates[j],x)).T + z = (tmp.T/tmp.sum(axis=1)).T + + tmpsum = np.ones((ncomp,ncomp), dtype=np.float64) + for k in range(ncomp): + tmpsum[k] = np.sum(z[:,k]*np.log(z[:,k]/Z.T), axis=1) + + tmpsum[tmpsum!=tmpsum] = 1e20 + sorts = lsa(tmpsum)[1] + mcweights[j] = mcweights[j][sorts] + mcrates[j] = mcrates[j][sorts] + niter += 1 + + def tm(Prot,i): dif = Prot['tm{0}'.format(i)][1]-Prot['tm{0}'.format(i)][0] return [Prot['tm{0}'.format(i)],dif] @@ -81,8 +115,8 @@ def run(self): #mcweights = np.memmap(f'{residue}/.mcweights.npy', shape=(self.niter + 1, ncomp), mode='w+') #mcrates = np.memmap(f'{residue}/.mcrates.npy', shape=(self.niter + 1, ncomp), mode='w+') #Ns = np.memmap(f'{residue}/.Ns.npy', shape=(self.niter, ncomp), mode='w+') - Indicator = np.memmap(f'{residue}/.indicator.npy', shape=(self.niter, x.shape[0]), - mode='w+', dtype=np.uint8) + #Indicator = np.memmap(f'{residue}/.indicator_{self.niter}.npy', shape=(self.niter, x.shape[0]), + # mode='w+', dtype=np.uint8) mcweights = np.zeros((self.niter + 1, ncomp)) mcrates = np.zeros((self.niter + 1, ncomp)) Ns = np.zeros((self.niter, ncomp)) @@ -91,6 +125,14 @@ def run(self): tmpw = 9*10**(-np.arange(1, ncomp+1, dtype=float)) mcweights[0], mcrates[0] = tmpw/tmpw.sum(), inrates[::-1] whypers, rhypers = np.ones(ncomp)/[ncomp], np.ones((ncomp, 2))*[1, 3] # guess hyperparameters + #tmp = np.logspace(0,-2, ncomp) + #whypers = tmp/tmp.sum() + #tmpw = np.sort(np.outer(np.array([6, 3]), 10**np.arange(-ncomp//2, 0, dtype=float)).flatten())[::-1] + #whypers = (9*10**np.arange(-ncomp, 0, dtype=float))[::-1] + #whypers = tmpw/tmpw.sum() + #whypers = (np.arange(1,ncomp+1)[::-1]/np.arange(1,ncomp+1).sum()) + #mult = 3 + #whypers = (np.linspace(1,mult*(ncomp+1), ncomp)[::-1]/(np.linspace(1,mult*(ncomp+1), ncomp).sum())) weights, rates = [], [] g, burnin = 0, 0 @@ -99,6 +141,14 @@ def run(self): values = [mcweights, mcrates, ncomp, self.niter, _s, t, residue, Ns, lnp, int(g), int(burnin)] for j in tqdm(range(self.niter), desc=f'{residue}-K{ncomp}', position=self.loc, leave=False): + #if j<3000: + # whypers, rhypers = np.ones(ncomp)/[10], np.ones((ncomp, 2))*[1, 10] # guess hyperparameters + #elif (j>=3000)&(j<6000): + # whypers, rhypers = np.ones(ncomp)/[100], np.ones((ncomp, 2))*[1, 10] # guess hyperparameters + #else: + # whypers, rhypers = np.ones(ncomp)/[1000], np.ones((ncomp, 2))*[1, 10] # guess hyperparameters + + #tmp = mcrates[j]*np.exp(np.outer(-mcrates[j],x)).T tmp = mcweights[j]*mcrates[j]*np.exp(np.outer(-mcrates[j],x)).T z = (tmp.T/tmp.sum(axis=1)).T @@ -106,30 +156,96 @@ def run(self): c = z.cumsum(axis=1) uu = np.random.rand(len(c), 1) s = np.array((uu < c).argmax(axis=1)) - Indicator[j] = s + #Indicator[j] = s np.put_along_axis(indicator, s[:,None], np.take_along_axis(indicator, s[:,None], axis=1)+1, axis=1) uniqs = np.unique(s) inds = [np.where(s==i)[0] for i in range(ncomp)] # Compute log posterior - #lnp[j] = np.log(tmp.take(s)).sum()+np.log(z.take(s)).sum()+(Ns[j]*np.log(mcweights[j])).sum()+sum([sum(-mcrates[j,i]*x[inds[i]]*np.log(x[inds[i]])) for i in range(ncomp)]) - lnp[j] = np.log(tmp.take(s)).sum()+np.log(z.take(s)).sum()+np.log(mcweights[j][uniqs]).sum()+np.log(mcrates[j][uniqs]).sum()-mcrates[j][uniqs].sum() + lnp[j] = np.log(tmp.take(s)).sum()+\ + np.log(mcweights[j][uniqs]).sum()-\ + (mcrates[j][uniqs]*rhypers[uniqs, 1]).sum()+\ + np.log(mcweights[j]**(whypers-1)).sum() - Ns[j][:] = np.array([len(s[s==i]) for i in range(ncomp)]) + Ns[j][:] = np.array([len(inds[i]) for i in range(ncomp)]) Ts = np.array([x[inds[i]].sum() for i in range(ncomp)]) - + #ms = np.zeros(ncomp) + #ss = np.ones(ncomp) + ##midTs = np.zeros(ncomp) + #for aval in uniqs: + # #midTs[aval] = np.array([np.median(x[inds[aval]])*Nj/len(x) for Nj in Ns[j]]) + # ms[aval] = np.array([np.mean(x[inds[aval]])*Nj/len(x) for Nj in Ns[j]]).sum() + # ss[aval] = np.array([np.std(x[inds[aval]])*Nj/len(x) for Nj in Ns[j] if Nj>1]).sum() + #ss[ss==0] = 1 + # Sample posteriors mcweights[j+1] = rng.dirichlet(whypers+Ns[j]) mcrates[j+1] = rng.gamma(rhypers[:,0]+Ns[j], 1/(rhypers[:,1]+Ts)) - ## Compute cost matrix for occupied states (initial fix (eqn. )) - #tmpsum = np.ones((len(uniqs),len(uniqs)), dtype=np.float64) - #for ii,val in enumerate(uniqs): - # for jj,val2 in enumerate(uniqs): - # tmpsum[ii,jj] = - # mcrates[j+1][val]*Ts[val2]-Ns[j][val2]*np.log(mcrates[j+1][val]) - +# # Compute cost matrix for occupied states (initial fix (eqn. )) + #if j>500: + # tmpsum = np.ones((len(uniqs),len(uniqs)), dtype=np.float64) + # for ii,val in enumerate(uniqs): + # for jj,val2 in enumerate(uniqs): + # #tmpsum[ii, jj] = Ts[val2]/midTs[val]+\ + # # Ns[j][val2]*np.log(midTs[val]) + # tmpsum[ii, jj] = Ns[j][val]*(((x[inds[val]]-ms[val2])/ss[val2])**2).sum() + # + + # sorts = lsa(tmpsum)[1] + # emptys = np.array([i for i in np.arange(ncomp) if i not in uniqs]) + # if len(emptys)>0: + # sortinds = np.concatenate([uniqs[sorts], emptys]) + # else: + # sortinds = uniqs[sorts] + + # mcweights[j+1] = mcweights[j+1][sortinds] + # mcrates[j+1] = mcrates[j+1][sortinds] + # Ns[j], Ts = Ns[j][sortinds], Ts[sortinds] + # #midTs = midTs[sortinds] + + #if j>500: + # uniqs = np.arange(len(uniqs)) + # tmpsum = np.ones((len(uniqs),len(uniqs)), dtype=np.float64) + # for ii,val in enumerate(uniqs): + # for jj,val2 in enumerate(uniqs): + # #tmpsum[ii,jj] = mcrates[j+1][val]*Ts[val2]-Ns[j][val2]*\ + # # np.log(mcweights[j+1][val]*mcrates[j+1][val]) + # #tmpsum[ii, jj] = Ts[val2]/midTs[val]+\ + # # Ns[j][val2]*np.log(len(x)*midTs[val]/Ns[j][val]) + # #tmpsum[ii, jj] = abs(Ts[val2]*(1/midTs[val]-mcrates[j+1][val])+\ + # # Ns[j][val2]*np.log(len(x)*midTs[val]/(Ns[j][val]*\ + # # mcweights[j+1][val]*mcrates[j+1][val]))) + # #tmpsum[ii, jj] = np.exp(1-mcweights[j+1][val2]*Ns[j][val]/(len(x))) + # #tmpsum[ii, jj] = abs(mcweights[j+1][val2]-Ns[j][val]/len(x)) + # #abs(mcrates[j+1][val2]-1/Ts[val]) + # #np.exp(abs(mcrates[j+1][val2]-1/midTs[val])) + # #tmpsum[ii, jj] = abs(mcweights[j+1][val2]-mcweights[j][val])*\ + # # abs(mcrates[j+1][val2]-mcrates[j][val]) + # # #np.exp(abs(mcrates[j+1][val2]-1/midTs[val])) + # #tmpsum[ii, jj] = \ + # #Ns[j][val]*(((x[inds[val]]-ms[val2])/ss[val2])**2).sum() + # #Ns[j][val](((x[inds[val]]-ms[val2])/ss[val2])**2) + # #*np.exp(abs(1/midTs[val]-mcrates[j+1][val2])) + # #tmpsum[ii, jj] = Ts[val2]/midTs[val]+Ns[j][val2]*np.log(midTs[val]) + # #tmpsum[ii, jj] = abs(pinvals[val2] - mcweights[j+1][val]) + # #tmpsum[ii, jj] = abs(pinvals[val2] - mcrates[j+1][val]) + # tmpsum[ii, jj] = Ts[val]/midTs[val2]+Ns[j][val]*np.log(midTs[val2]) + + # sorts = lsa(tmpsum)[1] + # emptys = np.array([i for i in np.arange(ncomp) if i not in uniqs]) + # if len(emptys)>0: + # sortinds = np.concatenate([uniqs[sorts], emptys]) + # else: + # sortinds = uniqs[sorts] + + # mcweights[j+1] = mcweights[j+1][sortinds] + # mcrates[j+1] = mcrates[j+1][sortinds] + # Ns[j], Ts = Ns[j][sortinds], Ts[sortinds] + # midTs = midTs[sortinds] + #for i in range(ncomp): + # Indicator[j][inds[i]] = sortinds[i] # test cost matrix #tmpsum = np.ones((len(uniqs),len(uniqs)), dtype=np.float64) #for ii,val in enumerate(uniqs): @@ -144,42 +260,47 @@ def run(self): # print(sort, uniqs) ## Relabel states - mcweights[j+1][uniqs] = mcweights[j+1][:len(uniqs)] - mcrates[j+1][uniqs] = mcrates[j+1][:len(uniqs)] +# mcweights[j+1][uniqs] = mcweights[j+1][:len(uniqs)] +# mcrates[j+1][uniqs] = mcrates[j+1][:len(uniqs)] #mcweights[j+1][uniqs] = mcweights[j+1][:len(uniqs)] #mcrates[j+1][uniqs] = mcrates[j+1][:len(uniqs)] - if self.sort: - if j==600: - pinvals = np.median(mcweights[200:600], axis=0) - if j>600: - #avgs = ((j-100)*avgs+mcweights[j+1])/(j+1-100) - #mids = np.median(mcweights[:j+1], axis=0) - - ## test cost matrix - tmpsum = np.ones((len(uniqs),len(uniqs)), dtype=np.float64) - for ii,val in enumerate(uniqs): - for jj,val2 in enumerate(uniqs): - tmpsum[ii,jj] = abs(mcweights[j+1][val]-pinvals[val2]) - - # Hungarian algorithm for minimum cost - sortinds = lsa(tmpsum)[1] - - # Relabel states - mcweights[j+1][uniqs], mcrates[j+1][uniqs] = mcweights[j+1][sortinds], mcrates[j+1][sortinds] - - gc.collect() - # ####################CHECK IF Ns Ts NEED TO BE SORTED AS WELL!!!!! +# if self.sort: +# if j==600: +# pinvals = np.median(mcweights[200:600], axis=0) +# if j>600: +# #avgs = ((j-100)*avgs+mcweights[j+1])/(j+1-100) +# #mids = np.median(mcweights[:j+1], axis=0) +# +# ## test cost matrix +# tmpsum = np.ones((len(uniqs),len(uniqs)), dtype=np.float64) +# for ii,val in enumerate(uniqs): +# for jj,val2 in enumerate(uniqs): +# tmpsum[ii,jj] = abs(mcweights[j+1][val]-pinvals[val2]) +# +# # Hungarian algorithm for minimum cost +# sortinds = lsa(tmpsum)[1] +# +# # Relabel states +# mcweights[j+1][uniqs], mcrates[j+1][uniqs] = mcweights[j+1][sortinds], mcrates[j+1][sortinds] +# +# gc.collect() +# # ####################CHECK IF Ns Ts NEED TO BE SORTED AS WELL!!!!! - naninds = np.where(lnp!=lnp)[0] - lnp, Ns = np.delete(lnp, naninds), np.delete(Ns, naninds) - mcrates = np.delete(mcrates, naninds, axis=0) - mcweights = np.delete(mcweights, naninds, axis=0) +# naninds = np.where(lnp!=lnp)[0] +# lnp, Ns = np.delete(lnp, naninds), np.delete(Ns, naninds) +# mcrates = np.delete(mcrates, naninds, axis=0) +# mcweights = np.delete(mcweights, naninds, axis=0) - burnin, g, nsample = pmts.detect_equilibration(lnp, fast=False) + if self.niter>50000: + Fast=True + else: + Fast=False + + burnin, g, nsample = pmts.detect_equilibration(lnp, fast=Fast) g = np.ceil(g) plt.close('all') diff --git a/scripts/gibbs_serial.py b/scripts/gibbs_serial.py index 9449d65..a4c93be 100644 --- a/scripts/gibbs_serial.py +++ b/scripts/gibbs_serial.py @@ -19,10 +19,11 @@ parser.add_argument('--resids', nargs='?') parser.add_argument('--niter', nargs='?', default=10000) parser.add_argument('--sort', nargs='?', default=True) + parser.add_argument('--ncomp', nargs='?', default=10, type=int) args = parser.parse_args() a = np.load(args.contacts) - ts, ncomp = 0.1, 10 + ts, ncomp = 0.1, args.ncomp cutoff = float(args.contacts.split('.npy')[0].split('_')[-1]) nproc, prot = 1, args.protname if args.niter: diff --git a/scripts/gibbs_synth.py b/scripts/gibbs_synth.py new file mode 100644 index 0000000..b282966 --- /dev/null +++ b/scripts/gibbs_synth.py @@ -0,0 +1,36 @@ +from multiprocessing import shared_memory +from basicrta import * +from basicrta.functions import simulate_hn +from multiprocessing import Pool, Lock +from basicrta import istarmap +import numpy as np +import MDAnalysis as mda +import os +from tqdm import tqdm +import gc + +if __name__ == "__main__": + # Parts of code taken from Shep (Centrifuge3.py, SuperMCMC.py) + + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--niter', nargs='?', default=10000) + parser.add_argument('-N', nargs='?', default=100000) + parser.add_argument('--ncomp', nargs='?', default=10, type=int) + args = parser.parse_args() + + ts, ncomp, N = 0.1, args.ncomp, args.N + if args.niter: + niter = int(args.niter) + + residue, ts, nproc = 'X1', 0.1, 1 + times = simulate_hn(N, [0.901, 0.09, 0.009], [5, 0.1, 0.001]) + + if not os.path.exists(f'X1'): + os.mkdir(f'X1') + os.chdir(f'X1') + + input_list = np.array([residue, times, ts, ncomp, niter], dtype=object) + with Pool(nproc, initializer=tqdm.set_lock, initargs=(Lock(),)) as p: + for _ in tqdm(p.istarmap(run_residue, input_list), total=len(residue), position=0, desc='overall progress'): + pass