Skip to content

Commit

Permalink
Optimize conf * clss for yolov5 v6 v7
Browse files Browse the repository at this point in the history
  • Loading branch information
jamjamjon committed Jul 9, 2024
1 parent b5f031a commit fcef4f4
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 138 deletions.
10 changes: 9 additions & 1 deletion src/models/yolo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ impl Vision for YOLO {
slice_bboxes,
slice_id,
slice_clss,
slice_confs,
slice_kpts,
slice_coefs,
slice_radians,
Expand All @@ -245,7 +246,14 @@ impl Vision for YOLO {
.into_iter()
.enumerate()
.max_by(|a, b| a.1.total_cmp(b.1))?;
(class_id, confidence)

match &slice_confs {
None => (class_id, confidence),
Some(slice_confs) => {
(class_id, confidence * slice_confs[[i, 0]])
}
}
// (class_id, confidence)
}
};

Expand Down
168 changes: 31 additions & 137 deletions src/models/yolo_.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use ndarray::{Array, ArrayBase, ArrayView, Axis, Dim, IxDyn, IxDynImpl, ViewRepr};
use ndarray::{ArrayBase, ArrayView, Axis, Dim, IxDyn, IxDynImpl, ViewRepr};

#[derive(Debug, Clone, clap::ValueEnum)]
pub enum YOLOTask {
Expand Down Expand Up @@ -251,7 +251,8 @@ impl YOLOPreds {
) -> (
ArrayView<f32, IxDyn>,
Option<ArrayView<f32, IxDyn>>,
Array<f32, IxDyn>,
ArrayView<f32, IxDyn>,
Option<ArrayView<f32, IxDyn>>,
Option<ArrayView<f32, IxDyn>>,
Option<ArrayView<f32, IxDyn>>,
Option<ArrayView<f32, IxDyn>>,
Expand All @@ -264,50 +265,59 @@ impl YOLOPreds {

// get each tasks slices
let (slice_bboxes, xs) = preds.split_at(Axis(1), 4);
let (slice_id, slice_clss, xs) = match self.clss {
let (slice_id, slice_clss, slice_confs, xs) = match self.clss {
ClssType::ConfClss => {
let slice_id = None;
let (confs, xs) = xs.split_at(Axis(1), 1);
let (clss, xs) = xs.split_at(Axis(1), nc);
let confs = confs.broadcast((confs.shape()[0], nc)).unwrap(); // 267ns

let t = std::time::Instant::now();
// TODO: par
let clss = &confs * &clss;
println!("2 > {:?}", std::time::Instant::now() - t); // 868.281µs

// let confs = confs.broadcast((confs.shape()[0], nc)).unwrap(); // 267ns
// let clss = &confs * &clss;
// let slice_clss = clss.to_owned();
let slice_clss = clss;
(slice_id, slice_clss, xs)
let slice_confs = Some(confs);
(slice_id, slice_clss, slice_confs, xs)
}
ClssType::ClssConf => {
let slice_id = None;
let (clss, xs) = xs.split_at(Axis(1), nc);
let (confs, xs) = xs.split_at(Axis(1), 1);
let confs = confs.broadcast((confs.shape()[0], nc)).unwrap();
// let confs = confs.broadcast((confs.shape()[0], nc)).unwrap();
// TODO: par
let clss = &confs * &clss;
// let clss = &confs * &clss;
// let slice_clss = clss;
// let slice_clss = clss.to_owned();
let slice_clss = clss;
(slice_id, slice_clss, xs)
let slice_confs = Some(confs);
// (slice_id, slice_clss, xs)
(slice_id, slice_clss, slice_confs, xs)
}
ClssType::ConfCls => {
let (clss, xs) = xs.split_at(Axis(1), 1);
let (ids, xs) = xs.split_at(Axis(1), 1);
let slice_id = Some(ids);
let slice_clss = clss.to_owned();
(slice_id, slice_clss, xs)
// let slice_clss = clss.to_owned();
let slice_clss = clss;
let slice_confs = None;
(slice_id, slice_clss, slice_confs, xs)
}
ClssType::ClsConf => {
let (ids, xs) = xs.split_at(Axis(1), 1);
let (clss, xs) = xs.split_at(Axis(1), 1);
let slice_id = Some(ids);
let slice_clss = clss.to_owned();
(slice_id, slice_clss, xs)
let slice_clss = clss;
// let slice_clss = clss.to_owned();
// (slice_id, slice_clss, xs)
let slice_confs = None;
(slice_id, slice_clss, slice_confs, xs)
}
ClssType::Clss => {
let slice_id = None;
let (clss, xs) = xs.split_at(Axis(1), nc);
let slice_clss = clss.to_owned();
(slice_id, slice_clss, xs)
// let slice_clss = clss.to_owned();
let slice_clss = clss;
// (slice_id, slice_clss, xs)
let slice_confs = None;
(slice_id, slice_clss, slice_confs, xs)
}
};
let (slice_kpts, slice_coefs, slice_radians) = match self.task() {
Expand All @@ -321,126 +331,10 @@ impl YOLOPreds {
slice_bboxes,
slice_id,
slice_clss,
slice_confs,
slice_kpts,
slice_coefs,
slice_radians,
)
}

// #[allow(clippy::type_complexity)]
// pub fn parse_preds<'a>(
// &'a self,
// preds: ArrayBase<ViewRepr<&'a f32>, Dim<IxDynImpl>>,
// nc: usize,
// ) -> (
// ArrayView<f32, IxDyn>,
// Option<ArrayView<f32, IxDyn>>,
// Array<f32, IxDyn>,
// Option<ArrayView<f32, IxDyn>>,
// Option<ArrayView<f32, IxDyn>>,
// Option<ArrayView<f32, IxDyn>>,
// ) {
// let preds = if self.is_anchors_first() {
// preds
// } else {
// preds.reversed_axes()
// };

// // get each tasks slices
// let (slice_bboxes, xs) = preds.split_at(Axis(1), 4);
// let (slice_id, slice_clss, slice_kpts, slice_coefs, slice_radians) = if self.is_cls_type() {
// // box-[cls | conf -[kpts | coefs]]
// if self.is_conf_at_end() {
// // box-cls-conf-[kpts | coefs]

// let (ids, xs) = xs.split_at(Axis(1), 1);
// let (clss, xs) = xs.split_at(Axis(1), 1);
// let slice_id = Some(ids);
// let slice_clss = clss.to_owned();

// let (slice_kpts, slice_coefs, slice_radians) = match self.task() {
// YOLOTask::Pose => (Some(xs), None, None),
// YOLOTask::Segment => (None, Some(xs), None),
// YOLOTask::Obb => (None, None, Some(xs)),
// _ => (None, None, None),
// };

// (slice_id, slice_clss, slice_kpts, slice_coefs, slice_radians)
// } else {
// // box-conf-cls-[kpts | coefs]

// let (clss, xs) = xs.split_at(Axis(1), 1);
// let (ids, xs) = xs.split_at(Axis(1), 1);
// let slice_id = Some(ids);
// let slice_clss = clss.to_owned();

// let (slice_kpts, slice_coefs, slice_radians) = match self.task() {
// YOLOTask::Pose => (Some(xs), None, None),
// YOLOTask::Segment => (None, Some(xs), None),
// YOLOTask::Obb => (None, None, Some(xs)),
// _ => (None, None, None),
// };
// (slice_id, slice_clss, slice_kpts, slice_coefs, slice_radians)
// }
// } else {
// // box-[clss | conf -[kpts | coefs]]
// if self.is_conf_independent() {
// if self.is_conf_at_end() {
// // box-clss-conf-[kpts | coefs]

// let slice_id = None;
// let (clss, xs) = xs.split_at(Axis(1), nc);
// let (confs, xs) = xs.split_at(Axis(1), 1);
// let confs = confs.broadcast((confs.shape()[0], nc)).unwrap();
// let clss = &confs * &clss;
// let slice_clss = clss;

// let (slice_kpts, slice_coefs, slice_radians) = match self.task() {
// YOLOTask::Pose => (Some(xs), None, None),
// YOLOTask::Segment => (None, Some(xs), None),
// YOLOTask::Obb => (None, None, Some(xs)),
// _ => (None, None, None),
// };
// (slice_id, slice_clss, slice_kpts, slice_coefs, slice_radians)
// } else {
// // box-conf-clss-[kpts | coefs]
// let slice_id = None;
// let (confs, xs) = xs.split_at(Axis(1), 1);
// let (clss, xs) = xs.split_at(Axis(1), nc);
// let confs = confs.broadcast((confs.shape()[0], nc)).unwrap();
// let clss = &confs * &clss;
// let slice_clss = clss;

// let (slice_kpts, slice_coefs, slice_radians) = match self.task() {
// YOLOTask::Pose => (Some(xs), None, None),
// YOLOTask::Segment => (None, Some(xs), None),
// YOLOTask::Obb => (None, None, Some(xs)),
// _ => (None, None, None),
// };
// (slice_id, slice_clss, slice_kpts, slice_coefs, slice_radians)
// }
// } else {
// // box-[clss -[kpts | coefs]]
// let slice_id = None;
// let (clss, xs) = xs.split_at(Axis(1), nc);
// let slice_clss = clss.to_owned();

// let (slice_kpts, slice_coefs, slice_radians) = match self.task() {
// YOLOTask::Pose => (Some(xs), None, None),
// YOLOTask::Segment => (None, Some(xs), None),
// YOLOTask::Obb => (None, None, Some(xs)),
// _ => (None, None, None),
// };
// (slice_id, slice_clss, slice_kpts, slice_coefs, slice_radians)
// }
// };
// (
// slice_bboxes,
// slice_id,
// slice_clss,
// slice_kpts,
// slice_coefs,
// slice_radians,
// )
// }
}

0 comments on commit fcef4f4

Please sign in to comment.