Skip to content

Commit ae9d878

Browse files
cijosefacebook-github-bot
authored andcommitted
Support color in cubify
Summary: The diff support colors in cubify for align = "center" Reviewed By: bottler Differential Revision: D53777011 fbshipit-source-id: ccb2bd1e3d89be3d1ac943eff08f40e50b0540d9
1 parent 8772fe0 commit ae9d878

File tree

2 files changed

+73
-3
lines changed

2 files changed

+73
-3
lines changed

pytorch3d/ops/cubify.py

+33-3
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,13 @@
55
# LICENSE file in the root directory of this source tree.
66

77

8+
from typing import Optional
9+
810
import torch
911
import torch.nn.functional as F
12+
1013
from pytorch3d.common.compat import meshgrid_ij
14+
1115
from pytorch3d.structures import Meshes
1216

1317

@@ -50,7 +54,14 @@ def ravel_index(idx, dims) -> torch.Tensor:
5054

5155

5256
@torch.no_grad()
53-
def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes:
57+
def cubify(
58+
voxels: torch.Tensor,
59+
thresh: float,
60+
*,
61+
feats: Optional[torch.Tensor] = None,
62+
device=None,
63+
align: str = "topleft"
64+
) -> Meshes:
5465
r"""
5566
Converts a voxel to a mesh by replacing each occupied voxel with a cube
5667
consisting of 12 faces and 8 vertices. Shared vertices are merged, and
@@ -59,6 +70,9 @@ def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes:
5970
voxels: A FloatTensor of shape (N, D, H, W) containing occupancy probabilities.
6071
thresh: A scalar threshold. If a voxel occupancy is larger than
6172
thresh, the voxel is considered occupied.
73+
feats: A FloatTensor of shape (N, K, D, H, W) containing the color information
74+
of each voxel. K is the number of channels. This is supported only when
75+
align == "center"
6276
device: The device of the output meshes
6377
align: Defines the alignment of the mesh vertices and the grid locations.
6478
Has to be one of {"topleft", "corner", "center"}. See below for explanation.
@@ -177,6 +191,7 @@ def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes:
177191
# boolean to linear index
178192
# NF x 2
179193
linind = torch.nonzero(faces_idx, as_tuple=False)
194+
180195
# NF x 4
181196
nyxz = unravel_index(linind[:, 0], (N, H, W, D))
182197

@@ -238,6 +253,21 @@ def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes:
238253
grid_verts.index_select(0, (idleverts[n] == 0).nonzero(as_tuple=False)[:, 0])
239254
for n in range(N)
240255
]
241-
faces_list = [nface - idlenum[n][nface] for n, nface in enumerate(faces_list)]
242256

243-
return Meshes(verts=verts_list, faces=faces_list)
257+
textures_list = None
258+
if feats is not None and align == "center":
259+
# We return a TexturesAtlas containing one color for each face
260+
# N x K x D x H x W -> N x H x W x D x K
261+
feats = feats.permute(0, 3, 4, 2, 1)
262+
263+
# (NHWD) x K
264+
feats = feats.reshape(-1, feats.size(4))
265+
feats = torch.index_select(feats, 0, linind[:, 0])
266+
feats = feats.reshape(-1, 1, 1, feats.size(1))
267+
feats_list = list(torch.split(feats, split_size.tolist(), 0))
268+
from pytorch3d.renderer.mesh.textures import TexturesAtlas
269+
270+
textures_list = TexturesAtlas(feats_list)
271+
272+
faces_list = [nface - idlenum[n][nface] for n, nface in enumerate(faces_list)]
273+
return Meshes(verts=verts_list, faces=faces_list, textures=textures_list)

tests/test_cubify.py

+40
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import torch
1010
from pytorch3d.ops import cubify
11+
from pytorch3d.renderer.mesh.textures import TexturesAtlas
1112

1213
from .common_testing import TestCaseMixin
1314

@@ -313,3 +314,42 @@ def convert():
313314
torch.cuda.synchronize()
314315

315316
return convert
317+
318+
def test_cubify_with_feats(self):
319+
N, V = 3, 2
320+
device = torch.device("cuda:0")
321+
voxels = torch.zeros((N, V, V, V), dtype=torch.float32, device=device)
322+
feats = torch.zeros((N, 3, V, V, V), dtype=torch.float32, device=device)
323+
# fill the feats with red color
324+
feats[:, 0, :, :, :] = 255
325+
326+
# 1st example: (top left corner, znear) is on
327+
voxels[0, 0, 0, 0] = 1.0
328+
# the color is set to green
329+
feats[0, :, 0, 0, 0] = torch.Tensor([0, 255, 0])
330+
# 2nd example: all are on
331+
voxels[1] = 1.0
332+
333+
# 3rd example
334+
voxels[2, :, :, 1] = 1.0
335+
voxels[2, 1, 1, 0] = 1.0
336+
# the color is set to yellow and blue respectively
337+
feats[2, 1, :, :, 1] = 255
338+
feats[2, :, 1, 1, 0] = torch.Tensor([0, 0, 255])
339+
meshes = cubify(voxels, 0.5, feats=feats, align="center")
340+
textures = meshes.textures
341+
self.assertTrue(textures is not None)
342+
self.assertTrue(isinstance(textures, TexturesAtlas))
343+
faces_textures = textures.faces_verts_textures_packed()
344+
red = faces_textures.new_tensor([255.0, 0.0, 0.0])
345+
green = faces_textures.new_tensor([0.0, 255.0, 0.0])
346+
blue = faces_textures.new_tensor([0.0, 0.0, 255.0])
347+
yellow = faces_textures.new_tensor([255.0, 255.0, 0.0])
348+
349+
self.assertEqual(faces_textures.shape, (100, 3, 3))
350+
faces_textures_ = faces_textures.flatten(end_dim=1)
351+
self.assertClose(faces_textures_[:36], green.expand(36, -1))
352+
self.assertClose(faces_textures_[36:180], red.expand(144, -1))
353+
self.assertClose(faces_textures_[180:228], yellow.expand(48, -1))
354+
self.assertClose(faces_textures_[228:258], blue.expand(30, -1))
355+
self.assertClose(faces_textures_[258:300], yellow.expand(42, -1))

0 commit comments

Comments
 (0)