Skip to content

Commit

Permalink
Merge pull request #467 from ecmwf-ifs/naml-symbolics-better-logical-…
Browse files Browse the repository at this point in the history
…eval

Expression: Remove logical literals in and/or operations during simplify
  • Loading branch information
reuterbal authored Jan 9, 2025
2 parents ad1e2bc + 0aa2956 commit 33c08c3
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
20 changes: 12 additions & 8 deletions loki/expression/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,21 +585,25 @@ def map_comparison(self, expr, *args, **kwargs):

def map_logical_and(self, expr, *args, **kwargs):
children = tuple(self.rec(child, *args, **kwargs) for child in expr.children)
if all(isinstance(c, sym.LogicLiteral) for c in children):
if all(c == 'True' for c in children):
return sym.LogicLiteral('True')
return sym.LogicLiteral('False')
if self.enabled_simplifications & Simplification.LogicEvaluation:
if any(c == 'False' for c in children):
return sym.LogicLiteral('False')
if any(c == 'True' for c in children):
# Trim all literals and return .true. if all were .true.
children = tuple(c for c in children if not c == 'True')

return sym.LogicalAnd(children)
return sym.LogicalAnd(children) if len(children) > 0 else sym.LogicLiteral('True')

def map_logical_or(self, expr, *args, **kwargs):
children = tuple(self.rec(child, *args, **kwargs) for child in expr.children)
if all(isinstance(c, sym.LogicLiteral) for c in children):
if self.enabled_simplifications & Simplification.LogicEvaluation:
if any(c == 'True' for c in children):
return sym.LogicLiteral('True')
return sym.LogicLiteral('False')
if any(c == 'False' for c in children):
# Trim all literals and return .false. if all were .false.
children = tuple(c for c in children if not c == 'False')

return sym.LogicalOr(children)
return sym.LogicalOr(children) if len(children) > 0 else sym.LogicLiteral('False')


def simplify(expr, enabled_simplifications=Simplification.ALL):
Expand Down
4 changes: 4 additions & 0 deletions loki/expression/tests/test_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,10 @@ def test_simplify_collect_coefficients(source, ref):
('.false. .or. .false.', 'False'),
('2 == 1 .and. 1 == 1', 'False'),
('2 == 1 .or. 1 == 1', 'True'),
('.true. .or. a', 'True'),
('.false. .or. a', 'a'),
('.false. .and. a', 'False'),
('.true. .and. a', 'a'),
])
def test_simplify_logic_evaluation(source, ref):
scope = Scope()
Expand Down

0 comments on commit 33c08c3

Please sign in to comment.