Skip to content

Commit d24c2b7

Browse files
support alpha; relax pytorch requirement (#94)
* support alpha for both marching and rendering; relax pytorch requirement * bump version
1 parent e9aa8d3 commit d24c2b7

File tree

5 files changed

+80
-30
lines changed

5 files changed

+80
-30
lines changed

examples/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,10 @@ def rgb_sigma_fn(t_starts, t_ends, ray_indices):
9999
alpha_thre=alpha_thre,
100100
)
101101
rgb, opacity, depth = rendering(
102-
rgb_sigma_fn,
103102
packed_info,
104103
t_starts,
105104
t_ends,
105+
rgb_sigma_fn=rgb_sigma_fn,
106106
render_bkgd=render_bkgd,
107107
)
108108
chunk_results = [rgb, opacity, depth, len(t_starts)]

nerfacc/ray_marching.py

+24-7
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ def ray_marching(
2222
scene_aabb: Optional[torch.Tensor] = None,
2323
# binarized grid for skipping empty space
2424
grid: Optional[Grid] = None,
25-
# sigma function for skipping invisible space
25+
# sigma/alpha function for skipping invisible space
2626
sigma_fn: Optional[Callable] = None,
27+
alpha_fn: Optional[Callable] = None,
2728
early_stop_eps: float = 1e-4,
2829
alpha_thre: float = 0.0,
2930
# rendering options
@@ -61,6 +62,12 @@ def ray_marching(
6162
by evaluating the density along the ray with `sigma_fn`. It should be a
6263
function that takes in samples {t_starts (N, 1), t_ends (N, 1),
6364
ray indices (N,)} and returns the post-activation density values (N, 1).
65+
You should only provide either `sigma_fn` or `alpha_fn`.
66+
alpha_fn: Optional. If provided, the marching will skip the invisible space
67+
by evaluating the density along the ray with `alpha_fn`. It should be a
68+
function that takes in samples {t_starts (N, 1), t_ends (N, 1),
69+
ray indices (N,)} and returns the post-activation opacity values (N, 1).
70+
You should only provide either `sigma_fn` or `alpha_fn`.
6471
early_stop_eps: Early stop threshold for skipping invisible space. Default: 1e-4.
6572
alpha_thre: Alpha threshold for skipping empty space. Default: 0.0.
6673
near_plane: Optional. Near plane distance. If provided, it will be used
@@ -128,6 +135,10 @@ def ray_marching(
128135
"""
129136
if not rays_o.is_cuda:
130137
raise NotImplementedError("Only support cuda inputs.")
138+
if alpha_fn is not None and sigma_fn is not None:
139+
raise ValueError(
140+
"Only one of `alpha_fn` and `sigma_fn` should be provided."
141+
)
131142

132143
# logic for t_min and t_max:
133144
# 1. if t_min and t_max are given, use them with highest priority.
@@ -184,14 +195,20 @@ def ray_marching(
184195
)
185196

186197
# skip invisible space
187-
if sigma_fn is not None:
198+
if sigma_fn is not None or alpha_fn is not None:
188199
# Query sigma without gradients
189200
ray_indices = unpack_info(packed_info)
190-
sigmas = sigma_fn(t_starts, t_ends, ray_indices.long())
191-
assert (
192-
sigmas.shape == t_starts.shape
193-
), "sigmas must have shape of (N, 1)! Got {}".format(sigmas.shape)
194-
alphas = 1.0 - torch.exp(-sigmas * (t_ends - t_starts))
201+
if sigma_fn is not None:
202+
sigmas = sigma_fn(t_starts, t_ends, ray_indices.long())
203+
assert (
204+
sigmas.shape == t_starts.shape
205+
), "sigmas must have shape of (N, 1)! Got {}".format(sigmas.shape)
206+
alphas = 1.0 - torch.exp(-sigmas * (t_ends - t_starts))
207+
elif alpha_fn is not None:
208+
alphas = alpha_fn(t_starts, t_ends, ray_indices.long())
209+
assert (
210+
alphas.shape == t_starts.shape
211+
), "alphas must have shape of (N, 1)! Got {}".format(alphas.shape)
195212

196213
# Compute visibility of the samples, and filter out invisible samples
197214
visibility, packed_info_visible = render_visibility(

nerfacc/vol_rendering.py

+50-19
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313

1414

1515
def rendering(
16-
# radiance field
17-
rgb_sigma_fn: Callable,
1816
# ray marching results
1917
packed_info: torch.Tensor,
2018
t_starts: torch.Tensor,
2119
t_ends: torch.Tensor,
20+
# radiance field
21+
rgb_sigma_fn: Optional[Callable] = None,
22+
rgb_alpha_fn: Optional[Callable] = None,
2223
# rendering options
2324
early_stop_eps: float = 1e-4,
2425
alpha_thre: float = 0.0,
@@ -33,12 +34,17 @@ def rendering(
3334
This function is not differentiable to `t_starts`, `t_ends`.
3435
3536
Args:
36-
rgb_sigma_fn: A function that takes in samples {t_starts (N, 1), t_ends (N, 1), \
37-
ray indices (N,)} and returns the post-activation rgb (N, 3) and density \
38-
values (N, 1).
3937
packed_info: Packed ray marching info. See :func:`ray_marching` for details.
4038
t_starts: Per-sample start distance. Tensor with shape (n_samples, 1).
4139
t_ends: Per-sample end distance. Tensor with shape (n_samples, 1).
40+
rgb_sigma_fn: A function that takes in samples {t_starts (N, 1), t_ends (N, 1), \
41+
ray indices (N,)} and returns the post-activation rgb (N, 3) and density \
42+
values (N, 1). At least one of `rgb_sigma_fn` and `rgb_alpha_fn` should be \
43+
specified.
44+
rgb_alpha_fn: A function that takes in samples {t_starts (N, 1), t_ends (N, 1), \
45+
ray indices (N,)} and returns the post-activation rgb (N, 3) and opacity \
46+
values (N, 1). At least one of `rgb_sigma_fn` and `rgb_alpha_fn` should be \
47+
specified.
4248
early_stop_eps: Early stop threshold during trasmittance accumulation. Default: 1e-4.
4349
alpha_thre: Alpha threshold for skipping empty space. Default: 0.0.
4450
render_bkgd: Optional. Background color. Tensor with shape (3,).
@@ -76,22 +82,47 @@ def rgb_sigma_fn(t_starts, t_ends, ray_indices):
7682
print(colors.shape, opacities.shape, depths.shape)
7783
7884
"""
85+
if callable(packed_info):
86+
raise RuntimeError(
87+
"You maybe want to use the nerfacc<=0.2.1 version. For nerfacc>0.2.1, "
88+
"The first argument of `rendering` should be the packed ray packed info. "
89+
"See the latest documentation for details: "
90+
"https://www.nerfacc.com/en/latest/apis/rendering.html#nerfacc.rendering"
91+
)
92+
93+
if rgb_sigma_fn is None and rgb_alpha_fn is None:
94+
raise ValueError(
95+
"At least one of `rgb_sigma_fn` and `rgb_alpha_fn` should be specified."
96+
)
97+
7998
n_rays = packed_info.shape[0]
8099
ray_indices = unpack_info(packed_info)
81100

82-
# Query sigma and color with gradients
83-
rgbs, sigmas = rgb_sigma_fn(t_starts, t_ends, ray_indices.long())
84-
assert rgbs.shape[-1] == 3, "rgbs must have 3 channels, got {}".format(
85-
rgbs.shape
86-
)
87-
assert (
88-
sigmas.shape == t_starts.shape
89-
), "sigmas must have shape of (N, 1)! Got {}".format(sigmas.shape)
90-
91-
# Rendering: compute weights and ray indices.
92-
weights = render_weight_from_density(
93-
packed_info, t_starts, t_ends, sigmas, early_stop_eps, alpha_thre
94-
)
101+
# Query sigma/alpha and color with gradients
102+
if rgb_sigma_fn is not None:
103+
rgbs, sigmas = rgb_sigma_fn(t_starts, t_ends, ray_indices.long())
104+
assert rgbs.shape[-1] == 3, "rgbs must have 3 channels, got {}".format(
105+
rgbs.shape
106+
)
107+
assert (
108+
sigmas.shape == t_starts.shape
109+
), "sigmas must have shape of (N, 1)! Got {}".format(sigmas.shape)
110+
# Rendering: compute weights and ray indices.
111+
weights = render_weight_from_density(
112+
packed_info, t_starts, t_ends, sigmas, early_stop_eps, alpha_thre
113+
)
114+
elif rgb_alpha_fn is not None:
115+
rgbs, alphas = rgb_alpha_fn(t_starts, t_ends, ray_indices.long())
116+
assert rgbs.shape[-1] == 3, "rgbs must have 3 channels, got {}".format(
117+
rgbs.shape
118+
)
119+
assert (
120+
alphas.shape == t_starts.shape
121+
), "alphas must have shape of (N, 1)! Got {}".format(alphas.shape)
122+
# Rendering: compute weights and ray indices.
123+
weights = render_weight_from_alpha(
124+
packed_info, alphas, early_stop_eps, alpha_thre
125+
)
95126

96127
# Rendering: accumulate rgbs, opacities, and depths along the rays.
97128
colors = accumulate_along_rays(
@@ -244,7 +275,7 @@ def render_weight_from_alpha(
244275
early_stop_eps: float = 1e-4,
245276
alpha_thre: float = 0.0,
246277
) -> torch.Tensor:
247-
"""Compute transmittance weights from density.
278+
"""Compute transmittance weights from opacity.
248279
249280
Args:
250281
packed_info: Stores information on which samples belong to the same ray. \

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "nerfacc"
7-
version = "0.2.1"
7+
version = "0.2.2"
88
description = "A General NeRF Acceleration Toolbox."
99
readme = "README.md"
1010
authors = [{name = "Ruilong", email = "[email protected]"}]
@@ -14,7 +14,7 @@ dependencies = [
1414
"importlib_metadata>=5.0.0; python_version<'3.8'",
1515
"ninja>=1.10.2.3",
1616
"pybind11>=2.10.0",
17-
"torch>=1.12.0",
17+
"torch", # tested with 1.12.0
1818
"rich>=12"
1919
]
2020

tests/test_rendering.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,9 @@ def rgb_sigma_fn(t_starts, t_ends, ray_indices):
120120
t_starts = torch.rand_like(sigmas)
121121
t_ends = torch.rand_like(sigmas) + 1.0
122122

123-
_, _, _ = rendering(rgb_sigma_fn, packed_info, t_starts, t_ends)
123+
_, _, _ = rendering(
124+
packed_info, t_starts, t_ends, rgb_sigma_fn=rgb_sigma_fn
125+
)
124126

125127

126128
if __name__ == "__main__":

0 commit comments

Comments
 (0)