Skip to content

Commit

Permalink
Merge pull request #60 from victor30608/patch-1
Browse files Browse the repository at this point in the history
Fixed a bug with missing imports
  • Loading branch information
peterdudfield authored Nov 2, 2023
2 parents 3e3c95f + 524ee84 commit d37b198
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ generator = Generator(conditioning_stack=context_stack, latent_stack=latent_stac
```python
from dgmr import DGMR
import torch.nn.functional as F
import torch
model = DGMR(
forecast_steps=4,
input_channels=1,
Expand Down
3 changes: 0 additions & 3 deletions train/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@


def get_wandb_logger(trainer: Trainer) -> WandbLogger:

if trainer.fast_dev_run:
raise Exception(
"Cannot use wandb callbacks since pytorch lightning disables loggers in `fast_dev_run=true` mode."
Expand All @@ -40,7 +39,6 @@ def get_wandb_logger(trainer: Trainer) -> WandbLogger:


class WatchModel(Callback):

def __init__(self, log: str = "gradients", log_freq: int = 100):
self.log = log
self.log_freq = log_freq
Expand All @@ -52,7 +50,6 @@ def on_train_start(self, trainer, pl_module):


class UploadCheckpointsAsArtifact(Callback):

def __init__(self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False):
self.ckpt_dir = ckpt_dir
self.upload_best_only = upload_best_only
Expand Down

0 comments on commit d37b198

Please sign in to comment.