From 0aa295619fd0aa695c93ebe8cc46a958aa056b48 Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Wed, 8 Jan 2025 08:52:40 +0000 Subject: [PATCH] Expression: Remove logical literals in and/or ops during simplify The current implementation would only resolve .and./.or. operations if all arguments are logical literals. Insterad, we can evaluate `.true. and a` => `.true.` and `.false. and a` => `a`, and vice versa for `or` operations. --- loki/expression/symbolic.py | 20 ++++++++++++-------- loki/expression/tests/test_symbolic.py | 4 ++++ 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/loki/expression/symbolic.py b/loki/expression/symbolic.py index effb4ef94..e8c5fb09b 100644 --- a/loki/expression/symbolic.py +++ b/loki/expression/symbolic.py @@ -585,21 +585,25 @@ def map_comparison(self, expr, *args, **kwargs): def map_logical_and(self, expr, *args, **kwargs): children = tuple(self.rec(child, *args, **kwargs) for child in expr.children) - if all(isinstance(c, sym.LogicLiteral) for c in children): - if all(c == 'True' for c in children): - return sym.LogicLiteral('True') - return sym.LogicLiteral('False') + if self.enabled_simplifications & Simplification.LogicEvaluation: + if any(c == 'False' for c in children): + return sym.LogicLiteral('False') + if any(c == 'True' for c in children): + # Trim all literals and return .true. if all were .true. + children = tuple(c for c in children if not c == 'True') - return sym.LogicalAnd(children) + return sym.LogicalAnd(children) if len(children) > 0 else sym.LogicLiteral('True') def map_logical_or(self, expr, *args, **kwargs): children = tuple(self.rec(child, *args, **kwargs) for child in expr.children) - if all(isinstance(c, sym.LogicLiteral) for c in children): + if self.enabled_simplifications & Simplification.LogicEvaluation: if any(c == 'True' for c in children): return sym.LogicLiteral('True') - return sym.LogicLiteral('False') + if any(c == 'False' for c in children): + # Trim all literals and return .false. if all were .false. + children = tuple(c for c in children if not c == 'False') - return sym.LogicalOr(children) + return sym.LogicalOr(children) if len(children) > 0 else sym.LogicLiteral('False') def simplify(expr, enabled_simplifications=Simplification.ALL): diff --git a/loki/expression/tests/test_symbolic.py b/loki/expression/tests/test_symbolic.py index d0f2c34ef..16f1f0a35 100644 --- a/loki/expression/tests/test_symbolic.py +++ b/loki/expression/tests/test_symbolic.py @@ -197,6 +197,10 @@ def test_simplify_collect_coefficients(source, ref): ('.false. .or. .false.', 'False'), ('2 == 1 .and. 1 == 1', 'False'), ('2 == 1 .or. 1 == 1', 'True'), + ('.true. .or. a', 'True'), + ('.false. .or. a', 'a'), + ('.false. .and. a', 'False'), + ('.true. .and. a', 'a'), ]) def test_simplify_logic_evaluation(source, ref): scope = Scope()