diff --git a/data_util.py b/data_util.py new file mode 100644 index 0000000..8d7bbb6 --- /dev/null +++ b/data_util.py @@ -0,0 +1,141 @@ +from matplotlib.colors import ListedColormap +import matplotlib.pyplot as plt +import numpy as np +from numpy.core.fromnumeric import mean +from tqdm.auto import trange +from pynwb import NWBHDF5IO +import seaborn as sns +import wandb +import os + +def load_dataset(indices=None,load_frames=True,num_pcs=10,datapath='moseq_data/saline_example',datapath2='amph_data/amphetamine_example',mix_data=False): + ''' + loads mouse data and returns train/test datasets as dicts + ''' + + if indices is None: + indices = np.arange(24) + + if ~mix_data: + train_dataset = [] + test_dataset = [] + for t in trange(len(indices)): + i = indices[t] + nwb_path = datapath + "_{}.nwb".format(i) + with NWBHDF5IO(nwb_path, mode='r') as io: + f = io.read() + num_frames = len(f.processing['MoSeq']['PCs']['pcs_clean'].data) + train_slc = slice(0, int(0.8 * num_frames)) + test_slc = slice(int(0.8 * num_frames) + 1, -1) + + train_data, test_data = dict(), dict() + for slc, data in zip([train_slc, test_slc], [train_data, test_data]): + data["raw_pcs"] = f.processing['MoSeq']['PCs']['pcs_clean'].data[slc][:, :num_pcs] + data["times"] = f.processing['MoSeq']['PCs']['pcs_clean'].timestamps[slc][:] + data["centroid_x_px"] = f.processing['MoSeq']['Scalars']['centroid_x_px'].data[slc][:] + data["centroid_y_px"] = f.processing['MoSeq']['Scalars']['centroid_y_px'].data[slc][:] + data["angles"] = f.processing['MoSeq']['Scalars']['angle'].data[slc][:] + data["labels"] = f.processing['MoSeq']['Labels']['labels_clean'].data[slc][:] + data["velocity_3d_px"] = f.processing['MoSeq']['Scalars']['velocity_3d_px'].data[slc][:] + data["height_ave_mm"] = f.processing['MoSeq']['Scalars']['height_ave_mm'].data[slc][:] + + # only load the frames on the test data + test_data["frames"] = f.processing['MoSeq']['Images']['frames'].data[test_slc] + + train_dataset.append(train_data) + test_dataset.append(test_data) + + elif mix_data: + train_dataset = [] + test_dataset = [] + # ind_1 = np.random.randint(0,len(indices),len(indices)//2) + # ind_2 = np.random.randint(0,len(indices),len(indices)//2) + + for t in trange(len(indices)): + i = indices[t] + nwb_path = datapath + "_{}.nwb".format(i) + with NWBHDF5IO(nwb_path, mode='r') as io: + f = io.read() + num_frames = len(f.processing['MoSeq']['PCs']['pcs_clean'].data) + train_slc = slice(0, int(0.8 * num_frames)) + test_slc = slice(int(0.8 * num_frames) + 1, -1) + + train_data, test_data = dict(), dict() + for slc, data in zip([train_slc, test_slc], [train_data, test_data]): + data["raw_pcs"] = f.processing['MoSeq']['PCs']['pcs_clean'].data[slc][:, :num_pcs] + data["times"] = f.processing['MoSeq']['PCs']['pcs_clean'].timestamps[slc][:] + data["centroid_x_px"] = f.processing['MoSeq']['Scalars']['centroid_x_px'].data[slc][:] + data["centroid_y_px"] = f.processing['MoSeq']['Scalars']['centroid_y_px'].data[slc][:] + data["angles"] = f.processing['MoSeq']['Scalars']['angle'].data[slc][:] + data["labels"] = f.processing['MoSeq']['Labels']['labels_clean'].data[slc][:] + + # only load the frames on the test data + test_data["frames"] = f.processing['MoSeq']['Images']['frames'].data[test_slc] + + train_dataset.append(train_data) + test_dataset.append(test_data) + + for t in trange(len(indices)): + i = indices[t] + nwb_path = datapath2 + "_{}.nwb".format(i) + with NWBHDF5IO(nwb_path, mode='r') as io: + f = io.read() + num_frames = len(f.processing['MoSeq']['PCs']['pcs_clean'].data) + train_slc = slice(0, int(0.8 * num_frames)) + test_slc = slice(int(0.8 * num_frames) + 1, -1) + + train_data, test_data = dict(), dict() + for slc, data in zip([train_slc, test_slc], [train_data, test_data]): + data["raw_pcs"] = f.processing['MoSeq']['PCs']['pcs_clean'].data[slc][:, :num_pcs] + data["times"] = f.processing['MoSeq']['PCs']['pcs_clean'].timestamps[slc][:] + data["centroid_x_px"] = f.processing['MoSeq']['Scalars']['centroid_x_px'].data[slc][:] + data["centroid_y_px"] = f.processing['MoSeq']['Scalars']['centroid_y_px'].data[slc][:] + data["angles"] = f.processing['MoSeq']['Scalars']['angle'].data[slc][:] + data["labels"] = f.processing['MoSeq']['Labels']['labels_clean'].data[slc][:] + + # only load the frames on the test data + test_data["frames"] = f.processing['MoSeq']['Images']['frames'].data[test_slc] + + train_dataset.append(train_data) + test_dataset.append(test_data) + + return train_dataset, test_dataset + +def standardize_pcs(dataset, mean=None, std=None): + ''' + adds new keyword 'data' corresponding with standardized PCs + ''' + + if mean is None and std is None: + all_pcs = np.vstack([data['raw_pcs'] for data in dataset]) + mean = all_pcs.mean(axis=0) + std = all_pcs.std(axis=0) + + for data in dataset: + data['data'] = (data['raw_pcs'] - mean) / std + return dataset, mean, std + +def precompute_ar_covariates(dataset, + num_lags=1, + fit_intercept=False): + ''' + add the desired covariates to the data dictionary + ''' + for data in dataset: + x = data['data'] + data_dim = x.shape[1] + phis = [] + for lag in range(1, num_lags+1): + phis.append(np.row_stack([np.zeros((lag, data_dim)), x[:-lag]])) + if fit_intercept: + phis.append(np.ones(len(x))) + data['covariates'] = np.column_stack(phis) + +def log_wandb_model(model, name, type): + trained_model_artifact = wandb.Artifact(name,type=type) + if not os.path.isdir('models'): os.mkdir('models') + subdirectory = wandb.run.name + filepath = os.path.join('models', subdirectory) + model.save(filepath) + trained_model_artifact.add_dir(filepath) + wandb.log_artifact(trained_model_artifact) \ No newline at end of file diff --git a/kernels.py b/kernels.py new file mode 100644 index 0000000..69014c7 --- /dev/null +++ b/kernels.py @@ -0,0 +1,28 @@ +import torch +from torch import nn + +device = torch.device('cpu') +dtype = torch.float64 +to_t = lambda array: torch.tensor(array, device=device, dtype=dtype) +from_t = lambda tensor: tensor.to("cpu").detach().numpy() + + +class RBF(nn.Module): + def __init__(self, num_discrete_states, lengthscales_Init=1.0): + super().__init__() + self.output_scale = nn.Parameter(torch.ones((num_discrete_states),device=device, dtype=dtype)) # one for each discrete state + self.lengthscales = nn.Parameter(lengthscales_Init*torch.ones((num_discrete_states),device=device, dtype=dtype)) # one for each discrete state + """ + Exponentiated Quadratic kernel class. + forward call evaluates Kernel Gram matrix at input arguments. + The output is num_discete_states x num_tau x num_tau + """ + + def forward(self, x_grid): + """ + classic kernel function + """ + + diffsq = (torch.div((x_grid.view(1,-1,1) - x_grid.view(1,1,-1)), self.lengthscales.view(-1,1,1)))**2 + + return self.output_scale.view(-1,1,1)**2 * torch.exp(-0.5 * diffsq) diff --git a/plotting_util.py b/plotting_util.py new file mode 100644 index 0000000..e828dc1 --- /dev/null +++ b/plotting_util.py @@ -0,0 +1,1082 @@ +# functions helpful for plotting discrete/continuous states over time +import wandb +from matplotlib.colors import ListedColormap +import matplotlib.pyplot as plt +import numpy as np +from twarhmm import Posterior +from tqdm.auto import trange +import datetime + +import seaborn as sns + +sns.set_style("white") +sns.set_context("talk") + +color_names = ["windows blue", + "red", + "amber", + "faded green", + "dusty purple", + "orange", + "clay", + "pink", + "greyish", + "mint", + "cyan", + "steel blue", + "forest green", + "pastel purple", + "salmon", + "dark brown"] + +colors = sns.xkcd_palette(color_names) +cmap = ListedColormap(colors) + +def plot_discrete_latent_states(states_z, K): + plt.figure(figsize=(10,2)) + cmap_limited = ListedColormap(colors[0:K]) + plt.imshow(states_z[None,:], aspect="auto", cmap=cmap_limited) + plt.title("Simulated Discrete Latent States") + plt.yticks([]) + plt.xlabel("Time") + plt.show() + +def remove_frame(ax_array,all_off=False): + for ax in np.ravel(ax_array): + if not all_off: + ax.spines['right'].set_visible(False) + ax.spines['top'].set_visible(False) + ax.yaxis.set_ticks_position('left') + ax.xaxis.set_ticks_position('bottom') + if all_off: + ax.set_axis_off() + +def plot_discrete_and_continuous_latent_states(states_z, states_x, K,colors = None, fig=None, ax=None): + + latent_dim = states_x.shape[1] + T = states_x.shape[0] + lim = abs(states_x).max() + + if ax == None: + fig, ax =plt.subplots(1,1,figsize=(10,2)) + ax.imshow(states_z[None,:], aspect="auto", cmap=colors, extent=[0, T, -lim, lim*latent_dim], alpha=.7) + ax.set_yticks([]) + + + # Plot the continuous latent states + + + for d in range(latent_dim): + ax.plot(states_x[:, d] + lim * d, 'k', lw=1) + ax.set_yticks(np.arange(latent_dim) * lim, ["$x_{}$".format(d+1) for d in range(latent_dim)]) + ax.set_xticks([]) + ax.set_xlim(0, T) + + + return fig + +def plot_continuous_states_and_emissions(states_x, emissions): + plt.figure(figsize=(10, 6)) + emissions_dim = emissions.shape[1] + latent_dim = states_x.shape[1] + T = states_x.shape[0] + gs = plt.GridSpec(2, 1, height_ratios=(1, emissions_dim/latent_dim)) + + # Plot the continuous latent states + lim = abs(states_x).max() + plt.subplot(gs[0]) + for d in range(latent_dim): + plt.plot(states_x[:, d] + lim * d, '-k') + plt.yticks(np.arange(latent_dim) * lim, ["$x_{}$".format(d+1) for d in range(latent_dim)]) + plt.xticks([]) + plt.xlim(0, T) + plt.title("Simulated Latent States") + + lim = abs(emissions).max() + plt.subplot(gs[1]) + for n in range(emissions_dim): + plt.plot(emissions[:, n] - lim * n, '-') + plt.yticks(-np.arange(emissions_dim) * lim, ["$y_{{ {} }}$".format(n+1) for n in range(emissions_dim)]) + plt.xlabel("time") + plt.xlim(0, T) + + plt.title("Simulated emissions") + plt.tight_layout() + +def plot_vector_field(*args, color='black'): + num_plots = len(args) + fig, ax = plt.subplots(1,num_plots,figsize=(4*num_plots, 4)) + ax = np.atleast_1d(ax) + xlims = [-2, 2] + ylims = [-2, 2] + X1, X2 = np.meshgrid(np.linspace(xlims[0], xlims[1], 10), np.linspace(ylims[0], ylims[1], 10)) + points = np.stack((X1, X2)) + for i, A in enumerate(args): + AX = np.einsum('ij,jkl->ikl', A, points) +# Q = ax[i].quiver(X1, X2, AX[0] - X1, AX[1] - X2, units='width', color=plt.cm.viridis_r(i/(len(args)-1)),scale = 2,scale_units = 'xy') + Q = ax[i].quiver(X1, X2, AX[0] - X1, AX[1] - X2, units='width', color=color,scale = 2,scale_units = 'xy') + ax[i].set_xlim(xlims) + ax[i].set_ylim(ylims) +# ax[i].scatter(0,0,color=plt.cm.viridis_r(i/(len(args)-1)),s = 10) + ax[i].scatter(0,0,color=color,s = 10) + + return ax + +def plot_continuous_latent_states(states_x, var_x=None, title="", spacing=1): + plt.figure(figsize=(10, 6)) + T = states_x.shape[0] + latent_dim = states_x.shape[1] + gs = plt.GridSpec(.1, 1) + # Plot the continuous latent states + if var_x is not None: + lim = abs(states_x + np.sqrt(var_x)).max() + else: + lim = abs(states_x).max() + lim *= spacing + plt.subplot(gs[0]) + for d in range(latent_dim): + x = states_x[:, d] + lim * d + plt.plot(x, '-k') + if var_x is not None: + plt.fill_between(np.arange(x.shape[0]), x+np.sqrt(var_x[:, d]), x-np.sqrt(var_x[:, d])) + plt.yticks(np.arange(latent_dim) * lim, ["$x_{}$".format(d+1) for d in range(latent_dim)]) + plt.xticks([]) + plt.xlim(0, T) + plt.title(title) + +def wnb_histogram_plot(posteriors, tau_duration=False,duration_plot=False, state_usage_plot=False, ordered_state_usage=False, state_switch=False): + if tau_duration + state_usage_plot + ordered_state_usage + state_switch == 0: + print('no histogram selected!') + return + + data_dim = posteriors[0].model.data_dim + total_states = posteriors[0].model.observations.num_states * len(posteriors[0].model.observations.taus) + + states = np.concatenate([posterior.get_states() for posterior in posteriors]) + state_usage = np.bincount(states, minlength=total_states) + + num_taus = len(posteriors[0].model.taus) + + if tau_duration: + durations = Posterior.state_durations(states, total_states) + duration_mean = np.mean(np.concatenate(durations)* 1000 / 30) + wandb.log({'duration_mean': duration_mean}) + print('duration mean: ', duration_mean) + duration_means = np.zeros(total_states) + duration_covs = np.zeros(total_states) + for k in range(len(durations)): + duration_means[k] = durations[k].sum() + duration_covs[k] = np.std(durations[k]) + + fig, ax_array = plt.subplots(4, int(np.ceil(posteriors[0].model.observations.num_states/4)), sharex=True, sharey=True) + for k in range(posteriors[0].model.observations.num_states): + ax = np.ravel(ax_array)[k] + ax.bar(np.arange(num_taus), duration_means[k * num_taus:k * num_taus + num_taus] * 1000 / 30) + # ax.set_xticks([0, 1, 2, 3, 4]) + # ax.set_xticklabels([1, 2, 3, 4, 5]) + ax.tick_params(axis='x', which='major', labelsize=12) + ax.tick_params(axis='y', which='major', labelsize=12) + ax.set_title(k, fontdict={'fontsize': 12}) + + fig.add_subplot(111, frameon=False) + # hide tick and tick label of the big axis + plt.tick_params(labelcolor='none', top=False, bottom=False, left=False, right=False) + plt.xlabel("Tau", labelpad=5) + plt.ylabel("Amount of time in each state (ms)", labelpad=15) + plt.tight_layout() + wandb.log({'tau_distribution_plot':wandb.Image(fig)}, commit=True) + + if duration_plot: + fig, ax = plt.subplots() + durations = Posterior.state_durations(states, total_states) + durations = np.concatenate(durations) * 1000 / 30 + plt.hist(durations, bins=50) + plt.xlabel("Duration (ms)") + plt.ylabel("Count") + plt.tight_layout() + wandb.log({'duration_plot': wandb.Image(fig)}, commit=True) + + if state_usage_plot: + fig, ax = plt.subplots() + plt.bar(np.arange(total_states), state_usage) + plt.xlabel("state index") + plt.ylabel("num frames") + plt.title("histogram of inferred state usage") + plt.tight_layout() + wandb.log({'state_usage':wandb.Image(fig)}, commit=True) + + if ordered_state_usage: + fig, ax = plt.subplots() + order = np.argsort(state_usage / state_usage.sum())[::-1] + plt.bar(np.arange(total_states), (state_usage / state_usage.sum())[order]) + plt.xlabel("state index [ordered]") + plt.ylabel("frequency") + plt.title("ordered histogram of inferred state usage") + plt.tight_layout() + wandb.log({'ordered_state_usage':wandb.Image(fig)}, commit=True) + + if state_switch: + fig, ax = plt.subplots() + changepoint_states = np.concatenate([posterior.state_switch() for posterior in posteriors]) + changepoint_usage = np.bincount(changepoint_states, minlength=total_states) + plt.bar(np.arange(total_states), changepoint_usage) + plt.xlabel("state index") + plt.ylabel("number of switches") + plt.title("histogram of state switches") + plt.tight_layout() + wandb.log({'state_switch': wandb.Image(fig)}, commit=True) + + +def wnb_histogram_plot_cont(posteriors, duration_plot=False, state_usage_plot=False, ordered_state_usage=False, state_switch=False): + if duration_plot + state_usage_plot + ordered_state_usage + state_switch == 0: + print('no histogram selected!') + return + + data_dim = posteriors[0].model.data_dim + total_states = posteriors[0].model.observations.num_states + + states = np.concatenate([posterior.get_states() for posterior in posteriors]) + state_usage = np.bincount(states, minlength=total_states) + + durations = Posterior.state_durations(states, total_states) + wandb.log({'duration_mean': np.mean(np.concatenate(durations) * 1000 / 30)}) + + if duration_plot: + fig, ax = plt.subplots() + durations = Posterior.state_durations(states, total_states) + durations = np.concatenate(durations) * 1000 / 30 + plt.hist(durations, bins=50) + plt.xlabel("Duration (ms)") + plt.ylabel("Count") + plt.tight_layout() + wandb.log({'duration_plot': wandb.Image(fig)}, commit=True) + + if state_usage_plot: + fig, ax = plt.subplots() + plt.bar(np.arange(total_states), state_usage) + plt.xlabel("state index") + plt.ylabel("num frames") + plt.title("histogram of inferred state usage") + plt.tight_layout() + wandb.log({'state_usage':wandb.Image(fig)}, commit=True) + + if ordered_state_usage: + fig, ax = plt.subplots() + order = np.argsort(state_usage / state_usage.sum())[::-1] + plt.bar(np.arange(total_states), (state_usage / state_usage.sum())[order]) + plt.xlabel("state index [ordered]") + plt.ylabel("frequency") + plt.title("ordered histogram of inferred state usage") + plt.tight_layout() + wandb.log({'ordered_state_usage':wandb.Image(fig)}, commit=True) + + if state_switch: + fig, ax = plt.subplots() + changepoint_states = np.concatenate([posterior.state_switch() for posterior in posteriors]) + changepoint_usage = np.bincount(changepoint_states, minlength=total_states) + plt.bar(np.arange(total_states), changepoint_usage) + plt.xlabel("state index") + plt.ylabel("number of switches") + plt.title("histogram of state switches") + plt.tight_layout() + wandb.log({'state_switch': wandb.Image(fig)}, commit=True) + +# for plotting PCs associated with each state and tau +def plot_multiple_average_pcs(state_list, + data_dim, + posteriors, + spc=4, + pad=30): + ''' + ''' + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + cmap = plt.get_cmap('viridis') + colorlist = [i / len(state_list) for i in range(len(state_list))] + # print(colorlist) + + for color, state_idx in enumerate(state_list): + + # Find slices for this state + slices = extract_syllable_slices(state_idx, posteriors, num_instances=20000) + # Find maximum duration + durs = [] + num_slices = 0 + for these_slices in slices: + for slc in these_slices: + durs.append(slc.stop - slc.start) + num_slices += 1 + if num_slices == 0: + print("no valid syllables found for state", state_idx) + return + #TODO: fix this for when not all taus are used + max_dur = np.max(durs) + # Initialize timestamps + times = np.arange(-pad, max_dur + pad) / 30 + exs = np.nan * np.ones((num_slices, 2 * pad + max_dur, data_dim)) + counter = 0 + # Make figure + + for these_slices, posterior in zip(slices, posteriors): + data = posterior.data + for slc in these_slices: + lpad = min(pad, slc.start) + rpad = min(pad, len(data['data']) - slc.stop) + dur = slc.stop - slc.start + padded_slc = slice(slc.start - lpad, slc.stop + rpad) + x = data['data'][padded_slc][:, :data_dim] + exs[counter][(pad - lpad):(pad - lpad + len(x))] = x + counter += 1 + # Plot single example + # ax.plot(times[(pad - lpad):(pad - lpad + len(x))], + # x - spc * np.arange(data_dim), + # ls='-', lw=.5, color='k') + # take the mean and standard deviation + ex_mean = np.nanmean(exs, axis=0) + ex_std = np.nanstd(exs, axis=0) + for d in range(data_dim): + # ax.fill_between(times, + # ex_mean[:, d] - 2 * ex_std[:, d] - spc * d, + # ex_mean[:, d] + 2 * ex_std[:, d] - spc * d, + # color='k', alpha=0.25) + if d == data_dim-1: ax.plot(times[:-1], np.abs(np.diff(ex_mean[:, d])) - spc * d, c=cmap(color / len(state_list)), lw=2,label=color) + else: ax.plot(times[:-1], np.abs(np.diff(ex_mean[:, d])) - spc * d, c=cmap(color / len(state_list)), lw=2) + ax.plot([0, 0], [-spc * data_dim, spc], '-r', lw=2, ls='--') + ax.set_yticks(-spc * np.arange(data_dim)) + ax.set_yticklabels(np.arange(data_dim) + 1) + ax.set_ylim(-spc * data_dim, spc) + ax.set_ylabel("principal component") + # ax.set_xlim(times[0], times[-1]) + ax.set_xlim(-.25, .5) + ax.set_xlabel("$\Delta t$ [s]") + num_taus = len(posterior.model.taus) + ax.set_title("Average PCs for State {}".format(((state_idx - num_taus + 1) / num_taus))) + + +def centroid_velocity_plot(posteriors): + speeds = [] + state_list = [] + for posterior in posteriors: + centroid_x = posterior.data['centroid_x_px'] + centroid_y = posterior.data['centroid_y_px'] + position = np.vstack((centroid_x,centroid_y)).T + speed = np.linalg.norm(np.diff(position,axis=0),axis=1) + speeds.append(np.hstack((0,speed))) + + states = posterior.get_states() + state_list.append(states) + + speeds = np.concatenate(speeds) + state_list = np.concatenate(state_list) + + num_taus = len(posteriors[0].model.taus) + total_states = posteriors[0].model.num_discrete_states * num_taus + fig, ax_array = plt.subplots(4, int(np.ceil(posteriors[0].model.num_discrete_states / 4)), sharex=True, sharey=True) + for k in range(posteriors[0].model.num_discrete_states): + avg_state_speeds = np.zeros(num_taus) + avg_state_vars = np.zeros(num_taus) + state_tau_list = np.arange(k*num_taus, (k+1)*num_taus) + for i, state in enumerate(state_tau_list): + inds = state_list == state + state_speeds = speeds[inds] + avg_state_speeds[i] = np.mean(state_speeds) + avg_state_vars[i] = np.std(state_speeds) + ax = np.ravel(ax_array)[k] + ax.bar(np.arange(num_taus), avg_state_speeds) + ax.set_xticks([i for i in range(num_taus)]) + ax.set_xticklabels([i+1 for i in range(num_taus)]) + ax.tick_params(axis='x', which='major', labelsize=12) + ax.tick_params(axis='y', which='major', labelsize=12) + ax.set_title(k, fontdict={'fontsize': 12}) + + plt.tight_layout() + wandb.log({'centroid_velocity': wandb.Image(fig)}, commit=True) + +def ave_height_plot(posteriors): + heights = [] + state_list = [] + for posterior in posteriors: + height_ave = posterior.data['height_ave_mm'] + heights.append(height_ave) + + states = posterior.get_states() + state_list.append(states) + + heights = np.concatenate(heights) + state_list = np.concatenate(state_list) + + num_taus = len(posteriors[0].model.taus) + total_states = posteriors[0].model.num_discrete_states * num_taus + fig, ax_array = plt.subplots(4, int(np.ceil(posteriors[0].model.num_discrete_states / 4)), sharex=True, sharey=True) + for k in range(posteriors[0].model.num_discrete_states): + max_state_heights = np.zeros(num_taus) + state_tau_list = np.arange(k*num_taus, (k+1)*num_taus) + for i, state in enumerate(state_tau_list): + inds = state_list == state + state_heights = heights[inds] + max_state_heights[i] = np.mean(state_heights) + ax = np.ravel(ax_array)[k] + ax.bar(np.arange(num_taus), max_state_heights) + ax.set_xticks([i for i in range(num_taus)]) + ax.set_xticklabels([i+1 for i in range(num_taus)]) + ax.tick_params(axis='x', which='major', labelsize=12) + ax.tick_params(axis='y', which='major', labelsize=12) + ax.set_title(k, fontdict={'fontsize': 12}) + + plt.tight_layout() + #plt.savefig('average_height_plot_different-sweep') + plt.show() + + #wandb.log({'avg_max_height': wandb.Image(fig)}, commit=True) + +def save_videos_wandb(posteriors): + for i in trange(posteriors[0].model.num_discrete_states): + try: + filename = "crowd{}_grouped.mp4".format(i) + # video = make_crowd_movie(i*posteriors[0].num_taus + posteriors[0].num_taus//2, posteriors) + video = make_crowd_movie(i, posteriors) + video = video.transpose([0,3,1,2]) + wandb.log( + {filename: wandb.Video(video, fps=30, format="mp4")}) + except: + print("failed to create a movie for state", i) + +#------------------------------------------------------------------------------------------------------------------------ +# functions helpful for animating crowd movies for mouse dataset + +import numpy as np +import numpy.random as npr +import matplotlib.pyplot as plt +import torch +# Specify that we want our tensors on the GPU and in float32 +device = torch.device('cpu') +dtype = torch.float64 + +# Helper function to convert between numpy arrays and tensors +to_t = lambda array: torch.tensor(array, device=device, dtype=dtype) +from_t = lambda tensor: tensor.to("cpu").detach().numpy() + +import cv2 +from matplotlib import animation +from IPython.display import HTML +from tempfile import NamedTemporaryFile +import base64 +import seaborn as sns + +sns.set_context("notebook") + +# initialize a color palette for plotting +palette = sns.xkcd_palette(["windows blue", + "red", + "medium green", + "dusty purple", + "greyish", + "orange", + "amber", + "clay", + "pink"]) + + +def sum_tuples(a, b): + assert a or b + if a is None: + return b + elif b is None: + return a + else: + return tuple(ai + bi for ai, bi in zip(a, b)) + + +_VIDEO_TAG = """""" + + +def _anim_to_html(anim, fps=20): + # todo: todocument + if not hasattr(anim, '_encoded_video'): + with NamedTemporaryFile(suffix='.mp4') as f: + anim.save(f.name, fps=fps, extra_args=['-vcodec', 'libx264']) + video = open(f.name, "rb").read() + anim._encoded_video = base64.b64encode(video) + + return _VIDEO_TAG.format(anim._encoded_video.decode('ascii')) + + +def _display_animation(anim, fps=30, start=0, stop=None): + plt.close(anim._fig) + return HTML(_anim_to_html(anim, fps=fps)) + + +def play(movie, fps=30, speedup=1, fig_height=6, + filename=None, show_time=False, show=True): + # First set up the figure, the axis, and the plot element we want to animate + T, Py, Px = movie.shape[:3] + fig, ax = plt.subplots(1, 1, figsize=(fig_height * Px / Py, fig_height)) + im = plt.imshow(movie[0], interpolation='None', cmap=plt.cm.gray) + + if show_time: + tx = plt.text(0.75, 0.05, 't={:.3f}s'.format(0), + color='white', + fontdict=dict(size=12), + horizontalalignment='left', + verticalalignment='center', + transform=ax.transAxes) + plt.axis('off') + + def animate(i): + im.set_data(movie[i * speedup]) + if show_time: + tx.set_text("t={:.3f}s".format(i * speedup / fps)) + return im, + + # call the animator. blit=True means only re-draw the parts that have changed. + + anim = animation.FuncAnimation(fig, animate, + frames=T // speedup, + interval=1, + blit=True) + plt.close(anim._fig) + + # save to mp4 if filename specified + if filename is not None: + with open(filename, "wb") as f: + anim.save(f.name, fps=fps, extra_args=['-vcodec', 'libx264']) + + # return an HTML video snippet + if show: + print("Preparing animation. This may take a minute...") + return HTML(_anim_to_html(anim, fps=30)) + + +def plot_data_and_states(data, states, + spc=4, slc=slice(0, 900), + title=None): + times = data["times"][slc] + labels = data["labels"][slc] + x = data["data"][slc] + num_timesteps, data_dim = x.shape + + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + ax.imshow(states[None, slc], + cmap="cubehelix", aspect="auto", + extent=(0, times[-1] - times[0], -data_dim * spc, spc)) + + ax.plot(times - times[0], + x - spc * np.arange(data_dim), + ls='-', lw=3, color='w') + ax.plot(times - times[0], + x - spc * np.arange(data_dim), + ls='-', lw=2, color=palette[0]) + + ax.set_yticks(-spc * np.arange(data_dim)) + ax.set_yticklabels(np.arange(data_dim)) + ax.set_ylabel("principal component") + ax.set_xlim(0, times[-1] - times[0]) + ax.set_xlabel("time [ms]") + + if title is None: + ax.set_title("data and discrete states") + else: + ax.set_title(title) + + +def extract_syllable_slices(state_idx, + posteriors, + pad=30, + num_instances=50, + min_duration=5, + max_duration=45, + seed=0): + # Find all the start indices and durations of specified state + all_mouse_inds = [] + all_starts = [] + all_durations = [] + for mouse, posterior in enumerate(posteriors): + states = np.argmax(posterior.expected_states(), axis=1)#//posteriors[0].num_taus + states = np.concatenate([[-1], states, [-1]]) + starts = np.where((states[1:] == state_idx) \ + & (states[:-1] != state_idx))[0] + stops = np.where((states[:-1] == state_idx) \ + & (states[1:] != state_idx))[0] + durations = stops - starts + assert np.all(durations >= 1) + all_mouse_inds.append(mouse * np.ones(len(starts), dtype=int)) + all_starts.append(starts) + all_durations.append(durations) + + all_mouse_inds = np.concatenate(all_mouse_inds) + all_starts = np.concatenate(all_starts) + all_durations = np.concatenate(all_durations) + + # Throw away ones that are too short or too close to start. + # TODO: also throw away ones close to the end + valid = (all_durations >= min_duration) \ + & (all_durations < max_duration) \ + & (all_starts > pad) + + num_valid = np.sum(valid) + all_mouse_inds = all_mouse_inds[valid] + all_starts = all_starts[valid] + all_durations = all_durations[valid] + + # Choose a random subset to show + rng = npr.RandomState(seed) + subset = rng.choice(num_valid, + size=min(num_valid, num_instances), + replace=False) + + all_mouse_inds = all_mouse_inds[subset] + all_starts = all_starts[subset] + all_durations = all_durations[subset] + + # Extract slices for each mouse + slices = [] + for mouse in range(len(posteriors)): + is_mouse = (all_mouse_inds == mouse) + slices.append([slice(start, start + dur) for start, dur in + zip(all_starts[is_mouse], all_durations[is_mouse])]) + + return slices + +def extract_syllable_slices_indiv(state_idx, + posteriors, + pad=30, + num_instances=50, + min_duration=5, + max_duration=45, + seed=0): + # Find all the start indices and durations of specified state + all_mouse_inds = [] + all_starts = [] + all_durations = [] + for mouse, posterior in enumerate(posteriors): + states = np.argmax(posterior.expected_states(), axis=1) + states = np.concatenate([[-1], states, [-1]]) + starts = np.where((states[1:] == state_idx) \ + & (states[:-1] != state_idx))[0] + stops = np.where((states[:-1] == state_idx) \ + & (states[1:] != state_idx))[0] + durations = stops - starts + assert np.all(durations >= 1) + all_mouse_inds.append(mouse * np.ones(len(starts), dtype=int)) + all_starts.append(starts) + all_durations.append(durations) + + all_mouse_inds = np.concatenate(all_mouse_inds) + all_starts = np.concatenate(all_starts) + all_durations = np.concatenate(all_durations) + + # Throw away ones that are too short or too close to start. + # TODO: also throw away ones close to the end + valid = (all_durations >= min_duration) \ + & (all_durations < max_duration) \ + & (all_starts > pad) + + num_valid = np.sum(valid) + all_mouse_inds = all_mouse_inds[valid] + all_starts = all_starts[valid] + all_durations = all_durations[valid] + + # Choose a random subset to show + rng = npr.RandomState(seed) + subset = rng.choice(num_valid, + size=min(num_valid, num_instances), + replace=False) + + all_mouse_inds = all_mouse_inds[subset] + all_starts = all_starts[subset] + all_durations = all_durations[subset] + + # Extract slices for each mouse + slices = [] + for mouse in range(len(posteriors)): + is_mouse = (all_mouse_inds == mouse) + slices.append([slice(start, start + dur) for start, dur in + zip(all_starts[is_mouse], all_durations[is_mouse])]) + + return slices + +def make_crowd_movie(state_idx, + posteriors, + pad=30, + raw_size=(512, 424), + crop_size=(80, 80), + offset=(50, 50), + scale=.5, + min_height=10, + **kwargs): + ''' + Adapted from https://github.com/dattalab/moseq2-viz/blob/release/moseq2_viz/viz.py + + Creates crowd movie video numpy array. + Parameters + ---------- + dataset (list of dicts): list of dictionaries containing data + slices (np.ndarray): video slices of specific syllable label + pad (int): number of frame padding in video + raw_size (tuple): video dimensions. + frame_path (str): variable to access frames in h5 file + crop_size (tuple): mouse crop size + offset (tuple): centroid offsets from cropped videos + scale (int): mouse size scaling factor. + min_height (int): minimum max height from floor to use. + kwargs (dict): extra keyword arguments + Returns + ------- + crowd_movie (np.ndarray): crowd movie for a specific syllable. + ''' + slices = extract_syllable_slices(state_idx, posteriors) + + xc0, yc0 = crop_size[1] // 2, crop_size[0] // 2 + xc = np.arange(-xc0, xc0 + 1, dtype='int16') + yc = np.arange(-yc0, yc0 + 1, dtype='int16') + + durs = [] + for these_slices in slices: + for slc in these_slices: + durs.append(slc.stop - slc.start) + + if len(durs) == 0: + print("no valid syllables found for state", state_idx) + return + max_dur = np.max(durs) + + # Initialize the crowd movie + crowd_movie = np.zeros((max_dur + pad * 2, raw_size[1], raw_size[0], 3), + dtype='uint8') + + for these_slices, posterior in zip(slices, posteriors): + data = posterior.data + for slc in these_slices: + lpad = min(pad, slc.start) + rpad = min(pad, len(data['frames']) - slc.stop) + dur = slc.stop - slc.start + padded_slc = slice(slc.start - lpad, slc.stop + rpad) + centroid_x = data['centroid_x_px'][padded_slc] + offset[0] + centroid_y = data['centroid_y_px'][padded_slc] + offset[1] + angles = np.rad2deg(data['angles'][padded_slc]) + frames = (data['frames'][padded_slc] / scale).astype('uint8') + flips = np.zeros(angles.shape, dtype='bool') + + for i in range(lpad + dur + rpad): + if np.any(np.isnan([centroid_x[i], centroid_y[i]])): + continue + + rr = (yc + centroid_y[i]).astype('int16') + cc = (xc + centroid_x[i]).astype('int16') + + if (np.any(rr < 1) + or np.any(cc < 1) + or np.any(rr >= raw_size[1]) + or np.any(cc >= raw_size[0]) + or (rr[-1] - rr[0] != crop_size[0]) + or (cc[-1] - cc[0] != crop_size[1])): + continue + + # rotate and clip the current frame + new_frame_clip = frames[i][:, :, None] * np.ones((1, 1, 3)) + rot_mat = cv2.getRotationMatrix2D((xc0, yc0), angles[i], 1) + new_frame_clip = cv2.warpAffine(new_frame_clip.astype('float32'), + rot_mat, crop_size).astype(frames.dtype) + + # overlay a circle on the mouse + if i >= lpad and i <= pad + dur: + cv2.circle(new_frame_clip, (xc0, yc0), 3, + (255, 0, 0), -1) + + # superimpose the clipped mouse + old_frame = crowd_movie[i] + new_frame = np.zeros_like(old_frame) + new_frame[rr[0]:rr[-1], cc[0]:cc[-1]] = new_frame_clip + + # zero out based on min_height before taking the non-zeros + new_frame[new_frame < min_height] = 0 + old_frame[old_frame < min_height] = 0 + + new_frame_nz = new_frame > 0 + old_frame_nz = old_frame > 0 + + blend_coords = np.logical_and(new_frame_nz, old_frame_nz) + overwrite_coords = np.logical_and(new_frame_nz, ~old_frame_nz) + + old_frame[blend_coords] = .5 * old_frame[blend_coords] \ + + .5 * new_frame[blend_coords] + old_frame[overwrite_coords] = new_frame[overwrite_coords] + + crowd_movie[i] = old_frame + + return crowd_movie + +def make_crowd_movie_grouped(state_idx, + num_taus, + posteriors, + pad=30, + raw_size=(512, 424), + crop_size=(80, 80), + offset=(50, 50), + scale=.5, + min_height=10, + **kwargs): + ''' + Adapted from https://github.com/dattalab/moseq2-viz/blob/release/moseq2_viz/viz.py + + Creates crowd movie video numpy array. + Parameters + ---------- + dataset (list of dicts): list of dictionaries containing data + slices (np.ndarray): video slices of specific syllable label + pad (int): number of frame padding in video + raw_size (tuple): video dimensions. + frame_path (str): variable to access frames in h5 file + crop_size (tuple): mouse crop size + offset (tuple): centroid offsets from cropped videos + scale (int): mouse size scaling factor. + min_height (int): minimum max height from floor to use. + kwargs (dict): extra keyword arguments + Returns + ------- + crowd_movie (np.ndarray): crowd movie for a specific syllable. + ''' + slices = extract_syllable_slices(state_idx, posteriors) + + xc0, yc0 = crop_size[1] // 2, crop_size[0] // 2 + xc = np.arange(-xc0, xc0 + 1, dtype='int16') + yc = np.arange(-yc0, yc0 + 1, dtype='int16') + + durs = [] + for these_slices in slices: + for slc in these_slices: + durs.append(slc.stop - slc.start) + + if len(durs) == 0: + print("no valid syllables found for state", state_idx) + return + max_dur = np.max(durs) + + # Initialize the crowd movie + crowd_movie = np.zeros((max_dur + pad * 2, raw_size[1], raw_size[0], 3), + dtype='uint8') + + for these_slices, posterior in zip(slices, posteriors): + data = posterior.data + for slc in these_slices: + lpad = min(pad, slc.start) + rpad = min(pad, len(data['frames']) - slc.stop) + dur = slc.stop - slc.start + padded_slc = slice(slc.start - lpad, slc.stop + rpad) + centroid_x = data['centroid_x_px'][padded_slc] + offset[0] + centroid_y = data['centroid_y_px'][padded_slc] + offset[1] + angles = np.rad2deg(data['angles'][padded_slc]) + frames = (data['frames'][padded_slc] / scale).astype('uint8') + flips = np.zeros(angles.shape, dtype='bool') + + for i in range(lpad + dur + rpad): + if np.any(np.isnan([centroid_x[i], centroid_y[i]])): + continue + + rr = (yc + centroid_y[i]).astype('int16') + cc = (xc + centroid_x[i]).astype('int16') + + if (np.any(rr < 1) + or np.any(cc < 1) + or np.any(rr >= raw_size[1]) + or np.any(cc >= raw_size[0]) + or (rr[-1] - rr[0] != crop_size[0]) + or (cc[-1] - cc[0] != crop_size[1])): + continue + + # rotate and clip the current frame + new_frame_clip = frames[i][:, :, None] * np.ones((1, 1, 3)) + rot_mat = cv2.getRotationMatrix2D((xc0, yc0), angles[i], 1) + new_frame_clip = cv2.warpAffine(new_frame_clip.astype('float32'), + rot_mat, crop_size).astype(frames.dtype) + + # overlay a circle on the mouse + if i >= lpad and i <= pad + dur: + cv2.circle(new_frame_clip, (xc0, yc0), 3, + (255, 0, 0), -1) + + # superimpose the clipped mouse + old_frame = crowd_movie[i] + new_frame = np.zeros_like(old_frame) + new_frame[rr[0]:rr[-1], cc[0]:cc[-1]] = new_frame_clip + + # zero out based on min_height before taking the non-zeros + new_frame[new_frame < min_height] = 0 + old_frame[old_frame < min_height] = 0 + + new_frame_nz = new_frame > 0 + old_frame_nz = old_frame > 0 + + blend_coords = np.logical_and(new_frame_nz, old_frame_nz) + overwrite_coords = np.logical_and(new_frame_nz, ~old_frame_nz) + + old_frame[blend_coords] = .5 * old_frame[blend_coords] \ + + .5 * new_frame[blend_coords] + old_frame[overwrite_coords] = new_frame[overwrite_coords] + + crowd_movie[i] = old_frame + + return crowd_movie + +def make_crowd_movie_indiv(state_idx, + posteriors, + pad=30, + raw_size=(512, 424), + crop_size=(80, 80), + offset=(50, 50), + scale=.5, + min_height=10, + **kwargs): + ''' + Adapted from https://github.com/dattalab/moseq2-viz/blob/release/moseq2_viz/viz.py + + Creates crowd movie video numpy array. + Parameters + ---------- + dataset (list of dicts): list of dictionaries containing data + slices (np.ndarray): video slices of specific syllable label + pad (int): number of frame padding in video + raw_size (tuple): video dimensions. + frame_path (str): variable to access frames in h5 file + crop_size (tuple): mouse crop size + offset (tuple): centroid offsets from cropped videos + scale (int): mouse size scaling factor. + min_height (int): minimum max height from floor to use. + kwargs (dict): extra keyword arguments + Returns + ------- + crowd_movie (np.ndarray): crowd movie for a specific syllable. + ''' + slices = extract_syllable_slices_indiv(state_idx, posteriors) + + xc0, yc0 = crop_size[1] // 2, crop_size[0] // 2 + xc = np.arange(-xc0, xc0 + 1, dtype='int16') + yc = np.arange(-yc0, yc0 + 1, dtype='int16') + + durs = [] + for these_slices in slices: + for slc in these_slices: + durs.append(slc.stop - slc.start) + + if len(durs) == 0: + print("no valid syllables found for state", state_idx) + return + max_dur = np.max(durs) + + # Initialize the crowd movie + crowd_movie = np.zeros((max_dur + pad * 2, raw_size[1], raw_size[0], 3), + dtype='uint8') + + for these_slices, posterior in zip(slices, posteriors): + data = posterior.data + for slc in these_slices: + lpad = min(pad, slc.start) + rpad = min(pad, len(data['frames']) - slc.stop) + dur = slc.stop - slc.start + padded_slc = slice(slc.start - lpad, slc.stop + rpad) + centroid_x = data['centroid_x_px'][padded_slc] + offset[0] + centroid_y = data['centroid_y_px'][padded_slc] + offset[1] + angles = np.rad2deg(data['angles'][padded_slc]) + frames = (data['frames'][padded_slc] / scale).astype('uint8') + flips = np.zeros(angles.shape, dtype='bool') + + for i in range(lpad + dur + rpad): + if np.any(np.isnan([centroid_x[i], centroid_y[i]])): + continue + + rr = (yc + centroid_y[i]).astype('int16') + cc = (xc + centroid_x[i]).astype('int16') + + if (np.any(rr < 1) + or np.any(cc < 1) + or np.any(rr >= raw_size[1]) + or np.any(cc >= raw_size[0]) + or (rr[-1] - rr[0] != crop_size[0]) + or (cc[-1] - cc[0] != crop_size[1])): + continue + + # rotate and clip the current frame + new_frame_clip = frames[i][:, :, None] * np.ones((1, 1, 3)) + rot_mat = cv2.getRotationMatrix2D((xc0, yc0), angles[i], 1) + new_frame_clip = cv2.warpAffine(new_frame_clip.astype('float32'), + rot_mat, crop_size).astype(frames.dtype) + + # overlay a circle on the mouse + if i >= lpad and i <= pad + dur: + cv2.circle(new_frame_clip, (xc0, yc0), 3, + (255, 0, 0), -1) + + # superimpose the clipped mouse + old_frame = crowd_movie[i] + new_frame = np.zeros_like(old_frame) + new_frame[rr[0]:rr[-1], cc[0]:cc[-1]] = new_frame_clip + + # zero out based on min_height before taking the non-zeros + new_frame[new_frame < min_height] = 0 + old_frame[old_frame < min_height] = 0 + + new_frame_nz = new_frame > 0 + old_frame_nz = old_frame > 0 + + blend_coords = np.logical_and(new_frame_nz, old_frame_nz) + overwrite_coords = np.logical_and(new_frame_nz, ~old_frame_nz) + + old_frame[blend_coords] = .5 * old_frame[blend_coords] \ + + .5 * new_frame[blend_coords] + old_frame[overwrite_coords] = new_frame[overwrite_coords] + + crowd_movie[i] = old_frame + + return crowd_movie + +# def make_indiv_movie(state, posteriors, pad=30): +# raw_size = posteriors[0].data['frames'][0].shape +# num_taus = len(posteriors[0].model.taus) +# +# state_slices = [[] for tau in range(num_taus)] +# state_frames = [[] for tau in range(num_taus)] +# for posterior in posteriors: +# data = posterior.data +# model = posterior.model +# num_discrete_states = model.num_discrete_states +# for tau in range(num_taus): +# state_idx = state*num_taus + tau +# slices = extract_syllable_slices(state_idx, [posterior]) +# if len(slices[0]) == 0: +# print("no valid syllables found for state", state_idx) +# elif len(slices[0]) != 0 and state_slices[tau] == []: +# state_slices[tau] = [slices[0][0]] +# slc = slices[0][0] +# lpad = min(pad, slc.start) +# rpad = min(pad, len(data['frames']) - slc.stop) +# dur = slc.stop - slc.start +# padded_slc = slice(slc.start - lpad, slc.stop + rpad) +# state_frames[tau] = (data['frames'][padded_slc]).astype('uint8') +# +# +# durs = [] +# for slc in state_slices: +# if len(slc) is not 0: +# slc = slc[0] +# durs.append(slc.stop - slc.start) +# +# max_dur = np.max(durs) +# +# # Initialize the crowd movie +# crowd_movie = np.zeros((max_dur + pad * 2, raw_size[1]*num_taus, raw_size[0]), +# dtype='uint8') +# +# for it, frame, slc in zip(range(num_taus), state_frames,state_slices): +# slc = slc[0] +# dur = slc.stop - slc.start +# crowd_movie[:dur+pad*2,raw_size[1]*it:raw_size[1]*(it+1),:] = frame +# +# crowd_movie = crowd_movie[:,:, :, None] * np.ones((1, 1, 3)) +# +# crowd_movie = cv2.normalize(crowd_movie, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) +# crowd_movie = crowd_movie.transpose([0, 3, 1, 2]) +# wandb.log( +# {'temp': wandb.Video(crowd_movie, fps=30, format="mp4")}) +# return crowd_movie \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..f67abae --- /dev/null +++ b/train.py @@ -0,0 +1,55 @@ +import wandb +import os +from twarhmm import TWARHMM, LinearRegressionObservations +from plotting_util import wnb_histogram_plot, save_videos_wandb, centroid_velocity_plot +from data_util import load_dataset, standardize_pcs, precompute_ar_covariates, log_wandb_model +import datetime + +data_dim = 10 +num_lags = 1 + +hyperparameter_defaults = dict( + num_discrete_states=5, + data_dim=data_dim, + covariates_dim=11, + tau_scale=0.6, + num_taus=5, + kappa=10000, + alpha=5, + covariance_reg=1e-4) + +train_dataset, test_dataset = load_dataset(num_pcs=data_dim) +train_dataset, mean, std = standardize_pcs(train_dataset) +test_dataset, _, _ = standardize_pcs(test_dataset, mean, std) + +print("data loaded") +# First compute the autoregression covariates +precompute_ar_covariates(train_dataset, num_lags=num_lags, fit_intercept=True) +precompute_ar_covariates(test_dataset, num_lags=num_lags, fit_intercept=True) + +# Then precompute the sufficient statistics +LinearRegressionObservations.precompute_suff_stats(train_dataset) +LinearRegressionObservations.precompute_suff_stats(test_dataset) + +covariates_dim = train_dataset[0]['covariates'].shape[1] + +projectname = "twarhmm" +wandb.init(config=hyperparameter_defaults, entity="twss", project=projectname) +config = wandb.config + +twarhmm = TWARHMM(config, None) + +train_lls, test_lls, train_posteriors, test_posteriors, = \ + twarhmm.fit_stoch(train_dataset, + test_dataset, + num_epochs=50, compute_posteriors=True, fit_transitions=True) + +e = datetime.datetime.now() + +log_wandb_model(twarhmm, "twarhmm_K{}_T{}".format(twarhmm.num_discrete_states,len(twarhmm.taus)),type="model") +if test_posteriors is not None: + wnb_histogram_plot(test_posteriors, tau_duration=True, duration_plot=True, state_usage_plot=True, ordered_state_usage=True, state_switch=True) + centroid_velocity_plot(test_posteriors) +#save_videos_wandb(test_posteriors) + +wandb.finish() diff --git a/train_gp.py b/train_gp.py new file mode 100644 index 0000000..97a0298 --- /dev/null +++ b/train_gp.py @@ -0,0 +1,60 @@ +import wandb +import numpy as np +import os +from warhmm_gp import TWARHMM_GP, LinearRegressionObservations_GP +from data_util import load_dataset, standardize_pcs, precompute_ar_covariates, log_wandb_model +import datetime +from kernels import RBF +import matplotlib.pyplot as plt + +data_dim = 10 +num_lags = 1 + +hyperparameter_defaults = dict( + num_discrete_states=20, + data_dim=data_dim, + covariates_dim=11, + tau_scale=1, + num_taus=31, + kappa=10000, + alpha=5, + covariance_reg=1e-4, + lengthscale=1 +) + +train_dataset, test_dataset = load_dataset(num_pcs=data_dim) +train_dataset, mean, std = standardize_pcs(train_dataset) +test_dataset, _, _ = standardize_pcs(test_dataset, mean, std) + +print("data loaded") +# First compute the autoregression covariates +precompute_ar_covariates(train_dataset, num_lags=num_lags, fit_intercept=True) +precompute_ar_covariates(test_dataset, num_lags=num_lags, fit_intercept=True) + +# Then precompute the sufficient statistics +LinearRegressionObservations_GP.precompute_suff_stats(train_dataset) +LinearRegressionObservations_GP.precompute_suff_stats(test_dataset) + +covariates_dim = train_dataset[0]['covariates'].shape[1] + +projectname = "twarhmm_gp" +wandb.init(config=hyperparameter_defaults, entity="twss", project=projectname) +config = wandb.config + +taus = np.linspace(-config['tau_scale'], config['tau_scale'], config['num_taus']) +twarhmm_gp = TWARHMM_GP(config, taus, kernel=RBF(config['num_discrete_states'], config['lengthscale'])) + +train_lls, test_lls, train_posteriors, test_posteriors, = \ + twarhmm_gp.fit_stoch(train_dataset, + test_dataset, + num_epochs=50, fit_transitions=True, fit_tau=False, fit_kernel_params=False, wandb_log=True) +#plt.plot(test_lls) +# e = datetime.datetime.now() +# +log_wandb_model(twarhmm_gp, "twarhmm_gp_K{}_T{}".format(twarhmm_gp.num_discrete_states,len(twarhmm_gp.taus)),type="model") +# if test_posteriors is not None: +# wnb_histogram_plot(test_posteriors, tau_duration=True, duration_plot=True, state_usage_plot=True, ordered_state_usage=True, state_switch=True) +# centroid_velocity_plot(test_posteriors) +# #save_videos_wandb(test_posteriors) +# +wandb.finish() \ No newline at end of file diff --git a/twarhmm.py b/twarhmm.py new file mode 100644 index 0000000..e3d6926 --- /dev/null +++ b/twarhmm.py @@ -0,0 +1,709 @@ +import numpy as np +import numpy.random as npr +import scipy.stats +import torch +from tqdm.auto import trange +from torch.distributions import MultivariateNormal +import pickle +import os +from util import random_rotation, sum_tuples +import wandb +import time +from numba import njit, prange + +device = torch.device('cpu') +dtype = torch.float64 +to_t = lambda array: torch.tensor(array, device=device, dtype=dtype) +from_t = lambda tensor: tensor.to("cpu").detach().numpy() + +class TWARHMM(object): + + def __init__(self, config, taus=None): #config is a dictionary containing parameters + self.config = dict(config) + self.num_discrete_states = config["num_discrete_states"] + self.data_dim = config["data_dim"] + self.covariates_dim = config["covariates_dim"] + if np.any(taus == None): self.taus = np.logspace(-config["tau_scale"],config["tau_scale"],config["num_taus"],base=2) + else: self.taus = taus + if config["num_taus"] == 1: + self.taus = np.array([1.]) + self.kappa = config["kappa"] + self.alpha = config["alpha"] + self.transitions = Transitions(self.num_discrete_states, len(self.taus), self.alpha, self.kappa, random_init=False) + + self.observations = LinearRegressionObservations(self.num_discrete_states, self.data_dim, + self.covariates_dim, self.taus, config["covariance_reg"]) + + def fit(self, train_dataset, test_dataset, seed=0, num_epochs=50, fit_observations=True, fit_transitions=False, fit_tau_trans=False): + # Fit using full batch EM + num_train = sum([len(data["data"]) for data in train_dataset]) + num_test = sum([len(data["data"]) for data in test_dataset]) + # Initialize with a random posterior + #posteriors = initialize_posteriors(train_dataset, self.num_discrete_states * self.taus.shape[0], seed=seed) + total_states = self.num_discrete_states*len(self.taus) + posteriors = [Posterior(self, data_dict, total_states) for data_dict in train_dataset] + for posterior in posteriors: + posterior.update() + continuous_expectations, discrete_expectations = self.compute_expected_suff_stats(train_dataset, posteriors, self.taus, fit_observations, fit_transitions) + train_lls = [] + test_lls = [] + + # Main loop + for itr in trange(num_epochs): + print(itr) + self.M_step(continuous_expectations, discrete_expectations, fit_observations, fit_transitions, fit_tau_trans) + + for posterior in posteriors: + posterior.update() + + # Compute the expected sufficient statistics under the new posteriors + continuous_expectations, discrete_expectations = self.compute_expected_suff_stats(train_dataset, posteriors, self.taus, fit_observations, fit_transitions) + + # Store the average train likelihood + avg_train_ll = sum([p.marginal_likelihood() for p in posteriors]) / num_train + train_lls.append(avg_train_ll) + + # Compute the posteriors for the test dataset too + test_posteriors = [Posterior(self,data_dict,total_states) for data_dict in test_dataset] + + for posterior in test_posteriors: + posterior.update() + + # Store the average test likelihood + avg_test_ll = sum([p.marginal_likelihood() for p in test_posteriors]) / num_test + test_lls.append(avg_test_ll) + + # convert lls to arrays + train_lls = np.array(train_lls) + test_lls = np.array(test_lls) + return train_lls, test_lls, posteriors, test_posteriors + + def fit_stoch(self, train_dataset, test_dataset, forgetting_rate=-0.5, seed=0, num_epochs=5, fit_observations=True, + fit_transitions=True, fit_tau_trans = True, compute_posteriors=True, wandb_log=True): + # Get some constants + num_batches = len(train_dataset) + taus = np.array(self.taus) + num_test = sum([len(data["data"]) for data in test_dataset]) + total_states = self.num_discrete_states * len(self.taus) + num_train = sum([len(data["data"]) for data in train_dataset]) + + # Initialize the step size schedule + schedule = np.arange(1, 1 + num_batches * num_epochs) ** (forgetting_rate) + + # Initialize progress bars + outer_pbar = trange(num_epochs) + inner_pbar = trange(num_batches) + outer_pbar.set_description("Epoch") + inner_pbar.set_description("Batch") + + # Main loop + rng = npr.RandomState(seed) + train_lls = [] + test_lls = [] + + it_times = np.zeros((num_epochs,num_batches)) + + for epoch in range(num_epochs): + perm = rng.permutation(num_batches) + + inner_pbar.reset() + for itr in range(num_batches): + t = time.time() + minibatch = [train_dataset[perm[itr]]] + this_num_train = len(minibatch[0]["data"]) + + posteriors = [Posterior(self, data, total_states) for data in minibatch] + + # E step: on this minibatch + for posterior in posteriors: + posterior.update() + + if itr == 0 and epoch == 0: continuous_expectations, discrete_expectations = self.compute_expected_suff_stats( + minibatch, posteriors, taus, fit_observations, fit_transitions) + # M step: using current stats + self.M_step(continuous_expectations, discrete_expectations, fit_observations, fit_transitions, fit_tau=fit_tau_trans) + + these_continuous_expectations, these_discrete_expectations = self.compute_expected_suff_stats(minibatch, + posteriors, + taus, fit_observations, + fit_transitions) + rescale = lambda x: num_train / this_num_train * x + + # Rescale the statistics as if they came from the whole dataset + rescaled_cont_stats = tuple(rescale(st) for st in these_continuous_expectations) + rescaled_disc_stats = tuple(rescale(st) for st in these_discrete_expectations) + + # Take a convex combination of the statistics using current step sz + stepsize = schedule[epoch * num_batches + itr] + continuous_expectations = tuple( + sum(x) for x in zip(tuple(st * (1 - stepsize) for st in continuous_expectations), + tuple(st * (stepsize) for st in rescaled_cont_stats))) + discrete_expectations = tuple( + sum(x) for x in zip(tuple(st * (1 - stepsize) for st in discrete_expectations), + tuple(st * (stepsize) for st in rescaled_disc_stats))) + + # Store the normalized log likelihood for this minibatch + avg_mll = sum([p.marginal_likelihood() for p in posteriors]) / this_num_train + train_lls.append(avg_mll) + + elapsed = time.time()-t + #print(elapsed) + it_times[epoch,itr] = elapsed + inner_pbar.set_description("Batch LL: {:.3f}".format(avg_mll)) + inner_pbar.update() + if wandb_log: wandb.log({'batch_ll': avg_mll}) + + # Evaluate the likelihood and posteriors on the test dataset + if compute_posteriors: + test_posteriors = [Posterior(self, test_data, total_states, seed) for test_data in test_dataset] + for posterior in test_posteriors: + posterior.update() + avg_test_mll = sum([p.marginal_likelihood() for p in test_posteriors]) / num_test + else: + mlls = [] + for test_data in test_dataset: + posterior = Posterior(self, test_data, total_states, seed) + posterior.update() + mlls.append(posterior.marginal_likelihood()) + avg_test_mll = np.sum(mlls)/ num_test + test_posteriors = None + test_lls.append(avg_test_mll) + outer_pbar.set_description("Test LL: {:.3f}".format(avg_test_mll)) + outer_pbar.update() + if wandb_log: wandb.log({'test_ll': avg_test_mll}) + + + # convert lls to arrays + train_lls = np.array(train_lls) + test_lls = np.array(test_lls) + + print('average iteration time: ', it_times.mean()) + return train_lls, test_lls, posteriors, test_posteriors + + def save(self, filepath): + # TODO: add optional artifact saving + os.mkdir(filepath) + obs_outfile = open(os.path.join(filepath, "model"), 'wb') + pickle.dump(self, obs_outfile) + obs_outfile.close() + + @staticmethod + def load(dir): + model_infile = open(os.path.join(dir, "model"), 'rb') + model = pickle.load(model_infile) + model_infile.close() + return model + + @staticmethod + def load_wnb(artifact_filepath): + artifact = wandb.use_artifact(artifact_filepath, type="model") + artifact_dir = artifact.download() + return TWARHMM.load(artifact_dir) + + + def E_step(self,initial_dist, transition_matrix, log_likes, compute_joints=True): + (Pz,Pt) = transition_matrix + + max_factor = np.max(log_likes, axis=1, keepdims=True) + alphas, marginal_ll = self.nb_forward_pass(initial_dist, transition_matrix, log_likes,max_factor) + + betas = self.nb_backward_pass(transition_matrix, log_likes, max_factor) + + likes_tilde = np.exp(log_likes - np.max(log_likes, axis=1)[:, None]) + hadamard_prod = alphas * likes_tilde * betas + expected_states = hadamard_prod / np.sum(hadamard_prod, axis=1)[:, None] + + alphas = alphas.reshape((alphas.shape[0],self.num_discrete_states,len(self.taus))) + betas = betas.reshape((betas.shape[0], self.num_discrete_states, len(self.taus))) + log_likes = log_likes.reshape((log_likes.shape[0],self.num_discrete_states, len(self.taus))) + + if compute_joints: #TODO: split into 2 matrices + alphas_z = alphas.sum(axis=2) + alphas_t = alphas.sum(axis=1) + betas_z = betas.sum(axis=2) + betas_t = betas.sum(axis=1) + log_likes_z = log_likes.sum(axis=2) + log_likes_t = log_likes.sum(axis=1) + likes_tilde_z = np.exp(log_likes_z - np.max(log_likes_z, axis=1)[:, None]) + likes_tilde_t = np.exp(log_likes_t - np.max(log_likes_t, axis=1)[:, None]) + + hadamard_2_z = alphas_z[:-1, :, None] * likes_tilde_z[:-1, :, None] * likes_tilde_z[1:, None,:] * Pz[None, :, :] * betas_z[1:,None,:] + expected_joints_z = hadamard_2_z / np.sum(hadamard_2_z, axis=(1, 2), keepdims=True) + + hadamard_2_t = alphas_t[:-1, :, None] * likes_tilde_t[:-1, :, None] * likes_tilde_t[1:, None, :] * Pt[None,:,:] * betas_t[1:,None,:] + expected_joints_t = hadamard_2_t / np.sum(hadamard_2_t, axis=(1, 2), keepdims=True) + + expected_joints = (expected_joints_z,expected_joints_t) + else: + expected_joints = (None, None) + + # Package the results into a dictionary summarizing the posterior + posterior = dict(expected_states=expected_states, + expected_joints=expected_joints, + marginal_ll=marginal_ll) + return posterior + + def M_step(self, continuous_expectations, discrete_expectations, fit_observations, fit_transitions, fit_tau): + if fit_transitions: self.transitions.M_step(discrete_expectations, fit_tau=fit_tau) + if fit_observations: self.observations.M_step(continuous_expectations) + + def forward_pass(self, initial_dist, transition_matrix, log_likes): + (Pz,Pt) = transition_matrix + alphas = np.zeros_like(log_likes) + marginal_ll = 0 + T = log_likes.shape[0] + max_factor = np.max(log_likes, axis=1, keepdims=True) + likes_tilde = np.exp(log_likes - max_factor) + + alphas[0] = np.squeeze(initial_dist) + + for t in range(1, T): + A_t_minus_1 = np.sum(alphas[t - 1] * likes_tilde[t - 1], axis=-1) + # alphas[t] = (1 / A_t_minus_1) * \ + # transition_matrix.T @ (alphas[t - 1] * likes_tilde[t - 1]) + alphas[t] = (1 / A_t_minus_1) * \ + np.einsum('ab,bc,cd->ad',Pz.T,np.reshape(alphas[t - 1] * likes_tilde[t - 1],(Pz.shape[0],Pt.shape[0])),Pt).ravel() + if A_t_minus_1 > 0 and not np.any(np.isnan(A_t_minus_1)): + marginal_ll += np.sum(np.log(A_t_minus_1) + max_factor[t - 1]) + else: + print("yikes") + + A_t = np.sum(alphas[t] * likes_tilde[t], axis=-1) + marginal_ll += np.sum(np.log(A_t) + max_factor[t]) + + return alphas, marginal_ll + + def backward_pass(self, transition_matrix, log_likes): + (Pz,Pt) = transition_matrix + betas = np.zeros_like(log_likes) + T, K = log_likes.shape + max_factor = np.max(log_likes, axis=1, keepdims=True) + likes_tilde = np.exp(log_likes - max_factor) + + betas[T - 1] = 1 / K + + for t in range(T - 2, -1, -1): # iterate from T-2 ==> 0 + #betas[t] = transition_matrix @ (betas[t + 1] * likes_tilde[t + 1]) + betas[t] = np.einsum('ab,bc,cd->ad',Pz,np.reshape(betas[t + 1] * likes_tilde[t + 1],(Pz.shape[0],Pt.shape[0])),Pt.T).ravel() + betas[t] /= np.sum(betas[t]) # normalize before the next step + + return betas + + def compute_expected_suff_stats(self, dataset, posteriors, taus, fit_observations, fit_transitions): + assert isinstance(dataset, list) + assert isinstance(posteriors, list) + + # Helper function to compute expected counts and sufficient statistics + # for a single time series and corresponding posterior. + def _compute_expected_suff_stats(data, posterior, taus, fit_observations, fit_transitions): + Dx = data["data"].shape[1] + D = data["covariates"].shape[1] + q = posterior.expected_states() + (fancy_e_z, fancy_e_t) = posterior.expected_transitions() #TODO: change to return two matrices + q += 1e-16 + q = q / q.sum(axis=1, keepdims=True) # basically Laplace smoothing + L = taus.shape[0] + K = q.shape[1] / L + q = q.reshape((q.shape[0], int(K), L)) # dim TxKxL + K = q.shape[1] + dxxT_Etau = np.zeros((K, Dx, D)) + xxT = np.zeros((K, D, D)) + dxdxT_Etau2 = np.zeros((K, Dx, Dx)) + T = np.zeros(K) + fancy_e_z_over_T = np.zeros((self.num_discrete_states, self.num_discrete_states)) + fancy_e_t_over_T = np.zeros((len(self.taus), len(self.taus))) + q_one = np.zeros(self.num_discrete_states * len(self.taus)) + for k in range(K): + qzt = q[:, k, :].sum(axis=-1) + + if fit_observations: + #TODO: rewrite with descriptive variable names + q_taugivenz = q[:, k, :] / np.sum(q[:, k, :], axis=-1, keepdims=True) + E_tau_given_k = np.einsum('tl,l -> t', q_taugivenz, taus) # TxL and L -> T + E_tauinv_given_k = np.einsum('tl,l -> t', q_taugivenz, (1/taus)) # TxL and L -> T + + # sufficient stats for A + dxxT_Etau[k, :, :] = np.einsum('t,tij->ij', qzt, data['suff_stats'][2]) + xxT[k, :, :] = np.einsum('t,t,tij->ij', qzt, E_tauinv_given_k, data['suff_stats'][3]) + + # sufficient stats for Q + dxdxT_Etau2[k, :, :] = np.einsum('t,t,tij->ij', qzt, E_tau_given_k, data['suff_stats'][1]) + + + T[k] = np.dot(qzt, data['suff_stats'][0]) + + if fit_transitions: + fancy_e_z_over_T = np.einsum('tij->ij', fancy_e_z) + fancy_e_t_over_T = np.einsum('tij->ij', fancy_e_t) + + q_one = posterior.expected_states()[0] + + stats = (tuple((dxxT_Etau, xxT, dxdxT_Etau2, T)), + tuple((fancy_e_z_over_T, fancy_e_t_over_T, q_one))) + return stats + + # Sum the expected stats over the whole dataset + stats = (None,None) + for data, posterior in zip(dataset, posteriors): + these_stats = _compute_expected_suff_stats(data, posterior, taus, fit_observations, fit_transitions) + stats_cont = sum_tuples(stats[0], these_stats[0]) + stats_disc = sum_tuples(stats[1], these_stats[1]) + stats = (stats_cont, stats_disc) + return stats + + def sample(self, T, bias=False): #TODO: might only work for relatively low total states + observations = self.observations + initial_dist = self.transitions.initial_dist + (Pz,Pt) = self.transitions.transition_matrix + transition_matrix = np.kron(Pz,Pt) + taus = self.taus + if bias: + x = np.hstack((np.zeros((T, observations.data_dim)),np.ones((T,1)))) + else: + x = np.zeros((T, observations.data_dim)) + z = np.zeros((T), dtype=np.int) + num_states = initial_dist.shape[0] + z[0] = np.random.choice(range(initial_dist.shape[0]), p=initial_dist) + + timescaled_weights, timescaled_covs = self.observations.timescale_weights_covs(observations.weights, observations.covs, taus) + if bias: + x[0,:-1] = MultivariateNormal(to_t(np.zeros(observations.data_dim)), to_t(timescaled_covs[z[0], :, :])).sample() + else: + x[0] = MultivariateNormal(to_t(np.zeros(observations.data_dim)), to_t(timescaled_covs[z[0], :, :])).sample() + for i in range(1, T): + z[i] = np.random.choice(range(num_states), p=transition_matrix[z[i - 1], :]) + # mu = timescaled_weights[z[i], :, :-1]@x[i-1] + timescaled_weights[z[i], :, -1] #changed to account for no bias + mu = timescaled_weights[z[i], :, :] @ x[i - 1] + cov = timescaled_covs[z[i], :, :] + if bias: + x[i,:-1] = MultivariateNormal(to_t(mu), to_t(cov)).sample() + else: + x[i] = MultivariateNormal(to_t(mu), to_t(cov)).sample() + if bias: + x = x[:,:-1] + return z, x + + @staticmethod + @njit() + def nb_forward_pass(initial_dist, transition_matrix, log_likes, max_factor): + (Pz,Pt) = transition_matrix + alphas = np.zeros_like(log_likes) + marginal_ll = 0 + T = log_likes.shape[0] + likes_tilde = np.exp(log_likes - max_factor) + + alphas[0] = initial_dist + + for t in range(1, T): + A_t_minus_1 = np.sum(alphas[t - 1] * likes_tilde[t - 1]) + alphas[t] = (1 / A_t_minus_1) * \ + (Pz.T @ (np.reshape(alphas[t - 1] * likes_tilde[t - 1],(Pz.shape[0],Pt.shape[0]))) @ Pt).ravel() + # alphas[t] = (1 / A_t_minus_1) * \ + # np.einsum('ab,bc,cd->ad',Pz.T,np.reshape(alphas[t - 1] * likes_tilde[t - 1],(Pz.shape[0],Pt.shape[0])),Pt).ravel() + # if A_t_minus_1 > 0 and not np.any(np.isnan(A_t_minus_1)): + marginal_ll += np.sum(np.log(A_t_minus_1) + max_factor[t - 1]) + # else: + # print("yikes") + + A_t = np.sum(alphas[t] * likes_tilde[t]) + marginal_ll += np.sum(np.log(A_t) + max_factor[t]) + + return alphas, marginal_ll + + @staticmethod + @njit() + def nb_backward_pass(transition_matrix, log_likes, max_factor): + (Pz,Pt) = transition_matrix + betas = np.zeros_like(log_likes) + T, K = log_likes.shape + likes_tilde = np.exp(log_likes - max_factor) + + betas[T - 1] = 1 / K + + for t in range(T - 2, -1, -1): # iterate from T-2 ==> 0 + betas[t] = (Pz @ (np.reshape(betas[t + 1] * likes_tilde[t + 1],(Pz.shape[0],Pt.shape[0]))) @ Pt.T).ravel() + #betas[t] = np.einsum('ab,bc,cd->ad',Pz,np.reshape(betas[t + 1] * likes_tilde[t + 1],(Pz.shape[0],Pt.shape[0])),Pt.T).ravel() + betas[t] /= np.sum(betas[t]) # normalize before the next step + + return betas + +class LinearRegressionObservations(object): + """ + Wrapper for a collection of Gaussian observation parameters. + """ + + def __init__(self, num_states, data_dim, covariate_dim, taus, covariance_reg, random_weights=True): + """ + Initialize a collection of observation parameters for an HMM whose + observation distributions are linear regressions. The HMM has + `num_states` (i.e. K) discrete states, `data_dim` (i.e. D) + dimensional observations, and `covariate_dim` covariates. + In an ARHMM, the covariates will be functions of the past data. + + Note: self.weights is always the continuous time operator. + """ + self.num_states = num_states + self.data_dim = data_dim + self.covariate_dim = covariate_dim + self.taus = taus + self.covariance_reg = covariance_reg + + # Initialize the model parameters + if random_weights: + self.weights = np.zeros((num_states, data_dim, covariate_dim)) + for i in range(num_states): + self.weights[i,:,:data_dim] = scipy.linalg.logm(random_rotation(data_dim,theta= np.pi/20)) + else: + self.weights = np.zeros((num_states, data_dim, covariate_dim)) + #TODO: do we need this scaling? + self.covs = .05*np.tile(np.eye(data_dim), (num_states, 1, 1)) + + @staticmethod + def precompute_suff_stats(dataset): + """ + Compute the sufficient statistics of the linear regression for each + data dictionary in the dataset. This modifies the dataset in place. + + Parameters + ---------- + dataset: a list of data dictionaries. + + Returns + ------- + Nothing, but the dataset is updated in place to have a new `suff_stats` + key, which contains a tuple of sufficient statistics. + """ + ### + # YOUR CODE BELOW + # + for data in dataset: + x = data['data'] + phi = data['covariates'] + #TODO: update to generalize for lags >1 + if x.shape[1] == phi.shape[1]: #no bias + dx = x - phi + else: + dx = x - phi[:,:-1] + data['suff_stats'] = (np.ones(len(x)), + np.einsum('ti,tj->tij', dx, dx), # dxn dxn.T + np.einsum('ti,tj->tij', dx, phi), # dxn xn-1.T + np.einsum('ti,tj->tij', phi, phi)) # xn-1 xn-1.T + # + ### + + def log_likelihoods(self, data): + """ + Compute the matrix of log likelihoods of data for each state. + (I like to use torch.distributions for this, though it requires + converting back and forth between numpy arrays and pytorch tensors.) + + Parameters + ---------- + data: a dictionary with multiple keys, including "data", the TxD array + of observations for this mouse. + + Returns + ------- + log_likes: a TxK array of log likelihoods for each datapoint and + discrete state. + """ + y = to_t(data["data"]) + x = data["covariates"] + taus = self.taus + + timescaled_weights, timescaled_covs = self.timescale_weights_covs(self.weights,self.covs,taus) + means = to_t(timescaled_weights @ x.T) + covs = to_t(timescaled_covs) + + K, _, _ = means.shape + T, _ = x.shape + log_likes = np.zeros((T, K)) + for k in range(K): + dist = torch.distributions.MultivariateNormal(means[k].T, covs[k],validate_args=False) + log_likes[:, k] = dist.log_prob(y) + # + return log_likes + + def M_step(self, continuous_expectations): + """ + Compute the linear regression parameters given the expected + sufficient statistics. + + Note: add a little bit (1e-4 * I) to the diagonal of each covariance + matrix to ensure that the result is positive definite. + + + Parameters + ---------- + stats: a tuple of expected sufficient statistics + + Returns + ------- + Nothing, but self.weights and self.covs are updated in place. + """ + # stats = tuple((dxxT_over_Etau,xxT_over_Etau)) + dxxT_Etau, xxT, dxdxT_Etau2, T = continuous_expectations + ### + for k in range(self.num_states): + AstarT = np.linalg.solve(xxT[k], dxxT_Etau[k].T) + self.weights[k] = AstarT.T #continuous time operator (unscaled) + self.covs[k] = self.covariance_reg* np.eye(self.data_dim) + \ + (dxdxT_Etau2[k] - dxxT_Etau[k] @ AstarT - AstarT.T @ dxxT_Etau[k].T + AstarT.T @ xxT[k] @ AstarT) / T[k] + + @classmethod + def timescale_weights_covs(cls, weights,covs,taus): + ''' + scale continuous time operator + ''' + tiled_weights = np.repeat(weights,len(taus),axis=0) + tiled_taus = np.tile(taus,weights.shape[0]) + if weights.shape[1] == weights.shape[2]: + timescaled_weights = np.eye(weights.shape[1]) + tiled_weights/tiled_taus[:,None,None] + else: + timescaled_weights = np.hstack((np.eye(weights.shape[1]),np.zeros((weights.shape[1],1)))) + tiled_weights / tiled_taus[:, None, None] + tiled_covs = np.repeat(covs, len(taus), axis=0) + timescaled_covs = tiled_covs/tiled_taus[:,None,None] + return timescaled_weights, timescaled_covs + +class Transitions(object): + def __init__(self, num_discrete_states, num_taus, alpha, kappa, random_init=True): + self.num_discrete_states = num_discrete_states + self.num_taus = num_taus + self.initial_dist = np.ones(self.num_discrete_states*self.num_taus) / (self.num_discrete_states*self.num_taus) + if random_init: + Pz = .99 * np.eye(self.num_discrete_states) + .01 * npr.rand(self.num_discrete_states, + self.num_discrete_states) + Pz /= Pz.sum(axis=1, keepdims=True) + Pt = .95 * np.eye(self.num_taus) + .05 * npr.rand(self.num_taus, self.num_taus) + Pt /= Pt.sum(axis=1, keepdims=True) + else: + if self.num_discrete_states != 1: + Pz = .99 * np.eye(self.num_discrete_states) + .01/(self.num_discrete_states-1) * (np.ones((self.num_discrete_states, + self.num_discrete_states))-np.eye(self.num_discrete_states)) + else: Pz = np.array([[1.]]) + if self.num_taus != 1: + Pt = .95 * np.eye(self.num_taus) + .025 * (np.diag(np.ones(self.num_taus-1), 1) + np.diag(np.ones(self.num_taus-1), -1)) + Pt /= Pt.sum(axis=1, keepdims=True) + else: Pt = np.array([[1.]]) + self.transition_matrix = (Pz,Pt) + self.alpha = alpha + self.kappa = kappa + + def M_step(self, discrete_expectations, fit_z = True, fit_tau = True): #TODO: kron first pass is done + expected_joints_z, expected_joints_t, q_zero = discrete_expectations + if fit_z: + expected_joints_z += self.kappa * np.eye(self.num_discrete_states) + (self.alpha-1) * np.ones((self.num_discrete_states, self.num_discrete_states)) + expected_joints_z += 1e-16 + Pz = np.nan_to_num(expected_joints_z / expected_joints_z.sum(axis=1, keepdims=True)) + else: Pz = self.transition_matrix[0] + + if fit_tau: + expected_joints_t += self.kappa * np.eye(self.num_taus) + (self.alpha - 1) * np.ones((self.num_taus, self.num_taus)) + expected_joints_t += 1e-16 + Pt = np.nan_to_num(expected_joints_t / expected_joints_t.sum(axis=1, keepdims=True)) + else: + Pt = self.transition_matrix[1] + self.transition_matrix = (Pz,Pt) + self.initial_dist = q_zero / np.sum(q_zero, keepdims=True) + +class Posterior(object): + + def __init__(self, model, data, num_states, seed=0): + self.model = model + self.data = data + self.num_states = num_states + self.num_taus = len(self.model.taus) + self.num_discrete_states = self.model.num_discrete_states + self._posterior = self._initialize_posteriors(data, num_states, seed) + + def _initialize_posteriors(self, dataset, num_states, seed=0): + # rng = npr.RandomState(seed) + # expected_states = rng.rand(len(dataset["data"]), num_states) + # expected_states /= expected_states.sum(axis=1, keepdims=True) + expected_taus = np.ones( + (len(dataset["data"]), num_states, 2)) # mu, sigma for each time step and each discrete state + # expected_joints = rng.rand(len(dataset["data"]) - 1, num_states, num_states) + # expected_joints /= expected_joints.sum(axis=(1, 2), keepdims=True) + expected_states = np.zeros((len(dataset["data"]), num_states)) + # expected_joints = (np.zeros((len(dataset["data"]) - 1, self.model.num_discrete_states, self.model.num_discrete_states)), + # np.zeros((len(dataset["data"]) - 1, len(self.model.taus), + # len(self.model.taus)))) + expected_joints = (np.zeros((len(dataset["data"]), self.num_discrete_states,self.num_discrete_states)), + np.zeros((len(dataset["data"]),self.num_taus,self.num_taus))) + return dict(expected_states=expected_states, + expected_joints=expected_joints, + marginal_ll=-np.inf) + + def update(self): + """ + Run the exact message passing algorithm to infer the posterior distribution. + """ + + log_likes = self.model.observations.log_likelihoods(self.data) + #should throw error if compute_joints is False while trying to update transitions + #TODO: better way to handle compute_joints argument + new_posterior = self.model.E_step(self.model.transitions.initial_dist, self.model.transitions.transition_matrix, log_likes, compute_joints=True) + self._posterior = new_posterior + return self + + def get_states(self): + # assumes posterior is already updated + # TODO: replace with Viterbi + # currently: for every z_t, find max q(z_t| x_1:T) + # goal: max z_1:T q(z_1:T| x_1:T) + return self._posterior['expected_states'].argmax(1) + + def marginal_likelihood(self): + """Compute the marginal likelihood of the data under the model. + Returns: + ``\log p(x_{1:T})`` the marginal likelihood of the data + summing over discrete latent state sequences. + """ + if self._posterior is None: + self.update() + return self._posterior["marginal_ll"] + + def expected_states(self): + """Compute the expected values of the latent states under the + posterior distribution. + Returns: + ``E[z_t | x_{1:T}]`` the expected value of the latent state + at time ``t`` given the sequence of data. + """ + if self._posterior is None: + self.update() + return self._posterior["expected_states"] + + def expected_transitions(self): + """Compute the expected transitions of the latent states under the + posterior distribution. + Returns: + ``E[z_t z_{t+1} | x_{1:T}]`` the expected value of + adjacent latent states given the sequence of data. + """ + if self._posterior is None: + self.update() + return self._posterior["expected_joints"] + + @staticmethod + def state_durations(states, total_states): + changepoints = states != np.hstack((states[1:], -1)) # 1 where state change occurs + changepoint_frame = np.where(changepoints)[0] # timestamps of changepoints + changepoint_states = states[changepoints] # state label of changepoint + state_durations = np.diff(np.hstack((0, changepoint_frame))) # duration before each change + state_durations[0] += 1 + durations = [] + for k in range(total_states): + changepoint_indices = changepoint_states == k + durations.append(state_durations[changepoint_indices]) + return durations + + def state_usage(self): + states = self.get_states() + return np.bincount(states, minlength=self.num_states) + + def state_switch(self): + states = self.get_states() + changepoints = states != np.hstack((states[1:], -1)) # 1 where state change occurs + changepoint_states = states[changepoints] # state label of changepoint + return changepoint_states diff --git a/util.py b/util.py new file mode 100644 index 0000000..494c000 --- /dev/null +++ b/util.py @@ -0,0 +1,32 @@ +import numpy as np +import numpy.random as npr + +def random_rotation(n, theta=None): # n: data dimension, theta: angle of rotation + if theta is None: + # Sample a random, slow rotation + theta = 0.5 * np.pi * np.random.rand() + if n == 1: + return np.random.rand() * np.eye(1) + rot = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + out = np.eye(n) + out[:2, :2] = rot + q = np.linalg.qr(np.random.randn(n, n))[0] + return q.dot(out).dot(q.T) + +def sum_tuples(a, b): + assert a or b + if a is None: + return b + elif b is None: + return a + else: + return tuple(ai + bi for ai, bi in zip(a, b)) + +def kron_A_N(A, N): # Simulates np.kron(A, np.eye(N)) + m,n = A.shape + out = np.zeros((m,N,n,N),dtype=A.dtype) + r = np.arange(N) + out[:,r,:,r] = A + out.shape = (m*N,n*N) + return out \ No newline at end of file diff --git a/warhmm_gp.py b/warhmm_gp.py new file mode 100644 index 0000000..7bab9c1 --- /dev/null +++ b/warhmm_gp.py @@ -0,0 +1,428 @@ +import numpy as np +import numpy.random as npr +import scipy.stats +from scipy import linalg as sclin +import torch +from tqdm.auto import trange +from torch.distributions import MultivariateNormal +import pickle +import os +from util import random_rotation, sum_tuples, kron_A_N +import wandb +import time +from numba import njit, prange +from twarhmm import TWARHMM, LinearRegressionObservations, Posterior + +device = torch.device('cpu') +dtype = torch.float64 +to_t = lambda array: torch.tensor(array, device=device, dtype=dtype) +from_t = lambda tensor: tensor.to("cpu").detach().numpy() +kernel_ridge = 1e-4 + + +class TWARHMM_GP(TWARHMM): + def __init__(self, config, taus, kernel): + super().__init__(config, taus) # config is a dictionary containing parameters + self.taus = taus + self.observations = LinearRegressionObservations_GP(self.num_discrete_states, self.data_dim, + self.covariates_dim, self.taus, kernel, + config["covariance_reg"], random_weights=False) + + def fit(self, train_dataset, test_dataset, seed=0, num_epochs=50, fit_observations=True, fit_transitions=True, fit_tau=False, fit_kernel_params=True): + # Fit using full batch EM + num_train = sum([len(data["data"]) for data in train_dataset]) + num_test = sum([len(data["data"]) for data in test_dataset]) + # Initialize with a random posterior + total_states = self.num_discrete_states*self.observations.num_taus + posteriors = [Posterior(self, data_dict, total_states) for data_dict in train_dataset] + for posterior in posteriors: + posterior.update() + continuous_expectations, discrete_expectations = self.compute_expected_suff_stats(train_dataset, posteriors, self.taus, fit_observations, fit_transitions) + train_lls = [] + test_lls = [] + + # Main loop + for itr in trange(num_epochs): + #print(itr) + self.M_step(continuous_expectations, discrete_expectations, fit_observations, fit_transitions, fit_tau, fit_kernel_params) + + for posterior in posteriors: + posterior.update() + + # Compute the expected sufficient statistics under the new posteriors + continuous_expectations, discrete_expectations = self.compute_expected_suff_stats(train_dataset, posteriors, self.taus, fit_observations, fit_transitions) + + # Store the average train likelihood + avg_train_ll = (sum([p.marginal_likelihood() for p in posteriors]) + self.observations.log_prior_likelihood().detach().numpy())/ num_train + train_lls.append(avg_train_ll) # TO DO: need to add prior log likelihood to overall objective function + + # Compute the posteriors for the test dataset too + test_posteriors = [Posterior(self,data_dict,total_states) for data_dict in test_dataset] + + for posterior in test_posteriors: + posterior.update() + + # Store the average test likelihood + avg_test_ll = (sum([p.marginal_likelihood() for p in test_posteriors]) ) / num_test + test_lls.append(avg_test_ll) + + # convert lls to arrays + train_lls = np.array(train_lls) + test_lls = np.array(test_lls) + return train_lls, test_lls, posteriors, test_posteriors + + def fit_stoch(self, train_dataset, test_dataset, forgetting_rate=-0.5, seed=0, num_epochs=5, fit_observations=True, + fit_transitions=True, fit_tau = True, compute_posteriors=True, fit_kernel_params=True, wandb_log=False): + # Get some constants + num_batches = len(train_dataset) + taus = np.array(self.taus) + num_test = sum([len(data["data"]) for data in test_dataset]) + total_states = self.num_discrete_states * len(self.taus) + num_train = sum([len(data["data"]) for data in train_dataset]) + + # Initialize the step size schedule + schedule = np.arange(1, 1 + num_batches * num_epochs) ** (forgetting_rate) + + # Initialize progress bars + outer_pbar = trange(num_epochs) + inner_pbar = trange(num_batches) + outer_pbar.set_description("Epoch") + inner_pbar.set_description("Batch") + + # Main loop + rng = npr.RandomState(seed) + train_lls = [] + test_lls = [] + + it_times = np.zeros((num_epochs,num_batches)) + + for epoch in range(num_epochs): + perm = rng.permutation(num_batches) + + inner_pbar.reset() + for itr in range(num_batches): + t = time.time() + minibatch = [train_dataset[perm[itr]]] + this_num_train = len(minibatch[0]["data"]) + + posteriors = [Posterior(self, data, total_states) for data in minibatch] + + # E step: on this minibatch + for posterior in posteriors: + posterior.update() + + if itr == 0 and epoch == 0: continuous_expectations, discrete_expectations = self.compute_expected_suff_stats( + minibatch, posteriors, taus, fit_observations, fit_transitions) + # M step: using current stats + self.M_step(continuous_expectations, discrete_expectations, fit_observations, fit_transitions, fit_tau, fit_kernel_params) + + these_continuous_expectations, these_discrete_expectations = self.compute_expected_suff_stats(minibatch, + posteriors, + taus, fit_observations, + fit_transitions) + rescale = lambda x: num_train / this_num_train * x + + # Rescale the statistics as if they came from the whole dataset + rescaled_cont_stats = tuple(rescale(st) for st in these_continuous_expectations) + rescaled_disc_stats = tuple(rescale(st) for st in these_discrete_expectations) + + # Take a convex combination of the statistics using current step sz + stepsize = schedule[epoch * num_batches + itr] + continuous_expectations = tuple( + sum(x) for x in zip(tuple(st * (1 - stepsize) for st in continuous_expectations), + tuple(st * (stepsize) for st in rescaled_cont_stats))) + discrete_expectations = tuple( + sum(x) for x in zip(tuple(st * (1 - stepsize) for st in discrete_expectations), + tuple(st * (stepsize) for st in rescaled_disc_stats))) + + # Store the normalized log likelihood for this minibatch + avg_mll = (sum([p.marginal_likelihood() for p in posteriors])+ self.observations.log_prior_likelihood().detach().numpy()) / this_num_train + train_lls.append(avg_mll) + + elapsed = time.time()-t + #print(elapsed) + it_times[epoch,itr] = elapsed + inner_pbar.set_description("Batch LL: {:.3f}".format(avg_mll)) + inner_pbar.update() + if wandb_log: wandb.log({'batch_ll': avg_mll}) + + # Evaluate the likelihood and posteriors on the test dataset + if compute_posteriors: + test_posteriors = [Posterior(self, test_data, total_states, seed) for test_data in test_dataset] + for posterior in test_posteriors: + posterior.update() + avg_test_mll = (sum([p.marginal_likelihood() for p in test_posteriors])) / num_test + else: + mlls = [] + for test_data in test_dataset: + posterior = Posterior(self, test_data, total_states, seed) + posterior.update() + mlls.append(posterior.marginal_likelihood()) + avg_test_mll = np.sum(mlls)/ num_test + test_posteriors = None + test_lls.append(avg_test_mll) + outer_pbar.set_description("Test LL: {:.3f}".format(avg_test_mll)) + outer_pbar.update() + if wandb_log: wandb.log({'test_ll': avg_test_mll}) + + + # convert lls to arrays + train_lls = np.array(train_lls) + test_lls = np.array(test_lls) + + print('average iteration time: ', it_times.mean()) + return train_lls, test_lls, posteriors, test_posteriors + + def M_step(self, continuous_expectations, discrete_expectations, fit_observations, fit_transitions, fit_tau, + fit_kernel_params, hyper_M_iter=100): + if fit_transitions: self.transitions.M_step(discrete_expectations, fit_tau=fit_tau) + if fit_observations: self.observations.M_step(continuous_expectations) + if fit_kernel_params: self.observations.hyper_M_step(niter=hyper_M_iter, learning_rate=1e-6) + + def compute_expected_suff_stats(self, dataset, posteriors, taus, fit_observations=True, fit_transitions=False): + assert isinstance(dataset, list) + assert isinstance(posteriors, list) + + # Helper function to compute expected counts and sufficient statistics + # for a single time series and corresponding posterior. + def _compute_expected_suff_stats(data, posterior, taus, fit_transitions): + dxdxT, dxxT, xxT = data['suff_stats_gp'] + (fancy_e_z, fancy_e_t) = posterior.expected_transitions() + + T,D,_ = xxT.shape + _,Dx,_ = dxdxT.shape + M = self.observations.num_taus + K = self.num_discrete_states + + w = posterior.expected_states().reshape((T,K,M)) + + # initializing, in case fit_observations or fit_transitions is false + fancy_e_z_over_T = np.zeros((self.num_discrete_states, self.num_discrete_states)) + fancy_e_t_over_T = np.zeros((len(self.taus), len(self.taus))) + q_one = np.zeros((self.num_discrete_states, len(self.taus))) + + xxTw = np.zeros((self.num_discrete_states, D, D, len(self.taus))) + dxxTw = np.zeros((self.num_discrete_states, Dx, D, len(self.taus))) + dxdxTw = np.zeros((self.num_discrete_states, Dx, Dx, len(self.taus))) + + if fit_observations: + + xxTw = np.einsum('tij, tkm -> kijm', xxT, w, optimize='optimal') # K x D x D x M + dxxTw = np.einsum('tij, tkm -> kijm', dxxT, w, optimize='optimal') + dxdxTw = np.einsum('tij, tkm -> kijm', dxdxT, w, optimize='optimal') + + wk = w.sum(axis=(0,2)) + + if fit_transitions: + fancy_e_z_over_T = np.einsum('tij->ij', fancy_e_z, optimize='optimal') + fancy_e_t_over_T = np.einsum('tij->ij', fancy_e_t, optimize='optimal') + + q_one = posterior.expected_states()[0] + + stats = (tuple((xxTw, dxxTw, dxdxTw, wk)), + tuple((fancy_e_z_over_T, fancy_e_t_over_T, q_one))) + + return stats + + # Sum the expected stats over the whole dataset + stats = (None,None) + for data, posterior in zip(dataset, posteriors): + these_stats = _compute_expected_suff_stats(data, posterior, taus, fit_transitions) + stats_cont = sum_tuples(stats[0], these_stats[0]) + stats_disc = sum_tuples(stats[1], these_stats[1]) + stats = (stats_cont, stats_disc) + return stats + + +class LinearRegressionObservations_GP(LinearRegressionObservations): + """ + Wrapper for a collection of Gaussian observation parameters. + """ + + def __init__(self, num_states, data_dim, covariate_dim, taus, kernel, covariance_reg, random_weights=True): + super().__init__(num_states, data_dim, covariate_dim, taus, covariance_reg, random_weights=True) + # self.priorCov = kernel(to_t(taus)) # covariance matrix num_discrete_states x ndim(tau) x ndim(tau) + self.num_taus = len(taus) + self.weight_gp_kernel = kernel + + # changing shape of weights to match KxDxDxM + if random_weights: + self.weights = np.zeros((num_states, data_dim, covariate_dim, self.num_taus)) + for k in range(num_states): + for m in range(self.num_taus): + self.weights[k, :, :data_dim, m] = scipy.linalg.logm(random_rotation(data_dim, theta=np.pi / 20)) + else: + self.weights = np.zeros((num_states, data_dim, covariate_dim, self.num_taus)) + + # adding in covs here to adjust initialization more easily + self.covs = .15 * np.tile(np.eye(data_dim), (num_states, 1, 1)) + + @staticmethod + def precompute_suff_stats(dataset): + """ + Compute the sufficient statistics of the linear regression for each + data dictionary in the dataset. This modifies the dataset in place. + + Parameters + ---------- + dataset: a list of data dictionaries. + + Returns + ------- + Nothing, but the dataset is updated in place to have a new `suff_stats_gp` + key, which contains a tuple of sufficient statistics. + """ + # TODO: diff or dx??? leaning towards diff based on scott's derivation + for data in dataset: + x = data['data'] # t = 2 : T + # diff = np.diff(x, axis=0) + phi = data['covariates'] # t = 1:T-1 + diff = x[1:] - x[:-1] # easier to read for now + # TODO: update to generalize for lags >1 + if x.shape[1] == phi.shape[1]: # no bias + dx = x - phi + else: + dx = x - phi[:, :-1] + data['suff_stats_gp'] = (np.einsum('ti,tj->tij', dx, dx), # dxn dxn.T + np.einsum('ti,tj->tij', dx, phi), # dxn xn-1.T + np.einsum('ti,tj->tij', phi, phi)) + + def M_step(self, continuous_expectations): + """ + Compute the linear regression parameters given the expected + sufficient statistics. + + Note: add a little bit (1e-4 * I) to the diagonal of each covariance + matrix to ensure that the result is positive definite. + + + Parameters + ---------- + stats: a tuple of expected sufficient statistics + + Returns + ------- + Nothing, but self.weights and self.covs are updated in place. + """ + # stats = tuple((dxxT_over_Etau,xxT_over_Etau)) + # H,wxxT = continuous_expectations # KxDxDxM, KxMxDxD + # w,xxT = continuous_expectations + + xxTw,dxxTw, dxdxTw, wk = continuous_expectations # LD: modified this to try tensor update + + D = self.covariate_dim + Dx = self.data_dim + K = self.num_states + M = len(self.taus) + Qinv = np.linalg.inv(self.covs) + + Ker = self.weight_gp_kernel(to_t(self.taus)).detach().numpy() + kernel_ridge*np.eye(self.num_taus)[None,:,:] # add small ridge for stability + Kinv = np.linalg.inv(Ker) + + # tensor version ... maybe slower but for debugging purposes + Ahat = np.zeros((K, Dx, D, M)) + Qhat = np.zeros((K, Dx, Dx)) + + + for k in range(K): + J1t = kron_A_N(Kinv[k,:,:], Dx*D) + J2t = sclin.block_diag(*np.kron(xxTw[k].transpose(2,0,1), Qinv[k])) + + Sigma = J1t + J2t + QinvdxxTw = np.einsum('pj, jlm -> plm', Qinv[k], dxxTw[k], optimize='optimal') # D x D x M + + mu = np.linalg.inv(Sigma) @ QinvdxxTw.flatten(order='F') # linear solve might be faster here + Ahat[k,:,:,:] = mu.reshape(Dx,D,M,order='F') # C style reordering would be faster but using column vector convention for now + + # update covariance + AxxTwAT = np.einsum('ijm, jlm, plm -> ip', Ahat[k], xxTw[k], Ahat[k], optimize='optimal') + dxxTATw = np.einsum('jlm, plm -> jp', dxxTw[k], Ahat[k], optimize='optimal') + + Qhat[k,:,:] = (AxxTwAT + dxdxTw[k].sum(axis=2) - dxxTATw - dxxTATw.T) / wk[k] + + # update stored parameters + self.weights = Ahat + self.covs = Qhat + + def log_likelihoods(self, data): + """ + Compute the matrix of log likelihoods of data for each state. + (I like to use torch.distributions for this, though it requires + converting back and forth between numpy arrays and pytorch tensors.) + Parameters + ---------- + data: a dictionary with multiple keys, including "data", the TxD array + of observations for this mouse. + Returns + ------- + log_likes: a TxK array of log likelihoods for each datapoint and + discrete state. + """ + y = to_t(data["data"]) + x = data["covariates"] + + # T,_ = x.shape + # + # K,Dx,D,M = self.weights.shape + # + # means = np.zeros((T,K,M,Dx)) + # + # if self.weights.shape[1] == self.weights.shape[2]: + # eye_weights = self.weights + np.eye(self.weights.shape[1])[None, :, :, None] + # else: + # eye_weights = self.weights + np.column_stack( + # (np.eye(self.weights.shape[1]), np.zeros((self.weights.shape[1], 1))))[None, :, :, None] + # + # for k in range(K): + # for m in range(M): + # means[:,k,m,:] = np.einsum('ij, tj -> ti', eye_weights[k,:,:,m], x) + # + # means = to_t(means) + if self.weights.shape[1] == self.weights.shape[2]: + means = to_t(np.einsum('kijm, tj -> tkmi', self.weights + np.eye(self.weights.shape[1])[None,:,:,None], x, optimize='optimal')) + else: + means = to_t( + np.einsum('kijm, tj -> tkmi', self.weights + np.column_stack((np.eye(self.weights.shape[1]), np.zeros((self.weights.shape[1], 1))))[None,:,:,None], x, optimize='optimal')) + covs = to_t(self.covs) + + log_likes = torch.distributions.MultivariateNormal(means, covs[None, :, None, :, :], + validate_args=False).log_prob(y[:, None, None, :]) # gives TxKxM log likelihoods + T,K,M = log_likes.shape + return log_likes.reshape((T,K*M)).numpy() + + def log_prior_likelihood(self): + tau_grid_torch = to_t(self.taus) + + Kcov = self.weight_gp_kernel(tau_grid_torch) + kernel_ridge*torch.eye(self.num_taus)[None,:,:] # add small ridge for stability + Kinv = torch.inverse(Kcov) + + A_tensor = to_t(self.weights) # num_z_states x data_dim x data_dim x num_tau_states + # pdb.set_trace() + + # \sum_{ijk} -0.5 * a_ijk ' * inv(K_k) * a_ijk + Kia = torch.matmul(Kinv, A_tensor.permute(0,3,1,2).flatten(2,3)) # now assuming A is K x D x D x M + + quad_term = -0.5 * torch.sum(Kia * A_tensor.permute(0,3,1,2).flatten(2,3)) + + # \sum_k -0.5 * D^2 * log|K_k| + + log_det_term = -0.5 * self.data_dim**2 * torch.sum(Kcov.logdet()) + + return quad_term + log_det_term + + + def hyper_M_step(self, niter=100, learning_rate=1e-3): + # function to optimize hyperparameters inside kernel object + optimizer = torch.optim.SGD(self.weight_gp_kernel.parameters(), lr=learning_rate) + # optimizer = torch.optim.LBFGS(self.weight_gp_kernel.parameters(), lr=learning_rate, history_size=100, line_search_fn=None) + + def closure(): + optimizer.zero_grad() + loss = -self.log_prior_likelihood() + loss.backward() + return loss + + for i in range(niter): + optimizer.step(closure) + +