Skip to content

Commit

Permalink
fix(datasets): make GBQTableDataset serializable
Browse files Browse the repository at this point in the history
Signed-off-by: Deepyaman Datta <[email protected]>
  • Loading branch information
deepyaman committed Dec 11, 2024
1 parent 50fa3c0 commit e3e70f3
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 19 deletions.
44 changes: 26 additions & 18 deletions kedro-datasets/kedro_datasets/pandas/gbq_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
validate_on_forbidden_chars,
)

from kedro_datasets._utils import ConnectionMixin

class GBQTableDataset(AbstractDataset[None, pd.DataFrame]):

class GBQTableDataset(ConnectionMixin, AbstractDataset[None, pd.DataFrame]):
"""``GBQTableDataset`` loads and saves data from/to Google BigQuery.
It uses pandas-gbq to read and write from/to BigQuery table.
Expand Down Expand Up @@ -68,6 +70,8 @@ class GBQTableDataset(AbstractDataset[None, pd.DataFrame]):
DEFAULT_LOAD_ARGS: dict[str, Any] = {}
DEFAULT_SAVE_ARGS: dict[str, Any] = {"progress_bar": False}

_CONNECTION_GROUP: ClassVar[str] = "bigquery"

def __init__( # noqa: PLR0913
self,
*,
Expand Down Expand Up @@ -114,18 +118,15 @@ def __init__( # noqa: PLR0913
self._validate_location()
validate_on_forbidden_chars(dataset=dataset, table_name=table_name)

if isinstance(credentials, dict):
credentials = Credentials(**credentials)

self._dataset = dataset
self._table_name = table_name
self._project_id = project
self._credentials = credentials
self._client = bigquery.Client(
project=self._project_id,
credentials=self._credentials,
location=self._save_args.get("location"),
)
self._connection_config = {
"project": self._project_id,
"credentials": self._credentials,
"location": self._save_args.get("location"),
}

self.metadata = metadata

Expand All @@ -137,12 +138,24 @@ def _describe(self) -> dict[str, Any]:
"save_args": self._save_args,
}

def _connect(self) -> bigquery.Client:
credentials = self._connection_config["credentials"]
if isinstance(credentials, dict):
# Only create `Credentials` object once for consistent hash.
credentials = Credentials(**credentials)

return bigquery.Client(
project=self._connection_config["project"],
credentials=credentials,
location=self._connection_config["location"],
)

def load(self) -> pd.DataFrame:
sql = f"select * from {self._dataset}.{self._table_name}" # nosec
self._load_args.setdefault("query_or_table", sql)
return pd_gbq.read_gbq(
project_id=self._project_id,
credentials=self._credentials,
credentials=self._connection._credentials,
**self._load_args,
)

Expand All @@ -151,14 +164,14 @@ def save(self, data: pd.DataFrame) -> None:
dataframe=data,
destination_table=f"{self._dataset}.{self._table_name}",
project_id=self._project_id,
credentials=self._credentials,
credentials=self._connection._credentials,
**self._save_args,
)

def _exists(self) -> bool:
table_ref = self._client.dataset(self._dataset).table(self._table_name)
table_ref = self._connection.dataset(self._dataset).table(self._table_name)
try:
self._client.get_table(table_ref)
self._connection.get_table(table_ref)
return True
except NotFound:
return False
Expand Down Expand Up @@ -268,11 +281,6 @@ def __init__( # noqa: PLR0913
credentials = Credentials(**credentials)

self._credentials = credentials
self._client = bigquery.Client(
project=self._project_id,
credentials=self._credentials,
location=self._load_args.get("location"),
)

# load sql query from arg or from file
if sql:
Expand Down
3 changes: 2 additions & 1 deletion kedro-datasets/tests/pandas/test_gbq_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def test_save_load_data(self, gbq_dataset, dummy_dataframe, mocker):
)
mocked_read_gbq.return_value = dummy_dataframe
mocked_df = mocker.Mock()
gbq_dataset._connection._credentials = None

gbq_dataset.save(mocked_df)
loaded_data = gbq_dataset.load()
Expand Down Expand Up @@ -206,7 +207,7 @@ def test_credentials_propagation(self, mocker):
project=PROJECT,
)

assert dataset._credentials == credentials_obj
assert dataset._connection._credentials == credentials_obj
mocked_credentials.assert_called_once_with(**credentials)
mocked_bigquery.Client.assert_called_once_with(
project=PROJECT, credentials=credentials_obj, location=None
Expand Down

0 comments on commit e3e70f3

Please sign in to comment.