Skip to content

Commit

Permalink
Fix falsy commited file
Browse files Browse the repository at this point in the history
  • Loading branch information
relativityhd committed Jan 15, 2025
1 parent 15f045c commit 3209924
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 3 additions & 1 deletion darts-segmentation/src/darts_segmentation/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def __init__(self, model_checkpoint: Path | str, device: torch.device = DEFAULT_
self.device = device
ckpt = torch.load(model_checkpoint, map_location=self.device)
self.config = validate_config(ckpt["config"])
self.model = smp.create_model(**self.config["model"], encoder_weights=None)
# Overwrite the encoder weights with None, because we load our own
self.config["model"] |= {"encoder_weights": None}
self.model = smp.create_model(**self.config["model"])
self.model.to(self.device)
self.model.load_state_dict(ckpt["statedict"])
self.model.eval()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def validation_step(self, batch, batch_idx): # noqa: D102

# Create figures for the samples (plot at maximum 24)
is_last_batch = self.trainer.num_val_batches == (batch_idx + 1)
max_batch_idx = 6 # Does only work if NOT last batch, since last batch may be smaller
max_batch_idx = (24 // x.shape[0]) - 1 # Does only work if NOT last batch, since last batch may be smaller
# If num_val_batches is 1 then this batch is the last one, but we still want to log it. despite its size
# Will plot the first 24 samples of the first batch if batch-size is larger than 24
should_log_batch = (
Expand Down

0 comments on commit 3209924

Please sign in to comment.