Skip to content

Commit

Permalink
fix(datasets): Fix spark tests
Browse files Browse the repository at this point in the history
Signed-off-by: Felipe Monroy <[email protected]>
  • Loading branch information
felipemonroy committed Jul 20, 2024
1 parent 142342d commit c03006a
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 50 deletions.
5 changes: 4 additions & 1 deletion kedro-datasets/tests/spark/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _setup_spark_session():
).getOrCreate()


@pytest.fixture(scope="module", autouse=True)
@pytest.fixture(scope="module")
def spark_session(tmp_path_factory):
# When running these spark tests with pytest-xdist, we need to make sure
# that the spark session setup on each test process don't interfere with each other.
Expand All @@ -40,3 +40,6 @@ def spark_session(tmp_path_factory):
spark = _setup_spark_session()
yield spark
spark.stop()
# Ensure that the spark session is not used after it is stopped
# https://stackoverflow.com/a/41512072
spark._instantiatedContext = None
5 changes: 2 additions & 3 deletions kedro-datasets/tests/spark/test_deltatable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from kedro.runner import ParallelRunner
from packaging.version import Version
from pyspark import __version__
from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType, StringType, StructField, StructType
from pyspark.sql.utils import AnalysisException

Expand All @@ -17,7 +16,7 @@


@pytest.fixture
def sample_spark_df():
def sample_spark_df(spark_session):
schema = StructType(
[
StructField("name", StringType(), True),
Expand All @@ -27,7 +26,7 @@ def sample_spark_df():

data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)]

return SparkSession.builder.getOrCreate().createDataFrame(data, schema)
return spark_session.createDataFrame(data, schema)


class TestDeltaTableDataset:
Expand Down
16 changes: 8 additions & 8 deletions kedro-datasets/tests/spark/test_memory_dataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import pytest
from kedro.io import MemoryDataset
from pyspark.sql import DataFrame as SparkDataFrame
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when


def _update_spark_df(data, idx, jdx, value):
session = SparkSession.builder.getOrCreate()
data = session.createDataFrame(data.rdd.zipWithIndex()).select(
def _update_spark_df(spark_session, data, idx, jdx, value):
data = spark_session.createDataFrame(data.rdd.zipWithIndex()).select(
col("_1.*"), col("_2").alias("__id")
)
cname = data.columns[idx]
Expand All @@ -34,19 +32,21 @@ def memory_dataset(spark_data_frame):
return MemoryDataset(data=spark_data_frame)


def test_load_modify_original_data(memory_dataset, spark_data_frame):
def test_load_modify_original_data(spark_session, memory_dataset, spark_data_frame):
"""Check that the data set object is not updated when the original
SparkDataFrame is changed."""
spark_data_frame = _update_spark_df(spark_data_frame, 1, 1, -5)
spark_data_frame = _update_spark_df(spark_session, spark_data_frame, 1, 1, -5)
assert not _check_equals(memory_dataset.load(), spark_data_frame)


def test_save_modify_original_data(spark_data_frame):
def test_save_modify_original_data(spark_session, spark_data_frame):
"""Check that the data set object is not updated when the original
SparkDataFrame is changed."""
memory_dataset = MemoryDataset()
memory_dataset.save(spark_data_frame)
spark_data_frame = _update_spark_df(spark_data_frame, 1, 1, "new value")
spark_data_frame = _update_spark_df(
spark_session, spark_data_frame, 1, 1, "new value"
)

assert not _check_equals(memory_dataset.load(), spark_data_frame)

Expand Down
5 changes: 2 additions & 3 deletions kedro-datasets/tests/spark/test_spark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from moto import mock_aws
from packaging.version import Version as PackagingVersion
from pyspark import __version__
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark.sql.types import (
FloatType,
Expand Down Expand Up @@ -102,7 +101,7 @@ def versioned_dataset_s3(version):


@pytest.fixture
def sample_spark_df():
def sample_spark_df(spark_session):
schema = StructType(
[
StructField("name", StringType(), True),
Expand All @@ -112,7 +111,7 @@ def sample_spark_df():

data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)]

return SparkSession.builder.getOrCreate().createDataFrame(data, schema)
return spark_session.createDataFrame(data, schema)


@pytest.fixture
Expand Down
68 changes: 38 additions & 30 deletions kedro-datasets/tests/spark/test_spark_hive_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def spark_session():
# in this module so that it always exits last and stops the spark session
# after tests are finished.
spark.stop()
# Ensure that the spark session is not used after it is stopped
# https://stackoverflow.com/a/41512072
spark._instantiatedContext = None
except PermissionError: # pragma: no cover
# On Windows machine TemporaryDirectory can't be removed because some
# files are still used by Java process.
Expand All @@ -68,7 +71,7 @@ def spark_session():
@pytest.fixture(scope="module", autouse=True)
def spark_test_databases(spark_session):
"""Setup spark test databases for all tests in this module."""
dataset = _generate_spark_df_one()
dataset = _generate_spark_df_one(spark_session)
dataset.createOrReplaceTempView("tmp")
databases = ["default_1", "default_2"]

Expand Down Expand Up @@ -100,37 +103,37 @@ def indexRDD(data_frame):
)


def _generate_spark_df_one():
def _generate_spark_df_one(spark_session):
schema = StructType(
[
StructField("name", StringType(), True),
StructField("age", IntegerType(), True),
]
)
data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)]
return SparkSession.builder.getOrCreate().createDataFrame(data, schema).coalesce(1)
return spark_session.createDataFrame(data, schema).coalesce(1)


