Skip to content

Commit

Permalink
Add support for 3d maps in inner_box.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 648681922
  • Loading branch information
mjanusz authored and copybara-github committed Jul 2, 2024
1 parent d4e3f5e commit 9a2546d
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 17 deletions.
45 changes: 28 additions & 17 deletions map_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,11 @@ def to_relative(


def fill_missing(
coord_map: np.ndarray, *, extrapolate=False, invalid_to_zero=False,
interpolate_first=True
coord_map: np.ndarray,
*,
extrapolate=False,
invalid_to_zero=False,
interpolate_first=True,
) -> np.ndarray:
"""Fills missing entries in a coordinate map.
Expand All @@ -247,13 +250,14 @@ def fill_missing(
s = coord_map.shape
dim = s[0]
if dim == 2:
query_coords = np.mgrid[:s[-2], :s[-1]] # yx
query_coords = np.mgrid[: s[-2], : s[-1]] # yx
elif dim == 3:
query_coords = np.mgrid[:s[-3], :s[-2], :s[-1]] # zyx
query_coords = np.mgrid[: s[-3], : s[-2], : s[-1]] # zyx

query_points = tuple([q.ravel() for q in query_coords[::-1]]) # xy[z]

rets = []

def _process_map(curr_map):
ret = curr_map.copy()
valid = np.all(np.isfinite(curr_map), axis=0)
Expand All @@ -267,9 +271,7 @@ def _process_map(curr_map):
if interpolate_first:
try:
intp = _interpolate_points(
points,
query_points, #
*[c[valid] for c in curr_map]
points, query_points, *[c[valid] for c in curr_map]
)
for i, update in enumerate(intp):
ret[i, ...] = update.reshape(s[1:])
Expand Down Expand Up @@ -346,39 +348,50 @@ def inner_box(
"""Returns a box within which all nodes are mapped to by coord map.
Args:
coord_map: [2, z, y, x] coordinate map in relative format
coord_map: [2 or 3, z, y, x] coordinate map in relative format
box: bounding box from which the coordinate map was extracted
stride: distance between nearest neighbors of the coordinate map
Returns:
bounding box, all (u, v) points contained within which have
an entry in the (x, y) -> (u, v) map
bounding box, all (u, v[, w]) points contained within which have
an entry in the (x, y[, z]) -> (u, v[, w]) map
"""
assert coord_map.shape[0] == 2
assert coord_map.shape[0] in (2, 3)

# Part of the map might be invalid, in which case we extrapolate
# in order to get a fully valid array.
int_map = to_absolute(fill_missing(coord_map, extrapolate=True), stride, box)
x0 = np.max(np.min(int_map[0, ...], axis=-1))
y0 = np.max(np.min(int_map[1, ...], axis=-2))
x1 = np.min(np.max(int_map[0, ...], axis=-1))
y0 = np.max(np.min(int_map[1, ...], axis=-2))
y1 = np.min(np.max(int_map[1, ...], axis=-2))

x0 = int(-(-x0 // stride))
y0 = int(-(-y0 // stride))
x1 = x1 // stride
y1 = y1 // stride

if coord_map.shape[0] == 2:
return bounding_box.BoundingBox(
start=(x0, y0, box.start[2]),
size=(x1 - x0 + 1, y1 - y0 + 1, box.size[2]),
)

z0 = np.max(np.min(int_map[2, ...], axis=-3))
z1 = np.min(np.max(int_map[2, ...], axis=-3))
z0 = int(-(-z0 // stride))
z1 = z1 // stride

return bounding_box.BoundingBox(
start=(x0, y0, box.start[2]), size=(x1 - x0 + 1, y1 - y0 + 1, box.size[2])
start=(x0, y0, z0), size=(x1 - x0 + 1, y1 - y0 + 1, z1 - z0 + 1)
)


def invert_map(
coord_map: np.ndarray,
src_box: bounding_box.BoundingBox,
dst_box: bounding_box.BoundingBox,
stride: StrideZYX
stride: StrideZYX,
) -> np.ndarray:
"""Inverts a coordinate map.
Expand Down Expand Up @@ -769,9 +782,7 @@ def mask_irregular(


def make_affine_map(
matrix: np.ndarray,
box: bounding_box.BoundingBox,
stride: StrideZYX
matrix: np.ndarray, box: bounding_box.BoundingBox, stride: StrideZYX
) -> np.ndarray:
"""Builds a coordinate map for an affine transform.
Expand Down
13 changes: 13 additions & 0 deletions tests/map_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,19 @@ def test_inner_box(self):
bounding_box.BoundingBox(start=(100, 200, 10), size=(50, 50, 1)),
)

def test_inner_box3d(self):
box = bounding_box.BoundingBox(start=(100, 200, 200), size=(50, 50, 50))
coord_map = np.zeros([3, 50, 50, 50])
coord_map[2, ...] = -30
coord_map[2, 0, :, :] = -40
coord_map[2, -1, :, :] = -25
inner_box = map_utils.inner_box(coord_map, box, stride=10)

self.assertEqual(
inner_box,
bounding_box.BoundingBox(start=(100, 200, 196), size=(50, 50, 51)),
)

def test_invert_map(self):
box = bounding_box.BoundingBox(start=(100, 200, 10), size=(50, 50, 1))
_, hx = np.mgrid[:50, :50]
Expand Down

0 comments on commit 9a2546d

Please sign in to comment.