-
Notifications
You must be signed in to change notification settings - Fork 94
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(datasets): Add CSVDataset to dask module (#627)
* Add CSVDataset to dask module Signed-off-by: Michael Sexton <[email protected]> * Add tests to dask.CSVDataset Signed-off-by: Michael Sexton <[email protected]> * Fix formatting issues in example usage Signed-off-by: Michael Sexton <[email protected]> * Fix error in example usage that is causing test to fail Signed-off-by: Michael Sexton <[email protected]> * Remove arguments from example usage Signed-off-by: Michael Sexton <[email protected]> * Fix issue with folder used as path for CSV file Signed-off-by: Michael Sexton <[email protected]> * Change number of partitions to fix failing assertion Signed-off-by: Michael Sexton <[email protected]> * Fix syntax issue Signed-off-by: Michael Sexton <[email protected]> * Remove temp path Signed-off-by: Michael Sexton <[email protected]> * Add default save args Signed-off-by: Michael Sexton <[email protected]> * Add to documentation and release notes Signed-off-by: Michael Sexton <[email protected]> * Fix lint Signed-off-by: Merel Theisen <[email protected]> * Try fix netcdfdataset doctest Signed-off-by: Merel Theisen <[email protected]> * Try fix netcdfdataset doctest pointing at file Signed-off-by: Merel Theisen <[email protected]> * Fix moto mock_aws import Signed-off-by: Merel Theisen <[email protected]> * Fix lint and test Signed-off-by: Ankita Katiyar <[email protected]> * Mypy Signed-off-by: Ankita Katiyar <[email protected]> * docs test Signed-off-by: Ankita Katiyar <[email protected]> * docs test Signed-off-by: Ankita Katiyar <[email protected]> * docs test Signed-off-by: Ankita Katiyar <[email protected]> * Fix unit tests Signed-off-by: Ankita Katiyar <[email protected]> * Remove extra comments Signed-off-by: Ankita Katiyar <[email protected]> * Try fix test Signed-off-by: Ankita Katiyar <[email protected]> * Release notes + test Signed-off-by: Ankita Katiyar <[email protected]> * Suggestion from code review Signed-off-by: Ankita Katiyar <[email protected]> --------- Signed-off-by: Michael Sexton <[email protected]> Signed-off-by: Merel Theisen <[email protected]> Signed-off-by: Merel Theisen <[email protected]> Signed-off-by: Ankita Katiyar <[email protected]> Signed-off-by: Ankita Katiyar <[email protected]> Co-authored-by: Merel Theisen <[email protected]> Co-authored-by: Merel Theisen <[email protected]> Co-authored-by: Ankita Katiyar <[email protected]> Co-authored-by: Ankita Katiyar <[email protected]>
- Loading branch information
1 parent
bf6596a
commit 966d989
Showing
6 changed files
with
296 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
"""``CSVDataset`` is a data set used to load and save data to CSV files using Dask | ||
dataframe""" | ||
from __future__ import annotations | ||
|
||
from copy import deepcopy | ||
from typing import Any | ||
|
||
import dask.dataframe as dd | ||
import fsspec | ||
from kedro.io.core import AbstractDataset, get_protocol_and_path | ||
|
||
|
||
class CSVDataset(AbstractDataset[dd.DataFrame, dd.DataFrame]): | ||
"""``CSVDataset`` loads and saves data to comma-separated value file(s). It uses Dask | ||
remote data services to handle the corresponding load and save operations: | ||
https://docs.dask.org/en/latest/how-to/connect-to-remote-data.html | ||
Example usage for the | ||
`YAML API <https://kedro.readthedocs.io/en/stable/data/\ | ||
data_catalog_yaml_examples.html>`_: | ||
.. code-block:: yaml | ||
cars: | ||
type: dask.CSVDataset | ||
filepath: s3://bucket_name/path/to/folder | ||
save_args: | ||
compression: GZIP | ||
credentials: | ||
client_kwargs: | ||
aws_access_key_id: YOUR_KEY | ||
aws_secret_access_key: YOUR_SECRET | ||
Example usage for the | ||
`Python API <https://kedro.readthedocs.io/en/stable/data/\ | ||
advanced_data_catalog_usage.html>`_: | ||
.. code-block:: pycon | ||
>>> from kedro_datasets.dask import CSVDataset | ||
>>> import pandas as pd | ||
>>> import numpy as np | ||
>>> import dask.dataframe as dd | ||
>>> data = pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [[5, 6], [7, 8]]}) | ||
>>> ddf = dd.from_pandas(data, npartitions=1) | ||
>>> dataset = CSVDataset(filepath="path/to/folder/*.csv") | ||
>>> dataset.save(ddf) | ||
>>> reloaded = dataset.load() | ||
>>> assert np.array_equal(ddf.compute(), reloaded.compute()) | ||
""" | ||
|
||
DEFAULT_LOAD_ARGS: dict[str, Any] = {} | ||
DEFAULT_SAVE_ARGS: dict[str, Any] = {"index": False} | ||
|
||
def __init__( # noqa: PLR0913 | ||
self, | ||
filepath: str, | ||
load_args: dict[str, Any] | None = None, | ||
save_args: dict[str, Any] | None = None, | ||
credentials: dict[str, Any] | None = None, | ||
fs_args: dict[str, Any] | None = None, | ||
metadata: dict[str, Any] | None = None, | ||
) -> None: | ||
"""Creates a new instance of ``CSVDataset`` pointing to concrete | ||
CSV files. | ||
Args: | ||
filepath: Filepath in POSIX format to a CSV file | ||
CSV collection or the directory of a multipart CSV. | ||
load_args: Additional loading options `dask.dataframe.read_csv`: | ||
https://docs.dask.org/en/latest/generated/dask.dataframe.read_csv.html | ||
save_args: Additional saving options for `dask.dataframe.to_csv`: | ||
https://docs.dask.org/en/latest/generated/dask.dataframe.to_csv.html | ||
credentials: Credentials required to get access to the underlying filesystem. | ||
E.g. for ``GCSFileSystem`` it should look like `{"token": None}`. | ||
fs_args: Optional parameters to the backend file system driver: | ||
https://docs.dask.org/en/latest/how-to/connect-to-remote-data.html#optional-parameters | ||
metadata: Any arbitrary metadata. | ||
This is ignored by Kedro, but may be consumed by users or external plugins. | ||
""" | ||
self._filepath = filepath | ||
self._fs_args = deepcopy(fs_args) or {} | ||
self._credentials = deepcopy(credentials) or {} | ||
|
||
self.metadata = metadata | ||
|
||
# Handle default load and save arguments | ||
self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) | ||
if load_args is not None: | ||
self._load_args.update(load_args) | ||
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) | ||
if save_args is not None: | ||
self._save_args.update(save_args) | ||
|
||
@property | ||
def fs_args(self) -> dict[str, Any]: | ||
"""Property of optional file system parameters. | ||
Returns: | ||
A dictionary of backend file system parameters, including credentials. | ||
""" | ||
fs_args = deepcopy(self._fs_args) | ||
fs_args.update(self._credentials) | ||
return fs_args | ||
|
||
def _describe(self) -> dict[str, Any]: | ||
return { | ||
"filepath": self._filepath, | ||
"load_args": self._load_args, | ||
"save_args": self._save_args, | ||
} | ||
|
||
def _load(self) -> dd.DataFrame: | ||
return dd.read_csv( | ||
self._filepath, storage_options=self.fs_args, **self._load_args | ||
) | ||
|
||
def _save(self, data: dd.DataFrame) -> None: | ||
data.to_csv(self._filepath, storage_options=self.fs_args, **self._save_args) | ||
|
||
def _exists(self) -> bool: | ||
protocol = get_protocol_and_path(self._filepath)[0] | ||
file_system = fsspec.filesystem(protocol=protocol, **self.fs_args) | ||
files = file_system.glob(self._filepath) | ||
return bool(files) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
import boto3 | ||
import dask.dataframe as dd | ||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
from kedro.io.core import DatasetError | ||
from moto import mock_aws | ||
from s3fs import S3FileSystem | ||
|
||
from kedro_datasets.dask import CSVDataset | ||
|
||
FILE_NAME = "*.csv" | ||
BUCKET_NAME = "test_bucket" | ||
AWS_CREDENTIALS = {"key": "FAKE_ACCESS_KEY", "secret": "FAKE_SECRET_KEY"} | ||
|
||
# Pathlib cannot be used since it strips out the second slash from "s3://" | ||
S3_PATH = f"s3://{BUCKET_NAME}/{FILE_NAME}" | ||
|
||
|
||
@pytest.fixture | ||
def mocked_s3_bucket(): | ||
"""Create a bucket for testing using moto.""" | ||
with mock_aws(): | ||
conn = boto3.client( | ||
"s3", | ||
aws_access_key_id="fake_access_key", | ||
aws_secret_access_key="fake_secret_key", | ||
) | ||
conn.create_bucket(Bucket=BUCKET_NAME) | ||
yield conn | ||
|
||
|
||
@pytest.fixture | ||
def dummy_dd_dataframe() -> dd.DataFrame: | ||
df = pd.DataFrame( | ||
{"Name": ["Alex", "Bob", "Clarke", "Dave"], "Age": [31, 12, 65, 29]} | ||
) | ||
return dd.from_pandas(df, npartitions=1) | ||
|
||
|
||
@pytest.fixture | ||
def mocked_s3_object(tmp_path, mocked_s3_bucket, dummy_dd_dataframe: dd.DataFrame): | ||
"""Creates test data and adds it to mocked S3 bucket.""" | ||
pandas_df = dummy_dd_dataframe.compute() | ||
temporary_path = tmp_path / "test.csv" | ||
pandas_df.to_csv(str(temporary_path)) | ||
|
||
mocked_s3_bucket.put_object( | ||
Bucket=BUCKET_NAME, Key=FILE_NAME, Body=temporary_path.read_bytes() | ||
) | ||
return mocked_s3_bucket | ||
|
||
|
||
@pytest.fixture | ||
def s3_dataset(load_args, save_args): | ||
return CSVDataset( | ||
filepath=S3_PATH, | ||
credentials=AWS_CREDENTIALS, | ||
load_args=load_args, | ||
save_args=save_args, | ||
) | ||
|
||
|
||
@pytest.fixture() | ||
def s3fs_cleanup(): | ||
# clear cache so we get a clean slate every time we instantiate a S3FileSystem | ||
yield | ||
S3FileSystem.cachable = False | ||
|
||
|
||
@pytest.mark.usefixtures("s3fs_cleanup") | ||
class TestCSVDataset: | ||
def test_incorrect_credentials_load(self): | ||
"""Test that incorrect credential keys won't instantiate dataset.""" | ||
pattern = r"unexpected keyword argument" | ||
with pytest.raises(DatasetError, match=pattern): | ||
CSVDataset( | ||
filepath=S3_PATH, | ||
credentials={ | ||
"client_kwargs": {"access_token": "TOKEN", "access_key": "KEY"} | ||
}, | ||
).load().compute() | ||
|
||
@pytest.mark.parametrize("bad_credentials", [{"key": None, "secret": None}]) | ||
def test_empty_credentials_load(self, bad_credentials): | ||
csv_dataset = CSVDataset(filepath=S3_PATH, credentials=bad_credentials) | ||
pattern = r"Failed while loading data from data set CSVDataset\(.+\)" | ||
with pytest.raises(DatasetError, match=pattern): | ||
csv_dataset.load().compute() | ||
|
||
@pytest.mark.xfail | ||
def test_pass_credentials(self, mocker): | ||
"""Test that AWS credentials are passed successfully into boto3 | ||
client instantiation on creating S3 connection.""" | ||
client_mock = mocker.patch("botocore.session.Session.create_client") | ||
s3_dataset = CSVDataset(filepath=S3_PATH, credentials=AWS_CREDENTIALS) | ||
pattern = r"Failed while loading data from data set CSVDataset\(.+\)" | ||
with pytest.raises(DatasetError, match=pattern): | ||
s3_dataset.load().compute() | ||
|
||
assert client_mock.call_count == 1 | ||
args, kwargs = client_mock.call_args_list[0] | ||
assert args == ("s3",) | ||
assert kwargs["aws_access_key_id"] == AWS_CREDENTIALS["key"] | ||
assert kwargs["aws_secret_access_key"] == AWS_CREDENTIALS["secret"] | ||
|
||
def test_save_data(self, s3_dataset, mocked_s3_bucket): | ||
"""Test saving the data to S3.""" | ||
pd_data = pd.DataFrame( | ||
{"col1": ["a", "b"], "col2": ["c", "d"], "col3": ["e", "f"]} | ||
) | ||
dd_data = dd.from_pandas(pd_data, npartitions=1) | ||
s3_dataset.save(dd_data) | ||
loaded_data = s3_dataset.load() | ||
np.array_equal(loaded_data.compute(), dd_data.compute()) | ||
|
||
def test_load_data(self, s3_dataset, dummy_dd_dataframe, mocked_s3_object): | ||
"""Test loading the data from S3.""" | ||
loaded_data = s3_dataset.load() | ||
np.array_equal(loaded_data, dummy_dd_dataframe.compute()) | ||
|
||
def test_exists(self, s3_dataset, dummy_dd_dataframe, mocked_s3_bucket): | ||
"""Test `exists` method invocation for both existing and | ||
nonexistent data set.""" | ||
assert not s3_dataset.exists() | ||
s3_dataset.save(dummy_dd_dataframe) | ||
assert s3_dataset.exists() | ||
|
||
def test_save_load_locally(self, tmp_path, dummy_dd_dataframe): | ||
"""Test loading the data locally.""" | ||
file_path = str(tmp_path / "some" / "dir" / FILE_NAME) | ||
dataset = CSVDataset(filepath=file_path) | ||
|
||
assert not dataset.exists() | ||
dataset.save(dummy_dd_dataframe) | ||
assert dataset.exists() | ||
loaded_data = dataset.load() | ||
dummy_dd_dataframe.compute().equals(loaded_data.compute()) | ||
|
||
@pytest.mark.parametrize( | ||
"load_args", [{"k1": "v1", "index": "value"}], indirect=True | ||
) | ||
def test_load_extra_params(self, s3_dataset, load_args): | ||
"""Test overriding the default load arguments.""" | ||
for key, value in load_args.items(): | ||
assert s3_dataset._load_args[key] == value | ||
|
||
@pytest.mark.parametrize( | ||
"save_args", [{"k1": "v1", "index": "value"}], indirect=True | ||
) | ||
def test_save_extra_params(self, s3_dataset, save_args): | ||
"""Test overriding the default save arguments.""" | ||
|
||
for key, value in save_args.items(): | ||
assert s3_dataset._save_args[key] == value | ||
|
||
for key, value in s3_dataset.DEFAULT_SAVE_ARGS.items(): | ||
assert s3_dataset._save_args[key] != value |