def _generate_spark_df_upsert():
def _generate_spark_df_upsert(spark_session):
schema = StructType(
[
StructField("name", StringType(), True),
StructField("age", IntegerType(), True),
]
)
data = [("Alex", 99), ("Jeremy", 55)]
return SparkSession.builder.getOrCreate().createDataFrame(data, schema).coalesce(1)
return spark_session.createDataFrame(data, schema).coalesce(1)


def _generate_spark_df_upsert_expected():
def _generate_spark_df_upsert_expected(spark_session):
schema = StructType(
[
StructField("name", StringType(), True),
StructField("age", IntegerType(), True),
]
)
data = [("Alex", 99), ("Bob", 12), ("Clarke", 65), ("Dave", 29), ("Jeremy", 55)]
return SparkSession.builder.getOrCreate().createDataFrame(data, schema).coalesce(1)
return spark_session.createDataFrame(data, schema).coalesce(1)


class TestSparkHiveDataset:
Expand All @@ -144,11 +147,11 @@ def test_cant_pickle(self):
)
)

def test_read_existing_table(self):
def test_read_existing_table(self, spark_session):
dataset = SparkHiveDataset(
database="default_1", table="table_1", write_mode="overwrite", save_args={}
)
assert_df_equal(_generate_spark_df_one(), dataset.load())
assert_df_equal(_generate_spark_df_one(spark_session), dataset.load())

def test_overwrite_empty_table(self, spark_session):
spark_session.sql(
Expand All @@ -159,8 +162,8 @@ def test_overwrite_empty_table(self, spark_session):
table="test_overwrite_empty_table",
write_mode="overwrite",
)
dataset.save(_generate_spark_df_one())
assert_df_equal(dataset.load(), _generate_spark_df_one())
dataset.save(_generate_spark_df_one(spark_session))
assert_df_equal(dataset.load(), _generate_spark_df_one(spark_session))

def test_overwrite_not_empty_table(self, spark_session):
spark_session.sql(
Expand All @@ -171,9 +174,9 @@ def test_overwrite_not_empty_table(self, spark_session):
table="test_overwrite_full_table",
write_mode="overwrite",
)
dataset.save(_generate_spark_df_one())
dataset.save(_generate_spark_df_one())
assert_df_equal(dataset.load(), _generate_spark_df_one())
dataset.save(_generate_spark_df_one(spark_session))
dataset.save(_generate_spark_df_one(spark_session))
assert_df_equal(dataset.load(), _generate_spark_df_one(spark_session))

def test_insert_not_empty_table(self, spark_session):
spark_session.sql(
Expand All @@ -184,10 +187,13 @@ def test_insert_not_empty_table(self, spark_session):
table="test_insert_not_empty_table",
write_mode="append",
)
dataset.save(_generate_spark_df_one())
dataset.save(_generate_spark_df_one())
dataset.save(_generate_spark_df_one(spark_session))
dataset.save(_generate_spark_df_one(spark_session))
assert_df_equal(
dataset.load(), _generate_spark_df_one().union(_generate_spark_df_one())
dataset.load(),
_generate_spark_df_one(spark_session).union(
_generate_spark_df_one(spark_session)
),
)

