Skip to content

Commit

Permalink
update to handle trevi fountain
Browse files Browse the repository at this point in the history
  • Loading branch information
JADGardner committed Nov 15, 2023
1 parent c923b2e commit 1777e72
Show file tree
Hide file tree
Showing 11 changed files with 26,940 additions and 155 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -189,4 +189,5 @@ camera_paths/
*/**/.DS_Store
*/**/._.DS_Store
/reni_neus.egg-info
/checkpoints/
/checkpoints/
models/
353 changes: 271 additions & 82 deletions notebooks/test.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion ns_reni
26,441 changes: 26,441 additions & 0 deletions publication/figures_and_tables.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies=[
"ftfy",
"regex",
"torchtyping",
"kaleido",
]

[tool.setuptools.packages.find]
Expand Down
14 changes: 9 additions & 5 deletions reni_neus/configs/reni_neus_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,19 @@
test_mode='test',
datamanager=RENINeuSDataManagerConfig(
dataparser=NeRFOSRCityScapesDataParserConfig(
scene="site1",
scene="trevi",
auto_scale_poses=True,
crop_to_equal_size=True,
crop_to_equal_size=False,
pad_to_equal_size=False,
scene_scale=1.0, # AABB
mask_vegetation=False,
mask_vegetation=True,
mask_out_of_view_frustum_objects=True,
# session_holdout_indices=[0, 0, 0, 0, 0], # site 1
# session_holdout_indices=[3, 13, 2, 7, 9], # site 2
# session_holdout_indices=[19, 15, 16, 13, 11], # site 3 # potentially [19, 15, 16, 13, 15]
# session_holdout_indices=[0, 0, 0, 0, 0], # trevi
),
train_num_images_to_sample_from=1,
train_num_images_to_sample_from=-1,
train_num_times_to_repeat_images=-1, # # Iterations before resample a new subset
pixel_sampler=RENINeuSPixelSamplerConfig(),
images_on_gpu=False,
Expand Down Expand Up @@ -145,7 +149,7 @@
),
},
},
eval_latent_optimise_method="nerf_osr_holdout", # per_image, nerf_osr_holdout, nerf_osr_envmap
eval_latent_optimise_method="per_image", # per_image, nerf_osr_holdout, nerf_osr_envmap (can't run nerf_osr with trevi)
eval_latent_sample_region="full_image",
illumination_field_ckpt_path=Path("outputs/reni/reni_plus_plus_models/latent_dim_100/"),
illumination_field_ckpt_step=50000,
Expand Down
56 changes: 31 additions & 25 deletions reni_neus/data/datamanagers/reni_neus_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ def __init__(
# TODO This is a mess, can only have test or val at one time anyway so just use one variable
# This will need updating in pipeline and model too
if self.eval_latent_optimise_method == "per_image":
self.num_test = len(test_outputs.image_filenames)
self.num_val = len(val_outputs.image_filenames)
self.num_test = len(self.test_outputs.image_filenames)
self.num_val = len(self.val_outputs.image_filenames)
else:
self.num_test = len(self.eval_dataset.metadata["session_to_indices"].keys())
self.num_val = len(self.eval_dataset.metadata["session_to_indices"].keys())
Expand Down Expand Up @@ -143,12 +143,12 @@ def create_train_dataset(self) -> RENINeuSDataset:
)

def create_eval_dataset(self) -> RENINeuSDataset:
test_outputs = self.dataparser.get_dataparser_outputs("test")
val_outputs = self.dataparser.get_dataparser_outputs("val")
self.test_outputs = self.dataparser.get_dataparser_outputs("test")
self.val_outputs = self.dataparser.get_dataparser_outputs("val")
# self.num_test = len(test_outputs.image_filenames)
# self.num_val = len(val_outputs.image_filenames)
return RENINeuSDataset(
dataparser_outputs=test_outputs if self.test_mode == "test" else val_outputs,
dataparser_outputs=self.test_outputs if self.test_mode == "test" else self.val_outputs,
scale_factor=self.config.camera_res_scale_factor,
split=self.test_split,
)
Expand All @@ -171,6 +171,12 @@ def setup_eval(self):
self.iter_eval_image_dataloader = iter(self.eval_image_dataloader)
self.eval_pixel_sampler = self._get_pixel_sampler(self.eval_dataset, self.config.eval_num_rays_per_batch)
self.eval_ray_generator = RayGenerator(self.eval_dataset.cameras.to(self.device))

self.eval_dataloader = RandIndicesEvalDataloader(
input_dataset=self.eval_dataset,
device=self.device,
num_workers=self.world_size * 4,
)
else:
### This is for NeRF-OSR relighting benchmark ###
session_image_idxs = self.eval_dataset.metadata["session_holdout_indices"] # idx of holdout relative to session
Expand All @@ -195,7 +201,7 @@ def setup_eval(self):
self.iter_eval_session_holdout_dataloader = iter(self.eval_session_holdout_dataloader)
# image_idxs_eval = [x for x in range(len(self.eval_dataset))]
# image_idxs_eval = [idx for idx in image_idxs_eval if idx not in image_idxs_holdout]
image_idxs_eval = self.eval_dataset.test_eval_mask_dict.keys()
image_idxs_eval = list(self.eval_dataset.test_eval_mask_dict.keys())
self.eval_session_compare_dataloader = SelectedIndicesCacheDataloader(
self.eval_dataset,
num_images_to_sample_from=self.config.eval_num_images_to_sample_from,
Expand All @@ -209,21 +215,13 @@ def setup_eval(self):
)
self.iter_eval_session_compare_dataloader = iter(self.eval_session_compare_dataloader)

# full images
if self.eval_latent_optimise_method == "per_image":
self.eval_dataloader = RandIndicesEvalDataloader(
input_dataset=self.eval_dataset,
device=self.device,
num_workers=self.world_size * 4,
)
else:
self.eval_dataloader = FixedIndicesEvalDataloader(
input_dataset=self.eval_dataset,
image_indices=tuple(image_idxs_eval),
device=self.device,
num_workers=self.world_size * 4,
)
self.iter_eval_dataloader = iter(self.eval_dataloader)
self.eval_dataloader = FixedIndicesEvalDataloader(
input_dataset=self.eval_dataset,
image_indices=tuple(image_idxs_eval),
device=self.device,
num_workers=self.world_size * 4,
)
self.iter_eval_dataloader = iter(self.eval_dataloader)


def next_eval_image(self, step: int) -> Tuple[int, RayBundle, Dict]:
Expand Down Expand Up @@ -254,7 +252,10 @@ def get_sky_ray_bundle(self, number_of_rays: int) -> Tuple[RayBundle, Dict]:
# choose random
image_batch = next(self.iter_train_image_dataloader)
assert self.train_pixel_sampler is not None
batch = self.train_pixel_sampler.collate_sky_ray_batch(image_batch, num_rays_per_batch=number_of_rays)
if isinstance(image_batch["image"], list):
batch = self.train_pixel_sampler.collate_sky_ray_batch_list(image_batch, num_rays_per_batch=number_of_rays)
else:
batch = self.train_pixel_sampler.collate_sky_ray_batch(image_batch, num_rays_per_batch=number_of_rays)
ray_indices = batch["indices"].cpu()
ray_bundle = self.train_ray_generator(ray_indices)
return ray_bundle
Expand All @@ -266,9 +267,14 @@ def get_eval_image_half_bundle(
image_batch = next(self.iter_eval_image_dataloader)
assert self.eval_pixel_sampler is not None
assert isinstance(image_batch, dict)
batch = self.eval_pixel_sampler.collate_image_half(
batch=image_batch, num_rays_per_batch=self.config.eval_num_rays_per_batch, sample_region=sample_region
)
if isinstance(image_batch["image"], list):
batch = self.eval_pixel_sampler.collate_image_half_list(
batch=image_batch, num_rays_per_batch=self.config.eval_num_rays_per_batch, sample_region=sample_region
)
else:
batch = self.eval_pixel_sampler.collate_image_half(
batch=image_batch, num_rays_per_batch=self.config.eval_num_rays_per_batch, sample_region=sample_region
)
ray_indices = batch["indices"].cpu()
ray_bundle = self.eval_ray_generator(ray_indices)
return ray_bundle, batch
Expand Down
85 changes: 45 additions & 40 deletions reni_neus/data/dataparsers/nerfosr_cityscapes_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,48 +313,53 @@ def _generate_dataparser_outputs(self, split="train"):
)

# load a single envmap to get its size
envmap = Image.open(envmap_filenames[0])
envmap_width, envmap_height = envmap.size
envmap_cameras = None
if len(envmap_filenames) > 0:
envmap = Image.open(envmap_filenames[0])
envmap_width, envmap_height = envmap.size

c2w = torch.tensor([[[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0]]], dtype=torch.float32).repeat(
len(envmap_filenames), 1, 1
)
c2w = torch.tensor([[[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0]]], dtype=torch.float32).repeat(
len(envmap_filenames), 1, 1
)

