Skip to content

Commit 88498de

Browse files
committed
feat: parameter groups & splatting monads
+ dsl updates (tanh, sqrt, roll) ! fix: bug in dsl's sub order when using scalar numbers + new torch loader kwargs + image torchmetrics + quat monads + adopt optimizer + updated parameter selection to support different optimizer kwargs for different params ! fix: zip iterator + add rerun config logging ! fix: scheduling for fitting pipelines + mode for rerun scalars logging
1 parent 18b5381 commit 88498de

File tree

62 files changed

+1419
-40
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+1419
-40
lines changed

hydra_plugins/moai_dsl_plugin/moai_dsl_plugin.py

+5
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,12 @@
9292
| "asin" "(" expr ")" -> asin
9393
| "tan" "(" name ")" -> tan
9494
| "tan" "(" expr ")" -> tan
95+
| "tanh" "(" name ")" -> tanh
96+
| "tanh" "(" expr ")" -> tanh
9597
| "atan" "(" name ")" -> atan
9698
| "atan" "(" expr ")" -> atan
99+
| "sqrt" "(" name ")" -> sqrt
100+
| "sqrt" "(" expr ")" -> sqrt
97101
| "deg" "(" name ")" -> rad2deg
98102
| "deg" "(" expr ")" -> rad2deg
99103
| "rad" "(" name ")" -> deg2rad
@@ -108,6 +112,7 @@
108112
| "transpose" "(" name "," SIGNED_INT ("," SIGNED_INT)* ")" -> transpose
109113
| "flatten" "(" name "," SIGNED_INT ["," SIGNED_INT] ")" -> flatten
110114
| "repeat_interleave" "(" name "," SIGNED_INT "," SIGNED_INT ")" -> repeat
115+
| "roll" "(" name "," SIGNED_INT "," SIGNED_INT ")" -> roll
111116
| "zeros" "(" name ")" -> zeros_like
112117
| "ones" "(" name ")" -> ones_like
113118
| "rand" "(" name ")" -> rand_like

moai/conf/data/test/loader/torch.yaml

+5-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,8 @@ batch_size: ${oc.select:experiment.batch_size,1}
55
shuffle: false
66
num_workers: ${oc.select:experiment.workers,0}
77
pin_memory: false
8-
drop_last: false
8+
drop_last: false
9+
prefetch_factor: 0
10+
persistent_workers: false
11+
pin_memory_device: ""
12+
# in_order: true

moai/conf/data/train/loader/torch.yaml

+5-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,8 @@ batch_size: ${oc.select:experiment.batch_size,1}
55
shuffle: true
66
num_workers: ${oc.select:experiment.workers,0}
77
pin_memory: false
8-
drop_last: false
8+
drop_last: false
9+
prefetch_factor: 0
10+
persistent_workers: false
11+
pin_memory_device: ""
12+
# in_order: true
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# @package model.metrics.learned_perceptual_image_patch_similarity
2+
3+
_target_: moai.validation.torchmetric.LPIPS
4+
type: LearnedPerceptualImagePatchSimilarity
5+
net_type: alex # vgg, squeeze
6+
normalize: false
7+
reduction: mean # sum, none
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# @package model.metrics.peak_signal_noise_ratio
2+
3+
_target_: moai.validation.torchmetric.ImageTorchMetric
4+
type: PeakSignalNoiseRatio
5+
data_range: 1.0
6+
base: 10.0
7+
reduction: elementwise_mean # sum, none
8+
dim: null
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# @package model.metrics.structural_similarity_index_measure
2+
3+
_target_: moai.validation.torchmetric.ImageTorchMetric
4+
type: StructuralSimilarityIndexMeasure
5+
gaussian_kernel: true
6+
sigma: 1.5
7+
kernel_size: 11
8+
reduction: elementwise_mean # sum, none
9+
data_range: 1.0
10+
k1: 0.01
11+
k2: 0.03
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# @package model.monads.gaussian_splat_parameters
2+
3+
_target_: moai.monads.generation.splats.gaussian_parameters.GaussianSplattingParameters
4+
num_splats: ???
5+
num_sets: 1
6+
skip_positions: false
7+
scale: 1.0
8+
opacity: null
9+
sh_degree: 0
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
# @package model.monads._name_
1+
# @package model.monads.clone
22

33
_target_: moai.monads.generation.tensor.Clone
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# @package model.monads._name_
1+
# @package model.monads.named_params
22

33
_target_: moai.monads.generation.tensor.Parameters
44
parameters: ???
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
# @package model.monads._name_
1+
# @package model.monads.npy
22

33
_target_: moai.monads.generation.tensor.Npy

moai/conf/model/monads/generation/tensor/ones.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# @package model.monads._name_
1+
# @package model.monads.ones
22

33
_target_: moai.monads.generation.tensor.Ones
44
shape: ???
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# @package model.monads.quaternion_composition
2+
3+
_target_: moai.monads.geometry.rotations.quaternion.QuaternionComposition
4+
normalize: false
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
# @package model.monads._name_
1+
# @package model.monads.unitquat_to_rotmat
22

