diff --git a/README.md b/README.md index 6aba357..292c378 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ An example notebook demoing how to transform 2D keypoints to a new coordinate sy The data is a short clip of tracked seabirds flying around a boat. The notebook shows how to express the data in a coordinate system aligned with the boat. -## Usage +## Installation Create a conda environment and install the latest version of `movement`: @@ -23,7 +23,7 @@ Install additional dependencies for Jupyter Notebooks: pip install jupyter jupyterlab ipympl ``` -If you wish to use the `movement`GUI, which additionally requires [napari](napari:), +If you wish to use the `movement` GUI, which additionally requires [napari](napari:), you should replace the first command with: ```sh conda create -n movement-env -c conda-forge movement napari pyqt @@ -44,6 +44,41 @@ jupyter lab notebook_seabirds.ipynb Alternatively, you can run the notebook in VS Code using the [Jupyter extension](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.jupyter). -## Next steps -- Scale distances using boat width -- Define ROIs \ No newline at end of file +## Data + +## Postprocessing steps + +The DLC-predicted trajectories are expressed in the camera drone's coordinate system. Before transforming them to the boat's coordinate system, we applied the following postprocessing steps to the raw DLC data: + +1. **Manual correction of ID errors** using a custom napari GUI. + +2. **Trajectory splitting based on gaps and jumps.** +DLC assigns one of five predefined IDs to the birds visible in each frame. However, if a bird leaves and later re-enters the frame, DLC may reuse the same ID for a different bird. To prevent this, we reset an individual's ID whenever there is a sufficiently large temporal gap in its trajectory and the position after the gap is far from the last valid point before it. + +3. **Filtering out low-confidence points and linear interpolation.** +We discard points with a DLC confidence score below 0.6, then linearly interpolate across the resulting gaps. + +These steps are collected in the notebook [`demo_notebook.ipynb`](https://github.com/anna-teruel/seabirds-coord-transform/blob/main/demo_notebook.ipynb). + + +## Boat coordinate system + +The notebook [`notebook_boat_coord_system.py`](notebook_boat_coord_system.py) transforms the cleaned DLC bird trajectories from the image coordinate system (ICS) to a boat coordinate system (BCS). The steps are: + +1. **Separate bird and boat data** from the input DLC file into two `movement` datasets. + +2. **Filter and interpolate boat keypoints**: low-confidence detections (below 0.5) are set to NaN and linearly interpolated over time. + +3. **Define the boat coordinate system (BCS) per frame**: + - Origin: mean of all boat keypoints per frame. + - Y-axis: unit vector from the boat centroid to the boat tip. + - X-axis: perpendicular to the y-axis, pointing to the right side of the boat. + - Z-axis: cross product of the x and y axes, it is positive pointing out of the image plane. + +4. **Express bird and boat trajectories in BCS** using the per-frame rotation matrix derived from the BCS axes definition. + +5. **Scale trajectories to meters** using the boat length per frame as reference (vessel length: 8.55 m). + +6. **Interpolate and smooth bird centroid trajectories** using monotone cubic interpolation (PCHIP) followed by a rolling median filter (window = 15 frames). + +7. **Export** the transformed data as `.nc` files (loadable in napari), `.csv` files, and the plot of the raw (unsmoothed) data as a plotly `.html` file. diff --git a/notebook_boat_coord_system.py b/notebook_boat_coord_system.py index 0d10081..612d6f7 100644 --- a/notebook_boat_coord_system.py +++ b/notebook_boat_coord_system.py @@ -1,26 +1,32 @@ -"""A notebook to express DLC trajectories from birds in a boat coordinate system. +"""A notebook to express cleaned DLC bird trajectories in a boat coordinate system. Requirements: following installation instructions for `movement` https://movement.neuroinformatics.dev/latest/user_guide/installation.html +Also install: plotly + Then run this notebook in that conda environment. """ # %% -import glob from pathlib import Path import matplotlib.pyplot as plt import numpy as np import pandas as pd +import plotly.graph_objects as go import xarray as xr - -from movement.filtering import filter_by_confidence, interpolate_over_time +from movement.filtering import ( + filter_by_confidence, + interpolate_over_time, + rolling_filter, + savgol_filter, +) from movement.io import load_poses, save_poses from movement.kinematics import compute_pairwise_distances -from movement.utils.vector import compute_norm - +from movement.plots import plot_occupancy +from movement.utils.vector import compute_norm, convert_to_unit from scipy.spatial.transform import Rotation as R # Hide attributes globally @@ -37,10 +43,10 @@ data_dir = notebook_path.parent / "data" filepath = ( data_dir - / "second-iter" - / "FILE00009_sDLC_DekrW32_seabirdNov6shuffle1_snapshot_170_el_filtered.h5" + / "trayectorias_AT" + / "FILE00009_sDLC_DekrW32_seabirdNov6shuffle1_snapshot_170_el_filtered_split_interpolated.h5" ) -output_dir = notebook_path.parent / "output" +output_dir = notebook_path.parent / "output" output_dir.mkdir(parents=True, exist_ok=True) # Vessel size: 8.55 x 2.95 m @@ -50,6 +56,7 @@ # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% # Helper functions + def get_data_for_load_from_numpy(df): """Get array from dataframe to use "from numpy" function""" list_individuals = sorted(df.columns.get_level_values("individuals").unique()) @@ -85,17 +92,6 @@ def get_data_for_load_from_numpy(df): return position_array, confidence_array, list_individuals, list_keypoints -def compute_rotation_to_align_y_axis(vec): - """Compute rotation to align y-axis""" - rrot, _rssd = R.align_vectors( - np.array([[0, 1, 0]]), # Vector components observed in initial frame A - vec, # Vector components observed in another frame B - return_sensitivity=False, - ) - - return rrot - - def add_z_coord_to_position_array(position_array): """Add z coordinate to position array""" return xr.concat( @@ -110,6 +106,43 @@ def add_z_coord_to_position_array(position_array): ) +def export_dataarray_as_csv(da_position, output_path): + """Export as a tidy dataframe with x,y separate columns.""" + df = da_position.to_dataframe().reset_index() + + # drop rows with NaN positions + df = df.dropna(subset=["position"]) + + # Pivot space to get x and y as separate columns + columns_to_keep = [idx for idx in df.columns if idx not in ["space", "position"]] + df_wide = df.pivot( + index=columns_to_keep, + columns="space", + values="position", + ).reset_index() + + # Flatten column names + df_wide.columns.name = None + + # Export to CSV + df_wide.to_csv(output_path, index=False) + + return output_path + + +def export_as_ds(da_position, da_confidence, output_path): + """Export dataset with given position array and nan confidence.""" + ds = xr.Dataset( + { + "position": da_position, + "confidence": da_confidence, + # xr.full_like(da_position.isel(space=0, drop=True), np.nan), + } + ) + ds.attrs["ds_type"] = "poses" + ds.to_netcdf(output_path) + + # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% # Read input data as pandas dataframe df = pd.read_hdf(filepath) @@ -120,7 +153,9 @@ def add_z_coord_to_position_array(position_array): if (filepath.parent / (filepath.stem + "_birds.h5")).exists(): ds_birds = load_poses.from_dlc_file(filepath.parent / (filepath.stem + "_birds.h5")) else: - columns_to_drop = [col for col in df.columns if "single" in col] + columns_to_drop = [ + col for col in df.columns if col[-2] in ["boatBL", "boatBR", "boatTip"] + ] df_birds = df.drop(columns=columns_to_drop) position_array, confidence_array, list_individuals, list_keypoints = ( @@ -145,8 +180,10 @@ def add_z_coord_to_position_array(position_array): if (filepath.parent / (filepath.stem + "_boat.h5")).exists(): ds_boat = load_poses.from_dlc_file(filepath.parent / (filepath.stem + "_boat.h5")) else: - columns_to_drop = [col for col in df.columns if "bird" in col[1]] - df_boat = df.drop(columns=columns_to_drop) + columns_to_keep = [ + col for col in df.columns if col[-2] in ["boatBL", "boatBR", "boatTip"] + ] + df_boat = df.loc[:, columns_to_keep] position_array, confidence_array, list_individuals, list_keypoints = ( get_data_for_load_from_numpy(df_boat) @@ -160,26 +197,24 @@ def add_z_coord_to_position_array(position_array): # fps=30, ) + # Rename individual name for boat + # (it is set as "bird24") + ds_boat["individuals"] = ["boat"] + # export for importable in napari save_poses.to_dlc_file( ds_boat, filepath.parent / (filepath.stem + "_boat.h5"), split_individuals=False ) # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -# Filter low-confidence values +# Filter low-confidence values in boat keypoint trajectories and interpolate # (values below the threshold are set to nan) confidence_threshold = 0.5 - boat_position = filter_by_confidence( ds_boat.position, ds_boat.confidence, threshold=confidence_threshold ) -birds_position = filter_by_confidence( - ds_birds.position, ds_birds.confidence, threshold=confidence_threshold -) -# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -# Linearly interpolate boat points -# (gaps with nan are linearly inteprolated) +# Linearly interpolate gaps in boat trajectory boat_position_interp = interpolate_over_time( boat_position, method="linear", @@ -188,55 +223,70 @@ def add_z_coord_to_position_array(position_array): # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -# Compute rotation to BCS (boat coordinate system) +# Compute axes of BCS (boat coordinate system) # - origin : centroid of all boat keypoints per frame -# - y-axis: vector from boat centroid to tip keypoint -# - x-axis: perpendicular to y-axis, points to left side of the boat -# (it is a rotation from the image coordinate system (ICS)) - -# Note: we need to flip the x-coord to match the "classic plot" -# coordinate system (x-axis from left to right, y-axis from bottom to top). -# We cannot rotate the ICS into the "classic plot", it needs a flip of -# the x-axis. - +# - y-axis: vector from boat centroid to boat tip keypoint +# - x-axis: perpendicular to y-axis, points to right side of the boat # compute origin boat_position_3d = add_z_coord_to_position_array(boat_position_interp) boat_centroid_3d = boat_position_3d.mean("keypoints") +boat_centroid_3d = boat_centroid_3d.drop_vars("individuals").squeeze() -# compute y-axis +# compute BCS y-axis unit vector in image coordinate system (ICS) boat_y_axis_3d = ( - boat_position_3d.sel(keypoints="boatTip") - boat_centroid_3d -).drop_vars(["keypoints"]) -boat_centroid_3d = boat_centroid_3d.drop_vars("individuals").squeeze() -boat_y_axis_3d = boat_y_axis_3d.drop_vars("individuals").squeeze() + convert_to_unit(boat_position_3d.sel(keypoints="boatTip") - boat_centroid_3d) + .drop_vars(["keypoints"]) + .drop_vars("individuals") + .squeeze() +) + +# compute BCS z-axis in image coordinate system (ICS) +# (negative of ICS z-axis, which is positive going into the paper) +boat_z_axis_3d = xr.DataArray(data=[0, 0, -1], coords={"space": ["x", "y", "z"]}) + +# compute boat x-axis in image coordinate system (ICS) +boat_x_axis_3d = xr.cross(boat_y_axis_3d, boat_z_axis_3d, dim="space") + + +# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +# Compute rotation matrix from BCS axes to ICS axes == change of basis +# matrix from ICS coordinates to BCS coordinates +# R.apply(x_BCS) = x_ICS + +# The rotation is approximately a 180 deg rotation +# about the x=y diagonal (axis x=1, y=1, z=0). It essentially +# swaps x and y and flips z -# compute rotation from ICS y-axis to BCS y-axis rotation2boat = xr.apply_ufunc( - lambda v: compute_rotation_to_align_y_axis(v), + lambda xv, yv, zv: R.from_matrix(np.array([xv, yv, zv])), + boat_x_axis_3d, boat_y_axis_3d, - input_core_dims=[["space"]], + boat_z_axis_3d, + input_core_dims=[["space"], ["space"], ["space"]], vectorize=True, ) - # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% # Compute bird keypoints in BCS (translated and rotated) -birds_position_3d = add_z_coord_to_position_array(birds_position) +birds_position_3d = add_z_coord_to_position_array(ds_birds.position) birds_position_3d_BCS = xr.apply_ufunc( lambda rot, trans, vec: rot.apply(vec - trans), - rotation2boat, # rot - boat_centroid_3d, # trans - birds_position_3d, # vec + rotation2boat, # rotation to BCS + boat_centroid_3d, # translation to BCS + birds_position_3d, # trajectories in ICS input_core_dims=[[], ["space"], ["space"]], output_core_dims=[["space"]], vectorize=True, ) -# drop z coordinate +# drop z coordinate for clarity birds_position_BCS = birds_position_3d_BCS.drop_sel(space="z") +# reorder coordinates (space is moved last after apply_ufunc) +birds_position_BCS = birds_position_BCS.transpose("time", "space", "keypoints", "individuals") + # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% # Apply same transform to boat points boat_position_3d_BCS = xr.apply_ufunc( @@ -249,26 +299,33 @@ def add_z_coord_to_position_array(position_array): vectorize=True, ) -# drop z coordinate +# drop z coordinate for clarity boat_position_BCS = boat_position_3d_BCS.drop_sel(space="z") + +# reorder coordinates (space is moved last after apply_ufunc) +boat_position_BCS = boat_position_BCS.transpose("time", "space", "keypoints", "individuals") + # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% # Apply scaling -# Compute boat width and length per frame in pixels +# Compute boat width per frame in pixels boat_width = compute_pairwise_distances( - boat_position_BCS, dim="keypoints", pairs={"boatBL": "boatBR"} + boat_position_BCS, + dim="keypoints", + pairs={"boatBL": "boatBR"}, ) -# boat_width.name = "position" -boat_midpoint_BL_BR = boat_position_BCS.sel(keypoints=["boatBL", "boatBR"]).mean( - dim="keypoints" -) +# Compute boat length per frame in pixels +boat_midpoint_BL_BR = boat_position_BCS.sel( + keypoints=["boatBL", "boatBR"], +).mean(dim="keypoints") + boat_length = compute_norm( boat_position_BCS.sel(keypoints="boatTip") - boat_midpoint_BL_BR ).squeeze() -# check with plot +# check width, length variation with time plt.figure() boat_width.plot(label="width") boat_length.plot(label="length") @@ -279,75 +336,215 @@ def add_z_coord_to_position_array(position_array): # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% # Express spatial coordinates in meters -# We use boat length to scale the data -scale_factor = ( - boat_max_length_in_m / boat_length -) # (boat_max_width_in_m / boat_width) - looks nosier +# We use boat length to scale the data per frame +# (boat_width looks a bit nosier) +scale_factor = boat_max_length_in_m / boat_length + +# Express boat and bird coords in meters boat_position_BCS_in_m = boat_position_BCS * scale_factor birds_position_BCS_in_m = birds_position_BCS * scale_factor +# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +# Interpolate and smooth bird centroid trajectories + +# Interpolation +# - Simplest: linear +# - For continuous 1st and 2nd derivative (speed): cubic spline -- but oscillations occur +# - For continous 1st derivative and less poly wiggle: monotone cubic interpolants +# +# both akima and pchip monotone cubic interpolants: these are constructed to be only once +# continuously differentiable, and attempt to preserve the local shape +# implied by the data. +# https://docs.scipy.org/doc/scipy/tutorial/interpolate/1D.html#monotone-interpolants + +# interpolate +bird_centroid_BCS_in_m = birds_position_BCS_in_m.mean(dim='keypoints') +birds_centroid_BCS_in_m_interp = interpolate_over_time( + bird_centroid_BCS_in_m, method="pchip" +) -# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -# Plot data in BCS +# smooth with rolling median filter +birds_centroid_BCS_in_m_interp_smooth = rolling_filter( + birds_centroid_BCS_in_m_interp, + window=15, # frames (video is 30fps) +) -# Select a time slice for clarity (frames 0 to 654) -time_slice = slice(0, 9000) +# alternatively: smooth with SG filter +# https://en.wikipedia.org/wiki/Savitzky%E2%80%93Golay_filter +# birds_position_BCS_in_m_smooth = savgol_filter( +# birds_position_BCS_in_m_interp, +# window=15, # frames (video is 30fps) +# polyorder=1, +# ) -fig, ax = plt.subplots(1, 1) -# plot bird data and color by individual -cmap = plt.get_cmap("tab10") -color_array = cmap(np.arange(len(birds_position_BCS_in_m.individuals))) +# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +# Plot centroid bird trajectories in BCS (plotly with WebGL) + +data_options = { + "": birds_position_BCS_in_m, + "_interp_smooth": birds_centroid_BCS_in_m_interp_smooth, +} + +# Select bird data to plot +tag = "" # "" for raw data, "_interp_smooth" for interpolated+smooth +position_da = data_options[tag] + +# Select a time slice for clarity +max_frame = position_da.time.max().values.item() +time_slice = slice(0, max_frame) + +# prepare colors by individual +colors = np.vstack([plt.get_cmap("tab20").colors, plt.get_cmap("tab20b").colors]) +color_array = colors[np.arange(len(position_da.individuals)) % len(colors)] + +# plot bird data +fig_plotly = go.Figure() +for i, ind in enumerate(position_da.individuals): + # compute centroid x,y coordinates + x = position_da.sel(time=time_slice, individuals=ind, space="x") + y = position_da.sel(time=time_slice, individuals=ind, space="y") + if "keypoints" in x.dims: + x = x.mean("keypoints") + y = y.mean("keypoints") + x, y = x.values, y.values + + rgb = color_array[i] + color_str = f"rgb({int(rgb[0] * 255)},{int(rgb[1] * 255)},{int(rgb[2] * 255)})" + fig_plotly.add_trace( + go.Scattergl( + x=x, + y=y, + mode="markers", + marker=dict(size=3, color=color_str), + name=ind.item(), + ) + ) -for i, ind in enumerate(birds_position_BCS_in_m.individuals): - # bird centroids - ax.scatter( - birds_position_BCS_in_m.sel(time=time_slice, individuals=ind, space="x").mean( - "keypoints" - ), - birds_position_BCS_in_m.sel(time=time_slice, individuals=ind, space="y").mean( - "keypoints" +# plot boat centroid, color by frame +# squeeze individuals dim (boat has a single "boat" individual) +boat_plot = boat_position_BCS_in_m.sel(time=time_slice).squeeze("individuals") +frame_idx = np.arange(time_slice.stop - time_slice.start + 1) +fig_plotly.add_trace( + go.Scattergl( + x=boat_plot.sel(space="x").mean("keypoints").values, + y=boat_plot.sel(space="y").mean("keypoints").values, + mode="markers", + marker=dict( + size=4, + color=frame_idx, + colorscale="Plasma", + symbol="star", + colorbar=dict(title="frames", x=-0.15), ), - 5, - color=color_array[i], - label=ind.item(), + name="boat centroid", ) - -ax.legend(loc="upper right", bbox_to_anchor=(1.02, 1)) - -# plot boat centroid -sc = ax.scatter( - boat_position_BCS_in_m.sel(time=time_slice, space="x").mean("keypoints"), - boat_position_BCS_in_m.sel(time=time_slice, space="y").mean("keypoints"), - 10, - c=np.arange((time_slice.stop - time_slice.start)), - cmap="plasma", - marker="*", ) -# plot boat keypoints in time +# plot boat keypoints, color by frame for boat_keypoint in ["boatTip", "boatBL", "boatBR"]: - ax.scatter( - boat_position_BCS_in_m.sel(time=time_slice, keypoints=boat_keypoint, space="x"), - boat_position_BCS_in_m.sel(time=time_slice, keypoints=boat_keypoint, space="y"), - 10, - c=np.arange((time_slice.stop - time_slice.start)), - cmap="plasma", + fig_plotly.add_trace( + go.Scattergl( + x=boat_plot.sel(keypoints=boat_keypoint, space="x").values, + y=boat_plot.sel(keypoints=boat_keypoint, space="y").values, + mode="markers", + marker=dict(size=6, color=frame_idx, colorscale="Plasma", showscale=False), + name=boat_keypoint, + showlegend=False, + ) ) +# axes +fig_plotly.update_layout( + xaxis_title="x_BCS (m)", + yaxis_title="y_BCS (m)", + yaxis_scaleanchor="x", + yaxis_scaleratio=1, + legend=dict(x=1.15, y=1, xanchor="left"), + template="plotly_white", +) + + +fig_plotly.show() + +fig_plotly.write_html(output_dir / f"bird_trajectories_BCS_centroid{tag}.html") +# %%%%%%%%%%%%%%%%%%% +# Plot heatmap with movement + +# Set extension of the data +# birds_position_BCS_in_m.sel(space="x").min().values +# birds_position_BCS_in_m.sel(space="x").max().values +# birds_position_BCS_in_m.sel(space="y").min().values +# birds_position_BCS_in_m.sel(space="y").max().values +xmin, xmax = -70, 70 +ymin, ymax = -115, 80 + +bin_edges_x = np.arange(xmin, xmax + 1, 1) # 1m wide +bin_edges_y = np.arange(ymin, ymax + 1, 1) # 1m wide + +# excluding bird1 and 20 +birds_to_exclude = ["bird1", "bird20"] + +fig, ax, hist = plot_occupancy( + birds_position_BCS_in_m, + range=[[xmin, xmax], [ymin, ymax]], + bins=[bin_edges_x, bin_edges_y], + individuals=[ + ind.item() + for ind in birds_position_BCS_in_m.individuals.values + if ind not in birds_to_exclude + ], +) + +# plot mean position of boat keypoints in red +ax.scatter( + boat_position_BCS_in_m.sel(space='x').mean(dim='time').values, + boat_position_BCS_in_m.sel(space='y').mean(dim='time').values, + 10, + c='r' +) + +ax.set_aspect("equal") ax.set_xlabel("x_BCS (m)") ax.set_ylabel("y_BCS (m)") -ax.set_aspect("equal") -ax.invert_xaxis() +fig.axes[-1].set_ylabel("counts") -# add colorbar -cbar = fig.colorbar(sc, ax=ax) -cbar.set_label("frames") +# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +# Save movement datasets as .nc files loadable in napari -# put legend top left -ax.legend(loc="upper left") +# Export bird data +export_as_ds( + birds_position_BCS_in_m, ds_birds.confidence, output_dir / "ds_birds_BCS_in_m.nc" +) -# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -# Save movement datasets -birds_position_BCS_in_m.to_netcdf(output_dir / "birds_position_BCS_in_m.nc") -boat_position_BCS_in_m.to_netcdf(output_dir / "boat_position_BCS_in_m.nc") +# Export boat data +export_as_ds( + boat_position_BCS_in_m, ds_boat.confidence, output_dir / "ds_boat_BCS_in_m.nc" +) + +# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +# Export csvs + +# Save bird trajectories in BCS (all keypoints) +export_dataarray_as_csv( + birds_position_BCS_in_m, output_dir / "birds_position_BCS_in_m.csv" +) + +# Save centroid (mean of all keypoints per frame) +export_dataarray_as_csv( + birds_position_BCS_in_m.mean(dim="keypoints"), + output_dir / "birds_position_BCS_in_m_centroid.csv", +) + +# Save centroid interpolated and smoothed +export_dataarray_as_csv( + birds_centroid_BCS_in_m_interp_smooth, + output_dir / "birds_position_BCS_in_m_centroid_interp_smooth.csv", +) + +# Save boat keypoints per frame in BCS +export_dataarray_as_csv( + boat_position_BCS_in_m, output_dir / "boat_position_BCS_in_m.csv" +) + +# %% diff --git a/notebook_postprocessing.py b/notebook_postprocessing.py deleted file mode 100644 index c3b515a..0000000 --- a/notebook_postprocessing.py +++ /dev/null @@ -1,407 +0,0 @@ -"""A notebook to postprocess DLC trajectories expressed in BCS. - -Requirements: following installation instructions for `movement` -https://movement.neuroinformatics.dev/latest/user_guide/installation.html - -Then run this notebook in that conda environment. - -""" - -# %% - -from pathlib import Path - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import xarray as xr - - -from movement.plots import plot_centroid_trajectory -from movement.utils.vector import compute_norm -from movement.filtering import interpolate_over_time, savgol_filter -from movement.io import save_poses -# Hide attributes globally -xr.set_options(display_expand_attrs=False) - -# %%%%%%%%%%%%%%%%%%%%%%% -# %matplotlib widget - -# %%%%%%%%%%%%%%%%%% -# Input data -notebook_path = Path(__file__).resolve() -input_dir = notebook_path.parent / "output" - -boat_netcdf = "boat_position_BCS_in_m.nc" -birds_netcdf = "birds_position_BCS_in_m.nc" - - -# Postprocessing parameters -fps = 30 # frames per second (video) -min_gap_size = 15 # in frames, for splitting IDs -min_n_frames_with_data = fps * 1 # per ID, for filtering out short trajectories - -# for defining reference smooth trajectory -savgol_window_size = 30 # fps=30 -savgol_poly_order = 1 -interp_method_reference = "akima" -max_distance_to_smoothed = 3 # in m - -# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -# Helper functions - - -def add_segment_ids(df, min_gap_size=1): - """ - Add segment IDs based on NaN gaps in position data. - - Parameters - ---------- - df : DataFrame - The trajectory data - min_gap_size : int - Minimum number of consecutive NaN frames to trigger a split. - - min_gap_size=1: split on any NaN (default, strictest) - - min_gap_size=5: only split if gap is 5+ frames - - min_gap_size=10: tolerate gaps up to 9 frames - """ - - segments = [] - - segment_id_delta = 0 - for _individual, group in df.groupby(["individuals"]): - # Pivot to get x and y side by side - pivoted = group.pivot( - index="time", columns=["keypoints", "space"], values="position" - ) - - # If any x/y coord of a keypoint is not nan, observation is valid - is_valid = pivoted.notna().any(axis=1) - - segment_id = get_significant_gaps(is_valid, min_gap_size) - - # Apply global offset to make IDs unique across individuals - segment_id += segment_id_delta - segment_id_delta = segment_id.max() + 1 - - # Map segment IDs back to original rows - group = group.copy() - group["segment"] = group["time"].map(segment_id) - - # Optionally: filter out the NaN rows - # group = group[group["position"].notna()] - - segments.append(group) - - return pd.concat(segments, ignore_index=True) - - -def get_significant_gaps(is_valid, min_gap_size): - """ - Identify where significant gaps (>= min_gap_size consecutive NaNs) occur. - Returns a Series of segment IDs. - """ - # Identify consecutive runs of the same value - # .ne() --> True where a transition occurs - # .cumsum() ---> runnning ID (Since True = 1 and False = 0, - # this increments by 1 each time there's a transition.) - runs = is_valid.ne(is_valid.shift()).cumsum() - - # Get the length of each run - run_lengths = is_valid.groupby(runs).transform("size") - - # A "significant gap" is an invalid run that's long enough - is_big_gap = (~is_valid) & (run_lengths >= min_gap_size) - - # Segment ID increments each time we EXIT a significant gap - # (i.e., when we go from big_gap=True to big_gap=False) - # restarts after a big gap - # True only where previous was in a gap AND current is not - # .cumsum() --> running count of exits - segment_id = (is_big_gap.shift(fill_value=False) & ~is_big_gap).cumsum() - - return segment_id - - -# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -# Load movement dataset -birds_position_BCS_in_m = xr.load_dataarray(input_dir / birds_netcdf) -boat_position_BCS_in_m = xr.load_dataarray(input_dir / boat_netcdf) - -# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -# Split IDs if gap between DLC IDs is sufficiently large - -# Convert to dataframe first -df_birds_position = birds_position_BCS_in_m.to_dataframe().reset_index() - -# Split IDs -df_with_segments = add_segment_ids(df_birds_position, min_gap_size=min_gap_size) - -# Redefine ID based on "segment" -df_with_segments["individuals"] = df_with_segments["individuals"].str[ - :-1 -] + df_with_segments["segment"].astype(str).str.zfill(3) - -# Convert to xarray data array -birds_position_BCS_m_split = ( - df_with_segments.loc[:, ["time", "space", "keypoints", "individuals", "position"]] - .set_index(["time", "space", "keypoints", "individuals"])["position"] - .to_xarray() -) - -# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -# plot before filtering - -# Select a slice of time for clarity if desired -time_slice = slice(0, 9000) - -fig, ax = plt.subplots(1, 1) - -# plot bird data and color by individual -cmap = plt.get_cmap("tab20") -n_individuals = len(birds_position_BCS_m_split.individuals) -color_array = cmap(np.arange(n_individuals) % cmap.N) - -for i, ind in enumerate(birds_position_BCS_m_split.individuals): - # Get the data for this individual - x_data = birds_position_BCS_m_split.sel( - time=time_slice, individuals=ind, space="x" - ).mean("keypoints") - y_data = birds_position_BCS_m_split.sel( - time=time_slice, individuals=ind, space="y" - ).mean("keypoints") - - # Check if there's any non-NaN data - has_data = (~np.isnan(x_data)).any() and (~np.isnan(y_data)).any() - - # bird centroids - ax.scatter( - x_data, - y_data, - 5, - color=color_array[i], - label=ind.item() if has_data else None, # Only label if has data - ) - -ax.legend(loc="upper left", bbox_to_anchor=(1.02, 1), markerscale=2) -ax.set_xlabel("x_BCS (m)") -ax.set_ylabel("y_BCS (m)") -ax.set_aspect("equal") - - -# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -# Filter out short trajectories - -# Compute number of frames with at least one keypoint per id -valid_frames_per_id = ( - birds_position_BCS_m_split.notnull() - .all(dim="space") - .any(dim="keypoints") - .sum(dim="time") -) - -# filter -birds_position_BCS_m_split = birds_position_BCS_m_split.sel( - individuals=valid_frames_per_id >= min_n_frames_with_data -) - -# %%%%% -# plot after filtering -# Select a slice of time for clarity if desired -time_slice = slice(0, 9000) - -fig, ax = plt.subplots(1, 1) - -# plot bird data and color by individual -cmap = plt.get_cmap("tab20") -n_individuals = len(birds_position_BCS_m_split.individuals) -color_array = cmap(np.arange(n_individuals) % cmap.N) - -for i, ind in enumerate(birds_position_BCS_m_split.individuals): - # Get the data for this individual - x_data = birds_position_BCS_m_split.sel( - time=time_slice, individuals=ind, space="x" - ).mean("keypoints") - y_data = birds_position_BCS_m_split.sel( - time=time_slice, individuals=ind, space="y" - ).mean("keypoints") - - # Check if there's any non-NaN data - has_data = (~np.isnan(x_data)).any() and (~np.isnan(y_data)).any() - - # bird centroids - ax.scatter( - x_data, - y_data, - 5, - color=color_array[i], - label=ind.item() if has_data else None, # Only label if has data - ) - -ax.legend(loc="upper left", bbox_to_anchor=(1.02, 1), markerscale=2) -ax.set_xlabel("x_BCS (m)") -ax.set_ylabel("y_BCS (m)") -ax.set_aspect("equal") - - -# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -# Compute a smoothed reference trajectory to filter out jumps -smoothed_position = savgol_filter( - birds_position_BCS_m_split, savgol_window_size, polyorder=savgol_poly_order -) -smoothed_position_interp = interpolate_over_time( - smoothed_position, method=interp_method_reference -) - -# if distance between birds_position_BCS_m_split and smoothed trajectory -# is above threshold, set datapoints to nan -distance_to_smoothed = compute_norm( - birds_position_BCS_m_split - smoothed_position_interp -) - -birds_position_BCS_m_split_post = birds_position_BCS_m_split.where( - distance_to_smoothed <= max_distance_to_smoothed -) - - -# %%%%%%%%%%%%%%%%%%%%%%% -# Drop IDs with all nans -# Check which individuals have at least one non-NaN value -has_valid_data = birds_position_BCS_m_split_post.notnull().any( - dim=["time", "space", "keypoints"] -) - -# Keep only individuals with valid data -birds_position_BCS_m_split_post = birds_position_BCS_m_split_post.sel( - individuals=has_valid_data -) - -# %%%%%%%%%%%%%%%% -# Save postprocessed *data array* as netcdf -birds_position_BCS_m_split_post.to_netcdf( - input_dir / "birds_position_BCS_m_postprocessed.nc" -) - - -# %%%%%%%%%%%%%%% -# Export as a tidy dataframe with x,y separate columns - -df_birds_post = birds_position_BCS_m_split_post.to_dataframe().reset_index() - -# Optional: drop rows with NaN positions if you only want valid data -df_birds_post = df_birds_post.dropna(subset=["position"]) - -# Pivot to get x and y as separate columns -df_wide = df_birds_post.pivot( - index=["time", "keypoints", "individuals"], columns="space", values="position" -).reset_index() - -# Flatten column names -df_wide.columns.name = None - -# Export to CSV -df_wide.to_csv(input_dir / "birds_position_BCS_m_postprocessed.csv", index=False) - -# %%%%%%%%%%%%%%%%%%%%% -# Save postprocessed trajectories as a **dataset** to load in napari -ds_export = xr.Dataset( - { - "position": birds_position_BCS_m_split_post, - "confidence": xr.full_like( - birds_position_BCS_m_split_post.isel(space=0, drop=True), np.nan - ), - } -) -ds_export.attrs["ds_type"] = "poses" # add dataset-level attributes -ds_export.to_netcdf(input_dir / "birds_BCS_m_postprocessed.nc") - - -# %% -# Save as DLC .h5 -save_poses.to_dlc_file(ds_export, input_dir / "birds_BCS_m_postprocessed.h5") - -# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -# Plot data -# Select a time slice for clarity -time_slice = slice(0, 9000) - -fig, ax = plt.subplots(1, 1) - -# plot bird data and color by individual -cmap = plt.get_cmap("tab20") -n_individuals = len(birds_position_BCS_m_split_post.individuals) -color_array = cmap(np.arange(n_individuals) % cmap.N) - -for i, ind in enumerate(birds_position_BCS_m_split_post.individuals): - # Get the data for this individual - x_data = birds_position_BCS_m_split_post.sel( - time=time_slice, individuals=ind, space="x" - ).mean("keypoints") - y_data = birds_position_BCS_m_split_post.sel( - time=time_slice, individuals=ind, space="y" - ).mean("keypoints") - - # Check if there's any non-NaN data - has_data = (~np.isnan(x_data)).any() and (~np.isnan(y_data)).any() - - # bird centroids - ax.scatter( - x_data, - y_data, - 5, - color=color_array[i], - label=ind.item() if has_data else None, # Only label if has data - ) - - -# plot boat keypoints in time -for boat_keypoint in ["boatTip", "boatBL", "boatBR"]: - ax.scatter( - boat_position_BCS_in_m.sel(time=time_slice, keypoints=boat_keypoint, space="x"), - boat_position_BCS_in_m.sel(time=time_slice, keypoints=boat_keypoint, space="y"), - 10, - c=np.arange((time_slice.stop - time_slice.start)), - cmap="plasma", - ) - -ax.legend(loc="upper left", bbox_to_anchor=(1.02, 1), markerscale=2) -ax.set_xlabel("x_BCS (m)") -ax.set_ylabel("y_BCS (m)") -ax.set_aspect("equal") -ax.invert_xaxis() # x is positive on the left side of the boat - -# %%%%%%%%%%%%%%%%%%%%%%%%% -# Plot an individual bird over time before filtering out jumps and -# with reference smoothed trajectory -fig, ax = plt.subplots() -plot_centroid_trajectory( - birds_position_BCS_m_split.sel(time=time_slice), - individual="bird015", - ax=ax, - label="pre", -) -plot_centroid_trajectory( - smoothed_position_interp.sel(time=time_slice), - individual="bird015", - c="r", - ax=ax, - label="reference", -) -ax.set_xlabel("x (m)") -ax.set_ylabel("y (m)") -ax.set_title("before removing data with 'jumps'") -ax.legend() - - -# %% -# Plot after removing jumps -fig, ax = plt.subplots() -plot_centroid_trajectory( - birds_position_BCS_m_split_post.sel(time=time_slice), - individual="bird015", - ax=ax, -) -ax.set_xlabel("x (m)") -ax.set_ylabel("y (m)") -ax.set_title("after removing data with 'jumps'") -# %% diff --git a/notebook_postprocessing_from_raw.py b/notebook_postprocessing_from_raw.py deleted file mode 100644 index 8c55384..0000000 --- a/notebook_postprocessing_from_raw.py +++ /dev/null @@ -1,407 +0,0 @@ -"""A notebook to postprocess DLC trajectories expressed in BCS. - -Requirements: following installation instructions for `movement` -https://movement.neuroinformatics.dev/latest/user_guide/installation.html - -Then run this notebook in that conda environment. - -""" - -# %% - -from pathlib import Path - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import xarray as xr -from movement.filtering import interpolate_over_time, savgol_filter -from movement.io import load_poses, save_poses -from movement.plots import plot_centroid_trajectory -from movement.utils.vector import compute_norm - -# Hide attributes globally -xr.set_options(display_expand_attrs=False) - -# %%%%%%%%%%%%%%%%%%%%%%% -# %matplotlib widget # pip install ipympl for interactive pltos - -# %%%%%%%%%%%%%%%%%% -# Input data -notebook_path = Path(__file__).resolve() -input_dir = notebook_path.parent / "output" - - -# Postprocessing parameters -fps = 30 # frames per second (video) -min_gap_size = 15 # in frames, for splitting IDs -min_n_frames_with_data = fps * 1 # per ID, for filtering out short trajectories - -# for defining reference smooth trajectory -savgol_window_size = 30 # fps=30 -savgol_poly_order = 1 -interp_method_reference = "akima" -max_distance_to_smoothed = 25 # in pixels - -# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -# Helper functions - - -def add_segment_ids(df, min_gap_size=1): - """ - Add segment IDs based on NaN gaps in position data. - - Parameters - ---------- - df : DataFrame - The trajectory data - min_gap_size : int - Minimum number of consecutive NaN frames to trigger a split. - - min_gap_size=1: split on any NaN (default, strictest) - - min_gap_size=5: only split if gap is 5+ frames - - min_gap_size=10: tolerate gaps up to 9 frames - """ - - segments = [] - - segment_id_delta = 0 - for _individual, group in df.groupby(["individuals"]): - # Pivot to get x and y side by side - pivoted = group.pivot( - index="time", columns=["keypoints", "space"], values="position" - ) - - # If any x/y coord of a keypoint is not nan, observation is valid - is_valid = pivoted.notna().any(axis=1) - - segment_id = get_significant_gaps(is_valid, min_gap_size) - - # Apply global offset to make IDs unique across individuals - segment_id += segment_id_delta - segment_id_delta = segment_id.max() + 1 - - # Map segment IDs back to original rows - group = group.copy() - group["segment"] = group["time"].map(segment_id) - - # Optionally: filter out the NaN rows - # group = group[group["position"].notna()] - - segments.append(group) - - return pd.concat(segments, ignore_index=True) - - -def get_significant_gaps(is_valid, min_gap_size): - """ - Identify where significant gaps (>= min_gap_size consecutive NaNs) occur. - Returns a Series of segment IDs. - """ - # Identify consecutive runs of the same value - # .ne() --> True where a transition occurs - # .cumsum() ---> runnning ID (Since True = 1 and False = 0, - # this increments by 1 each time there's a transition.) - runs = is_valid.ne(is_valid.shift()).cumsum() - - # Get the length of each run - run_lengths = is_valid.groupby(runs).transform("size") - - # A "significant gap" is an invalid run that's long enough - is_big_gap = (~is_valid) & (run_lengths >= min_gap_size) - - # Segment ID increments each time we EXIT a significant gap - # (i.e., when we go from big_gap=True to big_gap=False) - # restarts after a big gap - # True only where previous was in a gap AND current is not - # .cumsum() --> running count of exits - segment_id = (is_big_gap.shift(fill_value=False) & ~is_big_gap).cumsum() - - return segment_id - - -# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -# Load movement dataset -bird_ds = load_poses.from_dlc_file("/Users/sofia/swc/project_seabirds/data/second-iter/FILE00009_sDLC_DekrW32_seabirdNov6shuffle1_snapshot_170_el_filtered_birds.h5") -boat_ds = load_poses.from_dlc_file("/Users/sofia/swc/project_seabirds/data/second-iter/FILE00009_sDLC_DekrW32_seabirdNov6shuffle1_snapshot_170_el_filtered_boat.h5") - -bird_position = bird_ds.position -boat_position = boat_ds.position - -# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -# Split IDs if gap between DLC IDs is sufficiently large - -# Convert to dataframe first -df_birds_position = bird_position.to_dataframe().reset_index() - -# Split IDs -df_with_segments = add_segment_ids(df_birds_position, min_gap_size=min_gap_size) - -# Redefine ID based on "segment" -df_with_segments["individuals"] = df_with_segments["individuals"].str[ - :-1 -] + df_with_segments["segment"].astype(str).str.zfill(3) - -# Convert to xarray data array -birds_position_split = ( - df_with_segments.loc[:, ["time", "space", "keypoints", "individuals", "position"]] - .set_index(["time", "space", "keypoints", "individuals"])["position"] - .to_xarray() -) - -# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -# plot before filtering - -# Select a slice of time for clarity if desired -time_slice = slice(0, 9000) - -fig, ax = plt.subplots(1, 1) - -# plot bird data and color by individual -cmap = plt.get_cmap("tab20") -n_individuals = len(birds_position_split.individuals) -color_array = cmap(np.arange(n_individuals) % cmap.N) - -for i, ind in enumerate(birds_position_split.individuals): - # Get the data for this individual - x_data = birds_position_split.sel( - time=time_slice, individuals=ind, space="x" - ).mean("keypoints") - y_data = birds_position_split.sel( - time=time_slice, individuals=ind, space="y" - ).mean("keypoints") - - # Check if there's any non-NaN data - has_data = (~np.isnan(x_data)).any() and (~np.isnan(y_data)).any() - - # bird centroids - ax.scatter( - x_data, - y_data, - 5, - color=color_array[i], - label=ind.item() if has_data else None, # Only label if has data - ) - -ax.legend(loc="upper left", bbox_to_anchor=(1.02, 1), markerscale=2) -ax.set_xlabel("x_ICS (pixels)") -ax.set_ylabel("y_ICS (pixels)") -ax.set_aspect("equal") - - -# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -# Filter out short trajectories - -# Compute number of frames with at least one keypoint per id -valid_frames_per_id = ( - birds_position_split.notnull() - .all(dim="space") - .any(dim="keypoints") - .sum(dim="time") -) - -# filter -birds_position_split = birds_position_split.sel( - individuals=valid_frames_per_id >= min_n_frames_with_data -) - -# %%%%% -# plot after filtering -# Select a slice of time for clarity if desired -time_slice = slice(0, 9000) - -fig, ax = plt.subplots(1, 1) - -# plot bird data and color by individual -cmap = plt.get_cmap("tab20") -n_individuals = len(birds_position_split.individuals) -color_array = cmap(np.arange(n_individuals) % cmap.N) - -for i, ind in enumerate(birds_position_split.individuals): - # Get the data for this individual - x_data = birds_position_split.sel( - time=time_slice, individuals=ind, space="x" - ).mean("keypoints") - y_data = birds_position_split.sel( - time=time_slice, individuals=ind, space="y" - ).mean("keypoints") - - # Check if there's any non-NaN data - has_data = (~np.isnan(x_data)).any() and (~np.isnan(y_data)).any() - - # bird centroids - ax.scatter( - x_data, - y_data, - 5, - color=color_array[i], - label=ind.item() if has_data else None, # Only label if has data - ) - -ax.legend(loc="upper left", bbox_to_anchor=(1.02, 1), markerscale=2) -ax.set_xlabel("x_BCS (m)") -ax.set_ylabel("y_BCS (m)") -ax.set_aspect("equal") - - -# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -# Compute a smoothed reference trajectory to filter out jumps -smoothed_position = savgol_filter( - birds_position_split, savgol_window_size, polyorder=savgol_poly_order -) -smoothed_position_interp = interpolate_over_time( - smoothed_position, method=interp_method_reference -) - -# if distance between birds_position_BCS_m_split and smoothed trajectory -# is above threshold, set datapoints to nan -distance_to_smoothed = compute_norm( - birds_position_split - smoothed_position_interp -) - -birds_position_split_post = birds_position_split.where( - distance_to_smoothed <= max_distance_to_smoothed -) - - -# %%%%%%%%%%%%%%%%%%%%%%% -# Drop IDs with all nans -# Check which individuals have at least one non-NaN value -has_valid_data = birds_position_split_post.notnull().any( - dim=["time", "space", "keypoints"] -) - -# Keep only individuals with valid data -birds_position_split_post = birds_position_split_post.sel( - individuals=has_valid_data -) - -# %%%%%%%%%%%%%%%% -# Save postprocessed *data array* as netcdf -birds_position_split_post.to_netcdf( - input_dir / "birds_position_ICS_postprocessed.nc" -) - - -# %%%%%%%%%%%%%%% -# Export as a tidy dataframe with x,y separate columns - -df_birds_post = birds_position_split_post.to_dataframe().reset_index() - -# Optional: drop rows with NaN positions if you only want valid data -df_birds_post = df_birds_post.dropna(subset=["position"]) - -# Pivot to get x and y as separate columns -df_wide = df_birds_post.pivot( - index=["time", "keypoints", "individuals"], columns="space", values="position" -).reset_index() - -# Flatten column names -df_wide.columns.name = None - -# Export to CSV -df_wide.to_csv(input_dir / "birds_position_ICS_postprocessed.csv", index=False) - -# %%%%%%%%%%%%%%%%%%%%% -# Save postprocessed trajectories as a **dataset** to load in napari -ds_export = xr.Dataset( - { - "position": birds_position_split_post, - "confidence": xr.full_like( - birds_position_split_post.isel(space=0, drop=True), np.nan - ), - } -) -ds_export.attrs["ds_type"] = "poses" # add dataset-level attributes -ds_export.to_netcdf(input_dir / "birds_ICS_postprocessed.nc") - - -# %% -# Save as DLC .h5 -save_poses.to_dlc_file(ds_export, input_dir / "birds_ICS_postprocessed.h5") - -save_poses.to_dlc_file(boat_ds, input_dir / "boat_ICS.h5") - - -# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -# Plot data -# Select a time slice for clarity -time_slice = slice(0, 9000) - -fig, ax = plt.subplots(1, 1) - -# plot bird data and color by individual -cmap = plt.get_cmap("tab20") -n_individuals = len(birds_position_split_post.individuals) -color_array = cmap(np.arange(n_individuals) % cmap.N) - -for i, ind in enumerate(birds_position_split_post.individuals): - # Get the data for this individual - x_data = birds_position_split_post.sel( - time=time_slice, individuals=ind, space="x" - ).mean("keypoints") - y_data = birds_position_split_post.sel( - time=time_slice, individuals=ind, space="y" - ).mean("keypoints") - - # Check if there's any non-NaN data - has_data = (~np.isnan(x_data)).any() and (~np.isnan(y_data)).any() - - # bird centroids - ax.scatter( - x_data, - y_data, - 5, - color=color_array[i], - label=ind.item() if has_data else None, # Only label if has data - ) - - -# plot boat keypoints in time -for boat_keypoint in ["boatTip", "boatBL", "boatBR"]: - ax.scatter( - boat_position.sel(time=time_slice, keypoints=boat_keypoint, space="x"), - boat_position.sel(time=time_slice, keypoints=boat_keypoint, space="y"), - 10, - c=np.arange((time_slice.stop - time_slice.start)), - cmap="plasma", - ) - -ax.legend(loc="upper left", bbox_to_anchor=(1.02, 1), markerscale=2) -ax.set_xlabel("x_ICS (m)") -ax.set_ylabel("y_ICS (m)") -ax.set_aspect("equal") -ax.invert_yaxis() # to match image coordinate system -# %%%%%%%%%%%%%%%%%%%%%%%%% -# Plot an individual bird over time before filtering out jumps and -# with reference smoothed trajectory -fig, ax = plt.subplots() -plot_centroid_trajectory( - birds_position_split.sel(time=time_slice), - individual="bird015", - ax=ax, - label="pre", -) -plot_centroid_trajectory( - smoothed_position_interp.sel(time=time_slice), - individual="bird015", - c="r", - ax=ax, - label="reference", -) -ax.set_xlabel("x (m)") -ax.set_ylabel("y (m)") -ax.set_title("before removing data with 'jumps'") -ax.legend() - - -# %% -# Plot after removing jumps -fig, ax = plt.subplots() -plot_centroid_trajectory( - birds_position_split_post.sel(time=time_slice), - individual="bird015", - ax=ax, -) -ax.set_xlabel("x (m)") -ax.set_ylabel("y (m)") -ax.set_title("after removing data with 'jumps'")