Skip to content

Commit

Permalink
update config
Browse files Browse the repository at this point in the history
  • Loading branch information
JADGardner committed Nov 27, 2023
1 parent 61ec227 commit 7676392
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 66 deletions.
6 changes: 3 additions & 3 deletions reni/configs/reni_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
max_num_iterations=50001,
mixed_precision=False,
pipeline=RENIPipelineConfig(
test_mode='test',
test_mode='val',
datamanager=RENIDataManagerConfig(
dataparser=RENIDataParserConfig(
data=Path("data/RENI_HDR"),
Expand Down Expand Up @@ -66,7 +66,7 @@
axis_of_invariance="z", # Nerfstudio world space is z-up # old reni implementation was y-up
positional_encoding="NeRF",
encoded_input="Directions", # "InvarDirection", "Directions", "Conditioning", "Both", "None"
latent_dim=36, # N for a latent code size of (N x 3) # 9, 36, 49, 100 (for paper sizes)
latent_dim=100, # N for a latent code size of (N x 3) # 9, 36, 49, 100 (for paper sizes)
hidden_features=128, # ALL
hidden_layers=9, # SIRENs
mapping_layers=5, # FiLM MAPPING NETWORK
Expand All @@ -77,7 +77,7 @@
last_layer_linear=True, # SIRENs
fixed_decoder=False, # ALL
trainable_scale=False, # Used in inverse setting
old_implementation=False,
old_implementation=False, # Used to match prior RENI implementation, world space is y-up
),
eval_latent_optimizer={
"eval_latents": {
Expand Down
58 changes: 7 additions & 51 deletions reni/data/datamanagers/reni_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def __init__(
self.sampler = None
self.test_mode = test_mode
self.test_split = "test" if test_mode in ["test", "inference"] else "val"
self.using_scale_inv_grad_loss = kwargs.get("using_scale_inv_grad_loss", False)
self.dataparser_config = self.config.dataparser
if self.config.data is not None:
self.config.dataparser.data = Path(self.config.data)
Expand Down Expand Up @@ -221,22 +220,8 @@ def next_train(self, step: int) -> Tuple[RayBundle, Dict]:
assert self.train_pixel_sampler is not None
assert isinstance(image_batch, dict)
batch = self.train_pixel_sampler.sample(image_batch)
if self.using_scale_inv_grad_loss:
finite_diff_grad_indices = batch["indices"].clone()
# update the index to be samples_idxs
finite_diff_grad_indices[:, 0] = batch["sampled_idxs"].clone()
# roll the x values by 1 using datamanger.image_width as mod
finite_diff_grad_indices[:, 2] = (finite_diff_grad_indices[:, 2] + 1) % self.image_width
finite_diff_batch = self.train_pixel_sampler.sample(image_batch, indices=finite_diff_grad_indices)
ray_indices = torch.cat([batch["indices"], finite_diff_batch["indices"]], dim=0) # [2N, 3]
for key in batch:
batch[key] = torch.stack([batch[key], finite_diff_batch[key]], dim=0) # [2, N, ...]
batch.pop("sampled_idxs")
ray_bundle = self.train_ray_generator(ray_indices) # [2N, 3]
ray_bundle = self.stack_ray_bundle(ray_bundle, self.config.train_num_rays_per_batch) # [2, N, 3]
else:
ray_indices = batch["indices"] # [N, 3]
ray_bundle = self.train_ray_generator(ray_indices) # [N, 3]
ray_indices = batch["indices"] # [N, 3]
ray_bundle = self.train_ray_generator(ray_indices) # [N, 3]
return ray_bundle, batch

def next_eval(self, step: int) -> Tuple[RayBundle, Dict]:
Expand All @@ -246,22 +231,8 @@ def next_eval(self, step: int) -> Tuple[RayBundle, Dict]:
assert self.eval_pixel_sampler is not None
assert isinstance(image_batch, dict)
batch = self.eval_pixel_sampler.sample(image_batch)
if self.using_scale_inv_grad_loss:
finite_diff_grad_indices = batch["indices"].clone()
# update the index to be samples_idxs
finite_diff_grad_indices[:, 0] = batch["sampled_idxs"].clone()
# # roll the x values by 1 using datamanger.image_width as mod
finite_diff_grad_indices[:, 2] = (finite_diff_grad_indices[:, 2] + 1) % self.image_width
finite_diff_batch = self.eval_pixel_sampler.sample(image_batch, indices=finite_diff_grad_indices)
ray_indices = torch.cat([batch["indices"], finite_diff_batch["indices"]], dim=0) # [2N, 3]
for key in batch:
batch[key] = torch.stack([batch[key], finite_diff_batch[key]], dim=0) # [2, N, ...]
batch.pop("sampled_idxs")
ray_bundle = self.eval_ray_generator(ray_indices) # [2N, 3]
ray_bundle = self.stack_ray_bundle(ray_bundle, self.config.eval_num_rays_per_batch) # [2, N, 3]
else:
ray_indices = batch["indices"]
ray_bundle = self.eval_ray_generator(ray_indices)
ray_indices = batch["indices"]
ray_bundle = self.eval_ray_generator(ray_indices)
return ray_bundle, batch

def next_eval_image(self, idx: int) -> Tuple[int, RayBundle, Dict]:
Expand All @@ -275,24 +246,9 @@ def next_eval_image(self, idx: int) -> Tuple[int, RayBundle, Dict]:
assert self.eval_image_pixel_sampler is not None
assert isinstance(image_batch, dict)
batch = self.eval_image_pixel_sampler.sample(image_batch)
if self.using_scale_inv_grad_loss:
finite_diff_grad_indices = batch["indices"].clone()
# update the index to be samples_idxs
finite_diff_grad_indices[:, 0] = batch["sampled_idxs"].clone()
# # roll the x values by 1 using datamanger.image_width as mod
finite_diff_grad_indices[:, 2] = (finite_diff_grad_indices[:, 2] + 1) % self.image_width
finite_diff_batch = self.eval_image_pixel_sampler.sample(image_batch, indices=finite_diff_grad_indices)
ray_indices = torch.cat([batch["indices"], finite_diff_batch["indices"]], dim=0) # [2N, 3]
for key in batch:
batch[key] = torch.stack([batch[key], finite_diff_batch[key]], dim=0) # [2, N, ...]
batch.pop("sampled_idxs")
ray_bundle = self.eval_ray_generator(ray_indices) # [2N, 3]
ray_bundle = self.stack_ray_bundle(ray_bundle, self.image_height * self.image_width) # [2, N, 3]
image_idx = int(ray_bundle.camera_indices[0, 0, 0])
else:
ray_indices = batch["indices"]
ray_bundle = self.eval_ray_generator(ray_indices)
image_idx = int(ray_bundle.camera_indices[0, 0])
ray_indices = batch["indices"]
ray_bundle = self.eval_ray_generator(ray_indices)
image_idx = int(ray_bundle.camera_indices[0, 0])
return image_idx, ray_bundle, batch

def create_train_dataset(self) -> RENIDataset:
Expand Down
10 changes: 1 addition & 9 deletions reni/data/dataparsers/reni_inverse_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ class RENIInverseDataParserConfig(DataParserConfig):
"""target class to instantiate"""
data: Path = Path("data/RENI_HDR")
"""Directory specifying location of data."""
download_data: bool = False
"""Whether to download data."""
envmap_remove_indicies: Optional[list] = None
"""Indicies of environment maps to remove."""
specular_terms = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
Expand All @@ -68,13 +66,7 @@ def __init__(self, config: RENIInverseDataParserConfig):
def _generate_dataparser_outputs(self, split="train"):
path = self.data / split

# if it doesn't exist, download the data
url = "https://www.dropbox.com/s/15gn7zlzgua7s8n/RENI_HDR.zip?dl=1"
if not path.exists() and self.config.download_data:
wget.download(url, out=str(self.data) + ".zip")
with zipfile.ZipFile(str(self.data) + ".zip", "r") as zip_ref:
zip_ref.extractall(str(self.data))
Path(str(self.data) + ".zip").unlink()
assert path.exists(), f"Path {path} does not exist."

# get paths for all images in the directory
environment_maps_filenames = sorted(path.glob("*.exr"))
Expand Down
2 changes: 2 additions & 0 deletions reni/data/datasets/reni_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ def get_numpy_image(self, image_idx: int) -> npt.NDArray[np.uint8]:
num_cols_to_roll = int(np.round(img_width * angle_rad / (2 * np.pi)))
image = np.roll(image, -num_cols_to_roll, axis=1)

# only use the first 3 channels
image = image[:, :, :3]

assert np.all(np.isfinite(image)), "Image contains non finite values."
assert np.all(image >= 0), "Image contains negative values."
Expand Down
3 changes: 0 additions & 3 deletions reni/pipelines/reni_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,11 @@ def __init__(
self.config = config
self.test_mode = test_mode if self.config.test_mode is None else self.config.test_mode

self.using_scale_inv_grad_loss = self.config.model.loss_inclusions["scale_inv_grad_loss"]

self.datamanager: RENIDataManager = config.datamanager.setup(
device=device,
test_mode=self.test_mode,
world_size=world_size,
local_rank=local_rank,
using_scale_inv_grad_loss=self.using_scale_inv_grad_loss,
)
self.datamanager.to(device)
assert self.datamanager.train_dataset is not None, "Missing input dataset"
Expand Down

0 comments on commit 7676392

Please sign in to comment.