diff --git a/modifiers/allocation/modifier.py b/modifiers/allocation/modifier.py index d516d653f..bdc2affc0 100644 --- a/modifiers/allocation/modifier.py +++ b/modifiers/allocation/modifier.py @@ -57,16 +57,34 @@ def defined_allocation_options(expander): val = AllocOpt.as_type(alloc_opt, var_def) except ValueError: continue - defined[alloc_opt] = val + + if val is not None: + defined[alloc_opt] = val return defined class AttrDict(dict): - __getattr__ = dict.__getitem__ - __setattr__ = dict.__setitem__ - __delattr__ = dict.__delitem__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self['_attributes'] = set() + + def __getattr__(self, *args, **kwargs): + return self.__getitem__(*args, **kwargs) + + def __setattr__(self, *args, **kwargs): + self.__setitem__(*args, **kwargs) + + def __delattr__(self, *args, **kwargs): + self.__delitem__(*args, **kwargs) + + def __setitem__(self, key, value): + super().__setitem__(key, value) + if key != "_attributes": + self['_attributes'].add(key) + def defined(self): + return list((k, self[k]) for k in self['_attributes']) class Allocation(BasicModifier): @@ -86,14 +104,15 @@ def inherit_from_application(self, app): var_defs = defined_allocation_options(app.expander) - # Calculate unset values (e.g. determine n_nodes if not set) - self.determine_allocation(var_defs) + with Allocation.as_attrs(var_defs) as v: + # Calculate unset values (e.g. determine n_nodes if not set) + self.determine_allocation(v) - self.determine_scheduler_instructions(var_defs) + self.determine_scheduler_instructions(v) - # Definitions - for var, val in var_defs.items(): - app.define_variable(var.name.lower(), str(val)) + # Definitions + for var, val in v.defined(): + app.define_variable(var, str(val)) @staticmethod @contextmanager @@ -104,29 +123,23 @@ def as_attrs(var_defs): yield v - for alloc_opt in AllocOpt: - local_val = getattr(v, alloc_opt.name.lower(), None) - if (alloc_opt not in var_defs) and local_val: - var_defs[alloc_opt] = local_val + def determine_allocation(self, v): + if not v.n_ranks: + if v.n_ranks_per_node and v.n_nodes: + v.n_ranks = v.n_nodes * v.n_ranks_per_node - def determine_allocation(self, var_defs): - with Allocation.as_attrs(var_defs) as v: - if not v.n_ranks: - if v.n_ranks_per_node and v.n_nodes: - v.n_ranks = v.n_nodes * v.n_ranks_per_node - - if not v.n_nodes: - if v.n_ranks: - cpus_request_per_task = 1 - multi_cpus_per_task = n_cores_per_task or n_threads or 0 - cpus_request_per_task = max(multi_cpus_per_task, 1) - tasks_per_node = math.floor(cpus_per_node / cpus_request_per_task) - v.n_nodes = math.ceil(v.n_ranks / tasks_per_node) - if v.n_gpus and v.gpus_per_node: - v.n_nodes = math.ceil(v.n_gpus / float(v.gpus_per_node)) - - if not v.n_threads: - v.n_threads = 1 + if not v.n_nodes: + if v.n_ranks: + cpus_request_per_task = 1 + multi_cpus_per_task = v.n_cores_per_task or v.n_threads or 0 + cpus_request_per_task = max(multi_cpus_per_task, 1) + tasks_per_node = math.floor(cpus_per_node / cpus_request_per_task) + v.n_nodes = math.ceil(v.n_ranks / tasks_per_node) + if v.n_gpus and v.gpus_per_node: + v.n_nodes = math.ceil(v.n_gpus / float(v.gpus_per_node)) + + if not v.n_threads: + v.n_threads = 1 def slurm_instructions(self, v): srun_opts = [] @@ -134,9 +147,9 @@ def slurm_instructions(self, v): if v.n_ranks: srun_opts.append(f"-n {v.n_ranks}") if v.n_gpus: - srun_opts.append(f"-n {v.n_gpus}") + srun_opts.append(f"--gpus {v.n_gpus}") if v.n_nodes: - srun_opts.append(f"-n {v.n_nodes}") + srun_opts.append(f"-N {v.n_nodes}") sbatch_directives = list(f"# SBATCH {x}" for x in srun_opts) @@ -152,11 +165,12 @@ def mpi_instructions(self, v): v.batch_submit = "{execute_experiment}" v.allocation_directives = "" - def determine_scheduler_instructions(self, var_defs): + def determine_scheduler_instructions(self, v): handler = { "slurm": self.slurm_instructions, "flux": self.flux_instructions, "mpi": self.mpi_instructions } - with Allocation.as_attrs(var_defs) as v: - handler[v.scheduler](v) + if v.scheduler not in handler: + raise ValueError(f"scheduler ({v.scheduler}) must be one of : " + " ".join(handler.keys())) + handler[v.scheduler](v)