@@ -353,45 +353,16 @@ def _create_verts_index(verts_per_mesh, edges_per_mesh, device=None):
353
353
# e.g. verts_per_mesh = (4, 5, 6)
354
354
# e.g. edges_per_mesh = (5, 7, 9)
355
355
356
- V = verts_per_mesh .sum () # e.g. 15
357
- E = edges_per_mesh .sum () # e.g. 21
358
-
359
- verts_per_mesh_cumsum = verts_per_mesh .cumsum (dim = 0 ) # (N,) e.g. (4, 9, 15)
360
- edges_per_mesh_cumsum = edges_per_mesh .cumsum (dim = 0 ) # (N,) e.g. (5, 12, 21)
361
-
362
- v_to_e_idx = verts_per_mesh_cumsum .clone ()
363
-
364
- # vertex to edge index.
365
- v_to_e_idx [1 :] += edges_per_mesh_cumsum [
366
- :- 1
367
- ] # e.g. (4, 9, 15) + (0, 5, 12) = (4, 14, 27)
368
-
369
- # vertex to edge offset.
370
- v_to_e_offset = V - verts_per_mesh_cumsum # e.g. 15 - (4, 9, 15) = (11, 6, 0)
371
- v_to_e_offset [1 :] += edges_per_mesh_cumsum [
372
- :- 1
373
- ] # e.g. (11, 6, 0) + (0, 5, 12) = (11, 11, 12)
374
- e_to_v_idx = (
375
- verts_per_mesh_cumsum [:- 1 ] + edges_per_mesh_cumsum [:- 1 ]
376
- ) # (4, 9) + (5, 12) = (9, 21)
377
- e_to_v_offset = (
378
- verts_per_mesh_cumsum [:- 1 ] - edges_per_mesh_cumsum [:- 1 ] - V
379
- ) # (4, 9) - (5, 12) - 15 = (-16, -18)
380
-
381
- # Add one new vertex per edge.
382
- idx_diffs = torch .ones (V + E , device = device , dtype = torch .int64 ) # (36,)
383
- idx_diffs [v_to_e_idx ] += v_to_e_offset
384
- idx_diffs [e_to_v_idx ] += e_to_v_offset
385
-
386
- # e.g.
387
- # [
388
- # 1, 1, 1, 1, 12, 1, 1, 1, 1,
389
- # -15, 1, 1, 1, 1, 12, 1, 1, 1, 1, 1, 1,
390
- # -17, 1, 1, 1, 1, 1, 13, 1, 1, 1, 1, 1, 1, 1
391
- # ]
392
-
393
- verts_idx = idx_diffs .cumsum (dim = 0 ) - 1
394
-
356
+ rng = torch .arange (verts_per_mesh .shape [0 ], device = device ) # (0,1,2)
357
+ verts_nums = rng .repeat_interleave (
358
+ verts_per_mesh
359
+ ) # (0,0,0,0,1,1,1,1,1,2,2,2,2,2,2)
360
+ edges_nums = rng .repeat_interleave (
361
+ edges_per_mesh
362
+ ) # (0,0,0,0,0,1,1,1,1,1,1,1,2,2,2,2,2,2,2,2,2)
363
+ nums = torch .cat ([verts_nums , edges_nums ])
364
+
365
+ verts_idx = torch .argsort (nums , stable = True )
395
366
# e.g.
396
367
# [
397
368
# 0, 1, 2, 3, 15, 16, 17, 18, 19, --> mesh 0
@@ -400,7 +371,6 @@ def _create_verts_index(verts_per_mesh, edges_per_mesh, device=None):
400
371
# ]
401
372
# where for mesh 0, [0, 1, 2, 3] are the indices of the existing verts, and
402
373
# [15, 16, 17, 18, 19] are the indices of the new verts after subdivision.
403
-
404
374
return verts_idx
405
375
406
376
@@ -421,44 +391,9 @@ def _create_faces_index(faces_per_mesh: torch.Tensor, device=None):
421
391
"""
422
392
# e.g. faces_per_mesh = [2, 5, 3]
423
393
424
- F = faces_per_mesh .sum () # e.g. 10
425
- faces_per_mesh_cumsum = faces_per_mesh .cumsum (dim = 0 ) # (N,) e.g. (2, 7, 10)
426
-
427
- switch1_idx = faces_per_mesh_cumsum .clone ()
428
- switch1_idx [1 :] += (
429
- 3 * faces_per_mesh_cumsum [:- 1 ]
430
- ) # e.g. (2, 7, 10) + (0, 6, 21) = (2, 13, 31)
431
-
432
- switch2_idx = 2 * faces_per_mesh_cumsum # e.g. (4, 14, 20)
433
- switch2_idx [1 :] += (
434
- 2 * faces_per_mesh_cumsum [:- 1 ]
435
- ) # e.g. (4, 14, 20) + (0, 4, 14) = (4, 18, 34)
436
-
437
- switch3_idx = 3 * faces_per_mesh_cumsum # e.g. (6, 21, 30)
438
- switch3_idx [1 :] += faces_per_mesh_cumsum [
439
- :- 1
440
- ] # e.g. (6, 21, 30) + (0, 2, 7) = (6, 23, 37)
441
-
442
- switch4_idx = 4 * faces_per_mesh_cumsum [:- 1 ] # e.g. (8, 28)
443
-
444
- switch123_offset = F - faces_per_mesh # e.g. (8, 5, 7)
445
-
446
- # pyre-fixme[6]: For 1st param expected `Union[List[int], Size,
447
- # typing.Tuple[int, ...]]` but got `Tensor`.
448
- idx_diffs = torch .ones (4 * F , device = device , dtype = torch .int64 )
449
- idx_diffs [switch1_idx ] += switch123_offset
450
- idx_diffs [switch2_idx ] += switch123_offset
451
- idx_diffs [switch3_idx ] += switch123_offset
452
- idx_diffs [switch4_idx ] -= 3 * F
453
-
454
- # e.g
455
- # [
456
- # 1, 1, 9, 1, 9, 1, 9, 1, -> mesh 0
457
- # -29, 1, 1, 1, 1, 6, 1, 1, 1, 1, 6, 1, 1, 1, 1, 6, 1, 1, 1, 1, -> mesh 1
458
- # -29, 1, 1, 8, 1, 1, 8, 1, 1, 8, 1, 1 -> mesh 2
459
- # ]
460
-
461
- faces_idx = idx_diffs .cumsum (dim = 0 ) - 1
394
+ rng = torch .arange (faces_per_mesh .shape [0 ], device = device ) # (0,1,2)
395
+ nums = rng .repeat_interleave (faces_per_mesh ).repeat (4 )
396
+ faces_idx = torch .argsort (nums , stable = True )
462
397
463
398
# e.g.
464
399
# [
0 commit comments