From 96ffdc2f3ce20745f6fad12051c5d3063181d79a Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Thu, 27 Jun 2024 11:24:45 -0400 Subject: [PATCH] add test_wandb_init + test_wandb_log_frequency --- tests/test_trainer.py | 86 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 78 insertions(+), 8 deletions(-) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 818f1930..6fdb7310 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING +from unittest.mock import patch import numpy as np import pytest @@ -36,13 +37,13 @@ stresses=stresses, magmoms=magmoms, ) +train_loader, val_loader, _test_loader = get_train_val_test_loader( + data, batch_size=16, train_ratio=0.9, val_ratio=0.05 +) +chgnet = CHGNet.load() def test_trainer(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: - chgnet = CHGNet.load() - train_loader, val_loader, _test_loader = get_train_val_test_loader( - data, batch_size=16, train_ratio=0.9, val_ratio=0.05 - ) extra_run_config = dict(some_other_hyperparam=42) trainer = Trainer( model=chgnet, @@ -81,12 +82,8 @@ def test_trainer(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: def test_trainer_composition_model(tmp_path: Path) -> None: - chgnet = CHGNet.load() for param in chgnet.composition_model.parameters(): assert param.requires_grad is False - train_loader, val_loader, _test_loader = get_train_val_test_loader( - data, batch_size=16, train_ratio=0.9, val_ratio=0.05 - ) trainer = Trainer( model=chgnet, targets="efsm", @@ -115,3 +112,76 @@ def test_trainer_composition_model(tmp_path: Path) -> None: expect[0][10] = 0 expect[0][16] = 0 assert torch.all(comparison == expect) + + +@pytest.fixture() +def mock_wandb(): + with patch("chgnet.trainer.trainer.wandb") as mock: + yield mock + + +def test_wandb_init(mock_wandb): + chgnet = CHGNet.load() + _trainer = Trainer( + model=chgnet, + wandb_path="test-project/test-run", + wandb_init_kwargs={"tags": ["test"]}, + ) + expected_config = { + "targets": "ef", + "energy_loss_ratio": 1, + "force_loss_ratio": 1, + "stress_loss_ratio": 0.1, + "mag_loss_ratio": 0.1, + "optimizer": "Adam", + "scheduler": "CosLR", + "criterion": "MSE", + "epochs": 50, + "starting_epoch": 0, + "learning_rate": 0.001, + "print_freq": 100, + "torch_seed": None, + "data_seed": None, + "use_device": None, + "check_cuda_mem": False, + "wandb_path": "test-project/test-run", + "wandb_init_kwargs": {"tags": ["test"]}, + "extra_run_config": None, + } + mock_wandb.init.assert_called_once_with( + project="test-project", name="test-run", config=expected_config, tags=["test"] + ) + + +def test_wandb_log_frequency(mock_wandb): + trainer = Trainer(model=chgnet, wandb_path="test-project/test-run", epochs=1) + + # Test epoch logging + trainer.train(train_loader, val_loader, wandb_log_freq="epoch", save_dir="") + assert ( + mock_wandb.log.call_count == 2 * trainer.epochs + ), "Expected one train and one val log per epoch" + + mock_wandb.log.reset_mock() + + # Test batch logging + trainer.train(train_loader, val_loader, wandb_log_freq="batch", save_dir="") + expected_batch_calls = trainer.epochs * len(train_loader) + assert ( + mock_wandb.log.call_count > expected_batch_calls + ), "Expected more calls for batch logging" + + # Test log content (for both epoch and batch logging) + for call_args in mock_wandb.log.call_args_list: + logged_data = call_args[0][0] + assert isinstance(logged_data, dict), "Logged data should be a dictionary" + assert any( + key.endswith("_mae") for key in logged_data + ), "Logged data should contain MAE metrics" + + mock_wandb.log.reset_mock() + + # Test no logging when wandb_path is not provided + trainer_no_wandb = Trainer(model=chgnet, epochs=1) + trainer_no_wandb.train(train_loader, val_loader) + mock_wandb.log.assert_not_called()