@@ -351,69 +351,6 @@ def render_gaussians(self, means2D, conic_opacity, rgb, depths, radii, compute_l
351351 cuda_args
352352 )
353353
354- def get_local2j_ids (self , means2D , radii , cuda_args ):
355- # For each 3dgs, calculate the set of GPUs that will use this 3dgs for rendering.
356-
357- if isinstance (cuda_args ["dist_global_strategy" ], str ):
358- raster_settings = self .raster_settings
359- mp_world_size = int (cuda_args ["mp_world_size" ])
360- mp_rank = int (cuda_args ["mp_rank" ])
361-
362- # TODO: make it more general.
363- dist_global_strategy = [int (x ) for x in cuda_args ["dist_global_strategy" ].split ("," )]
364- assert len (dist_global_strategy ) == mp_world_size + 1 , "dist_global_strategy should have length WORLD_SIZE+1"
365- assert dist_global_strategy [0 ] == 0 , "dist_global_strategy[0] should be 0"
366- dist_global_strategy = torch .tensor (dist_global_strategy , dtype = torch .int , device = means2D .device )
367-
368- args = (
369- raster_settings .image_height ,
370- raster_settings .image_width ,
371- mp_rank ,
372- mp_world_size ,
373- means2D ,
374- radii ,
375- dist_global_strategy ,
376- cuda_args
377- )
378-
379- local2j_ids_bool = _C .get_local2j_ids_bool (* args ) # local2j_ids_bool is (P, world_size) bool tensor
380-
381- else :
382- raster_settings = self .raster_settings
383- mp_world_size = int (cuda_args ["mp_world_size" ])
384- mp_rank = int (cuda_args ["mp_rank" ])
385-
386- division_pos = cuda_args ["dist_global_strategy" ]
387- division_pos_xs , division_pos_ys = division_pos
388-
389- rectangles = []
390- for y_rank in range (len (division_pos_ys [0 ])- 1 ):
391- for x_rank in range (len (division_pos_ys )):
392- local_tile_x_l , local_tile_x_r = division_pos_xs [x_rank ], division_pos_xs [x_rank + 1 ]
393- local_tile_y_l , local_tile_y_r = division_pos_ys [x_rank ][y_rank ], division_pos_ys [x_rank ][y_rank + 1 ]
394- rectangles .append ([local_tile_y_l , local_tile_y_r , local_tile_x_l , local_tile_x_r ])
395- rectangles = torch .tensor (rectangles , dtype = torch .int , device = means2D .device )# (mp_world_size, 4)
396-
397- args = (
398- raster_settings .image_height ,
399- raster_settings .image_width ,
400- mp_rank ,
401- mp_world_size ,
402- means2D ,
403- radii ,
404- rectangles ,
405- cuda_args
406- )
407-
408- local2j_ids_bool = _C .get_local2j_ids_bool_adjust_mode6 (* args ) # local2j_ids_bool is (P, world_size) bool tensor
409-
410- local2j_ids = []
411- for rk in range (mp_world_size ):
412- local2j_ids .append (local2j_ids_bool [:, rk ].nonzero ())
413-
414- return local2j_ids , local2j_ids_bool
415-
416-
417354class _LoadImageTilesByPos (torch .autograd .Function ):
418355
419356 @staticmethod
0 commit comments