Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add config loader and store method #23

Merged
merged 8 commits into from
May 12, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions Example/test/configs/save_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from Example.test.configs.l2_3 import *
from Example.test.configs.l4 import *
from Example.test.configs.l5 import *
from Example.test.configs.sens_l4 import *
from Example.test.configs.l2_3_l5 import *
from Example.test.configs.l4_l2_3 import *
from Example.test.configs.l5_l2_3 import *

from conex.nn.Config.base_config import BaseConfig

if __name__ == '__main__':
l2_3().save_as_yaml(file_name="config-snn", hard_refresh=True)
l4().save_as_yaml(file_name="config-snn")
l5().save_as_yaml(file_name="config-snn")
sens_l4().save_as_yaml(file_name="config-snn")
l2_3_l5().save_as_yaml(file_name="config-snn")
l4_l2_3().save_as_yaml(file_name="config-snn")
l5_l2_3().save_as_yaml(file_name="config-snn")

l5_l2_3_instance = l5_l2_3()
l5_l2_3_instance.update_from_yaml(file_name="config-snn")
l5_l2_3_instance.exc_exc_structure_params['current_coef'] = 6
l5_l2_3_instance.save_as_yaml(file_name="new-config")

loaded_instances = BaseConfig.load_from_yaml(file_name="config-snn")

112 changes: 110 additions & 2 deletions conex/nn/Config/base_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import os
import yaml
import collections.abc

from yaml import Loader
from typing import Tuple, Union, Callable
from pymonntorch import *

import collections.abc
from pymonntorch import *


class BaseConfig:
Expand Down Expand Up @@ -30,3 +34,107 @@ def update_make(self, **kwargs):

def __call__(self, *args, **kwargs):
self.make(*args, **kwargs)

def _get_members(self, sort_by):
members = [
attr
for attr in dir(self)
if attr
not in [
"update_from_yaml",
"load_from_yaml",
"save_as_yaml",
"make",
"update",
"update_make",
"deep_update",
]
and not attr.startswith("_")
]
members.sort(key=sort_by)
return members

def save_as_yaml(
self,
file_name,
scope_key=None,
configs_dir=".",
sort_by=None,
hard_refresh=False,
):
"""
Args:
file_name: file name where configs are going to be saved in.
scope_key: scope where class config will be placed inside the yaml file, default is class name
configs_dir: where configs are going to be saved
sort_by: sorting algorithm used for ordering in the yaml configuration file
hard_refresh: Pass True if you want to create new file under the same file_name

Returns: None
"""
members = self._get_members(sort_by)

yaml_attributes_content = {attr: getattr(self, attr) for attr in members}

scope_key = scope_key or self.__class__.__name__
yaml_content = {
scope_key: {
"parameters": yaml_attributes_content,
"class": {self.__class__.__name__: self.__class__},
}
}

if not file_name.endswith(".yml"):
file_name += ".yml"

file_path = os.path.join(configs_dir, file_name)

if os.path.isfile(file_path) and hard_refresh:
os.remove(file_path)
print(
f"The file {file_name} has been deleted. A brand new config is going to be created!"
)

with open(file_path, "a") as yaml_file:
yaml.dump(yaml_content, yaml_file, default_flow_style=False)

def update_from_yaml(
self, file_name, scope_key=None, configs_dir=".", force_update=False
):
if not file_name.endswith(".yml"):
file_name += ".yml"
file_path = os.path.join(configs_dir, file_name)
with open(file_path, "r") as yaml_file:
yaml_content = yaml.load(yaml_file, Loader=Loader)

scope_key = scope_key or self.__class__.__name__
contents = yaml_content[scope_key]

if not force_update:
assert (
self.__class__.__name__ in contents["class"]
), "YAML config should have been dumped from same class."

self.update(contents["parameters"])

@staticmethod
def _make_config_instance(config):
config_class = list(config["class"].values())[0]
instance = config_class()
instance.update(config["parameters"])
return instance

@staticmethod
def load_from_yaml(file_name, configs_dir="."):
if not file_name.endswith(".yml"):
file_name += ".yml"
file_path = os.path.join(configs_dir, file_name)
with open(file_path, "r") as yaml_file:
yaml_content = yaml.load(yaml_file, Loader=Loader)

configs = {
scope: BaseConfig._make_config_instance(content)
for scope, content in yaml_content.items()
}

return configs
3 changes: 3 additions & 0 deletions conex/nn/Structure/CorticalColumn.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ def _add_synaptic_connection(cls, src, dst, config):
synapses[key] = StructuredSynapseGroup(
src=src_pop, dst=dst_pop, net=net, **config[key]
)

if tag in synapses[key].tags:
synapses[key].tags.remove(tag)
synapses[key].tags.insert(0, tag)
synapses[key].add_tag(key)

Expand Down
5 changes: 4 additions & 1 deletion conex/nn/Structure/Layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,10 @@ def _create_neural_population(net, config, tag):
def _create_synaptic_connection(src, dst, net, config):
if isinstance(config, dict):
syn = StructuredSynapseGroup(src, dst, net, **config)
syn.tags.insert(0, f"{src.tags[0]} => {dst.tags[0]}")
connection_tag = f"{src.tags[0]} => {dst.tags[0]}"
if connection_tag in syn.tags:
syn.tags.remove(connection_tag)
syn.tags.insert(0, connection_tag)
syn.add_tag("Proximal")
return syn
else:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ pymonntorch==0.1.0
pytest==7.1.2
setuptools==65.5.0
torch==1.13.1
pyyaml==6.0