diff --git a/tests/test_model.py b/tests/test_model.py index c7b24bee..e5604fce 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -358,4 +358,36 @@ def test_gencast_loss(): preds = torch.rand((batch_size, len(grid_lon), len(grid_lat), features_dim)) noise_levels = torch.rand((batch_size, 1)) targets = torch.rand((batch_size, len(grid_lon), len(grid_lat), features_dim)) - assert loss.forward(preds, targets, noise_levels) is not None + assert loss.forward(preds, noise_levels, targets) is not None + + +def test_gencast_denoiser(): + grid_lat = np.arange(-90, 90, 1) + grid_lon = np.arange(0, 360, 1) + input_features_dim = 10 + output_features_dim = 5 + batch_size = 3 + + denoiser = Denoiser( + grid_lon=grid_lon, + grid_lat=grid_lat, + input_features_dim=input_features_dim, + output_features_dim=output_features_dim, + hidden_dims=[16, 32], + num_blocks=3, + num_heads=4, + splits=0, + num_hops=1, + device=torch.device("cpu"), + ).eval() + + corrupted_targets = torch.randn((batch_size, len(grid_lon), len(grid_lat), output_features_dim)) + prev_inputs = torch.randn((batch_size, len(grid_lon), len(grid_lat), 2 * input_features_dim)) + noise_levels = torch.randn((batch_size, 1)) + + with torch.no_grad(): + preds = denoiser( + corrupted_targets=corrupted_targets, prev_inputs=prev_inputs, noise_levels=noise_levels + ) + + assert not torch.isnan(preds).any() \ No newline at end of file