Skip to content

Commit edd245b

Browse files
committed
modified parametric test
1 parent cd1ccd5 commit edd245b

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

basicrta/tests/test_parametric.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,39 @@
11
from basicrta.functions import simulate_hn
22
from basicrta.functions import newgibbs
33
import numpy as np
4+
from scipy.optimize import linear_sum_assignment as lsa
45

56
def test_parametric():
67
wts = np.array([0.89, 0.098, 0.008, 0.002, 0.00056])
78
wts = wts/wts.sum()
8-
x = simulate_hn(5e4, wts, [4.7, 0.8, 0.2, 0.02, 0.003])
9+
rts = [4.7, 0.8, 0.2, 0.02, 0.003]
10+
x = simulate_hn(1e5, wts, rts)
911
G = newgibbs(x, 'X1', 0, 0.1, ncomp=5, niter=10000, sort=False)
1012
G.run()
1113

12-
for i,rts in enumerate(G.results.mcrates):
13-
sorts = rts.argsort()[::-1]
14-
G.results.mcweights[i] = G.results.mcweights[i][sorts]
15-
G.results.mcrates[i] = G.results.mcrates[i][sorts]
14+
for i in range(len(G.results.mcrates)):
15+
tmpsum = np.ones((5, 5), dtype=np.float64)
16+
for ii in range(5):
17+
for jj in range(5):
18+
tmpsum[ii,jj] = abs(G.results.mcrates[i][ii]-rts[jj])
19+
20+
# Hungarian algorithm for minimum cost
21+
sortinds = lsa(tmpsum)[1]
22+
23+
# Relabel states
24+
G.results.mcweights[i] = G.results.mcweights[i][sortinds]
25+
G.results.mcrates[i] = G.results.mcrates[i][sortinds]
1626

1727
tmp = np.array([np.sort(G.results.weights[:,i]) for i in range(G.results.ncomp)])
1828
tmp2 = (tmp.cumsum(axis=1).T/tmp.cumsum(axis=1).T[-1])
1929
tmp3 = tmp.T[[np.where((tmp2[:,i]>0.025)&(tmp2[:,i]<0.975))[0] for i in range(G.results.ncomp)][0]]
2030
descsort = G.results.mcweights.mean(axis=0).argsort()[::-1]
2131
ci = np.array([[line[0],line[-1]] for line in tmp3.T])
2232

23-
2433
Bools = np.array([(wts[i]>ci[descsort][i,0])&(wts[i]<ci[descsort][i,1]) for i in descsort])
2534

2635
assert Bools.all() == True
2736

2837
if __name__=="__main__":
2938
test_parametric()
39+

0 commit comments

Comments
 (0)