Skip to content

Commit 238d586

Browse files
committed
Revert mypy changes (#1836)
* Revert mypy changes * Remove unused import
1 parent 454ef65 commit 238d586

28 files changed

+39
-77
lines changed

torchgeo/datamodules/chesapeake.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from typing import Any, Optional
77

88
import kornia.augmentation as K
9-
import torch
109
import torch.nn as nn
1110
import torch.nn.functional as F
1211
from einops import rearrange
@@ -113,7 +112,7 @@ def __init__(
113112
self.test_splits = test_splits
114113
self.class_set = class_set
115114
self.use_prior_labels = use_prior_labels
116-
self.prior_smoothing_constant = torch.tensor(prior_smoothing_constant)
115+
self.prior_smoothing_constant = prior_smoothing_constant
117116

118117
if self.use_prior_labels:
119118
self.layers = [

torchgeo/datamodules/spacenet.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from typing import Any
77

88
import kornia.augmentation as K
9-
import torch
109
from torch import Tensor
1110

1211
from ..datasets import SpaceNet1
@@ -88,6 +87,6 @@ def on_after_batch_transfer(
8887
# We add 1 to the mask to map the current {background, building} labels to
8988
# the values {1, 2}. This is necessary because we add 0 padding to the
9089
# mask that we want to ignore in the loss function.
91-
batch["mask"] += torch.tensor(1)
90+
batch["mask"] += 1
9291

9392
return super().on_after_batch_transfer(batch, dataloader_idx)

torchgeo/datasets/benin_cashews.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -384,8 +384,7 @@ def _load_mask(self, transform: rasterio.Affine) -> Tensor:
384384
dtype=np.uint8,
385385
)
386386

387-
mask = torch.from_numpy(mask_data)
388-
mask = mask.long()
387+
mask = torch.from_numpy(mask_data).long()
389388
return mask
390389

391390
def _check_integrity(self) -> bool:

torchgeo/datasets/bigearthnet.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -400,8 +400,7 @@ def _load_image(self, index: int) -> Tensor:
400400
)
401401
images.append(array)
402402
arrays: "np.typing.NDArray[np.int_]" = np.stack(images, axis=0)
403-
tensor = torch.from_numpy(arrays)
404-
tensor = tensor.float()
403+
tensor = torch.from_numpy(arrays).float()
405404
return tensor
406405

407406
def _load_target(self, index: int) -> Tensor:

torchgeo/datasets/biomassters.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,7 @@ def _load_target(self, filename: str) -> Tensor:
197197
with rasterio.open(os.path.join(self.root, "train_agbm", filename), "r") as src:
198198
arr: "np.typing.NDArray[np.float_]" = src.read()
199199

200-
target = torch.from_numpy(arr)
201-
target = target.float()
200+
target = torch.from_numpy(arr).float()
202201
return target
203202

204203
def _verify(self) -> None:

torchgeo/datasets/cloud_cover.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def plot(
384384
else:
385385
n_cols = 2
386386

387-
image, mask = sample["image"] / torch.tensor(3000), sample["mask"]
387+
image, mask = sample["image"] / 3000, sample["mask"]
388388

389389
fig, axs = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, n_cols * 5))
390390

torchgeo/datasets/cowc.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,7 @@ def _load_image(self, index: int) -> Tensor:
148148
filename = os.path.join(self.root, self.images[index])
149149
with Image.open(filename) as img:
150150
array: "np.typing.NDArray[np.int_]" = np.array(img)
151-
tensor = torch.from_numpy(array)
152-
tensor = tensor.float()
151+
tensor = torch.from_numpy(array).float()
153152
# Convert from HxWxC to CxHxW
154153
tensor = tensor.permute((2, 0, 1))
155154
return tensor
@@ -164,8 +163,7 @@ def _load_target(self, index: int) -> Tensor:
164163
the target
165164
"""
166165
target = int(self.targets[index])
167-
tensor = torch.tensor(target)
168-
tensor = tensor.float()
166+
tensor = torch.tensor(target).float()
169167
return tensor
170168

171169
def _check_integrity(self) -> bool:

torchgeo/datasets/cyclone.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,7 @@ def _load_image(self, directory: str) -> Tensor:
162162
img = img.resize(size=(self.size, self.size), resample=resample)
163163
array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB"))
164164
tensor = torch.from_numpy(array)
165-
tensor = tensor.permute((2, 0, 1))
166-
tensor = tensor.float()
165+
tensor = tensor.permute((2, 0, 1)).float()
167166
return tensor
168167

169168
def _load_features(self, directory: str) -> dict[str, Any]:

torchgeo/datasets/etci2021.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,7 @@ def _load_image(self, path: str) -> Tensor:
208208
filename = os.path.join(path)
209209
with Image.open(filename) as img:
210210
array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB"))
211-
tensor = torch.from_numpy(array)
212-
tensor = tensor.float()
211+
tensor = torch.from_numpy(array).float()
213212
# Convert from HxWxC to CxHxW
214213
tensor = tensor.permute((2, 0, 1))
215214
return tensor

torchgeo/datasets/idtrees.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -501,9 +501,7 @@ def plot(
501501
assert len(hsi_indices) == 3
502502

503503
def normalize(x: Tensor) -> Tensor:
504-
# https://github.com/pytorch/pytorch/issues/116327
505-
out: Tensor = (x - x.min()) / (x.max() - x.min())
506-
return out
504+
return (x - x.min()) / (x.max() - x.min())
507505

508506
ncols = 3
509507

torchgeo/datasets/inria.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,7 @@ def _load_image(self, path: str) -> Tensor:
130130
"""
131131
with rio.open(path) as img:
132132
array = img.read().astype(np.int32)
133-
tensor = torch.from_numpy(array)
134-
tensor = tensor.float()
133+
tensor = torch.from_numpy(array).float()
135134
return tensor
136135

137136
def _load_target(self, path: str) -> Tensor:
@@ -146,8 +145,7 @@ def _load_target(self, path: str) -> Tensor:
146145
with rio.open(path) as img:
147146
array = img.read().astype(np.int32)
148147
array = np.clip(array, 0, 1)
149-
mask = torch.from_numpy(array[0])
150-
mask = mask.long()
148+
mask = torch.from_numpy(array[0]).long()
151149
return mask
152150

153151
def __len__(self) -> int:

torchgeo/datasets/landcoverai.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,7 @@ def _load_image(self, id_: str) -> Tensor:
369369
filename = os.path.join(self.root, "output", id_ + ".jpg")
370370
with Image.open(filename) as img:
371371
array: "np.typing.NDArray[np.int_]" = np.array(img)
372-
tensor = torch.from_numpy(array)
373-
tensor = tensor.float()
372+
tensor = torch.from_numpy(array).float()
374373
# Convert from HxWxC to CxHxW
375374
tensor = tensor.permute((2, 0, 1))
376375
return tensor
@@ -388,8 +387,7 @@ def _load_target(self, id_: str) -> Tensor:
388387
filename = os.path.join(self.root, "output", id_ + "_m.png")
389388
with Image.open(filename) as img:
390389
array: "np.typing.NDArray[np.int_]" = np.array(img.convert("L"))
391-
tensor = torch.from_numpy(array)
392-
tensor = tensor.long()
390+
tensor = torch.from_numpy(array).long()
393391
return tensor
394392

395393
def _verify_data(self) -> bool:

torchgeo/datasets/loveda.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,7 @@ def _load_image(self, path: str) -> Tensor:
215215
filename = os.path.join(path)
216216
with Image.open(filename) as img:
217217
array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB"))
218-
tensor = torch.from_numpy(array)
219-
tensor = tensor.float()
218+
tensor = torch.from_numpy(array).float()
220219
# Convert from HxWxC to CxHxW
221220
tensor = tensor.permute((2, 0, 1))
222221
return tensor

torchgeo/datasets/mapinwild.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,7 @@ def _load_raster(self, filename: int, source: str) -> Tensor:
218218
array: "np.typing.NDArray[np.int_]" = np.stack(raw_array, axis=0)
219219
if array.dtype == np.uint16:
220220
array = array.astype(np.int32)
221-
tensor = torch.from_numpy(array)
222-
tensor = tensor.float()
221+
tensor = torch.from_numpy(array).float()
223222
return tensor
224223

225224
def _verify(self, url: str, md5: Optional[str] = None) -> None:

torchgeo/datasets/nasa_marine_debris.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,7 @@ def _load_image(self, path: str) -> Tensor:
130130
"""
131131
with rasterio.open(path) as f:
132132
array = f.read()
133-
tensor = torch.from_numpy(array)
134-
tensor = tensor.float()
133+
tensor = torch.from_numpy(array).float()
135134
return tensor
136135

137136
def _load_target(self, path: str) -> Tensor:

torchgeo/datasets/oscd.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,7 @@ def _load_image(self, paths: Sequence[str]) -> Tensor:
204204
with Image.open(path) as img:
205205
images.append(np.array(img))
206206
array: "np.typing.NDArray[np.int_]" = np.stack(images, axis=0).astype(np.int_)
207-
tensor = torch.from_numpy(array)
208-
tensor = tensor.float()
207+
tensor = torch.from_numpy(array).float()
209208
return tensor
210209

211210
def _load_target(self, path: str) -> Tensor:

torchgeo/datasets/pastis.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,7 @@ def _load_semantic_targets(self, index: int) -> Tensor:
232232
# See https://github.com/VSainteuf/pastis-benchmark/blob/main/code/dataloader.py#L201 # noqa: E501
233233
# even though the mask file is 3 bands, we just select the first band
234234
array = np.load(self.files[index]["semantic"])[0].astype(np.uint8)
235-
tensor = torch.from_numpy(array)
236-
tensor = tensor.long()
235+
tensor = torch.from_numpy(array).long()
237236
return tensor
238237

239238
def _load_instance_targets(self, index: int) -> tuple[Tensor, Tensor, Tensor]:

torchgeo/datasets/potsdam.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,7 @@ def _load_image(self, index: int) -> Tensor:
187187
path = self.files[index]["image"]
188188
with rasterio.open(path) as f:
189189
array = f.read()
190-
tensor = torch.from_numpy(array)
191-
tensor = tensor.float()
190+
tensor = torch.from_numpy(array).float()
192191
return tensor
193192

194193
def _load_target(self, index: int) -> Tensor:

torchgeo/datasets/seasonet.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -350,9 +350,7 @@ def _load_target(self, index: int) -> Tensor:
350350
path = self.files.iloc[index][0]
351351
with rasterio.open(f"{path}_labels.tif") as f:
352352
array = f.read() - 1
353-
tensor = torch.from_numpy(array)
354-
tensor = tensor.squeeze()
355-
tensor = tensor.long()
353+
tensor = torch.from_numpy(array).squeeze().long()
356354
return tensor
357355

358356
def _verify(self) -> None:

torchgeo/datasets/skippd.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,7 @@ def _load_image(self, index: int) -> Tensor:
173173
else:
174174
arr = rearrange(arr, "h w c -> c h w")
175175

176-
tensor = torch.from_numpy(arr)
177-
tensor = tensor.to(torch.float32)
176+
tensor = torch.from_numpy(arr).to(torch.float32)
178177
return tensor
179178

180179
def _load_features(self, index: int) -> dict[str, Union[str, Tensor]]:

torchgeo/datasets/spacenet.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,7 @@ def _load_mask(
203203
dtype=np.uint8,
204204
)
205205

206-
mask = torch.from_numpy(mask_data)
207-
mask = mask.long()
206+
mask = torch.from_numpy(mask_data).long()
208207

209208
return mask
210209

@@ -733,8 +732,7 @@ def _load_mask(
733732
dtype=np.uint8,
734733
)
735734

736-
mask = torch.from_numpy(mask_data)
737-
mask = mask.long()
735+
mask = torch.from_numpy(mask_data).long()
738736
return mask
739737

740738
def plot(

torchgeo/datasets/ssl4eo_benchmark.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,7 @@ def _load_image(self, path: str) -> Tensor:
314314
image
315315
"""
316316
with rasterio.open(path) as src:
317-
image = torch.from_numpy(src.read())
318-
image = image.float()
317+
image = torch.from_numpy(src.read()).float()
319318
return image
320319

321320
def _load_mask(self, path: str) -> Tensor:
@@ -328,8 +327,7 @@ def _load_mask(self, path: str) -> Tensor:
328327
mask
329328
"""
330329
with rasterio.open(path) as src:
331-
mask = torch.from_numpy(src.read())
332-
mask = mask.long()
330+
mask = torch.from_numpy(src.read()).long()
333331
mask = self.ordinal_map[mask]
334332
return mask
335333

torchgeo/datasets/usavars.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,7 @@ def _load_image(self, path: str) -> Tensor:
182182
"""
183183
with rasterio.open(path) as f:
184184
array: "np.typing.NDArray[np.int_]" = f.read()
185-
tensor = torch.from_numpy(array)
186-
tensor = tensor.float()
185+
tensor = torch.from_numpy(array).float()
187186
return tensor
188187

189188
def _verify(self) -> None:

torchgeo/losses/qr.py

+5-11
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import torch
77
import torch.nn.functional as F
8-
from torch import Tensor
98
from torch.nn.modules import Module
109

1110

@@ -29,16 +28,12 @@ def forward(self, probs: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
2928
qr loss
3029
"""
3130
q = probs
32-
# https://github.com/pytorch/pytorch/issues/116327
33-
q_bar: Tensor = q.mean(dim=(0, 2, 3))
34-
log_q_bar = torch.log(q_bar)
35-
qbar_log_S: Tensor = q_bar * log_q_bar
36-
qbar_log_S = qbar_log_S.sum()
31+
q_bar = q.mean(dim=(0, 2, 3))
32+
qbar_log_S = (q_bar * torch.log(q_bar)).sum()
3733

38-
q_log_p = torch.einsum("bcxy,bcxy->bxy", q, torch.log(target))
39-
q_log_p = q_log_p.mean()
34+
q_log_p = torch.einsum("bcxy,bcxy->bxy", q, torch.log(target)).mean()
4035

41-
loss: Tensor = qbar_log_S - q_log_p
36+
loss = qbar_log_S - q_log_p
4237
return loss
4338

4439

@@ -67,7 +62,6 @@ def forward(self, probs: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
6762
z = q / q.norm(p=1, dim=(0, 2, 3), keepdim=True).clamp_min(1e-12).expand_as(q)
6863
r = F.normalize(z * target, p=1, dim=1)
6964

70-
loss = torch.einsum("bcxy,bcxy->bxy", r, torch.log(r) - torch.log(q))
71-
loss = loss.mean()
65+
loss = torch.einsum("bcxy,bcxy->bxy", r, torch.log(r) - torch.log(q)).mean()
7266

7367
return loss

torchgeo/models/rcf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(
9494
),
9595
)
9696
self.register_buffer(
97-
"biases", torch.zeros(num_patches, requires_grad=False) + torch.tensor(bias)
97+
"biases", torch.zeros(num_patches, requires_grad=False) + bias
9898
)
9999

100100
if mode == "empirical":

torchgeo/samplers/batch.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def __init__(
128128
# torch.multinomial requires float probabilities > 0
129129
self.areas = torch.tensor(areas, dtype=torch.float)
130130
if torch.sum(self.areas) == 0:
131-
self.areas += torch.tensor(1)
131+
self.areas += 1
132132

133133
def __iter__(self) -> Iterator[list[BoundingBox]]:
134134
"""Return the indices of a dataset.

torchgeo/samplers/single.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def __init__(
128128
# torch.multinomial requires float probabilities > 0
129129
self.areas = torch.tensor(areas, dtype=torch.float)
130130
if torch.sum(self.areas) == 0:
131-
self.areas += torch.tensor(1)
131+
self.areas += 1
132132

133133
def __iter__(self) -> Iterator[BoundingBox]:
134134
"""Return the index of a dataset.

torchgeo/transforms/color.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,8 @@ def apply_transform(
7070
Returns:
7171
The augmented input.
7272
"""
73-
weights = flags["weights"]
74-
weights = weights[..., :, None, None]
75-
weights = weights.to(input.device)
76-
out: Tensor = input * weights
73+
weights = flags["weights"][..., :, None, None].to(input.device)
74+
out = input * weights
7775
out = out.sum(dim=-3)
78-
out = out.unsqueeze(-3)
79-
out = out.expand(input.shape)
76+
out = out.unsqueeze(-3).expand(input.shape)
8077
return out

0 commit comments

Comments
 (0)