-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcalculate_disco_basis.py
143 lines (105 loc) · 4.26 KB
/
calculate_disco_basis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
'''
This file is a part of the official implementation of
1) "DISCO: accurate Discrete Scale Convolutions"
by Ivan Sosnovik, Artem Moskalev, Arnold Smeulders, BMVC 2021
arxiv: https://arxiv.org/abs/2106.02733
2) "How to Transform Kernels for Scale-Convolutions"
by Ivan Sosnovik, Artem Moskalev, Arnold Smeulders, ICCV VIPriors 2021
pdf: https://openaccess.thecvf.com/content/ICCV2021W/VIPriors/papers/Sosnovik_How_To_Transform_Kernels_for_Scale-Convolutions_ICCVW_2021_paper.pdf
---------------------------------------------------------------------------
MIT License. Copyright (c) 2021 Ivan Sosnovik, Artem Moskalev
'''
import os
import time
from argparse import ArgumentParser
import torch
import torch.optim as optim
import torch.backends.cudnn as cudnn
from models.basis import ApproximateProxyBasis
from models.basis.disco import get_basis_filename
from utils import loaders
from utils.train_utils import train_equi_loss
from utils.model_utils import get_num_parameters
#########################################
# arguments
#########################################
parser = ArgumentParser()
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--epochs', type=int, default=40)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--lr_steps', type=int, nargs='+', default=[20, 30])
parser.add_argument('--lr_gamma', type=float, default=0.1)
parser.add_argument('--cuda', action='store_true', default=False)
# basis hyperparameters
parser.add_argument('--basis_size', type=int, default=7)
parser.add_argument('--basis_effective_size', type=int, default=3)
parser.add_argument('--basis_scales', type=float, nargs='+', default=[1.0])
parser.add_argument('--basis_save_dir', type=str, required=True)
args = parser.parse_args()
print("Args:")
for k, v in vars(args).items():
print(" {}={}".format(k, v))
print(flush=True)
#########################################
# Data
#########################################
loader = loaders.random_loader(args.batch_size)
print('Dataset:')
print(loader.dataset)
#########################################
# Model
#########################################
basis = ApproximateProxyBasis(size=args.basis_size, scales=args.basis_scales,
effective_size=args.basis_effective_size)
print('\nBasis:')
print(basis)
print()
use_cuda = args.cuda and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
print('Device: {}'.format(device))
if use_cuda:
cudnn.enabled = True
cudnn.benchmark = True
print('CUDNN is enabled. CUDNN benchmark is enabled')
basis.cuda()
print(flush=True)
#########################################
# optimizer
#########################################
parameters = filter(lambda x: x.requires_grad, basis.parameters())
optimizer = optim.Adam(parameters, lr=args.lr)
print(optimizer)
lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, args.lr_steps, args.lr_gamma)
#########################################
# Paths
#########################################
save_basis_postfix = get_basis_filename(size=args.basis_size,
effective_size=args.basis_effective_size,
scales=args.basis_scales)
save_basis_path = os.path.join(args.basis_save_dir, save_basis_postfix)
print('Basis path: ', save_basis_path)
print()
if not os.path.isdir(args.basis_save_dir):
os.makedirs(args.basis_save_dir)
#########################################
# Training
#########################################
print('\nTraining\n' + '-' * 30)
start_time = time.time()
best_loss = float('inf')
for epoch in range(args.epochs):
loss = train_equi_loss(basis, optimizer, loader, device)
print('Epoch {:3d}/{:3d}| Loss: {:.2e}'.format(epoch + 1, args.epochs, loss), flush=True)
if loss < best_loss:
best_loss = loss
with torch.no_grad():
torch.save(basis.get_basis().cpu(), save_basis_path)
lr_scheduler.step()
print('-' * 30)
print('Training is finished')
print('Best Loss: {:.2e}'.format(best_loss), flush=True)
end_time = time.time()
elapsed_time = end_time - start_time
time_per_epoch = elapsed_time / args.epochs
print('Total Time Elapsed: {:.2f}'.format(elapsed_time))
print('Time per Epoch: {:.2f}'.format(time_per_epoch))