Skip to content

Commit

Permalink
Loki: Cleaning up imports in subroutine tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mlange05 committed Nov 25, 2024
1 parent 140fe54 commit 9fcbb07
Showing 1 changed file with 49 additions and 49 deletions.
98 changes: 49 additions & 49 deletions loki/tests/test_subroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@
import pytest
import numpy as np

from loki import (
Sourcefile, Module, Subroutine, FindVariables, FindNodes, Section,
Array, Scalar, Variable,
SymbolAttributes, StringLiteral, fgen, fexprgen,
VariableDeclaration, Transformer, FindTypedSymbols,
ProcedureSymbol, StatementFunction, DeferredTypeSymbol
)
from loki import Sourcefile, Module, Subroutine, fgen, fexprgen
from loki.build import jit_compile, jit_compile_lib, clean_test
from loki.expression import symbols as sym
from loki.frontend import available_frontends, OMNI, REGEX
from loki.types import BasicType, DerivedType, ProcedureType
from loki.ir import nodes as ir
from loki.ir import (
nodes as ir, FindNodes, FindVariables, FindTypedSymbols,
Transformer
)
from loki.types import (
BasicType, DerivedType, ProcedureType, SymbolAttributes
)


@pytest.fixture(scope='module', name='here')
Expand Down Expand Up @@ -56,8 +56,8 @@ def test_routine_simple(tmp_path, frontend):

# Test the internals of the subroutine
routine = Subroutine.from_source(fcode, frontend=frontend)
assert isinstance(routine.body, Section)
assert isinstance(routine.spec, Section)
assert isinstance(routine.body, ir.Section)
assert isinstance(routine.spec, ir.Section)
assert len(routine.docstring) == 1
assert routine.docstring[0].text == '! This is the docstring'
assert routine.definitions == ()
Expand Down Expand Up @@ -216,9 +216,9 @@ def test_routine_arguments_add_remove(frontend):
# Create a new set of variables and add to local routine variables
x = routine.variables[1] # That's the symbol for variable 'x'
real_type = routine.symbol_attrs['scalar'] # Type of variable 'maximum'
a = Scalar(name='a', type=real_type, scope=routine)
b = Array(name='b', dimensions=(x, ), type=real_type, scope=routine)
c = Variable(name='c', type=x.type, scope=routine)
a = sym.Scalar(name='a', type=real_type, scope=routine)
b = sym.Array(name='b', dimensions=(x, ), type=real_type, scope=routine)
c = sym.Variable(name='c', type=x.type, scope=routine)

# Add new arguments and check that they are all in the routine spec
routine.arguments += (a, b, c)
Expand Down Expand Up @@ -375,9 +375,9 @@ def test_routine_variables_add_remove(frontend):
x = routine.variable_map['x'] # That's the symbol for variable 'x'
real_type = SymbolAttributes('real', kind=routine.variable_map['jprb'])
int_type = SymbolAttributes('integer')
a = Scalar(name='a', type=real_type, scope=routine)
b = Array(name='b', dimensions=(x, ), type=real_type, scope=routine)
c = Variable(name='c', type=int_type, scope=routine)
a = sym.Scalar(name='a', type=real_type, scope=routine)
b = sym.Array(name='b', dimensions=(x, ), type=real_type, scope=routine)
c = sym.Variable(name='c', type=int_type, scope=routine)

# Add new variables and check that they are all in the routine spec
routine.variables += (a, b, c)
Expand Down Expand Up @@ -493,17 +493,17 @@ def test_routine_variables_dim_shapes(frontend):
assert routine.arguments == ('v1', 'v2', 'v3(:)', 'v4(v1, v2)', 'v5(0:v1, v2 - 1)')

# Make sure variable/argument shapes on the routine work
shapes = [fexprgen(v.shape) for v in routine.arguments if isinstance(v, Array)]
shapes = [fexprgen(v.shape) for v in routine.arguments if isinstance(v, sym.Array)]
assert shapes == ['(v1,)', '(v1, v2)', '(0:v1, v2 - 1)']

