Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add config loader and store method #23

Merged
merged 8 commits into from
May 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions Example/test/configs/save_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from Example.test.configs.l2_3 import *
from Example.test.configs.l4 import *
from Example.test.configs.l5 import *
from Example.test.configs.sens_l4 import *
from Example.test.configs.l2_3_l5 import *
from Example.test.configs.l4_l2_3 import *
from Example.test.configs.l5_l2_3 import *

from conex.nn.Config.base_config import BaseConfig

if __name__ == '__main__':
config_type = 'yml' # json
l2_3().save(file_name=f"config-snn.{config_type}", hard_refresh=True)
l4().save(file_name=f"config-snn.{config_type}")
l5().save(file_name=f"config-snn.{config_type}")
sens_l4().save(file_name=f"config-snn.{config_type}")
l2_3_l5().save(file_name=f"config-snn.{config_type}")
l4_l2_3().save(file_name=f"config-snn.{config_type}")
l5_l2_3().save(file_name=f"config-snn.{config_type}")

l5_l2_3_instance = l5_l2_3()
l5_l2_3_instance.update_file(file_name=f"config-snn.{config_type}")
l5_l2_3_instance.exc_exc_structure_params['current_coef'] = 6
l5_l2_3_instance.save(file_name=f"new-config.{config_type}")

loaded_instances = BaseConfig.load(file_name=f"config-snn.{config_type}")
224 changes: 222 additions & 2 deletions conex/nn/Config/base_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import Tuple, Union, Callable
import json
import os
import collections.abc

from pymonntorch import *

import collections.abc
YML_EXT = ".yml"
JSON_EXT = ".json"


class BaseConfig:
Expand Down Expand Up @@ -30,3 +34,219 @@ def update_make(self, **kwargs):

def __call__(self, *args, **kwargs):
self.make(*args, **kwargs)

def _get_members(self, sort_by):
public_members = [
"save",
"load",
"update_file",
"make",
"update",
"update_make",
"deep_update",
]
members = [
attr
for attr in dir(self)
if (attr not in public_members and not attr.startswith("_"))
]
members.sort(key=sort_by)
return members

def _store_content(self,
file_name,
scope_key=None,
configs_dir=".",
sort_by=None,
hard_refresh=False,
io_store_function=None,
io_store_function_params={}
):
"""
Args:
file_name: file name where configs are going to be saved in.
scope_key: scope where class config will be placed inside the yaml file, default is class name
configs_dir: where configs are going to be saved
sort_by: sorting algorithm used for ordering in the yaml configuration file
hard_refresh: Pass True if you want to create new file under the same file_name
io_store_function: Function used to write content on io yaml.dump or json.dump
io_store_function_params: Params that will be passed to the io_store_function

Returns: None
"""
members = self._get_members(sort_by)

scope_key = scope_key or self.__class__.__name__
data_content = {
scope_key: {
"parameters": {attr: getattr(self, attr) for attr in members},
"class": {self.__class__.__name__: self.__class__},
}
}

file_path = os.path.join(configs_dir, file_name)

if os.path.isfile(file_path) and hard_refresh:
os.remove(file_path)
print(
f"The file {file_name} has been deleted. A brand new config is going to be created!"
)

with open(file_path, "a+") as output_file:
io_store_function(data_content, output_file, **io_store_function_params)

def _update_content(
self,
file_name,
scope_key=None,
configs_dir=".",
force_update=False,
io_load_function=None,
io_load_function_kwargs={}
):
file_path = os.path.join(configs_dir, file_name)
with open(file_path, "r") as input_file:
data_content = io_load_function(input_file, **io_load_function_kwargs)

scope_key = scope_key or self.__class__.__name__
contents = data_content[scope_key]

if not force_update:
assert (
self.__class__.__name__ in contents["class"]
), "Config should have been dumped from same class."

self.update(contents["parameters"])

@staticmethod
def _make_config_instance(config):
config_class = list(config["class"].values())[0]
instance = config_class()
instance.update(config["parameters"])
return instance

