Skip to content

Commit

Permalink
Merge pull request #744 from abja-dhi/abja-updated_nodeCellID
Browse files Browse the repository at this point in the history
Update _FM_utils.py
  • Loading branch information
ecomodeller authored Nov 19, 2024
2 parents 81e10c8 + 12348f4 commit b441e98
Showing 1 changed file with 33 additions and 9 deletions.
42 changes: 33 additions & 9 deletions mikeio/spatial/_FM_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from matplotlib.figure import Figure
from matplotlib.tri import Triangulation
import numpy as np
from scipy.sparse import csr_matrix
from collections import namedtuple

from ._utils import _relative_cumulative_distance
Expand Down Expand Up @@ -643,6 +644,35 @@ def _to_polygons(node_coordinates: np.ndarray, element_table: np.ndarray) -> lis
return polygons


def _create_node_element_matrix(
element_table: np.ndarray, num_nodes: int
) -> csr_matrix:
"""Creates a sparse node-element connectivity matrix from a given element table.
Parameters
----------
element_table : np.array
The element table (A 2D array where each row represents an element and each
column corresponds to a node index involved in the element.)
num_nodes : int
The total number of nodes in the mesh.
Returns
-------
scipy.sparse.csr_matrix
A sparse matrix of shape (num_nodes, number of elements), where the entry
(i, j) is 1 if node i is part of element j, and 0 otherwise.
"""
row_ind = element_table.ravel()
col_ind = np.repeat(np.arange(element_table.shape[0]), element_table.shape[1])
data = np.ones(len(row_ind), dtype=int)
connectivity_matrix = csr_matrix(
(data, (row_ind, col_ind)), shape=(num_nodes, element_table.shape[0])
)
return connectivity_matrix


def _get_node_centered_data(
node_coordinates: np.ndarray,
element_table: np.ndarray,
Expand Down Expand Up @@ -675,17 +705,11 @@ def _get_node_centered_data(
elem_table, ec, data = __create_tri_only_element_table(
nc, element_table, element_coordinates, data
)
connectivity_matrix = _create_node_element_matrix(elem_table, nc.shape[0])

node_cellID = [
list(np.argwhere(elem_table == i)[:, 0])
for i in np.unique(
elem_table.reshape(
-1,
)
)
]
node_centered_data = np.zeros(shape=nc.shape[0])
for n, item in enumerate(node_cellID):
for n in range(connectivity_matrix.shape[0]):
item = connectivity_matrix.getrow(n).indices
I = ec[item][:, :2] - nc[n][:2]
I2 = (I**2).sum(axis=0)
Ixy = (I[:, 0] * I[:, 1]).sum(axis=0)
Expand Down

0 comments on commit b441e98

Please sign in to comment.