diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 2a8e15cd09ccf..62b98a41b864d 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -964,8 +964,7 @@ public void reset() { updatePeakMemoryUsed(); numKeys = 0; numValues = 0; - freeArray(longArray); - longArray = null; + freeInternalArray(); while (dataPages.size() > 0) { MemoryBlock dataPage = dataPages.removeLast(); freePage(dataPage); @@ -976,6 +975,17 @@ public void reset() { pageCursor = 0; } + /** + * Free array memory to reduce the memory footprint in case of a fallback + * from a hash-based aggregation to the sort-based one. + */ + public void freeInternalArray() { + if (longArray != null) { + freeArray(longArray); + longArray = null; + } + } + /** * Grows the size of the hash table and re-hash everything. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 6f2d12e6b790a..c4587cc340b8f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -114,7 +114,9 @@ public UnsafeKVExternalSorter( // so that we can always reuse the pointer array. if (map.numValues() > pointerArray.size() / 4) { // Here we ask the map to allocate memory, so that the memory manager won't ask the map - // to spill, if the memory is not enough. + // to spill, if the memory is not enough. Also, we free the redundant internal map array + // to reduce the overall memory footprint. + map.freeInternalArray(); pointerArray = map.allocateArray(map.numValues() * 4L); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index b3370b6733d92..1a12162248af8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -242,7 +242,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSparkSession try { val context = new TaskContextImpl(0, 0, 0, 0, 0, 1, taskMemoryManager, new Properties(), null) TaskContext.setTaskContext(context) - val expectedSpillSize = map.getTotalMemoryConsumption + val expectedSpillSize = expectedSpillSizeForMapWithDuplicateKeys(map) val sorter = new UnsafeKVExternalSorter( schema, schema, @@ -267,7 +267,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSparkSession try { val context = new TaskContextImpl(0, 0, 0, 0, 0, 1, taskMemoryManager, new Properties(), null) TaskContext.setTaskContext(context) - val expectedSpillSize = map1.getTotalMemoryConsumption + map2.getTotalMemoryConsumption + val expectedSpillSize = expectedSpillSizeForMapWithDuplicateKeys(map1) + + expectedSpillSizeForMapWithDuplicateKeys(map2) val sorter1 = new UnsafeKVExternalSorter( schema, schema, @@ -309,4 +310,10 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSparkSession } map } + + private def expectedSpillSizeForMapWithDuplicateKeys(map: BytesToBytesMap): Long = { + val internalArrayMemoryUsed = Option(map.getArray).map(_.memoryBlock().size()).getOrElse(0L) + map.getTotalMemoryConsumption - internalArrayMemoryUsed + } + }