From 95e78e0cf3e41fafffc6cda6229b59b32d82d3ce Mon Sep 17 00:00:00 2001 From: Luik Date: Tue, 30 Jul 2024 17:46:56 +0200 Subject: [PATCH 01/24] Start adding eventsourcing to the postresql db --- biomero/__init__.py | 4 +- biomero/aggregates.py | 256 ++++++++++++++++++++++++++++++++++++++++ biomero/slurm_client.py | 126 ++++++++++++++------ pyproject.toml | 3 +- 4 files changed, 354 insertions(+), 35 deletions(-) create mode 100644 biomero/aggregates.py diff --git a/biomero/__init__.py b/biomero/__init__.py index 638c557..3e86257 100644 --- a/biomero/__init__.py +++ b/biomero/__init__.py @@ -11,4 +11,6 @@ import pkg_resources __version__ = pkg_resources.get_distribution(__package__).version except pkg_resources.DistributionNotFound: - __version__ = "Version not found" \ No newline at end of file + __version__ = "Version not found" + +from .aggregates import * \ No newline at end of file diff --git a/biomero/aggregates.py b/biomero/aggregates.py new file mode 100644 index 0000000..6ad78ec --- /dev/null +++ b/biomero/aggregates.py @@ -0,0 +1,256 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Torec Luik +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from eventsourcing.domain import Aggregate, event +from eventsourcing.application import Application +from uuid import UUID +from typing import Any, Dict, List +from fabric import Result +import logging + + +# Create a logger for this module +logger = logging.getLogger(__name__) + + +class ResultDict(dict): + def __init__(self, result: Result): + super().__init__() + self['command'] = result.command + self['env'] = result.env + self['stdout'] = result.stdout + self['stderr'] = result.stderr + + +# When updating Aggregate classes, take care of versioning for compatibility: +# Bump class_version and define @staticmethod upcast_vX_vY(state) +# for Aggregate (and Event(s)!) +# @See https://eventsourcing.readthedocs.io/en/stable/topics/domain.html#versioning + + +class WorkflowRun(Aggregate): + INITIAL_VERSION = 0 + + class WorkflowInitiated(Aggregate.Created): + name: str + description: str + user: int + group: int + + @event(WorkflowInitiated) + def __init__(self, name: str, + description: str, + user: int, + group: int): + self.name = name + self.description = description + self.user = user + self.group = group + self.tasks = [] + logger.debug(f"Initializing WorkflowRun: name={name}, description={description}, user={user}, group={group}") + + class TaskAdded(Aggregate.Event): + task_id: UUID + + @event(TaskAdded) + def add_task(self, task_id: UUID): + logger.debug(f"Adding task to WorkflowRun: task_id={task_id}") + self.tasks.append(task_id) + + class WorkflowStarted(Aggregate.Event): + pass + + @event(WorkflowStarted) + def start_workflow(self): + logger.debug(f"Starting workflow: id={self.id}") + pass + + class WorkflowCompleted(Aggregate.Event): + pass + + @event(WorkflowCompleted) + def complete_workflow(self): + logger.debug(f"Completing workflow: id={self.id}") + pass + + class WorkflowFailed(Aggregate.Event): + error_message: str + + @event(WorkflowFailed) + def fail_workflow(self, error_message: str): + logger.debug(f"Failing workflow: id={self.id}, error_message={error_message}") + pass + + +class Task(Aggregate): + INITIAL_VERSION = 0 + + class TaskCreated(Aggregate.Created): + workflow_id: UUID + task_name: str + task_version: str + input_data: Dict[str, Any] + params: Dict[str, Any] + + @event(TaskCreated) + def __init__(self, + workflow_id: UUID, + task_name: str, + task_version: str, + input_data: Dict[str, Any], + params: Dict[str, Any] + ): + self.workflow_id = workflow_id + self.task_name = task_name + self.task_version = task_version + self.input_data = input_data + self.params = params + self.job_ids = [] + self.results = [] + self.result_message = None + logger.debug(f"Initializing Task: workflow_id={workflow_id}, task_name={task_name}, task_version={task_version}") + + class JobIdAdded(Aggregate.Event): + job_id: str + + @event(JobIdAdded) + def add_job_id(self, job_id): + logger.debug(f"Adding job_id to Task: task_id={self.id}, job_id={job_id}") + self.job_ids.append(job_id) + + class ResultAdded(Aggregate.Event): + result: ResultDict + + def add_result(self, result: Result): + logger.debug(f"Adding result to Task: task_id={self.id}, result={result}") + result = ResultDict(result) + self._add_result(result) + + @event(ResultAdded) + def _add_result(self, result: ResultDict): + logger.debug(f"Adding result to Task results: task_id={self.id}, result={result}") + self.results.append(result) + + class TaskStarted(Aggregate.Event): + pass + + @event(TaskStarted) + def start_task(self): + logger.debug(f"Starting task: id={self.id}") + pass + + class TaskCompleted(Aggregate.Event): + result: str + + @event(TaskCompleted) + def complete_task(self, result: str): + logger.debug(f"Completing task: id={self.id}, result={result}") + self.result_message = result + + class TaskFailed(Aggregate.Event): + error_message: str + + @event(TaskFailed) + def fail_task(self, error_message: str): + logger.debug(f"Failing task: id={self.id}, error_message={error_message}") + pass + + +class WorkflowTracker(Application): + + def initiate_workflow(self, + name: str, + description: str, + user: int, + group: int) -> UUID: + logger.debug(f"Initiating workflow: name={name}, description={description}, user={user}, group={group}") + workflow = WorkflowRun(name, description, user, group) + self.save(workflow) + return workflow.id + + def add_task_to_workflow(self, + workflow_id: UUID, + task_name: str, + task_version: str, + input_data: Dict[str, Any], + kwargs: Dict[str, Any] + ) -> UUID: + logger.debug(f"Adding task to workflow: workflow_id={workflow_id}, task_name={task_name}, task_version={task_version}") + + task = Task(workflow_id, + task_name, + task_version, + input_data, + kwargs) + self.save(task) + workflow = self.repository.get(workflow_id) + workflow.add_task(task.id) + self.save(workflow) + return task.id + + def start_workflow(self, workflow_id: UUID): + logger.debug(f"Starting workflow: workflow_id={workflow_id}") + + workflow = self.repository.get(workflow_id) + workflow.start_workflow() + self.save(workflow) + + def complete_workflow(self, workflow_id: UUID): + logger.debug(f"Completing workflow: workflow_id={workflow_id}") + + workflow = self.repository.get(workflow_id) + workflow.complete_workflow() + self.save(workflow) + + def fail_workflow(self, workflow_id: UUID, error_message: str): + logger.debug(f"Failing workflow: workflow_id={workflow_id}, error_message={error_message}") + + workflow = self.repository.get(workflow_id) + workflow.fail_workflow(error_message) + self.save(workflow) + + def start_task(self, task_id: UUID): + logger.debug(f"Starting task: task_id={task_id}") + + task = self.repository.get(task_id) + task.start_task() + self.save(task) + + def complete_task(self, task_id: UUID, message: str): + logger.debug(f"Completing task: task_id={task_id}, message={message}") + + task = self.repository.get(task_id) + task.complete_task(message) + self.save(task) + + def fail_task(self, task_id: UUID, error_message: str): + logger.debug(f"Failing task: task_id={task_id}, error_message={error_message}") + + task = self.repository.get(task_id) + task.fail_task(error_message) + self.save(task) + + def add_job_id(self, task_id, slurm_job_id): + logger.debug(f"Adding job_id to task: task_id={task_id}, slurm_job_id={slurm_job_id}") + + task = self.repository.get(task_id) + task.add_job_id(slurm_job_id) + self.save(task) + + def add_result(self, task_id, result): + logger.debug(f"Adding result to task: task_id={task_id}, result={result}") + + task = self.repository.get(task_id) + task.add_result(result) + self.save(task) diff --git a/biomero/slurm_client.py b/biomero/slurm_client.py index 1b5cdc2..8c80d95 100644 --- a/biomero/slurm_client.py +++ b/biomero/slurm_client.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict, List, Optional, Tuple, Any +from uuid import UUID from fabric import Connection, Result from fabric.transfer import Result as TransferResult from invoke.exceptions import UnexpectedExit @@ -29,6 +30,7 @@ from importlib_resources import files import io import os +from biomero.aggregates import WorkflowTracker logger = logging.getLogger(__name__) @@ -49,15 +51,14 @@ class SlurmJob: Args: submit_result (Result): The result of submitting the job. - job_id (int): The Slurm job ID. - - Example: + job_id (int): The Sluslurm_job_id # Submit some job with the SlurmClient - submit_result, job_id = slurmClient.run_workflow( - workflow_name, workflow_version, input_data, email, time, **kwargs) + submit_result, job_id, wf_id, task_id = slurmClient.run_workflow( + workflow_name, workflow_version, input_data, email, time, wf_id, + **kwargs) # Create a SlurmJob instance - slurmJob = SlurmJob(submit_result, job_id) + slurmJob = SlurmJob(submit_result, job_id, wf_id, task_id) if not slurmJob.ok: logger.warning(f"Error with job: {slurmJob.get_error()}") @@ -76,7 +77,9 @@ class SlurmJob: def __init__(self, submit_result: Result, - job_id: int): + job_id: int, + wf_id: UUID, + task_id: UUID): """ Initialize a SlurmJob instance. @@ -85,6 +88,8 @@ def __init__(self, job_id (int): The Slurm job ID. """ self.job_id = job_id + self.wf_id = wf_id + self.task_id = task_id self.submit_result = submit_result self.ok = self.submit_result.ok self.job_state = None @@ -221,7 +226,7 @@ class SlurmClient(Connection): Example 2: # Create a SlurmClient and setup Slurm (download containers etc.) - with SlurmClient.from_config(init_slurm=True) as client: + with SlurmClient.from_config(inislurm_job_id client.run_workflow(...) @@ -276,6 +281,7 @@ def __init__(self, slurm_script_path: str = _DEFAULT_SLURM_GIT_SCRIPT_PATH, slurm_script_repo: str = None, init_slurm: bool = False, + track_workflows: bool = True, ): """Initializes a new instance of the SlurmClient class. @@ -380,6 +386,14 @@ def __init__(self, self.init_workflows() self.validate(validate_slurm_setup=init_slurm) + + # Setup workflow tracking + self.track_workflows = track_workflows + if self.track_workflows: # use configured persistence from env + self.workflowTracker = WorkflowTracker() + else: # turn off persistence, override + self.workflowTracker = WorkflowTracker(env={ + "PERSISTENCE_MODULE": ""}) def init_workflows(self, force_update: bool = False): """ @@ -1238,8 +1252,9 @@ def run_workflow(self, input_data: str, email: Optional[str] = None, time: Optional[str] = None, + wf_id: Optional[UUID] = None, **kwargs - ) -> Tuple[Result, int]: + ) -> Tuple[Result, int, UUID, UUID]: """ Run a specified workflow on Slurm using the given parameters. @@ -1252,24 +1267,48 @@ def run_workflow(self, email (str, optional): Email address for Slurm job notifications. time (str, optional): Time limit for the Slurm job in the format HH:MM:SS. + wf_id (UUID, optional): Workflow ID for tracking purposes. If not provided, a new one is created. **kwargs: Additional keyword arguments for the workflow. Returns: - Tuple[Result, int]: - A tuple containing the result of starting the workflow job and - the Slurm job ID, or -1 if the job ID could not be extracted. + Tuple[Result, int, UUID, UUID]: + A tuple containing the result of starting the workflow job, + the Slurm job ID, the workflow ID, and the task ID. + If the Slurm job ID could not be extracted, it returns -1 for the job ID. Note: - The Slurm job ID is extracted from the result of the - `run_commands` method. - """ + The Slurm job ID is extracted from the result of the `run_commands` method. + If `track_workflows` is enabled, workflow and task tracking is performed. + """ + if not wf_id: + wf_id = self.workflowTracker.initiate_workflow( + workflow_name, + workflow_version, + -1, + -1 + ) + task_id = self.workflowTracker.add_task_to_workflow( + wf_id, + workflow_name, + workflow_version, + input_data, + kwargs) + logger.debug(f"Added new task {task_id} to workflow {wf_id}") + sbatch_cmd, sbatch_env = self.get_workflow_command( workflow_name, workflow_version, input_data, email, time, **kwargs) print(f"Running {workflow_name} job on {input_data} on Slurm:\ {sbatch_cmd} w/ {sbatch_env}") logger.info(f"Running {workflow_name} job on {input_data} on Slurm") res = self.run_commands([sbatch_cmd], sbatch_env) - return res, self.extract_job_id(res) + slurm_job_id = self.extract_job_id(res) + + if self.track_workflows and task_id: + self.workflowTracker.start_task(task_id) + self.workflowTracker.add_job_id(task_id, slurm_job_id) + self.workflowTracker.add_result(task_id, res) + + return res, slurm_job_id, wf_id, task_id def run_workflow_job(self, workflow_name: str, @@ -1277,6 +1316,7 @@ def run_workflow_job(self, input_data: str, email: Optional[str] = None, time: Optional[str] = None, + wf_id: Optional[UUID] = None, **kwargs ) -> SlurmJob: """ @@ -1288,19 +1328,23 @@ def run_workflow_job(self, input_data (str): Name of the input data folder containing input image files. email (str, optional): Email address for Slurm job notifications. time (str, optional): Time limit for the Slurm job in the format HH:MM:SS. + wf_id (UUID, optional): Workflow ID for tracking purposes. If not provided, a new one is created. **kwargs: Additional keyword arguments for the workflow. Returns: SlurmJob: A SlurmJob instance representing the started workflow job. """ - result, job_id = self.run_workflow( - workflow_name, workflow_version, input_data, email, time, **kwargs) - return SlurmJob(result, job_id) + result, job_id, wf_id, task_id = self.run_workflow( + workflow_name, workflow_version, input_data, email, time, wf_id, + **kwargs) + return SlurmJob(result, job_id, wf_id, task_id) - def run_conversion_workflow_job(self, folder_name: str, + def run_conversion_workflow_job(self, + folder_name: str, source_format: str = 'zarr', - target_format: str = 'tiff' - ) -> Tuple[Result, int]: + target_format: str = 'tiff', + wf_id: UUID = None + ) -> SlurmJob: """ Run the data conversion workflow on Slurm using the given data folder. @@ -1310,9 +1354,8 @@ def run_conversion_workflow_job(self, folder_name: str, target_format (str): Target data format after conversion (default is 'tiff'). Returns: - Tuple[Result, int]: - A tuple containing the result of starting the conversion job and - the Slurm job ID, or -1 if the job ID could not be extracted. + SlurmJob: + the conversion job Warning: The default implementation only supports conversion from 'zarr' to 'tiff'. @@ -1324,7 +1367,7 @@ def run_conversion_workflow_job(self, folder_name: str, # Construct all commands to run consecutively data_path = f"{self.slurm_data_path}/{folder_name}" - conversion_cmd, sbatch_env = self.get_conversion_command( + conversion_cmd, sbatch_env, chosen_converter, version = self.get_conversion_command( data_path, config_file, source_format, target_format) commands = [ f"find \"{data_path}/data/in\" -name \"*.{source_format}\" | awk '{{print NR, $0}}' > \"{config_file}\"", @@ -1332,11 +1375,26 @@ def run_conversion_workflow_job(self, folder_name: str, f"echo \"Number of .{source_format} files: $N\"", conversion_cmd ] + + if not wf_id: + wf_id = self.workflowTracker.initiate_workflow( + "conversion", + -1, + -1, + -1 + ) + task_id = self.workflowTracker.add_task_to_workflow( + wf_id, + chosen_converter, + version, + data_path, + sbatch_env + ) # Run all commands consecutively res = self.run_commands(commands, sbatch_env) - - return SlurmJob(res, self.extract_job_id(res)) + + return SlurmJob(res, self.extract_job_id(res), wf_id, task_id) def extract_job_id(self, result: Result) -> int: """ @@ -1761,7 +1819,7 @@ def get_workflow_command(self, def get_conversion_command(self, data_path: str, config_file: str, source_format: str = 'zarr', - target_format: str = 'tiff') -> Tuple[str, Dict]: + target_format: str = 'tiff') -> Tuple[str, Dict, str, str]: """ Generate Slurm conversion command and environment variables for data conversion. @@ -1772,9 +1830,9 @@ def get_conversion_command(self, data_path: str, target_format (str): Target data format (default is 'tiff'). Returns: - Tuple[str, Dict]: + Tuple[str, Dict, str, str]: A tuple containing the Slurm conversion command and - the environment variables. + the environment variables, followed by the converter image name and version. Warning: The default implementation only supports conversion from 'zarr' to 'tiff'. @@ -1790,11 +1848,13 @@ def get_conversion_command(self, data_path: str, f"Conversion from {source_format} to {target_format} is not supported by default!") chosen_converter = f"convert_{source_format}_to_{target_format}_latest.sif" + version = None if self.converter_images: image = self.converter_images[f"{source_format}_to_{target_format}"] version, image = self.parse_docker_image_version(image) if version: - chosen_converter = f"convert_{source_format}_to_{target_format}_{version}.sif" + chosen_converter = f"convert_{source_format}_to_{target_format}_{version}.sif" + version = version or "latest" logger.info(f"Converting with {chosen_converter}") sbatch_env = { @@ -1808,7 +1868,7 @@ def get_conversion_command(self, data_path: str, conversion_cmd = "sbatch --job-name=conversion --export=ALL,CONFIG_PATH=\"$PWD/$CONFIG_FILE\" --array=1-$N \"$SCRIPT_PATH/convert_job_array.sh\"" # conversion_cmd_waiting = "sbatch --job-name=conversion --export=ALL,CONFIG_PATH=\"$PWD/$CONFIG_FILE\" --array=1-$N --wait $SCRIPT_PATH/convert_job_array.sh" - return conversion_cmd, sbatch_env + return conversion_cmd, sbatch_env, chosen_converter, version def workflow_params_to_envvars(self, **kwargs) -> Dict: """ diff --git a/pyproject.toml b/pyproject.toml index 9d43ec4..834adad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,8 @@ dependencies = [ "requests-cache==1.1.1", "fabric==3.1.0", "paramiko==3.4.0", - "importlib_resources>=5.4.0" + "importlib_resources>=5.4.0", + "eventsourcing[crypto,postgres-dev]==9.2.22" ] [tool.setuptools.packages] From 399c6615cab03b469482323c19b9ad3795d10d36 Mon Sep 17 00:00:00 2001 From: Luik Date: Wed, 31 Jul 2024 14:41:06 +0200 Subject: [PATCH 02/24] Store full workflows --- biomero/aggregates.py | 16 ++++++++++++++++ biomero/slurm_client.py | 9 +++++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/biomero/aggregates.py b/biomero/aggregates.py index 6ad78ec..fad96d9 100644 --- a/biomero/aggregates.py +++ b/biomero/aggregates.py @@ -119,6 +119,7 @@ def __init__(self, self.job_ids = [] self.results = [] self.result_message = None + self.status = None logger.debug(f"Initializing Task: workflow_id={workflow_id}, task_name={task_name}, task_version={task_version}") class JobIdAdded(Aggregate.Event): @@ -128,6 +129,14 @@ class JobIdAdded(Aggregate.Event): def add_job_id(self, job_id): logger.debug(f"Adding job_id to Task: task_id={self.id}, job_id={job_id}") self.job_ids.append(job_id) + + class StatusUpdated(Aggregate.Event): + status: str + + @event(StatusUpdated) + def update_task_status(self, status): + logger.debug(f"Adding status to Task: task_id={self.id}, status={status}") + self.status = status class ResultAdded(Aggregate.Event): result: ResultDict @@ -254,3 +263,10 @@ def add_result(self, task_id, result): task = self.repository.get(task_id) task.add_result(result) self.save(task) + + def update_task_status(self, task_id, status): + logger.debug(f"Adding status to task: task_id={task_id}, status={status}") + + task = self.repository.get(task_id) + task.update_task_status(status) + self.save(task) diff --git a/biomero/slurm_client.py b/biomero/slurm_client.py index 8c80d95..d57ec3e 100644 --- a/biomero/slurm_client.py +++ b/biomero/slurm_client.py @@ -74,12 +74,14 @@ class SlurmJob: raise e """ + SLURM_POLLING_INTERVAL = 10 # seconds def __init__(self, submit_result: Result, job_id: int, wf_id: UUID, - task_id: UUID): + task_id: UUID, + slurm_polling_interval: int = SLURM_POLLING_INTERVAL): """ Initialize a SlurmJob instance. @@ -90,6 +92,7 @@ def __init__(self, self.job_id = job_id self.wf_id = wf_id self.task_id = task_id + self.slurm_polling_interval = slurm_polling_interval self.submit_result = submit_result self.ok = self.submit_result.ok self.job_state = None @@ -124,7 +127,9 @@ def wait_for_completion(self, slurmClient, omeroConn) -> str: self.job_state = job_status_dict[self.job_id] # wait for 10 seconds before checking again omeroConn.keepAlive() # keep the OMERO connection alive - timesleep.sleep(10) + slurmClient.workflowTracker.update_task_status(self.task_id, + self.job_state) + timesleep.sleep(self.slurm_polling_interval) logger.info(f"Job {self.job_id} finished: {self.job_state}") logger.info( f"You can get the logfile using `Slurm Get Update` on job {self.job_id}") From 1557f45399f91e67e055361b9366327190564cb1 Mon Sep 17 00:00:00 2001 From: Luik Date: Tue, 6 Aug 2024 16:09:22 +0200 Subject: [PATCH 03/24] Add a view table for user-job view --- biomero/__init__.py | 2 +- biomero/{aggregates.py => eventsourcing.py} | 227 ++++++++++++++++++-- biomero/slurm_client.py | 13 +- pyproject.toml | 3 +- 4 files changed, 218 insertions(+), 27 deletions(-) rename biomero/{aggregates.py => eventsourcing.py} (52%) diff --git a/biomero/__init__.py b/biomero/__init__.py index 3e86257..e0a2329 100644 --- a/biomero/__init__.py +++ b/biomero/__init__.py @@ -13,4 +13,4 @@ except pkg_resources.DistributionNotFound: __version__ = "Version not found" -from .aggregates import * \ No newline at end of file +from .eventsourcing import * \ No newline at end of file diff --git a/biomero/aggregates.py b/biomero/eventsourcing.py similarity index 52% rename from biomero/aggregates.py rename to biomero/eventsourcing.py index fad96d9..c8cc7f6 100644 --- a/biomero/aggregates.py +++ b/biomero/eventsourcing.py @@ -12,17 +12,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from eventsourcing.domain import Aggregate, event -from eventsourcing.application import Application -from uuid import UUID +from eventsourcing.application import Application, AggregateNotFound +from eventsourcing.system import ProcessApplication +from eventsourcing.dispatch import singledispatchmethod +from uuid import NAMESPACE_URL, UUID, uuid5 from typing import Any, Dict, List from fabric import Result import logging +from sqlalchemy import create_engine, text, Column, Integer, String, URL +from sqlalchemy.orm import sessionmaker, declarative_base +from sqlalchemy.exc import IntegrityError +from sqlalchemy.schema import CreateTable + # Create a logger for this module logger = logging.getLogger(__name__) +# -------------------- DOMAIN MODEL -------------------- # + class ResultDict(dict): def __init__(self, result: Result): @@ -49,8 +59,8 @@ class WorkflowInitiated(Aggregate.Created): group: int @event(WorkflowInitiated) - def __init__(self, name: str, - description: str, + def __init__(self, name: str, + description: str, user: int, group: int): self.name = name @@ -104,9 +114,9 @@ class TaskCreated(Aggregate.Created): params: Dict[str, Any] @event(TaskCreated) - def __init__(self, - workflow_id: UUID, - task_name: str, + def __init__(self, + workflow_id: UUID, + task_name: str, task_version: str, input_data: Dict[str, Any], params: Dict[str, Any] @@ -129,7 +139,7 @@ class JobIdAdded(Aggregate.Event): def add_job_id(self, job_id): logger.debug(f"Adding job_id to Task: task_id={self.id}, job_id={job_id}") self.job_ids.append(job_id) - + class StatusUpdated(Aggregate.Event): status: str @@ -176,11 +186,32 @@ def fail_task(self, error_message: str): pass +class JobAccount(Aggregate): + INITIAL_VERSION = 0 + + def __init__(self, user_id, group_id): + self.user_id = user_id + self.group_id = group_id + self.jobs = [] + + @classmethod + def create_id(cls, user_id, group_id): + return uuid5(NAMESPACE_URL, f'/jobaccount/{group_id}/{user_id}') + + @event('JobAdded') + def add_job(self, job_id): + logger.debug(f"Adding job: id={self.id}, job={job_id}, user={self.user_id}, group={self.group_id}") + self.jobs.append(job_id) + + +# -------------------- APPLICATIONS -------------------- # + + class WorkflowTracker(Application): - def initiate_workflow(self, - name: str, - description: str, + def initiate_workflow(self, + name: str, + description: str, user: int, group: int) -> UUID: logger.debug(f"Initiating workflow: name={name}, description={description}, user={user}, group={group}") @@ -188,8 +219,8 @@ def initiate_workflow(self, self.save(workflow) return workflow.id - def add_task_to_workflow(self, - workflow_id: UUID, + def add_task_to_workflow(self, + workflow_id: UUID, task_name: str, task_version: str, input_data: Dict[str, Any], @@ -197,8 +228,8 @@ def add_task_to_workflow(self, ) -> UUID: logger.debug(f"Adding task to workflow: workflow_id={workflow_id}, task_name={task_name}, task_version={task_version}") - task = Task(workflow_id, - task_name, + task = Task(workflow_id, + task_name, task_version, input_data, kwargs) @@ -228,7 +259,7 @@ def fail_workflow(self, workflow_id: UUID, error_message: str): workflow = self.repository.get(workflow_id) workflow.fail_workflow(error_message) self.save(workflow) - + def start_task(self, task_id: UUID): logger.debug(f"Starting task: task_id={task_id}") @@ -249,24 +280,178 @@ def fail_task(self, task_id: UUID, error_message: str): task = self.repository.get(task_id) task.fail_task(error_message) self.save(task) - + def add_job_id(self, task_id, slurm_job_id): logger.debug(f"Adding job_id to task: task_id={task_id}, slurm_job_id={slurm_job_id}") task = self.repository.get(task_id) task.add_job_id(slurm_job_id) self.save(task) - + def add_result(self, task_id, result): logger.debug(f"Adding result to task: task_id={task_id}, result={result}") - + task = self.repository.get(task_id) task.add_result(result) self.save(task) - + def update_task_status(self, task_id, status): logger.debug(f"Adding status to task: task_id={task_id}, status={status}") - + task = self.repository.get(task_id) task.update_task_status(status) self.save(task) + + +#--------------------- VIEWS ---------------------------- # + +# Base class for declarative class definitions +Base = declarative_base() + + +class JobView(Base): + __tablename__ = 'biomero_job_view' + + slurm_job_id = Column(Integer, primary_key=True) + user = Column(Integer, nullable=False) + group = Column(Integer, nullable=False) + + +class JobAccounting(ProcessApplication): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Read database configuration from environment variables + database_url = URL.create( + drivername="postgresql+psycopg2", + username=os.getenv('POSTGRES_USER'), + password=os.getenv('POSTGRES_PASSWORD'), + host=os.getenv('POSTGRES_HOST', 'localhost'), + port=os.getenv('POSTGRES_PORT', 5432), + database=os.getenv('POSTGRES_DBNAME') + ) + + # Set up SQLAlchemy engine and session + self.engine = create_engine(database_url) + self.SessionLocal = sessionmaker(bind=self.engine) + + # State tracking + self.workflows = {} # {wf_id: {"user": user, "group": group}} + self.tasks = {} # {task_id: wf_id} + self.jobs = {} # {job_id: (task_id, user, group)} + + # Create defined tables (subclasses of Base) if they don't exist + Base.metadata.create_all(self.engine) + + @singledispatchmethod + def policy(self, domain_event, process_event): + """Default policy""" + + @policy.register(WorkflowRun.WorkflowInitiated) + def _(self, domain_event, process_event): + """Handle WorkflowInitiated event""" + user = domain_event.user + group = domain_event.group + wf_id = domain_event.originator_id + + # Track workflow + self.workflows[wf_id] = {"user": user, "group": group} + logger.debug(f"Workflow initiated: wf_id={wf_id}, user={user}, group={group}") + + # Optionally, persist this state if needed + # Optionally, add an event to do that, then save via collect + # process_event.collect_events(jobaccount, wfView) + + @policy.register(WorkflowRun.TaskAdded) + def _(self, domain_event, process_event): + """Handle TaskAdded event""" + task_id = domain_event.task_id + wf_id = domain_event.originator_id + + # Track task + self.tasks[task_id] = wf_id + logger.debug(f"Task added: task_id={task_id}, wf_id={wf_id}") + + # Optionally, persist this state if needed + # use .collect_events(agg) instead of .save(agg) + # process_event.collect_events(taskView) + + @policy.register(Task.JobIdAdded) + def _(self, domain_event, process_event): + """Handle JobIdAdded event""" + # Grab event payload + job_id = domain_event.job_id + task_id = domain_event.originator_id + + # Find workflow and user/group for the task + wf_id = self.tasks.get(task_id) + if wf_id: + workflow_info = self.workflows.get(wf_id) + if workflow_info: + user = workflow_info["user"] + group = workflow_info["group"] + + # Track job + self.jobs[job_id] = (task_id, user, group) + logger.debug(f"Job added: job_id={job_id}, task_id={task_id}, user={user}, group={group}") + + + # Update view table + self.update_view_table(job_id, user, group) + else: + logger.debug(f"JobIdAdded event ignored: task_id={task_id} not found in tasks") + + # use .collect_events(agg) instead of .save(agg) + # process_event.collect_events(jobaccount) + + def update_view_table(self, job_id, user, group): + """Update the view table with new job information.""" + with self.SessionLocal() as session: + try: + new_job = JobView(slurm_job_id=job_id, user=user, group=group) + session.add(new_job) + session.commit() + logger.debug(f"Inserted job into view table: job_id={job_id}, user={user}, group={group}") + except IntegrityError: + session.rollback() + # Handle the case where the job already exists in the table if necessary + logger.error(f"Failed to insert job into view table (already exists?): job_id={job_id}, user={user}, group={group}") + + def get_jobs(self, user=None, group=None): + """Retrieve jobs for a specific user and/or group. + + Parameters: + - user (int, optional): The user ID to filter by. + - group (int, optional): The group ID to filter by. + + Returns: + - Dictionary of user IDs to lists of job IDs if no user is specified. + - Dictionary with a single user ID key and a list of job IDs if user is specified. + + Raises: + - ValueError: If neither user nor group is provided. + """ + if user is None and group is None: + # Retrieve all jobs grouped by user + with self.SessionLocal() as session: + jobs = session.query(JobView.user, JobView.slurm_job_id).all() + user_jobs = {} + for user_id, job_id in jobs: + if user_id not in user_jobs: + user_jobs[user_id] = [] + user_jobs[user_id].append(job_id) + return user_jobs + else: + with self.SessionLocal() as session: + query = session.query(JobView.slurm_job_id) + + if user is not None: + query = query.filter_by(user=user) + + if group is not None: + query = query.filter_by(group=group) + + jobs = query.all() + result = {user: [job.slurm_job_id for job in jobs]} + logger.debug(f"Retrieved jobs for user={user} and group={group}: {result}") + return result diff --git a/biomero/slurm_client.py b/biomero/slurm_client.py index d57ec3e..e813c6c 100644 --- a/biomero/slurm_client.py +++ b/biomero/slurm_client.py @@ -30,7 +30,8 @@ from importlib_resources import files import io import os -from biomero.aggregates import WorkflowTracker +from biomero.eventsourcing import WorkflowTracker, JobAccounting +from eventsourcing.system import System, SingleThreadedRunner logger = logging.getLogger(__name__) @@ -392,13 +393,17 @@ def __init__(self, self.init_workflows() self.validate(validate_slurm_setup=init_slurm) - # Setup workflow tracking + # Setup workflow tracking and accounting self.track_workflows = track_workflows + system = System(pipes=[[WorkflowTracker, JobAccounting]]) if self.track_workflows: # use configured persistence from env - self.workflowTracker = WorkflowTracker() + runner = SingleThreadedRunner(system) else: # turn off persistence, override - self.workflowTracker = WorkflowTracker(env={ + runner = SingleThreadedRunner(system, env={ "PERSISTENCE_MODULE": ""}) + runner.start() + self.workflowTracker = runner.get(WorkflowTracker) + self.jobAccounting = runner.get(JobAccounting) def init_workflows(self, force_update: bool = False): """ diff --git a/pyproject.toml b/pyproject.toml index 834adad..0897d44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,8 @@ dependencies = [ "fabric==3.1.0", "paramiko==3.4.0", "importlib_resources>=5.4.0", - "eventsourcing[crypto,postgres-dev]==9.2.22" + "eventsourcing[crypto,postgres-dev]==9.2.22", + "sqlalchemy==2.0.32" ] [tool.setuptools.packages] From 389bafcc42a3b30cfe97d4801620a377b9b43a28 Mon Sep 17 00:00:00 2001 From: Luik Date: Wed, 7 Aug 2024 12:13:33 +0200 Subject: [PATCH 04/24] Split views; Add jobid tracking for conversion --- biomero/__init__.py | 3 +- biomero/eventsourcing.py | 180 +-------------------------------------- biomero/slurm_client.py | 12 ++- biomero/views.py | 167 ++++++++++++++++++++++++++++++++++++ pyproject.toml | 8 +- 5 files changed, 186 insertions(+), 184 deletions(-) create mode 100644 biomero/views.py diff --git a/biomero/__init__.py b/biomero/__init__.py index e0a2329..40c84cf 100644 --- a/biomero/__init__.py +++ b/biomero/__init__.py @@ -13,4 +13,5 @@ except pkg_resources.DistributionNotFound: __version__ = "Version not found" -from .eventsourcing import * \ No newline at end of file +from .eventsourcing import * +from .views import * \ No newline at end of file diff --git a/biomero/eventsourcing.py b/biomero/eventsourcing.py index c8cc7f6..e288999 100644 --- a/biomero/eventsourcing.py +++ b/biomero/eventsourcing.py @@ -14,18 +14,11 @@ # limitations under the License. import os from eventsourcing.domain import Aggregate, event -from eventsourcing.application import Application, AggregateNotFound -from eventsourcing.system import ProcessApplication -from eventsourcing.dispatch import singledispatchmethod +from eventsourcing.application import Application from uuid import NAMESPACE_URL, UUID, uuid5 from typing import Any, Dict, List from fabric import Result import logging -from sqlalchemy import create_engine, text, Column, Integer, String, URL -from sqlalchemy.orm import sessionmaker, declarative_base -from sqlalchemy.exc import IntegrityError -from sqlalchemy.schema import CreateTable - # Create a logger for this module @@ -186,24 +179,6 @@ def fail_task(self, error_message: str): pass -class JobAccount(Aggregate): - INITIAL_VERSION = 0 - - def __init__(self, user_id, group_id): - self.user_id = user_id - self.group_id = group_id - self.jobs = [] - - @classmethod - def create_id(cls, user_id, group_id): - return uuid5(NAMESPACE_URL, f'/jobaccount/{group_id}/{user_id}') - - @event('JobAdded') - def add_job(self, job_id): - logger.debug(f"Adding job: id={self.id}, job={job_id}, user={self.user_id}, group={self.group_id}") - self.jobs.append(job_id) - - # -------------------- APPLICATIONS -------------------- # @@ -303,155 +278,4 @@ def update_task_status(self, task_id, status): self.save(task) -#--------------------- VIEWS ---------------------------- # - -# Base class for declarative class definitions -Base = declarative_base() - - -class JobView(Base): - __tablename__ = 'biomero_job_view' - - slurm_job_id = Column(Integer, primary_key=True) - user = Column(Integer, nullable=False) - group = Column(Integer, nullable=False) - - -class JobAccounting(ProcessApplication): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # Read database configuration from environment variables - database_url = URL.create( - drivername="postgresql+psycopg2", - username=os.getenv('POSTGRES_USER'), - password=os.getenv('POSTGRES_PASSWORD'), - host=os.getenv('POSTGRES_HOST', 'localhost'), - port=os.getenv('POSTGRES_PORT', 5432), - database=os.getenv('POSTGRES_DBNAME') - ) - - # Set up SQLAlchemy engine and session - self.engine = create_engine(database_url) - self.SessionLocal = sessionmaker(bind=self.engine) - - # State tracking - self.workflows = {} # {wf_id: {"user": user, "group": group}} - self.tasks = {} # {task_id: wf_id} - self.jobs = {} # {job_id: (task_id, user, group)} - - # Create defined tables (subclasses of Base) if they don't exist - Base.metadata.create_all(self.engine) - - @singledispatchmethod - def policy(self, domain_event, process_event): - """Default policy""" - - @policy.register(WorkflowRun.WorkflowInitiated) - def _(self, domain_event, process_event): - """Handle WorkflowInitiated event""" - user = domain_event.user - group = domain_event.group - wf_id = domain_event.originator_id - - # Track workflow - self.workflows[wf_id] = {"user": user, "group": group} - logger.debug(f"Workflow initiated: wf_id={wf_id}, user={user}, group={group}") - - # Optionally, persist this state if needed - # Optionally, add an event to do that, then save via collect - # process_event.collect_events(jobaccount, wfView) - - @policy.register(WorkflowRun.TaskAdded) - def _(self, domain_event, process_event): - """Handle TaskAdded event""" - task_id = domain_event.task_id - wf_id = domain_event.originator_id - - # Track task - self.tasks[task_id] = wf_id - logger.debug(f"Task added: task_id={task_id}, wf_id={wf_id}") - - # Optionally, persist this state if needed - # use .collect_events(agg) instead of .save(agg) - # process_event.collect_events(taskView) - - @policy.register(Task.JobIdAdded) - def _(self, domain_event, process_event): - """Handle JobIdAdded event""" - # Grab event payload - job_id = domain_event.job_id - task_id = domain_event.originator_id - - # Find workflow and user/group for the task - wf_id = self.tasks.get(task_id) - if wf_id: - workflow_info = self.workflows.get(wf_id) - if workflow_info: - user = workflow_info["user"] - group = workflow_info["group"] - - # Track job - self.jobs[job_id] = (task_id, user, group) - logger.debug(f"Job added: job_id={job_id}, task_id={task_id}, user={user}, group={group}") - - - # Update view table - self.update_view_table(job_id, user, group) - else: - logger.debug(f"JobIdAdded event ignored: task_id={task_id} not found in tasks") - - # use .collect_events(agg) instead of .save(agg) - # process_event.collect_events(jobaccount) - - def update_view_table(self, job_id, user, group): - """Update the view table with new job information.""" - with self.SessionLocal() as session: - try: - new_job = JobView(slurm_job_id=job_id, user=user, group=group) - session.add(new_job) - session.commit() - logger.debug(f"Inserted job into view table: job_id={job_id}, user={user}, group={group}") - except IntegrityError: - session.rollback() - # Handle the case where the job already exists in the table if necessary - logger.error(f"Failed to insert job into view table (already exists?): job_id={job_id}, user={user}, group={group}") - - def get_jobs(self, user=None, group=None): - """Retrieve jobs for a specific user and/or group. - - Parameters: - - user (int, optional): The user ID to filter by. - - group (int, optional): The group ID to filter by. - - Returns: - - Dictionary of user IDs to lists of job IDs if no user is specified. - - Dictionary with a single user ID key and a list of job IDs if user is specified. - - Raises: - - ValueError: If neither user nor group is provided. - """ - if user is None and group is None: - # Retrieve all jobs grouped by user - with self.SessionLocal() as session: - jobs = session.query(JobView.user, JobView.slurm_job_id).all() - user_jobs = {} - for user_id, job_id in jobs: - if user_id not in user_jobs: - user_jobs[user_id] = [] - user_jobs[user_id].append(job_id) - return user_jobs - else: - with self.SessionLocal() as session: - query = session.query(JobView.slurm_job_id) - - if user is not None: - query = query.filter_by(user=user) - - if group is not None: - query = query.filter_by(group=group) - - jobs = query.all() - result = {user: [job.slurm_job_id for job in jobs]} - logger.debug(f"Retrieved jobs for user={user} and group={group}: {result}") - return result + diff --git a/biomero/slurm_client.py b/biomero/slurm_client.py index e813c6c..63892f0 100644 --- a/biomero/slurm_client.py +++ b/biomero/slurm_client.py @@ -30,7 +30,8 @@ from importlib_resources import files import io import os -from biomero.eventsourcing import WorkflowTracker, JobAccounting +from biomero.eventsourcing import WorkflowTracker +from biomero.views import JobAccounting from eventsourcing.system import System, SingleThreadedRunner logger = logging.getLogger(__name__) @@ -1404,7 +1405,14 @@ def run_conversion_workflow_job(self, # Run all commands consecutively res = self.run_commands(commands, sbatch_env) - return SlurmJob(res, self.extract_job_id(res), wf_id, task_id) + slurm_job_id = self.extract_job_id(res) + + if self.track_workflows and task_id: + self.workflowTracker.start_task(task_id) + self.workflowTracker.add_job_id(task_id, slurm_job_id) + self.workflowTracker.add_result(task_id, res) + + return SlurmJob(res, slurm_job_id, wf_id, task_id) def extract_job_id(self, result: Result) -> int: """ diff --git a/biomero/views.py b/biomero/views.py new file mode 100644 index 0000000..a753fb0 --- /dev/null +++ b/biomero/views.py @@ -0,0 +1,167 @@ +import os + +from eventsourcing.system import ProcessApplication +from eventsourcing.dispatch import singledispatchmethod +from uuid import NAMESPACE_URL, UUID, uuid5 +from typing import Any, Dict, List +import logging +from sqlalchemy import create_engine, text, Column, Integer, String, URL +from sqlalchemy.orm import sessionmaker, declarative_base +from sqlalchemy.exc import IntegrityError +from biomero.eventsourcing import WorkflowRun, Task + + +logger = logging.getLogger(__name__) + +# --------------------- VIEWS ---------------------------- # + +# Base class for declarative class definitions +Base = declarative_base() + + +class JobView(Base): + __tablename__ = 'biomero_job_view' + + slurm_job_id = Column(Integer, primary_key=True) + user = Column(Integer, nullable=False) + group = Column(Integer, nullable=False) + + +class JobAccounting(ProcessApplication): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Read database configuration from environment variables + database_url = URL.create( + drivername="postgresql+psycopg2", + username=os.getenv('POSTGRES_USER'), + password=os.getenv('POSTGRES_PASSWORD'), + host=os.getenv('POSTGRES_HOST', 'localhost'), + port=os.getenv('POSTGRES_PORT', 5432), + database=os.getenv('POSTGRES_DBNAME') + ) + + # Set up SQLAlchemy engine and session + self.engine = create_engine(database_url) + self.SessionLocal = sessionmaker(bind=self.engine) + + # State tracking + self.workflows = {} # {wf_id: {"user": user, "group": group}} + self.tasks = {} # {task_id: wf_id} + self.jobs = {} # {job_id: (task_id, user, group)} + + # Create defined tables (subclasses of Base) if they don't exist + Base.metadata.create_all(self.engine) + + @singledispatchmethod + def policy(self, domain_event, process_event): + """Default policy""" + + @policy.register(WorkflowRun.WorkflowInitiated) + def _(self, domain_event, process_event): + """Handle WorkflowInitiated event""" + user = domain_event.user + group = domain_event.group + wf_id = domain_event.originator_id + + # Track workflow + self.workflows[wf_id] = {"user": user, "group": group} + logger.debug(f"Workflow initiated: wf_id={wf_id}, user={user}, group={group}") + + # Optionally, persist this state if needed + # Optionally, add an event to do that, then save via collect + # process_event.collect_events(jobaccount, wfView) + + @policy.register(WorkflowRun.TaskAdded) + def _(self, domain_event, process_event): + """Handle TaskAdded event""" + task_id = domain_event.task_id + wf_id = domain_event.originator_id + + # Track task + self.tasks[task_id] = wf_id + logger.debug(f"Task added: task_id={task_id}, wf_id={wf_id}") + + # Optionally, persist this state if needed + # use .collect_events(agg) instead of .save(agg) + # process_event.collect_events(taskView) + + @policy.register(Task.JobIdAdded) + def _(self, domain_event, process_event): + """Handle JobIdAdded event""" + # Grab event payload + job_id = domain_event.job_id + task_id = domain_event.originator_id + + # Find workflow and user/group for the task + wf_id = self.tasks.get(task_id) + if wf_id: + workflow_info = self.workflows.get(wf_id) + if workflow_info: + user = workflow_info["user"] + group = workflow_info["group"] + + # Track job + self.jobs[job_id] = (task_id, user, group) + logger.debug(f"Job added: job_id={job_id}, task_id={task_id}, user={user}, group={group}") + + + # Update view table + self.update_view_table(job_id, user, group) + else: + logger.debug(f"JobIdAdded event ignored: task_id={task_id} not found in tasks") + + # use .collect_events(agg) instead of .save(agg) + # process_event.collect_events(jobaccount) + + def update_view_table(self, job_id, user, group): + """Update the view table with new job information.""" + with self.SessionLocal() as session: + try: + new_job = JobView(slurm_job_id=job_id, user=user, group=group) + session.add(new_job) + session.commit() + logger.debug(f"Inserted job into view table: job_id={job_id}, user={user}, group={group}") + except IntegrityError: + session.rollback() + # Handle the case where the job already exists in the table if necessary + logger.error(f"Failed to insert job into view table (already exists?): job_id={job_id}, user={user}, group={group}") + + def get_jobs(self, user=None, group=None): + """Retrieve jobs for a specific user and/or group. + + Parameters: + - user (int, optional): The user ID to filter by. + - group (int, optional): The group ID to filter by. + + Returns: + - Dictionary of user IDs to lists of job IDs if no user is specified. + - Dictionary with a single user ID key and a list of job IDs if user is specified. + + Raises: + - ValueError: If neither user nor group is provided. + """ + if user is None and group is None: + # Retrieve all jobs grouped by user + with self.SessionLocal() as session: + jobs = session.query(JobView.user, JobView.slurm_job_id).all() + user_jobs = {} + for user_id, job_id in jobs: + if user_id not in user_jobs: + user_jobs[user_id] = [] + user_jobs[user_id].append(job_id) + return user_jobs + else: + with self.SessionLocal() as session: + query = session.query(JobView.slurm_job_id) + + if user is not None: + query = query.filter_by(user=user) + + if group is not None: + query = query.filter_by(group=group) + + jobs = query.all() + result = {user: [job.slurm_job_id for job in jobs]} + logger.debug(f"Retrieved jobs for user={user} and group={group}: {result}") + return result \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 0897d44..95e10ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,8 +30,9 @@ dependencies = [ "fabric==3.1.0", "paramiko==3.4.0", "importlib_resources>=5.4.0", - "eventsourcing[crypto,postgres-dev]==9.2.22", - "sqlalchemy==2.0.32" + "eventsourcing[crypto]==9.2.22", + "sqlalchemy==2.0.32", + "psycopg2==2.9.9" ] [tool.setuptools.packages] @@ -42,7 +43,8 @@ find = {} # Scan the project directory with the default parameters [project.optional-dependencies] test = [ "pytest", - "mock" + "mock", + "psycopg2-binary" ] [project.urls] From 0639508d73c7ed01499e760721334ed19095436e Mon Sep 17 00:00:00 2001 From: Luik Date: Wed, 7 Aug 2024 13:54:58 +0200 Subject: [PATCH 05/24] add option for in memory sqlite db (for tests) --- biomero/views.py | 27 +++++++++++++++++++-------- tests/unit/test_slurm_client.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 8 deletions(-) diff --git a/biomero/views.py b/biomero/views.py index a753fb0..6b1e9db 100644 --- a/biomero/views.py +++ b/biomero/views.py @@ -32,14 +32,25 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Read database configuration from environment variables - database_url = URL.create( - drivername="postgresql+psycopg2", - username=os.getenv('POSTGRES_USER'), - password=os.getenv('POSTGRES_PASSWORD'), - host=os.getenv('POSTGRES_HOST', 'localhost'), - port=os.getenv('POSTGRES_PORT', 5432), - database=os.getenv('POSTGRES_DBNAME') - ) + persistence_mod = os.getenv('PERSISTENCE_MODULE') + if 'postgres' in persistence_mod: + logger.info("Using postgres database") + database_url = URL.create( + drivername="postgresql+psycopg2", + username=os.getenv('POSTGRES_USER'), + password=os.getenv('POSTGRES_PASSWORD'), + host=os.getenv('POSTGRES_HOST', 'localhost'), + port=os.getenv('POSTGRES_PORT', 5432), + database=os.getenv('POSTGRES_DBNAME') + ) + elif 'sqlite' in persistence_mod: + logger.info("Using sqlite in-mem database") + database_url = URL.create( + drivername="sqlite", + database=os.getenv('SQLITE_DBNAME') + ) + else: + raise NotImplementedError(f"Can't handle {persistence_mod}") # Set up SQLAlchemy engine and session self.engine = create_engine(database_url) diff --git a/tests/unit/test_slurm_client.py b/tests/unit/test_slurm_client.py index 9e4c54c..0337915 100644 --- a/tests/unit/test_slurm_client.py +++ b/tests/unit/test_slurm_client.py @@ -3,6 +3,37 @@ import mock from mock import patch, MagicMock from paramiko import SSHException +import os + +# using actual env vars +# @pytest.fixture(scope='session', autouse=True) +# def set_env_vars(): +# # Set environment variables +# os.environ["PERSISTENCE_MODULE"] = "eventsourcing.sqlite" +# os.environ["SQLITE_DBNAME"] = ":memory:" + +# # Optional: Return a dictionary of the set variables if needed elsewhere +# yield { +# "PERSISTENCE_MODULE": "eventsourcing.sqlite", +# "SQLITE_DBNAME": ":memory:", +# } + +# # Optionally, clean up the environment variables after tests are done +# del os.environ["PERSISTENCE_MODULE"] +# del os.environ["SQLITE_DBNAME"] + + +@pytest.fixture(autouse=True) +def mock_env_vars(): + # Define mock environment variables + mock_env = { + "PERSISTENCE_MODULE": "eventsourcing.sqlite", + "SQLITE_DBNAME": ":memory:", + } + + # Patch os.getenv to return values from the mock environment + with patch('os.getenv', lambda key, default=None: mock_env.get(key, default)): + yield @pytest.fixture From 2bee749894ac5df108ded3ba2fe85d4ec8bfee11 Mon Sep 17 00:00:00 2001 From: Luik Date: Wed, 7 Aug 2024 15:15:52 +0200 Subject: [PATCH 06/24] Add a SerializableMagicMock for the eventsourcing / postgres stuff --- tests/unit/test_slurm_client.py | 52 ++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/tests/unit/test_slurm_client.py b/tests/unit/test_slurm_client.py index 0337915..08f5d4b 100644 --- a/tests/unit/test_slurm_client.py +++ b/tests/unit/test_slurm_client.py @@ -3,8 +3,7 @@ import mock from mock import patch, MagicMock from paramiko import SSHException -import os - +# import os # using actual env vars # @pytest.fixture(scope='session', autouse=True) # def set_env_vars(): @@ -36,6 +35,11 @@ def mock_env_vars(): yield +class SerializableMagicMock(MagicMock, dict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @pytest.fixture @patch('biomero.slurm_client.Connection.create_session') @patch('biomero.slurm_client.Connection.open') @@ -169,7 +173,7 @@ def test_get_logfile_from_slurm(mock_get, slurm_client): @patch('biomero.slurm_client.logger') -@patch.object(SlurmClient, 'run_commands', return_value=MagicMock(ok=True, stdout="")) +@patch.object(SlurmClient, 'run_commands', return_value=SerializableMagicMock(ok=True, stdout="")) def test_zip_data_on_slurm_server(mock_run_commands, mock_logger, slurm_client): # GIVEN data_location = "/local/path/to/store" @@ -188,7 +192,7 @@ def test_zip_data_on_slurm_server(mock_run_commands, mock_logger, slurm_client): @patch('biomero.slurm_client.logger') -@patch.object(SlurmClient, 'get', return_value=MagicMock(ok=True, stdout="")) +@patch.object(SlurmClient, 'get', return_value=SerializableMagicMock(ok=True, stdout="")) def test_copy_zip_locally(mock_get, mock_logger, slurm_client): # GIVEN local_tmp_storage = "/local/path/to/store" @@ -247,8 +251,8 @@ def test_get_workflow_command(slurm_client, @pytest.mark.parametrize("source_format, target_format", [("zarr", "tiff"), ("xyz", "abc")]) -@patch('biomero.slurm_client.SlurmClient.run_commands') -@patch('fabric.Result') +@patch('biomero.slurm_client.SlurmClient.run_commands', new_callable=SerializableMagicMock) +@patch('fabric.Result', new_callable=SerializableMagicMock) def test_run_conversion_workflow_job(mock_result, mock_run_commands, slurm_client, source_format, target_format): # GIVEN folder_name = "example_folder" @@ -297,8 +301,8 @@ def test_run_conversion_workflow_job(mock_result, mock_run_commands, slurm_clien @pytest.mark.parametrize("source_format, target_format, version", [("zarr", "tiff", "1.0"), ("xyz", "abc", "v38.20-alpha.4")]) -@patch('biomero.slurm_client.SlurmClient.run_commands') -@patch('fabric.Result') +@patch('biomero.slurm_client.SlurmClient.run_commands', new_callable=SerializableMagicMock) +@patch('fabric.Result', new_callable=SerializableMagicMock) def test_run_conversion_workflow_job_versioned(mock_result, mock_run_commands, slurm_client, source_format, target_format, version): # GIVEN folder_name = "example_folder" @@ -539,7 +543,7 @@ def test_extract_data_location_from_log_exc(mock_run_commands, slurm_job_id = "123" logfile = "path/to/logfile.txt" expected_data_location = '/path/to/data' - mock_run_commands.return_value = mock.MagicMock( + mock_run_commands.return_value = SerializableMagicMock( ok=False, stdout=expected_data_location) # WHEN @@ -558,7 +562,7 @@ def test_extract_data_location_from_log_2(mock_run_commands, # GIVEN slurm_job_id = "123" expected_data_location = '/path/to/data' - mock_run_commands.return_value = mock.MagicMock( + mock_run_commands.return_value = SerializableMagicMock( ok=True, stdout=expected_data_location) # WHEN @@ -578,7 +582,7 @@ def test_extract_data_location_from_log(mock_run_commands, slurm_job_id = "123" logfile = "path/to/logfile.txt" expected_data_location = '/path/to/data' - mock_run_commands.return_value = mock.MagicMock( + mock_run_commands.return_value = SerializableMagicMock( ok=True, stdout=expected_data_location) # WHEN @@ -607,7 +611,7 @@ def test_get_job_status_command(slurm_client): def test_check_job_status(mock_run_commands, slurm_client): # GIVEN - mock_run_commands.return_value = mock.MagicMock( + mock_run_commands.return_value = SerializableMagicMock( ok=True, stdout="12345 RUNNING\n67890 COMPLETED") # WHEN @@ -671,7 +675,7 @@ def test_check_job_array_status(mock_run_commands, slurm_client): 2339_12 COMPLETED 2024-02-29T10:34:53 2339_[13-94+ PENDING Unknown""" - mock_run_commands.return_value = MagicMock(ok=True, stdout=mock_stdout) + mock_run_commands.return_value = SerializableMagicMock(ok=True, stdout=mock_stdout) # WHEN job_status_dict, _ = slurm_client.check_job_status([2304, 2339]) @@ -689,7 +693,7 @@ def test_check_job_array_status(mock_run_commands, slurm_client): def test_check_job_status_exc(mock_run_commands, mock_logger, slurm_client): # GIVEN - return_mock = mock.MagicMock( + return_mock = SerializableMagicMock( ok=False, stdout="12345 RUNNING\n67890 COMPLETED") mock_run_commands.return_value = return_mock @@ -710,7 +714,7 @@ def test_check_job_status_exc(mock_run_commands, def test_check_job_status_exc2(mock_run_commands, _mock_timesleep, mock_logger, slurm_client): # GIVEN - mock_run_commands.return_value = mock.MagicMock( + mock_run_commands.return_value = SerializableMagicMock( ok=True, stdout=None) # WHEN @@ -758,8 +762,8 @@ def test_update_slurm_scripts(mock_generate_job, mock_workflow_params_to_subs, mock_workflow_params_to_subs.return_value = { 'PARAMS': '--param1 $PARAM1_NAME'} mock_generate_job.return_value = "GeneratedJobScript" - mock_put.return_value = MagicMock(ok=True) - mock_run.return_value = MagicMock(ok=True) + mock_put.return_value = SerializableMagicMock(ok=True) + mock_run.return_value = SerializableMagicMock(ok=True) # WHEN slurm_client.update_slurm_scripts(generate_jobs=True) @@ -814,7 +818,7 @@ def test_list_completed_jobs(mock_run_commands, # Mocking the run_commands method stdout_content = "98765\n43210\n" - mock_run_commands.return_value = mock.MagicMock( + mock_run_commands.return_value = SerializableMagicMock( ok=True, stdout=stdout_content) # WHEN @@ -840,7 +844,7 @@ def test_list_active_jobs(mock_run_commands, # Mocking the run_commands method stdout_content = "12345\n67890\n" - mock_run_commands.return_value = mock.MagicMock( + mock_run_commands.return_value = SerializableMagicMock( ok=True, stdout=stdout_content) # WHEN @@ -866,7 +870,7 @@ def test_run_commands(mock_run, slurm_client): sep = ' && ' # Mocking the run method - mock_run.return_value = mock.MagicMock( + mock_run.return_value = SerializableMagicMock( ok=True, stdout="Command executed successfully") # WHEN @@ -900,7 +904,7 @@ def test_get_active_job_progress(mock_get_recent_log_command, # Mocking the run_commands method stdout_content = "Progress: 50%\nSome other text\nProgress: 75%\n" - mock_run_commands.return_value = mock.MagicMock( + mock_run_commands.return_value = SerializableMagicMock( ok=True, stdout=stdout_content) # WHEN @@ -927,7 +931,7 @@ def test_cleanup_tmp_files_loc(mock_extract_data_location, mock_run_commands, data_location = "/path" logfile = "/path/to/logfile" - mock_run_commands.return_value = mock.MagicMock(ok=True) + mock_run_commands.return_value = SerializableMagicMock(ok=True) # WHEN result = slurm_client.cleanup_tmp_files( @@ -961,7 +965,7 @@ def test_cleanup_tmp_files(mock_extract_data_location, mock_run_commands, found_location = '/path' mock_extract_data_location.return_value = found_location - mock_run_commands.return_value = mock.MagicMock(ok=True) + mock_run_commands.return_value = SerializableMagicMock(ok=True) # WHEN result = slurm_client.cleanup_tmp_files( @@ -998,7 +1002,7 @@ def test_from_config(mock_ConfigParser, mock_SlurmClient.return_value = None # Create a MagicMock instance to represent the ConfigParser object - mock_configparser_instance = mock.MagicMock() + mock_configparser_instance = MagicMock() # Set the behavior or attributes of the mock_configparser_instance as needed mock_configparser_instance.read.return_value = None From c662f0b645b4fb40826a56c0498ce9bbcf9c1037 Mon Sep 17 00:00:00 2001 From: Luik Date: Thu, 8 Aug 2024 10:23:20 +0200 Subject: [PATCH 07/24] Add a job progress view table --- biomero/eventsourcing.py | 37 ++++++++---- biomero/slurm_client.py | 21 +++++-- biomero/views.py | 125 +++++++++++++++++++++++++++++++++------ 3 files changed, 147 insertions(+), 36 deletions(-) diff --git a/biomero/eventsourcing.py b/biomero/eventsourcing.py index e288999..40a7ff4 100644 --- a/biomero/eventsourcing.py +++ b/biomero/eventsourcing.py @@ -140,6 +140,14 @@ class StatusUpdated(Aggregate.Event): def update_task_status(self, status): logger.debug(f"Adding status to Task: task_id={self.id}, status={status}") self.status = status + + class ProgressUpdated(Aggregate.Event): + progress: str + + @event(ProgressUpdated) + def update_task_progress(self, progress): + logger.debug(f"Adding progress to Task: task_id={self.id}, progress={progress}") + self.progress = progress class ResultAdded(Aggregate.Event): result: ResultDict @@ -189,7 +197,7 @@ def initiate_workflow(self, description: str, user: int, group: int) -> UUID: - logger.debug(f"Initiating workflow: name={name}, description={description}, user={user}, group={group}") + logger.debug(f"[WFT] Initiating workflow: name={name}, description={description}, user={user}, group={group}") workflow = WorkflowRun(name, description, user, group) self.save(workflow) return workflow.id @@ -201,7 +209,7 @@ def add_task_to_workflow(self, input_data: Dict[str, Any], kwargs: Dict[str, Any] ) -> UUID: - logger.debug(f"Adding task to workflow: workflow_id={workflow_id}, task_name={task_name}, task_version={task_version}") + logger.debug(f"[WFT] Adding task to workflow: workflow_id={workflow_id}, task_name={task_name}, task_version={task_version}") task = Task(workflow_id, task_name, @@ -215,67 +223,74 @@ def add_task_to_workflow(self, return task.id def start_workflow(self, workflow_id: UUID): - logger.debug(f"Starting workflow: workflow_id={workflow_id}") + logger.debug(f"[WFT] Starting workflow: workflow_id={workflow_id}") workflow = self.repository.get(workflow_id) workflow.start_workflow() self.save(workflow) def complete_workflow(self, workflow_id: UUID): - logger.debug(f"Completing workflow: workflow_id={workflow_id}") + logger.debug(f"[WFT] Completing workflow: workflow_id={workflow_id}") workflow = self.repository.get(workflow_id) workflow.complete_workflow() self.save(workflow) def fail_workflow(self, workflow_id: UUID, error_message: str): - logger.debug(f"Failing workflow: workflow_id={workflow_id}, error_message={error_message}") + logger.debug(f"[WFT] Failing workflow: workflow_id={workflow_id}, error_message={error_message}") workflow = self.repository.get(workflow_id) workflow.fail_workflow(error_message) self.save(workflow) def start_task(self, task_id: UUID): - logger.debug(f"Starting task: task_id={task_id}") + logger.debug(f"[WFT] Starting task: task_id={task_id}") task = self.repository.get(task_id) task.start_task() self.save(task) def complete_task(self, task_id: UUID, message: str): - logger.debug(f"Completing task: task_id={task_id}, message={message}") + logger.debug(f"[WFT] Completing task: task_id={task_id}, message={message}") task = self.repository.get(task_id) task.complete_task(message) self.save(task) def fail_task(self, task_id: UUID, error_message: str): - logger.debug(f"Failing task: task_id={task_id}, error_message={error_message}") + logger.debug(f"[WFT] Failing task: task_id={task_id}, error_message={error_message}") task = self.repository.get(task_id) task.fail_task(error_message) self.save(task) def add_job_id(self, task_id, slurm_job_id): - logger.debug(f"Adding job_id to task: task_id={task_id}, slurm_job_id={slurm_job_id}") + logger.debug(f"[WFT] Adding job_id to task: task_id={task_id}, slurm_job_id={slurm_job_id}") task = self.repository.get(task_id) task.add_job_id(slurm_job_id) self.save(task) def add_result(self, task_id, result): - logger.debug(f"Adding result to task: task_id={task_id}, result={result}") + logger.debug(f"[WFT] Adding result to task: task_id={task_id}, result={result}") task = self.repository.get(task_id) task.add_result(result) self.save(task) def update_task_status(self, task_id, status): - logger.debug(f"Adding status to task: task_id={task_id}, status={status}") + logger.debug(f"[WFT] Adding status to task: task_id={task_id}, status={status}") task = self.repository.get(task_id) task.update_task_status(status) self.save(task) + + def update_task_progress(self, task_id, progress): + logger.debug(f"[WFT] Adding progress to task: task_id={task_id}, progress={progress}") + + task = self.repository.get(task_id) + task.update_task_progress(progress) + self.save(task) diff --git a/biomero/slurm_client.py b/biomero/slurm_client.py index 63892f0..842f770 100644 --- a/biomero/slurm_client.py +++ b/biomero/slurm_client.py @@ -31,7 +31,7 @@ import io import os from biomero.eventsourcing import WorkflowTracker -from biomero.views import JobAccounting +from biomero.views import JobAccounting, JobProgress from eventsourcing.system import System, SingleThreadedRunner logger = logging.getLogger(__name__) @@ -98,6 +98,7 @@ def __init__(self, self.submit_result = submit_result self.ok = self.submit_result.ok self.job_state = None + self.progress = None self.error_message = self.submit_result.stderr if hasattr(self.submit_result, 'stderr') else '' def wait_for_completion(self, slurmClient, omeroConn) -> str: @@ -121,6 +122,7 @@ def wait_for_completion(self, slurmClient, omeroConn) -> str: "TIMEOUT+"): job_status_dict, poll_result = slurmClient.check_job_status( [self.job_id]) + self.progress = slurmClient.get_active_job_progress(self.job_id) if not poll_result.ok: logger.warning( f"Error checking job status:{poll_result.stderr}") @@ -131,6 +133,8 @@ def wait_for_completion(self, slurmClient, omeroConn) -> str: omeroConn.keepAlive() # keep the OMERO connection alive slurmClient.workflowTracker.update_task_status(self.task_id, self.job_state) + slurmClient.workflowTracker.update_task_progress( + self.task_id, self.progress) timesleep.sleep(self.slurm_polling_interval) logger.info(f"Job {self.job_id} finished: {self.job_state}") logger.info( @@ -396,7 +400,10 @@ def __init__(self, # Setup workflow tracking and accounting self.track_workflows = track_workflows - system = System(pipes=[[WorkflowTracker, JobAccounting]]) + system = System(pipes=[ + [WorkflowTracker, JobAccounting], + [WorkflowTracker, JobProgress] + ]) if self.track_workflows: # use configured persistence from env runner = SingleThreadedRunner(system) else: # turn off persistence, override @@ -405,6 +412,7 @@ def __init__(self, runner.start() self.workflowTracker = runner.get(WorkflowTracker) self.jobAccounting = runner.get(JobAccounting) + self.jobProgress = runner.get(JobProgress) def init_workflows(self, force_update: bool = False): """ @@ -896,7 +904,7 @@ def get_recent_log_command(self, log_file: str, n: int = 10) -> str: def get_active_job_progress(self, slurm_job_id: str, pattern: str = r"\d+%", - env: Optional[Dict[str, str]] = None) -> str: + env: Optional[Dict[str, str]] = None) -> Any: """ Get the progress of an active Slurm job from its logfiles. @@ -909,7 +917,7 @@ def get_active_job_progress(self, to set when running the command. Defaults to None. Returns: - str: The progress of the Slurm job. + Any: The progress of the Slurm job according to the pattern, or None. """ cmdlist = [] cmd = self.get_recent_log_command( @@ -922,13 +930,14 @@ def get_active_job_progress(self, except Exception as e: logger.error(f"Issue with run command: {e}") # Match the specified pattern in the result's stdout + latest_progress = None try: latest_progress = re.findall( pattern, result.stdout)[-1] except Exception as e: - logger.error(f"Issue with extracting progress: {e}") + logger.warning(f"Issue with extracting progress: {e}") - return f"Progress: {latest_progress}\n" + return latest_progress def run_commands(self, cmdlist: List[str], env: Optional[Dict[str, str]] = None, diff --git a/biomero/views.py b/biomero/views.py index 6b1e9db..2a85aca 100644 --- a/biomero/views.py +++ b/biomero/views.py @@ -19,18 +19,8 @@ Base = declarative_base() -class JobView(Base): - __tablename__ = 'biomero_job_view' - - slurm_job_id = Column(Integer, primary_key=True) - user = Column(Integer, nullable=False) - group = Column(Integer, nullable=False) - - -class JobAccounting(ProcessApplication): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - +class BaseApplication: + def __init__(self): # Read database configuration from environment variables persistence_mod = os.getenv('PERSISTENCE_MODULE') if 'postgres' in persistence_mod: @@ -51,18 +41,40 @@ def __init__(self, *args, **kwargs): ) else: raise NotImplementedError(f"Can't handle {persistence_mod}") - + # Set up SQLAlchemy engine and session self.engine = create_engine(database_url) self.SessionLocal = sessionmaker(bind=self.engine) + # Create defined tables (subclasses of Base) if they don't exist + Base.metadata.create_all(self.engine) + + +class JobView(Base): + __tablename__ = 'biomero_job_view' + + slurm_job_id = Column(Integer, primary_key=True) + user = Column(Integer, nullable=False) + group = Column(Integer, nullable=False) + + +class JobProgressView(Base): + __tablename__ = 'biomero_job_progress_view' + + slurm_job_id = Column(Integer, primary_key=True) + status = Column(String, nullable=False) + progress = Column(String, nullable=True) + + +class JobAccounting(ProcessApplication, BaseApplication): + def __init__(self, *args, **kwargs): + ProcessApplication.__init__(self, *args, **kwargs) + BaseApplication.__init__(self) + # State tracking self.workflows = {} # {wf_id: {"user": user, "group": group}} self.tasks = {} # {task_id: wf_id} - self.jobs = {} # {job_id: (task_id, user, group)} - - # Create defined tables (subclasses of Base) if they don't exist - Base.metadata.create_all(self.engine) + self.jobs = {} # {job_id: (task_id, user, group)} @singledispatchmethod def policy(self, domain_event, process_event): @@ -116,7 +128,7 @@ def _(self, domain_event, process_event): self.jobs[job_id] = (task_id, user, group) logger.debug(f"Job added: job_id={job_id}, task_id={task_id}, user={user}, group={group}") - + # Update view table self.update_view_table(job_id, user, group) else: @@ -175,4 +187,79 @@ def get_jobs(self, user=None, group=None): jobs = query.all() result = {user: [job.slurm_job_id for job in jobs]} logger.debug(f"Retrieved jobs for user={user} and group={group}: {result}") - return result \ No newline at end of file + return result + + +class JobProgress(ProcessApplication, BaseApplication): + def __init__(self, *args, **kwargs): + ProcessApplication.__init__(self, *args, **kwargs) + BaseApplication.__init__(self) + + # State tracking + self.task_to_job = {} # {task_id: job_id} + self.job_status = {} # {job_id: {"status": status, "progress": progress}} + + @singledispatchmethod + def policy(self, domain_event, process_event): + """Default policy""" + + @policy.register(Task.JobIdAdded) + def _(self, domain_event, process_event): + """Handle JobIdAdded event""" + job_id = domain_event.job_id + task_id = domain_event.originator_id + + # Track task to job mapping + self.task_to_job[task_id] = job_id + logger.debug(f"JobId added: job_id={job_id}, task_id={task_id}") + + @policy.register(Task.StatusUpdated) + def _(self, domain_event, process_event): + """Handle StatusUpdated event""" + task_id = domain_event.originator_id + status = domain_event.status + + job_id = self.task_to_job.get(task_id) + if job_id is not None: + if job_id in self.job_status: + self.job_status[job_id]["status"] = status + else: + self.job_status[job_id] = {"status": status, "progress": None} + + logger.debug(f"Status updated: job_id={job_id}, status={status}") + # Update view table + self.update_view_table(job_id) + + @policy.register(Task.ProgressUpdated) + def _(self, domain_event, process_event): + """Handle ProgressUpdated event""" + task_id = domain_event.originator_id + progress = domain_event.progress + + job_id = self.task_to_job.get(task_id) + if job_id is not None: + if job_id in self.job_status: + self.job_status[job_id]["progress"] = progress + else: + self.job_status[job_id] = {"status": "UNKNOWN", "progress": progress} + + logger.debug(f"Progress updated: job_id={job_id}, progress={progress}") + # Update view table + self.update_view_table(job_id) + + def update_view_table(self, job_id): + """Update the view table with new job status and progress information.""" + with self.SessionLocal() as session: + try: + job_info = self.job_status[job_id] + new_job_progress = JobProgressView( + slurm_job_id=job_id, + status=job_info["status"], + progress=job_info["progress"] + ) + session.merge(new_job_progress) # Use merge to insert or update + session.commit() + logger.debug(f"Inserted/Updated job progress in view table: job_id={job_id}, status={job_info['status']}, progress={job_info['progress']}") + except IntegrityError: + session.rollback() + logger.error(f"Failed to insert/update job progress in view table: job_id={job_id}") From 1f07e44e7c9146ee94919d4b81ace4590a05f3d5 Mon Sep 17 00:00:00 2001 From: Luik Date: Thu, 8 Aug 2024 10:26:39 +0200 Subject: [PATCH 08/24] Fix test with changed output --- tests/unit/test_slurm_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_slurm_client.py b/tests/unit/test_slurm_client.py index 08f5d4b..30b7d11 100644 --- a/tests/unit/test_slurm_client.py +++ b/tests/unit/test_slurm_client.py @@ -915,7 +915,7 @@ def test_get_active_job_progress(mock_get_recent_log_command, log_file=slurm_client._LOGFILE.format(slurm_job_id=slurm_job_id)) mock_run_commands.assert_called_once_with([log_cmd], env={}) - assert result == "Progress: 75%\n" + assert result == "75%" @patch('biomero.slurm_client.SlurmClient.run_commands') From 2fcf853e77ea6664f2a5af0bf7dfb374db939cfa Mon Sep 17 00:00:00 2001 From: Luik Date: Mon, 12 Aug 2024 14:19:11 +0200 Subject: [PATCH 09/24] Add a workflow stats viewer --- biomero/slurm_client.py | 6 +- biomero/views.py | 341 +++++++++++++++++++++++++++++++++++++--- 2 files changed, 327 insertions(+), 20 deletions(-) diff --git a/biomero/slurm_client.py b/biomero/slurm_client.py index 842f770..5818522 100644 --- a/biomero/slurm_client.py +++ b/biomero/slurm_client.py @@ -31,7 +31,7 @@ import io import os from biomero.eventsourcing import WorkflowTracker -from biomero.views import JobAccounting, JobProgress +from biomero.views import JobAccounting, JobProgress, WorkflowAnalytics from eventsourcing.system import System, SingleThreadedRunner logger = logging.getLogger(__name__) @@ -402,7 +402,8 @@ def __init__(self, self.track_workflows = track_workflows system = System(pipes=[ [WorkflowTracker, JobAccounting], - [WorkflowTracker, JobProgress] + [WorkflowTracker, JobProgress], + [WorkflowTracker, WorkflowAnalytics] ]) if self.track_workflows: # use configured persistence from env runner = SingleThreadedRunner(system) @@ -413,6 +414,7 @@ def __init__(self, self.workflowTracker = runner.get(WorkflowTracker) self.jobAccounting = runner.get(JobAccounting) self.jobProgress = runner.get(JobProgress) + self.workflowAnalytics = runner.get(WorkflowAnalytics) def init_workflows(self, force_update: bool = False): """ diff --git a/biomero/views.py b/biomero/views.py index 2a85aca..67dfecf 100644 --- a/biomero/views.py +++ b/biomero/views.py @@ -5,20 +5,55 @@ from uuid import NAMESPACE_URL, UUID, uuid5 from typing import Any, Dict, List import logging -from sqlalchemy import create_engine, text, Column, Integer, String, URL +from sqlalchemy import create_engine, text, Column, Integer, String, URL, DateTime, Float from sqlalchemy.orm import sessionmaker, declarative_base from sqlalchemy.exc import IntegrityError +from sqlalchemy.sql import func +from sqlalchemy.dialects.postgresql import UUID as PGUUID from biomero.eventsourcing import WorkflowRun, Task logger = logging.getLogger(__name__) -# --------------------- VIEWS ---------------------------- # +# --------------------- VIEWS DB tables/classes ---------------------------- # # Base class for declarative class definitions Base = declarative_base() +class JobView(Base): + __tablename__ = 'biomero_job_view' + + slurm_job_id = Column(Integer, primary_key=True) + user = Column(Integer, nullable=False) + group = Column(Integer, nullable=False) + + +class JobProgressView(Base): + __tablename__ = 'biomero_job_progress_view' + + slurm_job_id = Column(Integer, primary_key=True) + status = Column(String, nullable=False) + progress = Column(String, nullable=True) + + +class TaskExecution(Base): + __tablename__ = 'biomero_task_execution' + + task_id = Column(PGUUID(as_uuid=True), primary_key=True) + task_name = Column(String, nullable=False) + task_version = Column(String) + user_id = Column(Integer, nullable=True) + group_id = Column(Integer, nullable=True) + status = Column(String, nullable=False) + start_time = Column(DateTime, nullable=False) + end_time = Column(DateTime, nullable=True) + error_type = Column(String, nullable=True) + + +# ------------------- View Listener Applications ------------------ # + + class BaseApplication: def __init__(self): # Read database configuration from environment variables @@ -50,22 +85,6 @@ def __init__(self): Base.metadata.create_all(self.engine) -class JobView(Base): - __tablename__ = 'biomero_job_view' - - slurm_job_id = Column(Integer, primary_key=True) - user = Column(Integer, nullable=False) - group = Column(Integer, nullable=False) - - -class JobProgressView(Base): - __tablename__ = 'biomero_job_progress_view' - - slurm_job_id = Column(Integer, primary_key=True) - status = Column(String, nullable=False) - progress = Column(String, nullable=True) - - class JobAccounting(ProcessApplication, BaseApplication): def __init__(self, *args, **kwargs): ProcessApplication.__init__(self, *args, **kwargs) @@ -263,3 +282,289 @@ def update_view_table(self, job_id): except IntegrityError: session.rollback() logger.error(f"Failed to insert/update job progress in view table: job_id={job_id}") + + +class WorkflowAnalytics(BaseApplication, ProcessApplication): + def __init__(self, *args, **kwargs): + ProcessApplication.__init__(self, *args, **kwargs) + BaseApplication.__init__(self) + + # State tracking for workflows and tasks + self.workflows = {} # {wf_id: {"user": user, "group": group}} + self.tasks = {} # {task_id: {"wf_id": wf_id, "task_name": task_name, "task_version": task_version, "start_time": timestamp, "status": status, "end_time": timestamp, "error_type": error_type}} + + @singledispatchmethod + def policy(self, domain_event, process_event): + """Default policy""" + pass + + @policy.register(WorkflowRun.WorkflowInitiated) + def _(self, domain_event, process_event): + """Handle WorkflowInitiated event""" + user = domain_event.user + group = domain_event.group + wf_id = domain_event.originator_id + + # Track workflow + self.workflows[wf_id] = {"user": user, "group": group} + logger.debug(f"Workflow initiated: wf_id={wf_id}, user={user}, group={group}") + + @policy.register(WorkflowRun.TaskAdded) + def _(self, domain_event, process_event): + """Handle TaskAdded event""" + task_id = domain_event.task_id + wf_id = domain_event.originator_id + + # Add workflow ID to the existing task information + if task_id in self.tasks: + self.tasks[task_id]["wf_id"] = wf_id + else: + # In case TaskAdded arrives before TaskCreated (unlikely but possible) + self.tasks[task_id] = {"wf_id": wf_id} + + logger.debug(f"Task added: task_id={task_id}, wf_id={wf_id}") + + @policy.register(Task.TaskCreated) + def _(self, domain_event, process_event): + """Handle TaskCreated event""" + task_id = domain_event.originator_id + task_name = domain_event.task_name + task_version = domain_event.task_version + timestamp_created = domain_event.timestamp + + # Track task creation details + if task_id in self.tasks: + self.tasks[task_id].update({ + "task_name": task_name, + "task_version": task_version, + "start_time": timestamp_created + }) + else: + # Initialize task tracking if TaskAdded hasn't been processed yet + self.tasks[task_id] = { + "task_name": task_name, + "task_version": task_version, + "start_time": timestamp_created + } + + logger.debug(f"Task created: task_id={task_id}, task_name={task_name}, timestamp={timestamp_created}") + self.update_view_table(task_id) + + @policy.register(Task.StatusUpdated) + def _(self, domain_event, process_event): + """Handle StatusUpdated event""" + task_id = domain_event.originator_id + status = domain_event.status + + # Update task with status + if task_id in self.tasks: + self.tasks[task_id]["status"] = status + logger.debug(f"Task status updated: task_id={task_id}, status={status}") + self.update_view_table(task_id) + + @policy.register(Task.TaskCompleted) + def _(self, domain_event, process_event): + """Handle TaskCompleted event""" + task_id = domain_event.originator_id + timestamp_completed = domain_event.timestamp + + # Update task with end time + if task_id in self.tasks: + self.tasks[task_id]["end_time"] = timestamp_completed + logger.debug(f"Task completed: task_id={task_id}, end_time={timestamp_completed}") + self.update_view_table(task_id) + + @policy.register(Task.TaskFailed) + def _(self, domain_event, process_event): + """Handle TaskFailed event""" + task_id = domain_event.originator_id + timestamp_failed = domain_event.timestamp + error_message = domain_event.error_message + + # Update task with end time and error message + if task_id in self.tasks: + self.tasks[task_id]["end_time"] = timestamp_failed + self.tasks[task_id]["error_type"] = error_message + logger.debug(f"Task failed: task_id={task_id}, end_time={timestamp_failed}, error={error_message}") + self.update_view_table(task_id) + + def update_view_table(self, task_id): + """Update the view table with new task execution information.""" + task_info = self.tasks.get(task_id) + if not task_info: + return # Skip if task information is incomplete + + wf_id = task_info.get("wf_id") + user_id = None + group_id = None + + # Retrieve user and group from workflow + if wf_id and wf_id in self.workflows: + user_id = self.workflows[wf_id]["user"] + group_id = self.workflows[wf_id]["group"] + + with self.SessionLocal() as session: + try: + existing_task = session.query(TaskExecution).filter_by(task_id=task_id).first() + if existing_task: + # Update existing task execution record + existing_task.task_name = task_info.get("task_name", existing_task.task_name) + existing_task.task_version = task_info.get("task_version", existing_task.task_version) + existing_task.user_id = user_id + existing_task.group_id = group_id + existing_task.status = task_info.get("status", existing_task.status) + existing_task.start_time = task_info.get("start_time", existing_task.start_time) + existing_task.end_time = task_info.get("end_time", existing_task.end_time) + existing_task.error_type = task_info.get("error_type", existing_task.error_type) + else: + # Create a new task execution record + new_task_execution = TaskExecution( + task_id=task_id, + task_name=task_info.get("task_name"), + task_version=task_info.get("task_version"), + user_id=user_id, + group_id=group_id, + status=task_info.get("status"), + start_time=task_info.get("start_time"), + end_time=task_info.get("end_time"), + error_type=task_info.get("error_type") + ) + session.add(new_task_execution) + + session.commit() + logger.debug(f"Updated/Inserted task execution into view table: task_id={task_id}, task_name={task_info.get('task_name')}") + except IntegrityError: + session.rollback() + logger.error(f"Failed to insert/update task execution into view table: task_id={task_id}") + + def get_task_counts(self, user=None, group=None): + """Retrieve task execution counts grouped by task name and version. + + Parameters: + - user (int, optional): The user ID to filter by. + - group (int, optional): The group ID to filter by. + + Returns: + - Dictionary of task names and versions to counts. + """ + with self.SessionLocal() as session: + query = session.query( + TaskExecution.task_name, + TaskExecution.task_version, + func.count(TaskExecution.task_name) + ).group_by(TaskExecution.task_name, TaskExecution.task_version) + + if user is not None: + query = query.filter_by(user_id=user) + + if group is not None: + query = query.filter_by(group_id=group) + + task_counts = query.all() + result = { + (task_name, task_version): count + for task_name, task_version, count in task_counts + } + logger.debug(f"Retrieved task counts: {result}") + return result + + def get_average_task_duration(self, user=None, group=None): + """Retrieve the average task duration grouped by task name and version. + + Parameters: + - user (int, optional): The user ID to filter by. + - group (int, optional): The group ID to filter by. + + Returns: + - Dictionary of task names and versions to average duration (in seconds). + """ + with self.SessionLocal() as session: + query = session.query( + TaskExecution.task_name, + TaskExecution.task_version, + func.avg( + func.extract('epoch', TaskExecution.end_time) - func.extract('epoch', TaskExecution.start_time) + ).label('avg_duration') + ).filter(TaskExecution.end_time.isnot(None)) # Only include completed tasks + query = query.group_by(TaskExecution.task_name, TaskExecution.task_version) + + if user is not None: + query = query.filter_by(user_id=user) + + if group is not None: + query = query.filter_by(group_id=group) + + task_durations = query.all() + result = { + (task_name, task_version): avg_duration + for task_name, task_version, avg_duration in task_durations + } + logger.debug(f"Retrieved average task durations: {result}") + return result + + def get_task_failures(self, user=None, group=None): + """Retrieve tasks that failed, grouped by task name and version. + + Parameters: + - user (int, optional): The user ID to filter by. + - group (int, optional): The group ID to filter by. + + Returns: + - Dictionary of task names and versions to lists of failure reasons. + """ + with self.SessionLocal() as session: + query = session.query( + TaskExecution.task_name, + TaskExecution.task_version, + TaskExecution.error_type + ).filter(TaskExecution.error_type.isnot(None)) # Only include failed tasks + query = query.group_by(TaskExecution.task_name, TaskExecution.task_version, TaskExecution.error_type) + + if user is not None: + query = query.filter_by(user_id=user) + + if group is not None: + query = query.filter_by(group_id=group) + + task_failures = query.all() + result = {} + for task_name, task_version, error_type in task_failures: + key = (task_name, task_version) + if key not in result: + result[key] = [] + result[key].append(error_type) + + logger.debug(f"Retrieved task failures: {result}") + return result + + def get_task_usage_over_time(self, task_name, user=None, group=None): + """Retrieve task usage over time for a specific task. + + Parameters: + - task_name (str): The name of the task to filter by. + - user (int, optional): The user ID to filter by. + - group (int, optional): The group ID to filter by. + + Returns: + - Dictionary mapping date to the count of task executions on that date. + """ + with self.SessionLocal() as session: + query = session.query( + func.date(TaskExecution.start_time), + func.count(TaskExecution.task_name) + ).filter(TaskExecution.task_name == task_name) + query = query.group_by(func.date(TaskExecution.start_time)) + + if user is not None: + query = query.filter_by(user_id=user) + + if group is not None: + query = query.filter_by(group_id=group) + + usage_over_time = query.all() + result = { + date: count + for date, count in usage_over_time + } + logger.debug(f"Retrieved task usage over time for {task_name}: {result}") + return result From c7769f35f7b4b25f00f0a3a9d3648c02e28e888d Mon Sep 17 00:00:00 2001 From: Luik Date: Thu, 15 Aug 2024 11:37:04 +0200 Subject: [PATCH 10/24] Use SQLAlchemy backend and scoped session --- biomero/slurm_client.py | 14 ++++- biomero/views.py | 106 +++++++++++++++++--------------- pyproject.toml | 3 +- tests/unit/test_slurm_client.py | 37 ++++------- 4 files changed, 83 insertions(+), 77 deletions(-) diff --git a/biomero/slurm_client.py b/biomero/slurm_client.py index 5818522..4892913 100644 --- a/biomero/slurm_client.py +++ b/biomero/slurm_client.py @@ -31,7 +31,7 @@ import io import os from biomero.eventsourcing import WorkflowTracker -from biomero.views import JobAccounting, JobProgress, WorkflowAnalytics +from biomero.views import JobAccounting, JobProgress, WorkflowAnalytics, EngineManager from eventsourcing.system import System, SingleThreadedRunner logger = logging.getLogger(__name__) @@ -406,7 +406,9 @@ def __init__(self, [WorkflowTracker, WorkflowAnalytics] ]) if self.track_workflows: # use configured persistence from env - runner = SingleThreadedRunner(system) + scoped_session_topic = EngineManager.create_scoped_session() + runner = SingleThreadedRunner(system, env={ + 'SQLALCHEMY_SCOPED_SESSION_TOPIC': scoped_session_topic}) else: # turn off persistence, override runner = SingleThreadedRunner(system, env={ "PERSISTENCE_MODULE": ""}) @@ -415,6 +417,14 @@ def __init__(self, self.jobAccounting = runner.get(JobAccounting) self.jobProgress = runner.get(JobProgress) self.workflowAnalytics = runner.get(WorkflowAnalytics) + + def __exit__(self, exc_type, exc_val, exc_tb): + # Ensure to call the parent class's __exit__ + # to clean up Connection resources + super().__exit__(exc_type, exc_val, exc_tb) + # Cleanup resources specific to SlurmClient + EngineManager.close_engine() + # If we have any other resources to close or cleanup, do it here def init_workflows(self, force_update: bool = False): """ diff --git a/biomero/views.py b/biomero/views.py index 67dfecf..40a43eb 100644 --- a/biomero/views.py +++ b/biomero/views.py @@ -2,11 +2,12 @@ from eventsourcing.system import ProcessApplication from eventsourcing.dispatch import singledispatchmethod +from eventsourcing.utils import get_topic from uuid import NAMESPACE_URL, UUID, uuid5 from typing import Any, Dict, List import logging from sqlalchemy import create_engine, text, Column, Integer, String, URL, DateTime, Float -from sqlalchemy.orm import sessionmaker, declarative_base +from sqlalchemy.orm import sessionmaker, declarative_base, scoped_session from sqlalchemy.exc import IntegrityError from sqlalchemy.sql import func from sqlalchemy.dialects.postgresql import UUID as PGUUID @@ -53,48 +54,60 @@ class TaskExecution(Base): # ------------------- View Listener Applications ------------------ # - -class BaseApplication: - def __init__(self): - # Read database configuration from environment variables - persistence_mod = os.getenv('PERSISTENCE_MODULE') - if 'postgres' in persistence_mod: - logger.info("Using postgres database") - database_url = URL.create( - drivername="postgresql+psycopg2", - username=os.getenv('POSTGRES_USER'), - password=os.getenv('POSTGRES_PASSWORD'), - host=os.getenv('POSTGRES_HOST', 'localhost'), - port=os.getenv('POSTGRES_PORT', 5432), - database=os.getenv('POSTGRES_DBNAME') - ) - elif 'sqlite' in persistence_mod: - logger.info("Using sqlite in-mem database") - database_url = URL.create( - drivername="sqlite", - database=os.getenv('SQLITE_DBNAME') - ) - else: - raise NotImplementedError(f"Can't handle {persistence_mod}") - - # Set up SQLAlchemy engine and session - self.engine = create_engine(database_url) - self.SessionLocal = sessionmaker(bind=self.engine) +class EngineManager: + _engine = None + _scoped_session_topic = None + _session = None - # Create defined tables (subclasses of Base) if they don't exist - Base.metadata.create_all(self.engine) - + @classmethod + def create_scoped_session(cls): + if cls._engine is None: + persistence_mod = os.getenv('PERSISTENCE_MODULE') + if 'sqlalchemy' in persistence_mod: + logger.info("Using sqlalchemy database") + database_url=os.getenv('SQLALCHEMY_URL') + cls._engine = create_engine(database_url) + else: + raise NotImplementedError(f"Can't handle {persistence_mod}") + + # setup tables if needed + Base.metadata.create_all(cls._engine) + + # Create a scoped_session object. + cls._session = scoped_session( + sessionmaker(autocommit=False, autoflush=False, + bind=cls._engine) + ) + + class MyScopedSessionAdapter: + def __getattribute__(self, item: str) -> None: + return getattr(cls._session, item) + + # Produce the topic of the scoped session adapter class. + cls._scoped_session_topic = get_topic(MyScopedSessionAdapter) -class JobAccounting(ProcessApplication, BaseApplication): + return cls._scoped_session_topic + + @classmethod + def close_engine(cls): + if cls._engine is not None: + cls._session.remove() + cls._engine.dispose() + cls._engine = None + cls._session = None + cls._scoped_session_topic = None + logger.info("Database engine disposed.") + + +class JobAccounting(ProcessApplication): def __init__(self, *args, **kwargs): ProcessApplication.__init__(self, *args, **kwargs) - BaseApplication.__init__(self) # State tracking self.workflows = {} # {wf_id: {"user": user, "group": group}} self.tasks = {} # {task_id: wf_id} self.jobs = {} # {job_id: (task_id, user, group)} - + @singledispatchmethod def policy(self, domain_event, process_event): """Default policy""" @@ -147,7 +160,6 @@ def _(self, domain_event, process_event): self.jobs[job_id] = (task_id, user, group) logger.debug(f"Job added: job_id={job_id}, task_id={task_id}, user={user}, group={group}") - # Update view table self.update_view_table(job_id, user, group) else: @@ -158,7 +170,7 @@ def _(self, domain_event, process_event): def update_view_table(self, job_id, user, group): """Update the view table with new job information.""" - with self.SessionLocal() as session: + with self.recorder.transaction() as session: try: new_job = JobView(slurm_job_id=job_id, user=user, group=group) session.add(new_job) @@ -185,7 +197,7 @@ def get_jobs(self, user=None, group=None): """ if user is None and group is None: # Retrieve all jobs grouped by user - with self.SessionLocal() as session: + with self.recorder.transaction() as session: jobs = session.query(JobView.user, JobView.slurm_job_id).all() user_jobs = {} for user_id, job_id in jobs: @@ -194,7 +206,7 @@ def get_jobs(self, user=None, group=None): user_jobs[user_id].append(job_id) return user_jobs else: - with self.SessionLocal() as session: + with self.recorder.transaction() as session: query = session.query(JobView.slurm_job_id) if user is not None: @@ -209,10 +221,9 @@ def get_jobs(self, user=None, group=None): return result -class JobProgress(ProcessApplication, BaseApplication): +class JobProgress(ProcessApplication): def __init__(self, *args, **kwargs): ProcessApplication.__init__(self, *args, **kwargs) - BaseApplication.__init__(self) # State tracking self.task_to_job = {} # {task_id: job_id} @@ -268,7 +279,7 @@ def _(self, domain_event, process_event): def update_view_table(self, job_id): """Update the view table with new job status and progress information.""" - with self.SessionLocal() as session: + with self.recorder.transaction() as session: try: job_info = self.job_status[job_id] new_job_progress = JobProgressView( @@ -284,10 +295,9 @@ def update_view_table(self, job_id): logger.error(f"Failed to insert/update job progress in view table: job_id={job_id}") -class WorkflowAnalytics(BaseApplication, ProcessApplication): +class WorkflowAnalytics(ProcessApplication): def __init__(self, *args, **kwargs): ProcessApplication.__init__(self, *args, **kwargs) - BaseApplication.__init__(self) # State tracking for workflows and tasks self.workflows = {} # {wf_id: {"user": user, "group": group}} @@ -403,7 +413,7 @@ def update_view_table(self, task_id): user_id = self.workflows[wf_id]["user"] group_id = self.workflows[wf_id]["group"] - with self.SessionLocal() as session: + with self.recorder.transaction() as session: try: existing_task = session.query(TaskExecution).filter_by(task_id=task_id).first() if existing_task: @@ -447,7 +457,7 @@ def get_task_counts(self, user=None, group=None): Returns: - Dictionary of task names and versions to counts. """ - with self.SessionLocal() as session: + with self.recorder.transaction() as session: query = session.query( TaskExecution.task_name, TaskExecution.task_version, @@ -478,7 +488,7 @@ def get_average_task_duration(self, user=None, group=None): Returns: - Dictionary of task names and versions to average duration (in seconds). """ - with self.SessionLocal() as session: + with self.recorder.transaction() as session: query = session.query( TaskExecution.task_name, TaskExecution.task_version, @@ -512,7 +522,7 @@ def get_task_failures(self, user=None, group=None): Returns: - Dictionary of task names and versions to lists of failure reasons. """ - with self.SessionLocal() as session: + with self.recorder.transaction() as session: query = session.query( TaskExecution.task_name, TaskExecution.task_version, @@ -548,7 +558,7 @@ def get_task_usage_over_time(self, task_name, user=None, group=None): Returns: - Dictionary mapping date to the count of task executions on that date. """ - with self.SessionLocal() as session: + with self.recorder.transaction() as session: query = session.query( func.date(TaskExecution.start_time), func.count(TaskExecution.task_name) diff --git a/pyproject.toml b/pyproject.toml index 95e10ea..46d92b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,8 @@ dependencies = [ "importlib_resources>=5.4.0", "eventsourcing[crypto]==9.2.22", "sqlalchemy==2.0.32", - "psycopg2==2.9.9" + "psycopg2==2.9.9", + "eventsourcing_sqlalchemy==0.7" ] [tool.setuptools.packages] diff --git a/tests/unit/test_slurm_client.py b/tests/unit/test_slurm_client.py index 30b7d11..2529112 100644 --- a/tests/unit/test_slurm_client.py +++ b/tests/unit/test_slurm_client.py @@ -3,36 +3,21 @@ import mock from mock import patch, MagicMock from paramiko import SSHException -# import os -# using actual env vars -# @pytest.fixture(scope='session', autouse=True) -# def set_env_vars(): -# # Set environment variables -# os.environ["PERSISTENCE_MODULE"] = "eventsourcing.sqlite" -# os.environ["SQLITE_DBNAME"] = ":memory:" +import os -# # Optional: Return a dictionary of the set variables if needed elsewhere -# yield { -# "PERSISTENCE_MODULE": "eventsourcing.sqlite", -# "SQLITE_DBNAME": ":memory:", -# } -# # Optionally, clean up the environment variables after tests are done -# del os.environ["PERSISTENCE_MODULE"] -# del os.environ["SQLITE_DBNAME"] +@pytest.fixture(autouse=True) +def set_env_vars(): + # Set environment variables directly + os.environ["PERSISTENCE_MODULE"] = "eventsourcing_sqlalchemy" + os.environ["SQLALCHEMY_URL"] = "sqlite:///:memory:" + # Yield to let the test run + yield -@pytest.fixture(autouse=True) -def mock_env_vars(): - # Define mock environment variables - mock_env = { - "PERSISTENCE_MODULE": "eventsourcing.sqlite", - "SQLITE_DBNAME": ":memory:", - } - - # Patch os.getenv to return values from the mock environment - with patch('os.getenv', lambda key, default=None: mock_env.get(key, default)): - yield + # Optionally, clean up the environment variables after the test + del os.environ["PERSISTENCE_MODULE"] + del os.environ["SQLALCHEMY_URL"] class SerializableMagicMock(MagicMock, dict): From 0e77745808d01aeda4d2bb0c2e369f036393fbe2 Mon Sep 17 00:00:00 2001 From: Luik Date: Wed, 21 Aug 2024 11:21:34 +0200 Subject: [PATCH 11/24] Use SQLAlchemy backbone so we can use ScopedSession for all single thread listeners. --- biomero/database.py | 94 ++++++++++++++++++++++++++++++ biomero/eventsourcing.py | 15 ++++- biomero/slurm_client.py | 6 +- biomero/views.py | 122 +++++++++------------------------------ 4 files changed, 140 insertions(+), 97 deletions(-) create mode 100644 biomero/database.py diff --git a/biomero/database.py b/biomero/database.py new file mode 100644 index 0000000..db2db56 --- /dev/null +++ b/biomero/database.py @@ -0,0 +1,94 @@ +from eventsourcing.utils import get_topic +import logging +from sqlalchemy import create_engine, text, Column, Integer, String, URL, DateTime, Float +from sqlalchemy.orm import sessionmaker, declarative_base, scoped_session +from sqlalchemy.dialects.postgresql import UUID as PGUUID +import os + +logger = logging.getLogger(__name__) + +# --------------------- VIEWS DB tables/classes ---------------------------- # + +# Base class for declarative class definitions +Base = declarative_base() + + +class JobView(Base): + __tablename__ = 'biomero_job_view' + + slurm_job_id = Column(Integer, primary_key=True) + user = Column(Integer, nullable=False) + group = Column(Integer, nullable=False) + + +class JobProgressView(Base): + __tablename__ = 'biomero_job_progress_view' + + slurm_job_id = Column(Integer, primary_key=True) + status = Column(String, nullable=False) + progress = Column(String, nullable=True) + + +class TaskExecution(Base): + __tablename__ = 'biomero_task_execution' + + task_id = Column(PGUUID(as_uuid=True), primary_key=True) + task_name = Column(String, nullable=False) + task_version = Column(String) + user_id = Column(Integer, nullable=True) + group_id = Column(Integer, nullable=True) + status = Column(String, nullable=False) + start_time = Column(DateTime, nullable=False) + end_time = Column(DateTime, nullable=True) + error_type = Column(String, nullable=True) + + +class EngineManager: + _engine = None + _scoped_session_topic = None + _session = None + + @classmethod + def create_scoped_session(cls): + if cls._engine is None: + persistence_mod = os.getenv('PERSISTENCE_MODULE') + if 'sqlalchemy' in persistence_mod: + logger.info("Using sqlalchemy database") + database_url = os.getenv('SQLALCHEMY_URL') + cls._engine = create_engine(database_url) + else: + raise NotImplementedError(f"Can't handle {persistence_mod}") + + # setup tables if needed + Base.metadata.create_all(cls._engine) + + # Create a scoped_session object. + cls._session = scoped_session( + sessionmaker(autocommit=False, autoflush=False, bind=cls._engine) + ) + + class MyScopedSessionAdapter: + def __getattribute__(self, item: str) -> None: + return getattr(cls._session, item) + + # Produce the topic of the scoped session adapter class. + cls._scoped_session_topic = get_topic(MyScopedSessionAdapter) + + return cls._scoped_session_topic + + @classmethod + def get_session(cls): + return cls._session() + + @classmethod + def commit(cls): + cls._session.commit() + + @classmethod + def close_engine(cls): + if cls._engine is not None: + cls._session.remove() + cls._engine.dispose() + cls._engine = None + cls._session = None + cls._scoped_session_topic = None diff --git a/biomero/eventsourcing.py b/biomero/eventsourcing.py index 40a7ff4..11b2a5d 100644 --- a/biomero/eventsourcing.py +++ b/biomero/eventsourcing.py @@ -12,13 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from eventsourcing.domain import Aggregate, event from eventsourcing.application import Application from uuid import NAMESPACE_URL, UUID, uuid5 from typing import Any, Dict, List from fabric import Result import logging +from biomero.database import EngineManager # Create a logger for this module @@ -200,6 +200,7 @@ def initiate_workflow(self, logger.debug(f"[WFT] Initiating workflow: name={name}, description={description}, user={user}, group={group}") workflow = WorkflowRun(name, description, user, group) self.save(workflow) + EngineManager.commit() return workflow.id def add_task_to_workflow(self, @@ -217,9 +218,11 @@ def add_task_to_workflow(self, input_data, kwargs) self.save(task) + EngineManager.commit() workflow = self.repository.get(workflow_id) workflow.add_task(task.id) self.save(workflow) + EngineManager.commit() return task.id def start_workflow(self, workflow_id: UUID): @@ -228,6 +231,7 @@ def start_workflow(self, workflow_id: UUID): workflow = self.repository.get(workflow_id) workflow.start_workflow() self.save(workflow) + EngineManager.commit() def complete_workflow(self, workflow_id: UUID): logger.debug(f"[WFT] Completing workflow: workflow_id={workflow_id}") @@ -235,6 +239,7 @@ def complete_workflow(self, workflow_id: UUID): workflow = self.repository.get(workflow_id) workflow.complete_workflow() self.save(workflow) + EngineManager.commit() def fail_workflow(self, workflow_id: UUID, error_message: str): logger.debug(f"[WFT] Failing workflow: workflow_id={workflow_id}, error_message={error_message}") @@ -242,6 +247,7 @@ def fail_workflow(self, workflow_id: UUID, error_message: str): workflow = self.repository.get(workflow_id) workflow.fail_workflow(error_message) self.save(workflow) + EngineManager.commit() def start_task(self, task_id: UUID): logger.debug(f"[WFT] Starting task: task_id={task_id}") @@ -249,6 +255,7 @@ def start_task(self, task_id: UUID): task = self.repository.get(task_id) task.start_task() self.save(task) + EngineManager.commit() def complete_task(self, task_id: UUID, message: str): logger.debug(f"[WFT] Completing task: task_id={task_id}, message={message}") @@ -256,6 +263,7 @@ def complete_task(self, task_id: UUID, message: str): task = self.repository.get(task_id) task.complete_task(message) self.save(task) + EngineManager.commit() def fail_task(self, task_id: UUID, error_message: str): logger.debug(f"[WFT] Failing task: task_id={task_id}, error_message={error_message}") @@ -263,6 +271,7 @@ def fail_task(self, task_id: UUID, error_message: str): task = self.repository.get(task_id) task.fail_task(error_message) self.save(task) + EngineManager.commit() def add_job_id(self, task_id, slurm_job_id): logger.debug(f"[WFT] Adding job_id to task: task_id={task_id}, slurm_job_id={slurm_job_id}") @@ -270,6 +279,7 @@ def add_job_id(self, task_id, slurm_job_id): task = self.repository.get(task_id) task.add_job_id(slurm_job_id) self.save(task) + EngineManager.commit() def add_result(self, task_id, result): logger.debug(f"[WFT] Adding result to task: task_id={task_id}, result={result}") @@ -277,6 +287,7 @@ def add_result(self, task_id, result): task = self.repository.get(task_id) task.add_result(result) self.save(task) + EngineManager.commit() def update_task_status(self, task_id, status): logger.debug(f"[WFT] Adding status to task: task_id={task_id}, status={status}") @@ -284,6 +295,7 @@ def update_task_status(self, task_id, status): task = self.repository.get(task_id) task.update_task_status(status) self.save(task) + EngineManager.commit() def update_task_progress(self, task_id, progress): logger.debug(f"[WFT] Adding progress to task: task_id={task_id}, progress={progress}") @@ -291,6 +303,7 @@ def update_task_progress(self, task_id, progress): task = self.repository.get(task_id) task.update_task_progress(progress) self.save(task) + EngineManager.commit() diff --git a/biomero/slurm_client.py b/biomero/slurm_client.py index 4892913..60725f0 100644 --- a/biomero/slurm_client.py +++ b/biomero/slurm_client.py @@ -31,7 +31,8 @@ import io import os from biomero.eventsourcing import WorkflowTracker -from biomero.views import JobAccounting, JobProgress, WorkflowAnalytics, EngineManager +from biomero.views import JobAccounting, JobProgress, WorkflowAnalytics +from biomero.database import EngineManager from eventsourcing.system import System, SingleThreadedRunner logger = logging.getLogger(__name__) @@ -1407,7 +1408,7 @@ def run_conversion_workflow_job(self, f"echo \"Number of .{source_format} files: $N\"", conversion_cmd ] - + logger.debug(f"wf_id: {wf_id}") if not wf_id: wf_id = self.workflowTracker.initiate_workflow( "conversion", @@ -1415,6 +1416,7 @@ def run_conversion_workflow_job(self, -1, -1 ) + logger.debug(f"wf_id: {wf_id}") task_id = self.workflowTracker.add_task_to_workflow( wf_id, chosen_converter, diff --git a/biomero/views.py b/biomero/views.py index 40a43eb..5b12e79 100644 --- a/biomero/views.py +++ b/biomero/views.py @@ -11,93 +11,17 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.sql import func from sqlalchemy.dialects.postgresql import UUID as PGUUID +from sqlalchemy import event +from sqlalchemy.engine import Engine from biomero.eventsourcing import WorkflowRun, Task +from biomero.database import EngineManager, JobView, TaskExecution, JobProgressView logger = logging.getLogger(__name__) -# --------------------- VIEWS DB tables/classes ---------------------------- # - -# Base class for declarative class definitions -Base = declarative_base() - - -class JobView(Base): - __tablename__ = 'biomero_job_view' - - slurm_job_id = Column(Integer, primary_key=True) - user = Column(Integer, nullable=False) - group = Column(Integer, nullable=False) - - -class JobProgressView(Base): - __tablename__ = 'biomero_job_progress_view' - - slurm_job_id = Column(Integer, primary_key=True) - status = Column(String, nullable=False) - progress = Column(String, nullable=True) - - -class TaskExecution(Base): - __tablename__ = 'biomero_task_execution' - - task_id = Column(PGUUID(as_uuid=True), primary_key=True) - task_name = Column(String, nullable=False) - task_version = Column(String) - user_id = Column(Integer, nullable=True) - group_id = Column(Integer, nullable=True) - status = Column(String, nullable=False) - start_time = Column(DateTime, nullable=False) - end_time = Column(DateTime, nullable=True) - error_type = Column(String, nullable=True) - # ------------------- View Listener Applications ------------------ # -class EngineManager: - _engine = None - _scoped_session_topic = None - _session = None - - @classmethod - def create_scoped_session(cls): - if cls._engine is None: - persistence_mod = os.getenv('PERSISTENCE_MODULE') - if 'sqlalchemy' in persistence_mod: - logger.info("Using sqlalchemy database") - database_url=os.getenv('SQLALCHEMY_URL') - cls._engine = create_engine(database_url) - else: - raise NotImplementedError(f"Can't handle {persistence_mod}") - - # setup tables if needed - Base.metadata.create_all(cls._engine) - - # Create a scoped_session object. - cls._session = scoped_session( - sessionmaker(autocommit=False, autoflush=False, - bind=cls._engine) - ) - - class MyScopedSessionAdapter: - def __getattribute__(self, item: str) -> None: - return getattr(cls._session, item) - - # Produce the topic of the scoped session adapter class. - cls._scoped_session_topic = get_topic(MyScopedSessionAdapter) - - return cls._scoped_session_topic - - @classmethod - def close_engine(cls): - if cls._engine is not None: - cls._session.remove() - cls._engine.dispose() - cls._engine = None - cls._session = None - cls._scoped_session_topic = None - logger.info("Database engine disposed.") - class JobAccounting(ProcessApplication): def __init__(self, *args, **kwargs): @@ -170,16 +94,16 @@ def _(self, domain_event, process_event): def update_view_table(self, job_id, user, group): """Update the view table with new job information.""" - with self.recorder.transaction() as session: + with EngineManager.get_session() as session: try: new_job = JobView(slurm_job_id=job_id, user=user, group=group) session.add(new_job) session.commit() logger.debug(f"Inserted job into view table: job_id={job_id}, user={user}, group={group}") - except IntegrityError: + except IntegrityError as e: session.rollback() # Handle the case where the job already exists in the table if necessary - logger.error(f"Failed to insert job into view table (already exists?): job_id={job_id}, user={user}, group={group}") + logger.error(f"Failed to insert job into view table (already exists?): job_id={job_id}, user={user}, group={group}. Error {e}") def get_jobs(self, user=None, group=None): """Retrieve jobs for a specific user and/or group. @@ -197,7 +121,7 @@ def get_jobs(self, user=None, group=None): """ if user is None and group is None: # Retrieve all jobs grouped by user - with self.recorder.transaction() as session: + with EngineManager.get_session() as session: jobs = session.query(JobView.user, JobView.slurm_job_id).all() user_jobs = {} for user_id, job_id in jobs: @@ -206,7 +130,7 @@ def get_jobs(self, user=None, group=None): user_jobs[user_id].append(job_id) return user_jobs else: - with self.recorder.transaction() as session: + with EngineManager.get_session() as session: query = session.query(JobView.slurm_job_id) if user is not None: @@ -279,7 +203,7 @@ def _(self, domain_event, process_event): def update_view_table(self, job_id): """Update the view table with new job status and progress information.""" - with self.recorder.transaction() as session: + with EngineManager.get_session() as session: try: job_info = self.job_status[job_id] new_job_progress = JobProgressView( @@ -295,6 +219,12 @@ def update_view_table(self, job_id): logger.error(f"Failed to insert/update job progress in view table: job_id={job_id}") +# @event.listens_for(Engine, "before_cursor_execute") +# def before_cursor_execute(conn, cursor, statement, parameters, context, executemany): +# logger.debug(f"SQL: {statement}") +# logger.debug(f"Parameters: {parameters}") + + class WorkflowAnalytics(ProcessApplication): def __init__(self, *args, **kwargs): ProcessApplication.__init__(self, *args, **kwargs) @@ -347,14 +277,16 @@ def _(self, domain_event, process_event): self.tasks[task_id].update({ "task_name": task_name, "task_version": task_version, - "start_time": timestamp_created + "start_time": timestamp_created, + "status": "CREATED" }) else: # Initialize task tracking if TaskAdded hasn't been processed yet self.tasks[task_id] = { "task_name": task_name, "task_version": task_version, - "start_time": timestamp_created + "start_time": timestamp_created, + "status": "CREATED" } logger.debug(f"Task created: task_id={task_id}, task_name={task_name}, timestamp={timestamp_created}") @@ -413,9 +345,10 @@ def update_view_table(self, task_id): user_id = self.workflows[wf_id]["user"] group_id = self.workflows[wf_id]["group"] - with self.recorder.transaction() as session: + with EngineManager.get_session() as session: try: existing_task = session.query(TaskExecution).filter_by(task_id=task_id).first() + logger.debug(f"Existing: {existing_task}. vs info {task_info}") if existing_task: # Update existing task execution record existing_task.task_name = task_info.get("task_name", existing_task.task_name) @@ -443,9 +376,10 @@ def update_view_table(self, task_id): session.commit() logger.debug(f"Updated/Inserted task execution into view table: task_id={task_id}, task_name={task_info.get('task_name')}") - except IntegrityError: + except IntegrityError as e: session.rollback() - logger.error(f"Failed to insert/update task execution into view table: task_id={task_id}") + logger.error(f"Failed to insert/update task execution into view table: task_id={task_id}, error={str(e)}") + logger.debug(f"Task info: {task_info}") def get_task_counts(self, user=None, group=None): """Retrieve task execution counts grouped by task name and version. @@ -457,7 +391,7 @@ def get_task_counts(self, user=None, group=None): Returns: - Dictionary of task names and versions to counts. """ - with self.recorder.transaction() as session: + with EngineManager.get_session() as session: query = session.query( TaskExecution.task_name, TaskExecution.task_version, @@ -488,7 +422,7 @@ def get_average_task_duration(self, user=None, group=None): Returns: - Dictionary of task names and versions to average duration (in seconds). """ - with self.recorder.transaction() as session: + with EngineManager.get_session() as session: query = session.query( TaskExecution.task_name, TaskExecution.task_version, @@ -522,7 +456,7 @@ def get_task_failures(self, user=None, group=None): Returns: - Dictionary of task names and versions to lists of failure reasons. """ - with self.recorder.transaction() as session: + with EngineManager.get_session() as session: query = session.query( TaskExecution.task_name, TaskExecution.task_version, @@ -558,7 +492,7 @@ def get_task_usage_over_time(self, task_name, user=None, group=None): Returns: - Dictionary mapping date to the count of task executions on that date. """ - with self.recorder.transaction() as session: + with EngineManager.get_session() as session: query = session.query( func.date(TaskExecution.start_time), func.count(TaskExecution.task_name) From 5eac28d3f6dfebf87bcf004cb8ac1d7b1541854c Mon Sep 17 00:00:00 2001 From: Luik Date: Thu, 22 Aug 2024 11:08:01 +0200 Subject: [PATCH 12/24] Bump biomero minimal python version to 3.8 for eventsourcing_sqlalchemy --- .github/workflows/python-package.yml | 2 +- .github/workflows/python-publish.yml | 2 +- README.md | 2 +- biomero/__init__.py | 16 ++++------------ pyproject.toml | 5 ++--- resources/tutorials/tutorial_Azure_slurm.md | 12 ++++++------ 6 files changed, 15 insertions(+), 24 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index eb0050c..51ea19a 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -15,7 +15,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.7", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10"] steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index aa288ce..21a5259 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -27,7 +27,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: '3.7' + python-version: '3.8' - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/README.md b/README.md index 9ea081f..b36107d 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,7 @@ Your Slurm cluster/login node needs to have: Your OMERO _processing_ node needs to have: 1. SSH client and access to the Slurm cluster (w/ private key / headless) 2. SCP access to the Slurm cluster -3. Python3.7+ +3. Python3.8+ 4. This library installed - Latest release on PyPI `python3 -m pip install biomero` - or latest Github version `python3 -m pip install 'git+https://github.com/NL-BioImaging/biomero'` diff --git a/biomero/__init__.py b/biomero/__init__.py index 40c84cf..d3729d4 100644 --- a/biomero/__init__.py +++ b/biomero/__init__.py @@ -1,17 +1,9 @@ from .slurm_client import SlurmClient - +import importlib.metadata try: - import importlib.metadata - try: - __version__ = importlib.metadata.version(__package__) - except importlib.metadata.PackageNotFoundError: - __version__ = "Version not found" -except ModuleNotFoundError: # Python 3.7 - try: - import pkg_resources - __version__ = pkg_resources.get_distribution(__package__).version - except pkg_resources.DistributionNotFound: - __version__ = "Version not found" + __version__ = importlib.metadata.version(__package__) +except importlib.metadata.PackageNotFoundError: + __version__ = "Version not found" from .eventsourcing import * from .views import * \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 46d92b4..9ab0fdc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,12 +12,11 @@ authors = [ ] description = "A python library for easy connecting between OMERO (jobs) and a Slurm cluster" readme = "README.md" -requires-python = ">=3.7" +requires-python = ">=3.8" keywords = ["omero", "slurm", "high-performance-computing", "fair", "image-analysis", "bioimaging", "high-throughput-screening", "high-content-screening", "cytomine", "biomero", "biaflows"] license = { file = "LICENSE" } classifiers = [ "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", @@ -33,7 +32,7 @@ dependencies = [ "eventsourcing[crypto]==9.2.22", "sqlalchemy==2.0.32", "psycopg2==2.9.9", - "eventsourcing_sqlalchemy==0.7" + "eventsourcing_sqlalchemy==0.7" # requires py 3.8 ] [tool.setuptools.packages] diff --git a/resources/tutorials/tutorial_Azure_slurm.md b/resources/tutorials/tutorial_Azure_slurm.md index 6f9d22b..5829735 100644 --- a/resources/tutorials/tutorial_Azure_slurm.md +++ b/resources/tutorials/tutorial_Azure_slurm.md @@ -263,9 +263,9 @@ Fill in the actual ip, this is just a placeholder! - chmod the config to 700 too: `sudo chmod 700 .ssh/config` - Ready! `ssh localslurm` (or whatever you called the alias) -4. Let's edit the BIOMERO configuration `slurm-config.ini`, located in the worker-processor node +4. Let's edit the BIOMERO configuration `slurm-config.ini`, located in the biomeroworker node - - `vi ~/NL-BIOMERO/worker-processor/slurm-config.ini` + - `vi ~/NL-BIOMERO/biomeroworker/slurm-config.ini` - Change the `host` if you did not use the `localslurm` alias in the config above. - Change ALL the `[SLURM]` paths to match our new slurm setup: @@ -308,7 +308,7 @@ slurm_script_path=slurm-scripts - `cd NL-BIOMERO` - `docker compose down` - `docker compose up -d --build` - - Now `docker logs -f nl-biomero-omeroworker-processor-1` should show some good logs leading to: `Starting node omeroworker-processor`. + - Now `docker logs -f nl-biomero-biomeroworker-1` should show some good logs leading to: `Starting node biomeroworker`. ## 6. Showtime! @@ -444,12 +444,12 @@ sbatch: error: Batch job submission failed: Requested node configuration is not We will do this ad-hoc, by changing the configuration for CellPose in the `slurm-config.ini` in our installation: - - First, edit the config on the main VM with `vi worker-processor/slurm-config.ini` + - First, edit the config on the main VM with `vi biomeroworker/slurm-config.ini` - Add this line to your workflows `_job_cpus-per-task=2`, e.g. `cellpose_job_cpus-per-task=2` - save file (`:wq`) - Don't forget to open your .ssh to the container `chmod -R 777 ~/.ssh` (and close it later) - - Restart the biomero container(s) (`docker compose down` & `docker compose up -d --build`, perhaps specifically for `omeroworker-processor`). - - Check logs to see if biomero started up properly `docker logs -f nl-biomero-omeroworker-processor-1` + - Restart the biomero container(s) (`docker compose down` & `docker compose up -d --build`, perhaps specifically for `biomeroworker`). + - Check logs to see if biomero started up properly `docker logs -f nl-biomero-biomeroworker-1` 6. Next, time to segment! Time to spin up those SLURM compute nodes: From 91546ab4cbaf6506763872d75cad36ed4d5ba2b1 Mon Sep 17 00:00:00 2001 From: Luik Date: Thu, 22 Aug 2024 12:00:12 +0200 Subject: [PATCH 13/24] Make tracker&listeners configurable --- biomero/database.py | 13 +- biomero/eventsourcing.py | 13 ++ biomero/slurm_client.py | 245 ++++++++++++++++++++++---------- tests/unit/test_slurm_client.py | 156 +++++++++++++++++++- 4 files changed, 335 insertions(+), 92 deletions(-) diff --git a/biomero/database.py b/biomero/database.py index db2db56..91916b0 100644 --- a/biomero/database.py +++ b/biomero/database.py @@ -49,15 +49,12 @@ class EngineManager: _session = None @classmethod - def create_scoped_session(cls): + def create_scoped_session(cls, sqlalchemy_url=None): if cls._engine is None: - persistence_mod = os.getenv('PERSISTENCE_MODULE') - if 'sqlalchemy' in persistence_mod: - logger.info("Using sqlalchemy database") - database_url = os.getenv('SQLALCHEMY_URL') - cls._engine = create_engine(database_url) - else: - raise NotImplementedError(f"Can't handle {persistence_mod}") + # Note, we only allow sqlalchemy eventsourcing module + if not sqlalchemy_url: + sqlalchemy_url = os.getenv('SQLALCHEMY_URL') + cls._engine = create_engine(sqlalchemy_url) # setup tables if needed Base.metadata.create_all(cls._engine) diff --git a/biomero/eventsourcing.py b/biomero/eventsourcing.py index 11b2a5d..25619dc 100644 --- a/biomero/eventsourcing.py +++ b/biomero/eventsourcing.py @@ -189,6 +189,19 @@ def fail_task(self, error_message: str): # -------------------- APPLICATIONS -------------------- # +class NoOpWorkflowTracker: + def __getattr__(self, name): + """ + Override attribute access to make all methods no-op. + Logs a warning with the function name and parameters passed. + """ + def no_op_function(*args, **kwargs): + logger.debug(f"[No-op] Called function: {name} with args: {args}, kwargs: {kwargs}") + # Optionally log the parameters or add warnings instead + return None + + return no_op_function + class WorkflowTracker(Application): diff --git a/biomero/slurm_client.py b/biomero/slurm_client.py index 60725f0..a62b5d3 100644 --- a/biomero/slurm_client.py +++ b/biomero/slurm_client.py @@ -30,7 +30,7 @@ from importlib_resources import files import io import os -from biomero.eventsourcing import WorkflowTracker +from biomero.eventsourcing import WorkflowTracker, NoOpWorkflowTracker from biomero.views import JobAccounting, JobProgress, WorkflowAnalytics from biomero.database import EngineManager from eventsourcing.system import System, SingleThreadedRunner @@ -294,82 +294,99 @@ def __init__(self, slurm_script_repo: str = None, init_slurm: bool = False, track_workflows: bool = True, - ): - """Initializes a new instance of the SlurmClient class. + enable_job_accounting: bool = True, + enable_job_progress: bool = True, + enable_workflow_analytics: bool = True, + sqlalchemy_url: str = None): + """ + Initializes a new instance of the SlurmClient class. - It is preferable to use #from_config(...) method to initialize - parameters from a config file. + It is preferable to use the `#from_config(...)` method to initialize + parameters from a configuration file. Args: - host (str, optional): The hostname or IP address of the remote - server. Defaults to _DEFAULT_HOST. + host (str, optional): The hostname or IP address of the remote + server. Defaults to `_DEFAULT_HOST`. user (str, optional): The username to use when connecting to - the remote server. Defaults to None, which defaults - to config.user. + the remote server. Defaults to None, which falls back to + `config.user`. port (int, optional): The SSH port to use when connecting. - Defaults to None, which defaults to config.port. + Defaults to None, which falls back to `config.port`. config (str, optional): Path to the SSH config file. - Defaults to None, which defaults to your SSH config file. + Defaults to None, which falls back to your SSH config file. gateway (Connection, optional): An optional gateway for connecting through a jump host. Defaults to None. forward_agent (bool, optional): Whether to forward the local SSH agent to the remote server. Defaults to None, which - defaults to config.forward_agent. + falls back to `config.forward_agent`. connect_timeout (int, optional): Timeout for establishing the SSH - connection. Defaults to None, which defaults - to config.timeouts.connect. + connection. Defaults to None, which falls back to + `config.timeouts.connect`. connect_kwargs (dict, optional): Additional keyword arguments for - the underlying SSH connection. Handed verbatim to + the underlying SSH connection. These are passed verbatim to `SSHClient.connect `. Defaults to None. - inline_ssh_env (bool, optional): Whether to use inline SSH - environment. This is necessary if the remote server has - a restricted ``AcceptEnv`` setting (which is the common - default). Defaults to _DEFAULT_INLINE_SSH_ENV. - slurm_data_path (str, optional): The path to the directory - containing the data files for Slurm jobs. - Defaults to _DEFAULT_SLURM_DATA_PATH. - slurm_images_path (str, optional): The path to the directory - containing the Singularity images for Slurm jobs. - Defaults to _DEFAULT_SLURM_IMAGES_PATH. - slurm_converters_path (str, optional): The path to the directory - containing the Singularity images for file converters. - Defaults to _DEFAULT_SLURM_CONVERTERS_PATH. + inline_ssh_env (bool, optional): Whether to use inline SSH + environment variables. This is necessary if the remote server + has a restricted `AcceptEnv` setting (the common default). + Defaults to `_DEFAULT_INLINE_SSH_ENV`. + slurm_data_path (str, optional): The path to the directory + containing the data files for Slurm jobs. + Defaults to `_DEFAULT_SLURM_DATA_PATH`. + slurm_images_path (str, optional): The path to the directory + containing the Singularity images for Slurm jobs. + Defaults to `_DEFAULT_SLURM_IMAGES_PATH`. + slurm_converters_path (str, optional): The path to the directory + containing the Singularity images for file converters. + Defaults to `_DEFAULT_SLURM_CONVERTERS_PATH`. slurm_model_paths (dict, optional): A dictionary containing the - paths to the Singularity images for specific Slurm job models. + paths to the Singularity images for specific Slurm job models. Defaults to None. slurm_model_repos (dict, optional): A dictionary containing the - git repositories of Singularity images for specific Slurm - job models. - Defaults to None. + Git repositories of Singularity images for specific Slurm + job models. Defaults to None. slurm_model_images (dict, optional): A dictionary containing the - dockerhub of the Singularity images for specific Slurm - job models. Will fill automatically from the data in the git - repository if you set init_slurm. + DockerHub images of the Singularity images for specific + Slurm job models. Will be filled automatically from the + data in the Git repository if `init_slurm` is set to True. Defaults to None. - converter_images (dict, optional): A dictionairy containing the - dockerhub of the Singularity images for converters. - Will default to building converter available in this package - on Slurm instead if not configured. + converter_images (dict, optional): A dictionary containing the + DockerHub images of the Singularity images for file converters. + Will default to building the converter available in this package + on Slurm instead if not configured. Defaults to None. - slurm_model_jobs (dict, optional): A dictionary containing - information about specific Slurm job models. + slurm_model_jobs (dict, optional): A dictionary containing + information about specific Slurm job models. Defaults to None. - slurm_model_jobs_params (dict, optional): A dictionary containing - parameters for specific Slurm job models. + slurm_model_jobs_params (dict, optional): A dictionary containing + parameters for specific Slurm job models. Defaults to None. - slurm_script_path (str, optional): The path to the directory - containing the Slurm job submission scripts on Slurm. - Defaults to _DEFAULT_SLURM_GIT_SCRIPT_PATH. - slurm_script_repo (str, optional): The git https URL for cloning - the repo containing the Slurm job submission scripts. + slurm_script_path (str, optional): The path to the directory + containing the Slurm job submission scripts on Slurm. + Defaults to `_DEFAULT_SLURM_GIT_SCRIPT_PATH`. + slurm_script_repo (str, optional): The Git HTTPS URL for cloning + the repository containing the Slurm job submission scripts. Defaults to None. - init_slurm (bool): Whether to set up the required structures + init_slurm (bool, optional): Whether to set up the required structures on Slurm after initiating this client. This includes creating - missing folders, downloading container images, cloning git,etc. - This will take a while at first but will validate your setup. - Defaults to False to save time. + missing folders, downloading container images, cloning Git, etc. + This process will take some time initially but will validate + your setup. Defaults to False to save time. + track_workflows (bool, optional): Whether to track workflows. + Defaults to True. + enable_job_accounting (bool, optional): Whether to enable job + accounting. Defaults to True. + enable_job_progress (bool, optional): Whether to track job + progress. Defaults to True. + enable_workflow_analytics (bool, optional): Whether to enable + workflow analytics. Defaults to True. + sqlalchemy_url (str, optional): URL for eventsourcing database + connection. Defaults to None, which falls back to the + `SQLALCHEMY_URL` environment variable. Note that it will + always be overridden with the environment variable + `SQLALCHEMY_URL`, if that is set. """ + super(SlurmClient, self).__init__(host, user, port, @@ -400,25 +417,79 @@ def __init__(self, self.validate(validate_slurm_setup=init_slurm) # Setup workflow tracking and accounting + # Initialize the analytics settings self.track_workflows = track_workflows - system = System(pipes=[ - [WorkflowTracker, JobAccounting], - [WorkflowTracker, JobProgress], - [WorkflowTracker, WorkflowAnalytics] - ]) - if self.track_workflows: # use configured persistence from env - scoped_session_topic = EngineManager.create_scoped_session() + self.enable_job_accounting = enable_job_accounting + self.enable_job_progress = enable_job_progress + self.enable_workflow_analytics = enable_workflow_analytics + + # Initialize the analytics system + self.sqlalchemy_url = sqlalchemy_url + self.initialize_analytics_system() + + def initialize_analytics_system(self): + """ + Initialize the analytics system based on the analytics configuration + passed to the constructor. + """ + # Get persistence settings, prioritize environment variables + persistence_module = os.getenv("PERSISTENCE_MODULE", "eventsourcing_sqlalchemy") + if persistence_module != "eventsourcing_sqlalchemy": + raise NotImplementedError(f"Can't handle {persistence_module}. Currently only supports 'eventsourcing_sqlalchemy' as PERSISTENCE_MODULE") + + sqlalchemy_url = os.getenv("SQLALCHEMY_URL", self.sqlalchemy_url) + if not sqlalchemy_url: + raise ValueError("SQLALCHEMY_URL must be set either in init, config ('sqlalchemy_url') or as an environment variable.") + if sqlalchemy_url != self.sqlalchemy_url: + logger.info("Overriding configured SQLALCHEMY_URL with env var SQLALCHEMY_URL.") + + # Build the system based on the analytics configuration + pipes = [] + runner = None + if self.track_workflows: + # Add JobAccounting to the pipeline if enabled + if self.enable_job_accounting: + pipes.append([WorkflowTracker, JobAccounting]) + + # Add JobProgress to the pipeline if enabled + if self.enable_job_progress: + pipes.append([WorkflowTracker, JobProgress]) + + # Add WorkflowAnalytics to the pipeline if enabled + if self.enable_workflow_analytics: + pipes.append([WorkflowTracker, WorkflowAnalytics]) + + # Add onlys WorkflowTracker if no listeners are enabled + if not pipes: + pipes = [[WorkflowTracker]] + + system = System(pipes=pipes) + scoped_session_topic = EngineManager.create_scoped_session( + sqlalchemy_url=sqlalchemy_url) runner = SingleThreadedRunner(system, env={ - 'SQLALCHEMY_SCOPED_SESSION_TOPIC': scoped_session_topic}) + 'SQLALCHEMY_SCOPED_SESSION_TOPIC': scoped_session_topic, + 'PERSISTENCE_MODULE': persistence_module}) + runner.start() + self.workflowTracker = runner.get(WorkflowTracker) else: # turn off persistence, override - runner = SingleThreadedRunner(system, env={ - "PERSISTENCE_MODULE": ""}) - runner.start() - self.workflowTracker = runner.get(WorkflowTracker) - self.jobAccounting = runner.get(JobAccounting) - self.jobProgress = runner.get(JobProgress) - self.workflowAnalytics = runner.get(WorkflowAnalytics) + logger.warning("Tracking workflows is disabled. No-op WorkflowTracker will be used.") + self.workflowTracker = NoOpWorkflowTracker() + + if self.track_workflows and self.enable_job_accounting: + self.jobAccounting = runner.get(JobAccounting) + else: + self.jobAccounting = NoOpWorkflowTracker() + if self.track_workflows and self.enable_job_progress: + self.jobProgress = runner.get(JobProgress) + else: + self.jobProgress = NoOpWorkflowTracker() + + if self.track_workflows and self.enable_workflow_analytics: + self.workflowAnalytics = runner.get(WorkflowAnalytics) + else: + self.workflowAnalytics = NoOpWorkflowTracker() + def __exit__(self, exc_type, exc_val, exc_tb): # Ensure to call the parent class's __exit__ # to clean up Connection resources @@ -714,17 +785,16 @@ def from_config(cls, configfile: str = '', - /etc/slurm-config.ini - ~/slurm-config.ini - Note that this is only for the SLURM specific values that we added. + Note that this is only for the SLURM-specific values that we added. Most configuration values are set via configuration mechanisms from - Fabric library, - like SSH settings being loaded from SSH config, /etc/fabric.yml or - environment variables. + Fabric library, like SSH settings being loaded from SSH config, + /etc/fabric.yml or environment variables. See Fabric's documentation for more info on configuration if needed. Args: configfile (str): The path to your configuration file. Optional. init_slurm (bool): Initiate / validate slurm setup. Optional - Might take some time the first time with downloading etc. + Might take some time the first time with downloading, etc. Returns: SlurmClient: A new SlurmClient object. @@ -735,6 +805,7 @@ def from_config(cls, configfile: str = '', configs.read([cls._DEFAULT_CONFIG_PATH_1, cls._DEFAULT_CONFIG_PATH_2, configfile]) + # Read the required parameters from the configuration file, # fallback to defaults host = configs.get("SSH", "host", fallback=cls._DEFAULT_HOST) @@ -782,7 +853,6 @@ def from_config(cls, configfile: str = '', ) # Parse converters, if available - # Should be key=value where key is a name and value a docker image try: converter_items = configs.items("CONVERTERS") if converter_items: @@ -790,7 +860,22 @@ def from_config(cls, configfile: str = '', else: converter_images = None # Section exists but is empty except configparser.NoSectionError: - converter_images = None # Section does not exist + converter_images = None # Section does not exist + + # Read the analytics section, if available + try: + track_workflows = configs.getboolean('ANALYTICS', 'track_workflows', fallback=True) + enable_job_accounting = configs.getboolean('ANALYTICS', 'enable_job_accounting', fallback=True) + enable_job_progress = configs.getboolean('ANALYTICS', 'enable_job_progress', fallback=True) + enable_workflow_analytics = configs.getboolean('ANALYTICS', 'enable_workflow_analytics', fallback=True) + sqlalchemy_url = configs.get('ANALYTICS', 'sqlalchemy_url', fallback=None) + except configparser.NoSectionError: + # If the ANALYTICS section is missing, fallback to default values + track_workflows = True + enable_job_accounting = True + enable_job_progress = True + enable_workflow_analytics = True + sqlalchemy_url = None # Create the SlurmClient object with the parameters read from # the config file @@ -807,7 +892,13 @@ def from_config(cls, configfile: str = '', slurm_model_jobs_params=slurm_model_jobs_params, slurm_script_path=slurm_script_path, slurm_script_repo=slurm_script_repo, - init_slurm=init_slurm) + init_slurm=init_slurm, + # Pass analytics settings to the constructor + track_workflows=track_workflows, + enable_job_accounting=enable_job_accounting, + enable_job_progress=enable_job_progress, + enable_workflow_analytics=enable_workflow_analytics, + sqlalchemy_url=sqlalchemy_url) def cleanup_tmp_files(self, slurm_job_id: str, @@ -1336,7 +1427,7 @@ def run_workflow(self, res = self.run_commands([sbatch_cmd], sbatch_env) slurm_job_id = self.extract_job_id(res) - if self.track_workflows and task_id: + if task_id: self.workflowTracker.start_task(task_id) self.workflowTracker.add_job_id(task_id, slurm_job_id) self.workflowTracker.add_result(task_id, res) @@ -1430,7 +1521,7 @@ def run_conversion_workflow_job(self, slurm_job_id = self.extract_job_id(res) - if self.track_workflows and task_id: + if task_id: self.workflowTracker.start_task(task_id) self.workflowTracker.add_job_id(task_id, slurm_job_id) self.workflowTracker.add_result(task_id, res) diff --git a/tests/unit/test_slurm_client.py b/tests/unit/test_slurm_client.py index 2529112..b90b738 100644 --- a/tests/unit/test_slurm_client.py +++ b/tests/unit/test_slurm_client.py @@ -1,4 +1,6 @@ -from biomero import SlurmClient +import logging +from uuid import uuid4 +from biomero import SlurmClient, NoOpWorkflowTracker import pytest import mock from mock import patch, MagicMock @@ -969,6 +971,131 @@ def test_cleanup_tmp_files(mock_extract_data_location, mock_run_commands, assert result.ok is True +@pytest.mark.parametrize("track_workflows, enable_job_accounting, enable_job_progress, enable_workflow_analytics, expected_tracker_classes", [ + # Case when everything is enabled + (True, True, True, True, {"workflowTracker": "WorkflowTracker", "jobAccounting": "JobAccounting", "jobProgress": "JobProgress", "workflowAnalytics": "WorkflowAnalytics"}), + + # Case when tracking is disabled (NoOp for all) + (False, True, True, True, {"workflowTracker": "NoOpWorkflowTracker", "jobAccounting": "NoOpWorkflowTracker", "jobProgress": "NoOpWorkflowTracker", "workflowAnalytics": "NoOpWorkflowTracker"}), + + # Case when only accounting is disabled + (True, False, True, True, {"workflowTracker": "WorkflowTracker", "jobAccounting": "NoOpWorkflowTracker", "jobProgress": "JobProgress", "workflowAnalytics": "WorkflowAnalytics"}), + + # Case when only progress is disabled + (True, True, False, True, {"workflowTracker": "WorkflowTracker", "jobAccounting": "JobAccounting", "jobProgress": "NoOpWorkflowTracker", "workflowAnalytics": "WorkflowAnalytics"}), + + # Case when only analytics is disabled + (True, True, True, False, {"workflowTracker": "WorkflowTracker", "jobAccounting": "JobAccounting", "jobProgress": "JobProgress", "workflowAnalytics": "NoOpWorkflowTracker"}), + + # Case when all listeners are disabled + (True, False, False, False, {"workflowTracker": "WorkflowTracker", "jobAccounting": "NoOpWorkflowTracker", "jobProgress": "NoOpWorkflowTracker", "workflowAnalytics": "NoOpWorkflowTracker"}) +]) +@patch('biomero.slurm_client.Connection.create_session') +@patch('biomero.slurm_client.Connection.open') +@patch('biomero.slurm_client.Connection.put') +@patch('biomero.slurm_client.Connection.run') +def test_workflow_tracker_and_listeners_no_op( + mock_run, mock_put, mock_open, mock_session, + track_workflows, enable_job_accounting, enable_job_progress, enable_workflow_analytics, + expected_tracker_classes, caplog): + """ + Test that the WorkflowTracker, JobAccounting, JobProgress, and WorkflowAnalytics + are set to NoOpWorkflowTracker when tracking is disabled or listeners are disabled. + """ + # GIVEN + slurm_client = SlurmClient( + host="localhost", + port=8022, + user="slurm", + slurm_data_path="datapath", + slurm_images_path="imagespath", + slurm_script_path="scriptpath", + slurm_converters_path="converterspath", + slurm_script_repo="repo-url", + slurm_model_paths={'wf': 'path'}, + slurm_model_images={'wf': 'image'}, + slurm_model_repos={'wf': 'https://github.com/example/workflow1'}, + track_workflows=track_workflows, + enable_job_accounting=enable_job_accounting, + enable_job_progress=enable_job_progress, + enable_workflow_analytics=enable_workflow_analytics + ) + + # THEN + # Check workflow tracker + assert slurm_client.workflowTracker.__class__.__name__ == expected_tracker_classes['workflowTracker'] + + # Check job accounting listener + assert slurm_client.jobAccounting.__class__.__name__ == expected_tracker_classes['jobAccounting'] + + # Check job progress listener + assert slurm_client.jobProgress.__class__.__name__ == expected_tracker_classes['jobProgress'] + + # Check workflow analytics listener + assert slurm_client.workflowAnalytics.__class__.__name__ == expected_tracker_classes['workflowAnalytics'] + + # WHEN (No-Op calls) + if isinstance(slurm_client.workflowTracker, NoOpWorkflowTracker): + # Call NoOp methods on workflowTracker + slurm_client.workflowTracker.start_workflow("dummy_workflow") + slurm_client.workflowTracker.update_status("dummy_status") + # WHEN + workflow_id = uuid4() + task_name = 'example_task' + task_version = '1.0' + input_data = {'key': 'value'} + kwargs = {'param1': 'value1', 'param2': 'value2'} + with caplog.at_level(logging.DEBUG): + # Call methods that should be no-ops + slurm_client.workflowTracker.add_task_to_workflow( + workflow_id=workflow_id, + task_name=task_name, + task_version=task_version, + input_data=input_data, + kwargs=kwargs + ) + + # THEN + # Assert that appropriate log messages are generated for NoOpWorkflowTracker + for record in caplog.records: + if record.message.startswith("[No-op] Called function: add_task_to_workflow"): + assert f"'workflow_id': {repr(workflow_id)}" in record.message + assert f"'task_name': {repr(task_name)}" in record.message + assert f"'task_version': {repr(task_version)}" in record.message + assert f"'input_data': {repr(input_data)}" in record.message + assert f"'kwargs': {repr(kwargs)}" in record.message + + if isinstance(slurm_client.jobAccounting, NoOpWorkflowTracker): + # Call NoOp methods on jobAccounting + slurm_client.jobAccounting.record_job("dummy_job") + slurm_client.jobAccounting.get_job_status("dummy_job") + + if isinstance(slurm_client.jobProgress, NoOpWorkflowTracker): + # Call NoOp methods on jobProgress + slurm_client.jobProgress.track_progress("dummy_progress") + + if isinstance(slurm_client.workflowAnalytics, NoOpWorkflowTracker): + # Call NoOp methods on workflowAnalytics + slurm_client.workflowAnalytics.generate_report("dummy_report") + + # THEN + with caplog.at_level(logging.DEBUG): + # Assert that appropriate log messages are generated for NoOpWorkflowTracker + for tracker_name, tracker_class in expected_tracker_classes.items(): + tracker = getattr(slurm_client, tracker_name) + if tracker_class == "NoOpWorkflowTracker": + # Call a method to trigger the no-op behavior + tracker.some_method() + # Check the log output + assert any(record.message.startswith("[No-op] Called function: some_method") for record in caplog.records) + + + # THEN (No actions should be performed, and nothing should break) + # We can only verify that no exceptions are raised and no actual work was done + # Logs or other side-effects can be checked if logging is added to NoOpWorkflowTracker + assert True # No exception means test passes + + @patch('biomero.slurm_client.Connection.create_session') @patch('biomero.slurm_client.Connection.open') @patch('biomero.slurm_client.Connection.put') @@ -993,7 +1120,18 @@ def test_from_config(mock_ConfigParser, mock_configparser_instance.read.return_value = None mv = "configvalue" mock_configparser_instance.get.return_value = mv - mock_configparser_instance.getboolean.return_value = True + mock_configparser_instance.getboolean.side_effect = lambda section, option, fallback: { + 'track_workflows': True, + 'enable_job_accounting': True, + 'enable_job_progress': True, + 'enable_workflow_analytics': True + }.get(option, fallback) + + # Set up mock for 'sqlalchemy_url' (new addition) + mock_configparser_instance.get.side_effect = lambda section, option, fallback=None: { + ('ANALYTICS', 'sqlalchemy_url'): 'sqlite:///test.db', + }.get((section, option), mv) + model_dict = { "m1": "v1" } @@ -1028,10 +1166,8 @@ def items_side_effect(section): return conv_dict else: return {}.items() - + mock_configparser_instance.items.side_effect = items_side_effect - # mock_configparser_instance.items.return_value = {**model_dict, **repo_dict, - # **job_dict, **jp_dict} # Configure the MagicMock to return the mock_configparser_instance when called mock_ConfigParser.return_value = mock_configparser_instance @@ -1061,10 +1197,16 @@ def items_side_effect(section): slurm_model_jobs_params=jp_dict_out, # expected slurm_model_jobs_params value, slurm_script_path=mv, # expected slurm_script_path value, slurm_script_repo=mv, # expected slurm_script_repo value, - init_slurm=init_slurm + init_slurm=init_slurm, + track_workflows=True, # expected track_workflows value + enable_job_accounting=True, # expected enable_job_accounting value + enable_job_progress=True, # expected enable_job_progress value + enable_workflow_analytics=True, # expected enable_workflow_analytics value + sqlalchemy_url='sqlite:///test.db' # expected sqlalchemy_url value ) - + + def test_parse_docker_image_with_version(slurm_client): version, image_name = slurm_client.parse_docker_image_version("example_image:1.0") assert version == "1.0" From 3c3857d08ca788eb005275603da95edb8598e431 Mon Sep 17 00:00:00 2001 From: Luik Date: Wed, 28 Aug 2024 16:26:30 +0200 Subject: [PATCH 14/24] Add docstring and tests for eventsourcing db --- .github/workflows/python-package.yml | 6 +- biomero/database.py | 85 +- biomero/eventsourcing.py | 260 ++++- biomero/slurm_client.py | 18 +- biomero/views.py | 15 +- pyproject.toml | 1 + tests/unit/test_eventsourcing.py | 1507 ++++++++++++++++++++++++++ tests/unit/test_slurm_client.py | 8 +- 8 files changed, 1876 insertions(+), 24 deletions(-) create mode 100644 tests/unit/test_eventsourcing.py diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 51ea19a..cdc8ae1 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -36,4 +36,8 @@ jobs: flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - name: Test with pytest run: | - python -m pytest + python -m pytest --cov=biomero --cov-report=xml + - name: Coveralls GitHub Action + uses: coverallsapp/github-action@v2.3.0 + + diff --git a/biomero/database.py b/biomero/database.py index 91916b0..67bb909 100644 --- a/biomero/database.py +++ b/biomero/database.py @@ -1,4 +1,18 @@ -from eventsourcing.utils import get_topic +# -*- coding: utf-8 -*- +# Copyright 2024 Torec Luik +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from eventsourcing.utils import get_topic, clear_topic_cache import logging from sqlalchemy import create_engine, text, Column, Integer, String, URL, DateTime, Float from sqlalchemy.orm import sessionmaker, declarative_base, scoped_session @@ -14,6 +28,14 @@ class JobView(Base): + """ + SQLAlchemy model for the 'biomero_job_view' table. + + Attributes: + slurm_job_id (Integer): The unique identifier for the Slurm job. + user (Integer): The ID of the user who submitted the job. + group (Integer): The group ID associated with the job. + """ __tablename__ = 'biomero_job_view' slurm_job_id = Column(Integer, primary_key=True) @@ -22,6 +44,14 @@ class JobView(Base): class JobProgressView(Base): + """ + SQLAlchemy model for the 'biomero_job_progress_view' table. + + Attributes: + slurm_job_id (Integer): The unique identifier for the Slurm job. + status (String): The current status of the Slurm job. + progress (String, optional): The progress status of the Slurm job. + """ __tablename__ = 'biomero_job_progress_view' slurm_job_id = Column(Integer, primary_key=True) @@ -30,6 +60,20 @@ class JobProgressView(Base): class TaskExecution(Base): + """ + SQLAlchemy model for the 'biomero_task_execution' table. + + Attributes: + task_id (PGUUID): The unique identifier for the task. + task_name (String): The name of the task. + task_version (String): The version of the task. + user_id (Integer, optional): The ID of the user who initiated the task. + group_id (Integer, optional): The group ID associated with the task. + status (String): The current status of the task. + start_time (DateTime): The time when the task started. + end_time (DateTime, optional): The time when the task ended. + error_type (String, optional): Type of error encountered during execution, if any. + """ __tablename__ = 'biomero_task_execution' task_id = Column(PGUUID(as_uuid=True), primary_key=True) @@ -44,12 +88,33 @@ class TaskExecution(Base): class EngineManager: + """ + Manages the SQLAlchemy engine and session lifecycle. + + Class Attributes: + _engine: The SQLAlchemy engine used to connect to the database. + _scoped_session_topic: The topic of the scoped session. + _session: The scoped session used for database operations. + """ _engine = None _scoped_session_topic = None _session = None @classmethod - def create_scoped_session(cls, sqlalchemy_url=None): + def create_scoped_session(cls, sqlalchemy_url: str = None): + """ + Creates and returns a scoped session for interacting with the database. + + If the engine doesn't already exist, it initializes the SQLAlchemy engine + and sets up the scoped session. + + Args: + sqlalchemy_url (str, optional): The SQLAlchemy database URL. If not provided, + the method will retrieve the value from the 'SQLALCHEMY_URL' environment variable. + + Returns: + str: The topic of the scoped session adapter class. + """ if cls._engine is None: # Note, we only allow sqlalchemy eventsourcing module if not sqlalchemy_url: @@ -75,17 +140,33 @@ def __getattribute__(self, item: str) -> None: @classmethod def get_session(cls): + """ + Retrieves the current scoped session. + + Returns: + Session: The SQLAlchemy session for interacting with the database. + """ return cls._session() @classmethod def commit(cls): + """ + Commits the current transaction in the scoped session. + """ cls._session.commit() @classmethod def close_engine(cls): + """ + Closes the database engine and cleans up the session. + + This method disposes of the SQLAlchemy engine, removes the session, + and resets all associated class attributes to `None`. + """ if cls._engine is not None: cls._session.remove() cls._engine.dispose() cls._engine = None cls._session = None cls._scoped_session_topic = None + clear_topic_cache() diff --git a/biomero/eventsourcing.py b/biomero/eventsourcing.py index 25619dc..99d2865 100644 --- a/biomero/eventsourcing.py +++ b/biomero/eventsourcing.py @@ -28,6 +28,18 @@ class ResultDict(dict): + """ + A dictionary subclass that stores details from a Fabric Result object. + + Args: + result (Result): The Fabric result object. + + Attributes: + command (str): The command that was executed. + env (dict): The environment variables during the command execution. + stdout (str): The standard output from the command. + stderr (str): The standard error from the command. + """ def __init__(self, result: Result): super().__init__() self['command'] = result.command @@ -43,9 +55,28 @@ def __init__(self, result: Result): class WorkflowRun(Aggregate): + """ + Represents a workflow run as an aggregate in the domain model. + + Attributes: + name (str): The name of the workflow. + description (str): A description of the workflow. + user (int): The ID of the user who initiated the workflow. + group (int): The ID of the group associated with the workflow. + tasks (list): A list of task UUIDs associated with the workflow. + """ INITIAL_VERSION = 0 class WorkflowInitiated(Aggregate.Created): + """ + Event triggered when a new workflow is initiated. + + Attributes: + name (str): The name of the workflow. + description (str): A description of the workflow. + user (int): The ID of the user who initiated the workflow. + group (int): The group ID associated with the workflow. + """ name: str description: str user: int @@ -64,6 +95,12 @@ def __init__(self, name: str, logger.debug(f"Initializing WorkflowRun: name={name}, description={description}, user={user}, group={group}") class TaskAdded(Aggregate.Event): + """ + Event triggered when a task is added to the workflow. + + Attributes: + task_id (UUID): The UUID of the task added to the workflow. + """ task_id: UUID @event(TaskAdded) @@ -72,6 +109,9 @@ def add_task(self, task_id: UUID): self.tasks.append(task_id) class WorkflowStarted(Aggregate.Event): + """ + Event triggered when the workflow starts. + """ pass @event(WorkflowStarted) @@ -80,6 +120,9 @@ def start_workflow(self): pass class WorkflowCompleted(Aggregate.Event): + """ + Event triggered when the workflow completes. + """ pass @event(WorkflowCompleted) @@ -88,6 +131,12 @@ def complete_workflow(self): pass class WorkflowFailed(Aggregate.Event): + """ + Event triggered when the workflow fails. + + Attributes: + error_message (str): The error message indicating why the workflow failed. + """ error_message: str @event(WorkflowFailed) @@ -97,9 +146,33 @@ def fail_workflow(self, error_message: str): class Task(Aggregate): + """ + Represents a task in a workflow as an aggregate in the domain model. + + Attributes: + workflow_id (UUID): The UUID of the associated workflow. + task_name (str): The name of the task. + task_version (str): The version of the task. + input_data (dict): Input data for the task. + params (dict): Additional parameters for the task. + job_ids (list): List of job IDs associated with the task. + results (list): List of results from the task execution. + result_message (str): Message related to the result of the task. + status (str): The current status of the task. + """ INITIAL_VERSION = 0 class TaskCreated(Aggregate.Created): + """ + Event triggered when a task is created. + + Attributes: + workflow_id (UUID): The UUID of the associated workflow. + task_name (str): The name of the task. + task_version (str): The version of the task. + input_data (dict): Input data for the task. + params (dict): Additional parameters for the task. + """ workflow_id: UUID task_name: str task_version: str @@ -126,6 +199,12 @@ def __init__(self, logger.debug(f"Initializing Task: workflow_id={workflow_id}, task_name={task_name}, task_version={task_version}") class JobIdAdded(Aggregate.Event): + """ + Event triggered when a job ID is added to the task. + + Attributes: + job_id (str): The job ID added to the task. + """ job_id: str @event(JobIdAdded) @@ -134,6 +213,12 @@ def add_job_id(self, job_id): self.job_ids.append(job_id) class StatusUpdated(Aggregate.Event): + """ + Event triggered when the task's status is updated. + + Attributes: + status (str): The updated status of the task. + """ status: str @event(StatusUpdated) @@ -142,6 +227,12 @@ def update_task_status(self, status): self.status = status class ProgressUpdated(Aggregate.Event): + """ + Event triggered when the task's progress is updated. + + Attributes: + progress (str): The updated progress of the task. + """ progress: str @event(ProgressUpdated) @@ -150,6 +241,12 @@ def update_task_progress(self, progress): self.progress = progress class ResultAdded(Aggregate.Event): + """ + Event triggered when a result is added to the task. + + Attributes: + result (ResultDict): The result dictionary added to the task. + """ result: ResultDict def add_result(self, result: Result): @@ -163,6 +260,9 @@ def _add_result(self, result: ResultDict): self.results.append(result) class TaskStarted(Aggregate.Event): + """ + Event triggered when the task starts. + """ pass @event(TaskStarted) @@ -171,6 +271,12 @@ def start_task(self): pass class TaskCompleted(Aggregate.Event): + """ + Event triggered when the task completes. + + Attributes: + result (str): The result message of the task. + """ result: str @event(TaskCompleted) @@ -179,37 +285,83 @@ def complete_task(self, result: str): self.result_message = result class TaskFailed(Aggregate.Event): + """ + Event triggered when the task fails. + + Attributes: + error_message (str): The error message indicating why the task failed. + """ error_message: str @event(TaskFailed) def fail_task(self, error_message: str): logger.debug(f"Failing task: id={self.id}, error_message={error_message}") + self.result_message = error_message pass # -------------------- APPLICATIONS -------------------- # class NoOpWorkflowTracker: + """ + A no-operation workflow tracker that makes all method calls no-op. + + All method calls to this class will return None and log the function name + and its parameters for debugging purposes. + """ def __getattr__(self, name): """ - Override attribute access to make all methods no-op. - Logs a warning with the function name and parameters passed. + Override attribute access to return a no-op function for undefined methods. + + Args: + name (str): The name of the attribute or method being accessed. + + Returns: + function: A no-op function that logs its name and arguments. """ def no_op_function(*args, **kwargs): logger.debug(f"[No-op] Called function: {name} with args: {args}, kwargs: {kwargs}") - # Optionally log the parameters or add warnings instead return None return no_op_function class WorkflowTracker(Application): + """ + Application service class for managing workflow and task lifecycle operations. + + Methods: + initiate_workflow: Creates a new workflow. + add_task_to_workflow: Adds a new task to an existing workflow. + start_workflow: Starts a workflow by its UUID. + complete_workflow: Marks a workflow as completed. + fail_workflow: Marks a workflow as failed with an error message. + start_task: Starts a task by its UUID. + complete_task: Marks a task as completed with a result message. + fail_task: Marks a task as failed with an error message. + add_job_id: Adds a job ID to a task. + add_result: Adds a result to a task. + update_task_status: Updates the status of a task. + update_task_progress: Updates the progress of a task. + """ def initiate_workflow(self, name: str, description: str, user: int, group: int) -> UUID: + """ + Initiates a new workflow. + + Args: + name (str): The name of the workflow. + description (str): A description of the workflow. + user (int): The ID of the user initiating the workflow. + group (int): The group ID associated with the workflow. + + Returns: + UUID: The UUID of the newly initiated workflow. + """ logger.debug(f"[WFT] Initiating workflow: name={name}, description={description}, user={user}, group={group}") workflow = WorkflowRun(name, description, user, group) self.save(workflow) @@ -223,6 +375,19 @@ def add_task_to_workflow(self, input_data: Dict[str, Any], kwargs: Dict[str, Any] ) -> UUID: + """ + Adds a task to the specified workflow. + + Args: + workflow_id (UUID): The UUID of the workflow. + task_name (str): The name of the task. + task_version (str): The version of the task. + input_data (dict): Input data for the task. + kwargs (dict): Additional parameters for the task. + + Returns: + UUID: The UUID of the newly added task. + """ logger.debug(f"[WFT] Adding task to workflow: workflow_id={workflow_id}, task_name={task_name}, task_version={task_version}") task = Task(workflow_id, @@ -232,88 +397,155 @@ def add_task_to_workflow(self, kwargs) self.save(task) EngineManager.commit() - workflow = self.repository.get(workflow_id) + workflow: WorkflowRun = self.repository.get(workflow_id) workflow.add_task(task.id) self.save(workflow) EngineManager.commit() return task.id def start_workflow(self, workflow_id: UUID): + """ + Starts the workflow with the given UUID. + + Args: + workflow_id (UUID): The UUID of the workflow to start. + """ logger.debug(f"[WFT] Starting workflow: workflow_id={workflow_id}") - workflow = self.repository.get(workflow_id) + workflow: WorkflowRun = self.repository.get(workflow_id) workflow.start_workflow() self.save(workflow) EngineManager.commit() def complete_workflow(self, workflow_id: UUID): + """ + Marks the workflow with the given UUID as completed. + + Args: + workflow_id (UUID): The UUID of the workflow to complete. + """ logger.debug(f"[WFT] Completing workflow: workflow_id={workflow_id}") - workflow = self.repository.get(workflow_id) + workflow: WorkflowRun = self.repository.get(workflow_id) workflow.complete_workflow() self.save(workflow) EngineManager.commit() def fail_workflow(self, workflow_id: UUID, error_message: str): + """ + Marks the workflow with the given UUID as failed with an error message. + + Args: + workflow_id (UUID): The UUID of the workflow to fail. + error_message (str): The error message describing the failure. + """ logger.debug(f"[WFT] Failing workflow: workflow_id={workflow_id}, error_message={error_message}") - workflow = self.repository.get(workflow_id) + workflow: WorkflowRun = self.repository.get(workflow_id) workflow.fail_workflow(error_message) self.save(workflow) EngineManager.commit() def start_task(self, task_id: UUID): + """ + Starts the task with the given UUID. + + Args: + task_id (UUID): The UUID of the task to start. + """ logger.debug(f"[WFT] Starting task: task_id={task_id}") - task = self.repository.get(task_id) + task: Task = self.repository.get(task_id) task.start_task() self.save(task) EngineManager.commit() def complete_task(self, task_id: UUID, message: str): + """ + Marks the task with the given UUID as completed with a result message. + + Args: + task_id (UUID): The UUID of the task to complete. + message (str): The result message of the task. + """ logger.debug(f"[WFT] Completing task: task_id={task_id}, message={message}") - task = self.repository.get(task_id) + task: Task = self.repository.get(task_id) task.complete_task(message) self.save(task) EngineManager.commit() def fail_task(self, task_id: UUID, error_message: str): + """ + Marks the task with the given UUID as failed with an error message. + + Args: + task_id (UUID): The UUID of the task to fail. + error_message (str): The error message describing the failure. + """ logger.debug(f"[WFT] Failing task: task_id={task_id}, error_message={error_message}") - task = self.repository.get(task_id) + task: Task = self.repository.get(task_id) task.fail_task(error_message) self.save(task) EngineManager.commit() def add_job_id(self, task_id, slurm_job_id): + """ + Adds a Slurm job ID to the task with the given UUID. + + Args: + task_id (UUID): The UUID of the task. + slurm_job_id (str): The Slurm job ID to associate with the task. + """ logger.debug(f"[WFT] Adding job_id to task: task_id={task_id}, slurm_job_id={slurm_job_id}") - task = self.repository.get(task_id) + task: Task = self.repository.get(task_id) task.add_job_id(slurm_job_id) self.save(task) EngineManager.commit() def add_result(self, task_id, result): + """ + Adds a result to the task with the given UUID. + + Args: + task_id (UUID): The UUID of the task. + result (Result): The Fabric result object to add to the task. + """ logger.debug(f"[WFT] Adding result to task: task_id={task_id}, result={result}") - task = self.repository.get(task_id) + task: Task = self.repository.get(task_id) task.add_result(result) self.save(task) EngineManager.commit() def update_task_status(self, task_id, status): + """ + Updates the status of the task with the given UUID. + + Args: + task_id (UUID): The UUID of the task. + status (str): The new status of the task. + """ logger.debug(f"[WFT] Adding status to task: task_id={task_id}, status={status}") - task = self.repository.get(task_id) + task: Task = self.repository.get(task_id) task.update_task_status(status) self.save(task) EngineManager.commit() def update_task_progress(self, task_id, progress): + """ + Updates the progress of the task with the given UUID. + + Args: + task_id (UUID): The UUID of the task. + progress (str): The updated progress of the task. + """ logger.debug(f"[WFT] Adding progress to task: task_id={task_id}, progress={progress}") - task = self.repository.get(task_id) + task: Task = self.repository.get(task_id) task.update_task_progress(progress) self.save(task) EngineManager.commit() diff --git a/biomero/slurm_client.py b/biomero/slurm_client.py index a62b5d3..69d5774 100644 --- a/biomero/slurm_client.py +++ b/biomero/slurm_client.py @@ -50,11 +50,13 @@ class SlurmJob: submit_result (Result): The result of submitting the job. ok (bool): Indicates whether the job submission was successful. job_state (str): The current state of the Slurm job. - error_message (str): The error message, if any. + progress (str): The progress of the Slurm job. + error_message (str): The error message, if any, encountered during job submission. + wf_id (UUID): The workflow ID associated with the job. + task_id (UUID): The task ID within the workflow. + slurm_polling_interval (int): The polling interval (in seconds) for checking the job status. - Args: - submit_result (Result): The result of submitting the job. - job_id (int): The Sluslurm_job_id + Example: # Submit some job with the SlurmClient submit_result, job_id, wf_id, task_id = slurmClient.run_workflow( workflow_name, workflow_version, input_data, email, time, wf_id, @@ -91,6 +93,10 @@ def __init__(self, Args: submit_result (Result): The result of submitting the job. job_id (int): The Slurm job ID. + wf_id (UUID): The workflow ID associated with this job. + task_id (UUID): The task ID within the workflow. + slurm_polling_interval (int, optional): The interval in seconds for + polling the job status. Defaults to SLURM_POLLING_INTERVAL. """ self.job_id = job_id self.wf_id = wf_id @@ -218,6 +224,7 @@ class SlurmClient(Connection): containing the Slurm job submission scripts. Optional. Example: + # Create a SlurmClient object as contextmanager with SlurmClient.from_config() as client: @@ -236,9 +243,10 @@ class SlurmClient(Connection): print(result.stdout) Example 2: + # Create a SlurmClient and setup Slurm (download containers etc.) - with SlurmClient.from_config(inislurm_job_id + with SlurmClient.from_config(init_slurm=True) as client: client.run_workflow(...) diff --git a/biomero/views.py b/biomero/views.py index 5b12e79..f5e0330 100644 --- a/biomero/views.py +++ b/biomero/views.py @@ -1,5 +1,18 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Torec Luik +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os - from eventsourcing.system import ProcessApplication from eventsourcing.dispatch import singledispatchmethod from eventsourcing.utils import get_topic diff --git a/pyproject.toml b/pyproject.toml index 9ab0fdc..f874f3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ find = {} # Scan the project directory with the default parameters [project.optional-dependencies] test = [ "pytest", + "pytest-cov", "mock", "psycopg2-binary" ] diff --git a/tests/unit/test_eventsourcing.py b/tests/unit/test_eventsourcing.py new file mode 100644 index 0000000..f6c9278 --- /dev/null +++ b/tests/unit/test_eventsourcing.py @@ -0,0 +1,1507 @@ +from datetime import datetime, timedelta, timezone +import json +import os +from unittest.mock import Mock, patch +import uuid +import pytest +from biomero.eventsourcing import Task, WorkflowTracker +from biomero.views import JobAccounting, JobProgress, WorkflowAnalytics +from biomero.database import EngineManager, JobProgressView, JobView, TaskExecution +from uuid import UUID +import logging +from eventsourcing.system import System, SingleThreadedRunner + +# Configure logging +logging.basicConfig(level=logging.INFO) + +# Fixture for setting up the environment variables and session + + +@pytest.fixture(autouse=True) +def set_env_vars_and_session(): + # Set the necessary environment variables for testing + os.environ["PERSISTENCE_MODULE"] = "eventsourcing_sqlalchemy" + os.environ["SQLALCHEMY_URL"] = "sqlite:///:memory:" + + # Initialize the scoped session for testing + EngineManager.create_scoped_session() + + # Yield control to the test function + yield + + # Clean up environment variables after the test + del os.environ["PERSISTENCE_MODULE"] + del os.environ["SQLALCHEMY_URL"] + + # Close and clean up the database session + EngineManager.close_engine() + + +@pytest.fixture +def workflow_tracker(): + # Fixture to set up the WorkflowTracker application + return WorkflowTracker() + + +@pytest.fixture +def workflow_tracker_and_job_accounting(): + """Fixture to initialize System and SingleThreadedRunner with WorkflowTracker and JobAccounting.""" + # Create a System instance with the necessary components + system = System(pipes=[[WorkflowTracker, JobAccounting]]) + runner = SingleThreadedRunner(system) + runner.start() + + # Yield the instances of WorkflowTracker and JobAccounting + yield runner.get(WorkflowTracker), runner.get(JobAccounting) + + # Cleanup after tests + runner.stop() + + +@pytest.fixture +def workflow_tracker_and_job_progress(): + """Fixture to initialize System and SingleThreadedRunner with WorkflowTracker and JobProgress.""" + # Create a System instance with the necessary components + system = System(pipes=[[WorkflowTracker, JobProgress]]) + runner = SingleThreadedRunner(system) + runner.start() + + # Yield the instances of WorkflowTracker and JobProgress + yield runner.get(WorkflowTracker), runner.get(JobProgress) + + # Cleanup after tests + runner.stop() + + +@pytest.fixture +def workflow_tracker_and_workflow_analytics(): + """Fixture to initialize WorkflowTracker and WorkflowAnalytics.""" + # Create a System instance with the necessary components + system = System(pipes=[[WorkflowTracker, WorkflowAnalytics]]) + runner = SingleThreadedRunner(system) + runner.start() + + # Yield the instances of WorkflowTracker and WorkflowAnalytics + yield runner.get(WorkflowTracker), runner.get(WorkflowAnalytics) + + # Cleanup after tests + runner.stop() + + +def test_initiate_workflow(workflow_tracker): + # Initiating a workflow + workflow_id = workflow_tracker.initiate_workflow( + name="Test Workflow", + description="This is a test workflow", + user=1, + group=1 + ) + + # Check if the returned ID is a valid UUID + assert isinstance(workflow_id, UUID) + + # Check if the workflow has been added to the repository + workflow = workflow_tracker.repository.get(workflow_id) + assert workflow is not None + assert workflow.name == "Test Workflow" + assert workflow.description == "This is a test workflow" + assert workflow.user == 1 + assert workflow.group == 1 + assert len(workflow.tasks) == 0 # Initially, no tasks should be present + + +def test_add_task_to_workflow(workflow_tracker): + # Initiate a workflow + workflow_id = workflow_tracker.initiate_workflow( + name="Test Workflow with Task", + description="Workflow description", + user=1, + group=1 + ) + + # Add a task to the workflow + task_id = workflow_tracker.add_task_to_workflow( + workflow_id=workflow_id, + task_name="Task 1", + task_version="1.0", + input_data={"input_key": "input_value"}, + kwargs={} + ) + assert isinstance(task_id, UUID) + + # Check the task and workflow are updated + workflow = workflow_tracker.repository.get(workflow_id) + assert len(workflow.tasks) == 1 + assert workflow.tasks[0] == task_id + task = workflow_tracker.repository.get(task_id) + assert task.task_name == "Task 1" + assert task.workflow_id == workflow_id + + +def test_start_workflow(workflow_tracker, caplog): + """ + Test starting a workflow and checking for the WorkflowStarted event. + """ + # Enable capturing of log messages + with caplog.at_level("DEBUG"): + # Initiate a workflow + workflow_id = workflow_tracker.initiate_workflow( + name="Workflow to Start", + description="Description of workflow", + user=1, + group=1 + ) + + # Start the workflow + workflow_tracker.start_workflow(workflow_id) + + # Verify the workflow has emitted the WorkflowStarted event via logs + assert any( + "Starting workflow" in record.message and str( + workflow_id) in record.message + for record in caplog.records + ) + + +def test_complete_workflow(workflow_tracker): + # Initiate, start, and complete a workflow + workflow_id = workflow_tracker.initiate_workflow( + name="Workflow to Complete", + description="Description of workflow", + user=1, + group=1 + ) + workflow_tracker.start_workflow(workflow_id) + workflow_tracker.complete_workflow(workflow_id) + + # Retrieve and print notifications + notifications = workflow_tracker.notification_log.select(start=1, limit=10) + + # Sort notifications by ID to ensure proper ordering + notifications_sorted = sorted(notifications, key=lambda n: n.id) + + # Define the expected order of topics + expected_topics = [ + 'biomero.eventsourcing:WorkflowRun.WorkflowInitiated', + 'biomero.eventsourcing:WorkflowRun.WorkflowStarted', + 'biomero.eventsourcing:WorkflowRun.WorkflowCompleted', + ] + + # Extract topics in the order of sorted notifications + actual_topics = [ + notification.topic for notification in notifications_sorted] + + # Check if the actual topics match the expected topics + assert actual_topics == expected_topics, ( + f"Expected topics order: {expected_topics}, but got: {actual_topics}" + ) + + +def test_fail_workflow(workflow_tracker): + # GIVEN a workflow is initiated + workflow_id = workflow_tracker.initiate_workflow( + name="Workflow to Fail", + description="Description of workflow", + user=1, + group=1 + ) + + # WHEN the workflow is marked as failed + workflow_tracker.fail_workflow(workflow_id, error_message="Some error") + + # THEN verify the workflow has a failure notification + notifications = workflow_tracker.notification_log.select(start=1, limit=10) + notifications_sorted = sorted(notifications, key=lambda n: n.id) + + # Expected topics in the order they should appear + expected_topics = [ + 'biomero.eventsourcing:WorkflowRun.WorkflowInitiated', + 'biomero.eventsourcing:WorkflowRun.WorkflowFailed' + ] + + # Verify the order of notification topics + actual_topics = [n.topic for n in notifications_sorted] + assert actual_topics == expected_topics, ( + f"Expected topics order: {expected_topics}, but got: {actual_topics}" + ) + + # Verify the failure notification contains the correct error message + failure_notification = next( + (n for n in notifications_sorted if n.topic == + 'biomero.eventsourcing:WorkflowRun.WorkflowFailed'), + None + ) + assert failure_notification is not None, "Expected a WorkflowFailed notification" + + # Decode the state to verify the error message + state_str = failure_notification.state.decode('utf-8') + assert '"error_message":"Some error"' in state_str, ( + f"Expected 'error_message: Some error', but got: {state_str}" + ) + + +def test_complete_task(workflow_tracker): + # GIVEN a workflow with a task + workflow_id = workflow_tracker.initiate_workflow( + name="Workflow with Task", + description="Description of workflow", + user=1, + group=1 + ) + task_id = workflow_tracker.add_task_to_workflow( + workflow_id=workflow_id, + task_name="Task 1", + task_version="1.0", + input_data={"input_key": "input_value"}, + kwargs={} + ) + + # WHEN the task is started and then completed + workflow_tracker.start_task(task_id) + workflow_tracker.complete_task( + task_id, message="Task completed successfully") + + # THEN verify the task completion notification + notifications = workflow_tracker.notification_log.select(start=1, limit=10) + notifications_sorted = sorted(notifications, key=lambda n: n.id) + + # Expected topics in the order they should appear + expected_topics = [ + 'biomero.eventsourcing:WorkflowRun.WorkflowInitiated', + 'biomero.eventsourcing:Task.TaskCreated', + 'biomero.eventsourcing:WorkflowRun.TaskAdded', + 'biomero.eventsourcing:Task.TaskStarted', + 'biomero.eventsourcing:Task.TaskCompleted' + ] + + # Verify the order of notification topics + actual_topics = [n.topic for n in notifications_sorted] + assert actual_topics == expected_topics, ( + f"Expected topics order: {expected_topics}, but got: {actual_topics}" + ) + + # THEN Verify the task is marked as completed + task = workflow_tracker.repository.get(task_id) + assert task.result_message == "Task completed successfully" + + +def test_fail_task(workflow_tracker): + # GIVEN a workflow with a task + workflow_id = workflow_tracker.initiate_workflow( + name="Workflow with Task", + description="Description of workflow", + user=1, + group=1 + ) + task_id = workflow_tracker.add_task_to_workflow( + workflow_id=workflow_id, + task_name="Task 1", + task_version="1.0", + input_data={"input_key": "input_value"}, + kwargs={} + ) + + # WHEN the task is failed + workflow_tracker.fail_task( + task_id, error_message="Task failed due to an error") + + # THEN verify the task failure notification + notifications = workflow_tracker.notification_log.select(start=1, limit=10) + notifications_sorted = sorted(notifications, key=lambda n: n.id) + + # Expected topics in the order they should appear + expected_topics = [ + 'biomero.eventsourcing:WorkflowRun.WorkflowInitiated', + 'biomero.eventsourcing:Task.TaskCreated', + 'biomero.eventsourcing:WorkflowRun.TaskAdded', + 'biomero.eventsourcing:Task.TaskFailed' + ] + + # Verify the order of notification topics + actual_topics = [n.topic for n in notifications_sorted] + assert actual_topics == expected_topics, ( + f"Expected topics order: {expected_topics}, but got: {actual_topics}" + ) + + # THEN Verify the task is marked as failed + task: Task = workflow_tracker.repository.get(task_id) + assert task.result_message == "Task failed due to an error" + + +def test_add_job_id(workflow_tracker): + # GIVEN a workflow with a task + workflow_id = workflow_tracker.initiate_workflow( + name="Workflow with Task", + description="Description of workflow", + user=1, + group=1 + ) + task_id = workflow_tracker.add_task_to_workflow( + workflow_id=workflow_id, + task_name="Task 1", + task_version="1.0", + input_data={"input_key": "input_value"}, + kwargs={} + ) + + # WHEN a job ID is added to the task + job_id = "12345" + workflow_tracker.add_job_id(task_id, slurm_job_id=job_id) + + # THEN verify the job ID addition notification + notifications = workflow_tracker.notification_log.select(start=1, limit=10) + notifications_sorted = sorted(notifications, key=lambda n: n.id) + + # Expected topics in the order they should appear + expected_topics = [ + 'biomero.eventsourcing:WorkflowRun.WorkflowInitiated', + 'biomero.eventsourcing:Task.TaskCreated', + 'biomero.eventsourcing:WorkflowRun.TaskAdded', + 'biomero.eventsourcing:Task.JobIdAdded' + ] + + # Verify the order of notification topics + actual_topics = [n.topic for n in notifications_sorted] + assert actual_topics == expected_topics, ( + f"Expected topics order: {expected_topics}, but got: {actual_topics}" + ) + + # THEN Verify the job ID is added correctly + task = workflow_tracker.repository.get(task_id) + assert task.job_ids == [job_id] + + +def test_add_result(workflow_tracker): + # GIVEN a workflow with a task + workflow_id = workflow_tracker.initiate_workflow( + name="Workflow with Task", + description="Description of workflow", + user=1, + group=1 + ) + task_id = workflow_tracker.add_task_to_workflow( + workflow_id=workflow_id, + task_name="Task 1", + task_version="1.0", + input_data={"input_key": "input_value"}, + kwargs={} + ) + expected_result = { + "command": "lss -lah", + "env": "MY_PASS=SECRET", + "stdout": "\n", + "stderr": "oops did you mean ls?" + } + + # Mock the Result object from the fabric library + with patch('fabric.Result') as MockResult: + # Create a mock instance with specific return values + mock_result = MockResult.return_value + mock_result.command = "lss -lah" + mock_result.env = "MY_PASS=SECRET" + mock_result.stdout = "\n" + mock_result.stderr = "oops did you mean ls?" + + # WHEN a result is added to the task + workflow_tracker.add_result(task_id, mock_result) + + # THEN verify the result addition notification + notifications = workflow_tracker.notification_log.select(start=1, limit=10) + notifications_sorted = sorted(notifications, key=lambda n: n.id) + + # Expected topics in the order they should appear + expected_topics = [ + 'biomero.eventsourcing:WorkflowRun.WorkflowInitiated', + 'biomero.eventsourcing:Task.TaskCreated', + 'biomero.eventsourcing:WorkflowRun.TaskAdded', + 'biomero.eventsourcing:Task.ResultAdded' + ] + + # Verify the order of notification topics + actual_topics = [n.topic for n in notifications_sorted] + assert actual_topics == expected_topics, ( + f"Expected topics order: {expected_topics}, but got: {actual_topics}" + ) + + # THEN Verify the result is added correctly + task = workflow_tracker.repository.get(task_id) + assert task.results == [expected_result] + + +def test_update_task_status(workflow_tracker): + # GIVEN a workflow with a task + workflow_id = workflow_tracker.initiate_workflow( + name="Workflow with Task", + description="Description of workflow", + user=1, + group=1 + ) + task_id = workflow_tracker.add_task_to_workflow( + workflow_id=workflow_id, + task_name="Task 1", + task_version="1.0", + input_data={"input_key": "input_value"}, + kwargs={} + ) + + # WHEN the task status is updated + workflow_tracker.update_task_status(task_id, status="In Progress") + + # THEN verify the task status update notification + notifications = workflow_tracker.notification_log.select(start=1, limit=10) + notifications_sorted = sorted(notifications, key=lambda n: n.id) + + # Expected topics in the order they should appear + expected_topics = [ + 'biomero.eventsourcing:WorkflowRun.WorkflowInitiated', + 'biomero.eventsourcing:Task.TaskCreated', + 'biomero.eventsourcing:WorkflowRun.TaskAdded', + 'biomero.eventsourcing:Task.StatusUpdated' + ] + + # Verify the order of notification topics + actual_topics = [n.topic for n in notifications_sorted] + assert actual_topics == expected_topics, ( + f"Expected topics order: {expected_topics}, but got: {actual_topics}" + ) + + # THEN Verify the status is updated correctly + task = workflow_tracker.repository.get(task_id) + assert task.status == "In Progress" + + +def test_update_task_progress(workflow_tracker): + # GIVEN a workflow with a task + workflow_id = workflow_tracker.initiate_workflow( + name="Workflow with Task", + description="Description of workflow", + user=1, + group=1 + ) + task_id = workflow_tracker.add_task_to_workflow( + workflow_id=workflow_id, + task_name="Task 1", + task_version="1.0", + input_data={"input_key": "input_value"}, + kwargs={} + ) + + # WHEN the task progress is updated + workflow_tracker.update_task_progress(task_id, progress="50%") + + # THEN verify the task progress update notification + notifications = workflow_tracker.notification_log.select(start=1, limit=10) + notifications_sorted = sorted(notifications, key=lambda n: n.id) + + # Expected topics in the order they should appear + expected_topics = [ + 'biomero.eventsourcing:WorkflowRun.WorkflowInitiated', + 'biomero.eventsourcing:Task.TaskCreated', + 'biomero.eventsourcing:WorkflowRun.TaskAdded', + 'biomero.eventsourcing:Task.ProgressUpdated' + ] + + # Verify the order of notification topics + actual_topics = [n.topic for n in notifications_sorted] + assert actual_topics == expected_topics, ( + f"Expected topics order: {expected_topics}, but got: {actual_topics}" + ) + + # THEN Verify the progress is updated correctly + task = workflow_tracker.repository.get(task_id) + assert task.progress == "50%" + + +def test_job_acc_workflow_initiated(workflow_tracker_and_job_accounting): + # GIVEN a WorkflowTracker event system and job accounting listener + workflow_tracker: WorkflowTracker + job_accounting: JobAccounting + workflow_tracker, job_accounting = workflow_tracker_and_job_accounting + + # WHEN a new workflow is initiated + workflow_id = workflow_tracker.initiate_workflow("Test Workflow", + "Test Description", + user=1, + group=2) + + # THEN verify internal state in JobAccounting + assert workflow_id in job_accounting.workflows + assert job_accounting.workflows[workflow_id] == {"user": 1, "group": 2} + + +def test_job_acc_task_added(workflow_tracker_and_job_accounting): + # GIVEN a WorkflowTracker event system and job accounting listener + workflow_tracker: WorkflowTracker + job_accounting: JobAccounting + workflow_tracker, job_accounting = workflow_tracker_and_job_accounting + + # WHEN a new workflow is initiated + workflow_id = workflow_tracker.initiate_workflow("Test Workflow", + "Test Description", + user=1, + group=2) + # And a task is added + task_id = workflow_tracker.add_task_to_workflow( + workflow_id=workflow_id, + task_name="task", + task_version="v1", + input_data={"foo": "bar"}, + kwargs={"bar": "baz"} + ) + + # THEN verify internal state in JobAccounting + assert workflow_id in job_accounting.workflows + assert job_accounting.workflows[workflow_id] == {"user": 1, "group": 2} + assert job_accounting.tasks[task_id] == workflow_id + + +def test_job_acc_job_id_added(workflow_tracker_and_job_accounting): + # GIVEN a WorkflowTracker event system and job accounting listener + workflow_tracker: WorkflowTracker + job_accounting: JobAccounting + workflow_tracker, job_accounting = workflow_tracker_and_job_accounting + + # WHEN a workflow is initiated and a task is added + workflow_id = workflow_tracker.initiate_workflow( + "Test Workflow", "Test Description", user=1, group=2) + task_id = workflow_tracker.add_task_to_workflow( + workflow_id, "task", "v1", {"foo": "bar"}, {"bar": "baz"}) + + # Add a job ID to the task + job_id = "12345" + workflow_tracker.add_job_id(task_id, job_id) + + # THEN verify internal state in JobAccounting + assert job_id in job_accounting.jobs + assert job_accounting.jobs[job_id] == (task_id, 1, 2) + + # Verify the JobView entry using SQLAlchemy + with EngineManager.get_session() as session: + job_view_entry = session.query(JobView).filter_by( + slurm_job_id=job_id).first() + assert job_view_entry is not None + # Assuming job_id is stored as an integer + assert job_view_entry.slurm_job_id == int(job_id) + assert job_view_entry.user == 1 + assert job_view_entry.group == 2 + + +def test_job_acc_update_view_table(workflow_tracker_and_job_accounting): + # GIVEN a WorkflowTracker event system and job accounting listener + workflow_tracker: WorkflowTracker + job_accounting: JobAccounting + workflow_tracker, job_accounting = workflow_tracker_and_job_accounting + + # WHEN updating the view table with job information + job_id = "67890" + user = 1 + group = 2 + job_accounting.update_view_table(job_id=job_id, user=user, group=group) + + # THEN verify the JobView entry using SQLAlchemy + with EngineManager.get_session() as session: + job_view_entry = session.query(JobView).filter_by( + slurm_job_id=int(job_id)).first() + assert job_view_entry is not None + assert job_view_entry.slurm_job_id == int(job_id) + assert job_view_entry.user == user + assert job_view_entry.group == group + + +def test_job_acc_get_jobs_for_user(workflow_tracker_and_job_accounting): + # GIVEN a WorkflowTracker event system and job accounting listener + workflow_tracker: WorkflowTracker + job_accounting: JobAccounting + workflow_tracker, job_accounting = workflow_tracker_and_job_accounting + + # Simulate adding jobs + job_accounting.update_view_table(job_id=100, user=1, group=2) + job_accounting.update_view_table(job_id=200, user=1, group=2) + job_accounting.update_view_table(job_id=300, user=2, group=3) + + # WHEN retrieving jobs for a specific user + jobs = job_accounting.get_jobs(user=1) + + # THEN verify the jobs are correctly retrieved + assert jobs == {1: [100, 200]} + + # Verify JobView entries using SQLAlchemy + with EngineManager.get_session() as session: + job_views = session.query(JobView).filter(JobView.user == 1).all() + job_ids = [job_view.slurm_job_id for job_view in job_views] + assert set(job_ids) == {100, 200} + + +def test_job_acc_get_jobs_for_group(workflow_tracker_and_job_accounting): + # GIVEN a WorkflowTracker event system and job accounting listener + workflow_tracker: WorkflowTracker + job_accounting: JobAccounting + workflow_tracker, job_accounting = workflow_tracker_and_job_accounting + + # Simulate adding jobs + job_accounting.update_view_table(job_id=400, user=1, group=2) + job_accounting.update_view_table(job_id=500, user=2, group=2) + job_accounting.update_view_table(job_id=600, user=2, group=3) + + # WHEN retrieving jobs for a specific group + jobs = job_accounting.get_jobs(group=2) + + # THEN verify the jobs are correctly retrieved + assert jobs == {None: [400, 500]} + + # Verify JobView entries using SQLAlchemy + with EngineManager.get_session() as session: + job_views = session.query(JobView).filter(JobView.group == 2).all() + job_ids = [job_view.slurm_job_id for job_view in job_views] + assert set(job_ids) == {400, 500} + + +def test_job_acc_get_jobs_all(workflow_tracker_and_job_accounting): + # GIVEN a WorkflowTracker event system and job accounting listener + workflow_tracker: WorkflowTracker + job_accounting: JobAccounting + workflow_tracker, job_accounting = workflow_tracker_and_job_accounting + + # Simulate adding jobs + job_accounting.update_view_table(job_id=700, user=1, group=2) + job_accounting.update_view_table(job_id=800, user=1, group=2) + job_accounting.update_view_table(job_id=900, user=2, group=3) + + # WHEN retrieving all jobs + jobs = job_accounting.get_jobs() + + # THEN verify the jobs are correctly grouped by user + assert jobs == {1: [700, 800], 2: [900]} + + # Verify JobView entries using SQLAlchemy + with EngineManager.get_session() as session: + job_views = session.query(JobView).all() + user_jobs = {} + for job_view in job_views: + if job_view.user not in user_jobs: + user_jobs[job_view.user] = [] + user_jobs[job_view.user].append(job_view.slurm_job_id) + assert user_jobs == {1: [700, 800], 2: [900]} + + +def test_job_progress_job_id_added(workflow_tracker_and_job_progress): + # GIVEN a WorkflowTracker event system and job progress listener + workflow_tracker: WorkflowTracker + job_progress: JobProgress + workflow_tracker, job_progress = workflow_tracker_and_job_progress + + # WHEN a workflow is initiated and a task is added + workflow_id = workflow_tracker.initiate_workflow( + "Test Workflow", "Test Description", user=1, group=2) + task_id = workflow_tracker.add_task_to_workflow( + workflow_id, "task", "v1", {"foo": "bar"}, {"bar": "baz"}) + + # Add a job ID to the task + job_id = "12345" + workflow_tracker.add_job_id(task_id, job_id) + + # THEN verify internal state in JobProgress + assert task_id in job_progress.task_to_job + assert job_progress.task_to_job[task_id] == job_id + + # Verify JobProgressView entries using SQLAlchemy + with EngineManager.get_session() as session: + job_progress_views = session.query(JobProgressView).filter( + JobProgressView.slurm_job_id == job_id).all() + assert len(job_progress_views) == 0 + + +def test_job_progress_status_updated(workflow_tracker_and_job_progress): + # GIVEN a WorkflowTracker event system and job progress listener + workflow_tracker: WorkflowTracker + job_progress: JobProgress + workflow_tracker, job_progress = workflow_tracker_and_job_progress + + # WHEN a workflow is initiated and a task is added + workflow_id = workflow_tracker.initiate_workflow( + "Test Workflow", "Test Description", user=1, group=2) + task_id = workflow_tracker.add_task_to_workflow( + workflow_id, "task", "v1", {"foo": "bar"}, {"bar": "baz"}) + + # Add a job ID to the task + job_id = 200 + workflow_tracker.add_job_id(task_id, job_id) + + # Update the task status + status = "InProgress" + workflow_tracker.update_task_status(task_id, status) + + # THEN verify internal state in JobProgress + assert job_id in job_progress.job_status + assert job_progress.job_status[job_id]["status"] == status + + # Verify JobProgressView entries using SQLAlchemy + with EngineManager.get_session() as session: + job_progress_views = session.query(JobProgressView).filter( + JobProgressView.slurm_job_id == job_id).all() + assert len(job_progress_views) == 1 # Expecting exactly one entry + job_progress_view = job_progress_views[0] + assert job_progress_view.slurm_job_id == job_id + assert job_progress_view.status == status + assert job_progress_view.progress is None # No progress set in this test + + # Update the task status + status = "COMPLETED" + workflow_tracker.update_task_status(task_id, status) + + # THEN verify internal state in JobProgress + assert job_id in job_progress.job_status + assert job_progress.job_status[job_id]["status"] == status + + # Verify JobProgressView entries using SQLAlchemy + with EngineManager.get_session() as session: + job_progress_views = session.query(JobProgressView).filter( + JobProgressView.slurm_job_id == job_id).all() + assert len(job_progress_views) == 1 # Expecting exactly one entry + job_progress_view = job_progress_views[0] + assert job_progress_view.slurm_job_id == job_id + assert job_progress_view.status == status + assert job_progress_view.progress is None # No progress set in this test + + +def test_job_progress_progress_updated(workflow_tracker_and_job_progress): + # GIVEN a WorkflowTracker event system and job progress listener + workflow_tracker: WorkflowTracker + job_progress: JobProgress + workflow_tracker, job_progress = workflow_tracker_and_job_progress + + # WHEN a workflow is initiated and a task is added + workflow_id = workflow_tracker.initiate_workflow( + "Test Workflow", "Test Description", user=1, group=2) + task_id = workflow_tracker.add_task_to_workflow( + workflow_id, "task", "v1", {"foo": "bar"}, {"bar": "baz"}) + + # Add a job ID to the task + job_id = 12345 # Use a simple integer for the job ID + workflow_tracker.add_job_id(task_id, job_id) + + # Update the task progress + progress = "50%" + workflow_tracker.update_task_progress(task_id, progress) + + # THEN verify internal state in JobProgress + assert job_id in job_progress.job_status + assert job_progress.job_status[job_id]["progress"] == progress + + # Verify JobProgressView entries using SQLAlchemy + with EngineManager.get_session() as session: + job_progress_views = session.query(JobProgressView).filter( + JobProgressView.slurm_job_id == job_id).all() + assert len(job_progress_views) == 1 + job_progress_view = job_progress_views[0] + assert job_progress_view.slurm_job_id == job_id + assert job_progress_view.status == "UNKNOWN" + assert job_progress_view.progress == progress + + # Update the task progress + progress = "100%" + workflow_tracker.update_task_progress(task_id, progress) + + # THEN verify internal state in JobProgress + assert job_id in job_progress.job_status + assert job_progress.job_status[job_id]["progress"] == progress + + # Verify JobProgressView entries using SQLAlchemy + with EngineManager.get_session() as session: + job_progress_views = session.query(JobProgressView).filter( + JobProgressView.slurm_job_id == job_id).all() + assert len(job_progress_views) == 1 + job_progress_view = job_progress_views[0] + assert job_progress_view.slurm_job_id == job_id + assert job_progress_view.status == "UNKNOWN" + assert job_progress_view.progress == progress + + +def test_job_progress_update_view_table(workflow_tracker_and_job_progress): + # GIVEN a WorkflowTracker event system and job progress listener + workflow_tracker: WorkflowTracker + job_progress: JobProgress + workflow_tracker, job_progress = workflow_tracker_and_job_progress + + # WHEN a job status is set + job_id = 12345 + job_progress.job_status[job_id] = {"status": "RUNNING", "progress": "50%"} + + # Force update to the view table + job_progress.update_view_table(job_id) + + # THEN verify JobProgressView entries using SQLAlchemy + with EngineManager.get_session() as session: + job_progress_view = session.query(JobProgressView).filter( + JobProgressView.slurm_job_id == job_id).one_or_none() + assert job_progress_view is not None + assert job_progress_view.slurm_job_id == job_id + assert job_progress_view.status == "RUNNING" + assert job_progress_view.progress == "50%" + + +def test_job_progress_update_view_table_failed(workflow_tracker_and_job_progress, caplog): + # GIVEN a WorkflowTracker event system and job progress listener + workflow_tracker: WorkflowTracker + job_progress: JobProgress + workflow_tracker, job_progress = workflow_tracker_and_job_progress + + # WHEN a job status is not set + job_id = 12345 + job_progress.job_status[job_id] = {"status": None, "progress": "50%"} + + # Force update to the view table + with caplog.at_level("ERROR"): + job_progress.update_view_table(job_id) + + # THEN + assert f"Failed to insert/update job progress in view table: job_id={job_id}" in caplog.text + + +def test_wfanalytics_workflow_initiated(workflow_tracker_and_workflow_analytics): + # GIVEN a WorkflowTracker event system and workflow analytics listener + workflow_tracker: WorkflowTracker + workflow_analytics: WorkflowAnalytics + workflow_tracker, workflow_analytics = workflow_tracker_and_workflow_analytics + + # WHEN a workflow is initiated + workflow_id = workflow_tracker.initiate_workflow( + "Test Workflow", "Test Description", user=1, group=2) + + # THEN verify internal state in WorkflowAnalytics + assert workflow_id in workflow_analytics.workflows + assert workflow_analytics.workflows[workflow_id] == {"user": 1, "group": 2} + + +def test_wfanalytics_task_added(workflow_tracker_and_workflow_analytics, caplog): + # GIVEN a WorkflowTracker event system and workflow analytics listener + workflow_tracker: WorkflowTracker + workflow_analytics: WorkflowAnalytics + workflow_tracker, workflow_analytics = workflow_tracker_and_workflow_analytics + + # WHEN a workflow is initiated and a task is added + before = datetime.now(timezone.utc) + with caplog.at_level("DEBUG"): + workflow_id = workflow_tracker.initiate_workflow( + "Test Workflow", "Test Description", user=1, group=2) + task_id = workflow_tracker.add_task_to_workflow( + workflow_id, "task", "v1", {"foo": "bar"}, {"bar": "baz"}) # both + + # THEN verify internal state in WorkflowAnalytics + assert task_id in workflow_analytics.tasks + task_info = workflow_analytics.tasks[task_id] + assert task_info["wf_id"] == workflow_id + assert task_info["task_name"] == "task" + assert task_info["task_version"] == "v1" + assert task_info["status"] == "CREATED" + now = datetime.now(timezone.utc) + assert task_info["start_time"] >= before + assert task_info["start_time"] <= now + wf_info = workflow_analytics.workflows[workflow_id] + assert wf_info["user"] == 1 + assert wf_info["group"] == 2 + wf_id = workflow_analytics.tasks.get(task_id).get("wf_id") + assert wf_id == workflow_id + assert wf_id in workflow_analytics.workflows + assert workflow_analytics.workflows[wf_id]["user"] == 1 + + # THEN check logs for WorkflowInitiated event + assert f"Workflow initiated: wf_id={workflow_id}, user=1, group=2" in caplog.text + + # THEN check logs for TaskAdded event + assert f"Task added: task_id={task_id}, wf_id={workflow_id}" in caplog.text + + # THEN check logs for TaskCreated event + assert f"Task created: task_id={task_id}, task_name=task, timestamp=" in caplog.text + + +def test_wfanalytics_task_created(workflow_tracker_and_workflow_analytics): + # GIVEN a WorkflowTracker event system and workflow analytics listener + workflow_tracker: WorkflowTracker + workflow_analytics: WorkflowAnalytics + workflow_tracker, workflow_analytics = workflow_tracker_and_workflow_analytics + + # WHEN a workflow is initiated and a task is created + workflow_id = workflow_tracker.initiate_workflow( + "Test Workflow", "Test Description", user=1, group=2) + task_id = workflow_tracker.add_task_to_workflow( + workflow_id, "task", "v1", {"foo": "bar"}, {"bar": "baz"}) + workflow_tracker.update_task_progress(task_id, "CREATED") + + # THEN verify internal state in WorkflowAnalytics + assert task_id in workflow_analytics.tasks + task_info = workflow_analytics.tasks[task_id] + assert task_info["wf_id"] == workflow_id + assert task_info["task_name"] == "task" + assert task_info["task_version"] == "v1" + assert task_info["status"] == "CREATED" + assert task_info["start_time"] is not None + wf_info = workflow_analytics.workflows[workflow_id] + assert wf_info["user"] == 1 + assert wf_info["group"] == 2 + wf_id = workflow_analytics.tasks.get(task_id).get("wf_id") + assert wf_id == workflow_id + assert wf_id in workflow_analytics.workflows + assert workflow_analytics.workflows[wf_id]["user"] == 1 + + +def test_wfanalytics_task_completed(workflow_tracker_and_workflow_analytics, caplog): + # GIVEN a WorkflowTracker event system and workflow analytics listener + workflow_tracker: WorkflowTracker + workflow_analytics: WorkflowAnalytics + workflow_tracker, workflow_analytics = workflow_tracker_and_workflow_analytics + + # WHEN a workflow is initiated and a task is completed + workflow_id = workflow_tracker.initiate_workflow( + "Test Workflow", "Test Description", user=1, group=2) + task_id = workflow_tracker.add_task_to_workflow( + workflow_id, "task", "v1", {"foo": "bar"}, {"bar": "baz"}) + with caplog.at_level("DEBUG"): + workflow_tracker.complete_task(task_id, "done") + + # THEN verify internal state in WorkflowAnalytics + assert task_id in workflow_analytics.tasks + task_info = workflow_analytics.tasks[task_id] + assert task_info["wf_id"] == workflow_id + assert task_info["task_name"] == "task" + assert task_info["task_version"] == "v1" + assert task_info["status"] == "CREATED" + assert task_info["start_time"] is not None + + # Verify TaskExecution entries using SQLAlchemy + with EngineManager.get_session() as session: + task_execution: TaskExecution = session.query(TaskExecution).filter_by( + task_id=task_id).first() + assert task_execution is not None + assert task_execution.task_id == task_id + assert task_execution.task_name == "task" + assert task_execution.task_version == "v1" + assert task_execution.status == "CREATED" + assert task_execution.start_time is not None + assert task_execution.end_time is not None + assert task_execution.user_id == 1 + assert task_execution.group_id == 2 + + # THEN + completed = task_info["end_time"] + assert f"Task completed: task_id={task_id}, end_time={completed}" in caplog.text + +def test_wfanalytics_task_failed(workflow_tracker_and_workflow_analytics, caplog): + # GIVEN a WorkflowTracker event system and workflow analytics listener + workflow_tracker: WorkflowTracker + workflow_analytics: WorkflowAnalytics + workflow_tracker, workflow_analytics = workflow_tracker_and_workflow_analytics + + # WHEN a workflow is initiated and a task is completed + workflow_id = workflow_tracker.initiate_workflow( + "Test Workflow", "Test Description", user=1, group=2) + task_id = workflow_tracker.add_task_to_workflow( + workflow_id, "task", "v1", {"foo": "bar"}, {"bar": "baz"}) + with caplog.at_level("DEBUG"): + error_message = "failed" + workflow_tracker.fail_task(task_id, error_message) + + # THEN verify internal state in WorkflowAnalytics + assert task_id in workflow_analytics.tasks + task_info = workflow_analytics.tasks[task_id] + assert task_info["wf_id"] == workflow_id + assert task_info["task_name"] == "task" + assert task_info["task_version"] == "v1" + assert task_info["status"] == "CREATED" + assert task_info["start_time"] is not None + assert task_info["end_time"] is not None + assert task_info["error_type"] == error_message + + # Verify TaskExecution entries using SQLAlchemy + with EngineManager.get_session() as session: + task_execution: TaskExecution = session.query(TaskExecution).filter_by( + task_id=task_id).first() + assert task_execution is not None + assert task_execution.task_id == task_id + assert task_execution.task_name == "task" + assert task_execution.task_version == "v1" + assert task_execution.status == "CREATED" + assert task_execution.start_time is not None + assert task_execution.end_time is not None + assert task_execution.error_type == error_message + assert task_execution.user_id == 1 + assert task_execution.group_id == 2 + + # THEN + t_failed = task_info["end_time"] + assert f"Task failed: task_id={task_id}, end_time={t_failed}, error={error_message}" in caplog.text + + +def test_wfanalytics_update_view_table(workflow_tracker_and_workflow_analytics, caplog): + # GIVEN a WorkflowTracker event system and workflow analytics listener + workflow_tracker: WorkflowTracker + workflow_analytics: WorkflowAnalytics + workflow_tracker, workflow_analytics = workflow_tracker_and_workflow_analytics + + # WHEN a workflow is initiated and a task is created + workflow_id = workflow_tracker.initiate_workflow( + "Test Workflow", "Test Description", user=1, group=2) + task_id = workflow_tracker.add_task_to_workflow( + workflow_id, "task", "v1", {"foo": "bar"}, {"bar": "baz"}) + workflow_tracker.update_task_progress(task_id, "CREATED") + + # Verify TaskExecution entries using SQLAlchemy + with EngineManager.get_session() as session: + task_execution: TaskExecution = session.query(TaskExecution).filter_by( + task_id=task_id).first() + assert task_execution is not None + assert task_execution.task_id == task_id + assert task_execution.task_name == "task" + assert task_execution.task_version == "v1" + assert task_execution.status == "CREATED" + assert task_execution.start_time is not None + assert task_execution.user_id is None # No update sent to DB yet + assert task_execution.group_id is None # No update sent to DB yet + + # Verify that the entry was added to the SQLAlchemy table + workflow_analytics.update_view_table(task_id) + + with EngineManager.get_session() as session: + # Ensure the new task is inserted + task_execution: TaskExecution = session.query( + TaskExecution).filter_by(task_id=task_id).first() + assert task_execution is not None + assert task_execution.task_id == task_id + assert task_execution.task_name == "task" + assert task_execution.task_version == "v1" + assert task_execution.status == "CREATED" + assert task_execution.start_time is not None + assert task_execution.user_id == 1 + assert task_execution.group_id == 2 + + # Update the task status + workflow_tracker.update_task_status(task_id, "RUNNING") + workflow_analytics.tasks[task_id]["status"] = "RUNNING" + workflow_analytics.update_view_table(task_id) + + with EngineManager.get_session() as session: + # Ensure the status was updated + task_execution: TaskExecution = session.query( + TaskExecution).filter_by(task_id=task_id).first() + assert task_execution is not None + assert task_execution.status == "RUNNING" + + # Update the task's name and version + workflow_analytics.tasks[task_id]["task_name"] = "updated_task" + workflow_analytics.tasks[task_id]["task_version"] = "v2" + workflow_analytics.update_view_table(task_id) + + with EngineManager.get_session() as session: + # Ensure the task name and version were updated + task_execution: TaskExecution = session.query( + TaskExecution).filter_by(task_id=task_id).first() + assert task_execution is not None + assert task_execution.task_name == "updated_task" + assert task_execution.task_version == "v2" + + # Simulate the case where end_time is updated + end_time = datetime.now() + workflow_analytics.tasks[task_id]["end_time"] = end_time + workflow_analytics.update_view_table(task_id) + + with EngineManager.get_session() as session: + # Ensure the end_time was updated + task_execution: TaskExecution = session.query( + TaskExecution).filter_by(task_id=task_id).first() + assert task_execution is not None + assert task_execution.end_time == end_time + + # Simulate error_type being added + workflow_analytics.tasks[task_id]["error_type"] = "TaskError" + workflow_analytics.update_view_table(task_id) + + with EngineManager.get_session() as session: + # Ensure the error_type was updated + task_execution: TaskExecution = session.query( + TaskExecution).filter_by(task_id=task_id).first() + assert task_execution is not None + assert task_execution.error_type == "TaskError" + + # no-op + task_id2 = uuid.uuid4() + workflow_analytics.update_view_table(task_id2) + with EngineManager.get_session() as session: + # Ensure the end_time was updated + task_execution: TaskExecution = session.query( + TaskExecution).filter_by(task_id=task_id2).first() + assert task_execution is None + + + # rollback/integrityerror + with caplog.at_level("ERROR"): + workflow_analytics.tasks[task_id]["status"] = None + workflow_analytics.update_view_table(task_id) + + assert f"Failed to insert/update task execution into view table: task_id={task_id}, error=" in caplog.text + + +def test_wfanalytics_get_task_counts_with_filters(workflow_tracker_and_workflow_analytics): + # GIVEN a WorkflowTracker event system and workflow analytics listener + workflow_tracker: WorkflowTracker + workflow_analytics: WorkflowAnalytics + workflow_tracker, workflow_analytics = workflow_tracker_and_workflow_analytics + + # GIVEN task data added to workflow_analytics.tasks for different users and groups + workflow_id_1 = workflow_tracker.initiate_workflow( + "Test Workflow 1", "Test Description 1", user=1, group=2) + task_id1 = workflow_tracker.add_task_to_workflow( + workflow_id_1, "task1", "v1", {"foo": "bar"}, {"bar": "baz"}) + task_id2 = workflow_tracker.add_task_to_workflow( + workflow_id_1, "task1", "v1", {"foo": "bar"}, {"bar": "baz"}) + + workflow_id_2 = workflow_tracker.initiate_workflow( + "Test Workflow 2", "Test Description 2", user=2, group=2) + task_id3 = workflow_tracker.add_task_to_workflow( + workflow_id_2, "task2", "v1", {"foo": "bar"}, {"bar": "baz"}) + + workflow_id_3 = workflow_tracker.initiate_workflow( + "Test Workflow 3", "Test Description 3", user=1, group=3) + task_id4 = workflow_tracker.add_task_to_workflow( + workflow_id_3, "task3", "v1", {"foo": "bar"}, {"bar": "baz"}) + + workflow_analytics.update_view_table(task_id1) + workflow_analytics.update_view_table(task_id2) + workflow_analytics.update_view_table(task_id3) + workflow_analytics.update_view_table(task_id4) + + # WHEN calling get_task_counts without filters (should return counts for all tasks) + result = workflow_analytics.get_task_counts() + + # THEN we should get the correct counts + expected_result = { + ("task1", "v1"): 2, # Two tasks for workflow 1 + ("task2", "v1"): 1, # One task for workflow 2 + ("task3", "v1"): 1 # One task for workflow 3 + } + assert result == expected_result + + # WHEN calling get_task_counts filtered by user=1 + result_user_1 = workflow_analytics.get_task_counts(user=1) + + # THEN we should get the counts for tasks belonging to user 1 only + expected_result_user_1 = { + ("task1", "v1"): 2, # Two tasks for user 1 (workflow 1) + ("task3", "v1"): 1 # One task for user 1 (workflow 3) + } + assert result_user_1 == expected_result_user_1 + + # WHEN calling get_task_counts filtered by group=2 + result_group_2 = workflow_analytics.get_task_counts(group=2) + + # THEN we should get the counts for tasks belonging to group 2 only + expected_result_group_2 = { + ("task1", "v1"): 2, # Two tasks for group 2 (workflow 1) + ("task2", "v1"): 1 # One task for group 2 (workflow 2) + } + assert result_group_2 == expected_result_group_2 + + # WHEN calling get_task_counts filtered by both user=1 and group=2 + result_user_1_group_2 = workflow_analytics.get_task_counts(user=1, group=2) + + # THEN we should get the counts for tasks belonging to both user 1 and group 2 (workflow 1) + expected_result_user_1_group_2 = { + ("task1", "v1"): 2 # Two tasks for user 1 and group 2 (workflow 1) + } + assert result_user_1_group_2 == expected_result_user_1_group_2 + + + +def test_wfanalytics_get_average_task_duration(workflow_tracker_and_workflow_analytics): + # GIVEN a WorkflowTracker event system and workflow analytics listener + workflow_tracker: WorkflowTracker + workflow_analytics: WorkflowAnalytics + workflow_tracker, workflow_analytics = workflow_tracker_and_workflow_analytics + + # GIVEN task data added to workflow_analytics.tasks with start and end times + workflow_id = workflow_tracker.initiate_workflow( + "Test Workflow", "Test Description", user=1, group=2) + + # Simulating task creation with durations + task_id_1 = workflow_tracker.add_task_to_workflow( + workflow_id, "task1", "v1", {"foo": "bar"}, {"bar": "baz"}) + workflow_analytics.tasks[task_id_1]["start_time"] = datetime( + 2023, 8, 1, 12, 0, 0) + workflow_analytics.tasks[task_id_1]["end_time"] = datetime( + 2023, 8, 1, 12, 30, 0) # 30 mins duration + + task_id_2 = workflow_tracker.add_task_to_workflow( + workflow_id, "task1", "v1", {"foo": "bar"}, {"bar": "baz"}) + workflow_analytics.tasks[task_id_2]["start_time"] = datetime( + 2023, 8, 1, 13, 0, 0) + workflow_analytics.tasks[task_id_2]["end_time"] = datetime( + 2023, 8, 1, 13, 45, 0) # 45 mins duration + + task_id_3 = workflow_tracker.add_task_to_workflow( + workflow_id, "task2", "v1", {"foo": "bar"}, {"bar": "baz"}) + workflow_analytics.tasks[task_id_3]["start_time"] = datetime( + 2023, 8, 1, 14, 0, 0) + workflow_analytics.tasks[task_id_3]["end_time"] = datetime( + 2023, 8, 1, 14, 20, 0) # 20 mins duration + + # Manually update the view table with the tasks' details + workflow_analytics.update_view_table(task_id_1) + workflow_analytics.update_view_table(task_id_2) + workflow_analytics.update_view_table(task_id_3) + + # WHEN calling get_average_task_duration + result = workflow_analytics.get_average_task_duration() + + # THEN we should get the correct average durations + expected_result = { + # Average of 30 and 45 minutes, converted to seconds + ("task1", "v1"): (30 * 60 + 45 * 60) / 2, + ("task2", "v1"): 20 * 60 # 20 minutes in seconds + } + assert result == expected_result + + +def test_wfanalytics_get_average_task_duration_with_filters(workflow_tracker_and_workflow_analytics): + # GIVEN a WorkflowTracker event system and workflow analytics listener + workflow_tracker: WorkflowTracker + workflow_analytics: WorkflowAnalytics + workflow_tracker, workflow_analytics = workflow_tracker_and_workflow_analytics + + # GIVEN task data added to workflow_analytics.tasks with start and end times for different users and groups + workflow_id_1 = workflow_tracker.initiate_workflow( + "Test Workflow 1", "Test Description", user=1, group=2) + + workflow_id_2 = workflow_tracker.initiate_workflow( + "Test Workflow 2", "Test Description", user=3, group=4) + + # Simulating task creation with durations for user 1, group 2 + task_id_1 = workflow_tracker.add_task_to_workflow( + workflow_id_1, "task1", "v1", {"foo": "bar"}, {"bar": "baz"}) + workflow_analytics.tasks[task_id_1]["start_time"] = datetime( + 2023, 8, 1, 12, 0, 0) + workflow_analytics.tasks[task_id_1]["end_time"] = datetime( + 2023, 8, 1, 12, 30, 0) # 30 mins duration + + task_id_2 = workflow_tracker.add_task_to_workflow( + workflow_id_1, "task1", "v1", {"foo": "bar"}, {"bar": "baz"}) + workflow_analytics.tasks[task_id_2]["start_time"] = datetime( + 2023, 8, 1, 13, 0, 0) + workflow_analytics.tasks[task_id_2]["end_time"] = datetime( + 2023, 8, 1, 13, 45, 0) # 45 mins duration + + # Simulating task creation with durations for user 3, group 4 + task_id_3 = workflow_tracker.add_task_to_workflow( + workflow_id_2, "task1", "v1", {"foo": "bar"}, {"bar": "baz"}) + workflow_analytics.tasks[task_id_3]["start_time"] = datetime( + 2023, 8, 1, 14, 0, 0) + workflow_analytics.tasks[task_id_3]["end_time"] = datetime( + 2023, 8, 1, 14, 30, 0) # 30 mins duration + + task_id_4 = workflow_tracker.add_task_to_workflow( + workflow_id_2, "task2", "v1", {"foo": "bar"}, {"bar": "baz"}) + workflow_analytics.tasks[task_id_4]["start_time"] = datetime( + 2023, 8, 1, 15, 0, 0) + workflow_analytics.tasks[task_id_4]["end_time"] = datetime( + 2023, 8, 1, 15, 20, 0) # 20 mins duration + + # Manually update the view table with the tasks' details + workflow_analytics.update_view_table(task_id_1) + workflow_analytics.update_view_table(task_id_2) + workflow_analytics.update_view_table(task_id_3) + workflow_analytics.update_view_table(task_id_4) + + # WHEN calling get_average_task_duration without filters + result_all = workflow_analytics.get_average_task_duration() + + # THEN we should get the correct average durations for all tasks + expected_result_all = { + # Average of 30, 45, and 30 minutes + ("task1", "v1"): (30 * 60 + 45 * 60 + 30 * 60) / 3, + ("task2", "v1"): 20 * 60 # 20 minutes in seconds + } + assert result_all == expected_result_all + + # WHEN calling get_average_task_duration filtered by user=1 (group=2 implicitly since user=1 only belongs to group=2) + result_user_1 = workflow_analytics.get_average_task_duration(user=1) + + # THEN we should get the correct average durations for user 1's tasks + expected_result_user_1 = { + ("task1", "v1"): (30 * 60 + 45 * 60) / 2 # Average of 30 and 45 minutes + } + assert result_user_1 == expected_result_user_1 + + # WHEN calling get_average_task_duration filtered by group=4 (tasks by user 3 in group 4) + result_group_4 = workflow_analytics.get_average_task_duration(group=4) + + # THEN we should get the correct average durations for group 4's tasks + expected_result_group_4 = { + ("task1", "v1"): 30 * 60, # 30 minutes + ("task2", "v1"): 20 * 60 # 20 minutes + } + assert result_group_4 == expected_result_group_4 + + # WHEN calling get_average_task_duration filtered by both user=3 and group=4 + result_user_3_group_4 = workflow_analytics.get_average_task_duration( + user=3, group=4) + + # THEN we should get the correct average durations for user 3 in group 4 + expected_result_user_3_group_4 = { + ("task1", "v1"): 30 * 60, # 30 minutes + ("task2", "v1"): 20 * 60 # 20 minutes + } + assert result_user_3_group_4 == expected_result_user_3_group_4 + + +def test_wfanalytics_get_task_failures_with_filters(workflow_tracker_and_workflow_analytics): + # GIVEN a WorkflowTracker event system and workflow analytics listener + workflow_tracker: WorkflowTracker + workflow_analytics: WorkflowAnalytics + workflow_tracker, workflow_analytics = workflow_tracker_and_workflow_analytics + + # GIVEN task data added to workflow_analytics.tasks with failure reasons for different users and groups + workflow_id_1 = workflow_tracker.initiate_workflow( + "Test Workflow 1", "Test Description", user=1, group=2) + + workflow_id_2 = workflow_tracker.initiate_workflow( + "Test Workflow 2", "Test Description", user=3, group=4) + + # Simulating task creation with failures for user 1, group 2 + task_id_1 = workflow_tracker.add_task_to_workflow( + workflow_id_1, "task1", "v1", {"foo": "bar"}, {"bar": "baz"}) + workflow_analytics.tasks[task_id_1]["error_type"] = "ErrorA" + + task_id_2 = workflow_tracker.add_task_to_workflow( + workflow_id_1, "task1", "v1", {"foo": "bar"}, {"bar": "baz"}) + workflow_analytics.tasks[task_id_2]["error_type"] = "ErrorB" + + # Simulating task creation with failures for user 3, group 4 + task_id_3 = workflow_tracker.add_task_to_workflow( + workflow_id_2, "task1", "v1", {"foo": "bar"}, {"bar": "baz"}) + workflow_analytics.tasks[task_id_3]["error_type"] = "ErrorC" + + task_id_4 = workflow_tracker.add_task_to_workflow( + workflow_id_2, "task2", "v1", {"foo": "bar"}, {"bar": "baz"}) + workflow_analytics.tasks[task_id_4]["error_type"] = "ErrorD" + + # Manually update the view table with the tasks' details + workflow_analytics.update_view_table(task_id_1) + workflow_analytics.update_view_table(task_id_2) + workflow_analytics.update_view_table(task_id_3) + workflow_analytics.update_view_table(task_id_4) + + # WHEN calling get_task_failures without filters + result_all = workflow_analytics.get_task_failures() + + # THEN we should get the correct failures for all tasks + expected_result_all = { + ("task1", "v1"): ["ErrorA", "ErrorB", "ErrorC"], + ("task2", "v1"): ["ErrorD"] + } + assert result_all == expected_result_all + + # WHEN calling get_task_failures filtered by user=1 (group=2 implicitly since user=1 only belongs to group=2) + result_user_1 = workflow_analytics.get_task_failures(user=1) + + # THEN we should get the correct failures for user 1's tasks + expected_result_user_1 = { + ("task1", "v1"): ["ErrorA", "ErrorB"] + } + assert result_user_1 == expected_result_user_1 + + # WHEN calling get_task_failures filtered by group=4 (tasks by user 3 in group 4) + result_group_4 = workflow_analytics.get_task_failures(group=4) + + # THEN we should get the correct failures for group 4's tasks + expected_result_group_4 = { + ("task1", "v1"): ["ErrorC"], + ("task2", "v1"): ["ErrorD"] + } + assert result_group_4 == expected_result_group_4 + + # WHEN calling get_task_failures filtered by both user=3 and group=4 + result_user_3_group_4 = workflow_analytics.get_task_failures(user=3, group=4) + + # THEN we should get the correct failures for user 3 in group 4 + expected_result_user_3_group_4 = { + ("task1", "v1"): ["ErrorC"], + ("task2", "v1"): ["ErrorD"] + } + assert result_user_3_group_4 == expected_result_user_3_group_4 + + +def test_wfanalytics_get_task_usage_over_time_with_filters(workflow_tracker_and_workflow_analytics): + # GIVEN a WorkflowTracker event system and workflow analytics listener + workflow_tracker: WorkflowTracker + workflow_analytics: WorkflowAnalytics + workflow_tracker, workflow_analytics = workflow_tracker_and_workflow_analytics + + # GIVEN task data added to workflow_analytics.tasks with specific execution times for different users and groups + workflow_id_1 = workflow_tracker.initiate_workflow( + "Test Workflow 1", "Test Description", user=1, group=2) + workflow_id_2 = workflow_tracker.initiate_workflow( + "Test Workflow 2", "Test Description", user=3, group=4) + + # Simulating task executions for task1 on different dates for user 1, group 2 + task_id_1 = workflow_tracker.add_task_to_workflow( + workflow_id_1, "task1", "v1", {"foo": "bar"}, {"bar": "baz"}) + workflow_analytics.tasks[task_id_1]["start_time"] = datetime(2023, 8, 1, 12, 0, 0) + + task_id_2 = workflow_tracker.add_task_to_workflow( + workflow_id_1, "task1", "v1", {"foo": "bar"}, {"bar": "baz"}) + workflow_analytics.tasks[task_id_2]["start_time"] = datetime(2023, 8, 2, 13, 0, 0) + + # Simulating task executions for task1 on different dates for user 3, group 4 + task_id_3 = workflow_tracker.add_task_to_workflow( + workflow_id_2, "task1", "v1", {"foo": "bar"}, {"bar": "baz"}) + workflow_analytics.tasks[task_id_3]["start_time"] = datetime(2023, 8, 2, 14, 0, 0) + + task_id_4 = workflow_tracker.add_task_to_workflow( + workflow_id_2, "task1", "v1", {"foo": "bar"}, {"bar": "baz"}) + workflow_analytics.tasks[task_id_4]["start_time"] = datetime(2023, 8, 3, 15, 0, 0) + + # Manually update the view table with the tasks' details + workflow_analytics.update_view_table(task_id_1) + workflow_analytics.update_view_table(task_id_2) + workflow_analytics.update_view_table(task_id_3) + workflow_analytics.update_view_table(task_id_4) + + # WHEN calling get_task_usage_over_time without filters for task1 + result_all = workflow_analytics.get_task_usage_over_time(task_name="task1") + + # THEN we should get the correct usage counts over time for all users and groups + expected_result_all = { + str(datetime(2023, 8, 1).date()): 1, + str(datetime(2023, 8, 2).date()): 2, + str(datetime(2023, 8, 3).date()): 1 + } + assert result_all == expected_result_all + + # WHEN calling get_task_usage_over_time filtered by user=1 (group=2 implicitly) + result_user_1 = workflow_analytics.get_task_usage_over_time(task_name="task1", user=1) + + # THEN we should get the correct usage counts over time for user 1 + expected_result_user_1 = { + str(datetime(2023, 8, 1).date()): 1, + str(datetime(2023, 8, 2).date()): 1 + } + assert result_user_1 == expected_result_user_1 + + # WHEN calling get_task_usage_over_time filtered by group=4 (tasks by user 3 in group 4) + result_group_4 = workflow_analytics.get_task_usage_over_time(task_name="task1", group=4) + + # THEN we should get the correct usage counts over time for group 4 + expected_result_group_4 = { + str(datetime(2023, 8, 2).date()): 1, + str(datetime(2023, 8, 3).date()): 1 + } + assert result_group_4 == expected_result_group_4 + + # WHEN calling get_task_usage_over_time filtered by both user=3 and group=4 + result_user_3_group_4 = workflow_analytics.get_task_usage_over_time(task_name="task1", user=3, group=4) + + # THEN we should get the correct usage counts over time for user 3 in group 4 + expected_result_user_3_group_4 = { + str(datetime(2023, 8, 2).date()): 1, + str(datetime(2023, 8, 3).date()): 1 + } + assert result_user_3_group_4 == expected_result_user_3_group_4 + diff --git a/tests/unit/test_slurm_client.py b/tests/unit/test_slurm_client.py index b90b738..c1f883e 100644 --- a/tests/unit/test_slurm_client.py +++ b/tests/unit/test_slurm_client.py @@ -1,6 +1,8 @@ import logging from uuid import uuid4 -from biomero import SlurmClient, NoOpWorkflowTracker +from biomero.slurm_client import SlurmClient +from biomero.eventsourcing import NoOpWorkflowTracker +from biomero.database import EngineManager import pytest import mock from mock import patch, MagicMock @@ -26,6 +28,8 @@ class SerializableMagicMock(MagicMock, dict): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) +# Configure logging +logging.basicConfig(level=logging.INFO) @pytest.fixture @patch('biomero.slurm_client.Connection.create_session') @@ -35,11 +39,13 @@ def __init__(self, *args, **kwargs): def slurm_client(_mock_run, _mock_put, _mock_open, _mock_session): + logging.info("EngineManager.__dict__: %s", EngineManager.__dict__) return SlurmClient("localhost", 8022, "slurm") @patch.object(SlurmClient, 'run_commands_split_out') def test_get_all_image_versions_and_data_files(mock_run_commands_split_out, slurm_client): + # GIVEN slurm_client.slurm_images_path = "/path/to/slurm/images" slurm_client.slurm_data_path = "/path/to/slurm/data" From 670c1bce0bf1fe1de430298f96f1db093523a343 Mon Sep 17 00:00:00 2001 From: Luik Date: Tue, 3 Sep 2024 17:59:54 +0200 Subject: [PATCH 15/24] Store task_id also in JobView table. Rebuild tables. --- README.md | 2 +- biomero/constants.py | 1 + biomero/database.py | 4 +- biomero/slurm_client.py | 57 +++++++++++++++++++++++++++-- biomero/views.py | 8 ++-- tests/unit/test_eventsourcing.py | 49 ++++++++++++++++++++----- tests/unit/test_slurm_client.py | 63 +++++++++++++++++++++++++++++++- 7 files changed, 163 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index b36107d..81edf3d 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # BIOMERO - BioImage analysis in OMERO -[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![DOI](https://zenodo.org/badge/638954891.svg)](https://zenodo.org/badge/latestdoi/638954891) [![PyPI - Version](https://img.shields.io/pypi/v/biomero)](https://pypi.org/project/biomero/) [![PyPI - Python Versions](https://img.shields.io/pypi/pyversions/biomero)](https://pypi.org/project/biomero/) ![Slurm](https://img.shields.io/badge/Slurm-21.08.6-blue.svg) ![OMERO](https://img.shields.io/badge/OMERO-5.6.8-blue.svg) [![fair-software.eu](https://img.shields.io/badge/fair--software.eu-%E2%97%8F%20%20%E2%97%8F%20%20%E2%97%8F%20%20%E2%97%8F%20%20%E2%97%8F-green)](https://fair-software.eu) [![OpenSSF Best Practices](https://bestpractices.coreinfrastructure.org/projects/7530/badge)](https://bestpractices.coreinfrastructure.org/projects/7530) [![Sphinx build](https://github.com/NL-BioImaging/biomero/actions/workflows/sphinx.yml/badge.svg?branch=main)](https://github.com/NL-BioImaging/biomero/actions/workflows/sphinx.yml) [![pages-build-deployment](https://github.com/NL-BioImaging/biomero/actions/workflows/pages/pages-build-deployment/badge.svg)](https://github.com/NL-BioImaging/biomero/actions/workflows/pages/pages-build-deployment) [![python-package build](https://github.com/NL-BioImaging/biomero/actions/workflows/python-package.yml/badge.svg)](https://github.com/NL-BioImaging/biomero/actions/workflows/python-package.yml) [![python-publish build](https://github.com/NL-BioImaging/biomero/actions/workflows/python-publish.yml/badge.svg?branch=main)](https://github.com/NL-BioImaging/biomero/actions/workflows/python-publish.yml) +[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![DOI](https://zenodo.org/badge/638954891.svg)](https://zenodo.org/badge/latestdoi/638954891) [![PyPI - Version](https://img.shields.io/pypi/v/biomero)](https://pypi.org/project/biomero/) [![PyPI - Python Versions](https://img.shields.io/pypi/pyversions/biomero)](https://pypi.org/project/biomero/) ![Slurm](https://img.shields.io/badge/Slurm-21.08.6-blue.svg) ![OMERO](https://img.shields.io/badge/OMERO-5.6.8-blue.svg) [![fair-software.eu](https://img.shields.io/badge/fair--software.eu-%E2%97%8F%20%20%E2%97%8F%20%20%E2%97%8F%20%20%E2%97%8F%20%20%E2%97%8F-green)](https://fair-software.eu) [![OpenSSF Best Practices](https://bestpractices.coreinfrastructure.org/projects/7530/badge)](https://bestpractices.coreinfrastructure.org/projects/7530) [![Sphinx build](https://github.com/NL-BioImaging/biomero/actions/workflows/sphinx.yml/badge.svg?branch=main)](https://github.com/NL-BioImaging/biomero/actions/workflows/sphinx.yml) [![pages-build-deployment](https://github.com/NL-BioImaging/biomero/actions/workflows/pages/pages-build-deployment/badge.svg)](https://github.com/NL-BioImaging/biomero/actions/workflows/pages/pages-build-deployment) [![python-package build](https://github.com/NL-BioImaging/biomero/actions/workflows/python-package.yml/badge.svg)](https://github.com/NL-BioImaging/biomero/actions/workflows/python-package.yml) [![python-publish build](https://github.com/NL-BioImaging/biomero/actions/workflows/python-publish.yml/badge.svg?branch=main)](https://github.com/NL-BioImaging/biomero/actions/workflows/python-publish.yml) [![Coverage Status](https://coveralls.io/repos/github/NL-BioImaging/biomero/badge.svg?branch=main)](https://coveralls.io/github/NL-BioImaging/biomero?branch=main) The **BIOMERO** framework, for **B**io**I**mage analysis in **OMERO**, allows you to run (FAIR) bioimage analysis workflows directly from OMERO on a high-performance compute (HPC) cluster, remotely through SSH. diff --git a/biomero/constants.py b/biomero/constants.py index 7a5dd9a..95cf777 100644 --- a/biomero/constants.py +++ b/biomero/constants.py @@ -16,6 +16,7 @@ IMAGE_EXPORT_SCRIPT = "_SLURM_Image_Transfer.py" IMAGE_IMPORT_SCRIPT = "SLURM_Get_Results.py" +CONVERSION_SCRIPT = "SLURM_Remote_Conversion.py" RUN_WF_SCRIPT = "SLURM_Run_Workflow.py" diff --git a/biomero/database.py b/biomero/database.py index 67bb909..3c8f41a 100644 --- a/biomero/database.py +++ b/biomero/database.py @@ -35,12 +35,14 @@ class JobView(Base): slurm_job_id (Integer): The unique identifier for the Slurm job. user (Integer): The ID of the user who submitted the job. group (Integer): The group ID associated with the job. + task_id (UUID): The unique identifier for the biomero task """ __tablename__ = 'biomero_job_view' slurm_job_id = Column(Integer, primary_key=True) user = Column(Integer, nullable=False) group = Column(Integer, nullable=False) + task_id = Column(PGUUID(as_uuid=True)) class JobProgressView(Base): @@ -121,7 +123,7 @@ def create_scoped_session(cls, sqlalchemy_url: str = None): sqlalchemy_url = os.getenv('SQLALCHEMY_URL') cls._engine = create_engine(sqlalchemy_url) - # setup tables if needed + # setup tables if they don't exist yet Base.metadata.create_all(cls._engine) # Create a scoped_session object. diff --git a/biomero/slurm_client.py b/biomero/slurm_client.py index 69d5774..711d7c8 100644 --- a/biomero/slurm_client.py +++ b/biomero/slurm_client.py @@ -32,8 +32,10 @@ import os from biomero.eventsourcing import WorkflowTracker, NoOpWorkflowTracker from biomero.views import JobAccounting, JobProgress, WorkflowAnalytics -from biomero.database import EngineManager +from biomero.database import EngineManager, JobProgressView, JobView, TaskExecution from eventsourcing.system import System, SingleThreadedRunner +from sqlalchemy.exc import IntegrityError +from sqlalchemy.sql import text logger = logging.getLogger(__name__) @@ -433,12 +435,15 @@ def __init__(self, # Initialize the analytics system self.sqlalchemy_url = sqlalchemy_url - self.initialize_analytics_system() + self.initialize_analytics_system(reset_tables=init_slurm) - def initialize_analytics_system(self): + def initialize_analytics_system(self, reset_tables=False): """ Initialize the analytics system based on the analytics configuration passed to the constructor. + + Args: + reset_tables (bool): If True, drops and recreates all views. """ # Get persistence settings, prioritize environment variables persistence_module = os.getenv("PERSISTENCE_MODULE", "eventsourcing_sqlalchemy") @@ -483,6 +488,50 @@ def initialize_analytics_system(self): logger.warning("Tracking workflows is disabled. No-op WorkflowTracker will be used.") self.workflowTracker = NoOpWorkflowTracker() + self.setup_listeners(runner, reset_tables) + + def setup_listeners(self, runner, reset_tables): + # Only when people run init script, we just drop and rebuild. + self.get_listeners(runner) + + # Optionally drop and recreate tables + if reset_tables: + logger.info("Resetting view tables.") + tables = [] + # gather the listener tables + listeners = [self.jobAccounting, + self.jobProgress, + self.workflowAnalytics] + for listener in listeners: + if listener: + tables.append(listener.recorder.tracking_table_name) + tables.append(listener.recorder.events_table_name) + runner.stop() + # gather the view tables + tables.append(TaskExecution.__tablename__) + tables.append(JobProgressView.__tablename__) + tables.append(JobView.__tablename__) + with EngineManager.get_session() as session: + try: + # Begin a transaction + for table in tables: + # Drop the table if it exists + logger.info(f"Dropping table {table}") + drop_table_sql = text(f'DROP TABLE IF EXISTS {table}') + session.execute(drop_table_sql) + # Only when people run init script, we just drop and rebuild. + session.commit() + logger.info("Dropped view tables successfully") + except IntegrityError as e: + logger.error(e) + session.rollback() + raise Exception(f"Error trying to reset the view tables: {e}") + + EngineManager.close_engine() # close current sql session + # restart runner, listeners and recreate views + self.initialize_analytics_system(reset_tables=False) + + def get_listeners(self, runner): if self.track_workflows and self.enable_job_accounting: self.jobAccounting = runner.get(JobAccounting) else: @@ -1518,7 +1567,7 @@ def run_conversion_workflow_job(self, logger.debug(f"wf_id: {wf_id}") task_id = self.workflowTracker.add_task_to_workflow( wf_id, - chosen_converter, + f"convert_{source_format}_to_{target_format}".upper(), version, data_path, sbatch_env diff --git a/biomero/views.py b/biomero/views.py index f5e0330..5c6a831 100644 --- a/biomero/views.py +++ b/biomero/views.py @@ -98,21 +98,21 @@ def _(self, domain_event, process_event): logger.debug(f"Job added: job_id={job_id}, task_id={task_id}, user={user}, group={group}") # Update view table - self.update_view_table(job_id, user, group) + self.update_view_table(job_id, user, group, task_id) else: logger.debug(f"JobIdAdded event ignored: task_id={task_id} not found in tasks") # use .collect_events(agg) instead of .save(agg) # process_event.collect_events(jobaccount) - def update_view_table(self, job_id, user, group): + def update_view_table(self, job_id, user, group, task_id): """Update the view table with new job information.""" with EngineManager.get_session() as session: try: - new_job = JobView(slurm_job_id=job_id, user=user, group=group) + new_job = JobView(slurm_job_id=job_id, user=user, group=group, task_id=task_id) session.add(new_job) session.commit() - logger.debug(f"Inserted job into view table: job_id={job_id}, user={user}, group={group}") + logger.debug(f"Inserted job into view table: job_id={job_id}, user={user}, group={group}, task_id={task_id}") except IntegrityError as e: session.rollback() # Handle the case where the job already exists in the table if necessary diff --git a/tests/unit/test_eventsourcing.py b/tests/unit/test_eventsourcing.py index f6c9278..34802b4 100644 --- a/tests/unit/test_eventsourcing.py +++ b/tests/unit/test_eventsourcing.py @@ -88,6 +88,25 @@ def workflow_tracker_and_workflow_analytics(): runner.stop() +def test_runner(): + # Create a System instance with the necessary components + system = System(pipes=[[WorkflowTracker, JobAccounting]]) + runner = SingleThreadedRunner(system) + runner.start() + + # Get the application + wft = runner.get(WorkflowTracker) + + # when + assert wft.closing.is_set() is False + runner.stop() + assert wft.closing.is_set() is True + + # runner.start() + # wft2 = runner.get(WorkflowTracker) + # assert wft2.closing.is_set() is False + + def test_initiate_workflow(workflow_tracker): # Initiating a workflow workflow_id = workflow_tracker.initiate_workflow( @@ -596,7 +615,8 @@ def test_job_acc_update_view_table(workflow_tracker_and_job_accounting): job_id = "67890" user = 1 group = 2 - job_accounting.update_view_table(job_id=job_id, user=user, group=group) + job_accounting.update_view_table(job_id=job_id, user=user, group=group, + task_id=uuid.uuid4()) # THEN verify the JobView entry using SQLAlchemy with EngineManager.get_session() as session: @@ -615,9 +635,12 @@ def test_job_acc_get_jobs_for_user(workflow_tracker_and_job_accounting): workflow_tracker, job_accounting = workflow_tracker_and_job_accounting # Simulate adding jobs - job_accounting.update_view_table(job_id=100, user=1, group=2) - job_accounting.update_view_table(job_id=200, user=1, group=2) - job_accounting.update_view_table(job_id=300, user=2, group=3) + job_accounting.update_view_table(job_id=100, user=1, group=2, + task_id=uuid.uuid4()) + job_accounting.update_view_table(job_id=200, user=1, group=2, + task_id=uuid.uuid4()) + job_accounting.update_view_table(job_id=300, user=2, group=3, + task_id=uuid.uuid4()) # WHEN retrieving jobs for a specific user jobs = job_accounting.get_jobs(user=1) @@ -639,9 +662,12 @@ def test_job_acc_get_jobs_for_group(workflow_tracker_and_job_accounting): workflow_tracker, job_accounting = workflow_tracker_and_job_accounting # Simulate adding jobs - job_accounting.update_view_table(job_id=400, user=1, group=2) - job_accounting.update_view_table(job_id=500, user=2, group=2) - job_accounting.update_view_table(job_id=600, user=2, group=3) + job_accounting.update_view_table(job_id=400, user=1, group=2, + task_id=uuid.uuid4()) + job_accounting.update_view_table(job_id=500, user=2, group=2, + task_id=uuid.uuid4()) + job_accounting.update_view_table(job_id=600, user=2, group=3, + task_id=uuid.uuid4()) # WHEN retrieving jobs for a specific group jobs = job_accounting.get_jobs(group=2) @@ -663,9 +689,12 @@ def test_job_acc_get_jobs_all(workflow_tracker_and_job_accounting): workflow_tracker, job_accounting = workflow_tracker_and_job_accounting # Simulate adding jobs - job_accounting.update_view_table(job_id=700, user=1, group=2) - job_accounting.update_view_table(job_id=800, user=1, group=2) - job_accounting.update_view_table(job_id=900, user=2, group=3) + job_accounting.update_view_table(job_id=700, user=1, group=2, + task_id=uuid.uuid4()) + job_accounting.update_view_table(job_id=800, user=1, group=2, + task_id=uuid.uuid4()) + job_accounting.update_view_table(job_id=900, user=2, group=3, + task_id=uuid.uuid4()) # WHEN retrieving all jobs jobs = job_accounting.get_jobs() diff --git a/tests/unit/test_slurm_client.py b/tests/unit/test_slurm_client.py index c1f883e..c89eab5 100644 --- a/tests/unit/test_slurm_client.py +++ b/tests/unit/test_slurm_client.py @@ -2,12 +2,13 @@ from uuid import uuid4 from biomero.slurm_client import SlurmClient from biomero.eventsourcing import NoOpWorkflowTracker -from biomero.database import EngineManager +from biomero.database import EngineManager, TaskExecution, JobProgressView, JobView import pytest import mock from mock import patch, MagicMock from paramiko import SSHException import os +from sqlalchemy import inspect @pytest.fixture(autouse=True) @@ -977,6 +978,66 @@ def test_cleanup_tmp_files(mock_extract_data_location, mock_run_commands, assert result.ok is True +def table_exists(session, table_name): + inspector = inspect(session.bind) + return inspector.has_table(table_name) + +@patch('biomero.slurm_client.Connection.create_session') +@patch('biomero.slurm_client.Connection.open') +@patch('biomero.slurm_client.Connection.put') +@patch('biomero.slurm_client.Connection.run') +def test_sqlalchemy_tables_exist(mock_run, mock_put, mock_open, mock_session, caplog): + """ + Test that after initializing SlurmClient with all listeners enabled, + the relevant SQLAlchemy tables exist in the database. + """ + # Initialize the analytics system (with reset_tables=False to not drop them) + with caplog.at_level(logging.INFO): + slurm_client = SlurmClient( + host="localhost", + port=8022, + user="slurm", + slurm_data_path="datapath", + slurm_images_path="imagespath", + slurm_script_path="scriptpath", + slurm_converters_path="converterspath", + slurm_script_repo="repo-url", + slurm_model_paths={'wf': 'path'}, + slurm_model_images={'wf': 'image'}, + slurm_model_repos={'wf': 'https://github.com/example/workflow1'}, + track_workflows=True, # Enable workflow tracking + enable_job_accounting=True, # Enable job accounting + enable_job_progress=True, # Enable job progress tracking + enable_workflow_analytics=True, # Enable workflow analytics + init_slurm=True # Trigger the initialization of Slurm + ) + + # Check that the expected log message for table drops is present + assert any("Dropped view tables successfully" in record.message for record in caplog.records), \ + "Expected log message 'Dropped view tables successfully' was not found." + + # Check that the expected tables exist + with EngineManager.get_session() as session: + expected_tables = [ + # Listener event and tracking tables + slurm_client.jobAccounting.recorder.tracking_table_name, + slurm_client.jobAccounting.recorder.events_table_name, + slurm_client.jobProgress.recorder.tracking_table_name, + slurm_client.jobProgress.recorder.events_table_name, + slurm_client.workflowAnalytics.recorder.tracking_table_name, + slurm_client.workflowAnalytics.recorder.events_table_name, + + # Views + TaskExecution.__tablename__, + JobProgressView.__tablename__, + JobView.__tablename__ + ] + + # Ensure each expected table exists in the database + for table in expected_tables: + assert table_exists(session, table), f"Table {table} should exist but does not." + + @pytest.mark.parametrize("track_workflows, enable_job_accounting, enable_job_progress, enable_workflow_analytics, expected_tracker_classes", [ # Case when everything is enabled (True, True, True, True, {"workflowTracker": "WorkflowTracker", "jobAccounting": "JobAccounting", "jobProgress": "JobProgress", "workflowAnalytics": "WorkflowAnalytics"}), From 9aab515ed2c008d68f7b4215638974ccd9be2ad6 Mon Sep 17 00:00:00 2001 From: Luik Date: Tue, 10 Sep 2024 16:06:37 +0200 Subject: [PATCH 16/24] Add workflowprogress view table --- biomero/constants.py | 13 ++- biomero/database.py | 26 +++++- biomero/slurm_client.py | 32 +++++++- biomero/views.py | 170 +++++++++++++++++++++++++++++++++++++++- 4 files changed, 233 insertions(+), 8 deletions(-) diff --git a/biomero/constants.py b/biomero/constants.py index 95cf777..e846c19 100644 --- a/biomero/constants.py +++ b/biomero/constants.py @@ -107,4 +107,15 @@ class transfer: FORMAT_OMETIFF = 'OME-TIFF' FORMAT_ZARR = 'ZARR' FOLDER = "Folder_Name" - FOLDER_DEFAULT = 'SLURM_IMAGES_' \ No newline at end of file + FOLDER_DEFAULT = 'SLURM_IMAGES_' + + +class workflow_status: + INITIALIZING = "INITIALIZING" + TRANSFERRING = "TRANSFERRING" + CONVERTING = "CONVERTING" + RETRIEVING = "RETRIEVING" + DONE = "DONE" + FAILED = "FAILED" + RUNNING = "RUNNING" + JOB_STATUS = "JOB_" \ No newline at end of file diff --git a/biomero/database.py b/biomero/database.py index 3c8f41a..71efbe4 100644 --- a/biomero/database.py +++ b/biomero/database.py @@ -59,8 +59,32 @@ class JobProgressView(Base): slurm_job_id = Column(Integer, primary_key=True) status = Column(String, nullable=False) progress = Column(String, nullable=True) - + +class WorkflowProgressView(Base): + """ + SQLAlchemy model for the 'workflow_progress_view' table. + + Attributes: + workflow_id (PGUUID): The unique identifier for the workflow (primary key). + status (String, optional): The current status of the workflow. + progress (String, optional): The progress status of the workflow. + user (String, optional): The user who initiated the workflow. + group (String, optional): The group associated with the workflow. + name (String, optional): The name of the workflow + """ + __tablename__ = 'biomero_workflow_progress_view' + + workflow_id = Column(PGUUID(as_uuid=True), primary_key=True) + status = Column(String, nullable=True) + progress = Column(String, nullable=True) + user = Column(Integer, nullable=True) + group = Column(Integer, nullable=True) + name = Column(String, nullable=True) + task = Column(String, nullable=True) + start_time = Column(DateTime, nullable=False) + + class TaskExecution(Base): """ SQLAlchemy model for the 'biomero_task_execution' table. diff --git a/biomero/slurm_client.py b/biomero/slurm_client.py index 711d7c8..b02adff 100644 --- a/biomero/slurm_client.py +++ b/biomero/slurm_client.py @@ -31,8 +31,8 @@ import io import os from biomero.eventsourcing import WorkflowTracker, NoOpWorkflowTracker -from biomero.views import JobAccounting, JobProgress, WorkflowAnalytics -from biomero.database import EngineManager, JobProgressView, JobView, TaskExecution +from biomero.views import JobAccounting, JobProgress, WorkflowAnalytics, WorkflowProgress +from biomero.database import EngineManager, JobProgressView, JobView, TaskExecution, WorkflowProgressView from eventsourcing.system import System, SingleThreadedRunner from sqlalchemy.exc import IntegrityError from sqlalchemy.sql import text @@ -467,6 +467,7 @@ def initialize_analytics_system(self, reset_tables=False): # Add JobProgress to the pipeline if enabled if self.enable_job_progress: pipes.append([WorkflowTracker, JobProgress]) + pipes.append([WorkflowTracker, WorkflowProgress]) # Add WorkflowAnalytics to the pipeline if enabled if self.enable_workflow_analytics: @@ -500,7 +501,8 @@ def setup_listeners(self, runner, reset_tables): tables = [] # gather the listener tables listeners = [self.jobAccounting, - self.jobProgress, + self.jobProgress, + self.wfProgress, self.workflowAnalytics] for listener in listeners: if listener: @@ -510,6 +512,7 @@ def setup_listeners(self, runner, reset_tables): # gather the view tables tables.append(TaskExecution.__tablename__) tables.append(JobProgressView.__tablename__) + tables.append(WorkflowProgressView.__tablename__) tables.append(JobView.__tablename__) with EngineManager.get_session() as session: try: @@ -530,22 +533,43 @@ def setup_listeners(self, runner, reset_tables): EngineManager.close_engine() # close current sql session # restart runner, listeners and recreate views self.initialize_analytics_system(reset_tables=False) + # Update the view tables again + listeners = [self.jobAccounting, + self.jobProgress, + self.wfProgress, + self.workflowAnalytics] + for listener in listeners: + if listener: + self.bring_listener_uptodate(listener) def get_listeners(self, runner): if self.track_workflows and self.enable_job_accounting: - self.jobAccounting = runner.get(JobAccounting) + self.jobAccounting = runner.get(JobAccounting) else: self.jobAccounting = NoOpWorkflowTracker() if self.track_workflows and self.enable_job_progress: self.jobProgress = runner.get(JobProgress) + self.wfProgress = runner.get(WorkflowProgress) else: self.jobProgress = NoOpWorkflowTracker() + self.wfProgress = NoOpWorkflowTracker() if self.track_workflows and self.enable_workflow_analytics: self.workflowAnalytics = runner.get(WorkflowAnalytics) else: self.workflowAnalytics = NoOpWorkflowTracker() + + def bring_listener_uptodate(self, listener): + with EngineManager.get_session() as session: + try: + # Begin a transaction + listener.pull_and_process(leader_name=WorkflowTracker.__name__, start=1) + session.commit() + logger.info("Updated listener successfully") + except IntegrityError as e: + logger.error(e) + session.rollback() def __exit__(self, exc_type, exc_val, exc_tb): # Ensure to call the parent class's __exit__ diff --git a/biomero/views.py b/biomero/views.py index 5c6a831..9cfe496 100644 --- a/biomero/views.py +++ b/biomero/views.py @@ -27,8 +27,8 @@ from sqlalchemy import event from sqlalchemy.engine import Engine from biomero.eventsourcing import WorkflowRun, Task -from biomero.database import EngineManager, JobView, TaskExecution, JobProgressView - +from biomero.database import EngineManager, JobView, TaskExecution, JobProgressView, WorkflowProgressView +from biomero.constants import workflow_status as wfs logger = logging.getLogger(__name__) @@ -158,6 +158,172 @@ def get_jobs(self, user=None, group=None): return result +class WorkflowProgress(ProcessApplication): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # State tracking: {workflow_id: {"status": status, "progress": progress, "user": user, "group": group}} + self.workflows = {} + self.tasks = {} # {task_id: {"workflow_id": wf_id, "task_name": task_name}} + + @singledispatchmethod + def policy(self, domain_event, process_event): + """Default policy""" + pass + + @policy.register(WorkflowRun.WorkflowInitiated) + def _(self, domain_event, process_event): + """Handle WorkflowInitiated event""" + user = domain_event.user + group = domain_event.group + wf_id = domain_event.originator_id + name = domain_event.name + start_time = domain_event.timestamp + + # Track workflow with user, group, and INITIATED status + self.workflows[wf_id] = {"status": wfs.INITIALIZING, + "progress": "0%", + "user": user, + "group": group, + "name": name, + "task": None, + "start_time": start_time} + logger.debug(f"Workflow initiated: wf_id={wf_id}, name={name}, user={user}, group={group}, status={wfs.INITIALIZING}") + self.update_view_table(wf_id) + + @policy.register(WorkflowRun.WorkflowCompleted) + def _(self, domain_event, process_event): + wf_id = domain_event.originator_id + self.workflows[wf_id]["status"] = wfs.DONE + self.workflows[wf_id]["progress"] = "100%" + logger.debug(f"Status updated: wf_id={wf_id}, status={wfs.DONE}") + self.update_view_table(wf_id) + + @policy.register(WorkflowRun.WorkflowFailed) + def _(self, domain_event, process_event): + wf_id = domain_event.originator_id + error = domain_event.error_message + self.workflows[wf_id]["status"] = wfs.FAILED + logger.debug(f"Status updated: wf_id={wf_id}, status={wfs.FAILED}") + self.update_view_table(wf_id) + + @policy.register(WorkflowRun.TaskAdded) + def _(self, domain_event, process_event): + """Handle TaskAdded event""" + task_id = domain_event.task_id + wf_id = domain_event.originator_id + + # Track task to workflow mapping + if task_id in self.tasks: + self.tasks[task_id]["workflow_id"] = wf_id + if wf_id in self.workflows: + self.workflows[wf_id]["task"] = self.tasks[task_id]["task_name"] + logger.debug(f"Task added: task_id={task_id}, wf_id={wf_id}") + + @policy.register(Task.TaskCreated) + def _(self, domain_event, process_event): + task_id = domain_event.originator_id + task_name = domain_event.task_name + + # store task name + self.tasks[task_id] = { + "task_name": task_name, + "workflow_id": None, + "progress": None + } + logger.debug(f"Task created: task_id={task_id}, task_name={task_name}") + + @policy.register(Task.StatusUpdated) + def _(self, domain_event, process_event): + """Handle Task StatusUpdated event""" + task_id = domain_event.originator_id + status = domain_event.status + + # Get the workflow ID and task name associated with this task + task_info = self.tasks.get(task_id) + if task_info: + wf_id = task_info["workflow_id"] + task_name = task_info["task_name"] + + if wf_id and wf_id in self.workflows: + # Determine status based on task_name + if task_name == '_SLURM_Image_Transfer.py': + workflow_status = wfs.TRANSFERRING + workflow_prog = "5%" + elif task_name.startswith('convert_'): + workflow_status = wfs.CONVERTING + workflow_prog = "25%" + elif task_name == 'SLURM_Get_Results.py': + workflow_status = wfs.RETRIEVING + workflow_prog = "90%" + elif task_name == 'SLURM_Run_Workflow.py': + workflow_status = wfs.RUNNING + workflow_prog = "50%" + else: + # Default to JOB_STATUS prefix for unknown task types + workflow_status = wfs.JOB_STATUS + status + workflow_prog = "50%" + if "task_progress" in self.workflows[wf_id]: + task_prog = self.workflows[wf_id]["task_progress"] + if task_prog: + # Initial string and baseline + upper_limit_str = "90%" # Upper limit string + # Step 1: Extract integers from the strings + current_val = int(task_prog.strip('%')) + baseline_val = int(workflow_prog.strip('%')) + upper_limit_val = int(upper_limit_str.strip('%')) + + # Step 2: Interpolation logic + # Map the current_val (43) between the range 50 and 90 + # Formula for linear interpolation: new_val = baseline + (current_val / 100) * (upper_limit - baseline) + interpolated_val = baseline_val + ((current_val / 100) * (upper_limit_val - baseline_val)) + workflow_prog = f"{interpolated_val}%" + + + # Update the workflow status + self.workflows[wf_id]["status"] = workflow_status + self.workflows[wf_id]["progress"] = workflow_prog + logger.debug(f"Status updated: wf_id={wf_id}, task_id={task_id}, status={workflow_status}") + self.update_view_table(wf_id) + + @policy.register(Task.ProgressUpdated) + def _(self, domain_event, process_event): + """Handle ProgressUpdated event""" + task_id = domain_event.originator_id + progress = domain_event.progress + + if task_id in self.tasks: + self.tasks[task_id]["progress"] = progress + wf_id = self.tasks[task_id]["workflow_id"] + if wf_id and wf_id in self.workflows: + self.workflows[wf_id]["task_progress"] = progress + logger.debug(f"(Task) Progress updated: wf_id={wf_id}, progress={progress}") + self.update_view_table(wf_id) + + def update_view_table(self, wf_id): + """Update the view table with new workflow status, progress, user, and group.""" + with EngineManager.get_session() as session: + workflow_info = self.workflows[wf_id] + try: + new_workflow_progress = WorkflowProgressView( + workflow_id=wf_id, + status=workflow_info["status"], + progress=workflow_info["progress"], + user=workflow_info["user"], + group=workflow_info["group"], + name=workflow_info["name"], + task=workflow_info["task"], + start_time=workflow_info["start_time"] + ) + session.merge(new_workflow_progress) + session.commit() + logger.debug(f"Inserted wf progress in view table: wf_id={wf_id} wf_info={workflow_info}") + except IntegrityError: + session.rollback() + logger.error(f"Failed to insert/update wf progress in view table: wf_id={wf_id} wf_info={workflow_info}") + + + class JobProgress(ProcessApplication): def __init__(self, *args, **kwargs): ProcessApplication.__init__(self, *args, **kwargs) From c9587bd33ca2b04a793716454b156caeda010cde Mon Sep 17 00:00:00 2001 From: Luik Date: Tue, 1 Oct 2024 11:44:10 +0200 Subject: [PATCH 17/24] Better logging. More comitting. --- biomero/database.py | 2 +- biomero/eventsourcing.py | 29 +++++++-------- biomero/slurm_client.py | 6 ++-- biomero/views.py | 76 +++++++++++++++++++++++++--------------- 4 files changed, 66 insertions(+), 47 deletions(-) diff --git a/biomero/database.py b/biomero/database.py index 71efbe4..2cc69a3 100644 --- a/biomero/database.py +++ b/biomero/database.py @@ -152,7 +152,7 @@ def create_scoped_session(cls, sqlalchemy_url: str = None): # Create a scoped_session object. cls._session = scoped_session( - sessionmaker(autocommit=False, autoflush=False, bind=cls._engine) + sessionmaker(autocommit=False, autoflush=True, bind=cls._engine) ) class MyScopedSessionAdapter: diff --git a/biomero/eventsourcing.py b/biomero/eventsourcing.py index 99d2865..cff7f3e 100644 --- a/biomero/eventsourcing.py +++ b/biomero/eventsourcing.py @@ -92,7 +92,7 @@ def __init__(self, name: str, self.user = user self.group = group self.tasks = [] - logger.debug(f"Initializing WorkflowRun: name={name}, description={description}, user={user}, group={group}") + # logger.debug(f"Initializing WorkflowRun: name={name}, description={description}, user={user}, group={group}") class TaskAdded(Aggregate.Event): """ @@ -105,7 +105,7 @@ class TaskAdded(Aggregate.Event): @event(TaskAdded) def add_task(self, task_id: UUID): - logger.debug(f"Adding task to WorkflowRun: task_id={task_id}") + # logger.debug(f"Adding task to WorkflowRun: task_id={task_id}") self.tasks.append(task_id) class WorkflowStarted(Aggregate.Event): @@ -116,7 +116,7 @@ class WorkflowStarted(Aggregate.Event): @event(WorkflowStarted) def start_workflow(self): - logger.debug(f"Starting workflow: id={self.id}") + # logger.debug(f"Starting workflow: id={self.id}") pass class WorkflowCompleted(Aggregate.Event): @@ -127,7 +127,7 @@ class WorkflowCompleted(Aggregate.Event): @event(WorkflowCompleted) def complete_workflow(self): - logger.debug(f"Completing workflow: id={self.id}") + # logger.debug(f"Completing workflow: id={self.id}") pass class WorkflowFailed(Aggregate.Event): @@ -141,7 +141,7 @@ class WorkflowFailed(Aggregate.Event): @event(WorkflowFailed) def fail_workflow(self, error_message: str): - logger.debug(f"Failing workflow: id={self.id}, error_message={error_message}") + # logger.debug(f"Failing workflow: id={self.id}, error_message={error_message}") pass @@ -196,7 +196,8 @@ def __init__(self, self.results = [] self.result_message = None self.status = None - logger.debug(f"Initializing Task: workflow_id={workflow_id}, task_name={task_name}, task_version={task_version}") + # Not logging on aggregates, they get reconstructed so much + # logger.debug(f"Initializing Task: workflow_id={workflow_id}, task_name={task_name}, task_version={task_version}") class JobIdAdded(Aggregate.Event): """ @@ -209,7 +210,7 @@ class JobIdAdded(Aggregate.Event): @event(JobIdAdded) def add_job_id(self, job_id): - logger.debug(f"Adding job_id to Task: task_id={self.id}, job_id={job_id}") + # logger.debug(f"Adding job_id to Task: task_id={self.id}, job_id={job_id}") self.job_ids.append(job_id) class StatusUpdated(Aggregate.Event): @@ -223,7 +224,7 @@ class StatusUpdated(Aggregate.Event): @event(StatusUpdated) def update_task_status(self, status): - logger.debug(f"Adding status to Task: task_id={self.id}, status={status}") + # logger.debug(f"Adding status to Task: task_id={self.id}, status={status}") self.status = status class ProgressUpdated(Aggregate.Event): @@ -237,7 +238,7 @@ class ProgressUpdated(Aggregate.Event): @event(ProgressUpdated) def update_task_progress(self, progress): - logger.debug(f"Adding progress to Task: task_id={self.id}, progress={progress}") + # logger.debug(f"Adding progress to Task: task_id={self.id}, progress={progress}") self.progress = progress class ResultAdded(Aggregate.Event): @@ -250,13 +251,13 @@ class ResultAdded(Aggregate.Event): result: ResultDict def add_result(self, result: Result): - logger.debug(f"Adding result to Task: task_id={self.id}, result={result}") + # logger.debug(f"Adding result to Task: task_id={self.id}, result={result}") result = ResultDict(result) self._add_result(result) @event(ResultAdded) def _add_result(self, result: ResultDict): - logger.debug(f"Adding result to Task results: task_id={self.id}, result={result}") + # logger.debug(f"Adding result to Task results: task_id={self.id}, result={result}") self.results.append(result) class TaskStarted(Aggregate.Event): @@ -267,7 +268,7 @@ class TaskStarted(Aggregate.Event): @event(TaskStarted) def start_task(self): - logger.debug(f"Starting task: id={self.id}") + # logger.debug(f"Starting task: id={self.id}") pass class TaskCompleted(Aggregate.Event): @@ -281,7 +282,7 @@ class TaskCompleted(Aggregate.Event): @event(TaskCompleted) def complete_task(self, result: str): - logger.debug(f"Completing task: id={self.id}, result={result}") + # logger.debug(f"Completing task: id={self.id}, result={result}") self.result_message = result class TaskFailed(Aggregate.Event): @@ -295,7 +296,7 @@ class TaskFailed(Aggregate.Event): @event(TaskFailed) def fail_task(self, error_message: str): - logger.debug(f"Failing task: id={self.id}, error_message={error_message}") + # logger.debug(f"Failing task: id={self.id}, error_message={error_message}") self.result_message = error_message pass diff --git a/biomero/slurm_client.py b/biomero/slurm_client.py index b02adff..af581cd 100644 --- a/biomero/slurm_client.py +++ b/biomero/slurm_client.py @@ -505,7 +505,7 @@ def setup_listeners(self, runner, reset_tables): self.wfProgress, self.workflowAnalytics] for listener in listeners: - if listener: + if not isinstance(listener, NoOpWorkflowTracker): tables.append(listener.recorder.tracking_table_name) tables.append(listener.recorder.events_table_name) runner.stop() @@ -560,11 +560,11 @@ def get_listeners(self, runner): else: self.workflowAnalytics = NoOpWorkflowTracker() - def bring_listener_uptodate(self, listener): + def bring_listener_uptodate(self, listener, start=1): with EngineManager.get_session() as session: try: # Begin a transaction - listener.pull_and_process(leader_name=WorkflowTracker.__name__, start=1) + listener.pull_and_process(leader_name=WorkflowTracker.__name__, start=start) session.commit() logger.info("Updated listener successfully") except IntegrityError as e: diff --git a/biomero/views.py b/biomero/views.py index 9cfe496..93935b4 100644 --- a/biomero/views.py +++ b/biomero/views.py @@ -63,6 +63,7 @@ def _(self, domain_event, process_event): # Optionally, persist this state if needed # Optionally, add an event to do that, then save via collect # process_event.collect_events(jobaccount, wfView) + EngineManager.commit() @policy.register(WorkflowRun.TaskAdded) def _(self, domain_event, process_event): @@ -77,6 +78,7 @@ def _(self, domain_event, process_event): # Optionally, persist this state if needed # use .collect_events(agg) instead of .save(agg) # process_event.collect_events(taskView) + EngineManager.commit() @policy.register(Task.JobIdAdded) def _(self, domain_event, process_event): @@ -104,6 +106,7 @@ def _(self, domain_event, process_event): # use .collect_events(agg) instead of .save(agg) # process_event.collect_events(jobaccount) + EngineManager.commit() def update_view_table(self, job_id, user, group, task_id): """Update the view table with new job information.""" @@ -188,24 +191,27 @@ def _(self, domain_event, process_event): "name": name, "task": None, "start_time": start_time} - logger.debug(f"Workflow initiated: wf_id={wf_id}, name={name}, user={user}, group={group}, status={wfs.INITIALIZING}") + logger.debug(f"[WFP] Workflow initiated: wf_id={wf_id}, name={name}, user={user}, group={group}, status={wfs.INITIALIZING} -- {domain_event.__dict__}") self.update_view_table(wf_id) + EngineManager.commit() @policy.register(WorkflowRun.WorkflowCompleted) def _(self, domain_event, process_event): wf_id = domain_event.originator_id self.workflows[wf_id]["status"] = wfs.DONE self.workflows[wf_id]["progress"] = "100%" - logger.debug(f"Status updated: wf_id={wf_id}, status={wfs.DONE}") + logger.debug(f"[WFP] Status updated: wf_id={wf_id}, status={wfs.DONE} -- {domain_event.__dict__}") self.update_view_table(wf_id) + EngineManager.commit() @policy.register(WorkflowRun.WorkflowFailed) def _(self, domain_event, process_event): wf_id = domain_event.originator_id error = domain_event.error_message self.workflows[wf_id]["status"] = wfs.FAILED - logger.debug(f"Status updated: wf_id={wf_id}, status={wfs.FAILED}") + logger.debug(f"[WFP] Status updated: wf_id={wf_id}, status={wfs.FAILED} -- {domain_event.__dict__}") self.update_view_table(wf_id) + EngineManager.commit() @policy.register(WorkflowRun.TaskAdded) def _(self, domain_event, process_event): @@ -218,7 +224,8 @@ def _(self, domain_event, process_event): self.tasks[task_id]["workflow_id"] = wf_id if wf_id in self.workflows: self.workflows[wf_id]["task"] = self.tasks[task_id]["task_name"] - logger.debug(f"Task added: task_id={task_id}, wf_id={wf_id}") + logger.debug(f"[WFP] Task added: task_id={task_id}, wf_id={wf_id} -- {domain_event.__dict__}") + EngineManager.commit() @policy.register(Task.TaskCreated) def _(self, domain_event, process_event): @@ -231,7 +238,8 @@ def _(self, domain_event, process_event): "workflow_id": None, "progress": None } - logger.debug(f"Task created: task_id={task_id}, task_name={task_name}") + logger.debug(f"[WFP] Task created: task_id={task_id}, task_name={task_name} -- {domain_event.__dict__}") + EngineManager.commit() @policy.register(Task.StatusUpdated) def _(self, domain_event, process_event): @@ -283,8 +291,9 @@ def _(self, domain_event, process_event): # Update the workflow status self.workflows[wf_id]["status"] = workflow_status self.workflows[wf_id]["progress"] = workflow_prog - logger.debug(f"Status updated: wf_id={wf_id}, task_id={task_id}, status={workflow_status}") + logger.debug(f"[WFP] Status updated: wf_id={wf_id}, task_id={task_id}, status={workflow_status} -- {domain_event.__dict__}") self.update_view_table(wf_id) + EngineManager.commit() @policy.register(Task.ProgressUpdated) def _(self, domain_event, process_event): @@ -297,8 +306,9 @@ def _(self, domain_event, process_event): wf_id = self.tasks[task_id]["workflow_id"] if wf_id and wf_id in self.workflows: self.workflows[wf_id]["task_progress"] = progress - logger.debug(f"(Task) Progress updated: wf_id={wf_id}, progress={progress}") - self.update_view_table(wf_id) + logger.debug(f"[WFP] (Task) Progress updated: wf_id={wf_id}, progress={progress} -- {domain_event.__dict__}") + self.update_view_table(wf_id) + EngineManager.commit() def update_view_table(self, wf_id): """Update the view table with new workflow status, progress, user, and group.""" @@ -317,11 +327,10 @@ def update_view_table(self, wf_id): ) session.merge(new_workflow_progress) session.commit() - logger.debug(f"Inserted wf progress in view table: wf_id={wf_id} wf_info={workflow_info}") + logger.debug(f"[WFP] Inserted wf progress in view table: wf_id={wf_id} wf_info={workflow_info}") except IntegrityError: session.rollback() - logger.error(f"Failed to insert/update wf progress in view table: wf_id={wf_id} wf_info={workflow_info}") - + logger.error(f"[WFP] Failed to insert/update wf progress in view table: wf_id={wf_id} wf_info={workflow_info}") class JobProgress(ProcessApplication): @@ -344,7 +353,8 @@ def _(self, domain_event, process_event): # Track task to job mapping self.task_to_job[task_id] = job_id - logger.debug(f"JobId added: job_id={job_id}, task_id={task_id}") + logger.debug(f"[JP] JobId added: job_id={job_id}, task_id={task_id} -- {domain_event.__dict__}") + EngineManager.commit() @policy.register(Task.StatusUpdated) def _(self, domain_event, process_event): @@ -359,9 +369,10 @@ def _(self, domain_event, process_event): else: self.job_status[job_id] = {"status": status, "progress": None} - logger.debug(f"Status updated: job_id={job_id}, status={status}") + logger.debug(f"[JP] Status updated: job_id={job_id}, status={status} -- {domain_event.__dict__}") # Update view table self.update_view_table(job_id) + EngineManager.commit() @policy.register(Task.ProgressUpdated) def _(self, domain_event, process_event): @@ -376,9 +387,10 @@ def _(self, domain_event, process_event): else: self.job_status[job_id] = {"status": "UNKNOWN", "progress": progress} - logger.debug(f"Progress updated: job_id={job_id}, progress={progress}") + logger.debug(f"[JP] Progress updated: job_id={job_id}, progress={progress} -- {domain_event.__dict__}") # Update view table self.update_view_table(job_id) + EngineManager.commit() def update_view_table(self, job_id): """Update the view table with new job status and progress information.""" @@ -392,10 +404,10 @@ def update_view_table(self, job_id): ) session.merge(new_job_progress) # Use merge to insert or update session.commit() - logger.debug(f"Inserted/Updated job progress in view table: job_id={job_id}, status={job_info['status']}, progress={job_info['progress']}") + logger.debug(f"[JP] Inserted/Updated job progress in view table: job_id={job_id}, status={job_info['status']}, progress={job_info['progress']}") except IntegrityError: session.rollback() - logger.error(f"Failed to insert/update job progress in view table: job_id={job_id}") + logger.error(f"[JP] Failed to insert/update job progress in view table: job_id={job_id}") # @event.listens_for(Engine, "before_cursor_execute") @@ -426,7 +438,8 @@ def _(self, domain_event, process_event): # Track workflow self.workflows[wf_id] = {"user": user, "group": group} - logger.debug(f"Workflow initiated: wf_id={wf_id}, user={user}, group={group}") + logger.debug(f"[WFA] Workflow initiated: wf_id={wf_id}, user={user}, group={group} -- {domain_event.__dict__}") + EngineManager.commit() @policy.register(WorkflowRun.TaskAdded) def _(self, domain_event, process_event): @@ -441,7 +454,8 @@ def _(self, domain_event, process_event): # In case TaskAdded arrives before TaskCreated (unlikely but possible) self.tasks[task_id] = {"wf_id": wf_id} - logger.debug(f"Task added: task_id={task_id}, wf_id={wf_id}") + logger.debug(f"[WFA] Task added: task_id={task_id}, wf_id={wf_id} -- {domain_event.__dict__}") + EngineManager.commit() @policy.register(Task.TaskCreated) def _(self, domain_event, process_event): @@ -468,8 +482,9 @@ def _(self, domain_event, process_event): "status": "CREATED" } - logger.debug(f"Task created: task_id={task_id}, task_name={task_name}, timestamp={timestamp_created}") + logger.debug(f"[WFA] Task created: task_id={task_id}, task_name={task_name}, timestamp={timestamp_created} -- {domain_event.__dict__}") self.update_view_table(task_id) + EngineManager.commit() @policy.register(Task.StatusUpdated) def _(self, domain_event, process_event): @@ -480,8 +495,9 @@ def _(self, domain_event, process_event): # Update task with status if task_id in self.tasks: self.tasks[task_id]["status"] = status - logger.debug(f"Task status updated: task_id={task_id}, status={status}") + logger.debug(f"[WFA] Task status updated: task_id={task_id}, status={status} -- {domain_event.__dict__}") self.update_view_table(task_id) + EngineManager.commit() @policy.register(Task.TaskCompleted) def _(self, domain_event, process_event): @@ -492,8 +508,9 @@ def _(self, domain_event, process_event): # Update task with end time if task_id in self.tasks: self.tasks[task_id]["end_time"] = timestamp_completed - logger.debug(f"Task completed: task_id={task_id}, end_time={timestamp_completed}") + logger.debug(f"[WFA] Task completed: task_id={task_id}, end_time={timestamp_completed} -- {domain_event.__dict__}") self.update_view_table(task_id) + EngineManager.commit() @policy.register(Task.TaskFailed) def _(self, domain_event, process_event): @@ -506,8 +523,9 @@ def _(self, domain_event, process_event): if task_id in self.tasks: self.tasks[task_id]["end_time"] = timestamp_failed self.tasks[task_id]["error_type"] = error_message - logger.debug(f"Task failed: task_id={task_id}, end_time={timestamp_failed}, error={error_message}") + logger.debug(f"[WFA] Task failed: task_id={task_id}, end_time={timestamp_failed}, error={error_message} -- {domain_event.__dict__}") self.update_view_table(task_id) + EngineManager.commit() def update_view_table(self, task_id): """Update the view table with new task execution information.""" @@ -554,11 +572,11 @@ def update_view_table(self, task_id): session.add(new_task_execution) session.commit() - logger.debug(f"Updated/Inserted task execution into view table: task_id={task_id}, task_name={task_info.get('task_name')}") + logger.debug(f"[WFA] Updated/Inserted task execution into view table: task_id={task_id}, task_name={task_info.get('task_name')}") except IntegrityError as e: session.rollback() - logger.error(f"Failed to insert/update task execution into view table: task_id={task_id}, error={str(e)}") - logger.debug(f"Task info: {task_info}") + logger.error(f"[WFA] Failed to insert/update task execution into view table: task_id={task_id}, error={str(e)}") + logger.debug(f"[WFA] Task info: {task_info}") def get_task_counts(self, user=None, group=None): """Retrieve task execution counts grouped by task name and version. @@ -588,7 +606,7 @@ def get_task_counts(self, user=None, group=None): (task_name, task_version): count for task_name, task_version, count in task_counts } - logger.debug(f"Retrieved task counts: {result}") + logger.debug(f"[WFA] Retrieved task counts: {result}") return result def get_average_task_duration(self, user=None, group=None): @@ -622,7 +640,7 @@ def get_average_task_duration(self, user=None, group=None): (task_name, task_version): avg_duration for task_name, task_version, avg_duration in task_durations } - logger.debug(f"Retrieved average task durations: {result}") + logger.debug(f"[WFA] Retrieved average task durations: {result}") return result def get_task_failures(self, user=None, group=None): @@ -657,7 +675,7 @@ def get_task_failures(self, user=None, group=None): result[key] = [] result[key].append(error_type) - logger.debug(f"Retrieved task failures: {result}") + logger.debug(f"[WFA] Retrieved task failures: {result}") return result def get_task_usage_over_time(self, task_name, user=None, group=None): @@ -689,5 +707,5 @@ def get_task_usage_over_time(self, task_name, user=None, group=None): date: count for date, count in usage_over_time } - logger.debug(f"Retrieved task usage over time for {task_name}: {result}") + logger.debug(f"[WFA] Retrieved task usage over time for {task_name}: {result}") return result From 48007250a3c14d171d77080827fbef52ba9dc379 Mon Sep 17 00:00:00 2001 From: Luik Date: Tue, 8 Oct 2024 14:28:44 +0200 Subject: [PATCH 18/24] Use modified eventsourcing_alchemy lib --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f874f3f..3bd31bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,10 +29,10 @@ dependencies = [ "fabric==3.1.0", "paramiko==3.4.0", "importlib_resources>=5.4.0", - "eventsourcing[crypto]==9.2.22", + "eventsourcing==9.3", "sqlalchemy==2.0.32", "psycopg2==2.9.9", - "eventsourcing_sqlalchemy==0.7" # requires py 3.8 + "eventsourcing_sqlalchemy @ git+https://github.com/Cellular-Imaging-Amsterdam-UMC/eventsourcing-sqlalchemy@main#egg=eventsourcing_sqlalchemy" # Specify the Git repo ] [tool.setuptools.packages] From 3458600591a1bbde5185f227d25a3b3d2cde149c Mon Sep 17 00:00:00 2001 From: Luik Date: Tue, 8 Oct 2024 16:21:42 +0200 Subject: [PATCH 19/24] Update test to cover workflowProgress view --- biomero/views.py | 9 -- tests/unit/test_eventsourcing.py | 206 ++++++++++++++++++++++++++++++- 2 files changed, 202 insertions(+), 13 deletions(-) diff --git a/biomero/views.py b/biomero/views.py index 93935b4..1f0170e 100644 --- a/biomero/views.py +++ b/biomero/views.py @@ -12,20 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from eventsourcing.system import ProcessApplication from eventsourcing.dispatch import singledispatchmethod -from eventsourcing.utils import get_topic -from uuid import NAMESPACE_URL, UUID, uuid5 -from typing import Any, Dict, List import logging -from sqlalchemy import create_engine, text, Column, Integer, String, URL, DateTime, Float -from sqlalchemy.orm import sessionmaker, declarative_base, scoped_session from sqlalchemy.exc import IntegrityError from sqlalchemy.sql import func -from sqlalchemy.dialects.postgresql import UUID as PGUUID -from sqlalchemy import event -from sqlalchemy.engine import Engine from biomero.eventsourcing import WorkflowRun, Task from biomero.database import EngineManager, JobView, TaskExecution, JobProgressView, WorkflowProgressView from biomero.constants import workflow_status as wfs diff --git a/tests/unit/test_eventsourcing.py b/tests/unit/test_eventsourcing.py index 34802b4..231ccb4 100644 --- a/tests/unit/test_eventsourcing.py +++ b/tests/unit/test_eventsourcing.py @@ -1,12 +1,12 @@ -from datetime import datetime, timedelta, timezone -import json +from datetime import datetime, timezone import os -from unittest.mock import Mock, patch +from unittest.mock import patch import uuid import pytest from biomero.eventsourcing import Task, WorkflowTracker -from biomero.views import JobAccounting, JobProgress, WorkflowAnalytics +from biomero.views import JobAccounting, JobProgress, WorkflowAnalytics, WorkflowProgress from biomero.database import EngineManager, JobProgressView, JobView, TaskExecution +from biomero.constants import workflow_status as wfs from uuid import UUID import logging from eventsourcing.system import System, SingleThreadedRunner @@ -71,6 +71,21 @@ def workflow_tracker_and_job_progress(): # Cleanup after tests runner.stop() + + +@pytest.fixture +def workflow_tracker_and_workflow_progress(): + """Fixture to initialize System and SingleThreadedRunner with WorkflowTracker and JobProgress.""" + # Create a System instance with the necessary components + system = System(pipes=[[WorkflowTracker, WorkflowProgress]]) + runner = SingleThreadedRunner(system) + runner.start() + + # Yield the instances of WorkflowTracker and JobProgress + yield runner.get(WorkflowTracker), runner.get(WorkflowProgress) + + # Cleanup after tests + runner.stop() @pytest.fixture @@ -713,6 +728,189 @@ def test_job_acc_get_jobs_all(workflow_tracker_and_job_accounting): assert user_jobs == {1: [700, 800], 2: [900]} +def test_workflow_progress_workflow_initiated(workflow_tracker_and_workflow_progress): + # GIVEN a WorkflowTracker event system and workflow progress listener + workflow_tracker: WorkflowTracker + workflow_progress: WorkflowProgress + workflow_tracker, workflow_progress = workflow_tracker_and_workflow_progress + + # WHEN a workflow is initiated + workflow_id = workflow_tracker.initiate_workflow( + "Test Workflow", "Test Description", user=1, group=2) + + # THEN verify internal state in WorkflowProgress + assert workflow_id in workflow_progress.workflows + assert workflow_progress.workflows[workflow_id]["status"] == wfs.INITIALIZING + assert workflow_progress.workflows[workflow_id]["progress"] == "0%" + assert workflow_progress.workflows[workflow_id]["user"] == 1 + assert workflow_progress.workflows[workflow_id]["group"] == 2 + + +def test_workflow_progress_workflow_completed(workflow_tracker_and_workflow_progress): + # GIVEN a WorkflowTracker event system and workflow progress listener + workflow_tracker: WorkflowTracker + workflow_progress: WorkflowProgress + workflow_tracker, workflow_progress = workflow_tracker_and_workflow_progress + + # WHEN a workflow is initiated + workflow_id = workflow_tracker.initiate_workflow( + "Test Workflow", "Test Description", user=1, group=2) + + # Complete the workflow + workflow_tracker.complete_workflow(workflow_id) + + # THEN verify internal state in WorkflowProgress + assert workflow_id in workflow_progress.workflows + assert workflow_progress.workflows[workflow_id]["status"] == wfs.DONE + assert workflow_progress.workflows[workflow_id]["progress"] == "100%" + + +def test_workflow_progress_workflow_failed(workflow_tracker_and_workflow_progress): + # GIVEN a WorkflowTracker event system and workflow progress listener + workflow_tracker: WorkflowTracker + workflow_progress: WorkflowProgress + workflow_tracker, workflow_progress = workflow_tracker_and_workflow_progress + + # WHEN a workflow is initiated + workflow_id = workflow_tracker.initiate_workflow( + "Test Workflow", "Test Description", user=1, group=2) + + # Mark the workflow as failed + error_message = "An error occurred" + workflow_tracker.fail_workflow(workflow_id, error_message) + + # THEN verify internal state in WorkflowProgress + assert workflow_id in workflow_progress.workflows + assert workflow_progress.workflows[workflow_id]["status"] == wfs.FAILED + + +def test_workflow_progress_task_added(workflow_tracker_and_workflow_progress): + # GIVEN a WorkflowTracker event system and workflow progress listener + workflow_tracker: WorkflowTracker + workflow_progress: WorkflowProgress + workflow_tracker, workflow_progress = workflow_tracker_and_workflow_progress + + # WHEN a workflow is initiated + workflow_id = workflow_tracker.initiate_workflow( + "Test Workflow", "Test Description", user=1, group=2) + + # Add a task to the workflow + task_id = workflow_tracker.add_task_to_workflow( + workflow_id, "task1", "v1", {"foo": "bar"}, {"bar": "baz"}) + + # THEN verify internal state in WorkflowProgress + assert task_id in workflow_progress.tasks + assert workflow_progress.tasks[task_id]["workflow_id"] == workflow_id + assert workflow_progress.workflows[workflow_id]["task"] == "task1" + + +def test_workflow_progress_task_status_updated(workflow_tracker_and_workflow_progress): + # GIVEN a WorkflowTracker event system and workflow progress listener + workflow_tracker: WorkflowTracker + workflow_progress: WorkflowProgress + workflow_tracker, workflow_progress = workflow_tracker_and_workflow_progress + + # WHEN a workflow is initiated and a task is added + workflow_id = workflow_tracker.initiate_workflow( + "Test Workflow", "Test Description", user=1, group=2) + task_id = workflow_tracker.add_task_to_workflow( + workflow_id, "task1", "v1", {"foo": "bar"}, {"bar": "baz"}) + + # Update the task status + status = "InProgress" + workflow_tracker.update_task_status(task_id, status) + + # THEN verify internal state in WorkflowProgress + assert task_id in workflow_progress.tasks + assert workflow_progress.tasks[task_id]["workflow_id"] == workflow_id + assert workflow_progress.workflows[workflow_id]["status"] == wfs.JOB_STATUS + status + + +def test_workflow_progress_task_progress_updated(workflow_tracker_and_workflow_progress): + # GIVEN a WorkflowTracker event system and workflow progress listener + workflow_tracker: WorkflowTracker + workflow_progress: WorkflowProgress + workflow_tracker, workflow_progress = workflow_tracker_and_workflow_progress + + # WHEN a workflow is initiated and a task is added + workflow_id = workflow_tracker.initiate_workflow( + "Test Workflow", "Test Description", user=1, group=2) + task_id = workflow_tracker.add_task_to_workflow( + workflow_id, "task1", "v1", {"foo": "bar"}, {"bar": "baz"}) + + # Update the task progress + progress = "25%" + workflow_tracker.update_task_progress(task_id, progress) + + # THEN verify internal state in WorkflowProgress + assert workflow_progress.tasks[task_id]["progress"] == progress + assert workflow_progress.workflows[workflow_id]["task_progress"] == progress + + +def test_workflow_progress_all_statuses(workflow_tracker_and_workflow_progress): + # GIVEN a WorkflowTracker event system and workflow progress listener + workflow_tracker: WorkflowTracker + workflow_progress: WorkflowProgress + workflow_tracker, workflow_progress = workflow_tracker_and_workflow_progress + + # WHEN a workflow is initiated + workflow_id = workflow_tracker.initiate_workflow( + "Test Workflow", "Test Description", user=1, group=2) + + # Add tasks with names that will trigger all branches + task_names = [ + ('_SLURM_Image_Transfer.py', 'InProgress'), + ('convert_image', 'InProgress'), # should match the convert_ condition + ('SLURM_Get_Results.py', 'InProgress'), + ('SLURM_Run_Workflow.py', 'InProgress'), + ('unknown_task', 'InProgress') + ] + + for task_name, status in task_names: + task_id = workflow_tracker.add_task_to_workflow(workflow_id, task_name, "v1", {"foo": "bar"}, {"bar": "baz"}) + workflow_tracker.update_task_status(task_id, status) # Update status to trigger logic + + # Check expected status and progress after each update + if task_name == '_SLURM_Image_Transfer.py': + expected_status = wfs.TRANSFERRING + expected_progress = "5%" + elif task_name.startswith('convert_'): + expected_status = wfs.CONVERTING + expected_progress = "25%" + elif task_name == 'SLURM_Get_Results.py': + expected_status = wfs.RETRIEVING + expected_progress = "90%" + elif task_name == 'SLURM_Run_Workflow.py': + expected_status = wfs.RUNNING + expected_progress = "50%" + else: + expected_status = wfs.JOB_STATUS + status + expected_progress = "50%" + + # Validate after each task status update + assert workflow_progress.workflows[workflow_id]["status"] == expected_status + assert workflow_progress.workflows[workflow_id]["progress"] == expected_progress + + # Introduce task progress for interpolation + # Assume a task that updates its progress + task_id = workflow_tracker.add_task_to_workflow(workflow_id, 'some_task', "v1", {"foo": "bar"}, {"bar": "baz"}) + workflow_tracker.update_task_progress(task_id, 43) # Simulate a progress update of 43% + + # Manually set a previous task's progress to trigger interpolation logic + previous_task_id = workflow_tracker.add_task_to_workflow(workflow_id, 'previous_task', "v1", {"foo": "bar"}, {"bar": "baz"}) + workflow_tracker.update_task_progress(previous_task_id, "43%") # Simulate a previous progress of 50% + + # Trigger the status update for the last task + workflow_tracker.update_task_status(task_id, 'InProgress') + + # Check the workflow's updated progress using interpolation + expected_interpolated_progress = "67.2%" + + # Assert final workflow status and interpolated progress + assert workflow_progress.workflows[workflow_id]["status"] == wfs.JOB_STATUS + 'InProgress' # Final expected status + assert workflow_progress.workflows[workflow_id]["progress"] == expected_interpolated_progress # Check interpolated progress + + def test_job_progress_job_id_added(workflow_tracker_and_job_progress): # GIVEN a WorkflowTracker event system and job progress listener workflow_tracker: WorkflowTracker From f411abd12549329ebc86c632a64a489eaf6e261d Mon Sep 17 00:00:00 2001 From: Luik Date: Tue, 8 Oct 2024 16:24:06 +0200 Subject: [PATCH 20/24] cleanup test --- tests/unit/test_eventsourcing.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/unit/test_eventsourcing.py b/tests/unit/test_eventsourcing.py index 231ccb4..695564c 100644 --- a/tests/unit/test_eventsourcing.py +++ b/tests/unit/test_eventsourcing.py @@ -894,12 +894,8 @@ def test_workflow_progress_all_statuses(workflow_tracker_and_workflow_progress): # Introduce task progress for interpolation # Assume a task that updates its progress task_id = workflow_tracker.add_task_to_workflow(workflow_id, 'some_task', "v1", {"foo": "bar"}, {"bar": "baz"}) - workflow_tracker.update_task_progress(task_id, 43) # Simulate a progress update of 43% + workflow_tracker.update_task_progress(task_id, "43%") # Simulate a progress update of 43% - # Manually set a previous task's progress to trigger interpolation logic - previous_task_id = workflow_tracker.add_task_to_workflow(workflow_id, 'previous_task', "v1", {"foo": "bar"}, {"bar": "baz"}) - workflow_tracker.update_task_progress(previous_task_id, "43%") # Simulate a previous progress of 50% - # Trigger the status update for the last task workflow_tracker.update_task_status(task_id, 'InProgress') From 655b58020dd8e45247bc2ee6024c41379d8b5b34 Mon Sep 17 00:00:00 2001 From: Luik Date: Wed, 9 Oct 2024 15:13:56 +0200 Subject: [PATCH 21/24] pin eventsourcing-sqlalchemy to main commit --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3bd31bd..cb38e72 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ "eventsourcing==9.3", "sqlalchemy==2.0.32", "psycopg2==2.9.9", - "eventsourcing_sqlalchemy @ git+https://github.com/Cellular-Imaging-Amsterdam-UMC/eventsourcing-sqlalchemy@main#egg=eventsourcing_sqlalchemy" # Specify the Git repo + "eventsourcing_sqlalchemy @ git+https://github.com/pyeventsourcing/eventsourcing-sqlalchemy@105a48a9ffa5e5573b24f3e19a48d238135ac91d#egg=eventsourcing_sqlalchemy" # Pinned to a specific commit ] [tool.setuptools.packages] From d8ba8538a35c5105113b3ccd21f2576cb5715587 Mon Sep 17 00:00:00 2001 From: Luik Date: Mon, 21 Oct 2024 17:47:14 +0200 Subject: [PATCH 22/24] Pin to release 0.9 isntead of commit --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index cb38e72..436a85d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ "eventsourcing==9.3", "sqlalchemy==2.0.32", "psycopg2==2.9.9", - "eventsourcing_sqlalchemy @ git+https://github.com/pyeventsourcing/eventsourcing-sqlalchemy@105a48a9ffa5e5573b24f3e19a48d238135ac91d#egg=eventsourcing_sqlalchemy" # Pinned to a specific commit + "eventsourcing_sqlalchemy==0.9" ] [tool.setuptools.packages] From 8a9e71781f68873fdc2c2009f60c9265c6f2123d Mon Sep 17 00:00:00 2001 From: Luik Date: Mon, 18 Nov 2024 15:58:52 +0100 Subject: [PATCH 23/24] Turn prints into logs --- biomero/slurm_client.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/biomero/slurm_client.py b/biomero/slurm_client.py index af581cd..7172d0e 100644 --- a/biomero/slurm_client.py +++ b/biomero/slurm_client.py @@ -918,12 +918,11 @@ def from_config(cls, configfile: str = '', slurm_model_jobs[k[:-len(suffix_job)]] = v slurm_model_jobs_params[k[:-len(suffix_job)]] = [] elif job_param_match: - print(f"Match: {slurm_model_jobs_params}") slurm_model_jobs_params[job_param_match.group(1)].append( f" --{job_param_match.group(2)}={v}") - print(f"Added: {slurm_model_jobs_params}") else: slurm_model_paths[k] = v + logger.info(f"Using job params: {slurm_model_jobs_params}") slurm_script_path = configs.get( "SLURM", "slurm_script_path", From baf5ae3cef20a8f67c781b366b017f59edea0418 Mon Sep 17 00:00:00 2001 From: Luik Date: Thu, 28 Nov 2024 13:53:03 +0100 Subject: [PATCH 24/24] Add a jobid-to-task query --- biomero/views.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/biomero/views.py b/biomero/views.py index 1f0170e..fa6fdb3 100644 --- a/biomero/views.py +++ b/biomero/views.py @@ -151,6 +151,25 @@ def get_jobs(self, user=None, group=None): logger.debug(f"Retrieved jobs for user={user} and group={group}: {result}") return result + def get_task_id(self, job_id): + """ + Retrieve the task ID associated with a given job ID. + + Parameters: + - job_id (int): The job ID (slurm_job_id) to look up. + + Returns: + - UUID: The task ID associated with the job ID, or None if not found. + """ + with EngineManager.get_session() as session: + try: + task_id = session.query(JobView.task_id).filter(JobView.slurm_job_id == job_id).one_or_none() + logger.debug(f"Retrieved task_id={task_id[0] if task_id else None} for job_id={job_id}") + return task_id[0] if task_id else None + except Exception as e: + logger.error(f"Failed to retrieve task_id for job_id={job_id}: {e}") + return None + class WorkflowProgress(ProcessApplication): def __init__(self, *args, **kwargs):