From 23b21cb00e7c23620663639a533bc46b478d8df9 Mon Sep 17 00:00:00 2001 From: chiuhongyu <20734616+james77777778@users.noreply.github.com> Date: Sun, 17 Sep 2023 15:37:05 +0800 Subject: [PATCH 1/5] Add `ops.image.map_coordinates` --- check_map.py | 16 ++++ keras_core/backend/jax/image.py | 27 ++++++ keras_core/backend/numpy/image.py | 49 ++++++++++ keras_core/backend/tensorflow/image.py | 120 +++++++++++++++++++++++++ keras_core/backend/torch/image.py | 120 +++++++++++++++++++++++++ keras_core/ops/image.py | 98 ++++++++++++++++++++ keras_core/ops/image_test.py | 8 ++ 7 files changed, 438 insertions(+) create mode 100644 check_map.py diff --git a/check_map.py b/check_map.py new file mode 100644 index 000000000..82380a61a --- /dev/null +++ b/check_map.py @@ -0,0 +1,16 @@ +import numpy as np + +from keras_core.backend.jax.image import map_coordinates as jax_map_coordinates +from keras_core.backend.tensorflow.image import ( + map_coordinates as tf_map_coordinates, +) +from keras_core.backend.torch.image import ( + map_coordinates as torch_map_coordinates, +) + +data = np.arange(12).reshape((4, 3)) +coordinates = np.array([[0.5, 2], [0.5, 1]]) + +# print(jax_map_coordinates(data, coordinates, 1)) +print(tf_map_coordinates(data, coordinates, 1)) +# print(torch_map_coordinates(data, coordinates, 1)) diff --git a/keras_core/backend/jax/image.py b/keras_core/backend/jax/image.py index 7852d5e6a..7cdc1f04d 100644 --- a/keras_core/backend/jax/image.py +++ b/keras_core/backend/jax/image.py @@ -162,3 +162,30 @@ def affine_transform( if need_squeeze: affined = jnp.squeeze(affined, axis=0) return affined + + +MAP_COORDINATES_MODES = { + "constant", + "nearest", + "wrap", + "mirror", + "reflect", +} + + +def map_coordinates(input, coordinates, order, mode="constant", cval=0.0): + if mode not in MAP_COORDINATES_MODES: + raise ValueError( + "Invalid value for argument `mode`. Expected one of " + f"{set(MAP_COORDINATES_MODES.keys())}. Received: " + f"mode={mode}" + ) + if order not in range(2): + raise ValueError( + "Invalid value for argument `order`. Expected one of " + f"{[0, 1]}. Received: " + f"mode={mode}" + ) + return jax.scipy.ndimage.map_coordinates( + input, coordinates, order, mode, cval + ) diff --git a/keras_core/backend/numpy/image.py b/keras_core/backend/numpy/image.py index cd4ac272b..935410240 100644 --- a/keras_core/backend/numpy/image.py +++ b/keras_core/backend/numpy/image.py @@ -173,3 +173,52 @@ def affine_transform( if input_dtype == "float16": affined = affined.astype(input_dtype) return affined + + +MAP_COORDINATES_MODES = { + "constant", + "nearest", + "wrap", + "mirror", + "reflect", +} + + +def map_coordinates(input, coordinates, order, mode="constant", cval=0.0): + if mode not in MAP_COORDINATES_MODES: + raise ValueError( + "Invalid value for argument `mode`. Expected one of " + f"{set(MAP_COORDINATES_MODES.keys())}. Received: " + f"mode={mode}" + ) + if order not in range(2): + raise ValueError( + "Invalid value for argument `order`. Expected one of " + f"{[0, 1]}. Received: " + f"mode={mode}" + ) + # SciPy's implementation of map_coordinates handles boundaries incorrectly, + # unless mode='reflect'. For order=1, this only affects interpolation + # outside the bounds of the original array. + # https://github.com/scipy/scipy/issues/2640 + padding = [ + ( + max(-np.floor(c.min()).astype(int) + 1, 0), + max(np.ceil(c.max()).astype(int) + 1 - size, 0), + ) + for c, size in zip(coordinates, input.shape) + ] + shifted_coords = [c + p[0] for p, c in zip(padding, coordinates)] + pad_mode = { + "nearest": "edge", + "mirror": "reflect", + "reflect": "symmetric", + }.get(mode, mode) + if mode == "constant": + padded = np.pad(input, padding, mode=pad_mode, constant_values=cval) + else: + padded = np.pad(input, padding, mode=pad_mode) + result = scipy.ndimage.map_coordinates( + padded, shifted_coords, order=order, mode=mode, cval=cval + ) + return result diff --git a/keras_core/backend/tensorflow/image.py b/keras_core/backend/tensorflow/image.py index d864b5db0..f6f19934f 100644 --- a/keras_core/backend/tensorflow/image.py +++ b/keras_core/backend/tensorflow/image.py @@ -1,5 +1,9 @@ +import itertools + import tensorflow as tf +from keras_core.backend.tensorflow.core import convert_to_tensor + RESIZE_INTERPOLATIONS = ( "bilinear", "nearest", @@ -119,3 +123,119 @@ def affine_transform( if need_squeeze: affined = tf.squeeze(affined, axis=0) return affined + + +def _unzip3(xyzs): + """Unzip sequence of length-3 tuples into three tuples.""" + # Note: we deliberately don't use zip(*xyzs) because it is lazily evaluated, + # is too permissive about inputs, and does not guarantee a length-3 output. + xs = [] + ys = [] + zs = [] + for x, y, z in xyzs: + xs.append(x) + ys.append(y) + zs.append(z) + return tuple(xs), tuple(ys), tuple(zs) + + +def _mirror_index_fixer(index, size): + s = size - 1 # Half-wavelength of triangular wave + # Scaled, integer-valued version of the triangular wave |x - round(x)| + return tf.abs((index + s) % (2 * s) - s) + + +def _reflect_index_fixer(index, size): + return tf.math.floordiv( + _mirror_index_fixer(2 * index + 1, 2 * size + 1) - 1, 2 + ) + + +_INDEX_FIXERS = { + "constant": lambda index, size: index, + "nearest": lambda index, size: tf.clip_by_value(index, 0, size - 1), + "wrap": lambda index, size: index % size, + "mirror": _mirror_index_fixer, + "reflect": _reflect_index_fixer, +} + + +def _round_half_away_from_zero(a): + return a if a.dtype.is_integer else tf.round(a) + + +def _nearest_indices_and_weights(coordinate): + index = tf.cast(_round_half_away_from_zero(coordinate), tf.int32) + weight = tf.constant(1, coordinate.dtype) + return [(index, weight)] + + +def _linear_indices_and_weights(coordinate): + lower = tf.floor(coordinate) + upper_weight = coordinate - lower + lower_weight = 1 - upper_weight + index = tf.cast(lower, tf.int32) + return [(index, lower_weight), (index + 1, upper_weight)] + + +def map_coordinates(input, coordinates, order, mode="constant", cval=0.0): + input_arr = convert_to_tensor(input) + coordinate_arrs = convert_to_tensor(coordinates) + cval = convert_to_tensor(tf.cast(cval, input_arr.dtype)) + + if coordinates.shape[0] != len(input_arr.shape): + raise ValueError( + "coordinates must be a sequence of length input.ndim, but " + f"{coordinates.shape[0]} != {len(input_arr.shape)}" + ) + + index_fixer = _INDEX_FIXERS.get(mode) + if index_fixer is None: + raise NotImplementedError( + f"map_coordinates does not yet support mode {mode}. " + f"Currently supported modes are {set(_INDEX_FIXERS)}." + ) + + def is_valid(index, size): + if mode == "constant": + return (0 <= index) & (index < size) + else: + return True + + if order == 0: + interp_fun = _nearest_indices_and_weights + elif order == 1: + interp_fun = _linear_indices_and_weights + else: + raise NotImplementedError("map_coordinates currently requires order<=1") + + valid_1d_interpolations = [] + for coordinate, size in zip(coordinate_arrs, input_arr.shape): + interp_nodes = interp_fun(coordinate) + valid_interp = [] + for index, weight in interp_nodes: + fixed_index = index_fixer(index, size) + valid = is_valid(index, size) + valid_interp.append((fixed_index, valid, weight)) + valid_1d_interpolations.append(valid_interp) + + outputs = [] + for items in itertools.product(*valid_1d_interpolations): + indices, validities, weights = _unzip3(items) + indices = tf.transpose(tf.stack(indices)) + if tf.reduce_all(validities): + # fast path + contribution = tf.gather_nd(input_arr, indices) + else: + all_valid = tf.reduce_all(validities) + contribution = tf.where( + all_valid, tf.gather_nd(input_arr, indices), cval + ) + outputs.append( + tf.reduce_prod(weights, axis=0) + * tf.cast(contribution, weights[0].dtype) + ) + result = tf.reduce_sum(outputs, axis=0) + if input_arr.dtype.is_integer: + result = _round_half_away_from_zero(result) + return tf.cast(result, input_arr.dtype) diff --git a/keras_core/backend/torch/image.py b/keras_core/backend/torch/image.py index 9c85861d6..97a8ff82b 100644 --- a/keras_core/backend/torch/image.py +++ b/keras_core/backend/torch/image.py @@ -1,3 +1,7 @@ +import functools +import itertools +import operator + import torch import torch.nn.functional as tnn @@ -263,3 +267,119 @@ def affine_transform( if need_squeeze: affined = affined.squeeze(dim=0) return affined + + +def _unzip3(xyzs): + """Unzip sequence of length-3 tuples into three tuples.""" + # Note: we deliberately don't use zip(*xyzs) because it is lazily evaluated, + # is too permissive about inputs, and does not guarantee a length-3 output. + xs = [] + ys = [] + zs = [] + for x, y, z in xyzs: + xs.append(x) + ys.append(y) + zs.append(z) + return tuple(xs), tuple(ys), tuple(zs) + + +def _mirror_index_fixer(index, size): + s = size - 1 # Half-wavelength of triangular wave + # Scaled, integer-valued version of the triangular wave |x - round(x)| + return torch.abs((index + s) % (2 * s) - s) + + +def _reflect_index_fixer(index, size): + return torch.floor_divide( + _mirror_index_fixer(2 * index + 1, 2 * size + 1) - 1, 2 + ) + + +_INDEX_FIXERS = { + "constant": lambda index, size: index, + "nearest": lambda index, size: torch.clip(index, 0, size - 1), + "wrap": lambda index, size: index % size, + "mirror": _mirror_index_fixer, + "reflect": _reflect_index_fixer, +} + + +def _round_half_away_from_zero(a): + return ( + a + if (not torch.is_floating_point(a) and not torch.is_complex(a)) + else torch.round(a) + ) + + +def _nearest_indices_and_weights(coordinate): + index = _round_half_away_from_zero(coordinate).to(torch.int32) + weight = torch.tensor(1).to(torch.int32) + return [(index, weight)] + + +def _linear_indices_and_weights(coordinate): + lower = torch.floor(coordinate) + upper_weight = coordinate - lower + lower_weight = 1 - upper_weight + index = lower.to(torch.int32) + return [(index, lower_weight), (index + 1, upper_weight)] + + +def map_coordinates(input, coordinates, order, mode="constant", cval=0.0): + input_arr = convert_to_tensor(input) + coordinate_arrs = convert_to_tensor(coordinates) + cval = convert_to_tensor(cval, input_arr.dtype) + + if len(coordinates) != input_arr.ndim: + raise ValueError( + "coordinates must be a sequence of length input.ndim, but " + "{} != {}".format(len(coordinates), input_arr.ndim) + ) + + index_fixer = _INDEX_FIXERS.get(mode) + if index_fixer is None: + raise NotImplementedError( + "map_coordinates does not yet support mode {}. " + "Currently supported modes are {}.".format(mode, set(_INDEX_FIXERS)) + ) + + def is_valid(index, size): + if mode == "constant": + return (0 <= index) & (index < size) + else: + return True + + if order == 0: + interp_fun = _nearest_indices_and_weights + elif order == 1: + interp_fun = _linear_indices_and_weights + else: + raise NotImplementedError("map_coordinates currently requires order<=1") + + valid_1d_interpolations = [] + for coordinate, size in zip(coordinate_arrs, input_arr.shape): + interp_nodes = interp_fun(coordinate) + valid_interp = [] + for index, weight in interp_nodes: + fixed_index = index_fixer(index, size) + valid = is_valid(index, size) + valid_interp.append((fixed_index, valid, weight)) + valid_1d_interpolations.append(valid_interp) + + outputs = [] + for items in itertools.product(*valid_1d_interpolations): + indices, validities, weights = _unzip3(items) + if all(valid is True for valid in validities): + # fast path + contribution = input_arr[indices] + else: + all_valid = functools.reduce(operator.and_, validities) + contribution = torch.where(all_valid, input_arr[indices], cval) + outputs.append(functools.reduce(operator.mul, weights) * contribution) + result = functools.reduce(operator.add, outputs) + if not torch.is_floating_point(input_arr) and not torch.is_complex( + input_arr + ): + result = _round_half_away_from_zero(result) + return result.to(input_arr.dtype) diff --git a/keras_core/ops/image.py b/keras_core/ops/image.py index 7351a7425..354627622 100644 --- a/keras_core/ops/image.py +++ b/keras_core/ops/image.py @@ -418,3 +418,101 @@ def _extract_patches( if _unbatched: patches = backend.numpy.squeeze(patches, axis=0) return patches + + +class MapCoordinates(Operation): + def __init__(self, order, mode, cval=0.0): + super().__init__() + self.order = order + self.mode = mode + self.cval = cval + + def call(self, image, coordinates): + return backend.image.map_coordinates( + image, + coordinates, + order=self.order, + mode=self.mode, + cval=self.cval, + ) + + def compute_output_spec(self, image, coordinates): + if coordinates.shape[0] != len(image.shape): + raise ValueError( + "First dim of `coordinates` must be the same as the rank of " + "`image`. " + f"Received image with shape: {image.shape} and coordinate " + f"leading dim of {coordinates.shape[0]}" + ) + if len(coordinates.shape) < 2: + raise ValueError( + "Invalid coordinates rank: expected at least rank 2." + f" Received input with shape: {coordinates.shape}" + ) + return KerasTensor(coordinates.shape[1:], dtype=image.dtype) + + +@keras_core_export("keras_core.ops.image.map_coordinates") +def map_coordinates(input, coordinates, order, mode="constant", cval=0.0): + """Applies the image(s) onto a set of coordinates. + + Args: + image: Input image or batch of images. Must be 3D or 4D. + transform: Projective transform matrix/matrices. A vector of length 8 or + tensor of size N x 8. If one row of transform is + `[a0, a1, a2, b0, b1, b2, c0, c1]`, then it maps the output point + `(x, y)` to a transformed input point + `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, + where `k = c0 x + c1 y + 1`. The transform is inverted compared to + the transform mapping input points to output points. Note that + gradients are not backpropagated into transformation parameters. + Note that `c0` and `c1` are only effective when using TensorFlow + backend and will be considered as `0` when using other backends. + interpolation: Interpolation method. Available methods are `"nearest"`, + and `"bilinear"`. Defaults to `"bilinear"`. + fill_mode: Points outside the boundaries of the input are filled + according to the given mode. Available methods are `"constant"`, + `"nearest"`, `"wrap"` and `"reflect"`. Defaults to `"constant"`. + - `"reflect"`: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last + pixel. + - `"constant"`: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond + the edge with the same constant value k specified by + `fill_value`. + - `"wrap"`: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - `"nearest"`: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + Note that when using torch backend, `"reflect"` is redirected to + `"mirror"` `(c d c b | a b c d | c b a b)` because torch does not + support `"reflect"`. + Note that torch backend does not support `"wrap"`. + fill_value: Value used for points outside the boundaries of the input if + `fill_mode="constant"`. Defaults to `0`. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch, channels, height, weight)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + + Returns: + Applied affine transform image or batch of images. + + """ + if any_symbolic_tensors((input, coordinates)): + return MapCoordinates( + order, + mode, + cval, + ).symbolic_call(input, coordinates) + return backend.image.map_coordinates( + input, + coordinates, + order, + mode, + cval, + ) diff --git a/keras_core/ops/image_test.py b/keras_core/ops/image_test.py index 1e4d05880..d1c232842 100644 --- a/keras_core/ops/image_test.py +++ b/keras_core/ops/image_test.py @@ -350,3 +350,11 @@ def test_extract_patches( self.assertAllClose( patches_ref.numpy(), backend.convert_to_numpy(patches_out), atol=0.3 ) + + def test_map_coordinates(self): + data = np.arange(12).reshape((4, 3)) + coordinates = np.array([[0.5, 2], [0.5, 1]]) + expected = np.array([2.0, 7.0]) + map_coordinates_out = kimage.map_coordinates(data, coordinates, 1) + + self.assertAllClose(map_coordinates_out, expected) From e773a7105d8156d7a043360742bb2a0e15f30bd8 Mon Sep 17 00:00:00 2001 From: HongYu <20734616+james77777778@users.noreply.github.com> Date: Mon, 18 Sep 2023 04:39:48 +0000 Subject: [PATCH 2/5] Add `map_coordinates` --- keras_core/backend/jax/image.py | 19 ++++--- keras_core/backend/numpy/image.py | 27 +++++---- keras_core/backend/tensorflow/image.py | 77 ++++++++++++------------- keras_core/backend/torch/image.py | 66 +++++++++------------ keras_core/ops/image.py | 75 +++++++++++------------- keras_core/ops/image_test.py | 79 ++++++++++++++++++++++++-- 6 files changed, 195 insertions(+), 148 deletions(-) diff --git a/keras_core/backend/jax/image.py b/keras_core/backend/jax/image.py index 7cdc1f04d..421988022 100644 --- a/keras_core/backend/jax/image.py +++ b/keras_core/backend/jax/image.py @@ -164,7 +164,7 @@ def affine_transform( return affined -MAP_COORDINATES_MODES = { +MAP_COORDINATES_FILL_MODES = { "constant", "nearest", "wrap", @@ -173,19 +173,20 @@ def affine_transform( } -def map_coordinates(input, coordinates, order, mode="constant", cval=0.0): - if mode not in MAP_COORDINATES_MODES: +def map_coordinates( + input, coordinates, order, fill_mode="constant", fill_value=0.0 +): + if fill_mode not in MAP_COORDINATES_FILL_MODES: raise ValueError( - "Invalid value for argument `mode`. Expected one of " - f"{set(MAP_COORDINATES_MODES.keys())}. Received: " - f"mode={mode}" + "Invalid value for argument `fill_mode`. Expected one of " + f"{set(MAP_COORDINATES_FILL_MODES)}. Received: " + f"fill_mode={fill_mode}" ) if order not in range(2): raise ValueError( "Invalid value for argument `order`. Expected one of " - f"{[0, 1]}. Received: " - f"mode={mode}" + f"{[0, 1]}. Received: order={order}" ) return jax.scipy.ndimage.map_coordinates( - input, coordinates, order, mode, cval + input, coordinates, order, fill_mode, fill_value ) diff --git a/keras_core/backend/numpy/image.py b/keras_core/backend/numpy/image.py index 935410240..408d126be 100644 --- a/keras_core/backend/numpy/image.py +++ b/keras_core/backend/numpy/image.py @@ -175,7 +175,7 @@ def affine_transform( return affined -MAP_COORDINATES_MODES = { +MAP_COORDINATES_FILL_MODES = { "constant", "nearest", "wrap", @@ -184,18 +184,19 @@ def affine_transform( } -def map_coordinates(input, coordinates, order, mode="constant", cval=0.0): - if mode not in MAP_COORDINATES_MODES: +def map_coordinates( + input, coordinates, order, fill_mode="constant", fill_value=0.0 +): + if fill_mode not in MAP_COORDINATES_FILL_MODES: raise ValueError( - "Invalid value for argument `mode`. Expected one of " - f"{set(MAP_COORDINATES_MODES.keys())}. Received: " - f"mode={mode}" + "Invalid value for argument `fill_mode`. Expected one of " + f"{set(MAP_COORDINATES_FILL_MODES.keys())}. Received: " + f"fill_mode={fill_mode}" ) if order not in range(2): raise ValueError( "Invalid value for argument `order`. Expected one of " - f"{[0, 1]}. Received: " - f"mode={mode}" + f"{[0, 1]}. Received: order={order}" ) # SciPy's implementation of map_coordinates handles boundaries incorrectly, # unless mode='reflect'. For order=1, this only affects interpolation @@ -213,12 +214,14 @@ def map_coordinates(input, coordinates, order, mode="constant", cval=0.0): "nearest": "edge", "mirror": "reflect", "reflect": "symmetric", - }.get(mode, mode) - if mode == "constant": - padded = np.pad(input, padding, mode=pad_mode, constant_values=cval) + }.get(fill_mode, fill_mode) + if fill_mode == "constant": + padded = np.pad( + input, padding, mode=pad_mode, constant_values=fill_value + ) else: padded = np.pad(input, padding, mode=pad_mode) result = scipy.ndimage.map_coordinates( - padded, shifted_coords, order=order, mode=mode, cval=cval + padded, shifted_coords, order=order, mode=fill_mode, cval=fill_value ) return result diff --git a/keras_core/backend/tensorflow/image.py b/keras_core/backend/tensorflow/image.py index f6f19934f..b24528049 100644 --- a/keras_core/backend/tensorflow/image.py +++ b/keras_core/backend/tensorflow/image.py @@ -1,4 +1,6 @@ +import functools import itertools +import operator import tensorflow as tf @@ -125,20 +127,6 @@ def affine_transform( return affined -def _unzip3(xyzs): - """Unzip sequence of length-3 tuples into three tuples.""" - # Note: we deliberately don't use zip(*xyzs) because it is lazily evaluated, - # is too permissive about inputs, and does not guarantee a length-3 output. - xs = [] - ys = [] - zs = [] - for x, y, z in xyzs: - xs.append(x) - ys.append(y) - zs.append(z) - return tuple(xs), tuple(ys), tuple(zs) - - def _mirror_index_fixer(index, size): s = size - 1 # Half-wavelength of triangular wave # Scaled, integer-valued version of the triangular wave |x - round(x)| @@ -160,12 +148,11 @@ def _reflect_index_fixer(index, size): } -def _round_half_away_from_zero(a): - return a if a.dtype.is_integer else tf.round(a) - - def _nearest_indices_and_weights(coordinate): - index = tf.cast(_round_half_away_from_zero(coordinate), tf.int32) + coordinate = ( + coordinate if coordinate.dtype.is_integer else tf.round(coordinate) + ) + index = tf.cast(coordinate, tf.int32) weight = tf.constant(1, coordinate.dtype) return [(index, weight)] @@ -178,26 +165,31 @@ def _linear_indices_and_weights(coordinate): return [(index, lower_weight), (index + 1, upper_weight)] -def map_coordinates(input, coordinates, order, mode="constant", cval=0.0): +def map_coordinates( + input, coordinates, order, fill_mode="constant", fill_value=0.0 +): input_arr = convert_to_tensor(input) coordinate_arrs = convert_to_tensor(coordinates) - cval = convert_to_tensor(tf.cast(cval, input_arr.dtype)) + # unstack into a list of tensors for following operations + coordinate_arrs = tf.unstack(coordinate_arrs, axis=0) + fill_value = convert_to_tensor(tf.cast(fill_value, input_arr.dtype)) - if coordinates.shape[0] != len(input_arr.shape): + if len(coordinates) != len(input_arr.shape): raise ValueError( - "coordinates must be a sequence of length input.ndim, but " - f"{coordinates.shape[0]} != {len(input_arr.shape)}" + "coordinates must be a sequence of length input.shape, but " + f"{len(coordinates)} != {len(input_arr.shape)}" ) - index_fixer = _INDEX_FIXERS.get(mode) + index_fixer = _INDEX_FIXERS.get(fill_mode) if index_fixer is None: - raise NotImplementedError( - f"map_coordinates does not yet support mode {mode}. " - f"Currently supported modes are {set(_INDEX_FIXERS)}." + raise ValueError( + "Invalid value for argument `fill_mode`. Expected one of " + f"{set(_INDEX_FIXERS.keys())}. Received: " + f"fill_mode={fill_mode}" ) def is_valid(index, size): - if mode == "constant": + if fill_mode == "constant": return (0 <= index) & (index < size) else: return True @@ -221,21 +213,26 @@ def is_valid(index, size): outputs = [] for items in itertools.product(*valid_1d_interpolations): - indices, validities, weights = _unzip3(items) + indices, validities, weights = zip(*items) indices = tf.transpose(tf.stack(indices)) - if tf.reduce_all(validities): - # fast path - contribution = tf.gather_nd(input_arr, indices) - else: - all_valid = tf.reduce_all(validities) - contribution = tf.where( - all_valid, tf.gather_nd(input_arr, indices), cval + + def fast_path(): + return tf.transpose(tf.gather_nd(input_arr, indices)) + + def slow_path(): + all_valid = functools.reduce(operator.and_, validities) + return tf.where( + all_valid, + tf.transpose(tf.gather_nd(input_arr, indices)), + fill_value, ) + + contribution = tf.cond(tf.reduce_all(validities), fast_path, slow_path) outputs.append( - tf.reduce_prod(weights, axis=0) + functools.reduce(operator.mul, weights) * tf.cast(contribution, weights[0].dtype) ) - result = tf.reduce_sum(outputs, axis=0) + result = functools.reduce(operator.add, outputs) if input_arr.dtype.is_integer: - result = _round_half_away_from_zero(result) + result = result if result.dtype.is_integer else tf.round(result) return tf.cast(result, input_arr.dtype) diff --git a/keras_core/backend/torch/image.py b/keras_core/backend/torch/image.py index 97a8ff82b..7ef82eef2 100644 --- a/keras_core/backend/torch/image.py +++ b/keras_core/backend/torch/image.py @@ -269,20 +269,6 @@ def affine_transform( return affined -def _unzip3(xyzs): - """Unzip sequence of length-3 tuples into three tuples.""" - # Note: we deliberately don't use zip(*xyzs) because it is lazily evaluated, - # is too permissive about inputs, and does not guarantee a length-3 output. - xs = [] - ys = [] - zs = [] - for x, y, z in xyzs: - xs.append(x) - ys.append(y) - zs.append(z) - return tuple(xs), tuple(ys), tuple(zs) - - def _mirror_index_fixer(index, size): s = size - 1 # Half-wavelength of triangular wave # Scaled, integer-valued version of the triangular wave |x - round(x)| @@ -304,16 +290,17 @@ def _reflect_index_fixer(index, size): } -def _round_half_away_from_zero(a): - return ( - a - if (not torch.is_floating_point(a) and not torch.is_complex(a)) - else torch.round(a) - ) +def _is_integer(a): + if not torch.is_floating_point(a) and not torch.is_complex(a): + return True + return False def _nearest_indices_and_weights(coordinate): - index = _round_half_away_from_zero(coordinate).to(torch.int32) + coordinate = ( + coordinate if _is_integer(coordinate) else torch.round(coordinate) + ) + index = coordinate.to(torch.int32) weight = torch.tensor(1).to(torch.int32) return [(index, weight)] @@ -326,26 +313,29 @@ def _linear_indices_and_weights(coordinate): return [(index, lower_weight), (index + 1, upper_weight)] -def map_coordinates(input, coordinates, order, mode="constant", cval=0.0): +def map_coordinates( + input, coordinates, order, fill_mode="constant", fill_value=0.0 +): input_arr = convert_to_tensor(input) - coordinate_arrs = convert_to_tensor(coordinates) - cval = convert_to_tensor(cval, input_arr.dtype) + coordinate_arrs = [convert_to_tensor(c) for c in coordinates] + fill_value = convert_to_tensor(fill_value, input_arr.dtype) - if len(coordinates) != input_arr.ndim: + if len(coordinates) != len(input_arr.shape): raise ValueError( - "coordinates must be a sequence of length input.ndim, but " - "{} != {}".format(len(coordinates), input_arr.ndim) + "coordinates must be a sequence of length input.shape, but " + f"{len(coordinates)} != {len(input_arr.shape)}" ) - index_fixer = _INDEX_FIXERS.get(mode) + index_fixer = _INDEX_FIXERS.get(fill_mode) if index_fixer is None: - raise NotImplementedError( - "map_coordinates does not yet support mode {}. " - "Currently supported modes are {}.".format(mode, set(_INDEX_FIXERS)) + raise ValueError( + "Invalid value for argument `fill_mode`. Expected one of " + f"{set(_INDEX_FIXERS.keys())}. Received: " + f"fill_mode={fill_mode}" ) def is_valid(index, size): - if mode == "constant": + if fill_mode == "constant": return (0 <= index) & (index < size) else: return True @@ -369,17 +359,17 @@ def is_valid(index, size): outputs = [] for items in itertools.product(*valid_1d_interpolations): - indices, validities, weights = _unzip3(items) + indices, validities, weights = zip(*items) if all(valid is True for valid in validities): # fast path contribution = input_arr[indices] else: all_valid = functools.reduce(operator.and_, validities) - contribution = torch.where(all_valid, input_arr[indices], cval) + contribution = torch.where( + all_valid, input_arr[indices], fill_value + ) outputs.append(functools.reduce(operator.mul, weights) * contribution) result = functools.reduce(operator.add, outputs) - if not torch.is_floating_point(input_arr) and not torch.is_complex( - input_arr - ): - result = _round_half_away_from_zero(result) + if _is_integer(input_arr): + result = result if _is_integer(result) else torch.round(result) return result.to(input_arr.dtype) diff --git a/keras_core/ops/image.py b/keras_core/ops/image.py index 354627622..c924ed940 100644 --- a/keras_core/ops/image.py +++ b/keras_core/ops/image.py @@ -421,19 +421,19 @@ def _extract_patches( class MapCoordinates(Operation): - def __init__(self, order, mode, cval=0.0): + def __init__(self, order, fill_mode="constant", fill_value=0): super().__init__() self.order = order - self.mode = mode - self.cval = cval + self.fill_mode = fill_mode + self.fill_value = fill_value def call(self, image, coordinates): return backend.image.map_coordinates( image, coordinates, order=self.order, - mode=self.mode, - cval=self.cval, + fill_mode=self.fill_mode, + fill_value=self.fill_value, ) def compute_output_spec(self, image, coordinates): @@ -453,66 +453,55 @@ def compute_output_spec(self, image, coordinates): @keras_core_export("keras_core.ops.image.map_coordinates") -def map_coordinates(input, coordinates, order, mode="constant", cval=0.0): - """Applies the image(s) onto a set of coordinates. +def map_coordinates( + input, coordinates, order, fill_mode="constant", fill_value=0 +): + """Map the input array to new coordinates by interpolation.. + + Note that interpolation near boundaries differs from the scipy function, + because we fixed an outstanding bug + https://github.com/scipy/scipy/issues/2640. Args: - image: Input image or batch of images. Must be 3D or 4D. - transform: Projective transform matrix/matrices. A vector of length 8 or - tensor of size N x 8. If one row of transform is - `[a0, a1, a2, b0, b1, b2, c0, c1]`, then it maps the output point - `(x, y)` to a transformed input point - `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, - where `k = c0 x + c1 y + 1`. The transform is inverted compared to - the transform mapping input points to output points. Note that - gradients are not backpropagated into transformation parameters. - Note that `c0` and `c1` are only effective when using TensorFlow - backend and will be considered as `0` when using other backends. - interpolation: Interpolation method. Available methods are `"nearest"`, - and `"bilinear"`. Defaults to `"bilinear"`. + input: The input array. + coordinates: The coordinates at which input is evaluated. + order: The order of the spline interpolation. The order must be `0` or + `1`. `0` indicates the nearest neighbor and `1` indicates the linear + interpolation. fill_mode: Points outside the boundaries of the input are filled according to the given mode. Available methods are `"constant"`, - `"nearest"`, `"wrap"` and `"reflect"`. Defaults to `"constant"`. - - `"reflect"`: `(d c b a | a b c d | d c b a)` - The input is extended by reflecting about the edge of the last - pixel. + `"nearest"`, `"wrap"` and `"mirror"` and `"reflect"`. Defaults to + `"constant"`. - `"constant"`: `(k k k k | a b c d | k k k k)` The input is extended by filling all values beyond the edge with the same constant value k specified by `fill_value`. - - `"wrap"`: `(a b c d | a b c d | a b c d)` - The input is extended by wrapping around to the opposite edge. - `"nearest"`: `(a a a a | a b c d | d d d d)` The input is extended by the nearest pixel. - Note that when using torch backend, `"reflect"` is redirected to - `"mirror"` `(c d c b | a b c d | c b a b)` because torch does not - support `"reflect"`. - Note that torch backend does not support `"wrap"`. + - `"wrap"`: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - `"mirror"`: `(c d c b | a b c d | c b a b)` + The input is extended by mirroring about the edge. + - `"reflect"`: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last + pixel. fill_value: Value used for points outside the boundaries of the input if `fill_mode="constant"`. Defaults to `0`. - data_format: string, either `"channels_last"` or `"channels_first"`. - The ordering of the dimensions in the inputs. `"channels_last"` - corresponds to inputs with shape `(batch, height, width, channels)` - while `"channels_first"` corresponds to inputs with shape - `(batch, channels, height, weight)`. It defaults to the - `image_data_format` value found in your Keras config file at - `~/.keras/keras.json`. If you never set it, then it will be - `"channels_last"`. Returns: - Applied affine transform image or batch of images. + Output image or batch of images. """ if any_symbolic_tensors((input, coordinates)): return MapCoordinates( order, - mode, - cval, + fill_mode, + fill_value, ).symbolic_call(input, coordinates) return backend.image.map_coordinates( input, coordinates, order, - mode, - cval, + fill_mode, + fill_value, ) diff --git a/keras_core/ops/image_test.py b/keras_core/ops/image_test.py index d1c232842..cd756fefe 100644 --- a/keras_core/ops/image_test.py +++ b/keras_core/ops/image_test.py @@ -1,3 +1,5 @@ +import math + import numpy as np import pytest import scipy.ndimage @@ -34,6 +36,12 @@ def test_extract_patches(self): out = kimage.extract_patches(x, 5) self.assertEqual(out.shape, (None, 4, 4, 75)) + def test_map_coordinates(self): + input = KerasTensor([20, 20, None]) + coordinates = KerasTensor([3, 15, 15, None]) + out = kimage.map_coordinates(input, coordinates, 0) + self.assertEqual(out.shape, coordinates.shape[1:]) + class ImageOpsStaticShapeTest(testing.TestCase): def test_resize(self): @@ -55,6 +63,12 @@ def test_extract_patches(self): out = kimage.extract_patches(x, 5) self.assertEqual(out.shape, (4, 4, 75)) + def test_map_coordinates(self): + input = KerasTensor([20, 20, 3]) + coordinates = KerasTensor([3, 15, 15, 3]) + out = kimage.map_coordinates(input, coordinates, 0) + self.assertEqual(out.shape, coordinates.shape[1:]) + AFFINE_TRANSFORM_INTERPOLATIONS = { # map to order "nearest": 0, @@ -107,6 +121,38 @@ def _compute_affine_transform_coordinates(image, transform): return coordinates +def _map_coordinates( + input, coordinates, order, fill_mode="constant", fill_value=0.0 +): + # SciPy's implementation of map_coordinates handles boundaries incorrectly, + # unless mode='reflect'. For order=1, this only affects interpolation + # outside the bounds of the original array. + # https://github.com/scipy/scipy/issues/2640 + padding = [ + ( + max(-np.floor(c.min()).astype(int) + 1, 0), + max(np.ceil(c.max()).astype(int) + 1 - size, 0), + ) + for c, size in zip(coordinates, input.shape) + ] + shifted_coords = [c + p[0] for p, c in zip(padding, coordinates)] + pad_mode = { + "nearest": "edge", + "mirror": "reflect", + "reflect": "symmetric", + }.get(fill_mode, fill_mode) + if fill_mode == "constant": + padded = np.pad( + input, padding, mode=pad_mode, constant_values=fill_value + ) + else: + padded = np.pad(input, padding, mode=pad_mode) + result = scipy.ndimage.map_coordinates( + padded, shifted_coords, order=order, mode=fill_mode, cval=fill_value + ) + return result + + class ImageOpsCorrectnessTest(testing.TestCase, parameterized.TestCase): @parameterized.parameters( [ @@ -351,10 +397,31 @@ def test_extract_patches( patches_ref.numpy(), backend.convert_to_numpy(patches_out), atol=0.3 ) - def test_map_coordinates(self): - data = np.arange(12).reshape((4, 3)) - coordinates = np.array([[0.5, 2], [0.5, 1]]) - expected = np.array([2.0, 7.0]) - map_coordinates_out = kimage.map_coordinates(data, coordinates, 1) + @parameterized.product( + # (input_shape, coordinates_shape) + shape=[((5,), (7,)), ((3, 4, 5), (2, 3, 4))], + dtype=["uint8", "float16", "float32"], + order=[0, 1], + fill_mode=["constant", "nearest", "wrap", "mirror", "reflect"], + ) + def test_map_coordinates(self, shape, dtype, order, fill_mode): + input_shape, coordinates_shape = shape + input = np.arange(math.prod(input_shape), dtype=dtype).reshape( + input_shape + ) + coordinates_dtype = "float32" if "int" in dtype else dtype + coordinates = [ + (size - 1) + * np.random.uniform(size=coordinates_shape).astype( + coordinates_dtype + ) + for size in input_shape + ] + output = kimage.map_coordinates(input, coordinates, order, fill_mode) + if dtype == "float16": + self.skipTest( + "scipy.ndimage.map_coordinates does not support float16" + ) + expected = _map_coordinates(input, coordinates, order, fill_mode) - self.assertAllClose(map_coordinates_out, expected) + self.assertAllClose(output, expected) From 4aa0972e6d9582f02d28b6d67518fff844492ac8 Mon Sep 17 00:00:00 2001 From: HongYu <20734616+james77777778@users.noreply.github.com> Date: Mon, 18 Sep 2023 04:52:20 +0000 Subject: [PATCH 3/5] Remove testing file --- check_map.py | 16 ---------------- 1 file changed, 16 deletions(-) delete mode 100644 check_map.py diff --git a/check_map.py b/check_map.py deleted file mode 100644 index 82380a61a..000000000 --- a/check_map.py +++ /dev/null @@ -1,16 +0,0 @@ -import numpy as np - -from keras_core.backend.jax.image import map_coordinates as jax_map_coordinates -from keras_core.backend.tensorflow.image import ( - map_coordinates as tf_map_coordinates, -) -from keras_core.backend.torch.image import ( - map_coordinates as torch_map_coordinates, -) - -data = np.arange(12).reshape((4, 3)) -coordinates = np.array([[0.5, 2], [0.5, 1]]) - -# print(jax_map_coordinates(data, coordinates, 1)) -print(tf_map_coordinates(data, coordinates, 1)) -# print(torch_map_coordinates(data, coordinates, 1)) From 53507e17fbf5e5050a68562673807b7122cf7630 Mon Sep 17 00:00:00 2001 From: HongYu <20734616+james77777778@users.noreply.github.com> Date: Mon, 18 Sep 2023 05:08:03 +0000 Subject: [PATCH 4/5] Fix unit test --- keras_core/ops/image_test.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/keras_core/ops/image_test.py b/keras_core/ops/image_test.py index cd756fefe..940624f53 100644 --- a/keras_core/ops/image_test.py +++ b/keras_core/ops/image_test.py @@ -400,7 +400,9 @@ def test_extract_patches( @parameterized.product( # (input_shape, coordinates_shape) shape=[((5,), (7,)), ((3, 4, 5), (2, 3, 4))], - dtype=["uint8", "float16", "float32"], + # TODO: scipy.ndimage.map_coordinates does not support float16 + # TODO: torch cpu does not support round & floor for float16 + dtype=["uint8", "int32", "float32"], order=[0, 1], fill_mode=["constant", "nearest", "wrap", "mirror", "reflect"], ) @@ -418,10 +420,6 @@ def test_map_coordinates(self, shape, dtype, order, fill_mode): for size in input_shape ] output = kimage.map_coordinates(input, coordinates, order, fill_mode) - if dtype == "float16": - self.skipTest( - "scipy.ndimage.map_coordinates does not support float16" - ) expected = _map_coordinates(input, coordinates, order, fill_mode) self.assertAllClose(output, expected) From 5a37ea6aa40f53c9eb77d2f47ad3b276fbb1d92b Mon Sep 17 00:00:00 2001 From: chiuhongyu <20734616+james77777778@users.noreply.github.com> Date: Tue, 19 Sep 2023 08:06:27 +0800 Subject: [PATCH 5/5] Address comment --- keras_core/ops/image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_core/ops/image.py b/keras_core/ops/image.py index c924ed940..bc8e48d2d 100644 --- a/keras_core/ops/image.py +++ b/keras_core/ops/image.py @@ -460,7 +460,7 @@ def map_coordinates( Note that interpolation near boundaries differs from the scipy function, because we fixed an outstanding bug - https://github.com/scipy/scipy/issues/2640. + [scipy/issues/2640](https://github.com/scipy/scipy/issues/2640). Args: input: The input array.