-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathtrain.py
167 lines (128 loc) · 5.19 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import glob
import torch
import dataset
import numpy as np
from utils import *
from unet import UNet
from loss import loss_fn_kd
from metrics import dice_loss
from torch.autograd import Variable
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import StepLR
teacher_weights = '/home/nirvi/Internship_2020/KDforUNET/teacher_checkpoints/32_final/CP_32_5.pth'
#student_weights = 'checkpoints/CP5.pth'
num_of_epochs = 5
summary_steps = 10
def fetch_teacher_outputs(teacher, train_loader):
print('-------Fetch teacher outputs-------')
teacher.eval().cuda()
#list of tensors
teacher_outputs = []
with torch.no_grad():
#trainloader gets bs images at a time. why does enumerate(tl) run for all images?
for i, (img, gt) in enumerate(train_loader):
print(i, 'i')
'''img = img[0, :, :, :, :]
gt = gt[0, :, :, :, :]'''
if torch.cuda.is_available():
img = img.cuda(async = True)
img = Variable(img)
output = teacher(img)
teacher_outputs.append(output)
return teacher_outputs
def train_student(student, teacher_outputs, optimizer, train_loader):
print('-------Train student-------')
#called once for each epoch
student.train().cuda()
summ = []
for i, (img, gt) in enumerate(train_loader):
teacher_output = teacher_outputs[i]
if torch.cuda.is_available():
img, gt = img.cuda(), gt.cuda()
teacher_output = teacher_output.cuda()
img, gt = Variable(img), Variable(gt)
teacher_output = Variable(teacher_output)
output = student(img)
#TODO: loss is wrong
loss = loss_fn_kd(output, teacher_output, gt)
# clear previous gradients, compute gradients of all variables wrt loss
optimizer.zero_grad()
loss.backward()
# performs updates using calculated gradients
optimizer.step()
if i % summary_steps == 0:
#do i need to move it to CPU?
metric = dice_loss(output, gt)
summary = {'metric' : metric.item(), 'loss' : loss.item()}
summ.append(summary)
#print('Average loss over this epoch: ' + np.mean(loss_avg))
mean_dice_coeff = np.mean([x['metric'] for x in summ])
mean_loss = np.mean([x['loss'] for x in summ])
print('- Train metrics:\n' + '\tMetric:{}\n\tLoss:{}'.format(mean_dice_coeff, mean_loss))
#print accuracy and loss
def evaluate_kd(student, val_loader):
print('-------Evaluate student-------')
student.eval().cuda()
#criterion = torch.nn.BCEWithLogitsLoss()
loss_summ = []
with torch.no_grad():
for i, (img, gt) in enumerate(val_loader):
if torch.cuda.is_available():
img, gt = img.cuda(), gt.cuda()
img, gt = Variable(img), Variable(gt)
output = student(img)
output = output.clamp(min = 0, max = 1)
loss = dice_loss(output, gt)
loss_summ.append(loss.item())
mean_loss = np.mean(loss_summ)
print('- Eval metrics:\n\tAverage Dice loss:{}'.format(mean_loss))
return mean_loss
if __name__ == "__main__":
min_loss = 100
teacher = UNet(channel_depth = 32, n_channels = 3, n_classes=1)
student = UNet(channel_depth = 16, n_channels = 3, n_classes=1)
optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)
scheduler = StepLR(optimizer, step_size = 100, gamma = 0.2)
#load teacher and student model
teacher.load_state_dict(torch.load(teacher_weights))
#student.load_state_dict(torch.load(student_weights))
#NV: add val folder
train_list = glob.glob('/home/nirvi/Internship_2020/Carvana dataset/train/train1/*jpg')
val_list = glob.glob('/home/nirvi/Internship_2020/Carvana dataset/val/val1/*jpg')
tf = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
#2 tensors -> img_list and gt_list. for batch_size = 1 --> img: (1, 3, 320, 320); gt: (1, 1, 320, 320)
train_loader = torch.utils.data.DataLoader(
dataset.listDataset(train_list,
shuffle = False,
transform = tf,
),
batch_size = 1
)
val_loader = torch.utils.data.DataLoader(
dataset.listDataset(val_list,
shuffle = False,
transform = tf,
),
batch_size = 1
)
#train_and_evaluate_kd:
#get teacher outputs as list of tensors
teacher_outputs = fetch_teacher_outputs(teacher, train_loader)
print(len(teacher_outputs))
for epoch in range(num_of_epochs):
#train the student
print(' --- student training: epoch {}'.format(epoch+1))
train_student(student, teacher_outputs, optimizer, train_loader)
#evaluate for one epoch on validation set
val = evaluate_kd(student, val_loader)
if(val < min_loss):
min_loss = val
#TODO: make min as the val loss of teacher
print('New best!!')
#if val_metric is best, add checkpoint
torch.save(student.state_dict(), 'checkpoints/0.9/16/CP{}.pth'.format(epoch+1))
print("Checkpoint {} saved!".format(epoch+1))
scheduler.step()