diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 56f253b523358..bc7ff53d9af36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -26,7 +26,7 @@ import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} import javax.annotation.concurrent.GuardedBy import scala.collection.{mutable, Map} -import scala.jdk.CollectionConverters.{ConcurrentMapHasAsScala, MapHasAsJava} +import scala.jdk.CollectionConverters.ConcurrentMapHasAsScala import scala.ref.WeakReference import scala.util.Try @@ -165,6 +165,10 @@ class RocksDB( @volatile private var numKeysOnLoadedVersion = 0L @volatile private var numKeysOnWritingVersion = 0L + + @volatile private var numInternalKeysOnLoadedVersion = 0L + @volatile private var numInternalKeysOnWritingVersion = 0L + @volatile private var fileManagerMetrics = RocksDBFileManagerMetrics.EMPTY_METRICS // SPARK-46249 - Keep track of recorded metrics per version which can be used for querying later @@ -178,7 +182,10 @@ class RocksDB( // This is accessed and updated only between load and commit // which means it is implicitly guarded by acquireLock @GuardedBy("acquireLock") - private val colFamilyNameToIdMap = new ConcurrentHashMap[String, Short]() + private val colFamilyNameToInfoMap = new ConcurrentHashMap[String, ColumnFamilyInfo]() + + @GuardedBy("acquireLock") + private val colFamilyIdToNameMap = new ConcurrentHashMap[Short, String]() @GuardedBy("acquireLock") private val maxColumnFamilyId: AtomicInteger = new AtomicInteger(-1) @@ -186,34 +193,49 @@ class RocksDB( @GuardedBy("acquireLock") private val shouldForceSnapshot: AtomicBoolean = new AtomicBoolean(false) - /** - * Check whether the column family name is for internal column families. - * - * @param cfName - column family name - * @return - true if the column family is for internal use, false otherwise - */ - private def checkInternalColumnFamilies(cfName: String): Boolean = cfName.charAt(0) == '_' + private def getColumnFamilyInfo(cfName: String): ColumnFamilyInfo = { + colFamilyNameToInfoMap.get(cfName) + } + + private def getColumnFamilyNameForId(cfId: Short): String = { + colFamilyIdToNameMap.get(cfId) + } - // Methods to fetch column family mapping for this State Store version - def getColumnFamilyMapping: Map[String, Short] = { - colFamilyNameToIdMap.asScala + private def addToColFamilyMaps(cfName: String, cfId: Short, isInternal: Boolean): Unit = { + colFamilyNameToInfoMap.putIfAbsent(cfName, ColumnFamilyInfo(cfId, isInternal)) + colFamilyIdToNameMap.putIfAbsent(cfId, cfName) + } + + private def removeFromColFamilyMaps(cfName: String): Unit = { + val colFamilyInfo = colFamilyNameToInfoMap.get(cfName) + if (colFamilyInfo != null) { + colFamilyNameToInfoMap.remove(cfName) + colFamilyIdToNameMap.remove(colFamilyInfo.cfId) + } } - def getColumnFamilyId(cfName: String): Short = { - colFamilyNameToIdMap.get(cfName) + private def clearColFamilyMaps(): Unit = { + colFamilyNameToInfoMap.clear() + colFamilyIdToNameMap.clear() } /** - * Create RocksDB column family, if not created already + * Check if the column family exists with given name and create one if it doesn't. Users can + * create external column families storing user facing data as well as internal column families + * such as secondary indexes. Metrics for both of these types are tracked separately. + * + * @param colFamilyName - column family name + * @param isInternal - whether the column family is for internal use or not + * @return - virtual column family id */ - def createColFamilyIfAbsent(colFamilyName: String): Short = { + def createColFamilyIfAbsent(colFamilyName: String, isInternal: Boolean): Short = { if (!checkColFamilyExists(colFamilyName)) { val newColumnFamilyId = maxColumnFamilyId.incrementAndGet().toShort - colFamilyNameToIdMap.putIfAbsent(colFamilyName, newColumnFamilyId) + addToColFamilyMaps(colFamilyName, newColumnFamilyId, isInternal) shouldForceSnapshot.set(true) newColumnFamilyId } else { - colFamilyNameToIdMap.get(colFamilyName) + colFamilyNameToInfoMap.get(colFamilyName).cfId } } @@ -221,12 +243,16 @@ class RocksDB( * Remove RocksDB column family, if exists * @return columnFamilyId if it exists, else None */ - def removeColFamilyIfExists(colFamilyName: String): Option[Short] = { + def removeColFamilyIfExists(colFamilyName: String): Boolean = { if (checkColFamilyExists(colFamilyName)) { shouldForceSnapshot.set(true) - Some(colFamilyNameToIdMap.remove(colFamilyName)) + iterator(colFamilyName).foreach { kv => + remove(kv.key, colFamilyName) + } + removeFromColFamilyMaps(colFamilyName) + true } else { - None + false } } @@ -237,23 +263,19 @@ class RocksDB( * @return - true if the column family exists, false otherwise */ def checkColFamilyExists(colFamilyName: String): Boolean = { - colFamilyNameToIdMap.containsKey(colFamilyName) + db != null && colFamilyNameToInfoMap.containsKey(colFamilyName) } // This method sets the internal column family metadata to // the default values it should be set to on load private def setInitialCFInfo(): Unit = { - colFamilyNameToIdMap.clear() + clearColFamilyMaps() shouldForceSnapshot.set(false) maxColumnFamilyId.set(0) } def getColFamilyCount(isInternal: Boolean): Long = { - if (isInternal) { - colFamilyNameToIdMap.asScala.keys.toSeq.count(checkInternalColumnFamilies) - } else { - colFamilyNameToIdMap.asScala.keys.toSeq.count(!checkInternalColumnFamilies(_)) - } + colFamilyNameToInfoMap.asScala.values.toSeq.count(_.isInternal == isInternal) } // Mapping of local SST files to DFS files for file reuse. @@ -375,6 +397,7 @@ class RocksDB( // After changelog replay the numKeysOnWritingVersion will be updated to // the correct number of keys in the loaded version. numKeysOnLoadedVersion = numKeysOnWritingVersion + numInternalKeysOnLoadedVersion = numInternalKeysOnWritingVersion fileManagerMetrics = fileManager.latestLoadCheckpointMetrics } @@ -448,6 +471,7 @@ class RocksDB( // After changelog replay the numKeysOnWritingVersion will be updated to // the correct number of keys in the loaded version. numKeysOnLoadedVersion = numKeysOnWritingVersion + numInternalKeysOnLoadedVersion = numInternalKeysOnWritingVersion fileManagerMetrics = fileManager.latestLoadCheckpointMetrics } if (conf.resetStatsOnLoad) { @@ -467,29 +491,62 @@ class RocksDB( this } + /** + * Function to check if col family is internal or not based on information recorded in + * checkpoint metadata. + * @param cfName - column family name + * @param metadata - checkpoint metadata + * @return - type of column family (internal or otherwise) + */ + private def isInternalColFamily( + cfName: String, + metadata: RocksDBCheckpointMetadata): Boolean = { + if (metadata.columnFamilyTypeMap.isEmpty) { + false + } else { + metadata.columnFamilyTypeMap.get.get(cfName) match { + case Some(cfType) => + cfType + case None => + false + } + } + } + /** * Initialize key metrics based on the metadata loaded from DFS and open local RocksDB. */ private def openLocalRocksDB(metadata: RocksDBCheckpointMetadata): Unit = { setInitialCFInfo() metadata.columnFamilyMapping.foreach { mapping => - colFamilyNameToIdMap.putAll(mapping.asJava) + mapping.foreach { case (colFamilyName, cfId) => + addToColFamilyMaps(colFamilyName, cfId, isInternalColFamily(colFamilyName, metadata)) + } } metadata.maxColumnFamilyId.foreach { maxId => maxColumnFamilyId.set(maxId) } + + if (useColumnFamilies) { + createColFamilyIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME, isInternal = false) + } + openDB() - numKeysOnWritingVersion = if (!conf.trackTotalNumberOfRows) { - // we don't track the total number of rows - discard the number being track - -1L - } else if (metadata.numKeys < 0) { - // we track the total number of rows, but the snapshot doesn't have tracking number - // need to count keys now - countKeys() - } else { - metadata.numKeys + val (numKeys, numInternalKeys) = { + if (!conf.trackTotalNumberOfRows) { + // we don't track the total number of rows - discard the number being track + (-1L, -1L) + } else if (metadata.numKeys < 0) { + // we track the total number of rows, but the snapshot doesn't have tracking number + // need to count keys now + countKeys() + } else { + (metadata.numKeys, metadata.numInternalKeys) + } } + numKeysOnWritingVersion = numKeys + numInternalKeysOnWritingVersion = numInternalKeys } def load( @@ -557,16 +614,19 @@ class RocksDB( lastSnapshotVersion = snapshotVersion openDB() - numKeysOnWritingVersion = if (!conf.trackTotalNumberOfRows) { + val (numKeys, numInternalKeys) = if (!conf.trackTotalNumberOfRows) { // we don't track the total number of rows - discard the number being track - -1L + (-1L, -1L) } else if (metadata.numKeys < 0) { // we track the total number of rows, but the snapshot doesn't have tracking number // need to count keys now countKeys() } else { - metadata.numKeys + (metadata.numKeys, metadata.numInternalKeys) } + numKeysOnWritingVersion = numKeys + numInternalKeysOnWritingVersion = numInternalKeys + if (loadedVersion != endVersion) { val versionsAndUniqueIds: Array[(Long, Option[String])] = (loadedVersion + 1 to endVersion).map((_, None)).toArray @@ -576,6 +636,7 @@ class RocksDB( // After changelog replay the numKeysOnWritingVersion will be updated to // the correct number of keys in the loaded version. numKeysOnLoadedVersion = numKeysOnWritingVersion + numInternalKeysOnLoadedVersion = numInternalKeysOnWritingVersion fileManagerMetrics = fileManager.latestLoadCheckpointMetrics if (conf.resetStatsOnLoad) { @@ -603,16 +664,33 @@ class RocksDB( var changelogReader: StateStoreChangelogReader = null try { changelogReader = fileManager.getChangelogReader(v, uniqueId) - changelogReader.foreach { case (recordType, key, value) => - recordType match { - case RecordType.PUT_RECORD => - put(key, value) - case RecordType.DELETE_RECORD => - remove(key) + if (useColumnFamilies) { + changelogReader.foreach { case (recordType, key, value) => + val (keyWithoutPrefix, cfName) = decodeStateRowWithPrefix(key) + recordType match { + case RecordType.PUT_RECORD => + put(keyWithoutPrefix, value, cfName) + + case RecordType.DELETE_RECORD => + remove(keyWithoutPrefix, cfName) - case RecordType.MERGE_RECORD => - merge(key, value) + case RecordType.MERGE_RECORD => + merge(keyWithoutPrefix, value, cfName) + } + } + } else { + changelogReader.foreach { case (recordType, key, value) => + recordType match { + case RecordType.PUT_RECORD => + put(key, value) + + case RecordType.DELETE_RECORD => + remove(key) + + case RecordType.MERGE_RECORD => + merge(key, value) + } } } } finally { @@ -621,28 +699,114 @@ class RocksDB( } } + /** + * Function to encode state row with virtual col family id prefix + * @param data - passed byte array to be stored in state store + * @param cfName - name of column family + * @return - encoded byte array with virtual column family id prefix + */ + private def encodeStateRowWithPrefix( + data: Array[Byte], + cfName: String): Array[Byte] = { + val cfInfo = getColumnFamilyInfo(cfName) + RocksDBStateStoreProvider.encodeStateRowWithPrefix(data, cfInfo.cfId) + } + + /** + * Function to decode state row with virtual col family id prefix + * @param data - passed byte array retrieved from state store + * @return - pair of decoded byte array without virtual column family id prefix + * and name of column family + */ + private def decodeStateRowWithPrefix(data: Array[Byte]): (Array[Byte], String) = { + val cfId = RocksDBStateStoreProvider.getColumnFamilyBytesAsId(data) + val cfName = getColumnFamilyNameForId(cfId) + val key = RocksDBStateStoreProvider.decodeStateRowWithPrefix(data) + (key, cfName) + } + /** * Get the value for the given key if present, or null. * @note This will return the last written value even if it was uncommitted. */ - def get(key: Array[Byte]): Array[Byte] = { - db.get(readOptions, key) + def get( + key: Array[Byte], + cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Array[Byte] = { + val keyWithPrefix = if (useColumnFamilies) { + encodeStateRowWithPrefix(key, cfName) + } else { + key + } + + db.get(readOptions, keyWithPrefix) + } + + /** + * Function to check if value exists for a key or not depending on the operation type. + * @param oldValue - old value for the key + * @param isPutOrMerge - flag to indicate if the operation is put or merge + * @return - true if the value doesn't exist for putAndMerge operation and vice versa for remove + */ + private def checkExistingEntry( + oldValue: Array[Byte], + isPutOrMerge: Boolean): Boolean = { + if (isPutOrMerge) { + oldValue == null + } else { + oldValue != null + } + } + + /** + * Function to keep track of metrics updates around the number of keys in the store. + * @param keyWithPrefix - key with prefix + * @param cfName - column family name + * @param isPutOrMerge - flag to indicate if the operation is put or merge + */ + private def handleMetricsUpdate( + keyWithPrefix: Array[Byte], + cfName: String, + isPutOrMerge: Boolean): Unit = { + val updateCount = if (isPutOrMerge) 1L else -1L + if (useColumnFamilies) { + if (conf.trackTotalNumberOfRows) { + val oldValue = db.get(readOptions, keyWithPrefix) + if (checkExistingEntry(oldValue, isPutOrMerge)) { + val cfInfo = getColumnFamilyInfo(cfName) + if (cfInfo.isInternal) { + numInternalKeysOnWritingVersion += updateCount + } else { + numKeysOnWritingVersion += updateCount + } + } + } + } else { + if (conf.trackTotalNumberOfRows) { + val oldValue = db.get(readOptions, keyWithPrefix) + if (checkExistingEntry(oldValue, isPutOrMerge)) { + numKeysOnWritingVersion += updateCount + } + } + } } /** * Put the given value for the given key. * @note This update is not committed to disk until commit() is called. */ - def put(key: Array[Byte], value: Array[Byte]): Unit = { - if (conf.trackTotalNumberOfRows) { - val oldValue = db.get(readOptions, key) - if (oldValue == null) { - numKeysOnWritingVersion += 1 - } + def put( + key: Array[Byte], + value: Array[Byte], + cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = { + val keyWithPrefix = if (useColumnFamilies) { + encodeStateRowWithPrefix(key, cfName) + } else { + key } - db.put(writeOptions, key, value) - changelogWriter.foreach(_.put(key, value)) + handleMetricsUpdate(keyWithPrefix, cfName, isPutOrMerge = true) + db.put(writeOptions, keyWithPrefix, value) + changelogWriter.foreach(_.put(keyWithPrefix, value)) } /** @@ -656,31 +820,35 @@ class RocksDB( * * @note This update is not committed to disk until commit() is called. */ - def merge(key: Array[Byte], value: Array[Byte]): Unit = { - if (conf.trackTotalNumberOfRows) { - val oldValue = db.get(readOptions, key) - if (oldValue == null) { - numKeysOnWritingVersion += 1 - } + def merge( + key: Array[Byte], + value: Array[Byte], + cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = { + val keyWithPrefix = if (useColumnFamilies) { + encodeStateRowWithPrefix(key, cfName) + } else { + key } - db.merge(writeOptions, key, value) - changelogWriter.foreach(_.merge(key, value)) + handleMetricsUpdate(keyWithPrefix, cfName, isPutOrMerge = true) + db.merge(writeOptions, keyWithPrefix, value) + changelogWriter.foreach(_.merge(keyWithPrefix, value)) } /** * Remove the key if present. * @note This update is not committed to disk until commit() is called. */ - def remove(key: Array[Byte]): Unit = { - if (conf.trackTotalNumberOfRows) { - val value = db.get(readOptions, key) - if (value != null) { - numKeysOnWritingVersion -= 1 - } + def remove(key: Array[Byte], cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = { + val keyWithPrefix = if (useColumnFamilies) { + encodeStateRowWithPrefix(key, cfName) + } else { + key } - db.delete(writeOptions, key) - changelogWriter.foreach(_.delete(key)) + + handleMetricsUpdate(keyWithPrefix, cfName, isPutOrMerge = false) + db.delete(writeOptions, keyWithPrefix) + changelogWriter.foreach(_.delete(keyWithPrefix)) } /** @@ -700,7 +868,13 @@ class RocksDB( new NextIterator[ByteArrayPair] { override protected def getNext(): ByteArrayPair = { if (iter.isValid) { - byteArrayPair.set(iter.key, iter.value) + val key = if (useColumnFamilies) { + decodeStateRowWithPrefix(iter.key)._1 + } else { + iter.key + } + + byteArrayPair.set(key, iter.value) iter.next() byteArrayPair } else { @@ -713,7 +887,18 @@ class RocksDB( } } - private def countKeys(): Long = { + /** + * Get an iterator of all committed and uncommitted key-value pairs for the given column family. + */ + def iterator(cfName: String): Iterator[ByteArrayPair] = { + if (!useColumnFamilies) { + iterator() + } else { + prefixScan(Array.empty[Byte], cfName) + } + } + + private def countKeys(): (Long, Long) = { val iter = db.newIterator() try { @@ -723,20 +908,46 @@ class RocksDB( iter.seekToFirst() var keys = 0L - while (iter.isValid) { - keys += 1 - iter.next() + var internalKeys = 0L + + if (!useColumnFamilies) { + while (iter.isValid) { + keys += 1 + iter.next() + } + } else { + var currCfInfoOpt: Option[(String, ColumnFamilyInfo)] = None + while (iter.isValid) { + val (_, cfName) = decodeStateRowWithPrefix(iter.key) + if (currCfInfoOpt.isEmpty || currCfInfoOpt.get._1 != cfName) { + currCfInfoOpt = Some((cfName, getColumnFamilyInfo(cfName))) + } + if (currCfInfoOpt.get._2.isInternal) { + internalKeys += 1 + } else { + keys += 1 + } + iter.next() + } } - keys + (keys, internalKeys) } finally { iter.close() } } - def prefixScan(prefix: Array[Byte]): Iterator[ByteArrayPair] = { + def prefixScan( + prefix: Array[Byte], + cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[ByteArrayPair] = { val iter = db.newIterator() - iter.seek(prefix) + val updatedPrefix = if (useColumnFamilies) { + encodeStateRowWithPrefix(prefix, cfName) + } else { + prefix + } + + iter.seek(updatedPrefix) // Attempt to close this iterator if there is a task failure, or a task interruption. Option(TaskContext.get()).foreach { tc => @@ -745,8 +956,14 @@ class RocksDB( new NextIterator[ByteArrayPair] { override protected def getNext(): ByteArrayPair = { - if (iter.isValid && iter.key().take(prefix.length).sameElements(prefix)) { - byteArrayPair.set(iter.key, iter.value) + if (iter.isValid && iter.key().take(updatedPrefix.length).sameElements(updatedPrefix)) { + val key = if (useColumnFamilies) { + decodeStateRowWithPrefix(iter.key)._1 + } else { + iter.key + } + + byteArrayPair.set(key, iter.value) iter.next() byteArrayPair } else { @@ -827,6 +1044,7 @@ class RocksDB( fileManager.setMaxSeenVersion(newVersion) numKeysOnLoadedVersion = numKeysOnWritingVersion + numInternalKeysOnLoadedVersion = numInternalKeysOnWritingVersion loadedVersion = newVersion commitLatencyMs ++= Map( "fileSync" -> fileSyncTimeMs @@ -889,7 +1107,8 @@ class RocksDB( checkpointDir, version, numKeysOnWritingVersion, - colFamilyNameToIdMap.asScala.toMap, + numInternalKeysOnWritingVersion, + colFamilyNameToInfoMap.asScala.toMap, maxColumnFamilyId.get().toShort, dfsFileSuffix, immutableFileMapping, @@ -924,6 +1143,7 @@ class RocksDB( acquire(RollbackStore) try { numKeysOnWritingVersion = numKeysOnLoadedVersion + numInternalKeysOnWritingVersion = numInternalKeysOnLoadedVersion loadedVersion = -1L lastCommitBasedStateStoreCkptId = None lastCommittedStateStoreCkptId = None @@ -1064,6 +1284,7 @@ class RocksDB( RocksDBMetrics( numKeysOnLoadedVersion, numKeysOnWritingVersion, + numInternalKeysOnWritingVersion, memoryUsage, pinnedBlocksMemUsage, totalSSTFilesBytes, @@ -1216,6 +1437,7 @@ class RocksDB( snapshot.checkpointDir, snapshot.version, snapshot.numKeys, + snapshot.numInternalKeys, snapshot.fileMapping, Some(snapshot.columnFamilyMapping), Some(snapshot.maxColumnFamilyId), @@ -1290,7 +1512,8 @@ object RocksDB extends Logging { checkpointDir: File, version: Long, numKeys: Long, - columnFamilyMapping: Map[String, Short], + numInternalKeys: Long, + columnFamilyMapping: Map[String, ColumnFamilyInfo], maxColumnFamilyId: Short, dfsFileSuffix: String, fileMapping: Map[String, RocksDBSnapshotFile], @@ -1679,6 +1902,7 @@ object RocksDBConf { case class RocksDBMetrics( numCommittedKeys: Long, numUncommittedKeys: Long, + numInternalKeys: Long, totalMemUsageBytes: Long, pinnedBlocksMemUsage: Long, totalSSTFilesBytes: Long, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala index d077928711de9..bb1198dfccafc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala @@ -248,8 +248,9 @@ class RocksDBFileManager( checkpointDir: File, version: Long, numKeys: Long, + numInternalKeys: Long, fileMapping: Map[String, RocksDBSnapshotFile], - columnFamilyMapping: Option[Map[String, Short]] = None, + columnFamilyMapping: Option[Map[String, ColumnFamilyInfo]] = None, maxColumnFamilyId: Option[Short] = None, checkpointUniqueId: Option[String] = None): Unit = { logFilesInDir(checkpointDir, log"Saving checkpoint files " + @@ -257,8 +258,27 @@ class RocksDBFileManager( val (localImmutableFiles, localOtherFiles) = listRocksDBFiles(checkpointDir) val rocksDBFiles = saveImmutableFilesToDfs( version, localImmutableFiles, fileMapping, checkpointUniqueId) - val metadata = RocksDBCheckpointMetadata(rocksDBFiles, numKeys, columnFamilyMapping, - maxColumnFamilyId) + + val colFamilyIdMapping: Option[Map[String, Short]] = if (columnFamilyMapping.isDefined) { + Some(columnFamilyMapping.get.map { + case (cfName, cfInfo) => + cfName -> cfInfo.cfId + }) + } else { + None + } + + val colFamilyTypeMapping: Option[Map[String, Boolean]] = if (columnFamilyMapping.isDefined) { + Some(columnFamilyMapping.get.map { + case (cfName, cfInfo) => + cfName -> cfInfo.isInternal + }) + } else { + None + } + + val metadata = RocksDBCheckpointMetadata(rocksDBFiles, numKeys, numInternalKeys, + colFamilyIdMapping, colFamilyTypeMapping, maxColumnFamilyId) val metadataFile = localMetadataFile(checkpointDir) metadata.writeToFile(metadataFile) logInfo(log"Written metadata for version ${MDC(LogKeys.VERSION_NUM, version)}:\n" + @@ -923,6 +943,15 @@ object RocksDBFileManagerMetrics { val EMPTY_METRICS = RocksDBFileManagerMetrics(0L, 0L, 0L, None) } +/** + * Case class to keep track of column family info within checkpoint metadata. + * @param cfId - virtual column family id + * @param isInternal - whether the column family is internal or not + */ +case class ColumnFamilyInfo( + cfId: Short, + isInternal: Boolean) + /** * Classes to represent metadata of checkpoints saved to DFS. Since this is converted to JSON, any * changes to this MUST be backward-compatible. @@ -931,7 +960,9 @@ case class RocksDBCheckpointMetadata( sstFiles: Seq[RocksDBSstFile], logFiles: Seq[RocksDBLogFile], numKeys: Long, + numInternalKeys: Long, columnFamilyMapping: Option[Map[String, Short]] = None, + columnFamilyTypeMap: Option[Map[String, Boolean]] = None, maxColumnFamilyId: Option[Short] = None) { require(columnFamilyMapping.isDefined == maxColumnFamilyId.isDefined, @@ -997,6 +1028,7 @@ object RocksDBCheckpointMetadata { sstFiles.map(_.asInstanceOf[RocksDBSstFile]), logFiles.map(_.asInstanceOf[RocksDBLogFile]), numKeys, + 0, None, None ) @@ -1005,14 +1037,18 @@ object RocksDBCheckpointMetadata { def apply( rocksDBFiles: Seq[RocksDBImmutableFile], numKeys: Long, + numInternalKeys: Long, columnFamilyMapping: Option[Map[String, Short]], + columnFamilyTypeMap: Option[Map[String, Boolean]], maxColumnFamilyId: Option[Short]): RocksDBCheckpointMetadata = { val (sstFiles, logFiles) = rocksDBFiles.partition(_.isInstanceOf[RocksDBSstFile]) new RocksDBCheckpointMetadata( sstFiles.map(_.asInstanceOf[RocksDBSstFile]), logFiles.map(_.asInstanceOf[RocksDBLogFile]), numKeys, + numInternalKeys, columnFamilyMapping, + columnFamilyTypeMap, maxColumnFamilyId ) } @@ -1022,21 +1058,25 @@ object RocksDBCheckpointMetadata { sstFiles: Seq[RocksDBSstFile], logFiles: Seq[RocksDBLogFile], numKeys: Long): RocksDBCheckpointMetadata = { - new RocksDBCheckpointMetadata(sstFiles, logFiles, numKeys, None, None) + new RocksDBCheckpointMetadata(sstFiles, logFiles, numKeys, 0, None, None) } // Apply method for cases with column family information def apply( rocksDBFiles: Seq[RocksDBImmutableFile], numKeys: Long, + numInternalKeys: Long, columnFamilyMapping: Map[String, Short], + columnFamilyTypeMap: Map[String, Boolean], maxColumnFamilyId: Short): RocksDBCheckpointMetadata = { val (sstFiles, logFiles) = rocksDBFiles.partition(_.isInstanceOf[RocksDBSstFile]) new RocksDBCheckpointMetadata( sstFiles.map(_.asInstanceOf[RocksDBSstFile]), logFiles.map(_.asInstanceOf[RocksDBLogFile]), numKeys, + numInternalKeys, Some(columnFamilyMapping), + Some(columnFamilyTypeMap), Some(maxColumnFamilyId) ) } @@ -1046,13 +1086,17 @@ object RocksDBCheckpointMetadata { sstFiles: Seq[RocksDBSstFile], logFiles: Seq[RocksDBLogFile], numKeys: Long, + numInternalKeys: Long, columnFamilyMapping: Map[String, Short], + columnFamilyTypeMap: Map[String, Boolean], maxColumnFamilyId: Short): RocksDBCheckpointMetadata = { new RocksDBCheckpointMetadata( sstFiles, logFiles, numKeys, + numInternalKeys, Some(columnFamilyMapping), + Some(columnFamilyTypeMap), Some(maxColumnFamilyId) ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index 08633999bbc9c..c7b324ec32e62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, StateStoreColumnFamilySchemaUtils} -import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.{SCHEMA_ID_PREFIX_BYTES, STATE_ENCODING_NUM_VERSION_BYTES, STATE_ENCODING_VERSION, VIRTUAL_COL_FAMILY_PREFIX_BYTES} +import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.{SCHEMA_ID_PREFIX_BYTES, STATE_ENCODING_NUM_VERSION_BYTES, STATE_ENCODING_VERSION} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -46,7 +46,6 @@ sealed trait RocksDBKeyStateEncoder { def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] def encodeKey(row: UnsafeRow): Array[Byte] def decodeKey(keyBytes: Array[Byte]): UnsafeRow - def getColumnFamilyIdBytes(): Array[Byte] } sealed trait RocksDBValueStateEncoder { @@ -720,13 +719,14 @@ class UnsafeRowDataEncoder( * * @param keyStateEncoderSpec Specification for how to encode keys (prefix/range scan) * @param valueSchema Schema for the values to be encoded - * @param stateSchemaInfo Schema version information for both keys and values + * @param stateSchemaProvider Optional state schema provider + * @param columnFamilyName Column family name to be used */ class AvroStateEncoder( keyStateEncoderSpec: KeyStateEncoderSpec, valueSchema: StructType, stateSchemaProvider: Option[StateSchemaProvider], - columnFamilyInfo: Option[ColumnFamilyInfo] + columnFamilyName: String ) extends RocksDBDataEncoder(keyStateEncoderSpec, valueSchema) with Logging { private val avroEncoder = createAvroEnc(keyStateEncoderSpec, valueSchema) @@ -734,12 +734,12 @@ class AvroStateEncoder( // current schema IDs instantiated lazily // schema information private lazy val currentKeySchemaId: Short = getStateSchemaProvider.getCurrentStateSchemaId( - getColFamilyName, + columnFamilyName, isKey = true ) private lazy val currentValSchemaId: Short = getStateSchemaProvider.getCurrentStateSchemaId( - getColFamilyName, + columnFamilyName, isKey = false ) @@ -869,10 +869,6 @@ class AvroStateEncoder( override def supportsSchemaEvolution: Boolean = true - private def getColFamilyName: String = { - columnFamilyInfo.get.colFamilyName - } - private def getStateSchemaProvider: StateSchemaProvider = { assert(stateSchemaProvider.isDefined, "StateSchemaProvider should always be" + " defined for the Avro encoder") @@ -1265,7 +1261,7 @@ class AvroStateEncoder( val schemaIdRow = decodeStateSchemaIdRow(bytes) val writerSchema = getStateSchemaProvider.getSchemaMetadataValue( StateSchemaMetadataKey( - getColFamilyName, + columnFamilyName, schemaIdRow.schemaId, isKey = false ) @@ -1279,120 +1275,6 @@ class AvroStateEncoder( } } -/** - * Information about a RocksDB column family used for state storage. - * - * @param colFamilyName The name of the column family in RocksDB - * @param virtualColumnFamilyId A unique identifier for the virtual column family, - * used as a prefix in encoded state rows to distinguish - * between different column families - */ -case class ColumnFamilyInfo( - colFamilyName: String, - virtualColumnFamilyId: Short -) - -/** - * Metadata prefixes stored at the beginning of encoded state rows. - * These prefixes allow for schema evolution and column family organization - * in the state store. - * - * @param columnFamilyId Optional identifier for the virtual column family. - * When present, allows organizing state data into - * different column families in RocksDB. - */ -case class StateRowPrefix( - columnFamilyId: Option[Short] -) - -class StateRowPrefixEncoder( - useColumnFamilies: Boolean, - columnFamilyInfo: Option[ColumnFamilyInfo] -) { - - private val numColFamilyBytes = if (useColumnFamilies) { - VIRTUAL_COL_FAMILY_PREFIX_BYTES - } else { - 0 - } - - def getNumPrefixBytes: Int = numColFamilyBytes - - val out = new ByteArrayOutputStream - - /** - * Get Byte Array for the virtual column family id that is used as prefix for - * key state rows. - */ - def getColumnFamilyIdBytes(): Array[Byte] = { - assert(useColumnFamilies, "Cannot return virtual Column Family Id Bytes" + - " because multiple Column is not supported for this encoder") - val encodedBytes = new Array[Byte](VIRTUAL_COL_FAMILY_PREFIX_BYTES) - val virtualColFamilyId = columnFamilyInfo.get.virtualColumnFamilyId - Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET, virtualColFamilyId) - encodedBytes - } - - /** - * Encodes a state row by adding schema and column family ID prefixes if enabled. - * - * @param data The byte array containing the data to be prefixed - * @return A new byte array containing the prefixed data. If no prefixing is needed - * (neither schema evolution nor column families are enabled), returns a copy - * of the input array to maintain consistency with the prefixed case. - */ - def encodeStateRowWithPrefix(data: Array[Byte]): Array[Byte] = { - // Create result array big enough for all prefixes plus data - val result = new Array[Byte](getNumPrefixBytes + data.length) - var offset = Platform.BYTE_ARRAY_OFFSET - - // Write column family ID if enabled - if (useColumnFamilies) { - val colFamilyId = columnFamilyInfo.get.virtualColumnFamilyId - Platform.putShort(result, offset, colFamilyId) - offset += VIRTUAL_COL_FAMILY_PREFIX_BYTES - } - - // Write the actual data - Platform.copyMemory( - data, Platform.BYTE_ARRAY_OFFSET, - result, offset, - data.length - ) - - result - } - - def decodeStateRowPrefix(stateRow: Array[Byte]): StateRowPrefix = { - var offset = Platform.BYTE_ARRAY_OFFSET - - // Read column family ID if present - val colFamilyId = if (useColumnFamilies) { - val id = Platform.getShort(stateRow, offset) - offset += VIRTUAL_COL_FAMILY_PREFIX_BYTES - Some(id) - } else { - None - } - - StateRowPrefix(colFamilyId) - } - - def decodeStateRowData(stateRow: Array[Byte]): Array[Byte] = { - val offset = Platform.BYTE_ARRAY_OFFSET + getNumPrefixBytes - - // Extract the actual data - val dataLength = stateRow.length - getNumPrefixBytes - val data = new Array[Byte](dataLength) - Platform.copyMemory( - stateRow, offset, - data, Platform.BYTE_ARRAY_OFFSET, - dataLength - ) - data - } -} - /** * Factory object for creating state encoders used by RocksDB state store. * @@ -1409,15 +1291,13 @@ object RocksDBStateEncoder extends Logging { * @param keyStateEncoderSpec Specification defining the key encoding strategy * (no prefix, prefix scan, or range scan) * @param useColumnFamilies Whether to use RocksDB column families for storage - * @param virtualColFamilyId Optional column family identifier when column families are enabled * @return A configured RocksDBKeyStateEncoder instance */ def getKeyEncoder( dataEncoder: RocksDBDataEncoder, keyStateEncoderSpec: KeyStateEncoderSpec, - useColumnFamilies: Boolean, - columnFamilyInfo: Option[ColumnFamilyInfo] = None): RocksDBKeyStateEncoder = { - keyStateEncoderSpec.toEncoder(dataEncoder, useColumnFamilies, columnFamilyInfo) + useColumnFamilies: Boolean): RocksDBKeyStateEncoder = { + keyStateEncoderSpec.toEncoder(dataEncoder, useColumnFamilies) } /** @@ -1439,21 +1319,6 @@ object RocksDBStateEncoder extends Logging { new SingleValueStateEncoder(dataEncoder, valueSchema) } } - - /** - * Encodes a virtual column family ID into a byte array suitable for RocksDB. - * - * This method creates a fixed-size byte array prefixed with the virtual column family ID, - * which is used to partition data within RocksDB. - * - * @param virtualColFamilyId The column family identifier to encode - * @return A byte array containing the encoded column family ID - */ - def getColumnFamilyIdBytes(virtualColFamilyId: Short): Array[Byte] = { - val encodedBytes = new Array[Byte](VIRTUAL_COL_FAMILY_PREFIX_BYTES) - Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET, virtualColFamilyId) - encodedBytes - } } /** @@ -1468,12 +1333,8 @@ class PrefixKeyScanStateEncoder( dataEncoder: RocksDBDataEncoder, keySchema: StructType, numColsPrefixKey: Int, - useColumnFamilies: Boolean = false, - columnFamilyInfo: Option[ColumnFamilyInfo] = None) - extends StateRowPrefixEncoder( - useColumnFamilies, - columnFamilyInfo - ) with RocksDBKeyStateEncoder with Logging { + useColumnFamilies: Boolean = false) + extends RocksDBKeyStateEncoder with Logging { private val prefixKeyFieldsWithIdx: Seq[(StructField, Int)] = { keySchema.zipWithIndex.take(numColsPrefixKey) @@ -1519,13 +1380,11 @@ class PrefixKeyScanStateEncoder( remainingEncoded.length ) - // Add state row prefix using encoder - encodeStateRowWithPrefix(combinedData) + combinedData } override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = { - // First decode the metadata prefixes and get the actual key data - val keyData = decodeStateRowData(keyBytes) + val keyData = keyBytes // Get prefix key length from the start of the actual key data val prefixKeyEncodedLen = Platform.getInt(keyData, Platform.BYTE_ARRAY_OFFSET) @@ -1567,8 +1426,7 @@ class PrefixKeyScanStateEncoder( dataWithLength, Platform.BYTE_ARRAY_OFFSET + 4, prefixKeyEncoded.length ) - - encodeStateRowWithPrefix(dataWithLength) + dataWithLength } override def supportPrefixKeyScan: Boolean = true @@ -1610,12 +1468,8 @@ class RangeKeyScanStateEncoder( dataEncoder: RocksDBDataEncoder, keySchema: StructType, orderingOrdinals: Seq[Int], - useColumnFamilies: Boolean = false, - columnFamilyInfo: Option[ColumnFamilyInfo] = None) - extends StateRowPrefixEncoder( - useColumnFamilies, - columnFamilyInfo - ) with RocksDBKeyStateEncoder with Logging { + useColumnFamilies: Boolean = false) + extends RocksDBKeyStateEncoder with Logging { private val rangeScanKeyFieldsWithOrdinal: Seq[(StructField, Int)] = { orderingOrdinals.map { ordinal => @@ -1716,12 +1570,11 @@ class RangeKeyScanStateEncoder( remainingEncoded.length ) - encodeStateRowWithPrefix(combinedData) + combinedData } override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = { - // First decode metadata prefixes to get the actual key data - val keyData = decodeStateRowData(keyBytes) + val keyData = keyBytes // Get range scan key length and extract it val prefixKeyEncodedLen = Platform.getInt(keyData, Platform.BYTE_ARRAY_OFFSET) @@ -1770,7 +1623,7 @@ class RangeKeyScanStateEncoder( rangeScanKeyEncoded.length ) - encodeStateRowWithPrefix(dataWithLength) + dataWithLength } override def supportPrefixKeyScan: Boolean = true @@ -1791,16 +1644,12 @@ class RangeKeyScanStateEncoder( class NoPrefixKeyStateEncoder( dataEncoder: RocksDBDataEncoder, keySchema: StructType, - useColumnFamilies: Boolean = false, - columnFamilyInfo: Option[ColumnFamilyInfo] = None) - extends StateRowPrefixEncoder( - useColumnFamilies, - columnFamilyInfo - ) with RocksDBKeyStateEncoder with Logging { + useColumnFamilies: Boolean = false) + extends RocksDBKeyStateEncoder with Logging { override def encodeKey(row: UnsafeRow): Array[Byte] = { if (!useColumnFamilies) { - encodeStateRowWithPrefix(dataEncoder.encodeKey(row)) + dataEncoder.encodeKey(row) } else { // First encode the row with the data encoder val rowBytes = dataEncoder.encodeKey(row) @@ -1814,18 +1663,17 @@ class NoPrefixKeyStateEncoder( rowBytes.length ) - encodeStateRowWithPrefix(dataWithVersion) + dataWithVersion } } override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = { if (!useColumnFamilies) { - dataEncoder.decodeKey(decodeStateRowData(keyBytes)) + dataEncoder.decodeKey(keyBytes) } else if (keyBytes == null) { null } else { - // First decode the metadata prefixes - val dataWithVersion = decodeStateRowData(keyBytes) + val dataWithVersion = keyBytes // Skip version byte to get to actual data val dataLength = dataWithVersion.length - STATE_ENCODING_NUM_VERSION_BYTES diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 2e46712bf7271..cd9fdb9469d60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, StreamExecution} import org.apache.spark.sql.execution.streaming.state.StateStoreEncoding.Avro import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.Platform import org.apache.spark.util.{NonFateSharingCache, Utils} private[sql] class RocksDBStateStoreProvider @@ -56,17 +57,6 @@ private[sql] class RocksDBStateStoreProvider override def version: Long = lastVersion - // Test-visible methods to fetch column family mapping for this State Store version - // Because column families are only enabled for RocksDBStateStore, these methods - // are no-ops everywhere else. - private[sql] def getColumnFamilyMapping: Map[String, Short] = { - rocksDB.getColumnFamilyMapping.toMap - } - - private[sql] def getColumnFamilyId(cfName: String): Short = { - rocksDB.getColumnFamilyId(cfName) - } - override def createColFamilyIfAbsent( colFamilyName: String, keySchema: StructType, @@ -75,7 +65,7 @@ private[sql] class RocksDBStateStoreProvider useMultipleValuesPerKey: Boolean = false, isInternal: Boolean = false): Unit = { verifyColFamilyCreationOrDeletion("create_col_family", colFamilyName, isInternal) - val newColFamilyId = rocksDB.createColFamilyIfAbsent(colFamilyName) + val cfId = rocksDB.createColFamilyIfAbsent(colFamilyName, isInternal) val dataEncoderCacheKey = StateRowEncoderCacheKey( queryRunId = getRunId(hadoopConf), operatorId = stateStoreId.operatorId, @@ -83,8 +73,6 @@ private[sql] class RocksDBStateStoreProvider stateStoreName = stateStoreId.storeName, colFamilyName = colFamilyName) - val columnFamilyInfo = Some(ColumnFamilyInfo(colFamilyName, newColFamilyId)) - // For unit tests only: TestStateSchemaProvider allows dynamically adding schemas // during unit test execution to verify schema compatibility checks and evolution logic. // This provider is only used in isolated unit tests where we directly instantiate @@ -101,21 +89,19 @@ private[sql] class RocksDBStateStoreProvider keyStateEncoderSpec, valueSchema, stateSchemaProvider, - columnFamilyInfo + Some(colFamilyName) ) - val keyEncoder = RocksDBStateEncoder.getKeyEncoder( dataEncoder, keyStateEncoderSpec, - useColumnFamilies, - columnFamilyInfo + useColumnFamilies ) val valueEncoder = RocksDBStateEncoder.getValueEncoder( dataEncoder, valueSchema, useMultipleValuesPerKey ) - keyValueEncoderMap.putIfAbsent(colFamilyName, (keyEncoder, valueEncoder)) + keyValueEncoderMap.putIfAbsent(colFamilyName, (keyEncoder, valueEncoder, cfId)) } override def get(key: UnsafeRow, colFamilyName: String): UnsafeRow = { @@ -124,7 +110,7 @@ private[sql] class RocksDBStateStoreProvider val kvEncoder = keyValueEncoderMap.get(colFamilyName) val value = - kvEncoder._2.decodeValue(rocksDB.get(kvEncoder._1.encodeKey(key))) + kvEncoder._2.decodeValue(rocksDB.get(kvEncoder._1.encodeKey(key), colFamilyName)) if (!isValidated && value != null && !useColumnFamilies) { StateStoreProvider.validateStateRowFormat( @@ -155,7 +141,7 @@ private[sql] class RocksDBStateStoreProvider verify(valueEncoder.supportsMultipleValuesPerKey, "valuesIterator requires a encoder " + "that supports multiple values for a single key.") - val encodedValues = rocksDB.get(keyEncoder.encodeKey(key)) + val encodedValues = rocksDB.get(keyEncoder.encodeKey(key), colFamilyName) valueEncoder.decodeValues(encodedValues) } @@ -172,7 +158,7 @@ private[sql] class RocksDBStateStoreProvider verify(key != null, "Key cannot be null") require(value != null, "Cannot merge a null value") - rocksDB.merge(keyEncoder.encodeKey(key), valueEncoder.encodeValue(value)) + rocksDB.merge(keyEncoder.encodeKey(key), valueEncoder.encodeValue(value), colFamilyName) } override def put(key: UnsafeRow, value: UnsafeRow, colFamilyName: String): Unit = { @@ -182,7 +168,7 @@ private[sql] class RocksDBStateStoreProvider verifyColFamilyOperations("put", colFamilyName) val kvEncoder = keyValueEncoderMap.get(colFamilyName) - rocksDB.put(kvEncoder._1.encodeKey(key), kvEncoder._2.encodeValue(value)) + rocksDB.put(kvEncoder._1.encodeKey(key), kvEncoder._2.encodeValue(value), colFamilyName) } override def remove(key: UnsafeRow, colFamilyName: String): Unit = { @@ -191,7 +177,7 @@ private[sql] class RocksDBStateStoreProvider verifyColFamilyOperations("remove", colFamilyName) val kvEncoder = keyValueEncoderMap.get(colFamilyName) - rocksDB.remove(kvEncoder._1.encodeKey(key)) + rocksDB.remove(kvEncoder._1.encodeKey(key), colFamilyName) } override def iterator(colFamilyName: String): Iterator[UnsafeRowPair] = { @@ -202,11 +188,8 @@ private[sql] class RocksDBStateStoreProvider val kvEncoder = keyValueEncoderMap.get(colFamilyName) val rowPair = new UnsafeRowPair() - // As Virtual Column Family attaches a column family prefix to the key row, - // we'll need to do prefixScan on the default column family with the same column - // family id prefix to get all rows stored in a given virtual column family if (useColumnFamilies) { - rocksDB.prefixScan(kvEncoder._1.getColumnFamilyIdBytes()).map { kv => + rocksDB.iterator(colFamilyName).map { kv => rowPair.withRows(kvEncoder._1.decodeKey(kv.key), kvEncoder._2.decodeValue(kv.value)) if (!isValidated && rowPair.value != null && !useColumnFamilies) { @@ -240,7 +223,7 @@ private[sql] class RocksDBStateStoreProvider val rowPair = new UnsafeRowPair() val prefix = kvEncoder._1.encodePrefixKey(prefixKey) - rocksDB.prefixScan(prefix).map { kv => + rocksDB.prefixScan(prefix, colFamilyName).map { kv => rowPair.withRows(kvEncoder._1.decodeKey(kv.key), kvEncoder._2.decodeValue(kv.value)) rowPair @@ -327,6 +310,7 @@ private[sql] class RocksDBStateStoreProvider CUSTOM_METRIC_COMPACT_WRITTEN_BYTES -> nativeOpsMetrics("totalBytesWrittenByCompaction"), CUSTOM_METRIC_FLUSH_WRITTEN_BYTES -> nativeOpsMetrics("totalBytesWrittenByFlush"), CUSTOM_METRIC_PINNED_BLOCKS_MEM_USAGE -> rocksDBMetrics.pinnedBlocksMemUsage, + CUSTOM_METRIC_NUM_INTERNAL_COL_FAMILIES_KEYS -> rocksDBMetrics.numInternalKeys, CUSTOM_METRIC_NUM_EXTERNAL_COL_FAMILIES -> internalColFamilyCnt(), CUSTOM_METRIC_NUM_INTERNAL_COL_FAMILIES -> externalColFamilyCnt() ) ++ rocksDBMetrics.zipFileBytesUncompressed.map(bytes => @@ -363,21 +347,10 @@ private[sql] class RocksDBStateStoreProvider verifyColFamilyCreationOrDeletion("remove_col_family", colFamilyName) verify(useColumnFamilies, "Column families are not supported in this store") - val result = { - val colFamilyId = rocksDB.removeColFamilyIfExists(colFamilyName) - - colFamilyId match { - case Some(vcfId) => - val colFamilyIdBytes = - RocksDBStateEncoder.getColumnFamilyIdBytes(vcfId) - rocksDB.prefixScan(colFamilyIdBytes).foreach { kv => - rocksDB.remove(kv.key) - } - true - case None => false - } + val result = rocksDB.removeColFamilyIfExists(colFamilyName) + if (result) { + keyValueEncoderMap.remove(colFamilyName) } - keyValueEncoderMap.remove(colFamilyName) result } } @@ -412,7 +385,6 @@ private[sql] class RocksDBStateStoreProvider } rocksDB // lazy initialization - var defaultColFamilyId: Option[Short] = None val dataEncoderCacheKey = StateRowEncoderCacheKey( queryRunId = getRunId(hadoopConf), @@ -430,31 +402,32 @@ private[sql] class RocksDBStateStoreProvider case _ => } - defaultColFamilyId = Some(rocksDB.createColFamilyIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME)) - val columnFamilyInfo = - Some(ColumnFamilyInfo(StateStore.DEFAULT_COL_FAMILY_NAME, defaultColFamilyId.get)) - val dataEncoder = getDataEncoder( stateStoreEncoding, dataEncoderCacheKey, keyStateEncoderSpec, valueSchema, stateSchemaProvider, - columnFamilyInfo - ) + Some(StateStore.DEFAULT_COL_FAMILY_NAME)) val keyEncoder = RocksDBStateEncoder.getKeyEncoder( dataEncoder, keyStateEncoderSpec, - useColumnFamilies, - columnFamilyInfo - ) + useColumnFamilies) val valueEncoder = RocksDBStateEncoder.getValueEncoder( dataEncoder, valueSchema, useMultipleValuesPerKey ) - keyValueEncoderMap.putIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME, (keyEncoder, valueEncoder)) + + var cfId: Short = 0 + if (useColumnFamilies) { + cfId = rocksDB.createColFamilyIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME, + isInternal = false) + } + + keyValueEncoderMap.putIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME, + (keyEncoder, valueEncoder, cfId)) } override def stateStoreId: StateStoreId = stateStoreId_ @@ -548,7 +521,7 @@ private[sql] class RocksDBStateStoreProvider } private val keyValueEncoderMap = new java.util.concurrent.ConcurrentHashMap[String, - (RocksDBKeyStateEncoder, RocksDBValueStateEncoder)] + (RocksDBKeyStateEncoder, RocksDBValueStateEncoder, Short)] private val multiColFamiliesDisabledStr = "multiple column families is disabled in " + "RocksDBStateStoreProvider" @@ -679,12 +652,85 @@ object RocksDBStateStoreProvider { val STATE_ENCODING_NUM_VERSION_BYTES = 1 val STATE_ENCODING_VERSION: Byte = 0 val VIRTUAL_COL_FAMILY_PREFIX_BYTES = 2 + val SCHEMA_ID_PREFIX_BYTES = 2 private val MAX_AVRO_ENCODERS_IN_CACHE = 1000 private val AVRO_ENCODER_LIFETIME_HOURS = 1L private val DEFAULT_SCHEMA_IDS = StateSchemaInfo(0, 0) + /** + * Encodes a virtual column family ID into a byte array suitable for RocksDB. + * + * This method creates a fixed-size byte array prefixed with the virtual column family ID, + * which is used to partition data within RocksDB. + * + * @param virtualColFamilyId The column family identifier to encode + * @return A byte array containing the encoded column family ID + */ + def getColumnFamilyIdAsBytes(virtualColFamilyId: Short): Array[Byte] = { + val encodedBytes = new Array[Byte](RocksDBStateStoreProvider.VIRTUAL_COL_FAMILY_PREFIX_BYTES) + Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET, virtualColFamilyId) + encodedBytes + } + + /** + * Function to encode state row with virtual col family id prefix + * @param data - passed byte array to be stored in state store + * @param vcfId - virtual column family id + * @return - encoded byte array with virtual column family id prefix + */ + def encodeStateRowWithPrefix( + data: Array[Byte], + vcfId: Short): Array[Byte] = { + // Create result array big enough for all prefixes plus data + val result = new Array[Byte](RocksDBStateStoreProvider.VIRTUAL_COL_FAMILY_PREFIX_BYTES + + data.length) + val offset = Platform.BYTE_ARRAY_OFFSET + + RocksDBStateStoreProvider.VIRTUAL_COL_FAMILY_PREFIX_BYTES + + Platform.putShort(result, Platform.BYTE_ARRAY_OFFSET, vcfId) + + // Write the actual data + Platform.copyMemory( + data, Platform.BYTE_ARRAY_OFFSET, + result, offset, + data.length + ) + + result + } + + /** + * Function to decode virtual column family id from byte array + * @param data - passed byte array retrieved from state store + * @return - virtual column family id + */ + def getColumnFamilyBytesAsId(data: Array[Byte]): Short = { + Platform.getShort(data, Platform.BYTE_ARRAY_OFFSET) + } + + /** + * Function to decode state row with virtual col family id prefix + * @param data - passed byte array retrieved from state store + * @return - pair of decoded byte array without virtual column family id prefix + * and name of column family + */ + def decodeStateRowWithPrefix(data: Array[Byte]): Array[Byte] = { + val offset = Platform.BYTE_ARRAY_OFFSET + + RocksDBStateStoreProvider.VIRTUAL_COL_FAMILY_PREFIX_BYTES + + val key = new Array[Byte](data.length - + RocksDBStateStoreProvider.VIRTUAL_COL_FAMILY_PREFIX_BYTES) + Platform.copyMemory( + data, offset, + key, Platform.BYTE_ARRAY_OFFSET, + key.length + ) + + key + } + // Add the cache at companion object level so it persists across provider instances private val dataEncoderCache: NonFateSharingCache[StateRowEncoderCacheKey, RocksDBDataEncoder] = NonFateSharingCache( @@ -716,18 +762,20 @@ object RocksDBStateStoreProvider { keyStateEncoderSpec: KeyStateEncoderSpec, valueSchema: StructType, stateSchemaProvider: Option[StateSchemaProvider], - columnFamilyInfo: Option[ColumnFamilyInfo] = None): RocksDBDataEncoder = { + columnFamilyName: Option[String] = None): RocksDBDataEncoder = { assert(Set("avro", "unsaferow").contains(stateStoreEncoding)) RocksDBStateStoreProvider.dataEncoderCache.get( encoderCacheKey, new java.util.concurrent.Callable[RocksDBDataEncoder] { override def call(): RocksDBDataEncoder = { if (stateStoreEncoding == Avro.toString) { + assert(columnFamilyName.isDefined, + "Column family name must be defined for Avro encoding") new AvroStateEncoder( keyStateEncoderSpec, valueSchema, stateSchemaProvider, - columnFamilyInfo + columnFamilyName.get ) } else { new UnsafeRowDataEncoder( @@ -814,6 +862,9 @@ object RocksDBStateStoreProvider { val CUSTOM_METRIC_PINNED_BLOCKS_MEM_USAGE = StateStoreCustomSizeMetric( "rocksdbPinnedBlocksMemoryUsage", "RocksDB: memory usage for pinned blocks") + val CUSTOM_METRIC_NUM_INTERNAL_COL_FAMILIES_KEYS = StateStoreCustomSizeMetric( + "rocksdbNumInternalColFamiliesKeys", + "RocksDB: number of internal keys for internal column families") val CUSTOM_METRIC_NUM_EXTERNAL_COL_FAMILIES = StateStoreCustomSizeMetric( "rocksdbNumExternalColumnFamilies", "RocksDB: number of external column families") @@ -835,8 +886,8 @@ object RocksDBStateStoreProvider { CUSTOM_METRIC_BYTES_WRITTEN, CUSTOM_METRIC_ITERATOR_BYTES_READ, CUSTOM_METRIC_STALL_TIME, CUSTOM_METRIC_TOTAL_COMPACT_TIME, CUSTOM_METRIC_COMPACT_READ_BYTES, CUSTOM_METRIC_COMPACT_WRITTEN_BYTES, CUSTOM_METRIC_FLUSH_WRITTEN_BYTES, - CUSTOM_METRIC_PINNED_BLOCKS_MEM_USAGE, CUSTOM_METRIC_NUM_EXTERNAL_COL_FAMILIES, - CUSTOM_METRIC_NUM_INTERNAL_COL_FAMILIES) + CUSTOM_METRIC_PINNED_BLOCKS_MEM_USAGE, CUSTOM_METRIC_NUM_INTERNAL_COL_FAMILIES_KEYS, + CUSTOM_METRIC_NUM_EXTERNAL_COL_FAMILIES, CUSTOM_METRIC_NUM_INTERNAL_COL_FAMILIES) } /** [[StateStoreChangeDataReader]] implementation for [[RocksDBStateStoreProvider]] */ @@ -847,35 +898,20 @@ class RocksDBStateStoreChangeDataReader( endVersion: Long, compressionCodec: CompressionCodec, keyValueEncoderMap: - ConcurrentHashMap[String, (RocksDBKeyStateEncoder, RocksDBValueStateEncoder)], + ConcurrentHashMap[String, (RocksDBKeyStateEncoder, RocksDBValueStateEncoder, Short)], colFamilyNameOpt: Option[String] = None) extends StateStoreChangeDataReader( fm, stateLocation, startVersion, endVersion, compressionCodec, colFamilyNameOpt) { override protected var changelogSuffix: String = "changelog" - private def getColFamilyIdBytes: Option[Array[Byte]] = { - if (colFamilyNameOpt.isDefined) { - val colFamilyName = colFamilyNameOpt.get - if (!keyValueEncoderMap.containsKey(colFamilyName)) { - throw new IllegalStateException( - s"Column family $colFamilyName not found in the key value encoder map") - } - Some(keyValueEncoderMap.get(colFamilyName)._1.getColumnFamilyIdBytes()) - } else { - None - } - } - - private val colFamilyIdBytesOpt: Option[Array[Byte]] = getColFamilyIdBytes - override def getNext(): (RecordType.Value, UnsafeRow, UnsafeRow, Long) = { var currRecord: (RecordType.Value, Array[Byte], Array[Byte]) = null - val currEncoder: (RocksDBKeyStateEncoder, RocksDBValueStateEncoder) = + val currEncoder: (RocksDBKeyStateEncoder, RocksDBValueStateEncoder, Short) = keyValueEncoderMap.get(colFamilyNameOpt .getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME)) - if (colFamilyIdBytesOpt.isDefined) { + if (colFamilyNameOpt.isDefined) { // If we are reading records for a particular column family, the corresponding vcf id // will be encoded in the key byte array. We need to extract that and compare for the // expected column family id. If it matches, we return the record. If not, we move to @@ -888,13 +924,16 @@ class RocksDBStateStoreChangeDataReader( } val nextRecord = reader.next() - val colFamilyIdBytes: Array[Byte] = colFamilyIdBytesOpt.get + val colFamilyIdBytes: Array[Byte] = + RocksDBStateStoreProvider.getColumnFamilyIdAsBytes(currEncoder._3) val endIndex = colFamilyIdBytes.size // Function checks for byte arrays being equal // from index 0 to endIndex - 1 (both inclusive) if (java.util.Arrays.equals(nextRecord._2, 0, endIndex, colFamilyIdBytes, 0, endIndex)) { - currRecord = nextRecord + val extractedKey = RocksDBStateStoreProvider.decodeStateRowWithPrefix(nextRecord._2) + val result = (nextRecord._1, extractedKey, nextRecord._3) + currRecord = result } } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 434a9b59e6779..8ba3fc37162c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -331,13 +331,11 @@ sealed trait KeyStateEncoderSpec { * * @param dataEncoder The encoder to handle the actual data encoding/decoding * @param useColumnFamilies Whether to use RocksDB column families - * @param virtualColFamilyId Optional column family ID when column families are used * @return A RocksDBKeyStateEncoder configured for this spec */ def toEncoder( dataEncoder: RocksDBDataEncoder, - useColumnFamilies: Boolean, - columnFamilyInfo: Option[ColumnFamilyInfo]): RocksDBKeyStateEncoder + useColumnFamilies: Boolean): RocksDBKeyStateEncoder } object KeyStateEncoderSpec { @@ -364,10 +362,9 @@ case class NoPrefixKeyStateEncoderSpec(keySchema: StructType) extends KeyStateEn override def toEncoder( dataEncoder: RocksDBDataEncoder, - useColumnFamilies: Boolean, - columnFamilyInfo: Option[ColumnFamilyInfo]): RocksDBKeyStateEncoder = { + useColumnFamilies: Boolean): RocksDBKeyStateEncoder = { new NoPrefixKeyStateEncoder( - dataEncoder, keySchema, useColumnFamilies, columnFamilyInfo) + dataEncoder, keySchema, useColumnFamilies) } } @@ -380,13 +377,11 @@ case class PrefixKeyScanStateEncoderSpec( override def toEncoder( dataEncoder: RocksDBDataEncoder, - useColumnFamilies: Boolean, - columnFamilyInfo: Option[ColumnFamilyInfo]): RocksDBKeyStateEncoder = { + useColumnFamilies: Boolean): RocksDBKeyStateEncoder = { new PrefixKeyScanStateEncoder( - dataEncoder, keySchema, numColsPrefixKey, useColumnFamilies, columnFamilyInfo) + dataEncoder, keySchema, numColsPrefixKey, useColumnFamilies) } - override def jsonValue: JValue = { ("keyStateEncoderType" -> JString("PrefixKeyScanStateEncoderSpec")) ~ ("numColsPrefixKey" -> JInt(numColsPrefixKey)) @@ -403,10 +398,9 @@ case class RangeKeyScanStateEncoderSpec( override def toEncoder( dataEncoder: RocksDBDataEncoder, - useColumnFamilies: Boolean, - columnFamilyInfo: Option[ColumnFamilyInfo]): RocksDBKeyStateEncoder = { + useColumnFamilies: Boolean): RocksDBKeyStateEncoder = { new RangeKeyScanStateEncoder( - dataEncoder, keySchema, orderingOrdinals, useColumnFamilies, columnFamilyInfo) + dataEncoder, keySchema, orderingOrdinals, useColumnFamilies) } override def jsonValue: JValue = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala index f170de66ee9df..1f4fd7f795716 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala @@ -107,8 +107,8 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest "rocksdbTotalBytesReadByCompaction", "rocksdbTotalBytesWrittenByCompaction", "rocksdbTotalCompactionLatencyMs", "rocksdbWriterStallLatencyMs", "rocksdbTotalBytesReadThroughIterator", "rocksdbTotalBytesWrittenByFlush", - "rocksdbPinnedBlocksMemoryUsage", "rocksdbNumExternalColumnFamilies", - "rocksdbNumInternalColumnFamilies")) + "rocksdbPinnedBlocksMemoryUsage", "rocksdbNumInternalColFamiliesKeys", + "rocksdbNumExternalColumnFamilies", "rocksdbNumInternalColumnFamilies")) } } finally { query.stop() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index 4ac771a5b0baa..4d939db8796b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -646,7 +646,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid NoPrefixKeyStateEncoderSpec(keySchema), initialValueSchema, Some(testProvider), - Some(ColumnFamilyInfo(StateStore.DEFAULT_COL_FAMILY_NAME, 1)) + StateStore.DEFAULT_COL_FAMILY_NAME ) // Create test data @@ -669,7 +669,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid NoPrefixKeyStateEncoderSpec(keySchema), evolvedValueSchema, Some(testProvider), - Some(ColumnFamilyInfo(StateStore.DEFAULT_COL_FAMILY_NAME, 1)) + StateStore.DEFAULT_COL_FAMILY_NAME ) // Decode with evolved schema @@ -720,7 +720,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid NoPrefixKeyStateEncoderSpec(keySchema), initialValueSchema, Some(testProvider), - Some(ColumnFamilyInfo(StateStore.DEFAULT_COL_FAMILY_NAME, 1)) + StateStore.DEFAULT_COL_FAMILY_NAME ) // Create test data with null value @@ -747,7 +747,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid NoPrefixKeyStateEncoderSpec(keySchema), evolvedValueSchema, Some(testProvider), - Some(ColumnFamilyInfo(StateStore.DEFAULT_COL_FAMILY_NAME, 1)) + StateStore.DEFAULT_COL_FAMILY_NAME ) // Decode original value with evolved schema @@ -801,7 +801,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid NoPrefixKeyStateEncoderSpec(keySchema), initialValueSchema, Some(testProvider), - Some(ColumnFamilyInfo(StateStore.DEFAULT_COL_FAMILY_NAME, 1)) + StateStore.DEFAULT_COL_FAMILY_NAME ) val proj = UnsafeProjection.create(initialValueSchema) @@ -819,7 +819,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid NoPrefixKeyStateEncoderSpec(keySchema), evolvedValueSchema, Some(testProvider), - Some(ColumnFamilyInfo(StateStore.DEFAULT_COL_FAMILY_NAME, 1)) + StateStore.DEFAULT_COL_FAMILY_NAME ) val decoded = encoder2.decodeValue(encoded) @@ -866,7 +866,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid NoPrefixKeyStateEncoderSpec(keySchema), initialValueSchema, Some(testProvider), - Some(ColumnFamilyInfo(StateStore.DEFAULT_COL_FAMILY_NAME, 1)) + StateStore.DEFAULT_COL_FAMILY_NAME ) val proj = UnsafeProjection.create(initialValueSchema) @@ -884,7 +884,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid NoPrefixKeyStateEncoderSpec(keySchema), evolvedValueSchema, Some(testProvider), - Some(ColumnFamilyInfo(StateStore.DEFAULT_COL_FAMILY_NAME, 1)) + StateStore.DEFAULT_COL_FAMILY_NAME ) val decoded = encoder2.decodeValue(encoded) @@ -935,7 +935,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid NoPrefixKeyStateEncoderSpec(keySchema), initialValueSchema, Some(testProvider), - Some(ColumnFamilyInfo(StateStore.DEFAULT_COL_FAMILY_NAME, 1)) + StateStore.DEFAULT_COL_FAMILY_NAME ) // Create test data @@ -958,7 +958,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid NoPrefixKeyStateEncoderSpec(keySchema), evolvedValueSchema, Some(testProvider), - Some(ColumnFamilyInfo(StateStore.DEFAULT_COL_FAMILY_NAME, 1)) + StateStore.DEFAULT_COL_FAMILY_NAME ) // Decode with evolved schema @@ -1015,7 +1015,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid NoPrefixKeyStateEncoderSpec(keySchema), initialValueSchema, Some(testProvider), - Some(ColumnFamilyInfo(StateStore.DEFAULT_COL_FAMILY_NAME, 1)) + StateStore.DEFAULT_COL_FAMILY_NAME ) // Create and encode data with initial schema (IntegerType) @@ -1035,7 +1035,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid NoPrefixKeyStateEncoderSpec(keySchema), evolvedValueSchema, Some(testProvider), - Some(ColumnFamilyInfo(StateStore.DEFAULT_COL_FAMILY_NAME, 1)) + StateStore.DEFAULT_COL_FAMILY_NAME ) // Should successfully decode IntegerType as LongType @@ -1082,7 +1082,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid NoPrefixKeyStateEncoderSpec(keySchema), initialValueSchema, Some(testProvider), - Some(ColumnFamilyInfo(StateStore.DEFAULT_COL_FAMILY_NAME, 1)) + StateStore.DEFAULT_COL_FAMILY_NAME ) // Create and encode data with initial order @@ -1102,7 +1102,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid NoPrefixKeyStateEncoderSpec(keySchema), evolvedValueSchema, Some(testProvider), - Some(ColumnFamilyInfo(StateStore.DEFAULT_COL_FAMILY_NAME, 1)) + StateStore.DEFAULT_COL_FAMILY_NAME ) // Should decode with correct field values despite reordering @@ -1153,7 +1153,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid NoPrefixKeyStateEncoderSpec(keySchema), initialValueSchema, Some(testProvider), - Some(ColumnFamilyInfo(StateStore.DEFAULT_COL_FAMILY_NAME, 1)) + StateStore.DEFAULT_COL_FAMILY_NAME ) // Create and encode data with initial schema (LongType) @@ -1173,7 +1173,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid NoPrefixKeyStateEncoderSpec(keySchema), evolvedValueSchema, Some(testProvider), - Some(ColumnFamilyInfo(StateStore.DEFAULT_COL_FAMILY_NAME, 1)) + StateStore.DEFAULT_COL_FAMILY_NAME ) // Attempting to decode Long as Int should fail @@ -1834,7 +1834,6 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid store = provider.getRocksDBStateStore(2) store.createColFamilyIfAbsent(colFamily3, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema)) - assert(store.getColumnFamilyId(colFamily3) == 3) store.removeColFamilyIfExists(colFamily1) store.removeColFamilyIfExists(colFamily3) store.commit() @@ -1843,15 +1842,12 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid // this should return the old id, because we didn't remove this colFamily for version 1 store.createColFamilyIfAbsent(colFamily1, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema)) - assert(store.getColumnFamilyId(colFamily1) == 1) store = provider.getRocksDBStateStore(3) store.createColFamilyIfAbsent(colFamily4, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema)) - assert(store.getColumnFamilyId(colFamily4) == 4) store.createColFamilyIfAbsent(colFamily5, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema)) - assert(store.getColumnFamilyId(colFamily5) == 5) } } @@ -1900,6 +1896,66 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } + testWithColumnFamiliesAndEncodingTypes(s"numInternalKeys metrics", + TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => + tryWithProviderResource( + newStoreProvider(useColumnFamilies = colFamiliesEnabled)) { provider => + if (colFamiliesEnabled) { + val store = provider.getStore(0) + + // create non-internal col family and add data + val cfName = "testColFamily" + store.createColFamilyIfAbsent(cfName, keySchema, valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema)) + put(store, "a", 0, 1, cfName) + put(store, "b", 0, 2, cfName) + put(store, "c", 0, 3, cfName) + put(store, "d", 0, 4, cfName) + put(store, "e", 0, 5, cfName) + + // create internal col family and add data + val internalCfName = "$testIndex" + store.createColFamilyIfAbsent(internalCfName, keySchema, valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema), isInternal = true) + put(store, "a", 0, 1, internalCfName) + put(store, "m", 0, 2, internalCfName) + put(store, "n", 0, 3, internalCfName) + put(store, "b", 0, 4, internalCfName) + + assert(store.commit() === 1) + // Commit and verify that the metrics are correct for internal and non-internal col families + assert(store.metrics.numKeys === 5) + val metricPair = store + .metrics.customMetrics.find(_._1.name == "rocksdbNumInternalColFamiliesKeys") + assert(metricPair.isDefined && metricPair.get._2 === 4) + assert(rowPairsToDataSet(store.iterator(cfName)) === + Set(("a", 0) -> 1, ("b", 0) -> 2, ("c", 0) -> 3, ("d", 0) -> 4, ("e", 0) -> 5)) + assert(rowPairsToDataSet(store.iterator(internalCfName)) === + Set(("a", 0) -> 1, ("m", 0) -> 2, ("n", 0) -> 3, ("b", 0) -> 4)) + + // Reload the store and remove some keys + val reloadedProvider = newStoreProvider(store.id, colFamiliesEnabled) + val reloadedStore = reloadedProvider.getStore(1) + reloadedStore.createColFamilyIfAbsent(cfName, keySchema, valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema)) + reloadedStore.createColFamilyIfAbsent(internalCfName, keySchema, valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema), isInternal = true) + remove(reloadedStore, _._1 == "b", cfName) + remove(reloadedStore, _._1 == "m", internalCfName) + assert(reloadedStore.commit() === 2) + // Commit and verify that the metrics are correct for internal and non-internal col families + assert(reloadedStore.metrics.numKeys === 4) + val metricPairUpdated = reloadedStore + .metrics.customMetrics.find(_._1.name == "rocksdbNumInternalColFamiliesKeys") + assert(metricPairUpdated.isDefined && metricPairUpdated.get._2 === 3) + assert(rowPairsToDataSet(reloadedStore.iterator(cfName)) === + Set(("a", 0) -> 1, ("c", 0) -> 3, ("d", 0) -> 4, ("e", 0) -> 5)) + assert(rowPairsToDataSet(reloadedStore.iterator(internalCfName)) === + Set(("a", 0) -> 1, ("n", 0) -> 3, ("b", 0) -> 4)) + } + } + } + test(s"validate rocksdb removeColFamilyIfExists correctness") { Seq( NoPrefixKeyStateEncoderSpec(keySchema), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index 0c6824f76d016..7d4614d599733 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -381,9 +381,7 @@ class RocksDBStateEncoderSuite extends SparkFunSuite { keyStateEncoderSpec, valueSchema, Some(testProvider), - Some(ColumnFamilyInfo(StateStore.DEFAULT_COL_FAMILY_NAME, 0)) - ) - new AvroStateEncoder(keyStateEncoderSpec, valueSchema, None, None) + StateStore.DEFAULT_COL_FAMILY_NAME) } private def createNoPrefixKeyEncoder(): RocksDBDataEncoder = { @@ -428,7 +426,7 @@ class RocksDBStateEncoderSuite extends SparkFunSuite { withClue("Testing prefix scan encoding: ") { val prefixKeySpec = PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey = 2) val encoder = new AvroStateEncoder(prefixKeySpec, valueSchema, Some(testProvider), - Some(ColumnFamilyInfo(StateStore.DEFAULT_COL_FAMILY_NAME, 0))) + StateStore.DEFAULT_COL_FAMILY_NAME) // Then encode just the remaining key portion (which should include schema ID) val remainingKeyRow = keyProj.apply(InternalRow(null, null, 3.14)) @@ -452,7 +450,7 @@ class RocksDBStateEncoderSuite extends SparkFunSuite { withClue("Testing range scan encoding: ") { val rangeScanSpec = RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals = Seq(0, 1)) val encoder = new AvroStateEncoder(rangeScanSpec, valueSchema, Some(testProvider), - Some(ColumnFamilyInfo(StateStore.DEFAULT_COL_FAMILY_NAME, 0))) + StateStore.DEFAULT_COL_FAMILY_NAME) // Encode remaining key (non-ordering columns) // For range scan, the remaining key schema only contains columns NOT in orderingOrdinals @@ -560,7 +558,7 @@ class RocksDBStateEncoderSuite extends SparkFunSuite { val keySpec = NoPrefixKeyStateEncoderSpec(keySchema) val stateSchemaInfo = Some(StateSchemaInfo(keySchemaId = 0, valueSchemaId = 42)) val avroEncoder = new AvroStateEncoder(keySpec, valueSchema, Some(testProvider), - Some(ColumnFamilyInfo(StateStore.DEFAULT_COL_FAMILY_NAME, 0))) + StateStore.DEFAULT_COL_FAMILY_NAME) val valueEncoder = new SingleValueStateEncoder(avroEncoder, valueSchema) // Encode value @@ -629,7 +627,11 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession if (isChangelogCheckpointingEnabled) { assert(changelogVersionsPresent(remoteDir) === (1 to 50)) - assert(snapshotVersionsPresent(remoteDir) === Range.inclusive(5, 50, 5)) + if (colFamiliesEnabled) { + assert(snapshotVersionsPresent(remoteDir) === Seq(1) ++ Range.inclusive(5, 50, 5)) + } else { + assert(snapshotVersionsPresent(remoteDir) === Range.inclusive(5, 50, 5)) + } } else { assert(changelogVersionsPresent(remoteDir) === Seq.empty) assert(snapshotVersionsPresent(remoteDir) === (1 to 50)) @@ -693,14 +695,25 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession db.commit() db.doMaintenance() } - assert(snapshotVersionsPresent(remoteDir) === Seq(2, 3)) + + if (colFamiliesEnabled) { + assert(snapshotVersionsPresent(remoteDir) === Seq(1, 2, 3)) + } else { + assert(snapshotVersionsPresent(remoteDir) === Seq(2, 3)) + } assert(changelogVersionsPresent(remoteDir) == Seq(1, 2, 3)) for (version <- 3 to 4) { db.load(version) db.commit() } - assert(snapshotVersionsPresent(remoteDir) === Seq(2, 3)) + + if (colFamiliesEnabled) { + assert(snapshotVersionsPresent(remoteDir) === Seq(1, 2, 3)) + } else { + assert(snapshotVersionsPresent(remoteDir) === Seq(2, 3)) + } + assert(changelogVersionsPresent(remoteDir) == (1 to 5)) db.doMaintenance() // 3 is the latest snapshot <= maxSnapshotVersionPresent - minVersionsToRetain + 1 @@ -783,27 +796,47 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession db.commit() db.doMaintenance() } - // Snapshot should not be created because minDeltasForSnapshot = 3 - assert(snapshotVersionsPresent(remoteDir) === Seq.empty) + + if (colFamiliesEnabled) { + assert(snapshotVersionsPresent(remoteDir) === Seq(1)) + } else { + // Snapshot should not be created because minDeltasForSnapshot = 3 + assert(snapshotVersionsPresent(remoteDir) === Seq.empty) + } + assert(changelogVersionsPresent(remoteDir) == Seq(1, 2)) db.load(2, versionToUniqueId.get(2)) db.commit() db.doMaintenance() - assert(snapshotVersionsPresent(remoteDir) === Seq(3)) - db.load(3, versionToUniqueId.get(3)) + if (colFamiliesEnabled) { + assert(snapshotVersionsPresent(remoteDir) === Seq(1)) + } else { + assert(snapshotVersionsPresent(remoteDir) === Seq(3)) + } for (version <- 3 to 7) { db.load(version, versionToUniqueId.get(version)) db.commit() db.doMaintenance() } - assert(snapshotVersionsPresent(remoteDir) === Seq(3, 6)) + + if (colFamiliesEnabled) { + assert(snapshotVersionsPresent(remoteDir) === Seq(1, 4, 7)) + } else { + assert(snapshotVersionsPresent(remoteDir) === Seq(3, 6)) + } + for (version <- 8 to 17) { db.load(version, versionToUniqueId.get(version)) db.commit() } db.doMaintenance() - assert(snapshotVersionsPresent(remoteDir) === Seq(3, 6, 18)) + + if (colFamiliesEnabled) { + assert(snapshotVersionsPresent(remoteDir) === Seq(1, 4, 7, 16)) + } else { + assert(snapshotVersionsPresent(remoteDir) === Seq(3, 6, 18)) + } } // pick up from the last snapshot and the next upload will be for version 21 @@ -813,14 +846,24 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession db.load(18, versionToUniqueId.get(18)) db.commit() db.doMaintenance() - assert(snapshotVersionsPresent(remoteDir) === Seq(3, 6, 18)) + + if (colFamiliesEnabled) { + assert(snapshotVersionsPresent(remoteDir) === Seq(1, 4, 7, 16, 19)) + } else { + assert(snapshotVersionsPresent(remoteDir) === Seq(3, 6, 18)) + } for (version <- 19 to 20) { db.load(version, versionToUniqueId.get(version)) db.commit() } db.doMaintenance() - assert(snapshotVersionsPresent(remoteDir) === Seq(3, 6, 18, 21)) + + if (colFamiliesEnabled) { + assert(snapshotVersionsPresent(remoteDir) === Seq(1, 4, 7, 16, 19)) + } else { + assert(snapshotVersionsPresent(remoteDir) === Seq(3, 6, 18, 21)) + } } } @@ -902,7 +945,14 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession db.remove((version - 1).toString) db.commit() } - assert(snapshotVersionsPresent(remoteDir) === (1 to 30)) + + if (enableStateStoreCheckpointIds && colFamiliesEnabled) { + // This is because 30 is executed twice and snapshot does not overwrite in checkpoint v2 + assert(snapshotVersionsPresent(remoteDir) === (1 to 30) :+ 30 :+ 31) + } else { + assert(snapshotVersionsPresent(remoteDir) === (1 to 30)) + } + assert(changelogVersionsPresent(remoteDir) === (30 to 60)) for (version <- 1 to 60) { db.load(version, versionToUniqueId.get(version), readOnly = true) @@ -918,18 +968,36 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession } // Check that snapshots and changelogs get purged correctly. db.doMaintenance() - assert(snapshotVersionsPresent(remoteDir) === Seq(30, 60)) + + // Behavior is slightly different when column families are enabled with checkpoint v2 + // since snapshot version 31 was created previously. + if (enableStateStoreCheckpointIds && colFamiliesEnabled) { + assert(snapshotVersionsPresent(remoteDir) === Seq(31, 60, 60)) + } else { + assert(snapshotVersionsPresent(remoteDir) === Seq(30, 60)) + } if (enableStateStoreCheckpointIds) { // recommit version 60 creates another changelog file with different unique id - assert(changelogVersionsPresent(remoteDir) === (30 to 60) :+ 60) + if (colFamiliesEnabled) { + assert(changelogVersionsPresent(remoteDir) === (31 to 60) :+ 60) + } else { + assert(changelogVersionsPresent(remoteDir) === (30 to 60) :+ 60) + } } else { assert(changelogVersionsPresent(remoteDir) === (30 to 60)) } // Verify the content of retained versions. - for (version <- 30 to 60) { - db.load(version, versionToUniqueId.get(version), readOnly = true) - assert(db.iterator().map(toStr).toSet === Set((version.toString, version.toString))) + if (enableStateStoreCheckpointIds && colFamiliesEnabled) { + for (version <- 31 to 60) { + db.load(version, readOnly = true) + assert(db.iterator().map(toStr).toSet === Set((version.toString, version.toString))) + } + } else { + for (version <- 30 to 60) { + db.load(version, readOnly = true) + assert(db.iterator().map(toStr).toSet === Set((version.toString, version.toString))) + } } } } @@ -1074,6 +1142,7 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession db.load(version, versionToUniqueId.get(version)) assert(db.iterator().map(toStr).toSet === Set((version.toString, version.toString))) } + for (version <- 31 to 60) { db.load(version - 1, versionToUniqueId.get(version - 1)) db.put(version.toString, version.toString) @@ -1081,7 +1150,17 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession db.commit() } assert(changelogVersionsPresent(remoteDir) === (1 to 30)) - assert(snapshotVersionsPresent(remoteDir) === (31 to 60)) + + var result: Seq[Long] = if (colFamiliesEnabled) { + Seq(1) + } else { + Seq.empty + } + + (31 to 60).foreach { i => + result = result :+ i + } + assert(snapshotVersionsPresent(remoteDir) === result) for (version <- 1 to 60) { db.load(version, versionToUniqueId.get(version), readOnly = true) assert(db.iterator().map(toStr).toSet === Set((version.toString, version.toString))) @@ -1615,7 +1694,7 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession val conf = dbConf.copy(minDeltasForSnapshot = 5, compactOnCommit = false) new File(remoteDir).delete() // to make sure that the directory gets created withDB(remoteDir, conf = conf, useColumnFamilies = true) { db => - db.createColFamilyIfAbsent("test") + db.createColFamilyIfAbsent("test", isInternal = false) db.load(0) db.put("a", "1") db.put("b", "2") @@ -1716,7 +1795,8 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession case true => Some(UUID.randomUUID().toString) } saveCheckpointFiles(fileManager, cpFiles1, version = 1, - numKeys = 101, rocksDBFileMapping, uuid) + numKeys = 101, rocksDBFileMapping, + numInternalKeys = 0, uuid) assert(fileManager.getLatestVersion() === 1) assert(numRemoteSSTFiles == 2) // 2 sst files copied assert(numRemoteLogFiles == 2) @@ -1731,7 +1811,8 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession "archive/00003.log" -> 2000 ) saveCheckpointFiles(fileManager_, cpFiles1_, version = 1, - numKeys = 101, new RocksDBFileMapping(), uuid) + numKeys = 101, new RocksDBFileMapping(), + numInternalKeys = 0, uuid) assert(fileManager_.getLatestVersion() === 1) assert(numRemoteSSTFiles == 4) assert(numRemoteLogFiles == 4) @@ -1751,7 +1832,8 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession "archive/00005.log" -> 2000 ) saveCheckpointFiles(fileManager_, cpFiles2, - version = 2, numKeys = 121, new RocksDBFileMapping(), uuid) + version = 2, numKeys = 121, new RocksDBFileMapping(), + numInternalKeys = 0, uuid) fileManager_.deleteOldVersions(1) assert(numRemoteSSTFiles <= 4) // delete files recorded in 1.zip assert(numRemoteLogFiles <= 5) // delete files recorded in 1.zip and orphan 00001.log @@ -1766,7 +1848,8 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession "archive/00007.log" -> 2000 ) saveCheckpointFiles(fileManager_, cpFiles3, - version = 3, numKeys = 131, new RocksDBFileMapping(), uuid) + version = 3, numKeys = 131, new RocksDBFileMapping(), + numInternalKeys = 0, uuid) assert(fileManager_.getLatestVersion() === 3) fileManager_.deleteOldVersions(1) assert(numRemoteSSTFiles == 1) @@ -1812,7 +1895,8 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession } saveCheckpointFiles( - fileManager, cpFiles1, version = 1, numKeys = 101, rocksDBFileMapping, uuid) + fileManager, cpFiles1, version = 1, numKeys = 101, rocksDBFileMapping, + numInternalKeys = 0, uuid) fileManager.deleteOldVersions(1) // Should not delete orphan files even when they are older than all existing files // when there is only 1 version. @@ -1830,7 +1914,8 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession "archive/00004.log" -> 2000 ) saveCheckpointFiles( - fileManager, cpFiles2, version = 2, numKeys = 101, rocksDBFileMapping, uuid) + fileManager, cpFiles2, version = 2, numKeys = 101, rocksDBFileMapping, + numInternalKeys = 0, uuid) assert(numRemoteSSTFiles == 5) assert(numRemoteLogFiles == 5) fileManager.deleteOldVersions(1) @@ -1880,7 +1965,8 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession } saveCheckpointFiles( - fileManager, cpFiles1, version = 1, numKeys = 101, fileMapping, uuid) + fileManager, cpFiles1, version = 1, numKeys = 101, fileMapping, numInternalKeys = 0, + uuid) assert(fileManager.getLatestVersion() === 1) assert(numRemoteSSTFiles == 2) // 2 sst files copied assert(numRemoteLogFiles == 2) // 2 log files copied @@ -1918,7 +2004,8 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession // upload version 1 again, new checkpoint will be created and SST files from // previously committed version 1 will not be reused. saveCheckpointFiles(fileManager, cpFiles1_, - version = 1, numKeys = 1001, fileMapping, uuid) + version = 1, numKeys = 1001, fileMapping, + numInternalKeys = 0, uuid) assert(numRemoteSSTFiles === 5, "shouldn't reuse old version 1 SST files" + " while uploading version 1 again") // 2 old + 3 new SST files assert(numRemoteLogFiles === 5, "shouldn't reuse old version 1 log files" + @@ -1938,7 +2025,8 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession "archive/00004.log" -> 4000 ) saveCheckpointFiles(fileManager, cpFiles2, - version = 2, numKeys = 1501, fileMapping, uuid) + version = 2, numKeys = 1501, fileMapping, + numInternalKeys = 0, uuid) assert(numRemoteSSTFiles === 6) // 1 new file over earlier 5 files assert(numRemoteLogFiles === 6) // 1 new file over earlier 6 files loadAndVerifyCheckpointFiles(fileManager, verificationDir, @@ -1981,7 +2069,8 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession } intercept[IOException] { saveCheckpointFiles( - fileManager, cpFiles, version = 1, numKeys = 101, new RocksDBFileMapping(), uuid) + fileManager, cpFiles, version = 1, numKeys = 101, new RocksDBFileMapping(), + numInternalKeys = 0, uuid) } assert(CreateAtomicTestManager.cancelCalledInCreateAtomic) } @@ -2115,16 +2204,31 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession // should always include sstFiles and numKeys checkJsonRoundtrip( RocksDBCheckpointMetadata(Seq.empty, 0L), - """{"sstFiles":[],"numKeys":0}""" + """{"sstFiles":[],"numKeys":0,"numInternalKeys":0}""" ) // shouldn't include the "logFiles" field in json when it's empty checkJsonRoundtrip( RocksDBCheckpointMetadata(sstFiles, 12345678901234L), - """{"sstFiles":[{"localFileName":"00001.sst","dfsSstFileName":"00001-uuid.sst","sizeBytes":12345678901234}],"numKeys":12345678901234}""" + """{"sstFiles":[{"localFileName":"00001.sst","dfsSstFileName":"00001-uuid.sst","sizeBytes":12345678901234}],"numKeys":12345678901234,"numInternalKeys":0}""" ) checkJsonRoundtrip( RocksDBCheckpointMetadata(sstFiles, logFiles, 12345678901234L), - """{"sstFiles":[{"localFileName":"00001.sst","dfsSstFileName":"00001-uuid.sst","sizeBytes":12345678901234}],"logFiles":[{"localFileName":"00001.log","dfsLogFileName":"00001-uuid.log","sizeBytes":12345678901234}],"numKeys":12345678901234}""") + """{"sstFiles":[{"localFileName":"00001.sst","dfsSstFileName":"00001-uuid.sst","sizeBytes":12345678901234}],"logFiles":[{"localFileName":"00001.log","dfsLogFileName":"00001-uuid.log","sizeBytes":12345678901234}],"numKeys":12345678901234,"numInternalKeys":0}""") + + // verify format without including column family type + val cfMapping: Option[scala.collection.Map[String, Short]] = Some(Map("cf1" -> 1, "cf2" -> 2)) + var cfTypeMap: Option[scala.collection.Map[String, Boolean]] = None + val maxCfId: Option[Short] = Some(2) + checkJsonRoundtrip( + RocksDBCheckpointMetadata(Seq.empty, 5L, 0L, cfMapping, cfTypeMap, maxCfId), + """{"sstFiles":[],"numKeys":5,"numInternalKeys":0,"columnFamilyMapping":{"cf1":1,"cf2":2},"maxColumnFamilyId":2}""") + + // verify format including column family type and non-zero internal keys + cfTypeMap = Some(Map("cf1" -> true, "cf2" -> false)) + checkJsonRoundtrip( + RocksDBCheckpointMetadata(Seq.empty, 3L, 2L, cfMapping, cfTypeMap, maxCfId), + """{"sstFiles":[],"numKeys":3,"numInternalKeys":2,"columnFamilyMapping":{"cf1":1,"cf2":2},"columnFamilyTypeMap":{"cf1":true,"cf2":false},"maxColumnFamilyId":2}""") + // scalastyle:on line.size.limit } @@ -3374,8 +3478,14 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession useColumnFamilies = useColumnFamilies) } db.load(version, versionToUniqueId.get(version)) + if (useColumnFamilies) { + db.createColFamilyIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME, isInternal = false) + } func(db) } finally { + if (useColumnFamilies && db != null) { + db.removeColFamilyIfExists(StateStore.DEFAULT_COL_FAMILY_NAME) + } if (db != null) { db.close() } @@ -3395,6 +3505,7 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession version: Int, numKeys: Int, fileMapping: RocksDBFileMapping, + numInternalKeys: Int = 0, checkpointUniqueId: Option[String] = None): Unit = { val checkpointDir = Utils.createTempDir().getAbsolutePath // local dir to create checkpoints generateFiles(checkpointDir, fileToLengths) @@ -3404,6 +3515,7 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession checkpointDir, version, numKeys, + numInternalKeys, immutableFileMapping, checkpointUniqueId = checkpointUniqueId) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index adef3765c211f..08648148b4af4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -1873,22 +1873,26 @@ object StateStoreTestsHelper { iterator.map(rowPairToDataPair).toSet } - def remove(store: StateStore, condition: ((String, Int)) => Boolean): Unit = { - store.iterator().foreach { rowPair => - if (condition(keyRowToData(rowPair.key))) store.remove(rowPair.key) + def remove(store: StateStore, condition: ((String, Int)) => Boolean, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = { + store.iterator(colFamilyName).foreach { rowPair => + if (condition(keyRowToData(rowPair.key))) store.remove(rowPair.key, colFamilyName) } } - def put(store: StateStore, key1: String, key2: Int, value: Int): Unit = { - store.put(dataToKeyRow(key1, key2), dataToValueRow(value)) + def put(store: StateStore, key1: String, key2: Int, value: Int, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = { + store.put(dataToKeyRow(key1, key2), dataToValueRow(value), colFamilyName) } - def merge(store: StateStore, key1: String, key2: Int, value: Int): Unit = { - store.merge(dataToKeyRow(key1, key2), dataToValueRow(value)) + def merge(store: StateStore, key1: String, key2: Int, value: Int, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = { + store.merge(dataToKeyRow(key1, key2), dataToValueRow(value), colFamilyName) } - def get(store: ReadStateStore, key1: String, key2: Int): Option[Int] = { - Option(store.get(dataToKeyRow(key1, key2))).map(valueRowToData) + def get(store: ReadStateStore, key1: String, key2: Int, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Option[Int] = { + Option(store.get(dataToKeyRow(key1, key2), colFamilyName)).map(valueRowToData) } def newDir(): String = Utils.createTempDir().toString diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala index b188b92bdbb7c..d04573becf1ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala @@ -228,15 +228,18 @@ class TransformWithListStateTTLSuite extends TransformWithStateTTLTest // - List state: 1 record in the primary, TTL, min, and count indexes // - Value state: 1 record in the primary, and 1 record in the TTL index // - // So in total, that amounts to 2t + 4 + 2 = 2t + 6 records. + // So in total, that amounts to 2t + 4 + 2 = 2t + 6 records. This is for internal and + // non-internal column families. For non-internal column families, the total records are + // t + 2. // // In this test, we have 2 unique keys, and each key occurs 3 times. Thus, the total number - // of keys in state is 2 * (2t + 6) where t = 3, which is 24. + // of keys in state is 2 * (2t + 6) where t = 3, which is 24. And the total number of + // records in the primary indexes are 2 * (t + 2) = 10. // // The number of updated rows is the total across the last time assertNumStateRows // was called, and we only update numRowsUpdated for primary key updates. We ran 6 batches // and each wrote 3 primary keys, so the total number of updated rows is 6 * 3 = 18. - assertNumStateRows(total = 24, updated = 18) + assertNumStateRows(total = 10, updated = 18) ) } } @@ -552,7 +555,7 @@ class TransformWithListStateTTLSuite extends TransformWithStateTTLTest // // It's important to check with assertNumStateRows, since the InputEvents // only return values for the current grouping key, not the entirety of RocksDB. - assertNumStateRows(total = 4, updated = 4), + assertNumStateRows(total = 1, updated = 4), // The k1 calls should both return no values. However, the k2 calls should return // one record each. We put these into one AddData call since we want them all to