From 9cdb87f97d8157fb61ce5f4d95f14a03d2bfe5bc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 8 Dec 2023 18:24:44 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- graph_weather/models/graphs/ico.py | 41 ++++++++++++++++++---------- graph_weather/models/graphs/utils.py | 15 ++++------ 2 files changed, 32 insertions(+), 24 deletions(-) diff --git a/graph_weather/models/graphs/ico.py b/graph_weather/models/graphs/ico.py index 8668bddd..b1599181 100644 --- a/graph_weather/models/graphs/ico.py +++ b/graph_weather/models/graphs/ico.py @@ -38,7 +38,6 @@ from sklearn.neighbors import NearestNeighbors - def icosphere(nu=1, nr_verts=None): """ Returns a geodesic icosahedron with subdivision frequency nu. Frequency @@ -323,6 +322,7 @@ def generate_icosphere_mapping(lat_lons, resolutions=(1, 2, 4, 8, 16), bidirecti from torch_geometric.data import Data from graph_weather.models.graphs.utils import deg2rad, latlon2xyz, xyz2latlon, get_edge_len from sklearn.neighbors import NearestNeighbors + num_latlons = len(lat_lons) verticies_per_level = [] edges_per_level = [] @@ -333,7 +333,7 @@ def generate_icosphere_mapping(lat_lons, resolutions=(1, 2, 4, 8, 16), bidirecti # Check the verticies of each pair are the same up to the resolution for i in range(len(verticies_per_level) - 2): - #print(edges_per_level[i]) + # print(edges_per_level[i]) for vertex_lower_index, vertex in enumerate(verticies_per_level[i]): # Go through each index in the current level, finding the closest vertex in the next level # Should check all verticies in the next level, and find the closest one @@ -347,14 +347,14 @@ def generate_icosphere_mapping(lat_lons, resolutions=(1, 2, 4, 8, 16), bidirecti edges_per_level[i][edge_index][0] = vertex_upper_index if edge[1] == vertex_lower_index: edges_per_level[i][edge_index][1] = vertex_upper_index - #print(edges_per_level[i][edges_per_level[i] == vertex_lower_index]) + # print(edges_per_level[i][edges_per_level[i] == vertex_lower_index]) # The vertex is the same, so the edges in the current level that equal vertex_lower_index should be changed to equal vertex_upper_index - #edges_per_level[i][edges_per_level[i] == vertex_upper_index] = vertex_upper_index + # edges_per_level[i][edges_per_level[i] == vertex_upper_index] = vertex_upper_index if multiple_equals > 1: print(f"Multiple equals: {multiple_equals}") - #print(edges_per_level[i]) - #print("------------------") - verticies = verticies_per_level[-1] # The last layer has all the verticies of the ones above + # print(edges_per_level[i]) + # print("------------------") + verticies = verticies_per_level[-1] # The last layer has all the verticies of the ones above edges = np.sort(np.concatenate(edges_per_level), axis=1) print(f"Number of edges: {len(edges)}") if bidirectional: @@ -363,16 +363,24 @@ def generate_icosphere_mapping(lat_lons, resolutions=(1, 2, 4, 8, 16), bidirecti # TODO Create mapping from the lat/lon to the icosphere nodes print(f"Now Number of edges: {len(edges)}") print(f"Number of unique edges: {len(np.unique(edges, axis=0))}") - print(f"Max number of repeated edges: {np.max(np.unique(edges, axis=0, return_counts=True)[1])}") + print( + f"Max number of repeated edges: {np.max(np.unique(edges, axis=0, return_counts=True)[1])}" + ) u, c = np.unique(edges, axis=0, return_counts=True) print(f"First 20 duplicates: {u[c > 1][:20]}") xyz_grid = latlon2xyz(torch.tensor(lat_lons, dtype=torch.float)) # Find the closest vertex to each point vertex_mapping = np.argmin(np.sum(np.abs(verticies - xyz_grid[:, None]), axis=2), axis=1) # Features will need to be in the same lat/lon order as given, and added to the verticies - ico_graph = Data(pos=torch.tensor(verticies, dtype=torch.float), - edge_index=torch.tensor(edges, dtype=torch.long).t().contiguous()) - max_edge_len = np.max(get_edge_len(ico_graph.pos[edges_per_level[-1][:, 0]], ico_graph.pos[edges_per_level[-1][:, 1]])) + ico_graph = Data( + pos=torch.tensor(verticies, dtype=torch.float), + edge_index=torch.tensor(edges, dtype=torch.long).t().contiguous(), + ) + max_edge_len = np.max( + get_edge_len( + ico_graph.pos[edges_per_level[-1][:, 0]], ico_graph.pos[edges_per_level[-1][:, 1]] + ) + ) # create the grid2mesh bipartite graph cartesian_grid = latlon2xyz(lat_lons) n_nbrs = 4 @@ -389,10 +397,14 @@ def generate_icosphere_mapping(lat_lons, resolutions=(1, 2, 4, 8, 16), bidirecti ico_graph.validate(raise_on_error=True) return ico_graph -generate_icosphere_mapping([(0,0), (0,1), (1,0), (1,1)]) -def get_grid_to_mesh(lat_lons: torch.Tensor, mesh: Data): - max_edge_len = np.max(get_edge_len(mesh.pos[mesh.edge_index[:,0]], mesh.pos[mesh.edge_index[:,1]])) +generate_icosphere_mapping([(0, 0), (0, 1), (1, 0), (1, 1)]) + + +def get_grid_to_mesh(lat_lons: torch.Tensor, mesh: Data): + max_edge_len = np.max( + get_edge_len(mesh.pos[mesh.edge_index[:, 0]], mesh.pos[mesh.edge_index[:, 1]]) + ) # create the grid2mesh bipartite graph cartesian_grid = latlon2xyz(lat_lons) @@ -407,6 +419,7 @@ def get_grid_to_mesh(lat_lons: torch.Tensor, mesh: Data): src.append(i) dst.append(indices[i][j]) + def generate_latent_ico_graph(h3_mapping, h3_distances): """ Generate latent h3 graph. diff --git a/graph_weather/models/graphs/utils.py b/graph_weather/models/graphs/utils.py index b6d868d2..fe76a4bb 100644 --- a/graph_weather/models/graphs/utils.py +++ b/graph_weather/models/graphs/utils.py @@ -143,6 +143,7 @@ def rad2deg(rad): """ return rad * 180 / np.pi + def azimuthal_angle(lon: Tensor) -> Tensor: """ Gives the azimuthal angle of a point on the sphere @@ -179,9 +180,7 @@ def polar_angle(lat: Tensor) -> Tensor: return angle -def geospatial_rotation( - invar: Tensor, theta: Tensor, axis: str, unit: str = "rad" -) -> Tensor: +def geospatial_rotation(invar: Tensor, theta: Tensor, axis: str, unit: str = "rad") -> Tensor: """Rotation using right hand rule Parameters @@ -296,9 +295,7 @@ def add_edge_features(graph: Data, pos: Tensor, normalize: bool = True) -> Data: # normalize using the longest edge if normalize: max_disp_norm = torch.max(disp_norm) - graph["edge_attr"] = torch.cat( - (disp / max_disp_norm, disp_norm / max_disp_norm), dim=-1 - ) + graph["edge_attr"] = torch.cat((disp / max_disp_norm, disp_norm / max_disp_norm), dim=-1) else: graph["edge_attr"] = torch.cat((disp, disp_norm), dim=-1) return graph @@ -322,9 +319,7 @@ def add_node_features(graph: Data, pos: Tensor) -> Data: """ latlon = xyz2latlon(pos) lat, lon = latlon[:, 0], latlon[:, 1] - graph["x"] = torch.stack( - (torch.cos(lat), torch.sin(lon), torch.cos(lon)), dim=-1 - ) + graph["x"] = torch.stack((torch.cos(lat), torch.sin(lon), torch.cos(lon)), dim=-1) return graph @@ -383,5 +378,5 @@ def _format_axes(ax): for angle in range(0, 360): ax.view_init(30, angle) plt.draw() - plt.pause(.01) + plt.pause(0.01) plt.show()