envmap_cameras = Cameras(
fx=torch.tensor(envmap_height, dtype=torch.float32).repeat(len(envmap_filenames)),
fy=torch.tensor(envmap_height, dtype=torch.float32).repeat(len(envmap_filenames)),
cx=torch.tensor(envmap_width // 2, dtype=torch.float32).repeat(len(envmap_filenames)),
cy=torch.tensor(envmap_height // 2, dtype=torch.float32).repeat(len(envmap_filenames)),
camera_to_worlds=c2w,
camera_type=CameraType.EQUIRECTANGULAR,
)
envmap_cameras = Cameras(
fx=torch.tensor(envmap_height, dtype=torch.float32).repeat(len(envmap_filenames)),
fy=torch.tensor(envmap_height, dtype=torch.float32).repeat(len(envmap_filenames)),
cx=torch.tensor(envmap_width // 2, dtype=torch.float32).repeat(len(envmap_filenames)),
cy=torch.tensor(envmap_height // 2, dtype=torch.float32).repeat(len(envmap_filenames)),
camera_to_worlds=c2w,
camera_type=CameraType.EQUIRECTANGULAR,
)

# --- session IDs ---
# names of sessions are the folders within scene_dir/ENV_MAP
sessions = [os.path.basename(x) for x in glob.glob(f"{scene_dir}/ENV_MAP_CC/*")]
session_to_indices = defaultdict(list)

for idx, filename in enumerate(image_filenames):
# if filename contains a session name, use that as the session ID
# if no match just skip so as to not have sessions with no images
for session in sessions:
if session in filename:
session_to_indices[session].append(int(idx))

# update keys from strings to integers from 0 to len(session_to_indices) - 1
session_to_indices = {i: session_to_indices[k] for i, k in enumerate(session_to_indices.keys())}

# also create mapping from indices to sessions
indices_to_session = defaultdict(list)
for session_idx, indices in session_to_indices.items():
for idx in indices:
indices_to_session[idx] = session_idx

if split in ["validation", "test"]:
session_to_indices = dict(session_to_indices)
assert len(self.config.session_holdout_indices) == len(
session_to_indices
), "number of relative eval indicies must match number of unique sessions"
session_to_indices = None
indices_to_session = None
if scene != "trevi":
# --- session IDs ---
# names of sessions are the folders within scene_dir/ENV_MAP
sessions = [os.path.basename(x) for x in glob.glob(f"{scene_dir}/ENV_MAP_CC/*")]
session_to_indices = defaultdict(list)

for idx, filename in enumerate(image_filenames):
# if filename contains a session name, use that as the session ID
# if no match just skip so as to not have sessions with no images
for session in sessions:
if session in filename:
session_to_indices[session].append(int(idx))

# update keys from strings to integers from 0 to len(session_to_indices) - 1
session_to_indices = {i: session_to_indices[k] for i, k in enumerate(session_to_indices.keys())}

# also create mapping from indices to sessions
indices_to_session = defaultdict(list)
for session_idx, indices in session_to_indices.items():
for idx in indices:
indices_to_session[idx] = session_idx

if split in ["validation", "test"]:
session_to_indices = dict(session_to_indices)
assert len(self.config.session_holdout_indices) == len(
session_to_indices
), "number of relative eval indicies must match number of unique sessions"

# --- masks ---
mask_filenames = None
Expand Down Expand Up @@ -407,7 +412,7 @@ def _generate_dataparser_outputs(self, split="train"):


test_eval_mask_dict = {}
if split == 'test':
if split == 'test' and scene != "trevi":
def get_filename_without_extension(path):
return path.split('/')[-1].split('.')[0]
test_eval_mask_filenames = _find_files(f"{split_dir}/mask", exts=["*.png", "*.jpg", "*.JPG", "*.PNG"])
Expand Down
3 changes: 2 additions & 1 deletion reni_neus/data/datasets/reni_neus_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float =

self.metadata["c2w"] = dataparser_outputs.cameras.camera_to_worlds
self.envmap_cameras = deepcopy(self.metadata["envmap_cameras"])
self.metadata["num_sessions"] = len(dataparser_outputs.metadata["session_to_indices"])
if dataparser_outputs.metadata["session_to_indices"] is not None:
self.metadata["num_sessions"] = len(dataparser_outputs.metadata["session_to_indices"])
self.test_eval_mask_dict = dataparser_outputs.metadata["test_eval_mask_dict"]
self.out_of_view_frustum_objects_masks = dataparser_outputs.metadata["out_of_view_frustum_objects_masks"]
self.split = split
Expand Down
Loading

0 comments on commit 1777e72

Please sign in to comment.