Skip to content

Commit

Permalink
[pl]: update to pl 1.3+
Browse files Browse the repository at this point in the history
  • Loading branch information
ruotianluo committed May 16, 2021
1 parent b2fe2bf commit 5076fbc
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 46 deletions.
12 changes: 6 additions & 6 deletions captioning/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def length_average(length, logprobs, alpha=0.):
return logprobs / length


class NoamOpt(object):
class NoamOpt(torch.optim.Optimizer):
"Optim wrapper that implements rate."
def __init__(self, model_size, factor, warmup, optimizer):
self.optimizer = optimizer
Expand All @@ -167,14 +167,14 @@ def __init__(self, model_size, factor, warmup, optimizer):
self.model_size = model_size
self._rate = 0

def step(self):
def step(self, *args, **kwargs):
"Update parameters and rate"
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p['lr'] = rate
self._rate = rate
self.optimizer.step()
self.optimizer.step(*args, **kwargs)

def rate(self, step = None):
"Implement `lrate` above"
Expand All @@ -198,16 +198,16 @@ def load_state_dict(self, state_dict):
del state_dict['_step']
self.optimizer.load_state_dict(state_dict)

class ReduceLROnPlateau(object):
class ReduceLROnPlateau(torch.optim.Optimizer):
"Optim wrapper that implements rate."
def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08):
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode, factor, patience, verbose, threshold, threshold_mode, cooldown, min_lr, eps)
self.optimizer = optimizer
self.current_lr = get_lr(optimizer)

def step(self):
def step(self, *args, **kwargs):
"Update parameters and rate"
self.optimizer.step()
self.optimizer.step(*args, **kwargs)

def scheduler_step(self, val):
self.scheduler.step(val)
Expand Down
102 changes: 62 additions & 40 deletions tools/train_pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def training_step(self, data, batch_idx):
loss = model_out.pop('loss')
loss = torch.topk(loss, k=int(loss.shape[0] * (1-self.opt.drop_worst_rate)), largest=False)[0].mean()

# Prepare for logging info
data_time = self.trainer.profiler.recorded_durations["get_train_batch"][-1]
data_time = torch.tensor(data_time)

Expand All @@ -117,13 +118,11 @@ def training_step(self, data, batch_idx):
logger_logs['training_loss'] = loss
logger_logs['data_time'] = data_time

output = {
'loss': loss,
'log': logger_logs,
'progress_bar': {'data_time': data_time}
}
for k, v in logger_logs.items():
self.log(k, v, on_epoch=(k=='training_loss'), prog_bar=(k=='data_time'))
# logged

return output
return loss

def validation_step(self, data, batch_idx):
model = self.model
Expand Down Expand Up @@ -202,24 +201,25 @@ def validation_step(self, data, batch_idx):
fc_feats, att_feats, att_masks, data], eval_kwargs)

output = {
'val_loss': loss,
'loss': loss,
'predictions': predictions,
'n_predictions': n_predictions,
}
self.log('loss', loss)
return output

def test_step(self, *args, **kwargs):
return self.validation_step(*args, **kwargs)

def validation_epoch_end(self, outputs, split='val'):
def split_epoch_end(self, outputs, split='val'):
outputs = d2comm.gather(outputs)
# master node
if d2comm.is_main_process():
assert self.trainer.node_rank == 0 and self.trainer.local_rank == 0
outputs = sum(outputs, [])

