diff --git a/graph_weather/data/utils.py b/graph_weather/data/utils.py new file mode 100644 index 00000000..e69de29b diff --git a/graph_weather/models/graphs/__init__.py b/graph_weather/models/graphs/__init__.py new file mode 100644 index 00000000..71d50284 --- /dev/null +++ b/graph_weather/models/graphs/__init__.py @@ -0,0 +1 @@ +"""Set of graph classes for generating different meshes""" \ No newline at end of file diff --git a/graph_weather/models/graphs/hexagonal.py b/graph_weather/models/graphs/hexagonal.py new file mode 100644 index 00000000..00ee82f5 --- /dev/null +++ b/graph_weather/models/graphs/hexagonal.py @@ -0,0 +1,17 @@ +"""Generate hexagonal global grid using Uber's H3 library.""" +import h3 +import numpy as np + + +def generate_hexagonal_grid(resolution: int = 2) -> np.ndarray: + """Generate hexagonal global grid using Uber's H3 library. + + Args: + resolution: H3 resolution level + + Returns: + Hexagonal grid + """ + base_h3_grid = sorted(list(h3.uncompact(h3.get_res0_indexes(), resolution))) + base_h3_map = {h_i: i for i, h_i in enumerate(base_h3_grid)} + return np.array(base_h3_grid), base_h3_map diff --git a/graph_weather/models/graphs/ico.py b/graph_weather/models/graphs/ico.py new file mode 100644 index 00000000..e3f25fcd --- /dev/null +++ b/graph_weather/models/graphs/ico.py @@ -0,0 +1,275 @@ +''' +Creating geodesic icosahedron with given (integer) subdivision frequency (and +not by recursively applying Loop-like subdivision). + +Advantage of subdivision frequency compared to the recursive subdivision is in +controlling the mesh resolution. Mesh resolution grows quadratically with +subdivision frequencies while it grows exponentially with iterations of the +recursive subdivision. To be precise, using the recursive +subdivision (each iteration being a subdivision with frequency nu=2), the +possible number of vertices grows with iterations i as + [12+10*(2**i+1)*(2**i-1) for i in range(10)] +which gives + [12, 42, 162, 642, 2562, 10242, 40962, 163842, 655362, 2621442]. +Notice for example there is no mesh having between 2562 and 10242 vertices. +Using subdivision frequency, possible number of vertices grows with nu as + [12+10*(nu+1)*(nu-1) for nu in range(1,33)] +which gives + [12, 42, 92, 162, 252, 362, 492, 642, 812, 1002, 1212, 1442, 1692, 1962, + 2252, 2562, 2892, 3242, 3612, 4002, 4412, 4842, 5292, 5762, 6252, 6762, + 7292, 7842, 8412, 9002, 9612, 10242] +where nu = 32 gives 10242 vertices, and there are 15 meshes having between +2562 and 10242 vertices. The advantage is even more pronounced when using +higher resolutions. + +Author: vand@dtu.dk, 2014, 2017, 2021. +Originally developed in connectiton with +https://ieeexplore.ieee.org/document/7182720 + +This code is copied in as there is an improvement in the inside_points function that +is not merged in that speeds up generation 5-8x. See https://github.com/vedranaa/icosphere/pull/3 + +''' + +import numpy as np + + +def icosphere(nu=1, nr_verts=None): + ''' + Returns a geodesic icosahedron with subdivision frequency nu. Frequency + nu = 1 returns regular unit icosahedron, and nu>1 preformes subdivision. + If nr_verts is given, nu will be adjusted such that icosphere contains + at least nr_verts vertices. Returned faces are zero-indexed! + + Parameters + ---------- + nu : subdivision frequency, integer (larger than 1 to make a change). + nr_verts: desired number of mesh vertices, if given, nu may be increased. + + + Returns + ------- + subvertices : vertex list, numpy array of shape (20+10*(nu+1)*(nu-1)/2, 3) + subfaces : face list, numpy array of shape (10*n**2, 3) + + ''' + + # Unit icosahedron + (vertices, faces) = icosahedron() + + # If nr_verts given, computing appropriate subdivision frequency nu. + # We know nr_verts = 12+10*(nu+1)(nu-1) + if not nr_verts is None: + nu_min = np.ceil(np.sqrt(max(1 + (nr_verts - 12) / 10, 1))) + nu = max(nu, nu_min) + + # Subdividing + if nu > 1: + (vertices, faces) = subdivide_mesh(vertices, faces, nu) + vertices = vertices / np.sqrt(np.sum(vertices ** 2, axis=1, keepdims=True)) + + return (vertices, faces) + + +def icosahedron(): + '''' Regular unit icosahedron. ''' + + # 12 principal directions in 3D space: points on an unit icosahedron + phi = (1 + np.sqrt(5)) / 2 + vertices = np.array([ + [0, 1, phi], [0, -1, phi], [1, phi, 0], + [-1, phi, 0], [phi, 0, 1], [-phi, 0, 1]]) / np.sqrt(1 + phi ** 2) + vertices = np.r_[vertices, -vertices] + + # 20 faces + faces = np.array([ + [0, 5, 1], [0, 3, 5], [0, 2, 3], [0, 4, 2], [0, 1, 4], + [1, 5, 8], [5, 3, 10], [3, 2, 7], [2, 4, 11], [4, 1, 9], + [7, 11, 6], [11, 9, 6], [9, 8, 6], [8, 10, 6], [10, 7, 6], + [2, 11, 7], [4, 9, 11], [1, 8, 9], [5, 10, 8], [3, 7, 10]], dtype=int) + + return (vertices, faces) + + +def subdivide_mesh(vertices, faces, nu): + ''' + Subdivides mesh by adding vertices on mesh edges and faces. Each edge + will be divided in nu segments. (For example, for nu=2 one vertex is added + on each mesh edge, for nu=3 two vertices are added on each mesh edge and + one vertex is added on each face.) If V and F are number of mesh vertices + and number of mesh faces for the input mesh, the subdivided mesh contains + V + F*(nu+1)*(nu-1)/2 vertices and F*nu^2 faces. + + Parameters + ---------- + vertices : vertex list, numpy array of shape (V,3) + faces : face list, numby array of shape (F,3). Zero indexed. + nu : subdivision frequency, integer (larger than 1 to make a change). + + Returns + ------- + subvertices : vertex list, numpy array of shape (V + F*(nu+1)*(nu-1)/2, 3) + subfaces : face list, numpy array of shape (F*n**2, 3) + + Author: vand at dtu.dk, 8.12.2017. Translated to python 6.4.2021 + + ''' + + edges = np.r_[faces[:, :-1], faces[:, 1:], faces[:, [0, 2]]] + edges = np.unique(np.sort(edges, axis=1), axis=0) + F = faces.shape[0] + V = vertices.shape[0] + E = edges.shape[0] + subfaces = np.empty((F * nu ** 2, 3), dtype=int) + subvertices = np.empty((V + E * (nu - 1) + F * (nu - 1) * (nu - 2) // 2, 3)) + + subvertices[:V] = vertices + + # Dictionary for accessing edge index from indices of edge vertices. + edge_indices = dict() + for i in range(V): + edge_indices[i] = dict() + for i in range(E): + edge_indices[edges[i, 0]][edges[i, 1]] = i + edge_indices[edges[i, 1]][edges[i, 0]] = -i + + template = faces_template(nu) + ordering = vertex_ordering(nu) + reordered_template = ordering[template] + + # At this point, we have V vertices, and now we add (nu-1) vertex per edge + # (on-edge vertices). + w = np.arange(1, nu) / nu # interpolation weights + for e in range(E): + edge = edges[e] + for k in range(nu - 1): + subvertices[V + e * (nu - 1) + k] = (w[-1 - k] * vertices[edge[0]] + + w[k] * vertices[edge[1]]) + + # At this point we have E(nu-1)+V vertices, and we add (nu-1)*(nu-2)/2 + # vertices per face (on-face vertices). + r = np.arange(nu - 1) + for f in range(F): + # First, fixing connectivity. We get hold of the indices of all + # vertices invoved in this subface: original, on-edges and on-faces. + T = np.arange(f * (nu - 1) * (nu - 2) // 2 + E * (nu - 1) + V, + (f + 1) * (nu - 1) * (nu - 2) // 2 + E * (nu - 1) + V) # will be added + eAB = edge_indices[faces[f, 0]][faces[f, 1]] + eAC = edge_indices[faces[f, 0]][faces[f, 2]] + eBC = edge_indices[faces[f, 1]][faces[f, 2]] + AB = reverse(abs(eAB) * (nu - 1) + V + r, eAB < 0) # already added + AC = reverse(abs(eAC) * (nu - 1) + V + r, eAC < 0) # already added + BC = reverse(abs(eBC) * (nu - 1) + V + r, eBC < 0) # already added + VEF = np.r_[faces[f], AB, AC, BC, T] + subfaces[f * nu ** 2:(f + 1) * nu ** 2, :] = VEF[reordered_template] + # Now geometry, computing positions of face vertices. + subvertices[T, :] = inside_points(subvertices[AB, :], subvertices[AC, :]) + + return (subvertices, subfaces) + + +def reverse(vector, flag): + '''' For reversing the direction of an edge. ''' + + if flag: + vector = vector[::-1] + return (vector) + + +def faces_template(nu): + ''' + Template for linking subfaces 0 + in a subdivision of a face. / \ + Returns faces with vertex 1---2 + indexing given by reading order / \ / \ + (as illustratated). 3---4---5 + / \ / \ / \ + 6---7---8---9 + / \ / \ / \ / \ + 10--11--12--13--14 + ''' + + faces = [] + # looping in layers of triangles + for i in range(nu): + vertex0 = i * (i + 1) // 2 + skip = i + 1 + for j in range(i): # adding pairs of triangles, will not run for i==0 + faces.append([j + vertex0, j + vertex0 + skip, j + vertex0 + skip + 1]) + faces.append([j + vertex0, j + vertex0 + skip + 1, j + vertex0 + 1]) + # adding the last (unpaired, rightmost) triangle + faces.append([i + vertex0, i + vertex0 + skip, i + vertex0 + skip + 1]) + + return (np.array(faces)) + + +def vertex_ordering(nu): + ''' + Permutation for ordering of 0 + face vertices which transformes / \ + reading-order indexing into indexing 3---6 + first corners vertices, then on-edges / \ / \ + vertices, and then on-face vertices 4---12--7 + (as illustrated). / \ / \ / \ + 5---13--14--8 + / \ / \ / \ / \ + 1---9--10--11---2 + ''' + + left = [j for j in range(3, nu + 2)] + right = [j for j in range(nu + 2, 2 * nu + 1)] + bottom = [j for j in range(2 * nu + 1, 3 * nu)] + inside = [j for j in range(3 * nu, (nu + 1) * (nu + 2) // 2)] + + o = [0] # topmost corner + for i in range(nu - 1): + o.append(left[i]) + o = o + inside[i * (i - 1) // 2:i * (i + 1) // 2] + o.append(right[i]) + o = o + [1] + bottom + [2] + + return (np.array(o)) + + +def inside_points(vAB, vAC): + """ + Returns coordinates of the inside . + (on-face) vertices (marked by star) / \ + for subdivision of the face ABC when vAB0---vAC0 + given coordinates of the on-edge / \ / \ + vertices AB[i] and AC[i]. vAB1---*---vAC1 + / \ / \ / \ + vAB2---*---*---vAC2 + / \ / \ / \ / \ + .---.---.---.---. + """ + + out = [] + for i in range(1, vAB.shape[0]): + # Linearly interpolate between vABi and vACi in `i + 1` (`j`) steps, + # not including the endpoints. + # This could be written as + # vABi = vAB[i, :] + # vACi = vAC[i, :] + # interp_multipliers = np.arange(1, j) / j + # res = np.outer(interp_multipliers, vACi) + np.outer(1 - interp_multipliers, vABi) + # but that will involve some extra work on `np.outer`'s part that we can + # do ourselves since we know the shapes we're working with. + j = i + 1 + interp_multipliers = (np.arange(1, j) / j)[:, None] + out.append( + np.multiply(interp_multipliers, vAC[i, None]) + + np.multiply(1 - interp_multipliers, vAB[i, None]) + ) + return np.concatenate(out) + + +def generate_icosphere_graph(resolution=1): + """ + Generate a graph of the icosphere with the given level of subdivision. + """ + vertices, faces = icosphere(resolution) + edges = np.r_[faces[:, :-1], faces[:, 1:], faces[:, [0, 2]]] + edges = np.unique(np.sort(edges, axis=1), axis=0) + return NotImplementedError("TODO: Make into PyTorch Tensors and return") + return vertices, edges diff --git a/train/deepspeed_graph.py b/train/deepspeed_graph.py deleted file mode 100644 index 23ede888..00000000 --- a/train/deepspeed_graph.py +++ /dev/null @@ -1,64 +0,0 @@ -import pytorch_lightning as pl -import torch -from pytorch_lightning import Trainer - -from graph_weather import GraphWeatherForecaster - -lat_lons = [] -for lat in range(-90, 90, 1): - for lon in range(0, 360, 1): - lat_lons.append((lat, lon)) - - -class LitModel(pl.LightningModule): - def __init__(self, lat_lons, feature_dim, aux_dim): - super().__init__() - self.model = GraphWeatherForecaster( - lat_lons=lat_lons, feature_dim=feature_dim, aux_dim=aux_dim - ) - - def training_step(self, batch): - x, y = batch - x = x.half() - y = y.half() - out = self.forward(x) - criterion = torch.nn.MSELoss() - loss = criterion(out, y) - return loss - - def configure_optimizers(self): - return torch.optim.AdamW(self.parameters()) - - def forward(self, x): - return self.model(x) - - -# Fake data -from torch.utils.data import DataLoader, Dataset - - -class FakeDataset(Dataset): - def __init__(self): - super(FakeDataset, self).__init__() - - def __len__(self): - return 64000 - - def __getitem__(self, item): - return torch.randn((64800, 605 + 32)), torch.randn((64800, 605)) - - -model = LitModel(lat_lons=lat_lons, feature_dim=605, aux_dim=32) -trainer = Trainer( - accelerator="gpu", - devices=1, - strategy="deepspeed_stage_3_offload", - precision=16, - max_epochs=10, - limit_train_batches=2000, -) -dataset = FakeDataset() -train_dataloader = DataLoader( - dataset, batch_size=1, num_workers=1, pin_memory=True, prefetch_factor=1 -) -trainer.fit(model=model, train_dataloaders=train_dataloader) diff --git a/train/pl_graph_weather.py b/train/pl_graph_weather.py deleted file mode 100644 index 40ed7f21..00000000 --- a/train/pl_graph_weather.py +++ /dev/null @@ -1,355 +0,0 @@ -"""PyTorch Lightning training script for the weather forecasting model""" -import click -import datasets -import numpy as np -import pandas as pd -import pytorch_lightning as pl -import torch -from pysolar.util import extraterrestrial_irrad -from pytorch_lightning.callbacks import ModelCheckpoint -from torch.utils.data import DataLoader - -from graph_weather import GraphWeatherForecaster -from graph_weather.data import const -from graph_weather.models.losses import NormalizedMSELoss - -const.FORECAST_MEANS = {var: np.asarray(value) for var, value in const.FORECAST_MEANS.items()} -const.FORECAST_STD = {var: np.asarray(value) for var, value in const.FORECAST_STD.items()} - - -def worker_init_fn(worker_id): - np.random.seed(np.random.get_state()[1][0] + worker_id) - - -def get_mean_stds(): - names = [ - "CLMR", - "GRLE", - "VVEL", - "VGRD", - "UGRD", - "O3MR", - "CAPE", - "TMP", - "PLPL", - "DZDT", - "CIN", - "HGT", - "RH", - "ICMR", - "SNMR", - "SPFH", - "RWMR", - "TCDC", - "ABSV", - ] - means = {} - stds = {} - # For pressure level values - for n in names: - if ( - len( - sorted( - [ - float(var.split(".", 1)[-1].split("_")[0]) - for var in const.FORECAST_MEANS - if "mb" in var and n in var and "-" not in var - ] - ) - ) - > 0 - ): - means[n + "_mb"] = [] - stds[n + "_mb"] = [] - for value in sorted( - [ - float(var.split(".", 1)[-1].split("_")[0]) - for var in const.FORECAST_MEANS - if "mb" in var and n in var and "-" not in var - ] - ): - # Is floats now, but will be fixed - if value >= 1: - value = int(value) - var_name = f"{n}.{value}_mb" - # print(var_name) - - means[n + "_mb"].append(const.FORECAST_MEANS[var_name]) - stds[n + "_mb"].append(const.FORECAST_STD[var_name]) - means[n + "_mb"] = np.mean(np.stack(means[n + "_mb"], axis=-1)) - stds[n + "_mb"] = np.mean(np.stack(stds[n + "_mb"], axis=-1)) - - # For surface values - for n in list( - set( - [ - var.split(".", 1)[0] - for var in const.FORECAST_MEANS - if "surface" in var - and "level" not in var - and "2e06" not in var - and "below" not in var - and "atmos" not in var - and "tropo" not in var - and "iso" not in var - and "planetary_boundary_layer" not in var - ] - ) - ): - means[n] = const.FORECAST_MEANS[n + ".surface"] - stds[n] = const.FORECAST_STD[n + ".surface"] - - # For Cloud levels - for n in list( - set( - [ - var.split(".", 1)[0] - for var in const.FORECAST_MEANS - if "sigma" not in var - and "level" not in var - and "2e06" not in var - and "below" not in var - and "atmos" not in var - and "tropo" not in var - and "iso" not in var - and "planetary_boundary_layer" not in var - ] - ) - ): - if "LCDC" in n: # or "MCDC" in n or "HCDC" in n: - means[n] = const.FORECAST_MEANS["LCDC.low_cloud_layer"] - stds[n] = const.FORECAST_STD["LCDC.low_cloud_layer"] - if "MCDC" in n: # or "HCDC" in n: - means[n] = const.FORECAST_MEANS["MCDC.middle_cloud_layer"] - stds[n] = const.FORECAST_STD["MCDC.middle_cloud_layer"] - if "HCDC" in n: - means[n] = const.FORECAST_MEANS["HCDC.high_cloud_layer"] - stds[n] = const.FORECAST_STD["HCDC.high_cloud_layer"] - - # Now for each of these - means["max_wind"] = [] - stds["max_wind"] = [] - for n in sorted([var for var in const.FORECAST_MEANS if "max_wind" in var]): - means["max_wind"].append(const.FORECAST_MEANS[n]) - stds["max_wind"].append(const.FORECAST_STD[n]) - means["max_wind"] = np.stack(means["max_wind"], axis=-1) - stds["max_wind"] = np.stack(stds["max_wind"], axis=-1) - - for i in [2, 10, 20, 30, 40, 50, 80, 100]: - means[f"{i}m_above_ground"] = [] - stds[f"{i}m_above_ground"] = [] - for n in sorted([var for var in const.FORECAST_MEANS if f"{i}_m_above_ground" in var]): - means[f"{i}m_above_ground"].append(const.FORECAST_MEANS[n]) - stds[f"{i}m_above_ground"].append(const.FORECAST_STD[n]) - means[f"{i}m_above_ground"] = np.stack(means[f"{i}m_above_ground"], axis=-1) - stds[f"{i}m_above_ground"] = np.stack(stds[f"{i}m_above_ground"], axis=-1) - return means, stds - - -means, stds = get_mean_stds() - - -def process_data(data): - data.update( - { - key: np.expand_dims(np.asarray(value), axis=-1) - for key, value in data.items() - if key.replace("current_", "").replace("next_", "") in means.keys() - and np.asarray(value).ndim == 2 - } - ) # Add third dimension for ones with 2 - input_data = { - key.replace("current_", ""): torch.from_numpy( - (value - means[key.replace("current_", "")]) / stds[key.replace("current_", "")] - ) - for key, value in data.items() - if "current" in key and "time" not in key - } - output_data = { - key.replace("next_", ""): torch.from_numpy( - (value - means[key.replace("next_", "")]) / stds[key.replace("next_", "")] - ) - for key, value in data.items() - if "next" in key and "time" not in key - } - lat_lons = np.array( - np.meshgrid(np.asarray(data["latitude"]).flatten(), np.asarray(data["longitude"]).flatten()) - ).T.reshape((-1, 2)) - sin_lat_lons = np.sin(lat_lons * np.pi / 180.0) - cos_lat_lons = np.cos(lat_lons * np.pi / 180.0) - date = pd.to_datetime(data["timestamps"][0], utc=True) - solar_times = [ - np.array( - [ - extraterrestrial_irrad( - when=date.to_pydatetime(), latitude_deg=lat, longitude_deg=lon - ) - for lat, lon in lat_lons - ] - ) - ] - for when in pd.date_range( - date - pd.Timedelta("12 hours"), date + pd.Timedelta("12 hours"), freq="1H" - ): - solar_times.append( - np.array( - [ - extraterrestrial_irrad( - when=when.to_pydatetime(), latitude_deg=lat, longitude_deg=lon - ) - for lat, lon in lat_lons - ] - ) - ) - solar_times = np.array(solar_times) - # Normalize to between -1 and 1 - solar_times -= const.SOLAR_MEAN - solar_times /= const.SOLAR_STD - input_data = torch.concat([value for _, value in input_data.items()], dim=-1) - output_data = torch.concat([value for _, value in output_data.items()], dim=-1) - input_data = input_data.transpose(0, 1).reshape(-1, input_data.shape[-1]) - output_data = output_data.transpose(0, 1).reshape(-1, input_data.shape[-1]) - day_of_year = pd.to_datetime(data["timestamps"][0], utc=True).dayofyear / 366.0 - sin_of_year = np.ones_like(lat_lons)[:, 0] * np.sin(day_of_year) - cos_of_year = np.ones_like(lat_lons)[:, 0] * np.cos(day_of_year) - to_concat = [ - input_data, - torch.permute(torch.from_numpy(solar_times), (1, 0)), - torch.from_numpy(sin_lat_lons), - torch.from_numpy(cos_lat_lons), - torch.from_numpy(np.expand_dims(sin_of_year, axis=-1)), - torch.from_numpy(np.expand_dims(cos_of_year, axis=-1)), - ] # , landsea_fixed] - input_data = torch.concat(to_concat, dim=-1) - new_data = { - "input": input_data.float().numpy(), - "output": output_data.float().numpy(), - "has_nans": not np.isnan(input_data.float().numpy()).any() - and not np.isnan(output_data.float().numpy()).any(), - } - return new_data - - -class GraphDataModule(pl.LightningDataModule): - def __init__(self, deg: str = "2.0", batch_size: int = 1): - super().__init__() - self.batch_size = batch_size - self.dataset = datasets.load_dataset( - "openclimatefix/gfs-surface-pressure-2deg", split="train+validation", streaming=False - ) - features = datasets.Features( - { - "input": datasets.Array2D(shape=(16380, 637), dtype="float32"), - "output": datasets.Array2D(shape=(16380, 605), dtype="float32"), - "has_nans": datasets.Value("bool"), - } - ) - self.dataset = ( - self.dataset.map( - process_data, - remove_columns=self.dataset.column_names, - features=features, - num_proc=16, - writer_batch_size=2, - ) - .filter(lambda x: x["has_nans"]) - .with_format("torch") - ) - - def train_dataloader(self): - return DataLoader(self.dataset, batch_size=self.batch_size, num_workers=2) - - -class LitGraphForecaster(pl.LightningModule): - def __init__( - self, - lat_lons: list, - feature_dim: int = 605, - aux_dim: int = 32, - hidden_dim: int = 64, - num_blocks: int = 3, - lr: float = 3e-4, - ): - super().__init__() - self.model = GraphWeatherForecaster( - lat_lons, - feature_dim=feature_dim, - aux_dim=aux_dim, - hidden_dim_decoder=hidden_dim, - hidden_dim_processor_node=hidden_dim, - hidden_dim_processor_edge=hidden_dim, - num_blocks=num_blocks, - ) - self.criterion = NormalizedMSELoss( - lat_lons=lat_lons, feature_variance=np.ones((feature_dim,)) - ) - self.lr = lr - self.save_hyperparameters() - - def forward(self, x): - return self.model(x) - - def training_step(self, batch, batch_idx): - x, y = batch["input"], batch["output"] - if torch.isnan(x).any() or torch.isnan(y).any(): - return None - y_hat = self.forward(x) - loss = self.criterion(y_hat, y) - return loss - - def configure_optimizers(self): - return torch.optim.AdamW(self.parameters(), lr=self.lr) - - -@click.command() -@click.option( - "--num-blocks", - default=5, - help="Where to save the zarr files", - type=click.INT, -) -@click.option( - "--hidden", - default=32, - help="Where to save the zarr files", - type=click.INT, -) -@click.option( - "--batch", - default=1, - help="Where to save the zarr files", - type=click.INT, -) -@click.option( - "--gpus", - default=1, - help="Where to save the zarr files", - type=click.INT, -) -def run(num_blocks, hidden, batch, gpus): - hf_ds = datasets.load_dataset( - "openclimatefix/gfs-surface-pressure-2deg", split="train", streaming=False - ) - example_batch = next(iter(hf_ds)) - lat_lons = np.array( - np.meshgrid( - np.asarray(example_batch["latitude"]).flatten(), - np.asarray(example_batch["longitude"]).flatten(), - ) - ).T.reshape((-1, 2)) - checkpoint_callback = ModelCheckpoint(dirpath="./", save_top_k=2, monitor="loss") - dset = GraphDataModule(batch_size=batch) - model = LitGraphForecaster(lat_lons=lat_lons, num_blocks=num_blocks, hidden_dim=hidden) - trainer = pl.Trainer( - accelerator="gpu", - devices=gpus, - max_epochs=100, - precision=16, - callbacks=[checkpoint_callback], - ) - # strategy="deepspeed_stage_2_offload") - trainer.fit(model, dset) - - -if __name__ == "__main__": - run() diff --git a/train/run_fulll.py b/train/run_fulll.py deleted file mode 100644 index fdf65bc5..00000000 --- a/train/run_fulll.py +++ /dev/null @@ -1,133 +0,0 @@ -"""Training script for training the weather forecasting model""" -import json - -import numpy as np -import torch -import torch.optim as optim -import torchvision.transforms as transforms -import xarray as xr -from torch.utils.data import DataLoader, Dataset - -from graph_weather import GraphWeatherForecaster -from graph_weather.data import const -from graph_weather.models.losses import NormalizedMSELoss - - -class XrDataset(Dataset): - def __init__(self): - super().__init__() - with open("hf_forecasts.json", "r") as f: - files = json.load(f) - self.filepaths = [ - "zip:///::https://huggingface.co/datasets/openclimatefix/gfs-reforecast/resolve/main/" - + f - for f in files - ] - self.data = xr.open_mfdataset( - self.filepaths, engine="zarr", concat_dim="reftime", combine="nested" - ).sortby("reftime") - - def __len__(self): - return len(self.filepaths) - - def __getitem__(self, item): - start_idx = np.random.randint(0, 15) - data = self.data.isel(reftime=item, time=slice(start_idx, start_idx + 1)) - - start = data.isel(time=0) - end = data.isel(time=1) - # Stack the data into a large data cube - input_data = np.stack( - [ - (start[f"{var}"].values - const.FORECAST_MEANS[f"{var}"]) - / (const.FORECAST_STD[f"{var}"] + 0.0001) - for var in start.data_vars - if "mb" in var or "surface" in var - ], - axis=-1, - ) - input_data = np.nan_to_num(input_data) - assert not np.isnan(input_data).any() - output_data = np.stack( - [ - (end[f"{var}"].values - const.FORECAST_MEANS[f"{var}"]) - / (const.FORECAST_STD[f"{var}"] + 0.0001) - for var in end.data_vars - if "mb" in var or "surface" in var - ], - axis=-1, - ) - output_data = np.nan_to_num(output_data) - assert not np.isnan(output_data).any() - transform = transforms.Compose([transforms.ToTensor()]) - # Normalize now - return ( - transform(input_data).transpose(0, 1).reshape(-1, input_data.shape[-1]), - transform(output_data).transpose(0, 1).reshape(-1, input_data.shape[-1]), - ) - - -with open("hf_forecasts.json", "r") as f: - files = json.load(f) -files = [ - "zip:///::https://huggingface.co/datasets/openclimatefix/gfs-reforecast/resolve/main/" + f - for f in files -] -data = ( - xr.open_zarr(files[0], consolidated=True).isel(time=0) - # .coarsen(latitude=8, boundary="pad") - # .mean() - # .coarsen(longitude=8) - # .mean() -) -print(data) -# print("Done coarsening") -lat_lons = np.array(np.meshgrid(data.latitude.values, data.longitude.values)).T.reshape(-1, 2) -device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu") -# Get the variance of the variables -feature_variances = [] -for var in data.data_vars: - if "mb" in var or "surface" in var: - feature_variances.append(const.FORECAST_DIFF_STD[var] ** 2) -criterion = NormalizedMSELoss( - lat_lons=lat_lons, feature_variance=feature_variances, device=device -).to(device) -means = [] -dataset = DataLoader(XrDataset(), batch_size=1, num_workers=32) -model = GraphWeatherForecaster(lat_lons, feature_dim=597, num_blocks=6).to(device) -optimizer = optim.AdamW(model.parameters(), lr=0.001) -print("Done Setup") -import time - -for epoch in range(100): # loop over the dataset multiple times - running_loss = 0.0 - start = time.time() - print(f"Start Epoch: {epoch}") - for i, data in enumerate(dataset): - # get the inputs; data is a list of [inputs, labels] - inputs, labels = data[0].to(device), data[1].to(device) - # zero the parameter gradients - optimizer.zero_grad() - - # forward + backward + optimize - outputs = model(inputs) - - loss = criterion(outputs, labels) - loss.backward() - optimizer.step() - - # print statistics - running_loss += loss.item() - end = time.time() - print( - f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / (i + 1):.3f} Time: {end - start} sec" - ) - if epoch % 5 == 0: - assert not np.isnan(running_loss) - model.push_to_hub( - "graph-weather-forecaster-2.0deg", - organization="openclimatefix", - commit_message=f"Add model Epoch={epoch}", - ) - -print("Finished Training")