Skip to content

Commit

Permalink
feat(generic): add simplified type inference mechanism ExpressionType…
Browse files Browse the repository at this point in the history
…Mapper
  • Loading branch information
quepas committed Jul 2, 2024
1 parent 24564a0 commit fe2a819
Showing 1 changed file with 38 additions and 1 deletion.
39 changes: 38 additions & 1 deletion loki/expression/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


__all__ = ['LokiStringifyMapper', 'ExpressionRetriever', 'ExpressionDimensionsMapper',
'ExpressionCallbackMapper', 'SubstituteExpressionsMapper',
'ExpressionTypeMapper', 'ExpressionCallbackMapper', 'SubstituteExpressionsMapper',
'LokiIdentityMapper', 'AttachScopesMapper', 'DetachScopesMapper']


Expand Down Expand Up @@ -426,6 +426,43 @@ def map_inline_do(self, expr, *args, **kwargs):
return self.rec(expr.bounds, *args, **kwargs)


class ExpressionTypeMapper(Mapper):
"""
A visitor for an expression that determines the type of the expression.
This is a WIP implementation (missing, e.g.: handling of kinds, implicit type conversions)
"""
# pylint: disable=abstract-method,unused-argument

def map_float_literal(self, expr, *args, **kwargs):
return BasicType.REAL

def map_int_literal(self, expr, *args, **kwargs):
return BasicType.INTEGER

def map_logic_literal(self, expr, *args, **kwargs):
return BasicType.LOGICAL

def map_string_literal(self, expr, *args, **kwargs):
return BasicType.CHARACTER

def map_scalar(self, expr, *args, **kwargs):
return expr.type.dtype

map_array = map_scalar

def map_sum(self, expr, *args, **kwargs):
left = self.rec(expr.children[0], *args, **kwargs)
right = self.rec(expr.children[1], *args, **kwargs)
# INTEGER can be promoted to REAL
if left == BasicType.REAL and right == BasicType.INTEGER \
or left == BasicType.INTEGER and right == BasicType.REAL:
return BasicType.REAL
if left != right:
raise ValueError(f'Non-matching types: {str(left)} and {str(right)}')
return left

map_product = map_sum

class ExpressionCallbackMapper(CombineMapper):
"""
A visitor for expressions that returns the combined result of a specified callback function.
Expand Down

0 comments on commit fe2a819

Please sign in to comment.