Skip to content

Commit

Permalink
rearrange jax.fake_numpy to match other contexts
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl authored and inducer committed Jun 26, 2022
1 parent be1429c commit ff1cd0c
Showing 1 changed file with 82 additions and 43 deletions.
125 changes: 82 additions & 43 deletions arraycontext/impl/jax/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,40 +50,81 @@ def _get_fake_numpy_linalg_namespace(self):
def __getattr__(self, name):
return partial(rec_multimap_array_container, getattr(jnp, name))

# NOTE: the order of these follows the order in numpy docs
# NOTE: when adding a function here, also add it to `array_context.rst` docs!

# {{{ array creation routines

def ones_like(self, ary):
return self.full_like(ary, 1)

def full_like(self, ary, fill_value):
def _full_like(subary):
return jnp.full_like(ary, fill_value)

return self._new_like(ary, _full_like)

# }}}

# {{{ array manipulation routies

def reshape(self, a, newshape, order="C"):
return rec_map_array_container(
lambda ary: jnp.reshape(ary, newshape, order=order),
a)

def transpose(self, a, axes=None):
return rec_multimap_array_container(jnp.transpose, a, axes)
def ravel(self, a, order="C"):
"""
.. warning::
def concatenate(self, arrays, axis=0):
return rec_multimap_array_container(jnp.concatenate, arrays, axis)
Since :func:`jax.numpy.reshape` does not support orders `A`` and
``K``, in such cases we fallback to using ``order = C``.
"""
if order in "AK":
from warnings import warn
warn(f"ravel with order='{order}' not supported by JAX,"
" using order=C.")
order = "C"

def where(self, criterion, then, else_):
return rec_multimap_array_container(jnp.where, criterion, then, else_)
return rec_map_array_container(
lambda subary: jnp.ravel(subary, order=order), a)

def sum(self, a, axis=None, dtype=None):
return rec_map_reduce_array_container(sum,
partial(jnp.sum,
axis=axis,
dtype=dtype),
a)
def transpose(self, a, axes=None):
return rec_multimap_array_container(jnp.transpose, a, axes)

def min(self, a, axis=None):
return rec_map_reduce_array_container(
partial(reduce, jnp.minimum), partial(jnp.amin, axis=axis), a)
def broadcast_to(self, array, shape):
return rec_map_array_container(partial(jnp.broadcast_to, shape=shape), array)

def max(self, a, axis=None):
return rec_map_reduce_array_container(
partial(reduce, jnp.maximum), partial(jnp.amax, axis=axis), a)
def concatenate(self, arrays, axis=0):
return rec_multimap_array_container(jnp.concatenate, arrays, axis)

def stack(self, arrays, axis=0):
return rec_multimap_array_container(
lambda *args: jnp.stack(arrays=args, axis=axis),
*arrays)

# }}}

# {{{ linear algebra

def vdot(self, x, y, dtype=None):
from arraycontext import rec_multimap_reduce_array_container

def _rec_vdot(ary1, ary2):
if dtype not in [None, numpy.find_common_type((ary1.dtype,
ary2.dtype),
())]:
raise NotImplementedError(f"{type(self)} cannot take dtype in"
" vdot.")

return jnp.vdot(ary1, ary2)

return rec_multimap_reduce_array_container(sum, _rec_vdot, x, y)

# }}}

# {{{ logic functions

def array_equal(self, a, b):
actx = self._array_context

Expand All @@ -109,35 +150,33 @@ def rec_equal(x, y):

return rec_equal(a, b)

def ravel(self, a, order="C"):
"""
.. warning::
# }}}

Since :func:`jax.numpy.reshape` does not support orders `A`` and
``K``, in such cases we fallback to using ``order = C``.
"""
if order in "AK":
from warnings import warn
warn(f"ravel with order='{order}' not supported by JAX,"
" using order=C.")
order = "C"
# {{{ mathematical functions

def sum(self, a, axis=None, dtype=None):
return rec_map_reduce_array_container(
sum,
partial(jnp.sum, axis=axis, dtype=dtype),
a)

return rec_map_array_container(lambda subary: jnp.ravel(subary, order=order),
a)
def amin(self, a, axis=None):
return rec_map_reduce_array_container(
partial(reduce, jnp.minimum), partial(jnp.amin, axis=axis), a)

def vdot(self, x, y, dtype=None):
from arraycontext import rec_multimap_reduce_array_container
min = amin

def _rec_vdot(ary1, ary2):
if dtype not in [None, numpy.find_common_type((ary1.dtype,
ary2.dtype),
())]:
raise NotImplementedError(f"{type(self)} cannot take dtype in"
" vdot.")
def amax(self, a, axis=None):
return rec_map_reduce_array_container(
partial(reduce, jnp.maximum), partial(jnp.amax, axis=axis), a)

return jnp.vdot(ary1, ary2)
max = amax

return rec_multimap_reduce_array_container(sum, _rec_vdot, x, y)
# }}}

def broadcast_to(self, array, shape):
return rec_map_array_container(partial(jnp.broadcast_to, shape=shape), array)
# {{{ sorting, searching and counting

def where(self, criterion, then, else_):
return rec_multimap_array_container(jnp.where, criterion, then, else_)

# }}}

0 comments on commit ff1cd0c

Please sign in to comment.