Skip to content

Commit

Permalink
symbolics: rework Cast
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jan 17, 2025
1 parent 32f20ee commit 2bc6ab9
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 111 deletions.
8 changes: 4 additions & 4 deletions devito/mpi/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def _make_msg(self, f, hse, key):
return MPIMsg('msg%d' % key, f, halos)

def _make_sendrecv(self, f, hse, key, msg=None):
cast = cast_mapper[(f.c0.dtype, '*')]
cast = cast_mapper((f.c0.dtype, '*'))
comm = f.grid.distributor._obj_comm

bufg = FieldFromPointer(msg._C_field_bufg, msg)
Expand Down Expand Up @@ -671,7 +671,7 @@ def _call_compute(self, hs, compute, *args):
return compute.make_call(dynamic_args_mapper=hs.omapper.core)

def _make_wait(self, f, hse, key, msg=None):
cast = cast_mapper[(f.c0.dtype, '*')]
cast = cast_mapper((f.c0.dtype, '*'))

bufs = FieldFromPointer(msg._C_field_bufs, msg)

Expand Down Expand Up @@ -772,7 +772,7 @@ def _call_sendrecv(self, *args):
return

def _make_haloupdate(self, f, hse, key, *args, msg=None):
cast = cast_mapper[(f.c0.dtype, '*')]
cast = cast_mapper((f.c0.dtype, '*'))
comm = f.grid.distributor._obj_comm

fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices}
Expand Down Expand Up @@ -819,7 +819,7 @@ def _call_haloupdate(self, name, f, hse, msg):
return HaloUpdateCall(name, args)

def _make_halowait(self, f, hse, key, *args, msg=None):
cast = cast_mapper[(f.c0.dtype, '*')]
cast = cast_mapper((f.c0.dtype, '*'))

fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices}

