Skip to content

Commit

Permalink
fix: restrict kwta (#45)
Browse files Browse the repository at this point in the history
* fix: restrict kwta

* feat: weight learning tag
  • Loading branch information
saeedark authored Jan 2, 2024
1 parent 141bf2c commit 3a291d3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
14 changes: 7 additions & 7 deletions conex/behaviors/neurons/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,22 +91,22 @@ def initialize(self, neurons):

def forward(self, neurons):
will_spike = neurons.v >= neurons.threshold
will_spike_v = will_spike * (neurons.v - neurons.threshold)
v_values = neurons.v

dim = 0
if self.dimension is not None:
will_spike_v = will_spike_v.view(self.shape)
v_values = v_values.view(self.shape)
will_spike = will_spike.view(self.shape)
dim = self.dimension

if (will_spike.sum(axis=dim) <= self.k).all():
return

k_values, k_winners_indices = torch.topk(
will_spike_v, self.k, dim=dim, sorted=False
_, k_winners_indices = torch.topk(
v_values, self.k, dim=dim, sorted=False
)
min_values = k_values.min(dim=0).values
winners = will_spike_v >= min_values.expand(will_spike_v.size())
ignored = will_spike * (~winners)

ignored = will_spike
ignored.scatter_(dim, k_winners_indices, False)

neurons.v[ignored.view((-1,))] = neurons.v_reset
4 changes: 4 additions & 0 deletions conex/behaviors/synapses/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def no_bound(w, w_min, w_max):


class BaseLearning(Behavior):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.add_tag("weight_learning")

def get_spike_and_trace(self, synapse):
src_spike = synapse.src.axon.get_spike(synapse.src, synapse.src_delay)
dst_spike = synapse.dst.axon.get_spike(synapse.dst, synapse.dst_delay)
Expand Down

0 comments on commit 3a291d3

Please sign in to comment.