33
_target_: moai.monads.geometry.rotations.Quaternion2RotationMatrix

moai/conf/model/monads/math/clamp.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# @package model.monads._name_
1+
# @package model.monads.clamp
22

33
_target_: moai.monads.math.Clamp
44
min_value: 0.0
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# @package model.monads.gaussian_splat_rasterizer
2+
3+
_target_: moai.monads.render.splats.rasterize.GaussianSplatRasterizer
4+
prefiltered: false
5+
debug: false
6+
scale_modifier: 1.0
7+
width: null
8+
height: null
9+
background_color: black
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# @package model.monads.gaussian_splat_rasterizer
2+
3+
_target_: moai.monads.render.splats.rasterize_ashawkey.GaussianSplatRasterizer
4+
prefiltered: false
5+
debug: false
6+
scale_modifier: 1.0
7+
width: null
8+
height: null
9+
background_color: black
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# @package model.monads.remap
2+
3+
_target_: moai.monads.sampling.torch.remap.Remap
4+
align_corners: null
5+
sample_mode: bilinear
6+
padding_mode: zeros
+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
# @package model.monads._name_
1+
# @package model.monads.detach
22

33
_target_: moai.monads.tensor.Detach
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# @package model.monitors.config
2+
3+
_target_: moai.visualization.rerun.config.Config
4+
path: '/config'
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# @package model.objectives.structural_dissimilarity
2+
3+
_target_: moai.supervision.objectives.image.ssim.StructuralDissimilarity
4+
window_size: 7
5+
dynamic_range: 1.0
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# @package model.objectives.surface_feature_consistency
2+
3+
_target_: moai.supervision.objectives.mesh.consistency.SurfaceFeatureConsistency
4+
order: 2
+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# @package model.objectives.passthrough
22

3-
_target_: moai.supervision.objectives.Passthrough
3+
_target_: moai.supervision.objectives.passthrough.Passthrough
44
mode: minimize # one of [minimize, maximize]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# @package model.objectives.long_and_thin_penalty
2+
3+
_target_: moai.supervision.objectives.splats.scaling.LongAndThinPenalty
4+
max_scale: 0.008
5+
scale_ratio: 10.0
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# @package model.parameters.groups.model3
2+
3+
_target_: moai.parameters.selectors.model.ModelParameterSelector
4+
modules: null # optional
5+
monads: null # optional
6+
parameters: null # optional
7+
force_grad: true
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# @package model.parameters.groups.model4
2+
3+
_target_: moai.parameters.selectors.model.ModelParameterSelector
4+
modules: null # optional
5+
monads: null # optional
6+
parameters: null # optional
7+
force_grad: true
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# @package model.parameters.groups.model5
2+
3+
_target_: moai.parameters.selectors.model.ModelParameterSelector
4+
modules: null # optional
5+
monads: null # optional
6+
parameters: null # optional
7+
force_grad: true
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# @package model.parameters.groups.model_groups
2+
3+
_target_: moai.parameters.selectors.model_groups.ModelGroupParameterSelector
4+
groups: ???
5+
# each group (key-dict pair) contains
6+
# modules: null # optional
7+
# monads: null # optional
8+
# parameters: null # optional
9+
# force_grad: true
10+
# and extra optimizer relaed params
11+
# lr: float
12+
# etc.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# @package model.parameters.groups.model_groups1
2+
3+
_target_: moai.parameters.selectors.model_groups.ModelGroupParameterSelector
4+
groups: ???
5+
# each group (key-dict pair) contains
6+
# modules: null # optional
7+
# monads: null # optional
8+
# parameters: null # optional
9+
# force_grad: true
10+
# and extra optimizer relaed params
11+
# lr: float
12+
# etc.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# @package model.parameters.groups.model_groups2
2+
3+
_target_: moai.parameters.selectors.model_groups.ModelGroupParameterSelector
4+
groups: ???
5+
# each group (key-dict pair) contains
6+
# modules: null # optional
7+
# monads: null # optional
8+
# parameters: null # optional
9+
# force_grad: true
10+
# and extra optimizer relaed params
11+
# lr: float
12+
# etc.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# @package model.parameters.groups.model_groups3
2+
3+
_target_: moai.parameters.selectors.model_groups.ModelGroupParameterSelector
4+
groups: ???
5+
# each group (key-dict pair) contains
6+
# modules: null # optional
7+
# monads: null # optional
8+
# parameters: null # optional
9+
# force_grad: true
10+
# and extra optimizer relaed params
11+
# lr: float
12+
# etc.
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# @package model.parameters.initializers.default
22

