Skip to content

Commit

Permalink
Refactor grid to mesh and mesh to grid
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Dec 9, 2023
1 parent 9cdb87f commit d705712
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 35 deletions.
41 changes: 7 additions & 34 deletions graph_weather/models/graphs/ico.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@

import numpy as np
import torch
from torch_geometric.data import Data
from graph_weather.models.graphs.utils import deg2rad, latlon2xyz, xyz2latlon, get_edge_len
from sklearn.neighbors import NearestNeighbors
from torch_geometric.data import Data, HeteroData
from graph_weather.models.graphs.utils import generate_grid_to_mesh, generate_mesh_to_grid


def icosphere(nu=1, nr_verts=None):
Expand Down Expand Up @@ -368,9 +367,6 @@ def generate_icosphere_mapping(lat_lons, resolutions=(1, 2, 4, 8, 16), bidirecti
)
u, c = np.unique(edges, axis=0, return_counts=True)
print(f"First 20 duplicates: {u[c > 1][:20]}")
xyz_grid = latlon2xyz(torch.tensor(lat_lons, dtype=torch.float))
# Find the closest vertex to each point
vertex_mapping = np.argmin(np.sum(np.abs(verticies - xyz_grid[:, None]), axis=2), axis=1)
# Features will need to be in the same lat/lon order as given, and added to the verticies
ico_graph = Data(
pos=torch.tensor(verticies, dtype=torch.float),
Expand All @@ -381,43 +377,20 @@ def generate_icosphere_mapping(lat_lons, resolutions=(1, 2, 4, 8, 16), bidirecti
ico_graph.pos[edges_per_level[-1][:, 0]], ico_graph.pos[edges_per_level[-1][:, 1]]
)
)
# create the grid2mesh bipartite graph
cartesian_grid = latlon2xyz(lat_lons)
n_nbrs = 4
neighbors = NearestNeighbors(n_neighbors=n_nbrs).fit(ico_graph.pos)
distances, indices = neighbors.kneighbors(cartesian_grid)

src, dst = [], []
for i in range(len(cartesian_grid)):
for j in range(n_nbrs):
if distances[i][j] <= 0.6 * max_edge_len:
src.append(i)
dst.append(indices[i][j])
# Check that the graph is valid
ico_graph.validate(raise_on_error=True)

# Generate grid to mesh and mesh to graph
grid_to_mesh = generate_grid_to_mesh(lat_lons, ico_graph, max_edge_length=max_edge_len)
mesh_to_grid = generate_mesh_to_grid(lat_lons, ico_graph)

return ico_graph


generate_icosphere_mapping([(0, 0), (0, 1), (1, 0), (1, 1)])


def get_grid_to_mesh(lat_lons: torch.Tensor, mesh: Data):
max_edge_len = np.max(
get_edge_len(mesh.pos[mesh.edge_index[:, 0]], mesh.pos[mesh.edge_index[:, 1]])
)

# create the grid2mesh bipartite graph
cartesian_grid = latlon2xyz(lat_lons)
n_nbrs = 4
neighbors = NearestNeighbors(n_neighbors=n_nbrs).fit(mesh.pos)
distances, indices = neighbors.kneighbors(cartesian_grid)

src, dst = [], []
for i in range(len(cartesian_grid)):
for j in range(n_nbrs):
if distances[i][j] <= 0.6 * max_edge_len:
src.append(i)
dst.append(indices[i][j])


def generate_latent_ico_graph(h3_mapping, h3_distances):
Expand Down
61 changes: 60 additions & 1 deletion graph_weather/models/graphs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
"""

import torch
from typing import Optional, Tuple
from torch import Tensor, testing
import numpy as np
from torch_geometric.data import Data, HeteroData
from sklearn.neighbors import NearestNeighbors


def latlon2xyz(latlon: Tensor, radius: float = 1, unit: str = "deg") -> Tensor:
Expand Down Expand Up @@ -307,7 +309,7 @@ def add_node_features(graph: Data, pos: Tensor) -> Data:
Parameters
----------
graph : DGLGraph
graph : Data
The graph to add node features to.
pos : Tensor
The node positions.
Expand All @@ -323,6 +325,63 @@ def add_node_features(graph: Data, pos: Tensor) -> Data:
return graph


def generate_grid_to_mesh(lat_lons: torch.Tensor, mesh: Data, max_edge_length: Optional[float] = None) -> HeteroData:
if max_edge_length is None:
max_edge_len = np.max(
get_edge_len(mesh.pos[mesh.edge_index[:, 0]], mesh.pos[mesh.edge_index[:, 1]])
)
else:
max_edge_len = max_edge_length

# create the grid2mesh bipartite graph
cartesian_grid = latlon2xyz(lat_lons)
n_nbrs = 4
neighbors = NearestNeighbors(n_neighbors=n_nbrs).fit(mesh.pos)
distances, indices = neighbors.kneighbors(cartesian_grid)

src, dst = [], []
for i in range(len(cartesian_grid)):
for j in range(n_nbrs):
if distances[i][j] <= 0.6 * max_edge_len:
src.append(i)
dst.append(indices[i][j])
# This is in COO format now, and it is not bidirectional, so no copying
grid2mesh = HeteroData()
grid2mesh["grid"].pos = torch.tensor(cartesian_grid, dtype=torch.float)
grid2mesh["mesh"].pos = mesh.pos
grid2mesh["grid", "to", "mesh"].edge_index = torch.tensor([src, dst], dtype=torch.long)
# Add edge features
grid2mesh = add_edge_features(grid2mesh, (grid2mesh["grid"].pos, grid2mesh["mesh"].pos))
return grid2mesh


def generate_mesh_to_grid(lat_lons: torch.Tensor, mesh: Data):
# create the mesh2grid bipartite graph
cartesian_grid = latlon2xyz(lat_lons)
n_nbrs = 1
neighbors = NearestNeighbors(n_neighbors=n_nbrs).fit(
mesh.pos
)
_, indices = neighbors.kneighbors(cartesian_grid)
indices = indices.flatten()

src = [
p
for i in indices
for p in mesh.pos[i]
]
dst = [i for i in range(len(cartesian_grid)) for _ in range(3)]

mesh2grid = HeteroData()
mesh2grid["mesh"].pos = mesh.pos
mesh2grid["grid"].pos = torch.tensor(cartesian_grid, dtype=torch.float)
mesh2grid["mesh", "to", "grid"].edge_index = torch.tensor([src, dst], dtype=torch.long)
# Add edge features
mesh2grid = add_edge_features(mesh2grid, (mesh2grid["mesh"].pos, mesh2grid["grid"].pos))

return mesh2grid


def plot_graph(graph: Data, **kwargs):
"""Plots the graph.
Expand Down

0 comments on commit d705712

Please sign in to comment.