Skip to content

Commit

Permalink
support negative text prompt, change to stable-diffusion v1.5
Browse files Browse the repository at this point in the history
  • Loading branch information
ashawkey committed Oct 22, 2022
1 parent 47d0083 commit 5d9dde6
Show file tree
Hide file tree
Showing 11 changed files with 99 additions and 211 deletions.
4 changes: 1 addition & 3 deletions gradio_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied")
# network backbone
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
parser.add_argument('--backbone', type=str, default='grid', help="nerf backbone, choose from [grid, tcnn, vanilla]")
parser.add_argument('--backbone', type=str, default='grid', help="nerf backbone, choose from [grid, vanilla]")
# rendering resolution in training, decrease this if CUDA OOM.
parser.add_argument('--w', type=int, default=64, help="render width for NeRF in training")
parser.add_argument('--h', type=int, default=64, help="render height for NeRF in training")
Expand Down Expand Up @@ -78,8 +78,6 @@

if opt.backbone == 'vanilla':
from nerf.network import NeRFNetwork
elif opt.backbone == 'tcnn':
from nerf.network_tcnn import NeRFNetwork
elif opt.backbone == 'grid':
from nerf.network_grid import NeRFNetwork
else:
Expand Down
12 changes: 8 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

parser = argparse.ArgumentParser()
parser.add_argument('--text', default=None, help="text prompt")
parser.add_argument('--negative', default='', type=str, help="negative text prompt")
parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --dir_text")
parser.add_argument('-O2', action='store_true', help="equals --fp16 --dir_text")
parser.add_argument('--test', action='store_true', help="test mode")
Expand All @@ -38,7 +39,7 @@
parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied")
# network backbone
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
parser.add_argument('--backbone', type=str, default='grid', help="nerf backbone, choose from [grid, tcnn, vanilla]")
parser.add_argument('--backbone', type=str, default='grid', help="nerf backbone, choose from [grid, vanilla]")
# rendering resolution in training, decrease this if CUDA OOM.
parser.add_argument('--w', type=int, default=64, help="render width for NeRF in training")
parser.add_argument('--h', type=int, default=64, help="render height for NeRF in training")
Expand All @@ -51,12 +52,14 @@
parser.add_argument('--radius_range', type=float, nargs='*', default=[1.0, 1.5], help="training camera radius range")
parser.add_argument('--fovy_range', type=float, nargs='*', default=[40, 70], help="training camera fovy range")
parser.add_argument('--dir_text', action='store_true', help="direction-encode the text prompt, by appending front/side/back/overhead view")
parser.add_argument('--negative_dir_text', action='store_true', help="also use negative dir text prompt.")
parser.add_argument('--angle_overhead', type=float, default=30, help="[0, angle_overhead] is the overhead region")
parser.add_argument('--angle_front', type=float, default=60, help="[0, angle_front] is the front region, [180, 180+angle_front] the back region, otherwise the side region.")

parser.add_argument('--lambda_entropy', type=float, default=1e-4, help="loss scale for alpha entropy")
parser.add_argument('--lambda_opacity', type=float, default=0, help="loss scale for alpha value")
parser.add_argument('--lambda_orient', type=float, default=1e-2, help="loss scale for orientation")
parser.add_argument('--lambda_smooth', type=float, default=0, help="loss scale for orientation")

### GUI options
parser.add_argument('--gui', action='store_true', help="start a GUI")
Expand All @@ -73,21 +76,22 @@
if opt.O:
opt.fp16 = True
opt.dir_text = True
# use occupancy grid to prune ray sampling, faster rendering.
opt.negative_dir_text = True
opt.cuda_ray = True

# opt.lambda_entropy = 1e-4
# opt.lambda_opacity = 0

elif opt.O2:
opt.fp16 = True
opt.dir_text = True
opt.negative_dir_text = True

opt.lambda_entropy = 1e-4 # necessary to keep non-empty
opt.lambda_opacity = 3e-3 # no occupancy grid, so use a stronger opacity loss.

if opt.backbone == 'vanilla':
from nerf.network import NeRFNetwork
elif opt.backbone == 'tcnn':
from nerf.network_tcnn import NeRFNetwork
elif opt.backbone == 'grid':
from nerf.network_grid import NeRFNetwork
else:
Expand Down
4 changes: 3 additions & 1 deletion nerf/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def __init__(self, device):
# self.gaussian_blur = T.GaussianBlur(15, sigma=(0.1, 10))


def get_text_embeds(self, prompt):
def get_text_embeds(self, prompt, negative_prompt):

