Skip to content

Commit fcac049

Browse files
author
Hexu Zhao
committed
move extended_computed_locally to python part.
1 parent d898cb4 commit fcac049

File tree

1 file changed

+7
-56
lines changed

1 file changed

+7
-56
lines changed

diff_gaussian_rasterization/__init__.py

Lines changed: 7 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def render_gaussians(
159159
depths,
160160
radii,
161161
compute_locally,
162+
extended_compute_locally,
162163
raster_settings,
163164
cuda_args,
164165
):
@@ -169,46 +170,11 @@ def render_gaussians(
169170
depths,
170171
radii,
171172
compute_locally,
173+
extended_compute_locally,
172174
raster_settings,
173175
cuda_args,
174176
)
175177

176-
def get_extended_compute_locally(cuda_args, image_height, image_width):
177-
if isinstance(cuda_args["dist_global_strategy"], str):
178-
mp_rank = int(cuda_args["mp_rank"])
179-
dist_global_strategy = [int(x) for x in cuda_args["dist_global_strategy"].split(",")]
180-
181-
num_tile_y = (image_height + 16 - 1) // 16 #TODO: this is dangerous because 16 may change.
182-
num_tile_x = (image_width + 16 - 1) // 16
183-
tile_l = max(dist_global_strategy[mp_rank]-num_tile_x-1, 0)
184-
tile_r = min(dist_global_strategy[mp_rank+1]+num_tile_x+1, num_tile_y*num_tile_x)
185-
186-
extended_compute_locally = torch.zeros(num_tile_y*num_tile_x, dtype=torch.bool, device="cuda")
187-
extended_compute_locally[tile_l:tile_r] = True
188-
extended_compute_locally = extended_compute_locally.view(num_tile_y, num_tile_x)
189-
190-
return extended_compute_locally
191-
else:
192-
division_pos = cuda_args["dist_global_strategy"]
193-
division_pos_xs, division_pos_ys = division_pos
194-
mp_rank = int(cuda_args["mp_rank"])
195-
grid_size_x = len(division_pos_xs) - 1
196-
grid_size_y = len(division_pos_ys[0]) - 1
197-
y_rank = mp_rank // grid_size_x
198-
x_rank = mp_rank % grid_size_x
199-
200-
local_tile_x_l, local_tile_x_r = division_pos_xs[x_rank], division_pos_xs[x_rank+1]
201-
local_tile_y_l, local_tile_y_r = division_pos_ys[x_rank][y_rank], division_pos_ys[x_rank][y_rank+1]
202-
203-
num_tile_y = (image_height + 16 - 1) // 16
204-
num_tile_x = (image_width + 16 - 1) // 16
205-
206-
extended_compute_locally = torch.zeros((num_tile_y, num_tile_x), dtype=torch.bool, device="cuda")
207-
extended_compute_locally[max(local_tile_y_l-1,0):min(local_tile_y_r+1,num_tile_y),
208-
max(local_tile_x_l-1,0):min(local_tile_x_r+1,num_tile_x)] = True
209-
210-
return extended_compute_locally
211-
212178
class _RenderGaussians(torch.autograd.Function):
213179
@staticmethod
214180
def forward(
@@ -219,6 +185,7 @@ def forward(
219185
depths,
220186
radii,
221187
compute_locally,
188+
extended_compute_locally,
222189
raster_settings,
223190
cuda_args,
224191
):
@@ -231,10 +198,6 @@ def forward(
231198
# Basically, means2D is (P, 3) in python. But it is (P, 2) in cuda code.
232199
# dL_dmeans2D is alwayds (P, 3) in both python and cuda code.
233200

234-
extended_compute_locally = get_extended_compute_locally(cuda_args,
235-
raster_settings.image_height,
236-
raster_settings.image_width)
237-
238201
# Restructure arguments the way that the C++ lib expects them
239202
args = (
240203
raster_settings.bg,
@@ -312,6 +275,7 @@ def backward(ctx, grad_color, grad_n_render, grad_n_consider, grad_n_contrib):
312275
None,
313276
None,
314277
None,
278+
None,
315279
None # this is for cuda_args
316280
)
317281

@@ -370,7 +334,7 @@ def preprocess_gaussians(self, means3D, scales, rotations, shs, opacities, cuda_
370334
raster_settings,
371335
cuda_args)
372336

373-
def render_gaussians(self, means2D, conic_opacity, rgb, depths, radii, compute_locally, cuda_args = None):
337+
def render_gaussians(self, means2D, conic_opacity, rgb, depths, radii, compute_locally, extended_compute_locally, cuda_args = None):
374338

375339
raster_settings = self.raster_settings
376340

@@ -382,11 +346,13 @@ def render_gaussians(self, means2D, conic_opacity, rgb, depths, radii, compute_l
382346
depths,
383347
radii,
384348
compute_locally,
349+
extended_compute_locally,
385350
raster_settings,
386351
cuda_args
387352
)
388353

389354
def get_local2j_ids(self, means2D, radii, cuda_args):
355+
# For each 3dgs, calculate the set of GPUs that will use this 3dgs for rendering.
390356

391357
if isinstance(cuda_args["dist_global_strategy"], str):
392358
raster_settings = self.raster_settings
@@ -448,21 +414,6 @@ def get_local2j_ids(self, means2D, radii, cuda_args):
448414
return local2j_ids, local2j_ids_bool
449415

450416

451-
def get_distribution_strategy(self, means2D, radii, cuda_args):
452-
453-
assert False, "This function is not used in the current version."
454-
455-
raster_settings = self.raster_settings
456-
457-
return _C.get_distribution_strategy(
458-
raster_settings.image_height,
459-
raster_settings.image_width,
460-
means2D,
461-
radii,
462-
raster_settings.debug,
463-
cuda_args
464-
)# the return is compute_locally
465-
466417
class _LoadImageTilesByPos(torch.autograd.Function):
467418

468419
@staticmethod

0 commit comments

Comments
 (0)