diff --git a/models/eprop_iaf.cpp b/models/eprop_iaf.cpp index fb336a40f4..a081adec26 100644 --- a/models/eprop_iaf.cpp +++ b/models/eprop_iaf.cpp @@ -85,6 +85,9 @@ eprop_iaf::Parameters_::Parameters_() , kappa_( 0.97 ) , kappa_reg_( 0.97 ) , eprop_isi_trace_cutoff_( 1000.0 ) + , delay_rec_out_( 1 ) + , delay_out_rec_( 1 ) + , delay_total_( 1 ) { } @@ -131,6 +134,8 @@ eprop_iaf::Parameters_::get( DictionaryDatum& d ) const def< double >( d, names::kappa, kappa_ ); def< double >( d, names::kappa_reg, kappa_reg_ ); def< double >( d, names::eprop_isi_trace_cutoff, eprop_isi_trace_cutoff_ ); + def< double >( d, names::delay_rec_out, Time( Time::step( delay_rec_out_ ) ).get_ms() ); + def< double >( d, names::delay_out_rec, Time( Time::step( delay_out_rec_ ) ).get_ms() ); } double @@ -169,6 +174,14 @@ eprop_iaf::Parameters_::set( const DictionaryDatum& d, Node* node ) updateValueParam< double >( d, names::kappa_reg, kappa_reg_, node ); updateValueParam< double >( d, names::eprop_isi_trace_cutoff, eprop_isi_trace_cutoff_, node ); + double delay_rec_out_ms = Time( Time::step( delay_rec_out_ ) ).get_ms(); + updateValueParam< double >( d, names::delay_rec_out, delay_rec_out_ms, node ); + delay_rec_out_ = Time( Time::ms( delay_rec_out_ms ) ).get_steps(); + + double delay_out_rec_ms = Time( Time::step( delay_out_rec_ ) ).get_ms(); + updateValueParam< double >( d, names::delay_out_rec, delay_out_rec_ms, node ); + delay_out_rec_ = Time( Time::ms( delay_out_rec_ms ) ).get_steps(); + if ( C_m_ <= 0 ) { throw BadProperty( "Membrane capacitance C_m > 0 required." ); @@ -214,6 +227,18 @@ eprop_iaf::Parameters_::set( const DictionaryDatum& d, Node* node ) throw BadProperty( "Cutoff of integration of eprop trace between spikes eprop_isi_trace_cutoff ≥ 0 required." ); } + if ( delay_rec_out_ < 1 ) + { + throw BadProperty( "Connection delay from recurrent to readout neuron ≥ 1 required." ); + } + + if ( delay_out_rec_ < 1 ) + { + throw BadProperty( "Connection delay from readout to recurrent neuron ≥ 1 required." ); + } + + delay_total_ = delay_rec_out_ + ( delay_out_rec_ - 1 ); + return delta_EL; } @@ -278,6 +303,14 @@ eprop_iaf::pre_run_hook() V_.P_v_m_ = std::exp( -dt / P_.tau_m_ ); V_.P_i_in_ = P_.tau_m_ / P_.C_m_ * ( 1.0 - V_.P_v_m_ ); + + if ( eprop_history_.empty() ) + { + for ( long t = -P_.delay_total_; t < 0; ++t ) + { + append_new_eprop_history_entry( t ); + } + } } @@ -373,7 +406,8 @@ eprop_iaf::handle( DataLoggingRequest& e ) void eprop_iaf::compute_gradient( const long t_spike, const long t_spike_previous, - double& z_previous_buffer, + std::queue< double >& z_previous_buffer, + double& z_previous, double& z_bar, double& e_bar, double& e_bar_reg, @@ -382,26 +416,31 @@ eprop_iaf::compute_gradient( const long t_spike, const CommonSynapseProperties& cp, WeightOptimizer* optimizer ) { - double e = 0.0; // eligibility trace - double z = 0.0; // spiking variable - double z_current_buffer = 1.0; // buffer containing the spike that triggered the current integration - double psi = 0.0; // surrogate gradient - double L = 0.0; // learning signal - double firing_rate_reg = 0.0; // firing rate regularization - double grad = 0.0; // gradient + double e = 0.0; // eligibility trace + double z = 0.0; // spiking variable + double z_current = 1.0; // spike state that triggered the current integration + double psi = 0.0; // surrogate gradient + double L = 0.0; // learning signal + double firing_rate_reg = 0.0; // firing rate regularization + double grad = 0.0; // gradient const EpropSynapseCommonProperties& ecp = static_cast< const EpropSynapseCommonProperties& >( cp ); const auto optimize_each_step = ( *ecp.optimizer_cp_ ).optimize_each_step_; - auto eprop_hist_it = get_eprop_history( t_spike_previous - 1 ); + auto eprop_hist_it = get_eprop_history( t_spike_previous - P_.delay_total_ ); const long t_compute_until = std::min( t_spike_previous + V_.eprop_isi_trace_cutoff_steps_, t_spike ); for ( long t = t_spike_previous; t < t_compute_until; ++t, ++eprop_hist_it ) { - z = z_previous_buffer; - z_previous_buffer = z_current_buffer; - z_current_buffer = 0.0; + if ( P_.delay_total_ > 1 ) + { + update_pre_syn_buffer_multiple_entries( z, z_current, z_previous, z_previous_buffer, t_spike, t ); + } + else + { + update_pre_syn_buffer_one_entry( z, z_current, z_previous, z_previous_buffer, t_spike, t ); + } psi = eprop_hist_it->surrogate_gradient_; L = eprop_hist_it->learning_signal_; diff --git a/models/eprop_iaf.h b/models/eprop_iaf.h index 225de65333..22e04dc820 100644 --- a/models/eprop_iaf.h +++ b/models/eprop_iaf.h @@ -390,6 +390,7 @@ class eprop_iaf : public EpropArchivingNodeRecurrent< false > void compute_gradient( const long, const long, + std::queue< double >&, double&, double&, double&, @@ -402,6 +403,9 @@ class eprop_iaf : public EpropArchivingNodeRecurrent< false > long get_shift() const override; bool is_eprop_recurrent_node() const override; long get_eprop_isi_trace_cutoff() const override; + long get_delay_total() const override; + long get_delay_recurrent_to_readout() const override; + long get_delay_readout_to_recurrent() const override; //! Map for storing a static set of recordables. friend class RecordablesMap< eprop_iaf >; @@ -458,6 +462,15 @@ class eprop_iaf : public EpropArchivingNodeRecurrent< false > //! Time interval from the previous spike until the cutoff of e-prop update integration between two spikes (ms). double eprop_isi_trace_cutoff_; + //! Connection delay from recurrent to readout neuron. + long delay_rec_out_; + + //! Connection delay from readout to recurrent neuron. + long delay_out_rec_; + + //! Sum of connection delays from recurrent to readout neuron and from readout to recurrent neuron. + long delay_total_; + //! Default constructor. Parameters_(); @@ -594,10 +607,35 @@ eprop_iaf::get_eprop_isi_trace_cutoff() const return V_.eprop_isi_trace_cutoff_steps_; } +inline long +eprop_iaf::get_delay_total() const +{ + return P_.delay_total_; +} + +inline long +eprop_iaf::get_delay_recurrent_to_readout() const +{ + return P_.delay_rec_out_; +} + +inline long +eprop_iaf::get_delay_readout_to_recurrent() const +{ + return P_.delay_out_rec_; +} + inline size_t eprop_iaf::send_test_event( Node& target, size_t receptor_type, synindex, bool ) { SpikeEvent e; + + // To perform a consistency check on the delay parameter d_out_rec between recurrent + // neurons and readout neurons, the recurrent neurons send a test event with a delay + // specified by d_rec_out. Upon receiving the test event from the recurrent neuron, + // the readout neuron checks if the delay with which the event was received matches + // its own specified delay parameter d_rec_out. + e.set_delay_steps( P_.delay_rec_out_ ); e.set_sender( *this ); return target.handles_test_event( e, receptor_type ); } diff --git a/models/eprop_iaf_adapt.cpp b/models/eprop_iaf_adapt.cpp index 10cb9ff224..eccfdaa454 100644 --- a/models/eprop_iaf_adapt.cpp +++ b/models/eprop_iaf_adapt.cpp @@ -89,6 +89,9 @@ eprop_iaf_adapt::Parameters_::Parameters_() , kappa_( 0.97 ) , kappa_reg_( 0.97 ) , eprop_isi_trace_cutoff_( 1000.0 ) + , delay_rec_out_( 1 ) + , delay_out_rec_( 1 ) + , delay_total_( 1 ) { } @@ -139,6 +142,8 @@ eprop_iaf_adapt::Parameters_::get( DictionaryDatum& d ) const def< double >( d, names::kappa, kappa_ ); def< double >( d, names::kappa_reg, kappa_reg_ ); def< double >( d, names::eprop_isi_trace_cutoff, eprop_isi_trace_cutoff_ ); + def< double >( d, names::delay_rec_out, Time( Time::step( delay_rec_out_ ) ).get_ms() ); + def< double >( d, names::delay_out_rec, Time( Time::step( delay_out_rec_ ) ).get_ms() ); } double @@ -179,6 +184,12 @@ eprop_iaf_adapt::Parameters_::set( const DictionaryDatum& d, Node* node ) updateValueParam< double >( d, names::kappa_reg, kappa_reg_, node ); updateValueParam< double >( d, names::eprop_isi_trace_cutoff, eprop_isi_trace_cutoff_, node ); + const double delay_rec_out_ = Time::step( delay_rec_out_ ); + updateValueParam< double >( d, names::delay_rec_out, Time( delay_rec_out_).get_ms(), node ); + + const double delay_out_rec_ = Time::step( delay_out_rec_ ); + updateValueParam< double >( d, names::delay_out_rec, Time( delay_out_rec_ ).get_ms(), node ); + if ( adapt_beta_ < 0 ) { throw BadProperty( "Threshold adaptation prefactor adapt_beta ≥ 0 required." ); @@ -234,6 +245,18 @@ eprop_iaf_adapt::Parameters_::set( const DictionaryDatum& d, Node* node ) throw BadProperty( "Cutoff of integration of eprop trace between spikes eprop_isi_trace_cutoff ≥ 0 required." ); } + if ( delay_rec_out_ < 1 ) + { + throw BadProperty( "Connection delay from recurrent to readout neuron ≥ 1 required." ); + } + + if ( delay_out_rec_ < 1 ) + { + throw BadProperty( "Connection delay from readout to recurrent neuron ≥ 1 required." ); + } + + delay_total_ = delay_rec_out_ + ( delay_out_rec_ - 1 ); + return delta_EL; } @@ -313,6 +336,14 @@ eprop_iaf_adapt::pre_run_hook() V_.P_v_m_ = std::exp( -dt / P_.tau_m_ ); V_.P_i_in_ = P_.tau_m_ / P_.C_m_ * ( 1.0 - V_.P_v_m_ ); V_.P_adapt_ = std::exp( -dt / P_.adapt_tau_ ); + + if ( eprop_history_.empty() ) + { + for ( long t = -P_.delay_total_; t < 0; ++t ) + { + append_new_eprop_history_entry( t ); + } + } } @@ -412,7 +443,8 @@ eprop_iaf_adapt::handle( DataLoggingRequest& e ) void eprop_iaf_adapt::compute_gradient( const long t_spike, const long t_spike_previous, - double& z_previous_buffer, + std::queue< double >& z_previous_buffer, + double& z_previous, double& z_bar, double& e_bar, double& e_bar_reg, @@ -421,26 +453,31 @@ eprop_iaf_adapt::compute_gradient( const long t_spike, const CommonSynapseProperties& cp, WeightOptimizer* optimizer ) { - double e = 0.0; // eligibility trace - double z = 0.0; // spiking variable - double z_current_buffer = 1.0; // buffer containing the spike that triggered the current integration - double psi = 0.0; // surrogate gradient - double L = 0.0; // learning signal - double firing_rate_reg = 0.0; // firing rate regularization - double grad = 0.0; // gradient + double e = 0.0; // eligibility trace + double z = 0.0; // spiking variable + double z_current = 1.0; // spike state that triggered the current integration + double psi = 0.0; // surrogate gradient + double L = 0.0; // learning signal + double firing_rate_reg = 0.0; // firing rate regularization + double grad = 0.0; // gradient const EpropSynapseCommonProperties& ecp = static_cast< const EpropSynapseCommonProperties& >( cp ); const auto optimize_each_step = ( *ecp.optimizer_cp_ ).optimize_each_step_; - auto eprop_hist_it = get_eprop_history( t_spike_previous - 1 ); + auto eprop_hist_it = get_eprop_history( t_spike_previous - P_.delay_total_ ); const long t_compute_until = std::min( t_spike_previous + V_.eprop_isi_trace_cutoff_steps_, t_spike ); for ( long t = t_spike_previous; t < t_compute_until; ++t, ++eprop_hist_it ) { - z = z_previous_buffer; - z_previous_buffer = z_current_buffer; - z_current_buffer = 0.0; + if ( P_.delay_total_ > 1 ) + { + update_pre_syn_buffer_multiple_entries( z, z_current, z_previous, z_previous_buffer, t_spike, t ); + } + else + { + update_pre_syn_buffer_one_entry( z, z_current, z_previous, z_previous_buffer, t_spike, t ); + } psi = eprop_hist_it->surrogate_gradient_; L = eprop_hist_it->learning_signal_; diff --git a/models/eprop_iaf_adapt.h b/models/eprop_iaf_adapt.h index d404dcbd69..fe9689e39f 100644 --- a/models/eprop_iaf_adapt.h +++ b/models/eprop_iaf_adapt.h @@ -358,6 +358,7 @@ class eprop_iaf_adapt : public EpropArchivingNodeRecurrent< false > void compute_gradient( const long, const long, + std::queue< double >&, double&, double&, double&, @@ -370,6 +371,9 @@ class eprop_iaf_adapt : public EpropArchivingNodeRecurrent< false > long get_shift() const override; bool is_eprop_recurrent_node() const override; long get_eprop_isi_trace_cutoff() const override; + long get_delay_total() const override; + long get_delay_recurrent_to_readout() const override; + long get_delay_readout_to_recurrent() const override; //! Map for storing a static set of recordables. friend class RecordablesMap< eprop_iaf_adapt >; @@ -432,6 +436,15 @@ class eprop_iaf_adapt : public EpropArchivingNodeRecurrent< false > //! Time interval from the previous spike until the cutoff of e-prop update integration between two spikes (ms). double eprop_isi_trace_cutoff_; + //! Connection delay from recurrent to readout neuron. + long delay_rec_out_; + + //! Connection delay from readout to recurrent neuron. + long delay_out_rec_; + + //! Sum of connection delays from recurrent to readout neuron and readout to recurrent neuron. + long delay_total_; + //! Default constructor. Parameters_(); @@ -591,10 +604,35 @@ eprop_iaf_adapt::get_eprop_isi_trace_cutoff() const return V_.eprop_isi_trace_cutoff_steps_; } +inline long +eprop_iaf_adapt::get_delay_total() const +{ + return P_.delay_total_; +} + +inline long +eprop_iaf_adapt::get_delay_recurrent_to_readout() const +{ + return P_.delay_rec_out_; +} + +inline long +eprop_iaf_adapt::get_delay_readout_to_recurrent() const +{ + return P_.delay_out_rec_; +} + inline size_t eprop_iaf_adapt::send_test_event( Node& target, size_t receptor_type, synindex, bool ) { SpikeEvent e; + + // To perform a consistency check on the delay parameter d_out_rec between recurrent + // neurons and readout neurons, the recurrent neurons send a test event with a delay + // specified by d_rec_out. Upon receiving the test event from the recurrent neuron, + // the readout neuron checks if the delay with which the event was received matches + // its own specified delay parameter d_rec_out. + e.set_delay_steps( P_.delay_rec_out_ ); e.set_sender( *this ); return target.handles_test_event( e, receptor_type ); } diff --git a/models/eprop_iaf_psc_delta.cpp b/models/eprop_iaf_psc_delta.cpp index a8b13ac4e4..0bbfe38f49 100644 --- a/models/eprop_iaf_psc_delta.cpp +++ b/models/eprop_iaf_psc_delta.cpp @@ -87,6 +87,9 @@ eprop_iaf_psc_delta::Parameters_::Parameters_() , kappa_( 0.97 ) , kappa_reg_( 0.97 ) , eprop_isi_trace_cutoff_( 1000.0 ) + , delay_rec_out_( 1 ) + , delay_out_rec_( 1 ) + , delay_total_( 1 ) { } @@ -134,6 +137,8 @@ eprop_iaf_psc_delta::Parameters_::get( DictionaryDatum& d ) const def< double >( d, names::kappa, kappa_ ); def< double >( d, names::kappa_reg, kappa_reg_ ); def< double >( d, names::eprop_isi_trace_cutoff, eprop_isi_trace_cutoff_ ); + def< double >( d, names::delay_rec_out, Time::step( delay_rec_out_ ) ); + def< double >( d, names::delay_out_rec, Time::step( delay_out_rec_ ) ); } double @@ -162,23 +167,16 @@ eprop_iaf_psc_delta::Parameters_::set( const DictionaryDatum& d, Node* node ) updateValueParam< double >( d, names::beta, beta_, node ); updateValueParam< double >( d, names::gamma, gamma_, node ); - - if ( updateValueParam< std::string >( d, names::surrogate_gradient_function, surrogate_gradient_function_, node ) ) - { - eprop_iaf_psc_delta* nrn = dynamic_cast< eprop_iaf_psc_delta* >( node ); - assert( nrn ); - auto compute_surrogate_gradient = nrn->find_surrogate_gradient( surrogate_gradient_function_ ); - nrn->compute_surrogate_gradient_ = compute_surrogate_gradient; - } - + updateValueParam< std::string >( d, names::surrogate_gradient_function, surrogate_gradient_function_, node ); updateValueParam< double >( d, names::kappa, kappa_, node ); updateValueParam< double >( d, names::kappa_reg, kappa_reg_, node ); updateValueParam< double >( d, names::eprop_isi_trace_cutoff, eprop_isi_trace_cutoff_, node ); - if ( V_th_ < V_min_ ) - { - throw BadProperty( "Spike threshold voltage V_th ≥ minimal voltage V_min required." ); - } + const double delay_rec_out_ = Time::step( delay_rec_out_ ); + updateValueParam< double >( d, names::delay_rec_out, Time( delay_rec_out_).get_ms(), node ); + + const double delay_out_rec_ = Time::step( delay_out_rec_ ); + updateValueParam< double >( d, names::delay_out_rec, Time( delay_out_rec_).get_ms(), node ); if ( V_reset_ >= V_th_ ) { @@ -230,6 +228,19 @@ eprop_iaf_psc_delta::Parameters_::set( const DictionaryDatum& d, Node* node ) throw BadProperty( "Cutoff of integration of eprop trace between spikes eprop_isi_trace_cutoff ≥ 0 required." ); } + if ( delay_rec_out_ < 1 ) + { + throw BadProperty( "Connection delay from recurrent to readout neuron ≥ 1 required." ); + } + + if ( delay_out_rec_ < 1 ) + { + throw BadProperty( "Connection delay from readout to recurrent neuron ≥ 1 required." ); + } + + delay_total_ = delay_rec_out_ + ( delay_out_rec_ - 1 ); + + return delta_EL; } @@ -294,8 +305,27 @@ eprop_iaf_psc_delta::pre_run_hook() V_.P_v_m_ = std::exp( -dt / P_.tau_m_ ); V_.P_i_in_ = P_.tau_m_ / P_.C_m_ * ( 1.0 - V_.P_v_m_ ); + + if ( eprop_history_.empty() ) + { + for ( long t = -P_.delay_total_; t < 0; ++t ) + { + append_new_eprop_history_entry( t ); + } + } } +long +eprop_iaf_psc_delta::get_shift() const +{ + return offset_gen_ + delay_in_rec_; +} + +bool +eprop_iaf_psc_delta::is_eprop_recurrent_node() const +{ + return true; +} /* ---------------------------------------------------------------- * Update function @@ -406,7 +436,8 @@ eprop_iaf_psc_delta::handle( DataLoggingRequest& e ) void eprop_iaf_psc_delta::compute_gradient( const long t_spike, const long t_spike_previous, - double& z_previous_buffer, + std::queue< double >& z_previous_buffer, + double& z_previous, double& z_bar, double& e_bar, double& e_bar_reg, @@ -415,26 +446,31 @@ eprop_iaf_psc_delta::compute_gradient( const long t_spike, const CommonSynapseProperties& cp, WeightOptimizer* optimizer ) { - double e = 0.0; // eligibility trace - double z = 0.0; // spiking variable - double z_current_buffer = 1.0; // buffer containing the spike that triggered the current integration - double psi = 0.0; // surrogate gradient - double L = 0.0; // learning signal - double firing_rate_reg = 0.0; // firing rate regularization - double grad = 0.0; // gradient + double e = 0.0; // eligibility trace + double z = 0.0; // spiking variable + double z_current = 1.0; // spike state that triggered the current integration + double psi = 0.0; // surrogate gradient + double L = 0.0; // learning signal + double firing_rate_reg = 0.0; // firing rate regularization + double grad = 0.0; // gradient const EpropSynapseCommonProperties& ecp = static_cast< const EpropSynapseCommonProperties& >( cp ); const auto optimize_each_step = ( *ecp.optimizer_cp_ ).optimize_each_step_; - auto eprop_hist_it = get_eprop_history( t_spike_previous - 1 ); + auto eprop_hist_it = get_eprop_history( t_spike_previous - P_.delay_total_ ); const long t_compute_until = std::min( t_spike_previous + V_.eprop_isi_trace_cutoff_steps_, t_spike ); for ( long t = t_spike_previous; t < t_compute_until; ++t, ++eprop_hist_it ) { - z = z_previous_buffer; - z_previous_buffer = z_current_buffer; - z_current_buffer = 0.0; + if ( P_.delay_total_ > 1 ) + { + update_pre_syn_buffer_multiple_entries( z, z_current, z_previous, z_previous_buffer, t_spike, t ); + } + else + { + update_pre_syn_buffer_one_entry( z, z_current, z_previous, z_previous_buffer, t_spike, t ); + } psi = eprop_hist_it->surrogate_gradient_; L = eprop_hist_it->learning_signal_; diff --git a/models/eprop_iaf_psc_delta.h b/models/eprop_iaf_psc_delta.h index c36066a984..269ee7783f 100644 --- a/models/eprop_iaf_psc_delta.h +++ b/models/eprop_iaf_psc_delta.h @@ -307,7 +307,7 @@ References https://doi.org/10.1038/s41467-020-17236-y .. [2] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Plesser HE, - Dahmen D, Bolten M, Van Albada SJ, Diesmann M. Event-based + Dahmen D, Bolten M, Van Albada SJ*, Diesmann M*. Event-based implementation of eligibility propagation (in preparation) .. [3] Neftci EO, Mostafa H, Zenke F (2019). Surrogate Gradient Learning in @@ -402,6 +402,7 @@ class eprop_iaf_psc_delta : public EpropArchivingNodeRecurrent< false > void compute_gradient( const long, const long, + std::queue< double >&, double&, double&, double&, @@ -414,6 +415,9 @@ class eprop_iaf_psc_delta : public EpropArchivingNodeRecurrent< false > long get_shift() const override; bool is_eprop_recurrent_node() const override; long get_eprop_isi_trace_cutoff() const override; + long get_delay_total() const override; + long get_delay_recurrent_to_readout() const override; + long get_delay_readout_to_recurrent() const override; //! Map for storing a static set of recordables. friend class RecordablesMap< eprop_iaf_psc_delta >; @@ -476,6 +480,15 @@ class eprop_iaf_psc_delta : public EpropArchivingNodeRecurrent< false > //! Time interval from the previous spike until the cutoff of e-prop update integration between two spikes (ms). double eprop_isi_trace_cutoff_; + //! Connection delay from recurrent to readout neuron. + long delay_rec_out_; + + //! Connection delay from readout to recurrent neuron. + long delay_out_rec_; + + //! Sum of connection delays from recurrent to readout neuron and readout to recurrent neuron. + long delay_total_; + //! Default constructor. Parameters_(); @@ -592,27 +605,40 @@ class eprop_iaf_psc_delta : public EpropArchivingNodeRecurrent< false > }; inline long -eprop_iaf_psc_delta::get_shift() const +eprop_iaf_psc_delta::get_eprop_isi_trace_cutoff() const { - return offset_gen_ + delay_in_rec_; + return V_.eprop_isi_trace_cutoff_steps_; } -inline bool -eprop_iaf_psc_delta::is_eprop_recurrent_node() const +inline long +eprop_iaf_psc_delta::get_delay_total() const { - return true; + return P_.delay_total_; } inline long -eprop_iaf_psc_delta::get_eprop_isi_trace_cutoff() const +eprop_iaf_psc_delta::get_delay_recurrent_to_readout() const { - return V_.eprop_isi_trace_cutoff_steps_; + return P_.delay_rec_out_; +} + +inline long +eprop_iaf_psc_delta::get_delay_readout_to_recurrent() const +{ + return P_.delay_out_rec_; } inline size_t eprop_iaf_psc_delta::send_test_event( Node& target, size_t receptor_type, synindex, bool ) { SpikeEvent e; + + // To perform a consistency check on the delay parameter d_out_rec between recurrent + // neurons and readout neurons, the recurrent neurons send a test event with a delay + // specified by d_rec_out. Upon receiving the test event from the recurrent neuron, + // the readout neuron checks if the delay with which the event was received matches + // its own specified delay parameter d_rec_out. + e.set_delay_steps( P_.delay_rec_out_ ); e.set_sender( *this ); return target.handles_test_event( e, receptor_type ); } diff --git a/models/eprop_iaf_psc_delta_adapt.cpp b/models/eprop_iaf_psc_delta_adapt.cpp index 9ffe2a8b69..74b9dc752b 100644 --- a/models/eprop_iaf_psc_delta_adapt.cpp +++ b/models/eprop_iaf_psc_delta_adapt.cpp @@ -445,7 +445,8 @@ eprop_iaf_psc_delta_adapt::handle( DataLoggingRequest& e ) void eprop_iaf_psc_delta_adapt::compute_gradient( const long t_spike, const long t_spike_previous, - double& z_previous_buffer, + std::queue< double >& z_previous_buffer, + double& z_previous, double& z_bar, double& e_bar, double& e_bar_reg, @@ -454,13 +455,13 @@ eprop_iaf_psc_delta_adapt::compute_gradient( const long t_spike, const CommonSynapseProperties& cp, WeightOptimizer* optimizer ) { - double e = 0.0; // eligibility trace - double z = 0.0; // spiking variable - double z_current_buffer = 1.0; // buffer containing the spike that triggered the current integration - double psi = 0.0; // surrogate gradient - double L = 0.0; // learning signal - double firing_rate_reg = 0.0; // firing rate regularization - double grad = 0.0; // gradient + double e = 0.0; // eligibility trace + double z = 0.0; // spiking variable + double z_current = 1.0; // spike state that triggered the current integration + double psi = 0.0; // surrogate gradient + double L = 0.0; // learning signal + double firing_rate_reg = 0.0; // firing rate regularization + double grad = 0.0; // gradient const EpropSynapseCommonProperties& ecp = static_cast< const EpropSynapseCommonProperties& >( cp ); const auto optimize_each_step = ( *ecp.optimizer_cp_ ).optimize_each_step_; @@ -471,9 +472,14 @@ eprop_iaf_psc_delta_adapt::compute_gradient( const long t_spike, for ( long t = t_spike_previous; t < t_compute_until; ++t, ++eprop_hist_it ) { - z = z_previous_buffer; - z_previous_buffer = z_current_buffer; - z_current_buffer = 0.0; + if ( P_.delay_total_ > 1 ) + { + update_pre_syn_buffer_multiple_entries( z, z_current, z_previous, z_previous_buffer, t_spike, t ); + } + else + { + update_pre_syn_buffer_one_entry( z, z_current, z_previous, z_previous_buffer, t_spike, t ); + } psi = eprop_hist_it->surrogate_gradient_; L = eprop_hist_it->learning_signal_; diff --git a/models/eprop_iaf_psc_delta_adapt.h b/models/eprop_iaf_psc_delta_adapt.h index 3c0949de3a..e417b65567 100644 --- a/models/eprop_iaf_psc_delta_adapt.h +++ b/models/eprop_iaf_psc_delta_adapt.h @@ -418,6 +418,7 @@ class eprop_iaf_psc_delta_adapt : public EpropArchivingNodeRecurrent< false > void compute_gradient( const long, const long, + std::queue< double >&, double&, double&, double&, @@ -430,6 +431,9 @@ class eprop_iaf_psc_delta_adapt : public EpropArchivingNodeRecurrent< false > long get_shift() const override; bool is_eprop_recurrent_node() const override; long get_eprop_isi_trace_cutoff() const override; + long get_delay_total() const override; + long get_delay_recurrent_to_readout() const override; + long get_delay_readout_to_recurrent() const override; //! Map for storing a static set of recordables. friend class RecordablesMap< eprop_iaf_psc_delta_adapt >; @@ -498,6 +502,15 @@ class eprop_iaf_psc_delta_adapt : public EpropArchivingNodeRecurrent< false > //! Time interval from the previous spike until the cutoff of e-prop update integration between two spikes (ms). double eprop_isi_trace_cutoff_; + //! Connection delay from recurrent to readout neuron. + long delay_rec_out_; + + //! Connection delay from readout to recurrent neuron. + long delay_out_rec_; + + //! Sum of connection delays from recurrent to readout neuron and readout to recurrent neuron. + long delay_total_; + //! Default constructor. Parameters_(); @@ -657,10 +670,35 @@ eprop_iaf_psc_delta_adapt::get_eprop_isi_trace_cutoff() const return V_.eprop_isi_trace_cutoff_steps_; } +inline long +eprop_iaf_psc_delta_adapt::get_delay_total() const +{ + return P_.delay_total_; +} + +inline long +eprop_iaf_psc_delta_adapt::get_delay_recurrent_to_readout() const +{ + return P_.delay_rec_out_; +} + +inline long +eprop_iaf_psc_delta_adapt::get_delay_readout_to_recurrent() const +{ + return P_.delay_out_rec_; +} + inline size_t eprop_iaf_psc_delta_adapt::send_test_event( Node& target, size_t receptor_type, synindex, bool ) { SpikeEvent e; + + // To perform a consistency check on the delay parameter d_out_rec between recurrent + // neurons and readout neurons, the recurrent neurons send a test event with a delay + // specified by d_rec_out. Upon receiving the test event from the recurrent neuron, + // the readout neuron checks if the delay with which the event was received matches + // its own specified delay parameter d_rec_out. + e.set_delay_steps( P_.delay_rec_out_ ); e.set_sender( *this ); return target.handles_test_event( e, receptor_type ); } diff --git a/models/eprop_learning_signal_connection.h b/models/eprop_learning_signal_connection.h index cc013a028b..9c84671488 100644 --- a/models/eprop_learning_signal_connection.h +++ b/models/eprop_learning_signal_connection.h @@ -163,6 +163,12 @@ class eprop_learning_signal_connection : public Connection< targetidentifierT > { LearningSignalConnectionEvent ge; + const long delay_out_rec = t.get_delay_readout_to_recurrent(); + if ( delay_out_rec != get_delay_steps() ) + { + throw IllegalConnection( "delay == delay_rec_out of target neuron required." ); + } + s.sends_secondary_event( ge ); ge.set_sender( s ); Connection< targetidentifierT >::target_.set_rport( t.handles_test_event( ge, receptor_type ) ); diff --git a/models/eprop_readout.cpp b/models/eprop_readout.cpp index b2f740d70e..7bfeb0aaa3 100644 --- a/models/eprop_readout.cpp +++ b/models/eprop_readout.cpp @@ -76,6 +76,8 @@ eprop_readout::Parameters_::Parameters_() , tau_m_( 10.0 ) , V_min_( -std::numeric_limits< double >::max() ) , eprop_isi_trace_cutoff_( 1000.0 ) + , delay_rec_out_( 1 ) + , delay_out_rec_( 1 ) { } @@ -113,6 +115,8 @@ eprop_readout::Parameters_::get( DictionaryDatum& d ) const def< double >( d, names::tau_m, tau_m_ ); def< double >( d, names::V_min, V_min_ + E_L_ ); def< double >( d, names::eprop_isi_trace_cutoff, eprop_isi_trace_cutoff_ ); + def< double >( d, names::delay_rec_out, Time( Time::step( delay_rec_out_ ) ).get_ms() ); + def< double >( d, names::delay_out_rec, Time( Time::step( delay_out_rec_ ) ).get_ms() ); } double @@ -130,6 +134,12 @@ eprop_readout::Parameters_::set( const DictionaryDatum& d, Node* node ) updateValueParam< double >( d, names::tau_m, tau_m_, node ); updateValueParam< double >( d, names::eprop_isi_trace_cutoff, eprop_isi_trace_cutoff_, node ); + const double delay_rec_out_ = Time::step( delay_rec_out_ ); + updateValueParam< double >( d, names::delay_rec_out, Time( delay_rec_out_).get_ms(), node ); + + const double delay_out_rec_ = Time::step( delay_out_rec_ ); + updateValueParam< double >( d, names::delay_out_rec, Time( delay_out_rec_ ).get_ms(), node ); + if ( C_m_ <= 0 ) { throw BadProperty( "Membrane capacitance C_m > 0 required." ); @@ -145,6 +155,16 @@ eprop_readout::Parameters_::set( const DictionaryDatum& d, Node* node ) throw BadProperty( "Cutoff of integration of eprop trace between spikes eprop_isi_trace_cutoff ≥ 0 required." ); } + if ( delay_rec_out_ < 1 ) + { + throw BadProperty( "Connection delay from recurrent to output neuron ≥ 1 required." ); + } + + if ( delay_out_rec_ < 1 ) + { + throw BadProperty( "Connection delay from readout to recurrent neuron ≥ 1 required." ); + } + return delta_EL; } @@ -207,6 +227,19 @@ eprop_readout::pre_run_hook() V_.P_v_m_ = std::exp( -dt / P_.tau_m_ ); V_.P_i_in_ = P_.tau_m_ / P_.C_m_ * ( 1.0 - V_.P_v_m_ ); + + if ( eprop_history_.empty() ) + { + for ( long t = -P_.delay_rec_out_; t < 0; ++t ) + { + append_new_eprop_history_entry( t ); + } + + for ( long i = 0; i < P_.delay_out_rec_ - 1; i++ ) + { + S_.error_signal_deque_.push_back( 0.0 ); + } + } } @@ -237,7 +270,9 @@ eprop_readout::update( Time const& origin, const long from, const long to ) S_.readout_signal_ *= S_.learning_window_signal_; S_.error_signal_ *= S_.learning_window_signal_; - error_signal_buffer[ lag ] = S_.error_signal_; + S_.error_signal_deque_.push_back( S_.error_signal_ ); + error_signal_buffer[ lag ] = S_.error_signal_deque_.front(); // get delay_out_rec-th value + S_.error_signal_deque_.pop_front(); append_new_eprop_history_entry( t ); write_error_signal_to_history( t, S_.error_signal_ ); @@ -307,7 +342,8 @@ eprop_readout::handle( DataLoggingRequest& e ) void eprop_readout::compute_gradient( const long t_spike, const long t_spike_previous, - double& z_previous_buffer, + std::queue< double >& z_previous_buffer, + double& z_previous, double& z_bar, double& e_bar, double& e_bar_reg, @@ -316,10 +352,10 @@ eprop_readout::compute_gradient( const long t_spike, const CommonSynapseProperties& cp, WeightOptimizer* optimizer ) { - double z = 0.0; // spiking variable - double z_current_buffer = 1.0; // buffer containing the spike that triggered the current integration - double L = 0.0; // error signal - double grad = 0.0; // gradient + double z = 0.0; // spiking variable + double z_current = 1.0; // spike state that triggered the current integration + double L = 0.0; // error signal + double grad = 0.0; // gradient const EpropSynapseCommonProperties& ecp = static_cast< const EpropSynapseCommonProperties& >( cp ); const auto optimize_each_step = ( *ecp.optimizer_cp_ ).optimize_each_step_; @@ -330,9 +366,15 @@ eprop_readout::compute_gradient( const long t_spike, for ( long t = t_spike_previous; t < t_compute_until; ++t, ++eprop_hist_it ) { - z = z_previous_buffer; - z_previous_buffer = z_current_buffer; - z_current_buffer = 0.0; + if ( P_.delay_rec_out_ > 1 ) + { + z = z_previous_buffer.front(); + update_pre_syn_buffer_multiple_entries( z, z_current, z_previous, z_previous_buffer, t_spike, t ); + } + else + { + update_pre_syn_buffer_one_entry( z, z_current, z_previous, z_previous_buffer, t_spike, t ); + } L = eprop_hist_it->error_signal_; diff --git a/models/eprop_readout.h b/models/eprop_readout.h index b96c007542..0a4e2af6ac 100644 --- a/models/eprop_readout.h +++ b/models/eprop_readout.h @@ -301,6 +301,7 @@ class eprop_readout : public EpropArchivingNodeReadout< false > void compute_gradient( const long, const long, + std::queue< double >&, double&, double&, double&, @@ -313,6 +314,7 @@ class eprop_readout : public EpropArchivingNodeReadout< false > long get_shift() const override; bool is_eprop_recurrent_node() const override; long get_eprop_isi_trace_cutoff() const override; + long get_delay_total() const override; //! Map for storing a static set of recordables. friend class RecordablesMap< eprop_readout >; @@ -341,6 +343,12 @@ class eprop_readout : public EpropArchivingNodeReadout< false > //! Time interval from the previous spike until the cutoff of e-prop update integration between two spikes (ms). double eprop_isi_trace_cutoff_; + //! Connection delay from recurrent to readout neuron. + long delay_rec_out_; + + //! Connection delay from readout to recurrent neuron. + long delay_out_rec_; + //! Default constructor. Parameters_(); @@ -383,6 +391,9 @@ class eprop_readout : public EpropArchivingNodeReadout< false > //! Set the state variables. void set( const DictionaryDatum&, const Parameters_&, double, Node* ); + + //! Queue to hold last delay_out_rec error signals. + std::deque< double > error_signal_deque_; }; //! Structure of buffers. @@ -493,14 +504,35 @@ eprop_readout::get_eprop_isi_trace_cutoff() const return V_.eprop_isi_trace_cutoff_steps_; } +inline long +eprop_readout::get_delay_total() const +{ + return P_.delay_rec_out_; +} + inline size_t -eprop_readout::handles_test_event( SpikeEvent&, size_t receptor_type ) +eprop_readout::handles_test_event( SpikeEvent& e, size_t receptor_type ) { if ( receptor_type != 0 ) { throw UnknownReceptorType( receptor_type, get_name() ); } + // To perform a consistency check on the delay parameter d_out_rec between recurrent + // neurons and readout neurons, the recurrent neurons send a test event with a delay + // specified by d_rec_out. Upon receiving the test event from the recurrent neuron, + // the readout neuron checks if the delay with which the event was received matches + // its own specified delay parameter d_rec_out. + + // ensure that the spike event was not sent by a proxy node + if ( e.get_sender().get_node_id() != 0 ) + { + if ( e.get_delay_steps() != P_.delay_rec_out_ ) + { + throw IllegalConnection( + "delay_rec_out from recurrent to output neuron equal to delay_rec_out from output to recurrent neuron required." ); + } + } return 0; } diff --git a/models/eprop_synapse.h b/models/eprop_synapse.h index 9b4a5679c2..bab2484232 100644 --- a/models/eprop_synapse.h +++ b/models/eprop_synapse.h @@ -271,6 +271,9 @@ class eprop_synapse : public Connection< targetidentifierT > //! Update values in parameter dictionary. void set_status( const DictionaryDatum& d, ConnectorModel& cm ); + //! Initialize the presynaptic buffer. + void initialize_z_previous_buffer( const long delay_total ); + //! Send the spike event. bool send( Event& e, size_t thread, const EpropSynapseCommonProperties& cp ); @@ -333,7 +336,13 @@ class eprop_synapse : public Connection< targetidentifierT > double epsilon_ = 0.0; //! Value of spiking variable one time step before t_previous_spike_. - double z_previous_buffer_ = 0.0; + double z_previous_ = 0.0; + + //! Queue of length delay_total_ to hold previous spiking variables. + std::queue< double > z_previous_buffer_; + + //! Sum of connection delays from recurrent to readout neuron and readout to recurrent neuron. + long delay_total_ = 0; /** * Optimizer @@ -404,7 +413,7 @@ eprop_synapse< targetidentifierT >::operator=( const eprop_synapse& es ) e_bar_ = es.e_bar_; e_bar_reg_ = es.e_bar_reg_; epsilon_ = es.epsilon_; - z_previous_buffer_ = es.z_previous_buffer_; + z_previous_ = es.z_previous_; optimizer_ = es.optimizer_; return *this; @@ -445,7 +454,7 @@ eprop_synapse< targetidentifierT >::operator=( eprop_synapse&& es ) e_bar_ = es.e_bar_; e_bar_reg_ = es.e_bar_reg_; epsilon_ = es.epsilon_; - z_previous_buffer_ = es.z_previous_buffer_; + z_previous_ = es.z_previous_; optimizer_ = es.optimizer_; @@ -463,11 +472,22 @@ eprop_synapse< targetidentifierT >::check_connection( Node& s, const CommonPropertiesType& cp ) { // When we get here, delay has been set so we can check it. - if ( get_delay_steps() != 1 ) + if ( get_delay_steps() < 1 ) { throw IllegalConnection( "eprop synapses currently require a delay of one simulation step" ); } + const bool is_recurrent_node = t.is_eprop_recurrent_node(); + + if ( not is_recurrent_node ) + { + const long delay_rec_out = t.get_delay_total(); + if ( delay_rec_out != get_delay_steps() ) + { + throw IllegalConnection( "delay == delay_rec_out of target neuron required." ); + } + } + ConnTestDummyNode dummy_target; ConnectionBase::check_connection_( dummy_target, s, t, receptor_type ); @@ -484,6 +504,17 @@ eprop_synapse< targetidentifierT >::delete_optimizer() // do not set to nullptr to allow detection of double deletion } +template < typename targetidentifierT > +void +eprop_synapse< targetidentifierT >::initialize_z_previous_buffer( const long delay_total ) +{ + for ( long i = 0; i < delay_total; i++ ) + { + z_previous_buffer_.push( 0.0 ); + } + z_previous_buffer_.push( 1.0 ); +} + template < typename targetidentifierT > bool eprop_synapse< targetidentifierT >::send( Event& e, size_t thread, const EpropSynapseCommonProperties& cp ) @@ -492,11 +523,28 @@ eprop_synapse< targetidentifierT >::send( Event& e, size_t thread, const EpropSy assert( target ); const long t_spike = e.get_stamp().get_steps(); + const long delay_total = target->get_delay_total(); if ( t_spike_previous_ != 0 ) { - target->compute_gradient( - t_spike, t_spike_previous_, z_previous_buffer_, z_bar_, e_bar_, e_bar_reg_, epsilon_, weight_, cp, optimizer_ ); + target->compute_gradient( t_spike, + t_spike_previous_, + z_previous_buffer_, + z_previous_, + z_bar_, + e_bar_, + e_bar_reg_, + epsilon_, + weight_, + cp, + optimizer_ ); + } + else + { + if ( delay_total > 1 ) + { + initialize_z_previous_buffer( delay_total ); + } } const long eprop_isi_trace_cutoff = target->get_eprop_isi_trace_cutoff(); diff --git a/nestkernel/eprop_archiving_node.cpp b/nestkernel/eprop_archiving_node.cpp new file mode 100644 index 0000000000..324d4407f9 --- /dev/null +++ b/nestkernel/eprop_archiving_node.cpp @@ -0,0 +1,308 @@ +/* + * eprop_archiving_node.cpp + * + * This file is part of NEST. + * + * Copyright (C) 2004 The NEST Initiative + * + * NEST is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * NEST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with NEST. If not, see . + * + */ + +// nestkernel +#include "eprop_archiving_node.h" +#include "eprop_archiving_node_impl.h" +#include "kernel_manager.h" + +// sli +#include "dictutils.h" + +namespace nest +{ + +EpropArchivingNodeRecurrent::EpropArchivingNodeRecurrent() + : EpropArchivingNode() + , firing_rate_reg_( 0.0 ) + , f_av_( 0.0 ) + , n_spikes_( 0 ) +{ +} + +EpropArchivingNodeRecurrent::EpropArchivingNodeRecurrent( const EpropArchivingNodeRecurrent& n ) + : EpropArchivingNode( n ) + , firing_rate_reg_( n.firing_rate_reg_ ) + , f_av_( n.f_av_ ) + , n_spikes_( n.n_spikes_ ) +{ +} + +double +EpropArchivingNodeRecurrent::compute_piecewise_linear_surrogate_gradient( const double r, + const double v_m, + const double v_th_adapt, + const double V_th, + const double beta, + const double gamma ) +{ + if ( r > 0 ) + { + return 0.0; + } + + return gamma * std::max( 0.0, 1.0 - beta * std::fabs( ( v_m - v_th_adapt ) / V_th ) ) / V_th; +} + +double +EpropArchivingNodeRecurrent::compute_exponential_surrogate_gradient( const double r, + const double v_m, + const double v_th_adapt, + const double V_th, + const double beta, + const double gamma ) +{ + if ( r > 0 ) + { + return 0.0; + } + + if ( fabs( V_th ) < 1e-6 ) + { + throw BadProperty( + "Relative threshold voltage V_th-E_L ≠ 0 required if surrogate_gradient_function is \"piecewise_linear\"." ); + } + + return gamma * std::exp( -beta * std::fabs( v_m - v_th_adapt ) ); +} + +double +EpropArchivingNodeRecurrent::compute_fast_sigmoid_derivative_surrogate_gradient( const double r, + const double v_m, + const double v_th_adapt, + const double V_th, + const double beta, + const double gamma ) +{ + if ( r > 0 ) + { + return 0.0; + } + + return gamma * std::pow( 1.0 + beta * std::fabs( v_m - v_th_adapt ), -2 ); +} + +double +EpropArchivingNodeRecurrent::compute_arctan_surrogate_gradient( const double r, + const double v_m, + const double v_th_adapt, + const double V_th, + const double beta, + const double gamma ) +{ + if ( r > 0 ) + { + return 0.0; + } + + return gamma / M_PI * ( 1.0 / ( 1.0 + std::pow( beta * M_PI * ( v_m - v_th_adapt ), 2 ) ) ); +} + +void +EpropArchivingNodeRecurrent::emplace_new_eprop_history_entry( const long time_step ) +{ + if ( eprop_indegree_ == 0 ) + { + return; + } + + eprop_history_.emplace_back( time_step, 0.0, 0.0 ); +} + +void +EpropArchivingNodeRecurrent::write_surrogate_gradient_to_history( const long time_step, + const double surrogate_gradient ) +{ + if ( eprop_indegree_ == 0 ) + { + return; + } + + auto it_hist = get_eprop_history( time_step ); + it_hist->surrogate_gradient_ = surrogate_gradient; +} + +void +EpropArchivingNodeRecurrent::write_learning_signal_to_history( const long time_step, + const double learning_signal, + const bool has_norm_step ) +{ + if ( eprop_indegree_ == 0 ) + { + return; + } + + long shift = delay_out_rec_; + + if ( has_norm_step ) + { + shift += delay_rec_out_ + delay_out_norm_; + } + else + { + shift += get_delay_total(); + } + + + auto it_hist = get_eprop_history( time_step - shift ); + const auto it_hist_end = get_eprop_history( time_step - shift + delay_out_rec_ ); + + for ( ; it_hist != it_hist_end; ++it_hist ) + { + it_hist->learning_signal_ += learning_signal; + } +} + +void +EpropArchivingNodeRecurrent::write_firing_rate_reg_to_history( const long t_current_update, + const double f_target, + const double c_reg ) +{ + if ( eprop_indegree_ == 0 ) + { + return; + } + + const double update_interval = kernel().simulation_manager.get_eprop_update_interval().get_steps(); + const double dt = Time::get_resolution().get_ms(); + const long shift = Time::get_resolution().get_steps(); + + const double f_av = n_spikes_ / update_interval; + const double f_target_ = f_target * dt; // convert from spikes/ms to spikes/step + const double firing_rate_reg = c_reg * ( f_av - f_target_ ) / update_interval; + + firing_rate_reg_history_.emplace_back( t_current_update + shift, firing_rate_reg ); +} + +void +EpropArchivingNodeRecurrent::write_firing_rate_reg_to_history( const long t, + const double z, + const double f_target, + const double kappa, + const double c_reg ) +{ + if ( eprop_indegree_ == 0 ) + { + return; + } + + const double dt = Time::get_resolution().get_ms(); + + const double f_target_ = f_target * dt; // convert from spikes/ms to spikes/step + + f_av_ = kappa * f_av_ + ( 1.0 - kappa ) * z / dt; + + firing_rate_reg_ = c_reg * ( f_av_ - f_target_ ); + + auto it_hist = get_eprop_history( t ); + it_hist->learning_signal_ += firing_rate_reg_; +} + +std::vector< HistEntryEpropFiringRateReg >::iterator +EpropArchivingNodeRecurrent::get_firing_rate_reg_history( const long time_step ) +{ + const auto it_hist = std::lower_bound( firing_rate_reg_history_.begin(), firing_rate_reg_history_.end(), time_step ); + assert( it_hist != firing_rate_reg_history_.end() ); + + return it_hist; +} + +double +EpropArchivingNodeRecurrent::get_learning_signal_from_history( const long time_step, const bool has_norm_step ) +{ + long shift = delay_rec_out_ + delay_out_rec_; + + if ( has_norm_step ) + { + shift += delay_out_norm_; + } + + const auto it = get_eprop_history( time_step - shift ); + if ( it == eprop_history_.end() ) + { + return 0; + } + + return it->learning_signal_; +} + +void +EpropArchivingNodeRecurrent::erase_used_firing_rate_reg_history() +{ + auto it_update_hist = update_history_.begin(); + auto it_reg_hist = firing_rate_reg_history_.begin(); + + while ( it_update_hist != update_history_.end() and it_reg_hist != firing_rate_reg_history_.end() ) + { + if ( it_update_hist->access_counter_ == 0 ) + { + it_reg_hist = firing_rate_reg_history_.erase( it_reg_hist ); + } + else + { + ++it_reg_hist; + } + ++it_update_hist; + } +} + +EpropArchivingNodeReadout::EpropArchivingNodeReadout() + : EpropArchivingNode() +{ +} + +EpropArchivingNodeReadout::EpropArchivingNodeReadout( const EpropArchivingNodeReadout& n ) + : EpropArchivingNode( n ) +{ +} + +void +EpropArchivingNodeReadout::emplace_new_eprop_history_entry( const long time_step, const bool has_norm_step ) +{ + if ( eprop_indegree_ == 0 ) + { + return; + } + + const long shift = has_norm_step ? delay_out_norm_ : 0; + + eprop_history_.emplace_back( time_step - shift, 0.0 ); +} + +void +EpropArchivingNodeReadout::write_error_signal_to_history( const long time_step, + const double error_signal, + const bool has_norm_step ) +{ + if ( eprop_indegree_ == 0 ) + { + return; + } + + const long shift = has_norm_step ? delay_out_norm_ : 0; + + auto it_hist = get_eprop_history( time_step - shift ); + it_hist->error_signal_ = error_signal; +} + + +} // namespace nest diff --git a/nestkernel/eprop_archiving_node.h b/nestkernel/eprop_archiving_node.h index 04cfc2d3ba..a213fb5af3 100644 --- a/nestkernel/eprop_archiving_node.h +++ b/nestkernel/eprop_archiving_node.h @@ -111,8 +111,33 @@ class EpropArchivingNode : public Node * * Retrieves the size of the eprop history buffer. */ + double get_eprop_history_duration() const; + /** + * @brief Updates multiple entries in the presynaptic buffer. + * + * Used when the total synaptic delay is greater than one. + */ + void update_pre_syn_buffer_multiple_entries( double& z, + double& z_current, + double& z_previous, + std::queue< double >& z_previous_buffer, + double t_spike, + double t ); + + /** + * @brief Updates one entry in the presynaptic buffer. + * + * Used when the total synaptic delay equals one. + */ + void update_pre_syn_buffer_one_entry( double& z, + double& z_current, + double& z_previous, + std::queue< double >& z_previous_buffer, + double t_spike, + double t ); + protected: //! Returns correct shift for history depending on whether it is a normal or a bsshslm_2020 model. virtual long model_dependent_history_shift_() const = 0; diff --git a/nestkernel/eprop_archiving_node_impl.h b/nestkernel/eprop_archiving_node_impl.h index 7c4c60a52e..697d496aa8 100644 --- a/nestkernel/eprop_archiving_node_impl.h +++ b/nestkernel/eprop_archiving_node_impl.h @@ -189,6 +189,47 @@ EpropArchivingNode< HistEntryT >::get_eprop_history_duration() const return Time::get_resolution().get_ms() * eprop_history_.size(); } + +template < typename HistEntryT > +void +EpropArchivingNode< HistEntryT >::update_pre_syn_buffer_multiple_entries( double& z, + double& z_current, + double& z_previous, + std::queue< double >& z_previous_buffer, + double t_spike, + double t ) +{ + if ( not z_previous_buffer.empty() ) + { + z = z_previous_buffer.front(); + z_previous_buffer.pop(); + } + + if ( t_spike - t > 1 ) + { + z_previous_buffer.push( 0.0 ); + } + else + { + z_previous_buffer.push( 1.0 ); + } +} + +template < typename HistEntryT > +void +EpropArchivingNode< HistEntryT >::update_pre_syn_buffer_one_entry( double& z, + double& z_current, + double& z_previous, + std::queue< double >& pre_syn_buffer, + double t_spike, + double t ) +{ + z = z_previous; + z_previous = z_current; + z_current = 0.0; +} + + } // namespace nest #endif // EPROP_ARCHIVING_NODE_IMPL_H diff --git a/nestkernel/eprop_archiving_node_readout.h b/nestkernel/eprop_archiving_node_readout.h index 97fd67ca27..ec20a52433 100644 --- a/nestkernel/eprop_archiving_node_readout.h +++ b/nestkernel/eprop_archiving_node_readout.h @@ -129,7 +129,7 @@ EpropArchivingNodeReadout< hist_shift_required >::model_dependent_history_shift_ } else { - return -delay_rec_out_; + return -get_delay_total(); } } diff --git a/nestkernel/eprop_archiving_node_recurrent_impl.h b/nestkernel/eprop_archiving_node_recurrent_impl.h index 58dde7a87e..12a4ebba4a 100644 --- a/nestkernel/eprop_archiving_node_recurrent_impl.h +++ b/nestkernel/eprop_archiving_node_recurrent_impl.h @@ -184,11 +184,15 @@ EpropArchivingNodeRecurrent< hist_shift_required >::write_learning_signal_to_his return; } - long shift = delay_rec_out_ + delay_out_rec_; + long shift = delay_out_rec_; if constexpr ( hist_shift_required ) { - shift += delay_out_norm_; + shift += delay_rec_out_ + delay_out_norm_; + } + else + { + shift += get_delay_total(); } diff --git a/nestkernel/nest_names.cpp b/nestkernel/nest_names.cpp index 2b1d98435e..5d887da3e8 100644 --- a/nestkernel/nest_names.cpp +++ b/nestkernel/nest_names.cpp @@ -127,6 +127,8 @@ const Name dead_time( "dead_time" ); const Name dead_time_random( "dead_time_random" ); const Name dead_time_shape( "dead_time_shape" ); const Name delay( "delay" ); +const Name delay_out_rec( "delay_out_rec" ); +const Name delay_rec_out( "delay_rec_out" ); const Name delay_u_bars( "delay_u_bars" ); const Name deliver_interval( "deliver_interval" ); const Name delta( "delta" ); diff --git a/nestkernel/nest_names.h b/nestkernel/nest_names.h index 44a75f62d0..3877274a33 100644 --- a/nestkernel/nest_names.h +++ b/nestkernel/nest_names.h @@ -154,6 +154,8 @@ extern const Name dead_time; extern const Name dead_time_random; extern const Name dead_time_shape; extern const Name delay; +extern const Name delay_out_rec; +extern const Name delay_rec_out; extern const Name delay_u_bars; extern const Name deliver_interval; extern const Name delta; diff --git a/nestkernel/node.cpp b/nestkernel/node.cpp index 6f54cc1075..ce4781dc79 100644 --- a/nestkernel/node.cpp +++ b/nestkernel/node.cpp @@ -230,6 +230,24 @@ Node::get_shift() const throw IllegalConnection( "The target node is not an e-prop neuron." ); } +long +Node::get_delay_total() const +{ + throw IllegalConnection( "The target node is not an e-prop neuron." ); +} + +long +Node::get_delay_recurrent_to_readout() const +{ + throw IllegalConnection( "The target node is not an e-prop neuron." ); +} + +long +Node::get_delay_readout_to_recurrent() const +{ + throw IllegalConnection( "The target node is not an e-prop neuron." ); +} + void Node::write_update_to_history( const long, const long, const long ) { @@ -552,6 +570,7 @@ nest::Node::get_tau_syn_in( int ) void nest::Node::compute_gradient( const long, const long, + std::queue< double >&, double&, double&, double&, diff --git a/nestkernel/node.h b/nestkernel/node.h index 5c41113f43..c10bc35c7a 100644 --- a/nestkernel/node.h +++ b/nestkernel/node.h @@ -26,6 +26,7 @@ // C++ includes: #include #include +#include #include #include #include @@ -532,6 +533,30 @@ class Node */ virtual long get_eprop_isi_trace_cutoff() const; + /** + * Get sum of connection delays from recurrent to output neuron and output to recurrent neuron. + * + * @throws IllegalConnection + */ + + virtual long get_delay_total() const; + + /** + * Get connection delay from recurrent to output neuron. + * + * @throws IllegalConnection + */ + + virtual long get_delay_recurrent_to_readout() const; + + /** + * Get connection delay from output to recurrent neuron. + * + * @throws IllegalConnection + */ + + virtual long get_delay_readout_to_recurrent() const; + /** * Checks if the node is part of the recurrent network and thus not a readout neuron. * @@ -836,6 +861,7 @@ class Node * @param t_spike [in] Time of the current spike. * @param t_spike_previous [in] Time of the previous spike. * @param z_previous_buffer [in, out] Value of presynaptic spiking variable from previous time step. + * @param z_previous * @param z_bar [in, out] Filtered presynaptic spiking variable. * @param e_bar [in, out] Filtered eligibility trace. * @param e_bar_reg [in, out] Filtered eligibility trace for firing rate regularization. @@ -847,7 +873,8 @@ class Node */ virtual void compute_gradient( const long t_spike, const long t_spike_previous, - double& z_previous_buffer, + std::queue< double >& z_previous_buffer, + double& z_previous, double& z_bar, double& e_bar, double& e_bar_reg, diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation.py b/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation.py new file mode 100644 index 0000000000..8232b4d97f --- /dev/null +++ b/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation.py @@ -0,0 +1,852 @@ +# -*- coding: utf-8 -*- +# +# eprop_supervised_classification_evidence-accumulation.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +r""" +Tutorial on learning to accumulate evidence with e-prop +------------------------------------------------------- + +Training a classification model using supervised e-prop plasticity to accumulate evidence. + +Description +~~~~~~~~~~~ + +This script demonstrates supervised learning of a classification task with the eligibility propagation (e-prop) +plasticity mechanism by Bellec et al. [1]_. + +This type of learning is demonstrated at the proof-of-concept task in [1]_. We based this script on their +TensorFlow script given in [2]_. + +The task, a so-called evidence accumulation task, is inspired by behavioral tasks, where a lab animal (e.g., a +mouse) runs along a track, gets cues on the left and right, and has to decide at the end of the track between +taking a left and a right turn of which one is correct. After a number of iterations, the animal is able to +infer the underlying rationale of the task. Here, the solution is to turn to the side in which more cues were +presented. + +.. image:: ../../../../pynest/examples/eprop_plasticity/eprop_supervised_classification_schematic_evidence-accumulation.png + :width: 70 % + :alt: Schematic of network architecture. Same as Figure 1 in the code. + :align: center + +Learning in the neural network model is achieved by optimizing the connection weights with e-prop plasticity. +This plasticity rule requires a specific network architecture depicted in Figure 1. The neural network model +consists of a recurrent network that receives input from Poisson generators and projects onto two readout +neurons - one for the left and one for the right turn at the end. The input neuron population consists of four +groups: one group providing background noise of a specific rate for some base activity throughout the +experiment, one group providing the input spikes of the left cues and one group providing them for the right +cues, and a last group defining the recall window, in which the network has to decide. The readout neuron +compares the network signal :math:`\pi_k` with the teacher target signal :math:`\pi_k^*`, which it receives from +a rate generator. Since the decision is at the end and all the cues are relevant, the network has to keep the +cues in memory. Additional adaptive neurons in the network enable this memory. The network's training error is +assessed by employing a mean squared error loss. + +Details on the event-based NEST implementation of e-prop can be found in [3]_. + +References +~~~~~~~~~~ + +.. [1] Bellec G, Scherr F, Subramoney F, Hajek E, Salaj D, Legenstein R, Maass W (2020). A solution to the + learning dilemma for recurrent networks of spiking neurons. Nature Communications, 11:3625. + https://doi.org/10.1038/s41467-020-17236-y + +.. [2] https://github.com/IGITUGraz/eligibility_propagation/blob/master/Figure_3_and_S7_e_prop_tutorials/tutorial_evidence_accumulation_with_alif.py + +.. [3] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Dahmen D, van Albada SJ, Plesser HE, Bolten M, Diesmann M. + Event-based implementation of eligibility propagation (in preparation) +""" # pylint: disable=line-too-long # noqa: E501 + +# %% ########################################################################################################### +# Import libraries +# ~~~~~~~~~~~~~~~~ +# We begin by importing all libraries required for the simulation, analysis, and visualization. + +import matplotlib as mpl +import matplotlib.pyplot as plt +import nest +import numpy as np +from cycler import cycler +from IPython.display import Image + +# %% ########################################################################################################### +# Schematic of network architecture +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# This figure, identical to the one in the description, shows the required network architecture in the center, +# the input and output of the evidence accumulation task above, and lists of the required NEST device, neuron, +# and synapse models below. The connections that must be established are numbered 1 to 7. + +try: + Image(filename="./eprop_supervised_classification_schematic_evidence-accumulation.png") +except Exception: + pass + +# %% ########################################################################################################### +# Setup +# ~~~~~ + +# %% ########################################################################################################### +# Initialize random generator +# ........................... +# We seed the numpy random generator, which will generate random initial weights as well as random input and +# output. + +rng_seed = 1 # numpy random seed +np.random.seed(rng_seed) # fix numpy random seed + +# %% ########################################################################################################### +# Define timing of task +# ..................... +# The task's temporal structure is then defined, once as time steps and once as durations in milliseconds. +# Even though each sample is processed independently during training, we aggregate predictions and true +# labels across a group of samples during the evaluation phase. The number of samples in this group is +# determined by the `group_size` parameter. This data is then used to assess the neural network's +# performance metrics, such as average accuracy and mean error. Increasing the number of iterations enhances +# learning performance up to the point where overfitting occurs. + +group_size = 32 # number of instances over which to evaluate the learning performance +n_iter = 50 # number of iterations + +n_input_symbols = 4 # number of input populations, e.g. 4 = left, right, recall, noise +n_cues = 7 # number of cues given before decision +prob_group = 0.3 # probability with which one input group is present + +steps = { + "cue": 100, # time steps in one cue presentation + "spacing": 50, # time steps of break between two cues + "bg_noise": 1050, # time steps of background noise + "recall": 150, # time steps of recall + "delay_rec_out": 1, # time steps of connection delay from recurrent to output neurons + "delay_out_rec": 1, # time steps of broadcast delay of learning signals +} + +steps["cues"] = n_cues * (steps["cue"] + steps["spacing"]) # time steps of all cues +steps["sequence"] = steps["cues"] + steps["bg_noise"] + steps["recall"] # time steps of one full sequence +steps["learning_window"] = steps["recall"] # time steps of window with non-zero learning signals +steps["task"] = n_iter * group_size * steps["sequence"] # time steps of task + +steps.update( + { + "offset_gen": 1, # offset since generator signals start from time step 1 + "delay_in_rec": 1, # connection delay between input and recurrent neurons + "extension_sim": 3, # extra time step to close right-open simulation time interval in Simulate() + } +) + +steps["delays"] = steps["delay_in_rec"] # time steps of delays + +steps["total_offset"] = steps["offset_gen"] + steps["delays"] # time steps of total offset + +steps["sim"] = steps["task"] + steps["total_offset"] + steps["extension_sim"] # time steps of simulation + +duration = {"step": 1.0} # ms, temporal resolution of the simulation + +duration.update({key: value * duration["step"] for key, value in steps.items()}) # ms, durations + +# %% ########################################################################################################### +# Set up simulation +# ................. +# As last step of the setup, we reset the NEST kernel to remove all existing NEST simulation settings and +# objects and set some NEST kernel parameters. + +params_setup = { + "print_time": False, # if True, print time progress bar during simulation, set False if run as code cell + "resolution": duration["step"], + "total_num_virtual_procs": 1, # number of virtual processes, set in case of distributed computing +} + +#################### + +nest.ResetKernel() +nest.set(**params_setup) + +# %% ########################################################################################################### +# Create neurons +# ~~~~~~~~~~~~~~ +# We proceed by creating a certain number of input, recurrent, and readout neurons and setting their parameters. +# Additionally, we already create an input spike generator and an output target rate generator, which we will +# configure later. Within the recurrent network, alongside a population of regular neurons, we introduce a +# population of adaptive neurons, to enhance the network's memory retention. + +n_in = 40 # number of input neurons +n_ad = 50 # number of adaptive neurons +n_reg = 50 # number of regular neurons +n_rec = n_ad + n_reg # number of recurrent neurons +n_out = 2 # number of readout neurons + +params_nrn_out = { + "C_m": 1.0, # pF, membrane capacitance - takes effect only if neurons get current input (here not the case) + "E_L": 0.0, # mV, leak / resting membrane potential + "eprop_isi_trace_cutoff": 100, # cutoff of integration of eprop trace between spikes + "I_e": 0.0, # pA, external current input + "regular_spike_arrival": False, # If True, input spikes arrive at end of time step, if False at beginning + "tau_m": 20.0, # ms, membrane time constant + "V_m": 0.0, # mV, initial value of the membrane voltage + "delay_out_rec": duration["delay_out_rec"], # ms, broadcast delay of learning signals + "delay_rec_out": duration["delay_rec_out"], # ms, connection delay from recurrent to output neurons +} + +params_nrn_reg = { + "beta": 1.0, # width scaling of the pseudo-derivative + "C_m": 1.0, + "c_reg": 300.0 / duration["sequence"] * duration["learning_window"], # firing rate regularization scaling + "E_L": 0.0, + "eprop_isi_trace_cutoff": 100, + "f_target": 10.0, # spikes/s, target firing rate for firing rate regularization + "gamma": 0.3, # height scaling of the pseudo-derivative + "I_e": 0.0, + "regular_spike_arrival": True, + "surrogate_gradient_function": "piecewise_linear", # surrogate gradient / pseudo-derivative function + "t_ref": 5.0, # ms, duration of refractory period + "tau_m": 20.0, + "V_m": 0.0, + "V_th": 0.6, # mV, spike threshold membrane voltage + "kappa": 0.97, # low-pass filter of the eligibility trace + "delay_out_rec": duration["delay_out_rec"], # ms, broadcast delay of learning signals + "delay_rec_out": duration["delay_rec_out"], # ms, connection delay from recurrent to output neurons +} + +params_nrn_ad = { + "beta": 1.0, + "adapt_tau": 2000.0, # ms, time constant of adaptive threshold + "adaptation": 0.0, # initial value of the spike threshold adaptation + "C_m": 1.0, + "c_reg": 300.0 / duration["sequence"] * duration["learning_window"], + "E_L": 0.0, + "eprop_isi_trace_cutoff": 100, # cutoff of integration of eprop trace between spikes + "f_target": 10.0, + "gamma": 0.3, + "I_e": 0.0, + "regular_spike_arrival": True, + "surrogate_gradient_function": "piecewise_linear", + "t_ref": 5.0, + "tau_m": 20.0, + "V_m": 0.0, + "V_th": 0.6, + "kappa": 0.97, + "delay_out_rec": duration["delay_out_rec"], # ms, broadcast delay of learning signals + "delay_rec_out": duration["delay_rec_out"], # ms, connection delay from recurrent to output neurons +} + +params_nrn_ad["adapt_beta"] = 1.7 * ( + (1.0 - np.exp(-duration["step"] / params_nrn_ad["adapt_tau"])) + / (1.0 - np.exp(-duration["step"] / params_nrn_ad["tau_m"])) +) # prefactor of adaptive threshold + +#################### + +# Intermediate parrot neurons required between input spike generators and recurrent neurons, +# since devices cannot establish plastic synapses for technical reasons + +gen_spk_in = nest.Create("spike_generator", n_in) +nrns_in = nest.Create("parrot_neuron", n_in) + +nrns_reg = nest.Create("eprop_iaf", n_reg, params_nrn_reg) +nrns_ad = nest.Create("eprop_iaf_adapt", n_ad, params_nrn_ad) +nrns_out = nest.Create("eprop_readout", n_out, params_nrn_out) +gen_rate_target = nest.Create("step_rate_generator", n_out) +gen_learning_window = nest.Create("step_rate_generator") + +nrns_rec = nrns_reg + nrns_ad + +# %% ########################################################################################################### +# Create recorders +# ~~~~~~~~~~~~~~~~ +# We also create recorders, which, while not required for the training, will allow us to track various dynamic +# variables of the neurons, spikes, and changes in synaptic weights. To save computing time and memory, the +# recorders, the recorded variables, neurons, and synapses can be limited to the ones relevant to the +# experiment, and the recording interval can be increased (see the documentation on the specific recorders). By +# default, recordings are stored in memory but can also be written to file. + +n_record = 1 # number of neurons per type to record dynamic variables from - this script requires n_record >= 1 +n_record_w = 5 # number of senders and targets to record weights from - this script requires n_record_w >=1 + +if n_record == 0 or n_record_w == 0: + raise ValueError("n_record and n_record_w >= 1 required") + +params_mm_reg = { + "interval": duration["step"], # interval between two recorded time points + "record_from": ["V_m", "surrogate_gradient", "learning_signal"], # dynamic variables to record + "start": duration["offset_gen"] + duration["delay_in_rec"], # start time of recording + "stop": duration["offset_gen"] + duration["delay_in_rec"] + duration["task"], # stop time of recording + "label": "multimeter_reg", +} + +params_mm_ad = { + "interval": duration["step"], + "record_from": params_mm_reg["record_from"] + ["V_th_adapt", "adaptation"], + "start": duration["offset_gen"] + duration["delay_in_rec"], + "stop": duration["offset_gen"] + duration["delay_in_rec"] + duration["task"], + "label": "multimeter_ad", +} + +params_mm_out = { + "interval": duration["step"], + "record_from": ["V_m", "readout_signal", "target_signal", "error_signal"], + "start": duration["total_offset"], + "stop": duration["total_offset"] + duration["task"], + "label": "multimeter_out", +} + +params_wr = { + "senders": nrns_in[:n_record_w] + nrns_rec[:n_record_w], # limit senders to subsample weights to record + "targets": nrns_rec[:n_record_w] + nrns_out, # limit targets to subsample weights to record from + "start": duration["total_offset"], + "stop": duration["total_offset"] + duration["task"], + "label": "weight_recorder", +} + +params_sr_in = { + "start": duration["offset_gen"], + "stop": duration["total_offset"] + duration["task"], + "label": "spike_recorder_in", +} + +params_sr_reg = { + "start": duration["offset_gen"], + "stop": duration["total_offset"] + duration["task"], + "label": "spike_recorder_reg", +} + +params_sr_ad = { + "start": duration["offset_gen"], + "stop": duration["total_offset"] + duration["task"], + "label": "spike_recorder_ad", +} + +#################### + +mm_reg = nest.Create("multimeter", params_mm_reg) +mm_ad = nest.Create("multimeter", params_mm_ad) +mm_out = nest.Create("multimeter", params_mm_out) +sr_in = nest.Create("spike_recorder", params_sr_in) +sr_reg = nest.Create("spike_recorder", params_sr_reg) +sr_ad = nest.Create("spike_recorder", params_sr_ad) +wr = nest.Create("weight_recorder", params_wr) + +nrns_reg_record = nrns_reg[:n_record] +nrns_ad_record = nrns_ad[:n_record] + +# %% ########################################################################################################### +# Create connections +# ~~~~~~~~~~~~~~~~~~ +# Now, we define the connectivity and set up the synaptic parameters, with the synaptic weights drawn from +# normal distributions. After these preparations, we establish the enumerated connections of the core network, +# as well as additional connections to the recorders. + +params_conn_all_to_all = {"rule": "all_to_all", "allow_autapses": False} +params_conn_one_to_one = {"rule": "one_to_one"} + + +def calculate_glorot_dist(fan_in, fan_out): + glorot_scale = 1.0 / max(1.0, (fan_in + fan_out) / 2.0) + glorot_limit = np.sqrt(3.0 * glorot_scale) + glorot_distribution = np.random.uniform(low=-glorot_limit, high=glorot_limit, size=(fan_in, fan_out)) + return glorot_distribution + + +dtype_weights = np.float32 # data type of weights - for reproducing TF results set to np.float32 +weights_in_rec = np.array(np.random.randn(n_in, n_rec).T / np.sqrt(n_in), dtype=dtype_weights) +weights_rec_rec = np.array(np.random.randn(n_rec, n_rec).T / np.sqrt(n_rec), dtype=dtype_weights) +np.fill_diagonal(weights_rec_rec, 0.0) # since no autapses set corresponding weights to zero +weights_rec_out = np.array(calculate_glorot_dist(n_rec, n_out).T, dtype=dtype_weights) +weights_out_rec = np.array(np.random.randn(n_rec, n_out), dtype=dtype_weights) + +params_common_syn_eprop = { + "optimizer": { + "type": "adam", # algorithm to optimize the weights + "batch_size": 1, + "beta_1": 0.9, # exponential decay rate for 1st moment estimate of Adam optimizer + "beta_2": 0.999, # exponential decay rate for 2nd moment raw estimate of Adam optimizer + "epsilon": 1e-8, # small numerical stabilization constant of Adam optimizer + "eta": 5e-3 / duration["learning_window"], # learning rate + "optimize_each_step": True, # call optimizer every time step (True) or once per spike (False); only + # True implements original Adam algorithm, False offers speed-up; choice can affect learning performance + "Wmin": -100.0, # pA, minimal limit of the synaptic weights + "Wmax": 100.0, # pA, maximal limit of the synaptic weights + }, + "weight_recorder": wr, +} + +params_syn_base = { + "synapse_model": "eprop_synapse", + "delay": duration["step"], # ms, dendritic delay +} + +params_syn_in = params_syn_base.copy() +params_syn_in["weight"] = weights_in_rec # pA, initial values for the synaptic weights + +params_syn_rec = params_syn_base.copy() +params_syn_rec["weight"] = weights_rec_rec + +params_syn_out = params_syn_base.copy() +params_syn_out["weight"] = weights_rec_out +params_syn_out["delay"] = duration["delay_rec_out"] + +params_syn_feedback = { + "synapse_model": "eprop_learning_signal_connection", + "delay": duration["delay_out_rec"], + "weight": weights_out_rec, +} + +params_syn_learning_window = { + "synapse_model": "rate_connection_delayed", + "delay": duration["step"], + "receptor_type": 1, # receptor type over which readout neuron receives learning window signal +} + +params_syn_rate_target = { + "synapse_model": "rate_connection_delayed", + "delay": duration["step"], + "receptor_type": 2, # receptor type over which readout neuron receives target signal +} + +params_syn_static = { + "synapse_model": "static_synapse", + "delay": duration["step"], +} + +params_init_optimizer = { + "optimizer": { + "m": 0.0, # initial 1st moment estimate m of Adam optimizer + "v": 0.0, # initial 2nd moment raw estimate v of Adam optimizer + } +} + +#################### + +nest.SetDefaults("eprop_synapse", params_common_syn_eprop) + +nest.Connect(gen_spk_in, nrns_in, params_conn_one_to_one, params_syn_static) # connection 1 +nest.Connect(nrns_in, nrns_rec, params_conn_all_to_all, params_syn_in) # connection 2 +nest.Connect(nrns_rec, nrns_rec, params_conn_all_to_all, params_syn_rec) # connection 3 +nest.Connect(nrns_rec, nrns_out, params_conn_all_to_all, params_syn_out) # connection 4 +nest.Connect(nrns_out, nrns_rec, params_conn_all_to_all, params_syn_feedback) # connection 5 +nest.Connect(gen_rate_target, nrns_out, params_conn_one_to_one, params_syn_rate_target) # connection 6 +nest.Connect(gen_learning_window, nrns_out, params_conn_all_to_all, params_syn_learning_window) # connection 7 + +nest.Connect(nrns_in, sr_in, params_conn_all_to_all, params_syn_static) +nest.Connect(nrns_reg, sr_reg, params_conn_all_to_all, params_syn_static) +nest.Connect(nrns_ad, sr_ad, params_conn_all_to_all, params_syn_static) + +nest.Connect(mm_reg, nrns_reg_record, params_conn_all_to_all, params_syn_static) +nest.Connect(mm_ad, nrns_ad_record, params_conn_all_to_all, params_syn_static) +nest.Connect(mm_out, nrns_out, params_conn_all_to_all, params_syn_static) + +# After creating the connections, we can individually initialize the optimizer's +# dynamic variables for single synapses (here exemplarily for two connections). + +nest.GetConnections(nrns_rec[0], nrns_rec[1:3]).set([params_init_optimizer] * 2) + +# %% ########################################################################################################### +# Create input and output +# ~~~~~~~~~~~~~~~~~~~~~~~ +# We generate the input as four neuron populations, two producing the left and right cues, respectively, one the +# recall signal and one the background input throughout the task. The sequence of cues is drawn with a +# probability that favors one side. For each such sequence, the favored side, the solution or target, is +# assigned randomly to the left or right. + + +def generate_evidence_accumulation_input_output( + batch_size, n_in, prob_group, input_spike_prob, n_cues, n_input_symbols, steps +): + n_pop_nrn = n_in // n_input_symbols + + prob_choices = np.array([prob_group, 1 - prob_group], dtype=np.float32) + idx = np.random.choice([0, 1], batch_size) + probs = np.zeros((batch_size, 2), dtype=np.float32) + probs[:, 0] = prob_choices[idx] + probs[:, 1] = prob_choices[1 - idx] + + batched_cues = np.zeros((batch_size, n_cues), dtype=int) + for b_idx in range(batch_size): + batched_cues[b_idx, :] = np.random.choice([0, 1], n_cues, p=probs[b_idx]) + + input_spike_probs = np.zeros((batch_size, steps["sequence"], n_in)) + + for b_idx in range(batch_size): + for c_idx in range(n_cues): + cue = batched_cues[b_idx, c_idx] + + step_start = c_idx * (steps["cue"] + steps["spacing"]) + steps["spacing"] + step_stop = step_start + steps["cue"] + + pop_nrn_start = cue * n_pop_nrn + pop_nrn_stop = pop_nrn_start + n_pop_nrn + + input_spike_probs[b_idx, step_start:step_stop, pop_nrn_start:pop_nrn_stop] = input_spike_prob + + input_spike_probs[:, -steps["recall"] :, 2 * n_pop_nrn : 3 * n_pop_nrn] = input_spike_prob + input_spike_probs[:, :, 3 * n_pop_nrn :] = input_spike_prob / 4.0 + input_spike_bools = input_spike_probs > np.random.rand(input_spike_probs.size).reshape(input_spike_probs.shape) + input_spike_bools[:, 0, :] = 0 # remove spikes in 0th time step of every sequence for technical reasons + + target_cues = np.zeros(batch_size, dtype=int) + target_cues[:] = np.sum(batched_cues, axis=1) > int(n_cues / 2) + + return input_spike_bools, target_cues + + +input_spike_prob = 0.04 # spike probability of frozen input noise +dtype_in_spks = np.float32 # data type of input spikes - for reproducing TF results set to np.float32 + +input_spike_bools_list = [] +target_cues_list = [] + +for _ in range(n_iter): + input_spike_bools, target_cues = generate_evidence_accumulation_input_output( + group_size, n_in, prob_group, input_spike_prob, n_cues, n_input_symbols, steps + ) + input_spike_bools_list.append(input_spike_bools) + target_cues_list.extend(target_cues) + +input_spike_bools_arr = np.array(input_spike_bools_list).reshape(steps["task"], n_in) +timeline_task = np.arange(0.0, duration["task"], duration["step"]) + duration["offset_gen"] + +params_gen_spk_in = [ + {"spike_times": timeline_task[input_spike_bools_arr[:, nrn_in_idx]].astype(dtype_in_spks)} + for nrn_in_idx in range(n_in) +] + +target_rate_changes = np.zeros((n_out, group_size * n_iter)) +target_rate_changes[np.array(target_cues_list), np.arange(group_size * n_iter)] = 1 + +params_gen_rate_target = [ + { + "amplitude_times": np.arange(0.0, duration["task"], duration["sequence"]) + duration["total_offset"], + "amplitude_values": target_rate_changes[nrn_out_idx], + } + for nrn_out_idx in range(n_out) +] + +#################### + +nest.SetStatus(gen_spk_in, params_gen_spk_in) +nest.SetStatus(gen_rate_target, params_gen_rate_target) + +# %% ########################################################################################################### +# Create learning window +# ~~~~~~~~~~~~~~~~~~~~~~ +# Custom learning windows, in which the network learns, can be defined with an additional signal. The error +# signal is internally multiplied with this learning window signal. Passing a learning window signal of value 1 +# opens the learning window while passing a value of 0 closes it. + +amplitude_times = np.hstack( + [ + np.array([0.0, duration["sequence"] - duration["learning_window"]]) + + duration["total_offset"] + + i * duration["sequence"] + for i in range(group_size * n_iter) + ] +) + +amplitude_values = np.array([0.0, 1.0] * group_size * n_iter) + +params_gen_learning_window = { + "amplitude_times": amplitude_times, + "amplitude_values": amplitude_values, +} + +#################### + +nest.SetStatus(gen_learning_window, params_gen_learning_window) + +# %% ########################################################################################################### +# Force final update +# ~~~~~~~~~~~~~~~~~~ +# Synapses only get active, that is, the correct weight update calculated and applied, when they transmit a +# spike. To still be able to read out the correct weights at the end of the simulation, we force spiking of the +# presynaptic neuron and thus an update of all synapses, including those that have not transmitted a spike in +# the last update interval, by sending a strong spike to all neurons that form the presynaptic side of an eprop +# synapse. This step is required purely for technical reasons. + +gen_spk_final_update = nest.Create("spike_generator", 1, {"spike_times": [duration["task"] + duration["delays"]]}) + +nest.Connect(gen_spk_final_update, nrns_in + nrns_rec, "all_to_all", {"weight": 1000.0}) + +# %% ########################################################################################################### +# Read out pre-training weights +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Before we begin training, we read out the initial weight matrices so that we can eventually compare them to +# the optimized weights. + + +def get_weights(pop_pre, pop_post): + conns = nest.GetConnections(pop_pre, pop_post).get(["source", "target", "weight"]) + conns["senders"] = np.array(conns["source"]) - np.min(conns["source"]) + conns["targets"] = np.array(conns["target"]) - np.min(conns["target"]) + + conns["weight_matrix"] = np.zeros((len(pop_post), len(pop_pre))) + conns["weight_matrix"][conns["targets"], conns["senders"]] = conns["weight"] + return conns + + +weights_pre_train = { + "in_rec": get_weights(nrns_in, nrns_rec), + "rec_rec": get_weights(nrns_rec, nrns_rec), + "rec_out": get_weights(nrns_rec, nrns_out), +} + +# %% ########################################################################################################### +# Simulate +# ~~~~~~~~ +# We train the network by simulating for a set simulation time, determined by the number of iterations and the +# batch size and the length of one sequence. + +nest.Simulate(duration["sim"]) + +# %% ########################################################################################################### +# Read out post-training weights +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# After the training, we can read out the optimized final weights. + +weights_post_train = { + "in_rec": get_weights(nrns_in, nrns_rec), + "rec_rec": get_weights(nrns_rec, nrns_rec), + "rec_out": get_weights(nrns_rec, nrns_out), +} + +# %% ########################################################################################################### +# Read out recorders +# ~~~~~~~~~~~~~~~~~~ +# We can also retrieve the recorded history of the dynamic variables and weights, as well as detected spikes. + +events_mm_reg = mm_reg.get("events") +events_mm_ad = mm_ad.get("events") +events_mm_out = mm_out.get("events") +events_sr_in = sr_in.get("events") +events_sr_reg = sr_reg.get("events") +events_sr_ad = sr_ad.get("events") +events_wr = wr.get("events") + +# %% ########################################################################################################### +# Evaluate training error +# ~~~~~~~~~~~~~~~~~~~~~~~ +# We evaluate the network's training error by calculating a loss - in this case, the mean squared error between +# the integrated recurrent network activity and the target rate. + +readout_signal = events_mm_out["readout_signal"] +target_signal = events_mm_out["target_signal"] +senders = events_mm_out["senders"] + +readout_signal = np.array([readout_signal[senders == i] for i in set(senders)]) +target_signal = np.array([target_signal[senders == i] for i in set(senders)]) + +readout_signal = readout_signal.reshape((n_out, n_iter, group_size, steps["sequence"])) +target_signal = target_signal.reshape((n_out, n_iter, group_size, steps["sequence"])) + +readout_signal = readout_signal[:, :, :, -steps["learning_window"] :] +target_signal = target_signal[:, :, :, -steps["learning_window"] :] + +loss = 0.5 * np.mean(np.sum((readout_signal - target_signal) ** 2, axis=3), axis=(0, 2)) + +y_prediction = np.argmax(np.mean(readout_signal, axis=3), axis=0) +y_target = np.argmax(np.mean(target_signal, axis=3), axis=0) +accuracy = np.mean((y_target == y_prediction), axis=1) +recall_errors = 1.0 - accuracy + +# %% ########################################################################################################### +# Plot results +# ~~~~~~~~~~~~ +# Then, we plot a series of plots. + +do_plotting = True # if True, plot the results + +if not do_plotting: + exit() + +colors = { + "blue": "#2854c5ff", + "red": "#e04b40ff", + "white": "#ffffffff", +} + +plt.rcParams.update( + { + "font.sans-serif": "Arial", + "axes.spines.right": False, + "axes.spines.top": False, + "axes.prop_cycle": cycler(color=[colors["blue"], colors["red"]]), + } +) + +# %% ########################################################################################################### +# Plot training error +# ................... +# We begin with two plots visualizing the training error of the network: the loss and the recall error, both +# plotted against the iterations. + +fig, axs = plt.subplots(2, 1, sharex=True) +fig.suptitle("Training error") + +axs[0].plot(range(1, n_iter + 1), loss) +axs[0].set_ylabel(r"$E = \frac{1}{2} \sum_{t,k} \left( y_k^t -y_k^{*,t}\right)^2$") + +axs[1].plot(range(1, n_iter + 1), recall_errors) +axs[1].set_ylabel("recall errors") + +axs[-1].set_xlabel("training iteration") +axs[-1].set_xlim(1, n_iter) +axs[-1].xaxis.get_major_locator().set_params(integer=True) + +fig.tight_layout() + +# %% ########################################################################################################### +# Plot spikes and dynamic variables +# ................................. +# This plotting routine shows how to plot all of the recorded dynamic variables and spikes across time. We take +# one snapshot in the first iteration and one snapshot at the end. + + +def plot_recordable(ax, events, recordable, ylabel, xlims): + for sender in set(events["senders"]): + idc_sender = events["senders"] == sender + idc_times = (events["times"][idc_sender] > xlims[0]) & (events["times"][idc_sender] < xlims[1]) + ax.plot(events["times"][idc_sender][idc_times], events[recordable][idc_sender][idc_times], lw=0.5) + ax.set_ylabel(ylabel) + margin = np.abs(np.max(events[recordable]) - np.min(events[recordable])) * 0.1 + ax.set_ylim(np.min(events[recordable]) - margin, np.max(events[recordable]) + margin) + + +def plot_spikes(ax, events, ylabel, xlims): + idc_times = (events["times"] > xlims[0]) & (events["times"] < xlims[1]) + senders_subset = events["senders"][idc_times] + times_subset = events["times"][idc_times] + + ax.scatter(times_subset, senders_subset, s=0.1) + ax.set_ylabel(ylabel) + margin = np.abs(np.max(senders_subset) - np.min(senders_subset)) * 0.1 + ax.set_ylim(np.min(senders_subset) - margin, np.max(senders_subset) + margin) + + +for title, xlims in zip( + ["Dynamic variables before training", "Dynamic variables after training"], + [(0, steps["sequence"]), (steps["task"] - steps["sequence"], steps["task"])], +): + fig, axs = plt.subplots(14, 1, sharex=True, figsize=(8, 14), gridspec_kw={"hspace": 0.4, "left": 0.2}) + fig.suptitle(title) + + plot_spikes(axs[0], events_sr_in, r"$z_i$" + "\n", xlims) + plot_spikes(axs[1], events_sr_reg, r"$z_j$" + "\n", xlims) + + plot_recordable(axs[2], events_mm_reg, "V_m", r"$v_j$" + "\n(mV)", xlims) + plot_recordable(axs[3], events_mm_reg, "surrogate_gradient", r"$\psi_j$" + "\n", xlims) + plot_recordable(axs[4], events_mm_reg, "learning_signal", r"$L_j$" + "\n(pA)", xlims) + + plot_spikes(axs[5], events_sr_ad, r"$z_j$" + "\n", xlims) + + plot_recordable(axs[6], events_mm_ad, "V_m", r"$v_j$" + "\n(mV)", xlims) + plot_recordable(axs[7], events_mm_ad, "surrogate_gradient", r"$\psi_j$" + "\n", xlims) + plot_recordable(axs[8], events_mm_ad, "V_th_adapt", r"$A_j$" + "\n(mV)", xlims) + plot_recordable(axs[9], events_mm_ad, "learning_signal", r"$L_j$" + "\n(pA)", xlims) + + plot_recordable(axs[10], events_mm_out, "V_m", r"$v_k$" + "\n(mV)", xlims) + plot_recordable(axs[11], events_mm_out, "target_signal", r"$y^*_k$" + "\n", xlims) + plot_recordable(axs[12], events_mm_out, "readout_signal", r"$y_k$" + "\n", xlims) + plot_recordable(axs[13], events_mm_out, "error_signal", r"$y_k-y^*_k$" + "\n", xlims) + + axs[-1].set_xlabel(r"$t$ (ms)") + axs[-1].set_xlim(*xlims) + + fig.align_ylabels() + +# %% ########################################################################################################### +# Plot weight time courses +# ........................ +# Similarly, we can plot the weight histories. Note that the weight recorder, attached to the synapses, works +# differently than the other recorders. Since synapses only get activated when they transmit a spike, the weight +# recorder only records the weight in those moments. That is why the first weight registrations do not start in +# the first time step and we add the initial weights manually. + + +def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabel): + for sender in nrns_senders.tolist(): + for target in nrns_targets.tolist(): + idc_syn = (events["senders"] == sender) & (events["targets"] == target) + idc_syn_pre = (weights_pre_train[label]["source"] == sender) & ( + weights_pre_train[label]["target"] == target + ) + + times = [0.0] + events["times"][idc_syn].tolist() + weights = [weights_pre_train[label]["weight"][idc_syn_pre]] + events["weights"][idc_syn].tolist() + + ax.step(times, weights, c=colors["blue"]) + ax.set_ylabel(ylabel) + ax.set_ylim(-1.5, 1.5) + + +fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4)) +fig.suptitle("Weight time courses") + +plot_weight_time_course(axs[0], events_wr, nrns_in[:n_record_w], nrns_rec[:n_record_w], "in_rec", r"$W_\text{in}$ (pA)") +plot_weight_time_course( + axs[1], events_wr, nrns_rec[:n_record_w], nrns_rec[:n_record_w], "rec_rec", r"$W_\text{rec}$ (pA)" +) +plot_weight_time_course(axs[2], events_wr, nrns_rec[:n_record_w], nrns_out, "rec_out", r"$W_\text{out}$ (pA)") + +axs[-1].set_xlabel(r"$t$ (ms)") +axs[-1].set_xlim(0, steps["task"]) + +fig.align_ylabels() +fig.tight_layout() + +# %% ########################################################################################################### +# Plot weight matrices +# .................... +# If one is not interested in the time course of the weights, it is possible to read out only the initial and +# final weights, which requires less computing time and memory than the weight recorder approach. Here, we plot +# the corresponding weight matrices before and after the optimization. + +cmap = mpl.colors.LinearSegmentedColormap.from_list( + "cmap", ((0.0, colors["blue"]), (0.5, colors["white"]), (1.0, colors["red"])) +) + +fig, axs = plt.subplots(3, 2, sharex="col", sharey="row") +fig.suptitle("Weight matrices") + +all_w_extrema = [] + +for k in weights_pre_train.keys(): + w_pre = weights_pre_train[k]["weight"] + w_post = weights_post_train[k]["weight"] + all_w_extrema.append([np.min(w_pre), np.max(w_pre), np.min(w_post), np.max(w_post)]) + +args = {"cmap": cmap, "vmin": np.min(all_w_extrema), "vmax": np.max(all_w_extrema)} + +for i, weights in zip([0, 1], [weights_pre_train, weights_post_train]): + axs[0, i].pcolormesh(weights["in_rec"]["weight_matrix"].T, **args) + axs[1, i].pcolormesh(weights["rec_rec"]["weight_matrix"], **args) + cmesh = axs[2, i].pcolormesh(weights["rec_out"]["weight_matrix"], **args) + + axs[2, i].set_xlabel("recurrent\nneurons") + +axs[0, 0].set_ylabel("input\nneurons") +axs[1, 0].set_ylabel("recurrent\nneurons") +axs[2, 0].set_ylabel("readout\nneurons") +fig.align_ylabels(axs[:, 0]) + +axs[0, 0].text(0.5, 1.1, "before training", transform=axs[0, 0].transAxes, ha="center") +axs[0, 1].text(0.5, 1.1, "after training", transform=axs[0, 1].transAxes, ha="center") + +axs[2, 0].yaxis.get_major_locator().set_params(integer=True) + +cbar = plt.colorbar(cmesh, cax=axs[1, 1].inset_axes([1.1, 0.2, 0.05, 0.8]), label="weight (pA)") + +fig.tight_layout() + +plt.show() diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_classification_neuromorphic_mnist.py b/pynest/examples/eprop_plasticity/eprop_supervised_classification_neuromorphic_mnist.py index 84a0b990bf..592f079e35 100644 --- a/pynest/examples/eprop_plasticity/eprop_supervised_classification_neuromorphic_mnist.py +++ b/pynest/examples/eprop_plasticity/eprop_supervised_classification_neuromorphic_mnist.py @@ -134,6 +134,8 @@ steps = { "sequence": 300, # time steps of one full sequence "learning_window": 10, # time steps of window with non-zero learning signals + "delay_rec_out": 1, # time steps of connection delay from recurrent to output neurons + "delay_out_rec": 1, # time steps of broadcast delay of learning signals } steps.update( @@ -207,6 +209,8 @@ "I_e": 0.0, # pA, external current input "tau_m": 100.0, # ms, membrane time constant "V_m": 0.0, # mV, initial value of the membrane voltage + "delay_out_rec": duration["delay_out_rec"], # ms, broadcast delay of learning signals + "delay_rec_out": duration["delay_rec_out"], # ms, connection delay from recurrent to output neurons } params_nrn_rec = { @@ -225,6 +229,8 @@ "tau_m": 30.0, "V_m": 0.0, "V_th": 0.6, # mV, spike threshold membrane voltage + "delay_out_rec": duration["delay_out_rec"], # ms, broadcast delay of learning signals + "delay_rec_out": duration["delay_rec_out"], # ms, connection delay from recurrent to output neurons } scale_factor = 1.0 - params_nrn_rec["kappa"] # factor for rescaling due to removal of irregular spike arrival @@ -389,10 +395,11 @@ def get_weight_recorder_senders_targets(weights, sender_pop, target_pop): params_syn_in = params_syn_base.copy() params_syn_rec = params_syn_base.copy() params_syn_out = params_syn_base.copy() +params_syn_out["delay"] = duration["delay_rec_out"] params_syn_feedback = { "synapse_model": "eprop_learning_signal_connection", - "delay": duration["step"], + "delay": duration["delay_out_rec"], "weight": weights_out_rec, } @@ -592,6 +599,7 @@ def get_params_task_input_output(n_iter_interval, loader): + iteration_offset + group_element * duration["sequence"] + duration["total_offset"] + + duration["delay_rec_out"] - 1.0 for group_element in range(group_size) ] ), diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves.py b/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves.py index bf0a66e992..4413afb9d6 100644 --- a/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves.py +++ b/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves.py @@ -121,6 +121,8 @@ steps = { "sequence": 1000, # time steps of one full sequence + "delay_rec_out": 1, # time steps of connection delay from recurrent to output neurons + "delay_out_rec": 1, # time steps of broadcast delay of learning signals } steps["learning_window"] = steps["sequence"] # time steps of window with non-zero learning signals @@ -181,6 +183,8 @@ "I_e": 0.0, # pA, external current input "tau_m": 30.0, # ms, membrane time constant "V_m": 0.0, # mV, initial value of the membrane voltage + "delay_out_rec": duration["delay_out_rec"], # ms, broadcast delay of learning signals + "delay_rec_out": duration["delay_rec_out"], # ms, connection delay from recurrent to output neurons } params_nrn_rec = { @@ -199,6 +203,8 @@ "tau_m": 30.0, "V_m": 0.0, "V_th": 0.03, # mV, spike threshold membrane voltage + "delay_out_rec": duration["delay_out_rec"], # ms, broadcast delay of learning signals + "delay_rec_out": duration["delay_rec_out"], # ms, connection delay from recurrent to output neurons } scale_factor = 1.0 - params_nrn_rec["kappa"] # factor for rescaling due to removal of irregular spike arrival @@ -335,10 +341,11 @@ params_syn_out = params_syn_base.copy() params_syn_out["weight"] = weights_rec_out +params_syn_out["delay"] = duration["delay_rec_out"] params_syn_feedback = { "synapse_model": "eprop_learning_signal_connection", - "delay": duration["step"], + "delay": duration["delay_out_rec"], "weight": weights_out_rec, } @@ -428,7 +435,7 @@ def generate_superimposed_sines(steps_sequence, periods): target_signal = generate_superimposed_sines(steps["sequence"], [1000, 500, 333, 200]) # periods in steps params_gen_rate_target = { - "amplitude_times": np.arange(0.0, duration["task"], duration["step"]) + duration["total_offset"], + "amplitude_times": np.arange(0.0, duration["task"], duration["step"]) + duration["total_offset"] + duration["delay_rec_out"] - 1.0, "amplitude_values": np.tile(target_signal, n_iter * group_size), }