From 943cdff7de5a926aa438c39894a463d41c8991fd Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Thu, 7 Dec 2023 16:28:13 +0000 Subject: [PATCH] Add correctly generating icosphere graph Still need to add mapping from input grid to icosphere, and back --- graph_weather/models/graphs/ico.py | 48 ++++++++++-------------------- 1 file changed, 16 insertions(+), 32 deletions(-) diff --git a/graph_weather/models/graphs/ico.py b/graph_weather/models/graphs/ico.py index 2b0e5d7d..7ea449ad 100644 --- a/graph_weather/models/graphs/ico.py +++ b/graph_weather/models/graphs/ico.py @@ -268,17 +268,9 @@ def inside_points(vAB, vAC): / \ / \ / \ / \ .---.---.---.---. """ - out = [] - if vAB.shape[0] == 1: # Optimized code fails for nu=2 - v = [] - for i in range(1, vAB.shape[0]): - w = np.arange(1, i + 1) / (i + 1) - for k in range(i): - v.append(w[-1 - k] * vAB[i, :] + w[k] * vAC[i, :]) - - return np.array(v).reshape(-1, 3) # reshape needed for empty return - for i in range(1, vAB.shape[0]): + u = vAB.shape[0] + for i in range(0 if u == 1 else 1, u): # Linearly interpolate between vABi and vACi in `i + 1` (`j`) steps, # not including the endpoints. # This could be written as @@ -291,8 +283,8 @@ def inside_points(vAB, vAC): j = i + 1 interp_multipliers = (np.arange(1, j) / j)[:, None] out.append( - np.multiply(interp_multipliers, vAC[i, None]) - + np.multiply(1 - interp_multipliers, vAB[i, None]) + np.multiply(interp_multipliers, vAC[i, None]) + + np.multiply(1 - interp_multipliers, vAB[i, None]) ) return np.concatenate(out) @@ -306,19 +298,7 @@ def generate_icosphere_graph(resolution=1): edges = np.unique(np.sort(edges, axis=1), axis=0) return vertices, edges -vertex, edges = generate_icosphere_graph(2) -print(f"Vertices: {vertex.shape}") -print(f"Edges: {edges.shape}") -print(f"Edges: {edges}") -print(f"Vertices: {vertex}") -vertex2, edges = generate_icosphere_graph(3) -print(f"Vertices: {vertex2.shape}") -print(f"Edges: {vertex2}") -# Check if the first 12 vertices are the same -print(f"Vertices: {np.isclose(vertex2[:12], vertex)}") -exit() - -def generate_icosphere_mapping(lat_lons, resolutions=(1, 2, 3, 4, 5, 6, 7)): +def generate_icosphere_mapping(lat_lons, resolutions=(1, 2, 4, 8, 16, 32, 64)) -> Data: """ Generate mapping from lat/lon to icosphere index. @@ -332,7 +312,7 @@ def generate_icosphere_mapping(lat_lons, resolutions=(1, 2, 3, 4, 5, 6, 7)): Args: lat_lons: List of (lat,lon) points - resolutions: Icosphere resolution levels, first 7 levels correspond to Graphcast levels + resolutions: Icosphere resolution levels, first 7 levels correspond to Graphcast levels, in ascending order of resolution """ num_latlons = len(lat_lons) verticies_per_level = [] @@ -342,13 +322,17 @@ def generate_icosphere_mapping(lat_lons, resolutions=(1, 2, 3, 4, 5, 6, 7)): verticies_per_level.append(vertices) edges_per_level.append(edges) - # TODO Align the verticies so the same positions line up, and deduplicate the edges - # Each set of verticies is the same as the one before, but with more verticies, e.g. first 12 are always the same - # Should just need to deduplicate the edges, after combining them all - # TODO Create Data object where there is a minimal amount of verticies (the overlapping ones are the same) + # Check the verticies of each pair are the same up to the resolution + for i in range(len(verticies_per_level) - 1): + for vertex_lower_index, vertex in enumerate(verticies_per_level[i]): + vertex_mapping = np.argmin(np.sum(np.abs(verticies_per_level[i + 1] - vertex), axis=1), axis=0) + # Change all edge indicies from vertex_lower_index to vertex_mapping + edges_per_level[i + 1][edges_per_level[i + 1] == vertex_lower_index] = vertex_mapping + verticies = verticies_per_level[-1] # The last layer has all the verticies of the ones above + edges = np.unique(np.sort(np.concatenate(edges_per_level), axis=1), axis=0) # TODO Create mapping from the lat/lon to the icosphere nodes - - return h3_mapping, h3_distances + ico_graph = Data(x=torch.tensor(verticies, dtype=torch.float), edge_index=torch.tensor(edges, dtype=torch.long).t().contiguous()) + return ico_graph def generate_latent_ico_graph(h3_mapping, h3_distances):