Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(datasets): share the cache of Ibis connections #941

Merged
merged 5 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions kedro-datasets/kedro_datasets/_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .connection_mixin import ConnectionMixin # noqa: F401
32 changes: 32 additions & 0 deletions kedro-datasets/kedro_datasets/_utils/connection_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from abc import ABC, abstractmethod
from collections.abc import Hashable
from typing import Any, ClassVar


class ConnectionMixin(ABC):
_CONNECTION_GROUP: ClassVar[str]

_connection_config: dict[str, Any]

_connections: ClassVar[dict[Hashable, Any]] = {}

@abstractmethod
def _connect(self) -> Any:
... # pragma: no cover

@property
def _connection(self) -> Any:
def hashable(value: Any) -> Hashable:
"""Return a hashable key for a potentially-nested object."""
if isinstance(value, dict):
return tuple((k, hashable(v)) for k, v in sorted(value.items()))
if isinstance(value, list):
return tuple(hashable(x) for x in value)
return value

cls = type(self)
key = self._CONNECTION_GROUP, hashable(self._connection_config)
if key not in cls._connections:
cls._connections[key] = self._connect()

return cls._connections[key]
33 changes: 12 additions & 21 deletions kedro-datasets/kedro_datasets/ibis/file_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
import ibis.expr.types as ir
from kedro.io import AbstractVersionedDataset, DatasetError, Version

from kedro_datasets._utils import ConnectionMixin

if TYPE_CHECKING:
from ibis import BaseBackend


class FileDataset(AbstractVersionedDataset[ir.Table, ir.Table]):
class FileDataset(ConnectionMixin, AbstractVersionedDataset[ir.Table, ir.Table]):
"""``FileDataset`` loads/saves data from/to a specified file format.

Example usage for the
Expand Down Expand Up @@ -73,7 +75,7 @@ class FileDataset(AbstractVersionedDataset[ir.Table, ir.Table]):
DEFAULT_LOAD_ARGS: ClassVar[dict[str, Any]] = {}
DEFAULT_SAVE_ARGS: ClassVar[dict[str, Any]] = {}

_connections: ClassVar[dict[tuple[tuple[str, str]], BaseBackend]] = {}
_CONNECTION_GROUP: ClassVar[str] = "ibis"

def __init__( # noqa: PLR0913
self,
Expand Down Expand Up @@ -143,28 +145,17 @@ def __init__( # noqa: PLR0913
if save_args is not None:
self._save_args.update(save_args)

def _connect(self) -> BaseBackend:
import ibis

config = deepcopy(self._connection_config)
backend = getattr(ibis, config.pop("backend"))
return backend.connect(**config)

@property
def connection(self) -> BaseBackend:
"""The ``Backend`` instance for the connection configuration."""

def hashable(value):
"""Return a hashable key for a potentially-nested object."""
if isinstance(value, dict):
return tuple((k, hashable(v)) for k, v in sorted(value.items()))
if isinstance(value, list):
return tuple(hashable(x) for x in value)
return value

cls = type(self)
key = hashable(self._connection_config)
if key not in cls._connections:
import ibis

config = deepcopy(self._connection_config)
backend = getattr(ibis, config.pop("backend"))
cls._connections[key] = backend.connect(**config)

return cls._connections[key]
return self._connection

def load(self) -> ir.Table:
load_path = self._get_load_path()
Expand Down
32 changes: 11 additions & 21 deletions kedro-datasets/kedro_datasets/ibis/table_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
from kedro.io import AbstractDataset, DatasetError

from kedro_datasets import KedroDeprecationWarning
from kedro_datasets._utils import ConnectionMixin

if TYPE_CHECKING:
from ibis import BaseBackend


class TableDataset(AbstractDataset[ir.Table, ir.Table]):
class TableDataset(ConnectionMixin, AbstractDataset[ir.Table, ir.Table]):
"""``TableDataset`` loads/saves data from/to Ibis table expressions.

Example usage for the
Expand Down Expand Up @@ -70,7 +71,7 @@ class TableDataset(AbstractDataset[ir.Table, ir.Table]):
"overwrite": True,
}

_connections: ClassVar[dict[tuple[tuple[str, str]], BaseBackend]] = {}
_CONNECTION_GROUP: ClassVar[str] = "ibis"

def __init__( # noqa: PLR0913
self,
Expand Down Expand Up @@ -145,28 +146,17 @@ def __init__( # noqa: PLR0913

self._materialized = self._save_args.pop("materialized")

def _connect(self) -> BaseBackend:
import ibis

config = deepcopy(self._connection_config)
backend = getattr(ibis, config.pop("backend"))
return backend.connect(**config)

@property
def connection(self) -> BaseBackend:
"""The ``Backend`` instance for the connection configuration."""

def hashable(value):
"""Return a hashable key for a potentially-nested object."""
if isinstance(value, dict):
return tuple((k, hashable(v)) for k, v in sorted(value.items()))
if isinstance(value, list):
return tuple(hashable(x) for x in value)
return value

cls = type(self)
key = hashable(self._connection_config)
if key not in cls._connections:
import ibis

config = deepcopy(self._connection_config)
backend = getattr(ibis, config.pop("backend"))
cls._connections[key] = backend.connect(**config)

return cls._connections[key]
return self._connection

def load(self) -> ir.Table:
if self._filepath is not None:
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/tests/ibis/test_file_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def dummy_table():


class TestFileDataset:
def test_save_and_load(self, file_dataset, dummy_table, database):
def test_save_and_load(self, file_dataset, dummy_table):
"""Test saving and reloading the data set."""
file_dataset.save(dummy_table)
reloaded = file_dataset.load()
Expand Down Expand Up @@ -127,7 +127,7 @@ def test_connection_config(self, mocker, file_dataset, connection_config, key):
)
mocker.patch(f"ibis.{backend}")
file_dataset.load()
assert key in file_dataset._connections
assert ("ibis", key) in file_dataset._connections


class TestFileDatasetVersioned:
Expand Down
22 changes: 20 additions & 2 deletions kedro-datasets/tests/ibis/test_table_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from kedro.io import DatasetError
from pandas.testing import assert_frame_equal

from kedro_datasets.ibis import TableDataset
from kedro_datasets.ibis import FileDataset, TableDataset

_SENTINEL = object()

Expand Down Expand Up @@ -56,6 +56,17 @@ def dummy_table(table_dataset_from_csv):
return table_dataset_from_csv.load()


@pytest.fixture
def file_dataset(filepath_csv, connection_config, load_args, save_args):
return FileDataset(
filepath=filepath_csv,
file_format="csv",
connection=connection_config,
load_args=load_args,
save_args=save_args,
)


class TestTableDataset:
def test_save_and_load(self, table_dataset, dummy_table, database):
"""Test saving and reloading the dataset."""
Expand Down Expand Up @@ -146,4 +157,11 @@ def test_connection_config(self, mocker, table_dataset, connection_config, key):
)
mocker.patch(f"ibis.{backend}")
table_dataset.load()
assert key in table_dataset._connections
assert ("ibis", key) in table_dataset._connections

def test_save_data_loaded_using_file_dataset(self, file_dataset, table_dataset):
"""Test interoperability of Ibis datasets sharing a database."""
dummy_table = file_dataset.load()
assert not table_dataset.exists()
table_dataset.save(dummy_table)
assert table_dataset.exists()
Loading