Skip to content

Commit

Permalink
Updates to newer versions of requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Oct 23, 2023
1 parent 478fee0 commit 7abec47
Show file tree
Hide file tree
Showing 19 changed files with 5 additions and 112 deletions.
2 changes: 1 addition & 1 deletion dgmr/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def _from_pretrained(
proxies,
resume_download,
local_files_only,
use_auth_token,
use_auth_token=False,
map_location="cpu",
strict=False,
**model_kwargs,
Expand Down
Binary file not shown.
13 changes: 0 additions & 13 deletions tests/lightning_logs/version_0/hparams.yaml

This file was deleted.

Binary file not shown.
13 changes: 0 additions & 13 deletions tests/lightning_logs/version_1/hparams.yaml

This file was deleted.

Binary file not shown.
13 changes: 0 additions & 13 deletions tests/lightning_logs/version_2/hparams.yaml

This file was deleted.

Binary file not shown.
13 changes: 0 additions & 13 deletions tests/lightning_logs/version_3/hparams.yaml

This file was deleted.

Binary file not shown.
13 changes: 0 additions & 13 deletions tests/lightning_logs/version_4/hparams.yaml

This file was deleted.

Binary file not shown.
13 changes: 0 additions & 13 deletions tests/lightning_logs/version_5/hparams.yaml

This file was deleted.

Binary file not shown.
13 changes: 0 additions & 13 deletions tests/lightning_logs/version_6/hparams.yaml

This file was deleted.

Binary file not shown.
13 changes: 0 additions & 13 deletions tests/lightning_logs/version_7/hparams.yaml

This file was deleted.

4 changes: 2 additions & 2 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def test_load_dgmr_from_hf():
_ = Generator(conditioning_stack=ctz, latent_stack=lat, sampler=sam)


@pytest.mark.skip("Takes too long")
#@pytest.mark.skip("Takes too long")
def test_train_dgmr():
forecast_steps = 8

Expand All @@ -324,7 +324,7 @@ def __getitem__(self, idx):
train_loader = torch.utils.data.DataLoader(DS(), batch_size=1)
val_loader = torch.utils.data.DataLoader(DS(), batch_size=1)

trainer = Trainer(gpus=0, max_epochs=1)
trainer = Trainer(accelerator="cpu", max_epochs=1)
model = DGMR(forecast_steps=forecast_steps)

trainer.fit(model, train_loader, val_loader)
7 changes: 2 additions & 5 deletions train/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
import numpy as np
from numpy.random import default_rng
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.loggers import LoggerCollection, WandbLogger
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.utilities import rank_zero_only


def get_wandb_logger(trainer: Trainer) -> WandbLogger:
"""Safely get Weights&Biases logger from Trainer."""

if trainer.fast_dev_run:
raise Exception(
Expand All @@ -41,7 +40,6 @@ def get_wandb_logger(trainer: Trainer) -> WandbLogger:


class WatchModel(Callback):
"""Make wandb watch model at the beginning of the run."""

def __init__(self, log: str = "gradients", log_freq: int = 100):
self.log = log
Expand All @@ -54,7 +52,6 @@ def on_train_start(self, trainer, pl_module):


class UploadCheckpointsAsArtifact(Callback):
"""Upload checkpoints to wandb as an artifact, at the end of run."""

def __init__(self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False):
self.ckpt_dir = ckpt_dir
Expand Down Expand Up @@ -209,7 +206,7 @@ def val_dataloader(self):
max_epochs=1000,
logger=wandb_logger,
callbacks=[model_checkpoint],
gpus=6,
accelerator="auto",
precision=32,
# accelerator="tpu", devices=8
)
Expand Down

0 comments on commit 7abec47

Please sign in to comment.