Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 35 additions & 70 deletions examples/02. Train a model and perform backtest.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,23 @@
},
{
"cell_type": "code",
"execution_count": null,
"id": "066796c6",
"metadata": {
"ExecuteTime": {
"end_time": "2022-02-09T16:41:59.248166Z",
"start_time": "2022-02-09T16:41:59.231129Z"
}
},
"outputs": [],
"metadata": {},
"source": [
"import pandas as pd\n",
"from openstef.pipeline.train_create_forecast_backtest import train_model_and_forecast_back_test\n",
"from openstef.metrics.figure import plot_feature_importance\n",
"from openstef.data_classes.model_specifications import ModelSpecificationDataClass\n",
"from openstef.data_classes.prediction_job import PredictionJobDataClass # TODO, import from openstef when availavle\n",
"from openstef.plotting.load_forecast_plotter import LoadForecastPlotter\n",
"\n",
"# Set plotly as the default pandas plotting backend\n",
"pd.options.plotting.backend = 'plotly'\n",
"import plotly.io as pio\n",
"pio.renderers.default = \"plotly_mimetype+notebook\""
]
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
Expand All @@ -49,15 +45,8 @@
},
{
"cell_type": "code",
"execution_count": null,
"id": "86ba5377",
"metadata": {
"ExecuteTime": {
"end_time": "2022-02-09T16:43:47.845079Z",
"start_time": "2022-02-09T16:43:47.631699Z"
}
},
"outputs": [],
"metadata": {},
"source": [
"# Define properties of training/prediction. We call this a 'prediction_job' \n",
"pj=PredictionJobDataClass(id=287,\n",
Expand All @@ -79,7 +68,9 @@
"\n",
"# Load input data\n",
"input_data = pd.read_csv('data/get_model_input_pid_287.csv', index_col='index', parse_dates=True)\n"
]
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
Expand All @@ -94,15 +85,8 @@
},
{
"cell_type": "code",
"execution_count": null,
"id": "3695e036",
"metadata": {
"ExecuteTime": {
"end_time": "2022-02-09T16:48:17.967356Z",
"start_time": "2022-02-09T16:48:03.036309Z"
}
},
"outputs": [],
"metadata": {},
"source": [
"# Perform the backtest\n",
"n_folds = 2\n",
Expand All @@ -118,7 +102,9 @@
"# If n_folds>1, model is a list of models. In that case, only use the first model\n",
"if n_folds>1:\n",
" model=model[0]"
]
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
Expand All @@ -136,34 +122,22 @@
},
{
"cell_type": "code",
"execution_count": null,
"id": "4b0c71ae",
"metadata": {
"ExecuteTime": {
"end_time": "2022-02-09T16:48:22.244769Z",
"start_time": "2022-02-09T16:48:20.133250Z"
}
},
"outputs": [],
"metadata": {},
"source": [
"for horizon in set(forecast.horizon):\n",
" fig = forecast.loc[forecast.horizon==horizon,['quantile_P10','quantile_P30',\n",
" 'quantile_P50','quantile_P70','quantile_P90','realised','forecast']].plot(\n",
" title=f\"Horizon: {horizon}\")\n",
" fig.update_traces(\n",
" line=dict(color=\"green\", width=1), fill='tonexty', fillcolor='rgba(0, 255, 0, 0.1)',\n",
" selector=lambda x: 'quantile' in x.name and x.name != 'quantile_P10')\n",
" fig.update_traces(\n",
" line=dict(color=\"green\", width=1),\n",
" selector=lambda x: 'quantile_P10' == x.name)\n",
" fig.update_traces(\n",
" line=dict(color=\"red\", width=2),\n",
" selector=lambda x: 'realised' in x.name)\n",
" fig.update_traces(\n",
" line=dict(color=\"blue\", width=2),\n",
" selector=lambda x: 'forecast' in x.name)\n",
" data = forecast.loc[forecast.horizon==horizon]\n",
"\n",
" fig = LoadForecastPlotter().plot(\n",
" realized=data[\"realised\"],\n",
" forecast=data[\"forecast\"],\n",
" quantiles=data.filter(regex=\"quantile\")\n",
" )\n",
" fig.update_layout(dict(title=f'Horizon {horizon}'))\n",
" fig.show()"
]
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
Expand All @@ -175,20 +149,16 @@
},
{
"cell_type": "code",
"execution_count": null,
"id": "54426833",
"metadata": {
"ExecuteTime": {
"end_time": "2022-02-09T16:48:22.333170Z",
"start_time": "2022-02-09T16:48:22.246296Z"
},
"scrolled": false
},
"outputs": [],
"source": [
"forecast['err']=forecast['realised']-forecast['forecast']\n",
"forecast['err'].plot()"
]
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
Expand All @@ -200,18 +170,18 @@
},
{
"cell_type": "code",
"execution_count": null,
"id": "8e73f034",
"metadata": {},
"outputs": [],
"source": [
"mea = forecast.pivot_table(index='horizon', values=['err'], aggfunc=lambda x: x.abs().mean())\n",
"mea.index=mea.index.astype(str)\n",
"fig = mea.plot(kind='bar')\n",
"fig.update_layout(dict(title='MAE',\n",
" xaxis=dict(title='horizon'),\n",
" yaxis=dict(title='MAE [MW]')))"
]
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
Expand All @@ -223,18 +193,13 @@
},
{
"cell_type": "code",
"execution_count": null,
"id": "e6cb2814",
"metadata": {
"ExecuteTime": {
"end_time": "2022-02-09T16:48:31.314864Z",
"start_time": "2022-02-09T16:48:31.281873Z"
}
},
"outputs": [],
"metadata": {},
"source": [
"plot_feature_importance(model.feature_importance_dataframe)"
]
],
"outputs": [],
"execution_count": null
}
],
"metadata": {
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
jupyter==1.0.0
jupyterlab~=3.6.2
openstef==3.4.7
openstef==3.4.64
Loading