-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
98 lines (80 loc) · 3.11 KB
/
train.py
1
#!/user/bin/python# -*- encoding: utf-8 -*-import torchfrom torch.nn import functionalimport timefrom loss.dynamic_focal_loss import dy_focal_lossfrom loss.cross_entropy_loss import cross_entropy_loss2dfrom loss.new_loss import sample_balance_lossfrom utils import Averagvalue, save_checkpointimport osfrom os.path import join, isdirimport torchvisionfrom utils import log_lrdef train(cfg, args, train_loader, model, optimizer, scheduler, epoch, save_dir): # display and logging if not isdir(save_dir): os.makedirs(save_dir) batch_time = Averagvalue() data_time = Averagvalue() losses = Averagvalue() # switch to train mode model.train() end = time.time() epoch_loss = [] counter = 0 for i, (image, label, KDlabel, pth) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) image = image.cuda() _, _, H, W = image.size() outputs = model(image) loss = torch.zeros(1).cuda() if args.loss == "DFL": weight = [0.5, 0.5, 0.5, 0.5, 0.5, 1.1] for o, w in zip(outputs, weight): loss = loss + w * dy_focal_loss(o, label.cuda(), epoch, args) elif args.loss == "SBL": weight = [0.5, 0.5, 0.5, 0.5, 0.5, 1.1] for o, w in zip(outputs, weight): loss = loss + w * sample_balance_loss(o, KDlabel.cuda(), args,epoch) elif args.loss == "WCE": for o in outputs: loss += cross_entropy_loss2d(o, label.cuda()) else: raise Exception("illegal loss function") counter += 1 loss = loss / cfg.itersize loss.backward() if counter == cfg.itersize: optimizer.step() optimizer.zero_grad() counter = 0 # measure accuracy and record loss losses.update(loss.item(), image.size(0)) epoch_loss.append(loss.item()) batch_time.update(time.time() - end) end = time.time() if i % cfg.msg_iter == 0: info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.max_epoch + args.start_epoch, i, len(train_loader)) + \ 'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \ 'Save_dir:{} '.format(save_dir.split('/')[-3]) + \ 'Loss {loss.val:f} (avg:{loss.avg:f})'.format(loss=losses) print(info) outputs = [o.cpu() for o in outputs] if args.loss is not "SBL": outputs.append((label == 1).float()) else: outputs.extend([(label == 1).float(), KDlabel]) all_results = torch.cat(outputs, dim=0) torchvision.utils.save_image(1 - all_results, join(save_dir, "iter-%d.jpg" % i)) # adjust lr scheduler.step() log_lr(optimizer) # save checkpoint save_checkpoint({ 'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() }, filename=join(save_dir, "epoch-%d-checkpoint.pth" % epoch)) return losses.avg, epoch_loss