-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathscript_compare_rudy.py
49 lines (40 loc) · 2.01 KB
/
script_compare_rudy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import json
import numpy as np
from scipy.stats import pearsonr, spearmanr, kendalltau
from typing import Tuple, List, Dict, Any
test_dataset_name_iter = ('superblue19', 900)
LOG_DIR = f'log/superblue19'
logs: Dict[str, Any] = {}
def printout(arr1, arr2, prefix="", log_prefix=""):
pearsonr_rho, pearsonr_pval = pearsonr(arr1, arr2)
spearmanr_rho, spearmanr_pval = spearmanr(arr1, arr2)
kendalltau_rho, kendalltau_pval = kendalltau(arr1, arr2)
mae = np.sum(np.abs(arr1 - arr2)) / len(arr1)
delta = np.abs(arr1 - arr2)
rmse = np.sqrt(np.sum(np.multiply(delta, delta)) / len(arr1))
print(prefix + "pearson", pearsonr_rho, pearsonr_pval)
print(prefix + "spearman", spearmanr_rho, spearmanr_pval)
print(prefix + "kendall", kendalltau_rho, kendalltau_pval)
print(prefix + "MAE", mae)
print(prefix + "RMSE", rmse)
logs.update({
f'{log_prefix}pearson_rho': pearsonr_rho,
f'{log_prefix}pearsonr_pval': pearsonr_pval,
f'{log_prefix}spearmanr_rho': spearmanr_rho,
f'{log_prefix}spearmanr_pval': spearmanr_pval,
f'{log_prefix}kendalltau_rho': kendalltau_rho,
f'{log_prefix}kendalltau_pval': kendalltau_pval,
f'{log_prefix}mae': mae,
f'{log_prefix}rmse': rmse,
})
def load_rudy_nctu(raw_dir_name: str, given_iter: int) -> Tuple[np.ndarray, np.ndarray]:
h_bad_cmap = np.load(f'data/{raw_dir_name}/iter_{given_iter}_bad_cmap_h.npy')
v_bad_cmap = np.load(f'data/{raw_dir_name}/iter_{given_iter}_bad_cmap_v.npy')
h_cmap = np.load(f'data/{raw_dir_name}/iter_{given_iter}_cmap_h.npy')
v_cmap = np.load(f'data/{raw_dir_name}/iter_{given_iter}_cmap_v.npy')
return h_bad_cmap + v_bad_cmap, (h_cmap + v_cmap)[:, :, 0]
rudy_output, nctu_output = load_rudy_nctu(test_dataset_name_iter[0], test_dataset_name_iter[1])
rudy_output, nctu_output = rudy_output.flatten(), nctu_output.flatten()
printout(rudy_output, nctu_output, "\t\tGRID_NO_INDEX: ", f'test_grid_no_index_')
with open(f'{LOG_DIR}/RUDY.json', 'w+') as fp:
json.dump(logs, fp)