Skip to content

Commit

Permalink
Simplify and provide construct_affine_array_access_function_represent…
Browse files Browse the repository at this point in the history
…ation
  • Loading branch information
joscao committed Oct 27, 2023
1 parent d9b97aa commit 14633ac
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 55 deletions.
206 changes: 151 additions & 55 deletions loki/analyse/analyse_array_data_dependency_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,29 @@
from math import gcd
from warnings import warn
from dataclasses import dataclass
from typing import Any, List
from numpy.typing import NDArray as npt_NDArray
from numpy import int_ as np_int_
import numpy as np
from ortools.linear_solver import pywraplp

Check failure on line 15 in loki/analyse/analyse_array_data_dependency_detection.py

View workflow job for this annotation

GitHub Actions / code checks (3.11)

E0401: Unable to import 'ortools.linear_solver' (import-error)

from loki.expression import (
FindVariables,
accumulate_polynomial_terms,
simplify,
is_constant,
symbols as sym,
)
from loki.analyse.util_linear_algebra import (
generate_row_echelon_form,
back_substitution,
is_independent_system,
yield_one_d_systems,
)

__all__ = ["has_data_dependency"]
__all__ = [
"has_data_dependency",
"construct_affine_array_access_function_representation",
]

NDArrayInt = npt_NDArray[np_int_]

Expand All @@ -36,6 +46,10 @@ class IterationSpaceRepresentation:
vector: NDArrayInt


# abbrivation for readability
IterSpacRepr = IterationSpaceRepresentation


