forked from zhiqwang/yolort
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.cpp
219 lines (179 loc) · 6.42 KB
/
main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
#ifdef _WIN32 // or _MSC_VER, as you wish
#include <windows.h>
#endif
#include <chrono>
#include <iostream>
#include <memory>
#include "cmdline.h"
#include <opencv2/core.hpp>
#include <opencv2/highgui.hpp>
#include <opencv2/imgproc.hpp>
#include <torch/script.h>
#include <torch/torch.h>
#include <torchvision/ops/nms.h>
#include <torchvision/vision.h>
std::vector<std::string> LoadNames(const std::string& path) {
// load class names
std::vector<std::string> class_names;
std::ifstream infile(path);
if (infile.good()) {
std::string line;
while (getline(infile, line)) {
class_names.emplace_back(line);
}
infile.close();
} else {
std::cerr << "ERROR: Failed to access class name path: " << path
<< "\n\tDoes the file exist? Permission to read it?\n";
}
return class_names;
}
torch::Tensor ReadImage(const std::string& loc) {
// Read Image from the location of image
cv::Mat img = cv::imread(loc);
cv::cvtColor(img, img, cv::COLOR_BGR2RGB);
img.convertTo(img, CV_32FC3, 1.0f / 255.0f); // normalization 1/255
// Convert image to tensor
torch::Tensor img_tensor = torch::from_blob(img.data, {img.rows, img.cols, 3});
img_tensor = img_tensor.permute({2, 0, 1}); // Reshape to C x H x W
return img_tensor.clone();
};
struct Detection {
cv::Rect bbox;
float score;
int class_idx;
};
void OverlayBoxes(
cv::Mat& img,
const std::vector<Detection>& detections,
const std::vector<std::string>& class_names,
const std::string& img_name,
bool label = true) {
for (const auto& detection : detections) {
const auto& box = detection.bbox;
float score = detection.score;
int class_idx = detection.class_idx;
cv::rectangle(img, box, cv::Scalar(0, 0, 255), 2);
if (label) {
std::stringstream ss;
ss << std::fixed << std::setprecision(2) << score;
std::string s = class_names[class_idx] + " " + ss.str();
auto font_face = cv::FONT_HERSHEY_DUPLEX;
auto font_scale = 1.0;
int thickness = 1;
int baseline = 0;
auto s_size = cv::getTextSize(s, font_face, font_scale, thickness, &baseline);
cv::rectangle(
img,
cv::Point(box.tl().x, box.tl().y - s_size.height - 5),
cv::Point(box.tl().x + s_size.width, box.tl().y),
cv::Scalar(0, 0, 255),
-1);
cv::putText(
img,
s,
cv::Point(box.tl().x, box.tl().y - 5),
font_face,
font_scale,
cv::Scalar(255, 255, 255),
thickness);
}
}
cv::imwrite(img_name, img);
}
int main(int argc, char* argv[]) {
cmdline::parser cmd;
cmd.add<std::string>(
"checkpoint", 'c', "path of the generated torchscript file", true, "yolov5.torchscript.pt");
cmd.add<std::string>("input_source", 'i', "image source to be detected", true, "bus.jpg");
cmd.add<std::string>("labelmap", 'l', "path of dataset labels", true, "coco.names");
cmd.add("gpu", '\0', "Enable cuda device or cpu");
#ifdef _WIN32
cmd.parse_check(GetCommandLineA());
#else
cmd.parse_check(argc, argv);
#endif
// check if gpu flag is set
bool is_gpu = cmd.exist("gpu");
// set device type - CPU/GPU
torch::DeviceType device_type;
if (torch::cuda::is_available() && is_gpu) {
std::cout << "Set GPU mode" << std::endl;
device_type = torch::kCUDA;
} else {
std::cout << "Set CPU mode" << std::endl;
device_type = torch::kCPU;
}
// load class names from dataset for visualization
std::string labelmap = cmd.get<std::string>("labelmap");
std::vector<std::string> class_names = LoadNames(labelmap);
if (class_names.empty()) {
return -1;
}
// load input image
std::string image_path = cmd.get<std::string>("input_source");
if (std::ifstream(image_path).fail()) {
std::cerr << "ERROR: Failed to access image file path: " << image_path
<< "\n\tDoes the file exist? Permission to read it?\n";
return -1;
}
torch::jit::script::Module module;
try {
std::cout << "Loading model" << std::endl;
// Deserialize the ScriptModule from a file using torch::jit::load().
std::string weights = cmd.get<std::string>("checkpoint");
if (std::ifstream(weights).fail()) {
std::cerr << "ERROR: Failed to access checkpoint file path: " << weights
<< "\n\tDoes the file exist? Permission to read it?\n";
return -1;
}
module = torch::jit::load(weights);
module.to(device_type);
module.eval();
std::cout << "Model loaded" << std::endl;
} catch (const torch::Error& e) {
std::cout << "Error loading the model: " << e.what() << std::endl;
return -1;
} catch (const std::exception& e) {
std::cout << "Other error: " << e.what() << std::endl;
return -1;
}
// TorchScript models require a List[IValue] as input
std::vector<torch::jit::IValue> inputs;
// YOLO accepts a List[Tensor] as main input
std::vector<torch::Tensor> images;
torch::TensorOptions options = torch::TensorOptions{device_type};
// Run once to warm up
std::cout << "Run once on empty image" << std::endl;
auto img_dumy = torch::rand({3, 416, 320}, options);
images.push_back(img_dumy);
inputs.push_back(images);
auto output = module.forward(inputs);
images.clear();
inputs.clear();
/*** Pre-process ***/
auto start = std::chrono::high_resolution_clock::now();
// Read image
auto img = ReadImage(image_path);
img = img.to(device_type);
images.push_back(img);
inputs.push_back(images);
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
std::cout << "Pre-process takes : " << duration.count() << " ms" << std::endl;
// Run once to warm up
output = module.forward(inputs);
/*** Inference ***/
// TODO: add synchronize point
start = std::chrono::high_resolution_clock::now();
output = module.forward(inputs);
end = std::chrono::high_resolution_clock::now();
duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
// It should be known that it takes longer time at first time
std::cout << "Inference takes : " << duration.count() << " ms" << std::endl;
auto detections = output.toTuple()->elements().at(1).toList().get(0).toGenericDict();
std::cout << "Detected labels: " << detections.at("labels") << std::endl;
std::cout << "Detected boxes: " << detections.at("boxes") << std::endl;
std::cout << "Detected scores: " << detections.at("scores") << std::endl;
return 0;
}