-
Notifications
You must be signed in to change notification settings - Fork 384
/
Copy pathdemo.py
120 lines (99 loc) · 4.06 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import glob
import os
import time
import torch
from PIL import Image
from vizer.draw import draw_boxes
from ssd.config import cfg
from ssd.data.datasets import COCODataset, VOCDataset
import argparse
import numpy as np
from ssd.data.transforms import build_transforms
from ssd.modeling.detector import build_detection_model
from ssd.utils import mkdir
from ssd.utils.checkpoint import CheckPointer
@torch.no_grad()
def run_demo(cfg, ckpt, score_threshold, images_dir, output_dir, dataset_type):
if dataset_type == "voc":
class_names = VOCDataset.class_names
elif dataset_type == 'coco':
class_names = COCODataset.class_names
else:
raise NotImplementedError('Not implemented now.')
device = torch.device(cfg.MODEL.DEVICE)
model = build_detection_model(cfg)
model = model.to(device)
checkpointer = CheckPointer(model, save_dir=cfg.OUTPUT_DIR)
checkpointer.load(ckpt, use_latest=ckpt is None)
weight_file = ckpt if ckpt else checkpointer.get_checkpoint_file()
print('Loaded weights from {}'.format(weight_file))
image_paths = glob.glob(os.path.join(images_dir, '*.jpg'))
mkdir(output_dir)
cpu_device = torch.device("cpu")
transforms = build_transforms(cfg, is_train=False)
model.eval()
for i, image_path in enumerate(image_paths):
start = time.time()
image_name = os.path.basename(image_path)
image = np.array(Image.open(image_path).convert("RGB"))
height, width = image.shape[:2]
images = transforms(image)[0].unsqueeze(0)
load_time = time.time() - start
start = time.time()
result = model(images.to(device))[0]
inference_time = time.time() - start
result = result.resize((width, height)).to(cpu_device).numpy()
boxes, labels, scores = result['boxes'], result['labels'], result['scores']
indices = scores > score_threshold
boxes = boxes[indices]
labels = labels[indices]
scores = scores[indices]
meters = ' | '.join(
[
'objects {:02d}'.format(len(boxes)),
'load {:03d}ms'.format(round(load_time * 1000)),
'inference {:03d}ms'.format(round(inference_time * 1000)),
'FPS {}'.format(round(1.0 / inference_time))
]
)
print('({:04d}/{:04d}) {}: {}'.format(i + 1, len(image_paths), image_name, meters))
drawn_image = draw_boxes(image, boxes, labels, scores, class_names).astype(np.uint8)
Image.fromarray(drawn_image).save(os.path.join(output_dir, image_name))
def main():
parser = argparse.ArgumentParser(description="SSD Demo.")
parser.add_argument(
"--config-file",
default="",
metavar="FILE",
help="path to config file",
type=str,
)
parser.add_argument("--ckpt", type=str, default=None, help="Trained weights.")
parser.add_argument("--score_threshold", type=float, default=0.7)
parser.add_argument("--images_dir", default='demo', type=str, help='Specify a image dir to do prediction.')
parser.add_argument("--output_dir", default='demo/result', type=str, help='Specify a image dir to save predicted images.')
parser.add_argument("--dataset_type", default="voc", type=str, help='Specify dataset type. Currently support voc and coco.')
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
args = parser.parse_args()
print(args)
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
print("Loaded configuration file {}".format(args.config_file))
with open(args.config_file, "r") as cf:
config_str = "\n" + cf.read()
print(config_str)
print("Running with config:\n{}".format(cfg))
run_demo(cfg=cfg,
ckpt=args.ckpt,
score_threshold=args.score_threshold,
images_dir=args.images_dir,
output_dir=args.output_dir,
dataset_type=args.dataset_type)
if __name__ == '__main__':
main()