Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel read and preprocess the data #371

Merged
merged 15 commits into from
Oct 6, 2024
8 changes: 8 additions & 0 deletions CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ The rules for this file:
* accompany each entry with github issue/PR number (Issue #xyz)
* release numbers follow "Semantic Versioning" https://semver.org

**/**/**** xiki-tempula

* 2.5.0

Enhancements
- Parallelise read and preprocess for ABFE workflow. (PR #371)


09/19/2024 orbeckst, jaclark5

* 2.4.1
Expand Down
1 change: 1 addition & 0 deletions devtools/conda-envs/test_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies:
- matplotlib>=3.7
- loguru
- pyarrow
- joblib

# Testing
- pytest
Expand Down
37 changes: 37 additions & 0 deletions docs/workflows/alchemlyb.workflows.ABFE.rst
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,43 @@ to the data generated at each stage of the analysis. ::
>>> # Convergence analysis
>>> workflow.check_convergence(10, dF_t='dF_t.pdf')

Parallelisation of Data Reading and Decorrelation
-------------------------------------------------

The estimation step of the workflow is parallelized using JAX. However, the
reading and decorrelation stages can be parallelized using `joblib`. This is
achieved by passing the number of jobs to run in parallel via the `n_jobs`
parameter to the following methods:

- :meth:`~alchemlyb.workflows.ABFE.read`
- :meth:`~alchemlyb.workflows.ABFE.preprocess`

To enable parallel execution, specify the `n_jobs` parameter. Setting
`n_jobs=-1` allows the use of all available resources. ::

>>> workflow = ABFE(units='kcal/mol', software='GROMACS', dir=dir,
>>> prefix='dhdl', suffix='xvg', T=298, outdirectory='./')
>>> workflow.read(n_jobs=-1)
>>> workflow.preprocess(n_jobs=-1)

In a fully automated mode, the `n_jobs=-1` parameter can be passed directly to
the :meth:`~alchemlyb.workflows.ABFE.run` method. This will implicitly
parallelise the reading and decorrelation stages. ::

>>> workflow = ABFE(units='kcal/mol', software='GROMACS', dir=dir,
>>> prefix='dhdl', suffix='xvg', T=298, outdirectory='./')
>>> workflow.run(n_jobs=-1)

While the default `joblib` settings are suitable for most environments, you
can customize the parallelisation backend depending on the infrastructure. For
example, using the threading backend can be specified as follows. ::

>>> import joblib
>>> workflow = ABFE(units='kcal/mol', software='GROMACS', dir=dir,
>>> prefix='dhdl', suffix='xvg', T=298, outdirectory='./')
>>> with joblib.parallel_config(backend="threading"):
>>> workflow.run(n_jobs=-1)

API Reference
-------------
.. autoclass:: alchemlyb.workflows.ABFE
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ dependencies:
- pyarrow
- matplotlib>=3.7
- loguru
- joblib
orbeckst marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ dependencies = [
"matplotlib>=3.7",
"loguru",
"pyarrow",
"joblib",
]


Expand Down
69 changes: 67 additions & 2 deletions src/alchemlyb/tests/test_workflow_ABFE.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os

import numpy as np
import pandas as pd
import pytest
from alchemtest.amber import load_bace_example
from alchemtest.gmx import load_ABFE
import joblib

import alchemlyb.parsing.amber
from alchemlyb.workflows.abfe import ABFE
Expand Down Expand Up @@ -88,15 +90,23 @@ def test_single_estimator(self, workflow, monkeypatch):
monkeypatch.setattr(workflow, "dHdl_sample_list", [])
monkeypatch.setattr(workflow, "estimator", dict())
workflow.run(
uncorr=None, estimators="MBAR", overlap=None, breakdown=True, forwrev=None
uncorr=None,
estimators="MBAR",
overlap=None,
breakdown=True,
forwrev=None,
)
assert "MBAR" in workflow.estimator

@pytest.mark.parametrize("forwrev", [None, False, 0])
def test_no_forwrev(self, workflow, monkeypatch, forwrev):
monkeypatch.setattr(workflow, "convergence", None)
workflow.run(
uncorr=None, estimators=None, overlap=None, breakdown=None, forwrev=forwrev
uncorr=None,
estimators=None,
overlap=None,
breakdown=None,
forwrev=forwrev,
)
assert workflow.convergence is None

Expand Down Expand Up @@ -445,3 +455,58 @@ def test_summary(self, workflow):
"""Test if if the summary is right."""
summary = workflow.generate_result()
assert np.isclose(summary["BAR"]["Stages"]["TOTAL"], 1.40405980473, 0.1)


class TestParallel:
@pytest.fixture(scope="class")
def workflow(self, tmp_path_factory):
outdir = tmp_path_factory.mktemp("out")
(outdir / "dhdl_00.xvg").symlink_to(load_ABFE()["data"]["complex"][0])
(outdir / "dhdl_01.xvg").symlink_to(load_ABFE()["data"]["complex"][1])
workflow = ABFE(
units="kcal/mol",
software="GROMACS",
dir=str(outdir),
prefix="dhdl",
suffix="xvg",
T=310,
)
workflow.read()
workflow.preprocess()
return workflow

@pytest.fixture(scope="class")
def parallel_workflow(self, tmp_path_factory):
outdir = tmp_path_factory.mktemp("out")
(outdir / "dhdl_00.xvg").symlink_to(load_ABFE()["data"]["complex"][0])
(outdir / "dhdl_01.xvg").symlink_to(load_ABFE()["data"]["complex"][1])
workflow = ABFE(
units="kcal/mol",
software="GROMACS",
dir=str(outdir),
prefix="dhdl",
suffix="xvg",
T=310,
)
with joblib.parallel_config(backend="threading"):
# The default backend is "loky", which is more robust but somehow didn't
# play well with pytest, but "loky" is perfectly fine outside pytest.
workflow.read(n_jobs=2)
workflow.preprocess(n_jobs=2)
return workflow

def test_read(self, workflow, parallel_workflow):
pd.testing.assert_frame_equal(
workflow.u_nk_list[0], parallel_workflow.u_nk_list[0]
)
pd.testing.assert_frame_equal(
workflow.u_nk_list[1], parallel_workflow.u_nk_list[1]
)

def test_preprocess(self, workflow, parallel_workflow):
pd.testing.assert_frame_equal(
workflow.u_nk_sample_list[0], parallel_workflow.u_nk_sample_list[0]
)
pd.testing.assert_frame_equal(
workflow.u_nk_sample_list[1], parallel_workflow.u_nk_sample_list[1]
)
82 changes: 57 additions & 25 deletions src/alchemlyb/workflows/abfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import joblib
from loguru import logger

from .base import WorkflowBase
Expand Down Expand Up @@ -115,7 +116,7 @@ def __init__(
else:
raise NotImplementedError(f"{software} parser not found.")

def read(self, read_u_nk=True, read_dHdl=True):
def read(self, read_u_nk=True, read_dHdl=True, n_jobs=1):
"""Read the u_nk and dHdL data from the
:attr:`~alchemlyb.workflows.ABFE.file_list`

Expand All @@ -125,6 +126,9 @@ def read(self, read_u_nk=True, read_dHdl=True):
Whether to read the u_nk.
read_dHdl : bool
Whether to read the dHdl.
n_jobs : int
Number of parallel workers to use for reading the data.
orbeckst marked this conversation as resolved.
Show resolved Hide resolved
(-1 means using all the threads)

Attributes
----------
Expand All @@ -136,29 +140,43 @@ def read(self, read_u_nk=True, read_dHdl=True):
self.u_nk_sample_list = None
self.dHdl_sample_list = None

u_nk_list = []
dHdl_list = []
for file in self.file_list:
if read_u_nk:
if read_u_nk:
def extract_u_nk(_extract_u_nk, file, T):
try:
u_nk = self._extract_u_nk(file, T=self.T)
u_nk = _extract_u_nk(file, T)
logger.info(f"Reading {len(u_nk)} lines of u_nk from {file}")
u_nk_list.append(u_nk)
return u_nk
except Exception as exc:
msg = f"Error reading u_nk from {file}."
logger.error(msg)
raise OSError(msg) from exc

if read_dHdl:
u_nk_list = joblib.Parallel(n_jobs=n_jobs)(
joblib.delayed(extract_u_nk)(self._extract_u_nk, file, self.T)
for file in self.file_list
)
else:
u_nk_list = []

if read_dHdl:

def extract_dHdl(_extract_dHdl, file, T):
try:
dhdl = self._extract_dHdl(file, T=self.T)
dhdl = _extract_dHdl(file, T)
logger.info(f"Reading {len(dhdl)} lines of dhdl from {file}")
dHdl_list.append(dhdl)
return dhdl
except Exception as exc:
msg = f"Error reading dHdl from {file}."
logger.error(msg)
raise OSError(msg) from exc

dHdl_list = joblib.Parallel(n_jobs=n_jobs)(
joblib.delayed(extract_dHdl)(self._extract_dHdl, file, self.T)
for file in self.file_list
)
else:
dHdl_list = []

# Sort the files according to the state
if read_u_nk:
logger.info("Sort files according to the u_nk.")
Expand Down Expand Up @@ -201,6 +219,7 @@ def run(
overlap="O_MBAR.pdf",
breakdown=True,
forwrev=None,
n_jobs=1,
*args,
**kwargs,
):
Expand Down Expand Up @@ -236,6 +255,9 @@ def run(
contain u_nk, please run
meth:`~alchemlyb.workflows.ABFE.check_convergence` manually
with estimator='TI'.
n_jobs : int
Number of parallel workers to use for reading and decorrelating the data.
(-1 means using all the threads)

Attributes
----------
Expand Down Expand Up @@ -266,11 +288,12 @@ def run(
)
logger.error(msg)
raise ValueError(msg)

self.read(use_FEP, use_TI)
self.read(read_u_nk=use_FEP, read_dHdl=use_TI, n_jobs=n_jobs)

if uncorr is not None:
self.preprocess(skiptime=skiptime, uncorr=uncorr, threshold=threshold)
self.preprocess(
skiptime=skiptime, uncorr=uncorr, threshold=threshold, n_jobs=n_jobs
)
if estimators is not None:
self.estimate(estimators)
self.generate_result()
Expand Down Expand Up @@ -307,7 +330,7 @@ def update_units(self, units=None):
logger.info(f"Set unit to {units}.")
self.units = units or None

def preprocess(self, skiptime=0, uncorr="dE", threshold=50):
def preprocess(self, skiptime=0, uncorr="dE", threshold=50, n_jobs=1):
"""Preprocess the data by removing the equilibration time and
decorrelate the date.

Expand All @@ -322,6 +345,9 @@ def preprocess(self, skiptime=0, uncorr="dE", threshold=50):
Proceed with correlated samples if the number of uncorrelated
samples is found to be less than this number. If 0 is given, the
time series analysis will not be performed at all. Default: 50.
n_jobs : int
Number of parallel workers to use for decorrelating the data.
(-1 means using all the threads)

Attributes
----------
Expand All @@ -338,33 +364,34 @@ def preprocess(self, skiptime=0, uncorr="dE", threshold=50):
if len(self.u_nk_list) > 0:
logger.info(f"Processing the u_nk data set with skiptime of {skiptime}.")

self.u_nk_sample_list = []
for index, u_nk in enumerate(self.u_nk_list):
# Find the starting frame

def _decorrelate_u_nk(u_nk, skiptime, threshold, index):
u_nk = u_nk[u_nk.index.get_level_values("time") >= skiptime]
subsample = decorrelate_u_nk(u_nk, uncorr, remove_burnin=True)

if len(subsample) < threshold:
logger.warning(
f"Number of u_nk {len(subsample)} "
f"for state {index} is less than the "
f"threshold {threshold}."
)
logger.info(f"Take all the u_nk for state {index}.")
self.u_nk_sample_list.append(u_nk)
subsample = u_nk
else:
logger.info(
f"Take {len(subsample)} uncorrelated "
f"u_nk for state {index}."
)
self.u_nk_sample_list.append(subsample)
return subsample

self.u_nk_sample_list = joblib.Parallel(n_jobs=n_jobs)(
joblib.delayed(_decorrelate_u_nk)(u_nk, skiptime, threshold, index)
for index, u_nk in enumerate(self.u_nk_list)
)
else:
logger.info("No u_nk data being subsampled")

if len(self.dHdl_list) > 0:
self.dHdl_sample_list = []
for index, dHdl in enumerate(self.dHdl_list):

def _decorrelate_dhdl(dHdl, skiptime, threshold, index):
dHdl = dHdl[dHdl.index.get_level_values("time") >= skiptime]
subsample = decorrelate_dhdl(dHdl, remove_burnin=True)
if len(subsample) < threshold:
Expand All @@ -374,13 +401,18 @@ def preprocess(self, skiptime=0, uncorr="dE", threshold=50):
f"threshold {threshold}."
)
logger.info(f"Take all the dHdl for state {index}.")
self.dHdl_sample_list.append(dHdl)
subsample = dHdl
else:
logger.info(
f"Take {len(subsample)} uncorrelated "
f"dHdl for state {index}."
)
self.dHdl_sample_list.append(subsample)
return subsample

self.dHdl_sample_list = joblib.Parallel(n_jobs=n_jobs)(
joblib.delayed(_decorrelate_dhdl)(dHdl, skiptime, threshold, index)
for index, dHdl in enumerate(self.dHdl_list)
)
else:
logger.info("No dHdl data being subsampled")

Expand Down
Loading