|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import math |
| 8 | +import warnings |
| 9 | +from dataclasses import fields |
| 10 | +from typing import Callable, Dict, Optional, Tuple, Union |
| 11 | + |
| 12 | +import torch |
| 13 | + |
| 14 | +from omegaconf import DictConfig |
| 15 | + |
| 16 | +from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase |
| 17 | +from pytorch3d.implicitron.models.implicit_function.decoding_functions import ( |
| 18 | + DecoderFunctionBase, |
| 19 | +) |
| 20 | +from pytorch3d.implicitron.models.implicit_function.voxel_grid import VoxelGridModule |
| 21 | +from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle |
| 22 | +from pytorch3d.implicitron.tools.config import ( |
| 23 | + enable_get_default_args, |
| 24 | + get_default_args_field, |
| 25 | + registry, |
| 26 | + run_auto_creation, |
| 27 | +) |
| 28 | +from pytorch3d.renderer import ray_bundle_to_ray_points |
| 29 | +from pytorch3d.renderer.cameras import CamerasBase |
| 30 | +from pytorch3d.renderer.implicit import HarmonicEmbedding |
| 31 | + |
| 32 | +enable_get_default_args(HarmonicEmbedding) |
| 33 | + |
| 34 | + |
| 35 | +@registry.register |
| 36 | +# pyre-ignore[13] |
| 37 | +class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module): |
| 38 | + """ |
| 39 | + This implicit function consists of two streams, one for the density calculation and one |
| 40 | + for the color calculation. Each of these streams has three main parts: |
| 41 | + 1) Voxel grids: |
| 42 | + They take the (x, y, z) position and return the embedding of that point. |
| 43 | + These components are replaceable, you can make your own or choose one of |
| 44 | + several options. |
| 45 | + 2) Harmonic embeddings: |
| 46 | + Convert each feature into series of 'harmonic features', feature is passed through |
| 47 | + sine and cosine functions. Input is of shape [minibatch, ..., D] output |
| 48 | + [minibatch, ..., (n_harmonic_functions * 2 + int(append_input)) * D]. Appends |
| 49 | + input by default. If you want it to behave like identity, put n_harmonic_functions=0 |
| 50 | + and append_input=True. |
| 51 | + 3) Decoding functions: |
| 52 | + The decoder is an instance of the DecoderFunctionBase and converts the embedding |
| 53 | + of a spatial location to density/color. Examples are Identity which returns its |
| 54 | + input and the MLP which uses fully connected nerual network to transform the input. |
| 55 | + These components are replaceable, you can make your own or choose from |
| 56 | + several options. |
| 57 | +
|
| 58 | + Calculating density is done in three steps: |
| 59 | + 1) Evaluating the voxel grid on points |
| 60 | + 2) Embedding the outputs with harmonic embedding |
| 61 | + 3) Passing through the Density decoder |
| 62 | +
|
| 63 | + To calculate the color we need the embedding and the viewing direction, it has five steps: |
| 64 | + 1) Transforming the viewing direction with camera |
| 65 | + 2) Evaluating the voxel grid on points |
| 66 | + 3) Embedding the outputs with harmonic embedding |
| 67 | + 4) Embedding the normalized direction with harmonic embedding |
| 68 | + 5) Passing everything through the Color decoder |
| 69 | +
|
| 70 | + If using the Implicitron configuration system the input_dim to the decoding functions will |
| 71 | + be set to the output_dim of the Harmonic embeddings. |
| 72 | +
|
| 73 | + A speed up comes from using the scaffold, a low resolution voxel grid. |
| 74 | + The scaffold is referenced as "binary occupancy grid mask" in TensoRF paper and "AlphaMask" |
| 75 | + in official TensoRF implementation. |
| 76 | + The scaffold is used in: |
| 77 | + 1) filtering points in empty space |
| 78 | + - controlled by `scaffold_filter_points` boolean. If set to True, points for which |
| 79 | + scaffold predicts that are in empty space will return 0 density and |
| 80 | + (0, 0, 0) color. |
| 81 | + 2) calculating the bounding box of an object and cropping the voxel grids |
| 82 | + - controlled by `volume_cropping_epochs`. |
| 83 | + - at those epochs the implicit function will find the bounding box of an object |
| 84 | + inside it and crop density and color grids. Cropping of the voxel grids means |
| 85 | + preserving only voxel values that are inside the bounding box and changing the |
| 86 | + resolution to match the original, while preserving the new cropped location in |
| 87 | + world coordinates. |
| 88 | +
|
| 89 | + The scaffold has to exist before attempting filtering and cropping, and is created on |
| 90 | + `scaffold_calculating_epochs`. Each voxel in the scaffold is labeled as having density 1 if |
| 91 | + the point in the center of it evaluates to greater than `scaffold_empty_space_threshold`. |
| 92 | + 3D max pooling is performed on the densities of the points in 3D. |
| 93 | + Scaffold features are off by default. |
| 94 | +
|
| 95 | + Members: |
| 96 | + voxel_grid_density (VoxelGridBase): voxel grid to use for density estimation |
| 97 | + voxel_grid_color (VoxelGridBase): voxel grid to use for color estimation |
| 98 | +
|
| 99 | + harmonic_embedder_xyz_density (HarmonicEmbedder): Function to transform the outputs of |
| 100 | + the voxel_grid_density |
| 101 | + harmonic_embedder_xyz_color (HarmonicEmbedder): Function to transform the outputs of |
| 102 | + the voxel_grid_color for density |
| 103 | + harmonic_embedder_dir_color (HarmonicEmbedder): Function to transform the outputs of |
| 104 | + the voxel_grid_color for color |
| 105 | +
|
| 106 | + decoder_density (DecoderFunctionBase): decoder function to use for density estimation |
| 107 | + color_density (DecoderFunctionBase): decoder function to use for color estimation |
| 108 | +
|
| 109 | + use_multiple_streams (bool): if you want the density and color calculations to run on |
| 110 | + different cuda streams set this to True. Default True. |
| 111 | + xyz_ray_dir_in_camera_coords (bool): This is true if the directions are given in |
| 112 | + camera coordinates. Default False. |
| 113 | +
|
| 114 | + voxel_grid_scaffold (VoxelGridModule): which holds the scaffold. Extents and |
| 115 | + translation of it are set to those of voxel_grid_density. |
| 116 | + scaffold_calculating_epochs (Tuple[int, ...]): at which epochs to recalculate the |
| 117 | + scaffold. (The scaffold will be created automatically at the beginning of |
| 118 | + the calculation.) |
| 119 | + scaffold_resolution (Tuple[int, int, int]): (width, height, depth) of the underlying |
| 120 | + voxel grid which stores scaffold |
| 121 | + scaffold_empty_space_threshold (float): if `self.get_density` evaluates to less than |
| 122 | + this it will be considered as empty space and the scaffold at that point would |
| 123 | + evaluate as empty space. |
| 124 | + scaffold_occupancy_chunk_size (str or int): Number of xy scaffold planes to calculate |
| 125 | + at the same time. To calculate the scaffold we need to query `get_density()` at |
| 126 | + every voxel, this calculation can be split into scaffold depth number of xy plane |
| 127 | + calculations if you want the lowest memory usage, one calculation to calculate the |
| 128 | + whole scaffold, but with higher memory footprint or any other number of planes. |
| 129 | + Setting to 'inf' calculates all planes at the same time. Defaults to 'inf'. |
| 130 | + scaffold_max_pool_kernel_size (int): Size of the pooling region to use when |
| 131 | + calculating the scaffold. Defaults to 3. |
| 132 | + scaffold_filter_points (bool): If set to True the points will be filtered using |
| 133 | + `self.voxel_grid_scaffold`. Filtered points will be predicted as having 0 density |
| 134 | + and (0, 0, 0) color. The points which were not evaluated as empty space will be |
| 135 | + passed through the steps outlined above. |
| 136 | + volume_cropping_epochs: on which epochs to crop the voxel grids to fit the object's |
| 137 | + bounding box. Scaffold has to be calculated before cropping. |
| 138 | + """ |
| 139 | + |
| 140 | + # ---- voxel grid for density |
| 141 | + voxel_grid_density: VoxelGridModule |
| 142 | + |
| 143 | + # ---- voxel grid for color |
| 144 | + voxel_grid_color: VoxelGridModule |
| 145 | + |
| 146 | + # ---- harmonic embeddings density |
| 147 | + harmonic_embedder_xyz_density_args: DictConfig = get_default_args_field( |
| 148 | + HarmonicEmbedding |
| 149 | + ) |
| 150 | + harmonic_embedder_xyz_color_args: DictConfig = get_default_args_field( |
| 151 | + HarmonicEmbedding |
| 152 | + ) |
| 153 | + harmonic_embedder_dir_color_args: DictConfig = get_default_args_field( |
| 154 | + HarmonicEmbedding |
| 155 | + ) |
| 156 | + |
| 157 | + # ---- decoder function for density |
| 158 | + decoder_density_class_type: str = "MLPDecoder" |
| 159 | + decoder_density: DecoderFunctionBase |
| 160 | + |
| 161 | + # ---- decoder function for color |
| 162 | + decoder_color_class_type: str = "MLPDecoder" |
| 163 | + decoder_color: DecoderFunctionBase |
| 164 | + |
| 165 | + # ---- cuda streams |
| 166 | + use_multiple_streams: bool = True |
| 167 | + |
| 168 | + # ---- camera |
| 169 | + xyz_ray_dir_in_camera_coords: bool = False |
| 170 | + |
| 171 | + # --- scaffold |
| 172 | + # voxel_grid_scaffold: VoxelGridModule |
| 173 | + scaffold_calculating_epochs: Tuple[int, ...] = () |
| 174 | + scaffold_resolution: Tuple[int, int, int] = (128, 128, 128) |
| 175 | + scaffold_empty_space_threshold: float = 0.001 |
| 176 | + scaffold_occupancy_chunk_size: Union[str, int] = "inf" |
| 177 | + scaffold_max_pool_kernel_size: int = 3 |
| 178 | + scaffold_filter_points: bool = True |
| 179 | + |
| 180 | + # --- cropping |
| 181 | + volume_cropping_epochs: Tuple[int, ...] = () |
| 182 | + |
| 183 | + def __post_init__(self) -> None: |
| 184 | + super().__init__() |
| 185 | + run_auto_creation(self) |
| 186 | + # pyre-ignore[16] |
| 187 | + self.voxel_grid_scaffold = self._create_voxel_grid_scaffold() |
| 188 | + # pyre-ignore[16] |
| 189 | + self.harmonic_embedder_xyz_density = HarmonicEmbedding( |
| 190 | + **self.harmonic_embedder_xyz_density_args |
| 191 | + ) |
| 192 | + # pyre-ignore[16] |
| 193 | + self.harmonic_embedder_xyz_color = HarmonicEmbedding( |
| 194 | + **self.harmonic_embedder_xyz_color_args |
| 195 | + ) |
| 196 | + # pyre-ignore[16] |
| 197 | + self.harmonic_embedder_dir_color = HarmonicEmbedding( |
| 198 | + **self.harmonic_embedder_dir_color_args |
| 199 | + ) |
| 200 | + # pyre-ignore[16] |
| 201 | + self._scaffold_ready = False |
| 202 | + if type(self.scaffold_occupancy_chunk_size) != int: |
| 203 | + if self.scaffold_occupancy_chunk_size != "inf": |
| 204 | + raise ValueError( |
| 205 | + "`scaffold_occupancy_chunk_size` has to be int or 'inf'." |
| 206 | + ) |
| 207 | + |
| 208 | + def forward( |
| 209 | + self, |
| 210 | + ray_bundle: ImplicitronRayBundle, |
| 211 | + fun_viewpool=None, |
| 212 | + camera: Optional[CamerasBase] = None, |
| 213 | + global_code=None, |
| 214 | + **kwargs, |
| 215 | + ) -> Tuple[torch.Tensor, torch.Tensor, Dict]: |
| 216 | + """ |
| 217 | + The forward function accepts the parametrizations of 3D points sampled along |
| 218 | + projection rays. The forward pass is responsible for attaching a 3D vector |
| 219 | + and a 1D scalar representing the point's RGB color and opacity respectively. |
| 220 | +
|
| 221 | + Args: |
| 222 | + ray_bundle: An ImplicitronRayBundle object containing the following variables: |
| 223 | + origins: A tensor of shape `(minibatch, ..., 3)` denoting the |
| 224 | + origins of the sampling rays in world coords. |
| 225 | + directions: A tensor of shape `(minibatch, ..., 3)` |
| 226 | + containing the direction vectors of sampling rays in world coords. |
| 227 | + lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)` |
| 228 | + containing the lengths at which the rays are sampled. |
| 229 | + fun_viewpool: an optional callback with the signature |
| 230 | + fun_fiewpool(points) -> pooled_features |
| 231 | + where points is a [N_TGT x N x 3] tensor of world coords, |
| 232 | + and pooled_features is a [N_TGT x ... x N_SRC x latent_dim] tensor |
| 233 | + of the features pooled from the context images. |
| 234 | + camera: A camera model which will be used to transform the viewing |
| 235 | + directions |
| 236 | +
|
| 237 | + Returns: |
| 238 | + rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)` |
| 239 | + denoting the opacitiy of each ray point. |
| 240 | + rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)` |
| 241 | + denoting the color of each ray point. |
| 242 | + """ |
| 243 | + # ########## convert the ray parametrizations to world coordinates ########## # |
| 244 | + # points.shape = [minibatch x n_rays_width x n_rays_height x pts_per_ray x 3] |
| 245 | + # pyre-ignore[6] |
| 246 | + points = ray_bundle_to_ray_points(ray_bundle) |
| 247 | + directions = ray_bundle.directions.reshape(-1, 3) |
| 248 | + input_shape = points.shape |
| 249 | + points = points.view(-1, 3) |
| 250 | + |
| 251 | + # ########## filter the points using the scaffold ########## # |
| 252 | + if self._scaffold_ready and self.scaffold_filter_points: |
| 253 | + # pyre-ignore[29] |
| 254 | + non_empty_points = self.voxel_grid_scaffold(points)[..., 0] > 0 |
| 255 | + points = points[non_empty_points] |
| 256 | + directions = directions[non_empty_points] |
| 257 | + if len(points) == 0: |
| 258 | + warnings.warn( |
| 259 | + "The scaffold has filtered all the points." |
| 260 | + "The voxel grids and decoding functions will not be run." |
| 261 | + ) |
| 262 | + return ( |
| 263 | + points.new_zeros((*input_shape[:-1], 1)), |
| 264 | + points.new_zeros((*input_shape[:-1], 3)), |
| 265 | + {}, |
| 266 | + ) |
| 267 | + |
| 268 | + # ########## calculate color and density ########## # |
| 269 | + rays_densities, rays_colors = self.calculate_density_and_color( |
| 270 | + points, directions, camera |
| 271 | + ) |
| 272 | + |
| 273 | + if not (self._scaffold_ready and self.scaffold_filter_points): |
| 274 | + return ( |
| 275 | + rays_densities.view((*input_shape[:-1], rays_densities.shape[-1])), |
| 276 | + rays_colors.view((*input_shape[:-1], rays_colors.shape[-1])), |
| 277 | + {}, |
| 278 | + ) |
| 279 | + |
| 280 | + # ########## merge scaffold calculated points ########## # |
| 281 | + # Create a zeroed tensor corresponding to a point with density=0 and fill it |
| 282 | + # with calculated density for points which are not in empty space. Do the |
| 283 | + # same for color |
| 284 | + rays_densities_combined = rays_densities.new_zeros( |
| 285 | + (math.prod(input_shape[:-1]), rays_densities.shape[-1]) |
| 286 | + ) |
| 287 | + rays_colors_combined = rays_colors.new_zeros( |
| 288 | + (math.prod(input_shape[:-1]), rays_colors.shape[-1]) |
| 289 | + ) |
| 290 | + # pyre-ignore[61] |
| 291 | + rays_densities_combined[non_empty_points] = rays_densities |
| 292 | + # pyre-ignore[61] |
| 293 | + rays_colors_combined[non_empty_points] = rays_colors |
| 294 | + |
| 295 | + return ( |
| 296 | + rays_densities_combined.view((*input_shape[:-1], rays_densities.shape[-1])), |
| 297 | + rays_colors_combined.view((*input_shape[:-1], rays_colors.shape[-1])), |
| 298 | + {}, |
| 299 | + ) |
| 300 | + |
| 301 | + def calculate_density_and_color( |
| 302 | + self, |
| 303 | + points: torch.Tensor, |
| 304 | + directions: torch.Tensor, |
| 305 | + camera: Optional[CamerasBase] = None, |
| 306 | + ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 307 | + """ |
| 308 | + Calculates density and color at `points`. |
| 309 | + If enabled use cuda streams. |
| 310 | +
|
| 311 | + Args: |
| 312 | + points: points at which to calculate density and color. |
| 313 | + Tensor of shape [..., 3]. |
| 314 | + directions: from which directions are the points viewed |
| 315 | + Tensor of shape [..., 3]. |
| 316 | + camera: A camera model which will be used to transform the viewing |
| 317 | + directions |
| 318 | + Returns: |
| 319 | + Tuple of color (tensor of shape [..., 3]) and density |
| 320 | + (tensor of shape [..., 1]) |
| 321 | + """ |
| 322 | + if self.use_multiple_streams and points.is_cuda: |
| 323 | + current_stream = torch.cuda.current_stream(points.device) |
| 324 | + other_stream = torch.cuda.Stream(points.device) |
| 325 | + other_stream.wait_stream(current_stream) |
| 326 | + |
| 327 | + with torch.cuda.stream(other_stream): |
| 328 | + # rays_densities.shape = |
| 329 | + # [minibatch x n_rays_width x n_rays_height x pts_per_ray x density_dim] |
| 330 | + rays_densities = self.get_density(points) |
| 331 | + |
| 332 | + # rays_colors.shape = |
| 333 | + # [minibatch x n_rays_width x n_rays_height x pts_per_ray x color_dim] |
| 334 | + rays_colors = self.get_color(points, camera, directions) |
| 335 | + |
| 336 | + current_stream.wait_stream(other_stream) |
| 337 | + else: |
| 338 | + # Same calculation as above, just serial. |
| 339 | + rays_densities = self.get_density(points) |
| 340 | + rays_colors = self.get_color(points, camera, directions) |
| 341 | + return rays_densities, rays_colors |
| 342 | + |
| 343 | + def get_density(self, points: torch.Tensor) -> torch.Tensor: |
| 344 | + """ |
| 345 | + Calculates density at points: |
| 346 | + 1) Evaluates the voxel grid on points |
| 347 | + 2) Embeds the outputs with harmonic embedding |
| 348 | + 3) Passes everything through the Density decoder |
| 349 | +
|
| 350 | + Args: |
| 351 | + points: tensor of shape [..., 3] |
| 352 | + where the last dimension is the points in the (x, y, z) |
| 353 | + Returns: |
| 354 | + calculated densities of shape [..., density_dim], `density_dim` is the |
| 355 | + feature dimensionality which `decoder_density` returns |
| 356 | + """ |
| 357 | + embeds_density = self.voxel_grid_density(points) |
| 358 | + # pyre-ignore[29] |
| 359 | + harmonic_embedding_density = self.harmonic_embedder_xyz_density(embeds_density) |
| 360 | + # shape = [..., density_dim] |
| 361 | + return self.decoder_density(harmonic_embedding_density) |
| 362 | + |
| 363 | + def get_color( |
| 364 | + self, |
| 365 | + points: torch.Tensor, |
| 366 | + camera: Optional[CamerasBase], |
| 367 | + directions: torch.Tensor, |
| 368 | + ) -> torch.Tensor: |
| 369 | + """ |
| 370 | + Calculates color at points using the viewing direction: |
| 371 | + 1) Transforms the viewing direction with camera |
| 372 | + 2) Evaluates the voxel grid on points |
| 373 | + 3) Embeds the outputs with harmonic embedding |
| 374 | + 4) Embeds the normalized direction with harmonic embedding |
| 375 | + 5) Passes everything through the Color decoder |
| 376 | + Args: |
| 377 | + points: tensor of shape (..., 3) |
| 378 | + where the last dimension is the points in the (x, y, z) |
| 379 | + camera: A camera model which will be used to transform the viewing |
| 380 | + directions |
| 381 | + directions: A tensor of shape `(..., 3)` |
| 382 | + containing the direction vectors of sampling rays in world coords. |
| 383 | + """ |
| 384 | + # ########## transform direction ########## # |
| 385 | + if self.xyz_ray_dir_in_camera_coords: |
| 386 | + if camera is None: |
| 387 | + raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords") |
| 388 | + directions = directions @ camera.R |
| 389 | + |
| 390 | + # ########## get voxel grid output ########## # |
| 391 | + # embeds_color.shape = [..., pts_per_ray, n_features] |
| 392 | + embeds_color = self.voxel_grid_color(points) |
| 393 | + |
| 394 | + # ########## embed with the harmonic function ########## # |
| 395 | + # Obtain the harmonic embedding of the voxel grid output. |
| 396 | + # pyre-ignore[29] |
| 397 | + harmonic_embedding_color = self.harmonic_embedder_xyz_color(embeds_color) |
| 398 | + |
| 399 | + # Normalize the ray_directions to unit l2 norm. |
| 400 | + rays_directions_normed = torch.nn.functional.normalize(directions, dim=-1) |
| 401 | + # Obtain the harmonic embedding of the normalized ray directions. |
| 402 | + # pyre-ignore[29] |
| 403 | + harmonic_embedding_dir = self.harmonic_embedder_dir_color( |
| 404 | + rays_directions_normed |
| 405 | + ) |
| 406 | + |
| 407 | + n_rays = directions.shape[0] |
| 408 | + points_per_ray: int = points.shape[0] // n_rays |
| 409 | + |
| 410 | + harmonic_embedding_dir = torch.repeat_interleave( |
| 411 | + harmonic_embedding_dir, points_per_ray, dim=0 |
| 412 | + ) |
| 413 | + |
| 414 | + # total color embedding is concatenation of the harmonic embedding of voxel grid |
| 415 | + # output and harmonic embedding of the normalized direction |
| 416 | + total_color_embedding = torch.cat( |
| 417 | + (harmonic_embedding_color, harmonic_embedding_dir), dim=-1 |
| 418 | + ) |
| 419 | + |
| 420 | + # ########## evaluate color with the decoding function ########## # |
| 421 | + # rays_colors.shape = [..., pts_per_ray, 3] in [0-1] |
| 422 | + return self.decoder_color(total_color_embedding) |
| 423 | + |
| 424 | + @staticmethod |
| 425 | + def allows_multiple_passes() -> bool: |
| 426 | + """ |
| 427 | + Returns True as this implicit function allows |
| 428 | + multiple passes. Overridden from ImplicitFunctionBase. |
| 429 | + """ |
| 430 | + return True |
| 431 | + |
| 432 | + def subscribe_to_epochs(self) -> Tuple[Tuple[int, ...], Callable[[int], bool]]: |
| 433 | + """ |
| 434 | + Method which expresses interest in subscribing to optimization epoch updates. |
| 435 | + This implicit function subscribes to epochs to calculate the scaffold and to |
| 436 | + crop voxel grids, so this method combines wanted epochs and wraps their callbacks. |
| 437 | +
|
| 438 | + Returns: |
| 439 | + list of epochs on which to call a callable and callable to be called on |
| 440 | + particular epoch. The callable returns True if parameter change has |
| 441 | + happened else False and it must be supplied with one argument, epoch. |
| 442 | + """ |
| 443 | + |
| 444 | + def callback(epoch) -> bool: |
| 445 | + change = False |
| 446 | + if epoch in self.scaffold_calculating_epochs: |
| 447 | + change = self._get_scaffold(epoch) |
| 448 | + if epoch in self.volume_cropping_epochs: |
| 449 | + change = self._crop(epoch) or change |
| 450 | + return change |
| 451 | + |
| 452 | + # remove duplicates |
| 453 | + call_epochs = list( |
| 454 | + set(self.scaffold_calculating_epochs) | set(self.volume_cropping_epochs) |
| 455 | + ) |
| 456 | + return call_epochs, callback |
| 457 | + |
| 458 | + def _crop(self, epoch: int) -> bool: |
| 459 | + """ |
| 460 | + Finds the bounding box of an object represented in the scaffold and crops |
| 461 | + density and color voxel grids to match that bounding box. If density of the |
| 462 | + scaffold is 0 everywhere (there is no object in it) no change will |
| 463 | + happen. |
| 464 | +
|
| 465 | + Args: |
| 466 | + epoch: ignored |
| 467 | + Returns: |
| 468 | + True (indicating that parameter change has happened) if there is |
| 469 | + an object inside, else False. |
| 470 | + """ |
| 471 | + # find bounding box |
| 472 | + # pyre-ignore[16] |
| 473 | + points = self.voxel_grid_scaffold.get_grid_points(epoch=epoch) |
| 474 | + assert self._scaffold_ready, "Scaffold has to be calculated before cropping." |
| 475 | + # pyre-ignore[29] |
| 476 | + occupancy = self.voxel_grid_scaffold(points)[..., 0] > 0 |
| 477 | + non_zero_idxs = torch.nonzero(occupancy) |
| 478 | + if len(non_zero_idxs) == 0: |
| 479 | + return False |
| 480 | + min_indices = tuple(torch.min(non_zero_idxs, dim=0)[0]) |
| 481 | + max_indices = tuple(torch.max(non_zero_idxs, dim=0)[0]) |
| 482 | + min_point, max_point = points[min_indices], points[max_indices] |
| 483 | + |
| 484 | + # crop the voxel grids |
| 485 | + self.voxel_grid_density.crop_self(min_point, max_point) |
| 486 | + self.voxel_grid_color.crop_self(min_point, max_point) |
| 487 | + return True |
| 488 | + |
| 489 | + @torch.no_grad() |
| 490 | + def _get_scaffold(self, epoch: int) -> bool: |
| 491 | + """ |
| 492 | + Creates a low resolution grid which is used to filter points that are in empty |
| 493 | + space. |
| 494 | +
|
| 495 | + Args: |
| 496 | + epoch: epoch on which it is called, ignored inside method |
| 497 | + Returns: |
| 498 | + Always False: Modifies `self.voxel_grid_scaffold` member. |
| 499 | + """ |
| 500 | + |
| 501 | + planes = [] |
| 502 | + # pyre-ignore[16] |
| 503 | + points = self.voxel_grid_scaffold.get_grid_points(epoch=epoch) |
| 504 | + |
| 505 | + chunk_size = ( |
| 506 | + self.scaffold_occupancy_chunk_size |
| 507 | + if type(self.scaffold_occupancy_chunk_size) == int |
| 508 | + else points.shape[-1] |
| 509 | + ) |
| 510 | + for k in range(0, points.shape[-1], chunk_size): |
| 511 | + points_in_planes = points[..., k : k + chunk_size] |
| 512 | + planes.append(self.get_density(points_in_planes)[..., 0]) |
| 513 | + |
| 514 | + density_cube = torch.cat(planes, dim=-1) |
| 515 | + density_cube = torch.nn.functional.max_pool3d( |
| 516 | + density_cube[None, None], |
| 517 | + kernel_size=self.scaffold_max_pool_kernel_size, |
| 518 | + padding=self.scaffold_max_pool_kernel_size // 2, |
| 519 | + stride=1, |
| 520 | + ) |
| 521 | + occupancy_cube = density_cube > self.scaffold_empty_space_threshold |
| 522 | + # pyre-ignore[16] |
| 523 | + self.voxel_grid_scaffold.params["voxel_grid"] = occupancy_cube.float() |
| 524 | + # pyre-ignore[16] |
| 525 | + self._scaffold_ready = True |
| 526 | + |
| 527 | + return False |
| 528 | + |
| 529 | + @classmethod |
| 530 | + def decoder_density_tweak_args(cls, type, args: DictConfig) -> None: |
| 531 | + args.pop("input_dim", None) |
| 532 | + |
| 533 | + def create_decoder_density_impl(self, type, args: DictConfig) -> None: |
| 534 | + """ |
| 535 | + Decoding functions come after harmonic embedding and voxel grid. In order to not |
| 536 | + calculate the input dimension of the decoder in the config file this function |
| 537 | + calculates the required input dimension and sets the input dimension of the |
| 538 | + decoding function to this value. |
| 539 | + """ |
| 540 | + grid_args = self.voxel_grid_density_args |
| 541 | + # pyre-ignore[6] |
| 542 | + grid_output_dim = VoxelGridModule.get_output_dim(grid_args) |
| 543 | + |
| 544 | + embedder_args = self.harmonic_embedder_xyz_density_args |
| 545 | + input_dim = HarmonicEmbedding.get_output_dim_static( |
| 546 | + grid_output_dim, |
| 547 | + embedder_args["n_harmonic_functions"], |
| 548 | + embedder_args["append_input"], |
| 549 | + ) |
| 550 | + |
| 551 | + cls = registry.get(DecoderFunctionBase, type) |
| 552 | + need_input_dim = any(field.name == "input_dim" for field in fields(cls)) |
| 553 | + if need_input_dim: |
| 554 | + self.decoder_density = cls(input_dim=input_dim, **args) |
| 555 | + else: |
| 556 | + self.decoder_density = cls(**args) |
| 557 | + |
| 558 | + @classmethod |
| 559 | + def decoder_color_tweak_args(cls, type, args: DictConfig) -> None: |
| 560 | + args.pop("input_dim", None) |
| 561 | + |
| 562 | + def create_decoder_color_impl(self, type, args: DictConfig) -> None: |
| 563 | + """ |
| 564 | + Decoding functions come after harmonic embedding and voxel grid. In order to not |
| 565 | + calculate the input dimension of the decoder in the config file this function |
| 566 | + calculates the required input dimension and sets the input dimension of the |
| 567 | + decoding function to this value. |
| 568 | + """ |
| 569 | + grid_args = self.voxel_grid_color_args |
| 570 | + # pyre-ignore[6] |
| 571 | + grid_output_dim = VoxelGridModule.get_output_dim(grid_args) |
| 572 | + |
| 573 | + embedder_args = self.harmonic_embedder_xyz_color_args |
| 574 | + input_dim0 = HarmonicEmbedding.get_output_dim_static( |
| 575 | + grid_output_dim, |
| 576 | + embedder_args["n_harmonic_functions"], |
| 577 | + embedder_args["append_input"], |
| 578 | + ) |
| 579 | + |
| 580 | + dir_dim = 3 |
| 581 | + embedder_args = self.harmonic_embedder_dir_color_args |
| 582 | + input_dim1 = HarmonicEmbedding.get_output_dim_static( |
| 583 | + dir_dim, |
| 584 | + embedder_args["n_harmonic_functions"], |
| 585 | + embedder_args["append_input"], |
| 586 | + ) |
| 587 | + |
| 588 | + input_dim = input_dim0 + input_dim1 |
| 589 | + |
| 590 | + cls = registry.get(DecoderFunctionBase, type) |
| 591 | + need_input_dim = any(field.name == "input_dim" for field in fields(cls)) |
| 592 | + if need_input_dim: |
| 593 | + self.decoder_color = cls(input_dim=input_dim, **args) |
| 594 | + else: |
| 595 | + self.decoder_color = cls(**args) |
| 596 | + |
| 597 | + def _create_voxel_grid_scaffold(self) -> VoxelGridModule: |
| 598 | + """ |
| 599 | + Creates object to become self.voxel_grid_scaffold: |
| 600 | + - makes `self.voxel_grid_scaffold` have same world to local mapping as |
| 601 | + `self.voxel_grid_density` |
| 602 | + """ |
| 603 | + return VoxelGridModule( |
| 604 | + # pyre-ignore[29] |
| 605 | + extents=self.voxel_grid_density_args["extents"], |
| 606 | + # pyre-ignore[29] |
| 607 | + translation=self.voxel_grid_density_args["translation"], |
| 608 | + voxel_grid_class_type="FullResolutionVoxelGrid", |
| 609 | + hold_voxel_grid_as_parameters=False, |
| 610 | + voxel_grid_FullResolutionVoxelGrid_args={ |
| 611 | + "resolution_changes": {0: self.scaffold_resolution}, |
| 612 | + "padding": "zeros", |
| 613 | + "align_corners": True, |
| 614 | + "mode": "trilinear", |
| 615 | + }, |
| 616 | + ) |
0 commit comments