# Ensure that all spec variables (including dimension symbols) are scoped correctly
spec_vars = [v for v in FindVariables(unique=False).visit(routine.spec) if v.name.lower() != 'selected_real_kind']
assert all(v.scope == routine for v in spec_vars)
assert all(isinstance(v, (Scalar, Array)) for v in spec_vars)
assert all(isinstance(v, (sym.Scalar, sym.Array)) for v in spec_vars)

# Ensure shapes of body variables are ok
b_shapes = [fexprgen(v.shape) for v in FindVariables(unique=False).visit(routine.body)
if isinstance(v, Array)]
if isinstance(v, sym.Array)]
assert b_shapes == ['(v1,)', '(v1,)', '(v1, v2)', '(0:v1, v2 - 1)']


Expand Down Expand Up @@ -540,7 +540,7 @@ def test_routine_variables_shape_propagation(tmp_path, header_path, frontend):

# Verify that all variable instances have type and shape information
variables = FindVariables().visit(routine.body)
assert all(v.shape is not None for v in variables if isinstance(v, Array))
assert all(v.shape is not None for v in variables if isinstance(v, sym.Array))

vmap = {v.name: v for v in variables}
assert fexprgen(vmap['vector'].shape) == '(x,)'
Expand Down Expand Up @@ -576,7 +576,7 @@ def test_routine_variables_shape_propagation(tmp_path, header_path, frontend):

# Verify that all derived type variables have shape info
variables = FindVariables().visit(routine.body)
assert all(v.shape is not None for v in variables if isinstance(v, Array))
assert all(v.shape is not None for v in variables if isinstance(v, sym.Array))

# Verify shape info from imported derived type is propagated
vmap = {v.name: v for v in variables}
Expand Down Expand Up @@ -656,7 +656,7 @@ def test_routine_type_propagation(header_path, frontend, tmp_path):

# Verify that all variable instances have type information
variables = FindVariables().visit(routine.body)
assert all(v.type is not None for v in variables if isinstance(v, (Scalar, Array)))
assert all(v.type is not None for v in variables if isinstance(v, (sym.Scalar, sym.Array)))

vmap = {v.name: v for v in variables}
assert vmap['x'].type.dtype == BasicType.INTEGER
Expand Down Expand Up @@ -744,11 +744,11 @@ def test_routine_call_arrays(header_path, frontend, tmp_path):
assert str(call.arguments[3]) == 'matrix'
assert str(call.arguments[4]) == 'item%matrix'

assert isinstance(call.arguments[0], Scalar)
assert isinstance(call.arguments[1], Scalar)
assert isinstance(call.arguments[2], Array)
assert isinstance(call.arguments[3], Array)
assert isinstance(call.arguments[4], Array)
assert isinstance(call.arguments[0], sym.Scalar)
assert isinstance(call.arguments[1], sym.Scalar)
assert isinstance(call.arguments[2], sym.Array)
assert isinstance(call.arguments[3], sym.Array)
assert isinstance(call.arguments[4], sym.Array)

assert fexprgen(call.arguments[2].shape) == '(x,)'
assert fexprgen(call.arguments[3].shape) == '(x, y)'
Expand Down Expand Up @@ -791,10 +791,10 @@ def test_call_kwargs(frontend):
assert all(isinstance(arg, tuple) and len(arg) == 2 for arg in calls[0].kwarguments)

