Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions dev/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@
log_to_file(fname+'.klg.'+str(shank), 'debug')
#log_suppress_hierarchy('klustakwik', inclusive=False)

if os.path.exists(fname+'.pickle'):
# if False:
# if os.path.exists(fname+'.pickle'):
if False:
start_time = time.time()
data = pickle.load(open(fname+'.pickle', 'rb'))
print('load from pickle:', time.time()-start_time)
else:
start_time = time.time()
raw_data = load_fet_fmask_to_raw(fname, shank, drop_last_n_features=1)
raw_data = load_fet_fmask_to_raw(fname, shank, drop_last_n_features=1,
use_fmask=False,
)
print('load_fet_fmask_to_raw:', time.time()-start_time)
data = raw_data.to_sparse_data()
pickle.dump(data, open(fname+'.pickle', 'wb'), -1)
Expand All @@ -41,7 +43,8 @@
# fast_split=True,
# max_split_iterations=10,
consider_cluster_deletion=True,
#num_cpus=1,
num_cpus=1,
num_starting_clusters=50,
)
# kk.register_callback(SaveCluEvery(fname, shank, every=1))
kk.register_callback(MonitoringServer())
Expand Down Expand Up @@ -73,7 +76,7 @@ def printclu_after(kk):
else:
print('Generating clusters from scratch')
#kk.cluster_with_subset_schedule(100, [0.99, 1.0])
kk.cluster_mask_starts()
kk.cluster_mask_or_random_starts()

# clusters = loadtxt('../temp/testsmallish.start.clu', skiprows=1, dtype=int)
# # dump_covariance_matrices(kk)
Expand Down
14 changes: 14 additions & 0 deletions klustakwik2/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,20 @@ def cluster_mask_starts(self):
clusters = mask_starts(self.data, self.num_starting_clusters, self.num_special_clusters)
self.cluster_from(clusters)

def cluster_random_starts(self):
clusters = randint(self.num_special_clusters,
self.num_special_clusters+self.num_starting_clusters,
size=self.data.num_spikes)
self.cluster_from(clusters)

def cluster_mask_or_random_starts(self):
if self.data.num_masks<self.num_starting_clusters:
self.log('info', 'Using random starts')
self.cluster_random_starts()
else:
self.log('info', 'Using mask starts')
self.cluster_mask_starts()

def cluster_from(self, clusters, recurse=True, score_target=-inf):
self.log('info', 'Clustering data set of %d points, %d features' % (self.data.num_spikes,
self.data.num_features))
Expand Down
53 changes: 40 additions & 13 deletions klustakwik2/input_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,46 @@
__all__ = ['load_fet_fmask_to_raw', 'save_clu', 'load_clu', 'SaveCluEvery']


