Skip to content

Commit

Permalink
Add tests for the patched inference
Browse files Browse the repository at this point in the history
  • Loading branch information
relativityhd committed Oct 15, 2024
1 parent 0416ce8 commit 8692b16
Show file tree
Hide file tree
Showing 3 changed files with 253 additions and 84 deletions.
86 changes: 2 additions & 84 deletions darts-segmentation/src/darts_segmentation/segment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Functionality for segmenting tiles."""

from collections.abc import Callable, Generator
from pathlib import Path
from typing import Any, TypedDict

Expand All @@ -9,6 +8,8 @@
import torch.nn as nn
import xarray as xr

from darts_segmentation.utils import predict_in_patches


class SMPSegmenterConfig(TypedDict):
"""Configuration for the segmentor."""
Expand All @@ -18,89 +19,6 @@ class SMPSegmenterConfig(TypedDict):
# patch_size: int


def patch_coords(h: int, w: int, patch_size: int, margin_size: int) -> Generator[tuple[int, int, int, int], None, None]:
"""Yield patch coordinates based on height, width, patch size and margin size.
Args:
h (int): Height of the image.
w (int): Width of the image.
patch_size (int): Patch size.
margin_size (int): Margin size.
Yields:
tuple[int, int, int, int]: The patch coordinates y, x, patch_idx_h and patch_idx_w.
"""
step_size = patch_size - margin_size
for y in range(0, h, step_size):
for x in range(0, w, step_size):
if y + patch_size > h:
y = h - patch_size
if x + patch_size > w:
x = w - patch_size
patch_idx_h = y // step_size
patch_idx_w = x // step_size
yield y, x, patch_idx_h, patch_idx_w


@torch.no_grad()
def predict_in_patches(
model: Callable, tensor_tiles: torch.Tensor, patch_size: int = 1024, margin_size: int = 16
) -> torch.Tensor:
"""Predict on a tensor.
Args:
model: The model to use for prediction.
tensor_tiles: The input tensor. Shape: (BS, C, H, W).
patch_size (int): The size of the patches. Defaults to 1024.
margin_size (int): The size of the margin. Defaults to 16.
Returns:
The predicted tensor.
"""
assert tensor_tiles.dim() == 4, f"Expects tensor_tiles to has shape (BS, C, H, W), got {tensor_tiles.shape}"
bs, c, h, w = tensor_tiles.shape
step_size = patch_size - margin_size
nh, nw = h // step_size, w // step_size

# Create Patches of size (BS, N_h, N_w, C, patch_size, patch_size)
patches = torch.zeros((bs, nh, nw, c, patch_size, patch_size), device=tensor_tiles.device)
for y, x, patch_idx_h, patch_idx_w in patch_coords(h, w, patch_size, margin_size):
patches[:, patch_idx_h, patch_idx_w, :] = tensor_tiles[:, :, y : y + patch_size, x : x + patch_size]

# Flatten the patches so they fit to the model
# (BS, N_h, N_w, C, patch_size, patch_size) -> (BS * N_h * N_w, C, patch_size, patch_size)
patches = patches.view(bs * nh * nw, c, patch_size, patch_size)

# Create a soft margin for the patches
margin_ramp = torch.cat(
[
torch.linspace(0, 1, margin_size),
torch.ones(patch_size - 2 * margin_size),
torch.linspace(1, 0, margin_size),
]
)
soft_margin = margin_ramp.reshape(1, 1, patch_size) * margin_ramp.reshape(1, patch_size, 1)

# Infer logits with model and turn into probabilities with sigmoid
patched_logits = model(patches)
patched_probabilities = torch.sigmoid(patched_logits)

# Reconstruct the image from the patches
prediction = torch.zeros(bs, h, w, device=tensor_tiles.device)
weights = torch.zeros(bs, h, w, device=tensor_tiles.device)

for y, x, patch_idx_h, patch_idx_w in patch_coords(h, w, patch_size, margin_size):
patch = patched_probabilities[patch_idx_h, patch_idx_w]
prediction[:, y : y + patch_size, x : x + patch_size] += patch * soft_margin
weights[:, y : y + patch_size, x : x + patch_size] += soft_margin

# Avoid division by zero
weights = torch.where(weights == 0, torch.ones_like(weights), weights)
return prediction / weights


class SMPSegmenter:
"""An actor that keeps a model as its state and segments tiles."""

Expand Down
129 changes: 129 additions & 0 deletions darts-segmentation/src/darts_segmentation/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""Shared utilities for the inference modules."""

import math
from collections.abc import Callable, Generator

import torch


def patch_coords(h: int, w: int, patch_size: int, overlap: int) -> Generator[tuple[int, int, int, int], None, None]:
"""Yield patch coordinates based on height, width, patch size and margin size.
Args:
h (int): Height of the image.
w (int): Width of the image.
patch_size (int): Patch size.
overlap (int): Margin size.
Yields:
tuple[int, int, int, int]: The patch coordinates y, x, patch_idx_y and patch_idx_x.
"""
step_size = patch_size - overlap
for patch_idx_y, y in enumerate(range(0, h, step_size)):
for patch_idx_x, x in enumerate(range(0, w, step_size)):
if y + patch_size > h:
y = h - patch_size
if x + patch_size > w:
x = w - patch_size
yield y, x, patch_idx_y, patch_idx_x