@dataclass
class AffineArrayAccessFunctionRepresentation:
"""
Expand All @@ -47,31 +61,31 @@ class AffineArrayAccessFunctionRepresentation:
vector: NDArrayInt


# abbrivation for readability
AffiAcceRepr = AffineArrayAccessFunctionRepresentation


@dataclass
class StaticArrayAccessInstanceRepresentation:
"""
Represents a static array access instance using an iteration space representation and
an affine array access function representation.
"""

iteration_space: IterationSpaceRepresentation
access_function: AffineArrayAccessFunctionRepresentation
iteration_space: IterSpacRepr
access_function: AffiAcceRepr

def __init__(
self,
iteration_space: IterationSpaceRepresentation,
access_function: AffineArrayAccessFunctionRepresentation,
):
self.iteration_space = IterationSpaceRepresentation(*iteration_space)
self.access_function = AffineArrayAccessFunctionRepresentation(*access_function)

# abbrivation for readability
StatInstRepr = StaticArrayAccessInstanceRepresentation


def _assert_correct_dimensions(
first_array_access: StaticArrayAccessInstanceRepresentation,
second_array_access: StaticArrayAccessInstanceRepresentation,
first_array_access: StatInstRepr,
second_array_access: StatInstRepr,
) -> None:
def assert_compatibility_of_iteration_space(
first: IterationSpaceRepresentation, second: IterationSpaceRepresentation
first: IterSpacRepr, second: IterSpacRepr
):
(number_inequalties_1, number_variables_1) = first.matrix.shape
(rhs_number_inequalities_1, column_count) = first.vector.shape
Expand All @@ -94,8 +108,8 @@ def assert_compatibility_of_iteration_space(
)

def assert_compatibility_of_access_function(
first: AffineArrayAccessFunctionRepresentation,
second: AffineArrayAccessFunctionRepresentation,
first: AffiAcceRepr,
second: AffiAcceRepr,
):
(number_rows_1, number_columns_1) = first.matrix.shape
(rhs_number_rows_1, column_count) = first.vector.shape
Expand All @@ -118,8 +132,8 @@ def assert_compatibility_of_access_function(
)

def assert_compatibility_of_access_and_iteration_unkowns(
function: AffineArrayAccessFunctionRepresentation,
space: IterationSpaceRepresentation,
function: AffiAcceRepr,
space: IterSpacRepr,
):
(_, number_variables_space) = space.matrix.shape
(_, number_variables_function) = function.matrix.shape
Expand Down Expand Up @@ -260,41 +274,21 @@ def _solve_inequalties_with_ortools(
return solver.Solve()


def has_data_dependency(
first_array_access: StaticArrayAccessInstanceRepresentation,
second_array_access: StaticArrayAccessInstanceRepresentation,
def has_data_dependency_impl(
first_array_access: StatInstRepr,
second_array_access: StatInstRepr,
) -> bool:
if not isinstance(first_array_access, StaticArrayAccessInstanceRepresentation):
first_array_access = StaticArrayAccessInstanceRepresentation(
*first_array_access
)

if not isinstance(second_array_access, StaticArrayAccessInstanceRepresentation):
second_array_access = StaticArrayAccessInstanceRepresentation(
*second_array_access
)

_assert_correct_dimensions(first_array_access, second_array_access)

F, f = (
first_array_access.access_function.matrix,
first_array_access.access_function.vector,
)
F_dash, f_dash = (
second_array_access.access_function.matrix,
second_array_access.access_function.vector,
)

# fmt: off
(has_solution, solved_system) = _gaussian_eliminiation_for_diophantine_equations(
np.hstack(
(
F,
-F_dash,
f - f_dash,
first_array_access.access_function.matrix,
-second_array_access.access_function.matrix,
first_array_access.access_function.vector - second_array_access.access_function.vector,
)
)
)

# fmt: on
if not has_solution:
return False

Expand All @@ -304,18 +298,25 @@ def has_data_dependency(

R, S, g_hat = solved_system[:, :n], solved_system[:, n:-1], solved_system[:, -1:]

B, b = (
first_array_access.iteration_space.matrix,
first_array_access.iteration_space.vector,
# fmt: off
K = np.block(
[
[
first_array_access.iteration_space.matrix, np.zeros_like(first_array_access.iteration_space.matrix),
],
[
np.zeros_like(second_array_access.iteration_space.matrix), second_array_access.iteration_space.matrix,
],
]
)
B_dash, b_dash = (
second_array_access.iteration_space.matrix,
second_array_access.iteration_space.vector,
# fmt: on
k = np.vstack(
(
first_array_access.iteration_space.vector,
second_array_access.iteration_space.vector,
)
)

K = np.block([[B, np.zeros_like(B)], [np.zeros_like(B_dash), B_dash]])
k = np.vstack((b, b_dash))

n, m = R.shape
K_R = K[0:n, 0:m]

Expand Down Expand Up @@ -350,3 +351,98 @@ def has_data_dependency(
)

return True # assume data dependency if all tests are inconclusive


def _ensure_correct_StatInstRepr(array_access: Any) -> StatInstRepr:
if isinstance(array_access, StatInstRepr):
return array_access

if len(array_access) == 2:
iteration_space = array_access[0]
access_function = array_access[1]

if isinstance(iteration_space, IterSpacRepr) and isinstance(
access_function, AffiAcceRepr
):
return StatInstRepr(iteration_space, access_function)

if isinstance(iteration_space, IterSpacRepr) and len(access_function) == 2:
return StatInstRepr(iteration_space, AffiAcceRepr(*access_function))

if len(iteration_space) == 2 and isinstance(access_function, AffiAcceRepr):
return StatInstRepr(IterSpacRepr(*iteration_space), access_function)

if len(iteration_space) == 2 and len(access_function) == 2:
return StatInstRepr(
IterSpacRepr(*iteration_space), AffiAcceRepr(*access_function)
)

raise ValueError(f"Cannot convert {str(array_access)} to StatInstRepr")


def has_data_dependency(
first_array_access: StatInstRepr, second_array_access: StatInstRepr
):
first_array_access = _ensure_correct_StatInstRepr(first_array_access)
second_array_access = _ensure_correct_StatInstRepr(second_array_access)

_assert_correct_dimensions(first_array_access, second_array_access)

return has_data_dependency_impl(first_array_access, second_array_access)


def construct_affine_array_access_function_representation(
array_dimensions_expr: tuple(), additional_variables: List[str] = None
):
"""
Construct a matrix, vector representation of the access function of an array.
E.g. z[i], where the expression ("i", ) should be passed to this function,
y[1+3,4-j], where ("1+3", "4-j") should be passed to this function,
if var=Array(...), then var.dimensions should be passed to this function.
Returns: matrix, vector: F,f mapping a vecector i within the bounds Bi+b>=0 to the
array location Fi+f
"""

def generate_row(expr, variables):
supported_types = (sym.TypedSymbol, sym.MetaSymbol, sym.Sum, sym.Product)
if not (is_constant(expr) or isinstance(expr, supported_types)):
raise ValueError(f"Cannot derive inequality from expr {str(expr)}")
simplified_expr = simplify(expr)
terms = accumulate_polynomial_terms(simplified_expr)
const_term = terms.pop(1, 0) # Constant term or 0
row = np.zeros(len(variables), dtype=np.dtype(int))

for base, coef in terms.items():
if not len(base) == 1:
raise ValueError(f"Non-affine bound {str(simplified_expr)}")
row[variables.index(base[0].name.lower())] = coef

return row, const_term

def unique_order_preserving(sequence):
seen = set()
return [x for x in sequence if not (x in seen or seen.add(x))]

if additional_variables is None:
additional_variables = list()

Check failure on line 427 in loki/analyse/analyse_array_data_dependency_detection.py

View workflow job for this annotation

GitHub Actions / code checks (3.11)

R1734: Consider using [] instead of list() (use-list-literal)

for variable in additional_variables:
assert variable.lower() == variable

variables = additional_variables.copy()
variables += list(
{v.name.lower() for v in FindVariables().visit(array_dimensions_expr)}
)
variables = unique_order_preserving(variables)

n = len(array_dimensions_expr)
d = len(variables)

F = np.zeros([n, d], dtype=np.dtype(int))
f = np.zeros([n, 1], dtype=np.dtype(int))

for dimension_index, sub_expr in enumerate(array_dimensions_expr):
row, constant = generate_row(sub_expr, variables)
F[dimension_index] = row
f[dimension_index, 0] = constant
return AffiAcceRepr(F, f), variables
17 changes: 17 additions & 0 deletions tests/test_analyse_array_data_dependency_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,23 @@
from loki.analyse.analyse_array_data_dependency_detection import has_data_dependency


@pytest.mark.parametrize(
"first_access_represenetation, second_access_represenetation, error_type",
[
# totaly empty tuple
(tuple(), tuple(), ValueError),
# tuple with two empty tuple --> correct form
((tuple(), tuple()), (tuple(), tuple()), ValueError),
#
],
)
def test_has_data_dependency_false_inputs(
first_access_represenetation, second_access_represenetation, error_type
):
with pytest.raises(error_type):
_ = has_data_dependency(first_access_represenetation, second_access_represenetation)


@pytest.mark.parametrize(
"first_access_represenetation, second_access_represenetation, result",
[
Expand Down

0 comments on commit 14633ac

Please sign in to comment.