Skip to content

Commit

Permalink
implement option to pass bq, fs credentials and fs args
Browse files Browse the repository at this point in the history
Signed-off-by: Abhishek Bhatia <[email protected]>
  • Loading branch information
abhi8893 committed Jan 3, 2025
1 parent a2be311 commit 6908fab
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 29 deletions.
97 changes: 78 additions & 19 deletions kedro-datasets/kedro_datasets/spark/spark_gbq_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
from pyspark.sql import DataFrame

from kedro_datasets._utils.spark_utils import get_spark
import copy
import fsspec

from kedro.io.core import get_protocol_and_path

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -68,75 +72,130 @@ class GBQQueryDataset(AbstractDataset[None, DataFrame]):

def __init__( # noqa: PLR0913
self,
sql: str,
materialization_dataset: str,
sql: str | None = None,
filepath: str | None = None,
materialization_project: str | None = None,
load_args: dict[str, Any] | None = None,
credentials: dict[str, Any] | None = None,
fs_args: dict[str, Any] | None = None,
bq_credentials: dict[str, Any] | None = None,
fs_credentials: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""Creates a new instance of ``SparkGBQDataSet`` pointing to a specific table in Google BigQuery.
Args:
sql: SQL query to execute.
materialization_dataset: The name of the dataset to materialize the query results.
sql: The SQL query to execute
filepath: A path to a file with a sql query statement.
materialization_project: The name of the project to materialize the query results.
Optional (defaults to the project id set by the credentials).
load_args: Load args passed to Spark DataFrameReader load method.
It is dependent on the selected file format. You can find
a list of read options for each supported format
in Spark DataFrame read documentation:
https://spark.apache.org/docs/latest/api/python/getting_started/quickstart_df.html
credentials: Credentials to authenticate spark session with google bigquery.
bq_credentials: Credentials to authenticate spark session with google bigquery.
Dictionary with key specifying the type of credentials ('base64', 'file', 'json').
Alternatively, you can pass the credentials in load_args as follows:
When passing as `base64`:
`load_args={"credentials": "your_credentials"}`
When passing as a `file`:
`load_args={"credentialsFile": "/path/to/your/credentials.json"}`
When passing as a json object:
NOT SUPPORTED
Read more here:
https://github.com/GoogleCloudDataproc/spark-bigquery-connector?tab=readme-ov-file#how-do-i-authenticate-outside-gce--dataproc
fs_credentials: Credentials to authenticate with the filesystem.
The keyword args would be directly passed to fsspec.filesystem constructor.
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.
"""
self._sql = sql
if sql and filepath:
raise DatasetError(
"'sql' and 'filepath' arguments cannot both be provided."
"Please only provide one."
)

if not (sql or filepath):
raise DatasetError(
"'sql' and 'filepath' arguments cannot both be empty."
"Please provide a sql query or path to a sql query file."
)

if sql:
self._sql = sql
self._filepath = None
else:
# TODO: Add protocol specific handling cases for different filesystems.
protocol, path = get_protocol_and_path(str(filepath))

self._fs_args = fs_args or {}
self._fs_credentials = fs_credentials or {}
self._fs_protocol = protocol

self._fs = fsspec.filesystem(
self._protocol, **self._fs_credentials, **self._fs_args
)
self._filepath = path

self._materialization_dataset = materialization_dataset
self._materialization_project = materialization_project
self._load_args = load_args or {}
self._credentials = credentials or {}
self._bq_credentials = bq_credentials or {}

self._metadata = metadata

def _get_spark_credentials(self) -> dict[str, str]:
if not self._credentials:
def _get_spark_bq_credentials(self) -> dict[str, str]:
if not self._bq_credentials:
return {}

if len(self._credentials) > 1:
if len(self._bq_credentials) > 1:
raise ValueError(
"Please provide only one of 'base64', 'file' or 'json' key in the credentials. "
f"You provided: {list(self._credentials.keys())}"
f"You provided: {list(self._bq_credentials.keys())}"
)
if self._credentials.get("base64"):
if self._bq_credentials.get("base64"):
return {
"credentials": self._credentials["base64"],
"credentials": self._bq_credentials["base64"],
}
if self._credentials.get("file"):
if self._bq_credentials.get("file"):
return {
"credentialsFile": self._credentials["file"],
"credentialsFile": self._bq_credentials["file"],
}
if self._credentials.get("json"):
if self._bq_credentials.get("json"):
creds_b64 = base64.b64encode(
json.dumps(self._credentials["json"]).encode("utf-8")
json.dumps(self._bq_credentials["json"]).encode("utf-8")
).decode("utf-8")
return {"credentials": creds_b64}

raise ValueError(
f"Please provide one of 'base64', 'file' or 'json' key in the credentials. You provided: {list(self._credentials.keys())[0]}"
f"Please provide one of 'base64', 'file' or 'json' key in the credentials. You provided: {list(self._bq_credentials.keys())[0]}"
)

def _load_sql_from_filepath(self) -> str:
with self._fs.open(self._filepath, "r") as f:
return f.read()

def _get_sql(self) -> str:
if self._sql:
return self._sql
else:
return self._load_sql_from_filepath()

def _get_spark_load_args(self) -> dict[str, Any]:
spark_load_args = deepcopy(self._load_args)
spark_load_args["query"] = self._sql
spark_load_args["query"] = self._get_sql()
spark_load_args["materializationDataset"] = self._materialization_dataset

if self._materialization_project:
spark_load_args["materializationProject"] = self._materialization_project

spark_load_args.update(self._get_spark_credentials())
spark_load_args.update(self._get_spark_bq_credentials())

try:
views_enabled_spark_conf = get_spark().conf.get("viewsEnabled")
Expand Down
19 changes: 9 additions & 10 deletions kedro-datasets/tests/spark/test_spark_gbq_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def gbq_query_dataset():
sql=SQL_QUERY,
materialization_dataset=MATERIALIZATION_DATASET,
materialization_project=MATERIALIZATION_PROJECT,
credentials=None,
load_args=LOAD_ARGS,
)

Expand Down Expand Up @@ -62,20 +61,20 @@ def test_save_not_implemented(gbq_query_dataset, dummy_save_dataset):
({}, {}),
],
)
def test_get_spark_credentials(gbq_query_dataset, credentials, expected_credentials):
gbq_query_dataset._credentials = credentials
assert gbq_query_dataset._get_spark_credentials() == expected_credentials
def test_get_spark_bq_credentials(gbq_query_dataset, credentials, expected_credentials):
gbq_query_dataset._bq_credentials = credentials
assert gbq_query_dataset._get_spark_bq_credentials() == expected_credentials


def test_invalid_credentials_key(gbq_query_dataset):
def test_invalid_bq_credentials_key(gbq_query_dataset):

invalid_cred_key = "invalid_cred_key"
gbq_query_dataset._credentials = {invalid_cred_key: "value"}
gbq_query_dataset._bq_credentials = {invalid_cred_key: "value"}
with pytest.raises(
ValueError,
match=f"Please provide one of 'base64', 'file' or 'json' key in the credentials. You provided: {invalid_cred_key}",
):
gbq_query_dataset._get_spark_credentials()
gbq_query_dataset._get_spark_bq_credentials()


@pytest.mark.parametrize(
Expand All @@ -92,16 +91,16 @@ def test_invalid_credentials_key(gbq_query_dataset):
{"base64": "base64_creds", "invalid_key": "value"},
],
)
def test_more_than_one_credentials_key(gbq_query_dataset, credentials):
gbq_query_dataset._credentials = credentials
def test_more_than_one_bq_credentials_key(gbq_query_dataset, credentials):
gbq_query_dataset._bq_credentials = credentials
pattern = re.escape(
f"Please provide only one of 'base64', 'file' or 'json' key in the credentials. You provided: {list(credentials.keys())}"
)
with pytest.raises(
ValueError,
match=pattern,
):
gbq_query_dataset._get_spark_credentials()
gbq_query_dataset._get_spark_bq_credentials()


@pytest.mark.parametrize(
Expand Down

0 comments on commit 6908fab

Please sign in to comment.