Skip to content

Commit

Permalink
add uniform sphere camera sampling from https://github.com/tryba/stab…
Browse files Browse the repository at this point in the history
  • Loading branch information
ashawkey committed Nov 4, 2022
1 parent f70f214 commit cb850b0
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 23 deletions.
17 changes: 9 additions & 8 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when not using --cuda_ray)")
parser.add_argument('--albedo_iters', type=int, default=1000, help="training iters that only use albedo shading")
parser.add_argument('--uniform_sphere_rate', type=float, default=0.5, help="likelihood of sampling camera location uniformly on the sphere surface area")
# model options
parser.add_argument('--bg_radius', type=float, default=1.4, help="if positive, use a background model at sphere(bg_radius)")
parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied")
Expand Down Expand Up @@ -125,6 +126,14 @@

else:

train_loader = NeRFDataset(opt, device=device, type='train', H=opt.h, W=opt.w, size=100).dataloader()

optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15)
# optimizer = lambda model: Shampoo(model.get_params(opt.lr))

scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1))
# scheduler = lambda optimizer: optim.lr_scheduler.OneCycleLR(optimizer, max_lr=opt.lr, total_steps=opt.iters, pct_start=0.1)

if opt.guidance == 'stable-diffusion':
from nerf.sd import StableDiffusion
guidance = StableDiffusion(device)
Expand All @@ -134,14 +143,6 @@
else:
raise NotImplementedError(f'--guidance {opt.guidance} is not implemented.')

optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15)
# optimizer = lambda model: Shampoo(model.get_params(opt.lr))

train_loader = NeRFDataset(opt, device=device, type='train', H=opt.h, W=opt.w, size=100).dataloader()

scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1))
# scheduler = lambda optimizer: optim.lr_scheduler.OneCycleLR(optimizer, max_lr=opt.lr, total_steps=opt.iters, pct_start=0.1)

trainer = Trainer('df', opt, model, guidance, device=device, workspace=opt.workspace, optimizer=optimizer, ema_decay=None, fp16=opt.fp16, lr_scheduler=scheduler, use_checkpoint=opt.ckpt, eval_interval=opt.eval_interval, scheduler_update_every_step=True)

if opt.gui:
Expand Down
57 changes: 42 additions & 15 deletions nerf/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,28 @@
import trimesh

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

from .utils import get_rays, safe_normalize

def visualize_poses(poses, size=0.1):
# poses: [B, 4, 4]
DIR_COLORS = np.array([
[255, 0, 0, 255], # front
[0, 255, 0, 255], # side
[0, 0, 255, 255], # back
[255, 255, 0, 255], # side
[255, 0, 255, 255], # overhead
[0, 255, 255, 255], # bottom
], dtype=np.uint8)

def visualize_poses(poses, dirs, size=0.1):
# poses: [B, 4, 4], dirs: [B]

axes = trimesh.creation.axis(axis_length=4)
sphere = trimesh.creation.icosphere(radius=1)
objects = [axes, sphere]

for pose in poses:
for pose, dir in zip(poses, dirs):
# a camera is visualized with 8 line segments.
pos = pose[:3, 3]
a = pos + size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2]
Expand All @@ -31,6 +41,10 @@ def visualize_poses(poses, size=0.1):

segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a]])
segs = trimesh.load_path(segs)

# different color for different dirs
segs.colors = DIR_COLORS[[dir]].repeat(len(segs.entities), 0)

objects.append(segs)

trimesh.Scene(objects).show()
Expand All @@ -55,7 +69,7 @@ def get_view_direction(thetas, phis, overhead, front):
return res


def rand_poses(size, device, radius_range=[1, 1.5], theta_range=[0, 100], phi_range=[0, 360], return_dirs=False, angle_overhead=30, angle_front=60, jitter=False):
def rand_poses(size, device, radius_range=[1, 1.5], theta_range=[0, 100], phi_range=[0, 360], return_dirs=False, angle_overhead=30, angle_front=60, jitter=False, uniform_sphere_rate=0.5):
''' generate random poses from an orbit camera
Args:
size: batch size of generated poses.
Expand All @@ -73,14 +87,28 @@ def rand_poses(size, device, radius_range=[1, 1.5], theta_range=[0, 100], phi_ra
angle_front = np.deg2rad(angle_front)

radius = torch.rand(size, device=device) * (radius_range[1] - radius_range[0]) + radius_range[0]
thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0]
phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0]

centers = torch.stack([
radius * torch.sin(thetas) * torch.sin(phis),
radius * torch.cos(thetas),
radius * torch.sin(thetas) * torch.cos(phis),
], dim=-1) # [B, 3]
if random.random() < uniform_sphere_rate:
unit_centers = F.normalize(
torch.stack([
(torch.rand(size, device=device) - 0.5) * 2.0,
torch.rand(size, device=device),
(torch.rand(size, device=device) - 0.5) * 2.0,
], dim=-1), p=2, dim=1
)
thetas = torch.acos(unit_centers[:,1])
phis = torch.atan2(unit_centers[:,0], unit_centers[:,2])
phis[phis < 0] += 2 * np.pi
centers = unit_centers * radius.unsqueeze(-1)
else:
thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0]
phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0]

centers = torch.stack([
radius * torch.sin(thetas) * torch.sin(phis),
radius * torch.cos(thetas),
radius * torch.sin(thetas) * torch.cos(phis),
], dim=-1) # [B, 3]

targets = 0

Expand Down Expand Up @@ -167,8 +195,8 @@ def __init__(self, opt, device, type='train', H=256, W=256, size=100):
self.cy = self.W / 2

# [debug] visualize poses
# poses, dirs = rand_poses(100, self.device, return_dirs=self.opt.dir_text, radius_range=self.radius_range)
# visualize_poses(poses.detach().cpu().numpy())
# poses, dirs = rand_poses(100, self.device, radius_range=self.radius_range, return_dirs=self.opt.dir_text, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front, jitter=self.opt.jitter_pose, uniform_sphere_rate=1)
# visualize_poses(poses.detach().cpu().numpy(), dirs.detach().cpu().numpy())


def collate(self, index):
Expand All @@ -177,7 +205,7 @@ def collate(self, index):

if self.training:
# random pose on the fly
poses, dirs = rand_poses(B, self.device, radius_range=self.radius_range, return_dirs=self.opt.dir_text, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front, jitter=self.opt.jitter_pose)
poses, dirs = rand_poses(B, self.device, radius_range=self.radius_range, return_dirs=self.opt.dir_text, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front, jitter=self.opt.jitter_pose, uniform_sphere_rate=self.opt.uniform_sphere_rate)

# random focal
fov = random.random() * (self.fovy_range[1] - self.fovy_range[0]) + self.fovy_range[0]
Expand Down Expand Up @@ -210,5 +238,4 @@ def collate(self, index):

def dataloader(self):
loader = DataLoader(list(range(self.size)), batch_size=1, collate_fn=self.collate, shuffle=self.training, num_workers=0)
loader._data = self # an ugly fix... we need to access dataset in trainer.
return loader

0 comments on commit cb850b0

Please sign in to comment.