Skip to content

Commit 2784829

Browse files
authored
Add eval script for body metrics (#9)
* Add eval script for body metrics * Add metrics * Add metrics helpers
1 parent b37f1fc commit 2784829

5 files changed

+425
-4
lines changed

5_eval_body_metrics.py

+168
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
"""Example script for computing body metrics on the test split of the AMASS dataset.
2+
3+
This is not the exact script we used for the paper metrics, but should have the
4+
details that matter matched. Below are some metrics from this script when our
5+
released checkpoint is passed in.
6+
7+
For --subseq-len 128:
8+
9+
mpjpe 118.340 +/- 1.350 (in paper: 119.7 +/- 1.3)
10+
pampjpe 100.026 +/- 1.349 (in paper: 101.1 +/- 1.3)
11+
T_head 0.006 +/- 0.000 (in paper: 0.0062 +/- 0.0001)
12+
foot_contact (GND) 1.000 +/- 0.000 (in paper: 1.0 +/- 0.0)
13+
foot_skate 0.417 +/- 0.017 (not reported in paper)
14+
15+
16+
For --subseq-len 32:
17+
18+
mpjpe 129.193 +/- 1.108 (in paper: 129.8 +/- 1.1)
19+
pampjpe 109.489 +/- 1.147 (in paper: 109.8 +/- 1.1)
20+
T_head 0.006 +/- 0.000 (in paper: 0.0064 +/- 0.0001)
21+
foot_contact (GND) 0.985 +/- 0.003 (in paper: 0.98 +/- 0.00)
22+
foot_skate 0.185 +/- 0.005 (not reported in paper)
23+
"""
24+
25+
from pathlib import Path
26+
27+
import jax.tree
28+
import numpy as np
29+
import torch.optim.lr_scheduler
30+
import torch.utils.data
31+
import tyro
32+
33+
from egoallo import fncsmpl
34+
from egoallo.data.amass import EgoAmassHdf5Dataset
35+
from egoallo.fncsmpl_extensions import get_T_world_root_from_cpf_pose
36+
from egoallo.inference_utils import load_denoiser
37+
from egoallo.metrics_helpers import (
38+
compute_foot_contact,
39+
compute_foot_skate,
40+
compute_head_trans,
41+
compute_mpjpe,
42+
)
43+
from egoallo.sampling import run_sampling_with_stitching
44+
from egoallo.transforms import SE3, SO3
45+
46+
47+
def main(
48+
dataset_hdf5_path: Path,
49+
dataset_files_path: Path,
50+
subseq_len: int = 128,
51+
guidance_inner: bool = False,
52+
checkpoint_dir: Path = Path("./egoallo_checkpoint_april13/checkpoints_3000000/"),
53+
smplh_npz_path: Path = Path("./data/smplh/neutral/model.npz"),
54+
num_samples: int = 1,
55+
) -> None:
56+
"""Compute body metrics on the test split of the AMASS dataset."""
57+
device = torch.device("cuda")
58+
59+
# Setup.
60+
denoiser_network = load_denoiser(checkpoint_dir).to(device)
61+
dataset = EgoAmassHdf5Dataset(
62+
dataset_hdf5_path,
63+
dataset_files_path,
64+
splits=("test",),
65+
# We need an extra timestep in order to compute the relative CPF pose. (T_cpf_tm1_cpf_t)
66+
subseq_len=subseq_len + 1,
67+
cache_files=True,
68+
slice_strategy="deterministic",
69+
random_variable_len_proportion=0.0,
70+
)
71+
body_model = fncsmpl.SmplhModel.load(smplh_npz_path).to(device)
72+
73+
metrics = list[dict[str, np.ndarray]]()
74+
75+
for i in range(len(dataset)):
76+
sequence = dataset[i].to(device)
77+
78+
samples = run_sampling_with_stitching(
79+
denoiser_network,
80+
body_model=body_model,
81+
guidance_mode="no_hands",
82+
guidance_inner=guidance_inner,
83+
guidance_post=True,
84+
Ts_world_cpf=sequence.T_world_cpf,
85+
hamer_detections=None,
86+
aria_detections=None,
87+
num_samples=num_samples,
88+
floor_z=0.0,
89+
device=device,
90+
guidance_verbose=False,
91+
)
92+
93+
assert samples.hand_rotmats is not None
94+
assert samples.betas.shape == (num_samples, subseq_len, 16)
95+
assert samples.body_rotmats.shape == (num_samples, subseq_len, 21, 3, 3)
96+
assert samples.hand_rotmats.shape == (num_samples, subseq_len, 30, 3, 3)
97+
assert sequence.hand_quats is not None
98+
99+
# We'll only use the body joint rotations.
100+
pred_posed = body_model.with_shape(samples.betas).with_pose(
101+
T_world_root=SE3.identity(device, torch.float32).wxyz_xyz,
102+
local_quats=SO3.from_matrix(
103+
torch.cat([samples.body_rotmats, samples.hand_rotmats], dim=2)
104+
).wxyz,
105+
)
106+
pred_posed = pred_posed.with_new_T_world_root(
107+
get_T_world_root_from_cpf_pose(pred_posed, sequence.T_world_cpf[1:, ...])
108+
)
109+
110+
label_posed = body_model.with_shape(sequence.betas[1:, ...]).with_pose(
111+
sequence.T_world_root[1:, ...],
112+
torch.cat(
113+
[
114+
sequence.body_quats[1:, ...],
115+
sequence.hand_quats[1:, ...],
116+
],
117+
dim=1,
118+
),
119+
)
120+
121+
metrics.append(
122+
{
123+
"mpjpe": compute_mpjpe(
124+
label_T_world_root=label_posed.T_world_root,
125+
label_Ts_world_joint=label_posed.Ts_world_joint[:, :21, :],
126+
pred_T_world_root=pred_posed.T_world_root,
127+
pred_Ts_world_joint=pred_posed.Ts_world_joint[:, :, :21, :],
128+
per_frame_procrustes_align=False,
129+
),
130+
"pampjpe": compute_mpjpe(
131+
label_T_world_root=label_posed.T_world_root,
132+
label_Ts_world_joint=label_posed.Ts_world_joint[:, :21, :],
133+
pred_T_world_root=pred_posed.T_world_root,
134+
pred_Ts_world_joint=pred_posed.Ts_world_joint[:, :, :21, :],
135+
per_frame_procrustes_align=True,
136+
),
137+
# We didn't report foot skating metrics in the paper. It's not
138+
# really meaningful: since we optimize foot skating in the
139+
# guidance optimizer, it's easy to "cheat" this metric.
140+
"foot_skate": compute_foot_skate(
141+
pred_Ts_world_joint=pred_posed.Ts_world_joint[:, :, :21, :],
142+
),
143+
"foot_contact (GND)": compute_foot_contact(
144+
pred_Ts_world_joint=pred_posed.Ts_world_joint[:, :, :21, :],
145+
),
146+
"T_head": compute_head_trans(
147+
label_Ts_world_joint=label_posed.Ts_world_joint[:, :21, :],
148+
pred_Ts_world_joint=pred_posed.Ts_world_joint[:, :, :21, :],
149+
),
150+
}
151+
)
152+
153+
print("=" * 80)
154+
print("=" * 80)
155+
print("=" * 80)
156+
print(f"Metrics ({i}/{len(dataset)} processed)")
157+
for k, v in jax.tree.map(
158+
lambda *x: f"{np.mean(x):.3f} +/- {np.std(x) / np.sqrt(len(metrics) * num_samples):.3f}",
159+
*metrics,
160+
).items():
161+
print("\t", k, v)
162+
print("=" * 80)
163+
print("=" * 80)
164+
print("=" * 80)
165+
166+
167+
if __name__ == "__main__":
168+
tyro.cli(main)

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ select = [
4545
"PLW", # Pylint warnings.
4646
]
4747
ignore = [
48+
"E731", # Do not assign a lambda expression, use a def.
4849
"E741", # Ambiguous variable name. (l, O, or I)
4950
"E501", # Line too long.
5051
"E721", # Do not compare types, use `isinstance()`.

src/egoallo/guidance_optimizer_jax.py

+6
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def do_guidance_optimization(
3939
phase: Literal["inner", "post"],
4040
hamer_detections: None | CorrespondedHamerDetections,
4141
aria_detections: None | CorrespondedAriaHandWristPoseDetections,
42+
verbose: bool,
4243
) -> tuple[network.EgoDenoiseTraj, dict]:
4344
"""Run an optimizer to apply foot contact constraints."""
4445

@@ -70,6 +71,7 @@ def do_guidance_optimization(
7071
aria_detections=None
7172
if aria_detections is None
7273
else aria_detections.as_nested_dict(numpy=True),
74+
verbose=verbose,
7375
)
7476
rotmats = SO3(
7577
torch.from_numpy(onp.array(quats))
@@ -122,6 +124,7 @@ def _optimize_vmapped(
122124
guidance_params: JaxGuidanceParams,
123125
hamer_detections: dict | None,
124126
aria_detections: dict | None,
127+
verbose: jdc.Static[bool],
125128
) -> tuple[jax.Array, dict]:
126129
return jax.vmap(
127130
partial(
@@ -131,6 +134,7 @@ def _optimize_vmapped(
131134
guidance_params=guidance_params,
132135
hamer_detections=hamer_detections,
133136
aria_detections=aria_detections,
137+
verbose=verbose,
134138
)
135139
)(
136140
betas=betas,
@@ -306,6 +310,7 @@ def _optimize(
306310
guidance_params: JaxGuidanceParams,
307311
hamer_detections: dict | None,
308312
aria_detections: dict | None,
313+
verbose: bool,
309314
) -> tuple[jax.Array, dict]:
310315
"""Apply constraints using Levenberg-Marquardt optimizer. Returns updated
311316
body_rotmats and hand_rotmats matrices."""
@@ -867,6 +872,7 @@ def skating_cost(
867872
lambda_initial=guidance_params.lambda_initial
868873
),
869874
termination=jaxls.TerminationConfig(max_iterations=guidance_params.max_iters),
875+
verbose=verbose,
870876
)
871877
out_body_quats = solutions[_SmplhBodyPosesVar]
872878
assert out_body_quats.shape == (timesteps, 21, 4)

0 commit comments

Comments
 (0)