From f5d828f184cb6b17efee256eb5bcebcbcc815177 Mon Sep 17 00:00:00 2001 From: Deepyaman Datta Date: Wed, 11 Dec 2024 00:21:10 -0700 Subject: [PATCH] fix(datasets): make `GBQTableDataset` serializable Signed-off-by: Deepyaman Datta --- .../kedro_datasets/pandas/gbq_dataset.py | 45 +++++++++++-------- .../tests/pandas/test_gbq_dataset.py | 7 +-- 2 files changed, 28 insertions(+), 24 deletions(-) diff --git a/kedro-datasets/kedro_datasets/pandas/gbq_dataset.py b/kedro-datasets/kedro_datasets/pandas/gbq_dataset.py index a38b3b82c..6130a6a19 100644 --- a/kedro-datasets/kedro_datasets/pandas/gbq_dataset.py +++ b/kedro-datasets/kedro_datasets/pandas/gbq_dataset.py @@ -22,8 +22,10 @@ validate_on_forbidden_chars, ) +from kedro_datasets._utils import ConnectionMixin -class GBQTableDataset(AbstractDataset[None, pd.DataFrame]): + +class GBQTableDataset(ConnectionMixin, AbstractDataset[None, pd.DataFrame]): """``GBQTableDataset`` loads and saves data from/to Google BigQuery. It uses pandas-gbq to read and write from/to BigQuery table. @@ -68,6 +70,8 @@ class GBQTableDataset(AbstractDataset[None, pd.DataFrame]): DEFAULT_LOAD_ARGS: dict[str, Any] = {} DEFAULT_SAVE_ARGS: dict[str, Any] = {"progress_bar": False} + _CONNECTION_GROUP: ClassVar[str] = "bigquery" + def __init__( # noqa: PLR0913 self, *, @@ -114,18 +118,14 @@ def __init__( # noqa: PLR0913 self._validate_location() validate_on_forbidden_chars(dataset=dataset, table_name=table_name) - if isinstance(credentials, dict): - credentials = Credentials(**credentials) - self._dataset = dataset self._table_name = table_name self._project_id = project - self._credentials = credentials - self._client = bigquery.Client( - project=self._project_id, - credentials=self._credentials, - location=self._save_args.get("location"), - ) + self._connection_config = { + "project": self._project_id, + "credentials": credentials, + "location": self._save_args.get("location"), + } self.metadata = metadata @@ -137,12 +137,24 @@ def _describe(self) -> dict[str, Any]: "save_args": self._save_args, } + def _connect(self) -> bigquery.Client: + credentials = self._connection_config["credentials"] + if isinstance(credentials, dict): + # Only create `Credentials` object once for consistent hash. + credentials = Credentials(**credentials) + + return bigquery.Client( + project=self._connection_config["project"], + credentials=credentials, + location=self._connection_config["location"], + ) + def load(self) -> pd.DataFrame: sql = f"select * from {self._dataset}.{self._table_name}" # nosec self._load_args.setdefault("query_or_table", sql) return pd_gbq.read_gbq( project_id=self._project_id, - credentials=self._credentials, + credentials=self._connection._credentials, **self._load_args, ) @@ -151,14 +163,14 @@ def save(self, data: pd.DataFrame) -> None: dataframe=data, destination_table=f"{self._dataset}.{self._table_name}", project_id=self._project_id, - credentials=self._credentials, + credentials=self._connection._credentials, **self._save_args, ) def _exists(self) -> bool: - table_ref = self._client.dataset(self._dataset).table(self._table_name) + table_ref = self._connection.dataset(self._dataset).table(self._table_name) try: - self._client.get_table(table_ref) + self._connection.get_table(table_ref) return True except NotFound: return False @@ -268,11 +280,6 @@ def __init__( # noqa: PLR0913 credentials = Credentials(**credentials) self._credentials = credentials - self._client = bigquery.Client( - project=self._project_id, - credentials=self._credentials, - location=self._load_args.get("location"), - ) # load sql query from arg or from file if sql: diff --git a/kedro-datasets/tests/pandas/test_gbq_dataset.py b/kedro-datasets/tests/pandas/test_gbq_dataset.py index 19767f15b..03f7f5fab 100644 --- a/kedro-datasets/tests/pandas/test_gbq_dataset.py +++ b/kedro-datasets/tests/pandas/test_gbq_dataset.py @@ -141,6 +141,7 @@ def test_save_load_data(self, gbq_dataset, dummy_dataframe, mocker): ) mocked_read_gbq.return_value = dummy_dataframe mocked_df = mocker.Mock() + gbq_dataset._connection._credentials = None gbq_dataset.save(mocked_df) loaded_data = gbq_dataset.load() @@ -205,8 +206,8 @@ def test_credentials_propagation(self, mocker): credentials=credentials, project=PROJECT, ) + dataset.exists() # Do something to trigger the client creation. - assert dataset._credentials == credentials_obj mocked_credentials.assert_called_once_with(**credentials) mocked_bigquery.Client.assert_called_once_with( project=PROJECT, credentials=credentials_obj, location=None @@ -238,7 +239,6 @@ def test_credentials_propagation(self, mocker): "kedro_datasets.pandas.gbq_dataset.Credentials", return_value=credentials_obj, ) - mocked_bigquery = mocker.patch("kedro_datasets.pandas.gbq_dataset.bigquery") dataset = GBQQueryDataset( sql=SQL_QUERY, @@ -248,9 +248,6 @@ def test_credentials_propagation(self, mocker): assert dataset._credentials == credentials_obj mocked_credentials.assert_called_once_with(**credentials) - mocked_bigquery.Client.assert_called_once_with( - project=PROJECT, credentials=credentials_obj, location=None - ) def test_load(self, mocker, gbq_sql_dataset, dummy_dataframe): """Test `load` method invocation"""