@@ -504,7 +504,7 @@ def _download(self, url: str, filename: Path, md5: str) -> None:
504
504
filename: output filename to write downloaded file
505
505
md5: md5 of downloaded file
506
506
"""
507
- if not os .path .exists (filename ):
507
+ if not os .path .exists (os . path . join ( self . root , filename ) ):
508
508
download_url (
509
509
url , self .root , filename = filename , md5 = md5 if self .checksum else None
510
510
)
@@ -667,6 +667,7 @@ def __init__(
667
667
self .download = download
668
668
self .checksum = checksum
669
669
self .class2idx = {c : i for i , c in enumerate (self .class_sets [num_classes ])}
670
+ self .idx2class = {i : c for i , c in enumerate (self .class_sets [num_classes ])}
670
671
self ._verify ()
671
672
672
673
self .metadata_df = pd .read_parquet (os .path .join (self .root , 'metadata.parquet' ))
@@ -788,10 +789,8 @@ def _load_target(self, index: int) -> Tensor:
788
789
Returns:
789
790
the target label
790
791
"""
791
- image_labels = self .metadata_df .iloc [index ]['labels' ]
792
+ indices = self .metadata_df .iloc [index ]['labels' ]
792
793
793
- # labels -> indices
794
- indices = [self .class2idx [label ] for label in image_labels ]
795
794
image_target = torch .zeros (self .num_classes , dtype = torch .long )
796
795
image_target [indices ] = 1
797
796
return image_target
@@ -839,7 +838,7 @@ def _download(self, url: str, filename: Path, md5: str) -> None:
839
838
filename: output filename to write downloaded file
840
839
md5: md5 of downloaded file
841
840
"""
842
- if not os .path .exists (filename ):
841
+ if not os .path .exists (os . path . join ( self . root , filename ) ):
843
842
download_url (
844
843
url , self .root , filename = filename , md5 = md5 if self .checksum else None
845
844
)
0 commit comments