Skip to content

[SPARK-51637][PYTHON] Implement createColumnarReader for Python Data Source #50414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions python/pyspark/sql/tests/test_python_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,27 @@ def reader(self, schema) -> "DataSourceReader":
assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1")])
self.assertEqual(df.select(spark_partition_id()).distinct().count(), 2)

def test_struct_array(self):
class TestDataSourceReader(DataSourceReader):
def read(self, partition):
yield (1, 2), [(3, (4, 5)), (6, (7, 8))]
yield (9, 10), [(11, (12, 13))]

class TestDataSource(DataSource):
def schema(self):
return "a struct<b:int, c:int>, d array<struct<e:int, f:struct<g:int, h:int>>>"

def reader(self, schema) -> "DataSourceReader":
return TestDataSourceReader()

self.spark.dataSource.register(TestDataSource)
df = self.spark.read.format("TestDataSource").load()
expected_data = [
Row(a=Row(b=1, c=2), d=[Row(e=3, f=Row(g=4, h=5)), Row(e=6, f=Row(g=7, h=8))]),
Row(a=Row(b=9, c=10), d=[Row(e=11, f=Row(g=12, h=13))]),
]
assertDataFrameEqual(df, expected_data)

def test_filter_pushdown(self):
class TestDataSourceReader(DataSourceReader):
def __init__(self):
Expand Down
13 changes: 8 additions & 5 deletions python/pyspark/sql/tests/test_python_streaming_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
SimpleDataSourceStreamReader,
WriterCommitMessage,
)
from pyspark.sql.session import SparkSession
from pyspark.sql.streaming import StreamingQueryException
from pyspark.sql.types import Row
from pyspark.testing.sqlutils import (
Expand All @@ -39,6 +40,8 @@

@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
class BasePythonStreamingDataSourceTestsMixin:
spark: SparkSession

def test_basic_streaming_data_source_class(self):
class MyDataSource(DataSource):
...
Expand Down Expand Up @@ -146,7 +149,7 @@ def check_batch(df, batch_id):
assertDataFrameEqual(df, [Row(batch_id * 2), Row(batch_id * 2 + 1)])

q = df.writeStream.foreachBatch(check_batch).start()
while len(q.recentProgress) < 10:
while q.isActive and len(q.recentProgress) < 10:
time.sleep(0.2)
q.stop()
q.awaitTermination()
Expand Down Expand Up @@ -196,7 +199,7 @@ def streamReader(self, schema):
.option("checkpointLocation", checkpoint_dir.name)
.start(output_dir.name)
)
while not q.recentProgress:
while q.isActive and not q.recentProgress:
time.sleep(0.2)
q.stop()
q.awaitTermination()
Expand Down Expand Up @@ -244,7 +247,7 @@ def check_batch(df, batch_id):
assertDataFrameEqual(df, [Row(batch_id * 2), Row(batch_id * 2 + 1)])

q = df.writeStream.foreachBatch(check_batch).start()
while len(q.recentProgress) < 10:
while q.isActive and len(q.recentProgress) < 10:
time.sleep(0.2)
q.stop()
q.awaitTermination()
Expand All @@ -266,7 +269,7 @@ def test_stream_writer(self):
.option("checkpointLocation", checkpoint_dir.name)
.start(output_dir.name)
)
while not q.recentProgress:
while q.isActive and not q.recentProgress:
time.sleep(0.2)

# Test stream writer write and commit.
Expand All @@ -281,7 +284,7 @@ def test_stream_writer(self):
# Test StreamWriter write and abort.
# When row id > 50, write tasks throw exception and fail.
# 1.txt is written by StreamWriter.abort() to record the failure.
while q.exception() is None:
while q.isActive and q.exception() is None:
time.sleep(0.2)
assertDataFrameEqual(
self.spark.read.text(os.path.join(output_dir.name, "1.txt")),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.connector.metric.CustomTaskMetric
import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch


case class PythonInputPartition(index: Int, pickedPartition: Array[Byte]) extends InputPartition
Expand Down Expand Up @@ -63,4 +64,37 @@ class PythonPartitionReaderFactory(
}
}
}

override def supportColumnarReads(partition: InputPartition): Boolean = true

