diff --git a/dgmr/__init__.py b/dgmr/__init__.py index b9cd0b7..a1a3190 100644 --- a/dgmr/__init__.py +++ b/dgmr/__init__.py @@ -1,4 +1,4 @@ +from .common import ContextConditioningStack, LatentConditioningStack from .dgmr import DGMR -from .generators import Sampler, Generator -from .discriminators import SpatialDiscriminator, TemporalDiscriminator, Discriminator -from .common import LatentConditioningStack, ContextConditioningStack +from .discriminators import Discriminator, SpatialDiscriminator, TemporalDiscriminator +from .generators import Generator, Sampler diff --git a/dgmr/common.py b/dgmr/common.py index f6e13da..70b20fc 100644 --- a/dgmr/common.py +++ b/dgmr/common.py @@ -3,12 +3,13 @@ import einops import torch import torch.nn.functional as F +from huggingface_hub import PyTorchModelHubMixin from torch.distributions import normal -from torch.nn.utils.parametrizations import spectral_norm from torch.nn.modules.pixelshuffle import PixelUnshuffle -from dgmr.layers.utils import get_conv_layer +from torch.nn.utils.parametrizations import spectral_norm + from dgmr.layers import AttentionLayer -from huggingface_hub import PyTorchModelHubMixin +from dgmr.layers.utils import get_conv_layer class GBlock(torch.nn.Module): diff --git a/dgmr/dgmr.py b/dgmr/dgmr.py index c2772d8..d1a71af 100644 --- a/dgmr/dgmr.py +++ b/dgmr/dgmr.py @@ -1,17 +1,18 @@ +import pytorch_lightning as pl import torch +import torchvision + +from dgmr.common import ContextConditioningStack, LatentConditioningStack +from dgmr.discriminators import Discriminator +from dgmr.generators import Generator, Sampler +from dgmr.hub import NowcastingModelHubMixin from dgmr.losses import ( - NowcastingLoss, GridCellLoss, + NowcastingLoss, + grid_cell_regularizer, loss_hinge_disc, loss_hinge_gen, - grid_cell_regularizer, ) -import pytorch_lightning as pl -import torchvision -from dgmr.common import LatentConditioningStack, ContextConditioningStack -from dgmr.generators import Sampler, Generator -from dgmr.discriminators import Discriminator -from dgmr.hub import NowcastingModelHubMixin class DGMR(pl.LightningModule, NowcastingModelHubMixin): diff --git a/dgmr/discriminators.py b/dgmr/discriminators.py index 64cde4c..a2711e3 100644 --- a/dgmr/discriminators.py +++ b/dgmr/discriminators.py @@ -1,9 +1,10 @@ import torch +import torch.nn.functional as F +from huggingface_hub import PyTorchModelHubMixin from torch.nn.modules.pixelshuffle import PixelUnshuffle from torch.nn.utils.parametrizations import spectral_norm -import torch.nn.functional as F + from dgmr.common import DBlock -from huggingface_hub import PyTorchModelHubMixin class Discriminator(torch.nn.Module, PyTorchModelHubMixin): diff --git a/dgmr/generators.py b/dgmr/generators.py index bb0758c..c4b744e 100644 --- a/dgmr/generators.py +++ b/dgmr/generators.py @@ -1,13 +1,15 @@ +import logging +from typing import List + import einops import torch import torch.nn.functional as F +from huggingface_hub import PyTorchModelHubMixin from torch.nn.modules.pixelshuffle import PixelShuffle from torch.nn.utils.parametrizations import spectral_norm -from typing import List + from dgmr.common import GBlock, UpsampleGBlock from dgmr.layers import ConvGRU -from huggingface_hub import PyTorchModelHubMixin -import logging logger = logging.getLogger(__name__) logger.setLevel(logging.WARN) @@ -27,6 +29,7 @@ def __init__( The sampler takes the output from the Latent and Context conditioning stacks and creates one stack of ConvGRU layers per future timestep. + Args: forecast_steps: Number of forecast steps latent_channels: Number of input channels to the lowest ConvGRU layer diff --git a/dgmr/hub.py b/dgmr/hub.py index d6f4d23..5731b6b 100644 --- a/dgmr/hub.py +++ b/dgmr/hub.py @@ -12,7 +12,6 @@ import torch - try: from huggingface_hub import cached_download, hf_hub_url diff --git a/dgmr/layers/Attention.py b/dgmr/layers/Attention.py index 54895da..a9e251c 100644 --- a/dgmr/layers/Attention.py +++ b/dgmr/layers/Attention.py @@ -1,7 +1,7 @@ +import einops import torch import torch.nn as nn from torch.nn import functional as F -import einops def attention_einsum(q, k, v): diff --git a/dgmr/layers/utils.py b/dgmr/layers/utils.py index 0c1f35f..86682f3 100644 --- a/dgmr/layers/utils.py +++ b/dgmr/layers/utils.py @@ -1,4 +1,5 @@ import torch + from dgmr.layers import CoordConv diff --git a/dgmr/losses.py b/dgmr/losses.py index 6281f1d..061ebbb 100644 --- a/dgmr/losses.py +++ b/dgmr/losses.py @@ -1,7 +1,7 @@ import numpy as np import torch import torch.nn as nn -from pytorch_msssim import SSIM, MS_SSIM +from pytorch_msssim import MS_SSIM, SSIM from torch.nn import functional as F @@ -9,6 +9,7 @@ class SSIMLoss(nn.Module): def __init__(self, convert_range: bool = False, **kwargs): """ SSIM Loss, optionally converting input range from [-1,1] to [0,1] + Args: convert_range: **kwargs: @@ -28,6 +29,7 @@ class MS_SSIMLoss(nn.Module): def __init__(self, convert_range: bool = False, **kwargs): """ Multi-Scale SSIM Loss, optionally converting input range from [-1,1] to [0,1] + Args: convert_range: **kwargs: @@ -78,6 +80,7 @@ def tv_loss(img, tv_weight): Inputs: - img: PyTorch Variable of shape (1, 3, H, W) holding an input image. - tv_weight: Scalar giving the weight w_t to use for the TV loss. + Returns: - loss: PyTorch Variable holding a scalar giving the total variation loss for img weighted by tv_weight. @@ -138,6 +141,7 @@ def forward(self, generated_images, targets): """ Calculates the grid cell regularizer value, assumes generated images are the mean predictions from 6 calls to the generater (Monte Carlo estimation of the expectations for the latent variable) + Args: generated_images: Mean generated images from the generator targets: Ground truth future frames diff --git a/setup.py b/setup.py index 4607da8..131d4ea 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,7 @@ -from setuptools import setup, find_packages from pathlib import Path +from setuptools import find_packages, setup + this_directory = Path(__file__).parent install_requires = (this_directory / "requirements.txt").read_text().splitlines() long_description = (this_directory / "README.md").read_text() diff --git a/train/run.py b/train/run.py index 8d27fc9..27b158b 100644 --- a/train/run.py +++ b/train/run.py @@ -1,19 +1,19 @@ import torch.utils.data.dataset -from dgmr import DGMR +import wandb from datasets import load_dataset -from torch.utils.data import DataLoader from pytorch_lightning import ( LightningDataModule, ) from pytorch_lightning.callbacks import ModelCheckpoint -import wandb +from torch.utils.data import DataLoader + +from dgmr import DGMR wandb.init(project="dgmr") -from numpy.random import default_rng -import os -import numpy as np from pathlib import Path -import tensorflow as tf + +import numpy as np +from numpy.random import default_rng from pytorch_lightning import Callback, Trainer from pytorch_lightning.loggers import LoggerCollection, WandbLogger from pytorch_lightning.utilities import rank_zero_only @@ -136,7 +136,7 @@ def __len__(self): def __getitem__(self, item): try: row = next(self.iter_reader) - except Exception as e: + except Exception: rng = default_rng() self.iter_reader = iter( self.reader.shuffle(seed=rng.integers(low=0, high=100000), buffer_size=1000)