From 181b035aabf13e3ef4e9c12e8793357cc45c44f1 Mon Sep 17 00:00:00 2001 From: Haiyang Sun <75672016+haiyangsun-db@users.noreply.github.com> Date: Fri, 27 Dec 2024 16:16:49 +0100 Subject: [PATCH 1/4] Fix Spark Connect Scala foreachBatch impl. to support Dataset[T]. --- .../streaming/ClientStreamingQuerySuite.scala | 65 ++- .../connect/planner/SparkConnectPlanner.scala | 6 +- .../planner/StreamingForeachBatchHelper.scala | 28 +- .../sql/streaming/DataStreamWriter.scala | 403 ++++-------------- 4 files changed, 155 insertions(+), 347 deletions(-) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala index b1a7d81916e92..7fdf89b56df44 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala @@ -28,9 +28,8 @@ import org.scalatest.concurrent.Futures.timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkException -import org.apache.spark.api.java.function.VoidFunction2 import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, ForeachWriter, Row, SparkSession} +import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row, SparkSession} import org.apache.spark.sql.functions.{col, lit, udf, window} import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent, QueryProgressEvent, QueryStartedEvent, QueryTerminatedEvent} import org.apache.spark.sql.test.{IntegrationTestUtils, QueryTest, RemoteSparkSession} @@ -581,7 +580,12 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L .option("numPartitions", "1") .load() .writeStream - .foreachBatch(new ForeachBatchFn(viewName)) + .foreachBatch((df: DataFrame, batchId: Long) => { + val count = df.collect().map(row => row.getLong(1)).sum + df.sparkSession + .createDataFrame(Seq((batchId, count))) + .createOrReplaceGlobalTempView(viewName) + }) .start() eventually(timeout(30.seconds)) { // Wait for first progress. @@ -596,6 +600,50 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L .collect() .toSeq assert(rows.size > 0) + assert(rows.map(_.getLong(1)).sum > 0) + logInfo(s"Rows in $tableName: $rows") + } + + q.stop() + } + } + + test("foreachBatch with Dataset[java.lang.Long]") { + val viewName = "test_view" + val tableName = s"global_temp.$viewName" + + withTable(tableName) { + val session = spark + import session.implicits._ + val q = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("numPartitions", "1") + .load() + .select($"value") + .as[java.lang.Long] + .writeStream + .foreachBatch((ds: Dataset[java.lang.Long], batchId: Long) => { + val count = ds.collect().map(v => v.asInstanceOf[Long]).sum + ds.sparkSession + .createDataFrame(Seq((batchId, count))) + .createOrReplaceGlobalTempView(viewName) + }) + .start() + + eventually(timeout(30.seconds)) { // Wait for first progress. + assert(q.lastProgress != null, "Failed to make progress") + assert(q.lastProgress.numInputRows > 0) + } + + eventually(timeout(30.seconds)) { + // There should be row(s) in temporary view created by foreachBatch. + val rows = spark + .sql(s"select * from $tableName") + .collect() + .toSeq + assert(rows.size > 0) + assert(rows.map(_.getLong(1)).sum > 0) logInfo(s"Rows in $tableName: $rows") } @@ -700,14 +748,3 @@ class TestForeachWriter[T] extends ForeachWriter[T] { case class TestClass(value: Int) { override def toString: String = value.toString } - -class ForeachBatchFn(val viewName: String) - extends VoidFunction2[DataFrame, java.lang.Long] - with Serializable { - override def call(df: DataFrame, batchId: java.lang.Long): Unit = { - val count = df.count() - df.sparkSession - .createDataFrame(Seq((batchId.toLong, count))) - .createOrReplaceGlobalTempView(viewName) - } -} diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 5ace916ba3e9e..7975f3170837d 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -2957,10 +2957,10 @@ class SparkConnectPlanner( fn case StreamingForeachFunction.FunctionCase.SCALA_FUNCTION => - val scalaFn = Utils.deserialize[StreamingForeachBatchHelper.ForeachBatchFnType]( + StreamingForeachBatchHelper.scalaForeachBatchWrapper( writeOp.getForeachBatch.getScalaFunction.getPayload.toByteArray, - Utils.getContextOrSparkClassLoader) - StreamingForeachBatchHelper.scalaForeachBatchWrapper(scalaFn, sessionHolder) + sessionHolder + ) case StreamingForeachFunction.FunctionCase.FUNCTION_NOT_SET => throw InvalidPlanInput("Unexpected foreachBatch function") // Unreachable diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala index b6f67fe9f02f6..cc96eb39978f6 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala @@ -28,11 +28,14 @@ import org.apache.spark.SparkException import org.apache.spark.api.python.{PythonException, PythonWorkerUtils, SimplePythonFunction, SpecialLengths, StreamingPythonRunner} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, QUERY_ID, RUN_ID_STRING, SESSION_ID} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders} +import org.apache.spark.sql.connect.common.ForeachWriterPacket import org.apache.spark.sql.connect.service.SessionHolder import org.apache.spark.sql.connect.service.SparkConnectService import org.apache.spark.sql.streaming.StreamingQuery import org.apache.spark.sql.streaming.StreamingQueryListener +import org.apache.spark.util.Utils /** * A helper class for handling ForeachBatch related functionality in Spark Connect servers @@ -88,13 +91,32 @@ object StreamingForeachBatchHelper extends Logging { * DataFrame, so the user code actually runs with legacy DataFrame and session.. */ def scalaForeachBatchWrapper( - fn: ForeachBatchFnType, + payloadBytes: Array[Byte], sessionHolder: SessionHolder): ForeachBatchFnType = { + val foreachBatchPkt = Utils.deserialize[ForeachWriterPacket]( + payloadBytes, + Utils.getContextOrSparkClassLoader) + val fn = foreachBatchPkt.foreachWriter.asInstanceOf[(Dataset[Any], Long) => Unit] + val encoder = foreachBatchPkt.datasetEncoder.asInstanceOf[AgnosticEncoder[Any]] // TODO(SPARK-44462): Set up Spark Connect session. // Do we actually need this for the first version? dataFrameCachingWrapper( (args: FnArgsWithId) => { - fn(args.df, args.batchId) // dfId is not used, see hack comment above. + // dfId is not used, see hack comment above. + try { + val ds = if (AgnosticEncoders.UnboundRowEncoder == encoder) { + // When the dataset is a DataFrame (Dataset[Row). + args.df.asInstanceOf[Dataset[Any]] + } else { + // Recover the Dataset from the DataFrame using the encoder. + Dataset.apply(args.df.sparkSession, args.df.logicalPlan)(encoder) + } + fn(ds, args.batchId) + } catch { + case t: Throwable => + logError(s"Calling foreachBatch fn failed", t) + throw t + } }, sessionHolder) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index d41933c6a135c..ee442b571c820 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -22,377 +22,158 @@ import java.util.concurrent.TimeoutException import scala.jdk.CollectionConverters._ -import org.apache.hadoop.fs.Path +import com.google.protobuf.ByteString import org.apache.spark.annotation.Evolving import org.apache.spark.api.java.function.VoidFunction2 -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.plans.logical.{ColumnDefinition, CreateTable, OptionList, UnresolvedTableSpec} -import org.apache.spark.sql.catalyst.streaming.InternalOutputModes -import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.connector.catalog.{Identifier, SupportsWrite, Table, TableCatalog, TableProvider, V1Table, V2TableWithV1Fallback} -import org.apache.spark.sql.connector.catalog.TableCapability._ -import org.apache.spark.sql.connector.expressions.{ClusterByTransform, FieldReference} -import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.command.DDLUtils -import org.apache.spark.sql.execution.datasources.{DataSource, DataSourceUtils} -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Utils, FileDataSourceV2} -import org.apache.spark.sql.execution.datasources.v2.python.PythonDataSourceV2 -import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.sources._ -import org.apache.spark.sql.util.CaseInsensitiveStringMap -import org.apache.spark.util.ArrayImplicits._ -import org.apache.spark.util.Utils +import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.Command +import org.apache.spark.connect.proto.WriteStreamOperationStart +import org.apache.spark.sql.{api, Dataset, ForeachWriter} +import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket} +import org.apache.spark.sql.execution.streaming.AvailableNowTrigger +import org.apache.spark.sql.execution.streaming.ContinuousTrigger +import org.apache.spark.sql.execution.streaming.OneTimeTrigger +import org.apache.spark.sql.execution.streaming.ProcessingTimeTrigger +import org.apache.spark.sql.streaming.StreamingQueryListener.QueryStartedEvent +import org.apache.spark.sql.types.NullType +import org.apache.spark.util.SparkSerDeUtils /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, * key-value stores, etc). Use `Dataset.writeStream` to access this. * - * @since 2.0.0 + * @since 3.5.0 */ @Evolving -final class DataStreamWriter[T] private[sql](ds: Dataset[T]) extends api.DataStreamWriter[T] { - type DS[U] = Dataset[U] +final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends api.DataStreamWriter[T] { + override type DS[U] = Dataset[U] /** @inheritdoc */ def outputMode(outputMode: OutputMode): this.type = { - this.outputMode = outputMode + sinkBuilder.setOutputMode(outputMode.toString.toLowerCase(Locale.ROOT)) this } /** @inheritdoc */ def outputMode(outputMode: String): this.type = { - this.outputMode = InternalOutputModes(outputMode) + sinkBuilder.setOutputMode(outputMode) this } /** @inheritdoc */ def trigger(trigger: Trigger): this.type = { - this.trigger = trigger + trigger match { + case ProcessingTimeTrigger(intervalMs) => + sinkBuilder.setProcessingTimeInterval(s"$intervalMs milliseconds") + case AvailableNowTrigger => + sinkBuilder.setAvailableNow(true) + case OneTimeTrigger => + sinkBuilder.setOnce(true) + case ContinuousTrigger(intervalMs) => + sinkBuilder.setContinuousCheckpointInterval(s"$intervalMs milliseconds") + } this } /** @inheritdoc */ def queryName(queryName: String): this.type = { - this.extraOptions += ("queryName" -> queryName) + sinkBuilder.setQueryName(queryName) this } /** @inheritdoc */ def format(source: String): this.type = { - this.source = source + sinkBuilder.setFormat(source) this } /** @inheritdoc */ @scala.annotation.varargs def partitionBy(colNames: String*): this.type = { - this.partitioningColumns = Option(colNames) - validatePartitioningAndClustering() + sinkBuilder.clearPartitioningColumnNames() + sinkBuilder.addAllPartitioningColumnNames(colNames.asJava) this } /** @inheritdoc */ @scala.annotation.varargs def clusterBy(colNames: String*): this.type = { - this.clusteringColumns = Option(colNames) - validatePartitioningAndClustering() + sinkBuilder.clearClusteringColumnNames() + sinkBuilder.addAllClusteringColumnNames(colNames.asJava) this } /** @inheritdoc */ def option(key: String, value: String): this.type = { - this.extraOptions += (key -> value) + sinkBuilder.putOptions(key, value) this } /** @inheritdoc */ def options(options: scala.collection.Map[String, String]): this.type = { - this.extraOptions ++= options + this.options(options.asJava) this } /** @inheritdoc */ def options(options: java.util.Map[String, String]): this.type = { - this.options(options.asScala) + sinkBuilder.putAllOptions(options) this } - /** @inheritdoc */ - def start(path: String): StreamingQuery = { - if (!ds.sparkSession.sessionState.conf.legacyPathOptionBehavior && - extraOptions.contains("path")) { - throw QueryCompilationErrors.setPathOptionAndCallWithPathParameterError("start") - } - startInternal(Some(path)) - } - - /** @inheritdoc */ - @throws[TimeoutException] - def start(): StreamingQuery = startInternal(None) - - /** @inheritdoc */ - @Evolving - @throws[TimeoutException] - def toTable(tableName: String): StreamingQuery = { - - import ds.sparkSession.sessionState.analyzer.CatalogAndIdentifier - import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ - val parser = ds.sparkSession.sessionState.sqlParser - val originalMultipartIdentifier = parser.parseMultipartIdentifier(tableName) - val CatalogAndIdentifier(catalog, identifier) = originalMultipartIdentifier - - // Currently we don't create a logical streaming writer node in logical plan, so cannot rely - // on analyzer to resolve it. Directly lookup only for temp view to provide clearer message. - // TODO (SPARK-27484): we should add the writing node before the plan is analyzed. - if (ds.sparkSession.sessionState.catalog.isTempView(originalMultipartIdentifier)) { - throw QueryCompilationErrors.tempViewNotSupportStreamingWriteError(tableName) - } - - if (!catalog.asTableCatalog.tableExists(identifier)) { - import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ - - val properties = normalizedClusteringCols.map { cols => - Map( - DataSourceUtils.CLUSTERING_COLUMNS_KEY -> DataSourceUtils.encodePartitioningColumns(cols)) - }.getOrElse(Map.empty) - val partitioningOrClusteringTransform = normalizedClusteringCols.map { colNames => - Array(ClusterByTransform(colNames.map(col => FieldReference(col)))).toImmutableArraySeq - }.getOrElse(partitioningColumns.getOrElse(Nil).asTransforms.toImmutableArraySeq) - - /** - * Note, currently the new table creation by this API doesn't fully cover the V2 table. - * TODO (SPARK-33638): Full support of v2 table creation - */ - val tableSpec = UnresolvedTableSpec( - properties, - Some(source), - OptionList(Seq.empty), - extraOptions.get("path"), - None, - None, - None, - external = false) - val cmd = CreateTable( - UnresolvedIdentifier(originalMultipartIdentifier), - ds.schema.asNullable.map(ColumnDefinition.fromV1Column(_, parser)), - partitioningOrClusteringTransform, - tableSpec, - ignoreIfExists = false) - Dataset.ofRows(ds.sparkSession, cmd) - } - - val tableInstance = catalog.asTableCatalog.loadTable(identifier) - - def writeToV1Table(table: CatalogTable): StreamingQuery = { - if (table.tableType == CatalogTableType.VIEW) { - throw QueryCompilationErrors.streamingIntoViewNotSupportedError(tableName) - } - require(table.provider.isDefined) - if (source != table.provider.get) { - throw QueryCompilationErrors.inputSourceDiffersFromDataSourceProviderError( - source, tableName, table) - } - format(table.provider.get).startInternal( - Some(new Path(table.location).toString), catalogTable = Some(table)) - } - - import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ - tableInstance match { - case t: SupportsWrite if t.supports(STREAMING_WRITE) => - startQuery(t, extraOptions, catalogAndIdent = Some(catalog.asTableCatalog, identifier)) - case t: V2TableWithV1Fallback => - writeToV1Table(t.v1Table) - case t: V1Table => - writeToV1Table(t.v1Table) - case t => throw QueryCompilationErrors.tableNotSupportStreamingWriteError(tableName, t) - } - } - - private def startInternal( - path: Option[String], - catalogTable: Option[CatalogTable] = None): StreamingQuery = { - if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { - throw QueryCompilationErrors.cannotOperateOnHiveDataSourceFilesError("write") - } - - if (source == DataStreamWriter.SOURCE_NAME_MEMORY) { - assertNotPartitioned(DataStreamWriter.SOURCE_NAME_MEMORY) - if (extraOptions.get("queryName").isEmpty) { - throw QueryCompilationErrors.queryNameNotSpecifiedForMemorySinkError() - } - val sink = new MemorySink() - val resultDf = Dataset.ofRows(ds.sparkSession, - MemoryPlan(sink, DataTypeUtils.toAttributes(ds.schema))) - val recoverFromCheckpoint = outputMode == OutputMode.Complete() - val query = startQuery(sink, extraOptions, recoverFromCheckpoint = recoverFromCheckpoint, - catalogTable = catalogTable) - resultDf.createOrReplaceTempView(query.name) - query - } else if (source == DataStreamWriter.SOURCE_NAME_FOREACH) { - assertNotPartitioned(DataStreamWriter.SOURCE_NAME_FOREACH) - val sink = ForeachWriterTable[Any](foreachWriter, foreachWriterEncoder) - startQuery(sink, extraOptions, catalogTable = catalogTable) - } else if (source == DataStreamWriter.SOURCE_NAME_FOREACH_BATCH) { - assertNotPartitioned(DataStreamWriter.SOURCE_NAME_FOREACH_BATCH) - if (trigger.isInstanceOf[ContinuousTrigger]) { - throw QueryCompilationErrors.sourceNotSupportedWithContinuousTriggerError(source) - } - val sink = new ForeachBatchSink[T](foreachBatchWriter, ds.exprEnc) - startQuery(sink, extraOptions, catalogTable = catalogTable) - } else { - val cls = DataSource.lookupDataSource(source, ds.sparkSession.sessionState.conf) - val disabledSources = - Utils.stringToSeq(ds.sparkSession.sessionState.conf.disabledV2StreamingWriters) - val useV1Source = disabledSources.contains(cls.getCanonicalName) || - // file source v2 does not support streaming yet. - classOf[FileDataSourceV2].isAssignableFrom(cls) - - val optionsWithPath = if (path.isEmpty) { - extraOptions - } else { - extraOptions + ("path" -> path.get) - } - - val sink = if (classOf[TableProvider].isAssignableFrom(cls) && !useV1Source) { - val provider = cls.getConstructor().newInstance().asInstanceOf[TableProvider] - val sessionOptions = DataSourceV2Utils.extractSessionConfigs( - source = provider, conf = ds.sparkSession.sessionState.conf) - val finalOptions = sessionOptions.filter { case (k, _) => !optionsWithPath.contains(k) } ++ - optionsWithPath.originalMap - val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava) - // If the source accepts external table metadata, here we pass the schema of input query - // to `getTable`. This is for avoiding schema inference, which can be very expensive. - // If the query schema is not compatible with the existing data, the behavior is undefined. - val outputSchema = if (provider.supportsExternalMetadata()) { - Some(ds.schema) - } else { - None - } - provider match { - case p: PythonDataSourceV2 => p.setShortName(source) - case _ => - } - val table = DataSourceV2Utils.getTableFromProvider( - provider, dsOptions, userSpecifiedSchema = outputSchema) - import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ - table match { - case table: SupportsWrite if table.supports(STREAMING_WRITE) => - table - case _ => createV1Sink(optionsWithPath) - } - } else { - createV1Sink(optionsWithPath) - } - - startQuery(sink, optionsWithPath, catalogTable = catalogTable) - } - } - - private def startQuery( - sink: Table, - newOptions: CaseInsensitiveMap[String], - recoverFromCheckpoint: Boolean = true, - catalogAndIdent: Option[(TableCatalog, Identifier)] = None, - catalogTable: Option[CatalogTable] = None): StreamingQuery = { - val useTempCheckpointLocation = DataStreamWriter.SOURCES_ALLOW_ONE_TIME_QUERY.contains(source) - - ds.sparkSession.sessionState.streamingQueryManager.startQuery( - newOptions.get("queryName"), - newOptions.get("checkpointLocation"), - ds, - newOptions.originalMap, - sink, - outputMode, - useTempCheckpointLocation = useTempCheckpointLocation, - recoverFromCheckpointLocation = recoverFromCheckpoint, - trigger = trigger, - catalogAndIdent = catalogAndIdent, - catalogTable = catalogTable) - } - - private def createV1Sink(optionsWithPath: CaseInsensitiveMap[String]): Sink = { - // Do not allow the user to specify clustering columns in the options. Ignoring this option is - // consistent with the behavior of DataFrameWriter on non Path-based tables and with the - // behavior of DataStreamWriter on partitioning columns specified in options. - val optionsWithoutClusteringKey = - optionsWithPath.originalMap - DataSourceUtils.CLUSTERING_COLUMNS_KEY - - val optionsWithClusteringColumns = normalizedClusteringCols match { - case Some(cols) => optionsWithoutClusteringKey + ( - DataSourceUtils.CLUSTERING_COLUMNS_KEY -> - DataSourceUtils.encodePartitioningColumns(cols)) - case None => optionsWithoutClusteringKey - } - val ds = DataSource( - this.ds.sparkSession, - className = source, - options = optionsWithClusteringColumns, - partitionColumns = normalizedParCols.getOrElse(Nil)) - ds.createSink(outputMode) - } - /** @inheritdoc */ def foreach(writer: ForeachWriter[T]): this.type = { - foreachImplementation(writer.asInstanceOf[ForeachWriter[Any]]) - } - - private[sql] def foreachImplementation(writer: ForeachWriter[Any], - encoder: Option[ExpressionEncoder[Any]] = None): this.type = { - this.source = DataStreamWriter.SOURCE_NAME_FOREACH - this.foreachWriter = if (writer != null) { - ds.sparkSession.sparkContext.clean(writer) - } else { - throw new IllegalArgumentException("foreach writer cannot be null") - } - encoder.foreach(e => this.foreachWriterEncoder = e) + val serialized = SparkSerDeUtils.serialize(ForeachWriterPacket(writer, ds.agnosticEncoder)) + val scalaWriterBuilder = proto.ScalarScalaUDF + .newBuilder() + .setPayload(ByteString.copyFrom(serialized)) + sinkBuilder.getForeachWriterBuilder.setScalaFunction(scalaWriterBuilder) this } /** @inheritdoc */ @Evolving def foreachBatch(function: (Dataset[T], Long) => Unit): this.type = { - this.source = DataStreamWriter.SOURCE_NAME_FOREACH_BATCH - if (function == null) throw new IllegalArgumentException("foreachBatch function cannot be null") - this.foreachBatchWriter = function + // SPARK-50661: the client should sent the encoder for the input dataset together with the + // function to the server. + val serializedFn = SparkSerDeUtils.serialize(ForeachWriterPacket(function, ds.agnosticEncoder)) + sinkBuilder.getForeachBatchBuilder.getScalaFunctionBuilder + .setPayload(ByteString.copyFrom(serializedFn)) + .setOutputType(DataTypeProtoConverter.toConnectProtoType(NullType)) // Unused. + .setNullable(true) // Unused. this } - private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols => - cols.map(normalize(_, "Partition")) - } - - private def normalizedClusteringCols: Option[Seq[String]] = clusteringColumns.map { cols => - cols.map(normalize(_, "Clustering")) - } - - /** - * The given column name may not be equal to any of the existing column names if we were in - * case-insensitive context. Normalize the given column name to the real one so that we don't - * need to care about case sensitivity afterwards. - */ - private def normalize(columnName: String, columnType: String): String = { - val validColumnNames = ds.logicalPlan.output.map(_.name) - validColumnNames.find(ds.sparkSession.sessionState.analyzer.resolver(_, columnName)) - .getOrElse(throw QueryCompilationErrors.columnNotFoundInExistingColumnsError( - columnType, columnName, validColumnNames)) + /** @inheritdoc */ + def start(path: String): StreamingQuery = { + sinkBuilder.setPath(path) + start() } - private def assertNotPartitioned(operation: String): Unit = { - if (partitioningColumns.isDefined) { - throw QueryCompilationErrors.operationNotSupportPartitioningError(operation) + /** @inheritdoc */ + @throws[TimeoutException] + def start(): StreamingQuery = { + val startCmd = Command + .newBuilder() + .setWriteStreamOperationStart(sinkBuilder.build()) + .build() + + val resp = ds.sparkSession.execute(startCmd).head + if (resp.getWriteStreamOperationStartResult.hasQueryStartedEventJson) { + val event = QueryStartedEvent.fromJson( + resp.getWriteStreamOperationStartResult.getQueryStartedEventJson) + ds.sparkSession.streams.streamingQueryListenerBus.postToAll(event) } + RemoteStreamingQuery.fromStartCommandResponse(ds.sparkSession, resp) } - // Validate that partitionBy isn't used with clusterBy. - private def validatePartitioningAndClustering(): Unit = { - if (clusteringColumns.nonEmpty && partitioningColumns.nonEmpty) { - throw QueryCompilationErrors.clusterByWithPartitionedBy() - } + /** @inheritdoc */ + @Evolving + @throws[TimeoutException] + def toTable(tableName: String): StreamingQuery = { + sinkBuilder.setTableName(tableName) + start() } /////////////////////////////////////////////////////////////////////////////////////// @@ -413,39 +194,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) extends api.DataStr override def foreachBatch(function: VoidFunction2[Dataset[T], java.lang.Long]): this.type = super.foreachBatch(function) - /////////////////////////////////////////////////////////////////////////////////////// - // Builder pattern config options - /////////////////////////////////////////////////////////////////////////////////////// - - private var source: String = ds.sparkSession.sessionState.conf.defaultDataSourceName - - private var outputMode: OutputMode = OutputMode.Append - - private var trigger: Trigger = Trigger.ProcessingTime(0L) - - private var extraOptions = CaseInsensitiveMap[String](Map.empty) - - private var foreachWriter: ForeachWriter[Any] = _ - - private var foreachWriterEncoder: ExpressionEncoder[Any] = - ds.exprEnc.asInstanceOf[ExpressionEncoder[Any]] - - private var foreachBatchWriter: (Dataset[T], Long) => Unit = _ - - private var partitioningColumns: Option[Seq[String]] = None - - private var clusteringColumns: Option[Seq[String]] = None -} - -object DataStreamWriter { - val SOURCE_NAME_MEMORY: String = "memory" - val SOURCE_NAME_FOREACH: String = "foreach" - val SOURCE_NAME_FOREACH_BATCH: String = "foreachBatch" - val SOURCE_NAME_CONSOLE: String = "console" - val SOURCE_NAME_TABLE: String = "table" - val SOURCE_NAME_NOOP: String = "noop" - - // these writer sources are also used for one-time query, hence allow temp checkpoint location - val SOURCES_ALLOW_ONE_TIME_QUERY: Seq[String] = Seq(SOURCE_NAME_MEMORY, SOURCE_NAME_FOREACH, - SOURCE_NAME_FOREACH_BATCH, SOURCE_NAME_CONSOLE, SOURCE_NAME_NOOP) + private val sinkBuilder = WriteStreamOperationStart + .newBuilder() + .setInput(ds.plan.getRoot) } From bce3cb9dcb56ff8cc88c81b80091fd32e7df9621 Mon Sep 17 00:00:00 2001 From: Haiyang Sun <75672016+haiyangsun-db@users.noreply.github.com> Date: Fri, 27 Dec 2024 16:27:34 +0100 Subject: [PATCH 2/4] fix wrong updates. --- .../sql/streaming/DataStreamWriter.scala | 4 +- .../sql/streaming/DataStreamWriter.scala | 403 ++++++++++++++---- 2 files changed, 330 insertions(+), 77 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 9fcc31e562682..ee442b571c820 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -135,7 +135,9 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends api.DataSt /** @inheritdoc */ @Evolving def foreachBatch(function: (Dataset[T], Long) => Unit): this.type = { - val serializedFn = SparkSerDeUtils.serialize(function) + // SPARK-50661: the client should sent the encoder for the input dataset together with the + // function to the server. + val serializedFn = SparkSerDeUtils.serialize(ForeachWriterPacket(function, ds.agnosticEncoder)) sinkBuilder.getForeachBatchBuilder.getScalaFunctionBuilder .setPayload(ByteString.copyFrom(serializedFn)) .setOutputType(DataTypeProtoConverter.toConnectProtoType(NullType)) // Unused. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index ee442b571c820..d41933c6a135c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -22,158 +22,377 @@ import java.util.concurrent.TimeoutException import scala.jdk.CollectionConverters._ -import com.google.protobuf.ByteString +import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Evolving import org.apache.spark.api.java.function.VoidFunction2 -import org.apache.spark.connect.proto -import org.apache.spark.connect.proto.Command -import org.apache.spark.connect.proto.WriteStreamOperationStart -import org.apache.spark.sql.{api, Dataset, ForeachWriter} -import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket} -import org.apache.spark.sql.execution.streaming.AvailableNowTrigger -import org.apache.spark.sql.execution.streaming.ContinuousTrigger -import org.apache.spark.sql.execution.streaming.OneTimeTrigger -import org.apache.spark.sql.execution.streaming.ProcessingTimeTrigger -import org.apache.spark.sql.streaming.StreamingQueryListener.QueryStartedEvent -import org.apache.spark.sql.types.NullType -import org.apache.spark.util.SparkSerDeUtils +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.plans.logical.{ColumnDefinition, CreateTable, OptionList, UnresolvedTableSpec} +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.connector.catalog.{Identifier, SupportsWrite, Table, TableCatalog, TableProvider, V1Table, V2TableWithV1Fallback} +import org.apache.spark.sql.connector.catalog.TableCapability._ +import org.apache.spark.sql.connector.expressions.{ClusterByTransform, FieldReference} +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.command.DDLUtils +import org.apache.spark.sql.execution.datasources.{DataSource, DataSourceUtils} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Utils, FileDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.python.PythonDataSourceV2 +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.sources._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.ArrayImplicits._ +import org.apache.spark.util.Utils /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, * key-value stores, etc). Use `Dataset.writeStream` to access this. * - * @since 3.5.0 + * @since 2.0.0 */ @Evolving -final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends api.DataStreamWriter[T] { - override type DS[U] = Dataset[U] +final class DataStreamWriter[T] private[sql](ds: Dataset[T]) extends api.DataStreamWriter[T] { + type DS[U] = Dataset[U] /** @inheritdoc */ def outputMode(outputMode: OutputMode): this.type = { - sinkBuilder.setOutputMode(outputMode.toString.toLowerCase(Locale.ROOT)) + this.outputMode = outputMode this } /** @inheritdoc */ def outputMode(outputMode: String): this.type = { - sinkBuilder.setOutputMode(outputMode) + this.outputMode = InternalOutputModes(outputMode) this } /** @inheritdoc */ def trigger(trigger: Trigger): this.type = { - trigger match { - case ProcessingTimeTrigger(intervalMs) => - sinkBuilder.setProcessingTimeInterval(s"$intervalMs milliseconds") - case AvailableNowTrigger => - sinkBuilder.setAvailableNow(true) - case OneTimeTrigger => - sinkBuilder.setOnce(true) - case ContinuousTrigger(intervalMs) => - sinkBuilder.setContinuousCheckpointInterval(s"$intervalMs milliseconds") - } + this.trigger = trigger this } /** @inheritdoc */ def queryName(queryName: String): this.type = { - sinkBuilder.setQueryName(queryName) + this.extraOptions += ("queryName" -> queryName) this } /** @inheritdoc */ def format(source: String): this.type = { - sinkBuilder.setFormat(source) + this.source = source this } /** @inheritdoc */ @scala.annotation.varargs def partitionBy(colNames: String*): this.type = { - sinkBuilder.clearPartitioningColumnNames() - sinkBuilder.addAllPartitioningColumnNames(colNames.asJava) + this.partitioningColumns = Option(colNames) + validatePartitioningAndClustering() this } /** @inheritdoc */ @scala.annotation.varargs def clusterBy(colNames: String*): this.type = { - sinkBuilder.clearClusteringColumnNames() - sinkBuilder.addAllClusteringColumnNames(colNames.asJava) + this.clusteringColumns = Option(colNames) + validatePartitioningAndClustering() this } /** @inheritdoc */ def option(key: String, value: String): this.type = { - sinkBuilder.putOptions(key, value) + this.extraOptions += (key -> value) this } /** @inheritdoc */ def options(options: scala.collection.Map[String, String]): this.type = { - this.options(options.asJava) + this.extraOptions ++= options this } /** @inheritdoc */ def options(options: java.util.Map[String, String]): this.type = { - sinkBuilder.putAllOptions(options) + this.options(options.asScala) this } + /** @inheritdoc */ + def start(path: String): StreamingQuery = { + if (!ds.sparkSession.sessionState.conf.legacyPathOptionBehavior && + extraOptions.contains("path")) { + throw QueryCompilationErrors.setPathOptionAndCallWithPathParameterError("start") + } + startInternal(Some(path)) + } + + /** @inheritdoc */ + @throws[TimeoutException] + def start(): StreamingQuery = startInternal(None) + + /** @inheritdoc */ + @Evolving + @throws[TimeoutException] + def toTable(tableName: String): StreamingQuery = { + + import ds.sparkSession.sessionState.analyzer.CatalogAndIdentifier + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + val parser = ds.sparkSession.sessionState.sqlParser + val originalMultipartIdentifier = parser.parseMultipartIdentifier(tableName) + val CatalogAndIdentifier(catalog, identifier) = originalMultipartIdentifier + + // Currently we don't create a logical streaming writer node in logical plan, so cannot rely + // on analyzer to resolve it. Directly lookup only for temp view to provide clearer message. + // TODO (SPARK-27484): we should add the writing node before the plan is analyzed. + if (ds.sparkSession.sessionState.catalog.isTempView(originalMultipartIdentifier)) { + throw QueryCompilationErrors.tempViewNotSupportStreamingWriteError(tableName) + } + + if (!catalog.asTableCatalog.tableExists(identifier)) { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + val properties = normalizedClusteringCols.map { cols => + Map( + DataSourceUtils.CLUSTERING_COLUMNS_KEY -> DataSourceUtils.encodePartitioningColumns(cols)) + }.getOrElse(Map.empty) + val partitioningOrClusteringTransform = normalizedClusteringCols.map { colNames => + Array(ClusterByTransform(colNames.map(col => FieldReference(col)))).toImmutableArraySeq + }.getOrElse(partitioningColumns.getOrElse(Nil).asTransforms.toImmutableArraySeq) + + /** + * Note, currently the new table creation by this API doesn't fully cover the V2 table. + * TODO (SPARK-33638): Full support of v2 table creation + */ + val tableSpec = UnresolvedTableSpec( + properties, + Some(source), + OptionList(Seq.empty), + extraOptions.get("path"), + None, + None, + None, + external = false) + val cmd = CreateTable( + UnresolvedIdentifier(originalMultipartIdentifier), + ds.schema.asNullable.map(ColumnDefinition.fromV1Column(_, parser)), + partitioningOrClusteringTransform, + tableSpec, + ignoreIfExists = false) + Dataset.ofRows(ds.sparkSession, cmd) + } + + val tableInstance = catalog.asTableCatalog.loadTable(identifier) + + def writeToV1Table(table: CatalogTable): StreamingQuery = { + if (table.tableType == CatalogTableType.VIEW) { + throw QueryCompilationErrors.streamingIntoViewNotSupportedError(tableName) + } + require(table.provider.isDefined) + if (source != table.provider.get) { + throw QueryCompilationErrors.inputSourceDiffersFromDataSourceProviderError( + source, tableName, table) + } + format(table.provider.get).startInternal( + Some(new Path(table.location).toString), catalogTable = Some(table)) + } + + import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ + tableInstance match { + case t: SupportsWrite if t.supports(STREAMING_WRITE) => + startQuery(t, extraOptions, catalogAndIdent = Some(catalog.asTableCatalog, identifier)) + case t: V2TableWithV1Fallback => + writeToV1Table(t.v1Table) + case t: V1Table => + writeToV1Table(t.v1Table) + case t => throw QueryCompilationErrors.tableNotSupportStreamingWriteError(tableName, t) + } + } + + private def startInternal( + path: Option[String], + catalogTable: Option[CatalogTable] = None): StreamingQuery = { + if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { + throw QueryCompilationErrors.cannotOperateOnHiveDataSourceFilesError("write") + } + + if (source == DataStreamWriter.SOURCE_NAME_MEMORY) { + assertNotPartitioned(DataStreamWriter.SOURCE_NAME_MEMORY) + if (extraOptions.get("queryName").isEmpty) { + throw QueryCompilationErrors.queryNameNotSpecifiedForMemorySinkError() + } + val sink = new MemorySink() + val resultDf = Dataset.ofRows(ds.sparkSession, + MemoryPlan(sink, DataTypeUtils.toAttributes(ds.schema))) + val recoverFromCheckpoint = outputMode == OutputMode.Complete() + val query = startQuery(sink, extraOptions, recoverFromCheckpoint = recoverFromCheckpoint, + catalogTable = catalogTable) + resultDf.createOrReplaceTempView(query.name) + query + } else if (source == DataStreamWriter.SOURCE_NAME_FOREACH) { + assertNotPartitioned(DataStreamWriter.SOURCE_NAME_FOREACH) + val sink = ForeachWriterTable[Any](foreachWriter, foreachWriterEncoder) + startQuery(sink, extraOptions, catalogTable = catalogTable) + } else if (source == DataStreamWriter.SOURCE_NAME_FOREACH_BATCH) { + assertNotPartitioned(DataStreamWriter.SOURCE_NAME_FOREACH_BATCH) + if (trigger.isInstanceOf[ContinuousTrigger]) { + throw QueryCompilationErrors.sourceNotSupportedWithContinuousTriggerError(source) + } + val sink = new ForeachBatchSink[T](foreachBatchWriter, ds.exprEnc) + startQuery(sink, extraOptions, catalogTable = catalogTable) + } else { + val cls = DataSource.lookupDataSource(source, ds.sparkSession.sessionState.conf) + val disabledSources = + Utils.stringToSeq(ds.sparkSession.sessionState.conf.disabledV2StreamingWriters) + val useV1Source = disabledSources.contains(cls.getCanonicalName) || + // file source v2 does not support streaming yet. + classOf[FileDataSourceV2].isAssignableFrom(cls) + + val optionsWithPath = if (path.isEmpty) { + extraOptions + } else { + extraOptions + ("path" -> path.get) + } + + val sink = if (classOf[TableProvider].isAssignableFrom(cls) && !useV1Source) { + val provider = cls.getConstructor().newInstance().asInstanceOf[TableProvider] + val sessionOptions = DataSourceV2Utils.extractSessionConfigs( + source = provider, conf = ds.sparkSession.sessionState.conf) + val finalOptions = sessionOptions.filter { case (k, _) => !optionsWithPath.contains(k) } ++ + optionsWithPath.originalMap + val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava) + // If the source accepts external table metadata, here we pass the schema of input query + // to `getTable`. This is for avoiding schema inference, which can be very expensive. + // If the query schema is not compatible with the existing data, the behavior is undefined. + val outputSchema = if (provider.supportsExternalMetadata()) { + Some(ds.schema) + } else { + None + } + provider match { + case p: PythonDataSourceV2 => p.setShortName(source) + case _ => + } + val table = DataSourceV2Utils.getTableFromProvider( + provider, dsOptions, userSpecifiedSchema = outputSchema) + import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ + table match { + case table: SupportsWrite if table.supports(STREAMING_WRITE) => + table + case _ => createV1Sink(optionsWithPath) + } + } else { + createV1Sink(optionsWithPath) + } + + startQuery(sink, optionsWithPath, catalogTable = catalogTable) + } + } + + private def startQuery( + sink: Table, + newOptions: CaseInsensitiveMap[String], + recoverFromCheckpoint: Boolean = true, + catalogAndIdent: Option[(TableCatalog, Identifier)] = None, + catalogTable: Option[CatalogTable] = None): StreamingQuery = { + val useTempCheckpointLocation = DataStreamWriter.SOURCES_ALLOW_ONE_TIME_QUERY.contains(source) + + ds.sparkSession.sessionState.streamingQueryManager.startQuery( + newOptions.get("queryName"), + newOptions.get("checkpointLocation"), + ds, + newOptions.originalMap, + sink, + outputMode, + useTempCheckpointLocation = useTempCheckpointLocation, + recoverFromCheckpointLocation = recoverFromCheckpoint, + trigger = trigger, + catalogAndIdent = catalogAndIdent, + catalogTable = catalogTable) + } + + private def createV1Sink(optionsWithPath: CaseInsensitiveMap[String]): Sink = { + // Do not allow the user to specify clustering columns in the options. Ignoring this option is + // consistent with the behavior of DataFrameWriter on non Path-based tables and with the + // behavior of DataStreamWriter on partitioning columns specified in options. + val optionsWithoutClusteringKey = + optionsWithPath.originalMap - DataSourceUtils.CLUSTERING_COLUMNS_KEY + + val optionsWithClusteringColumns = normalizedClusteringCols match { + case Some(cols) => optionsWithoutClusteringKey + ( + DataSourceUtils.CLUSTERING_COLUMNS_KEY -> + DataSourceUtils.encodePartitioningColumns(cols)) + case None => optionsWithoutClusteringKey + } + val ds = DataSource( + this.ds.sparkSession, + className = source, + options = optionsWithClusteringColumns, + partitionColumns = normalizedParCols.getOrElse(Nil)) + ds.createSink(outputMode) + } + /** @inheritdoc */ def foreach(writer: ForeachWriter[T]): this.type = { - val serialized = SparkSerDeUtils.serialize(ForeachWriterPacket(writer, ds.agnosticEncoder)) - val scalaWriterBuilder = proto.ScalarScalaUDF - .newBuilder() - .setPayload(ByteString.copyFrom(serialized)) - sinkBuilder.getForeachWriterBuilder.setScalaFunction(scalaWriterBuilder) + foreachImplementation(writer.asInstanceOf[ForeachWriter[Any]]) + } + + private[sql] def foreachImplementation(writer: ForeachWriter[Any], + encoder: Option[ExpressionEncoder[Any]] = None): this.type = { + this.source = DataStreamWriter.SOURCE_NAME_FOREACH + this.foreachWriter = if (writer != null) { + ds.sparkSession.sparkContext.clean(writer) + } else { + throw new IllegalArgumentException("foreach writer cannot be null") + } + encoder.foreach(e => this.foreachWriterEncoder = e) this } /** @inheritdoc */ @Evolving def foreachBatch(function: (Dataset[T], Long) => Unit): this.type = { - // SPARK-50661: the client should sent the encoder for the input dataset together with the - // function to the server. - val serializedFn = SparkSerDeUtils.serialize(ForeachWriterPacket(function, ds.agnosticEncoder)) - sinkBuilder.getForeachBatchBuilder.getScalaFunctionBuilder - .setPayload(ByteString.copyFrom(serializedFn)) - .setOutputType(DataTypeProtoConverter.toConnectProtoType(NullType)) // Unused. - .setNullable(true) // Unused. + this.source = DataStreamWriter.SOURCE_NAME_FOREACH_BATCH + if (function == null) throw new IllegalArgumentException("foreachBatch function cannot be null") + this.foreachBatchWriter = function this } - /** @inheritdoc */ - def start(path: String): StreamingQuery = { - sinkBuilder.setPath(path) - start() + private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols => + cols.map(normalize(_, "Partition")) } - /** @inheritdoc */ - @throws[TimeoutException] - def start(): StreamingQuery = { - val startCmd = Command - .newBuilder() - .setWriteStreamOperationStart(sinkBuilder.build()) - .build() - - val resp = ds.sparkSession.execute(startCmd).head - if (resp.getWriteStreamOperationStartResult.hasQueryStartedEventJson) { - val event = QueryStartedEvent.fromJson( - resp.getWriteStreamOperationStartResult.getQueryStartedEventJson) - ds.sparkSession.streams.streamingQueryListenerBus.postToAll(event) + private def normalizedClusteringCols: Option[Seq[String]] = clusteringColumns.map { cols => + cols.map(normalize(_, "Clustering")) + } + + /** + * The given column name may not be equal to any of the existing column names if we were in + * case-insensitive context. Normalize the given column name to the real one so that we don't + * need to care about case sensitivity afterwards. + */ + private def normalize(columnName: String, columnType: String): String = { + val validColumnNames = ds.logicalPlan.output.map(_.name) + validColumnNames.find(ds.sparkSession.sessionState.analyzer.resolver(_, columnName)) + .getOrElse(throw QueryCompilationErrors.columnNotFoundInExistingColumnsError( + columnType, columnName, validColumnNames)) + } + + private def assertNotPartitioned(operation: String): Unit = { + if (partitioningColumns.isDefined) { + throw QueryCompilationErrors.operationNotSupportPartitioningError(operation) } - RemoteStreamingQuery.fromStartCommandResponse(ds.sparkSession, resp) } - /** @inheritdoc */ - @Evolving - @throws[TimeoutException] - def toTable(tableName: String): StreamingQuery = { - sinkBuilder.setTableName(tableName) - start() + // Validate that partitionBy isn't used with clusterBy. + private def validatePartitioningAndClustering(): Unit = { + if (clusteringColumns.nonEmpty && partitioningColumns.nonEmpty) { + throw QueryCompilationErrors.clusterByWithPartitionedBy() + } } /////////////////////////////////////////////////////////////////////////////////////// @@ -194,7 +413,39 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends api.DataSt override def foreachBatch(function: VoidFunction2[Dataset[T], java.lang.Long]): this.type = super.foreachBatch(function) - private val sinkBuilder = WriteStreamOperationStart - .newBuilder() - .setInput(ds.plan.getRoot) + /////////////////////////////////////////////////////////////////////////////////////// + // Builder pattern config options + /////////////////////////////////////////////////////////////////////////////////////// + + private var source: String = ds.sparkSession.sessionState.conf.defaultDataSourceName + + private var outputMode: OutputMode = OutputMode.Append + + private var trigger: Trigger = Trigger.ProcessingTime(0L) + + private var extraOptions = CaseInsensitiveMap[String](Map.empty) + + private var foreachWriter: ForeachWriter[Any] = _ + + private var foreachWriterEncoder: ExpressionEncoder[Any] = + ds.exprEnc.asInstanceOf[ExpressionEncoder[Any]] + + private var foreachBatchWriter: (Dataset[T], Long) => Unit = _ + + private var partitioningColumns: Option[Seq[String]] = None + + private var clusteringColumns: Option[Seq[String]] = None +} + +object DataStreamWriter { + val SOURCE_NAME_MEMORY: String = "memory" + val SOURCE_NAME_FOREACH: String = "foreach" + val SOURCE_NAME_FOREACH_BATCH: String = "foreachBatch" + val SOURCE_NAME_CONSOLE: String = "console" + val SOURCE_NAME_TABLE: String = "table" + val SOURCE_NAME_NOOP: String = "noop" + + // these writer sources are also used for one-time query, hence allow temp checkpoint location + val SOURCES_ALLOW_ONE_TIME_QUERY: Seq[String] = Seq(SOURCE_NAME_MEMORY, SOURCE_NAME_FOREACH, + SOURCE_NAME_FOREACH_BATCH, SOURCE_NAME_CONSOLE, SOURCE_NAME_NOOP) } From 8d918de4f09b7106515394ee6e84b0a34926507c Mon Sep 17 00:00:00 2001 From: Haiyang Sun <75672016+haiyangsun-db@users.noreply.github.com> Date: Fri, 27 Dec 2024 17:00:07 +0100 Subject: [PATCH 3/4] format / lint --- .../org/apache/spark/sql/streaming/DataStreamWriter.scala | 3 ++- .../spark/sql/connect/planner/SparkConnectPlanner.scala | 3 +-- .../sql/connect/planner/StreamingForeachBatchHelper.scala | 5 ++--- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index ee442b571c820..934453ab25479 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -137,7 +137,8 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends api.DataSt def foreachBatch(function: (Dataset[T], Long) => Unit): this.type = { // SPARK-50661: the client should sent the encoder for the input dataset together with the // function to the server. - val serializedFn = SparkSerDeUtils.serialize(ForeachWriterPacket(function, ds.agnosticEncoder)) + val serializedFn = + SparkSerDeUtils.serialize(ForeachWriterPacket(function, ds.agnosticEncoder)) sinkBuilder.getForeachBatchBuilder.getScalaFunctionBuilder .setPayload(ByteString.copyFrom(serializedFn)) .setOutputType(DataTypeProtoConverter.toConnectProtoType(NullType)) // Unused. diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 7975f3170837d..d6ade1ac91264 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -2959,8 +2959,7 @@ class SparkConnectPlanner( case StreamingForeachFunction.FunctionCase.SCALA_FUNCTION => StreamingForeachBatchHelper.scalaForeachBatchWrapper( writeOp.getForeachBatch.getScalaFunction.getPayload.toByteArray, - sessionHolder - ) + sessionHolder) case StreamingForeachFunction.FunctionCase.FUNCTION_NOT_SET => throw InvalidPlanInput("Unexpected foreachBatch function") // Unreachable diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala index cc96eb39978f6..ab6bed7152c09 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala @@ -93,9 +93,8 @@ object StreamingForeachBatchHelper extends Logging { def scalaForeachBatchWrapper( payloadBytes: Array[Byte], sessionHolder: SessionHolder): ForeachBatchFnType = { - val foreachBatchPkt = Utils.deserialize[ForeachWriterPacket]( - payloadBytes, - Utils.getContextOrSparkClassLoader) + val foreachBatchPkt = + Utils.deserialize[ForeachWriterPacket](payloadBytes, Utils.getContextOrSparkClassLoader) val fn = foreachBatchPkt.foreachWriter.asInstanceOf[(Dataset[Any], Long) => Unit] val encoder = foreachBatchPkt.datasetEncoder.asInstanceOf[AgnosticEncoder[Any]] // TODO(SPARK-44462): Set up Spark Connect session. From 2fbd4e996630c8b4dab1adcbe6cd25ca50d51baf Mon Sep 17 00:00:00 2001 From: Haiyang Sun <75672016+haiyangsun-db@users.noreply.github.com> Date: Sat, 28 Dec 2024 00:08:55 +0100 Subject: [PATCH 4/4] Address comments. --- .../sql/streaming/DataStreamWriter.scala | 2 +- .../streaming/ClientStreamingQuerySuite.scala | 28 ++++++++++++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 934453ab25479..b2c4fcf64e70f 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -135,7 +135,7 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends api.DataSt /** @inheritdoc */ @Evolving def foreachBatch(function: (Dataset[T], Long) => Unit): this.type = { - // SPARK-50661: the client should sent the encoder for the input dataset together with the + // SPARK-50661: the client should send the encoder for the input dataset together with the // function to the server. val serializedFn = SparkSerDeUtils.serialize(ForeachWriterPacket(function, ds.agnosticEncoder)) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala index 7fdf89b56df44..199a1507a3b19 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala @@ -566,7 +566,7 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L } } - test("foreachBatch") { + test("foreachBatch with DataFrame") { // Starts a streaming query with a foreachBatch function, which writes batchId and row count // to a temp view. The test verifies that the view is populated with data. @@ -651,6 +651,32 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L } } + test("foreachBatch with Dataset[TestClass]") { + val session: SparkSession = spark + import session.implicits._ + val viewName = "test_view" + val tableName = s"global_temp.$viewName" + + val df = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .load() + + val q = df + .selectExpr("CAST(value AS INT)") + .as[TestClass] + .writeStream + .foreachBatch((ds: Dataset[TestClass], batchId: Long) => { + val count = ds.collect().map(_.value).sum + }) + .start() + eventually(timeout(30.seconds)) { + assert(q.isActive) + assert(q.exception.isEmpty) + } + q.stop() + } + abstract class EventCollector extends StreamingQueryListener { protected def tablePostfix: String