From 4d96e297d294d905885b653539051130bad9616c Mon Sep 17 00:00:00 2001 From: "Wang, Fei" Date: Thu, 14 Nov 2024 15:25:31 -0800 Subject: [PATCH 1/4] [CELEBORN-1720] Prevent stage re-run if task another attempt is running task_id Align the LOG spark3 only Spark 2 (#30) ut (#31) revert ut Refine the check ut (#33) log Ut for spark utils (#36) comments record the reported shuffle fetch failure tasks (#42) nit Address comments from mridul (#44) * revert logger => LOG * taskScheduler instance lock and stage uniq id * docs * listener * spark 2 * comments * test --- .../task/reduce/CelebornShuffleConsumer.java | 1 + .../shuffle/celeborn/SparkShuffleManager.java | 5 + .../spark/shuffle/celeborn/SparkUtils.java | 143 ++++++++++++++++++ .../celeborn/CelebornShuffleReader.scala | 11 +- ...eFetchFailureReportTaskCleanListener.scala | 28 ++++ .../shuffle/celeborn/SparkShuffleManager.java | 4 + .../spark/shuffle/celeborn/SparkUtils.java | 143 ++++++++++++++++++ .../celeborn/CelebornShuffleReader.scala | 3 +- ...eFetchFailureReportTaskCleanListener.scala | 28 ++++ .../apache/celeborn/client/ShuffleClient.java | 5 +- .../celeborn/client/ShuffleClientImpl.java | 5 +- .../client/read/CelebornInputStream.java | 7 +- .../celeborn/client/LifecycleManager.scala | 42 ++++- .../celeborn/client/DummyShuffleClient.java | 3 +- .../client/WithShuffleClientSuite.scala | 2 + common/src/main/proto/TransportMessages.proto | 1 + .../shuffle/celeborn/SparkUtilsSuite.scala | 96 ++++++++++++ .../deploy/cluster/ReadWriteTestBase.scala | 1 + 18 files changed, 515 insertions(+), 13 deletions(-) create mode 100644 client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/ShuffleFetchFailureReportTaskCleanListener.scala create mode 100644 client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/ShuffleFetchFailureReportTaskCleanListener.scala create mode 100644 tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala diff --git a/client-mr/mr/src/main/java/org/apache/hadoop/mapreduce/task/reduce/CelebornShuffleConsumer.java b/client-mr/mr/src/main/java/org/apache/hadoop/mapreduce/task/reduce/CelebornShuffleConsumer.java index 778e243faec..3c3356006b3 100644 --- a/client-mr/mr/src/main/java/org/apache/hadoop/mapreduce/task/reduce/CelebornShuffleConsumer.java +++ b/client-mr/mr/src/main/java/org/apache/hadoop/mapreduce/task/reduce/CelebornShuffleConsumer.java @@ -155,6 +155,7 @@ public void incReadTime(long time) {} reduceId.getTaskID().getId(), reduceId.getId(), 0, + 0, Integer.MAX_VALUE, metricsCallback); CelebornShuffleFetcher shuffleReader = diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index ce0093bd3f9..67f29ecedee 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -101,6 +101,11 @@ private void initializeLifecycleManager(String appId) { if (celebornConf.clientStageRerunEnabled()) { MapOutputTrackerMaster mapOutputTracker = (MapOutputTrackerMaster) SparkEnv.get().mapOutputTracker(); + + lifecycleManager.registerReportTaskShuffleFetchFailurePreCheck( + taskId -> !SparkUtils.taskAnotherAttemptRunningOrSuccessful(taskId)); + SparkUtils.addSparkListener(new ShuffleFetchFailureReportTaskCleanListener()); + lifecycleManager.registerShuffleTrackerCallback( shuffleId -> mapOutputTracker.unregisterAllMapOutput(shuffleId)); } diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index 3f2e4709750..7ac0f658310 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -20,12 +20,19 @@ import java.io.IOException; import java.lang.reflect.Field; import java.lang.reflect.Method; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.LongAdder; +import java.util.stream.Collectors; import scala.Option; import scala.Some; import scala.Tuple2; +import com.google.common.annotations.VisibleForTesting; import org.apache.spark.BarrierTaskContext; import org.apache.spark.SparkConf; import org.apache.spark.SparkContext; @@ -35,6 +42,10 @@ import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; import org.apache.spark.scheduler.ShuffleMapStage; +import org.apache.spark.scheduler.SparkListener; +import org.apache.spark.scheduler.TaskInfo; +import org.apache.spark.scheduler.TaskSchedulerImpl; +import org.apache.spark.scheduler.TaskSetManager; import org.apache.spark.sql.execution.UnsafeRowSerializer; import org.apache.spark.sql.execution.metric.SQLMetric; import org.apache.spark.storage.BlockManagerId; @@ -43,6 +54,7 @@ import org.apache.celeborn.client.ShuffleClient; import org.apache.celeborn.common.CelebornConf; +import org.apache.celeborn.common.util.JavaUtils; import org.apache.celeborn.common.util.Utils; import org.apache.celeborn.reflect.DynFields; @@ -203,4 +215,135 @@ public static void cancelShuffle(int shuffleId, String reason) { logger.error("Can not get active SparkContext, skip cancelShuffle."); } } + + private static final DynFields.UnboundField> + TASK_ID_TO_TASK_SET_MANAGER_FIELD = + DynFields.builder() + .hiddenImpl(TaskSchedulerImpl.class, "taskIdToTaskSetManager") + .defaultAlwaysNull() + .build(); + private static final DynFields.UnboundField> + TASK_INFOS_FIELD = + DynFields.builder() + .hiddenImpl(TaskSetManager.class, "taskInfos") + .defaultAlwaysNull() + .build(); + + /** + * TaskSetManager - it is not designed to be used outside the spark scheduler. Please be careful. + */ + @VisibleForTesting + protected static TaskSetManager getTaskSetManager(TaskSchedulerImpl taskScheduler, long taskId) { + synchronized (taskScheduler) { + ConcurrentHashMap taskIdToTaskSetManager = + TASK_ID_TO_TASK_SET_MANAGER_FIELD.bind(taskScheduler).get(); + return taskIdToTaskSetManager.get(taskId); + } + } + + @VisibleForTesting + protected static Tuple2> getTaskAttempts( + TaskSetManager taskSetManager, long taskId) { + if (taskSetManager != null) { + scala.Option taskInfoOption = + TASK_INFOS_FIELD.bind(taskSetManager).get().get(taskId); + if (taskInfoOption.isDefined()) { + TaskInfo taskInfo = taskInfoOption.get(); + List taskAttempts = + scala.collection.JavaConverters.asJavaCollectionConverter( + taskSetManager.taskAttempts()[taskInfo.index()]) + .asJavaCollection().stream() + .collect(Collectors.toList()); + return Tuple2.apply(taskInfo, taskAttempts); + } else { + logger.error("Can not get TaskInfo for taskId: {}", taskId); + return null; + } + } else { + logger.error("Can not get TaskSetManager for taskId: {}", taskId); + return null; + } + } + + protected static Map> reportedStageShuffleFetchFailureTaskIds = + JavaUtils.newConcurrentHashMap(); + + protected static void removeStageReportedShuffleFetchFailureTaskIds( + int stageId, int stageAttemptId) { + reportedStageShuffleFetchFailureTaskIds.remove(stageId + "-" + stageAttemptId); + } + + /** + * Only used to check for the shuffle fetch failure task whether another attempt is running or + * successful. If another attempt(excluding the reported shuffle fetch failure tasks in current + * stage) is running or successful, return true. Otherwise, return false. + */ + public static boolean taskAnotherAttemptRunningOrSuccessful(long taskId) { + SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null); + if (sparkContext == null) { + logger.error("Can not get active SparkContext."); + return false; + } + TaskSchedulerImpl taskScheduler = (TaskSchedulerImpl) sparkContext.taskScheduler(); + synchronized (taskScheduler) { + TaskSetManager taskSetManager = getTaskSetManager(taskScheduler, taskId); + if (taskSetManager != null) { + int stageId = taskSetManager.stageId(); + int stageAttemptId = taskSetManager.taskSet().stageAttemptId(); + String stageUniqId = stageId + "-" + stageAttemptId; + Set reportedStageTaskIds = + reportedStageShuffleFetchFailureTaskIds.computeIfAbsent( + stageUniqId, k -> new HashSet<>()); + reportedStageTaskIds.add(taskId); + + Tuple2> taskAttempts = getTaskAttempts(taskSetManager, taskId); + + if (taskAttempts == null) return false; + + TaskInfo taskInfo = taskAttempts._1(); + for (TaskInfo ti : taskAttempts._2()) { + if (ti.taskId() != taskId) { + if (reportedStageTaskIds.contains(ti.taskId())) { + logger.info( + "StageId={} index={} taskId={} attempt={} another attempt {} has reported shuffle fetch failure, ignore it.", + stageId, + taskInfo.index(), + taskId, + taskInfo.attemptNumber(), + ti.attemptNumber()); + } else if (ti.successful()) { + logger.info( + "StageId={} index={} taskId={} attempt={} another attempt {} is successful.", + stageId, + taskInfo.index(), + taskId, + taskInfo.attemptNumber(), + ti.attemptNumber()); + return true; + } else if (ti.running()) { + logger.info( + "StageId={} index={} taskId={} attempt={} another attempt {} is running.", + stageId, + taskInfo.index(), + taskId, + taskInfo.attemptNumber(), + ti.attemptNumber()); + return true; + } + } + } + return false; + } else { + logger.error("Can not get TaskSetManager for taskId: {}", taskId); + return false; + } + } + } + + public static void addSparkListener(SparkListener listener) { + SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null); + if (sparkContext != null) { + sparkContext.addSparkListener(listener); + } + } } diff --git a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index fb55e741886..a60c68f3dd1 100644 --- a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -98,6 +98,7 @@ class CelebornShuffleReader[K, C]( shuffleId, partitionId, encodedAttemptId, + context.taskAttemptId(), startMapIndex, endMapIndex, metricsCallback) @@ -124,7 +125,10 @@ class CelebornShuffleReader[K, C]( exceptionRef.get() match { case ce @ (_: CelebornIOException | _: PartitionUnRetryAbleException) => if (handle.throwsFetchFailure && - shuffleClient.reportShuffleFetchFailure(handle.shuffleId, shuffleId)) { + shuffleClient.reportShuffleFetchFailure( + handle.shuffleId, + shuffleId, + context.taskAttemptId())) { throw new FetchFailedException( null, handle.shuffleId, @@ -158,7 +162,10 @@ class CelebornShuffleReader[K, C]( } catch { case e @ (_: CelebornIOException | _: PartitionUnRetryAbleException) => if (handle.throwsFetchFailure && - shuffleClient.reportShuffleFetchFailure(handle.shuffleId, shuffleId)) { + shuffleClient.reportShuffleFetchFailure( + handle.shuffleId, + shuffleId, + context.taskAttemptId())) { throw new FetchFailedException( null, handle.shuffleId, diff --git a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/ShuffleFetchFailureReportTaskCleanListener.scala b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/ShuffleFetchFailureReportTaskCleanListener.scala new file mode 100644 index 00000000000..2e85f969cb9 --- /dev/null +++ b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/ShuffleFetchFailureReportTaskCleanListener.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.celeborn + +import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted} + +class ShuffleFetchFailureReportTaskCleanListener extends SparkListener { + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { + SparkUtils.removeStageReportedShuffleFetchFailureTaskIds( + stageCompleted.stageInfo.stageId, + stageCompleted.stageInfo.attemptNumber()) + } +} diff --git a/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index af3c400ec7c..df28143c674 100644 --- a/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -144,6 +144,10 @@ private void initializeLifecycleManager(String appId) { MapOutputTrackerMaster mapOutputTracker = (MapOutputTrackerMaster) SparkEnv.get().mapOutputTracker(); + lifecycleManager.registerReportTaskShuffleFetchFailurePreCheck( + taskId -> !SparkUtils.taskAnotherAttemptRunningOrSuccessful(taskId)); + SparkUtils.addSparkListener(new ShuffleFetchFailureReportTaskCleanListener()); + lifecycleManager.registerShuffleTrackerCallback( shuffleId -> SparkUtils.unregisterAllMapOutput(mapOutputTracker, shuffleId)); } diff --git a/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index d8a237bc459..6c2e5120e04 100644 --- a/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -17,12 +17,19 @@ package org.apache.spark.shuffle.celeborn; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.LongAdder; +import java.util.stream.Collectors; import scala.Option; import scala.Some; import scala.Tuple2; +import com.google.common.annotations.VisibleForTesting; import org.apache.spark.BarrierTaskContext; import org.apache.spark.MapOutputTrackerMaster; import org.apache.spark.SparkConf; @@ -33,6 +40,10 @@ import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; import org.apache.spark.scheduler.ShuffleMapStage; +import org.apache.spark.scheduler.SparkListener; +import org.apache.spark.scheduler.TaskInfo; +import org.apache.spark.scheduler.TaskSchedulerImpl; +import org.apache.spark.scheduler.TaskSetManager; import org.apache.spark.shuffle.ShuffleHandle; import org.apache.spark.shuffle.ShuffleReadMetricsReporter; import org.apache.spark.shuffle.ShuffleReader; @@ -46,6 +57,7 @@ import org.apache.celeborn.client.ShuffleClient; import org.apache.celeborn.common.CelebornConf; +import org.apache.celeborn.common.util.JavaUtils; import org.apache.celeborn.reflect.DynConstructors; import org.apache.celeborn.reflect.DynFields; import org.apache.celeborn.reflect.DynMethods; @@ -319,4 +331,135 @@ public static void cancelShuffle(int shuffleId, String reason) { LOG.error("Can not get active SparkContext, skip cancelShuffle."); } } + + private static final DynFields.UnboundField> + TASK_ID_TO_TASK_SET_MANAGER_FIELD = + DynFields.builder() + .hiddenImpl(TaskSchedulerImpl.class, "taskIdToTaskSetManager") + .defaultAlwaysNull() + .build(); + private static final DynFields.UnboundField> + TASK_INFOS_FIELD = + DynFields.builder() + .hiddenImpl(TaskSetManager.class, "taskInfos") + .defaultAlwaysNull() + .build(); + + /** + * TaskSetManager - it is not designed to be used outside the spark scheduler. Please be careful. + */ + @VisibleForTesting + protected static TaskSetManager getTaskSetManager(TaskSchedulerImpl taskScheduler, long taskId) { + synchronized (taskScheduler) { + ConcurrentHashMap taskIdToTaskSetManager = + TASK_ID_TO_TASK_SET_MANAGER_FIELD.bind(taskScheduler).get(); + return taskIdToTaskSetManager.get(taskId); + } + } + + @VisibleForTesting + protected static Tuple2> getTaskAttempts( + TaskSetManager taskSetManager, long taskId) { + if (taskSetManager != null) { + scala.Option taskInfoOption = + TASK_INFOS_FIELD.bind(taskSetManager).get().get(taskId); + if (taskInfoOption.isDefined()) { + TaskInfo taskInfo = taskInfoOption.get(); + List taskAttempts = + scala.collection.JavaConverters.asJavaCollectionConverter( + taskSetManager.taskAttempts()[taskInfo.index()]) + .asJavaCollection().stream() + .collect(Collectors.toList()); + return Tuple2.apply(taskInfo, taskAttempts); + } else { + LOG.error("Can not get TaskInfo for taskId: {}", taskId); + return null; + } + } else { + LOG.error("Can not get TaskSetManager for taskId: {}", taskId); + return null; + } + } + + protected static Map> reportedStageShuffleFetchFailureTaskIds = + JavaUtils.newConcurrentHashMap(); + + protected static void removeStageReportedShuffleFetchFailureTaskIds( + int stageId, int stageAttemptId) { + reportedStageShuffleFetchFailureTaskIds.remove(stageId + "-" + stageAttemptId); + } + + /** + * Only used to check for the shuffle fetch failure task whether another attempt is running or + * successful. If another attempt(excluding the reported shuffle fetch failure tasks in current + * stage) is running or successful, return true. Otherwise, return false. + */ + public static boolean taskAnotherAttemptRunningOrSuccessful(long taskId) { + SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null); + if (sparkContext == null) { + LOG.error("Can not get active SparkContext."); + return false; + } + TaskSchedulerImpl taskScheduler = (TaskSchedulerImpl) sparkContext.taskScheduler(); + synchronized (taskScheduler) { + TaskSetManager taskSetManager = getTaskSetManager(taskScheduler, taskId); + if (taskSetManager != null) { + int stageId = taskSetManager.stageId(); + int stageAttemptId = taskSetManager.taskSet().stageAttemptId(); + String stageUniqId = stageId + "-" + stageAttemptId; + Set reportedStageTaskIds = + reportedStageShuffleFetchFailureTaskIds.computeIfAbsent( + stageUniqId, k -> new HashSet<>()); + reportedStageTaskIds.add(taskId); + + Tuple2> taskAttempts = getTaskAttempts(taskSetManager, taskId); + + if (taskAttempts == null) return false; + + TaskInfo taskInfo = taskAttempts._1(); + for (TaskInfo ti : taskAttempts._2()) { + if (ti.taskId() != taskId) { + if (reportedStageTaskIds.contains(ti.taskId())) { + LOG.info( + "StageId={} index={} taskId={} attempt={} another attempt {} has reported shuffle fetch failure, ignore it.", + stageId, + taskInfo.index(), + taskId, + taskInfo.attemptNumber(), + ti.attemptNumber()); + } else if (ti.successful()) { + LOG.info( + "StageId={} index={} taskId={} attempt={} another attempt {} is successful.", + stageId, + taskInfo.index(), + taskId, + taskInfo.attemptNumber(), + ti.attemptNumber()); + return true; + } else if (ti.running()) { + LOG.info( + "StageId={} index={} taskId={} attempt={} another attempt {} is running.", + stageId, + taskInfo.index(), + taskId, + taskInfo.attemptNumber(), + ti.attemptNumber()); + return true; + } + } + } + return false; + } else { + LOG.error("Can not get TaskSetManager for taskId: {}", taskId); + return false; + } + } + } + + public static void addSparkListener(SparkListener listener) { + SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null); + if (sparkContext != null) { + sparkContext.addSparkListener(listener); + } + } } diff --git a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index e3e1c9198f0..ad802f8cef9 100644 --- a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -215,6 +215,7 @@ class CelebornShuffleReader[K, C]( handle.shuffleId, partitionId, encodedAttemptId, + context.taskAttemptId(), startMapIndex, endMapIndex, if (throwsFetchFailure) ExceptionMakerHelper.SHUFFLE_FETCH_FAILURE_EXCEPTION_MAKER @@ -375,7 +376,7 @@ class CelebornShuffleReader[K, C]( partitionId: Int, ce: Throwable) = { if (throwsFetchFailure && - shuffleClient.reportShuffleFetchFailure(appShuffleId, shuffleId)) { + shuffleClient.reportShuffleFetchFailure(appShuffleId, shuffleId, context.taskAttemptId())) { logWarning(s"Handle fetch exceptions for ${shuffleId}-${partitionId}", ce) throw new FetchFailedException( null, diff --git a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/ShuffleFetchFailureReportTaskCleanListener.scala b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/ShuffleFetchFailureReportTaskCleanListener.scala new file mode 100644 index 00000000000..2e85f969cb9 --- /dev/null +++ b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/ShuffleFetchFailureReportTaskCleanListener.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.celeborn + +import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted} + +class ShuffleFetchFailureReportTaskCleanListener extends SparkListener { + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { + SparkUtils.removeStageReportedShuffleFetchFailureTaskIds( + stageCompleted.stageInfo.stageId, + stageCompleted.stageInfo.attemptNumber()) + } +} diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java index efa9641f671..673c9382437 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java @@ -224,6 +224,7 @@ public CelebornInputStream readPartition( int shuffleId, int partitionId, int attemptNumber, + long taskId, int startMapIndex, int endMapIndex, MetricsCallback metricsCallback) @@ -233,6 +234,7 @@ public CelebornInputStream readPartition( shuffleId, partitionId, attemptNumber, + taskId, startMapIndex, endMapIndex, null, @@ -247,6 +249,7 @@ public abstract CelebornInputStream readPartition( int appShuffleId, int partitionId, int attemptNumber, + long taskId, int startMapIndex, int endMapIndex, ExceptionMaker exceptionMaker, @@ -276,7 +279,7 @@ public abstract int getShuffleId( * cleanup for spark app. It must be a sync call and make sure the cleanup is done, otherwise, * incorrect shuffle data can be fetched in re-run tasks */ - public abstract boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId); + public abstract boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId, long taskId); /** * Report barrier task failure. When any barrier task fails, all (shuffle) output for that stage diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index cde9fc043c2..48412f81f12 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -630,11 +630,12 @@ public int getShuffleId( } @Override - public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId) { + public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId, long taskId) { PbReportShuffleFetchFailure pbReportShuffleFetchFailure = PbReportShuffleFetchFailure.newBuilder() .setAppShuffleId(appShuffleId) .setShuffleId(shuffleId) + .setTaskId(taskId) .build(); PbReportShuffleFetchFailureResponse pbReportShuffleFetchFailureResponse = lifecycleManagerRef.askSync( @@ -1845,6 +1846,7 @@ public CelebornInputStream readPartition( int appShuffleId, int partitionId, int attemptNumber, + long taskId, int startMapIndex, int endMapIndex, ExceptionMaker exceptionMaker, @@ -1883,6 +1885,7 @@ public CelebornInputStream readPartition( streamHandlers, mapAttempts, attemptNumber, + taskId, startMapIndex, endMapIndex, fetchExcludedWorkers, diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index 2b02da2baf4..fc295a746c8 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -55,6 +55,7 @@ public static CelebornInputStream create( ArrayList streamHandlers, int[] attempts, int attemptNumber, + long taskId, int startMapIndex, int endMapIndex, ConcurrentHashMap fetchExcludedWorkers, @@ -76,6 +77,7 @@ public static CelebornInputStream create( streamHandlers, attempts, attemptNumber, + taskId, startMapIndex, endMapIndex, fetchExcludedWorkers, @@ -129,6 +131,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { private ArrayList streamHandlers; private int[] attempts; private final int attemptNumber; + private final long taskId; private final int startMapIndex; private final int endMapIndex; @@ -178,6 +181,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { ArrayList streamHandlers, int[] attempts, int attemptNumber, + long taskId, int startMapIndex, int endMapIndex, ConcurrentHashMap fetchExcludedWorkers, @@ -197,6 +201,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { } this.attempts = attempts; this.attemptNumber = attemptNumber; + this.taskId = taskId; this.startMapIndex = startMapIndex; this.endMapIndex = endMapIndex; this.rangeReadFilter = conf.shuffleRangeReadFilterEnabled(); @@ -654,7 +659,7 @@ private boolean fillBuffer() throws IOException { ioe = new IOException(e); } if (exceptionMaker != null) { - if (shuffleClient.reportShuffleFetchFailure(appShuffleId, shuffleId)) { + if (shuffleClient.reportShuffleFetchFailure(appShuffleId, shuffleId, taskId)) { /* * [[ExceptionMaker.makeException]], for spark applications with celeborn.client.spark.fetch.throwsFetchFailure enabled will result in creating * a FetchFailedException; and that will make the TaskContext as failed with shuffle fetch issues - see SPARK-19276 for more. diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index 4f69a05bc34..94b5a0676bf 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -445,8 +445,9 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends case pb: PbReportShuffleFetchFailure => val appShuffleId = pb.getAppShuffleId val shuffleId = pb.getShuffleId - logDebug(s"Received ReportShuffleFetchFailure request, appShuffleId $appShuffleId shuffleId $shuffleId") - handleReportShuffleFetchFailure(context, appShuffleId, shuffleId) + val taskId = pb.getTaskId + logDebug(s"Received ReportShuffleFetchFailure request, appShuffleId $appShuffleId shuffleId $shuffleId taskId $taskId") + handleReportShuffleFetchFailure(context, appShuffleId, shuffleId, taskId) case pb: PbReportBarrierStageAttemptFailure => val appShuffleId = pb.getAppShuffleId @@ -935,7 +936,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends private def handleReportShuffleFetchFailure( context: RpcCallContext, appShuffleId: Int, - shuffleId: Int): Unit = { + shuffleId: Int, + taskId: Long): Unit = { val shuffleIds = shuffleIdMapping.get(appShuffleId) if (shuffleIds == null) { @@ -945,9 +947,15 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends shuffleIds.synchronized { shuffleIds.find(e => e._2._1 == shuffleId) match { case Some((appShuffleIdentifier, (shuffleId, true))) => - logInfo(s"handle fetch failure for appShuffleId $appShuffleId shuffleId $shuffleId") - ret = invokeAppShuffleTrackerCallback(appShuffleId) - shuffleIds.put(appShuffleIdentifier, (shuffleId, false)) + if (invokeReportTaskShuffleFetchFailurePreCheck(taskId)) { + logInfo(s"handle fetch failure for appShuffleId $appShuffleId shuffleId $shuffleId") + ret = invokeAppShuffleTrackerCallback(appShuffleId) + shuffleIds.put(appShuffleIdentifier, (shuffleId, false)) + } else { + logInfo( + s"Ignoring fetch failure from appShuffleIdentifier $appShuffleIdentifier shuffleId $shuffleId taskId $taskId") + ret = false + } case Some((appShuffleIdentifier, (shuffleId, false))) => logInfo( s"Ignoring fetch failure from appShuffleIdentifier $appShuffleIdentifier shuffleId $shuffleId, " + @@ -1010,6 +1018,20 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends } } + private def invokeReportTaskShuffleFetchFailurePreCheck(taskId: Long): Boolean = { + reportTaskShuffleFetchFailurePreCheck match { + case Some(preCheck) => + try { + preCheck.apply(taskId) + } catch { + case t: Throwable => + logError(s"Error preChecking the shuffle fetch failure reported by task: $taskId", t) + false + } + case None => true + } + } + private def handleStageEnd(shuffleId: Int): Unit = { // check whether shuffle has registered if (!registeredShuffle.contains(shuffleId)) { @@ -1770,6 +1792,14 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends workerStatusTracker.registerWorkerStatusListener(workerStatusListener) } + @volatile private var reportTaskShuffleFetchFailurePreCheck + : Option[java.util.function.Function[java.lang.Long, Boolean]] = None + def registerReportTaskShuffleFetchFailurePreCheck(preCheck: java.util.function.Function[ + java.lang.Long, + Boolean]): Unit = { + reportTaskShuffleFetchFailurePreCheck = Some(preCheck) + } + @volatile private var appShuffleTrackerCallback: Option[Consumer[Integer]] = None def registerShuffleTrackerCallback(callback: Consumer[Integer]): Unit = { appShuffleTrackerCallback = Some(callback) diff --git a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java index a78e34be6f0..77a9c784c4a 100644 --- a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java +++ b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java @@ -130,6 +130,7 @@ public CelebornInputStream readPartition( int appShuffleId, int partitionId, int attemptNumber, + long taskId, int startMapIndex, int endMapIndex, ExceptionMaker exceptionMaker, @@ -179,7 +180,7 @@ public int getShuffleId( } @Override - public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId) { + public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId, long taskId) { return true; } diff --git a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala index d5631d5b055..8cd8bedf76d 100644 --- a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala +++ b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala @@ -158,6 +158,7 @@ trait WithShuffleClientSuite extends CelebornFunSuite { 1, 1, 0, + 0, Integer.MAX_VALUE, null, null, @@ -173,6 +174,7 @@ trait WithShuffleClientSuite extends CelebornFunSuite { 3, 1, 0, + 0, Integer.MAX_VALUE, null, null, diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto index d7b439a68aa..8aa59bcb29b 100644 --- a/common/src/main/proto/TransportMessages.proto +++ b/common/src/main/proto/TransportMessages.proto @@ -391,6 +391,7 @@ message PbGetShuffleIdResponse { message PbReportShuffleFetchFailure { int32 appShuffleId = 1; int32 shuffleId = 2; + int64 taskId = 3; } message PbReportShuffleFetchFailureResponse { diff --git a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala new file mode 100644 index 00000000000..2edfaf898e4 --- /dev/null +++ b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.celeborn + +import org.apache.spark.SparkConf +import org.apache.spark.scheduler.TaskSchedulerImpl +import org.apache.spark.sql.SparkSession +import org.scalatest.BeforeAndAfterEach +import org.scalatest.concurrent.Eventually.eventually +import org.scalatest.concurrent.Futures.{interval, timeout} +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.time.SpanSugar.convertIntToGrainOfTime + +import org.apache.celeborn.client.ShuffleClient +import org.apache.celeborn.common.protocol.ShuffleMode +import org.apache.celeborn.tests.spark.SparkTestBase + +class SparkUtilsSuite extends AnyFunSuite + with SparkTestBase + with BeforeAndAfterEach { + + override def beforeEach(): Unit = { + ShuffleClient.reset() + } + + override def afterEach(): Unit = { + System.gc() + } + + test("check if fetch failure task another attempt is running or successful") { + val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]") + val sparkSession = SparkSession.builder() + .config(updateSparkConf(sparkConf, ShuffleMode.HASH)) + .config("spark.sql.shuffle.partitions", 2) + .config("spark.celeborn.shuffle.forceFallback.partition.enabled", false) + .config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true") + .config( + "spark.shuffle.manager", + "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager") + .getOrCreate() + + try { + val sc = sparkSession.sparkContext + val jobThread = new Thread { + override def run(): Unit = { + try { + sc.parallelize(1 to 100, 2) + .repartition(1) + .mapPartitions { iter => + Thread.sleep(3000) + iter + }.collect() + } catch { + case _: InterruptedException => + } + } + } + jobThread.start() + + val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl] + eventually(timeout(3.seconds), interval(100.milliseconds)) { + val taskId = 0 + val taskSetManager = SparkUtils.getTaskSetManager(taskScheduler, taskId) + assert(taskSetManager != null) + assert(SparkUtils.getTaskAttempts(taskSetManager, taskId)._2.size() == 1) + assert(!SparkUtils.taskAnotherAttemptRunningOrSuccessful(taskId)) + assert(SparkUtils.reportedStageShuffleFetchFailureTaskIds.size() == 1) + } + + sparkSession.sparkContext.cancelAllJobs() + + jobThread.interrupt() + + eventually(timeout(3.seconds), interval(100.milliseconds)) { + assert(SparkUtils.reportedStageShuffleFetchFailureTaskIds.size() == 0) + } + } finally { + sparkSession.stop() + } + } +} diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala index d62a78515c5..0ff646b8f8b 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala @@ -110,6 +110,7 @@ trait ReadWriteTestBase extends AnyFunSuite 0, 0, 0, + 0, Integer.MAX_VALUE, null, null, From c35172c8e55b97c410075c02d6ba13c4f9055466 Mon Sep 17 00:00:00 2001 From: Fei Wang Date: Fri, 27 Dec 2024 15:41:15 -0800 Subject: [PATCH 2/4] fetch failure integration testing (#46) --- .../spark/CelebornFetchFailureSuite.scala | 66 ++----------------- .../celeborn/tests/spark/SparkTestBase.scala | 62 ++++++++++++++++- .../shuffle/celeborn/SparkUtilsSuite.scala | 26 +++++--- 3 files changed, 84 insertions(+), 70 deletions(-) diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala index 1703ad0b8f2..dd0f3840149 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala @@ -17,22 +17,19 @@ package org.apache.celeborn.tests.spark -import java.io.{File, IOException} +import java.io.IOException import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark.{BarrierTaskContext, ShuffleDependency, SparkConf, SparkContextHelper, SparkException, TaskContext} import org.apache.spark.celeborn.ExceptionMakerHelper import org.apache.spark.rdd.RDD -import org.apache.spark.shuffle.ShuffleHandle -import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, ShuffleManagerHook, SparkShuffleManager, SparkUtils, TestCelebornShuffleManager} +import org.apache.spark.shuffle.celeborn.{SparkShuffleManager, SparkUtils, TestCelebornShuffleManager} import org.apache.spark.sql.SparkSession import org.scalatest.BeforeAndAfterEach import org.scalatest.funsuite.AnyFunSuite import org.apache.celeborn.client.ShuffleClient -import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.protocol.ShuffleMode -import org.apache.celeborn.service.deploy.worker.Worker class CelebornFetchFailureSuite extends AnyFunSuite with SparkTestBase @@ -46,57 +43,6 @@ class CelebornFetchFailureSuite extends AnyFunSuite System.gc() } - var workerDirs: Seq[String] = Seq.empty - - override def createWorker(map: Map[String, String]): Worker = { - val storageDir = createTmpDir() - this.synchronized { - workerDirs = workerDirs :+ storageDir - } - super.createWorker(map, storageDir) - } - - class ShuffleReaderGetHook(conf: CelebornConf) extends ShuffleManagerHook { - var executed: AtomicBoolean = new AtomicBoolean(false) - val lock = new Object - - override def exec( - handle: ShuffleHandle, - startPartition: Int, - endPartition: Int, - context: TaskContext): Unit = { - if (executed.get() == true) return - - lock.synchronized { - handle match { - case h: CelebornShuffleHandle[_, _, _] => { - val appUniqueId = h.appUniqueId - val shuffleClient = ShuffleClient.get( - h.appUniqueId, - h.lifecycleManagerHost, - h.lifecycleManagerPort, - conf, - h.userIdentifier, - h.extension) - val celebornShuffleId = SparkUtils.celebornShuffleId(shuffleClient, h, context, false) - val allFiles = workerDirs.map(dir => { - new File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId") - }) - val datafile = allFiles.filter(_.exists()) - .flatMap(_.listFiles().iterator).headOption - datafile match { - case Some(file) => file.delete() - case None => throw new RuntimeException("unexpected, there must be some data file" + - s" under ${workerDirs.mkString(",")}") - } - } - case _ => throw new RuntimeException("unexpected, only support RssShuffleHandle here") - } - executed.set(true) - } - } - } - test("celeborn spark integration test - Fetch Failure") { if (Spark3OrNewer) { val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]") @@ -111,7 +57,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite .getOrCreate() val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf) - val hook = new ShuffleReaderGetHook(celebornConf) + val hook = new ShuffleReaderFetchFailureGetHook(celebornConf) TestCelebornShuffleManager.registerReaderGetHook(hook) val value = Range(1, 10000).mkString(",") @@ -184,7 +130,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite .getOrCreate() val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf) - val hook = new ShuffleReaderGetHook(celebornConf) + val hook = new ShuffleReaderFetchFailureGetHook(celebornConf) TestCelebornShuffleManager.registerReaderGetHook(hook) import sparkSession.implicits._ @@ -215,7 +161,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite .getOrCreate() val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf) - val hook = new ShuffleReaderGetHook(celebornConf) + val hook = new ShuffleReaderFetchFailureGetHook(celebornConf) TestCelebornShuffleManager.registerReaderGetHook(hook) val sc = sparkSession.sparkContext @@ -255,7 +201,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite .getOrCreate() val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf) - val hook = new ShuffleReaderGetHook(celebornConf) + val hook = new ShuffleReaderFetchFailureGetHook(celebornConf) TestCelebornShuffleManager.registerReaderGetHook(hook) val sc = sparkSession.sparkContext diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala index c92ec4c9d3c..999abc053d3 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala @@ -17,19 +17,26 @@ package org.apache.celeborn.tests.spark +import java.io.File +import java.util.concurrent.atomic.AtomicBoolean + import scala.util.Random -import org.apache.spark.SPARK_VERSION -import org.apache.spark.SparkConf +import org.apache.spark.{SPARK_VERSION, SparkConf, TaskContext} +import org.apache.spark.shuffle.ShuffleHandle +import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, ShuffleManagerHook, SparkUtils} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.internal.SQLConf import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.scalatest.funsuite.AnyFunSuite +import org.apache.celeborn.client.ShuffleClient +import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.CelebornConf._ import org.apache.celeborn.common.internal.Logging import org.apache.celeborn.common.protocol.ShuffleMode import org.apache.celeborn.service.deploy.MiniClusterFeature +import org.apache.celeborn.service.deploy.worker.Worker trait SparkTestBase extends AnyFunSuite with Logging with MiniClusterFeature with BeforeAndAfterAll with BeforeAndAfterEach { @@ -52,6 +59,16 @@ trait SparkTestBase extends AnyFunSuite shutdownMiniCluster() } + var workerDirs: Seq[String] = Seq.empty + + override def createWorker(map: Map[String, String]): Worker = { + val storageDir = createTmpDir() + this.synchronized { + workerDirs = workerDirs :+ storageDir + } + super.createWorker(map, storageDir) + } + def updateSparkConf(sparkConf: SparkConf, mode: ShuffleMode): SparkConf = { sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") sparkConf.set( @@ -98,4 +115,45 @@ trait SparkTestBase extends AnyFunSuite val outMap = result.collect().map(row => row.getString(0) -> row.getLong(1)).toMap outMap } + + class ShuffleReaderFetchFailureGetHook(conf: CelebornConf) extends ShuffleManagerHook { + var executed: AtomicBoolean = new AtomicBoolean(false) + val lock = new Object + + override def exec( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext): Unit = { + if (executed.get() == true) return + + lock.synchronized { + handle match { + case h: CelebornShuffleHandle[_, _, _] => { + val appUniqueId = h.appUniqueId + val shuffleClient = ShuffleClient.get( + h.appUniqueId, + h.lifecycleManagerHost, + h.lifecycleManagerPort, + conf, + h.userIdentifier, + h.extension) + val celebornShuffleId = SparkUtils.celebornShuffleId(shuffleClient, h, context, false) + val allFiles = workerDirs.map(dir => { + new File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId") + }) + val datafile = allFiles.filter(_.exists()) + .flatMap(_.listFiles().iterator).sortBy(_.getName).headOption + datafile match { + case Some(file) => file.delete() + case None => throw new RuntimeException("unexpected, there must be some data file" + + s" under ${workerDirs.mkString(",")}") + } + } + case _ => throw new RuntimeException("unexpected, only support RssShuffleHandle here") + } + executed.set(true) + } + } + } } diff --git a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala index 2edfaf898e4..2d753ff7b17 100644 --- a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.shuffle.celeborn +import scala.collection.JavaConverters._ + import org.apache.spark.SparkConf import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.sql.SparkSession @@ -54,13 +56,19 @@ class SparkUtilsSuite extends AnyFunSuite "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager") .getOrCreate() + val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf) + val hook = new ShuffleReaderFetchFailureGetHook(celebornConf) + TestCelebornShuffleManager.registerReaderGetHook(hook) + try { val sc = sparkSession.sparkContext val jobThread = new Thread { override def run(): Unit = { try { - sc.parallelize(1 to 100, 2) - .repartition(1) + val value = Range(1, 10000).mkString(",") + sc.parallelize(1 to 10000, 2) + .map { i => (i, value) } + .groupByKey(10) .mapPartitions { iter => Thread.sleep(3000) iter @@ -73,13 +81,15 @@ class SparkUtilsSuite extends AnyFunSuite jobThread.start() val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl] - eventually(timeout(3.seconds), interval(100.milliseconds)) { - val taskId = 0 - val taskSetManager = SparkUtils.getTaskSetManager(taskScheduler, taskId) + eventually(timeout(30.seconds), interval(0.milliseconds)) { + assert(hook.executed.get() == true) + val reportedTaskId = + SparkUtils.reportedStageShuffleFetchFailureTaskIds.values().asScala.flatMap( + _.asScala).head + val taskSetManager = SparkUtils.getTaskSetManager(taskScheduler, reportedTaskId) assert(taskSetManager != null) - assert(SparkUtils.getTaskAttempts(taskSetManager, taskId)._2.size() == 1) - assert(!SparkUtils.taskAnotherAttemptRunningOrSuccessful(taskId)) - assert(SparkUtils.reportedStageShuffleFetchFailureTaskIds.size() == 1) + assert(SparkUtils.getTaskAttempts(taskSetManager, reportedTaskId)._2.size() == 1) + assert(!SparkUtils.taskAnotherAttemptRunningOrSuccessful(reportedTaskId)) } sparkSession.sparkContext.cancelAllJobs() From 1f252e770746b67d6dcb7b0f1dfdcebd98492546 Mon Sep 17 00:00:00 2001 From: "Wang, Fei" Date: Fri, 27 Dec 2024 17:28:01 -0800 Subject: [PATCH 3/4] spark 3 fetch failure --- .../shuffle/celeborn/SparkUtilsSuite.scala | 96 ++++++++++--------- 1 file changed, 49 insertions(+), 47 deletions(-) diff --git a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala index 2d753ff7b17..0241ccc1f84 100644 --- a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala @@ -45,62 +45,64 @@ class SparkUtilsSuite extends AnyFunSuite } test("check if fetch failure task another attempt is running or successful") { - val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]") - val sparkSession = SparkSession.builder() - .config(updateSparkConf(sparkConf, ShuffleMode.HASH)) - .config("spark.sql.shuffle.partitions", 2) - .config("spark.celeborn.shuffle.forceFallback.partition.enabled", false) - .config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true") - .config( - "spark.shuffle.manager", - "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager") - .getOrCreate() + if (Spark3OrNewer) { + val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]") + val sparkSession = SparkSession.builder() + .config(updateSparkConf(sparkConf, ShuffleMode.HASH)) + .config("spark.sql.shuffle.partitions", 2) + .config("spark.celeborn.shuffle.forceFallback.partition.enabled", false) + .config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true") + .config( + "spark.shuffle.manager", + "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager") + .getOrCreate() - val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf) - val hook = new ShuffleReaderFetchFailureGetHook(celebornConf) - TestCelebornShuffleManager.registerReaderGetHook(hook) + val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf) + val hook = new ShuffleReaderFetchFailureGetHook(celebornConf) + TestCelebornShuffleManager.registerReaderGetHook(hook) - try { - val sc = sparkSession.sparkContext - val jobThread = new Thread { - override def run(): Unit = { - try { - val value = Range(1, 10000).mkString(",") - sc.parallelize(1 to 10000, 2) - .map { i => (i, value) } - .groupByKey(10) - .mapPartitions { iter => - Thread.sleep(3000) - iter - }.collect() - } catch { - case _: InterruptedException => + try { + val sc = sparkSession.sparkContext + val jobThread = new Thread { + override def run(): Unit = { + try { + val value = Range(1, 10000).mkString(",") + sc.parallelize(1 to 10000, 2) + .map { i => (i, value) } + .groupByKey(10) + .mapPartitions { iter => + Thread.sleep(3000) + iter + }.collect() + } catch { + case _: InterruptedException => + } } } - } - jobThread.start() + jobThread.start() - val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl] - eventually(timeout(30.seconds), interval(0.milliseconds)) { - assert(hook.executed.get() == true) - val reportedTaskId = - SparkUtils.reportedStageShuffleFetchFailureTaskIds.values().asScala.flatMap( - _.asScala).head - val taskSetManager = SparkUtils.getTaskSetManager(taskScheduler, reportedTaskId) - assert(taskSetManager != null) - assert(SparkUtils.getTaskAttempts(taskSetManager, reportedTaskId)._2.size() == 1) - assert(!SparkUtils.taskAnotherAttemptRunningOrSuccessful(reportedTaskId)) - } + val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl] + eventually(timeout(30.seconds), interval(0.milliseconds)) { + assert(hook.executed.get() == true) + val reportedTaskId = + SparkUtils.reportedStageShuffleFetchFailureTaskIds.values().asScala.flatMap( + _.asScala).head + val taskSetManager = SparkUtils.getTaskSetManager(taskScheduler, reportedTaskId) + assert(taskSetManager != null) + assert(SparkUtils.getTaskAttempts(taskSetManager, reportedTaskId)._2.size() == 1) + assert(!SparkUtils.taskAnotherAttemptRunningOrSuccessful(reportedTaskId)) + } - sparkSession.sparkContext.cancelAllJobs() + sparkSession.sparkContext.cancelAllJobs() - jobThread.interrupt() + jobThread.interrupt() - eventually(timeout(3.seconds), interval(100.milliseconds)) { - assert(SparkUtils.reportedStageShuffleFetchFailureTaskIds.size() == 0) + eventually(timeout(3.seconds), interval(100.milliseconds)) { + assert(SparkUtils.reportedStageShuffleFetchFailureTaskIds.size() == 0) + } + } finally { + sparkSession.stop() } - } finally { - sparkSession.stop() } } } From fbca79ab6efb926513181b929405e645db1a69cc Mon Sep 17 00:00:00 2001 From: "Wang, Fei" Date: Fri, 27 Dec 2024 17:30:38 -0800 Subject: [PATCH 4/4] basic methods testing for both spark2 and spark3 --- .../shuffle/celeborn/SparkUtilsSuite.scala | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala index 0241ccc1f84..6b4bc13b8bf 100644 --- a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala @@ -105,4 +105,56 @@ class SparkUtilsSuite extends AnyFunSuite } } } + + test("getTaskSetManager/getTaskAttempts test") { + val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]") + val sparkSession = SparkSession.builder() + .config(updateSparkConf(sparkConf, ShuffleMode.HASH)) + .config("spark.sql.shuffle.partitions", 2) + .config("spark.celeborn.shuffle.forceFallback.partition.enabled", false) + .config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true") + .config( + "spark.shuffle.manager", + "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager") + .getOrCreate() + + try { + val sc = sparkSession.sparkContext + val jobThread = new Thread { + override def run(): Unit = { + try { + sc.parallelize(1 to 100, 2) + .repartition(1) + .mapPartitions { iter => + Thread.sleep(3000) + iter + }.collect() + } catch { + case _: InterruptedException => + } + } + } + jobThread.start() + + val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl] + eventually(timeout(3.seconds), interval(100.milliseconds)) { + val taskId = 0 + val taskSetManager = SparkUtils.getTaskSetManager(taskScheduler, taskId) + assert(taskSetManager != null) + assert(SparkUtils.getTaskAttempts(taskSetManager, taskId)._2.size() == 1) + assert(!SparkUtils.taskAnotherAttemptRunningOrSuccessful(taskId)) + assert(SparkUtils.reportedStageShuffleFetchFailureTaskIds.size() == 1) + } + + sparkSession.sparkContext.cancelAllJobs() + + jobThread.interrupt() + + eventually(timeout(3.seconds), interval(100.milliseconds)) { + assert(SparkUtils.reportedStageShuffleFetchFailureTaskIds.size() == 0) + } + } finally { + sparkSession.stop() + } + } }