@staticmethod
def _load_content(
file_name,
configs_dir=".",
io_load_function=None,
io_load_function_kwargs={}):

file_path = os.path.join(configs_dir, file_name)
with open(file_path, "r") as input_file:
data_content = io_load_function(input_file, **io_load_function_kwargs)

configs = {
scope: BaseConfig._make_config_instance(content)
for scope, content in data_content.items()
}

return configs

@staticmethod
def _has_yaml_module():
# NOTE: importlib.util.find_spec didn't work as expected!
try:
import yaml
except ImportError as e:
raise ImportError("For using yaml file, you must have pyyaml==6.0 installed your environment!")

def _save_as_yaml(
self,
file_name,
**store_content_kwargs,
):
self._has_yaml_module()
import yaml

self._store_content(
file_name,
**store_content_kwargs,
io_store_function=yaml.dump,
io_store_function_params={"default_flow_style": False}
)

def _save_as_json(self, file_name, **store_content_kwargs):

def default_loader(o):
try:
return o.__dict__
except Exception as e:
print(
f'Object {getattr(o, "__module__", o.__class__.__name__)} '
'is not json serializable, consider using yaml file or override the o.__dict__ method')
return lambda: None

def io_store_function(data_content, output_file, **kwargs):
output_file.seek(0)
file_content = output_file.read()
loaded_data = json.loads(file_content) if file_content else {}
loaded_data = {**loaded_data, **data_content}
output_file.truncate(0)
json.dump(loaded_data, output_file, **kwargs)

self._store_content(
file_name,
**store_content_kwargs,
io_store_function=io_store_function,
io_store_function_params={"indent": 2, "default": default_loader}
)

def _update_from_json(
self, file_name, **load_content_kwargs

):
self._update_content(file_name, **load_content_kwargs, io_load_function=json.load)

def _update_from_yaml(
self, file_name, **load_content_kwargs
):
self._has_yaml_module()
import yaml
from yaml import Loader

self._update_content(file_name,
**load_content_kwargs,
io_load_function=yaml.load,
io_load_function_kwargs={'Loader': Loader})

@staticmethod
def _load_from_yaml(file_name, **load_content_kwargs):
BaseConfig._has_yaml_module()
import yaml
from yaml import Loader

return BaseConfig._load_content(
file_name,
**load_content_kwargs,
io_load_function=yaml.load,
io_load_function_kwargs={'Loader': Loader})

@staticmethod
def _load_from_json(file_name, **load_content_kwargs):
return BaseConfig._load_content(file_name, **load_content_kwargs, io_load_function=json.load)

def save(self, file_name, **store_content_kwargs):
if file_name.endswith(YML_EXT):
self._save_as_yaml(file_name, **store_content_kwargs)
elif file_name.endswith(JSON_EXT):
self._save_as_json(file_name, **store_content_kwargs)
else:
raise TypeError(f'{file_name} must end with .json or .yml')

def update_file(self, file_name, **update_content_kwargs):
if file_name.endswith(YML_EXT):
self._update_from_yaml(file_name, **update_content_kwargs)
elif file_name.endswith(JSON_EXT):
self._update_from_json(file_name, **update_content_kwargs)
else:
raise TypeError(f'{file_name} must end with .json or .yaml')