Expand Down
9 changes: 7 additions & 2 deletions devito/passes/iet/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@ def _complex_includes(iet: Callable, lang: type[LangBB] = None,
"""

# Check if there are complex numbers that always take dtype precedence
types = {f.dtype for f in FindSymbols().visit(iet)
if not issubclass(f.dtype, ctypes._Pointer)}
types = set()
for f in FindSymbols().visit(iet):
try:
if not issubclass(f.dtype, ctypes._Pointer):
types.add(f.dtype)
except TypeError:
pass

if not any(np.issubdtype(d, np.complexfloating) for d in types):
return iet, {}
Expand Down
4 changes: 2 additions & 2 deletions devito/passes/iet/languages/openacc.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,11 @@ def place_devptr(self, iet, **kwargs):

dpf = List(body=[
self.lang.mapper['map-serial-present'](hp, tdp),
Block(body=DummyExpr(tdp, cast_mapper[tdp.dtype](hp)))
Block(body=DummyExpr(tdp, cast_mapper(tdp.dtype)(hp)))
])

ffp = FieldFromPointer(f._C_field_dmap, f._C_symbol)
ctdp = cast_mapper[(hp.dtype, '*')](tdp)
ctdp = cast_mapper((hp.dtype, '*'))(tdp)
cast = DummyExpr(ffp, ctdp)

ret = Return(ctdp)
Expand Down
101 changes: 22 additions & 79 deletions devito/symbolics/extended_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from devito.tools import (Bunch, float2, float3, float4, double2, double3, double4, # noqa
int2, int3, int4, ctypes_vector_mapper)

__all__ = ['cast_mapper', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', # noqa
__all__ = ['cast_mapper', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', 'BaseCast', # noqa
'DOUBLE', 'VOID', 'NoDeclStruct', 'c_complex', 'c_double_complex']


Expand Down Expand Up @@ -68,94 +68,37 @@ class CustomType(ReservedWord):
globals()[clsp.__name__] = clsp


class CHAR(Cast):
_base_typ = 'char'
def no_dtype(kwargs):
return {k: v for k, v in kwargs.items() if k != 'dtype'}


class SHORT(Cast):
_base_typ = 'short'
def cast_mapper(arg):
try:
assert len(arg) == 2 and arg[1] == '*'
return lambda v, **kw: CastStar(v, dtype=arg[0], **no_dtype(kw))
except TypeError:
return lambda v, **kw: Cast(v, dtype=arg, **no_dtype(kw))


class USHORT(Cast):
_base_typ = 'unsigned short'
FLOAT = cast_mapper(np.float32)
DOUBLE = cast_mapper(np.float64)
ULONG = cast_mapper(np.uint64)
UINTP = cast_mapper(np.uint32)


class UCHAR(Cast):
_base_typ = 'unsigned char'
# Standard ones, needed as class for e.g. single dispatch
class BaseCast(Cast):

def __new__(cls, base, stars=None, **kwargs):
kwargs['dtype'] = cls._dtype
return super().__new__(cls, base, stars=stars, **kwargs)

class LONG(Cast):
_base_typ = 'long'

class VOID(BaseCast):

class ULONG(Cast):
_base_typ = 'unsigned long'
_dtype = type('void', (ctypes.c_int,), {})


class CFLOAT(Cast):
_base_typ = 'float'
class INT(BaseCast):


class CDOUBLE(Cast):
_base_typ = 'double'


class VOID(Cast):
_base_typ = 'void'


class CHARP(CastStar):
base = CHAR


class UCHARP(CastStar):
base = UCHAR


class SHORTP(CastStar):
base = SHORT


class USHORTP(CastStar):
base = USHORT


class CFLOATP(CastStar):
base = CFLOAT


class CDOUBLEP(CastStar):
base = CDOUBLE


cast_mapper = {
np.int8: CHAR,
np.uint8: UCHAR,
np.int16: SHORT, # noqa
np.uint16: USHORT, # noqa
int: INT, # noqa
np.int32: INT, # noqa
np.int64: LONG,
np.uint64: ULONG,
np.float32: FLOAT, # noqa
float: DOUBLE, # noqa
np.float64: DOUBLE, # noqa

(np.int8, '*'): CHARP,
(np.uint8, '*'): UCHARP,
(int, '*'): INTP, # noqa
(np.uint16, '*'): USHORTP, # noqa
(np.int16, '*'): SHORTP, # noqa
(np.int32, '*'): INTP, # noqa
(np.int64, '*'): INTP, # noqa
(np.float32, '*'): FLOATP, # noqa
(float, '*'): DOUBLEP, # noqa
(np.float64, '*'): DOUBLEP, # noqa
}

for base_name in ['int', 'float', 'double']:
for i in [2, 3, 4]:
v = '%s%d' % (base_name, i)
cls = locals()[v]
cast_mapper[cls] = locals()[v.upper()]
cast_mapper[(cls, '*')] = locals()['%sP' % v.upper()]
_dtype = np.int32
41 changes: 29 additions & 12 deletions devito/symbolics/extended_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from devito.finite_differences.elementary import Min, Max
from devito.tools import (Pickable, Bunch, as_tuple, is_integer, float2, # noqa
float3, float4, double2, double3, double4, int2, int3,
int4)
int4, dtype_to_ctype, ctypes_to_cstr, ctypes_vector_mapper)
from devito.types import Symbol
from devito.types.basic import Basic

Expand Down Expand Up @@ -382,15 +382,17 @@ class Cast(UnaryOp):
Symbolic representation of the C notation `(type)expr`.
"""

_base_typ = ''
__rargs__ = ('base', )
__rkwargs__ = ('stars', 'dtype')

__rkwargs__ = ('stars',)

def __new__(cls, base, stars=None, **kwargs):
def __new__(cls, base, dtype=None, stars=None, **kwargs):
# Attempt simplifcation
# E.g., `FLOAT(32) -> 32.0` of type `sympy.Float`
try:
return sympify(eval(cls._base_typ)(base))
if isinstance(dtype, str):
return sympify(eval(dtype)(base))
else:
return sympify(dtype(base))
except (NameError, SyntaxError):
# E.g., `_base_typ` is "char" or "unsigned long"
pass
Expand All @@ -399,9 +401,22 @@ def __new__(cls, base, stars=None, **kwargs):
pass

obj = super().__new__(cls, base)
obj._stars = stars
obj._stars = stars or ''
obj._dtype = cls.__process_dtype__(dtype)
return obj

@classmethod
def __process_dtype__(cls, dtype):
if isinstance(dtype, str):
return dtype
dtype = ctypes_vector_mapper.get(dtype, dtype)
try:
dtype = ctypes_to_cstr(dtype_to_ctype(dtype))
except:
pass

return dtype

def _hashable_content(self):
return super()._hashable_content() + (self._stars,)

Expand All @@ -411,9 +426,13 @@ def _hashable_content(self):
def stars(self):
return self._stars

@property
def dtype(self):
return self._dtype

@property
def typ(self):
return '%s%s' % (self._base_typ, self.stars or '')
return '%s%s' % (self.dtype, self.stars or '')

@property
def _op(self):
Expand Down Expand Up @@ -753,10 +772,8 @@ def __str__(self):

class CastStar:

base = None

def __new__(cls, base=''):
return cls.base(base, '*')
def __new__(cls, base, dtype=None, ase=''):
return Cast(base, dtype=dtype, stars='*')


# Some other utility objects
Expand Down
15 changes: 10 additions & 5 deletions devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,11 +308,16 @@ def _print_InlineIf(self, expr):
PREC = precedence(expr)
return self.parenthesize("(%s) ? %s : %s" % (cond, true_expr, false_expr), PREC)

def _print_UnaryOp(self, expr):
if expr.base.is_Symbol:
return "%s%s" % (expr._op, self._print(expr.base))
else:
return "%s(%s)" % (expr._op, self._print(expr.base))
def _print_UnaryOp(self, expr, op=None):
op = op or expr._op
base = self._print(expr.base)
if not expr.base.is_Symbol:
base = f'({base})'
return f'{op}{base}'

def _print_Cast(self, expr):
cast = f'({self._print(expr.dtype)}{self._print(expr.stars)})'
return self._print_UnaryOp(expr, op=cast)

def _print_ComponentAccess(self, expr):
return "%s.%s" % (self._print(expr.base), expr.sindex)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest
import sympy

from devito import Constant, Eq, Function, Grid, Operator
from devito import Constant, Eq, Function, Grid, Operator, exp, log, sin
from devito.passes.iet.langbase import LangBB
from devito.passes.iet.languages.C import CBB
from devito.passes.iet.languages.openacc import AccBB
Expand Down Expand Up @@ -161,9 +161,9 @@ def test_imag_unit(dtype: np.complexfloating, kwargs: dict[str, str]) -> None:

@pytest.mark.parametrize('dtype', [np.float32, np.float64,
np.complex64, np.complex128])
@pytest.mark.parametrize(['sym', 'fun'], [(sympy.exp, np.exp),
(sympy.log, np.log),
(sympy.sin, np.sin)])
@pytest.mark.parametrize(['sym', 'fun'], [(exp, np.exp),
(log, np.log),
(sin, np.sin)])
def test_math_functions(dtype: np.dtype[np.inexact],
sym: sympy.Function, fun: np.ufunc) -> None:
"""
Expand Down
6 changes: 3 additions & 3 deletions tests/test_symbolics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
CallFromPointer, Cast, DefFunction, FieldFromPointer,
INT, FieldFromComposite, IntDiv, Namespace, Rvalue,
ReservedWord, ListInitializer, uxreplace, ccode,
retrieve_derivatives)
retrieve_derivatives, BaseCast)
from devito.tools import as_tuple
from devito.types import (Array, Bundle, FIndexed, LocalObject, Object,
Symbol as dSymbol)
Expand Down Expand Up @@ -394,8 +394,8 @@ def test_rvalue():
def test_cast():
s = Symbol(name='s', dtype=np.float32)

class BarCast(Cast):
_base_typ = 'bar'
class BarCast(BaseCast):
_dtype = 'bar'

v = BarCast(s, '**')
assert ccode(v) == '(bar**)s'
Expand Down

0 comments on commit 2bc6ab9

Please sign in to comment.