Skip to content

Commit 166c729

Browse files
author
root
committed
update readme and notebooks for moe
1 parent 7a62471 commit 166c729

15 files changed

Lines changed: 1171 additions & 289 deletions

File tree

CITATION.cff

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
cff-version: 1.2.0
22
title: Unified Training of Universal Time Series Forecasting Transformers
3-
message: If you find Chronos models useful for your research, please consider citing the associated paper.
3+
message: If you find Moirai useful for your research, please consider citing the associated paper.
44
authors:
55
- family-names: Woo
66
given-names: Gerald

README.md

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22

33
Uni2TS is a PyTorch based library for research and applications related to Time Series Forecasting. It provides a unified framework for large-scale pre-training, fine-tuning, inference, and evaluation of Universal Time Series Transformers.
44

5-
Related reading: [Moirai Paper](https://arxiv.org/abs/2402.02592), [Moirai Salesforce Blog](https://blog.salesforceairesearch.com/moirai/), [Moirai-MoE Paper](https://arxiv.org/abs/2410.10469), [Moirai-MoE AI Horizon Forecast Blog](https://aihorizonforecast.substack.com/p/moirai-moe-upgrading-moirai-with), [Moirai-MoE Jiqizhixin Blog](https://mp.weixin.qq.com/s/LQvlgxx9vU965Yzy6RuBfQ).
5+
Related reading: [Moirai Paper](https://arxiv.org/abs/2402.02592), [Moirai Salesforce Blog](https://blog.salesforceairesearch.com/moirai/), [Moirai-MoE Paper](https://arxiv.org/abs/2410.10469), [Moirai-MoE Salesforce Blog](https://www.salesforce.com/blog/time-series-morai-moe/), [Moirai-MoE AI Horizon Forecast Blog](https://aihorizonforecast.substack.com/p/moirai-moe-upgrading-moirai-with), [Moirai-MoE Jiqizhixin Blog](https://mp.weixin.qq.com/s/LQvlgxx9vU965Yzy6RuBfQ).
66

77
## 🎉 What's New
88

9-
* Oct 2024: A new model Moirai-MoE! The preprint is available on [arXiv](https://arxiv.org/abs/2410.10469), along with model weights of [small](https://huggingface.co/Salesforce/moirai-moe-1.0-R-small) and [base](https://huggingface.co/Salesforce/moirai-moe-1.0-R-base), and [code](https://github.com/SalesforceAIResearch/uni2ts/tree/main/project/moirai-moe-1) to get started.
9+
* Nov 2024: The first general time series forecasting benchmark [GIFT-Eval](https://github.com/SalesforceAIResearch/gift-eval) is released. Moirai-Large achieves the best performance in the [Leaderborad](https://huggingface.co/spaces/Salesforce/GIFT-Eval)!
10+
11+
* Oct 2024: A new model Moirai-MoE! The preprint is available on [arXiv](https://arxiv.org/abs/2410.10469), along with the model weights of [Moirai-MoE-Small](https://huggingface.co/Salesforce/moirai-moe-1.0-R-small) and [Moirai-MoE-Base](https://huggingface.co/Salesforce/moirai-moe-1.0-R-base). Getting started with [inference code](https://github.com/SalesforceAIResearch/uni2ts/tree/main/project/moirai-moe-1) and [notebook examples](https://github.com/SalesforceAIResearch/uni2ts/tree/main/example)!
1012

1113
* Sep 2024: Released [Evaluation Code](https://github.com/SalesforceAIResearch/uni2ts/tree/main/project/benchmarks) of [TimesFM](https://arxiv.org/abs/2310.10688), [Chronos](https://arxiv.org/abs/2403.07815) and [VisionTS](https://arxiv.org/abs/2408.17253) on Monash, LSF and PF benchmarks.
1214

@@ -16,22 +18,6 @@ Related reading: [Moirai Paper](https://arxiv.org/abs/2402.02592), [Moirai Sales
1618

1719
* Mar 2024: Release of Uni2TS library, along with [Moirai Paper](https://arxiv.org/abs/2402.02592), [Moirai-1.0-R Models](https://huggingface.co/collections/Salesforce/moirai-10-r-models-65c8d3a94c51428c300e0742), and [LOTSA Data](https://huggingface.co/datasets/Salesforce/lotsa_data/).
1820

19-
## ✅ TODO
20-
21-
- [ ] Improve docstrings and documentation
22-
23-
[//]: # (- [ ] Support more pre-training paradigms)
24-
25-
[//]: # ( - [ ] (Non-)Contrastive learning)
26-
27-
[//]: # ( - [ ] Masked Autoencoder)
28-
29-
[//]: # ( - [ ] Next token prediction)
30-
31-
[//]: # (- [ ] Decoder Transformer)
32-
33-
[//]: # (- [ ] Data augmentations - down sampling, subsampling, aggregation)
34-
3521
## ⚙️ Installation
3622

3723
1. Clone repository:
@@ -56,6 +42,11 @@ pip install -e '.[notebook]'
5642
touch .env
5743
```
5844

45+
We also support installation via PyPI.
46+
```shell
47+
pip install uni2ts
48+
```
49+
5950
## 🏃 Getting Started
6051

6152
Let's see a simple example on how to use Uni2TS to make zero-shot forecasts from a pre-trained model.
@@ -72,8 +63,9 @@ from huggingface_hub import hf_hub_download
7263

7364
from uni2ts.eval_util.plot import plot_single
7465
from uni2ts.model.moirai import MoiraiForecast, MoiraiModule
66+
from uni2ts.model.moirai_moe import MoiraiMoEForecast, MoiraiMoEModule
7567

76-
68+
MODEL = "moirai-moe" # model name: choose from {'moirai', 'moirai-moe'}
7769
SIZE = "small" # model size: choose from {'small', 'base', 'large'}
7870
PDT = 20 # prediction length: any positive integer
7971
CTX = 200 # context length: any positive integer
@@ -104,16 +96,28 @@ test_data = test_template.generate_instances(
10496
)
10597

10698
# Prepare pre-trained model by downloading model weights from huggingface hub
107-
model = MoiraiForecast(
108-
module=MoiraiModule.from_pretrained(f"Salesforce/moirai-1.0-R-{SIZE}"),
109-
prediction_length=PDT,
110-
context_length=CTX,
111-
patch_size=PSZ,
112-
num_samples=100,
113-
target_dim=1,
114-
feat_dynamic_real_dim=ds.num_feat_dynamic_real,
115-
past_feat_dynamic_real_dim=ds.num_past_feat_dynamic_real,
116-
)
99+
if MODEL == "moirai":
100+
model = MoiraiForecast(
101+
module=MoiraiModule.from_pretrained(f"Salesforce/moirai-1.1-R-{SIZE}"),
102+
prediction_length=PDT,
103+
context_length=CTX,
104+
patch_size=PSZ,
105+
num_samples=100,
106+
target_dim=1,
107+
feat_dynamic_real_dim=ds.num_feat_dynamic_real,
108+
past_feat_dynamic_real_dim=ds.num_past_feat_dynamic_real,
109+
)
110+
elif MODEL == "moirai-moe":
111+
model = MoiraiMoEForecast(
112+
module=MoiraiMoEModule.from_pretrained(f"Salesforce/moirai-moe-1.0-R-{SIZE}"),
113+
prediction_length=PDT,
114+
context_length=CTX,
115+
patch_size=16,
116+
num_samples=100,
117+
target_dim=1,
118+
feat_dynamic_real_dim=ds.num_feat_dynamic_real,
119+
past_feat_dynamic_real_dim=ds.num_past_feat_dynamic_real,
120+
)
117121

118122
predictor = model.create_predictor(batch_size=BSZ)
119123
forecasts = predictor.predict(test_data.input)
@@ -243,7 +247,14 @@ If you're using this repository in your research or applications, please cite us
243247
year={2024}
244248
}
245249

246-
@inproceedings{woo2024unified,
250+
@article{aksu2024gifteval,
251+
title={GIFT-Eval: A Benchmark For General Time Series Forecasting Model Evaluation},
252+
author={Aksu, Taha and Woo, Gerald and Liu, Juncheng and Liu, Xu and Liu, Chenghao and Savarese, Silvio and Xiong, Caiming and Sahoo, Doyen},
253+
journal={arXiv preprint arXiv:2410.10393},
254+
year={2024}
255+
}
256+
257+
@inproceedings{woo2024moirai,
247258
title={Unified Training of Universal Time Series Forecasting Transformers},
248259
author={Woo, Gerald and Liu, Chenghao and Kumar, Akshat and Xiong, Caiming and Savarese, Silvio and Sahoo, Doyen},
249260
booktitle={Forty-first International Conference on Machine Learning},

cli/conf/eval/default.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,5 @@ metrics:
2020
- _target_: gluonts.ev.metrics.MeanWeightedSumQuantileLoss
2121
quantile_levels: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
2222
batch_size: 512
23-
min_batch_size: 16
23+
min_batch_size: 1
2424
device: auto
Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
_target_: uni2ts.model.moirai.MoiraiForecast
1+
_target_: uni2ts.model.moirai_moe.MoiraiMoEForecast
22
module:
3-
_target_: uni2ts.model.moirai.MoiraiMoEModule.from_pretrained
3+
_target_: uni2ts.model.moirai_moe.MoiraiMoEModule.from_pretrained
44
pretrained_model_name_or_path: Salesforce/moirai-moe-1.0-R-base
5-
mode: autoregressive
65
num_samples: 100
76
patch_size: 16
87
context_length: ???
Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
_target_: uni2ts.model.moirai.MoiraiForecast
1+
_target_: uni2ts.model.moirai_moe.MoiraiMoEForecast
22
module:
3-
_target_: uni2ts.model.moirai.MoiraiMoEModule.from_pretrained
3+
_target_: uni2ts.model.moirai_moe.MoiraiMoEModule.from_pretrained
44
pretrained_model_name_or_path: Salesforce/moirai-moe-1.0-R-small
5-
mode: autoregressive
65
num_samples: 100
76
patch_size: 16
87
context_length: ???

example/moirai_forecast.ipynb

Lines changed: 34 additions & 18 deletions
Large diffs are not rendered by default.

example/moirai_forecast_pandas.ipynb

Lines changed: 128 additions & 65 deletions
Large diffs are not rendered by default.

project/moirai-moe-1/README.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@ The pre-trained weights of Moirai-MoE can be found in the following table.
2121

2222
## Usage
2323

24-
Let's see a simple example on how to use pre-trained Moirai-MoE models to make forecasts.
24+
Let's see a simple example below on how to use pre-trained Moirai-MoE models to make forecasts. See also the notebooks in the [example folder](https://github.com/SalesforceAIResearch/uni2ts/tree/main/example) to try out Moirai-MoE.
2525

2626
```python
2727
import matplotlib.pyplot as plt
2828
from gluonts.dataset.repository import dataset_recipes
2929

3030
from uni2ts.eval_util.data import get_gluonts_test_dataset
3131
from uni2ts.eval_util.plot import plot_next_multi
32-
from uni2ts.model.moirai import MoiraiForecast, MoiraiMoEModule
32+
from uni2ts.model.moirai_moe import MoiraiMoEForecast, MoiraiMoEModule
3333

3434
SIZE = "small" # model size: choose from {'small', 'base'}
3535
CTX = 1000 # context length: any positive integer
@@ -43,11 +43,10 @@ test_data, metadata = get_gluonts_test_dataset(
4343
# print(sorted(dataset_recipes.keys()))
4444

4545
# Prepare model
46-
model = MoiraiForecast(
46+
model = MoiraiMoEForecast(
4747
module=MoiraiMoEModule.from_pretrained(
4848
f"Salesforce/moirai-moe-1.0-R-{SIZE}",
4949
),
50-
mode="autoregressive",
5150
prediction_length=metadata.prediction_length,
5251
context_length=CTX,
5352
patch_size=16,

src/uni2ts/model/moirai/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,11 @@
1616
from .finetune import MoiraiFinetune
1717
from .forecast import MoiraiForecast
1818
from .module import MoiraiModule
19-
from .module_moe import MoiraiMoEModule
2019
from .pretrain import MoiraiPretrain
2120

2221
__all__ = [
2322
"MoiraiFinetune",
2423
"MoiraiForecast",
2524
"MoiraiModule",
26-
"MoiraiMoEModule",
2725
"MoiraiPretrain",
2826
]

src/uni2ts/model/moirai/forecast.py

Lines changed: 21 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from gluonts.transform import (
2828
AddObservedValuesIndicator,
2929
AsNumpyArray,
30-
CausalMeanValueImputation,
3130
ExpandDimArray,
3231
TestSplitSampler,
3332
Transformation,
@@ -82,7 +81,6 @@ def __init__(
8281
module: Optional[MoiraiModule] = None,
8382
patch_size: int | str = "auto",
8483
num_samples: int = 100,
85-
mode: str = "direct",
8684
):
8785
assert (module is not None) or (
8886
module_kwargs is not None
@@ -334,139 +332,22 @@ def forward(
334332
idx = val_loss.argmin(dim=0)
335333
return preds[idx, torch.arange(len(idx), device=idx.device)]
336334
else:
337-
if self.hparams.mode == "direct":
338-
distr = self._get_distr(
339-
self.hparams.patch_size,
340-
past_target,
341-
past_observed_target,
342-
past_is_pad,
343-
feat_dynamic_real,
344-
observed_feat_dynamic_real,
345-
past_feat_dynamic_real,
346-
past_observed_feat_dynamic_real,
347-
)
348-
preds = distr.sample(
349-
torch.Size((num_samples or self.hparams.num_samples,))
350-
)
351-
return self._format_preds(
352-
self.hparams.patch_size, preds, past_target.shape[-1]
353-
)
354-
355-
elif self.hparams.mode == "autoregressive":
356-
context_step = self.context_token_length(self.hparams.patch_size)
357-
context_token = self.hparams.target_dim * context_step
358-
predict_step = self.prediction_token_length(self.hparams.patch_size)
359-
predict_token = self.hparams.target_dim * predict_step
360-
361-
(
362-
target,
363-
observed_mask,
364-
sample_id,
365-
time_id,
366-
variate_id,
367-
prediction_mask,
368-
) = self._convert(
369-
self.hparams.patch_size,
370-
past_target,
371-
past_observed_target,
372-
past_is_pad,
373-
feat_dynamic_real=feat_dynamic_real,
374-
observed_feat_dynamic_real=observed_feat_dynamic_real,
375-
past_feat_dynamic_real=past_feat_dynamic_real,
376-
past_observed_feat_dynamic_real=past_observed_feat_dynamic_real,
377-
)
378-
patch_size = (
379-
torch.ones_like(time_id, dtype=torch.long) * self.hparams.patch_size
380-
)
381-
382-
pred_index = torch.arange(
383-
start=context_step - 1, end=context_token, step=context_step
384-
)
385-
assign_index = torch.arange(
386-
start=context_token,
387-
end=context_token + predict_token,
388-
step=predict_step,
389-
)
390-
391-
if predict_step == 1:
392-
distr = self.module(
393-
target,
394-
observed_mask,
395-
sample_id,
396-
time_id,
397-
variate_id,
398-
prediction_mask,
399-
patch_size,
400-
)
401-
preds = distr.sample(
402-
torch.Size((num_samples or self.hparams.num_samples,))
403-
)
404-
preds[..., assign_index, :] = preds[..., pred_index, :]
405-
return self._format_preds(
406-
self.hparams.patch_size, preds, self.hparams.target_dim
407-
)
408-
else:
409-
distr = self.module(
410-
target,
411-
observed_mask,
412-
sample_id,
413-
time_id,
414-
variate_id,
415-
prediction_mask,
416-
patch_size,
417-
)
418-
preds = distr.sample(torch.Size((self.hparams.num_samples,)))
419-
420-
expand_target = target.unsqueeze(0).repeat(
421-
self.hparams.num_samples, 1, 1, 1
422-
)
423-
expand_prediction_mask = prediction_mask.unsqueeze(0).repeat(
424-
self.hparams.num_samples, 1, 1
425-
)
426-
expand_observed_mask = observed_mask.unsqueeze(0).expand(
427-
self.hparams.num_samples, -1, -1, -1
428-
)
429-
expand_sample_id = sample_id.unsqueeze(0).expand(
430-
self.hparams.num_samples, -1, -1
431-
)
432-
expand_time_id = time_id.unsqueeze(0).expand(
433-
self.hparams.num_samples, -1, -1
434-
)
435-
expand_variate_id = variate_id.unsqueeze(0).expand(
436-
self.hparams.num_samples, -1, -1
437-
)
438-
expand_patch_size = patch_size.unsqueeze(0).expand(
439-
self.hparams.num_samples, -1, -1
440-
)
441-
442-
expand_target[..., assign_index, :] = preds[..., pred_index, :]
443-
expand_prediction_mask[..., assign_index] = False
444-
445-
remain_step = predict_step - 1
446-
while remain_step > 0:
447-
distr = self.module(
448-
expand_target,
449-
expand_observed_mask,
450-
expand_sample_id,
451-
expand_time_id,
452-
expand_variate_id,
453-
expand_prediction_mask,
454-
expand_patch_size,
455-
)
456-
preds = distr.sample(torch.Size((1,)))
457-
_, _, bs, token, ps = preds.shape
458-
preds = preds.view(-1, bs, token, ps)
459-
460-
pred_index = assign_index
461-
assign_index = assign_index + 1
462-
expand_target[..., assign_index, :] = preds[..., pred_index, :]
463-
expand_prediction_mask[..., assign_index] = False
464-
465-
remain_step -= 1
466-
467-
return self._format_preds(
468-
self.hparams.patch_size, expand_target, self.hparams.target_dim
469-
)
335+
distr = self._get_distr(
336+
self.hparams.patch_size,
337+
past_target,
338+
past_observed_target,
339+
past_is_pad,
340+
feat_dynamic_real,
341+
observed_feat_dynamic_real,
342+
past_feat_dynamic_real,
343+
past_observed_feat_dynamic_real,
344+
)
345+
preds = distr.sample(
346+
torch.Size((num_samples or self.hparams.num_samples,))
347+
)
348+
return self._format_preds(
349+
self.hparams.patch_size, preds, past_target.shape[-1]
350+
)
470351

471352
def _val_loss(
472353
self,
@@ -1066,20 +947,12 @@ def get_default_transform(self) -> Transformation:
1066947
dtype=np.float32,
1067948
)
1068949
if self.hparams.target_dim == 1:
1069-
transform += AddObservedValuesIndicator(
1070-
target_field="target",
1071-
output_field="observed_target",
1072-
imputation_method=CausalMeanValueImputation(),
1073-
dtype=bool,
1074-
)
1075950
transform += ExpandDimArray(field="target", axis=0)
1076-
transform += ExpandDimArray(field="observed_target", axis=0)
1077-
else:
1078-
transform += AddObservedValuesIndicator(
1079-
target_field="target",
1080-
output_field="observed_target",
1081-
dtype=bool,
1082-
)
951+
transform += AddObservedValuesIndicator(
952+
target_field="target",
953+
output_field="observed_target",
954+
dtype=bool,
955+
)
1083956

1084957
if self.hparams.feat_dynamic_real_dim > 0:
1085958
transform += AsNumpyArray(

0 commit comments

Comments
 (0)