@staticmethod
def load(file_name, **load_content_kwargs):
if file_name.endswith(YML_EXT):
BaseConfig._load_from_yaml(file_name, **load_content_kwargs)
elif file_name.endswith(JSON_EXT):
BaseConfig._load_from_json(file_name, **load_content_kwargs)
else:
raise TypeError(f'{file_name} must end with .json or .yaml')
34 changes: 14 additions & 20 deletions conex/nn/Modules/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def __init__(
self.network = net
net.input_layers.append(self)

sensory_tag = "Sensory" if sensory_tag is None else "Sensory," + sensory_tag
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sure that no errors or issues will be encountered in this format. using underline rather than comma is safer in my opinion.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nein. (add_tag process this tag(made of comma-separated tags))


if have_sensory:
self.sensory_pop = self.__get_ng(
net,
Expand All @@ -62,13 +64,12 @@ def __init__(
sensory_user_defined,
)

if sensory_tag is None:
self.sensory_pop.tags.insert(0, "Sensory")
else:
self.sensory_pop.add_tag("Sensory")

self.sensory_pop.layer = self

location_tag = (
"Location" if location_tag is None else "Location," + location_tag
)

if have_location:
self.location_pop = self.__get_ng(
net,
Expand All @@ -79,11 +80,6 @@ def __init__(
location_user_defined,
)

if location_tag is None:
self.location_pop.tags.insert(0, "Location")
else:
self.location_pop.add_tag("Location")

self.location_pop.layer = self

def connect(
Expand Down Expand Up @@ -158,6 +154,12 @@ def __init__(
self.network = net
net.output_layers.append(self)

representation_tag = (
"Representation"
if representation_tag is None
else "Representation," + representation_tag
)

if representation_size is not None:
self.representation_pop = self.__get_ng(
net,
Expand All @@ -168,13 +170,10 @@ def __init__(
representation_user_defined,
)

if representation_tag is None:
self.representation_pop.tags.insert(0, "Representation")
else:
self.representation_pop.add_tag("Representation")

self.representation_pop.layer = self

motor_tag = "Motor" if motor_tag is None else "Motor," + motor_tag

if motor_size is not None:
self.motor_pop = self.__get_ng(
net,
Expand All @@ -185,11 +184,6 @@ def __init__(
motor_user_defined,
)

if motor_tag is None:
self.motor_pop.tags.insert(0, "Motor")
else:
self.motor_pop.add_tag("Motor")

self.motor_pop.layer = self

self.add_tag(self.__class__.__name__)
Expand Down
18 changes: 7 additions & 11 deletions conex/nn/Structure/CorticalColumn.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,6 @@ def _add_synaptic_connection(cls, src, dst, config):

synapses = {}
for key in config:
src_tag = src.tags[0]
if pop_tag := config[key].get("src_pop"):
src_tag = src_tag + "_" + pop_tag.removesuffix("_pop")
dst_tag = dst.tags[0]
if pop_tag := config[key].get("dst_pop"):
dst_tag = dst_tag + "_" + pop_tag.removesuffix("_pop")
tag = f"{src_tag} => {dst_tag}"

if isinstance(config[key], dict):
if isinstance(src, NeuronGroup):
src_pop = src
Expand All @@ -129,7 +121,7 @@ def _add_synaptic_connection(cls, src, dst, config):
synapses[key] = StructuredSynapseGroup(
src=src_pop, dst=dst_pop, net=net, **config[key]
)
synapses[key].tags.insert(0, tag)

synapses[key].add_tag(key)

if not (
Expand All @@ -140,9 +132,13 @@ def _add_synaptic_connection(cls, src, dst, config):
for connection in ["Proximal", "Distal", "Apical"]
)
):
if hasattr(dst, "cortical_column"):
if hasattr(src, "cortical_column") and hasattr(
dst, "cortical_column"
):
if src.cortical_column == dst.cortical_column:
if "L4" in src_tag and "L2_3" in dst_tag:
if any("L4" in tag for tag in src.tags) and any(
"L2_3" in tag for tag in dst.tags
):
synapses[key].add_tag("Proximal")
else:
synapses[key].add_tag("Distal")
Expand Down
1 change: 0 additions & 1 deletion conex/nn/Structure/Layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def _create_neural_population(net, config, tag):
def _create_synaptic_connection(src, dst, net, config):
if isinstance(config, dict):
syn = StructuredSynapseGroup(src, dst, net, **config)
syn.tags.insert(0, f"{src.tags[0]} => {dst.tags[0]}")
syn.add_tag("Proximal")
return syn
else:
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ click==8.0.4
pymonntorch==0.1.0
pytest==7.1.2
setuptools==65.5.0
torch==1.13.1
torch==1.13.1