Skip to content

Commit

Permalink
fixed bug in energy loss
Browse files Browse the repository at this point in the history
  • Loading branch information
bowen-bd committed Jan 5, 2025
1 parent 84e8d55 commit 5ef8876
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 13 deletions.
12 changes: 0 additions & 12 deletions chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ def __init__(
self.criterion = CombinedLoss(
target_str=self.targets,
criterion=criterion,
is_intensive=self.model.is_intensive,
energy_loss_ratio=energy_loss_ratio,
force_loss_ratio=force_loss_ratio,
stress_loss_ratio=stress_loss_ratio,
Expand Down Expand Up @@ -725,7 +724,6 @@ def __init__(
*,
target_str: str = "ef",
criterion: str = "MSE",
is_intensive: bool = True,
energy_loss_ratio: float = 1,
force_loss_ratio: float = 1,
stress_loss_ratio: float = 0.1,
Expand All @@ -740,8 +738,6 @@ def __init__(
Default = "ef"
criterion: loss criterion to use
Default = "MSE"
is_intensive (bool): whether the energy label is intensive
Default = True
energy_loss_ratio (float): energy loss ratio in loss function
Default = 1
force_loss_ratio (float): force loss ratio in loss function
Expand All @@ -765,7 +761,6 @@ def __init__(
else:
raise NotImplementedError
self.target_str = target_str
self.is_intensive = is_intensive
self.energy_loss_ratio = energy_loss_ratio
if "f" not in self.target_str:
self.force_loss_ratio = 0
Expand Down Expand Up @@ -803,19 +798,12 @@ def forward(
if self.allow_missing_labels:
valid_value_indices = ~torch.isnan(targets["e"])
valid_e_target = targets["e"][valid_value_indices]
valid_atoms_per_graph = prediction["atoms_per_graph"][
valid_value_indices
]
valid_e_pred = prediction["e"][valid_value_indices]
if valid_e_pred.shape == torch.Size([]):
valid_e_pred = valid_e_pred.view(1)
else:
valid_e_target = targets["e"]
valid_atoms_per_graph = prediction["atoms_per_graph"]
valid_e_pred = prediction["e"]
if self.is_intensive:
valid_e_target = valid_e_target / valid_atoms_per_graph
valid_e_pred = valid_e_pred / valid_atoms_per_graph

out["loss"] += self.energy_loss_ratio * self.criterion(
valid_e_target, valid_e_pred
Expand Down
56 changes: 55 additions & 1 deletion tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from chgnet.data.dataset import StructureData, get_train_val_test_loader
from chgnet.model import CHGNet
from chgnet.trainer import Trainer
from chgnet.trainer.trainer import CombinedLoss, Trainer

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -50,6 +50,60 @@
chgnet = CHGNet.load()


def test_combined_loss() -> None:
criterion = CombinedLoss(
target_str="ef",
criterion="MSE",
energy_loss_ratio=1,
force_loss_ratio=1,
stress_loss_ratio=0.1,
mag_loss_ratio=0.1,
allow_missing_labels=False,
)
target1 = {"e": torch.Tensor([1]), "f": [torch.Tensor([[[1, 1, 1], [2, 2, 2]]])]}
prediction1 = chgnet.predict_structure(NaCl)
prediction1 = {
"e": torch.from_numpy(prediction1["e"]).unsqueeze(0),
"f": [torch.from_numpy(prediction1["f"])],
"atoms_per_graph": torch.tensor([2]),
}
out1 = criterion(
targets=target1,
prediction=prediction1,
)
target2 = {
"e": torch.Tensor([1]),
"f": [
torch.Tensor(
[
[
[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
]
]
)
],
}
supercell = NaCl.make_supercell([2, 2, 1], in_place=False)
prediction2 = chgnet.predict_structure(supercell)
prediction2 = {
"e": torch.from_numpy(prediction2["e"]).unsqueeze(0),
"f": [torch.from_numpy(prediction2["f"])],
"atoms_per_graph": torch.tensor([8]),
}
out2 = criterion(
targets=target2,
prediction=prediction2,
)
assert np.isclose(out1["loss"], out2["loss"], rtol=1e-04, atol=1e-05)


def test_trainer(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
extra_run_config = dict(some_other_hyperparam=42)
trainer = Trainer(
Expand Down

0 comments on commit 5ef8876

Please sign in to comment.