@@ -31,7 +31,7 @@ def preprocess_gaussians(
3131 sh ,
3232 opacities ,
3333 raster_settings ,
34- cuda_args ,
34+ cuda_args ,flag_batched = False
3535):
3636 return _PreprocessGaussians .apply (
3737 means3D ,
@@ -40,7 +40,7 @@ def preprocess_gaussians(
4040 sh ,
4141 opacities ,
4242 raster_settings ,
43- cuda_args ,
43+ cuda_args ,flag_batched
4444 )
4545
4646class _PreprocessGaussians (torch .autograd .Function ):
@@ -52,45 +52,88 @@ def forward(
5252 rotations ,
5353 sh ,
5454 opacities ,
55- raster_settings ,
56- cuda_args ,
55+ raster_settings_list ,
56+ batched_cuda_args , flag_batched
5757 ):
5858
5959 # Restructure arguments the way that the C++ lib expects them
60- args = (
61- means3D ,
62- scales ,
63- rotations ,
64- sh ,
65- opacities ,# 3dgs' parametes.
66- raster_settings .scale_modifier ,
67- raster_settings .viewmatrix ,
68- raster_settings .projmatrix ,
69- raster_settings .tanfovx ,
70- raster_settings .tanfovy ,
71- raster_settings .image_height ,
72- raster_settings .image_width ,
73- raster_settings .sh_degree ,
74- raster_settings .campos ,
75- raster_settings .prefiltered ,
76- raster_settings .debug ,#raster_settings
77- cuda_args
78- )
79-
80- # TODO: update this.
81- num_rendered , means2D , depths , radii , cov3D , conic_opacity , rgb , clamped = _C .preprocess_gaussians (* args )
60+ if flag_batched == False :
61+ args = (
62+ means3D ,
63+ scales ,
64+ rotations ,
65+ sh ,
66+ opacities ,# 3dgs' parametes.
67+ raster_settings .scale_modifier ,
68+ raster_settings .viewmatrix ,
69+ raster_settings .projmatrix ,
70+ raster_settings .tanfovx ,
71+ raster_settings .tanfovy ,
72+ raster_settings .image_height ,
73+ raster_settings .image_width ,
74+ raster_settings .sh_degree ,
75+ raster_settings .campos ,
76+ raster_settings .prefiltered ,
77+ raster_settings .debug ,#raster_settings
78+ cuda_args
79+ )
80+
81+ # TODO: update this.
82+ num_rendered , means2D , depths , radii , cov3D , conic_opacity , rgb , clamped = _C .preprocess_gaussians (* args )
83+
84+ # Keep relevant tensors for backward
85+ ctx .raster_settings = raster_settings
86+ ctx .cuda_args = cuda_args
87+ ctx .num_rendered = num_rendered
88+ ctx .save_for_backward (means3D , scales , rotations , sh , means2D , depths , radii , cov3D , conic_opacity , rgb , clamped )
89+ ctx .mark_non_differentiable (radii , depths )
90+
91+ # # TODO: double check. means2D is padded to (P, 3) in python. It is (P, 2) in cuda code.
92+ # means2D_pad = torch.zeros((means2D.shape[0], 1), dtype = means2D.dtype, device = means2D.device)
93+ # means2D = torch.cat((means2D, means2D_pad), dim = 1).contiguous()
94+ return means2D , rgb , conic_opacity , radii , depths
95+
96+ else :
97+ args_list = []
98+ for raster_settings ,cuda_args in zip (raster_settings_list ,batched_cuda_args ):
99+
100+ args = (
101+ means3D ,
102+ scales ,
103+ rotations ,
104+ sh ,
105+ opacities ,# 3dgs' parametes.
106+ raster_settings .scale_modifier ,
107+ raster_settings .viewmatrix ,
108+ raster_settings .projmatrix ,
109+ raster_settings .tanfovx ,
110+ raster_settings .tanfovy ,
111+ raster_settings .image_height ,
112+ raster_settings .image_width ,
113+ raster_settings .sh_degree ,
114+ raster_settings .campos ,
115+ raster_settings .prefiltered ,
116+ raster_settings .debug ,#raster_settings
117+ cuda_args
118+ )
119+ args_list .append (args )
120+
121+ # TODO: update this.
122+ num_rendered , means2D , depths , radii , cov3D , conic_opacity , rgb , clamped = _C .preprocess_gaussians_batches (* args_list )
123+
124+ # Keep relevant tensors for backward
125+ ctx .raster_settings = raster_settings_list
126+ ctx .cuda_args = batched_cuda_args
127+ ctx .num_rendered = num_rendered
128+ ctx .save_for_backward (means3D , scales , rotations , sh , means2D , depths , radii , cov3D , conic_opacity , rgb , clamped )
129+ ctx .mark_non_differentiable (radii , depths )
130+
131+ # # TODO: double check. means2D is padded to (P, 3) in python. It is (P, 2) in cuda code.
132+ # means2D_pad = torch.zeros((means2D.shape[0], 1), dtype = means2D.dtype, device = means2D.device)
133+ # means2D = torch.cat((means2D, means2D_pad), dim = 1).contiguous()
134+ return means2D , rgb , conic_opacity , radii , depths
82135
83- # Keep relevant tensors for backward
84- ctx .raster_settings = raster_settings
85- ctx .cuda_args = cuda_args
86- ctx .num_rendered = num_rendered
87- ctx .save_for_backward (means3D , scales , rotations , sh , means2D , depths , radii , cov3D , conic_opacity , rgb , clamped )
88- ctx .mark_non_differentiable (radii , depths )
89136
90- # # TODO: double check. means2D is padded to (P, 3) in python. It is (P, 2) in cuda code.
91- # means2D_pad = torch.zeros((means2D.shape[0], 1), dtype = means2D.dtype, device = means2D.device)
92- # means2D = torch.cat((means2D, means2D_pad), dim = 1).contiguous()
93- return means2D , rgb , conic_opacity , radii , depths
94137
95138 @staticmethod # TODO: gradient for conic_opacity is tricky. because cuda render backward generate dL_dconic and dL_dopacity sperately.
96139 def backward (ctx , grad_means2D , grad_rgb , grad_conic_opacity , grad_radii , grad_depths ):
@@ -320,14 +363,14 @@ def markVisible(self, positions):
320363 def preprocess_gaussians (self , means3D , scales , rotations , shs , opacities , batched_cuda_args = None ):
321364 # Invoke C++/CUDA rasterization routine
322365
323- return preprocess_gaussians_batches (
366+ return preprocess_gaussians (
324367 means3D ,
325368 scales ,
326369 rotations ,
327370 shs ,
328371 opacities ,
329372 self .raster_settings_list ,
330- batched_cuda_args )
373+ batched_cuda_args , True )
331374
332375class GaussianRasterizer (nn .Module ):
333376 def __init__ (self , raster_settings ):
0 commit comments