override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] = {
new PartitionReader[ColumnarBatch] {

private[this] val metrics: Map[String, SQLMetric] = PythonCustomMetric.pythonMetrics

private val outputIter = {
val evaluatorFactory = source.createMapInBatchColumnarEvaluatorFactory(
pickledReadFunc,
"read_from_data_source",
UserDefinedPythonDataSource.readInputSchema,
outputSchema,
metrics,
jobArtifactUUID)

val part = partition.asInstanceOf[PythonInputPartition]
evaluatorFactory.createEvaluator().eval(
part.index, Iterator.single(InternalRow(part.pickedPartition)))
}

override def next(): Boolean = outputIter.hasNext

override def get(): ColumnarBatch = outputIter.next()

override def close(): Unit = {}

override def currentMetricsValues(): Array[CustomTaskMetric] = {
source.createPythonTaskMetrics(metrics.map { case (k, v) => k -> v.value})
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.connector.metric.CustomTaskMetric
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.storage.PythonStreamBlockId


Expand Down Expand Up @@ -87,4 +88,41 @@ class PythonStreamingPartitionReaderFactory(
}
}
}

override def supportColumnarReads(partition: InputPartition): Boolean = {
// Prefetched block doesn't support columnar read because ColumnarBatch is not serializable.
val part = partition.asInstanceOf[PythonStreamingInputPartition]
part.blockId.isEmpty
}

override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] = {
val part = partition.asInstanceOf[PythonStreamingInputPartition]

new PartitionReader[ColumnarBatch] {
private[this] val metrics: Map[String, SQLMetric] = PythonCustomMetric.pythonMetrics

private val outputIter = {
val evaluatorFactory = source.createMapInBatchColumnarEvaluatorFactory(
pickledReadFunc,
"read_from_data_source",
UserDefinedPythonDataSource.readInputSchema,
outputSchema,
metrics,
jobArtifactUUID)

evaluatorFactory.createEvaluator().eval(
part.index, Iterator.single(InternalRow(part.pickedPartition)))
}

override def next(): Boolean = outputIter.hasNext

override def get(): ColumnarBatch = outputIter.next()

override def close(): Unit = {}

override def currentMetricsValues(): Array[CustomTaskMetric] = {
source.createPythonTaskMetrics(metrics.map { case (k, v) => k -> v.value })
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric}
import org.apache.spark.sql.connector.write._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.{ArrowPythonRunner, MapInBatchEvaluatorFactory, PythonPlannerRunner, PythonSQLMetrics}
import org.apache.spark.sql.execution.python.{ArrowPythonRunner, MapInBatchColumnarEvaluatorFactory, MapInBatchEvaluatorFactory, PythonPlannerRunner, PythonSQLMetrics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources
import org.apache.spark.sql.sources.Filter
Expand Down Expand Up @@ -172,6 +172,44 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) {
jobArtifactUUID)
}

/**
* (Executor-side) Create an iterator that execute the Python function.
*/
def createMapInBatchColumnarEvaluatorFactory(
pickledFunc: Array[Byte],
funcName: String,
inputSchema: StructType,
outputSchema: StructType,
metrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String]): MapInBatchColumnarEvaluatorFactory = {
val pythonFunc = createPythonFunction(pickledFunc)

val pythonEvalType = PythonEvalType.SQL_MAP_ARROW_ITER_UDF

val pythonUDF = PythonUDF(
name = funcName,
func = pythonFunc,
dataType = outputSchema,
children = toAttributes(inputSchema),
evalType = pythonEvalType,
udfDeterministic = false
)

val conf = SQLConf.get

val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
new MapInBatchColumnarEvaluatorFactory(
Seq((ChainedPythonFunctions(Seq(pythonUDF.func)), pythonUDF.resultId.id)),
inputSchema,
pythonEvalType,
conf.sessionLocalTimeZone,
conf.arrowUseLargeVarTypes,
pythonRunnerConf,
metrics,
jobArtifactUUID
)
}

def createPythonMetrics(): Array[CustomMetric] = {
// Do not add other metrics such as number of rows,
// that is already included via DSv2.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,51 @@ class MapInBatchEvaluatorFactory(
}
}
}

class MapInBatchColumnarEvaluatorFactory(
chainedFunc: Seq[(ChainedPythonFunctions, Long)],
outputTypes: StructType,
pythonEvalType: Int,
sessionLocalTimeZone: String,
largeVarTypes: Boolean,
pythonRunnerConf: Map[String, String],
val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String])
extends PartitionEvaluatorFactory[InternalRow, ColumnarBatch] {

override def createEvaluator(): PartitionEvaluator[InternalRow, ColumnarBatch] =
new MapInBatchEvaluator

private class MapInBatchEvaluator extends PartitionEvaluator[InternalRow, ColumnarBatch] {
override def eval(
partitionIndex: Int,
inputs: Iterator[InternalRow]*): Iterator[ColumnarBatch] = {
assert(inputs.length == 1)
val inputIter = inputs.head
// Single function with one struct.
val argOffsets = Array(Array(0))
val context = TaskContext.get()

// Here we wrap it via another row so that Python sides understand it
// as a DataFrame.
val wrappedIter = inputIter.map(InternalRow(_))

val batchIter = Iterator(wrappedIter)

val pyRunner = new ArrowPythonRunner(
chainedFunc,
pythonEvalType,
argOffsets,
StructType(Array(StructField("struct", outputTypes))),
sessionLocalTimeZone,
largeVarTypes,
pythonRunnerConf,
pythonMetrics,
jobArtifactUUID,
None) with BatchedPythonArrowInput
val columnarBatchIter = pyRunner.compute(batchIter, context.partitionId(), context)

columnarBatchIter
}
}
}