From a3b1c1e9d8377f158b9e06d8c1dbf5b0cee72038 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Tue, 10 Dec 2024 12:37:53 +0000 Subject: [PATCH 1/3] Add eval script for body metrics --- 5_eval_body_metrics.py | 150 ++++++++++++++++++++++++++ pyproject.toml | 1 + src/egoallo/guidance_optimizer_jax.py | 6 ++ src/egoallo/sampling.py | 8 +- 4 files changed, 161 insertions(+), 4 deletions(-) create mode 100644 5_eval_body_metrics.py diff --git a/5_eval_body_metrics.py b/5_eval_body_metrics.py new file mode 100644 index 0000000..48fcc74 --- /dev/null +++ b/5_eval_body_metrics.py @@ -0,0 +1,150 @@ +"""Example script for computing body metrics on the test split of the AMASS dataset. + +This is a tidied version of the code we used to compute metrics in our paper. +Some context: https://github.com/brentyi/egoallo/issues/7 +""" + +from pathlib import Path + +import jax.tree +import numpy as np +import torch.optim.lr_scheduler +import torch.utils.data +import tyro + +from egoallo import fncsmpl +from egoallo.data.amass import EgoAmassHdf5Dataset +from egoallo.fncsmpl_extensions import get_T_world_root_from_cpf_pose +from egoallo.inference_utils import load_denoiser +from egoallo.metrics_helpers import ( + compute_foot_contact, + compute_foot_skate, + compute_head_trans, + compute_mpjpe, +) +from egoallo.sampling import run_sampling_with_stitching +from egoallo.transforms import SE3, SO3 + + +def main( + dataset_hdf5_path: Path, + dataset_files_path: Path, + subseq_len: int = 128, + guidance_inner: bool = False, + checkpoint_dir: Path = Path("./egoallo_checkpoint_april13/checkpoints_3000000/"), + smplh_npz_path: Path = Path("./data/smplh/neutral/model.npz"), + num_samples: int = 1, +) -> None: + """Compute body metrics on the test split of the AMASS dataset.""" + device = torch.device("cuda") + + # Setup. + denoiser_network = load_denoiser(checkpoint_dir).to(device) + dataset = EgoAmassHdf5Dataset( + dataset_hdf5_path, + dataset_files_path, + splits=("test",), + # We need an extra timestep in order to compute the relative CPF pose. (T_cpf_tm1_cpf_t) + subseq_len=subseq_len + 1, + cache_files=True, + slice_strategy="deterministic", + random_variable_len_proportion=0.0, + ) + body_model = fncsmpl.SmplhModel.load(smplh_npz_path).to(device) + + metrics = list[dict[str, np.ndarray]]() + + for i in range(len(dataset)): + sequence = dataset[i].to(device) + + samples = run_sampling_with_stitching( + denoiser_network, + body_model=body_model, + guidance_mode="no_hands", + guidance_inner=guidance_inner, + guidance_post=True, + Ts_world_cpf=sequence.T_world_cpf, + hamer_detections=None, + aria_detections=None, + num_samples=num_samples, + floor_z=0.0, + device=device, + guidance_verbose=False, + ) + + assert samples.hand_rotmats is not None + assert samples.betas.shape == (num_samples, subseq_len, 16) + assert samples.body_rotmats.shape == (num_samples, subseq_len, 21, 3, 3) + assert samples.hand_rotmats.shape == (num_samples, subseq_len, 30, 3, 3) + assert sequence.hand_quats is not None + + # We'll only use the body joint rotations. + pred_posed = body_model.with_shape(samples.betas).with_pose( + T_world_root=SE3.identity(device, torch.float32).wxyz_xyz, + local_quats=SO3.from_matrix( + torch.cat([samples.body_rotmats, samples.hand_rotmats], dim=2) + ).wxyz, + ) + pred_posed = pred_posed.with_new_T_world_root( + get_T_world_root_from_cpf_pose(pred_posed, sequence.T_world_cpf[1:, ...]) + ) + + label_posed = body_model.with_shape(sequence.betas[1:, ...]).with_pose( + sequence.T_world_root[1:, ...], + torch.cat( + [ + sequence.body_quats[1:, ...], + sequence.hand_quats[1:, ...], + ], + dim=1, + ), + ) + + metrics.append( + { + "mpjpe": compute_mpjpe( + label_T_world_root=label_posed.T_world_root, + label_Ts_world_joint=label_posed.Ts_world_joint[:, :21, :], + pred_T_world_root=pred_posed.T_world_root, + pred_Ts_world_joint=pred_posed.Ts_world_joint[:, :, :21, :], + per_frame_procrustes_align=False, + ), + "pampjpe": compute_mpjpe( + label_T_world_root=label_posed.T_world_root, + label_Ts_world_joint=label_posed.Ts_world_joint[:, :21, :], + pred_T_world_root=pred_posed.T_world_root, + pred_Ts_world_joint=pred_posed.Ts_world_joint[:, :, :21, :], + per_frame_procrustes_align=True, + ), + # We didn't report foot skating metrics in the paper. It's not + # really meaningful: since we optimize foot skating in the + # guidance optimizer, it's easy to "cheat" this metric. + "foot_skate": compute_foot_skate( + pred_Ts_world_joint=pred_posed.Ts_world_joint[:, :, :21, :], + ), + "foot_contact (GND)": compute_foot_contact( + pred_Ts_world_joint=pred_posed.Ts_world_joint[:, :, :21, :], + ), + "T_head": compute_head_trans( + label_Ts_world_joint=label_posed.Ts_world_joint[:, :21, :], + pred_Ts_world_joint=pred_posed.Ts_world_joint[:, :, :21, :], + ), + } + ) + + print("=" * 80) + print("=" * 80) + print("=" * 80) + print(f"Metrics ({i}/{len(dataset)} processed)") + for k, v in jax.tree.map( + lambda *x: f"{np.mean(x):.3f} +/- {np.std(x) / np.sqrt(len(metrics) * num_samples):.3f}", + *metrics, + ).items(): + print("\t", k, v) + print("=" * 80) + print("=" * 80) + print("=" * 80) + + +if __name__ == "__main__": + tyro.cli(main) diff --git a/pyproject.toml b/pyproject.toml index e94089e..92c4f11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ select = [ "PLW", # Pylint warnings. ] ignore = [ + "E731", # Do not assign a lambda expression, use a def. "E741", # Ambiguous variable name. (l, O, or I) "E501", # Line too long. "E721", # Do not compare types, use `isinstance()`. diff --git a/src/egoallo/guidance_optimizer_jax.py b/src/egoallo/guidance_optimizer_jax.py index da12f53..b1b92f5 100644 --- a/src/egoallo/guidance_optimizer_jax.py +++ b/src/egoallo/guidance_optimizer_jax.py @@ -39,6 +39,7 @@ def do_guidance_optimization( phase: Literal["inner", "post"], hamer_detections: None | CorrespondedHamerDetections, aria_detections: None | CorrespondedAriaHandWristPoseDetections, + verbose: bool, ) -> tuple[network.EgoDenoiseTraj, dict]: """Run an optimizer to apply foot contact constraints.""" @@ -70,6 +71,7 @@ def do_guidance_optimization( aria_detections=None if aria_detections is None else aria_detections.as_nested_dict(numpy=True), + verbose=verbose, ) rotmats = SO3( torch.from_numpy(onp.array(quats)) @@ -122,6 +124,7 @@ def _optimize_vmapped( guidance_params: JaxGuidanceParams, hamer_detections: dict | None, aria_detections: dict | None, + verbose: jdc.Static[bool], ) -> tuple[jax.Array, dict]: return jax.vmap( partial( @@ -131,6 +134,7 @@ def _optimize_vmapped( guidance_params=guidance_params, hamer_detections=hamer_detections, aria_detections=aria_detections, + verbose=verbose, ) )( betas=betas, @@ -306,6 +310,7 @@ def _optimize( guidance_params: JaxGuidanceParams, hamer_detections: dict | None, aria_detections: dict | None, + verbose: bool, ) -> tuple[jax.Array, dict]: """Apply constraints using Levenberg-Marquardt optimizer. Returns updated body_rotmats and hand_rotmats matrices.""" @@ -867,6 +872,7 @@ def skating_cost( lambda_initial=guidance_params.lambda_initial ), termination=jaxls.TerminationConfig(max_iterations=guidance_params.max_iters), + verbose=verbose, ) out_body_quats = solutions[_SmplhBodyPosesVar] assert out_body_quats.shape == (timesteps, 21, 4) diff --git a/src/egoallo/sampling.py b/src/egoallo/sampling.py index e4a51e8..2993cef 100644 --- a/src/egoallo/sampling.py +++ b/src/egoallo/sampling.py @@ -9,10 +9,7 @@ from tqdm.auto import tqdm from . import fncsmpl, network -from .guidance_optimizer_jax import ( - GuidanceMode, - do_guidance_optimization, -) +from .guidance_optimizer_jax import GuidanceMode, do_guidance_optimization from .hand_detection_structs import ( CorrespondedAriaHandWristPoseDetections, CorrespondedHamerDetections, @@ -71,6 +68,7 @@ def run_sampling_with_stitching( aria_detections: None | CorrespondedAriaHandWristPoseDetections, num_samples: int, device: torch.device, + guidance_verbose: bool = True, ) -> network.EgoDenoiseTraj: # Offset the T_world_cpf transform to place the floor at z=0 for the # denoiser network. All of the network outputs are local, so we don't need to @@ -190,6 +188,7 @@ def run_sampling_with_stitching( phase="inner", hamer_detections=hamer_detections, aria_detections=aria_detections, + verbose=guidance_verbose, ) x_0_packed_pred = x_0_pred.pack() del x_0_pred @@ -224,6 +223,7 @@ def run_sampling_with_stitching( phase="post", hamer_detections=hamer_detections, aria_detections=aria_detections, + verbose=guidance_verbose, ) assert start_time is not None print("RUNTIME (exclude first optimization)", time.time() - start_time) From 62494af4b14b9dc8b002c5146d1925b069fa87be Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Tue, 10 Dec 2024 15:27:48 +0000 Subject: [PATCH 2/3] Add metrics --- 5_eval_body_metrics.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/5_eval_body_metrics.py b/5_eval_body_metrics.py index 48fcc74..9279795 100644 --- a/5_eval_body_metrics.py +++ b/5_eval_body_metrics.py @@ -1,7 +1,25 @@ """Example script for computing body metrics on the test split of the AMASS dataset. -This is a tidied version of the code we used to compute metrics in our paper. -Some context: https://github.com/brentyi/egoallo/issues/7 +This is not the exact script we used for the paper metrics, but should have the +details that matter matched. Below are some metrics from this script when our +released checkpoint is passed in. + +For --subseq-len 128: + + mpjpe 118.340 +/- 1.350 (in paper: 119.7 +/- 1.3) + pampjpe 100.026 +/- 1.349 (in paper: 101.1 +/- 1.3) + T_head 0.006 +/- 0.000 (in paper: 0.0062 +/- 0.0001) + foot_contact (GND) 1.000 +/- 0.000 (in paper: 1.0 +/- 0.0) + foot_skate 0.417 +/- 0.017 (not reported in paper) + + +For --subseq-len 32: + + mpjpe 129.193 +/- 1.108 (in paper: 129.8 +/- 1.1) + pampjpe 109.489 +/- 1.147 (in paper: 109.8 +/- 1.1) + T_head 0.006 +/- 0.000 (in paper: 0.0064 +/- 0.0001) + foot_contact (GND) 0.985 +/- 0.003 (in paper: 0.98 +/- 0.00) + foot_skate 0.185 +/- 0.005 (not reported in paper) """ from pathlib import Path From 94d7bab54090c03764518ef3a7e07912e373613c Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Tue, 10 Dec 2024 15:28:01 +0000 Subject: [PATCH 3/3] Add metrics helpers --- src/egoallo/metrics_helpers.py | 246 +++++++++++++++++++++++++++++++++ 1 file changed, 246 insertions(+) create mode 100644 src/egoallo/metrics_helpers.py diff --git a/src/egoallo/metrics_helpers.py b/src/egoallo/metrics_helpers.py new file mode 100644 index 0000000..f58460f --- /dev/null +++ b/src/egoallo/metrics_helpers.py @@ -0,0 +1,246 @@ +from typing import Literal, overload + +import numpy as np +import torch +from jaxtyping import Float +from torch import Tensor +from typing_extensions import assert_never + +from .transforms import SO3 + + +def compute_foot_skate( + pred_Ts_world_joint: Float[Tensor, "num_samples time 21 7"], +) -> np.ndarray: + (num_samples, time) = pred_Ts_world_joint.shape[:2] + + # Drop the person to the floor. + # This is necessary for the foot skating metric to make sense for floating people...! + pred_Ts_world_joint = pred_Ts_world_joint.clone() + pred_Ts_world_joint[..., 6] -= torch.min(pred_Ts_world_joint[..., 6]) + + foot_indices = torch.tensor([6, 7, 9, 10], device=pred_Ts_world_joint.device) + + foot_positions = pred_Ts_world_joint[:, :, foot_indices, 4:7] + foot_positions_diff = foot_positions[:, 1:, :, :2] - foot_positions[:, :-1, :, :2] + assert foot_positions_diff.shape == (num_samples, time - 1, 4, 2) + + foot_positions_diff_norm = torch.sum(torch.abs(foot_positions_diff), dim=-1) + assert foot_positions_diff_norm.shape == (num_samples, time - 1, 4) + + # From EgoEgo / kinpoly. + H_thresh = torch.tensor( + # To match indices above: (ankle, ankle, toe, toe) + [0.08, 0.08, 0.04, 0.04], + device=pred_Ts_world_joint.device, + dtype=torch.float32, + ) + + foot_positions_diff_norm = torch.sum(torch.abs(foot_positions_diff), dim=-1) + assert foot_positions_diff_norm.shape == (num_samples, time - 1, 4) + + # Threshold. + foot_positions_diff_norm = foot_positions_diff_norm * ( + foot_positions[..., 1:, :, 2] < H_thresh + ) + fs_per_sample = torch.sum( + torch.sum( + foot_positions_diff_norm + * (2 - 2 ** (foot_positions[..., 1:, :, 2] / H_thresh)), + dim=-1, + ), + dim=-1, + ) + assert fs_per_sample.shape == (num_samples,) + + return fs_per_sample.numpy(force=True) + + +def compute_foot_contact( + pred_Ts_world_joint: Float[Tensor, "num_samples time 21 7"], +) -> np.ndarray: + (num_samples, time) = pred_Ts_world_joint.shape[:2] + + foot_indices = torch.tensor([6, 7, 9, 10], device=pred_Ts_world_joint.device) + + # From EgoEgo / kinpoly. + H_thresh = torch.tensor( + # To match indices above: (ankle, ankle, toe, toe) + [0.08, 0.08, 0.04, 0.04], + device=pred_Ts_world_joint.device, + dtype=torch.float32, + ) + + foot_positions = pred_Ts_world_joint[:, :, foot_indices, 4:7] + + any_contact = torch.any( + torch.any(foot_positions[..., 2] < H_thresh, dim=-1), dim=-1 + ).to(torch.float32) + assert any_contact.shape == (num_samples,) + + return any_contact.numpy(force=True) + + +def compute_head_ori( + label_Ts_world_joint: Float[Tensor, "time 21 7"], + pred_Ts_world_joint: Float[Tensor, "num_samples time 21 7"], +) -> np.ndarray: + (num_samples, time) = pred_Ts_world_joint.shape[:2] + matrix_errors = ( + SO3(pred_Ts_world_joint[:, :, 14, :4]).as_matrix() + @ SO3(label_Ts_world_joint[:, 14, :4]).inverse().as_matrix() + ) - torch.eye(3, device=label_Ts_world_joint.device) + assert matrix_errors.shape == (num_samples, time, 3, 3) + + return torch.mean( + torch.linalg.norm(matrix_errors.reshape((num_samples, time, 9)), dim=-1), + dim=-1, + ).numpy(force=True) + + +def compute_head_trans( + label_Ts_world_joint: Float[Tensor, "time 21 7"], + pred_Ts_world_joint: Float[Tensor, "num_samples time 21 7"], +) -> np.ndarray: + (num_samples, time) = pred_Ts_world_joint.shape[:2] + errors = pred_Ts_world_joint[:, :, 14, 4:7] - label_Ts_world_joint[:, 14, 4:7] + assert errors.shape == (num_samples, time, 3) + + return torch.mean( + torch.linalg.norm(errors, dim=-1), + dim=-1, + ).numpy(force=True) + + +def compute_mpjpe( + label_T_world_root: Float[Tensor, "time 7"], + label_Ts_world_joint: Float[Tensor, "time 21 7"], + pred_T_world_root: Float[Tensor, "num_samples time 7"], + pred_Ts_world_joint: Float[Tensor, "num_samples time 21 7"], + per_frame_procrustes_align: bool, +) -> np.ndarray: + num_samples, time, _, _ = pred_Ts_world_joint.shape + + # Concatenate the world root to the joints. + label_Ts_world_joint = torch.cat( + [label_T_world_root[..., None, :], label_Ts_world_joint], dim=-2 + ) + pred_Ts_world_joint = torch.cat( + [pred_T_world_root[..., None, :], pred_Ts_world_joint], dim=-2 + ) + del label_T_world_root, pred_T_world_root + + pred_joint_positions = pred_Ts_world_joint[:, :, :, 4:7] + label_joint_positions = label_Ts_world_joint[None, :, :, 4:7].repeat( + num_samples, 1, 1, 1 + ) + + if per_frame_procrustes_align: + pred_joint_positions = procrustes_align( + points_y=pred_joint_positions, + points_x=label_joint_positions, + output="aligned_x", + ) + + position_differences = pred_joint_positions - label_joint_positions + assert position_differences.shape == (num_samples, time, 22, 3) + + # Per-joint position errors, in millimeters. + pjpe = torch.linalg.norm(position_differences, dim=-1) * 1000.0 + assert pjpe.shape == (num_samples, time, 22) + + # Mean per-joint position errors. + mpjpe = torch.mean(pjpe.reshape((num_samples, -1)), dim=-1) + assert mpjpe.shape == (num_samples,) + + return mpjpe.cpu().numpy() + + +@overload +def procrustes_align( + points_y: Float[Tensor, "*#batch N 3"], + points_x: Float[Tensor, "*#batch N 3"], + output: Literal["transforms"], + fix_scale: bool = False, +) -> tuple[Tensor, Tensor, Tensor]: ... + + +@overload +def procrustes_align( + points_y: Float[Tensor, "*#batch N 3"], + points_x: Float[Tensor, "*#batch N 3"], + output: Literal["aligned_x"], + fix_scale: bool = False, +) -> Tensor: ... + + +def procrustes_align( + points_y: Float[Tensor, "*#batch N 3"], + points_x: Float[Tensor, "*#batch N 3"], + output: Literal["transforms", "aligned_x"], + fix_scale: bool = False, +) -> tuple[Tensor, Tensor, Tensor] | Tensor: + """Similarity transform alignment using the Umeyama method. Adapted from + SLAHMR: https://github.com/vye16/slahmr/blob/main/slahmr/geometry/pcl.py + Minimizes: + mean( || Y - s * (R @ X) + t ||^2 ) + with respect to s, R, and t. + Returns an (s, R, t) tuple. + """ + *dims, N, _ = points_y.shape + device = points_y.device + N = torch.ones((*dims, 1, 1), device=device) * N + + # subtract mean + my = points_y.sum(dim=-2) / N[..., 0] # (*, 3) + mx = points_x.sum(dim=-2) / N[..., 0] + y0 = points_y - my[..., None, :] # (*, N, 3) + x0 = points_x - mx[..., None, :] + + # correlation + C = torch.matmul(y0.transpose(-1, -2), x0) / N # (*, 3, 3) + U, D, Vh = torch.linalg.svd(C) # (*, 3, 3), (*, 3), (*, 3, 3) + + S = ( + torch.eye(3, device=device) + .reshape(*(1,) * (len(dims)), 3, 3) + .repeat(*dims, 1, 1) + ) + neg = torch.det(U) * torch.det(Vh.transpose(-1, -2)) < 0 + S = torch.where( + neg.reshape(*dims, 1, 1), + S * torch.diag(torch.tensor([1, 1, -1], device=device)), + S, + ) + + R = torch.matmul(U, torch.matmul(S, Vh)) # (*, 3, 3) + + D = torch.diag_embed(D) # (*, 3, 3) + if fix_scale: + s = torch.ones(*dims, 1, device=device, dtype=torch.float32) + else: + var = torch.sum(torch.square(x0), dim=(-1, -2), keepdim=True) / N # (*, 1, 1) + s = ( + torch.diagonal(torch.matmul(D, S), dim1=-2, dim2=-1).sum( + dim=-1, keepdim=True + ) + / var[..., 0] + ) # (*, 1) + + t = my - s * torch.matmul(R, mx[..., None])[..., 0] # (*, 3) + + assert s.shape == (*dims, 1) + assert R.shape == (*dims, 3, 3) + assert t.shape == (*dims, 3) + + if output == "transforms": + return s, R, t + elif output == "aligned_x": + aligned_x = ( + s[..., None, :] * torch.einsum("...ij,...nj->...ni", R, points_x) + + t[..., None, :] + ) + assert aligned_x.shape == points_x.shape + return aligned_x + else: + assert_never(output)