Skip to content

Commit 8fe6934

Browse files
bottlerfacebook-github-bot
authored andcommitted
fix subdivide_meshes with empty mesh #1788
Summary: Simplify code fixes #1788 Reviewed By: MichaelRamamonjisoa Differential Revision: D61847675 fbshipit-source-id: 48400875d1d885bb3615bc9f4b3c7c3d822b67e7
1 parent c434957 commit 8fe6934

File tree

2 files changed

+22
-78
lines changed

2 files changed

+22
-78
lines changed

pytorch3d/ops/subdivide_meshes.py

Lines changed: 13 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -353,45 +353,16 @@ def _create_verts_index(verts_per_mesh, edges_per_mesh, device=None):
353353
# e.g. verts_per_mesh = (4, 5, 6)
354354
# e.g. edges_per_mesh = (5, 7, 9)
355355

356-
V = verts_per_mesh.sum() # e.g. 15
357-
E = edges_per_mesh.sum() # e.g. 21
358-
359-
verts_per_mesh_cumsum = verts_per_mesh.cumsum(dim=0) # (N,) e.g. (4, 9, 15)
360-
edges_per_mesh_cumsum = edges_per_mesh.cumsum(dim=0) # (N,) e.g. (5, 12, 21)
361-
362-
v_to_e_idx = verts_per_mesh_cumsum.clone()
363-
364-
# vertex to edge index.
365-
v_to_e_idx[1:] += edges_per_mesh_cumsum[
366-
:-1
367-
] # e.g. (4, 9, 15) + (0, 5, 12) = (4, 14, 27)
368-
369-
# vertex to edge offset.
370-
v_to_e_offset = V - verts_per_mesh_cumsum # e.g. 15 - (4, 9, 15) = (11, 6, 0)
371-
v_to_e_offset[1:] += edges_per_mesh_cumsum[
372-
:-1
373-
] # e.g. (11, 6, 0) + (0, 5, 12) = (11, 11, 12)
374-
e_to_v_idx = (
375-
verts_per_mesh_cumsum[:-1] + edges_per_mesh_cumsum[:-1]
376-
) # (4, 9) + (5, 12) = (9, 21)
377-
e_to_v_offset = (
378-
verts_per_mesh_cumsum[:-1] - edges_per_mesh_cumsum[:-1] - V
379-
) # (4, 9) - (5, 12) - 15 = (-16, -18)
380-
381-
# Add one new vertex per edge.
382-
idx_diffs = torch.ones(V + E, device=device, dtype=torch.int64) # (36,)
383-
idx_diffs[v_to_e_idx] += v_to_e_offset
384-
idx_diffs[e_to_v_idx] += e_to_v_offset
385-
386-
# e.g.
387-
# [
388-
# 1, 1, 1, 1, 12, 1, 1, 1, 1,
389-
# -15, 1, 1, 1, 1, 12, 1, 1, 1, 1, 1, 1,
390-
# -17, 1, 1, 1, 1, 1, 13, 1, 1, 1, 1, 1, 1, 1
391-
# ]
392-
393-
verts_idx = idx_diffs.cumsum(dim=0) - 1
394-
356+
rng = torch.arange(verts_per_mesh.shape[0], device=device) # (0,1,2)
357+
verts_nums = rng.repeat_interleave(
358+
verts_per_mesh
359+
) # (0,0,0,0,1,1,1,1,1,2,2,2,2,2,2)
360+
edges_nums = rng.repeat_interleave(
361+
edges_per_mesh
362+
) # (0,0,0,0,0,1,1,1,1,1,1,1,2,2,2,2,2,2,2,2,2)
363+
nums = torch.cat([verts_nums, edges_nums])
364+
365+
verts_idx = torch.argsort(nums, stable=True)
395366
# e.g.
396367
# [
397368
# 0, 1, 2, 3, 15, 16, 17, 18, 19, --> mesh 0
@@ -400,7 +371,6 @@ def _create_verts_index(verts_per_mesh, edges_per_mesh, device=None):
400371
# ]
401372
# where for mesh 0, [0, 1, 2, 3] are the indices of the existing verts, and
402373
# [15, 16, 17, 18, 19] are the indices of the new verts after subdivision.
403-
404374
return verts_idx
405375

