Skip to content

Commit

Permalink
added tutorials, contributing.md, codeowner, data folder, results folder
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuahwu committed Jul 26, 2023
1 parent 9a25d96 commit 2c977b3
Show file tree
Hide file tree
Showing 15 changed files with 554 additions and 187 deletions.
1 change: 1 addition & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# * @joshuahwu
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ Desktop.ini
# Recycle Bin used on file shares
$RECYCLE.BIN/

*.mat
*.sh
*.out
*.p
Expand All @@ -22,22 +21,28 @@ $RECYCLE.BIN/
*.txt
*.mat

*.h5
*.mp4
*.yaml
*.png
*.pyc
*./dappy/__pycache__/

!/data/demo_meta.csv
!/tutorials/
/tutorials/*
!/tutorials/tutorial.ipynb
!/tutorials/automated_pipeline.py
/configs/*
/configs/param_configs/
/configs/param_configs/*
!/configs/tutorial.yaml

/dappy/wandb/
/dappy/models/
/dappy/artifcats/

/results/*
*.egg-info

# Windows shortcuts
Expand Down
Empty file added CONTRIBUTING.md
Empty file.
26 changes: 0 additions & 26 deletions configs/param_configs/fitsne.yaml

This file was deleted.

20 changes: 5 additions & 15 deletions configs/tutorial.yaml
Original file line number Diff line number Diff line change
@@ -1,36 +1,26 @@
# Folder path of data location
data_path: '/home/exx/Desktop/GitHub/CAPTURE_data/ensemble_healthy/'

# File path of predictions file
pose_path: '/home/exx/Desktop/GitHub/CAPTURE_data/ensemble_healthy/predictions.mat'

# File path of metadata file
meta_path: '/home/exx/Desktop/GitHub/CAPTURE_data/ensemble_healthy/metadata.csv'
data_path: '../data/'

# Output folder of all plots
out_path: '/home/exx/Desktop/GitHub/results/ensemble_healthy/'

# Path of list of behavior heuristics
heuristics_path: '/home/exx/Desktop/GitHub/dappy/src/dappy/behavior_heuristics.py'
out_path: '../results/tutorial/'

# File path of skeletal specifications
skeleton_path: '/home/exx/Desktop/GitHub/dappy/src/dappy/skeletons.py'
skeleton_path: '../src/dappy/skeletons.py'

# Key of skeleton in the skeletons file
skeleton_name: 'mouse20_notail'

label: 'fitsne'

analysis: 'embed'

downsample: 10

# Parameters for t-SNE embedding
single_embed:
method: 'fitsne'
perplexity: 50
lr: 'auto'
sigma: 15

# Parameters for embedding new data into an existing t-SNE embedding
transform_embed:
method: 'knn'
k: 5
Expand Down
3 changes: 3 additions & 0 deletions data/demo_meta.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
AnimalID,Sex,Strain,Condition
A0,Male,Adora2a-Cre,Baseline
A1,Female,Adora2a-Cre,Baseline
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ channels:
dependencies:
- python=3.8
- numpy
- faiss-gpu=1.6.5
- faiss-gpu=1.7.1
- matplotlib
- seaborn
- hdf5storage
Expand Down
5 changes: 3 additions & 2 deletions src/dappy/DataStruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def __init__(
joint_names: List[str],
colors: Union[np.ndarray, List[Tuple[float, float, float, float]]],
links: Union[np.ndarray, List[Tuple[int, int]]],
angles: Union[np.ndarray, List[Tuple[int, int, int]]],
angles: Optional[Union[np.ndarray, List[Tuple[int, int, int]]]] = None,
):
"""Initializes instance of Connectivity class
Expand All @@ -183,7 +183,8 @@ def __init__(
self.joint_names = joint_names
self.colors = self._check_type(colors, np.float32)
self.links = self._check_type(links, np.uint16)
self.angles = self._check_type(angles, np.uint16)
if angles is not None:
self.angles = self._check_type(angles, np.uint16)

def _check_type(
self,
Expand Down
121 changes: 99 additions & 22 deletions src/dappy/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,92 @@
from sklearn.ensemble import RandomForestRegressor
import seaborn as sns
from dappy.embed import Watershed

import faiss
import time
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import dijkstra, minimum_spanning_tree
from scipy.spatial import distance

def get_nn_graph(X: np.ndarray, k: int = 5, weighted: bool = True):
X = np.ascontiguousarray(X, dtype=np.float32)

# max_k = 20
print("Building NN Graph")
start_time = time.time()
index = faiss.IndexFlatL2(X.shape[1])
index.add(X)
distances, indices = index.search(X, k=k+1)
distances, indices = distances[:, 1:], indices[:, 1:]
row = np.tile(np.arange(X.shape[0])[:, None], k)

# min_distances, min_indices = distances[:, :k], indices[:,:k]
# min_row = row = np.tile(np.arange(X.shape[0])[:, None], k)
if weighted:
nn_graph = csr_matrix(
(distances.flatten(), (row.flatten(), indices.flatten())),
shape=(X.shape[0], X.shape[0]),
)

# min_graph = csr_matrix(
# (min_distances.flatten(), (min_row.flatten(), min_indices.flatten())),
# shape=(X.shape[0], X.shape[0]),
# )
else:
nn_graph = csr_matrix(
(np.ones(distances.flatten().shape), (row.flatten(), indices.flatten())),
shape=(X.shape[0], X.shape[0]),
)
# min_graph = csr_matrix(
# (np.ones(min_distances.flatten()), (min_row.flatten(), min_indices.flatten())),
# shape=(X.shape[0], X.shape[0]),
# )

print("NN Time: " + str(time.time() - start_time))

# # Get minimum spanning tree to ensure full connectivity in graph
# start_time = time.time()
# min_span_tree = minimum_spanning_tree(nn_graph)
# min_span_tree.data = min_span_tree.data.astype(X.dtype)
# print("Minimum Spanning Tree Time: " + str(time.time() - start_time))

# # Get union between minimum spanning tree and nn graph
# min_span_tree_insert = min_span_tree - nn_graph
# min_span_tree_insert.data = np.where(min_span_tree_insert.data < 0, 1, 0)
# graph = (
# min_span_tree
# - min_span_tree.multiply(min_span_tree_insert)
# + nn_graph.multiply(min_span_tree_insert)
# )

return nn_graph


def get_pose_geodesic(
pose: np.ndarray,
graph: csr_matrix,
START_FRAME: int,
END_FRAME: int,
):
print("Calculating Dijkstra")
path_indices = dijkstra(
csgraph=graph, directed=False, indices=END_FRAME, return_predecessors=True
)[1]

print("Finding pose geodesic")
geodesic_pose, geodesic_indices = [], []
curr_frame = START_FRAME

while path_indices[curr_frame] > 0:
geodesic_pose += [pose[curr_frame : curr_frame + 1, ...]]
geodesic_indices += [curr_frame]
curr_frame = path_indices[curr_frame]

if curr_frame != END_FRAME:
print("Broken graph")

geodesic_pose = np.concatenate(geodesic_pose, axis=0)

return geodesic_pose, geodesic_indices

def cluster_freq_from_data(data: np.ndarray, watershed: Watershed):
"""
Expand Down Expand Up @@ -63,12 +146,6 @@ def lstsq(freq: np.ndarray, y: np.ndarray, filepath: str):
m = np.linalg.lstsq(np.delete(freq, i, axis=0), np.delete(y, i))[0]
pred_y[i] = freq[i, :] @ m

plt.scatter(y, pred_y)
plt.xlabel("Real Fluorescence")
plt.ylabel("Predicted Fluorescence")
plt.savefig("".join([filepath, "lstsq.png"]))
plt.close()

print("R2 Score " + str(r2_score(y, pred_y)))
return pred_y

Expand All @@ -85,16 +162,16 @@ def elastic_net(freq: np.ndarray, y: np.ndarray, filepath: str):
regr.fit(scaler.transform(temp_lesion), np.log2(np.delete(y, i)))
pred_y[i] = regr.predict(scaler.transform(freq[i, :][None, :]))

sns.set(rc={'figure.figsize':(6,5)})
f = plt.figure()
# import pdb; pdb.set_trace()
plt.plot(np.linspace(y.min(), y.max(), 100), np.linspace(y.min(),y.max(),100), markersize=0, color='k', label="y = x")
plt.legend(loc="upper center")
plt.scatter(y, 2**pred_y, s=30)
plt.xlabel("Real Fluorescence")
plt.ylabel("Predicted Fluorescence")
plt.savefig("".join([filepath, "elastic.png"]))
plt.close()
# sns.set(rc={'figure.figsize':(6,5)})
# f = plt.figure()
# # import pdb; pdb.set_trace()
# plt.plot(np.linspace(y.min(), y.max(), 100), np.linspace(y.min(),y.max(),100), markersize=0, color='k', label="y = x")
# plt.legend(loc="upper center")
# plt.scatter(y, 2**pred_y, s=30)
# plt.xlabel("Real Fluorescence")
# plt.ylabel("Predicted Fluorescence")
# plt.savefig("".join([filepath, "elastic.png"]))
# plt.close()

print("R2 Score " + str(r2_score(y, 2**pred_y)))
return pred_y,
Expand Down Expand Up @@ -148,11 +225,11 @@ def random_forest(freq: np.ndarray, y: np.ndarray, filepath: str):
rf_regr.fit(np.delete(freq, i, axis=0), np.delete(y, i))
pred_y[i] = rf_regr.predict(freq[i, :][None, :])

plt.scatter(y, pred_y)
plt.xlabel("Real Fluorescence")
plt.ylabel("Predicted Fluorescence")
plt.savefig("".join([filepath, "rforest.png"]))
plt.close()
# plt.scatter(y, pred_y)
# plt.xlabel("Real Fluorescence")
# plt.ylabel("Predicted Fluorescence")
# plt.savefig("".join([filepath, "rforest.png"]))
# plt.close()
print("R2 Score " + str(r2_score(y, pred_y)))
return

Expand Down
3 changes: 1 addition & 2 deletions src/dappy/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,7 @@ def anipose_med_filt(

for _, i in enumerate(tqdm(np.unique(exp_id))):
pose_exp = pose[exp_id == i, :, :]
# dxyz = get_frame_diff(pose_exp, time=1, idx_center=False)
# vel =

pose_error = pose_exp - scp_ndi.median_filter(
pose_exp, (filter_len, 1, 1)
) # Median filter 5 frames repeat the ends of video
Expand Down
15 changes: 14 additions & 1 deletion src/dappy/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def pose_mat(
path: str,
connectivity: Connectivity,
dtype: Optional[Type[Union[np.float64, np.float32]]] = np.float32,
):
) -> np.ndarray: ## TODO: Use output docstrings
"""Reads 3D pose data from .mat files.
Expand Down Expand Up @@ -183,6 +183,19 @@ def connectivity(path: str, skeleton_name: str):

return connectivity

def connectivity_config(path: str):
skeleton_config = config(path)

joint_names = skeleton_config["LABELS"]
colors = skeleton_config["COLORS"]
links = skeleton_config["SEGMENTS"]

connectivity = Connectivity(
joint_names=joint_names, colors=colors, links=links
)

return connectivity


def features_h5(
path, dtype: Optional[Type[Union[np.float64, np.float32]]] = np.float32
Expand Down
12 changes: 8 additions & 4 deletions src/dappy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@
import numpy as np
from typing import Union, List


def by_id(func):
@functools.wraps(func)
def wrapper(pose: np.ndarray, ids:Union[np.ndarray, List], **kwargs):
def wrapper(pose: np.ndarray, ids: Union[np.ndarray, List], **kwargs):
for _, i in enumerate(tqdm(np.unique(ids))):
pose_exp = pose[ids == i,:,:]
pose[ids == i ,:,:] = func(pose_exp, **kwargs)
pose_exp = pose[ids == i, :, :]
pose[ids == i, :, :] = func(pose_exp, **kwargs)
return pose

return wrapper


def rolling_window(data:np.ndarray, window:int):
"""
Returns a view of data windowed (data.shape, window)
Expand All @@ -36,6 +39,7 @@ def rolling_window(data:np.ndarray, window:int):
np.lib.stride_tricks.as_strided(d_pad, shape=shape, strides=strides), 0, 1
)


def get_frame_diff(x: np.ndarray, time: int, idx_center: bool = True):
"""
IN:
Expand Down Expand Up @@ -81,4 +85,4 @@ def standard_scale(features, clip=None):
features = np.clip(features / feat_std[feat_std != 0], -clip, clip)
labels = [label for i, label in enumerate(labels) if feat_std[i] != 0]

return features, labels
return features, labels
Loading

0 comments on commit 2c977b3

Please sign in to comment.