Skip to content

Commit

Permalink
refactor: rewrite some comments;
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Dec 6, 2023
1 parent 6dad105 commit f403b3e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pypots/data/load_specific_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def load_specific_dataset(dataset_name: str, use_cache: bool = True) -> dict:
"""
logger.info(
f"Loading the dataset {dataset_name} with TSDB (https://github.com/WenjieDu/Time_Series_Database)..."
f"Loading the dataset {dataset_name} with TSDB (https://github.com/WenjieDu/Time_Series_Data_Beans)..."
)
assert dataset_name in SUPPORTED_DATASETS, (
f"Dataset {dataset_name} is not supported. "
Expand Down
6 changes: 3 additions & 3 deletions pypots/imputation/csdi/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,14 @@ def _fetch_data_from_array(self, idx: int) -> Iterable:
X = self.X[idx].to(torch.float32)
X_intact, X, missing_mask, indicating_mask = mcar(X, p=self.rate)

observed_data = X_intact # i.e. originally observed data
observed_mask = missing_mask + indicating_mask # i.e. originally missing masks
observed_data = X_intact
observed_mask = missing_mask + indicating_mask
observed_tp = (
torch.arange(0, self.n_steps, dtype=torch.float32)
if self.time_points is None
else self.time_points[idx].to(torch.float32)
)
gt_mask = missing_mask # missing mask with ground truth masked for validation
gt_mask = missing_mask
for_pattern_mask = (
gt_mask if self.for_pattern_mask is None else self.for_pattern_mask[idx]
)
Expand Down

0 comments on commit f403b3e

Please sign in to comment.