Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
JADGardner committed Nov 27, 2023
1 parent 7676392 commit fce775f
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 38 deletions.
9 changes: 4 additions & 5 deletions reni/configs/reni_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,18 @@
method_name="reni",
experiment_name="reni",
machine=MachineConfig(),
steps_per_eval_image=50000,
steps_per_eval_batch=10000000,
steps_per_eval_image=5000,
steps_per_eval_batch=50002,
steps_per_save=10000,
save_only_latest_checkpoint=False,
steps_per_eval_all_images=50000,
save_only_latest_checkpoint=True,
steps_per_eval_all_images=5000,
max_num_iterations=50001,
mixed_precision=False,
pipeline=RENIPipelineConfig(
test_mode='val',
datamanager=RENIDataManagerConfig(
dataparser=RENIDataParserConfig(
data=Path("data/RENI_HDR"),
download_data=False,
train_subset_size=None,
val_subset_size=None,
convert_to_ldr=False,
Expand Down
24 changes: 9 additions & 15 deletions reni/configs/sh_sg_envmap_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
"cosine_similarity_loss": 1.0,
"kld_loss": 0.00001,
"scale_inv_loss": 1.0,
"scale_inv_grad_loss": 1.0,
},
loss_inclusions={
"log_mse_loss": True,
Expand All @@ -72,9 +71,6 @@
"cosine_similarity_loss": False,
"kld_loss": False,
"scale_inv_loss": False,
"scale_inv_grad_loss": False,
"bce_loss": False, # For RESGAN, leave False in this config
"wgan_loss": False, # For RESGAN, leave False in this config
},
),
),
Expand Down Expand Up @@ -131,7 +127,6 @@
"cosine_similarity_loss": 1.0,
"kld_loss": 0.00001,
"scale_inv_loss": 1.0,
"scale_inv_grad_loss": 1.0,
},
loss_inclusions={
"log_mse_loss": True,
Expand All @@ -140,9 +135,6 @@
"cosine_similarity_loss": False,
"kld_loss": False,
"scale_inv_loss": False,
"scale_inv_grad_loss": False,
"bce_loss": False, # For RESGAN, leave False in this config
"wgan_loss": False, # For RESGAN, leave False in this config
},
),
),
Expand Down Expand Up @@ -194,18 +186,20 @@
apply_padding=True,
),
loss_coefficients={
"mse_loss": 10.0,
"log_mse_loss": 1.0,
"hdr_mse_loss": 1.0,
"ldr_mse_loss": 1.0,
"cosine_similarity_loss": 1.0,
"kld_loss": 0.00001,
"scale_inv_loss": 1.0,
"scale_inv_grad_loss": 1.0,
},
loss_inclusions={
"mse_loss": False,
"cosine_similarity_loss": True,
"kld_loss": True,
"scale_inv_loss": True,
"scale_inv_grad_loss": False,
"log_mse_loss": True,
"hdr_mse_loss": False,
"ldr_mse_loss": False,
"cosine_similarity_loss": False,
"kld_loss": False,
"scale_inv_loss": False,
},
),
),
Expand Down
5 changes: 1 addition & 4 deletions reni/models/reni_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,7 @@ def get_outputs(self, ray_bundle: RayBundle, batch: Optional[dict] = None):

ray_samples = self.create_ray_samples(ray_bundle.origins, ray_bundle.directions, ray_bundle.camera_indices)

rotation = None
latent_codes = None # if auto-decoder training regime latents are trainable params of the field

field_outputs = self.field.forward(ray_samples=ray_samples, rotation=rotation, latent_codes=latent_codes)
field_outputs = self.field.forward(ray_samples=ray_samples)

outputs = {
"rgb": field_outputs[RENIFieldHeadNames.RGB],
Expand Down
23 changes: 9 additions & 14 deletions reni/pipelines/reni_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,12 @@ def __init__(
self._model = typing.cast(Model, DDP(self._model, device_ids=[local_rank], find_unused_parameters=True))
dist.barrier(device_ids=[local_rank])

self.last_step_of_eval_optimisation = 0 # used to avoid fitting eval latents on each eval function call below
self.step_of_last_latent_optimisation = 0

def _optimise_evaluation_latents(self, step):
if self.step_of_last_latent_optimisation != step:
self.model.fit_eval_latents(self.datamanager)
self.step_of_last_latent_optimisation = step

def forward(self):
"""Blank forward method
Expand Down Expand Up @@ -162,11 +167,7 @@ def get_eval_loss_dict(self, step: int):
step: current iteration step
"""
self.eval()
# if we haven't already fit the eval latents this step, do it now
if self.last_step_of_eval_optimisation != step:
if self.model.config.training_regime != "vae":
self.model.fit_eval_latents(self.datamanager)
self.last_step_of_eval_optimisation = step
self._optimise_evaluation_latents(step)
ray_bundle, batch = self.datamanager.next_eval(step)
model_outputs = self.model(ray_bundle, batch)
metrics_dict = self.model.get_metrics_dict(model_outputs, batch)
Expand All @@ -184,10 +185,7 @@ def get_eval_image_metrics_and_images(self, step: int):
"""
self.eval()
# if we haven't already fit the eval latents this step, do it now
if self.last_step_of_eval_optimisation != step:
if self.model.config.training_regime != "vae":
self.model.fit_eval_latents(self.datamanager)
self.last_step_of_eval_optimisation = step
self._optimise_evaluation_latents(step)
image_idx, ray_bundle, batch = self.datamanager.next_eval_image(step)
outputs = self.model(ray_bundle, batch)
metrics_dict, images_dict = self.model.get_image_metrics_and_images(outputs, batch)
Expand All @@ -207,10 +205,7 @@ def get_average_eval_image_metrics(self, step: Optional[int] = None, optimise_la
"""
self.eval()
# if we haven't already fit the eval latents this step, do it now
if self.last_step_of_eval_optimisation != step:
if self.model.config.training_regime != "vae" and optimise_latents:
self.model.fit_eval_latents(self.datamanager)
self.last_step_of_eval_optimisation = step
self._optimise_evaluation_latents(step)
metrics_dict_list = []
num_images = len(self.datamanager.fixed_indices_eval_dataloader)
# get all eval images
Expand Down

0 comments on commit fce775f

Please sign in to comment.