-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
first pass of code files (to be commented more)
- Loading branch information
1 parent
66bec45
commit 01cb9d6
Showing
8 changed files
with
2,535 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
from matplotlib.colors import ListedColormap | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
from numpy.core.fromnumeric import mean | ||
from tqdm.auto import trange | ||
from pynwb import NWBHDF5IO | ||
import seaborn as sns | ||
import wandb | ||
import os | ||
|
||
def load_dataset(indices=None,load_frames=True,num_pcs=10,datapath='moseq_data/saline_example',datapath2='amph_data/amphetamine_example',mix_data=False): | ||
''' | ||
loads mouse data and returns train/test datasets as dicts | ||
''' | ||
|
||
if indices is None: | ||
indices = np.arange(24) | ||
|
||
if ~mix_data: | ||
train_dataset = [] | ||
test_dataset = [] | ||
for t in trange(len(indices)): | ||
i = indices[t] | ||
nwb_path = datapath + "_{}.nwb".format(i) | ||
with NWBHDF5IO(nwb_path, mode='r') as io: | ||
f = io.read() | ||
num_frames = len(f.processing['MoSeq']['PCs']['pcs_clean'].data) | ||
train_slc = slice(0, int(0.8 * num_frames)) | ||
test_slc = slice(int(0.8 * num_frames) + 1, -1) | ||
|
||
train_data, test_data = dict(), dict() | ||
for slc, data in zip([train_slc, test_slc], [train_data, test_data]): | ||
data["raw_pcs"] = f.processing['MoSeq']['PCs']['pcs_clean'].data[slc][:, :num_pcs] | ||
data["times"] = f.processing['MoSeq']['PCs']['pcs_clean'].timestamps[slc][:] | ||
data["centroid_x_px"] = f.processing['MoSeq']['Scalars']['centroid_x_px'].data[slc][:] | ||
data["centroid_y_px"] = f.processing['MoSeq']['Scalars']['centroid_y_px'].data[slc][:] | ||
data["angles"] = f.processing['MoSeq']['Scalars']['angle'].data[slc][:] | ||
data["labels"] = f.processing['MoSeq']['Labels']['labels_clean'].data[slc][:] | ||
data["velocity_3d_px"] = f.processing['MoSeq']['Scalars']['velocity_3d_px'].data[slc][:] | ||
data["height_ave_mm"] = f.processing['MoSeq']['Scalars']['height_ave_mm'].data[slc][:] | ||
|
||
# only load the frames on the test data | ||
test_data["frames"] = f.processing['MoSeq']['Images']['frames'].data[test_slc] | ||
|
||
train_dataset.append(train_data) | ||
test_dataset.append(test_data) | ||
|
||
elif mix_data: | ||
train_dataset = [] | ||
test_dataset = [] | ||
# ind_1 = np.random.randint(0,len(indices),len(indices)//2) | ||
# ind_2 = np.random.randint(0,len(indices),len(indices)//2) | ||
|
||
for t in trange(len(indices)): | ||
i = indices[t] | ||
nwb_path = datapath + "_{}.nwb".format(i) | ||
with NWBHDF5IO(nwb_path, mode='r') as io: | ||
f = io.read() | ||
num_frames = len(f.processing['MoSeq']['PCs']['pcs_clean'].data) | ||
train_slc = slice(0, int(0.8 * num_frames)) | ||
test_slc = slice(int(0.8 * num_frames) + 1, -1) | ||
|
||
train_data, test_data = dict(), dict() | ||
for slc, data in zip([train_slc, test_slc], [train_data, test_data]): | ||
data["raw_pcs"] = f.processing['MoSeq']['PCs']['pcs_clean'].data[slc][:, :num_pcs] | ||
data["times"] = f.processing['MoSeq']['PCs']['pcs_clean'].timestamps[slc][:] | ||
data["centroid_x_px"] = f.processing['MoSeq']['Scalars']['centroid_x_px'].data[slc][:] | ||
data["centroid_y_px"] = f.processing['MoSeq']['Scalars']['centroid_y_px'].data[slc][:] | ||
data["angles"] = f.processing['MoSeq']['Scalars']['angle'].data[slc][:] | ||
data["labels"] = f.processing['MoSeq']['Labels']['labels_clean'].data[slc][:] | ||
|
||
# only load the frames on the test data | ||
test_data["frames"] = f.processing['MoSeq']['Images']['frames'].data[test_slc] | ||
|
||
train_dataset.append(train_data) | ||
test_dataset.append(test_data) | ||
|
||
for t in trange(len(indices)): | ||
i = indices[t] | ||
nwb_path = datapath2 + "_{}.nwb".format(i) | ||
with NWBHDF5IO(nwb_path, mode='r') as io: | ||
f = io.read() | ||
num_frames = len(f.processing['MoSeq']['PCs']['pcs_clean'].data) | ||
train_slc = slice(0, int(0.8 * num_frames)) | ||
test_slc = slice(int(0.8 * num_frames) + 1, -1) | ||
|
||
train_data, test_data = dict(), dict() | ||
for slc, data in zip([train_slc, test_slc], [train_data, test_data]): | ||
data["raw_pcs"] = f.processing['MoSeq']['PCs']['pcs_clean'].data[slc][:, :num_pcs] | ||
data["times"] = f.processing['MoSeq']['PCs']['pcs_clean'].timestamps[slc][:] | ||
data["centroid_x_px"] = f.processing['MoSeq']['Scalars']['centroid_x_px'].data[slc][:] | ||
data["centroid_y_px"] = f.processing['MoSeq']['Scalars']['centroid_y_px'].data[slc][:] | ||
data["angles"] = f.processing['MoSeq']['Scalars']['angle'].data[slc][:] | ||
data["labels"] = f.processing['MoSeq']['Labels']['labels_clean'].data[slc][:] | ||
|
||
# only load the frames on the test data | ||
test_data["frames"] = f.processing['MoSeq']['Images']['frames'].data[test_slc] | ||
|
||
train_dataset.append(train_data) | ||
test_dataset.append(test_data) | ||
|
||
return train_dataset, test_dataset | ||
|
||
def standardize_pcs(dataset, mean=None, std=None): | ||
''' | ||
adds new keyword 'data' corresponding with standardized PCs | ||
''' | ||
|
||
if mean is None and std is None: | ||
all_pcs = np.vstack([data['raw_pcs'] for data in dataset]) | ||
mean = all_pcs.mean(axis=0) | ||
std = all_pcs.std(axis=0) | ||
|
||
for data in dataset: | ||
data['data'] = (data['raw_pcs'] - mean) / std | ||
return dataset, mean, std | ||
|
||
def precompute_ar_covariates(dataset, | ||
num_lags=1, | ||
fit_intercept=False): | ||
''' | ||
add the desired covariates to the data dictionary | ||
''' | ||
for data in dataset: | ||
x = data['data'] | ||
data_dim = x.shape[1] | ||
phis = [] | ||
for lag in range(1, num_lags+1): | ||
phis.append(np.row_stack([np.zeros((lag, data_dim)), x[:-lag]])) | ||
if fit_intercept: | ||
phis.append(np.ones(len(x))) | ||
data['covariates'] = np.column_stack(phis) | ||
|
||
def log_wandb_model(model, name, type): | ||
trained_model_artifact = wandb.Artifact(name,type=type) | ||
if not os.path.isdir('models'): os.mkdir('models') | ||
subdirectory = wandb.run.name | ||
filepath = os.path.join('models', subdirectory) | ||
model.save(filepath) | ||
trained_model_artifact.add_dir(filepath) | ||
wandb.log_artifact(trained_model_artifact) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import torch | ||
from torch import nn | ||
|
||
device = torch.device('cpu') | ||
dtype = torch.float64 | ||
to_t = lambda array: torch.tensor(array, device=device, dtype=dtype) | ||
from_t = lambda tensor: tensor.to("cpu").detach().numpy() | ||
|
||
|
||
class RBF(nn.Module): | ||
def __init__(self, num_discrete_states, lengthscales_Init=1.0): | ||
super().__init__() | ||
self.output_scale = nn.Parameter(torch.ones((num_discrete_states),device=device, dtype=dtype)) # one for each discrete state | ||
self.lengthscales = nn.Parameter(lengthscales_Init*torch.ones((num_discrete_states),device=device, dtype=dtype)) # one for each discrete state | ||
""" | ||
Exponentiated Quadratic kernel class. | ||
forward call evaluates Kernel Gram matrix at input arguments. | ||
The output is num_discete_states x num_tau x num_tau | ||
""" | ||
|
||
def forward(self, x_grid): | ||
""" | ||
classic kernel function | ||
""" | ||
|
||
diffsq = (torch.div((x_grid.view(1,-1,1) - x_grid.view(1,1,-1)), self.lengthscales.view(-1,1,1)))**2 | ||
|
||
return self.output_scale.view(-1,1,1)**2 * torch.exp(-0.5 * diffsq) |
Oops, something went wrong.