Skip to content

Commit

Permalink
fix_loss_detach
Browse files Browse the repository at this point in the history
  • Loading branch information
elephaint committed Oct 8, 2024
1 parent ccf8b2d commit 8cba223
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
9 changes: 5 additions & 4 deletions nbs/common.base_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,6 @@
" if self.val_size == 0:\n",
" return\n",
" losses = torch.stack(self.validation_step_outputs)\n",
" avg_loss = losses.mean().detach().item()\n",
" avg_loss = losses.mean().detach()\n",
" self.log(\n",
" \"ptl/val_loss\",\n",
Expand Down Expand Up @@ -1167,14 +1166,15 @@
" print('outsample_y', torch.isnan(outsample_y).sum())\n",
" raise Exception('Loss is NaN, training stopped.')\n",
"\n",
" train_loss_log = loss.detach().item()\n",
" self.log(\n",
" 'train_loss',\n",
" loss.detach(),\n",
" train_loss_log,\n",
" batch_size=outsample_y.size(0),\n",
" prog_bar=True,\n",
" on_epoch=True,\n",
" )\n",
" self.train_trajectories.append((self.global_step, loss.detach()))\n",
" self.train_trajectories.append((self.global_step, train_loss_log))\n",
"\n",
" self.h = self.horizon_backup\n",
"\n",
Expand Down Expand Up @@ -1245,9 +1245,10 @@
" if torch.isnan(valid_loss):\n",
" raise Exception('Loss is NaN, training stopped.')\n",
"\n",
" valid_loss_log = valid_loss.detach()\n",
" self.log(\n",
" 'valid_loss',\n",
" valid_loss.detach(),\n",
" valid_loss_log.item(),\n",
" batch_size=batch_size,\n",
" prog_bar=True,\n",
" on_epoch=True,\n",
Expand Down
9 changes: 5 additions & 4 deletions neuralforecast/common/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,6 @@ def on_validation_epoch_end(self):
if self.val_size == 0:
return
losses = torch.stack(self.validation_step_outputs)
avg_loss = losses.mean().detach().item()
avg_loss = losses.mean().detach()
self.log(
"ptl/val_loss",
Expand Down Expand Up @@ -1268,14 +1267,15 @@ def training_step(self, batch, batch_idx):
print("outsample_y", torch.isnan(outsample_y).sum())
raise Exception("Loss is NaN, training stopped.")

train_loss_log = loss.detach().item()
self.log(
"train_loss",
loss.detach(),
train_loss_log,
batch_size=outsample_y.size(0),
prog_bar=True,
on_epoch=True,
)
self.train_trajectories.append((self.global_step, loss.detach()))
self.train_trajectories.append((self.global_step, train_loss_log))

self.h = self.horizon_backup

Expand Down Expand Up @@ -1361,9 +1361,10 @@ def validation_step(self, batch, batch_idx):
if torch.isnan(valid_loss):
raise Exception("Loss is NaN, training stopped.")

valid_loss_log = valid_loss.detach()
self.log(
"valid_loss",
valid_loss.detach(),
valid_loss_log.item(),
batch_size=batch_size,
prog_bar=True,
on_epoch=True,
Expand Down

0 comments on commit 8cba223

Please sign in to comment.