diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 9fcc31e562682..b2c4fcf64e70f 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -135,7 +135,10 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends api.DataSt /** @inheritdoc */ @Evolving def foreachBatch(function: (Dataset[T], Long) => Unit): this.type = { - val serializedFn = SparkSerDeUtils.serialize(function) + // SPARK-50661: the client should send the encoder for the input dataset together with the + // function to the server. + val serializedFn = + SparkSerDeUtils.serialize(ForeachWriterPacket(function, ds.agnosticEncoder)) sinkBuilder.getForeachBatchBuilder.getScalaFunctionBuilder .setPayload(ByteString.copyFrom(serializedFn)) .setOutputType(DataTypeProtoConverter.toConnectProtoType(NullType)) // Unused. diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala index b1a7d81916e92..199a1507a3b19 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala @@ -28,9 +28,8 @@ import org.scalatest.concurrent.Futures.timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkException -import org.apache.spark.api.java.function.VoidFunction2 import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, ForeachWriter, Row, SparkSession} +import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row, SparkSession} import org.apache.spark.sql.functions.{col, lit, udf, window} import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent, QueryProgressEvent, QueryStartedEvent, QueryTerminatedEvent} import org.apache.spark.sql.test.{IntegrationTestUtils, QueryTest, RemoteSparkSession} @@ -567,7 +566,7 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L } } - test("foreachBatch") { + test("foreachBatch with DataFrame") { // Starts a streaming query with a foreachBatch function, which writes batchId and row count // to a temp view. The test verifies that the view is populated with data. @@ -581,7 +580,12 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L .option("numPartitions", "1") .load() .writeStream - .foreachBatch(new ForeachBatchFn(viewName)) + .foreachBatch((df: DataFrame, batchId: Long) => { + val count = df.collect().map(row => row.getLong(1)).sum + df.sparkSession + .createDataFrame(Seq((batchId, count))) + .createOrReplaceGlobalTempView(viewName) + }) .start() eventually(timeout(30.seconds)) { // Wait for first progress. @@ -596,6 +600,7 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L .collect() .toSeq assert(rows.size > 0) + assert(rows.map(_.getLong(1)).sum > 0) logInfo(s"Rows in $tableName: $rows") } @@ -603,6 +608,75 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L } } + test("foreachBatch with Dataset[java.lang.Long]") { + val viewName = "test_view" + val tableName = s"global_temp.$viewName" + + withTable(tableName) { + val session = spark + import session.implicits._ + val q = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("numPartitions", "1") + .load() + .select($"value") + .as[java.lang.Long] + .writeStream + .foreachBatch((ds: Dataset[java.lang.Long], batchId: Long) => { + val count = ds.collect().map(v => v.asInstanceOf[Long]).sum + ds.sparkSession + .createDataFrame(Seq((batchId, count))) + .createOrReplaceGlobalTempView(viewName) + }) + .start() + + eventually(timeout(30.seconds)) { // Wait for first progress. + assert(q.lastProgress != null, "Failed to make progress") + assert(q.lastProgress.numInputRows > 0) + } + + eventually(timeout(30.seconds)) { + // There should be row(s) in temporary view created by foreachBatch. + val rows = spark + .sql(s"select * from $tableName") + .collect() + .toSeq + assert(rows.size > 0) + assert(rows.map(_.getLong(1)).sum > 0) + logInfo(s"Rows in $tableName: $rows") + } + + q.stop() + } + } + + test("foreachBatch with Dataset[TestClass]") { + val session: SparkSession = spark + import session.implicits._ + val viewName = "test_view" + val tableName = s"global_temp.$viewName" + + val df = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .load() + + val q = df + .selectExpr("CAST(value AS INT)") + .as[TestClass] + .writeStream + .foreachBatch((ds: Dataset[TestClass], batchId: Long) => { + val count = ds.collect().map(_.value).sum + }) + .start() + eventually(timeout(30.seconds)) { + assert(q.isActive) + assert(q.exception.isEmpty) + } + q.stop() + } + abstract class EventCollector extends StreamingQueryListener { protected def tablePostfix: String @@ -700,14 +774,3 @@ class TestForeachWriter[T] extends ForeachWriter[T] { case class TestClass(value: Int) { override def toString: String = value.toString } - -class ForeachBatchFn(val viewName: String) - extends VoidFunction2[DataFrame, java.lang.Long] - with Serializable { - override def call(df: DataFrame, batchId: java.lang.Long): Unit = { - val count = df.count() - df.sparkSession - .createDataFrame(Seq((batchId.toLong, count))) - .createOrReplaceGlobalTempView(viewName) - } -} diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 5ace916ba3e9e..d6ade1ac91264 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -2957,10 +2957,9 @@ class SparkConnectPlanner( fn case StreamingForeachFunction.FunctionCase.SCALA_FUNCTION => - val scalaFn = Utils.deserialize[StreamingForeachBatchHelper.ForeachBatchFnType]( + StreamingForeachBatchHelper.scalaForeachBatchWrapper( writeOp.getForeachBatch.getScalaFunction.getPayload.toByteArray, - Utils.getContextOrSparkClassLoader) - StreamingForeachBatchHelper.scalaForeachBatchWrapper(scalaFn, sessionHolder) + sessionHolder) case StreamingForeachFunction.FunctionCase.FUNCTION_NOT_SET => throw InvalidPlanInput("Unexpected foreachBatch function") // Unreachable diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala index b6f67fe9f02f6..ab6bed7152c09 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala @@ -28,11 +28,14 @@ import org.apache.spark.SparkException import org.apache.spark.api.python.{PythonException, PythonWorkerUtils, SimplePythonFunction, SpecialLengths, StreamingPythonRunner} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, QUERY_ID, RUN_ID_STRING, SESSION_ID} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders} +import org.apache.spark.sql.connect.common.ForeachWriterPacket import org.apache.spark.sql.connect.service.SessionHolder import org.apache.spark.sql.connect.service.SparkConnectService import org.apache.spark.sql.streaming.StreamingQuery import org.apache.spark.sql.streaming.StreamingQueryListener +import org.apache.spark.util.Utils /** * A helper class for handling ForeachBatch related functionality in Spark Connect servers @@ -88,13 +91,31 @@ object StreamingForeachBatchHelper extends Logging { * DataFrame, so the user code actually runs with legacy DataFrame and session.. */ def scalaForeachBatchWrapper( - fn: ForeachBatchFnType, + payloadBytes: Array[Byte], sessionHolder: SessionHolder): ForeachBatchFnType = { + val foreachBatchPkt = + Utils.deserialize[ForeachWriterPacket](payloadBytes, Utils.getContextOrSparkClassLoader) + val fn = foreachBatchPkt.foreachWriter.asInstanceOf[(Dataset[Any], Long) => Unit] + val encoder = foreachBatchPkt.datasetEncoder.asInstanceOf[AgnosticEncoder[Any]] // TODO(SPARK-44462): Set up Spark Connect session. // Do we actually need this for the first version? dataFrameCachingWrapper( (args: FnArgsWithId) => { - fn(args.df, args.batchId) // dfId is not used, see hack comment above. + // dfId is not used, see hack comment above. + try { + val ds = if (AgnosticEncoders.UnboundRowEncoder == encoder) { + // When the dataset is a DataFrame (Dataset[Row). + args.df.asInstanceOf[Dataset[Any]] + } else { + // Recover the Dataset from the DataFrame using the encoder. + Dataset.apply(args.df.sparkSession, args.df.logicalPlan)(encoder) + } + fn(ds, args.batchId) + } catch { + case t: Throwable => + logError(s"Calling foreachBatch fn failed", t) + throw t + } }, sessionHolder) }