Skip to content

Commit

Permalink
Fengwu ghr training (#125)
Browse files Browse the repository at this point in the history
* fengwu_ghr: initial

* fengwu_ghr: fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Interpolate initial

* ImageMetaModel

* MetaModel initial

* tested metamodel

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* wrapper meta model

* RES

* load RES state_dict

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* bug fix

* bug fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* env yml fix

* fengwu_ghr: initial

fengwu_ghr: fixes

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Interpolate initial

ImageMetaModel

MetaModel initial

tested metamodel

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

wrapper meta model

RES

load RES state_dict

bug fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

env yml fix

* test_wrapper_meta_model

* tests fix

* parent 743cf97
author Lorenzo Breschi <[email protected]> 1716973130 +0200
committer Lorenzo Breschi <[email protected]> 1722343516 +0200

fengwu_ghr: initial

fengwu_ghr: fixes

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Interpolate initial

ImageMetaModel

MetaModel initial

tested metamodel

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

wrapper meta model

RES

load RES state_dict

bug fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

env yml fix

fengwu_ghr: initial

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Interpolate initial

ImageMetaModel

MetaModel initial

tested metamodel

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

wrapper meta model

RES

load RES state_dict

bug fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

env yml fix

test_wrapper_meta_model

tests fix

* fengwu_ghr: initial

fengwu_ghr: fixes

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Interpolate initial

ImageMetaModel

MetaModel initial

tested metamodel

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

wrapper meta model

RES

load RES state_dict

bug fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

env yml fix

* fengwu_ghr: initial

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Interpolate initial

* ImageMetaModel

* MetaModel initial

* tested metamodel

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* wrapper meta model

* RES

* load RES state_dict

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* added gcsfs to env yml

* __init__.py imports

* MetaModel long coordinates

* knn_interpolate gpu patch

* era5 training

* era5 training bugfix

* lora training

* pkg does not exist

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* env.yml bugfix

* Update environment_cpu.yml

* Update environment_cuda.yml

* Update environment_cuda.yml

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jacob Bieker <[email protected]>
  • Loading branch information
3 people authored Sep 23, 2024
1 parent 98e1ba7 commit cf2904f
Show file tree
Hide file tree
Showing 9 changed files with 457 additions and 8 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@
# pixi environments
.pixi
.vscode/
checkpoints/
lightning_logs/
6 changes: 5 additions & 1 deletion environment_cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@ dependencies:
- zarr
- h3-py
- numpy
- torch_harmonics
- pyshtools
- gcsfs
- pytest
- pip:
- setuptools
- datasets
- einops
- fsspec
Expand All @@ -36,3 +39,4 @@ dependencies:
- click
- trimesh
- rtree
- torch-harmonics
8 changes: 6 additions & 2 deletions environment_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ channels:
- conda-forge
- defaults
dependencies:
- pytorch-cuda=12.1
- pytorch-cuda
- numcodecs
- pandas
- pip
Expand All @@ -25,8 +25,11 @@ dependencies:
- zarr
- h3-py
- numpy
- torch_harmonics
- pyshtools
- gcsfs
- pytest
- pip:
- setuptools
- datasets
- einops
- fsspec
Expand All @@ -37,3 +40,4 @@ dependencies:
- click
- trimesh
- rtree
- torch-harmonics
8 changes: 7 additions & 1 deletion graph_weather/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
"""Models"""

from .fengwu_ghr.layers import ImageMetaModel, MetaModel, WrapperImageModel, WrapperMetaModel
from .fengwu_ghr.layers import (
ImageMetaModel,
LoRAModule,
MetaModel,
WrapperImageModel,
WrapperMetaModel,
)
from .layers.assimilator_decoder import AssimilatorDecoder
from .layers.assimilator_encoder import AssimilatorEncoder
from .layers.decoder import Decoder
Expand Down
3 changes: 3 additions & 0 deletions graph_weather/models/fengwu_ghr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Main import for FengWu-GHR"""

from .layers import ImageMetaModel, LoRAModule, MetaModel, WrapperImageModel, WrapperMetaModel
48 changes: 47 additions & 1 deletion graph_weather/models/fengwu_ghr/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def knn_interpolate(
squared_distance = (diff * diff).sum(dim=-1, keepdim=True)
weights = 1.0 / torch.clamp(squared_distance, min=1e-16)

y_idx, x_idx = y_idx.to(x.device), x_idx.to(x.device)
weights = weights.to(x.device)

den = scatter(weights, y_idx, 0, pos_y.size(0), reduce="sum")
y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce="sum")

Expand Down Expand Up @@ -228,6 +231,7 @@ def __init__(
)

def forward(self, x):
assert x.shape[1] == self.channels, "Wrong number of channels"
device = x.device
dtype = x.dtype

Expand Down Expand Up @@ -276,7 +280,7 @@ def __init__(
super().__init__()
self.i_h, self.i_w = pair(image_size)

self.pos_x = torch.tensor(lat_lons)
self.pos_x = torch.tensor(lat_lons).to(torch.long)
self.pos_y = torch.cartesian_prod(
(torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long),
(torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long),
Expand Down Expand Up @@ -344,3 +348,45 @@ def forward(self, x):
x = rearrange(x, "n (b c) -> b n c", b=b, c=c)

return x


class LoRALayer(nn.Module):
def __init__(self, linear_layer: nn.Module, r: int):
"""
Initialize LoRALayer.
Args:
linear_layer (nn.Module): Linear layer to be transformed.
r (int): rank of the low-rank matrix.
"""
super().__init__()
out_features, in_features = linear_layer.weight.shape

self.A = nn.Parameter(torch.randn(r, in_features))
self.B = nn.Parameter(torch.zeros(out_features, r))
self.linear_layer = linear_layer

def forward(self, x):
out = self.linear_layer(x) + self.B @ self.A @ x
return out


class LoRAModule(nn.Module):
def __init__(self, model, r=4):
"""
Initialize LoRAModule.
Args:
model (nn.Module): Model to be modified with LoRA layers.
r (int, optional): Rank of LoRA layers. Defaults to 4.
"""
super().__init__()
for name, layer in model.named_modules():
layer.eval()
if isinstance(layer, nn.Linear):
lora_layer = LoRALayer(layer, r)
setattr(model, name, lora_layer)
self.model = model

def forward(self, x):
return self.model(x)
33 changes: 30 additions & 3 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,6 @@ def test_image_meta_model():

out = model(image)
assert not torch.isnan(out).any()
assert not torch.isnan(out).any()
assert out.size() == image.size()


Expand All @@ -275,7 +274,6 @@ def test_wrapper_image_meta_model():
big_model = WrapperImageModel(model, scale_factor)
out = big_model(big_image)
assert not torch.isnan(out).any()
assert not torch.isnan(out).any()
assert out.size() == big_image.size()


Expand Down Expand Up @@ -303,5 +301,34 @@ def test_meta_model():

out = model(features)
assert not torch.isnan(out).any()
assert not torch.isnan(out).any()
assert out.size() == features.size()


def test_wrapper_meta_model():
lat_lons = []
for lat in range(-90, 90, 5):
for lon in range(0, 360, 5):
lat_lons.append((lat, lon))

batch = 2
channels = 3
image_size = 20
patch_size = 4
scale_factor = 3
model = MetaModel(
lat_lons,
image_size=image_size,
patch_size=patch_size,
depth=1,
heads=1,
mlp_dim=7,
channels=channels,
dim_head=64,
)

big_features = torch.randn((batch, len(lat_lons), channels))
big_model = WrapperMetaModel(lat_lons, model, scale_factor)
out = big_model(big_features)

assert not torch.isnan(out).any()
assert out.size() == big_features.size()
192 changes: 192 additions & 0 deletions train/era5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
from pathlib import Path

import numpy as np
import pytorch_lightning as pl
import torch
import xarray
from einops import rearrange
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader, Dataset

from graph_weather.models import MetaModel
from graph_weather.models.losses import NormalizedMSELoss


class LitFengWuGHR(pl.LightningModule):
"""
LightningModule for graph-based weather forecasting.
Attributes:
model (GraphWeatherForecaster): Graph weather forecaster model.
criterion (NormalizedMSELoss): Loss criterion for training.
lr : Learning rate for optimizer.
Methods:
__init__: Initialize the LitFengWuGHR object.
forward: Forward pass of the model.
training_step: Training step.
configure_optimizers: Configure the optimizer for training.
"""

def __init__(
self,
lat_lons: list,
*,
channels: int,
image_size,
patch_size=4,
depth=5,
heads=4,
mlp_dim=5,
feature_dim: int = 605, # TODO where does this come from?
lr: float = 3e-4,
):
"""
Initialize the LitFengWuGHR object with the required args.
Args:
lat_lons : List of latitude and longitude values.
feature_dim : Dimensionality of the input features.
aux_dim : Dimensionality of auxiliary features.
hidden_dim : Dimensionality of hidden layers in the model.
num_blocks : Number of graph convolutional blocks in the model.
lr (float): Learning rate for optimizer.
"""
super().__init__()
self.model = MetaModel(
lat_lons,
image_size=image_size,
patch_size=patch_size,
depth=depth,
heads=heads,
mlp_dim=mlp_dim,
channels=channels,
)
self.criterion = NormalizedMSELoss(
lat_lons=lat_lons, feature_variance=np.ones((feature_dim,))
)
self.lr = lr
self.save_hyperparameters()

def forward(self, x):
"""
Forward pass .
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor.
"""
return self.model(x)

def training_step(self, batch, batch_idx):
"""
Training step.
Args:
batch (array): Batch of data containing input and output tensors.
batch_idx (int): Index of the current batch.
Returns:
torch.Tensor: Loss tensor.
"""
x, y = batch[:, 0], batch[:, 1]
if torch.isnan(x).any() or torch.isnan(y).any():
return None
y_hat = self.forward(x)
loss = self.criterion(y_hat, y)
self.log("loss", loss, prog_bar=True)
return loss

def configure_optimizers(self):
"""
Configure the optimizer.
Returns:
torch.optim.Optimizer: Optimizer instance.
"""
return torch.optim.AdamW(self.parameters(), lr=self.lr)


class Era5Dataset(Dataset):
"""Era5 dataset."""

def __init__(self, xarr, transform=None):
"""
Arguments:
#TODO
"""
ds = np.asarray(xarr.to_array())
ds = torch.from_numpy(ds)
ds -= ds.min(0, keepdim=True)[0]
ds /= ds.max(0, keepdim=True)[0]
ds = rearrange(ds, "C T H W -> T (H W) C")
self.ds = ds

def __len__(self):
return len(self.ds) - 1

def __getitem__(self, index):
return self.ds[index : index + 2]


if __name__ == "__main__":

ckpt_path = Path("./checkpoints")
patch_size = 4
grid_step = 20
variables = [
"2m_temperature",
"surface_pressure",
"10m_u_component_of_wind",
"10m_v_component_of_wind",
]

channels = len(variables)
ckpt_path.mkdir(parents=True, exist_ok=True)

reanalysis = xarray.open_zarr(
"gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3",
storage_options=dict(token="anon"),
)

reanalysis = reanalysis.sel(time=slice("2020-01-01", "2021-01-01"))
reanalysis = reanalysis.isel(
time=slice(100, 107), longitude=slice(0, 1440, grid_step), latitude=slice(0, 721, grid_step)
)

reanalysis = reanalysis[variables]
print(f"size: {reanalysis.nbytes / (1024 ** 3)} GiB")

lat_lons = np.array(
np.meshgrid(
np.asarray(reanalysis["latitude"]).flatten(),
np.asarray(reanalysis["longitude"]).flatten(),
)
).T.reshape((-1, 2))

checkpoint_callback = ModelCheckpoint(dirpath=ckpt_path, save_top_k=1, monitor="loss")

dset = DataLoader(Era5Dataset(reanalysis), batch_size=10, num_workers=8)
model = LitFengWuGHR(
lat_lons=lat_lons,
channels=channels,
image_size=(721 // grid_step, 1440 // grid_step),
patch_size=patch_size,
depth=5,
heads=4,
mlp_dim=5,
)
trainer = pl.Trainer(
accelerator="gpu",
devices=-1,
max_epochs=100,
precision="16-mixed",
callbacks=[checkpoint_callback],
log_every_n_steps=3,
)

trainer.fit(model, dset)

torch.save(model.model.state_dict(), ckpt_path / "best.pt")
Loading

0 comments on commit cf2904f

Please sign in to comment.