Skip to content

Commit

Permalink
compiler: process dtypes through printer
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jan 15, 2025
1 parent 81c5dc5 commit 086f142
Show file tree
Hide file tree
Showing 12 changed files with 115 additions and 207 deletions.
5 changes: 0 additions & 5 deletions devito/finite_differences/finite_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,6 @@ def first_derivative(expr, dim, fd_order, **kwargs):
return generic_derivative(expr, dim, fd_order, 1, **kwargs)


# Backward compatibility
def first_derivative(expr, dim, fd_order, **kwargs):
return generic_derivative(expr, dim, fd_order, 1, **kwargs)


def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, coefficients,
expand, weights=None):
# Always expand time derivatives to avoid issue with buffering and streaming.
Expand Down
63 changes: 34 additions & 29 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
Call, Lambda, BlankLine, Section, ListMajor)
from devito.ir.support.space import Backward
from devito.symbolics import (FieldFromComposite, FieldFromPointer,
ListInitializer, ccode, uxreplace)
ListInitializer, uxreplace)
from devito.symbolics.printer import _DevitoPrinterBase
from devito.symbolics.extended_dtypes import NoDeclStruct
from devito.tools import (GenericVisitor, as_tuple, ctypes_to_cstr, filter_ordered,
from devito.tools import (GenericVisitor, as_tuple, filter_ordered,
filter_sorted, flatten, is_external_ctype,
c_restrict_void_p, sorted_priority)
from devito.types.basic import AbstractFunction, AbstractSymbol, Basic
Expand Down Expand Up @@ -176,9 +177,10 @@ class CGen(Visitor):
Return a representation of the Iteration/Expression tree as a :module:`cgen` tree.
"""

def __init__(self, *args, compiler=None, **kwargs):
def __init__(self, *args, compiler=None, printer=None, **kwargs):
super().__init__(*args, **kwargs)
self._compiler = compiler or configuration['compiler']
self._printer = printer or _DevitoPrinterBase

# The following mappers may be customized by subclasses (that is,
# backend-specific CGen-erators)
Expand All @@ -194,6 +196,9 @@ def __init__(self, *args, compiler=None, **kwargs):
def compiler(self):
return self._compiler

def ccode(self, expr, **settings):
return self._printer(settings=settings).doprint(expr, None)

def visit(self, o, *args, **kwargs):
# Make sure the visitor always is within the generating compiler
# in case the configuration is accessed
Expand Down Expand Up @@ -233,7 +238,7 @@ def _gen_struct_decl(self, obj, masked=()):
try:
entries.append(self._gen_value(i, 0, masked=('const',)))
except AttributeError:
cstr = ctypes_to_cstr(ct)
cstr = self.ccode(ct)
if ct is c_restrict_void_p:
cstr = '%srestrict' % cstr
entries.append(c.Value(cstr, n))
Expand All @@ -255,10 +260,10 @@ def _gen_value(self, obj, mode=1, masked=()):
if getattr(obj.function, k, False) and v not in masked]

if (obj._mem_stack or obj._mem_constant) and mode == 1:
strtype = obj._C_typedata
strshape = ''.join('[%s]' % ccode(i) for i in obj.symbolic_shape)
strtype = self.ccode(obj._C_typedata)
strshape = ''.join('[%s]' % self.ccode(i) for i in obj.symbolic_shape)
else:
strtype = ctypes_to_cstr(obj._C_ctype)
strtype = self.ccode(obj._C_ctype)
strshape = ''
if isinstance(obj, (AbstractFunction, IndexedData)) and mode >= 1:
if not obj._mem_stack:
Expand All @@ -272,7 +277,7 @@ def _gen_value(self, obj, mode=1, masked=()):
strobj = '%s%s' % (strname, strshape)

if obj.is_LocalObject and obj.cargs and mode == 1:
arguments = [ccode(i) for i in obj.cargs]
arguments = [self.ccode(i) for i in obj.cargs]
strobj = MultilineCall(strobj, arguments, True)

value = c.Value(strtype, strobj)
Expand All @@ -286,9 +291,9 @@ def _gen_value(self, obj, mode=1, masked=()):
if obj.is_Array and obj.initvalue is not None and mode == 1:
init = ListInitializer(obj.initvalue)
if not obj._mem_constant or init.is_numeric:
value = c.Initializer(value, ccode(init))
value = c.Initializer(value, self.ccode(init))
elif obj.is_LocalObject and obj.initvalue is not None and mode == 1:
value = c.Initializer(value, ccode(obj.initvalue))
value = c.Initializer(value, self.ccode(obj.initvalue))

return value

Expand Down Expand Up @@ -322,7 +327,7 @@ def _args_call(self, args):
else:
ret.append(i._C_name)
except AttributeError:
ret.append(ccode(i))
ret.append(self.ccode(i))
return ret

def _gen_signature(self, o, is_declaration=False):
Expand Down Expand Up @@ -388,7 +393,7 @@ def visit_tuple(self, o):
def visit_PointerCast(self, o):
f = o.function
i = f.indexed
cstr = i._C_typedata
cstr = self.ccode(i._C_typedata)

if f.is_PointerArray:
# lvalue
Expand All @@ -410,7 +415,7 @@ def visit_PointerCast(self, o):
else:
v = f.name
if o.flat is None:
shape = ''.join("[%s]" % ccode(i) for i in o.castshape)
shape = ''.join("[%s]" % self.ccode(i) for i in o.castshape)
rshape = '(*)%s' % shape
lvalue = c.Value(cstr, '(*restrict %s)%s' % (v, shape))
else:
Expand Down Expand Up @@ -443,9 +448,9 @@ def visit_Dereference(self, o):
a0, a1 = o.functions
if a1.is_PointerArray or a1.is_TempFunction:
i = a1.indexed
cstr = i._C_typedata
cstr = self.ccode(i._C_typedata)
if o.flat is None:
shape = ''.join("[%s]" % ccode(i) for i in a0.symbolic_shape[1:])
shape = ''.join("[%s]" % self.ccode(i) for i in a0.symbolic_shape[1:])
rvalue = '(%s (*)%s) %s[%s]' % (cstr, shape, a1.name,
a1.dim.name)
lvalue = c.Value(cstr, '(*restrict %s)%s' % (a0.name, shape))
Expand Down Expand Up @@ -484,8 +489,8 @@ def visit_Definition(self, o):
return self._gen_value(o.function)

def visit_Expression(self, o):
lhs = ccode(o.expr.lhs, dtype=o.dtype, compiler=self._compiler)
rhs = ccode(o.expr.rhs, dtype=o.dtype, compiler=self._compiler)
lhs = self.ccode(o.expr.lhs, dtype=o.dtype, compiler=self._compiler)
rhs = self.ccode(o.expr.rhs, dtype=o.dtype, compiler=self._compiler)

if o.init:
code = c.Initializer(self._gen_value(o.expr.lhs, 0), rhs)
Expand All @@ -498,8 +503,8 @@ def visit_Expression(self, o):
return code

def visit_AugmentedExpression(self, o):
c_lhs = ccode(o.expr.lhs, dtype=o.dtype, compiler=self._compiler)
c_rhs = ccode(o.expr.rhs, dtype=o.dtype, compiler=self._compiler)
c_lhs = self.ccode(o.expr.lhs, dtype=o.dtype, compiler=self._compiler)
c_rhs = self.ccode(o.expr.rhs, dtype=o.dtype, compiler=self._compiler)
code = c.Statement("%s %s= %s" % (c_lhs, o.op, c_rhs))
if o.pragmas:
code = c.Module(self._visit(o.pragmas) + (code,))
Expand All @@ -518,7 +523,7 @@ def visit_Call(self, o, nested_call=False):
o.templates)
if retobj.is_Indexed or \
isinstance(retobj, (FieldFromComposite, FieldFromPointer)):
return c.Assign(ccode(retobj), call)
return c.Assign(self.ccode(retobj), call)
else:
return c.Initializer(c.Value(rettype, retobj._C_name), call)

Expand All @@ -532,9 +537,9 @@ def visit_Conditional(self, o):
then_body = c.Block(self._visit(then_body))
if else_body:
else_body = c.Block(self._visit(else_body))
return c.If(ccode(o.condition), then_body, else_body)
return c.If(self.ccode(o.condition), then_body, else_body)
else:
return c.If(ccode(o.condition), then_body)
return c.If(self.ccode(o.condition), then_body)

def visit_Iteration(self, o):
body = flatten(self._visit(i) for i in self._blankline_logic(o.children))
Expand All @@ -544,23 +549,23 @@ def visit_Iteration(self, o):

# For backward direction flip loop bounds
if o.direction == Backward:
loop_init = 'int %s = %s' % (o.index, ccode(_max))
loop_cond = '%s >= %s' % (o.index, ccode(_min))
loop_init = 'int %s = %s' % (o.index, self.ccode(_max))
loop_cond = '%s >= %s' % (o.index, self.ccode(_min))
loop_inc = '%s -= %s' % (o.index, o.limits[2])
else:
loop_init = 'int %s = %s' % (o.index, ccode(_min))
loop_cond = '%s <= %s' % (o.index, ccode(_max))
loop_init = 'int %s = %s' % (o.index, self.ccode(_min))
loop_cond = '%s <= %s' % (o.index, self.ccode(_max))
loop_inc = '%s += %s' % (o.index, o.limits[2])

# Append unbounded indices, if any
if o.uindices:
uinit = ['%s = %s' % (i.name, ccode(i.symbolic_min)) for i in o.uindices]
uinit = ['%s = %s' % (i.name, self.ccode(i.symbolic_min)) for i in o.uindices]
loop_init = c.Line(', '.join([loop_init] + uinit))

ustep = []
for i in o.uindices:
op = '=' if i.is_Modulo else '+='
ustep.append('%s %s %s' % (i.name, op, ccode(i.symbolic_incr)))
ustep.append('%s %s %s' % (i.name, op, self.ccode(i.symbolic_incr)))
loop_inc = c.Line(', '.join([loop_inc] + ustep))

# Create For header+body
Expand All @@ -577,7 +582,7 @@ def visit_Pragma(self, o):
return c.Pragma(o._generate)

def visit_While(self, o):
condition = ccode(o.condition)
condition = self.ccode(o.condition)
if o.body:
body = flatten(self._visit(i) for i in o.children)
return c.While(condition, c.Block(body))
Expand Down
20 changes: 8 additions & 12 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,10 @@
from devito.passes import (Graph, lower_index_derivatives, generate_implicit,
generate_macros, minimize_symbols, unevaluate,
error_mapper, is_on_device)
from devito.passes.iet.langbase import LangBB
from devito.symbolics import estimate_cost, subs_op_args
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_mapper, as_tuple,
flatten, filter_sorted, frozendict, is_integer,
split, timed_pass, timed_region, contains_val)
from devito.tools.dtypes_lowering import ctypes_vector_mapper
from devito.types import (Buffer, Grid, Evaluable, host_layer, device_layer,
disk_layer)
from devito.types.dimension import Thickness
Expand Down Expand Up @@ -272,9 +270,6 @@ def _lower(cls, expressions, **kwargs):
# expression for which a partial or complete lowering is desired
kwargs['rcompile'] = cls._rcompile_wrapper(**kwargs)

# Load language-specific types into the global dtype->ctype mapper
cls._load_dtype_mappings(**kwargs)

# [Eq] -> [LoweredEq]
expressions = cls._lower_exprs(expressions, **kwargs)

Expand All @@ -296,11 +291,6 @@ def _lower(cls, expressions, **kwargs):
def _rcompile_wrapper(cls, **kwargs0):
raise NotImplementedError

@classmethod
def _load_dtype_mappings(cls, **kwargs):
lang: type[LangBB] = cls._Target.DataManager.lang
ctypes_vector_mapper.update(lang.mapper.get('types', {}))

@classmethod
def _initialize_state(cls, **kwargs):
return {}
Expand Down Expand Up @@ -764,13 +754,19 @@ def _soname(self):
"""A unique name for the shared object resulting from JIT compilation."""
return Signer._digest(self, configuration)

@property
def printer(self):
return self._Target.Printer

@cached_property
def ccode(self):
try:
return self._ccode_handler(compiler=self._compiler).visit(self)
return self._ccode_handler(compiler=self._compiler,
printer=self.printer).visit(self)
except (AttributeError, TypeError):
from devito.ir.iet.visitors import CGen
return CGen(compiler=self._compiler).visit(self)
return CGen(compiler=self._compiler,
printer=self.printer).visit(self)

def _jit_compile(self):
"""
Expand Down
32 changes: 2 additions & 30 deletions devito/passes/iet/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@
import numpy as np

from devito.arch.compiler import Compiler
from devito.ir import Callable, Dereference, FindSymbols, Node, SymbolRegistry, Uxreplace
from devito.ir import Callable, FindSymbols, SymbolRegistry
from devito.passes.iet.langbase import LangBB
from devito.symbolics.extended_dtypes import Float16P
from devito.tools import as_list
from devito.types.basic import AbstractSymbol, Basic, Symbol

__all__ = ['lower_dtypes']

Expand All @@ -17,34 +14,9 @@ def lower_dtypes(iet: Callable, lang: type[LangBB], compiler: Compiler,
Lowers float16 scalar types to pointers since we can't directly pass their
value. Also includes headers for complex arithmetic if needed.
"""

# Complex numbers
iet, metadata = _complex_includes(iet, lang, compiler)

# Lower float16 parameters to pointers and dereference
prefix: list[Node] = []
params_mapper: dict[AbstractSymbol, AbstractSymbol] = {}
body_mapper: dict[AbstractSymbol, Symbol] = {}

params_set = set(iet.parameters)
s: AbstractSymbol
for s in FindSymbols('abstractsymbols').visit(iet):
if s.dtype != np.float16 or s not in params_set:
continue

# Replace the parameter with a pointer; replace occurences in the IET
# body with dereferenced symbol (using the original symbol's dtype)
ptr: AbstractSymbol = s._rebuild(dtype=Float16P, is_const=True)
val = Symbol(name=sregistry.make_name(prefix='hf'), dtype=s.dtype,
is_const=s.is_const)

params_mapper[s], body_mapper[s] = ptr, val
prefix.append(Dereference(val, ptr)) # val = *ptr

# Apply the replacements
prefix.extend(as_list(Uxreplace(body_mapper).visit(iet.body)))
params: tuple[Basic] = Uxreplace(params_mapper).visit(iet.parameters)

iet = iet._rebuild(body=prefix, parameters=params)
return iet, metadata


Expand Down
29 changes: 12 additions & 17 deletions devito/passes/iet/languages/C.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,11 @@
import numpy as np

from devito.ir import Call
from devito.passes.iet.definitions import DataManager
from devito.passes.iet.orchestration import Orchestrator
from devito.passes.iet.langbase import LangBB
from devito.symbolics.extended_dtypes import (Float16P, c_complex, c_double_complex,
c_half, c_half_p)


__all__ = ['CBB', 'CDataManager', 'COrchestrator', 'c_float16', 'c_float16_p']

from devito.symbolics.extended_dtypes import c_complex, c_double_complex
from devito.symbolics.printer import _DevitoPrinterBase

c99_complex = type('_Complex float', (c_complex,), {})
c99_double_complex = type('_Complex double', (c_double_complex,), {})

c_float16 = type('_Float16', (c_half,), {})
c_float16_p = type('_Float16 *', (c_half_p,), {'_type_': c_float16})
__all__ = ['CBB', 'CDataManager', 'COrchestrator']


class CBB(LangBB):
Expand All @@ -34,10 +24,6 @@ class CBB(LangBB):
Call('memcpy', (i, j, k)),
# Complex and float16
'header-complex': 'complex.h',
'types': {np.complex128: c99_double_complex,
np.complex64: c99_complex,
np.float16: c_float16,
Float16P: c_float16_p}
}


Expand All @@ -47,3 +33,12 @@ class CDataManager(DataManager):

class COrchestrator(Orchestrator):
lang = CBB


class CDevitoPrinter(_DevitoPrinterBase):

# These cannot go through _print_xxx because they are classes not
# instances
type_mappings = {**_DevitoPrinterBase.type_mappings,
c_complex: 'float _Complex',
c_double_complex: 'double _Complex'}
Loading

0 comments on commit 086f142

Please sign in to comment.