Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

numerical instability in compositing.norm_weighted_sum #1135

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions pytorch3d/csrc/compositing/norm_weighted_sum.cu
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ __global__ void weightedSumNormCudaBackwardKernel(

float sum_alpha = 0.;
float sum_alphafs = 0.;
float sum_alpha_den = sum_alpha;
// Iterate through the closest K points for this pixel to calculate the
// cumulative sum of the alphas for this pixel
for (int k = 0; k < points_idx.size(1); ++k) {
Expand All @@ -126,9 +127,7 @@ __global__ void weightedSumNormCudaBackwardKernel(
sum_alphafs += alphas[batch][k][j][i] * features[ch][n_idx];
}

if (sum_alpha < kEpsilon) {
sum_alpha = kEpsilon;
}
sum_alpha_den = max(sum_alpha, kEpsilon);

// Iterate again through the closest K points for this pixel to calculate
// the gradient.
Expand All @@ -147,7 +146,7 @@ __global__ void weightedSumNormCudaBackwardKernel(
atomicAdd(
&grad_alphas[batch][k][j][i],
(features[ch][n_idx] * sum_alpha - sum_alphafs) /
(sum_alpha * sum_alpha) * grad_outputs[batch][ch][j][i]);
(sum_alpha_den * sum_alpha_den) * grad_outputs[batch][ch][j][i]);
atomicAdd(
&grad_features[ch][n_idx],
alpha * grad_outputs[batch][ch][j][i] / sum_alpha);
Expand Down
7 changes: 3 additions & 4 deletions pytorch3d/csrc/compositing/norm_weighted_sum_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ std::tuple<torch::Tensor, torch::Tensor> weightedSumNormCpuBackward(
for (int i = 0; i < W; ++i) {
float t_alpha = 0.;
float t_alphafs = 0.;
float t_alpha_den = t_alpha;
// Iterate through the closest K points for this pixel
for (int k = 0; k < K; ++k) {
int64_t n_idx = points_idx_a[b][k][j][i];
Expand All @@ -113,9 +114,7 @@ std::tuple<torch::Tensor, torch::Tensor> weightedSumNormCpuBackward(
t_alphafs += alphas_a[b][k][j][i] * features_a[c][n_idx];
}

if (t_alpha < kEps) {
t_alpha = kEps;
}
t_alpha_den = std::max(t_alpha, kEps);

// Iterate through the closest K points for this pixel ordered by z
// distance.
Expand All @@ -128,7 +127,7 @@ std::tuple<torch::Tensor, torch::Tensor> weightedSumNormCpuBackward(
float alpha = alphas_a[b][k][j][i];
grad_alphas_a[b][k][j][i] += grad_outputs_a[b][c][j][i] *
(features_a[c][n_idx] * t_alpha - t_alphafs) /
(t_alpha * t_alpha);
(t_alpha_den * t_alpha_den);
grad_features_a[c][n_idx] +=
grad_outputs_a[b][c][j][i] * alpha / t_alpha;
}
Expand Down