@@ -380,19 +380,7 @@ def _get_spike_timestamps(self, block_index, seg_index, spike_channel_index, t_s
380380
381381 spike_timestamps , unit_ids , waveforms = self ._spike_channel_cache [channel_name ]
382382
383- if t_start is not None or t_stop is not None :
384- # restrict spikes to given limits (in seconds)
385- timestamp_frequency = self .pl2reader .pl2_file_info .m_TimestampFrequency
386- lim0 = int (t_start * timestamp_frequency )
387- lim1 = int (t_stop * self .pl2reader .pl2_file_info .m_TimestampFrequency )
388-
389- # limits are with respect to segment t_start and not to time 0
390- lim0 -= self .pl2reader .pl2_file_info .m_StartRecordingTime
391- lim1 -= self .pl2reader .pl2_file_info .m_StartRecordingTime
392-
393- time_mask = (spike_timestamps >= lim0 ) & (spike_timestamps <= lim1 )
394- else :
395- time_mask = slice (None , None )
383+ time_mask = self ._get_timestamp_time_mask (t_start , t_stop , spike_timestamps )
396384
397385 unit_mask = unit_ids [time_mask ] == channel_unit_id
398386 spike_timestamps = spike_timestamps [time_mask ][unit_mask ]
@@ -425,25 +413,33 @@ def _get_spike_raw_waveforms(self, block_index, seg_index, spike_channel_index,
425413
426414 spike_timestamps , unit_ids , waveforms = self ._spike_channel_cache [channel_name ]
427415
416+ time_mask = self ._get_timestamp_time_mask (t_start , t_stop , spike_timestamps )
417+
418+ unit_mask = unit_ids [time_mask ] == int (channel_unit_id )
419+ waveforms = waveforms [time_mask ][unit_mask ]
420+
421+ # add tetrode dimension
422+ waveforms = np .expand_dims (waveforms , axis = 1 )
423+ return waveforms
424+
425+ def _get_timestamp_time_mask (self , t_start , t_stop , timestamps ):
426+
428427 if t_start is not None or t_stop is not None :
429428 # restrict spikes to given limits (in seconds)
430429 timestamp_frequency = self .pl2reader .pl2_file_info .m_TimestampFrequency
431430 lim0 = int (t_start * timestamp_frequency )
432431 lim1 = int (t_stop * self .pl2reader .pl2_file_info .m_TimestampFrequency )
433- time_mask = (spike_timestamps >= lim0 ) & (spike_timestamps <= lim1 )
434432
435433 # limits are with respect to segment t_start and not to time 0
436434 lim0 -= self .pl2reader .pl2_file_info .m_StartRecordingTime
437435 lim1 -= self .pl2reader .pl2_file_info .m_StartRecordingTime
436+
437+ time_mask = (timestamps >= lim0 ) & (timestamps <= lim1 )
438+
438439 else :
439440 time_mask = slice (None , None )
440441
441- unit_mask = unit_ids [time_mask ] == int (channel_unit_id )
442- waveforms = waveforms [time_mask ][unit_mask ]
443-
444- # add tetrode dimension
445- waveforms = np .expand_dims (waveforms , axis = 1 )
446- return waveforms
442+ return time_mask
447443
448444 def _event_count (self , block_index , seg_index , event_channel_index ):
449445
@@ -474,18 +470,7 @@ def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_s
474470 event_times , labels = self ._event_channel_cache [channel_name ]
475471 labels = np .asarray (labels , dtype = 'U' )
476472
477- if t_start is not None or t_stop is not None :
478- # restrict events to given limits (in seconds)
479- timestamp_frequency = self .pl2reader .pl2_file_info .m_TimestampFrequency
480- lim0 = int (t_start * timestamp_frequency )
481- lim1 = int (t_stop * self .pl2reader .pl2_file_info .m_TimestampFrequency )
482- time_mask = (event_times >= lim0 ) & (event_times <= lim1 )
483-
484- # limits are with respect to segment t_start and not to time 0
485- lim0 -= self .pl2reader .pl2_file_info .m_StartRecordingTime
486- lim1 -= self .pl2reader .pl2_file_info .m_StartRecordingTime
487- else :
488- time_mask = np .ones_like (event_times )
473+ time_mask = self ._get_timestamp_time_mask (t_start , t_stop , event_times )
489474
490475 # events don't have a duration. Epochs are not supported
491476 durations = None
0 commit comments