Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50661][CONNECT][SS] Fix Spark Connect Scala foreachBatch impl. to support Dataset[T]. #49323

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,10 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends api.DataSt
/** @inheritdoc */
@Evolving
def foreachBatch(function: (Dataset[T], Long) => Unit): this.type = {
val serializedFn = SparkSerDeUtils.serialize(function)
// SPARK-50661: the client should send the encoder for the input dataset together with the
// function to the server.
val serializedFn =
SparkSerDeUtils.serialize(ForeachWriterPacket(function, ds.agnosticEncoder))
sinkBuilder.getForeachBatchBuilder.getScalaFunctionBuilder
.setPayload(ByteString.copyFrom(serializedFn))
.setOutputType(DataTypeProtoConverter.toConnectProtoType(NullType)) // Unused.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ import org.scalatest.concurrent.Futures.timeout
import org.scalatest.time.SpanSugar._

import org.apache.spark.SparkException
import org.apache.spark.api.java.function.VoidFunction2
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, ForeachWriter, Row, SparkSession}
import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row, SparkSession}
import org.apache.spark.sql.functions.{col, lit, udf, window}
import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent, QueryProgressEvent, QueryStartedEvent, QueryTerminatedEvent}
import org.apache.spark.sql.test.{IntegrationTestUtils, QueryTest, RemoteSparkSession}
Expand Down Expand Up @@ -567,7 +566,7 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L
}
}

test("foreachBatch") {
test("foreachBatch with DataFrame") {
// Starts a streaming query with a foreachBatch function, which writes batchId and row count
// to a temp view. The test verifies that the view is populated with data.

Expand All @@ -581,7 +580,12 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L
.option("numPartitions", "1")
.load()
.writeStream
.foreachBatch(new ForeachBatchFn(viewName))
.foreachBatch((df: DataFrame, batchId: Long) => {
val count = df.collect().map(row => row.getLong(1)).sum
df.sparkSession
.createDataFrame(Seq((batchId, count)))
.createOrReplaceGlobalTempView(viewName)
})
.start()

eventually(timeout(30.seconds)) { // Wait for first progress.
Expand All @@ -596,13 +600,83 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L
.collect()
.toSeq
assert(rows.size > 0)
assert(rows.map(_.getLong(1)).sum > 0)
logInfo(s"Rows in $tableName: $rows")
}

q.stop()
}
}

test("foreachBatch with Dataset[java.lang.Long]") {
val viewName = "test_view"
val tableName = s"global_temp.$viewName"

withTable(tableName) {
val session = spark
import session.implicits._
val q = spark.readStream
.format("rate")
.option("rowsPerSecond", "10")
.option("numPartitions", "1")
.load()
.select($"value")
.as[java.lang.Long]
.writeStream
.foreachBatch((ds: Dataset[java.lang.Long], batchId: Long) => {
val count = ds.collect().map(v => v.asInstanceOf[Long]).sum
ds.sparkSession
.createDataFrame(Seq((batchId, count)))
.createOrReplaceGlobalTempView(viewName)
})
.start()

eventually(timeout(30.seconds)) { // Wait for first progress.
assert(q.lastProgress != null, "Failed to make progress")
assert(q.lastProgress.numInputRows > 0)
}

eventually(timeout(30.seconds)) {
// There should be row(s) in temporary view created by foreachBatch.
val rows = spark
.sql(s"select * from $tableName")
.collect()
.toSeq
assert(rows.size > 0)
assert(rows.map(_.getLong(1)).sum > 0)
logInfo(s"Rows in $tableName: $rows")
}

q.stop()
}
}

test("foreachBatch with Dataset[TestClass]") {
val session: SparkSession = spark
import session.implicits._
val viewName = "test_view"
val tableName = s"global_temp.$viewName"

val df = spark.readStream
.format("rate")
.option("rowsPerSecond", "10")
.load()

val q = df
.selectExpr("CAST(value AS INT)")
.as[TestClass]
.writeStream
.foreachBatch((ds: Dataset[TestClass], batchId: Long) => {
val count = ds.collect().map(_.value).sum
})
.start()
eventually(timeout(30.seconds)) {
assert(q.isActive)
assert(q.exception.isEmpty)
}
q.stop()
}

