Skip to content

Commit

Permalink
Tame yield_one_d_systems
Browse files Browse the repository at this point in the history
  • Loading branch information
joscao committed Nov 2, 2023
1 parent 7de1b43 commit 1754f10
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 27 deletions.
23 changes: 7 additions & 16 deletions loki/analyse/util_linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,28 +78,19 @@ def yield_one_d_systems(matrix, right_hand_side):
solution = solve_one_d_system(A, b)
```
"""
# drop completly empty rows
mask = ~np.all(np.hstack((matrix, right_hand_side)) == 0, axis=1)
matrix = matrix[mask]
right_hand_side = right_hand_side[mask]

# yield systems with empty left hand side (A) and non empty right hand side
mask = np.all(matrix == 0, axis=1)
if right_hand_side[mask].size == 0:
return

for A, b in zip(matrix[mask].T, right_hand_side[mask].T):
yield A, b
if right_hand_side[mask].size != 0:
for A in matrix[mask].T:
yield A.reshape((-1,1)), right_hand_side[mask]

matrix = matrix[~mask]
right_hand_side = right_hand_side[~mask]

if right_hand_side.size == 0:
return

for A, b in zip(matrix.T, right_hand_side.T):
mask = A != 0
yield A[mask], b[mask]
if right_hand_side.size != 0:
for A in matrix.T:
mask = A != 0
yield A[mask].reshape((-1,1)), right_hand_side[mask]


def back_substitution(
Expand Down
24 changes: 13 additions & 11 deletions tests/test_util_linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,30 +157,32 @@ def test_is_independent_system(matrix, expected_result):
(
np.array([[1, 0], [0, 1], [0, 0]]),
np.array([[1], [2], [0]]),
[[1], [1]],
[[1], [2]],
[np.array([[0]]), np.array([[0]]), np.array([[1]]), np.array([[1]])],
[np.array([[0]]), np.array([[0]]), np.array([[1]]), np.array([[2]])],
),
(
np.array([[1, 0], [0, 1], [0, 0]]),
np.array([[1], [2], [1]]),
[[0], [1], [1]],
[[1], [1], [2]],
[np.array([[0]]), np.array([[0]]), np.array([[1]]), np.array([[1]])],
[np.array([[1]]), np.array([[1]]), np.array([[1]]), np.array([[2]])],
),
(
np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]]),
np.array([[1], [2], [3]]),
[[0]],
[[1, 2, 3]],
[np.array([[0], [0], [0]])] * 3,
[np.array([[1], [2], [3]])] * 3,
),
( # will even split non independent systems, call is_independent_system before
np.array([[2, 1], [1, 3]]),
np.array([[3], [4]]),
[[2], [1]],
[[3], [4]],
[np.array([[2], [1]]), np.array([[1], [3]])],
[np.array([[3], [4]]), np.array([[3], [4]])],
),
],
)
def test_yield_one_d_systems(matrix, rhs, list_of_lhs_column, list_of_rhs_column):
for index, (A, b) in enumerate(yield_one_d_systems(matrix, rhs)):
assert np.allclose(A, list_of_lhs_column[index])
assert np.allclose(b, list_of_rhs_column[index])
results = list(yield_one_d_systems(matrix, rhs))
assert len(results) == len(list_of_lhs_column) == len(list_of_rhs_column)
for index, (A, b) in enumerate(results):
assert np.array_equal(A, list_of_lhs_column[index])
assert np.array_equal(b, list_of_rhs_column[index])

0 comments on commit 1754f10

Please sign in to comment.