Skip to content

Commit

Permalink
Add tests for spark.GBQQueryDataset
Browse files Browse the repository at this point in the history
Signed-off-by: Abhishek Bhatia <[email protected]>
  • Loading branch information
abhi8893 committed Jan 1, 2025
1 parent 4d7091a commit cc6c930
Showing 1 changed file with 137 additions and 0 deletions.
137 changes: 137 additions & 0 deletions kedro-datasets/tests/spark/test_spark_gbq_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import pytest
from pyspark.sql import SparkSession
from kedro_datasets.spark.spark_gbq_dataset import GBQQueryDataset
import json
import base64
from kedro.io import DatasetError
import re


SQL_QUERY = "SELECT * FROM table"
MATERIALIZATION_DATASET = "dataset"
MATERIALIZATION_PROJECT = "project"
LOAD_ARGS = {"key": "value"}
REQUIRED_INIT_ARGS = {
"sql": SQL_QUERY,
"materialization_dataset": MATERIALIZATION_DATASET,
}
DUMMMY_SAVE_DATA = "dummy_save_data"


@pytest.fixture
def spark_session(mocker):
return mocker.MagicMock(spec=SparkSession)


@pytest.fixture
def gbq_query_dataset():
return GBQQueryDataset(
sql=SQL_QUERY,
materialization_dataset=MATERIALIZATION_DATASET,
materialization_project=MATERIALIZATION_PROJECT,
credentials=None,
load_args=LOAD_ARGS,
)


def test_save_not_implemented(gbq_query_dataset):
with pytest.raises(
DatasetError,
match=r"'save' is not supported on GBQQueryDataset",
):
gbq_query_dataset.save(DUMMMY_SAVE_DATA)


@pytest.mark.parametrize(
"credentials, expected_credentials",
[
({"base64": "base64_creds"}, {"credentials": "base64_creds"}),
({"file": "/path/to/creds.json"}, {"credentialsFile": "/path/to/creds.json"}),
(
{"json": {"type": "service_account"}},
{
"credentials": base64.b64encode(
json.dumps({"type": "service_account"}).encode("utf-8")
).decode("utf-8")
},
),
({}, {}),
],
)
def test_get_spark_credentials(gbq_query_dataset, credentials, expected_credentials):
gbq_query_dataset._credentials = credentials
assert gbq_query_dataset._get_spark_credentials() == expected_credentials


def test_invalid_credentials_key(gbq_query_dataset):

invalid_cred_key = "invalid_cred_key"
gbq_query_dataset._credentials = {invalid_cred_key: "value"}
with pytest.raises(
ValueError,
match=f"Please provide one of 'base64', 'file' or 'json' key in the credentials. You provided: {invalid_cred_key}",
):
gbq_query_dataset._get_spark_credentials()


@pytest.mark.parametrize(
"credentials",
[
{"base64": "base64_creds", "file": "/path/to/creds.json"},
{"base64": "base64_creds", "json": {"type": "service_account"}},
{"file": "/path/to/creds.json", "json": {"type": "service_account"}},
{
"base64": "base64_creds",
"file": "/path/to/creds.json",
"json": {"type": "service_account"},
},
{"base64": "base64_creds", "invalid_key": "value"},
],
)
def test_more_than_one_credentials_key(gbq_query_dataset, credentials):
gbq_query_dataset._credentials = credentials
pattern = re.escape(
f"Please provide only one of 'base64', 'file' or 'json' key in the credentials. You provided: {list(credentials.keys())}"
)
with pytest.raises(
ValueError,
match=pattern,
):
gbq_query_dataset._get_spark_credentials()


@pytest.mark.parametrize(
"init_args, expected_load_args",
[
(
REQUIRED_INIT_ARGS,
{
"query": REQUIRED_INIT_ARGS["sql"],
"materializationDataset": REQUIRED_INIT_ARGS["materialization_dataset"],
"viewsEnabled": "true",
},
),
(
{**REQUIRED_INIT_ARGS, "materialization_project": MATERIALIZATION_PROJECT},
{
"query": REQUIRED_INIT_ARGS["sql"],
"materializationDataset": REQUIRED_INIT_ARGS["materialization_dataset"],
"materializationProject": MATERIALIZATION_PROJECT,
"viewsEnabled": "true",
},
),
],
)
def test_load(mocker, spark_session, init_args, expected_load_args):
gbq_query_dataset = GBQQueryDataset(**init_args)
mocker.patch(
"kedro_datasets.spark.spark_gbq_dataset.get_spark", return_value=spark_session
)
read_obj = mocker.MagicMock()
spark_session.read.format.return_value = read_obj
read_obj.load.return_value = mocker.MagicMock()

gbq_query_dataset.load()

spark_session.read.format.assert_called_once_with("bigquery")
read_obj.load.assert_called_once_with(**expected_load_args)

0 comments on commit cc6c930

Please sign in to comment.