Skip to content

Commit

Permalink
allow setting loss_ratio to 0
Browse files Browse the repository at this point in the history
  • Loading branch information
bowen-bd committed Feb 7, 2024
1 parent 44658ec commit f889f69
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 19 deletions.
8 changes: 4 additions & 4 deletions chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ def forward(
"""
out = {"loss": 0.0}
# Energy
if self.energy_loss_ratio != 0 and "e" in targets:
if "e" in targets:
if self.is_intensive:
out["loss"] += self.energy_loss_ratio * self.criterion(
targets["e"], prediction["e"]
Expand All @@ -658,7 +658,7 @@ def forward(
out["e_MAE_size"] = prediction["e"].shape[0]

# Force
if self.force_loss_ratio != 0 and "f" in targets:
if "f" in targets:
forces_pred = torch.cat(prediction["f"], dim=0)
forces_target = torch.cat(targets["f"], dim=0)
out["loss"] += self.force_loss_ratio * self.criterion(
Expand All @@ -668,7 +668,7 @@ def forward(
out["f_MAE_size"] = forces_target.shape[0]

# Stress
if self.stress_loss_ratio != 0 and "s" in targets:
if "s" in targets:
stress_pred = torch.cat(prediction["s"], dim=0)
stress_target = torch.cat(targets["s"], dim=0)
out["loss"] += self.stress_loss_ratio * self.criterion(
Expand All @@ -678,7 +678,7 @@ def forward(
out["s_MAE_size"] = stress_target.shape[0]

# Mag
if self.mag_loss_ratio != 0 and "m" in targets:
if "m" in targets:
mag_preds, mag_targets = [], []
m_mae_size = 0
for mag_pred, mag_target in zip(prediction["m"], targets["m"]):
Expand Down
20 changes: 5 additions & 15 deletions examples/fine_tuning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,13 @@
"execution_count": null,
"id": "7ead933c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CHGNet initialized with 400,438 parameters\n"
]
}
],
"outputs": [],
"source": [
"import numpy as np\n",
"from pymatgen.core import Structure\n",
"\n",
"from chgnet.model import CHGNet\n",
"\n",
"# If the above line fails in Google Colab due to numpy version issue,\n",
"# please restart the runtime, and the problem will be solved\n",
"\n",
"chgnet = CHGNet.load()"
"# please restart the runtime, and the problem will be solved"
]
},
{
Expand Down Expand Up @@ -293,12 +281,14 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CHGNet initialized with 400,438 parameters\n"
"CHGNet v0.3.0 initialized with 412,525 parameters\n",
"CHGNet will run on cpu\n"
]
}
],
"source": [
"from chgnet.trainer import Trainer\n",
"from chgnet.model import CHGNet\n",
"\n",
"# Load pretrained CHGNet\n",
"chgnet = CHGNet.load()"
Expand Down

0 comments on commit f889f69

Please sign in to comment.