@@ -1179,6 +1179,15 @@ def add_to_group(self, group_name: str):
1179
1179
np .concatenate ([self .base .groups [group_name ], self ._nodes_in_view ])
1180
1180
)
1181
1181
1182
+ def _get_state_names (self ) -> Tuple [List , List ]:
1183
+ """Collect all recordable / clampable states in the membrane and synapses.
1184
+
1185
+ Returns states seperated by comps and edges."""
1186
+ channel_states = [name for c in self .channels for name in c .channel_states ]
1187
+ synapse_states = [name for s in self .synapses for name in s .synapse_states ]
1188
+ membrane_states = ["v" , "i" ] + self .membrane_current_names
1189
+ return channel_states + membrane_states , synapse_states
1190
+
1182
1191
def get_parameters (self ) -> List [Dict [str , jnp .ndarray ]]:
1183
1192
"""Get all trainable parameters.
1184
1193
@@ -1447,10 +1456,10 @@ def _init_morph_for_debugging(self):
1447
1456
self .base .debug_states ["par_inds" ] = self .base .par_inds
1448
1457
1449
1458
def record (self , state : str = "v" , verbose = True ):
1450
- in_view = None
1451
- in_view = self . _edges_in_view if state in self . edges . columns else in_view
1452
- in_view = self . _nodes_in_view if state in self . nodes . columns else in_view
1453
- assert in_view is not None , "State not found in nodes or edges."
1459
+ comp_states , edge_states = self . _get_state_names ()
1460
+ if state not in comp_states + edge_states :
1461
+ raise KeyError ( f" { state } is not a recognized state in this module." )
1462
+ in_view = self . _nodes_in_view if state in comp_states else self . _edges_in_view
1454
1463
1455
1464
new_recs = pd .DataFrame (in_view , columns = ["rec_index" ])
1456
1465
new_recs ["state" ] = state
@@ -1514,8 +1523,6 @@ def clamp(self, state_name: str, state_array: jnp.ndarray, verbose: bool = True)
1514
1523
1515
1524
This function sets external states for the compartments.
1516
1525
"""
1517
- if state_name not in self .nodes .columns :
1518
- raise KeyError (f"{ state_name } is not a recognized state in this module." )
1519
1526
self ._external_input (state_name , state_array , verbose = verbose )
1520
1527
1521
1528
def _external_input (
@@ -1524,15 +1531,16 @@ def _external_input(
1524
1531
values : Optional [jnp .ndarray ],
1525
1532
verbose : bool = True ,
1526
1533
):
1534
+ comp_states , edge_states = self ._get_state_names ()
1535
+ if key not in comp_states + edge_states :
1536
+ raise KeyError (f"{ key } is not a recognized state in this module." )
1527
1537
values = values if values .ndim == 2 else jnp .expand_dims (values , axis = 0 )
1528
1538
batch_size = values .shape [0 ]
1529
- num_inserted = len (self ._nodes_in_view )
1530
- is_multiple = num_inserted == batch_size
1531
- values = (
1532
- values
1533
- if is_multiple
1534
- else jnp .repeat (values , len (self ._nodes_in_view ), axis = 0 )
1539
+ num_inserted = (
1540
+ len (self ._nodes_in_view ) if key in comp_states else len (self ._edges_in_view )
1535
1541
)
1542
+ is_multiple = num_inserted == batch_size
1543
+ values = values if is_multiple else jnp .repeat (values , num_inserted , axis = 0 )
1536
1544
assert batch_size in [
1537
1545
1 ,
1538
1546
num_inserted ,
@@ -1546,9 +1554,12 @@ def _external_input(
1546
1554
[self .base .external_inds [key ], self ._nodes_in_view ]
1547
1555
)
1548
1556
else :
1549
- self .base .externals [key ] = values
1550
- self .base .external_inds [key ] = self ._nodes_in_view
1551
-
1557
+ if key in comp_states :
1558
+ self .base .externals [key ] = values
1559
+ self .base .external_inds [key ] = self ._nodes_in_view
1560
+ else :
1561
+ self .base .externals [key ] = values
1562
+ self .base .external_inds [key ] = self ._edges_in_view
1552
1563
if verbose :
1553
1564
print (
1554
1565
f"Added { num_inserted } external_states. See `.externals` for details."
@@ -1588,8 +1599,12 @@ def data_clamp(
1588
1599
verbose: Whether or not to print the number of inserted clamps. `False`
1589
1600
by default because this method is meant to be jitted.
1590
1601
"""
1602
+ comp_states , edge_states = self ._get_state_names ()
1603
+ if state_name not in comp_states + edge_states :
1604
+ raise KeyError (f"{ state_name } is not a recognized state in this module." )
1605
+ data = self .nodes if state_name in comp_states else self .edges
1591
1606
return self ._data_external_input (
1592
- state_name , state_array , data_clamps , self . nodes , verbose = verbose
1607
+ state_name , state_array , data_clamps , data , verbose = verbose
1593
1608
)
1594
1609
1595
1610
def _data_external_input (
@@ -1600,16 +1615,23 @@ def _data_external_input(
1600
1615
view : pd .DataFrame ,
1601
1616
verbose : bool = False ,
1602
1617
):
1618
+ comp_states , edge_states = self ._get_state_names ()
1603
1619
state_array = (
1604
1620
state_array
1605
1621
if state_array .ndim == 2
1606
1622
else jnp .expand_dims (state_array , axis = 0 )
1607
1623
)
1608
1624
batch_size = state_array .shape [0 ]
1609
- num_inserted = len (self ._nodes_in_view )
1625
+ num_inserted = (
1626
+ len (self ._nodes_in_view )
1627
+ if state_name in comp_states
1628
+ else len (self ._edges_in_view )
1629
+ )
1610
1630
is_multiple = num_inserted == batch_size
1611
1631
state_array = (
1612
- state_array if is_multiple else jnp .repeat (state_array , len (view ), axis = 0 )
1632
+ state_array
1633
+ if is_multiple
1634
+ else jnp .repeat (state_array , num_inserted , axis = 0 )
1613
1635
)
1614
1636
assert batch_size in [
1615
1637
1 ,
@@ -1638,23 +1660,28 @@ def delete_stimuli(self):
1638
1660
"""Removes all stimuli from the module."""
1639
1661
self .delete_clamps ("i" )
1640
1662
1641
- def delete_clamps (self , state_name : str ):
1663
+ def delete_clamps (self , state_name : Optional [ str ] = None ):
1642
1664
"""Removes all clamps of the given state from the module."""
1643
- if state_name in self .externals :
1644
- keep_inds = ~ np .isin (
1645
- self .base .external_inds [state_name ], self ._nodes_in_view
1646
- )
1647
- base_exts = self .base .externals
1648
- base_exts_inds = self .base .external_inds
1649
- if np .all (~ keep_inds ):
1650
- base_exts .pop (state_name , None )
1651
- base_exts_inds .pop (state_name , None )
1665
+ all_externals = list (self .externals .keys ())
1666
+ if "i" in all_externals :
1667
+ all_externals .remove ("i" )
1668
+ state_names = all_externals if state_name is None else [state_name ]
1669
+ for state_name in state_names :
1670
+ if state_name in self .externals :
1671
+ keep_inds = ~ np .isin (
1672
+ self .base .external_inds [state_name ], self ._nodes_in_view
1673
+ )
1674
+ base_exts = self .base .externals
1675
+ base_exts_inds = self .base .external_inds
1676
+ if np .all (~ keep_inds ):
1677
+ base_exts .pop (state_name , None )
1678
+ base_exts_inds .pop (state_name , None )
1679
+ else :
1680
+ base_exts [state_name ] = base_exts [state_name ][keep_inds ]
1681
+ base_exts_inds [state_name ] = base_exts_inds [state_name ][keep_inds ]
1682
+ self ._update_view ()
1652
1683
else :
1653
- base_exts [state_name ] = base_exts [state_name ][keep_inds ]
1654
- base_exts_inds [state_name ] = base_exts_inds [state_name ][keep_inds ]
1655
- self ._update_view ()
1656
- else :
1657
- pass # does not have to be deleted if not in externals
1684
+ pass # does not have to be deleted if not in externals
1658
1685
1659
1686
def insert (self , channel : Channel ):
1660
1687
"""Insert a channel into the module.
0 commit comments