Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(datasets): Implement spark.GBQQueryDataset for reading data from BigQuery as a spark dataframe using SQL query #971

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
2 changes: 2 additions & 0 deletions kedro-datasets/kedro_datasets/spark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
SparkHiveDataset: Any
SparkJDBCDataset: Any
SparkStreamingDataset: Any
GBQQueryDataset: Any

__getattr__, __dir__, __all__ = lazy.attach(
__name__,
Expand All @@ -19,5 +20,6 @@
"spark_hive_dataset": ["SparkHiveDataset"],
"spark_jdbc_dataset": ["SparkJDBCDataset"],
"spark_streaming_dataset": ["SparkStreamingDataset"],
"spark_gbq_dataset": ["GBQQueryDataset"],
},
)
240 changes: 240 additions & 0 deletions kedro-datasets/kedro_datasets/spark/spark_gbq_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
"""``AbstractDataset`` implementation to access Spark dataframes using
``pyspark``.
"""

from __future__ import annotations

import base64
import json
import logging
from copy import deepcopy
from typing import Any, NoReturn

from kedro.io import AbstractDataset, DatasetError
from py4j.protocol import Py4JJavaError
from pyspark.sql import DataFrame

from kedro_datasets._utils.spark_utils import get_spark
import copy
import fsspec

from kedro.io.core import get_protocol_and_path

logger = logging.getLogger(__name__)


class GBQQueryDataset(AbstractDataset[None, DataFrame]):
"""``GBQQueryDataset`` loads data from Google BigQuery with a SQL query using BigQuery Spark connector.

Example usage for the
`YAML API <https://docs.kedro.org/en/stable/data/\
data_catalog_yaml_examples.html>`_:

.. code-block:: yaml

my_gbq_spark_data:
type: spark.GBQQueryDataset
sql: |
SELECT * FROM your_table
materialization_dataset: your_dataset
materialization_project: your_project
bq_credentials:
file: /path/to/your/credentials.json
fs_credentials:
key: value

Example usage for the
`Python API <https://docs.kedro.org/en/stable/data/\
advanced_data_catalog_usage.html>`_:

.. code-block:: pycon

>>> from kedro_datasets.spark import GBQQueryDataset
>>> import pyspark.sql as sql
>>>
>>> # Define your SQL query
>>> sql = "SELECT * FROM your_table"
>>>
>>> # Initialize dataset
>>> dataset = GBQQueryDataset(
... sql=sql,
... materialization_dataset="your_dataset",
... materialization_project="your_project", # optional
... bq_credentials=dict(file="/path/to/your/credentials.json"), # optional
... fs_credentials=dict(key="value"), # optional
... )
>>>
>>> # Load data
>>> df = dataset.load()
>>>
>>> # Example output
>>> df.show()
"""

_VALID_CREDENTIALS_KEYS = {"base64", "file", "json"}

