Skip to content

Commit b079f8d

Browse files
author
booth-algo
committed
Checkpoint commit, minor QoL updates
1 parent 3d03993 commit b079f8d

File tree

5 files changed

+30
-10
lines changed

5 files changed

+30
-10
lines changed

Diff for: .gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
**/__pycache__/**
44
gsplat
55
vis
6-
B075X65R3X/
6+
training-data/

Diff for: B075X65R3X.zip

-5.22 MB
Binary file not shown.

Diff for: gaussian_splatting/gauss_render.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,8 @@ def __init__(self, active_sh_degree=3, white_bkgd=True, **kwargs):
175175
self.active_sh_degree = active_sh_degree
176176
self.debug = False
177177
self.white_bkgd = white_bkgd
178-
self.pix_coord = torch.stack(torch.meshgrid(torch.arange(256), torch.arange(256), indexing='xy'), dim=-1).to('cuda')
178+
# self.pix_coord = torch.stack(torch.meshgrid(torch.arange(256), torch.arange(256), indexing='xy'), dim=-1).to('cuda')
179+
self.pix_coord = None
179180

180181

181182
def build_color(self, means3D, shs, camera):
@@ -239,6 +240,9 @@ def render(self, camera, means2D, cov2d, color, opacity, depths):
239240

240241

241242
def forward(self, pc_output, camera, **kwargs):
243+
if self.pix_coord is None or self.pix_coord.shape[:2] != (camera.image_height, camera.image_width):
244+
self.pix_coord = torch.stack(torch.meshgrid(torch.arange(camera.image_height), torch.arange(camera.image_width), indexing='xy'), dim=-1).to('cuda')
245+
242246
means3D = pc_output['xyz']
243247
opacity = pc_output['opacity']
244248
scales = pc_output['scaling']

Diff for: gaussian_splatting/utils/data_utils.py

+3
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ def read_all(folder, resize_factor=1.):
4747
src_rgb , src_depth, src_alpha, src_camera = \
4848
read_image(src_rgb_file, src_pose,
4949
intrinsic, max_depth=max_depth, resize_factor=resize_factor)
50+
51+
# Extract focal lengths
52+
focal_x, focal_y = intrinsic[0, 0], intrinsic[1, 1]
5053

5154
src_rgbs.append(src_rgb)
5255
src_depths.append(src_depth)

Diff for: train.py

+21-8
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,15 @@ def __init__(self, **kwargs):
4141

4242
def on_train_step(self):
4343
ind = np.random.choice(len(self.data['camera']))
44-
camera = self.data['camera'][ind]
44+
camera_params = self.data['camera'][ind]
45+
camera = to_viewpoint_camera(camera_params)
46+
# camera = self.data['camera'][ind]
4547
rgb = self.data['rgb'][ind]
4648
depth = self.data['depth'][ind]
4749
mask = (self.data['alpha'][ind] > 0.5)
48-
if USE_GPU_PYTORCH:
49-
camera = to_viewpoint_camera(camera)
50+
51+
# if USE_GPU_PYTORCH:
52+
# camera = to_viewpoint_camera(camera)
5053

5154
if USE_PROFILE:
5255
prof = profile(activities=[ProfilerActivity.CUDA], with_stack=True)
@@ -101,9 +104,11 @@ def log_psnr_stats(self):
101104
def on_evaluate_step(self, **kwargs):
102105
import matplotlib.pyplot as plt
103106
ind = np.random.choice(len(self.data['camera']))
104-
camera = self.data['camera'][ind]
105-
if USE_GPU_PYTORCH:
106-
camera = to_viewpoint_camera(camera)
107+
# camera = self.data['camera'][ind]
108+
# if USE_GPU_PYTORCH:
109+
# camera = to_viewpoint_camera(camera)
110+
111+
camera = to_viewpoint_camera(self.data['camera'][ind])
107112

108113
rgb = self.data['rgb'][ind].detach().cpu().numpy()
109114

@@ -114,9 +119,17 @@ def on_evaluate_step(self, **kwargs):
114119
rgb_pd = out['render'].detach().cpu().numpy()
115120
depth_pd = out['depth'].detach().cpu().numpy()[..., 0]
116121
depth = self.data['depth'][ind].detach().cpu().numpy()
122+
123+
if depth.shape != depth_pd.shape:
124+
depth = np.resize(depth, depth_pd.shape)
125+
117126
depth = np.concatenate([depth, depth_pd], axis=1)
118127
depth = (1 - depth / depth.max())
119128
depth = plt.get_cmap('jet')(depth)[..., :3]
129+
130+
if rgb.shape != rgb_pd.shape:
131+
rgb = np.resize(rgb, rgb_pd.shape)
132+
120133
image = np.concatenate([rgb, rgb_pd], axis=1)
121134
image = np.concatenate([image, depth], axis=0)
122135
utils.imwrite(str(self.results_folder / f'image-{self.step}.png'), image)
@@ -144,7 +157,7 @@ def get_test_folder(base_folder='result', prefix='test'):
144157

145158
if __name__ == "__main__":
146159
device = 'cuda'
147-
folder = './B075X65R3X'
160+
folder = './training-data/B075X65R3X'
148161
data = read_all(folder, resize_factor=0.5)
149162
data = {k: v.to(device) for k, v in data.items()}
150163
data['depth_range'] = torch.Tensor([[1,3]]*len(data['rgb'])).to(device)
@@ -197,7 +210,7 @@ def get_test_folder(base_folder='result', prefix='test'):
197210
# model=GaussModel,
198211
data=data,
199212
train_batch_size=1,
200-
train_num_steps=30,
213+
train_num_steps=1000,
201214
i_image =100,
202215
train_lr=1e-3,
203216
amp=False,

0 commit comments

Comments
 (0)