Skip to content

Commit f844a97

Browse files
committed
Add Interface for depth compuation in function parameters
1 parent fc0cfe9 commit f844a97

File tree

4 files changed

+11
-6
lines changed

4 files changed

+11
-6
lines changed

cuda_rasterizer/forward.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,4 +452,4 @@ void FORWARD::preprocess(int P, int D, int M,
452452
tiles_touched,
453453
prefiltered
454454
);
455-
}
455+
}

cuda_rasterizer/rasterizer_impl.cu

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ CudaRasterizer::BinningState CudaRasterizer::BinningState::fromChunk(char*& chun
195195

196196
// Forward rendering procedure for differentiable rasterization
197197
// of Gaussians.
198-
int CudaRasterizer::Rasterizer::forward(
198+
int CudaRasterizer::Rasterizer::forward( // TODO
199199
std::function<char* (size_t)> geometryBuffer,
200200
std::function<char* (size_t)> binningBuffer,
201201
std::function<char* (size_t)> imageBuffer,
@@ -216,6 +216,7 @@ int CudaRasterizer::Rasterizer::forward(
216216
const float tan_fovx, float tan_fovy,
217217
const bool prefiltered,
218218
float* out_color,
219+
float* out_depth,
219220
int* radii)
220221
{
221222
const float focal_y = height / (2.0f * tan_fovy);
@@ -326,11 +327,13 @@ int CudaRasterizer::Rasterizer::forward(
326327
width, height,
327328
geomState.means2D,
328329
feature_ptr,
330+
geomState.depths,
329331
geomState.conic_opacity,
330332
imgState.accum_alpha,
331333
imgState.n_contrib,
332334
background,
333-
out_color);
335+
out_color,
336+
out_depth);
334337

335338
return num_rendered;
336339
}

diff_gaussian_rasterization/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,13 @@ def forward(
7575
)
7676

7777
# Invoke C++/CUDA rasterizer
78-
num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
78+
num_rendered, color, depth, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
7979

8080
# Keep relevant tensors for backward
8181
ctx.raster_settings = raster_settings
8282
ctx.num_rendered = num_rendered
8383
ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer)
84-
return color, radii
84+
return color, radii, depth
8585

8686
@staticmethod
8787
def backward(ctx, grad_out_color, _):

rasterize_points.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ RasterizeGaussiansCUDA(
6565
auto float_opts = means3D.options().dtype(torch::kFloat32);
6666

6767
torch::Tensor out_color = torch::full({NUM_CHANNELS, H, W}, 0.0, float_opts);
68+
torch::Tensor out_depth = torch::full({1, H, W}, 0.0, float_opts);
6869
torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32));
6970

7071
torch::Device device(torch::kCUDA);
@@ -107,9 +108,10 @@ RasterizeGaussiansCUDA(
107108
tan_fovy,
108109
prefiltered,
109110
out_color.contiguous().data<float>(),
111+
out_depth.contiguous().data<float>(),
110112
radii.contiguous().data<int>());
111113
}
112-
return std::make_tuple(rendered, out_color, radii, geomBuffer, binningBuffer, imgBuffer);
114+
return std::make_tuple(rendered, out_color, out_depth, radii, geomBuffer, binningBuffer, imgBuffer);
113115
}
114116

115117
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>

0 commit comments

Comments
 (0)