@torch.no_grad()
def create_patches(
tensor_tiles: torch.Tensor, patch_size: int, overlap: int, return_coords: bool = False
) -> torch.Tensor:
"""Create patches from a tensor.
Args:
tensor_tiles (torch.Tensor): The input tensor. Shape: (BS, C, H, W).
patch_size (int, optional): The size of the patches.
overlap (int, optional): The size of the overlap.
return_coords (bool, optional): Whether to return the coordinates of the patches.
Can be used for debugging. Defaults to False.
Returns:
torch.Tensor: The patches. Shape: (BS, N_h, N_w, C, patch_size, patch_size).
"""
assert tensor_tiles.dim() == 4, f"Expects tensor_tiles to has shape (BS, C, H, W), got {tensor_tiles.shape}"
bs, c, h, w = tensor_tiles.shape
assert h > patch_size > overlap
assert w > patch_size > overlap

step_size = patch_size - overlap
nh, nw = math.ceil(h / step_size), math.ceil(w / step_size)
# Create Patches of size (BS, N_h, N_w, C, patch_size, patch_size)
patches = torch.zeros((bs, nh, nw, c, patch_size, patch_size), device=tensor_tiles.device)
coords = torch.zeros((nh, nw, 5))
for i, (y, x, patch_idx_h, patch_idx_w) in enumerate(patch_coords(h, w, patch_size, overlap)):
patches[:, patch_idx_h, patch_idx_w, :] = tensor_tiles[:, :, y : y + patch_size, x : x + patch_size]
coords[patch_idx_h, patch_idx_w, :] = torch.tensor([i, y, x, patch_idx_h, patch_idx_w])
if return_coords:
return patches, coords
else:
return patches


@torch.no_grad()
def predict_in_patches(
model: Callable, tensor_tiles: torch.Tensor, patch_size: int = 1024, overlap: int = 16
) -> torch.Tensor:
"""Predict on a tensor.
Args:
model: The model to use for prediction.
tensor_tiles: The input tensor. Shape: (BS, C, H, W).
patch_size (int): The size of the patches. Defaults to 1024.
overlap (int): The size of the overlap. Defaults to 16.
Returns:
The predicted tensor.
"""
assert tensor_tiles.dim() == 4, f"Expects tensor_tiles to has shape (BS, C, H, W), got {tensor_tiles.shape}"
# Add a 1px border to avoid pixel loss when applying the soft margin
tensor_tiles = torch.nn.functional.pad(tensor_tiles, (1, 1, 1, 1), mode="reflect")
bs, c, h, w = tensor_tiles.shape
step_size = patch_size - overlap
nh, nw = math.ceil(h / step_size), math.ceil(w / step_size)

# Create Patches of size (BS, N_h, N_w, C, patch_size, patch_size)
patches = create_patches(tensor_tiles, patch_size=patch_size, overlap=overlap)

print(patches.shape)
# Flatten the patches so they fit to the model
# (BS, N_h, N_w, C, patch_size, patch_size) -> (BS * N_h * N_w, C, patch_size, patch_size)
patches = patches.view(bs * nh * nw, c, patch_size, patch_size)

# Create a soft margin for the patches
margin_ramp = torch.cat(
[
torch.linspace(0, 1, overlap),
torch.ones(patch_size - 2 * overlap),
torch.linspace(1, 0, overlap),
]
)
soft_margin = margin_ramp.reshape(1, 1, patch_size) * margin_ramp.reshape(1, patch_size, 1)

# Infer logits with model and turn into probabilities with sigmoid
patched_probabilities = torch.sigmoid(model(patches)).squeeze(1)

patched_probabilities = patched_probabilities.view(bs, nh, nw, patch_size, patch_size)

# Reconstruct the image from the patches
prediction = torch.zeros(bs, h, w, device=tensor_tiles.device)
weights = torch.zeros(bs, h, w, device=tensor_tiles.device)

for y, x, patch_idx_h, patch_idx_w in patch_coords(h, w, patch_size, overlap):
patch = patched_probabilities[:, patch_idx_h, patch_idx_w]
prediction[:, y : y + patch_size, x : x + patch_size] += patch * soft_margin
weights[:, y : y + patch_size, x : x + patch_size] += soft_margin

# Avoid division by zero
weights = torch.where(weights == 0, torch.ones_like(weights), weights)
prediction = prediction / weights

