From 59ba3c8df1e29d9c5c9b278a05b0467f129b5352 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Thu, 2 Jan 2025 09:37:47 +0000 Subject: [PATCH] Specify on_epoch --- torchgeo/trainers/segmentation.py | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 3a58226ed92..997bfc09b6c 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -282,11 +282,19 @@ def training_step( batch_size = x.shape[0] y_hat = self(x) loss: Tensor = self.criterion(y_hat, y) - self.log('train_loss', loss, batch_size=batch_size) + self.log( + 'train_loss', + loss, + batch_size=batch_size, + on_step=True, + on_epoch=True, + ) self.train_metrics(y_hat, y) self.log_dict( {f'{k}': v for k, v in self.train_metrics.compute().items()}, batch_size=batch_size, + on_step=True, + on_epoch=True, ) return loss @@ -305,11 +313,19 @@ def validation_step( batch_size = x.shape[0] y_hat = self(x) loss = self.criterion(y_hat, y) - self.log('val_loss', loss, batch_size=batch_size) + self.log( + 'val_loss', + loss, + batch_size=batch_size, + on_step=False, + on_epoch=True, + ) self.val_metrics(y_hat, y) self.log_dict( {f'{k}': v for k, v in self.val_metrics.compute().items()}, batch_size=batch_size, + on_step=False, + on_epoch=True, ) if ( @@ -352,11 +368,19 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None batch_size = x.shape[0] y_hat = self(x) loss = self.criterion(y_hat, y) - self.log('test_loss', loss, batch_size=batch_size) + self.log( + 'test_loss', + loss, + batch_size=batch_size, + on_step=False, + on_epoch=True, + ) self.test_metrics(y_hat, y) self.log_dict( {f'{k}': v for k, v in self.test_metrics.compute().items()}, batch_size=batch_size, + on_step=False, + on_epoch=True, ) def predict_step(