Skip to content

Commit 13e4cb0

Browse files
sandeepnmenonPrapti Devansh TrivediPrapti Devansh TrivediSandeep Menonprapti19
authored
Batch kernels for forward pass of Preprocessing (#2)
* test function for rasterszaton tests * add mock of improved preproc * batched rasterization * Refactor GaussianRasterizerBatches class to support batched preprocess_gaussians function. * batched forward pass kernel * added headers and changed kernel structure to 1d block * Refactor GaussianRasterizerBatches class to use torch.tensor instead of math.tan in test_batched_gaussian_rasterizer_batch_processing function * add parity test * Refactor preprocess_gaussians function to handle batched and non-batched inputs in __init__.py * Refactor tan_fovy parameter to be const in CUDA rasterizer files * Refactor tan_fovx parameter to be const in CUDA rasterizer files * Refactor CUDA rasterizer files to use CUDA tensors for batched calculations * Refactor assert_tensor_equal function to compare_tensors in rasterization_tests.py * tile_grid calculated before kernel launch * Refactor compare_tensors function to fix indexing bug and handle non-matching values * Refactor GaussianRasterizationSettings class to handle raster_settings as a batch * Refactor rasterization_tests.py to use raster_settings_batch instead of batched_raster_settings * fixed namedtuple setting bug * Refactor computeColorFromSH function in forward.cu to use point_idx and result_idx instead of only idx. * fixed cuda illegal memory bug and can run for 1M gaussians --------- Co-authored-by: Prapti Devansh Trivedi <[email protected]> Co-authored-by: Prapti Devansh Trivedi <[email protected]> Co-authored-by: Sandeep Menon <[email protected]> Co-authored-by: prapti19 <[email protected]>
1 parent b0dfe34 commit 13e4cb0

File tree

11 files changed

+746
-20
lines changed

11 files changed

+746
-20
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ diff_gaussian_rasterization.egg-info/
33
dist/
44
diff_gaussian_rasterization/__pycache__/
55
*so
6+
*.pyc

cuda_rasterizer/forward.cu

Lines changed: 160 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,16 @@ namespace cg = cooperative_groups;
1818

1919
// Forward method for converting the input spherical harmonics
2020
// coefficients of each Gaussian to a simple RGB color.
21-
__device__ glm::vec3 computeColorFromSH(int idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, bool* clamped)
21+
__device__ glm::vec3 computeColorFromSH(int point_idx, int result_idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, bool* clamped)
2222
{
2323
// The implementation is loosely based on code for
2424
// "Differentiable Point-Based Radiance Fields for
2525
// Efficient View Synthesis" by Zhang et al. (2022)
26-
glm::vec3 pos = means[idx];
26+
glm::vec3 pos = means[point_idx];
2727
glm::vec3 dir = pos - campos;
2828
dir = dir / glm::length(dir);
2929

30-
glm::vec3* sh = ((glm::vec3*)shs) + idx * max_coeffs;
30+
glm::vec3* sh = ((glm::vec3*)shs) + point_idx * max_coeffs;
3131
glm::vec3 result = SH_C0 * sh[0];
3232

3333
if (deg > 0)
@@ -65,9 +65,9 @@ __device__ glm::vec3 computeColorFromSH(int idx, int deg, int max_coeffs, const
6565

6666
// RGB colors are clamped to positive values. If values are
6767
// clamped, we need to keep track of this for the backward pass.
68-
clamped[3 * idx + 0] = (result.x < 0);
69-
clamped[3 * idx + 1] = (result.y < 0);
70-
clamped[3 * idx + 2] = (result.z < 0);
68+
clamped[3 * result_idx + 0] = (result.x < 0);
69+
clamped[3 * result_idx + 1] = (result.y < 0);
70+
clamped[3 * result_idx + 2] = (result.z < 0);
7171
return glm::max(result, 0.0f);
7272
}
7373

@@ -213,7 +213,6 @@ __global__ void preprocessCUDA(int P, int D, int M,
213213
computeCov3D(scales[idx], scale_modifier, rotations[idx], cov3Ds + idx * 6);
214214
cov3D = cov3Ds + idx * 6;
215215
}
216-
217216
// Compute 2D screen-space covariance matrix
218217
float3 cov = computeCov2D(p_orig, focal_x, focal_y, tan_fovx, tan_fovy, cov3D, viewmatrix);
219218

@@ -242,7 +241,7 @@ __global__ void preprocessCUDA(int P, int D, int M,
242241
// spherical harmonics coefficients to RGB color.
243242
if (colors_precomp == nullptr)
244243
{
245-
glm::vec3 result = computeColorFromSH(idx, D, M, (glm::vec3*)orig_points, *cam_pos, shs, clamped);
244+
glm::vec3 result = computeColorFromSH(idx, idx, D, M, (glm::vec3*)orig_points, *cam_pos, shs, clamped);
246245
rgb[idx * C + 0] = result.x;
247246
rgb[idx * C + 1] = result.y;
248247
rgb[idx * C + 2] = result.z;
@@ -471,7 +470,7 @@ void FORWARD::preprocess(int P, int D, int M,
471470
uint32_t* tiles_touched,
472471
bool prefiltered)
473472
{
474-
preprocessCUDA<NUM_CHANNELS> << <(P + ONE_DIM_BLOCK_SIZE - 1) / ONE_DIM_BLOCK_SIZE, ONE_DIM_BLOCK_SIZE >> > (
473+
preprocessCUDA<NUM_CHANNELS> << <cdiv(P, ONE_DIM_BLOCK_SIZE), ONE_DIM_BLOCK_SIZE >> > (
475474
P, D, M,
476475
means3D,
477476
scales,
@@ -498,4 +497,155 @@ void FORWARD::preprocess(int P, int D, int M,
498497
tiles_touched,
499498
prefiltered
500499
);
501-
}
500+
}
501+
502+
503+
template<int C>
504+
__global__ void preprocessCUDABatched(
505+
int P, int D, int M,
506+
const float* orig_points, const glm::vec3* scales, const float scale_modifier,
507+
const glm::vec4* rotations, const float* opacities, const float* shs,
508+
bool* clamped, const float* cov3D_precomp, const float* colors_precomp,
509+
const float* viewmatrix_arr, const float* projmatrix_arr, const glm::vec3* cam_pos,
510+
const int W, int H, const float* tan_fovx, const float* tan_fovy,
511+
int* radii, float2* points_xy_image, float* depths, float* cov3Ds,
512+
float* rgb, float4* conic_opacity, const dim3 grid, uint32_t* tiles_touched,
513+
bool prefiltered, const int num_viewpoints)
514+
{
515+
auto point_idx = blockIdx.x * blockDim.x + threadIdx.x;
516+
auto viewpoint_idx = blockIdx.y;
517+
518+
if (viewpoint_idx >= num_viewpoints || point_idx >= P) return;
519+
520+
auto idx = viewpoint_idx * P + point_idx;
521+
const float* viewmatrix = viewmatrix_arr + viewpoint_idx * 16;
522+
const float* projmatrix = projmatrix_arr + viewpoint_idx * 16;
523+
524+
// Initialize radius and touched tiles to 0. If this isn't changed,
525+
// this Gaussian will not be processed further.
526+
radii[idx] = 0;
527+
tiles_touched[idx] = 0;
528+
529+
// Perform near culling, quit if outside.
530+
float3 p_view;
531+
if (!in_frustum(point_idx, orig_points, viewmatrix, projmatrix, prefiltered, p_view)) return;
532+
533+
// Transform point by projecting
534+
float3 p_orig = { orig_points[3 * point_idx], orig_points[3 * point_idx + 1], orig_points[3 * point_idx + 2] };
535+
536+
float4 p_hom = transformPoint4x4(p_orig, projmatrix);
537+
float p_w = 1.0f / (p_hom.w + 0.0000001f);
538+
float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w };
539+
540+
// If 3D covariance matrix is precomputed, use it, otherwise compute
541+
// from scaling and rotation parameters.
542+
const float* cov3D;
543+
if (cov3D_precomp != nullptr) {
544+
cov3D = cov3D_precomp + idx * 6;
545+
} else {
546+
computeCov3D(scales[point_idx], scale_modifier, rotations[point_idx], cov3Ds + idx * 6);
547+
cov3D = cov3Ds + idx * 6;
548+
}
549+
550+
551+
// Compute 2D screen-space covariance matrix
552+
const float focal_x = W / (2.0f * tan_fovx[viewpoint_idx]);
553+
const float focal_y = H / (2.0f * tan_fovy[viewpoint_idx]);
554+
float3 cov = computeCov2D(p_orig, focal_x, focal_y, tan_fovx[viewpoint_idx], tan_fovy[viewpoint_idx], cov3D, viewmatrix);
555+
556+
557+
// Invert covariance (EWA algorithm)
558+
float det = (cov.x * cov.z - cov.y * cov.y);
559+
if (det == 0.0f) return;
560+
float det_inv = 1.f / det;
561+
float3 conic = { cov.z * det_inv, -cov.y * det_inv, cov.x * det_inv };
562+
563+
// Compute extent in screen space (by finding eigenvalues of
564+
// 2D covariance matrix). Use extent to compute a bounding rectangle
565+
// of screen-space tiles that this Gaussian overlaps with. Quit if
566+
// rectangle covers 0 tiles.
567+
float mid = 0.5f * (cov.x + cov.z);
568+
float lambda1 = mid + sqrt(max(0.1f, mid * mid - det));
569+
float lambda2 = mid - sqrt(max(0.1f, mid * mid - det));
570+
float my_radius = ceil(3.f * sqrt(max(lambda1, lambda2)));
571+
float2 point_image = { ndc2Pix(p_proj.x, W), ndc2Pix(p_proj.y, H) };
572+
uint2 rect_min, rect_max;
573+
getRect(point_image, my_radius, rect_min, rect_max, grid);
574+
if ((rect_max.x - rect_min.x) * (rect_max.y - rect_min.y) == 0) return;
575+
576+
// If colors have been precomputed, use them, otherwise convert
577+
// spherical harmonics coefficients to RGB color.
578+
579+
if (colors_precomp == nullptr) {
580+
581+
glm::vec3 result = computeColorFromSH(point_idx, idx, D, M, (glm::vec3*)orig_points, cam_pos[viewpoint_idx], shs, clamped);
582+
rgb[idx * C + 0] = result.x;
583+
rgb[idx * C + 1] = result.y;
584+
rgb[idx * C + 2] = result.z;
585+
}
586+
587+
// Store some useful helper data for the next steps.
588+
depths[idx] = p_view.z;
589+
radii[idx] = my_radius;
590+
points_xy_image[idx] = point_image;
591+
592+
// Inverse 2D covariance and opacity neatly pack into one float4
593+
conic_opacity[idx] = { conic.x, conic.y, conic.z, opacities[point_idx] };
594+
tiles_touched[idx] = (rect_max.y - rect_min.y) * (rect_max.x - rect_min.x);
595+
}
596+
597+
void FORWARD::preprocess_batch(int P, int D, int M,
598+
const float* means3D,
599+
const glm::vec3* scales,
600+
const float scale_modifier,
601+
const glm::vec4* rotations,
602+
const float* opacities,
603+
const float* shs,
604+
bool* clamped,
605+
const float* cov3D_precomp,
606+
const float* colors_precomp,
607+
const float* viewmatrix,
608+
const float* projmatrix,
609+
const glm::vec3* cam_pos,
610+
const int W, int H,
611+
const float* tan_fovx, const float* tan_fovy,
612+
int* radii,
613+
float2* means2D,
614+
float* depths,
615+
float* cov3Ds,
616+
float* rgb,
617+
float4* conic_opacity,
618+
const dim3 grid,
619+
uint32_t* tiles_touched,
620+
bool prefiltered,
621+
const int num_viewpoints)
622+
{
623+
dim3 tile_grid(cdiv(P, ONE_DIM_BLOCK_SIZE), num_viewpoints);
624+
preprocessCUDABatched<NUM_CHANNELS><<<tile_grid, ONE_DIM_BLOCK_SIZE>>>(
625+
P, D, M,
626+
means3D,
627+
scales,
628+
scale_modifier,
629+
rotations,
630+
opacities,
631+
shs,
632+
clamped,
633+
cov3D_precomp,
634+
colors_precomp,
635+
viewmatrix,
636+
projmatrix,
637+
cam_pos,
638+
W, H,
639+
tan_fovx, tan_fovy,
640+
radii,
641+
means2D,
642+
depths,
643+
cov3Ds,
644+
rgb,
645+
conic_opacity,
646+
grid,
647+
tiles_touched,
648+
prefiltered,
649+
num_viewpoints
650+
);
651+
}

cuda_rasterizer/forward.h

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,35 @@ namespace FORWARD
4545
float4* conic_opacity,
4646
const dim3 grid,
4747
uint32_t* tiles_touched,
48-
bool prefiltered);
48+
bool prefiltered
49+
);
50+
51+
void preprocess_batch(int P, int D, int M,
52+
const float* means3D,
53+
const glm::vec3* scales,
54+
const float scale_modifier,
55+
const glm::vec4* rotations,
56+
const float* opacities,
57+
const float* shs,
58+
bool* clamped,
59+
const float* cov3D_precomp,
60+
const float* colors_precomp,
61+
const float* viewmatrix,
62+
const float* projmatrix,
63+
const glm::vec3* cam_pos,
64+
const int W, int H,
65+
const float* tan_fovx, const float* tan_fovy,
66+
int* radii,
67+
float2* means2D,
68+
float* depths,
69+
float* cov3Ds,
70+
float* rgb,
71+
float4* conic_opacity,
72+
const dim3 grid,
73+
uint32_t* tiles_touched,
74+
bool prefiltered,
75+
const int num_viewpoints
76+
);
4977

5078
// Main rasterization method.
5179
void render(
@@ -61,7 +89,8 @@ namespace FORWARD
6189
uint32_t* n_contrib2loss,
6290
const int* compute_locally_1D_2D_map,
6391
const float* bg_color,
64-
float* out_color);
92+
float* out_color
93+
);
6594
}
6695

6796

cuda_rasterizer/rasterizer.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,31 @@ namespace CudaRasterizer
6565
bool debug,//raster_settings
6666
const pybind11::dict &args);
6767

68+
static int preprocessForwardBatches(
69+
float2* means2D,
70+
float* depths,
71+
int* radii,
72+
float* cov3D,
73+
float4* conic_opacity,
74+
float* rgb,
75+
bool* clamped,//the above are all per-Gaussian intemediate results.
76+
const int P, int D, int M,
77+
const int width, int height,
78+
const float* means3D,
79+
const float* scales,
80+
const float* rotations,
81+
const float* shs,
82+
const float* opacities,//3dgs parameters
83+
const float scale_modifier,
84+
const float* viewmatrix,
85+
const float* projmatrix,
86+
const float* cam_pos,
87+
const float* tan_fovx, const float* tan_fovy,
88+
const bool prefiltered,
89+
const int num_viewpoints,
90+
bool debug,//raster_settings
91+
const pybind11::dict &args);
92+
6893
static void preprocessBackward(
6994
const int* radii,
7095
const float* cov3D,

0 commit comments

Comments
 (0)