-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest.py
executable file
·96 lines (73 loc) · 2.84 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import sys
import skvideo.io
import json
import base64
import torch
from io import BytesIO, StringIO
from scipy import misc
from argparse import ArgumentParser
from torchvision.transforms import ToTensor
import numpy as np
from torch.autograd import Variable
# from PIL import Image
import timeit
import cv2
def eval(file, net, device, outfile):
# process the video
# Define encoder function
# def encode(array):
# pil_img = Image.fromarray(array)
# buff = BytesIO()
# pil_img.save(buff, format="PNG")
# return base64.b64encode(buff.getvalue()).decode("utf-8")
def encode(array):
retval, buffer = cv2.imencode('.png', array)
return base64.b64encode(buffer).decode("utf-8")
# Video Processing
video = skvideo.io.vread(file)
answer_key = {}
# Frame numbering starts at 1
frame = 1
net.eval()
shape = video.shape[1:3]
ones = torch.ones(shape).to(device)
zeros = torch.zeros(shape).to(device)
for rgb_frame in video:
img = ToTensor()(rgb_frame)
img = img.to(device)
img = img.unsqueeze(0)
result = net(img)
result = result.max(1)[1]
#result = result.cpu().detach().numpy()[0]
final = result[0]
binary_car_result = torch.where(final==2,ones,zeros).cpu().detach().numpy()
binary_road_result = torch.where(final==1,ones,zeros).cpu().detach().numpy()
answer_key[frame] = [encode(binary_car_result), encode(binary_road_result)]
# Increment frame
frame += 1
# Print output in proper json format
if(outfile == None): #output to screen
print(json.dumps(answer_key))
else: #output to file
with open(outfile, 'w') as outfile:
json.dump(answer_key, outfile)
return frame
def main(args):
start = timeit.default_timer()
from erfnet import Net
net = Net(num_classes=3)
net = net.to(args.device)
net.load_state_dict(torch.load(args.model))
stop = timeit.default_timer()
frames_processed = eval(args.video, net, args.device, args.output)
#print("FPS: ", frames_processed/ (stop-start))
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--device', action='store', default='cuda', help='choices={cpu, cuda, cuda:0, cuda:1, ...}')
#parser.add_argument('--video', default='Example/test_video.mp4', help='Input video')
parser.add_argument('video', type=str, metavar='input_video', help='Input video')
parser.add_argument('--answer_file', default='Example/results.json', help='Correct json file (answer key)')
parser.add_argument('--model', default='model.pth', help='model file to load')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size')
parser.add_argument('--output', help='file to save the output json')
main(parser.parse_args())