From da36cff1ba4b24817888c6556adecda521deb89a Mon Sep 17 00:00:00 2001 From: Zhiyi Wu Date: Mon, 27 May 2024 21:30:11 +0100 Subject: [PATCH 01/12] update --- src/alchemlyb/workflows/abfe.py | 34 ++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/src/alchemlyb/workflows/abfe.py b/src/alchemlyb/workflows/abfe.py index 751220b4..90cc715c 100644 --- a/src/alchemlyb/workflows/abfe.py +++ b/src/alchemlyb/workflows/abfe.py @@ -6,6 +6,7 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd +from joblib import Parallel, delayed from loguru import logger from .base import WorkflowBase @@ -136,29 +137,44 @@ def read(self, read_u_nk=True, read_dHdl=True): self.u_nk_sample_list = None self.dHdl_sample_list = None - u_nk_list = [] - dHdl_list = [] - for file in self.file_list: - if read_u_nk: + if read_u_nk: + + def extract_u_nk(_extract_u_nk, file, T): try: - u_nk = self._extract_u_nk(file, T=self.T) + u_nk = _extract_u_nk(file, T) logger.info(f"Reading {len(u_nk)} lines of u_nk from {file}") - u_nk_list.append(u_nk) + return u_nk except Exception as exc: msg = f"Error reading u_nk from {file}." logger.error(msg) raise OSError(msg) from exc - if read_dHdl: + u_nk_list = Parallel(n_jobs=-1)( + delayed(extract_u_nk)(self._extract_u_nk, file, self.T) + for file in self.file_list + ) + else: + u_nk_list = [] + + if read_dHdl: + + def extract_dHdl(_extract_dHdl, file, T): try: - dhdl = self._extract_dHdl(file, T=self.T) + dhdl = _extract_dHdl(file, T) logger.info(f"Reading {len(dhdl)} lines of dhdl from {file}") - dHdl_list.append(dhdl) + return dhdl except Exception as exc: msg = f"Error reading dHdl from {file}." logger.error(msg) raise OSError(msg) from exc + dHdl_list = Parallel(n_jobs=-1)( + delayed(extract_dHdl)(self._extract_dHdl, file, self.T) + for file in self.file_list + ) + else: + dHdl_list = [] + # Sort the files according to the state if read_u_nk: logger.info("Sort files according to the u_nk.") From 5ed0ce56d238ca99a5f26580f63c275e327d3f51 Mon Sep 17 00:00:00 2001 From: Zhiyi Wu Date: Mon, 27 May 2024 21:38:18 +0100 Subject: [PATCH 02/12] update --- src/alchemlyb/workflows/abfe.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/src/alchemlyb/workflows/abfe.py b/src/alchemlyb/workflows/abfe.py index 90cc715c..1e491c20 100644 --- a/src/alchemlyb/workflows/abfe.py +++ b/src/alchemlyb/workflows/abfe.py @@ -354,13 +354,9 @@ def preprocess(self, skiptime=0, uncorr="dE", threshold=50): if len(self.u_nk_list) > 0: logger.info(f"Processing the u_nk data set with skiptime of {skiptime}.") - self.u_nk_sample_list = [] - for index, u_nk in enumerate(self.u_nk_list): - # Find the starting frame - + def _decorrelate_u_nk(u_nk, skiptime, threshold, index): u_nk = u_nk[u_nk.index.get_level_values("time") >= skiptime] subsample = decorrelate_u_nk(u_nk, uncorr, remove_burnin=True) - if len(subsample) < threshold: logger.warning( f"Number of u_nk {len(subsample)} " @@ -368,19 +364,24 @@ def preprocess(self, skiptime=0, uncorr="dE", threshold=50): f"threshold {threshold}." ) logger.info(f"Take all the u_nk for state {index}.") - self.u_nk_sample_list.append(u_nk) + subsample = u_nk else: logger.info( f"Take {len(subsample)} uncorrelated " f"u_nk for state {index}." ) - self.u_nk_sample_list.append(subsample) + return subsample + + self.u_nk_sample_list = Parallel(n_jobs=-1)( + delayed(_decorrelate_u_nk)(u_nk, skiptime, threshold) + for index, u_nk in enumerate(self.u_nk_list) + ) else: logger.info("No u_nk data being subsampled") if len(self.dHdl_list) > 0: - self.dHdl_sample_list = [] - for index, dHdl in enumerate(self.dHdl_list): + + def _decorrelate_dhdl(dHdl, skiptime, threshold, index): dHdl = dHdl[dHdl.index.get_level_values("time") >= skiptime] subsample = decorrelate_dhdl(dHdl, remove_burnin=True) if len(subsample) < threshold: @@ -390,13 +391,18 @@ def preprocess(self, skiptime=0, uncorr="dE", threshold=50): f"threshold {threshold}." ) logger.info(f"Take all the dHdl for state {index}.") - self.dHdl_sample_list.append(dHdl) + subsample = dHdl else: logger.info( f"Take {len(subsample)} uncorrelated " f"dHdl for state {index}." ) - self.dHdl_sample_list.append(subsample) + return subsample + + self.dHdl_sample_list = Parallel(n_jobs=-1)( + delayed(_decorrelate_dhdl)(dHdl, skiptime, threshold, index) + for index, dHdl in enumerate(self.dHdl_list) + ) else: logger.info("No dHdl data being subsampled") From d348e76e79eb4463ec9cd98a02167e7c6a952ae8 Mon Sep 17 00:00:00 2001 From: zhiyi wu Date: Wed, 29 May 2024 10:23:50 +0100 Subject: [PATCH 03/12] update --- src/alchemlyb/workflows/abfe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/alchemlyb/workflows/abfe.py b/src/alchemlyb/workflows/abfe.py index 1e491c20..831ae8d3 100644 --- a/src/alchemlyb/workflows/abfe.py +++ b/src/alchemlyb/workflows/abfe.py @@ -373,7 +373,7 @@ def _decorrelate_u_nk(u_nk, skiptime, threshold, index): return subsample self.u_nk_sample_list = Parallel(n_jobs=-1)( - delayed(_decorrelate_u_nk)(u_nk, skiptime, threshold) + delayed(_decorrelate_u_nk)(u_nk, skiptime, threshold, index) for index, u_nk in enumerate(self.u_nk_list) ) else: From fda5ef778817065415a4549e3c9ea1639ebee354 Mon Sep 17 00:00:00 2001 From: Zhiyi Wu Date: Wed, 29 May 2024 20:30:08 +0100 Subject: [PATCH 04/12] update --- src/alchemlyb/workflows/abfe.py | 77 +++++++++++++++++++++++++-------- 1 file changed, 58 insertions(+), 19 deletions(-) diff --git a/src/alchemlyb/workflows/abfe.py b/src/alchemlyb/workflows/abfe.py index 831ae8d3..ca1df8b6 100644 --- a/src/alchemlyb/workflows/abfe.py +++ b/src/alchemlyb/workflows/abfe.py @@ -6,7 +6,13 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd -from joblib import Parallel, delayed + +try: + from joblib import Parallel, delayed + + has_joblib = True +except ImportError: + has_joblib = False from loguru import logger from .base import WorkflowBase @@ -116,7 +122,7 @@ def __init__( else: raise NotImplementedError(f"{software} parser not found.") - def read(self, read_u_nk=True, read_dHdl=True): + def read(self, read_u_nk=True, read_dHdl=True, n_jobs=-1): """Read the u_nk and dHdL data from the :attr:`~alchemlyb.workflows.ABFE.file_list` @@ -126,6 +132,8 @@ def read(self, read_u_nk=True, read_dHdl=True): Whether to read the u_nk. read_dHdl : bool Whether to read the dHdl. + n_jobs : int + Number of parallel workers to use for reading the data. Attributes ---------- @@ -149,10 +157,16 @@ def extract_u_nk(_extract_u_nk, file, T): logger.error(msg) raise OSError(msg) from exc - u_nk_list = Parallel(n_jobs=-1)( - delayed(extract_u_nk)(self._extract_u_nk, file, self.T) - for file in self.file_list - ) + if has_joblib: + u_nk_list = Parallel(n_jobs=n_jobs)( + delayed(extract_u_nk)(self._extract_u_nk, file, self.T) + for file in self.file_list + ) + else: + u_nk_list = [ + extract_u_nk(self._extract_u_nk, file, self.T) + for file in self.file_list + ] else: u_nk_list = [] @@ -168,10 +182,17 @@ def extract_dHdl(_extract_dHdl, file, T): logger.error(msg) raise OSError(msg) from exc - dHdl_list = Parallel(n_jobs=-1)( - delayed(extract_dHdl)(self._extract_dHdl, file, self.T) - for file in self.file_list - ) + if has_joblib: + dHdl_list = Parallel(n_jobs=n_jobs)( + delayed(extract_dHdl)(self._extract_dHdl, file, self.T) + for file in self.file_list + ) + else: + dHdl_list = [ + extract_dHdl(self._extract_dHdl, file, self.T) + for file in self.file_list + ] + else: dHdl_list = [] @@ -217,6 +238,7 @@ def run( overlap="O_MBAR.pdf", breakdown=True, forwrev=None, + n_jobs=-1, *args, **kwargs, ): @@ -252,6 +274,8 @@ def run( contain u_nk, please run meth:`~alchemlyb.workflows.ABFE.check_convergence` manually with estimator='TI'. + n_jobs : int + Number of parallel workers to use for reading and decorrelating the data. Attributes ---------- @@ -323,7 +347,7 @@ def update_units(self, units=None): logger.info(f"Set unit to {units}.") self.units = units or None - def preprocess(self, skiptime=0, uncorr="dE", threshold=50): + def preprocess(self, skiptime=0, uncorr="dE", threshold=50, n_jobs=-1): """Preprocess the data by removing the equilibration time and decorrelate the date. @@ -338,6 +362,8 @@ def preprocess(self, skiptime=0, uncorr="dE", threshold=50): Proceed with correlated samples if the number of uncorrelated samples is found to be less than this number. If 0 is given, the time series analysis will not be performed at all. Default: 50. + n_jobs : int + Number of parallel workers to use for decorrelating the data. Attributes ---------- @@ -372,10 +398,16 @@ def _decorrelate_u_nk(u_nk, skiptime, threshold, index): ) return subsample - self.u_nk_sample_list = Parallel(n_jobs=-1)( - delayed(_decorrelate_u_nk)(u_nk, skiptime, threshold, index) - for index, u_nk in enumerate(self.u_nk_list) - ) + if has_joblib: + self.u_nk_sample_list = Parallel(n_jobs=n_jobs)( + delayed(_decorrelate_u_nk)(u_nk, skiptime, threshold, index) + for index, u_nk in enumerate(self.u_nk_list) + ) + else: + self.u_nk_sample_list = [ + _decorrelate_u_nk(u_nk, skiptime, threshold, index) + for index, u_nk in enumerate(self.u_nk_list) + ] else: logger.info("No u_nk data being subsampled") @@ -399,10 +431,17 @@ def _decorrelate_dhdl(dHdl, skiptime, threshold, index): ) return subsample - self.dHdl_sample_list = Parallel(n_jobs=-1)( - delayed(_decorrelate_dhdl)(dHdl, skiptime, threshold, index) - for index, dHdl in enumerate(self.dHdl_list) - ) + if has_joblib: + self.dHdl_sample_list = Parallel(n_jobs=n_jobs)( + delayed(_decorrelate_dhdl)(dHdl, skiptime, threshold, index) + for index, dHdl in enumerate(self.dHdl_list) + ) + else: + self.dHdl_sample_list = [ + _decorrelate_dhdl(dHdl, skiptime, threshold, index) + for index, dHdl in enumerate(self.dHdl_list) + ] + else: logger.info("No dHdl data being subsampled") From 0e089a60647abdde54866b4721d2b28c70739928 Mon Sep 17 00:00:00 2001 From: Zhiyi Wu Date: Wed, 29 May 2024 20:33:24 +0100 Subject: [PATCH 05/12] update --- src/alchemlyb/workflows/abfe.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/alchemlyb/workflows/abfe.py b/src/alchemlyb/workflows/abfe.py index ca1df8b6..193edb2d 100644 --- a/src/alchemlyb/workflows/abfe.py +++ b/src/alchemlyb/workflows/abfe.py @@ -306,11 +306,12 @@ def run( ) logger.error(msg) raise ValueError(msg) - - self.read(use_FEP, use_TI) + self.read(read_u_nk=use_FEP, read_dHdl=use_FEP, n_jobs=n_jobs) if uncorr is not None: - self.preprocess(skiptime=skiptime, uncorr=uncorr, threshold=threshold) + self.preprocess( + skiptime=skiptime, uncorr=uncorr, threshold=threshold, n_jobs=n_jobs + ) if estimators is not None: self.estimate(estimators) self.generate_result() From 3fa533af9e3d893e73115dee372c4eb298b6b184 Mon Sep 17 00:00:00 2001 From: Zhiyi Wu Date: Wed, 29 May 2024 20:45:10 +0100 Subject: [PATCH 06/12] update --- src/alchemlyb/workflows/abfe.py | 66 +++++++++------------------------ 1 file changed, 17 insertions(+), 49 deletions(-) diff --git a/src/alchemlyb/workflows/abfe.py b/src/alchemlyb/workflows/abfe.py index 193edb2d..ca4f259e 100644 --- a/src/alchemlyb/workflows/abfe.py +++ b/src/alchemlyb/workflows/abfe.py @@ -6,13 +6,7 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd - -try: - from joblib import Parallel, delayed - - has_joblib = True -except ImportError: - has_joblib = False +from joblib import Parallel, delayed from loguru import logger from .base import WorkflowBase @@ -157,16 +151,10 @@ def extract_u_nk(_extract_u_nk, file, T): logger.error(msg) raise OSError(msg) from exc - if has_joblib: - u_nk_list = Parallel(n_jobs=n_jobs)( - delayed(extract_u_nk)(self._extract_u_nk, file, self.T) - for file in self.file_list - ) - else: - u_nk_list = [ - extract_u_nk(self._extract_u_nk, file, self.T) - for file in self.file_list - ] + u_nk_list = Parallel(n_jobs=n_jobs)( + delayed(extract_u_nk)(self._extract_u_nk, file, self.T) + for file in self.file_list + ) else: u_nk_list = [] @@ -182,17 +170,10 @@ def extract_dHdl(_extract_dHdl, file, T): logger.error(msg) raise OSError(msg) from exc - if has_joblib: - dHdl_list = Parallel(n_jobs=n_jobs)( - delayed(extract_dHdl)(self._extract_dHdl, file, self.T) - for file in self.file_list - ) - else: - dHdl_list = [ - extract_dHdl(self._extract_dHdl, file, self.T) - for file in self.file_list - ] - + dHdl_list = Parallel(n_jobs=n_jobs)( + delayed(extract_dHdl)(self._extract_dHdl, file, self.T) + for file in self.file_list + ) else: dHdl_list = [] @@ -399,16 +380,10 @@ def _decorrelate_u_nk(u_nk, skiptime, threshold, index): ) return subsample - if has_joblib: - self.u_nk_sample_list = Parallel(n_jobs=n_jobs)( - delayed(_decorrelate_u_nk)(u_nk, skiptime, threshold, index) - for index, u_nk in enumerate(self.u_nk_list) - ) - else: - self.u_nk_sample_list = [ - _decorrelate_u_nk(u_nk, skiptime, threshold, index) - for index, u_nk in enumerate(self.u_nk_list) - ] + self.u_nk_sample_list = Parallel(n_jobs=n_jobs)( + delayed(_decorrelate_u_nk)(u_nk, skiptime, threshold, index) + for index, u_nk in enumerate(self.u_nk_list) + ) else: logger.info("No u_nk data being subsampled") @@ -432,17 +407,10 @@ def _decorrelate_dhdl(dHdl, skiptime, threshold, index): ) return subsample - if has_joblib: - self.dHdl_sample_list = Parallel(n_jobs=n_jobs)( - delayed(_decorrelate_dhdl)(dHdl, skiptime, threshold, index) - for index, dHdl in enumerate(self.dHdl_list) - ) - else: - self.dHdl_sample_list = [ - _decorrelate_dhdl(dHdl, skiptime, threshold, index) - for index, dHdl in enumerate(self.dHdl_list) - ] - + self.dHdl_sample_list = Parallel(n_jobs=n_jobs)( + delayed(_decorrelate_dhdl)(dHdl, skiptime, threshold, index) + for index, dHdl in enumerate(self.dHdl_list) + ) else: logger.info("No dHdl data being subsampled") From 77eee59827347244cdecacb2d74a09f832d17633 Mon Sep 17 00:00:00 2001 From: Zhiyi Wu Date: Wed, 29 May 2024 20:57:46 +0100 Subject: [PATCH 07/12] fix test --- src/alchemlyb/tests/test_workflow_ABFE.py | 32 ++++++++++++++++------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/src/alchemlyb/tests/test_workflow_ABFE.py b/src/alchemlyb/tests/test_workflow_ABFE.py index ed25646d..86a1cff4 100644 --- a/src/alchemlyb/tests/test_workflow_ABFE.py +++ b/src/alchemlyb/tests/test_workflow_ABFE.py @@ -30,6 +30,7 @@ def workflow(tmp_path_factory): overlap="O_MBAR.pdf", breakdown=True, forwrev=10, + n_jobs=1, ) return workflow @@ -79,6 +80,7 @@ def test_invalid_estimator(self, workflow): overlap=None, breakdown=None, forwrev=None, + n_jobs=1, ) def test_single_estimator(self, workflow, monkeypatch): @@ -88,7 +90,12 @@ def test_single_estimator(self, workflow, monkeypatch): monkeypatch.setattr(workflow, "dHdl_sample_list", []) monkeypatch.setattr(workflow, "estimator", dict()) workflow.run( - uncorr=None, estimators="MBAR", overlap=None, breakdown=True, forwrev=None + uncorr=None, + estimators="MBAR", + overlap=None, + breakdown=True, + forwrev=None, + n_jobs=1, ) assert "MBAR" in workflow.estimator @@ -96,7 +103,12 @@ def test_single_estimator(self, workflow, monkeypatch): def test_no_forwrev(self, workflow, monkeypatch, forwrev): monkeypatch.setattr(workflow, "convergence", None) workflow.run( - uncorr=None, estimators=None, overlap=None, breakdown=None, forwrev=forwrev + uncorr=None, + estimators=None, + overlap=None, + breakdown=None, + forwrev=forwrev, + n_jobs=1, ) assert workflow.convergence is None @@ -128,7 +140,7 @@ def test_read_TI_FEP(self, workflow, monkeypatch, read_u_nk, read_dHdl): monkeypatch.setattr(workflow, "dHdl_list", []) monkeypatch.setattr(workflow, "u_nk_sample_list", []) monkeypatch.setattr(workflow, "dHdl_sample_list", []) - workflow.read(read_u_nk, read_dHdl) + workflow.read(read_u_nk, read_dHdl, n_jobs=1) if read_u_nk: assert len(workflow.u_nk_list) == 30 else: @@ -148,7 +160,7 @@ def extract_u_nk(self, T): monkeypatch.setattr(workflow, "_extract_u_nk", extract_u_nk) with pytest.raises(OSError, match=r"Error reading u_nk"): - workflow.read() + workflow.read(n_jobs=1) def test_read_invalid_dHdl(self, workflow, monkeypatch): monkeypatch.setattr(workflow, "u_nk_sample_list", []) @@ -159,7 +171,7 @@ def extract_dHdl(self, T): monkeypatch.setattr(workflow, "_extract_dHdl", extract_dHdl) with pytest.raises(OSError, match=r"Error reading dHdl"): - workflow.read() + workflow.read(n_jobs=1) class TestSubsample: @@ -181,7 +193,7 @@ def test_uncorr_threshold(self, workflow, monkeypatch): ) monkeypatch.setattr(workflow, "u_nk_sample_list", []) monkeypatch.setattr(workflow, "dHdl_sample_list", []) - workflow.preprocess(threshold=50) + workflow.preprocess(threshold=50, n_jobs=1) assert all([len(u_nk) == 40 for u_nk in workflow.u_nk_sample_list]) assert all([len(dHdl) == 40 for dHdl in workflow.dHdl_sample_list]) @@ -189,14 +201,14 @@ def test_no_u_nk_preprocess(self, workflow, monkeypatch): monkeypatch.setattr(workflow, "u_nk_list", []) monkeypatch.setattr(workflow, "u_nk_sample_list", []) monkeypatch.setattr(workflow, "dHdl_sample_list", []) - workflow.preprocess(threshold=50) + workflow.preprocess(threshold=50, n_jobs=1) assert len(workflow.u_nk_list) == 0 def test_no_dHdl_preprocess(self, workflow, monkeypatch): monkeypatch.setattr(workflow, "dHdl_list", []) monkeypatch.setattr(workflow, "u_nk_sample_list", []) monkeypatch.setattr(workflow, "dHdl_sample_list", []) - workflow.preprocess(threshold=50) + workflow.preprocess(threshold=50, n_jobs=1) assert len(workflow.dHdl_list) == 0 @@ -407,7 +419,7 @@ def workflow(tmp_path_factory): T=298.0, outdirectory=str(outdir), ) - workflow.read() + workflow.read(n_jobs=1) workflow.estimate(estimators="TI") return workflow @@ -437,7 +449,7 @@ def workflow(tmp_path_factory): T=298.0, outdirectory=str(outdir), ) - workflow.read() + workflow.read(n_jobs=1) workflow.estimate(estimators="BAR") return workflow From acae6b5066b3d420755fe801a6992635e94af70c Mon Sep 17 00:00:00 2001 From: Zhiyi Wu Date: Wed, 29 May 2024 21:12:58 +0100 Subject: [PATCH 08/12] update --- environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/environment.yml b/environment.yml index 025ff953..69c696cf 100644 --- a/environment.yml +++ b/environment.yml @@ -11,3 +11,4 @@ dependencies: - pyarrow - matplotlib - loguru +- joblib From 6b82e09328381f89fcbb61ddbea3ddbb6384321b Mon Sep 17 00:00:00 2001 From: Zhiyi Wu Date: Wed, 29 May 2024 21:33:18 +0100 Subject: [PATCH 09/12] update --- CHANGES | 5 +++- src/alchemlyb/tests/test_workflow_ABFE.py | 29 +++++++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/CHANGES b/CHANGES index 990ae9d8..03d0f57f 100644 --- a/CHANGES +++ b/CHANGES @@ -13,7 +13,7 @@ The rules for this file: * release numbers follow "Semantic Versioning" https://semver.org ------------------------------------------------------------------------------ -??/??/2024 orbeckst +??/??/2024 orbeckst, xiki-tempula * 2.3.1 @@ -21,6 +21,9 @@ Changes: - alchemlyb adopts SPEC 0 (replaces NEP 29) https://scientific-python.org/specs/spec-0000/ +Enhancements + - Parallelise read and preprocess for ABFE workflow. (PR #371) + 21/05/2024 xiki-tempula diff --git a/src/alchemlyb/tests/test_workflow_ABFE.py b/src/alchemlyb/tests/test_workflow_ABFE.py index 86a1cff4..47ec8bd6 100644 --- a/src/alchemlyb/tests/test_workflow_ABFE.py +++ b/src/alchemlyb/tests/test_workflow_ABFE.py @@ -4,6 +4,7 @@ import pytest from alchemtest.amber import load_bace_example from alchemtest.gmx import load_ABFE +from joblib import parallel_config import alchemlyb.parsing.amber from alchemlyb.workflows.abfe import ABFE @@ -457,3 +458,31 @@ def test_summary(self, workflow): """Test if if the summary is right.""" summary = workflow.generate_result() assert np.isclose(summary["BAR"]["Stages"]["TOTAL"], 1.40405980473, 0.1) + + +class TestParallel: + @pytest.fixture(scope="class") + def workflow(self, tmp_path_factory): + outdir = tmp_path_factory.mktemp("out") + (outdir / "dhdl_00.xvg").symlink_to(load_ABFE()["data"]["complex"][0]) + (outdir / "dhdl_01.xvg").symlink_to(load_ABFE()["data"]["complex"][1]) + workflow = ABFE( + units="kcal/mol", + software="GROMACS", + dir=str(outdir), + prefix="dhdl", + suffix="xvg", + T=310, + ) + with parallel_config(backend="threading"): + # The default backend is "loky", which is more robust but somehow didn't + # play well with pytest, but "loky" is perfectly fine outside pytest. + workflow.read(n_jobs=2) + workflow.preprocess(n_jobs=2) + return workflow + + def test_read(self, workflow): + assert len(workflow.u_nk_list) == 2 + + def test_preprocess(self, workflow): + assert len(workflow.u_nk_sample_list) == 2 From d5fe5dddbbf6114ffa9dbcb266486ad308891315 Mon Sep 17 00:00:00 2001 From: Zhiyi Wu Date: Wed, 29 May 2024 21:42:51 +0100 Subject: [PATCH 10/12] make test more clear --- src/alchemlyb/tests/test_workflow_ABFE.py | 36 ++++++++++++++++++++--- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/src/alchemlyb/tests/test_workflow_ABFE.py b/src/alchemlyb/tests/test_workflow_ABFE.py index 47ec8bd6..6dc84083 100644 --- a/src/alchemlyb/tests/test_workflow_ABFE.py +++ b/src/alchemlyb/tests/test_workflow_ABFE.py @@ -1,6 +1,7 @@ import os import numpy as np +import pandas as pd import pytest from alchemtest.amber import load_bace_example from alchemtest.gmx import load_ABFE @@ -463,6 +464,23 @@ def test_summary(self, workflow): class TestParallel: @pytest.fixture(scope="class") def workflow(self, tmp_path_factory): + outdir = tmp_path_factory.mktemp("out") + (outdir / "dhdl_00.xvg").symlink_to(load_ABFE()["data"]["complex"][0]) + (outdir / "dhdl_01.xvg").symlink_to(load_ABFE()["data"]["complex"][1]) + workflow = ABFE( + units="kcal/mol", + software="GROMACS", + dir=str(outdir), + prefix="dhdl", + suffix="xvg", + T=310, + ) + workflow.read(n_jobs=1) + workflow.preprocess(n_jobs=1) + return workflow + + @pytest.fixture(scope="class") + def parallel_workflow(self, tmp_path_factory): outdir = tmp_path_factory.mktemp("out") (outdir / "dhdl_00.xvg").symlink_to(load_ABFE()["data"]["complex"][0]) (outdir / "dhdl_01.xvg").symlink_to(load_ABFE()["data"]["complex"][1]) @@ -481,8 +499,18 @@ def workflow(self, tmp_path_factory): workflow.preprocess(n_jobs=2) return workflow - def test_read(self, workflow): - assert len(workflow.u_nk_list) == 2 + def test_read(self, workflow, parallel_workflow): + pd.testing.assert_frame_equal( + workflow.u_nk_list[0], parallel_workflow.u_nk_list[0] + ) + pd.testing.assert_frame_equal( + workflow.u_nk_list[1], parallel_workflow.u_nk_list[1] + ) - def test_preprocess(self, workflow): - assert len(workflow.u_nk_sample_list) == 2 + def test_preprocess(self, workflow, parallel_workflow): + pd.testing.assert_frame_equal( + workflow.u_nk_sample_list[0], parallel_workflow.u_nk_sample_list[0] + ) + pd.testing.assert_frame_equal( + workflow.u_nk_sample_list[1], parallel_workflow.u_nk_sample_list[1] + ) From 54bb316defcac5f9809e9a2f164617507e470d44 Mon Sep 17 00:00:00 2001 From: zhiyi wu Date: Fri, 7 Jun 2024 14:28:24 +0100 Subject: [PATCH 11/12] fix type --- src/alchemlyb/workflows/abfe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/alchemlyb/workflows/abfe.py b/src/alchemlyb/workflows/abfe.py index ca4f259e..733104b0 100644 --- a/src/alchemlyb/workflows/abfe.py +++ b/src/alchemlyb/workflows/abfe.py @@ -287,7 +287,7 @@ def run( ) logger.error(msg) raise ValueError(msg) - self.read(read_u_nk=use_FEP, read_dHdl=use_FEP, n_jobs=n_jobs) + self.read(read_u_nk=use_FEP, read_dHdl=use_TI, n_jobs=n_jobs) if uncorr is not None: self.preprocess( From c6cb74535bf09c808567ed47e3334b72d0f5c91b Mon Sep 17 00:00:00 2001 From: Zhiyi Wu Date: Sat, 5 Oct 2024 14:42:24 +0100 Subject: [PATCH 12/12] update --- CHANGES | 8 +++++ devtools/conda-envs/test_env.yaml | 1 + docs/workflows/alchemlyb.workflows.ABFE.rst | 37 +++++++++++++++++++++ pyproject.toml | 1 + src/alchemlyb/tests/test_workflow_ABFE.py | 28 +++++++--------- src/alchemlyb/workflows/abfe.py | 28 ++++++++-------- 6 files changed, 74 insertions(+), 29 deletions(-) diff --git a/CHANGES b/CHANGES index 11fe15a4..103526c9 100644 --- a/CHANGES +++ b/CHANGES @@ -12,6 +12,14 @@ The rules for this file: * accompany each entry with github issue/PR number (Issue #xyz) * release numbers follow "Semantic Versioning" https://semver.org +**/**/**** xiki-tempula + + * 2.5.0 + +Enhancements + - Parallelise read and preprocess for ABFE workflow. (PR #371) + + 09/19/2024 orbeckst, jaclark5 * 2.4.1 diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index bef524b7..abe08740 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -11,6 +11,7 @@ dependencies: - matplotlib>=3.7 - loguru - pyarrow +- joblib # Testing - pytest diff --git a/docs/workflows/alchemlyb.workflows.ABFE.rst b/docs/workflows/alchemlyb.workflows.ABFE.rst index a8dc409a..f6933253 100644 --- a/docs/workflows/alchemlyb.workflows.ABFE.rst +++ b/docs/workflows/alchemlyb.workflows.ABFE.rst @@ -153,6 +153,43 @@ to the data generated at each stage of the analysis. :: >>> # Convergence analysis >>> workflow.check_convergence(10, dF_t='dF_t.pdf') +Parallelisation of Data Reading and Decorrelation +------------------------------------------------- + +The estimation step of the workflow is parallelized using JAX. However, the +reading and decorrelation stages can be parallelized using `joblib`. This is +achieved by passing the number of jobs to run in parallel via the `n_jobs` +parameter to the following methods: + +- :meth:`~alchemlyb.workflows.ABFE.read` +- :meth:`~alchemlyb.workflows.ABFE.preprocess` + +To enable parallel execution, specify the `n_jobs` parameter. Setting +`n_jobs=-1` allows the use of all available resources. :: + + >>> workflow = ABFE(units='kcal/mol', software='GROMACS', dir=dir, + >>> prefix='dhdl', suffix='xvg', T=298, outdirectory='./') + >>> workflow.read(n_jobs=-1) + >>> workflow.preprocess(n_jobs=-1) + +In a fully automated mode, the `n_jobs=-1` parameter can be passed directly to +the :meth:`~alchemlyb.workflows.ABFE.run` method. This will implicitly +parallelise the reading and decorrelation stages. :: + + >>> workflow = ABFE(units='kcal/mol', software='GROMACS', dir=dir, + >>> prefix='dhdl', suffix='xvg', T=298, outdirectory='./') + >>> workflow.run(n_jobs=-1) + +While the default `joblib` settings are suitable for most environments, you +can customize the parallelisation backend depending on the infrastructure. For +example, using the threading backend can be specified as follows. :: + + >>> import joblib + >>> workflow = ABFE(units='kcal/mol', software='GROMACS', dir=dir, + >>> prefix='dhdl', suffix='xvg', T=298, outdirectory='./') + >>> with joblib.parallel_config(backend="threading"): + >>> workflow.run(n_jobs=-1) + API Reference ------------- .. autoclass:: alchemlyb.workflows.ABFE diff --git a/pyproject.toml b/pyproject.toml index 7088f10a..e88de469 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ dependencies = [ "matplotlib>=3.7", "loguru", "pyarrow", + "joblib", ] diff --git a/src/alchemlyb/tests/test_workflow_ABFE.py b/src/alchemlyb/tests/test_workflow_ABFE.py index 6dc84083..a679641a 100644 --- a/src/alchemlyb/tests/test_workflow_ABFE.py +++ b/src/alchemlyb/tests/test_workflow_ABFE.py @@ -5,7 +5,7 @@ import pytest from alchemtest.amber import load_bace_example from alchemtest.gmx import load_ABFE -from joblib import parallel_config +import joblib import alchemlyb.parsing.amber from alchemlyb.workflows.abfe import ABFE @@ -32,7 +32,6 @@ def workflow(tmp_path_factory): overlap="O_MBAR.pdf", breakdown=True, forwrev=10, - n_jobs=1, ) return workflow @@ -82,7 +81,6 @@ def test_invalid_estimator(self, workflow): overlap=None, breakdown=None, forwrev=None, - n_jobs=1, ) def test_single_estimator(self, workflow, monkeypatch): @@ -97,7 +95,6 @@ def test_single_estimator(self, workflow, monkeypatch): overlap=None, breakdown=True, forwrev=None, - n_jobs=1, ) assert "MBAR" in workflow.estimator @@ -110,7 +107,6 @@ def test_no_forwrev(self, workflow, monkeypatch, forwrev): overlap=None, breakdown=None, forwrev=forwrev, - n_jobs=1, ) assert workflow.convergence is None @@ -142,7 +138,7 @@ def test_read_TI_FEP(self, workflow, monkeypatch, read_u_nk, read_dHdl): monkeypatch.setattr(workflow, "dHdl_list", []) monkeypatch.setattr(workflow, "u_nk_sample_list", []) monkeypatch.setattr(workflow, "dHdl_sample_list", []) - workflow.read(read_u_nk, read_dHdl, n_jobs=1) + workflow.read(read_u_nk, read_dHdl) if read_u_nk: assert len(workflow.u_nk_list) == 30 else: @@ -162,7 +158,7 @@ def extract_u_nk(self, T): monkeypatch.setattr(workflow, "_extract_u_nk", extract_u_nk) with pytest.raises(OSError, match=r"Error reading u_nk"): - workflow.read(n_jobs=1) + workflow.read() def test_read_invalid_dHdl(self, workflow, monkeypatch): monkeypatch.setattr(workflow, "u_nk_sample_list", []) @@ -173,7 +169,7 @@ def extract_dHdl(self, T): monkeypatch.setattr(workflow, "_extract_dHdl", extract_dHdl) with pytest.raises(OSError, match=r"Error reading dHdl"): - workflow.read(n_jobs=1) + workflow.read() class TestSubsample: @@ -195,7 +191,7 @@ def test_uncorr_threshold(self, workflow, monkeypatch): ) monkeypatch.setattr(workflow, "u_nk_sample_list", []) monkeypatch.setattr(workflow, "dHdl_sample_list", []) - workflow.preprocess(threshold=50, n_jobs=1) + workflow.preprocess(threshold=50) assert all([len(u_nk) == 40 for u_nk in workflow.u_nk_sample_list]) assert all([len(dHdl) == 40 for dHdl in workflow.dHdl_sample_list]) @@ -203,14 +199,14 @@ def test_no_u_nk_preprocess(self, workflow, monkeypatch): monkeypatch.setattr(workflow, "u_nk_list", []) monkeypatch.setattr(workflow, "u_nk_sample_list", []) monkeypatch.setattr(workflow, "dHdl_sample_list", []) - workflow.preprocess(threshold=50, n_jobs=1) + workflow.preprocess(threshold=50) assert len(workflow.u_nk_list) == 0 def test_no_dHdl_preprocess(self, workflow, monkeypatch): monkeypatch.setattr(workflow, "dHdl_list", []) monkeypatch.setattr(workflow, "u_nk_sample_list", []) monkeypatch.setattr(workflow, "dHdl_sample_list", []) - workflow.preprocess(threshold=50, n_jobs=1) + workflow.preprocess(threshold=50) assert len(workflow.dHdl_list) == 0 @@ -421,7 +417,7 @@ def workflow(tmp_path_factory): T=298.0, outdirectory=str(outdir), ) - workflow.read(n_jobs=1) + workflow.read() workflow.estimate(estimators="TI") return workflow @@ -451,7 +447,7 @@ def workflow(tmp_path_factory): T=298.0, outdirectory=str(outdir), ) - workflow.read(n_jobs=1) + workflow.read() workflow.estimate(estimators="BAR") return workflow @@ -475,8 +471,8 @@ def workflow(self, tmp_path_factory): suffix="xvg", T=310, ) - workflow.read(n_jobs=1) - workflow.preprocess(n_jobs=1) + workflow.read() + workflow.preprocess() return workflow @pytest.fixture(scope="class") @@ -492,7 +488,7 @@ def parallel_workflow(self, tmp_path_factory): suffix="xvg", T=310, ) - with parallel_config(backend="threading"): + with joblib.parallel_config(backend="threading"): # The default backend is "loky", which is more robust but somehow didn't # play well with pytest, but "loky" is perfectly fine outside pytest. workflow.read(n_jobs=2) diff --git a/src/alchemlyb/workflows/abfe.py b/src/alchemlyb/workflows/abfe.py index 733104b0..b47b2618 100644 --- a/src/alchemlyb/workflows/abfe.py +++ b/src/alchemlyb/workflows/abfe.py @@ -6,7 +6,7 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd -from joblib import Parallel, delayed +import joblib from loguru import logger from .base import WorkflowBase @@ -116,7 +116,7 @@ def __init__( else: raise NotImplementedError(f"{software} parser not found.") - def read(self, read_u_nk=True, read_dHdl=True, n_jobs=-1): + def read(self, read_u_nk=True, read_dHdl=True, n_jobs=1): """Read the u_nk and dHdL data from the :attr:`~alchemlyb.workflows.ABFE.file_list` @@ -128,6 +128,7 @@ def read(self, read_u_nk=True, read_dHdl=True, n_jobs=-1): Whether to read the dHdl. n_jobs : int Number of parallel workers to use for reading the data. + (-1 means using all the threads) Attributes ---------- @@ -140,7 +141,6 @@ def read(self, read_u_nk=True, read_dHdl=True, n_jobs=-1): self.dHdl_sample_list = None if read_u_nk: - def extract_u_nk(_extract_u_nk, file, T): try: u_nk = _extract_u_nk(file, T) @@ -151,8 +151,8 @@ def extract_u_nk(_extract_u_nk, file, T): logger.error(msg) raise OSError(msg) from exc - u_nk_list = Parallel(n_jobs=n_jobs)( - delayed(extract_u_nk)(self._extract_u_nk, file, self.T) + u_nk_list = joblib.Parallel(n_jobs=n_jobs)( + joblib.delayed(extract_u_nk)(self._extract_u_nk, file, self.T) for file in self.file_list ) else: @@ -170,8 +170,8 @@ def extract_dHdl(_extract_dHdl, file, T): logger.error(msg) raise OSError(msg) from exc - dHdl_list = Parallel(n_jobs=n_jobs)( - delayed(extract_dHdl)(self._extract_dHdl, file, self.T) + dHdl_list = joblib.Parallel(n_jobs=n_jobs)( + joblib.delayed(extract_dHdl)(self._extract_dHdl, file, self.T) for file in self.file_list ) else: @@ -219,7 +219,7 @@ def run( overlap="O_MBAR.pdf", breakdown=True, forwrev=None, - n_jobs=-1, + n_jobs=1, *args, **kwargs, ): @@ -257,6 +257,7 @@ def run( with estimator='TI'. n_jobs : int Number of parallel workers to use for reading and decorrelating the data. + (-1 means using all the threads) Attributes ---------- @@ -329,7 +330,7 @@ def update_units(self, units=None): logger.info(f"Set unit to {units}.") self.units = units or None - def preprocess(self, skiptime=0, uncorr="dE", threshold=50, n_jobs=-1): + def preprocess(self, skiptime=0, uncorr="dE", threshold=50, n_jobs=1): """Preprocess the data by removing the equilibration time and decorrelate the date. @@ -346,6 +347,7 @@ def preprocess(self, skiptime=0, uncorr="dE", threshold=50, n_jobs=-1): time series analysis will not be performed at all. Default: 50. n_jobs : int Number of parallel workers to use for decorrelating the data. + (-1 means using all the threads) Attributes ---------- @@ -380,8 +382,8 @@ def _decorrelate_u_nk(u_nk, skiptime, threshold, index): ) return subsample - self.u_nk_sample_list = Parallel(n_jobs=n_jobs)( - delayed(_decorrelate_u_nk)(u_nk, skiptime, threshold, index) + self.u_nk_sample_list = joblib.Parallel(n_jobs=n_jobs)( + joblib.delayed(_decorrelate_u_nk)(u_nk, skiptime, threshold, index) for index, u_nk in enumerate(self.u_nk_list) ) else: @@ -407,8 +409,8 @@ def _decorrelate_dhdl(dHdl, skiptime, threshold, index): ) return subsample - self.dHdl_sample_list = Parallel(n_jobs=n_jobs)( - delayed(_decorrelate_dhdl)(dHdl, skiptime, threshold, index) + self.dHdl_sample_list = joblib.Parallel(n_jobs=n_jobs)( + joblib.delayed(_decorrelate_dhdl)(dHdl, skiptime, threshold, index) for index, dHdl in enumerate(self.dHdl_list) ) else: