diff --git a/CHANGES b/CHANGES index 11fe15a4..103526c9 100644 --- a/CHANGES +++ b/CHANGES @@ -12,6 +12,14 @@ The rules for this file: * accompany each entry with github issue/PR number (Issue #xyz) * release numbers follow "Semantic Versioning" https://semver.org +**/**/**** xiki-tempula + + * 2.5.0 + +Enhancements + - Parallelise read and preprocess for ABFE workflow. (PR #371) + + 09/19/2024 orbeckst, jaclark5 * 2.4.1 diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index bef524b7..abe08740 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -11,6 +11,7 @@ dependencies: - matplotlib>=3.7 - loguru - pyarrow +- joblib # Testing - pytest diff --git a/docs/workflows/alchemlyb.workflows.ABFE.rst b/docs/workflows/alchemlyb.workflows.ABFE.rst index a8dc409a..f6933253 100644 --- a/docs/workflows/alchemlyb.workflows.ABFE.rst +++ b/docs/workflows/alchemlyb.workflows.ABFE.rst @@ -153,6 +153,43 @@ to the data generated at each stage of the analysis. :: >>> # Convergence analysis >>> workflow.check_convergence(10, dF_t='dF_t.pdf') +Parallelisation of Data Reading and Decorrelation +------------------------------------------------- + +The estimation step of the workflow is parallelized using JAX. However, the +reading and decorrelation stages can be parallelized using `joblib`. This is +achieved by passing the number of jobs to run in parallel via the `n_jobs` +parameter to the following methods: + +- :meth:`~alchemlyb.workflows.ABFE.read` +- :meth:`~alchemlyb.workflows.ABFE.preprocess` + +To enable parallel execution, specify the `n_jobs` parameter. Setting +`n_jobs=-1` allows the use of all available resources. :: + + >>> workflow = ABFE(units='kcal/mol', software='GROMACS', dir=dir, + >>> prefix='dhdl', suffix='xvg', T=298, outdirectory='./') + >>> workflow.read(n_jobs=-1) + >>> workflow.preprocess(n_jobs=-1) + +In a fully automated mode, the `n_jobs=-1` parameter can be passed directly to +the :meth:`~alchemlyb.workflows.ABFE.run` method. This will implicitly +parallelise the reading and decorrelation stages. :: + + >>> workflow = ABFE(units='kcal/mol', software='GROMACS', dir=dir, + >>> prefix='dhdl', suffix='xvg', T=298, outdirectory='./') + >>> workflow.run(n_jobs=-1) + +While the default `joblib` settings are suitable for most environments, you +can customize the parallelisation backend depending on the infrastructure. For +example, using the threading backend can be specified as follows. :: + + >>> import joblib + >>> workflow = ABFE(units='kcal/mol', software='GROMACS', dir=dir, + >>> prefix='dhdl', suffix='xvg', T=298, outdirectory='./') + >>> with joblib.parallel_config(backend="threading"): + >>> workflow.run(n_jobs=-1) + API Reference ------------- .. autoclass:: alchemlyb.workflows.ABFE diff --git a/environment.yml b/environment.yml index 95512668..ced2721e 100644 --- a/environment.yml +++ b/environment.yml @@ -11,3 +11,4 @@ dependencies: - pyarrow - matplotlib>=3.7 - loguru +- joblib diff --git a/pyproject.toml b/pyproject.toml index 7088f10a..e88de469 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ dependencies = [ "matplotlib>=3.7", "loguru", "pyarrow", + "joblib", ] diff --git a/src/alchemlyb/tests/test_workflow_ABFE.py b/src/alchemlyb/tests/test_workflow_ABFE.py index ed25646d..a679641a 100644 --- a/src/alchemlyb/tests/test_workflow_ABFE.py +++ b/src/alchemlyb/tests/test_workflow_ABFE.py @@ -1,9 +1,11 @@ import os import numpy as np +import pandas as pd import pytest from alchemtest.amber import load_bace_example from alchemtest.gmx import load_ABFE +import joblib import alchemlyb.parsing.amber from alchemlyb.workflows.abfe import ABFE @@ -88,7 +90,11 @@ def test_single_estimator(self, workflow, monkeypatch): monkeypatch.setattr(workflow, "dHdl_sample_list", []) monkeypatch.setattr(workflow, "estimator", dict()) workflow.run( - uncorr=None, estimators="MBAR", overlap=None, breakdown=True, forwrev=None + uncorr=None, + estimators="MBAR", + overlap=None, + breakdown=True, + forwrev=None, ) assert "MBAR" in workflow.estimator @@ -96,7 +102,11 @@ def test_single_estimator(self, workflow, monkeypatch): def test_no_forwrev(self, workflow, monkeypatch, forwrev): monkeypatch.setattr(workflow, "convergence", None) workflow.run( - uncorr=None, estimators=None, overlap=None, breakdown=None, forwrev=forwrev + uncorr=None, + estimators=None, + overlap=None, + breakdown=None, + forwrev=forwrev, ) assert workflow.convergence is None @@ -445,3 +455,58 @@ def test_summary(self, workflow): """Test if if the summary is right.""" summary = workflow.generate_result() assert np.isclose(summary["BAR"]["Stages"]["TOTAL"], 1.40405980473, 0.1) + + +class TestParallel: + @pytest.fixture(scope="class") + def workflow(self, tmp_path_factory): + outdir = tmp_path_factory.mktemp("out") + (outdir / "dhdl_00.xvg").symlink_to(load_ABFE()["data"]["complex"][0]) + (outdir / "dhdl_01.xvg").symlink_to(load_ABFE()["data"]["complex"][1]) + workflow = ABFE( + units="kcal/mol", + software="GROMACS", + dir=str(outdir), + prefix="dhdl", + suffix="xvg", + T=310, + ) + workflow.read() + workflow.preprocess() + return workflow + + @pytest.fixture(scope="class") + def parallel_workflow(self, tmp_path_factory): + outdir = tmp_path_factory.mktemp("out") + (outdir / "dhdl_00.xvg").symlink_to(load_ABFE()["data"]["complex"][0]) + (outdir / "dhdl_01.xvg").symlink_to(load_ABFE()["data"]["complex"][1]) + workflow = ABFE( + units="kcal/mol", + software="GROMACS", + dir=str(outdir), + prefix="dhdl", + suffix="xvg", + T=310, + ) + with joblib.parallel_config(backend="threading"): + # The default backend is "loky", which is more robust but somehow didn't + # play well with pytest, but "loky" is perfectly fine outside pytest. + workflow.read(n_jobs=2) + workflow.preprocess(n_jobs=2) + return workflow + + def test_read(self, workflow, parallel_workflow): + pd.testing.assert_frame_equal( + workflow.u_nk_list[0], parallel_workflow.u_nk_list[0] + ) + pd.testing.assert_frame_equal( + workflow.u_nk_list[1], parallel_workflow.u_nk_list[1] + ) + + def test_preprocess(self, workflow, parallel_workflow): + pd.testing.assert_frame_equal( + workflow.u_nk_sample_list[0], parallel_workflow.u_nk_sample_list[0] + ) + pd.testing.assert_frame_equal( + workflow.u_nk_sample_list[1], parallel_workflow.u_nk_sample_list[1] + ) diff --git a/src/alchemlyb/workflows/abfe.py b/src/alchemlyb/workflows/abfe.py index 751220b4..b47b2618 100644 --- a/src/alchemlyb/workflows/abfe.py +++ b/src/alchemlyb/workflows/abfe.py @@ -6,6 +6,7 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd +import joblib from loguru import logger from .base import WorkflowBase @@ -115,7 +116,7 @@ def __init__( else: raise NotImplementedError(f"{software} parser not found.") - def read(self, read_u_nk=True, read_dHdl=True): + def read(self, read_u_nk=True, read_dHdl=True, n_jobs=1): """Read the u_nk and dHdL data from the :attr:`~alchemlyb.workflows.ABFE.file_list` @@ -125,6 +126,9 @@ def read(self, read_u_nk=True, read_dHdl=True): Whether to read the u_nk. read_dHdl : bool Whether to read the dHdl. + n_jobs : int + Number of parallel workers to use for reading the data. + (-1 means using all the threads) Attributes ---------- @@ -136,29 +140,43 @@ def read(self, read_u_nk=True, read_dHdl=True): self.u_nk_sample_list = None self.dHdl_sample_list = None - u_nk_list = [] - dHdl_list = [] - for file in self.file_list: - if read_u_nk: + if read_u_nk: + def extract_u_nk(_extract_u_nk, file, T): try: - u_nk = self._extract_u_nk(file, T=self.T) + u_nk = _extract_u_nk(file, T) logger.info(f"Reading {len(u_nk)} lines of u_nk from {file}") - u_nk_list.append(u_nk) + return u_nk except Exception as exc: msg = f"Error reading u_nk from {file}." logger.error(msg) raise OSError(msg) from exc - if read_dHdl: + u_nk_list = joblib.Parallel(n_jobs=n_jobs)( + joblib.delayed(extract_u_nk)(self._extract_u_nk, file, self.T) + for file in self.file_list + ) + else: + u_nk_list = [] + + if read_dHdl: + + def extract_dHdl(_extract_dHdl, file, T): try: - dhdl = self._extract_dHdl(file, T=self.T) + dhdl = _extract_dHdl(file, T) logger.info(f"Reading {len(dhdl)} lines of dhdl from {file}") - dHdl_list.append(dhdl) + return dhdl except Exception as exc: msg = f"Error reading dHdl from {file}." logger.error(msg) raise OSError(msg) from exc + dHdl_list = joblib.Parallel(n_jobs=n_jobs)( + joblib.delayed(extract_dHdl)(self._extract_dHdl, file, self.T) + for file in self.file_list + ) + else: + dHdl_list = [] + # Sort the files according to the state if read_u_nk: logger.info("Sort files according to the u_nk.") @@ -201,6 +219,7 @@ def run( overlap="O_MBAR.pdf", breakdown=True, forwrev=None, + n_jobs=1, *args, **kwargs, ): @@ -236,6 +255,9 @@ def run( contain u_nk, please run meth:`~alchemlyb.workflows.ABFE.check_convergence` manually with estimator='TI'. + n_jobs : int + Number of parallel workers to use for reading and decorrelating the data. + (-1 means using all the threads) Attributes ---------- @@ -266,11 +288,12 @@ def run( ) logger.error(msg) raise ValueError(msg) - - self.read(use_FEP, use_TI) + self.read(read_u_nk=use_FEP, read_dHdl=use_TI, n_jobs=n_jobs) if uncorr is not None: - self.preprocess(skiptime=skiptime, uncorr=uncorr, threshold=threshold) + self.preprocess( + skiptime=skiptime, uncorr=uncorr, threshold=threshold, n_jobs=n_jobs + ) if estimators is not None: self.estimate(estimators) self.generate_result() @@ -307,7 +330,7 @@ def update_units(self, units=None): logger.info(f"Set unit to {units}.") self.units = units or None - def preprocess(self, skiptime=0, uncorr="dE", threshold=50): + def preprocess(self, skiptime=0, uncorr="dE", threshold=50, n_jobs=1): """Preprocess the data by removing the equilibration time and decorrelate the date. @@ -322,6 +345,9 @@ def preprocess(self, skiptime=0, uncorr="dE", threshold=50): Proceed with correlated samples if the number of uncorrelated samples is found to be less than this number. If 0 is given, the time series analysis will not be performed at all. Default: 50. + n_jobs : int + Number of parallel workers to use for decorrelating the data. + (-1 means using all the threads) Attributes ---------- @@ -338,13 +364,9 @@ def preprocess(self, skiptime=0, uncorr="dE", threshold=50): if len(self.u_nk_list) > 0: logger.info(f"Processing the u_nk data set with skiptime of {skiptime}.") - self.u_nk_sample_list = [] - for index, u_nk in enumerate(self.u_nk_list): - # Find the starting frame - + def _decorrelate_u_nk(u_nk, skiptime, threshold, index): u_nk = u_nk[u_nk.index.get_level_values("time") >= skiptime] subsample = decorrelate_u_nk(u_nk, uncorr, remove_burnin=True) - if len(subsample) < threshold: logger.warning( f"Number of u_nk {len(subsample)} " @@ -352,19 +374,24 @@ def preprocess(self, skiptime=0, uncorr="dE", threshold=50): f"threshold {threshold}." ) logger.info(f"Take all the u_nk for state {index}.") - self.u_nk_sample_list.append(u_nk) + subsample = u_nk else: logger.info( f"Take {len(subsample)} uncorrelated " f"u_nk for state {index}." ) - self.u_nk_sample_list.append(subsample) + return subsample + + self.u_nk_sample_list = joblib.Parallel(n_jobs=n_jobs)( + joblib.delayed(_decorrelate_u_nk)(u_nk, skiptime, threshold, index) + for index, u_nk in enumerate(self.u_nk_list) + ) else: logger.info("No u_nk data being subsampled") if len(self.dHdl_list) > 0: - self.dHdl_sample_list = [] - for index, dHdl in enumerate(self.dHdl_list): + + def _decorrelate_dhdl(dHdl, skiptime, threshold, index): dHdl = dHdl[dHdl.index.get_level_values("time") >= skiptime] subsample = decorrelate_dhdl(dHdl, remove_burnin=True) if len(subsample) < threshold: @@ -374,13 +401,18 @@ def preprocess(self, skiptime=0, uncorr="dE", threshold=50): f"threshold {threshold}." ) logger.info(f"Take all the dHdl for state {index}.") - self.dHdl_sample_list.append(dHdl) + subsample = dHdl else: logger.info( f"Take {len(subsample)} uncorrelated " f"dHdl for state {index}." ) - self.dHdl_sample_list.append(subsample) + return subsample + + self.dHdl_sample_list = joblib.Parallel(n_jobs=n_jobs)( + joblib.delayed(_decorrelate_dhdl)(dHdl, skiptime, threshold, index) + for index, dHdl in enumerate(self.dHdl_list) + ) else: logger.info("No dHdl data being subsampled")