Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add deepar #192

Open
wants to merge 24 commits into
base: masterr
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,8 @@ jobs:
coverage run flood_forecast/trainer.py -p tests/multitask_decoder.json
echo -e 'running da-meta data unit test'
coverage run flood_forecast/trainer.py -p tests/da_meta.json
echo -e 'running Deep_AR_test \n'
coverage run flood_forecast/trainer.py -p tests/DeepAR_test.json
echo -e 'running transformer bottleneck'
coverage run flood_forecast/trainer.py -p tests/transformer_bottleneck.json
echo -e 'running da_rnn probabilistic test'
Expand Down
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,8 @@ tests/output/
data
mypy
.mypy_cache
*.png
*.png
.vscode
.vscode/*
.idea
.idea/
18 changes: 18 additions & 0 deletions flood_forecast/deep_ar/config/lstm_kwargs.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"batch_size": 64,
"cov_dim": 4,
"embedding_dim": 20,
"learning_rate": 1e-3,
"lstm_dropout": 0.1,
"lstm_hidden_dim": 40,
"lstm_layers": 3,
"num_class": 370,
"num_epochs": 20,
"predict_batch": 256,
"predict_start": 168,
"predict_steps": 24,
"sample_times": 200,
"test_predict_start": 168,
"test_window": 192,
"train_window": 192
}
137 changes: 137 additions & 0 deletions flood_forecast/deep_ar/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@

import torch
import torch.nn as nn


class DeepAR(nn.Module):
def __init__(self,
num_class: int,
cov_dim: int,
lstm_dropout: float,
embedding_dim: int,
lstm_hidden_dim: int,
lstm_layers: int,
sample_times: int,
predict_steps: int,
predict_start: int
):
"""Initialize the DeepAR model.

:param num_class: Number of classes
:param cov_dim: Number of covariates
:param lstm_dropout: drop out rate
:param embedding_dim: dimension of embedding layer
:param lstm_hidden_dim: hidden dimension of LSTM
:param lstm_layers: Number of LSTM layers
:param sample_times: sample time steps
:param predict_steps: Number of steps to predict
:param predict_start: Step to start prediction at
"""
super(DeepAR, self).__init__()
self.params = {}

self.params["num_class"] = num_class
self.params["cov_dim"] = cov_dim
self.params["lstm_dropout"] = lstm_dropout
self.params["embedding_dim"] = embedding_dim
self.params["lstm_hidden_dim"] = lstm_hidden_dim
self.params["lstm_layers"] = lstm_layers
self.params["sample_times"] = sample_times
self.params["predict_steps"] = predict_steps
self.params["predict_start"] = predict_start
# self.params = params
self.embedding = nn.Embedding(self.params["num_class"], self.params["embedding_dim"])

self.lstm = nn.LSTM(input_size=1 + self.params["cov_dim"] + self.params["embedding_dim"],
hidden_size=self.params["lstm_hidden_dim"],
num_layers=self.params["lstm_layers"],
bias=True,
batch_first=False,
dropout=self.params["lstm_dropout"])
'''self.lstm = nn.LSTM(input_size=1 + params.cov_dim,
hidden_size=params.lstm_hidden_dim,
num_layers=params.lstm_layers,
bias=True,
batch_first=False,
dropout=params.lstm_dropout)'''
# initialize LSTM forget gate bias to be 1 as recommanded by http://proceedings.mlr.press/v37/jozefowicz15.pdf
for names in self.lstm._all_weights:
for name in filter(lambda n: "bias" in n, names):
bias = getattr(self.lstm, name)
n = bias.size(0)
start, end = n // 4, n // 2
bias.data[start:end].fill_(1.)

self.relu = nn.ReLU()
self.distribution_mu = nn.Linear(self.params["lstm_hidden_dim"] * self.params["lstm_layers"], 1)
self.distribution_presigma = nn.Linear(self.params["lstm_hidden_dim"] * self.params["lstm_layers"], 1)
self.distribution_sigma = nn.Softplus()
self.cell = self.init_cell(1 + self.params["cov_dim"] + self.params["embedding_dim"])
self.hidden = self.init_hidden(1 + self.params["cov_dim"] + self.params["embedding_dim"])

def forward(self, x, idx=torch.Tensor(0)):
'''
Predict mu and sigma of the distribution for z_t.
Args:
x: ([1, batch_size, 1+cov_dim]): z_{t-1} + x_t, note that z_0 = 0
idx ([1, batch_size]): one integer denoting the time series id
hidden ([lstm_layers, batch_size, lstm_hidden_dim]): LSTM h from time step t-1
cell ([lstm_layers, batch_size, lstm_hidden_dim]): LSTM c from time step t-1
Returns:
mu ([batch_size]): estimated mean of z_t
sigma ([batch_size]): estimated standard deviation of z_t
hidden ([lstm_layers, batch_size, lstm_hidden_dim]): LSTM h from time step t
cell ([lstm_layers, batch_size, lstm_hidden_dim]): LSTM c from time step t
'''
onehot_embed = self.embedding(idx) # TODO: is it possible to do this only once per window instead of per step?
lstm_input = torch.cat((x, onehot_embed), dim=2)
output, (hidden, cell) = self.lstm(lstm_input, (self.hidden, self.cell))
self.cell = cell
self.hidden = hidden
# use h from all three layers to calculate mu and sigma
hidden_permute = hidden.permute(1, 2, 0).contiguous().view(hidden.shape[1], -1)
pre_sigma = self.distribution_presigma(hidden_permute)
mu = self.distribution_mu(hidden_permute)
sigma = self.distribution_sigma(pre_sigma) # softplus to make sure standard deviation is positive
return torch.squeeze(mu), torch.squeeze(sigma), hidden, cell

def init_hidden(self, input_size):
return torch.zeros(self.params["lstm_layers"], input_size, self.params["lstm_hidden_dim"])

def init_cell(self, input_size):
return torch.zeros(self.params["lstm_layers"], input_size, self.params["lstm_hidden_dim"])

def test(self, x, v_batch, id_batch, hidden, cell, sampling=False):
batch_size = x.shape[1]
if sampling:
samples = torch.zeros(self.params["sample_times"], batch_size, self.params["predict_steps"])
for j in range(self.params["sample_times"]):
decoder_hidden = hidden
decoder_cell = cell
for t in range(self.params["predict_steps"]):
mu_de, sigma_de, decoder_hidden, decoder_cell = self(
x[self.params["predict_start"] + t].unsqueeze(0),
id_batch, decoder_hidden, decoder_cell)
gaussian = torch.distributions.normal.Normal(mu_de, sigma_de)
pred = gaussian.sample() # not scaled
samples[j, :, t] = pred * v_batch[:, 0] + v_batch[:, 1]
if t < (self.params["predict_steps"] - 1):
x[self.params["predict_start"] + t + 1, :, 0] = pred

sample_mu = torch.median(samples, dim=0)[0]
sample_sigma = samples.std(dim=0)
return samples, sample_mu, sample_sigma

else:
decoder_hidden = hidden
decoder_cell = cell
sample_mu = torch.zeros(batch_size, self.params["predict_steps"])
sample_sigma = torch.zeros(batch_size, self.params["predict_steps"])
for t in range(self.params["predict_steps"]):
mu_de, sigma_de, decoder_hidden, decoder_cell = self(x[self.params["predict_start"] + t].unsqueeze(0),
id_batch, decoder_hidden, decoder_cell)
sample_mu[:, t] = mu_de * v_batch[:, 0] + v_batch[:, 1]
sample_sigma[:, t] = sigma_de * v_batch[:, 0]
if t < (self.params["predict_steps"] - 1):
x[self.params["predict_start"] + t + 1, :, 0] = mu_de
return sample_mu, sample_sigma
4 changes: 3 additions & 1 deletion flood_forecast/model_dict_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from flood_forecast.transformer_xl.transformer_bottleneck import DecoderTransformer
from flood_forecast.custom.dilate_loss import DilateLoss
from flood_forecast.meta_models.basic_ae import AE
from flood_forecast.deep_ar.model import DeepAR
import torch

"""
Expand All @@ -29,7 +30,8 @@
"CustomTransformerDecoder": CustomTransformerDecoder,
"DARNN": DARNN,
"DecoderTransformer": DecoderTransformer,
"BasicAE": AE
"BasicAE": AE,
"DeepAR": DeepAR
}

pytorch_criterion_dict = {
Expand Down
75 changes: 75 additions & 0 deletions tests/DeepAR_test.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
{
"model_name": "DeepAR",
"model_type": "PyTorch",
"model_params": {
"cov_dim": 3,
"embedding_dim": 10,
"lstm_dropout": 0.1,
"lstm_hidden_dim": 40,
"lstm_layers": 2,
"num_class": 1,
"predict_start": 5,
"predict_steps": 24,
"sample_times": 200
},
"dataset_params":
{ "class": "default",
"training_path": "tests/test_data/keag_small.csv",
"validation_path": "tests/test_data/keag_small.csv",
"test_path": "tests/test_data/keag_small.csv",
"batch_size":4,
"forecast_history":5,
"forecast_length":1,
"train_end": 100,
"valid_start":301,
"valid_end": 401,
"test_predict_start":125,
"test_end":500,
"target_col": ["cfs"],
"relevant_cols": ["cfs", "precip", "temp"],
"interpolate": false,
"interpolate_param": false
},
"early_stopping":{
"patience":1

},

"training_params":
{
"criterion":"NegativeLogLikelihood",
"probabilistic": true,
"optimizer": "Adam",
"optim_params":
{
},
"lr": 0.3,
"epochs": 1,
"batch_size":4

},

"GCS": false,

"wandb": {
"name": "flood_forecast_circleci",
"tags": ["dummy_run", "circleci", "DeepAR"],
"project": "repo-flood_forecast"
},
"forward_params":{},
"metrics":["MSE"],
"inference_params":
{ "num_prediction_samples": 10,
"datetime_start":"2016-05-31",
"hours_to_forecast":336,
"test_csv_path":"tests/test_data/keag_small.csv",
"dataset_params":{
"file_path": "tests/test_data/keag_small.csv",
"forecast_history":5,
"forecast_length":1,
"relevant_cols": ["cfs", "precip", "temp"],
"target_col": ["cfs"],
"interpolate_param": false
}
}
}