Skip to content

Commit f6c2ca6

Browse files
ezyangfacebook-github-bot
authored andcommitted
Prepare for "Fix type-safety of torch.nn.Module instances": wave 2
Summary: See D52890934 Reviewed By: malfet, r-barnes Differential Revision: D66245100 fbshipit-source-id: 019058106ac7eaacf29c1c55912922ea55894d23
1 parent e20cbe9 commit f6c2ca6

23 files changed

+147
-1
lines changed

projects/implicitron_trainer/impl/optimizer_factory.py

+1
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def __call__(
123123
"""
124124
# Get the parameters to optimize
125125
if hasattr(model, "_get_param_groups"): # use the model function
126+
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
126127
p_groups = model._get_param_groups(self.lr, wd=self.weight_decay)
127128
else:
128129
p_groups = [

projects/implicitron_trainer/impl/training_loop.py

+1
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,7 @@ def _training_or_validation_epoch(
395395
):
396396
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
397397
if hasattr(model, "visualize"):
398+
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
398399
model.visualize(
399400
viz,
400401
visdom_env_imgs,

pytorch3d/implicitron/dataset/utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ def adjust_camera_to_bbox_crop_(
329329

330330
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
331331
camera.focal_length[0],
332+
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
332333
camera.principal_point[0],
333334
image_size_wh,
334335
)
@@ -341,6 +342,7 @@ def adjust_camera_to_bbox_crop_(
341342
)
342343

343344
camera.focal_length = focal_length[None]
345+
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
344346
camera.principal_point = principal_point_cropped[None]
345347

346348

@@ -352,6 +354,7 @@ def adjust_camera_to_image_scale_(
352354
) -> PerspectiveCameras:
353355
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
354356
camera.focal_length[0],
357+
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
355358
camera.principal_point[0],
356359
original_size_wh,
357360
)
@@ -368,6 +371,7 @@ def adjust_camera_to_image_scale_(
368371
image_size_wh_output,
369372
)
370373
camera.focal_length = focal_length_scaled[None]
374+
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
371375
camera.principal_point = principal_point_scaled[None]
372376

373377

pytorch3d/implicitron/models/feature_extractor/resnet_feature_extractor.py

+11
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,15 @@ def _get_resnet_stage_feature_name(self, stage) -> str:
142142
return f"res_layer_{stage + 1}"
143143

144144
def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor:
145+
# pyre-fixme[58]: `-` is not supported for operand types `Tensor` and
146+
# `Union[Tensor, Module]`.
147+
# pyre-fixme[58]: `/` is not supported for operand types `Tensor` and
148+
# `Union[Tensor, Module]`.
145149
return (img - self._resnet_mean) / self._resnet_std
146150

147151
def get_feat_dims(self) -> int:
152+
# pyre-fixme[29]: `Union[(self: TensorBase) -> Tensor, Tensor, Module]` is
153+
# not a function.
148154
return sum(self._feat_dim.values())
149155

150156
def forward(
@@ -183,7 +189,12 @@ def forward(
183189
else:
184190
imgs_normed = imgs_resized
185191
# is not a function.
192+
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
186193
feats = self.stem(imgs_normed)
194+
# pyre-fixme[6]: For 1st argument expected `Iterable[_T1]` but got
195+
# `Union[Tensor, Module]`.
196+
# pyre-fixme[6]: For 2nd argument expected `Iterable[_T2]` but got
197+
# `Union[Tensor, Module]`.
187198
for stage, (layer, proj) in enumerate(zip(self.layers, self.proj_layers)):
188199
feats = layer(feats)
189200
# just a sanity check below

pytorch3d/implicitron/models/generic_model.py

+4
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,8 @@ def curried_viewpooler(pts):
478478
)
479479
custom_args["global_code"] = global_code
480480

481+
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not a
482+
# function.
481483
for func in self._implicit_functions:
482484
func.bind_args(**custom_args)
483485

@@ -500,6 +502,8 @@ def curried_viewpooler(pts):
500502
# Unbind the custom arguments to prevent pytorch from storing
501503
# large buffers of intermediate results due to points in the
502504
# bound arguments.
505+
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not a
506+
# function.
503507
for func in self._implicit_functions:
504508
func.unbind_args()
505509

pytorch3d/implicitron/models/global_encoder/autodecoder.py

+3
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def _build_key_map(
7171
return key_map
7272

7373
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
74+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `weight`.
7475
return (self._autodecoder_codes.weight**2).mean()
7576

7677
def get_encoding_dim(self) -> int:
@@ -95,13 +96,15 @@ def forward(self, x: Union[torch.LongTensor, List[str]]) -> Optional[torch.Tenso
9596
# pyre-fixme[9]: x has type `Union[List[str], LongTensor]`; used as
9697
# `Tensor`.
9798
x = torch.tensor(
99+
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, ...
98100
[self._key_map[elem] for elem in x],
99101
dtype=torch.long,
100102
device=next(self.parameters()).device,
101103
)
102104
except StopIteration:
103105
raise ValueError("Not enough n_instances in the autodecoder") from None
104106

107+
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
105108
return self._autodecoder_codes(x)
106109

107110
def _load_key_map_hook(

pytorch3d/implicitron/models/global_encoder/global_encoder.py

+1
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def forward(
122122
if frame_timestamp.shape[-1] != 1:
123123
raise ValueError("Frame timestamp's last dimensions should be one.")
124124
time = frame_timestamp / self.time_divisor
125+
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
125126
return self._harmonic_embedding(time)
126127

127128
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:

pytorch3d/implicitron/models/implicit_function/decoding_functions.py

+5
Original file line numberDiff line numberDiff line change
@@ -232,9 +232,14 @@ def forward(self, x: torch.Tensor, z: Optional[torch.Tensor] = None):
232232
# if the skip tensor is None, we use `x` instead.
233233
z = x
234234
skipi = 0
235+
# pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got
236+
# `Union[Tensor, Module]`.
235237
for li, layer in enumerate(self.mlp):
238+
# pyre-fixme[58]: `in` is not supported for right operand type
239+
# `Union[Tensor, Module]`.
236240
if li in self._input_skips:
237241
if self._skip_affine_trans:
242+
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, ...
238243
y = self._apply_affine_layer(self.skip_affines[skipi], y, z)
239244
else:
240245
y = torch.cat((y, z), dim=-1)

pytorch3d/implicitron/models/implicit_function/idr_feature_field.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,16 @@ def forward(
141141
self.embed_fn is None and fun_viewpool is None and global_code is None
142142
):
143143
return torch.tensor(
144-
[], device=rays_points_world.device, dtype=rays_points_world.dtype
144+
[],
145+
device=rays_points_world.device,
146+
dtype=rays_points_world.dtype,
147+
# pyre-fixme[6]: For 2nd argument expected `Union[int, SymInt]` but got
148+
# `Union[Module, Tensor]`.
145149
).view(0, self.out_dim)
146150

147151
embeddings = []
148152
if self.embed_fn is not None:
153+
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
149154
embeddings.append(self.embed_fn(rays_points_world))
150155

151156
if fun_viewpool is not None:
@@ -164,13 +169,19 @@ def forward(
164169

165170
embedding = torch.cat(embeddings, dim=-1)
166171
x = embedding
172+
# pyre-fixme[29]: `Union[(self: TensorBase, other: Union[bool, complex,
173+
# float, int, Tensor]) -> Tensor, Module, Tensor]` is not a function.
167174
for layer_idx in range(self.num_layers - 1):
168175
if layer_idx in self.skip_in:
169176
x = torch.cat([x, embedding], dim=-1) / 2**0.5
170177

178+
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[An...
171179
x = self.linear_layers[layer_idx](x)
172180

181+
# pyre-fixme[29]: `Union[(self: TensorBase, other: Union[bool, complex,
182+
# float, int, Tensor]) -> Tensor, Module, Tensor]` is not a function.
173183
if layer_idx < self.num_layers - 2:
184+
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
174185
x = self.softplus(x)
175186

176187
return x

pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py

+8
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,10 @@ def _get_colors(self, features: torch.Tensor, rays_directions: torch.Tensor):
123123
# Normalize the ray_directions to unit l2 norm.
124124
rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1)
125125
# Obtain the harmonic embedding of the normalized ray directions.
126+
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
126127
rays_embedding = self.harmonic_embedding_dir(rays_directions_normed)
127128

129+
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
128130
return self.color_layer((self.intermediate_linear(features), rays_embedding))
129131

130132
@staticmethod
@@ -195,6 +197,8 @@ def forward(
195197
embeds = create_embeddings_for_implicit_function(
196198
xyz_world=rays_points_world,
197199
# for 2nd param but got `Union[None, torch.Tensor, torch.nn.Module]`.
200+
# pyre-fixme[6]: For 2nd argument expected `Optional[(...) -> Any]` but
201+
# got `Union[None, Tensor, Module]`.
198202
xyz_embedding_function=(
199203
self.harmonic_embedding_xyz if self.input_xyz else None
200204
),
@@ -206,19 +210,23 @@ def forward(
206210
)
207211

208212
# embeds.shape = [minibatch x n_src x n_rays x n_pts x self.n_harmonic_functions*6+3]
213+
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
209214
features = self.xyz_encoder(embeds)
210215
# features.shape = [minibatch x ... x self.n_hidden_neurons_xyz]
211216
# NNs operate on the flattenned rays; reshaping to the correct spatial size
212217
# TODO: maybe make the transformer work on non-flattened tensors to avoid this reshape
213218
features = features.reshape(*rays_points_world.shape[:-1], -1)
214219

220+
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
215221
raw_densities = self.density_layer(features)
216222
# raw_densities.shape = [minibatch x ... x 1] in [0-1]
217223

218224
if self.xyz_ray_dir_in_camera_coords:
219225
if camera is None:
220226
raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")
221227

228+
# pyre-fixme[58]: `@` is not supported for operand types `Tensor` and
229+
# `Union[Tensor, Module]`.
222230
directions = ray_bundle.directions @ camera.R
223231
else:
224232
directions = ray_bundle.directions

pytorch3d/implicitron/models/implicit_function/scene_representation_networks.py

+14
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ def forward(
103103

104104
embeds = create_embeddings_for_implicit_function(
105105
xyz_world=rays_points_world,
106+
# pyre-fixme[6]: For 2nd argument expected `Optional[(...) -> Any]` but
107+
# got `Union[Tensor, Module]`.
106108
xyz_embedding_function=self._harmonic_embedding,
107109
global_code=global_code,
108110
fun_viewpool=fun_viewpool,
@@ -112,6 +114,7 @@ def forward(
112114

113115
# Before running the network, we have to resize embeds to ndims=3,
114116
# otherwise the SRN layers consume huge amounts of memory.
117+
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
115118
raymarch_features = self._net(
116119
embeds.view(embeds.shape[0], -1, embeds.shape[-1])
117120
)
@@ -166,7 +169,9 @@ def _get_colors(self, features: torch.Tensor, rays_directions: torch.Tensor):
166169
# Normalize the ray_directions to unit l2 norm.
167170
rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1)
168171
# Obtain the harmonic embedding of the normalized ray directions.
172+
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
169173
rays_embedding = self._harmonic_embedding(rays_directions_normed)
174+
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
170175
return self._color_layer((features, rays_embedding))
171176

172177
def forward(
@@ -195,20 +200,24 @@ def forward(
195200
denoting the color of each ray point.
196201
"""
197202
# raymarch_features.shape = [minibatch x ... x pts_per_ray x 3]
203+
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
198204
features = self._net(raymarch_features)
199205
# features.shape = [minibatch x ... x self.n_hidden_units]
200206

201207
if self.ray_dir_in_camera_coords:
202208
if camera is None:
203209
raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")
204210

211+
# pyre-fixme[58]: `@` is not supported for operand types `Tensor` and
212+
# `Union[Tensor, Module]`.
205213
directions = ray_bundle.directions @ camera.R
206214
else:
207215
directions = ray_bundle.directions
208216

209217
# NNs operate on the flattenned rays; reshaping to the correct spatial size
210218
features = features.reshape(*raymarch_features.shape[:-1], -1)
211219

220+
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
212221
raw_densities = self._density_layer(features)
213222

214223
rays_colors = self._get_colors(features, directions)
@@ -269,6 +278,7 @@ def _run_hypernet(self, global_code: torch.Tensor) -> Tuple[SRNRaymarchFunction]
269278
srn_raymarch_function.
270279
"""
271280

281+
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
272282
net = self._hypernet(global_code)
273283

274284
# use the hyper-net generated network to instantiate the raymarch module
@@ -304,6 +314,8 @@ def forward(
304314
# across LSTM iterations for the same global_code.
305315
if self.cached_srn_raymarch_function is None:
306316
# generate the raymarching network from the hypernet
317+
# pyre-fixme[16]: `SRNRaymarchHyperNet` has no attribute
318+
# `cached_srn_raymarch_function`.
307319
self.cached_srn_raymarch_function = self._run_hypernet(global_code)
308320
(srn_raymarch_function,) = cast(
309321
Tuple[SRNRaymarchFunction], self.cached_srn_raymarch_function
@@ -331,6 +343,7 @@ def __post_init__(self):
331343
def create_raymarch_function(self) -> None:
332344
self.raymarch_function = SRNRaymarchFunction(
333345
latent_dim=self.latent_dim,
346+
# pyre-fixme[32]: Keyword argument must be a mapping with string keys.
334347
**self.raymarch_function_args,
335348
)
336349

@@ -389,6 +402,7 @@ def create_hypernet(self) -> None:
389402
self.hypernet = SRNRaymarchHyperNet(
390403
latent_dim=self.latent_dim,
391404
latent_dim_hypernet=self.latent_dim_hypernet,
405+
# pyre-fixme[32]: Keyword argument must be a mapping with string keys.
392406
**self.hypernet_args,
393407
)
394408

pytorch3d/implicitron/models/implicit_function/voxel_grid.py

+11
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ def change_individual_resolution(tensor, wanted_resolution):
269269
for name, tensor in vars(grid_values_with_wanted_resolution).items()
270270
}
271271

272+
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
272273
return self.values_type(**params), True
273274

274275
def get_resolution_change_epochs(self) -> Tuple[int, ...]:
@@ -882,6 +883,7 @@ def forward(self, points: torch.Tensor) -> torch.Tensor:
882883
torch.Tensor of shape (..., n_features)
883884
"""
884885
locator = self._get_volume_locator()
886+
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
885887
grid_values = self.voxel_grid.values_type(**self.params)
886888
# voxel grids operate with extra n_grids dimension, which we fix to one
887889
return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0]
@@ -895,6 +897,7 @@ def set_voxel_grid_parameters(self, params: VoxelGridValuesBase) -> None:
895897
replace current parameters
896898
"""
897899
if self.hold_voxel_grid_as_parameters:
900+
# pyre-fixme[16]: `VoxelGridModule` has no attribute `params`.
898901
self.params = torch.nn.ParameterDict(
899902
{
900903
k: torch.nn.Parameter(val)
@@ -945,6 +948,7 @@ def _apply_epochs(self, epoch: int) -> bool:
945948
Returns:
946949
True if parameter change has happened else False.
947950
"""
951+
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
948952
grid_values = self.voxel_grid.values_type(**self.params)
949953
grid_values, change = self.voxel_grid.change_resolution(
950954
grid_values, epoch=epoch
@@ -992,16 +996,21 @@ def _create_parameters_with_new_size(
992996
"""
993997
'''
994998
new_params = {}
999+
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not a
1000+
# function.
9951001
for name in self.params:
9961002
key = prefix + "params." + name
9971003
if key in state_dict:
9981004
new_params[name] = torch.zeros_like(state_dict[key])
1005+
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
9991006
self.set_voxel_grid_parameters(self.voxel_grid.values_type(**new_params))
10001007

10011008
def get_device(self) -> torch.device:
10021009
"""
10031010
Returns torch.device on which module parameters are located
10041011
"""
1012+
# pyre-fixme[29]: `Union[(self: TensorBase) -> Tensor, Tensor, Module]` is
1013+
# not a function.
10051014
return next(val for val in self.params.values() if val is not None).device
10061015

10071016
def crop_self(self, min_point: torch.Tensor, max_point: torch.Tensor) -> None:
@@ -1018,13 +1027,15 @@ def crop_self(self, min_point: torch.Tensor, max_point: torch.Tensor) -> None:
10181027
"""
10191028
locator = self._get_volume_locator()
10201029
# torch.nn.modules.module.Module]` is not a function.
1030+
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
10211031
old_grid_values = self.voxel_grid.values_type(**self.params)
10221032
new_grid_values = self.voxel_grid.crop_world(
10231033
min_point, max_point, old_grid_values, locator
10241034
)
10251035
grid_values, _ = self.voxel_grid.change_resolution(
10261036
new_grid_values, grid_values_with_wanted_resolution=old_grid_values
10271037
)
1038+
# pyre-fixme[16]: `VoxelGridModule` has no attribute `params`.
10281039
self.params = torch.nn.ParameterDict(
10291040
{
10301041
k: torch.nn.Parameter(val)

0 commit comments

Comments
 (0)