From 134dd8251f77c91bec39e4f405cbc1cae9296bda Mon Sep 17 00:00:00 2001 From: Donald Campbell <125581724+donaldcampbelljr@users.noreply.github.com> Date: Wed, 19 Jun 2024 10:40:27 -0400 Subject: [PATCH] add recursive functions #334 --- looper/conductor.py | 18 +++++++++--------- looper/pipeline_interface.py | 11 ++--------- looper/utils.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 41 insertions(+), 18 deletions(-) diff --git a/looper/conductor.py b/looper/conductor.py index 5f71c9be..93598431 100644 --- a/looper/conductor.py +++ b/looper/conductor.py @@ -28,7 +28,11 @@ from .const import * from .exceptions import JobSubmissionException from .processed_project import populate_sample_paths -from .utils import fetch_sample_flags, jinja_render_template_strictly +from .utils import ( + fetch_sample_flags, + jinja_render_template_strictly, + expand_nested_var_templates, +) from .const import PipelineLevel @@ -717,14 +721,10 @@ def write_script(self, pool, size): _LOGGER.debug(f"namespace pipelines: { pl_iface }") namespaces["pipeline"]["var_templates"] = pl_iface[VAR_TEMPL_KEY] or {} - for k, v in namespaces["pipeline"]["var_templates"].items(): - if isinstance(v, dict): - for key, value in v.items(): - namespaces["pipeline"]["var_templates"][k][key] = ( - jinja_render_template_strictly(value, namespaces) - ) - else: - namespaces["pipeline"]["var_templates"][k] = expandpath(v) + + namespaces["pipeline"]["var_templates"] = expand_nested_var_templates( + namespaces["pipeline"]["var_templates"], namespaces + ) # pre_submit hook namespace updates namespaces = _exec_pre_submit(pl_iface, namespaces) diff --git a/looper/pipeline_interface.py b/looper/pipeline_interface.py index 8674f3bb..1064f20f 100644 --- a/looper/pipeline_interface.py +++ b/looper/pipeline_interface.py @@ -17,7 +17,7 @@ InvalidResourceSpecificationException, PipelineInterfaceConfigError, ) -from .utils import jinja_render_template_strictly +from .utils import jinja_render_template_strictly, render_nested_var_templates __author__ = "Michal Stolarczyk" __email__ = "michal@virginia.edu" @@ -89,14 +89,7 @@ def render_var_templates(self, namespaces): var_templates = {} if curr_data: var_templates.update(curr_data) - for k, v in var_templates.items(): - if isinstance(v, dict): - for key, value in v.items(): - var_templates[k][key] = jinja_render_template_strictly( - value, namespaces - ) - else: - var_templates[k] = jinja_render_template_strictly(v, namespaces) + var_templates = render_nested_var_templates(var_templates, namespaces) return var_templates def get_pipeline_schemas(self, schema_key=INPUT_SCHEMA_KEY): diff --git a/looper/utils.py b/looper/utils.py index fce5fa88..a1b5f67a 100644 --- a/looper/utils.py +++ b/looper/utils.py @@ -825,3 +825,33 @@ def inspect_looper_config_file(looper_config_dict) -> None: print("LOOPER INSPECT") for key, value in looper_config_dict.items(): print(f"{key} {value}") + + +def expand_nested_var_templates(var_templates_dict, namespaces): + + "Takes all var_templates as a dict and recursively expands any paths." + + result = {} + + for k, v in var_templates_dict.items(): + if isinstance(v, dict): + result[k] = expand_nested_var_templates(v, namespaces) + else: + result[k] = expandpath(v) + + return result + + +def render_nested_var_templates(var_templates_dict, namespaces): + + "Takes all var_templates as a dict and recursively renders the jinja templates." + + result = {} + + for k, v in var_templates_dict.items(): + if isinstance(v, dict): + result[k] = expand_nested_var_templates(v, namespaces) + else: + result[k] = jinja_render_template_strictly(v, namespaces) + + return result