Skip to content

Commit a751a33

Browse files
author
David Svitov
committedMar 14, 2025
Implement StopThePop as an optional configuration
1 parent a976565 commit a751a33

13 files changed

+1238
-38
lines changed
 

‎cuda_rasterizer/auxiliary.h

+35-3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#define MIDDEPTH_OFFSET 5
2828
#define DISTORTION_OFFSET 6
2929
#define MEDIAN_WEIGHT_OFFSET 7
30+
#define OUTPUT_CHANNELS 8
3031

3132
// distortion helper macros
3233
#define BACKFACE_CULL 1
@@ -35,6 +36,16 @@
3536
#define FAR_PLANE 100.0
3637
#define DETACH_WEIGHT 0
3738

39+
#define TILE_SORTING 0
40+
#define PIXEL_RESORTING 0
41+
#define BUFFER_LENGTH 8
42+
43+
#define FAST_INFERENCE 0
44+
#define MAX_BILLBOARD_SIZE 1000
45+
46+
constexpr uint32_t WARP_SIZE = 32U;
47+
constexpr uint32_t WARP_MASK = 0xFFFFFFFFU;
48+
3849
// Spherical harmonics coefficients
3950
__device__ const float SH_C0 = 0.28209479177387814f;
4051
__device__ const float SH_C1 = 0.4886025119029199f;
@@ -55,12 +66,33 @@ __device__ const float SH_C3[] = {
5566
-0.5900435899266435f
5667
};
5768

69+
template<typename T>
70+
__device__ void swap_T(T& a, T& b)
71+
{
72+
T temp = a;
73+
a = b;
74+
b = temp;
75+
}
76+
5877
__forceinline__ __device__ float ndc2Pix(float v, int S)
5978
{
6079
return ((v + 1.0) * S - 1.0) * 0.5;
6180
}
6281

63-
__forceinline__ __device__ void getRect(const float2 p, int max_radius, uint2& rect_min, uint2& rect_max, dim3 grid)
82+
83+
__forceinline__ __device__ void getRect(const float2 p, float2 max_radius, uint2& rect_min, uint2& rect_max, dim3 grid)
84+
{
85+
rect_min = {
86+
min(grid.x, max((int)0, (int)floorf((p.x - max_radius.x) / BLOCK_X))),
87+
min(grid.y, max((int)0, (int)floorf((p.y - max_radius.y) / BLOCK_Y)))
88+
};
89+
rect_max = {
90+
min(grid.x, max((int)0, (int)ceilf((p.x + max_radius.x) / BLOCK_X))),
91+
min(grid.y, max((int)0, (int)ceilf((p.y + max_radius.y) / BLOCK_Y)))
92+
};
93+
}
94+
95+
__forceinline__ __device__ void getRectOld(const float2 p, int max_radius, uint2& rect_min, uint2& rect_max, dim3 grid)
6496
{
6597
rect_min = {
6698
min(grid.x, max((int)0, (int)((p.x - max_radius) / BLOCK_X))),
@@ -261,7 +293,7 @@ scale_to_mat(const float3 scale, const float glob_scale) {
261293
glm::mat3 S = glm::mat3(1.f);
262294
S[0][0] = glob_scale * scale.x;
263295
S[1][1] = glob_scale * scale.y;
264-
S[2][2] = glob_scale * scale.z;
296+
//S[2][2] = glob_scale * scale.z;
265297
return S;
266298
}
267299

@@ -276,4 +308,4 @@ throw std::runtime_error(cudaGetErrorString(ret)); \
276308
} \
277309
}
278310

279-
#endif
311+
#endif

‎cuda_rasterizer/backward.cu

+32-5
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "auxiliary.h"
1515
#include <cooperative_groups.h>
1616
#include <cooperative_groups/reduce.h>
17+
#include "stopthepop_2DGS/resorted_render.cuh"
1718
namespace cg = cooperative_groups;
1819

