Skip to content

Commit

Permalink
compiler: fix std math func names
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jan 16, 2025
1 parent 54cd1fc commit 62f2deb
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 7 deletions.
3 changes: 2 additions & 1 deletion devito/passes/iet/languages/CXX.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,13 @@ class CXXDevitoPrinter(_DevitoPrinterBase, CXX11CodePrinter):
_default_settings = {**_DevitoPrinterBase._default_settings,
**CXX11CodePrinter._default_settings}
_ns = "std::"
_func_litterals = {}

# These cannot go through _print_xxx because they are classes not
# instances
type_mappings = {**_DevitoPrinterBase.type_mappings,
c_complex: 'std::complex<float>',
c_double_complex: 'std::complex<float>',
c_double_complex: 'std::complex<double>',
**CXX11CodePrinter.type_mappings}

def _print_ImaginaryUnit(self, expr):
Expand Down
7 changes: 5 additions & 2 deletions devito/symbolics/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,10 +295,13 @@ def has_integer_args(*args):
return res


def sympy_dtype(expr, base=None):
def sympy_dtype(expr, base=None, default=None):
"""
Infer the dtype of the expression.
"""
if expr is None:
return default

dtypes = {base} - {None}
for i in expr.free_symbols:
try:
Expand All @@ -312,7 +315,7 @@ def sympy_dtype(expr, base=None):
is_im = np.issubdtype(dtype, np.complexfloating)
if expr.has(ImaginaryUnit) and not is_im:
if dtype is None:
dtype = np.complex64
dtype = default or np.complex64
else:
dtype = np.promote_types(dtype, np.complex64).type

Expand Down
6 changes: 3 additions & 3 deletions devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@


_prec_litterals = {np.float16: 'F16', np.float32: 'F', np.complex64: 'F'}
_func_litterals = {np.float32: 'f', np.complex64: 'f', Real: 'f'}


class _DevitoPrinterBase(C99CodePrinter):
Expand All @@ -42,6 +41,7 @@ class _DevitoPrinterBase(C99CodePrinter):
**C99CodePrinter._default_settings}

_func_prefix = {np.float32: 'f', np.float64: 'f'}
_func_litterals = {np.float32: 'f', np.complex64: 'f', Real: 'f'}

@property
def dtype(self):
Expand All @@ -62,7 +62,7 @@ def doprint(self, expr, assign_to=None):
return self._print(expr)

def _prec(self, expr):
dtype = sympy_dtype(expr) if expr is not None else None
dtype = sympy_dtype(expr, default=self.dtype)
if dtype is None or np.issubdtype(dtype, np.integer):
real = any(isinstance(i, Float) for i in expr.atoms())
stype = self.dtype if real else np.int32
Expand All @@ -74,7 +74,7 @@ def prec_literal(self, expr):
return _prec_litterals.get(self._prec(expr), '')

def func_literal(self, expr):
return _func_litterals.get(self._prec(expr), '')
return self._func_litterals.get(self._prec(expr), '')

def func_prefix(self, expr, abs=False):
prefix = self._func_prefix.get(self._prec(expr), '')
Expand Down
2 changes: 1 addition & 1 deletion tests/test_gpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_complex(self, dtype):
xx, yy = np.meshgrid(np.linspace(0, 4, 5), np.linspace(0, 4, 5))
npres = xx + 1j*yy + np.exp(1j + dx) * (1.0 + 2.0j)

assert np.allclose(u.data, npres.T, rtol=1e-7, atol=0)
assert np.allclose(u.data, npres.T, rtol=5e-7, atol=0)


class TestPassesOptional:
Expand Down

0 comments on commit 62f2deb

Please sign in to comment.