diff --git a/kedro-datasets/tests/spark/conftest.py b/kedro-datasets/tests/spark/conftest.py index fa7504f0a..389217bac 100644 --- a/kedro-datasets/tests/spark/conftest.py +++ b/kedro-datasets/tests/spark/conftest.py @@ -27,7 +27,7 @@ def _setup_spark_session(): ).getOrCreate() -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(scope="module") def spark_session(tmp_path_factory): # When running these spark tests with pytest-xdist, we need to make sure # that the spark session setup on each test process don't interfere with each other. @@ -40,3 +40,6 @@ def spark_session(tmp_path_factory): spark = _setup_spark_session() yield spark spark.stop() + # Ensure that the spark session is not used after it is stopped + # https://stackoverflow.com/a/41512072 + spark._instantiatedContext = None diff --git a/kedro-datasets/tests/spark/test_deltatable_dataset.py b/kedro-datasets/tests/spark/test_deltatable_dataset.py index 24ad7a3c6..55db00f98 100644 --- a/kedro-datasets/tests/spark/test_deltatable_dataset.py +++ b/kedro-datasets/tests/spark/test_deltatable_dataset.py @@ -7,7 +7,6 @@ from kedro.runner import ParallelRunner from packaging.version import Version from pyspark import __version__ -from pyspark.sql import SparkSession from pyspark.sql.types import IntegerType, StringType, StructField, StructType from pyspark.sql.utils import AnalysisException @@ -17,7 +16,7 @@ @pytest.fixture -def sample_spark_df(): +def sample_spark_df(spark_session): schema = StructType( [ StructField("name", StringType(), True), @@ -27,7 +26,7 @@ def sample_spark_df(): data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)] - return SparkSession.builder.getOrCreate().createDataFrame(data, schema) + return spark_session.createDataFrame(data, schema) class TestDeltaTableDataset: diff --git a/kedro-datasets/tests/spark/test_memory_dataset.py b/kedro-datasets/tests/spark/test_memory_dataset.py index 8dd469217..d15b0159d 100644 --- a/kedro-datasets/tests/spark/test_memory_dataset.py +++ b/kedro-datasets/tests/spark/test_memory_dataset.py @@ -1,13 +1,11 @@ import pytest from kedro.io import MemoryDataset from pyspark.sql import DataFrame as SparkDataFrame -from pyspark.sql import SparkSession from pyspark.sql.functions import col, when -def _update_spark_df(data, idx, jdx, value): - session = SparkSession.builder.getOrCreate() - data = session.createDataFrame(data.rdd.zipWithIndex()).select( +def _update_spark_df(spark_session, data, idx, jdx, value): + data = spark_session.createDataFrame(data.rdd.zipWithIndex()).select( col("_1.*"), col("_2").alias("__id") ) cname = data.columns[idx] @@ -34,19 +32,21 @@ def memory_dataset(spark_data_frame): return MemoryDataset(data=spark_data_frame) -def test_load_modify_original_data(memory_dataset, spark_data_frame): +def test_load_modify_original_data(spark_session, memory_dataset, spark_data_frame): """Check that the data set object is not updated when the original SparkDataFrame is changed.""" - spark_data_frame = _update_spark_df(spark_data_frame, 1, 1, -5) + spark_data_frame = _update_spark_df(spark_session, spark_data_frame, 1, 1, -5) assert not _check_equals(memory_dataset.load(), spark_data_frame) -def test_save_modify_original_data(spark_data_frame): +def test_save_modify_original_data(spark_session, spark_data_frame): """Check that the data set object is not updated when the original SparkDataFrame is changed.""" memory_dataset = MemoryDataset() memory_dataset.save(spark_data_frame) - spark_data_frame = _update_spark_df(spark_data_frame, 1, 1, "new value") + spark_data_frame = _update_spark_df( + spark_session, spark_data_frame, 1, 1, "new value" + ) assert not _check_equals(memory_dataset.load(), spark_data_frame) diff --git a/kedro-datasets/tests/spark/test_spark_dataset.py b/kedro-datasets/tests/spark/test_spark_dataset.py index e4eed4481..a9cb40b29 100644 --- a/kedro-datasets/tests/spark/test_spark_dataset.py +++ b/kedro-datasets/tests/spark/test_spark_dataset.py @@ -15,7 +15,6 @@ from moto import mock_aws from packaging.version import Version as PackagingVersion from pyspark import __version__ -from pyspark.sql import SparkSession from pyspark.sql.functions import col from pyspark.sql.types import ( FloatType, @@ -102,7 +101,7 @@ def versioned_dataset_s3(version): @pytest.fixture -def sample_spark_df(): +def sample_spark_df(spark_session): schema = StructType( [ StructField("name", StringType(), True), @@ -112,7 +111,7 @@ def sample_spark_df(): data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)] - return SparkSession.builder.getOrCreate().createDataFrame(data, schema) + return spark_session.createDataFrame(data, schema) @pytest.fixture diff --git a/kedro-datasets/tests/spark/test_spark_hive_dataset.py b/kedro-datasets/tests/spark/test_spark_hive_dataset.py index 5f11674dd..b8dcf2870 100644 --- a/kedro-datasets/tests/spark/test_spark_hive_dataset.py +++ b/kedro-datasets/tests/spark/test_spark_hive_dataset.py @@ -45,6 +45,9 @@ def spark_session(): # in this module so that it always exits last and stops the spark session # after tests are finished. spark.stop() + # Ensure that the spark session is not used after it is stopped + # https://stackoverflow.com/a/41512072 + spark._instantiatedContext = None except PermissionError: # pragma: no cover # On Windows machine TemporaryDirectory can't be removed because some # files are still used by Java process. @@ -68,7 +71,7 @@ def spark_session(): @pytest.fixture(scope="module", autouse=True) def spark_test_databases(spark_session): """Setup spark test databases for all tests in this module.""" - dataset = _generate_spark_df_one() + dataset = _generate_spark_df_one(spark_session) dataset.createOrReplaceTempView("tmp") databases = ["default_1", "default_2"] @@ -100,7 +103,7 @@ def indexRDD(data_frame): ) -def _generate_spark_df_one(): +def _generate_spark_df_one(spark_session): schema = StructType( [ StructField("name", StringType(), True), @@ -108,10 +111,10 @@ def _generate_spark_df_one(): ] ) data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)] - return SparkSession.builder.getOrCreate().createDataFrame(data, schema).coalesce(1) + return spark_session.createDataFrame(data, schema).coalesce(1) -def _generate_spark_df_upsert(): +def _generate_spark_df_upsert(spark_session): schema = StructType( [ StructField("name", StringType(), True), @@ -119,10 +122,10 @@ def _generate_spark_df_upsert(): ] ) data = [("Alex", 99), ("Jeremy", 55)] - return SparkSession.builder.getOrCreate().createDataFrame(data, schema).coalesce(1) + return spark_session.createDataFrame(data, schema).coalesce(1) -def _generate_spark_df_upsert_expected(): +def _generate_spark_df_upsert_expected(spark_session): schema = StructType( [ StructField("name", StringType(), True), @@ -130,7 +133,7 @@ def _generate_spark_df_upsert_expected(): ] ) data = [("Alex", 99), ("Bob", 12), ("Clarke", 65), ("Dave", 29), ("Jeremy", 55)] - return SparkSession.builder.getOrCreate().createDataFrame(data, schema).coalesce(1) + return spark_session.createDataFrame(data, schema).coalesce(1) class TestSparkHiveDataset: @@ -144,11 +147,11 @@ def test_cant_pickle(self): ) ) - def test_read_existing_table(self): + def test_read_existing_table(self, spark_session): dataset = SparkHiveDataset( database="default_1", table="table_1", write_mode="overwrite", save_args={} ) - assert_df_equal(_generate_spark_df_one(), dataset.load()) + assert_df_equal(_generate_spark_df_one(spark_session), dataset.load()) def test_overwrite_empty_table(self, spark_session): spark_session.sql( @@ -159,8 +162,8 @@ def test_overwrite_empty_table(self, spark_session): table="test_overwrite_empty_table", write_mode="overwrite", ) - dataset.save(_generate_spark_df_one()) - assert_df_equal(dataset.load(), _generate_spark_df_one()) + dataset.save(_generate_spark_df_one(spark_session)) + assert_df_equal(dataset.load(), _generate_spark_df_one(spark_session)) def test_overwrite_not_empty_table(self, spark_session): spark_session.sql( @@ -171,9 +174,9 @@ def test_overwrite_not_empty_table(self, spark_session): table="test_overwrite_full_table", write_mode="overwrite", ) - dataset.save(_generate_spark_df_one()) - dataset.save(_generate_spark_df_one()) - assert_df_equal(dataset.load(), _generate_spark_df_one()) + dataset.save(_generate_spark_df_one(spark_session)) + dataset.save(_generate_spark_df_one(spark_session)) + assert_df_equal(dataset.load(), _generate_spark_df_one(spark_session)) def test_insert_not_empty_table(self, spark_session): spark_session.sql( @@ -184,10 +187,13 @@ def test_insert_not_empty_table(self, spark_session): table="test_insert_not_empty_table", write_mode="append", ) - dataset.save(_generate_spark_df_one()) - dataset.save(_generate_spark_df_one()) + dataset.save(_generate_spark_df_one(spark_session)) + dataset.save(_generate_spark_df_one(spark_session)) assert_df_equal( - dataset.load(), _generate_spark_df_one().union(_generate_spark_df_one()) + dataset.load(), + _generate_spark_df_one(spark_session).union( + _generate_spark_df_one(spark_session) + ), ) def test_upsert_config_err(self): @@ -207,9 +213,10 @@ def test_upsert_empty_table(self, spark_session): write_mode="upsert", table_pk=["name"], ) - dataset.save(_generate_spark_df_one()) + dataset.save(_generate_spark_df_one(spark_session)) assert_df_equal( - dataset.load().sort("name"), _generate_spark_df_one().sort("name") + dataset.load().sort("name"), + _generate_spark_df_one(spark_session).sort("name"), ) def test_upsert_not_empty_table(self, spark_session): @@ -222,15 +229,15 @@ def test_upsert_not_empty_table(self, spark_session): write_mode="upsert", table_pk=["name"], ) - dataset.save(_generate_spark_df_one()) - dataset.save(_generate_spark_df_upsert()) + dataset.save(_generate_spark_df_one(spark_session)) + dataset.save(_generate_spark_df_upsert(spark_session)) assert_df_equal( dataset.load().sort("name"), - _generate_spark_df_upsert_expected().sort("name"), + _generate_spark_df_upsert_expected(spark_session).sort("name"), ) - def test_invalid_pk_provided(self): + def test_invalid_pk_provided(self, spark_session): _test_columns = ["column_doesnt_exist"] dataset = SparkHiveDataset( database="default_1", @@ -245,7 +252,7 @@ def test_invalid_pk_provided(self): f"not found in table default_1.table_1", ), ): - dataset.save(_generate_spark_df_one()) + dataset.save(_generate_spark_df_one(spark_session)) def test_invalid_write_mode_provided(self): pattern = ( @@ -277,15 +284,16 @@ def test_invalid_schema_insert(self, spark_session): r"Present on insert only: \[\('age', 'int'\)\]\n" r"Present on schema only: \[\('additional_column_on_hive', 'int'\)\]", ): - dataset.save(_generate_spark_df_one()) + dataset.save(_generate_spark_df_one(spark_session)) - def test_insert_to_non_existent_table(self): + def test_insert_to_non_existent_table(self, spark_session): dataset = SparkHiveDataset( database="default_1", table="table_not_yet_created", write_mode="append" ) - dataset.save(_generate_spark_df_one()) + dataset.save(_generate_spark_df_one(spark_session)) assert_df_equal( - dataset.load().sort("name"), _generate_spark_df_one().sort("name") + dataset.load().sort("name"), + _generate_spark_df_one(spark_session).sort("name"), ) def test_read_from_non_existent_table(self): @@ -300,12 +308,12 @@ def test_read_from_non_existent_table(self): ): dataset.load() - def test_save_delta_format(self, mocker): + def test_save_delta_format(self, mocker, spark_session): dataset = SparkHiveDataset( database="default_1", table="delta_table", save_args={"format": "delta"} ) mocked_save = mocker.patch("pyspark.sql.DataFrameWriter.saveAsTable") - dataset.save(_generate_spark_df_one()) + dataset.save(_generate_spark_df_one(spark_session)) mocked_save.assert_called_with( "default_1.delta_table", mode="errorifexists", format="delta" ) diff --git a/kedro-datasets/tests/spark/test_spark_streaming_dataset.py b/kedro-datasets/tests/spark/test_spark_streaming_dataset.py index 6ac8d8b0b..36941a11c 100644 --- a/kedro-datasets/tests/spark/test_spark_streaming_dataset.py +++ b/kedro-datasets/tests/spark/test_spark_streaming_dataset.py @@ -6,7 +6,6 @@ from moto import mock_aws from packaging.version import Version from pyspark import __version__ -from pyspark.sql import SparkSession from pyspark.sql.types import IntegerType, StringType, StructField, StructType from pyspark.sql.utils import AnalysisException @@ -43,15 +42,13 @@ def sample_spark_df_schema() -> StructType: @pytest.fixture -def sample_spark_streaming_df(tmp_path, sample_spark_df_schema): +def sample_spark_streaming_df(spark_session, tmp_path, sample_spark_df_schema): """Create a sample dataframe for streaming""" data = [("0001", 2), ("0001", 7), ("0002", 4)] schema_path = (tmp_path / SCHEMA_FILE_NAME).as_posix() with open(schema_path, "w", encoding="utf-8") as f: json.dump(sample_spark_df_schema.jsonValue(), f) - return SparkSession.builder.getOrCreate().createDataFrame( - data, sample_spark_df_schema - ) + return spark_session.createDataFrame(data, sample_spark_df_schema) @pytest.fixture