Skip to content

Commit

Permalink
feat: enable return_labels in Dataset classes;
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Apr 24, 2023
1 parent ea04dd6 commit 0787260
Show file tree
Hide file tree
Showing 12 changed files with 93 additions and 42 deletions.
6 changes: 3 additions & 3 deletions pypots/classification/brits.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def fit(
Trained classifier.
"""

training_set = DatasetForBRITS(train_set)
training_set = DatasetForBRITS(train_set, file_type=file_type)
training_loader = DataLoader(
training_set,
batch_size=self.batch_size,
Expand All @@ -344,7 +344,7 @@ def fit(
if val_set is None:
self._train_model(training_loader)
else:
val_set = DatasetForBRITS(val_set)
val_set = DatasetForBRITS(val_set, file_type=file_type)
val_loader = DataLoader(
val_set,
batch_size=self.batch_size,
Expand Down Expand Up @@ -374,7 +374,7 @@ def classify(self, X: Union[dict, str], file_type: str = "h5py"):
Classification results of the given samples.
"""
self.model.eval() # set the model as eval status to freeze it.
test_set = DatasetForBRITS(X, file_type)
test_set = DatasetForBRITS(X, return_labels=False, file_type=file_type)
test_loader = DataLoader(
test_set,
batch_size=self.batch_size,
Expand Down
6 changes: 3 additions & 3 deletions pypots/classification/grud.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def fit(
Trained classifier.
"""

training_set = DatasetForGRUD(train_set, file_type)
training_set = DatasetForGRUD(train_set, file_type=file_type)
training_loader = DataLoader(
training_set,
batch_size=self.batch_size,
Expand All @@ -297,7 +297,7 @@ def fit(
if val_set is None:
self._train_model(training_loader)
else:
val_set = DatasetForGRUD(val_set)
val_set = DatasetForGRUD(val_set, file_type=file_type)
val_loader = DataLoader(
val_set,
batch_size=self.batch_size,
Expand Down Expand Up @@ -327,7 +327,7 @@ def classify(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray:
Classification results of the given samples.
"""
self.model.eval() # set the model as eval status to freeze it.
test_set = DatasetForGRUD(X, file_type)
test_set = DatasetForGRUD(X, return_labels=False, file_type=file_type)
test_loader = DataLoader(
test_set,
batch_size=self.batch_size,
Expand Down
6 changes: 3 additions & 3 deletions pypots/classification/raindrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ def fit(
Trained model.
"""

training_set = DatasetForGRUD(train_set)
training_set = DatasetForGRUD(train_set, file_type=file_type)
training_loader = DataLoader(
training_set,
batch_size=self.batch_size,
Expand All @@ -814,7 +814,7 @@ def fit(
if val_set is None:
self._train_model(training_loader)
else:
val_set = DatasetForGRUD(val_set)
val_set = DatasetForGRUD(val_set, file_type=file_type)
val_loader = DataLoader(
val_set,
batch_size=self.batch_size,
Expand Down Expand Up @@ -844,7 +844,7 @@ def classify(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray:
Classification results of the given samples.
"""
self.model.eval() # set the model as eval status to freeze it.
test_set = DatasetForGRUD(X, file_type)
test_set = DatasetForGRUD(X, return_labels=False, file_type=file_type)
test_loader = DataLoader(
test_set,
batch_size=self.batch_size,
Expand Down
4 changes: 2 additions & 2 deletions pypots/clustering/crli.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def fit(
The type of the given file if train_set is a path string.
"""
training_set = DatasetForGRUD(train_set, file_type)
training_set = DatasetForGRUD(train_set, file_type=file_type)
training_loader = DataLoader(
training_set,
batch_size=self.batch_size,
Expand Down Expand Up @@ -610,7 +610,7 @@ def cluster(
Clustering results.
"""
self.model.eval() # set the model as eval status to freeze it.
test_set = DatasetForGRUD(X, file_type)
test_set = DatasetForGRUD(X, return_labels=False, file_type=file_type)
test_loader = DataLoader(
test_set,
batch_size=self.batch_size,
Expand Down
4 changes: 2 additions & 2 deletions pypots/clustering/vader.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ def fit(
self : object,
Trained classifier.
"""
training_set = DatasetForGRUD(train_set, file_type)
training_set = DatasetForGRUD(train_set, file_type=file_type)
training_loader = DataLoader(
training_set,
batch_size=self.batch_size,
Expand Down Expand Up @@ -693,7 +693,7 @@ def cluster(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray:
Clustering results.
"""
self.model.eval() # set the model as eval status to freeze it.
test_set = DatasetForGRUD(X, file_type)
test_set = DatasetForGRUD(X, return_labels=False, file_type=file_type)
test_loader = DataLoader(
test_set,
batch_size=self.batch_size,
Expand Down
24 changes: 19 additions & 5 deletions pypots/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,31 @@ class BaseDataset(Dataset):
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
return_labels : bool, default = True,
Whether to return labels in function __getitem__() if they exist in the given data. If `True`, for example,
during training of classification models, the Dataset class will return labels in __getitem__() for model input.
Otherwise, labels won't be included in the data returned by __getitem__(). This parameter exists because we
need the defined Dataset class for all training/validating/testing stages. For those big datasets stored in h5
files, they already have both X and y saved. But we don't read labels from the file for validating and testing
with function _fetch_data_from_file(), which works for all three stages. Therefore, we need this parameter for
distinction.
file_type : str, default = "h5py"
The type of the given file if train_set and val_set are path strings.
"""

def __init__(self, data: Union[dict, str], file_type: str = "h5py"):
def __init__(
self,
data: Union[dict, str],
return_labels: bool = True,
file_type: str = "h5py",
):
super().__init__()
# types and shapes had been checked after X and y input into the model
# So they are safe to use here. No need to check again.

self.data = data
self.return_labels = return_labels
if isinstance(self.data, str): # data from file
# check if the given file type is supported
assert (
Expand Down Expand Up @@ -194,7 +209,7 @@ def _fetch_data_from_array(self, idx: int) -> Iterable:
missing_mask.to(torch.float32),
]

if self.y is not None:
if self.y is not None and self.return_labels:
sample.append(self.y[idx].to(torch.long))

return sample
Expand Down Expand Up @@ -269,9 +284,8 @@ def _fetch_data_from_file(self, idx: int) -> Iterable:
missing_mask.to(torch.float32),
]

if (
"y" in self.file_handle.keys()
): # if the dataset has labels, then fetch it from the file
# if the dataset has labels and is for training, then fetch it from the file
if "y" in self.file_handle.keys() and self.return_labels:
sample.append(self.file_handle["y"][idx].to(torch.long))

return sample
Expand Down
24 changes: 19 additions & 5 deletions pypots/data/dataset_for_brits.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,26 @@ class DatasetForBRITS(BaseDataset):
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
return_labels : bool, default = True,
Whether to return labels in function __getitem__() if they exist in the given data. If `True`, for example,
during training of classification models, the Dataset class will return labels in __getitem__() for model input.
Otherwise, labels won't be included in the data returned by __getitem__(). This parameter exists because we
need the defined Dataset class for all training/validating/testing stages. For those big datasets stored in h5
files, they already have both X and y saved. But we don't read labels from the file for validating and testing
with function _fetch_data_from_file(), which works for all three stages. Therefore, we need this parameter for
distinction.
file_type : str, default = "h5py"
The type of the given file if train_set and val_set are path strings.
"""

def __init__(self, data: Union[dict, str], file_type: str = "h5py"):
super().__init__(data, file_type)
def __init__(
self,
data: Union[dict, str],
return_labels: bool = True,
file_type: str = "h5py",
):
super().__init__(data, return_labels, file_type)

if not isinstance(self.data, str):
# calculate all delta here.
Expand Down Expand Up @@ -96,7 +110,7 @@ def _fetch_data_from_array(self, idx: int) -> Iterable:
self.processed_data["backward"]["delta"][idx].to(torch.float32),
]

if self.y is not None:
if self.y is not None and self.return_labels:
sample.append(self.y[idx].to(torch.long))

return sample
Expand Down Expand Up @@ -147,8 +161,8 @@ def _fetch_data_from_file(self, idx: int) -> Iterable:
backward["deltas"],
]

# if the dataset has labels, then fetch it from the file
if "y" in self.file_handle.keys():
# if the dataset has labels and is for training, then fetch it from the file
if "y" in self.file_handle.keys() and self.return_labels:
sample.append(torch.tensor(self.file_handle["y"][idx], dtype=torch.long))

return sample
24 changes: 19 additions & 5 deletions pypots/data/dataset_for_grud.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,26 @@ class DatasetForGRUD(BaseDataset):
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
return_labels : bool, default = True,
Whether to return labels in function __getitem__() if they exist in the given data. If `True`, for example,
during training of classification models, the Dataset class will return labels in __getitem__() for model input.
Otherwise, labels won't be included in the data returned by __getitem__(). This parameter exists because we
need the defined Dataset class for all training/validating/testing stages. For those big datasets stored in h5
files, they already have both X and y saved. But we don't read labels from the file for validating and testing
with function _fetch_data_from_file(), which works for all three stages. Therefore, we need this parameter for
distinction.
file_type : str, default = "h5py"
The type of the given file if train_set and val_set are path strings.
"""

def __init__(self, data: Union[dict, str], file_type: str = "h5py"):
super().__init__(data, file_type)
def __init__(
self,
data: Union[dict, str],
return_labels: bool = True,
file_type: str = "h5py",
):
super().__init__(data, return_labels, file_type)
self.locf = LOCF()

if not isinstance(self.data, str): # data from array
Expand Down Expand Up @@ -86,7 +100,7 @@ def _fetch_data_from_array(self, idx: int) -> Iterable:
self.empirical_mean.to(torch.float32),
]

if self.y is not None:
if self.y is not None and self.return_labels:
sample.append(self.y[idx].to(torch.long))

return sample
Expand Down Expand Up @@ -127,8 +141,8 @@ def _fetch_data_from_file(self, idx: int) -> Iterable:
empirical_mean,
]

# if the dataset has labels, then fetch it from the file
if "y" in self.file_handle.keys():
# if the dataset has labels and is for training, then fetch it from the file
if "y" in self.file_handle.keys() and self.return_labels:
sample.append(torch.tensor(self.file_handle["y"][idx], dtype=torch.long))

return sample
19 changes: 14 additions & 5 deletions pypots/data/dataset_for_mit.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ class DatasetForMIT(BaseDataset):
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
return_labels : bool, default = True,
Whether to return labels in function __getitem__() if they exist in the given data. If `True`, for example,
during training of classification models, the Dataset class will return labels in __getitem__() for model input.
Otherwise, labels won't be included in the data returned by __getitem__(). This parameter exists because we
need the defined Dataset class for all training/validating/testing stages. For those big datasets stored in h5
files, they already have both X and y saved. But we don't read labels from the file for validating and testing
with function _fetch_data_from_file(), which works for all three stages. Therefore, we need this parameter for
distinction.
file_type : str, default = "h5py"
The type of the given file if train_set and val_set are path strings.
Expand All @@ -44,10 +53,11 @@ class DatasetForMIT(BaseDataset):
def __init__(
self,
data: Union[dict, str],
return_labels: bool = True,
file_type: str = "h5py",
rate: float = 0.2,
):
super().__init__(data, file_type)
super().__init__(data, return_labels, file_type)
self.rate = rate

def _fetch_data_from_array(self, idx: int) -> Iterable:
Expand Down Expand Up @@ -89,7 +99,7 @@ def _fetch_data_from_array(self, idx: int) -> Iterable:
indicating_mask.to(torch.float32),
]

if self.y is not None:
if self.y is not None and self.return_labels:
sample.append(self.y[idx].to(torch.long))

return sample
Expand Down Expand Up @@ -123,9 +133,8 @@ def _fetch_data_from_file(self, idx: int) -> Iterable:
indicating_mask.to(torch.float32),
]

if (
"y" in self.file_handle.keys()
): # if the dataset has labels, then fetch it from the file
# if the dataset has labels and is for training, then fetch it from the file
if "y" in self.file_handle.keys() and self.return_labels:
sample.append(torch.tensor(self.file_handle["y"][idx], dtype=torch.long))

return sample
6 changes: 3 additions & 3 deletions pypots/imputation/brits.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ def fit(
The type of the given file if train_set and val_set are path strings.
"""
training_set = DatasetForBRITS(train_set, file_type)
training_set = DatasetForBRITS(train_set, file_type=file_type)
training_loader = DataLoader(
training_set,
batch_size=self.batch_size,
Expand All @@ -675,7 +675,7 @@ def fit(
"indicating_mask": hf["indicating_mask"][:],
}

val_set = DatasetForBRITS(val_set)
val_set = DatasetForBRITS(val_set, file_type=file_type)
val_loader = DataLoader(
val_set,
batch_size=self.batch_size,
Expand Down Expand Up @@ -710,7 +710,7 @@ def impute(
Imputed data.
"""
self.model.eval() # set the model as eval status to freeze it.
test_set = DatasetForBRITS(X)
test_set = DatasetForBRITS(X, return_labels=False, file_type=file_type)
test_loader = DataLoader(
test_set,
batch_size=self.batch_size,
Expand Down
6 changes: 3 additions & 3 deletions pypots/imputation/saits.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def fit(
The type of the given file if train_set and val_set are path strings.
"""
training_set = DatasetForMIT(train_set, file_type)
training_set = DatasetForMIT(train_set, file_type=file_type)
training_loader = DataLoader(
training_set,
batch_size=self.batch_size,
Expand All @@ -358,7 +358,7 @@ def fit(
"indicating_mask": hf["indicating_mask"][:],
}

val_set = BaseDataset(val_set)
val_set = BaseDataset(val_set, file_type=file_type)
val_loader = DataLoader(
val_set,
batch_size=self.batch_size,
Expand Down Expand Up @@ -392,7 +392,7 @@ def impute(
Imputed data.
"""
self.model.eval() # set the model as eval status to freeze it.
test_set = BaseDataset(X, file_type)
test_set = BaseDataset(X, return_labels=False, file_type=file_type)
test_loader = DataLoader(
test_set,
batch_size=self.batch_size,
Expand Down
6 changes: 3 additions & 3 deletions pypots/imputation/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def fit(
"""

training_set = DatasetForMIT(train_set, file_type)
training_set = DatasetForMIT(train_set, file_type=file_type)
training_loader = DataLoader(
training_set,
batch_size=self.batch_size,
Expand All @@ -470,7 +470,7 @@ def fit(
"indicating_mask": hf["indicating_mask"][:],
}

val_set = BaseDataset(val_set)
val_set = BaseDataset(val_set, file_type=file_type)
val_loader = DataLoader(
val_set,
batch_size=self.batch_size,
Expand Down Expand Up @@ -500,7 +500,7 @@ def impute(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray:
Imputed data.
"""
self.model.eval() # set the model as eval status to freeze it.
test_set = BaseDataset(X, file_type)
test_set = BaseDataset(X, return_labels=False, file_type=file_type)
test_loader = DataLoader(
test_set,
batch_size=self.batch_size,
Expand Down

0 comments on commit 0787260

Please sign in to comment.