diff --git a/examples/compute_polarization_AP_inference.py b/examples/compute_polarization_AP_inference.py new file mode 100644 index 000000000..8328203f9 --- /dev/null +++ b/examples/compute_polarization_AP_inference.py @@ -0,0 +1,3958 @@ +#!/usr/bin/env python3 +"""AP inference demo script: PC1-based body-axis inference vs hand-curated GT. + +Tests the prior-free AP inference pipeline against manually curated ground +truth rankings across 5 multi-animal SLEAP datasets. Operates in three +passes: + Pass 1 — R×M Selection: find best individual per file + Pass 2 — Cross-Individual Ordering Consistency: compare raw PC1-based + orderings against best individual's ordering (pseudo GT) + Pass 3 — Inferred AP Concordance: for each + individual, compare the velocity-inferred AP ordering + (anterior_sign × PC1) of GT nodes against hand-curated GT + +After the passes, all GT pair permutations × all individuals are run +through validate_ap and stored in HDF5. analyze_results then reads back +the H5 to produce GT coverage analysis, suggested pair analysis, and +the data for Figure 2. + +Generates two types of figures: + Figure 1 — Per-file detail (2×2 tile per best individual) + Figure 2 — Cross-dataset comparison (skeletons + coverage + ordering) + +Parallelizes at pair×individual level for max throughput. +""" + +import itertools +import multiprocessing as mp +import os +import sys +from collections import defaultdict +from concurrent.futures import ProcessPoolExecutor, as_completed +from datetime import datetime +from pathlib import Path +from urllib.request import urlretrieve + +import h5py +import matplotlib.patches as mpatches +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.legend_handler import HandlerBase +from matplotlib.lines import Line2D + +# Configuration +ROOT_PATH = Path(__file__).parent / "datasets" / "multi-animal" +SLP_DIR = ROOT_PATH / "slp" +MP4_DIR = ROOT_PATH / "mp4" +OUTPUT_DIR = ROOT_PATH / "exports" / "AP-inference-demo" +FIGURES_DIR = OUTPUT_DIR / "figures" +LOGS_DIR = OUTPUT_DIR / "logs" +H5_DIR = OUTPUT_DIR / "h5" +N_WORKERS = mp.cpu_count() + +# Hand-curated ground truth AP rankings (not exhaustive) +# Format: {file_stem: {node_index: rank}} +# Convention: higher rank = more anterior (rank 1 = most posterior in subset) +GROUND_TRUTH = { + # Higher rank = more anterior + "free-moving-2flies-ID-13nodes-1024x1024x1-30_3pxmm": { + 0: 3, # head (most anterior, rank 3) + 1: 2, # thorax + 2: 1, # abdomen (most posterior, rank 1) + }, + "free-moving-2mice-noID-5nodes-1280x1024x1-1_9pxmm": { + 0: 2, # snout (anterior, rank 2) + 3: 1, # tail-base (posterior, rank 1) + }, + "free-moving-4gerbils-ID-14nodes-1024x1280x3-2pxmm": { + 0: 6, # nose (most anterior, rank 6) + 5: 5, # spine1 + 6: 4, # spine2 + 7: 3, # spine3 + 8: 2, # spine4 + 9: 1, # spine5 (most posterior, rank 1) + }, + "free-moving-5mice-noID-11nodes-1280x1024x1-1_97pxmm": { + 0: 3, # nose (most anterior, rank 3) + 1: 2, # neck + 6: 1, # tail_base (most posterior, rank 1) + }, + "freemoving-2bees-noID-21nodes-1535x2048x1-14pxmm": { + 1: 3, # head (most anterior, rank 3) + 0: 2, # thorax + 2: 1, # abdomen (most posterior, rank 1) + }, +} + +# Display labels for files +FILE_LABELS = { + "free-moving-2flies-ID-13nodes-1024x1024x1-30_3pxmm": "2Flies.slp", + "free-moving-2mice-noID-5nodes-1280x1024x1-1_9pxmm": "2Mice.slp", + "free-moving-4gerbils-ID-14nodes-1024x1280x3-2pxmm": "4Gerbils.slp", + "free-moving-5mice-noID-11nodes-1280x1024x1-1_97pxmm": "5Mice.slp", + "freemoving-2bees-noID-21nodes-1535x2048x1-14pxmm": "2Bees.slp", +} + +DEMO_DATASETS = { + "free-moving-2flies-ID-13nodes-1024x1024x1-30_3pxmm": { + "mp4": "https://storage.googleapis.com/sleap-data/datasets/wt_gold.13pt/clips/talk_title_slide%4013150-14500.mp4", + "slp": "https://storage.googleapis.com/sleap-data/datasets/wt_gold.13pt/clips/talk_title_slide%4013150-14500.slp", + }, + "free-moving-5mice-noID-11nodes-1280x1024x1-1_97pxmm": { + "mp4": "https://storage.googleapis.com/sleap-data/datasets/wang_4mice_john/clips/OFTsocial5mice-0000-00%4015488-18736.mp4", + "slp": "https://storage.googleapis.com/sleap-data/datasets/wang_4mice_john/clips/OFTsocial5mice-0000-00%4015488-18736.slp", + }, + "free-moving-2mice-noID-5nodes-1280x1024x1-1_9pxmm": { + "mp4": "https://storage.googleapis.com/sleap-data/datasets/eleni_mice/clips/20200111_USVpairs_court1_M1_F1_top-01112020145828-0000%400-2560.mp4", + "slp": "https://storage.googleapis.com/sleap-data/datasets/eleni_mice/clips/20200111_USVpairs_court1_M1_F1_top-01112020145828-0000%400-2560.slp", + }, + "freemoving-2bees-noID-21nodes-1535x2048x1-14pxmm": { + "mp4": "https://storage.googleapis.com/sleap-data/datasets/yan_bees/clips/bees_demo%4021000-23000.mp4", + "slp": "https://storage.googleapis.com/sleap-data/datasets/yan_bees/clips/bees_demo%4021000-23000.slp", + }, + "free-moving-4gerbils-ID-14nodes-1024x1280x3-2pxmm": { + "mp4": "https://storage.googleapis.com/sleap-data/datasets/nyu-gerbils/clips/2020-3-10_daytime_5mins_compressedTalmo%403200-5760.mp4", + "slp": "https://storage.googleapis.com/sleap-data/datasets/nyu-gerbils/clips/2020-3-10_daytime_5mins_compressedTalmo%403200-5760.slp", + }, +} + + +class TeeOutput: + """Context manager that duplicates stdout to both console and a file.""" + + def __init__(self, filepath): + """Initialize with the target file path.""" + self.filepath = Path(filepath) + self.file = None + self.original_stdout = None + + def __enter__(self): + """Open the file and redirect stdout.""" + self.filepath.parent.mkdir(parents=True, exist_ok=True) + self.file = open(self.filepath, "w") + self.original_stdout = sys.stdout + sys.stdout = self + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Restore stdout and close the file.""" + sys.stdout = self.original_stdout + if self.file: + self.file.close() + return False + + def write(self, text): + """Write text to both stdout and file.""" + self.original_stdout.write(text) + self.file.write(text) + self.file.flush() + + def flush(self): + """Flush both stdout and file.""" + self.original_stdout.flush() + self.file.flush() + + +def _download_file(url, destination): + """Download a single file to the requested destination.""" + destination.parent.mkdir(parents=True, exist_ok=True) + print(f"Downloading {destination.name}...") + urlretrieve(url, destination) + + +def ensure_demo_datasets(): + """Ensure the expected demo SLP/MP4 files exist locally.""" + ROOT_PATH.mkdir(parents=True, exist_ok=True) + SLP_DIR.mkdir(parents=True, exist_ok=True) + MP4_DIR.mkdir(parents=True, exist_ok=True) + + slp_files = sorted(SLP_DIR.glob("*.slp")) + mp4_files = sorted(MP4_DIR.glob("*.mp4")) + + should_bootstrap = len(slp_files) < len(DEMO_DATASETS) or len( + mp4_files + ) < len(DEMO_DATASETS) + if not should_bootstrap: + return + + print( + "Bootstrapping demo datasets " + f"(found {len(slp_files)} .slp and {len(mp4_files)} .mp4 files)..." + ) + + for file_stem, urls in DEMO_DATASETS.items(): + slp_target = SLP_DIR / f"{file_stem}.slp" + mp4_target = MP4_DIR / f"{file_stem}.mp4" + + if not slp_target.exists(): + _download_file(urls["slp"], slp_target) + if not mp4_target.exists(): + _download_file(urls["mp4"], mp4_target) + + +def process_single_validation(args): # noqa: C901 + """Process a single (file, individual, from_kp, to_kp) validation.""" + # Import inside worker to avoid pickling issues + from movement.io import load_poses + from movement.kinematics.body_axis import ValidateAPConfig, validate_ap + + ( + slp_path, + file_stem, + individual, + ind_idx, + from_kp, + to_kp, + from_idx, + to_idx, + n_kp, + n_ind, + n_frames, + ) = args + + # Suppress stdout using context manager + with open(os.devnull, "w") as devnull: + old_stdout = sys.stdout + sys.stdout = devnull + + try: + ds = load_poses.from_sleap_file(Path(slp_path)) + if "individuals" in ds.position.dims: + pos_data = ds.position.sel(individuals=individual) + else: + pos_data = ds.position + + config = ValidateAPConfig() + val = validate_ap( + pos_data, + from_node=from_kp, + to_node=to_kp, + config=config, + verbose=False, + ) + + rec = { + "file": file_stem, + "individual": str(individual), + "individual_idx": ind_idx, + "from_keypoint": from_kp, + "to_keypoint": to_kp, + "from_index": from_idx, + "to_index": to_idx, + "n_keypoints": n_kp, + "n_individuals": n_ind, + "n_frames": n_frames, + "validation_success": val.get("success", False), + "resultant_length": val.get("resultant_length", np.nan), + "vote_margin": val.get("vote_margin", np.nan), + "num_selected_frames": val.get("num_selected_frames", 0), + "circ_mean_dir": val.get("circ_mean_dir", np.nan), + "anterior_sign": val.get("anterior_sign", 0), + "num_clusters": val.get("num_clusters", 0), + "error_msg": val.get("error_msg", ""), + "error": False, + "error_type": "", + } + + r, m = rec["resultant_length"], rec["vote_margin"] + rec["rxm"] = r * m if not (np.isnan(r) or np.isnan(m)) else np.nan + + pr = val.get("pair_report") + if pr: + rec.update( + { + "pr_success": pr.success, + "pr_failure_step": pr.failure_step, + "pr_failure_reason": pr.failure_reason, + "pr_scenario": pr.scenario, + "pr_outcome": pr.outcome, + "pr_warning_message": pr.warning_message, + "pr_input_pair_in_candidates": ( + pr.input_pair_in_candidates + ), + "pr_input_pair_opposite_sides": ( + pr.input_pair_opposite_sides + ), + "pr_input_pair_separation_abs": ( + pr.input_pair_separation_abs + ), + "pr_input_pair_is_distal": pr.input_pair_is_distal, + "pr_input_pair_rank": pr.input_pair_rank, + "pr_input_pair_order_matches_inference": ( + pr.input_pair_order_matches_inference + ), + "pr_max_separation_distal": pr.max_separation_distal, + "pr_max_separation": pr.max_separation, + "pr_lateral_offset_min": pr.lateral_offset_min, + "pr_lateral_offset_max": pr.lateral_offset_max, + "pr_midpoint_pc1": pr.midpoint_pc1, + "pr_pc1_min": pr.pc1_min, + "pr_pc1_max": pr.pc1_max, + "pr_midline_dist_max": pr.midline_dist_max, + # Cascade counts (same for all records + # from the same individual) + "n_valid_nodes": int( + np.sum(~np.isnan(pr.lateral_offsets_norm)) + ) + if len(pr.lateral_offsets_norm) > 0 + else 0, + "n_step1_candidates": len(pr.sorted_candidate_nodes), + "n_step2_pairs": len(pr.valid_pairs), + "n_step3_distal": len(pr.distal_pairs), + "n_step3_proximal": len(pr.proximal_pairs), + } + ) + + if len(pr.max_separation_distal_nodes) > 0: + distal = pr.max_separation_distal_nodes + rec["suggested_from_idx"] = int(distal[0]) + rec["suggested_to_idx"] = int(distal[1]) + rec["suggested_type"] = "distal" + elif len(pr.max_separation_nodes) > 0: + prox = pr.max_separation_nodes + rec["suggested_from_idx"] = int(prox[0]) + rec["suggested_to_idx"] = int(prox[1]) + rec["suggested_type"] = "proximal" + else: + rec["suggested_from_idx"] = -1 + rec["suggested_to_idx"] = -1 + rec["suggested_type"] = "" + + # Store avg_skeleton and PC1 from validation result + avg_skel = val.get("avg_skeleton") + pc1_vec = val.get("PC1") + if avg_skel is not None and not np.all(np.isnan(avg_skel)): + rec["avg_skeleton"] = avg_skel.tolist() # (n_keypoints, 2) + else: + rec["avg_skeleton"] = None + if pc1_vec is not None: + rec["PC1"] = pc1_vec.tolist() # (2,) + else: + rec["PC1"] = None + + # Store velocity projections for histogram + vel_projs = val.get("vel_projs_pc1") + if vel_projs is not None and len(vel_projs) > 0: + rec["vel_projs_pc1"] = vel_projs.tolist() + else: + rec["vel_projs_pc1"] = None + + return rec + + except Exception as e: + import traceback + + print( + f"WARNING: validate_ap failed for " + f"{file_stem} / {individual} " + f"({from_kp} → {to_kp}): " + f"{type(e).__name__}: {e}", + file=sys.stderr, + ) + traceback.print_exc(file=sys.stderr) + return { + "file": file_stem, + "individual": str(individual), + "individual_idx": ind_idx, + "from_keypoint": from_kp, + "to_keypoint": to_kp, + "from_index": from_idx, + "to_index": to_idx, + "n_keypoints": n_kp, + "n_individuals": n_ind, + "n_frames": n_frames, + "error": True, + "error_type": f"{type(e).__name__}: {e}", + "validation_success": False, + "rxm": np.nan, + "resultant_length": np.nan, + "vote_margin": np.nan, + "anterior_sign": 0, + } + finally: + sys.stdout = old_stdout + + +def generate_rxm_tasks(slp_files): + """Pass 1: generate one task per (file, individual) to compute R×M. + + Uses first two GT node indices as an arbitrary pair since R×M depends + only on the individual's motion and body shape, not the input pair. + Returns tasks and metadata for subsequent passes. + """ + from movement.io import load_poses + + tasks = [] + file_metadata = {} + + for slp_file in slp_files: + ds = load_poses.from_sleap_file(slp_file) + keypoints = [str(k) for k in ds.coords["keypoints"].values] + individuals = [str(i) for i in ds.coords["individuals"].values] + n_frames = ds.sizes["time"] + n_kp = len(keypoints) + has_individuals = "individuals" in ds.position.dims + n_ind = len(individuals) if has_individuals else 1 + + # Get GT node indices for this file (use first two as arbitrary pair) + gt_nodes = GROUND_TRUTH.get(slp_file.stem, {}) + gt_indices = list(gt_nodes.keys()) + if len(gt_indices) < 2: + continue + from_idx, to_idx = gt_indices[0], gt_indices[1] + + file_metadata[slp_file.stem] = { + "slp_file": slp_file, + "keypoints": keypoints, + "individuals": individuals, + "n_frames": n_frames, + "n_kp": n_kp, + "n_ind": n_ind, + "has_individuals": has_individuals, + "gt_indices": gt_indices, + } + + # One task per individual + for ind_idx in range(n_ind): + individual = individuals[ind_idx] if has_individuals else "single" + tasks.append( + ( + str(slp_file), + slp_file.stem, + individual, + ind_idx, + keypoints[from_idx], + keypoints[to_idx], + from_idx, + to_idx, + n_kp, + n_ind, + n_frames, + ) + ) + + return tasks, file_metadata + + +def find_best_individuals(rxm_results): + """Pass 1: Select best individual per file by maximum R×M. + + Returns: + best_individuals: {file_stem: individual_name} + all_rxm: {file_stem: {individual: rxm_value}} - + R×M values for all individuals + file_individual_data: {file_stem: {individual: + {"avg_skeleton": ..., "pc1": ..., "anterior_sign": ...}}} - + per-individual skeleton, PC1, and anterior_sign for downstream use + + """ + file_individual_rxm = defaultdict(dict) + file_individual_data = defaultdict(dict) + + for rec in rxm_results: + file_stem = rec["file"] + individual = rec["individual"] + rxm = rec["rxm"] + if not np.isnan(rxm): + file_individual_rxm[file_stem][individual] = rxm + file_individual_data[file_stem][individual] = { + "avg_skeleton": rec.get("avg_skeleton"), + "pc1": rec.get("PC1"), + "anterior_sign": rec.get("anterior_sign", 0), + } + + best_individuals = {} + for file_stem, individuals in file_individual_rxm.items(): + if not individuals: + continue + best_individuals[file_stem] = max(individuals, key=individuals.get) + + all_rxm = dict(file_individual_rxm) + return best_individuals, all_rxm, dict(file_individual_data) + + +def compute_pc1_orderings( + file_individual_data, file_metadata, best_individuals +): + """Project GT nodes onto each individual's PC1 and rank by projection. + + For each individual, GT nodes are projected onto that individual's + PC1 vector. Nodes are ranked by descending PC1 projection + (rank 1 = highest projection; whether this corresponds to + anterior or posterior depends on the individual's anterior_sign). + + Returns: + best_pc1_orderings: {file_stem: {node_idx: pc1_rank}} - + best individual's ordering (used as pseudo GT in Pass 2) + all_pc1_orderings: {file_stem: {individual: {node_idx: pc1_rank}}} - + all individuals' PC1-based orderings (used in Pass 2) + + """ + best_pc1_orderings = {} + all_pc1_orderings = {} + + for file_stem, ind_data in file_individual_data.items(): + if file_stem not in file_metadata: + continue + + gt_indices = file_metadata[file_stem]["gt_indices"] + best_ind = best_individuals.get(file_stem) + all_pc1_orderings[file_stem] = {} + + for individual, data in ind_data.items(): + avg_skeleton = data.get("avg_skeleton") + pc1 = data.get("pc1") + + if avg_skeleton is None or pc1 is None: + continue + + avg_skeleton = np.array(avg_skeleton) + pc1 = np.array(pc1) + + # Project GT nodes onto raw PC1 (no anterior_sign correction). + # Rank 1 = highest raw PC1 projection; whether this corresponds + # to anterior or posterior depends on anterior_sign. + gt_projections = {} + for node_idx in gt_indices: + if node_idx < len(avg_skeleton): + pos = avg_skeleton[node_idx] + if not np.any(np.isnan(pos)): + proj = np.dot(pos, pc1) + gt_projections[node_idx] = proj + + # Rank by raw PC1 projection (highest projection = rank 1) + sorted_nodes = sorted( + gt_projections.keys(), + key=lambda x: gt_projections[x], + reverse=True, + ) + ind_ordering = { + node: rank + 1 for rank, node in enumerate(sorted_nodes) + } + all_pc1_orderings[file_stem][individual] = ind_ordering + + # Store best individual's ordering separately + if individual == best_ind: + best_pc1_orderings[file_stem] = ind_ordering + + return best_pc1_orderings, all_pc1_orderings + + +def compute_inferred_ap_concordance(file_individual_data): + """Pass 3: Compare each individual's inferred AP ordering against GT. + + For each individual, GT nodes are projected onto the inferred AP axis + (anterior_sign × PC1, where positive = more anterior). All C(n,2) + unique pairs of GT nodes are compared pairwise: a pair is concordant + if the node with the higher inferred AP coordinate also has the + higher hand-curated GT rank (higher rank = more anterior). + + Returns: + individual_accuracy: {file_stem: {individual: + {"correct": n, "total": n, "accuracy": pct}}} + + """ + individual_accuracy = {} + + for file_stem, gt_ranks in GROUND_TRUTH.items(): + individual_accuracy[file_stem] = {} + ind_data = file_individual_data.get(file_stem, {}) + + for individual, data in ind_data.items(): + avg_skeleton = data.get("avg_skeleton") + pc1 = data.get("pc1") + anterior_sign = data.get("anterior_sign", 0) + + if avg_skeleton is None or pc1 is None or anterior_sign == 0: + continue + + avg_skeleton = np.array(avg_skeleton) + pc1 = np.array(pc1) + + # Inferred AP unit vector: anterior_sign × PC1 + pc1_norm = pc1 / np.linalg.norm(pc1) + e_ap = anterior_sign * pc1_norm + + # Project GT nodes onto inferred AP axis + gt_ap_coords = {} + for node_idx in gt_ranks: + if node_idx < len(avg_skeleton): + pos = avg_skeleton[node_idx] + if not np.any(np.isnan(pos)): + gt_ap_coords[node_idx] = np.dot(pos, e_ap) + + # Compare pairwise ordering against GT + correct = 0 + total = 0 + for a, b in itertools.combinations(gt_ap_coords.keys(), 2): + total += 1 + ap_a, ap_b = gt_ap_coords[a], gt_ap_coords[b] + gt_a, gt_b = gt_ranks[a], gt_ranks[b] + + # Concordant if relative ordering agrees: + # higher AP coord = more anterior should match + # higher GT rank = more anterior + if (ap_a > ap_b and gt_a > gt_b) or ( + ap_a < ap_b and gt_a < gt_b + ): + correct += 1 + + accuracy = 100 * correct / total if total > 0 else 0 + individual_accuracy[file_stem][individual] = { + "correct": correct, + "total": total, + "accuracy": accuracy, + } + + return individual_accuracy + + +def compare_orderings_to_pseudo_gt(all_pc1_orderings, pseudo_gt_orderings): + """Pass 2: Compare each individual's PC1-based ordering against pseudo GT. + + The pseudo GT is the best individual's PC1-based ordering. Comparison + is strict list equality: the sorted node-index sequences must be + identical. A single rank swap between any two nodes counts as a + mismatch. + + Returns: + ordering_matches: {file_stem: {individual: bool}} - + True if individual's PC1-based ordering matches pseudo GT + + """ + ordering_matches = {} + + for file_stem, best_ordering in pseudo_gt_orderings.items(): + ordering_matches[file_stem] = {} + ind_orderings = all_pc1_orderings.get(file_stem, {}) + + for individual, ind_ordering in ind_orderings.items(): + # Compare orderings - match if relative order of all nodes same + # Convert to sorted list of node indices by rank to compare + best_sorted = sorted( + best_ordering.keys(), key=lambda x: best_ordering[x] + ) + ind_sorted = sorted( + ind_ordering.keys(), key=lambda x: ind_ordering[x] + ) + + matches = best_sorted == ind_sorted + ordering_matches[file_stem][individual] = matches + + return ordering_matches + + +def generate_gt_validation_tasks(file_metadata, best_individuals): + """Generate H5 tasks for all GT node pair permutations. + + Tests all permutations of GT node pairs on ALL individuals. + Results compared against hand-curated GROUND_TRUTH (higher rank = more + anterior). + + Convention: (from_idx, to_idx) is "correctly ordered" if from is + posterior (lower GT rank) and to is anterior (higher GT rank), matching + the posterior→anterior direction used by compute_polarization's + body_axis_keypoints. + """ + tasks = [] + keypoints_by_file = {} + + for file_stem, meta in file_metadata.items(): + if file_stem not in best_individuals: + continue + + keypoints = meta["keypoints"] + individuals = meta["individuals"] + gt_indices = meta["gt_indices"] + slp_file = meta["slp_file"] + + keypoints_by_file[file_stem] = keypoints + + # Test all ordered GT pairs (permutations) on ALL individuals + for from_idx, to_idx in itertools.permutations(gt_indices, 2): + for ind_idx in range(meta["n_ind"]): + has_ind = meta["has_individuals"] + individual = individuals[ind_idx] if has_ind else "single" + tasks.append( + ( + str(slp_file), + file_stem, + individual, + ind_idx, + keypoints[from_idx], + keypoints[to_idx], + from_idx, + to_idx, + meta["n_kp"], + meta["n_ind"], + meta["n_frames"], + ) + ) + + return tasks, keypoints_by_file + + +def save_to_h5( # noqa: C901 + all_results, + keypoints_by_file, + output_path, + individual_accuracy=None, + all_rxm=None, +): + """Save results to HDF5.""" + n = len(all_results) + dt_str = h5py.special_dtype(vlen=str) + + str_fields = [ + "file", + "individual", + "from_keypoint", + "to_keypoint", + "error_msg", + "error_type", + "pr_failure_step", + "pr_failure_reason", + "pr_outcome", + "pr_warning_message", + "suggested_type", + ] + int_fields = [ + "individual_idx", + "from_index", + "to_index", + "n_keypoints", + "n_individuals", + "n_frames", + "num_selected_frames", + "anterior_sign", + "num_clusters", + "pr_scenario", + "pr_input_pair_rank", + "suggested_from_idx", + "suggested_to_idx", + "n_valid_nodes", + "n_step1_candidates", + "n_step2_pairs", + "n_step3_distal", + "n_step3_proximal", + ] + float_fields = [ + "resultant_length", + "vote_margin", + "rxm", + "circ_mean_dir", + "pr_input_pair_separation_abs", + "pr_max_separation_distal", + "pr_max_separation", + "pr_lateral_offset_min", + "pr_lateral_offset_max", + "pr_midpoint_pc1", + "pr_pc1_min", + "pr_pc1_max", + "pr_midline_dist_max", + ] + bool_fields = [ + "validation_success", + "error", + "pr_success", + "pr_input_pair_in_candidates", + "pr_input_pair_opposite_sides", + "pr_input_pair_is_distal", + "pr_input_pair_order_matches_inference", + ] + + with h5py.File(output_path, "w") as f: + f.attrs["created"] = datetime.now().isoformat() + f.attrs["n_files"] = len(keypoints_by_file) + f.attrs["n_records"] = n + + kp_grp = f.create_group("keypoints") + for fname, kps in keypoints_by_file.items(): + kp_grp.create_dataset(fname, data=np.array(kps, dtype="S")) + + for field in str_fields: + f.create_dataset( + field, + data=np.array( + [r.get(field, "") for r in all_results], dtype=dt_str + ), + ) + + for field in int_fields: + f.create_dataset( + field, + data=np.array( + [r.get(field, -1) for r in all_results], dtype=np.int32 + ), + ) + + for field in float_fields: + f.create_dataset( + field, + data=np.array( + [r.get(field, np.nan) for r in all_results], + dtype=np.float64, + ), + ) + + for field in bool_fields: + f.create_dataset( + field, + data=np.array( + [r.get(field, False) for r in all_results], dtype=bool + ), + ) + + # Save avg_skeleton, PC1, and vel_projs arrays + # Group by (file, individual) - these are per-individual, not per-pair + skel_grp = f.create_group("skeletons") + pc1_grp = f.create_group("pc1_vectors") + velprojs_grp = f.create_group("vel_projs_pc1") + + # First pass: collect best data for each (file, individual) pair + # Need to search all records to find valid data (not just first) + # {(file, individual): {"skeleton": ..., "pc1": ..., "vel_projs": ...}} + pair_data = {} + for r in all_results: + file_stem = r.get("file", "") + individual = r.get("individual", "") + if not file_stem or not individual: + continue + + key = (file_stem, individual) + if key not in pair_data: + pair_data[key] = { + "skeleton": None, + "pc1": None, + "vel_projs": None, + } + + # Update with valid data if current is None + pd = pair_data[key] + if pd["skeleton"] is None and r.get("avg_skeleton") is not None: + pd["skeleton"] = r.get("avg_skeleton") + if pd["pc1"] is None and r.get("PC1") is not None: + pd["pc1"] = r.get("PC1") + if pd["vel_projs"] is None and r.get("vel_projs_pc1") is not None: + pd["vel_projs"] = r.get("vel_projs_pc1") + + # Second pass: save collected data to H5 + for (file_stem, individual), data in pair_data.items(): + # Create file group if needed + if file_stem not in skel_grp: + skel_grp.create_group(file_stem) + pc1_grp.create_group(file_stem) + velprojs_grp.create_group(file_stem) + + if data["skeleton"] is not None: + arr = np.array(data["skeleton"], dtype=np.float64) + skel_grp[file_stem].create_dataset(individual, data=arr) + if data["pc1"] is not None: + arr = np.array(data["pc1"], dtype=np.float64) + pc1_grp[file_stem].create_dataset(individual, data=arr) + if data["vel_projs"] is not None: + arr = np.array(data["vel_projs"], dtype=np.float64) + velprojs_grp[file_stem].create_dataset(individual, data=arr) + + # Save individual accuracy (raw PC1 direction diag for Fig 1) + if individual_accuracy: + acc_grp = f.create_group("individual_accuracy") + for file_stem, accuracies in individual_accuracy.items(): + file_grp = acc_grp.create_group(file_stem) + for individual, acc in accuracies.items(): + ind_grp = file_grp.create_group(individual) + ind_grp.create_dataset("correct", data=acc["correct"]) + ind_grp.create_dataset("total", data=acc["total"]) + ind_grp.create_dataset("accuracy", data=acc["accuracy"]) + + # Save R×M values for all individuals + if all_rxm: + rxm_grp = f.create_group("individual_rxm") + for file_stem, individuals in all_rxm.items(): + file_grp = rxm_grp.create_group(file_stem) + for individual, rxm_val in individuals.items(): + file_grp.create_dataset(individual, data=rxm_val) + + +# Analysis functions + + +def load_h5_data(h5_path): # noqa: C901 + """Load relevant fields from H5 file.""" + + def decode_str(x): + return x.decode() if isinstance(x, bytes) else x + + data = {} + with h5py.File(h5_path, "r") as f: + # String fields + data["file"] = [decode_str(x) for x in f["file"][:]] + data["individual"] = [decode_str(x) for x in f["individual"][:]] + data["suggested_type"] = [ + decode_str(x) for x in f["suggested_type"][:] + ] + + # Index fields + data["from_index"] = np.array(f["from_index"]) + data["to_index"] = np.array(f["to_index"]) + data["suggested_from_idx"] = np.array(f["suggested_from_idx"]) + data["suggested_to_idx"] = np.array(f["suggested_to_idx"]) + + # Validation results + data["validation_success"] = np.array(f["validation_success"]) + data["anterior_sign"] = np.array(f["anterior_sign"]) + data["pr_input_pair_order_matches_inference"] = np.array( + f["pr_input_pair_order_matches_inference"] + ) + data["pr_success"] = np.array(f["pr_success"]) + data["pr_input_pair_in_candidates"] = np.array( + f["pr_input_pair_in_candidates"] + ) + data["rxm"] = np.array(f["rxm"]) + data["vote_margin"] = np.array(f["vote_margin"]) + data["resultant_length"] = np.array(f["resultant_length"]) + n = len(data["file"]) + if "circ_mean_dir" in f: + data["circ_mean_dir"] = np.array(f["circ_mean_dir"]) + else: + data["circ_mean_dir"] = np.full(n, np.nan) + if "num_selected_frames" in f: + data["num_selected_frames"] = np.array(f["num_selected_frames"]) + else: + data["num_selected_frames"] = np.zeros(n, dtype=np.int32) + if "n_frames" in f: + data["n_frames"] = np.array(f["n_frames"]) + else: + data["n_frames"] = np.zeros(n, dtype=np.int32) + + # Cascade counts (backward-compatible) + for cascade_field in [ + "n_valid_nodes", + "n_step1_candidates", + "n_step2_pairs", + "n_step3_distal", + "n_step3_proximal", + ]: + if cascade_field in f: + data[cascade_field] = np.array(f[cascade_field]) + else: + data[cascade_field] = np.zeros(n, dtype=np.int32) + + # Load skeleton and PC1 data (if present) + data["skeletons"] = {} # {file_stem: {individual: np.array}} + data["pc1_vectors"] = {} # {file_stem: {individual: np.array}} + + if "skeletons" in f: + for file_stem in f["skeletons"]: + data["skeletons"][file_stem] = {} + for individual in f["skeletons"][file_stem]: + data["skeletons"][file_stem][individual] = np.array( + f["skeletons"][file_stem][individual] + ) + + if "pc1_vectors" in f: + for file_stem in f["pc1_vectors"]: + data["pc1_vectors"][file_stem] = {} + for individual in f["pc1_vectors"][file_stem]: + data["pc1_vectors"][file_stem][individual] = np.array( + f["pc1_vectors"][file_stem][individual] + ) + + # Load velocity projections for histogram + data["vel_projs_pc1"] = {} + if "vel_projs_pc1" in f: + for file_stem in f["vel_projs_pc1"]: + data["vel_projs_pc1"][file_stem] = {} + for individual in f["vel_projs_pc1"][file_stem]: + data["vel_projs_pc1"][file_stem][individual] = np.array( + f["vel_projs_pc1"][file_stem][individual] + ) + + # Load keypoint names + data["keypoints"] = {} + if "keypoints" in f: + for file_stem in f["keypoints"]: + kp_data = f["keypoints"][file_stem][:] + data["keypoints"][file_stem] = [ + x.decode() if isinstance(x, bytes) else x for x in kp_data + ] + + # Load individual accuracy data (raw PC1 direction diagnostic) + data["individual_accuracy"] = {} + if "individual_accuracy" in f: + for file_stem in f["individual_accuracy"]: + data["individual_accuracy"][file_stem] = {} + for individual in f["individual_accuracy"][file_stem]: + ind_grp = f["individual_accuracy"][file_stem][individual] + data["individual_accuracy"][file_stem][individual] = { + "correct": int(ind_grp["correct"][()]), + "total": int(ind_grp["total"][()]), + "accuracy": float(ind_grp["accuracy"][()]), + } + + # Load R×M values for all individuals + data["individual_rxm"] = {} + if "individual_rxm" in f: + for file_stem in f["individual_rxm"]: + data["individual_rxm"][file_stem] = {} + for individual in f["individual_rxm"][file_stem]: + data["individual_rxm"][file_stem][individual] = float( + f["individual_rxm"][file_stem][individual][()] + ) + + return data + + +def find_best_individual_per_file(data): + """Find the individual with highest mean R×M for each file.""" + file_individual_rxm = defaultdict(lambda: defaultdict(list)) + + for i in range(len(data["file"])): + file_stem = data["file"][i] + individual = data["individual"][i] + rxm = data["rxm"][i] + if not np.isnan(rxm): + file_individual_rxm[file_stem][individual].append(rxm) + + best_individual = {} + for file_stem, individuals in file_individual_rxm.items(): + best_rxm = -1 + best_ind = None + for ind, rxm_list in individuals.items(): + mean_rxm = np.mean(rxm_list) + if mean_rxm > best_rxm: + best_rxm = mean_rxm + best_ind = ind + best_individual[file_stem] = best_ind + + return best_individual + + +def find_step1_surviving_nodes(data, best_individual): + """Find nodes surviving the lateral alignment filter (Step 1) per file. + + A node survives Step 1 if it appears in any pair where + pr_input_pair_in_candidates=True for the file's best individual. + """ + file_surviving_nodes = defaultdict(set) + + for i in range(len(data["file"])): + file_stem = data["file"][i] + individual = data["individual"][i] + + # Only consider best individual + if individual != best_individual.get(file_stem): + continue + + # If this pair survived Step 1, both nodes are candidates + if data["pr_input_pair_in_candidates"][i]: + from_idx = int(data["from_index"][i]) + to_idx = int(data["to_index"][i]) + file_surviving_nodes[file_stem].add(from_idx) + file_surviving_nodes[file_stem].add(to_idx) + + return file_surviving_nodes + + +def compute_gt_coverage(file_surviving_nodes): + """Compute lateral filter coverage: fraction of GT nodes surviving Step 1. + + Returns dict: {file_stem: {surviving_in_gt, gt_total, coverage_pct, ...}} + """ + coverage = {} + + for file_stem, gt_ranks in GROUND_TRUTH.items(): + gt_nodes = set(gt_ranks.keys()) + n_gt_total = len(gt_nodes) + + surviving = file_surviving_nodes.get(file_stem, set()) + surviving_in_gt = surviving & gt_nodes + n_surviving_in_gt = len(surviving_in_gt) + + if n_gt_total > 0: + coverage_pct = 100 * n_surviving_in_gt / n_gt_total + else: + coverage_pct = 0 + + coverage[file_stem] = { + "surviving_in_gt": n_surviving_in_gt, + "gt_total": n_gt_total, + "coverage_pct": coverage_pct, + "surviving_nodes": sorted(surviving), + "gt_nodes": sorted(gt_nodes), + "surviving_in_gt_nodes": sorted(surviving_in_gt), + "gt_not_surviving": sorted(gt_nodes - surviving), + } + + return coverage + + +def analyze_suggested_pairs(data, best_individual): # noqa: C901 + """Report which node pair the 3-step filter cascade auto-selected. + + For each file's best individual, retrieves the suggested pair + (posterior→anterior, distal vs proximal) and notes whether the + selected nodes happen to fall within the hand-curated GT subset. + GT membership is incidental context — the pipeline often selects + nodes outside the GT subset, which is expected. + """ + # Get suggested pair for each file (use first record for best individual) + file_suggested = {} + + for i in range(len(data["file"])): + file_stem = data["file"][i] + individual = data["individual"][i] + + if individual != best_individual.get(file_stem): + continue + + # Only record once per file (suggested pair same for all input pairs) + if file_stem in file_suggested: + continue + + suggested_from = int(data["suggested_from_idx"][i]) + suggested_to = int(data["suggested_to_idx"][i]) + suggested_type = data["suggested_type"][i] + + file_suggested[file_stem] = { + "from_idx": suggested_from, + "to_idx": suggested_to, + "type": suggested_type, + } + + # Analyze each file + results = {} + for file_stem, gt_ranks in GROUND_TRUTH.items(): + gt_nodes = set(gt_ranks.keys()) + default = {"from_idx": -1, "to_idx": -1, "type": ""} + suggested = file_suggested.get(file_stem, default) + + from_idx = suggested["from_idx"] + to_idx = suggested["to_idx"] + stype = suggested["type"] + + # Check if suggested pair is valid + if from_idx < 0 or to_idx < 0: + results[file_stem] = { + "suggested_from": from_idx, + "suggested_to": to_idx, + "suggested_type": stype, + "from_in_gt": False, + "to_in_gt": False, + "both_in_gt": False, + "order_correct": None, + "status": "NO SUGGESTION", + } + continue + + from_in_gt = from_idx in gt_nodes + to_in_gt = to_idx in gt_nodes + both_in_gt = from_in_gt and to_in_gt + + # Check ordering if both are in GT + order_correct = None + if both_in_gt: + from_rank = gt_ranks[from_idx] + to_rank = gt_ranks[to_idx] + # Correct: from=posterior (lower rank), to=anterior (higher) + order_correct = from_rank < to_rank + + # Determine status (GT membership is incidental — the pipeline + # often selects nodes outside the hand-curated GT subset) + if both_in_gt and order_correct: + status = "Both in GT, order correct" + elif both_in_gt and not order_correct: + status = "Both in GT, order reversed" + elif from_in_gt or to_in_gt: + status = "One node in GT" + else: + status = "Neither node in GT" + + results[file_stem] = { + "suggested_from": from_idx, + "suggested_to": to_idx, + "suggested_type": stype, + "from_in_gt": from_in_gt, + "to_in_gt": to_in_gt, + "both_in_gt": both_in_gt, + "order_correct": order_correct, + "from_rank": gt_ranks.get(from_idx, "NaN"), + "to_rank": gt_ranks.get(to_idx, "NaN"), + "status": status, + } + + return results + + +def get_ground_truth_order(file_stem, from_idx, to_idx): + """Determine hand-curated GT AP ordering for a node pair. + + Convention: (from, to) = (posterior, anterior). This matches + compute_polarization's body_axis_keypoints convention where the vector + points FROM posterior TO anterior. + + Returns: + 1 if from_idx is posterior to to_idx (correctly ordered) + -1 if from_idx is anterior to to_idx (reversed) + None if either index is not in the hand-curated ground truth + + """ + if file_stem not in GROUND_TRUTH: + return None + + gt = GROUND_TRUTH[file_stem] + if from_idx not in gt or to_idx not in gt: + return None + + from_rank = gt[from_idx] + to_rank = gt[to_idx] + + # Higher rank = more anterior, lower rank = more posterior + # Correct: from=posterior (lower rank), to=anterior (higher rank) + if from_rank < to_rank: + return 1 # from is posterior, to is anterior (correct order) + elif from_rank > to_rank: + return -1 # from is anterior, to is posterior (incorrect order) + else: + return 0 # same rank (shouldn't happen with non-NaN unique ranks) + + +def analyze_results(h5_path): # noqa: C901 + """Report filter cascade, suggested pairs, and Figure 2 data. + + Read the H5 file produced by the parallel validate_ap runs. For each + file's best individual (highest mean R×M), log the 3-step filter + cascade progression (with GT coverage folded into Step 1) and the + suggested pair analysis. Return cascade stats, GT coverage, and + suggested pair data for Figure 2 rendering. + """ + print("\n" + "─" * 70) + print("REPORTING: Filter Cascade & Suggested Pairs") + print("─" * 70) + + data = load_h5_data(h5_path) + n_records = len(data["file"]) + + # Find best individual per file (already reported in Pass 1, + # recomputed here from H5 using mean R×M across all records) + best_individual = find_best_individual_per_file(data) + + # Find Step 1 surviving nodes and compute GT coverage + file_surviving_nodes = find_step1_surviving_nodes(data, best_individual) + gt_coverage = compute_gt_coverage(file_surviving_nodes) + + # Extract cascade progression stats for each file's best individual + cascade_stats = {} + seen_files = set() + for i in range(n_records): + file_stem = data["file"][i] + individual = data["individual"][i] + + if individual != best_individual.get(file_stem): + continue + if file_stem in seen_files: + continue + seen_files.add(file_stem) + + n_valid = int(data["n_valid_nodes"][i]) + n_s1 = int(data["n_step1_candidates"][i]) + n_s2 = int(data["n_step2_pairs"][i]) + n_s3d = int(data["n_step3_distal"][i]) + n_s3p = int(data["n_step3_proximal"][i]) + + n_candidate_pairs = n_s1 * (n_s1 - 1) // 2 if n_s1 >= 2 else 0 + + cascade_stats[file_stem] = { + "n_valid_nodes": n_valid, + "n_step1_candidates": n_s1, + "n_candidate_pairs": n_candidate_pairs, + "n_step2_pairs": n_s2, + "n_step3_distal": n_s3d, + "n_step3_proximal": n_s3p, + } + + # Log cascade progression with GT coverage folded into Step 1 + print("\n3-STEP FILTER CASCADE (best individual per file):") + for file_stem in sorted(cascade_stats.keys()): + cs = cascade_stats[file_stem] + cov = gt_coverage.get(file_stem, {}) + label = FILE_LABELS.get(file_stem, file_stem[:15]) + n_valid = cs["n_valid_nodes"] + n_s1 = cs["n_step1_candidates"] + n_cp = cs["n_candidate_pairs"] + n_s2 = cs["n_step2_pairs"] + n_s3d = cs["n_step3_distal"] + + gt_surv = cov.get("surviving_in_gt", 0) + gt_tot = cov.get("gt_total", 0) + gt_str = f" (GT: {gt_surv}/{gt_tot} nodes)" if gt_tot > 0 else "" + + print(f"\n {label}:") + print(f" Step 1 (lateral): {n_s1}/{n_valid} nodes{gt_str}") + print(f" Step 2 (opposite): {n_s2}/{n_cp} candidate pairs") + print(f" Step 3 (distal): {n_s3d}/{n_s2} pairs") + + # Analyze suggested pairs + suggested_analysis = analyze_suggested_pairs(data, best_individual) + + print("\n" + "─" * 70) + print("SUGGESTED AP PAIR: Auto-Selected Node Pair (3-Step Filter Cascade)") + print("─" * 70) + for file_stem in sorted(suggested_analysis.keys()): + r = suggested_analysis[file_stem] + label = FILE_LABELS.get(file_stem, file_stem[:15]) + print(f"\n{label} ({file_stem}):") + gt_nodes = sorted(GROUND_TRUTH[file_stem].keys()) + print(f" Hand-curated GT node indices: {gt_nodes}") + sf, st = r["suggested_from"], r["suggested_to"] + print( + f" Suggested pair: [{sf} → {st}] " + f"(posterior → anterior, type: {r['suggested_type']})" + ) + + if r["suggested_from"] >= 0: + print( + f" Posterior node {sf}: " + f"in GT = {r['from_in_gt']}, " + f"GT rank = {r['from_rank']}" + ) + print( + f" Anterior node {st}: " + f"in GT = {r['to_in_gt']}, " + f"GT rank = {r['to_rank']}" + ) + + if r["both_in_gt"]: + if r["order_correct"]: + order_str = ( + "CORRECT (from_node is posterior, " + "to_node is anterior per GT)" + ) + else: + order_str = ( + "REVERSED (from_node is anterior, " + "to_node is posterior per GT)" + ) + print(f" Ordering: {order_str}") + + print(f" Status: {r['status']}") + print() + + return { + "best_individual": best_individual, + "gt_coverage": gt_coverage, + "suggested_analysis": suggested_analysis, + "cascade_stats": cascade_stats, + "skeletons": data.get("skeletons", {}), + "pc1_vectors": data.get("pc1_vectors", {}), + "vel_projs_pc1": data.get("vel_projs_pc1", {}), + "keypoints": data.get("keypoints", {}), + # For individual plot generation + "file": data["file"], + "individual": data["individual"], + "rxm": data["rxm"], + "vote_margin": data["vote_margin"], + "resultant_length": data["resultant_length"], + "anterior_sign": data["anterior_sign"], + "circ_mean_dir": data["circ_mean_dir"], + "num_selected_frames": data["num_selected_frames"], + "n_frames": data["n_frames"], + # Per-individual diagnostic data: GT concordance and R×M values + "individual_accuracy": data.get("individual_accuracy", {}), + "individual_rxm": data.get("individual_rxm", {}), + } + + +def plot_validation_results(analysis_data, output_path): # noqa: C901 + """Create a tiled layout figure showing validation results. + + Parameters + ---------- + analysis_data : dict + Analysis results from analyze_results() + output_path : Path + Output path for saving figure + + """ + _gt_coverage = analysis_data["gt_coverage"] # noqa: F841 + suggested_analysis = analysis_data["suggested_analysis"] + cascade_stats = analysis_data["cascade_stats"] + + # Short labels for files (filename.slp format) + file_labels = { + "free-moving-2flies-ID-13nodes-1024x1024x1-30_3pxmm": "2Flies.slp", + "free-moving-2mice-noID-5nodes-1280x1024x1-1_9pxmm": "2Mice.slp", + "free-moving-4gerbils-ID-14nodes-1024x1280x3-2pxmm": "4Gerbils.slp", + "free-moving-5mice-noID-11nodes-1280x1024x1-1_97pxmm": "5Mice.slp", + "freemoving-2bees-noID-21nodes-1535x2048x1-14pxmm": "2Bees.slp", + } + + files = sorted(GROUND_TRUTH.keys()) + n_files = len(files) + + bg_color = "white" + text_color = "black" + axis_color = "black" + midline_color = "black" + ap_arrow_color = "black" + gt_node_color = "black" + label_bg_alpha = 0.8 + + fig = plt.figure(figsize=(14, 10), facecolor=bg_color) + gs = fig.add_gridspec( + 2, 2, height_ratios=[1, 1], width_ratios=[1, 1], hspace=0.3, wspace=0.3 + ) + fig.suptitle( + "AP Validation: Average Skeleton (Best Individual by R×M)", + fontsize=14, + fontweight="bold", + color=text_color, + ) + + # Color palette + colors = [ + (0.12, 0.47, 0.71, 1.0), # Blue (2 Flies) + (0.17, 0.63, 0.17, 1.0), # Green (2 Mice) + (0.80, 0.70, 0.10, 1.0), # Gold (4 Gerbils) + (1.00, 0.50, 0.05, 1.0), # Orange (5 Mice) + (0.58, 0.40, 0.74, 1.0), # Purple (2 Bees) + ] + + ax1 = fig.add_subplot(gs[0, :], facecolor=bg_color) + ax1.set_xticks([]) + ax1.set_yticks([]) + for spine in ax1.spines.values(): + spine.set_visible(False) + + # Get skeletons and PC1 from stored H5 data + best_individual = analysis_data["best_individual"] + stored_skeletons = analysis_data.get("skeletons", {}) + stored_pc1 = analysis_data.get("pc1_vectors", {}) + skeleton_data = [] + + for f in files: + best_ind = best_individual.get(f) + if best_ind is None: + continue + + # Try to get stored skeleton from H5 + if f in stored_skeletons and best_ind in stored_skeletons[f]: + avg_skel = stored_skeletons[f][best_ind] # (n_keypoints, 2) + x = avg_skel[:, 0] + y = avg_skel[:, 1] + pc1 = stored_pc1.get(f, {}).get(best_ind) # (2,) or None + else: + # Fallback: load from .slp file + from movement.io import load_poses + + slp_path = SLP_DIR / f"{f}.slp" + if not slp_path.exists(): + continue + + ds = load_poses.from_sleap_file(slp_path) + if "individuals" in ds.position.dims: + pos = ds.position.sel(individuals=best_ind) + else: + pos = ds.position + + x = np.nanmean(pos.sel(space="x").values, axis=0) + y = np.nanmean(pos.sel(space="y").values, axis=0) + pc1 = None # Will compute below if needed + + # Get suggested pair for this file + default_sug = { + "suggested_from": -1, + "suggested_to": -1, + "status": "NO SUGGESTION", + } + suggested = suggested_analysis.get(f, default_sug) + + skeleton_data.append( + { + "file": f, + "label": file_labels.get(f, f[:10]), + "best_ind": best_ind, # Animal identity + "x": x.copy(), + "y": y.copy(), + "pc1": pc1, # Stored PC1 vector (or None) + "gt": GROUND_TRUTH.get(f, {}), + "suggested_from": suggested.get("suggested_from", -1), + "suggested_to": suggested.get("suggested_to", -1), + "from_in_gt": suggested.get("from_in_gt", False), + "to_in_gt": suggested.get("to_in_gt", False), + "status": suggested.get("status", "NO SUGGESTION"), + } + ) + + # Process and plot skeletons side by side + n_skeletons = len(skeleton_data) + midline_plotted = False # For legend + + if n_skeletons > 0: + x_offset = 0 + spacing = 1.8 # Space between skeletons + + for idx, skel in enumerate(skeleton_data): + x, y = skel["x"], skel["y"] + gt = skel["gt"] + stored_pc1 = skel.get("pc1") + + # Get valid (non-NaN) node indices + valid_mask = ~np.isnan(x) & ~np.isnan(y) + valid_indices = np.where(valid_mask)[0] + + if len(valid_indices) < 2: + continue + + # avg_skeleton from collective.py is already centered, use directly + x_centered = x.copy() + y_centered = y.copy() + + # Use stored PC1 from collective.py if available, else compute + if stored_pc1 is not None: + pc1 = np.array(stored_pc1) + else: + # Fallback: compute PC1 via PCA on valid nodes + xv = x_centered[valid_mask] + yv = y_centered[valid_mask] + valid_coords = np.column_stack([xv, yv]) + cov_matrix = np.cov(valid_coords.T) + eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix) + pc1 = eigenvectors[:, np.argmax(eigenvalues)] + + # Ensure PC1 points toward the most anterior GT node + # (highest rank = most anterior in GT convention) + anterior_node = None + max_rank = -1 + for node_idx, rank in gt.items(): + if rank > max_rank: + max_rank = rank + anterior_node = node_idx + + ant_ok = anterior_node is not None and not np.isnan( + x_centered[anterior_node] + ) + if ant_ok: + ax, ay = x_centered[anterior_node], y_centered[anterior_node] + ant_vec = np.array([ax, ay]) + if np.dot(pc1, ant_vec) < 0: + pc1 = -pc1 # Flip to point toward anterior + + # Normalize scale (max of x/y range for square proportions) + x_range = np.nanmax(x_centered) - np.nanmin(x_centered) + y_range = np.nanmax(y_centered) - np.nanmin(y_centered) + max_range = max(x_range, y_range) + if max_range > 0: + scale = 1.5 / max_range # Scale up for larger display + x_centered *= scale + y_centered *= scale + + # Offset horizontally for side-by-side placement + x_plot = x_centered + x_offset + y_plot = y_centered + + file_color = colors[idx] + + # Compute axis extent based on projections onto PC1 + xv, yv = x_centered[valid_mask], y_centered[valid_mask] + proj_pc1 = xv * pc1[0] + yv * pc1[1] + min_proj, max_proj = np.min(proj_pc1), np.max(proj_pc1) + axis_extent = (max_proj - min_proj) * 0.7 + + # Draw PC1 axis at actual angle (matching color) + ae = axis_extent + ax1.plot( + [x_offset - ae * pc1[0], x_offset + ae * pc1[0]], + [0 - ae * pc1[1], 0 + ae * pc1[1]], + "-", + color=(*file_color[:3], 0.5), + linewidth=1.5, + zorder=1, + ) + + # Draw AP midline (dashed black, perpendicular to PC1) + pc1_perp = np.array([-pc1[1], pc1[0]]) # Perpendicular to PC1 + midline_extent = 0.2 + mid_y = (min_proj + max_proj) / 2 # Midpoint along PC1 + mid_x_pos = x_offset + mid_y * pc1[0] + mid_y_pos = mid_y * pc1[1] + + me = midline_extent + mx0 = mid_x_pos - me * pc1_perp[0] + mx1 = mid_x_pos + me * pc1_perp[0] + my0 = mid_y_pos - me * pc1_perp[1] + my1 = mid_y_pos + me * pc1_perp[1] + if not midline_plotted: + ax1.plot( + [mx0, mx1], + [my0, my1], + "--", + color=midline_color, + linewidth=1.5, + alpha=0.7, + zorder=2, + label="AP midline", + ) + midline_plotted = True + else: + ax1.plot( + [mx0, mx1], + [my0, my1], + "--", + color=midline_color, + linewidth=1.5, + alpha=0.7, + zorder=2, + ) + + # Draw AP bidirectional arrow next to midline + ap_off = midline_extent * 1.1 # Offset from midline center + ap_half = axis_extent * 0.125 # Half-length of arrow shaft + + # Arrow center (offset perpendicular to PC1 from midline center) + acx = mid_x_pos + ap_off * pc1_perp[0] + acy = mid_y_pos + ap_off * pc1_perp[1] + + # Arrow endpoints along PC1 direction + arrow_ant_x = acx + ap_half * pc1[0] # Anterior end (+PC1) + arrow_ant_y = acy + ap_half * pc1[1] + arrow_post_x = acx - ap_half * pc1[0] # Posterior end (-PC1) + arrow_post_y = acy - ap_half * pc1[1] + + # Draw arrow shaft + ax1.plot( + [arrow_post_x, arrow_ant_x], + [arrow_post_y, arrow_ant_y], + "-", + color=ap_arrow_color, + linewidth=1.5, + zorder=2, + ) + + # Draw triangular arrowheads at both ends + head_len = 0.06 # Arrowhead length + head_wid = 0.03 # Arrowhead half-width + + # Anterior arrowhead (pointing toward +PC1) + ant_head_base_x = arrow_ant_x - head_len * pc1[0] + ant_head_base_y = arrow_ant_y - head_len * pc1[1] + hw = head_wid + ahbx, ahby = ant_head_base_x, ant_head_base_y + ant_head_tri = np.array( + [ + [arrow_ant_x, arrow_ant_y], + [ahbx + hw * pc1_perp[0], ahby + hw * pc1_perp[1]], + [ahbx - hw * pc1_perp[0], ahby - hw * pc1_perp[1]], + ] + ) + ax1.fill( + ant_head_tri[:, 0], + ant_head_tri[:, 1], + color=ap_arrow_color, + zorder=3, + ) + + # Posterior arrowhead (pointing toward -PC1) + phbx = arrow_post_x + head_len * pc1[0] + phby = arrow_post_y + head_len * pc1[1] + post_head_tri = np.array( + [ + [arrow_post_x, arrow_post_y], + [phbx + hw * pc1_perp[0], phby + hw * pc1_perp[1]], + [phbx - hw * pc1_perp[0], phby - hw * pc1_perp[1]], + ] + ) + ax1.fill( + post_head_tri[:, 0], + post_head_tri[:, 1], + color=ap_arrow_color, + zorder=3, + ) + + # Add "A" and "P" labels beyond arrow ends (along PC1 direction) + ap_lbl = 0.08 + ax1.text( + arrow_ant_x + ap_lbl * pc1[0], + arrow_ant_y + ap_lbl * pc1[1], + "A", + fontsize=10, + fontweight="bold", + color=ap_arrow_color, + ha="center", + va="center", + zorder=10, + ) + ax1.text( + arrow_post_x - ap_lbl * pc1[0], + arrow_post_y - ap_lbl * pc1[1], + "P", + fontsize=10, + fontweight="bold", + color=ap_arrow_color, + ha="center", + va="center", + zorder=10, + ) + + # Add +PC1 and -PC1 labels at ends of axis + lo = 0.05 + ax1.text( + x_offset + ae * pc1[0] + lo * pc1_perp[0], + ae * pc1[1] + lo * pc1_perp[1], + "+PC1", + fontsize=10, + fontweight="bold", + color=file_color, + ha="center", + va="center", + alpha=0.8, + zorder=10, + ) + ax1.text( + x_offset - ae * pc1[0] + lo * pc1_perp[0], + -ae * pc1[1] + lo * pc1_perp[1], + "-PC1", + fontsize=10, + fontweight="bold", + color=file_color, + ha="center", + va="center", + alpha=0.8, + zorder=10, + ) + + # Get suggested pair info first (needed for node coloring) + from_idx = skel.get("suggested_from", -1) + to_idx = skel.get("suggested_to", -1) + status = skel.get("status", "NO SUGGESTION") + from_in_gt = skel.get("from_in_gt", False) + to_in_gt = skel.get("to_in_gt", False) + + # Status colors (same as status legend) + status_colors_map = { + "Both in GT, order correct": "#2ecc71", + "Both in GT, order reversed": "#f39c12", + "One node in GT": "#e74c3c", + "Neither node in GT": "#95a5a6", + "NO SUGGESTION": "#bdc3c7", + } + line_color = status_colors_map.get(status, "#bdc3c7") + + # Plot nodes + n_nodes = len(x_plot) + + for node_idx in range(n_nodes): + if np.isnan(x_plot[node_idx]) or np.isnan(y_plot[node_idx]): + continue + + # Check if node is NaN (not in GT) + is_nan = node_idx not in gt + # Check if node is part of suggested pair + is_suggested = node_idx in (from_idx, to_idx) + + # Set alpha and size based on node type + size = 30 if is_nan else 50 + + # Dot color: GT=theme, suggested non-GT=status, else file_color + if not is_nan: + dot_color = gt_node_color # GT nodes use theme color + alpha = 1.0 + elif is_suggested: + # Suggested non-GT nodes use status color + dot_color = line_color + alpha = 1.0 + else: + # Purple non-GT, non-suggested nodes + dot_color = file_color + alpha = 0.30 + + ax1.scatter( + x_plot[node_idx], + y_plot[node_idx], + c=[dot_color], + s=size, + alpha=alpha, + edgecolors="white" if not is_nan else "none", + linewidths=0.5, + zorder=5, + ) + + # Add node index label for ranked nodes only + if not is_nan: + ax1.annotate( + str(node_idx), + (x_plot[node_idx], y_plot[node_idx]), + xytext=(4, 4), + textcoords="offset points", + fontsize=13, + color=gt_node_color, + fontweight="bold", + zorder=10, + ) + + # Draw arrow connecting suggested pair (color based on status) + # Arrow points from from_node (posterior) to to_node (anterior) + if from_idx >= 0 and to_idx >= 0: + from_ok = not np.isnan(x_plot[from_idx]) and not np.isnan( + y_plot[from_idx] + ) + to_ok = not np.isnan(x_plot[to_idx]) and not np.isnan( + y_plot[to_idx] + ) + if from_ok and to_ok: + # Draw arrow from posterior to anterior + ax1.annotate( + "", + xy=(x_plot[to_idx], y_plot[to_idx]), + xytext=(x_plot[from_idx], y_plot[from_idx]), + arrowprops=dict( + arrowstyle="-|>", + color=line_color, + lw=1.5, + mutation_scale=10, + alpha=0.85, + ), + zorder=4, + ) + + # Redraw suggested nodes with higher zorder to be on top + # GT nodes use theme color, non-GT nodes use line_color + if not np.isnan(x_plot[from_idx]): + from_color = ( + gt_node_color if from_idx in gt else line_color + ) + ax1.scatter( + x_plot[from_idx], + y_plot[from_idx], + c=[from_color], + s=50, + alpha=1.0, + edgecolors="white", + linewidths=0.5, + zorder=6, + ) + ax1.annotate( + str(from_idx), + (x_plot[from_idx], y_plot[from_idx]), + xytext=(4, 4), + textcoords="offset points", + fontsize=13, + color=from_color, + fontweight="bold", + zorder=10, + ) + + if not np.isnan(x_plot[to_idx]): + to_color = ( + gt_node_color if to_idx in gt else line_color + ) + ax1.scatter( + x_plot[to_idx], + y_plot[to_idx], + c=[to_color], + s=50, + alpha=1.0, + edgecolors="white", + linewidths=0.5, + zorder=6, + ) + ax1.annotate( + str(to_idx), + (x_plot[to_idx], y_plot[to_idx]), + xytext=(4, 4), + textcoords="offset points", + fontsize=13, + color=to_color, + fontweight="bold", + zorder=10, + ) + + # Add label below skeleton + ax1.text( + x_offset, + -1.25, + skel["label"], + ha="center", + va="top", + fontsize=12, + fontweight="bold", + color=file_color, + clip_on=False, + ) + ax1.text( + x_offset, + -1.45, + skel["best_ind"], + ha="center", + va="top", + fontsize=12, + fontweight="bold", + color=file_color, + clip_on=False, + ) + + # Add suggested pair [ # , # ] with color-coded numbers + if from_idx >= 0 and to_idx >= 0: + # Determine colors for each number based on GT membership + from_num_color = gt_node_color if from_in_gt else line_color + to_num_color = gt_node_color if to_in_gt else line_color + + # Build text parts with spacing + pair_y = -1.65 + char_width = 0.08 # Approximate character width in data units + + # Calculate total width and starting position for centering + from_str = str(from_idx) + to_str = str(to_idx) + # Format: "[ " (2) + from + " , " (3) + to + " ]" (2) + total_chars = 2 + len(from_str) + 3 + len(to_str) + 2 + start_x = x_offset - (total_chars * char_width) / 2 + + # Draw each part with spaces + pos = start_x + ax1.text( + pos, + pair_y, + "[ ", + ha="left", + va="top", + fontsize=14, + fontweight="bold", + color="black", + clip_on=False, + ) + pos += char_width * 2 + ax1.text( + pos, + pair_y, + from_str, + ha="left", + va="top", + fontsize=14, + fontweight="bold", + color=from_num_color, + clip_on=False, + ) + pos += char_width * len(from_str) + ax1.text( + pos, + pair_y, + " , ", + ha="left", + va="top", + fontsize=14, + fontweight="bold", + color="black", + clip_on=False, + ) + pos += char_width * 3 + ax1.text( + pos, + pair_y, + to_str, + ha="left", + va="top", + fontsize=14, + fontweight="bold", + color=to_num_color, + clip_on=False, + ) + pos += char_width * len(to_str) + ax1.text( + pos, + pair_y, + " ]", + ha="left", + va="top", + fontsize=14, + fontweight="bold", + color="black", + clip_on=False, + ) + + x_offset += spacing + + # Set axis limits (wider since spanning full row) + x_total = x_offset - spacing + ax1.set_xlim(-0.8, x_total + 0.8) + ax1.set_ylim(-1.85, 1.2) # Extended to accommodate three-line labels + + # Force equal aspect ratio for square skeleton display + ax1.set_aspect("equal", adjustable="box") + + # Add legend for AP midline and GT Node (positioned lower right) + midline_handle = Line2D( + [0], + [0], + linestyle="--", + color=midline_color, + linewidth=1.5, + alpha=0.7, + ) + gt_node_handle = Line2D( + [0], [0], marker="o", color="black", linestyle="None", markersize=6 + ) + midline_legend = ax1.legend( + [midline_handle, gt_node_handle], + ["A/P midline", "GT Node"], + loc="lower right", + fontsize=9, + framealpha=label_bg_alpha, + bbox_to_anchor=(1.12, -0.05), + prop={"weight": "bold"}, + facecolor=bg_color, + labelcolor=text_color, + ) + ax1.add_artist(midline_legend) + + # Determine dominant status color for arrow text + status_colors_map = { + "Both in GT, order correct": "#2ecc71", + "Both in GT, order reversed": "#f39c12", + "One node in GT": "#e74c3c", + "Neither node in GT": "#95a5a6", + "NO SUGGESTION": "#bdc3c7", + } + status_counts = {} + for f in files: + status = suggested_analysis.get(f, {}).get("status", "NO SUGGESTION") + status_counts[status] = status_counts.get(status, 0) + 1 + dominant_status = ( + max(status_counts, key=status_counts.get) + if status_counts + else "NO SUGGESTION" + ) + dominant_color = status_colors_map.get(dominant_status, "#bdc3c7") + + # Add arrow description at top left (bold, dominant status color) + fig.text( + 0.01, + 0.99, + "vector points in the inferred P\u2192A direction for suggested pair", + fontsize=8, + fontweight="normal", + color=dominant_color, + ha="left", + va="top", + ) + fig.text( + 0.01, + 0.97, + "(max AP separation, distal preferred)", + fontsize=8, + fontweight="normal", + color=dominant_color, + ha="left", + va="top", + ) + + # Add formula text at top right corner of figure + line_spacing = 0.018 + y_start = 0.99 + x_right = 1.0 + fig.text( + x_right, + y_start, + "R = √(C² + S²)", + fontsize=8, + fontweight="normal", + color=text_color, + ha="right", + va="top", + ) + fig.text( + x_right, + y_start - line_spacing, + "M = |n₊ − n₋| / (n₊ + n₋)", + fontsize=8, + fontweight="normal", + color=text_color, + ha="right", + va="top", + ) + fig.text( + x_right, + y_start - 2 * line_spacing, + "C, S = mean cos θ, mean sin θ of centroid velocities", + fontsize=8, + fontweight="normal", + color=text_color, + ha="right", + va="top", + ) + fig.text( + x_right, + y_start - 3 * line_spacing, + "n₊, n₋ = # of pos./neg. velocity proj. onto PC1", + fontsize=8, + fontweight="normal", + color=text_color, + ha="right", + va="top", + ) + fig.text( + 0.01, + y_start - 3 * line_spacing, + "anterior = +PC1 if n₊ > n₋, else −PC1", + fontsize=8, + fontweight="normal", + color=text_color, + ha="left", + va="top", + ) + + # Build labels list for bar charts (use animal identity) + labels = [best_individual.get(f, f[:15]) for f in files] + + # Panel 3: Suggested pair analysis + ax3 = fig.add_subplot(gs[1, 0], facecolor=bg_color) + + # Status colors for legend (kept for tile 1) + status_colors = { + "Both in GT, order correct": "#2ecc71", + "Both in GT, order reversed": "#f39c12", + "One node in GT": "#e74c3c", + "Neither node in GT": "#95a5a6", + "NO SUGGESTION": "#bdc3c7", + } + + # Create custom legend handler for arrow with dots at each end + class ArrowHandler(HandlerBase): + def __init__(self, color, left_dot_color=None, right_dot_color=None): + self.color = color + self.left_dot_color = left_dot_color if left_dot_color else color + self.right_dot_color = ( + right_dot_color if right_dot_color else color + ) + super().__init__() + + def create_artists( + self, + legend, + orig_handle, + xdescent, + ydescent, + width, + height, + fontsize, + trans, + ): + # Arrow line with triangular head + x_start = xdescent + 1 + x_end = xdescent + width - 1 + y_mid = ydescent + height / 2 + + # Draw arrow shaft (shorter to make room for visible arrowhead) + line = Line2D( + [x_start + 5, x_end - 10], + [y_mid, y_mid], + color=self.color, + linewidth=3, + transform=trans, + ) + + # Draw arrowhead as a triangle (larger and more visible) + head_width = 5 + head_length = 5 + arrow_head = mpatches.FancyArrow( + x_end - 10, + y_mid, + 5, + 0, + width=0, + head_width=head_width, + head_length=head_length, + fc=self.color, + ec=self.color, + transform=trans, + ) + + # Draw dots at each end with specified colors + ldc, rdc = self.left_dot_color, self.right_dot_color + dot_start = Line2D( + [x_start], + [y_mid], + marker="o", + markersize=7, + markerfacecolor=ldc, + markeredgecolor=ldc, + linestyle="None", + transform=trans, + ) + dot_end = Line2D( + [x_end + 4], + [y_mid], + marker="o", + markersize=7, + markerfacecolor=rdc, + markeredgecolor=rdc, + linestyle="None", + transform=trans, + ) + + return [line, arrow_head, dot_start, dot_end] + + # Create legend for status colors at bottom left of tile 1 + legend_labels = [ + ("Both in GT, order correct", "Both in GT, Match A\u2194P"), + ("Both in GT, order reversed", "Both in GT, Mismatch A\u2194P"), + ("One node in GT", "One in GT"), + ("One node in GT SWAPPED", "One in GT"), + ("Neither node in GT", "Neither in GT"), + ] + + # Create legend handles and handler map + legend_handles = [] + handler_map = {} + legend_colors = [] # Track colors for text coloring + for status, label in legend_labels: + # Get base color (strip SWAPPED suffix for lookup) + base_status = status.replace(" SWAPPED", "") + color = status_colors[base_status] + legend_colors.append(color) + handle = Line2D([], [], label=label) + legend_handles.append(handle) + # "One in GT" has black left dot (GT) and red right dot (non-GT) + if status == "One node in GT": + handler_map[handle] = ArrowHandler( + color, left_dot_color="black", right_dot_color=color + ) + elif status == "One node in GT SWAPPED": + handler_map[handle] = ArrowHandler( + color, left_dot_color=color, right_dot_color="black" + ) + else: + handler_map[handle] = ArrowHandler(color) + + leg_status = fig.legend( + handles=legend_handles, + handler_map=handler_map, + loc="lower left", + bbox_to_anchor=(0.01, 0.58), + fontsize=10, + frameon=False, + prop={"weight": "bold"}, + ) + + # Set legend text colors to match arrow colors (One in GT = black) + zipped = zip( + leg_status.get_texts(), legend_labels, legend_colors, strict=False + ) + for text, (status, _label), color in zipped: + if "One node in GT" in status: + text.set_color("black") + else: + text.set_color(color) + + # Add "Suggested AP Pair:" label below legend + fig.text( + 0.01, + 0.565, + "Suggested AP Pair:", + fontsize=13, + fontweight="bold", + color="black", + ha="left", + va="center", + ) + + # Panel 3: GT Node Ranks (descending bars) + # Reassign ranks: highest rank = most anterior (tallest bar) + # Bar width and group spacing + group_width = 0.8 + + ax3.set_title( + "Hand-Curated GT Node Rankings", + fontsize=14, + fontweight="bold", + color=text_color, + ) + + for i, f in enumerate(files): + gt = GROUND_TRUTH.get(f, {}) + file_color = colors[i] + + if not gt: + continue + + # Higher rank = more anterior, so rank directly maps to bar height + # Sort by rank descending (most anterior/highest rank first) + sorted_nodes = sorted(gt.items(), key=lambda x: x[1], reverse=True) + n_gt_nodes = len(sorted_nodes) + + for j, (node_idx, rank) in enumerate(sorted_nodes): + # Rank value is the bar height (higher = more anterior) + height = rank + + # Position within group (most anterior = leftmost = tallest) + x_pos = ( + i - group_width / 2 + (j + 0.5) * (group_width / n_gt_nodes) + ) + + # Draw bar + ax3.bar( + x_pos, + height, + width=group_width / n_gt_nodes * 0.85, + color=file_color, + edgecolor=axis_color, + linewidth=0.5, + alpha=0.75, + ) + + # Add node index as bold text on top of bar + ax3.text( + x_pos, + height + 0.1, + str(node_idx), + ha="center", + va="bottom", + fontsize=13, + fontweight="bold", + color=text_color, + ) + + ax3.set_xticks(range(n_files)) + ax3.set_xticklabels( + labels, rotation=30, ha="right", fontsize=14, fontweight="bold" + ) + # Color x-axis labels to match skeleton colors + for tick_label, color in zip( + ax3.get_xticklabels(), + colors, + strict=True, + ): + tick_label.set_color(color) + ax3.set_ylabel( + "Rank (higher = more anterior)", + fontsize=14, + fontweight="bold", + color=text_color, + ) + ax3.set_ylim(0, max(len(gt) for gt in GROUND_TRUTH.values()) + 1) + + # Make tile 3 tick labels bold + ax3.tick_params(axis="y", labelsize=10) + for label in ax3.get_yticklabels(): + label.set_fontweight("bold") + + # Panel 4: Filter cascade progression per dataset + ax4 = fig.add_subplot(gs[1, 1], facecolor=bg_color) + + import matplotlib as mpl + + mpl.rcParams["hatch.linewidth"] = 1.5 + + # Compute cascade data per file as raw pair counts. + # Step 1 shows C(n_candidates, 2) — pair potential of surviving nodes. + # Steps 2-3 show surviving pair counts directly. + # Since C(K,2) >= P >= D, bars are guaranteed to decrease. + all_counts = [[], [], []] + for f in files: + cs = cascade_stats.get(f, {}) + n_s1 = cs.get("n_step1_candidates", 0) + n_s2 = cs.get("n_step2_pairs", 0) + n_s3d = cs.get("n_step3_distal", 0) + + s1_pairs = n_s1 * (n_s1 - 1) // 2 if n_s1 >= 2 else 0 + all_counts[0].append(s1_pairs) + all_counts[1].append(n_s2) + all_counts[2].append(n_s3d) + + x_pos = np.arange(n_files) + bar_width = 0.25 + + # Hatch patterns distinguish steps; bar color matches dataset + step_hatches = ["--", "||", "+++"] + step_names = ["Step 1: Lateral", "Step 2: Opposite", "Step 3: Distal"] + + for j, (counts, hatch, _sn) in enumerate( + zip( + all_counts, + step_hatches, + step_names, + strict=True, + ) + ): + offset = (j - 1) * bar_width + for i, count in enumerate(counts): + c = colors[i] + fc = (c[0], c[1], c[2], 0.30) + ax4.bar( + x_pos[i] + offset, + count, + bar_width, + color=fc, + edgecolor="black", + linewidth=0.5, + hatch=hatch, + ) + bx = x_pos[i] + offset + by = count + 0.5 + ax4.text( + bx, + by, + str(count), + ha="center", + va="bottom", + fontsize=8, + fontweight="bold", + color=text_color, + ) + + # Small legend showing only hatch patterns (no color) + legend_patches = [] + for hatch, sn in zip(step_hatches, step_names, strict=True): + legend_patches.append( + mpatches.Patch( + facecolor="white", + edgecolor="black", + linewidth=0.5, + hatch=hatch, + label=sn, + ) + ) + ax4.legend( + handles=legend_patches, + loc="upper right", + fontsize=8, + framealpha=label_bg_alpha, + facecolor=bg_color, + labelcolor=text_color, + handlelength=1.5, + handleheight=1.0, + ) + + ax4.set_ylabel( + "Candidate Pairs", + fontsize=14, + fontweight="bold", + color=text_color, + ) + # Auto-scale y-axis with padding + max_count = max(max(c) for c in all_counts) if any(all_counts[0]) else 1 + ax4.set_ylim(0, max_count * 1.25) + ax4.set_xticks(x_pos) + ax4.set_xticklabels( + labels, rotation=30, ha="right", fontsize=14, fontweight="bold" + ) + for tick_label, color in zip( + ax4.get_xticklabels(), + colors, + strict=True, + ): + tick_label.set_color(color) + + ax4.tick_params(axis="y", labelsize=10, colors=text_color) + for label in ax4.get_yticklabels(): + label.set_fontweight("bold") + + ax4.set_title( + "3-Step Filter Cascade Progression", + fontsize=14, + fontweight="bold", + color=text_color, + ) + plt.subplots_adjust(left=0.08, right=0.95, top=0.92, bottom=0.08) + + for ax in [ax3, ax4]: + ax.tick_params(axis="y", colors=text_color, labelsize=10) + for spine in ax.spines.values(): + spine.set_color(axis_color) + ax.xaxis.label.set_color(text_color) + ax.yaxis.label.set_color(text_color) + if hasattr(ax, "title"): + ax.title.set_color(text_color) + + for tick_label, color in zip( + ax3.get_xticklabels(), + colors, + strict=True, + ): + tick_label.set_color(color) + tick_label.set_fontsize(14) + for tick_label, color in zip( + ax4.get_xticklabels(), + colors, + strict=True, + ): + tick_label.set_color(color) + tick_label.set_fontsize(14) + + ax4.yaxis.label.set_color("black") + + # Extract full timestamp (YYYYMMDD_HHMMSS) from H5 filename + stem_parts = output_path.stem.split("_") + stem_suffix = "_".join(stem_parts[-2:]) + fig_path = FIGURES_DIR / f"ap_validation_results_{stem_suffix}.svg" + plt.savefig( + fig_path, format="svg", facecolor=bg_color, bbox_inches="tight" + ) + plt.close(fig) + + print(f"\nSaved cross-dataset validation figure to: {fig_path}") + return fig_path + + +def create_individual_plot( # noqa: C901 + file_stem, + avg_skeleton, + pc1, + keypoint_names, + metrics, + output_path, + vel_projs=None, + individual_name=None, + all_individuals_metrics=None, + file_color=None, + individual_accuracy=None, +): + """Create a detailed 2×2 plot for a single file/animal. + + The resulting figure contains four tiles showing (1) the longitudinal + spread of keypoints along the first principal component, (2) the lateral + spread along the second principal component, (3) the direction of + anterior–posterior motion via centroid velocity, and (4) a scatter plot + relating the product of resultant length and vote margin (R×M) to + ground‑truth ordering accuracy. Colour schemes and labels are chosen + consistently with the summary AP validation figure. + + Parameters + ---------- + file_stem : str + The stem of the file name (without extension). Used to derive + display labels and file names for the output figure. + avg_skeleton : np.ndarray + Array of shape ``(n_keypoints, 2)`` containing the average x,y + coordinates of each keypoint after segmentation and outlier removal. + pc1 : np.ndarray + The first principal component vector (length‑2) representing the + anterior–posterior axis. + keypoint_names : list[str] + Names of the keypoints corresponding to rows of ``avg_skeleton``. + metrics : dict + Dictionary of scalar metrics summarising the individual's movement + and orientation. Expected keys include ``'anterior_sign'``, + ``'vote_margin'``, ``'resultant_length'``, ``'circ_mean_dir'``, + ``'num_selected_frames'`` and ``'n_frames'``. + output_path : Path | str + Directory or filename where the generated figure will be written. + vel_projs : np.ndarray | None, optional + One‑dimensional array of projections of velocity vectors onto + ``pc1``. When provided, a histogram of these values is drawn in + Tile 3. If ``None`` or empty, the histogram is omitted. + individual_name : str | None, optional + Name of the individual animal, shown in the bottom title and used + for the scatter plot legend. If ``None``, only the file label + appears. + all_individuals_metrics : dict[str, dict[str, float]], optional + Mapping from individual names to dictionaries containing mean + resultant length ``'R'`` and vote margin ``'M'`` values. Used to + populate the scatter plot (Tile 4). + file_color : tuple | None, optional + RGB(A) colour for this file, matching the colours used in the + summary AP validation figure. When ``None``, a default colour + palette is used. + individual_accuracy : dict[str, dict[str, float]] | None, optional + Mapping from individual names to accuracy dictionaries. If + provided, accuracy values are displayed in the scatter plot legend. + + Returns + ------- + matplotlib.figure.Figure | None + The created figure. If there are fewer than two valid keypoints + after filtering, the function prints a message and returns ``None``. + + """ + n_keypoints = len(avg_skeleton) + + # Generate unique visually distinct colors for all nodes + # Start with maximally distinct base colors, then add shades if needed + base_colors = [ + (0.12, 0.47, 0.71), # Blue + (1.00, 0.50, 0.05), # Orange + (0.17, 0.63, 0.17), # Green + (0.84, 0.15, 0.16), # Red + (0.58, 0.40, 0.74), # Purple + (0.55, 0.34, 0.29), # Brown + (0.89, 0.47, 0.76), # Pink + (0.50, 0.50, 0.50), # Gray + (0.74, 0.74, 0.13), # Olive/Yellow + (0.09, 0.75, 0.81), # Cyan + (0.00, 0.80, 0.60), # Teal + (0.90, 0.30, 0.50), # Magenta + (0.40, 0.20, 0.60), # Dark purple + (0.95, 0.70, 0.20), # Gold + (0.30, 0.70, 0.90), # Sky blue + (0.70, 0.90, 0.30), # Lime + ] + + def generate_colors(n): + """Generate n distinct colors, using shades if needed.""" + colors = [] + n_base = len(base_colors) + + if n <= n_base: + # Use base colors directly + colors = [(*base_colors[i], 1.0) for i in range(n)] + else: + # Use all base colors, then add lighter/darker shades + for i in range(n): + base_idx = i % n_base + # 0 = normal, 1 = lighter, 2 = darker, etc. + shade_level = i // n_base + + r, g, b = base_colors[base_idx] + + if shade_level == 0: + # Original color + pass + elif shade_level % 2 == 1: + # Lighter shade + factor = 0.4 * ((shade_level + 1) // 2) + r = min(1.0, r + (1 - r) * factor) + g = min(1.0, g + (1 - g) * factor) + b = min(1.0, b + (1 - b) * factor) + else: + # Darker shade + factor = 0.4 * (shade_level // 2) + r = max(0.0, r * (1 - factor)) + g = max(0.0, g * (1 - factor)) + b = max(0.0, b * (1 - factor)) + + colors.append((r, g, b, 1.0)) + + return colors + + colors = generate_colors(n_keypoints) + + # Get valid keypoints (non-NaN) + valid_mask = ~np.any(np.isnan(avg_skeleton), axis=1) + valid_idx = np.where(valid_mask)[0] + + if len(valid_idx) < 2: + print(f" Skipping {file_stem}: not enough valid keypoints") + return + + # Compute PC2 as perpendicular to PC1 + pc1 = pc1 / np.linalg.norm(pc1) # Normalize + pc2 = np.array([-pc1[1], pc1[0]]) # 90 degree rotation + + # Compute projections onto PC1 and PC2 + proj_pc1 = avg_skeleton @ pc1 + proj_pc2 = avg_skeleton @ pc2 + + # Display parameters + valid_coords = avg_skeleton[valid_mask] + shape_radius = np.max(np.abs(valid_coords)) * 1.3 + display_radius = shape_radius * 1.2 + + # Body axis extent + body_axis_extent = np.max(np.abs(proj_pc1[valid_mask])) * 1.3 + + # Use file color for all keypoints (matching summary figure) + + # Extract metrics + anterior_sign = int(metrics.get("anterior_sign", 1)) + vote_margin = metrics.get("vote_margin", 0) + resultant_length = metrics.get("resultant_length", 0) + circ_mean_dir = metrics.get("circ_mean_dir", np.nan) + num_selected_frames = int(metrics.get("num_selected_frames", 0)) + n_frames = int(metrics.get("n_frames", 1)) + + # Compute Vc (net velocity display vector) + # Scale so R=1 displays at 3*shape_radius, R scales 0-1 + # Clamp so arrow tip + label stays within display_radius + vel_display_max = shape_radius * 3.0 + vel_display_len = resultant_length * vel_display_max + max_vc_len = display_radius * 0.8 # clamp to 80% of axis boundary + vel_display_len = min(vel_display_len, max_vc_len) + if not np.isnan(circ_mean_dir): + cos_dir = np.cos(circ_mean_dir) + sin_dir = np.sin(circ_mean_dir) + net_vel_display = vel_display_len * np.array([cos_dir, sin_dir]) + else: + net_vel_display = np.array([0.0, 0.0]) + + centroid_alpha = 0.7 + label_offset = 0.12 * shape_radius + + # Create figure with dark background + fig = plt.figure(figsize=(14, 10), facecolor="black") + + # Compute R*M + rxm = resultant_length * vote_margin + + # Top title with description + fig.suptitle( + "AP Validation Detail: BBox-Centroid Average Skeleton", + fontsize=12, + fontweight="normal", + color="white", + y=0.97, + ) + + # Bottom title with frame count, dataset label, and individual name + # Move to right (beneath tile 4) if >=20 nodes to avoid legend overlap + file_label = FILE_LABELS.get(file_stem, file_stem[:15]) + _ind_label = f" ({individual_name})" if individual_name else "" + pct = 100.0 * num_selected_frames / n_frames if n_frames > 0 else 0 + if n_keypoints >= 20: + text_x, text_ha = 0.78, "center" # Right side, beneath tile 4 + else: + text_x, text_ha = 0.5, "center" # Centered + fig.text( + text_x, + 0.045, + f"{num_selected_frames} segment frames ({pct:.1f}% of all)", + fontsize=14, + fontweight="normal", + color="white", + ha=text_ha, + va="bottom", + ) + # Render file label with individual name + if individual_name: + fig.text( + text_x, + 0.012, + f"{file_label} ({individual_name})", + fontsize=14, + fontweight="bold", + color="white", + ha=text_ha, + va="bottom", + ) + else: + fig.text( + text_x, + 0.012, + file_label, + fontsize=14, + fontweight="bold", + color="white", + ha=text_ha, + va="bottom", + ) + + # Fixed width ratio for Tile 4 scatter plot + tile4_width_ratio = 0.8 + + # Create 2x2 grid with space for center legend + right_margin = 0.92 + gs = fig.add_gridspec( + 2, + 2, + hspace=0.25, + wspace=0.35, + left=0.08, + right=right_margin, + top=0.90, + bottom=0.10, + width_ratios=[1, tile4_width_ratio], + ) + + # TILE 1: Longitudinal Spread (PC1) + ax1 = fig.add_subplot(gs[0, 0], facecolor="black") + ax1.set_title( + "Longitudinal Spread (PC1 Projections)", + fontsize=12, + fontweight="normal", + color="white", + ) + + # PC1 axis line segments (broken at keypoint range) + min_proj = np.min(proj_pc1[valid_mask]) + max_proj = np.max(proj_pc1[valid_mask]) + + ax1.plot( + [-1.5 * body_axis_extent * pc1[0], min_proj * pc1[0]], + [-1.5 * body_axis_extent * pc1[1], min_proj * pc1[1]], + "-", + color=(1, 1, 1, 0.65), + linewidth=1.5, + ) + ax1.plot( + [max_proj * pc1[0], 1.5 * body_axis_extent * pc1[0]], + [max_proj * pc1[1], 1.5 * body_axis_extent * pc1[1]], + "-", + color=(1, 1, 1, 0.65), + linewidth=1.5, + ) + + # PC1 labels at edges + pc1_perp = np.array([-pc1[1], pc1[0]]) + pc_label_extent = 1.1 * body_axis_extent + ax1.text( + pc_label_extent * pc1[0] + pc1_perp[0] * label_offset, + pc_label_extent * pc1[1] + pc1_perp[1] * label_offset, + "+PC1", + color="white", + fontsize=10, + ha="center", + va="center", + ) + ax1.text( + -pc_label_extent * pc1[0] + pc1_perp[0] * label_offset, + -pc_label_extent * pc1[1] + pc1_perp[1] * label_offset, + "—PC1", + color="white", + fontsize=10, + ha="center", + va="center", + ) + + # Draw projection lines and keypoints + for i in valid_idx: + tip_x, tip_y = avg_skeleton[i] + proj_x = proj_pc1[i] * pc1[0] + proj_y = proj_pc1[i] * pc1[1] + + ax1.plot( + [tip_x, proj_x], + [tip_y, proj_y], + "--", + color=(*colors[i][:3], 0.5), + linewidth=1, + ) + ax1.plot( + [0, proj_x], + [0, proj_y], + "-", + color=(*colors[i][:3], 0.5), + linewidth=2, + ) + + for i in valid_idx: + ax1.scatter( + avg_skeleton[i, 0], + avg_skeleton[i, 1], + s=120, + c=[colors[i]], + edgecolors="black", + linewidths=0.5, + zorder=5, + ) + + ax1.scatter( + 0, + 0, + s=120, + c="gray", + edgecolors="black", + linewidths=0.5, + alpha=centroid_alpha, + zorder=6, + ) + + ax1.set_xlim(-display_radius, display_radius) + ax1.set_ylim(-display_radius, display_radius) + ax1.set_aspect("equal") + ax1.set_xlabel("X", color="white", fontsize=10) + ax1.set_ylabel( + "Y", color="white", fontsize=10, rotation=0, ha="right", va="center" + ) + ax1.tick_params(colors="gray", labelsize=8) + ax1.grid(True, color="gray", alpha=0.3, linestyle="-", linewidth=0.5) + for spine in ax1.spines.values(): + spine.set_color("gray") + + # TILE 2: Lateral Spread (PC2) + ax2 = fig.add_subplot(gs[0, 1], facecolor="black") + ax2.set_title( + "Lateral Spread (PC2 Projections)", + fontsize=12, + fontweight="normal", + color="white", + ) + + min_proj_pc2 = np.min(proj_pc2[valid_mask]) + max_proj_pc2 = np.max(proj_pc2[valid_mask]) + + ax2.plot( + [-1.5 * body_axis_extent * pc2[0], min_proj_pc2 * pc2[0]], + [-1.5 * body_axis_extent * pc2[1], min_proj_pc2 * pc2[1]], + "-", + color=(1, 1, 1, 0.65), + linewidth=1.5, + ) + ax2.plot( + [max_proj_pc2 * pc2[0], 1.5 * body_axis_extent * pc2[0]], + [max_proj_pc2 * pc2[1], 1.5 * body_axis_extent * pc2[1]], + "-", + color=(1, 1, 1, 0.65), + linewidth=1.5, + ) + + # PC2 labels at edges + pc2_perp = np.array([-pc2[1], pc2[0]]) + pc_label_extent_pc2 = 1.1 * body_axis_extent + ax2.text( + pc_label_extent_pc2 * pc2[0] + pc2_perp[0] * label_offset, + pc_label_extent_pc2 * pc2[1] + pc2_perp[1] * label_offset, + "+PC2", + color="white", + fontsize=10, + ha="center", + va="center", + ) + ax2.text( + -pc_label_extent_pc2 * pc2[0] + pc2_perp[0] * label_offset, + -pc_label_extent_pc2 * pc2[1] + pc2_perp[1] * label_offset, + "—PC2", + color="white", + fontsize=10, + ha="center", + va="center", + ) + + for i in valid_idx: + tip_x, tip_y = avg_skeleton[i] + proj_x = proj_pc2[i] * pc2[0] + proj_y = proj_pc2[i] * pc2[1] + + ax2.plot( + [tip_x, proj_x], + [tip_y, proj_y], + "--", + color=(*colors[i][:3], 0.5), + linewidth=1, + ) + ax2.plot( + [0, proj_x], + [0, proj_y], + "-", + color=(*colors[i][:3], 0.5), + linewidth=2, + ) + + for i in valid_idx: + ax2.scatter( + avg_skeleton[i, 0], + avg_skeleton[i, 1], + s=120, + c=[colors[i]], + edgecolors="black", + linewidths=0.5, + zorder=5, + ) + + ax2.scatter( + 0, + 0, + s=120, + c="gray", + edgecolors="black", + linewidths=0.5, + alpha=centroid_alpha, + zorder=6, + ) + + ax2.set_xlim(-display_radius, display_radius) + ax2.set_ylim(-display_radius, display_radius) + ax2.set_aspect("equal") + ax2.set_xlabel("X", color="white", fontsize=10) + ax2.set_ylabel( + "Y", color="white", fontsize=10, rotation=0, ha="right", va="center" + ) + ax2.tick_params(colors="gray", labelsize=8) + ax2.grid(True, color="gray", alpha=0.3, linestyle="-", linewidth=0.5) + for spine in ax2.spines.values(): + spine.set_color("gray") + + # TILE 3: Inferred AP Direction (Velocity Voting) + ax3 = fig.add_subplot(gs[1, 0], facecolor="black") + ax3.set_title( + "Inferred AP Direction (Velocity Voting)", + fontsize=12, + fontweight="normal", + color="white", + ) + + vel_proj_scalar = np.dot(net_vel_display, pc1) + min_proj_vel = min(0, vel_proj_scalar) + max_proj_vel = max(0, vel_proj_scalar) + + ax3.plot( + [-1.5 * body_axis_extent * pc1[0], min_proj_vel * pc1[0]], + [-1.5 * body_axis_extent * pc1[1], min_proj_vel * pc1[1]], + "-", + color=(1, 1, 1, 0.65), + linewidth=1.5, + ) + ax3.plot( + [max_proj_vel * pc1[0], 1.5 * body_axis_extent * pc1[0]], + [max_proj_vel * pc1[1], 1.5 * body_axis_extent * pc1[1]], + "-", + color=(1, 1, 1, 0.65), + linewidth=1.5, + ) + + vel_proj_x = vel_proj_scalar * pc1[0] + vel_proj_y = vel_proj_scalar * pc1[1] + ax3.plot( + [0, vel_proj_x], + [0, vel_proj_y], + "-", + color=(1, 1, 1, 0.5), + linewidth=1.5, + ) + + body_perp = np.array([-pc1[1], pc1[0]]) + + # Velocity projection histogram along PC1 + if vel_projs is not None and len(vel_projs) > 0: + # Adapt bins based on data size, keep consistent transparency + n_vals = len(vel_projs) + # Use a ternary expression rather than an if–else block (SIM108) + num_bins = max(5, n_vals // 3) if n_vals < 50 else 25 + hist_alpha = 0.25 + max_bar_height = 1.0 * shape_radius + + vp_min = np.min(vel_projs) + vp_max = np.max(vel_projs) + data_range = vp_max - vp_min + + # Ensure minimum bin width similar to 2Mice + min_bin_width = shape_radius / 12 + min_range = min_bin_width * num_bins + if data_range < min_range: + center = (vp_min + vp_max) / 2 + vp_min = center - min_range / 2 + vp_max = center + min_range / 2 + data_range = min_range + + edge_pad = data_range * 0.02 + bin_edges = np.linspace( + vp_min - edge_pad, vp_max + edge_pad, num_bins + 1 + ) + + bin_counts, _ = np.histogram(vel_projs, bins=bin_edges) + bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 + + max_count = np.max(bin_counts) + # Set minimum bar height so even count=1 is visible + min_bar_height = 0.15 * max_bar_height + if max_count > 0: + bar_heights = (bin_counts / max_count) * max_bar_height + # Apply minimum height for non-zero bins + bar_heights = np.where( + bin_counts > 0, + np.maximum(bar_heights, min_bar_height), + bar_heights, + ) + else: + bar_heights = np.zeros_like(bin_counts) + + # Draw bars as filled polygons in PC1/perp coordinate frame + for bi in range(len(bin_counts)): + if bin_counts[bi] == 0: + continue + + pc1_lo = bin_edges[bi] + pc1_hi = bin_edges[bi + 1] + bh = bar_heights[bi] + + # Four corners: bottom-left, bottom-right, top-right, top-left + corners_pc1 = np.array([pc1_lo, pc1_hi, pc1_hi, pc1_lo]) + corners_perp = np.array([0, 0, bh, bh]) + + # Transform to x,y using PC1 and body_perp basis + cx = corners_pc1 * pc1[0] + corners_perp * body_perp[0] + cy = corners_pc1 * pc1[1] + corners_perp * body_perp[1] + + # Color by sign: blue for +PC1, red for -PC1 + if bin_centers[bi] > 0: + bar_color = (0.4, 0.75, 1.0) # blue + else: + bar_color = (1.0, 0.45, 0.35) # red + + ax3.fill( + cx, + cy, + color=bar_color, + alpha=hist_alpha, + edgecolor="black", + linewidth=0.5, + zorder=2, + ) + + # Add bin count label at top of bar (centered within bar) + bin_center_pc1 = (pc1_lo + pc1_hi) / 2 + # At top of bar, slightly inside + label_perp_offset = bh - 0.02 * shape_radius + lx = bin_center_pc1 * pc1[0] + label_perp_offset * body_perp[0] + ly = bin_center_pc1 * pc1[1] + label_perp_offset * body_perp[1] + # Rotation angle matches PC1 direction + rot_deg = np.degrees(np.arctan2(pc1[1], pc1[0])) + ax3.text( + lx, + ly, + str(bin_counts[bi]), + color="white", + fontsize=5, + ha="center", + va="top", + rotation=rot_deg, + rotation_mode="anchor", + zorder=20, + ) + + for i in valid_idx: + ax3.scatter( + avg_skeleton[i, 0], + avg_skeleton[i, 1], + s=120, + c=[colors[i]], + edgecolors="black", + linewidths=0.5, + zorder=5, + ) + + # PC1 labels at endpoints of PC line, on opposite side of histogram + pc_label_extent = 1.1 * body_axis_extent # Closer to tile center + pc_label_offset = label_offset * 1.5 # Offset perpendicular to PC1 + ax3.text( + pc_label_extent * pc1[0] - body_perp[0] * pc_label_offset, + pc_label_extent * pc1[1] - body_perp[1] * pc_label_offset, + "+PC1", + color="white", + fontsize=10, + ha="center", + va="center", + ) + ax3.text( + -pc_label_extent * pc1[0] - body_perp[0] * pc_label_offset, + -pc_label_extent * pc1[1] - body_perp[1] * pc_label_offset, + "—PC1", + color="white", + fontsize=10, + ha="center", + va="center", + ) + + # Vc arrow (circular-mean velocity) + if np.linalg.norm(net_vel_display) > 0: + net_vel_angle = np.arctan2(net_vel_display[1], net_vel_display[0]) + + ax3.plot( + [0, net_vel_display[0]], + [0, net_vel_display[1]], + "-", + color="white", + linewidth=1.2, + zorder=7, + ) + + arrow_len = 0.06 * shape_radius + arrow_wid = 0.03 * shape_radius + local_tri_x = np.array([0, -arrow_len, -arrow_len]) + local_tri_y = np.array([0, arrow_wid, -arrow_wid]) + cos_a, sin_a = np.cos(net_vel_angle), np.sin(net_vel_angle) + tri_x = local_tri_x * cos_a - local_tri_y * sin_a + net_vel_display[0] + tri_y = local_tri_x * sin_a + local_tri_y * cos_a + net_vel_display[1] + ax3.fill( + tri_x, + tri_y, + color="white", + edgecolor="black", + linewidth=0.5, + zorder=8, + clip_on=True, + ) + + vc_label_offset = 0.12 * shape_radius + vc_label_x = net_vel_display[0] + vc_label_offset * np.cos( + net_vel_angle + ) + vc_label_y = net_vel_display[1] + vc_label_offset * np.sin( + net_vel_angle + ) + ax3.text( + vc_label_x, + vc_label_y, + r"$\mathbf{V_c}$", + color="white", + fontsize=12, + fontweight="bold", + ha="left", + va="center", + zorder=10, + clip_on=True, + ) + + # AP midline and bidirectional arrow (parallel to PC1) + min_p = np.min(proj_pc1[valid_mask]) + max_p = np.max(proj_pc1[valid_mask]) + midpoint_pc1 = (min_p + max_p) / 2 + midline_center = midpoint_pc1 * pc1 # Position along PC1 + midline_half_len = 0.35 * shape_radius + pc1_perp = np.array([-pc1[1], pc1[0]]) # Perpendicular to PC1 + + # Draw AP midline (dotted beige, perpendicular to PC1) + ax3.plot( + [ + midline_center[0] - midline_half_len * pc1_perp[0], + midline_center[0] + midline_half_len * pc1_perp[0], + ], + [ + midline_center[1] - midline_half_len * pc1_perp[1], + midline_center[1] + midline_half_len * pc1_perp[1], + ], + ":", + color=(0.82, 0.71, 0.55), + linewidth=1.5, + zorder=4, + ) + + # AP bidirectional arrow (along PC1, next to midline) + ap_arrow_offset = midline_half_len * 1.1 # Offset perpendicular to PC1 + # Opposite side of histogram + ap_arrow_center = midline_center - ap_arrow_offset * pc1_perp + ap_arrow_half_len = 0.225 * shape_radius # Shortened by 0.5x + + # Arrow endpoints along PC1 + ap_arrow_ant = ap_arrow_center + ap_arrow_half_len * pc1 * anterior_sign + ap_arrow_post = ap_arrow_center - ap_arrow_half_len * pc1 * anterior_sign + + # Draw arrow shaft + ax3.plot( + [ap_arrow_post[0], ap_arrow_ant[0]], + [ap_arrow_post[1], ap_arrow_ant[1]], + "-", + color=(0.82, 0.71, 0.55), + linewidth=1.5, + zorder=4, + ) + + # Draw arrowheads + head_len = 0.08 * shape_radius + head_wid = 0.04 * shape_radius + + # Anterior arrowhead + ant_head_base = ap_arrow_ant - head_len * pc1 * anterior_sign + ant_tri = np.array( + [ + ap_arrow_ant, + ant_head_base + head_wid * pc1_perp, + ant_head_base - head_wid * pc1_perp, + ] + ) + tan_color = (0.82, 0.71, 0.55) + ax3.fill( + ant_tri[:, 0], + ant_tri[:, 1], + color=tan_color, + edgecolor="black", + linewidth=0.5, + zorder=5, + ) + + # Posterior arrowhead + post_head_base = ap_arrow_post + head_len * pc1 * anterior_sign + post_tri = np.array( + [ + ap_arrow_post, + post_head_base + head_wid * pc1_perp, + post_head_base - head_wid * pc1_perp, + ] + ) + ax3.fill( + post_tri[:, 0], + post_tri[:, 1], + color=tan_color, + edgecolor="black", + linewidth=0.5, + zorder=5, + ) + + # A and P labels + ap_label_offset = 0.06 * shape_radius + ax3.text( + ap_arrow_ant[0] + ap_label_offset * pc1[0] * anterior_sign, + ap_arrow_ant[1] + ap_label_offset * pc1[1] * anterior_sign, + "A", + color=tan_color, + fontsize=10, + fontweight="bold", + ha="center", + va="center", + zorder=10, + ) + ax3.text( + ap_arrow_post[0] - ap_label_offset * pc1[0] * anterior_sign, + ap_arrow_post[1] - ap_label_offset * pc1[1] * anterior_sign, + "P", + color=tan_color, + fontsize=10, + fontweight="bold", + ha="center", + va="center", + zorder=10, + ) + + ax3.set_xlim(-display_radius, display_radius) + ax3.set_ylim(-display_radius, display_radius) + ax3.set_aspect("equal") + ax3.set_xlabel("X", color="white", fontsize=10) + ax3.set_ylabel( + "Y", color="white", fontsize=10, rotation=0, ha="right", va="center" + ) + ax3.set_xticks([]) + ax3.set_yticks([]) + for spine in ax3.spines.values(): + spine.set_color("gray") + + # TILE 4: R×M vs GT Accuracy Scatter Plot + ax4 = fig.add_subplot(gs[1, 1], facecolor="black") + + # Different marker shapes for each individual + marker_shapes = ["o", "s", "^", "D", "v", "h", "*", "X", "P", "p"] + + # Check if both metrics and accuracy data are present + has_data = ( + all_individuals_metrics + and len(all_individuals_metrics) > 0 + and individual_accuracy + and len(individual_accuracy) > 0 + ) + + if has_data: + # Collect data points for scatter plot + rxm_values = [] + accuracy_values = [] + ind_labels = [] + + for ind in sorted(all_individuals_metrics.keys()): + R = all_individuals_metrics[ind]["R"] + M = all_individuals_metrics[ind]["M"] + rxm = R * M + + if ind in individual_accuracy: + acc = individual_accuracy[ind]["accuracy"] + rxm_values.append(rxm) + accuracy_values.append(acc) + ind_labels.append(ind) + + if rxm_values: + # Use file color if provided, otherwise default to gray + base_color = file_color[:3] if file_color else (0.5, 0.5, 0.5) + + # Find best individual (highest R×M) + best_idx = np.argmax(rxm_values) + + # Plot each individual with a different marker shape + # Best individual (highest R×M) always uses star marker + legend_handles = [] + data = zip(rxm_values, accuracy_values, ind_labels, strict=True) + for i, (rxm, acc, label) in enumerate(data): + is_best = i == best_idx + marker = ( + "*" if is_best else marker_shapes[i % len(marker_shapes)] + ) + size = 200 if is_best else 120 + + zord = 6 if is_best else 5 + _scatter = ax4.scatter( + [rxm], + [acc], + s=size, + c=[base_color], + marker=marker, + edgecolors="white", + linewidths=1.0, + alpha=0.9, + zorder=zord, + ) + + # Create legend handle (white edge) + ms = 12 if is_best else 10 + legend_handles.append( + Line2D( + [0], + [0], + marker=marker, + color="w", + markerfacecolor=base_color, + markeredgecolor="white", + markeredgewidth=1.0, + markersize=ms, + linestyle="None", + label=label, + ) + ) + + # Set axis limits with padding + rxm_range = ( + max(rxm_values) - min(rxm_values) + if len(rxm_values) > 1 + else 0.1 + ) + # Handle case when all values are identical (range = 0) + if rxm_range < 0.01: + rxm_range = 0.1 + x_min = min(rxm_values) - rxm_range * 0.15 + x_max = max(rxm_values) + rxm_range * 0.25 + ax4.set_xlim(x_min, x_max) + # Y-axis: -20 to 120 range, but only show ticks 0-100 + ax4.set_ylim(-20, 120) + ax4.set_yticks([0, 20, 40, 60, 80, 100]) + + # Axis labels + ax4.set_xlabel("R×M", color="white", fontsize=10) + ax4.set_ylabel( + "Pairwise GT Concordance\n(inferred AP, %)", + color="white", + fontsize=10, + ) + + # Add legend for individual markers + ax4.legend( + handles=legend_handles, + loc="lower right", + fontsize=8, + facecolor="black", + edgecolor="gray", + labelcolor="white", + ) + + # Add grid + ax4.grid( + True, color="gray", alpha=0.3, linestyle="-", linewidth=0.5 + ) + else: + ax4.text( + 0.5, + 0.5, + "No data", + color="gray", + fontsize=12, + ha="center", + va="center", + transform=ax4.transAxes, + ) + + ax4.tick_params(axis="y", colors="gray", labelsize=8) + ax4.tick_params(axis="x", colors="gray", labelsize=8) + for spine in ax4.spines.values(): + spine.set_color("gray") + + # NODE LEGEND (Center of figure) + legend_x = 0.47 + legend_top_y = 0.88 + legend_spacing = 0.045 + + fig_w, fig_h = fig.get_size_inches() + fig_aspect = fig_w / fig_h + dot_h = 0.012 + dot_w = dot_h / fig_aspect + + legend_entry_count = 0 + for i in valid_idx: + y_pos = legend_top_y - legend_entry_count * legend_spacing + + ax_dot = fig.add_axes( + [legend_x - dot_w / 2, y_pos - dot_h / 2, dot_w, dot_h] + ) + ax_dot.set_xlim(0, 1) + ax_dot.set_ylim(0, 1) + ax_dot.scatter( + 0.5, 0.5, s=100, c=[colors[i]], edgecolors="black", linewidths=0.5 + ) + ax_dot.axis("off") + + name = keypoint_names[i] if i < len(keypoint_names) else f"node_{i}" + fig.text( + legend_x + 0.015, + y_pos, + f"[{i}] {name}", + color="white", + fontsize=10, + va="center", + ha="left", + ) + + legend_entry_count += 1 + + # Centroid entry + y_pos = legend_top_y - legend_entry_count * legend_spacing + ax_dot = fig.add_axes( + [legend_x - dot_w / 2, y_pos - dot_h / 2, dot_w, dot_h] + ) + ax_dot.set_xlim(0, 1) + ax_dot.set_ylim(0, 1) + ax_dot.scatter( + 0.5, + 0.5, + s=100, + c=[(0.3, 0.3, 0.3)], + edgecolors="black", + linewidths=0.5, + ) + ax_dot.axis("off") + fig.text( + legend_x + 0.015, + y_pos, + "bbox centroid", + color="white", + fontsize=10, + va="center", + ha="left", + ) + legend_entry_count += 1 + + # Histogram bar entries (blue = +vel, red = -vel) + bar_w = 0.012 + bar_h = 0.008 + + # Blue bar (+vel projections) + y_pos = legend_top_y - legend_entry_count * legend_spacing + ax_bar = fig.add_axes( + [legend_x - bar_w / 2, y_pos - bar_h / 2, bar_w, bar_h] + ) + ax_bar.set_xlim(0, 1) + ax_bar.set_ylim(0, 1) + ax_bar.add_patch( + plt.Rectangle( + (0.1, 0.1), + 0.8, + 0.8, + facecolor=(0.4, 0.75, 1.0), + alpha=0.6, + edgecolor="none", + ) + ) + ax_bar.axis("off") + fig.text( + legend_x + 0.015, + y_pos, + "+PC1 vel proj.", + color="white", + fontsize=10, + va="center", + ha="left", + ) + legend_entry_count += 1 + + # Red bar (-vel projections) + y_pos = legend_top_y - legend_entry_count * legend_spacing + ax_bar = fig.add_axes( + [legend_x - bar_w / 2, y_pos - bar_h / 2, bar_w, bar_h] + ) + ax_bar.set_xlim(0, 1) + ax_bar.set_ylim(0, 1) + ax_bar.add_patch( + plt.Rectangle( + (0.1, 0.1), + 0.8, + 0.8, + facecolor=(1.0, 0.45, 0.35), + alpha=0.6, + edgecolor="none", + ) + ) + ax_bar.axis("off") + fig.text( + legend_x + 0.015, + y_pos, + "−PC1 vel proj.", + color="white", + fontsize=10, + va="center", + ha="left", + ) + legend_entry_count += 1 + + # AP Midline (dotted beige line) + line_w = 0.025 + line_h = 0.006 + y_pos = legend_top_y - legend_entry_count * legend_spacing + ax_line = fig.add_axes( + [legend_x - line_w / 2, y_pos - line_h / 2, line_w, line_h] + ) + ax_line.set_xlim(0, 1) + ax_line.set_ylim(0, 1) + ax_line.plot([0.1, 0.9], [0.5, 0.5], ":", color=tan_color, linewidth=2) + ax_line.axis("off") + fig.text( + legend_x + 0.015, + y_pos, + "AP midline", + color="white", + fontsize=10, + va="center", + ha="left", + ) + + # Save figure + plt.savefig( + output_path, + format="svg", + facecolor="black", + edgecolor="none", + bbox_inches="tight", + ) + plt.close(fig) + print(f" Saved: {output_path.name}") + + +def generate_individual_plots(analysis_data, output_dir, timestamp=None): # noqa: C901 + """Generate detailed 2x2 plot for each file's best individual.""" + data = analysis_data + stored_skeletons = data.get("skeletons", {}) + stored_pc1 = data.get("pc1_vectors", {}) + stored_vel_projs = data.get("vel_projs_pc1", {}) + keypoints_by_file = data.get("keypoints", {}) + individual_accuracy_by_file = data.get("individual_accuracy", {}) + + # File color mapping (matching ap_validation_results figure) + # Files sorted alphabetically, colors assigned in order + file_colors = { + "free-moving-2flies-ID-13nodes-1024x1024x1-30_3pxmm": ( + 0.12, + 0.47, + 0.71, + 1.0, + ), # Blue + "free-moving-2mice-noID-5nodes-1280x1024x1-1_9pxmm": ( + 0.17, + 0.63, + 0.17, + 1.0, + ), # Green + "free-moving-4gerbils-ID-14nodes-1024x1280x3-2pxmm": ( + 0.80, + 0.70, + 0.10, + 1.0, + ), # Gold + "free-moving-5mice-noID-11nodes-1280x1024x1-1_97pxmm": ( + 1.00, + 0.50, + 0.05, + 1.0, + ), # Orange + "freemoving-2bees-noID-21nodes-1535x2048x1-14pxmm": ( + 0.58, + 0.40, + 0.74, + 1.0, + ), # Purple + } + + # Collect R, M, and R×M values per file per individual + def default_metrics(): + return {"R": [], "M": [], "RxM": []} + + file_individual_data = defaultdict(lambda: defaultdict(default_metrics)) + for i in range(len(data["file"])): + file_stem = data["file"][i] + individual = data["individual"][i] + R = data["resultant_length"][i] + M = data["vote_margin"][i] + rxm = data["rxm"][i] + if not np.isnan(rxm): + file_individual_data[file_stem][individual]["R"].append(R) + file_individual_data[file_stem][individual]["M"].append(M) + file_individual_data[file_stem][individual]["RxM"].append(rxm) + + # Compute mean R, M per individual and find best by R×M + best_individual = {} + best_metrics = {} + # {file_stem: {individual: {"R": mean_R, "M": mean_M}}} + all_individuals_metrics_by_file = {} + + for file_stem, individuals in file_individual_data.items(): + best_rxm = -1 + best_ind = None + all_individuals_metrics_by_file[file_stem] = {} + + for ind, metrics_dict in individuals.items(): + mean_R = np.mean(metrics_dict["R"]) + mean_M = np.mean(metrics_dict["M"]) + mean_rxm = np.mean(metrics_dict["RxM"]) + all_individuals_metrics_by_file[file_stem][ind] = { + "R": mean_R, + "M": mean_M, + } + if mean_rxm > best_rxm: + best_rxm = mean_rxm + best_ind = ind + + best_individual[file_stem] = best_ind + + # Get metrics for best individual + for i in range(len(data["file"])): + is_match = ( + data["file"][i] == file_stem + and data["individual"][i] == best_ind + ) + if is_match: + best_metrics[file_stem] = { + "vote_margin": data["vote_margin"][i], + "resultant_length": data["resultant_length"][i], + "anterior_sign": data["anterior_sign"][i], + "circ_mean_dir": data["circ_mean_dir"][i], + "num_selected_frames": data["num_selected_frames"][i], + "n_frames": data["n_frames"][i], + } + break + + print("\nGenerating per-file detail plots (best individual)...") + + for file_stem, best_ind in sorted(best_individual.items()): + label = FILE_LABELS.get(file_stem, file_stem[:15]) + print(f" {label}: *{best_ind}*") + + has_skel = ( + file_stem in stored_skeletons + and best_ind in stored_skeletons[file_stem] + ) + if not has_skel: + print(" No skeleton data") + continue + + avg_skeleton = stored_skeletons[file_stem][best_ind] + + has_pc1 = file_stem in stored_pc1 and best_ind in stored_pc1[file_stem] + if not has_pc1: + print(" No PC1 data") + continue + + pc1 = stored_pc1[file_stem][best_ind] + keypoint_names = keypoints_by_file.get(file_stem, []) + metrics = best_metrics.get(file_stem, {}) + + # Get velocity projections for histogram + vel_projs = None + has_vel = ( + file_stem in stored_vel_projs + and best_ind in stored_vel_projs[file_stem] + ) + if has_vel: + vel_projs = stored_vel_projs[file_stem][best_ind] + + # Get R and M data for all individuals in this file + all_individuals_metrics = all_individuals_metrics_by_file.get( + file_stem, {} + ) + + # Get individual accuracy data for this file + individual_accuracy = individual_accuracy_by_file.get(file_stem, {}) + + # Get file color + file_color = file_colors.get(file_stem, (0.5, 0.5, 0.5, 1.0)) + + # Sanitize individual name for filename + safe_ind = best_ind.replace(" ", "_") + ts_suffix = f"_{timestamp}" if timestamp else "" + output_path = output_dir / f"{file_stem}_{safe_ind}{ts_suffix}.svg" + create_individual_plot( + file_stem, + avg_skeleton, + pc1, + keypoint_names, + metrics, + output_path, + vel_projs, + individual_name=best_ind, + all_individuals_metrics=all_individuals_metrics, + file_color=file_color, + individual_accuracy=individual_accuracy, + ) + + print(f"Per-file detail plots saved to: {output_dir}") + + +def main(): # noqa: C901 + ensure_demo_datasets() + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + FIGURES_DIR.mkdir(exist_ok=True) + LOGS_DIR.mkdir(exist_ok=True) + H5_DIR.mkdir(exist_ok=True) + + # Set up log file + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + log_file = LOGS_DIR / f"ap_validation_{timestamp}.log" + + with TeeOutput(log_file): + print(f"AP Validation started at {datetime.now().isoformat()}") + print(f"Log file: {log_file}") + print() + + # Check for existing H5 files + existing_h5 = sorted(H5_DIR.glob("ap_validation_*.h5")) + + if existing_h5: + # Use most recent existing file + output_path = existing_h5[-1] + print(f"Found existing H5 file: {output_path.name}") + print("Skipping generation, proceeding to analysis...") + else: + # Generate new H5 file + slp_files = sorted(SLP_DIR.glob("*.slp")) + if not slp_files: + print("No .slp files found after bootstrap") + sys.exit(1) + + # PASS 1: R×M Selection — Find best individual per file + n_files = len(slp_files) + print( + f"Pass 1: R×M Selection — finding best individual " + f"per file from {n_files} files..." + ) + rxm_tasks, file_metadata = generate_rxm_tasks(slp_files) + print(f" Running {len(rxm_tasks)} R×M computations...") + + rxm_results = [] + with ProcessPoolExecutor(max_workers=N_WORKERS) as executor: + futures = [ + executor.submit(process_single_validation, task) + for task in rxm_tasks + ] + for future in as_completed(futures): + rxm_results.append(future.result()) + + best_individuals, all_rxm, file_individual_data = ( + find_best_individuals(rxm_results) + ) + best_ind_str = ", ".join( + f"{FILE_LABELS.get(k, k[:15])}: *{v}*" + for k, v in best_individuals.items() + ) + print(f" Best individuals (max R×M): {{{best_ind_str}}}") + + # Compute PC1-based orderings for all individuals + # (prerequisite for Pass 2: cross-individual consistency) + best_pc1_orderings, all_pc1_orderings = compute_pc1_orderings( + file_individual_data, file_metadata, best_individuals + ) + + # PASS 2: Cross-Individual Ordering Consistency + print( + "\nPass 2: Cross-Individual Ordering Consistency — " + "do individuals from the same video agree on node ordering?\n" + " Each individual's raw PC1 ordering of GT nodes is compared " + "against the best individual's ordering (the 'pseudo GT').\n" + " This is a CONSISTENCY check, not a correctness check — " + "high agreement means the body shape is stable across " + "individuals, but says nothing about whether the ordering " + "is anatomically correct." + ) + ordering_matches = compare_orderings_to_pseudo_gt( + all_pc1_orderings, best_pc1_orderings + ) + for file_stem, matches in ordering_matches.items(): + n_match = sum(matches.values()) + n_total = len(matches) + label = FILE_LABELS.get(file_stem, file_stem[:15]) + print( + f" {label}: {n_match}/{n_total} individuals " + f"share the best individual's ordering" + ) + + # PASS 3: Inferred AP Concordance per Individual + # For each individual, project GT nodes onto the inferred + # AP axis (anterior_sign × PC1) and compare the ordering + # against hand-curated GT. This tests the full pipeline + # (PCA + velocity voting) per individual. Feeds Figure 1 Tile 4. + print( + "\nPass 3: Inferred AP Concordance — " + "does each individual's velocity-inferred AP ordering " + "match hand-curated GT?\n" + " For each individual, GT nodes are projected onto " + "anterior_sign × PC1 (the inferred AP axis). All C(n,2) " + "unique pairs are tested against the hand-curated GT " + "ranking.\n" + " This tests the full pipeline (PCA + velocity voting) " + "per individual." + ) + individual_accuracy = compute_inferred_ap_concordance( + file_individual_data + ) + + # Build R, M, R×M lookup from Pass 1 results + ind_metrics = {} + for rec in rxm_results: + if not rec.get("error", False): + key = (rec["file"], rec["individual"]) + ind_metrics[key] = { + "R": rec.get("resultant_length", float("nan")), + "M": rec.get("vote_margin", float("nan")), + "rxm": rec.get("rxm", float("nan")), + } + + for file_stem, accuracies in individual_accuracy.items(): + print(f" {FILE_LABELS.get(file_stem, file_stem[:15])}:") + best_ind = best_individuals.get(file_stem) + for ind, acc in sorted(accuracies.items()): + c = acc["correct"] + t = acc["total"] + a = acc["accuracy"] + marker = "*" if ind == best_ind else "" + m = ind_metrics.get((file_stem, ind), {}) + R = m.get("R", float("nan")) + M = m.get("M", float("nan")) + rxm = m.get("rxm", float("nan")) + print( + f" {marker}{ind}{marker}: " + f"{c}/{t} pairs = {a:.1f}% " + f"(R={R:.2f}, M={M:.2f}, R×M={rxm:.2f})" + ) + + # Generate validation tasks for H5 storage + print( + "\nGenerating H5 validation records " + "(all GT pair permutations × all individuals)..." + ) + tasks, keypoints_by_file = generate_gt_validation_tasks( + file_metadata, best_individuals + ) + n_tasks = len(tasks) + print( + f" Processing {n_tasks} validation comparisons " + f"with {N_WORKERS} workers..." + ) + + all_results = [] + with ProcessPoolExecutor(max_workers=N_WORKERS) as executor: + futures = [ + executor.submit(process_single_validation, task) + for task in tasks + ] + + for i, future in enumerate(as_completed(futures), 1): + result = future.result() + all_results.append(result) + if i % 50 == 0: + print(f" {i}/{len(tasks)} completed...") + + h5_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_path = H5_DIR / f"ap_validation_{h5_timestamp}.h5" + # Include Pass 1 results to ensure best individual's data is saved + combined_results = rxm_results + all_results + save_to_h5( + combined_results, + keypoints_by_file, + output_path, + individual_accuracy=individual_accuracy, + all_rxm=all_rxm, + ) + + print(f"\nSaved {len(all_results)} records to: {output_path}") + + # Reporting: GT coverage, suggested pairs, and Figure 2 data + analysis_data = analyze_results(output_path) + plot_validation_results(analysis_data, output_path) + + # Generate per-file 2×2 tile detail plots for each best individual + # Extract timestamp from H5 filename (ap_validation_YYYYMMDD_HHMMSS.h5) + h5_stem = output_path.stem + fig_timestamp = "_".join(h5_stem.split("_")[-2:]) + generate_individual_plots(analysis_data, FIGURES_DIR, fig_timestamp) + + print(f"\nLog saved to: {log_file}") + + +if __name__ == "__main__": + main() diff --git a/examples/compute_polarization_viz.py b/examples/compute_polarization_viz.py new file mode 100644 index 000000000..ca3455ff3 --- /dev/null +++ b/examples/compute_polarization_viz.py @@ -0,0 +1,3186 @@ +#!/usr/bin/env python3 +"""Static figure showing polarization dynamics over time. + +Row 1: Overlay frames (1 second apart) - continuous mode + - Each panel after the first shows two frames superimposed: + * Current frame (panel_frame[p]) at 100% opacity + * Previous frame (panel_frame[p] - fps) at 50% opacity (additive ghost) + - Displacement arrows: velocity_node(t−fps) → velocity_node(t) + - For sparse mode: no overlay, just current frame + +Row 2: Current frame with body axis + - Shows current frame only (no overlay) + - Opaque body axis line + - Square marker at from_node, circle marker at to_node + - When pair is loaded from AP validation H5: from_node = inferred posterior, + to_node = inferred anterior (via _order_pair_by_ap). Otherwise + the user's ordering is taken as-is with no directional validation. + - Use H5 output from compute_polarization_AP_inference.py to automatically + select the suggested ordered pair from the AP inference pipeline. + Element 0 of the suggested pair (inferred posterior node) + serves as velocity keypoint. + +Row 3: Polar plots + - Orientation polarization (yellow) and heading polarization (orange) + +Usage: + python compute_polarization_viz.py +""" + +import sys +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import cast +from urllib.request import urlretrieve + +import cv2 +import h5py +import matplotlib +import matplotlib.gridspec as gridspec +import matplotlib.pyplot as plt +import numpy as np +import xarray as xr + +from movement.io import load_poses +from movement.kinematics import compute_polarization +from movement.kinematics.body_axis import ValidateAPConfig, validate_ap + +# ── Duplicated-literal constants ───────────────────────────────────────────── +SLP_GLOB = "*.slp" +INVALID_NUMBER_MSG = " Invalid. Enter a number." +ENTER_INDEX_PROMPT = " Enter index: " + + +class TeeOutput: + """Context manager that duplicates stdout to both console and a file.""" + + def __init__(self, filepath): + """Initialise with target file path.""" + self.filepath = Path(filepath) + self.file = None + self.original_stdout = None + + def __enter__(self): + """Open log file and redirect stdout.""" + self.filepath.parent.mkdir(parents=True, exist_ok=True) + self.file = open(self.filepath, "w") + self.original_stdout = sys.stdout + sys.stdout = self + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Restore stdout and close the log file.""" + sys.stdout = self.original_stdout + if self.file: + self.file.close() + return False + + def write(self, text): + """Write *text* to both console and log file.""" + self.original_stdout.write(text) + self.file.write(text) + self.file.flush() + + def flush(self): + """Flush both console and log file streams.""" + self.original_stdout.flush() + self.file.flush() + + +# ── Pyramid blending helpers (extracted from adaptive_blend_frames) ────────── + + +def _build_pyramid(img, levels=4): + """Build Gaussian and Laplacian pyramids.""" + gaussian = [img.astype(np.float32)] + for _ in range(levels): + img = cv2.pyrDown(img) + gaussian.append(img.astype(np.float32)) + laplacian = [] + for i in range(levels): + size = (gaussian[i].shape[1], gaussian[i].shape[0]) + expanded = cv2.pyrUp(gaussian[i + 1], dstsize=size) + laplacian.append(gaussian[i] - expanded) + laplacian.append(gaussian[-1]) + return laplacian + + +def _blend_pyramids(lap1, lap2, mask_pyr): + """Blend two Laplacian pyramids using mask pyramid.""" + blended = [] + for l1, l2, m in zip(lap1, lap2, mask_pyr, strict=True): + if m.ndim == 2 and l1.ndim == 3: + m = m[:, :, np.newaxis] + blended.append(l1 * (1 - m) + l2 * m) + return blended + + +def _reconstruct(pyramid): + """Reconstruct image from Laplacian pyramid.""" + img = pyramid[-1] + for i in range(len(pyramid) - 2, -1, -1): + size = (pyramid[i].shape[1], pyramid[i].shape[0]) + img = cv2.pyrUp(img, dstsize=size) + pyramid[i] + return img + + +def _to_grayscale(img_u8): + """Convert to grayscale, handling both color and single-channel.""" + if img_u8.ndim == 3: + return cv2.cvtColor(img_u8, cv2.COLOR_RGB2GRAY) + return img_u8 + + +def _build_position_masks(h, w, prev_positions, curr_positions, mask_radius): + """Create ghost and protection masks from tracked animal positions.""" + ghost_mask = np.zeros((h, w), dtype=np.float32) + protect_mask = np.zeros((h, w), dtype=np.float32) + + if prev_positions is not None: + for px, py in prev_positions: + if not np.isnan(px) and not np.isnan(py): + cv2.circle( + ghost_mask, (int(px), int(py)), mask_radius, 1.0, -1 + ) + + if curr_positions is not None: + for cx, cy in curr_positions: + if not np.isnan(cx) and not np.isnan(cy): + cv2.circle( + protect_mask, (int(cx), int(cy)), mask_radius, 1.0, -1 + ) + + ghost_mask = cast( + "np.ndarray", + cv2.GaussianBlur(ghost_mask, (51, 51), 0), + ) + protect_mask = cast( + "np.ndarray", + cv2.GaussianBlur(protect_mask, (51, 51), 0), + ) + return ghost_mask, protect_mask + + +def _enhance_ghost(prev_u8, clahe): + """Apply CLAHE enhancement and tint to ghost frame.""" + if prev_u8.ndim == 3: + prev_lab = cv2.cvtColor(prev_u8, cv2.COLOR_RGB2LAB) + prev_lab[:, :, 0] = clahe.apply(prev_lab[:, :, 0]) + prev_enhanced = cv2.cvtColor(prev_lab, cv2.COLOR_LAB2RGB) + else: + prev_enhanced = clahe.apply(prev_u8) + + ghost_tinted = prev_enhanced.astype(np.float32) + if ghost_tinted.ndim == 3: + ghost_tinted[:, :, 0] *= 0.85 # reduce red + ghost_tinted[:, :, 1] *= 0.95 # slight reduce green + ghost_tinted[:, :, 2] = np.clip(ghost_tinted[:, :, 2] * 1.1, 0, 255) + return ghost_tinted + + +def _compute_alpha_map( + base_alpha, + ghost_mask, + protect_mask, + motion_alpha, + prev_positions, + curr_positions, + curr_gray, +): + """Compute adaptive alpha map from masks and motion data.""" + if prev_positions is not None or curr_positions is not None: + alpha_map = ( + base_alpha * (0.3 + 0.7 * ghost_mask) * (1 - 0.5 * protect_mask) + ) + alpha_map = np.maximum(alpha_map, base_alpha * 0.5 * motion_alpha) + else: + alpha_map = base_alpha * (0.3 + 0.7 * motion_alpha) + + brightness = curr_gray.astype(np.float32) / 255.0 + alpha_map = alpha_map * (1 - 0.5 * brightness) + return cv2.GaussianBlur(alpha_map, (15, 15), 0) + + +def adaptive_blend_frames( + curr_img: np.ndarray, + prev_img: np.ndarray, + base_alpha: float = 0.5, + curr_positions: list[tuple[float, float]] | None = None, + prev_positions: list[tuple[float, float]] | None = None, + mask_radius: int = 80, +) -> np.ndarray: + """Blend two frames using CV2-based adaptive compositing. + + Uses multiple techniques for optimal visibility of both frames: + 1. Motion-based ghost detection via frame differencing + 2. Position-aware masking if tracking coordinates provided + 3. Laplacian pyramid blending for seamless compositing + 4. Local contrast enhancement (CLAHE) for ghost visibility + 5. Adaptive tone mapping to prevent bleaching + + Parameters + ---------- + curr_img : ndarray + Current frame as float array in [0, 1], shape (H, W, C) or (H, W). + prev_img : ndarray + Previous frame (ghost) as float array in [0, 1], same shape. + base_alpha : float + Base ghost opacity (0-1). Default 0.5. + curr_positions : list of (x, y) tuples, optional + Current positions of tracked animals (arrow heads). + prev_positions : list of (x, y) tuples, optional + Previous positions of tracked animals (arrow stems/ghosts). + mask_radius : int + Radius around animal positions for masking. Default 80 pixels. + + Returns + ------- + ndarray + Blended image with same shape as inputs, values in [0, 1]. + + """ + curr_u8 = (curr_img * 255).astype(np.uint8) + prev_u8 = (prev_img * 255).astype(np.uint8) + h, w = curr_u8.shape[:2] + + # Step 1: Detect motion via frame difference + curr_gray = _to_grayscale(curr_u8) + prev_gray = _to_grayscale(prev_u8) + diff = cv2.absdiff(curr_gray, prev_gray) + _, motion_mask = cv2.threshold(diff, 25, 255, cv2.THRESH_BINARY) + kernel: np.ndarray = np.ones((3, 3), dtype=np.uint8) + motion_mask = cv2.dilate(motion_mask, kernel, iterations=3) + motion_mask = cv2.GaussianBlur(motion_mask, (21, 21), 0) + + # Step 2: Position-aware masks + ghost_mask, protect_mask = _build_position_masks( + h, + w, + prev_positions, + curr_positions, + mask_radius, + ) + + # Step 3: Enhance ghost with CLAHE + clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) + ghost_tinted = _enhance_ghost(prev_u8, clahe) + + # Step 4: Compute adaptive alpha + motion_alpha = motion_mask.astype(np.float32) / 255.0 + alpha_map = _compute_alpha_map( + base_alpha, + ghost_mask, + protect_mask, + motion_alpha, + prev_positions, + curr_positions, + curr_gray, + ) + + # Step 5: Laplacian pyramid blending + levels = 4 + curr_lap = _build_pyramid(curr_u8, levels) + ghost_lap = _build_pyramid(ghost_tinted.astype(np.uint8), levels) + + mask_pyr = [alpha_map] + for _ in range(levels): + mask_pyr.append(cv2.pyrDown(mask_pyr[-1])) + + blended_lap = _blend_pyramids(curr_lap, ghost_lap, mask_pyr) + result = _reconstruct(blended_lap) + + result = np.clip(result, 0, 255).astype(np.uint8) + return result.astype(np.float32) / 255.0 + + +@dataclass +class VideoMetadata: + """Metadata extracted from video file.""" + + fps: float | None = None + frame_count: int | None = None + width: int | None = None + height: int | None = None + duration: float | None = None + path: Path | None = None + + @property + def resolution(self) -> tuple[int, int] | None: + """Return (width, height) tuple if available.""" + if self.width is not None and self.height is not None: + return (self.width, self.height) + return None + + +# Disable navigation toolbar +matplotlib.rcParams["toolbar"] = "None" + +# CONFIGURATION +ROOT_PATH = Path(__file__).parent / "datasets" / "multi-animal" +SLP_DIR = ROOT_PATH / "slp" +MP4_DIR = ROOT_PATH / "mp4" + +DEMO_DATASETS = { + "free-moving-2flies-ID-13nodes-1024x1024x1-30_3pxmm": { + "mp4": "https://storage.googleapis.com/sleap-data/datasets/wt_gold.13pt/clips/talk_title_slide%4013150-14500.mp4", # noqa: E501 + "slp": "https://storage.googleapis.com/sleap-data/datasets/wt_gold.13pt/clips/talk_title_slide%4013150-14500.slp", # noqa: E501 + }, + "free-moving-5mice-noID-11nodes-1280x1024x1-1_97pxmm": { + "mp4": "https://storage.googleapis.com/sleap-data/datasets/wang_4mice_john/clips/OFTsocial5mice-0000-00%4015488-18736.mp4", # noqa: E501 + "slp": "https://storage.googleapis.com/sleap-data/datasets/wang_4mice_john/clips/OFTsocial5mice-0000-00%4015488-18736.slp", # noqa: E501 + }, + "free-moving-2mice-noID-5nodes-1280x1024x1-1_9pxmm": { + "mp4": "https://storage.googleapis.com/sleap-data/datasets/eleni_mice/clips/20200111_USVpairs_court1_M1_F1_top-01112020145828-0000%400-2560.mp4", # noqa: E501 + "slp": "https://storage.googleapis.com/sleap-data/datasets/eleni_mice/clips/20200111_USVpairs_court1_M1_F1_top-01112020145828-0000%400-2560.slp", # noqa: E501 + }, + "freemoving-2bees-noID-21nodes-1535x2048x1-14pxmm": { + "mp4": "https://storage.googleapis.com/sleap-data/datasets/yan_bees/clips/bees_demo%4021000-23000.mp4", # noqa: E501 + "slp": "https://storage.googleapis.com/sleap-data/datasets/yan_bees/clips/bees_demo%4021000-23000.slp", # noqa: E501 + }, + "free-moving-4gerbils-ID-14nodes-1024x1280x3-2pxmm": { + "mp4": "https://storage.googleapis.com/sleap-data/datasets/nyu-gerbils/clips/2020-3-10_daytime_5mins_compressedTalmo%403200-5760.mp4", # noqa: E501 + "slp": "https://storage.googleapis.com/sleap-data/datasets/nyu-gerbils/clips/2020-3-10_daytime_5mins_compressedTalmo%403200-5760.slp", # noqa: E501 + }, +} + +# Frame interval is computed dynamically as fps (1 second between panels) +NUM_PANELS = 5 # Number of time points to display (0s to 4s) +MIN_CONTINUOUS_PANELS = 3 # Minimum panels for partial continuous mode + +OVERLAY_ALPHA = 0.7 # Ghost frame opacity (Row 1 overlay) + +ANIMAL_COLORS = [ + "red", + "#39FF14", + "#1E90FF", + "magenta", + "cyan", + "yellow", + "white", +] +NET_ORIENTATION_COLOR = [ + 1.00, + 0.85, + 0.20, +] # Orientation polarization vector (yellow) +NET_HEADING_COLOR = [1.00, 0.55, 0.10] # Heading polarization vector (orange) +DARK_BG = (0.08, 0.08, 0.10) # Figure background +DARK_FG = (0.92, 0.92, 0.92) # Text and axis color + +SAVE_PNG = False +SAVE_SVG = True + +# When True, prompt user interactively; when False, use auto-detected values +INTERACTIVE_NODE_SELECTION = True + +# Minimum polarization for visually meaningful vectors +MIN_VISUAL_POL = 0.15 + +# When not None, skip interactive mode and process each dataset +# with specified nodes. +# Format: {slp_file_index: (from_node_idx, to_node_idx)} +# - from_node: should be posterior keypoint (e.g., tail_base, abdomen) +# - to_node: should be anterior keypoint (e.g., nose, head) +# NOTE: manual BATCH_CONFIG has no automatic AP-inference ordering unless +# BATCH_CONFIG_AP_VALIDATED is set to True. +# Example: {0: (3, 0), 1: (4, 0)} +BATCH_CONFIG: dict[int, tuple[int, int]] | None = None + +# When True, treat BATCH_CONFIG pairs as already AP-validated +# (posterior→anterior ordering), skipping the AP inference feedback. +BATCH_CONFIG_AP_VALIDATED = False + +# When True, auto-load suggested pairs from AP validation H5 file. +# This populates BATCH_CONFIG automatically: element 0 = inferred posterior, +# element 1 = inferred anterior +# (ordered by _order_pair_by_ap in collective.py). +USE_AP_VALIDATION_H5 = True + + +# ── Shared helpers ─────────────────────────────────────────────────────────── + + +def _list_slp_files(directory: Path) -> list[Path]: + """Return sorted SLP files from *directory*, or empty list if missing.""" + return sorted(directory.glob(SLP_GLOB)) if directory.exists() else [] + + +def _prompt_int_in_range(prompt: str, low: int, high: int) -> int: + """Repeatedly prompt until the user enters an integer in [low, high).""" + while True: + try: + value = int(input(prompt).strip()) + if low <= value < high: + return value + print(f" Invalid. Enter {low}-{high - 1}.") + except ValueError: + print(INVALID_NUMBER_MSG) + + +def _download_file(url: str, destination: Path) -> None: + """Download a single file to the requested destination.""" + destination.parent.mkdir(parents=True, exist_ok=True) + print(f"Downloading {destination.name}...") + urlretrieve(url, destination) + + +def ensure_demo_datasets() -> None: + """Ensure the expected demo SLP/MP4 files exist locally.""" + ROOT_PATH.mkdir(parents=True, exist_ok=True) + SLP_DIR.mkdir(parents=True, exist_ok=True) + MP4_DIR.mkdir(parents=True, exist_ok=True) + + slp_files = sorted(SLP_DIR.glob(SLP_GLOB)) + mp4_files = sorted(MP4_DIR.glob("*.mp4")) + + should_bootstrap = len(slp_files) < len(DEMO_DATASETS) or len( + mp4_files + ) < len(DEMO_DATASETS) + if not should_bootstrap: + return + + print( + "Bootstrapping demo datasets " + f"(found {len(slp_files)} .slp and {len(mp4_files)} .mp4 files)..." + ) + + for file_stem, urls in DEMO_DATASETS.items(): + slp_target = SLP_DIR / f"{file_stem}.slp" + mp4_target = MP4_DIR / f"{file_stem}.mp4" + + if not slp_target.exists(): + _download_file(urls["slp"], slp_target) + if not mp4_target.exists(): + _download_file(urls["mp4"], mp4_target) + + +# ── H5 loading helpers ────────────────────────────────────────────────────── + + +def _decode_h5_strings(arr): + """Decode bytes to str if necessary for H5 string arrays.""" + return [s.decode() if isinstance(s, bytes) else s for s in arr] + + +def _find_best_per_file(files, rxm, suggested_from, suggested_to): + """Find the best individual per file based on highest R×M score.""" + best_per_file: dict[str, tuple[float, int, int]] = {} + for i in range(len(files)): + file_stem = files[i] + if not file_stem: + continue + current_rxm = rxm[i] + if np.isnan(current_rxm): + continue + is_new = file_stem not in best_per_file + is_better = current_rxm > best_per_file.get(file_stem, (0,))[0] + if is_new or is_better: + best_per_file[file_stem] = ( + current_rxm, + int(suggested_from[i]), + int(suggested_to[i]), + ) + return best_per_file + + +def load_suggested_pairs_from_h5() -> dict[int, tuple[int, int]] | None: + """Load suggested AP node pairs from the AP validation H5 file. + + Finds the most recent ap_validation_*.h5 file in the exports directory + and extracts the suggested (from_node, to_node) pair for each file's + best individual (highest R×M score). + Element 0 = inferred posterior (lower AP + coord), element 1 = inferred anterior (higher AP coord). + + Returns + ------- + dict[int, tuple[int, int]] or None + Mapping of {slp_file_index: (posterior_node_idx, anterior_node_idx)}. + Returns None if no H5 file found or extraction fails. + + """ + h5_dir = ROOT_PATH / "exports" / "AP-inference-demo" / "h5" + if not h5_dir.exists(): + print(f"AP validation H5 directory not found: {h5_dir}") + return None + + h5_files = sorted(h5_dir.glob("ap_validation_*.h5")) + if not h5_files: + print(f"No AP validation H5 files found in {h5_dir}") + return None + + h5_path = h5_files[-1] + print(f"Loading suggested pairs from: {h5_path.name}") + + slp_path = ROOT_PATH / "slp" + slp_files = _list_slp_files(slp_path) + if not slp_files: + print("No SLP files found") + return None + + stem_to_idx = {sf.stem: idx for idx, sf in enumerate(slp_files)} + + try: + with h5py.File(h5_path, "r") as f: + files = _decode_h5_strings(f["file"][:]) + _individuals = _decode_h5_strings(f["individual"][:]) + rxm = f["rxm"][:] + suggested_from = f["suggested_from_idx"][:] + suggested_to = f["suggested_to_idx"][:] + + best_per_file = _find_best_per_file( + files, + rxm, + suggested_from, + suggested_to, + ) + + batch_config = {} + for file_stem, (_, from_idx, to_idx) in best_per_file.items(): + if file_stem not in stem_to_idx: + continue + slp_idx = stem_to_idx[file_stem] + if from_idx >= 0 and to_idx >= 0: + batch_config[slp_idx] = (from_idx, to_idx) + print( + f" [{slp_idx}] {file_stem}: " + f"[{from_idx} → {to_idx}] " + f"(posterior → anterior)" + ) + + if batch_config: + return batch_config + print("No valid suggested pairs found in H5") + return None + + except Exception as e: + print(f"Error loading H5 file: {e}") + return None + + +def _run_ap_inference( + ds: xr.Dataset, + from_keypoint: str, + to_keypoint: str, +) -> tuple[dict, dict, float] | None: + """Run AP inference on a pair and return the best individual's results. + + Parameters + ---------- + ds : xr.Dataset + Movement dataset with ``position`` DataArray. + from_keypoint : str + Name of the from_node. + to_keypoint : str + Name of the to_node. + + Returns + ------- + tuple of (result_dict, pair_report, best_rxm) or None + The best individual's result dict, its ``APNodePairReport``, + and the R×M score. Returns ``None`` if all individuals failed. + + """ + config = ValidateAPConfig() + individuals = list(ds.coords["individuals"].values) + + all_results: list[dict] = [] + for individual in individuals: + pos_data = ds.position.sel(individuals=individual) + result = validate_ap( + pos_data, + from_node=from_keypoint, + to_node=to_keypoint, + config=config, + verbose=False, + ) + result["individual"] = individual + all_results.append(result) + + best_idx = -1 + best_rxm = -1.0 + for i, result in enumerate(all_results): + if not result["success"]: + continue + pr = result.get("pair_report") + if pr is None or not pr.success: + continue + rxm = result["resultant_length"] * result["vote_margin"] + if rxm > best_rxm: + best_rxm = rxm + best_idx = i + + if best_idx < 0: + return None + + result = all_results[best_idx] + pr = result["pair_report"] + return result, pr, best_rxm + + +# ── AP scenario / ordering helpers ─────────────────────────────────────────── + + +def _find_suggested_alternative(pr, input_set): + """Return a suggested alternative pair from AP inference, or None.""" + if len(pr.max_separation_distal_nodes) == 2: + alt = pr.max_separation_distal_nodes + if frozenset([int(alt[0]), int(alt[1])]) != input_set: + return (int(alt[0]), int(alt[1])), "max-separation distal" + + if len(pr.max_separation_nodes) == 2: + alt = pr.max_separation_nodes + if frozenset([int(alt[0]), int(alt[1])]) != input_set: + return (int(alt[0]), int(alt[1])), "max-separation overall" + + return None + + +def _prompt_yn(prompt_text: str) -> bool: + """Prompt for y/n response, returning True for 'y'.""" + while True: + reply = input(prompt_text).strip().lower() + if reply in ("y", "yes"): + return True + if reply in ("n", "no"): + return False + print(" Please enter 'y' or 'n'.") + + +def report_ap_scenario( + ds: xr.Dataset, + from_keypoint: str, + to_keypoint: str, + node_names: list[str], + inference_result: tuple[dict, dict, float] | None = None, +) -> tuple[int, int] | None: + """Report the AP scenario for the user's pair and suggest alternatives. + + Runs the AP inference pipeline (or reuses pre-computed results), + reports which of the 13 mutually exclusive scenarios the user's + pair falls into, and interactively offers to switch to a better + pair if one exists. + + Parameters + ---------- + ds : xr.Dataset + Movement dataset with ``position`` DataArray. + from_keypoint : str + Name of the user-specified from_node. + to_keypoint : str + Name of the user-specified to_node. + node_names : list[str] + List of all keypoint names (for index↔name conversion). + inference_result : tuple or None, optional + Pre-computed output from ``_run_ap_inference``. If ``None``, + inference is run internally. + + Returns + ------- + tuple[int, int] or None + ``(posterior_idx, anterior_idx)`` if the user accepted a + suggested alternative pair (already AP-ordered), or ``None`` + if no switch was made. + + """ + print("\n AP scenario report (running inference on user's pair)...") + + if inference_result is None: + inference_result = _run_ap_inference(ds, from_keypoint, to_keypoint) + + if inference_result is None: + print(" AP inference failed for all individuals.") + return None + + result, pr, best_rxm = inference_result + ind_label = result.get("individual", "unknown") + + print(f" Best individual: {ind_label} (R×M = {best_rxm:.3f})") + print(f" Scenario {pr.scenario}: {pr.outcome}") + if pr.warning_message: + print(f" {pr.warning_message}") + + from_idx = node_names.index(from_keypoint) + to_idx = node_names.index(to_keypoint) + input_set = frozenset([from_idx, to_idx]) + + alt_result = _find_suggested_alternative(pr, input_set) + if alt_result is None: + return None + + suggested, suggested_label = alt_result + s_from, s_to = suggested + print( + f" Suggested alternative ({suggested_label}): " + f"{node_names[s_from]}[{s_from}] -> " + f"{node_names[s_to]}[{s_to}] " + f"(posterior -> anterior)" + ) + if _prompt_yn(" Switch to suggested pair? [y/n]: "): + print( + f" Switching to: " + f"{node_names[s_from]}[{s_from}] -> " + f"{node_names[s_to]}[{s_to}]" + ) + return suggested + print(" Keeping original pair.") + return None + + +def check_ap_ordering( + ds: xr.Dataset, + from_keypoint: str, + to_keypoint: str, + inference_result: tuple[dict, dict, float] | None = None, +) -> bool: + """Run AP inference and warn if ordering is reversed. + + Uses ``validate_ap`` on each individual (or reuses pre-computed + results), selects the best individual (highest R×M score), and + checks whether the user's from_node is actually posterior to their + to_node according to the inferred AP axis. If reversed, + interactively prompts the user to consider flipping. + + Parameters + ---------- + ds : xr.Dataset + Movement dataset with ``position`` DataArray. + from_keypoint : str + Name of the user-specified from_node. + to_keypoint : str + Name of the user-specified to_node. + inference_result : tuple or None, optional + Pre-computed output from ``_run_ap_inference``. If ``None``, + inference is run internally. + + Returns + ------- + bool + True if the user accepted the flip (from/to should be swapped), + False otherwise (ordering assumed correct, user declined flip, + or inference failed). + + """ + print("\n AP ordering check...") + + if inference_result is None: + inference_result = _run_ap_inference(ds, from_keypoint, to_keypoint) + + if inference_result is None: + print( + " AP inference could not validate this pair" + " (all individuals failed)" + ) + return False + + result, pr, best_rxm = inference_result + sign_str = "+" if result["anterior_sign"] > 0 else "−" + margin = result["vote_margin"] + ind_label = result.get("individual", "unknown") + + if pr.input_pair_order_matches_inference: + print( + f" AP inference agrees: {from_keypoint} is " + f"posterior to {to_keypoint} " + f"(anterior = {sign_str}PC1, " + f"vote margin M = {margin:.3f}, " + f"best individual = {ind_label}, " + f"R×M = {best_rxm:.3f})" + ) + return False + + print( + f" WARNING: AP inference reports {from_keypoint} is " + f"ANTERIOR to {to_keypoint} - ordering likely reversed!" + ) + print( + f" Inferred AP direction: {sign_str}PC1, " + f"vote margin M = {margin:.3f}, " + f"best individual = {ind_label}, " + f"R×M = {best_rxm:.3f}" + ) + print( + f" Appropriate order would likely be: " + f"{to_keypoint} (posterior) -> " + f"{from_keypoint} (anterior)" + ) + + if _prompt_yn(" Flip from_node and to_node? [y/n]: "): + print( + f" Flipping: from={to_keypoint} (posterior), " + f"to={from_keypoint} (anterior)" + ) + return True + print(" Keeping original ordering as entered.") + return False + + +def prompt_dataset_selection(): + """List available datasets and prompt user to select one. + + Returns + ------- + selected_index : int + Selected dataset index (0-based). + + """ + slp_path = ROOT_PATH / "slp" + slp_files = _list_slp_files(slp_path) + + if not slp_files: + raise FileNotFoundError(f"No SLP files found in {slp_path}") + + print("\n") + print("AVAILABLE DATASETS") + print(f"\nFound {len(slp_files)} SLP file(s):\n") + + for idx, sf in enumerate(slp_files): + print(f" [{idx}] {sf.name}") + + print() + + sel_idx = _prompt_int_in_range( + f"Select dataset [0-{len(slp_files) - 1}]: ", + 0, + len(slp_files), + ) + print(f"\nSelected: [{sel_idx}] {slp_files[sel_idx].name}") + print("─" * 60) + return sel_idx + + +# ── Video metadata helpers ─────────────────────────────────────────────────── + + +def _validate_fps_from_duration(metadata, video_path, actual_duration_ms): + """Cross-check stored FPS against actual video duration. + + Overwrites ``metadata.fps`` and ``metadata.duration`` in-place if the + stored value is too far from the computed one. + """ + if actual_duration_ms is None or actual_duration_ms <= 0: + metadata.duration = metadata.frame_count / metadata.fps + return + + actual_duration_sec = actual_duration_ms / 1000.0 + computed_fps = metadata.frame_count / actual_duration_sec + fps_diff = abs(metadata.fps - computed_fps) + + if fps_diff <= 1.0: + metadata.duration = metadata.frame_count / metadata.fps + return + + import warnings + + warnings.warn( + f"FPS mismatch in {video_path.name}:\n" + f" Stored FPS: {metadata.fps:.2f}\n" + f" Computed FPS: {computed_fps:.2f}\n" + f" Frame count: {metadata.frame_count}\n" + f" Actual duration: {actual_duration_sec:.2f}s\n" + f" Difference: {fps_diff:.2f} fps\n" + f" Using COMPUTED FPS value.", + UserWarning, + stacklevel=3, + ) + metadata.fps = computed_fps + metadata.duration = actual_duration_sec + + +def extract_video_metadata(video_path: Path) -> VideoMetadata: + """Extract metadata from video file. + + Parameters + ---------- + video_path : Path + Path to the video file + + Returns + ------- + VideoMetadata + Dataclass containing fps, frame_count, width, height, duration, path. + Fields are None if extraction fails. + + """ + metadata = VideoMetadata(path=video_path) + try: + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + return metadata + + fps = cap.get(cv2.CAP_PROP_FPS) + if fps > 0: + metadata.fps = fps + + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + if frame_count > 0: + metadata.frame_count = frame_count + + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + if width > 0 and height > 0: + metadata.width = width + metadata.height = height + + actual_duration_ms = None + if frame_count > 0: + cap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1) + actual_duration_ms = cap.get(cv2.CAP_PROP_POS_MSEC) + + if metadata.fps and metadata.frame_count: + _validate_fps_from_duration( + metadata, video_path, actual_duration_ms + ) + + cap.release() + except Exception: + pass + return metadata + + +# ── Data loading helpers ───────────────────────────────────────────────────── + + +def _resolve_fps(fps, video_metadata): + """Determine FPS from user input or video metadata. + + Returns (fps, source) tuple. + """ + if fps is not None: + return int(fps), "user-defined" + if video_metadata is not None and video_metadata.fps is not None: + return int(video_metadata.fps), "video metadata" + return None, None + + +def _print_video_info(video_metadata, fps, fps_source): + """Print video and FPS information.""" + if fps is not None: + print(f"FPS: {fps} (from {fps_source})") + else: + print("FPS: Unknown (will use sparse frame selection)") + + if video_metadata is None: + return + if video_metadata.frame_count is not None: + print(f"Video frames: {video_metadata.frame_count}") + if video_metadata.resolution is not None: + print( + f"Video resolution: {video_metadata.width}x{video_metadata.height}" + ) + if video_metadata.duration is not None: + print(f"Video duration: {video_metadata.duration:.2f}s") + + +def load_data(dataset_idx, fps=None, node_config=None): + """Load tracking data, returning poses as xarray Dataset. + + Video loading is deferred to plotting stage. + + Parameters + ---------- + dataset_idx : int + Index of the dataset to load (0-based). + fps : int, optional + Frames per second. If provided, overrides auto-detection. + If None, attempts: video metadata → None (sparse mode) + node_config : dict, optional + Pre-defined node configuration. + + Returns + ------- + ds : xarray.Dataset + Movement-compatible dataset. + video_file : Path or None + Path to the video file if it exists, None otherwise + video_metadata : VideoMetadata or None + Video metadata if video exists. + slp_filename : str + Name of the SLP file (used for output naming) + node_config : dict + Node configuration for orientation and velocity computation + + """ + slp_files = _list_slp_files(SLP_DIR) + video_files = sorted(MP4_DIR.glob("*.mp4")) if MP4_DIR.exists() else [] + + slp_file = slp_files[dataset_idx] + + # Find matching video file + video_file = None + for vf in video_files: + if vf.stem == slp_file.stem or slp_file.stem in vf.stem: + video_file = vf + break + + video_metadata = None + if video_file is not None: + video_metadata = extract_video_metadata(video_file) + + fps, fps_source = _resolve_fps(fps, video_metadata) + + print(f"Loading SLP: {slp_file.name}") + if video_file: + print(f"Video found: {video_file.name}") + else: + print("Video: Not found (will use sparse mode without frames)") + + ds = load_poses.from_sleap_file(slp_file, fps=fps) + + node_names = list(ds.keypoints.values) + track_names = list(ds.individuals.values) + total_frames = ds.sizes["time"] + n_keypoints = ds.sizes["keypoints"] + total_animals = ds.sizes["individuals"] + + _print_video_info(video_metadata, fps, fps_source) + + print(f"Tracking frames: {total_frames}") + print(f"Individuals ({total_animals}): {track_names}") + print(f"Keypoints ({n_keypoints}): {node_names}") + + if node_config is not None: + _print_batch_node_config(node_config, node_names) + else: + node_config = prompt_node_selection(ds, frames=None) + + return ds, video_file, video_metadata, slp_file.name, node_config + + +def _print_batch_node_config(node_config, node_names): + """Print the batch-mode node configuration summary.""" + from_idx = node_config["orientation_from"] + to_idx = node_config["orientation_to"] + vel_idx = node_config["velocity_node"] + ap_ok = node_config.get("ap_validated", False) + dir_label = "(posterior → anterior)" if ap_ok else "(from → to)" + print("\nUsing batch config:") + print(f" Velocity: {node_names[vel_idx]}[{vel_idx}]") + print( + f" Orientation: {node_names[from_idx]}[{from_idx}] → " + f"{node_names[to_idx]}[{to_idx}] " + f"{dir_label}" + ) + + +def find_first_valid_frame(ds: xr.Dataset, min_animals: int = 2): + """Find the first frame where min_animals have all keypoints visible. + + Parameters + ---------- + ds : xarray.Dataset + Movement dataset with position DataArray. + min_animals : int + Minimum number of individuals with all keypoints visible + + Returns + ------- + tuple (frame_idx, list of individual_indices) or (None, None) if not found + + """ + position = ds.position.values + n_frames = ds.sizes["time"] + n_animals = ds.sizes["individuals"] + + for f in range(n_frames): + valid_animals = [] + for a in range(n_animals): + animal_data = position[f, :, :, a] + if not np.any(np.isnan(animal_data)): + valid_animals.append(a) + if len(valid_animals) >= min_animals: + return f, valid_animals + return None, None + + +def show_keypoint_reference( + frames: np.ndarray, ds: xr.Dataset, frame_idx: int, animal_indices: list +): + """Display a reference frame with labeled keypoints. + + Parameters + ---------- + frames : ndarray + Video frames + ds : xarray.Dataset + Movement dataset with position DataArray. + frame_idx : int + Frame index to display. + animal_indices : list + List of individual indices to show. + + """ + node_names = list(ds.keypoints.values) + position = ds.position.values + + fig, ax = plt.subplots(figsize=(10, 8), facecolor=DARK_BG) + ax.set_facecolor(DARK_BG) + ax.imshow(frames[frame_idx]) + colors = plt.colormaps["tab10"](np.linspace(0, 1, len(node_names))) + + for animal_idx in animal_indices: + keypoints_xy = position[frame_idx, :, :, animal_idx] + + for i, _name in enumerate(node_names): + x, y = keypoints_xy[0, i], keypoints_xy[1, i] + if not np.isnan(x) and not np.isnan(y): + ax.plot(x, y, "o", color=colors[i], markersize=5) + ax.annotate( + f"{i}", + (x, y), + xytext=(5, 5), + textcoords="offset points", + color="white", + fontsize=8, + fontweight="bold", + ) + + ax.set_xlim(0, frames[frame_idx].shape[1]) + ax.set_ylim(frames[frame_idx].shape[0], 0) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_title( + f"Keypoint Reference (Frame {frame_idx}, Animals {animal_indices})", + color=DARK_FG, + fontsize=12, + pad=10, + ) + + for spine in ax.spines.values(): + spine.set_visible(False) + + plt.tight_layout() + plt.show(block=False) + plt.pause(0.1) + + return fig + + +def prompt_node_selection(ds: xr.Dataset, frames: np.ndarray | None = None): + """Prompt user to select keypoints for velocity and orientation. + + Parameters + ---------- + ds : xarray.Dataset + Movement dataset with position DataArray. + frames : ndarray, optional + Video frames (for showing reference image). + + Returns + ------- + dict with keys: + velocity_node: int - keypoint index for velocity/heading. + orientation_from: int - "from" keypoint (ideally posterior). + orientation_to: int - "to" keypoint (ideally anterior). + ap_validated: bool - always False for interactive selection. + + """ + node_names = list(ds.keypoints.values) + n_nodes = len(node_names) + ref_fig = None + + node_list_str = " ".join( + f"{name}[{i}]" for i, name in enumerate(node_names) + ) + + if not INTERACTIVE_NODE_SELECTION: + raise ValueError( + "Interactive mode is required." + " Set INTERACTIVE_NODE_SELECTION=True." + ) + + print("\n") + print("NODE SELECTION") + print(f"\nAvailable nodes: {node_list_str}") + + if frames is not None: + frame_idx, animal_indices = find_first_valid_frame(ds, min_animals=2) + if frame_idx is not None: + msg = f"Showing keypoint reference (Frame {frame_idx})" + print(f"\n{msg}, Animals {animal_indices}...") + ref_fig = show_keypoint_reference( + frames, ds, frame_idx, animal_indices + ) + + print("\nFROM node (should be posterior, e.g., tail_base):") + orientation_from = _prompt_int_in_range(ENTER_INDEX_PROMPT, 0, n_nodes) + + print("\nTO node (should be anterior, e.g., head):") + while True: + orientation_to = _prompt_int_in_range(ENTER_INDEX_PROMPT, 0, n_nodes) + if orientation_to != orientation_from: + break + print(" Must differ from FROM node.") + + print("\nVelocity node (for heading):") + velocity_node = _prompt_int_in_range(ENTER_INDEX_PROMPT, 0, n_nodes) + + if ref_fig is not None: + plt.close(ref_fig) + + print("\n" + "-" * 60) + print("Final configuration:") + print(f" Velocity: {node_names[velocity_node]}[{velocity_node}]") + from_name = node_names[orientation_from] + to_name = node_names[orientation_to] + print( + f" Orientation: {from_name}[{orientation_from}] → " + f"{to_name}[{orientation_to}] " + f"(from → to)" + ) + print("-" * 60 + "\n") + + return { + "velocity_node": velocity_node, + "orientation_from": orientation_from, + "orientation_to": orientation_to, + "ap_validated": False, + "auto_log": None, + } + + +# ── Sparse frame selection helpers ─────────────────────────────────────────── + + +def _collect_candidate_frames(orientation_valid, vel_valid, total_frames): + """Return frames where both orientation and velocity are valid.""" + return [ + f + for f in range(1, total_frames) + if orientation_valid[f] and vel_valid[f] + ] + + +def _compute_combined_polarization( + candidate_frames, orientation_pol, heading_pol, total_frames +): + """Compute combined (orientation + heading) polarization per frame.""" + combined_pol = np.zeros(total_frames) + for f in candidate_frames: + vals = [] + if not np.isnan(orientation_pol[f]): + vals.append(orientation_pol[f]) + if not np.isnan(heading_pol[f]): + vals.append(heading_pol[f]) + if vals: + combined_pol[f] = np.mean(vals) + return combined_pol + + +def _greedy_select_frames(source_frames, combined_pol, selected, target_count): + """Greedily add frames from *source_frames* to maximise variance. + + Mutates *selected* in-place and returns it. + """ + remaining = set(source_frames) + if not selected and remaining: + first = max(remaining, key=lambda f: combined_pol[f]) + selected.append(first) + remaining.remove(first) + + while len(selected) < target_count and remaining: + current_vals = [combined_pol[f] for f in selected] + best_f = max( + remaining, + key=lambda f: np.var(current_vals + [combined_pol[f]]), + ) + selected.append(best_f) + remaining.remove(best_f) + + return selected + + +def _find_sparse_frames( + orientation_pol, heading_pol, orientation_valid, vel_valid, total_frames +): + """Select sparse frames with maximal variance in polarization. + + Used when continuous segments aren't available or when fps is unknown. + Each selected frame requires: + - orientation_valid[f]: valid orientation at frame f + - vel_valid[f]: valid velocity at frame f (implies pos_valid at f AND f-1) + + Returns + ------- + Same as find_best_segment + + """ + candidate_frames = _collect_candidate_frames( + orientation_valid, + vel_valid, + total_frames, + ) + + if len(candidate_frames) < NUM_PANELS: + n_cand = len(candidate_frames) + raise ValueError( + f"Not enough valid frames ({n_cand}) for {NUM_PANELS} panels" + ) + + combined_pol = _compute_combined_polarization( + candidate_frames, + orientation_pol, + heading_pol, + total_frames, + ) + + visible_frames = [ + f for f in candidate_frames if combined_pol[f] >= MIN_VISUAL_POL + ] + low_vis_frames = [ + f for f in candidate_frames if combined_pol[f] < MIN_VISUAL_POL + ] + + n_total = len(candidate_frames) + n_vis = len(visible_frames) + print( + f"Candidate frames: {n_total} total, " + f"{n_vis} visually clear (>= {MIN_VISUAL_POL})" + ) + + selected: list = [] + _greedy_select_frames(visible_frames, combined_pol, selected, NUM_PANELS) + if len(selected) < NUM_PANELS: + _greedy_select_frames( + low_vis_frames, combined_pol, selected, NUM_PANELS + ) + + panel_frames = np.array(sorted(selected)) + prev_frames = panel_frames - 1 + + has_orientation = np.ones(NUM_PANELS, dtype=bool) + has_heading = np.ones(NUM_PANELS, dtype=bool) + + selected_pol = [combined_pol[f] for f in panel_frames] + n_visible = sum(1 for p in selected_pol if p >= MIN_VISUAL_POL) + + print(f"Selected sparse frames: {panel_frames}") + mvp = MIN_VISUAL_POL + print(f"Previous frames (f-1): {prev_frames}") + print(f"Polarization values: {[f'{p:.3f}' for p in selected_pol]}") + print(f"Visually clear panels: {n_visible}/{NUM_PANELS} (>= {mvp})") + print("Orientation valid: all True (guaranteed by selection)") + print("Heading valid: all True (guaranteed by selection)") + + return ( + panel_frames, + prev_frames, + False, + has_orientation, + has_heading, + 1, + NUM_PANELS, + ) + + +# ── Segment scoring helpers ────────────────────────────────────────────────── + + +def _safe_unwrap(theta_arr): + """Unwrap angles, interpolating through NaN gaps.""" + t = theta_arr.copy() + v = ~np.isnan(t) + if np.sum(v) < 2: + return t + t[~v] = np.interp(np.where(~v)[0], np.where(v)[0], t[v]) + t = np.unwrap(t) + t[~v] = np.nan + return t + + +def _metric_scores(win): + """Vectorised range × monotonicity + mean for each window.""" + wmin = np.nanmin(win, axis=1) + wmax = np.nanmax(win, axis=1) + rng = wmax - wmin + delta = np.abs(win[:, -1] - win[:, 0]) + mono = np.where(rng > 0.01, delta / rng, 0.0) + mean = np.nanmean(win, axis=1) + return rng * (0.3 + 0.7 * mono) + 0.15 * mean + + +def _compute_sweep_bonus(ot_uw, ht_uw, ng, w): + """Compute the collective-sweep scoring bonus.""" + orientation_sweep = np.abs(ot_uw[w - 1 : ng] - ot_uw[: ng - w + 1]) / np.pi + head_sweep = np.abs(ht_uw[w - 1 : ng] - ht_uw[: ng - w + 1]) / np.pi + orientation_signed = ot_uw[w - 1 : ng] - ot_uw[: ng - w + 1] + head_signed = ht_uw[w - 1 : ng] - ht_uw[: ng - w + 1] + + mean_sw = np.nanmean(np.stack([orientation_sweep, head_sweep]), axis=0) + agreement = np.where( + orientation_signed * head_signed > 0, + 1.0, + np.where( + (np.abs(orientation_signed) < 0.1) | (np.abs(head_signed) < 0.1), + 0.6, + 0.3, + ), + ) + return 8.0 * mean_sw * agreement + + +def _compute_motion_bonus(min_disp, hp_win, md_win): + """Compute the collective-motion scoring bonus.""" + if min_disp is None: + return 0.0 + vd = min_disp[~np.isnan(min_disp)] + speed_scale = max(np.percentile(vd, 75), 1.0) if len(vd) > 0 else 1.0 + norm_speed = np.nanmean(md_win, axis=1) / speed_scale + mean_hpol = np.nanmean(hp_win, axis=1) + return 10.0 * mean_hpol * norm_speed + + +def _try_continuous_window( + w, + grid, + ng, + both_valid, + orientation_pol, + heading_pol, + heading_theta, + orientation_theta, + min_disp, +): + """Score all candidate windows of width *w* on the panel grid. + + Returns (best_grid_idx, best_score) or (None, -inf). + """ + from numpy.lib.stride_tricks import sliding_window_view + + if ng < w + 1: + return None, -np.inf + + gv = both_valid[grid] + win_valid = sliding_window_view(gv, w) + all_ok = np.all(win_valid, axis=1) + prev_ok = np.empty(ng - w + 1, dtype=bool) + prev_ok[0] = False + prev_ok[1:] = gv[: ng - w] + candidate = all_ok & prev_ok + + if not np.any(candidate): + return None, -np.inf + + op_win = sliding_window_view(orientation_pol[grid], w) + hp_arr = ( + heading_pol[grid] if heading_pol is not None else np.full(ng, np.nan) + ) + hp_win = sliding_window_view(hp_arr, w) + md_arr = min_disp[grid] if min_disp is not None else np.full(ng, np.nan) + md_win = sliding_window_view(md_arr, w) + + ot_uw = _safe_unwrap(orientation_theta[grid]) + ht_raw = ( + heading_theta[grid] + if heading_theta is not None + else np.full(ng, np.nan) + ) + ht_uw = _safe_unwrap(ht_raw) + + op_metric = _metric_scores(op_win) + hp_metric = _metric_scores(hp_win) + sweep_bonus = _compute_sweep_bonus(ot_uw, ht_uw, ng, w) + motion_bonus = _compute_motion_bonus(min_disp, hp_win, md_win) + + scores = op_metric + hp_metric + sweep_bonus + motion_bonus + scores[~candidate] = -np.inf + + best_i = np.argmax(scores) + if scores[best_i] == -np.inf: + return None, -np.inf + return best_i, scores[best_i] + + +def _log_continuous_segment( + panel_frames, + has_orientation, + has_heading, + orientation_pol, + heading_pol, + n_panels, + start_frame, + best_score, +): + """Print diagnostic information for a found continuous segment.""" + print( + f"\n{n_panels}-panel continuous segment starting at " + f"frame {start_frame} (score={best_score:.3f})" + ) + print(f"Panel frames: {panel_frames}") + if np.any(has_orientation): + vb = orientation_pol[panel_frames][has_orientation] + n_vis = np.sum(vb >= MIN_VISUAL_POL) + print( + f"Orientation pol range: {np.nanmin(vb):.3f} to " + f"{np.nanmax(vb):.3f} ({n_vis}/{len(vb)} panels >= " + f"{MIN_VISUAL_POL})" + ) + if np.any(has_heading) and heading_pol is not None: + vh = heading_pol[panel_frames][has_heading] + n_vis_h = np.sum(vh >= MIN_VISUAL_POL) + print( + f"Heading pol range: {np.nanmin(vh):.3f} to " + f"{np.nanmax(vh):.3f} ({n_vis_h}/{len(vh)} panels >= " + f"{MIN_VISUAL_POL})" + ) + + +def find_best_segment( + orientation_pol, + orientation_theta, + heading_pol_1f, + heading_pol, + heading_theta, + min_disp, + orientation_valid, + vel_valid, + pos_valid, + both_valid, + total_frames, + fps, +): + """Find the best continuous segment by scanning library output arrays. + + Subsamples all timeseries at the panel spacing (= fps), then scores + every candidate window with vectorised numpy. + + Parameters + ---------- + orientation_pol, orientation_theta : ndarray + Orientation polarization magnitude and angle per frame. + heading_pol_1f : ndarray + 1-frame heading polarization (sparse fallback only). + heading_pol, heading_theta : ndarray or None + fps-frame heading polarization and angle per frame. + min_disp : ndarray or None + Min displacement magnitude per frame (across individuals). + orientation_valid, vel_valid, pos_valid, both_valid : ndarray + Per-frame boolean validity masks. + total_frames, fps : int or None + Frame count and frame rate. + + Returns + ------- + panel_frames, prev_frames, is_continuous, has_orientation, has_heading, + frame_interval, num_panels + + """ + if fps is None: + print("FPS unknown - using sparse frame selection") + return _find_sparse_frames( + orientation_pol, + heading_pol_1f, + orientation_valid, + vel_valid, + total_frames, + ) + + frame_interval = int(fps) + grid = np.arange(0, total_frames, frame_interval) + ng = len(grid) + + window_args = ( + grid, + ng, + both_valid, + orientation_pol, + heading_pol, + heading_theta, + orientation_theta, + min_disp, + ) + + best_i, best_score = _try_continuous_window(NUM_PANELS, *window_args) + n_panels = NUM_PANELS + + if best_i is None: + print( + f"\nNo full {NUM_PANELS}-panel segment found. " + "Searching for partial continuous..." + ) + for n_panels in range(NUM_PANELS - 1, MIN_CONTINUOUS_PANELS - 1, -1): + best_i, best_score = _try_continuous_window(n_panels, *window_args) + if best_i is not None: + print( + f"Found {n_panels}-panel segment (score={best_score:.3f})" + ) + break + + if best_i is not None: + panel_frames = grid[best_i] + np.arange(n_panels) * frame_interval + prev_frames = panel_frames - frame_interval + + has_orientation = np.array( + [orientation_valid[f] for f in panel_frames] + ) + has_heading = np.array( + [ + pos_valid[prev_frames[p]] and pos_valid[panel_frames[p]] + for p in range(n_panels) + ] + ) + + _log_continuous_segment( + panel_frames, + has_orientation, + has_heading, + orientation_pol, + heading_pol, + n_panels, + grid[best_i], + best_score, + ) + + return ( + panel_frames, + prev_frames, + True, + has_orientation, + has_heading, + frame_interval, + n_panels, + ) + + print( + f"\nNo continuous segment (>={MIN_CONTINUOUS_PANELS} panels) found. " + "Selecting sparse frames..." + ) + return _find_sparse_frames( + orientation_pol, + heading_pol_1f, + orientation_valid, + vel_valid, + total_frames, + ) + + +# ── Polarization log helpers ───────────────────────────────────────────────── + + +def _log_header( + log, + video_filename, + fps, + frame_interval, + num_panels, + is_continuous, + auto_log, + node_config, + node_names, +): + """Write the log file header and configuration block.""" + log("POLARIZATION ANALYSIS LOG") + log(f"\nVideo: {video_filename}") + log(f"FPS: {fps if fps is not None else 'Unknown'}") + if fps is not None: + log(f"Frame interval: {frame_interval} frames (1 second)") + else: + log(f"Frame interval: {frame_interval} frames (estimated)") + log(f"Number of panels: {num_panels}") + log(f"Segment type: {'CONTINUOUS' if is_continuous else 'SPARSE'}") + + if auto_log is not None: + log("\n" + auto_log) + + log("\n") + log("FINAL NODE CONFIGURATION") + from_idx = node_config["orientation_from"] + to_idx = node_config["orientation_to"] + vel_idx = node_config["velocity_node"] + ap_ok = node_config.get("ap_validated", False) + + from_label = "(posterior)" if ap_ok else "(from_node)" + to_label = "(anterior)" if ap_ok else "(to_node)" + + log(f"\n Velocity node: {node_names[vel_idx]}[{vel_idx}]") + log( + f" Orientation FROM: {node_names[from_idx]}[{from_idx}] {from_label}" + ) + log(f" Orientation TO: {node_names[to_idx]}[{to_idx}] {to_label}") + log(" Fallback TO: None (strict mode)") + + log("\n") + log("POLARIZATION VERIFICATION") + log("\nCoordinate system notes:") + log(" - Image coords: y increases DOWNWARD (top=0)") + log(" - Polar plot: math convention (0=E, 90=N, CCW positive)") + log(" - Y-flip: dy = -(to_y - from_y) for image→Cartesian") + log(" - Heading from 1-second displacement (matches Row 1)") + + +def _compute_unit_vectors_and_log(coords, total, log, label_prefix): + """Compute unit vectors from coordinate pairs, logging each animal. + + Parameters + ---------- + coords : list of tuples + Each element: (animal_idx, from_x, from_y, to_x, to_y). + total : int + Total number of animals. + log : callable + Logging function. + label_prefix : str + Prefix for log messages. + + Returns + ------- + list of (ux, uy) + Unit vectors for valid animals. + + """ + unit_vectors = [] + for a, fx, fy, tx, ty in coords: + if np.isnan(fx) or np.isnan(tx): + nan_row = ( + f" {a:<8} {'NaN':>8} {'NaN':>8} " + f"{'NaN':>8} {'NaN':>8} │ {'---':>8} " + f"{'---':>8} {'---':>8} │ {'---':>8}" + ) + log(nan_row) + continue + + dx = tx - fx + dy_img = ty - fy + dy_cart = -dy_img + length = np.hypot(dx, dy_cart) + + if length > 0: + angle_deg = np.degrees(np.arctan2(dy_cart, dx)) + unit_vectors.append((dx / length, dy_cart / length)) + row = ( + f" {a:<8} {fx:>8.1f} {fy:>8.1f} " + f"{tx:>8.1f} {ty:>8.1f} │ {dx:>8.1f} " + f"{dy_img:>8.1f} {dy_cart:>8.1f} │ {angle_deg:>8.1f}" + ) + log(row) + return unit_vectors + + +def _log_polarization_check( + log, unit_vectors, stored_pol, stored_theta, total_animals, label +): + """Log a polarization verification check.""" + if not unit_vectors: + return + + sum_ux = sum(u[0] for u in unit_vectors) + sum_uy = sum(u[1] for u in unit_vectors) + n = len(unit_vectors) + mean_ux, mean_uy = sum_ux / n, sum_uy / n + computed_pol = np.hypot(mean_ux, mean_uy) + computed_theta_deg = np.degrees(np.arctan2(mean_uy, mean_ux)) + + stored_theta_deg = ( + np.degrees(-stored_theta) if not np.isnan(stored_theta) else np.nan + ) + + log(f"\n {label} POLARIZATION CHECK:") + log(f" Mean unit vector: ({mean_ux:.4f}, {mean_uy:.4f})") + log( + f" Computed: pol={computed_pol:.4f}, " + f"theta={computed_theta_deg:.1f}°" + ) + log(f" Stored: pol={stored_pol:.4f}, theta={stored_theta_deg:.1f}°") + + if n < total_animals: + log(f" MATCH: SKIPPED ({n}/{total_animals} valid)") + elif np.isnan(stored_pol): + log(" MATCH: ✗ NO (stored is NaN unexpectedly)") + else: + match = np.isclose(computed_pol, stored_pol, atol=1e-6) + log(f" MATCH: {'✓ YES' if match else '✗ NO'}") + + +def _log_panel( + log, + p, + curr_f, + prev_f, + head_x, + head_y, + tail_x, + tail_y, + vel_x, + vel_y, + orientation_pol, + heading_pol, + orientation_theta, + heading_theta, + total_animals, + node_config, + node_names, +): + """Log a single panel's orientation and heading data.""" + from_idx = node_config["orientation_from"] + ap_ok = node_config.get("ap_validated", False) + from_name = node_names[from_idx] + to_name = node_names[node_config["orientation_to"]] + + log(f"\n{'─' * 80}") + log(f"PANEL {p} ({p}s) | Frame {curr_f} | Prev frame {prev_f}") + log(f"{'─' * 80}") + + # Orientation + log(f"\n ORIENTATION ({from_name} → {to_name}):") + col_from_x = "post_x" if ap_ok else "from_x" + col_from_y = "post_y" if ap_ok else "from_y" + col_to_x = "ant_x" if ap_ok else "to_x" + col_to_y = "ant_y" if ap_ok else "to_y" + hdr = ( + f" {'Animal':<8} {col_from_x:>8} {col_from_y:>8} " + f"{col_to_x:>8} {col_to_y:>8} │ {'dx':>8} " + f"{'dy_img':>8} {'dy_cart':>8} │ {'angle°':>8}" + ) + log(hdr) + sep = ( + f" {'-' * 8} {'-' * 8} {'-' * 8} {'-' * 8} {'-' * 8} │ " + f"{'-' * 8} {'-' * 8} {'-' * 8} │ {'-' * 8}" + ) + log(sep) + + orient_coords = [ + ( + a, + tail_x[curr_f, a], + tail_y[curr_f, a], + head_x[curr_f, a], + head_y[curr_f, a], + ) + for a in range(total_animals) + ] + orient_uvs = _compute_unit_vectors_and_log( + orient_coords, + total_animals, + log, + "ORIENTATION", + ) + _log_polarization_check( + log, + orient_uvs, + orientation_pol[curr_f], + orientation_theta[curr_f], + total_animals, + "ORIENTATION", + ) + + # Heading + log(f"\n HEADING/VELOCITY (frame {prev_f} to {curr_f}):") + hdr_h = ( + f" {'Animal':<8} {'prev_x':>8} {'prev_y':>8} " + f"{'curr_x':>8} {'curr_y':>8} │ {'dx':>8} " + f"{'dy_img':>8} {'dy_cart':>8} │ {'angle°':>8}" + ) + log(hdr_h) + log(sep) + + heading_coords = [ + ( + a, + vel_x[prev_f, a], + vel_y[prev_f, a], + vel_x[curr_f, a], + vel_y[curr_f, a], + ) + for a in range(total_animals) + ] + heading_uvs = _compute_unit_vectors_and_log( + heading_coords, + total_animals, + log, + "HEADING", + ) + _log_polarization_check( + log, + heading_uvs, + heading_pol[p], + heading_theta[p], + total_animals, + "HEADING", + ) + + +def write_polarization_log( + panel_frames, + prev_frames, + head_x, + head_y, + tail_x, + tail_y, + vel_x, + vel_y, + orientation_pol, + panel_heading_pol, + orientation_theta, + panel_heading_theta, + total_animals, + node_config, + node_names, + auto_log, + video_filename, + fps, + is_continuous, + frame_interval, + num_panels, + output_file=None, +): + """Write a detailed textual log of polarization computations. + + The log includes per-panel and per-frame diagnostics about orientation and + heading, velocity displacements, computed angles and + resultant magnitudes. + + Parameters + ---------- + panel_frames : Sequence[int] + The list of frame indices corresponding to the current panel frames. + prev_frames : Sequence[int] + The list of previous frame indices used for ghost overlays. + head_x, head_y : np.ndarray + 2‑D arrays of shape (frames, animals) for the to_node keypoint. + tail_x, tail_y : np.ndarray + 2‑D arrays of shape (frames, animals) for the from_node keypoint. + vel_x, vel_y : np.ndarray + 2‑D arrays of shape (frames, animals) for the velocity node. + orientation_pol : np.ndarray + Orientation polarization values per frame. + panel_heading_pol : np.ndarray + Heading polarization values per panel. + orientation_theta : np.ndarray + Orientation angles (radians) per frame. + panel_heading_theta : np.ndarray + Heading angles (radians) per panel. + total_animals : int + Number of individuals. + node_config : dict + Node configuration dictionary. + node_names : list[str] + Keypoint names. + auto_log : str | None + Optional auto-detection log text. + video_filename : str + Name of the video file. + fps : int | None + Frames per second. + is_continuous : bool + Whether a continuous segment was found. + frame_interval : int + Frames between panels. + num_panels : int + Number of panels. + output_file : Path | str | None, optional + Optional log file path. + + """ + lines = [] + + def log(text=""): + print(text) + lines.append(text) + + _log_header( + log, + video_filename, + fps, + frame_interval, + num_panels, + is_continuous, + auto_log, + node_config, + node_names, + ) + + for p, curr_f in enumerate(panel_frames): + _log_panel( + log, + p, + curr_f, + prev_frames[p], + head_x, + head_y, + tail_x, + tail_y, + vel_x, + vel_y, + orientation_pol, + panel_heading_pol, + orientation_theta, + panel_heading_theta, + total_animals, + node_config, + node_names, + ) + + log("\n") + log("END LOG") + + if output_file is not None: + with open(output_file, "w") as f: + f.write("\n".join(lines)) + print(f"Log saved to: {output_file}") + + +# ── Figure data container (reduces create_figure parameters) ───────────────── + + +@dataclass +class FigureData: + """Container for the data arrays used to render the polarization figure.""" + + frames: dict + panel_frames: np.ndarray + prev_frames: np.ndarray + head_x: np.ndarray + head_y: np.ndarray + tail_x: np.ndarray + tail_y: np.ndarray + vel_x: np.ndarray + vel_y: np.ndarray + orientation_pol: np.ndarray + heading_pol: np.ndarray + orientation_theta: np.ndarray + heading_theta: np.ndarray + total_animals: int + has_orientation: np.ndarray + has_heading: np.ndarray + + +@dataclass +class FigureConfig: + """Container for configuration options for the polarization figure.""" + + video_filename: str + fps: float | None + node_config: dict + node_names: list[str] + is_continuous: bool + frame_interval: int + num_panels: int + video_metadata: VideoMetadata | None = None + video_available: bool = True + overlay_alpha: float = OVERLAY_ALPHA + + +# ── Figure rendering helpers ───────────────────────────────────────────────── + + +def _get_frame_dimensions(cfg, data): + """Return (width, height) from metadata or frames dict.""" + if ( + cfg.video_metadata is not None + and cfg.video_metadata.resolution is not None + ): + return cfg.video_metadata.resolution + if cfg.video_available and data.frames: + sample = next(iter(data.frames.values())) + frame_h, frame_w = sample.shape[:2] + return frame_w, frame_h + raise ValueError("Cannot determine frame dimensions") + + +def _add_figure_title(fig, cfg): + """Add filename and optional sparse-mode note to the figure.""" + fig.text( + 0.03, + 0.975, + cfg.video_filename, + color=DARK_FG, + fontsize=7, + ha="left", + va="top", + ) + + if not cfg.is_continuous: + sparse_note = ( + "Sparse mode: Each panel shows frame f overlaid with f-1.\n" + "Heading vectors show single-frame displacement.\n" + "Polar plots use same consecutive frame pairs." + ) + fig.text( + 0.03, + 0.955, + sparse_note, + color=[0.7, 0.7, 0.7], + fontsize=7, + ha="left", + va="top", + linespacing=1.3, + ) + + +def _style_panel_ax(ax, frame_w, frame_h): + """Apply common panel axis styling.""" + ax.set_xlim(0, frame_w) + ax.set_ylim(frame_h, 0) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_aspect("equal") + for spine in ax.spines.values(): + spine.set_visible(False) + + +def _render_heading_panel(ax, p, data, cfg, frame_w, frame_h): + """Render a single heading (Row 1) panel.""" + curr_f = data.panel_frames[p] + prev_f = data.prev_frames[p] + + ax.set_facecolor(DARK_BG) + + if not cfg.video_available: + ax.text( + 0.5, + 0.5, + "Skipped - video not available", + transform=ax.transAxes, + ha="center", + va="center", + fontsize=15, + color=DARK_FG, + wrap=True, + ) + return + + if not data.has_heading[p]: + curr_img = data.frames[curr_f].astype(float) / 255 + ax.imshow(curr_img) + return + + curr_img = data.frames[curr_f].astype(float) / 255 + prev_img = data.frames[prev_f].astype(float) / 255 + + curr_positions = [ + (data.vel_x[curr_f, a], data.vel_y[curr_f, a]) + for a in range(data.total_animals) + ] + prev_positions = [ + (data.vel_x[prev_f, a], data.vel_y[prev_f, a]) + for a in range(data.total_animals) + ] + + blended = adaptive_blend_frames( + curr_img, + prev_img, + cfg.overlay_alpha, + curr_positions=curr_positions, + prev_positions=prev_positions, + mask_radius=100, + ) + ax.imshow(blended) + + _draw_displacement_arrows(ax, p, data, curr_f, prev_f) + + +def _draw_displacement_arrows(ax, p, data, curr_f, prev_f): + """Draw velocity displacement arrows for all animals.""" + for a in range(data.total_animals): + c = ANIMAL_COLORS[a % len(ANIMAL_COLORS)] + curr_vx = data.vel_x[curr_f, a] + curr_vy = data.vel_y[curr_f, a] + prev_vx = data.vel_x[prev_f, a] + prev_vy = data.vel_y[prev_f, a] + + if np.isnan(curr_vx) or np.isnan(prev_vx): + continue + dx = curr_vx - prev_vx + dy = curr_vy - prev_vy + if dx == 0 and dy == 0: + continue + + quiver_common = { + "angles": "xy", + "scale_units": "xy", + "scale": 1, + "minshaft": 1.5, + "minlength": 0.5, + } + ax.quiver( + prev_vx, + prev_vy, + dx, + dy, + color="black", + width=0.012, + headwidth=4.5, + headlength=5, + headaxislength=4.5, + zorder=4, + **quiver_common, + ) + ax.quiver( + prev_vx, + prev_vy, + dx, + dy, + color=c, + width=0.008, + headwidth=4, + headlength=5, + headaxislength=4, + zorder=5, + **quiver_common, + ) + + +def _set_panel_title(ax, p, data, cfg, show_prev=False): + """Set the time/frame title for a panel.""" + curr_f = data.panel_frames[p] + prev_f = data.prev_frames[p] + is_edge = p == 0 or p == cfg.num_panels - 1 + + if cfg.is_continuous: + if is_edge: + ax.set_title( + f"Frame: {curr_f}\n{p}s", + color=DARK_FG, + fontsize=12, + pad=8, + ) + else: + ax.set_title(f"{p}s", color=DARK_FG, fontsize=12) + elif show_prev: + if is_edge: + ax.set_title( + f"Frames: {prev_f}, {curr_f}", + color=DARK_FG, + fontsize=11, + pad=8, + ) + else: + ax.set_title(f"{prev_f}, {curr_f}", color=DARK_FG, fontsize=12) + else: + if is_edge: + ax.set_title( + f"Frame: {curr_f}", + color=DARK_FG, + fontsize=12, + pad=8, + ) + else: + ax.set_title(f"{curr_f}", color=DARK_FG, fontsize=12) + + +def _render_heading_row(fig, gs, data, cfg, frame_w, frame_h, current_row): + """Render the heading row (Row 1) of the figure.""" + vel_node_name = cfg.node_names[cfg.node_config["velocity_node"]] + fig.text( + 0.355, + 0.96, + "Heading Across Frames", + color=DARK_FG, + fontsize=12, + fontweight="bold", + ha="left", + va="bottom", + ) + fig.text( + 0.570, + 0.96, + f"({vel_node_name} → {vel_node_name})", + color=DARK_FG, + fontsize=12, + fontweight="bold", + ha="left", + va="bottom", + ) + + for p in range(cfg.num_panels): + ax = fig.add_subplot(gs[current_row, p]) + _render_heading_panel(ax, p, data, cfg, frame_w, frame_h) + _style_panel_ax(ax, frame_w, frame_h) + _set_panel_title(ax, p, data, cfg, show_prev=True) + + +def _render_orientation_panel(ax, p, data, cfg): + """Render a single orientation (Row 2) panel.""" + curr_f = data.panel_frames[p] + ax.set_facecolor(DARK_BG) + + if not cfg.video_available: + ax.text( + 0.5, + 0.5, + "Skipped - video not available", + transform=ax.transAxes, + ha="center", + va="center", + fontsize=15, + color=DARK_FG, + wrap=True, + ) + return + + curr_img = data.frames[curr_f].astype(float) / 255 + ax.imshow(curr_img) + + if not data.has_orientation[p]: + return + + for a in range(data.total_animals): + c = ANIMAL_COLORS[a % len(ANIMAL_COLORS)] + ctx = data.tail_x[curr_f, a] + cty = data.tail_y[curr_f, a] + chx = data.head_x[curr_f, a] + chy = data.head_y[curr_f, a] + + if np.isnan(ctx) or np.isnan(chx): + continue + + ax.plot([ctx, chx], [cty, chy], "-", color=c, linewidth=2) + ax.plot( + ctx, + cty, + "s", + color=c, + markersize=5, + markerfacecolor=c, + markeredgecolor="white", + markeredgewidth=0.5, + ) + ax.plot( + chx, + chy, + "o", + color=c, + markersize=5, + markerfacecolor=c, + markeredgecolor="white", + markeredgewidth=0.5, + ) + + +def _render_orientation_row( + fig, gs, data, cfg, frame_w, frame_h, current_row, show_heading_row +): + """Render the orientation row (Row 2) of the figure.""" + from_node_name = cfg.node_names[cfg.node_config["orientation_from"]] + to_node_name = cfg.node_names[cfg.node_config["orientation_to"]] + title_y = 0.62 if show_heading_row else 0.96 + + fig.text( + 0.355, + title_y, + "Orientation Within Frame", + color=DARK_FG, + fontsize=12, + fontweight="bold", + ha="left", + va="bottom", + ) + fig.text( + 0.560, + title_y, + f"({from_node_name} → {to_node_name})", + color=DARK_FG, + fontsize=12, + fontweight="bold", + ha="left", + va="bottom", + ) + + ap_ok = cfg.node_config.get("ap_validated", False) + from_sub = "Posterior" if ap_ok else "(from_node)" + to_sub = "Anterior" if ap_ok else "(to_node)" + fig.text( + 0.82, + title_y, + f"□ = {from_node_name}", + color=DARK_FG, + fontsize=10, + ha="left", + va="bottom", + ) + fig.text( + 0.82, + title_y - 0.02, + f" {from_sub}", + color=DARK_FG, + fontsize=9, + ha="left", + va="bottom", + ) + fig.text( + 0.90, + title_y, + f"○ = {to_node_name}", + color=DARK_FG, + fontsize=10, + ha="left", + va="bottom", + ) + fig.text( + 0.90, + title_y - 0.02, + f" {to_sub}", + color=DARK_FG, + fontsize=9, + ha="left", + va="bottom", + ) + + for p in range(cfg.num_panels): + ax = fig.add_subplot(gs[current_row, p]) + _render_orientation_panel(ax, p, data, cfg) + _style_panel_ax(ax, frame_w, frame_h) + if not show_heading_row: + _set_panel_title(ax, p, data, cfg, show_prev=False) + + +def _render_polar_panel(ax, p, data): + """Render a single polar plot panel.""" + curr_f = data.panel_frames[p] + ax.set_facecolor(DARK_BG) + + ax.set_theta_zero_location("E") + ax.set_theta_direction(1) + ax.set_rlim(0, 1) + ax.set_rticks([0.25, 0.5, 0.75, 1.0]) + ax.tick_params(colors=DARK_FG, labelsize=8) + ax.spines["polar"].set_color([0.4, 0.4, 0.45]) + + ax.set_rgrids( + [0.25, 0.5, 0.75, 1.0], + labels=["", "0.5", "", "1.0"], + color=DARK_FG, + fontsize=6, + ) + ax.grid( + True, + color=[0.5, 0.5, 0.55], + alpha=0.7, + linestyle="-", + linewidth=0.8, + ) + + b_pol = data.orientation_pol[curr_f] if data.has_orientation[p] else np.nan + b_theta = ( + -data.orientation_theta[curr_f] if data.has_orientation[p] else np.nan + ) + h_pol = data.heading_pol[p] if data.has_heading[p] else np.nan + h_theta = -data.heading_theta[p] if data.has_heading[p] else np.nan + + if ( + data.has_orientation[p] + and not np.isnan(b_pol) + and not np.isnan(b_theta) + ): + ax.annotate( + "", + xy=(b_theta, b_pol), + xytext=(b_theta, 0), + arrowprops={ + "arrowstyle": "->", + "color": NET_ORIENTATION_COLOR, + "lw": 2.5, + "mutation_scale": 12, + }, + ) + + if data.has_heading[p] and not np.isnan(h_pol) and not np.isnan(h_theta): + ax.annotate( + "", + xy=(h_theta, h_pol), + xytext=(h_theta, 0), + arrowprops={ + "arrowstyle": "->", + "color": NET_HEADING_COLOR, + "lw": 2.5, + "mutation_scale": 12, + }, + ) + + if data.has_orientation[p] and not np.isnan(b_pol): + ax.text( + 0.5, + 1.32, + f"$p_b$={b_pol:.2f}", + transform=ax.transAxes, + color=NET_ORIENTATION_COLOR, + fontsize=9, + fontweight="bold", + ha="center", + va="center", + ) + if data.has_heading[p] and not np.isnan(h_pol): + ax.text( + 0.5, + 1.20, + f"$p_h$={h_pol:.2f}", + transform=ax.transAxes, + color=NET_HEADING_COLOR, + fontsize=9, + fontweight="bold", + ha="center", + va="center", + ) + + +def _get_polar_labels(p, num_panels): + """Return theta grid labels for polar panel *p*.""" + all_angles = [0, 45, 90, 135, 180, 225, 270, 315] + if p == 0: + labels = ["", "45°", "90°", "135°", "180°", "225°", "270°", "315°"] + elif p == num_panels - 1: + labels = ["0°", "45°", "90°", "135°", "", "225°", "270°", "315°"] + else: + labels = ["", "45°", "90°", "135°", "", "225°", "270°", "315°"] + return all_angles, labels + + +def _render_polar_row(fig, gs, data, cfg, current_row): + """Render the polar plots row.""" + for p in range(cfg.num_panels): + ax = fig.add_subplot(gs[current_row, p], projection="polar") + _render_polar_panel(ax, p, data) + all_angles, labels = _get_polar_labels(p, cfg.num_panels) + ax.set_thetagrids(all_angles, labels=labels) + + +def _render_legend_and_footer( + fig, cfg, show_heading_row, show_orientation_row +): + """Render legend and footer text on the figure.""" + from_node_name = cfg.node_names[cfg.node_config["orientation_from"]] + to_node_name = cfg.node_names[cfg.node_config["orientation_to"]] + vel_node_name = cfg.node_names[cfg.node_config["velocity_node"]] + + legend_x = 0.03 + if show_orientation_row: + fig.text( + legend_x, + 0.015, + f"$p_b$ = Orientation ({from_node_name}→{to_node_name}) ", + color=NET_ORIENTATION_COLOR, + fontsize=9, + ha="left", + va="bottom", + ) + legend_x = 0.22 + if show_heading_row: + fig.text( + legend_x, + 0.015, + f"$p_h$ = Heading ({vel_node_name} velocity)", + color=NET_HEADING_COLOR, + fontsize=9, + ha="left", + va="bottom", + ) + + fig.text( + 0.5, + 0.015, + "Computed Polarization", + color=DARK_FG, + fontsize=13, + fontweight="bold", + ha="center", + va="bottom", + ) + + if cfg.is_continuous: + fig.text( + 0.97, + 0.015, + f"Frame interval={cfg.frame_interval} " + f"({1000 * cfg.frame_interval / cfg.fps:.0f}ms)", + color=DARK_FG, + fontsize=10, + ha="right", + va="bottom", + ) + else: + fig.text( + 0.97, + 0.015, + "Sparse frames (non-continuous)", + color=DARK_FG, + fontsize=10, + ha="right", + va="bottom", + ) + + +def create_figure(data: FigureData, cfg: FigureConfig): + """Create the static figure with frame panels and polar plots. + + Parameters + ---------- + data : FigureData + Container holding all data arrays for rendering. + cfg : FigureConfig + Container holding configuration and metadata. + + Returns + ------- + matplotlib.figure.Figure + The created figure object containing all subplots. + + """ + show_heading_row = np.any(data.has_heading) + show_orientation_row = np.any(data.has_orientation) + + if show_heading_row and show_orientation_row: + height_ratios = [1.2, 1.2, 1] + n_rows = 3 + elif show_heading_row or show_orientation_row: + height_ratios = [1.2, 1] + n_rows = 2 + else: + raise ValueError("No valid data to visualize") + + fig_width = 2.5 * cfg.num_panels + fig_height = 11 if n_rows == 3 else 8 + fig = plt.figure(figsize=(fig_width, fig_height), facecolor=DARK_BG) + gs = gridspec.GridSpec( + n_rows, + cfg.num_panels, + height_ratios=height_ratios, + hspace=0.12, + wspace=0.08, + left=0.03, + right=0.97, + top=0.94, + bottom=0.04, + ) + + frame_w, frame_h = _get_frame_dimensions(cfg, data) + + _add_figure_title(fig, cfg) + + current_row = 0 + + if show_heading_row: + _render_heading_row(fig, gs, data, cfg, frame_w, frame_h, current_row) + current_row += 1 + + if show_orientation_row: + _render_orientation_row( + fig, + gs, + data, + cfg, + frame_w, + frame_h, + current_row, + show_heading_row, + ) + current_row += 1 + + _render_polar_row(fig, gs, data, cfg, current_row) + + _render_legend_and_footer(fig, cfg, show_heading_row, show_orientation_row) + + return fig + + +# ── Dataset processing ─────────────────────────────────────────────────────── + + +def _determine_overlay_alpha(video_file): + """Return the overlay alpha based on the video filename.""" + if video_file is None: + return OVERLAY_ALPHA + video_name = video_file.name + if ( + "free-moving-4gerbils" in video_name + or "freemoving-2bees" in video_name + ): + return 0.7 + return OVERLAY_ALPHA + + +def _apply_ap_corrections( + ds, node_config, node_names, from_keypoint, to_keypoint +): + """Run AP inference and apply corrections to node_config if needed. + + Returns the (possibly updated) from_keypoint and to_keypoint. + """ + inference_result = _run_ap_inference(ds, from_keypoint, to_keypoint) + + suggested = report_ap_scenario( + ds, + from_keypoint, + to_keypoint, + node_names, + inference_result=inference_result, + ) + if suggested is not None: + s_from, s_to = suggested + node_config["orientation_from"] = s_from + node_config["orientation_to"] = s_to + node_config["ap_validated"] = True + return node_names[s_from], node_names[s_to] + + should_flip = check_ap_ordering( + ds, + from_keypoint, + to_keypoint, + inference_result=inference_result, + ) + if should_flip: + old_from = node_config["orientation_from"] + old_to = node_config["orientation_to"] + node_config["orientation_from"] = old_to + node_config["orientation_to"] = old_from + node_config["ap_validated"] = True + return node_names[node_config["orientation_from"]], node_names[ + node_config["orientation_to"] + ] + + return from_keypoint, to_keypoint + + +def _compute_validity_masks(ds, from_keypoint, to_keypoint, velocity_keypoint): + """Compute per-frame validity masks.""" + from_pos = ds.position.sel(keypoints=from_keypoint).values + to_pos = ds.position.sel(keypoints=to_keypoint).values + vel_pos_all = ds.position.sel(keypoints=velocity_keypoint).values + + from_ok = ~np.isnan(from_pos).any(axis=1) + to_ok = ~np.isnan(to_pos).any(axis=1) + is_valid = from_ok & to_ok + is_pos_valid = ~np.isnan(vel_pos_all).any(axis=1) + is_vel_valid = np.zeros_like(is_pos_valid) + is_vel_valid[1:] = is_pos_valid[1:] & is_pos_valid[:-1] + + orientation_valid = np.all(is_valid, axis=1) + vel_valid = np.all(is_vel_valid, axis=1) + pos_valid = np.all(is_pos_valid, axis=1) + both_valid = orientation_valid & pos_valid + + return ( + from_pos, + to_pos, + vel_pos_all, + is_valid, + is_pos_valid, + orientation_valid, + vel_valid, + pos_valid, + both_valid, + ) + + +def _load_needed_video_frames(video_file, needed, total_frames): + """Load only the needed video frames into a dict.""" + frames = {} + video_available = False + if video_file is None: + print("\nNo video file - frames will not be displayed") + return frames, video_available + + print(f"\nLoading selected video frames from: {video_file.name}") + cap = cv2.VideoCapture(str(video_file)) + video_total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + if video_total < total_frames: + print( + f" Warning: Video has fewer frames ({video_total}) " + f"than tracking data ({total_frames})" + ) + print(" Video frames will not be displayed") + cap.release() + return frames, video_available + + for f in sorted(needed): + if f < 0 or f >= video_total: + continue + cap.set(cv2.CAP_PROP_POS_FRAMES, f) + ret, frame = cap.read() + if ret: + frames[f] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + video_available = len(frames) > 0 + print(f" Loaded {len(frames)} frames (of {video_total} total)") + cap.release() + return frames, video_available + + +def process_dataset(dataset_idx, node_config=None): + """Process a dataset: load, analyze, and create visualization. + + Parameters + ---------- + dataset_idx : int + Index of the dataset to process (0-based). + node_config : dict, optional + Pre-defined node configuration. + + """ + ds, video_file, video_metadata, slp_filename, node_config = load_data( + dataset_idx, node_config=node_config + ) + + overlay_alpha = _determine_overlay_alpha(video_file) + + node_names = list(ds.keypoints.values) + total_animals = ds.sizes["individuals"] + total_frames = ds.sizes["time"] + fps = ds.attrs.get("fps") + + from_keypoint = node_names[node_config["orientation_from"]] + to_keypoint = node_names[node_config["orientation_to"]] + velocity_keypoint = node_names[node_config["velocity_node"]] + + if not node_config.get("ap_validated", False): + from_keypoint, to_keypoint = _apply_ap_corrections( + ds, + node_config, + node_names, + from_keypoint, + to_keypoint, + ) + + # Phase 1: Validity masks + print("\nComputing validity masks...") + ( + from_pos, + to_pos, + vel_pos_all, + is_valid, + is_pos_valid, + orientation_valid, + vel_valid, + pos_valid, + both_valid, + ) = _compute_validity_masks( + ds, from_keypoint, to_keypoint, velocity_keypoint + ) + + # Phase 2: Polarization + print("Computing polarization metrics...") + orientation_pol_da, orientation_theta_da = compute_polarization( + ds.position, + body_axis_keypoints=(from_keypoint, to_keypoint), + return_angle=True, + validate_ap=False, + ) + heading_pol_da, heading_theta_da = compute_polarization( + ds.position.sel(keypoints=velocity_keypoint), + displacement_frames=1, + return_angle=True, + ) + orientation_pol = orientation_pol_da.values + orientation_theta = orientation_theta_da.values + heading_pol = heading_pol_da.values + heading_theta = heading_theta_da.values + + # Phase 3: Segment selection + print("Finding best segment...") + fps_for_segment, fps_heading_pol, fps_heading_theta, min_disp_full = ( + _compute_fps_heading( + ds, velocity_keypoint, vel_pos_all, fps, total_frames + ) + ) + + segment_result = find_best_segment( + orientation_pol, + orientation_theta, + heading_pol, + fps_heading_pol, + fps_heading_theta, + min_disp_full, + orientation_valid, + vel_valid, + pos_valid, + both_valid, + total_frames, + fps_for_segment, + ) + ( + panel_frames, + prev_frames, + is_continuous, + has_orientation, + has_heading, + frame_interval, + num_panels, + ) = segment_result + + if is_continuous: + panel_heading_pol = fps_heading_pol[panel_frames] + panel_heading_theta = fps_heading_theta[panel_frames] + else: + print( + "Using precomputed heading polarization (1-frame displacement)..." + ) + panel_heading_pol = heading_pol[panel_frames] + panel_heading_theta = heading_theta[panel_frames] + + # Phase 4: Extract per-animal coordinates + print("Extracting per-animal coordinates for selected frames...") + needed = np.unique(np.concatenate([panel_frames, prev_frames])) + + head_x, head_y, tail_x, tail_y = _extract_orientation_coords( + needed, + total_frames, + total_animals, + is_valid, + from_pos, + to_pos, + ) + vel_x, vel_y = _extract_velocity_coords( + needed, + total_frames, + total_animals, + is_pos_valid, + vel_pos_all, + ) + + frames, video_available = _load_needed_video_frames( + video_file, + needed, + total_frames, + ) + + output_dir = ROOT_PATH / "exports" / "polarization-demos" + figures_dir = output_dir / "figures" + output_dir.mkdir(parents=True, exist_ok=True) + figures_dir.mkdir(exist_ok=True) + + base_stem = Path(slp_filename).stem + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + video_filename = ( + video_file.name + if video_file is not None + else f"{slp_filename} (no video)" + ) + write_polarization_log( + panel_frames, + prev_frames, + head_x, + head_y, + tail_x, + tail_y, + vel_x, + vel_y, + orientation_pol, + panel_heading_pol, + orientation_theta, + panel_heading_theta, + total_animals, + node_config, + node_names, + node_config.get("auto_log"), + video_filename, + fps, + is_continuous, + frame_interval, + num_panels, + ) + + print("\nCreating figure...") + fig_data = FigureData( + frames=frames, + panel_frames=panel_frames, + prev_frames=prev_frames, + head_x=head_x, + head_y=head_y, + tail_x=tail_x, + tail_y=tail_y, + vel_x=vel_x, + vel_y=vel_y, + orientation_pol=orientation_pol, + heading_pol=panel_heading_pol, + orientation_theta=orientation_theta, + heading_theta=panel_heading_theta, + total_animals=total_animals, + has_orientation=has_orientation, + has_heading=has_heading, + ) + fig_cfg = FigureConfig( + video_filename=slp_filename, + fps=fps, + node_config=node_config, + node_names=node_names, + is_continuous=is_continuous, + frame_interval=frame_interval, + num_panels=num_panels, + video_metadata=video_metadata, + video_available=video_available, + overlay_alpha=overlay_alpha, + ) + fig = create_figure(fig_data, fig_cfg) + + base_name = base_stem + f"_{timestamp}_polarization-plot" + _save_figure(fig, figures_dir, base_name, fps) + plt.close(fig) + + +def _compute_fps_heading( + ds, velocity_keypoint, vel_pos_all, fps, total_frames +): + """Compute fps-frame heading polarization for segment scoring.""" + if fps is None: + return None, None, None, None + + print(f" Computing {int(fps)}-frame heading for segment scoring...") + vel_pos_da = ds.position.sel(keypoints=velocity_keypoint) + hp_da, ht_da = compute_polarization( + vel_pos_da, + displacement_frames=int(fps), + return_angle=True, + ) + + fi = int(fps) + disp = vel_pos_all[fi:] - vel_pos_all[:-fi] + disp_mag = np.sqrt(np.nansum(disp**2, axis=1)) + min_disp = np.nanmin(disp_mag, axis=1) + min_disp_full = np.full(total_frames, np.nan) + min_disp_full[fi:] = min_disp + + return fps, hp_da.values, ht_da.values, min_disp_full + + +def _extract_orientation_coords( + needed, total_frames, total_animals, is_valid, from_pos, to_pos +): + """Extract head/tail coordinates for the needed frames.""" + head_x = np.full((total_frames, total_animals), np.nan) + head_y = np.full((total_frames, total_animals), np.nan) + tail_x = np.full((total_frames, total_animals), np.nan) + tail_y = np.full((total_frames, total_animals), np.nan) + for f in needed: + if f < 0 or f >= total_frames: + continue + mask = is_valid[f] + tail_x[f, mask] = from_pos[f, 0, mask] + tail_y[f, mask] = from_pos[f, 1, mask] + head_x[f, mask] = to_pos[f, 0, mask] + head_y[f, mask] = to_pos[f, 1, mask] + return head_x, head_y, tail_x, tail_y + + +def _extract_velocity_coords( + needed, total_frames, total_animals, is_pos_valid, vel_pos_all +): + """Extract velocity node coordinates for the needed frames.""" + vel_x = np.full((total_frames, total_animals), np.nan) + vel_y = np.full((total_frames, total_animals), np.nan) + for f in needed: + if f < 0 or f >= total_frames: + continue + mask = is_pos_valid[f] + vel_x[f, mask] = vel_pos_all[f, 0, mask] + vel_y[f, mask] = vel_pos_all[f, 1, mask] + return vel_x, vel_y + + +def _save_figure(fig, figures_dir, base_name, fps): + """Save the figure in configured formats.""" + if SAVE_PNG: + png_file = figures_dir / (base_name + ".png") + fig.savefig( + png_file, + dpi=900, + facecolor=DARK_BG, + edgecolor="none", + bbox_inches="tight", + pad_inches=0.02, + ) + print(f"\nSaved PNG: {png_file}") + + if SAVE_SVG: + svg_file = figures_dir / (base_name + ".svg") + fig.savefig( + svg_file, + format="svg", + facecolor=DARK_BG, + edgecolor="none", + bbox_inches="tight", + pad_inches=0, + ) + print(f"Saved SVG: {svg_file}") + + +# ── Main entry point helpers ───────────────────────────────────────────────── + + +def _load_h5_config_interactive(): + """Attempt to load config from H5 file with user confirmation.""" + print("\n") + print("AP VALIDATION H5 DETECTED") + h5_config = load_suggested_pairs_from_h5() + if h5_config is None: + return None, False + + print() + if _prompt_yn("Use suggested node pairs from H5 file? [y/n]: "): + return h5_config, True + print("Skipping H5 pairs — proceeding to interactive node selection.") + return None, False + + +def _run_batch_mode(batch_config, ap_validated): + """Run batch mode processing for all configured datasets.""" + print("\n") + print("BATCH MODE") + print(f"\nProcessing {len(batch_config)} dataset(s)...") + + for dataset_idx, (from_node, to_node) in batch_config.items(): + print(f"\n{'─' * 60}") + print(f"Dataset [{dataset_idx}]") + print(f"{'─' * 60}") + + node_config = { + "velocity_node": from_node, + "orientation_from": from_node, + "orientation_to": to_node, + "ap_validated": ap_validated, + "auto_log": None, + } + process_dataset(dataset_idx, node_config=node_config) + + print("\n") + print(f"BATCH COMPLETED ({len(batch_config)} datasets)") + + +def _run_interactive_mode(): + """Run interactive mode: prompt user, process single dataset.""" + dataset_idx = prompt_dataset_selection() + process_dataset(dataset_idx) + print("\n") + print("COMPLETED") + + +def main(): + """Run the program from the command line. + + This entry point determines whether to run in batch mode or to load + suggested node pairs interactively from the AP validation H5 file. + It iterates through each configured dataset, computes polarization + measures and visualizations, and writes out figures and logs. + """ + ensure_demo_datasets() + + output_dir = ROOT_PATH / "exports" / "polarization-demos" + logs_dir = output_dir / "logs" + output_dir.mkdir(parents=True, exist_ok=True) + logs_dir.mkdir(exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + console_log = logs_dir / f"polarization_viz_{timestamp}.log" + + with TeeOutput(console_log): + print(f"Polarization Viz started at {datetime.now().isoformat()}") + print(f"Console log: {console_log}") + print() + + batch_config = BATCH_CONFIG + ap_validated = ( + BATCH_CONFIG_AP_VALIDATED if BATCH_CONFIG is not None else False + ) + + if batch_config is None and USE_AP_VALIDATION_H5: + batch_config, ap_validated = _load_h5_config_interactive() + + if batch_config is not None: + _run_batch_mode(batch_config, ap_validated) + else: + _run_interactive_mode() + + print(f"\nConsole log saved to: {console_log}") + + +if __name__ == "__main__": + main() diff --git a/movement/kinematics/__init__.py b/movement/kinematics/__init__.py index 7216a367d..9d03d5651 100644 --- a/movement/kinematics/__init__.py +++ b/movement/kinematics/__init__.py @@ -1,5 +1,6 @@ """Compute variables derived from ``position`` data.""" +from movement.kinematics.collective import compute_polarization from movement.kinematics.distances import compute_pairwise_distances from movement.kinematics.kinematics import ( compute_acceleration, @@ -32,4 +33,5 @@ "compute_head_direction_vector", "compute_forward_vector_angle", "compute_kinetic_energy", + "compute_polarization", ] diff --git a/movement/kinematics/body_axis.py b/movement/kinematics/body_axis.py new file mode 100644 index 000000000..054af965e --- /dev/null +++ b/movement/kinematics/body_axis.py @@ -0,0 +1,2910 @@ +"""Body-axis inference and anterior-posterior validation for pose data. + +This module provides infrastructure for validating user-supplied body-axis +keypoint pairs by inferring the anterior-posterior (AP) axis from motion +data. It uses a prior-free approach combining: + +1. High-motion segment detection via tiered validity and sliding windows +2. Postural clustering via k-medoids (when posture varies across segments) +3. PCA-based body-axis extraction from centered skeletons +4. Velocity projection voting to infer anterior direction +5. A 3-step filter cascade to evaluate candidate AP keypoint pairs + +""" + +from collections.abc import Hashable +from dataclasses import dataclass, field +from typing import Any + +import numpy as np +import xarray as xr + +from movement.utils.logging import logger + +# Separator line for log output formatting +_LOG_SEPARATOR = "\u2500" * 60 + + +# Configuration and Data Classes +# ────────────────────────────── + + +@dataclass +class ValidateAPConfig: + """Configuration for the validate_ap function. + + Parameters + ---------- + min_valid_frac : float, default=0.6 + Minimum fraction of keypoints that must be present for a frame + to qualify as tier-1 valid. + window_len : int, default=50 + Number of speed samples per sliding window. + stride : int, default=5 + Step size between consecutive sliding window start positions. + pct_thresh : float, default=85.0 + Percentile threshold applied to valid-window median speeds for + high-motion classification. + min_run_len : int, default=1 + Minimum number of consecutive qualifying windows required to + form a valid run. + postural_var_ratio_thresh : float, default=2.0 + Between-segment to within-segment RMSD variance ratio above which + postural clustering is triggered. + max_clusters : int, default=4 + Upper bound on the number of clusters to evaluate during k-medoids. + confidence_floor : float, default=0.1 + Vote margin below which the anterior inference is flagged as + unreliable. + lateral_thresh : float, default=0.4 + Normalized lateral offset ceiling for the Step 1 lateral alignment + filter. + edge_thresh : float, default=0.3 + Normalized midpoint distance floor for the Step 3 distal/proximal + classification. + + """ + + min_valid_frac: float = 0.6 + window_len: int = 50 + stride: int = 5 + pct_thresh: float = 85.0 + min_run_len: int = 1 + postural_var_ratio_thresh: float = 2.0 + max_clusters: int = 4 + confidence_floor: float = 0.1 + lateral_thresh: float = 0.4 + edge_thresh: float = 0.3 + + def __post_init__(self) -> None: + """Validate configuration parameters.""" + for name in ( + "min_valid_frac", + "confidence_floor", + "lateral_thresh", + "edge_thresh", + ): + value = getattr(self, name) + if not (0 <= value <= 1): + raise ValueError( + f"{name} must be between 0 and 1, got {value}" + ) + + for name in ("window_len", "stride", "min_run_len", "max_clusters"): + value = getattr(self, name) + if not isinstance(value, int) or value <= 0: + raise ValueError( + f"{name} must be a positive integer, got {value}" + ) + + if not (0 <= self.pct_thresh <= 100): + raise ValueError( + f"pct_thresh must be between 0 and 100, got {self.pct_thresh}" + ) + + if self.postural_var_ratio_thresh <= 0: + raise ValueError( + f"postural_var_ratio_thresh must be positive, " + f"got {self.postural_var_ratio_thresh}" + ) + + +@dataclass +class FrameSelection: + """Selected frames from high-motion segmentation and tier-2 filtering. + + Bundles the frame indices, segment assignments, and related arrays + produced by the segmentation pipeline for downstream consumption + (skeleton construction, postural clustering, velocity recomputation). + + Attributes + ---------- + frames : np.ndarray + Array of selected frame indices (tier-2 valid, within segments). + seg_ids : np.ndarray + Segment ID (0-indexed) for each selected frame. + segments : np.ndarray + Array of shape (n_segments, 2) with [frame_start, frame_end]. + bbox_centroids : np.ndarray + Array of shape (n_frames, 2) with bounding-box centroids. + count : int + Number of selected frames. + + """ + + frames: np.ndarray + seg_ids: np.ndarray + segments: np.ndarray + bbox_centroids: np.ndarray + count: int + + +@dataclass +class APNodePairReport: + """Report from the AP node-pair evaluation pipeline. + + This dataclass holds all results from the 3-step filter cascade + used to evaluate a candidate anterior-posterior keypoint pair. + + Attributes + ---------- + success : bool + Whether the evaluation pipeline completed successfully. + failure_step : str + Name of the step at which evaluation failed, if any. + failure_reason : str + Reason for failure, if any. + scenario : int + Scenario number (1-13) from the mutually exclusive outcomes. + outcome : str + Either "accept" or "warn". + warning_message : str + Warning message, if applicable. + sorted_candidate_nodes : np.ndarray + Indices of candidate nodes after Step 1 filtering, sorted by + ascending normalized lateral offset. + valid_pairs : np.ndarray + Array of shape (n_pairs, 2) containing valid node pairs after + Step 2 filtering. + valid_pairs_internode_dist : np.ndarray + Internode separation (AP distance) for each valid pair. + input_pair_in_candidates : bool + Whether the input pair survived Step 1 filtering. + input_pair_opposite_sides : bool + Whether the input pair lies on opposite sides of the midpoint. + input_pair_separation_abs : float + Absolute AP separation of the input pair. + input_pair_is_distal : bool + Whether the input pair is classified as distal in Step 3. + input_pair_rank : int + Rank of the input pair by internode separation (1 = largest). + input_pair_order_matches_inference : bool + Whether from_node has a lower AP coordinate than to_node + (i.e. from_node is more posterior). True means the input pair + ordering is consistent with the inferred AP axis. + pc1_coords : np.ndarray + PC1 coordinates for each keypoint. + ap_coords : np.ndarray + AP (anterior-posterior) coordinates for each keypoint. + lateral_offsets : np.ndarray + Unsigned lateral offset from body axis for each keypoint. + lateral_offsets_norm : np.ndarray + Normalized lateral offsets (0 = nearest to axis, 1 = farthest). + lateral_offset_min : float + Minimum lateral offset among valid keypoints. + lateral_offset_max : float + Maximum lateral offset among valid keypoints. + midpoint_pc1 : float + AP reference midpoint (average of min and max PC1 projections). + pc1_min : float + Minimum PC1 projection among valid keypoints. + pc1_max : float + Maximum PC1 projection among valid keypoints. + midline_dist_norm : np.ndarray + Normalized distance from midpoint for each keypoint. + midline_dist_max : float + Maximum absolute distance from midpoint. + distal_pairs : np.ndarray + Array of distal pairs (both nodes at or above edge_thresh). + proximal_pairs : np.ndarray + Array of proximal pairs (at least one node below edge_thresh). + max_separation_distal_nodes : np.ndarray + Node indices of the maximum-separation distal pair, ordered + so that element 0 is posterior (lower AP coord) and element 1 + is anterior (higher AP coord). + max_separation_distal : float + Internode separation of the max-separation distal pair. + max_separation_nodes : np.ndarray + Node indices of the overall maximum-separation pair, ordered + so that element 0 is posterior (lower AP coord) and element 1 + is anterior (higher AP coord). + max_separation : float + Internode separation of the overall max-separation pair. + + """ + + success: bool = False + failure_step: str = "" + failure_reason: str = "" + scenario: int = 0 + outcome: str = "" + warning_message: str = "" + + sorted_candidate_nodes: np.ndarray = field( + default_factory=lambda: np.array([], dtype=int) + ) + valid_pairs: np.ndarray = field( + default_factory=lambda: np.zeros((0, 2), dtype=int) + ) + valid_pairs_internode_dist: np.ndarray = field( + default_factory=lambda: np.array([]) + ) + + input_pair_in_candidates: bool = False + input_pair_opposite_sides: bool = False + input_pair_separation_abs: float = np.nan + input_pair_is_distal: bool = False + input_pair_rank: int = 0 + input_pair_order_matches_inference: bool = False + + pc1_coords: np.ndarray = field(default_factory=lambda: np.array([])) + ap_coords: np.ndarray = field(default_factory=lambda: np.array([])) + lateral_offsets: np.ndarray = field(default_factory=lambda: np.array([])) + lateral_offsets_norm: np.ndarray = field( + default_factory=lambda: np.array([]) + ) + lateral_offset_min: float = np.nan + lateral_offset_max: float = np.nan + midpoint_pc1: float = np.nan + pc1_min: float = np.nan + pc1_max: float = np.nan + midline_dist_norm: np.ndarray = field(default_factory=lambda: np.array([])) + midline_dist_max: float = np.nan + + distal_pairs: np.ndarray = field( + default_factory=lambda: np.zeros((0, 2), dtype=int) + ) + proximal_pairs: np.ndarray = field( + default_factory=lambda: np.zeros((0, 2), dtype=int) + ) + max_separation_distal_nodes: np.ndarray = field( + default_factory=lambda: np.array([], dtype=int) + ) + max_separation_distal: float = np.nan + max_separation_nodes: np.ndarray = field( + default_factory=lambda: np.array([], dtype=int) + ) + max_separation: float = np.nan + + +# Tiered Validity and Centroid Computation +# ───────────────────────────────────────── + + +def compute_tiered_validity( + keypoints: np.ndarray, + min_valid_frac: float, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Compute tiered validity masks for each frame. + + Parameters + ---------- + keypoints : np.ndarray + Keypoint positions with shape (n_frames, n_keypoints, 2). + min_valid_frac : float + Minimum fraction of keypoints required for tier-1 validity. + + Returns + ------- + tier1_valid : np.ndarray + Boolean array of shape (n_frames,) indicating tier-1 valid frames. + A frame is tier-1 valid if at least min_valid_frac of keypoints + are present AND at least 2 keypoints are present. + tier2_valid : np.ndarray + Boolean array of shape (n_frames,) indicating tier-2 valid frames. + A frame is tier-2 valid if all keypoints are present. + frac_present : np.ndarray + Array of shape (n_frames,) with fraction of keypoints present. + + """ + _, n_keypoints, _ = keypoints.shape + + keypoint_present = ~np.any(np.isnan(keypoints), axis=2) + n_present = np.sum(keypoint_present, axis=1) + frac_present = n_present / n_keypoints + + tier2_valid = n_present == n_keypoints + tier1_valid = (frac_present >= min_valid_frac) & (n_present >= 2) + + return tier1_valid, tier2_valid, frac_present + + +def compute_bbox_centroid( + keypoints: np.ndarray, + tier1_valid: np.ndarray, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Compute bounding-box centroids for tier-1 valid frames. + + The bounding-box centroid is the midpoint of the axis-aligned bounding + box enclosing all present keypoints. This is density-invariant, unlike + the arithmetic mean. + + Parameters + ---------- + keypoints : np.ndarray + Keypoint positions with shape (n_frames, n_keypoints, 2). + tier1_valid : np.ndarray + Boolean array of shape (n_frames,) indicating tier-1 valid frames. + + Returns + ------- + bbox_centroids : np.ndarray + Array of shape (n_frames, 2) with bounding-box centroids. + NaN for non-tier-1-valid frames. + arith_centroids : np.ndarray + Array of shape (n_frames, 2) with arithmetic-mean centroids. + NaN for non-tier-1-valid frames. Used for diagnostic comparison. + centroid_discrepancy : np.ndarray + Array of shape (n_frames,) with normalized discrepancy between + bbox and arithmetic centroids (distance / bbox_diagonal). + NaN for non-tier-1-valid frames. + + """ + n_frames = keypoints.shape[0] + + bbox_centroids = np.full((n_frames, 2), np.nan) + arith_centroids = np.full((n_frames, 2), np.nan) + centroid_discrepancy = np.full(n_frames, np.nan) + + for f in range(n_frames): + if not tier1_valid[f]: + continue + + kp_f = keypoints[f] + present_mask = ~np.any(np.isnan(kp_f), axis=1) + kp_present = kp_f[present_mask] + + bbox_min = np.min(kp_present, axis=0) + bbox_max = np.max(kp_present, axis=0) + bbox_centroids[f] = (bbox_min + bbox_max) / 2 + + arith_centroids[f] = np.mean(kp_present, axis=0) + + bbox_diag = np.linalg.norm(bbox_max - bbox_min) + if bbox_diag > 0: + discrepancy = np.linalg.norm( + bbox_centroids[f] - arith_centroids[f] + ) + centroid_discrepancy[f] = discrepancy / bbox_diag + else: + centroid_discrepancy[f] = 0.0 + + return bbox_centroids, arith_centroids, centroid_discrepancy + + +# Velocity and Motion Detection +# ────────────────────────────── + + +def compute_frame_velocities( + bbox_centroids: np.ndarray, + tier1_valid: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """Compute frame-to-frame centroid velocities and speeds. + + A velocity is valid only when both adjacent frames are tier-1 valid. + + Parameters + ---------- + bbox_centroids : np.ndarray + Array of shape (n_frames, 2) with bounding-box centroids. + tier1_valid : np.ndarray + Boolean array of shape (n_frames,) indicating tier-1 valid frames. + + Returns + ------- + velocities : np.ndarray + Array of shape (n_frames - 1, 2) with velocity vectors. + Invalid velocities are NaN. + speeds : np.ndarray + Array of shape (n_frames - 1,) with speed scalars. + Invalid speeds are NaN. + + """ + velocities = np.diff(bbox_centroids, axis=0) + speed_valid = tier1_valid[:-1] & tier1_valid[1:] + velocities[~speed_valid] = np.nan + speeds = np.linalg.norm(velocities, axis=1) + + return velocities, speeds + + +def compute_sliding_window_medians( + speeds: np.ndarray, + window_len: int, + stride: int, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Compute median speeds for sliding windows. + + A window is valid only when every speed sample in that window is valid + (non-NaN), ensuring strict NaN-free content. + + Parameters + ---------- + speeds : np.ndarray + Array of shape (n_speed_samples,) with speed values. + window_len : int + Number of speed samples per sliding window. + stride : int + Step size between consecutive window start positions. + + Returns + ------- + window_starts : np.ndarray + Array of window start indices (0-indexed). + window_medians : np.ndarray + Median speed for each window. NaN for invalid windows. + window_all_valid : np.ndarray + Boolean array indicating which windows are fully valid. + + """ + num_speed = len(speeds) + window_starts = np.arange(0, num_speed - window_len + 1, stride) + num_windows = len(window_starts) + + window_medians = np.full(num_windows, np.nan) + window_all_valid = np.zeros(num_windows, dtype=bool) + + for k in range(num_windows): + s = window_starts[k] + e = s + window_len + w = speeds[s:e] + + if np.all(~np.isnan(w)): + window_all_valid[k] = True + window_medians[k] = np.median(w) + + return window_starts, window_medians, window_all_valid + + +def detect_high_motion_windows( + window_medians: np.ndarray, + window_all_valid: np.ndarray, + pct_thresh: float, +) -> np.ndarray: + """Identify high-motion windows based on percentile threshold. + + Parameters + ---------- + window_medians : np.ndarray + Median speed for each window. + window_all_valid : np.ndarray + Boolean array indicating which windows are fully valid. + pct_thresh : float + Percentile threshold (0-100) for high-motion classification. + + Returns + ------- + high_motion : np.ndarray + Boolean array indicating high-motion windows. + + """ + valid_medians = window_medians[window_all_valid] + if len(valid_medians) == 0: + return np.zeros(len(window_medians), dtype=bool) + + thresh = np.percentile(valid_medians, pct_thresh) + high_motion = window_all_valid & (window_medians >= thresh) + + return high_motion + + +# Run and Segment Detection +# ────────────────────────── + + +def detect_runs( + high_motion: np.ndarray, + min_run_len: int, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Detect runs of consecutive high-motion windows. + + A run is a maximal sequence of consecutively indexed qualifying windows. + + Parameters + ---------- + high_motion : np.ndarray + Boolean array indicating high-motion windows. + min_run_len : int + Minimum number of consecutive qualifying windows for a valid run. + + Returns + ------- + run_starts : np.ndarray + Start indices of valid runs. + run_ends : np.ndarray + End indices (inclusive) of valid runs. + run_lengths : np.ndarray + Length of each valid run. + + """ + padded = np.concatenate([[False], high_motion, [False]]) + d = np.diff(padded.astype(int)) + + run_starts_all = np.nonzero(d == 1)[0] + run_ends_all = np.nonzero(d == -1)[0] - 1 + run_lengths_all = run_ends_all - run_starts_all + 1 + + valid_mask = run_lengths_all >= min_run_len + run_starts = run_starts_all[valid_mask] + run_ends = run_ends_all[valid_mask] + run_lengths = run_lengths_all[valid_mask] + + return run_starts, run_ends, run_lengths + + +def convert_runs_to_segments( + run_starts: np.ndarray, + run_ends: np.ndarray, + window_starts: np.ndarray, + window_len: int, +) -> np.ndarray: + """Convert window runs to frame segments. + + Each run is converted to a frame interval spanning from the start frame + of the first window to the end frame of the last window. + + Parameters + ---------- + run_starts : np.ndarray + Start indices of valid runs (indices into window arrays). + run_ends : np.ndarray + End indices (inclusive) of valid runs. + window_starts : np.ndarray + Start frame indices for each window. + window_len : int + Length of each window in frames. + + Returns + ------- + segments_raw : np.ndarray + Array of shape (n_runs, 2) with [frame_start, frame_end] for each run. + + """ + n_runs = len(run_starts) + segments_raw = np.zeros((n_runs, 2), dtype=int) + + for j in range(n_runs): + s_idx = run_starts[j] + e_idx = run_ends[j] + frame_start = window_starts[s_idx] + frame_end = window_starts[e_idx] + window_len + segments_raw[j] = [frame_start, frame_end] + + return segments_raw + + +def merge_segments(segments_raw: np.ndarray) -> np.ndarray: + """Merge overlapping or abutting frame segments. + + Segments are first sorted by start frame, then merged if they overlap + or abut (next start <= current end + 1). + + Parameters + ---------- + segments_raw : np.ndarray + Array of shape (n_segments, 2) with [frame_start, frame_end]. + + Returns + ------- + segments : np.ndarray + Array of merged non-overlapping segments. + + """ + if len(segments_raw) == 0: + return segments_raw + + sorted_idx = np.argsort(segments_raw[:, 0]) + segments_sorted = segments_raw[sorted_idx] + + merged = [segments_sorted[0].tolist()] + + for j in range(1, len(segments_sorted)): + next_seg = segments_sorted[j] + curr_seg = merged[-1] + + if next_seg[0] <= curr_seg[1] + 1: + merged[-1][1] = max(curr_seg[1], next_seg[1]) + else: + merged.append(next_seg.tolist()) + + return np.array(merged, dtype=int) + + +def filter_segments_tier2( + segments: np.ndarray, + tier2_valid: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """Filter segment frames to retain only tier-2 valid frames. + + Parameters + ---------- + segments : np.ndarray + Array of shape (n_segments, 2) with [frame_start, frame_end]. + tier2_valid : np.ndarray + Boolean array of shape (n_frames,) indicating tier-2 valid frames. + + Returns + ------- + selected_frames : np.ndarray + Array of tier-2 valid frame indices within segments. + selected_seg_id : np.ndarray + Segment ID (0-indexed) for each selected frame. + + """ + all_segment_frames: list[int] = [] + for k in range(len(segments)): + frame_start, frame_end = segments[k] + seg_frames = np.arange(frame_start, frame_end + 1) + all_segment_frames.extend(seg_frames) + + segment_frames_all = np.unique(all_segment_frames) + + tier2_mask = tier2_valid[segment_frames_all] + selected_frames = segment_frames_all[tier2_mask] + + num_selected = len(selected_frames) + selected_seg_id = np.zeros(num_selected, dtype=int) + + for j in range(num_selected): + f = selected_frames[j] + for k in range(len(segments)): + if segments[k, 0] <= f <= segments[k, 1]: + selected_seg_id[j] = k + break + + return selected_frames, selected_seg_id + + +# Skeleton Analysis +# ────────────────── + + +def build_centered_skeletons( + keypoints: np.ndarray, + selected_frames: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """Build centroid-centered skeletons for selected frames. + + Uses bounding-box centroid for centering, consistent with the + segmentation centroid. + + Parameters + ---------- + keypoints : np.ndarray + Keypoint positions with shape (n_frames, n_keypoints, 2). + selected_frames : np.ndarray + Array of selected frame indices. + + Returns + ------- + selected_centroids : np.ndarray + Array of shape (num_selected, 2) with bounding-box centroids. + centered_skeletons : np.ndarray + Array of shape (num_selected, n_keypoints, 2) with + centroid-centered skeleton coordinates. + + """ + num_selected = len(selected_frames) + n_keypoints = keypoints.shape[1] + + selected_centroids = np.zeros((num_selected, 2)) + centered_skeletons = np.zeros((num_selected, n_keypoints, 2)) + + for j in range(num_selected): + f = selected_frames[j] + kp_f = keypoints[f] + + bbox_min = np.min(kp_f, axis=0) + bbox_max = np.max(kp_f, axis=0) + centroid_f = (bbox_min + bbox_max) / 2 + + selected_centroids[j] = centroid_f + centered_skeletons[j] = kp_f - centroid_f + + return selected_centroids, centered_skeletons + + +def compute_pairwise_rmsd(centered_skeletons: np.ndarray) -> np.ndarray: + """Compute pairwise RMSD between all centered skeletons. + + RMSD is computed as the square root of the mean of squared entry-wise + differences between flattened skeleton vectors. + + Parameters + ---------- + centered_skeletons : np.ndarray + Array of shape (num_selected, n_keypoints, 2). + + Returns + ------- + rmsd_matrix : np.ndarray + Symmetric matrix of shape (num_selected, num_selected) with + pairwise RMSD values. Diagonal is zero. + + """ + num_selected = len(centered_skeletons) + skel_flat = centered_skeletons.reshape(num_selected, -1) + rmsd_matrix = np.zeros((num_selected, num_selected)) + + for i in range(num_selected): + for j in range(i + 1, num_selected): + d = skel_flat[i] - skel_flat[j] + rmsd_val = np.sqrt(np.mean(d**2)) + rmsd_matrix[i, j] = rmsd_val + rmsd_matrix[j, i] = rmsd_val + + return rmsd_matrix + + +def compute_postural_variance_ratio( + rmsd_matrix: np.ndarray, + selected_seg_id: np.ndarray, +) -> tuple[float, np.ndarray, np.ndarray, bool]: + """Compute the between/within segment RMSD variance ratio. + + Parameters + ---------- + rmsd_matrix : np.ndarray + Pairwise RMSD matrix of shape (num_selected, num_selected). + selected_seg_id : np.ndarray + Segment ID for each selected frame. + + Returns + ------- + var_ratio : float + Ratio of between-segment to within-segment RMSD variance. + Returns 0.0 if either distribution is empty or within variance is 0. + within_rmsds : np.ndarray + Array of within-segment RMSD values. + between_rmsds : np.ndarray + Array of between-segment RMSD values. + var_ratio_override : bool + True if variance ratio was set to 0 due to edge cases. + + """ + num_selected = len(selected_seg_id) + within_rmsds_list: list[float] = [] + between_rmsds_list: list[float] = [] + + for i in range(num_selected): + for j in range(i + 1, num_selected): + if selected_seg_id[i] == selected_seg_id[j]: + within_rmsds_list.append(rmsd_matrix[i, j]) + else: + between_rmsds_list.append(rmsd_matrix[i, j]) + + within_rmsds = np.array(within_rmsds_list) + between_rmsds = np.array(between_rmsds_list) + + var_ratio_override = False + if ( + len(within_rmsds) > 0 + and len(between_rmsds) > 0 + and np.var(within_rmsds) > 0 + ): + var_ratio = np.var(between_rmsds) / np.var(within_rmsds) + else: + var_ratio = 0.0 + var_ratio_override = True + + return var_ratio, within_rmsds, between_rmsds, var_ratio_override + + +# Clustering (k-medoids) +# ─────────────────────── + + +def _update_medoid_for_cluster( + cluster: int, + labels: np.ndarray, + medoids: np.ndarray, + dist_matrix: np.ndarray, +) -> int: + """Find the optimal medoid for a single cluster.""" + cluster_mask = labels == cluster + if not np.any(cluster_mask): + return medoids[cluster] + + cluster_indices = np.nonzero(cluster_mask)[0] + cluster_dists = dist_matrix[np.ix_(cluster_indices, cluster_indices)] + total_dists = np.sum(cluster_dists, axis=1) + best_idx = np.argmin(total_dists) + return cluster_indices[best_idx] + + +def kmedoids( + data: np.ndarray, + k: int, + max_iter: int = 100, + n_init: int = 5, + random_state: int | None = None, +) -> tuple[np.ndarray, np.ndarray, float]: + """Perform k-medoids clustering. + + Parameters + ---------- + data : np.ndarray + Array of shape (n_samples, n_features). + k : int + Number of clusters. + max_iter : int, default=100 + Maximum number of iterations. + n_init : int, default=5 + Number of random initializations. + random_state : int, optional + Random seed for reproducibility. + + Returns + ------- + labels : np.ndarray + Cluster labels for each sample (0-indexed). + medoid_indices : np.ndarray + Indices of medoid samples. + inertia : float + Sum of distances from samples to their medoids. + + """ + from scipy.spatial.distance import cdist + + rng = np.random.default_rng(random_state) + n_samples = len(data) + + dist_matrix = cdist(data, data, metric="euclidean") + + best_labels: np.ndarray | None = None + best_medoids: np.ndarray | None = None + best_inertia = np.inf + + for _ in range(n_init): + medoids = rng.choice(n_samples, size=k, replace=False) + + for _ in range(max_iter): + distances_to_medoids = dist_matrix[:, medoids] + labels = np.argmin(distances_to_medoids, axis=1) + + new_medoids = np.array( + [ + _update_medoid_for_cluster(c, labels, medoids, dist_matrix) + for c in range(k) + ] + ) + + if np.array_equal(np.sort(medoids), np.sort(new_medoids)): + break + medoids = new_medoids + + distances_to_medoids = dist_matrix[:, medoids] + labels = np.argmin(distances_to_medoids, axis=1) + inertia = np.sum(distances_to_medoids[np.arange(n_samples), labels]) + + if inertia < best_inertia: + best_inertia = inertia + best_labels = labels.copy() + best_medoids = medoids.copy() + + assert best_labels is not None and best_medoids is not None + return best_labels, best_medoids, best_inertia + + +def _compute_intra_cluster_dist( + i: int, + labels: np.ndarray, + dist_matrix: np.ndarray, + n_samples: int, +) -> float: + """Compute mean distance from sample i to other samples in its cluster.""" + own_cluster = labels[i] + own_mask = labels == own_cluster + if np.sum(own_mask) > 1: + return np.mean(dist_matrix[i, own_mask & (np.arange(n_samples) != i)]) + return 0.0 + + +def _compute_nearest_cluster_dist( + i: int, + labels: np.ndarray, + dist_matrix: np.ndarray, + unique_labels: np.ndarray, +) -> float: + """Compute mean distance to nearest other cluster.""" + own_cluster = labels[i] + b_i = np.inf + for cluster in unique_labels: + if cluster == own_cluster: + continue + cluster_mask = labels == cluster + if np.any(cluster_mask): + mean_dist = np.mean(dist_matrix[i, cluster_mask]) + b_i = min(b_i, mean_dist) + return b_i + + +def silhouette_score(data: np.ndarray, labels: np.ndarray) -> float: + """Compute mean silhouette score. + + Parameters + ---------- + data : np.ndarray + Array of shape (n_samples, n_features). + labels : np.ndarray + Cluster labels for each sample. + + Returns + ------- + score : float + Mean silhouette score across all samples. + Returns 0.0 if clustering is degenerate. + + """ + from scipy.spatial.distance import cdist + + n_samples = len(data) + unique_labels = np.unique(labels) + n_clusters = len(unique_labels) + + if n_clusters <= 1 or n_clusters >= n_samples: + return 0.0 + + dist_matrix = cdist(data, data, metric="euclidean") + silhouette_vals = np.zeros(n_samples) + + for i in range(n_samples): + a_i = _compute_intra_cluster_dist(i, labels, dist_matrix, n_samples) + b_i = _compute_nearest_cluster_dist( + i, labels, dist_matrix, unique_labels + ) + + if b_i == np.inf: + silhouette_vals[i] = 0.0 + elif max(a_i, b_i) > 0: + silhouette_vals[i] = (b_i - a_i) / max(a_i, b_i) + else: + silhouette_vals[i] = 0.0 + + return float(np.mean(silhouette_vals)) + + +def perform_postural_clustering( + centered_skeletons: np.ndarray, + max_clusters: int, + min_silhouette: float = 0.2, +) -> tuple[np.ndarray, int, int, float, list[tuple[int, float]]]: + """Perform postural clustering using k-medoids with silhouette selection. + + Parameters + ---------- + centered_skeletons : np.ndarray + Array of shape (num_selected, n_keypoints, 2). + max_clusters : int + Maximum number of clusters to evaluate. + min_silhouette : float, default=0.2 + Minimum silhouette score to accept clustering. + + Returns + ------- + cluster_labels : np.ndarray + Cluster labels for each frame (0-indexed). + num_clusters : int + Number of clusters (1 if clustering not accepted). + primary_cluster : int + Index of largest cluster (0-indexed). + best_silhouette : float + Best silhouette score achieved. + silhouette_scores : list of (k, score) + Silhouette scores for each k evaluated. + + """ + num_selected = len(centered_skeletons) + skel_flat = centered_skeletons.reshape(num_selected, -1) + + best_k = 1 + best_sil = -np.inf + silhouette_scores = [] + + max_k = min(max_clusters, num_selected // 2) + + for k in range(2, max_k + 1): + try: + labels, _, _ = kmedoids(skel_flat, k, n_init=5) + sil = silhouette_score(skel_flat, labels) + silhouette_scores.append((k, sil)) + + if sil > best_sil: + best_sil = sil + best_k = k + except Exception: + silhouette_scores.append((k, np.nan)) + + if best_k > 1 and best_sil > min_silhouette: + cluster_labels, _, _ = kmedoids(skel_flat, best_k, n_init=10) + num_clusters = best_k + + cluster_counts = np.bincount(cluster_labels, minlength=num_clusters) + primary_cluster = int(np.argmax(cluster_counts)) + else: + cluster_labels = np.zeros(num_selected, dtype=int) + num_clusters = 1 + primary_cluster = 0 + + return ( + cluster_labels, + num_clusters, + primary_cluster, + best_sil, + silhouette_scores, + ) + + +# PCA and Anterior Inference +# ─────────────────────────── + + +def compute_cluster_velocities( + selected_frames: np.ndarray, + selected_seg_id: np.ndarray, + cluster_mask: np.ndarray, + segments: np.ndarray, + bbox_centroids: np.ndarray, +) -> np.ndarray: + """Compute velocities between adjacent consecutive frames. + + Only considers frames in the same segment and cluster. Frame pairs + where both frames are consecutive (frame[i] == frame[i-1] + 1), + in the same segment, and in the same cluster contribute a velocity + vector. + + Returns + ------- + np.ndarray + Array of shape (n_velocities, 2). Empty (0, 2) if no valid pairs. + + """ + frames_c = selected_frames[cluster_mask] + seg_ids_c = selected_seg_id[cluster_mask] + velocities_list: list[np.ndarray] = [] + + for seg_k in range(len(segments)): + seg_mask = seg_ids_c == seg_k + seg_frames = np.sort(frames_c[seg_mask]) + for fi in range(1, len(seg_frames)): + if seg_frames[fi] != seg_frames[fi - 1] + 1: + continue + curr_frame = seg_frames[fi] + prev_frame = seg_frames[fi - 1] + v = bbox_centroids[curr_frame] - bbox_centroids[prev_frame] + if np.all(~np.isnan(v)): + velocities_list.append(v) + + return np.array(velocities_list) if velocities_list else np.zeros((0, 2)) + + +def infer_anterior_from_velocities( + velocities: np.ndarray, + pc1: np.ndarray, +) -> dict: + """Infer anterior direction from velocity projections onto PC1. + + Uses strict majority vote on PC1 projection signs: anterior = +PC1 + if n_positive > n_negative, else -PC1 (ties default to -PC1). + + Also computes circular statistics on velocity angles: + - resultant_length R = sqrt(C^2 + S^2) where C = mean(cos theta), + S = mean(sin theta) + - vote_margin M = |n+ - n-| / (n+ + n-) + + Returns dict with resultant_length, circ_mean_dir, vel_projs_pc1, + num_positive, num_negative, vote_margin, anterior_sign. + + """ + result: dict = { + "resultant_length": 0.0, + "circ_mean_dir": np.nan, + "vel_projs_pc1": np.array([]), + "num_positive": 0, + "num_negative": 0, + "vote_margin": 0.0, + "anterior_sign": -1, + } + if len(velocities) == 0: + return result + + vel_angles = np.arctan2(velocities[:, 1], velocities[:, 0]) + cos_mean = np.mean(np.cos(vel_angles)) + sin_mean = np.mean(np.sin(vel_angles)) + result["resultant_length"] = np.sqrt(cos_mean**2 + sin_mean**2) + result["circ_mean_dir"] = np.arctan2(sin_mean, cos_mean) + + vel_projs = velocities @ pc1 + num_pos = int(np.sum(vel_projs > 0)) + num_neg = int(np.sum(vel_projs < 0)) + result["vel_projs_pc1"] = vel_projs + result["num_positive"] = num_pos + result["num_negative"] = num_neg + result["vote_margin"] = abs(num_pos - num_neg) / max(num_pos + num_neg, 1) + result["anterior_sign"] = +1 if num_pos > num_neg else -1 + return result + + +def compute_cluster_pca_and_anterior( + centered_skeletons: np.ndarray, + cluster_mask: np.ndarray, + selected_frames: np.ndarray, + selected_seg_id: np.ndarray, + segments: np.ndarray, + bbox_centroids: np.ndarray, +) -> dict: + """Compute SVD-based PCA and velocity-based anterior inference. + + Performs inference for one cluster. + + Performs SVD on the cluster's average centered skeleton to extract PC1/PC2, + applies the geometric sign convention, then infers the anterior direction + via velocity voting on centroid displacements projected onto PC1. + + Returns + ------- + dict + Keys: valid, n_frames, avg_skeleton, valid_shape_rows, + PC1, PC2, anterior_sign, vote_margin, resultant_length, + circ_mean_dir, velocities, vel_projs_pc1, and others. + + """ + n_keypoints = centered_skeletons.shape[1] + n_c = int(np.sum(cluster_mask)) + + result: dict = { + "valid": False, + "n_frames": n_c, + "avg_skeleton": np.full((n_keypoints, 2), np.nan), + "valid_shape_rows": np.zeros(n_keypoints, dtype=bool), + "PC1": np.array([1.0, 0.0]), + "PC2": np.array([0.0, 1.0]), + "proj_pc1": np.full(n_keypoints, np.nan), + "proj_pc2": np.full(n_keypoints, np.nan), + "anterior_sign": -1, + "num_positive": 0, + "num_negative": 0, + "vote_margin": 0.0, + "resultant_length": 0.0, + "circ_mean_dir": np.nan, + "velocities": np.zeros((0, 2)), + "vel_projs_pc1": np.array([]), + } + + if n_c == 0: + return result + + skels_c = centered_skeletons[cluster_mask] + avg_skel_c = np.mean(skels_c, axis=0) + valid_shape_rows = ~np.any(np.isnan(avg_skel_c), axis=1) + + if np.sum(valid_shape_rows) < 2: + return result + + result["avg_skeleton"] = avg_skel_c + result["valid_shape_rows"] = valid_shape_rows + + valid_rows = avg_skel_c[valid_shape_rows] + _u, _s, vt = np.linalg.svd(valid_rows, full_matrices=False) + PC1 = vt[0] + PC2 = vt[1] if len(vt) > 1 else np.array([0.0, 1.0]) + + # Geometric sign convention: + # PC1 flipped so y-component >= 0 + # PC2 flipped so x-component >= 0 + if PC1[1] < 0: + PC1 = -PC1 + if PC2[0] < 0: + PC2 = -PC2 + + result["PC1"] = PC1 + result["PC2"] = PC2 + + proj_pc1 = np.full(n_keypoints, np.nan) + proj_pc2 = np.full(n_keypoints, np.nan) + proj_pc1[valid_shape_rows] = avg_skel_c[valid_shape_rows] @ PC1 + proj_pc2[valid_shape_rows] = avg_skel_c[valid_shape_rows] @ PC2 + result["proj_pc1"] = proj_pc1 + result["proj_pc2"] = proj_pc2 + + velocities = compute_cluster_velocities( + selected_frames, + selected_seg_id, + cluster_mask, + segments, + bbox_centroids, + ) + result["velocities"] = velocities + result.update(infer_anterior_from_velocities(velocities, PC1)) + result["valid"] = True + return result + + +# AP Node-Pair Evaluation (3-Step Filter Cascade) +# ──────────────────────────────────────────────── + + +def compute_node_projections( + report: APNodePairReport, + avg_skeleton: np.ndarray, + pc1_vec: np.ndarray, + anterior_sign: int, + valid_shape_rows: np.ndarray, + from_node: int, + to_node: int, +) -> None: + """Compute raw PC1, AP-oriented, and lateral projections. + + Computes projections for all valid keypoints. + + Populates the report's coordinate arrays and determines: + - pc1_coords: raw projection onto PC1 (sign-convention only) + - ap_coords: projection onto anterior_sign * PC1 (positive = more + anterior) + - lateral_offsets: unsigned distance from the AP axis + - midpoint_pc1: average of min/max PC1 projections (AP reference point) + - input_pair_order_matches_inference: True if from_node's AP coord < + to_node's + + """ + pc1 = pc1_vec / np.linalg.norm(pc1_vec) + e_ap = anterior_sign * pc1 + e_lat = np.array([-e_ap[1], e_ap[0]]) + + valid_rows = avg_skeleton[valid_shape_rows] + report.pc1_coords[valid_shape_rows] = valid_rows @ pc1 + report.ap_coords[valid_shape_rows] = valid_rows @ e_ap + report.lateral_offsets[valid_shape_rows] = np.abs(valid_rows @ e_lat) + + if valid_shape_rows[from_node] and valid_shape_rows[to_node]: + report.input_pair_order_matches_inference = ( + report.ap_coords[from_node] < report.ap_coords[to_node] + ) + + proj_pc1_valid = report.pc1_coords[valid_shape_rows] + report.pc1_min = float(np.min(proj_pc1_valid)) + report.pc1_max = float(np.max(proj_pc1_valid)) + report.midpoint_pc1 = (report.pc1_min + report.pc1_max) / 2 + + +def apply_lateral_filter( + report: APNodePairReport, + valid_idx: np.ndarray, + lateral_thresh: float, +) -> np.ndarray | None: + """Step 1: Filter keypoints by normalized lateral offset. + + Returns sorted candidate node indices, or None on failure. + + """ + d_valid = report.lateral_offsets[valid_idx] + d_min = float(np.min(d_valid)) + d_max = float(np.max(d_valid)) + report.lateral_offset_min = d_min + report.lateral_offset_max = d_max + + if d_max > d_min: + d_norm = (d_valid - d_min) / (d_max - d_min) + report.lateral_offsets_norm[valid_idx] = d_norm + keep_mask = d_norm <= lateral_thresh + else: + report.lateral_offsets_norm[valid_idx] = np.zeros(len(d_valid)) + keep_mask = np.ones(len(d_valid), dtype=bool) + + candidate_idx = np.nonzero(keep_mask)[0] + candidates = valid_idx[candidate_idx] + sorted_order = np.argsort(d_valid[candidate_idx]) + candidates = candidates[sorted_order] + report.sorted_candidate_nodes = candidates.copy() + + if len(candidates) < 2: + report.failure_step = "Step 1: lateral alignment filter" + report.failure_reason = ( + "Fewer than 2 candidates remained after filtering." + ) + return None + return candidates + + +def find_opposite_side_pairs( + report: APNodePairReport, + candidates: np.ndarray, + from_node: int, + to_node: int, + valid_shape_rows: np.ndarray, +) -> tuple[np.ndarray, np.ndarray] | None: + """Step 2: Find candidate pairs on opposite sides of the AP midpoint. + + Returns (pairs, seps) arrays, or None on failure. + + """ + m = report.midpoint_pc1 + report.input_pair_in_candidates = (from_node in candidates) and ( + to_node in candidates + ) + + pairs_list: list[list[int]] = [] + seps_list: list[float] = [] + for ii in range(len(candidates)): + for jj in range(ii + 1, len(candidates)): + i, j = candidates[ii], candidates[jj] + if (report.pc1_coords[i] - m) * (report.pc1_coords[j] - m) < 0: + pairs_list.append([i, j]) + seps_list.append( + abs(report.ap_coords[i] - report.ap_coords[j]) + ) + + pairs = ( + np.array(pairs_list, dtype=int) + if pairs_list + else np.zeros((0, 2), dtype=int) + ) + seps = np.array(seps_list) if seps_list else np.array([]) + report.valid_pairs = pairs + report.valid_pairs_internode_dist = seps + + if valid_shape_rows[from_node] and valid_shape_rows[to_node]: + report.input_pair_opposite_sides = ( + (report.pc1_coords[from_node] - m) + * (report.pc1_coords[to_node] - m) + ) < 0 + report.input_pair_separation_abs = abs( + report.ap_coords[from_node] - report.ap_coords[to_node] + ) + + if len(pairs) == 0: + report.failure_step = "Step 2: opposite-sides constraint" + report.failure_reason = ( + "No candidate pair lies on opposite sides of the midpoint." + ) + return None + return pairs, seps + + +def order_pair_by_ap( + pair: np.ndarray, + ap_coords: np.ndarray, +) -> np.ndarray: + """Order a node pair so element 0 is posterior (lower AP coord). + + This ensures that suggested pairs always encode the + posterior->anterior direction, matching the convention used by + ``body_axis_keypoints=(from_node, to_node)`` where from_node is + posterior and to_node is anterior. + + """ + i, j = pair + if ap_coords[i] <= ap_coords[j]: + return np.array([i, j], dtype=int) + return np.array([j, i], dtype=int) + + +def classify_distal_proximal( + report: APNodePairReport, + pairs: np.ndarray, + seps: np.ndarray, + valid_shape_rows: np.ndarray, + edge_thresh: float, +) -> np.ndarray: + """Step 3: Classify pairs as distal or proximal. Returns pair_is_distal.""" + m = report.midpoint_pc1 + midline_dist = np.abs(report.pc1_coords - m) + d_max_midline = float(np.nanmax(midline_dist[valid_shape_rows])) + report.midline_dist_max = d_max_midline + + if d_max_midline > 0: + report.midline_dist_norm = midline_dist / d_max_midline + else: + report.midline_dist_norm = np.zeros(len(report.pc1_coords)) + + pair_is_distal = np.zeros(len(pairs), dtype=bool) + for k in range(len(pairs)): + i, j = pairs[k] + pair_is_distal[k] = ( + min(report.midline_dist_norm[i], report.midline_dist_norm[j]) + >= edge_thresh + ) + + report.distal_pairs = pairs[pair_is_distal] + report.proximal_pairs = pairs[~pair_is_distal] + + if len(seps) > 0: + idx_max = int(np.argmax(seps)) + report.max_separation_nodes = order_pair_by_ap( + pairs[idx_max], report.ap_coords + ) + report.max_separation = seps[idx_max] + + if np.any(pair_is_distal): + distal_seps = seps[pair_is_distal] + distal_pairs_only = pairs[pair_is_distal] + idx_max_distal = int(np.argmax(distal_seps)) + report.max_separation_distal_nodes = order_pair_by_ap( + distal_pairs_only[idx_max_distal], report.ap_coords + ) + report.max_separation_distal = distal_seps[idx_max_distal] + + return pair_is_distal + + +def check_input_pair_in_valid( + report: APNodePairReport, + pairs: np.ndarray, + seps: np.ndarray, + pair_is_distal: np.ndarray, + from_node: int, + to_node: int, +) -> tuple[bool, int]: + """Check whether input pair is among valid pairs. Returns (found, idx).""" + input_pair_sorted = tuple(sorted([from_node, to_node])) + input_in_valid = False + input_idx = -1 + + for k in range(len(pairs)): + if tuple(sorted(pairs[k])) == input_pair_sorted: + input_in_valid = True + input_idx = k + break + + if input_in_valid: + report.input_pair_is_distal = pair_is_distal[input_idx] + rank_order = np.argsort(seps)[::-1] + report.input_pair_rank = ( + int(np.nonzero(rank_order == input_idx)[0][0]) + 1 + ) + return input_in_valid, input_idx + + +# Scenario Assignment +# ──────────────────── + + +def assign_single_pair_scenario( + report: APNodePairReport, + pairs: np.ndarray, + pair_is_distal: np.ndarray, + input_in_valid: bool, +) -> APNodePairReport: + """Assign scenario when exactly one valid pair exists (scenarios 1-4).""" + if input_in_valid: + if pair_is_distal[0]: + report.scenario = 1 + report.outcome = "accept" + else: + report.scenario = 2 + report.outcome = "warn" + report.warning_message = "Input pair has proximal node(s)." + elif pair_is_distal[0]: + report.scenario = 3 + report.outcome = "warn" + report.warning_message = ( + f"Input invalid. Suggest pair [{pairs[0, 0]}, {pairs[0, 1]}]." + ) + else: + report.scenario = 4 + report.outcome = "warn" + report.warning_message = ( + f"Input invalid. Only option " + f"[{pairs[0, 0]}, {pairs[0, 1]}] has proximal node(s)." + ) + return report + + +def assign_multi_input_distal_scenario( + report: APNodePairReport, + pairs: np.ndarray, + input_idx: int, +) -> APNodePairReport: + """Assign scenario for distal input in multi-pair case (5, 6, 7).""" + input_pair_sorted = tuple( + sorted([pairs[input_idx, 0], pairs[input_idx, 1]]) + ) + max_distal_sorted = ( + tuple(sorted(report.max_separation_distal_nodes)) + if len(report.max_separation_distal_nodes) > 0 + else () + ) + + if report.input_pair_rank == 1: + report.scenario = 5 + report.outcome = "accept" + elif input_pair_sorted == max_distal_sorted: + report.scenario = 7 + report.outcome = "accept" + else: + report.scenario = 6 + report.outcome = "warn" + d = report.max_separation_distal_nodes + report.warning_message = ( + f"Distal pair with greater separation exists: [{d[0]}, {d[1]}]." + ) + return report + + +def assign_multi_input_proximal_scenario( + report: APNodePairReport, + pair_is_distal: np.ndarray, +) -> APNodePairReport: + """Assign scenario for proximal input in multi-pair case (8-11).""" + has_distal = np.any(pair_is_distal) + is_max_sep = report.input_pair_rank == 1 + + if is_max_sep and has_distal: + report.scenario = 8 + d = report.max_separation_distal_nodes + report.warning_message = ( + f"Input has proximal node(s). " + f"Distal alternative: [{d[0]}, {d[1]}]." + ) + elif is_max_sep: + report.scenario = 9 + report.warning_message = ( + "Input has proximal node(s). All pairs have proximal node(s)." + ) + elif has_distal: + report.scenario = 10 + d = report.max_separation_distal_nodes + report.warning_message = ( + f"Input has proximal node(s). " + f"Distal pair with greater separation: [{d[0]}, {d[1]}]." + ) + else: + report.scenario = 11 + report.warning_message = ( + "Input has proximal node(s). All pairs have proximal node(s)." + ) + + report.outcome = "warn" + return report + + +def assign_multi_input_invalid_scenario( + report: APNodePairReport, + pair_is_distal: np.ndarray, +) -> APNodePairReport: + """Assign scenario when input not in valid pairs (12-13).""" + has_distal = np.any(pair_is_distal) + report.outcome = "warn" + + if has_distal: + report.scenario = 12 + d = report.max_separation_distal_nodes + report.warning_message = ( + f"Input invalid. Suggest max separation distal pair: " + f"[{d[0]}, {d[1]}]." + ) + else: + report.scenario = 13 + m = report.max_separation_nodes + report.warning_message = ( + f"Input invalid. All pairs have proximal node(s). " + f"Max separation: [{m[0]}, {m[1]}]." + ) + return report + + +def assign_scenario( + report: APNodePairReport, + pairs: np.ndarray, + seps: np.ndarray, + pair_is_distal: np.ndarray, + input_in_valid: bool, + input_idx: int, +) -> APNodePairReport: + """Assign one of 13 mutually exclusive scenarios. + + Parameters + ---------- + report : APNodePairReport + The report to update with scenario information. + pairs : np.ndarray + Valid pairs array of shape (n_pairs, 2). + seps : np.ndarray + Internode separations for each pair. + pair_is_distal : np.ndarray + Boolean array indicating distal pairs. + input_in_valid : bool + Whether input pair is among valid pairs. + input_idx : int + Index of input pair in valid pairs (-1 if not present). + + Returns + ------- + APNodePairReport + Updated report with scenario, outcome, and warning_message. + + """ + if len(pairs) == 1: + return assign_single_pair_scenario( + report, + pairs, + pair_is_distal, + input_in_valid, + ) + + if not input_in_valid: + return assign_multi_input_invalid_scenario(report, pair_is_distal) + + if report.input_pair_is_distal: + return assign_multi_input_distal_scenario( + report, + pairs, + input_idx, + ) + + return assign_multi_input_proximal_scenario(report, pair_is_distal) + + +def evaluate_ap_node_pair( + avg_skeleton: np.ndarray, + pc1_vec: np.ndarray, + anterior_sign: int, + valid_shape_rows: np.ndarray, + from_node: int, + to_node: int, + config: ValidateAPConfig, +) -> APNodePairReport: + """Evaluate an AP node pair through the 3-step filter cascade. + + Parameters + ---------- + avg_skeleton : np.ndarray + Average centered skeleton of shape (n_keypoints, 2). + pc1_vec : np.ndarray + First principal component vector of shape (2,). + anterior_sign : int + Inferred anterior direction (+1 or -1 relative to PC1). + valid_shape_rows : np.ndarray + Boolean array indicating valid (non-NaN) keypoints. + from_node : int + Index of the input from_node (body_axis_keypoints origin, + claimed posterior). 0-indexed. + to_node : int + Index of the input to_node (body_axis_keypoints target, + claimed anterior). 0-indexed. + config : ValidateAPConfig + Configuration with ``lateral_thresh`` and ``edge_thresh``. + + Returns + ------- + APNodePairReport + Complete evaluation report. + + """ + n_keypoints = len(avg_skeleton) + report = APNodePairReport() + report.pc1_coords = np.full(n_keypoints, np.nan) + report.ap_coords = np.full(n_keypoints, np.nan) + report.lateral_offsets = np.full(n_keypoints, np.nan) + report.lateral_offsets_norm = np.full(n_keypoints, np.nan) + report.midline_dist_norm = np.full(n_keypoints, np.nan) + + for node, label in [(from_node, "from_node"), (to_node, "to_node")]: + if node < 0 or node >= n_keypoints: + report.failure_step = "Input validation" + report.failure_reason = ( + f"{label} must be a valid index in 0..{n_keypoints - 1}." + ) + return report + + valid_idx = np.nonzero(valid_shape_rows)[0] + if len(valid_idx) < 2: + report.failure_step = "Step 1: lateral alignment filter" + report.failure_reason = "Fewer than 2 valid nodes are available." + return report + + compute_node_projections( + report, + avg_skeleton, + pc1_vec, + anterior_sign, + valid_shape_rows, + from_node, + to_node, + ) + + candidates = apply_lateral_filter(report, valid_idx, config.lateral_thresh) + if candidates is None: + return report + + step2 = find_opposite_side_pairs( + report, + candidates, + from_node, + to_node, + valid_shape_rows, + ) + if step2 is None: + return report + pairs, seps = step2 + + pair_is_distal = classify_distal_proximal( + report, + pairs, + seps, + valid_shape_rows, + config.edge_thresh, + ) + + input_in_valid, input_idx = check_input_pair_in_valid( + report, + pairs, + seps, + pair_is_distal, + from_node, + to_node, + ) + + report = assign_scenario( + report, pairs, seps, pair_is_distal, input_in_valid, input_idx + ) + report.success = True + return report + + +# Input Preparation and Validation +# ────────────────────────────────── + + +def resolve_node_index(node: Hashable, names: list) -> int: + """Resolve a node identifier to an integer index.""" + if isinstance(node, str): + if node in names: + return names.index(node) + raise ValueError(f"Keypoint '{node}' not found in {names}.") + if isinstance(node, int): + return node + return int(node) # type: ignore[call-overload] + + +def prepare_validation_inputs( + data: xr.DataArray, + from_node: Hashable, + to_node: Hashable, +) -> tuple[np.ndarray, int, int, str, str, list[str], int]: + """Validate inputs and extract numpy arrays for AP validation. + + Returns + ------- + tuple + (keypoints, from_idx, to_idx, from_name, to_name, + keypoint_names, num_frames) + + Raises + ------ + TypeError + If data is not an xarray.DataArray. + ValueError + If dimensions or indices are invalid. + + """ + if not isinstance(data, xr.DataArray): + raise TypeError( + f"Input data must be an xarray.DataArray, but got {type(data)}." + ) + + required_dims = {"time", "space", "keypoints"} + if not required_dims.issubset(set(data.dims)): + raise ValueError( + f"data must have dimensions {required_dims}, " + f"but has {set(data.dims)}." + ) + + if "individuals" in data.dims: + if data.sizes["individuals"] != 1: + raise ValueError( + "data must be for a single individual. " + "Use data.sel(individuals='name') to select one." + ) + data = data.squeeze("individuals", drop=True) + + if "keypoints" in data.coords: + keypoint_names = list(data.coords["keypoints"].values) + else: + keypoint_names = [f"node_{i}" for i in range(data.sizes["keypoints"])] + + n_keypoints = data.sizes["keypoints"] + from_idx = resolve_node_index(from_node, keypoint_names) + to_idx = resolve_node_index(to_node, keypoint_names) + + if from_idx < 0 or from_idx >= n_keypoints: + raise ValueError( + f"from_node index {from_idx} out of range [0, {n_keypoints - 1}]." + ) + if to_idx < 0 or to_idx >= n_keypoints: + raise ValueError( + f"to_node index {to_idx} out of range [0, {n_keypoints - 1}]." + ) + + data_xy = data.sel(space=["x", "y"]) + keypoints = data_xy.transpose("time", "keypoints", "space").values + + from_name = keypoint_names[from_idx] + to_name = keypoint_names[to_idx] + num_frames = keypoints.shape[0] + + return ( + keypoints, + from_idx, + to_idx, + from_name, + to_name, + keypoint_names, + num_frames, + ) + + +# Pipeline Orchestration Functions +# ────────────────────────────────── + + +def run_motion_segmentation( + keypoints: np.ndarray, + num_frames: int, + config: ValidateAPConfig, + log_info, + log_warning, +) -> dict | None: + """Run tiered validity through segment detection. + + Returns a dict with tier1_valid, tier2_valid, bbox_centroids, + segments, or None on failure (error logged). + + """ + tier1_valid, tier2_valid, _frac = compute_tiered_validity( + keypoints, config.min_valid_frac + ) + num_tier1 = int(np.sum(tier1_valid)) + num_tier2 = int(np.sum(tier2_valid)) + + log_info(_LOG_SEPARATOR) + log_info("Tiered Validity Report") + log_info(_LOG_SEPARATOR) + log_info( + "Tier 1 (>= %.0f%% keypoints): %d / %d frames (%.2f%%)", + config.min_valid_frac * 100, + num_tier1, + num_frames, + 100 * num_tier1 / num_frames, + ) + log_info( + "Tier 2 (100%% keypoints): %d / %d frames (%.2f%%)", + num_tier2, + num_frames, + 100 * num_tier2 / num_frames, + ) + + if num_tier1 < 2: + logger.error("Not enough tier-1 valid frames.") + return None + + bbox_centroids, _arith, centroid_disc = compute_bbox_centroid( + keypoints, tier1_valid + ) + valid_disc = centroid_disc[tier1_valid & ~np.isnan(centroid_disc)] + if len(valid_disc) > 0: + log_info("") + log_info(_LOG_SEPARATOR) + log_info("Centroid Discrepancy Diagnostic") + log_info(_LOG_SEPARATOR) + log_info("BBox vs arithmetic centroid (normalized by bbox diagonal):") + log_info( + " Median: %.4f | Mean: %.4f | Max: %.4f", + np.median(valid_disc), + np.mean(valid_disc), + np.max(valid_disc), + ) + if np.median(valid_disc) > 0.05: + log_warning( + "Median discrepancy > 5%% - annotation density " + "is likely asymmetric." + ) + + segments = detect_motion_segments( + bbox_centroids, tier1_valid, config, log_info + ) + if segments is None: + return None + + return { + "tier1_valid": tier1_valid, + "tier2_valid": tier2_valid, + "bbox_centroids": bbox_centroids, + "segments": segments, + } + + +def detect_motion_segments( + bbox_centroids: np.ndarray, + tier1_valid: np.ndarray, + config: ValidateAPConfig, + log_info, +) -> np.ndarray | None: + """Detect high-motion segments from centroid velocities. + + Returns merged segments array, or None on failure. + + """ + _, speeds = compute_frame_velocities(bbox_centroids, tier1_valid) + num_speed = len(speeds) + + if num_speed < config.window_len: + logger.error( + "window_len=%d exceeds available speed samples=%d.", + config.window_len, + num_speed, + ) + return None + + window_starts, window_medians, window_all_valid = ( + compute_sliding_window_medians( + speeds, config.window_len, config.stride + ) + ) + num_valid_windows = int(np.sum(window_all_valid)) + if num_valid_windows == 0: + logger.error("No fully valid sliding windows found.") + return None + + high_motion = detect_high_motion_windows( + window_medians, window_all_valid, config.pct_thresh + ) + num_high_motion = int(np.sum(high_motion)) + + log_info("") + log_info(_LOG_SEPARATOR) + log_info("High-Motion Window Detection") + log_info(_LOG_SEPARATOR) + log_info( + "Sliding windows (len=%d, stride=%d): " + "%d total, %d fully valid (NaN-free), " + "%d high-motion (median speed >= %dth percentile)", + config.window_len, + config.stride, + len(window_starts), + num_valid_windows, + num_high_motion, + int(config.pct_thresh), + ) + + if num_high_motion == 0: + logger.error("No high-motion windows found.") + return None + + run_starts, run_ends, _run_lengths = detect_runs( + high_motion, config.min_run_len + ) + if len(run_starts) == 0: + logger.error("No runs met min_run_len=%d.", config.min_run_len) + return None + + segments_raw = convert_runs_to_segments( + run_starts, run_ends, window_starts, config.window_len + ) + segments = merge_segments(segments_raw) + + log_info("Detected %d merged high-motion segment(s):", len(segments)) + for i, (start, end) in enumerate(segments): + log_info(" Segment %d: frames %d - %d", i + 1, start, end) + + return segments + + +def select_tier2_frames( + segments: np.ndarray, + tier2_valid: np.ndarray, + num_frames: int, + log_info, + log_warning, +) -> tuple[np.ndarray, np.ndarray, int] | None: + """Filter segment frames to tier-2 valid only. + + Returns (selected_frames, selected_seg_id, num_selected) or None. + + """ + selected_frames, selected_seg_id = filter_segments_tier2( + segments, tier2_valid + ) + + num_tier1_in_segs = sum( + np.sum( + (np.arange(num_frames) >= s[0]) & (np.arange(num_frames) <= s[1]) + ) + for s in segments + ) + num_selected = len(selected_frames) + + log_info("") + log_info(_LOG_SEPARATOR) + log_info("Tier-2 Filtering on High-Motion Segments") + log_info(_LOG_SEPARATOR) + log_info( + "Frames in high-motion segments (any tier): %d", num_tier1_in_segs + ) + log_info( + "Tier-2 valid frames retained (all keypoints present): " + "%d (%.1f%% of segment frames)", + num_selected, + 100 * num_selected / max(num_tier1_in_segs, 1), + ) + + retention = num_selected / max(num_tier1_in_segs, 1) + if retention < 0.3: + log_warning( + "Tier 2 discards > 70%% of segment frames - " + "body model may be unrepresentative." + ) + + if num_selected < 2: + logger.error("Not enough tier-2 valid frames in selected segments.") + return None + + return selected_frames, selected_seg_id, num_selected + + +def run_clustering_and_pca( + centered_skeletons: np.ndarray, + frame_sel: FrameSelection, + config: ValidateAPConfig, + log_info, + log_warning, +) -> dict | None: + """Run postural analysis, clustering, and per-cluster PCA. + + Returns dict with primary_result, cluster_results, + num_clusters, primary_cluster, or None on failure. + + """ + rmsd_matrix = compute_pairwise_rmsd(centered_skeletons) + var_ratio, within_rmsds, between_rmsds, var_ratio_override = ( + compute_postural_variance_ratio(rmsd_matrix, frame_sel.seg_ids) + ) + + rmsd_stats = { + "within": within_rmsds, + "between": between_rmsds, + "var_ratio": var_ratio, + "override": var_ratio_override, + } + log_postural_consistency( + rmsd_stats, + config, + frame_sel.count, + log_info, + ) + + cluster_labels, num_clusters, primary_cluster = decide_and_run_clustering( + centered_skeletons, + var_ratio, + frame_sel.count, + config, + log_info, + ) + + cluster_results = [] + for c in range(num_clusters): + cluster_mask = cluster_labels == c + cr = compute_cluster_pca_and_anterior( + centered_skeletons, + cluster_mask, + frame_sel.frames, + frame_sel.seg_ids, + frame_sel.segments, + frame_sel.bbox_centroids, + ) + cluster_results.append(cr) + + pr = cluster_results[primary_cluster] + if not pr["valid"]: + logger.error("Primary cluster has invalid PCA result.") + return None + + return { + "primary_result": pr, + "cluster_results": cluster_results, + "num_clusters": num_clusters, + "primary_cluster": primary_cluster, + } + + +def log_postural_consistency( + rmsd_stats, + config, + num_selected, + log_info, +): + """Log postural consistency check results.""" + within_rmsds = rmsd_stats["within"] + between_rmsds = rmsd_stats["between"] + var_ratio = rmsd_stats["var_ratio"] + var_ratio_override = rmsd_stats["override"] + + log_info("") + log_info(_LOG_SEPARATOR) + log_info("Postural Consistency Check") + log_info(_LOG_SEPARATOR) + + if len(within_rmsds) > 0: + log_info( + "Within-segment RMSD: mean=%.4f, std=%.4f (n=%d pairs)", + np.mean(within_rmsds), + np.std(within_rmsds), + len(within_rmsds), + ) + else: + log_info("Within-segment RMSD: N/A (no within-segment pairs)") + + if len(between_rmsds) > 0: + log_info( + "Between-segment RMSD: mean=%.4f, std=%.4f (n=%d pairs)", + np.mean(between_rmsds), + np.std(between_rmsds), + len(between_rmsds), + ) + log_info( + "Variance ratio (between/within): %.2f (threshold=%.2f)", + var_ratio, + config.postural_var_ratio_thresh, + ) + if var_ratio_override: + log_info( + " (Conservative override to zero: within-segment variance " + "is zero or no within-segment pairs)" + ) + else: + log_info("Between-segment RMSD: N/A (single segment)") + log_info("Variance ratio: N/A") + + do_clustering = ( + var_ratio > config.postural_var_ratio_thresh and num_selected >= 6 + ) + if do_clustering: + log_info(" -> Variance ratio exceeds threshold. Running clustering.") + elif var_ratio > config.postural_var_ratio_thresh and num_selected < 6: + log_info( + " -> Variance ratio exceeds threshold but too few frames (%d) " + "for clustering.", + num_selected, + ) + else: + log_info(" -> Postural consistency acceptable. Using global average.") + + +def decide_and_run_clustering( + centered_skeletons, + var_ratio, + num_selected, + config, + log_info, +): + """Decide whether to cluster; run k-medoids if triggered.""" + do_clustering = ( + var_ratio > config.postural_var_ratio_thresh and num_selected >= 6 + ) + + if not do_clustering: + return np.zeros(num_selected, dtype=int), 1, 0 + + ( + cluster_labels, + num_clusters, + primary_cluster, + best_silhouette, + silhouette_scores, + ) = perform_postural_clustering(centered_skeletons, config.max_clusters) + + for k, sil in silhouette_scores: + if np.isnan(sil): + log_info(" k=%d: clustering failed.", k) + else: + log_info(" k=%d: mean silhouette = %.4f", k, sil) + + if num_clusters > 1: + cluster_counts = np.bincount(cluster_labels, minlength=num_clusters) + log_info( + " Selected k=%d clusters (silhouette=%.4f). " + "Primary cluster=%d (%d frames)", + num_clusters, + best_silhouette, + primary_cluster + 1, + cluster_counts[primary_cluster], + ) + else: + log_info( + " Clustering did not improve separation (best_sil=%.4f). " + "Using global average.", + best_silhouette, + ) + + return cluster_labels, num_clusters, primary_cluster + + +def log_anterior_report( + pr, + cluster_results, + num_clusters, + primary_cluster, + config, + log_info, + log_warning, +): + """Log anterior direction detection and cluster agreement.""" + log_info("") + log_info(_LOG_SEPARATOR) + log_info("Anterior Direction Inference (Velocity Voting)") + log_info(_LOG_SEPARATOR) + log_info( + "Centroid velocity projections onto PC1: " + "%d positive (+PC1), %d negative (-PC1)", + pr["num_positive"], + pr["num_negative"], + ) + + vote_margin_str = f"Vote margin M: {pr['vote_margin']:.4f}" + if pr["vote_margin"] < config.confidence_floor: + vote_margin_str += ( + f" ** BELOW CONFIDENCE FLOOR ({config.confidence_floor:.2f}) " + "- anterior assignment is unreliable **" + ) + log_warning( + "Vote margin M = %.4f is below confidence floor %.2f - " + "anterior assignment is unreliable.", + pr["vote_margin"], + config.confidence_floor, + ) + log_info(vote_margin_str) + log_info( + "Resultant length R: %.4f (0 = omnidirectional, 1 = unidirectional)", + pr["resultant_length"], + ) + log_info( + "Inferred anterior direction: %sPC1 " + "(strict majority; ties default to -PC1)", + "+" if pr["anterior_sign"] > 0 else "-", + ) + + if num_clusters <= 1: + return + + log_info("") + log_info(_LOG_SEPARATOR) + log_info("Inter-Cluster Anterior Polarity Agreement") + log_info(_LOG_SEPARATOR) + signs = [cr["anterior_sign"] for cr in cluster_results if cr["valid"]] + if len(set(signs)) == 1: + log_info( + "All %d clusters AGREE on anterior polarity.", + num_clusters, + ) + else: + log_info( + "DISAGREEMENT: clusters assign different anterior polarities." + ) + for c, cr in enumerate(cluster_results): + if cr["valid"]: + log_info( + " Cluster %d (%d frames): anterior = %sPC1, " + "vote_margin M = %.4f, resultant_length R = %.4f", + c + 1, + cr["n_frames"], + "+" if cr["anterior_sign"] > 0 else "-", + cr["vote_margin"], + cr["resultant_length"], + ) + log_info( + " Primary result from cluster %d (largest).", + primary_cluster + 1, + ) + + +def log_pair_evaluation( + pair_report, + config, + from_idx, + to_idx, + from_name, + to_name, + log_info, +): + """Log the complete AP node pair evaluation report.""" + log_info("") + log_info(_LOG_SEPARATOR) + log_info("AP Node-Pair Filter Cascade (3-Step Evaluation)") + log_info(_LOG_SEPARATOR) + log_info( + "Input pair: [%d, %d] (%s -> %s, claimed posterior -> anterior)", + from_idx, + to_idx, + from_name, + to_name, + ) + + step1_failed = pair_report.failure_step.startswith("Step 1") + + valid_nodes = np.nonzero(~np.isnan(pair_report.lateral_offsets_norm))[0] + num_candidates = len(pair_report.sorted_candidate_nodes) + step1_loss = 1 - num_candidates / max(len(valid_nodes), 1) + + log_step1_report(pair_report, config, valid_nodes, log_info) + log_input_node_status(pair_report, config, from_idx, to_idx, log_info) + + step2_loss = 0.0 + step3_frac = 0.0 + step2_failed = False + + if step1_failed: + log_info("") + log_info("Step 2-3: not evaluated (Step 1 failed)") + else: + step2_loss, step3_frac, step2_failed = log_step2_step3_details( + pair_report, + config, + from_idx, + to_idx, + num_candidates, + log_info, + ) + + log_loss_summary( + step1_loss, + step2_loss, + step3_frac, + step1_failed, + step2_failed, + log_info, + ) + log_order_check( + pair_report, + from_idx, + to_idx, + from_name, + to_name, + log_info, + ) + + +def log_step1_report(pair_report, config, valid_nodes, log_info): + """Log Step 1 lateral filter results.""" + num_valid = len(valid_nodes) + num_candidates = len(pair_report.sorted_candidate_nodes) + step1_loss = 1 - num_candidates / max(num_valid, 1) + + pass_strs = [] + fail_strs = [] + for node_i in valid_nodes: + lat_norm = pair_report.lateral_offsets_norm[node_i] + if lat_norm <= config.lateral_thresh: + pass_strs.append(f"{node_i}({lat_norm:.2f})") + else: + fail_strs.append(f"{node_i}({lat_norm:.2f})") + + log_info("") + log_info( + "Step 1 - Lateral Alignment Filter (lateral_thresh=%.2f): " + "%d of %d valid nodes pass [loss=%.0f%%]", + config.lateral_thresh, + num_candidates, + num_valid, + 100 * step1_loss, + ) + log_info( + " Scale: 0.00 = nearest to body axis, 1.00 = farthest from body axis" + ) + if pass_strs: + log_info(" PASS: %s", ", ".join(pass_strs)) + if fail_strs: + log_info(" FAIL: %s", ", ".join(fail_strs)) + + +def log_step2_report(pair_report, _config, log_info): + """Log Step 2 opposite-sides results.""" + num_candidates = len(pair_report.sorted_candidate_nodes) + num_possible_pairs = num_candidates * (num_candidates - 1) // 2 + num_valid_pairs = len(pair_report.valid_pairs) + step2_loss = 1 - num_valid_pairs / max(num_possible_pairs, 1) + m = pair_report.midpoint_pc1 + + plus_strs = [] + minus_strs = [] + for node_i in pair_report.sorted_candidate_nodes: + pc1_rel = pair_report.pc1_coords[node_i] - m + if pc1_rel > 0: + plus_strs.append(f"{node_i}({pc1_rel:+.1f})") + else: + minus_strs.append(f"{node_i}({pc1_rel:+.1f})") + + log_info("") + log_info( + "Step 2 - Opposite-Sides Constraint (AP midpoint=%.2f): " + "%d of %d candidate pairs on opposite sides [loss=%.0f%%]", + m, + num_valid_pairs, + num_possible_pairs, + 100 * step2_loss, + ) + if plus_strs: + log_info(" + side (anterior of midpoint): %s", ", ".join(plus_strs)) + if minus_strs: + log_info(" - side (posterior of midpoint): %s", ", ".join(minus_strs)) + + +def log_step3_report(pair_report, config, log_info): + """Log Step 3 distal/proximal classification results.""" + num_distal = len(pair_report.distal_pairs) + num_proximal = len(pair_report.proximal_pairs) + num_valid_pairs = len(pair_report.valid_pairs) + step3_distal_frac = num_distal / max(num_valid_pairs, 1) + + log_info("") + log_info( + "Step 3 - Distal/Proximal Classification (edge_thresh=%.2f): " + "%d distal, %d proximal [distal fraction=%.0f%%]", + config.edge_thresh, + num_distal, + num_proximal, + 100 * step3_distal_frac, + ) + + for idx in range(num_valid_pairs): + node_i, node_j = pair_report.valid_pairs[idx] + d_i = pair_report.midline_dist_norm[node_i] + d_j = pair_report.midline_dist_norm[node_j] + min_d = min(d_i, d_j) + sep = pair_report.valid_pairs_internode_dist[idx] + status = "DISTAL" if min_d >= config.edge_thresh else "PROXIMAL" + log_info( + " [%d,%d]: min_d=%.2f, sep=%.2f [%s]", + node_i, + node_j, + min_d, + sep, + status, + ) + + +def log_loss_summary( + step1_loss, step2_loss, step3_frac, step1_failed, step2_failed, log_info +): + """Log cumulative filtering loss summary.""" + log_info("") + log_info(_LOG_SEPARATOR) + log_info("Filtering Loss Summary") + log_info(_LOG_SEPARATOR) + log_info( + "Step 1 (Lateral Filter): %.0f%% of valid nodes eliminated", + 100 * step1_loss, + ) + if not step1_failed: + log_info( + "Step 2 (Opposite-Sides): %.0f%% of candidate pairs eliminated", + 100 * step2_loss, + ) + if not step1_failed and not step2_failed: + log_info( + "Step 3 (Distal/Proximal): %.0f%% of surviving pairs are distal", + 100 * step3_frac, + ) + + +def log_order_check( + pair_report, from_idx, to_idx, from_name, to_name, log_info +): + """Log AP ordering check for the input pair.""" + log_info("") + log_info(_LOG_SEPARATOR) + log_info("Order Check: is from_node posterior to to_node?") + log_info(_LOG_SEPARATOR) + ap_from = pair_report.ap_coords[from_idx] + ap_to = pair_report.ap_coords[to_idx] + if np.isnan(ap_from) or np.isnan(ap_to): + log_info("Order check: cannot evaluate (invalid node coordinates)") + return + + log_info( + "AP coords: from_node %s[%d]=%.2f, to_node %s[%d]=%.2f", + from_name, + from_idx, + ap_from, + to_name, + to_idx, + ap_to, + ) + if pair_report.input_pair_order_matches_inference: + log_info( + "[%d, %d]: CONSISTENT - inference agrees that " + "from_node is posterior (lower AP coord), " + "to_node is anterior", + from_idx, + to_idx, + ) + else: + log_info( + "[%d, %d]: INCONSISTENT - inference suggests " + "from_node is anterior (higher AP coord), " + "to_node is posterior", + from_idx, + to_idx, + ) + log_info( + " -> Inferred posterior->anterior order would be [%d, %d]", + to_idx, + from_idx, + ) + + +def log_input_node_status( + pair_report, + config, + from_idx, + to_idx, + log_info, +): + """Log whether each input node passed the lateral filter.""" + lat_from = pair_report.lateral_offsets_norm[from_idx] + lat_to = pair_report.lateral_offsets_norm[to_idx] + from_pass = not np.isnan(lat_from) and lat_from <= config.lateral_thresh + to_pass = not np.isnan(lat_to) and lat_to <= config.lateral_thresh + + if from_pass and to_pass: + return + + fail_nodes = [] + if not from_pass: + fail_nodes.append(f"{from_idx}({lat_from:.2f})") + if not to_pass: + fail_nodes.append(f"{to_idx}({lat_to:.2f})") + log_info( + " -> Input node(s) FAILED lateral filter: %s", + ", ".join(fail_nodes), + ) + + +def log_step2_step3_details( + pair_report, + config, + from_idx, + to_idx, + num_candidates, + log_info, +): + """Log Step 2 and Step 3 results when Step 1 succeeded. + + Returns (step2_loss, step3_frac, step2_failed). + + """ + log_step2_report(pair_report, config, log_info) + + if ( + pair_report.input_pair_in_candidates + and not pair_report.input_pair_opposite_sides + ): + log_info(" -> Input nodes on SAME side of AP midpoint") + + num_possible = num_candidates * (num_candidates - 1) // 2 + step2_loss = 1 - len(pair_report.valid_pairs) / max(num_possible, 1) + step2_failed = pair_report.failure_step.startswith("Step 2") + + if step2_failed: + log_info("") + log_info("Step 3: not evaluated (Step 2 failed)") + return step2_loss, 0.0, True + + step3_frac = log_step3_with_proximal_check( + pair_report, + config, + from_idx, + to_idx, + log_info, + ) + return step2_loss, step3_frac, False + + +def log_step3_with_proximal_check( + pair_report, + config, + from_idx, + to_idx, + log_info, +): + """Log Step 3 results and check input pair proximal status. + + Returns step3_frac. + + """ + log_step3_report(pair_report, config, log_info) + num_distal = len(pair_report.distal_pairs) + num_valid_pairs = len(pair_report.valid_pairs) + step3_frac = num_distal / max(num_valid_pairs, 1) + + is_candidate = pair_report.input_pair_in_candidates + is_opposite = pair_report.input_pair_opposite_sides + is_proximal = not pair_report.input_pair_is_distal + if is_candidate and is_opposite and is_proximal: + d_from = pair_report.midline_dist_norm[from_idx] + d_to = pair_report.midline_dist_norm[to_idx] + log_info( + " -> Input pair is PROXIMAL (min_d=%.2f < %.2f)", + min(d_from, d_to), + config.edge_thresh, + ) + + return step3_frac + + +# Main Validation Function +# ───────────────────────── + + +def validate_ap( + data: xr.DataArray, + from_node: Hashable, + to_node: Hashable, + config: ValidateAPConfig | None = None, + verbose: bool = True, +) -> dict: + """Validate an anterior-posterior keypoint pair using body-axis inference. + + This function implements a prior-free body-axis inference pipeline that: + 1. Identifies high-motion segments using tiered validity and sliding + windows + 2. Optionally performs postural clustering via k-medoids + 3. Infers the anterior direction using velocity projection voting + 4. Evaluates the candidate AP keypoint pair through a 3-step filter + cascade + + Parameters + ---------- + data : xarray.DataArray + Position data for a single individual. + from_node : int or str + Index or name of the posterior keypoint. + to_node : int or str + Index or name of the anterior keypoint. + config : ValidateAPConfig, optional + Configuration parameters. If None, uses defaults. + verbose : bool, default=True + If True, log detailed validation output to console. + + Returns + ------- + dict + Validation results including success, anterior_sign, + vote_margin, resultant_length, pair_report, etc. + + """ + if config is None: + config = ValidateAPConfig() + + log_lines: list[str] = [] + + def _log_info(msg, *args): + """Log an informational message.""" + line = msg % args if args else msg + log_lines.append(line) + if verbose: + print(line) + + def _log_warning(msg, *args): + """Log a warning message.""" + line = f"WARNING: {msg % args if args else msg}" + log_lines.append(line) + if verbose: + print(line) + + # Prepare inputs + ( + keypoints, + from_idx, + to_idx, + from_name, + to_name, + _keypoint_names, + num_frames, + ) = prepare_validation_inputs(data, from_node, to_node) + + n_keypoints = keypoints.shape[1] + result: dict = { + "success": False, + "anterior_sign": 0, + "vote_margin": 0.0, + "resultant_length": 0.0, + "num_selected_frames": 0, + "num_clusters": 1, + "primary_cluster": 0, + "pair_report": APNodePairReport(), + "PC1": np.array([1.0, 0.0]), + "PC2": np.array([0.0, 1.0]), + "avg_skeleton": np.full((n_keypoints, 2), np.nan), + "error_msg": "", + "log_lines": log_lines, + } + + # Motion segmentation + seg = run_motion_segmentation( + keypoints, + num_frames, + config, + _log_info, + _log_warning, + ) + if seg is None: + result["error_msg"] = "Motion segmentation failed." + return result + + # Tier-2 frame selection + t2 = select_tier2_frames( + seg["segments"], + seg["tier2_valid"], + num_frames, + _log_info, + _log_warning, + ) + if t2 is None: + result["error_msg"] = "Not enough tier-2 valid frames." + return result + selected_frames, selected_seg_id, num_selected = t2 + result["num_selected_frames"] = num_selected + + # Build centered skeletons + _selected_centroids, centered_skeletons = build_centered_skeletons( + keypoints, selected_frames + ) + + # Bundle frame selection data + frame_sel = FrameSelection( + frames=selected_frames, + seg_ids=selected_seg_id, + segments=seg["segments"], + bbox_centroids=seg["bbox_centroids"], + count=num_selected, + ) + + # Postural clustering + PCA + anterior inference + pca = run_clustering_and_pca( + centered_skeletons, + frame_sel, + config, + _log_info, + _log_warning, + ) + if pca is None: + result["error_msg"] = "Primary cluster PCA failed." + return result + + pr = pca["primary_result"] + result["anterior_sign"] = pr["anterior_sign"] + result["vote_margin"] = pr["vote_margin"] + result["resultant_length"] = pr["resultant_length"] + result["circ_mean_dir"] = pr["circ_mean_dir"] + result["vel_projs_pc1"] = pr["vel_projs_pc1"] + result["PC1"] = pr["PC1"] + result["PC2"] = pr["PC2"] + result["avg_skeleton"] = pr["avg_skeleton"] + result["num_clusters"] = pca["num_clusters"] + result["primary_cluster"] = pca["primary_cluster"] + + # Log anterior inference + log_anterior_report( + pr, + pca["cluster_results"], + pca["num_clusters"], + pca["primary_cluster"], + config, + _log_info, + _log_warning, + ) + + # AP node-pair evaluation + pair_report = evaluate_ap_node_pair( + pr["avg_skeleton"], + pr["PC1"], + pr["anterior_sign"], + pr["valid_shape_rows"], + from_idx, + to_idx, + config, + ) + result["pair_report"] = pair_report + + log_pair_evaluation( + pair_report, + config, + from_idx, + to_idx, + from_name, + to_name, + _log_info, + ) + + result["success"] = True + return result + + +# Multi-Individual Validation +# ──────────────────────────── + + +def run_ap_validation( + data: xr.DataArray, + normalized_keypoints: tuple[Hashable, Hashable], + ap_validation_config: dict[str, Any] | None, +) -> dict: + """Run AP validation across all individuals, select best by R*M. + + Each individual is validated independently using the supplied keypoint + pair. R*M (resultant_length * vote_margin) is computed per individual + and depends only on the individual's motion and body shape, not on + the input pair. The best individual is the one with the highest R*M. + + Parameters + ---------- + data : xarray.DataArray + Position data with individuals dimension. + normalized_keypoints : tuple[Hashable, Hashable] + The (from_node, to_node) keypoint pair. + ap_validation_config : dict, optional + Configuration overrides for ValidateAPConfig. + + Returns + ------- + dict + Dictionary with 'all_results' (list of per-individual results) + and 'best_idx' (index of best individual by R*M). + + """ + config = ( + ValidateAPConfig(**ap_validation_config) + if ap_validation_config is not None + else None + ) + + if "individuals" not in data.dims: + single_result = validate_ap( + data, + from_node=normalized_keypoints[0], + to_node=normalized_keypoints[1], + config=config, + verbose=False, + ) + return {"all_results": [single_result], "best_idx": 0} + + individuals = list(data.coords["individuals"].values) + all_results = [] + for individual in individuals: + result = validate_ap( + data.sel(individuals=individual), + from_node=normalized_keypoints[0], + to_node=normalized_keypoints[1], + config=config, + verbose=False, + ) + result["individual"] = individual + all_results.append(result) + + best_idx = find_best_individual_by_rxm(all_results) + return {"all_results": all_results, "best_idx": best_idx} + + +def find_best_individual_by_rxm(all_results: list[dict]) -> int: + """Return index of the individual with highest R*M score.""" + best_idx = -1 + best_rxm = -1.0 + for i, result in enumerate(all_results): + if not result["success"]: + continue + rxm = result["resultant_length"] * result["vote_margin"] + if rxm > best_rxm: + best_rxm = rxm + best_idx = i + return best_idx diff --git a/movement/kinematics/collective.py b/movement/kinematics/collective.py new file mode 100644 index 000000000..e9795d3fc --- /dev/null +++ b/movement/kinematics/collective.py @@ -0,0 +1,375 @@ +# collective.py +"""Compute collective behavior metrics for multi-individual tracking data.""" + +from collections.abc import Hashable +from typing import Any + +import numpy as np +import xarray as xr + +from movement.kinematics.body_axis import run_ap_validation +from movement.utils.logging import logger +from movement.utils.vector import ( + compute_norm, + compute_signed_angle_2d, + convert_to_unit, +) +from movement.validators.arrays import validate_dims_coords + +_ANGLE_EPS = 1e-12 + + +def compute_polarization( + data: xr.DataArray, + body_axis_keypoints: tuple[Hashable, Hashable] | None = None, + displacement_frames: int = 1, + return_angle: bool = False, + in_degrees: bool = False, + validate_ap: bool = False, + ap_validation_config: dict[str, Any] | None = None, +) -> xr.DataArray | tuple[xr.DataArray, xr.DataArray]: + r"""Compute polarization (group alignment) of individuals. + + Polarization measures how aligned individuals' direction vectors are, + supporting two modes: **orientation polarization** (body-axis mode) for + body orientation alignment, and **heading polarization** (displacement + mode) for movement direction alignment. A value of 1 indicates perfect + alignment, while a value near 0 indicates weak or canceling alignment. + + The polarization is computed as + + .. math:: + + \Phi = \frac{1}{N} \left\| \sum_{i=1}^{N} \hat{u}_i \right\| + + where :math:`\hat{u}_i` is the unit direction vector for individual + :math:`i`, and :math:`N` is the number of valid individuals at that time. + + Parameters + ---------- + data : xarray.DataArray + Position data. Must contain ``time``, ``space``, and ``individuals`` as + dimensions. If ``body_axis_keypoints`` is provided, the array must also + contain a ``keypoints`` dimension. For displacement-based heading, + pre-select a keypoint (e.g., ``data.sel(keypoints="thorax")``) or the + first keypoint (index 0) will be used. + + Spatial coordinates must include ``"x"`` and ``"y"``. If additional + spatial coordinates are present (e.g., ``"z"``), they are ignored; + polarization is computed in the x/y plane. + body_axis_keypoints : tuple[Hashable, Hashable], optional + Pair of keypoint names ``(origin, target)`` used to compute heading as + the vector from origin to target. If omitted, heading is inferred from + displacement over ``displacement_frames``. + displacement_frames : int, default=1 + Number of frames used to compute displacement when + ``body_axis_keypoints`` is not provided. Must be a positive integer. + This parameter is ignored when ``body_axis_keypoints`` is provided. + return_angle : bool, default=False + If True, also return the mean angle. Returns the mean body + orientation angle when using ``body_axis_keypoints``, or the mean + movement direction angle when using displacement-based polarization. + in_degrees : bool, default=False + If True, the mean angle is returned in degrees. Otherwise, the + angle is returned in radians. Only relevant when + ``return_angle=True``. + validate_ap : bool, default=False + If True, run anterior-posterior axis validation when + ``body_axis_keypoints`` is provided. Validation is skipped for + displacement-based polarization. + ap_validation_config : dict, optional + Configuration overrides for anterior-posterior axis validation. + See ``movement.kinematics.body_axis.ValidateAPConfig`` for options. + + Returns + ------- + xarray.DataArray or tuple[xarray.DataArray, xarray.DataArray] + If ``return_angle`` is False, returns a DataArray named + ``"polarization"`` with dimension ``("time",)``. + + If ``return_angle`` is True, returns + ``(polarization, mean_angle)`` where ``mean_angle`` is a DataArray + named ``"mean_angle"`` with dimension ``("time",)``. + + Notes + ----- + Missing data are excluded per individual, per frame. + + Zero-length headings are treated as invalid and excluded from the + calculation. + + The mean angle is defined from the summed unit-heading vector projected + onto the x/y plane. When using ``body_axis_keypoints``, this represents + the mean body orientation; when using displacement, it represents the + mean movement direction. When no valid headings exist, or when the summed + heading vector has zero magnitude (for example exact cancellation), the + returned angle is NaN. + + When ``validate_ap=True`` and ``body_axis_keypoints`` is provided, + anterior-posterior validation is run per individual and the result is + stored in ``polarization.attrs["ap_validation_result"]``. + + Examples + -------- + Compute orientation polarization from body-axis keypoints: + + >>> polarization = compute_polarization( + ... ds.position, + ... body_axis_keypoints=("tail_base", "neck"), + ... ) + + Compute heading polarization from displacement (pre-select keypoint): + + >>> polarization = compute_polarization( + ... ds.position.sel(keypoints="thorax") + ... ) + + If multiple keypoints exist and none is selected, the first is used: + + >>> polarization = compute_polarization(ds.position) + + Return orientation polarization with mean body orientation angle: + + >>> polarization, mean_angle = compute_polarization( + ... ds.position, + ... body_axis_keypoints=("tail_base", "neck"), + ... return_angle=True, + ... ) + + Return heading polarization with mean movement direction angle (radians): + + >>> polarization, mean_angle = compute_polarization( + ... ds.position.sel(keypoints="thorax"), + ... return_angle=True, + ... ) + + Return heading polarization with mean movement direction angle (degrees): + + >>> polarization, mean_angle = compute_polarization( + ... ds.position.sel(keypoints="thorax"), + ... return_angle=True, + ... in_degrees=True, + ... ) + + If multiple keypoints exist, first is used; also return mean angle: + + >>> polarization, mean_angle = compute_polarization( + ... ds.position, + ... return_angle=True, + ... ) + + Run AP validation while computing body-axis polarization: + + >>> polarization = compute_polarization( + ... ds.position, + ... body_axis_keypoints=("tail_base", "neck"), + ... validate_ap=True, + ... ) + + """ + _validate_type_data_array(data) + normalized_keypoints = _validate_position_data( + data=data, + body_axis_keypoints=body_axis_keypoints, + ) + + ap_validation_result = None + if normalized_keypoints is not None: + if validate_ap: + ap_validation_result = run_ap_validation( + data, normalized_keypoints, ap_validation_config + ) + heading_vectors = _compute_heading_from_keypoints( + data=data, + body_axis_keypoints=normalized_keypoints, + ) + else: + heading_vectors = _compute_heading_from_velocity( + data=data, + displacement_frames=displacement_frames, + ) + + heading = _select_space(heading_vectors) + + unit_headings = convert_to_unit(heading) + valid_mask = ~unit_headings.isnull().any(dim="space") + vector_sum = unit_headings.sum(dim="individuals", skipna=True) + sum_magnitude = compute_norm(vector_sum) + n_valid = valid_mask.sum(dim="individuals") + + polarization = xr.where( + n_valid > 0, + sum_magnitude / n_valid, + np.nan, + ).clip(min=0.0, max=1.0) + polarization = polarization.rename("polarization") + + if ap_validation_result is not None: + polarization.attrs["ap_validation_result"] = ap_validation_result + + if not return_angle: + return polarization + + # Normalize vector_sum to unit vector for angle computation + mean_unit_vector = vector_sum / sum_magnitude + + # Compute angle from positive x-axis to mean unit vector + reference = np.array([1, 0]) + angle_defined = (n_valid > 0) & (sum_magnitude > _ANGLE_EPS) + mean_angle = xr.where( + angle_defined, + compute_signed_angle_2d( + mean_unit_vector, reference, v_as_left_operand=True + ), + np.nan, + ) + if in_degrees: + mean_angle = np.rad2deg(mean_angle) + mean_angle = mean_angle.rename("mean_angle") + + return polarization, mean_angle + + +def _compute_heading_from_keypoints( + data: xr.DataArray, + body_axis_keypoints: tuple[Hashable, Hashable], +) -> xr.DataArray: + """Compute heading vectors from two keypoints (origin to target).""" + origin, target = body_axis_keypoints + heading = data.sel(keypoints=target, drop=True) - data.sel( + keypoints=origin, + drop=True, + ) + return heading + + +def _compute_heading_from_velocity( + data: xr.DataArray, + displacement_frames: int = 1, +) -> xr.DataArray: + """Compute heading vectors from displacement direction.""" + _validate_displacement_frames(displacement_frames) + + position = data + if "keypoints" in data.dims: + if data.sizes["keypoints"] < 1: + raise ValueError( + "data.keypoints must contain at least one keypoint." + ) + position = data.isel(keypoints=0, drop=True) + + if "keypoints" in data.coords and data.coords["keypoints"].size > 0: + logger.info( + "Using keypoint '%s' for displacement-based heading.", + data.coords["keypoints"].values[0], + ) + else: + logger.info( + "Using keypoint index 0 for displacement-based heading." + ) + + displacement = position - position.shift(time=displacement_frames) + return displacement + + +def _select_space(data: xr.DataArray) -> xr.DataArray: + """Return data with standard dim order, selecting only x and y coords.""" + result = data.sel(space=["x", "y"]) + return result.transpose("time", "space", "individuals") + + +def _validate_position_data( + data: xr.DataArray, + body_axis_keypoints: tuple[Hashable, Hashable] | None, +) -> tuple[Hashable, Hashable] | None: + """Validate the input array and normalize ``body_axis_keypoints``.""" + validate_dims_coords( + data, + { + "time": [], + "space": [], + "individuals": [], + }, + ) + + allowed_dims = {"time", "space", "individuals", "keypoints"} + unexpected_dims = set(data.dims) - allowed_dims + if unexpected_dims: + raise ValueError( + f"data contains unsupported dimension(s): " + f"{sorted(str(d) for d in unexpected_dims)}" + ) + + if "space" not in data.coords: + raise ValueError( + "data must have coordinate labels for the 'space' dimension." + ) + + space_labels = set(data.coords["space"].values.tolist()) + if not {"x", "y"}.issubset(space_labels): + raise ValueError( + "data.space must include coordinate labels 'x' and 'y'." + ) + + if body_axis_keypoints is None: + return None + + origin, target = _normalize_body_axis_keypoints(body_axis_keypoints) + + if "keypoints" not in data.dims: + raise ValueError( + "body_axis_keypoints requires a 'keypoints' dimension in data." + ) + + validate_dims_coords(data, {"keypoints": [origin, target]}) + return origin, target + + +def _normalize_body_axis_keypoints( + body_axis_keypoints: tuple[Hashable, Hashable] | Any, +) -> tuple[Hashable, Hashable]: + """Validate and normalize the keypoint pair.""" + if isinstance(body_axis_keypoints, (str, bytes)): + raise TypeError( + "body_axis_keypoints must be an iterable of exactly two " + "keypoint names." + ) + + try: + origin, target = body_axis_keypoints + except (TypeError, ValueError) as exc: + raise TypeError( + "body_axis_keypoints must be an iterable of exactly two " + "keypoint names." + ) from exc + + for keypoint in (origin, target): + if not isinstance(keypoint, Hashable): + raise TypeError("Each body axis keypoint must be hashable.") + + if origin == target: + raise ValueError( + "body_axis_keypoints must contain two distinct keypoint names." + ) + + return origin, target + + +def _validate_displacement_frames(displacement_frames: int) -> None: + """Validate the displacement window.""" + if isinstance(displacement_frames, (bool, np.bool_)) or not isinstance( + displacement_frames, + (int, np.integer), + ): + raise TypeError("displacement_frames must be a positive integer.") + + if displacement_frames < 1: + raise ValueError("displacement_frames must be >= 1.") + + +def _validate_type_data_array(data: xr.DataArray) -> None: + """Validate that the input is an xarray.DataArray.""" + if not isinstance(data, xr.DataArray): + raise TypeError( + f"Input data must be an xarray.DataArray, but got {type(data)}." + ) diff --git a/tests/test_unit/test_kinematics/test_body_axis.py b/tests/test_unit/test_kinematics/test_body_axis.py new file mode 100644 index 000000000..9b316b6c1 --- /dev/null +++ b/tests/test_unit/test_kinematics/test_body_axis.py @@ -0,0 +1,62 @@ +# test_body_axis.py +"""Tests for the body axis validation module.""" + +from typing import Any + +import pytest + +from movement.kinematics.body_axis import ValidateAPConfig + + +class TestValidateAPConfig: + """Tests for the ValidateAPConfig dataclass parameter validation.""" + + @pytest.mark.parametrize( + ("field", "value"), + [ + ("min_valid_frac", -0.1), + ("min_valid_frac", 1.1), + ("window_len", 0), + ("window_len", -5), + ("window_len", 2.5), + ("stride", 0), + ("stride", -1), + ("stride", 1.5), + ("pct_thresh", -1), + ("pct_thresh", 101), + ("min_run_len", 0), + ("min_run_len", -1), + ("min_run_len", 1.5), + ("postural_var_ratio_thresh", 0), + ("postural_var_ratio_thresh", -1), + ("max_clusters", 0), + ("max_clusters", 2.5), + ("confidence_floor", -0.1), + ("confidence_floor", 1.1), + ("lateral_thresh", -0.1), + ("lateral_thresh", 1.1), + ("edge_thresh", -0.1), + ("edge_thresh", 1.1), + ], + ) + def test_invalid_config_values_raise(self, field: str, value: Any) -> None: + """Invalid config values should raise ValueError.""" + kwargs = {field: value} + with pytest.raises(ValueError, match="must be"): + ValidateAPConfig(**kwargs) + + def test_valid_config_does_not_raise(self) -> None: + """Valid config values should not raise any error.""" + # Should not raise + ValidateAPConfig( + min_valid_frac=0.5, + window_len=10, + stride=2, + pct_thresh=50.0, + min_run_len=2, + postural_var_ratio_thresh=1.5, + max_clusters=3, + confidence_floor=0.2, + lateral_thresh=0.3, + edge_thresh=0.2, + ) diff --git a/tests/test_unit/test_kinematics/test_collective.py b/tests/test_unit/test_kinematics/test_collective.py new file mode 100644 index 000000000..809b40bdf --- /dev/null +++ b/tests/test_unit/test_kinematics/test_collective.py @@ -0,0 +1,1115 @@ +# test_collective.py +"""Tests for the collective behavior metrics module.""" + +import numpy as np +import pytest +import xarray as xr + +from movement import kinematics + + +def _get_space_labels(n_space: int, space: list[str] | None) -> list[str]: + """Return space labels, defaulting to ['x', 'y'] for 2D.""" + if space is not None: + return space + if n_space == 2: + return ["x", "y"] + raise ValueError("Provide explicit `space` labels for non-2D data.") + + +def _build_coords( + data: np.ndarray, + time: list | None, + space: list[str] | None, + individuals: list | None, + keypoints: list[str] | None = None, +) -> dict: + """Build coordinate dict for a position DataArray.""" + n_time, n_space = data.shape[0], data.shape[1] + coords: dict = { + "time": time if time is not None else list(range(n_time)), + "space": _get_space_labels(n_space, space), + } + if data.ndim == 4: + n_keypoints = data.shape[2] + coords["keypoints"] = keypoints or [ + f"kp_{i}" for i in range(n_keypoints) + ] + n_individuals = data.shape[3] + else: + n_individuals = data.shape[2] + coords["individuals"] = individuals or [ + f"id_{i}" for i in range(n_individuals) + ] + return coords + + +def _make_position_dataarray( + data: np.ndarray, + *, + time: list | None = None, + individuals: list | None = None, + keypoints: list[str] | None = None, + space: list[str] | None = None, +) -> xr.DataArray: + """Create a position DataArray for tests.""" + data = np.asarray(data, dtype=float) + + dims_map = { + 3: ["time", "space", "individuals"], + 4: ["time", "space", "keypoints", "individuals"], + } + if data.ndim not in dims_map: + raise ValueError( + "Expected data with shape (time, space, individuals) or " + "(time, space, keypoints, individuals)." + ) + + return xr.DataArray( + data, + dims=dims_map[data.ndim], + coords=_build_coords(data, time, space, individuals, keypoints), + name="position", + ) + + +@pytest.fixture +def aligned_positions() -> xr.DataArray: + """Two individuals moving together in +x direction.""" + data = np.array( + [ + [[0, 5], [0, 0]], + [[1, 6], [0, 0]], + [[2, 7], [0, 0]], + [[3, 8], [0, 0]], + ], + dtype=float, + ) + return _make_position_dataarray(data) + + +@pytest.fixture +def opposite_positions() -> xr.DataArray: + """Two individuals moving in opposite x directions (+x and -x).""" + data = np.array( + [ + [[0, 5], [0, 0]], + [[1, 4], [0, 0]], + [[2, 3], [0, 0]], + [[3, 2], [0, 0]], + ], + dtype=float, + ) + return _make_position_dataarray(data) + + +@pytest.fixture +def partial_alignment_positions() -> xr.DataArray: + """Three individuals: two move +x, one moves +y.""" + data = np.array( + [ + [[0, 5, 0], [0, 0, 0]], + [[1, 6, 0], [0, 0, 1]], + [[2, 7, 0], [0, 0, 2]], + [[3, 8, 0], [0, 0, 3]], + ], + dtype=float, + ) + return _make_position_dataarray(data) + + +@pytest.fixture +def perpendicular_positions() -> xr.DataArray: + """Four individuals moving in cardinal directions (+x, -x, +y, -y).""" + data = np.array( + [ + [[0, 10, 0, 0], [0, 0, 0, 10]], + [[1, 9, 0, 0], [0, 0, 1, 9]], + [[2, 8, 0, 0], [0, 0, 2, 8]], + [[3, 7, 0, 0], [0, 0, 3, 7]], + ], + dtype=float, + ) + return _make_position_dataarray(data) + + +@pytest.fixture +def keypoint_positions() -> xr.DataArray: + """Two individuals with tail_base/neck keypoints, both facing +x.""" + data = np.array( + [ + [ + [[0.0, 10.0], [1.0, 11.0]], + [[0.0, 0.0], [0.0, 0.0]], + ], + [ + [[0.5, 10.5], [1.5, 11.5]], + [[0.0, 0.0], [0.0, 0.0]], + ], + [ + [[1.0, 11.0], [2.0, 12.0]], + [[0.0, 0.0], [0.0, 0.0]], + ], + ], + dtype=float, + ) + return _make_position_dataarray(data, keypoints=["tail_base", "neck"]) + + +class TestComputePolarizationValidation: + """Tests for input validation in compute_polarization.""" + + def test_requires_dataarray(self): + """Raise TypeError if input is not an xarray.DataArray.""" + with pytest.raises(TypeError, match="xarray.DataArray"): + kinematics.compute_polarization(np.zeros((3, 2, 2))) + + @pytest.mark.parametrize( + "dims", + [ + ("space", "individuals"), + ("time", "individuals"), + ("time", "space"), + ], + ids=["missing_time", "missing_space", "missing_individuals"], + ) + def test_requires_time_space_individuals(self, dims): + """Raise ValueError if required dimensions are missing.""" + data = xr.DataArray(np.zeros((2, 2)), dims=dims) + with pytest.raises(ValueError, match="time|space|individuals"): + kinematics.compute_polarization(data) + + def test_rejects_unexpected_dimensions(self): + """Raise ValueError if data contains unsupported dimensions.""" + data = xr.DataArray( + np.zeros((3, 2, 2, 2)), + dims=["time", "space", "individuals", "batch"], + coords={ + "time": [0, 1, 2], + "space": ["x", "y"], + "individuals": ["a", "b"], + "batch": [0, 1], + }, + ) + with pytest.raises(ValueError, match="unsupported dimension"): + kinematics.compute_polarization(data) + + def test_requires_x_and_y_space_labels(self): + """Raise ValueError if space dimension lacks x and y labels.""" + data = xr.DataArray( + np.zeros((3, 2, 2)), + dims=["time", "space", "individuals"], + coords={ + "time": [0, 1, 2], + "space": ["lat", "lon"], + "individuals": ["a", "b"], + }, + ) + with pytest.raises( + ValueError, match="include coordinate labels 'x' and 'y'" + ): + kinematics.compute_polarization(data) + + @pytest.mark.parametrize( + "body_axis_keypoints", + [ + "neck", + ("tail_base",), + ("tail_base", "neck", "ear"), + 123, + ], + ids=["string", "length_one", "length_three", "non_iterable"], + ) + def test_body_axis_keypoints_must_be_length_two_iterable( + self, + body_axis_keypoints, + keypoint_positions, + ): + """Raise TypeError if body_axis_keypoints is not length-two.""" + with pytest.raises(TypeError, match="exactly two keypoint names"): + kinematics.compute_polarization( + keypoint_positions, + body_axis_keypoints=body_axis_keypoints, + ) + + def test_body_axis_keypoints_must_be_hashable(self, keypoint_positions): + """Raise TypeError if body axis keypoints are not hashable.""" + with pytest.raises(TypeError, match="hashable"): + kinematics.compute_polarization( + keypoint_positions, + body_axis_keypoints=(["tail_base"], "neck"), + ) + + def test_body_axis_keypoints_require_keypoints_dimension( + self, aligned_positions + ): + """Raise ValueError if body_axis_keypoints given without keypoints.""" + with pytest.raises( + ValueError, match="requires a 'keypoints' dimension" + ): + kinematics.compute_polarization( + aligned_positions, + body_axis_keypoints=("tail_base", "neck"), + ) + + def test_body_axis_keypoints_must_exist(self, keypoint_positions): + """Raise ValueError if specified keypoints do not exist in data.""" + with pytest.raises(ValueError, match="snout|keypoints"): + kinematics.compute_polarization( + keypoint_positions, + body_axis_keypoints=("tail_base", "snout"), + ) + + def test_body_axis_keypoints_must_be_distinct(self, keypoint_positions): + """Raise ValueError if origin and target keypoints are identical.""" + with pytest.raises(ValueError, match="two distinct keypoint names"): + kinematics.compute_polarization( + keypoint_positions, + body_axis_keypoints=("tail_base", "tail_base"), + ) + + @pytest.mark.parametrize( + ("displacement_frames", "expected_exception"), + [ + (0, ValueError), + (-1, ValueError), + (1.5, TypeError), + (True, TypeError), + ], + ids=["zero", "negative", "float", "bool"], + ) + def test_displacement_frames_must_be_positive_integer( + self, + aligned_positions, + displacement_frames, + expected_exception, + ): + """Raise error if displacement_frames is not a positive integer.""" + with pytest.raises(expected_exception, match="positive integer|>= 1"): + kinematics.compute_polarization( + aligned_positions, + displacement_frames=displacement_frames, + ) + + def test_invalid_displacement_frames_is_ignored_in_keypoint_mode( + self, + keypoint_positions, + ): + """Invalid displacement_frames is ignored when keypoints are used.""" + polarization = kinematics.compute_polarization( + keypoint_positions, + body_axis_keypoints=("tail_base", "neck"), + displacement_frames=0, + ) + assert np.allclose(polarization.values, 1.0, atol=1e-10) + + def test_requires_space_coordinate_labels_to_exist(self): + """Raise ValueError if the space dimension has no coordinate labels.""" + data = xr.DataArray( + np.zeros((3, 2, 2)), + dims=["time", "space", "individuals"], + coords={ + "time": [0, 1, 2], + "individuals": ["a", "b"], + }, + name="position", + ) + with pytest.raises( + ValueError, + match="coordinate labels for the 'space' dimension", + ): + kinematics.compute_polarization(data) + + def test_empty_keypoints_dimension_raises_in_displacement_mode(self): + """Raise if keypoints dimension exists but contains no entries.""" + data = xr.DataArray( + np.empty((3, 2, 0, 2)), + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": [0, 1, 2], + "space": ["x", "y"], + "keypoints": [], + "individuals": ["a", "b"], + }, + name="position", + ) + with pytest.raises(ValueError, match="at least one keypoint"): + kinematics.compute_polarization(data) + + +class TestComputePolarizationBehavior: + """Tests for polarization computation behavior.""" + + def test_aligned_motion_gives_one(self, aligned_positions): + """Polarization is 1.0 when all individuals move in same direction.""" + polarization = kinematics.compute_polarization(aligned_positions) + assert np.isnan(polarization.values[0]) + assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) + + def test_opposite_motion_gives_zero(self, opposite_positions): + """Polarization is 0.0 when individuals move in opposite directions.""" + polarization = kinematics.compute_polarization(opposite_positions) + assert np.allclose(polarization.values[1:], 0.0, atol=1e-10) + + def test_perpendicular_cardinal_directions_give_zero( + self, perpendicular_positions + ): + """Polarization is 0.0 when four individuals move in cardinal dirs.""" + polarization = kinematics.compute_polarization(perpendicular_positions) + assert np.allclose(polarization.values[1:], 0.0, atol=1e-10) + + def test_partial_alignment_matches_expected_magnitude( + self, + partial_alignment_positions, + ): + """Polarization matches expected value for partial alignment.""" + polarization = kinematics.compute_polarization( + partial_alignment_positions + ) + expected = np.sqrt(5) / 3 + assert np.allclose(polarization.values[1:], expected, atol=1e-10) + + def test_single_individual_gives_one(self): + """Polarization is 1.0 for a single moving individual.""" + data = np.array( + [ + [[0], [0]], + [[1], [0]], + [[2], [0]], + [[3], [0]], + ], + dtype=float, + ) + polarization = kinematics.compute_polarization( + _make_position_dataarray(data) + ) + assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) + + def test_stationary_individuals_are_excluded(self): + """Stationary individuals produce NaN polarization and angle.""" + data = np.array( + [ + [[0, 10], [0, 0]], + [[0, 10], [0, 0]], + [[0, 10], [0, 0]], + ], + dtype=float, + ) + polarization, mean_angle = kinematics.compute_polarization( + _make_position_dataarray(data), + return_angle=True, + ) + assert np.all(np.isnan(polarization.values)) + assert np.all(np.isnan(mean_angle.values)) + + def test_stationary_and_moving_individuals_uses_only_valid_headings(self): + """Only moving individuals contribute to polarization.""" + data = np.array( + [ + [[0, 10], [0, 0]], + [[1, 10], [0, 0]], + [[2, 10], [0, 0]], + [[3, 10], [0, 0]], + ], + dtype=float, + ) + polarization, mean_angle = kinematics.compute_polarization( + _make_position_dataarray(data), + return_angle=True, + ) + assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) + assert np.allclose(mean_angle.values[1:], 0.0, atol=1e-10) + + def test_one_coordinate_nan_excludes_that_individual(self): + """NaN in one coordinate excludes that individual from calculation.""" + data = np.array( + [ + [[0, 10], [0, 0]], + [[1, np.nan], [0, 0]], + [[2, 12], [0, 0]], + ], + dtype=float, + ) + polarization, mean_angle = kinematics.compute_polarization( + _make_position_dataarray(data), + return_angle=True, + ) + assert np.isnan(polarization.values[0]) + assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) + assert np.allclose(mean_angle.values[1:], 0.0, atol=1e-10) + + def test_nan_in_body_axis_heading_excludes_that_individual(self): + """NaN in keypoint position excludes that individual.""" + data = np.array( + [ + [ + [[0.0, 10.0], [1.0, 11.0]], + [[0.0, 0.0], [0.0, 0.0]], + ], + [ + [[1.0, 10.0], [2.0, np.nan]], + [[0.0, 0.0], [0.0, np.nan]], + ], + [ + [[2.0, 12.0], [3.0, 13.0]], + [[0.0, 0.0], [0.0, 0.0]], + ], + ], + dtype=float, + ) + da = _make_position_dataarray(data, keypoints=["tail_base", "neck"]) + polarization = kinematics.compute_polarization( + da, + body_axis_keypoints=("tail_base", "neck"), + ) + assert np.allclose(polarization.values[[0, 2]], 1.0, atol=1e-10) + assert np.allclose(polarization.values[1], 1.0, atol=1e-10) + + def test_empty_individual_axis_returns_all_nan(self): + """Empty individuals axis returns all NaN values.""" + data = _make_position_dataarray( + np.empty((3, 2, 0)), + individuals=[], + space=["x", "y"], + ) + polarization, mean_angle = kinematics.compute_polarization( + data, + return_angle=True, + ) + assert np.all(np.isnan(polarization.values)) + assert np.all(np.isnan(mean_angle.values)) + + def test_empty_time_axis_returns_empty_outputs(self): + """Empty time axis returns empty output arrays.""" + data = xr.DataArray( + np.empty((0, 2, 0)), + dims=["time", "space", "individuals"], + coords={"time": [], "space": ["x", "y"], "individuals": []}, + name="position", + ) + polarization, mean_angle = kinematics.compute_polarization( + data, + return_angle=True, + ) + assert polarization.shape == (0,) + assert mean_angle.shape == (0,) + assert polarization.name == "polarization" + assert mean_angle.name == "mean_angle" + + def test_preserves_non_uniform_time_coordinates(self, aligned_positions): + """Non-uniform time coordinates are preserved in output.""" + time = [0.0, 0.25, 0.75, 1.5] + data = aligned_positions.assign_coords(time=time) + polarization, mean_angle = kinematics.compute_polarization( + data, + return_angle=True, + ) + np.testing.assert_array_equal(polarization.time.values, time) + np.testing.assert_array_equal(mean_angle.time.values, time) + + def test_polarization_is_invariant_to_individual_order(self): + """Polarization is independent of individual ordering.""" + data = np.array( + [ + [[0, 5, 0], [0, 0, 0]], + [[1, 6, 0], [0, 0, 1]], + [[2, 7, 0], [0, 0, 2]], + [[3, 8, 0], [0, 0, 3]], + ], + dtype=float, + ) + da = _make_position_dataarray(data) + da_permuted = da.isel(individuals=[2, 0, 1]) + + pol_original = kinematics.compute_polarization(da) + pol_permuted = kinematics.compute_polarization(da_permuted) + + np.testing.assert_allclose( + pol_original.values, pol_permuted.values, atol=1e-10 + ) + + def test_zero_length_body_axis_vectors_are_excluded(self): + """Zero-length body-axis headings are excluded as invalid.""" + # ind0 has coincident tail_base and neck (zero-length heading) + # ind1 has valid +x body axis heading + data = np.array( + [ + [ + [[0.0, 10.0], [0.0, 11.0]], # x: ind0 zero-length, ind1 +1 + [[0.0, 0.0], [0.0, 0.0]], # y + ], + [ + [[0.0, 10.5], [0.0, 11.5]], + [[0.0, 0.0], [0.0, 0.0]], + ], + ], + dtype=float, + ) + da = _make_position_dataarray(data, keypoints=["tail_base", "neck"]) + + polarization, mean_angle = kinematics.compute_polarization( + da, + body_axis_keypoints=("tail_base", "neck"), + return_angle=True, + ) + + assert np.allclose(polarization.values, 1.0, atol=1e-10) + assert np.allclose(mean_angle.values, 0.0, atol=1e-10) + + def test_polarization_is_invariant_to_translation( + self, + partial_alignment_positions, + ): + """Adding a constant offset does not change polarization.""" + shifted = partial_alignment_positions.copy() + shifted.loc[{"space": "x"}] = shifted.sel(space="x") + 1000.0 + shifted.loc[{"space": "y"}] = shifted.sel(space="y") - 500.0 + + pol_original = kinematics.compute_polarization( + partial_alignment_positions + ) + pol_shifted = kinematics.compute_polarization(shifted) + + np.testing.assert_allclose( + pol_original.values, + pol_shifted.values, + atol=1e-10, + equal_nan=True, + ) + + def test_polarization_is_invariant_to_positive_scaling( + self, + partial_alignment_positions, + ): + """Positive scalar multiplication preserves polarization.""" + scaled = partial_alignment_positions * 7.5 + + pol_original = kinematics.compute_polarization( + partial_alignment_positions + ) + pol_scaled = kinematics.compute_polarization(scaled) + + np.testing.assert_allclose( + pol_original.values, + pol_scaled.values, + atol=1e-10, + equal_nan=True, + ) + + def test_polarization_is_invariant_to_global_rotation( + self, + partial_alignment_positions, + ): + """A global planar rotation preserves polarization magnitude.""" + x = partial_alignment_positions.sel(space="x") + y = partial_alignment_positions.sel(space="y") + + rotated = partial_alignment_positions.copy() + rotated.loc[{"space": "x"}] = -y + rotated.loc[{"space": "y"}] = x + + pol_original = kinematics.compute_polarization( + partial_alignment_positions + ) + pol_rotated = kinematics.compute_polarization(rotated) + + np.testing.assert_allclose( + pol_original.values, + pol_rotated.values, + atol=1e-10, + equal_nan=True, + ) + + def _body_axis_baseline(self): + """Build body-axis test data and return (da, pol, angle). + + Three individuals with body axes: +x, +x, +y. + Vector sum = (2, 1), polarization = sqrt(5)/3, angle = atan2(1,2). + Absolute positions differ across frames to test body-axis + independence from location. + """ + data = np.array( + [ + [ + [[0.0, 10.0, -2.0], [1.0, 11.0, -2.0]], + [[0.0, 5.0, 3.0], [0.0, 5.0, 4.0]], + ], + [ + [[100.0, 50.0, 7.0], [101.0, 51.0, 7.0]], + [[-1.0, 20.0, -3.0], [-1.0, 20.0, -2.0]], + ], + ], + dtype=float, + ) + da = _make_position_dataarray(data, keypoints=["tail_base", "neck"]) + pol, angle = kinematics.compute_polarization( + da, + body_axis_keypoints=("tail_base", "neck"), + return_angle=True, + ) + return da, pol, angle + + def test_body_axis_baseline_matches_expected_values(self): + """Body-axis polarization and angle match hand-computed values.""" + _da, pol, angle = self._body_axis_baseline() + np.testing.assert_allclose(pol.values, np.sqrt(5) / 3, atol=1e-10) + np.testing.assert_allclose( + angle.values, np.arctan2(1.0, 2.0), atol=1e-10 + ) + + def test_body_axis_invariance_to_translation(self): + """Global translation does not change body-axis polarization.""" + da, pol_base, angle_base = self._body_axis_baseline() + translated = da.copy() + translated.loc[{"space": "x"}] = translated.sel(space="x") + 123.4 + translated.loc[{"space": "y"}] = translated.sel(space="y") - 56.7 + + pol, angle = kinematics.compute_polarization( + translated, + body_axis_keypoints=("tail_base", "neck"), + return_angle=True, + ) + np.testing.assert_allclose(pol.values, pol_base.values, atol=1e-10) + np.testing.assert_allclose(angle.values, angle_base.values, atol=1e-10) + + def test_body_axis_invariance_to_positive_scaling(self): + """Positive scaling preserves body-axis polarization and angle.""" + da, pol_base, angle_base = self._body_axis_baseline() + + pol, angle = kinematics.compute_polarization( + da * 4.2, + body_axis_keypoints=("tail_base", "neck"), + return_angle=True, + ) + np.testing.assert_allclose(pol.values, pol_base.values, atol=1e-10) + np.testing.assert_allclose(angle.values, angle_base.values, atol=1e-10) + + def test_body_axis_angle_rotates_under_global_rotation(self): + """Polarization preserved, angle shifts by pi/2. + + Tests behavior under 90-degree rotation. + """ + da, pol_base, angle_base = self._body_axis_baseline() + rotated = da.copy() + x = da.sel(space="x") + y = da.sel(space="y") + rotated.loc[{"space": "x"}] = -y + rotated.loc[{"space": "y"}] = x + + pol, angle = kinematics.compute_polarization( + rotated, + body_axis_keypoints=("tail_base", "neck"), + return_angle=True, + ) + np.testing.assert_allclose(pol.values, pol_base.values, atol=1e-10) + + expected = angle_base.values + (np.pi / 2) + expected = (expected + np.pi) % (2 * np.pi) - np.pi + np.testing.assert_allclose(angle.values, expected, atol=1e-10) + + +class TestHeadingSourceSelection: + """Tests for heading computation mode selection.""" + + def test_body_axis_heading_valid_on_first_frame_returns_expected_angle( + self, keypoint_positions + ): + """Body-axis heading is valid from frame 0 and returns angle 0.""" + polarization, mean_angle = kinematics.compute_polarization( + keypoint_positions, + body_axis_keypoints=("tail_base", "neck"), + return_angle=True, + ) + assert np.allclose(polarization.values, 1.0, atol=1e-10) + assert np.allclose(mean_angle.values, 0.0, atol=1e-10) + + def test_displacement_mode_with_keypoints_uses_first_keypoint(self): + """Displacement mode uses first keypoint when multiple exist.""" + data = np.array( + [ + [ + [[0, 10], [0, 10]], + [[0, 0], [0, 0]], + ], + [ + [[1, 11], [1, 9]], + [[0, 0], [0, 0]], + ], + [ + [[2, 12], [2, 8]], + [[0, 0], [0, 0]], + ], + ], + dtype=float, + ) + da = _make_position_dataarray(data, keypoints=["thorax", "head"]) + polarization = kinematics.compute_polarization(da) + assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) + + def test_explicit_keypoint_selection_with_sel(self): + """Pre-selecting keypoint with .sel() uses that keypoint. + + Data shape: (time, space, keypoints, individuals). + + X-coordinates across frames: + + Keypoint | Individual | Frame 0 | Frame 1 | Displacement + ---------|------------|---------|---------|------------- + thorax | ind0 | 0 | 1 | +1 (right) + thorax | ind1 | 10 | 11 | +1 (right) + head | ind0 | 0 | 1 | +1 (right) + head | ind1 | 10 | 9 | -1 (left) + + Thorax: both individuals move right -> polarization = 1.0 + Head: ind0 moves right, ind1 moves left -> polarization = 0.0 + """ + data = np.array( + [ + [ + [[0, 10], [0, 10]], + [[0, 0], [0, 0]], + ], + [ + [[1, 11], [1, 9]], + [[0, 0], [0, 0]], + ], + ], + dtype=float, + ) + da = _make_position_dataarray(data, keypoints=["thorax", "head"]) + + # Without .sel(): uses thorax -> both move right -> polarization = 1.0 + pol_default = kinematics.compute_polarization(da) + assert np.allclose(pol_default.values[1], 1.0, atol=1e-10) + + # With .sel(): head selected -> ind0 right, ind1 left -> 0.0 + pol_head = kinematics.compute_polarization(da.sel(keypoints="head")) + assert np.allclose(pol_head.values[1], 0.0, atol=1e-10) + + def test_body_axis_heading_overrides_displacement_behavior(self): + """Body-axis heading overrides displacement computation.""" + data = np.array( + [ + [ + [[0.0, 0.0], [1.0, 1.0]], + [[0.0, 2.0], [0.0, 2.0]], + ], + [ + [[0.0, 0.0], [1.0, 1.0]], + [[1.0, 3.0], [1.0, 3.0]], + ], + ], + dtype=float, + ) + da = _make_position_dataarray(data, keypoints=["tail_base", "neck"]) + polarization = kinematics.compute_polarization( + da, + body_axis_keypoints=("tail_base", "neck"), + displacement_frames=1000, + ) + assert np.allclose(polarization.values, 1.0, atol=1e-10) + + def test_extra_spatial_dimensions_are_ignored_for_planar_metrics(self): + """Extra spatial dimensions (z) are ignored; only x/y used.""" + data = np.array( + [ + [[0, 5], [0, 0], [0, 100]], + [[1, 6], [0, 0], [10, -100]], + [[2, 7], [0, 0], [-10, 50]], + [[3, 8], [0, 0], [999, -999]], + ], + dtype=float, + ) + da = _make_position_dataarray(data, space=["x", "y", "z"]) + polarization, mean_angle = kinematics.compute_polarization( + da, + return_angle=True, + ) + assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) + assert np.allclose(mean_angle.values[1:], 0.0, atol=1e-10) + + +class TestDisplacementFrames: + """Tests for displacement_frames parameter behavior.""" + + def test_first_n_frames_are_nan(self, aligned_positions): + """First N frames are NaN when displacement_frames=N.""" + polarization, mean_angle = kinematics.compute_polarization( + aligned_positions, + displacement_frames=2, + return_angle=True, + ) + assert np.all(np.isnan(polarization.values[:2])) + assert np.all(np.isnan(mean_angle.values[:2])) + assert np.allclose(polarization.values[2:], 1.0, atol=1e-10) + assert np.allclose(mean_angle.values[2:], 0.0, atol=1e-10) + + def test_nan_in_reference_frame_propagates_to_that_displacement_window( + self, + ): + """NaN in reference frame propagates through displacement window.""" + data = np.array( + [ + [[0, 5], [0, 0]], + [[np.nan, np.nan], [np.nan, np.nan]], + [[2, 7], [0, 0]], + [[3, 8], [0, 0]], + [[4, 9], [0, 0]], + ], + dtype=float, + ) + polarization = kinematics.compute_polarization( + _make_position_dataarray(data), + displacement_frames=2, + ) + assert np.isnan(polarization.values[0]) + assert np.isnan(polarization.values[1]) + assert np.allclose(polarization.values[2], 1.0, atol=1e-10) + assert np.isnan(polarization.values[3]) + assert np.allclose(polarization.values[4], 1.0, atol=1e-10) + + def test_larger_displacement_window_can_change_alignment_estimate(self): + """Larger displacement window smooths jittery movement.""" + data = np.array( + [ + [[0, 10], [0, 0]], + [[2, 9], [0, 0]], + [[1, 11], [0, 0]], + [[3, 10], [0, 0]], + [[2, 12], [0, 0]], + [[4, 11], [0, 0]], + ], + dtype=float, + ) + da = _make_position_dataarray(data) + + pol_1frame = kinematics.compute_polarization(da, displacement_frames=1) + pol_2frame = kinematics.compute_polarization(da, displacement_frames=2) + + assert np.allclose(pol_1frame.values[1:], 0.0, atol=1e-10) + assert np.allclose(pol_2frame.values[2:], 1.0, atol=1e-10) + + def test_displacement_frames_larger_than_time_axis_returns_all_nan( + self, + aligned_positions, + ): + """Oversized displacement windows produce no valid headings.""" + polarization, mean_angle = kinematics.compute_polarization( + aligned_positions, + displacement_frames=10, + return_angle=True, + ) + assert np.all(np.isnan(polarization.values)) + assert np.all(np.isnan(mean_angle.values)) + + +class TestReturnAngle: + """Tests for return_angle parameter behavior.""" + + def test_default_returns_only_polarization(self, aligned_positions): + """Default return is a single polarization DataArray.""" + result = kinematics.compute_polarization(aligned_positions) + assert isinstance(result, xr.DataArray) + assert result.name == "polarization" + assert result.dims == ("time",) + + def test_return_angle_true_returns_named_pair(self, aligned_positions): + """return_angle=True returns (polarization, mean_angle) tuple.""" + polarization, mean_angle = kinematics.compute_polarization( + aligned_positions, + return_angle=True, + ) + assert isinstance(polarization, xr.DataArray) + assert isinstance(mean_angle, xr.DataArray) + assert (polarization.name, mean_angle.name) == ( + "polarization", + "mean_angle", + ) + assert polarization.dims == ("time",) + assert mean_angle.dims == ("time",) + + @pytest.mark.parametrize( + ("data", "expected_angle", "use_abs"), + [ + ( + np.array( + [ + [[0, 5], [0, 0]], + [[1, 6], [0, 0]], + [[2, 7], [0, 0]], + ], + dtype=float, + ), + 0.0, + False, + ), + ( + np.array( + [ + [[0, 0], [0, 5]], + [[0, 0], [1, 6]], + [[0, 0], [2, 7]], + ], + dtype=float, + ), + np.pi / 2, + False, + ), + ( + np.array( + [ + [[10, 15], [0, 0]], + [[9, 14], [0, 0]], + [[8, 13], [0, 0]], + ], + dtype=float, + ), + np.pi, + True, + ), + ], + ids=["positive_x", "positive_y", "negative_x"], + ) + def test_mean_angle_matches_cardinal_directions( + self, + data, + expected_angle, + use_abs, + ): + """Mean angle matches expected value for cardinal directions.""" + _, mean_angle = kinematics.compute_polarization( + _make_position_dataarray(data), + return_angle=True, + ) + values = mean_angle.values[1:] + if use_abs: + values = np.abs(values) + assert np.allclose(values, expected_angle, atol=1e-10) + + def test_mean_angle_diagonal_motion_is_pi_over_four(self): + """Mean angle is pi/4 for diagonal (+x, +y) motion.""" + data = np.array( + [ + [[0, 5], [0, 5]], + [[1, 6], [1, 6]], + [[2, 7], [2, 7]], + ], + dtype=float, + ) + _, mean_angle = kinematics.compute_polarization( + _make_position_dataarray(data), + return_angle=True, + ) + assert np.allclose(mean_angle.values[1:], np.pi / 4, atol=1e-10) + + def test_mean_angle_partial_alignment_matches_vector_average( + self, + partial_alignment_positions, + ): + """Mean angle matches vector average for partial alignment.""" + _, mean_angle = kinematics.compute_polarization( + partial_alignment_positions, + return_angle=True, + ) + expected = np.arctan2(1, 2) + assert np.allclose(mean_angle.values[1:], expected, atol=1e-10) + + def test_mean_angle_is_nan_when_net_vector_cancels( + self, + opposite_positions, + perpendicular_positions, + ): + """Mean angle is NaN when heading vectors cancel out.""" + pol_opposite, angle_opposite = kinematics.compute_polarization( + opposite_positions, + return_angle=True, + ) + pol_perp, angle_perp = kinematics.compute_polarization( + perpendicular_positions, + return_angle=True, + ) + assert np.allclose(pol_opposite.values[1:], 0.0, atol=1e-10) + assert np.allclose(pol_perp.values[1:], 0.0, atol=1e-10) + assert np.all(np.isnan(angle_opposite.values[1:])) + assert np.all(np.isnan(angle_perp.values[1:])) + + def test_mean_angle_rotates_with_global_rotation( + self, + partial_alignment_positions, + ): + """Mean angle shifts by the same amount under global rotation.""" + _, angle_original = kinematics.compute_polarization( + partial_alignment_positions, + return_angle=True, + ) + + x = partial_alignment_positions.sel(space="x") + y = partial_alignment_positions.sel(space="y") + + rotated = partial_alignment_positions.copy() + rotated.loc[{"space": "x"}] = -y + rotated.loc[{"space": "y"}] = x + + _, angle_rotated = kinematics.compute_polarization( + rotated, + return_angle=True, + ) + + expected = angle_original.values[1:] + (np.pi / 2) + expected = (expected + np.pi) % (2 * np.pi) - np.pi + + np.testing.assert_allclose( + angle_rotated.values[1:], + expected, + atol=1e-10, + ) + + def test_mean_angle_wraparound_near_pi_is_handled_correctly(self): + """Headings near +pi and -pi should average leftward, not to zero.""" + # Two individuals moving left with tiny y-offsets in opposite dirs. + # This creates headings very close to +pi and -pi. + data = np.array( + [ + [[0.0, 0.0], [0.0, 0.0]], + [[-1.0, -1.0], [1e-6, -1e-6]], + [[-2.0, -2.0], [2e-6, -2e-6]], + ], + dtype=float, + ) + + polarization, mean_angle = kinematics.compute_polarization( + _make_position_dataarray(data), + return_angle=True, + ) + + assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) + assert np.allclose( + np.abs(mean_angle.values[1:]), + np.pi, + atol=1e-6, + ) + + def test_in_degrees_true_returns_degrees(self): + """in_degrees=True returns angle in degrees.""" + # Two individuals moving in +y direction + data = np.array( + [ + [[0, 0], [0, 0]], + [[0, 0], [1, 1]], + [[0, 0], [2, 2]], + ], + dtype=float, + ) + _, mean_angle_rad = kinematics.compute_polarization( + _make_position_dataarray(data), + return_angle=True, + in_degrees=False, + ) + _, mean_angle_deg = kinematics.compute_polarization( + _make_position_dataarray(data), + return_angle=True, + in_degrees=True, + ) + # +y direction = 90 degrees = pi/2 radians + assert np.allclose(mean_angle_rad.values[1:], np.pi / 2, atol=1e-10) + assert np.allclose(mean_angle_deg.values[1:], 90.0, atol=1e-10)