|
1 | 1 | from basicrta.functions import simulate_hn
|
2 | 2 | from basicrta.functions import newgibbs
|
3 | 3 | import numpy as np
|
| 4 | +from scipy.optimize import linear_sum_assignment as lsa |
4 | 5 |
|
5 | 6 | def test_parametric():
|
6 | 7 | wts = np.array([0.89, 0.098, 0.008, 0.002, 0.00056])
|
7 | 8 | wts = wts/wts.sum()
|
8 |
| - x = simulate_hn(5e4, wts, [4.7, 0.8, 0.2, 0.02, 0.003]) |
| 9 | + rts = [4.7, 0.8, 0.2, 0.02, 0.003] |
| 10 | + x = simulate_hn(1e5, wts, rts) |
9 | 11 | G = newgibbs(x, 'X1', 0, 0.1, ncomp=5, niter=10000, sort=False)
|
10 | 12 | G.run()
|
11 | 13 |
|
12 |
| - for i,rts in enumerate(G.results.mcrates): |
13 |
| - sorts = rts.argsort()[::-1] |
14 |
| - G.results.mcweights[i] = G.results.mcweights[i][sorts] |
15 |
| - G.results.mcrates[i] = G.results.mcrates[i][sorts] |
| 14 | + for i in range(len(G.results.mcrates)): |
| 15 | + tmpsum = np.ones((5, 5), dtype=np.float64) |
| 16 | + for ii in range(5): |
| 17 | + for jj in range(5): |
| 18 | + tmpsum[ii,jj] = abs(G.results.mcrates[i][ii]-rts[jj]) |
| 19 | + |
| 20 | + # Hungarian algorithm for minimum cost |
| 21 | + sortinds = lsa(tmpsum)[1] |
| 22 | + |
| 23 | + # Relabel states |
| 24 | + G.results.mcweights[i] = G.results.mcweights[i][sortinds] |
| 25 | + G.results.mcrates[i] = G.results.mcrates[i][sortinds] |
16 | 26 |
|
17 | 27 | tmp = np.array([np.sort(G.results.weights[:,i]) for i in range(G.results.ncomp)])
|
18 | 28 | tmp2 = (tmp.cumsum(axis=1).T/tmp.cumsum(axis=1).T[-1])
|
19 | 29 | tmp3 = tmp.T[[np.where((tmp2[:,i]>0.025)&(tmp2[:,i]<0.975))[0] for i in range(G.results.ncomp)][0]]
|
20 | 30 | descsort = G.results.mcweights.mean(axis=0).argsort()[::-1]
|
21 | 31 | ci = np.array([[line[0],line[-1]] for line in tmp3.T])
|
22 | 32 |
|
23 |
| - |
24 | 33 | Bools = np.array([(wts[i]>ci[descsort][i,0])&(wts[i]<ci[descsort][i,1]) for i in descsort])
|
25 | 34 |
|
26 | 35 | assert Bools.all() == True
|
27 | 36 |
|
28 | 37 | if __name__=="__main__":
|
29 | 38 | test_parametric()
|
| 39 | + |
0 commit comments