-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrepresentational_similarity_analysis.py
77 lines (60 loc) · 2.42 KB
/
representational_similarity_analysis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""Calculate similarity between representational dissimilarity matrices."""
import glob
import logging
import os
import hydra
import matplotlib.pyplot as plt
import rsatoolbox
from omegaconf import DictConfig, OmegaConf
from fingers_rsa import rdm_utils
log = logging.getLogger(__name__)
@hydra.main(config_path="config", config_name="rsa", version_base="1.1")
def main(cfg: DictConfig) -> None:
log.debug("Config args:\n{}".format(OmegaConf.to_yaml(cfg)))
# Convert config parameters as needed
data_rdm_files = glob.glob(cfg.rdm_files)
assert len(data_rdm_files) > 0, "No RDM files found at: {}.".format(cfg.rdm_files)
filename_prefix = filename(cfg)
# Read in the RDM files
data_rdm_list = [rsatoolbox.rdm.load_rdm(rdm_file) for rdm_file in data_rdm_files]
data_rdms = rsatoolbox.rdm.concat(data_rdm_list)
# Load in models. These should be saved as RDM files, too.
if cfg.metrics.similarity.startswith("cosine"):
cfg.rsa.models += ["unstructured"]
model_rdms = rdm_utils.load_models(
cfg.rsa.models, model_dir=hydra.utils.to_absolute_path("models")
)
assert model_rdms, "No models found with cfg: {}".format(cfg.rsa.models)
assert (
data_rdms.n_rdm > 1
), "Need more than one RDM to calculate RSA confidence intervals"
results: rsatoolbox.inference.Result = rsatoolbox.inference.eval_bootstrap_rdm(
model_rdms,
data_rdms,
method=cfg.metrics.similarity,
)
# Save results to hdf5, including metadata
results_filename = filename_prefix + ".hdf5"
results.save(results_filename, file_type="hdf5", overwrite=True)
log.info("Saved results to file: {}".format(os.path.abspath(results_filename)))
rsatoolbox.vis.plot_model_comparison(
results,
test_pair_comparisons="golan",
)
# Save plot
# matplotlib will automatically add the file extension, assuming
# `plot_filename` doesn't have a period
plot_filename = filename_prefix
plt.savefig(plot_filename, bbox_inches="tight")
log.info("Saved plot to file: {}".format(os.path.abspath(plot_filename)))
plt.show()
def filename(cfg: DictConfig) -> str:
"""Generate filename for representational similarity analysis.
:param cfg: Hydra config object, including task info
:return: filename for RSA
"""
return ("sub-{subject}_rsa").format(
subject=cfg.array.subject,
)
if __name__ == "__main__":
main()