diff --git a/clashroyalebuildabot/data/unit.onnx b/clashroyalebuildabot/data/unit.onnx index 93a61d9..981e191 100644 Binary files a/clashroyalebuildabot/data/unit.onnx and b/clashroyalebuildabot/data/unit.onnx differ diff --git a/clashroyalebuildabot/state/onnx_detector.py b/clashroyalebuildabot/state/onnx_detector.py index 05c0a4b..58e34eb 100644 --- a/clashroyalebuildabot/state/onnx_detector.py +++ b/clashroyalebuildabot/state/onnx_detector.py @@ -1,5 +1,5 @@ -import onnxruntime import numpy as np +import onnxruntime class OnnxDetector: @@ -43,22 +43,31 @@ def _nms(boxes, scores, thresh): return keep - def nms(self, prediction, conf_thres=0.35, iou_thres=0.45): + def nms(self, prediction, conf_thres=0.35, iou_thres=0.45, yolov8=False): """ Runs Non-Maximum Suppression (NMS) on inference results """ + if yolov8: + prediction = prediction.transpose((0, 2, 1)) output = [np.zeros((0, 6))] * len(prediction) for i in range(len(prediction)): - # Mask out predictions below the confidence threshold - mask = prediction[i, :, 4] > conf_thres - x = prediction[i][mask] + if yolov8: + x = prediction[i] + else: + # Mask out predictions below the confidence threshold + mask = prediction[i, :, 4] > conf_thres + x = prediction[i][mask] - if not x.shape[0]: - continue + if not x.shape[0]: + continue # Calculate the best scores - # score = object confidence * class confidence - scores = x[:, 4:5] * x[:, 5:] + if yolov8: + # score = class confidence + scores = x[:, 4:] + else: + # score = object confidence * class confidence + scores = x[:, 4:5] * x[:, 5:] best_scores_idx = np.argmax(scores, axis=1).reshape(-1, 1) best_scores = np.take_along_axis(scores, best_scores_idx, axis=1) @@ -76,8 +85,8 @@ def nms(self, prediction, conf_thres=0.35, iou_thres=0.45): # Keep only the best class best = np.hstack([boxes[keep], best_scores[keep], best_scores_idx[keep]]) - output[i] = best + output[i] = best return output def _post_process(self, pred): diff --git a/clashroyalebuildabot/state/unit_detector.py b/clashroyalebuildabot/state/unit_detector.py index 7a74d4e..fa25a35 100644 --- a/clashroyalebuildabot/state/unit_detector.py +++ b/clashroyalebuildabot/state/unit_detector.py @@ -90,7 +90,7 @@ def run(self, image): pred = self.sess.run([self.output_name], {self.input_name: np_image})[0] # Forced post-processing - pred = np.array(self.nms(pred)[0]) + pred = np.array(self.nms(pred, yolov8=True)[0]) pred[:, [0, 2]] *= width / UNIT_SIZE pred[:, [1, 3]] *= height / UNIT_SIZE