Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ensure_time_dtype and handle pandas nullable dtypes in validate_format #38

Merged
merged 2 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 141 additions & 12 deletions nbs/validation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,135 @@
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from utilsforecast.compat import DataFrame, pl_DataFrame"
"from utilsforecast.compat import DataFrame, Series, pl_DataFrame, pl_Series, pl"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7638f2a6-0303-41a9-a103-d3055e7bc572",
"metadata": {},
"outputs": [],
"source": [
"from fastcore.test import test_fail"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e001379f-a552-4fd9-ac82-3f4d5c76f555",
"metadata": {},
"outputs": [],
"source": [
"#| polars\n",
"import polars.testing"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "98e0f72a-3128-46d7-8d6c-cbd0c62ce68a",
"metadata": {},
"outputs": [],
"source": [
"#| exporti\n",
"def _is_dt_or_int(s: Series) -> bool:\n",
" dtype = s.head(1).to_numpy().dtype\n",
" is_dt = np.issubdtype(dtype, np.datetime64)\n",
" is_int = np.issubdtype(dtype, np.integer)\n",
" return is_dt or is_int"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6679513a-4f2b-4eda-9970-e5c6725dd761",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"def ensure_shallow_copy(df: pd.DataFrame) -> pd.DataFrame:\n",
" from packaging.version import Version\n",
"\n",
" if Version(pd.__version__) < Version(\"1.4\"):\n",
" # https://github.com/pandas-dev/pandas/pull/43406\n",
" df = df.copy()\n",
" return df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ccb15f20-56e4-4d0d-96bc-830c5effff19",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"def ensure_time_dtype(df: DataFrame, time_col: str = 'ds') -> DataFrame:\n",
" \"\"\"Make sure that `time_col` contains timestamps or integers.\n",
" If it contains strings, try to cast them as timestamps.\"\"\"\n",
" times = df[time_col]\n",
" if _is_dt_or_int(times):\n",
" return df\n",
" parse_err_msg = (\n",
" f\"Failed to parse '{time_col}' from string to datetime. \"\n",
" 'Please make sure that it contains valid timestamps or integers.'\n",
" )\n",
" if isinstance(times, pd.Series) and pd.api.types.is_object_dtype(times):\n",
" try:\n",
" times = pd.to_datetime(times)\n",
" except ValueError:\n",
" raise ValueError(parse_err_msg)\n",
" df = ensure_shallow_copy(df.copy(deep=False))\n",
" df[time_col] = times\n",
" elif isinstance(times, pl_Series) and times.dtype == pl.Utf8:\n",
" try:\n",
" times = times.str.to_datetime()\n",
" except pl.exceptions.ComputeError:\n",
" raise ValueError(parse_err_msg)\n",
" df = df.with_columns(times)\n",
" else:\n",
" raise ValueError(f\"'{time_col}' should have valid timestamps or integers.\")\n",
" return df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "604ec21b-314f-42cb-9ee0-5950fd611b97",
"metadata": {},
"outputs": [],
"source": [
"pd.testing.assert_frame_equal(\n",
" ensure_time_dtype(pd.DataFrame({'ds': ['2000-01-01']})),\n",
" pd.DataFrame({'ds': pd.to_datetime(['2000-01-01'])})\n",
")\n",
"df = pd.DataFrame({'ds': [1, 2]})\n",
"assert df is ensure_time_dtype(df)\n",
"test_fail(\n",
" lambda: ensure_time_dtype(pd.DataFrame({'ds': ['2000-14-14']})),\n",
" contains='Please make sure that it contains valid timestamps',\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5335d217-e240-46df-90a0-6aeeb07a0586",
"metadata": {},
"outputs": [],
"source": [
"#| polars\n",
"pl.testing.assert_frame_equal(\n",
" ensure_time_dtype(pl.DataFrame({'ds': ['2000-01-01']})),\n",
" pl.DataFrame().with_columns(ds=pl.datetime(2000, 1, 1))\n",
")\n",
"df = pl.DataFrame({'ds': [1, 2]})\n",
"assert df is ensure_time_dtype(df)\n",
"test_fail(\n",
" lambda: ensure_time_dtype(pl.DataFrame({'ds': ['hello']})),\n",
" contains='Please make sure that it contains valid timestamps',\n",
")"
]
},
{
Expand Down Expand Up @@ -76,14 +204,18 @@
" raise ValueError(f\"The following columns are missing: {missing_cols}\")\n",
"\n",
" # time col\n",
" times_dtype = df[time_col].head(1).to_numpy().dtype\n",
" if not (np.issubdtype(times_dtype, np.datetime64) or np.issubdtype(times_dtype, np.integer)):\n",
" if not _is_dt_or_int(df[time_col]):\n",
" times_dtype = df[time_col].head(1).to_numpy().dtype\n",
" raise ValueError(f\"The time column ('{time_col}') should have either timestamps or integers, got '{times_dtype}'.\")\n",
"\n",
" # target col\n",
" target_dtype = df[target_col].head(1).to_numpy().dtype\n",
" if not np.issubdtype(target_dtype, np.number):\n",
" raise ValueError(f\"The target column ('{target_col}') should have a numeric data type, got '{target_dtype}')\")"
" target = df[target_col]\n",
" if isinstance(target, pd.Series):\n",
" is_numeric = np.issubdtype(target.dtype.type, np.number)\n",
" else:\n",
" is_numeric = target.is_numeric()\n",
" if not is_numeric:\n",
" raise ValueError(f\"The target column ('{target_col}') should have a numeric data type, got '{target.dtype}')\")"
]
},
{
Expand All @@ -108,7 +240,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/validation.py#L12){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/validation.py#L57){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### validate_format\n",
"\n",
Expand All @@ -130,7 +262,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/validation.py#L12){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/validation.py#L57){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### validate_format\n",
"\n",
Expand Down Expand Up @@ -168,10 +300,7 @@
"source": [
"import datetime\n",
"\n",
"import pandas as pd\n",
"from fastcore.test import test_fail\n",
"\n",
"from utilsforecast.compat import POLARS_INSTALLED, pl\n",
"from utilsforecast.compat import POLARS_INSTALLED\n",
"from utilsforecast.data import generate_series"
]
},
Expand Down
2 changes: 1 addition & 1 deletion settings.ini
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ keywords = time-series analysis forecasting
language = English
status = 3
user = Nixtla
requirements = numpy pandas>=1.1.1
requirements = numpy packaging pandas>=1.1.1
plotting_requirements = matplotlib plotly plotly-resampler
dev_requirements = matplotlib numba plotly polars pyarrow
readme_nb = index.ipynb
Expand Down
8 changes: 7 additions & 1 deletion utilsforecast/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,5 +166,11 @@
'utilsforecast/target_transforms.py'),
'utilsforecast.target_transforms._transform': ( 'target_transforms.html#_transform',
'utilsforecast/target_transforms.py')},
'utilsforecast.validation': { 'utilsforecast.validation.validate_format': ( 'validation.html#validate_format',
'utilsforecast.validation': { 'utilsforecast.validation._is_dt_or_int': ( 'validation.html#_is_dt_or_int',
'utilsforecast/validation.py'),
'utilsforecast.validation.ensure_shallow_copy': ( 'validation.html#ensure_shallow_copy',
'utilsforecast/validation.py'),
'utilsforecast.validation.ensure_time_dtype': ( 'validation.html#ensure_time_dtype',
'utilsforecast/validation.py'),
'utilsforecast.validation.validate_format': ( 'validation.html#validate_format',
'utilsforecast/validation.py')}}}
67 changes: 56 additions & 11 deletions utilsforecast/validation.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,59 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/validation.ipynb.

