diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index d450071f477..0d7354cb76a 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -101,6 +101,8 @@ private void initializeLifecycleManager(String appId) { if (celebornConf.clientFetchThrowsFetchFailure()) { MapOutputTrackerMaster mapOutputTracker = (MapOutputTrackerMaster) SparkEnv.get().mapOutputTracker(); + lifecycleManager.registerReportTaskShuffleFetchFailurePreCheck( + taskId -> !SparkUtils.taskAnotherAttemptRunning(taskId)); lifecycleManager.registerShuffleTrackerCallback( shuffleId -> mapOutputTracker.unregisterAllMapOutput(shuffleId)); } diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index 3f2e4709750..77a6feef036 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -20,6 +20,8 @@ import java.io.IOException; import java.lang.reflect.Field; import java.lang.reflect.Method; +import java.util.HashMap; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.LongAdder; import scala.Option; @@ -35,6 +37,9 @@ import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; import org.apache.spark.scheduler.ShuffleMapStage; +import org.apache.spark.scheduler.TaskInfo; +import org.apache.spark.scheduler.TaskSchedulerImpl; +import org.apache.spark.scheduler.TaskSetManager; import org.apache.spark.sql.execution.UnsafeRowSerializer; import org.apache.spark.sql.execution.metric.SQLMetric; import org.apache.spark.storage.BlockManagerId; @@ -203,4 +208,47 @@ public static void cancelShuffle(int shuffleId, String reason) { logger.error("Can not get active SparkContext, skip cancelShuffle."); } } + + private static final DynFields.UnboundField> + TASK_ID_TO_TASK_SET_MANAGER_FIELD = + DynFields.builder() + .hiddenImpl(TaskSchedulerImpl.class, "taskIdToTaskSetManager") + .defaultAlwaysNull() + .build(); + private static final DynFields.UnboundField> TASK_INFOS_FIELD = + DynFields.builder().hiddenImpl(TaskSetManager.class, "taskInfos").defaultAlwaysNull().build(); + + public static boolean taskAnotherAttemptRunning(long taskId) { + if (SparkContext$.MODULE$.getActive().nonEmpty()) { + TaskSchedulerImpl taskScheduler = + (TaskSchedulerImpl) SparkContext$.MODULE$.getActive().get().taskScheduler(); + ConcurrentHashMap taskIdToTaskSetManager = + TASK_ID_TO_TASK_SET_MANAGER_FIELD.bind(taskScheduler).get(); + TaskSetManager taskSetManager = taskIdToTaskSetManager.get(taskId); + if (taskSetManager != null) { + HashMap taskInfos = TASK_INFOS_FIELD.bind(taskSetManager).get(); + TaskInfo taskInfo = taskInfos.get(taskId); + if (taskInfo != null) { + return taskSetManager.taskAttempts()[taskInfo.index()].exists( + ti -> { + if (ti.running() && ti.attemptNumber() != taskInfo.attemptNumber()) { + LOG.info("Another attempt of task {} is running: {}.", taskInfo, ti); + return true; + } else { + return false; + } + }); + } else { + LOG.error("Can not get TaskInfo for taskId: {}", taskId); + return false; + } + } else { + LOG.error("Can not get TaskSetManager for taskId: {}", taskId); + return false; + } + } else { + LOG.error("Can not get active SparkContext, skip cancelShuffle."); + return false; + } + } } diff --git a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index fb55e741886..589a2f2b90d 100644 --- a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -98,6 +98,7 @@ class CelebornShuffleReader[K, C]( shuffleId, partitionId, encodedAttemptId, + context.taskAttemptId(), startMapIndex, endMapIndex, metricsCallback) diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index 8541ad22324..d4434cb765f 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -143,7 +143,8 @@ private void initializeLifecycleManager(String appId) { if (celebornConf.clientFetchThrowsFetchFailure()) { MapOutputTrackerMaster mapOutputTracker = (MapOutputTrackerMaster) SparkEnv.get().mapOutputTracker(); - + lifecycleManager.registerReportTaskShuffleFetchFailurePreCheck( + taskId -> !SparkUtils.taskAnotherAttemptRunning(taskId)); lifecycleManager.registerShuffleTrackerCallback( shuffleId -> SparkUtils.unregisterAllMapOutput(mapOutputTracker, shuffleId)); } diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index d8a237bc459..5aecbc367c6 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -17,6 +17,8 @@ package org.apache.spark.shuffle.celeborn; +import java.util.HashMap; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.LongAdder; import scala.Option; @@ -33,6 +35,9 @@ import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; import org.apache.spark.scheduler.ShuffleMapStage; +import org.apache.spark.scheduler.TaskInfo; +import org.apache.spark.scheduler.TaskSchedulerImpl; +import org.apache.spark.scheduler.TaskSetManager; import org.apache.spark.shuffle.ShuffleHandle; import org.apache.spark.shuffle.ShuffleReadMetricsReporter; import org.apache.spark.shuffle.ShuffleReader; @@ -319,4 +324,47 @@ public static void cancelShuffle(int shuffleId, String reason) { LOG.error("Can not get active SparkContext, skip cancelShuffle."); } } + + private static final DynFields.UnboundField> + TASK_ID_TO_TASK_SET_MANAGER_FIELD = + DynFields.builder() + .hiddenImpl(TaskSchedulerImpl.class, "taskIdToTaskSetManager") + .defaultAlwaysNull() + .build(); + private static final DynFields.UnboundField> TASK_INFOS_FIELD = + DynFields.builder().hiddenImpl(TaskSetManager.class, "taskInfos").defaultAlwaysNull().build(); + + public static boolean taskAnotherAttemptRunning(long taskId) { + if (SparkContext$.MODULE$.getActive().nonEmpty()) { + TaskSchedulerImpl taskScheduler = + (TaskSchedulerImpl) SparkContext$.MODULE$.getActive().get().taskScheduler(); + ConcurrentHashMap taskIdToTaskSetManager = + TASK_ID_TO_TASK_SET_MANAGER_FIELD.bind(taskScheduler).get(); + TaskSetManager taskSetManager = taskIdToTaskSetManager.get(taskId); + if (taskSetManager != null) { + HashMap taskInfos = TASK_INFOS_FIELD.bind(taskSetManager).get(); + TaskInfo taskInfo = taskInfos.get(taskId); + if (taskInfo != null) { + return taskSetManager.taskAttempts()[taskInfo.index()].exists( + ti -> { + if (ti.running() && ti.attemptNumber() != taskInfo.attemptNumber()) { + LOG.info("Another attempt of task {} is running: {}.", taskInfo, ti); + return true; + } else { + return false; + } + }); + } else { + LOG.error("Can not get TaskInfo for taskId: {}", taskId); + return false; + } + } else { + LOG.error("Can not get TaskSetManager for taskId: {}", taskId); + return false; + } + } else { + LOG.error("Can not get active SparkContext, skip cancelShuffle."); + return false; + } + } } diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index a549464fefd..3cbe952c6cc 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -215,6 +215,7 @@ class CelebornShuffleReader[K, C]( handle.shuffleId, partitionId, encodedAttemptId, + context.taskAttemptId(), startMapIndex, endMapIndex, if (throwsFetchFailure) ExceptionMakerHelper.SHUFFLE_FETCH_FAILURE_EXCEPTION_MAKER @@ -371,7 +372,10 @@ class CelebornShuffleReader[K, C]( private def handleFetchExceptions(shuffleId: Int, partitionId: Int, ce: Throwable) = { if (throwsFetchFailure && - shuffleClient.reportShuffleFetchFailure(handle.shuffleId, shuffleId)) { + shuffleClient.reportShuffleFetchFailure( + handle.shuffleId, + shuffleId, + context.taskAttemptId())) { logWarning(s"Handle fetch exceptions for ${shuffleId}-${partitionId}", ce) throw new FetchFailedException( null, diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java index efa9641f671..673c9382437 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java @@ -224,6 +224,7 @@ public CelebornInputStream readPartition( int shuffleId, int partitionId, int attemptNumber, + long taskId, int startMapIndex, int endMapIndex, MetricsCallback metricsCallback) @@ -233,6 +234,7 @@ public CelebornInputStream readPartition( shuffleId, partitionId, attemptNumber, + taskId, startMapIndex, endMapIndex, null, @@ -247,6 +249,7 @@ public abstract CelebornInputStream readPartition( int appShuffleId, int partitionId, int attemptNumber, + long taskId, int startMapIndex, int endMapIndex, ExceptionMaker exceptionMaker, @@ -276,7 +279,7 @@ public abstract int getShuffleId( * cleanup for spark app. It must be a sync call and make sure the cleanup is done, otherwise, * incorrect shuffle data can be fetched in re-run tasks */ - public abstract boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId); + public abstract boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId, long taskId); /** * Report barrier task failure. When any barrier task fails, all (shuffle) output for that stage diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index de849a8847a..ea0c3c87de2 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -622,11 +622,12 @@ public int getShuffleId( } @Override - public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId) { + public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId, long taskId) { PbReportShuffleFetchFailure pbReportShuffleFetchFailure = PbReportShuffleFetchFailure.newBuilder() .setAppShuffleId(appShuffleId) .setShuffleId(shuffleId) + .setTaskId(taskId) .build(); PbReportShuffleFetchFailureResponse pbReportShuffleFetchFailureResponse = lifecycleManagerRef.askSync( @@ -1752,6 +1753,7 @@ public CelebornInputStream readPartition( int appShuffleId, int partitionId, int attemptNumber, + long taskId, int startMapIndex, int endMapIndex, ExceptionMaker exceptionMaker, @@ -1790,6 +1792,7 @@ public CelebornInputStream readPartition( streamHandlers, mapAttempts, attemptNumber, + taskId, startMapIndex, endMapIndex, fetchExcludedWorkers, diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index dfbb7c502b8..5415607d9b6 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -56,6 +56,7 @@ public static CelebornInputStream create( ArrayList streamHandlers, int[] attempts, int attemptNumber, + long taskId, int startMapIndex, int endMapIndex, ConcurrentHashMap fetchExcludedWorkers, @@ -77,6 +78,7 @@ public static CelebornInputStream create( streamHandlers, attempts, attemptNumber, + taskId, startMapIndex, endMapIndex, fetchExcludedWorkers, @@ -130,6 +132,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { private ArrayList streamHandlers; private int[] attempts; private final int attemptNumber; + private final long taskId; private final int startMapIndex; private final int endMapIndex; @@ -179,6 +182,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { ArrayList streamHandlers, int[] attempts, int attemptNumber, + long taskId, int startMapIndex, int endMapIndex, ConcurrentHashMap fetchExcludedWorkers, @@ -198,6 +202,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { } this.attempts = attempts; this.attemptNumber = attemptNumber; + this.taskId = taskId; this.startMapIndex = startMapIndex; this.endMapIndex = endMapIndex; this.rangeReadFilter = conf.shuffleRangeReadFilterEnabled(); @@ -673,7 +678,7 @@ private boolean fillBuffer() throws IOException { ioe = new IOException(e); } if (exceptionMaker != null) { - if (shuffleClient.reportShuffleFetchFailure(appShuffleId, shuffleId)) { + if (shuffleClient.reportShuffleFetchFailure(appShuffleId, shuffleId, taskId)) { /* * [[ExceptionMaker.makeException]], for spark applications with celeborn.client.spark.fetch.throwsFetchFailure enabled will result in creating * a FetchFailedException; and that will make the TaskContext as failed with shuffle fetch issues - see SPARK-19276 for more. diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index 47c27bcedcd..e7f9995d28a 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -441,8 +441,9 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends case pb: PbReportShuffleFetchFailure => val appShuffleId = pb.getAppShuffleId val shuffleId = pb.getShuffleId - logDebug(s"Received ReportShuffleFetchFailure request, appShuffleId $appShuffleId shuffleId $shuffleId") - handleReportShuffleFetchFailure(context, appShuffleId, shuffleId) + val taskId = pb.getTaskId + logDebug(s"Received ReportShuffleFetchFailure request, appShuffleId $appShuffleId shuffleId $shuffleId taskId $taskId") + handleReportShuffleFetchFailure(context, appShuffleId, shuffleId, taskId) case pb: PbReportBarrierStageAttemptFailure => val appShuffleId = pb.getAppShuffleId @@ -931,7 +932,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends private def handleReportShuffleFetchFailure( context: RpcCallContext, appShuffleId: Int, - shuffleId: Int): Unit = { + shuffleId: Int, + taskId: Long): Unit = { val shuffleIds = shuffleIdMapping.get(appShuffleId) if (shuffleIds == null) { @@ -941,9 +943,15 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends shuffleIds.synchronized { shuffleIds.find(e => e._2._1 == shuffleId) match { case Some((appShuffleIdentifier, (shuffleId, true))) => - logInfo(s"handle fetch failure for appShuffleId $appShuffleId shuffleId $shuffleId") - ret = invokeAppShuffleTrackerCallback(appShuffleId) - shuffleIds.put(appShuffleIdentifier, (shuffleId, false)) + if (invokeReportTaskShuffleFetchFailurePreCheck(taskId)) { + logInfo(s"handle fetch failure for appShuffleId $appShuffleId shuffleId $shuffleId") + ret = invokeAppShuffleTrackerCallback(appShuffleId) + shuffleIds.put(appShuffleIdentifier, (shuffleId, false)) + } else { + logInfo( + s"Ignoring fetch failure from appShuffleIdentifier $appShuffleIdentifier shuffleId $shuffleId taskId $taskId") + ret = false + } case Some((appShuffleIdentifier, (shuffleId, false))) => logInfo( s"Ignoring fetch failure from appShuffleIdentifier $appShuffleIdentifier shuffleId $shuffleId, " + @@ -1006,6 +1014,22 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends } } + private def invokeReportTaskShuffleFetchFailurePreCheck(taskId: Long): Boolean = { + reportTaskShuffleFetchFailurePreCheck match { + case Some(precheck) => + try { + precheck.apply(taskId) + } catch { + case t: Throwable => + logError(t.toString) + false + } + case None => + throw new UnsupportedOperationException( + "unexpected! reportTaskShuffleFetchFailurePreCheck is not registered") + } + } + private def handleStageEnd(shuffleId: Int): Unit = { // check whether shuffle has registered if (!registeredShuffle.contains(shuffleId)) { @@ -1766,6 +1790,13 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends workerStatusTracker.registerWorkerStatusListener(workerStatusListener) } + @volatile private var reportTaskShuffleFetchFailurePreCheck + : Option[Function[java.lang.Long, Boolean]] = None + def registerReportTaskShuffleFetchFailurePreCheck(preCheck: Function[java.lang.Long, Boolean]) + : Unit = { + reportTaskShuffleFetchFailurePreCheck = Some(preCheck) + } + @volatile private var appShuffleTrackerCallback: Option[Consumer[Integer]] = None def registerShuffleTrackerCallback(callback: Consumer[Integer]): Unit = { appShuffleTrackerCallback = Some(callback) diff --git a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java index a190c3e1bc7..c724467ce41 100644 --- a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java +++ b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java @@ -130,6 +130,7 @@ public CelebornInputStream readPartition( int appShuffleId, int partitionId, int attemptNumber, + long taskId, int startMapIndex, int endMapIndex, ExceptionMaker exceptionMaker, @@ -179,7 +180,7 @@ public int getShuffleId( } @Override - public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId) { + public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId, long taskId) { return true; } diff --git a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala index d5631d5b055..8cd8bedf76d 100644 --- a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala +++ b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala @@ -158,6 +158,7 @@ trait WithShuffleClientSuite extends CelebornFunSuite { 1, 1, 0, + 0, Integer.MAX_VALUE, null, null, @@ -173,6 +174,7 @@ trait WithShuffleClientSuite extends CelebornFunSuite { 3, 1, 0, + 0, Integer.MAX_VALUE, null, null, diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto index be09e83c7b7..98abed5e301 100644 --- a/common/src/main/proto/TransportMessages.proto +++ b/common/src/main/proto/TransportMessages.proto @@ -390,6 +390,7 @@ message PbGetShuffleIdResponse { message PbReportShuffleFetchFailure { int32 appShuffleId = 1; int32 shuffleId = 2; + int64 taskId = 3; } message PbReportShuffleFetchFailureResponse { diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala index d62a78515c5..0ff646b8f8b 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala @@ -110,6 +110,7 @@ trait ReadWriteTestBase extends AnyFunSuite 0, 0, 0, + 0, Integer.MAX_VALUE, null, null,