From 5c7d940409ffb5f1d60ae015a6c54978f79d4d0d Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 27 Dec 2024 14:43:40 -0800 Subject: [PATCH] all sorts of stuff fails --- .../state/metadata/StateMetadataSource.scala | 2 +- .../TransformWithStateInPandasExec.scala | 3 +- .../FlatMapGroupsWithStateExec.scala | 4 +- .../streaming/IncrementalExecution.scala | 66 ++++++++--- .../StateStoreColumnFamilySchemaUtils.scala | 20 ++-- .../StreamingSymmetricHashJoinExec.scala | 7 +- .../streaming/TransformWithStateExec.scala | 5 +- .../TransformWithStateVariableUtils.scala | 10 +- .../state/OperatorStateMetadata.scala | 5 +- .../streaming/state/RocksDBStateEncoder.scala | 108 ++++++++++++++++- .../state/RocksDBStateStoreProvider.scala | 23 ++-- .../streaming/state/SchemaHelper.scala | 10 ++ .../StateSchemaCompatibilityChecker.scala | 110 ++++++++++++------ .../streaming/statefulOperators.scala | 47 ++++---- .../execution/streaming/streamingLimits.scala | 4 +- .../state/OperatorStateMetadataSuite.scala | 2 +- .../streaming/state/RocksDBSuite.scala | 28 ++--- ...StateSchemaCompatibilityCheckerSuite.scala | 18 +-- .../streaming/TransformWithStateSuite.scala | 12 +- .../TransformWithValueStateTTLSuite.scala | 20 ++-- 20 files changed, 339 insertions(+), 165 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala index 334076f7aa882..614f4a95f2e9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala @@ -270,7 +270,7 @@ class StateMetadataPartitionReader( operatorStateMetadata.version, v2.operatorPropertiesJson, -1, // numColsPrefixKey is not available in OperatorStateMetadataV2 - Some(stateStoreMetadata.stateSchemaFilePaths(stateStoreMetadata.stateSchemaId)) + Some(stateStoreMetadata.stateSchemaFilePaths.last) ) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala index f302e70ac3f92..2e4572ee4c2d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala @@ -149,7 +149,8 @@ case class TransformWithStateInPandasExec( initialStateGroupingAttrs.map(SortOrder(_, Ascending))) override def operatorStateMetadata( - stateSchemaPaths: List[String]): OperatorStateMetadata = { + stateSchemaPaths: List[List[String]] + ): OperatorStateMetadata = { getOperatorStateMetadata(stateSchemaPaths, getStateInfo, shortName, timeMode, outputMode) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 34b14d5bf3b05..58d2a19989cbf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -196,8 +196,8 @@ trait FlatMapGroupsWithStateExecBase hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int): List[StateSchemaValidationResult] = { - val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, - groupingAttributes.toStructType, stateManager.stateSchema)) + val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, 0, + groupingAttributes.toStructType, 0, stateManager.stateSchema)) List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, newStateSchema, session.sessionState, stateSchemaVersion)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 161abc8b3544d..b2555ec073e57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -246,9 +246,9 @@ class IncrementalExecution( } else { None } - val stateSchemaMapping = ssw.stateSchemaMapping(schemaValidationResult, + val stateSchemaList = ssw.stateSchemaList(schemaValidationResult, oldMetadata) - val metadata = ssw.operatorStateMetadata(stateSchemaMapping) + val metadata = ssw.operatorStateMetadata(stateSchemaList) oldMetadata match { case Some(oldMetadata) => ssw.validateNewMetadata(oldMetadata, metadata) case None => @@ -260,7 +260,7 @@ class IncrementalExecution( Some(currentBatchId)) metadataWriter.write(metadata) if (ssw.supportsSchemaEvolution) { - val stateSchemaMetadata = createStateSchemaMetadata(stateSchemaMapping.head) + val stateSchemaMetadata = createStateSchemaMetadata(stateSchemaList.head) stateSchemaMetadatas.put(ssw.getStateInfo.operatorId, stateSchemaMetadata) // Create new instance with copied fields but updated stateInfo ssw match { @@ -279,22 +279,54 @@ class IncrementalExecution( } private def createStateSchemaMetadata( - stateSchemaMapping: Map[Short, String] + stateSchemaFiles: List[String] ): StateSchemaBroadcast = { val fm = CheckpointFileManager.create(new Path(checkpointLocation), hadoopConf) - val stateSchemas = stateSchemaMapping.flatMap { case (stateSchemaId, stateSchemaPath) => - val inStream = fm.open(new Path(stateSchemaPath)) - StateSchemaCompatibilityChecker.readSchemaFile(inStream).map { schema => - StateSchemaMetadataKey( - schema.colFamilyName, stateSchemaId) -> - StateSchemaMetadataValue( - schema.valueSchema, SchemaConverters.toAvroType(schema.valueSchema)) - }.toMap - } - StateSchemaBroadcast( - sparkSession.sparkContext.broadcast( - StateSchemaMetadata(stateSchemas.keys.map(_.schemaId).max, stateSchemas) - )) + + // Build up our map of schema metadata + val activeSchemas = stateSchemaFiles.zipWithIndex.foldLeft( + Map.empty[StateSchemaMetadataKey, StateSchemaMetadataValue]) { + case (schemas, (stateSchemaFile, schemaIndex)) => + val fsDataInputStream = fm.open(new Path(stateSchemaFile)) + val colFamilySchemas = StateSchemaCompatibilityChecker.readSchemaFile(fsDataInputStream) + + // For each column family, create metadata entries for both key and value schemas + val schemaEntries = colFamilySchemas.flatMap { colFamilySchema => + // Create key schema metadata + val keyAvroSchema = SchemaConverters.toAvroType(colFamilySchema.keySchema) + val keyEntry = StateSchemaMetadataKey( + colFamilySchema.colFamilyName, + colFamilySchema.keySchemaId, + isKey = true + ) -> StateSchemaMetadataValue( + colFamilySchema.keySchema, + keyAvroSchema + ) + + // Create value schema metadata + val valueAvroSchema = SchemaConverters.toAvroType(colFamilySchema.valueSchema) + val valueEntry = StateSchemaMetadataKey( + colFamilySchema.colFamilyName, + colFamilySchema.valueSchemaId, + isKey = false + ) -> StateSchemaMetadataValue( + colFamilySchema.valueSchema, + valueAvroSchema + ) + + Seq(keyEntry, valueEntry) + } + + // Add new entries to our accumulated map + schemas ++ schemaEntries.toMap + } + + // Create the final metadata and wrap it in a broadcast + val metadata = StateSchemaMetadata( + activeSchemas = activeSchemas + ) + + StateSchemaBroadcast(sparkSession.sparkContext.broadcast(metadata)) } object StateOpIdRule extends SparkPlanPartialRule { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala index eb019289f9a5e..858c74a9eae8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala @@ -69,8 +69,8 @@ object StateStoreColumnFamilySchemaUtils { valEncoder: Encoder[T], hasTtl: Boolean): StateStoreColFamilySchema = { StateStoreColFamilySchema( - stateName, - keyEncoder.schema, + stateName, 0, + keyEncoder.schema, 0, getValueSchemaWithTTL(valEncoder.schema, hasTtl), Some(NoPrefixKeyStateEncoderSpec(keyEncoder.schema))) } @@ -81,8 +81,8 @@ object StateStoreColumnFamilySchemaUtils { valEncoder: Encoder[T], hasTtl: Boolean): StateStoreColFamilySchema = { StateStoreColFamilySchema( - stateName, - keyEncoder.schema, + stateName, 0, + keyEncoder.schema, 0, getValueSchemaWithTTL(valEncoder.schema, hasTtl), Some(NoPrefixKeyStateEncoderSpec(keyEncoder.schema))) } @@ -95,8 +95,8 @@ object StateStoreColumnFamilySchemaUtils { hasTtl: Boolean): StateStoreColFamilySchema = { val compositeKeySchema = getCompositeKeySchema(keyEncoder.schema, userKeyEnc.schema) StateStoreColFamilySchema( - stateName, - compositeKeySchema, + stateName, 0, + compositeKeySchema, 0, getValueSchemaWithTTL(valEncoder.schema, hasTtl), Some(PrefixKeyScanStateEncoderSpec(compositeKeySchema, 1)), Some(userKeyEnc.schema)) @@ -107,8 +107,8 @@ object StateStoreColumnFamilySchemaUtils { keySchema: StructType, valSchema: StructType): StateStoreColFamilySchema = { StateStoreColFamilySchema( - stateName, - keySchema, + stateName, 0, + keySchema, 0, valSchema, Some(PrefixKeyScanStateEncoderSpec(keySchema, 1))) } @@ -118,8 +118,8 @@ object StateStoreColumnFamilySchemaUtils { keySchema: StructType, valSchema: StructType): StateStoreColFamilySchema = { StateStoreColFamilySchema( - stateName, - keySchema, + stateName, 0, + keySchema, 0, valSchema, Some(RangeKeyScanStateEncoderSpec(keySchema, Seq(0)))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index cfde2a308f7f5..6852cb85eb832 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -228,7 +228,8 @@ case class StreamingSymmetricHashJoinExec( SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide) override def operatorStateMetadata( - stateSchemaPaths: List[Map[Short, String]] = List.empty): OperatorStateMetadata = { + stateSchemaPaths: List[List[String]] = List.empty + ): OperatorStateMetadata = { val info = getStateInfo val operatorInfo = OperatorInfoV1(info.operatorId, shortName) val stateStoreInfo = @@ -263,8 +264,8 @@ case class StreamingSymmetricHashJoinExec( // validate and maybe evolve schema for all state stores across both sides of the join result.map { case (stateStoreName, (keySchema, valueSchema)) => - val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, - keySchema, valueSchema)) + val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, 0, + keySchema, 0, valueSchema)) StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, newStateSchema, session.sessionState, stateSchemaVersion, storeName = stateStoreName) }.toList diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 668159ad08dea..fc74278fdfce4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -462,12 +462,13 @@ case class TransformWithStateExec( stateSchemaVersion: Int): List[StateSchemaValidationResult] = { val info = getStateInfo validateAndWriteStateSchema(hadoopConf, batchId, stateSchemaVersion, - info, session, operatorStateMetadataVersion) + info, session, operatorStateMetadataVersion, conf.stateStoreEncodingFormat) } /** Metadata of this stateful operator and its states stores. */ override def operatorStateMetadata( - stateSchemaPaths: List[String]): OperatorStateMetadata = { + stateSchemaPaths: List[List[String]] + ): OperatorStateMetadata = { val info = getStateInfo getOperatorStateMetadata(stateSchemaPaths, info, shortName, timeMode, outputMode) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala index 34dddeab59d29..e0bb5cfeb7102 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala @@ -176,7 +176,7 @@ trait TransformWithStateMetadataUtils extends Logging { def getStateVariableInfos(): Map[String, TransformWithStateVariableInfo] def getOperatorStateMetadata( - stateSchemaPaths: List[String], + stateSchemaPaths: List[List[String]], info: StatefulOperatorStateInfo, shortName: String, timeMode: TimeMode, @@ -201,7 +201,8 @@ trait TransformWithStateMetadataUtils extends Logging { stateSchemaVersion: Int, info: StatefulOperatorStateInfo, session: SparkSession, - operatorStateMetadataVersion: Int = 2): List[StateSchemaValidationResult] = { + operatorStateMetadataVersion: Int = 2, + stateStoreEncodingFormat: String = "unsaferow"): List[StateSchemaValidationResult] = { assert(stateSchemaVersion >= 3) val newSchemas = getColFamilySchemas() val stateSchemaDir = stateSchemaDirPath(info) @@ -223,7 +224,7 @@ trait TransformWithStateMetadataUtils extends Logging { case Some(metadata) => metadata match { case v2: OperatorStateMetadataV2 => - Some(new Path(v2.stateStoreInfo.head.stateSchemaFilePath)) + Some(new Path(v2.stateStoreInfo.head.stateSchemaFilePaths.last)) case _ => None } case None => None @@ -234,7 +235,8 @@ trait TransformWithStateMetadataUtils extends Logging { newSchemas.values.toList, session.sessionState, stateSchemaVersion, storeName = StateStoreId.DEFAULT_STORE_NAME, oldSchemaFilePath = oldStateSchemaFilePath, - newSchemaFilePath = Some(newStateSchemaFilePath))) + newSchemaFilePath = Some(newStateSchemaFilePath), + schemaEvolutionEnabled = stateStoreEncodingFormat == "avro")) } def validateNewMetadataForTWS( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala index c69ea7cf00b3e..99bc7ede2227f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala @@ -51,8 +51,7 @@ case class StateStoreMetadataV2( storeName: String, numColsPrefixKey: Int, numPartitions: Int, - stateSchemaId: Short, - stateSchemaFilePaths: Map[Short, String]) + stateSchemaFilePaths: List[String]) extends StateStoreMetadata with Serializable object StateStoreMetadataV2 { @@ -470,7 +469,7 @@ class OperatorStateMetadataV2FileManager( val earliestBatchToKeep = latestMetadata match { case Some(OperatorStateMetadataV2(_, stateStoreInfo, _)) => val ssInfo = stateStoreInfo.head - val schemaFilePath = ssInfo.stateSchemaFilePaths.minBy(_._1)._2 + val schemaFilePath = ssInfo.stateSchemaFilePaths.head new Path(schemaFilePath).getName.split("_").head.toLong case _ => 0 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index 46b4ad205c2fd..d290d0e30faa1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -26,6 +26,7 @@ import org.apache.avro.Schema import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWriter, GenericRecord} import org.apache.avro.io.{DecoderFactory, EncoderFactory} +import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.sql.avro.{AvroDeserializer, AvroOptions, AvroSerializer, SchemaConverters} import org.apache.spark.sql.catalyst.InternalRow @@ -51,6 +52,78 @@ sealed trait RocksDBValueStateEncoder { def decodeValues(valueBytes: Array[Byte]): Iterator[UnsafeRow] } +/** + * Broadcasts schema metadata information for stateful operators in a streaming query. + * + * This class provides a way to distribute schema evolution information to all executors + * via Spark's broadcast mechanism. Each stateful operator in a streaming query maintains + * its own instance of this class to track schema versions and evolution. + * + * @param broadcast Spark broadcast variable containing the schema metadata + */ +case class StateSchemaBroadcast( + broadcast: Broadcast[StateSchemaMetadata] +) extends Logging { + + /** + * Retrieves the schema information for a given column family and schema version + * + * @param key A combination of column family name and schema ID + * @return The corresponding schema metadata value containing both SQL and Avro schemas + */ + def getSchemaMetadataValue(key: StateSchemaMetadataKey): StateSchemaMetadataValue = { + broadcast.value.activeSchemas(key) + } + + def getCurrentStateSchemaId(colFamilyName: String, isKey: Boolean): Short = { + broadcast.value.activeSchemas + .keys + .filter(key => + key.colFamilyName == colFamilyName && + key.isKey == isKey + ) + .map(_.schemaId).max + } +} + +/** + * Contains schema evolution metadata for a stateful operator. + * + * @param currentSchemaId The schema version currently being used for writing new state + * @param activeSchemas Map of all active schema versions, keyed by column family and schema ID. + * This includes both the current schema and any previous schemas that + * may still exist in the state store. + */ +case class StateSchemaMetadata( + activeSchemas: Map[StateSchemaMetadataKey, StateSchemaMetadataValue] +) + +/** + * Composite key for looking up schema metadata, combining column family and schema version. + * + * @param colFamilyName Name of the RocksDB column family this schema applies to + * @param schemaId Version identifier for this schema + */ +case class StateSchemaMetadataKey( + colFamilyName: String, + schemaId: Short, + isKey: Boolean +) + +/** + * Contains both SQL and Avro representations of a schema version. + * + * The SQL schema represents the logical structure while the Avro schema is used + * for evolution compatibility checking and serialization. + * + * @param sqlSchema The Spark SQL schema definition + * @param avroSchema The equivalent Avro schema used for compatibility checking + */ +case class StateSchemaMetadataValue( + sqlSchema: StructType, + avroSchema: Schema +) + /** * Contains schema version information for both key and value schemas in a state store. * This information is used to support schema evolution, allowing state schemas to be @@ -297,7 +370,8 @@ abstract class RocksDBDataEncoder( class UnsafeRowDataEncoder( keyStateEncoderSpec: KeyStateEncoderSpec, valueSchema: StructType, - stateSchemaInfo: Option[StateSchemaInfo] + stateSchemaBroadcast: Option[StateSchemaBroadcast], + columnFamilyInfo: Option[ColumnFamilyInfo] ) extends RocksDBDataEncoder(keyStateEncoderSpec, valueSchema) { override def supportsSchemaEvolution: Boolean = false @@ -528,9 +602,22 @@ class UnsafeRowDataEncoder( class AvroStateEncoder( keyStateEncoderSpec: KeyStateEncoderSpec, valueSchema: StructType, - stateSchemaInfo: Option[StateSchemaInfo] + stateSchemaBroadcast: Option[StateSchemaBroadcast], + columnFamilyInfo: Option[ColumnFamilyInfo] ) extends RocksDBDataEncoder(keyStateEncoderSpec, valueSchema) with Logging { + + // schema information + private val currentKeySchemaId: Short = getStateSchemaBroadcast.getCurrentStateSchemaId( + getColFamilyName, + isKey = true + ) + + private val currentValSchemaId: Short = getStateSchemaBroadcast.getCurrentStateSchemaId( + getColFamilyName, + isKey = false + ) + private val avroEncoder = createAvroEnc(keyStateEncoderSpec, valueSchema) // Avro schema used by the avro encoders private lazy val keyAvroType: Schema = SchemaConverters.toAvroType(keySchema) @@ -546,10 +633,11 @@ class AvroStateEncoder( StructType(keySchema.take (numColsPrefixKey)) case _ => throw unsupportedOperationForKeyStateEncoder("prefixKeySchema") } + private lazy val prefixKeyAvroType = SchemaConverters.toAvroType(prefixKeySchema) private lazy val prefixKeyProj = UnsafeProjection.create(prefixKeySchema) - // Range Key schema nd projection definitions used by the Avro Serializers and + // Range Key schema and projection definitions used by the Avro Serializers and // Deserializers private lazy val rangeScanKeyFieldsWithOrdinal = keyStateEncoderSpec match { case RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) => @@ -656,6 +744,14 @@ class AvroStateEncoder( override def supportsSchemaEvolution: Boolean = true + private def getColFamilyName: String = { + columnFamilyInfo.get.colFamilyName + } + + private def getStateSchemaBroadcast: StateSchemaBroadcast = { + stateSchemaBroadcast.get + } + /** * This method takes an UnsafeRow, and serializes to a byte array using Avro encoding. */ @@ -710,7 +806,7 @@ class AvroStateEncoder( encodeUnsafeRowToAvro(row, avroEncoder.keySerializer, keyAvroType, out) // prepend stateSchemaId to the Avro-encoded key portion for NoPrefixKeys encodeWithStateSchemaId( - StateSchemaIdRow(stateSchemaInfo.get.keySchemaId, avroRow)) + StateSchemaIdRow(currentKeySchemaId, avroRow)) case PrefixKeyScanStateEncoderSpec(_, _) => encodeUnsafeRowToAvro(row, avroEncoder.keySerializer, prefixKeyAvroType, out) case _ => throw unsupportedOperationForKeyStateEncoder("encodeKey") @@ -727,7 +823,7 @@ class AvroStateEncoder( } // prepend stateSchemaId to the remaining key portion encodeWithStateSchemaId( - StateSchemaIdRow(stateSchemaInfo.get.keySchemaId, avroRow)) + StateSchemaIdRow(currentKeySchemaId, avroRow)) } /** @@ -872,7 +968,7 @@ class AvroStateEncoder( override def encodeValue(row: UnsafeRow): Array[Byte] = { val avroRow = encodeUnsafeRowToAvro(row, avroEncoder.valueSerializer, valueAvroType, out) // prepend stateSchemaId to the Avro-encoded value portion - encodeWithStateSchemaId(StateSchemaIdRow(stateSchemaInfo.get.valueSchemaId, avroRow)) + encodeWithStateSchemaId(StateSchemaIdRow(currentValSchemaId, avroRow)) } override def decodeKey(bytes: Array[Byte]): UnsafeRow = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 74684335c9f88..1e477a6fbbcb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.execution.streaming.state import java.io._ import java.util.UUID import java.util.concurrent.{ConcurrentHashMap, TimeUnit} + import scala.util.control.NonFatal + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -30,7 +32,6 @@ import org.apache.spark.internal.LogKeys._ import org.apache.spark.io.CompressionCodec import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.DEFAULT_SCHEMA_IDS import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, StreamExecution} import org.apache.spark.sql.types.StructType import org.apache.spark.util.{NonFateSharingCache, Utils} @@ -101,8 +102,7 @@ private[sql] class RocksDBStateStoreProvider val valueEncoder = RocksDBStateEncoder.getValueEncoder( dataEncoder, valueSchema, - useMultipleValuesPerKey, - stateSchemaBroadcast.map(_.getCurrentSchemaId) + useMultipleValuesPerKey ) keyValueEncoderMap.putIfAbsent(colFamilyName, (keyEncoder, valueEncoder)) } @@ -418,7 +418,7 @@ private[sql] class RocksDBStateStoreProvider } val dataEncoder = getDataEncoder( - stateStoreEncoding, + "unsaferow", dataEncoderCacheKey, keyStateEncoderSpec, valueSchema, @@ -435,8 +435,7 @@ private[sql] class RocksDBStateStoreProvider val valueEncoder = RocksDBStateEncoder.getValueEncoder( dataEncoder, valueSchema, - useMultipleValuesPerKey, - stateSchemaBroadcast.map(_.getCurrentSchemaId) + useMultipleValuesPerKey ) keyValueEncoderMap.putIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME, (keyEncoder, valueEncoder)) } @@ -706,10 +705,18 @@ object RocksDBStateStoreProvider { override def call(): RocksDBDataEncoder = { if (stateStoreEncoding == "avro") { new AvroStateEncoder( - keyStateEncoderSpec, valueSchema, stateSchemaBroadcast, Some(DEFAULT_SCHEMA_IDS), columnFamilyInfo) + keyStateEncoderSpec, + valueSchema, + stateSchemaBroadcast, + columnFamilyInfo + ) } else { new UnsafeRowDataEncoder( - keyStateEncoderSpec, valueSchema, stateSchemaBroadcast, None, columnFamilyInfo) + keyStateEncoderSpec, + valueSchema, + stateSchemaBroadcast, + columnFamilyInfo + ) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala index f6737307fad13..d67eb40fde2c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala @@ -70,7 +70,9 @@ object SchemaHelper { val keySchemaStr = inputStream.readUTF() val valueSchemaStr = inputStream.readUTF() List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, + 0, StructType.fromString(keySchemaStr), + 0, StructType.fromString(valueSchemaStr))) } } @@ -83,7 +85,9 @@ object SchemaHelper { val valueSchemaStr = readJsonSchema(inputStream) List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, + 0, StructType.fromString(keySchemaStr), + 0, StructType.fromString(valueSchemaStr))) } } @@ -97,7 +101,9 @@ object SchemaHelper { (0 until numEntries).map { _ => // read the col family name and the key and value schema val colFamilyName = inputStream.readUTF() + val keySchemaId = inputStream.readShort() val keySchemaStr = readJsonSchema(inputStream) + val valSchemaId = inputStream.readShort() val valueSchemaStr = readJsonSchema(inputStream) val keySchema = StructType.fromString(keySchemaStr) @@ -111,7 +117,9 @@ object SchemaHelper { val userKeyEncoderSchema = Try(StructType.fromString(userKeyEncoderSchemaStr)).toOption StateStoreColFamilySchema(colFamilyName, + keySchemaId, keySchema, + valSchemaId, StructType.fromString(valueSchemaStr), Some(encoderSpec), userKeyEncoderSchema) @@ -206,7 +214,9 @@ object SchemaHelper { stateStoreColFamilySchema.foreach { colFamilySchema => assert(colFamilySchema.keyStateEncoderSpec.isDefined) outputStream.writeUTF(colFamilySchema.colFamilyName) + outputStream.writeShort(colFamilySchema.keySchemaId) writeJsonSchema(outputStream, colFamilySchema.keySchema.json) + outputStream.writeShort(colFamilySchema.valueSchemaId) writeJsonSchema(outputStream, colFamilySchema.valueSchema.json) writeJsonSchema(outputStream, colFamilySchema.keyStateEncoderSpec.get.json) // write user key encoder schema if provided and empty json otherwise diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala index 0db0b174cdf84..b6a45d4247c05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala @@ -67,7 +67,9 @@ case class AvroEncoder( // Used to represent the schema of a column family in the state store case class StateStoreColFamilySchema( colFamilyName: String, + keySchemaId: Short, keySchema: StructType, + valueSchemaId: Short, valueSchema: StructType, keyStateEncoderSpec: Option[KeyStateEncoderSpec] = None, userKeyEncoderSchema: Option[StructType] = None @@ -156,39 +158,64 @@ class StateSchemaCompatibilityChecker( oldSchema: StateStoreColFamilySchema, newSchema: StateStoreColFamilySchema, ignoreValueSchema: Boolean, - schemaEvolutionEnabled: Boolean) : Boolean = { - val (storedKeySchema, storedValueSchema) = (oldSchema.keySchema, - oldSchema.valueSchema) + schemaEvolutionEnabled: Boolean): StateStoreColFamilySchema = { + + val (storedKeySchema, storedValueSchema) = (oldSchema.keySchema, oldSchema.valueSchema) val (keySchema, valueSchema) = (newSchema.keySchema, newSchema.valueSchema) + def incrementSchemaId(id: Short): Short = (id + 1).toShort + + // Initialize with old schema IDs + var resultSchema = newSchema.copy( + keySchemaId = oldSchema.keySchemaId, + valueSchemaId = oldSchema.valueSchemaId + ) + if (storedKeySchema.equals(keySchema) && (ignoreValueSchema || storedValueSchema.equals(valueSchema))) { - // schema is exactly same - false + // Schema is exactly same - return old schema + oldSchema } else if (!ignoreValueSchema && schemaEvolutionEnabled) { - // By this point, we know that old value schema is not equal to new value schema + // Check value schema evolution val oldAvroSchema = SchemaConverters.toAvroType(storedValueSchema) val newAvroSchema = SchemaConverters.toAvroType(valueSchema) val validator = new SchemaValidatorBuilder().canReadStrategy.validateAll() - // This will throw a SchemaValidation exception if the schema has evolved in an - // unacceptable way. validator.validate(newAvroSchema, Iterable(oldAvroSchema).asJava) - // If no exception is thrown, then we know that the schema evolved in an - // acceptable way - true - } else if (!schemasCompatible(storedKeySchema, keySchema)) { - throw StateStoreErrors.stateStoreKeySchemaNotCompatible(storedKeySchema.toString, - keySchema.toString) - } else if (!ignoreValueSchema && !schemasCompatible(storedValueSchema, valueSchema)) { - throw StateStoreErrors.stateStoreValueSchemaNotCompatible(storedValueSchema.toString, - valueSchema.toString) + + // Schema evolved - increment value schema ID + resultSchema.copy(valueSchemaId = incrementSchemaId(oldSchema.valueSchemaId)) } else { - logInfo("Detected schema change which is compatible. Allowing to put rows.") - true + // Check compatibility + if (!schemasCompatible(storedKeySchema, keySchema)) { + throw StateStoreErrors.stateStoreKeySchemaNotCompatible( + storedKeySchema.toString, keySchema.toString) + } + if (!ignoreValueSchema && !schemasCompatible(storedValueSchema, valueSchema)) { + throw StateStoreErrors.stateStoreValueSchemaNotCompatible( + storedValueSchema.toString, valueSchema.toString) + } + + // Schema changed but compatible - increment IDs as needed + val needsKeyUpdate = !storedKeySchema.equals(keySchema) + val needsValueUpdate = !ignoreValueSchema && !storedValueSchema.equals(valueSchema) + + resultSchema.copy( + keySchemaId = if (needsKeyUpdate) { + incrementSchemaId(oldSchema.keySchemaId) + } else { + oldSchema.keySchemaId + }, + valueSchemaId = if (needsValueUpdate) { + incrementSchemaId(oldSchema.valueSchemaId) + } else { + oldSchema.valueSchemaId + } + ) } } + /** * Function to validate the new state store schema and evolve the schema if required. * @param newStateSchema - proposed new state store schema by the operator @@ -202,31 +229,44 @@ class StateSchemaCompatibilityChecker( stateSchemaVersion: Int, schemaEvolutionEnabled: Boolean): Boolean = { val existingStateSchemaList = getExistingKeyAndValueSchema() - val newStateSchemaList = newStateSchema if (existingStateSchemaList.isEmpty) { - // write the schema file if it doesn't exist - createSchemaFile(newStateSchemaList.sortBy(_.colFamilyName), stateSchemaVersion) + // Initialize schemas with ID 0 when no existing schema + val initializedSchemas = newStateSchema.map(schema => + schema.copy(keySchemaId = 0, valueSchemaId = 0) + ) + createSchemaFile(initializedSchemas.sortBy(_.colFamilyName), stateSchemaVersion) true } else { - // validate if the new schema is compatible with the existing schema - val existingSchemaMap = existingStateSchemaList.map { schema => + val existingSchemaMap = existingStateSchemaList.map(schema => schema.colFamilyName -> schema - }.toMap - // For each new state variable, we want to compare it to the old state variable - // schema with the same name - val hasEvolvedSchema = newStateSchemaList.exists { newSchema => - existingSchemaMap.get(newSchema.colFamilyName) - .exists(existingSchema => check( - existingSchema, newSchema, ignoreValueSchema, schemaEvolutionEnabled)) + ).toMap + + // Process each new schema and track if any have evolved + val (evolvedSchemas, hasEvolutions) = newStateSchema.foldLeft( + (List.empty[StateStoreColFamilySchema], false)) { + case ((schemas, evolved), newSchema) => + existingSchemaMap.get(newSchema.colFamilyName) match { + case Some(existingSchema) => + val updatedSchema = check( + existingSchema, newSchema, ignoreValueSchema, schemaEvolutionEnabled) + val hasEvolved = !updatedSchema.equals(existingSchema) + (updatedSchema :: schemas, evolved || hasEvolved) + case None => + // New column family - initialize with schema ID 0 + val newSchemaWithIds = newSchema.copy(keySchemaId = 0, valueSchemaId = 0) + (newSchemaWithIds :: schemas, true) + } } + val colFamiliesAddedOrRemoved = - (newStateSchemaList.map(_.colFamilyName).toSet != existingSchemaMap.keySet) - val newSchemaFileWritten = hasEvolvedSchema || colFamiliesAddedOrRemoved + (newStateSchema.map(_.colFamilyName).toSet != existingSchemaMap.keySet) + val newSchemaFileWritten = hasEvolutions || colFamiliesAddedOrRemoved + if (stateSchemaVersion == SCHEMA_FORMAT_V3 && newSchemaFileWritten) { - createSchemaFile(newStateSchemaList.sortBy(_.colFamilyName), stateSchemaVersion) + createSchemaFile(evolvedSchemas.sortBy(_.colFamilyName), stateSchemaVersion) } - // TODO: [SPARK-49535] Write Schema files after schema has changed for StateSchemaV3 + newSchemaFileWritten } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 2e76c3e814a4c..933d239d2d52b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -234,31 +234,28 @@ trait StateStoreWriter } } - def stateSchemaMapping( + def stateSchemaList( stateSchemaValidationResults: List[StateSchemaValidationResult], - oldMetadata: Option[OperatorStateMetadata]): List[Map[Short, String]] = { + oldMetadata: Option[OperatorStateMetadata]): List[List[String]] = { - def getExistingStateInfo(metadata: OperatorStateMetadataV2): (Short, Map[Short, String]) = { + def getExistingStateInfo(metadata: OperatorStateMetadataV2): List[String] = { val ssInfo = metadata.stateStoreInfo.head - (ssInfo.stateSchemaId, ssInfo.stateSchemaFilePaths) + ssInfo.stateSchemaFilePaths } val validationResult = stateSchemaValidationResults.head - def nextSchemaId(currentId: Short): Short = (currentId + 1).toShort - oldMetadata match { case Some(v2: OperatorStateMetadataV2) => - val (oldSchemaId, oldSchemaPaths) = getExistingStateInfo(v2) + val oldSchemaPaths = getExistingStateInfo(v2) if (validationResult.evolvedSchema) { - List(oldSchemaPaths + (nextSchemaId(oldSchemaId) -> validationResult.schemaPath)) + List(oldSchemaPaths ++ List(validationResult.schemaPath)) } else { List(oldSchemaPaths) } - case _ => // No previous metadata - start with schema ID 0 - List(Map(0.toShort -> validationResult.schemaPath)) + List(List(validationResult.schemaPath)) } } @@ -350,7 +347,8 @@ trait StateStoreWriter /** Metadata of this stateful operator and its states stores. */ def operatorStateMetadata( - stateSchemaPaths: List[Map[Short, String]] = List.empty): OperatorStateMetadata = { + stateSchemaPaths: List[List[String]] = List.empty + ): OperatorStateMetadata = { val info = getStateInfo val operatorInfo = OperatorInfoV1(info.operatorId, shortName) val stateStoreInfo = @@ -614,8 +612,8 @@ case class StateStoreRestoreExec( override def validateAndMaybeEvolveStateSchema( hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int): List[StateSchemaValidationResult] = { - val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, - keyExpressions.toStructType, stateManager.getStateValueSchema)) + val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, 0, + keyExpressions.toStructType, 0, stateManager.getStateValueSchema)) List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, newStateSchema, session.sessionState, stateSchemaVersion)) } @@ -686,8 +684,8 @@ case class StateStoreSaveExec( hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int): List[StateSchemaValidationResult] = { - val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, - keyExpressions.toStructType, stateManager.getStateValueSchema)) + val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, 0, + keyExpressions.toStructType, 0, stateManager.getStateValueSchema)) List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, newStateSchema, session.sessionState, stateSchemaVersion)) } @@ -900,8 +898,8 @@ case class SessionWindowStateStoreRestoreExec( override def validateAndMaybeEvolveStateSchema( hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int): List[StateSchemaValidationResult] = { - val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, - stateManager.getStateKeySchema, stateManager.getStateValueSchema)) + val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, 0, + stateManager.getStateKeySchema, 0, stateManager.getStateValueSchema)) List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, newStateSchema, session.sessionState, stateSchemaVersion)) } @@ -991,8 +989,8 @@ case class SessionWindowStateStoreSaveExec( override def validateAndMaybeEvolveStateSchema( hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int): List[StateSchemaValidationResult] = { - val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, - stateManager.getStateKeySchema, stateManager.getStateValueSchema)) + val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, 0, + stateManager.getStateKeySchema, 0, stateManager.getStateValueSchema)) List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, newStateSchema, session.sessionState, stateSchemaVersion)) } @@ -1102,7 +1100,8 @@ case class SessionWindowStateStoreSaveExec( } override def operatorStateMetadata( - stateSchemaPaths: List[Map[Short, String]] = List.empty): OperatorStateMetadata = { + stateSchemaPaths: List[List[String]] = List.empty + ): OperatorStateMetadata = { val info = getStateInfo val operatorInfo = OperatorInfoV1(info.operatorId, shortName) val stateStoreInfo = Array(StateStoreMetadataV1( @@ -1312,8 +1311,8 @@ case class StreamingDeduplicateExec( override def validateAndMaybeEvolveStateSchema( hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int): List[StateSchemaValidationResult] = { - val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, - keyExpressions.toStructType, schemaForValueRow)) + val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, 0, + keyExpressions.toStructType, 0, schemaForValueRow)) List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, newStateSchema, session.sessionState, stateSchemaVersion, extraOptions = extraOptionOnStateStore)) @@ -1392,8 +1391,8 @@ case class StreamingDeduplicateWithinWatermarkExec( override def validateAndMaybeEvolveStateSchema( hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int): List[StateSchemaValidationResult] = { - val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, - keyExpressions.toStructType, schemaForValueRow)) + val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, 0, + keyExpressions.toStructType, 0, schemaForValueRow)) List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, newStateSchema, session.sessionState, stateSchemaVersion, extraOptions = extraOptionOnStateStore)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala index 0be2450c0ed16..e942aaccacc07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala @@ -50,8 +50,8 @@ case class StreamingGlobalLimitExec( override def validateAndMaybeEvolveStateSchema( hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int): List[StateSchemaValidationResult] = { - val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, - keySchema, valueSchema)) + val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, 0, + keySchema, 0, valueSchema)) List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, newStateSchema, session.sessionState, stateSchemaVersion)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala index b881c6b888768..740b5b0625101 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala @@ -152,7 +152,7 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { // dynamically by the operator. val expectedMetadata = OperatorStateMetadataV2(OperatorInfoV1(0, "transformWithStateExec"), Array(StateStoreMetadataV2( - "default", 0, numShufflePartitions, 0, Map.empty)), + "default", 0, numShufflePartitions, List.empty)), "") checkOperatorStateMetadata(checkpointDir.toString, 0, expectedMetadata, 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index 99d2f25a44e10..e516ac37ac06f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -344,7 +344,7 @@ class RocksDBStateEncoderSuite extends SparkFunSuite { keySchemaId = 0, valueSchemaId = 0 )) - new AvroStateEncoder(keyStateEncoderSpec, valueSchema, stateSchemaInfo) + new AvroStateEncoder(keyStateEncoderSpec, valueSchema, None, None) } private def createNoPrefixKeyEncoder(): RocksDBDataEncoder = { @@ -380,7 +380,7 @@ class RocksDBStateEncoderSuite extends SparkFunSuite { withClue("Testing prefix scan encoding: ") { val prefixKeySpec = PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey = 2) val stateSchemaInfo = Some(StateSchemaInfo(keySchemaId = 42, valueSchemaId = 0)) - val encoder = new AvroStateEncoder(prefixKeySpec, valueSchema, stateSchemaInfo) + val encoder = new AvroStateEncoder(prefixKeySpec, valueSchema, None, None) // Then encode just the remaining key portion (which should include schema ID) val remainingKeyRow = keyProj.apply(InternalRow(null, null, 3.14)) @@ -396,7 +396,7 @@ class RocksDBStateEncoderSuite extends SparkFunSuite { withClue("Testing range scan encoding: ") { val rangeScanSpec = RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals = Seq(0, 1)) val stateSchemaInfo = Some(StateSchemaInfo(keySchemaId = 24, valueSchemaId = 0)) - val encoder = new AvroStateEncoder(rangeScanSpec, valueSchema, stateSchemaInfo) + val encoder = new AvroStateEncoder(rangeScanSpec, valueSchema, None, None) // Encode remaining key (non-ordering columns) // For range scan, the remaining key schema only contains columns NOT in orderingOrdinals @@ -494,7 +494,7 @@ class RocksDBStateEncoderSuite extends SparkFunSuite { withClue("Testing single value encoder: ") { val keySpec = NoPrefixKeyStateEncoderSpec(keySchema) val stateSchemaInfo = Some(StateSchemaInfo(keySchemaId = 0, valueSchemaId = 42)) - val avroEncoder = new AvroStateEncoder(keySpec, valueSchema, stateSchemaInfo) + val avroEncoder = new AvroStateEncoder(keySpec, valueSchema, None, None) val valueEncoder = new SingleValueStateEncoder(avroEncoder, valueSchema) // Encode value @@ -1946,10 +1946,10 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession val keyRowPrefixEncoder = new StateRowPrefixEncoder( useColumnFamilies = colFamiliesEnabled, - colFamilyInfo, supportSchemaEvolution = schemaEvolutionEnabled) + colFamilyInfo) val valueRowPrefixEncoder = new StateRowPrefixEncoder( - false, None, supportSchemaEvolution = schemaEvolutionEnabled) + false, None) // Create some test data with known prefix values val testData = "test data".getBytes @@ -1965,10 +1965,6 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession // Verify key prefixes val keyPrefix = keyRowPrefixEncoder.decodeStateRowPrefix(encodedKey) - assert(keyPrefix.schemaId.isDefined == schemaEvolutionEnabled) - if (schemaEvolutionEnabled) { - assert(keyPrefix.schemaId.get === keyRowPrefixEncoder.getCurrentSchemaId) - } if (colFamiliesEnabled) { assert(keyPrefix.columnFamilyId.isDefined) assert(keyPrefix.columnFamilyId.get === 1) @@ -1978,11 +1974,6 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession // Verify value prefixes val valuePrefix = valueRowPrefixEncoder.decodeStateRowPrefix(retrievedValue) - assert(valuePrefix.schemaId.isDefined == schemaEvolutionEnabled) - if (schemaEvolutionEnabled) { - assert(valuePrefix.schemaId.get === valueRowPrefixEncoder.getCurrentSchemaId) - } - assert(valuePrefix.columnFamilyId.isEmpty) // Values don't have column family IDs // Verify the actual data after stripping prefixes val retrievedKeyData = keyRowPrefixEncoder.decodeStateRowData(encodedKey) @@ -2061,8 +2052,7 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession val valueEncoder = RocksDBStateEncoder.getValueEncoder( dataEncoder, valueSchema = valueSchema, - useMultipleValuesPerKey = useMultipleValues, - None + useMultipleValuesPerKey = useMultipleValues ) // Encode and write to DB @@ -2088,13 +2078,9 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession if (schemaEvolutionEnabled) { val keyPrefix = keyEncoder.asInstanceOf[StateRowPrefixEncoder] .decodeStateRowPrefix(encodedKey) - assert(keyPrefix.schemaId.isDefined == schemaEvolutionEnabled) - assert(keyPrefix.schemaId.get === 0) // default schema ID val valuePrefix = valueEncoder.asInstanceOf[StateRowPrefixEncoder] .decodeStateRowPrefix(encodedValue) - assert(valuePrefix.schemaId.isDefined == schemaEvolutionEnabled) - assert(valuePrefix.schemaId.get === 0) } // Verify column family prefix if enabled diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala index 99483bc0ee8dc..36d6888ff850b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala @@ -238,7 +238,7 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { val providerId = StateStoreProviderId( StateStoreId(dir, opId, partitionId), queryId) val storeColFamilySchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, - keySchema, valueSchema)) + 0, keySchema, 0, valueSchema)) val checker = new StateSchemaCompatibilityChecker(providerId, hadoopConf) checker.createSchemaFile(storeColFamilySchema, SchemaHelper.SchemaWriter.createSchemaWriter(1)) @@ -259,15 +259,15 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { val runId = UUID.randomUUID() val stateInfo = StatefulOperatorStateInfo(dir, runId, opId, 0, 200) val storeColFamilySchema = List( - StateStoreColFamilySchema("test1", keySchema, valueSchema, + StateStoreColFamilySchema("test1", 0, keySchema, 0, valueSchema, keyStateEncoderSpec = getKeyStateEncoderSpec(stateSchemaVersion, keySchema, encoderSpecStr)), - StateStoreColFamilySchema("test2", longKeySchema, longValueSchema, + StateStoreColFamilySchema("test2", 0, longKeySchema, 0, longValueSchema, keyStateEncoderSpec = getKeyStateEncoderSpec(stateSchemaVersion, longKeySchema, encoderSpecStr)), - StateStoreColFamilySchema("test3", keySchema65535Bytes, valueSchema65535Bytes, + StateStoreColFamilySchema("test3", 0, keySchema65535Bytes, 0, valueSchema65535Bytes, keyStateEncoderSpec = getKeyStateEncoderSpec(stateSchemaVersion, keySchema65535Bytes)), - StateStoreColFamilySchema("test4", keySchema, valueSchema, + StateStoreColFamilySchema("test4", 0, keySchema, 0, valueSchema, keyStateEncoderSpec = getKeyStateEncoderSpec(stateSchemaVersion, keySchema, encoderSpecStr), userKeyEncoderSchema = Some(structSchema))) @@ -391,7 +391,7 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { } val oldStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, - oldKeySchema, oldValueSchema, + 0, oldKeySchema, 0, oldValueSchema, keyStateEncoderSpec = getKeyStateEncoderSpec(stateSchemaVersion, oldKeySchema))) val newSchemaFilePath = getNewSchemaPath(stateSchemaDir, stateSchemaVersion) val result = Try( @@ -407,7 +407,7 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { } else { intercept[SparkUnsupportedOperationException] { val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, - newKeySchema, newValueSchema, + 0, newKeySchema, 0, newValueSchema, keyStateEncoderSpec = getKeyStateEncoderSpec(stateSchemaVersion, newKeySchema))) StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(stateInfo, hadoopConf, newStateSchema, spark.sessionState, stateSchemaVersion = stateSchemaVersion, @@ -459,7 +459,7 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { } val oldStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, - oldKeySchema, oldValueSchema, + 0, oldKeySchema, 0, oldValueSchema, keyStateEncoderSpec = getKeyStateEncoderSpec(stateSchemaVersion, oldKeySchema))) StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(stateInfo, hadoopConf, oldStateSchema, spark.sessionState, stateSchemaVersion = stateSchemaVersion, @@ -468,7 +468,7 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { extraOptions = extraOptions) val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, - newKeySchema, newValueSchema, + 0, newKeySchema, 0, newValueSchema, keyStateEncoderSpec = getKeyStateEncoderSpec(stateSchemaVersion, newKeySchema))) StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(stateInfo, hadoopConf, newStateSchema, spark.sessionState, stateSchemaVersion = stateSchemaVersion, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index de8e7d372150c..15be385479c75 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -1123,15 +1123,15 @@ class TransformWithStateSuite extends StateStoreMetricsTest val keySchema = new StructType().add("value", StringType) val schema0 = StateStoreColFamilySchema( - "countState", - keySchema, + "countState", 0, + keySchema, 0, new StructType().add("value", LongType, false), Some(NoPrefixKeyStateEncoderSpec(keySchema)), None ) val schema1 = StateStoreColFamilySchema( - "listState", - keySchema, + "listState", 0, + keySchema, 0, new StructType() .add("id", LongType, false) .add("name", StringType), @@ -1146,8 +1146,8 @@ class TransformWithStateSuite extends StateStoreMetricsTest .add("key", new StructType().add("value", StringType)) .add("userKey", userKeySchema) val schema2 = StateStoreColFamilySchema( - "mapState", - compositeKeySchema, + "mapState", 0, + compositeKeySchema, 0, new StructType().add("value", StringType), Some(PrefixKeyScanStateEncoderSpec(compositeKeySchema, 1)), Option(userKeySchema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala index 4c7f3a06ea7b9..8f7145859c6ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala @@ -285,13 +285,13 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { .add("expiryTimestampMs", LongType, nullable = false) val schemaForValueRow: StructType = StructType(Array(StructField("__dummy__", NullType))) val schema0 = StateStoreColFamilySchema( - TimerStateUtils.getTimerStateVarName(TimeMode.ProcessingTime().toString), - schemaForKeyRow, + TimerStateUtils.getTimerStateVarName(TimeMode.ProcessingTime().toString), 0, + schemaForKeyRow, 0, schemaForValueRow, Some(PrefixKeyScanStateEncoderSpec(schemaForKeyRow, 1))) val schema1 = StateStoreColFamilySchema( - "valueStateTTL", - keySchema, + "valueStateTTL", 0, + keySchema, 0, new StructType().add("value", new StructType() .add("value", IntegerType, false)) @@ -300,15 +300,15 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { None ) val schema2 = StateStoreColFamilySchema( - "valueState", - keySchema, + "valueState", 0, + keySchema, 0, new StructType().add("value", IntegerType, false), Some(NoPrefixKeyStateEncoderSpec(keySchema)), None ) val schema3 = StateStoreColFamilySchema( - "listState", - keySchema, + "listState", 0, + keySchema, 0, new StructType().add("value", new StructType() .add("id", LongType, false) @@ -325,8 +325,8 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { .add("key", new StructType().add("value", StringType)) .add("userKey", userKeySchema) val schema4 = StateStoreColFamilySchema( - "mapState", - compositeKeySchema, + "mapState", 0, + compositeKeySchema, 0, new StructType().add("value", new StructType() .add("value", StringType))