Skip to content

Commit

Permalink
Allow metadata to be loaded from the serialised parquet file. (#340)
Browse files Browse the repository at this point in the history
* fix #331 
* Allow metadata to be loaded from the serialised parquet file.
* add tests
* update changelog
  • Loading branch information
xiki-tempula authored Jan 3, 2024
1 parent 1360a0d commit 99048eb
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 13 deletions.
4 changes: 3 additions & 1 deletion CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ The rules for this file:

------------------------------------------------------------------------------

*/*/2023 hl2500
*/*/2023 hl2500, xiki-tempula

* 2.2.0

Changes
- Require pandas >= 2.1 (PR #340)
- For pandas>=2.1, metadata will be loaded from the parquet file (issue #331, PR #340).
- add support for Python 3.12, remove Python 3.8 support (issue #341, PR #304).

Enhancements
Expand Down
2 changes: 1 addition & 1 deletion devtools/conda-envs/test_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
dependencies:
- python
- numpy
- pandas
- pandas>=2.1
- pymbar>=4
- scipy
- scikit-learn
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
dependencies:
- python
- numpy
- pandas
- pandas>=2.1
- pymbar>=4
- scipy
- scikit-learn
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
tests_require=["pytest", "alchemtest"],
install_requires=[
"numpy",
"pandas>=1.4",
"pandas>=2.1",
"pymbar>=4",
"scipy",
"scikit-learn",
Expand Down
39 changes: 36 additions & 3 deletions src/alchemlyb/parsing/parquet.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,42 @@
import pandas as pd
from loguru import logger

from . import _init_attrs


@_init_attrs
def _read_parquet_with_metadata(path: str, T: float) -> pd.DataFrame:
"""
Check if the metadata is included in the Dataframe and has the correct
temperature.
Parameters
----------
path : str
Path to parquet file to extract dataframe from.
T : float
Temperature in Kelvin of the simulations.
Returns
-------
DataFrame
"""
df = pd.read_parquet(path)
if "temperature" not in df.attrs:
logger.warning(
f"No temperature metadata found in {path}. "
f"Serialise the Dataframe with pandas>=2.1 to preserve the metadata."
)
df.attrs["temperature"] = T
df.attrs["energy_unit"] = "kT"
else:
if df.attrs["temperature"] != T:
raise ValueError(
f"Temperature in the input ({T}) doesn't match the temperature "
f"in the dataframe ({df.attrs['temperature']})."
)
return df


def extract_u_nk(path, T):
r"""Return reduced potentials `u_nk` (unit: kT) from a pandas parquet file.
Expand Down Expand Up @@ -36,7 +69,7 @@ def extract_u_nk(path, T):
.. versionadded:: 2.1.0
"""
u_nk = pd.read_parquet(path)
u_nk = _read_parquet_with_metadata(path, T)
columns = list(u_nk.columns)
if isinstance(columns[0], str) and columns[0][0] == "(":
new_columns = []
Expand Down Expand Up @@ -81,4 +114,4 @@ def extract_dHdl(path, T):
.. versionadded:: 2.1.0
"""
return pd.read_parquet(path)
return _read_parquet_with_metadata(path, T)
2 changes: 1 addition & 1 deletion src/alchemlyb/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def gmx_ABFE():


@pytest.fixture
def gmx_ABFE_complex_n_uk(gmx_ABFE):
def gmx_ABFE_complex_u_nk(gmx_ABFE):
return [gmx.extract_u_nk(file, T=300) for file in gmx_ABFE["complex"]]


Expand Down
35 changes: 34 additions & 1 deletion src/alchemlyb/tests/parsing/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,45 @@ def test_extract_dHdl(dHdl_list, request, tmp_path):
new_dHdl = extract_dHdl(str(tmp_path / "dhdl.parquet"), T=300)
assert (new_dHdl.columns == dHdl.columns).all()
assert (new_dHdl.index == dHdl.index).all()
assert new_dHdl.attrs["temperature"] == 300
assert new_dHdl.attrs["energy_unit"] == "kT"


@pytest.mark.parametrize("u_nk_list", ["gmx_benzene_VDW_u_nk", "gmx_ABFE_complex_n_uk"])
@pytest.mark.parametrize("u_nk_list", ["gmx_benzene_VDW_u_nk", "gmx_ABFE_complex_u_nk"])
def test_extract_dHdl(u_nk_list, request, tmp_path):
u_nk = request.getfixturevalue(u_nk_list)[0]
u_nk.to_parquet(path=str(tmp_path / "u_nk.parquet"), index=True)
new_u_nk = extract_u_nk(str(tmp_path / "u_nk.parquet"), T=300)
assert (new_u_nk.columns == u_nk.columns).all()
assert (new_u_nk.index == u_nk.index).all()
assert new_u_nk.attrs["temperature"] == 300
assert new_u_nk.attrs["energy_unit"] == "kT"


@pytest.fixture()
def u_nk(gmx_ABFE_complex_u_nk):
return gmx_ABFE_complex_u_nk[0]


def test_no_T(u_nk, tmp_path, caplog):
u_nk.attrs = {}
u_nk.to_parquet(path=str(tmp_path / "temp.parquet"), index=True)
extract_u_nk(str(tmp_path / "temp.parquet"), 300)
assert (
"Serialise the Dataframe with pandas>=2.1 to preserve the metadata."
in caplog.text
)


def test_wrong_T(u_nk, tmp_path, caplog):
u_nk.to_parquet(path=str(tmp_path / "temp.parquet"), index=True)
with pytest.raises(ValueError, match="doesn't match the temperature"):
extract_u_nk(str(tmp_path / "temp.parquet"), 400)


def test_metadata_unchanged(u_nk, tmp_path):
u_nk.attrs = {"temperature": 400, "energy_unit": "kcal/mol"}
u_nk.to_parquet(path=str(tmp_path / "temp.parquet"), index=True)
new_u_nk = extract_u_nk(str(tmp_path / "temp.parquet"), 400)
assert new_u_nk.attrs["temperature"] == 400
assert new_u_nk.attrs["energy_unit"] == "kcal/mol"
6 changes: 3 additions & 3 deletions src/alchemlyb/tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def u_nk(gmx_benzene_Coulomb_u_nk):


@pytest.fixture()
def multi_index_u_nk(gmx_ABFE_complex_n_uk):
return gmx_ABFE_complex_n_uk[0]
def multi_index_u_nk(gmx_ABFE_complex_u_nk):
return gmx_ABFE_complex_u_nk[0]


@pytest.fixture()
Expand Down Expand Up @@ -470,7 +470,7 @@ def test_decorrelate_dhdl_multiple_l(multi_index_dHdl):
)


def test_raise_non_uk(multi_index_dHdl):
def test_raise_nou_nk(multi_index_dHdl):
with pytest.raises(ValueError):
decorrelate_u_nk(
multi_index_dHdl,
Expand Down
2 changes: 1 addition & 1 deletion src/alchemlyb/tests/test_workflow_ABFE.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def test_single_estimator_ti(self, workflow, monkeypatch):
summary = workflow.generate_result()
assert np.isclose(summary["TI"]["Stages"]["TOTAL"], 21.51472826028906, 0.1)

def test_unprocessed_n_uk(self, workflow, monkeypatch):
def test_unprocessed_u_nk(self, workflow, monkeypatch):
monkeypatch.setattr(workflow, "u_nk_sample_list", None)
monkeypatch.setattr(workflow, "estimator", dict())
workflow.estimate()
Expand Down

0 comments on commit 99048eb

Please sign in to comment.