def load_fet_fmask_to_raw(fname, shank, use_features=None, drop_last_n_features=0):
def load_fet_fmask_to_raw(fname, shank, use_features=None, drop_last_n_features=0,
use_fmask=True):
if use_features is None and drop_last_n_features>0:
use_features = slice(None, -drop_last_n_features)
else:
use_features = slice(None)
fet_fname = fname+'.fet.'+str(shank)
fmask_fname = fname+'.fmask.'+str(shank)
if use_fmask:
fmask_fname = fname+'.fmask.'+str(shank)
# read files
fet_file = open(fet_fname, 'r')
fmask_file = open(fmask_fname, 'r')
# read first line of fmask file
num_features = int(fmask_file.readline())
if use_fmask:
fmask_file = open(fmask_fname, 'r')
# read first line of fmask file
fmask_file.readline()
num_features = int(fet_file.readline())
features_to_use = arange(num_features)[use_features]
num_features = len(features_to_use)
# Stage 1: read min/max of fet values for normalisation
# and count total number of unmasked features
fet_file.readline() # skip first line (num channels)
#fet_file.readline() # skip first line (num channels)
# we normalise channel-by-channel
vmin = ones(num_features)*inf
vmax = ones(num_features)*-inf
total_unmasked_features = 0
num_spikes = 0
for fetline, fmaskline in zip(fet_file, fmask_file):
if use_fmask:
lines = zip(fet_file, fmask_file)
else:
lines = fet_file
for line in lines:
if use_fmask:
fetline, fmaskline = line
else:
fetline = line
vals = fromstring(fetline, dtype=float, sep=' ')[use_features]
fmaskvals = fromstring(fmaskline, dtype=float, sep=' ')[use_features]
if use_fmask:
fmaskvals = fromstring(fmaskline, dtype=float, sep=' ')[use_features]
else:
fmaskvals = ones_like(vals)
inds, = (fmaskvals>0).nonzero()
total_unmasked_features += len(inds)
vmin = minimum(vals, vmin)
Expand All @@ -42,9 +57,10 @@ def load_fet_fmask_to_raw(fname, shank, use_features=None, drop_last_n_features=
fet_file.close()
fet_file = open(fet_fname, 'r')
fet_file.readline()
fmask_file.close()
fmask_file = open(fmask_fname, 'r')
fmask_file.readline()
if use_fmask:
fmask_file.close()
fmask_file = open(fmask_fname, 'r')
fmask_file.readline()
vdiff = vmax-vmin
vdiff[vdiff==0] = 1
fetsum = zeros(num_features)
Expand All @@ -55,9 +71,20 @@ def load_fet_fmask_to_raw(fname, shank, use_features=None, drop_last_n_features=
all_unmasked = zeros(total_unmasked_features, dtype=int)
offsets = zeros(num_spikes+1, dtype=int)
curoff = 0
for i, (fetline, fmaskline) in enumerate(zip(fet_file, fmask_file)):
if use_fmask:
lines = zip(fet_file, fmask_file)
else:
lines = fet_file
for i, line in enumerate(lines):
if use_fmask:
fetline, fmaskline = line
else:
fetline = line
fetvals = (fromstring(fetline, dtype=float, sep=' ')[use_features]-vmin)/vdiff
fmaskvals = fromstring(fmaskline, dtype=float, sep=' ')[use_features]
if use_fmask:
fmaskvals = fromstring(fmaskline, dtype=float, sep=' ')[use_features]
else:
fmaskvals = ones_like(vals)
inds, = (fmaskvals>0).nonzero()
masked_inds, = (fmaskvals==0).nonzero()
all_features[curoff:curoff+len(inds)] = fetvals[inds]
Expand Down
7 changes: 5 additions & 2 deletions klustakwik2/scripts/kk2_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def main():
use_noise_cluster=True,
use_mua_cluster=True,
subset_schedule=None,
use_fmask=True,
)
(fname, shank), params = parse_args(2, script_params, __doc__.strip()+'\n',
string_args=set(['start_from_clu']))
Expand All @@ -39,6 +40,7 @@ def main():
use_noise_cluster = params.pop('use_noise_cluster')
use_mua_cluster = params.pop('use_mua_cluster')
subset_schedule = params.pop('subset_schedule')
use_fmask = params.pop('use_fmask')

if subset_schedule is not None:
if save_clu_every is not None:
Expand All @@ -54,7 +56,8 @@ def main():
log_suppress_hierarchy('klustakwik', inclusive=False)

start_time = time.time()
raw_data = load_fet_fmask_to_raw(fname, shank, drop_last_n_features=drop_last_n_features)
raw_data = load_fet_fmask_to_raw(fname, shank, drop_last_n_features=drop_last_n_features,
use_fmask=use_fmask)
log_message('debug', 'Loading data from .fet and .fmask file took %.2f s' % (time.time()-start_time))
data = raw_data.to_sparse_data()

Expand All @@ -70,7 +73,7 @@ def main():

if start_from_clu is None:
if subset_schedule is None:
kk.cluster_mask_starts()
kk.cluster_mask_or_random_starts()
else:
kk.cluster_with_subset_schedule(num_starting_clusters, subset_schedule)
else:
Expand Down