def test_upsert_config_err(self):
Expand All @@ -207,9 +213,10 @@ def test_upsert_empty_table(self, spark_session):
write_mode="upsert",
table_pk=["name"],
)
dataset.save(_generate_spark_df_one())
dataset.save(_generate_spark_df_one(spark_session))
assert_df_equal(
dataset.load().sort("name"), _generate_spark_df_one().sort("name")
dataset.load().sort("name"),
_generate_spark_df_one(spark_session).sort("name"),
)

def test_upsert_not_empty_table(self, spark_session):
Expand All @@ -222,15 +229,15 @@ def test_upsert_not_empty_table(self, spark_session):
write_mode="upsert",
table_pk=["name"],
)
dataset.save(_generate_spark_df_one())
dataset.save(_generate_spark_df_upsert())
dataset.save(_generate_spark_df_one(spark_session))
dataset.save(_generate_spark_df_upsert(spark_session))

assert_df_equal(
dataset.load().sort("name"),
_generate_spark_df_upsert_expected().sort("name"),
_generate_spark_df_upsert_expected(spark_session).sort("name"),
)

def test_invalid_pk_provided(self):
def test_invalid_pk_provided(self, spark_session):
_test_columns = ["column_doesnt_exist"]
dataset = SparkHiveDataset(
database="default_1",
Expand All @@ -245,7 +252,7 @@ def test_invalid_pk_provided(self):
f"not found in table default_1.table_1",
),
):
dataset.save(_generate_spark_df_one())
dataset.save(_generate_spark_df_one(spark_session))

def test_invalid_write_mode_provided(self):
pattern = (
Expand Down Expand Up @@ -277,15 +284,16 @@ def test_invalid_schema_insert(self, spark_session):
r"Present on insert only: \[\('age', 'int'\)\]\n"
r"Present on schema only: \[\('additional_column_on_hive', 'int'\)\]",
):
dataset.save(_generate_spark_df_one())
dataset.save(_generate_spark_df_one(spark_session))

def test_insert_to_non_existent_table(self):
def test_insert_to_non_existent_table(self, spark_session):
dataset = SparkHiveDataset(
database="default_1", table="table_not_yet_created", write_mode="append"
)
dataset.save(_generate_spark_df_one())
dataset.save(_generate_spark_df_one(spark_session))
assert_df_equal(
dataset.load().sort("name"), _generate_spark_df_one().sort("name")
dataset.load().sort("name"),
_generate_spark_df_one(spark_session).sort("name"),
)

def test_read_from_non_existent_table(self):
Expand All @@ -300,12 +308,12 @@ def test_read_from_non_existent_table(self):
):
dataset.load()

def test_save_delta_format(self, mocker):
def test_save_delta_format(self, mocker, spark_session):
dataset = SparkHiveDataset(
database="default_1", table="delta_table", save_args={"format": "delta"}
)
mocked_save = mocker.patch("pyspark.sql.DataFrameWriter.saveAsTable")
dataset.save(_generate_spark_df_one())
dataset.save(_generate_spark_df_one(spark_session))
mocked_save.assert_called_with(
"default_1.delta_table", mode="errorifexists", format="delta"
)
Expand Down
7 changes: 2 additions & 5 deletions kedro-datasets/tests/spark/test_spark_streaming_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from moto import mock_aws
from packaging.version import Version
from pyspark import __version__
from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType, StringType, StructField, StructType
from pyspark.sql.utils import AnalysisException

Expand Down Expand Up @@ -43,15 +42,13 @@ def sample_spark_df_schema() -> StructType:


@pytest.fixture
def sample_spark_streaming_df(tmp_path, sample_spark_df_schema):
def sample_spark_streaming_df(spark_session, tmp_path, sample_spark_df_schema):
"""Create a sample dataframe for streaming"""
data = [("0001", 2), ("0001", 7), ("0002", 4)]
schema_path = (tmp_path / SCHEMA_FILE_NAME).as_posix()
with open(schema_path, "w", encoding="utf-8") as f:
json.dump(sample_spark_df_schema.jsonValue(), f)
return SparkSession.builder.getOrCreate().createDataFrame(
data, sample_spark_df_schema
)
return spark_session.createDataFrame(data, sample_spark_df_schema)


@pytest.fixture
Expand Down

0 comments on commit c03006a

Please sign in to comment.