-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0416ce8
commit 8692b16
Showing
3 changed files
with
253 additions
and
84 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
"""Shared utilities for the inference modules.""" | ||
|
||
import math | ||
from collections.abc import Callable, Generator | ||
|
||
import torch | ||
|
||
|
||
def patch_coords(h: int, w: int, patch_size: int, overlap: int) -> Generator[tuple[int, int, int, int], None, None]: | ||
"""Yield patch coordinates based on height, width, patch size and margin size. | ||
Args: | ||
h (int): Height of the image. | ||
w (int): Width of the image. | ||
patch_size (int): Patch size. | ||
overlap (int): Margin size. | ||
Yields: | ||
tuple[int, int, int, int]: The patch coordinates y, x, patch_idx_y and patch_idx_x. | ||
""" | ||
step_size = patch_size - overlap | ||
for patch_idx_y, y in enumerate(range(0, h, step_size)): | ||
for patch_idx_x, x in enumerate(range(0, w, step_size)): | ||
if y + patch_size > h: | ||
y = h - patch_size | ||
if x + patch_size > w: | ||
x = w - patch_size | ||
yield y, x, patch_idx_y, patch_idx_x | ||
|
||
|
||
@torch.no_grad() | ||
def create_patches( | ||
tensor_tiles: torch.Tensor, patch_size: int, overlap: int, return_coords: bool = False | ||
) -> torch.Tensor: | ||
"""Create patches from a tensor. | ||
Args: | ||
tensor_tiles (torch.Tensor): The input tensor. Shape: (BS, C, H, W). | ||
patch_size (int, optional): The size of the patches. | ||
overlap (int, optional): The size of the overlap. | ||
return_coords (bool, optional): Whether to return the coordinates of the patches. | ||
Can be used for debugging. Defaults to False. | ||
Returns: | ||
torch.Tensor: The patches. Shape: (BS, N_h, N_w, C, patch_size, patch_size). | ||
""" | ||
assert tensor_tiles.dim() == 4, f"Expects tensor_tiles to has shape (BS, C, H, W), got {tensor_tiles.shape}" | ||
bs, c, h, w = tensor_tiles.shape | ||
assert h > patch_size > overlap | ||
assert w > patch_size > overlap | ||
|
||
step_size = patch_size - overlap | ||
nh, nw = math.ceil(h / step_size), math.ceil(w / step_size) | ||
# Create Patches of size (BS, N_h, N_w, C, patch_size, patch_size) | ||
patches = torch.zeros((bs, nh, nw, c, patch_size, patch_size), device=tensor_tiles.device) | ||
coords = torch.zeros((nh, nw, 5)) | ||
for i, (y, x, patch_idx_h, patch_idx_w) in enumerate(patch_coords(h, w, patch_size, overlap)): | ||
patches[:, patch_idx_h, patch_idx_w, :] = tensor_tiles[:, :, y : y + patch_size, x : x + patch_size] | ||
coords[patch_idx_h, patch_idx_w, :] = torch.tensor([i, y, x, patch_idx_h, patch_idx_w]) | ||
if return_coords: | ||
return patches, coords | ||
else: | ||
return patches | ||
|
||
|
||
@torch.no_grad() | ||
def predict_in_patches( | ||
model: Callable, tensor_tiles: torch.Tensor, patch_size: int = 1024, overlap: int = 16 | ||
) -> torch.Tensor: | ||
"""Predict on a tensor. | ||
Args: | ||
model: The model to use for prediction. | ||
tensor_tiles: The input tensor. Shape: (BS, C, H, W). | ||
patch_size (int): The size of the patches. Defaults to 1024. | ||
overlap (int): The size of the overlap. Defaults to 16. | ||
Returns: | ||
The predicted tensor. | ||
""" | ||
assert tensor_tiles.dim() == 4, f"Expects tensor_tiles to has shape (BS, C, H, W), got {tensor_tiles.shape}" | ||
# Add a 1px border to avoid pixel loss when applying the soft margin | ||
tensor_tiles = torch.nn.functional.pad(tensor_tiles, (1, 1, 1, 1), mode="reflect") | ||
bs, c, h, w = tensor_tiles.shape | ||
step_size = patch_size - overlap | ||
nh, nw = math.ceil(h / step_size), math.ceil(w / step_size) | ||
|
||
# Create Patches of size (BS, N_h, N_w, C, patch_size, patch_size) | ||
patches = create_patches(tensor_tiles, patch_size=patch_size, overlap=overlap) | ||
|
||
print(patches.shape) | ||
# Flatten the patches so they fit to the model | ||
# (BS, N_h, N_w, C, patch_size, patch_size) -> (BS * N_h * N_w, C, patch_size, patch_size) | ||
patches = patches.view(bs * nh * nw, c, patch_size, patch_size) | ||
|
||
# Create a soft margin for the patches | ||
margin_ramp = torch.cat( | ||
[ | ||
torch.linspace(0, 1, overlap), | ||
torch.ones(patch_size - 2 * overlap), | ||
torch.linspace(1, 0, overlap), | ||
] | ||
) | ||
soft_margin = margin_ramp.reshape(1, 1, patch_size) * margin_ramp.reshape(1, patch_size, 1) | ||
|
||
# Infer logits with model and turn into probabilities with sigmoid | ||
patched_probabilities = torch.sigmoid(model(patches)).squeeze(1) | ||
|
||
patched_probabilities = patched_probabilities.view(bs, nh, nw, patch_size, patch_size) | ||
|
||
# Reconstruct the image from the patches | ||
prediction = torch.zeros(bs, h, w, device=tensor_tiles.device) | ||
weights = torch.zeros(bs, h, w, device=tensor_tiles.device) | ||
|
||
for y, x, patch_idx_h, patch_idx_w in patch_coords(h, w, patch_size, overlap): | ||
patch = patched_probabilities[:, patch_idx_h, patch_idx_w] | ||
prediction[:, y : y + patch_size, x : x + patch_size] += patch * soft_margin | ||
weights[:, y : y + patch_size, x : x + patch_size] += soft_margin | ||
|
||
# Avoid division by zero | ||
weights = torch.where(weights == 0, torch.ones_like(weights), weights) | ||
prediction = prediction / weights | ||
|
||
# Remove the 1px border | ||
prediction = prediction[:, 1:-1, 1:-1] | ||
return prediction |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
"""Tests for the utility functions used for patched and stacked prediction.""" | ||
|
||
import math | ||
|
||
import pytest | ||
import torch | ||
|
||
from darts_segmentation.utils import create_patches, patch_coords, predict_in_patches | ||
|
||
|
||
@pytest.mark.parametrize("size", [10, 60, 500, 2000]) | ||
@pytest.mark.parametrize("patch_size", [8, 64, 1024]) | ||
@pytest.mark.parametrize("overlap", [0, 1, 3, 8, 16, 64]) | ||
def test_patch_prediction(size: int, patch_size: int, overlap: int): | ||
"""Tests the prediction function with a mock model (*2) and a random tensor.""" | ||
# Skip tests for invalid parameter to be able to to larger sweeps | ||
if not size > patch_size > overlap: | ||
return | ||
|
||
def model(x): | ||
return 2 * x | ||
|
||
h, w = size, size | ||
tensor_tiles = torch.rand((3, 1, h, w)) | ||
prediction = predict_in_patches(model, tensor_tiles, patch_size=patch_size, overlap=overlap) | ||
prediction_true = torch.sigmoid(2 * tensor_tiles).squeeze(1) | ||
assert prediction.shape == (3, h, w) | ||
torch.testing.assert_allclose(prediction, prediction_true) | ||
|
||
|
||
@pytest.mark.parametrize("size", [10, 60, 500, 2000]) | ||
@pytest.mark.parametrize("patch_size", [8, 64, 1024]) | ||
@pytest.mark.parametrize("overlap", [0, 1, 3, 8, 16, 64]) | ||
def test_create_patches(size: int, patch_size: int, overlap: int): | ||
"""Tests the creation of patches.""" | ||
# Skip tests for invalid parameter to be able to to larger sweeps | ||
if not size > patch_size > overlap: | ||
return | ||
|
||
h, w = size, size | ||
tensor_tiles = torch.rand((3, 1, h, w)) | ||
patches = create_patches(tensor_tiles, patch_size=patch_size, overlap=overlap) | ||
n_patches_h = math.ceil(h / (patch_size - overlap)) | ||
n_patches_w = math.ceil(h / (patch_size - overlap)) | ||
assert patches.shape == (3, n_patches_h, n_patches_w, 1, patch_size, patch_size) | ||
|
||
step_size = patch_size - overlap | ||
for i in range(n_patches_h): | ||
for j in range(n_patches_w): | ||
ipx = i * step_size | ||
jpx = j * step_size | ||
if ipx + patch_size > h: | ||
ipx = h - patch_size | ||
if jpx + patch_size > w: | ||
jpx = w - patch_size | ||
patch = patches[:, i, j] | ||
true_patch = tensor_tiles[:, :, ipx : ipx + patch_size, jpx : jpx + patch_size] | ||
assert patch.shape == (3, 1, patch_size, patch_size) | ||
assert torch.allclose(patch, true_patch) | ||
|
||
|
||
def test_patch_coords_example_generator(): | ||
"""Tests the generation of the generation of patch-coordinates. | ||
Tests the first 20 patch-coordinates for a tile fo size 60x60px with a patch-size of 8 and an overlap of 3. | ||
""" | ||
expected = [ | ||
(0, (0, 0, 0, 0)), | ||
(1, (0, 5, 0, 1)), | ||
(2, (0, 10, 0, 2)), | ||
(3, (0, 15, 0, 3)), | ||
(4, (0, 20, 0, 4)), | ||
(5, (0, 25, 0, 5)), | ||
(6, (0, 30, 0, 6)), | ||
(7, (0, 35, 0, 7)), | ||
(8, (0, 40, 0, 8)), | ||
(9, (0, 45, 0, 9)), | ||
(10, (0, 50, 0, 10)), | ||
(11, (0, 52, 0, 11)), | ||
(12, (5, 0, 1, 0)), | ||
(13, (5, 5, 1, 1)), | ||
(14, (5, 10, 1, 2)), | ||
(15, (5, 15, 1, 3)), | ||
(16, (5, 20, 1, 4)), | ||
(17, (5, 25, 1, 5)), | ||
(18, (5, 30, 1, 6)), | ||
(19, (5, 35, 1, 7)), | ||
] | ||
actual = list(enumerate(patch_coords(60, 60, 8, 3)))[:20] | ||
for expected_coords, actual_coords in zip(expected, actual): | ||
n_exp, (y_exp, x_exp, patch_idx_y_exp, patch_idx_x_exp) = expected_coords | ||
n_act, (y_act, x_act, patch_idx_y_act, patch_idx_x_act) = actual_coords | ||
|
||
assert n_exp == n_act | ||
assert y_exp == y_act | ||
assert x_exp == x_act | ||
assert patch_idx_y_exp == patch_idx_y_act | ||
assert patch_idx_x_exp == patch_idx_x_act | ||
|
||
|
||
@pytest.mark.parametrize("size", [10, 60, 500, 2000]) | ||
@pytest.mark.parametrize("patch_size", [8, 64, 1024]) | ||
@pytest.mark.parametrize("overlap", [0, 1, 3, 8, 16, 64]) | ||
def test_patch_coords_generator_logical(size: int, patch_size: int, overlap: int): | ||
"""Tests the generation of the generation of patch-coordinates. | ||
Tests the first 20 patch-coordinates for a tile fo size 60x60px with a patch-size of 8 and an overlap of 3. | ||
""" | ||
# Skip tests for invalid parameter to be able to to larger sweeps | ||
if not size > patch_size > overlap: | ||
return | ||
|
||
coords = list(enumerate(patch_coords(size, size, patch_size, overlap))) | ||
n_patches_h = math.ceil(size / (patch_size - overlap)) | ||
for n, (y, x, patch_idx_y, patch_idx_x) in coords: | ||
assert y >= 0 | ||
assert x >= 0 | ||
assert patch_idx_y >= 0 | ||
assert patch_idx_x >= 0 | ||
assert y + patch_size <= size | ||
assert x + patch_size <= size | ||
assert n == patch_idx_y * n_patches_h + patch_idx_x |