Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(datasets): verify file exists if on Polars 1.0 #957

Merged
merged 9 commits into from
Dec 6, 2024
10 changes: 8 additions & 2 deletions kedro-datasets/kedro_datasets/polars/csv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,21 @@ class CSVDataset(AbstractVersionedDataset[pl.DataFrame, pl.DataFrame]):

.. code-block:: pycon

>>> from kedro_datasets.polars import CSVDataset
>>> import sys
>>>
>>> import polars as pl
>>> import pytest
>>> from kedro_datasets.polars import CSVDataset
>>>
>>> if sys.platform.startswith("win"):
... pytest.skip("this doctest fails on Windows CI runner")
...
>>> data = pl.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]})
>>>
>>> dataset = CSVDataset(filepath=tmp_path / "test.csv")
>>> dataset.save(data)
>>> reloaded = dataset.load()
>>> assert data.frame_equal(reloaded)
>>> assert data.equals(reloaded)

"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class EagerPolarsDataset(AbstractVersionedDataset[pl.DataFrame, pl.DataFrame]):
>>> dataset = EagerPolarsDataset(filepath=tmp_path / "test.parquet", file_format="parquet")
>>> dataset.save(data)
>>> reloaded = dataset.load()
>>> assert data.frame_equal(reloaded)
>>> assert data.equals(reloaded)

"""

Expand Down
6 changes: 5 additions & 1 deletion kedro-datasets/kedro_datasets/polars/lazy_polars_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
"""
from __future__ import annotations

