From 7336349e05e00189d11af98e04f8c5a8b5f36b02 Mon Sep 17 00:00:00 2001
From: "allcontributors[bot]"
<46447321+allcontributors[bot]@users.noreply.github.com>
Date: Fri, 24 Mar 2023 09:54:26 +0000
Subject: [PATCH 1/3] docs: update README.md [skip ci]
---
README.md | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/README.md b/README.md
index f2b79c0..5e13681 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
# Skillful Nowcasting with Deep Generative Model of Radar (DGMR)
-[![All Contributors](https://img.shields.io/badge/all_contributors-7-orange.svg?style=flat-square)](#contributors-)
+[![All Contributors](https://img.shields.io/badge/all_contributors-8-orange.svg?style=flat-square)](#contributors-)
Implementation of DeepMind's Skillful Nowcasting GAN Deep Generative Model of Radar (DGMR) (https://arxiv.org/abs/2104.00954) in PyTorch Lightning.
@@ -117,6 +117,9 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
cameron 💬 |
zhrli 💬 |
+
+ Najeeb Kazmi 💬 |
+
From 40a454ad716b4a7bcedca845077c79d32fa09d8d Mon Sep 17 00:00:00 2001
From: "allcontributors[bot]"
<46447321+allcontributors[bot]@users.noreply.github.com>
Date: Fri, 24 Mar 2023 09:54:27 +0000
Subject: [PATCH 2/3] docs: update .all-contributorsrc [skip ci]
---
.all-contributorsrc | 9 +++++++++
1 file changed, 9 insertions(+)
diff --git a/.all-contributorsrc b/.all-contributorsrc
index 1bd30b1..d151bbc 100644
--- a/.all-contributorsrc
+++ b/.all-contributorsrc
@@ -67,6 +67,15 @@
"contributions": [
"question"
]
+ },
+ {
+ "login": "najeeb-kazmi",
+ "name": "Najeeb Kazmi",
+ "avatar_url": "https://avatars.githubusercontent.com/u/14131235?v=4",
+ "profile": "https://github.com/najeeb-kazmi",
+ "contributions": [
+ "question"
+ ]
}
],
"contributorsPerLine": 7,
From 15ccdeef6c7670943cd6ab4d435c17b70215f5dc Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
<66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Fri, 24 Mar 2023 09:54:34 +0000
Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---
dgmr/__init__.py | 6 +++---
dgmr/common.py | 7 ++++---
dgmr/dgmr.py | 17 +++++++++--------
dgmr/discriminators.py | 5 +++--
dgmr/generators.py | 9 ++++++---
dgmr/hub.py | 1 -
dgmr/layers/Attention.py | 2 +-
dgmr/layers/utils.py | 1 +
dgmr/losses.py | 6 +++++-
setup.py | 3 ++-
train/run.py | 16 ++++++++--------
11 files changed, 42 insertions(+), 31 deletions(-)
diff --git a/dgmr/__init__.py b/dgmr/__init__.py
index b9cd0b7..a1a3190 100644
--- a/dgmr/__init__.py
+++ b/dgmr/__init__.py
@@ -1,4 +1,4 @@
+from .common import ContextConditioningStack, LatentConditioningStack
from .dgmr import DGMR
-from .generators import Sampler, Generator
-from .discriminators import SpatialDiscriminator, TemporalDiscriminator, Discriminator
-from .common import LatentConditioningStack, ContextConditioningStack
+from .discriminators import Discriminator, SpatialDiscriminator, TemporalDiscriminator
+from .generators import Generator, Sampler
diff --git a/dgmr/common.py b/dgmr/common.py
index f6e13da..70b20fc 100644
--- a/dgmr/common.py
+++ b/dgmr/common.py
@@ -3,12 +3,13 @@
import einops
import torch
import torch.nn.functional as F
+from huggingface_hub import PyTorchModelHubMixin
from torch.distributions import normal
-from torch.nn.utils.parametrizations import spectral_norm
from torch.nn.modules.pixelshuffle import PixelUnshuffle
-from dgmr.layers.utils import get_conv_layer
+from torch.nn.utils.parametrizations import spectral_norm
+
from dgmr.layers import AttentionLayer
-from huggingface_hub import PyTorchModelHubMixin
+from dgmr.layers.utils import get_conv_layer
class GBlock(torch.nn.Module):
diff --git a/dgmr/dgmr.py b/dgmr/dgmr.py
index c2772d8..d1a71af 100644
--- a/dgmr/dgmr.py
+++ b/dgmr/dgmr.py
@@ -1,17 +1,18 @@
+import pytorch_lightning as pl
import torch
+import torchvision
+
+from dgmr.common import ContextConditioningStack, LatentConditioningStack
+from dgmr.discriminators import Discriminator
+from dgmr.generators import Generator, Sampler
+from dgmr.hub import NowcastingModelHubMixin
from dgmr.losses import (
- NowcastingLoss,
GridCellLoss,
+ NowcastingLoss,
+ grid_cell_regularizer,
loss_hinge_disc,
loss_hinge_gen,
- grid_cell_regularizer,
)
-import pytorch_lightning as pl
-import torchvision
-from dgmr.common import LatentConditioningStack, ContextConditioningStack
-from dgmr.generators import Sampler, Generator
-from dgmr.discriminators import Discriminator
-from dgmr.hub import NowcastingModelHubMixin
class DGMR(pl.LightningModule, NowcastingModelHubMixin):
diff --git a/dgmr/discriminators.py b/dgmr/discriminators.py
index 64cde4c..a2711e3 100644
--- a/dgmr/discriminators.py
+++ b/dgmr/discriminators.py
@@ -1,9 +1,10 @@
import torch
+import torch.nn.functional as F
+from huggingface_hub import PyTorchModelHubMixin
from torch.nn.modules.pixelshuffle import PixelUnshuffle
from torch.nn.utils.parametrizations import spectral_norm
-import torch.nn.functional as F
+
from dgmr.common import DBlock
-from huggingface_hub import PyTorchModelHubMixin
class Discriminator(torch.nn.Module, PyTorchModelHubMixin):
diff --git a/dgmr/generators.py b/dgmr/generators.py
index bb0758c..c4b744e 100644
--- a/dgmr/generators.py
+++ b/dgmr/generators.py
@@ -1,13 +1,15 @@
+import logging
+from typing import List
+
import einops
import torch
import torch.nn.functional as F
+from huggingface_hub import PyTorchModelHubMixin
from torch.nn.modules.pixelshuffle import PixelShuffle
from torch.nn.utils.parametrizations import spectral_norm
-from typing import List
+
from dgmr.common import GBlock, UpsampleGBlock
from dgmr.layers import ConvGRU
-from huggingface_hub import PyTorchModelHubMixin
-import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARN)
@@ -27,6 +29,7 @@ def __init__(
The sampler takes the output from the Latent and Context conditioning stacks and
creates one stack of ConvGRU layers per future timestep.
+
Args:
forecast_steps: Number of forecast steps
latent_channels: Number of input channels to the lowest ConvGRU layer
diff --git a/dgmr/hub.py b/dgmr/hub.py
index d6f4d23..5731b6b 100644
--- a/dgmr/hub.py
+++ b/dgmr/hub.py
@@ -12,7 +12,6 @@
import torch
-
try:
from huggingface_hub import cached_download, hf_hub_url
diff --git a/dgmr/layers/Attention.py b/dgmr/layers/Attention.py
index 54895da..a9e251c 100644
--- a/dgmr/layers/Attention.py
+++ b/dgmr/layers/Attention.py
@@ -1,7 +1,7 @@
+import einops
import torch
import torch.nn as nn
from torch.nn import functional as F
-import einops
def attention_einsum(q, k, v):
diff --git a/dgmr/layers/utils.py b/dgmr/layers/utils.py
index 0c1f35f..86682f3 100644
--- a/dgmr/layers/utils.py
+++ b/dgmr/layers/utils.py
@@ -1,4 +1,5 @@
import torch
+
from dgmr.layers import CoordConv
diff --git a/dgmr/losses.py b/dgmr/losses.py
index 6281f1d..061ebbb 100644
--- a/dgmr/losses.py
+++ b/dgmr/losses.py
@@ -1,7 +1,7 @@
import numpy as np
import torch
import torch.nn as nn
-from pytorch_msssim import SSIM, MS_SSIM
+from pytorch_msssim import MS_SSIM, SSIM
from torch.nn import functional as F
@@ -9,6 +9,7 @@ class SSIMLoss(nn.Module):
def __init__(self, convert_range: bool = False, **kwargs):
"""
SSIM Loss, optionally converting input range from [-1,1] to [0,1]
+
Args:
convert_range:
**kwargs:
@@ -28,6 +29,7 @@ class MS_SSIMLoss(nn.Module):
def __init__(self, convert_range: bool = False, **kwargs):
"""
Multi-Scale SSIM Loss, optionally converting input range from [-1,1] to [0,1]
+
Args:
convert_range:
**kwargs:
@@ -78,6 +80,7 @@ def tv_loss(img, tv_weight):
Inputs:
- img: PyTorch Variable of shape (1, 3, H, W) holding an input image.
- tv_weight: Scalar giving the weight w_t to use for the TV loss.
+
Returns:
- loss: PyTorch Variable holding a scalar giving the total variation loss
for img weighted by tv_weight.
@@ -138,6 +141,7 @@ def forward(self, generated_images, targets):
"""
Calculates the grid cell regularizer value, assumes generated images are the mean predictions from
6 calls to the generater (Monte Carlo estimation of the expectations for the latent variable)
+
Args:
generated_images: Mean generated images from the generator
targets: Ground truth future frames
diff --git a/setup.py b/setup.py
index 4607da8..131d4ea 100644
--- a/setup.py
+++ b/setup.py
@@ -1,6 +1,7 @@
-from setuptools import setup, find_packages
from pathlib import Path
+from setuptools import find_packages, setup
+
this_directory = Path(__file__).parent
install_requires = (this_directory / "requirements.txt").read_text().splitlines()
long_description = (this_directory / "README.md").read_text()
diff --git a/train/run.py b/train/run.py
index 8d27fc9..27b158b 100644
--- a/train/run.py
+++ b/train/run.py
@@ -1,19 +1,19 @@
import torch.utils.data.dataset
-from dgmr import DGMR
+import wandb
from datasets import load_dataset
-from torch.utils.data import DataLoader
from pytorch_lightning import (
LightningDataModule,
)
from pytorch_lightning.callbacks import ModelCheckpoint
-import wandb
+from torch.utils.data import DataLoader
+
+from dgmr import DGMR
wandb.init(project="dgmr")
-from numpy.random import default_rng
-import os
-import numpy as np
from pathlib import Path
-import tensorflow as tf
+
+import numpy as np
+from numpy.random import default_rng
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.loggers import LoggerCollection, WandbLogger
from pytorch_lightning.utilities import rank_zero_only
@@ -136,7 +136,7 @@ def __len__(self):
def __getitem__(self, item):
try:
row = next(self.iter_reader)
- except Exception as e:
+ except Exception:
rng = default_rng()
self.iter_reader = iter(
self.reader.shuffle(seed=rng.integers(low=0, high=100000), buffer_size=1000)