Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.commons.io.output.{ByteArrayOutputStream => ApacheByteArrayOut
import org.roaringbitmap.RoaringBitmap

import org.apache.spark.broadcast.{Broadcast, BroadcastManager}
import org.apache.spark.celeborn.CelebornShuffleState
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.io.CompressionCodec
Expand Down Expand Up @@ -916,6 +917,7 @@ private[spark] class MapOutputTrackerMaster(
shuffleStatus.invalidateSerializedMergeOutputStatusCache()
}
}
CelebornShuffleState.unregisterCelebornSkewedShuffle(shuffleId)
}

/**
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.python.PythonWorkerFactory
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.celeborn.CelebornShuffleState
import org.apache.spark.executor.ExecutorBackend
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.internal.config._
Expand Down Expand Up @@ -419,6 +420,7 @@ object SparkEnv extends Logging {
if (isDriver) {
val sparkFilesDir = Utils.createTempDir(Utils.getLocalDir(conf), "userFiles").getAbsolutePath
envInstance.driverTmpDir = Some(sparkFilesDir)
CelebornShuffleState.init(envInstance)
}

envInstance
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.celeborn

import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicBoolean

import org.apache.spark.SparkEnv
import org.apache.spark.internal.config.ConfigBuilder

object CelebornShuffleState {

private val CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ =
ConfigBuilder("spark.celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled")
.booleanConf
.createWithDefault(false)

private val CELEBORN_STAGE_RERUN_ENABLED =
ConfigBuilder("spark.celeborn.client.spark.stageRerun.enabled")
.withAlternative("spark.celeborn.client.spark.fetch.throwsFetchFailure")
.booleanConf
.createWithDefault(false)

private val celebornOptimizeSkewedPartitionReadEnabled = new AtomicBoolean()
private val stageRerunEnabled = new AtomicBoolean()
private val skewShuffleIds = ConcurrentHashMap.newKeySet[Int]()

// call this from SparkEnv.create
def init(env: SparkEnv): Unit = {
// cleanup existing state (if required) - and initialize
skewShuffleIds.clear()

// use env.conf for all initialization, and not SQLConf
celebornOptimizeSkewedPartitionReadEnabled.set(
env.conf.get("spark.shuffle.manager", "sort").contains("celeborn") &&
env.conf.get(CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ))
stageRerunEnabled.set(env.conf.get(CELEBORN_STAGE_RERUN_ENABLED))
}

def unregisterCelebornSkewedShuffle(shuffleId: Int): Unit = {
skewShuffleIds.remove(shuffleId)
}

def registerCelebornSkewedShuffle(shuffleId: Int): Unit = {
skewShuffleIds.add(shuffleId)
}

def isCelebornSkewedShuffle(shuffleId: Int): Boolean = {
skewShuffleIds.contains(shuffleId)
}

def celebornAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean = {
celebornOptimizeSkewedPartitionReadEnabled.get()
}

def celebornStageRerunEnabled: Boolean = {
stageRerunEnabled.get()
}

}
43 changes: 30 additions & 13 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ private[spark] class Executor(
val currentJars = new HashMap[String, Long]
val currentArchives = new HashMap[String, Long]
val urlClassLoader =
createClassLoader(currentJars, isStubbingEnabledForState(jobArtifactState.uuid))
createClassLoader(currentJars, isStubbingEnabledForState(jobArtifactState.uuid),
isDefaultState(jobArtifactState.uuid))
val replClassLoader = addReplClassLoaderIfNeeded(
urlClassLoader, jobArtifactState.replClassDirUri, jobArtifactState.uuid)
new IsolatedSessionState(
Expand Down Expand Up @@ -1029,7 +1030,8 @@ private[spark] class Executor(
*/
private def createClassLoader(
currentJars: HashMap[String, Long],
useStub: Boolean): MutableURLClassLoader = {
useStub: Boolean,
isDefaultSession: Boolean): MutableURLClassLoader = {
// Bootstrap the list of jars with the user class path.
val now = System.currentTimeMillis()
userClassPath.foreach { url =>
Expand All @@ -1041,43 +1043,57 @@ private[spark] class Executor(
val urls = userClassPath.toArray ++ currentJars.keySet.map { uri =>
new File(uri.split("/").last).toURI.toURL
}
createClassLoader(urls, useStub)
createClassLoader(urls, useStub, isDefaultSession)
}

private def createClassLoader(urls: Array[URL], useStub: Boolean): MutableURLClassLoader = {
private def createClassLoader(urls: Array[URL],
useStub: Boolean,
isDefaultSession: Boolean): MutableURLClassLoader = {
logInfo(
s"Starting executor with user classpath (userClassPathFirst = $userClassPathFirst): " +
urls.mkString("'", ",", "'")
)

if (useStub) {
createClassLoaderWithStub(urls, conf.get(CONNECT_SCALA_UDF_STUB_PREFIXES))
createClassLoaderWithStub(urls, conf.get(CONNECT_SCALA_UDF_STUB_PREFIXES), isDefaultSession)
} else {
createClassLoader(urls)
createClassLoader(urls, isDefaultSession)
}
}

private def createClassLoader(urls: Array[URL]): MutableURLClassLoader = {
private def createClassLoader(urls: Array[URL],
isDefaultSession: Boolean): MutableURLClassLoader = {
// SPARK-51537: The isolated session must *inherit* the classloader from the default session,
// which has already included the global JARs specified via --jars. For Spark plugins, we
// cannot simply add the plugin JARs to the classpath of the isolated session, as this may
// cause the plugin to be reloaded, leading to potential conflicts or unexpected behavior.
val loader = if (isDefaultSession) systemLoader else defaultSessionState.replClassLoader
if (userClassPathFirst) {
new ChildFirstURLClassLoader(urls, systemLoader)
new ChildFirstURLClassLoader(urls, loader)
} else {
new MutableURLClassLoader(urls, systemLoader)
new MutableURLClassLoader(urls, loader)
}
}

private def createClassLoaderWithStub(
urls: Array[URL],
binaryName: Seq[String]): MutableURLClassLoader = {
binaryName: Seq[String],
isDefaultSession: Boolean): MutableURLClassLoader = {
// SPARK-51537: The isolated session must *inherit* the classloader from the default session,
// which has already included the global JARs specified via --jars. For Spark plugins, we
// cannot simply add the plugin JARs to the classpath of the isolated session, as this may
// cause the plugin to be reloaded, leading to potential conflicts or unexpected behavior.
val loader = if (isDefaultSession) systemLoader else defaultSessionState.replClassLoader
if (userClassPathFirst) {
// user -> (sys -> stub)
val stubClassLoader =
StubClassLoader(systemLoader, binaryName)
StubClassLoader(loader, binaryName)
new ChildFirstURLClassLoader(urls, stubClassLoader)
} else {
// sys -> user -> stub
val stubClassLoader =
StubClassLoader(null, binaryName)
new ChildFirstURLClassLoader(urls, stubClassLoader, systemLoader)
new ChildFirstURLClassLoader(urls, stubClassLoader, loader)
}
}

Expand Down Expand Up @@ -1176,7 +1192,8 @@ private[spark] class Executor(
}
if (renewClassLoader) {
// Recreate the class loader to ensure all classes are updated.
state.urlClassLoader = createClassLoader(state.urlClassLoader.getURLs, useStub = true)
state.urlClassLoader = createClassLoader(state.urlClassLoader.getURLs,
useStub = true, isDefaultState(state.sessionUUID))
state.replClassLoader =
addReplClassLoaderIfNeeded(state.urlClassLoader, state.replClassDirUri, state.sessionUUID)
}
Expand Down
42 changes: 37 additions & 5 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import com.google.common.util.concurrent.{Futures, SettableFuture}

import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.celeborn.CelebornShuffleState
import org.apache.spark.errors.SparkCoreErrors
import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics}
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -1480,7 +1481,10 @@ private[spark] class DAGScheduler(
// The operation here can make sure for the partially completed intermediate stage,
// `findMissingPartitions()` returns all partitions every time.
stage match {
case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable =>
case sms: ShuffleMapStage if (stage.isIndeterminate ||
CelebornShuffleState.isCelebornSkewedShuffle(sms.shuffleDep.shuffleId)) && !sms.isAvailable =>
logInfo(s"Unregistering shuffle output for stage ${stage.id}" +
s" shuffle ${sms.shuffleDep.shuffleId}")
mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
sms.shuffleDep.newShuffleMergeState()
case _ =>
Expand Down Expand Up @@ -1854,7 +1858,18 @@ private[spark] class DAGScheduler(
// tasks complete, they still count and we can mark the corresponding partitions as
// finished if the stage is determinate. Here we notify the task scheduler to skip running
// tasks for the same partition to save resource.
if (!stage.isIndeterminate && task.stageAttemptId < stage.latestInfo.attemptNumber()) {
// finished. Here we notify the task scheduler to skip running tasks for the same partition,
// to save resource.
// CELEBORN-1856, if stage is indeterminate or shuffleMapStage is skewed and read by
// Celeborn chunkOffsets, should not call notifyPartitionCompletion, otherwise will
// skip running tasks for the same partition because TaskSetManager.dequeueTaskFromList
// will skip running task which TaskSetManager.successful(taskIndex) is true.
// TODO: ResultStage has result commit and other issues
val isCelebornShuffleIndeterminate = stage.isInstanceOf[ShuffleMapStage] &&
CelebornShuffleState.isCelebornSkewedShuffle(
stage.asInstanceOf[ShuffleMapStage].shuffleDep.shuffleId)
if (!stage.isIndeterminate && task.stageAttemptId < stage.latestInfo.attemptNumber()
&& !isCelebornShuffleIndeterminate) {
taskScheduler.notifyPartitionCompletion(stageId, task.partitionId)
}

Expand Down Expand Up @@ -1909,7 +1924,7 @@ private[spark] class DAGScheduler(
case smt: ShuffleMapTask =>
val shuffleStage = stage.asInstanceOf[ShuffleMapStage]
// Ignore task completion for old attempt of indeterminate stage
val ignoreIndeterminate = stage.isIndeterminate &&
val ignoreIndeterminate = (stage.isIndeterminate || isCelebornShuffleIndeterminate) &&
task.stageAttemptId < stage.latestInfo.attemptNumber()
if (!ignoreIndeterminate) {
shuffleStage.pendingPartitions -= task.partitionId
Expand Down Expand Up @@ -1944,6 +1959,14 @@ private[spark] class DAGScheduler(
val failedStage = stageIdToStage(task.stageId)
val mapStage = shuffleIdToMapStage(shuffleId)

// In Celeborn-1139 we support read skew partition by Celeborn chunkOffsets,
// it will make shuffle be indeterminate, so abort the ResultStage directly here.
if (failedStage.isInstanceOf[ResultStage] && CelebornShuffleState.isCelebornSkewedShuffle(shuffleId)) {
val shuffleFailedReason = s"ResultStage:${failedStage.id} fetch failed and the shuffle:$shuffleId " +
s"is skewed partition read by Celeborn, so abort it."
abortStage(failedStage, shuffleFailedReason, None)
}

if (failedStage.latestInfo.attemptNumber != task.stageAttemptId) {
logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" +
s" ${task.stageAttemptId} and there is a more recent attempt for that stage " +
Expand Down Expand Up @@ -2042,7 +2065,8 @@ private[spark] class DAGScheduler(
// Note that, if map stage is UNORDERED, we are fine. The shuffle partitioner is
// guaranteed to be determinate, so the input data of the reducers will not change
// even if the map tasks are re-tried.
if (mapStage.isIndeterminate) {
val isCelebornSkewedShuffle = CelebornShuffleState.isCelebornSkewedShuffle(shuffleId)
if (mapStage.isIndeterminate || isCelebornSkewedShuffle) {
// It's a little tricky to find all the succeeding stages of `mapStage`, because
// each stage only know its parents not children. Here we traverse the stages from
// the leaf nodes (the result stages of active jobs), and rollback all the stages
Expand All @@ -2053,7 +2077,15 @@ private[spark] class DAGScheduler(

def collectStagesToRollback(stageChain: List[Stage]): Unit = {
if (stagesToRollback.contains(stageChain.head)) {
stageChain.drop(1).foreach(s => stagesToRollback += s)
stageChain.drop(1).foreach(s => {
stagesToRollback += s
s match {
case currentMapStage: ShuffleMapStage if isCelebornSkewedShuffle =>
CelebornShuffleState.registerCelebornSkewedShuffle(currentMapStage.shuffleDep.shuffleId)
case _: ResultStage =>
// do nothing, should abort celeborn skewed read stage
}
})
} else {
stageChain.head.parents.foreach { s =>
collectStagesToRollback(s :: stageChain)
Expand Down
Loading
Loading