# %% auto 0
__all__ = ['validate_format']
__all__ = ['ensure_shallow_copy', 'ensure_time_dtype', 'validate_format']

# %% ../nbs/validation.ipynb 2
import numpy as np
import pandas as pd

from .compat import DataFrame, pl_DataFrame
from .compat import DataFrame, Series, pl_DataFrame, pl_Series, pl

# %% ../nbs/validation.ipynb 3
# %% ../nbs/validation.ipynb 5
def _is_dt_or_int(s: Series) -> bool:
dtype = s.head(1).to_numpy().dtype
is_dt = np.issubdtype(dtype, np.datetime64)
is_int = np.issubdtype(dtype, np.integer)
return is_dt or is_int

# %% ../nbs/validation.ipynb 6
def ensure_shallow_copy(df: pd.DataFrame) -> pd.DataFrame:
from packaging.version import Version

if Version(pd.__version__) < Version("1.4"):
# https://github.com/pandas-dev/pandas/pull/43406
df = df.copy()
return df

# %% ../nbs/validation.ipynb 7
def ensure_time_dtype(df: DataFrame, time_col: str = "ds") -> DataFrame:
"""Make sure that `time_col` contains timestamps or integers.
If it contains strings, try to cast them as timestamps."""
times = df[time_col]
if _is_dt_or_int(times):
return df
parse_err_msg = (
f"Failed to parse '{time_col}' from string to datetime. "
"Please make sure that it contains valid timestamps or integers."
)
if isinstance(times, pd.Series) and pd.api.types.is_object_dtype(times):
try:
times = pd.to_datetime(times)
except ValueError:
raise ValueError(parse_err_msg)
df = ensure_shallow_copy(df.copy(deep=False))
df[time_col] = times
elif isinstance(times, pl_Series) and times.dtype == pl.Utf8:
try:
times = times.str.to_datetime()
except pl.exceptions.ComputeError:
raise ValueError(parse_err_msg)
df = df.with_columns(times)
else:
raise ValueError(f"'{time_col}' should have valid timestamps or integers.")
return df

# %% ../nbs/validation.ipynb 10
def validate_format(
df: DataFrame,
id_col: str = "unique_id",
Expand Down Expand Up @@ -44,18 +88,19 @@ def validate_format(
raise ValueError(f"The following columns are missing: {missing_cols}")

# time col
times_dtype = df[time_col].head(1).to_numpy().dtype
if not (
np.issubdtype(times_dtype, np.datetime64)
or np.issubdtype(times_dtype, np.integer)
):
if not _is_dt_or_int(df[time_col]):
times_dtype = df[time_col].head(1).to_numpy().dtype
raise ValueError(
f"The time column ('{time_col}') should have either timestamps or integers, got '{times_dtype}'."
)

# target col
target_dtype = df[target_col].head(1).to_numpy().dtype
if not np.issubdtype(target_dtype, np.number):
target = df[target_col]
if isinstance(target, pd.Series):
is_numeric = np.issubdtype(target.dtype.type, np.number)
else:
is_numeric = target.is_numeric()
if not is_numeric:
raise ValueError(
f"The target column ('{target_col}') should have a numeric data type, got '{target_dtype}')"
f"The target column ('{target_col}') should have a numeric data type, got '{target.dtype}')"
)