Skip to content

Commit

Permalink
symbolics: specialize sizeof
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jan 16, 2025
1 parent a38b3ef commit b3b39af
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 7 deletions.
2 changes: 1 addition & 1 deletion devito/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def reinit_compiler(val):
# Setup target platform and compiler
configuration.add('platform', 'cpu64', list(platform_registry),
callback=lambda i: platform_registry[i]())
configuration.add('compiler', 'custom', list(compiler_registry),
configuration.add('compiler', 'custom', compiler_registry,
callback=lambda i: compiler_registry[i](name=i))

# Setup language for shared-memory parallelism
Expand Down
6 changes: 5 additions & 1 deletion devito/arch/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,12 +988,16 @@ class CompilerRegistry(dict):
"""

def __getitem__(self, key):
if isinstance(key, Compiler):
key = key.name

if key.startswith('gcc-'):
i = key.split('-')[1]
return partial(GNUCompiler, suffix=i)

return super().__getitem__(key)

def __contains__(self, k):
def __contains__(self, key):
if isinstance(k, Compiler):
k = k.name
return k in self.keys() or k.startswith('gcc-')
Expand Down
4 changes: 2 additions & 2 deletions devito/passes/iet/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from devito.passes.iet.engine import iet_pass
from devito.passes.iet.langbase import LangBB
from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer,
SizeOf, VOID, Keyword, pow_to_mul)
SizeOf, VOID, pow_to_mul)
from devito.tools import as_mapper, as_list, as_tuple, filter_sorted, flatten
from devito.types import (Array, ComponentAccess, CustomDimension, DeviceMap,
DeviceRM, Eq, Symbol)
Expand Down Expand Up @@ -279,7 +279,7 @@ def _alloc_pointed_array_on_high_bw_mem(self, site, obj, storage):

memptr = VOID(Byref(obj._C_symbol), '**')
alignment = obj._data_alignment
nbytes = SizeOf(Keyword('%s*' % obj._C_typedata))*obj.dim.symbolic_size
nbytes = SizeOf(obj._C_typedata, stars='*')*obj.dim.symbolic_size
alloc0 = self.lang['host-alloc'](memptr, alignment, nbytes)

free0 = self.lang['host-free'](obj._C_symbol)
Expand Down
5 changes: 3 additions & 2 deletions devito/symbolics/extended_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,11 +766,12 @@ def __new__(cls, base=''):
# DefFunction, unlike sympy.Function, generates e.g. `sizeof(float)`, not `sizeof(float_)`
class SizeOf(DefFunction):

__rargs__ = ('intype',)
__rargs__ = ('intype', 'stars')

def __new__(cls, intype, **kwargs):
def __new__(cls, intype, stars=None, **kwargs):
newobj = super().__new__(cls, 'sizeof', arguments=[str(intype)], **kwargs)
newobj.intype = intype
newobj.stars = stars or ''

return newobj

Expand Down
4 changes: 3 additions & 1 deletion devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,9 @@ def _print_DefFunction(self, expr):
template = ''
return "%s%s(%s)" % (expr.name, template, ','.join(arguments))

def _print_SizeOf(self, expr):
return f'sizeof({self._print(expr.intype)}{self._print(expr.stars)})'

_print_MathFunction = _print_DefFunction

def _print_Fallback(self, expr):
Expand All @@ -359,7 +362,6 @@ def _print_Fallback(self, expr):
_print_IndexSum = _print_Fallback
_print_ReservedWord = _print_Fallback
_print_Basic = _print_Fallback
_print_SizeOf = _print_DefFunction


# Lifted from SymPy so that we go through our own `_print_math_func`
Expand Down

0 comments on commit b3b39af

Please sign in to comment.