Skip to content

Commit

Permalink
Fix duplicate patches if size is matching
Browse files Browse the repository at this point in the history
  • Loading branch information
relativityhd committed Oct 15, 2024
1 parent 8692b16 commit 1caf3d4
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 20 deletions.
21 changes: 16 additions & 5 deletions darts-segmentation/src/darts_segmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ def patch_coords(h: int, w: int, patch_size: int, overlap: int) -> Generator[tup
"""
step_size = patch_size - overlap
for patch_idx_y, y in enumerate(range(0, h, step_size)):
for patch_idx_x, x in enumerate(range(0, w, step_size)):
# Substract the overlap from h and w so that an exact match of the last patch won't create a duplicate
for patch_idx_y, y in enumerate(range(0, h - overlap, step_size)):
for patch_idx_x, x in enumerate(range(0, w - overlap, step_size)):
if y + patch_size > h:
y = h - patch_size
if x + patch_size > w:
Expand Down Expand Up @@ -52,7 +53,17 @@ def create_patches(
assert w > patch_size > overlap

step_size = patch_size - overlap
nh, nw = math.ceil(h / step_size), math.ceil(w / step_size)

# The problem with unfold is that is cuts off the last patch if it doesn't fit exactly
# Padding could help, but then the next problem is that the view needs to get reshaped (copied in memory)
# to fit the model input shape. Such a complex view can't be inserted into the model.
# Since we need, doing it manually is currently our best choice, since be can avoid the padding.
# patches = (
# tensor_tiles.unfold(2, patch_size, step_size).unfold(3, patch_size, step_size).transpose(1, 2).transpose(2, 3)
# )
# return patches

nh, nw = math.ceil((h - overlap) / step_size), math.ceil((w - overlap) / step_size)
# Create Patches of size (BS, N_h, N_w, C, patch_size, patch_size)
patches = torch.zeros((bs, nh, nw, c, patch_size, patch_size), device=tensor_tiles.device)
coords = torch.zeros((nh, nw, 5))
Expand Down Expand Up @@ -86,7 +97,7 @@ def predict_in_patches(
tensor_tiles = torch.nn.functional.pad(tensor_tiles, (1, 1, 1, 1), mode="reflect")
bs, c, h, w = tensor_tiles.shape
step_size = patch_size - overlap
nh, nw = math.ceil(h / step_size), math.ceil(w / step_size)
nh, nw = math.ceil((h - overlap) / step_size), math.ceil((w - overlap) / step_size)

# Create Patches of size (BS, N_h, N_w, C, patch_size, patch_size)
patches = create_patches(tensor_tiles, patch_size=patch_size, overlap=overlap)
Expand Down Expand Up @@ -124,6 +135,6 @@ def predict_in_patches(
weights = torch.where(weights == 0, torch.ones_like(weights), weights)
prediction = prediction / weights

# Remove the 1px border
# Remove the 1px border and the padding
prediction = prediction[:, 1:-1, 1:-1]
return prediction
34 changes: 19 additions & 15 deletions darts-segmentation/tests/test_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,19 @@

from darts_segmentation.utils import create_patches, patch_coords, predict_in_patches

test_sizes = [10, 23, 60, 2000]
test_patch_sizes = [8, 64, 1024]
test_overlaps = [0, 1, 3, 16, 64]

@pytest.mark.parametrize("size", [10, 60, 500, 2000])
@pytest.mark.parametrize("patch_size", [8, 64, 1024])
@pytest.mark.parametrize("overlap", [0, 1, 3, 8, 16, 64])

@pytest.mark.parametrize("size", test_sizes)
@pytest.mark.parametrize("patch_size", test_patch_sizes)
@pytest.mark.parametrize("overlap", test_overlaps)
def test_patch_prediction(size: int, patch_size: int, overlap: int):
"""Tests the prediction function with a mock model (*2) and a random tensor."""
# Skip tests for invalid parameter to be able to to larger sweeps
if not size > patch_size > overlap:
return
pytest.skip("unsupported configuration")

def model(x):
return 2 * x
Expand All @@ -28,20 +32,20 @@ def model(x):
torch.testing.assert_allclose(prediction, prediction_true)


@pytest.mark.parametrize("size", [10, 60, 500, 2000])
@pytest.mark.parametrize("patch_size", [8, 64, 1024])
@pytest.mark.parametrize("overlap", [0, 1, 3, 8, 16, 64])
@pytest.mark.parametrize("size", test_sizes)
@pytest.mark.parametrize("patch_size", test_patch_sizes)
@pytest.mark.parametrize("overlap", test_overlaps)
def test_create_patches(size: int, patch_size: int, overlap: int):
"""Tests the creation of patches."""
# Skip tests for invalid parameter to be able to to larger sweeps
if not size > patch_size > overlap:
return
pytest.skip("unsupported configuration")

h, w = size, size
tensor_tiles = torch.rand((3, 1, h, w))
patches = create_patches(tensor_tiles, patch_size=patch_size, overlap=overlap)
n_patches_h = math.ceil(h / (patch_size - overlap))
n_patches_w = math.ceil(h / (patch_size - overlap))
n_patches_h = math.ceil((h - overlap) / (patch_size - overlap))
n_patches_w = math.ceil((w - overlap) / (patch_size - overlap))
assert patches.shape == (3, n_patches_h, n_patches_w, 1, patch_size, patch_size)

step_size = patch_size - overlap
Expand Down Expand Up @@ -98,20 +102,20 @@ def test_patch_coords_example_generator():
assert patch_idx_x_exp == patch_idx_x_act


@pytest.mark.parametrize("size", [10, 60, 500, 2000])
@pytest.mark.parametrize("patch_size", [8, 64, 1024])
@pytest.mark.parametrize("overlap", [0, 1, 3, 8, 16, 64])
@pytest.mark.parametrize("size", test_sizes)
@pytest.mark.parametrize("patch_size", test_patch_sizes)
@pytest.mark.parametrize("overlap", test_overlaps)
def test_patch_coords_generator_logical(size: int, patch_size: int, overlap: int):
"""Tests the generation of the generation of patch-coordinates.
Tests the first 20 patch-coordinates for a tile fo size 60x60px with a patch-size of 8 and an overlap of 3.
"""
# Skip tests for invalid parameter to be able to to larger sweeps
if not size > patch_size > overlap:
return
pytest.skip("unsupported configuration")

coords = list(enumerate(patch_coords(size, size, patch_size, overlap)))
n_patches_h = math.ceil(size / (patch_size - overlap))
n_patches_h = math.ceil((size - overlap) / (patch_size - overlap))
for n, (y, x, patch_idx_y, patch_idx_x) in coords:
assert y >= 0
assert x >= 0
Expand Down

0 comments on commit 1caf3d4

Please sign in to comment.