406376

@@ -421,44 +391,9 @@ def _create_faces_index(faces_per_mesh: torch.Tensor, device=None):
421391
"""
422392
# e.g. faces_per_mesh = [2, 5, 3]
423393

424-
F = faces_per_mesh.sum() # e.g. 10
425-
faces_per_mesh_cumsum = faces_per_mesh.cumsum(dim=0) # (N,) e.g. (2, 7, 10)
426-
427-
switch1_idx = faces_per_mesh_cumsum.clone()
428-
switch1_idx[1:] += (
429-
3 * faces_per_mesh_cumsum[:-1]
430-
) # e.g. (2, 7, 10) + (0, 6, 21) = (2, 13, 31)
431-
432-
switch2_idx = 2 * faces_per_mesh_cumsum # e.g. (4, 14, 20)
433-
switch2_idx[1:] += (
434-
2 * faces_per_mesh_cumsum[:-1]
435-
) # e.g. (4, 14, 20) + (0, 4, 14) = (4, 18, 34)
436-
437-
switch3_idx = 3 * faces_per_mesh_cumsum # e.g. (6, 21, 30)
438-
switch3_idx[1:] += faces_per_mesh_cumsum[
439-
:-1
440-
] # e.g. (6, 21, 30) + (0, 2, 7) = (6, 23, 37)
441-
442-
switch4_idx = 4 * faces_per_mesh_cumsum[:-1] # e.g. (8, 28)
443-
444-
switch123_offset = F - faces_per_mesh # e.g. (8, 5, 7)
445-
446-
# pyre-fixme[6]: For 1st param expected `Union[List[int], Size,
447-
# typing.Tuple[int, ...]]` but got `Tensor`.
448-
idx_diffs = torch.ones(4 * F, device=device, dtype=torch.int64)
449-
idx_diffs[switch1_idx] += switch123_offset
450-
idx_diffs[switch2_idx] += switch123_offset
451-
idx_diffs[switch3_idx] += switch123_offset
452-
idx_diffs[switch4_idx] -= 3 * F
453-
454-
# e.g
455-
# [
456-
# 1, 1, 9, 1, 9, 1, 9, 1, -> mesh 0
457-
# -29, 1, 1, 1, 1, 6, 1, 1, 1, 1, 6, 1, 1, 1, 1, 6, 1, 1, 1, 1, -> mesh 1
458-
# -29, 1, 1, 8, 1, 1, 8, 1, 1, 8, 1, 1 -> mesh 2
459-
# ]
460-
461-
faces_idx = idx_diffs.cumsum(dim=0) - 1
394+
rng = torch.arange(faces_per_mesh.shape[0], device=device) # (0,1,2)
395+
nums = rng.repeat_interleave(faces_per_mesh).repeat(4)
396+
faces_idx = torch.argsort(nums, stable=True)
462397

463398
# e.g.
464399
# [

tests/test_subdivide_meshes.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,15 @@ def test_subdivide_features(self):
217217
self.assertClose(new_feats, gt_feats)
218218
self.assertTrue(new_feats.requires_grad == gt_feats.requires_grad)
219219

220+
def test_with_empty(self):
221+
verts_list = [[[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0]], []]
222+
faces_list = [[[0, 1, 2], [0, 2, 3]], []]
223+
verts_list = [torch.tensor(verts, dtype=torch.float64) for verts in verts_list]
224+
face_list = [torch.tensor(faces, dtype=torch.long) for faces in faces_list]
225+
meshes = Meshes(verts=verts_list, faces=face_list)
226+
subdivided_meshes = SubdivideMeshes()(meshes)
227+
self.assertEqual(len(subdivided_meshes), 2)
228+
220229
@staticmethod
221230
def subdivide_meshes_with_init(num_meshes: int = 10, same_topo: bool = False):
222231
device = torch.device("cuda:0")

0 commit comments

Comments
 (0)