Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ops.map_coordinates #906

Merged
merged 6 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions keras_core/backend/jax/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,31 @@ def affine_transform(
if need_squeeze:
affined = jnp.squeeze(affined, axis=0)
return affined


MAP_COORDINATES_FILL_MODES = {
"constant",
"nearest",
"wrap",
"mirror",
"reflect",
}


def map_coordinates(
input, coordinates, order, fill_mode="constant", fill_value=0.0
):
if fill_mode not in MAP_COORDINATES_FILL_MODES:
raise ValueError(
"Invalid value for argument `fill_mode`. Expected one of "
f"{set(MAP_COORDINATES_FILL_MODES)}. Received: "
f"fill_mode={fill_mode}"
)
if order not in range(2):
raise ValueError(
"Invalid value for argument `order`. Expected one of "
f"{[0, 1]}. Received: order={order}"
)
return jax.scipy.ndimage.map_coordinates(
input, coordinates, order, fill_mode, fill_value
)
52 changes: 52 additions & 0 deletions keras_core/backend/numpy/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,55 @@ def affine_transform(
if input_dtype == "float16":
affined = affined.astype(input_dtype)
return affined


MAP_COORDINATES_FILL_MODES = {
"constant",
"nearest",
"wrap",
"mirror",
"reflect",
}


def map_coordinates(
input, coordinates, order, fill_mode="constant", fill_value=0.0
):
if fill_mode not in MAP_COORDINATES_FILL_MODES:
raise ValueError(
"Invalid value for argument `fill_mode`. Expected one of "
f"{set(MAP_COORDINATES_FILL_MODES.keys())}. Received: "
f"fill_mode={fill_mode}"
)
if order not in range(2):
raise ValueError(
"Invalid value for argument `order`. Expected one of "
f"{[0, 1]}. Received: order={order}"
)
# SciPy's implementation of map_coordinates handles boundaries incorrectly,
# unless mode='reflect'. For order=1, this only affects interpolation
# outside the bounds of the original array.
# https://github.com/scipy/scipy/issues/2640
padding = [
(
max(-np.floor(c.min()).astype(int) + 1, 0),
max(np.ceil(c.max()).astype(int) + 1 - size, 0),
)
for c, size in zip(coordinates, input.shape)
]
shifted_coords = [c + p[0] for p, c in zip(padding, coordinates)]
pad_mode = {
"nearest": "edge",
"mirror": "reflect",
"reflect": "symmetric",
}.get(fill_mode, fill_mode)
if fill_mode == "constant":
padded = np.pad(
input, padding, mode=pad_mode, constant_values=fill_value
)
else:
padded = np.pad(input, padding, mode=pad_mode)
result = scipy.ndimage.map_coordinates(
padded, shifted_coords, order=order, mode=fill_mode, cval=fill_value
)
return result
117 changes: 117 additions & 0 deletions keras_core/backend/tensorflow/image.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import functools
import itertools
import operator

import tensorflow as tf

from keras_core.backend.tensorflow.core import convert_to_tensor

RESIZE_INTERPOLATIONS = (
"bilinear",
"nearest",
Expand Down Expand Up @@ -119,3 +125,114 @@ def affine_transform(
if need_squeeze:
affined = tf.squeeze(affined, axis=0)
return affined


def _mirror_index_fixer(index, size):
s = size - 1 # Half-wavelength of triangular wave
# Scaled, integer-valued version of the triangular wave |x - round(x)|
return tf.abs((index + s) % (2 * s) - s)


def _reflect_index_fixer(index, size):
return tf.math.floordiv(
_mirror_index_fixer(2 * index + 1, 2 * size + 1) - 1, 2
)


_INDEX_FIXERS = {
"constant": lambda index, size: index,
"nearest": lambda index, size: tf.clip_by_value(index, 0, size - 1),
"wrap": lambda index, size: index % size,
"mirror": _mirror_index_fixer,
"reflect": _reflect_index_fixer,
}


def _nearest_indices_and_weights(coordinate):
coordinate = (
coordinate if coordinate.dtype.is_integer else tf.round(coordinate)
)
index = tf.cast(coordinate, tf.int32)
weight = tf.constant(1, coordinate.dtype)
return [(index, weight)]


def _linear_indices_and_weights(coordinate):
lower = tf.floor(coordinate)
upper_weight = coordinate - lower
lower_weight = 1 - upper_weight
index = tf.cast(lower, tf.int32)
return [(index, lower_weight), (index + 1, upper_weight)]


def map_coordinates(
input, coordinates, order, fill_mode="constant", fill_value=0.0
):
input_arr = convert_to_tensor(input)
coordinate_arrs = convert_to_tensor(coordinates)
# unstack into a list of tensors for following operations
coordinate_arrs = tf.unstack(coordinate_arrs, axis=0)
fill_value = convert_to_tensor(tf.cast(fill_value, input_arr.dtype))

if len(coordinates) != len(input_arr.shape):
raise ValueError(
"coordinates must be a sequence of length input.shape, but "
f"{len(coordinates)} != {len(input_arr.shape)}"
)

index_fixer = _INDEX_FIXERS.get(fill_mode)
if index_fixer is None:
raise ValueError(
"Invalid value for argument `fill_mode`. Expected one of "
f"{set(_INDEX_FIXERS.keys())}. Received: "
f"fill_mode={fill_mode}"
)

def is_valid(index, size):
if fill_mode == "constant":
return (0 <= index) & (index < size)
else:
return True

if order == 0:
interp_fun = _nearest_indices_and_weights
elif order == 1:
interp_fun = _linear_indices_and_weights
else:
raise NotImplementedError("map_coordinates currently requires order<=1")

valid_1d_interpolations = []
for coordinate, size in zip(coordinate_arrs, input_arr.shape):
interp_nodes = interp_fun(coordinate)
valid_interp = []
for index, weight in interp_nodes:
fixed_index = index_fixer(index, size)
valid = is_valid(index, size)
valid_interp.append((fixed_index, valid, weight))
valid_1d_interpolations.append(valid_interp)

outputs = []
for items in itertools.product(*valid_1d_interpolations):
indices, validities, weights = zip(*items)
indices = tf.transpose(tf.stack(indices))

def fast_path():
return tf.transpose(tf.gather_nd(input_arr, indices))

def slow_path():
all_valid = functools.reduce(operator.and_, validities)
return tf.where(
all_valid,
tf.transpose(tf.gather_nd(input_arr, indices)),
fill_value,
)

contribution = tf.cond(tf.reduce_all(validities), fast_path, slow_path)
outputs.append(
functools.reduce(operator.mul, weights)
* tf.cast(contribution, weights[0].dtype)
)
result = functools.reduce(operator.add, outputs)
if input_arr.dtype.is_integer:
result = result if result.dtype.is_integer else tf.round(result)
return tf.cast(result, input_arr.dtype)
110 changes: 110 additions & 0 deletions keras_core/backend/torch/image.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import functools
import itertools
import operator

import torch
import torch.nn.functional as tnn

Expand Down Expand Up @@ -263,3 +267,109 @@ def affine_transform(
if need_squeeze:
affined = affined.squeeze(dim=0)
return affined


def _mirror_index_fixer(index, size):
s = size - 1 # Half-wavelength of triangular wave
# Scaled, integer-valued version of the triangular wave |x - round(x)|
return torch.abs((index + s) % (2 * s) - s)


def _reflect_index_fixer(index, size):
return torch.floor_divide(
_mirror_index_fixer(2 * index + 1, 2 * size + 1) - 1, 2
)


_INDEX_FIXERS = {
"constant": lambda index, size: index,
"nearest": lambda index, size: torch.clip(index, 0, size - 1),
"wrap": lambda index, size: index % size,
"mirror": _mirror_index_fixer,
"reflect": _reflect_index_fixer,
}


def _is_integer(a):
if not torch.is_floating_point(a) and not torch.is_complex(a):
return True
return False


def _nearest_indices_and_weights(coordinate):
coordinate = (
coordinate if _is_integer(coordinate) else torch.round(coordinate)
)
index = coordinate.to(torch.int32)
weight = torch.tensor(1).to(torch.int32)
return [(index, weight)]


def _linear_indices_and_weights(coordinate):
lower = torch.floor(coordinate)
upper_weight = coordinate - lower
lower_weight = 1 - upper_weight
index = lower.to(torch.int32)
return [(index, lower_weight), (index + 1, upper_weight)]


def map_coordinates(
input, coordinates, order, fill_mode="constant", fill_value=0.0
):
input_arr = convert_to_tensor(input)
coordinate_arrs = [convert_to_tensor(c) for c in coordinates]
fill_value = convert_to_tensor(fill_value, input_arr.dtype)

if len(coordinates) != len(input_arr.shape):
raise ValueError(
"coordinates must be a sequence of length input.shape, but "
f"{len(coordinates)} != {len(input_arr.shape)}"
)

index_fixer = _INDEX_FIXERS.get(fill_mode)
if index_fixer is None:
raise ValueError(
"Invalid value for argument `fill_mode`. Expected one of "
f"{set(_INDEX_FIXERS.keys())}. Received: "
f"fill_mode={fill_mode}"
)

def is_valid(index, size):
if fill_mode == "constant":
return (0 <= index) & (index < size)
else:
return True

if order == 0:
interp_fun = _nearest_indices_and_weights
elif order == 1:
interp_fun = _linear_indices_and_weights
else:
raise NotImplementedError("map_coordinates currently requires order<=1")

valid_1d_interpolations = []
for coordinate, size in zip(coordinate_arrs, input_arr.shape):
interp_nodes = interp_fun(coordinate)
valid_interp = []
for index, weight in interp_nodes:
fixed_index = index_fixer(index, size)
valid = is_valid(index, size)
valid_interp.append((fixed_index, valid, weight))
valid_1d_interpolations.append(valid_interp)

outputs = []
for items in itertools.product(*valid_1d_interpolations):
indices, validities, weights = zip(*items)
if all(valid is True for valid in validities):
# fast path
contribution = input_arr[indices]
else:
all_valid = functools.reduce(operator.and_, validities)
contribution = torch.where(
all_valid, input_arr[indices], fill_value
)
outputs.append(functools.reduce(operator.mul, weights) * contribution)
result = functools.reduce(operator.add, outputs)
if _is_integer(input_arr):
result = result if _is_integer(result) else torch.round(result)
return result.to(input_arr.dtype)
Loading