2323)
2424from autoPyTorch .utils .common import FitRequirement , hash_array_or_matrix
2525
26- BaseDatasetType = Union [Tuple [np .ndarray , np .ndarray ], Dataset ]
26+ BaseDatasetInputType = Union [Tuple [np .ndarray , np .ndarray ], Dataset ]
2727
2828
2929def check_valid_data (data : Any ) -> None :
@@ -32,10 +32,9 @@ def check_valid_data(data: Any) -> None:
3232 'The specified Data for Dataset must have both __getitem__ and __len__ attribute.' )
3333
3434
35- def type_check (train_tensors : BaseDatasetType , val_tensors : Optional [BaseDatasetType ] = None ) -> None :
36- """To avoid unexpected behavior, we use loops over indices."""
37- for i in range (len (train_tensors )):
38- check_valid_data (train_tensors [i ])
35+ def type_check (train_tensors : BaseDatasetInputType , val_tensors : Optional [BaseDatasetInputType ] = None ) -> None :
36+ for train_tensor in train_tensors :
37+ check_valid_data (train_tensor )
3938 if val_tensors is not None :
4039 for i in range (len (val_tensors )):
4140 check_valid_data (val_tensors [i ])
@@ -63,10 +62,10 @@ def __getitem__(self, idx: int) -> np.ndarray:
6362class BaseDataset (Dataset , metaclass = ABCMeta ):
6463 def __init__ (
6564 self ,
66- train_tensors : BaseDatasetType ,
65+ train_tensors : BaseDatasetInputType ,
6766 dataset_name : Optional [str ] = None ,
68- val_tensors : Optional [BaseDatasetType ] = None ,
69- test_tensors : Optional [BaseDatasetType ] = None ,
67+ val_tensors : Optional [BaseDatasetInputType ] = None ,
68+ test_tensors : Optional [BaseDatasetInputType ] = None ,
7069 resampling_strategy : Union [CrossValTypes , HoldoutValTypes ] = HoldoutValTypes .holdout_validation ,
7170 resampling_strategy_args : Optional [Dict [str , Any ]] = None ,
7271 shuffle : Optional [bool ] = True ,
@@ -313,7 +312,7 @@ def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]:
313312 return (TransformSubset (self , self .splits [split_id ][0 ], train = True ),
314313 TransformSubset (self , self .splits [split_id ][1 ], train = False ))
315314
316- def replace_data (self , X_train : BaseDatasetType , X_test : Optional [BaseDatasetType ]) -> 'BaseDataset' :
315+ def replace_data (self , X_train : BaseDatasetInputType , X_test : Optional [BaseDatasetInputType ]) -> 'BaseDataset' :
317316 """
318317 To speed up the training of small dataset, early pre-processing of the data
319318 can be made on the fly by the pipeline.
0 commit comments