From 80e1f0232c5f680daa9b2c89faad20741e890bf7 Mon Sep 17 00:00:00 2001 From: Googler Date: Tue, 21 May 2024 17:31:50 -0700 Subject: [PATCH] feat(components): Use GetModel integration test to manually test write_user_defined_error function Signed-off-by: Googler PiperOrigin-RevId: 635979715 --- .../container/utils/error_surfacing.py | 4 +- .../v1/model/get_model/remote_runner.py | 43 ++++++++++++++++--- .../proto/task_error_pb2.py | 1 - 3 files changed, 38 insertions(+), 10 deletions(-) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/utils/error_surfacing.py b/components/google-cloud/google_cloud_pipeline_components/container/utils/error_surfacing.py index 3c9919beeb01..ebbf0ab18072 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/utils/error_surfacing.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/utils/error_surfacing.py @@ -20,10 +20,10 @@ from google_cloud_pipeline_components.proto import task_error_pb2 -def write_user_defined_error( +def write_customized_error( executor_input: str, error: task_error_pb2.TaskError ): - """Writes a TaskError to a JSON file ('executor_error.json') in the output directory specified in the executor input. + """Writes a TaskError customized by the author of the pipelines to a JSON file ('executor_error.json') in the output directory specified in the executor input. Args: executor_input: JSON string containing executor input data. diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/model/get_model/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/model/get_model/remote_runner.py index 797f8c6f5349..3c74ab7a93e5 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/model/get_model/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/model/get_model/remote_runner.py @@ -12,13 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. """Remote runner for Get Model based on the Vertex AI SDK.""" - +import contextlib +from typing import Tuple, Type, Union from google.api_core.client_options import ClientOptions from google.cloud import aiplatform_v1 as aip_v1 from google_cloud_pipeline_components.container.utils import artifact_utils +from google_cloud_pipeline_components.container.utils import error_surfacing +from google_cloud_pipeline_components.proto import task_error_pb2 from google_cloud_pipeline_components.types import artifact_types +@contextlib.contextmanager +def catch_write_and_raise( + executor_input: str, + exception_types: Union[ + Type[Exception], Tuple[Type[Exception], ...] + ] = Exception, +): + """Context manager to catch specified exceptions, log them using error_surfacing, and then re-raise.""" + try: + yield + except exception_types as e: + task_error = task_error_pb2.TaskError() + task_error.error_message = str(e) + error_surfacing.write_customized_error(executor_input, task_error) + raise + + def get_model( executor_input, model_name: str, @@ -26,11 +46,16 @@ def get_model( location: str, ) -> None: """Get model.""" - if not location or not project: - raise ValueError( - 'Model resource name must be in the format' - ' projects/{project}/locations/{location}/models/{model_name}' - ) + with catch_write_and_raise( + executor_input=executor_input, + exception_types=ValueError, + ): + if not location or not project: + model_name_error_message = ( + 'Model resource name must be in the format' + ' projects/{project}/locations/{location}/models/{model_name}' + ) + raise ValueError(model_name_error_message) api_endpoint = location + '-aiplatform.googleapis.com' vertex_uri_prefix = f'https://{api_endpoint}/v1/' model_resource_name = ( @@ -40,7 +65,11 @@ def get_model( client_options = ClientOptions(api_endpoint=api_endpoint) client = aip_v1.ModelServiceClient(client_options=client_options) request = aip_v1.GetModelRequest(name=model_resource_name) - get_model_response = client.get_model(request) + with catch_write_and_raise( + executor_input=executor_input, + exception_types=Exception, + ): + get_model_response = client.get_model(request) resp_model_name_without_version = get_model_response.name.split('@', 1)[0] model_resource_name = ( f'{resp_model_name_without_version}@{get_model_response.version_id}' diff --git a/components/google-cloud/google_cloud_pipeline_components/proto/task_error_pb2.py b/components/google-cloud/google_cloud_pipeline_components/proto/task_error_pb2.py index baaaec862ea4..1dbf8868519f 100755 --- a/components/google-cloud/google_cloud_pipeline_components/proto/task_error_pb2.py +++ b/components/google-cloud/google_cloud_pipeline_components/proto/task_error_pb2.py @@ -5,7 +5,6 @@ """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import runtime_version as _runtime_version from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports)