diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c4fabe05..b6a12480 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -38,7 +38,7 @@ jobs: python setup.py build_ext --inplace - uv pip install -e .[test] --system --resolution=${{ matrix.version.resolution }} + uv pip install -e .[test,logging] --system --resolution=${{ matrix.version.resolution }} - name: Run Tests run: pytest --capture=no --cov --cov-report=xml diff --git a/.gitignore b/.gitignore index aea32c2b..ac65ad05 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,6 @@ coverage.xml .ipynb_checkpoints bond_graph_error.cif test.py + +# training logs +wandb diff --git a/chgnet/trainer/trainer.py b/chgnet/trainer/trainer.py index c7ae2086..8575b372 100644 --- a/chgnet/trainer/trainer.py +++ b/chgnet/trainer/trainer.py @@ -21,6 +21,12 @@ from chgnet.model.model import CHGNet from chgnet.utils import AverageMeter, determine_device, mae, write_json +try: + import wandb +except ImportError: + wandb = None + + if TYPE_CHECKING: from torch.utils.data import DataLoader @@ -50,6 +56,9 @@ def __init__( data_seed: int | None = None, use_device: str | None = None, check_cuda_mem: bool = False, + wandb_path: str | None = None, + wandb_init_kwargs: dict | None = None, + extra_run_config: dict | None = None, **kwargs, ) -> None: """Initialize all hyper-parameters for trainer. @@ -88,6 +97,14 @@ def __init__( Default = None check_cuda_mem (bool): Whether to use cuda with most available memory Default = False + wandb_path (str | None): The project and run name separated by a slash: + "project/run_name". If None, wandb logging is not used. + Default = None + wandb_init_kwargs (dict): Additional kwargs to pass to wandb.init. + Default = None + extra_run_config (dict): Additional hyper-params to be recorded by wandb + that are not included in the trainer_args. Default = None + **kwargs (dict): additional hyper-params for optimizer, scheduler, etc. """ # Store trainer args for reproducibility @@ -95,8 +112,7 @@ def __init__( k: v for k, v in locals().items() if k not in {"self", "__class__", "model", "kwargs"} - } - self.trainer_args.update(kwargs) + } | kwargs self.model = model self.targets = targets @@ -195,6 +211,27 @@ def __init__( ] = {key: {"train": [], "val": [], "test": []} for key in self.targets} self.best_model = None + # Initialize wandb if project/run specified + if wandb_path: + if wandb is None: + raise ImportError( + "Weights and Biases not installed. pip install wandb to use " + "wandb logging." + ) + if wandb_path.count("/") == 1: + project, run_name = wandb_path.split("/") + else: + raise ValueError( + f"{wandb_path=} should be in the format 'project/run_name' " + "(no extra slashes)" + ) + wandb.init( + project=project, + name=run_name, + config=self.trainer_args | (extra_run_config or {}), + **(wandb_init_kwargs or {}), + ) + def train( self, train_loader: DataLoader, @@ -257,6 +294,13 @@ def train( self.save_checkpoint(epoch, val_mae, save_dir=save_dir) + # Log train/val metrics to wandb + if wandb is not None and self.trainer_args.get("wandb_path"): + wandb.log( + {f"train_{k}_mae": v for k, v in train_mae.items()} + | {f"val_{k}_mae": v for k, v in val_mae.items()} + ) + if test_loader is not None: # test best model print("---------Evaluate Model on Test Set---------------") @@ -279,6 +323,10 @@ def train( self.training_history[key]["test"] = test_mae[key] self.save(filename=os.path.join(save_dir, test_file)) + # Log test metrics to wandb + if wandb is not None and self.trainer_args.get("wandb_path"): + wandb.log({f"test_{k}_mae": v for k, v in test_mae.items()}) + def _train(self, train_loader: DataLoader, current_epoch: int) -> dict: """Train all data for one epoch. diff --git a/pyproject.toml b/pyproject.toml index dc6c0f49..470842b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ test = ["pytest-cov>=4", "pytest>=8"] # needed to run interactive example notebooks examples = ["crystal-toolkit>=2023.11.3", "pandas>=2.2"] docs = ["lazydocs>=0.4"] +logging = ["wandb>=0.17"] [project.urls] Source = "https://github.com/CederGroupHub/chgnet" @@ -89,6 +90,7 @@ ignore = [ pydocstyle.convention = "google" isort.required-imports = ["from __future__ import annotations"] isort.split-on-trailing-comma = false +isort.known-third-party = ["wandb"] [tool.ruff.format] docstring-code-format = true diff --git a/tests/test_trainer.py b/tests/test_trainer.py index bcf44f64..9403534c 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -3,7 +3,9 @@ from typing import TYPE_CHECKING import numpy as np +import pytest import torch +import wandb from pymatgen.core import Lattice, Structure from chgnet.data.dataset import StructureData, get_train_val_test_loader @@ -36,11 +38,12 @@ ) -def test_trainer(tmp_path: Path) -> None: +def test_trainer(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: chgnet = CHGNet.load() train_loader, val_loader, _test_loader = get_train_val_test_loader( data, batch_size=16, train_ratio=0.9, val_ratio=0.05 ) + extra_run_config = dict(some_other_hyperparam=42) trainer = Trainer( model=chgnet, targets="efsm", @@ -48,7 +51,11 @@ def test_trainer(tmp_path: Path) -> None: criterion="MSE", learning_rate=1e-2, epochs=5, + wandb_path="test/run", + wandb_init_kwargs=dict(anonymous="must"), + extra_run_config=extra_run_config, ) + assert dict(wandb.config).items() >= extra_run_config.items() dir_name = "test_tmp_dir" test_dir = tmp_path / dir_name trainer.train(train_loader, val_loader, save_dir=test_dir) @@ -63,6 +70,12 @@ def test_trainer(tmp_path: Path) -> None: n_matches == 1 ), f"Expected 1 {prefix} file, found {n_matches} in {output_files}" + # expect ImportError when passing wandb_path without wandb installed + err_msg = "Weights and Biases not installed. pip install wandb to use wandb logging" + with monkeypatch.context() as ctx, pytest.raises(ImportError, match=err_msg): # noqa: PT012 + ctx.setattr("chgnet.trainer.trainer.wandb", None) + _ = Trainer(model=chgnet, wandb_path="some-org/some-project") + def test_trainer_composition_model(tmp_path: Path) -> None: chgnet = CHGNet.load()