12
12
from pathlib import Path
13
13
from gsplat ._helper import load_test_data
14
14
from gsplat .distributed import cli
15
- from gsplat .rendering import rasterization
15
+ from gsplat .rendering import rasterization , rasterization_3dcs
16
16
17
17
from nerfview import CameraState , RenderTabState , apply_float_colormap
18
18
from gsplat_viewer import GsplatViewer , GsplatRenderTabState
@@ -101,55 +101,81 @@ def main(local_rank: int, world_rank, world_size: int, args):
101
101
)
102
102
else :
103
103
means , quats , scales , opacities , sh0 , shN = [], [], [], [], [], []
104
- for ckpt_path in args .ckpt :
105
- ckpt = torch .load (ckpt_path , map_location = device )["splats" ]
106
- means .append (ckpt ["means" ])
107
- quats .append (F .normalize (ckpt ["quats" ], p = 2 , dim = - 1 ))
108
- scales .append (torch .exp (ckpt ["scales" ]))
109
- opacities .append (torch .sigmoid (ckpt ["opacities" ]))
110
- sh0 .append (ckpt ["sh0" ])
111
- shN .append (ckpt ["shN" ])
112
- means = torch .cat (means , dim = 0 )
113
- quats = torch .cat (quats , dim = 0 )
114
- scales = torch .cat (scales , dim = 0 )
115
- opacities = torch .cat (opacities , dim = 0 )
116
- sh0 = torch .cat (sh0 , dim = 0 )
117
- shN = torch .cat (shN , dim = 0 )
118
- colors = torch .cat ([sh0 , shN ], dim = - 2 )
119
- sh_degree = int (math .sqrt (colors .shape [- 2 ]) - 1 )
120
104
121
- # # crop
122
- # aabb = torch.tensor((-1.0, -1.0, -1.0, 1.0, 1.0, 0.7), device=device)
123
- # edges = aabb[3:] - aabb[:3]
124
- # sel = ((means >= aabb[:3]) & (means <= aabb[3:])).all(dim=-1)
125
- # sel = torch.where(sel)[0]
126
- # means, quats, scales, colors, opacities = (
127
- # means[sel],
128
- # quats[sel],
129
- # scales[sel],
130
- # colors[sel],
131
- # opacities[sel],
132
- # )
105
+ convex_points , delta , sigma , num_points_per_convex , cumsum_of_points_per_convex = [], [], [], [], []
106
+ if args .backend == "3dcs" :
107
+ for ckpt_path in args .ckpt :
108
+ hyperparam = torch .load (os .path .join (ckpt_path , "hyperparameters.pt" ), map_location = device , weights_only = False )
109
+ pc = torch .load (os .path .join (ckpt_path , "point_cloud_state_dict.pt" ), map_location = device , weights_only = False )
110
+ convex_points .append (pc ['convex_points' ])
111
+ delta .append (torch .exp (pc ['delta' ]))
112
+ sigma .append (torch .exp (pc ['sigma' ]))
113
+ opacities .append (torch .sigmoid (pc ["opacity" ]).squeeze ())
114
+ num_points_per_convex .append (torch .tensor ([6 ]))
115
+ cumsum_of_points_per_convex .append (hyperparam ["cumsum_of_points_per_convex" ])
116
+ sh0 .append (pc ["features_dc" ])
117
+ shN .append (pc ["features_rest" ])
118
+ convex_points = torch .cat (convex_points , dim = 0 )
119
+ delta = torch .cat (delta , dim = 0 )
120
+ sigma = torch .cat (sigma , dim = 0 )
121
+ num_points_per_convex = torch .cat (num_points_per_convex , dim = 0 )
122
+ cumsum_of_points_per_convex = torch .cat (cumsum_of_points_per_convex , dim = 0 )
123
+ opacities = torch .cat (opacities , dim = 0 )
124
+ sh0 = torch .cat (sh0 , dim = 0 )
125
+ shN = torch .cat (shN , dim = 0 )
126
+ colors = torch .cat ([sh0 , shN ], dim = - 2 )
127
+ sh_degree = int (pc ["active_sh_degree" ])
128
+ print ("Number of 3D convexes:" , convex_points .shape [0 ]* convex_points .shape [1 ])
129
+ else :
130
+ for ckpt_path in args .ckpt :
131
+ ckpt = torch .load (ckpt_path , map_location = device )["splats" ]
132
+ means .append (ckpt ["means" ])
133
+ quats .append (F .normalize (ckpt ["quats" ], p = 2 , dim = - 1 ))
134
+ scales .append (torch .exp (ckpt ["scales" ]))
135
+ opacities .append (torch .sigmoid (ckpt ["opacities" ]))
136
+ sh0 .append (ckpt ["sh0" ])
137
+ shN .append (ckpt ["shN" ])
138
+ means = torch .cat (means , dim = 0 )
139
+ quats = torch .cat (quats , dim = 0 )
140
+ scales = torch .cat (scales , dim = 0 )
141
+ opacities = torch .cat (opacities , dim = 0 )
142
+ sh0 = torch .cat (sh0 , dim = 0 )
143
+ shN = torch .cat (shN , dim = 0 )
144
+ colors = torch .cat ([sh0 , shN ], dim = - 2 )
145
+ sh_degree = int (math .sqrt (colors .shape [- 2 ]) - 1 )
146
+
147
+ # # crop
148
+ # aabb = torch.tensor((-1.0, -1.0, -1.0, 1.0, 1.0, 0.7), device=device)
149
+ # edges = aabb[3:] - aabb[:3]
150
+ # sel = ((means >= aabb[:3]) & (means <= aabb[3:])).all(dim=-1)
151
+ # sel = torch.where(sel)[0]
152
+ # means, quats, scales, colors, opacities = (
153
+ # means[sel],
154
+ # quats[sel],
155
+ # scales[sel],
156
+ # colors[sel],
157
+ # opacities[sel],
158
+ # )
133
159
134
- # # repeat the scene into a grid (to mimic a large-scale setting)
135
- # repeats = args.scene_grid
136
- # gridx, gridy = torch.meshgrid(
137
- # [
138
- # torch.arange(-(repeats // 2), repeats // 2 + 1, device=device),
139
- # torch.arange(-(repeats // 2), repeats // 2 + 1, device=device),
140
- # ],
141
- # indexing="ij",
142
- # )
143
- # grid = torch.stack([gridx, gridy, torch.zeros_like(gridx)], dim=-1).reshape(
144
- # -1, 3
145
- # )
146
- # means = means[None, :, :] + grid[:, None, :] * edges[None, None, :]
147
- # means = means.reshape(-1, 3)
148
- # quats = quats.repeat(repeats**2, 1)
149
- # scales = scales.repeat(repeats**2, 1)
150
- # colors = colors.repeat(repeats**2, 1, 1)
151
- # opacities = opacities.repeat(repeats**2)
152
- print ("Number of Gaussians:" , len (means ))
160
+ # # repeat the scene into a grid (to mimic a large-scale setting)
161
+ # repeats = args.scene_grid
162
+ # gridx, gridy = torch.meshgrid(
163
+ # [
164
+ # torch.arange(-(repeats // 2), repeats // 2 + 1, device=device),
165
+ # torch.arange(-(repeats // 2), repeats // 2 + 1, device=device),
166
+ # ],
167
+ # indexing="ij",
168
+ # )
169
+ # grid = torch.stack([gridx, gridy, torch.zeros_like(gridx)], dim=-1).reshape(
170
+ # -1, 3
171
+ # )
172
+ # means = means[None, :, :] + grid[:, None, :] * edges[None, None, :]
173
+ # means = means.reshape(-1, 3)
174
+ # quats = quats.repeat(repeats**2, 1)
175
+ # scales = scales.repeat(repeats**2, 1)
176
+ # colors = colors.repeat(repeats**2, 1, 1)
177
+ # opacities = opacities.repeat(repeats**2)
178
+ print ("Number of Gaussians:" , len (means ))
153
179
154
180
# register and open viewer
155
181
@torch .no_grad ()
@@ -174,34 +200,74 @@ def viewer_render_fn(camera_state: CameraState, render_tab_state: RenderTabState
174
200
"alpha" : "RGB" ,
175
201
}
176
202
177
- render_colors , render_alphas , info = rasterization (
178
- means , # [N, 3]
179
- quats , # [N, 4]
180
- scales , # [N, 3]
181
- opacities , # [N]
182
- colors , # [N, S, 3]
183
- viewmat [None ], # [1, 4, 4]
184
- K [None ], # [1, 3, 3]
185
- width ,
186
- height ,
187
- sh_degree = (
188
- min (render_tab_state .max_sh_degree , sh_degree )
189
- if sh_degree is not None
190
- else None
191
- ),
192
- near_plane = render_tab_state .near_plane ,
193
- far_plane = render_tab_state .far_plane ,
194
- radius_clip = render_tab_state .radius_clip ,
195
- eps2d = render_tab_state .eps2d ,
196
- backgrounds = torch .tensor ([render_tab_state .backgrounds ], device = device )
197
- / 255.0 ,
198
- render_mode = RENDER_MODE_MAP [render_tab_state .render_mode ],
199
- rasterize_mode = render_tab_state .rasterize_mode ,
200
- camera_model = render_tab_state .camera_model ,
201
- packed = False ,
202
- )
203
- render_tab_state .total_gs_count = len (means )
204
- render_tab_state .rendered_gs_count = (info ["radii" ] > 0 ).all (- 1 ).sum ().item ()
203
+ if args .backend == "gsplat" :
204
+ rasterization_fn = rasterization
205
+ elif args .backend == "3dcs" :
206
+ rasterization_fn = rasterization_3dcs
207
+ elif args .backend == "inria" :
208
+ from gsplat import rasterization_inria_wrapper
209
+
210
+ rasterization_fn = rasterization_inria_wrapper
211
+ else :
212
+ raise ValueError
213
+
214
+ if args .backend == "3dcs" :
215
+ render_colors , render_alphas , info = rasterization_fn (
216
+ convex_points ,
217
+ delta ,
218
+ sigma ,
219
+ num_points_per_convex ,
220
+ cumsum_of_points_per_convex ,
221
+ opacities ,
222
+ colors ,
223
+ viewmat [None ], # [1, 4, 4]
224
+ K [None ], # [1, 3, 3]
225
+ width ,
226
+ height ,
227
+ packed = False ,
228
+ sh_degree = (
229
+ min (render_tab_state .max_sh_degree , sh_degree )
230
+ if sh_degree is not None
231
+ else None
232
+ ),
233
+ near_plane = render_tab_state .near_plane ,
234
+ far_plane = render_tab_state .far_plane ,
235
+ radius_clip = render_tab_state .radius_clip ,
236
+ eps2d = render_tab_state .eps2d ,
237
+ backgrounds = torch .tensor ([render_tab_state .backgrounds ], device = device )
238
+ / 255.0 ,
239
+ render_mode = RENDER_MODE_MAP [render_tab_state .render_mode ],
240
+ rasterize_mode = render_tab_state .rasterize_mode ,
241
+ camera_model = render_tab_state .camera_model ,
242
+ )
243
+ else :
244
+ render_colors , render_alphas , info = rasterization (
245
+ means , # [N, 3]
246
+ quats , # [N, 4]
247
+ scales , # [N, 3]
248
+ opacities , # [N]
249
+ colors , # [N, S, 3]
250
+ viewmat [None ], # [1, 4, 4]
251
+ K [None ], # [1, 3, 3]
252
+ width ,
253
+ height ,
254
+ sh_degree = (
255
+ min (render_tab_state .max_sh_degree , sh_degree )
256
+ if sh_degree is not None
257
+ else None
258
+ ),
259
+ near_plane = render_tab_state .near_plane ,
260
+ far_plane = render_tab_state .far_plane ,
261
+ radius_clip = render_tab_state .radius_clip ,
262
+ eps2d = render_tab_state .eps2d ,
263
+ backgrounds = torch .tensor ([render_tab_state .backgrounds ], device = device )
264
+ / 255.0 ,
265
+ render_mode = RENDER_MODE_MAP [render_tab_state .render_mode ],
266
+ rasterize_mode = render_tab_state .rasterize_mode ,
267
+ camera_model = render_tab_state .camera_model ,
268
+ )
269
+ render_tab_state .total_gs_count = len (means )
270
+ render_tab_state .rendered_gs_count = (info ["radii" ] > 0 ).all (- 1 ).sum ().item ()
205
271
206
272
if render_tab_state .render_mode == "rgb" :
207
273
# colors represented with sh are not guranteed to be in [0, 1]
@@ -267,6 +333,12 @@ def viewer_render_fn(camera_state: CameraState, render_tab_state: RenderTabState
267
333
parser .add_argument (
268
334
"--ckpt" , type = str , nargs = "+" , default = None , help = "path to the .pt file"
269
335
)
336
+ parser .add_argument (
337
+ "--ply" , type = str , nargs = "+" , default = None , help = "path to the .ply file"
338
+ )
339
+ parser .add_argument (
340
+ "--backend" , type = str , default = "gsplat" , choices = ["gsplat" , "3dcs" , "inria" ], help = "backend to use for rendering" ,
341
+ )
270
342
parser .add_argument (
271
343
"--port" , type = int , default = 8080 , help = "port for the viewer server"
272
344
)
0 commit comments