opt = self.opt
val_loss_mean = sum([_['val_loss']
loss_mean = sum([_['loss'].item()
for _ in outputs]) / len(outputs)

predictions = sum([_['predictions'] for _ in outputs], [])
Expand Down Expand Up @@ -247,13 +247,13 @@ def validation_epoch_end(self, outputs, split='val'):
if 'CIDEr' in lang_stats:
optimizer.scheduler_step(-lang_stats['CIDEr'])
else:
optimizer.scheduler_step(val_loss_mean)
optimizer.scheduler_step(loss_mean)

out = {
'val_loss': val_loss_mean
'loss': loss_mean
}
out.update(lang_stats)
out['to_monitor'] = lang_stats['CIDEr'] if lang_stats is not None else -val_loss_mean
out['to_monitor'] = lang_stats['CIDEr'] if lang_stats is not None else -loss_mean
else:
out = {}

Expand All @@ -263,23 +263,25 @@ def validation_epoch_end(self, outputs, split='val'):
# must all be tensors
out = {k: torch.tensor(v) if not torch.is_tensor(
v) else v for k, v in out.items()}
return {
'progress_bar': {'val_loss': out['val_loss']},
'log': out,
}

return out

def validation_epoch_end(self, outputs):
out = self.split_epoch_end(outputs, 'val')
out['val_loss'] = out.pop('loss')
for k,v in out.items():
self.log(k, v)
return out

def test_epoch_end(self, outputs):
out = self.validation_epoch_end(outputs, 'test')
out['progress_bar'] = {
'test_loss': out['progress_bar']['val_loss']
}
out['log']['test_loss'] = out['log']['val_loss']
del out['log']['val_loss']
del out['log']['to_monitor']
out['log'] = {'test_'+k if 'test' not in k else k:v \
for k,v in out['log'].items()}
out = self.split_epoch_end(outputs, 'test')
out['test_loss'] = out.pop('loss')
out = {'test_'+k if 'test' not in k else k: v
for k, v in out.items()}
for k,v in out.items():
self.log(k, v)
return out

def configure_optimizers(self):
opt = self.opt
model = self.model
Expand Down Expand Up @@ -309,11 +311,11 @@ def optimizer_step(self, epoch, batch_idx, optimizer,
super().optimizer_step(epoch, batch_idx, optimizer,
optimizer_idx, *args, **kwargs)

def state_dict(self):
def state_dict(self, *args, **kwargs):
"""
Save the model state dict as well as opt and vocab
"""
state_dict = self.model.state_dict()
state_dict = self.model.state_dict(*args, **kwargs)
device = next(iter(state_dict.values())).device
assert '_vocab' not in state_dict and '_opt' not in state_dict, 'Just in case'
state_dict.update({
Expand Down Expand Up @@ -345,10 +347,16 @@ def load_state_dict(self, state_dict=None, strict=True):
raise KeyError
self.model.load_state_dict(state_dict, strict)

def get_progress_bar_dict(self):
# don't show the version number
items = super().get_progress_bar_dict()
items.pop("v_num", None)
return items


class OnEpochStartCallback(pl.Callback):

def on_epoch_start(self, trainer, pl_module):
def on_train_epoch_start(self, trainer, pl_module):
# Update lr/training stage/scheduled sampling prob etc.
opt = pl_module.opt
model = pl_module.model
Expand Down Expand Up @@ -402,21 +410,32 @@ class ModelCheckpoint(pl.callbacks.ModelCheckpoint):

def on_keyboard_interrupt(self, trainer, pl_module):
# Save model when keyboard interrupt
filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt')
self._save_model(filepath)
filepath = os.path.join(self.dirpath, pl_module.opt.id+'_interrupt.ckpt')
self._save_model(trainer, filepath=filepath)


opt = opts.parse_opt()

checkpoint_callback = ModelCheckpoint(
filepath=opt.checkpoint_path,
dirpath=opt.checkpoint_path,
filename=opt.id+'_{epoch}-{step}',
save_last=True,
save_top_k=1,
verbose=True,
monitor='to_monitor',
mode='max',
prefix=opt.id+'_',
)
checkpoint_callback.CHECKPOINT_NAME_LAST = opt.id+"_last"


tb_logger = pl.loggers.TensorBoardLogger(opt.checkpoint_path +
'/lightning_logs/',
name='',
version=0)
wandb_logger = pl.loggers.WandbLogger(name=opt.id,
id=opt.id,
project='captioning',
log_model=True)

print("""
val_image_use,
Expand All @@ -438,29 +457,32 @@ def on_keyboard_interrupt(self, trainer, pl_module):
lit = LitModel(opt)
# warning grad_clip_mode is ignored.
trainer = pl.Trainer(
logger=[tb_logger, wandb_logger],
callbacks=[
OnEpochStartCallback(),
pl.callbacks.lr_logger.LearningRateLogger()
pl.callbacks.LearningRateMonitor(),
checkpoint_callback,
],
default_root_dir=opt.checkpoint_path,
resume_from_checkpoint=resume_from,
distributed_backend='ddp',
accelerator='ddp',
check_val_every_n_epoch=1,
max_epochs=opt.max_epochs,
gradient_clip_algorithm=opt.grad_clip_mode,
gradient_clip_val=opt.grad_clip_value,
gpus=torch.cuda.device_count(),
checkpoint_callback=checkpoint_callback,
log_gpu_memory='min_max',
log_save_interval=opt.losses_log_every,
profiler=True,
row_log_interval=10, # what is it?
log_every_n_steps=opt.losses_log_every,
profiler='simple',
num_sanity_val_steps=0,
# limit_train_batches=500,
# limit_train_batches=100,
# progress_bar_refresh_rate=0,
# fast_dev_run=True,
)

if os.getenv('EVALUATE', '0') == '1':
lit.load_state_dict(
torch.load(resume_from, map_location='cpu')['state_dict'], strict=False)
trainer.test(lit)
else:
trainer.fit(lit)

0 comments on commit 5076fbc

Please sign in to comment.