Skip to content

Commit

Permalink
Add nodes toleration to Ray nodes
Browse files Browse the repository at this point in the history
Signed-off-by: Revital Sur <[email protected]>
  • Loading branch information
revit13 committed Sep 25, 2024
1 parent f5d9680 commit dfe9dde
Showing 1 changed file with 23 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
ClusterSpec,
HeadNodeSpec,
RayJobRequest,
Template,
WorkerNodeSpec,
environment_variables_decoder,
template_decoder,
volume_decoder,
)
from ray.job_submission import JobStatus
Expand Down Expand Up @@ -121,41 +121,37 @@ def create_ray_cluster(
"""
# start with templates
# head_node
cpus = head_node.get("cpu", 1)
memory = head_node.get("memory", 1)
gpus = head_node.get("gpu", 0)
accelerator = head_node.get("gpu_accelerator", None)
dct = {}
dct["cpu"] = head_node.get("cpu", 1)
dct["memory"] = head_node.get("memory", 1)
dct["gpu"] = head_node.get("gpu", 0)
dct["gpu_accelerator"] = head_node.get("gpu_accelerator", None)
head_node_template_name = f"{name}-head-template"
_, _ = self.api_server_client.delete_compute_template(ns=namespace, name=head_node_template_name)
head_template = Template(
name=head_node_template_name,
namespace=namespace,
cpu=cpus,
memory=memory,
gpu=gpus,
gpu_accelerator=accelerator,
)
dct["name"] = head_node_template_name
dct["namespace"] = namespace
if "tolerations" in head_node:
dct["tolerations"] = head_node.get("tolerations")
_, _ = self.api_server_client.delete_compute_template(ns=namespace, name=dct["name"])
head_template = template_decoder(dct)
status, error = self.api_server_client.create_compute_template(head_template)
if status != 200:
return status, error
worker_template_names = [""] * len(worker_nodes)
index = 0
# For every worker group
for worker_node in worker_nodes:
cpus = worker_node.get("cpu", 1)
memory = worker_node.get("memory", 1)
gpus = worker_node.get("gpu", 0)
accelerator = worker_node.get("gpu_accelerator", None)
dct = {}
dct["cpu"] = worker_node.get("cpu", 1)
dct["memory"] = worker_node.get("memory", 1)
dct["gpu"] = worker_node.get("gpu", 0)
dct["gpu_accelerator"] = worker_node.get("gpu_accelerator", None)
worker_node_template_name = f"{name}-worker-template-{index}"
_, _ = self.api_server_client.delete_compute_template(ns=namespace, name=worker_node_template_name)
worker_template = Template(
name=worker_node_template_name,
namespace=namespace,
cpu=cpus,
memory=memory,
gpu=gpus,
gpu_accelerator=accelerator,
)
dct["name"] = worker_node_template_name
dct["namespace"] = namespace
if "tolerations" in worker_node:
dct["tolerations"] = worker_node.get("tolerations")
_, _ = self.api_server_client.delete_compute_template(ns=namespace, name=dct["name"])
worker_template = template_decoder(dct)
status, error = self.api_server_client.create_compute_template(worker_template)
if status != 200:
return status, error
Expand Down

0 comments on commit dfe9dde

Please sign in to comment.