def __init__( # noqa: PLR0913
self,
materialization_dataset: str,
sql: str | None = None,
filepath: str | None = None,
materialization_project: str | None = None,
load_args: dict[str, Any] | None = None,
fs_args: dict[str, Any] | None = None,
bq_credentials: dict[str, Any] | None = None,
fs_credentials: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""Creates a new instance of ``SparkGBQDataSet`` pointing to a specific table in Google BigQuery.

Args:
materialization_dataset: The name of the dataset to materialize the query results.
sql: The SQL query to execute
filepath: A path to a file with a sql query statement.
materialization_project: The name of the project to materialize the query results.
Optional (defaults to the project id set by the credentials).
load_args: Load args passed to Spark DataFrameReader load method.
It is dependent on the selected file format. You can find
a list of read options for each supported format
in Spark DataFrame read documentation:
https://spark.apache.org/docs/latest/api/python/getting_started/quickstart_df.html
bq_credentials: Credentials to authenticate spark session with google bigquery.
Dictionary with key specifying the type of credentials ('base64', 'file', 'json').
Alternatively, you can pass the credentials in load_args as follows:

When passing as `base64`:
`load_args={"credentials": "your_credentials"}`

When passing as a `file`:
`load_args={"credentialsFile": "/path/to/your/credentials.json"}`

When passing as a json object:
NOT SUPPORTED

Read more here:
https://github.com/GoogleCloudDataproc/spark-bigquery-connector?tab=readme-ov-file#how-do-i-authenticate-outside-gce--dataproc
fs_credentials: Credentials to authenticate with the filesystem.
The keyword args would be directly passed to fsspec.filesystem constructor.
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.
"""
if sql and filepath:
raise DatasetError(
"'sql' and 'filepath' arguments cannot both be provided."
"Please only provide one."
)

if not (sql or filepath):
raise DatasetError(
"'sql' and 'filepath' arguments cannot both be empty."
"Please provide a sql query or path to a sql query file."
)

if sql:
self._sql = sql
self._filepath = None
else:
# TODO: Add protocol specific handling cases for different filesystems.
protocol, path = get_protocol_and_path(str(filepath))

self._fs_args = fs_args or {}
self._fs_credentials = fs_credentials or {}
self._fs_protocol = protocol

self._fs = fsspec.filesystem(
self._protocol, **self._fs_credentials, **self._fs_args
)
self._filepath = path

self._materialization_dataset = materialization_dataset
self._materialization_project = materialization_project
self._load_args = load_args or {}
self._bq_credentials = bq_credentials or {}

self._metadata = metadata

def _get_spark_bq_credentials(self) -> dict[str, str]:
if not self._bq_credentials:
return {}

if len(self._bq_credentials) > 1:
raise ValueError(
"Please provide only one of 'base64', 'file' or 'json' key in the credentials. "
f"You provided: {list(self._bq_credentials.keys())}"
)
if self._bq_credentials.get("base64"):
return {
"credentials": self._bq_credentials["base64"],
}
if self._bq_credentials.get("file"):
return {
"credentialsFile": self._bq_credentials["file"],
}
if self._bq_credentials.get("json"):
creds_b64 = base64.b64encode(
json.dumps(self._bq_credentials["json"]).encode("utf-8")
).decode("utf-8")
return {"credentials": creds_b64}

raise ValueError(
f"Please provide one of 'base64', 'file' or 'json' key in the credentials. You provided: {list(self._bq_credentials.keys())[0]}"
)

def _load_sql_from_filepath(self) -> str:
with self._fs.open(self._filepath, "r") as f:
return f.read()

def _get_sql(self) -> str:
if self._sql:
return self._sql
else:
return self._load_sql_from_filepath()

def _get_spark_load_args(self) -> dict[str, Any]:
spark_load_args = deepcopy(self._load_args)
spark_load_args["query"] = self._get_sql()
spark_load_args["materializationDataset"] = self._materialization_dataset

if self._materialization_project:
spark_load_args["materializationProject"] = self._materialization_project

spark_load_args.update(self._get_spark_bq_credentials())

try:
views_enabled_spark_conf = get_spark().conf.get("viewsEnabled")
except Py4JJavaError:
views_enabled_spark_conf = "false"

if views_enabled_spark_conf != "true":
spark_load_args["viewsEnabled"] = "true"
logger.warning(
"The 'viewsEnabled' configuration is not set to 'true' in the SparkSession. "
"This is required for the Spark BigQuery connector to read via a SQL query. "
"Setting 'viewsEnabled' to 'true' for the current query read operation. "
"This may incur additional costs!"
)

return spark_load_args

def load(self) -> DataFrame:
"""Loads data from Google BigQuery.

Returns:
A Spark DataFrame.
"""
spark = get_spark()
read_obj = spark.read.format("bigquery")

return read_obj.load(**self._get_spark_load_args())

def save(self, data: None) -> NoReturn:
raise DatasetError("'save' is not supported on GBQQueryDataset")

def _describe(self) -> dict[str, Any]:
return {
"sql": self._sql,
"materialization_dataset": self._materialization_dataset,
"materialization_project": self._materialization_project,
"load_args": self._load_args,
"metadata": self._metadata,
}
140 changes: 140 additions & 0 deletions kedro-datasets/tests/spark/test_spark_gbq_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import pytest
from pyspark.sql import SparkSession
from kedro_datasets.spark.spark_gbq_dataset import GBQQueryDataset
import json
import base64
from kedro.io import DatasetError
import re


SQL_QUERY = "SELECT * FROM table"
MATERIALIZATION_DATASET = "dataset"
MATERIALIZATION_PROJECT = "project"
LOAD_ARGS = {"key": "value"}
REQUIRED_INIT_ARGS = {
"sql": SQL_QUERY,
"materialization_dataset": MATERIALIZATION_DATASET,
}


@pytest.fixture
def spark_session(mocker):
return mocker.MagicMock(spec=SparkSession)


@pytest.fixture
def dummy_save_dataset(spark_session):
return spark_session.createDataFrame([("foo",)], ["bar"])


@pytest.fixture
def gbq_query_dataset():
return GBQQueryDataset(
sql=SQL_QUERY,
materialization_dataset=MATERIALIZATION_DATASET,
materialization_project=MATERIALIZATION_PROJECT,
load_args=LOAD_ARGS,
)


def test_save_not_implemented(gbq_query_dataset, dummy_save_dataset):
with pytest.raises(
DatasetError,
match=r"'save' is not supported on GBQQueryDataset",
):
gbq_query_dataset.save(dummy_save_dataset)


@pytest.mark.parametrize(
"credentials, expected_credentials",
[
({"base64": "base64_creds"}, {"credentials": "base64_creds"}),
({"file": "/path/to/creds.json"}, {"credentialsFile": "/path/to/creds.json"}),
(
{"json": {"type": "service_account"}},
{
"credentials": base64.b64encode(
json.dumps({"type": "service_account"}).encode("utf-8")
).decode("utf-8")
},
),
({}, {}),
],
)
def test_get_spark_bq_credentials(gbq_query_dataset, credentials, expected_credentials):
gbq_query_dataset._bq_credentials = credentials
assert gbq_query_dataset._get_spark_bq_credentials() == expected_credentials


def test_invalid_bq_credentials_key(gbq_query_dataset):

invalid_cred_key = "invalid_cred_key"
gbq_query_dataset._bq_credentials = {invalid_cred_key: "value"}
with pytest.raises(
ValueError,
match=f"Please provide one of 'base64', 'file' or 'json' key in the credentials. You provided: {invalid_cred_key}",
):
gbq_query_dataset._get_spark_bq_credentials()


@pytest.mark.parametrize(
"credentials",
[
{"base64": "base64_creds", "file": "/path/to/creds.json"},
{"base64": "base64_creds", "json": {"type": "service_account"}},
{"file": "/path/to/creds.json", "json": {"type": "service_account"}},
{
"base64": "base64_creds",
"file": "/path/to/creds.json",
"json": {"type": "service_account"},
},
{"base64": "base64_creds", "invalid_key": "value"},
],
)
def test_more_than_one_bq_credentials_key(gbq_query_dataset, credentials):
gbq_query_dataset._bq_credentials = credentials
pattern = re.escape(
f"Please provide only one of 'base64', 'file' or 'json' key in the credentials. You provided: {list(credentials.keys())}"
)
with pytest.raises(
ValueError,
match=pattern,
):
gbq_query_dataset._get_spark_bq_credentials()


@pytest.mark.parametrize(
"init_args, expected_load_args",
[
(
REQUIRED_INIT_ARGS,
{
"query": REQUIRED_INIT_ARGS["sql"],
"materializationDataset": REQUIRED_INIT_ARGS["materialization_dataset"],
"viewsEnabled": "true",
},
),
(
{**REQUIRED_INIT_ARGS, "materialization_project": MATERIALIZATION_PROJECT},
{
"query": REQUIRED_INIT_ARGS["sql"],
"materializationDataset": REQUIRED_INIT_ARGS["materialization_dataset"],
"materializationProject": MATERIALIZATION_PROJECT,
"viewsEnabled": "true",
},
),
],
)
def test_load(mocker, spark_session, init_args, expected_load_args):
gbq_query_dataset = GBQQueryDataset(**init_args)
mocker.patch(
"kedro_datasets.spark.spark_gbq_dataset.get_spark", return_value=spark_session
)
read_obj = mocker.MagicMock()
spark_session.read.format.return_value = read_obj
read_obj.load.return_value = mocker.MagicMock()

gbq_query_dataset.load()

spark_session.read.format.assert_called_once_with("bigquery")
read_obj.load.assert_called_once_with(**expected_load_args)
Loading