diff --git a/models/eprop_iaf.cpp b/models/eprop_iaf.cpp
index fb336a40f4..a081adec26 100644
--- a/models/eprop_iaf.cpp
+++ b/models/eprop_iaf.cpp
@@ -85,6 +85,9 @@ eprop_iaf::Parameters_::Parameters_()
, kappa_( 0.97 )
, kappa_reg_( 0.97 )
, eprop_isi_trace_cutoff_( 1000.0 )
+ , delay_rec_out_( 1 )
+ , delay_out_rec_( 1 )
+ , delay_total_( 1 )
{
}
@@ -131,6 +134,8 @@ eprop_iaf::Parameters_::get( DictionaryDatum& d ) const
def< double >( d, names::kappa, kappa_ );
def< double >( d, names::kappa_reg, kappa_reg_ );
def< double >( d, names::eprop_isi_trace_cutoff, eprop_isi_trace_cutoff_ );
+ def< double >( d, names::delay_rec_out, Time( Time::step( delay_rec_out_ ) ).get_ms() );
+ def< double >( d, names::delay_out_rec, Time( Time::step( delay_out_rec_ ) ).get_ms() );
}
double
@@ -169,6 +174,14 @@ eprop_iaf::Parameters_::set( const DictionaryDatum& d, Node* node )
updateValueParam< double >( d, names::kappa_reg, kappa_reg_, node );
updateValueParam< double >( d, names::eprop_isi_trace_cutoff, eprop_isi_trace_cutoff_, node );
+ double delay_rec_out_ms = Time( Time::step( delay_rec_out_ ) ).get_ms();
+ updateValueParam< double >( d, names::delay_rec_out, delay_rec_out_ms, node );
+ delay_rec_out_ = Time( Time::ms( delay_rec_out_ms ) ).get_steps();
+
+ double delay_out_rec_ms = Time( Time::step( delay_out_rec_ ) ).get_ms();
+ updateValueParam< double >( d, names::delay_out_rec, delay_out_rec_ms, node );
+ delay_out_rec_ = Time( Time::ms( delay_out_rec_ms ) ).get_steps();
+
if ( C_m_ <= 0 )
{
throw BadProperty( "Membrane capacitance C_m > 0 required." );
@@ -214,6 +227,18 @@ eprop_iaf::Parameters_::set( const DictionaryDatum& d, Node* node )
throw BadProperty( "Cutoff of integration of eprop trace between spikes eprop_isi_trace_cutoff ≥ 0 required." );
}
+ if ( delay_rec_out_ < 1 )
+ {
+ throw BadProperty( "Connection delay from recurrent to readout neuron ≥ 1 required." );
+ }
+
+ if ( delay_out_rec_ < 1 )
+ {
+ throw BadProperty( "Connection delay from readout to recurrent neuron ≥ 1 required." );
+ }
+
+ delay_total_ = delay_rec_out_ + ( delay_out_rec_ - 1 );
+
return delta_EL;
}
@@ -278,6 +303,14 @@ eprop_iaf::pre_run_hook()
V_.P_v_m_ = std::exp( -dt / P_.tau_m_ );
V_.P_i_in_ = P_.tau_m_ / P_.C_m_ * ( 1.0 - V_.P_v_m_ );
+
+ if ( eprop_history_.empty() )
+ {
+ for ( long t = -P_.delay_total_; t < 0; ++t )
+ {
+ append_new_eprop_history_entry( t );
+ }
+ }
}
@@ -373,7 +406,8 @@ eprop_iaf::handle( DataLoggingRequest& e )
void
eprop_iaf::compute_gradient( const long t_spike,
const long t_spike_previous,
- double& z_previous_buffer,
+ std::queue< double >& z_previous_buffer,
+ double& z_previous,
double& z_bar,
double& e_bar,
double& e_bar_reg,
@@ -382,26 +416,31 @@ eprop_iaf::compute_gradient( const long t_spike,
const CommonSynapseProperties& cp,
WeightOptimizer* optimizer )
{
- double e = 0.0; // eligibility trace
- double z = 0.0; // spiking variable
- double z_current_buffer = 1.0; // buffer containing the spike that triggered the current integration
- double psi = 0.0; // surrogate gradient
- double L = 0.0; // learning signal
- double firing_rate_reg = 0.0; // firing rate regularization
- double grad = 0.0; // gradient
+ double e = 0.0; // eligibility trace
+ double z = 0.0; // spiking variable
+ double z_current = 1.0; // spike state that triggered the current integration
+ double psi = 0.0; // surrogate gradient
+ double L = 0.0; // learning signal
+ double firing_rate_reg = 0.0; // firing rate regularization
+ double grad = 0.0; // gradient
const EpropSynapseCommonProperties& ecp = static_cast< const EpropSynapseCommonProperties& >( cp );
const auto optimize_each_step = ( *ecp.optimizer_cp_ ).optimize_each_step_;
- auto eprop_hist_it = get_eprop_history( t_spike_previous - 1 );
+ auto eprop_hist_it = get_eprop_history( t_spike_previous - P_.delay_total_ );
const long t_compute_until = std::min( t_spike_previous + V_.eprop_isi_trace_cutoff_steps_, t_spike );
for ( long t = t_spike_previous; t < t_compute_until; ++t, ++eprop_hist_it )
{
- z = z_previous_buffer;
- z_previous_buffer = z_current_buffer;
- z_current_buffer = 0.0;
+ if ( P_.delay_total_ > 1 )
+ {
+ update_pre_syn_buffer_multiple_entries( z, z_current, z_previous, z_previous_buffer, t_spike, t );
+ }
+ else
+ {
+ update_pre_syn_buffer_one_entry( z, z_current, z_previous, z_previous_buffer, t_spike, t );
+ }
psi = eprop_hist_it->surrogate_gradient_;
L = eprop_hist_it->learning_signal_;
diff --git a/models/eprop_iaf.h b/models/eprop_iaf.h
index 225de65333..22e04dc820 100644
--- a/models/eprop_iaf.h
+++ b/models/eprop_iaf.h
@@ -390,6 +390,7 @@ class eprop_iaf : public EpropArchivingNodeRecurrent< false >
void compute_gradient( const long,
const long,
+ std::queue< double >&,
double&,
double&,
double&,
@@ -402,6 +403,9 @@ class eprop_iaf : public EpropArchivingNodeRecurrent< false >
long get_shift() const override;
bool is_eprop_recurrent_node() const override;
long get_eprop_isi_trace_cutoff() const override;
+ long get_delay_total() const override;
+ long get_delay_recurrent_to_readout() const override;
+ long get_delay_readout_to_recurrent() const override;
//! Map for storing a static set of recordables.
friend class RecordablesMap< eprop_iaf >;
@@ -458,6 +462,15 @@ class eprop_iaf : public EpropArchivingNodeRecurrent< false >
//! Time interval from the previous spike until the cutoff of e-prop update integration between two spikes (ms).
double eprop_isi_trace_cutoff_;
+ //! Connection delay from recurrent to readout neuron.
+ long delay_rec_out_;
+
+ //! Connection delay from readout to recurrent neuron.
+ long delay_out_rec_;
+
+ //! Sum of connection delays from recurrent to readout neuron and from readout to recurrent neuron.
+ long delay_total_;
+
//! Default constructor.
Parameters_();
@@ -594,10 +607,35 @@ eprop_iaf::get_eprop_isi_trace_cutoff() const
return V_.eprop_isi_trace_cutoff_steps_;
}
+inline long
+eprop_iaf::get_delay_total() const
+{
+ return P_.delay_total_;
+}
+
+inline long
+eprop_iaf::get_delay_recurrent_to_readout() const
+{
+ return P_.delay_rec_out_;
+}
+
+inline long
+eprop_iaf::get_delay_readout_to_recurrent() const
+{
+ return P_.delay_out_rec_;
+}
+
inline size_t
eprop_iaf::send_test_event( Node& target, size_t receptor_type, synindex, bool )
{
SpikeEvent e;
+
+ // To perform a consistency check on the delay parameter d_out_rec between recurrent
+ // neurons and readout neurons, the recurrent neurons send a test event with a delay
+ // specified by d_rec_out. Upon receiving the test event from the recurrent neuron,
+ // the readout neuron checks if the delay with which the event was received matches
+ // its own specified delay parameter d_rec_out.
+ e.set_delay_steps( P_.delay_rec_out_ );
e.set_sender( *this );
return target.handles_test_event( e, receptor_type );
}
diff --git a/models/eprop_iaf_adapt.cpp b/models/eprop_iaf_adapt.cpp
index 10cb9ff224..eccfdaa454 100644
--- a/models/eprop_iaf_adapt.cpp
+++ b/models/eprop_iaf_adapt.cpp
@@ -89,6 +89,9 @@ eprop_iaf_adapt::Parameters_::Parameters_()
, kappa_( 0.97 )
, kappa_reg_( 0.97 )
, eprop_isi_trace_cutoff_( 1000.0 )
+ , delay_rec_out_( 1 )
+ , delay_out_rec_( 1 )
+ , delay_total_( 1 )
{
}
@@ -139,6 +142,8 @@ eprop_iaf_adapt::Parameters_::get( DictionaryDatum& d ) const
def< double >( d, names::kappa, kappa_ );
def< double >( d, names::kappa_reg, kappa_reg_ );
def< double >( d, names::eprop_isi_trace_cutoff, eprop_isi_trace_cutoff_ );
+ def< double >( d, names::delay_rec_out, Time( Time::step( delay_rec_out_ ) ).get_ms() );
+ def< double >( d, names::delay_out_rec, Time( Time::step( delay_out_rec_ ) ).get_ms() );
}
double
@@ -179,6 +184,12 @@ eprop_iaf_adapt::Parameters_::set( const DictionaryDatum& d, Node* node )
updateValueParam< double >( d, names::kappa_reg, kappa_reg_, node );
updateValueParam< double >( d, names::eprop_isi_trace_cutoff, eprop_isi_trace_cutoff_, node );
+ const double delay_rec_out_ = Time::step( delay_rec_out_ );
+ updateValueParam< double >( d, names::delay_rec_out, Time( delay_rec_out_).get_ms(), node );
+
+ const double delay_out_rec_ = Time::step( delay_out_rec_ );
+ updateValueParam< double >( d, names::delay_out_rec, Time( delay_out_rec_ ).get_ms(), node );
+
if ( adapt_beta_ < 0 )
{
throw BadProperty( "Threshold adaptation prefactor adapt_beta ≥ 0 required." );
@@ -234,6 +245,18 @@ eprop_iaf_adapt::Parameters_::set( const DictionaryDatum& d, Node* node )
throw BadProperty( "Cutoff of integration of eprop trace between spikes eprop_isi_trace_cutoff ≥ 0 required." );
}
+ if ( delay_rec_out_ < 1 )
+ {
+ throw BadProperty( "Connection delay from recurrent to readout neuron ≥ 1 required." );
+ }
+
+ if ( delay_out_rec_ < 1 )
+ {
+ throw BadProperty( "Connection delay from readout to recurrent neuron ≥ 1 required." );
+ }
+
+ delay_total_ = delay_rec_out_ + ( delay_out_rec_ - 1 );
+
return delta_EL;
}
@@ -313,6 +336,14 @@ eprop_iaf_adapt::pre_run_hook()
V_.P_v_m_ = std::exp( -dt / P_.tau_m_ );
V_.P_i_in_ = P_.tau_m_ / P_.C_m_ * ( 1.0 - V_.P_v_m_ );
V_.P_adapt_ = std::exp( -dt / P_.adapt_tau_ );
+
+ if ( eprop_history_.empty() )
+ {
+ for ( long t = -P_.delay_total_; t < 0; ++t )
+ {
+ append_new_eprop_history_entry( t );
+ }
+ }
}
@@ -412,7 +443,8 @@ eprop_iaf_adapt::handle( DataLoggingRequest& e )
void
eprop_iaf_adapt::compute_gradient( const long t_spike,
const long t_spike_previous,
- double& z_previous_buffer,
+ std::queue< double >& z_previous_buffer,
+ double& z_previous,
double& z_bar,
double& e_bar,
double& e_bar_reg,
@@ -421,26 +453,31 @@ eprop_iaf_adapt::compute_gradient( const long t_spike,
const CommonSynapseProperties& cp,
WeightOptimizer* optimizer )
{
- double e = 0.0; // eligibility trace
- double z = 0.0; // spiking variable
- double z_current_buffer = 1.0; // buffer containing the spike that triggered the current integration
- double psi = 0.0; // surrogate gradient
- double L = 0.0; // learning signal
- double firing_rate_reg = 0.0; // firing rate regularization
- double grad = 0.0; // gradient
+ double e = 0.0; // eligibility trace
+ double z = 0.0; // spiking variable
+ double z_current = 1.0; // spike state that triggered the current integration
+ double psi = 0.0; // surrogate gradient
+ double L = 0.0; // learning signal
+ double firing_rate_reg = 0.0; // firing rate regularization
+ double grad = 0.0; // gradient
const EpropSynapseCommonProperties& ecp = static_cast< const EpropSynapseCommonProperties& >( cp );
const auto optimize_each_step = ( *ecp.optimizer_cp_ ).optimize_each_step_;
- auto eprop_hist_it = get_eprop_history( t_spike_previous - 1 );
+ auto eprop_hist_it = get_eprop_history( t_spike_previous - P_.delay_total_ );
const long t_compute_until = std::min( t_spike_previous + V_.eprop_isi_trace_cutoff_steps_, t_spike );
for ( long t = t_spike_previous; t < t_compute_until; ++t, ++eprop_hist_it )
{
- z = z_previous_buffer;
- z_previous_buffer = z_current_buffer;
- z_current_buffer = 0.0;
+ if ( P_.delay_total_ > 1 )
+ {
+ update_pre_syn_buffer_multiple_entries( z, z_current, z_previous, z_previous_buffer, t_spike, t );
+ }
+ else
+ {
+ update_pre_syn_buffer_one_entry( z, z_current, z_previous, z_previous_buffer, t_spike, t );
+ }
psi = eprop_hist_it->surrogate_gradient_;
L = eprop_hist_it->learning_signal_;
diff --git a/models/eprop_iaf_adapt.h b/models/eprop_iaf_adapt.h
index d404dcbd69..fe9689e39f 100644
--- a/models/eprop_iaf_adapt.h
+++ b/models/eprop_iaf_adapt.h
@@ -358,6 +358,7 @@ class eprop_iaf_adapt : public EpropArchivingNodeRecurrent< false >
void compute_gradient( const long,
const long,
+ std::queue< double >&,
double&,
double&,
double&,
@@ -370,6 +371,9 @@ class eprop_iaf_adapt : public EpropArchivingNodeRecurrent< false >
long get_shift() const override;
bool is_eprop_recurrent_node() const override;
long get_eprop_isi_trace_cutoff() const override;
+ long get_delay_total() const override;
+ long get_delay_recurrent_to_readout() const override;
+ long get_delay_readout_to_recurrent() const override;
//! Map for storing a static set of recordables.
friend class RecordablesMap< eprop_iaf_adapt >;
@@ -432,6 +436,15 @@ class eprop_iaf_adapt : public EpropArchivingNodeRecurrent< false >
//! Time interval from the previous spike until the cutoff of e-prop update integration between two spikes (ms).
double eprop_isi_trace_cutoff_;
+ //! Connection delay from recurrent to readout neuron.
+ long delay_rec_out_;
+
+ //! Connection delay from readout to recurrent neuron.
+ long delay_out_rec_;
+
+ //! Sum of connection delays from recurrent to readout neuron and readout to recurrent neuron.
+ long delay_total_;
+
//! Default constructor.
Parameters_();
@@ -591,10 +604,35 @@ eprop_iaf_adapt::get_eprop_isi_trace_cutoff() const
return V_.eprop_isi_trace_cutoff_steps_;
}
+inline long
+eprop_iaf_adapt::get_delay_total() const
+{
+ return P_.delay_total_;
+}
+
+inline long
+eprop_iaf_adapt::get_delay_recurrent_to_readout() const
+{
+ return P_.delay_rec_out_;
+}
+
+inline long
+eprop_iaf_adapt::get_delay_readout_to_recurrent() const
+{
+ return P_.delay_out_rec_;
+}
+
inline size_t
eprop_iaf_adapt::send_test_event( Node& target, size_t receptor_type, synindex, bool )
{
SpikeEvent e;
+
+ // To perform a consistency check on the delay parameter d_out_rec between recurrent
+ // neurons and readout neurons, the recurrent neurons send a test event with a delay
+ // specified by d_rec_out. Upon receiving the test event from the recurrent neuron,
+ // the readout neuron checks if the delay with which the event was received matches
+ // its own specified delay parameter d_rec_out.
+ e.set_delay_steps( P_.delay_rec_out_ );
e.set_sender( *this );
return target.handles_test_event( e, receptor_type );
}
diff --git a/models/eprop_iaf_psc_delta.cpp b/models/eprop_iaf_psc_delta.cpp
index a8b13ac4e4..0bbfe38f49 100644
--- a/models/eprop_iaf_psc_delta.cpp
+++ b/models/eprop_iaf_psc_delta.cpp
@@ -87,6 +87,9 @@ eprop_iaf_psc_delta::Parameters_::Parameters_()
, kappa_( 0.97 )
, kappa_reg_( 0.97 )
, eprop_isi_trace_cutoff_( 1000.0 )
+ , delay_rec_out_( 1 )
+ , delay_out_rec_( 1 )
+ , delay_total_( 1 )
{
}
@@ -134,6 +137,8 @@ eprop_iaf_psc_delta::Parameters_::get( DictionaryDatum& d ) const
def< double >( d, names::kappa, kappa_ );
def< double >( d, names::kappa_reg, kappa_reg_ );
def< double >( d, names::eprop_isi_trace_cutoff, eprop_isi_trace_cutoff_ );
+ def< double >( d, names::delay_rec_out, Time::step( delay_rec_out_ ) );
+ def< double >( d, names::delay_out_rec, Time::step( delay_out_rec_ ) );
}
double
@@ -162,23 +167,16 @@ eprop_iaf_psc_delta::Parameters_::set( const DictionaryDatum& d, Node* node )
updateValueParam< double >( d, names::beta, beta_, node );
updateValueParam< double >( d, names::gamma, gamma_, node );
-
- if ( updateValueParam< std::string >( d, names::surrogate_gradient_function, surrogate_gradient_function_, node ) )
- {
- eprop_iaf_psc_delta* nrn = dynamic_cast< eprop_iaf_psc_delta* >( node );
- assert( nrn );
- auto compute_surrogate_gradient = nrn->find_surrogate_gradient( surrogate_gradient_function_ );
- nrn->compute_surrogate_gradient_ = compute_surrogate_gradient;
- }
-
+ updateValueParam< std::string >( d, names::surrogate_gradient_function, surrogate_gradient_function_, node );
updateValueParam< double >( d, names::kappa, kappa_, node );
updateValueParam< double >( d, names::kappa_reg, kappa_reg_, node );
updateValueParam< double >( d, names::eprop_isi_trace_cutoff, eprop_isi_trace_cutoff_, node );
- if ( V_th_ < V_min_ )
- {
- throw BadProperty( "Spike threshold voltage V_th ≥ minimal voltage V_min required." );
- }
+ const double delay_rec_out_ = Time::step( delay_rec_out_ );
+ updateValueParam< double >( d, names::delay_rec_out, Time( delay_rec_out_).get_ms(), node );
+
+ const double delay_out_rec_ = Time::step( delay_out_rec_ );
+ updateValueParam< double >( d, names::delay_out_rec, Time( delay_out_rec_).get_ms(), node );
if ( V_reset_ >= V_th_ )
{
@@ -230,6 +228,19 @@ eprop_iaf_psc_delta::Parameters_::set( const DictionaryDatum& d, Node* node )
throw BadProperty( "Cutoff of integration of eprop trace between spikes eprop_isi_trace_cutoff ≥ 0 required." );
}
+ if ( delay_rec_out_ < 1 )
+ {
+ throw BadProperty( "Connection delay from recurrent to readout neuron ≥ 1 required." );
+ }
+
+ if ( delay_out_rec_ < 1 )
+ {
+ throw BadProperty( "Connection delay from readout to recurrent neuron ≥ 1 required." );
+ }
+
+ delay_total_ = delay_rec_out_ + ( delay_out_rec_ - 1 );
+
+
return delta_EL;
}
@@ -294,8 +305,27 @@ eprop_iaf_psc_delta::pre_run_hook()
V_.P_v_m_ = std::exp( -dt / P_.tau_m_ );
V_.P_i_in_ = P_.tau_m_ / P_.C_m_ * ( 1.0 - V_.P_v_m_ );
+
+ if ( eprop_history_.empty() )
+ {
+ for ( long t = -P_.delay_total_; t < 0; ++t )
+ {
+ append_new_eprop_history_entry( t );
+ }
+ }
}
+long
+eprop_iaf_psc_delta::get_shift() const
+{
+ return offset_gen_ + delay_in_rec_;
+}
+
+bool
+eprop_iaf_psc_delta::is_eprop_recurrent_node() const
+{
+ return true;
+}
/* ----------------------------------------------------------------
* Update function
@@ -406,7 +436,8 @@ eprop_iaf_psc_delta::handle( DataLoggingRequest& e )
void
eprop_iaf_psc_delta::compute_gradient( const long t_spike,
const long t_spike_previous,
- double& z_previous_buffer,
+ std::queue< double >& z_previous_buffer,
+ double& z_previous,
double& z_bar,
double& e_bar,
double& e_bar_reg,
@@ -415,26 +446,31 @@ eprop_iaf_psc_delta::compute_gradient( const long t_spike,
const CommonSynapseProperties& cp,
WeightOptimizer* optimizer )
{
- double e = 0.0; // eligibility trace
- double z = 0.0; // spiking variable
- double z_current_buffer = 1.0; // buffer containing the spike that triggered the current integration
- double psi = 0.0; // surrogate gradient
- double L = 0.0; // learning signal
- double firing_rate_reg = 0.0; // firing rate regularization
- double grad = 0.0; // gradient
+ double e = 0.0; // eligibility trace
+ double z = 0.0; // spiking variable
+ double z_current = 1.0; // spike state that triggered the current integration
+ double psi = 0.0; // surrogate gradient
+ double L = 0.0; // learning signal
+ double firing_rate_reg = 0.0; // firing rate regularization
+ double grad = 0.0; // gradient
const EpropSynapseCommonProperties& ecp = static_cast< const EpropSynapseCommonProperties& >( cp );
const auto optimize_each_step = ( *ecp.optimizer_cp_ ).optimize_each_step_;
- auto eprop_hist_it = get_eprop_history( t_spike_previous - 1 );
+ auto eprop_hist_it = get_eprop_history( t_spike_previous - P_.delay_total_ );
const long t_compute_until = std::min( t_spike_previous + V_.eprop_isi_trace_cutoff_steps_, t_spike );
for ( long t = t_spike_previous; t < t_compute_until; ++t, ++eprop_hist_it )
{
- z = z_previous_buffer;
- z_previous_buffer = z_current_buffer;
- z_current_buffer = 0.0;
+ if ( P_.delay_total_ > 1 )
+ {
+ update_pre_syn_buffer_multiple_entries( z, z_current, z_previous, z_previous_buffer, t_spike, t );
+ }
+ else
+ {
+ update_pre_syn_buffer_one_entry( z, z_current, z_previous, z_previous_buffer, t_spike, t );
+ }
psi = eprop_hist_it->surrogate_gradient_;
L = eprop_hist_it->learning_signal_;
diff --git a/models/eprop_iaf_psc_delta.h b/models/eprop_iaf_psc_delta.h
index c36066a984..269ee7783f 100644
--- a/models/eprop_iaf_psc_delta.h
+++ b/models/eprop_iaf_psc_delta.h
@@ -307,7 +307,7 @@ References
https://doi.org/10.1038/s41467-020-17236-y
.. [2] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Plesser HE,
- Dahmen D, Bolten M, Van Albada SJ, Diesmann M. Event-based
+ Dahmen D, Bolten M, Van Albada SJ*, Diesmann M*. Event-based
implementation of eligibility propagation (in preparation)
.. [3] Neftci EO, Mostafa H, Zenke F (2019). Surrogate Gradient Learning in
@@ -402,6 +402,7 @@ class eprop_iaf_psc_delta : public EpropArchivingNodeRecurrent< false >
void compute_gradient( const long,
const long,
+ std::queue< double >&,
double&,
double&,
double&,
@@ -414,6 +415,9 @@ class eprop_iaf_psc_delta : public EpropArchivingNodeRecurrent< false >
long get_shift() const override;
bool is_eprop_recurrent_node() const override;
long get_eprop_isi_trace_cutoff() const override;
+ long get_delay_total() const override;
+ long get_delay_recurrent_to_readout() const override;
+ long get_delay_readout_to_recurrent() const override;
//! Map for storing a static set of recordables.
friend class RecordablesMap< eprop_iaf_psc_delta >;
@@ -476,6 +480,15 @@ class eprop_iaf_psc_delta : public EpropArchivingNodeRecurrent< false >
//! Time interval from the previous spike until the cutoff of e-prop update integration between two spikes (ms).
double eprop_isi_trace_cutoff_;
+ //! Connection delay from recurrent to readout neuron.
+ long delay_rec_out_;
+
+ //! Connection delay from readout to recurrent neuron.
+ long delay_out_rec_;
+
+ //! Sum of connection delays from recurrent to readout neuron and readout to recurrent neuron.
+ long delay_total_;
+
//! Default constructor.
Parameters_();
@@ -592,27 +605,40 @@ class eprop_iaf_psc_delta : public EpropArchivingNodeRecurrent< false >
};
inline long
-eprop_iaf_psc_delta::get_shift() const
+eprop_iaf_psc_delta::get_eprop_isi_trace_cutoff() const
{
- return offset_gen_ + delay_in_rec_;
+ return V_.eprop_isi_trace_cutoff_steps_;
}
-inline bool
-eprop_iaf_psc_delta::is_eprop_recurrent_node() const
+inline long
+eprop_iaf_psc_delta::get_delay_total() const
{
- return true;
+ return P_.delay_total_;
}
inline long
-eprop_iaf_psc_delta::get_eprop_isi_trace_cutoff() const
+eprop_iaf_psc_delta::get_delay_recurrent_to_readout() const
{
- return V_.eprop_isi_trace_cutoff_steps_;
+ return P_.delay_rec_out_;
+}
+
+inline long
+eprop_iaf_psc_delta::get_delay_readout_to_recurrent() const
+{
+ return P_.delay_out_rec_;
}
inline size_t
eprop_iaf_psc_delta::send_test_event( Node& target, size_t receptor_type, synindex, bool )
{
SpikeEvent e;
+
+ // To perform a consistency check on the delay parameter d_out_rec between recurrent
+ // neurons and readout neurons, the recurrent neurons send a test event with a delay
+ // specified by d_rec_out. Upon receiving the test event from the recurrent neuron,
+ // the readout neuron checks if the delay with which the event was received matches
+ // its own specified delay parameter d_rec_out.
+ e.set_delay_steps( P_.delay_rec_out_ );
e.set_sender( *this );
return target.handles_test_event( e, receptor_type );
}
diff --git a/models/eprop_iaf_psc_delta_adapt.cpp b/models/eprop_iaf_psc_delta_adapt.cpp
index 9ffe2a8b69..74b9dc752b 100644
--- a/models/eprop_iaf_psc_delta_adapt.cpp
+++ b/models/eprop_iaf_psc_delta_adapt.cpp
@@ -445,7 +445,8 @@ eprop_iaf_psc_delta_adapt::handle( DataLoggingRequest& e )
void
eprop_iaf_psc_delta_adapt::compute_gradient( const long t_spike,
const long t_spike_previous,
- double& z_previous_buffer,
+ std::queue< double >& z_previous_buffer,
+ double& z_previous,
double& z_bar,
double& e_bar,
double& e_bar_reg,
@@ -454,13 +455,13 @@ eprop_iaf_psc_delta_adapt::compute_gradient( const long t_spike,
const CommonSynapseProperties& cp,
WeightOptimizer* optimizer )
{
- double e = 0.0; // eligibility trace
- double z = 0.0; // spiking variable
- double z_current_buffer = 1.0; // buffer containing the spike that triggered the current integration
- double psi = 0.0; // surrogate gradient
- double L = 0.0; // learning signal
- double firing_rate_reg = 0.0; // firing rate regularization
- double grad = 0.0; // gradient
+ double e = 0.0; // eligibility trace
+ double z = 0.0; // spiking variable
+ double z_current = 1.0; // spike state that triggered the current integration
+ double psi = 0.0; // surrogate gradient
+ double L = 0.0; // learning signal
+ double firing_rate_reg = 0.0; // firing rate regularization
+ double grad = 0.0; // gradient
const EpropSynapseCommonProperties& ecp = static_cast< const EpropSynapseCommonProperties& >( cp );
const auto optimize_each_step = ( *ecp.optimizer_cp_ ).optimize_each_step_;
@@ -471,9 +472,14 @@ eprop_iaf_psc_delta_adapt::compute_gradient( const long t_spike,
for ( long t = t_spike_previous; t < t_compute_until; ++t, ++eprop_hist_it )
{
- z = z_previous_buffer;
- z_previous_buffer = z_current_buffer;
- z_current_buffer = 0.0;
+ if ( P_.delay_total_ > 1 )
+ {
+ update_pre_syn_buffer_multiple_entries( z, z_current, z_previous, z_previous_buffer, t_spike, t );
+ }
+ else
+ {
+ update_pre_syn_buffer_one_entry( z, z_current, z_previous, z_previous_buffer, t_spike, t );
+ }
psi = eprop_hist_it->surrogate_gradient_;
L = eprop_hist_it->learning_signal_;
diff --git a/models/eprop_iaf_psc_delta_adapt.h b/models/eprop_iaf_psc_delta_adapt.h
index 3c0949de3a..e417b65567 100644
--- a/models/eprop_iaf_psc_delta_adapt.h
+++ b/models/eprop_iaf_psc_delta_adapt.h
@@ -418,6 +418,7 @@ class eprop_iaf_psc_delta_adapt : public EpropArchivingNodeRecurrent< false >
void compute_gradient( const long,
const long,
+ std::queue< double >&,
double&,
double&,
double&,
@@ -430,6 +431,9 @@ class eprop_iaf_psc_delta_adapt : public EpropArchivingNodeRecurrent< false >
long get_shift() const override;
bool is_eprop_recurrent_node() const override;
long get_eprop_isi_trace_cutoff() const override;
+ long get_delay_total() const override;
+ long get_delay_recurrent_to_readout() const override;
+ long get_delay_readout_to_recurrent() const override;
//! Map for storing a static set of recordables.
friend class RecordablesMap< eprop_iaf_psc_delta_adapt >;
@@ -498,6 +502,15 @@ class eprop_iaf_psc_delta_adapt : public EpropArchivingNodeRecurrent< false >
//! Time interval from the previous spike until the cutoff of e-prop update integration between two spikes (ms).
double eprop_isi_trace_cutoff_;
+ //! Connection delay from recurrent to readout neuron.
+ long delay_rec_out_;
+
+ //! Connection delay from readout to recurrent neuron.
+ long delay_out_rec_;
+
+ //! Sum of connection delays from recurrent to readout neuron and readout to recurrent neuron.
+ long delay_total_;
+
//! Default constructor.
Parameters_();
@@ -657,10 +670,35 @@ eprop_iaf_psc_delta_adapt::get_eprop_isi_trace_cutoff() const
return V_.eprop_isi_trace_cutoff_steps_;
}
+inline long
+eprop_iaf_psc_delta_adapt::get_delay_total() const
+{
+ return P_.delay_total_;
+}
+
+inline long
+eprop_iaf_psc_delta_adapt::get_delay_recurrent_to_readout() const
+{
+ return P_.delay_rec_out_;
+}
+
+inline long
+eprop_iaf_psc_delta_adapt::get_delay_readout_to_recurrent() const
+{
+ return P_.delay_out_rec_;
+}
+
inline size_t
eprop_iaf_psc_delta_adapt::send_test_event( Node& target, size_t receptor_type, synindex, bool )
{
SpikeEvent e;
+
+ // To perform a consistency check on the delay parameter d_out_rec between recurrent
+ // neurons and readout neurons, the recurrent neurons send a test event with a delay
+ // specified by d_rec_out. Upon receiving the test event from the recurrent neuron,
+ // the readout neuron checks if the delay with which the event was received matches
+ // its own specified delay parameter d_rec_out.
+ e.set_delay_steps( P_.delay_rec_out_ );
e.set_sender( *this );
return target.handles_test_event( e, receptor_type );
}
diff --git a/models/eprop_learning_signal_connection.h b/models/eprop_learning_signal_connection.h
index cc013a028b..9c84671488 100644
--- a/models/eprop_learning_signal_connection.h
+++ b/models/eprop_learning_signal_connection.h
@@ -163,6 +163,12 @@ class eprop_learning_signal_connection : public Connection< targetidentifierT >
{
LearningSignalConnectionEvent ge;
+ const long delay_out_rec = t.get_delay_readout_to_recurrent();
+ if ( delay_out_rec != get_delay_steps() )
+ {
+ throw IllegalConnection( "delay == delay_rec_out of target neuron required." );
+ }
+
s.sends_secondary_event( ge );
ge.set_sender( s );
Connection< targetidentifierT >::target_.set_rport( t.handles_test_event( ge, receptor_type ) );
diff --git a/models/eprop_readout.cpp b/models/eprop_readout.cpp
index b2f740d70e..7bfeb0aaa3 100644
--- a/models/eprop_readout.cpp
+++ b/models/eprop_readout.cpp
@@ -76,6 +76,8 @@ eprop_readout::Parameters_::Parameters_()
, tau_m_( 10.0 )
, V_min_( -std::numeric_limits< double >::max() )
, eprop_isi_trace_cutoff_( 1000.0 )
+ , delay_rec_out_( 1 )
+ , delay_out_rec_( 1 )
{
}
@@ -113,6 +115,8 @@ eprop_readout::Parameters_::get( DictionaryDatum& d ) const
def< double >( d, names::tau_m, tau_m_ );
def< double >( d, names::V_min, V_min_ + E_L_ );
def< double >( d, names::eprop_isi_trace_cutoff, eprop_isi_trace_cutoff_ );
+ def< double >( d, names::delay_rec_out, Time( Time::step( delay_rec_out_ ) ).get_ms() );
+ def< double >( d, names::delay_out_rec, Time( Time::step( delay_out_rec_ ) ).get_ms() );
}
double
@@ -130,6 +134,12 @@ eprop_readout::Parameters_::set( const DictionaryDatum& d, Node* node )
updateValueParam< double >( d, names::tau_m, tau_m_, node );
updateValueParam< double >( d, names::eprop_isi_trace_cutoff, eprop_isi_trace_cutoff_, node );
+ const double delay_rec_out_ = Time::step( delay_rec_out_ );
+ updateValueParam< double >( d, names::delay_rec_out, Time( delay_rec_out_).get_ms(), node );
+
+ const double delay_out_rec_ = Time::step( delay_out_rec_ );
+ updateValueParam< double >( d, names::delay_out_rec, Time( delay_out_rec_ ).get_ms(), node );
+
if ( C_m_ <= 0 )
{
throw BadProperty( "Membrane capacitance C_m > 0 required." );
@@ -145,6 +155,16 @@ eprop_readout::Parameters_::set( const DictionaryDatum& d, Node* node )
throw BadProperty( "Cutoff of integration of eprop trace between spikes eprop_isi_trace_cutoff ≥ 0 required." );
}
+ if ( delay_rec_out_ < 1 )
+ {
+ throw BadProperty( "Connection delay from recurrent to output neuron ≥ 1 required." );
+ }
+
+ if ( delay_out_rec_ < 1 )
+ {
+ throw BadProperty( "Connection delay from readout to recurrent neuron ≥ 1 required." );
+ }
+
return delta_EL;
}
@@ -207,6 +227,19 @@ eprop_readout::pre_run_hook()
V_.P_v_m_ = std::exp( -dt / P_.tau_m_ );
V_.P_i_in_ = P_.tau_m_ / P_.C_m_ * ( 1.0 - V_.P_v_m_ );
+
+ if ( eprop_history_.empty() )
+ {
+ for ( long t = -P_.delay_rec_out_; t < 0; ++t )
+ {
+ append_new_eprop_history_entry( t );
+ }
+
+ for ( long i = 0; i < P_.delay_out_rec_ - 1; i++ )
+ {
+ S_.error_signal_deque_.push_back( 0.0 );
+ }
+ }
}
@@ -237,7 +270,9 @@ eprop_readout::update( Time const& origin, const long from, const long to )
S_.readout_signal_ *= S_.learning_window_signal_;
S_.error_signal_ *= S_.learning_window_signal_;
- error_signal_buffer[ lag ] = S_.error_signal_;
+ S_.error_signal_deque_.push_back( S_.error_signal_ );
+ error_signal_buffer[ lag ] = S_.error_signal_deque_.front(); // get delay_out_rec-th value
+ S_.error_signal_deque_.pop_front();
append_new_eprop_history_entry( t );
write_error_signal_to_history( t, S_.error_signal_ );
@@ -307,7 +342,8 @@ eprop_readout::handle( DataLoggingRequest& e )
void
eprop_readout::compute_gradient( const long t_spike,
const long t_spike_previous,
- double& z_previous_buffer,
+ std::queue< double >& z_previous_buffer,
+ double& z_previous,
double& z_bar,
double& e_bar,
double& e_bar_reg,
@@ -316,10 +352,10 @@ eprop_readout::compute_gradient( const long t_spike,
const CommonSynapseProperties& cp,
WeightOptimizer* optimizer )
{
- double z = 0.0; // spiking variable
- double z_current_buffer = 1.0; // buffer containing the spike that triggered the current integration
- double L = 0.0; // error signal
- double grad = 0.0; // gradient
+ double z = 0.0; // spiking variable
+ double z_current = 1.0; // spike state that triggered the current integration
+ double L = 0.0; // error signal
+ double grad = 0.0; // gradient
const EpropSynapseCommonProperties& ecp = static_cast< const EpropSynapseCommonProperties& >( cp );
const auto optimize_each_step = ( *ecp.optimizer_cp_ ).optimize_each_step_;
@@ -330,9 +366,15 @@ eprop_readout::compute_gradient( const long t_spike,
for ( long t = t_spike_previous; t < t_compute_until; ++t, ++eprop_hist_it )
{
- z = z_previous_buffer;
- z_previous_buffer = z_current_buffer;
- z_current_buffer = 0.0;
+ if ( P_.delay_rec_out_ > 1 )
+ {
+ z = z_previous_buffer.front();
+ update_pre_syn_buffer_multiple_entries( z, z_current, z_previous, z_previous_buffer, t_spike, t );
+ }
+ else
+ {
+ update_pre_syn_buffer_one_entry( z, z_current, z_previous, z_previous_buffer, t_spike, t );
+ }
L = eprop_hist_it->error_signal_;
diff --git a/models/eprop_readout.h b/models/eprop_readout.h
index b96c007542..0a4e2af6ac 100644
--- a/models/eprop_readout.h
+++ b/models/eprop_readout.h
@@ -301,6 +301,7 @@ class eprop_readout : public EpropArchivingNodeReadout< false >
void compute_gradient( const long,
const long,
+ std::queue< double >&,
double&,
double&,
double&,
@@ -313,6 +314,7 @@ class eprop_readout : public EpropArchivingNodeReadout< false >
long get_shift() const override;
bool is_eprop_recurrent_node() const override;
long get_eprop_isi_trace_cutoff() const override;
+ long get_delay_total() const override;
//! Map for storing a static set of recordables.
friend class RecordablesMap< eprop_readout >;
@@ -341,6 +343,12 @@ class eprop_readout : public EpropArchivingNodeReadout< false >
//! Time interval from the previous spike until the cutoff of e-prop update integration between two spikes (ms).
double eprop_isi_trace_cutoff_;
+ //! Connection delay from recurrent to readout neuron.
+ long delay_rec_out_;
+
+ //! Connection delay from readout to recurrent neuron.
+ long delay_out_rec_;
+
//! Default constructor.
Parameters_();
@@ -383,6 +391,9 @@ class eprop_readout : public EpropArchivingNodeReadout< false >
//! Set the state variables.
void set( const DictionaryDatum&, const Parameters_&, double, Node* );
+
+ //! Queue to hold last delay_out_rec error signals.
+ std::deque< double > error_signal_deque_;
};
//! Structure of buffers.
@@ -493,14 +504,35 @@ eprop_readout::get_eprop_isi_trace_cutoff() const
return V_.eprop_isi_trace_cutoff_steps_;
}
+inline long
+eprop_readout::get_delay_total() const
+{
+ return P_.delay_rec_out_;
+}
+
inline size_t
-eprop_readout::handles_test_event( SpikeEvent&, size_t receptor_type )
+eprop_readout::handles_test_event( SpikeEvent& e, size_t receptor_type )
{
if ( receptor_type != 0 )
{
throw UnknownReceptorType( receptor_type, get_name() );
}
+ // To perform a consistency check on the delay parameter d_out_rec between recurrent
+ // neurons and readout neurons, the recurrent neurons send a test event with a delay
+ // specified by d_rec_out. Upon receiving the test event from the recurrent neuron,
+ // the readout neuron checks if the delay with which the event was received matches
+ // its own specified delay parameter d_rec_out.
+
+ // ensure that the spike event was not sent by a proxy node
+ if ( e.get_sender().get_node_id() != 0 )
+ {
+ if ( e.get_delay_steps() != P_.delay_rec_out_ )
+ {
+ throw IllegalConnection(
+ "delay_rec_out from recurrent to output neuron equal to delay_rec_out from output to recurrent neuron required." );
+ }
+ }
return 0;
}
diff --git a/models/eprop_synapse.h b/models/eprop_synapse.h
index 9b4a5679c2..bab2484232 100644
--- a/models/eprop_synapse.h
+++ b/models/eprop_synapse.h
@@ -271,6 +271,9 @@ class eprop_synapse : public Connection< targetidentifierT >
//! Update values in parameter dictionary.
void set_status( const DictionaryDatum& d, ConnectorModel& cm );
+ //! Initialize the presynaptic buffer.
+ void initialize_z_previous_buffer( const long delay_total );
+
//! Send the spike event.
bool send( Event& e, size_t thread, const EpropSynapseCommonProperties& cp );
@@ -333,7 +336,13 @@ class eprop_synapse : public Connection< targetidentifierT >
double epsilon_ = 0.0;
//! Value of spiking variable one time step before t_previous_spike_.
- double z_previous_buffer_ = 0.0;
+ double z_previous_ = 0.0;
+
+ //! Queue of length delay_total_ to hold previous spiking variables.
+ std::queue< double > z_previous_buffer_;
+
+ //! Sum of connection delays from recurrent to readout neuron and readout to recurrent neuron.
+ long delay_total_ = 0;
/**
* Optimizer
@@ -404,7 +413,7 @@ eprop_synapse< targetidentifierT >::operator=( const eprop_synapse& es )
e_bar_ = es.e_bar_;
e_bar_reg_ = es.e_bar_reg_;
epsilon_ = es.epsilon_;
- z_previous_buffer_ = es.z_previous_buffer_;
+ z_previous_ = es.z_previous_;
optimizer_ = es.optimizer_;
return *this;
@@ -445,7 +454,7 @@ eprop_synapse< targetidentifierT >::operator=( eprop_synapse&& es )
e_bar_ = es.e_bar_;
e_bar_reg_ = es.e_bar_reg_;
epsilon_ = es.epsilon_;
- z_previous_buffer_ = es.z_previous_buffer_;
+ z_previous_ = es.z_previous_;
optimizer_ = es.optimizer_;
@@ -463,11 +472,22 @@ eprop_synapse< targetidentifierT >::check_connection( Node& s,
const CommonPropertiesType& cp )
{
// When we get here, delay has been set so we can check it.
- if ( get_delay_steps() != 1 )
+ if ( get_delay_steps() < 1 )
{
throw IllegalConnection( "eprop synapses currently require a delay of one simulation step" );
}
+ const bool is_recurrent_node = t.is_eprop_recurrent_node();
+
+ if ( not is_recurrent_node )
+ {
+ const long delay_rec_out = t.get_delay_total();
+ if ( delay_rec_out != get_delay_steps() )
+ {
+ throw IllegalConnection( "delay == delay_rec_out of target neuron required." );
+ }
+ }
+
ConnTestDummyNode dummy_target;
ConnectionBase::check_connection_( dummy_target, s, t, receptor_type );
@@ -484,6 +504,17 @@ eprop_synapse< targetidentifierT >::delete_optimizer()
// do not set to nullptr to allow detection of double deletion
}
+template < typename targetidentifierT >
+void
+eprop_synapse< targetidentifierT >::initialize_z_previous_buffer( const long delay_total )
+{
+ for ( long i = 0; i < delay_total; i++ )
+ {
+ z_previous_buffer_.push( 0.0 );
+ }
+ z_previous_buffer_.push( 1.0 );
+}
+
template < typename targetidentifierT >
bool
eprop_synapse< targetidentifierT >::send( Event& e, size_t thread, const EpropSynapseCommonProperties& cp )
@@ -492,11 +523,28 @@ eprop_synapse< targetidentifierT >::send( Event& e, size_t thread, const EpropSy
assert( target );
const long t_spike = e.get_stamp().get_steps();
+ const long delay_total = target->get_delay_total();
if ( t_spike_previous_ != 0 )
{
- target->compute_gradient(
- t_spike, t_spike_previous_, z_previous_buffer_, z_bar_, e_bar_, e_bar_reg_, epsilon_, weight_, cp, optimizer_ );
+ target->compute_gradient( t_spike,
+ t_spike_previous_,
+ z_previous_buffer_,
+ z_previous_,
+ z_bar_,
+ e_bar_,
+ e_bar_reg_,
+ epsilon_,
+ weight_,
+ cp,
+ optimizer_ );
+ }
+ else
+ {
+ if ( delay_total > 1 )
+ {
+ initialize_z_previous_buffer( delay_total );
+ }
}
const long eprop_isi_trace_cutoff = target->get_eprop_isi_trace_cutoff();
diff --git a/nestkernel/eprop_archiving_node.cpp b/nestkernel/eprop_archiving_node.cpp
new file mode 100644
index 0000000000..324d4407f9
--- /dev/null
+++ b/nestkernel/eprop_archiving_node.cpp
@@ -0,0 +1,308 @@
+/*
+ * eprop_archiving_node.cpp
+ *
+ * This file is part of NEST.
+ *
+ * Copyright (C) 2004 The NEST Initiative
+ *
+ * NEST is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 2 of the License, or
+ * (at your option) any later version.
+ *
+ * NEST is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with NEST. If not, see .
+ *
+ */
+
+// nestkernel
+#include "eprop_archiving_node.h"
+#include "eprop_archiving_node_impl.h"
+#include "kernel_manager.h"
+
+// sli
+#include "dictutils.h"
+
+namespace nest
+{
+
+EpropArchivingNodeRecurrent::EpropArchivingNodeRecurrent()
+ : EpropArchivingNode()
+ , firing_rate_reg_( 0.0 )
+ , f_av_( 0.0 )
+ , n_spikes_( 0 )
+{
+}
+
+EpropArchivingNodeRecurrent::EpropArchivingNodeRecurrent( const EpropArchivingNodeRecurrent& n )
+ : EpropArchivingNode( n )
+ , firing_rate_reg_( n.firing_rate_reg_ )
+ , f_av_( n.f_av_ )
+ , n_spikes_( n.n_spikes_ )
+{
+}
+
+double
+EpropArchivingNodeRecurrent::compute_piecewise_linear_surrogate_gradient( const double r,
+ const double v_m,
+ const double v_th_adapt,
+ const double V_th,
+ const double beta,
+ const double gamma )
+{
+ if ( r > 0 )
+ {
+ return 0.0;
+ }
+
+ return gamma * std::max( 0.0, 1.0 - beta * std::fabs( ( v_m - v_th_adapt ) / V_th ) ) / V_th;
+}
+
+double
+EpropArchivingNodeRecurrent::compute_exponential_surrogate_gradient( const double r,
+ const double v_m,
+ const double v_th_adapt,
+ const double V_th,
+ const double beta,
+ const double gamma )
+{
+ if ( r > 0 )
+ {
+ return 0.0;
+ }
+
+ if ( fabs( V_th ) < 1e-6 )
+ {
+ throw BadProperty(
+ "Relative threshold voltage V_th-E_L ≠ 0 required if surrogate_gradient_function is \"piecewise_linear\"." );
+ }
+
+ return gamma * std::exp( -beta * std::fabs( v_m - v_th_adapt ) );
+}
+
+double
+EpropArchivingNodeRecurrent::compute_fast_sigmoid_derivative_surrogate_gradient( const double r,
+ const double v_m,
+ const double v_th_adapt,
+ const double V_th,
+ const double beta,
+ const double gamma )
+{
+ if ( r > 0 )
+ {
+ return 0.0;
+ }
+
+ return gamma * std::pow( 1.0 + beta * std::fabs( v_m - v_th_adapt ), -2 );
+}
+
+double
+EpropArchivingNodeRecurrent::compute_arctan_surrogate_gradient( const double r,
+ const double v_m,
+ const double v_th_adapt,
+ const double V_th,
+ const double beta,
+ const double gamma )
+{
+ if ( r > 0 )
+ {
+ return 0.0;
+ }
+
+ return gamma / M_PI * ( 1.0 / ( 1.0 + std::pow( beta * M_PI * ( v_m - v_th_adapt ), 2 ) ) );
+}
+
+void
+EpropArchivingNodeRecurrent::emplace_new_eprop_history_entry( const long time_step )
+{
+ if ( eprop_indegree_ == 0 )
+ {
+ return;
+ }
+
+ eprop_history_.emplace_back( time_step, 0.0, 0.0 );
+}
+
+void
+EpropArchivingNodeRecurrent::write_surrogate_gradient_to_history( const long time_step,
+ const double surrogate_gradient )
+{
+ if ( eprop_indegree_ == 0 )
+ {
+ return;
+ }
+
+ auto it_hist = get_eprop_history( time_step );
+ it_hist->surrogate_gradient_ = surrogate_gradient;
+}
+
+void
+EpropArchivingNodeRecurrent::write_learning_signal_to_history( const long time_step,
+ const double learning_signal,
+ const bool has_norm_step )
+{
+ if ( eprop_indegree_ == 0 )
+ {
+ return;
+ }
+
+ long shift = delay_out_rec_;
+
+ if ( has_norm_step )
+ {
+ shift += delay_rec_out_ + delay_out_norm_;
+ }
+ else
+ {
+ shift += get_delay_total();
+ }
+
+
+ auto it_hist = get_eprop_history( time_step - shift );
+ const auto it_hist_end = get_eprop_history( time_step - shift + delay_out_rec_ );
+
+ for ( ; it_hist != it_hist_end; ++it_hist )
+ {
+ it_hist->learning_signal_ += learning_signal;
+ }
+}
+
+void
+EpropArchivingNodeRecurrent::write_firing_rate_reg_to_history( const long t_current_update,
+ const double f_target,
+ const double c_reg )
+{
+ if ( eprop_indegree_ == 0 )
+ {
+ return;
+ }
+
+ const double update_interval = kernel().simulation_manager.get_eprop_update_interval().get_steps();
+ const double dt = Time::get_resolution().get_ms();
+ const long shift = Time::get_resolution().get_steps();
+
+ const double f_av = n_spikes_ / update_interval;
+ const double f_target_ = f_target * dt; // convert from spikes/ms to spikes/step
+ const double firing_rate_reg = c_reg * ( f_av - f_target_ ) / update_interval;
+
+ firing_rate_reg_history_.emplace_back( t_current_update + shift, firing_rate_reg );
+}
+
+void
+EpropArchivingNodeRecurrent::write_firing_rate_reg_to_history( const long t,
+ const double z,
+ const double f_target,
+ const double kappa,
+ const double c_reg )
+{
+ if ( eprop_indegree_ == 0 )
+ {
+ return;
+ }
+
+ const double dt = Time::get_resolution().get_ms();
+
+ const double f_target_ = f_target * dt; // convert from spikes/ms to spikes/step
+
+ f_av_ = kappa * f_av_ + ( 1.0 - kappa ) * z / dt;
+
+ firing_rate_reg_ = c_reg * ( f_av_ - f_target_ );
+
+ auto it_hist = get_eprop_history( t );
+ it_hist->learning_signal_ += firing_rate_reg_;
+}
+
+std::vector< HistEntryEpropFiringRateReg >::iterator
+EpropArchivingNodeRecurrent::get_firing_rate_reg_history( const long time_step )
+{
+ const auto it_hist = std::lower_bound( firing_rate_reg_history_.begin(), firing_rate_reg_history_.end(), time_step );
+ assert( it_hist != firing_rate_reg_history_.end() );
+
+ return it_hist;
+}
+
+double
+EpropArchivingNodeRecurrent::get_learning_signal_from_history( const long time_step, const bool has_norm_step )
+{
+ long shift = delay_rec_out_ + delay_out_rec_;
+
+ if ( has_norm_step )
+ {
+ shift += delay_out_norm_;
+ }
+
+ const auto it = get_eprop_history( time_step - shift );
+ if ( it == eprop_history_.end() )
+ {
+ return 0;
+ }
+
+ return it->learning_signal_;
+}
+
+void
+EpropArchivingNodeRecurrent::erase_used_firing_rate_reg_history()
+{
+ auto it_update_hist = update_history_.begin();
+ auto it_reg_hist = firing_rate_reg_history_.begin();
+
+ while ( it_update_hist != update_history_.end() and it_reg_hist != firing_rate_reg_history_.end() )
+ {
+ if ( it_update_hist->access_counter_ == 0 )
+ {
+ it_reg_hist = firing_rate_reg_history_.erase( it_reg_hist );
+ }
+ else
+ {
+ ++it_reg_hist;
+ }
+ ++it_update_hist;
+ }
+}
+
+EpropArchivingNodeReadout::EpropArchivingNodeReadout()
+ : EpropArchivingNode()
+{
+}
+
+EpropArchivingNodeReadout::EpropArchivingNodeReadout( const EpropArchivingNodeReadout& n )
+ : EpropArchivingNode( n )
+{
+}
+
+void
+EpropArchivingNodeReadout::emplace_new_eprop_history_entry( const long time_step, const bool has_norm_step )
+{
+ if ( eprop_indegree_ == 0 )
+ {
+ return;
+ }
+
+ const long shift = has_norm_step ? delay_out_norm_ : 0;
+
+ eprop_history_.emplace_back( time_step - shift, 0.0 );
+}
+
+void
+EpropArchivingNodeReadout::write_error_signal_to_history( const long time_step,
+ const double error_signal,
+ const bool has_norm_step )
+{
+ if ( eprop_indegree_ == 0 )
+ {
+ return;
+ }
+
+ const long shift = has_norm_step ? delay_out_norm_ : 0;
+
+ auto it_hist = get_eprop_history( time_step - shift );
+ it_hist->error_signal_ = error_signal;
+}
+
+
+} // namespace nest
diff --git a/nestkernel/eprop_archiving_node.h b/nestkernel/eprop_archiving_node.h
index 04cfc2d3ba..a213fb5af3 100644
--- a/nestkernel/eprop_archiving_node.h
+++ b/nestkernel/eprop_archiving_node.h
@@ -111,8 +111,33 @@ class EpropArchivingNode : public Node
*
* Retrieves the size of the eprop history buffer.
*/
+
double get_eprop_history_duration() const;
+ /**
+ * @brief Updates multiple entries in the presynaptic buffer.
+ *
+ * Used when the total synaptic delay is greater than one.
+ */
+ void update_pre_syn_buffer_multiple_entries( double& z,
+ double& z_current,
+ double& z_previous,
+ std::queue< double >& z_previous_buffer,
+ double t_spike,
+ double t );
+
+ /**
+ * @brief Updates one entry in the presynaptic buffer.
+ *
+ * Used when the total synaptic delay equals one.
+ */
+ void update_pre_syn_buffer_one_entry( double& z,
+ double& z_current,
+ double& z_previous,
+ std::queue< double >& z_previous_buffer,
+ double t_spike,
+ double t );
+
protected:
//! Returns correct shift for history depending on whether it is a normal or a bsshslm_2020 model.
virtual long model_dependent_history_shift_() const = 0;
diff --git a/nestkernel/eprop_archiving_node_impl.h b/nestkernel/eprop_archiving_node_impl.h
index 7c4c60a52e..697d496aa8 100644
--- a/nestkernel/eprop_archiving_node_impl.h
+++ b/nestkernel/eprop_archiving_node_impl.h
@@ -189,6 +189,47 @@ EpropArchivingNode< HistEntryT >::get_eprop_history_duration() const
return Time::get_resolution().get_ms() * eprop_history_.size();
}
+
+template < typename HistEntryT >
+void
+EpropArchivingNode< HistEntryT >::update_pre_syn_buffer_multiple_entries( double& z,
+ double& z_current,
+ double& z_previous,
+ std::queue< double >& z_previous_buffer,
+ double t_spike,
+ double t )
+{
+ if ( not z_previous_buffer.empty() )
+ {
+ z = z_previous_buffer.front();
+ z_previous_buffer.pop();
+ }
+
+ if ( t_spike - t > 1 )
+ {
+ z_previous_buffer.push( 0.0 );
+ }
+ else
+ {
+ z_previous_buffer.push( 1.0 );
+ }
+}
+
+template < typename HistEntryT >
+void
+EpropArchivingNode< HistEntryT >::update_pre_syn_buffer_one_entry( double& z,
+ double& z_current,
+ double& z_previous,
+ std::queue< double >& pre_syn_buffer,
+ double t_spike,
+ double t )
+{
+ z = z_previous;
+ z_previous = z_current;
+ z_current = 0.0;
+}
+
+
} // namespace nest
#endif // EPROP_ARCHIVING_NODE_IMPL_H
diff --git a/nestkernel/eprop_archiving_node_readout.h b/nestkernel/eprop_archiving_node_readout.h
index 97fd67ca27..ec20a52433 100644
--- a/nestkernel/eprop_archiving_node_readout.h
+++ b/nestkernel/eprop_archiving_node_readout.h
@@ -129,7 +129,7 @@ EpropArchivingNodeReadout< hist_shift_required >::model_dependent_history_shift_
}
else
{
- return -delay_rec_out_;
+ return -get_delay_total();
}
}
diff --git a/nestkernel/eprop_archiving_node_recurrent_impl.h b/nestkernel/eprop_archiving_node_recurrent_impl.h
index 58dde7a87e..12a4ebba4a 100644
--- a/nestkernel/eprop_archiving_node_recurrent_impl.h
+++ b/nestkernel/eprop_archiving_node_recurrent_impl.h
@@ -184,11 +184,15 @@ EpropArchivingNodeRecurrent< hist_shift_required >::write_learning_signal_to_his
return;
}
- long shift = delay_rec_out_ + delay_out_rec_;
+ long shift = delay_out_rec_;
if constexpr ( hist_shift_required )
{
- shift += delay_out_norm_;
+ shift += delay_rec_out_ + delay_out_norm_;
+ }
+ else
+ {
+ shift += get_delay_total();
}
diff --git a/nestkernel/nest_names.cpp b/nestkernel/nest_names.cpp
index 2b1d98435e..5d887da3e8 100644
--- a/nestkernel/nest_names.cpp
+++ b/nestkernel/nest_names.cpp
@@ -127,6 +127,8 @@ const Name dead_time( "dead_time" );
const Name dead_time_random( "dead_time_random" );
const Name dead_time_shape( "dead_time_shape" );
const Name delay( "delay" );
+const Name delay_out_rec( "delay_out_rec" );
+const Name delay_rec_out( "delay_rec_out" );
const Name delay_u_bars( "delay_u_bars" );
const Name deliver_interval( "deliver_interval" );
const Name delta( "delta" );
diff --git a/nestkernel/nest_names.h b/nestkernel/nest_names.h
index 44a75f62d0..3877274a33 100644
--- a/nestkernel/nest_names.h
+++ b/nestkernel/nest_names.h
@@ -154,6 +154,8 @@ extern const Name dead_time;
extern const Name dead_time_random;
extern const Name dead_time_shape;
extern const Name delay;
+extern const Name delay_out_rec;
+extern const Name delay_rec_out;
extern const Name delay_u_bars;
extern const Name deliver_interval;
extern const Name delta;
diff --git a/nestkernel/node.cpp b/nestkernel/node.cpp
index 6f54cc1075..ce4781dc79 100644
--- a/nestkernel/node.cpp
+++ b/nestkernel/node.cpp
@@ -230,6 +230,24 @@ Node::get_shift() const
throw IllegalConnection( "The target node is not an e-prop neuron." );
}
+long
+Node::get_delay_total() const
+{
+ throw IllegalConnection( "The target node is not an e-prop neuron." );
+}
+
+long
+Node::get_delay_recurrent_to_readout() const
+{
+ throw IllegalConnection( "The target node is not an e-prop neuron." );
+}
+
+long
+Node::get_delay_readout_to_recurrent() const
+{
+ throw IllegalConnection( "The target node is not an e-prop neuron." );
+}
+
void
Node::write_update_to_history( const long, const long, const long )
{
@@ -552,6 +570,7 @@ nest::Node::get_tau_syn_in( int )
void
nest::Node::compute_gradient( const long,
const long,
+ std::queue< double >&,
double&,
double&,
double&,
diff --git a/nestkernel/node.h b/nestkernel/node.h
index 5c41113f43..c10bc35c7a 100644
--- a/nestkernel/node.h
+++ b/nestkernel/node.h
@@ -26,6 +26,7 @@
// C++ includes:
#include
#include
+#include
#include
#include
#include
@@ -532,6 +533,30 @@ class Node
*/
virtual long get_eprop_isi_trace_cutoff() const;
+ /**
+ * Get sum of connection delays from recurrent to output neuron and output to recurrent neuron.
+ *
+ * @throws IllegalConnection
+ */
+
+ virtual long get_delay_total() const;
+
+ /**
+ * Get connection delay from recurrent to output neuron.
+ *
+ * @throws IllegalConnection
+ */
+
+ virtual long get_delay_recurrent_to_readout() const;
+
+ /**
+ * Get connection delay from output to recurrent neuron.
+ *
+ * @throws IllegalConnection
+ */
+
+ virtual long get_delay_readout_to_recurrent() const;
+
/**
* Checks if the node is part of the recurrent network and thus not a readout neuron.
*
@@ -836,6 +861,7 @@ class Node
* @param t_spike [in] Time of the current spike.
* @param t_spike_previous [in] Time of the previous spike.
* @param z_previous_buffer [in, out] Value of presynaptic spiking variable from previous time step.
+ * @param z_previous
* @param z_bar [in, out] Filtered presynaptic spiking variable.
* @param e_bar [in, out] Filtered eligibility trace.
* @param e_bar_reg [in, out] Filtered eligibility trace for firing rate regularization.
@@ -847,7 +873,8 @@ class Node
*/
virtual void compute_gradient( const long t_spike,
const long t_spike_previous,
- double& z_previous_buffer,
+ std::queue< double >& z_previous_buffer,
+ double& z_previous,
double& z_bar,
double& e_bar,
double& e_bar_reg,
diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation.py b/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation.py
new file mode 100644
index 0000000000..8232b4d97f
--- /dev/null
+++ b/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation.py
@@ -0,0 +1,852 @@
+# -*- coding: utf-8 -*-
+#
+# eprop_supervised_classification_evidence-accumulation.py
+#
+# This file is part of NEST.
+#
+# Copyright (C) 2004 The NEST Initiative
+#
+# NEST is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 2 of the License, or
+# (at your option) any later version.
+#
+# NEST is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with NEST. If not, see .
+
+r"""
+Tutorial on learning to accumulate evidence with e-prop
+-------------------------------------------------------
+
+Training a classification model using supervised e-prop plasticity to accumulate evidence.
+
+Description
+~~~~~~~~~~~
+
+This script demonstrates supervised learning of a classification task with the eligibility propagation (e-prop)
+plasticity mechanism by Bellec et al. [1]_.
+
+This type of learning is demonstrated at the proof-of-concept task in [1]_. We based this script on their
+TensorFlow script given in [2]_.
+
+The task, a so-called evidence accumulation task, is inspired by behavioral tasks, where a lab animal (e.g., a
+mouse) runs along a track, gets cues on the left and right, and has to decide at the end of the track between
+taking a left and a right turn of which one is correct. After a number of iterations, the animal is able to
+infer the underlying rationale of the task. Here, the solution is to turn to the side in which more cues were
+presented.
+
+.. image:: ../../../../pynest/examples/eprop_plasticity/eprop_supervised_classification_schematic_evidence-accumulation.png
+ :width: 70 %
+ :alt: Schematic of network architecture. Same as Figure 1 in the code.
+ :align: center
+
+Learning in the neural network model is achieved by optimizing the connection weights with e-prop plasticity.
+This plasticity rule requires a specific network architecture depicted in Figure 1. The neural network model
+consists of a recurrent network that receives input from Poisson generators and projects onto two readout
+neurons - one for the left and one for the right turn at the end. The input neuron population consists of four
+groups: one group providing background noise of a specific rate for some base activity throughout the
+experiment, one group providing the input spikes of the left cues and one group providing them for the right
+cues, and a last group defining the recall window, in which the network has to decide. The readout neuron
+compares the network signal :math:`\pi_k` with the teacher target signal :math:`\pi_k^*`, which it receives from
+a rate generator. Since the decision is at the end and all the cues are relevant, the network has to keep the
+cues in memory. Additional adaptive neurons in the network enable this memory. The network's training error is
+assessed by employing a mean squared error loss.
+
+Details on the event-based NEST implementation of e-prop can be found in [3]_.
+
+References
+~~~~~~~~~~
+
+.. [1] Bellec G, Scherr F, Subramoney F, Hajek E, Salaj D, Legenstein R, Maass W (2020). A solution to the
+ learning dilemma for recurrent networks of spiking neurons. Nature Communications, 11:3625.
+ https://doi.org/10.1038/s41467-020-17236-y
+
+.. [2] https://github.com/IGITUGraz/eligibility_propagation/blob/master/Figure_3_and_S7_e_prop_tutorials/tutorial_evidence_accumulation_with_alif.py
+
+.. [3] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Dahmen D, van Albada SJ, Plesser HE, Bolten M, Diesmann M.
+ Event-based implementation of eligibility propagation (in preparation)
+""" # pylint: disable=line-too-long # noqa: E501
+
+# %% ###########################################################################################################
+# Import libraries
+# ~~~~~~~~~~~~~~~~
+# We begin by importing all libraries required for the simulation, analysis, and visualization.
+
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+import nest
+import numpy as np
+from cycler import cycler
+from IPython.display import Image
+
+# %% ###########################################################################################################
+# Schematic of network architecture
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# This figure, identical to the one in the description, shows the required network architecture in the center,
+# the input and output of the evidence accumulation task above, and lists of the required NEST device, neuron,
+# and synapse models below. The connections that must be established are numbered 1 to 7.
+
+try:
+ Image(filename="./eprop_supervised_classification_schematic_evidence-accumulation.png")
+except Exception:
+ pass
+
+# %% ###########################################################################################################
+# Setup
+# ~~~~~
+
+# %% ###########################################################################################################
+# Initialize random generator
+# ...........................
+# We seed the numpy random generator, which will generate random initial weights as well as random input and
+# output.
+
+rng_seed = 1 # numpy random seed
+np.random.seed(rng_seed) # fix numpy random seed
+
+# %% ###########################################################################################################
+# Define timing of task
+# .....................
+# The task's temporal structure is then defined, once as time steps and once as durations in milliseconds.
+# Even though each sample is processed independently during training, we aggregate predictions and true
+# labels across a group of samples during the evaluation phase. The number of samples in this group is
+# determined by the `group_size` parameter. This data is then used to assess the neural network's
+# performance metrics, such as average accuracy and mean error. Increasing the number of iterations enhances
+# learning performance up to the point where overfitting occurs.
+
+group_size = 32 # number of instances over which to evaluate the learning performance
+n_iter = 50 # number of iterations
+
+n_input_symbols = 4 # number of input populations, e.g. 4 = left, right, recall, noise
+n_cues = 7 # number of cues given before decision
+prob_group = 0.3 # probability with which one input group is present
+
+steps = {
+ "cue": 100, # time steps in one cue presentation
+ "spacing": 50, # time steps of break between two cues
+ "bg_noise": 1050, # time steps of background noise
+ "recall": 150, # time steps of recall
+ "delay_rec_out": 1, # time steps of connection delay from recurrent to output neurons
+ "delay_out_rec": 1, # time steps of broadcast delay of learning signals
+}
+
+steps["cues"] = n_cues * (steps["cue"] + steps["spacing"]) # time steps of all cues
+steps["sequence"] = steps["cues"] + steps["bg_noise"] + steps["recall"] # time steps of one full sequence
+steps["learning_window"] = steps["recall"] # time steps of window with non-zero learning signals
+steps["task"] = n_iter * group_size * steps["sequence"] # time steps of task
+
+steps.update(
+ {
+ "offset_gen": 1, # offset since generator signals start from time step 1
+ "delay_in_rec": 1, # connection delay between input and recurrent neurons
+ "extension_sim": 3, # extra time step to close right-open simulation time interval in Simulate()
+ }
+)
+
+steps["delays"] = steps["delay_in_rec"] # time steps of delays
+
+steps["total_offset"] = steps["offset_gen"] + steps["delays"] # time steps of total offset
+
+steps["sim"] = steps["task"] + steps["total_offset"] + steps["extension_sim"] # time steps of simulation
+
+duration = {"step": 1.0} # ms, temporal resolution of the simulation
+
+duration.update({key: value * duration["step"] for key, value in steps.items()}) # ms, durations
+
+# %% ###########################################################################################################
+# Set up simulation
+# .................
+# As last step of the setup, we reset the NEST kernel to remove all existing NEST simulation settings and
+# objects and set some NEST kernel parameters.
+
+params_setup = {
+ "print_time": False, # if True, print time progress bar during simulation, set False if run as code cell
+ "resolution": duration["step"],
+ "total_num_virtual_procs": 1, # number of virtual processes, set in case of distributed computing
+}
+
+####################
+
+nest.ResetKernel()
+nest.set(**params_setup)
+
+# %% ###########################################################################################################
+# Create neurons
+# ~~~~~~~~~~~~~~
+# We proceed by creating a certain number of input, recurrent, and readout neurons and setting their parameters.
+# Additionally, we already create an input spike generator and an output target rate generator, which we will
+# configure later. Within the recurrent network, alongside a population of regular neurons, we introduce a
+# population of adaptive neurons, to enhance the network's memory retention.
+
+n_in = 40 # number of input neurons
+n_ad = 50 # number of adaptive neurons
+n_reg = 50 # number of regular neurons
+n_rec = n_ad + n_reg # number of recurrent neurons
+n_out = 2 # number of readout neurons
+
+params_nrn_out = {
+ "C_m": 1.0, # pF, membrane capacitance - takes effect only if neurons get current input (here not the case)
+ "E_L": 0.0, # mV, leak / resting membrane potential
+ "eprop_isi_trace_cutoff": 100, # cutoff of integration of eprop trace between spikes
+ "I_e": 0.0, # pA, external current input
+ "regular_spike_arrival": False, # If True, input spikes arrive at end of time step, if False at beginning
+ "tau_m": 20.0, # ms, membrane time constant
+ "V_m": 0.0, # mV, initial value of the membrane voltage
+ "delay_out_rec": duration["delay_out_rec"], # ms, broadcast delay of learning signals
+ "delay_rec_out": duration["delay_rec_out"], # ms, connection delay from recurrent to output neurons
+}
+
+params_nrn_reg = {
+ "beta": 1.0, # width scaling of the pseudo-derivative
+ "C_m": 1.0,
+ "c_reg": 300.0 / duration["sequence"] * duration["learning_window"], # firing rate regularization scaling
+ "E_L": 0.0,
+ "eprop_isi_trace_cutoff": 100,
+ "f_target": 10.0, # spikes/s, target firing rate for firing rate regularization
+ "gamma": 0.3, # height scaling of the pseudo-derivative
+ "I_e": 0.0,
+ "regular_spike_arrival": True,
+ "surrogate_gradient_function": "piecewise_linear", # surrogate gradient / pseudo-derivative function
+ "t_ref": 5.0, # ms, duration of refractory period
+ "tau_m": 20.0,
+ "V_m": 0.0,
+ "V_th": 0.6, # mV, spike threshold membrane voltage
+ "kappa": 0.97, # low-pass filter of the eligibility trace
+ "delay_out_rec": duration["delay_out_rec"], # ms, broadcast delay of learning signals
+ "delay_rec_out": duration["delay_rec_out"], # ms, connection delay from recurrent to output neurons
+}
+
+params_nrn_ad = {
+ "beta": 1.0,
+ "adapt_tau": 2000.0, # ms, time constant of adaptive threshold
+ "adaptation": 0.0, # initial value of the spike threshold adaptation
+ "C_m": 1.0,
+ "c_reg": 300.0 / duration["sequence"] * duration["learning_window"],
+ "E_L": 0.0,
+ "eprop_isi_trace_cutoff": 100, # cutoff of integration of eprop trace between spikes
+ "f_target": 10.0,
+ "gamma": 0.3,
+ "I_e": 0.0,
+ "regular_spike_arrival": True,
+ "surrogate_gradient_function": "piecewise_linear",
+ "t_ref": 5.0,
+ "tau_m": 20.0,
+ "V_m": 0.0,
+ "V_th": 0.6,
+ "kappa": 0.97,
+ "delay_out_rec": duration["delay_out_rec"], # ms, broadcast delay of learning signals
+ "delay_rec_out": duration["delay_rec_out"], # ms, connection delay from recurrent to output neurons
+}
+
+params_nrn_ad["adapt_beta"] = 1.7 * (
+ (1.0 - np.exp(-duration["step"] / params_nrn_ad["adapt_tau"]))
+ / (1.0 - np.exp(-duration["step"] / params_nrn_ad["tau_m"]))
+) # prefactor of adaptive threshold
+
+####################
+
+# Intermediate parrot neurons required between input spike generators and recurrent neurons,
+# since devices cannot establish plastic synapses for technical reasons
+
+gen_spk_in = nest.Create("spike_generator", n_in)
+nrns_in = nest.Create("parrot_neuron", n_in)
+
+nrns_reg = nest.Create("eprop_iaf", n_reg, params_nrn_reg)
+nrns_ad = nest.Create("eprop_iaf_adapt", n_ad, params_nrn_ad)
+nrns_out = nest.Create("eprop_readout", n_out, params_nrn_out)
+gen_rate_target = nest.Create("step_rate_generator", n_out)
+gen_learning_window = nest.Create("step_rate_generator")
+
+nrns_rec = nrns_reg + nrns_ad
+
+# %% ###########################################################################################################
+# Create recorders
+# ~~~~~~~~~~~~~~~~
+# We also create recorders, which, while not required for the training, will allow us to track various dynamic
+# variables of the neurons, spikes, and changes in synaptic weights. To save computing time and memory, the
+# recorders, the recorded variables, neurons, and synapses can be limited to the ones relevant to the
+# experiment, and the recording interval can be increased (see the documentation on the specific recorders). By
+# default, recordings are stored in memory but can also be written to file.
+
+n_record = 1 # number of neurons per type to record dynamic variables from - this script requires n_record >= 1
+n_record_w = 5 # number of senders and targets to record weights from - this script requires n_record_w >=1
+
+if n_record == 0 or n_record_w == 0:
+ raise ValueError("n_record and n_record_w >= 1 required")
+
+params_mm_reg = {
+ "interval": duration["step"], # interval between two recorded time points
+ "record_from": ["V_m", "surrogate_gradient", "learning_signal"], # dynamic variables to record
+ "start": duration["offset_gen"] + duration["delay_in_rec"], # start time of recording
+ "stop": duration["offset_gen"] + duration["delay_in_rec"] + duration["task"], # stop time of recording
+ "label": "multimeter_reg",
+}
+
+params_mm_ad = {
+ "interval": duration["step"],
+ "record_from": params_mm_reg["record_from"] + ["V_th_adapt", "adaptation"],
+ "start": duration["offset_gen"] + duration["delay_in_rec"],
+ "stop": duration["offset_gen"] + duration["delay_in_rec"] + duration["task"],
+ "label": "multimeter_ad",
+}
+
+params_mm_out = {
+ "interval": duration["step"],
+ "record_from": ["V_m", "readout_signal", "target_signal", "error_signal"],
+ "start": duration["total_offset"],
+ "stop": duration["total_offset"] + duration["task"],
+ "label": "multimeter_out",
+}
+
+params_wr = {
+ "senders": nrns_in[:n_record_w] + nrns_rec[:n_record_w], # limit senders to subsample weights to record
+ "targets": nrns_rec[:n_record_w] + nrns_out, # limit targets to subsample weights to record from
+ "start": duration["total_offset"],
+ "stop": duration["total_offset"] + duration["task"],
+ "label": "weight_recorder",
+}
+
+params_sr_in = {
+ "start": duration["offset_gen"],
+ "stop": duration["total_offset"] + duration["task"],
+ "label": "spike_recorder_in",
+}
+
+params_sr_reg = {
+ "start": duration["offset_gen"],
+ "stop": duration["total_offset"] + duration["task"],
+ "label": "spike_recorder_reg",
+}
+
+params_sr_ad = {
+ "start": duration["offset_gen"],
+ "stop": duration["total_offset"] + duration["task"],
+ "label": "spike_recorder_ad",
+}
+
+####################
+
+mm_reg = nest.Create("multimeter", params_mm_reg)
+mm_ad = nest.Create("multimeter", params_mm_ad)
+mm_out = nest.Create("multimeter", params_mm_out)
+sr_in = nest.Create("spike_recorder", params_sr_in)
+sr_reg = nest.Create("spike_recorder", params_sr_reg)
+sr_ad = nest.Create("spike_recorder", params_sr_ad)
+wr = nest.Create("weight_recorder", params_wr)
+
+nrns_reg_record = nrns_reg[:n_record]
+nrns_ad_record = nrns_ad[:n_record]
+
+# %% ###########################################################################################################
+# Create connections
+# ~~~~~~~~~~~~~~~~~~
+# Now, we define the connectivity and set up the synaptic parameters, with the synaptic weights drawn from
+# normal distributions. After these preparations, we establish the enumerated connections of the core network,
+# as well as additional connections to the recorders.
+
+params_conn_all_to_all = {"rule": "all_to_all", "allow_autapses": False}
+params_conn_one_to_one = {"rule": "one_to_one"}
+
+
+def calculate_glorot_dist(fan_in, fan_out):
+ glorot_scale = 1.0 / max(1.0, (fan_in + fan_out) / 2.0)
+ glorot_limit = np.sqrt(3.0 * glorot_scale)
+ glorot_distribution = np.random.uniform(low=-glorot_limit, high=glorot_limit, size=(fan_in, fan_out))
+ return glorot_distribution
+
+
+dtype_weights = np.float32 # data type of weights - for reproducing TF results set to np.float32
+weights_in_rec = np.array(np.random.randn(n_in, n_rec).T / np.sqrt(n_in), dtype=dtype_weights)
+weights_rec_rec = np.array(np.random.randn(n_rec, n_rec).T / np.sqrt(n_rec), dtype=dtype_weights)
+np.fill_diagonal(weights_rec_rec, 0.0) # since no autapses set corresponding weights to zero
+weights_rec_out = np.array(calculate_glorot_dist(n_rec, n_out).T, dtype=dtype_weights)
+weights_out_rec = np.array(np.random.randn(n_rec, n_out), dtype=dtype_weights)
+
+params_common_syn_eprop = {
+ "optimizer": {
+ "type": "adam", # algorithm to optimize the weights
+ "batch_size": 1,
+ "beta_1": 0.9, # exponential decay rate for 1st moment estimate of Adam optimizer
+ "beta_2": 0.999, # exponential decay rate for 2nd moment raw estimate of Adam optimizer
+ "epsilon": 1e-8, # small numerical stabilization constant of Adam optimizer
+ "eta": 5e-3 / duration["learning_window"], # learning rate
+ "optimize_each_step": True, # call optimizer every time step (True) or once per spike (False); only
+ # True implements original Adam algorithm, False offers speed-up; choice can affect learning performance
+ "Wmin": -100.0, # pA, minimal limit of the synaptic weights
+ "Wmax": 100.0, # pA, maximal limit of the synaptic weights
+ },
+ "weight_recorder": wr,
+}
+
+params_syn_base = {
+ "synapse_model": "eprop_synapse",
+ "delay": duration["step"], # ms, dendritic delay
+}
+
+params_syn_in = params_syn_base.copy()
+params_syn_in["weight"] = weights_in_rec # pA, initial values for the synaptic weights
+
+params_syn_rec = params_syn_base.copy()
+params_syn_rec["weight"] = weights_rec_rec
+
+params_syn_out = params_syn_base.copy()
+params_syn_out["weight"] = weights_rec_out
+params_syn_out["delay"] = duration["delay_rec_out"]
+
+params_syn_feedback = {
+ "synapse_model": "eprop_learning_signal_connection",
+ "delay": duration["delay_out_rec"],
+ "weight": weights_out_rec,
+}
+
+params_syn_learning_window = {
+ "synapse_model": "rate_connection_delayed",
+ "delay": duration["step"],
+ "receptor_type": 1, # receptor type over which readout neuron receives learning window signal
+}
+
+params_syn_rate_target = {
+ "synapse_model": "rate_connection_delayed",
+ "delay": duration["step"],
+ "receptor_type": 2, # receptor type over which readout neuron receives target signal
+}
+
+params_syn_static = {
+ "synapse_model": "static_synapse",
+ "delay": duration["step"],
+}
+
+params_init_optimizer = {
+ "optimizer": {
+ "m": 0.0, # initial 1st moment estimate m of Adam optimizer
+ "v": 0.0, # initial 2nd moment raw estimate v of Adam optimizer
+ }
+}
+
+####################
+
+nest.SetDefaults("eprop_synapse", params_common_syn_eprop)
+
+nest.Connect(gen_spk_in, nrns_in, params_conn_one_to_one, params_syn_static) # connection 1
+nest.Connect(nrns_in, nrns_rec, params_conn_all_to_all, params_syn_in) # connection 2
+nest.Connect(nrns_rec, nrns_rec, params_conn_all_to_all, params_syn_rec) # connection 3
+nest.Connect(nrns_rec, nrns_out, params_conn_all_to_all, params_syn_out) # connection 4
+nest.Connect(nrns_out, nrns_rec, params_conn_all_to_all, params_syn_feedback) # connection 5
+nest.Connect(gen_rate_target, nrns_out, params_conn_one_to_one, params_syn_rate_target) # connection 6
+nest.Connect(gen_learning_window, nrns_out, params_conn_all_to_all, params_syn_learning_window) # connection 7
+
+nest.Connect(nrns_in, sr_in, params_conn_all_to_all, params_syn_static)
+nest.Connect(nrns_reg, sr_reg, params_conn_all_to_all, params_syn_static)
+nest.Connect(nrns_ad, sr_ad, params_conn_all_to_all, params_syn_static)
+
+nest.Connect(mm_reg, nrns_reg_record, params_conn_all_to_all, params_syn_static)
+nest.Connect(mm_ad, nrns_ad_record, params_conn_all_to_all, params_syn_static)
+nest.Connect(mm_out, nrns_out, params_conn_all_to_all, params_syn_static)
+
+# After creating the connections, we can individually initialize the optimizer's
+# dynamic variables for single synapses (here exemplarily for two connections).
+
+nest.GetConnections(nrns_rec[0], nrns_rec[1:3]).set([params_init_optimizer] * 2)
+
+# %% ###########################################################################################################
+# Create input and output
+# ~~~~~~~~~~~~~~~~~~~~~~~
+# We generate the input as four neuron populations, two producing the left and right cues, respectively, one the
+# recall signal and one the background input throughout the task. The sequence of cues is drawn with a
+# probability that favors one side. For each such sequence, the favored side, the solution or target, is
+# assigned randomly to the left or right.
+
+
+def generate_evidence_accumulation_input_output(
+ batch_size, n_in, prob_group, input_spike_prob, n_cues, n_input_symbols, steps
+):
+ n_pop_nrn = n_in // n_input_symbols
+
+ prob_choices = np.array([prob_group, 1 - prob_group], dtype=np.float32)
+ idx = np.random.choice([0, 1], batch_size)
+ probs = np.zeros((batch_size, 2), dtype=np.float32)
+ probs[:, 0] = prob_choices[idx]
+ probs[:, 1] = prob_choices[1 - idx]
+
+ batched_cues = np.zeros((batch_size, n_cues), dtype=int)
+ for b_idx in range(batch_size):
+ batched_cues[b_idx, :] = np.random.choice([0, 1], n_cues, p=probs[b_idx])
+
+ input_spike_probs = np.zeros((batch_size, steps["sequence"], n_in))
+
+ for b_idx in range(batch_size):
+ for c_idx in range(n_cues):
+ cue = batched_cues[b_idx, c_idx]
+
+ step_start = c_idx * (steps["cue"] + steps["spacing"]) + steps["spacing"]
+ step_stop = step_start + steps["cue"]
+
+ pop_nrn_start = cue * n_pop_nrn
+ pop_nrn_stop = pop_nrn_start + n_pop_nrn
+
+ input_spike_probs[b_idx, step_start:step_stop, pop_nrn_start:pop_nrn_stop] = input_spike_prob
+
+ input_spike_probs[:, -steps["recall"] :, 2 * n_pop_nrn : 3 * n_pop_nrn] = input_spike_prob
+ input_spike_probs[:, :, 3 * n_pop_nrn :] = input_spike_prob / 4.0
+ input_spike_bools = input_spike_probs > np.random.rand(input_spike_probs.size).reshape(input_spike_probs.shape)
+ input_spike_bools[:, 0, :] = 0 # remove spikes in 0th time step of every sequence for technical reasons
+
+ target_cues = np.zeros(batch_size, dtype=int)
+ target_cues[:] = np.sum(batched_cues, axis=1) > int(n_cues / 2)
+
+ return input_spike_bools, target_cues
+
+
+input_spike_prob = 0.04 # spike probability of frozen input noise
+dtype_in_spks = np.float32 # data type of input spikes - for reproducing TF results set to np.float32
+
+input_spike_bools_list = []
+target_cues_list = []
+
+for _ in range(n_iter):
+ input_spike_bools, target_cues = generate_evidence_accumulation_input_output(
+ group_size, n_in, prob_group, input_spike_prob, n_cues, n_input_symbols, steps
+ )
+ input_spike_bools_list.append(input_spike_bools)
+ target_cues_list.extend(target_cues)
+
+input_spike_bools_arr = np.array(input_spike_bools_list).reshape(steps["task"], n_in)
+timeline_task = np.arange(0.0, duration["task"], duration["step"]) + duration["offset_gen"]
+
+params_gen_spk_in = [
+ {"spike_times": timeline_task[input_spike_bools_arr[:, nrn_in_idx]].astype(dtype_in_spks)}
+ for nrn_in_idx in range(n_in)
+]
+
+target_rate_changes = np.zeros((n_out, group_size * n_iter))
+target_rate_changes[np.array(target_cues_list), np.arange(group_size * n_iter)] = 1
+
+params_gen_rate_target = [
+ {
+ "amplitude_times": np.arange(0.0, duration["task"], duration["sequence"]) + duration["total_offset"],
+ "amplitude_values": target_rate_changes[nrn_out_idx],
+ }
+ for nrn_out_idx in range(n_out)
+]
+
+####################
+
+nest.SetStatus(gen_spk_in, params_gen_spk_in)
+nest.SetStatus(gen_rate_target, params_gen_rate_target)
+
+# %% ###########################################################################################################
+# Create learning window
+# ~~~~~~~~~~~~~~~~~~~~~~
+# Custom learning windows, in which the network learns, can be defined with an additional signal. The error
+# signal is internally multiplied with this learning window signal. Passing a learning window signal of value 1
+# opens the learning window while passing a value of 0 closes it.
+
+amplitude_times = np.hstack(
+ [
+ np.array([0.0, duration["sequence"] - duration["learning_window"]])
+ + duration["total_offset"]
+ + i * duration["sequence"]
+ for i in range(group_size * n_iter)
+ ]
+)
+
+amplitude_values = np.array([0.0, 1.0] * group_size * n_iter)
+
+params_gen_learning_window = {
+ "amplitude_times": amplitude_times,
+ "amplitude_values": amplitude_values,
+}
+
+####################
+
+nest.SetStatus(gen_learning_window, params_gen_learning_window)
+
+# %% ###########################################################################################################
+# Force final update
+# ~~~~~~~~~~~~~~~~~~
+# Synapses only get active, that is, the correct weight update calculated and applied, when they transmit a
+# spike. To still be able to read out the correct weights at the end of the simulation, we force spiking of the
+# presynaptic neuron and thus an update of all synapses, including those that have not transmitted a spike in
+# the last update interval, by sending a strong spike to all neurons that form the presynaptic side of an eprop
+# synapse. This step is required purely for technical reasons.
+
+gen_spk_final_update = nest.Create("spike_generator", 1, {"spike_times": [duration["task"] + duration["delays"]]})
+
+nest.Connect(gen_spk_final_update, nrns_in + nrns_rec, "all_to_all", {"weight": 1000.0})
+
+# %% ###########################################################################################################
+# Read out pre-training weights
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# Before we begin training, we read out the initial weight matrices so that we can eventually compare them to
+# the optimized weights.
+
+
+def get_weights(pop_pre, pop_post):
+ conns = nest.GetConnections(pop_pre, pop_post).get(["source", "target", "weight"])
+ conns["senders"] = np.array(conns["source"]) - np.min(conns["source"])
+ conns["targets"] = np.array(conns["target"]) - np.min(conns["target"])
+
+ conns["weight_matrix"] = np.zeros((len(pop_post), len(pop_pre)))
+ conns["weight_matrix"][conns["targets"], conns["senders"]] = conns["weight"]
+ return conns
+
+
+weights_pre_train = {
+ "in_rec": get_weights(nrns_in, nrns_rec),
+ "rec_rec": get_weights(nrns_rec, nrns_rec),
+ "rec_out": get_weights(nrns_rec, nrns_out),
+}
+
+# %% ###########################################################################################################
+# Simulate
+# ~~~~~~~~
+# We train the network by simulating for a set simulation time, determined by the number of iterations and the
+# batch size and the length of one sequence.
+
+nest.Simulate(duration["sim"])
+
+# %% ###########################################################################################################
+# Read out post-training weights
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# After the training, we can read out the optimized final weights.
+
+weights_post_train = {
+ "in_rec": get_weights(nrns_in, nrns_rec),
+ "rec_rec": get_weights(nrns_rec, nrns_rec),
+ "rec_out": get_weights(nrns_rec, nrns_out),
+}
+
+# %% ###########################################################################################################
+# Read out recorders
+# ~~~~~~~~~~~~~~~~~~
+# We can also retrieve the recorded history of the dynamic variables and weights, as well as detected spikes.
+
+events_mm_reg = mm_reg.get("events")
+events_mm_ad = mm_ad.get("events")
+events_mm_out = mm_out.get("events")
+events_sr_in = sr_in.get("events")
+events_sr_reg = sr_reg.get("events")
+events_sr_ad = sr_ad.get("events")
+events_wr = wr.get("events")
+
+# %% ###########################################################################################################
+# Evaluate training error
+# ~~~~~~~~~~~~~~~~~~~~~~~
+# We evaluate the network's training error by calculating a loss - in this case, the mean squared error between
+# the integrated recurrent network activity and the target rate.
+
+readout_signal = events_mm_out["readout_signal"]
+target_signal = events_mm_out["target_signal"]
+senders = events_mm_out["senders"]
+
+readout_signal = np.array([readout_signal[senders == i] for i in set(senders)])
+target_signal = np.array([target_signal[senders == i] for i in set(senders)])
+
+readout_signal = readout_signal.reshape((n_out, n_iter, group_size, steps["sequence"]))
+target_signal = target_signal.reshape((n_out, n_iter, group_size, steps["sequence"]))
+
+readout_signal = readout_signal[:, :, :, -steps["learning_window"] :]
+target_signal = target_signal[:, :, :, -steps["learning_window"] :]
+
+loss = 0.5 * np.mean(np.sum((readout_signal - target_signal) ** 2, axis=3), axis=(0, 2))
+
+y_prediction = np.argmax(np.mean(readout_signal, axis=3), axis=0)
+y_target = np.argmax(np.mean(target_signal, axis=3), axis=0)
+accuracy = np.mean((y_target == y_prediction), axis=1)
+recall_errors = 1.0 - accuracy
+
+# %% ###########################################################################################################
+# Plot results
+# ~~~~~~~~~~~~
+# Then, we plot a series of plots.
+
+do_plotting = True # if True, plot the results
+
+if not do_plotting:
+ exit()
+
+colors = {
+ "blue": "#2854c5ff",
+ "red": "#e04b40ff",
+ "white": "#ffffffff",
+}
+
+plt.rcParams.update(
+ {
+ "font.sans-serif": "Arial",
+ "axes.spines.right": False,
+ "axes.spines.top": False,
+ "axes.prop_cycle": cycler(color=[colors["blue"], colors["red"]]),
+ }
+)
+
+# %% ###########################################################################################################
+# Plot training error
+# ...................
+# We begin with two plots visualizing the training error of the network: the loss and the recall error, both
+# plotted against the iterations.
+
+fig, axs = plt.subplots(2, 1, sharex=True)
+fig.suptitle("Training error")
+
+axs[0].plot(range(1, n_iter + 1), loss)
+axs[0].set_ylabel(r"$E = \frac{1}{2} \sum_{t,k} \left( y_k^t -y_k^{*,t}\right)^2$")
+
+axs[1].plot(range(1, n_iter + 1), recall_errors)
+axs[1].set_ylabel("recall errors")
+
+axs[-1].set_xlabel("training iteration")
+axs[-1].set_xlim(1, n_iter)
+axs[-1].xaxis.get_major_locator().set_params(integer=True)
+
+fig.tight_layout()
+
+# %% ###########################################################################################################
+# Plot spikes and dynamic variables
+# .................................
+# This plotting routine shows how to plot all of the recorded dynamic variables and spikes across time. We take
+# one snapshot in the first iteration and one snapshot at the end.
+
+
+def plot_recordable(ax, events, recordable, ylabel, xlims):
+ for sender in set(events["senders"]):
+ idc_sender = events["senders"] == sender
+ idc_times = (events["times"][idc_sender] > xlims[0]) & (events["times"][idc_sender] < xlims[1])
+ ax.plot(events["times"][idc_sender][idc_times], events[recordable][idc_sender][idc_times], lw=0.5)
+ ax.set_ylabel(ylabel)
+ margin = np.abs(np.max(events[recordable]) - np.min(events[recordable])) * 0.1
+ ax.set_ylim(np.min(events[recordable]) - margin, np.max(events[recordable]) + margin)
+
+
+def plot_spikes(ax, events, ylabel, xlims):
+ idc_times = (events["times"] > xlims[0]) & (events["times"] < xlims[1])
+ senders_subset = events["senders"][idc_times]
+ times_subset = events["times"][idc_times]
+
+ ax.scatter(times_subset, senders_subset, s=0.1)
+ ax.set_ylabel(ylabel)
+ margin = np.abs(np.max(senders_subset) - np.min(senders_subset)) * 0.1
+ ax.set_ylim(np.min(senders_subset) - margin, np.max(senders_subset) + margin)
+
+
+for title, xlims in zip(
+ ["Dynamic variables before training", "Dynamic variables after training"],
+ [(0, steps["sequence"]), (steps["task"] - steps["sequence"], steps["task"])],
+):
+ fig, axs = plt.subplots(14, 1, sharex=True, figsize=(8, 14), gridspec_kw={"hspace": 0.4, "left": 0.2})
+ fig.suptitle(title)
+
+ plot_spikes(axs[0], events_sr_in, r"$z_i$" + "\n", xlims)
+ plot_spikes(axs[1], events_sr_reg, r"$z_j$" + "\n", xlims)
+
+ plot_recordable(axs[2], events_mm_reg, "V_m", r"$v_j$" + "\n(mV)", xlims)
+ plot_recordable(axs[3], events_mm_reg, "surrogate_gradient", r"$\psi_j$" + "\n", xlims)
+ plot_recordable(axs[4], events_mm_reg, "learning_signal", r"$L_j$" + "\n(pA)", xlims)
+
+ plot_spikes(axs[5], events_sr_ad, r"$z_j$" + "\n", xlims)
+
+ plot_recordable(axs[6], events_mm_ad, "V_m", r"$v_j$" + "\n(mV)", xlims)
+ plot_recordable(axs[7], events_mm_ad, "surrogate_gradient", r"$\psi_j$" + "\n", xlims)
+ plot_recordable(axs[8], events_mm_ad, "V_th_adapt", r"$A_j$" + "\n(mV)", xlims)
+ plot_recordable(axs[9], events_mm_ad, "learning_signal", r"$L_j$" + "\n(pA)", xlims)
+
+ plot_recordable(axs[10], events_mm_out, "V_m", r"$v_k$" + "\n(mV)", xlims)
+ plot_recordable(axs[11], events_mm_out, "target_signal", r"$y^*_k$" + "\n", xlims)
+ plot_recordable(axs[12], events_mm_out, "readout_signal", r"$y_k$" + "\n", xlims)
+ plot_recordable(axs[13], events_mm_out, "error_signal", r"$y_k-y^*_k$" + "\n", xlims)
+
+ axs[-1].set_xlabel(r"$t$ (ms)")
+ axs[-1].set_xlim(*xlims)
+
+ fig.align_ylabels()
+
+# %% ###########################################################################################################
+# Plot weight time courses
+# ........................
+# Similarly, we can plot the weight histories. Note that the weight recorder, attached to the synapses, works
+# differently than the other recorders. Since synapses only get activated when they transmit a spike, the weight
+# recorder only records the weight in those moments. That is why the first weight registrations do not start in
+# the first time step and we add the initial weights manually.
+
+
+def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabel):
+ for sender in nrns_senders.tolist():
+ for target in nrns_targets.tolist():
+ idc_syn = (events["senders"] == sender) & (events["targets"] == target)
+ idc_syn_pre = (weights_pre_train[label]["source"] == sender) & (
+ weights_pre_train[label]["target"] == target
+ )
+
+ times = [0.0] + events["times"][idc_syn].tolist()
+ weights = [weights_pre_train[label]["weight"][idc_syn_pre]] + events["weights"][idc_syn].tolist()
+
+ ax.step(times, weights, c=colors["blue"])
+ ax.set_ylabel(ylabel)
+ ax.set_ylim(-1.5, 1.5)
+
+
+fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4))
+fig.suptitle("Weight time courses")
+
+plot_weight_time_course(axs[0], events_wr, nrns_in[:n_record_w], nrns_rec[:n_record_w], "in_rec", r"$W_\text{in}$ (pA)")
+plot_weight_time_course(
+ axs[1], events_wr, nrns_rec[:n_record_w], nrns_rec[:n_record_w], "rec_rec", r"$W_\text{rec}$ (pA)"
+)
+plot_weight_time_course(axs[2], events_wr, nrns_rec[:n_record_w], nrns_out, "rec_out", r"$W_\text{out}$ (pA)")
+
+axs[-1].set_xlabel(r"$t$ (ms)")
+axs[-1].set_xlim(0, steps["task"])
+
+fig.align_ylabels()
+fig.tight_layout()
+
+# %% ###########################################################################################################
+# Plot weight matrices
+# ....................
+# If one is not interested in the time course of the weights, it is possible to read out only the initial and
+# final weights, which requires less computing time and memory than the weight recorder approach. Here, we plot
+# the corresponding weight matrices before and after the optimization.
+
+cmap = mpl.colors.LinearSegmentedColormap.from_list(
+ "cmap", ((0.0, colors["blue"]), (0.5, colors["white"]), (1.0, colors["red"]))
+)
+
+fig, axs = plt.subplots(3, 2, sharex="col", sharey="row")
+fig.suptitle("Weight matrices")
+
+all_w_extrema = []
+
+for k in weights_pre_train.keys():
+ w_pre = weights_pre_train[k]["weight"]
+ w_post = weights_post_train[k]["weight"]
+ all_w_extrema.append([np.min(w_pre), np.max(w_pre), np.min(w_post), np.max(w_post)])
+
+args = {"cmap": cmap, "vmin": np.min(all_w_extrema), "vmax": np.max(all_w_extrema)}
+
+for i, weights in zip([0, 1], [weights_pre_train, weights_post_train]):
+ axs[0, i].pcolormesh(weights["in_rec"]["weight_matrix"].T, **args)
+ axs[1, i].pcolormesh(weights["rec_rec"]["weight_matrix"], **args)
+ cmesh = axs[2, i].pcolormesh(weights["rec_out"]["weight_matrix"], **args)
+
+ axs[2, i].set_xlabel("recurrent\nneurons")
+
+axs[0, 0].set_ylabel("input\nneurons")
+axs[1, 0].set_ylabel("recurrent\nneurons")
+axs[2, 0].set_ylabel("readout\nneurons")
+fig.align_ylabels(axs[:, 0])
+
+axs[0, 0].text(0.5, 1.1, "before training", transform=axs[0, 0].transAxes, ha="center")
+axs[0, 1].text(0.5, 1.1, "after training", transform=axs[0, 1].transAxes, ha="center")
+
+axs[2, 0].yaxis.get_major_locator().set_params(integer=True)
+
+cbar = plt.colorbar(cmesh, cax=axs[1, 1].inset_axes([1.1, 0.2, 0.05, 0.8]), label="weight (pA)")
+
+fig.tight_layout()
+
+plt.show()
diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_classification_neuromorphic_mnist.py b/pynest/examples/eprop_plasticity/eprop_supervised_classification_neuromorphic_mnist.py
index 84a0b990bf..592f079e35 100644
--- a/pynest/examples/eprop_plasticity/eprop_supervised_classification_neuromorphic_mnist.py
+++ b/pynest/examples/eprop_plasticity/eprop_supervised_classification_neuromorphic_mnist.py
@@ -134,6 +134,8 @@
steps = {
"sequence": 300, # time steps of one full sequence
"learning_window": 10, # time steps of window with non-zero learning signals
+ "delay_rec_out": 1, # time steps of connection delay from recurrent to output neurons
+ "delay_out_rec": 1, # time steps of broadcast delay of learning signals
}
steps.update(
@@ -207,6 +209,8 @@
"I_e": 0.0, # pA, external current input
"tau_m": 100.0, # ms, membrane time constant
"V_m": 0.0, # mV, initial value of the membrane voltage
+ "delay_out_rec": duration["delay_out_rec"], # ms, broadcast delay of learning signals
+ "delay_rec_out": duration["delay_rec_out"], # ms, connection delay from recurrent to output neurons
}
params_nrn_rec = {
@@ -225,6 +229,8 @@
"tau_m": 30.0,
"V_m": 0.0,
"V_th": 0.6, # mV, spike threshold membrane voltage
+ "delay_out_rec": duration["delay_out_rec"], # ms, broadcast delay of learning signals
+ "delay_rec_out": duration["delay_rec_out"], # ms, connection delay from recurrent to output neurons
}
scale_factor = 1.0 - params_nrn_rec["kappa"] # factor for rescaling due to removal of irregular spike arrival
@@ -389,10 +395,11 @@ def get_weight_recorder_senders_targets(weights, sender_pop, target_pop):
params_syn_in = params_syn_base.copy()
params_syn_rec = params_syn_base.copy()
params_syn_out = params_syn_base.copy()
+params_syn_out["delay"] = duration["delay_rec_out"]
params_syn_feedback = {
"synapse_model": "eprop_learning_signal_connection",
- "delay": duration["step"],
+ "delay": duration["delay_out_rec"],
"weight": weights_out_rec,
}
@@ -592,6 +599,7 @@ def get_params_task_input_output(n_iter_interval, loader):
+ iteration_offset
+ group_element * duration["sequence"]
+ duration["total_offset"]
+ + duration["delay_rec_out"] - 1.0
for group_element in range(group_size)
]
),
diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves.py b/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves.py
index bf0a66e992..4413afb9d6 100644
--- a/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves.py
+++ b/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves.py
@@ -121,6 +121,8 @@
steps = {
"sequence": 1000, # time steps of one full sequence
+ "delay_rec_out": 1, # time steps of connection delay from recurrent to output neurons
+ "delay_out_rec": 1, # time steps of broadcast delay of learning signals
}
steps["learning_window"] = steps["sequence"] # time steps of window with non-zero learning signals
@@ -181,6 +183,8 @@
"I_e": 0.0, # pA, external current input
"tau_m": 30.0, # ms, membrane time constant
"V_m": 0.0, # mV, initial value of the membrane voltage
+ "delay_out_rec": duration["delay_out_rec"], # ms, broadcast delay of learning signals
+ "delay_rec_out": duration["delay_rec_out"], # ms, connection delay from recurrent to output neurons
}
params_nrn_rec = {
@@ -199,6 +203,8 @@
"tau_m": 30.0,
"V_m": 0.0,
"V_th": 0.03, # mV, spike threshold membrane voltage
+ "delay_out_rec": duration["delay_out_rec"], # ms, broadcast delay of learning signals
+ "delay_rec_out": duration["delay_rec_out"], # ms, connection delay from recurrent to output neurons
}
scale_factor = 1.0 - params_nrn_rec["kappa"] # factor for rescaling due to removal of irregular spike arrival
@@ -335,10 +341,11 @@
params_syn_out = params_syn_base.copy()
params_syn_out["weight"] = weights_rec_out
+params_syn_out["delay"] = duration["delay_rec_out"]
params_syn_feedback = {
"synapse_model": "eprop_learning_signal_connection",
- "delay": duration["step"],
+ "delay": duration["delay_out_rec"],
"weight": weights_out_rec,
}
@@ -428,7 +435,7 @@ def generate_superimposed_sines(steps_sequence, periods):
target_signal = generate_superimposed_sines(steps["sequence"], [1000, 500, 333, 200]) # periods in steps
params_gen_rate_target = {
- "amplitude_times": np.arange(0.0, duration["task"], duration["step"]) + duration["total_offset"],
+ "amplitude_times": np.arange(0.0, duration["task"], duration["step"]) + duration["total_offset"] + duration["delay_rec_out"] - 1.0,
"amplitude_values": np.tile(target_signal, n_iter * group_size),
}