Skip to content

Commit 387903e

Browse files
Merge branch 'main' into documentation
2 parents 07656c4 + ea6b44e commit 387903e

11 files changed

+173
-2
lines changed

.pre-commit-config.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ repos:
33
rev: 24.2.0
44
hooks:
55
- id: black
6-
language_version: python3.10
6+
language_version: python3
77
- repo: https://github.com/pycqa/isort
88
rev: 5.13.2
99
hooks:
1010
- id: isort
11-
name: isort (python)
11+
name: isort (python)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
_target_: uni2ts.model.moirai.MoiraiForecast.load_from_checkpoint
2+
checkpoint_path: ...
3+
num_samples: 100
4+
patch_size: ???
5+
context_length: ???

cli/conf/finetune/default.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ run_name: ???
1010
seed: 0
1111
tf32: true
1212
compile: false # set to mode: default, reduce-overhead, max-autotune
13+
ckpt_path: null
1314
trainer:
1415
_target_: lightning.Trainer
1516
accelerator: auto

cli/conf/pretrain/default.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ run_name: ???
1010
seed: 0
1111
tf32: true
1212
compile: false # set to mode: default, reduce-overhead, max-autotune
13+
ckpt_path: null # set to "last" to resume training
1314
trainer:
1415
_target_: lightning.Trainer
1516
accelerator: auto

cli/train.py

+1
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def main(cfg: DictConfig):
142142
trainer.fit(
143143
model,
144144
datamodule=DataModule(cfg, train_dataset, val_dataset),
145+
ckpt_path=cfg.ckpt_path,
145146
)
146147

147148

project/benchmarks/README.md

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Benchmark
2+
This directory contains the code and scripts for benchmarking.
3+
4+
5+
## Chronos
6+
`run_chronos.py` is the code to run Chronos on a given dataset.
7+
8+
`chronos_scripts` contains the scripts to run Chronos on different datasets.
9+
10+
Example:
11+
```
12+
sh chronos_scripts/monash_chronos_base.sh
13+
```
14+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
model_size=base
2+
model_path=amazon/chronos-t5-${model_size}
3+
for ds in us_births saugeenday sunspot_with_missing temperature_rain_with_missing covid_deaths hospital rideshare_with_missing traffic_weekly traffic_hourly fred_md car_parts_with_missing electricity_weekly electricity_hourly solar_weekly solar_10_minutes nn5_weekly nn5_daily_with_missing weather kdd_cup_2018_with_missing vehicle_trips_with_missing pedestrian_counts bitcoin_with_missing dominick australian_electricity_demand cif_2016_12 cif_2016_6 tourism_monthly tourism_quarterly m4_hourly m4_daily m4_weekly m4_monthly monash_m3_other monash_m3_monthly m1_monthly m1_yearly monash_m3_yearly m4_yearly tourism_yearly m1_quarterly monash_m3_quarterly m4_quarterly kaggle_web_traffic_weekly kaggle_web_traffic_with_missing bitcoin
4+
do
5+
python run_chronos.py --model_path=${model_path} --dataset=${ds} --run_name=chronos-${model_size}
6+
done
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
model_size=mini
2+
model_path=amazon/chronos-t5-${model_size}
3+
for ds in us_births saugeenday sunspot_with_missing temperature_rain_with_missing covid_deaths hospital rideshare_with_missing traffic_weekly traffic_hourly fred_md car_parts_with_missing electricity_weekly electricity_hourly solar_weekly solar_10_minutes nn5_weekly nn5_daily_with_missing weather kdd_cup_2018_with_missing vehicle_trips_with_missing pedestrian_counts bitcoin_with_missing dominick australian_electricity_demand cif_2016_12 cif_2016_6 tourism_monthly tourism_quarterly m4_hourly m4_daily m4_weekly m4_monthly monash_m3_other monash_m3_monthly m1_monthly m1_yearly monash_m3_yearly m4_yearly tourism_yearly m1_quarterly monash_m3_quarterly m4_quarterly kaggle_web_traffic_weekly kaggle_web_traffic_with_missing bitcoin
4+
do
5+
python run_chronos.py --model_path=${model_path} --dataset=${ds} --run_name=chronos-${model_size}
6+
done
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
model_size=small
2+
model_path=amazon/chronos-t5-${model_size}
3+
for ds in us_births saugeenday sunspot_with_missing temperature_rain_with_missing covid_deaths hospital rideshare_with_missing traffic_weekly traffic_hourly fred_md car_parts_with_missing electricity_weekly electricity_hourly solar_weekly solar_10_minutes nn5_weekly nn5_daily_with_missing weather kdd_cup_2018_with_missing vehicle_trips_with_missing pedestrian_counts bitcoin_with_missing dominick australian_electricity_demand cif_2016_12 cif_2016_6 tourism_monthly tourism_quarterly m4_hourly m4_daily m4_weekly m4_monthly monash_m3_other monash_m3_monthly m1_monthly m1_yearly monash_m3_yearly m4_yearly tourism_yearly m1_quarterly monash_m3_quarterly m4_quarterly kaggle_web_traffic_weekly kaggle_web_traffic_with_missing bitcoin
4+
do
5+
python run_chronos.py --model_path=${model_path} --dataset=${ds} --run_name=chronos-${model_size}
6+
done
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
model_size=tiny
2+
model_path=amazon/chronos-t5-${model_size}
3+
for ds in us_births saugeenday sunspot_with_missing temperature_rain_with_missing covid_deaths hospital rideshare_with_missing traffic_weekly traffic_hourly fred_md car_parts_with_missing electricity_weekly electricity_hourly solar_weekly solar_10_minutes nn5_weekly nn5_daily_with_missing weather kdd_cup_2018_with_missing vehicle_trips_with_missing pedestrian_counts bitcoin_with_missing dominick australian_electricity_demand cif_2016_12 cif_2016_6 tourism_monthly tourism_quarterly m4_hourly m4_daily m4_weekly m4_monthly monash_m3_other monash_m3_monthly m1_monthly m1_yearly monash_m3_yearly m4_yearly tourism_yearly m1_quarterly monash_m3_quarterly m4_quarterly kaggle_web_traffic_weekly kaggle_web_traffic_with_missing bitcoin
4+
do
5+
python run_chronos.py --model_path=${model_path} --dataset=${ds} --run_name=chronos-${model_size}
6+
done

