Skip to content

Commit

Permalink
feat(core): EmbeddedFlow task
Browse files Browse the repository at this point in the history
Adds an EmbeddedFlow that allow to embed subflow tasks into a parent tasks.

Fixes #6518
  • Loading branch information
loicmathieu committed Jan 2, 2025
1 parent a9ff469 commit cf5c1f3
Show file tree
Hide file tree
Showing 6 changed files with 292 additions and 20 deletions.
50 changes: 31 additions & 19 deletions core/src/main/java/io/kestra/core/runners/ExecutableUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -65,28 +65,11 @@ public static <T extends Task & ExecutableTask<?>> SubflowExecution<?> subflowEx
boolean inheritLabels,
Property<ZonedDateTime> scheduleDate
) throws IllegalVariableEvaluationException {
String tenantId = currentExecution.getTenantId();
String subflowNamespace = runContext.render(currentTask.subflowId().namespace());
String subflowId = runContext.render(currentTask.subflowId().flowId());
Optional<Integer> subflowRevision = currentTask.subflowId().revision();

io.kestra.core.models.flows.Flow flow = flowExecutorInterface.findByIdFromTask(
currentExecution.getTenantId(),
subflowNamespace,
subflowId,
subflowRevision,
currentExecution.getTenantId(),
currentFlow.getNamespace(),
currentFlow.getId()
)
.orElseThrow(() -> new IllegalStateException("Unable to find flow '" + subflowNamespace + "'.'" + subflowId + "' with revision '" + subflowRevision.orElse(0) + "'"));

if (flow.isDisabled()) {
throw new IllegalStateException("Cannot execute a flow which is disabled");
}

if (flow instanceof FlowWithException fwe) {
throw new IllegalStateException("Cannot execute an invalid flow: " + fwe.getException());
}
Flow flow = getSubflow(tenantId, subflowNamespace, subflowId, subflowRevision, flowExecutorInterface, currentFlow);

List<Label> newLabels = inheritLabels ? new ArrayList<>(currentExecution.getLabels()) : new ArrayList<>(systemLabels(currentExecution));
if (labels != null) {
Expand Down Expand Up @@ -122,6 +105,35 @@ public static <T extends Task & ExecutableTask<?>> SubflowExecution<?> subflowEx
.build();
}

public static Flow getSubflow(String tenantId,
String subflowNamespace,
String subflowId,
Optional<Integer> subflowRevision,
FlowExecutorInterface flowExecutorInterface,
Flow currentFlow) {

Flow flow = flowExecutorInterface.findByIdFromTask(
tenantId,
subflowNamespace,
subflowId,
subflowRevision,
tenantId,
currentFlow.getNamespace(),
currentFlow.getId()
)
.orElseThrow(() -> new IllegalStateException("Unable to find flow '" + subflowNamespace + "'.'" + subflowId + "' with revision '" + subflowRevision.orElse(0) + "'"));

if (flow.isDisabled()) {
throw new IllegalStateException("Cannot execute a flow which is disabled");
}

if (flow instanceof FlowWithException fwe) {
throw new IllegalStateException("Cannot execute an invalid flow: " + fwe.getException());
}

return flow;
}

private static List<Label> systemLabels(Execution execution) {
return Streams.of(execution.getLabels())
.filter(label -> label.key().startsWith(Label.SYSTEM_PREFIX))
Expand Down
209 changes: 209 additions & 0 deletions core/src/main/java/io/kestra/plugin/core/flow/EmbeddedFlow.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
package io.kestra.plugin.core.flow;

import com.fasterxml.jackson.annotation.JsonIgnore;
import io.kestra.core.exceptions.IllegalVariableEvaluationException;
import io.kestra.core.models.annotations.Example;
import io.kestra.core.models.annotations.Plugin;
import io.kestra.core.models.annotations.PluginProperty;
import io.kestra.core.models.executions.Execution;
import io.kestra.core.models.executions.NextTaskRun;
import io.kestra.core.models.executions.TaskRun;
import io.kestra.core.models.flows.Flow;
import io.kestra.core.models.flows.FlowWithException;
import io.kestra.core.models.flows.FlowWithSource;
import io.kestra.core.models.hierarchies.AbstractGraph;
import io.kestra.core.models.hierarchies.GraphCluster;
import io.kestra.core.models.hierarchies.RelationType;
import io.kestra.core.models.tasks.FlowableTask;
import io.kestra.core.models.tasks.ResolvedTask;
import io.kestra.core.models.tasks.Task;
import io.kestra.core.models.tasks.VoidOutput;
import io.kestra.core.runners.*;
import io.kestra.core.services.FlowService;
import io.kestra.core.utils.GraphUtils;
import io.micronaut.context.ApplicationContext;
import io.micronaut.context.event.StartupEvent;
import io.micronaut.runtime.event.annotation.EventListener;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import jakarta.validation.constraints.Min;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
import lombok.*;
import lombok.experimental.SuperBuilder;

import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Stream;

@SuperBuilder
@ToString
@EqualsAndHashCode
@Getter
@NoArgsConstructor
@Schema(
title = "Embeds subflow tasks into this flow."
)
@Plugin(
examples = {
@Example(
title = "Embeds subflow tasks.",
full = true,
code = """
id: parent_flow
namespace: company.team
tasks:
- id: embed_subflow
type: io.kestra.plugin.core.flow.EmbeddedFlow
namespace: company.team
flowId: subflow
"""
)
}
)
public class EmbeddedFlow extends Task implements FlowableTask<VoidOutput>, ChildFlowInterface {
static final String PLUGIN_FLOW_OUTPUTS_ENABLED = "outputs.enabled";

// FIXME no other choice for now as getErrors() and allChildTasks() has no context
@Schema(
title = "The tenantId of the subflow to be embedded."
)
@PluginProperty
private String tenantId;

@NotEmpty
@Schema(
title = "The namespace of the subflow to be embedded."
)
@PluginProperty
private String namespace;

@NotNull
@Schema(
title = "The identifier of the subflow to be embedded."
)
@PluginProperty
private String flowId;

@Schema(
title = "The revision of the subflow to be embedded.",
description = "By default, the last, i.e. the most recent, revision of the subflow is embedded."
)
@PluginProperty
@Min(value = 1)
private Integer revision;

@Override
@JsonIgnore
public List<Task> getErrors() {
Flow subflow = fetchSubflow();

return subflow.getErrors();
}

@Override
public AbstractGraph tasksTree(Execution execution, TaskRun taskRun, List<String> parentValues) throws IllegalVariableEvaluationException {
Flow subflow = fetchSubflow();

GraphCluster subGraph = new GraphCluster(this, taskRun, parentValues, RelationType.SEQUENTIAL);

GraphUtils.sequential(
subGraph,
subflow.getTasks(),
subflow.getErrors(),
taskRun,
execution
);

return subGraph;
}

@Override
public List<Task> allChildTasks() {
Flow subflow = fetchSubflow();

return Stream
.concat(
subflow.getTasks() != null ? subflow.getTasks().stream() : Stream.empty(),
subflow.getErrors() != null ? subflow.getErrors().stream() : Stream.empty()
)
.toList();
}

@Override
public List<ResolvedTask> childTasks(RunContext runContext, TaskRun parentTaskRun) throws IllegalVariableEvaluationException {
// we check that we are allowed to access the namespace TODO should we as Subflow and ForEachItem didn't do that check?
FlowService flowService = ((DefaultRunContext) runContext).getApplicationContext().getBean(FlowService.class);
flowService.checkAllowedNamespace(tenantId, namespace, runContext.flowInfo().tenantId(), runContext.flowInfo().namespace());

// we check that the task tenant is the current tenant to avoid accessing flows from another tenant
if (!Objects.equals(tenantId, runContext.flowInfo().tenantId())) {
throw new IllegalArgumentException("Cannot embeds a flow from a different tenant");
}

Flow subflow = fetchSubflow(runContext);

return FlowableUtils.resolveTasks(subflow.getTasks(), parentTaskRun);
}

@Override
public List<NextTaskRun> resolveNexts(RunContext runContext, Execution execution, TaskRun parentTaskRun) throws IllegalVariableEvaluationException {
return FlowableUtils.resolveSequentialNexts(
execution,
this.childTasks(runContext, parentTaskRun),
FlowableUtils.resolveTasks(this.getErrors(), parentTaskRun),
parentTaskRun
);
}

// This method should only be used when getSubflow(RunContext) cannot be used.
private Flow fetchSubflow() {
ApplicationContext applicationContext = ContextHelper.context();
FlowExecutorInterface flowExecutor = applicationContext.getBean(FlowExecutorInterface.class);
FlowWithSource subflow = flowExecutor.findById(tenantId, namespace, flowId, Optional.ofNullable(revision)).orElseThrow(() -> new IllegalArgumentException("Unable to find flow " + namespace + "." + flowId));

if (subflow.isDisabled()) {
throw new IllegalStateException("Cannot execute a flow which is disabled");
}

if (subflow instanceof FlowWithException fwe) {
throw new IllegalStateException("Cannot execute an invalid flow: " + fwe.getException());
}

return subflow;
}

// This method is preferred as getSubflow() as it checks current flow and subflow
private Flow fetchSubflow(RunContext runContext) {
ApplicationContext applicationContext = ContextHelper.context();
FlowExecutorInterface flowExecutor = applicationContext.getBean(FlowExecutorInterface.class);
RunContext.FlowInfo flowInfo = runContext.flowInfo();
FlowWithSource flow = flowExecutor.findById(flowInfo.tenantId(), flowInfo.namespace(), flowInfo.id(), Optional.of(flowInfo.revision()))
.orElseThrow(() -> new IllegalArgumentException("Unable to find flow " + flowInfo.namespace() + "." + flowInfo.id()));
return ExecutableUtils.getSubflow(tenantId, namespace, flowId, Optional.ofNullable(revision), flowExecutor, flow);
}

/**
* Ugly hack to provide the ApplicationContext on {{@link #allChildTasks }} &amp; {{@link #tasksTree }}
* We need to inject a way to fetch embedded subflows ...
*/
@Singleton
static class ContextHelper {
@Inject
private ApplicationContext applicationContext;

private static ApplicationContext context;

static ApplicationContext context() {
return ContextHelper.context;
}

@EventListener
void onStartup(final StartupEvent event) {
ContextHelper.context = this.applicationContext;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package io.kestra.plugin.core.flow;

import io.kestra.core.junit.annotations.ExecuteFlow;
import io.kestra.core.junit.annotations.KestraTest;
import io.kestra.core.junit.annotations.LoadFlows;
import io.kestra.core.models.executions.Execution;
import io.kestra.core.models.flows.State;
import org.junit.jupiter.api.Test;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.*;

@KestraTest(startRunner = true)
class EmbeddedFlowTest {
@Test
@LoadFlows("flows/valids/minimal.yaml")
@ExecuteFlow("flows/valids/embedded-flow.yaml")
void shouldEmbedTasks(Execution execution) throws Exception {
assertThat(execution.getState().getCurrent(), is(State.Type.SUCCESS));
assertThat(execution.getTaskRunList(), hasSize(2));
assertThat(execution.findTaskRunsByTaskId("embeddedFlow"), notNullValue());
assertThat(execution.findTaskRunsByTaskId("date"), notNullValue());
}

@Test
@LoadFlows({"flows/valids/minimal.yaml", "flows/valids/embedded-flow.yaml"})
@ExecuteFlow("flows/valids/embedded-parent.yaml")
void shouldEmbedTasksRecursively(Execution execution) throws Exception {
assertThat(execution.getState().getCurrent(), is(State.Type.SUCCESS));
assertThat(execution.getTaskRunList(), hasSize(3));
assertThat(execution.findTaskRunsByTaskId("embeddedParent"), notNullValue());
assertThat(execution.findTaskRunsByTaskId("embeddedFlow"), notNullValue());
assertThat(execution.findTaskRunsByTaskId("date"), notNullValue());
}
}
8 changes: 8 additions & 0 deletions core/src/test/resources/flows/valids/embedded-flow.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
id: embedded-flow
namespace: io.kestra.tests

tasks:
- id: embeddedFlow
type: io.kestra.plugin.core.flow.EmbeddedFlow
namespace: io.kestra.tests
flowId: minimal
8 changes: 8 additions & 0 deletions core/src/test/resources/flows/valids/embedded-parent.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
id: embedded-parent
namespace: io.kestra.tests

tasks:
- id: embeddedParent
type: io.kestra.plugin.core.flow.EmbeddedFlow
namespace: io.kestra.tests
flowId: embedded-flow
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public AbstractJdbcFlowRepository(io.kestra.jdbc.AbstractJdbcRepository<Flow> jd
Flow deserialize = this.jdbcRepository.deserialize(source);

// raise exception for invalid flow, ex: Templates disabled
deserialize.allTasksWithChilds();
deserialize.allTasks().forEach((task) -> {});

return deserialize;
} catch (DeserializationException e) {
Expand Down

0 comments on commit cf5c1f3

Please sign in to comment.