diff --git a/kedro-datasets/tests/databricks/conftest.py b/kedro-datasets/tests/databricks/conftest.py index ccc0c78ad..336e8168d 100644 --- a/kedro-datasets/tests/databricks/conftest.py +++ b/kedro-datasets/tests/databricks/conftest.py @@ -4,11 +4,15 @@ discover them automatically. More info here: https://docs.pytest.org/en/latest/fixture.html """ +import os # importlib_metadata needs backport for python 3.8 and older import importlib_metadata +import pandas as pd import pytest from pyspark.sql import SparkSession +from pyspark.sql.types import IntegerType, StringType, StructField, StructType + DELTA_VERSION = importlib_metadata.version("delta-spark") @@ -28,3 +32,170 @@ def spark_session(): spark.sql("create database if not exists test") yield spark spark.sql("drop database test cascade;") + + +@pytest.fixture +def sample_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def upsert_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Alex", 32), ("Evan", 23)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def mismatched_upsert_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", IntegerType(), True), + ] + ) + + data = [("Alex", 32, 174), ("Evan", 23, 166)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def subset_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", IntegerType(), True), + ] + ) + + data = [("Alex", 32, 174), ("Evan", 23, 166)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def subset_pandas_df(): + return pd.DataFrame( + {"name": ["Alex", "Evan"], "age": [32, 23], "height": [174, 166]} + ) + + +@pytest.fixture +def subset_expected_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Alex", 32), ("Evan", 23)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def sample_pandas_df(): + return pd.DataFrame( + {"name": ["Alex", "Bob", "Clarke", "Dave"], "age": [31, 12, 65, 29]} + ) + + +@pytest.fixture +def append_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Evan", 23), ("Frank", 13)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def expected_append_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [ + ("Alex", 31), + ("Bob", 12), + ("Clarke", 65), + ("Dave", 29), + ("Evan", 23), + ("Frank", 13), + ] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def expected_upsert_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [ + ("Alex", 32), + ("Bob", 12), + ("Clarke", 65), + ("Dave", 29), + ("Evan", 23), + ] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def expected_upsert_multiple_primary_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [ + ("Alex", 31), + ("Alex", 32), + ("Bob", 12), + ("Clarke", 65), + ("Dave", 29), + ("Evan", 23), + ] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def external_location(): + return os.environ.get("DATABRICKS_EXTERNAL_LOCATION") diff --git a/kedro-datasets/tests/databricks/test_base_table_dataset.py b/kedro-datasets/tests/databricks/test_base_table_dataset.py index 5cc88e8df..ac0c0cc27 100644 --- a/kedro-datasets/tests/databricks/test_base_table_dataset.py +++ b/kedro-datasets/tests/databricks/test_base_table_dataset.py @@ -3,179 +3,12 @@ import pandas as pd import pytest from kedro.io.core import DatasetError, Version, VersionNotFoundError -from pyspark.sql import DataFrame, SparkSession +from pyspark.sql import DataFrame from pyspark.sql.types import IntegerType, StringType, StructField, StructType from kedro_datasets.databricks._base_table_dataset import BaseTableDataset -@pytest.fixture -def sample_spark_df(spark_session: SparkSession): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - ] - ) - - data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)] - - return spark_session.createDataFrame(data, schema) - - -@pytest.fixture -def upsert_spark_df(spark_session: SparkSession): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - ] - ) - - data = [("Alex", 32), ("Evan", 23)] - - return spark_session.createDataFrame(data, schema) - - -@pytest.fixture -def mismatched_upsert_spark_df(spark_session: SparkSession): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - StructField("height", IntegerType(), True), - ] - ) - - data = [("Alex", 32, 174), ("Evan", 23, 166)] - - return spark_session.createDataFrame(data, schema) - - -@pytest.fixture -def subset_spark_df(spark_session: SparkSession): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - StructField("height", IntegerType(), True), - ] - ) - - data = [("Alex", 32, 174), ("Evan", 23, 166)] - - return spark_session.createDataFrame(data, schema) - - -@pytest.fixture -def subset_pandas_df(): - return pd.DataFrame( - {"name": ["Alex", "Evan"], "age": [32, 23], "height": [174, 166]} - ) - - -@pytest.fixture -def subset_expected_df(spark_session: SparkSession): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - ] - ) - - data = [("Alex", 32), ("Evan", 23)] - - return spark_session.createDataFrame(data, schema) - - -@pytest.fixture -def sample_pandas_df(): - return pd.DataFrame( - {"name": ["Alex", "Bob", "Clarke", "Dave"], "age": [31, 12, 65, 29]} - ) - - -@pytest.fixture -def append_spark_df(spark_session: SparkSession): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - ] - ) - - data = [("Evan", 23), ("Frank", 13)] - - return spark_session.createDataFrame(data, schema) - - -@pytest.fixture -def expected_append_spark_df(spark_session: SparkSession): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - ] - ) - - data = [ - ("Alex", 31), - ("Bob", 12), - ("Clarke", 65), - ("Dave", 29), - ("Evan", 23), - ("Frank", 13), - ] - - return spark_session.createDataFrame(data, schema) - - -@pytest.fixture -def expected_upsert_spark_df(spark_session: SparkSession): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - ] - ) - - data = [ - ("Alex", 32), - ("Bob", 12), - ("Clarke", 65), - ("Dave", 29), - ("Evan", 23), - ] - - return spark_session.createDataFrame(data, schema) - - -@pytest.fixture -def expected_upsert_multiple_primary_spark_df(spark_session: SparkSession): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - ] - ) - - data = [ - ("Alex", 31), - ("Alex", 32), - ("Bob", 12), - ("Clarke", 65), - ("Dave", 29), - ("Evan", 23), - ] - - return spark_session.createDataFrame(data, schema) - - -@pytest.fixture -def external_location(): - return os.environ.get("DATABRICKS_EXTERNAL_LOCATION") - - class TestBaseTableDataset: def test_full_table(self): unity_ds = BaseTableDataset(catalog="test", database="test", table="test")