From edea3b73c22f0d2bae997fc109ce915da9508a82 Mon Sep 17 00:00:00 2001 From: ZoeLeibowitz Date: Tue, 7 Jan 2025 14:02:02 +0000 Subject: [PATCH] types: Fix dtype of FFP and edit sympy_dtype --- devito/symbolics/extended_sympy.py | 4 ++++ devito/symbolics/inspection.py | 9 ++++----- tests/test_symbolics.py | 17 ++++++++++++++--- 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 4087bbc72c..a31fc9986e 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -252,6 +252,10 @@ def __str__(self): def field(self): return self.call + @property + def dtype(self): + return self.field.dtype + __repr__ = __str__ diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index 437d48fff0..411faee26c 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -295,9 +295,8 @@ def sympy_dtype(expr, base=None): Infer the dtype of the expression. """ dtypes = {base} - {None} - for i in expr.free_symbols: - try: - dtypes.add(i.dtype) - except AttributeError: - pass + for i in expr.args: + dtype = getattr(i, 'dtype', None) + if dtype: + dtypes.add(dtype) return infer_dtype(dtypes) diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index 7beb0c0b97..61fb0daef0 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -13,10 +13,10 @@ CallFromPointer, Cast, DefFunction, FieldFromPointer, INT, FieldFromComposite, IntDiv, Namespace, Rvalue, ReservedWord, ListInitializer, ccode, uxreplace, - retrieve_derivatives) + retrieve_derivatives, sympy_dtype) from devito.tools import as_tuple from devito.types import (Array, Bundle, FIndexed, LocalObject, Object, - Symbol as dSymbol) + Symbol as dSymbol, CompositeObject) from devito.types.basic import AbstractSymbol @@ -248,6 +248,17 @@ def test_field_from_pointer(): # Free symbols assert ffp1.free_symbols == {s} + # Test dtype + f = dSymbol('f') + pfields = [(f._C_name, f._C_ctype)] + struct = CompositeObject('s1', 'myStruct', pfields) + ffp4 = FieldFromPointer(f, struct) + assert str(ffp4) == 's1->f' + assert ffp4.dtype == f.dtype + expr = 1/ffp4 + dtype = sympy_dtype(expr) + assert dtype == f.dtype + def test_field_from_composite(): s = Symbol('s') @@ -292,7 +303,7 @@ def test_extended_sympy_arithmetic(): # noncommutative o = Object(name='o', dtype=c_void_p) bar = FieldFromPointer('bar', o) - assert ccode(-1 + bar) == '-1 + o->bar' + assert ccode(-1 + bar) == 'o->bar - 1' def test_integer_abs():