Skip to content

Commit

Permalink
finished up work that allows the modifier to define allocations as well
Browse files Browse the repository at this point in the history
  • Loading branch information
scheibelp committed Apr 4, 2024
1 parent 12f092f commit 9642b6f
Showing 1 changed file with 51 additions and 37 deletions.
88 changes: 51 additions & 37 deletions modifiers/allocation/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,34 @@ def defined_allocation_options(expander):
val = AllocOpt.as_type(alloc_opt, var_def)
except ValueError:
continue
defined[alloc_opt] = val

if val is not None:
defined[alloc_opt] = val

return defined


class AttrDict(dict):
__getattr__ = dict.__getitem__
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self['_attributes'] = set()

def __getattr__(self, *args, **kwargs):
return self.__getitem__(*args, **kwargs)

def __setattr__(self, *args, **kwargs):
self.__setitem__(*args, **kwargs)

def __delattr__(self, *args, **kwargs):
self.__delitem__(*args, **kwargs)

def __setitem__(self, key, value):
super().__setitem__(key, value)
if key != "_attributes":
self['_attributes'].add(key)

def defined(self):
return list((k, self[k]) for k in self['_attributes'])

class Allocation(BasicModifier):

Expand All @@ -86,14 +104,15 @@ def inherit_from_application(self, app):

var_defs = defined_allocation_options(app.expander)

# Calculate unset values (e.g. determine n_nodes if not set)
self.determine_allocation(var_defs)
with Allocation.as_attrs(var_defs) as v:
# Calculate unset values (e.g. determine n_nodes if not set)
self.determine_allocation(v)

self.determine_scheduler_instructions(var_defs)
self.determine_scheduler_instructions(v)

# Definitions
for var, val in var_defs.items():
app.define_variable(var.name.lower(), str(val))
# Definitions
for var, val in v.defined():
app.define_variable(var, str(val))

@staticmethod
@contextmanager
Expand All @@ -104,39 +123,33 @@ def as_attrs(var_defs):

yield v

for alloc_opt in AllocOpt:
local_val = getattr(v, alloc_opt.name.lower(), None)
if (alloc_opt not in var_defs) and local_val:
var_defs[alloc_opt] = local_val
def determine_allocation(self, v):
if not v.n_ranks:
if v.n_ranks_per_node and v.n_nodes:
v.n_ranks = v.n_nodes * v.n_ranks_per_node

def determine_allocation(self, var_defs):
with Allocation.as_attrs(var_defs) as v:
if not v.n_ranks:
if v.n_ranks_per_node and v.n_nodes:
v.n_ranks = v.n_nodes * v.n_ranks_per_node

if not v.n_nodes:
if v.n_ranks:
cpus_request_per_task = 1
multi_cpus_per_task = n_cores_per_task or n_threads or 0
cpus_request_per_task = max(multi_cpus_per_task, 1)
tasks_per_node = math.floor(cpus_per_node / cpus_request_per_task)
v.n_nodes = math.ceil(v.n_ranks / tasks_per_node)
if v.n_gpus and v.gpus_per_node:
v.n_nodes = math.ceil(v.n_gpus / float(v.gpus_per_node))

if not v.n_threads:
v.n_threads = 1
if not v.n_nodes:
if v.n_ranks:
cpus_request_per_task = 1
multi_cpus_per_task = v.n_cores_per_task or v.n_threads or 0
cpus_request_per_task = max(multi_cpus_per_task, 1)
tasks_per_node = math.floor(cpus_per_node / cpus_request_per_task)
v.n_nodes = math.ceil(v.n_ranks / tasks_per_node)
if v.n_gpus and v.gpus_per_node:
v.n_nodes = math.ceil(v.n_gpus / float(v.gpus_per_node))

if not v.n_threads:
v.n_threads = 1

def slurm_instructions(self, v):
srun_opts = []

if v.n_ranks:
srun_opts.append(f"-n {v.n_ranks}")
if v.n_gpus:
srun_opts.append(f"-n {v.n_gpus}")
srun_opts.append(f"--gpus {v.n_gpus}")
if v.n_nodes:
srun_opts.append(f"-n {v.n_nodes}")
srun_opts.append(f"-N {v.n_nodes}")

sbatch_directives = list(f"# SBATCH {x}" for x in srun_opts)

Expand All @@ -152,11 +165,12 @@ def mpi_instructions(self, v):
v.batch_submit = "{execute_experiment}"
v.allocation_directives = ""

def determine_scheduler_instructions(self, var_defs):
def determine_scheduler_instructions(self, v):
handler = {
"slurm": self.slurm_instructions,
"flux": self.flux_instructions,
"mpi": self.mpi_instructions
}
with Allocation.as_attrs(var_defs) as v:
handler[v.scheduler](v)
if v.scheduler not in handler:
raise ValueError(f"scheduler ({v.scheduler}) must be one of : " + " ".join(handler.keys()))
handler[v.scheduler](v)

0 comments on commit 9642b6f

Please sign in to comment.