Skip to content

Commit 0f175e8

Browse files
committed
Create the viewer version of 3DCS
1 parent a72419f commit 0f175e8

File tree

2 files changed

+155
-76
lines changed

2 files changed

+155
-76
lines changed

examples/simple_viewer.py

Lines changed: 148 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pathlib import Path
1313
from gsplat._helper import load_test_data
1414
from gsplat.distributed import cli
15-
from gsplat.rendering import rasterization
15+
from gsplat.rendering import rasterization, rasterization_3dcs
1616

1717
from nerfview import CameraState, RenderTabState, apply_float_colormap
1818
from gsplat_viewer import GsplatViewer, GsplatRenderTabState
@@ -101,55 +101,81 @@ def main(local_rank: int, world_rank, world_size: int, args):
101101
)
102102
else:
103103
means, quats, scales, opacities, sh0, shN = [], [], [], [], [], []
104-
for ckpt_path in args.ckpt:
105-
ckpt = torch.load(ckpt_path, map_location=device)["splats"]
106-
means.append(ckpt["means"])
107-
quats.append(F.normalize(ckpt["quats"], p=2, dim=-1))
108-
scales.append(torch.exp(ckpt["scales"]))
109-
opacities.append(torch.sigmoid(ckpt["opacities"]))
110-
sh0.append(ckpt["sh0"])
111-
shN.append(ckpt["shN"])
112-
means = torch.cat(means, dim=0)
113-
quats = torch.cat(quats, dim=0)
114-
scales = torch.cat(scales, dim=0)
115-
opacities = torch.cat(opacities, dim=0)
116-
sh0 = torch.cat(sh0, dim=0)
117-
shN = torch.cat(shN, dim=0)
118-
colors = torch.cat([sh0, shN], dim=-2)
119-
sh_degree = int(math.sqrt(colors.shape[-2]) - 1)
120104

121-
# # crop
122-
# aabb = torch.tensor((-1.0, -1.0, -1.0, 1.0, 1.0, 0.7), device=device)
123-
# edges = aabb[3:] - aabb[:3]
124-
# sel = ((means >= aabb[:3]) & (means <= aabb[3:])).all(dim=-1)
125-
# sel = torch.where(sel)[0]
126-
# means, quats, scales, colors, opacities = (
127-
# means[sel],
128-
# quats[sel],
129-
# scales[sel],
130-
# colors[sel],
131-
# opacities[sel],
132-
# )
105+
convex_points, delta, sigma, num_points_per_convex, cumsum_of_points_per_convex = [], [], [], [], []
106+
if args.backend == "3dcs":
107+
for ckpt_path in args.ckpt:
108+
hyperparam = torch.load(os.path.join(ckpt_path, "hyperparameters.pt"), map_location=device, weights_only=False)
109+
pc = torch.load(os.path.join(ckpt_path, "point_cloud_state_dict.pt"), map_location=device, weights_only=False)
110+
convex_points.append(pc['convex_points'])
111+
delta.append(torch.exp(pc['delta']))
112+
sigma.append(torch.exp(pc['sigma']))
113+
opacities.append(torch.sigmoid(pc["opacity"]).squeeze())
114+
num_points_per_convex.append(torch.tensor([6]))
115+
cumsum_of_points_per_convex.append(hyperparam["cumsum_of_points_per_convex"])
116+
sh0.append(pc["features_dc"])
117+
shN.append(pc["features_rest"])
118+
convex_points = torch.cat(convex_points, dim=0)
119+
delta = torch.cat(delta, dim=0)
120+
sigma = torch.cat(sigma, dim=0)
121+
num_points_per_convex = torch.cat(num_points_per_convex, dim=0)
122+
cumsum_of_points_per_convex = torch.cat(cumsum_of_points_per_convex, dim=0)
123+
opacities = torch.cat(opacities, dim=0)
124+
sh0 = torch.cat(sh0, dim=0)
125+
shN = torch.cat(shN, dim=0)
126+
colors = torch.cat([sh0, shN], dim=-2)
127+
sh_degree = int(pc["active_sh_degree"])
128+
print("Number of 3D convexes:", convex_points.shape[0]*convex_points.shape[1])
129+
else:
130+
for ckpt_path in args.ckpt:
131+
ckpt = torch.load(ckpt_path, map_location=device)["splats"]
132+
means.append(ckpt["means"])
133+
quats.append(F.normalize(ckpt["quats"], p=2, dim=-1))
134+
scales.append(torch.exp(ckpt["scales"]))
135+
opacities.append(torch.sigmoid(ckpt["opacities"]))
136+
sh0.append(ckpt["sh0"])
137+
shN.append(ckpt["shN"])
138+
means = torch.cat(means, dim=0)
139+
quats = torch.cat(quats, dim=0)
140+
scales = torch.cat(scales, dim=0)
141+
opacities = torch.cat(opacities, dim=0)
142+
sh0 = torch.cat(sh0, dim=0)
143+
shN = torch.cat(shN, dim=0)
144+
colors = torch.cat([sh0, shN], dim=-2)
145+
sh_degree = int(math.sqrt(colors.shape[-2]) - 1)
146+
147+
# # crop
148+
# aabb = torch.tensor((-1.0, -1.0, -1.0, 1.0, 1.0, 0.7), device=device)
149+
# edges = aabb[3:] - aabb[:3]
150+
# sel = ((means >= aabb[:3]) & (means <= aabb[3:])).all(dim=-1)
151+
# sel = torch.where(sel)[0]
152+
# means, quats, scales, colors, opacities = (
153+
# means[sel],
154+
# quats[sel],
155+
# scales[sel],
156+
# colors[sel],
157+
# opacities[sel],
158+
# )
133159

134-
# # repeat the scene into a grid (to mimic a large-scale setting)
135-
# repeats = args.scene_grid
136-
# gridx, gridy = torch.meshgrid(
137-
# [
138-
# torch.arange(-(repeats // 2), repeats // 2 + 1, device=device),
139-
# torch.arange(-(repeats // 2), repeats // 2 + 1, device=device),
140-
# ],
141-
# indexing="ij",
142-
# )
143-
# grid = torch.stack([gridx, gridy, torch.zeros_like(gridx)], dim=-1).reshape(
144-
# -1, 3
145-
# )
146-
# means = means[None, :, :] + grid[:, None, :] * edges[None, None, :]
147-
# means = means.reshape(-1, 3)
148-
# quats = quats.repeat(repeats**2, 1)
149-
# scales = scales.repeat(repeats**2, 1)
150-
# colors = colors.repeat(repeats**2, 1, 1)
151-
# opacities = opacities.repeat(repeats**2)
152-
print("Number of Gaussians:", len(means))
160+
# # repeat the scene into a grid (to mimic a large-scale setting)
161+
# repeats = args.scene_grid
162+
# gridx, gridy = torch.meshgrid(
163+
# [
164+
# torch.arange(-(repeats // 2), repeats // 2 + 1, device=device),
165+
# torch.arange(-(repeats // 2), repeats // 2 + 1, device=device),
166+
# ],
167+
# indexing="ij",
168+
# )
169+
# grid = torch.stack([gridx, gridy, torch.zeros_like(gridx)], dim=-1).reshape(
170+
# -1, 3
171+
# )
172+
# means = means[None, :, :] + grid[:, None, :] * edges[None, None, :]
173+
# means = means.reshape(-1, 3)
174+
# quats = quats.repeat(repeats**2, 1)
175+
# scales = scales.repeat(repeats**2, 1)
176+
# colors = colors.repeat(repeats**2, 1, 1)
177+
# opacities = opacities.repeat(repeats**2)
178+
print("Number of Gaussians:", len(means))
153179

154180
# register and open viewer
155181
@torch.no_grad()
@@ -174,34 +200,74 @@ def viewer_render_fn(camera_state: CameraState, render_tab_state: RenderTabState
174200
"alpha": "RGB",
175201
}
176202

177-
render_colors, render_alphas, info = rasterization(
178-
means, # [N, 3]
179-
quats, # [N, 4]
180-
scales, # [N, 3]
181-
opacities, # [N]
182-
colors, # [N, S, 3]
183-
viewmat[None], # [1, 4, 4]
184-
K[None], # [1, 3, 3]
185-
width,
186-
height,
187-
sh_degree=(
188-
min(render_tab_state.max_sh_degree, sh_degree)
189-
if sh_degree is not None
190-
else None
191-
),
192-
near_plane=render_tab_state.near_plane,
193-
far_plane=render_tab_state.far_plane,
194-
radius_clip=render_tab_state.radius_clip,
195-
eps2d=render_tab_state.eps2d,
196-
backgrounds=torch.tensor([render_tab_state.backgrounds], device=device)
197-
/ 255.0,
198-
render_mode=RENDER_MODE_MAP[render_tab_state.render_mode],
199-
rasterize_mode=render_tab_state.rasterize_mode,
200-
camera_model=render_tab_state.camera_model,
201-
packed=False,
202-
)
203-
render_tab_state.total_gs_count = len(means)
204-
render_tab_state.rendered_gs_count = (info["radii"] > 0).all(-1).sum().item()
203+
if args.backend == "gsplat":
204+
rasterization_fn = rasterization
205+
elif args.backend == "3dcs":
206+
rasterization_fn = rasterization_3dcs
207+
elif args.backend == "inria":
208+
from gsplat import rasterization_inria_wrapper
209+
210+
rasterization_fn = rasterization_inria_wrapper
211+
else:
212+
raise ValueError
213+
214+
if args.backend == "3dcs":
215+
render_colors, render_alphas, info = rasterization_fn(
216+
convex_points,
217+
delta,
218+
sigma,
219+
num_points_per_convex,
220+
cumsum_of_points_per_convex,
221+
opacities,
222+
colors,
223+
viewmat[None], # [1, 4, 4]
224+
K[None], # [1, 3, 3]
225+
width,
226+
height,
227+
packed=False,
228+
sh_degree=(
229+
min(render_tab_state.max_sh_degree, sh_degree)
230+
if sh_degree is not None
231+
else None
232+
),
233+
near_plane=render_tab_state.near_plane,
234+
far_plane=render_tab_state.far_plane,
235+
radius_clip=render_tab_state.radius_clip,
236+
eps2d=render_tab_state.eps2d,
237+
backgrounds=torch.tensor([render_tab_state.backgrounds], device=device)
238+
/ 255.0,
239+
render_mode=RENDER_MODE_MAP[render_tab_state.render_mode],
240+
rasterize_mode=render_tab_state.rasterize_mode,
241+
camera_model=render_tab_state.camera_model,
242+
)
243+
else:
244+
render_colors, render_alphas, info = rasterization(
245+
means, # [N, 3]
246+
quats, # [N, 4]
247+
scales, # [N, 3]
248+
opacities, # [N]
249+
colors, # [N, S, 3]
250+
viewmat[None], # [1, 4, 4]
251+
K[None], # [1, 3, 3]
252+
width,
253+
height,
254+
sh_degree=(
255+
min(render_tab_state.max_sh_degree, sh_degree)
256+
if sh_degree is not None
257+
else None
258+
),
259+
near_plane=render_tab_state.near_plane,
260+
far_plane=render_tab_state.far_plane,
261+
radius_clip=render_tab_state.radius_clip,
262+
eps2d=render_tab_state.eps2d,
263+
backgrounds=torch.tensor([render_tab_state.backgrounds], device=device)
264+
/ 255.0,
265+
render_mode=RENDER_MODE_MAP[render_tab_state.render_mode],
266+
rasterize_mode=render_tab_state.rasterize_mode,
267+
camera_model=render_tab_state.camera_model,
268+
)
269+
render_tab_state.total_gs_count = len(means)
270+
render_tab_state.rendered_gs_count = (info["radii"] > 0).all(-1).sum().item()
205271

206272
if render_tab_state.render_mode == "rgb":
207273
# colors represented with sh are not guranteed to be in [0, 1]
@@ -267,6 +333,12 @@ def viewer_render_fn(camera_state: CameraState, render_tab_state: RenderTabState
267333
parser.add_argument(
268334
"--ckpt", type=str, nargs="+", default=None, help="path to the .pt file"
269335
)
336+
parser.add_argument(
337+
"--ply", type=str, nargs="+", default=None, help="path to the .ply file"
338+
)
339+
parser.add_argument(
340+
"--backend", type=str, default="gsplat", choices=["gsplat", "3dcs", "inria"], help="backend to use for rendering",
341+
)
270342
parser.add_argument(
271343
"--port", type=int, default=8080, help="port for the viewer server"
272344
)

gsplat/_helper.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@
55
import torch
66
import torch.nn.functional as F
77

8+
def load_ply_data(
9+
data_path: Optional[str] = None,
10+
device="cuda",
11+
scene_crop: Tuple[float, float, float, float, float, float] = (-2, -2, -2, 2, 2, 2),
12+
scene_grid: int = 1,
13+
):
14+
assert True
815

916
def load_test_data(
1017
data_path: Optional[str] = None,

0 commit comments

Comments
 (0)