Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 8, 2023
1 parent f2eb1b3 commit 9cdb87f
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 24 deletions.
41 changes: 27 additions & 14 deletions graph_weather/models/graphs/ico.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from sklearn.neighbors import NearestNeighbors



def icosphere(nu=1, nr_verts=None):
"""
Returns a geodesic icosahedron with subdivision frequency nu. Frequency
Expand Down Expand Up @@ -323,6 +322,7 @@ def generate_icosphere_mapping(lat_lons, resolutions=(1, 2, 4, 8, 16), bidirecti
from torch_geometric.data import Data
from graph_weather.models.graphs.utils import deg2rad, latlon2xyz, xyz2latlon, get_edge_len
from sklearn.neighbors import NearestNeighbors

num_latlons = len(lat_lons)
verticies_per_level = []
edges_per_level = []
Expand All @@ -333,7 +333,7 @@ def generate_icosphere_mapping(lat_lons, resolutions=(1, 2, 4, 8, 16), bidirecti

# Check the verticies of each pair are the same up to the resolution
for i in range(len(verticies_per_level) - 2):
#print(edges_per_level[i])
# print(edges_per_level[i])
for vertex_lower_index, vertex in enumerate(verticies_per_level[i]):
# Go through each index in the current level, finding the closest vertex in the next level
# Should check all verticies in the next level, and find the closest one
Expand All @@ -347,14 +347,14 @@ def generate_icosphere_mapping(lat_lons, resolutions=(1, 2, 4, 8, 16), bidirecti
edges_per_level[i][edge_index][0] = vertex_upper_index
if edge[1] == vertex_lower_index:
edges_per_level[i][edge_index][1] = vertex_upper_index
#print(edges_per_level[i][edges_per_level[i] == vertex_lower_index])
# print(edges_per_level[i][edges_per_level[i] == vertex_lower_index])
# The vertex is the same, so the edges in the current level that equal vertex_lower_index should be changed to equal vertex_upper_index
#edges_per_level[i][edges_per_level[i] == vertex_upper_index] = vertex_upper_index
# edges_per_level[i][edges_per_level[i] == vertex_upper_index] = vertex_upper_index
if multiple_equals > 1:
print(f"Multiple equals: {multiple_equals}")
#print(edges_per_level[i])
#print("------------------")
verticies = verticies_per_level[-1] # The last layer has all the verticies of the ones above
# print(edges_per_level[i])
# print("------------------")
verticies = verticies_per_level[-1] # The last layer has all the verticies of the ones above
edges = np.sort(np.concatenate(edges_per_level), axis=1)
print(f"Number of edges: {len(edges)}")
if bidirectional:
Expand All @@ -363,16 +363,24 @@ def generate_icosphere_mapping(lat_lons, resolutions=(1, 2, 4, 8, 16), bidirecti
# TODO Create mapping from the lat/lon to the icosphere nodes
print(f"Now Number of edges: {len(edges)}")
print(f"Number of unique edges: {len(np.unique(edges, axis=0))}")
print(f"Max number of repeated edges: {np.max(np.unique(edges, axis=0, return_counts=True)[1])}")
print(
f"Max number of repeated edges: {np.max(np.unique(edges, axis=0, return_counts=True)[1])}"
)
u, c = np.unique(edges, axis=0, return_counts=True)
print(f"First 20 duplicates: {u[c > 1][:20]}")
xyz_grid = latlon2xyz(torch.tensor(lat_lons, dtype=torch.float))
# Find the closest vertex to each point
vertex_mapping = np.argmin(np.sum(np.abs(verticies - xyz_grid[:, None]), axis=2), axis=1)
# Features will need to be in the same lat/lon order as given, and added to the verticies
ico_graph = Data(pos=torch.tensor(verticies, dtype=torch.float),
edge_index=torch.tensor(edges, dtype=torch.long).t().contiguous())
max_edge_len = np.max(get_edge_len(ico_graph.pos[edges_per_level[-1][:, 0]], ico_graph.pos[edges_per_level[-1][:, 1]]))
ico_graph = Data(
pos=torch.tensor(verticies, dtype=torch.float),
edge_index=torch.tensor(edges, dtype=torch.long).t().contiguous(),
)
max_edge_len = np.max(
get_edge_len(
ico_graph.pos[edges_per_level[-1][:, 0]], ico_graph.pos[edges_per_level[-1][:, 1]]
)
)
# create the grid2mesh bipartite graph
cartesian_grid = latlon2xyz(lat_lons)
n_nbrs = 4
Expand All @@ -389,10 +397,14 @@ def generate_icosphere_mapping(lat_lons, resolutions=(1, 2, 4, 8, 16), bidirecti
ico_graph.validate(raise_on_error=True)
return ico_graph

generate_icosphere_mapping([(0,0), (0,1), (1,0), (1,1)])
def get_grid_to_mesh(lat_lons: torch.Tensor, mesh: Data):

max_edge_len = np.max(get_edge_len(mesh.pos[mesh.edge_index[:,0]], mesh.pos[mesh.edge_index[:,1]]))
generate_icosphere_mapping([(0, 0), (0, 1), (1, 0), (1, 1)])


def get_grid_to_mesh(lat_lons: torch.Tensor, mesh: Data):
max_edge_len = np.max(
get_edge_len(mesh.pos[mesh.edge_index[:, 0]], mesh.pos[mesh.edge_index[:, 1]])
)

# create the grid2mesh bipartite graph
cartesian_grid = latlon2xyz(lat_lons)
Expand All @@ -407,6 +419,7 @@ def get_grid_to_mesh(lat_lons: torch.Tensor, mesh: Data):
src.append(i)
dst.append(indices[i][j])


def generate_latent_ico_graph(h3_mapping, h3_distances):
"""
Generate latent h3 graph.
Expand Down
15 changes: 5 additions & 10 deletions graph_weather/models/graphs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def rad2deg(rad):
"""
return rad * 180 / np.pi


def azimuthal_angle(lon: Tensor) -> Tensor:
"""
Gives the azimuthal angle of a point on the sphere
Expand Down Expand Up @@ -179,9 +180,7 @@ def polar_angle(lat: Tensor) -> Tensor:
return angle


