Skip to content

Commit

Permalink
Merge pull request openclimatefix#53 from openclimatefix/all-contribu…
Browse files Browse the repository at this point in the history
…tors/add-najeeb-kazmi

docs: add najeeb-kazmi as a contributor for question
  • Loading branch information
peterdudfield authored Mar 24, 2023
2 parents 534cef5 + 15ccdee commit 68f6c55
Show file tree
Hide file tree
Showing 13 changed files with 55 additions and 32 deletions.
9 changes: 9 additions & 0 deletions .all-contributorsrc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@
"contributions": [
"question"
]
},
{
"login": "najeeb-kazmi",
"name": "Najeeb Kazmi",
"avatar_url": "https://avatars.githubusercontent.com/u/14131235?v=4",
"profile": "https://github.com/najeeb-kazmi",
"contributions": [
"question"
]
}
],
"contributorsPerLine": 7,
Expand Down
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Skillful Nowcasting with Deep Generative Model of Radar (DGMR)
<!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
[![All Contributors](https://img.shields.io/badge/all_contributors-7-orange.svg?style=flat-square)](#contributors-)
[![All Contributors](https://img.shields.io/badge/all_contributors-8-orange.svg?style=flat-square)](#contributors-)
<!-- ALL-CONTRIBUTORS-BADGE:END -->
Implementation of DeepMind's Skillful Nowcasting GAN Deep Generative Model of Radar (DGMR) (https://arxiv.org/abs/2104.00954) in PyTorch Lightning.

Expand Down Expand Up @@ -117,6 +117,9 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
<td align="center" valign="top" width="14.28%"><a href="https://github.com/primeoc"><img src="https://avatars.githubusercontent.com/u/75205487?v=4?s=100" width="100px;" alt="cameron"/><br /><sub><b>cameron</b></sub></a><br /><a href="#question-primeoc" title="Answering Questions">💬</a></td>
<td align="center" valign="top" width="14.28%"><a href="https://github.com/zhrli"><img src="https://avatars.githubusercontent.com/u/11074703?v=4?s=100" width="100px;" alt="zhrli"/><br /><sub><b>zhrli</b></sub></a><br /><a href="#question-zhrli" title="Answering Questions">💬</a></td>
</tr>
<tr>
<td align="center" valign="top" width="14.28%"><a href="https://github.com/najeeb-kazmi"><img src="https://avatars.githubusercontent.com/u/14131235?v=4?s=100" width="100px;" alt="Najeeb Kazmi"/><br /><sub><b>Najeeb Kazmi</b></sub></a><br /><a href="#question-najeeb-kazmi" title="Answering Questions">💬</a></td>
</tr>
</tbody>
</table>
Expand Down
6 changes: 3 additions & 3 deletions dgmr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .common import ContextConditioningStack, LatentConditioningStack
from .dgmr import DGMR
from .generators import Sampler, Generator
from .discriminators import SpatialDiscriminator, TemporalDiscriminator, Discriminator
from .common import LatentConditioningStack, ContextConditioningStack
from .discriminators import Discriminator, SpatialDiscriminator, TemporalDiscriminator
from .generators import Generator, Sampler
7 changes: 4 additions & 3 deletions dgmr/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import einops
import torch
import torch.nn.functional as F
from huggingface_hub import PyTorchModelHubMixin
from torch.distributions import normal
from torch.nn.utils.parametrizations import spectral_norm
from torch.nn.modules.pixelshuffle import PixelUnshuffle
from dgmr.layers.utils import get_conv_layer
from torch.nn.utils.parametrizations import spectral_norm

from dgmr.layers import AttentionLayer
from huggingface_hub import PyTorchModelHubMixin
from dgmr.layers.utils import get_conv_layer


class GBlock(torch.nn.Module):
Expand Down
17 changes: 9 additions & 8 deletions dgmr/dgmr.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import pytorch_lightning as pl
import torch
import torchvision

from dgmr.common import ContextConditioningStack, LatentConditioningStack
from dgmr.discriminators import Discriminator
from dgmr.generators import Generator, Sampler
from dgmr.hub import NowcastingModelHubMixin
from dgmr.losses import (
NowcastingLoss,
GridCellLoss,
NowcastingLoss,
grid_cell_regularizer,
loss_hinge_disc,
loss_hinge_gen,
grid_cell_regularizer,
)
import pytorch_lightning as pl
import torchvision
from dgmr.common import LatentConditioningStack, ContextConditioningStack
from dgmr.generators import Sampler, Generator
from dgmr.discriminators import Discriminator
from dgmr.hub import NowcastingModelHubMixin


class DGMR(pl.LightningModule, NowcastingModelHubMixin):
Expand Down
5 changes: 3 additions & 2 deletions dgmr/discriminators.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch
import torch.nn.functional as F
from huggingface_hub import PyTorchModelHubMixin
from torch.nn.modules.pixelshuffle import PixelUnshuffle
from torch.nn.utils.parametrizations import spectral_norm
import torch.nn.functional as F

from dgmr.common import DBlock
from huggingface_hub import PyTorchModelHubMixin


class Discriminator(torch.nn.Module, PyTorchModelHubMixin):
Expand Down
9 changes: 6 additions & 3 deletions dgmr/generators.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import logging
from typing import List

import einops
import torch
import torch.nn.functional as F
from huggingface_hub import PyTorchModelHubMixin
from torch.nn.modules.pixelshuffle import PixelShuffle
from torch.nn.utils.parametrizations import spectral_norm
from typing import List

from dgmr.common import GBlock, UpsampleGBlock
from dgmr.layers import ConvGRU
from huggingface_hub import PyTorchModelHubMixin
import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARN)
Expand All @@ -27,6 +29,7 @@ def __init__(
The sampler takes the output from the Latent and Context conditioning stacks and
creates one stack of ConvGRU layers per future timestep.
Args:
forecast_steps: Number of forecast steps
latent_channels: Number of input channels to the lowest ConvGRU layer
Expand Down
1 change: 0 additions & 1 deletion dgmr/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import torch


try:
from huggingface_hub import cached_download, hf_hub_url

Expand Down
2 changes: 1 addition & 1 deletion dgmr/layers/Attention.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import einops
import torch
import torch.nn as nn
from torch.nn import functional as F
import einops


def attention_einsum(q, k, v):
Expand Down
1 change: 1 addition & 0 deletions dgmr/layers/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch

from dgmr.layers import CoordConv


Expand Down
6 changes: 5 additions & 1 deletion dgmr/losses.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import numpy as np
import torch
import torch.nn as nn
from pytorch_msssim import SSIM, MS_SSIM
from pytorch_msssim import MS_SSIM, SSIM
from torch.nn import functional as F


class SSIMLoss(nn.Module):
def __init__(self, convert_range: bool = False, **kwargs):
"""
SSIM Loss, optionally converting input range from [-1,1] to [0,1]
Args:
convert_range:
**kwargs:
Expand All @@ -28,6 +29,7 @@ class MS_SSIMLoss(nn.Module):
def __init__(self, convert_range: bool = False, **kwargs):
"""
Multi-Scale SSIM Loss, optionally converting input range from [-1,1] to [0,1]
Args:
convert_range:
**kwargs:
Expand Down Expand Up @@ -78,6 +80,7 @@ def tv_loss(img, tv_weight):
Inputs:
- img: PyTorch Variable of shape (1, 3, H, W) holding an input image.
- tv_weight: Scalar giving the weight w_t to use for the TV loss.
Returns:
- loss: PyTorch Variable holding a scalar giving the total variation loss
for img weighted by tv_weight.
Expand Down Expand Up @@ -138,6 +141,7 @@ def forward(self, generated_images, targets):
"""
Calculates the grid cell regularizer value, assumes generated images are the mean predictions from
6 calls to the generater (Monte Carlo estimation of the expectations for the latent variable)
Args:
generated_images: Mean generated images from the generator
targets: Ground truth future frames
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from setuptools import setup, find_packages
from pathlib import Path

from setuptools import find_packages, setup

this_directory = Path(__file__).parent
install_requires = (this_directory / "requirements.txt").read_text().splitlines()
long_description = (this_directory / "README.md").read_text()
Expand Down
16 changes: 8 additions & 8 deletions train/run.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import torch.utils.data.dataset
from dgmr import DGMR
import wandb
from datasets import load_dataset
from torch.utils.data import DataLoader
from pytorch_lightning import (
LightningDataModule,
)
from pytorch_lightning.callbacks import ModelCheckpoint
import wandb
from torch.utils.data import DataLoader

from dgmr import DGMR

wandb.init(project="dgmr")
from numpy.random import default_rng
import os
import numpy as np
from pathlib import Path
import tensorflow as tf

import numpy as np
from numpy.random import default_rng
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.loggers import LoggerCollection, WandbLogger
from pytorch_lightning.utilities import rank_zero_only
Expand Down Expand Up @@ -136,7 +136,7 @@ def __len__(self):
def __getitem__(self, item):
try:
row = next(self.iter_reader)
except Exception as e:
except Exception:
rng = default_rng()
self.iter_reader = iter(
self.reader.shuffle(seed=rng.integers(low=0, high=100000), buffer_size=1000)
Expand Down

0 comments on commit 68f6c55

Please sign in to comment.