From 93caaaa5b09dfc52d1b3ad4169eec65c8c3a2c88 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 15 Jan 2025 15:49:41 +0000 Subject: [PATCH] compiler: Fixup pow_to_mul --- devito/symbolics/manipulation.py | 2 +- tests/test_dse.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/devito/symbolics/manipulation.py b/devito/symbolics/manipulation.py index 6ca746adcc..f5992ac8be 100644 --- a/devito/symbolics/manipulation.py +++ b/devito/symbolics/manipulation.py @@ -334,7 +334,7 @@ def pow_to_mul(expr): # but at least we traverse the base looking for other Pows return expr.func(pow_to_mul(base), exp, evaluate=False) elif exp > 0: - return Mul(*[base]*int(exp), evaluate=False) + return Mul(*[pow_to_mul(base)]*int(exp), evaluate=False) else: # SymPy represents 1/x as Pow(x,-1). Also, it represents # 2/x as Mul(2, Pow(x, -1)). So we shouldn't end up here, diff --git a/tests/test_dse.py b/tests/test_dse.py index bde9c1b27e..06e00ab182 100644 --- a/tests/test_dse.py +++ b/tests/test_dse.py @@ -74,6 +74,7 @@ def test_scheduling_after_rewrite(): ('sqrt(fa[x]**4)', 'sqrt(fa[x]*fa[x]*fa[x]*fa[x])'), ('sqrt(fa[x])**2', 'fa[x]'), ('fa[x]**-2', '1/(fa[x]*fa[x])'), + ('cos(fa[x]*fa[x])*cos(fa[x]*fa[x])', 'cos(fa[x]*fa[x])*cos(fa[x]*fa[x])'), ]) def test_pow_to_mul(expr, expected): grid = Grid((4, 5))