diff --git a/graph_weather/models/graphs/ico.py b/graph_weather/models/graphs/ico.py index b1599181..29707050 100644 --- a/graph_weather/models/graphs/ico.py +++ b/graph_weather/models/graphs/ico.py @@ -33,9 +33,8 @@ import numpy as np import torch -from torch_geometric.data import Data -from graph_weather.models.graphs.utils import deg2rad, latlon2xyz, xyz2latlon, get_edge_len -from sklearn.neighbors import NearestNeighbors +from torch_geometric.data import Data, HeteroData +from graph_weather.models.graphs.utils import generate_grid_to_mesh, generate_mesh_to_grid def icosphere(nu=1, nr_verts=None): @@ -368,9 +367,6 @@ def generate_icosphere_mapping(lat_lons, resolutions=(1, 2, 4, 8, 16), bidirecti ) u, c = np.unique(edges, axis=0, return_counts=True) print(f"First 20 duplicates: {u[c > 1][:20]}") - xyz_grid = latlon2xyz(torch.tensor(lat_lons, dtype=torch.float)) - # Find the closest vertex to each point - vertex_mapping = np.argmin(np.sum(np.abs(verticies - xyz_grid[:, None]), axis=2), axis=1) # Features will need to be in the same lat/lon order as given, and added to the verticies ico_graph = Data( pos=torch.tensor(verticies, dtype=torch.float), @@ -381,43 +377,20 @@ def generate_icosphere_mapping(lat_lons, resolutions=(1, 2, 4, 8, 16), bidirecti ico_graph.pos[edges_per_level[-1][:, 0]], ico_graph.pos[edges_per_level[-1][:, 1]] ) ) - # create the grid2mesh bipartite graph - cartesian_grid = latlon2xyz(lat_lons) - n_nbrs = 4 - neighbors = NearestNeighbors(n_neighbors=n_nbrs).fit(ico_graph.pos) - distances, indices = neighbors.kneighbors(cartesian_grid) - - src, dst = [], [] - for i in range(len(cartesian_grid)): - for j in range(n_nbrs): - if distances[i][j] <= 0.6 * max_edge_len: - src.append(i) - dst.append(indices[i][j]) # Check that the graph is valid ico_graph.validate(raise_on_error=True) + + # Generate grid to mesh and mesh to graph + grid_to_mesh = generate_grid_to_mesh(lat_lons, ico_graph, max_edge_length=max_edge_len) + mesh_to_grid = generate_mesh_to_grid(lat_lons, ico_graph) + return ico_graph generate_icosphere_mapping([(0, 0), (0, 1), (1, 0), (1, 1)]) -def get_grid_to_mesh(lat_lons: torch.Tensor, mesh: Data): - max_edge_len = np.max( - get_edge_len(mesh.pos[mesh.edge_index[:, 0]], mesh.pos[mesh.edge_index[:, 1]]) - ) - # create the grid2mesh bipartite graph - cartesian_grid = latlon2xyz(lat_lons) - n_nbrs = 4 - neighbors = NearestNeighbors(n_neighbors=n_nbrs).fit(mesh.pos) - distances, indices = neighbors.kneighbors(cartesian_grid) - - src, dst = [], [] - for i in range(len(cartesian_grid)): - for j in range(n_nbrs): - if distances[i][j] <= 0.6 * max_edge_len: - src.append(i) - dst.append(indices[i][j]) def generate_latent_ico_graph(h3_mapping, h3_distances): diff --git a/graph_weather/models/graphs/utils.py b/graph_weather/models/graphs/utils.py index fe76a4bb..302e4793 100644 --- a/graph_weather/models/graphs/utils.py +++ b/graph_weather/models/graphs/utils.py @@ -20,9 +20,11 @@ """ import torch +from typing import Optional, Tuple from torch import Tensor, testing import numpy as np from torch_geometric.data import Data, HeteroData +from sklearn.neighbors import NearestNeighbors def latlon2xyz(latlon: Tensor, radius: float = 1, unit: str = "deg") -> Tensor: @@ -307,7 +309,7 @@ def add_node_features(graph: Data, pos: Tensor) -> Data: Parameters ---------- - graph : DGLGraph + graph : Data The graph to add node features to. pos : Tensor The node positions. @@ -323,6 +325,63 @@ def add_node_features(graph: Data, pos: Tensor) -> Data: return graph +def generate_grid_to_mesh(lat_lons: torch.Tensor, mesh: Data, max_edge_length: Optional[float] = None) -> HeteroData: + if max_edge_length is None: + max_edge_len = np.max( + get_edge_len(mesh.pos[mesh.edge_index[:, 0]], mesh.pos[mesh.edge_index[:, 1]]) + ) + else: + max_edge_len = max_edge_length + + # create the grid2mesh bipartite graph + cartesian_grid = latlon2xyz(lat_lons) + n_nbrs = 4 + neighbors = NearestNeighbors(n_neighbors=n_nbrs).fit(mesh.pos) + distances, indices = neighbors.kneighbors(cartesian_grid) + + src, dst = [], [] + for i in range(len(cartesian_grid)): + for j in range(n_nbrs): + if distances[i][j] <= 0.6 * max_edge_len: + src.append(i) + dst.append(indices[i][j]) + # This is in COO format now, and it is not bidirectional, so no copying + grid2mesh = HeteroData() + grid2mesh["grid"].pos = torch.tensor(cartesian_grid, dtype=torch.float) + grid2mesh["mesh"].pos = mesh.pos + grid2mesh["grid", "to", "mesh"].edge_index = torch.tensor([src, dst], dtype=torch.long) + # Add edge features + grid2mesh = add_edge_features(grid2mesh, (grid2mesh["grid"].pos, grid2mesh["mesh"].pos)) + return grid2mesh + + +def generate_mesh_to_grid(lat_lons: torch.Tensor, mesh: Data): + # create the mesh2grid bipartite graph + cartesian_grid = latlon2xyz(lat_lons) + n_nbrs = 1 + neighbors = NearestNeighbors(n_neighbors=n_nbrs).fit( + mesh.pos + ) + _, indices = neighbors.kneighbors(cartesian_grid) + indices = indices.flatten() + + src = [ + p + for i in indices + for p in mesh.pos[i] + ] + dst = [i for i in range(len(cartesian_grid)) for _ in range(3)] + + mesh2grid = HeteroData() + mesh2grid["mesh"].pos = mesh.pos + mesh2grid["grid"].pos = torch.tensor(cartesian_grid, dtype=torch.float) + mesh2grid["mesh", "to", "grid"].edge_index = torch.tensor([src, dst], dtype=torch.long) + # Add edge features + mesh2grid = add_edge_features(mesh2grid, (mesh2grid["mesh"].pos, mesh2grid["grid"].pos)) + + return mesh2grid + + def plot_graph(graph: Data, **kwargs): """Plots the graph.