1+ import os
2+ import uuid
13from abc import ABCMeta
24from typing import Any , Dict , List , Optional , Sequence , Tuple , Union , cast
35
1315
1416from autoPyTorch .constants import CLASSIFICATION_OUTPUTS , STRING_TO_OUTPUT_TYPES
1517from autoPyTorch .datasets .resampling_strategy import (
16- CROSS_VAL_FN ,
18+ CrossValFunc ,
19+ CrossValFuncs ,
1720 CrossValTypes ,
1821 DEFAULT_RESAMPLING_PARAMETERS ,
19- HOLDOUT_FN ,
20- HoldoutValTypes ,
21- get_cross_validators ,
22- get_holdout_validators ,
23- is_stratified ,
22+ HoldOutFunc ,
23+ HoldOutFuncs ,
24+ HoldoutValTypes
2425)
25- from autoPyTorch .utils .common import FitRequirement , hash_array_or_matrix
26+ from autoPyTorch .utils .common import FitRequirement
2627
27- BaseDatasetType = Union [Tuple [np .ndarray , np .ndarray ], Dataset ]
28+ BaseDatasetInputType = Union [Tuple [np .ndarray , np .ndarray ], Dataset ]
2829
2930
3031def check_valid_data (data : Any ) -> None :
@@ -33,7 +34,8 @@ def check_valid_data(data: Any) -> None:
3334 'The specified Data for Dataset must have both __getitem__ and __len__ attribute.' )
3435
3536
36- def type_check (train_tensors : BaseDatasetType , val_tensors : Optional [BaseDatasetType ] = None ) -> None :
37+ def type_check (train_tensors : BaseDatasetInputType ,
38+ val_tensors : Optional [BaseDatasetInputType ] = None ) -> None :
3739 """To avoid unexpected behavior, we use loops over indices."""
3840 for i in range (len (train_tensors )):
3941 check_valid_data (train_tensors [i ])
@@ -49,8 +51,8 @@ class TransformSubset(Subset):
4951 we require different transformation for each data point.
5052 This class helps to take the subset of the dataset
5153 with either training or validation transformation.
52-
53- We achieve so by adding a train flag to the pytorch subset
54+ The TransformSubset allows to add train flags
55+ while indexing the main dataset towards this goal.
5456
5557 Attributes:
5658 dataset (BaseDataset/Dataset): Dataset to sample the subset
@@ -71,10 +73,10 @@ def __getitem__(self, idx: int) -> np.ndarray:
7173class BaseDataset (Dataset , metaclass = ABCMeta ):
7274 def __init__ (
7375 self ,
74- train_tensors : BaseDatasetType ,
76+ train_tensors : BaseDatasetInputType ,
7577 dataset_name : Optional [str ] = None ,
76- val_tensors : Optional [BaseDatasetType ] = None ,
77- test_tensors : Optional [BaseDatasetType ] = None ,
78+ val_tensors : Optional [BaseDatasetInputType ] = None ,
79+ test_tensors : Optional [BaseDatasetInputType ] = None ,
7880 resampling_strategy : Union [CrossValTypes , HoldoutValTypes ] = HoldoutValTypes .holdout_validation ,
7981 resampling_strategy_args : Optional [Dict [str , Any ]] = None ,
8082 shuffle : Optional [bool ] = True ,
@@ -106,14 +108,16 @@ def __init__(
106108 val_transforms (Optional[torchvision.transforms.Compose]):
107109 Additional Transforms to be applied to the validation/test data
108110 """
109- self .dataset_name = dataset_name if dataset_name is not None \
110- else hash_array_or_matrix (train_tensors [0 ])
111+ self .dataset_name = dataset_name
112+
113+ if self .dataset_name is None :
114+ self .dataset_name = str (uuid .uuid1 (clock_seq = os .getpid ()))
111115
112116 if not hasattr (train_tensors [0 ], 'shape' ):
113117 type_check (train_tensors , val_tensors )
114118 self .train_tensors , self .val_tensors , self .test_tensors = train_tensors , val_tensors , test_tensors
115- self .cross_validators : Dict [str , CROSS_VAL_FN ] = {}
116- self .holdout_validators : Dict [str , HOLDOUT_FN ] = {}
119+ self .cross_validators : Dict [str , CrossValFunc ] = {}
120+ self .holdout_validators : Dict [str , HoldOutFunc ] = {}
117121 self .rng = np .random .RandomState (seed = seed )
118122 self .shuffle = shuffle
119123 self .resampling_strategy = resampling_strategy
@@ -134,8 +138,8 @@ def __init__(
134138 self .is_small_preprocess = True
135139
136140 # Make sure cross validation splits are created once
137- self .cross_validators = get_cross_validators (* CrossValTypes )
138- self .holdout_validators = get_holdout_validators (* HoldoutValTypes )
141+ self .cross_validators = CrossValFuncs . get_cross_validators (* CrossValTypes )
142+ self .holdout_validators = HoldOutFuncs . get_holdout_validators (* HoldoutValTypes )
139143 self .splits = self .get_splits_from_resampling_strategy ()
140144
141145 # We also need to be able to transform the data, be it for pre-processing
@@ -263,7 +267,7 @@ def create_cross_val_splits(
263267 if not isinstance (cross_val_type , CrossValTypes ):
264268 raise NotImplementedError (f'The selected `cross_val_type` "{ cross_val_type } " is not implemented.' )
265269 kwargs = {}
266- if is_stratified (cross_val_type ):
270+ if cross_val_type . is_stratified ():
267271 # we need additional information about the data for stratification
268272 kwargs ["stratify" ] = self .train_tensors [- 1 ]
269273 splits = self .cross_validators [cross_val_type .name ](
@@ -298,7 +302,7 @@ def create_holdout_val_split(
298302 if not isinstance (holdout_val_type , HoldoutValTypes ):
299303 raise NotImplementedError (f'The specified `holdout_val_type` "{ holdout_val_type } " is not supported.' )
300304 kwargs = {}
301- if is_stratified (holdout_val_type ):
305+ if holdout_val_type . is_stratified ():
302306 # we need additional information about the data for stratification
303307 kwargs ["stratify" ] = self .train_tensors [- 1 ]
304308 train , val = self .holdout_validators [holdout_val_type .name ](val_share , self ._get_indices (), ** kwargs )
@@ -321,7 +325,8 @@ def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]:
321325 return (TransformSubset (self , self .splits [split_id ][0 ], train = True ),
322326 TransformSubset (self , self .splits [split_id ][1 ], train = False ))
323327
324- def replace_data (self , X_train : BaseDatasetType , X_test : Optional [BaseDatasetType ]) -> 'BaseDataset' :
328+ def replace_data (self , X_train : BaseDatasetInputType ,
329+ X_test : Optional [BaseDatasetInputType ]) -> 'BaseDataset' :
325330 """
326331 To speed up the training of small dataset, early pre-processing of the data
327332 can be made on the fly by the pipeline.
0 commit comments