From f2dea76d7d02762b42426108c9b2f2524d0d39d1 Mon Sep 17 00:00:00 2001 From: Edan Bainglass Date: Mon, 6 Jan 2025 07:53:48 +0000 Subject: [PATCH] Make total jobs count restart-aware --- .../app/result/components/status/tree.py | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/src/aiidalab_qe/app/result/components/status/tree.py b/src/aiidalab_qe/app/result/components/status/tree.py index 46cf4b1dc..5d079a3ef 100644 --- a/src/aiidalab_qe/app/result/components/status/tree.py +++ b/src/aiidalab_qe/app/result/components/status/tree.py @@ -6,7 +6,7 @@ import traitlets as tl from aiida import orm -from aiida.engine import ProcessState +from aiida.engine import BaseRestartWorkChain, ProcessState from aiidalab_qe.common.mixins import HasProcess from aiidalab_qe.common.mvc import Model from aiidalab_qe.common.widgets import LoadingWidget @@ -266,13 +266,7 @@ def _add_branches(self, node=None): self.pks.add(child.pk) def _get_tally(self, node): - inputs = node.get_metadata_inputs() - processes = [key for key in inputs.keys() if key != "metadata"] - total = len(processes) - if node.process_label == "PwBaseWorkChain" and "kpoints" not in node.inputs: - total += 1 # k-point grid generation - if node.process_label == "PwBandsWorkChain": - total += 1 # high-symmetry k-point generation + total = self._get_total(node) finished = len( [ child.process_state @@ -280,7 +274,23 @@ def _get_tally(self, node): if child.process_state is ProcessState.FINISHED ] ) - return f"{finished}/{total} job{'s' if total > 1 else ''}" + tally = f"{finished}/{total}" + tally += "*" if isinstance(node, BaseRestartWorkChain) else "" + tally += " job" if total == 1 else " jobs" + return tally + + def _get_total(self, node): + if isinstance(node, BaseRestartWorkChain): + total = len(node.called) + else: + inputs = node.get_metadata_inputs() + processes = [key for key in inputs.keys() if key != "metadata"] + total = len(processes) + if node.process_label == "PwBaseWorkChain" and "kpoints" not in node.inputs: + total += 1 # k-point grid generation + if node.process_label == "PwBandsWorkChain": + total += 1 # high-symmetry k-point generation + return total def _toggle_branches(self, _): if self.collapsed: