diff --git a/main.py b/main.py index 12aae691..5dd9a95b 100644 --- a/main.py +++ b/main.py @@ -34,6 +34,7 @@ parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)") parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when not using --cuda_ray)") parser.add_argument('--albedo_iters', type=int, default=1000, help="training iters that only use albedo shading") + parser.add_argument('--uniform_sphere_rate', type=float, default=0.5, help="likelihood of sampling camera location uniformly on the sphere surface area") # model options parser.add_argument('--bg_radius', type=float, default=1.4, help="if positive, use a background model at sphere(bg_radius)") parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied") @@ -125,6 +126,14 @@ else: + train_loader = NeRFDataset(opt, device=device, type='train', H=opt.h, W=opt.w, size=100).dataloader() + + optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15) + # optimizer = lambda model: Shampoo(model.get_params(opt.lr)) + + scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1)) + # scheduler = lambda optimizer: optim.lr_scheduler.OneCycleLR(optimizer, max_lr=opt.lr, total_steps=opt.iters, pct_start=0.1) + if opt.guidance == 'stable-diffusion': from nerf.sd import StableDiffusion guidance = StableDiffusion(device) @@ -134,14 +143,6 @@ else: raise NotImplementedError(f'--guidance {opt.guidance} is not implemented.') - optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15) - # optimizer = lambda model: Shampoo(model.get_params(opt.lr)) - - train_loader = NeRFDataset(opt, device=device, type='train', H=opt.h, W=opt.w, size=100).dataloader() - - scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1)) - # scheduler = lambda optimizer: optim.lr_scheduler.OneCycleLR(optimizer, max_lr=opt.lr, total_steps=opt.iters, pct_start=0.1) - trainer = Trainer('df', opt, model, guidance, device=device, workspace=opt.workspace, optimizer=optimizer, ema_decay=None, fp16=opt.fp16, lr_scheduler=scheduler, use_checkpoint=opt.ckpt, eval_interval=opt.eval_interval, scheduler_update_every_step=True) if opt.gui: diff --git a/nerf/provider.py b/nerf/provider.py index 71d5d768..2cc8cbd5 100644 --- a/nerf/provider.py +++ b/nerf/provider.py @@ -10,18 +10,28 @@ import trimesh import torch +import torch.nn.functional as F from torch.utils.data import DataLoader from .utils import get_rays, safe_normalize -def visualize_poses(poses, size=0.1): - # poses: [B, 4, 4] +DIR_COLORS = np.array([ + [255, 0, 0, 255], # front + [0, 255, 0, 255], # side + [0, 0, 255, 255], # back + [255, 255, 0, 255], # side + [255, 0, 255, 255], # overhead + [0, 255, 255, 255], # bottom +], dtype=np.uint8) + +def visualize_poses(poses, dirs, size=0.1): + # poses: [B, 4, 4], dirs: [B] axes = trimesh.creation.axis(axis_length=4) sphere = trimesh.creation.icosphere(radius=1) objects = [axes, sphere] - for pose in poses: + for pose, dir in zip(poses, dirs): # a camera is visualized with 8 line segments. pos = pose[:3, 3] a = pos + size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] @@ -31,6 +41,10 @@ def visualize_poses(poses, size=0.1): segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a]]) segs = trimesh.load_path(segs) + + # different color for different dirs + segs.colors = DIR_COLORS[[dir]].repeat(len(segs.entities), 0) + objects.append(segs) trimesh.Scene(objects).show() @@ -55,7 +69,7 @@ def get_view_direction(thetas, phis, overhead, front): return res -def rand_poses(size, device, radius_range=[1, 1.5], theta_range=[0, 100], phi_range=[0, 360], return_dirs=False, angle_overhead=30, angle_front=60, jitter=False): +def rand_poses(size, device, radius_range=[1, 1.5], theta_range=[0, 100], phi_range=[0, 360], return_dirs=False, angle_overhead=30, angle_front=60, jitter=False, uniform_sphere_rate=0.5): ''' generate random poses from an orbit camera Args: size: batch size of generated poses. @@ -73,14 +87,28 @@ def rand_poses(size, device, radius_range=[1, 1.5], theta_range=[0, 100], phi_ra angle_front = np.deg2rad(angle_front) radius = torch.rand(size, device=device) * (radius_range[1] - radius_range[0]) + radius_range[0] - thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0] - phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0] - centers = torch.stack([ - radius * torch.sin(thetas) * torch.sin(phis), - radius * torch.cos(thetas), - radius * torch.sin(thetas) * torch.cos(phis), - ], dim=-1) # [B, 3] + if random.random() < uniform_sphere_rate: + unit_centers = F.normalize( + torch.stack([ + (torch.rand(size, device=device) - 0.5) * 2.0, + torch.rand(size, device=device), + (torch.rand(size, device=device) - 0.5) * 2.0, + ], dim=-1), p=2, dim=1 + ) + thetas = torch.acos(unit_centers[:,1]) + phis = torch.atan2(unit_centers[:,0], unit_centers[:,2]) + phis[phis < 0] += 2 * np.pi + centers = unit_centers * radius.unsqueeze(-1) + else: + thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0] + phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0] + + centers = torch.stack([ + radius * torch.sin(thetas) * torch.sin(phis), + radius * torch.cos(thetas), + radius * torch.sin(thetas) * torch.cos(phis), + ], dim=-1) # [B, 3] targets = 0 @@ -167,8 +195,8 @@ def __init__(self, opt, device, type='train', H=256, W=256, size=100): self.cy = self.W / 2 # [debug] visualize poses - # poses, dirs = rand_poses(100, self.device, return_dirs=self.opt.dir_text, radius_range=self.radius_range) - # visualize_poses(poses.detach().cpu().numpy()) + # poses, dirs = rand_poses(100, self.device, radius_range=self.radius_range, return_dirs=self.opt.dir_text, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front, jitter=self.opt.jitter_pose, uniform_sphere_rate=1) + # visualize_poses(poses.detach().cpu().numpy(), dirs.detach().cpu().numpy()) def collate(self, index): @@ -177,7 +205,7 @@ def collate(self, index): if self.training: # random pose on the fly - poses, dirs = rand_poses(B, self.device, radius_range=self.radius_range, return_dirs=self.opt.dir_text, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front, jitter=self.opt.jitter_pose) + poses, dirs = rand_poses(B, self.device, radius_range=self.radius_range, return_dirs=self.opt.dir_text, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front, jitter=self.opt.jitter_pose, uniform_sphere_rate=self.opt.uniform_sphere_rate) # random focal fov = random.random() * (self.fovy_range[1] - self.fovy_range[0]) + self.fovy_range[0] @@ -210,5 +238,4 @@ def collate(self, index): def dataloader(self): loader = DataLoader(list(range(self.size)), batch_size=1, collate_fn=self.collate, shuffle=self.training, num_workers=0) - loader._data = self # an ugly fix... we need to access dataset in trainer. return loader \ No newline at end of file