@@ -18,16 +18,16 @@ namespace cg = cooperative_groups;
1818
1919// Forward method for converting the input spherical harmonics
2020// coefficients of each Gaussian to a simple RGB color.
21- __device__ glm::vec3 computeColorFromSH (int idx , int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float * shs, bool * clamped)
21+ __device__ glm::vec3 computeColorFromSH (int point_idx, int result_idx , int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float * shs, bool * clamped)
2222{
2323 // The implementation is loosely based on code for
2424 // "Differentiable Point-Based Radiance Fields for
2525 // Efficient View Synthesis" by Zhang et al. (2022)
26- glm::vec3 pos = means[idx ];
26+ glm::vec3 pos = means[point_idx ];
2727 glm::vec3 dir = pos - campos;
2828 dir = dir / glm::length (dir);
2929
30- glm::vec3* sh = ((glm::vec3*)shs) + idx * max_coeffs;
30+ glm::vec3* sh = ((glm::vec3*)shs) + point_idx * max_coeffs;
3131 glm::vec3 result = SH_C0 * sh[0 ];
3232
3333 if (deg > 0 )
@@ -65,9 +65,9 @@ __device__ glm::vec3 computeColorFromSH(int idx, int deg, int max_coeffs, const
6565
6666 // RGB colors are clamped to positive values. If values are
6767 // clamped, we need to keep track of this for the backward pass.
68- clamped[3 * idx + 0 ] = (result.x < 0 );
69- clamped[3 * idx + 1 ] = (result.y < 0 );
70- clamped[3 * idx + 2 ] = (result.z < 0 );
68+ clamped[3 * result_idx + 0 ] = (result.x < 0 );
69+ clamped[3 * result_idx + 1 ] = (result.y < 0 );
70+ clamped[3 * result_idx + 2 ] = (result.z < 0 );
7171 return glm::max (result, 0 .0f );
7272}
7373
@@ -213,7 +213,6 @@ __global__ void preprocessCUDA(int P, int D, int M,
213213 computeCov3D (scales[idx], scale_modifier, rotations[idx], cov3Ds + idx * 6 );
214214 cov3D = cov3Ds + idx * 6 ;
215215 }
216-
217216 // Compute 2D screen-space covariance matrix
218217 float3 cov = computeCov2D (p_orig, focal_x, focal_y, tan_fovx, tan_fovy, cov3D, viewmatrix);
219218
@@ -242,7 +241,7 @@ __global__ void preprocessCUDA(int P, int D, int M,
242241 // spherical harmonics coefficients to RGB color.
243242 if (colors_precomp == nullptr )
244243 {
245- glm::vec3 result = computeColorFromSH (idx, D, M, (glm::vec3*)orig_points, *cam_pos, shs, clamped);
244+ glm::vec3 result = computeColorFromSH (idx, idx, D, M, (glm::vec3*)orig_points, *cam_pos, shs, clamped);
246245 rgb[idx * C + 0 ] = result.x ;
247246 rgb[idx * C + 1 ] = result.y ;
248247 rgb[idx * C + 2 ] = result.z ;
@@ -471,7 +470,7 @@ void FORWARD::preprocess(int P, int D, int M,
471470 uint32_t * tiles_touched,
472471 bool prefiltered)
473472{
474- preprocessCUDA<NUM_CHANNELS> << <(P + ONE_DIM_BLOCK_SIZE - 1 ) / ONE_DIM_BLOCK_SIZE , ONE_DIM_BLOCK_SIZE >> > (
473+ preprocessCUDA<NUM_CHANNELS> << <cdiv (P, ONE_DIM_BLOCK_SIZE) , ONE_DIM_BLOCK_SIZE >> > (
475474 P, D, M,
476475 means3D,
477476 scales,
@@ -498,4 +497,155 @@ void FORWARD::preprocess(int P, int D, int M,
498497 tiles_touched,
499498 prefiltered
500499 );
501- }
500+ }
501+
502+
503+ template <int C>
504+ __global__ void preprocessCUDABatched (
505+ int P, int D, int M,
506+ const float * orig_points, const glm::vec3* scales, const float scale_modifier,
507+ const glm::vec4* rotations, const float * opacities, const float * shs,
508+ bool * clamped, const float * cov3D_precomp, const float * colors_precomp,
509+ const float * viewmatrix_arr, const float * projmatrix_arr, const glm::vec3* cam_pos,
510+ const int W, int H, const float * tan_fovx, const float * tan_fovy,
511+ int * radii, float2 * points_xy_image, float * depths, float * cov3Ds,
512+ float * rgb, float4 * conic_opacity, const dim3 grid, uint32_t * tiles_touched,
513+ bool prefiltered, const int num_viewpoints)
514+ {
515+ auto point_idx = blockIdx .x * blockDim .x + threadIdx .x ;
516+ auto viewpoint_idx = blockIdx .y ;
517+
518+ if (viewpoint_idx >= num_viewpoints || point_idx >= P) return ;
519+
520+ auto idx = viewpoint_idx * P + point_idx;
521+ const float * viewmatrix = viewmatrix_arr + viewpoint_idx * 16 ;
522+ const float * projmatrix = projmatrix_arr + viewpoint_idx * 16 ;
523+
524+ // Initialize radius and touched tiles to 0. If this isn't changed,
525+ // this Gaussian will not be processed further.
526+ radii[idx] = 0 ;
527+ tiles_touched[idx] = 0 ;
528+
529+ // Perform near culling, quit if outside.
530+ float3 p_view;
531+ if (!in_frustum (point_idx, orig_points, viewmatrix, projmatrix, prefiltered, p_view)) return ;
532+
533+ // Transform point by projecting
534+ float3 p_orig = { orig_points[3 * point_idx], orig_points[3 * point_idx + 1 ], orig_points[3 * point_idx + 2 ] };
535+
536+ float4 p_hom = transformPoint4x4 (p_orig, projmatrix);
537+ float p_w = 1 .0f / (p_hom.w + 0 .0000001f );
538+ float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w };
539+
540+ // If 3D covariance matrix is precomputed, use it, otherwise compute
541+ // from scaling and rotation parameters.
542+ const float * cov3D;
543+ if (cov3D_precomp != nullptr ) {
544+ cov3D = cov3D_precomp + idx * 6 ;
545+ } else {
546+ computeCov3D (scales[point_idx], scale_modifier, rotations[point_idx], cov3Ds + idx * 6 );
547+ cov3D = cov3Ds + idx * 6 ;
548+ }
549+
550+
551+ // Compute 2D screen-space covariance matrix
552+ const float focal_x = W / (2 .0f * tan_fovx[viewpoint_idx]);
553+ const float focal_y = H / (2 .0f * tan_fovy[viewpoint_idx]);
554+ float3 cov = computeCov2D (p_orig, focal_x, focal_y, tan_fovx[viewpoint_idx], tan_fovy[viewpoint_idx], cov3D, viewmatrix);
555+
556+
557+ // Invert covariance (EWA algorithm)
558+ float det = (cov.x * cov.z - cov.y * cov.y );
559+ if (det == 0 .0f ) return ;
560+ float det_inv = 1 .f / det;
561+ float3 conic = { cov.z * det_inv, -cov.y * det_inv, cov.x * det_inv };
562+
563+ // Compute extent in screen space (by finding eigenvalues of
564+ // 2D covariance matrix). Use extent to compute a bounding rectangle
565+ // of screen-space tiles that this Gaussian overlaps with. Quit if
566+ // rectangle covers 0 tiles.
567+ float mid = 0 .5f * (cov.x + cov.z );
568+ float lambda1 = mid + sqrt (max (0 .1f , mid * mid - det));
569+ float lambda2 = mid - sqrt (max (0 .1f , mid * mid - det));
570+ float my_radius = ceil (3 .f * sqrt (max (lambda1, lambda2)));
571+ float2 point_image = { ndc2Pix (p_proj.x , W), ndc2Pix (p_proj.y , H) };
572+ uint2 rect_min, rect_max;
573+ getRect (point_image, my_radius, rect_min, rect_max, grid);
574+ if ((rect_max.x - rect_min.x ) * (rect_max.y - rect_min.y ) == 0 ) return ;
575+
576+ // If colors have been precomputed, use them, otherwise convert
577+ // spherical harmonics coefficients to RGB color.
578+
579+ if (colors_precomp == nullptr ) {
580+
581+ glm::vec3 result = computeColorFromSH (point_idx, idx, D, M, (glm::vec3*)orig_points, cam_pos[viewpoint_idx], shs, clamped);
582+ rgb[idx * C + 0 ] = result.x ;
583+ rgb[idx * C + 1 ] = result.y ;
584+ rgb[idx * C + 2 ] = result.z ;
585+ }
586+
587+ // Store some useful helper data for the next steps.
588+ depths[idx] = p_view.z ;
589+ radii[idx] = my_radius;
590+ points_xy_image[idx] = point_image;
591+
592+ // Inverse 2D covariance and opacity neatly pack into one float4
593+ conic_opacity[idx] = { conic.x , conic.y , conic.z , opacities[point_idx] };
594+ tiles_touched[idx] = (rect_max.y - rect_min.y ) * (rect_max.x - rect_min.x );
595+ }
596+
597+ void FORWARD::preprocess_batch (int P, int D, int M,
598+ const float * means3D,
599+ const glm::vec3* scales,
600+ const float scale_modifier,
601+ const glm::vec4* rotations,
602+ const float * opacities,
603+ const float * shs,
604+ bool * clamped,
605+ const float * cov3D_precomp,
606+ const float * colors_precomp,
607+ const float * viewmatrix,
608+ const float * projmatrix,
609+ const glm::vec3* cam_pos,
610+ const int W, int H,
611+ const float * tan_fovx, const float * tan_fovy,
612+ int * radii,
613+ float2 * means2D,
614+ float * depths,
615+ float * cov3Ds,
616+ float * rgb,
617+ float4 * conic_opacity,
618+ const dim3 grid,
619+ uint32_t * tiles_touched,
620+ bool prefiltered,
621+ const int num_viewpoints)
622+ {
623+ dim3 tile_grid (cdiv (P, ONE_DIM_BLOCK_SIZE), num_viewpoints);
624+ preprocessCUDABatched<NUM_CHANNELS><<<tile_grid, ONE_DIM_BLOCK_SIZE>>> (
625+ P, D, M,
626+ means3D,
627+ scales,
628+ scale_modifier,
629+ rotations,
630+ opacities,
631+ shs,
632+ clamped,
633+ cov3D_precomp,
634+ colors_precomp,
635+ viewmatrix,
636+ projmatrix,
637+ cam_pos,
638+ W, H,
639+ tan_fovx, tan_fovy,
640+ radii,
641+ means2D,
642+ depths,
643+ cov3Ds,
644+ rgb,
645+ conic_opacity,
646+ grid,
647+ tiles_touched,
648+ prefiltered,
649+ num_viewpoints
650+ );
651+ }
0 commit comments