Skip to content

Commit d183020

Browse files
jnsbckmichaeldeistler
authored andcommitted
Fix synapse clamping (#492)
* add: wip version of tutorial on views * fix: added fixes for clamping and stimulating of synapses * fix: allow all states * fix: add current to states * fix: add tests for current and synapse clamping * fix: fix current clamp test * rm: rm accidently commited stash from different branch * enh: allow to remove multiple clamp * fix: make tests pass
1 parent 967c236 commit d183020

File tree

3 files changed

+115
-37
lines changed

3 files changed

+115
-37
lines changed

jaxley/integrate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,11 @@ def add_stimuli(
118118
if data_stimuli is not None:
119119
externals["i"] = jnp.concatenate([externals["i"], data_stimuli[1]])
120120
external_inds["i"] = jnp.concatenate(
121-
[external_inds["i"], data_stimuli[2].global_comp_index.to_numpy()]
121+
[external_inds["i"], data_stimuli[2].index.to_numpy()]
122122
)
123123
else:
124124
externals["i"] = data_stimuli[1]
125-
external_inds["i"] = data_stimuli[2].global_comp_index.to_numpy()
125+
external_inds["i"] = data_stimuli[2].index.to_numpy()
126126

127127
return externals, external_inds
128128

@@ -148,11 +148,11 @@ def add_clamps(
148148
if state_name in externals.keys():
149149
externals[state_name] = jnp.concatenate([externals[state_name], clamps])
150150
external_inds[state_name] = jnp.concatenate(
151-
[external_inds[state_name], inds.global_comp_index.to_numpy()]
151+
[external_inds[state_name], inds.index.to_numpy()]
152152
)
153153
else:
154154
externals[state_name] = clamps
155-
external_inds[state_name] = inds.global_comp_index.to_numpy()
155+
external_inds[state_name] = inds.index.to_numpy()
156156

157157
return externals, external_inds
158158

jaxley/modules/base.py

Lines changed: 60 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,6 +1179,15 @@ def add_to_group(self, group_name: str):
11791179
np.concatenate([self.base.groups[group_name], self._nodes_in_view])
11801180
)
11811181

1182+
def _get_state_names(self) -> Tuple[List, List]:
1183+
"""Collect all recordable / clampable states in the membrane and synapses.
1184+
1185+
Returns states seperated by comps and edges."""
1186+
channel_states = [name for c in self.channels for name in c.channel_states]
1187+
synapse_states = [name for s in self.synapses for name in s.synapse_states]
1188+
membrane_states = ["v", "i"] + self.membrane_current_names
1189+
return channel_states + membrane_states, synapse_states
1190+
11821191
def get_parameters(self) -> List[Dict[str, jnp.ndarray]]:
11831192
"""Get all trainable parameters.
11841193
@@ -1447,10 +1456,10 @@ def _init_morph_for_debugging(self):
14471456
self.base.debug_states["par_inds"] = self.base.par_inds
14481457

14491458
def record(self, state: str = "v", verbose=True):
1450-
in_view = None
1451-
in_view = self._edges_in_view if state in self.edges.columns else in_view
1452-
in_view = self._nodes_in_view if state in self.nodes.columns else in_view
1453-
assert in_view is not None, "State not found in nodes or edges."
1459+
comp_states, edge_states = self._get_state_names()
1460+
if state not in comp_states + edge_states:
1461+
raise KeyError(f"{state} is not a recognized state in this module.")
1462+
in_view = self._nodes_in_view if state in comp_states else self._edges_in_view
14541463

14551464
new_recs = pd.DataFrame(in_view, columns=["rec_index"])
14561465
new_recs["state"] = state
@@ -1514,8 +1523,6 @@ def clamp(self, state_name: str, state_array: jnp.ndarray, verbose: bool = True)
15141523
15151524
This function sets external states for the compartments.
15161525
"""
1517-
if state_name not in self.nodes.columns:
1518-
raise KeyError(f"{state_name} is not a recognized state in this module.")
15191526
self._external_input(state_name, state_array, verbose=verbose)
15201527

15211528
def _external_input(
@@ -1524,15 +1531,16 @@ def _external_input(
15241531
values: Optional[jnp.ndarray],
15251532
verbose: bool = True,
15261533
):
1534+
comp_states, edge_states = self._get_state_names()
1535+
if key not in comp_states + edge_states:
1536+
raise KeyError(f"{key} is not a recognized state in this module.")
15271537
values = values if values.ndim == 2 else jnp.expand_dims(values, axis=0)
15281538
batch_size = values.shape[0]
1529-
num_inserted = len(self._nodes_in_view)
1530-
is_multiple = num_inserted == batch_size
1531-
values = (
1532-
values
1533-
if is_multiple
1534-
else jnp.repeat(values, len(self._nodes_in_view), axis=0)
1539+
num_inserted = (
1540+
len(self._nodes_in_view) if key in comp_states else len(self._edges_in_view)
15351541
)
1542+
is_multiple = num_inserted == batch_size
1543+
values = values if is_multiple else jnp.repeat(values, num_inserted, axis=0)
15361544
assert batch_size in [
15371545
1,
15381546
num_inserted,
@@ -1546,9 +1554,12 @@ def _external_input(
15461554
[self.base.external_inds[key], self._nodes_in_view]
15471555
)
15481556
else:
1549-
self.base.externals[key] = values
1550-
self.base.external_inds[key] = self._nodes_in_view
1551-
1557+
if key in comp_states:
1558+
self.base.externals[key] = values
1559+
self.base.external_inds[key] = self._nodes_in_view
1560+
else:
1561+
self.base.externals[key] = values
1562+
self.base.external_inds[key] = self._edges_in_view
15521563
if verbose:
15531564
print(
15541565
f"Added {num_inserted} external_states. See `.externals` for details."
@@ -1588,8 +1599,12 @@ def data_clamp(
15881599
verbose: Whether or not to print the number of inserted clamps. `False`
15891600
by default because this method is meant to be jitted.
15901601
"""
1602+
comp_states, edge_states = self._get_state_names()
1603+
if state_name not in comp_states + edge_states:
1604+
raise KeyError(f"{state_name} is not a recognized state in this module.")
1605+
data = self.nodes if state_name in comp_states else self.edges
15911606
return self._data_external_input(
1592-
state_name, state_array, data_clamps, self.nodes, verbose=verbose
1607+
state_name, state_array, data_clamps, data, verbose=verbose
15931608
)
15941609

15951610
def _data_external_input(
@@ -1600,16 +1615,23 @@ def _data_external_input(
16001615
view: pd.DataFrame,
16011616
verbose: bool = False,
16021617
):
1618+
comp_states, edge_states = self._get_state_names()
16031619
state_array = (
16041620
state_array
16051621
if state_array.ndim == 2
16061622
else jnp.expand_dims(state_array, axis=0)
16071623
)
16081624
batch_size = state_array.shape[0]
1609-
num_inserted = len(self._nodes_in_view)
1625+
num_inserted = (
1626+
len(self._nodes_in_view)
1627+
if state_name in comp_states
1628+
else len(self._edges_in_view)
1629+
)
16101630
is_multiple = num_inserted == batch_size
16111631
state_array = (
1612-
state_array if is_multiple else jnp.repeat(state_array, len(view), axis=0)
1632+
state_array
1633+
if is_multiple
1634+
else jnp.repeat(state_array, num_inserted, axis=0)
16131635
)
16141636
assert batch_size in [
16151637
1,
@@ -1638,23 +1660,28 @@ def delete_stimuli(self):
16381660
"""Removes all stimuli from the module."""
16391661
self.delete_clamps("i")
16401662

1641-
def delete_clamps(self, state_name: str):
1663+
def delete_clamps(self, state_name: Optional[str] = None):
16421664
"""Removes all clamps of the given state from the module."""
1643-
if state_name in self.externals:
1644-
keep_inds = ~np.isin(
1645-
self.base.external_inds[state_name], self._nodes_in_view
1646-
)
1647-
base_exts = self.base.externals
1648-
base_exts_inds = self.base.external_inds
1649-
if np.all(~keep_inds):
1650-
base_exts.pop(state_name, None)
1651-
base_exts_inds.pop(state_name, None)
1665+
all_externals = list(self.externals.keys())
1666+
if "i" in all_externals:
1667+
all_externals.remove("i")
1668+
state_names = all_externals if state_name is None else [state_name]
1669+
for state_name in state_names:
1670+
if state_name in self.externals:
1671+
keep_inds = ~np.isin(
1672+
self.base.external_inds[state_name], self._nodes_in_view
1673+
)
1674+
base_exts = self.base.externals
1675+
base_exts_inds = self.base.external_inds
1676+
if np.all(~keep_inds):
1677+
base_exts.pop(state_name, None)
1678+
base_exts_inds.pop(state_name, None)
1679+
else:
1680+
base_exts[state_name] = base_exts[state_name][keep_inds]
1681+
base_exts_inds[state_name] = base_exts_inds[state_name][keep_inds]
1682+
self._update_view()
16521683
else:
1653-
base_exts[state_name] = base_exts[state_name][keep_inds]
1654-
base_exts_inds[state_name] = base_exts_inds[state_name][keep_inds]
1655-
self._update_view()
1656-
else:
1657-
pass # does not have to be deleted if not in externals
1684+
pass # does not have to be deleted if not in externals
16581685

16591686
def insert(self, channel: Channel):
16601687
"""Insert a channel into the module.

tests/test_clamp.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
import jax
55

6+
from jaxley.connect import connect
7+
from jaxley.synapses.ionotropic import IonotropicSynapse
8+
69
jax.config.update("jax_enable_x64", True)
710
jax.config.update("jax_platform_name", "cpu")
811
from typing import Optional
@@ -25,6 +28,54 @@ def test_clamp_pointneuron():
2528
assert np.all(v[:, 1:] == -50.0)
2629

2730

31+
def test_clamp_currents():
32+
comp = jx.Compartment()
33+
comp.insert(HH())
34+
comp.record("i_HH")
35+
36+
# test clamp
37+
comp.clamp("i_HH", 1.0 * jnp.ones((1000,)))
38+
i1 = jx.integrate(comp, t_max=1.0)
39+
assert np.all(i1[:, 1:] == 1.0)
40+
41+
# test data clamp
42+
data_clamps = None
43+
ipts = 1.0 * jnp.ones((1000,))
44+
data_clamps = comp.data_clamp("i_HH", ipts, data_clamps=data_clamps)
45+
46+
i2 = jx.integrate(comp, data_clamps=data_clamps, t_max=1.0)
47+
assert np.all(i2[:, 1:] == 1.0)
48+
49+
assert np.all(np.isclose(i1, i2))
50+
51+
52+
def test_clamp_synapse():
53+
comp = jx.Compartment()
54+
branch = jx.Branch(comp, 1)
55+
cell1 = jx.Cell(branch, [-1])
56+
cell2 = jx.Cell(branch, [-1])
57+
net = jx.Network([cell1, cell2])
58+
connect(net[0, 0, 0], net[1, 0, 0], IonotropicSynapse())
59+
net.record("IonotropicSynapse_s")
60+
61+
# test clamp
62+
net.clamp("IonotropicSynapse_s", 1.0 * jnp.ones((1000,)))
63+
s1 = jx.integrate(net, t_max=1.0)
64+
assert np.all(s1[:, 1:] == 1.0)
65+
66+
net.delete_clamps()
67+
68+
# test data clamp
69+
data_clamps = None
70+
ipts = 1.0 * jnp.ones((1000,))
71+
data_clamps = net.data_clamp("IonotropicSynapse_s", ipts, data_clamps=data_clamps)
72+
73+
s2 = jx.integrate(net, data_clamps=data_clamps, t_max=1.0)
74+
assert np.all(s2[:, 1:] == 1.0)
75+
76+
assert np.all(np.isclose(s1, s2))
77+
78+
2879
def test_clamp_multicompartment():
2980
comp = jx.Compartment()
3081
branch = jx.Branch(comp, 4)

0 commit comments

Comments
 (0)