diff --git a/src/aiidalab_qe/app/result/components/status/status.py b/src/aiidalab_qe/app/result/components/status/status.py index c86fea262..7deec95ef 100644 --- a/src/aiidalab_qe/app/result/components/status/status.py +++ b/src/aiidalab_qe/app/result/components/status/status.py @@ -118,54 +118,29 @@ def __init__(self, node, level=0, **kwargs): self.uuid = node.uuid self.level = level self.label = ipw.HTML(self._humanize_title(node)) - self.state = "" self.emoji = ipw.HTML() - self.status = ipw.HTML() - self.inspect = ipw.Button( - description="Inspect", - button_style="info", - layout=ipw.Layout(width="fit-content", margin="0 0 0 5px"), - ) - self.pks = set() + self.state = ipw.HTML() self.title = ipw.HBox( children=[ ipw.HTML(self._get_indentation(level)), self.emoji, self.label, - self.status, - self.inspect if isinstance(node, orm.CalcJobNode) else ipw.HTML(), + ipw.HTML(" | "), + self.state, ], layout=ipw.Layout(align_items="center"), ) - self.branches = ipw.VBox() super().__init__( children=[ self.title, - self.branches, ], **kwargs, ) - def update(self): - node = orm.load_node(self.uuid) - self._add_children(node) - self.state = self._get_state(node) - self.emoji.value = self._get_emoji(self.state) - self.status.value = self._get_status(node) - for branch in self.branches.children: - if isinstance(branch, TreeNode): - branch.update() - - def _add_children(self, node): - for child in node.called: - if child.pk in self.pks: - continue - if child.process_label == "BandsWorkChain": - self._add_children(child) - else: - branch = TreeNode(child, level=self.level + 1) - self.branches.children += (branch,) - self.pks.add(child.pk) + def update(self, node=None): + node = node or orm.load_node(self.uuid) + self.state.value = self._get_state(node) + self.emoji.value = self._get_emoji(self.state.value) def _get_indentation(self, level=0): return " " * 8 * level @@ -180,28 +155,6 @@ def _get_emoji(self, state): "excepted": "❌", }.get(state, "❓") - def _get_status(self, node): - return f"({self._get_tally(node)}{self.state})" - - def _get_tally(self, node): - if not isinstance(node, orm.WorkflowNode): - return "" - inputs = node.get_metadata_inputs() - processes = [key for key in inputs.keys() if key != "metadata"] - total = len(processes) - if node.process_label == "PwBaseWorkChain" and "kpoints" not in node.inputs: - total += 1 # k-point grid generation - if node.process_label == "PwBandsWorkChain": - total += 1 # high-symmetry k-point generation - finished = len( - [ - child.process_state - for child in node.called - if child.process_state is ProcessState.FINISHED - ] - ) - return f"{finished}/{total} job{'s' if total > 1 else ''}; " - def _get_state(self, node): if not hasattr(node, "process_state"): return "queued" @@ -235,13 +188,81 @@ def _humanize_title(self, node): return mappings.get(title, title) +class WorkChainTreeNode(TreeNode): + def __init__(self, node, **kwargs): + super().__init__(node, **kwargs) + self.tally = ipw.HTML() + self.title.children += (ipw.HTML(" | "),) + self.title.children += (self.tally,) + self.pks = set() + self.branches = ipw.VBox() + self.children += (self.branches,) + + def update(self, node=None): + node = node or orm.load_node(self.uuid) + super().update(node) + self.tally.value = self._get_tally(node) + self._add_children(node) + branch: TreeNode + for branch in self.branches.children: + branch.update() + + def _add_children(self, node): + for child in node.called: + if child.pk in self.pks: + continue + if child.process_label == "BandsWorkChain": + self._add_children(child) + else: + TreeNodeClass = ( + WorkChainTreeNode + if isinstance(child, orm.WorkflowNode) + else CalculationTreeNode + ) + branch = TreeNodeClass(child, level=self.level + 1) + self.branches.children += (branch,) + self.pks.add(child.pk) + + def _get_tally(self, node): + inputs = node.get_metadata_inputs() + processes = [key for key in inputs.keys() if key != "metadata"] + total = len(processes) + if node.process_label == "PwBaseWorkChain" and "kpoints" not in node.inputs: + total += 1 # k-point grid generation + if node.process_label == "PwBandsWorkChain": + total += 1 # high-symmetry k-point generation + finished = len( + [ + child.process_state + for child in node.called + if child.process_state is ProcessState.FINISHED + ] + ) + return f"{finished}/{total} job{'s' if total > 1 else ''}" + + +class CalculationTreeNode(TreeNode): + def __init__(self, node, **kwargs): + super().__init__(node, **kwargs) + self.inspect = ipw.Button( + description="Inspect", + button_style="info", + layout=ipw.Layout(width="fit-content", margin="0 0 0 5px"), + ) + self.title.children += (self.inspect,) + + class SimplifiedProcessTreeModel(Model, HasProcess): """""" class SimplifiedProcessTree(ipw.VBox): def __init__(self, model: SimplifiedProcessTreeModel, **kwargs): - super().__init__(**kwargs) + self.loading_message = LoadingWidget("Loading process tree") + super().__init__( + children=[self.loading_message], + **kwargs, + ) self.add_class("simplified-process-tree") self._model = model self._model.observe( @@ -258,7 +279,7 @@ def render(self): if self.rendered: return root = self._model.fetch_process_node() - self.trunk = TreeNode(root) + self.trunk = WorkChainTreeNode(root) self.rendered = True self._update() self.children = [self.trunk]