Skip to content

Commit

Permalink
Merge pull request #62 from Pbatch/pb_unit_yolov8
Browse files Browse the repository at this point in the history
Yolov8 Unit Model
  • Loading branch information
Pbatch authored Jan 23, 2023
2 parents 19e43ed + 14c09e1 commit 0ae2f4b
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 11 deletions.
Binary file modified clashroyalebuildabot/data/unit.onnx
Binary file not shown.
29 changes: 19 additions & 10 deletions clashroyalebuildabot/state/onnx_detector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import onnxruntime
import numpy as np
import onnxruntime


class OnnxDetector:
Expand Down Expand Up @@ -43,22 +43,31 @@ def _nms(boxes, scores, thresh):

return keep

def nms(self, prediction, conf_thres=0.35, iou_thres=0.45):
def nms(self, prediction, conf_thres=0.35, iou_thres=0.45, yolov8=False):
"""
Runs Non-Maximum Suppression (NMS) on inference results
"""
if yolov8:
prediction = prediction.transpose((0, 2, 1))
output = [np.zeros((0, 6))] * len(prediction)
for i in range(len(prediction)):
# Mask out predictions below the confidence threshold
mask = prediction[i, :, 4] > conf_thres
x = prediction[i][mask]
if yolov8:
x = prediction[i]
else:
# Mask out predictions below the confidence threshold
mask = prediction[i, :, 4] > conf_thres
x = prediction[i][mask]

if not x.shape[0]:
continue
if not x.shape[0]:
continue

# Calculate the best scores
# score = object confidence * class confidence
scores = x[:, 4:5] * x[:, 5:]
if yolov8:
# score = class confidence
scores = x[:, 4:]
else:
# score = object confidence * class confidence
scores = x[:, 4:5] * x[:, 5:]
best_scores_idx = np.argmax(scores, axis=1).reshape(-1, 1)
best_scores = np.take_along_axis(scores, best_scores_idx, axis=1)

Expand All @@ -76,8 +85,8 @@ def nms(self, prediction, conf_thres=0.35, iou_thres=0.45):

# Keep only the best class
best = np.hstack([boxes[keep], best_scores[keep], best_scores_idx[keep]])
output[i] = best

output[i] = best
return output

def _post_process(self, pred):
Expand Down
2 changes: 1 addition & 1 deletion clashroyalebuildabot/state/unit_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def run(self, image):
pred = self.sess.run([self.output_name], {self.input_name: np_image})[0]

# Forced post-processing
pred = np.array(self.nms(pred)[0])
pred = np.array(self.nms(pred, yolov8=True)[0])
pred[:, [0, 2]] *= width / UNIT_SIZE
pred[:, [1, 3]] *= height / UNIT_SIZE

Expand Down

0 comments on commit 0ae2f4b

Please sign in to comment.