Skip to content

Commit

Permalink
Fix example and loss normalization (#87)
Browse files Browse the repository at this point in the history
* Fix number of features in example

* Fix loss normalization and add test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
gbruno16 and pre-commit-ci[bot] authored Mar 8, 2024
1 parent 4293f2d commit 66f2b6f
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 6 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@ for lat in range(-90, 90, 1):
lat_lons.append((lat, lon))
model = GraphWeatherForecaster(lat_lons)

features = torch.randn((2, len(lat_lons), 78))
# Generate 78 random features + 24 non-NWP features (i.e. landsea mask)
features = torch.randn((2, len(lat_lons), 102))

target = torch.randn((2, len(lat_lons), 78))
out = model(features)

criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,)))
loss = criterion(out, features)
loss = criterion(out, target)
loss.backward()
```

Expand Down
12 changes: 8 additions & 4 deletions graph_weather/models/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
class NormalizedMSELoss(torch.nn.Module):
"""Loss function described in the paper"""

def __init__(self, feature_variance: list, lat_lons: list, device="cpu"):
def __init__(
self, feature_variance: list, lat_lons: list, device="cpu", normalize: bool = False
):
"""
Normalized MSE Loss as described in the paper
Expand All @@ -31,6 +33,7 @@ def __init__(self, feature_variance: list, lat_lons: list, device="cpu"):
for lat, lon in lat_lons:
weights.append(np.cos(lat * np.pi / 180.0))
self.weights = torch.tensor(weights, dtype=torch.float)
self.normalize = normalize
assert not torch.isnan(self.weights).any()

def forward(self, pred: torch.Tensor, target: torch.Tensor):
Expand All @@ -50,10 +53,11 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor):
self.feature_variance = self.feature_variance.to(pred.device)
self.weights = self.weights.to(pred.device)

# pred = pred / self.feature_variance
# target = target / self.feature_variance

out = (pred - target) ** 2

if self.normalize:
out = out / self.feature_variance

assert not torch.isnan(out).any()
# Mean of the physical variables
out = out.mean(-1)
Expand Down
21 changes: 21 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,24 @@ def test_forecaster_and_loss_grad_checkpoint():
assert not torch.isnan(out).any()
assert not torch.isnan(out).any()
loss.backward()


def test_normalized_loss():
lat_lons = []
for lat in range(-90, 90, 5):
for lon in range(0, 360, 5):
lat_lons.append((lat, lon))

# Generate output as strictly positive random features
out = torch.rand((2, len(lat_lons), 78)) + 0.0001
feature_variance = out**2
target = torch.zeros((2, len(lat_lons), 78))

criterion = NormalizedMSELoss(
lat_lons=lat_lons, feature_variance=feature_variance, normalize=True
)
loss = criterion(out, target)

assert not torch.isnan(loss)
# Since feature_variance = out**2 and target = 0, we expect loss = weights
assert torch.isclose(loss, criterion.weights.expand_as(out.mean(-1)).mean())

0 comments on commit 66f2b6f

Please sign in to comment.