Skip to content

Commit

Permalink
update for moirai-moe
Browse files Browse the repository at this point in the history
  • Loading branch information
liuxu77 committed Nov 1, 2024
1 parent 959622c commit 1c32b6e
Show file tree
Hide file tree
Showing 15 changed files with 627 additions and 48 deletions.
30 changes: 18 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
# Unified Training of Universal Time Series Forecasting Transformers
[Paper](https://arxiv.org/abs/2402.02592) | [Blog Post](https://blog.salesforceairesearch.com/moirai/)
# Unified Training of Universal Time Series Transformers

Uni2TS is a PyTorch based library for research and applications related to Time Series Transformers.
This library aims to provide a unified solution to large-scale pre-training of Universal Time Series Transformers.
Uni2TS also provides tools for fine-tuning, inference, and evaluation for time series forecasting.
Uni2TS is a PyTorch based library for research and applications related to Time Series Forecasting. It provides a unified framework for large-scale pre-training, fine-tuning, inference, and evaluation of Universal Time Series Transformers.

Related reading: [Moirai Paper](https://arxiv.org/abs/2402.02592), [Moirai Salesforce Blog](https://blog.salesforceairesearch.com/moirai/), [Moirai-MoE Paper](https://arxiv.org/abs/2410.10469), [Moirai-MoE AI Horizon Forecast Blog](https://aihorizonforecast.substack.com/p/moirai-moe-upgrading-moirai-with), [Moirai-MoE Jiqizhixin Blog](https://mp.weixin.qq.com/s/LQvlgxx9vU965Yzy6RuBfQ).

## 🎉 What's New

* Oct 2024: A new model Moirai-MoE! The preprint is now available on [arXiv](https://arxiv.org/abs/2410.10469). Model weights to be released soon.
* Oct 2024: A new model Moirai-MoE! The preprint is available on [arXiv](https://arxiv.org/abs/2410.10469), along with model weights of [small](https://huggingface.co/Salesforce/moirai-moe-1.0-R-small) and [base](https://huggingface.co/Salesforce/moirai-moe-1.0-R-base), and [simple example](https://github.com/SalesforceAIResearch/uni2ts/project/moirai-moe-1) to get started.

* Jun 2024: Released Moirai-1.1-R model weights in [small](https://huggingface.co/Salesforce/moirai-1.1-R-small), [base](https://huggingface.co/Salesforce/moirai-1.1-R-base), and [large](https://huggingface.co/Salesforce/moirai-1.1-R-large).

* May 2024: The Uni2TS paper has been accepted to ICML 2024 as an Oral presentation!
* May 2024: The [Moirai Paper](https://arxiv.org/abs/2402.02592) has been accepted to ICML 2024 as an Oral presentation!

* Mar 2024: Release of Uni2TS library, along with [Moirai-1.0-R](https://huggingface.co/collections/Salesforce/moirai-10-r-models-65c8d3a94c51428c300e0742) and [LOTSA data](https://huggingface.co/datasets/Salesforce/lotsa_data/)!
* Mar 2024: Release of Uni2TS library, along with [Moirai Paper](https://arxiv.org/abs/2402.02592), [Moirai-1.0-R Models](https://huggingface.co/collections/Salesforce/moirai-10-r-models-65c8d3a94c51428c300e0742), and [LOTSA Data](https://huggingface.co/datasets/Salesforce/lotsa_data/).

## ✅ TODO

Expand Down Expand Up @@ -230,15 +229,22 @@ python -m cli.train \
data=lotsa_v1_unweighted
```

## 👀 Citing Uni2TS
## 👀 Citation

If you're using Uni2TS in your research or applications, please cite it using this BibTeX:
If you're using this repository in your research or applications, please cite using the following BibTeX:

```markdown
@article{woo2024unified,
@article{liu2024moiraimoe,
title={Moirai-MoE: Empowering Time Series Foundation Models with Sparse Mixture of Experts},
author={Liu, Xu and Liu, Juncheng and Woo, Gerald and Aksu, Taha and Liang, Yuxuan and Zimmermann, Roger and Liu, Chenghao and Savarese, Silvio and Xiong, Caiming and Sahoo, Doyen},
journal={arXiv preprint arXiv:2410.10469},
year={2024}
}

@inproceedings{woo2024unified,
title={Unified Training of Universal Time Series Forecasting Transformers},
author={Woo, Gerald and Liu, Chenghao and Kumar, Akshat and Xiong, Caiming and Savarese, Silvio and Sahoo, Doyen},
journal={arXiv preprint arXiv:2402.02592},
booktitle={Forty-first International Conference on Machine Learning},
year={2024}
}
```
8 changes: 8 additions & 0 deletions cli/conf/eval/model/moirai_moe_1.0_R_base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
_target_: uni2ts.model.moirai.MoiraiForecast
module:
_target_: uni2ts.model.moirai.MoiraiMoEModule.from_pretrained
pretrained_model_name_or_path: Salesforce/moirai-moe-1.0-R-base
mode: autoregressive
num_samples: 100
patch_size: 16
context_length: ???
8 changes: 8 additions & 0 deletions cli/conf/eval/model/moirai_moe_1.0_R_small.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
_target_: uni2ts.model.moirai.MoiraiForecast
module:
_target_: uni2ts.model.moirai.MoiraiMoEModule.from_pretrained
pretrained_model_name_or_path: Salesforce/moirai-moe-1.0-R-small
mode: autoregressive
num_samples: 100
patch_size: 16
context_length: ???
99 changes: 99 additions & 0 deletions project/moirai-moe-1/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Moirai-MoE-1.0-R

Our paper [Moirai-MoE: Empowering Time Series Foundation Models with Sparse Mixture of Experts](https://arxiv.org/abs/2410.10469) introduces the first mixture-of-experts time series foundation model.

The figure below presents the major difference between Moirai-MoE and Moirai. Compared to Moirai using multi-heuristic-defined input/output projection layers to model time series with different frequencies, Moirai-MoE utilizes a single input/output projection layer while delegating the task of capturing diverse time series patterns to the sparse mixture of experts Transformers. With these designs, the specialization of Moirai-MoE is achieved in a data-driven manner and operates at the token level.

<p align="center">
<img src="./img/framework.png" height="200" alt="" align=center />
</p>


## Models

The pre-trained weights of Moirai-MoE can be found in the following table.

| Model | # Activated Parameters | # Total Parameters |
| :---: | :---: | :---: |
| [Moirai-MoE-1.0-R-Small](https://huggingface.co/Salesforce/moirai-moe-1.0-R-small) | 11m | 117m |
| [Moirai-MoE-1.0-R-Base](https://huggingface.co/Salesforce/moirai-moe-1.0-R-base) | 86m | 935m |


## Usage

Let's see a simple example on how to use pre-trained Moirai-MoE models to make forecasts.

```python
import matplotlib.pyplot as plt
from gluonts.dataset.repository import dataset_recipes

from uni2ts.eval_util.data import get_gluonts_test_dataset
from uni2ts.eval_util.plot import plot_next_multi
from uni2ts.model.moirai import MoiraiForecast, MoiraiMoEModule

SIZE = "small" # model size: choose from {'small', 'base'}
CTX = 1000 # context length: any positive integer
BSZ = 32 # batch size: any positive integer

# Load dataset
test_data, metadata = get_gluonts_test_dataset(
"electricity", prediction_length=None, regenerate=False
)
# Uncomment the below line to find other datasets
# print(sorted(dataset_recipes.keys()))

# Prepare model
model = MoiraiForecast(
module=MoiraiMoEModule.from_pretrained(
f"Salesforce/moirai-moe-1.0-R-{SIZE}",
),
mode="autoregressive",
prediction_length=metadata.prediction_length,
context_length=CTX,
patch_size=16,
num_samples=100,
target_dim=metadata.target_dim,
feat_dynamic_real_dim=metadata.feat_dynamic_real_dim,
past_feat_dynamic_real_dim=metadata.past_feat_dynamic_real_dim,
)

predictor = model.create_predictor(batch_size=BSZ)
forecasts = predictor.predict(test_data.input)

input_it = iter(test_data.input)
label_it = iter(test_data.label)
forecast_it = iter(forecasts)

# Visualize forecasts
fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(25, 10))
plot_next_multi(
axes,
input_it,
label_it,
forecast_it,
context_length=200,
intervals=(0.5, 0.9),
dim=None,
name="pred",
show_label=True,
)
```


## Results

Extensive experiments on 39 datasets demonstrate the superiority of Moirai-MoE over existing foundation models in both in-distribution and zero-shot scenarios.

<p align="center">
<img src="./img/in-dist.png" height="200" alt="" align=center />
</p>

The above figure presents the in-distribution evaluation using a total of 29 datasets from the Monash benchmark. The evaluation results show that Moirai-MoE beats all competitors.

<p align="center">
<img src="./img/zero-shot.png" height="200" alt="" align=center />
</p>

The above table shows a zero-shot forecasting evaluation on 10 datasets and Moirai-MoE-Base achieves the best zero-shot performance.

We will soon release scripts to reproduce the results.
Binary file added project/moirai-moe-1/img/framework.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added project/moirai-moe-1/img/in-dist.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added project/moirai-moe-1/img/zero-shot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 12 additions & 0 deletions src/uni2ts/common/torch_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,18 @@ def packed_attention_mask(
return attention_mask


def packed_causal_attention_mask(
sample_id: Int[torch.Tensor, "*batch seq_len"],
time_id: Int[torch.Tensor, "*batch seq_len"]
) -> Bool[torch.Tensor, "*batch seq_len seq_len"]:
attention_mask = packed_attention_mask(sample_id)
expanded_id1 = time_id.unsqueeze(-2)
expanded_id2 = time_id.unsqueeze(-1)
compare_res = expanded_id1 <= expanded_id2
attention_mask = attention_mask * compare_res
return attention_mask


def mask_fill(
tensor: Float[torch.Tensor, "*batch dim"],
mask: Bool[torch.Tensor, "*batch"],
Expand Down
3 changes: 2 additions & 1 deletion src/uni2ts/model/moirai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .finetune import MoiraiFinetune
from .forecast import MoiraiForecast
from .module import MoiraiModule
from .module_moe import MoiraiMoEModule
from .pretrain import MoiraiPretrain

__all__ = ["MoiraiFinetune", "MoiraiForecast", "MoiraiModule", "MoiraiPretrain"]
__all__ = ["MoiraiFinetune", "MoiraiForecast", "MoiraiModule", "MoiraiMoEModule", "MoiraiPretrain"]
143 changes: 123 additions & 20 deletions src/uni2ts/model/moirai/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ExpandDimArray,
TestSplitSampler,
Transformation,
CausalMeanValueImputation,
)
from gluonts.transform.split import TFTInstanceSplitter
from jaxtyping import Bool, Float, Int
Expand Down Expand Up @@ -81,6 +82,7 @@ def __init__(
module: Optional[MoiraiModule] = None,
patch_size: int | str = "auto",
num_samples: int = 100,
mode: str = "direct",
):
assert (module is not None) or (
module_kwargs is not None
Expand Down Expand Up @@ -332,20 +334,113 @@ def forward(
idx = val_loss.argmin(dim=0)
return preds[idx, torch.arange(len(idx), device=idx.device)]
else:
distr = self._get_distr(
self.hparams.patch_size,
past_target,
past_observed_target,
past_is_pad,
feat_dynamic_real,
observed_feat_dynamic_real,
past_feat_dynamic_real,
past_observed_feat_dynamic_real,
)
preds = distr.sample(torch.Size((num_samples or self.hparams.num_samples,)))
return self._format_preds(
self.hparams.patch_size, preds, past_target.shape[-1]
)
if self.hparams.mode == "direct":
distr = self._get_distr(
self.hparams.patch_size,
past_target,
past_observed_target,
past_is_pad,
feat_dynamic_real,
observed_feat_dynamic_real,
past_feat_dynamic_real,
past_observed_feat_dynamic_real,
)
preds = distr.sample(torch.Size((num_samples or self.hparams.num_samples,)))
return self._format_preds(
self.hparams.patch_size, preds, past_target.shape[-1]
)

elif self.hparams.mode == "autoregressive":
context_step = self.context_token_length(self.hparams.patch_size)
context_token = self.hparams.target_dim * context_step
predict_step = self.prediction_token_length(self.hparams.patch_size)
predict_token = self.hparams.target_dim * predict_step

(
target,
observed_mask,
sample_id,
time_id,
variate_id,
prediction_mask,
) = self._convert(
self.hparams.patch_size,
past_target,
past_observed_target,
past_is_pad,
feat_dynamic_real=feat_dynamic_real,
observed_feat_dynamic_real=observed_feat_dynamic_real,
past_feat_dynamic_real=past_feat_dynamic_real,
past_observed_feat_dynamic_real=past_observed_feat_dynamic_real,
)
patch_size = torch.ones_like(time_id, dtype=torch.long) * self.hparams.patch_size

pred_index = torch.arange(start=context_step-1, end=context_token, step=context_step)
assign_index = torch.arange(start=context_token, end=context_token+predict_token, step=predict_step)

if predict_step == 1:
distr = self.module(
target,
observed_mask,
sample_id,
time_id,
variate_id,
prediction_mask,
patch_size,
)
preds = distr.sample(torch.Size((num_samples or self.hparams.num_samples,)))
preds[..., assign_index, :] = preds[..., pred_index, :]
return self._format_preds(
self.hparams.patch_size, preds, self.hparams.target_dim
)
else:
distr = self.module(
target,
observed_mask,
sample_id,
time_id,
variate_id,
prediction_mask,
patch_size,
)
preds = distr.sample(torch.Size((self.hparams.num_samples,)))

expand_target = target.unsqueeze(0).repeat(self.hparams.num_samples, 1, 1, 1)
expand_prediction_mask = prediction_mask.unsqueeze(0).repeat(self.hparams.num_samples, 1, 1)
expand_observed_mask = observed_mask.unsqueeze(0).expand(self.hparams.num_samples, -1, -1, -1)
expand_sample_id = sample_id.unsqueeze(0).expand(self.hparams.num_samples, -1, -1)
expand_time_id = time_id.unsqueeze(0).expand(self.hparams.num_samples, -1, -1)
expand_variate_id = variate_id.unsqueeze(0).expand(self.hparams.num_samples, -1, -1)
expand_patch_size = patch_size.unsqueeze(0).expand(self.hparams.num_samples, -1, -1)

expand_target[..., assign_index, :] = preds[..., pred_index, :]
expand_prediction_mask[..., assign_index] = False

remain_step = predict_step - 1
while remain_step > 0:
distr = self.module(
expand_target,
expand_observed_mask,
expand_sample_id,
expand_time_id,
expand_variate_id,
expand_prediction_mask,
expand_patch_size,
)
preds = distr.sample(torch.Size((1,)))
_, _, bs, token, ps = preds.shape
preds = preds.view(-1, bs, token, ps)

pred_index = assign_index
assign_index = assign_index + 1
expand_target[..., assign_index, :] = preds[..., pred_index, :]
expand_prediction_mask[..., assign_index] = False

remain_step -= 1

return self._format_preds(
self.hparams.patch_size, expand_target, self.hparams.target_dim
)

def _val_loss(
self,
Expand Down Expand Up @@ -486,7 +581,7 @@ def _generate_time_id(
"max",
patch=patch_size,
)
past_seq_id = torch.clamp(past_seq_id.cumsum(dim=-1) - 1, min=0)
past_seq_id = torch.clamp(past_seq_id.cummax(dim=-1).values.cumsum(dim=-1) - 1, min=0)
batch_shape = " ".join(map(str, past_observed_target.shape[:-2]))
future_seq_id = (
repeat(
Expand Down Expand Up @@ -943,12 +1038,20 @@ def get_default_transform(self) -> Transformation:
dtype=np.float32,
)
if self.hparams.target_dim == 1:
transform += AddObservedValuesIndicator(
target_field="target",
output_field="observed_target",
imputation_method=CausalMeanValueImputation(),
dtype=bool,
)
transform += ExpandDimArray(field="target", axis=0)
transform += AddObservedValuesIndicator(
target_field="target",
output_field="observed_target",
dtype=bool,
)
transform += ExpandDimArray(field="observed_target", axis=0)
else:
transform += AddObservedValuesIndicator(
target_field="target",
output_field="observed_target",
dtype=bool,
)

if self.hparams.feat_dynamic_real_dim > 0:
transform += AsNumpyArray(
Expand Down
1 change: 1 addition & 0 deletions src/uni2ts/model/moirai/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __init__(
dropout_p=dropout_p,
norm_layer=RMSNorm,
activation=F.silu,
use_moe=False,
use_glu=True,
use_qk_norm=True,
var_attn_bias_layer=partial(BinaryAttentionBias),
Expand Down
Loading

0 comments on commit 1c32b6e

Please sign in to comment.