diff --git a/devito/finite_differences/finite_difference.py b/devito/finite_differences/finite_difference.py index b664e57c2d..45f4060715 100644 --- a/devito/finite_differences/finite_difference.py +++ b/devito/finite_differences/finite_difference.py @@ -143,11 +143,6 @@ def first_derivative(expr, dim, fd_order, **kwargs): return generic_derivative(expr, dim, fd_order, 1, **kwargs) -# Backward compatibility -def first_derivative(expr, dim, fd_order, **kwargs): - return generic_derivative(expr, dim, fd_order, 1, **kwargs) - - def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, coefficients, expand, weights=None): # Always expand time derivatives to avoid issue with buffering and streaming. diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 0432a6e77a..622b1ed95b 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -19,9 +19,10 @@ Call, Lambda, BlankLine, Section, ListMajor) from devito.ir.support.space import Backward from devito.symbolics import (FieldFromComposite, FieldFromPointer, - ListInitializer, ccode, uxreplace) + ListInitializer, uxreplace) +from devito.symbolics.printer import _DevitoPrinterBase from devito.symbolics.extended_dtypes import NoDeclStruct -from devito.tools import (GenericVisitor, as_tuple, ctypes_to_cstr, filter_ordered, +from devito.tools import (GenericVisitor, as_tuple, filter_ordered, filter_sorted, flatten, is_external_ctype, c_restrict_void_p, sorted_priority) from devito.types.basic import AbstractFunction, AbstractSymbol, Basic @@ -176,9 +177,10 @@ class CGen(Visitor): Return a representation of the Iteration/Expression tree as a :module:`cgen` tree. """ - def __init__(self, *args, compiler=None, **kwargs): + def __init__(self, *args, compiler=None, printer=None, **kwargs): super().__init__(*args, **kwargs) self._compiler = compiler or configuration['compiler'] + self._printer = printer or _DevitoPrinterBase # The following mappers may be customized by subclasses (that is, # backend-specific CGen-erators) @@ -194,6 +196,9 @@ def __init__(self, *args, compiler=None, **kwargs): def compiler(self): return self._compiler + def ccode(self, expr, **settings): + return self._printer(settings=settings).doprint(expr, None) + def visit(self, o, *args, **kwargs): # Make sure the visitor always is within the generating compiler # in case the configuration is accessed @@ -233,7 +238,7 @@ def _gen_struct_decl(self, obj, masked=()): try: entries.append(self._gen_value(i, 0, masked=('const',))) except AttributeError: - cstr = ctypes_to_cstr(ct) + cstr = self.ccode(ct) if ct is c_restrict_void_p: cstr = '%srestrict' % cstr entries.append(c.Value(cstr, n)) @@ -255,10 +260,10 @@ def _gen_value(self, obj, mode=1, masked=()): if getattr(obj.function, k, False) and v not in masked] if (obj._mem_stack or obj._mem_constant) and mode == 1: - strtype = obj._C_typedata - strshape = ''.join('[%s]' % ccode(i) for i in obj.symbolic_shape) + strtype = self.ccode(obj._C_typedata) + strshape = ''.join('[%s]' % self.ccode(i) for i in obj.symbolic_shape) else: - strtype = ctypes_to_cstr(obj._C_ctype) + strtype = self.ccode(obj._C_ctype) strshape = '' if isinstance(obj, (AbstractFunction, IndexedData)) and mode >= 1: if not obj._mem_stack: @@ -272,7 +277,7 @@ def _gen_value(self, obj, mode=1, masked=()): strobj = '%s%s' % (strname, strshape) if obj.is_LocalObject and obj.cargs and mode == 1: - arguments = [ccode(i) for i in obj.cargs] + arguments = [self.ccode(i) for i in obj.cargs] strobj = MultilineCall(strobj, arguments, True) value = c.Value(strtype, strobj) @@ -286,9 +291,9 @@ def _gen_value(self, obj, mode=1, masked=()): if obj.is_Array and obj.initvalue is not None and mode == 1: init = ListInitializer(obj.initvalue) if not obj._mem_constant or init.is_numeric: - value = c.Initializer(value, ccode(init)) + value = c.Initializer(value, self.ccode(init)) elif obj.is_LocalObject and obj.initvalue is not None and mode == 1: - value = c.Initializer(value, ccode(obj.initvalue)) + value = c.Initializer(value, self.ccode(obj.initvalue)) return value @@ -322,7 +327,7 @@ def _args_call(self, args): else: ret.append(i._C_name) except AttributeError: - ret.append(ccode(i)) + ret.append(self.ccode(i)) return ret def _gen_signature(self, o, is_declaration=False): @@ -388,7 +393,7 @@ def visit_tuple(self, o): def visit_PointerCast(self, o): f = o.function i = f.indexed - cstr = i._C_typedata + cstr = self.ccode(i._C_typedata) if f.is_PointerArray: # lvalue @@ -410,7 +415,7 @@ def visit_PointerCast(self, o): else: v = f.name if o.flat is None: - shape = ''.join("[%s]" % ccode(i) for i in o.castshape) + shape = ''.join("[%s]" % self.ccode(i) for i in o.castshape) rshape = '(*)%s' % shape lvalue = c.Value(cstr, '(*restrict %s)%s' % (v, shape)) else: @@ -443,9 +448,9 @@ def visit_Dereference(self, o): a0, a1 = o.functions if a1.is_PointerArray or a1.is_TempFunction: i = a1.indexed - cstr = i._C_typedata + cstr = self.ccode(i._C_typedata) if o.flat is None: - shape = ''.join("[%s]" % ccode(i) for i in a0.symbolic_shape[1:]) + shape = ''.join("[%s]" % self.ccode(i) for i in a0.symbolic_shape[1:]) rvalue = '(%s (*)%s) %s[%s]' % (cstr, shape, a1.name, a1.dim.name) lvalue = c.Value(cstr, '(*restrict %s)%s' % (a0.name, shape)) @@ -484,8 +489,8 @@ def visit_Definition(self, o): return self._gen_value(o.function) def visit_Expression(self, o): - lhs = ccode(o.expr.lhs, dtype=o.dtype, compiler=self._compiler) - rhs = ccode(o.expr.rhs, dtype=o.dtype, compiler=self._compiler) + lhs = self.ccode(o.expr.lhs, dtype=o.dtype, compiler=self._compiler) + rhs = self.ccode(o.expr.rhs, dtype=o.dtype, compiler=self._compiler) if o.init: code = c.Initializer(self._gen_value(o.expr.lhs, 0), rhs) @@ -498,8 +503,8 @@ def visit_Expression(self, o): return code def visit_AugmentedExpression(self, o): - c_lhs = ccode(o.expr.lhs, dtype=o.dtype, compiler=self._compiler) - c_rhs = ccode(o.expr.rhs, dtype=o.dtype, compiler=self._compiler) + c_lhs = self.ccode(o.expr.lhs, dtype=o.dtype, compiler=self._compiler) + c_rhs = self.ccode(o.expr.rhs, dtype=o.dtype, compiler=self._compiler) code = c.Statement("%s %s= %s" % (c_lhs, o.op, c_rhs)) if o.pragmas: code = c.Module(self._visit(o.pragmas) + (code,)) @@ -518,7 +523,7 @@ def visit_Call(self, o, nested_call=False): o.templates) if retobj.is_Indexed or \ isinstance(retobj, (FieldFromComposite, FieldFromPointer)): - return c.Assign(ccode(retobj), call) + return c.Assign(self.ccode(retobj), call) else: return c.Initializer(c.Value(rettype, retobj._C_name), call) @@ -532,9 +537,9 @@ def visit_Conditional(self, o): then_body = c.Block(self._visit(then_body)) if else_body: else_body = c.Block(self._visit(else_body)) - return c.If(ccode(o.condition), then_body, else_body) + return c.If(self.ccode(o.condition), then_body, else_body) else: - return c.If(ccode(o.condition), then_body) + return c.If(self.ccode(o.condition), then_body) def visit_Iteration(self, o): body = flatten(self._visit(i) for i in self._blankline_logic(o.children)) @@ -544,23 +549,23 @@ def visit_Iteration(self, o): # For backward direction flip loop bounds if o.direction == Backward: - loop_init = 'int %s = %s' % (o.index, ccode(_max)) - loop_cond = '%s >= %s' % (o.index, ccode(_min)) + loop_init = 'int %s = %s' % (o.index, self.ccode(_max)) + loop_cond = '%s >= %s' % (o.index, self.ccode(_min)) loop_inc = '%s -= %s' % (o.index, o.limits[2]) else: - loop_init = 'int %s = %s' % (o.index, ccode(_min)) - loop_cond = '%s <= %s' % (o.index, ccode(_max)) + loop_init = 'int %s = %s' % (o.index, self.ccode(_min)) + loop_cond = '%s <= %s' % (o.index, self.ccode(_max)) loop_inc = '%s += %s' % (o.index, o.limits[2]) # Append unbounded indices, if any if o.uindices: - uinit = ['%s = %s' % (i.name, ccode(i.symbolic_min)) for i in o.uindices] + uinit = ['%s = %s' % (i.name, self.ccode(i.symbolic_min)) for i in o.uindices] loop_init = c.Line(', '.join([loop_init] + uinit)) ustep = [] for i in o.uindices: op = '=' if i.is_Modulo else '+=' - ustep.append('%s %s %s' % (i.name, op, ccode(i.symbolic_incr))) + ustep.append('%s %s %s' % (i.name, op, self.ccode(i.symbolic_incr))) loop_inc = c.Line(', '.join([loop_inc] + ustep)) # Create For header+body @@ -577,7 +582,7 @@ def visit_Pragma(self, o): return c.Pragma(o._generate) def visit_While(self, o): - condition = ccode(o.condition) + condition = self.ccode(o.condition) if o.body: body = flatten(self._visit(i) for i in o.children) return c.While(condition, c.Block(body)) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 6a89dabd9c..8561a8e911 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -27,12 +27,10 @@ from devito.passes import (Graph, lower_index_derivatives, generate_implicit, generate_macros, minimize_symbols, unevaluate, error_mapper, is_on_device) -from devito.passes.iet.langbase import LangBB from devito.symbolics import estimate_cost, subs_op_args from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_mapper, as_tuple, flatten, filter_sorted, frozendict, is_integer, split, timed_pass, timed_region, contains_val) -from devito.tools.dtypes_lowering import ctypes_vector_mapper from devito.types import (Buffer, Grid, Evaluable, host_layer, device_layer, disk_layer) from devito.types.dimension import Thickness @@ -272,9 +270,6 @@ def _lower(cls, expressions, **kwargs): # expression for which a partial or complete lowering is desired kwargs['rcompile'] = cls._rcompile_wrapper(**kwargs) - # Load language-specific types into the global dtype->ctype mapper - cls._load_dtype_mappings(**kwargs) - # [Eq] -> [LoweredEq] expressions = cls._lower_exprs(expressions, **kwargs) @@ -296,11 +291,6 @@ def _lower(cls, expressions, **kwargs): def _rcompile_wrapper(cls, **kwargs0): raise NotImplementedError - @classmethod - def _load_dtype_mappings(cls, **kwargs): - lang: type[LangBB] = cls._Target.DataManager.lang - ctypes_vector_mapper.update(lang.mapper.get('types', {})) - @classmethod def _initialize_state(cls, **kwargs): return {} @@ -764,13 +754,19 @@ def _soname(self): """A unique name for the shared object resulting from JIT compilation.""" return Signer._digest(self, configuration) + @property + def printer(self): + return self._Target.Printer + @cached_property def ccode(self): try: - return self._ccode_handler(compiler=self._compiler).visit(self) + return self._ccode_handler(compiler=self._compiler, + printer=self.printer).visit(self) except (AttributeError, TypeError): from devito.ir.iet.visitors import CGen - return CGen(compiler=self._compiler).visit(self) + return CGen(compiler=self._compiler, + printer=self.printer).visit(self) def _jit_compile(self): """ diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index 6f698a659d..f1606a73ff 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -2,11 +2,8 @@ import numpy as np from devito.arch.compiler import Compiler -from devito.ir import Callable, Dereference, FindSymbols, Node, SymbolRegistry, Uxreplace +from devito.ir import Callable, FindSymbols, SymbolRegistry from devito.passes.iet.langbase import LangBB -from devito.symbolics.extended_dtypes import Float16P -from devito.tools import as_list -from devito.types.basic import AbstractSymbol, Basic, Symbol __all__ = ['lower_dtypes'] @@ -17,34 +14,9 @@ def lower_dtypes(iet: Callable, lang: type[LangBB], compiler: Compiler, Lowers float16 scalar types to pointers since we can't directly pass their value. Also includes headers for complex arithmetic if needed. """ - + # Complex numbers iet, metadata = _complex_includes(iet, lang, compiler) - # Lower float16 parameters to pointers and dereference - prefix: list[Node] = [] - params_mapper: dict[AbstractSymbol, AbstractSymbol] = {} - body_mapper: dict[AbstractSymbol, Symbol] = {} - - params_set = set(iet.parameters) - s: AbstractSymbol - for s in FindSymbols('abstractsymbols').visit(iet): - if s.dtype != np.float16 or s not in params_set: - continue - - # Replace the parameter with a pointer; replace occurences in the IET - # body with dereferenced symbol (using the original symbol's dtype) - ptr: AbstractSymbol = s._rebuild(dtype=Float16P, is_const=True) - val = Symbol(name=sregistry.make_name(prefix='hf'), dtype=s.dtype, - is_const=s.is_const) - - params_mapper[s], body_mapper[s] = ptr, val - prefix.append(Dereference(val, ptr)) # val = *ptr - - # Apply the replacements - prefix.extend(as_list(Uxreplace(body_mapper).visit(iet.body))) - params: tuple[Basic] = Uxreplace(params_mapper).visit(iet.parameters) - - iet = iet._rebuild(body=prefix, parameters=params) return iet, metadata diff --git a/devito/passes/iet/languages/C.py b/devito/passes/iet/languages/C.py index 069aa10320..7efdaa44ff 100644 --- a/devito/passes/iet/languages/C.py +++ b/devito/passes/iet/languages/C.py @@ -1,21 +1,11 @@ -import numpy as np - from devito.ir import Call from devito.passes.iet.definitions import DataManager from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.langbase import LangBB -from devito.symbolics.extended_dtypes import (Float16P, c_complex, c_double_complex, - c_half, c_half_p) - - -__all__ = ['CBB', 'CDataManager', 'COrchestrator', 'c_float16', 'c_float16_p'] - +from devito.symbolics.extended_dtypes import c_complex, c_double_complex +from devito.symbolics.printer import _DevitoPrinterBase -c99_complex = type('_Complex float', (c_complex,), {}) -c99_double_complex = type('_Complex double', (c_double_complex,), {}) - -c_float16 = type('_Float16', (c_half,), {}) -c_float16_p = type('_Float16 *', (c_half_p,), {'_type_': c_float16}) +__all__ = ['CBB', 'CDataManager', 'COrchestrator'] class CBB(LangBB): @@ -34,10 +24,6 @@ class CBB(LangBB): Call('memcpy', (i, j, k)), # Complex and float16 'header-complex': 'complex.h', - 'types': {np.complex128: c99_double_complex, - np.complex64: c99_complex, - np.float16: c_float16, - Float16P: c_float16_p} } @@ -47,3 +33,12 @@ class CDataManager(DataManager): class COrchestrator(Orchestrator): lang = CBB + + +class CDevitoPrinter(_DevitoPrinterBase): + + # These cannot go through _print_xxx because they are classes not + # instances + type_mappings = {**_DevitoPrinterBase.type_mappings, + c_complex: 'float _Complex', + c_double_complex: 'double _Complex'} diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index 1174a27f8d..aa9e9118de 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -1,9 +1,9 @@ -import numpy as np +from sympy.printing.cxx import CXX11CodePrinter from devito.ir import Call, UsingNamespace from devito.passes.iet.langbase import LangBB -from devito.passes.iet.languages.C import c_float16, c_float16_p -from devito.symbolics.extended_dtypes import Float16P, c_complex, c_double_complex +from devito.symbolics.printer import _DevitoPrinterBase +from devito.symbolics.extended_dtypes import c_complex, c_double_complex __all__ = ['CXXBB'] @@ -45,10 +45,6 @@ """ -cxx_complex = type('std::complex', (c_complex,), {}) -cxx_double_complex = type('std::complex', (c_double_complex,), {}) - - class CXXBB(LangBB): mapper = { @@ -67,8 +63,16 @@ class CXXBB(LangBB): 'header-complex': 'complex', 'complex-namespace': [UsingNamespace('std::complex_literals')], 'def-complex': std_arith, - "types": {np.complex128: cxx_double_complex, - np.complex64: cxx_complex, - np.float16: c_float16, - Float16P: c_float16_p} } + + +class CXXDevitoPrinter(_DevitoPrinterBase, CXX11CodePrinter): + + _default_settings = {**_DevitoPrinterBase._default_settings, + **CXX11CodePrinter._default_settings} + + # These cannot go through _print_xxx because they are classes not + # instances + type_mappings = {c_complex: 'std::complex', + c_double_complex: 'std::complex', + **CXX11CodePrinter.type_mappings} diff --git a/devito/passes/iet/languages/openacc.py b/devito/passes/iet/languages/openacc.py index bcf5660ac7..1718a5269a 100644 --- a/devito/passes/iet/languages/openacc.py +++ b/devito/passes/iet/languages/openacc.py @@ -9,7 +9,7 @@ from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.parpragma import (PragmaDeviceAwareTransformer, PragmaLangBB, PragmaIteration, PragmaTransfer) -from devito.passes.iet.languages.CXX import CXXBB +from devito.passes.iet.languages.CXX import CXXBB, CXXDevitoPrinter from devito.passes.iet.languages.openmp import OmpRegion, OmpIteration from devito.symbolics import FieldFromPointer, Macro, cast_mapper from devito.tools import filter_ordered, UnboundTuple @@ -263,3 +263,8 @@ def place_devptr(self, iet, **kwargs): class AccOrchestrator(Orchestrator): lang = AccBB + + +class AccDevitoPrinter(CXXDevitoPrinter): + + pass diff --git a/devito/passes/iet/languages/targets.py b/devito/passes/iet/languages/targets.py index 4ac8d94398..66137a53e7 100644 --- a/devito/passes/iet/languages/targets.py +++ b/devito/passes/iet/languages/targets.py @@ -1,9 +1,9 @@ -from devito.passes.iet.languages.C import CDataManager, COrchestrator +from devito.passes.iet.languages.C import CDataManager, COrchestrator, CDevitoPrinter from devito.passes.iet.languages.openmp import (SimdOmpizer, Ompizer, DeviceOmpizer, OmpDataManager, DeviceOmpDataManager, OmpOrchestrator, DeviceOmpOrchestrator) from devito.passes.iet.languages.openacc import (DeviceAccizer, DeviceAccDataManager, - AccOrchestrator) + AccOrchestrator, AccDevitoPrinter) from devito.passes.iet.instrument import instrument __all__ = ['CTarget', 'OmpTarget', 'DeviceOmpTarget', 'DeviceAccTarget'] @@ -13,6 +13,7 @@ class Target: Parizer = None DataManager = None Orchestrator = None + Printer = None @classmethod def lang(cls): @@ -27,21 +28,25 @@ class CTarget(Target): Parizer = SimdOmpizer DataManager = CDataManager Orchestrator = COrchestrator + Printer = CDevitoPrinter class OmpTarget(Target): Parizer = Ompizer DataManager = OmpDataManager Orchestrator = OmpOrchestrator + Printer = CDevitoPrinter class DeviceOmpTarget(Target): Parizer = DeviceOmpizer DataManager = DeviceOmpDataManager Orchestrator = DeviceOmpOrchestrator + Printer = CDevitoPrinter class DeviceAccTarget(Target): Parizer = DeviceAccizer DataManager = DeviceAccDataManager Orchestrator = AccOrchestrator + Printer = AccDevitoPrinter diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py index 8c90e48986..e9772ee744 100644 --- a/devito/symbolics/extended_dtypes.py +++ b/devito/symbolics/extended_dtypes.py @@ -3,11 +3,10 @@ from devito.symbolics.extended_sympy import ReservedWord, Cast, CastStar, ValueLimit from devito.tools import (Bunch, float2, float3, float4, double2, double3, double4, # noqa - int2, int3, int4) + int2, int3, int4, ctypes_vector_mapper) -__all__ = ['cast_mapper', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', - 'DOUBLE', 'VOID', 'NoDeclStruct', 'c_complex', 'c_double_complex', - 'c_half', 'c_half_p', 'Float16P'] +__all__ = ['cast_mapper', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', # noqa + 'DOUBLE', 'VOID', 'NoDeclStruct', 'c_complex', 'c_double_complex'] limits_mapper = { @@ -50,34 +49,8 @@ def from_param(cls, val): return cls(val.real, val.imag) -class c_half(ctypes.c_uint16): - """Ctype for non-scalar half floats""" - - @classmethod - def from_param(cls, val): - return cls(np.float16(val).view(np.uint16)) - - -class c_half_p(ctypes.POINTER(c_half)): - """ - Ctype for half scalars; we can't directly pass _Float16 values so - we use a pointer and dereference (see `passes.iet.dtypes`) - """ - - @classmethod - def from_param(cls, val): - arr = np.array(val, dtype=np.float16) - return arr.ctypes.data_as(cls) - - -class Float16P(np.float16): - """ - Dummy dtype for a scalar half value that has been mapped to a pointer. - This is needed because we can't directly pass in the values; we map to - pointers and dereference in the kernel. See `passes.iet.dtypes`. - """ - - pass +ctypes_vector_mapper.update({np.complex64: c_complex, + np.complex128: c_double_complex}) class CustomType(ReservedWord): diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index 1853b02044..29b429b6c1 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -11,19 +11,20 @@ from sympy.core import S from sympy.core.numbers import equal_valued, Float +from sympy.printing.c import C99CodePrinter from sympy.logic.boolalg import BooleanFunction from sympy.printing.precedence import PRECEDENCE_VALUES, precedence -from sympy.printing.c import C99CodePrinter from devito import configuration from devito.arch.compiler import AOMPCompiler from devito.symbolics.inspection import has_integer_args, sympy_dtype from devito.types.basic import AbstractFunction +from devito.tools import ctypes_to_cstr __all__ = ['ccode'] -class CodePrinter(C99CodePrinter): +class _DevitoPrinterBase(C99CodePrinter): """ Decorator for sympy.printing.ccode.CCodePrinter. @@ -47,6 +48,13 @@ def dtype(self): def compiler(self): return self._settings['compiler'] or configuration['compiler'] + def doprint(self, expr, assign_to=None): + """ + The sympy code printer does a lot of extra we do not need as we handle all of + it in the compiler so we directly defaults to `_print` + """ + return self._print(expr) + def single_prec(self, expr=None, with_f=False): no_f = self.compiler._cpp and not with_f if no_f and expr is not None: @@ -72,6 +80,12 @@ def parenthesize(self, item, level, strict=False): return "(%s)" % self._print(item) return super().parenthesize(item, level, strict=strict) + def _print_type(self, expr): + try: + return self.type_mappings[expr] + except KeyError: + return ctypes_to_cstr(expr) + def _print_Function(self, expr): if isinstance(expr, AbstractFunction): return str(expr) @@ -269,7 +283,7 @@ def _print_ImaginaryUnit(self, expr): def _print_Differentiable(self, expr): return "(%s)" % self._print(expr._expr) - _print_EvalDerivative = C99CodePrinter._print_Add + _print_EvalDerivative = _print_Add def _print_CallFromPointer(self, expr): indices = [self._print(i) for i in expr.params] @@ -349,7 +363,7 @@ def _print_Fallback(self, expr): # Lifted from SymPy so that we go through our own `_print_math_func` for k in ('exp log sin cos tan ceiling floor').split(): - setattr(CodePrinter, '_print_%s' % k, CodePrinter._print_math_func) + setattr(_DevitoPrinterBase, '_print_%s' % k, _DevitoPrinterBase._print_math_func) # Always parenthesize IntDiv and InlineIf within expressions @@ -373,10 +387,10 @@ def ccode(expr, **settings): The resulting code as a C++ string. If something went south, returns the input ``expr`` itself. """ - return CodePrinter(settings=settings).doprint(expr, None) + return _DevitoPrinterBase(settings=settings).doprint(expr, None) # Sympy 1.11 has introduced a bug in `_print_Add`, so we enforce here # to always use the correct one from our printer if Version(sympy.__version__) >= Version("1.11"): - setattr(sympy.printing.str.StrPrinter, '_print_Add', CodePrinter._print_Add) + setattr(sympy.printing.str.StrPrinter, '_print_Add', _DevitoPrinterBase._print_Add) diff --git a/devito/types/basic.py b/devito/types/basic.py index aa9c5e0b29..8e131c7357 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -13,7 +13,7 @@ from devito.data import default_allocator from devito.parameters import configuration -from devito.tools import (Pickable, as_tuple, ctypes_to_cstr, dtype_to_ctype, +from devito.tools import (Pickable, as_tuple, dtype_to_ctype, frozendict, memoized_meth, sympy_mutex, CustomDtype, Reconstructable) from devito.types.args import ArgProvider @@ -95,7 +95,7 @@ def _C_typedata(self): if _type is c_char_p: _type = c_char - return ctypes_to_cstr(_type) + return _type @abc.abstractproperty def _C_ctype(self): diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index a330096b36..124dbba5c4 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -52,7 +52,7 @@ def _config_kwargs(platform: str, language: str, compiler: str) -> dict[str, str ] -@pytest.mark.parametrize('dtype', [np.float16, np.complex64, np.complex128]) +@pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('kwargs', _configs) def test_dtype_mapping(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> None: """ @@ -79,7 +79,7 @@ def test_dtype_mapping(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> N assert _c._C_ctype == lang_types[_c.dtype] -@pytest.mark.parametrize('dtype', [np.float16, np.complex64, np.complex128]) +@pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('kwargs', _configs) def test_cse_ctypes(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> None: """ @@ -106,39 +106,7 @@ def test_cse_ctypes(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> None assert match.group(1) == ctypes_to_cstr(lang_types[dtype]) -def test_half_params() -> None: - """ - Tests float16 input parameters: scalars should be lowered to pointers - and dereferenced; other parameters should keep the original dtype. - """ - - grid = Grid(shape=(5, 5), dtype=np.float16) - x, y = grid.dimensions - - c = Constant(name='c', dtype=np.float16) - u = Function(name='u', grid=grid) - eq = Eq(u, x * x.spacing + c * y * y.spacing) - op = Operator(eq) - - # Check that lowered parameters have the correct dtypes - params: dict[str, Basic] = {p.name: p for p in op.parameters} - _u, _c, _dx, _dy = params['u'], params['c'], params['h_x'], params['h_y'] - - assert _u.dtype == np.float16 - assert _c.dtype == Float16P - assert _dx.dtype == Float16P - assert _dy.dtype == Float16P - - # Ensure the mapped pointer-to-half symbols are dereferenced - derefs: set[Symbol] = {n.pointer for n in op.body.body - if isinstance(n, Dereference)} - assert _c in derefs - assert _dx in derefs - assert _dy in derefs - - -@pytest.mark.parametrize('dtype', [np.float16, np.float32, - np.complex64, np.complex128]) +@pytest.mark.parametrize('dtype', [np.float32, np.complex64, np.complex128]) @pytest.mark.parametrize('kwargs', _configs) def test_complex_headers(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> None: np.dtype @@ -192,7 +160,7 @@ def test_imag_unit(dtype: np.complexfloating, kwargs: dict[str, str]) -> None: assert unit_str in str(op) -@pytest.mark.parametrize('dtype', [np.float16, np.float32, np.float64, +@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.complex64, np.complex128]) @pytest.mark.parametrize(['sym', 'fun'], [(sympy.exp, np.exp), (sympy.log, np.log), @@ -248,30 +216,6 @@ def test_complex_override(dtype: np.dtype[np.complexfloating]) -> None: assert np.allclose(u.data.T, expected) -def test_half_time_deriv() -> None: - """ - Tests taking the time derivative of a float16 function. - """ - - grid = Grid(shape=(5, 5)) - x, y = grid.dimensions - t = grid.time_dim - - f = TimeFunction(name='f', grid=grid, space_order=2, dtype=np.float16) - g = Function(name='g', grid=grid, dtype=np.float16) - eqns = [Eq(f.forward, t * x * x.spacing + - y * y.spacing), - Eq(g, f.dt)] - op = Operator(eqns) - op.apply(time=10, dt=1.0) - - # Check against expected result - dx = grid.spacing_map[x.spacing] - xx = np.repeat(np.linspace(0, 4, 5, dtype=np.float16)[np.newaxis, :], 5, axis=0) - expected = xx * np.float16(dx) - assert np.allclose(g.data.T, expected) - - @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) def test_complex_time_deriv(dtype: np.dtype[np.complexfloating]) -> None: """