diff --git a/loki/analyse/analyse_array_data_dependency_detection.py b/loki/analyse/analyse_array_data_dependency_detection.py index a80520cf6..1219a26da 100644 --- a/loki/analyse/analyse_array_data_dependency_detection.py +++ b/loki/analyse/analyse_array_data_dependency_detection.py @@ -8,11 +8,18 @@ from math import gcd from warnings import warn from dataclasses import dataclass +from typing import Any, List from numpy.typing import NDArray as npt_NDArray from numpy import int_ as np_int_ import numpy as np from ortools.linear_solver import pywraplp - +from loki.expression import ( + FindVariables, + accumulate_polynomial_terms, + simplify, + is_constant, + symbols as sym, +) from loki.analyse.util_linear_algebra import ( generate_row_echelon_form, back_substitution, @@ -20,7 +27,10 @@ yield_one_d_systems, ) -__all__ = ["has_data_dependency"] +__all__ = [ + "has_data_dependency", + "construct_affine_array_access_function_representation", +] NDArrayInt = npt_NDArray[np_int_] @@ -36,6 +46,10 @@ class IterationSpaceRepresentation: vector: NDArrayInt +# abbrivation for readability +IterSpacRepr = IterationSpaceRepresentation + + @dataclass class AffineArrayAccessFunctionRepresentation: """ @@ -47,6 +61,10 @@ class AffineArrayAccessFunctionRepresentation: vector: NDArrayInt +# abbrivation for readability +AffiAcceRepr = AffineArrayAccessFunctionRepresentation + + @dataclass class StaticArrayAccessInstanceRepresentation: """ @@ -54,24 +72,20 @@ class StaticArrayAccessInstanceRepresentation: an affine array access function representation. """ - iteration_space: IterationSpaceRepresentation - access_function: AffineArrayAccessFunctionRepresentation + iteration_space: IterSpacRepr + access_function: AffiAcceRepr - def __init__( - self, - iteration_space: IterationSpaceRepresentation, - access_function: AffineArrayAccessFunctionRepresentation, - ): - self.iteration_space = IterationSpaceRepresentation(*iteration_space) - self.access_function = AffineArrayAccessFunctionRepresentation(*access_function) + +# abbrivation for readability +StatInstRepr = StaticArrayAccessInstanceRepresentation def _assert_correct_dimensions( - first_array_access: StaticArrayAccessInstanceRepresentation, - second_array_access: StaticArrayAccessInstanceRepresentation, + first_array_access: StatInstRepr, + second_array_access: StatInstRepr, ) -> None: def assert_compatibility_of_iteration_space( - first: IterationSpaceRepresentation, second: IterationSpaceRepresentation + first: IterSpacRepr, second: IterSpacRepr ): (number_inequalties_1, number_variables_1) = first.matrix.shape (rhs_number_inequalities_1, column_count) = first.vector.shape @@ -94,8 +108,8 @@ def assert_compatibility_of_iteration_space( ) def assert_compatibility_of_access_function( - first: AffineArrayAccessFunctionRepresentation, - second: AffineArrayAccessFunctionRepresentation, + first: AffiAcceRepr, + second: AffiAcceRepr, ): (number_rows_1, number_columns_1) = first.matrix.shape (rhs_number_rows_1, column_count) = first.vector.shape @@ -118,8 +132,8 @@ def assert_compatibility_of_access_function( ) def assert_compatibility_of_access_and_iteration_unkowns( - function: AffineArrayAccessFunctionRepresentation, - space: IterationSpaceRepresentation, + function: AffiAcceRepr, + space: IterSpacRepr, ): (_, number_variables_space) = space.matrix.shape (_, number_variables_function) = function.matrix.shape @@ -260,41 +274,21 @@ def _solve_inequalties_with_ortools( return solver.Solve() -def has_data_dependency( - first_array_access: StaticArrayAccessInstanceRepresentation, - second_array_access: StaticArrayAccessInstanceRepresentation, +def has_data_dependency_impl( + first_array_access: StatInstRepr, + second_array_access: StatInstRepr, ) -> bool: - if not isinstance(first_array_access, StaticArrayAccessInstanceRepresentation): - first_array_access = StaticArrayAccessInstanceRepresentation( - *first_array_access - ) - - if not isinstance(second_array_access, StaticArrayAccessInstanceRepresentation): - second_array_access = StaticArrayAccessInstanceRepresentation( - *second_array_access - ) - - _assert_correct_dimensions(first_array_access, second_array_access) - - F, f = ( - first_array_access.access_function.matrix, - first_array_access.access_function.vector, - ) - F_dash, f_dash = ( - second_array_access.access_function.matrix, - second_array_access.access_function.vector, - ) - + # fmt: off (has_solution, solved_system) = _gaussian_eliminiation_for_diophantine_equations( np.hstack( ( - F, - -F_dash, - f - f_dash, + first_array_access.access_function.matrix, + -second_array_access.access_function.matrix, + first_array_access.access_function.vector - second_array_access.access_function.vector, ) ) ) - + # fmt: on if not has_solution: return False @@ -304,18 +298,25 @@ def has_data_dependency( R, S, g_hat = solved_system[:, :n], solved_system[:, n:-1], solved_system[:, -1:] - B, b = ( - first_array_access.iteration_space.matrix, - first_array_access.iteration_space.vector, + # fmt: off + K = np.block( + [ + [ + first_array_access.iteration_space.matrix, np.zeros_like(first_array_access.iteration_space.matrix), + ], + [ + np.zeros_like(second_array_access.iteration_space.matrix), second_array_access.iteration_space.matrix, + ], + ] ) - B_dash, b_dash = ( - second_array_access.iteration_space.matrix, - second_array_access.iteration_space.vector, + # fmt: on + k = np.vstack( + ( + first_array_access.iteration_space.vector, + second_array_access.iteration_space.vector, + ) ) - K = np.block([[B, np.zeros_like(B)], [np.zeros_like(B_dash), B_dash]]) - k = np.vstack((b, b_dash)) - n, m = R.shape K_R = K[0:n, 0:m] @@ -350,3 +351,98 @@ def has_data_dependency( ) return True # assume data dependency if all tests are inconclusive + + +def _ensure_correct_StatInstRepr(array_access: Any) -> StatInstRepr: + if isinstance(array_access, StatInstRepr): + return array_access + + if len(array_access) == 2: + iteration_space = array_access[0] + access_function = array_access[1] + + if isinstance(iteration_space, IterSpacRepr) and isinstance( + access_function, AffiAcceRepr + ): + return StatInstRepr(iteration_space, access_function) + + if isinstance(iteration_space, IterSpacRepr) and len(access_function) == 2: + return StatInstRepr(iteration_space, AffiAcceRepr(*access_function)) + + if len(iteration_space) == 2 and isinstance(access_function, AffiAcceRepr): + return StatInstRepr(IterSpacRepr(*iteration_space), access_function) + + if len(iteration_space) == 2 and len(access_function) == 2: + return StatInstRepr( + IterSpacRepr(*iteration_space), AffiAcceRepr(*access_function) + ) + + raise ValueError(f"Cannot convert {str(array_access)} to StatInstRepr") + + +def has_data_dependency( + first_array_access: StatInstRepr, second_array_access: StatInstRepr +): + first_array_access = _ensure_correct_StatInstRepr(first_array_access) + second_array_access = _ensure_correct_StatInstRepr(second_array_access) + + _assert_correct_dimensions(first_array_access, second_array_access) + + return has_data_dependency_impl(first_array_access, second_array_access) + + +def construct_affine_array_access_function_representation( + array_dimensions_expr: tuple(), additional_variables: List[str] = None +): + """ + Construct a matrix, vector representation of the access function of an array. + E.g. z[i], where the expression ("i", ) should be passed to this function, + y[1+3,4-j], where ("1+3", "4-j") should be passed to this function, + if var=Array(...), then var.dimensions should be passed to this function. + Returns: matrix, vector: F,f mapping a vecector i within the bounds Bi+b>=0 to the + array location Fi+f + """ + + def generate_row(expr, variables): + supported_types = (sym.TypedSymbol, sym.MetaSymbol, sym.Sum, sym.Product) + if not (is_constant(expr) or isinstance(expr, supported_types)): + raise ValueError(f"Cannot derive inequality from expr {str(expr)}") + simplified_expr = simplify(expr) + terms = accumulate_polynomial_terms(simplified_expr) + const_term = terms.pop(1, 0) # Constant term or 0 + row = np.zeros(len(variables), dtype=np.dtype(int)) + + for base, coef in terms.items(): + if not len(base) == 1: + raise ValueError(f"Non-affine bound {str(simplified_expr)}") + row[variables.index(base[0].name.lower())] = coef + + return row, const_term + + def unique_order_preserving(sequence): + seen = set() + return [x for x in sequence if not (x in seen or seen.add(x))] + + if additional_variables is None: + additional_variables = list() + + for variable in additional_variables: + assert variable.lower() == variable + + variables = additional_variables.copy() + variables += list( + {v.name.lower() for v in FindVariables().visit(array_dimensions_expr)} + ) + variables = unique_order_preserving(variables) + + n = len(array_dimensions_expr) + d = len(variables) + + F = np.zeros([n, d], dtype=np.dtype(int)) + f = np.zeros([n, 1], dtype=np.dtype(int)) + + for dimension_index, sub_expr in enumerate(array_dimensions_expr): + row, constant = generate_row(sub_expr, variables) + F[dimension_index] = row + f[dimension_index, 0] = constant + return AffiAcceRepr(F, f), variables diff --git a/tests/test_analyse_array_data_dependency_detection.py b/tests/test_analyse_array_data_dependency_detection.py index 984b7a06c..4f3304d5d 100644 --- a/tests/test_analyse_array_data_dependency_detection.py +++ b/tests/test_analyse_array_data_dependency_detection.py @@ -10,6 +10,23 @@ from loki.analyse.analyse_array_data_dependency_detection import has_data_dependency +@pytest.mark.parametrize( + "first_access_represenetation, second_access_represenetation, error_type", + [ + # totaly empty tuple + (tuple(), tuple(), ValueError), + # tuple with two empty tuple --> correct form + ((tuple(), tuple()), (tuple(), tuple()), ValueError), + # + ], +) +def test_has_data_dependency_false_inputs( + first_access_represenetation, second_access_represenetation, error_type +): + with pytest.raises(error_type): + _ = has_data_dependency(first_access_represenetation, second_access_represenetation) + + @pytest.mark.parametrize( "first_access_represenetation, second_access_represenetation, result", [