Skip to content

Commit

Permalink
Simplify parameter handling logic
Browse files Browse the repository at this point in the history
Signed-off-by: droctothorpe <[email protected]>
Co-authored-by: zazulam <[email protected]>
  • Loading branch information
droctothorpe and zazulam committed Sep 18, 2024
1 parent d920cf1 commit ee7f6c9
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 52 deletions.
84 changes: 34 additions & 50 deletions backend/src/v2/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1218,63 +1218,48 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int,
return nil, paramError(err)
}

// The producer is the task that produces the output that we need to
// consume.
producer := tasks[taskOutput.GetProducerTask()]
glog.V(4).Info("producer: ", producer)
currentTask := producer
// If the producer is a DAG, AND its output / producer subtask is
// ALSO a DAG, then we need to cycle through this loop until we
// arrive at a non-DAG subtask and essentially bubble up that
// non-DAG subtask so that its value can be consumed.
producerSubTaskMaybeDAG := true
for producerSubTaskMaybeDAG {
// The producer is the task that produces the output that we need to
// consume.
producer := tasks[taskOutput.GetProducerTask()]

glog.V(4).Info("producer: ", producer)

// Get the producer's outputs.
_, producerOutputParameters, err := producer.GetParameters()
if err != nil {
return nil, paramError(fmt.Errorf("get producer output parameters: %w", err))
}
glog.V(4).Info("producer output parameters: ", producerOutputParameters)
// Deserialize them.
var producerOutputParametersMap map[string]string
b, err := producerOutputParameters["Output"].GetStructValue().MarshalJSON()
currentSubTaskMaybeDAG := true
for currentSubTaskMaybeDAG {
glog.V(4).Info("currentTask: ", currentTask.TaskName())
_, outputParametersCustomProperty, err := currentTask.GetParameters()
if err != nil {
return nil, err
}
json.Unmarshal(b, &producerOutputParametersMap)
glog.V(4).Info("producerOutputParametersMap: ", producerOutputParametersMap)

// If the producer's output includes a producer subtask, which means
// that the producer is a DAG that is getting its output from one of
// the tasks in the DAG, then we want to roll up the output from the
// producer subtask to the producer, so that the downstream logic
// can retrieve it appropriately.
if producerSubTask, ok := producerOutputParametersMap["producer_subtask"]; ok {
glog.V(4).Infof(
"Overriding producer task, %v, output with producer_subtask, %v, output.",
producer.TaskName(),
producerSubTask,
)
_, producerOutputParameters, err = tasks[producerSubTask].GetParameters()
// If the current task is a DAG:
if *currentTask.GetExecution().Type == "system.DAGExecution" {
// Since currentTask is a DAG, we need to deserialize its
// output parameter map so that we can look its
// corresponding producer sub-task, reassign currentTask,
// and iterate through this loop again.
var outputParametersMap map[string]string
b, err := outputParametersCustomProperty["Output"].GetStructValue().MarshalJSON()
if err != nil {
return nil, err
}
glog.V(4).Info("producerSubTask output parameters: ", producerOutputParameters)
// The only reason we're updating this is to make the downstream
// logging more accurate.
taskOutput.ProducerTask = producerOutputParametersMap["producer_subtask"]
// Grab the value of the producer output.
producerOutputParameterValue, ok := producerOutputParameters[taskOutput.GetOutputParameterKey()]
if !ok {
return nil, paramError(fmt.Errorf("cannot find output parameter key %q in producer task %q", taskOutput.GetOutputParameterKey(), taskOutput.GetProducerTask()))
}
// Update the input to be the producer output value.
inputs.ParameterValues[name] = producerOutputParameterValue
json.Unmarshal(b, &outputParametersMap)
glog.V(4).Info("Deserialized outputParametersMap: ", outputParametersMap)
subTaskName := outputParametersMap["producer_subtask"]
glog.V(4).Infof(
"Overriding currentTask, %v, output with currentTask's producer_subtask, %v, output.",
currentTask.TaskName(),
subTaskName,
)

// Reassign sub-task before running through the loop again.
currentTask = tasks[subTaskName]
} else {
// The producer subtask is not a DAG, so we exit the loop.
producerSubTaskMaybeDAG = false
inputs.ParameterValues[name] = producerOutputParameters[taskOutput.GetOutputParameterKey()]
inputs.ParameterValues[name] = outputParametersCustomProperty[taskOutput.GetOutputParameterKey()]
// Exit the loop.
currentSubTaskMaybeDAG = false
}
}

Expand Down Expand Up @@ -1333,8 +1318,8 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int,
var outputArtifactKey string
currentSubTaskMaybeDAG := true
for currentSubTaskMaybeDAG {
// If the current task is a DAG:
glog.V(4).Info("currentTask: ", currentTask.TaskName())
// If the current task is a DAG:
if *currentTask.GetExecution().Type == "system.DAGExecution" {
// Get the sub-task.
outputArtifactsCustomProperty := currentTask.GetExecution().GetCustomProperties()["output_artifacts"]
Expand All @@ -1351,13 +1336,12 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int,
outputArtifactKey = artifactSelectors[0].OutputArtifactKey
glog.V(4).Info("subTaskName: ", subTaskName)
glog.V(4).Info("outputArtifactKey: ", outputArtifactKey)
currentSubTask := tasks[subTaskName]
// If the sub-task is a DAG, reassign currentTask and run
// through the loop again.
currentTask = currentSubTask
currentTask = tasks[subTaskName]
// }
} else {
// Base case, subtask is a container, not a DAG.
// Base case, currentTask is a container, not a DAG.
outputs, err := mlmd.GetOutputArtifactsByExecutionId(ctx, currentTask.GetID())
if err != nil {
return nil, artifactError(err)
Expand Down
2 changes: 0 additions & 2 deletions backend/src/v2/metadata/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,6 @@ type ExecutionConfig struct {
ParentDagID int64 // parent DAG execution ID. Only the root DAG does not have a parent DAG.
InputParameters map[string]*structpb.Value
OutputParameters map[string]*structpb.Value
// OutputArtifacts map[string]*structpb.Value
// OutputArtifacts []*pipelinespec.DagOutputsSpec_ArtifactSelectorSpec
OutputArtifacts map[string]*pipelinespec.DagOutputsSpec_DagOutputArtifactSpec
InputArtifactIDs map[string][]int64
IterationIndex *int // Index of the iteration.
Expand Down

0 comments on commit ee7f6c9

Please sign in to comment.