project/benchmarks/run_chronos.py

+125
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import argparse
2+
import os
3+
4+
import numpy as np
5+
import torch
6+
from chronos import ChronosPipeline
7+
from gluonts.dataset.repository import get_dataset
8+
from gluonts.dataset.split import split
9+
from gluonts.ev.metrics import (
10+
MAE,
11+
MAPE,
12+
MASE,
13+
MSE,
14+
MSIS,
15+
ND,
16+
NRMSE,
17+
RMSE,
18+
SMAPE,
19+
MeanWeightedSumQuantileLoss,
20+
)
21+
from gluonts.itertools import batcher
22+
23+
# from gluonts.model.evaluation import evaluate_forecasts
24+
from gluonts.model.forecast import SampleForecast
25+
from tqdm.auto import tqdm
26+
27+
from uni2ts.eval_util.data import get_gluonts_test_dataset
28+
from uni2ts.eval_util.evaluation import evaluate_forecasts
29+
from uni2ts.eval_util.metrics import MedianMSE
30+
31+
32+
def evaluate(pipeline, dataset, save_path, num_samples=20, batch_size=512):
33+
print("-" * 5, f"Evaluating {dataset}", "-" * 5)
34+
test_data, metadata = get_gluonts_test_dataset(dataset)
35+
prediction_length = metadata.prediction_length
36+
37+
while True:
38+
try:
39+
# Generate forecast samples
40+
forecast_samples = []
41+
for batch in tqdm(batcher(test_data.input, batch_size=batch_size)):
42+
context = [torch.tensor(entry["target"]) for entry in batch]
43+
forecast_samples.append(
44+
pipeline.predict(
45+
context,
46+
prediction_length=prediction_length,
47+
num_samples=num_samples,
48+
limit_prediction_length=False, # We disable the limit on prediction length.
49+
).numpy()
50+
)
51+
forecast_samples = np.concatenate(forecast_samples)
52+
break
53+
except torch.cuda.OutOfMemoryError:
54+
print(
55+
f"OutOfMemoryError at batch_size {batch_size}, reducing to {batch_size//2}"
56+
)
57+
batch_size //= 2
58+
59+
# Convert forecast samples into gluonts SampleForecast objects
60+
sample_forecasts = []
61+
for item, ts in zip(forecast_samples, test_data.input):
62+
forecast_start_date = ts["start"] + len(ts["target"])
63+
sample_forecasts.append(
64+
SampleForecast(samples=item, start_date=forecast_start_date)
65+
)
66+
67+
# Evaluate
68+
metrics_df = evaluate_forecasts(
69+
sample_forecasts,
70+
test_data=test_data,
71+
metrics=[
72+
MSE(),
73+
MAE(),
74+
MAPE(),
75+
SMAPE(),
76+
MSIS(),
77+
RMSE(),
78+
NRMSE(),
79+
ND(),
80+
MASE(),
81+
MedianMSE(),
82+
MeanWeightedSumQuantileLoss(np.arange(0.1, 1.0, 0.1)),
83+
],
84+
)
85+
metrics_df.index = [dataset]
86+
print(metrics_df)
87+
metrics_df.to_csv(save_path)
88+
print(f"Results saved to {save_path}")
89+
print("-" * 5, f"Evaluation of {dataset} complete", "-" * 5)
90+
return metrics_df
91+
92+
93+
if __name__ == "__main__":
94+
parser = argparse.ArgumentParser(
95+
description="Load a model and dataset, then make predictions."
96+
)
97+
parser.add_argument(
98+
"--model_path", type=str, required=True, help="Path to load the model"
99+
)
100+
parser.add_argument(
101+
"--dataset", type=str, required=True, help="Name of the dataset to use"
102+
)
103+
parser.add_argument(
104+
"--save_dir", type=str, default="results", help="Directory to save the results"
105+
)
106+
parser.add_argument(
107+
"--num_samples", type=int, default=20, help="Number of samples to generate"
108+
)
109+
parser.add_argument(
110+
"--batch_size", type=int, default=512, help="Batch size for generating samples"
111+
)
112+
parser.add_argument("--run_name", type=str, default="test", help="Name of the run")
113+
114+
args = parser.parse_args()
115+
# Load Chronos
116+
pipeline = ChronosPipeline.from_pretrained(
117+
# "amazon/chronos-t5-small",
118+
args.model_path,
119+
device_map="cuda:0",
120+
torch_dtype=torch.bfloat16,
121+
)
122+
output_dir = os.path.join(args.save_dir, args.run_name)
123+
if not os.path.exists(output_dir):
124+
os.makedirs(output_dir)
125+
evaluate(pipeline, args.dataset, os.path.join(output_dir, f"{args.dataset}.csv"))

0 commit comments

Comments
 (0)