diff --git a/examples/compute_polarization_AP_inference.py b/examples/compute_polarization_AP_inference.py new file mode 100644 index 000000000..8328203f9 --- /dev/null +++ b/examples/compute_polarization_AP_inference.py @@ -0,0 +1,3958 @@ +#!/usr/bin/env python3 +"""AP inference demo script: PC1-based body-axis inference vs hand-curated GT. + +Tests the prior-free AP inference pipeline against manually curated ground +truth rankings across 5 multi-animal SLEAP datasets. Operates in three +passes: + Pass 1 — R×M Selection: find best individual per file + Pass 2 — Cross-Individual Ordering Consistency: compare raw PC1-based + orderings against best individual's ordering (pseudo GT) + Pass 3 — Inferred AP Concordance: for each + individual, compare the velocity-inferred AP ordering + (anterior_sign × PC1) of GT nodes against hand-curated GT + +After the passes, all GT pair permutations × all individuals are run +through validate_ap and stored in HDF5. analyze_results then reads back +the H5 to produce GT coverage analysis, suggested pair analysis, and +the data for Figure 2. + +Generates two types of figures: + Figure 1 — Per-file detail (2×2 tile per best individual) + Figure 2 — Cross-dataset comparison (skeletons + coverage + ordering) + +Parallelizes at pair×individual level for max throughput. +""" + +import itertools +import multiprocessing as mp +import os +import sys +from collections import defaultdict +from concurrent.futures import ProcessPoolExecutor, as_completed +from datetime import datetime +from pathlib import Path +from urllib.request import urlretrieve + +import h5py +import matplotlib.patches as mpatches +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.legend_handler import HandlerBase +from matplotlib.lines import Line2D + +# Configuration +ROOT_PATH = Path(__file__).parent / "datasets" / "multi-animal" +SLP_DIR = ROOT_PATH / "slp" +MP4_DIR = ROOT_PATH / "mp4" +OUTPUT_DIR = ROOT_PATH / "exports" / "AP-inference-demo" +FIGURES_DIR = OUTPUT_DIR / "figures" +LOGS_DIR = OUTPUT_DIR / "logs" +H5_DIR = OUTPUT_DIR / "h5" +N_WORKERS = mp.cpu_count() + +# Hand-curated ground truth AP rankings (not exhaustive) +# Format: {file_stem: {node_index: rank}} +# Convention: higher rank = more anterior (rank 1 = most posterior in subset) +GROUND_TRUTH = { + # Higher rank = more anterior + "free-moving-2flies-ID-13nodes-1024x1024x1-30_3pxmm": { + 0: 3, # head (most anterior, rank 3) + 1: 2, # thorax + 2: 1, # abdomen (most posterior, rank 1) + }, + "free-moving-2mice-noID-5nodes-1280x1024x1-1_9pxmm": { + 0: 2, # snout (anterior, rank 2) + 3: 1, # tail-base (posterior, rank 1) + }, + "free-moving-4gerbils-ID-14nodes-1024x1280x3-2pxmm": { + 0: 6, # nose (most anterior, rank 6) + 5: 5, # spine1 + 6: 4, # spine2 + 7: 3, # spine3 + 8: 2, # spine4 + 9: 1, # spine5 (most posterior, rank 1) + }, + "free-moving-5mice-noID-11nodes-1280x1024x1-1_97pxmm": { + 0: 3, # nose (most anterior, rank 3) + 1: 2, # neck + 6: 1, # tail_base (most posterior, rank 1) + }, + "freemoving-2bees-noID-21nodes-1535x2048x1-14pxmm": { + 1: 3, # head (most anterior, rank 3) + 0: 2, # thorax + 2: 1, # abdomen (most posterior, rank 1) + }, +} + +# Display labels for files +FILE_LABELS = { + "free-moving-2flies-ID-13nodes-1024x1024x1-30_3pxmm": "2Flies.slp", + "free-moving-2mice-noID-5nodes-1280x1024x1-1_9pxmm": "2Mice.slp", + "free-moving-4gerbils-ID-14nodes-1024x1280x3-2pxmm": "4Gerbils.slp", + "free-moving-5mice-noID-11nodes-1280x1024x1-1_97pxmm": "5Mice.slp", + "freemoving-2bees-noID-21nodes-1535x2048x1-14pxmm": "2Bees.slp", +} + +DEMO_DATASETS = { + "free-moving-2flies-ID-13nodes-1024x1024x1-30_3pxmm": { + "mp4": "https://storage.googleapis.com/sleap-data/datasets/wt_gold.13pt/clips/talk_title_slide%4013150-14500.mp4", + "slp": "https://storage.googleapis.com/sleap-data/datasets/wt_gold.13pt/clips/talk_title_slide%4013150-14500.slp", + }, + "free-moving-5mice-noID-11nodes-1280x1024x1-1_97pxmm": { + "mp4": "https://storage.googleapis.com/sleap-data/datasets/wang_4mice_john/clips/OFTsocial5mice-0000-00%4015488-18736.mp4", + "slp": "https://storage.googleapis.com/sleap-data/datasets/wang_4mice_john/clips/OFTsocial5mice-0000-00%4015488-18736.slp", + }, + "free-moving-2mice-noID-5nodes-1280x1024x1-1_9pxmm": { + "mp4": "https://storage.googleapis.com/sleap-data/datasets/eleni_mice/clips/20200111_USVpairs_court1_M1_F1_top-01112020145828-0000%400-2560.mp4", + "slp": "https://storage.googleapis.com/sleap-data/datasets/eleni_mice/clips/20200111_USVpairs_court1_M1_F1_top-01112020145828-0000%400-2560.slp", + }, + "freemoving-2bees-noID-21nodes-1535x2048x1-14pxmm": { + "mp4": "https://storage.googleapis.com/sleap-data/datasets/yan_bees/clips/bees_demo%4021000-23000.mp4", + "slp": "https://storage.googleapis.com/sleap-data/datasets/yan_bees/clips/bees_demo%4021000-23000.slp", + }, + "free-moving-4gerbils-ID-14nodes-1024x1280x3-2pxmm": { + "mp4": "https://storage.googleapis.com/sleap-data/datasets/nyu-gerbils/clips/2020-3-10_daytime_5mins_compressedTalmo%403200-5760.mp4", + "slp": "https://storage.googleapis.com/sleap-data/datasets/nyu-gerbils/clips/2020-3-10_daytime_5mins_compressedTalmo%403200-5760.slp", + }, +} + + +class TeeOutput: + """Context manager that duplicates stdout to both console and a file.""" + + def __init__(self, filepath): + """Initialize with the target file path.""" + self.filepath = Path(filepath) + self.file = None + self.original_stdout = None + + def __enter__(self): + """Open the file and redirect stdout.""" + self.filepath.parent.mkdir(parents=True, exist_ok=True) + self.file = open(self.filepath, "w") + self.original_stdout = sys.stdout + sys.stdout = self + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Restore stdout and close the file.""" + sys.stdout = self.original_stdout + if self.file: + self.file.close() + return False + + def write(self, text): + """Write text to both stdout and file.""" + self.original_stdout.write(text) + self.file.write(text) + self.file.flush() + + def flush(self): + """Flush both stdout and file.""" + self.original_stdout.flush() + self.file.flush() + + +def _download_file(url, destination): + """Download a single file to the requested destination.""" + destination.parent.mkdir(parents=True, exist_ok=True) + print(f"Downloading {destination.name}...") + urlretrieve(url, destination) + + +def ensure_demo_datasets(): + """Ensure the expected demo SLP/MP4 files exist locally.""" + ROOT_PATH.mkdir(parents=True, exist_ok=True) + SLP_DIR.mkdir(parents=True, exist_ok=True) + MP4_DIR.mkdir(parents=True, exist_ok=True) + + slp_files = sorted(SLP_DIR.glob("*.slp")) + mp4_files = sorted(MP4_DIR.glob("*.mp4")) + + should_bootstrap = len(slp_files) < len(DEMO_DATASETS) or len( + mp4_files + ) < len(DEMO_DATASETS) + if not should_bootstrap: + return + + print( + "Bootstrapping demo datasets " + f"(found {len(slp_files)} .slp and {len(mp4_files)} .mp4 files)..." + ) + + for file_stem, urls in DEMO_DATASETS.items(): + slp_target = SLP_DIR / f"{file_stem}.slp" + mp4_target = MP4_DIR / f"{file_stem}.mp4" + + if not slp_target.exists(): + _download_file(urls["slp"], slp_target) + if not mp4_target.exists(): + _download_file(urls["mp4"], mp4_target) + + +def process_single_validation(args): # noqa: C901 + """Process a single (file, individual, from_kp, to_kp) validation.""" + # Import inside worker to avoid pickling issues + from movement.io import load_poses + from movement.kinematics.body_axis import ValidateAPConfig, validate_ap + + ( + slp_path, + file_stem, + individual, + ind_idx, + from_kp, + to_kp, + from_idx, + to_idx, + n_kp, + n_ind, + n_frames, + ) = args + + # Suppress stdout using context manager + with open(os.devnull, "w") as devnull: + old_stdout = sys.stdout + sys.stdout = devnull + + try: + ds = load_poses.from_sleap_file(Path(slp_path)) + if "individuals" in ds.position.dims: + pos_data = ds.position.sel(individuals=individual) + else: + pos_data = ds.position + + config = ValidateAPConfig() + val = validate_ap( + pos_data, + from_node=from_kp, + to_node=to_kp, + config=config, + verbose=False, + ) + + rec = { + "file": file_stem, + "individual": str(individual), + "individual_idx": ind_idx, + "from_keypoint": from_kp, + "to_keypoint": to_kp, + "from_index": from_idx, + "to_index": to_idx, + "n_keypoints": n_kp, + "n_individuals": n_ind, + "n_frames": n_frames, + "validation_success": val.get("success", False), + "resultant_length": val.get("resultant_length", np.nan), + "vote_margin": val.get("vote_margin", np.nan), + "num_selected_frames": val.get("num_selected_frames", 0), + "circ_mean_dir": val.get("circ_mean_dir", np.nan), + "anterior_sign": val.get("anterior_sign", 0), + "num_clusters": val.get("num_clusters", 0), + "error_msg": val.get("error_msg", ""), + "error": False, + "error_type": "", + } + + r, m = rec["resultant_length"], rec["vote_margin"] + rec["rxm"] = r * m if not (np.isnan(r) or np.isnan(m)) else np.nan + + pr = val.get("pair_report") + if pr: + rec.update( + { + "pr_success": pr.success, + "pr_failure_step": pr.failure_step, + "pr_failure_reason": pr.failure_reason, + "pr_scenario": pr.scenario, + "pr_outcome": pr.outcome, + "pr_warning_message": pr.warning_message, + "pr_input_pair_in_candidates": ( + pr.input_pair_in_candidates + ), + "pr_input_pair_opposite_sides": ( + pr.input_pair_opposite_sides + ), + "pr_input_pair_separation_abs": ( + pr.input_pair_separation_abs + ), + "pr_input_pair_is_distal": pr.input_pair_is_distal, + "pr_input_pair_rank": pr.input_pair_rank, + "pr_input_pair_order_matches_inference": ( + pr.input_pair_order_matches_inference + ), + "pr_max_separation_distal": pr.max_separation_distal, + "pr_max_separation": pr.max_separation, + "pr_lateral_offset_min": pr.lateral_offset_min, + "pr_lateral_offset_max": pr.lateral_offset_max, + "pr_midpoint_pc1": pr.midpoint_pc1, + "pr_pc1_min": pr.pc1_min, + "pr_pc1_max": pr.pc1_max, + "pr_midline_dist_max": pr.midline_dist_max, + # Cascade counts (same for all records + # from the same individual) + "n_valid_nodes": int( + np.sum(~np.isnan(pr.lateral_offsets_norm)) + ) + if len(pr.lateral_offsets_norm) > 0 + else 0, + "n_step1_candidates": len(pr.sorted_candidate_nodes), + "n_step2_pairs": len(pr.valid_pairs), + "n_step3_distal": len(pr.distal_pairs), + "n_step3_proximal": len(pr.proximal_pairs), + } + ) + + if len(pr.max_separation_distal_nodes) > 0: + distal = pr.max_separation_distal_nodes + rec["suggested_from_idx"] = int(distal[0]) + rec["suggested_to_idx"] = int(distal[1]) + rec["suggested_type"] = "distal" + elif len(pr.max_separation_nodes) > 0: + prox = pr.max_separation_nodes + rec["suggested_from_idx"] = int(prox[0]) + rec["suggested_to_idx"] = int(prox[1]) + rec["suggested_type"] = "proximal" + else: + rec["suggested_from_idx"] = -1 + rec["suggested_to_idx"] = -1 + rec["suggested_type"] = "" + + # Store avg_skeleton and PC1 from validation result + avg_skel = val.get("avg_skeleton") + pc1_vec = val.get("PC1") + if avg_skel is not None and not np.all(np.isnan(avg_skel)): + rec["avg_skeleton"] = avg_skel.tolist() # (n_keypoints, 2) + else: + rec["avg_skeleton"] = None + if pc1_vec is not None: + rec["PC1"] = pc1_vec.tolist() # (2,) + else: + rec["PC1"] = None + + # Store velocity projections for histogram + vel_projs = val.get("vel_projs_pc1") + if vel_projs is not None and len(vel_projs) > 0: + rec["vel_projs_pc1"] = vel_projs.tolist() + else: + rec["vel_projs_pc1"] = None + + return rec + + except Exception as e: + import traceback + + print( + f"WARNING: validate_ap failed for " + f"{file_stem} / {individual} " + f"({from_kp} → {to_kp}): " + f"{type(e).__name__}: {e}", + file=sys.stderr, + ) + traceback.print_exc(file=sys.stderr) + return { + "file": file_stem, + "individual": str(individual), + "individual_idx": ind_idx, + "from_keypoint": from_kp, + "to_keypoint": to_kp, + "from_index": from_idx, + "to_index": to_idx, + "n_keypoints": n_kp, + "n_individuals": n_ind, + "n_frames": n_frames, + "error": True, + "error_type": f"{type(e).__name__}: {e}", + "validation_success": False, + "rxm": np.nan, + "resultant_length": np.nan, + "vote_margin": np.nan, + "anterior_sign": 0, + } + finally: + sys.stdout = old_stdout + + +def generate_rxm_tasks(slp_files): + """Pass 1: generate one task per (file, individual) to compute R×M. + + Uses first two GT node indices as an arbitrary pair since R×M depends + only on the individual's motion and body shape, not the input pair. + Returns tasks and metadata for subsequent passes. + """ + from movement.io import load_poses + + tasks = [] + file_metadata = {} + + for slp_file in slp_files: + ds = load_poses.from_sleap_file(slp_file) + keypoints = [str(k) for k in ds.coords["keypoints"].values] + individuals = [str(i) for i in ds.coords["individuals"].values] + n_frames = ds.sizes["time"] + n_kp = len(keypoints) + has_individuals = "individuals" in ds.position.dims + n_ind = len(individuals) if has_individuals else 1 + + # Get GT node indices for this file (use first two as arbitrary pair) + gt_nodes = GROUND_TRUTH.get(slp_file.stem, {}) + gt_indices = list(gt_nodes.keys()) + if len(gt_indices) < 2: + continue + from_idx, to_idx = gt_indices[0], gt_indices[1] + + file_metadata[slp_file.stem] = { + "slp_file": slp_file, + "keypoints": keypoints, + "individuals": individuals, + "n_frames": n_frames, + "n_kp": n_kp, + "n_ind": n_ind, + "has_individuals": has_individuals, + "gt_indices": gt_indices, + } + + # One task per individual + for ind_idx in range(n_ind): + individual = individuals[ind_idx] if has_individuals else "single" + tasks.append( + ( + str(slp_file), + slp_file.stem, + individual, + ind_idx, + keypoints[from_idx], + keypoints[to_idx], + from_idx, + to_idx, + n_kp, + n_ind, + n_frames, + ) + ) + + return tasks, file_metadata + + +def find_best_individuals(rxm_results): + """Pass 1: Select best individual per file by maximum R×M. + + Returns: + best_individuals: {file_stem: individual_name} + all_rxm: {file_stem: {individual: rxm_value}} - + R×M values for all individuals + file_individual_data: {file_stem: {individual: + {"avg_skeleton": ..., "pc1": ..., "anterior_sign": ...}}} - + per-individual skeleton, PC1, and anterior_sign for downstream use + + """ + file_individual_rxm = defaultdict(dict) + file_individual_data = defaultdict(dict) + + for rec in rxm_results: + file_stem = rec["file"] + individual = rec["individual"] + rxm = rec["rxm"] + if not np.isnan(rxm): + file_individual_rxm[file_stem][individual] = rxm + file_individual_data[file_stem][individual] = { + "avg_skeleton": rec.get("avg_skeleton"), + "pc1": rec.get("PC1"), + "anterior_sign": rec.get("anterior_sign", 0), + } + + best_individuals = {} + for file_stem, individuals in file_individual_rxm.items(): + if not individuals: + continue + best_individuals[file_stem] = max(individuals, key=individuals.get) + + all_rxm = dict(file_individual_rxm) + return best_individuals, all_rxm, dict(file_individual_data) + + +def compute_pc1_orderings( + file_individual_data, file_metadata, best_individuals +): + """Project GT nodes onto each individual's PC1 and rank by projection. + + For each individual, GT nodes are projected onto that individual's + PC1 vector. Nodes are ranked by descending PC1 projection + (rank 1 = highest projection; whether this corresponds to + anterior or posterior depends on the individual's anterior_sign). + + Returns: + best_pc1_orderings: {file_stem: {node_idx: pc1_rank}} - + best individual's ordering (used as pseudo GT in Pass 2) + all_pc1_orderings: {file_stem: {individual: {node_idx: pc1_rank}}} - + all individuals' PC1-based orderings (used in Pass 2) + + """ + best_pc1_orderings = {} + all_pc1_orderings = {} + + for file_stem, ind_data in file_individual_data.items(): + if file_stem not in file_metadata: + continue + + gt_indices = file_metadata[file_stem]["gt_indices"] + best_ind = best_individuals.get(file_stem) + all_pc1_orderings[file_stem] = {} + + for individual, data in ind_data.items(): + avg_skeleton = data.get("avg_skeleton") + pc1 = data.get("pc1") + + if avg_skeleton is None or pc1 is None: + continue + + avg_skeleton = np.array(avg_skeleton) + pc1 = np.array(pc1) + + # Project GT nodes onto raw PC1 (no anterior_sign correction). + # Rank 1 = highest raw PC1 projection; whether this corresponds + # to anterior or posterior depends on anterior_sign. + gt_projections = {} + for node_idx in gt_indices: + if node_idx < len(avg_skeleton): + pos = avg_skeleton[node_idx] + if not np.any(np.isnan(pos)): + proj = np.dot(pos, pc1) + gt_projections[node_idx] = proj + + # Rank by raw PC1 projection (highest projection = rank 1) + sorted_nodes = sorted( + gt_projections.keys(), + key=lambda x: gt_projections[x], + reverse=True, + ) + ind_ordering = { + node: rank + 1 for rank, node in enumerate(sorted_nodes) + } + all_pc1_orderings[file_stem][individual] = ind_ordering + + # Store best individual's ordering separately + if individual == best_ind: + best_pc1_orderings[file_stem] = ind_ordering + + return best_pc1_orderings, all_pc1_orderings + + +def compute_inferred_ap_concordance(file_individual_data): + """Pass 3: Compare each individual's inferred AP ordering against GT. + + For each individual, GT nodes are projected onto the inferred AP axis + (anterior_sign × PC1, where positive = more anterior). All C(n,2) + unique pairs of GT nodes are compared pairwise: a pair is concordant + if the node with the higher inferred AP coordinate also has the + higher hand-curated GT rank (higher rank = more anterior). + + Returns: + individual_accuracy: {file_stem: {individual: + {"correct": n, "total": n, "accuracy": pct}}} + + """ + individual_accuracy = {} + + for file_stem, gt_ranks in GROUND_TRUTH.items(): + individual_accuracy[file_stem] = {} + ind_data = file_individual_data.get(file_stem, {}) + + for individual, data in ind_data.items(): + avg_skeleton = data.get("avg_skeleton") + pc1 = data.get("pc1") + anterior_sign = data.get("anterior_sign", 0) + + if avg_skeleton is None or pc1 is None or anterior_sign == 0: + continue + + avg_skeleton = np.array(avg_skeleton) + pc1 = np.array(pc1) + + # Inferred AP unit vector: anterior_sign × PC1 + pc1_norm = pc1 / np.linalg.norm(pc1) + e_ap = anterior_sign * pc1_norm + + # Project GT nodes onto inferred AP axis + gt_ap_coords = {} + for node_idx in gt_ranks: + if node_idx < len(avg_skeleton): + pos = avg_skeleton[node_idx] + if not np.any(np.isnan(pos)): + gt_ap_coords[node_idx] = np.dot(pos, e_ap) + + # Compare pairwise ordering against GT + correct = 0 + total = 0 + for a, b in itertools.combinations(gt_ap_coords.keys(), 2): + total += 1 + ap_a, ap_b = gt_ap_coords[a], gt_ap_coords[b] + gt_a, gt_b = gt_ranks[a], gt_ranks[b] + + # Concordant if relative ordering agrees: + # higher AP coord = more anterior should match + # higher GT rank = more anterior + if (ap_a > ap_b and gt_a > gt_b) or ( + ap_a < ap_b and gt_a < gt_b + ): + correct += 1 + + accuracy = 100 * correct / total if total > 0 else 0 + individual_accuracy[file_stem][individual] = { + "correct": correct, + "total": total, + "accuracy": accuracy, + } + + return individual_accuracy + + +def compare_orderings_to_pseudo_gt(all_pc1_orderings, pseudo_gt_orderings): + """Pass 2: Compare each individual's PC1-based ordering against pseudo GT. + + The pseudo GT is the best individual's PC1-based ordering. Comparison + is strict list equality: the sorted node-index sequences must be + identical. A single rank swap between any two nodes counts as a + mismatch. + + Returns: + ordering_matches: {file_stem: {individual: bool}} - + True if individual's PC1-based ordering matches pseudo GT + + """ + ordering_matches = {} + + for file_stem, best_ordering in pseudo_gt_orderings.items(): + ordering_matches[file_stem] = {} + ind_orderings = all_pc1_orderings.get(file_stem, {}) + + for individual, ind_ordering in ind_orderings.items(): + # Compare orderings - match if relative order of all nodes same + # Convert to sorted list of node indices by rank to compare + best_sorted = sorted( + best_ordering.keys(), key=lambda x: best_ordering[x] + ) + ind_sorted = sorted( + ind_ordering.keys(), key=lambda x: ind_ordering[x] + ) + + matches = best_sorted == ind_sorted + ordering_matches[file_stem][individual] = matches + + return ordering_matches + + +def generate_gt_validation_tasks(file_metadata, best_individuals): + """Generate H5 tasks for all GT node pair permutations. + + Tests all permutations of GT node pairs on ALL individuals. + Results compared against hand-curated GROUND_TRUTH (higher rank = more + anterior). + + Convention: (from_idx, to_idx) is "correctly ordered" if from is + posterior (lower GT rank) and to is anterior (higher GT rank), matching + the posterior→anterior direction used by compute_polarization's + body_axis_keypoints. + """ + tasks = [] + keypoints_by_file = {} + + for file_stem, meta in file_metadata.items(): + if file_stem not in best_individuals: + continue + + keypoints = meta["keypoints"] + individuals = meta["individuals"] + gt_indices = meta["gt_indices"] + slp_file = meta["slp_file"] + + keypoints_by_file[file_stem] = keypoints + + # Test all ordered GT pairs (permutations) on ALL individuals + for from_idx, to_idx in itertools.permutations(gt_indices, 2): + for ind_idx in range(meta["n_ind"]): + has_ind = meta["has_individuals"] + individual = individuals[ind_idx] if has_ind else "single" + tasks.append( + ( + str(slp_file), + file_stem, + individual, + ind_idx, + keypoints[from_idx], + keypoints[to_idx], + from_idx, + to_idx, + meta["n_kp"], + meta["n_ind"], + meta["n_frames"], + ) + ) + + return tasks, keypoints_by_file + + +def save_to_h5( # noqa: C901 + all_results, + keypoints_by_file, + output_path, + individual_accuracy=None, + all_rxm=None, +): + """Save results to HDF5.""" + n = len(all_results) + dt_str = h5py.special_dtype(vlen=str) + + str_fields = [ + "file", + "individual", + "from_keypoint", + "to_keypoint", + "error_msg", + "error_type", + "pr_failure_step", + "pr_failure_reason", + "pr_outcome", + "pr_warning_message", + "suggested_type", + ] + int_fields = [ + "individual_idx", + "from_index", + "to_index", + "n_keypoints", + "n_individuals", + "n_frames", + "num_selected_frames", + "anterior_sign", + "num_clusters", + "pr_scenario", + "pr_input_pair_rank", + "suggested_from_idx", + "suggested_to_idx", + "n_valid_nodes", + "n_step1_candidates", + "n_step2_pairs", + "n_step3_distal", + "n_step3_proximal", + ] + float_fields = [ + "resultant_length", + "vote_margin", + "rxm", + "circ_mean_dir", + "pr_input_pair_separation_abs", + "pr_max_separation_distal", + "pr_max_separation", + "pr_lateral_offset_min", + "pr_lateral_offset_max", + "pr_midpoint_pc1", + "pr_pc1_min", + "pr_pc1_max", + "pr_midline_dist_max", + ] + bool_fields = [ + "validation_success", + "error", + "pr_success", + "pr_input_pair_in_candidates", + "pr_input_pair_opposite_sides", + "pr_input_pair_is_distal", + "pr_input_pair_order_matches_inference", + ] + + with h5py.File(output_path, "w") as f: + f.attrs["created"] = datetime.now().isoformat() + f.attrs["n_files"] = len(keypoints_by_file) + f.attrs["n_records"] = n + + kp_grp = f.create_group("keypoints") + for fname, kps in keypoints_by_file.items(): + kp_grp.create_dataset(fname, data=np.array(kps, dtype="S")) + + for field in str_fields: + f.create_dataset( + field, + data=np.array( + [r.get(field, "") for r in all_results], dtype=dt_str + ), + ) + + for field in int_fields: + f.create_dataset( + field, + data=np.array( + [r.get(field, -1) for r in all_results], dtype=np.int32 + ), + ) + + for field in float_fields: + f.create_dataset( + field, + data=np.array( + [r.get(field, np.nan) for r in all_results], + dtype=np.float64, + ), + ) + + for field in bool_fields: + f.create_dataset( + field, + data=np.array( + [r.get(field, False) for r in all_results], dtype=bool + ), + ) + + # Save avg_skeleton, PC1, and vel_projs arrays + # Group by (file, individual) - these are per-individual, not per-pair + skel_grp = f.create_group("skeletons") + pc1_grp = f.create_group("pc1_vectors") + velprojs_grp = f.create_group("vel_projs_pc1") + + # First pass: collect best data for each (file, individual) pair + # Need to search all records to find valid data (not just first) + # {(file, individual): {"skeleton": ..., "pc1": ..., "vel_projs": ...}} + pair_data = {} + for r in all_results: + file_stem = r.get("file", "") + individual = r.get("individual", "") + if not file_stem or not individual: + continue + + key = (file_stem, individual) + if key not in pair_data: + pair_data[key] = { + "skeleton": None, + "pc1": None, + "vel_projs": None, + } + + # Update with valid data if current is None + pd = pair_data[key] + if pd["skeleton"] is None and r.get("avg_skeleton") is not None: + pd["skeleton"] = r.get("avg_skeleton") + if pd["pc1"] is None and r.get("PC1") is not None: + pd["pc1"] = r.get("PC1") + if pd["vel_projs"] is None and r.get("vel_projs_pc1") is not None: + pd["vel_projs"] = r.get("vel_projs_pc1") + + # Second pass: save collected data to H5 + for (file_stem, individual), data in pair_data.items(): + # Create file group if needed + if file_stem not in skel_grp: + skel_grp.create_group(file_stem) + pc1_grp.create_group(file_stem) + velprojs_grp.create_group(file_stem) + + if data["skeleton"] is not None: + arr = np.array(data["skeleton"], dtype=np.float64) + skel_grp[file_stem].create_dataset(individual, data=arr) + if data["pc1"] is not None: + arr = np.array(data["pc1"], dtype=np.float64) + pc1_grp[file_stem].create_dataset(individual, data=arr) + if data["vel_projs"] is not None: + arr = np.array(data["vel_projs"], dtype=np.float64) + velprojs_grp[file_stem].create_dataset(individual, data=arr) + + # Save individual accuracy (raw PC1 direction diag for Fig 1) + if individual_accuracy: + acc_grp = f.create_group("individual_accuracy") + for file_stem, accuracies in individual_accuracy.items(): + file_grp = acc_grp.create_group(file_stem) + for individual, acc in accuracies.items(): + ind_grp = file_grp.create_group(individual) + ind_grp.create_dataset("correct", data=acc["correct"]) + ind_grp.create_dataset("total", data=acc["total"]) + ind_grp.create_dataset("accuracy", data=acc["accuracy"]) + + # Save R×M values for all individuals + if all_rxm: + rxm_grp = f.create_group("individual_rxm") + for file_stem, individuals in all_rxm.items(): + file_grp = rxm_grp.create_group(file_stem) + for individual, rxm_val in individuals.items(): + file_grp.create_dataset(individual, data=rxm_val) + + +# Analysis functions + + +def load_h5_data(h5_path): # noqa: C901 + """Load relevant fields from H5 file.""" + + def decode_str(x): + return x.decode() if isinstance(x, bytes) else x + + data = {} + with h5py.File(h5_path, "r") as f: + # String fields + data["file"] = [decode_str(x) for x in f["file"][:]] + data["individual"] = [decode_str(x) for x in f["individual"][:]] + data["suggested_type"] = [ + decode_str(x) for x in f["suggested_type"][:] + ] + + # Index fields + data["from_index"] = np.array(f["from_index"]) + data["to_index"] = np.array(f["to_index"]) + data["suggested_from_idx"] = np.array(f["suggested_from_idx"]) + data["suggested_to_idx"] = np.array(f["suggested_to_idx"]) + + # Validation results + data["validation_success"] = np.array(f["validation_success"]) + data["anterior_sign"] = np.array(f["anterior_sign"]) + data["pr_input_pair_order_matches_inference"] = np.array( + f["pr_input_pair_order_matches_inference"] + ) + data["pr_success"] = np.array(f["pr_success"]) + data["pr_input_pair_in_candidates"] = np.array( + f["pr_input_pair_in_candidates"] + ) + data["rxm"] = np.array(f["rxm"]) + data["vote_margin"] = np.array(f["vote_margin"]) + data["resultant_length"] = np.array(f["resultant_length"]) + n = len(data["file"]) + if "circ_mean_dir" in f: + data["circ_mean_dir"] = np.array(f["circ_mean_dir"]) + else: + data["circ_mean_dir"] = np.full(n, np.nan) + if "num_selected_frames" in f: + data["num_selected_frames"] = np.array(f["num_selected_frames"]) + else: + data["num_selected_frames"] = np.zeros(n, dtype=np.int32) + if "n_frames" in f: + data["n_frames"] = np.array(f["n_frames"]) + else: + data["n_frames"] = np.zeros(n, dtype=np.int32) + + # Cascade counts (backward-compatible) + for cascade_field in [ + "n_valid_nodes", + "n_step1_candidates", + "n_step2_pairs", + "n_step3_distal", + "n_step3_proximal", + ]: + if cascade_field in f: + data[cascade_field] = np.array(f[cascade_field]) + else: + data[cascade_field] = np.zeros(n, dtype=np.int32) + + # Load skeleton and PC1 data (if present) + data["skeletons"] = {} # {file_stem: {individual: np.array}} + data["pc1_vectors"] = {} # {file_stem: {individual: np.array}} + + if "skeletons" in f: + for file_stem in f["skeletons"]: + data["skeletons"][file_stem] = {} + for individual in f["skeletons"][file_stem]: + data["skeletons"][file_stem][individual] = np.array( + f["skeletons"][file_stem][individual] + ) + + if "pc1_vectors" in f: + for file_stem in f["pc1_vectors"]: + data["pc1_vectors"][file_stem] = {} + for individual in f["pc1_vectors"][file_stem]: + data["pc1_vectors"][file_stem][individual] = np.array( + f["pc1_vectors"][file_stem][individual] + ) + + # Load velocity projections for histogram + data["vel_projs_pc1"] = {} + if "vel_projs_pc1" in f: + for file_stem in f["vel_projs_pc1"]: + data["vel_projs_pc1"][file_stem] = {} + for individual in f["vel_projs_pc1"][file_stem]: + data["vel_projs_pc1"][file_stem][individual] = np.array( + f["vel_projs_pc1"][file_stem][individual] + ) + + # Load keypoint names + data["keypoints"] = {} + if "keypoints" in f: + for file_stem in f["keypoints"]: + kp_data = f["keypoints"][file_stem][:] + data["keypoints"][file_stem] = [ + x.decode() if isinstance(x, bytes) else x for x in kp_data + ] + + # Load individual accuracy data (raw PC1 direction diagnostic) + data["individual_accuracy"] = {} + if "individual_accuracy" in f: + for file_stem in f["individual_accuracy"]: + data["individual_accuracy"][file_stem] = {} + for individual in f["individual_accuracy"][file_stem]: + ind_grp = f["individual_accuracy"][file_stem][individual] + data["individual_accuracy"][file_stem][individual] = { + "correct": int(ind_grp["correct"][()]), + "total": int(ind_grp["total"][()]), + "accuracy": float(ind_grp["accuracy"][()]), + } + + # Load R×M values for all individuals + data["individual_rxm"] = {} + if "individual_rxm" in f: + for file_stem in f["individual_rxm"]: + data["individual_rxm"][file_stem] = {} + for individual in f["individual_rxm"][file_stem]: + data["individual_rxm"][file_stem][individual] = float( + f["individual_rxm"][file_stem][individual][()] + ) + + return data + + +def find_best_individual_per_file(data): + """Find the individual with highest mean R×M for each file.""" + file_individual_rxm = defaultdict(lambda: defaultdict(list)) + + for i in range(len(data["file"])): + file_stem = data["file"][i] + individual = data["individual"][i] + rxm = data["rxm"][i] + if not np.isnan(rxm): + file_individual_rxm[file_stem][individual].append(rxm) + + best_individual = {} + for file_stem, individuals in file_individual_rxm.items(): + best_rxm = -1 + best_ind = None + for ind, rxm_list in individuals.items(): + mean_rxm = np.mean(rxm_list) + if mean_rxm > best_rxm: + best_rxm = mean_rxm + best_ind = ind + best_individual[file_stem] = best_ind + + return best_individual + + +def find_step1_surviving_nodes(data, best_individual): + """Find nodes surviving the lateral alignment filter (Step 1) per file. + + A node survives Step 1 if it appears in any pair where + pr_input_pair_in_candidates=True for the file's best individual. + """ + file_surviving_nodes = defaultdict(set) + + for i in range(len(data["file"])): + file_stem = data["file"][i] + individual = data["individual"][i] + + # Only consider best individual + if individual != best_individual.get(file_stem): + continue + + # If this pair survived Step 1, both nodes are candidates + if data["pr_input_pair_in_candidates"][i]: + from_idx = int(data["from_index"][i]) + to_idx = int(data["to_index"][i]) + file_surviving_nodes[file_stem].add(from_idx) + file_surviving_nodes[file_stem].add(to_idx) + + return file_surviving_nodes + + +def compute_gt_coverage(file_surviving_nodes): + """Compute lateral filter coverage: fraction of GT nodes surviving Step 1. + + Returns dict: {file_stem: {surviving_in_gt, gt_total, coverage_pct, ...}} + """ + coverage = {} + + for file_stem, gt_ranks in GROUND_TRUTH.items(): + gt_nodes = set(gt_ranks.keys()) + n_gt_total = len(gt_nodes) + + surviving = file_surviving_nodes.get(file_stem, set()) + surviving_in_gt = surviving & gt_nodes + n_surviving_in_gt = len(surviving_in_gt) + + if n_gt_total > 0: + coverage_pct = 100 * n_surviving_in_gt / n_gt_total + else: + coverage_pct = 0 + + coverage[file_stem] = { + "surviving_in_gt": n_surviving_in_gt, + "gt_total": n_gt_total, + "coverage_pct": coverage_pct, + "surviving_nodes": sorted(surviving), + "gt_nodes": sorted(gt_nodes), + "surviving_in_gt_nodes": sorted(surviving_in_gt), + "gt_not_surviving": sorted(gt_nodes - surviving), + } + + return coverage + + +def analyze_suggested_pairs(data, best_individual): # noqa: C901 + """Report which node pair the 3-step filter cascade auto-selected. + + For each file's best individual, retrieves the suggested pair + (posterior→anterior, distal vs proximal) and notes whether the + selected nodes happen to fall within the hand-curated GT subset. + GT membership is incidental context — the pipeline often selects + nodes outside the GT subset, which is expected. + """ + # Get suggested pair for each file (use first record for best individual) + file_suggested = {} + + for i in range(len(data["file"])): + file_stem = data["file"][i] + individual = data["individual"][i] + + if individual != best_individual.get(file_stem): + continue + + # Only record once per file (suggested pair same for all input pairs) + if file_stem in file_suggested: + continue + + suggested_from = int(data["suggested_from_idx"][i]) + suggested_to = int(data["suggested_to_idx"][i]) + suggested_type = data["suggested_type"][i] + + file_suggested[file_stem] = { + "from_idx": suggested_from, + "to_idx": suggested_to, + "type": suggested_type, + } + + # Analyze each file + results = {} + for file_stem, gt_ranks in GROUND_TRUTH.items(): + gt_nodes = set(gt_ranks.keys()) + default = {"from_idx": -1, "to_idx": -1, "type": ""} + suggested = file_suggested.get(file_stem, default) + + from_idx = suggested["from_idx"] + to_idx = suggested["to_idx"] + stype = suggested["type"] + + # Check if suggested pair is valid + if from_idx < 0 or to_idx < 0: + results[file_stem] = { + "suggested_from": from_idx, + "suggested_to": to_idx, + "suggested_type": stype, + "from_in_gt": False, + "to_in_gt": False, + "both_in_gt": False, + "order_correct": None, + "status": "NO SUGGESTION", + } + continue + + from_in_gt = from_idx in gt_nodes + to_in_gt = to_idx in gt_nodes + both_in_gt = from_in_gt and to_in_gt + + # Check ordering if both are in GT + order_correct = None + if both_in_gt: + from_rank = gt_ranks[from_idx] + to_rank = gt_ranks[to_idx] + # Correct: from=posterior (lower rank), to=anterior (higher) + order_correct = from_rank < to_rank + + # Determine status (GT membership is incidental — the pipeline + # often selects nodes outside the hand-curated GT subset) + if both_in_gt and order_correct: + status = "Both in GT, order correct" + elif both_in_gt and not order_correct: + status = "Both in GT, order reversed" + elif from_in_gt or to_in_gt: + status = "One node in GT" + else: + status = "Neither node in GT" + + results[file_stem] = { + "suggested_from": from_idx, + "suggested_to": to_idx, + "suggested_type": stype, + "from_in_gt": from_in_gt, + "to_in_gt": to_in_gt, + "both_in_gt": both_in_gt, + "order_correct": order_correct, + "from_rank": gt_ranks.get(from_idx, "NaN"), + "to_rank": gt_ranks.get(to_idx, "NaN"), + "status": status, + } + + return results + + +def get_ground_truth_order(file_stem, from_idx, to_idx): + """Determine hand-curated GT AP ordering for a node pair. + + Convention: (from, to) = (posterior, anterior). This matches + compute_polarization's body_axis_keypoints convention where the vector + points FROM posterior TO anterior. + + Returns: + 1 if from_idx is posterior to to_idx (correctly ordered) + -1 if from_idx is anterior to to_idx (reversed) + None if either index is not in the hand-curated ground truth + + """ + if file_stem not in GROUND_TRUTH: + return None + + gt = GROUND_TRUTH[file_stem] + if from_idx not in gt or to_idx not in gt: + return None + + from_rank = gt[from_idx] + to_rank = gt[to_idx] + + # Higher rank = more anterior, lower rank = more posterior + # Correct: from=posterior (lower rank), to=anterior (higher rank) + if from_rank < to_rank: + return 1 # from is posterior, to is anterior (correct order) + elif from_rank > to_rank: + return -1 # from is anterior, to is posterior (incorrect order) + else: + return 0 # same rank (shouldn't happen with non-NaN unique ranks) + + +def analyze_results(h5_path): # noqa: C901 + """Report filter cascade, suggested pairs, and Figure 2 data. + + Read the H5 file produced by the parallel validate_ap runs. For each + file's best individual (highest mean R×M), log the 3-step filter + cascade progression (with GT coverage folded into Step 1) and the + suggested pair analysis. Return cascade stats, GT coverage, and + suggested pair data for Figure 2 rendering. + """ + print("\n" + "─" * 70) + print("REPORTING: Filter Cascade & Suggested Pairs") + print("─" * 70) + + data = load_h5_data(h5_path) + n_records = len(data["file"]) + + # Find best individual per file (already reported in Pass 1, + # recomputed here from H5 using mean R×M across all records) + best_individual = find_best_individual_per_file(data) + + # Find Step 1 surviving nodes and compute GT coverage + file_surviving_nodes = find_step1_surviving_nodes(data, best_individual) + gt_coverage = compute_gt_coverage(file_surviving_nodes) + + # Extract cascade progression stats for each file's best individual + cascade_stats = {} + seen_files = set() + for i in range(n_records): + file_stem = data["file"][i] + individual = data["individual"][i] + + if individual != best_individual.get(file_stem): + continue + if file_stem in seen_files: + continue + seen_files.add(file_stem) + + n_valid = int(data["n_valid_nodes"][i]) + n_s1 = int(data["n_step1_candidates"][i]) + n_s2 = int(data["n_step2_pairs"][i]) + n_s3d = int(data["n_step3_distal"][i]) + n_s3p = int(data["n_step3_proximal"][i]) + + n_candidate_pairs = n_s1 * (n_s1 - 1) // 2 if n_s1 >= 2 else 0 + + cascade_stats[file_stem] = { + "n_valid_nodes": n_valid, + "n_step1_candidates": n_s1, + "n_candidate_pairs": n_candidate_pairs, + "n_step2_pairs": n_s2, + "n_step3_distal": n_s3d, + "n_step3_proximal": n_s3p, + } + + # Log cascade progression with GT coverage folded into Step 1 + print("\n3-STEP FILTER CASCADE (best individual per file):") + for file_stem in sorted(cascade_stats.keys()): + cs = cascade_stats[file_stem] + cov = gt_coverage.get(file_stem, {}) + label = FILE_LABELS.get(file_stem, file_stem[:15]) + n_valid = cs["n_valid_nodes"] + n_s1 = cs["n_step1_candidates"] + n_cp = cs["n_candidate_pairs"] + n_s2 = cs["n_step2_pairs"] + n_s3d = cs["n_step3_distal"] + + gt_surv = cov.get("surviving_in_gt", 0) + gt_tot = cov.get("gt_total", 0) + gt_str = f" (GT: {gt_surv}/{gt_tot} nodes)" if gt_tot > 0 else "" + + print(f"\n {label}:") + print(f" Step 1 (lateral): {n_s1}/{n_valid} nodes{gt_str}") + print(f" Step 2 (opposite): {n_s2}/{n_cp} candidate pairs") + print(f" Step 3 (distal): {n_s3d}/{n_s2} pairs") + + # Analyze suggested pairs + suggested_analysis = analyze_suggested_pairs(data, best_individual) + + print("\n" + "─" * 70) + print("SUGGESTED AP PAIR: Auto-Selected Node Pair (3-Step Filter Cascade)") + print("─" * 70) + for file_stem in sorted(suggested_analysis.keys()): + r = suggested_analysis[file_stem] + label = FILE_LABELS.get(file_stem, file_stem[:15]) + print(f"\n{label} ({file_stem}):") + gt_nodes = sorted(GROUND_TRUTH[file_stem].keys()) + print(f" Hand-curated GT node indices: {gt_nodes}") + sf, st = r["suggested_from"], r["suggested_to"] + print( + f" Suggested pair: [{sf} → {st}] " + f"(posterior → anterior, type: {r['suggested_type']})" + ) + + if r["suggested_from"] >= 0: + print( + f" Posterior node {sf}: " + f"in GT = {r['from_in_gt']}, " + f"GT rank = {r['from_rank']}" + ) + print( + f" Anterior node {st}: " + f"in GT = {r['to_in_gt']}, " + f"GT rank = {r['to_rank']}" + ) + + if r["both_in_gt"]: + if r["order_correct"]: + order_str = ( + "CORRECT (from_node is posterior, " + "to_node is anterior per GT)" + ) + else: + order_str = ( + "REVERSED (from_node is anterior, " + "to_node is posterior per GT)" + ) + print(f" Ordering: {order_str}") + + print(f" Status: {r['status']}") + print() + + return { + "best_individual": best_individual, + "gt_coverage": gt_coverage, + "suggested_analysis": suggested_analysis, + "cascade_stats": cascade_stats, + "skeletons": data.get("skeletons", {}), + "pc1_vectors": data.get("pc1_vectors", {}), + "vel_projs_pc1": data.get("vel_projs_pc1", {}), + "keypoints": data.get("keypoints", {}), + # For individual plot generation + "file": data["file"], + "individual": data["individual"], + "rxm": data["rxm"], + "vote_margin": data["vote_margin"], + "resultant_length": data["resultant_length"], + "anterior_sign": data["anterior_sign"], + "circ_mean_dir": data["circ_mean_dir"], + "num_selected_frames": data["num_selected_frames"], + "n_frames": data["n_frames"], + # Per-individual diagnostic data: GT concordance and R×M values + "individual_accuracy": data.get("individual_accuracy", {}), + "individual_rxm": data.get("individual_rxm", {}), + } + + +def plot_validation_results(analysis_data, output_path): # noqa: C901 + """Create a tiled layout figure showing validation results. + + Parameters + ---------- + analysis_data : dict + Analysis results from analyze_results() + output_path : Path + Output path for saving figure + + """ + _gt_coverage = analysis_data["gt_coverage"] # noqa: F841 + suggested_analysis = analysis_data["suggested_analysis"] + cascade_stats = analysis_data["cascade_stats"] + + # Short labels for files (filename.slp format) + file_labels = { + "free-moving-2flies-ID-13nodes-1024x1024x1-30_3pxmm": "2Flies.slp", + "free-moving-2mice-noID-5nodes-1280x1024x1-1_9pxmm": "2Mice.slp", + "free-moving-4gerbils-ID-14nodes-1024x1280x3-2pxmm": "4Gerbils.slp", + "free-moving-5mice-noID-11nodes-1280x1024x1-1_97pxmm": "5Mice.slp", + "freemoving-2bees-noID-21nodes-1535x2048x1-14pxmm": "2Bees.slp", + } + + files = sorted(GROUND_TRUTH.keys()) + n_files = len(files) + + bg_color = "white" + text_color = "black" + axis_color = "black" + midline_color = "black" + ap_arrow_color = "black" + gt_node_color = "black" + label_bg_alpha = 0.8 + + fig = plt.figure(figsize=(14, 10), facecolor=bg_color) + gs = fig.add_gridspec( + 2, 2, height_ratios=[1, 1], width_ratios=[1, 1], hspace=0.3, wspace=0.3 + ) + fig.suptitle( + "AP Validation: Average Skeleton (Best Individual by R×M)", + fontsize=14, + fontweight="bold", + color=text_color, + ) + + # Color palette + colors = [ + (0.12, 0.47, 0.71, 1.0), # Blue (2 Flies) + (0.17, 0.63, 0.17, 1.0), # Green (2 Mice) + (0.80, 0.70, 0.10, 1.0), # Gold (4 Gerbils) + (1.00, 0.50, 0.05, 1.0), # Orange (5 Mice) + (0.58, 0.40, 0.74, 1.0), # Purple (2 Bees) + ] + + ax1 = fig.add_subplot(gs[0, :], facecolor=bg_color) + ax1.set_xticks([]) + ax1.set_yticks([]) + for spine in ax1.spines.values(): + spine.set_visible(False) + + # Get skeletons and PC1 from stored H5 data + best_individual = analysis_data["best_individual"] + stored_skeletons = analysis_data.get("skeletons", {}) + stored_pc1 = analysis_data.get("pc1_vectors", {}) + skeleton_data = [] + + for f in files: + best_ind = best_individual.get(f) + if best_ind is None: + continue + + # Try to get stored skeleton from H5 + if f in stored_skeletons and best_ind in stored_skeletons[f]: + avg_skel = stored_skeletons[f][best_ind] # (n_keypoints, 2) + x = avg_skel[:, 0] + y = avg_skel[:, 1] + pc1 = stored_pc1.get(f, {}).get(best_ind) # (2,) or None + else: + # Fallback: load from .slp file + from movement.io import load_poses + + slp_path = SLP_DIR / f"{f}.slp" + if not slp_path.exists(): + continue + + ds = load_poses.from_sleap_file(slp_path) + if "individuals" in ds.position.dims: + pos = ds.position.sel(individuals=best_ind) + else: + pos = ds.position + + x = np.nanmean(pos.sel(space="x").values, axis=0) + y = np.nanmean(pos.sel(space="y").values, axis=0) + pc1 = None # Will compute below if needed + + # Get suggested pair for this file + default_sug = { + "suggested_from": -1, + "suggested_to": -1, + "status": "NO SUGGESTION", + } + suggested = suggested_analysis.get(f, default_sug) + + skeleton_data.append( + { + "file": f, + "label": file_labels.get(f, f[:10]), + "best_ind": best_ind, # Animal identity + "x": x.copy(), + "y": y.copy(), + "pc1": pc1, # Stored PC1 vector (or None) + "gt": GROUND_TRUTH.get(f, {}), + "suggested_from": suggested.get("suggested_from", -1), + "suggested_to": suggested.get("suggested_to", -1), + "from_in_gt": suggested.get("from_in_gt", False), + "to_in_gt": suggested.get("to_in_gt", False), + "status": suggested.get("status", "NO SUGGESTION"), + } + ) + + # Process and plot skeletons side by side + n_skeletons = len(skeleton_data) + midline_plotted = False # For legend + + if n_skeletons > 0: + x_offset = 0 + spacing = 1.8 # Space between skeletons + + for idx, skel in enumerate(skeleton_data): + x, y = skel["x"], skel["y"] + gt = skel["gt"] + stored_pc1 = skel.get("pc1") + + # Get valid (non-NaN) node indices + valid_mask = ~np.isnan(x) & ~np.isnan(y) + valid_indices = np.where(valid_mask)[0] + + if len(valid_indices) < 2: + continue + + # avg_skeleton from collective.py is already centered, use directly + x_centered = x.copy() + y_centered = y.copy() + + # Use stored PC1 from collective.py if available, else compute + if stored_pc1 is not None: + pc1 = np.array(stored_pc1) + else: + # Fallback: compute PC1 via PCA on valid nodes + xv = x_centered[valid_mask] + yv = y_centered[valid_mask] + valid_coords = np.column_stack([xv, yv]) + cov_matrix = np.cov(valid_coords.T) + eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix) + pc1 = eigenvectors[:, np.argmax(eigenvalues)] + + # Ensure PC1 points toward the most anterior GT node + # (highest rank = most anterior in GT convention) + anterior_node = None + max_rank = -1 + for node_idx, rank in gt.items(): + if rank > max_rank: + max_rank = rank + anterior_node = node_idx + + ant_ok = anterior_node is not None and not np.isnan( + x_centered[anterior_node] + ) + if ant_ok: + ax, ay = x_centered[anterior_node], y_centered[anterior_node] + ant_vec = np.array([ax, ay]) + if np.dot(pc1, ant_vec) < 0: + pc1 = -pc1 # Flip to point toward anterior + + # Normalize scale (max of x/y range for square proportions) + x_range = np.nanmax(x_centered) - np.nanmin(x_centered) + y_range = np.nanmax(y_centered) - np.nanmin(y_centered) + max_range = max(x_range, y_range) + if max_range > 0: + scale = 1.5 / max_range # Scale up for larger display + x_centered *= scale + y_centered *= scale + + # Offset horizontally for side-by-side placement + x_plot = x_centered + x_offset + y_plot = y_centered + + file_color = colors[idx] + + # Compute axis extent based on projections onto PC1 + xv, yv = x_centered[valid_mask], y_centered[valid_mask] + proj_pc1 = xv * pc1[0] + yv * pc1[1] + min_proj, max_proj = np.min(proj_pc1), np.max(proj_pc1) + axis_extent = (max_proj - min_proj) * 0.7 + + # Draw PC1 axis at actual angle (matching color) + ae = axis_extent + ax1.plot( + [x_offset - ae * pc1[0], x_offset + ae * pc1[0]], + [0 - ae * pc1[1], 0 + ae * pc1[1]], + "-", + color=(*file_color[:3], 0.5), + linewidth=1.5, + zorder=1, + ) + + # Draw AP midline (dashed black, perpendicular to PC1) + pc1_perp = np.array([-pc1[1], pc1[0]]) # Perpendicular to PC1 + midline_extent = 0.2 + mid_y = (min_proj + max_proj) / 2 # Midpoint along PC1 + mid_x_pos = x_offset + mid_y * pc1[0] + mid_y_pos = mid_y * pc1[1] + + me = midline_extent + mx0 = mid_x_pos - me * pc1_perp[0] + mx1 = mid_x_pos + me * pc1_perp[0] + my0 = mid_y_pos - me * pc1_perp[1] + my1 = mid_y_pos + me * pc1_perp[1] + if not midline_plotted: + ax1.plot( + [mx0, mx1], + [my0, my1], + "--", + color=midline_color, + linewidth=1.5, + alpha=0.7, + zorder=2, + label="AP midline", + ) + midline_plotted = True + else: + ax1.plot( + [mx0, mx1], + [my0, my1], + "--", + color=midline_color, + linewidth=1.5, + alpha=0.7, + zorder=2, + ) + + # Draw AP bidirectional arrow next to midline + ap_off = midline_extent * 1.1 # Offset from midline center + ap_half = axis_extent * 0.125 # Half-length of arrow shaft + + # Arrow center (offset perpendicular to PC1 from midline center) + acx = mid_x_pos + ap_off * pc1_perp[0] + acy = mid_y_pos + ap_off * pc1_perp[1] + + # Arrow endpoints along PC1 direction + arrow_ant_x = acx + ap_half * pc1[0] # Anterior end (+PC1) + arrow_ant_y = acy + ap_half * pc1[1] + arrow_post_x = acx - ap_half * pc1[0] # Posterior end (-PC1) + arrow_post_y = acy - ap_half * pc1[1] + + # Draw arrow shaft + ax1.plot( + [arrow_post_x, arrow_ant_x], + [arrow_post_y, arrow_ant_y], + "-", + color=ap_arrow_color, + linewidth=1.5, + zorder=2, + ) + + # Draw triangular arrowheads at both ends + head_len = 0.06 # Arrowhead length + head_wid = 0.03 # Arrowhead half-width + + # Anterior arrowhead (pointing toward +PC1) + ant_head_base_x = arrow_ant_x - head_len * pc1[0] + ant_head_base_y = arrow_ant_y - head_len * pc1[1] + hw = head_wid + ahbx, ahby = ant_head_base_x, ant_head_base_y + ant_head_tri = np.array( + [ + [arrow_ant_x, arrow_ant_y], + [ahbx + hw * pc1_perp[0], ahby + hw * pc1_perp[1]], + [ahbx - hw * pc1_perp[0], ahby - hw * pc1_perp[1]], + ] + ) + ax1.fill( + ant_head_tri[:, 0], + ant_head_tri[:, 1], + color=ap_arrow_color, + zorder=3, + ) + + # Posterior arrowhead (pointing toward -PC1) + phbx = arrow_post_x + head_len * pc1[0] + phby = arrow_post_y + head_len * pc1[1] + post_head_tri = np.array( + [ + [arrow_post_x, arrow_post_y], + [phbx + hw * pc1_perp[0], phby + hw * pc1_perp[1]], + [phbx - hw * pc1_perp[0], phby - hw * pc1_perp[1]], + ] + ) + ax1.fill( + post_head_tri[:, 0], + post_head_tri[:, 1], + color=ap_arrow_color, + zorder=3, + ) + + # Add "A" and "P" labels beyond arrow ends (along PC1 direction) + ap_lbl = 0.08 + ax1.text( + arrow_ant_x + ap_lbl * pc1[0], + arrow_ant_y + ap_lbl * pc1[1], + "A", + fontsize=10, + fontweight="bold", + color=ap_arrow_color, + ha="center", + va="center", + zorder=10, + ) + ax1.text( + arrow_post_x - ap_lbl * pc1[0], + arrow_post_y - ap_lbl * pc1[1], + "P", + fontsize=10, + fontweight="bold", + color=ap_arrow_color, + ha="center", + va="center", + zorder=10, + ) + + # Add +PC1 and -PC1 labels at ends of axis + lo = 0.05 + ax1.text( + x_offset + ae * pc1[0] + lo * pc1_perp[0], + ae * pc1[1] + lo * pc1_perp[1], + "+PC1", + fontsize=10, + fontweight="bold", + color=file_color, + ha="center", + va="center", + alpha=0.8, + zorder=10, + ) + ax1.text( + x_offset - ae * pc1[0] + lo * pc1_perp[0], + -ae * pc1[1] + lo * pc1_perp[1], + "-PC1", + fontsize=10, + fontweight="bold", + color=file_color, + ha="center", + va="center", + alpha=0.8, + zorder=10, + ) + + # Get suggested pair info first (needed for node coloring) + from_idx = skel.get("suggested_from", -1) + to_idx = skel.get("suggested_to", -1) + status = skel.get("status", "NO SUGGESTION") + from_in_gt = skel.get("from_in_gt", False) + to_in_gt = skel.get("to_in_gt", False) + + # Status colors (same as status legend) + status_colors_map = { + "Both in GT, order correct": "#2ecc71", + "Both in GT, order reversed": "#f39c12", + "One node in GT": "#e74c3c", + "Neither node in GT": "#95a5a6", + "NO SUGGESTION": "#bdc3c7", + } + line_color = status_colors_map.get(status, "#bdc3c7") + + # Plot nodes + n_nodes = len(x_plot) + + for node_idx in range(n_nodes): + if np.isnan(x_plot[node_idx]) or np.isnan(y_plot[node_idx]): + continue + + # Check if node is NaN (not in GT) + is_nan = node_idx not in gt + # Check if node is part of suggested pair + is_suggested = node_idx in (from_idx, to_idx) + + # Set alpha and size based on node type + size = 30 if is_nan else 50 + + # Dot color: GT=theme, suggested non-GT=status, else file_color + if not is_nan: + dot_color = gt_node_color # GT nodes use theme color + alpha = 1.0 + elif is_suggested: + # Suggested non-GT nodes use status color + dot_color = line_color + alpha = 1.0 + else: + # Purple non-GT, non-suggested nodes + dot_color = file_color + alpha = 0.30 + + ax1.scatter( + x_plot[node_idx], + y_plot[node_idx], + c=[dot_color], + s=size, + alpha=alpha, + edgecolors="white" if not is_nan else "none", + linewidths=0.5, + zorder=5, + ) + + # Add node index label for ranked nodes only + if not is_nan: + ax1.annotate( + str(node_idx), + (x_plot[node_idx], y_plot[node_idx]), + xytext=(4, 4), + textcoords="offset points", + fontsize=13, + color=gt_node_color, + fontweight="bold", + zorder=10, + ) + + # Draw arrow connecting suggested pair (color based on status) + # Arrow points from from_node (posterior) to to_node (anterior) + if from_idx >= 0 and to_idx >= 0: + from_ok = not np.isnan(x_plot[from_idx]) and not np.isnan( + y_plot[from_idx] + ) + to_ok = not np.isnan(x_plot[to_idx]) and not np.isnan( + y_plot[to_idx] + ) + if from_ok and to_ok: + # Draw arrow from posterior to anterior + ax1.annotate( + "", + xy=(x_plot[to_idx], y_plot[to_idx]), + xytext=(x_plot[from_idx], y_plot[from_idx]), + arrowprops=dict( + arrowstyle="-|>", + color=line_color, + lw=1.5, + mutation_scale=10, + alpha=0.85, + ), + zorder=4, + ) + + # Redraw suggested nodes with higher zorder to be on top + # GT nodes use theme color, non-GT nodes use line_color + if not np.isnan(x_plot[from_idx]): + from_color = ( + gt_node_color if from_idx in gt else line_color + ) + ax1.scatter( + x_plot[from_idx], + y_plot[from_idx], + c=[from_color], + s=50, + alpha=1.0, + edgecolors="white", + linewidths=0.5, + zorder=6, + ) + ax1.annotate( + str(from_idx), + (x_plot[from_idx], y_plot[from_idx]), + xytext=(4, 4), + textcoords="offset points", + fontsize=13, + color=from_color, + fontweight="bold", + zorder=10, + ) + + if not np.isnan(x_plot[to_idx]): + to_color = ( + gt_node_color if to_idx in gt else line_color + ) + ax1.scatter( + x_plot[to_idx], + y_plot[to_idx], + c=[to_color], + s=50, + alpha=1.0, + edgecolors="white", + linewidths=0.5, + zorder=6, + ) + ax1.annotate( + str(to_idx), + (x_plot[to_idx], y_plot[to_idx]), + xytext=(4, 4), + textcoords="offset points", + fontsize=13, + color=to_color, + fontweight="bold", + zorder=10, + ) + + # Add label below skeleton + ax1.text( + x_offset, + -1.25, + skel["label"], + ha="center", + va="top", + fontsize=12, + fontweight="bold", + color=file_color, + clip_on=False, + ) + ax1.text( + x_offset, + -1.45, + skel["best_ind"], + ha="center", + va="top", + fontsize=12, + fontweight="bold", + color=file_color, + clip_on=False, + ) + + # Add suggested pair [ # , # ] with color-coded numbers + if from_idx >= 0 and to_idx >= 0: + # Determine colors for each number based on GT membership + from_num_color = gt_node_color if from_in_gt else line_color + to_num_color = gt_node_color if to_in_gt else line_color + + # Build text parts with spacing + pair_y = -1.65 + char_width = 0.08 # Approximate character width in data units + + # Calculate total width and starting position for centering + from_str = str(from_idx) + to_str = str(to_idx) + # Format: "[ " (2) + from + " , " (3) + to + " ]" (2) + total_chars = 2 + len(from_str) + 3 + len(to_str) + 2 + start_x = x_offset - (total_chars * char_width) / 2 + + # Draw each part with spaces + pos = start_x + ax1.text( + pos, + pair_y, + "[ ", + ha="left", + va="top", + fontsize=14, + fontweight="bold", + color="black", + clip_on=False, + ) + pos += char_width * 2 + ax1.text( + pos, + pair_y, + from_str, + ha="left", + va="top", + fontsize=14, + fontweight="bold", + color=from_num_color, + clip_on=False, + ) + pos += char_width * len(from_str) + ax1.text( + pos, + pair_y, + " , ", + ha="left", + va="top", + fontsize=14, + fontweight="bold", + color="black", + clip_on=False, + ) + pos += char_width * 3 + ax1.text( + pos, + pair_y, + to_str, + ha="left", + va="top", + fontsize=14, + fontweight="bold", + color=to_num_color, + clip_on=False, + ) + pos += char_width * len(to_str) + ax1.text( + pos, + pair_y, + " ]", + ha="left", + va="top", + fontsize=14, + fontweight="bold", + color="black", + clip_on=False, + ) + + x_offset += spacing + + # Set axis limits (wider since spanning full row) + x_total = x_offset - spacing + ax1.set_xlim(-0.8, x_total + 0.8) + ax1.set_ylim(-1.85, 1.2) # Extended to accommodate three-line labels + + # Force equal aspect ratio for square skeleton display + ax1.set_aspect("equal", adjustable="box") + + # Add legend for AP midline and GT Node (positioned lower right) + midline_handle = Line2D( + [0], + [0], + linestyle="--", + color=midline_color, + linewidth=1.5, + alpha=0.7, + ) + gt_node_handle = Line2D( + [0], [0], marker="o", color="black", linestyle="None", markersize=6 + ) + midline_legend = ax1.legend( + [midline_handle, gt_node_handle], + ["A/P midline", "GT Node"], + loc="lower right", + fontsize=9, + framealpha=label_bg_alpha, + bbox_to_anchor=(1.12, -0.05), + prop={"weight": "bold"}, + facecolor=bg_color, + labelcolor=text_color, + ) + ax1.add_artist(midline_legend) + + # Determine dominant status color for arrow text + status_colors_map = { + "Both in GT, order correct": "#2ecc71", + "Both in GT, order reversed": "#f39c12", + "One node in GT": "#e74c3c", + "Neither node in GT": "#95a5a6", + "NO SUGGESTION": "#bdc3c7", + } + status_counts = {} + for f in files: + status = suggested_analysis.get(f, {}).get("status", "NO SUGGESTION") + status_counts[status] = status_counts.get(status, 0) + 1 + dominant_status = ( + max(status_counts, key=status_counts.get) + if status_counts + else "NO SUGGESTION" + ) + dominant_color = status_colors_map.get(dominant_status, "#bdc3c7") + + # Add arrow description at top left (bold, dominant status color) + fig.text( + 0.01, + 0.99, + "vector points in the inferred P\u2192A direction for suggested pair", + fontsize=8, + fontweight="normal", + color=dominant_color, + ha="left", + va="top", + ) + fig.text( + 0.01, + 0.97, + "(max AP separation, distal preferred)", + fontsize=8, + fontweight="normal", + color=dominant_color, + ha="left", + va="top", + ) + + # Add formula text at top right corner of figure + line_spacing = 0.018 + y_start = 0.99 + x_right = 1.0 + fig.text( + x_right, + y_start, + "R = √(C² + S²)", + fontsize=8, + fontweight="normal", + color=text_color, + ha="right", + va="top", + ) + fig.text( + x_right, + y_start - line_spacing, + "M = |n₊ − n₋| / (n₊ + n₋)", + fontsize=8, + fontweight="normal", + color=text_color, + ha="right", + va="top", + ) + fig.text( + x_right, + y_start - 2 * line_spacing, + "C, S = mean cos θ, mean sin θ of centroid velocities", + fontsize=8, + fontweight="normal", + color=text_color, + ha="right", + va="top", + ) + fig.text( + x_right, + y_start - 3 * line_spacing, + "n₊, n₋ = # of pos./neg. velocity proj. onto PC1", + fontsize=8, + fontweight="normal", + color=text_color, + ha="right", + va="top", + ) + fig.text( + 0.01, + y_start - 3 * line_spacing, + "anterior = +PC1 if n₊ > n₋, else −PC1", + fontsize=8, + fontweight="normal", + color=text_color, + ha="left", + va="top", + ) + + # Build labels list for bar charts (use animal identity) + labels = [best_individual.get(f, f[:15]) for f in files] + + # Panel 3: Suggested pair analysis + ax3 = fig.add_subplot(gs[1, 0], facecolor=bg_color) + + # Status colors for legend (kept for tile 1) + status_colors = { + "Both in GT, order correct": "#2ecc71", + "Both in GT, order reversed": "#f39c12", + "One node in GT": "#e74c3c", + "Neither node in GT": "#95a5a6", + "NO SUGGESTION": "#bdc3c7", + } + + # Create custom legend handler for arrow with dots at each end + class ArrowHandler(HandlerBase): + def __init__(self, color, left_dot_color=None, right_dot_color=None): + self.color = color + self.left_dot_color = left_dot_color if left_dot_color else color + self.right_dot_color = ( + right_dot_color if right_dot_color else color + ) + super().__init__() + + def create_artists( + self, + legend, + orig_handle, + xdescent, + ydescent, + width, + height, + fontsize, + trans, + ): + # Arrow line with triangular head + x_start = xdescent + 1 + x_end = xdescent + width - 1 + y_mid = ydescent + height / 2 + + # Draw arrow shaft (shorter to make room for visible arrowhead) + line = Line2D( + [x_start + 5, x_end - 10], + [y_mid, y_mid], + color=self.color, + linewidth=3, + transform=trans, + ) + + # Draw arrowhead as a triangle (larger and more visible) + head_width = 5 + head_length = 5 + arrow_head = mpatches.FancyArrow( + x_end - 10, + y_mid, + 5, + 0, + width=0, + head_width=head_width, + head_length=head_length, + fc=self.color, + ec=self.color, + transform=trans, + ) + + # Draw dots at each end with specified colors + ldc, rdc = self.left_dot_color, self.right_dot_color + dot_start = Line2D( + [x_start], + [y_mid], + marker="o", + markersize=7, + markerfacecolor=ldc, + markeredgecolor=ldc, + linestyle="None", + transform=trans, + ) + dot_end = Line2D( + [x_end + 4], + [y_mid], + marker="o", + markersize=7, + markerfacecolor=rdc, + markeredgecolor=rdc, + linestyle="None", + transform=trans, + ) + + return [line, arrow_head, dot_start, dot_end] + + # Create legend for status colors at bottom left of tile 1 + legend_labels = [ + ("Both in GT, order correct", "Both in GT, Match A\u2194P"), + ("Both in GT, order reversed", "Both in GT, Mismatch A\u2194P"), + ("One node in GT", "One in GT"), + ("One node in GT SWAPPED", "One in GT"), + ("Neither node in GT", "Neither in GT"), + ] + + # Create legend handles and handler map + legend_handles = [] + handler_map = {} + legend_colors = [] # Track colors for text coloring + for status, label in legend_labels: + # Get base color (strip SWAPPED suffix for lookup) + base_status = status.replace(" SWAPPED", "") + color = status_colors[base_status] + legend_colors.append(color) + handle = Line2D([], [], label=label) + legend_handles.append(handle) + # "One in GT" has black left dot (GT) and red right dot (non-GT) + if status == "One node in GT": + handler_map[handle] = ArrowHandler( + color, left_dot_color="black", right_dot_color=color + ) + elif status == "One node in GT SWAPPED": + handler_map[handle] = ArrowHandler( + color, left_dot_color=color, right_dot_color="black" + ) + else: + handler_map[handle] = ArrowHandler(color) + + leg_status = fig.legend( + handles=legend_handles, + handler_map=handler_map, + loc="lower left", + bbox_to_anchor=(0.01, 0.58), + fontsize=10, + frameon=False, + prop={"weight": "bold"}, + ) + + # Set legend text colors to match arrow colors (One in GT = black) + zipped = zip( + leg_status.get_texts(), legend_labels, legend_colors, strict=False + ) + for text, (status, _label), color in zipped: + if "One node in GT" in status: + text.set_color("black") + else: + text.set_color(color) + + # Add "Suggested AP Pair:" label below legend + fig.text( + 0.01, + 0.565, + "Suggested AP Pair:", + fontsize=13, + fontweight="bold", + color="black", + ha="left", + va="center", + ) + + # Panel 3: GT Node Ranks (descending bars) + # Reassign ranks: highest rank = most anterior (tallest bar) + # Bar width and group spacing + group_width = 0.8 + + ax3.set_title( + "Hand-Curated GT Node Rankings", + fontsize=14, + fontweight="bold", + color=text_color, + ) + + for i, f in enumerate(files): + gt = GROUND_TRUTH.get(f, {}) + file_color = colors[i] + + if not gt: + continue + + # Higher rank = more anterior, so rank directly maps to bar height + # Sort by rank descending (most anterior/highest rank first) + sorted_nodes = sorted(gt.items(), key=lambda x: x[1], reverse=True) + n_gt_nodes = len(sorted_nodes) + + for j, (node_idx, rank) in enumerate(sorted_nodes): + # Rank value is the bar height (higher = more anterior) + height = rank + + # Position within group (most anterior = leftmost = tallest) + x_pos = ( + i - group_width / 2 + (j + 0.5) * (group_width / n_gt_nodes) + ) + + # Draw bar + ax3.bar( + x_pos, + height, + width=group_width / n_gt_nodes * 0.85, + color=file_color, + edgecolor=axis_color, + linewidth=0.5, + alpha=0.75, + ) + + # Add node index as bold text on top of bar + ax3.text( + x_pos, + height + 0.1, + str(node_idx), + ha="center", + va="bottom", + fontsize=13, + fontweight="bold", + color=text_color, + ) + + ax3.set_xticks(range(n_files)) + ax3.set_xticklabels( + labels, rotation=30, ha="right", fontsize=14, fontweight="bold" + ) + # Color x-axis labels to match skeleton colors + for tick_label, color in zip( + ax3.get_xticklabels(), + colors, + strict=True, + ): + tick_label.set_color(color) + ax3.set_ylabel( + "Rank (higher = more anterior)", + fontsize=14, + fontweight="bold", + color=text_color, + ) + ax3.set_ylim(0, max(len(gt) for gt in GROUND_TRUTH.values()) + 1) + + # Make tile 3 tick labels bold + ax3.tick_params(axis="y", labelsize=10) + for label in ax3.get_yticklabels(): + label.set_fontweight("bold") + + # Panel 4: Filter cascade progression per dataset + ax4 = fig.add_subplot(gs[1, 1], facecolor=bg_color) + + import matplotlib as mpl + + mpl.rcParams["hatch.linewidth"] = 1.5 + + # Compute cascade data per file as raw pair counts. + # Step 1 shows C(n_candidates, 2) — pair potential of surviving nodes. + # Steps 2-3 show surviving pair counts directly. + # Since C(K,2) >= P >= D, bars are guaranteed to decrease. + all_counts = [[], [], []] + for f in files: + cs = cascade_stats.get(f, {}) + n_s1 = cs.get("n_step1_candidates", 0) + n_s2 = cs.get("n_step2_pairs", 0) + n_s3d = cs.get("n_step3_distal", 0) + + s1_pairs = n_s1 * (n_s1 - 1) // 2 if n_s1 >= 2 else 0 + all_counts[0].append(s1_pairs) + all_counts[1].append(n_s2) + all_counts[2].append(n_s3d) + + x_pos = np.arange(n_files) + bar_width = 0.25 + + # Hatch patterns distinguish steps; bar color matches dataset + step_hatches = ["--", "||", "+++"] + step_names = ["Step 1: Lateral", "Step 2: Opposite", "Step 3: Distal"] + + for j, (counts, hatch, _sn) in enumerate( + zip( + all_counts, + step_hatches, + step_names, + strict=True, + ) + ): + offset = (j - 1) * bar_width + for i, count in enumerate(counts): + c = colors[i] + fc = (c[0], c[1], c[2], 0.30) + ax4.bar( + x_pos[i] + offset, + count, + bar_width, + color=fc, + edgecolor="black", + linewidth=0.5, + hatch=hatch, + ) + bx = x_pos[i] + offset + by = count + 0.5 + ax4.text( + bx, + by, + str(count), + ha="center", + va="bottom", + fontsize=8, + fontweight="bold", + color=text_color, + ) + + # Small legend showing only hatch patterns (no color) + legend_patches = [] + for hatch, sn in zip(step_hatches, step_names, strict=True): + legend_patches.append( + mpatches.Patch( + facecolor="white", + edgecolor="black", + linewidth=0.5, + hatch=hatch, + label=sn, + ) + ) + ax4.legend( + handles=legend_patches, + loc="upper right", + fontsize=8, + framealpha=label_bg_alpha, + facecolor=bg_color, + labelcolor=text_color, + handlelength=1.5, + handleheight=1.0, + ) + + ax4.set_ylabel( + "Candidate Pairs", + fontsize=14, + fontweight="bold", + color=text_color, + ) + # Auto-scale y-axis with padding + max_count = max(max(c) for c in all_counts) if any(all_counts[0]) else 1 + ax4.set_ylim(0, max_count * 1.25) + ax4.set_xticks(x_pos) + ax4.set_xticklabels( + labels, rotation=30, ha="right", fontsize=14, fontweight="bold" + ) + for tick_label, color in zip( + ax4.get_xticklabels(), + colors, + strict=True, + ): + tick_label.set_color(color) + + ax4.tick_params(axis="y", labelsize=10, colors=text_color) + for label in ax4.get_yticklabels(): + label.set_fontweight("bold") + + ax4.set_title( + "3-Step Filter Cascade Progression", + fontsize=14, + fontweight="bold", + color=text_color, + ) + plt.subplots_adjust(left=0.08, right=0.95, top=0.92, bottom=0.08) + + for ax in [ax3, ax4]: + ax.tick_params(axis="y", colors=text_color, labelsize=10) + for spine in ax.spines.values(): + spine.set_color(axis_color) + ax.xaxis.label.set_color(text_color) + ax.yaxis.label.set_color(text_color) + if hasattr(ax, "title"): + ax.title.set_color(text_color) + + for tick_label, color in zip( + ax3.get_xticklabels(), + colors, + strict=True, + ): + tick_label.set_color(color) + tick_label.set_fontsize(14) + for tick_label, color in zip( + ax4.get_xticklabels(), + colors, + strict=True, + ): + tick_label.set_color(color) + tick_label.set_fontsize(14) + + ax4.yaxis.label.set_color("black") + + # Extract full timestamp (YYYYMMDD_HHMMSS) from H5 filename + stem_parts = output_path.stem.split("_") + stem_suffix = "_".join(stem_parts[-2:]) + fig_path = FIGURES_DIR / f"ap_validation_results_{stem_suffix}.svg" + plt.savefig( + fig_path, format="svg", facecolor=bg_color, bbox_inches="tight" + ) + plt.close(fig) + + print(f"\nSaved cross-dataset validation figure to: {fig_path}") + return fig_path + + +def create_individual_plot( # noqa: C901 + file_stem, + avg_skeleton, + pc1, + keypoint_names, + metrics, + output_path, + vel_projs=None, + individual_name=None, + all_individuals_metrics=None, + file_color=None, + individual_accuracy=None, +): + """Create a detailed 2×2 plot for a single file/animal. + + The resulting figure contains four tiles showing (1) the longitudinal + spread of keypoints along the first principal component, (2) the lateral + spread along the second principal component, (3) the direction of + anterior–posterior motion via centroid velocity, and (4) a scatter plot + relating the product of resultant length and vote margin (R×M) to + ground‑truth ordering accuracy. Colour schemes and labels are chosen + consistently with the summary AP validation figure. + + Parameters + ---------- + file_stem : str + The stem of the file name (without extension). Used to derive + display labels and file names for the output figure. + avg_skeleton : np.ndarray + Array of shape ``(n_keypoints, 2)`` containing the average x,y + coordinates of each keypoint after segmentation and outlier removal. + pc1 : np.ndarray + The first principal component vector (length‑2) representing the + anterior–posterior axis. + keypoint_names : list[str] + Names of the keypoints corresponding to rows of ``avg_skeleton``. + metrics : dict + Dictionary of scalar metrics summarising the individual's movement + and orientation. Expected keys include ``'anterior_sign'``, + ``'vote_margin'``, ``'resultant_length'``, ``'circ_mean_dir'``, + ``'num_selected_frames'`` and ``'n_frames'``. + output_path : Path | str + Directory or filename where the generated figure will be written. + vel_projs : np.ndarray | None, optional + One‑dimensional array of projections of velocity vectors onto + ``pc1``. When provided, a histogram of these values is drawn in + Tile 3. If ``None`` or empty, the histogram is omitted. + individual_name : str | None, optional + Name of the individual animal, shown in the bottom title and used + for the scatter plot legend. If ``None``, only the file label + appears. + all_individuals_metrics : dict[str, dict[str, float]], optional + Mapping from individual names to dictionaries containing mean + resultant length ``'R'`` and vote margin ``'M'`` values. Used to + populate the scatter plot (Tile 4). + file_color : tuple | None, optional + RGB(A) colour for this file, matching the colours used in the + summary AP validation figure. When ``None``, a default colour + palette is used. + individual_accuracy : dict[str, dict[str, float]] | None, optional + Mapping from individual names to accuracy dictionaries. If + provided, accuracy values are displayed in the scatter plot legend. + + Returns + ------- + matplotlib.figure.Figure | None + The created figure. If there are fewer than two valid keypoints + after filtering, the function prints a message and returns ``None``. + + """ + n_keypoints = len(avg_skeleton) + + # Generate unique visually distinct colors for all nodes + # Start with maximally distinct base colors, then add shades if needed + base_colors = [ + (0.12, 0.47, 0.71), # Blue + (1.00, 0.50, 0.05), # Orange + (0.17, 0.63, 0.17), # Green + (0.84, 0.15, 0.16), # Red + (0.58, 0.40, 0.74), # Purple + (0.55, 0.34, 0.29), # Brown + (0.89, 0.47, 0.76), # Pink + (0.50, 0.50, 0.50), # Gray + (0.74, 0.74, 0.13), # Olive/Yellow + (0.09, 0.75, 0.81), # Cyan + (0.00, 0.80, 0.60), # Teal + (0.90, 0.30, 0.50), # Magenta + (0.40, 0.20, 0.60), # Dark purple + (0.95, 0.70, 0.20), # Gold + (0.30, 0.70, 0.90), # Sky blue + (0.70, 0.90, 0.30), # Lime + ] + + def generate_colors(n): + """Generate n distinct colors, using shades if needed.""" + colors = [] + n_base = len(base_colors) + + if n <= n_base: + # Use base colors directly + colors = [(*base_colors[i], 1.0) for i in range(n)] + else: + # Use all base colors, then add lighter/darker shades + for i in range(n): + base_idx = i % n_base + # 0 = normal, 1 = lighter, 2 = darker, etc. + shade_level = i // n_base + + r, g, b = base_colors[base_idx] + + if shade_level == 0: + # Original color + pass + elif shade_level % 2 == 1: + # Lighter shade + factor = 0.4 * ((shade_level + 1) // 2) + r = min(1.0, r + (1 - r) * factor) + g = min(1.0, g + (1 - g) * factor) + b = min(1.0, b + (1 - b) * factor) + else: + # Darker shade + factor = 0.4 * (shade_level // 2) + r = max(0.0, r * (1 - factor)) + g = max(0.0, g * (1 - factor)) + b = max(0.0, b * (1 - factor)) + + colors.append((r, g, b, 1.0)) + + return colors + + colors = generate_colors(n_keypoints) + + # Get valid keypoints (non-NaN) + valid_mask = ~np.any(np.isnan(avg_skeleton), axis=1) + valid_idx = np.where(valid_mask)[0] + + if len(valid_idx) < 2: + print(f" Skipping {file_stem}: not enough valid keypoints") + return + + # Compute PC2 as perpendicular to PC1 + pc1 = pc1 / np.linalg.norm(pc1) # Normalize + pc2 = np.array([-pc1[1], pc1[0]]) # 90 degree rotation + + # Compute projections onto PC1 and PC2 + proj_pc1 = avg_skeleton @ pc1 + proj_pc2 = avg_skeleton @ pc2 + + # Display parameters + valid_coords = avg_skeleton[valid_mask] + shape_radius = np.max(np.abs(valid_coords)) * 1.3 + display_radius = shape_radius * 1.2 + + # Body axis extent + body_axis_extent = np.max(np.abs(proj_pc1[valid_mask])) * 1.3 + + # Use file color for all keypoints (matching summary figure) + + # Extract metrics + anterior_sign = int(metrics.get("anterior_sign", 1)) + vote_margin = metrics.get("vote_margin", 0) + resultant_length = metrics.get("resultant_length", 0) + circ_mean_dir = metrics.get("circ_mean_dir", np.nan) + num_selected_frames = int(metrics.get("num_selected_frames", 0)) + n_frames = int(metrics.get("n_frames", 1)) + + # Compute Vc (net velocity display vector) + # Scale so R=1 displays at 3*shape_radius, R scales 0-1 + # Clamp so arrow tip + label stays within display_radius + vel_display_max = shape_radius * 3.0 + vel_display_len = resultant_length * vel_display_max + max_vc_len = display_radius * 0.8 # clamp to 80% of axis boundary + vel_display_len = min(vel_display_len, max_vc_len) + if not np.isnan(circ_mean_dir): + cos_dir = np.cos(circ_mean_dir) + sin_dir = np.sin(circ_mean_dir) + net_vel_display = vel_display_len * np.array([cos_dir, sin_dir]) + else: + net_vel_display = np.array([0.0, 0.0]) + + centroid_alpha = 0.7 + label_offset = 0.12 * shape_radius + + # Create figure with dark background + fig = plt.figure(figsize=(14, 10), facecolor="black") + + # Compute R*M + rxm = resultant_length * vote_margin + + # Top title with description + fig.suptitle( + "AP Validation Detail: BBox-Centroid Average Skeleton", + fontsize=12, + fontweight="normal", + color="white", + y=0.97, + ) + + # Bottom title with frame count, dataset label, and individual name + # Move to right (beneath tile 4) if >=20 nodes to avoid legend overlap + file_label = FILE_LABELS.get(file_stem, file_stem[:15]) + _ind_label = f" ({individual_name})" if individual_name else "" + pct = 100.0 * num_selected_frames / n_frames if n_frames > 0 else 0 + if n_keypoints >= 20: + text_x, text_ha = 0.78, "center" # Right side, beneath tile 4 + else: + text_x, text_ha = 0.5, "center" # Centered + fig.text( + text_x, + 0.045, + f"{num_selected_frames} segment frames ({pct:.1f}% of all)", + fontsize=14, + fontweight="normal", + color="white", + ha=text_ha, + va="bottom", + ) + # Render file label with individual name + if individual_name: + fig.text( + text_x, + 0.012, + f"{file_label} ({individual_name})", + fontsize=14, + fontweight="bold", + color="white", + ha=text_ha, + va="bottom", + ) + else: + fig.text( + text_x, + 0.012, + file_label, + fontsize=14, + fontweight="bold", + color="white", + ha=text_ha, + va="bottom", + ) + + # Fixed width ratio for Tile 4 scatter plot + tile4_width_ratio = 0.8 + + # Create 2x2 grid with space for center legend + right_margin = 0.92 + gs = fig.add_gridspec( + 2, + 2, + hspace=0.25, + wspace=0.35, + left=0.08, + right=right_margin, + top=0.90, + bottom=0.10, + width_ratios=[1, tile4_width_ratio], + ) + + # TILE 1: Longitudinal Spread (PC1) + ax1 = fig.add_subplot(gs[0, 0], facecolor="black") + ax1.set_title( + "Longitudinal Spread (PC1 Projections)", + fontsize=12, + fontweight="normal", + color="white", + ) + + # PC1 axis line segments (broken at keypoint range) + min_proj = np.min(proj_pc1[valid_mask]) + max_proj = np.max(proj_pc1[valid_mask]) + + ax1.plot( + [-1.5 * body_axis_extent * pc1[0], min_proj * pc1[0]], + [-1.5 * body_axis_extent * pc1[1], min_proj * pc1[1]], + "-", + color=(1, 1, 1, 0.65), + linewidth=1.5, + ) + ax1.plot( + [max_proj * pc1[0], 1.5 * body_axis_extent * pc1[0]], + [max_proj * pc1[1], 1.5 * body_axis_extent * pc1[1]], + "-", + color=(1, 1, 1, 0.65), + linewidth=1.5, + ) + + # PC1 labels at edges + pc1_perp = np.array([-pc1[1], pc1[0]]) + pc_label_extent = 1.1 * body_axis_extent + ax1.text( + pc_label_extent * pc1[0] + pc1_perp[0] * label_offset, + pc_label_extent * pc1[1] + pc1_perp[1] * label_offset, + "+PC1", + color="white", + fontsize=10, + ha="center", + va="center", + ) + ax1.text( + -pc_label_extent * pc1[0] + pc1_perp[0] * label_offset, + -pc_label_extent * pc1[1] + pc1_perp[1] * label_offset, + "—PC1", + color="white", + fontsize=10, + ha="center", + va="center", + ) + + # Draw projection lines and keypoints + for i in valid_idx: + tip_x, tip_y = avg_skeleton[i] + proj_x = proj_pc1[i] * pc1[0] + proj_y = proj_pc1[i] * pc1[1] + + ax1.plot( + [tip_x, proj_x], + [tip_y, proj_y], + "--", + color=(*colors[i][:3], 0.5), + linewidth=1, + ) + ax1.plot( + [0, proj_x], + [0, proj_y], + "-", + color=(*colors[i][:3], 0.5), + linewidth=2, + ) + + for i in valid_idx: + ax1.scatter( + avg_skeleton[i, 0], + avg_skeleton[i, 1], + s=120, + c=[colors[i]], + edgecolors="black", + linewidths=0.5, + zorder=5, + ) + + ax1.scatter( + 0, + 0, + s=120, + c="gray", + edgecolors="black", + linewidths=0.5, + alpha=centroid_alpha, + zorder=6, + ) + + ax1.set_xlim(-display_radius, display_radius) + ax1.set_ylim(-display_radius, display_radius) + ax1.set_aspect("equal") + ax1.set_xlabel("X", color="white", fontsize=10) + ax1.set_ylabel( + "Y", color="white", fontsize=10, rotation=0, ha="right", va="center" + ) + ax1.tick_params(colors="gray", labelsize=8) + ax1.grid(True, color="gray", alpha=0.3, linestyle="-", linewidth=0.5) + for spine in ax1.spines.values(): + spine.set_color("gray") + + # TILE 2: Lateral Spread (PC2) + ax2 = fig.add_subplot(gs[0, 1], facecolor="black") + ax2.set_title( + "Lateral Spread (PC2 Projections)", + fontsize=12, + fontweight="normal", + color="white", + ) + + min_proj_pc2 = np.min(proj_pc2[valid_mask]) + max_proj_pc2 = np.max(proj_pc2[valid_mask]) + + ax2.plot( + [-1.5 * body_axis_extent * pc2[0], min_proj_pc2 * pc2[0]], + [-1.5 * body_axis_extent * pc2[1], min_proj_pc2 * pc2[1]], + "-", + color=(1, 1, 1, 0.65), + linewidth=1.5, + ) + ax2.plot( + [max_proj_pc2 * pc2[0], 1.5 * body_axis_extent * pc2[0]], + [max_proj_pc2 * pc2[1], 1.5 * body_axis_extent * pc2[1]], + "-", + color=(1, 1, 1, 0.65), + linewidth=1.5, + ) + + # PC2 labels at edges + pc2_perp = np.array([-pc2[1], pc2[0]]) + pc_label_extent_pc2 = 1.1 * body_axis_extent + ax2.text( + pc_label_extent_pc2 * pc2[0] + pc2_perp[0] * label_offset, + pc_label_extent_pc2 * pc2[1] + pc2_perp[1] * label_offset, + "+PC2", + color="white", + fontsize=10, + ha="center", + va="center", + ) + ax2.text( + -pc_label_extent_pc2 * pc2[0] + pc2_perp[0] * label_offset, + -pc_label_extent_pc2 * pc2[1] + pc2_perp[1] * label_offset, + "—PC2", + color="white", + fontsize=10, + ha="center", + va="center", + ) + + for i in valid_idx: + tip_x, tip_y = avg_skeleton[i] + proj_x = proj_pc2[i] * pc2[0] + proj_y = proj_pc2[i] * pc2[1] + + ax2.plot( + [tip_x, proj_x], + [tip_y, proj_y], + "--", + color=(*colors[i][:3], 0.5), + linewidth=1, + ) + ax2.plot( + [0, proj_x], + [0, proj_y], + "-", + color=(*colors[i][:3], 0.5), + linewidth=2, + ) + + for i in valid_idx: + ax2.scatter( + avg_skeleton[i, 0], + avg_skeleton[i, 1], + s=120, + c=[colors[i]], + edgecolors="black", + linewidths=0.5, + zorder=5, + ) + + ax2.scatter( + 0, + 0, + s=120, + c="gray", + edgecolors="black", + linewidths=0.5, + alpha=centroid_alpha, + zorder=6, + ) + + ax2.set_xlim(-display_radius, display_radius) + ax2.set_ylim(-display_radius, display_radius) + ax2.set_aspect("equal") + ax2.set_xlabel("X", color="white", fontsize=10) + ax2.set_ylabel( + "Y", color="white", fontsize=10, rotation=0, ha="right", va="center" + ) + ax2.tick_params(colors="gray", labelsize=8) + ax2.grid(True, color="gray", alpha=0.3, linestyle="-", linewidth=0.5) + for spine in ax2.spines.values(): + spine.set_color("gray") + + # TILE 3: Inferred AP Direction (Velocity Voting) + ax3 = fig.add_subplot(gs[1, 0], facecolor="black") + ax3.set_title( + "Inferred AP Direction (Velocity Voting)", + fontsize=12, + fontweight="normal", + color="white", + ) + + vel_proj_scalar = np.dot(net_vel_display, pc1) + min_proj_vel = min(0, vel_proj_scalar) + max_proj_vel = max(0, vel_proj_scalar) + + ax3.plot( + [-1.5 * body_axis_extent * pc1[0], min_proj_vel * pc1[0]], + [-1.5 * body_axis_extent * pc1[1], min_proj_vel * pc1[1]], + "-", + color=(1, 1, 1, 0.65), + linewidth=1.5, + ) + ax3.plot( + [max_proj_vel * pc1[0], 1.5 * body_axis_extent * pc1[0]], + [max_proj_vel * pc1[1], 1.5 * body_axis_extent * pc1[1]], + "-", + color=(1, 1, 1, 0.65), + linewidth=1.5, + ) + + vel_proj_x = vel_proj_scalar * pc1[0] + vel_proj_y = vel_proj_scalar * pc1[1] + ax3.plot( + [0, vel_proj_x], + [0, vel_proj_y], + "-", + color=(1, 1, 1, 0.5), + linewidth=1.5, + ) + + body_perp = np.array([-pc1[1], pc1[0]]) + + # Velocity projection histogram along PC1 + if vel_projs is not None and len(vel_projs) > 0: + # Adapt bins based on data size, keep consistent transparency + n_vals = len(vel_projs) + # Use a ternary expression rather than an if–else block (SIM108) + num_bins = max(5, n_vals // 3) if n_vals < 50 else 25 + hist_alpha = 0.25 + max_bar_height = 1.0 * shape_radius + + vp_min = np.min(vel_projs) + vp_max = np.max(vel_projs) + data_range = vp_max - vp_min + + # Ensure minimum bin width similar to 2Mice + min_bin_width = shape_radius / 12 + min_range = min_bin_width * num_bins + if data_range < min_range: + center = (vp_min + vp_max) / 2 + vp_min = center - min_range / 2 + vp_max = center + min_range / 2 + data_range = min_range + + edge_pad = data_range * 0.02 + bin_edges = np.linspace( + vp_min - edge_pad, vp_max + edge_pad, num_bins + 1 + ) + + bin_counts, _ = np.histogram(vel_projs, bins=bin_edges) + bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 + + max_count = np.max(bin_counts) + # Set minimum bar height so even count=1 is visible + min_bar_height = 0.15 * max_bar_height + if max_count > 0: + bar_heights = (bin_counts / max_count) * max_bar_height + # Apply minimum height for non-zero bins + bar_heights = np.where( + bin_counts > 0, + np.maximum(bar_heights, min_bar_height), + bar_heights, + ) + else: + bar_heights = np.zeros_like(bin_counts) + + # Draw bars as filled polygons in PC1/perp coordinate frame + for bi in range(len(bin_counts)): + if bin_counts[bi] == 0: + continue + + pc1_lo = bin_edges[bi] + pc1_hi = bin_edges[bi + 1] + bh = bar_heights[bi] + + # Four corners: bottom-left, bottom-right, top-right, top-left + corners_pc1 = np.array([pc1_lo, pc1_hi, pc1_hi, pc1_lo]) + corners_perp = np.array([0, 0, bh, bh]) + + # Transform to x,y using PC1 and body_perp basis + cx = corners_pc1 * pc1[0] + corners_perp * body_perp[0] + cy = corners_pc1 * pc1[1] + corners_perp * body_perp[1] + + # Color by sign: blue for +PC1, red for -PC1 + if bin_centers[bi] > 0: + bar_color = (0.4, 0.75, 1.0) # blue + else: + bar_color = (1.0, 0.45, 0.35) # red + + ax3.fill( + cx, + cy, + color=bar_color, + alpha=hist_alpha, + edgecolor="black", + linewidth=0.5, + zorder=2, + ) + + # Add bin count label at top of bar (centered within bar) + bin_center_pc1 = (pc1_lo + pc1_hi) / 2 + # At top of bar, slightly inside + label_perp_offset = bh - 0.02 * shape_radius + lx = bin_center_pc1 * pc1[0] + label_perp_offset * body_perp[0] + ly = bin_center_pc1 * pc1[1] + label_perp_offset * body_perp[1] + # Rotation angle matches PC1 direction + rot_deg = np.degrees(np.arctan2(pc1[1], pc1[0])) + ax3.text( + lx, + ly, + str(bin_counts[bi]), + color="white", + fontsize=5, + ha="center", + va="top", + rotation=rot_deg, + rotation_mode="anchor", + zorder=20, + ) + + for i in valid_idx: + ax3.scatter( + avg_skeleton[i, 0], + avg_skeleton[i, 1], + s=120, + c=[colors[i]], + edgecolors="black", + linewidths=0.5, + zorder=5, + ) + + # PC1 labels at endpoints of PC line, on opposite side of histogram + pc_label_extent = 1.1 * body_axis_extent # Closer to tile center + pc_label_offset = label_offset * 1.5 # Offset perpendicular to PC1 + ax3.text( + pc_label_extent * pc1[0] - body_perp[0] * pc_label_offset, + pc_label_extent * pc1[1] - body_perp[1] * pc_label_offset, + "+PC1", + color="white", + fontsize=10, + ha="center", + va="center", + ) + ax3.text( + -pc_label_extent * pc1[0] - body_perp[0] * pc_label_offset, + -pc_label_extent * pc1[1] - body_perp[1] * pc_label_offset, + "—PC1", + color="white", + fontsize=10, + ha="center", + va="center", + ) + + # Vc arrow (circular-mean velocity) + if np.linalg.norm(net_vel_display) > 0: + net_vel_angle = np.arctan2(net_vel_display[1], net_vel_display[0]) + + ax3.plot( + [0, net_vel_display[0]], + [0, net_vel_display[1]], + "-", + color="white", + linewidth=1.2, + zorder=7, + ) + + arrow_len = 0.06 * shape_radius + arrow_wid = 0.03 * shape_radius + local_tri_x = np.array([0, -arrow_len, -arrow_len]) + local_tri_y = np.array([0, arrow_wid, -arrow_wid]) + cos_a, sin_a = np.cos(net_vel_angle), np.sin(net_vel_angle) + tri_x = local_tri_x * cos_a - local_tri_y * sin_a + net_vel_display[0] + tri_y = local_tri_x * sin_a + local_tri_y * cos_a + net_vel_display[1] + ax3.fill( + tri_x, + tri_y, + color="white", + edgecolor="black", + linewidth=0.5, + zorder=8, + clip_on=True, + ) + + vc_label_offset = 0.12 * shape_radius + vc_label_x = net_vel_display[0] + vc_label_offset * np.cos( + net_vel_angle + ) + vc_label_y = net_vel_display[1] + vc_label_offset * np.sin( + net_vel_angle + ) + ax3.text( + vc_label_x, + vc_label_y, + r"$\mathbf{V_c}$", + color="white", + fontsize=12, + fontweight="bold", + ha="left", + va="center", + zorder=10, + clip_on=True, + ) + + # AP midline and bidirectional arrow (parallel to PC1) + min_p = np.min(proj_pc1[valid_mask]) + max_p = np.max(proj_pc1[valid_mask]) + midpoint_pc1 = (min_p + max_p) / 2 + midline_center = midpoint_pc1 * pc1 # Position along PC1 + midline_half_len = 0.35 * shape_radius + pc1_perp = np.array([-pc1[1], pc1[0]]) # Perpendicular to PC1 + + # Draw AP midline (dotted beige, perpendicular to PC1) + ax3.plot( + [ + midline_center[0] - midline_half_len * pc1_perp[0], + midline_center[0] + midline_half_len * pc1_perp[0], + ], + [ + midline_center[1] - midline_half_len * pc1_perp[1], + midline_center[1] + midline_half_len * pc1_perp[1], + ], + ":", + color=(0.82, 0.71, 0.55), + linewidth=1.5, + zorder=4, + ) + + # AP bidirectional arrow (along PC1, next to midline) + ap_arrow_offset = midline_half_len * 1.1 # Offset perpendicular to PC1 + # Opposite side of histogram + ap_arrow_center = midline_center - ap_arrow_offset * pc1_perp + ap_arrow_half_len = 0.225 * shape_radius # Shortened by 0.5x + + # Arrow endpoints along PC1 + ap_arrow_ant = ap_arrow_center + ap_arrow_half_len * pc1 * anterior_sign + ap_arrow_post = ap_arrow_center - ap_arrow_half_len * pc1 * anterior_sign + + # Draw arrow shaft + ax3.plot( + [ap_arrow_post[0], ap_arrow_ant[0]], + [ap_arrow_post[1], ap_arrow_ant[1]], + "-", + color=(0.82, 0.71, 0.55), + linewidth=1.5, + zorder=4, + ) + + # Draw arrowheads + head_len = 0.08 * shape_radius + head_wid = 0.04 * shape_radius + + # Anterior arrowhead + ant_head_base = ap_arrow_ant - head_len * pc1 * anterior_sign + ant_tri = np.array( + [ + ap_arrow_ant, + ant_head_base + head_wid * pc1_perp, + ant_head_base - head_wid * pc1_perp, + ] + ) + tan_color = (0.82, 0.71, 0.55) + ax3.fill( + ant_tri[:, 0], + ant_tri[:, 1], + color=tan_color, + edgecolor="black", + linewidth=0.5, + zorder=5, + ) + + # Posterior arrowhead + post_head_base = ap_arrow_post + head_len * pc1 * anterior_sign + post_tri = np.array( + [ + ap_arrow_post, + post_head_base + head_wid * pc1_perp, + post_head_base - head_wid * pc1_perp, + ] + ) + ax3.fill( + post_tri[:, 0], + post_tri[:, 1], + color=tan_color, + edgecolor="black", + linewidth=0.5, + zorder=5, + ) + + # A and P labels + ap_label_offset = 0.06 * shape_radius + ax3.text( + ap_arrow_ant[0] + ap_label_offset * pc1[0] * anterior_sign, + ap_arrow_ant[1] + ap_label_offset * pc1[1] * anterior_sign, + "A", + color=tan_color, + fontsize=10, + fontweight="bold", + ha="center", + va="center", + zorder=10, + ) + ax3.text( + ap_arrow_post[0] - ap_label_offset * pc1[0] * anterior_sign, + ap_arrow_post[1] - ap_label_offset * pc1[1] * anterior_sign, + "P", + color=tan_color, + fontsize=10, + fontweight="bold", + ha="center", + va="center", + zorder=10, + ) + + ax3.set_xlim(-display_radius, display_radius) + ax3.set_ylim(-display_radius, display_radius) + ax3.set_aspect("equal") + ax3.set_xlabel("X", color="white", fontsize=10) + ax3.set_ylabel( + "Y", color="white", fontsize=10, rotation=0, ha="right", va="center" + ) + ax3.set_xticks([]) + ax3.set_yticks([]) + for spine in ax3.spines.values(): + spine.set_color("gray") + + # TILE 4: R×M vs GT Accuracy Scatter Plot + ax4 = fig.add_subplot(gs[1, 1], facecolor="black") + + # Different marker shapes for each individual + marker_shapes = ["o", "s", "^", "D", "v", "h", "*", "X", "P", "p"] + + # Check if both metrics and accuracy data are present + has_data = ( + all_individuals_metrics + and len(all_individuals_metrics) > 0 + and individual_accuracy + and len(individual_accuracy) > 0 + ) + + if has_data: + # Collect data points for scatter plot + rxm_values = [] + accuracy_values = [] + ind_labels = [] + + for ind in sorted(all_individuals_metrics.keys()): + R = all_individuals_metrics[ind]["R"] + M = all_individuals_metrics[ind]["M"] + rxm = R * M + + if ind in individual_accuracy: + acc = individual_accuracy[ind]["accuracy"] + rxm_values.append(rxm) + accuracy_values.append(acc) + ind_labels.append(ind) + + if rxm_values: + # Use file color if provided, otherwise default to gray + base_color = file_color[:3] if file_color else (0.5, 0.5, 0.5) + + # Find best individual (highest R×M) + best_idx = np.argmax(rxm_values) + + # Plot each individual with a different marker shape + # Best individual (highest R×M) always uses star marker + legend_handles = [] + data = zip(rxm_values, accuracy_values, ind_labels, strict=True) + for i, (rxm, acc, label) in enumerate(data): + is_best = i == best_idx + marker = ( + "*" if is_best else marker_shapes[i % len(marker_shapes)] + ) + size = 200 if is_best else 120 + + zord = 6 if is_best else 5 + _scatter = ax4.scatter( + [rxm], + [acc], + s=size, + c=[base_color], + marker=marker, + edgecolors="white", + linewidths=1.0, + alpha=0.9, + zorder=zord, + ) + + # Create legend handle (white edge) + ms = 12 if is_best else 10 + legend_handles.append( + Line2D( + [0], + [0], + marker=marker, + color="w", + markerfacecolor=base_color, + markeredgecolor="white", + markeredgewidth=1.0, + markersize=ms, + linestyle="None", + label=label, + ) + ) + + # Set axis limits with padding + rxm_range = ( + max(rxm_values) - min(rxm_values) + if len(rxm_values) > 1 + else 0.1 + ) + # Handle case when all values are identical (range = 0) + if rxm_range < 0.01: + rxm_range = 0.1 + x_min = min(rxm_values) - rxm_range * 0.15 + x_max = max(rxm_values) + rxm_range * 0.25 + ax4.set_xlim(x_min, x_max) + # Y-axis: -20 to 120 range, but only show ticks 0-100 + ax4.set_ylim(-20, 120) + ax4.set_yticks([0, 20, 40, 60, 80, 100]) + + # Axis labels + ax4.set_xlabel("R×M", color="white", fontsize=10) + ax4.set_ylabel( + "Pairwise GT Concordance\n(inferred AP, %)", + color="white", + fontsize=10, + ) + + # Add legend for individual markers + ax4.legend( + handles=legend_handles, + loc="lower right", + fontsize=8, + facecolor="black", + edgecolor="gray", + labelcolor="white", + ) + + # Add grid + ax4.grid( + True, color="gray", alpha=0.3, linestyle="-", linewidth=0.5 + ) + else: + ax4.text( + 0.5, + 0.5, + "No data", + color="gray", + fontsize=12, + ha="center", + va="center", + transform=ax4.transAxes, + ) + + ax4.tick_params(axis="y", colors="gray", labelsize=8) + ax4.tick_params(axis="x", colors="gray", labelsize=8) + for spine in ax4.spines.values(): + spine.set_color("gray") + + # NODE LEGEND (Center of figure) + legend_x = 0.47 + legend_top_y = 0.88 + legend_spacing = 0.045 + + fig_w, fig_h = fig.get_size_inches() + fig_aspect = fig_w / fig_h + dot_h = 0.012 + dot_w = dot_h / fig_aspect + + legend_entry_count = 0 + for i in valid_idx: + y_pos = legend_top_y - legend_entry_count * legend_spacing + + ax_dot = fig.add_axes( + [legend_x - dot_w / 2, y_pos - dot_h / 2, dot_w, dot_h] + ) + ax_dot.set_xlim(0, 1) + ax_dot.set_ylim(0, 1) + ax_dot.scatter( + 0.5, 0.5, s=100, c=[colors[i]], edgecolors="black", linewidths=0.5 + ) + ax_dot.axis("off") + + name = keypoint_names[i] if i < len(keypoint_names) else f"node_{i}" + fig.text( + legend_x + 0.015, + y_pos, + f"[{i}] {name}", + color="white", + fontsize=10, + va="center", + ha="left", + ) + + legend_entry_count += 1 + + # Centroid entry + y_pos = legend_top_y - legend_entry_count * legend_spacing + ax_dot = fig.add_axes( + [legend_x - dot_w / 2, y_pos - dot_h / 2, dot_w, dot_h] + ) + ax_dot.set_xlim(0, 1) + ax_dot.set_ylim(0, 1) + ax_dot.scatter( + 0.5, + 0.5, + s=100, + c=[(0.3, 0.3, 0.3)], + edgecolors="black", + linewidths=0.5, + ) + ax_dot.axis("off") + fig.text( + legend_x + 0.015, + y_pos, + "bbox centroid", + color="white", + fontsize=10, + va="center", + ha="left", + ) + legend_entry_count += 1 + + # Histogram bar entries (blue = +vel, red = -vel) + bar_w = 0.012 + bar_h = 0.008 + + # Blue bar (+vel projections) + y_pos = legend_top_y - legend_entry_count * legend_spacing + ax_bar = fig.add_axes( + [legend_x - bar_w / 2, y_pos - bar_h / 2, bar_w, bar_h] + ) + ax_bar.set_xlim(0, 1) + ax_bar.set_ylim(0, 1) + ax_bar.add_patch( + plt.Rectangle( + (0.1, 0.1), + 0.8, + 0.8, + facecolor=(0.4, 0.75, 1.0), + alpha=0.6, + edgecolor="none", + ) + ) + ax_bar.axis("off") + fig.text( + legend_x + 0.015, + y_pos, + "+PC1 vel proj.", + color="white", + fontsize=10, + va="center", + ha="left", + ) + legend_entry_count += 1 + + # Red bar (-vel projections) + y_pos = legend_top_y - legend_entry_count * legend_spacing + ax_bar = fig.add_axes( + [legend_x - bar_w / 2, y_pos - bar_h / 2, bar_w, bar_h] + ) + ax_bar.set_xlim(0, 1) + ax_bar.set_ylim(0, 1) + ax_bar.add_patch( + plt.Rectangle( + (0.1, 0.1), + 0.8, + 0.8, + facecolor=(1.0, 0.45, 0.35), + alpha=0.6, + edgecolor="none", + ) + ) + ax_bar.axis("off") + fig.text( + legend_x + 0.015, + y_pos, + "−PC1 vel proj.", + color="white", + fontsize=10, + va="center", + ha="left", + ) + legend_entry_count += 1 + + # AP Midline (dotted beige line) + line_w = 0.025 + line_h = 0.006 + y_pos = legend_top_y - legend_entry_count * legend_spacing + ax_line = fig.add_axes( + [legend_x - line_w / 2, y_pos - line_h / 2, line_w, line_h] + ) + ax_line.set_xlim(0, 1) + ax_line.set_ylim(0, 1) + ax_line.plot([0.1, 0.9], [0.5, 0.5], ":", color=tan_color, linewidth=2) + ax_line.axis("off") + fig.text( + legend_x + 0.015, + y_pos, + "AP midline", + color="white", + fontsize=10, + va="center", + ha="left", + ) + + # Save figure + plt.savefig( + output_path, + format="svg", + facecolor="black", + edgecolor="none", + bbox_inches="tight", + ) + plt.close(fig) + print(f" Saved: {output_path.name}") + + +def generate_individual_plots(analysis_data, output_dir, timestamp=None): # noqa: C901 + """Generate detailed 2x2 plot for each file's best individual.""" + data = analysis_data + stored_skeletons = data.get("skeletons", {}) + stored_pc1 = data.get("pc1_vectors", {}) + stored_vel_projs = data.get("vel_projs_pc1", {}) + keypoints_by_file = data.get("keypoints", {}) + individual_accuracy_by_file = data.get("individual_accuracy", {}) + + # File color mapping (matching ap_validation_results figure) + # Files sorted alphabetically, colors assigned in order + file_colors = { + "free-moving-2flies-ID-13nodes-1024x1024x1-30_3pxmm": ( + 0.12, + 0.47, + 0.71, + 1.0, + ), # Blue + "free-moving-2mice-noID-5nodes-1280x1024x1-1_9pxmm": ( + 0.17, + 0.63, + 0.17, + 1.0, + ), # Green + "free-moving-4gerbils-ID-14nodes-1024x1280x3-2pxmm": ( + 0.80, + 0.70, + 0.10, + 1.0, + ), # Gold + "free-moving-5mice-noID-11nodes-1280x1024x1-1_97pxmm": ( + 1.00, + 0.50, + 0.05, + 1.0, + ), # Orange + "freemoving-2bees-noID-21nodes-1535x2048x1-14pxmm": ( + 0.58, + 0.40, + 0.74, + 1.0, + ), # Purple + } + + # Collect R, M, and R×M values per file per individual + def default_metrics(): + return {"R": [], "M": [], "RxM": []} + + file_individual_data = defaultdict(lambda: defaultdict(default_metrics)) + for i in range(len(data["file"])): + file_stem = data["file"][i] + individual = data["individual"][i] + R = data["resultant_length"][i] + M = data["vote_margin"][i] + rxm = data["rxm"][i] + if not np.isnan(rxm): + file_individual_data[file_stem][individual]["R"].append(R) + file_individual_data[file_stem][individual]["M"].append(M) + file_individual_data[file_stem][individual]["RxM"].append(rxm) + + # Compute mean R, M per individual and find best by R×M + best_individual = {} + best_metrics = {} + # {file_stem: {individual: {"R": mean_R, "M": mean_M}}} + all_individuals_metrics_by_file = {} + + for file_stem, individuals in file_individual_data.items(): + best_rxm = -1 + best_ind = None + all_individuals_metrics_by_file[file_stem] = {} + + for ind, metrics_dict in individuals.items(): + mean_R = np.mean(metrics_dict["R"]) + mean_M = np.mean(metrics_dict["M"]) + mean_rxm = np.mean(metrics_dict["RxM"]) + all_individuals_metrics_by_file[file_stem][ind] = { + "R": mean_R, + "M": mean_M, + } + if mean_rxm > best_rxm: + best_rxm = mean_rxm + best_ind = ind + + best_individual[file_stem] = best_ind + + # Get metrics for best individual + for i in range(len(data["file"])): + is_match = ( + data["file"][i] == file_stem + and data["individual"][i] == best_ind + ) + if is_match: + best_metrics[file_stem] = { + "vote_margin": data["vote_margin"][i], + "resultant_length": data["resultant_length"][i], + "anterior_sign": data["anterior_sign"][i], + "circ_mean_dir": data["circ_mean_dir"][i], + "num_selected_frames": data["num_selected_frames"][i], + "n_frames": data["n_frames"][i], + } + break + + print("\nGenerating per-file detail plots (best individual)...") + + for file_stem, best_ind in sorted(best_individual.items()): + label = FILE_LABELS.get(file_stem, file_stem[:15]) + print(f" {label}: *{best_ind}*") + + has_skel = ( + file_stem in stored_skeletons + and best_ind in stored_skeletons[file_stem] + ) + if not has_skel: + print(" No skeleton data") + continue + + avg_skeleton = stored_skeletons[file_stem][best_ind] + + has_pc1 = file_stem in stored_pc1 and best_ind in stored_pc1[file_stem] + if not has_pc1: + print(" No PC1 data") + continue + + pc1 = stored_pc1[file_stem][best_ind] + keypoint_names = keypoints_by_file.get(file_stem, []) + metrics = best_metrics.get(file_stem, {}) + + # Get velocity projections for histogram + vel_projs = None + has_vel = ( + file_stem in stored_vel_projs + and best_ind in stored_vel_projs[file_stem] + ) + if has_vel: + vel_projs = stored_vel_projs[file_stem][best_ind] + + # Get R and M data for all individuals in this file + all_individuals_metrics = all_individuals_metrics_by_file.get( + file_stem, {} + ) + + # Get individual accuracy data for this file + individual_accuracy = individual_accuracy_by_file.get(file_stem, {}) + + # Get file color + file_color = file_colors.get(file_stem, (0.5, 0.5, 0.5, 1.0)) + + # Sanitize individual name for filename + safe_ind = best_ind.replace(" ", "_") + ts_suffix = f"_{timestamp}" if timestamp else "" + output_path = output_dir / f"{file_stem}_{safe_ind}{ts_suffix}.svg" + create_individual_plot( + file_stem, + avg_skeleton, + pc1, + keypoint_names, + metrics, + output_path, + vel_projs, + individual_name=best_ind, + all_individuals_metrics=all_individuals_metrics, + file_color=file_color, + individual_accuracy=individual_accuracy, + ) + + print(f"Per-file detail plots saved to: {output_dir}") + + +def main(): # noqa: C901 + ensure_demo_datasets() + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + FIGURES_DIR.mkdir(exist_ok=True) + LOGS_DIR.mkdir(exist_ok=True) + H5_DIR.mkdir(exist_ok=True) + + # Set up log file + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + log_file = LOGS_DIR / f"ap_validation_{timestamp}.log" + + with TeeOutput(log_file): + print(f"AP Validation started at {datetime.now().isoformat()}") + print(f"Log file: {log_file}") + print() + + # Check for existing H5 files + existing_h5 = sorted(H5_DIR.glob("ap_validation_*.h5")) + + if existing_h5: + # Use most recent existing file + output_path = existing_h5[-1] + print(f"Found existing H5 file: {output_path.name}") + print("Skipping generation, proceeding to analysis...") + else: + # Generate new H5 file + slp_files = sorted(SLP_DIR.glob("*.slp")) + if not slp_files: + print("No .slp files found after bootstrap") + sys.exit(1) + + # PASS 1: R×M Selection — Find best individual per file + n_files = len(slp_files) + print( + f"Pass 1: R×M Selection — finding best individual " + f"per file from {n_files} files..." + ) + rxm_tasks, file_metadata = generate_rxm_tasks(slp_files) + print(f" Running {len(rxm_tasks)} R×M computations...") + + rxm_results = [] + with ProcessPoolExecutor(max_workers=N_WORKERS) as executor: + futures = [ + executor.submit(process_single_validation, task) + for task in rxm_tasks + ] + for future in as_completed(futures): + rxm_results.append(future.result()) + + best_individuals, all_rxm, file_individual_data = ( + find_best_individuals(rxm_results) + ) + best_ind_str = ", ".join( + f"{FILE_LABELS.get(k, k[:15])}: *{v}*" + for k, v in best_individuals.items() + ) + print(f" Best individuals (max R×M): {{{best_ind_str}}}") + + # Compute PC1-based orderings for all individuals + # (prerequisite for Pass 2: cross-individual consistency) + best_pc1_orderings, all_pc1_orderings = compute_pc1_orderings( + file_individual_data, file_metadata, best_individuals + ) + + # PASS 2: Cross-Individual Ordering Consistency + print( + "\nPass 2: Cross-Individual Ordering Consistency — " + "do individuals from the same video agree on node ordering?\n" + " Each individual's raw PC1 ordering of GT nodes is compared " + "against the best individual's ordering (the 'pseudo GT').\n" + " This is a CONSISTENCY check, not a correctness check — " + "high agreement means the body shape is stable across " + "individuals, but says nothing about whether the ordering " + "is anatomically correct." + ) + ordering_matches = compare_orderings_to_pseudo_gt( + all_pc1_orderings, best_pc1_orderings + ) + for file_stem, matches in ordering_matches.items(): + n_match = sum(matches.values()) + n_total = len(matches) + label = FILE_LABELS.get(file_stem, file_stem[:15]) + print( + f" {label}: {n_match}/{n_total} individuals " + f"share the best individual's ordering" + ) + + # PASS 3: Inferred AP Concordance per Individual + # For each individual, project GT nodes onto the inferred + # AP axis (anterior_sign × PC1) and compare the ordering + # against hand-curated GT. This tests the full pipeline + # (PCA + velocity voting) per individual. Feeds Figure 1 Tile 4. + print( + "\nPass 3: Inferred AP Concordance — " + "does each individual's velocity-inferred AP ordering " + "match hand-curated GT?\n" + " For each individual, GT nodes are projected onto " + "anterior_sign × PC1 (the inferred AP axis). All C(n,2) " + "unique pairs are tested against the hand-curated GT " + "ranking.\n" + " This tests the full pipeline (PCA + velocity voting) " + "per individual." + ) + individual_accuracy = compute_inferred_ap_concordance( + file_individual_data + ) + + # Build R, M, R×M lookup from Pass 1 results + ind_metrics = {} + for rec in rxm_results: + if not rec.get("error", False): + key = (rec["file"], rec["individual"]) + ind_metrics[key] = { + "R": rec.get("resultant_length", float("nan")), + "M": rec.get("vote_margin", float("nan")), + "rxm": rec.get("rxm", float("nan")), + } + + for file_stem, accuracies in individual_accuracy.items(): + print(f" {FILE_LABELS.get(file_stem, file_stem[:15])}:") + best_ind = best_individuals.get(file_stem) + for ind, acc in sorted(accuracies.items()): + c = acc["correct"] + t = acc["total"] + a = acc["accuracy"] + marker = "*" if ind == best_ind else "" + m = ind_metrics.get((file_stem, ind), {}) + R = m.get("R", float("nan")) + M = m.get("M", float("nan")) + rxm = m.get("rxm", float("nan")) + print( + f" {marker}{ind}{marker}: " + f"{c}/{t} pairs = {a:.1f}% " + f"(R={R:.2f}, M={M:.2f}, R×M={rxm:.2f})" + ) + + # Generate validation tasks for H5 storage + print( + "\nGenerating H5 validation records " + "(all GT pair permutations × all individuals)..." + ) + tasks, keypoints_by_file = generate_gt_validation_tasks( + file_metadata, best_individuals + ) + n_tasks = len(tasks) + print( + f" Processing {n_tasks} validation comparisons " + f"with {N_WORKERS} workers..." + ) + + all_results = [] + with ProcessPoolExecutor(max_workers=N_WORKERS) as executor: + futures = [ + executor.submit(process_single_validation, task) + for task in tasks + ] + + for i, future in enumerate(as_completed(futures), 1): + result = future.result() + all_results.append(result) + if i % 50 == 0: + print(f" {i}/{len(tasks)} completed...") + + h5_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_path = H5_DIR / f"ap_validation_{h5_timestamp}.h5" + # Include Pass 1 results to ensure best individual's data is saved + combined_results = rxm_results + all_results + save_to_h5( + combined_results, + keypoints_by_file, + output_path, + individual_accuracy=individual_accuracy, + all_rxm=all_rxm, + ) + + print(f"\nSaved {len(all_results)} records to: {output_path}") + + # Reporting: GT coverage, suggested pairs, and Figure 2 data + analysis_data = analyze_results(output_path) + plot_validation_results(analysis_data, output_path) + + # Generate per-file 2×2 tile detail plots for each best individual + # Extract timestamp from H5 filename (ap_validation_YYYYMMDD_HHMMSS.h5) + h5_stem = output_path.stem + fig_timestamp = "_".join(h5_stem.split("_")[-2:]) + generate_individual_plots(analysis_data, FIGURES_DIR, fig_timestamp) + + print(f"\nLog saved to: {log_file}") + + +if __name__ == "__main__": + main() diff --git a/movement/kinematics/body_axis.py b/movement/kinematics/body_axis.py new file mode 100644 index 000000000..054af965e --- /dev/null +++ b/movement/kinematics/body_axis.py @@ -0,0 +1,2910 @@ +"""Body-axis inference and anterior-posterior validation for pose data. + +This module provides infrastructure for validating user-supplied body-axis +keypoint pairs by inferring the anterior-posterior (AP) axis from motion +data. It uses a prior-free approach combining: + +1. High-motion segment detection via tiered validity and sliding windows +2. Postural clustering via k-medoids (when posture varies across segments) +3. PCA-based body-axis extraction from centered skeletons +4. Velocity projection voting to infer anterior direction +5. A 3-step filter cascade to evaluate candidate AP keypoint pairs + +""" + +from collections.abc import Hashable +from dataclasses import dataclass, field +from typing import Any + +import numpy as np +import xarray as xr + +from movement.utils.logging import logger + +# Separator line for log output formatting +_LOG_SEPARATOR = "\u2500" * 60 + + +# Configuration and Data Classes +# ────────────────────────────── + + +@dataclass +class ValidateAPConfig: + """Configuration for the validate_ap function. + + Parameters + ---------- + min_valid_frac : float, default=0.6 + Minimum fraction of keypoints that must be present for a frame + to qualify as tier-1 valid. + window_len : int, default=50 + Number of speed samples per sliding window. + stride : int, default=5 + Step size between consecutive sliding window start positions. + pct_thresh : float, default=85.0 + Percentile threshold applied to valid-window median speeds for + high-motion classification. + min_run_len : int, default=1 + Minimum number of consecutive qualifying windows required to + form a valid run. + postural_var_ratio_thresh : float, default=2.0 + Between-segment to within-segment RMSD variance ratio above which + postural clustering is triggered. + max_clusters : int, default=4 + Upper bound on the number of clusters to evaluate during k-medoids. + confidence_floor : float, default=0.1 + Vote margin below which the anterior inference is flagged as + unreliable. + lateral_thresh : float, default=0.4 + Normalized lateral offset ceiling for the Step 1 lateral alignment + filter. + edge_thresh : float, default=0.3 + Normalized midpoint distance floor for the Step 3 distal/proximal + classification. + + """ + + min_valid_frac: float = 0.6 + window_len: int = 50 + stride: int = 5 + pct_thresh: float = 85.0 + min_run_len: int = 1 + postural_var_ratio_thresh: float = 2.0 + max_clusters: int = 4 + confidence_floor: float = 0.1 + lateral_thresh: float = 0.4 + edge_thresh: float = 0.3 + + def __post_init__(self) -> None: + """Validate configuration parameters.""" + for name in ( + "min_valid_frac", + "confidence_floor", + "lateral_thresh", + "edge_thresh", + ): + value = getattr(self, name) + if not (0 <= value <= 1): + raise ValueError( + f"{name} must be between 0 and 1, got {value}" + ) + + for name in ("window_len", "stride", "min_run_len", "max_clusters"): + value = getattr(self, name) + if not isinstance(value, int) or value <= 0: + raise ValueError( + f"{name} must be a positive integer, got {value}" + ) + + if not (0 <= self.pct_thresh <= 100): + raise ValueError( + f"pct_thresh must be between 0 and 100, got {self.pct_thresh}" + ) + + if self.postural_var_ratio_thresh <= 0: + raise ValueError( + f"postural_var_ratio_thresh must be positive, " + f"got {self.postural_var_ratio_thresh}" + ) + + +@dataclass +class FrameSelection: + """Selected frames from high-motion segmentation and tier-2 filtering. + + Bundles the frame indices, segment assignments, and related arrays + produced by the segmentation pipeline for downstream consumption + (skeleton construction, postural clustering, velocity recomputation). + + Attributes + ---------- + frames : np.ndarray + Array of selected frame indices (tier-2 valid, within segments). + seg_ids : np.ndarray + Segment ID (0-indexed) for each selected frame. + segments : np.ndarray + Array of shape (n_segments, 2) with [frame_start, frame_end]. + bbox_centroids : np.ndarray + Array of shape (n_frames, 2) with bounding-box centroids. + count : int + Number of selected frames. + + """ + + frames: np.ndarray + seg_ids: np.ndarray + segments: np.ndarray + bbox_centroids: np.ndarray + count: int + + +@dataclass +class APNodePairReport: + """Report from the AP node-pair evaluation pipeline. + + This dataclass holds all results from the 3-step filter cascade + used to evaluate a candidate anterior-posterior keypoint pair. + + Attributes + ---------- + success : bool + Whether the evaluation pipeline completed successfully. + failure_step : str + Name of the step at which evaluation failed, if any. + failure_reason : str + Reason for failure, if any. + scenario : int + Scenario number (1-13) from the mutually exclusive outcomes. + outcome : str + Either "accept" or "warn". + warning_message : str + Warning message, if applicable. + sorted_candidate_nodes : np.ndarray + Indices of candidate nodes after Step 1 filtering, sorted by + ascending normalized lateral offset. + valid_pairs : np.ndarray + Array of shape (n_pairs, 2) containing valid node pairs after + Step 2 filtering. + valid_pairs_internode_dist : np.ndarray + Internode separation (AP distance) for each valid pair. + input_pair_in_candidates : bool + Whether the input pair survived Step 1 filtering. + input_pair_opposite_sides : bool + Whether the input pair lies on opposite sides of the midpoint. + input_pair_separation_abs : float + Absolute AP separation of the input pair. + input_pair_is_distal : bool + Whether the input pair is classified as distal in Step 3. + input_pair_rank : int + Rank of the input pair by internode separation (1 = largest). + input_pair_order_matches_inference : bool + Whether from_node has a lower AP coordinate than to_node + (i.e. from_node is more posterior). True means the input pair + ordering is consistent with the inferred AP axis. + pc1_coords : np.ndarray + PC1 coordinates for each keypoint. + ap_coords : np.ndarray + AP (anterior-posterior) coordinates for each keypoint. + lateral_offsets : np.ndarray + Unsigned lateral offset from body axis for each keypoint. + lateral_offsets_norm : np.ndarray + Normalized lateral offsets (0 = nearest to axis, 1 = farthest). + lateral_offset_min : float + Minimum lateral offset among valid keypoints. + lateral_offset_max : float + Maximum lateral offset among valid keypoints. + midpoint_pc1 : float + AP reference midpoint (average of min and max PC1 projections). + pc1_min : float + Minimum PC1 projection among valid keypoints. + pc1_max : float + Maximum PC1 projection among valid keypoints. + midline_dist_norm : np.ndarray + Normalized distance from midpoint for each keypoint. + midline_dist_max : float + Maximum absolute distance from midpoint. + distal_pairs : np.ndarray + Array of distal pairs (both nodes at or above edge_thresh). + proximal_pairs : np.ndarray + Array of proximal pairs (at least one node below edge_thresh). + max_separation_distal_nodes : np.ndarray + Node indices of the maximum-separation distal pair, ordered + so that element 0 is posterior (lower AP coord) and element 1 + is anterior (higher AP coord). + max_separation_distal : float + Internode separation of the max-separation distal pair. + max_separation_nodes : np.ndarray + Node indices of the overall maximum-separation pair, ordered + so that element 0 is posterior (lower AP coord) and element 1 + is anterior (higher AP coord). + max_separation : float + Internode separation of the overall max-separation pair. + + """ + + success: bool = False + failure_step: str = "" + failure_reason: str = "" + scenario: int = 0 + outcome: str = "" + warning_message: str = "" + + sorted_candidate_nodes: np.ndarray = field( + default_factory=lambda: np.array([], dtype=int) + ) + valid_pairs: np.ndarray = field( + default_factory=lambda: np.zeros((0, 2), dtype=int) + ) + valid_pairs_internode_dist: np.ndarray = field( + default_factory=lambda: np.array([]) + ) + + input_pair_in_candidates: bool = False + input_pair_opposite_sides: bool = False + input_pair_separation_abs: float = np.nan + input_pair_is_distal: bool = False + input_pair_rank: int = 0 + input_pair_order_matches_inference: bool = False + + pc1_coords: np.ndarray = field(default_factory=lambda: np.array([])) + ap_coords: np.ndarray = field(default_factory=lambda: np.array([])) + lateral_offsets: np.ndarray = field(default_factory=lambda: np.array([])) + lateral_offsets_norm: np.ndarray = field( + default_factory=lambda: np.array([]) + ) + lateral_offset_min: float = np.nan + lateral_offset_max: float = np.nan + midpoint_pc1: float = np.nan + pc1_min: float = np.nan + pc1_max: float = np.nan + midline_dist_norm: np.ndarray = field(default_factory=lambda: np.array([])) + midline_dist_max: float = np.nan + + distal_pairs: np.ndarray = field( + default_factory=lambda: np.zeros((0, 2), dtype=int) + ) + proximal_pairs: np.ndarray = field( + default_factory=lambda: np.zeros((0, 2), dtype=int) + ) + max_separation_distal_nodes: np.ndarray = field( + default_factory=lambda: np.array([], dtype=int) + ) + max_separation_distal: float = np.nan + max_separation_nodes: np.ndarray = field( + default_factory=lambda: np.array([], dtype=int) + ) + max_separation: float = np.nan + + +# Tiered Validity and Centroid Computation +# ───────────────────────────────────────── + + +def compute_tiered_validity( + keypoints: np.ndarray, + min_valid_frac: float, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Compute tiered validity masks for each frame. + + Parameters + ---------- + keypoints : np.ndarray + Keypoint positions with shape (n_frames, n_keypoints, 2). + min_valid_frac : float + Minimum fraction of keypoints required for tier-1 validity. + + Returns + ------- + tier1_valid : np.ndarray + Boolean array of shape (n_frames,) indicating tier-1 valid frames. + A frame is tier-1 valid if at least min_valid_frac of keypoints + are present AND at least 2 keypoints are present. + tier2_valid : np.ndarray + Boolean array of shape (n_frames,) indicating tier-2 valid frames. + A frame is tier-2 valid if all keypoints are present. + frac_present : np.ndarray + Array of shape (n_frames,) with fraction of keypoints present. + + """ + _, n_keypoints, _ = keypoints.shape + + keypoint_present = ~np.any(np.isnan(keypoints), axis=2) + n_present = np.sum(keypoint_present, axis=1) + frac_present = n_present / n_keypoints + + tier2_valid = n_present == n_keypoints + tier1_valid = (frac_present >= min_valid_frac) & (n_present >= 2) + + return tier1_valid, tier2_valid, frac_present + + +def compute_bbox_centroid( + keypoints: np.ndarray, + tier1_valid: np.ndarray, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Compute bounding-box centroids for tier-1 valid frames. + + The bounding-box centroid is the midpoint of the axis-aligned bounding + box enclosing all present keypoints. This is density-invariant, unlike + the arithmetic mean. + + Parameters + ---------- + keypoints : np.ndarray + Keypoint positions with shape (n_frames, n_keypoints, 2). + tier1_valid : np.ndarray + Boolean array of shape (n_frames,) indicating tier-1 valid frames. + + Returns + ------- + bbox_centroids : np.ndarray + Array of shape (n_frames, 2) with bounding-box centroids. + NaN for non-tier-1-valid frames. + arith_centroids : np.ndarray + Array of shape (n_frames, 2) with arithmetic-mean centroids. + NaN for non-tier-1-valid frames. Used for diagnostic comparison. + centroid_discrepancy : np.ndarray + Array of shape (n_frames,) with normalized discrepancy between + bbox and arithmetic centroids (distance / bbox_diagonal). + NaN for non-tier-1-valid frames. + + """ + n_frames = keypoints.shape[0] + + bbox_centroids = np.full((n_frames, 2), np.nan) + arith_centroids = np.full((n_frames, 2), np.nan) + centroid_discrepancy = np.full(n_frames, np.nan) + + for f in range(n_frames): + if not tier1_valid[f]: + continue + + kp_f = keypoints[f] + present_mask = ~np.any(np.isnan(kp_f), axis=1) + kp_present = kp_f[present_mask] + + bbox_min = np.min(kp_present, axis=0) + bbox_max = np.max(kp_present, axis=0) + bbox_centroids[f] = (bbox_min + bbox_max) / 2 + + arith_centroids[f] = np.mean(kp_present, axis=0) + + bbox_diag = np.linalg.norm(bbox_max - bbox_min) + if bbox_diag > 0: + discrepancy = np.linalg.norm( + bbox_centroids[f] - arith_centroids[f] + ) + centroid_discrepancy[f] = discrepancy / bbox_diag + else: + centroid_discrepancy[f] = 0.0 + + return bbox_centroids, arith_centroids, centroid_discrepancy + + +# Velocity and Motion Detection +# ────────────────────────────── + + +def compute_frame_velocities( + bbox_centroids: np.ndarray, + tier1_valid: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """Compute frame-to-frame centroid velocities and speeds. + + A velocity is valid only when both adjacent frames are tier-1 valid. + + Parameters + ---------- + bbox_centroids : np.ndarray + Array of shape (n_frames, 2) with bounding-box centroids. + tier1_valid : np.ndarray + Boolean array of shape (n_frames,) indicating tier-1 valid frames. + + Returns + ------- + velocities : np.ndarray + Array of shape (n_frames - 1, 2) with velocity vectors. + Invalid velocities are NaN. + speeds : np.ndarray + Array of shape (n_frames - 1,) with speed scalars. + Invalid speeds are NaN. + + """ + velocities = np.diff(bbox_centroids, axis=0) + speed_valid = tier1_valid[:-1] & tier1_valid[1:] + velocities[~speed_valid] = np.nan + speeds = np.linalg.norm(velocities, axis=1) + + return velocities, speeds + + +def compute_sliding_window_medians( + speeds: np.ndarray, + window_len: int, + stride: int, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Compute median speeds for sliding windows. + + A window is valid only when every speed sample in that window is valid + (non-NaN), ensuring strict NaN-free content. + + Parameters + ---------- + speeds : np.ndarray + Array of shape (n_speed_samples,) with speed values. + window_len : int + Number of speed samples per sliding window. + stride : int + Step size between consecutive window start positions. + + Returns + ------- + window_starts : np.ndarray + Array of window start indices (0-indexed). + window_medians : np.ndarray + Median speed for each window. NaN for invalid windows. + window_all_valid : np.ndarray + Boolean array indicating which windows are fully valid. + + """ + num_speed = len(speeds) + window_starts = np.arange(0, num_speed - window_len + 1, stride) + num_windows = len(window_starts) + + window_medians = np.full(num_windows, np.nan) + window_all_valid = np.zeros(num_windows, dtype=bool) + + for k in range(num_windows): + s = window_starts[k] + e = s + window_len + w = speeds[s:e] + + if np.all(~np.isnan(w)): + window_all_valid[k] = True + window_medians[k] = np.median(w) + + return window_starts, window_medians, window_all_valid + + +def detect_high_motion_windows( + window_medians: np.ndarray, + window_all_valid: np.ndarray, + pct_thresh: float, +) -> np.ndarray: + """Identify high-motion windows based on percentile threshold. + + Parameters + ---------- + window_medians : np.ndarray + Median speed for each window. + window_all_valid : np.ndarray + Boolean array indicating which windows are fully valid. + pct_thresh : float + Percentile threshold (0-100) for high-motion classification. + + Returns + ------- + high_motion : np.ndarray + Boolean array indicating high-motion windows. + + """ + valid_medians = window_medians[window_all_valid] + if len(valid_medians) == 0: + return np.zeros(len(window_medians), dtype=bool) + + thresh = np.percentile(valid_medians, pct_thresh) + high_motion = window_all_valid & (window_medians >= thresh) + + return high_motion + + +# Run and Segment Detection +# ────────────────────────── + + +def detect_runs( + high_motion: np.ndarray, + min_run_len: int, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Detect runs of consecutive high-motion windows. + + A run is a maximal sequence of consecutively indexed qualifying windows. + + Parameters + ---------- + high_motion : np.ndarray + Boolean array indicating high-motion windows. + min_run_len : int + Minimum number of consecutive qualifying windows for a valid run. + + Returns + ------- + run_starts : np.ndarray + Start indices of valid runs. + run_ends : np.ndarray + End indices (inclusive) of valid runs. + run_lengths : np.ndarray + Length of each valid run. + + """ + padded = np.concatenate([[False], high_motion, [False]]) + d = np.diff(padded.astype(int)) + + run_starts_all = np.nonzero(d == 1)[0] + run_ends_all = np.nonzero(d == -1)[0] - 1 + run_lengths_all = run_ends_all - run_starts_all + 1 + + valid_mask = run_lengths_all >= min_run_len + run_starts = run_starts_all[valid_mask] + run_ends = run_ends_all[valid_mask] + run_lengths = run_lengths_all[valid_mask] + + return run_starts, run_ends, run_lengths + + +def convert_runs_to_segments( + run_starts: np.ndarray, + run_ends: np.ndarray, + window_starts: np.ndarray, + window_len: int, +) -> np.ndarray: + """Convert window runs to frame segments. + + Each run is converted to a frame interval spanning from the start frame + of the first window to the end frame of the last window. + + Parameters + ---------- + run_starts : np.ndarray + Start indices of valid runs (indices into window arrays). + run_ends : np.ndarray + End indices (inclusive) of valid runs. + window_starts : np.ndarray + Start frame indices for each window. + window_len : int + Length of each window in frames. + + Returns + ------- + segments_raw : np.ndarray + Array of shape (n_runs, 2) with [frame_start, frame_end] for each run. + + """ + n_runs = len(run_starts) + segments_raw = np.zeros((n_runs, 2), dtype=int) + + for j in range(n_runs): + s_idx = run_starts[j] + e_idx = run_ends[j] + frame_start = window_starts[s_idx] + frame_end = window_starts[e_idx] + window_len + segments_raw[j] = [frame_start, frame_end] + + return segments_raw + + +def merge_segments(segments_raw: np.ndarray) -> np.ndarray: + """Merge overlapping or abutting frame segments. + + Segments are first sorted by start frame, then merged if they overlap + or abut (next start <= current end + 1). + + Parameters + ---------- + segments_raw : np.ndarray + Array of shape (n_segments, 2) with [frame_start, frame_end]. + + Returns + ------- + segments : np.ndarray + Array of merged non-overlapping segments. + + """ + if len(segments_raw) == 0: + return segments_raw + + sorted_idx = np.argsort(segments_raw[:, 0]) + segments_sorted = segments_raw[sorted_idx] + + merged = [segments_sorted[0].tolist()] + + for j in range(1, len(segments_sorted)): + next_seg = segments_sorted[j] + curr_seg = merged[-1] + + if next_seg[0] <= curr_seg[1] + 1: + merged[-1][1] = max(curr_seg[1], next_seg[1]) + else: + merged.append(next_seg.tolist()) + + return np.array(merged, dtype=int) + + +def filter_segments_tier2( + segments: np.ndarray, + tier2_valid: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """Filter segment frames to retain only tier-2 valid frames. + + Parameters + ---------- + segments : np.ndarray + Array of shape (n_segments, 2) with [frame_start, frame_end]. + tier2_valid : np.ndarray + Boolean array of shape (n_frames,) indicating tier-2 valid frames. + + Returns + ------- + selected_frames : np.ndarray + Array of tier-2 valid frame indices within segments. + selected_seg_id : np.ndarray + Segment ID (0-indexed) for each selected frame. + + """ + all_segment_frames: list[int] = [] + for k in range(len(segments)): + frame_start, frame_end = segments[k] + seg_frames = np.arange(frame_start, frame_end + 1) + all_segment_frames.extend(seg_frames) + + segment_frames_all = np.unique(all_segment_frames) + + tier2_mask = tier2_valid[segment_frames_all] + selected_frames = segment_frames_all[tier2_mask] + + num_selected = len(selected_frames) + selected_seg_id = np.zeros(num_selected, dtype=int) + + for j in range(num_selected): + f = selected_frames[j] + for k in range(len(segments)): + if segments[k, 0] <= f <= segments[k, 1]: + selected_seg_id[j] = k + break + + return selected_frames, selected_seg_id + + +# Skeleton Analysis +# ────────────────── + + +def build_centered_skeletons( + keypoints: np.ndarray, + selected_frames: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """Build centroid-centered skeletons for selected frames. + + Uses bounding-box centroid for centering, consistent with the + segmentation centroid. + + Parameters + ---------- + keypoints : np.ndarray + Keypoint positions with shape (n_frames, n_keypoints, 2). + selected_frames : np.ndarray + Array of selected frame indices. + + Returns + ------- + selected_centroids : np.ndarray + Array of shape (num_selected, 2) with bounding-box centroids. + centered_skeletons : np.ndarray + Array of shape (num_selected, n_keypoints, 2) with + centroid-centered skeleton coordinates. + + """ + num_selected = len(selected_frames) + n_keypoints = keypoints.shape[1] + + selected_centroids = np.zeros((num_selected, 2)) + centered_skeletons = np.zeros((num_selected, n_keypoints, 2)) + + for j in range(num_selected): + f = selected_frames[j] + kp_f = keypoints[f] + + bbox_min = np.min(kp_f, axis=0) + bbox_max = np.max(kp_f, axis=0) + centroid_f = (bbox_min + bbox_max) / 2 + + selected_centroids[j] = centroid_f + centered_skeletons[j] = kp_f - centroid_f + + return selected_centroids, centered_skeletons + + +def compute_pairwise_rmsd(centered_skeletons: np.ndarray) -> np.ndarray: + """Compute pairwise RMSD between all centered skeletons. + + RMSD is computed as the square root of the mean of squared entry-wise + differences between flattened skeleton vectors. + + Parameters + ---------- + centered_skeletons : np.ndarray + Array of shape (num_selected, n_keypoints, 2). + + Returns + ------- + rmsd_matrix : np.ndarray + Symmetric matrix of shape (num_selected, num_selected) with + pairwise RMSD values. Diagonal is zero. + + """ + num_selected = len(centered_skeletons) + skel_flat = centered_skeletons.reshape(num_selected, -1) + rmsd_matrix = np.zeros((num_selected, num_selected)) + + for i in range(num_selected): + for j in range(i + 1, num_selected): + d = skel_flat[i] - skel_flat[j] + rmsd_val = np.sqrt(np.mean(d**2)) + rmsd_matrix[i, j] = rmsd_val + rmsd_matrix[j, i] = rmsd_val + + return rmsd_matrix + + +def compute_postural_variance_ratio( + rmsd_matrix: np.ndarray, + selected_seg_id: np.ndarray, +) -> tuple[float, np.ndarray, np.ndarray, bool]: + """Compute the between/within segment RMSD variance ratio. + + Parameters + ---------- + rmsd_matrix : np.ndarray + Pairwise RMSD matrix of shape (num_selected, num_selected). + selected_seg_id : np.ndarray + Segment ID for each selected frame. + + Returns + ------- + var_ratio : float + Ratio of between-segment to within-segment RMSD variance. + Returns 0.0 if either distribution is empty or within variance is 0. + within_rmsds : np.ndarray + Array of within-segment RMSD values. + between_rmsds : np.ndarray + Array of between-segment RMSD values. + var_ratio_override : bool + True if variance ratio was set to 0 due to edge cases. + + """ + num_selected = len(selected_seg_id) + within_rmsds_list: list[float] = [] + between_rmsds_list: list[float] = [] + + for i in range(num_selected): + for j in range(i + 1, num_selected): + if selected_seg_id[i] == selected_seg_id[j]: + within_rmsds_list.append(rmsd_matrix[i, j]) + else: + between_rmsds_list.append(rmsd_matrix[i, j]) + + within_rmsds = np.array(within_rmsds_list) + between_rmsds = np.array(between_rmsds_list) + + var_ratio_override = False + if ( + len(within_rmsds) > 0 + and len(between_rmsds) > 0 + and np.var(within_rmsds) > 0 + ): + var_ratio = np.var(between_rmsds) / np.var(within_rmsds) + else: + var_ratio = 0.0 + var_ratio_override = True + + return var_ratio, within_rmsds, between_rmsds, var_ratio_override + + +# Clustering (k-medoids) +# ─────────────────────── + + +def _update_medoid_for_cluster( + cluster: int, + labels: np.ndarray, + medoids: np.ndarray, + dist_matrix: np.ndarray, +) -> int: + """Find the optimal medoid for a single cluster.""" + cluster_mask = labels == cluster + if not np.any(cluster_mask): + return medoids[cluster] + + cluster_indices = np.nonzero(cluster_mask)[0] + cluster_dists = dist_matrix[np.ix_(cluster_indices, cluster_indices)] + total_dists = np.sum(cluster_dists, axis=1) + best_idx = np.argmin(total_dists) + return cluster_indices[best_idx] + + +def kmedoids( + data: np.ndarray, + k: int, + max_iter: int = 100, + n_init: int = 5, + random_state: int | None = None, +) -> tuple[np.ndarray, np.ndarray, float]: + """Perform k-medoids clustering. + + Parameters + ---------- + data : np.ndarray + Array of shape (n_samples, n_features). + k : int + Number of clusters. + max_iter : int, default=100 + Maximum number of iterations. + n_init : int, default=5 + Number of random initializations. + random_state : int, optional + Random seed for reproducibility. + + Returns + ------- + labels : np.ndarray + Cluster labels for each sample (0-indexed). + medoid_indices : np.ndarray + Indices of medoid samples. + inertia : float + Sum of distances from samples to their medoids. + + """ + from scipy.spatial.distance import cdist + + rng = np.random.default_rng(random_state) + n_samples = len(data) + + dist_matrix = cdist(data, data, metric="euclidean") + + best_labels: np.ndarray | None = None + best_medoids: np.ndarray | None = None + best_inertia = np.inf + + for _ in range(n_init): + medoids = rng.choice(n_samples, size=k, replace=False) + + for _ in range(max_iter): + distances_to_medoids = dist_matrix[:, medoids] + labels = np.argmin(distances_to_medoids, axis=1) + + new_medoids = np.array( + [ + _update_medoid_for_cluster(c, labels, medoids, dist_matrix) + for c in range(k) + ] + ) + + if np.array_equal(np.sort(medoids), np.sort(new_medoids)): + break + medoids = new_medoids + + distances_to_medoids = dist_matrix[:, medoids] + labels = np.argmin(distances_to_medoids, axis=1) + inertia = np.sum(distances_to_medoids[np.arange(n_samples), labels]) + + if inertia < best_inertia: + best_inertia = inertia + best_labels = labels.copy() + best_medoids = medoids.copy() + + assert best_labels is not None and best_medoids is not None + return best_labels, best_medoids, best_inertia + + +def _compute_intra_cluster_dist( + i: int, + labels: np.ndarray, + dist_matrix: np.ndarray, + n_samples: int, +) -> float: + """Compute mean distance from sample i to other samples in its cluster.""" + own_cluster = labels[i] + own_mask = labels == own_cluster + if np.sum(own_mask) > 1: + return np.mean(dist_matrix[i, own_mask & (np.arange(n_samples) != i)]) + return 0.0 + + +def _compute_nearest_cluster_dist( + i: int, + labels: np.ndarray, + dist_matrix: np.ndarray, + unique_labels: np.ndarray, +) -> float: + """Compute mean distance to nearest other cluster.""" + own_cluster = labels[i] + b_i = np.inf + for cluster in unique_labels: + if cluster == own_cluster: + continue + cluster_mask = labels == cluster + if np.any(cluster_mask): + mean_dist = np.mean(dist_matrix[i, cluster_mask]) + b_i = min(b_i, mean_dist) + return b_i + + +def silhouette_score(data: np.ndarray, labels: np.ndarray) -> float: + """Compute mean silhouette score. + + Parameters + ---------- + data : np.ndarray + Array of shape (n_samples, n_features). + labels : np.ndarray + Cluster labels for each sample. + + Returns + ------- + score : float + Mean silhouette score across all samples. + Returns 0.0 if clustering is degenerate. + + """ + from scipy.spatial.distance import cdist + + n_samples = len(data) + unique_labels = np.unique(labels) + n_clusters = len(unique_labels) + + if n_clusters <= 1 or n_clusters >= n_samples: + return 0.0 + + dist_matrix = cdist(data, data, metric="euclidean") + silhouette_vals = np.zeros(n_samples) + + for i in range(n_samples): + a_i = _compute_intra_cluster_dist(i, labels, dist_matrix, n_samples) + b_i = _compute_nearest_cluster_dist( + i, labels, dist_matrix, unique_labels + ) + + if b_i == np.inf: + silhouette_vals[i] = 0.0 + elif max(a_i, b_i) > 0: + silhouette_vals[i] = (b_i - a_i) / max(a_i, b_i) + else: + silhouette_vals[i] = 0.0 + + return float(np.mean(silhouette_vals)) + + +def perform_postural_clustering( + centered_skeletons: np.ndarray, + max_clusters: int, + min_silhouette: float = 0.2, +) -> tuple[np.ndarray, int, int, float, list[tuple[int, float]]]: + """Perform postural clustering using k-medoids with silhouette selection. + + Parameters + ---------- + centered_skeletons : np.ndarray + Array of shape (num_selected, n_keypoints, 2). + max_clusters : int + Maximum number of clusters to evaluate. + min_silhouette : float, default=0.2 + Minimum silhouette score to accept clustering. + + Returns + ------- + cluster_labels : np.ndarray + Cluster labels for each frame (0-indexed). + num_clusters : int + Number of clusters (1 if clustering not accepted). + primary_cluster : int + Index of largest cluster (0-indexed). + best_silhouette : float + Best silhouette score achieved. + silhouette_scores : list of (k, score) + Silhouette scores for each k evaluated. + + """ + num_selected = len(centered_skeletons) + skel_flat = centered_skeletons.reshape(num_selected, -1) + + best_k = 1 + best_sil = -np.inf + silhouette_scores = [] + + max_k = min(max_clusters, num_selected // 2) + + for k in range(2, max_k + 1): + try: + labels, _, _ = kmedoids(skel_flat, k, n_init=5) + sil = silhouette_score(skel_flat, labels) + silhouette_scores.append((k, sil)) + + if sil > best_sil: + best_sil = sil + best_k = k + except Exception: + silhouette_scores.append((k, np.nan)) + + if best_k > 1 and best_sil > min_silhouette: + cluster_labels, _, _ = kmedoids(skel_flat, best_k, n_init=10) + num_clusters = best_k + + cluster_counts = np.bincount(cluster_labels, minlength=num_clusters) + primary_cluster = int(np.argmax(cluster_counts)) + else: + cluster_labels = np.zeros(num_selected, dtype=int) + num_clusters = 1 + primary_cluster = 0 + + return ( + cluster_labels, + num_clusters, + primary_cluster, + best_sil, + silhouette_scores, + ) + + +# PCA and Anterior Inference +# ─────────────────────────── + + +def compute_cluster_velocities( + selected_frames: np.ndarray, + selected_seg_id: np.ndarray, + cluster_mask: np.ndarray, + segments: np.ndarray, + bbox_centroids: np.ndarray, +) -> np.ndarray: + """Compute velocities between adjacent consecutive frames. + + Only considers frames in the same segment and cluster. Frame pairs + where both frames are consecutive (frame[i] == frame[i-1] + 1), + in the same segment, and in the same cluster contribute a velocity + vector. + + Returns + ------- + np.ndarray + Array of shape (n_velocities, 2). Empty (0, 2) if no valid pairs. + + """ + frames_c = selected_frames[cluster_mask] + seg_ids_c = selected_seg_id[cluster_mask] + velocities_list: list[np.ndarray] = [] + + for seg_k in range(len(segments)): + seg_mask = seg_ids_c == seg_k + seg_frames = np.sort(frames_c[seg_mask]) + for fi in range(1, len(seg_frames)): + if seg_frames[fi] != seg_frames[fi - 1] + 1: + continue + curr_frame = seg_frames[fi] + prev_frame = seg_frames[fi - 1] + v = bbox_centroids[curr_frame] - bbox_centroids[prev_frame] + if np.all(~np.isnan(v)): + velocities_list.append(v) + + return np.array(velocities_list) if velocities_list else np.zeros((0, 2)) + + +def infer_anterior_from_velocities( + velocities: np.ndarray, + pc1: np.ndarray, +) -> dict: + """Infer anterior direction from velocity projections onto PC1. + + Uses strict majority vote on PC1 projection signs: anterior = +PC1 + if n_positive > n_negative, else -PC1 (ties default to -PC1). + + Also computes circular statistics on velocity angles: + - resultant_length R = sqrt(C^2 + S^2) where C = mean(cos theta), + S = mean(sin theta) + - vote_margin M = |n+ - n-| / (n+ + n-) + + Returns dict with resultant_length, circ_mean_dir, vel_projs_pc1, + num_positive, num_negative, vote_margin, anterior_sign. + + """ + result: dict = { + "resultant_length": 0.0, + "circ_mean_dir": np.nan, + "vel_projs_pc1": np.array([]), + "num_positive": 0, + "num_negative": 0, + "vote_margin": 0.0, + "anterior_sign": -1, + } + if len(velocities) == 0: + return result + + vel_angles = np.arctan2(velocities[:, 1], velocities[:, 0]) + cos_mean = np.mean(np.cos(vel_angles)) + sin_mean = np.mean(np.sin(vel_angles)) + result["resultant_length"] = np.sqrt(cos_mean**2 + sin_mean**2) + result["circ_mean_dir"] = np.arctan2(sin_mean, cos_mean) + + vel_projs = velocities @ pc1 + num_pos = int(np.sum(vel_projs > 0)) + num_neg = int(np.sum(vel_projs < 0)) + result["vel_projs_pc1"] = vel_projs + result["num_positive"] = num_pos + result["num_negative"] = num_neg + result["vote_margin"] = abs(num_pos - num_neg) / max(num_pos + num_neg, 1) + result["anterior_sign"] = +1 if num_pos > num_neg else -1 + return result + + +def compute_cluster_pca_and_anterior( + centered_skeletons: np.ndarray, + cluster_mask: np.ndarray, + selected_frames: np.ndarray, + selected_seg_id: np.ndarray, + segments: np.ndarray, + bbox_centroids: np.ndarray, +) -> dict: + """Compute SVD-based PCA and velocity-based anterior inference. + + Performs inference for one cluster. + + Performs SVD on the cluster's average centered skeleton to extract PC1/PC2, + applies the geometric sign convention, then infers the anterior direction + via velocity voting on centroid displacements projected onto PC1. + + Returns + ------- + dict + Keys: valid, n_frames, avg_skeleton, valid_shape_rows, + PC1, PC2, anterior_sign, vote_margin, resultant_length, + circ_mean_dir, velocities, vel_projs_pc1, and others. + + """ + n_keypoints = centered_skeletons.shape[1] + n_c = int(np.sum(cluster_mask)) + + result: dict = { + "valid": False, + "n_frames": n_c, + "avg_skeleton": np.full((n_keypoints, 2), np.nan), + "valid_shape_rows": np.zeros(n_keypoints, dtype=bool), + "PC1": np.array([1.0, 0.0]), + "PC2": np.array([0.0, 1.0]), + "proj_pc1": np.full(n_keypoints, np.nan), + "proj_pc2": np.full(n_keypoints, np.nan), + "anterior_sign": -1, + "num_positive": 0, + "num_negative": 0, + "vote_margin": 0.0, + "resultant_length": 0.0, + "circ_mean_dir": np.nan, + "velocities": np.zeros((0, 2)), + "vel_projs_pc1": np.array([]), + } + + if n_c == 0: + return result + + skels_c = centered_skeletons[cluster_mask] + avg_skel_c = np.mean(skels_c, axis=0) + valid_shape_rows = ~np.any(np.isnan(avg_skel_c), axis=1) + + if np.sum(valid_shape_rows) < 2: + return result + + result["avg_skeleton"] = avg_skel_c + result["valid_shape_rows"] = valid_shape_rows + + valid_rows = avg_skel_c[valid_shape_rows] + _u, _s, vt = np.linalg.svd(valid_rows, full_matrices=False) + PC1 = vt[0] + PC2 = vt[1] if len(vt) > 1 else np.array([0.0, 1.0]) + + # Geometric sign convention: + # PC1 flipped so y-component >= 0 + # PC2 flipped so x-component >= 0 + if PC1[1] < 0: + PC1 = -PC1 + if PC2[0] < 0: + PC2 = -PC2 + + result["PC1"] = PC1 + result["PC2"] = PC2 + + proj_pc1 = np.full(n_keypoints, np.nan) + proj_pc2 = np.full(n_keypoints, np.nan) + proj_pc1[valid_shape_rows] = avg_skel_c[valid_shape_rows] @ PC1 + proj_pc2[valid_shape_rows] = avg_skel_c[valid_shape_rows] @ PC2 + result["proj_pc1"] = proj_pc1 + result["proj_pc2"] = proj_pc2 + + velocities = compute_cluster_velocities( + selected_frames, + selected_seg_id, + cluster_mask, + segments, + bbox_centroids, + ) + result["velocities"] = velocities + result.update(infer_anterior_from_velocities(velocities, PC1)) + result["valid"] = True + return result + + +# AP Node-Pair Evaluation (3-Step Filter Cascade) +# ──────────────────────────────────────────────── + + +def compute_node_projections( + report: APNodePairReport, + avg_skeleton: np.ndarray, + pc1_vec: np.ndarray, + anterior_sign: int, + valid_shape_rows: np.ndarray, + from_node: int, + to_node: int, +) -> None: + """Compute raw PC1, AP-oriented, and lateral projections. + + Computes projections for all valid keypoints. + + Populates the report's coordinate arrays and determines: + - pc1_coords: raw projection onto PC1 (sign-convention only) + - ap_coords: projection onto anterior_sign * PC1 (positive = more + anterior) + - lateral_offsets: unsigned distance from the AP axis + - midpoint_pc1: average of min/max PC1 projections (AP reference point) + - input_pair_order_matches_inference: True if from_node's AP coord < + to_node's + + """ + pc1 = pc1_vec / np.linalg.norm(pc1_vec) + e_ap = anterior_sign * pc1 + e_lat = np.array([-e_ap[1], e_ap[0]]) + + valid_rows = avg_skeleton[valid_shape_rows] + report.pc1_coords[valid_shape_rows] = valid_rows @ pc1 + report.ap_coords[valid_shape_rows] = valid_rows @ e_ap + report.lateral_offsets[valid_shape_rows] = np.abs(valid_rows @ e_lat) + + if valid_shape_rows[from_node] and valid_shape_rows[to_node]: + report.input_pair_order_matches_inference = ( + report.ap_coords[from_node] < report.ap_coords[to_node] + ) + + proj_pc1_valid = report.pc1_coords[valid_shape_rows] + report.pc1_min = float(np.min(proj_pc1_valid)) + report.pc1_max = float(np.max(proj_pc1_valid)) + report.midpoint_pc1 = (report.pc1_min + report.pc1_max) / 2 + + +def apply_lateral_filter( + report: APNodePairReport, + valid_idx: np.ndarray, + lateral_thresh: float, +) -> np.ndarray | None: + """Step 1: Filter keypoints by normalized lateral offset. + + Returns sorted candidate node indices, or None on failure. + + """ + d_valid = report.lateral_offsets[valid_idx] + d_min = float(np.min(d_valid)) + d_max = float(np.max(d_valid)) + report.lateral_offset_min = d_min + report.lateral_offset_max = d_max + + if d_max > d_min: + d_norm = (d_valid - d_min) / (d_max - d_min) + report.lateral_offsets_norm[valid_idx] = d_norm + keep_mask = d_norm <= lateral_thresh + else: + report.lateral_offsets_norm[valid_idx] = np.zeros(len(d_valid)) + keep_mask = np.ones(len(d_valid), dtype=bool) + + candidate_idx = np.nonzero(keep_mask)[0] + candidates = valid_idx[candidate_idx] + sorted_order = np.argsort(d_valid[candidate_idx]) + candidates = candidates[sorted_order] + report.sorted_candidate_nodes = candidates.copy() + + if len(candidates) < 2: + report.failure_step = "Step 1: lateral alignment filter" + report.failure_reason = ( + "Fewer than 2 candidates remained after filtering." + ) + return None + return candidates + + +def find_opposite_side_pairs( + report: APNodePairReport, + candidates: np.ndarray, + from_node: int, + to_node: int, + valid_shape_rows: np.ndarray, +) -> tuple[np.ndarray, np.ndarray] | None: + """Step 2: Find candidate pairs on opposite sides of the AP midpoint. + + Returns (pairs, seps) arrays, or None on failure. + + """ + m = report.midpoint_pc1 + report.input_pair_in_candidates = (from_node in candidates) and ( + to_node in candidates + ) + + pairs_list: list[list[int]] = [] + seps_list: list[float] = [] + for ii in range(len(candidates)): + for jj in range(ii + 1, len(candidates)): + i, j = candidates[ii], candidates[jj] + if (report.pc1_coords[i] - m) * (report.pc1_coords[j] - m) < 0: + pairs_list.append([i, j]) + seps_list.append( + abs(report.ap_coords[i] - report.ap_coords[j]) + ) + + pairs = ( + np.array(pairs_list, dtype=int) + if pairs_list + else np.zeros((0, 2), dtype=int) + ) + seps = np.array(seps_list) if seps_list else np.array([]) + report.valid_pairs = pairs + report.valid_pairs_internode_dist = seps + + if valid_shape_rows[from_node] and valid_shape_rows[to_node]: + report.input_pair_opposite_sides = ( + (report.pc1_coords[from_node] - m) + * (report.pc1_coords[to_node] - m) + ) < 0 + report.input_pair_separation_abs = abs( + report.ap_coords[from_node] - report.ap_coords[to_node] + ) + + if len(pairs) == 0: + report.failure_step = "Step 2: opposite-sides constraint" + report.failure_reason = ( + "No candidate pair lies on opposite sides of the midpoint." + ) + return None + return pairs, seps + + +def order_pair_by_ap( + pair: np.ndarray, + ap_coords: np.ndarray, +) -> np.ndarray: + """Order a node pair so element 0 is posterior (lower AP coord). + + This ensures that suggested pairs always encode the + posterior->anterior direction, matching the convention used by + ``body_axis_keypoints=(from_node, to_node)`` where from_node is + posterior and to_node is anterior. + + """ + i, j = pair + if ap_coords[i] <= ap_coords[j]: + return np.array([i, j], dtype=int) + return np.array([j, i], dtype=int) + + +def classify_distal_proximal( + report: APNodePairReport, + pairs: np.ndarray, + seps: np.ndarray, + valid_shape_rows: np.ndarray, + edge_thresh: float, +) -> np.ndarray: + """Step 3: Classify pairs as distal or proximal. Returns pair_is_distal.""" + m = report.midpoint_pc1 + midline_dist = np.abs(report.pc1_coords - m) + d_max_midline = float(np.nanmax(midline_dist[valid_shape_rows])) + report.midline_dist_max = d_max_midline + + if d_max_midline > 0: + report.midline_dist_norm = midline_dist / d_max_midline + else: + report.midline_dist_norm = np.zeros(len(report.pc1_coords)) + + pair_is_distal = np.zeros(len(pairs), dtype=bool) + for k in range(len(pairs)): + i, j = pairs[k] + pair_is_distal[k] = ( + min(report.midline_dist_norm[i], report.midline_dist_norm[j]) + >= edge_thresh + ) + + report.distal_pairs = pairs[pair_is_distal] + report.proximal_pairs = pairs[~pair_is_distal] + + if len(seps) > 0: + idx_max = int(np.argmax(seps)) + report.max_separation_nodes = order_pair_by_ap( + pairs[idx_max], report.ap_coords + ) + report.max_separation = seps[idx_max] + + if np.any(pair_is_distal): + distal_seps = seps[pair_is_distal] + distal_pairs_only = pairs[pair_is_distal] + idx_max_distal = int(np.argmax(distal_seps)) + report.max_separation_distal_nodes = order_pair_by_ap( + distal_pairs_only[idx_max_distal], report.ap_coords + ) + report.max_separation_distal = distal_seps[idx_max_distal] + + return pair_is_distal + + +def check_input_pair_in_valid( + report: APNodePairReport, + pairs: np.ndarray, + seps: np.ndarray, + pair_is_distal: np.ndarray, + from_node: int, + to_node: int, +) -> tuple[bool, int]: + """Check whether input pair is among valid pairs. Returns (found, idx).""" + input_pair_sorted = tuple(sorted([from_node, to_node])) + input_in_valid = False + input_idx = -1 + + for k in range(len(pairs)): + if tuple(sorted(pairs[k])) == input_pair_sorted: + input_in_valid = True + input_idx = k + break + + if input_in_valid: + report.input_pair_is_distal = pair_is_distal[input_idx] + rank_order = np.argsort(seps)[::-1] + report.input_pair_rank = ( + int(np.nonzero(rank_order == input_idx)[0][0]) + 1 + ) + return input_in_valid, input_idx + + +# Scenario Assignment +# ──────────────────── + + +def assign_single_pair_scenario( + report: APNodePairReport, + pairs: np.ndarray, + pair_is_distal: np.ndarray, + input_in_valid: bool, +) -> APNodePairReport: + """Assign scenario when exactly one valid pair exists (scenarios 1-4).""" + if input_in_valid: + if pair_is_distal[0]: + report.scenario = 1 + report.outcome = "accept" + else: + report.scenario = 2 + report.outcome = "warn" + report.warning_message = "Input pair has proximal node(s)." + elif pair_is_distal[0]: + report.scenario = 3 + report.outcome = "warn" + report.warning_message = ( + f"Input invalid. Suggest pair [{pairs[0, 0]}, {pairs[0, 1]}]." + ) + else: + report.scenario = 4 + report.outcome = "warn" + report.warning_message = ( + f"Input invalid. Only option " + f"[{pairs[0, 0]}, {pairs[0, 1]}] has proximal node(s)." + ) + return report + + +def assign_multi_input_distal_scenario( + report: APNodePairReport, + pairs: np.ndarray, + input_idx: int, +) -> APNodePairReport: + """Assign scenario for distal input in multi-pair case (5, 6, 7).""" + input_pair_sorted = tuple( + sorted([pairs[input_idx, 0], pairs[input_idx, 1]]) + ) + max_distal_sorted = ( + tuple(sorted(report.max_separation_distal_nodes)) + if len(report.max_separation_distal_nodes) > 0 + else () + ) + + if report.input_pair_rank == 1: + report.scenario = 5 + report.outcome = "accept" + elif input_pair_sorted == max_distal_sorted: + report.scenario = 7 + report.outcome = "accept" + else: + report.scenario = 6 + report.outcome = "warn" + d = report.max_separation_distal_nodes + report.warning_message = ( + f"Distal pair with greater separation exists: [{d[0]}, {d[1]}]." + ) + return report + + +def assign_multi_input_proximal_scenario( + report: APNodePairReport, + pair_is_distal: np.ndarray, +) -> APNodePairReport: + """Assign scenario for proximal input in multi-pair case (8-11).""" + has_distal = np.any(pair_is_distal) + is_max_sep = report.input_pair_rank == 1 + + if is_max_sep and has_distal: + report.scenario = 8 + d = report.max_separation_distal_nodes + report.warning_message = ( + f"Input has proximal node(s). " + f"Distal alternative: [{d[0]}, {d[1]}]." + ) + elif is_max_sep: + report.scenario = 9 + report.warning_message = ( + "Input has proximal node(s). All pairs have proximal node(s)." + ) + elif has_distal: + report.scenario = 10 + d = report.max_separation_distal_nodes + report.warning_message = ( + f"Input has proximal node(s). " + f"Distal pair with greater separation: [{d[0]}, {d[1]}]." + ) + else: + report.scenario = 11 + report.warning_message = ( + "Input has proximal node(s). All pairs have proximal node(s)." + ) + + report.outcome = "warn" + return report + + +def assign_multi_input_invalid_scenario( + report: APNodePairReport, + pair_is_distal: np.ndarray, +) -> APNodePairReport: + """Assign scenario when input not in valid pairs (12-13).""" + has_distal = np.any(pair_is_distal) + report.outcome = "warn" + + if has_distal: + report.scenario = 12 + d = report.max_separation_distal_nodes + report.warning_message = ( + f"Input invalid. Suggest max separation distal pair: " + f"[{d[0]}, {d[1]}]." + ) + else: + report.scenario = 13 + m = report.max_separation_nodes + report.warning_message = ( + f"Input invalid. All pairs have proximal node(s). " + f"Max separation: [{m[0]}, {m[1]}]." + ) + return report + + +def assign_scenario( + report: APNodePairReport, + pairs: np.ndarray, + seps: np.ndarray, + pair_is_distal: np.ndarray, + input_in_valid: bool, + input_idx: int, +) -> APNodePairReport: + """Assign one of 13 mutually exclusive scenarios. + + Parameters + ---------- + report : APNodePairReport + The report to update with scenario information. + pairs : np.ndarray + Valid pairs array of shape (n_pairs, 2). + seps : np.ndarray + Internode separations for each pair. + pair_is_distal : np.ndarray + Boolean array indicating distal pairs. + input_in_valid : bool + Whether input pair is among valid pairs. + input_idx : int + Index of input pair in valid pairs (-1 if not present). + + Returns + ------- + APNodePairReport + Updated report with scenario, outcome, and warning_message. + + """ + if len(pairs) == 1: + return assign_single_pair_scenario( + report, + pairs, + pair_is_distal, + input_in_valid, + ) + + if not input_in_valid: + return assign_multi_input_invalid_scenario(report, pair_is_distal) + + if report.input_pair_is_distal: + return assign_multi_input_distal_scenario( + report, + pairs, + input_idx, + ) + + return assign_multi_input_proximal_scenario(report, pair_is_distal) + + +def evaluate_ap_node_pair( + avg_skeleton: np.ndarray, + pc1_vec: np.ndarray, + anterior_sign: int, + valid_shape_rows: np.ndarray, + from_node: int, + to_node: int, + config: ValidateAPConfig, +) -> APNodePairReport: + """Evaluate an AP node pair through the 3-step filter cascade. + + Parameters + ---------- + avg_skeleton : np.ndarray + Average centered skeleton of shape (n_keypoints, 2). + pc1_vec : np.ndarray + First principal component vector of shape (2,). + anterior_sign : int + Inferred anterior direction (+1 or -1 relative to PC1). + valid_shape_rows : np.ndarray + Boolean array indicating valid (non-NaN) keypoints. + from_node : int + Index of the input from_node (body_axis_keypoints origin, + claimed posterior). 0-indexed. + to_node : int + Index of the input to_node (body_axis_keypoints target, + claimed anterior). 0-indexed. + config : ValidateAPConfig + Configuration with ``lateral_thresh`` and ``edge_thresh``. + + Returns + ------- + APNodePairReport + Complete evaluation report. + + """ + n_keypoints = len(avg_skeleton) + report = APNodePairReport() + report.pc1_coords = np.full(n_keypoints, np.nan) + report.ap_coords = np.full(n_keypoints, np.nan) + report.lateral_offsets = np.full(n_keypoints, np.nan) + report.lateral_offsets_norm = np.full(n_keypoints, np.nan) + report.midline_dist_norm = np.full(n_keypoints, np.nan) + + for node, label in [(from_node, "from_node"), (to_node, "to_node")]: + if node < 0 or node >= n_keypoints: + report.failure_step = "Input validation" + report.failure_reason = ( + f"{label} must be a valid index in 0..{n_keypoints - 1}." + ) + return report + + valid_idx = np.nonzero(valid_shape_rows)[0] + if len(valid_idx) < 2: + report.failure_step = "Step 1: lateral alignment filter" + report.failure_reason = "Fewer than 2 valid nodes are available." + return report + + compute_node_projections( + report, + avg_skeleton, + pc1_vec, + anterior_sign, + valid_shape_rows, + from_node, + to_node, + ) + + candidates = apply_lateral_filter(report, valid_idx, config.lateral_thresh) + if candidates is None: + return report + + step2 = find_opposite_side_pairs( + report, + candidates, + from_node, + to_node, + valid_shape_rows, + ) + if step2 is None: + return report + pairs, seps = step2 + + pair_is_distal = classify_distal_proximal( + report, + pairs, + seps, + valid_shape_rows, + config.edge_thresh, + ) + + input_in_valid, input_idx = check_input_pair_in_valid( + report, + pairs, + seps, + pair_is_distal, + from_node, + to_node, + ) + + report = assign_scenario( + report, pairs, seps, pair_is_distal, input_in_valid, input_idx + ) + report.success = True + return report + + +# Input Preparation and Validation +# ────────────────────────────────── + + +def resolve_node_index(node: Hashable, names: list) -> int: + """Resolve a node identifier to an integer index.""" + if isinstance(node, str): + if node in names: + return names.index(node) + raise ValueError(f"Keypoint '{node}' not found in {names}.") + if isinstance(node, int): + return node + return int(node) # type: ignore[call-overload] + + +def prepare_validation_inputs( + data: xr.DataArray, + from_node: Hashable, + to_node: Hashable, +) -> tuple[np.ndarray, int, int, str, str, list[str], int]: + """Validate inputs and extract numpy arrays for AP validation. + + Returns + ------- + tuple + (keypoints, from_idx, to_idx, from_name, to_name, + keypoint_names, num_frames) + + Raises + ------ + TypeError + If data is not an xarray.DataArray. + ValueError + If dimensions or indices are invalid. + + """ + if not isinstance(data, xr.DataArray): + raise TypeError( + f"Input data must be an xarray.DataArray, but got {type(data)}." + ) + + required_dims = {"time", "space", "keypoints"} + if not required_dims.issubset(set(data.dims)): + raise ValueError( + f"data must have dimensions {required_dims}, " + f"but has {set(data.dims)}." + ) + + if "individuals" in data.dims: + if data.sizes["individuals"] != 1: + raise ValueError( + "data must be for a single individual. " + "Use data.sel(individuals='name') to select one." + ) + data = data.squeeze("individuals", drop=True) + + if "keypoints" in data.coords: + keypoint_names = list(data.coords["keypoints"].values) + else: + keypoint_names = [f"node_{i}" for i in range(data.sizes["keypoints"])] + + n_keypoints = data.sizes["keypoints"] + from_idx = resolve_node_index(from_node, keypoint_names) + to_idx = resolve_node_index(to_node, keypoint_names) + + if from_idx < 0 or from_idx >= n_keypoints: + raise ValueError( + f"from_node index {from_idx} out of range [0, {n_keypoints - 1}]." + ) + if to_idx < 0 or to_idx >= n_keypoints: + raise ValueError( + f"to_node index {to_idx} out of range [0, {n_keypoints - 1}]." + ) + + data_xy = data.sel(space=["x", "y"]) + keypoints = data_xy.transpose("time", "keypoints", "space").values + + from_name = keypoint_names[from_idx] + to_name = keypoint_names[to_idx] + num_frames = keypoints.shape[0] + + return ( + keypoints, + from_idx, + to_idx, + from_name, + to_name, + keypoint_names, + num_frames, + ) + + +# Pipeline Orchestration Functions +# ────────────────────────────────── + + +def run_motion_segmentation( + keypoints: np.ndarray, + num_frames: int, + config: ValidateAPConfig, + log_info, + log_warning, +) -> dict | None: + """Run tiered validity through segment detection. + + Returns a dict with tier1_valid, tier2_valid, bbox_centroids, + segments, or None on failure (error logged). + + """ + tier1_valid, tier2_valid, _frac = compute_tiered_validity( + keypoints, config.min_valid_frac + ) + num_tier1 = int(np.sum(tier1_valid)) + num_tier2 = int(np.sum(tier2_valid)) + + log_info(_LOG_SEPARATOR) + log_info("Tiered Validity Report") + log_info(_LOG_SEPARATOR) + log_info( + "Tier 1 (>= %.0f%% keypoints): %d / %d frames (%.2f%%)", + config.min_valid_frac * 100, + num_tier1, + num_frames, + 100 * num_tier1 / num_frames, + ) + log_info( + "Tier 2 (100%% keypoints): %d / %d frames (%.2f%%)", + num_tier2, + num_frames, + 100 * num_tier2 / num_frames, + ) + + if num_tier1 < 2: + logger.error("Not enough tier-1 valid frames.") + return None + + bbox_centroids, _arith, centroid_disc = compute_bbox_centroid( + keypoints, tier1_valid + ) + valid_disc = centroid_disc[tier1_valid & ~np.isnan(centroid_disc)] + if len(valid_disc) > 0: + log_info("") + log_info(_LOG_SEPARATOR) + log_info("Centroid Discrepancy Diagnostic") + log_info(_LOG_SEPARATOR) + log_info("BBox vs arithmetic centroid (normalized by bbox diagonal):") + log_info( + " Median: %.4f | Mean: %.4f | Max: %.4f", + np.median(valid_disc), + np.mean(valid_disc), + np.max(valid_disc), + ) + if np.median(valid_disc) > 0.05: + log_warning( + "Median discrepancy > 5%% - annotation density " + "is likely asymmetric." + ) + + segments = detect_motion_segments( + bbox_centroids, tier1_valid, config, log_info + ) + if segments is None: + return None + + return { + "tier1_valid": tier1_valid, + "tier2_valid": tier2_valid, + "bbox_centroids": bbox_centroids, + "segments": segments, + } + + +def detect_motion_segments( + bbox_centroids: np.ndarray, + tier1_valid: np.ndarray, + config: ValidateAPConfig, + log_info, +) -> np.ndarray | None: + """Detect high-motion segments from centroid velocities. + + Returns merged segments array, or None on failure. + + """ + _, speeds = compute_frame_velocities(bbox_centroids, tier1_valid) + num_speed = len(speeds) + + if num_speed < config.window_len: + logger.error( + "window_len=%d exceeds available speed samples=%d.", + config.window_len, + num_speed, + ) + return None + + window_starts, window_medians, window_all_valid = ( + compute_sliding_window_medians( + speeds, config.window_len, config.stride + ) + ) + num_valid_windows = int(np.sum(window_all_valid)) + if num_valid_windows == 0: + logger.error("No fully valid sliding windows found.") + return None + + high_motion = detect_high_motion_windows( + window_medians, window_all_valid, config.pct_thresh + ) + num_high_motion = int(np.sum(high_motion)) + + log_info("") + log_info(_LOG_SEPARATOR) + log_info("High-Motion Window Detection") + log_info(_LOG_SEPARATOR) + log_info( + "Sliding windows (len=%d, stride=%d): " + "%d total, %d fully valid (NaN-free), " + "%d high-motion (median speed >= %dth percentile)", + config.window_len, + config.stride, + len(window_starts), + num_valid_windows, + num_high_motion, + int(config.pct_thresh), + ) + + if num_high_motion == 0: + logger.error("No high-motion windows found.") + return None + + run_starts, run_ends, _run_lengths = detect_runs( + high_motion, config.min_run_len + ) + if len(run_starts) == 0: + logger.error("No runs met min_run_len=%d.", config.min_run_len) + return None + + segments_raw = convert_runs_to_segments( + run_starts, run_ends, window_starts, config.window_len + ) + segments = merge_segments(segments_raw) + + log_info("Detected %d merged high-motion segment(s):", len(segments)) + for i, (start, end) in enumerate(segments): + log_info(" Segment %d: frames %d - %d", i + 1, start, end) + + return segments + + +def select_tier2_frames( + segments: np.ndarray, + tier2_valid: np.ndarray, + num_frames: int, + log_info, + log_warning, +) -> tuple[np.ndarray, np.ndarray, int] | None: + """Filter segment frames to tier-2 valid only. + + Returns (selected_frames, selected_seg_id, num_selected) or None. + + """ + selected_frames, selected_seg_id = filter_segments_tier2( + segments, tier2_valid + ) + + num_tier1_in_segs = sum( + np.sum( + (np.arange(num_frames) >= s[0]) & (np.arange(num_frames) <= s[1]) + ) + for s in segments + ) + num_selected = len(selected_frames) + + log_info("") + log_info(_LOG_SEPARATOR) + log_info("Tier-2 Filtering on High-Motion Segments") + log_info(_LOG_SEPARATOR) + log_info( + "Frames in high-motion segments (any tier): %d", num_tier1_in_segs + ) + log_info( + "Tier-2 valid frames retained (all keypoints present): " + "%d (%.1f%% of segment frames)", + num_selected, + 100 * num_selected / max(num_tier1_in_segs, 1), + ) + + retention = num_selected / max(num_tier1_in_segs, 1) + if retention < 0.3: + log_warning( + "Tier 2 discards > 70%% of segment frames - " + "body model may be unrepresentative." + ) + + if num_selected < 2: + logger.error("Not enough tier-2 valid frames in selected segments.") + return None + + return selected_frames, selected_seg_id, num_selected + + +def run_clustering_and_pca( + centered_skeletons: np.ndarray, + frame_sel: FrameSelection, + config: ValidateAPConfig, + log_info, + log_warning, +) -> dict | None: + """Run postural analysis, clustering, and per-cluster PCA. + + Returns dict with primary_result, cluster_results, + num_clusters, primary_cluster, or None on failure. + + """ + rmsd_matrix = compute_pairwise_rmsd(centered_skeletons) + var_ratio, within_rmsds, between_rmsds, var_ratio_override = ( + compute_postural_variance_ratio(rmsd_matrix, frame_sel.seg_ids) + ) + + rmsd_stats = { + "within": within_rmsds, + "between": between_rmsds, + "var_ratio": var_ratio, + "override": var_ratio_override, + } + log_postural_consistency( + rmsd_stats, + config, + frame_sel.count, + log_info, + ) + + cluster_labels, num_clusters, primary_cluster = decide_and_run_clustering( + centered_skeletons, + var_ratio, + frame_sel.count, + config, + log_info, + ) + + cluster_results = [] + for c in range(num_clusters): + cluster_mask = cluster_labels == c + cr = compute_cluster_pca_and_anterior( + centered_skeletons, + cluster_mask, + frame_sel.frames, + frame_sel.seg_ids, + frame_sel.segments, + frame_sel.bbox_centroids, + ) + cluster_results.append(cr) + + pr = cluster_results[primary_cluster] + if not pr["valid"]: + logger.error("Primary cluster has invalid PCA result.") + return None + + return { + "primary_result": pr, + "cluster_results": cluster_results, + "num_clusters": num_clusters, + "primary_cluster": primary_cluster, + } + + +def log_postural_consistency( + rmsd_stats, + config, + num_selected, + log_info, +): + """Log postural consistency check results.""" + within_rmsds = rmsd_stats["within"] + between_rmsds = rmsd_stats["between"] + var_ratio = rmsd_stats["var_ratio"] + var_ratio_override = rmsd_stats["override"] + + log_info("") + log_info(_LOG_SEPARATOR) + log_info("Postural Consistency Check") + log_info(_LOG_SEPARATOR) + + if len(within_rmsds) > 0: + log_info( + "Within-segment RMSD: mean=%.4f, std=%.4f (n=%d pairs)", + np.mean(within_rmsds), + np.std(within_rmsds), + len(within_rmsds), + ) + else: + log_info("Within-segment RMSD: N/A (no within-segment pairs)") + + if len(between_rmsds) > 0: + log_info( + "Between-segment RMSD: mean=%.4f, std=%.4f (n=%d pairs)", + np.mean(between_rmsds), + np.std(between_rmsds), + len(between_rmsds), + ) + log_info( + "Variance ratio (between/within): %.2f (threshold=%.2f)", + var_ratio, + config.postural_var_ratio_thresh, + ) + if var_ratio_override: + log_info( + " (Conservative override to zero: within-segment variance " + "is zero or no within-segment pairs)" + ) + else: + log_info("Between-segment RMSD: N/A (single segment)") + log_info("Variance ratio: N/A") + + do_clustering = ( + var_ratio > config.postural_var_ratio_thresh and num_selected >= 6 + ) + if do_clustering: + log_info(" -> Variance ratio exceeds threshold. Running clustering.") + elif var_ratio > config.postural_var_ratio_thresh and num_selected < 6: + log_info( + " -> Variance ratio exceeds threshold but too few frames (%d) " + "for clustering.", + num_selected, + ) + else: + log_info(" -> Postural consistency acceptable. Using global average.") + + +def decide_and_run_clustering( + centered_skeletons, + var_ratio, + num_selected, + config, + log_info, +): + """Decide whether to cluster; run k-medoids if triggered.""" + do_clustering = ( + var_ratio > config.postural_var_ratio_thresh and num_selected >= 6 + ) + + if not do_clustering: + return np.zeros(num_selected, dtype=int), 1, 0 + + ( + cluster_labels, + num_clusters, + primary_cluster, + best_silhouette, + silhouette_scores, + ) = perform_postural_clustering(centered_skeletons, config.max_clusters) + + for k, sil in silhouette_scores: + if np.isnan(sil): + log_info(" k=%d: clustering failed.", k) + else: + log_info(" k=%d: mean silhouette = %.4f", k, sil) + + if num_clusters > 1: + cluster_counts = np.bincount(cluster_labels, minlength=num_clusters) + log_info( + " Selected k=%d clusters (silhouette=%.4f). " + "Primary cluster=%d (%d frames)", + num_clusters, + best_silhouette, + primary_cluster + 1, + cluster_counts[primary_cluster], + ) + else: + log_info( + " Clustering did not improve separation (best_sil=%.4f). " + "Using global average.", + best_silhouette, + ) + + return cluster_labels, num_clusters, primary_cluster + + +def log_anterior_report( + pr, + cluster_results, + num_clusters, + primary_cluster, + config, + log_info, + log_warning, +): + """Log anterior direction detection and cluster agreement.""" + log_info("") + log_info(_LOG_SEPARATOR) + log_info("Anterior Direction Inference (Velocity Voting)") + log_info(_LOG_SEPARATOR) + log_info( + "Centroid velocity projections onto PC1: " + "%d positive (+PC1), %d negative (-PC1)", + pr["num_positive"], + pr["num_negative"], + ) + + vote_margin_str = f"Vote margin M: {pr['vote_margin']:.4f}" + if pr["vote_margin"] < config.confidence_floor: + vote_margin_str += ( + f" ** BELOW CONFIDENCE FLOOR ({config.confidence_floor:.2f}) " + "- anterior assignment is unreliable **" + ) + log_warning( + "Vote margin M = %.4f is below confidence floor %.2f - " + "anterior assignment is unreliable.", + pr["vote_margin"], + config.confidence_floor, + ) + log_info(vote_margin_str) + log_info( + "Resultant length R: %.4f (0 = omnidirectional, 1 = unidirectional)", + pr["resultant_length"], + ) + log_info( + "Inferred anterior direction: %sPC1 " + "(strict majority; ties default to -PC1)", + "+" if pr["anterior_sign"] > 0 else "-", + ) + + if num_clusters <= 1: + return + + log_info("") + log_info(_LOG_SEPARATOR) + log_info("Inter-Cluster Anterior Polarity Agreement") + log_info(_LOG_SEPARATOR) + signs = [cr["anterior_sign"] for cr in cluster_results if cr["valid"]] + if len(set(signs)) == 1: + log_info( + "All %d clusters AGREE on anterior polarity.", + num_clusters, + ) + else: + log_info( + "DISAGREEMENT: clusters assign different anterior polarities." + ) + for c, cr in enumerate(cluster_results): + if cr["valid"]: + log_info( + " Cluster %d (%d frames): anterior = %sPC1, " + "vote_margin M = %.4f, resultant_length R = %.4f", + c + 1, + cr["n_frames"], + "+" if cr["anterior_sign"] > 0 else "-", + cr["vote_margin"], + cr["resultant_length"], + ) + log_info( + " Primary result from cluster %d (largest).", + primary_cluster + 1, + ) + + +def log_pair_evaluation( + pair_report, + config, + from_idx, + to_idx, + from_name, + to_name, + log_info, +): + """Log the complete AP node pair evaluation report.""" + log_info("") + log_info(_LOG_SEPARATOR) + log_info("AP Node-Pair Filter Cascade (3-Step Evaluation)") + log_info(_LOG_SEPARATOR) + log_info( + "Input pair: [%d, %d] (%s -> %s, claimed posterior -> anterior)", + from_idx, + to_idx, + from_name, + to_name, + ) + + step1_failed = pair_report.failure_step.startswith("Step 1") + + valid_nodes = np.nonzero(~np.isnan(pair_report.lateral_offsets_norm))[0] + num_candidates = len(pair_report.sorted_candidate_nodes) + step1_loss = 1 - num_candidates / max(len(valid_nodes), 1) + + log_step1_report(pair_report, config, valid_nodes, log_info) + log_input_node_status(pair_report, config, from_idx, to_idx, log_info) + + step2_loss = 0.0 + step3_frac = 0.0 + step2_failed = False + + if step1_failed: + log_info("") + log_info("Step 2-3: not evaluated (Step 1 failed)") + else: + step2_loss, step3_frac, step2_failed = log_step2_step3_details( + pair_report, + config, + from_idx, + to_idx, + num_candidates, + log_info, + ) + + log_loss_summary( + step1_loss, + step2_loss, + step3_frac, + step1_failed, + step2_failed, + log_info, + ) + log_order_check( + pair_report, + from_idx, + to_idx, + from_name, + to_name, + log_info, + ) + + +def log_step1_report(pair_report, config, valid_nodes, log_info): + """Log Step 1 lateral filter results.""" + num_valid = len(valid_nodes) + num_candidates = len(pair_report.sorted_candidate_nodes) + step1_loss = 1 - num_candidates / max(num_valid, 1) + + pass_strs = [] + fail_strs = [] + for node_i in valid_nodes: + lat_norm = pair_report.lateral_offsets_norm[node_i] + if lat_norm <= config.lateral_thresh: + pass_strs.append(f"{node_i}({lat_norm:.2f})") + else: + fail_strs.append(f"{node_i}({lat_norm:.2f})") + + log_info("") + log_info( + "Step 1 - Lateral Alignment Filter (lateral_thresh=%.2f): " + "%d of %d valid nodes pass [loss=%.0f%%]", + config.lateral_thresh, + num_candidates, + num_valid, + 100 * step1_loss, + ) + log_info( + " Scale: 0.00 = nearest to body axis, 1.00 = farthest from body axis" + ) + if pass_strs: + log_info(" PASS: %s", ", ".join(pass_strs)) + if fail_strs: + log_info(" FAIL: %s", ", ".join(fail_strs)) + + +def log_step2_report(pair_report, _config, log_info): + """Log Step 2 opposite-sides results.""" + num_candidates = len(pair_report.sorted_candidate_nodes) + num_possible_pairs = num_candidates * (num_candidates - 1) // 2 + num_valid_pairs = len(pair_report.valid_pairs) + step2_loss = 1 - num_valid_pairs / max(num_possible_pairs, 1) + m = pair_report.midpoint_pc1 + + plus_strs = [] + minus_strs = [] + for node_i in pair_report.sorted_candidate_nodes: + pc1_rel = pair_report.pc1_coords[node_i] - m + if pc1_rel > 0: + plus_strs.append(f"{node_i}({pc1_rel:+.1f})") + else: + minus_strs.append(f"{node_i}({pc1_rel:+.1f})") + + log_info("") + log_info( + "Step 2 - Opposite-Sides Constraint (AP midpoint=%.2f): " + "%d of %d candidate pairs on opposite sides [loss=%.0f%%]", + m, + num_valid_pairs, + num_possible_pairs, + 100 * step2_loss, + ) + if plus_strs: + log_info(" + side (anterior of midpoint): %s", ", ".join(plus_strs)) + if minus_strs: + log_info(" - side (posterior of midpoint): %s", ", ".join(minus_strs)) + + +def log_step3_report(pair_report, config, log_info): + """Log Step 3 distal/proximal classification results.""" + num_distal = len(pair_report.distal_pairs) + num_proximal = len(pair_report.proximal_pairs) + num_valid_pairs = len(pair_report.valid_pairs) + step3_distal_frac = num_distal / max(num_valid_pairs, 1) + + log_info("") + log_info( + "Step 3 - Distal/Proximal Classification (edge_thresh=%.2f): " + "%d distal, %d proximal [distal fraction=%.0f%%]", + config.edge_thresh, + num_distal, + num_proximal, + 100 * step3_distal_frac, + ) + + for idx in range(num_valid_pairs): + node_i, node_j = pair_report.valid_pairs[idx] + d_i = pair_report.midline_dist_norm[node_i] + d_j = pair_report.midline_dist_norm[node_j] + min_d = min(d_i, d_j) + sep = pair_report.valid_pairs_internode_dist[idx] + status = "DISTAL" if min_d >= config.edge_thresh else "PROXIMAL" + log_info( + " [%d,%d]: min_d=%.2f, sep=%.2f [%s]", + node_i, + node_j, + min_d, + sep, + status, + ) + + +def log_loss_summary( + step1_loss, step2_loss, step3_frac, step1_failed, step2_failed, log_info +): + """Log cumulative filtering loss summary.""" + log_info("") + log_info(_LOG_SEPARATOR) + log_info("Filtering Loss Summary") + log_info(_LOG_SEPARATOR) + log_info( + "Step 1 (Lateral Filter): %.0f%% of valid nodes eliminated", + 100 * step1_loss, + ) + if not step1_failed: + log_info( + "Step 2 (Opposite-Sides): %.0f%% of candidate pairs eliminated", + 100 * step2_loss, + ) + if not step1_failed and not step2_failed: + log_info( + "Step 3 (Distal/Proximal): %.0f%% of surviving pairs are distal", + 100 * step3_frac, + ) + + +def log_order_check( + pair_report, from_idx, to_idx, from_name, to_name, log_info +): + """Log AP ordering check for the input pair.""" + log_info("") + log_info(_LOG_SEPARATOR) + log_info("Order Check: is from_node posterior to to_node?") + log_info(_LOG_SEPARATOR) + ap_from = pair_report.ap_coords[from_idx] + ap_to = pair_report.ap_coords[to_idx] + if np.isnan(ap_from) or np.isnan(ap_to): + log_info("Order check: cannot evaluate (invalid node coordinates)") + return + + log_info( + "AP coords: from_node %s[%d]=%.2f, to_node %s[%d]=%.2f", + from_name, + from_idx, + ap_from, + to_name, + to_idx, + ap_to, + ) + if pair_report.input_pair_order_matches_inference: + log_info( + "[%d, %d]: CONSISTENT - inference agrees that " + "from_node is posterior (lower AP coord), " + "to_node is anterior", + from_idx, + to_idx, + ) + else: + log_info( + "[%d, %d]: INCONSISTENT - inference suggests " + "from_node is anterior (higher AP coord), " + "to_node is posterior", + from_idx, + to_idx, + ) + log_info( + " -> Inferred posterior->anterior order would be [%d, %d]", + to_idx, + from_idx, + ) + + +def log_input_node_status( + pair_report, + config, + from_idx, + to_idx, + log_info, +): + """Log whether each input node passed the lateral filter.""" + lat_from = pair_report.lateral_offsets_norm[from_idx] + lat_to = pair_report.lateral_offsets_norm[to_idx] + from_pass = not np.isnan(lat_from) and lat_from <= config.lateral_thresh + to_pass = not np.isnan(lat_to) and lat_to <= config.lateral_thresh + + if from_pass and to_pass: + return + + fail_nodes = [] + if not from_pass: + fail_nodes.append(f"{from_idx}({lat_from:.2f})") + if not to_pass: + fail_nodes.append(f"{to_idx}({lat_to:.2f})") + log_info( + " -> Input node(s) FAILED lateral filter: %s", + ", ".join(fail_nodes), + ) + + +def log_step2_step3_details( + pair_report, + config, + from_idx, + to_idx, + num_candidates, + log_info, +): + """Log Step 2 and Step 3 results when Step 1 succeeded. + + Returns (step2_loss, step3_frac, step2_failed). + + """ + log_step2_report(pair_report, config, log_info) + + if ( + pair_report.input_pair_in_candidates + and not pair_report.input_pair_opposite_sides + ): + log_info(" -> Input nodes on SAME side of AP midpoint") + + num_possible = num_candidates * (num_candidates - 1) // 2 + step2_loss = 1 - len(pair_report.valid_pairs) / max(num_possible, 1) + step2_failed = pair_report.failure_step.startswith("Step 2") + + if step2_failed: + log_info("") + log_info("Step 3: not evaluated (Step 2 failed)") + return step2_loss, 0.0, True + + step3_frac = log_step3_with_proximal_check( + pair_report, + config, + from_idx, + to_idx, + log_info, + ) + return step2_loss, step3_frac, False + + +def log_step3_with_proximal_check( + pair_report, + config, + from_idx, + to_idx, + log_info, +): + """Log Step 3 results and check input pair proximal status. + + Returns step3_frac. + + """ + log_step3_report(pair_report, config, log_info) + num_distal = len(pair_report.distal_pairs) + num_valid_pairs = len(pair_report.valid_pairs) + step3_frac = num_distal / max(num_valid_pairs, 1) + + is_candidate = pair_report.input_pair_in_candidates + is_opposite = pair_report.input_pair_opposite_sides + is_proximal = not pair_report.input_pair_is_distal + if is_candidate and is_opposite and is_proximal: + d_from = pair_report.midline_dist_norm[from_idx] + d_to = pair_report.midline_dist_norm[to_idx] + log_info( + " -> Input pair is PROXIMAL (min_d=%.2f < %.2f)", + min(d_from, d_to), + config.edge_thresh, + ) + + return step3_frac + + +# Main Validation Function +# ───────────────────────── + + +def validate_ap( + data: xr.DataArray, + from_node: Hashable, + to_node: Hashable, + config: ValidateAPConfig | None = None, + verbose: bool = True, +) -> dict: + """Validate an anterior-posterior keypoint pair using body-axis inference. + + This function implements a prior-free body-axis inference pipeline that: + 1. Identifies high-motion segments using tiered validity and sliding + windows + 2. Optionally performs postural clustering via k-medoids + 3. Infers the anterior direction using velocity projection voting + 4. Evaluates the candidate AP keypoint pair through a 3-step filter + cascade + + Parameters + ---------- + data : xarray.DataArray + Position data for a single individual. + from_node : int or str + Index or name of the posterior keypoint. + to_node : int or str + Index or name of the anterior keypoint. + config : ValidateAPConfig, optional + Configuration parameters. If None, uses defaults. + verbose : bool, default=True + If True, log detailed validation output to console. + + Returns + ------- + dict + Validation results including success, anterior_sign, + vote_margin, resultant_length, pair_report, etc. + + """ + if config is None: + config = ValidateAPConfig() + + log_lines: list[str] = [] + + def _log_info(msg, *args): + """Log an informational message.""" + line = msg % args if args else msg + log_lines.append(line) + if verbose: + print(line) + + def _log_warning(msg, *args): + """Log a warning message.""" + line = f"WARNING: {msg % args if args else msg}" + log_lines.append(line) + if verbose: + print(line) + + # Prepare inputs + ( + keypoints, + from_idx, + to_idx, + from_name, + to_name, + _keypoint_names, + num_frames, + ) = prepare_validation_inputs(data, from_node, to_node) + + n_keypoints = keypoints.shape[1] + result: dict = { + "success": False, + "anterior_sign": 0, + "vote_margin": 0.0, + "resultant_length": 0.0, + "num_selected_frames": 0, + "num_clusters": 1, + "primary_cluster": 0, + "pair_report": APNodePairReport(), + "PC1": np.array([1.0, 0.0]), + "PC2": np.array([0.0, 1.0]), + "avg_skeleton": np.full((n_keypoints, 2), np.nan), + "error_msg": "", + "log_lines": log_lines, + } + + # Motion segmentation + seg = run_motion_segmentation( + keypoints, + num_frames, + config, + _log_info, + _log_warning, + ) + if seg is None: + result["error_msg"] = "Motion segmentation failed." + return result + + # Tier-2 frame selection + t2 = select_tier2_frames( + seg["segments"], + seg["tier2_valid"], + num_frames, + _log_info, + _log_warning, + ) + if t2 is None: + result["error_msg"] = "Not enough tier-2 valid frames." + return result + selected_frames, selected_seg_id, num_selected = t2 + result["num_selected_frames"] = num_selected + + # Build centered skeletons + _selected_centroids, centered_skeletons = build_centered_skeletons( + keypoints, selected_frames + ) + + # Bundle frame selection data + frame_sel = FrameSelection( + frames=selected_frames, + seg_ids=selected_seg_id, + segments=seg["segments"], + bbox_centroids=seg["bbox_centroids"], + count=num_selected, + ) + + # Postural clustering + PCA + anterior inference + pca = run_clustering_and_pca( + centered_skeletons, + frame_sel, + config, + _log_info, + _log_warning, + ) + if pca is None: + result["error_msg"] = "Primary cluster PCA failed." + return result + + pr = pca["primary_result"] + result["anterior_sign"] = pr["anterior_sign"] + result["vote_margin"] = pr["vote_margin"] + result["resultant_length"] = pr["resultant_length"] + result["circ_mean_dir"] = pr["circ_mean_dir"] + result["vel_projs_pc1"] = pr["vel_projs_pc1"] + result["PC1"] = pr["PC1"] + result["PC2"] = pr["PC2"] + result["avg_skeleton"] = pr["avg_skeleton"] + result["num_clusters"] = pca["num_clusters"] + result["primary_cluster"] = pca["primary_cluster"] + + # Log anterior inference + log_anterior_report( + pr, + pca["cluster_results"], + pca["num_clusters"], + pca["primary_cluster"], + config, + _log_info, + _log_warning, + ) + + # AP node-pair evaluation + pair_report = evaluate_ap_node_pair( + pr["avg_skeleton"], + pr["PC1"], + pr["anterior_sign"], + pr["valid_shape_rows"], + from_idx, + to_idx, + config, + ) + result["pair_report"] = pair_report + + log_pair_evaluation( + pair_report, + config, + from_idx, + to_idx, + from_name, + to_name, + _log_info, + ) + + result["success"] = True + return result + + +# Multi-Individual Validation +# ──────────────────────────── + + +def run_ap_validation( + data: xr.DataArray, + normalized_keypoints: tuple[Hashable, Hashable], + ap_validation_config: dict[str, Any] | None, +) -> dict: + """Run AP validation across all individuals, select best by R*M. + + Each individual is validated independently using the supplied keypoint + pair. R*M (resultant_length * vote_margin) is computed per individual + and depends only on the individual's motion and body shape, not on + the input pair. The best individual is the one with the highest R*M. + + Parameters + ---------- + data : xarray.DataArray + Position data with individuals dimension. + normalized_keypoints : tuple[Hashable, Hashable] + The (from_node, to_node) keypoint pair. + ap_validation_config : dict, optional + Configuration overrides for ValidateAPConfig. + + Returns + ------- + dict + Dictionary with 'all_results' (list of per-individual results) + and 'best_idx' (index of best individual by R*M). + + """ + config = ( + ValidateAPConfig(**ap_validation_config) + if ap_validation_config is not None + else None + ) + + if "individuals" not in data.dims: + single_result = validate_ap( + data, + from_node=normalized_keypoints[0], + to_node=normalized_keypoints[1], + config=config, + verbose=False, + ) + return {"all_results": [single_result], "best_idx": 0} + + individuals = list(data.coords["individuals"].values) + all_results = [] + for individual in individuals: + result = validate_ap( + data.sel(individuals=individual), + from_node=normalized_keypoints[0], + to_node=normalized_keypoints[1], + config=config, + verbose=False, + ) + result["individual"] = individual + all_results.append(result) + + best_idx = find_best_individual_by_rxm(all_results) + return {"all_results": all_results, "best_idx": best_idx} + + +def find_best_individual_by_rxm(all_results: list[dict]) -> int: + """Return index of the individual with highest R*M score.""" + best_idx = -1 + best_rxm = -1.0 + for i, result in enumerate(all_results): + if not result["success"]: + continue + rxm = result["resultant_length"] * result["vote_margin"] + if rxm > best_rxm: + best_rxm = rxm + best_idx = i + return best_idx diff --git a/movement/kinematics/collective.py b/movement/kinematics/collective.py index 8b0e06aa5..e9795d3fc 100644 --- a/movement/kinematics/collective.py +++ b/movement/kinematics/collective.py @@ -7,8 +7,13 @@ import numpy as np import xarray as xr +from movement.kinematics.body_axis import run_ap_validation from movement.utils.logging import logger -from movement.utils.vector import compute_norm +from movement.utils.vector import ( + compute_norm, + compute_signed_angle_2d, + convert_to_unit, +) from movement.validators.arrays import validate_dims_coords _ANGLE_EPS = 1e-12 @@ -20,6 +25,8 @@ def compute_polarization( displacement_frames: int = 1, return_angle: bool = False, in_degrees: bool = False, + validate_ap: bool = False, + ap_validation_config: dict[str, Any] | None = None, ) -> xr.DataArray | tuple[xr.DataArray, xr.DataArray]: r"""Compute polarization (group alignment) of individuals. @@ -66,6 +73,13 @@ def compute_polarization( If True, the mean angle is returned in degrees. Otherwise, the angle is returned in radians. Only relevant when ``return_angle=True``. + validate_ap : bool, default=False + If True, run anterior-posterior axis validation when + ``body_axis_keypoints`` is provided. Validation is skipped for + displacement-based polarization. + ap_validation_config : dict, optional + Configuration overrides for anterior-posterior axis validation. + See ``movement.kinematics.body_axis.ValidateAPConfig`` for options. Returns ------- @@ -91,6 +105,10 @@ def compute_polarization( heading vector has zero magnitude (for example exact cancellation), the returned angle is NaN. + When ``validate_ap=True`` and ``body_axis_keypoints`` is provided, + anterior-posterior validation is run per individual and the result is + stored in ``polarization.attrs["ap_validation_result"]``. + Examples -------- Compute orientation polarization from body-axis keypoints: @@ -140,6 +158,14 @@ def compute_polarization( ... return_angle=True, ... ) + Run AP validation while computing body-axis polarization: + + >>> polarization = compute_polarization( + ... ds.position, + ... body_axis_keypoints=("tail_base", "neck"), + ... validate_ap=True, + ... ) + """ _validate_type_data_array(data) normalized_keypoints = _validate_position_data( @@ -147,7 +173,12 @@ def compute_polarization( body_axis_keypoints=body_axis_keypoints, ) + ap_validation_result = None if normalized_keypoints is not None: + if validate_ap: + ap_validation_result = run_ap_validation( + data, normalized_keypoints, ap_validation_config + ) heading_vectors = _compute_heading_from_keypoints( data=data, body_axis_keypoints=normalized_keypoints, @@ -158,11 +189,10 @@ def compute_polarization( displacement_frames=displacement_frames, ) - heading_xy = _select_xy(heading_vectors) - norm = compute_norm(heading_xy) - valid_mask = (~heading_xy.isnull().any(dim="space")) & (norm > 0) + heading = _select_space(heading_vectors) - unit_headings = (heading_xy / norm).where(valid_mask) + unit_headings = convert_to_unit(heading) + valid_mask = ~unit_headings.isnull().any(dim="space") vector_sum = unit_headings.sum(dim="individuals", skipna=True) sum_magnitude = compute_norm(vector_sum) n_valid = valid_mask.sum(dim="individuals") @@ -174,15 +204,22 @@ def compute_polarization( ).clip(min=0.0, max=1.0) polarization = polarization.rename("polarization") + if ap_validation_result is not None: + polarization.attrs["ap_validation_result"] = ap_validation_result + if not return_angle: return polarization + # Normalize vector_sum to unit vector for angle computation + mean_unit_vector = vector_sum / sum_magnitude + + # Compute angle from positive x-axis to mean unit vector + reference = np.array([1, 0]) angle_defined = (n_valid > 0) & (sum_magnitude > _ANGLE_EPS) mean_angle = xr.where( angle_defined, - np.arctan2( - vector_sum.sel(space="y"), - vector_sum.sel(space="x"), + compute_signed_angle_2d( + mean_unit_vector, reference, v_as_left_operand=True ), np.nan, ) @@ -235,9 +272,10 @@ def _compute_heading_from_velocity( return displacement -def _select_xy(data: xr.DataArray) -> xr.DataArray: - """Select the planar x/y components and return standard dim order.""" - return data.sel(space=["x", "y"]).transpose("time", "space", "individuals") +def _select_space(data: xr.DataArray) -> xr.DataArray: + """Return data with standard dim order, selecting only x and y coords.""" + result = data.sel(space=["x", "y"]) + return result.transpose("time", "space", "individuals") def _validate_position_data( diff --git a/tests/test_unit/test_kinematics/test_body_axis.py b/tests/test_unit/test_kinematics/test_body_axis.py new file mode 100644 index 000000000..9b316b6c1 --- /dev/null +++ b/tests/test_unit/test_kinematics/test_body_axis.py @@ -0,0 +1,62 @@ +# test_body_axis.py +"""Tests for the body axis validation module.""" + +from typing import Any + +import pytest + +from movement.kinematics.body_axis import ValidateAPConfig + + +class TestValidateAPConfig: + """Tests for the ValidateAPConfig dataclass parameter validation.""" + + @pytest.mark.parametrize( + ("field", "value"), + [ + ("min_valid_frac", -0.1), + ("min_valid_frac", 1.1), + ("window_len", 0), + ("window_len", -5), + ("window_len", 2.5), + ("stride", 0), + ("stride", -1), + ("stride", 1.5), + ("pct_thresh", -1), + ("pct_thresh", 101), + ("min_run_len", 0), + ("min_run_len", -1), + ("min_run_len", 1.5), + ("postural_var_ratio_thresh", 0), + ("postural_var_ratio_thresh", -1), + ("max_clusters", 0), + ("max_clusters", 2.5), + ("confidence_floor", -0.1), + ("confidence_floor", 1.1), + ("lateral_thresh", -0.1), + ("lateral_thresh", 1.1), + ("edge_thresh", -0.1), + ("edge_thresh", 1.1), + ], + ) + def test_invalid_config_values_raise(self, field: str, value: Any) -> None: + """Invalid config values should raise ValueError.""" + kwargs = {field: value} + with pytest.raises(ValueError, match="must be"): + ValidateAPConfig(**kwargs) + + def test_valid_config_does_not_raise(self) -> None: + """Valid config values should not raise any error.""" + # Should not raise + ValidateAPConfig( + min_valid_frac=0.5, + window_len=10, + stride=2, + pct_thresh=50.0, + min_run_len=2, + postural_var_ratio_thresh=1.5, + max_clusters=3, + confidence_floor=0.2, + lateral_thresh=0.3, + edge_thresh=0.2, + ) diff --git a/tests/test_unit/test_kinematics/test_collective.py b/tests/test_unit/test_kinematics/test_collective.py index 8a785d115..809b40bdf 100644 --- a/tests/test_unit/test_kinematics/test_collective.py +++ b/tests/test_unit/test_kinematics/test_collective.py @@ -17,6 +17,33 @@ def _get_space_labels(n_space: int, space: list[str] | None) -> list[str]: raise ValueError("Provide explicit `space` labels for non-2D data.") +def _build_coords( + data: np.ndarray, + time: list | None, + space: list[str] | None, + individuals: list | None, + keypoints: list[str] | None = None, +) -> dict: + """Build coordinate dict for a position DataArray.""" + n_time, n_space = data.shape[0], data.shape[1] + coords: dict = { + "time": time if time is not None else list(range(n_time)), + "space": _get_space_labels(n_space, space), + } + if data.ndim == 4: + n_keypoints = data.shape[2] + coords["keypoints"] = keypoints or [ + f"kp_{i}" for i in range(n_keypoints) + ] + n_individuals = data.shape[3] + else: + n_individuals = data.shape[2] + coords["individuals"] = individuals or [ + f"id_{i}" for i in range(n_individuals) + ] + return coords + + def _make_position_dataarray( data: np.ndarray, *, @@ -27,41 +54,22 @@ def _make_position_dataarray( ) -> xr.DataArray: """Create a position DataArray for tests.""" data = np.asarray(data, dtype=float) - n_time, n_space = data.shape[0], data.shape[1] - if data.ndim == 3: - n_individuals = data.shape[2] - ind = individuals or [f"id_{i}" for i in range(n_individuals)] - return xr.DataArray( - data, - dims=["time", "space", "individuals"], - coords={ - "time": time if time else list(range(n_time)), - "space": _get_space_labels(n_space, space), - "individuals": ind, - }, - name="position", - ) - - if data.ndim == 4: - n_keypoints, n_individuals = data.shape[2], data.shape[3] - kp = keypoints or [f"kp_{i}" for i in range(n_keypoints)] - ind = individuals or [f"id_{i}" for i in range(n_individuals)] - return xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time if time else list(range(n_time)), - "space": _get_space_labels(n_space, space), - "keypoints": kp, - "individuals": ind, - }, - name="position", + dims_map = { + 3: ["time", "space", "individuals"], + 4: ["time", "space", "keypoints", "individuals"], + } + if data.ndim not in dims_map: + raise ValueError( + "Expected data with shape (time, space, individuals) or " + "(time, space, keypoints, individuals)." ) - raise ValueError( - "Expected data with shape (time, space, individuals) or " - "(time, space, keypoints, individuals)." + return xr.DataArray( + data, + dims=dims_map[data.ndim], + coords=_build_coords(data, time, space, individuals, keypoints), + name="position", ) @@ -261,7 +269,7 @@ def test_body_axis_keypoints_must_be_distinct(self, keypoint_positions): ) @pytest.mark.parametrize( - "displacement_frames,expected_exception", + ("displacement_frames", "expected_exception"), [ (0, ValueError), (-1, ValueError), @@ -612,113 +620,92 @@ def test_polarization_is_invariant_to_global_rotation( equal_nan=True, ) - def test_body_axis_invariance_to_translation_scaling_rotation( - self, - ): - """Body-axis polarization is invariant to translation/scaling/rotation. + def _body_axis_baseline(self): + """Build body-axis test data and return (da, pol, angle). - Mean body angle is invariant to translation and positive scaling, and - rotates by the same amount under global planar rotation. + Three individuals with body axes: +x, +x, +y. + Vector sum = (2, 1), polarization = sqrt(5)/3, angle = atan2(1,2). + Absolute positions differ across frames to test body-axis + independence from location. """ - # Three individuals with body axes: +x, +x, +y. - # This gives a nontrivial baseline: - # vector sum = (2, 1) - # polarization = sqrt(5) / 3 - # mean angle = atan2(1, 2) - # - # Absolute positions differ across frames to ensure we are really - # testing body-axis heading (target - origin), not any accidental - # dependence on absolute location. data = np.array( [ [ - [[0.0, 10.0, -2.0], [1.0, 11.0, -2.0]], # x - [[0.0, 5.0, 3.0], [0.0, 5.0, 4.0]], # y + [[0.0, 10.0, -2.0], [1.0, 11.0, -2.0]], + [[0.0, 5.0, 3.0], [0.0, 5.0, 4.0]], ], [ - [[100.0, 50.0, 7.0], [101.0, 51.0, 7.0]], # x - [[-1.0, 20.0, -3.0], [-1.0, 20.0, -2.0]], # y + [[100.0, 50.0, 7.0], [101.0, 51.0, 7.0]], + [[-1.0, 20.0, -3.0], [-1.0, 20.0, -2.0]], ], ], dtype=float, ) da = _make_position_dataarray(data, keypoints=["tail_base", "neck"]) - - pol_base, angle_base = kinematics.compute_polarization( + pol, angle = kinematics.compute_polarization( da, body_axis_keypoints=("tail_base", "neck"), return_angle=True, ) + return da, pol, angle - expected_pol = np.sqrt(5) / 3 - expected_angle = np.arctan2(1.0, 2.0) - - np.testing.assert_allclose(pol_base.values, expected_pol, atol=1e-10) + def test_body_axis_baseline_matches_expected_values(self): + """Body-axis polarization and angle match hand-computed values.""" + _da, pol, angle = self._body_axis_baseline() + np.testing.assert_allclose(pol.values, np.sqrt(5) / 3, atol=1e-10) np.testing.assert_allclose( - angle_base.values, expected_angle, atol=1e-10 + angle.values, np.arctan2(1.0, 2.0), atol=1e-10 ) - # Global translation: should not affect body-axis vectors. + def test_body_axis_invariance_to_translation(self): + """Global translation does not change body-axis polarization.""" + da, pol_base, angle_base = self._body_axis_baseline() translated = da.copy() translated.loc[{"space": "x"}] = translated.sel(space="x") + 123.4 translated.loc[{"space": "y"}] = translated.sel(space="y") - 56.7 - pol_translated, angle_translated = kinematics.compute_polarization( + pol, angle = kinematics.compute_polarization( translated, body_axis_keypoints=("tail_base", "neck"), return_angle=True, ) + np.testing.assert_allclose(pol.values, pol_base.values, atol=1e-10) + np.testing.assert_allclose(angle.values, angle_base.values, atol=1e-10) - np.testing.assert_allclose( - pol_translated.values, pol_base.values, atol=1e-10 - ) - np.testing.assert_allclose( - angle_translated.values, angle_base.values, atol=1e-10 - ) - - # Positive scaling: should preserve directions and therefore preserve - # polarization and angle. - scaled = da * 4.2 + def test_body_axis_invariance_to_positive_scaling(self): + """Positive scaling preserves body-axis polarization and angle.""" + da, pol_base, angle_base = self._body_axis_baseline() - pol_scaled, angle_scaled = kinematics.compute_polarization( - scaled, + pol, angle = kinematics.compute_polarization( + da * 4.2, body_axis_keypoints=("tail_base", "neck"), return_angle=True, ) + np.testing.assert_allclose(pol.values, pol_base.values, atol=1e-10) + np.testing.assert_allclose(angle.values, angle_base.values, atol=1e-10) - np.testing.assert_allclose( - pol_scaled.values, pol_base.values, atol=1e-10 - ) - np.testing.assert_allclose( - angle_scaled.values, angle_base.values, atol=1e-10 - ) + def test_body_axis_angle_rotates_under_global_rotation(self): + """Polarization preserved, angle shifts by pi/2. - # Global 90-degree rotation: polarization magnitude should be - # unchanged, and mean angle should rotate by +pi/2 (with wraparound). + Tests behavior under 90-degree rotation. + """ + da, pol_base, angle_base = self._body_axis_baseline() rotated = da.copy() x = da.sel(space="x") y = da.sel(space="y") rotated.loc[{"space": "x"}] = -y rotated.loc[{"space": "y"}] = x - pol_rotated, angle_rotated = kinematics.compute_polarization( + pol, angle = kinematics.compute_polarization( rotated, body_axis_keypoints=("tail_base", "neck"), return_angle=True, ) + np.testing.assert_allclose(pol.values, pol_base.values, atol=1e-10) - np.testing.assert_allclose( - pol_rotated.values, pol_base.values, atol=1e-10 - ) - - expected_rotated_angle = angle_base.values + (np.pi / 2) - expected_rotated_angle = ( - (expected_rotated_angle + np.pi) % (2 * np.pi) - ) - np.pi - - np.testing.assert_allclose( - angle_rotated.values, expected_rotated_angle, atol=1e-10 - ) + expected = angle_base.values + (np.pi / 2) + expected = (expected + np.pi) % (2 * np.pi) - np.pi + np.testing.assert_allclose(angle.values, expected, atol=1e-10) class TestHeadingSourceSelection: @@ -852,10 +839,8 @@ def test_first_n_frames_are_nan(self, aligned_positions): displacement_frames=2, return_angle=True, ) - assert np.isnan(polarization.values[0]) - assert np.isnan(polarization.values[1]) - assert np.isnan(mean_angle.values[0]) - assert np.isnan(mean_angle.values[1]) + assert np.all(np.isnan(polarization.values[:2])) + assert np.all(np.isnan(mean_angle.values[:2])) assert np.allclose(polarization.values[2:], 1.0, atol=1e-10) assert np.allclose(mean_angle.values[2:], 0.0, atol=1e-10) @@ -936,13 +921,15 @@ def test_return_angle_true_returns_named_pair(self, aligned_positions): ) assert isinstance(polarization, xr.DataArray) assert isinstance(mean_angle, xr.DataArray) - assert polarization.name == "polarization" - assert mean_angle.name == "mean_angle" + assert (polarization.name, mean_angle.name) == ( + "polarization", + "mean_angle", + ) assert polarization.dims == ("time",) assert mean_angle.dims == ("time",) @pytest.mark.parametrize( - "data,expected_angle,use_abs", + ("data", "expected_angle", "use_abs"), [ ( np.array(