Skip to content

Commit f319e3e

Browse files
authored
Feature/detect format function (#144)
* CDL: minor doc typo fix * Undoing some changes that got mixed in * Add detect_pose_format function and SupportedPoseFormat Literal * detect_known_pose_format and tests for it. * various cleanup changes, style changes * missing import * undo black formatting for face contours and ignore_names * SupportedPoseFormat->KnownPoseFormat * Unreachable raise ValueErrors fixed * generic utils type annotations * change detect_known_format to take Pose or PoseHeader * Reraise ImportError if mediapipe is not installed * conftest update to supply unknown-format fake poses * nicer formatting for plane_info and line_info * fix import in generic_test.py * add some pylint disables, consistent with pose-evaluation * Change import in conftest.py * change import style in generic.py * change more imports * Fix a few type issues * Change matrix strategy fail-fast to false, so that we can still run tests if Python 3.8 does not work * Union for type annotation backwards compatibility * Add checks for NotImplementedError * Fix correct_wrist modifying input, and wrong shape for stacked conf. Also added a function to check fake_pose and its outputs * Simplify get_component_names and fix spacing * fix test_get_component_names
1 parent 56e6717 commit f319e3e

File tree

12 files changed

+419
-74
lines changed

12 files changed

+419
-74
lines changed

.github/workflows/python.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ jobs:
1313
strategy:
1414
matrix:
1515
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
16+
fail-fast: false
1617

1718
steps:
1819
- uses: actions/checkout@v3

.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,6 @@
11
.idea/
22
.DS_Store
3+
.vscode/
4+
.coverage
5+
.coveragerc
6+
coverage.lcov

src/python/ComfyUI-Pose-Format/nodes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import cv2
44
import torch
5-
from pose_format import Pose
5+
from pose_format.pose import Pose
66
from pose_format.pose_visualizer import PoseVisualizer
77
from pose_format.utils.generic import reduce_holistic
88
from pose_format.utils.openpose import OpenPose_Components

src/python/pose_format/bin/pose_visualizer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import argparse
44
import os
55

6-
from pose_format import Pose
6+
from pose_format.pose import Pose
77
from pose_format.pose_visualizer import PoseVisualizer
88
from pose_format.utils.generic import pose_normalization_info
99

src/python/pose_format/pose.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from itertools import chain
2-
from typing import BinaryIO, Dict, List, Tuple, Type
2+
from typing import BinaryIO, Dict, List, Tuple, Type, Union
33

44
import numpy as np
55
import numpy.ma as ma
@@ -87,7 +87,7 @@ def focus(self):
8787
dimensions = (maxs - mins).tolist()
8888
self.header.dimensions = PoseHeaderDimensions(*dimensions)
8989

90-
def normalize(self, info: PoseNormalizationInfo|None=None, scale_factor: float = 1) -> "Pose":
90+
def normalize(self, info: Union[PoseNormalizationInfo,None]=None, scale_factor: float = 1) -> "Pose":
9191
"""
9292
Normalize the points to a fixed distance between two particular points.
9393
@@ -203,7 +203,7 @@ def frame_dropout_normal(self, dropout_mean: float = 0.5, dropout_std: float = 0
203203
body, selected_indexes = self.body.frame_dropout_normal(dropout_mean=dropout_mean, dropout_std=dropout_std)
204204
return Pose(header=self.header, body=body), selected_indexes
205205

206-
def get_components(self, components: List[str], points: Dict[str, List[str]] = None):
206+
def get_components(self, components: List[str], points: Union[Dict[str, List[str]],None] = None):
207207
"""
208208
get pose components based on criteria.
209209
+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import copy
2+
from typing import List, get_args
3+
import pytest
4+
from pose_format.pose import Pose
5+
from pose_format.utils.generic import get_standard_components_for_known_format, fake_pose, KnownPoseFormat
6+
7+
@pytest.fixture
8+
def fake_poses(request) -> List[Pose]:
9+
# Access the parameter passed to the fixture
10+
known_format = request.param
11+
count = getattr(request, "count", 3)
12+
known_formats = get_args(KnownPoseFormat)
13+
if known_format in known_formats:
14+
15+
components = get_standard_components_for_known_format(known_format)
16+
return copy.deepcopy([fake_pose(i * 10 + 10, components=components) for i in range(count)])
17+
else:
18+
# get openpose
19+
fake_poses_list = [fake_pose(i * 10 + 10) for i in range(count)]
20+
for i, pose in enumerate(fake_poses_list):
21+
for component in pose.header.components:
22+
component.name = f"unknown_component_{i}_formerly_{component.name}"
23+
return copy.deepcopy(fake_poses_list)

0 commit comments

Comments
 (0)