From 1754f106bac18801f7aebcfc82a5ad61ba6637fb Mon Sep 17 00:00:00 2001 From: joscao <62660664+joscao@users.noreply.github.com> Date: Thu, 2 Nov 2023 09:33:22 +0100 Subject: [PATCH] Tame yield_one_d_systems --- loki/analyse/util_linear_algebra.py | 23 +++++++---------------- tests/test_util_linear_algebra.py | 24 +++++++++++++----------- 2 files changed, 20 insertions(+), 27 deletions(-) diff --git a/loki/analyse/util_linear_algebra.py b/loki/analyse/util_linear_algebra.py index 02c1d1f71..2c81f13a2 100644 --- a/loki/analyse/util_linear_algebra.py +++ b/loki/analyse/util_linear_algebra.py @@ -78,28 +78,19 @@ def yield_one_d_systems(matrix, right_hand_side): solution = solve_one_d_system(A, b) ``` """ - # drop completly empty rows - mask = ~np.all(np.hstack((matrix, right_hand_side)) == 0, axis=1) - matrix = matrix[mask] - right_hand_side = right_hand_side[mask] - # yield systems with empty left hand side (A) and non empty right hand side mask = np.all(matrix == 0, axis=1) - if right_hand_side[mask].size == 0: - return - - for A, b in zip(matrix[mask].T, right_hand_side[mask].T): - yield A, b + if right_hand_side[mask].size != 0: + for A in matrix[mask].T: + yield A.reshape((-1,1)), right_hand_side[mask] matrix = matrix[~mask] right_hand_side = right_hand_side[~mask] - if right_hand_side.size == 0: - return - - for A, b in zip(matrix.T, right_hand_side.T): - mask = A != 0 - yield A[mask], b[mask] + if right_hand_side.size != 0: + for A in matrix.T: + mask = A != 0 + yield A[mask].reshape((-1,1)), right_hand_side[mask] def back_substitution( diff --git a/tests/test_util_linear_algebra.py b/tests/test_util_linear_algebra.py index 2c216fa8e..caf45dda4 100644 --- a/tests/test_util_linear_algebra.py +++ b/tests/test_util_linear_algebra.py @@ -157,30 +157,32 @@ def test_is_independent_system(matrix, expected_result): ( np.array([[1, 0], [0, 1], [0, 0]]), np.array([[1], [2], [0]]), - [[1], [1]], - [[1], [2]], + [np.array([[0]]), np.array([[0]]), np.array([[1]]), np.array([[1]])], + [np.array([[0]]), np.array([[0]]), np.array([[1]]), np.array([[2]])], ), ( np.array([[1, 0], [0, 1], [0, 0]]), np.array([[1], [2], [1]]), - [[0], [1], [1]], - [[1], [1], [2]], + [np.array([[0]]), np.array([[0]]), np.array([[1]]), np.array([[1]])], + [np.array([[1]]), np.array([[1]]), np.array([[1]]), np.array([[2]])], ), ( np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]]), np.array([[1], [2], [3]]), - [[0]], - [[1, 2, 3]], + [np.array([[0], [0], [0]])] * 3, + [np.array([[1], [2], [3]])] * 3, ), ( # will even split non independent systems, call is_independent_system before np.array([[2, 1], [1, 3]]), np.array([[3], [4]]), - [[2], [1]], - [[3], [4]], + [np.array([[2], [1]]), np.array([[1], [3]])], + [np.array([[3], [4]]), np.array([[3], [4]])], ), ], ) def test_yield_one_d_systems(matrix, rhs, list_of_lhs_column, list_of_rhs_column): - for index, (A, b) in enumerate(yield_one_d_systems(matrix, rhs)): - assert np.allclose(A, list_of_lhs_column[index]) - assert np.allclose(b, list_of_rhs_column[index]) + results = list(yield_one_d_systems(matrix, rhs)) + assert len(results) == len(list_of_lhs_column) == len(list_of_rhs_column) + for index, (A, b) in enumerate(results): + assert np.array_equal(A, list_of_lhs_column[index]) + assert np.array_equal(b, list_of_rhs_column[index])