Skip to content

Commit 74754bb

Browse files
bottlerfacebook-github-bot
authored andcommittedOct 23, 2022
voxel_grid_implicit_function
Reviewed By: shapovalov Differential Revision: D40622304 fbshipit-source-id: 277515a55c46d9b8300058b439526539a7fe00a0
1 parent 611aba9 commit 74754bb

File tree

4 files changed

+1008
-0
lines changed

4 files changed

+1008
-0
lines changed
 

‎projects/implicitron_trainer/tests/experiment.yaml

+162
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,168 @@ model_factory_ImplicitronModelFactory_args:
394394
in_features: 256
395395
out_features: 3
396396
ray_dir_in_camera_coords: false
397+
implicit_function_VoxelGridImplicitFunction_args:
398+
harmonic_embedder_xyz_density_args:
399+
n_harmonic_functions: 6
400+
omega_0: 1.0
401+
logspace: true
402+
append_input: true
403+
harmonic_embedder_xyz_color_args:
404+
n_harmonic_functions: 6
405+
omega_0: 1.0
406+
logspace: true
407+
append_input: true
408+
harmonic_embedder_dir_color_args:
409+
n_harmonic_functions: 6
410+
omega_0: 1.0
411+
logspace: true
412+
append_input: true
413+
decoder_density_class_type: MLPDecoder
414+
decoder_color_class_type: MLPDecoder
415+
use_multiple_streams: true
416+
xyz_ray_dir_in_camera_coords: false
417+
scaffold_calculating_epochs: []
418+
scaffold_resolution:
419+
- 128
420+
- 128
421+
- 128
422+
scaffold_empty_space_threshold: 0.001
423+
scaffold_occupancy_chunk_size: 'inf'
424+
scaffold_max_pool_kernel_size: 3
425+
scaffold_filter_points: true
426+
volume_cropping_epochs: []
427+
voxel_grid_density_args:
428+
voxel_grid_class_type: FullResolutionVoxelGrid
429+
extents:
430+
- 2.0
431+
- 2.0
432+
- 2.0
433+
translation:
434+
- 0.0
435+
- 0.0
436+
- 0.0
437+
init_std: 0.1
438+
init_mean: 0.0
439+
hold_voxel_grid_as_parameters: true
440+
param_groups: {}
441+
voxel_grid_CPFactorizedVoxelGrid_args:
442+
align_corners: true
443+
padding: zeros
444+
mode: bilinear
445+
n_features: 1
446+
resolution_changes:
447+
0:
448+
- 128
449+
- 128
450+
- 128
451+
n_components: 24
452+
basis_matrix: true
453+
voxel_grid_FullResolutionVoxelGrid_args:
454+
align_corners: true
455+
padding: zeros
456+
mode: bilinear
457+
n_features: 1
458+
resolution_changes:
459+
0:
460+
- 128
461+
- 128
462+
- 128
463+
voxel_grid_VMFactorizedVoxelGrid_args:
464+
align_corners: true
465+
padding: zeros
466+
mode: bilinear
467+
n_features: 1
468+
resolution_changes:
469+
0:
470+
- 128
471+
- 128
472+
- 128
473+
n_components: null
474+
distribution_of_components: null
475+
basis_matrix: true
476+
voxel_grid_color_args:
477+
voxel_grid_class_type: FullResolutionVoxelGrid
478+
extents:
479+
- 2.0
480+
- 2.0
481+
- 2.0
482+
translation:
483+
- 0.0
484+
- 0.0
485+
- 0.0
486+
init_std: 0.1
487+
init_mean: 0.0
488+
hold_voxel_grid_as_parameters: true
489+
param_groups: {}
490+
voxel_grid_CPFactorizedVoxelGrid_args:
491+
align_corners: true
492+
padding: zeros
493+
mode: bilinear
494+
n_features: 1
495+
resolution_changes:
496+
0:
497+
- 128
498+
- 128
499+
- 128
500+
n_components: 24
501+
basis_matrix: true
502+
voxel_grid_FullResolutionVoxelGrid_args:
503+
align_corners: true
504+
padding: zeros
505+
mode: bilinear
506+
n_features: 1
507+
resolution_changes:
508+
0:
509+
- 128
510+
- 128
511+
- 128
512+
voxel_grid_VMFactorizedVoxelGrid_args:
513+
align_corners: true
514+
padding: zeros
515+
mode: bilinear
516+
n_features: 1
517+
resolution_changes:
518+
0:
519+
- 128
520+
- 128
521+
- 128
522+
n_components: null
523+
distribution_of_components: null
524+
basis_matrix: true
525+
decoder_density_ElementwiseDecoder_args:
526+
scale: 1.0
527+
shift: 0.0
528+
operation: IDENTITY
529+
decoder_density_MLPDecoder_args:
530+
param_groups: {}
531+
network_args:
532+
n_layers: 8
533+
output_dim: 256
534+
skip_dim: 39
535+
hidden_dim: 256
536+
input_skips:
537+
- 5
538+
skip_affine_trans: false
539+
last_layer_bias_init: null
540+
last_activation: RELU
541+
use_xavier_init: true
542+
decoder_color_ElementwiseDecoder_args:
543+
scale: 1.0
544+
shift: 0.0
545+
operation: IDENTITY
546+
decoder_color_MLPDecoder_args:
547+
param_groups: {}
548+
network_args:
549+
n_layers: 8
550+
output_dim: 256
551+
skip_dim: 39
552+
hidden_dim: 256
553+
input_skips:
554+
- 5
555+
skip_affine_trans: false
556+
last_layer_bias_init: null
557+
last_activation: RELU
558+
use_xavier_init: true
397559
view_metrics_ViewMetrics_args: {}
398560
regularization_metrics_RegularizationMetrics_args: {}
399561
optimizer_factory_ImplicitronOptimizerFactory_args:

‎pytorch3d/implicitron/models/generic_model.py

+3
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@
5252
SRNHyperNetImplicitFunction,
5353
SRNImplicitFunction,
5454
)
55+
from .implicit_function.voxel_grid_implicit_function import ( # noqa
56+
VoxelGridImplicitFunction,
57+
)
5558

5659
from .renderer.base import (
5760
BaseRenderer,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,616 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import math
8+
import warnings
9+
from dataclasses import fields
10+
from typing import Callable, Dict, Optional, Tuple, Union
11+
12+
import torch
13+
14+
from omegaconf import DictConfig
15+
16+
from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase
17+
from pytorch3d.implicitron.models.implicit_function.decoding_functions import (
18+
DecoderFunctionBase,
19+
)
20+
from pytorch3d.implicitron.models.implicit_function.voxel_grid import VoxelGridModule
21+
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
22+
from pytorch3d.implicitron.tools.config import (
23+
enable_get_default_args,
24+
get_default_args_field,
25+
registry,
26+
run_auto_creation,
27+
)
28+
from pytorch3d.renderer import ray_bundle_to_ray_points
29+
from pytorch3d.renderer.cameras import CamerasBase
30+
from pytorch3d.renderer.implicit import HarmonicEmbedding
31+
32+
enable_get_default_args(HarmonicEmbedding)
33+
34+
35+
@registry.register
36+
# pyre-ignore[13]
37+
class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
38+
"""
39+
This implicit function consists of two streams, one for the density calculation and one
40+
for the color calculation. Each of these streams has three main parts:
41+
1) Voxel grids:
42+
They take the (x, y, z) position and return the embedding of that point.
43+
These components are replaceable, you can make your own or choose one of
44+
several options.
45+
2) Harmonic embeddings:
46+
Convert each feature into series of 'harmonic features', feature is passed through
47+
sine and cosine functions. Input is of shape [minibatch, ..., D] output
48+
[minibatch, ..., (n_harmonic_functions * 2 + int(append_input)) * D]. Appends
49+
input by default. If you want it to behave like identity, put n_harmonic_functions=0
50+
and append_input=True.
51+
3) Decoding functions:
52+
The decoder is an instance of the DecoderFunctionBase and converts the embedding
53+
of a spatial location to density/color. Examples are Identity which returns its
54+
input and the MLP which uses fully connected nerual network to transform the input.
55+
These components are replaceable, you can make your own or choose from
56+
several options.
57+
58+
Calculating density is done in three steps:
59+
1) Evaluating the voxel grid on points
60+
2) Embedding the outputs with harmonic embedding
61+
3) Passing through the Density decoder
62+
63+
To calculate the color we need the embedding and the viewing direction, it has five steps:
64+
1) Transforming the viewing direction with camera
65+
2) Evaluating the voxel grid on points
66+
3) Embedding the outputs with harmonic embedding
67+
4) Embedding the normalized direction with harmonic embedding
68+
5) Passing everything through the Color decoder
69+
70+
If using the Implicitron configuration system the input_dim to the decoding functions will
71+
be set to the output_dim of the Harmonic embeddings.
72+
73+
A speed up comes from using the scaffold, a low resolution voxel grid.
74+
The scaffold is referenced as "binary occupancy grid mask" in TensoRF paper and "AlphaMask"
75+
in official TensoRF implementation.
76+
The scaffold is used in:
77+
1) filtering points in empty space
78+
- controlled by `scaffold_filter_points` boolean. If set to True, points for which
79+
scaffold predicts that are in empty space will return 0 density and
80+
(0, 0, 0) color.
81+
2) calculating the bounding box of an object and cropping the voxel grids
82+
- controlled by `volume_cropping_epochs`.
83+
- at those epochs the implicit function will find the bounding box of an object
84+
inside it and crop density and color grids. Cropping of the voxel grids means
85+
preserving only voxel values that are inside the bounding box and changing the
86+
resolution to match the original, while preserving the new cropped location in
87+
world coordinates.
88+
89+
The scaffold has to exist before attempting filtering and cropping, and is created on
90+
`scaffold_calculating_epochs`. Each voxel in the scaffold is labeled as having density 1 if
91+
the point in the center of it evaluates to greater than `scaffold_empty_space_threshold`.
92+
3D max pooling is performed on the densities of the points in 3D.
93+
Scaffold features are off by default.
94+
95+
Members:
96+
voxel_grid_density (VoxelGridBase): voxel grid to use for density estimation
97+
voxel_grid_color (VoxelGridBase): voxel grid to use for color estimation
98+
99+
harmonic_embedder_xyz_density (HarmonicEmbedder): Function to transform the outputs of
100+
the voxel_grid_density
101+
harmonic_embedder_xyz_color (HarmonicEmbedder): Function to transform the outputs of
102+
the voxel_grid_color for density
103+
harmonic_embedder_dir_color (HarmonicEmbedder): Function to transform the outputs of
104+
the voxel_grid_color for color
105+
106+
decoder_density (DecoderFunctionBase): decoder function to use for density estimation
107+
color_density (DecoderFunctionBase): decoder function to use for color estimation
108+
109+
use_multiple_streams (bool): if you want the density and color calculations to run on
110+
different cuda streams set this to True. Default True.
111+
xyz_ray_dir_in_camera_coords (bool): This is true if the directions are given in
112+
camera coordinates. Default False.
113+
114+
voxel_grid_scaffold (VoxelGridModule): which holds the scaffold. Extents and
115+
translation of it are set to those of voxel_grid_density.
116+
scaffold_calculating_epochs (Tuple[int, ...]): at which epochs to recalculate the
117+
scaffold. (The scaffold will be created automatically at the beginning of
118+
the calculation.)
119+
scaffold_resolution (Tuple[int, int, int]): (width, height, depth) of the underlying
120+
voxel grid which stores scaffold
121+
scaffold_empty_space_threshold (float): if `self.get_density` evaluates to less than
122+
this it will be considered as empty space and the scaffold at that point would
123+
evaluate as empty space.
124+
scaffold_occupancy_chunk_size (str or int): Number of xy scaffold planes to calculate
125+
at the same time. To calculate the scaffold we need to query `get_density()` at
126+
every voxel, this calculation can be split into scaffold depth number of xy plane
127+
calculations if you want the lowest memory usage, one calculation to calculate the
128+
whole scaffold, but with higher memory footprint or any other number of planes.
129+
Setting to 'inf' calculates all planes at the same time. Defaults to 'inf'.
130+
scaffold_max_pool_kernel_size (int): Size of the pooling region to use when
131+
calculating the scaffold. Defaults to 3.
132+
scaffold_filter_points (bool): If set to True the points will be filtered using
133+
`self.voxel_grid_scaffold`. Filtered points will be predicted as having 0 density
134+
and (0, 0, 0) color. The points which were not evaluated as empty space will be
135+
passed through the steps outlined above.
136+
volume_cropping_epochs: on which epochs to crop the voxel grids to fit the object's
137+
bounding box. Scaffold has to be calculated before cropping.
138+
"""
139+
140+
# ---- voxel grid for density
141+
voxel_grid_density: VoxelGridModule
142+
143+
# ---- voxel grid for color
144+
voxel_grid_color: VoxelGridModule
145+
146+
# ---- harmonic embeddings density
147+
harmonic_embedder_xyz_density_args: DictConfig = get_default_args_field(
148+
HarmonicEmbedding
149+
)
150+
harmonic_embedder_xyz_color_args: DictConfig = get_default_args_field(
151+
HarmonicEmbedding
152+
)
153+
harmonic_embedder_dir_color_args: DictConfig = get_default_args_field(
154+
HarmonicEmbedding
155+
)
156+
157+
# ---- decoder function for density
158+
decoder_density_class_type: str = "MLPDecoder"
159+
decoder_density: DecoderFunctionBase
160+
161+
# ---- decoder function for color
162+
decoder_color_class_type: str = "MLPDecoder"
163+
decoder_color: DecoderFunctionBase
164+
165+
# ---- cuda streams
166+
use_multiple_streams: bool = True
167+
168+
# ---- camera
169+
xyz_ray_dir_in_camera_coords: bool = False
170+
171+
# --- scaffold
172+
# voxel_grid_scaffold: VoxelGridModule
173+
scaffold_calculating_epochs: Tuple[int, ...] = ()
174+
scaffold_resolution: Tuple[int, int, int] = (128, 128, 128)
175+
scaffold_empty_space_threshold: float = 0.001
176+
scaffold_occupancy_chunk_size: Union[str, int] = "inf"
177+
scaffold_max_pool_kernel_size: int = 3
178+
scaffold_filter_points: bool = True
179+
180+
# --- cropping
181+
volume_cropping_epochs: Tuple[int, ...] = ()
182+
183+
def __post_init__(self) -> None:
184+
super().__init__()
185+
run_auto_creation(self)
186+
# pyre-ignore[16]
187+
self.voxel_grid_scaffold = self._create_voxel_grid_scaffold()
188+
# pyre-ignore[16]
189+
self.harmonic_embedder_xyz_density = HarmonicEmbedding(
190+
**self.harmonic_embedder_xyz_density_args
191+
)
192+
# pyre-ignore[16]
193+
self.harmonic_embedder_xyz_color = HarmonicEmbedding(
194+
**self.harmonic_embedder_xyz_color_args
195+
)
196+
# pyre-ignore[16]
197+
self.harmonic_embedder_dir_color = HarmonicEmbedding(
198+
**self.harmonic_embedder_dir_color_args
199+
)
200+
# pyre-ignore[16]
201+
self._scaffold_ready = False
202+
if type(self.scaffold_occupancy_chunk_size) != int:
203+
if self.scaffold_occupancy_chunk_size != "inf":
204+
raise ValueError(
205+
"`scaffold_occupancy_chunk_size` has to be int or 'inf'."
206+
)
207+
208+
def forward(
209+
self,
210+
ray_bundle: ImplicitronRayBundle,
211+
fun_viewpool=None,
212+
camera: Optional[CamerasBase] = None,
213+
global_code=None,
214+
**kwargs,
215+
) -> Tuple[torch.Tensor, torch.Tensor, Dict]:
216+
"""
217+
The forward function accepts the parametrizations of 3D points sampled along
218+
projection rays. The forward pass is responsible for attaching a 3D vector
219+
and a 1D scalar representing the point's RGB color and opacity respectively.
220+
221+
Args:
222+
ray_bundle: An ImplicitronRayBundle object containing the following variables:
223+
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
224+
origins of the sampling rays in world coords.
225+
directions: A tensor of shape `(minibatch, ..., 3)`
226+
containing the direction vectors of sampling rays in world coords.
227+
lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
228+
containing the lengths at which the rays are sampled.
229+
fun_viewpool: an optional callback with the signature
230+
fun_fiewpool(points) -> pooled_features
231+
where points is a [N_TGT x N x 3] tensor of world coords,
232+
and pooled_features is a [N_TGT x ... x N_SRC x latent_dim] tensor
233+
of the features pooled from the context images.
234+
camera: A camera model which will be used to transform the viewing
235+
directions
236+
237+
Returns:
238+
rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)`
239+
denoting the opacitiy of each ray point.
240+
rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`
241+
denoting the color of each ray point.
242+
"""
243+
# ########## convert the ray parametrizations to world coordinates ########## #
244+
# points.shape = [minibatch x n_rays_width x n_rays_height x pts_per_ray x 3]
245+
# pyre-ignore[6]
246+
points = ray_bundle_to_ray_points(ray_bundle)
247+
directions = ray_bundle.directions.reshape(-1, 3)
248+
input_shape = points.shape
249+
points = points.view(-1, 3)
250+
251+
# ########## filter the points using the scaffold ########## #
252+
if self._scaffold_ready and self.scaffold_filter_points:
253+
# pyre-ignore[29]
254+
non_empty_points = self.voxel_grid_scaffold(points)[..., 0] > 0
255+
points = points[non_empty_points]
256+
directions = directions[non_empty_points]
257+
if len(points) == 0:
258+
warnings.warn(
259+
"The scaffold has filtered all the points."
260+
"The voxel grids and decoding functions will not be run."
261+
)
262+
return (
263+
points.new_zeros((*input_shape[:-1], 1)),
264+
points.new_zeros((*input_shape[:-1], 3)),
265+
{},
266+
)
267+
268+
# ########## calculate color and density ########## #
269+
rays_densities, rays_colors = self.calculate_density_and_color(
270+
points, directions, camera
271+
)
272+
273+
if not (self._scaffold_ready and self.scaffold_filter_points):
274+
return (
275+
rays_densities.view((*input_shape[:-1], rays_densities.shape[-1])),
276+
rays_colors.view((*input_shape[:-1], rays_colors.shape[-1])),
277+
{},
278+
)
279+
280+
# ########## merge scaffold calculated points ########## #
281+
# Create a zeroed tensor corresponding to a point with density=0 and fill it
282+
# with calculated density for points which are not in empty space. Do the
283+
# same for color
284+
rays_densities_combined = rays_densities.new_zeros(
285+
(math.prod(input_shape[:-1]), rays_densities.shape[-1])
286+
)
287+
rays_colors_combined = rays_colors.new_zeros(
288+
(math.prod(input_shape[:-1]), rays_colors.shape[-1])
289+
)
290+
# pyre-ignore[61]
291+
rays_densities_combined[non_empty_points] = rays_densities
292+
# pyre-ignore[61]
293+
rays_colors_combined[non_empty_points] = rays_colors
294+
295+
return (
296+
rays_densities_combined.view((*input_shape[:-1], rays_densities.shape[-1])),
297+
rays_colors_combined.view((*input_shape[:-1], rays_colors.shape[-1])),
298+
{},
299+
)
300+
301+
def calculate_density_and_color(
302+
self,
303+
points: torch.Tensor,
304+
directions: torch.Tensor,
305+
camera: Optional[CamerasBase] = None,
306+
) -> Tuple[torch.Tensor, torch.Tensor]:
307+
"""
308+
Calculates density and color at `points`.
309+
If enabled use cuda streams.
310+
311+
Args:
312+
points: points at which to calculate density and color.
313+
Tensor of shape [..., 3].
314+
directions: from which directions are the points viewed
315+
Tensor of shape [..., 3].
316+
camera: A camera model which will be used to transform the viewing
317+
directions
318+
Returns:
319+
Tuple of color (tensor of shape [..., 3]) and density
320+
(tensor of shape [..., 1])
321+
"""
322+
if self.use_multiple_streams and points.is_cuda:
323+
current_stream = torch.cuda.current_stream(points.device)
324+
other_stream = torch.cuda.Stream(points.device)
325+
other_stream.wait_stream(current_stream)
326+
327+
with torch.cuda.stream(other_stream):
328+
# rays_densities.shape =
329+
# [minibatch x n_rays_width x n_rays_height x pts_per_ray x density_dim]
330+
rays_densities = self.get_density(points)
331+
332+
# rays_colors.shape =
333+
# [minibatch x n_rays_width x n_rays_height x pts_per_ray x color_dim]
334+
rays_colors = self.get_color(points, camera, directions)
335+
336+
current_stream.wait_stream(other_stream)
337+
else:
338+
# Same calculation as above, just serial.
339+
rays_densities = self.get_density(points)
340+
rays_colors = self.get_color(points, camera, directions)
341+
return rays_densities, rays_colors
342+
343+
def get_density(self, points: torch.Tensor) -> torch.Tensor:
344+
"""
345+
Calculates density at points:
346+
1) Evaluates the voxel grid on points
347+
2) Embeds the outputs with harmonic embedding
348+
3) Passes everything through the Density decoder
349+
350+
Args:
351+
points: tensor of shape [..., 3]
352+
where the last dimension is the points in the (x, y, z)
353+
Returns:
354+
calculated densities of shape [..., density_dim], `density_dim` is the
355+
feature dimensionality which `decoder_density` returns
356+
"""
357+
embeds_density = self.voxel_grid_density(points)
358+
# pyre-ignore[29]
359+
harmonic_embedding_density = self.harmonic_embedder_xyz_density(embeds_density)
360+
# shape = [..., density_dim]
361+
return self.decoder_density(harmonic_embedding_density)
362+
363+
def get_color(
364+
self,
365+
points: torch.Tensor,
366+
camera: Optional[CamerasBase],
367+
directions: torch.Tensor,
368+
) -> torch.Tensor:
369+
"""
370+
Calculates color at points using the viewing direction:
371+
1) Transforms the viewing direction with camera
372+
2) Evaluates the voxel grid on points
373+
3) Embeds the outputs with harmonic embedding
374+
4) Embeds the normalized direction with harmonic embedding
375+
5) Passes everything through the Color decoder
376+
Args:
377+
points: tensor of shape (..., 3)
378+
where the last dimension is the points in the (x, y, z)
379+
camera: A camera model which will be used to transform the viewing
380+
directions
381+
directions: A tensor of shape `(..., 3)`
382+
containing the direction vectors of sampling rays in world coords.
383+
"""
384+
# ########## transform direction ########## #
385+
if self.xyz_ray_dir_in_camera_coords:
386+
if camera is None:
387+
raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")
388+
directions = directions @ camera.R
389+
390+
# ########## get voxel grid output ########## #
391+
# embeds_color.shape = [..., pts_per_ray, n_features]
392+
embeds_color = self.voxel_grid_color(points)
393+
394+
# ########## embed with the harmonic function ########## #
395+
# Obtain the harmonic embedding of the voxel grid output.
396+
# pyre-ignore[29]
397+
harmonic_embedding_color = self.harmonic_embedder_xyz_color(embeds_color)
398+
399+
# Normalize the ray_directions to unit l2 norm.
400+
rays_directions_normed = torch.nn.functional.normalize(directions, dim=-1)
401+
# Obtain the harmonic embedding of the normalized ray directions.
402+
# pyre-ignore[29]
403+
harmonic_embedding_dir = self.harmonic_embedder_dir_color(
404+
rays_directions_normed
405+
)
406+
407+
n_rays = directions.shape[0]
408+
points_per_ray: int = points.shape[0] // n_rays
409+
410+
harmonic_embedding_dir = torch.repeat_interleave(
411+
harmonic_embedding_dir, points_per_ray, dim=0
412+
)
413+
414+
# total color embedding is concatenation of the harmonic embedding of voxel grid
415+
# output and harmonic embedding of the normalized direction
416+
total_color_embedding = torch.cat(
417+
(harmonic_embedding_color, harmonic_embedding_dir), dim=-1
418+
)
419+
420+
# ########## evaluate color with the decoding function ########## #
421+
# rays_colors.shape = [..., pts_per_ray, 3] in [0-1]
422+
return self.decoder_color(total_color_embedding)
423+
424+
@staticmethod
425+
def allows_multiple_passes() -> bool:
426+
"""
427+
Returns True as this implicit function allows
428+
multiple passes. Overridden from ImplicitFunctionBase.
429+
"""
430+
return True
431+
432+
def subscribe_to_epochs(self) -> Tuple[Tuple[int, ...], Callable[[int], bool]]:
433+
"""
434+
Method which expresses interest in subscribing to optimization epoch updates.
435+
This implicit function subscribes to epochs to calculate the scaffold and to
436+
crop voxel grids, so this method combines wanted epochs and wraps their callbacks.
437+
438+
Returns:
439+
list of epochs on which to call a callable and callable to be called on
440+
particular epoch. The callable returns True if parameter change has
441+
happened else False and it must be supplied with one argument, epoch.
442+
"""
443+
444+
def callback(epoch) -> bool:
445+
change = False
446+
if epoch in self.scaffold_calculating_epochs:
447+
change = self._get_scaffold(epoch)
448+
if epoch in self.volume_cropping_epochs:
449+
change = self._crop(epoch) or change
450+
return change
451+
452+
# remove duplicates
453+
call_epochs = list(
454+
set(self.scaffold_calculating_epochs) | set(self.volume_cropping_epochs)
455+
)
456+
return call_epochs, callback
457+
458+
def _crop(self, epoch: int) -> bool:
459+
"""
460+
Finds the bounding box of an object represented in the scaffold and crops
461+
density and color voxel grids to match that bounding box. If density of the
462+
scaffold is 0 everywhere (there is no object in it) no change will
463+
happen.
464+
465+
Args:
466+
epoch: ignored
467+
Returns:
468+
True (indicating that parameter change has happened) if there is
469+
an object inside, else False.
470+
"""
471+
# find bounding box
472+
# pyre-ignore[16]
473+
points = self.voxel_grid_scaffold.get_grid_points(epoch=epoch)
474+
assert self._scaffold_ready, "Scaffold has to be calculated before cropping."
475+
# pyre-ignore[29]
476+
occupancy = self.voxel_grid_scaffold(points)[..., 0] > 0
477+
non_zero_idxs = torch.nonzero(occupancy)
478+
if len(non_zero_idxs) == 0:
479+
return False
480+
min_indices = tuple(torch.min(non_zero_idxs, dim=0)[0])
481+
max_indices = tuple(torch.max(non_zero_idxs, dim=0)[0])
482+
min_point, max_point = points[min_indices], points[max_indices]
483+
484+
# crop the voxel grids
485+
self.voxel_grid_density.crop_self(min_point, max_point)
486+
self.voxel_grid_color.crop_self(min_point, max_point)
487+
return True
488+
489+
@torch.no_grad()
490+
def _get_scaffold(self, epoch: int) -> bool:
491+
"""
492+
Creates a low resolution grid which is used to filter points that are in empty
493+
space.
494+
495+
Args:
496+
epoch: epoch on which it is called, ignored inside method
497+
Returns:
498+
Always False: Modifies `self.voxel_grid_scaffold` member.
499+
"""
500+
501+
planes = []
502+
# pyre-ignore[16]
503+
points = self.voxel_grid_scaffold.get_grid_points(epoch=epoch)
504+
505+
chunk_size = (
506+
self.scaffold_occupancy_chunk_size
507+
if type(self.scaffold_occupancy_chunk_size) == int
508+
else points.shape[-1]
509+
)
510+
for k in range(0, points.shape[-1], chunk_size):
511+
points_in_planes = points[..., k : k + chunk_size]
512+
planes.append(self.get_density(points_in_planes)[..., 0])
513+
514+
density_cube = torch.cat(planes, dim=-1)
515+
density_cube = torch.nn.functional.max_pool3d(
516+
density_cube[None, None],
517+
kernel_size=self.scaffold_max_pool_kernel_size,
518+
padding=self.scaffold_max_pool_kernel_size // 2,
519+
stride=1,
520+
)
521+
occupancy_cube = density_cube > self.scaffold_empty_space_threshold
522+
# pyre-ignore[16]
523+
self.voxel_grid_scaffold.params["voxel_grid"] = occupancy_cube.float()
524+
# pyre-ignore[16]
525+
self._scaffold_ready = True
526+
527+
return False
528+
529+
@classmethod
530+
def decoder_density_tweak_args(cls, type, args: DictConfig) -> None:
531+
args.pop("input_dim", None)
532+
533+
def create_decoder_density_impl(self, type, args: DictConfig) -> None:
534+
"""
535+
Decoding functions come after harmonic embedding and voxel grid. In order to not
536+
calculate the input dimension of the decoder in the config file this function
537+
calculates the required input dimension and sets the input dimension of the
538+
decoding function to this value.
539+
"""
540+
grid_args = self.voxel_grid_density_args
541+
# pyre-ignore[6]
542+
grid_output_dim = VoxelGridModule.get_output_dim(grid_args)
543+
544+
embedder_args = self.harmonic_embedder_xyz_density_args
545+
input_dim = HarmonicEmbedding.get_output_dim_static(
546+
grid_output_dim,
547+
embedder_args["n_harmonic_functions"],
548+
embedder_args["append_input"],
549+
)
550+
551+
cls = registry.get(DecoderFunctionBase, type)
552+
need_input_dim = any(field.name == "input_dim" for field in fields(cls))
553+
if need_input_dim:
554+
self.decoder_density = cls(input_dim=input_dim, **args)
555+
else:
556+
self.decoder_density = cls(**args)
557+
558+
@classmethod
559+
def decoder_color_tweak_args(cls, type, args: DictConfig) -> None:
560+
args.pop("input_dim", None)
561+
562+
def create_decoder_color_impl(self, type, args: DictConfig) -> None:
563+
"""
564+
Decoding functions come after harmonic embedding and voxel grid. In order to not
565+
calculate the input dimension of the decoder in the config file this function
566+
calculates the required input dimension and sets the input dimension of the
567+
decoding function to this value.
568+
"""
569+
grid_args = self.voxel_grid_color_args
570+
# pyre-ignore[6]
571+
grid_output_dim = VoxelGridModule.get_output_dim(grid_args)
572+
573+
embedder_args = self.harmonic_embedder_xyz_color_args
574+
input_dim0 = HarmonicEmbedding.get_output_dim_static(
575+
grid_output_dim,
576+
embedder_args["n_harmonic_functions"],
577+
embedder_args["append_input"],
578+
)
579+
580+
dir_dim = 3
581+
embedder_args = self.harmonic_embedder_dir_color_args
582+
input_dim1 = HarmonicEmbedding.get_output_dim_static(
583+
dir_dim,
584+
embedder_args["n_harmonic_functions"],
585+
embedder_args["append_input"],
586+
)
587+
588+
input_dim = input_dim0 + input_dim1
589+
590+
cls = registry.get(DecoderFunctionBase, type)
591+
need_input_dim = any(field.name == "input_dim" for field in fields(cls))
592+
if need_input_dim:
593+
self.decoder_color = cls(input_dim=input_dim, **args)
594+
else:
595+
self.decoder_color = cls(**args)
596+
597+
def _create_voxel_grid_scaffold(self) -> VoxelGridModule:
598+
"""
599+
Creates object to become self.voxel_grid_scaffold:
600+
- makes `self.voxel_grid_scaffold` have same world to local mapping as
601+
`self.voxel_grid_density`
602+
"""
603+
return VoxelGridModule(
604+
# pyre-ignore[29]
605+
extents=self.voxel_grid_density_args["extents"],
606+
# pyre-ignore[29]
607+
translation=self.voxel_grid_density_args["translation"],
608+
voxel_grid_class_type="FullResolutionVoxelGrid",
609+
hold_voxel_grid_as_parameters=False,
610+
voxel_grid_FullResolutionVoxelGrid_args={
611+
"resolution_changes": {0: self.scaffold_resolution},
612+
"padding": "zeros",
613+
"align_corners": True,
614+
"mode": "trilinear",
615+
},
616+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import unittest
9+
10+
import torch
11+
12+
from omegaconf import DictConfig, OmegaConf
13+
from pytorch3d.implicitron.models.implicit_function.voxel_grid_implicit_function import (
14+
VoxelGridImplicitFunction,
15+
)
16+
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
17+
18+
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
19+
from pytorch3d.renderer import ray_bundle_to_ray_points
20+
from tests.common_testing import TestCaseMixin
21+
22+
23+
class TestVoxelGridImplicitFunction(TestCaseMixin, unittest.TestCase):
24+
def setUp(self) -> None:
25+
torch.manual_seed(42)
26+
expand_args_fields(VoxelGridImplicitFunction)
27+
28+
def _get_simple_implicit_function(self, scaffold_res=16):
29+
default_cfg = get_default_args(VoxelGridImplicitFunction)
30+
custom_cfg = DictConfig(
31+
{
32+
"voxel_grid_density_args": {
33+
"voxel_grid_FullResolutionVoxelGrid_args": {"n_features": 7}
34+
},
35+
"decoder_density_class_type": "ElementwiseDecoder",
36+
"decoder_color_class_type": "MLPDecoder",
37+
"decoder_color_MLPDecoder_args": {
38+
"network_args": {
39+
"n_layers": 2,
40+
"output_dim": 3,
41+
"hidden_dim": 128,
42+
}
43+
},
44+
"scaffold_resolution": (scaffold_res, scaffold_res, scaffold_res),
45+
}
46+
)
47+
cfg = OmegaConf.merge(default_cfg, custom_cfg)
48+
return VoxelGridImplicitFunction(**cfg)
49+
50+
def test_forward(self) -> None:
51+
"""
52+
Test one forward of VoxelGridImplicitFunction.
53+
"""
54+
func = self._get_simple_implicit_function()
55+
56+
n_grids, n_points = 10, 9
57+
raybundle = ImplicitronRayBundle(
58+
origins=torch.randn(n_grids, 2, 3, 3),
59+
directions=torch.randn(n_grids, 2, 3, 3),
60+
lengths=torch.randn(n_grids, 2, 3, n_points),
61+
xys=0,
62+
)
63+
func(raybundle)
64+
65+
def test_scaffold_formation(self):
66+
"""
67+
Test calculating the scaffold.
68+
69+
We define a custom density function and make the implicit function use it
70+
After calculating the scaffold we compare the density of our custom
71+
density function with densities from the scaffold.
72+
"""
73+
device = "cuda" if torch.cuda.is_available() else "cpu"
74+
func = self._get_simple_implicit_function().to(device)
75+
func.scaffold_max_pool_kernel_size = 1
76+
77+
def new_density(points):
78+
"""
79+
Density function which returns 1 if p>(0.5, 0.5, 0.5) or
80+
p < (-0.5, -0.5, -0.5) else 0
81+
"""
82+
inshape = points.shape
83+
points = points.view(-1, 3)
84+
out = []
85+
for p in points:
86+
if torch.all(p > 0.5) or torch.all(p < -0.5):
87+
out.append(torch.tensor([[1.0]]))
88+
else:
89+
out.append(torch.tensor([[0.0]]))
90+
return torch.cat(out).view(*inshape[:-1], 1).to(device)
91+
92+
func.get_density = new_density
93+
func._get_scaffold(0)
94+
95+
points = torch.tensor(
96+
[
97+
[0, 0, 0],
98+
[1, 1, 1],
99+
[1, 0, 0],
100+
[0.1, 0, 0],
101+
[10, 1, -1],
102+
[-0.8, -0.7, -0.9],
103+
]
104+
).to(device)
105+
expected = new_density(points).float().to(device)
106+
assert torch.allclose(func.voxel_grid_scaffold(points), expected), (
107+
func.voxel_grid_scaffold(points),
108+
expected,
109+
)
110+
111+
def test_scaffold_filtering(self, n_test_points=100):
112+
"""
113+
Test that filtering points with scaffold works.
114+
115+
We define a scaffold and make the implicit function use it. We also
116+
define new density and color functions which check that all passed
117+
points are not in empty space (with scaffold function). In the end
118+
we compare the result from the implicit function with one calculated
119+
simple python, this checks that the points were merged correectly.
120+
"""
121+
device = "cuda"
122+
func = self._get_simple_implicit_function().to(device)
123+
124+
def scaffold(points):
125+
"""'
126+
Function to deterministically and randomly enough assign a point
127+
to empty or occupied space.
128+
Return 1 if second digit of sum after 0 is odd else 0
129+
"""
130+
return (
131+
((points.sum(dim=-1, keepdim=True) * 10**2 % 10).long() % 2) == 1
132+
).float()
133+
134+
def new_density(points):
135+
# check if all passed points should be passed here
136+
assert torch.all(scaffold(points)), (scaffold(points), points.shape)
137+
return points.sum(dim=-1, keepdim=True)
138+
139+
def new_color(points, camera, directions):
140+
# check if all passed points should be passed here
141+
assert torch.all(scaffold(points)) # , (scaffold(points), points)
142+
return points * 2
143+
144+
# check both computation paths that they contain only points
145+
# which are not in empty space
146+
func.get_density = new_density
147+
func.get_color = new_color
148+
func.voxel_grid_scaffold.forward = scaffold
149+
func._scaffold_ready = True
150+
151+
bundle = ImplicitronRayBundle(
152+
origins=torch.rand((n_test_points, 2, 1, 3), device=device),
153+
directions=torch.rand((n_test_points, 2, 1, 3), device=device),
154+
lengths=torch.rand((n_test_points, 2, 1, 4), device=device),
155+
xys=None,
156+
)
157+
points = ray_bundle_to_ray_points(bundle)
158+
result_density, result_color, _ = func(bundle)
159+
160+
# construct the wanted result 'by hand'
161+
flat_points = points.view(-1, 3)
162+
expected_result_density, expected_result_color = [], []
163+
for point in flat_points:
164+
if scaffold(point) == 1:
165+
expected_result_density.append(point.sum(dim=-1, keepdim=True))
166+
expected_result_color.append(point * 2)
167+
else:
168+
expected_result_density.append(point.new_zeros((1,)))
169+
expected_result_color.append(point.new_zeros((3,)))
170+
expected_result_density = torch.stack(expected_result_density, dim=0).view(
171+
*points.shape[:-1], 1
172+
)
173+
expected_result_color = torch.stack(expected_result_color, dim=0).view(
174+
*points.shape[:-1], 3
175+
)
176+
177+
# check that thre result is expected
178+
assert torch.allclose(result_density, expected_result_density), (
179+
result_density,
180+
expected_result_density,
181+
)
182+
assert torch.allclose(result_color, expected_result_color), (
183+
result_color,
184+
expected_result_color,
185+
)
186+
187+
def test_cropping(self, scaffold_res=9):
188+
"""
189+
Tests whether implicit function finds the bounding box of the object and sends
190+
correct min and max points to voxel grids for rescaling.
191+
"""
192+
device = "cuda" if torch.cuda.is_available() else "cpu"
193+
func = self._get_simple_implicit_function(scaffold_res=scaffold_res).to(device)
194+
195+
assert scaffold_res >= 8
196+
div = (scaffold_res - 1) / 2
197+
true_min_point = torch.tensor(
198+
[-3 / div, 0 / div, -3 / div],
199+
device=device,
200+
)
201+
true_max_point = torch.tensor(
202+
[1 / div, 2 / div, 3 / div],
203+
device=device,
204+
)
205+
206+
def new_scaffold(points):
207+
# 1 if between true_min and true_max point else 0
208+
# return points.new_ones((*points.shape[:-1], 1))
209+
return (
210+
torch.logical_and(true_min_point <= points, points <= true_max_point)
211+
.all(dim=-1)
212+
.float()[..., None]
213+
)
214+
215+
called_crop = []
216+
217+
def assert_min_max_points(min_point, max_point):
218+
called_crop.append(1)
219+
self.assertClose(min_point, true_min_point)
220+
self.assertClose(max_point, true_max_point)
221+
222+
func.voxel_grid_density.crop_self = assert_min_max_points
223+
func.voxel_grid_color.crop_self = assert_min_max_points
224+
func.voxel_grid_scaffold.forward = new_scaffold
225+
func._scaffold_ready = True
226+
func._crop(epoch=0)
227+
assert len(called_crop) == 2

0 commit comments

Comments
 (0)
Please sign in to comment.