assert calls[0].kwarguments[0][0] == 'kprocs'
assert (isinstance(calls[0].kwarguments[0][1], Scalar) and
assert (isinstance(calls[0].kwarguments[0][1], sym.Scalar) and
calls[0].kwarguments[0][1].name == 'kprocs')

assert calls[0].kwarguments[1] == ('cdstring', StringLiteral('routine_call_kwargs'))
assert calls[0].kwarguments[1] == ('cdstring', sym.StringLiteral('routine_call_kwargs'))


@pytest.mark.parametrize('frontend', available_frontends())
Expand All @@ -812,7 +812,7 @@ def test_call_args_kwargs(frontend):
assert calls[0].name == 'mpl_send'
assert len(calls[0].arguments) == 3
assert all(a.name == b.name for a, b in zip(calls[0].arguments, routine.arguments))
assert calls[0].kwarguments == (('cdstring', StringLiteral('routine_call_args_kwargs')),)
assert calls[0].kwarguments == (('cdstring', sym.StringLiteral('routine_call_args_kwargs')),)


@pytest.mark.parametrize('frontend', available_frontends())
Expand Down Expand Up @@ -1178,15 +1178,15 @@ def test_external_stmt(tmp_path, frontend):
routine = source['routine_external_stmt']
assert len(routine.arguments) == 8

for decl in FindNodes(VariableDeclaration).visit(routine.spec):
for decl in FindNodes(ir.VariableDeclaration).visit(routine.spec):
# Skip local variables
if decl.symbols[0].name in ('invar', 'outvar', 'tmp'):
continue
# Is the EXTERNAL attribute set?
assert decl.external
for v in decl.symbols:
# Are procedure names represented as Scalar objects?
assert isinstance(v, ProcedureSymbol)
assert isinstance(v, sym.ProcedureSymbol)
assert isinstance(v.type.dtype, ProcedureType)
assert v.type.external is True
assert v.type.dtype.procedure == BasicType.DEFERRED
Expand Down Expand Up @@ -1499,12 +1499,12 @@ def test_subroutine_stmt_func(tmp_path, frontend):
# OMNI inlines statement functions, so we can only check correct representation
# for fparser
if frontend != OMNI:
stmt_func_decls = {d.variable: d for d in FindNodes(StatementFunction).visit(routine.spec)}
stmt_func_decls = {d.variable: d for d in FindNodes(ir.StatementFunction).visit(routine.spec)}
assert len(stmt_func_decls) == 3

for name in ('plus', 'minus', 'mult'):
var = routine.variable_map[name]
assert isinstance(var, ProcedureSymbol)
assert isinstance(var, sym.ProcedureSymbol)
assert isinstance(var.type.dtype, ProcedureType)
assert var.type.dtype.procedure is stmt_func_decls[var]
assert stmt_func_decls[var].source is not None
Expand All @@ -1530,8 +1530,8 @@ def test_mixed_declaration_interface(frontend):

with pytest.raises(AssertionError) as error:
routine = Subroutine.from_source(fcode, frontend=frontend)
assert isinstance(routine.body, Section)
assert isinstance(routine.spec, Section)
assert isinstance(routine.body, ir.Section)
assert isinstance(routine.spec, ir.Section)
_ = routine.interface

assert "Declarations must have intents" in str(error.value)
Expand All @@ -1556,7 +1556,7 @@ def test_subroutine_prefix(frontend):
assert routine.return_type.dtype is BasicType.REAL

assert routine.name in routine.symbol_map
decl = [d for d in FindNodes(VariableDeclaration).visit(routine.spec) if routine.name in d.symbols]
decl = [d for d in FindNodes(ir.VariableDeclaration).visit(routine.spec) if routine.name in d.symbols]
assert len(decl) == 1
decl = decl[0]

Expand Down Expand Up @@ -1752,15 +1752,15 @@ def test_subroutine_lazy_arguments_incomplete1(frontend):
assert routine.arguments == ()
assert routine.argnames == []
assert routine._dummies == ()
assert all(isinstance(arg, DeferredTypeSymbol) for arg in routine.arguments)
assert all(isinstance(arg, sym.DeferredTypeSymbol) for arg in routine.arguments)

routine.make_complete(frontend=frontend)
assert not routine._incomplete
assert routine.arguments == ('n', 'a(n)', 'b(n)', 'd(n)')
assert routine.argnames == ['n', 'a', 'b', 'd']
assert routine._dummies == ('n', 'a', 'b', 'd')
assert isinstance(routine.arguments[0], Scalar)
assert all(isinstance(arg, Array) for arg in routine.arguments[1:])
assert isinstance(routine.arguments[0], sym.Scalar)
assert all(isinstance(arg, sym.Array) for arg in routine.arguments[1:])


@pytest.mark.parametrize('frontend', available_frontends())
Expand Down Expand Up @@ -1841,15 +1841,15 @@ def test_subroutine_lazy_arguments_incomplete2(frontend):
assert routine.arguments == ()
assert routine.argnames == []
assert routine._dummies == ()
assert all(isinstance(arg, DeferredTypeSymbol) for arg in routine.arguments)
assert all(isinstance(arg, sym.DeferredTypeSymbol) for arg in routine.arguments)

routine.make_complete(frontend=frontend)
assert not routine._incomplete
assert routine.arguments == argnames_with_dim
assert [arg.upper() for arg in routine.argnames] == [arg.upper() for arg in argnames]
assert routine._dummies == argnames
assert all(isinstance(arg, Scalar) for arg in routine.arguments[:4])
assert all(isinstance(arg, Array) for arg in routine.arguments[4:])
assert all(isinstance(arg, sym.Scalar) for arg in routine.arguments[:4])
assert all(isinstance(arg, sym.Array) for arg in routine.arguments[4:])


@pytest.mark.parametrize('frontend', available_frontends())
Expand Down Expand Up @@ -2092,18 +2092,18 @@ def test_enrich_derived_types(tmp_path, frontend):
assert yda_array.type.dtype.typedef is field_3rb_tdef
assert yda_array_p.type.dtype is BasicType.REAL
assert yda_array_p.type.shape == (':', ':', ':')
assert isinstance(yda_array_p, Array)
assert isinstance(yda_array_p, sym.Array)

# Double-check body and spec expressions
decls = FindNodes(ir.VariableDeclaration).visit(routine.spec)
assert len(decls) == 1
assert len(decls[0].symbols) == 1
assert isinstance(decls[0].symbols[0], Scalar)
assert isinstance(decls[0].symbols[0], sym.Scalar)
assert decls[0].symbols[0].type.dtype.typedef == field_3rb_tdef

assigns = FindNodes(ir.Assignment).visit(routine.body)
assert len(assigns) == 1
assert isinstance(assigns[0].lhs, Array)
assert isinstance(assigns[0].lhs, sym.Array)
assert assigns[0].lhs.type.dtype == BasicType.REAL
assert assigns[0].lhs.type.shape == (':', ':', ':')
assert assigns[0].lhs.parent.type.dtype.typedef == field_3rb_tdef
Expand Down Expand Up @@ -2145,7 +2145,7 @@ def test_subroutine_deep_clone(frontend):
map_nodes={}
for assign in FindNodes(ir.Assignment).visit(new_routine.body):
map_nodes[assign] = ir.CallStatement(
name=DeferredTypeSymbol(name='testcall'), arguments=(assign.lhs,), scope=new_routine
name=sym.DeferredTypeSymbol(name='testcall'), arguments=(assign.lhs,), scope=new_routine
)
new_routine.body = Transformer(map_nodes).visit(new_routine.body)

Expand Down Expand Up @@ -2291,7 +2291,7 @@ def test_resolve_typebound_var(frontend, tmp_path):

# Instead, we can creatae a deferred type variable in the scope and
# resolve members relative to it
not_tt = Variable(name='not_tt', scope=routine)
not_tt = sym.Variable(name='not_tt', scope=routine)
assert not_tt.type.dtype == BasicType.DEFERRED # pylint: disable=no-member
not_tt_invalid = not_tt.get_derived_type_member('invalid') # pylint: disable=no-member
assert not_tt_invalid == 'not_tt%invalid'
Expand Down

0 comments on commit 9fcbb07

Please sign in to comment.