Skip to content

Commit 27fcbb6

Browse files
committed
ENH BatchSimulate for JSON path handling
Signed-off-by: samadpls <[email protected]>
1 parent 100baad commit 27fcbb6

File tree

2 files changed

+41
-4
lines changed

2 files changed

+41
-4
lines changed

hnn_core/batch_simulate.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# Ryan Thorpe <[email protected]>
66
# Mainak Jas <[email protected]>
77

8+
import json
89
import numpy as np
910
import os
1011
from joblib import Parallel, delayed, parallel_config
@@ -13,6 +14,7 @@
1314
from .externals.mne import _validate_type, _check_option
1415
from .dipole import simulate_dipole
1516
from .network_models import jones_2009_model
17+
from .hnn_io import dict_to_network
1618

1719

1820
class BatchSimulate(object):
@@ -24,7 +26,7 @@ def __init__(self, set_params, net=jones_2009_model(), tstop=170,
2426
save_dpl=True, save_spiking=False,
2527
save_lfp=False, save_voltages=False,
2628
save_currents=False, save_calcium=False,
27-
clear_cache=False):
29+
clear_cache=False, net_json=None):
2830
"""Initialize the BatchSimulate class.
2931
3032
Parameters
@@ -100,6 +102,9 @@ def __init__(self, set_params, net=jones_2009_model(), tstop=170,
100102
clear_cache : bool, optional
101103
Whether to clear the results cache after saving each batch.
102104
Default is False.
105+
net_json : str, optional
106+
The path to a JSON file to create the network model. If provided,
107+
this will override the `net` parameter. Default is None.
103108
Notes
104109
-----
105110
When `save_output=True`, the saved files will appear as
@@ -127,6 +132,8 @@ def __init__(self, set_params, net=jones_2009_model(), tstop=170,
127132
_validate_type(save_currents, types=(bool,), item_name='save_currents')
128133
_validate_type(save_calcium, types=(bool,), item_name='save_calcium')
129134
_validate_type(clear_cache, types=(bool,), item_name='clear_cache')
135+
_validate_type(net_json, types=('path-like', None),
136+
item_name='net_json')
130137

131138
if set_params is not None and not callable(set_params):
132139
raise TypeError("set_params must be a callable function")
@@ -154,6 +161,7 @@ def __init__(self, set_params, net=jones_2009_model(), tstop=170,
154161
self.save_currents = save_currents
155162
self.save_calcium = save_calcium
156163
self.clear_cache = clear_cache
164+
self.net_json = net_json
157165

158166
def run(self, param_grid, return_output=True,
159167
combinations=True, n_jobs=1, backend='loky',
@@ -295,7 +303,14 @@ def _run_single_sim(self, param_values):
295303
- `param_values`: The parameter values used for the simulation.
296304
"""
297305

298-
net = self.net.copy()
306+
if isinstance(self.net_json, str):
307+
with open(self.net_json, 'r') as file:
308+
net_data = json.load(file)
309+
net = dict_to_network(net_data)
310+
else:
311+
net = self.net
312+
net = net.copy()
313+
299314
self.set_params(param_values, net)
300315

301316
results = {'net': net, 'param_values': param_values}

hnn_core/tests/test_batch_simulate.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,17 @@
33
# Ryan Thorpe <[email protected]>
44
# Mainak Jas <[email protected]>
55

6+
from pathlib import Path
67
import pytest
78
import numpy as np
89
import os
910

1011
from hnn_core.batch_simulate import BatchSimulate
1112
from hnn_core import jones_2009_model
1213

14+
hnn_core_root = Path(__file__).parents[1]
15+
assets_path = Path(hnn_core_root, 'tests', 'assets')
16+
1317

1418
@pytest.fixture
1519
def batch_simulate_instance(tmp_path):
@@ -33,9 +37,9 @@ def set_params(param_values, net):
3337
weights_ampa=weights_ampa,
3438
synaptic_delays=synaptic_delays)
3539

36-
net = jones_2009_model()
40+
net = jones_2009_model(mesh_shape=(3, 3))
3741
return BatchSimulate(net=net, set_params=set_params,
38-
tstop=1.,
42+
tstop=10,
3943
save_folder=tmp_path,
4044
batch_size=3)
4145

@@ -74,6 +78,9 @@ def test_parameter_validation():
7478

7579
with pytest.raises(TypeError, match="net must be"):
7680
BatchSimulate(net="invalid_network", set_params=lambda x: x)
81+
82+
with pytest.raises(TypeError, match="net_json must be"):
83+
BatchSimulate(net_json=123, set_params=lambda x: x)
7784

7885

7986
def test_generate_param_combinations(batch_simulate_instance, param_grid):
@@ -104,6 +111,21 @@ def test_run_single_sim(batch_simulate_instance):
104111
assert isinstance(result['net'], type(batch_simulate_instance.net))
105112

106113

114+
def test_net_json_loading(param_grid):
115+
"""Test loading the network from a JSON file."""
116+
json_path = assets_path / 'jones2009_3x3_drives.json'
117+
118+
batch_simulate = BatchSimulate(net_json=str(json_path),
119+
set_params=lambda x, y: x,
120+
tstop=70)
121+
122+
result = batch_simulate._run_single_sim(param_grid)
123+
assert isinstance(result, dict)
124+
assert 'net' in result
125+
assert 'param_values' in result
126+
assert 'dpl' in result
127+
128+
107129
def test_simulate_batch(batch_simulate_instance, param_grid):
108130
"""Test simulating a batch of parameter sets."""
109131
param_combinations = batch_simulate_instance._generate_param_combinations(

0 commit comments

Comments
 (0)