Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
rnwzd committed Jul 2, 2024
1 parent fe82edc commit 325fd0e
Showing 1 changed file with 33 additions and 1 deletion.
34 changes: 33 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,4 +358,36 @@ def test_gencast_loss():
preds = torch.rand((batch_size, len(grid_lon), len(grid_lat), features_dim))
noise_levels = torch.rand((batch_size, 1))
targets = torch.rand((batch_size, len(grid_lon), len(grid_lat), features_dim))
assert loss.forward(preds, targets, noise_levels) is not None
assert loss.forward(preds, noise_levels, targets) is not None


def test_gencast_denoiser():
grid_lat = np.arange(-90, 90, 1)
grid_lon = np.arange(0, 360, 1)
input_features_dim = 10
output_features_dim = 5
batch_size = 3

denoiser = Denoiser(
grid_lon=grid_lon,
grid_lat=grid_lat,
input_features_dim=input_features_dim,
output_features_dim=output_features_dim,
hidden_dims=[16, 32],
num_blocks=3,
num_heads=4,
splits=0,
num_hops=1,
device=torch.device("cpu"),
).eval()

corrupted_targets = torch.randn((batch_size, len(grid_lon), len(grid_lat), output_features_dim))
prev_inputs = torch.randn((batch_size, len(grid_lon), len(grid_lat), 2 * input_features_dim))
noise_levels = torch.randn((batch_size, 1))

with torch.no_grad():
preds = denoiser(
corrupted_targets=corrupted_targets, prev_inputs=prev_inputs, noise_levels=noise_levels
)

assert not torch.isnan(preds).any()

0 comments on commit 325fd0e

Please sign in to comment.