From 2fefc0939c6fd982d2de06c4eccad8fa8c824d8d Mon Sep 17 00:00:00 2001 From: iulia Date: Fri, 8 May 2020 17:27:31 -0400 Subject: [PATCH] draw_graph fix --- deepposekit/utils/keypoints.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/deepposekit/utils/keypoints.py b/deepposekit/utils/keypoints.py index cb89186..c3d799c 100644 --- a/deepposekit/utils/keypoints.py +++ b/deepposekit/utils/keypoints.py @@ -66,6 +66,7 @@ def draw_graph(keypoints, height, width, output_shape, graph, sigma=1, linewidth edge_confidence_list = [] for idx, label in enumerate(labels): lines = graph[edge_labels == label] + keep_idx = np.where(lines != -1)[0] lines_idx = np.where(edge_labels == label)[0] edge_confidence = np.zeros((out_height, out_width, lines.shape[0])) zeros = np.zeros((height, width, 1), dtype=np.uint8) @@ -73,20 +74,23 @@ def draw_graph(keypoints, height, width, output_shape, graph, sigma=1, linewidth if line >= 0: pt1 = keypoints[line_idx] pt2 = keypoints[line] - line_map = cv2.line( - zeros.copy(), - (int(pt1[0]), int(pt1[1])), - (int(pt2[0]), int(pt2[1])), - 1, - linewidth, - lineType=cv2.LINE_AA, - ) - blurred = cv2.GaussianBlur( - line_map.astype(np.float64), (height + 1, width + 1), sigma - ) - resized = cv2.resize(blurred, (out_width, out_height)) + MACHINE_EPSILON - edge_confidence[..., jdx] = resized - edge_confidence = edge_confidence[..., 1:] + nan_pt1 = np.any(pt1 < 0) + nan_pt2 = np.any(pt2 < 0) + if not (nan_pt1 or nan_pt2): + line_map = cv2.line( + zeros.copy(), + (int(pt1[0]), int(pt1[1])), + (int(pt2[0]), int(pt2[1])), + 1, + linewidth, + lineType=cv2.LINE_AA, + ) + blurred = cv2.GaussianBlur( + line_map.astype(np.float64), (height + 1, width + 1), sigma + ) + resized = cv2.resize(blurred, (out_width, out_height)) + MACHINE_EPSILON + edge_confidence[..., jdx] = resized + edge_confidence = edge_confidence[:,:,keep_idx] edge_confidence_list.append(edge_confidence) confidence[..., idx] = edge_confidence.sum(-1) edge_confidence = np.concatenate(edge_confidence_list, -1)