abstract class EventCollector extends StreamingQueryListener {
protected def tablePostfix: String

Expand Down Expand Up @@ -700,14 +774,3 @@ class TestForeachWriter[T] extends ForeachWriter[T] {
case class TestClass(value: Int) {
override def toString: String = value.toString
}

class ForeachBatchFn(val viewName: String)
extends VoidFunction2[DataFrame, java.lang.Long]
with Serializable {
override def call(df: DataFrame, batchId: java.lang.Long): Unit = {
val count = df.count()
df.sparkSession
.createDataFrame(Seq((batchId.toLong, count)))
.createOrReplaceGlobalTempView(viewName)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2957,10 +2957,9 @@ class SparkConnectPlanner(
fn

case StreamingForeachFunction.FunctionCase.SCALA_FUNCTION =>
val scalaFn = Utils.deserialize[StreamingForeachBatchHelper.ForeachBatchFnType](
StreamingForeachBatchHelper.scalaForeachBatchWrapper(
writeOp.getForeachBatch.getScalaFunction.getPayload.toByteArray,
Utils.getContextOrSparkClassLoader)
StreamingForeachBatchHelper.scalaForeachBatchWrapper(scalaFn, sessionHolder)
sessionHolder)

case StreamingForeachFunction.FunctionCase.FUNCTION_NOT_SET =>
throw InvalidPlanInput("Unexpected foreachBatch function") // Unreachable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@ import org.apache.spark.SparkException
import org.apache.spark.api.python.{PythonException, PythonWorkerUtils, SimplePythonFunction, SpecialLengths, StreamingPythonRunner}
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, QUERY_ID, RUN_ID_STRING, SESSION_ID}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders}
import org.apache.spark.sql.connect.common.ForeachWriterPacket
import org.apache.spark.sql.connect.service.SessionHolder
import org.apache.spark.sql.connect.service.SparkConnectService
import org.apache.spark.sql.streaming.StreamingQuery
import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.util.Utils

/**
* A helper class for handling ForeachBatch related functionality in Spark Connect servers
Expand Down Expand Up @@ -88,13 +91,31 @@ object StreamingForeachBatchHelper extends Logging {
* DataFrame, so the user code actually runs with legacy DataFrame and session..
*/
def scalaForeachBatchWrapper(
fn: ForeachBatchFnType,
payloadBytes: Array[Byte],
sessionHolder: SessionHolder): ForeachBatchFnType = {
val foreachBatchPkt =
Utils.deserialize[ForeachWriterPacket](payloadBytes, Utils.getContextOrSparkClassLoader)
val fn = foreachBatchPkt.foreachWriter.asInstanceOf[(Dataset[Any], Long) => Unit]
val encoder = foreachBatchPkt.datasetEncoder.asInstanceOf[AgnosticEncoder[Any]]
// TODO(SPARK-44462): Set up Spark Connect session.
// Do we actually need this for the first version?
dataFrameCachingWrapper(
(args: FnArgsWithId) => {
fn(args.df, args.batchId) // dfId is not used, see hack comment above.
// dfId is not used, see hack comment above.
try {
val ds = if (AgnosticEncoders.UnboundRowEncoder == encoder) {
// When the dataset is a DataFrame (Dataset[Row).
args.df.asInstanceOf[Dataset[Any]]
} else {
// Recover the Dataset from the DataFrame using the encoder.
Dataset.apply(args.df.sparkSession, args.df.logicalPlan)(encoder)
}
fn(ds, args.batchId)
} catch {
case t: Throwable =>
logError(s"Calling foreachBatch fn failed", t)
throw t
}
},
sessionHolder)
}
Expand Down
Loading