From f889f6933d402664050cb4eabe2a4beb81a12e21 Mon Sep 17 00:00:00 2001 From: BowenD-UCB <84425382+BowenD-UCB@users.noreply.github.com> Date: Wed, 7 Feb 2024 12:47:16 -0800 Subject: [PATCH] allow setting loss_ratio to 0 --- chgnet/trainer/trainer.py | 8 ++++---- examples/fine_tuning.ipynb | 20 +++++--------------- 2 files changed, 9 insertions(+), 19 deletions(-) diff --git a/chgnet/trainer/trainer.py b/chgnet/trainer/trainer.py index 886c4d39..3f90ac8f 100644 --- a/chgnet/trainer/trainer.py +++ b/chgnet/trainer/trainer.py @@ -643,7 +643,7 @@ def forward( """ out = {"loss": 0.0} # Energy - if self.energy_loss_ratio != 0 and "e" in targets: + if "e" in targets: if self.is_intensive: out["loss"] += self.energy_loss_ratio * self.criterion( targets["e"], prediction["e"] @@ -658,7 +658,7 @@ def forward( out["e_MAE_size"] = prediction["e"].shape[0] # Force - if self.force_loss_ratio != 0 and "f" in targets: + if "f" in targets: forces_pred = torch.cat(prediction["f"], dim=0) forces_target = torch.cat(targets["f"], dim=0) out["loss"] += self.force_loss_ratio * self.criterion( @@ -668,7 +668,7 @@ def forward( out["f_MAE_size"] = forces_target.shape[0] # Stress - if self.stress_loss_ratio != 0 and "s" in targets: + if "s" in targets: stress_pred = torch.cat(prediction["s"], dim=0) stress_target = torch.cat(targets["s"], dim=0) out["loss"] += self.stress_loss_ratio * self.criterion( @@ -678,7 +678,7 @@ def forward( out["s_MAE_size"] = stress_target.shape[0] # Mag - if self.mag_loss_ratio != 0 and "m" in targets: + if "m" in targets: mag_preds, mag_targets = [], [] m_mae_size = 0 for mag_pred, mag_target in zip(prediction["m"], targets["m"]): diff --git a/examples/fine_tuning.ipynb b/examples/fine_tuning.ipynb index b844231f..4984dba3 100644 --- a/examples/fine_tuning.ipynb +++ b/examples/fine_tuning.ipynb @@ -27,25 +27,13 @@ "execution_count": null, "id": "7ead933c", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CHGNet initialized with 400,438 parameters\n" - ] - } - ], + "outputs": [], "source": [ "import numpy as np\n", "from pymatgen.core import Structure\n", "\n", - "from chgnet.model import CHGNet\n", - "\n", "# If the above line fails in Google Colab due to numpy version issue,\n", - "# please restart the runtime, and the problem will be solved\n", - "\n", - "chgnet = CHGNet.load()" + "# please restart the runtime, and the problem will be solved" ] }, { @@ -293,12 +281,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "CHGNet initialized with 400,438 parameters\n" + "CHGNet v0.3.0 initialized with 412,525 parameters\n", + "CHGNet will run on cpu\n" ] } ], "source": [ "from chgnet.trainer import Trainer\n", + "from chgnet.model import CHGNet\n", "\n", "# Load pretrained CHGNet\n", "chgnet = CHGNet.load()"