-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(components): Support dynamic machine type paramters in CustomTra…
…iningJobOp (#10883) * feat(components): Support dynamic machine type paramters in CustomTrainingJobOp Signed-off-by: KevinGrantLee <[email protected]> * fix formatting Signed-off-by: KevinGrantLee <[email protected]> * Remove unused imports and try to fix type annotation error. Signed-off-by: KevinGrantLee <[email protected]> * fix formatting Signed-off-by: KevinGrantLee <[email protected]> * Fix string formatting Signed-off-by: KevinGrantLee <[email protected]> * Fix string formatting Signed-off-by: KevinGrantLee <[email protected]> * Add new test pipeline Signed-off-by: KevinGrantLee <[email protected]> * update release.md Signed-off-by: KevinGrantLee <[email protected]> * Rename recursive_replace Signed-off-by: KevinGrantLee <[email protected]> * Refactor code and add condition test pipeline Signed-off-by: KevinGrantLee <[email protected]> * fix formatting Signed-off-by: KevinGrantLee <[email protected]> * minor clean up Signed-off-by: KevinGrantLee <[email protected]> --------- Signed-off-by: KevinGrantLee <[email protected]>
- Loading branch information
1 parent
0c2bcf1
commit b57f9e8
Showing
11 changed files
with
1,225 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
64 changes: 64 additions & 0 deletions
64
...on/test_data/pipelines/pipeline_with_condition_dynamic_task_output_custom_training_job.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import google_cloud_pipeline_components.v1.custom_job as custom_job | ||
from kfp import dsl | ||
|
||
|
||
@dsl.component | ||
def flip_biased_coin_op() -> str: | ||
"""Flip a coin and output heads.""" | ||
return 'heads' | ||
|
||
|
||
@dsl.component | ||
def machine_type() -> str: | ||
return 'n1-standard-4' | ||
|
||
|
||
@dsl.component | ||
def accelerator_type() -> str: | ||
return 'NVIDIA_TESLA_P4' | ||
|
||
|
||
@dsl.component | ||
def accelerator_count() -> int: | ||
return 1 | ||
|
||
|
||
@dsl.pipeline | ||
def pipeline( | ||
project: str, | ||
location: str, | ||
encryption_spec_key_name: str = '', | ||
): | ||
flip1 = flip_biased_coin_op().set_caching_options(False) | ||
machine_type_task = machine_type() | ||
accelerator_type_task = accelerator_type() | ||
accelerator_count_task = accelerator_count() | ||
|
||
with dsl.Condition(flip1.output == 'heads'): | ||
custom_job.CustomTrainingJobOp( | ||
display_name='add-numbers', | ||
worker_pool_specs=[{ | ||
'container_spec': { | ||
'image_uri': ( | ||
'gcr.io/ml-pipeline/google-cloud-pipeline-components:2.5.0' | ||
), | ||
'command': ['echo'], | ||
'args': ['foo'], | ||
}, | ||
'machine_spec': { | ||
'machine_type': machine_type_task.output, | ||
'accelerator_type': accelerator_type_task.output, | ||
'accelerator_count': accelerator_count_task.output, | ||
}, | ||
'replica_count': 1, | ||
}], | ||
project=project, | ||
location=location, | ||
encryption_spec_key_name=encryption_spec_key_name, | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
from kfp import compiler | ||
compiler.Compiler().compile( | ||
pipeline_func=pipeline, package_path=__file__.replace('.py', '.yaml')) |
Oops, something went wrong.