Skip to content

Commit

Permalink
Updated PR #844 changes to replace st[:,2] instead of new column
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobpennington committed Jan 15, 2025
1 parent b82c0f3 commit 169d3fb
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
12 changes: 6 additions & 6 deletions kilosort/run_kilosort.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
ops : dict
Dictionary storing settings and results for all algorithmic steps.
st : np.ndarray
4-column array of peak time (in samples), template, amplitude, and
threshold amplitude for each spike.
3-column array of peak time (in samples), template, and thresold
amplitude for each spike.
clu : np.ndarray
1D vector of cluster ids indicating which spike came from which cluster,
same shape as `st[:,0]`.
Expand Down Expand Up @@ -620,7 +620,7 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None,
Returns
-------
st : np.ndarray
4-column array of peak time (in samples), template, amplitude, and threshold
3-column array of peak time (in samples), template, and thresold
amplitude for each spike.
clu : np.ndarray
1D vector of cluster ids indicating which spike came from which cluster,
Expand Down Expand Up @@ -690,7 +690,7 @@ def cluster_spikes(st, tF, ops, device, bfile, tic0=np.nan, progress_bar=None,
Parameters
----------
st : np.ndarray
4-column array of peak time (in samples), template, amplitud, and threshold
3-column array of peak time (in samples), template, and thresold
amplitude for each spike.
tF : torch.Tensor
PC features for each spike, with shape
Expand Down Expand Up @@ -760,7 +760,7 @@ def save_sorting(ops, results_dir, st, clu, tF, Wall, imin, tic0=np.nan,
results_dir : pathlib.Path
Directory where results should be saved.
st : np.ndarray
4-column array of peak time (in samples), template, amplitude, and thresold
3-column array of peak time (in samples), template, and thresold
amplitude for each spike.
clu : np.ndarray
1D vector of cluster ids indicating which spike came from which cluster,
Expand Down Expand Up @@ -887,7 +887,7 @@ def load_sorting(results_dir, device=None, load_extra_vars=False):
(n_clusters, n_channels, n_pcs).
full_st : np.ndarray.
Only returned if `load_extra_vars` is True.
3-column array of peak time (in samples), template, and amplitude for
3-column array of peak time (in samples), template, and threshold amplitude for
each spike.
Includes spikes removed by `kilosort.postprocessing.remove_duplicate_spikes`.
full_clu : np.ndarray.
Expand Down
9 changes: 4 additions & 5 deletions kilosort/template_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def extract(ops, bfile, U, device=torch.device('cuda'), progress_bar=None):

tiwave = torch.arange(-(nt//2), nt//2+1, device=device)
ctc = prepare_matching(ops, U)
st = np.zeros((10**6, 4), 'float64')
st = np.zeros((10**6, 3), 'float64')
tF = torch.zeros((10**6, nC , ops['settings']['n_pcs']))
k = 0
prog = tqdm(
Expand Down Expand Up @@ -102,8 +102,7 @@ def extract(ops, bfile, U, device=torch.device('cuda'), progress_bar=None):
stt = stt.double()
st[k:k+nsp,0] = ((stt[:,0]-nt) + ibatch * (ops['batch_size'])).cpu().numpy() - nt//2 + ops['nt0min']
st[k:k+nsp,1] = stt[:,1].cpu().numpy()
st[k:k+nsp,2] = amps.cpu().numpy().squeeze()
st[k:k+nsp,3] = th_amps.cpu().numpy().squeeze()
st[k:k+nsp,2] = th_amps.cpu().numpy().squeeze()

tF[k:k+nsp] = xfeat.transpose(0,1).cpu()

Expand Down Expand Up @@ -200,7 +199,7 @@ def run_matching(ops, X, U, ctc, device=torch.device('cuda')):
Cf[:, -nt:] = 0

Cfmax, imax = torch.max(Cf, 0)
Cmax = max_pool1d(Cfmax.unsqueeze(0).unsqueeze(0), (2*nt+1), stride = 1, padding = (nt))
Cmax = max_pool1d(Cfmax.unsqueeze(0).unsqueeze(0), (2*nt+1), stride=1, padding=(nt))

#print(Cfmax.shape)
#import pdb; pdb.set_trace()
Expand All @@ -223,7 +222,7 @@ def run_matching(ops, X, U, ctc, device=torch.device('cuda')):
st[k:k+nsp, 1] = iY[:,0]
amps[k:k+nsp] = B[iY,iX] / nm[iY]
amp = amps[k:k+nsp]
th_amps[k:k+nsp] = Cmax[0,0,iX[:,0],None]**.5
th_amps[k:k+nsp] = Cmax[0, 0, iX[:,0], None]**.5

k+= nsp

Expand Down

0 comments on commit 169d3fb

Please sign in to comment.