From 2c977b3e1b3bd9344cbc69bc62a02acfad00a07b Mon Sep 17 00:00:00 2001 From: joshuahwu Date: Tue, 25 Jul 2023 22:28:32 -0400 Subject: [PATCH] added tutorials, contributing.md, codeowner, data folder, results folder --- .github/CODEOWNERS | 1 + .gitignore | 7 +- CONTRIBUTING.md | 0 configs/param_configs/fitsne.yaml | 26 - configs/tutorial.yaml | 20 +- data/demo_meta.csv | 3 + environment.yml | 2 +- src/dappy/DataStruct.py | 5 +- src/dappy/analysis.py | 121 +++- src/dappy/preprocess.py | 3 +- src/dappy/read.py | 15 +- src/dappy/utils.py | 12 +- src/dappy/visualization.py | 9 +- .../run.py => tutorials/automated_pipeline.py | 2 +- tutorials/tutorial.ipynb | 515 ++++++++++++++---- 15 files changed, 554 insertions(+), 187 deletions(-) create mode 100644 .github/CODEOWNERS create mode 100644 CONTRIBUTING.md delete mode 100644 configs/param_configs/fitsne.yaml create mode 100644 data/demo_meta.csv rename src/dappy/run.py => tutorials/automated_pipeline.py (98%) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000..bde2deb --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1 @@ +# * @joshuahwu diff --git a/.gitignore b/.gitignore index ac1bcfd..999308f 100644 --- a/.gitignore +++ b/.gitignore @@ -9,7 +9,6 @@ Desktop.ini # Recycle Bin used on file shares $RECYCLE.BIN/ -*.mat *.sh *.out *.p @@ -22,22 +21,28 @@ $RECYCLE.BIN/ *.txt *.mat +*.h5 *.mp4 *.yaml *.png *.pyc *./dappy/__pycache__/ +!/data/demo_meta.csv !/tutorials/ /tutorials/* !/tutorials/tutorial.ipynb +!/tutorials/automated_pipeline.py /configs/* +/configs/param_configs/ +/configs/param_configs/* !/configs/tutorial.yaml /dappy/wandb/ /dappy/models/ /dappy/artifcats/ +/results/* *.egg-info # Windows shortcuts diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..e69de29 diff --git a/configs/param_configs/fitsne.yaml b/configs/param_configs/fitsne.yaml deleted file mode 100644 index d5753b9..0000000 --- a/configs/param_configs/fitsne.yaml +++ /dev/null @@ -1,26 +0,0 @@ -label: 'fitsne' - -analysis: 'embed' - -downsample: 10 - -column: 'Condition' -density_by_column: ['label'] - -filter_still: False - -single_embed: - method: 'fitsne' - perplexity: 50 - lr: 'auto' - sigma: 15 - -transform_embed: - method: 'knn' - k: 5 - sigma: 15 - -skeleton_vids: False - -save_embedder: True -load_embedder: null \ No newline at end of file diff --git a/configs/tutorial.yaml b/configs/tutorial.yaml index 692acd5..7d0b834 100644 --- a/configs/tutorial.yaml +++ b/configs/tutorial.yaml @@ -1,36 +1,26 @@ # Folder path of data location -data_path: '/home/exx/Desktop/GitHub/CAPTURE_data/ensemble_healthy/' - -# File path of predictions file -pose_path: '/home/exx/Desktop/GitHub/CAPTURE_data/ensemble_healthy/predictions.mat' - -# File path of metadata file -meta_path: '/home/exx/Desktop/GitHub/CAPTURE_data/ensemble_healthy/metadata.csv' +data_path: '../data/' # Output folder of all plots -out_path: '/home/exx/Desktop/GitHub/results/ensemble_healthy/' - -# Path of list of behavior heuristics -heuristics_path: '/home/exx/Desktop/GitHub/dappy/src/dappy/behavior_heuristics.py' +out_path: '../results/tutorial/' # File path of skeletal specifications -skeleton_path: '/home/exx/Desktop/GitHub/dappy/src/dappy/skeletons.py' +skeleton_path: '../src/dappy/skeletons.py' # Key of skeleton in the skeletons file skeleton_name: 'mouse20_notail' label: 'fitsne' - -analysis: 'embed' - downsample: 10 +# Parameters for t-SNE embedding single_embed: method: 'fitsne' perplexity: 50 lr: 'auto' sigma: 15 +# Parameters for embedding new data into an existing t-SNE embedding transform_embed: method: 'knn' k: 5 diff --git a/data/demo_meta.csv b/data/demo_meta.csv new file mode 100644 index 0000000..b98a363 --- /dev/null +++ b/data/demo_meta.csv @@ -0,0 +1,3 @@ +AnimalID,Sex,Strain,Condition +A0,Male,Adora2a-Cre,Baseline +A1,Female,Adora2a-Cre,Baseline \ No newline at end of file diff --git a/environment.yml b/environment.yml index a65487e..5efe355 100644 --- a/environment.yml +++ b/environment.yml @@ -7,7 +7,7 @@ channels: dependencies: - python=3.8 - numpy - - faiss-gpu=1.6.5 + - faiss-gpu=1.7.1 - matplotlib - seaborn - hdf5storage diff --git a/src/dappy/DataStruct.py b/src/dappy/DataStruct.py index e631b91..21452f3 100644 --- a/src/dappy/DataStruct.py +++ b/src/dappy/DataStruct.py @@ -163,7 +163,7 @@ def __init__( joint_names: List[str], colors: Union[np.ndarray, List[Tuple[float, float, float, float]]], links: Union[np.ndarray, List[Tuple[int, int]]], - angles: Union[np.ndarray, List[Tuple[int, int, int]]], + angles: Optional[Union[np.ndarray, List[Tuple[int, int, int]]]] = None, ): """Initializes instance of Connectivity class @@ -183,7 +183,8 @@ def __init__( self.joint_names = joint_names self.colors = self._check_type(colors, np.float32) self.links = self._check_type(links, np.uint16) - self.angles = self._check_type(angles, np.uint16) + if angles is not None: + self.angles = self._check_type(angles, np.uint16) def _check_type( self, diff --git a/src/dappy/analysis.py b/src/dappy/analysis.py index 8a3b25a..36b5da4 100644 --- a/src/dappy/analysis.py +++ b/src/dappy/analysis.py @@ -12,9 +12,92 @@ from sklearn.ensemble import RandomForestRegressor import seaborn as sns from dappy.embed import Watershed - +import faiss +import time +from scipy.sparse import csr_matrix +from scipy.sparse.csgraph import dijkstra, minimum_spanning_tree from scipy.spatial import distance +def get_nn_graph(X: np.ndarray, k: int = 5, weighted: bool = True): + X = np.ascontiguousarray(X, dtype=np.float32) + + # max_k = 20 + print("Building NN Graph") + start_time = time.time() + index = faiss.IndexFlatL2(X.shape[1]) + index.add(X) + distances, indices = index.search(X, k=k+1) + distances, indices = distances[:, 1:], indices[:, 1:] + row = np.tile(np.arange(X.shape[0])[:, None], k) + + # min_distances, min_indices = distances[:, :k], indices[:,:k] + # min_row = row = np.tile(np.arange(X.shape[0])[:, None], k) + if weighted: + nn_graph = csr_matrix( + (distances.flatten(), (row.flatten(), indices.flatten())), + shape=(X.shape[0], X.shape[0]), + ) + + # min_graph = csr_matrix( + # (min_distances.flatten(), (min_row.flatten(), min_indices.flatten())), + # shape=(X.shape[0], X.shape[0]), + # ) + else: + nn_graph = csr_matrix( + (np.ones(distances.flatten().shape), (row.flatten(), indices.flatten())), + shape=(X.shape[0], X.shape[0]), + ) + # min_graph = csr_matrix( + # (np.ones(min_distances.flatten()), (min_row.flatten(), min_indices.flatten())), + # shape=(X.shape[0], X.shape[0]), + # ) + + print("NN Time: " + str(time.time() - start_time)) + + # # Get minimum spanning tree to ensure full connectivity in graph + # start_time = time.time() + # min_span_tree = minimum_spanning_tree(nn_graph) + # min_span_tree.data = min_span_tree.data.astype(X.dtype) + # print("Minimum Spanning Tree Time: " + str(time.time() - start_time)) + + # # Get union between minimum spanning tree and nn graph + # min_span_tree_insert = min_span_tree - nn_graph + # min_span_tree_insert.data = np.where(min_span_tree_insert.data < 0, 1, 0) + # graph = ( + # min_span_tree + # - min_span_tree.multiply(min_span_tree_insert) + # + nn_graph.multiply(min_span_tree_insert) + # ) + + return nn_graph + + +def get_pose_geodesic( + pose: np.ndarray, + graph: csr_matrix, + START_FRAME: int, + END_FRAME: int, +): + print("Calculating Dijkstra") + path_indices = dijkstra( + csgraph=graph, directed=False, indices=END_FRAME, return_predecessors=True + )[1] + + print("Finding pose geodesic") + geodesic_pose, geodesic_indices = [], [] + curr_frame = START_FRAME + + while path_indices[curr_frame] > 0: + geodesic_pose += [pose[curr_frame : curr_frame + 1, ...]] + geodesic_indices += [curr_frame] + curr_frame = path_indices[curr_frame] + + if curr_frame != END_FRAME: + print("Broken graph") + + geodesic_pose = np.concatenate(geodesic_pose, axis=0) + + return geodesic_pose, geodesic_indices def cluster_freq_from_data(data: np.ndarray, watershed: Watershed): """ @@ -63,12 +146,6 @@ def lstsq(freq: np.ndarray, y: np.ndarray, filepath: str): m = np.linalg.lstsq(np.delete(freq, i, axis=0), np.delete(y, i))[0] pred_y[i] = freq[i, :] @ m - plt.scatter(y, pred_y) - plt.xlabel("Real Fluorescence") - plt.ylabel("Predicted Fluorescence") - plt.savefig("".join([filepath, "lstsq.png"])) - plt.close() - print("R2 Score " + str(r2_score(y, pred_y))) return pred_y @@ -85,16 +162,16 @@ def elastic_net(freq: np.ndarray, y: np.ndarray, filepath: str): regr.fit(scaler.transform(temp_lesion), np.log2(np.delete(y, i))) pred_y[i] = regr.predict(scaler.transform(freq[i, :][None, :])) - sns.set(rc={'figure.figsize':(6,5)}) - f = plt.figure() - # import pdb; pdb.set_trace() - plt.plot(np.linspace(y.min(), y.max(), 100), np.linspace(y.min(),y.max(),100), markersize=0, color='k', label="y = x") - plt.legend(loc="upper center") - plt.scatter(y, 2**pred_y, s=30) - plt.xlabel("Real Fluorescence") - plt.ylabel("Predicted Fluorescence") - plt.savefig("".join([filepath, "elastic.png"])) - plt.close() + # sns.set(rc={'figure.figsize':(6,5)}) + # f = plt.figure() + # # import pdb; pdb.set_trace() + # plt.plot(np.linspace(y.min(), y.max(), 100), np.linspace(y.min(),y.max(),100), markersize=0, color='k', label="y = x") + # plt.legend(loc="upper center") + # plt.scatter(y, 2**pred_y, s=30) + # plt.xlabel("Real Fluorescence") + # plt.ylabel("Predicted Fluorescence") + # plt.savefig("".join([filepath, "elastic.png"])) + # plt.close() print("R2 Score " + str(r2_score(y, 2**pred_y))) return pred_y, @@ -148,11 +225,11 @@ def random_forest(freq: np.ndarray, y: np.ndarray, filepath: str): rf_regr.fit(np.delete(freq, i, axis=0), np.delete(y, i)) pred_y[i] = rf_regr.predict(freq[i, :][None, :]) - plt.scatter(y, pred_y) - plt.xlabel("Real Fluorescence") - plt.ylabel("Predicted Fluorescence") - plt.savefig("".join([filepath, "rforest.png"])) - plt.close() + # plt.scatter(y, pred_y) + # plt.xlabel("Real Fluorescence") + # plt.ylabel("Predicted Fluorescence") + # plt.savefig("".join([filepath, "rforest.png"])) + # plt.close() print("R2 Score " + str(r2_score(y, pred_y))) return diff --git a/src/dappy/preprocess.py b/src/dappy/preprocess.py index 414aa91..fbf0787 100644 --- a/src/dappy/preprocess.py +++ b/src/dappy/preprocess.py @@ -220,8 +220,7 @@ def anipose_med_filt( for _, i in enumerate(tqdm(np.unique(exp_id))): pose_exp = pose[exp_id == i, :, :] - # dxyz = get_frame_diff(pose_exp, time=1, idx_center=False) - # vel = + pose_error = pose_exp - scp_ndi.median_filter( pose_exp, (filter_len, 1, 1) ) # Median filter 5 frames repeat the ends of video diff --git a/src/dappy/read.py b/src/dappy/read.py index bb8ddb3..f744f21 100644 --- a/src/dappy/read.py +++ b/src/dappy/read.py @@ -96,7 +96,7 @@ def pose_mat( path: str, connectivity: Connectivity, dtype: Optional[Type[Union[np.float64, np.float32]]] = np.float32, -): +) -> np.ndarray: ## TODO: Use output docstrings """Reads 3D pose data from .mat files. @@ -183,6 +183,19 @@ def connectivity(path: str, skeleton_name: str): return connectivity +def connectivity_config(path: str): + skeleton_config = config(path) + + joint_names = skeleton_config["LABELS"] + colors = skeleton_config["COLORS"] + links = skeleton_config["SEGMENTS"] + + connectivity = Connectivity( + joint_names=joint_names, colors=colors, links=links + ) + + return connectivity + def features_h5( path, dtype: Optional[Type[Union[np.float64, np.float32]]] = np.float32 diff --git a/src/dappy/utils.py b/src/dappy/utils.py index 76bb1f0..481b597 100644 --- a/src/dappy/utils.py +++ b/src/dappy/utils.py @@ -3,15 +3,18 @@ import numpy as np from typing import Union, List + def by_id(func): @functools.wraps(func) - def wrapper(pose: np.ndarray, ids:Union[np.ndarray, List], **kwargs): + def wrapper(pose: np.ndarray, ids: Union[np.ndarray, List], **kwargs): for _, i in enumerate(tqdm(np.unique(ids))): - pose_exp = pose[ids == i,:,:] - pose[ids == i ,:,:] = func(pose_exp, **kwargs) + pose_exp = pose[ids == i, :, :] + pose[ids == i, :, :] = func(pose_exp, **kwargs) return pose + return wrapper + def rolling_window(data:np.ndarray, window:int): """ Returns a view of data windowed (data.shape, window) @@ -36,6 +39,7 @@ def rolling_window(data:np.ndarray, window:int): np.lib.stride_tricks.as_strided(d_pad, shape=shape, strides=strides), 0, 1 ) + def get_frame_diff(x: np.ndarray, time: int, idx_center: bool = True): """ IN: @@ -81,4 +85,4 @@ def standard_scale(features, clip=None): features = np.clip(features / feat_std[feat_std != 0], -clip, clip) labels = [label for i, label in enumerate(labels) if feat_std[i] != 0] - return features, labels \ No newline at end of file + return features, labels diff --git a/src/dappy/visualization.py b/src/dappy/visualization.py index 3b3fe6d..c72cb63 100644 --- a/src/dappy/visualization.py +++ b/src/dappy/visualization.py @@ -910,6 +910,9 @@ def pose3D_arena( VID_NAME: str = "0.mp4", SAVE_ROOT: str = "./test/pose_vids/", ): + if isinstance(frames, int): + frames = [frames] + pose_3d, limits, links, COLORS = _init_vid3D( pose, connectivity, np.array(frames,dtype=int), centered, N_FRAMES, SAVE_ROOT ) @@ -983,6 +986,8 @@ def pose3D_grid( VID_NAME: str = "0.mp4", SAVE_ROOT: str = "./test/pose_vids/", ): + if isinstance(frames, int): + frames = [frames] # Reshape pose and other variables pose_3d, limits, links, COLOR = _init_vid3D( pose, connectivity, np.array(frames,dtype=int), centered, N_FRAMES, SAVE_ROOT @@ -992,7 +997,7 @@ def pose3D_grid( writer = FFMpegWriter(fps=fps) # Set up figure cols = min(4, len(frames)) - rows = int(len(frames) / 4) + rows = int(len(frames) / 4) + 1 figsize = (cols * 5, rows * 5) fig = plt.figure(figsize=figsize) @@ -1041,6 +1046,8 @@ def pose3D_features( VID_NAME: str = "0.mp4", SAVE_ROOT: str = "./test/skeleton_vids/", ): + if isinstance(frames, int): + frames = [frames] # Reshape pose and other variables pose_3d, limits, links_expand, COLOR = _init_vid3D( pose, connectivity, frames, N_FRAMES, SAVE_ROOT diff --git a/src/dappy/run.py b/tutorials/automated_pipeline.py similarity index 98% rename from src/dappy/run.py rename to tutorials/automated_pipeline.py index 262a2f7..387a366 100644 --- a/src/dappy/run.py +++ b/tutorials/automated_pipeline.py @@ -5,7 +5,7 @@ from dappy.embed import Watershed, Embed from pathlib import Path - +#TODO: Probably be like a demo/notebook, don't maintain this def standard_features( pose, connectivity, diff --git a/tutorials/tutorial.ipynb b/tutorials/tutorial.ipynb index 98d3ab8..5d1f9ea 100644 --- a/tutorials/tutorial.ipynb +++ b/tutorials/tutorial.ipynb @@ -7,10 +7,10 @@ "source": [ "## Unsupervised Behavioral Phenotyping with 3D Skeletal Pose\n", "Joshua Wu\n", - "\n", + "Duke University Biomedical Engineering\n", "Timothy Dunn Lab\n", "\n", - "14 October 2022" + "25 July, 2023" ] }, { @@ -26,7 +26,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "This notebook implements a Python version of [CAPTURE (Marshall, 2020)](https://www.cell.com/neuron/fulltext/S0896-6273(20)30894-1?_returnURL=https%3A%2F%2Flinkinghub.elsevier.com%2Fretrieve%2Fpii%2FS0896627320308941%3Fshowall%3Dtrue), which was based on earlier work [MotionMapper (Berman, 2014)](https://royalsocietypublishing.org/doi/full/10.1098/rsif.2014.0672) for the analysis of behavioral data, which can interface with future frameworks." + "This notebook implements a Python version of [CAPTURE (Marshall, 2020)](https://www.cell.com/neuron/fulltext/S0896-6273(20)30894-1?_returnURL=https%3A%2F%2Flinkinghub.elsevier.com%2Fretrieve%2Fpii%2FS0896627320308941%3Fshowall%3Dtrue), which was based on earlier work [MotionMapper (Berman, 2014)](https://royalsocietypublishing.org/doi/full/10.1098/rsif.2014.0672) for the analysis of behavioral data." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To get this notebook to run, please download the [demo dataset](https://duke.box.com/v/demo-mouse-poses) into the `/dappy/data/` directory." ] }, { @@ -39,17 +46,17 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ - "from dappy import features, read, write\n", - "import dappy.DataStruct as ds\n", - "import dappy.visualization as vis\n", + "from dappy import read, write\n", + "from dappy import visualization as vis\n", "import numpy as np\n", "import time\n", "from IPython.display import Video\n", - "from dappy.embed import Watershed, Embed\n", + "from pathlib import Path\n", + "import matplotlib.pyplot as plt\n", "%matplotlib inline" ] }, @@ -58,119 +65,299 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Load pose predictions and keypoint connectivity information" + "Load pose predictions, keypoint connectivity information, and metadata." ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "analysis_key = \"tutorial\"\n", - "config = read.config(\"../../configs/\" + analysis_key + \".yaml\")\n", + "config = read.config(\"../configs/\" + analysis_key + \".yaml\")\n", "\n", - "pose, ids = read.pose_h5(config[\"data_path\"] + \"pose_aligned.h5\")\n", + "pose, ids = read.pose_h5(config[\"data_path\"] + \"demo_mouse.h5\")\n", "\n", "connectivity = read.connectivity(\n", " path=config[\"skeleton_path\"], skeleton_name=config[\"skeleton_name\"]\n", ")\n", "\n", - "meta, meta_by_frame = read.meta(config[\"meta_path\"], id=ids)" + "meta, meta_by_frame = read.meta(config[\"data_path\"] + \"demo_meta.csv\", id=ids)\n", + "\n", + "Path(config[\"out_path\"]).mkdir(parents=True, exist_ok=True)" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "Plot some skeletons together" + "`pose` shape (# frames x # keypoints x 3 coordinates)." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Pose shape (# frames x # keypoints x 3 coordinates): \n", + "(648000, 18, 3)\n" + ] + } + ], + "source": [ + "print(\"Pose shape (# frames x # keypoints x 3 coordinates): \")\n", + "print(pose.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`meta` contains categorical information on recording sessions in `pose`. Here, we have loaded in two sessions. Each frame of the `pose` has a session id label in `ids`." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " id AnimalID Sex Strain Condition\n", + "0 0 A0 Male Adora2a-Cre Baseline\n", + "1 1 A1 Female Adora2a-Cre Baseline\n", + "\n", + "[0 0 0 ... 1 1 1]\n" + ] + } + ], + "source": [ + "print(meta)\n", + "print(\"\\n\" + str(ids))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`connectivity` contains key information indicating keypoint labels, connectivity, etc." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "keypoint labels\n", + "['Snout', 'EarR', 'EarL', 'SpineF', 'SpineM', 'Tail_base_', 'Forepaw_R', 'Wrist_R', 'ForeLimb_R', 'Forepaw_L', 'Wrist_L', 'Forelimb_L', 'Hindpaw_R', 'Ankel_R', 'Hindlimb_R', 'Hindpaw_L', 'Ankel_L', 'Hindlimb_L']\n", + "\n", + " Keypoint connections\n", + "[[ 0 1]\n", + " [ 1 3]\n", + " [ 0 2]\n", + " [ 2 3]\n", + " [ 2 1]\n", + " [ 0 3]\n", + " [ 4 3]\n", + " [ 5 4]\n", + " [ 6 7]\n", + " [ 7 8]\n", + " [ 8 3]\n", + " [ 9 10]\n", + " [10 11]\n", + " [11 3]\n", + " [12 13]\n", + " [13 14]\n", + " [14 5]\n", + " [15 16]\n", + " [16 17]\n", + " [17 5]]\n" + ] + } + ], + "source": [ + "print(\"keypoint labels\")\n", + "print(connectivity.joint_names)\n", + "print(\"\\n Keypoint connections\")\n", + "print(connectivity.links)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's plot 150 frames from each session." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - " 0%| | 0/200 [00:00 1\u001b[0m vis\u001b[39m.\u001b[39;49mskeleton_vid3D(\n\u001b[1;32m 2\u001b[0m pose,\n\u001b[1;32m 3\u001b[0m connectivity,\n\u001b[1;32m 4\u001b[0m frames\u001b[39m=\u001b[39;49m[\u001b[39m1000\u001b[39;49m, \u001b[39m500000\u001b[39;49m, \u001b[39m200000\u001b[39;49m],\n\u001b[1;32m 5\u001b[0m N_FRAMES\u001b[39m=\u001b[39;49m\u001b[39m200\u001b[39;49m,\n\u001b[1;32m 6\u001b[0m dpi\u001b[39m=\u001b[39;49m\u001b[39m100\u001b[39;49m,\n\u001b[1;32m 7\u001b[0m VID_NAME\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mvid_raw.mp4\u001b[39;49m\u001b[39m\"\u001b[39;49m,\n\u001b[1;32m 8\u001b[0m SAVE_ROOT\u001b[39m=\u001b[39;49mconfig[\u001b[39m\"\u001b[39;49m\u001b[39mout_path\u001b[39;49m\u001b[39m\"\u001b[39;49m],\n\u001b[1;32m 9\u001b[0m )\n\u001b[1;32m 11\u001b[0m Video(config[\u001b[39m\"\u001b[39m\u001b[39mout_path\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m+\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mvis_vid_raw.mp4\u001b[39m\u001b[39m\"\u001b[39m, width\u001b[39m=\u001b[39m\u001b[39m600\u001b[39m, height\u001b[39m=\u001b[39m\u001b[39m600\u001b[39m)\n", - "File \u001b[0;32m/hpc/group/tdunn/joshwu/dappy/src/dappy/visualization.py:891\u001b[0m, in \u001b[0;36mskeleton_vid3D\u001b[0;34m(pose, connectivity, frames, N_FRAMES, fps, dpi, VID_NAME, SAVE_ROOT)\u001b[0m\n\u001b[1;32m 886\u001b[0m \u001b[39mfor\u001b[39;00m color, (index_from, index_to) \u001b[39min\u001b[39;00m \u001b[39mzip\u001b[39m(COLOR, links_expand):\n\u001b[1;32m 887\u001b[0m xs, ys, zs \u001b[39m=\u001b[39m [\n\u001b[1;32m 888\u001b[0m np\u001b[39m.\u001b[39marray([kpts_3d[index_from, j], kpts_3d[index_to, j]])\n\u001b[1;32m 889\u001b[0m \u001b[39mfor\u001b[39;00m j \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(\u001b[39m3\u001b[39m)\n\u001b[1;32m 890\u001b[0m ]\n\u001b[0;32m--> 891\u001b[0m ax_3d\u001b[39m.\u001b[39;49mplot3D(xs, ys, zs, c\u001b[39m=\u001b[39;49mcolor, lw\u001b[39m=\u001b[39;49m\u001b[39m2\u001b[39;49m)\n\u001b[1;32m 893\u001b[0m ax_3d\u001b[39m.\u001b[39mset_xlim(\u001b[39m*\u001b[39mlimits[\u001b[39m0\u001b[39m, :])\n\u001b[1;32m 894\u001b[0m ax_3d\u001b[39m.\u001b[39mset_ylim(\u001b[39m*\u001b[39mlimits[\u001b[39m1\u001b[39m, :])\n", - "File \u001b[0;32m/hpc/group/tdunn/joshwu/miniconda3/envs/capture/lib/python3.8/site-packages/mpl_toolkits/mplot3d/axes3d.py:1497\u001b[0m, in \u001b[0;36mAxes3D.plot\u001b[0;34m(self, xs, ys, zdir, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1494\u001b[0m \u001b[39m# Match length\u001b[39;00m\n\u001b[1;32m 1495\u001b[0m zs \u001b[39m=\u001b[39m np\u001b[39m.\u001b[39mbroadcast_to(zs, np\u001b[39m.\u001b[39mshape(xs))\n\u001b[0;32m-> 1497\u001b[0m lines \u001b[39m=\u001b[39m \u001b[39msuper\u001b[39;49m()\u001b[39m.\u001b[39;49mplot(xs, ys, \u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1498\u001b[0m \u001b[39mfor\u001b[39;00m line \u001b[39min\u001b[39;00m lines:\n\u001b[1;32m 1499\u001b[0m art3d\u001b[39m.\u001b[39mline_2d_to_3d(line, zs\u001b[39m=\u001b[39mzs, zdir\u001b[39m=\u001b[39mzdir)\n", - "File \u001b[0;32m/hpc/group/tdunn/joshwu/miniconda3/envs/capture/lib/python3.8/site-packages/matplotlib/axes/_axes.py:1632\u001b[0m, in \u001b[0;36mAxes.plot\u001b[0;34m(self, scalex, scaley, data, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1390\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 1391\u001b[0m \u001b[39mPlot y versus x as lines and/or markers.\u001b[39;00m\n\u001b[1;32m 1392\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1629\u001b[0m \u001b[39m(``'green'``) or hex strings (``'#008000'``).\u001b[39;00m\n\u001b[1;32m 1630\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 1631\u001b[0m kwargs \u001b[39m=\u001b[39m cbook\u001b[39m.\u001b[39mnormalize_kwargs(kwargs, mlines\u001b[39m.\u001b[39mLine2D)\n\u001b[0;32m-> 1632\u001b[0m lines \u001b[39m=\u001b[39m [\u001b[39m*\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_get_lines(\u001b[39m*\u001b[39margs, data\u001b[39m=\u001b[39mdata, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)]\n\u001b[1;32m 1633\u001b[0m \u001b[39mfor\u001b[39;00m line \u001b[39min\u001b[39;00m lines:\n\u001b[1;32m 1634\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39madd_line(line)\n", - "File \u001b[0;32m/hpc/group/tdunn/joshwu/miniconda3/envs/capture/lib/python3.8/site-packages/matplotlib/axes/_base.py:312\u001b[0m, in \u001b[0;36m_process_plot_var_args.__call__\u001b[0;34m(self, data, *args, **kwargs)\u001b[0m\n\u001b[1;32m 310\u001b[0m this \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m args[\u001b[39m0\u001b[39m],\n\u001b[1;32m 311\u001b[0m args \u001b[39m=\u001b[39m args[\u001b[39m1\u001b[39m:]\n\u001b[0;32m--> 312\u001b[0m \u001b[39myield from\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_plot_args(this, kwargs)\n", - "File \u001b[0;32m/hpc/group/tdunn/joshwu/miniconda3/envs/capture/lib/python3.8/site-packages/matplotlib/axes/_base.py:538\u001b[0m, in \u001b[0;36m_process_plot_var_args._plot_args\u001b[0;34m(self, tup, kwargs, return_kwargs)\u001b[0m\n\u001b[1;32m 536\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mlist\u001b[39m(result)\n\u001b[1;32m 537\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 538\u001b[0m \u001b[39mreturn\u001b[39;00m [l[\u001b[39m0\u001b[39m] \u001b[39mfor\u001b[39;00m l \u001b[39min\u001b[39;00m result]\n", - "File \u001b[0;32m/hpc/group/tdunn/joshwu/miniconda3/envs/capture/lib/python3.8/site-packages/matplotlib/axes/_base.py:538\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 536\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mlist\u001b[39m(result)\n\u001b[1;32m 537\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 538\u001b[0m \u001b[39mreturn\u001b[39;00m [l[\u001b[39m0\u001b[39m] \u001b[39mfor\u001b[39;00m l \u001b[39min\u001b[39;00m result]\n", - "File \u001b[0;32m/hpc/group/tdunn/joshwu/miniconda3/envs/capture/lib/python3.8/site-packages/matplotlib/axes/_base.py:531\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 528\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 529\u001b[0m labels \u001b[39m=\u001b[39m [label] \u001b[39m*\u001b[39m n_datasets\n\u001b[0;32m--> 531\u001b[0m result \u001b[39m=\u001b[39m (make_artist(x[:, j \u001b[39m%\u001b[39;49m ncx], y[:, j \u001b[39m%\u001b[39;49m ncy], kw,\n\u001b[1;32m 532\u001b[0m {\u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs, \u001b[39m'\u001b[39;49m\u001b[39mlabel\u001b[39;49m\u001b[39m'\u001b[39;49m: label})\n\u001b[1;32m 533\u001b[0m \u001b[39mfor\u001b[39;00m j, label \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(labels))\n\u001b[1;32m 535\u001b[0m \u001b[39mif\u001b[39;00m return_kwargs:\n\u001b[1;32m 536\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mlist\u001b[39m(result)\n", - "File \u001b[0;32m/hpc/group/tdunn/joshwu/miniconda3/envs/capture/lib/python3.8/site-packages/matplotlib/axes/_base.py:351\u001b[0m, in \u001b[0;36m_process_plot_var_args._makeline\u001b[0;34m(self, x, y, kw, kwargs)\u001b[0m\n\u001b[1;32m 349\u001b[0m default_dict \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_getdefaults(\u001b[39mset\u001b[39m(), kw)\n\u001b[1;32m 350\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_setdefaults(default_dict, kw)\n\u001b[0;32m--> 351\u001b[0m seg \u001b[39m=\u001b[39m mlines\u001b[39m.\u001b[39;49mLine2D(x, y, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkw)\n\u001b[1;32m 352\u001b[0m \u001b[39mreturn\u001b[39;00m seg, kw\n", - "File \u001b[0;32m/hpc/group/tdunn/joshwu/miniconda3/envs/capture/lib/python3.8/site-packages/matplotlib/lines.py:370\u001b[0m, in \u001b[0;36mLine2D.__init__\u001b[0;34m(self, xdata, ydata, linewidth, linestyle, color, marker, markersize, markeredgewidth, markeredgecolor, markerfacecolor, markerfacecoloralt, fillstyle, antialiased, dash_capstyle, solid_capstyle, dash_joinstyle, solid_joinstyle, pickradius, drawstyle, markevery, **kwargs)\u001b[0m\n\u001b[1;32m 367\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mset_drawstyle(drawstyle)\n\u001b[1;32m 369\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_color \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n\u001b[0;32m--> 370\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mset_color(color)\n\u001b[1;32m 371\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_marker \u001b[39m=\u001b[39m MarkerStyle(marker, fillstyle)\n\u001b[1;32m 373\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_markevery \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n", - "File \u001b[0;32m/hpc/group/tdunn/joshwu/miniconda3/envs/capture/lib/python3.8/site-packages/matplotlib/lines.py:1030\u001b[0m, in \u001b[0;36mLine2D.set_color\u001b[0;34m(self, color)\u001b[0m\n\u001b[1;32m 1022\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mset_color\u001b[39m(\u001b[39mself\u001b[39m, color):\n\u001b[1;32m 1023\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 1024\u001b[0m \u001b[39m Set the color of the line.\u001b[39;00m\n\u001b[1;32m 1025\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1028\u001b[0m \u001b[39m color : color\u001b[39;00m\n\u001b[1;32m 1029\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1030\u001b[0m mcolors\u001b[39m.\u001b[39;49m_check_color_like(color\u001b[39m=\u001b[39;49mcolor)\n\u001b[1;32m 1031\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_color \u001b[39m=\u001b[39m color\n\u001b[1;32m 1032\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstale \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n", - "File \u001b[0;32m/hpc/group/tdunn/joshwu/miniconda3/envs/capture/lib/python3.8/site-packages/matplotlib/colors.py:130\u001b[0m, in \u001b[0;36m_check_color_like\u001b[0;34m(**kwargs)\u001b[0m\n\u001b[1;32m 128\u001b[0m \u001b[39mfor\u001b[39;00m k, v \u001b[39min\u001b[39;00m kwargs\u001b[39m.\u001b[39mitems():\n\u001b[1;32m 129\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m is_color_like(v):\n\u001b[0;32m--> 130\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00mv\u001b[39m!r}\u001b[39;00m\u001b[39m is not a valid value for \u001b[39m\u001b[39m{\u001b[39;00mk\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n", - "\u001b[0;31mValueError\u001b[0m: array([3. , 1.5651001, 0. , 1.5 ], dtype=float32) is not a valid value for color" + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 150/150 [00:27<00:00, 5.43it/s]\n" ] }, { "data": { - "image/png": "", + "text/html": [ + "" + ], "text/plain": [ - "
" + "" ] }, + "execution_count": 18, "metadata": {}, - "output_type": "display_data" + "output_type": "execute_result" } ], "source": [ - "vis.skeleton_vid3D(\n", + "vis.pose3D_arena(\n", " pose,\n", " connectivity,\n", - " frames=[1000, 500000, 200000],\n", - " N_FRAMES=200,\n", + " frames=[1000, 500000],\n", + " N_FRAMES=150,\n", " dpi=100,\n", - " VID_NAME=\"vid_raw.mp4\",\n", + " VID_NAME=\"raw.mp4\",\n", " SAVE_ROOT=config[\"out_path\"],\n", ")\n", "\n", - "Video(config[\"out_path\"] + \"vis_vid_raw.mp4\", width=600, height=600)" + "Video(config[\"out_path\"] + \"vis_raw.mp4\", width=600, height=600)" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "All other features will be egocentric so we will center and lock the front spine onto the x-z axis by rotation." + "Skeletons across sessions may not be aligned worldviews. The following code will estimate the floor plane for each session, and rotate to the x-y plane." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 20, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/2 [00:00\n", + " Your browser does not support the video element.\n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "pose = features.rotate_spine(features.center_spine(pose))\n", + "from dappy import preprocess\n", "\n", - "vis.skeleton_vid3D(\n", - " pose,\n", + "pose_aligned = preprocess.align_floor_by_id(pose=pose, ids=ids, foot_id=12, head_id=0)\n", + "\n", + "vis.pose3D_arena(\n", + " pose_aligned,\n", " connectivity,\n", - " frames=[50000],\n", - " N_FRAMES=200,\n", + " frames=[1000, 500000],\n", + " N_FRAMES=150,\n", " dpi=100,\n", - " VID_NAME=\"vid_centered.mp4\",\n", + " VID_NAME=\"aligned.mp4\",\n", " SAVE_ROOT=config[\"out_path\"],\n", - ")" + ")\n", + "\n", + "Video(config[\"out_path\"] + \"vis_aligned.mp4\", width=600, height=600)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can use the following code to save the new aligned poses for easy access later." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "from dappy import write\n", + "\n", + "# write.pose_h5(pose_aligned, ids, config[\"data_path\"] + \"pose_aligned.h5\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For this analysis, we would like to prevent divergence of behavioral representations due to global position. Thus, we will generate an egocentric representation of pose for downstream feature calculation. \n", + "\n", + "Here, we center the mid-spine to $(0,0,0)$, and rotate the front-spine to the $x+$ direction." ] }, { @@ -179,7 +366,20 @@ "metadata": {}, "outputs": [], "source": [ - "Video(config[\"out_path\"] + \"vis_vid_centered.mp4\", width=600, height=600)" + "# Provide the mid-spine and the mid-spine -> front-spine indices.\n", + "pose = features.rotate_spine(features.center_spine(pose_aligned, keypt_idx=4), keypt_idx=[4, 3])\n", + "\n", + "vis.skeleton_vid3D(\n", + " pose,\n", + " connectivity,\n", + " frames=[50000],\n", + " N_FRAMES=150,\n", + " dpi=100,\n", + " VID_NAME=\"centered.mp4\",\n", + " SAVE_ROOT=config[\"out_path\"],\n", + ")\n", + "\n", + "Video(config[\"out_path\"] + \"vis_centered.mp4\", width=600, height=600)" ] }, { @@ -187,7 +387,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Using this centered and spine-locked pose transformation, we can calculate relative velocities of all keypoints. We leave out the mid spine since it is zeroed." + "In this package, we provide functionality for easily calculating features of interest. \n", + "\n", + "Using this centered and spine-locked pose transformation, we can calculate relative velocities of all keypoints. We leave out the mid spine since it is centered." ] }, { @@ -196,6 +398,8 @@ "metadata": {}, "outputs": [], "source": [ + "from dappy import features\n", + "\n", "# Getting relative velocities\n", "rel_vel, rel_vel_labels = features.get_velocities(\n", " pose,\n", @@ -211,9 +415,53 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Next, we calculate joint angles.\n", - "\n", - "Hopefully, informative joint angles are preselected in `skeletons.py`." + "You can also calculate joint angles of interest as specified in `skeletons.py`." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[ 0 1 3]\n", + " [ 0 2 3]\n", + " [ 0 3 4]\n", + " [ 1 3 4]\n", + " [ 2 3 4]\n", + " [ 3 4 5]\n", + " [ 1 3 8]\n", + " [ 2 3 8]\n", + " [ 0 3 8]\n", + " [ 3 8 7]\n", + " [ 8 7 6]\n", + " [ 1 3 11]\n", + " [ 2 3 11]\n", + " [ 0 3 11]\n", + " [ 3 11 10]\n", + " [11 10 9]\n", + " [ 4 5 14]\n", + " [ 5 14 13]\n", + " [14 13 12]\n", + " [ 4 5 17]\n", + " [ 5 17 16]\n", + " [17 16 15]\n", + " [ 0 3 6]\n", + " [ 0 3 7]\n", + " [ 0 3 9]\n", + " [ 0 3 10]\n", + " [ 4 5 12]\n", + " [ 4 5 13]\n", + " [ 4 5 15]\n", + " [ 4 5 16]]\n" + ] + } + ], + "source": [ + "print(connectivity.angles)" ] }, { @@ -231,9 +479,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Finally, we are going to save the egocentric x, y, z coordinates as its own set of features\n", + "These velocity and angle calculations are just for demonstration, we will not use velocities or angles for the analysis in this tutorial.\n", "\n", - "This code does not calculate anything - it just reshapes the pose and generates labels for each feature." + "We will just rearrange egocentric x, y, z coordinates of each keypoint into its own set of features. This code does not calculate anything - it just reshapes the pose and generates labels for each feature." ] }, { @@ -243,7 +491,10 @@ "outputs": [], "source": [ "# Reshape pose to get egocentric pose features\n", - "ego_pose, ego_pose_labels = features.get_ego_pose(pose, connectivity.joint_names)" + "features, labels = features.get_ego_pose(pose, connectivity.joint_names)\n", + "\n", + "# Clear some memory\n", + "del angles, rel_vel, angel_labels, rel_vel_labels" ] }, { @@ -251,7 +502,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Merge features together and clear some memory" + "Write features to or read features from `.h5` file." ] }, { @@ -260,9 +511,19 @@ "metadata": {}, "outputs": [], "source": [ - "# Collect all features together\n", - "features = np.concatenate([ego_pose, angles], axis=1)\n", - "labels = ego_pose_labels + angle_labels" + "# Write\n", + "# write.features_h5(features, labels, path=config[\"out_path\"] + \"postural_feats.h5\")\n", + "\n", + "# Read\n", + "# features, labels = read.features_h5(path=config[\"out_path\"] + \"postural_feats.h5\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It's now time for principal component analysis (PCA). PCA is a dimensionality reduction technique which generates orthogonal axes of high variance upon which to project our data. There are many implementations of PCA, but we will use Facebook's Fast Randomized PCA package (`fbpca`), which is significantly faster than most other implementations." ] }, { @@ -271,36 +532,52 @@ "metadata": {}, "outputs": [], "source": [ - "# Save or read kinematic/wavelet features from h5 file\n", - "write.features_h5(\n", - " features, labels, path=\"\".join([config[\"out_path\"], \"postural_feats.h5\"])\n", + "t = time.time()\n", + "pc_feats, pc_labels = features.pca(\n", + " features, labels, categories=[\"ego_euc\"], n_pcs=5, method=\"fbpca\"\n", ")\n", - "# features, labels = read_h5(path = ''.join([pstruct.out_path,'postural_feats.h5']))" + "print(\"PCA time: \" + str(time.time() - t))\n", + "\n", + "del features, labels" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "It's now time for principal component analysis (PCA). PCA is a dimensionality reduction technique which generates orthogonal axes of high variance upon which to project our data. There are many implementations of PCA, but we will use Facebook's Fast Randomized PCA package (`fbpca`), which is significantly faster than most other implementations.\n", + "Although velocities are calculated over rolling windows, the featurization we have so far still lacks the ability to capture complex temporal signals.\n", "\n", - "We calculate PCA separately on each feature category to preserve variance and balance the categories. This is in lieu of z-transforming (mean-centering and unit variance) every feature. ** Discussion" + "To address this, we can leverage the frequency domain through a Morlet wavelet transformation.\n", + "\n", + "Let's see first what a Morlet wavelet looks like." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 28, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "t = time.time()\n", - "pc_feats, pc_labels = features.pca(\n", - " features, labels, categories=[\"ego_euc\", \"ang\"], n_pcs=5, method=\"fbpca\"\n", - ")\n", - "print(\"PCA time: \" + str(time.time() - t))\n", - "\n", - "del features, labels" + "from scipy import signal\n", + "M = 100\n", + "w0 = 5\n", + "s = w0*90/(2*np.pi*25)\n", + "morlet_wavelet = signal.morlet2(M, s, w0)\n", + "plt.plot(morlet_wavelet.imag, label='Imaginary')\n", + "plt.plot(morlet_wavelet.real, label='Real')\n", + "plt.legend()\n", + "plt.show()" ] }, { @@ -319,7 +596,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We now use PCA to reduce the dimensions of the new wavelet features, and consolidate with previous PC scores." + "We now use PCA to reduce the dimensions of the new wavelet features, and consolidate with the previous PC scores. Each frame is now associated with a vector of features corresponding to the PC scores of egocentric keypoint coordinates and local frequency information." ] }, { @@ -328,13 +605,15 @@ "metadata": {}, "outputs": [], "source": [ + "# PCA on wavelet features\n", "pc_wlet, pc_wlet_labels = features.pca(\n", " wlet_feats,\n", " wlet_labels,\n", - " categories=[\"wlet_ego_euc\", \"wlet_ang\"],\n", + " categories=[\"wlet_ego_euc\"],\n", " n_pcs=5,\n", " method=\"fbpca\",\n", ")\n", + "\n", "del wlet_feats, wlet_labels\n", "pc_feats = np.hstack((pc_feats, pc_wlet))\n", "pc_labels += pc_wlet_labels\n", @@ -347,9 +626,17 @@ "metadata": {}, "outputs": [], "source": [ - "write.features_h5(\n", - " pc_feats, pc_labels, path=\"\".join([config[\"out_path\"], \"pca_feats.h5\"])\n", - ")" + "# Optionally save full PC features to file\n", + "# write.features_h5(\n", + "# pc_feats, pc_labels, path=\"\".join([config[\"out_path\"], \"pca_feats.h5\"])\n", + "# )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We encapsulate all relevant data to store in a data object." ] }, { @@ -361,16 +648,24 @@ "data_obj = ds.DataStruct(\n", " pose=pose,\n", " id=ids,\n", - " id_full=ids,\n", " meta=meta,\n", " meta_by_frame=meta_by_frame,\n", " connectivity=connectivity,\n", ")\n", "\n", "data_obj.features = pc_feats\n", + "# Downsampling data, appears to be necessary in order to \n", + "# discover granular structure in embedding\n", "data_obj = data_obj[:: config[\"downsample\"], :]" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Using t-SNE, frames are projected onto a 2D embedding for clustering and visualization." + ] + }, { "cell_type": "code", "execution_count": null, @@ -387,16 +682,10 @@ ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "# Watershed clustering\n", - "data_obj.ws = Watershed(\n", - " sigma=config[\"single_embed\"][\"sigma\"], max_clip=1, log_out=True, pad_factor=0.05\n", - ")\n", - "data_obj.data.loc[:, \"Cluster\"] = data_obj.ws.fit_predict(data=data_obj.embed_vals)" + "The histogram of the 2D embedding is smoothed with a Gaussian, and segmented by the watershed algorithm to determine cluster assignments." ] }, { @@ -405,14 +694,28 @@ "metadata": {}, "outputs": [], "source": [ + "# Watershed clustering\n", + "data_obj.ws = Watershed(\n", + " sigma=config[\"single_embed\"][\"sigma\"], max_clip=1, log_out=True, pad_factor=0.05\n", + ")\n", + "data_obj.data.loc[:, \"Cluster\"] = data_obj.ws.fit_predict(data=data_obj.embed_vals)\n", + "\n", + "# Plot density\n", "vis.density(\n", " data_obj.ws.density,\n", " data_obj.ws.borders,\n", - " filepath=\"\".join([config[\"out_path\"], config[\"label\"], \"/density.png\"]),\n", + " filepath=config[\"out_path\"] + \"/density.png\",\n", " show=True,\n", ")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Within the embedding, we can visualize the density of each animal separately." + ] + }, { "cell_type": "code", "execution_count": null, @@ -421,23 +724,13 @@ "source": [ "vis.density_cat(\n", " data=data_obj,\n", - " column='id',\n", + " column=\"id\",\n", " watershed=data_obj.ws,\n", " n_col=4,\n", - " filepath=\"\".join(\n", - " [config[\"out_path\"], config[\"label\"], \"/density_id.png\"]\n", - " ),\n", - " show=True\n", + " filepath=config[\"out_path\"] + \"/density_id.png\",\n", + " show=True,\n", ")" ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Time to run the rest of the analysis" - ] } ], "metadata": { @@ -456,7 +749,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.8.12" }, "orig_nbformat": 4, "vscode": {