Skip to content

Commit 2eb32ea

Browse files
committed
Add backward to mean3D in depth back pass
1 parent 8fa430b commit 2eb32ea

File tree

7 files changed

+52
-19
lines changed

7 files changed

+52
-19
lines changed

cuda_rasterizer/backward.cu

+35-10
Original file line numberDiff line numberDiff line change
@@ -353,11 +353,13 @@ __global__ void preprocessCUDA(
353353
const glm::vec3* scales,
354354
const glm::vec4* rotations,
355355
const float scale_modifier,
356+
const float* view,
356357
const float* proj,
357358
const glm::vec3* campos,
358359
const float3* dL_dmean2D,
359360
glm::vec3* dL_dmeans,
360361
float* dL_dcolor,
362+
float* dL_ddepth,
361363
float* dL_dcov3D,
362364
float* dL_dsh,
363365
glm::vec3* dL_dscale,
@@ -386,6 +388,20 @@ __global__ void preprocessCUDA(
386388
// of cov2D and following SH conversion also affects it.
387389
dL_dmeans[idx] += dL_dmean;
388390

391+
// the w must be equal to 1 for view^T * [x,y,z,1]
392+
float3 m_view = transformPoint4x3(m, view);
393+
394+
// Compute loss gradient w.r.t. 3D means due to gradients of depth
395+
// from rendering procedure
396+
glm::vec3 dL_dmean2;
397+
float mul3 = view[2] * m.x + view[6] * m.y + view[10] * m.z + view[14];
398+
dL_dmean2.x = (view[2] - view[3] * mul3) * dL_ddepth[idx];
399+
dL_dmean2.y = (view[6] - view[7] * mul3) * dL_ddepth[idx];
400+
dL_dmean2.z = (view[10] - view[11] * mul3) * dL_ddepth[idx];
401+
402+
// That's the third part of the mean gradient.
403+
dL_dmeans[idx] += dL_dmean2;
404+
389405
// Compute gradient updates due to computing colors from SHs
390406
if (shs)
391407
computeColorFromSH(idx, D, M, (glm::vec3*)means, *campos, shs, clamped, (glm::vec3*)dL_dcolor, (glm::vec3*)dL_dmeans, (glm::vec3*)dL_dsh);
@@ -410,11 +426,12 @@ renderCUDA(
410426
const float* __restrict__ final_Ts,
411427
const uint32_t* __restrict__ n_contrib,
412428
const float* __restrict__ dL_dpixels,
413-
const float* __restrict__ dL_depths,
429+
const float* __restrict__ dL_dpixel_depths,
414430
float3* __restrict__ dL_dmean2D,
415431
float4* __restrict__ dL_dconic2D,
416432
float* __restrict__ dL_dopacity,
417-
float* __restrict__ dL_dcolors)
433+
float* __restrict__ dL_dcolors,
434+
float* __restrict__ dL_ddepths)
418435
{
419436
// We rasterize again. Compute necessary block info.
420437
auto block = cg::this_thread_block();
@@ -451,12 +468,12 @@ renderCUDA(
451468

452469
float accum_rec[C] = { 0 };
453470
float dL_dpixel[C];
454-
float dL_depth;
471+
float dL_dpixel_depth;
455472
float accum_depth_rec = 0;
456473
if (inside){
457474
for (int i = 0; i < C; i++)
458475
dL_dpixel[i] = dL_dpixels[i * H * W + pix_id];
459-
dL_depth = dL_depths[pix_id];
476+
dL_dpixel_depth = dL_dpixel_depths[pix_id];
460477
}
461478

462479
float last_alpha = 0;
@@ -483,7 +500,7 @@ renderCUDA(
483500
collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id];
484501
for (int i = 0; i < C; i++)
485502
collected_colors[i * BLOCK_SIZE + block.thread_rank()] = colors[coll_id * C + i];
486-
collected_depths[block.thread_rank()] = depths[coll_id];
503+
collected_depths[block.thread_rank()] = depths[coll_id];
487504
}
488505
block.sync();
489506

@@ -511,6 +528,7 @@ renderCUDA(
511528

512529
T = T / (1.f - alpha);
513530
const float dchannel_dcolor = alpha * T;
531+
const float dpixel_depth_ddepth = alpha * T;
514532

515533
// Propagate gradients to per-Gaussian colors and keep
516534
// gradients w.r.t. alpha (blending factor for a Gaussian/pixel
@@ -534,7 +552,9 @@ renderCUDA(
534552
const float c_d = collected_depths[j];
535553
accum_depth_rec = last_alpha * last_depth + (1.f - last_alpha) * accum_depth_rec;
536554
last_depth = c_d;
537-
dL_dalpha += (c_d - accum_depth_rec) * dL_depth;
555+
dL_dalpha += (c_d - accum_depth_rec) * dL_dpixel_depth;
556+
atomicAdd(&(dL_ddepths[global_id]), dpixel_depth_ddepth * dL_dpixel_depth);
557+
538558
dL_dalpha *= T;
539559
// Update last alpha (to be used in the next iteration)
540560
last_alpha = alpha;
@@ -588,6 +608,7 @@ void BACKWARD::preprocess(
588608
const float* dL_dconic,
589609
glm::vec3* dL_dmean3D,
590610
float* dL_dcolor,
611+
float* dL_ddepth,
591612
float* dL_dcov3D,
592613
float* dL_dsh,
593614
glm::vec3* dL_dscale,
@@ -623,11 +644,13 @@ void BACKWARD::preprocess(
623644
(glm::vec3*)scales,
624645
(glm::vec4*)rotations,
625646
scale_modifier,
647+
viewmatrix,
626648
projmatrix,
627649
campos,
628650
(float3*)dL_dmean2D,
629651
(glm::vec3*)dL_dmean3D,
630652
dL_dcolor,
653+
dL_ddepth,
631654
dL_dcov3D,
632655
dL_dsh,
633656
dL_dscale,
@@ -647,11 +670,12 @@ void BACKWARD::render(
647670
const float* final_Ts,
648671
const uint32_t* n_contrib,
649672
const float* dL_dpixels,
650-
const float* dL_depths,
673+
const float* dL_dpixel_depths,
651674
float3* dL_dmean2D,
652675
float4* dL_dconic2D,
653676
float* dL_dopacity,
654-
float* dL_dcolors)
677+
float* dL_dcolors,
678+
float* dL_ddepths)
655679
{
656680
renderCUDA<NUM_CHANNELS> << <grid, block >> >(
657681
ranges,
@@ -665,10 +689,11 @@ void BACKWARD::render(
665689
final_Ts,
666690
n_contrib,
667691
dL_dpixels,
668-
dL_depths,
692+
dL_dpixel_depths,
669693
dL_dmean2D,
670694
dL_dconic2D,
671695
dL_dopacity,
672-
dL_dcolors
696+
dL_dcolors,
697+
dL_ddepths
673698
);
674699
}

cuda_rasterizer/backward.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,12 @@ namespace BACKWARD
3333
const float* final_Ts,
3434
const uint32_t* n_contrib,
3535
const float* dL_dpixels,
36-
const float* dL_depths,
36+
const float* dL_dpixel_depths,
3737
float3* dL_dmean2D,
3838
float4* dL_dconic2D,
3939
float* dL_dopacity,
40-
float* dL_dcolors);
40+
float* dL_dcolors,
41+
float* dL_ddepths);
4142

4243
void preprocess(
4344
int P, int D, int M,
@@ -58,6 +59,7 @@ namespace BACKWARD
5859
const float* dL_dconics,
5960
glm::vec3* dL_dmeans,
6061
float* dL_dcolor,
62+
float* dL_ddepth,
6163
float* dL_dcov3D,
6264
float* dL_dsh,
6365
glm::vec3* dL_dscale,

cuda_rasterizer/forward.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ namespace FORWARD
5555
int W, int H,
5656
const float2* points_xy_image,
5757
const float* features,
58-
const float* depths,
58+
const float* depth,
5959
const float4* conic_opacity,
6060
float* final_T,
6161
uint32_t* n_contrib,

cuda_rasterizer/rasterizer.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,12 @@ namespace CudaRasterizer
7373
char* binning_buffer,
7474
char* image_buffer,
7575
const float* dL_dpix,
76-
const float* dL_depths,
76+
const float* dL_dpix_depth,
7777
float* dL_dmean2D,
7878
float* dL_dconic,
7979
float* dL_dopacity,
8080
float* dL_dcolor,
81+
float* dL_ddepth,
8182
float* dL_dmean3D,
8283
float* dL_dcov3D,
8384
float* dL_dsh,

cuda_rasterizer/rasterizer_impl.cu

+6-3
Original file line numberDiff line numberDiff line change
@@ -360,11 +360,12 @@ void CudaRasterizer::Rasterizer::backward(
360360
char* binning_buffer,
361361
char* img_buffer,
362362
const float* dL_dpix,
363-
const float* dL_depths,
363+
const float* dL_dpix_depth,
364364
float* dL_dmean2D,
365365
float* dL_dconic,
366366
float* dL_dopacity,
367367
float* dL_dcolor,
368+
float* dL_ddepth,
368369
float* dL_dmean3D,
369370
float* dL_dcov3D,
370371
float* dL_dsh,
@@ -406,11 +407,12 @@ void CudaRasterizer::Rasterizer::backward(
406407
imgState.accum_alpha,
407408
imgState.n_contrib,
408409
dL_dpix,
409-
dL_depths,
410+
dL_dpix_depth,
410411
(float3*)dL_dmean2D,
411412
(float4*)dL_dconic,
412413
dL_dopacity,
413-
dL_dcolor), debug)
414+
dL_dcolor,
415+
dL_ddepth), debug)
414416

415417
// Take care of the rest of preprocessing. Was the precomputed covariance
416418
// given to us or a scales/rot pair? If precomputed, pass that. If not,
@@ -434,6 +436,7 @@ void CudaRasterizer::Rasterizer::backward(
434436
dL_dconic,
435437
(glm::vec3*)dL_dmean3D,
436438
dL_dcolor,
439+
dL_ddepth,
437440
dL_dcov3D,
438441
dL_dsh,
439442
(glm::vec3*)dL_dscale,

rasterize_points.cu

+3-1
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
131131
const float tan_fovx,
132132
const float tan_fovy,
133133
const torch::Tensor& dL_dout_color,
134-
const torch::Tensor& dL_dout_depth,
134+
const torch::Tensor& dL_dout_depth,
135135
const torch::Tensor& sh,
136136
const int degree,
137137
const torch::Tensor& campos,
@@ -154,6 +154,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
154154
torch::Tensor dL_dmeans3D = torch::zeros({P, 3}, means3D.options());
155155
torch::Tensor dL_dmeans2D = torch::zeros({P, 3}, means3D.options());
156156
torch::Tensor dL_dcolors = torch::zeros({P, NUM_CHANNELS}, means3D.options());
157+
torch::Tensor dL_ddepths = torch::zeros({P, 1}, means3D.options());
157158
torch::Tensor dL_dconic = torch::zeros({P, 2, 2}, means3D.options());
158159
torch::Tensor dL_dopacity = torch::zeros({P, 1}, means3D.options());
159160
torch::Tensor dL_dcov3D = torch::zeros({P, 6}, means3D.options());
@@ -188,6 +189,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
188189
dL_dconic.contiguous().data<float>(),
189190
dL_dopacity.contiguous().data<float>(),
190191
dL_dcolors.contiguous().data<float>(),
192+
dL_ddepths.contiguous().data<float>(),
191193
dL_dmeans3D.contiguous().data<float>(),
192194
dL_dcov3D.contiguous().data<float>(),
193195
dL_dsh.contiguous().data<float>(),

rasterize_points.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
5252
const float tan_fovx,
5353
const float tan_fovy,
5454
const torch::Tensor& dL_dout_color,
55-
const torch::Tensor& dL_dout_depth,
55+
const torch::Tensor& dL_dout_depth,
5656
const torch::Tensor& sh,
5757
const int degree,
5858
const torch::Tensor& campos,

0 commit comments

Comments
 (0)