Skip to content

Commit a123815

Browse files
bottlerfacebook-github-bot
authored andcommitted
join_pointclouds_as_scene
Summary: New function Reviewed By: davidsonic Differential Revision: D42776590 fbshipit-source-id: 2a6e73480bcf2d1749f86bcb22d1942e3e8d3167
1 parent d388881 commit a123815

File tree

3 files changed

+61
-5
lines changed

3 files changed

+61
-5
lines changed

pytorch3d/structures/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from .meshes import join_meshes_as_batch, join_meshes_as_scene, Meshes
8-
from .pointclouds import Pointclouds
8+
from .pointclouds import (
9+
join_pointclouds_as_batch,
10+
join_pointclouds_as_scene,
11+
Pointclouds,
12+
)
913
from .utils import list_to_packed, list_to_padded, packed_to_list, padded_to_list
1014
from .volumes import Volumes
1115

pytorch3d/structures/pointclouds.py

+38
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,14 @@ def __init__(self, points, normals=None, features=None) -> None:
124124
normals:
125125
Can be either
126126
127+
- None
127128
- List where each element is a tensor of shape (num_points, 3)
128129
containing the normal vector for each point.
129130
- Padded float tensor of shape (num_clouds, num_points, 3).
130131
features:
131132
Can be either
132133
134+
- None
133135
- List where each element is a tensor of shape (num_points, C)
134136
containing the features for the points in the cloud.
135137
- Padded float tensor of shape (num_clouds, num_points, C).
@@ -1260,6 +1262,42 @@ def join_pointclouds_as_batch(pointclouds: Sequence[Pointclouds]) -> Pointclouds
12601262
field_list = None
12611263
else:
12621264
field_list = [p for points in field_list for p in points]
1265+
if field == "features" and any(
1266+
p.shape[1] != field_list[0].shape[1] for p in field_list[1:]
1267+
):
1268+
raise ValueError("Pointclouds must have the same number of features")
12631269
kwargs[field] = field_list
12641270

12651271
return Pointclouds(**kwargs)
1272+
1273+
1274+
def join_pointclouds_as_scene(
1275+
pointclouds: Union[Pointclouds, List[Pointclouds]]
1276+
) -> Pointclouds:
1277+
"""
1278+
Joins a batch of point cloud in the form of a Pointclouds object or a list of Pointclouds
1279+
objects as a single point cloud. If the input is a list, the Pointclouds objects in the
1280+
list must all be on the same device, and they must either all or none have features and
1281+
all or none have normals.
1282+
1283+
Args:
1284+
Pointclouds: Pointclouds object that contains a batch of point clouds, or a list of
1285+
Pointclouds objects.
1286+
1287+
Returns:
1288+
new Pointclouds object containing a single point cloud
1289+
"""
1290+
if isinstance(pointclouds, list):
1291+
pointclouds = join_pointclouds_as_batch(pointclouds)
1292+
1293+
if len(pointclouds) == 1:
1294+
return pointclouds
1295+
points = pointclouds.points_packed()
1296+
features = pointclouds.features_packed()
1297+
normals = pointclouds.normals_packed()
1298+
pointcloud = Pointclouds(
1299+
points=points[None],
1300+
features=None if features is None else features[None],
1301+
normals=None if normals is None else normals[None],
1302+
)
1303+
return pointcloud

tests/test_pointclouds.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
import numpy as np
1212
import torch
1313
from pytorch3d.structures import utils as struct_utils
14-
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
14+
from pytorch3d.structures.pointclouds import (
15+
join_pointclouds_as_batch,
16+
join_pointclouds_as_scene,
17+
Pointclouds,
18+
)
1519

1620
from .common_testing import TestCaseMixin
1721

@@ -1159,9 +1163,9 @@ def check_triple(points, points3):
11591163
normals = [torch.rand(length, 3) for length in lengths]
11601164

11611165
# Test with normals and features present
1162-
pcl = Pointclouds(points=points, features=features, normals=normals)
1163-
pcl3 = join_pointclouds_as_batch([pcl] * 3)
1164-
check_triple(pcl, pcl3)
1166+
pcl1 = Pointclouds(points=points, features=features, normals=normals)
1167+
pcl3 = join_pointclouds_as_batch([pcl1] * 3)
1168+
check_triple(pcl1, pcl3)
11651169

11661170
# Test with normals and features present for tensor backed pointclouds
11671171
N, P, D = 5, 30, 4
@@ -1173,15 +1177,25 @@ def check_triple(points, points3):
11731177
pcl3 = join_pointclouds_as_batch([pcl] * 3)
11741178
check_triple(pcl, pcl3)
11751179

1180+
# Test with inconsistent #features
1181+
with self.assertRaisesRegex(ValueError, "same number of features"):
1182+
join_pointclouds_as_batch([pcl1, pcl])
1183+
11761184
# Test without normals
11771185
pcl_nonormals = Pointclouds(points=points, features=features)
11781186
pcl3 = join_pointclouds_as_batch([pcl_nonormals] * 3)
11791187
check_triple(pcl_nonormals, pcl3)
1188+
pcl_scene = join_pointclouds_as_scene([pcl_nonormals] * 3)
1189+
self.assertEqual(len(pcl_scene), 1)
1190+
self.assertClose(pcl_scene.features_packed(), pcl3.features_packed())
11801191

11811192
# Test without features
11821193
pcl_nofeats = Pointclouds(points=points, normals=normals)
11831194
pcl3 = join_pointclouds_as_batch([pcl_nofeats] * 3)
11841195
check_triple(pcl_nofeats, pcl3)
1196+
pcl_scene = join_pointclouds_as_scene([pcl_nofeats] * 3)
1197+
self.assertEqual(len(pcl_scene), 1)
1198+
self.assertClose(pcl_scene.normals_packed(), pcl3.normals_packed())
11851199

11861200
# Check error raised if all pointclouds in the batch
11871201
# are not consistent in including normals/features

0 commit comments

Comments
 (0)