-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathtest.py
75 lines (58 loc) · 1.99 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
import os
import time
import numpy as np
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import matplotlib.pylab as pylab
import scipy.io as scio
import scipy.misc as mc
import importlib
from datasets.sklarge import TestDataset
network = 'hed'
gpu_id = 0
torch.set_default_tensor_type('torch.cuda.FloatTensor')
torch.cuda.set_device(gpu_id)
Network = getattr(importlib.import_module('networks.' + network), 'Network')
net = Network().cuda(gpu_id).eval()
net.load_state_dict(torch.load('./weights/hed_sklarge/hed_30000.pth', map_location=lambda storage, loc: storage))
root = './OriginalSKLARGE/images/test'
files = './OriginalSKLARGE/test.lst'
dataset = TestDataset(files, root)
dataloader = list(DataLoader(dataset, batch_size=1))
def plot_single_scale(scale_lst, size):
pylab.rcParams['figure.figsize'] = size / 2, size / 2.5
plt.figure()
for i, image in enumerate(scale_lst):
image = image.data[0, 0].cpu().numpy().astype(np.float32)
s = plt.subplot(1, 1, i + 1)
plt.imshow(1 - image, cmap=cm.Greys_r)
s.set_xticklabels([])
s.set_yticklabels([])
s.yaxis.set_ticks_position('none')
s.xaxis.set_ticks_position('none')
plt.tight_layout()
idx = 16
inp, fname = dataloader[idx]
inp = Variable(inp.cuda(gpu_id))
out = net(inp)
scale_lst = [out]
plot_single_scale(scale_lst, 22)
plt.show()
output_dir = 'outputs/hed/'
if not os.path.exists(output_dir):
os.makedirs(output_dir)
start_time = time.time()
tep = 1
for inp, fname in dataloader:
inp = Variable(inp.cuda(gpu_id))
out = net(inp)
fileName = output_dir + fname[0] + '.mat'
# file_jpg = output_dir + fname[0] + '.jpg'
tep += 1
scio.savemat(fileName, {'sym': out.data[0, 0].cpu().numpy()})
# mc.toimage(out.cpu().detach()[0, 0, :, :]).save(file_jpg)
diff_time = time.time() - start_time
print('Detection took {:.5f}s per image'.format(diff_time / len(dataloader)))