From 29cc7ed28b006d781dd813b2fe2ff549c3baff19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Thu, 9 Nov 2023 16:00:49 -0600 Subject: [PATCH 1/2] add ensure_time_dtype --- nbs/validation.ipynb | 143 +++++++++++++++++++++++++++++++++--- settings.ini | 2 +- utilsforecast/_modidx.py | 8 +- utilsforecast/validation.py | 57 ++++++++++++-- 4 files changed, 191 insertions(+), 19 deletions(-) diff --git a/nbs/validation.ipynb b/nbs/validation.ipynb index 76c98df..7082648 100644 --- a/nbs/validation.ipynb +++ b/nbs/validation.ipynb @@ -31,7 +31,135 @@ "import numpy as np\n", "import pandas as pd\n", "\n", - "from utilsforecast.compat import DataFrame, pl_DataFrame" + "from utilsforecast.compat import DataFrame, Series, pl_DataFrame, pl_Series, pl" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7638f2a6-0303-41a9-a103-d3055e7bc572", + "metadata": {}, + "outputs": [], + "source": [ + "from fastcore.test import test_fail" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e001379f-a552-4fd9-ac82-3f4d5c76f555", + "metadata": {}, + "outputs": [], + "source": [ + "#| polars\n", + "import polars.testing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "98e0f72a-3128-46d7-8d6c-cbd0c62ce68a", + "metadata": {}, + "outputs": [], + "source": [ + "#| exporti\n", + "def _is_dt_or_int(s: Series) -> bool:\n", + " dtype = s.head(1).to_numpy().dtype\n", + " is_dt = np.issubdtype(dtype, np.datetime64)\n", + " is_int = np.issubdtype(dtype, np.integer)\n", + " return is_dt or is_int" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6679513a-4f2b-4eda-9970-e5c6725dd761", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "def ensure_shallow_copy(df: pd.DataFrame) -> pd.DataFrame:\n", + " from packaging.version import Version\n", + "\n", + " if Version(pd.__version__) < Version(\"1.4\"):\n", + " # https://github.com/pandas-dev/pandas/pull/43406\n", + " df = df.copy()\n", + " return df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ccb15f20-56e4-4d0d-96bc-830c5effff19", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "def ensure_time_dtype(df: DataFrame, time_col: str = 'ds') -> DataFrame:\n", + " \"\"\"Make sure that `time_col` contains timestamps or integers.\n", + " If it contains strings, try to cast them as timestamps.\"\"\"\n", + " times = df[time_col]\n", + " if _is_dt_or_int(times):\n", + " return df\n", + " parse_err_msg = (\n", + " f\"Failed to parse '{time_col}' from string to datetime. \"\n", + " 'Please make sure that it contains valid timestamps or integers.'\n", + " )\n", + " if isinstance(times, pd.Series) and pd.api.types.is_object_dtype(times):\n", + " try:\n", + " times = pd.to_datetime(times)\n", + " except ValueError:\n", + " raise ValueError(parse_err_msg)\n", + " df = ensure_shallow_copy(df.copy(deep=False))\n", + " df[time_col] = times\n", + " elif isinstance(times, pl_Series) and times.dtype == pl.Utf8:\n", + " try:\n", + " times = times.str.to_datetime()\n", + " except pl.exceptions.ComputeError:\n", + " raise ValueError(parse_err_msg)\n", + " df = df.with_columns(times)\n", + " else:\n", + " raise ValueError(f\"'{time_col}' should have valid timestamps or integers.\")\n", + " return df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "604ec21b-314f-42cb-9ee0-5950fd611b97", + "metadata": {}, + "outputs": [], + "source": [ + "pd.testing.assert_frame_equal(\n", + " ensure_time_dtype(pd.DataFrame({'ds': ['2000-01-01']})),\n", + " pd.DataFrame({'ds': pd.to_datetime(['2000-01-01'])})\n", + ")\n", + "df = pd.DataFrame({'ds': [1, 2]})\n", + "assert df is ensure_time_dtype(df)\n", + "test_fail(\n", + " lambda: ensure_time_dtype(pd.DataFrame({'ds': ['2000-14-14']})),\n", + " contains='Please make sure that it contains valid timestamps',\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5335d217-e240-46df-90a0-6aeeb07a0586", + "metadata": {}, + "outputs": [], + "source": [ + "#| polars\n", + "pl.testing.assert_frame_equal(\n", + " ensure_time_dtype(pl.DataFrame({'ds': ['2000-01-01']})),\n", + " pl.DataFrame().with_columns(ds=pl.datetime(2000, 1, 1))\n", + ")\n", + "df = pl.DataFrame({'ds': [1, 2]})\n", + "assert df is ensure_time_dtype(df)\n", + "test_fail(\n", + " lambda: ensure_time_dtype(pl.DataFrame({'ds': ['hello']})),\n", + " contains='Please make sure that it contains valid timestamps',\n", + ")" ] }, { @@ -76,8 +204,8 @@ " raise ValueError(f\"The following columns are missing: {missing_cols}\")\n", "\n", " # time col\n", - " times_dtype = df[time_col].head(1).to_numpy().dtype\n", - " if not (np.issubdtype(times_dtype, np.datetime64) or np.issubdtype(times_dtype, np.integer)):\n", + " if not _is_dt_or_int(df[time_col]):\n", + " times_dtype = df[time_col].head(1).to_numpy().dtype\n", " raise ValueError(f\"The time column ('{time_col}') should have either timestamps or integers, got '{times_dtype}'.\")\n", "\n", " # target col\n", @@ -108,7 +236,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/validation.py#L12){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/validation.py#L13){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### validate_format\n", "\n", @@ -130,7 +258,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/validation.py#L12){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/validation.py#L13){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### validate_format\n", "\n", @@ -168,10 +296,7 @@ "source": [ "import datetime\n", "\n", - "import pandas as pd\n", - "from fastcore.test import test_fail\n", - "\n", - "from utilsforecast.compat import POLARS_INSTALLED, pl\n", + "from utilsforecast.compat import POLARS_INSTALLED\n", "from utilsforecast.data import generate_series" ] }, diff --git a/settings.ini b/settings.ini index c299ee7..9667419 100644 --- a/settings.ini +++ b/settings.ini @@ -26,7 +26,7 @@ keywords = time-series analysis forecasting language = English status = 3 user = Nixtla -requirements = numpy pandas>=1.1.1 +requirements = numpy packaging pandas>=1.1.1 plotting_requirements = matplotlib plotly plotly-resampler dev_requirements = matplotlib numba plotly polars pyarrow readme_nb = index.ipynb diff --git a/utilsforecast/_modidx.py b/utilsforecast/_modidx.py index d90b01c..b4130fe 100644 --- a/utilsforecast/_modidx.py +++ b/utilsforecast/_modidx.py @@ -166,5 +166,11 @@ 'utilsforecast/target_transforms.py'), 'utilsforecast.target_transforms._transform': ( 'target_transforms.html#_transform', 'utilsforecast/target_transforms.py')}, - 'utilsforecast.validation': { 'utilsforecast.validation.validate_format': ( 'validation.html#validate_format', + 'utilsforecast.validation': { 'utilsforecast.validation._is_dt_or_int': ( 'validation.html#_is_dt_or_int', + 'utilsforecast/validation.py'), + 'utilsforecast.validation.ensure_shallow_copy': ( 'validation.html#ensure_shallow_copy', + 'utilsforecast/validation.py'), + 'utilsforecast.validation.ensure_time_dtype': ( 'validation.html#ensure_time_dtype', + 'utilsforecast/validation.py'), + 'utilsforecast.validation.validate_format': ( 'validation.html#validate_format', 'utilsforecast/validation.py')}}} diff --git a/utilsforecast/validation.py b/utilsforecast/validation.py index 325f2e8..a2fbe7f 100644 --- a/utilsforecast/validation.py +++ b/utilsforecast/validation.py @@ -1,15 +1,59 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/validation.ipynb. # %% auto 0 -__all__ = ['validate_format'] +__all__ = ['ensure_shallow_copy', 'ensure_time_dtype', 'validate_format'] # %% ../nbs/validation.ipynb 2 import numpy as np import pandas as pd -from .compat import DataFrame, pl_DataFrame +from .compat import DataFrame, Series, pl_DataFrame, pl_Series, pl -# %% ../nbs/validation.ipynb 3 +# %% ../nbs/validation.ipynb 5 +def _is_dt_or_int(s: Series) -> bool: + dtype = s.head(1).to_numpy().dtype + is_dt = np.issubdtype(dtype, np.datetime64) + is_int = np.issubdtype(dtype, np.integer) + return is_dt or is_int + +# %% ../nbs/validation.ipynb 6 +def ensure_shallow_copy(df: pd.DataFrame) -> pd.DataFrame: + from packaging.version import Version + + if Version(pd.__version__) < Version("1.4"): + # https://github.com/pandas-dev/pandas/pull/43406 + df = df.copy() + return df + +# %% ../nbs/validation.ipynb 7 +def ensure_time_dtype(df: DataFrame, time_col: str = "ds") -> DataFrame: + """Make sure that `time_col` contains timestamps or integers. + If it contains strings, try to cast them as timestamps.""" + times = df[time_col] + if _is_dt_or_int(times): + return df + parse_err_msg = ( + f"Failed to parse '{time_col}' from string to datetime. " + "Please make sure that it contains valid timestamps or integers." + ) + if isinstance(times, pd.Series) and pd.api.types.is_object_dtype(times): + try: + times = pd.to_datetime(times) + except ValueError: + raise ValueError(parse_err_msg) + df = ensure_shallow_copy(df.copy(deep=False)) + df[time_col] = times + elif isinstance(times, pl_Series) and times.dtype == pl.Utf8: + try: + times = times.str.to_datetime() + except pl.exceptions.ComputeError: + raise ValueError(parse_err_msg) + df = df.with_columns(times) + else: + raise ValueError(f"'{time_col}' should have valid timestamps or integers.") + return df + +# %% ../nbs/validation.ipynb 10 def validate_format( df: DataFrame, id_col: str = "unique_id", @@ -44,11 +88,8 @@ def validate_format( raise ValueError(f"The following columns are missing: {missing_cols}") # time col - times_dtype = df[time_col].head(1).to_numpy().dtype - if not ( - np.issubdtype(times_dtype, np.datetime64) - or np.issubdtype(times_dtype, np.integer) - ): + if not _is_dt_or_int(df[time_col]): + times_dtype = df[time_col].head(1).to_numpy().dtype raise ValueError( f"The time column ('{time_col}') should have either timestamps or integers, got '{times_dtype}'." ) From ef1011fd5780744e17bbfa32679cffbae5ed8bbf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Thu, 9 Nov 2023 20:11:52 -0600 Subject: [PATCH 2/2] handle pandas nullable data types in target check --- nbs/validation.ipynb | 14 +++++++++----- utilsforecast/validation.py | 10 +++++++--- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/nbs/validation.ipynb b/nbs/validation.ipynb index 7082648..7a78636 100644 --- a/nbs/validation.ipynb +++ b/nbs/validation.ipynb @@ -209,9 +209,13 @@ " raise ValueError(f\"The time column ('{time_col}') should have either timestamps or integers, got '{times_dtype}'.\")\n", "\n", " # target col\n", - " target_dtype = df[target_col].head(1).to_numpy().dtype\n", - " if not np.issubdtype(target_dtype, np.number):\n", - " raise ValueError(f\"The target column ('{target_col}') should have a numeric data type, got '{target_dtype}')\")" + " target = df[target_col]\n", + " if isinstance(target, pd.Series):\n", + " is_numeric = np.issubdtype(target.dtype.type, np.number)\n", + " else:\n", + " is_numeric = target.is_numeric()\n", + " if not is_numeric:\n", + " raise ValueError(f\"The target column ('{target_col}') should have a numeric data type, got '{target.dtype}')\")" ] }, { @@ -236,7 +240,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/validation.py#L13){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/validation.py#L57){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### validate_format\n", "\n", @@ -258,7 +262,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/validation.py#L13){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/validation.py#L57){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### validate_format\n", "\n", diff --git a/utilsforecast/validation.py b/utilsforecast/validation.py index a2fbe7f..564ef32 100644 --- a/utilsforecast/validation.py +++ b/utilsforecast/validation.py @@ -95,8 +95,12 @@ def validate_format( ) # target col - target_dtype = df[target_col].head(1).to_numpy().dtype - if not np.issubdtype(target_dtype, np.number): + target = df[target_col] + if isinstance(target, pd.Series): + is_numeric = np.issubdtype(target.dtype.type, np.number) + else: + is_numeric = target.is_numeric() + if not is_numeric: raise ValueError( - f"The target column ('{target_col}') should have a numeric data type, got '{target_dtype}')" + f"The target column ('{target_col}') should have a numeric data type, got '{target.dtype}')" )