Skip to content

Commit

Permalink
Make calls on cacheArtifacts/reconnect non-blocking (#556)
Browse files Browse the repository at this point in the history
* unblock calls on cache+reconnect

* refactor submit/cancel task

* rename
  • Loading branch information
Andyz26 authored Sep 14, 2023
1 parent adeca1f commit 3ebcf81
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ CompletableFuture<TaskExecutorID> getTaskExecutorFor(
*/
CompletableFuture<TaskExecutorGateway> getTaskExecutorGateway(TaskExecutorID taskExecutorID);

CompletableFuture<TaskExecutorGateway> reconnectTaskExecutorGateway(TaskExecutorID taskExecutorID);
CompletableFuture<Ack> reconnectGateway(TaskExecutorID taskExecutorID);

CompletableFuture<TaskExecutorRegistration> getTaskExecutorInfo(String hostName);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
import io.mantisrx.server.master.resourcecluster.TaskExecutorReport.Occupied;
import io.mantisrx.server.master.resourcecluster.TaskExecutorStatusChange;
import io.mantisrx.server.master.scheduler.JobMessageRouter;
import io.mantisrx.server.worker.TaskExecutorGateway;
import io.mantisrx.server.worker.TaskExecutorGateway.TaskNotFoundException;
import io.mantisrx.shaded.com.google.common.base.Preconditions;
import io.mantisrx.shaded.com.google.common.collect.Comparators;
Expand Down Expand Up @@ -358,13 +357,14 @@ private void onTaskExecutorGatewayRequest(TaskExecutorGatewayRequest request) {
} else {
try {
if (state.isRegistered()) {
sender().tell(state.getGateway(), self());
sender().tell(state.getGatewayAsync(), self());
} else {
sender().tell(
new Status.Failure(new IllegalStateException("Unregistered TaskExecutor: " + request.getTaskExecutorID())),
self());
}
} catch (Exception e) {
log.error("onTaskExecutorGatewayRequest error: {}", request, e);
metrics.incrementCounter(
ResourceClusterActorMetrics.TE_CONNECTION_FAILURE,
TagList.create(ImmutableMap.of(
Expand Down Expand Up @@ -399,7 +399,12 @@ private void onTaskExecutorGatewayReconnectRequest(TaskExecutorGatewayReconnectR
} else {
try {
if (state.isRegistered()) {
sender().tell(state.reconnect().join(), self());
state.reconnect().whenComplete((res, throwable) -> {
if (throwable != null) {
log.error("failed to reconnect to {}", request.getTaskExecutorID(), throwable);
}
});
sender().tell(Ack.getInstance(), self());
} else {
sender().tell(
new Status.Failure(
Expand Down Expand Up @@ -813,11 +818,22 @@ private void onCacheJobArtifactsOnTaskExecutorRequest(CacheJobArtifactsOnTaskExe
TaskExecutorState state = this.executorStateManager.get(request.getTaskExecutorID());
if (state != null && state.isRegistered()) {
try {
TaskExecutorGateway gateway = state.getGateway();
// TODO(fdichiara): store URI directly to avoid remapping for each TE
List<URI> artifacts = jobArtifactsToCache.stream().map(artifactID -> URI.create(artifactID.getResourceID())).collect(Collectors.toList());

gateway.cacheJobArtifacts(new CacheJobArtifactsRequest(artifacts));
state.getGatewayAsync()
.thenComposeAsync(taskExecutorGateway ->
taskExecutorGateway.cacheJobArtifacts(new CacheJobArtifactsRequest(
jobArtifactsToCache
.stream()
.map(artifactID -> URI.create(artifactID.getResourceID()))
.collect(Collectors.toList()))))
.whenComplete((res, throwable) -> {
if (throwable != null) {
log.error("failed to cache artifact on {}", request.getTaskExecutorID(), throwable);
}
else {
log.debug("Acked from cacheJobArtifacts for {}", request.getTaskExecutorID());
}
});
} catch (Exception ex) {
log.warn("Failed to cache job artifacts in task executor {}", request.getTaskExecutorID(), ex);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,20 +197,29 @@ public CompletableFuture<TaskExecutorID> getTaskExecutorAssignedFor(WorkerId wor
public CompletableFuture<TaskExecutorGateway> getTaskExecutorGateway(
TaskExecutorID taskExecutorID) {
return
Patterns
.ask(resourceClusterManagerActor, new TaskExecutorGatewayRequest(taskExecutorID, clusterID), askTimeout)
.thenApply(TaskExecutorGateway.class::cast)
.toCompletableFuture();
(CompletableFuture<TaskExecutorGateway>) Patterns
.ask(resourceClusterManagerActor, new TaskExecutorGatewayRequest(taskExecutorID, clusterID),
askTimeout)
.thenComposeAsync(result -> {
if (result instanceof CompletableFuture) {
return (CompletableFuture<TaskExecutorGateway>) result;
} else {
CompletableFuture<TaskExecutorGateway> exceptionFuture = new CompletableFuture<>();
exceptionFuture.completeExceptionally(new RuntimeException(
"Unexpected object type on getTaskExecutorGateway: " + result.getClass().getName()));
return exceptionFuture;
}
});
}

@Override
public CompletableFuture<TaskExecutorGateway> reconnectTaskExecutorGateway(
public CompletableFuture<Ack> reconnectGateway(
TaskExecutorID taskExecutorID) {
return
Patterns
.ask(resourceClusterManagerActor, new TaskExecutorGatewayReconnectRequest(taskExecutorID, clusterID),
askTimeout)
.thenApply(TaskExecutorGateway.class::cast)
.thenApply(Ack.class::cast)
.toCompletableFuture();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import java.time.Instant;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import javax.annotation.Nullable;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
Expand Down Expand Up @@ -252,11 +251,11 @@ TaskExecutorRegistration getRegistration() {
return this.registration;
}

protected TaskExecutorGateway getGateway() throws ExecutionException, InterruptedException {
protected CompletableFuture<TaskExecutorGateway> getGatewayAsync() {
if (this.gateway == null) {
throw new IllegalStateException("gateway is null");
}
return this.gateway.get();
return this.gateway;
}

protected CompletableFuture<TaskExecutorGateway> reconnect() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@

import akka.actor.AbstractActorWithTimers;
import akka.actor.Props;
import akka.actor.Status.Failure;
import akka.japi.pf.ReceiveBuilder;
import com.netflix.spectator.api.Tag;
import io.mantisrx.common.Ack;
import io.mantisrx.common.metrics.Counter;
import io.mantisrx.common.metrics.Metrics;
import io.mantisrx.common.metrics.MetricsRegistry;
Expand All @@ -40,6 +42,7 @@
import java.time.Instant;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
import lombok.Value;
Expand Down Expand Up @@ -112,6 +115,8 @@ public Receive createReceive() {
.match(FailedToSubmitScheduleRequestEvent.class, this::onFailedToSubmitScheduleRequestEvent)
.match(RetryCancelRequestEvent.class, this::onRetryCancelRequestEvent)
.match(Noop.class, this::onNoop)
.match(Ack.class, ack -> log.debug("Received ack from {}", sender()))
.match(Failure.class, failure -> log.error("Received failure from {}: {}", sender(), failure))
.build();
}

Expand Down Expand Up @@ -139,38 +144,48 @@ private void onInitializeRunningWorkerRequest(InitializeRunningWorkerRequestEven
}

private void onAssignedScheduleRequestEvent(AssignedScheduleRequestEvent event) {
TaskExecutorGateway gateway = null;
TaskExecutorRegistration info = null;
try {
gateway = resourceCluster.getTaskExecutorGateway(event.getTaskExecutorID()).join();
info = resourceCluster.getTaskExecutorInfo(event.getTaskExecutorID()).join();
CompletableFuture<TaskExecutorGateway> gatewayFut = resourceCluster.getTaskExecutorGateway(event.getTaskExecutorID());
TaskExecutorRegistration info = resourceCluster.getTaskExecutorInfo(event.getTaskExecutorID()).join();

if (gatewayFut != null && info != null) {
CompletionStage<Object> ackFuture =
gatewayFut
.thenComposeAsync(gateway ->
gateway
.submitTask(
executeStageRequestFactory.of(
event.getScheduleRequestEvent().getRequest(),
info))
.<Object>thenApply(
dontCare -> new SubmittedScheduleRequestEvent(
event.getScheduleRequestEvent(),
event.getTaskExecutorID()))
.exceptionally(
throwable -> new FailedToSubmitScheduleRequestEvent(
event.getScheduleRequestEvent(),
event.getTaskExecutorID(), throwable))
.whenCompleteAsync((res, err) ->
{
if (err == null) {
log.debug("[Submit Task] finish with {}", res);
}
else {
log.error("[Submit Task] fail: {}", event.getTaskExecutorID(), err);
}
})

);
pipe(ackFuture, getContext().getDispatcher()).to(self());
}
} catch (Exception e) {
// we are not able to get the gateway, which either means the node is not great or some transient network issue
// we will retry the request
log.warn(
"Failed to establish connection with the task executor {}; Resubmitting the request",
"Failed to submit task with the task executor {}; Resubmitting the request",
event.getTaskExecutorID(), e);
connectionFailures.increment();
self().tell(event.getScheduleRequestEvent().onFailure(e), self());
}

if (gateway != null && info != null) {
CompletableFuture<Object> ackFuture =
gateway
.submitTask(
executeStageRequestFactory.of(event.getScheduleRequestEvent().getRequest(),
info))
.<Object>thenApply(
dontCare -> new SubmittedScheduleRequestEvent(
event.getScheduleRequestEvent(),
event.getTaskExecutorID()))
.exceptionally(
throwable -> new FailedToSubmitScheduleRequestEvent(
event.getScheduleRequestEvent(),
event.getTaskExecutorID(), throwable));

pipe(ackFuture, getContext().getDispatcher()).to(self());
}
}

private void onFailedScheduleRequestEvent(FailedToScheduleRequestEvent event) {
Expand All @@ -187,6 +202,7 @@ private void onFailedScheduleRequestEvent(FailedToScheduleRequestEvent event) {
}

private void onSubmittedScheduleRequestEvent(SubmittedScheduleRequestEvent event) {
log.debug("[Submit Task]: receive SubmittedScheduleRequestEvent: {}", event);
final TaskExecutorID taskExecutorID = event.getTaskExecutorID();
try {
final TaskExecutorRegistration info = resourceCluster.getTaskExecutorInfo(taskExecutorID)
Expand Down Expand Up @@ -222,7 +238,15 @@ private void onFailedToSubmitScheduleRequestEvent(FailedToSubmitScheduleRequestE
Throwables.getStackTraceAsString(event.throwable)));

try {
resourceCluster.reconnectTaskExecutorGateway(event.getTaskExecutorID()).join();
resourceCluster.reconnectGateway(event.getTaskExecutorID())
.whenComplete((res, throwable) -> {
if (throwable != null) {
log.error("Failed to request reconnect to gateway for {}", event.getTaskExecutorID(), throwable);
}
else {
log.debug("Acked from reconnection request for {}", event.getTaskExecutorID());
}
});
} catch (Exception e) {
log.warn(
"Failed to establish re-connection with the task executor {} on failed schedule request",
Expand All @@ -237,24 +261,24 @@ private void onCancelRequestEvent(CancelRequestEvent event) {
getTimers().cancel(getSchedulingQueueKeyFor(event.getWorkerId()));
final TaskExecutorID taskExecutorID =
resourceCluster.getTaskExecutorAssignedFor(event.getWorkerId()).join();
final TaskExecutorGateway gateway =
resourceCluster.getTaskExecutorGateway(taskExecutorID).join();

CompletableFuture<Object> cancelFuture =
gateway
.cancelTask(event.getWorkerId())
.<Object>thenApply(dontCare -> Noop.getInstance())
.exceptionally(exception -> {
Throwable actual =
ExceptionUtils.stripCompletionException(
ExceptionUtils.stripExecutionException(exception));
// no need to retry if the TaskExecutor does not know about the task anymore.
if (actual instanceof TaskNotFoundException) {
return Noop.getInstance();
} else {
return event.onFailure(actual);
}
});
resourceCluster.getTaskExecutorGateway(taskExecutorID)
.thenComposeAsync(gateway ->
gateway
.cancelTask(event.getWorkerId())
.<Object>thenApply(dontCare -> Noop.getInstance())
.exceptionally(exception -> {
Throwable actual =
ExceptionUtils.stripCompletionException(
ExceptionUtils.stripExecutionException(exception));
// no need to retry if the TaskExecutor does not know about the task anymore.
if (actual instanceof TaskNotFoundException) {
return Noop.getInstance();
} else {
return event.onFailure(actual);
}
}));

pipe(cancelFuture, context().dispatcher()).to(self());
} catch (Exception e) {
Expand Down

0 comments on commit 3ebcf81

Please sign in to comment.