Skip to content

Commit

Permalink
feat(components): Support dynamic machine type paramters in CustomTra…
Browse files Browse the repository at this point in the history
…iningJobOp (#10883)

* feat(components): Support dynamic machine type paramters in CustomTrainingJobOp

Signed-off-by: KevinGrantLee <[email protected]>

* fix formatting

Signed-off-by: KevinGrantLee <[email protected]>

* Remove unused imports and try to fix type annotation error.

Signed-off-by: KevinGrantLee <[email protected]>

* fix formatting

Signed-off-by: KevinGrantLee <[email protected]>

* Fix string formatting

Signed-off-by: KevinGrantLee <[email protected]>

* Fix string formatting

Signed-off-by: KevinGrantLee <[email protected]>

* Add new test pipeline

Signed-off-by: KevinGrantLee <[email protected]>

* update release.md

Signed-off-by: KevinGrantLee <[email protected]>

* Rename recursive_replace

Signed-off-by: KevinGrantLee <[email protected]>

* Refactor code and add condition test pipeline
Signed-off-by: KevinGrantLee <[email protected]>

* fix formatting
Signed-off-by: KevinGrantLee <[email protected]>

* minor clean up
Signed-off-by: KevinGrantLee <[email protected]>

---------

Signed-off-by: KevinGrantLee <[email protected]>
  • Loading branch information
KevinGrantLee authored Jun 13, 2024
1 parent 0c2bcf1 commit b57f9e8
Show file tree
Hide file tree
Showing 11 changed files with 1,225 additions and 10 deletions.
1 change: 1 addition & 0 deletions sdk/RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Current Version (in development)

## Features
* Support dynamic machine type parameters in CustomTrainingJobOp. [\#10883](https://github.com/kubeflow/pipelines/pull/10883)

## Breaking changes
* Drop support for Python 3.7 since it has reached end-of-life. [\#10750](https://github.com/kubeflow/pipelines/pull/10750)
Expand Down
31 changes: 31 additions & 0 deletions sdk/python/kfp/compiler/compiler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,3 +772,34 @@ def get_dependencies(
dependencies[downstream_names[0]].add(upstream_names[0])

return dependencies


def recursive_replace_placeholders(data: Union[Dict, List], old_value: str,
new_value: str) -> Union[Dict, List]:
"""Recursively replaces values in a nested dict/list object.
This method is used to replace PipelineChannel objects with input parameter
placeholders in a nested object like worker_pool_specs for custom jobs.
Args:
data: A nested object that can contain dictionaries and/or lists.
old_value: The value that will be replaced.
new_value: The value to replace the old value with.
Returns:
A copy of data with all occurences of old_value replaced by new_value.
"""
if isinstance(data, dict):
return {
k: recursive_replace_placeholders(v, old_value, new_value)
for k, v in data.items()
}
elif isinstance(data, list):
return [
recursive_replace_placeholders(i, old_value, new_value)
for i in data
]
else:
if isinstance(data, pipeline_channel.PipelineChannel):
data = str(data)
return new_value if data == old_value else data
61 changes: 61 additions & 0 deletions sdk/python/kfp/compiler/compiler_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,67 @@ def test_additional_input_name_for_pipeline_channel(self, channel,
expected,
compiler_utils.additional_input_name_for_pipeline_channel(channel))

@parameterized.parameters(
{
'data': [{
'container_spec': {
'image_uri':
'gcr.io/ml-pipeline/google-cloud-pipeline-components:2.5.0',
'command': ['echo'],
'args': ['foo']
},
'machine_spec': {
'machine_type':
pipeline_channel.PipelineParameterChannel(
name='Output',
channel_type='String',
task_name='machine-type'),
'accelerator_type':
pipeline_channel.PipelineParameterChannel(
name='Output',
channel_type='String',
task_name='accelerator-type'),
'accelerator_count':
1
},
'replica_count': 1
}],
'old_value':
'{{channel:task=machine-type;name=Output;type=String;}}',
'new_value':
'{{$.inputs.parameters['
'pipelinechannel--machine-type-Output'
']}}',
'expected': [{
'container_spec': {
'image_uri':
'gcr.io/ml-pipeline/google-cloud-pipeline-components:2.5.0',
'command': ['echo'],
'args': ['foo']
},
'machine_spec': {
'machine_type':
'{{$.inputs.parameters['
'pipelinechannel--machine-type-Output'
']}}',
'accelerator_type':
pipeline_channel.PipelineParameterChannel(
name='Output',
channel_type='String',
task_name='accelerator-type'),
'accelerator_count':
1
},
'replica_count': 1
}],
},)
def test_recursive_replace_placeholders(self, data, old_value, new_value,
expected):
self.assertEqual(
expected,
compiler_utils.recursive_replace_placeholders(
data, old_value, new_value))


if __name__ == '__main__':
unittest.main()
21 changes: 12 additions & 9 deletions sdk/python/kfp/compiler/pipeline_spec_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,14 @@ def build_task_spec_for_task(
input_name].component_input_parameter = (
component_input_parameter)

elif isinstance(input_value, str):
# Handle extra input due to string concat
elif isinstance(input_value, (str, int, float, bool, dict, list)):
pipeline_channels = (
pipeline_channel.extract_pipeline_channels_from_any(input_value)
)
for channel in pipeline_channels:
# NOTE: case like this p3 = print_and_return_str(s='Project = {}'.format(project))
# triggers this code

# value contains PipelineChannel placeholders which needs to be
# replaced. And the input needs to be added to the task spec.

Expand All @@ -265,8 +267,14 @@ def build_task_spec_for_task(

additional_input_placeholder = placeholders.InputValuePlaceholder(
additional_input_name)._to_string()
input_value = input_value.replace(channel.pattern,
additional_input_placeholder)

if isinstance(input_value, str):
input_value = input_value.replace(
channel.pattern, additional_input_placeholder)
else:
input_value = compiler_utils.recursive_replace_placeholders(
input_value, channel.pattern,
additional_input_placeholder)

if channel.task_name:
# Value is produced by an upstream task.
Expand Down Expand Up @@ -299,11 +307,6 @@ def build_task_spec_for_task(
additional_input_name].component_input_parameter = (
component_input_parameter)

pipeline_task_spec.inputs.parameters[
input_name].runtime_value.constant.string_value = input_value

elif isinstance(input_value, (str, int, float, bool, dict, list)):

pipeline_task_spec.inputs.parameters[
input_name].runtime_value.constant.CopyFrom(
to_protobuf_value(input_value))
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/kfp/dsl/pipeline_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def extract_pipeline_channels_from_string(


def extract_pipeline_channels_from_any(
payload: Union[PipelineChannel, str, list, tuple, dict]
payload: Union[PipelineChannel, str, int, float, bool, list, tuple, dict]
) -> List[PipelineChannel]:
"""Recursively extract PipelineChannels from any object or list of objects.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import google_cloud_pipeline_components.v1.custom_job as custom_job
from kfp import dsl


@dsl.component
def flip_biased_coin_op() -> str:
"""Flip a coin and output heads."""
return 'heads'


@dsl.component
def machine_type() -> str:
return 'n1-standard-4'


@dsl.component
def accelerator_type() -> str:
return 'NVIDIA_TESLA_P4'


@dsl.component
def accelerator_count() -> int:
return 1


@dsl.pipeline
def pipeline(
project: str,
location: str,
encryption_spec_key_name: str = '',
):
flip1 = flip_biased_coin_op().set_caching_options(False)
machine_type_task = machine_type()
accelerator_type_task = accelerator_type()
accelerator_count_task = accelerator_count()

with dsl.Condition(flip1.output == 'heads'):
custom_job.CustomTrainingJobOp(
display_name='add-numbers',
worker_pool_specs=[{
'container_spec': {
'image_uri': (
'gcr.io/ml-pipeline/google-cloud-pipeline-components:2.5.0'
),
'command': ['echo'],
'args': ['foo'],
},
'machine_spec': {
'machine_type': machine_type_task.output,
'accelerator_type': accelerator_type_task.output,
'accelerator_count': accelerator_count_task.output,
},
'replica_count': 1,
}],
project=project,
location=location,
encryption_spec_key_name=encryption_spec_key_name,
)


if __name__ == '__main__':
from kfp import compiler
compiler.Compiler().compile(
pipeline_func=pipeline, package_path=__file__.replace('.py', '.yaml'))
Loading

0 comments on commit b57f9e8

Please sign in to comment.