Skip to content

Commit a3aad45

Browse files
committed
fix: integrate metric specific flows in test viz
+ various monads/components
1 parent 88498de commit a3aad45

File tree

16 files changed

+194
-8
lines changed

16 files changed

+194
-8
lines changed

moai/conf/data/test/loader/torch.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ shuffle: false
66
num_workers: ${oc.select:experiment.workers,0}
77
pin_memory: false
88
drop_last: false
9-
prefetch_factor: 0
9+
prefetch_factor: null # 0
1010
persistent_workers: false
1111
pin_memory_device: ""
1212
# in_order: true

moai/conf/data/train/loader/torch.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ shuffle: true
66
num_workers: ${oc.select:experiment.workers,0}
77
pin_memory: false
88
drop_last: false
9-
prefetch_factor: 0
9+
prefetch_factor: null # 0
1010
persistent_workers: false
1111
pin_memory_device: ""
1212
# in_order: true

moai/conf/model/components/tcnn/encoding_mlp.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
_target_: moai.components.tcnn.modules.EncodingMLP
44
input_dims: ???
5+
hidden_dims: ???
56
output_dims: ???
67
encoding_config: ???
78
seed: ${oc.select:engine.modules.manual_seed.seed,1337}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# @package model.monads.linspace
2+
3+
_target_: moai.monads.generation.tensor.torch.LinSpace
4+
start: 0.0
5+
end: 1.0
6+
steps: ???

moai/conf/model/monads/math/exp.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
# @package model.monads._name_
1+
# @package model.monads.exp
22

3-
_target_: moai.monads.math.Exponential
3+
_target_: moai.monads.math.Exponential
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# @package model.monads.normalize
2+
3+
_target_: moai.monads.math.Normalize
4+
order: 2
5+
dim: -1
6+
epsilon: 1e-12
+2-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
# @package model.monads._name_
1+
# @package model.monads.sigmoid
22

3-
_target_: torch.nn.Sigmoid
3+
_target_: moai.monads.math.Sigmoid
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# @package model.monads.gaussian_splat_rasterizer
2+
3+
_target_: moai.monads.render.splats.rasterize_slothfulxtx.GaussianSplatRasterizer
4+
prefiltered: false
5+
debug: false
6+
scale_modifier: 1.0
7+
width: null
8+
height: null
9+
background_color: black
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# @package model.monads.float32
2+
3+
_target_: moai.monads.tensor.cast.Float32

moai/core/model.py

+2
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,8 @@ def test_step(
453453
if monitor := get_dict(
454454
self.monitor, f"{C._TEST_}.{C._DATASETS_}.{dataset_name}"
455455
):
456+
for step in toolz.get(C._FLOWS_, monitor, None) or []:
457+
self.named_flows[step](batch)
456458
for metric in get_list(monitor, C._METRICS_): # Metrics monitoring
457459
self.named_metrics[metric](batch)
458460
for tensor_monitor in get_list(monitor, C._MONITORS_):

moai/monads/generation/tensor/torch.py

+15
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"RandomLike",
1818
"OnesLike",
1919
"TemporalParams",
20+
"LinSpace",
2021
]
2122

2223

@@ -234,3 +235,17 @@ def __init__(
234235

235236
def forward(self, void: torch.Tensor) -> torch.nn.parameter.Parameter:
236237
return dict(self.named_parameters())
238+
239+
240+
class LinSpace(torch.nn.Module):
241+
def __init__(self, start: float, end: float, steps: int) -> None:
242+
super().__init__()
243+
self.register_buffer(
244+
"spaced", torch.linspace(start, end, steps)[np.newaxis, :, np.newaxis]
245+
)
246+
247+
def forward(
248+
self,
249+
tensor: torch.Tensor,
250+
) -> torch.Tensor:
251+
return self.spaced.expand(tensor.shape[0], -1, 1)

moai/monads/math/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
PlusOne,
1111
Rad2Deg,
1212
Scale,
13+
Sigmoid,
1314
)
1415
from moai.monads.math.dot import Dot
15-
from moai.monads.math.normalization import MinMaxNorm, Znorm
16+
from moai.monads.math.normalization import MinMaxNorm, Normalize, Znorm
1617

1718
__all__ = [
1819
"Abs",
@@ -30,4 +31,6 @@
3031
"Rad2Deg",
3132
"Deg2Rad",
3233
"Exponential",
34+
"Normalization",
35+
"Sigmoid",
3336
]

moai/monads/math/common.py

+9
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"Rad2Deg",
1313
"Deg2Rad",
1414
"Exponential",
15+
"Sigmoid",
1516
]
1617

1718

@@ -104,3 +105,11 @@ def __init__(self) -> None:
104105

105106
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
106107
return torch.exp(tensor)
108+
109+
110+
class Sigmoid(torch.nn.Module):
111+
def __init__(self) -> None:
112+
super().__init__()
113+
114+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
115+
return torch.sigmoid(tensor)

moai/monads/math/normalization.py

