diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index 3332bc68d67..6649fa86bf5 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -304,4 +304,10 @@ def sympy_dtype(expr, base=None): dtypes.add(i.dtype) except AttributeError: pass - return infer_dtype(dtypes) + dtype = infer_dtype(dtypes) + + # Promote if complex + if expr.has(ImaginaryUnit): + dtype = np.promote_types(dtype, np.complex64).type + + return dtype diff --git a/devito/types/basic.py b/devito/types/basic.py index 0c27188ae95..af08d6281d9 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -436,7 +436,7 @@ def _C_ctype(self): return self.dtype elif np.issubdtype(self.dtype, np.complexfloating): rtype = self.dtype(0).real.__class__ - ctname = '%s _Complex' % dtype_to_cstr(rtype) + ctname = '%s complex' % dtype_to_cstr(rtype) ctype = dtype_to_ctype(rtype) r = type(ctname, (ctype,), {}) return r diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index 1204eb2c130..28d0a38edd3 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -66,6 +66,24 @@ def test_maxpar_option(self): assert trees[0][0] is trees[1][0] assert trees[0][1] is not trees[1][1] + def test_complex(self): + grid = Grid((5, 5)) + x, y = grid.dimensions + # Float32 complex is called complex64 in numpy + u = Function(name="u", grid=grid, dtype=np.complex64) + + eq = Eq(u, x + 1j*y + exp(1j + x.spacing)) + # Currently wrong alias type + op = Operator(eq) + op() + + # Check against numpy + dx = grid.spacing_map[x.spacing] + xx, yy = np.meshgrid(np.linspace(0, 4, 5), np.linspace(0, 4, 5)) + npres = xx + 1j*yy + np.exp(1j + dx) + + assert np.allclose(u.data, npres.T, rtol=1e-7, atol=0) + class TestPassesOptional(object): diff --git a/tests/test_operator.py b/tests/test_operator.py index 1ad704f8367..376c7e5dff6 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -648,7 +648,7 @@ def test_complex(self): eq = Eq(u, x + 1j*y + exp(1j + x.spacing)) # Currently wrong alias type - op = Operator(eq, opt='noop') + op = Operator(eq) op() # Check against numpy