From 8692b16b1e057fd1908d01b19baa3a3508596078 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Tue, 15 Oct 2024 17:42:44 +0200 Subject: [PATCH] Add tests for the patched inference --- .../src/darts_segmentation/segment.py | 86 +----------- .../src/darts_segmentation/utils.py | 129 ++++++++++++++++++ darts-segmentation/tests/test_patches.py | 122 +++++++++++++++++ 3 files changed, 253 insertions(+), 84 deletions(-) create mode 100644 darts-segmentation/src/darts_segmentation/utils.py create mode 100644 darts-segmentation/tests/test_patches.py diff --git a/darts-segmentation/src/darts_segmentation/segment.py b/darts-segmentation/src/darts_segmentation/segment.py index dfd84a1..01e317a 100644 --- a/darts-segmentation/src/darts_segmentation/segment.py +++ b/darts-segmentation/src/darts_segmentation/segment.py @@ -1,6 +1,5 @@ """Functionality for segmenting tiles.""" -from collections.abc import Callable, Generator from pathlib import Path from typing import Any, TypedDict @@ -9,6 +8,8 @@ import torch.nn as nn import xarray as xr +from darts_segmentation.utils import predict_in_patches + class SMPSegmenterConfig(TypedDict): """Configuration for the segmentor.""" @@ -18,89 +19,6 @@ class SMPSegmenterConfig(TypedDict): # patch_size: int -def patch_coords(h: int, w: int, patch_size: int, margin_size: int) -> Generator[tuple[int, int, int, int], None, None]: - """Yield patch coordinates based on height, width, patch size and margin size. - - Args: - h (int): Height of the image. - w (int): Width of the image. - patch_size (int): Patch size. - margin_size (int): Margin size. - - Yields: - tuple[int, int, int, int]: The patch coordinates y, x, patch_idx_h and patch_idx_w. - - """ - step_size = patch_size - margin_size - for y in range(0, h, step_size): - for x in range(0, w, step_size): - if y + patch_size > h: - y = h - patch_size - if x + patch_size > w: - x = w - patch_size - patch_idx_h = y // step_size - patch_idx_w = x // step_size - yield y, x, patch_idx_h, patch_idx_w - - -@torch.no_grad() -def predict_in_patches( - model: Callable, tensor_tiles: torch.Tensor, patch_size: int = 1024, margin_size: int = 16 -) -> torch.Tensor: - """Predict on a tensor. - - Args: - model: The model to use for prediction. - tensor_tiles: The input tensor. Shape: (BS, C, H, W). - patch_size (int): The size of the patches. Defaults to 1024. - margin_size (int): The size of the margin. Defaults to 16. - - Returns: - The predicted tensor. - - """ - assert tensor_tiles.dim() == 4, f"Expects tensor_tiles to has shape (BS, C, H, W), got {tensor_tiles.shape}" - bs, c, h, w = tensor_tiles.shape - step_size = patch_size - margin_size - nh, nw = h // step_size, w // step_size - - # Create Patches of size (BS, N_h, N_w, C, patch_size, patch_size) - patches = torch.zeros((bs, nh, nw, c, patch_size, patch_size), device=tensor_tiles.device) - for y, x, patch_idx_h, patch_idx_w in patch_coords(h, w, patch_size, margin_size): - patches[:, patch_idx_h, patch_idx_w, :] = tensor_tiles[:, :, y : y + patch_size, x : x + patch_size] - - # Flatten the patches so they fit to the model - # (BS, N_h, N_w, C, patch_size, patch_size) -> (BS * N_h * N_w, C, patch_size, patch_size) - patches = patches.view(bs * nh * nw, c, patch_size, patch_size) - - # Create a soft margin for the patches - margin_ramp = torch.cat( - [ - torch.linspace(0, 1, margin_size), - torch.ones(patch_size - 2 * margin_size), - torch.linspace(1, 0, margin_size), - ] - ) - soft_margin = margin_ramp.reshape(1, 1, patch_size) * margin_ramp.reshape(1, patch_size, 1) - - # Infer logits with model and turn into probabilities with sigmoid - patched_logits = model(patches) - patched_probabilities = torch.sigmoid(patched_logits) - - # Reconstruct the image from the patches - prediction = torch.zeros(bs, h, w, device=tensor_tiles.device) - weights = torch.zeros(bs, h, w, device=tensor_tiles.device) - - for y, x, patch_idx_h, patch_idx_w in patch_coords(h, w, patch_size, margin_size): - patch = patched_probabilities[patch_idx_h, patch_idx_w] - prediction[:, y : y + patch_size, x : x + patch_size] += patch * soft_margin - weights[:, y : y + patch_size, x : x + patch_size] += soft_margin - - # Avoid division by zero - weights = torch.where(weights == 0, torch.ones_like(weights), weights) - return prediction / weights - - class SMPSegmenter: """An actor that keeps a model as its state and segments tiles.""" diff --git a/darts-segmentation/src/darts_segmentation/utils.py b/darts-segmentation/src/darts_segmentation/utils.py new file mode 100644 index 0000000..d1940d9 --- /dev/null +++ b/darts-segmentation/src/darts_segmentation/utils.py @@ -0,0 +1,129 @@ +"""Shared utilities for the inference modules.""" + +import math +from collections.abc import Callable, Generator + +import torch + + +def patch_coords(h: int, w: int, patch_size: int, overlap: int) -> Generator[tuple[int, int, int, int], None, None]: + """Yield patch coordinates based on height, width, patch size and margin size. + + Args: + h (int): Height of the image. + w (int): Width of the image. + patch_size (int): Patch size. + overlap (int): Margin size. + + Yields: + tuple[int, int, int, int]: The patch coordinates y, x, patch_idx_y and patch_idx_x. + + """ + step_size = patch_size - overlap + for patch_idx_y, y in enumerate(range(0, h, step_size)): + for patch_idx_x, x in enumerate(range(0, w, step_size)): + if y + patch_size > h: + y = h - patch_size + if x + patch_size > w: + x = w - patch_size + yield y, x, patch_idx_y, patch_idx_x + + +@torch.no_grad() +def create_patches( + tensor_tiles: torch.Tensor, patch_size: int, overlap: int, return_coords: bool = False +) -> torch.Tensor: + """Create patches from a tensor. + + Args: + tensor_tiles (torch.Tensor): The input tensor. Shape: (BS, C, H, W). + patch_size (int, optional): The size of the patches. + overlap (int, optional): The size of the overlap. + return_coords (bool, optional): Whether to return the coordinates of the patches. + Can be used for debugging. Defaults to False. + + Returns: + torch.Tensor: The patches. Shape: (BS, N_h, N_w, C, patch_size, patch_size). + + """ + assert tensor_tiles.dim() == 4, f"Expects tensor_tiles to has shape (BS, C, H, W), got {tensor_tiles.shape}" + bs, c, h, w = tensor_tiles.shape + assert h > patch_size > overlap + assert w > patch_size > overlap + + step_size = patch_size - overlap + nh, nw = math.ceil(h / step_size), math.ceil(w / step_size) + # Create Patches of size (BS, N_h, N_w, C, patch_size, patch_size) + patches = torch.zeros((bs, nh, nw, c, patch_size, patch_size), device=tensor_tiles.device) + coords = torch.zeros((nh, nw, 5)) + for i, (y, x, patch_idx_h, patch_idx_w) in enumerate(patch_coords(h, w, patch_size, overlap)): + patches[:, patch_idx_h, patch_idx_w, :] = tensor_tiles[:, :, y : y + patch_size, x : x + patch_size] + coords[patch_idx_h, patch_idx_w, :] = torch.tensor([i, y, x, patch_idx_h, patch_idx_w]) + if return_coords: + return patches, coords + else: + return patches + + +@torch.no_grad() +def predict_in_patches( + model: Callable, tensor_tiles: torch.Tensor, patch_size: int = 1024, overlap: int = 16 +) -> torch.Tensor: + """Predict on a tensor. + + Args: + model: The model to use for prediction. + tensor_tiles: The input tensor. Shape: (BS, C, H, W). + patch_size (int): The size of the patches. Defaults to 1024. + overlap (int): The size of the overlap. Defaults to 16. + + Returns: + The predicted tensor. + + """ + assert tensor_tiles.dim() == 4, f"Expects tensor_tiles to has shape (BS, C, H, W), got {tensor_tiles.shape}" + # Add a 1px border to avoid pixel loss when applying the soft margin + tensor_tiles = torch.nn.functional.pad(tensor_tiles, (1, 1, 1, 1), mode="reflect") + bs, c, h, w = tensor_tiles.shape + step_size = patch_size - overlap + nh, nw = math.ceil(h / step_size), math.ceil(w / step_size) + + # Create Patches of size (BS, N_h, N_w, C, patch_size, patch_size) + patches = create_patches(tensor_tiles, patch_size=patch_size, overlap=overlap) + + print(patches.shape) + # Flatten the patches so they fit to the model + # (BS, N_h, N_w, C, patch_size, patch_size) -> (BS * N_h * N_w, C, patch_size, patch_size) + patches = patches.view(bs * nh * nw, c, patch_size, patch_size) + + # Create a soft margin for the patches + margin_ramp = torch.cat( + [ + torch.linspace(0, 1, overlap), + torch.ones(patch_size - 2 * overlap), + torch.linspace(1, 0, overlap), + ] + ) + soft_margin = margin_ramp.reshape(1, 1, patch_size) * margin_ramp.reshape(1, patch_size, 1) + + # Infer logits with model and turn into probabilities with sigmoid + patched_probabilities = torch.sigmoid(model(patches)).squeeze(1) + + patched_probabilities = patched_probabilities.view(bs, nh, nw, patch_size, patch_size) + + # Reconstruct the image from the patches + prediction = torch.zeros(bs, h, w, device=tensor_tiles.device) + weights = torch.zeros(bs, h, w, device=tensor_tiles.device) + + for y, x, patch_idx_h, patch_idx_w in patch_coords(h, w, patch_size, overlap): + patch = patched_probabilities[:, patch_idx_h, patch_idx_w] + prediction[:, y : y + patch_size, x : x + patch_size] += patch * soft_margin + weights[:, y : y + patch_size, x : x + patch_size] += soft_margin + + # Avoid division by zero + weights = torch.where(weights == 0, torch.ones_like(weights), weights) + prediction = prediction / weights + + # Remove the 1px border + prediction = prediction[:, 1:-1, 1:-1] + return prediction diff --git a/darts-segmentation/tests/test_patches.py b/darts-segmentation/tests/test_patches.py new file mode 100644 index 0000000..8acc18d --- /dev/null +++ b/darts-segmentation/tests/test_patches.py @@ -0,0 +1,122 @@ +"""Tests for the utility functions used for patched and stacked prediction.""" + +import math + +import pytest +import torch + +from darts_segmentation.utils import create_patches, patch_coords, predict_in_patches + + +@pytest.mark.parametrize("size", [10, 60, 500, 2000]) +@pytest.mark.parametrize("patch_size", [8, 64, 1024]) +@pytest.mark.parametrize("overlap", [0, 1, 3, 8, 16, 64]) +def test_patch_prediction(size: int, patch_size: int, overlap: int): + """Tests the prediction function with a mock model (*2) and a random tensor.""" + # Skip tests for invalid parameter to be able to to larger sweeps + if not size > patch_size > overlap: + return + + def model(x): + return 2 * x + + h, w = size, size + tensor_tiles = torch.rand((3, 1, h, w)) + prediction = predict_in_patches(model, tensor_tiles, patch_size=patch_size, overlap=overlap) + prediction_true = torch.sigmoid(2 * tensor_tiles).squeeze(1) + assert prediction.shape == (3, h, w) + torch.testing.assert_allclose(prediction, prediction_true) + + +@pytest.mark.parametrize("size", [10, 60, 500, 2000]) +@pytest.mark.parametrize("patch_size", [8, 64, 1024]) +@pytest.mark.parametrize("overlap", [0, 1, 3, 8, 16, 64]) +def test_create_patches(size: int, patch_size: int, overlap: int): + """Tests the creation of patches.""" + # Skip tests for invalid parameter to be able to to larger sweeps + if not size > patch_size > overlap: + return + + h, w = size, size + tensor_tiles = torch.rand((3, 1, h, w)) + patches = create_patches(tensor_tiles, patch_size=patch_size, overlap=overlap) + n_patches_h = math.ceil(h / (patch_size - overlap)) + n_patches_w = math.ceil(h / (patch_size - overlap)) + assert patches.shape == (3, n_patches_h, n_patches_w, 1, patch_size, patch_size) + + step_size = patch_size - overlap + for i in range(n_patches_h): + for j in range(n_patches_w): + ipx = i * step_size + jpx = j * step_size + if ipx + patch_size > h: + ipx = h - patch_size + if jpx + patch_size > w: + jpx = w - patch_size + patch = patches[:, i, j] + true_patch = tensor_tiles[:, :, ipx : ipx + patch_size, jpx : jpx + patch_size] + assert patch.shape == (3, 1, patch_size, patch_size) + assert torch.allclose(patch, true_patch) + + +def test_patch_coords_example_generator(): + """Tests the generation of the generation of patch-coordinates. + + Tests the first 20 patch-coordinates for a tile fo size 60x60px with a patch-size of 8 and an overlap of 3. + """ + expected = [ + (0, (0, 0, 0, 0)), + (1, (0, 5, 0, 1)), + (2, (0, 10, 0, 2)), + (3, (0, 15, 0, 3)), + (4, (0, 20, 0, 4)), + (5, (0, 25, 0, 5)), + (6, (0, 30, 0, 6)), + (7, (0, 35, 0, 7)), + (8, (0, 40, 0, 8)), + (9, (0, 45, 0, 9)), + (10, (0, 50, 0, 10)), + (11, (0, 52, 0, 11)), + (12, (5, 0, 1, 0)), + (13, (5, 5, 1, 1)), + (14, (5, 10, 1, 2)), + (15, (5, 15, 1, 3)), + (16, (5, 20, 1, 4)), + (17, (5, 25, 1, 5)), + (18, (5, 30, 1, 6)), + (19, (5, 35, 1, 7)), + ] + actual = list(enumerate(patch_coords(60, 60, 8, 3)))[:20] + for expected_coords, actual_coords in zip(expected, actual): + n_exp, (y_exp, x_exp, patch_idx_y_exp, patch_idx_x_exp) = expected_coords + n_act, (y_act, x_act, patch_idx_y_act, patch_idx_x_act) = actual_coords + + assert n_exp == n_act + assert y_exp == y_act + assert x_exp == x_act + assert patch_idx_y_exp == patch_idx_y_act + assert patch_idx_x_exp == patch_idx_x_act + + +@pytest.mark.parametrize("size", [10, 60, 500, 2000]) +@pytest.mark.parametrize("patch_size", [8, 64, 1024]) +@pytest.mark.parametrize("overlap", [0, 1, 3, 8, 16, 64]) +def test_patch_coords_generator_logical(size: int, patch_size: int, overlap: int): + """Tests the generation of the generation of patch-coordinates. + + Tests the first 20 patch-coordinates for a tile fo size 60x60px with a patch-size of 8 and an overlap of 3. + """ + # Skip tests for invalid parameter to be able to to larger sweeps + if not size > patch_size > overlap: + return + + coords = list(enumerate(patch_coords(size, size, patch_size, overlap))) + n_patches_h = math.ceil(size / (patch_size - overlap)) + for n, (y, x, patch_idx_y, patch_idx_x) in coords: + assert y >= 0 + assert x >= 0 + assert patch_idx_y >= 0 + assert patch_idx_x >= 0 + assert y + patch_size <= size + assert x + patch_size <= size + assert n == patch_idx_y * n_patches_h + patch_idx_x