Skip to content

Commit

Permalink
api: remove un-needed dtype reconstruction mode
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jan 17, 2025
1 parent e1324b8 commit dcd0bd4
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 32 deletions.
12 changes: 6 additions & 6 deletions devito/passes/iet/dtypes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import ctypes
import numpy as np

from devito.arch.compiler import Compiler
Expand All @@ -18,15 +17,16 @@ def _complex_includes(iet: Callable, lang: type[LangBB] = None,
"""

# Check if there are complex numbers that always take dtype precedence
types = set()
is_complex = False
for f in FindSymbols().visit(iet):
try:
if not issubclass(f.dtype, ctypes._Pointer):
types.add(f.dtype)
if np.issubdtype(f.dtype, np.complexfloating):
is_complex = True
break
except TypeError:
pass
continue

if not any(np.issubdtype(d, np.complexfloating) for d in types):
if not is_complex:
return iet, {}

metadata = {}
Expand Down
2 changes: 1 addition & 1 deletion devito/symbolics/extended_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def cast_mapper(arg):
FLOAT = cast_mapper(np.float32)
DOUBLE = cast_mapper(np.float64)
ULONG = cast_mapper(np.uint64)
UINTP = cast_mapper(np.uint32)
UINTP = cast_mapper((np.uint32, '*'))


# Standard ones, needed as class for e.g. single dispatch
Expand Down
2 changes: 1 addition & 1 deletion devito/symbolics/extended_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def __new__(cls, base, dtype=None, stars=None, **kwargs):
except (NameError, SyntaxError):
# E.g., `_base_typ` is "char" or "unsigned long"
pass
except TypeError:
except (ValueError, TypeError):
# `base` ain't a number
pass

Expand Down
30 changes: 7 additions & 23 deletions devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from devito.data import default_allocator
from devito.parameters import configuration
from devito.tools import (Pickable, as_tuple, dtype_to_ctype,
frozendict, memoized_meth, sympy_mutex, CustomDtype,
Reconstructable)
frozendict, memoized_meth, sympy_mutex, CustomDtype)
from devito.types.args import ArgProvider
from devito.types.caching import Cached, Uncached
from devito.types.lazy import Evaluable
Expand Down Expand Up @@ -881,16 +880,14 @@ def __new__(cls, *args, **kwargs):
name = kwargs.get('name')
alias = kwargs.get('alias')
function = kwargs.get('function')
dtype = kwargs.get('dtype')
if alias is True or (function and function.name != name):
function = kwargs['function'] = None

# If same name/indices and `function` isn't None, then it's
# definitely a reconstruction
if function is not None and \
function.name == name and \
function.indices == indices and \
function.dtype == dtype:
function.indices == indices:
# Special case: a syntactically identical alias of `function`, so
# let's just return `function` itself
return function
Expand Down Expand Up @@ -1231,8 +1228,7 @@ def bound_symbols(self):
@cached_property
def indexed(self):
"""The wrapped IndexedData object."""
return IndexedData(self.name, shape=self._shape, function=self.function,
dtype=self.dtype)
return IndexedData(self.name, shape=self._shape, function=self.function)

@cached_property
def dmap(self):
Expand Down Expand Up @@ -1507,14 +1503,13 @@ class IndexedBase(sympy.IndexedBase, Basic, Pickable):
__rargs__ = ('label', 'shape')
__rkwargs__ = ('function',)

def __new__(cls, label, shape, function=None, dtype=None):
def __new__(cls, label, shape, function=None):
# Make sure `label` is a devito.Symbol, not a sympy.Symbol
if isinstance(label, str):
label = Symbol(name=label, dtype=None)
with sympy_mutex:
obj = sympy.IndexedBase.__new__(cls, label, shape)
obj.function = function
obj._dtype = dtype or function.dtype
return obj

func = Pickable._rebuild
Expand Down Expand Up @@ -1548,7 +1543,7 @@ def indices(self):

@property
def dtype(self):
return self._dtype
return self.function.dtype

@cached_property
def free_symbols(self):
Expand Down Expand Up @@ -1610,7 +1605,7 @@ def _C_ctype(self):
return self.function._C_ctype


class Indexed(sympy.Indexed, Reconstructable):
class Indexed(sympy.Indexed):

# The two type flags have changed in upstream sympy as of version 1.1,
# but the below interpretation is used throughout the compiler to
Expand All @@ -1622,17 +1617,6 @@ class Indexed(sympy.Indexed, Reconstructable):

is_Dimension = False

__rargs__ = ('base', 'indices')
__rkwargs__ = ('dtype',)

def __new__(cls, base, *indices, dtype=None, **kwargs):
if len(indices) == 1:
indices = as_tuple(indices[0])
newobj = sympy.Indexed.__new__(cls, base, *indices)
newobj._dtype = dtype or base.dtype

return newobj

@memoized_meth
def __str__(self):
return super().__str__()
Expand All @@ -1654,7 +1638,7 @@ def function(self):

@property
def dtype(self):
return self._dtype
return self.function.dtype

@property
def name(self):
Expand Down
2 changes: 1 addition & 1 deletion devito/types/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class FIndexed(Indexed, Pickable):
__rkwargs__ = ('strides_map', 'accessor')

def __new__(cls, base, *args, strides_map=None, accessor=None):
obj = super().__new__(cls, base, args)
obj = super().__new__(cls, base, *args)
obj.strides_map = frozendict(strides_map or {})
obj.accessor = accessor

Expand Down

0 comments on commit dcd0bd4

Please sign in to comment.