1920
// Backward pass for conversion of spherical harmonics to RGB for
@@ -257,11 +258,6 @@ renderCUDA(
257258
float last_alpha = 0;
258259
float last_color[C] = { 0 };
259260

260-
// Gradient of pixel coordinate w.r.t. normalized
261-
// screen-space viewport corrdinates (-1 to 1)
262-
const float ddelx_dx = 0.5 * W;
263-
const float ddely_dy = 0.5 * H;
264-
265261
// Traverse all Gaussians
266262
for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE)
267263
{
@@ -744,6 +740,8 @@ void BACKWARD::render(
744740
const float* depths,
745741
const float* final_Ts,
746742
const uint32_t* n_contrib,
743+
const float* out_color,
744+
const float* out_others,
747745
const float* dL_dpixels,
748746
const float* dL_depths,
749747
float * dL_dtransMat,
@@ -753,6 +751,34 @@ void BACKWARD::render(
753751
float* dL_dtexture_alpha,
754752
float* dL_dtexture_color)
755753
{
754+
#if PIXEL_RESORTING
755+
renderkBufferBackwardCUDA<NUM_CHANNELS> << <grid, block >> >(
756+
ranges,
757+
point_list,
758+
W, H,
759+
focal_x, focal_y,
760+
bg_color,
761+
texture_alpha,
762+
texture_color,
763+
texture_size,
764+
means2D,
765+
normal_array,
766+
transMats,
767+
colors,
768+
depths,
769+
final_Ts,
770+
n_contrib,
771+
out_color,
772+
out_others,
773+
dL_dpixels,
774+
dL_depths,
775+
dL_dtransMat,
776+
dL_dmean2D,
777+
dL_dnormal3D,
778+
dL_dcolors,
779+
dL_dtexture_alpha,
780+
dL_dtexture_color);
781+
#else
756782
renderCUDA<NUM_CHANNELS> << <grid, block >> >(
757783
ranges,
758784
point_list,
@@ -777,4 +803,5 @@ void BACKWARD::render(
777803
dL_dcolors,
778804
dL_dtexture_alpha,
779805
dL_dtexture_color);
806+
#endif
780807
}

‎cuda_rasterizer/backward.h

+2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ namespace BACKWARD
3737
const float* depths,
3838
const float* final_Ts,
3939
const uint32_t* n_contrib,
40+
const float* out_color,
41+
const float* out_others,
4042
const float* dL_dpixels,
4143
const float* dL_depths,
4244
float * dL_dtransMat,

‎cuda_rasterizer/forward.cu

+93-26
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include "forward.h"
1313
#include "grid_sample.h"
1414
#include "auxiliary.h"
15+
#include "stopthepop_2DGS/stopthepop_common.cuh"
16+
#include "stopthepop_2DGS/resorted_render.cuh"
1517
#include <cooperative_groups.h>
1618
#include <cooperative_groups/reduce.h>
1719
namespace cg = cooperative_groups;
@@ -181,6 +183,7 @@ __global__ void preprocessCUDA(int P, int D, int M,
181183
const float tan_fovx, const float tan_fovy,
182184
const float focal_x, const float focal_y,
183185
int* radii,
186+
float2* rects,
184187
float2* points_xy_image,
185188
float* depths,
186189
float* transMats,
@@ -233,9 +236,18 @@ __global__ void preprocessCUDA(int P, int D, int M,
233236
float radius = ceil(truncated_R * max(max(extent.x, extent.y), FilterSize));
234237

235238
uint2 rect_min, rect_max;
236-
getRect(center, radius, rect_min, rect_max, grid);
237-
if ((rect_max.x - rect_min.x) * (rect_max.y - rect_min.y) == 0)
238-
return;
239+
#if FAST_INFERENCE
240+
if (radius > MAX_BILLBOARD_SIZE)
241+
getRectOld(center, radius, rect_min, rect_max, grid);
242+
else
243+
getRect(center, extent, rect_min, rect_max, grid);
244+
#else
245+
getRectOld(center, radius, rect_min, rect_max, grid);
246+
#endif
247+
248+
if ((rect_max.x - rect_min.x) * (rect_max.y - rect_min.y) == 0) {
249+
return;
250+
}
239251

240252
// compute colors
241253
if (colors_precomp == nullptr) {
@@ -246,6 +258,7 @@ __global__ void preprocessCUDA(int P, int D, int M,
246258
}
247259

248260
depths[idx] = p_view.z;
261+
rects[idx] = extent;
249262
radii[idx] = (int)radius;
250263
points_xy_image[idx] = center;
251264
// store them in float4
@@ -299,7 +312,6 @@ renderCUDA(
299312

300313
// Allocate storage for batches of collectively fetched data.
301314
__shared__ int collected_id[BLOCK_SIZE];
302-
__shared__ float2 collected_xy[BLOCK_SIZE];
303315
__shared__ float3 collected_normal[BLOCK_SIZE];
304316
__shared__ float3 collected_Tu[BLOCK_SIZE];
305317
__shared__ float3 collected_Tv[BLOCK_SIZE];
@@ -319,7 +331,7 @@ renderCUDA(
319331
float dist1 = {0};
320332
float dist2 = {0};
321333
float distortion = {0};
322-
float median_depth = {0};
334+
float median_depth = {100};
323335
float median_weight = {0};
324336
float median_contributor = {-1};
325337

@@ -339,7 +351,6 @@ renderCUDA(
339351
{
340352
int coll_id = point_list[range.x + progress];
341353
collected_id[block.thread_rank()] = coll_id;
342-
collected_xy[block.thread_rank()] = points_xy_image[coll_id];
343354
collected_normal[block.thread_rank()] = normal_array[coll_id];
344355
collected_Tu[block.thread_rank()] = {transMats[9 * coll_id+0], transMats[9 * coll_id+1], transMats[9 * coll_id+2]};
345356
collected_Tv[block.thread_rank()] = {transMats[9 * coll_id+3], transMats[9 * coll_id+4], transMats[9 * coll_id+5]};
@@ -409,7 +420,7 @@ renderCUDA(
409420
float error = mapped_depth * mapped_depth * A + dist2 - 2 * mapped_depth * dist1;
410421
distortion += error * alpha * T;
411422

412-
if (T > 0.5) {
423+
if (T > 0.5 && alpha > 0.05) {
413424
median_depth = depth;
414425
median_weight = alpha * T;
415426
median_contributor = contributor;
@@ -484,25 +495,48 @@ void FORWARD::render(
484495
float* out_others,
485496
float* impact)
486497
{
487-
renderCUDA<NUM_CHANNELS> << <grid, block >> > (
488-
ranges,
489-
point_list,
490-
W, H,
491-
focal_x, focal_y,
492-
means2D,
493-
colors,
494-
texture_alpha,
495-
texture_color,
496-
texture_size,
497-
transMats,
498-
depths,
499-
normal_array,
500-
final_T,
501-
n_contrib,
502-
bg_color,
503-
out_color,
504-
out_others,
505-
impact);
498+
499+
#if PIXEL_RESORTING
500+
renderBufferCUDA<NUM_CHANNELS> << <grid, block >> > (
501+
ranges,
502+
point_list,
503+
W, H,
504+
focal_x, focal_y,
505+
means2D,
506+
colors,
507+
texture_alpha,
508+
texture_color,
509+
texture_size,
510+
transMats,
511+
depths,
512+
normal_array,
513+
final_T,
514+
n_contrib,
515+
bg_color,
516+
out_color,
517+
out_others,
518+
impact);
519+
#else
520+
renderCUDA<NUM_CHANNELS> << <grid, block >> > (
521+
ranges,
522+
point_list,
523+
W, H,
524+
focal_x, focal_y,
525+
means2D,
526+
colors,
527+
texture_alpha,
528+
texture_color,
529+
texture_size,
530+
transMats,
531+
depths,
532+
normal_array,
533+
final_T,
534+
n_contrib,
535+
bg_color,
536+
out_color,
537+
out_others,
538+
impact);
539+
#endif
506540
}
507541

508542
void FORWARD::preprocess(int P, int D, int M,
@@ -521,6 +555,7 @@ void FORWARD::preprocess(int P, int D, int M,
521555
const float focal_x, const float focal_y,
522556
const float tan_fovx, const float tan_fovy,
523557
int* radii,
558+
float2* rects,
524559
float2* means2D,
525560
float* depths,
526561
float* transMats,
@@ -547,6 +582,7 @@ void FORWARD::preprocess(int P, int D, int M,
547582
tan_fovx, tan_fovy,
548583
focal_x, focal_y,
549584
radii,
585+
rects,
550586
means2D,
551587
depths,
552588
transMats,
@@ -557,3 +593,34 @@ void FORWARD::preprocess(int P, int D, int M,
557593
prefiltered
558594
);
559595
}
596+
597+
void FORWARD::duplicate(
598+
int P,
599+
int W, int H,
600+
const float focal_x, const float focal_y,
601+
const float2* means2D,
602+
const float* depths,
603+
const float2* scales,
604+
const float* view2gaussians,
605+
const uint32_t* offsets,
606+
const int* radii,
607+
const float2* rects,
608+
uint64_t* gaussian_keys_unsorted,
609+
uint32_t* gaussian_values_unsorted,
610+
dim3 grid)
611+
{
612+
duplicateWithKeys_extended<false, true> << <(P + 255) / 256, 256 >> >(
613+
P, W, H, focal_x, focal_y,
614+
means2D,
615+
depths,
616+
scales,
617+
view2gaussians,
618+
offsets,
619+
radii,
620+
rects,
621+
gaussian_keys_unsorted,
622+
gaussian_values_unsorted,
623+
grid
624+
);
625+
626+
}

‎cuda_rasterizer/forward.h

+16
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ namespace FORWARD
3737
const float focal_x, float focal_y,
3838
const float tan_fovx, float tan_fovy,
3939
int* radii,
40+
float2* rects,
4041
float2* points_xy_image,
4142
float* depths,
4243
float* transMats,
@@ -67,6 +68,21 @@ namespace FORWARD
6768
float* out_color,
6869
float* out_others,
6970
float* impact);
71+
72+
void duplicate(
73+
int P,
74+
int W, int H,
75+
const float focal_x, const float focal_y,
76+
const float2 *means2D,
77+
const float* depths,
78+
const float2* scales,
79+
const float* view2gaussians,
80+
const uint32_t* offsets,
81+
const int* radii,
82+
const float2* rects,
83+
uint64_t* gaussian_keys_unsorted,
84+
uint32_t* gaussian_values_unsorted,
85+
dim3 grid);
7086
}
7187

7288

‎cuda_rasterizer/rasterizer.h

+2
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ namespace CudaRasterizer
7575
const float* campos,
7676
const float tan_fovx, float tan_fovy,
7777
const int* radii,
78+
const float* out_color,
79+
const float* out_others,
7880
char* geom_buffer,
7981
char* binning_buffer,
8082
char* image_buffer,

‎cuda_rasterizer/rasterizer_impl.cu

+35-2
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ __global__ void duplicateWithKeys(
7575
uint64_t* gaussian_keys_unsorted,
7676
uint32_t* gaussian_values_unsorted,
7777
int* radii,
78+
float2* rects,
7879
dim3 grid)
7980
{
8081
auto idx = cg::this_grid().thread_rank();
@@ -88,7 +89,14 @@ __global__ void duplicateWithKeys(
8889
uint32_t off = (idx == 0) ? 0 : offsets[idx - 1];
8990
uint2 rect_min, rect_max;
9091

91-
getRect(points_xy[idx], radii[idx], rect_min, rect_max, grid);
92+
#if FAST_INFERENCE
93+
if (radii[idx] > MAX_BILLBOARD_SIZE)
94+
getRectOld(points_xy[idx], radii[idx], rect_min, rect_max, grid);
95+
else
96+
getRect(points_xy[idx], rects[idx], rect_min, rect_max, grid);
97+
#else
98+
getRectOld(points_xy[idx], radii[idx], rect_min, rect_max, grid);
99+
#endif
92100

93101
// For each tile that the bounding rect overlaps, emit a
94102
// key/value pair. The key is | tile ID | depth |,
@@ -158,6 +166,7 @@ CudaRasterizer::GeometryState CudaRasterizer::GeometryState::fromChunk(char*& ch
158166
obtain(chunk, geom.depths, P, 128);
159167
obtain(chunk, geom.clamped, P * 3, 128);
160168
obtain(chunk, geom.internal_radii, P, 128);
169+
obtain(chunk, geom.rects2D, P, 128);
161170
obtain(chunk, geom.means2D, P, 128);
162171
obtain(chunk, geom.transMat, P * 9, 128);
163172
obtain(chunk, geom.normal, P, 128);
@@ -265,6 +274,7 @@ int CudaRasterizer::Rasterizer::forward(
265274
focal_x, focal_y,
266275
tan_fovx, tan_fovy,
267276
radii,
277+
geomState.rects2D,
268278
geomState.means2D,
269279
geomState.depths,
270280
geomState.transMat,
@@ -286,7 +296,24 @@ int CudaRasterizer::Rasterizer::forward(
286296
size_t binning_chunk_size = required<BinningState>(num_rendered);
287297
char* binning_chunkptr = binningBuffer(binning_chunk_size);
288298
BinningState binningState = BinningState::fromChunk(binning_chunkptr, num_rendered);
299+
300+
const float* transMat_ptr = transMat_precomp != nullptr ? transMat_precomp : geomState.transMat;
289301

302+
#if TILE_SORTING
303+
FORWARD::duplicate(
304+
P, width, height, focal_x, focal_y,
305+
geomState.means2D,
306+
geomState.depths,
307+
(float2*)scales,
308+
transMat_ptr,
309+
geomState.point_offsets,
310+
radii,
311+
geomState.rects2D,
312+
binningState.point_list_keys_unsorted,
313+
binningState.point_list_unsorted,
314+
tile_grid
315+
);
316+
#else
290317
// For each instance to be rendered, produce adequate [ tile | depth ] key
291318
// and corresponding dublicated Gaussian indices to be sorted
292319
duplicateWithKeys << <(P + 255) / 256, 256 >> > (
@@ -297,7 +324,10 @@ int CudaRasterizer::Rasterizer::forward(
297324
binningState.point_list_keys_unsorted,
298325
binningState.point_list_unsorted,
299326
radii,
327+
geomState.rects2D,
300328
tile_grid)
329+
#endif
330+
301331
CHECK_CUDA(, debug)
302332

303333
int bit = getHigherMsb(tile_grid.x * tile_grid.y);
@@ -322,7 +352,6 @@ int CudaRasterizer::Rasterizer::forward(
322352

323353
// Let each tile blend its range of Gaussians independently in parallel
324354
const float* feature_ptr = colors_precomp != nullptr ? colors_precomp : geomState.rgb;
325-
const float* transMat_ptr = transMat_precomp != nullptr ? transMat_precomp : geomState.transMat;
326355
CHECK_CUDA(FORWARD::render(
327356
tile_grid, block,
328357
imgState.ranges,
@@ -368,6 +397,8 @@ void CudaRasterizer::Rasterizer::backward(
368397
const float* campos,
369398
const float tan_fovx, float tan_fovy,
370399
const int* radii,
400+
const float* out_color,
401+
const float* out_others,
371402
char* geom_buffer,
372403
char* binning_buffer,
373404
char* img_buffer,
@@ -424,6 +455,8 @@ void CudaRasterizer::Rasterizer::backward(
424455
depth_ptr,
425456
imgState.accum_alpha,
426457
imgState.n_contrib,
458+
out_color,
459+
out_others,
427460
dL_dpix,
428461
dL_depths,
429462
dL_dtransMat,

‎cuda_rasterizer/rasterizer_impl.h

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ namespace CudaRasterizer
3333
char* scanning_space;
3434
bool* clamped;
3535
int* internal_radii;
36+
float2* rects2D;
3637
float2* means2D;
3738
float* transMat;
3839
float3* normal;

‎cuda_rasterizer/stopthepop_2DGS/resorted_render.cuh

+693
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
/*
2+
* Copyright (C) 2024, Graz University of Technology
3+
* This code is licensed under the MIT license (see LICENSE.txt in this folder for details)
4+
*/
5+
6+
#pragma once
7+
8+
#include "../auxiliary.h"
9+
10+
#include <cooperative_groups.h>
11+
namespace cg = cooperative_groups;
12+
13+
__device__ __inline__ uint64_t constructSortKey(uint32_t tile_id, float depth)
14+
{
15+
uint64_t key = tile_id;
16+
key <<= 32;
17+
key |= *((uint32_t*)&depth);
18+
return key;
19+
}
20+
21+
// Given a ray and a Gaussian primitive, compute the intersection depth.
22+
__device__ __inline__ bool getIntersectPoint(
23+
const int W, const int H,
24+
const float fx, const float fy,
25+
const float2 scale,
26+
const glm::vec2 pixel_center,
27+
const float* view2gaussian,
28+
float& depth
29+
){
30+
31+
// Fisrt compute two homogeneous planes, See Eq. (8)
32+
float3 Tu = {view2gaussian[0], view2gaussian[1], view2gaussian[2]};
33+
float3 Tv = {view2gaussian[3], view2gaussian[4], view2gaussian[5]};
34+
float3 Tw = {view2gaussian[6], view2gaussian[7], view2gaussian[8]};
35+
float3 k = {-Tu.x + pixel_center.x * Tw.x, -Tu.y + pixel_center.x * Tw.y, -Tu.z + pixel_center.x * Tw.z};
36+
float3 l = {-Tv.x + pixel_center.y * Tw.x, -Tv.y + pixel_center.y * Tw.y, -Tv.z + pixel_center.y * Tw.z};
37+
// cross product of two planes is a line (i.e., homogeneous point), See Eq. (10)
38+
float3 p = crossProduct(k, l);
39+
40+
if (p.z == 0.0) return false; // there is not intersection
41+
// TODO: no intersection if distance < scale
42+
43+
// 3d homogeneous point to 2d point on the splat
44+
float2 s = {p.x / p.z, p.y / p.z};
45+
// 3d distance. Compute Mahalanobis distance in the canonical splat' space
46+
float rho3d = (s.x * s.x + s.y * s.y);
47+
48+
depth = (s.x * Tw.x + s.y * Tw.y) + Tw.z; // splat depth
49+
return true;
50+
}
51+
52+
53+
template<bool TILE_BASED_CULLING = false, bool LOAD_BALANCING = true>
54+
__global__ void duplicateWithKeys_extended(
55+
int P,
56+
int W, int H,
57+
const float focal_x, const float focal_y,
58+
const float2* __restrict__ points_xy,
59+
const float* __restrict__ depths,
60+
const float2* __restrict__ scales,
61+
const float* __restrict__ view2gaussians,
62+
const uint32_t* __restrict__ offsets,
63+
const int* __restrict__ radii,
64+
const float2* __restrict__ rects,
65+
uint64_t* __restrict__ gaussian_keys_unsorted,
66+
uint32_t* __restrict__ gaussian_values_unsorted,
67+
dim3 grid)
68+
{
69+
auto block = cg::this_thread_block();
70+
auto warp = cg::tiled_partition<WARP_SIZE>(block);
71+
72+
// Since the projection of the quadratic surface on the image is non-convex,
73+
// there is no explicit solution for computing the pixel with the maximum weight on the image,
74+
// and tile-based culling is not performed.
75+
constexpr bool EVAL_MAX_CONTRIB_POS = false;
76+
constexpr bool PER_TILE_DEPTH = true;
77+
78+
#define RETURN_OR_INACTIVE() if constexpr(LOAD_BALANCING) { active = false; } else { return; }
79+
uint32_t idx = cg::this_grid().thread_rank();
80+
bool active = true;
81+
if (idx >= P) {
82+
RETURN_OR_INACTIVE();
83+
idx = P - 1;
84+
}
85+
86+
const int radius = radii[idx];
87+
if (radius <= 0) {
88+
RETURN_OR_INACTIVE();
89+
}
90+
91+
// If the thread exceeds the Gaussian index, the Gaussian projection is zero,
92+
// and there are no Gaussians to process in the current warp, return.
93+
if constexpr(LOAD_BALANCING)
94+
if (__ballot_sync(WARP_MASK, active) == 0)
95+
return;
96+
97+
// Find this Gaussian's offset in buffer for writing keys/values.
98+
uint32_t off_init = (idx == 0) ? 0 : offsets[idx - 1];
99+
100+
const int offset_to_init = offsets[idx];
101+
const float global_depth_init = depths[idx];
102+
103+
const float2 xy_init = points_xy[idx];
104+
const float2 rect_dims_init = rects[idx];
105+
106+
__shared__ float2 s_xy[BLOCK_SIZE];
107+
__shared__ float2 s_rect_dims[BLOCK_SIZE];
108+
__shared__ float s_radius[BLOCK_SIZE];
109+
s_xy[block.thread_rank()] = xy_init;
110+
s_rect_dims[block.thread_rank()] = rect_dims_init;
111+
s_radius[block.thread_rank()] = radius;
112+
113+
uint2 rect_min_init, rect_max_init;
114+
#if FAST_INFERENCE
115+
if (radius > MAX_BILLBOARD_SIZE)
116+
getRectOld(xy_init, radius, rect_min_init, rect_max_init, grid);
117+
else
118+
getRect(xy_init, rect_dims_init, rect_min_init, rect_max_init, grid);
119+
# else
120+
getRectOld(xy_init, radius, rect_min_init, rect_max_init, grid);
121+
#endif
122+
123+
__shared__ float s_view2gaussians[BLOCK_SIZE * 9];
124+
__shared__ float2 s_scales[BLOCK_SIZE];
125+
126+
if (PER_TILE_DEPTH)
127+
{
128+
s_scales[block.thread_rank()] = scales[idx];
129+
for (int ii = 0; ii < 9; ii++)
130+
s_view2gaussians[9 * block.thread_rank() + ii] = view2gaussians[idx * 9 + ii];
131+
}
132+
133+
constexpr uint32_t SEQUENTIAL_TILE_THRESH = 32U; // all tiles above this threshold will be computed cooperatively
134+
const uint32_t rect_width_init = (rect_max_init.x - rect_min_init.x);
135+
const uint32_t tile_count_init = (rect_max_init.y - rect_min_init.y) * rect_width_init;
136+
137+
// Generate no key/value pair for invisible Gaussians
138+
if (tile_count_init == 0) {
139+
RETURN_OR_INACTIVE();
140+
}
141+
auto tile_function = [&](const int W, const int H,
142+
const float fx, const float fy,
143+
float2 xy,
144+
int x, int y,// tile ID
145+
const float2 scale,
146+
const float* view2gaussian,
147+
const float global_depth,
148+
float& depth)
149+
{
150+
const glm::vec2 tile_min(x * BLOCK_X, y * BLOCK_Y);
151+
const glm::vec2 tile_max((x + 1) * BLOCK_X - 1, (y + 1) * BLOCK_Y - 1); // 像素坐标
152+
153+
glm::vec2 max_pos;
154+
if constexpr (PER_TILE_DEPTH)
155+
{
156+
glm::vec2 target_pos = {max(min(xy.x, tile_max.x), tile_min.x), max(min(xy.y, tile_max.y), tile_min.y)};
157+
158+
// Or select the tile's center pixel as the target_pos.
159+
// const glm::vec2 tile_center = (tile_min + tile_max) * 0.5f;
160+
// glm::vec2 target_pos = tile_center;
161+
162+
bool intersect = getIntersectPoint(
163+
W, H, fx, fy, scale, target_pos, view2gaussian, depth); // Compute the intersection point of the quadratic surface.
164+
if (intersect)
165+
depth = max(0.0f, depth);
166+
else // If there is no intersection, sort by the Gaussian centroid.
167+
depth = global_depth;
168+
}
169+
else
170+
{
171+
depth = global_depth;
172+
}
173+
174+
// Since the quadratic surface is non-convex, tile-based culling is not performed.
175+
// return (!TILE_BASED_CULLING) || max_opac_factor <= opacity_factor_threshold;
176+
return true;
177+
};
178+
179+
if (active)
180+
{
181+
const float2 scale_init = {
182+
s_scales[block.thread_rank()].x,
183+
s_scales[block.thread_rank()].y};
184+
185+
float view2gaussian_init[9];
186+
for (int ii = 0; ii < 9; ii++)
187+
view2gaussian_init[ii] = s_view2gaussians[9 * block.thread_rank() + ii];
188+
189+
for (uint32_t tile_idx = 0; tile_idx < tile_count_init && (!LOAD_BALANCING || tile_idx < SEQUENTIAL_TILE_THRESH); tile_idx++)
190+
{
191+
const int y = (tile_idx / rect_width_init) + rect_min_init.y;
192+
const int x = (tile_idx % rect_width_init) + rect_min_init.x;
193+
194+
float depth;
195+
bool write_tile = tile_function(
196+
W, H, focal_x, focal_y,
197+
xy_init, x, y, scale_init, view2gaussian_init, global_depth_init, depth);
198+
if (write_tile)
199+
{
200+
if (off_init < offset_to_init)
201+
{
202+
const uint32_t tile_id = y * grid.x + x;
203+
gaussian_values_unsorted[off_init] = idx;
204+
gaussian_keys_unsorted[off_init] = constructSortKey(tile_id, depth);
205+
}
206+
else
207+
{
208+
#ifdef DUPLICATE_OPT_DEBUG
209+
printf("Error (sequential): Too little memory reserved in preprocess: off=%d off_to=%d idx=%d\n", off_init, offset_to_init, idx);
210+
#endif
211+
}
212+
off_init++;
213+
}
214+
}
215+
}
216+
217+
#undef RETURN_OR_INACTIVE
218+
219+
if (!LOAD_BALANCING) // Coordinate to handle the unprocessed tasks of other threads within the same warp.
220+
return;
221+
222+
const uint32_t idx_init = idx; // Current thread idx.
223+
const uint32_t lane_idx = cg::this_thread_block().thread_rank() % WARP_SIZE;
224+
const uint32_t warp_idx = cg::this_thread_block().thread_rank() / WARP_SIZE;
225+
unsigned int lane_mask_allprev_excl = 0xFFFFFFFFU >> (WARP_SIZE - lane_idx);
226+
227+
const int32_t compute_cooperatively = active && tile_count_init > SEQUENTIAL_TILE_THRESH; // Determine whether additional idle threads are needed for computation.
228+
const uint32_t remaining_threads = __ballot_sync(WARP_MASK, compute_cooperatively);
229+
if (remaining_threads == 0)
230+
return;
231+
232+
uint32_t n_remaining_threads = __popc(remaining_threads); // The number of threads required for collaborative computation.
233+
for (int n = 0; n < n_remaining_threads && n < WARP_SIZE; n++)
234+
{
235+
int i = __fns(remaining_threads, 0, n+1); // find lane index of next remaining thread
236+
237+
uint32_t idx_coop = __shfl_sync(WARP_MASK, idx_init, i);
238+
uint32_t off_coop = __shfl_sync(WARP_MASK, off_init, i);
239+
240+
const uint32_t offset_to = __shfl_sync(WARP_MASK, offset_to_init, i);
241+
const float global_depth = __shfl_sync(WARP_MASK, global_depth_init, i);
242+
243+
const float2 xy = s_xy[warp.meta_group_rank() * WARP_SIZE + i];
244+
const float2 rect_dims = s_rect_dims[warp.meta_group_rank() * WARP_SIZE + i];
245+
const float rad = s_radius[warp.meta_group_rank() * WARP_SIZE + i];
246+
const float2 scale = {
247+
s_scales[warp.meta_group_rank() * WARP_SIZE + i].x,
248+
s_scales[warp.meta_group_rank() * WARP_SIZE + i].y};
249+
float view2gaussian[9];
250+
for (int ii = 0; ii < 9; ii++)
251+
view2gaussian[ii] = s_view2gaussians[9 * (warp.meta_group_rank() * WARP_SIZE + i) + ii];
252+
253+
uint2 rect_min, rect_max;
254+
#if FAST_INFERENCE
255+
if (radius > MAX_BILLBOARD_SIZE)
256+
getRectOld(xy, rad, rect_min, rect_max, grid);
257+
else
258+
getRect(xy, rect_dims, rect_min, rect_max, grid);
259+
#else
260+
getRectOld(xy, rad, rect_min, rect_max, grid);
261+
#endif
262+
263+
const uint32_t rect_width = (rect_max.x - rect_min.x);
264+
const uint32_t tile_count = (rect_max.y - rect_min.y) * rect_width;
265+
const uint32_t remaining_tile_count = tile_count - SEQUENTIAL_TILE_THRESH;
266+
const int32_t n_iterations = (remaining_tile_count + WARP_SIZE - 1) / WARP_SIZE;
267+
for (int it = 0; it < n_iterations; it++)
268+
{
269+
int tile_idx = it * WARP_SIZE + lane_idx + SEQUENTIAL_TILE_THRESH; // it*32 + local_warp_idx + 32
270+
int active_curr_it = tile_idx < tile_count;
271+
272+
int y = (tile_idx / rect_width) + rect_min.y;
273+
int x = (tile_idx % rect_width) + rect_min.x;
274+
275+
float depth;
276+
bool write_tile = tile_function(
277+
W, H, focal_x, focal_y,
278+
xy, x, y, scale, view2gaussian, global_depth, depth
279+
);
280+
281+
const uint32_t write = active_curr_it && write_tile;
282+
283+
uint32_t n_writes, write_offset;
284+
if constexpr (!TILE_BASED_CULLING)
285+
{
286+
n_writes = WARP_SIZE;
287+
write_offset = off_coop + lane_idx;
288+
}
289+
else
290+
{
291+
const uint32_t write_ballot = __ballot_sync(WARP_MASK, write);
292+
n_writes = __popc(write_ballot);
293+
294+
const uint32_t write_offset_it = __popc(write_ballot & lane_mask_allprev_excl);
295+
write_offset = off_coop + write_offset_it;
296+
}
297+
298+
if (write)
299+
{
300+
if (write_offset < offset_to)
301+
{
302+
const uint32_t tile_id = y * grid.x + x;
303+
gaussian_values_unsorted[write_offset] = idx_coop;
304+
gaussian_keys_unsorted[write_offset] = constructSortKey(tile_id, depth);
305+
}
306+
#ifdef DUPLICATE_OPT_DEBUG
307+
else
308+
{
309+
printf("Error (parallel): Too little memory reserved in preprocess: off=%d off_to=%d idx=%d tile_count=%d it=%d | x=%d y=%d rect=(%d %d - %d %d)\n",
310+
write_offset, offset_to, idx_coop, tile_count, it, x, y, rect_min.x, rect_min.y, rect_max.x, rect_max.y);
311+
}
312+
#endif
313+
}
314+
off_coop += n_writes;
315+
}
316+
317+
__syncwarp();
318+
}
319+
}

‎diff_bbsplat_rasterization/__init__.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def forward(
9898
# Keep relevant tensors for backward
9999
ctx.raster_settings = raster_settings
100100
ctx.num_rendered = num_rendered
101-
ctx.save_for_backward(colors_precomp, means3D, scales, rotations, texture_alpha, textured_color, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer)
101+
ctx.save_for_backward(color, depth, colors_precomp, means3D, scales, rotations, texture_alpha, textured_color, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer)
102102

103103
return color, radii, impact, depth
104104

@@ -107,12 +107,14 @@ def backward(ctx, grad_out_color, grad_radii, grad_impact, grad_depth):
107107
# Restore necessary values from context
108108
num_rendered = ctx.num_rendered
109109
raster_settings = ctx.raster_settings
110-
colors_precomp, means3D, scales, rotations, texture_alpha, textured_color, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer = ctx.saved_tensors
110+
out_colors, out_others, colors_precomp, means3D, scales, rotations, texture_alpha, textured_color, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer = ctx.saved_tensors
111111

112112
# Restructure args as C++ method expects them
113113
args = (raster_settings.bg,
114114
means3D,
115115
radii,
116+
out_colors,
117+
out_others,
116118
colors_precomp,
117119
scales,
118120
rotations,

‎rasterize_points.cu

+4
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
164164
const torch::Tensor& background,
165165
const torch::Tensor& means3D,
166166
const torch::Tensor& radii,
167+
const torch::Tensor& out_colors,
168+
const torch::Tensor& out_others,
167169
const torch::Tensor& colors,
168170
const torch::Tensor& scales,
169171
const torch::Tensor& rotations,
@@ -248,6 +250,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
248250
tan_fovx,
249251
tan_fovy,
250252
radii.contiguous().data<int>(),
253+
out_colors.contiguous().data<float>(),
254+
out_others.contiguous().data<float>(),
251255
reinterpret_cast<char*>(geomBuffer.contiguous().data_ptr()),
252256
reinterpret_cast<char*>(binningBuffer.contiguous().data_ptr()),
253257
reinterpret_cast<char*>(imageBuffer.contiguous().data_ptr()),

‎rasterize_points.h

+2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
4343
const torch::Tensor& background,
4444
const torch::Tensor& means3D,
4545
const torch::Tensor& radii,
46+
const torch::Tensor& out_color,
47+
const torch::Tensor& out_others,
4648
const torch::Tensor& colors,
4749
const torch::Tensor& scales,
4850
const torch::Tensor& rotations,

0 commit comments

Comments
 (0)
Please sign in to comment.