diff --git a/pypots/data/load_specific_datasets.py b/pypots/data/load_specific_datasets.py index 174b148c..c43864cd 100644 --- a/pypots/data/load_specific_datasets.py +++ b/pypots/data/load_specific_datasets.py @@ -57,7 +57,7 @@ def load_specific_dataset(dataset_name: str, use_cache: bool = True) -> dict: """ logger.info( - f"Loading the dataset {dataset_name} with TSDB (https://github.com/WenjieDu/Time_Series_Database)..." + f"Loading the dataset {dataset_name} with TSDB (https://github.com/WenjieDu/Time_Series_Data_Beans)..." ) assert dataset_name in SUPPORTED_DATASETS, ( f"Dataset {dataset_name} is not supported. " diff --git a/pypots/imputation/csdi/data.py b/pypots/imputation/csdi/data.py index b2798ca0..e0cfc894 100644 --- a/pypots/imputation/csdi/data.py +++ b/pypots/imputation/csdi/data.py @@ -69,14 +69,14 @@ def _fetch_data_from_array(self, idx: int) -> Iterable: X = self.X[idx].to(torch.float32) X_intact, X, missing_mask, indicating_mask = mcar(X, p=self.rate) - observed_data = X_intact # i.e. originally observed data - observed_mask = missing_mask + indicating_mask # i.e. originally missing masks + observed_data = X_intact + observed_mask = missing_mask + indicating_mask observed_tp = ( torch.arange(0, self.n_steps, dtype=torch.float32) if self.time_points is None else self.time_points[idx].to(torch.float32) ) - gt_mask = missing_mask # missing mask with ground truth masked for validation + gt_mask = missing_mask for_pattern_mask = ( gt_mask if self.for_pattern_mask is None else self.for_pattern_mask[idx] )