|
| 1 | +import math |
| 2 | +import typing |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import torch |
| 6 | +from colour import Color |
| 7 | +from diff_gauss import GaussianRasterizationSettings, GaussianRasterizer |
| 8 | + |
| 9 | +__all__ = ["GaussianSplatRasterizer"] |
| 10 | + |
| 11 | + |
| 12 | +# NOTE: different gaussian splatting versions produce different number and order of outputs |
| 13 | +# but all collide using the same package |
| 14 | + |
| 15 | + |
| 16 | +class GaussianSplatRasterizer(torch.nn.Module): |
| 17 | + def __init__( |
| 18 | + self, |
| 19 | + width: typing.Optional[int] = None, |
| 20 | + height: typing.Optional[int] = None, |
| 21 | + background_color: str = "black", |
| 22 | + prefiltered: bool = False, |
| 23 | + debug: bool = False, |
| 24 | + scale_modifier: float = 1.0, |
| 25 | + ): |
| 26 | + super().__init__() |
| 27 | + color = Color(background_color).get_rgb() |
| 28 | + self.register_buffer("background_color", torch.tensor(color)) |
| 29 | + self.width, self.height = width, height |
| 30 | + self.prefiltered, self.debug = prefiltered, debug |
| 31 | + self.scale_modifier = scale_modifier |
| 32 | + |
| 33 | + def forward( |
| 34 | + self, |
| 35 | + view_matrix: torch.Tensor, # [B, 4, 4] |
| 36 | + view_projection_matrix: torch.Tensor, # [B, 4, 4] |
| 37 | + camera_position: torch.Tensor, # [B, 3] |
| 38 | + positions: torch.Tensor, # [B, V, 3] |
| 39 | + sh_coeffs: torch.Tensor, # [B, SH, 3] |
| 40 | + opacities: torch.Tensor, # [B, V, 1] |
| 41 | + rotations: torch.Tensor, # [B, V, 4] |
| 42 | + scales: torch.Tensor, # [B, V, 3] |
| 43 | + features: torch.Tensor, # [B, V, K] |
| 44 | + intrinsics: torch.Tensor, # [B, 3, 3] |
| 45 | + image: typing.Optional[torch.Tensor] = None, # [B, C, H, W] |
| 46 | + background_color: typing.Optional[torch.Tensor] = None, # [B, 3] |
| 47 | + ): |
| 48 | + assert len(positions.shape) == 3 |
| 49 | + B = view_matrix.shape[0] |
| 50 | + if positions.shape[0] != B: # either many-to-many or one-to-many |
| 51 | + positions = positions.expand(B, -1, -1) |
| 52 | + sh_coeffs = sh_coeffs.expand(B, -1, -1, -1) |
| 53 | + opacities = opacities.expand(B, -1, -1) |
| 54 | + rotations = rotations.expand(B, -1, -1) |
| 55 | + scales = scales.expand(B, -1, -1) |
| 56 | + features = features.expand(B, -1, -1) |
| 57 | + bg = background_color if background_color is not None else self.background_color |
| 58 | + sh_degree = math.sqrt(sh_coeffs.shape[-2]) - 1 |
| 59 | + if bg.shape[0] != B: |
| 60 | + bg = bg.expand(B, -1) |
| 61 | + colors, radiis, depths, alphas, extras = [], [], [], [], [] |
| 62 | + for i in range(B): |
| 63 | + W = image[i].shape[-1] if image is not None else self.width |
| 64 | + H = image[i].shape[-2] if image is not None else self.height |
| 65 | + settings = GaussianRasterizationSettings( |
| 66 | + image_height=H, |
| 67 | + image_width=W, |
| 68 | + # tanfovx=2.0 * np.arctan(W / (2.0 * intrinsics[i, 0, 0].cpu().float())), |
| 69 | + # tanfovy=2.0 * np.arctan(H / (2.0 * intrinsics[i, 1, 1].cpu().float())), |
| 70 | + tanfovx=W / (2.0 * intrinsics[i, 0, 0].cpu().float()), |
| 71 | + tanfovy=H / (2.0 * intrinsics[i, 1, 1].cpu().float()), |
| 72 | + bg=bg[i], |
| 73 | + scale_modifier=self.scale_modifier, |
| 74 | + viewmatrix=view_matrix[i], |
| 75 | + projmatrix=view_projection_matrix[i], |
| 76 | + sh_degree=int(sh_degree), |
| 77 | + campos=camera_position[i], |
| 78 | + prefiltered=self.prefiltered, |
| 79 | + debug=self.debug, |
| 80 | + ) |
| 81 | + screenspace_points = torch.zeros_like( |
| 82 | + positions[i] |
| 83 | + ) # , requires_grad=True) + 0 |
| 84 | + screenspace_points.requires_grad_(True) |
| 85 | + screenspace_points.retain_grad() |
| 86 | + rasterizer = GaussianRasterizer(settings) |
| 87 | + color, depth, norm, alpha, radii, feats = rasterizer( # TODO inv_depth |
| 88 | + means3D=positions[i], |
| 89 | + means2D=screenspace_points, |
| 90 | + opacities=opacities[i], |
| 91 | + scales=scales[i], |
| 92 | + rotations=rotations[i], |
| 93 | + shs=sh_coeffs[i], |
| 94 | + extra_attrs=features[i], |
| 95 | + colors_precomp=None, |
| 96 | + cov3Ds_precomp=None, |
| 97 | + ) |
| 98 | + colors.append(color) |
| 99 | + radiis.append(radii) |
| 100 | + depths.append(depth) |
| 101 | + alphas.append(alpha) |
| 102 | + extras.append(feats) |
| 103 | + # "viewspace_points": screenspace_points, |
| 104 | + # "visibility_filter" : radii > 0, |
| 105 | + return { |
| 106 | + "color": torch.stack(colors).clamp(0, 1), |
| 107 | + "radii": torch.stack(radiis), |
| 108 | + "depth": torch.stack(depths), |
| 109 | + "alpha": torch.stack(alphas), |
| 110 | + "features": torch.stack(extras), |
| 111 | + } |
0 commit comments