From 3209924b2ceb521609da0bbbd120b1ffb4348506 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Wed, 15 Jan 2025 14:00:28 +0100 Subject: [PATCH] Fix falsy commited file --- darts-segmentation/src/darts_segmentation/segment.py | 4 +++- darts-segmentation/src/darts_segmentation/training/module.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/darts-segmentation/src/darts_segmentation/segment.py b/darts-segmentation/src/darts_segmentation/segment.py index 16c109d..545534d 100644 --- a/darts-segmentation/src/darts_segmentation/segment.py +++ b/darts-segmentation/src/darts_segmentation/segment.py @@ -68,7 +68,9 @@ def __init__(self, model_checkpoint: Path | str, device: torch.device = DEFAULT_ self.device = device ckpt = torch.load(model_checkpoint, map_location=self.device) self.config = validate_config(ckpt["config"]) - self.model = smp.create_model(**self.config["model"], encoder_weights=None) + # Overwrite the encoder weights with None, because we load our own + self.config["model"] |= {"encoder_weights": None} + self.model = smp.create_model(**self.config["model"]) self.model.to(self.device) self.model.load_state_dict(ckpt["statedict"]) self.model.eval() diff --git a/darts-segmentation/src/darts_segmentation/training/module.py b/darts-segmentation/src/darts_segmentation/training/module.py index 1af4d97..7f86ea8 100644 --- a/darts-segmentation/src/darts_segmentation/training/module.py +++ b/darts-segmentation/src/darts_segmentation/training/module.py @@ -131,7 +131,7 @@ def validation_step(self, batch, batch_idx): # noqa: D102 # Create figures for the samples (plot at maximum 24) is_last_batch = self.trainer.num_val_batches == (batch_idx + 1) - max_batch_idx = 6 # Does only work if NOT last batch, since last batch may be smaller + max_batch_idx = (24 // x.shape[0]) - 1 # Does only work if NOT last batch, since last batch may be smaller # If num_val_batches is 1 then this batch is the last one, but we still want to log it. despite its size # Will plot the first 24 samples of the first batch if batch-size is larger than 24 should_log_batch = (