diff --git a/python/pyspark/sql/tests/test_python_datasource.py b/python/pyspark/sql/tests/test_python_datasource.py index 5f654ce6cfaea..0202930dd304b 100644 --- a/python/pyspark/sql/tests/test_python_datasource.py +++ b/python/pyspark/sql/tests/test_python_datasource.py @@ -266,6 +266,27 @@ def reader(self, schema) -> "DataSourceReader": assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1")]) self.assertEqual(df.select(spark_partition_id()).distinct().count(), 2) + def test_struct_array(self): + class TestDataSourceReader(DataSourceReader): + def read(self, partition): + yield (1, 2), [(3, (4, 5)), (6, (7, 8))] + yield (9, 10), [(11, (12, 13))] + + class TestDataSource(DataSource): + def schema(self): + return "a struct, d array>>" + + def reader(self, schema) -> "DataSourceReader": + return TestDataSourceReader() + + self.spark.dataSource.register(TestDataSource) + df = self.spark.read.format("TestDataSource").load() + expected_data = [ + Row(a=Row(b=1, c=2), d=[Row(e=3, f=Row(g=4, h=5)), Row(e=6, f=Row(g=7, h=8))]), + Row(a=Row(b=9, c=10), d=[Row(e=11, f=Row(g=12, h=13))]), + ] + assertDataFrameEqual(df, expected_data) + def test_filter_pushdown(self): class TestDataSourceReader(DataSourceReader): def __init__(self): diff --git a/python/pyspark/sql/tests/test_python_streaming_datasource.py b/python/pyspark/sql/tests/test_python_streaming_datasource.py index fa14b37b57e62..91e2d8640ad3b 100644 --- a/python/pyspark/sql/tests/test_python_streaming_datasource.py +++ b/python/pyspark/sql/tests/test_python_streaming_datasource.py @@ -27,6 +27,7 @@ SimpleDataSourceStreamReader, WriterCommitMessage, ) +from pyspark.sql.session import SparkSession from pyspark.sql.streaming import StreamingQueryException from pyspark.sql.types import Row from pyspark.testing.sqlutils import ( @@ -39,6 +40,8 @@ @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) class BasePythonStreamingDataSourceTestsMixin: + spark: SparkSession + def test_basic_streaming_data_source_class(self): class MyDataSource(DataSource): ... @@ -146,7 +149,7 @@ def check_batch(df, batch_id): assertDataFrameEqual(df, [Row(batch_id * 2), Row(batch_id * 2 + 1)]) q = df.writeStream.foreachBatch(check_batch).start() - while len(q.recentProgress) < 10: + while q.isActive and len(q.recentProgress) < 10: time.sleep(0.2) q.stop() q.awaitTermination() @@ -196,7 +199,7 @@ def streamReader(self, schema): .option("checkpointLocation", checkpoint_dir.name) .start(output_dir.name) ) - while not q.recentProgress: + while q.isActive and not q.recentProgress: time.sleep(0.2) q.stop() q.awaitTermination() @@ -244,7 +247,7 @@ def check_batch(df, batch_id): assertDataFrameEqual(df, [Row(batch_id * 2), Row(batch_id * 2 + 1)]) q = df.writeStream.foreachBatch(check_batch).start() - while len(q.recentProgress) < 10: + while q.isActive and len(q.recentProgress) < 10: time.sleep(0.2) q.stop() q.awaitTermination() @@ -266,7 +269,7 @@ def test_stream_writer(self): .option("checkpointLocation", checkpoint_dir.name) .start(output_dir.name) ) - while not q.recentProgress: + while q.isActive and not q.recentProgress: time.sleep(0.2) # Test stream writer write and commit. @@ -281,7 +284,7 @@ def test_stream_writer(self): # Test StreamWriter write and abort. # When row id > 50, write tasks throw exception and fail. # 1.txt is written by StreamWriter.abort() to record the failure. - while q.exception() is None: + while q.isActive and q.exception() is None: time.sleep(0.2) assertDataFrameEqual( self.spark.read.text(os.path.join(output_dir.name, "1.txt")), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonPartitionReaderFactory.scala index 44933779c26a4..b9b837f949556 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonPartitionReaderFactory.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.connector.metric.CustomTaskMetric import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch case class PythonInputPartition(index: Int, pickedPartition: Array[Byte]) extends InputPartition @@ -63,4 +64,37 @@ class PythonPartitionReaderFactory( } } } + + override def supportColumnarReads(partition: InputPartition): Boolean = true + + override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] = { + new PartitionReader[ColumnarBatch] { + + private[this] val metrics: Map[String, SQLMetric] = PythonCustomMetric.pythonMetrics + + private val outputIter = { + val evaluatorFactory = source.createMapInBatchColumnarEvaluatorFactory( + pickledReadFunc, + "read_from_data_source", + UserDefinedPythonDataSource.readInputSchema, + outputSchema, + metrics, + jobArtifactUUID) + + val part = partition.asInstanceOf[PythonInputPartition] + evaluatorFactory.createEvaluator().eval( + part.index, Iterator.single(InternalRow(part.pickedPartition))) + } + + override def next(): Boolean = outputIter.hasNext + + override def get(): ColumnarBatch = outputIter.next() + + override def close(): Unit = {} + + override def currentMetricsValues(): Array[CustomTaskMetric] = { + source.createPythonTaskMetrics(metrics.map { case (k, v) => k -> v.value}) + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingPartitionReaderFactory.scala index 7d80cc2728102..ccebe2e61a7e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingPartitionReaderFactory.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.connector.metric.CustomTaskMetric import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.storage.PythonStreamBlockId @@ -87,4 +88,41 @@ class PythonStreamingPartitionReaderFactory( } } } + + override def supportColumnarReads(partition: InputPartition): Boolean = { + // Prefetched block doesn't support columnar read because ColumnarBatch is not serializable. + val part = partition.asInstanceOf[PythonStreamingInputPartition] + part.blockId.isEmpty + } + + override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] = { + val part = partition.asInstanceOf[PythonStreamingInputPartition] + + new PartitionReader[ColumnarBatch] { + private[this] val metrics: Map[String, SQLMetric] = PythonCustomMetric.pythonMetrics + + private val outputIter = { + val evaluatorFactory = source.createMapInBatchColumnarEvaluatorFactory( + pickledReadFunc, + "read_from_data_source", + UserDefinedPythonDataSource.readInputSchema, + outputSchema, + metrics, + jobArtifactUUID) + + evaluatorFactory.createEvaluator().eval( + part.index, Iterator.single(InternalRow(part.pickedPartition))) + } + + override def next(): Boolean = outputIter.hasNext + + override def get(): ColumnarBatch = outputIter.next() + + override def close(): Unit = {} + + override def currentMetricsValues(): Array[CustomTaskMetric] = { + source.createPythonTaskMetrics(metrics.map { case (k, v) => k -> v.value }) + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala index 4664e957ab31f..e5cfd0b23fc66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric} import org.apache.spark.sql.connector.write._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.execution.python.{ArrowPythonRunner, MapInBatchEvaluatorFactory, PythonPlannerRunner, PythonSQLMetrics} +import org.apache.spark.sql.execution.python.{ArrowPythonRunner, MapInBatchColumnarEvaluatorFactory, MapInBatchEvaluatorFactory, PythonPlannerRunner, PythonSQLMetrics} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources import org.apache.spark.sql.sources.Filter @@ -172,6 +172,44 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { jobArtifactUUID) } + /** + * (Executor-side) Create an iterator that execute the Python function. + */ + def createMapInBatchColumnarEvaluatorFactory( + pickledFunc: Array[Byte], + funcName: String, + inputSchema: StructType, + outputSchema: StructType, + metrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String]): MapInBatchColumnarEvaluatorFactory = { + val pythonFunc = createPythonFunction(pickledFunc) + + val pythonEvalType = PythonEvalType.SQL_MAP_ARROW_ITER_UDF + + val pythonUDF = PythonUDF( + name = funcName, + func = pythonFunc, + dataType = outputSchema, + children = toAttributes(inputSchema), + evalType = pythonEvalType, + udfDeterministic = false + ) + + val conf = SQLConf.get + + val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf) + new MapInBatchColumnarEvaluatorFactory( + Seq((ChainedPythonFunctions(Seq(pythonUDF.func)), pythonUDF.resultId.id)), + inputSchema, + pythonEvalType, + conf.sessionLocalTimeZone, + conf.arrowUseLargeVarTypes, + pythonRunnerConf, + metrics, + jobArtifactUUID + ) + } + def createPythonMetrics(): Array[CustomMetric] = { // Do not add other metrics such as number of rows, // that is already included via DSv2. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala index 88b63f3b2dd09..302d541e14328 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala @@ -86,3 +86,51 @@ class MapInBatchEvaluatorFactory( } } } + +class MapInBatchColumnarEvaluatorFactory( + chainedFunc: Seq[(ChainedPythonFunctions, Long)], + outputTypes: StructType, + pythonEvalType: Int, + sessionLocalTimeZone: String, + largeVarTypes: Boolean, + pythonRunnerConf: Map[String, String], + val pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String]) + extends PartitionEvaluatorFactory[InternalRow, ColumnarBatch] { + + override def createEvaluator(): PartitionEvaluator[InternalRow, ColumnarBatch] = + new MapInBatchEvaluator + + private class MapInBatchEvaluator extends PartitionEvaluator[InternalRow, ColumnarBatch] { + override def eval( + partitionIndex: Int, + inputs: Iterator[InternalRow]*): Iterator[ColumnarBatch] = { + assert(inputs.length == 1) + val inputIter = inputs.head + // Single function with one struct. + val argOffsets = Array(Array(0)) + val context = TaskContext.get() + + // Here we wrap it via another row so that Python sides understand it + // as a DataFrame. + val wrappedIter = inputIter.map(InternalRow(_)) + + val batchIter = Iterator(wrappedIter) + + val pyRunner = new ArrowPythonRunner( + chainedFunc, + pythonEvalType, + argOffsets, + StructType(Array(StructField("struct", outputTypes))), + sessionLocalTimeZone, + largeVarTypes, + pythonRunnerConf, + pythonMetrics, + jobArtifactUUID, + None) with BatchedPythonArrowInput + val columnarBatchIter = pyRunner.compute(batchIter, context.partitionId(), context) + + columnarBatchIter + } + } +}