Skip to content

Commit

Permalink
compiler: fix alias dtype with complex numbers
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed May 24, 2024
1 parent e6b2fe3 commit ed90efb
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 3 deletions.
8 changes: 7 additions & 1 deletion devito/symbolics/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,4 +304,10 @@ def sympy_dtype(expr, base=None):
dtypes.add(i.dtype)
except AttributeError:
pass
return infer_dtype(dtypes)
dtype = infer_dtype(dtypes)

# Promote if complex
if expr.has(ImaginaryUnit):
dtype = np.promote_types(dtype, np.complex64).type

return dtype
2 changes: 1 addition & 1 deletion devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def _C_ctype(self):
return self.dtype
elif np.issubdtype(self.dtype, np.complexfloating):
rtype = self.dtype(0).real.__class__
ctname = '%s _Complex' % dtype_to_cstr(rtype)
ctname = '%s complex' % dtype_to_cstr(rtype)
ctype = dtype_to_ctype(rtype)
r = type(ctname, (ctype,), {})
return r
Expand Down
18 changes: 18 additions & 0 deletions tests/test_gpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,24 @@ def test_maxpar_option(self):
assert trees[0][0] is trees[1][0]
assert trees[0][1] is not trees[1][1]

def test_complex(self):
grid = Grid((5, 5))
x, y = grid.dimensions
# Float32 complex is called complex64 in numpy
u = Function(name="u", grid=grid, dtype=np.complex64)

eq = Eq(u, x + 1j*y + exp(1j + x.spacing))
# Currently wrong alias type
op = Operator(eq)
op()

# Check against numpy
dx = grid.spacing_map[x.spacing]
xx, yy = np.meshgrid(np.linspace(0, 4, 5), np.linspace(0, 4, 5))
npres = xx + 1j*yy + np.exp(1j + dx)

assert np.allclose(u.data, npres.T, rtol=1e-7, atol=0)


class TestPassesOptional(object):

Expand Down
2 changes: 1 addition & 1 deletion tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ def test_complex(self):

eq = Eq(u, x + 1j*y + exp(1j + x.spacing))
# Currently wrong alias type
op = Operator(eq, opt='noop')
op = Operator(eq)
op()

# Check against numpy
Expand Down

0 comments on commit ed90efb

Please sign in to comment.