Skip to content

Commit

Permalink
feat(datasets): support versioning data partitions (#447)
Browse files Browse the repository at this point in the history
* feat(datasets): support versioning data partitions

Signed-off-by: Deepyaman Datta <[email protected]>

* Remove unused import

Signed-off-by: Deepyaman Datta <[email protected]>

* chore(datasets): use keyword arguments when needed

Signed-off-by: Deepyaman Datta <[email protected]>

* Apply suggestions from code review

Signed-off-by: Deepyaman Datta <[email protected]>

* Update kedro-datasets/kedro_datasets/partitions/partitioned_dataset.py

Signed-off-by: Deepyaman Datta <[email protected]>

---------

Signed-off-by: Deepyaman Datta <[email protected]>
  • Loading branch information
deepyaman authored Dec 11, 2023
1 parent 0923bc5 commit 28598ab
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 27 deletions.
7 changes: 4 additions & 3 deletions kedro-datasets/RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Upcoming Release

## Major features and improvements
* Removed support for Python 3.7 and 3.8
* Spark and Databricks based datasets now support [databricks-connect>=13.0](https://docs.databricks.com/en/dev-tools/databricks-connect-ref.html)
* Bump `s3fs` to latest calendar versioning.
* Removed support for Python 3.7 and 3.8.
* Spark and Databricks based datasets now support [databricks-connect>=13.0](https://docs.databricks.com/en/dev-tools/databricks-connect-ref.html).
* Bump `s3fs` to latest calendar-versioned release.
* `PartitionedDataset` and `IncrementalDataset` now both support versioning of the underlying dataset.

## Bug fixes and other changes
* Fixed bug with loading models saved with `TensorFlowModelDataset`.
Expand Down
15 changes: 10 additions & 5 deletions kedro-datasets/kedro_datasets/partitions/incremental_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
from kedro.io.data_catalog import CREDENTIALS_KEY
from kedro.utils import load_obj

from .partitioned_dataset import KEY_PROPAGATION_WARNING, PartitionedDataset
from .partitioned_dataset import (
KEY_PROPAGATION_WARNING,
PartitionedDataset,
_grandparent,
)


class IncrementalDataset(PartitionedDataset):
Expand Down Expand Up @@ -126,7 +130,7 @@ def __init__( # noqa: PLR0913
This is ignored by Kedro, but may be consumed by users or external plugins.
Raises:
DatasetError: If versioning is enabled for the underlying dataset.
DatasetError: If versioning is enabled for the checkpoint dataset.
"""

super().__init__(
Expand Down Expand Up @@ -186,6 +190,7 @@ def _list_partitions(self) -> list[str]:
checkpoint_path = self._filesystem._strip_protocol(
self._checkpoint_config[self._filepath_arg]
)
dataset_is_versioned = VERSION_KEY in self._dataset_config

def _is_valid_partition(partition) -> bool:
if not partition.endswith(self._filename_suffix):
Expand All @@ -199,9 +204,9 @@ def _is_valid_partition(partition) -> bool:
return self._comparison_func(partition_id, checkpoint)

return sorted(
part
for part in self._filesystem.find(self._normalized_path, **self._load_args)
if _is_valid_partition(part)
_grandparent(path) if dataset_is_versioned else path
for path in self._filesystem.find(self._normalized_path, **self._load_args)
if _is_valid_partition(path)
)

@property
Expand Down
26 changes: 17 additions & 9 deletions kedro-datasets/kedro_datasets/partitions/partitioned_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import operator
from copy import deepcopy
from pathlib import PurePosixPath
from typing import Any, Callable
from urllib.parse import urlparse
from warnings import warn
Expand All @@ -13,7 +14,6 @@
from cachetools import Cache, cachedmethod
from kedro.io.core import (
VERSION_KEY,
VERSIONED_FLAG_KEY,
AbstractDataset,
DatasetError,
parse_dataset_definition,
Expand All @@ -28,6 +28,19 @@
S3_PROTOCOLS = ("s3", "s3a", "s3n")


def _grandparent(path: str) -> str:
"""Check and return the logical parent of the parent of the path."""
path_obj = PurePosixPath(path)
grandparent = path_obj.parents[1]
if grandparent.name != path_obj.name:
last_three_parts = path_obj.relative_to(*path_obj.parts[:-3])
raise DatasetError(
f"`{path}` is not a well-formed versioned path ending with "
f"`filename/timestamp/filename` (got `{last_three_parts}`)."
)
return str(grandparent)


class PartitionedDataset(AbstractDataset[dict[str, Any], dict[str, Callable[[], Any]]]):
"""``PartitionedDataset`` loads and saves partitioned file-like data using the
underlying dataset definition. For filesystem level operations it uses `fsspec`:
Expand Down Expand Up @@ -174,7 +187,7 @@ def __init__( # noqa: PLR0913
load_args: Keyword arguments to be passed into ``find()`` method of
the filesystem implementation.
fs_args: Extra arguments to pass into underlying filesystem class constructor
(e.g. `{"project": "my-project"}` for ``GCSFileSystem``)
(e.g. `{"project": "my-project"}` for ``GCSFileSystem``).
overwrite: If True, any existing partitions will be removed.
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.
Expand All @@ -195,12 +208,6 @@ def __init__( # noqa: PLR0913

dataset = dataset if isinstance(dataset, dict) else {"type": dataset}
self._dataset_type, self._dataset_config = parse_dataset_definition(dataset)
if VERSION_KEY in self._dataset_config:
raise DatasetError(
f"'{self.__class__.__name__}' does not support versioning of the "
f"underlying dataset. Please remove '{VERSIONED_FLAG_KEY}' flag from "
f"the dataset definition."
)

if credentials:
if CREDENTIALS_KEY in self._dataset_config:
Expand Down Expand Up @@ -248,8 +255,9 @@ def _normalized_path(self) -> str:

@cachedmethod(cache=operator.attrgetter("_partition_cache"))
def _list_partitions(self) -> list[str]:
dataset_is_versioned = VERSION_KEY in self._dataset_config
return [
path
_grandparent(path) if dataset_is_versioned else path
for path in self._filesystem.find(self._normalized_path, **self._load_args)
if path.endswith(self._filename_suffix)
]
Expand Down
68 changes: 67 additions & 1 deletion kedro-datasets/tests/partitions/test_incremental_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,79 @@ def test_checkpoint_type(
),
],
)
def test_version_not_allowed(self, tmp_path, checkpoint_config, error_pattern):
def test_checkpoint_versioning_not_allowed(
self, tmp_path, checkpoint_config, error_pattern
):
"""Test that invalid checkpoint configurations raise expected errors"""
with pytest.raises(DatasetError, match=re.escape(error_pattern)):
IncrementalDataset(
path=str(tmp_path), dataset=DATASET, checkpoint=checkpoint_config
)

@pytest.mark.parametrize("dataset_config", [{"type": DATASET, "versioned": True}])
@pytest.mark.parametrize(
"suffix,expected_num_parts", [("", 5), (".csv", 5), ("bad", 0)]
)
def test_versioned_dataset_save_and_load(
self,
mocker,
tmp_path,
partitioned_data_pandas,
dataset_config,
suffix,
expected_num_parts,
):
"""Test that saved and reloaded data matches the original one for
the versioned dataset."""
save_version = "2020-01-01T00.00.00.000Z"
mock_ts = mocker.patch(
"kedro.io.core.generate_timestamp", return_value=save_version
)
IncrementalDataset(path=str(tmp_path), dataset=dataset_config).save(
partitioned_data_pandas
)
mock_ts.assert_called_once()

dataset = IncrementalDataset(
path=str(tmp_path), dataset=dataset_config, filename_suffix=suffix
)
loaded_partitions = dataset.load()

assert len(loaded_partitions) == expected_num_parts

actual_save_versions = set()
for part in loaded_partitions:
partition_dir = tmp_path / (part + suffix)
actual_save_versions |= {each.name for each in partition_dir.iterdir()}
assert partition_dir.is_dir()
assert_frame_equal(
loaded_partitions[part], partitioned_data_pandas[part + suffix]
)

if expected_num_parts:
# all partitions were saved using the same version string
assert actual_save_versions == {save_version}

def test_malformed_versioned_path(self, tmp_path):
local_dir = tmp_path / "files"
local_dir.mkdir()

path = local_dir / "path/to/folder/new/partition/version/partition/file"
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text("content")

dataset = IncrementalDataset(
path=str(local_dir / "path/to/folder"),
dataset={"type": "pandas.CSVDataset", "versioned": True},
)

pattern = re.escape(
f"`{path.as_posix()}` is not a well-formed versioned path ending with "
f"`filename/timestamp/filename` (got `version/partition/file`)."
)
with pytest.raises(DatasetError, match=pattern):
dataset.load()

@pytest.mark.parametrize(
"pds_config,fs_creds,dataset_creds,checkpoint_creds",
[
Expand Down
78 changes: 69 additions & 9 deletions kedro-datasets/tests/partitions/test_partitioned_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def partitioned_data_pandas():

@pytest.fixture
def local_csvs(tmp_path, partitioned_data_pandas):
local_dir = Path(str(tmp_path / "csvs"))
local_dir = tmp_path / "csvs"
local_dir.mkdir()

for k, data in partitioned_data_pandas.items():
Expand All @@ -38,6 +38,11 @@ def local_csvs(tmp_path, partitioned_data_pandas):
return local_dir


@pytest.fixture
def filepath_csvs(tmp_path):
return str(tmp_path / "csvs")


LOCAL_DATASET_DEFINITION = [
"pandas.CSVDataset",
"kedro_datasets.pandas.CSVDataset",
Expand Down Expand Up @@ -291,17 +296,72 @@ def test_invalid_dataset_config(self, dataset_config, error_pattern):
@pytest.mark.parametrize(
"dataset_config",
[
{"type": CSVDataset, "versioned": True},
{"type": "pandas.CSVDataset", "versioned": True},
{**ds_config, "versioned": True}
for ds_config in LOCAL_DATASET_DEFINITION
if isinstance(ds_config, dict)
],
)
def test_versioned_dataset_not_allowed(self, dataset_config):
pattern = (
"'PartitionedDataset' does not support versioning of the underlying "
"dataset. Please remove 'versioned' flag from the dataset definition."
@pytest.mark.parametrize(
"suffix,expected_num_parts", [("", 5), (".csv", 3), ("p4", 1)]
)
def test_versioned_dataset_save_and_load(
self,
mocker,
filepath_csvs,
dataset_config,
suffix,
expected_num_parts,
partitioned_data_pandas,
):
"""Test that saved and reloaded data matches the original one for
the versioned dataset."""
save_version = "2020-01-01T00.00.00.000Z"
mock_ts = mocker.patch(
"kedro.io.core.generate_timestamp", return_value=save_version
)
with pytest.raises(DatasetError, match=re.escape(pattern)):
PartitionedDataset(path=str(Path.cwd()), dataset=dataset_config)
PartitionedDataset(path=filepath_csvs, dataset=dataset_config).save(
partitioned_data_pandas
)
mock_ts.assert_called_once()

pds = PartitionedDataset(
path=filepath_csvs, dataset=dataset_config, filename_suffix=suffix
)
loaded_partitions = pds.load()

assert len(loaded_partitions) == expected_num_parts
actual_save_versions = set()
for partition_id, load_func in loaded_partitions.items():
partition_dir = Path(filepath_csvs, partition_id + suffix)
actual_save_versions |= {each.name for each in partition_dir.iterdir()}
df = load_func()
assert_frame_equal(df, partitioned_data_pandas[partition_id + suffix])
if suffix:
assert not partition_id.endswith(suffix)

if expected_num_parts:
# all partitions were saved using the same version string
assert actual_save_versions == {save_version}

def test_malformed_versioned_path(self, tmp_path):
local_dir = tmp_path / "files"
local_dir.mkdir()

path = local_dir / "path/to/folder/new/partition/version/partition/file"
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text("content")

pds = PartitionedDataset(
path=str(local_dir / "path/to/folder"),
dataset={"type": "pandas.CSVDataset", "versioned": True},
)

pattern = re.escape(
f"`{path.as_posix()}` is not a well-formed versioned path ending with "
f"`filename/timestamp/filename` (got `version/partition/file`)."
)
with pytest.raises(DatasetError, match=pattern):
pds.load()

def test_no_partitions(self, tmpdir):
pds = PartitionedDataset(path=str(tmpdir), dataset="pandas.CSVDataset")
Expand Down

0 comments on commit 28598ab

Please sign in to comment.