From 66f2b6ff18214a3ceda6e5b24dd2cfe2dba8fc7d Mon Sep 17 00:00:00 2001 From: gbruno16 <72879691+gbruno16@users.noreply.github.com> Date: Fri, 8 Mar 2024 10:57:42 +0100 Subject: [PATCH] Fix example and loss normalization (#87) * Fix number of features in example * Fix loss normalization and add test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- README.md | 7 +++++-- graph_weather/models/losses.py | 12 ++++++++---- tests/test_model.py | 21 +++++++++++++++++++++ 3 files changed, 34 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 7117fd39..23570ce8 100644 --- a/README.md +++ b/README.md @@ -30,11 +30,14 @@ for lat in range(-90, 90, 1): lat_lons.append((lat, lon)) model = GraphWeatherForecaster(lat_lons) -features = torch.randn((2, len(lat_lons), 78)) +# Generate 78 random features + 24 non-NWP features (i.e. landsea mask) +features = torch.randn((2, len(lat_lons), 102)) +target = torch.randn((2, len(lat_lons), 78)) out = model(features) + criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,))) -loss = criterion(out, features) +loss = criterion(out, target) loss.backward() ``` diff --git a/graph_weather/models/losses.py b/graph_weather/models/losses.py index 5e5d27ec..e36c65cf 100644 --- a/graph_weather/models/losses.py +++ b/graph_weather/models/losses.py @@ -6,7 +6,9 @@ class NormalizedMSELoss(torch.nn.Module): """Loss function described in the paper""" - def __init__(self, feature_variance: list, lat_lons: list, device="cpu"): + def __init__( + self, feature_variance: list, lat_lons: list, device="cpu", normalize: bool = False + ): """ Normalized MSE Loss as described in the paper @@ -31,6 +33,7 @@ def __init__(self, feature_variance: list, lat_lons: list, device="cpu"): for lat, lon in lat_lons: weights.append(np.cos(lat * np.pi / 180.0)) self.weights = torch.tensor(weights, dtype=torch.float) + self.normalize = normalize assert not torch.isnan(self.weights).any() def forward(self, pred: torch.Tensor, target: torch.Tensor): @@ -50,10 +53,11 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor): self.feature_variance = self.feature_variance.to(pred.device) self.weights = self.weights.to(pred.device) - # pred = pred / self.feature_variance - # target = target / self.feature_variance - out = (pred - target) ** 2 + + if self.normalize: + out = out / self.feature_variance + assert not torch.isnan(out).any() # Mean of the physical variables out = out.mean(-1) diff --git a/tests/test_model.py b/tests/test_model.py index ff9f945e..58904292 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -201,3 +201,24 @@ def test_forecaster_and_loss_grad_checkpoint(): assert not torch.isnan(out).any() assert not torch.isnan(out).any() loss.backward() + + +def test_normalized_loss(): + lat_lons = [] + for lat in range(-90, 90, 5): + for lon in range(0, 360, 5): + lat_lons.append((lat, lon)) + + # Generate output as strictly positive random features + out = torch.rand((2, len(lat_lons), 78)) + 0.0001 + feature_variance = out**2 + target = torch.zeros((2, len(lat_lons), 78)) + + criterion = NormalizedMSELoss( + lat_lons=lat_lons, feature_variance=feature_variance, normalize=True + ) + loss = criterion(out, target) + + assert not torch.isnan(loss) + # Since feature_variance = out**2 and target = 0, we expect loss = weights + assert torch.isclose(loss, criterion.weights.expand_as(out.mean(-1)).mean())