diff --git a/keras_core/backend/jax/image.py b/keras_core/backend/jax/image.py index 7852d5e6a..421988022 100644 --- a/keras_core/backend/jax/image.py +++ b/keras_core/backend/jax/image.py @@ -162,3 +162,31 @@ def affine_transform( if need_squeeze: affined = jnp.squeeze(affined, axis=0) return affined + + +MAP_COORDINATES_FILL_MODES = { + "constant", + "nearest", + "wrap", + "mirror", + "reflect", +} + + +def map_coordinates( + input, coordinates, order, fill_mode="constant", fill_value=0.0 +): + if fill_mode not in MAP_COORDINATES_FILL_MODES: + raise ValueError( + "Invalid value for argument `fill_mode`. Expected one of " + f"{set(MAP_COORDINATES_FILL_MODES)}. Received: " + f"fill_mode={fill_mode}" + ) + if order not in range(2): + raise ValueError( + "Invalid value for argument `order`. Expected one of " + f"{[0, 1]}. Received: order={order}" + ) + return jax.scipy.ndimage.map_coordinates( + input, coordinates, order, fill_mode, fill_value + ) diff --git a/keras_core/backend/numpy/image.py b/keras_core/backend/numpy/image.py index cd4ac272b..408d126be 100644 --- a/keras_core/backend/numpy/image.py +++ b/keras_core/backend/numpy/image.py @@ -173,3 +173,55 @@ def affine_transform( if input_dtype == "float16": affined = affined.astype(input_dtype) return affined + + +MAP_COORDINATES_FILL_MODES = { + "constant", + "nearest", + "wrap", + "mirror", + "reflect", +} + + +def map_coordinates( + input, coordinates, order, fill_mode="constant", fill_value=0.0 +): + if fill_mode not in MAP_COORDINATES_FILL_MODES: + raise ValueError( + "Invalid value for argument `fill_mode`. Expected one of " + f"{set(MAP_COORDINATES_FILL_MODES.keys())}. Received: " + f"fill_mode={fill_mode}" + ) + if order not in range(2): + raise ValueError( + "Invalid value for argument `order`. Expected one of " + f"{[0, 1]}. Received: order={order}" + ) + # SciPy's implementation of map_coordinates handles boundaries incorrectly, + # unless mode='reflect'. For order=1, this only affects interpolation + # outside the bounds of the original array. + # https://github.com/scipy/scipy/issues/2640 + padding = [ + ( + max(-np.floor(c.min()).astype(int) + 1, 0), + max(np.ceil(c.max()).astype(int) + 1 - size, 0), + ) + for c, size in zip(coordinates, input.shape) + ] + shifted_coords = [c + p[0] for p, c in zip(padding, coordinates)] + pad_mode = { + "nearest": "edge", + "mirror": "reflect", + "reflect": "symmetric", + }.get(fill_mode, fill_mode) + if fill_mode == "constant": + padded = np.pad( + input, padding, mode=pad_mode, constant_values=fill_value + ) + else: + padded = np.pad(input, padding, mode=pad_mode) + result = scipy.ndimage.map_coordinates( + padded, shifted_coords, order=order, mode=fill_mode, cval=fill_value + ) + return result diff --git a/keras_core/backend/tensorflow/image.py b/keras_core/backend/tensorflow/image.py index d864b5db0..b24528049 100644 --- a/keras_core/backend/tensorflow/image.py +++ b/keras_core/backend/tensorflow/image.py @@ -1,5 +1,11 @@ +import functools +import itertools +import operator + import tensorflow as tf +from keras_core.backend.tensorflow.core import convert_to_tensor + RESIZE_INTERPOLATIONS = ( "bilinear", "nearest", @@ -119,3 +125,114 @@ def affine_transform( if need_squeeze: affined = tf.squeeze(affined, axis=0) return affined + + +def _mirror_index_fixer(index, size): + s = size - 1 # Half-wavelength of triangular wave + # Scaled, integer-valued version of the triangular wave |x - round(x)| + return tf.abs((index + s) % (2 * s) - s) + + +def _reflect_index_fixer(index, size): + return tf.math.floordiv( + _mirror_index_fixer(2 * index + 1, 2 * size + 1) - 1, 2 + ) + + +_INDEX_FIXERS = { + "constant": lambda index, size: index, + "nearest": lambda index, size: tf.clip_by_value(index, 0, size - 1), + "wrap": lambda index, size: index % size, + "mirror": _mirror_index_fixer, + "reflect": _reflect_index_fixer, +} + + +def _nearest_indices_and_weights(coordinate): + coordinate = ( + coordinate if coordinate.dtype.is_integer else tf.round(coordinate) + ) + index = tf.cast(coordinate, tf.int32) + weight = tf.constant(1, coordinate.dtype) + return [(index, weight)] + + +def _linear_indices_and_weights(coordinate): + lower = tf.floor(coordinate) + upper_weight = coordinate - lower + lower_weight = 1 - upper_weight + index = tf.cast(lower, tf.int32) + return [(index, lower_weight), (index + 1, upper_weight)] + + +def map_coordinates( + input, coordinates, order, fill_mode="constant", fill_value=0.0 +): + input_arr = convert_to_tensor(input) + coordinate_arrs = convert_to_tensor(coordinates) + # unstack into a list of tensors for following operations + coordinate_arrs = tf.unstack(coordinate_arrs, axis=0) + fill_value = convert_to_tensor(tf.cast(fill_value, input_arr.dtype)) + + if len(coordinates) != len(input_arr.shape): + raise ValueError( + "coordinates must be a sequence of length input.shape, but " + f"{len(coordinates)} != {len(input_arr.shape)}" + ) + + index_fixer = _INDEX_FIXERS.get(fill_mode) + if index_fixer is None: + raise ValueError( + "Invalid value for argument `fill_mode`. Expected one of " + f"{set(_INDEX_FIXERS.keys())}. Received: " + f"fill_mode={fill_mode}" + ) + + def is_valid(index, size): + if fill_mode == "constant": + return (0 <= index) & (index < size) + else: + return True + + if order == 0: + interp_fun = _nearest_indices_and_weights + elif order == 1: + interp_fun = _linear_indices_and_weights + else: + raise NotImplementedError("map_coordinates currently requires order<=1") + + valid_1d_interpolations = [] + for coordinate, size in zip(coordinate_arrs, input_arr.shape): + interp_nodes = interp_fun(coordinate) + valid_interp = [] + for index, weight in interp_nodes: + fixed_index = index_fixer(index, size) + valid = is_valid(index, size) + valid_interp.append((fixed_index, valid, weight)) + valid_1d_interpolations.append(valid_interp) + + outputs = [] + for items in itertools.product(*valid_1d_interpolations): + indices, validities, weights = zip(*items) + indices = tf.transpose(tf.stack(indices)) + + def fast_path(): + return tf.transpose(tf.gather_nd(input_arr, indices)) + + def slow_path(): + all_valid = functools.reduce(operator.and_, validities) + return tf.where( + all_valid, + tf.transpose(tf.gather_nd(input_arr, indices)), + fill_value, + ) + + contribution = tf.cond(tf.reduce_all(validities), fast_path, slow_path) + outputs.append( + functools.reduce(operator.mul, weights) + * tf.cast(contribution, weights[0].dtype) + ) + result = functools.reduce(operator.add, outputs) + if input_arr.dtype.is_integer: + result = result if result.dtype.is_integer else tf.round(result) + return tf.cast(result, input_arr.dtype) diff --git a/keras_core/backend/torch/image.py b/keras_core/backend/torch/image.py index 9c85861d6..7ef82eef2 100644 --- a/keras_core/backend/torch/image.py +++ b/keras_core/backend/torch/image.py @@ -1,3 +1,7 @@ +import functools +import itertools +import operator + import torch import torch.nn.functional as tnn @@ -263,3 +267,109 @@ def affine_transform( if need_squeeze: affined = affined.squeeze(dim=0) return affined + + +def _mirror_index_fixer(index, size): + s = size - 1 # Half-wavelength of triangular wave + # Scaled, integer-valued version of the triangular wave |x - round(x)| + return torch.abs((index + s) % (2 * s) - s) + + +def _reflect_index_fixer(index, size): + return torch.floor_divide( + _mirror_index_fixer(2 * index + 1, 2 * size + 1) - 1, 2 + ) + + +_INDEX_FIXERS = { + "constant": lambda index, size: index, + "nearest": lambda index, size: torch.clip(index, 0, size - 1), + "wrap": lambda index, size: index % size, + "mirror": _mirror_index_fixer, + "reflect": _reflect_index_fixer, +} + + +def _is_integer(a): + if not torch.is_floating_point(a) and not torch.is_complex(a): + return True + return False + + +def _nearest_indices_and_weights(coordinate): + coordinate = ( + coordinate if _is_integer(coordinate) else torch.round(coordinate) + ) + index = coordinate.to(torch.int32) + weight = torch.tensor(1).to(torch.int32) + return [(index, weight)] + + +def _linear_indices_and_weights(coordinate): + lower = torch.floor(coordinate) + upper_weight = coordinate - lower + lower_weight = 1 - upper_weight + index = lower.to(torch.int32) + return [(index, lower_weight), (index + 1, upper_weight)] + + +def map_coordinates( + input, coordinates, order, fill_mode="constant", fill_value=0.0 +): + input_arr = convert_to_tensor(input) + coordinate_arrs = [convert_to_tensor(c) for c in coordinates] + fill_value = convert_to_tensor(fill_value, input_arr.dtype) + + if len(coordinates) != len(input_arr.shape): + raise ValueError( + "coordinates must be a sequence of length input.shape, but " + f"{len(coordinates)} != {len(input_arr.shape)}" + ) + + index_fixer = _INDEX_FIXERS.get(fill_mode) + if index_fixer is None: + raise ValueError( + "Invalid value for argument `fill_mode`. Expected one of " + f"{set(_INDEX_FIXERS.keys())}. Received: " + f"fill_mode={fill_mode}" + ) + + def is_valid(index, size): + if fill_mode == "constant": + return (0 <= index) & (index < size) + else: + return True + + if order == 0: + interp_fun = _nearest_indices_and_weights + elif order == 1: + interp_fun = _linear_indices_and_weights + else: + raise NotImplementedError("map_coordinates currently requires order<=1") + + valid_1d_interpolations = [] + for coordinate, size in zip(coordinate_arrs, input_arr.shape): + interp_nodes = interp_fun(coordinate) + valid_interp = [] + for index, weight in interp_nodes: + fixed_index = index_fixer(index, size) + valid = is_valid(index, size) + valid_interp.append((fixed_index, valid, weight)) + valid_1d_interpolations.append(valid_interp) + + outputs = [] + for items in itertools.product(*valid_1d_interpolations): + indices, validities, weights = zip(*items) + if all(valid is True for valid in validities): + # fast path + contribution = input_arr[indices] + else: + all_valid = functools.reduce(operator.and_, validities) + contribution = torch.where( + all_valid, input_arr[indices], fill_value + ) + outputs.append(functools.reduce(operator.mul, weights) * contribution) + result = functools.reduce(operator.add, outputs) + if _is_integer(input_arr): + result = result if _is_integer(result) else torch.round(result) + return result.to(input_arr.dtype) diff --git a/keras_core/ops/image.py b/keras_core/ops/image.py index 7351a7425..bc8e48d2d 100644 --- a/keras_core/ops/image.py +++ b/keras_core/ops/image.py @@ -418,3 +418,90 @@ def _extract_patches( if _unbatched: patches = backend.numpy.squeeze(patches, axis=0) return patches + + +class MapCoordinates(Operation): + def __init__(self, order, fill_mode="constant", fill_value=0): + super().__init__() + self.order = order + self.fill_mode = fill_mode + self.fill_value = fill_value + + def call(self, image, coordinates): + return backend.image.map_coordinates( + image, + coordinates, + order=self.order, + fill_mode=self.fill_mode, + fill_value=self.fill_value, + ) + + def compute_output_spec(self, image, coordinates): + if coordinates.shape[0] != len(image.shape): + raise ValueError( + "First dim of `coordinates` must be the same as the rank of " + "`image`. " + f"Received image with shape: {image.shape} and coordinate " + f"leading dim of {coordinates.shape[0]}" + ) + if len(coordinates.shape) < 2: + raise ValueError( + "Invalid coordinates rank: expected at least rank 2." + f" Received input with shape: {coordinates.shape}" + ) + return KerasTensor(coordinates.shape[1:], dtype=image.dtype) + + +@keras_core_export("keras_core.ops.image.map_coordinates") +def map_coordinates( + input, coordinates, order, fill_mode="constant", fill_value=0 +): + """Map the input array to new coordinates by interpolation.. + + Note that interpolation near boundaries differs from the scipy function, + because we fixed an outstanding bug + [scipy/issues/2640](https://github.com/scipy/scipy/issues/2640). + + Args: + input: The input array. + coordinates: The coordinates at which input is evaluated. + order: The order of the spline interpolation. The order must be `0` or + `1`. `0` indicates the nearest neighbor and `1` indicates the linear + interpolation. + fill_mode: Points outside the boundaries of the input are filled + according to the given mode. Available methods are `"constant"`, + `"nearest"`, `"wrap"` and `"mirror"` and `"reflect"`. Defaults to + `"constant"`. + - `"constant"`: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond + the edge with the same constant value k specified by + `fill_value`. + - `"nearest"`: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + - `"wrap"`: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - `"mirror"`: `(c d c b | a b c d | c b a b)` + The input is extended by mirroring about the edge. + - `"reflect"`: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last + pixel. + fill_value: Value used for points outside the boundaries of the input if + `fill_mode="constant"`. Defaults to `0`. + + Returns: + Output image or batch of images. + + """ + if any_symbolic_tensors((input, coordinates)): + return MapCoordinates( + order, + fill_mode, + fill_value, + ).symbolic_call(input, coordinates) + return backend.image.map_coordinates( + input, + coordinates, + order, + fill_mode, + fill_value, + ) diff --git a/keras_core/ops/image_test.py b/keras_core/ops/image_test.py index 1e4d05880..940624f53 100644 --- a/keras_core/ops/image_test.py +++ b/keras_core/ops/image_test.py @@ -1,3 +1,5 @@ +import math + import numpy as np import pytest import scipy.ndimage @@ -34,6 +36,12 @@ def test_extract_patches(self): out = kimage.extract_patches(x, 5) self.assertEqual(out.shape, (None, 4, 4, 75)) + def test_map_coordinates(self): + input = KerasTensor([20, 20, None]) + coordinates = KerasTensor([3, 15, 15, None]) + out = kimage.map_coordinates(input, coordinates, 0) + self.assertEqual(out.shape, coordinates.shape[1:]) + class ImageOpsStaticShapeTest(testing.TestCase): def test_resize(self): @@ -55,6 +63,12 @@ def test_extract_patches(self): out = kimage.extract_patches(x, 5) self.assertEqual(out.shape, (4, 4, 75)) + def test_map_coordinates(self): + input = KerasTensor([20, 20, 3]) + coordinates = KerasTensor([3, 15, 15, 3]) + out = kimage.map_coordinates(input, coordinates, 0) + self.assertEqual(out.shape, coordinates.shape[1:]) + AFFINE_TRANSFORM_INTERPOLATIONS = { # map to order "nearest": 0, @@ -107,6 +121,38 @@ def _compute_affine_transform_coordinates(image, transform): return coordinates +def _map_coordinates( + input, coordinates, order, fill_mode="constant", fill_value=0.0 +): + # SciPy's implementation of map_coordinates handles boundaries incorrectly, + # unless mode='reflect'. For order=1, this only affects interpolation + # outside the bounds of the original array. + # https://github.com/scipy/scipy/issues/2640 + padding = [ + ( + max(-np.floor(c.min()).astype(int) + 1, 0), + max(np.ceil(c.max()).astype(int) + 1 - size, 0), + ) + for c, size in zip(coordinates, input.shape) + ] + shifted_coords = [c + p[0] for p, c in zip(padding, coordinates)] + pad_mode = { + "nearest": "edge", + "mirror": "reflect", + "reflect": "symmetric", + }.get(fill_mode, fill_mode) + if fill_mode == "constant": + padded = np.pad( + input, padding, mode=pad_mode, constant_values=fill_value + ) + else: + padded = np.pad(input, padding, mode=pad_mode) + result = scipy.ndimage.map_coordinates( + padded, shifted_coords, order=order, mode=fill_mode, cval=fill_value + ) + return result + + class ImageOpsCorrectnessTest(testing.TestCase, parameterized.TestCase): @parameterized.parameters( [ @@ -350,3 +396,30 @@ def test_extract_patches( self.assertAllClose( patches_ref.numpy(), backend.convert_to_numpy(patches_out), atol=0.3 ) + + @parameterized.product( + # (input_shape, coordinates_shape) + shape=[((5,), (7,)), ((3, 4, 5), (2, 3, 4))], + # TODO: scipy.ndimage.map_coordinates does not support float16 + # TODO: torch cpu does not support round & floor for float16 + dtype=["uint8", "int32", "float32"], + order=[0, 1], + fill_mode=["constant", "nearest", "wrap", "mirror", "reflect"], + ) + def test_map_coordinates(self, shape, dtype, order, fill_mode): + input_shape, coordinates_shape = shape + input = np.arange(math.prod(input_shape), dtype=dtype).reshape( + input_shape + ) + coordinates_dtype = "float32" if "int" in dtype else dtype + coordinates = [ + (size - 1) + * np.random.uniform(size=coordinates_shape).astype( + coordinates_dtype + ) + for size in input_shape + ] + output = kimage.map_coordinates(input, coordinates, order, fill_mode) + expected = _map_coordinates(input, coordinates, order, fill_mode) + + self.assertAllClose(output, expected)