Skip to content

Commit

Permalink
Add ops.map_coordinates (#906)
Browse files Browse the repository at this point in the history
* Add `ops.image.map_coordinates`

* Add `map_coordinates`

* Remove testing file

* Fix unit test

* Address comment
  • Loading branch information
james77777778 authored Sep 19, 2023
1 parent e1e207d commit 956e89a
Show file tree
Hide file tree
Showing 6 changed files with 467 additions and 0 deletions.
28 changes: 28 additions & 0 deletions keras_core/backend/jax/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,31 @@ def affine_transform(
if need_squeeze:
affined = jnp.squeeze(affined, axis=0)
return affined


MAP_COORDINATES_FILL_MODES = {
"constant",
"nearest",
"wrap",
"mirror",
"reflect",
}


def map_coordinates(
input, coordinates, order, fill_mode="constant", fill_value=0.0
):
if fill_mode not in MAP_COORDINATES_FILL_MODES:
raise ValueError(
"Invalid value for argument `fill_mode`. Expected one of "
f"{set(MAP_COORDINATES_FILL_MODES)}. Received: "
f"fill_mode={fill_mode}"
)
if order not in range(2):
raise ValueError(
"Invalid value for argument `order`. Expected one of "
f"{[0, 1]}. Received: order={order}"
)
return jax.scipy.ndimage.map_coordinates(
input, coordinates, order, fill_mode, fill_value
)
52 changes: 52 additions & 0 deletions keras_core/backend/numpy/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,55 @@ def affine_transform(
if input_dtype == "float16":
affined = affined.astype(input_dtype)
return affined


MAP_COORDINATES_FILL_MODES = {
"constant",
"nearest",
"wrap",
"mirror",
"reflect",
}


def map_coordinates(
input, coordinates, order, fill_mode="constant", fill_value=0.0
):
if fill_mode not in MAP_COORDINATES_FILL_MODES:
raise ValueError(
"Invalid value for argument `fill_mode`. Expected one of "
f"{set(MAP_COORDINATES_FILL_MODES.keys())}. Received: "
f"fill_mode={fill_mode}"
)
if order not in range(2):
raise ValueError(
"Invalid value for argument `order`. Expected one of "
f"{[0, 1]}. Received: order={order}"
)
# SciPy's implementation of map_coordinates handles boundaries incorrectly,
# unless mode='reflect'. For order=1, this only affects interpolation
# outside the bounds of the original array.
# https://github.com/scipy/scipy/issues/2640
padding = [
(
max(-np.floor(c.min()).astype(int) + 1, 0),
max(np.ceil(c.max()).astype(int) + 1 - size, 0),
)
for c, size in zip(coordinates, input.shape)
]
shifted_coords = [c + p[0] for p, c in zip(padding, coordinates)]
pad_mode = {
"nearest": "edge",
"mirror": "reflect",
"reflect": "symmetric",
}.get(fill_mode, fill_mode)
if fill_mode == "constant":
padded = np.pad(
input, padding, mode=pad_mode, constant_values=fill_value
)
else:
padded = np.pad(input, padding, mode=pad_mode)
result = scipy.ndimage.map_coordinates(
padded, shifted_coords, order=order, mode=fill_mode, cval=fill_value
)
return result
117 changes: 117 additions & 0 deletions keras_core/backend/tensorflow/image.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import functools
import itertools
import operator

import tensorflow as tf

from keras_core.backend.tensorflow.core import convert_to_tensor

RESIZE_INTERPOLATIONS = (
"bilinear",
"nearest",
Expand Down Expand Up @@ -119,3 +125,114 @@ def affine_transform(
if need_squeeze:
affined = tf.squeeze(affined, axis=0)
return affined


def _mirror_index_fixer(index, size):
s = size - 1 # Half-wavelength of triangular wave
# Scaled, integer-valued version of the triangular wave |x - round(x)|
return tf.abs((index + s) % (2 * s) - s)


def _reflect_index_fixer(index, size):
return tf.math.floordiv(
_mirror_index_fixer(2 * index + 1, 2 * size + 1) - 1, 2
)


_INDEX_FIXERS = {
"constant": lambda index, size: index,
"nearest": lambda index, size: tf.clip_by_value(index, 0, size - 1),
"wrap": lambda index, size: index % size,
"mirror": _mirror_index_fixer,
"reflect": _reflect_index_fixer,
}


def _nearest_indices_and_weights(coordinate):
coordinate = (
coordinate if coordinate.dtype.is_integer else tf.round(coordinate)
)
index = tf.cast(coordinate, tf.int32)
weight = tf.constant(1, coordinate.dtype)
return [(index, weight)]


def _linear_indices_and_weights(coordinate):
lower = tf.floor(coordinate)
upper_weight = coordinate - lower
lower_weight = 1 - upper_weight
index = tf.cast(lower, tf.int32)
return [(index, lower_weight), (index + 1, upper_weight)]


def map_coordinates(
input, coordinates, order, fill_mode="constant", fill_value=0.0
):
input_arr = convert_to_tensor(input)
coordinate_arrs = convert_to_tensor(coordinates)
# unstack into a list of tensors for following operations
coordinate_arrs = tf.unstack(coordinate_arrs, axis=0)
fill_value = convert_to_tensor(tf.cast(fill_value, input_arr.dtype))

if len(coordinates) != len(input_arr.shape):
raise ValueError(
"coordinates must be a sequence of length input.shape, but "
f"{len(coordinates)} != {len(input_arr.shape)}"
)

index_fixer = _INDEX_FIXERS.get(fill_mode)
if index_fixer is None:
raise ValueError(
"Invalid value for argument `fill_mode`. Expected one of "
f"{set(_INDEX_FIXERS.keys())}. Received: "
f"fill_mode={fill_mode}"
)

def is_valid(index, size):
if fill_mode == "constant":
return (0 <= index) & (index < size)
else:
return True

if order == 0:
interp_fun = _nearest_indices_and_weights
elif order == 1:
interp_fun = _linear_indices_and_weights
else:
raise NotImplementedError("map_coordinates currently requires order<=1")

valid_1d_interpolations = []
for coordinate, size in zip(coordinate_arrs, input_arr.shape):
interp_nodes = interp_fun(coordinate)
valid_interp = []
for index, weight in interp_nodes:
fixed_index = index_fixer(index, size)
valid = is_valid(index, size)
valid_interp.append((fixed_index, valid, weight))
valid_1d_interpolations.append(valid_interp)

outputs = []
for items in itertools.product(*valid_1d_interpolations):
indices, validities, weights = zip(*items)
indices = tf.transpose(tf.stack(indices))

def fast_path():
return tf.transpose(tf.gather_nd(input_arr, indices))

def slow_path():
all_valid = functools.reduce(operator.and_, validities)
return tf.where(
all_valid,
tf.transpose(tf.gather_nd(input_arr, indices)),
fill_value,
)

contribution = tf.cond(tf.reduce_all(validities), fast_path, slow_path)
outputs.append(
functools.reduce(operator.mul, weights)
* tf.cast(contribution, weights[0].dtype)
)
result = functools.reduce(operator.add, outputs)
if input_arr.dtype.is_integer:
result = result if result.dtype.is_integer else tf.round(result)
return tf.cast(result, input_arr.dtype)
110 changes: 110 additions & 0 deletions keras_core/backend/torch/image.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import functools
import itertools
import operator

import torch
import torch.nn.functional as tnn

Expand Down Expand Up @@ -263,3 +267,109 @@ def affine_transform(
if need_squeeze:
affined = affined.squeeze(dim=0)
return affined


def _mirror_index_fixer(index, size):
s = size - 1 # Half-wavelength of triangular wave
# Scaled, integer-valued version of the triangular wave |x - round(x)|
return torch.abs((index + s) % (2 * s) - s)


def _reflect_index_fixer(index, size):
return torch.floor_divide(
_mirror_index_fixer(2 * index + 1, 2 * size + 1) - 1, 2
)


_INDEX_FIXERS = {
"constant": lambda index, size: index,
"nearest": lambda index, size: torch.clip(index, 0, size - 1),
"wrap": lambda index, size: index % size,
"mirror": _mirror_index_fixer,
"reflect": _reflect_index_fixer,
}


def _is_integer(a):
if not torch.is_floating_point(a) and not torch.is_complex(a):
return True
return False


def _nearest_indices_and_weights(coordinate):
coordinate = (
coordinate if _is_integer(coordinate) else torch.round(coordinate)
)
index = coordinate.to(torch.int32)
weight = torch.tensor(1).to(torch.int32)
return [(index, weight)]


def _linear_indices_and_weights(coordinate):
lower = torch.floor(coordinate)
upper_weight = coordinate - lower
lower_weight = 1 - upper_weight
index = lower.to(torch.int32)
return [(index, lower_weight), (index + 1, upper_weight)]


def map_coordinates(
input, coordinates, order, fill_mode="constant", fill_value=0.0
):
input_arr = convert_to_tensor(input)
coordinate_arrs = [convert_to_tensor(c) for c in coordinates]
fill_value = convert_to_tensor(fill_value, input_arr.dtype)

if len(coordinates) != len(input_arr.shape):
raise ValueError(
"coordinates must be a sequence of length input.shape, but "
f"{len(coordinates)} != {len(input_arr.shape)}"
)

index_fixer = _INDEX_FIXERS.get(fill_mode)
if index_fixer is None:
raise ValueError(
"Invalid value for argument `fill_mode`. Expected one of "
f"{set(_INDEX_FIXERS.keys())}. Received: "
f"fill_mode={fill_mode}"
)

def is_valid(index, size):
if fill_mode == "constant":
return (0 <= index) & (index < size)
else:
return True

if order == 0:
interp_fun = _nearest_indices_and_weights
elif order == 1:
interp_fun = _linear_indices_and_weights
else:
raise NotImplementedError("map_coordinates currently requires order<=1")

valid_1d_interpolations = []
for coordinate, size in zip(coordinate_arrs, input_arr.shape):
interp_nodes = interp_fun(coordinate)
valid_interp = []
for index, weight in interp_nodes:
fixed_index = index_fixer(index, size)
valid = is_valid(index, size)
valid_interp.append((fixed_index, valid, weight))
valid_1d_interpolations.append(valid_interp)

outputs = []
for items in itertools.product(*valid_1d_interpolations):
indices, validities, weights = zip(*items)
if all(valid is True for valid in validities):
# fast path
contribution = input_arr[indices]
else:
all_valid = functools.reduce(operator.and_, validities)
contribution = torch.where(
all_valid, input_arr[indices], fill_value
)
outputs.append(functools.reduce(operator.mul, weights) * contribution)
result = functools.reduce(operator.add, outputs)
if _is_integer(input_arr):
result = result if _is_integer(result) else torch.round(result)
return result.to(input_arr.dtype)
Loading

0 comments on commit 956e89a

Please sign in to comment.