diff --git a/loki/expression/symbolic.py b/loki/expression/symbolic.py index effb4ef94..e8c5fb09b 100644 --- a/loki/expression/symbolic.py +++ b/loki/expression/symbolic.py @@ -585,21 +585,25 @@ def map_comparison(self, expr, *args, **kwargs): def map_logical_and(self, expr, *args, **kwargs): children = tuple(self.rec(child, *args, **kwargs) for child in expr.children) - if all(isinstance(c, sym.LogicLiteral) for c in children): - if all(c == 'True' for c in children): - return sym.LogicLiteral('True') - return sym.LogicLiteral('False') + if self.enabled_simplifications & Simplification.LogicEvaluation: + if any(c == 'False' for c in children): + return sym.LogicLiteral('False') + if any(c == 'True' for c in children): + # Trim all literals and return .true. if all were .true. + children = tuple(c for c in children if not c == 'True') - return sym.LogicalAnd(children) + return sym.LogicalAnd(children) if len(children) > 0 else sym.LogicLiteral('True') def map_logical_or(self, expr, *args, **kwargs): children = tuple(self.rec(child, *args, **kwargs) for child in expr.children) - if all(isinstance(c, sym.LogicLiteral) for c in children): + if self.enabled_simplifications & Simplification.LogicEvaluation: if any(c == 'True' for c in children): return sym.LogicLiteral('True') - return sym.LogicLiteral('False') + if any(c == 'False' for c in children): + # Trim all literals and return .false. if all were .false. + children = tuple(c for c in children if not c == 'False') - return sym.LogicalOr(children) + return sym.LogicalOr(children) if len(children) > 0 else sym.LogicLiteral('False') def simplify(expr, enabled_simplifications=Simplification.ALL): diff --git a/loki/expression/tests/test_symbolic.py b/loki/expression/tests/test_symbolic.py index d0f2c34ef..16f1f0a35 100644 --- a/loki/expression/tests/test_symbolic.py +++ b/loki/expression/tests/test_symbolic.py @@ -197,6 +197,10 @@ def test_simplify_collect_coefficients(source, ref): ('.false. .or. .false.', 'False'), ('2 == 1 .and. 1 == 1', 'False'), ('2 == 1 .or. 1 == 1', 'True'), + ('.true. .or. a', 'True'), + ('.false. .or. a', 'a'), + ('.false. .and. a', 'False'), + ('.true. .and. a', 'a'), ]) def test_simplify_logic_evaluation(source, ref): scope = Scope()