diff --git a/src/models/yolo.rs b/src/models/yolo.rs index 40cd4d3..e651483 100644 --- a/src/models/yolo.rs +++ b/src/models/yolo.rs @@ -224,6 +224,7 @@ impl Vision for YOLO { slice_bboxes, slice_id, slice_clss, + slice_confs, slice_kpts, slice_coefs, slice_radians, @@ -245,7 +246,14 @@ impl Vision for YOLO { .into_iter() .enumerate() .max_by(|a, b| a.1.total_cmp(b.1))?; - (class_id, confidence) + + match &slice_confs { + None => (class_id, confidence), + Some(slice_confs) => { + (class_id, confidence * slice_confs[[i, 0]]) + } + } + // (class_id, confidence) } }; diff --git a/src/models/yolo_.rs b/src/models/yolo_.rs index e19af6e..78c1e6e 100644 --- a/src/models/yolo_.rs +++ b/src/models/yolo_.rs @@ -1,4 +1,4 @@ -use ndarray::{Array, ArrayBase, ArrayView, Axis, Dim, IxDyn, IxDynImpl, ViewRepr}; +use ndarray::{ArrayBase, ArrayView, Axis, Dim, IxDyn, IxDynImpl, ViewRepr}; #[derive(Debug, Clone, clap::ValueEnum)] pub enum YOLOTask { @@ -251,7 +251,8 @@ impl YOLOPreds { ) -> ( ArrayView, Option>, - Array, + ArrayView, + Option>, Option>, Option>, Option>, @@ -264,50 +265,59 @@ impl YOLOPreds { // get each tasks slices let (slice_bboxes, xs) = preds.split_at(Axis(1), 4); - let (slice_id, slice_clss, xs) = match self.clss { + let (slice_id, slice_clss, slice_confs, xs) = match self.clss { ClssType::ConfClss => { let slice_id = None; let (confs, xs) = xs.split_at(Axis(1), 1); let (clss, xs) = xs.split_at(Axis(1), nc); - let confs = confs.broadcast((confs.shape()[0], nc)).unwrap(); // 267ns - - let t = std::time::Instant::now(); - // TODO: par - let clss = &confs * &clss; - println!("2 > {:?}", std::time::Instant::now() - t); // 868.281µs - + // let confs = confs.broadcast((confs.shape()[0], nc)).unwrap(); // 267ns + // let clss = &confs * &clss; + // let slice_clss = clss.to_owned(); let slice_clss = clss; - (slice_id, slice_clss, xs) + let slice_confs = Some(confs); + (slice_id, slice_clss, slice_confs, xs) } ClssType::ClssConf => { let slice_id = None; let (clss, xs) = xs.split_at(Axis(1), nc); let (confs, xs) = xs.split_at(Axis(1), 1); - let confs = confs.broadcast((confs.shape()[0], nc)).unwrap(); + // let confs = confs.broadcast((confs.shape()[0], nc)).unwrap(); // TODO: par - let clss = &confs * &clss; + // let clss = &confs * &clss; + // let slice_clss = clss; + // let slice_clss = clss.to_owned(); let slice_clss = clss; - (slice_id, slice_clss, xs) + let slice_confs = Some(confs); + // (slice_id, slice_clss, xs) + (slice_id, slice_clss, slice_confs, xs) } ClssType::ConfCls => { let (clss, xs) = xs.split_at(Axis(1), 1); let (ids, xs) = xs.split_at(Axis(1), 1); let slice_id = Some(ids); - let slice_clss = clss.to_owned(); - (slice_id, slice_clss, xs) + // let slice_clss = clss.to_owned(); + let slice_clss = clss; + let slice_confs = None; + (slice_id, slice_clss, slice_confs, xs) } ClssType::ClsConf => { let (ids, xs) = xs.split_at(Axis(1), 1); let (clss, xs) = xs.split_at(Axis(1), 1); let slice_id = Some(ids); - let slice_clss = clss.to_owned(); - (slice_id, slice_clss, xs) + let slice_clss = clss; + // let slice_clss = clss.to_owned(); + // (slice_id, slice_clss, xs) + let slice_confs = None; + (slice_id, slice_clss, slice_confs, xs) } ClssType::Clss => { let slice_id = None; let (clss, xs) = xs.split_at(Axis(1), nc); - let slice_clss = clss.to_owned(); - (slice_id, slice_clss, xs) + // let slice_clss = clss.to_owned(); + let slice_clss = clss; + // (slice_id, slice_clss, xs) + let slice_confs = None; + (slice_id, slice_clss, slice_confs, xs) } }; let (slice_kpts, slice_coefs, slice_radians) = match self.task() { @@ -321,126 +331,10 @@ impl YOLOPreds { slice_bboxes, slice_id, slice_clss, + slice_confs, slice_kpts, slice_coefs, slice_radians, ) } - - // #[allow(clippy::type_complexity)] - // pub fn parse_preds<'a>( - // &'a self, - // preds: ArrayBase, Dim>, - // nc: usize, - // ) -> ( - // ArrayView, - // Option>, - // Array, - // Option>, - // Option>, - // Option>, - // ) { - // let preds = if self.is_anchors_first() { - // preds - // } else { - // preds.reversed_axes() - // }; - - // // get each tasks slices - // let (slice_bboxes, xs) = preds.split_at(Axis(1), 4); - // let (slice_id, slice_clss, slice_kpts, slice_coefs, slice_radians) = if self.is_cls_type() { - // // box-[cls | conf -[kpts | coefs]] - // if self.is_conf_at_end() { - // // box-cls-conf-[kpts | coefs] - - // let (ids, xs) = xs.split_at(Axis(1), 1); - // let (clss, xs) = xs.split_at(Axis(1), 1); - // let slice_id = Some(ids); - // let slice_clss = clss.to_owned(); - - // let (slice_kpts, slice_coefs, slice_radians) = match self.task() { - // YOLOTask::Pose => (Some(xs), None, None), - // YOLOTask::Segment => (None, Some(xs), None), - // YOLOTask::Obb => (None, None, Some(xs)), - // _ => (None, None, None), - // }; - - // (slice_id, slice_clss, slice_kpts, slice_coefs, slice_radians) - // } else { - // // box-conf-cls-[kpts | coefs] - - // let (clss, xs) = xs.split_at(Axis(1), 1); - // let (ids, xs) = xs.split_at(Axis(1), 1); - // let slice_id = Some(ids); - // let slice_clss = clss.to_owned(); - - // let (slice_kpts, slice_coefs, slice_radians) = match self.task() { - // YOLOTask::Pose => (Some(xs), None, None), - // YOLOTask::Segment => (None, Some(xs), None), - // YOLOTask::Obb => (None, None, Some(xs)), - // _ => (None, None, None), - // }; - // (slice_id, slice_clss, slice_kpts, slice_coefs, slice_radians) - // } - // } else { - // // box-[clss | conf -[kpts | coefs]] - // if self.is_conf_independent() { - // if self.is_conf_at_end() { - // // box-clss-conf-[kpts | coefs] - - // let slice_id = None; - // let (clss, xs) = xs.split_at(Axis(1), nc); - // let (confs, xs) = xs.split_at(Axis(1), 1); - // let confs = confs.broadcast((confs.shape()[0], nc)).unwrap(); - // let clss = &confs * &clss; - // let slice_clss = clss; - - // let (slice_kpts, slice_coefs, slice_radians) = match self.task() { - // YOLOTask::Pose => (Some(xs), None, None), - // YOLOTask::Segment => (None, Some(xs), None), - // YOLOTask::Obb => (None, None, Some(xs)), - // _ => (None, None, None), - // }; - // (slice_id, slice_clss, slice_kpts, slice_coefs, slice_radians) - // } else { - // // box-conf-clss-[kpts | coefs] - // let slice_id = None; - // let (confs, xs) = xs.split_at(Axis(1), 1); - // let (clss, xs) = xs.split_at(Axis(1), nc); - // let confs = confs.broadcast((confs.shape()[0], nc)).unwrap(); - // let clss = &confs * &clss; - // let slice_clss = clss; - - // let (slice_kpts, slice_coefs, slice_radians) = match self.task() { - // YOLOTask::Pose => (Some(xs), None, None), - // YOLOTask::Segment => (None, Some(xs), None), - // YOLOTask::Obb => (None, None, Some(xs)), - // _ => (None, None, None), - // }; - // (slice_id, slice_clss, slice_kpts, slice_coefs, slice_radians) - // } - // } else { - // // box-[clss -[kpts | coefs]] - // let slice_id = None; - // let (clss, xs) = xs.split_at(Axis(1), nc); - // let slice_clss = clss.to_owned(); - - // let (slice_kpts, slice_coefs, slice_radians) = match self.task() { - // YOLOTask::Pose => (Some(xs), None, None), - // YOLOTask::Segment => (None, Some(xs), None), - // YOLOTask::Obb => (None, None, Some(xs)), - // _ => (None, None, None), - // }; - // (slice_id, slice_clss, slice_kpts, slice_coefs, slice_radians) - // } - // }; - // ( - // slice_bboxes, - // slice_id, - // slice_clss, - // slice_kpts, - // slice_coefs, - // slice_radians, - // ) - // } }