+13
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
3131
# .mul(scale.view(b, 1, 1, 1))\
3232
# .add(self.min)
3333
return torch.addcmul(self.min, x - expand_dims(mins, x), expand_dims(scale, x))
34+
35+
36+
class Normalize(torch.nn.Module):
37+
def __init__(self, order: int = 2, dim: int = -1, epsilon: float = 1e-12):
38+
super().__init__()
39+
self.order = order
40+
self.dim = dim
41+
self.eps = epsilon
42+
43+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
44+
return torch.nn.functional.normalize(
45+
tensor, p=self.order, dim=self.dim, eps=self.eps
46+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import math
2+
import typing
3+
4+
import numpy as np
5+
import torch
6+
from colour import Color
7+
from diff_gauss import GaussianRasterizationSettings, GaussianRasterizer
8+
9+
__all__ = ["GaussianSplatRasterizer"]
10+
11+
12+
# NOTE: different gaussian splatting versions produce different number and order of outputs
13+
# but all collide using the same package
14+
15+
16+
class GaussianSplatRasterizer(torch.nn.Module):
17+
def __init__(
18+
self,
19+
width: typing.Optional[int] = None,
20+
height: typing.Optional[int] = None,
21+
background_color: str = "black",
22+
prefiltered: bool = False,
23+
debug: bool = False,
24+
scale_modifier: float = 1.0,
25+
):
26+
super().__init__()
27+
color = Color(background_color).get_rgb()
28+
self.register_buffer("background_color", torch.tensor(color))
29+
self.width, self.height = width, height
30+
self.prefiltered, self.debug = prefiltered, debug
31+
self.scale_modifier = scale_modifier
32+
33+
def forward(
34+
self,
35+
view_matrix: torch.Tensor, # [B, 4, 4]
36+
view_projection_matrix: torch.Tensor, # [B, 4, 4]
37+
camera_position: torch.Tensor, # [B, 3]
38+
positions: torch.Tensor, # [B, V, 3]
39+
sh_coeffs: torch.Tensor, # [B, SH, 3]
40+
opacities: torch.Tensor, # [B, V, 1]
41+
rotations: torch.Tensor, # [B, V, 4]
42+
scales: torch.Tensor, # [B, V, 3]
43+
features: torch.Tensor, # [B, V, K]
44+
intrinsics: torch.Tensor, # [B, 3, 3]
45+
image: typing.Optional[torch.Tensor] = None, # [B, C, H, W]
46+
background_color: typing.Optional[torch.Tensor] = None, # [B, 3]
47+
):
48+
assert len(positions.shape) == 3
49+
B = view_matrix.shape[0]
50+
if positions.shape[0] != B: # either many-to-many or one-to-many
51+
positions = positions.expand(B, -1, -1)
52+
sh_coeffs = sh_coeffs.expand(B, -1, -1, -1)
53+
opacities = opacities.expand(B, -1, -1)
54+
rotations = rotations.expand(B, -1, -1)
55+
scales = scales.expand(B, -1, -1)
56+
features = features.expand(B, -1, -1)
57+
bg = background_color if background_color is not None else self.background_color
58+
sh_degree = math.sqrt(sh_coeffs.shape[-2]) - 1
59+
if bg.shape[0] != B:
60+
bg = bg.expand(B, -1)
61+
colors, radiis, depths, alphas, extras = [], [], [], [], []
62+
for i in range(B):
63+
W = image[i].shape[-1] if image is not None else self.width
64+
H = image[i].shape[-2] if image is not None else self.height
65+
settings = GaussianRasterizationSettings(
66+
image_height=H,
67+
image_width=W,
68+
# tanfovx=2.0 * np.arctan(W / (2.0 * intrinsics[i, 0, 0].cpu().float())),
69+
# tanfovy=2.0 * np.arctan(H / (2.0 * intrinsics[i, 1, 1].cpu().float())),
70+
tanfovx=W / (2.0 * intrinsics[i, 0, 0].cpu().float()),
71+
tanfovy=H / (2.0 * intrinsics[i, 1, 1].cpu().float()),
72+
bg=bg[i],
73+
scale_modifier=self.scale_modifier,
74+
viewmatrix=view_matrix[i],
75+
projmatrix=view_projection_matrix[i],
76+
sh_degree=int(sh_degree),
77+
campos=camera_position[i],
78+
prefiltered=self.prefiltered,
79+
debug=self.debug,
80+
)
81+
screenspace_points = torch.zeros_like(
82+
positions[i]
83+
) # , requires_grad=True) + 0
84+
screenspace_points.requires_grad_(True)
85+
screenspace_points.retain_grad()
86+
rasterizer = GaussianRasterizer(settings)
87+
color, depth, norm, alpha, radii, feats = rasterizer( # TODO inv_depth
88+
means3D=positions[i],
89+
means2D=screenspace_points,
90+
opacities=opacities[i],
91+
scales=scales[i],
92+
rotations=rotations[i],
93+
shs=sh_coeffs[i],
94+
extra_attrs=features[i],
95+
colors_precomp=None,
96+
cov3Ds_precomp=None,
97+
)
98+
colors.append(color)
99+
radiis.append(radii)
100+
depths.append(depth)
101+
alphas.append(alpha)
102+
extras.append(feats)
103+
# "viewspace_points": screenspace_points,
104+
# "visibility_filter" : radii > 0,
105+
return {
106+
"color": torch.stack(colors).clamp(0, 1),
107+
"radii": torch.stack(radiis),
108+
"depth": torch.stack(depths),
109+
"alpha": torch.stack(alphas),
110+
"features": torch.stack(extras),
111+
}

moai/monads/tensor/cast.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22

3-
__all__ = ["Int32"]
3+
__all__ = ["Int32, Float32"]
44

55

66
class Int32(torch.nn.Module):
@@ -9,3 +9,11 @@ def __init__(self) -> None:
99

1010
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
1111
return tensor.int()
12+
13+
14+
class Float32(torch.nn.Module):
15+
def __init__(self) -> None:
16+
super().__init__()
17+
18+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
19+
return tensor.float()

0 commit comments

Comments
 (0)