From b7f08958735551be86c0eb3e3df3238de22dce15 Mon Sep 17 00:00:00 2001 From: Julius Hansjakob Date: Sat, 19 Mar 2022 15:55:00 +0100 Subject: [PATCH 1/2] fixed numerical instability in norm_weighted_sum --- pytorch3d/csrc/compositing/norm_weighted_sum.cu | 9 ++++++--- pytorch3d/csrc/compositing/norm_weighted_sum_cpu.cpp | 9 ++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/pytorch3d/csrc/compositing/norm_weighted_sum.cu b/pytorch3d/csrc/compositing/norm_weighted_sum.cu index 984647172..e9efa7c8b 100644 --- a/pytorch3d/csrc/compositing/norm_weighted_sum.cu +++ b/pytorch3d/csrc/compositing/norm_weighted_sum.cu @@ -113,6 +113,7 @@ __global__ void weightedSumNormCudaBackwardKernel( float sum_alpha = 0.; float sum_alphafs = 0.; + float sum_alpha_den = sum_alpha; // Iterate through the closest K points for this pixel to calculate the // cumulative sum of the alphas for this pixel for (int k = 0; k < points_idx.size(1); ++k) { @@ -126,8 +127,10 @@ __global__ void weightedSumNormCudaBackwardKernel( sum_alphafs += alphas[batch][k][j][i] * features[ch][n_idx]; } - if (sum_alpha < kEpsilon) { - sum_alpha = kEpsilon; + if (sum_alpha > kEpsilon) { + sum_alpha_den = kEpsilon; + } else { + sum_alpha_den = kEpsilon; } // Iterate again through the closest K points for this pixel to calculate @@ -147,7 +150,7 @@ __global__ void weightedSumNormCudaBackwardKernel( atomicAdd( &grad_alphas[batch][k][j][i], (features[ch][n_idx] * sum_alpha - sum_alphafs) / - (sum_alpha * sum_alpha) * grad_outputs[batch][ch][j][i]); + (sum_alpha_den * sum_alpha_den) * grad_outputs[batch][ch][j][i]); atomicAdd( &grad_features[ch][n_idx], alpha * grad_outputs[batch][ch][j][i] / sum_alpha); diff --git a/pytorch3d/csrc/compositing/norm_weighted_sum_cpu.cpp b/pytorch3d/csrc/compositing/norm_weighted_sum_cpu.cpp index 840ef3d24..cba08d6c8 100644 --- a/pytorch3d/csrc/compositing/norm_weighted_sum_cpu.cpp +++ b/pytorch3d/csrc/compositing/norm_weighted_sum_cpu.cpp @@ -101,6 +101,7 @@ std::tuple weightedSumNormCpuBackward( for (int i = 0; i < W; ++i) { float t_alpha = 0.; float t_alphafs = 0.; + float t_alpha_den = t_alpha; // Iterate through the closest K points for this pixel for (int k = 0; k < K; ++k) { int64_t n_idx = points_idx_a[b][k][j][i]; @@ -113,8 +114,10 @@ std::tuple weightedSumNormCpuBackward( t_alphafs += alphas_a[b][k][j][i] * features_a[c][n_idx]; } - if (t_alpha < kEps) { - t_alpha = kEps; + if (t_alpha > kEps) { + t_alpha_den = t_alpha; + } else { + t_alpha_den = kEps; } // Iterate through the closest K points for this pixel ordered by z @@ -128,7 +131,7 @@ std::tuple weightedSumNormCpuBackward( float alpha = alphas_a[b][k][j][i]; grad_alphas_a[b][k][j][i] += grad_outputs_a[b][c][j][i] * (features_a[c][n_idx] * t_alpha - t_alphafs) / - (t_alpha * t_alpha); + (t_alpha_den * t_alpha_den); grad_features_a[c][n_idx] += grad_outputs_a[b][c][j][i] * alpha / t_alpha; } From 3d46ee0841ea9e773020b7c202af12c5854e4d16 Mon Sep 17 00:00:00 2001 From: Julius Hansjakob Date: Sat, 19 Mar 2022 15:59:43 +0100 Subject: [PATCH 2/2] fixed numerical instability in norm_weighted_sum --- pytorch3d/csrc/compositing/norm_weighted_sum.cu | 6 +----- pytorch3d/csrc/compositing/norm_weighted_sum_cpu.cpp | 6 +----- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/pytorch3d/csrc/compositing/norm_weighted_sum.cu b/pytorch3d/csrc/compositing/norm_weighted_sum.cu index e9efa7c8b..6bf9e2fb5 100644 --- a/pytorch3d/csrc/compositing/norm_weighted_sum.cu +++ b/pytorch3d/csrc/compositing/norm_weighted_sum.cu @@ -127,11 +127,7 @@ __global__ void weightedSumNormCudaBackwardKernel( sum_alphafs += alphas[batch][k][j][i] * features[ch][n_idx]; } - if (sum_alpha > kEpsilon) { - sum_alpha_den = kEpsilon; - } else { - sum_alpha_den = kEpsilon; - } + sum_alpha_den = max(sum_alpha, kEpsilon); // Iterate again through the closest K points for this pixel to calculate // the gradient. diff --git a/pytorch3d/csrc/compositing/norm_weighted_sum_cpu.cpp b/pytorch3d/csrc/compositing/norm_weighted_sum_cpu.cpp index cba08d6c8..8801f5598 100644 --- a/pytorch3d/csrc/compositing/norm_weighted_sum_cpu.cpp +++ b/pytorch3d/csrc/compositing/norm_weighted_sum_cpu.cpp @@ -114,11 +114,7 @@ std::tuple weightedSumNormCpuBackward( t_alphafs += alphas_a[b][k][j][i] * features_a[c][n_idx]; } - if (t_alpha > kEps) { - t_alpha_den = t_alpha; - } else { - t_alpha_den = kEps; - } + t_alpha_den = std::max(t_alpha, kEps); // Iterate through the closest K points for this pixel ordered by z // distance.