# NOTE: negative_prompt is ignored for CLIP.

text = clip.tokenize(prompt).to(self.device)
text_z = self.clip_model.encode_text(text)
Expand Down
3 changes: 3 additions & 0 deletions nerf/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ def register_dpg(self):
# text prompt
if self.opt.text is not None:
dpg.add_text("text: " + self.opt.text, tag="_log_prompt_text")

if self.opt.negative != '':
dpg.add_text("negative text: " + self.opt.negative, tag="_log_prompt_negative_text")

# button theme
with dpg.theme() as theme_button:
Expand Down
13 changes: 13 additions & 0 deletions nerf/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,19 @@ def finite_difference_normal(self, x, epsilon=1e-2):

return normal

def normal(self, x):

with torch.enable_grad():
x.requires_grad_(True)
sigma, albedo = self.common_forward(x)
# query gradient
normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]

# normalize...
normal = safe_normalize(normal)
normal[torch.isnan(normal)] = 0
return normal

def forward(self, x, d, l=None, ratio=1, shading='albedo'):
# x: [N, 3], in [-bound, bound]
# d: [N, 3], view direction, nomalized in [-1, 1]
Expand Down
24 changes: 12 additions & 12 deletions nerf/network_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self,
self.num_layers = num_layers
self.hidden_dim = hidden_dim

self.encoder, self.in_dim = get_encoder('tiledgrid', input_dim=3, desired_resolution=2048 * self.bound)
self.encoder, self.in_dim = get_encoder('tiledgrid', input_dim=3, log2_hashmap_size=16, desired_resolution=2048 * self.bound)

self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)

Expand Down Expand Up @@ -103,6 +103,16 @@ def finite_difference_normal(self, x, epsilon=1e-2):
], dim=-1)

return normal


def normal(self, x):

normal = self.finite_difference_normal(x)
normal = safe_normalize(normal)
normal[torch.isnan(normal)] = 0

return normal


def forward(self, x, d, l=None, ratio=1, shading='albedo'):
# x: [N, 3], in [-bound, bound]
Expand All @@ -119,17 +129,7 @@ def forward(self, x, d, l=None, ratio=1, shading='albedo'):
# query normal

sigma, albedo = self.common_forward(x)
normal = self.finite_difference_normal(x)

# with torch.enable_grad():
# x.requires_grad_(True)
# sigma, albedo = self.common_forward(x)
# # query gradient
# normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]

# normalize...
normal = safe_normalize(normal)
normal[torch.isnan(normal)] = 0
normal = self.normal(x)

# lambertian shading
lambertian = ratio + (1 - ratio) * (normal @ -l).clamp(min=0) # [N,]
Expand Down
174 changes: 0 additions & 174 deletions nerf/network_tcnn.py

This file was deleted.

17 changes: 13 additions & 4 deletions nerf/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,14 +399,17 @@ def run(self, rays_o, rays_d, num_steps=128, upsample_steps=128, light_d=None, a
sigmas, rgbs, normals = self(xyzs.reshape(-1, 3), dirs.reshape(-1, 3), light_d, ratio=ambient_ratio, shading=shading)
rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3]

#print(xyzs.shape, 'valid_rgb:', mask.sum().item())
# orientation loss
if normals is not None:
# orientation loss
normals = normals.view(N, -1, 3)
# print(weights.shape, normals.shape, dirs.shape)
loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2
results['loss_orient'] = loss_orient.mean()

# surface normal smoothness
normals_perturb = self.normal(xyzs + torch.randn_like(xyzs) * 1e-2).view(N, -1, 3)
loss_smooth = (normals - normals_perturb).abs()
results['loss_smooth'] = loss_smooth.mean()

# calculate weight_sum (mask)
weights_sum = weights.sum(dim=-1) # [N]

Expand Down Expand Up @@ -478,12 +481,18 @@ def run_cuda(self, rays_o, rays_d, dt_gamma=0, light_d=None, ambient_ratio=1.0,

weights_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, deltas, rays, T_thresh)

# orientation loss
# normals related regularizations
if normals is not None:
# orientation loss
weights = 1 - torch.exp(-sigmas)
loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2
results['loss_orient'] = loss_orient.mean()

# surface normal smoothness
normals_perturb = self.normal(xyzs + torch.randn_like(xyzs) * 1e-2)
loss_smooth = (normals - normals_perturb).abs()
results['loss_smooth'] = loss_smooth.mean()

else:

# allocate outputs
Expand Down
Loading

0 comments on commit 5d9dde6

Please sign in to comment.