diff --git a/dgmr/hub.py b/dgmr/hub.py index 5731b6b..568883d 100644 --- a/dgmr/hub.py +++ b/dgmr/hub.py @@ -129,7 +129,7 @@ def _from_pretrained( proxies, resume_download, local_files_only, - use_auth_token, + use_auth_token=False, map_location="cpu", strict=False, **model_kwargs, diff --git a/tests/lightning_logs/version_0/events.out.tfevents.1655536486.ocf.4177843.0 b/tests/lightning_logs/version_0/events.out.tfevents.1655536486.ocf.4177843.0 deleted file mode 100644 index 8c80c48..0000000 Binary files a/tests/lightning_logs/version_0/events.out.tfevents.1655536486.ocf.4177843.0 and /dev/null differ diff --git a/tests/lightning_logs/version_0/hparams.yaml b/tests/lightning_logs/version_0/hparams.yaml deleted file mode 100644 index 5bef6ee..0000000 --- a/tests/lightning_logs/version_0/hparams.yaml +++ /dev/null @@ -1,13 +0,0 @@ -beta1: 0.0 -beta2: 0.999 -context_channels: 384 -conv_type: standard -disc_lr: 0.0002 -forecast_steps: 4 -gen_lr: 5.0e-05 -grid_lambda: 20.0 -input_channels: 1 -latent_channels: 768 -num_samples: 6 -output_shape: 128 -visualize: false diff --git a/tests/lightning_logs/version_1/events.out.tfevents.1655536602.ocf.4179322.0 b/tests/lightning_logs/version_1/events.out.tfevents.1655536602.ocf.4179322.0 deleted file mode 100644 index f3b6014..0000000 Binary files a/tests/lightning_logs/version_1/events.out.tfevents.1655536602.ocf.4179322.0 and /dev/null differ diff --git a/tests/lightning_logs/version_1/hparams.yaml b/tests/lightning_logs/version_1/hparams.yaml deleted file mode 100644 index 5bef6ee..0000000 --- a/tests/lightning_logs/version_1/hparams.yaml +++ /dev/null @@ -1,13 +0,0 @@ -beta1: 0.0 -beta2: 0.999 -context_channels: 384 -conv_type: standard -disc_lr: 0.0002 -forecast_steps: 4 -gen_lr: 5.0e-05 -grid_lambda: 20.0 -input_channels: 1 -latent_channels: 768 -num_samples: 6 -output_shape: 128 -visualize: false diff --git a/tests/lightning_logs/version_2/events.out.tfevents.1655536666.ocf.4180026.0 b/tests/lightning_logs/version_2/events.out.tfevents.1655536666.ocf.4180026.0 deleted file mode 100644 index 159ba9f..0000000 Binary files a/tests/lightning_logs/version_2/events.out.tfevents.1655536666.ocf.4180026.0 and /dev/null differ diff --git a/tests/lightning_logs/version_2/hparams.yaml b/tests/lightning_logs/version_2/hparams.yaml deleted file mode 100644 index 5bef6ee..0000000 --- a/tests/lightning_logs/version_2/hparams.yaml +++ /dev/null @@ -1,13 +0,0 @@ -beta1: 0.0 -beta2: 0.999 -context_channels: 384 -conv_type: standard -disc_lr: 0.0002 -forecast_steps: 4 -gen_lr: 5.0e-05 -grid_lambda: 20.0 -input_channels: 1 -latent_channels: 768 -num_samples: 6 -output_shape: 128 -visualize: false diff --git a/tests/lightning_logs/version_3/events.out.tfevents.1655536730.ocf.4180814.0 b/tests/lightning_logs/version_3/events.out.tfevents.1655536730.ocf.4180814.0 deleted file mode 100644 index 8f88f35..0000000 Binary files a/tests/lightning_logs/version_3/events.out.tfevents.1655536730.ocf.4180814.0 and /dev/null differ diff --git a/tests/lightning_logs/version_3/hparams.yaml b/tests/lightning_logs/version_3/hparams.yaml deleted file mode 100644 index 5bef6ee..0000000 --- a/tests/lightning_logs/version_3/hparams.yaml +++ /dev/null @@ -1,13 +0,0 @@ -beta1: 0.0 -beta2: 0.999 -context_channels: 384 -conv_type: standard -disc_lr: 0.0002 -forecast_steps: 4 -gen_lr: 5.0e-05 -grid_lambda: 20.0 -input_channels: 1 -latent_channels: 768 -num_samples: 6 -output_shape: 128 -visualize: false diff --git a/tests/lightning_logs/version_4/events.out.tfevents.1655536810.ocf.4181676.0 b/tests/lightning_logs/version_4/events.out.tfevents.1655536810.ocf.4181676.0 deleted file mode 100644 index ea676ea..0000000 Binary files a/tests/lightning_logs/version_4/events.out.tfevents.1655536810.ocf.4181676.0 and /dev/null differ diff --git a/tests/lightning_logs/version_4/hparams.yaml b/tests/lightning_logs/version_4/hparams.yaml deleted file mode 100644 index 5bef6ee..0000000 --- a/tests/lightning_logs/version_4/hparams.yaml +++ /dev/null @@ -1,13 +0,0 @@ -beta1: 0.0 -beta2: 0.999 -context_channels: 384 -conv_type: standard -disc_lr: 0.0002 -forecast_steps: 4 -gen_lr: 5.0e-05 -grid_lambda: 20.0 -input_channels: 1 -latent_channels: 768 -num_samples: 6 -output_shape: 128 -visualize: false diff --git a/tests/lightning_logs/version_5/events.out.tfevents.1655536886.ocf.4182513.0 b/tests/lightning_logs/version_5/events.out.tfevents.1655536886.ocf.4182513.0 deleted file mode 100644 index 08085d9..0000000 Binary files a/tests/lightning_logs/version_5/events.out.tfevents.1655536886.ocf.4182513.0 and /dev/null differ diff --git a/tests/lightning_logs/version_5/hparams.yaml b/tests/lightning_logs/version_5/hparams.yaml deleted file mode 100644 index 5bef6ee..0000000 --- a/tests/lightning_logs/version_5/hparams.yaml +++ /dev/null @@ -1,13 +0,0 @@ -beta1: 0.0 -beta2: 0.999 -context_channels: 384 -conv_type: standard -disc_lr: 0.0002 -forecast_steps: 4 -gen_lr: 5.0e-05 -grid_lambda: 20.0 -input_channels: 1 -latent_channels: 768 -num_samples: 6 -output_shape: 128 -visualize: false diff --git a/tests/lightning_logs/version_6/events.out.tfevents.1655536960.ocf.4183415.0 b/tests/lightning_logs/version_6/events.out.tfevents.1655536960.ocf.4183415.0 deleted file mode 100644 index 36a9597..0000000 Binary files a/tests/lightning_logs/version_6/events.out.tfevents.1655536960.ocf.4183415.0 and /dev/null differ diff --git a/tests/lightning_logs/version_6/hparams.yaml b/tests/lightning_logs/version_6/hparams.yaml deleted file mode 100644 index 5bef6ee..0000000 --- a/tests/lightning_logs/version_6/hparams.yaml +++ /dev/null @@ -1,13 +0,0 @@ -beta1: 0.0 -beta2: 0.999 -context_channels: 384 -conv_type: standard -disc_lr: 0.0002 -forecast_steps: 4 -gen_lr: 5.0e-05 -grid_lambda: 20.0 -input_channels: 1 -latent_channels: 768 -num_samples: 6 -output_shape: 128 -visualize: false diff --git a/tests/lightning_logs/version_7/events.out.tfevents.1655537037.ocf.4184301.0 b/tests/lightning_logs/version_7/events.out.tfevents.1655537037.ocf.4184301.0 deleted file mode 100644 index d135b0b..0000000 Binary files a/tests/lightning_logs/version_7/events.out.tfevents.1655537037.ocf.4184301.0 and /dev/null differ diff --git a/tests/lightning_logs/version_7/hparams.yaml b/tests/lightning_logs/version_7/hparams.yaml deleted file mode 100644 index 5bef6ee..0000000 --- a/tests/lightning_logs/version_7/hparams.yaml +++ /dev/null @@ -1,13 +0,0 @@ -beta1: 0.0 -beta2: 0.999 -context_channels: 384 -conv_type: standard -disc_lr: 0.0002 -forecast_steps: 4 -gen_lr: 5.0e-05 -grid_lambda: 20.0 -input_channels: 1 -latent_channels: 768 -num_samples: 6 -output_shape: 128 -visualize: false diff --git a/tests/test_model.py b/tests/test_model.py index 4f6f090..14feb32 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -307,7 +307,7 @@ def test_load_dgmr_from_hf(): _ = Generator(conditioning_stack=ctz, latent_stack=lat, sampler=sam) -@pytest.mark.skip("Takes too long") +#@pytest.mark.skip("Takes too long") def test_train_dgmr(): forecast_steps = 8 @@ -324,7 +324,7 @@ def __getitem__(self, idx): train_loader = torch.utils.data.DataLoader(DS(), batch_size=1) val_loader = torch.utils.data.DataLoader(DS(), batch_size=1) - trainer = Trainer(gpus=0, max_epochs=1) + trainer = Trainer(accelerator="cpu", max_epochs=1) model = DGMR(forecast_steps=forecast_steps) trainer.fit(model, train_loader, val_loader) diff --git a/train/run.py b/train/run.py index 27b158b..efc1939 100644 --- a/train/run.py +++ b/train/run.py @@ -15,12 +15,11 @@ import numpy as np from numpy.random import default_rng from pytorch_lightning import Callback, Trainer -from pytorch_lightning.loggers import LoggerCollection, WandbLogger +from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.utilities import rank_zero_only def get_wandb_logger(trainer: Trainer) -> WandbLogger: - """Safely get Weights&Biases logger from Trainer.""" if trainer.fast_dev_run: raise Exception( @@ -41,7 +40,6 @@ def get_wandb_logger(trainer: Trainer) -> WandbLogger: class WatchModel(Callback): - """Make wandb watch model at the beginning of the run.""" def __init__(self, log: str = "gradients", log_freq: int = 100): self.log = log @@ -54,7 +52,6 @@ def on_train_start(self, trainer, pl_module): class UploadCheckpointsAsArtifact(Callback): - """Upload checkpoints to wandb as an artifact, at the end of run.""" def __init__(self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False): self.ckpt_dir = ckpt_dir @@ -209,7 +206,7 @@ def val_dataloader(self): max_epochs=1000, logger=wandb_logger, callbacks=[model_checkpoint], - gpus=6, + accelerator="auto", precision=32, # accelerator="tpu", devices=8 )