From d7d8a10c1ffccb04e3f92ae552bbdb1c7e82c026 Mon Sep 17 00:00:00 2001 From: Siqi Liu Date: Sat, 18 Nov 2017 21:03:04 +1100 Subject: [PATCH] remove if True: in train.py --- train.py | 61 ++++++++++++++++++++++++++++---------------------------- 1 file changed, 30 insertions(+), 31 deletions(-) diff --git a/train.py b/train.py index 7a626977..eafa5b4b 100644 --- a/train.py +++ b/train.py @@ -161,40 +161,39 @@ def train(opt): current_score = - val_loss best_flag = False - if True: # if true - if best_val_score is None or current_score > best_val_score: - best_val_score = current_score - best_flag = True - checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') + if best_val_score is None or current_score > best_val_score: + best_val_score = current_score + best_flag = True + checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') + torch.save(model.state_dict(), checkpoint_path) + print("model saved to {}".format(checkpoint_path)) + optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') + torch.save(optimizer.state_dict(), optimizer_path) + + # Dump miscalleous informations + infos['iter'] = iteration + infos['epoch'] = epoch + infos['iterators'] = loader.iterators + infos['split_ix'] = loader.split_ix + infos['best_val_score'] = best_val_score + infos['opt'] = opt + infos['vocab'] = loader.get_vocab() + + histories['val_result_history'] = val_result_history + histories['loss_history'] = loss_history + histories['lr_history'] = lr_history + histories['ss_prob_history'] = ss_prob_history + with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f: + cPickle.dump(infos, f) + with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'.pkl'), 'wb') as f: + cPickle.dump(histories, f) + + if best_flag: + checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) - optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') - torch.save(optimizer.state_dict(), optimizer_path) - - # Dump miscalleous informations - infos['iter'] = iteration - infos['epoch'] = epoch - infos['iterators'] = loader.iterators - infos['split_ix'] = loader.split_ix - infos['best_val_score'] = best_val_score - infos['opt'] = opt - infos['vocab'] = loader.get_vocab() - - histories['val_result_history'] = val_result_history - histories['loss_history'] = loss_history - histories['lr_history'] = lr_history - histories['ss_prob_history'] = ss_prob_history - with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f: + with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f: cPickle.dump(infos, f) - with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'.pkl'), 'wb') as f: - cPickle.dump(histories, f) - - if best_flag: - checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth') - torch.save(model.state_dict(), checkpoint_path) - print("model saved to {}".format(checkpoint_path)) - with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f: - cPickle.dump(infos, f) # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: