@@ -159,6 +159,7 @@ def render_gaussians(
159159 depths ,
160160 radii ,
161161 compute_locally ,
162+ extended_compute_locally ,
162163 raster_settings ,
163164 cuda_args ,
164165):
@@ -169,46 +170,11 @@ def render_gaussians(
169170 depths ,
170171 radii ,
171172 compute_locally ,
173+ extended_compute_locally ,
172174 raster_settings ,
173175 cuda_args ,
174176 )
175177
176- def get_extended_compute_locally (cuda_args , image_height , image_width ):
177- if isinstance (cuda_args ["dist_global_strategy" ], str ):
178- mp_rank = int (cuda_args ["mp_rank" ])
179- dist_global_strategy = [int (x ) for x in cuda_args ["dist_global_strategy" ].split ("," )]
180-
181- num_tile_y = (image_height + 16 - 1 ) // 16 #TODO: this is dangerous because 16 may change.
182- num_tile_x = (image_width + 16 - 1 ) // 16
183- tile_l = max (dist_global_strategy [mp_rank ]- num_tile_x - 1 , 0 )
184- tile_r = min (dist_global_strategy [mp_rank + 1 ]+ num_tile_x + 1 , num_tile_y * num_tile_x )
185-
186- extended_compute_locally = torch .zeros (num_tile_y * num_tile_x , dtype = torch .bool , device = "cuda" )
187- extended_compute_locally [tile_l :tile_r ] = True
188- extended_compute_locally = extended_compute_locally .view (num_tile_y , num_tile_x )
189-
190- return extended_compute_locally
191- else :
192- division_pos = cuda_args ["dist_global_strategy" ]
193- division_pos_xs , division_pos_ys = division_pos
194- mp_rank = int (cuda_args ["mp_rank" ])
195- grid_size_x = len (division_pos_xs ) - 1
196- grid_size_y = len (division_pos_ys [0 ]) - 1
197- y_rank = mp_rank // grid_size_x
198- x_rank = mp_rank % grid_size_x
199-
200- local_tile_x_l , local_tile_x_r = division_pos_xs [x_rank ], division_pos_xs [x_rank + 1 ]
201- local_tile_y_l , local_tile_y_r = division_pos_ys [x_rank ][y_rank ], division_pos_ys [x_rank ][y_rank + 1 ]
202-
203- num_tile_y = (image_height + 16 - 1 ) // 16
204- num_tile_x = (image_width + 16 - 1 ) // 16
205-
206- extended_compute_locally = torch .zeros ((num_tile_y , num_tile_x ), dtype = torch .bool , device = "cuda" )
207- extended_compute_locally [max (local_tile_y_l - 1 ,0 ):min (local_tile_y_r + 1 ,num_tile_y ),
208- max (local_tile_x_l - 1 ,0 ):min (local_tile_x_r + 1 ,num_tile_x )] = True
209-
210- return extended_compute_locally
211-
212178class _RenderGaussians (torch .autograd .Function ):
213179 @staticmethod
214180 def forward (
@@ -219,6 +185,7 @@ def forward(
219185 depths ,
220186 radii ,
221187 compute_locally ,
188+ extended_compute_locally ,
222189 raster_settings ,
223190 cuda_args ,
224191 ):
@@ -231,10 +198,6 @@ def forward(
231198 # Basically, means2D is (P, 3) in python. But it is (P, 2) in cuda code.
232199 # dL_dmeans2D is alwayds (P, 3) in both python and cuda code.
233200
234- extended_compute_locally = get_extended_compute_locally (cuda_args ,
235- raster_settings .image_height ,
236- raster_settings .image_width )
237-
238201 # Restructure arguments the way that the C++ lib expects them
239202 args = (
240203 raster_settings .bg ,
@@ -312,6 +275,7 @@ def backward(ctx, grad_color, grad_n_render, grad_n_consider, grad_n_contrib):
312275 None ,
313276 None ,
314277 None ,
278+ None ,
315279 None # this is for cuda_args
316280 )
317281
@@ -370,7 +334,7 @@ def preprocess_gaussians(self, means3D, scales, rotations, shs, opacities, cuda_
370334 raster_settings ,
371335 cuda_args )
372336
373- def render_gaussians (self , means2D , conic_opacity , rgb , depths , radii , compute_locally , cuda_args = None ):
337+ def render_gaussians (self , means2D , conic_opacity , rgb , depths , radii , compute_locally , extended_compute_locally , cuda_args = None ):
374338
375339 raster_settings = self .raster_settings
376340
@@ -382,11 +346,13 @@ def render_gaussians(self, means2D, conic_opacity, rgb, depths, radii, compute_l
382346 depths ,
383347 radii ,
384348 compute_locally ,
349+ extended_compute_locally ,
385350 raster_settings ,
386351 cuda_args
387352 )
388353
389354 def get_local2j_ids (self , means2D , radii , cuda_args ):
355+ # For each 3dgs, calculate the set of GPUs that will use this 3dgs for rendering.
390356
391357 if isinstance (cuda_args ["dist_global_strategy" ], str ):
392358 raster_settings = self .raster_settings
@@ -448,21 +414,6 @@ def get_local2j_ids(self, means2D, radii, cuda_args):
448414 return local2j_ids , local2j_ids_bool
449415
450416
451- def get_distribution_strategy (self , means2D , radii , cuda_args ):
452-
453- assert False , "This function is not used in the current version."
454-
455- raster_settings = self .raster_settings
456-
457- return _C .get_distribution_strategy (
458- raster_settings .image_height ,
459- raster_settings .image_width ,
460- means2D ,
461- radii ,
462- raster_settings .debug ,
463- cuda_args
464- )# the return is compute_locally
465-
466417class _LoadImageTilesByPos (torch .autograd .Function ):
467418
468419 @staticmethod
0 commit comments