From dfe9dde02fb0fd79da518e25e6989deee64e0321 Mon Sep 17 00:00:00 2001 From: Revital Sur Date: Wed, 25 Sep 2024 09:21:50 +0300 Subject: [PATCH] Add nodes toleration to Ray nodes Signed-off-by: Revital Sur --- .../src/runtime_utils/remote_jobs_utils.py | 50 +++++++++---------- 1 file changed, 23 insertions(+), 27 deletions(-) diff --git a/kfp/kfp_support_lib/shared_workflow_support/src/runtime_utils/remote_jobs_utils.py b/kfp/kfp_support_lib/shared_workflow_support/src/runtime_utils/remote_jobs_utils.py index eb6a2e2c39..007345dfdf 100644 --- a/kfp/kfp_support_lib/shared_workflow_support/src/runtime_utils/remote_jobs_utils.py +++ b/kfp/kfp_support_lib/shared_workflow_support/src/runtime_utils/remote_jobs_utils.py @@ -25,9 +25,9 @@ ClusterSpec, HeadNodeSpec, RayJobRequest, - Template, WorkerNodeSpec, environment_variables_decoder, + template_decoder, volume_decoder, ) from ray.job_submission import JobStatus @@ -121,20 +121,18 @@ def create_ray_cluster( """ # start with templates # head_node - cpus = head_node.get("cpu", 1) - memory = head_node.get("memory", 1) - gpus = head_node.get("gpu", 0) - accelerator = head_node.get("gpu_accelerator", None) + dct = {} + dct["cpu"] = head_node.get("cpu", 1) + dct["memory"] = head_node.get("memory", 1) + dct["gpu"] = head_node.get("gpu", 0) + dct["gpu_accelerator"] = head_node.get("gpu_accelerator", None) head_node_template_name = f"{name}-head-template" - _, _ = self.api_server_client.delete_compute_template(ns=namespace, name=head_node_template_name) - head_template = Template( - name=head_node_template_name, - namespace=namespace, - cpu=cpus, - memory=memory, - gpu=gpus, - gpu_accelerator=accelerator, - ) + dct["name"] = head_node_template_name + dct["namespace"] = namespace + if "tolerations" in head_node: + dct["tolerations"] = head_node.get("tolerations") + _, _ = self.api_server_client.delete_compute_template(ns=namespace, name=dct["name"]) + head_template = template_decoder(dct) status, error = self.api_server_client.create_compute_template(head_template) if status != 200: return status, error @@ -142,20 +140,18 @@ def create_ray_cluster( index = 0 # For every worker group for worker_node in worker_nodes: - cpus = worker_node.get("cpu", 1) - memory = worker_node.get("memory", 1) - gpus = worker_node.get("gpu", 0) - accelerator = worker_node.get("gpu_accelerator", None) + dct = {} + dct["cpu"] = worker_node.get("cpu", 1) + dct["memory"] = worker_node.get("memory", 1) + dct["gpu"] = worker_node.get("gpu", 0) + dct["gpu_accelerator"] = worker_node.get("gpu_accelerator", None) worker_node_template_name = f"{name}-worker-template-{index}" - _, _ = self.api_server_client.delete_compute_template(ns=namespace, name=worker_node_template_name) - worker_template = Template( - name=worker_node_template_name, - namespace=namespace, - cpu=cpus, - memory=memory, - gpu=gpus, - gpu_accelerator=accelerator, - ) + dct["name"] = worker_node_template_name + dct["namespace"] = namespace + if "tolerations" in worker_node: + dct["tolerations"] = worker_node.get("tolerations") + _, _ = self.api_server_client.delete_compute_template(ns=namespace, name=dct["name"]) + worker_template = template_decoder(dct) status, error = self.api_server_client.create_compute_template(worker_template) if status != 200: return status, error