Skip to content

Commit

Permalink
correct some comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ashawkey committed Oct 6, 2022
1 parent 3d350ce commit 6f790c0
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 52 deletions.
41 changes: 0 additions & 41 deletions encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,6 @@
import torch.nn as nn
import torch.nn.functional as F

class FreqEncoder(nn.Module):
def __init__(self, input_dim, max_freq_log2, N_freqs,
log_sampling=True, include_input=True,
periodic_fns=(torch.sin, torch.cos)):

super().__init__()

self.input_dim = input_dim
self.include_input = include_input
self.periodic_fns = periodic_fns

self.output_dim = 0
if self.include_input:
self.output_dim += self.input_dim

self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns)

if log_sampling:
self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs)
else:
self.freq_bands = torch.linspace(2. ** 0., 2. ** max_freq_log2, N_freqs)

self.freq_bands = self.freq_bands.numpy().tolist()

def forward(self, input, **kwargs):

out = []
if self.include_input:
out.append(input)

for i in range(len(self.freq_bands)):
freq = self.freq_bands[i]
for p_fn in self.periodic_fns:
out.append(p_fn(input * freq))

out = torch.cat(out, dim=-1)


return out

def get_encoder(encoding, input_dim=3,
multires=6,
degree=4,
Expand All @@ -52,7 +12,6 @@ def get_encoder(encoding, input_dim=3,
return lambda x, **kwargs: x, input_dim

elif encoding == 'frequency':
#encoder = FreqEncoder(input_dim=input_dim, max_freq_log2=multires-1, N_freqs=multires, log_sampling=True)
from freqencoder import FreqEncoder
encoder = FreqEncoder(input_dim=input_dim, degree=multires)

Expand Down
20 changes: 10 additions & 10 deletions main_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,15 @@
from nerf.utils import *
from optimizer import Shampoo

from nerf.sd import StableDiffusion
from nerf.clip import CLIP
from nerf.gui import NeRFGUI

# torch.autograd.set_detect_anomaly(True)

if __name__ == '__main__':

parser = argparse.ArgumentParser()
parser.add_argument('--text', help="text prompt")
parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --preload")
parser.add_argument('--text', default=None, help="text prompt")
parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --dir_text")
parser.add_argument('--test', action='store_true', help="test mode")
parser.add_argument('--workspace', type=str, default='workspace')
parser.add_argument('--guidance', type=str, default='stable-diffusion', help='choose from [stable-diffusion, clip]')
Expand All @@ -31,24 +29,24 @@
parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when not using --cuda_ray)")
parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when not using --cuda_ray)")
parser.add_argument('--albedo_iters', type=int, default=15000, help="training iters")
parser.add_argument('--albedo_iters', type=int, default=15000, help="training iters that only use albedo shading")
# model options
parser.add_argument('--bg_radius', type=float, default=1.4, help="if positive, use a background model at sphere(bg_radius)")
parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied")
# network backbone
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
parser.add_argument('--backbone', type=str, default='grid', help="nerf backbone, choose from [grid, tcnn, vanilla]")
# rendering resolution in training
parser.add_argument('--w', type=int, default=64, help="render width for CLIP training (<=224)")
parser.add_argument('--h', type=int, default=64, help="render height for CLIP training (<=224)")
parser.add_argument('--w', type=int, default=64, help="render width for NeRF in training")
parser.add_argument('--h', type=int, default=64, help="render height for NeRF in training")

### dataset options
parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box(-bound, bound)")
parser.add_argument('--dt_gamma', type=float, default=0, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
parser.add_argument('--min_near', type=float, default=0.1, help="minimum near distance for camera")
parser.add_argument('--radius_range', type=float, nargs='*', default=[1.0, 1.5], help="training camera radius range")
parser.add_argument('--fovy_range', type=float, nargs='*', default=[40, 70], help="training camera fovy range")
parser.add_argument('--dir_text', action='store_true', help="direction encoded text prompt")
parser.add_argument('--dir_text', action='store_true', help="direction-encode the text prompt, by appending front/side/back/overhead view")

### GUI options
parser.add_argument('--gui', action='store_true', help="start a GUI")
Expand All @@ -58,7 +56,7 @@
parser.add_argument('--fovy', type=float, default=60, help="default GUI camera fovy")
parser.add_argument('--light_theta', type=float, default=60, help="default GUI light direction")
parser.add_argument('--light_phi', type=float, default=0, help="default GUI light direction")
parser.add_argument('--max_spp', type=int, default=64, help="GUI rendering max sample per pixel")
parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel")

opt = parser.parse_args()

Expand Down Expand Up @@ -87,7 +85,7 @@
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if opt.test:
guidance = None # do not load guidance at test
guidance = None # no need to load guidance model at test

trainer = Trainer('ngp', opt, model, guidance, device=device, workspace=opt.workspace, fp16=opt.fp16, use_checkpoint=opt.ckpt)

Expand All @@ -103,8 +101,10 @@
else:

if opt.guidance == 'stable-diffusion':
from nerf.sd import StableDiffusion
guidance = StableDiffusion(device)
elif opt.guidance == 'clip':
from nerf.clip import CLIP
guidance = CLIP(device)
else:
raise NotImplementedError(f'--guidance {opt.guidance} is not implemented.')
Expand Down
1 change: 1 addition & 0 deletions nerf/network_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(self,
else:
self.bg_net = None

# add a density blob to the scene center
def gaussian(self, x):
# x: [B, N, 3]

Expand Down
2 changes: 1 addition & 1 deletion nerf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def __init__(self,
self.guidance = guidance

if self.guidance is not None:
assert ref_text is not None, 'Training must provide a text prompt!'

for p in self.guidance.parameters():
p.requires_grad = False
Expand Down Expand Up @@ -401,7 +402,6 @@ def eval_step(self, data):

return pred_rgb, pred_depth, loss

# moved out bg_color and perturb for more flexible control...
def test_step(self, data, bg_color=None, perturb=False):
rays_o = data['rays_o'] # [B, N, 3]
rays_d = data['rays_d'] # [B, N, 3]
Expand Down

0 comments on commit 6f790c0

Please sign in to comment.