Skip to content

Commit

Permalink
improve non-cuda-ray mode
Browse files Browse the repository at this point in the history
  • Loading branch information
ashawkey committed Oct 9, 2022
1 parent fa5ca19 commit 3de5f93
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 18 deletions.
15 changes: 11 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
parser.add_argument('--ckpt', type=str, default='latest')
parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
parser.add_argument('--max_steps', type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)")
parser.add_argument('--num_steps', type=int, default=128, help="num steps sampled per ray (only valid when not using --cuda_ray)")
parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when not using --cuda_ray)")
parser.add_argument('--num_steps', type=int, default=64, help="num steps sampled per ray (only valid when not using --cuda_ray)")
parser.add_argument('--upsample_steps', type=int, default=64, help="num steps up-sampled per ray (only valid when not using --cuda_ray)")
parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when not using --cuda_ray)")
parser.add_argument('--albedo_iters', type=int, default=15000, help="training iters that only use albedo shading")
Expand All @@ -40,8 +40,8 @@
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
parser.add_argument('--backbone', type=str, default='grid', help="nerf backbone, choose from [grid, tcnn, vanilla]")
# rendering resolution in training, decrease this if CUDA OOM.
parser.add_argument('--w', type=int, default=128, help="render width for NeRF in training")
parser.add_argument('--h', type=int, default=128, help="render height for NeRF in training")
parser.add_argument('--w', type=int, default=64, help="render width for NeRF in training")
parser.add_argument('--h', type=int, default=64, help="render height for NeRF in training")
parser.add_argument('--jitter_pose', action='store_true', help="add jitters to the randomly sampled camera poses")

### dataset options
Expand All @@ -55,6 +55,7 @@
parser.add_argument('--angle_front', type=float, default=30, help="[0, angle_front] is the front region, [180, 180+angle_front] the back region, otherwise the side region.")

parser.add_argument('--lambda_entropy', type=float, default=1e-4, help="loss scale for alpha entropy")
parser.add_argument('--lambda_opacity', type=float, default=0, help="loss scale for alpha value")
parser.add_argument('--lambda_orient', type=float, default=1e-2, help="loss scale for orientation")

### GUI options
Expand All @@ -72,10 +73,16 @@
if opt.O:
opt.fp16 = True
opt.dir_text = True
# use occupancy grid to prune ray sampling, faster rendering.
opt.cuda_ray = True
opt.lambda_entropy = 1e-4
opt.lambda_opacity = 0

elif opt.O2:
opt.fp16 = True
opt.dir_text = True
opt.lambda_entropy = 1e-3
opt.lambda_opacity = 1e-3 # no occupancy grid, so use a stronger opacity loss.

if opt.backbone == 'vanilla':
from nerf.network import NeRFNetwork
Expand Down
2 changes: 1 addition & 1 deletion nerf/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, device):
print(f'[INFO] loaded hugging face access token from ./TOKEN!')
except FileNotFoundError as e:
self.token = True
print(f'[INFO] try to load hugging face access token from the default plase, make sure you have run `huggingface-cli login`.')
print(f'[INFO] try to load hugging face access token from the default place, make sure you have run `huggingface-cli login`.')

self.device = device
self.num_train_timesteps = 1000
Expand Down
28 changes: 15 additions & 13 deletions nerf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,11 @@ def train_step(self, data):
if rand > 0.8:
shading = 'albedo'
ambient_ratio = 1.0
elif rand > 0.4:
shading = 'lambertian'
ambient_ratio = 0.1
# elif rand > 0.4:
# shading = 'textureless'
# ambient_ratio = 0.1
else:
shading = 'textureless'
shading = 'lambertian'
ambient_ratio = 0.1

# _t = time.time()
Expand All @@ -355,22 +355,24 @@ def train_step(self, data):

# encode pred_rgb to latents
# _t = time.time()
loss_guidance = self.guidance.train_step(text_z, pred_rgb)
loss = self.guidance.train_step(text_z, pred_rgb)
# torch.cuda.synchronize(); print(f'[TIME] total guiding {time.time() - _t:.4f}s')

# occupancy loss
pred_ws = outputs['weights_sum'].reshape(B, 1, H, W)
# mask_ws = outputs['mask'].reshape(B, 1, H, W) # near < far

# loss_ws = (pred_ws ** 2 + 0.01).sqrt().mean()
if self.opt.lambda_opacity > 0:
loss_opacity = (pred_ws ** 2).mean()
loss = loss + self.opt.lambda_opacity * loss_opacity

alphas = (pred_ws).clamp(1e-5, 1 - 1e-5)
# alphas = alphas ** 2 # skewed entropy, favors 0 over 1
loss_entropy = (- alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)).mean()

loss = loss_guidance + self.opt.lambda_entropy * loss_entropy
if self.opt.lambda_entropy > 0:
alphas = (pred_ws).clamp(1e-5, 1 - 1e-5)
# alphas = alphas ** 2 # skewed entropy, favors 0 over 1
loss_entropy = (- alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)).mean()

loss = loss + self.opt.lambda_entropy * loss_entropy

if 'loss_orient' in outputs:
if self.opt.lambda_orient > 0 and 'loss_orient' in outputs:
loss_orient = outputs['loss_orient']
loss = loss + self.opt.lambda_orient * loss_orient

Expand Down

0 comments on commit 3de5f93

Please sign in to comment.