5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
8
+ from typing import Optional
9
+
8
10
import torch
9
11
import torch .nn .functional as F
12
+
10
13
from pytorch3d .common .compat import meshgrid_ij
14
+
11
15
from pytorch3d .structures import Meshes
12
16
13
17
@@ -50,7 +54,14 @@ def ravel_index(idx, dims) -> torch.Tensor:
50
54
51
55
52
56
@torch .no_grad ()
53
- def cubify (voxels , thresh , device = None , align : str = "topleft" ) -> Meshes :
57
+ def cubify (
58
+ voxels : torch .Tensor ,
59
+ thresh : float ,
60
+ * ,
61
+ feats : Optional [torch .Tensor ] = None ,
62
+ device = None ,
63
+ align : str = "topleft"
64
+ ) -> Meshes :
54
65
r"""
55
66
Converts a voxel to a mesh by replacing each occupied voxel with a cube
56
67
consisting of 12 faces and 8 vertices. Shared vertices are merged, and
@@ -59,6 +70,9 @@ def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes:
59
70
voxels: A FloatTensor of shape (N, D, H, W) containing occupancy probabilities.
60
71
thresh: A scalar threshold. If a voxel occupancy is larger than
61
72
thresh, the voxel is considered occupied.
73
+ feats: A FloatTensor of shape (N, K, D, H, W) containing the color information
74
+ of each voxel. K is the number of channels. This is supported only when
75
+ align == "center"
62
76
device: The device of the output meshes
63
77
align: Defines the alignment of the mesh vertices and the grid locations.
64
78
Has to be one of {"topleft", "corner", "center"}. See below for explanation.
@@ -177,6 +191,7 @@ def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes:
177
191
# boolean to linear index
178
192
# NF x 2
179
193
linind = torch .nonzero (faces_idx , as_tuple = False )
194
+
180
195
# NF x 4
181
196
nyxz = unravel_index (linind [:, 0 ], (N , H , W , D ))
182
197
@@ -238,6 +253,21 @@ def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes:
238
253
grid_verts .index_select (0 , (idleverts [n ] == 0 ).nonzero (as_tuple = False )[:, 0 ])
239
254
for n in range (N )
240
255
]
241
- faces_list = [nface - idlenum [n ][nface ] for n , nface in enumerate (faces_list )]
242
256
243
- return Meshes (verts = verts_list , faces = faces_list )
257
+ textures_list = None
258
+ if feats is not None and align == "center" :
259
+ # We return a TexturesAtlas containing one color for each face
260
+ # N x K x D x H x W -> N x H x W x D x K
261
+ feats = feats .permute (0 , 3 , 4 , 2 , 1 )
262
+
263
+ # (NHWD) x K
264
+ feats = feats .reshape (- 1 , feats .size (4 ))
265
+ feats = torch .index_select (feats , 0 , linind [:, 0 ])
266
+ feats = feats .reshape (- 1 , 1 , 1 , feats .size (1 ))
267
+ feats_list = list (torch .split (feats , split_size .tolist (), 0 ))
268
+ from pytorch3d .renderer .mesh .textures import TexturesAtlas
269
+
270
+ textures_list = TexturesAtlas (feats_list )
271
+
272
+ faces_list = [nface - idlenum [n ][nface ] for n , nface in enumerate (faces_list )]
273
+ return Meshes (verts = verts_list , faces = faces_list , textures = textures_list )
0 commit comments