Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add wandb logging support to Trainer class #166

Merged
merged 6 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:

python setup.py build_ext --inplace

uv pip install -e .[test] --system --resolution=${{ matrix.version.resolution }}
uv pip install -e .[test,logging] --system --resolution=${{ matrix.version.resolution }}

- name: Run Tests
run: pytest --capture=no --cov --cov-report=xml
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,6 @@ coverage.xml
.ipynb_checkpoints
bond_graph_error.cif
test.py

# training logs
wandb
52 changes: 50 additions & 2 deletions chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
from chgnet.model.model import CHGNet
from chgnet.utils import AverageMeter, determine_device, mae, write_json

try:
import wandb
except ImportError:
wandb = None


if TYPE_CHECKING:
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -50,6 +56,9 @@ def __init__(
data_seed: int | None = None,
use_device: str | None = None,
check_cuda_mem: bool = False,
wandb_path: str | None = None,
wandb_init_kwargs: dict | None = None,
extra_run_config: dict | None = None,
**kwargs,
) -> None:
"""Initialize all hyper-parameters for trainer.
Expand Down Expand Up @@ -88,15 +97,22 @@ def __init__(
Default = None
check_cuda_mem (bool): Whether to use cuda with most available memory
Default = False
wandb_path (str | None): The project and run name separated by a slash:
"project/run_name". If None, wandb logging is not used.
Default = None
wandb_init_kwargs (dict): Additional kwargs to pass to wandb.init.
Default = None
extra_run_config (dict): Additional hyper-params to be recorded by wandb
that are not included in the trainer_args. Default = None

**kwargs (dict): additional hyper-params for optimizer, scheduler, etc.
"""
# Store trainer args for reproducibility
self.trainer_args = {
k: v
for k, v in locals().items()
if k not in {"self", "__class__", "model", "kwargs"}
}
self.trainer_args.update(kwargs)
} | kwargs

self.model = model
self.targets = targets
Expand Down Expand Up @@ -195,6 +211,27 @@ def __init__(
] = {key: {"train": [], "val": [], "test": []} for key in self.targets}
self.best_model = None

# Initialize wandb if project/run specified
if wandb_path:
if wandb is None:
raise ImportError(
"Weights and Biases not installed. pip install wandb to use "
"wandb logging."
)
if wandb_path.count("/") == 1:
project, run_name = wandb_path.split("/")
else:
raise ValueError(
f"{wandb_path=} should be in the format 'project/run_name' "
"(no extra slashes)"
)
wandb.init(
project=project,
name=run_name,
config=self.trainer_args | (extra_run_config or {}),
**(wandb_init_kwargs or {}),
)

def train(
self,
train_loader: DataLoader,
Expand Down Expand Up @@ -257,6 +294,13 @@ def train(

self.save_checkpoint(epoch, val_mae, save_dir=save_dir)

# Log train/val metrics to wandb
if wandb is not None and self.trainer_args.get("wandb_path"):
wandb.log(
{f"train_{k}_mae": v for k, v in train_mae.items()}
| {f"val_{k}_mae": v for k, v in val_mae.items()}
)

if test_loader is not None:
# test best model
print("---------Evaluate Model on Test Set---------------")
Expand All @@ -279,6 +323,10 @@ def train(
self.training_history[key]["test"] = test_mae[key]
self.save(filename=os.path.join(save_dir, test_file))

# Log test metrics to wandb
if wandb is not None and self.trainer_args.get("wandb_path"):
wandb.log({f"test_{k}_mae": v for k, v in test_mae.items()})

def _train(self, train_loader: DataLoader, current_epoch: int) -> dict:
"""Train all data for one epoch.

Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ test = ["pytest-cov>=4", "pytest>=8"]
# needed to run interactive example notebooks
examples = ["crystal-toolkit>=2023.11.3", "pandas>=2.2"]
docs = ["lazydocs>=0.4"]
logging = ["wandb>=0.17"]

[project.urls]
Source = "https://github.com/CederGroupHub/chgnet"
Expand Down Expand Up @@ -89,6 +90,7 @@ ignore = [
pydocstyle.convention = "google"
isort.required-imports = ["from __future__ import annotations"]
isort.split-on-trailing-comma = false
isort.known-third-party = ["wandb"]

[tool.ruff.format]
docstring-code-format = true
Expand Down
15 changes: 14 additions & 1 deletion tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from typing import TYPE_CHECKING

import numpy as np
import pytest
import torch
import wandb
from pymatgen.core import Lattice, Structure

from chgnet.data.dataset import StructureData, get_train_val_test_loader
Expand Down Expand Up @@ -36,19 +38,24 @@
)


def test_trainer(tmp_path: Path) -> None:
def test_trainer(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
chgnet = CHGNet.load()
train_loader, val_loader, _test_loader = get_train_val_test_loader(
data, batch_size=16, train_ratio=0.9, val_ratio=0.05
)
extra_run_config = dict(some_other_hyperparam=42)
trainer = Trainer(
model=chgnet,
targets="efsm",
optimizer="Adam",
criterion="MSE",
learning_rate=1e-2,
epochs=5,
wandb_path="test/run",
wandb_init_kwargs=dict(anonymous="must"),
extra_run_config=extra_run_config,
)
assert dict(wandb.config).items() >= extra_run_config.items()
dir_name = "test_tmp_dir"
test_dir = tmp_path / dir_name
trainer.train(train_loader, val_loader, save_dir=test_dir)
Expand All @@ -63,6 +70,12 @@ def test_trainer(tmp_path: Path) -> None:
n_matches == 1
), f"Expected 1 {prefix} file, found {n_matches} in {output_files}"

# expect ImportError when passing wandb_path without wandb installed
err_msg = "Weights and Biases not installed. pip install wandb to use wandb logging"
with monkeypatch.context() as ctx, pytest.raises(ImportError, match=err_msg): # noqa: PT012
ctx.setattr("chgnet.trainer.trainer.wandb", None)
_ = Trainer(model=chgnet, wandb_path="some-org/some-project")


def test_trainer_composition_model(tmp_path: Path) -> None:
chgnet = CHGNet.load()
Expand Down
Loading