From be1a526b058707eb136ede9397b54d971374fb1f Mon Sep 17 00:00:00 2001 From: Ruotian Luo Date: Fri, 14 May 2021 22:02:31 -0500 Subject: [PATCH] [pl]: update to pl 1.3+ --- captioning/utils/misc.py | 12 ++--- tools/train_pl.py | 101 +++++++++++++++++++++++---------------- 2 files changed, 67 insertions(+), 46 deletions(-) diff --git a/captioning/utils/misc.py b/captioning/utils/misc.py index 8bd3193d..7bf84f0c 100644 --- a/captioning/utils/misc.py +++ b/captioning/utils/misc.py @@ -157,7 +157,7 @@ def length_average(length, logprobs, alpha=0.): return logprobs / length -class NoamOpt(object): +class NoamOpt(torch.optim.Optimizer): "Optim wrapper that implements rate." def __init__(self, model_size, factor, warmup, optimizer): self.optimizer = optimizer @@ -167,14 +167,14 @@ def __init__(self, model_size, factor, warmup, optimizer): self.model_size = model_size self._rate = 0 - def step(self): + def step(self, *args, **kwargs): "Update parameters and rate" self._step += 1 rate = self.rate() for p in self.optimizer.param_groups: p['lr'] = rate self._rate = rate - self.optimizer.step() + self.optimizer.step(*args, **kwargs) def rate(self, step = None): "Implement `lrate` above" @@ -198,16 +198,16 @@ def load_state_dict(self, state_dict): del state_dict['_step'] self.optimizer.load_state_dict(state_dict) -class ReduceLROnPlateau(object): +class ReduceLROnPlateau(torch.optim.Optimizer): "Optim wrapper that implements rate." def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08): self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode, factor, patience, verbose, threshold, threshold_mode, cooldown, min_lr, eps) self.optimizer = optimizer self.current_lr = get_lr(optimizer) - def step(self): + def step(self, *args, **kwargs): "Update parameters and rate" - self.optimizer.step() + self.optimizer.step(*args, **kwargs) def scheduler_step(self, val): self.scheduler.step(val) diff --git a/tools/train_pl.py b/tools/train_pl.py index 58529542..c397e052 100644 --- a/tools/train_pl.py +++ b/tools/train_pl.py @@ -104,6 +104,7 @@ def training_step(self, data, batch_idx): loss = model_out.pop('loss') loss = torch.topk(loss, k=int(loss.shape[0] * (1-self.opt.drop_worst_rate)), largest=False)[0].mean() + # Prepare for logging info data_time = self.trainer.profiler.recorded_durations["get_train_batch"][-1] data_time = torch.tensor(data_time) @@ -117,13 +118,11 @@ def training_step(self, data, batch_idx): logger_logs['training_loss'] = loss logger_logs['data_time'] = data_time - output = { - 'loss': loss, - 'log': logger_logs, - 'progress_bar': {'data_time': data_time} - } + for k, v in logger_logs.items(): + self.log(k, v, on_epoch=(k=='training_loss'), prog_bar=(k=='data_time')) + # logged - return output + return loss def validation_step(self, data, batch_idx): model = self.model @@ -202,7 +201,7 @@ def validation_step(self, data, batch_idx): fc_feats, att_feats, att_masks, data], eval_kwargs) output = { - 'val_loss': loss, + 'loss': loss, 'predictions': predictions, 'n_predictions': n_predictions, } @@ -211,7 +210,7 @@ def validation_step(self, data, batch_idx): def test_step(self, *args, **kwargs): return self.validation_step(*args, **kwargs) - def validation_epoch_end(self, outputs, split='val'): + def split_epoch_end(self, outputs, split='val'): outputs = d2comm.gather(outputs) # master node if d2comm.is_main_process(): @@ -219,7 +218,7 @@ def validation_epoch_end(self, outputs, split='val'): outputs = sum(outputs, []) opt = self.opt - val_loss_mean = sum([_['val_loss'] + loss_mean = sum([_['loss'].item() for _ in outputs]) / len(outputs) predictions = sum([_['predictions'] for _ in outputs], []) @@ -247,13 +246,13 @@ def validation_epoch_end(self, outputs, split='val'): if 'CIDEr' in lang_stats: optimizer.scheduler_step(-lang_stats['CIDEr']) else: - optimizer.scheduler_step(val_loss_mean) + optimizer.scheduler_step(loss_mean) out = { - 'val_loss': val_loss_mean + 'loss': loss_mean } out.update(lang_stats) - out['to_monitor'] = lang_stats['CIDEr'] if lang_stats is not None else -val_loss_mean + out['to_monitor'] = lang_stats['CIDEr'] if lang_stats is not None else -loss_mean else: out = {} @@ -263,23 +262,25 @@ def validation_epoch_end(self, outputs, split='val'): # must all be tensors out = {k: torch.tensor(v) if not torch.is_tensor( v) else v for k, v in out.items()} - return { - 'progress_bar': {'val_loss': out['val_loss']}, - 'log': out, - } + + return out + + def validation_epoch_end(self, outputs): + out = self.split_epoch_end(outputs, 'val') + out['val_loss'] = out.pop('loss') + for k,v in out.items(): + self.log(k, v) + return out def test_epoch_end(self, outputs): - out = self.validation_epoch_end(outputs, 'test') - out['progress_bar'] = { - 'test_loss': out['progress_bar']['val_loss'] - } - out['log']['test_loss'] = out['log']['val_loss'] - del out['log']['val_loss'] - del out['log']['to_monitor'] - out['log'] = {'test_'+k if 'test' not in k else k:v \ - for k,v in out['log'].items()} + out = self.split_epoch_end(outputs, 'test') + out['test_loss'] = out.pop('loss') + out = {'test_'+k if 'test' not in k else k: v + for k, v in out.items()} + for k,v in out.items(): + self.log(k, v) return out - + def configure_optimizers(self): opt = self.opt model = self.model @@ -309,11 +310,11 @@ def optimizer_step(self, epoch, batch_idx, optimizer, super().optimizer_step(epoch, batch_idx, optimizer, optimizer_idx, *args, **kwargs) - def state_dict(self): + def state_dict(self, *args, **kwargs): """ Save the model state dict as well as opt and vocab """ - state_dict = self.model.state_dict() + state_dict = self.model.state_dict(*args, **kwargs) device = next(iter(state_dict.values())).device assert '_vocab' not in state_dict and '_opt' not in state_dict, 'Just in case' state_dict.update({ @@ -345,10 +346,16 @@ def load_state_dict(self, state_dict=None, strict=True): raise KeyError self.model.load_state_dict(state_dict, strict) + def get_progress_bar_dict(self): + # don't show the version number + items = super().get_progress_bar_dict() + items.pop("v_num", None) + return items + class OnEpochStartCallback(pl.Callback): - def on_epoch_start(self, trainer, pl_module): + def on_train_epoch_start(self, trainer, pl_module): # Update lr/training stage/scheduled sampling prob etc. opt = pl_module.opt model = pl_module.model @@ -402,21 +409,32 @@ class ModelCheckpoint(pl.callbacks.ModelCheckpoint): def on_keyboard_interrupt(self, trainer, pl_module): # Save model when keyboard interrupt - filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt') - self._save_model(filepath) + filepath = os.path.join(self.dirpath, pl_module.opt.id+'_interrupt.ckpt') + self._save_model(trainer, filepath=filepath) opt = opts.parse_opt() checkpoint_callback = ModelCheckpoint( - filepath=opt.checkpoint_path, + dirpath=opt.checkpoint_path, + filename=opt.id+'_{epoch}-{step}', save_last=True, save_top_k=1, verbose=True, monitor='to_monitor', mode='max', - prefix=opt.id+'_', ) +checkpoint_callback.CHECKPOINT_NAME_LAST = opt.id+"_last" + + +tb_logger = pl.loggers.TensorBoardLogger(opt.checkpoint_path + + '/lightning_logs/', + name='', + version=0) +wandb_logger = pl.loggers.WandbLogger(name=opt.id, + id=opt.id, + project='captioning', + log_model=True) print(""" val_image_use, @@ -438,29 +456,32 @@ def on_keyboard_interrupt(self, trainer, pl_module): lit = LitModel(opt) # warning grad_clip_mode is ignored. trainer = pl.Trainer( + logger=[tb_logger, wandb_logger], callbacks=[ OnEpochStartCallback(), - pl.callbacks.lr_logger.LearningRateLogger() + pl.callbacks.LearningRateMonitor(), + checkpoint_callback, ], default_root_dir=opt.checkpoint_path, resume_from_checkpoint=resume_from, - distributed_backend='ddp', + accelerator='ddp', check_val_every_n_epoch=1, max_epochs=opt.max_epochs, + gradient_clip_algorithm=opt.grad_clip_mode, gradient_clip_val=opt.grad_clip_value, gpus=torch.cuda.device_count(), - checkpoint_callback=checkpoint_callback, log_gpu_memory='min_max', - log_save_interval=opt.losses_log_every, - profiler=True, - row_log_interval=10, # what is it? + log_every_n_steps=opt.losses_log_every, + profiler='simple', num_sanity_val_steps=0, - # limit_train_batches=500, + # limit_train_batches=100, # progress_bar_refresh_rate=0, # fast_dev_run=True, ) if os.getenv('EVALUATE', '0') == '1': + lit.load_state_dict( + torch.load(resume_from, map_location='cpu')['state_dict'], strict=False) trainer.test(lit) else: trainer.fit(lit)