def geospatial_rotation(
invar: Tensor, theta: Tensor, axis: str, unit: str = "rad"
) -> Tensor:
def geospatial_rotation(invar: Tensor, theta: Tensor, axis: str, unit: str = "rad") -> Tensor:
"""Rotation using right hand rule
Parameters
Expand Down Expand Up @@ -296,9 +295,7 @@ def add_edge_features(graph: Data, pos: Tensor, normalize: bool = True) -> Data:
# normalize using the longest edge
if normalize:
max_disp_norm = torch.max(disp_norm)
graph["edge_attr"] = torch.cat(
(disp / max_disp_norm, disp_norm / max_disp_norm), dim=-1
)
graph["edge_attr"] = torch.cat((disp / max_disp_norm, disp_norm / max_disp_norm), dim=-1)
else:
graph["edge_attr"] = torch.cat((disp, disp_norm), dim=-1)
return graph
Expand All @@ -322,9 +319,7 @@ def add_node_features(graph: Data, pos: Tensor) -> Data:
"""
latlon = xyz2latlon(pos)
lat, lon = latlon[:, 0], latlon[:, 1]
graph["x"] = torch.stack(
(torch.cos(lat), torch.sin(lon), torch.cos(lon)), dim=-1
)
graph["x"] = torch.stack((torch.cos(lat), torch.sin(lon), torch.cos(lon)), dim=-1)
return graph


Expand Down Expand Up @@ -383,5 +378,5 @@ def _format_axes(ax):
for angle in range(0, 360):
ax.view_init(30, angle)
plt.draw()
plt.pause(.01)
plt.pause(0.01)
plt.show()

0 comments on commit 9cdb87f

Please sign in to comment.