diff --git a/pytorch3d/csrc/compositing/norm_weighted_sum.cu b/pytorch3d/csrc/compositing/norm_weighted_sum.cu index 984647172..6bf9e2fb5 100644 --- a/pytorch3d/csrc/compositing/norm_weighted_sum.cu +++ b/pytorch3d/csrc/compositing/norm_weighted_sum.cu @@ -113,6 +113,7 @@ __global__ void weightedSumNormCudaBackwardKernel( float sum_alpha = 0.; float sum_alphafs = 0.; + float sum_alpha_den = sum_alpha; // Iterate through the closest K points for this pixel to calculate the // cumulative sum of the alphas for this pixel for (int k = 0; k < points_idx.size(1); ++k) { @@ -126,9 +127,7 @@ __global__ void weightedSumNormCudaBackwardKernel( sum_alphafs += alphas[batch][k][j][i] * features[ch][n_idx]; } - if (sum_alpha < kEpsilon) { - sum_alpha = kEpsilon; - } + sum_alpha_den = max(sum_alpha, kEpsilon); // Iterate again through the closest K points for this pixel to calculate // the gradient. @@ -147,7 +146,7 @@ __global__ void weightedSumNormCudaBackwardKernel( atomicAdd( &grad_alphas[batch][k][j][i], (features[ch][n_idx] * sum_alpha - sum_alphafs) / - (sum_alpha * sum_alpha) * grad_outputs[batch][ch][j][i]); + (sum_alpha_den * sum_alpha_den) * grad_outputs[batch][ch][j][i]); atomicAdd( &grad_features[ch][n_idx], alpha * grad_outputs[batch][ch][j][i] / sum_alpha); diff --git a/pytorch3d/csrc/compositing/norm_weighted_sum_cpu.cpp b/pytorch3d/csrc/compositing/norm_weighted_sum_cpu.cpp index 840ef3d24..8801f5598 100644 --- a/pytorch3d/csrc/compositing/norm_weighted_sum_cpu.cpp +++ b/pytorch3d/csrc/compositing/norm_weighted_sum_cpu.cpp @@ -101,6 +101,7 @@ std::tuple weightedSumNormCpuBackward( for (int i = 0; i < W; ++i) { float t_alpha = 0.; float t_alphafs = 0.; + float t_alpha_den = t_alpha; // Iterate through the closest K points for this pixel for (int k = 0; k < K; ++k) { int64_t n_idx = points_idx_a[b][k][j][i]; @@ -113,9 +114,7 @@ std::tuple weightedSumNormCpuBackward( t_alphafs += alphas_a[b][k][j][i] * features_a[c][n_idx]; } - if (t_alpha < kEps) { - t_alpha = kEps; - } + t_alpha_den = std::max(t_alpha, kEps); // Iterate through the closest K points for this pixel ordered by z // distance. @@ -128,7 +127,7 @@ std::tuple weightedSumNormCpuBackward( float alpha = alphas_a[b][k][j][i]; grad_alphas_a[b][k][j][i] += grad_outputs_a[b][c][j][i] * (features_a[c][n_idx] * t_alpha - t_alphafs) / - (t_alpha * t_alpha); + (t_alpha_den * t_alpha_den); grad_features_a[c][n_idx] += grad_outputs_a[b][c][j][i] * alpha / t_alpha; }