-
-
Notifications
You must be signed in to change notification settings - Fork 128
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add model templates for each task;
- Loading branch information
Showing
22 changed files
with
735 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
prune pypots/tests | ||
prune pypots/*/template |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# The template for new models to be included | ||
|
||
Congrats! You've made it this far! | ||
We really appreciate that you have taken the time to commit your model to PyPOTS! | ||
Once your model gets included in PyPOTS, it will be available to all PyPOTS users. | ||
Your research work will reach out to more people and be more impactful. | ||
|
||
This template is created to help you quickly get started with your model's inclusion. | ||
Your model's main body should be implemented in the `model.py` file. | ||
If your model consists of multiple modules, put them in the `modules.py`. | ||
`dataset.py` should contain the Dataset class assembling the input data for your model. | ||
|
||
Please follow the instructions below to complete your model's inclusion: | ||
|
||
1. Rename this folder `template` to your model's name; | ||
2. Implement your model according to the `TODO` comments, add necessary comments and docstrings, | ||
write necessary tests and run them on your local machine to ensure everything works well; | ||
3. Delete this README file and all TODO comments; | ||
4. Raise an issue first to request add your new model, then make a PR to commit your code the `dev` branch of PyPOTS; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
""" | ||
The package of the classification model "Your model name". | ||
Refer to the paper "Your paper citation". | ||
TODO: modify the above description with your model's information. | ||
""" | ||
|
||
# Created by Your Name <Your contact email> TODO: modify the author information. | ||
# License: GLP-v3 | ||
|
||
# TODO: ensure the import is correct | ||
from pypots.classification.template.model import YourNewModel | ||
|
||
__all__ = [ | ||
"YourNewModel", # TODO: ensure the name is correct | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
""" | ||
Dataset class for YourNewModel. | ||
TODO: modify the above description with your model's information. | ||
""" | ||
|
||
# Created by Your Name <Your contact email> TODO: modify the author information. | ||
# License: GLP-v3 | ||
|
||
from typing import Union, Iterable | ||
|
||
from pypots.data.base import BaseDataset | ||
|
||
|
||
class DatasetForYourNewModel(BaseDataset): | ||
def __init__( | ||
self, | ||
data: Union[dict, str], | ||
return_labels: bool = True, | ||
file_type: str = "h5py", | ||
): | ||
super().__init__(data, return_labels, file_type) | ||
|
||
def _fetch_data_from_array(self, idx: int) -> Iterable: | ||
raise NotImplementedError | ||
|
||
def _fetch_data_from_file(self, idx: int) -> Iterable: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
""" | ||
The implementation of YourNewModel for the partially-observed time-series classification task. | ||
Refer to the paper "Your paper citation". | ||
""" | ||
|
||
# Created by Your Name <Your contact email> TODO: modify the author information. | ||
# License: GLP-v3 | ||
|
||
from typing import Union, Optional | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
|
||
# TODO: import the base class from the classification package in PyPOTS. | ||
# Here I suppose this is a neural-network classification model. | ||
# You should make your model inherent BaseClassifier if it is not a NN. | ||
# from pypots.classification.base import BaseClassifier | ||
from pypots.classification.base import BaseNNClassifier | ||
|
||
from pypots.optim.adam import Adam | ||
from pypots.optim.base import Optimizer | ||
|
||
|
||
# TODO: define your new model here. | ||
# It could be a neural network model or a non-neural network algorithm (e.g. written in numpy). | ||
# Your model should be implemented with PyTorch and subclass torch.nn.Module if it is a neural network. | ||
# Note that your main algorithm is defined in this class, and this class usually won't be exposed to users. | ||
class _YourNewModel(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, inputs: dict) -> dict: | ||
# TODO: define your model's forward propagation process here. | ||
# The input is a dict, and the output `results` should also be a dict. | ||
# `results` must contains the key `loss` which is will be used for backward propagation to update the model. | ||
|
||
loss = None | ||
results = { | ||
"loss": loss, | ||
} | ||
return results | ||
|
||
|
||
# TODO: define your new model's wrapper here. | ||
# It should be a subclass of a base class defined in PyPOTS task packages (e.g. | ||
# BaseNNClassifier of PyPOTS classification task package). It has to implement all abstract methods of the base class. | ||
# Note that this class is a wrapper of your new model and will be directly exposed to users. | ||
class YourNewModel(BaseNNClassifier): | ||
def __init__( | ||
self, | ||
# TODO: add your model's hyper-parameters here | ||
n_classes: int, | ||
batch_size: int, | ||
epochs: int, | ||
patience: int, | ||
num_workers: int = 0, | ||
optimizer: Optional[Optimizer] = Adam(), | ||
device: Optional[Union[str, torch.device]] = None, | ||
saving_path: str = None, | ||
model_saving_strategy: Optional[str] = "best", | ||
): | ||
super().__init__( | ||
n_classes, | ||
batch_size, | ||
epochs, | ||
patience, | ||
num_workers, | ||
device, | ||
saving_path, | ||
model_saving_strategy, | ||
) | ||
# set up the hyper-parameters | ||
# TODO: set up your model's hyper-parameters here | ||
|
||
# set up the model | ||
self.model = _YourNewModel() | ||
self.model = self.model.to(self.device) | ||
self._print_model_size() | ||
|
||
# set up the optimizer | ||
self.optimizer = optimizer | ||
self.optimizer.init_optimizer(self.model.parameters()) | ||
|
||
def _assemble_input_for_training(self, data: list) -> dict: | ||
raise NotImplementedError | ||
|
||
def _assemble_input_for_validating(self, data: list) -> dict: | ||
raise NotImplementedError | ||
|
||
def _assemble_input_for_testing(self, data: list) -> dict: | ||
raise NotImplementedError | ||
|
||
def fit( | ||
self, | ||
train_set: Union[dict, str], | ||
val_set: Optional[Union[dict, str]] = None, | ||
file_type: str = "h5py", | ||
) -> None: | ||
raise NotImplementedError | ||
|
||
def classify(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
""" | ||
The implementation of the modules for YourNewModel. | ||
Refer to the paper "Your paper citation". | ||
""" | ||
|
||
# Created by Your Name <Your contact email> TODO: modify the author information. | ||
# License: GLP-v3 | ||
|
||
|
||
# TODO: this file is not necessary. If your new model has customized layers or modules, please put them here. | ||
# Otherwise, please delete this modules.py file, don't commit it to the repository. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# The template for new models to be included | ||
|
||
Congrats! You've made it this far! | ||
We really appreciate that you have taken the time to commit your model to PyPOTS! | ||
Once your model gets included in PyPOTS, it will be available to all PyPOTS users. | ||
Your research work will reach out to more people and be more impactful. | ||
|
||
This template is created to help you quickly get started with your model's inclusion. | ||
Your model's main body should be implemented in the `model.py` file. | ||
If your model consists of multiple modules, put them in the `modules.py`. | ||
`dataset.py` should contain the Dataset class assembling the input data for your model. | ||
|
||
Please follow the instructions below to complete your model's inclusion: | ||
|
||
1. Rename this folder `template` to your model's name; | ||
2. Implement your model according to the `TODO` comments, add necessary comments and docstrings, | ||
write necessary tests and run them on your local machine to ensure everything works well; | ||
3. Delete this README file and all TODO comments; | ||
4. Raise an issue first to request add your new model, then make a PR to commit your code the `dev` branch of PyPOTS; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
""" | ||
The package of the clustering model "Your model name". | ||
Refer to the paper "Your paper citation". | ||
TODO: modify the above description with your model's information. | ||
""" | ||
|
||
# Created by Your Name <Your contact email> TODO: modify the author information. | ||
# License: GLP-v3 | ||
|
||
# TODO: ensure the import is correct | ||
from pypots.clustering.template.model import YourNewModel | ||
|
||
__all__ = [ | ||
"YourNewModel", # TODO: ensure the name is correct | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
""" | ||
Dataset class for YourNewModel. | ||
TODO: modify the above description with your model's information. | ||
""" | ||
|
||
# Created by Your Name <Your contact email> TODO: modify the author information. | ||
# License: GLP-v3 | ||
|
||
from typing import Union, Iterable | ||
|
||
from pypots.data.base import BaseDataset | ||
|
||
|
||
class DatasetForYourNewModel(BaseDataset): | ||
def __init__( | ||
self, | ||
data: Union[dict, str], | ||
return_labels: bool = True, | ||
file_type: str = "h5py", | ||
): | ||
super().__init__(data, return_labels, file_type) | ||
|
||
def _fetch_data_from_array(self, idx: int) -> Iterable: | ||
raise NotImplementedError | ||
|
||
def _fetch_data_from_file(self, idx: int) -> Iterable: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
""" | ||
The implementation of YourNewModel for the partially-observed time-series clustering task. | ||
Refer to the paper "Your paper citation". | ||
""" | ||
|
||
# Created by Your Name <Your contact email> TODO: modify the author information. | ||
# License: GLP-v3 | ||
|
||
from typing import Union, Optional | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
|
||
# TODO: import the base class from the clustering package in PyPOTS. | ||
# Here I suppose this is a neural-network clustering model. | ||
# You should make your model inherent BaseClusterer if it is not a NN. | ||
# from pypots.clustering.base import BaseClusterer | ||
from pypots.clustering.base import BaseNNClusterer | ||
|
||
from pypots.optim.adam import Adam | ||
from pypots.optim.base import Optimizer | ||
|
||
|
||
# TODO: define your new model here. | ||
# It could be a neural network model or a non-neural network algorithm (e.g. written in numpy). | ||
# Your model should be implemented with PyTorch and subclass torch.nn.Module if it is a neural network. | ||
# Note that your main algorithm is defined in this class, and this class usually won't be exposed to users. | ||
class _YourNewModel(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, inputs: dict) -> dict: | ||
# TODO: define your model's forward propagation process here. | ||
# The input is a dict, and the output `results` should also be a dict. | ||
# `results` must contains the key `loss` which is will be used for backward propagation to update the model. | ||
|
||
loss = None | ||
results = { | ||
"loss": loss, | ||
} | ||
return results | ||
|
||
|
||
# TODO: define your new model's wrapper here. | ||
# It should be a subclass of a base class defined in PyPOTS task packages (e.g. | ||
# BaseNNClusterer of PyPOTS clustering task package), and it has to implement all abstract methods of the base class. | ||
# Note that this class is a wrapper of your new model and will be directly exposed to users. | ||
class YourNewModel(BaseNNClusterer): | ||
def __init__( | ||
self, | ||
# TODO: add your model's hyper-parameters here | ||
n_clusters: int, | ||
batch_size: int, | ||
epochs: int, | ||
patience: int, | ||
num_workers: int = 0, | ||
optimizer: Optional[Optimizer] = Adam(), | ||
device: Optional[Union[str, torch.device]] = None, | ||
saving_path: str = None, | ||
model_saving_strategy: Optional[str] = "best", | ||
): | ||
super().__init__( | ||
n_clusters, | ||
batch_size, | ||
epochs, | ||
patience, | ||
num_workers, | ||
device, | ||
saving_path, | ||
model_saving_strategy, | ||
) | ||
# set up the hyper-parameters | ||
# TODO: set up your model's hyper-parameters here | ||
|
||
# set up the model | ||
self.model = _YourNewModel() | ||
self.model = self.model.to(self.device) | ||
self._print_model_size() | ||
|
||
# set up the optimizer | ||
self.optimizer = optimizer | ||
self.optimizer.init_optimizer(self.model.parameters()) | ||
|
||
def _assemble_input_for_training(self, data: list) -> dict: | ||
raise NotImplementedError | ||
|
||
def _assemble_input_for_validating(self, data: list) -> dict: | ||
raise NotImplementedError | ||
|
||
def _assemble_input_for_testing(self, data: list) -> dict: | ||
raise NotImplementedError | ||
|
||
def fit( | ||
self, | ||
train_set: Union[dict, str], | ||
val_set: Optional[Union[dict, str]] = None, | ||
file_type: str = "h5py", | ||
) -> None: | ||
raise NotImplementedError | ||
|
||
def cluster(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
""" | ||
The implementation of the modules for YourNewModel. | ||
Refer to the paper "Your paper citation". | ||
""" | ||
|
||
# Created by Your Name <Your contact email> TODO: modify the author information. | ||
# License: GLP-v3 | ||
|
||
|
||
# TODO: this file is not necessary. If your new model has customized layers or modules, please put them here. | ||
# Otherwise, please delete this modules.py file, don't commit it to the repository. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.