Skip to content

Commit

Permalink
Fix tests with Polar
Browse files Browse the repository at this point in the history
  • Loading branch information
marcopeix committed Jan 13, 2025
1 parent abe1ea2 commit c727ae3
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3546,7 +3546,7 @@
"source": [
"#| hide\n",
"#| polars\n",
"models = [LSTM(h=12, input_size=24, max_steps=5, hist_exog_list=['zeros'], scaler_type='robust')]\n",
"models = [LSTM(h=12, input_size=24, max_steps=5, scaler_type='robust')]\n",
"\n",
"# Pandas\n",
"nf = NeuralForecast(models=models, freq='M')\n",
Expand Down Expand Up @@ -3576,9 +3576,13 @@
"\n",
"def assert_equal_dfs(pandas_df, polars_df):\n",
" mapping = {k: v for k, v in inverse_renamer.items() if k in polars_df}\n",
" polars_df = polars_df.rename(mapping).to_pandas()\\\n",
" .sort_values(['unique_id', 'ds'], ascending=True)\\\n",
" .reset_index(drop=True)\n",
" pandas_df = pandas_df.reset_index(drop=True)\n",
" pd.testing.assert_frame_equal(\n",
" pandas_df,\n",
" polars_df.rename(mapping).to_pandas(),\n",
" polars_df,\n",
" )\n",
"\n",
"assert_equal_dfs(preds, preds_pl)\n",
Expand Down Expand Up @@ -3620,7 +3624,7 @@
" last_cutoff = train_end - test_size * pd.offsets.MonthEnd() - h * pd.offsets.MonthEnd()\n",
" expected_cutoffs = np.flip(np.array([last_cutoff - step_size * i * pd.offsets.MonthEnd() for i in range(n_expected_cutoffs)]))\n",
" pl_cutoffs = forecasts.filter(polars.col('uid') ==nf.uids[1]).select('cutoff').unique(maintain_order=True)\n",
" actual_cutoffs = np.array([pd.Timestamp(x['cutoff']) for x in pl_cutoffs.rows(named=True)])\n",
" actual_cutoffs = np.sort(np.array([pd.Timestamp(x['cutoff']) for x in pl_cutoffs.rows(named=True)]))\n",
" np.testing.assert_array_equal(expected_cutoffs, actual_cutoffs, err_msg=f\"{step_size=},{expected_cutoffs=},{actual_cutoffs=}\")\n",
"\n",
" # check forecast-points count per series\n",
Expand Down

0 comments on commit c727ae3

Please sign in to comment.