From fce775f25cc4cfad388d8ad073238e64f8d53b72 Mon Sep 17 00:00:00 2001 From: jadgardner Date: Mon, 27 Nov 2023 14:41:22 +0000 Subject: [PATCH] refactoring --- reni/configs/reni_config.py | 9 ++++----- reni/configs/sh_sg_envmap_configs.py | 24 +++++++++--------------- reni/models/reni_model.py | 5 +---- reni/pipelines/reni_pipeline.py | 23 +++++++++-------------- 4 files changed, 23 insertions(+), 38 deletions(-) diff --git a/reni/configs/reni_config.py b/reni/configs/reni_config.py index f526e42..1949367 100644 --- a/reni/configs/reni_config.py +++ b/reni/configs/reni_config.py @@ -22,11 +22,11 @@ method_name="reni", experiment_name="reni", machine=MachineConfig(), - steps_per_eval_image=50000, - steps_per_eval_batch=10000000, + steps_per_eval_image=5000, + steps_per_eval_batch=50002, steps_per_save=10000, - save_only_latest_checkpoint=False, - steps_per_eval_all_images=50000, + save_only_latest_checkpoint=True, + steps_per_eval_all_images=5000, max_num_iterations=50001, mixed_precision=False, pipeline=RENIPipelineConfig( @@ -34,7 +34,6 @@ datamanager=RENIDataManagerConfig( dataparser=RENIDataParserConfig( data=Path("data/RENI_HDR"), - download_data=False, train_subset_size=None, val_subset_size=None, convert_to_ldr=False, diff --git a/reni/configs/sh_sg_envmap_configs.py b/reni/configs/sh_sg_envmap_configs.py index fd0ecf8..4b1e8f5 100644 --- a/reni/configs/sh_sg_envmap_configs.py +++ b/reni/configs/sh_sg_envmap_configs.py @@ -63,7 +63,6 @@ "cosine_similarity_loss": 1.0, "kld_loss": 0.00001, "scale_inv_loss": 1.0, - "scale_inv_grad_loss": 1.0, }, loss_inclusions={ "log_mse_loss": True, @@ -72,9 +71,6 @@ "cosine_similarity_loss": False, "kld_loss": False, "scale_inv_loss": False, - "scale_inv_grad_loss": False, - "bce_loss": False, # For RESGAN, leave False in this config - "wgan_loss": False, # For RESGAN, leave False in this config }, ), ), @@ -131,7 +127,6 @@ "cosine_similarity_loss": 1.0, "kld_loss": 0.00001, "scale_inv_loss": 1.0, - "scale_inv_grad_loss": 1.0, }, loss_inclusions={ "log_mse_loss": True, @@ -140,9 +135,6 @@ "cosine_similarity_loss": False, "kld_loss": False, "scale_inv_loss": False, - "scale_inv_grad_loss": False, - "bce_loss": False, # For RESGAN, leave False in this config - "wgan_loss": False, # For RESGAN, leave False in this config }, ), ), @@ -194,18 +186,20 @@ apply_padding=True, ), loss_coefficients={ - "mse_loss": 10.0, + "log_mse_loss": 1.0, + "hdr_mse_loss": 1.0, + "ldr_mse_loss": 1.0, "cosine_similarity_loss": 1.0, "kld_loss": 0.00001, "scale_inv_loss": 1.0, - "scale_inv_grad_loss": 1.0, }, loss_inclusions={ - "mse_loss": False, - "cosine_similarity_loss": True, - "kld_loss": True, - "scale_inv_loss": True, - "scale_inv_grad_loss": False, + "log_mse_loss": True, + "hdr_mse_loss": False, + "ldr_mse_loss": False, + "cosine_similarity_loss": False, + "kld_loss": False, + "scale_inv_loss": False, }, ), ), diff --git a/reni/models/reni_model.py b/reni/models/reni_model.py index cb397b4..14fe62d 100644 --- a/reni/models/reni_model.py +++ b/reni/models/reni_model.py @@ -152,10 +152,7 @@ def get_outputs(self, ray_bundle: RayBundle, batch: Optional[dict] = None): ray_samples = self.create_ray_samples(ray_bundle.origins, ray_bundle.directions, ray_bundle.camera_indices) - rotation = None - latent_codes = None # if auto-decoder training regime latents are trainable params of the field - - field_outputs = self.field.forward(ray_samples=ray_samples, rotation=rotation, latent_codes=latent_codes) + field_outputs = self.field.forward(ray_samples=ray_samples) outputs = { "rgb": field_outputs[RENIFieldHeadNames.RGB], diff --git a/reni/pipelines/reni_pipeline.py b/reni/pipelines/reni_pipeline.py index 6799436..786e65b 100644 --- a/reni/pipelines/reni_pipeline.py +++ b/reni/pipelines/reni_pipeline.py @@ -128,7 +128,12 @@ def __init__( self._model = typing.cast(Model, DDP(self._model, device_ids=[local_rank], find_unused_parameters=True)) dist.barrier(device_ids=[local_rank]) - self.last_step_of_eval_optimisation = 0 # used to avoid fitting eval latents on each eval function call below + self.step_of_last_latent_optimisation = 0 + + def _optimise_evaluation_latents(self, step): + if self.step_of_last_latent_optimisation != step: + self.model.fit_eval_latents(self.datamanager) + self.step_of_last_latent_optimisation = step def forward(self): """Blank forward method @@ -162,11 +167,7 @@ def get_eval_loss_dict(self, step: int): step: current iteration step """ self.eval() - # if we haven't already fit the eval latents this step, do it now - if self.last_step_of_eval_optimisation != step: - if self.model.config.training_regime != "vae": - self.model.fit_eval_latents(self.datamanager) - self.last_step_of_eval_optimisation = step + self._optimise_evaluation_latents(step) ray_bundle, batch = self.datamanager.next_eval(step) model_outputs = self.model(ray_bundle, batch) metrics_dict = self.model.get_metrics_dict(model_outputs, batch) @@ -184,10 +185,7 @@ def get_eval_image_metrics_and_images(self, step: int): """ self.eval() # if we haven't already fit the eval latents this step, do it now - if self.last_step_of_eval_optimisation != step: - if self.model.config.training_regime != "vae": - self.model.fit_eval_latents(self.datamanager) - self.last_step_of_eval_optimisation = step + self._optimise_evaluation_latents(step) image_idx, ray_bundle, batch = self.datamanager.next_eval_image(step) outputs = self.model(ray_bundle, batch) metrics_dict, images_dict = self.model.get_image_metrics_and_images(outputs, batch) @@ -207,10 +205,7 @@ def get_average_eval_image_metrics(self, step: Optional[int] = None, optimise_la """ self.eval() # if we haven't already fit the eval latents this step, do it now - if self.last_step_of_eval_optimisation != step: - if self.model.config.training_regime != "vae" and optimise_latents: - self.model.fit_eval_latents(self.datamanager) - self.last_step_of_eval_optimisation = step + self._optimise_evaluation_latents(step) metrics_dict_list = [] num_images = len(self.datamanager.fixed_indices_eval_dataloader) # get all eval images