-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathnotebook_postprocessing.py
More file actions
407 lines (318 loc) · 11.8 KB
/
notebook_postprocessing.py
File metadata and controls
407 lines (318 loc) · 11.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
"""A notebook to postprocess DLC trajectories expressed in BCS.
Requirements: following installation instructions for `movement`
https://movement.neuroinformatics.dev/latest/user_guide/installation.html
Then run this notebook in that conda environment.
"""
# %%
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from movement.plots import plot_centroid_trajectory
from movement.utils.vector import compute_norm
from movement.filtering import interpolate_over_time, savgol_filter
from movement.io import save_poses
# Hide attributes globally
xr.set_options(display_expand_attrs=False)
# %%%%%%%%%%%%%%%%%%%%%%%
# %matplotlib widget
# %%%%%%%%%%%%%%%%%%
# Input data
notebook_path = Path(__file__).resolve()
input_dir = notebook_path.parent / "output"
boat_netcdf = "boat_position_BCS_in_m.nc"
birds_netcdf = "birds_position_BCS_in_m.nc"
# Postprocessing parameters
fps = 30 # frames per second (video)
min_gap_size = 15 # in frames, for splitting IDs
min_n_frames_with_data = fps * 1 # per ID, for filtering out short trajectories
# for defining reference smooth trajectory
savgol_window_size = 30 # fps=30
savgol_poly_order = 1
interp_method_reference = "akima"
max_distance_to_smoothed = 3 # in m
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Helper functions
def add_segment_ids(df, min_gap_size=1):
"""
Add segment IDs based on NaN gaps in position data.
Parameters
----------
df : DataFrame
The trajectory data
min_gap_size : int
Minimum number of consecutive NaN frames to trigger a split.
- min_gap_size=1: split on any NaN (default, strictest)
- min_gap_size=5: only split if gap is 5+ frames
- min_gap_size=10: tolerate gaps up to 9 frames
"""
segments = []
segment_id_delta = 0
for _individual, group in df.groupby(["individuals"]):
# Pivot to get x and y side by side
pivoted = group.pivot(
index="time", columns=["keypoints", "space"], values="position"
)
# If any x/y coord of a keypoint is not nan, observation is valid
is_valid = pivoted.notna().any(axis=1)
segment_id = get_significant_gaps(is_valid, min_gap_size)
# Apply global offset to make IDs unique across individuals
segment_id += segment_id_delta
segment_id_delta = segment_id.max() + 1
# Map segment IDs back to original rows
group = group.copy()
group["segment"] = group["time"].map(segment_id)
# Optionally: filter out the NaN rows
# group = group[group["position"].notna()]
segments.append(group)
return pd.concat(segments, ignore_index=True)
def get_significant_gaps(is_valid, min_gap_size):
"""
Identify where significant gaps (>= min_gap_size consecutive NaNs) occur.
Returns a Series of segment IDs.
"""
# Identify consecutive runs of the same value
# .ne() --> True where a transition occurs
# .cumsum() ---> runnning ID (Since True = 1 and False = 0,
# this increments by 1 each time there's a transition.)
runs = is_valid.ne(is_valid.shift()).cumsum()
# Get the length of each run
run_lengths = is_valid.groupby(runs).transform("size")
# A "significant gap" is an invalid run that's long enough
is_big_gap = (~is_valid) & (run_lengths >= min_gap_size)
# Segment ID increments each time we EXIT a significant gap
# (i.e., when we go from big_gap=True to big_gap=False)
# restarts after a big gap
# True only where previous was in a gap AND current is not
# .cumsum() --> running count of exits
segment_id = (is_big_gap.shift(fill_value=False) & ~is_big_gap).cumsum()
return segment_id
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Load movement dataset
birds_position_BCS_in_m = xr.load_dataarray(input_dir / birds_netcdf)
boat_position_BCS_in_m = xr.load_dataarray(input_dir / boat_netcdf)
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Split IDs if gap between DLC IDs is sufficiently large
# Convert to dataframe first
df_birds_position = birds_position_BCS_in_m.to_dataframe().reset_index()
# Split IDs
df_with_segments = add_segment_ids(df_birds_position, min_gap_size=min_gap_size)
# Redefine ID based on "segment"
df_with_segments["individuals"] = df_with_segments["individuals"].str[
:-1
] + df_with_segments["segment"].astype(str).str.zfill(3)
# Convert to xarray data array
birds_position_BCS_m_split = (
df_with_segments.loc[:, ["time", "space", "keypoints", "individuals", "position"]]
.set_index(["time", "space", "keypoints", "individuals"])["position"]
.to_xarray()
)
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# plot before filtering
# Select a slice of time for clarity if desired
time_slice = slice(0, 9000)
fig, ax = plt.subplots(1, 1)
# plot bird data and color by individual
cmap = plt.get_cmap("tab20")
n_individuals = len(birds_position_BCS_m_split.individuals)
color_array = cmap(np.arange(n_individuals) % cmap.N)
for i, ind in enumerate(birds_position_BCS_m_split.individuals):
# Get the data for this individual
x_data = birds_position_BCS_m_split.sel(
time=time_slice, individuals=ind, space="x"
).mean("keypoints")
y_data = birds_position_BCS_m_split.sel(
time=time_slice, individuals=ind, space="y"
).mean("keypoints")
# Check if there's any non-NaN data
has_data = (~np.isnan(x_data)).any() and (~np.isnan(y_data)).any()
# bird centroids
ax.scatter(
x_data,
y_data,
5,
color=color_array[i],
label=ind.item() if has_data else None, # Only label if has data
)
ax.legend(loc="upper left", bbox_to_anchor=(1.02, 1), markerscale=2)
ax.set_xlabel("x_BCS (m)")
ax.set_ylabel("y_BCS (m)")
ax.set_aspect("equal")
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Filter out short trajectories
# Compute number of frames with at least one keypoint per id
valid_frames_per_id = (
birds_position_BCS_m_split.notnull()
.all(dim="space")
.any(dim="keypoints")
.sum(dim="time")
)
# filter
birds_position_BCS_m_split = birds_position_BCS_m_split.sel(
individuals=valid_frames_per_id >= min_n_frames_with_data
)
# %%%%%
# plot after filtering
# Select a slice of time for clarity if desired
time_slice = slice(0, 9000)
fig, ax = plt.subplots(1, 1)
# plot bird data and color by individual
cmap = plt.get_cmap("tab20")
n_individuals = len(birds_position_BCS_m_split.individuals)
color_array = cmap(np.arange(n_individuals) % cmap.N)
for i, ind in enumerate(birds_position_BCS_m_split.individuals):
# Get the data for this individual
x_data = birds_position_BCS_m_split.sel(
time=time_slice, individuals=ind, space="x"
).mean("keypoints")
y_data = birds_position_BCS_m_split.sel(
time=time_slice, individuals=ind, space="y"
).mean("keypoints")
# Check if there's any non-NaN data
has_data = (~np.isnan(x_data)).any() and (~np.isnan(y_data)).any()
# bird centroids
ax.scatter(
x_data,
y_data,
5,
color=color_array[i],
label=ind.item() if has_data else None, # Only label if has data
)
ax.legend(loc="upper left", bbox_to_anchor=(1.02, 1), markerscale=2)
ax.set_xlabel("x_BCS (m)")
ax.set_ylabel("y_BCS (m)")
ax.set_aspect("equal")
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Compute a smoothed reference trajectory to filter out jumps
smoothed_position = savgol_filter(
birds_position_BCS_m_split, savgol_window_size, polyorder=savgol_poly_order
)
smoothed_position_interp = interpolate_over_time(
smoothed_position, method=interp_method_reference
)
# if distance between birds_position_BCS_m_split and smoothed trajectory
# is above threshold, set datapoints to nan
distance_to_smoothed = compute_norm(
birds_position_BCS_m_split - smoothed_position_interp
)
birds_position_BCS_m_split_post = birds_position_BCS_m_split.where(
distance_to_smoothed <= max_distance_to_smoothed
)
# %%%%%%%%%%%%%%%%%%%%%%%
# Drop IDs with all nans
# Check which individuals have at least one non-NaN value
has_valid_data = birds_position_BCS_m_split_post.notnull().any(
dim=["time", "space", "keypoints"]
)
# Keep only individuals with valid data
birds_position_BCS_m_split_post = birds_position_BCS_m_split_post.sel(
individuals=has_valid_data
)
# %%%%%%%%%%%%%%%%
# Save postprocessed *data array* as netcdf
birds_position_BCS_m_split_post.to_netcdf(
input_dir / "birds_position_BCS_m_postprocessed.nc"
)
# %%%%%%%%%%%%%%%
# Export as a tidy dataframe with x,y separate columns
df_birds_post = birds_position_BCS_m_split_post.to_dataframe().reset_index()
# Optional: drop rows with NaN positions if you only want valid data
df_birds_post = df_birds_post.dropna(subset=["position"])
# Pivot to get x and y as separate columns
df_wide = df_birds_post.pivot(
index=["time", "keypoints", "individuals"], columns="space", values="position"
).reset_index()
# Flatten column names
df_wide.columns.name = None
# Export to CSV
df_wide.to_csv(input_dir / "birds_position_BCS_m_postprocessed.csv", index=False)
# %%%%%%%%%%%%%%%%%%%%%
# Save postprocessed trajectories as a **dataset** to load in napari
ds_export = xr.Dataset(
{
"position": birds_position_BCS_m_split_post,
"confidence": xr.full_like(
birds_position_BCS_m_split_post.isel(space=0, drop=True), np.nan
),
}
)
ds_export.attrs["ds_type"] = "poses" # add dataset-level attributes
ds_export.to_netcdf(input_dir / "birds_BCS_m_postprocessed.nc")
# %%
# Save as DLC .h5
save_poses.to_dlc_file(ds_export, input_dir / "birds_BCS_m_postprocessed.h5")
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Plot data
# Select a time slice for clarity
time_slice = slice(0, 9000)
fig, ax = plt.subplots(1, 1)
# plot bird data and color by individual
cmap = plt.get_cmap("tab20")
n_individuals = len(birds_position_BCS_m_split_post.individuals)
color_array = cmap(np.arange(n_individuals) % cmap.N)
for i, ind in enumerate(birds_position_BCS_m_split_post.individuals):
# Get the data for this individual
x_data = birds_position_BCS_m_split_post.sel(
time=time_slice, individuals=ind, space="x"
).mean("keypoints")
y_data = birds_position_BCS_m_split_post.sel(
time=time_slice, individuals=ind, space="y"
).mean("keypoints")
# Check if there's any non-NaN data
has_data = (~np.isnan(x_data)).any() and (~np.isnan(y_data)).any()
# bird centroids
ax.scatter(
x_data,
y_data,
5,
color=color_array[i],
label=ind.item() if has_data else None, # Only label if has data
)
# plot boat keypoints in time
for boat_keypoint in ["boatTip", "boatBL", "boatBR"]:
ax.scatter(
boat_position_BCS_in_m.sel(time=time_slice, keypoints=boat_keypoint, space="x"),
boat_position_BCS_in_m.sel(time=time_slice, keypoints=boat_keypoint, space="y"),
10,
c=np.arange((time_slice.stop - time_slice.start)),
cmap="plasma",
)
ax.legend(loc="upper left", bbox_to_anchor=(1.02, 1), markerscale=2)
ax.set_xlabel("x_BCS (m)")
ax.set_ylabel("y_BCS (m)")
ax.set_aspect("equal")
ax.invert_xaxis() # x is positive on the left side of the boat
# %%%%%%%%%%%%%%%%%%%%%%%%%
# Plot an individual bird over time before filtering out jumps and
# with reference smoothed trajectory
fig, ax = plt.subplots()
plot_centroid_trajectory(
birds_position_BCS_m_split.sel(time=time_slice),
individual="bird015",
ax=ax,
label="pre",
)
plot_centroid_trajectory(
smoothed_position_interp.sel(time=time_slice),
individual="bird015",
c="r",
ax=ax,
label="reference",
)
ax.set_xlabel("x (m)")
ax.set_ylabel("y (m)")
ax.set_title("before removing data with 'jumps'")
ax.legend()
# %%
# Plot after removing jumps
fig, ax = plt.subplots()
plot_centroid_trajectory(
birds_position_BCS_m_split_post.sel(time=time_slice),
individual="bird015",
ax=ax,
)
ax.set_xlabel("x (m)")
ax.set_ylabel("y (m)")
ax.set_title("after removing data with 'jumps'")
# %%