# Remove the 1px border
prediction = prediction[:, 1:-1, 1:-1]
return prediction
122 changes: 122 additions & 0 deletions darts-segmentation/tests/test_patches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""Tests for the utility functions used for patched and stacked prediction."""

import math

import pytest
import torch

from darts_segmentation.utils import create_patches, patch_coords, predict_in_patches


@pytest.mark.parametrize("size", [10, 60, 500, 2000])
@pytest.mark.parametrize("patch_size", [8, 64, 1024])
@pytest.mark.parametrize("overlap", [0, 1, 3, 8, 16, 64])
def test_patch_prediction(size: int, patch_size: int, overlap: int):
"""Tests the prediction function with a mock model (*2) and a random tensor."""
# Skip tests for invalid parameter to be able to to larger sweeps
if not size > patch_size > overlap:
return

def model(x):
return 2 * x

h, w = size, size
tensor_tiles = torch.rand((3, 1, h, w))
prediction = predict_in_patches(model, tensor_tiles, patch_size=patch_size, overlap=overlap)
prediction_true = torch.sigmoid(2 * tensor_tiles).squeeze(1)
assert prediction.shape == (3, h, w)
torch.testing.assert_allclose(prediction, prediction_true)


@pytest.mark.parametrize("size", [10, 60, 500, 2000])
@pytest.mark.parametrize("patch_size", [8, 64, 1024])
@pytest.mark.parametrize("overlap", [0, 1, 3, 8, 16, 64])
def test_create_patches(size: int, patch_size: int, overlap: int):
"""Tests the creation of patches."""
# Skip tests for invalid parameter to be able to to larger sweeps
if not size > patch_size > overlap:
return

h, w = size, size
tensor_tiles = torch.rand((3, 1, h, w))
patches = create_patches(tensor_tiles, patch_size=patch_size, overlap=overlap)
n_patches_h = math.ceil(h / (patch_size - overlap))
n_patches_w = math.ceil(h / (patch_size - overlap))
assert patches.shape == (3, n_patches_h, n_patches_w, 1, patch_size, patch_size)

step_size = patch_size - overlap
for i in range(n_patches_h):
for j in range(n_patches_w):
ipx = i * step_size
jpx = j * step_size
if ipx + patch_size > h:
ipx = h - patch_size
if jpx + patch_size > w:
jpx = w - patch_size
patch = patches[:, i, j]
true_patch = tensor_tiles[:, :, ipx : ipx + patch_size, jpx : jpx + patch_size]
assert patch.shape == (3, 1, patch_size, patch_size)
assert torch.allclose(patch, true_patch)


def test_patch_coords_example_generator():
"""Tests the generation of the generation of patch-coordinates.
Tests the first 20 patch-coordinates for a tile fo size 60x60px with a patch-size of 8 and an overlap of 3.
"""
expected = [
(0, (0, 0, 0, 0)),
(1, (0, 5, 0, 1)),
(2, (0, 10, 0, 2)),
(3, (0, 15, 0, 3)),
(4, (0, 20, 0, 4)),
(5, (0, 25, 0, 5)),
(6, (0, 30, 0, 6)),
(7, (0, 35, 0, 7)),
(8, (0, 40, 0, 8)),
(9, (0, 45, 0, 9)),
(10, (0, 50, 0, 10)),
(11, (0, 52, 0, 11)),
(12, (5, 0, 1, 0)),
(13, (5, 5, 1, 1)),
(14, (5, 10, 1, 2)),
(15, (5, 15, 1, 3)),
(16, (5, 20, 1, 4)),
(17, (5, 25, 1, 5)),
(18, (5, 30, 1, 6)),
(19, (5, 35, 1, 7)),
]
actual = list(enumerate(patch_coords(60, 60, 8, 3)))[:20]
for expected_coords, actual_coords in zip(expected, actual):
n_exp, (y_exp, x_exp, patch_idx_y_exp, patch_idx_x_exp) = expected_coords
n_act, (y_act, x_act, patch_idx_y_act, patch_idx_x_act) = actual_coords

assert n_exp == n_act
assert y_exp == y_act
assert x_exp == x_act
assert patch_idx_y_exp == patch_idx_y_act
assert patch_idx_x_exp == patch_idx_x_act


@pytest.mark.parametrize("size", [10, 60, 500, 2000])
@pytest.mark.parametrize("patch_size", [8, 64, 1024])
@pytest.mark.parametrize("overlap", [0, 1, 3, 8, 16, 64])
def test_patch_coords_generator_logical(size: int, patch_size: int, overlap: int):
"""Tests the generation of the generation of patch-coordinates.
Tests the first 20 patch-coordinates for a tile fo size 60x60px with a patch-size of 8 and an overlap of 3.
"""
# Skip tests for invalid parameter to be able to to larger sweeps
if not size > patch_size > overlap:
return

coords = list(enumerate(patch_coords(size, size, patch_size, overlap)))
n_patches_h = math.ceil(size / (patch_size - overlap))
for n, (y, x, patch_idx_y, patch_idx_x) in coords:
assert y >= 0
assert x >= 0
assert patch_idx_y >= 0
assert patch_idx_x >= 0
assert y + patch_size <= size
assert x + patch_size <= size
assert n == patch_idx_y * n_patches_h + patch_idx_x

0 comments on commit 8692b16

Please sign in to comment.