From 7336349e05e00189d11af98e04f8c5a8b5f36b02 Mon Sep 17 00:00:00 2001 From: "allcontributors[bot]" <46447321+allcontributors[bot]@users.noreply.github.com> Date: Fri, 24 Mar 2023 09:54:26 +0000 Subject: [PATCH 1/3] docs: update README.md [skip ci] --- README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index f2b79c0..5e13681 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Skillful Nowcasting with Deep Generative Model of Radar (DGMR) -[![All Contributors](https://img.shields.io/badge/all_contributors-7-orange.svg?style=flat-square)](#contributors-) +[![All Contributors](https://img.shields.io/badge/all_contributors-8-orange.svg?style=flat-square)](#contributors-) Implementation of DeepMind's Skillful Nowcasting GAN Deep Generative Model of Radar (DGMR) (https://arxiv.org/abs/2104.00954) in PyTorch Lightning. @@ -117,6 +117,9 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d cameron
cameron

💬 zhrli
zhrli

💬 + + Najeeb Kazmi
Najeeb Kazmi

💬 + From 40a454ad716b4a7bcedca845077c79d32fa09d8d Mon Sep 17 00:00:00 2001 From: "allcontributors[bot]" <46447321+allcontributors[bot]@users.noreply.github.com> Date: Fri, 24 Mar 2023 09:54:27 +0000 Subject: [PATCH 2/3] docs: update .all-contributorsrc [skip ci] --- .all-contributorsrc | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.all-contributorsrc b/.all-contributorsrc index 1bd30b1..d151bbc 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -67,6 +67,15 @@ "contributions": [ "question" ] + }, + { + "login": "najeeb-kazmi", + "name": "Najeeb Kazmi", + "avatar_url": "https://avatars.githubusercontent.com/u/14131235?v=4", + "profile": "https://github.com/najeeb-kazmi", + "contributions": [ + "question" + ] } ], "contributorsPerLine": 7, From 15ccdeef6c7670943cd6ab4d435c17b70215f5dc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 24 Mar 2023 09:54:34 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- dgmr/__init__.py | 6 +++--- dgmr/common.py | 7 ++++--- dgmr/dgmr.py | 17 +++++++++-------- dgmr/discriminators.py | 5 +++-- dgmr/generators.py | 9 ++++++--- dgmr/hub.py | 1 - dgmr/layers/Attention.py | 2 +- dgmr/layers/utils.py | 1 + dgmr/losses.py | 6 +++++- setup.py | 3 ++- train/run.py | 16 ++++++++-------- 11 files changed, 42 insertions(+), 31 deletions(-) diff --git a/dgmr/__init__.py b/dgmr/__init__.py index b9cd0b7..a1a3190 100644 --- a/dgmr/__init__.py +++ b/dgmr/__init__.py @@ -1,4 +1,4 @@ +from .common import ContextConditioningStack, LatentConditioningStack from .dgmr import DGMR -from .generators import Sampler, Generator -from .discriminators import SpatialDiscriminator, TemporalDiscriminator, Discriminator -from .common import LatentConditioningStack, ContextConditioningStack +from .discriminators import Discriminator, SpatialDiscriminator, TemporalDiscriminator +from .generators import Generator, Sampler diff --git a/dgmr/common.py b/dgmr/common.py index f6e13da..70b20fc 100644 --- a/dgmr/common.py +++ b/dgmr/common.py @@ -3,12 +3,13 @@ import einops import torch import torch.nn.functional as F +from huggingface_hub import PyTorchModelHubMixin from torch.distributions import normal -from torch.nn.utils.parametrizations import spectral_norm from torch.nn.modules.pixelshuffle import PixelUnshuffle -from dgmr.layers.utils import get_conv_layer +from torch.nn.utils.parametrizations import spectral_norm + from dgmr.layers import AttentionLayer -from huggingface_hub import PyTorchModelHubMixin +from dgmr.layers.utils import get_conv_layer class GBlock(torch.nn.Module): diff --git a/dgmr/dgmr.py b/dgmr/dgmr.py index c2772d8..d1a71af 100644 --- a/dgmr/dgmr.py +++ b/dgmr/dgmr.py @@ -1,17 +1,18 @@ +import pytorch_lightning as pl import torch +import torchvision + +from dgmr.common import ContextConditioningStack, LatentConditioningStack +from dgmr.discriminators import Discriminator +from dgmr.generators import Generator, Sampler +from dgmr.hub import NowcastingModelHubMixin from dgmr.losses import ( - NowcastingLoss, GridCellLoss, + NowcastingLoss, + grid_cell_regularizer, loss_hinge_disc, loss_hinge_gen, - grid_cell_regularizer, ) -import pytorch_lightning as pl -import torchvision -from dgmr.common import LatentConditioningStack, ContextConditioningStack -from dgmr.generators import Sampler, Generator -from dgmr.discriminators import Discriminator -from dgmr.hub import NowcastingModelHubMixin class DGMR(pl.LightningModule, NowcastingModelHubMixin): diff --git a/dgmr/discriminators.py b/dgmr/discriminators.py index 64cde4c..a2711e3 100644 --- a/dgmr/discriminators.py +++ b/dgmr/discriminators.py @@ -1,9 +1,10 @@ import torch +import torch.nn.functional as F +from huggingface_hub import PyTorchModelHubMixin from torch.nn.modules.pixelshuffle import PixelUnshuffle from torch.nn.utils.parametrizations import spectral_norm -import torch.nn.functional as F + from dgmr.common import DBlock -from huggingface_hub import PyTorchModelHubMixin class Discriminator(torch.nn.Module, PyTorchModelHubMixin): diff --git a/dgmr/generators.py b/dgmr/generators.py index bb0758c..c4b744e 100644 --- a/dgmr/generators.py +++ b/dgmr/generators.py @@ -1,13 +1,15 @@ +import logging +from typing import List + import einops import torch import torch.nn.functional as F +from huggingface_hub import PyTorchModelHubMixin from torch.nn.modules.pixelshuffle import PixelShuffle from torch.nn.utils.parametrizations import spectral_norm -from typing import List + from dgmr.common import GBlock, UpsampleGBlock from dgmr.layers import ConvGRU -from huggingface_hub import PyTorchModelHubMixin -import logging logger = logging.getLogger(__name__) logger.setLevel(logging.WARN) @@ -27,6 +29,7 @@ def __init__( The sampler takes the output from the Latent and Context conditioning stacks and creates one stack of ConvGRU layers per future timestep. + Args: forecast_steps: Number of forecast steps latent_channels: Number of input channels to the lowest ConvGRU layer diff --git a/dgmr/hub.py b/dgmr/hub.py index d6f4d23..5731b6b 100644 --- a/dgmr/hub.py +++ b/dgmr/hub.py @@ -12,7 +12,6 @@ import torch - try: from huggingface_hub import cached_download, hf_hub_url diff --git a/dgmr/layers/Attention.py b/dgmr/layers/Attention.py index 54895da..a9e251c 100644 --- a/dgmr/layers/Attention.py +++ b/dgmr/layers/Attention.py @@ -1,7 +1,7 @@ +import einops import torch import torch.nn as nn from torch.nn import functional as F -import einops def attention_einsum(q, k, v): diff --git a/dgmr/layers/utils.py b/dgmr/layers/utils.py index 0c1f35f..86682f3 100644 --- a/dgmr/layers/utils.py +++ b/dgmr/layers/utils.py @@ -1,4 +1,5 @@ import torch + from dgmr.layers import CoordConv diff --git a/dgmr/losses.py b/dgmr/losses.py index 6281f1d..061ebbb 100644 --- a/dgmr/losses.py +++ b/dgmr/losses.py @@ -1,7 +1,7 @@ import numpy as np import torch import torch.nn as nn -from pytorch_msssim import SSIM, MS_SSIM +from pytorch_msssim import MS_SSIM, SSIM from torch.nn import functional as F @@ -9,6 +9,7 @@ class SSIMLoss(nn.Module): def __init__(self, convert_range: bool = False, **kwargs): """ SSIM Loss, optionally converting input range from [-1,1] to [0,1] + Args: convert_range: **kwargs: @@ -28,6 +29,7 @@ class MS_SSIMLoss(nn.Module): def __init__(self, convert_range: bool = False, **kwargs): """ Multi-Scale SSIM Loss, optionally converting input range from [-1,1] to [0,1] + Args: convert_range: **kwargs: @@ -78,6 +80,7 @@ def tv_loss(img, tv_weight): Inputs: - img: PyTorch Variable of shape (1, 3, H, W) holding an input image. - tv_weight: Scalar giving the weight w_t to use for the TV loss. + Returns: - loss: PyTorch Variable holding a scalar giving the total variation loss for img weighted by tv_weight. @@ -138,6 +141,7 @@ def forward(self, generated_images, targets): """ Calculates the grid cell regularizer value, assumes generated images are the mean predictions from 6 calls to the generater (Monte Carlo estimation of the expectations for the latent variable) + Args: generated_images: Mean generated images from the generator targets: Ground truth future frames diff --git a/setup.py b/setup.py index 4607da8..131d4ea 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,7 @@ -from setuptools import setup, find_packages from pathlib import Path +from setuptools import find_packages, setup + this_directory = Path(__file__).parent install_requires = (this_directory / "requirements.txt").read_text().splitlines() long_description = (this_directory / "README.md").read_text() diff --git a/train/run.py b/train/run.py index 8d27fc9..27b158b 100644 --- a/train/run.py +++ b/train/run.py @@ -1,19 +1,19 @@ import torch.utils.data.dataset -from dgmr import DGMR +import wandb from datasets import load_dataset -from torch.utils.data import DataLoader from pytorch_lightning import ( LightningDataModule, ) from pytorch_lightning.callbacks import ModelCheckpoint -import wandb +from torch.utils.data import DataLoader + +from dgmr import DGMR wandb.init(project="dgmr") -from numpy.random import default_rng -import os -import numpy as np from pathlib import Path -import tensorflow as tf + +import numpy as np +from numpy.random import default_rng from pytorch_lightning import Callback, Trainer from pytorch_lightning.loggers import LoggerCollection, WandbLogger from pytorch_lightning.utilities import rank_zero_only @@ -136,7 +136,7 @@ def __len__(self): def __getitem__(self, item): try: row = next(self.iter_reader) - except Exception as e: + except Exception: rng = default_rng() self.iter_reader = iter( self.reader.shuffle(seed=rng.integers(low=0, high=100000), buffer_size=1000)