import errno
import logging
import os
from copy import deepcopy
from pathlib import PurePosixPath
from typing import Any, ClassVar
Expand Down Expand Up @@ -69,7 +71,7 @@ class LazyPolarsDataset(
>>> dataset = LazyPolarsDataset(filepath=tmp_path / "test.csv", file_format="csv")
>>> dataset.save(data)
>>> reloaded = dataset.load()
>>> assert data.frame_equal(reloaded.collect())
>>> assert data.equals(reloaded.collect())

"""

Expand Down Expand Up @@ -199,6 +201,8 @@ def _describe(self) -> dict[str, Any]:

def load(self) -> pl.LazyFrame:
load_path = str(self._get_load_path())
if not self._exists():
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), load_path)

if self._protocol == "file":
# With local filesystems, we can use Polar's build-in I/O method:
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ test = [
"pandas>=2.0",
"Pillow~=10.0",
"plotly>=4.8.0, <6.0",
"polars[xlsx2csv, deltalake]~=0.18.0",
"polars[deltalake,xlsx2csv]>=1.0",
"pyarrow>=1.0; python_version < '3.11'",
"pyarrow>=7.0; python_version >= '3.11'", # Adding to avoid numpy build errors
"pyodbc~=5.0",
Expand Down
10 changes: 10 additions & 0 deletions kedro-datasets/tests/polars/test_csv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,14 @@ def mocked_csv_in_s3(mocked_s3_bucket, mocked_dataframe: pl.DataFrame):


class TestCSVDataset:
@pytest.mark.xfail(sys.platform == "win32", reason="file encoding is not UTF-8")
def test_save_and_load(self, csv_dataset, dummy_dataframe):
"""Test saving and reloading the dataset."""
csv_dataset.save(dummy_dataframe)
reloaded = csv_dataset.load()
assert_frame_equal(dummy_dataframe, reloaded)

@pytest.mark.xfail(sys.platform == "win32", reason="file encoding is not UTF-8")
def test_exists(self, csv_dataset, dummy_dataframe):
"""Test `exists` method invocation for both existing and
nonexistent dataset."""
Expand Down Expand Up @@ -202,13 +204,15 @@ def test_version_str_repr(self, load_version, save_version):
assert "load_args={'rechunk': True}" in str(ds)
assert "load_args={'rechunk': True}" in str(ds_versioned)

@pytest.mark.xfail(sys.platform == "win32", reason="file encoding is not UTF-8")
def test_save_and_load(self, versioned_csv_dataset, dummy_dataframe):
"""Test that saved and reloaded data matches the original one for
the versioned dataset."""
versioned_csv_dataset.save(dummy_dataframe)
reloaded_df = versioned_csv_dataset.load()
assert_frame_equal(dummy_dataframe, reloaded_df)

@pytest.mark.xfail(sys.platform == "win32", reason="file encoding is not UTF-8")
def test_multiple_loads(self, versioned_csv_dataset, dummy_dataframe, filepath_csv):
"""Test that if a new version is created mid-run, by an
external system, it won't be loaded in the current run."""
Expand All @@ -232,6 +236,7 @@ def test_multiple_loads(self, versioned_csv_dataset, dummy_dataframe, filepath_c
ds_new.resolve_load_version() == v_new
) # new version is discoverable by a new instance

@pytest.mark.xfail(sys.platform == "win32", reason="file encoding is not UTF-8")
def test_multiple_saves(self, dummy_dataframe, filepath_csv):
"""Test multiple cycles of save followed by load for the same dataset"""
ds_versioned = CSVDataset(filepath=filepath_csv, version=Version(None, None))
Expand All @@ -254,6 +259,7 @@ def test_multiple_saves(self, dummy_dataframe, filepath_csv):
ds_new = CSVDataset(filepath=filepath_csv, version=Version(None, None))
assert ds_new.resolve_load_version() == second_load_version

@pytest.mark.xfail(sys.platform == "win32", reason="file encoding is not UTF-8")
def test_release_instance_cache(self, dummy_dataframe, filepath_csv):
"""Test that cache invalidation does not affect other instances"""
ds_a = CSVDataset(filepath=filepath_csv, version=Version(None, None))
Expand Down Expand Up @@ -282,12 +288,14 @@ def test_no_versions(self, versioned_csv_dataset):
with pytest.raises(DatasetError, match=pattern):
versioned_csv_dataset.load()

@pytest.mark.xfail(sys.platform == "win32", reason="file encoding is not UTF-8")
def test_exists(self, versioned_csv_dataset, dummy_dataframe):
"""Test `exists` method invocation for versioned dataset."""
assert not versioned_csv_dataset.exists()
versioned_csv_dataset.save(dummy_dataframe)
assert versioned_csv_dataset.exists()

@pytest.mark.xfail(sys.platform == "win32", reason="file encoding is not UTF-8")
def test_prevent_overwrite(self, versioned_csv_dataset, dummy_dataframe):
"""Check the error when attempting to override the dataset if the
corresponding CSV file for a given save version already exists."""
Expand All @@ -299,6 +307,7 @@ def test_prevent_overwrite(self, versioned_csv_dataset, dummy_dataframe):
with pytest.raises(DatasetError, match=pattern):
versioned_csv_dataset.save(dummy_dataframe)

@pytest.mark.xfail(sys.platform == "win32", reason="file encoding is not UTF-8")
@pytest.mark.parametrize(
"load_version", ["2019-01-01T23.59.59.999Z"], indirect=True
)
Expand All @@ -325,6 +334,7 @@ def test_http_filesystem_no_versioning(self):
filepath="https://example.com/file.csv", version=Version(None, None)
)

@pytest.mark.xfail(sys.platform == "win32", reason="file encoding is not UTF-8")
def test_versioning_existing_dataset(
self, csv_dataset, versioned_csv_dataset, dummy_dataframe
):
Expand Down
1 change: 1 addition & 0 deletions kedro-datasets/tests/polars/test_eager_polars_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def excel_dataset(dummy_dataframe: pl.DataFrame, filepath_excel):
return EagerPolarsDataset(
filepath=filepath_excel.as_posix(),
file_format="excel",
load_args={"engine": "xlsx2csv"},
)


Expand Down
Loading