Skip to content

Commit 791a068

Browse files
bottlerfacebook-github-bot
authored andcommitted
avoid math.prod for python 3.7
Summary: This makes the new volumes tutorial work on google colab. Reviewed By: kjchalup Differential Revision: D38501906 fbshipit-source-id: a606a357e929dae903dc4d9067bd1519f05b1458
1 parent c49ebad commit 791a068

File tree

4 files changed

+21
-11
lines changed

4 files changed

+21
-11
lines changed

pytorch3d/common/compat.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
"""
13-
Some functions which depend on PyTorch versions.
13+
Some functions which depend on PyTorch or Python versions.
1414
"""
1515

1616

@@ -79,3 +79,12 @@ def meshgrid_ij(
7979
# pyre-fixme[6]: For 1st param expected `Union[List[Tensor], Tensor]` but got
8080
# `Union[Sequence[Tensor], Tensor]`.
8181
return torch.meshgrid(*A)
82+
83+
84+
def prod(iterable, *, start=1):
85+
"""
86+
Like math.prod in Python 3.8 and later.
87+
"""
88+
for i in iterable:
89+
start *= i
90+
return start

pytorch3d/implicitron/models/generic_model.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch
1818
import tqdm
1919
from omegaconf import DictConfig
20+
from pytorch3d.common.compat import prod
2021
from pytorch3d.implicitron.models.metrics import (
2122
RegularizationMetricsBase,
2223
ViewMetricsBase,
@@ -919,7 +920,7 @@ def _chunk_generator(
919920
f"by n_pts_per_ray ({n_pts_per_ray})"
920921
)
921922

922-
n_rays = math.prod(spatial_dim)
923+
n_rays = prod(spatial_dim)
923924
# special handling for raytracing-based methods
924925
n_chunks = -(-n_rays * max(n_pts_per_ray, 1) // chunk_size)
925926
chunk_size_in_rays = -(-n_rays // n_chunks)
@@ -935,9 +936,9 @@ def _chunk_generator(
935936
directions=ray_bundle.directions.reshape(batch_size, -1, 3)[
936937
:, start_idx:end_idx
937938
],
938-
lengths=ray_bundle.lengths.reshape(
939-
batch_size, math.prod(spatial_dim), n_pts_per_ray
940-
)[:, start_idx:end_idx],
939+
lengths=ray_bundle.lengths.reshape(batch_size, n_rays, n_pts_per_ray)[
940+
:, start_idx:end_idx
941+
],
941942
xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx],
942943
)
943944
extra_args = kwargs.copy()

pytorch3d/implicitron/models/implicit_function/utils.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import math
87
from typing import Callable, Optional
98

109
import torch
10+
from pytorch3d.common.compat import prod
1111
from pytorch3d.renderer.cameras import CamerasBase
1212

1313

@@ -52,7 +52,7 @@ def create_embeddings_for_implicit_function(
5252
embeds = torch.empty(
5353
bs,
5454
1,
55-
math.prod(spatial_size),
55+
prod(spatial_size),
5656
pts_per_ray,
5757
0,
5858
dtype=xyz_world.dtype,
@@ -62,7 +62,7 @@ def create_embeddings_for_implicit_function(
6262
embeds = xyz_embedding_function(ray_points_for_embed).reshape(
6363
bs,
6464
1,
65-
math.prod(spatial_size),
65+
prod(spatial_size),
6666
pts_per_ray,
6767
-1,
6868
) # flatten spatial, add n_src dim
@@ -73,7 +73,7 @@ def create_embeddings_for_implicit_function(
7373
embed_shape = (
7474
bs,
7575
embeds_viewpooled.shape[1],
76-
math.prod(spatial_size),
76+
prod(spatial_size),
7777
pts_per_ray,
7878
-1,
7979
)

pytorch3d/implicitron/models/renderer/sdf_renderer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
# implicit_differentiable_renderer.py
44
# Copyright (c) 2020 Lior Yariv
55
import functools
6-
import math
76
from typing import List, Optional, Tuple
87

98
import torch
109
from omegaconf import DictConfig
10+
from pytorch3d.common.compat import prod
1111
from pytorch3d.implicitron.tools.config import (
1212
get_default_args_field,
1313
registry,
@@ -105,7 +105,7 @@ def forward(
105105

106106
# object_mask: silhouette of the object
107107
batch_size, *spatial_size, _ = ray_bundle.lengths.shape
108-
num_pixels = math.prod(spatial_size)
108+
num_pixels = prod(spatial_size)
109109

110110
cam_loc = ray_bundle.origins.reshape(batch_size, -1, 3)
111111
ray_dirs = ray_bundle.directions.reshape(batch_size, -1, 3)

0 commit comments

Comments
 (0)