From ff1cd0cfb4aefde053bd75a21a6b812790130439 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Sun, 26 Jun 2022 16:07:40 +0300 Subject: [PATCH] rearrange jax.fake_numpy to match other contexts --- arraycontext/impl/jax/fake_numpy.py | 125 ++++++++++++++++++---------- 1 file changed, 82 insertions(+), 43 deletions(-) diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index d0466eee..8a72d9aa 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -50,40 +50,81 @@ def _get_fake_numpy_linalg_namespace(self): def __getattr__(self, name): return partial(rec_multimap_array_container, getattr(jnp, name)) + # NOTE: the order of these follows the order in numpy docs + # NOTE: when adding a function here, also add it to `array_context.rst` docs! + + # {{{ array creation routines + + def ones_like(self, ary): + return self.full_like(ary, 1) + + def full_like(self, ary, fill_value): + def _full_like(subary): + return jnp.full_like(ary, fill_value) + + return self._new_like(ary, _full_like) + + # }}} + + # {{{ array manipulation routies + def reshape(self, a, newshape, order="C"): return rec_map_array_container( lambda ary: jnp.reshape(ary, newshape, order=order), a) - def transpose(self, a, axes=None): - return rec_multimap_array_container(jnp.transpose, a, axes) + def ravel(self, a, order="C"): + """ + .. warning:: - def concatenate(self, arrays, axis=0): - return rec_multimap_array_container(jnp.concatenate, arrays, axis) + Since :func:`jax.numpy.reshape` does not support orders `A`` and + ``K``, in such cases we fallback to using ``order = C``. + """ + if order in "AK": + from warnings import warn + warn(f"ravel with order='{order}' not supported by JAX," + " using order=C.") + order = "C" - def where(self, criterion, then, else_): - return rec_multimap_array_container(jnp.where, criterion, then, else_) + return rec_map_array_container( + lambda subary: jnp.ravel(subary, order=order), a) - def sum(self, a, axis=None, dtype=None): - return rec_map_reduce_array_container(sum, - partial(jnp.sum, - axis=axis, - dtype=dtype), - a) + def transpose(self, a, axes=None): + return rec_multimap_array_container(jnp.transpose, a, axes) - def min(self, a, axis=None): - return rec_map_reduce_array_container( - partial(reduce, jnp.minimum), partial(jnp.amin, axis=axis), a) + def broadcast_to(self, array, shape): + return rec_map_array_container(partial(jnp.broadcast_to, shape=shape), array) - def max(self, a, axis=None): - return rec_map_reduce_array_container( - partial(reduce, jnp.maximum), partial(jnp.amax, axis=axis), a) + def concatenate(self, arrays, axis=0): + return rec_multimap_array_container(jnp.concatenate, arrays, axis) def stack(self, arrays, axis=0): return rec_multimap_array_container( lambda *args: jnp.stack(arrays=args, axis=axis), *arrays) + # }}} + + # {{{ linear algebra + + def vdot(self, x, y, dtype=None): + from arraycontext import rec_multimap_reduce_array_container + + def _rec_vdot(ary1, ary2): + if dtype not in [None, numpy.find_common_type((ary1.dtype, + ary2.dtype), + ())]: + raise NotImplementedError(f"{type(self)} cannot take dtype in" + " vdot.") + + return jnp.vdot(ary1, ary2) + + return rec_multimap_reduce_array_container(sum, _rec_vdot, x, y) + + # }}} + + # {{{ logic functions + def array_equal(self, a, b): actx = self._array_context @@ -109,35 +150,33 @@ def rec_equal(x, y): return rec_equal(a, b) - def ravel(self, a, order="C"): - """ - .. warning:: + # }}} - Since :func:`jax.numpy.reshape` does not support orders `A`` and - ``K``, in such cases we fallback to using ``order = C``. - """ - if order in "AK": - from warnings import warn - warn(f"ravel with order='{order}' not supported by JAX," - " using order=C.") - order = "C" + # {{{ mathematical functions + + def sum(self, a, axis=None, dtype=None): + return rec_map_reduce_array_container( + sum, + partial(jnp.sum, axis=axis, dtype=dtype), + a) - return rec_map_array_container(lambda subary: jnp.ravel(subary, order=order), - a) + def amin(self, a, axis=None): + return rec_map_reduce_array_container( + partial(reduce, jnp.minimum), partial(jnp.amin, axis=axis), a) - def vdot(self, x, y, dtype=None): - from arraycontext import rec_multimap_reduce_array_container + min = amin - def _rec_vdot(ary1, ary2): - if dtype not in [None, numpy.find_common_type((ary1.dtype, - ary2.dtype), - ())]: - raise NotImplementedError(f"{type(self)} cannot take dtype in" - " vdot.") + def amax(self, a, axis=None): + return rec_map_reduce_array_container( + partial(reduce, jnp.maximum), partial(jnp.amax, axis=axis), a) - return jnp.vdot(ary1, ary2) + max = amax - return rec_multimap_reduce_array_container(sum, _rec_vdot, x, y) + # }}} - def broadcast_to(self, array, shape): - return rec_map_array_container(partial(jnp.broadcast_to, shape=shape), array) + # {{{ sorting, searching and counting + + def where(self, criterion, then, else_): + return rec_multimap_array_container(jnp.where, criterion, then, else_) + + # }}}