3-
_target_: moai.parameters.initialization.default.Default
3+
_target_: moai.parameters.initialization.default.Default
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# @package model.parameters.optimizers.adopt
2+
3+
_target_: moai.parameters.optimization.optimizers.adopt.ADOPT
4+
lr: 0.0001
5+
betas: [0.9, 0.999]
6+
eps: 1e-6 # 1e-8
7+
weight_decay: 0.0
8+
decouple: false
9+
foreach: null
10+
maximize: false
11+
capturable: false
12+
differentiable: false
13+
fused: null

moai/core/execution/expression.py

+22-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import dataclasses
2+
import numbers
23
import typing
34

45
import benedict
@@ -74,21 +75,27 @@ def forward(self, td, tmp) -> None:
7475
@dataclasses.dataclass(repr=False, unsafe_hash=True)
7576
class BinaryOperationScalar(torch.nn.Module):
7677
operation: str
77-
lhs: str
78-
rhs: typing.Union[float, int]
78+
lhs: typing.Union[float, int, str]
79+
rhs: typing.Union[float, int, str]
7980
index: int
8081

8182
def __post_init__(self):
8283
super().__init__()
8384
self.op = getattr(torch, self.operation)
85+
self.is_lhs_scalar = isinstance(self.lhs, numbers.Number)
8486

8587
def __repr__(self):
8688
return f"{self.operation}:{self.lhs},{self.rhs}"
8789

8890
def forward(self, td, tmp) -> None:
89-
tmp[f"result{self.index}"] = self.op(
90-
toolz.get_in(self.lhs.split("."), tmp), self.rhs
91-
)
91+
if self.is_lhs_scalar:
92+
tmp[f"result{self.index}"] = self.op(
93+
-toolz.get_in(self.rhs.split("."), tmp), -self.lhs
94+
)
95+
else:
96+
tmp[f"result{self.index}"] = self.op(
97+
toolz.get_in(self.lhs.split("."), tmp), self.rhs
98+
)
9299

93100

94101
@dataclasses.dataclass(repr=False, unsafe_hash=True)
@@ -283,7 +290,7 @@ def _binary(self, name, lhs, rhs):
283290
if rhs is None:
284291
rhs = self.results.pop()
285292
if not isinstance(lhs, str):
286-
m = BinaryOperationScalar(name, rhs, lhs, self.index)
293+
m = BinaryOperationScalar(name, lhs, rhs, self.index)
287294
elif not isinstance(rhs, str):
288295
m = BinaryOperationScalar(name, lhs, rhs, self.index)
289296
else:
@@ -572,6 +579,10 @@ def acos(self, key):
572579
def tan(self, key):
573580
pass
574581

582+
@unary
583+
def tanh(self, key):
584+
pass
585+
575586
@unary
576587
def atan(self, key):
577588
pass
@@ -581,7 +592,7 @@ def abs(self, key):
581592
pass
582593

583594
@unary
584-
def abs(self, key):
595+
def sqrt(self, key):
585596
pass
586597

587598
@unary
@@ -716,6 +727,10 @@ def repeat(self, key, *dims):
716727
dims = list(map(int, dims))
717728
self._transform_operation("repeat_interleave", key, dims)
718729

730+
def roll(self, key, shift, dim):
731+
key = self.extract(key) # NOTE: only supports single shift/dim
732+
self._transform_operation("roll", key, [int(shift), int(dim)])
733+
719734
def unsqueeze(self, key, *dims):
720735
if not isinstance(key, str) or isinstance(key, Token): # NOTE: is lark.Tree
721736
key = self.extract(key)

moai/core/execution/metrics.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def __init__(
6262
extra_params = set(metric_kwargs.keys()) - set(sig_params) - set([C._OUT_])
6363
if extra_params:
6464
log.error(
65-
f"The parameters [{extra_params}] are not part of the `{key}` metric signature."
65+
f"The parameters [{extra_params}] are not part of the `{key}` metric signature ({list(sig.parameters.keys())})."
6666
)
6767
metric_kwargs = mic._dict_of_lists_to_list_of_dicts(metric_kwargs)
6868
for j, params in enumerate(metric_kwargs):

moai/core/execution/models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(
4343
extra_params = set(model_params.keys()) - set(sig_params) - set([C._OUT_])
4444
if extra_params:
4545
log.error(
46-
f"The parameters [{extra_params}] are not part of the `{key}` model signature."
46+
f"The parameters [{extra_params}] are not part of the `{key}` model signature ({list(sig.parameters.keys())})."
4747
)
4848
model_params = mic._dict_of_lists_to_list_of_dicts(model_params)
4949
for j, params in enumerate(model_params):

moai/core/execution/monads.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __init__(
6969
extra_params = set(graph_params.keys()) - set(sig_params) - set([C._OUT_])
7070
if extra_params:
7171
log.error(
72-
f"The parameters [{extra_params}] are not part of the `{key}` monad signature."
72+
f"The parameters [{extra_params}] are not part of the `{key}` monad signature ({list(sig.parameters.keys())})."
7373
)
7474
graph_params = mic._dict_of_lists_to_list_of_dicts(graph_params)
7575
for j, params in enumerate(graph_params):

0 commit comments

Comments
 (0)