Skip to content

Commit c40aff1

Browse files
committed
tests
1 parent 9dfbb72 commit c40aff1

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

tests/datasets/test_bigearthnet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def test_getitem_s2(self, dataset: BigEarthNetV2) -> None:
216216
def test_len(self, dataset: BigEarthNetV2) -> None:
217217
"""Test dataset length."""
218218
if dataset.split == 'train':
219-
assert len(dataset) == 2
219+
assert len(dataset) == 1
220220
elif dataset.split == 'val':
221221
assert len(dataset) == 1
222222
else:

torchgeo/datasets/bigearthnet.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ def _download(self, url: str, filename: Path, md5: str) -> None:
504504
filename: output filename to write downloaded file
505505
md5: md5 of downloaded file
506506
"""
507-
if not os.path.exists(filename):
507+
if not os.path.exists(os.path.join(self.root, filename)):
508508
download_url(
509509
url, self.root, filename=filename, md5=md5 if self.checksum else None
510510
)
@@ -667,6 +667,7 @@ def __init__(
667667
self.download = download
668668
self.checksum = checksum
669669
self.class2idx = {c: i for i, c in enumerate(self.class_sets[num_classes])}
670+
self.idx2class = {i: c for i, c in enumerate(self.class_sets[num_classes])}
670671
self._verify()
671672

672673
self.metadata_df = pd.read_parquet(os.path.join(self.root, 'metadata.parquet'))
@@ -788,10 +789,8 @@ def _load_target(self, index: int) -> Tensor:
788789
Returns:
789790
the target label
790791
"""
791-
image_labels = self.metadata_df.iloc[index]['labels']
792+
indices = self.metadata_df.iloc[index]['labels']
792793

793-
# labels -> indices
794-
indices = [self.class2idx[label] for label in image_labels]
795794
image_target = torch.zeros(self.num_classes, dtype=torch.long)
796795
image_target[indices] = 1
797796
return image_target
@@ -839,7 +838,7 @@ def _download(self, url: str, filename: Path, md5: str) -> None:
839838
filename: output filename to write downloaded file
840839
md5: md5 of downloaded file
841840
"""
842-
if not os.path.exists(filename):
841+
if not os.path.exists(os.path.join(self.root, filename)):
843842
download_url(
844843
url, self.root, filename=filename, md5=md5 if self.checksum else None
845844
)

0 commit comments

Comments
 (0)