From b5f031a4ee93a3a3ad46e44677a0693282ad3327 Mon Sep 17 00:00:00 2001 From: jamjamjon Date: Tue, 9 Jul 2024 22:10:45 +0800 Subject: [PATCH] Update --- examples/yolo/main.rs | 37 ++- src/core/options.rs | 30 +-- src/models/mod.rs | 7 +- src/models/yolo.rs | 123 ++++----- src/models/yolo_.rs | 446 ++++++++++++++++++++++++++++++++ src/models/yolo_format.rs | 532 +++++++++++++++++++++++--------------- src/models/yolop.rs | 6 +- src/ys/mbr.rs | 10 +- src/ys/mod.rs | 1 - 9 files changed, 871 insertions(+), 321 deletions(-) create mode 100644 src/models/yolo_.rs diff --git a/examples/yolo/main.rs b/examples/yolo/main.rs index 9be7f52..a11ee3a 100644 --- a/examples/yolo/main.rs +++ b/examples/yolo/main.rs @@ -1,8 +1,7 @@ +use anyhow::Result; use clap::Parser; -use usls::{ - coco, models::YOLO, Annotator, DataLoader, Options, Vision, YOLOFormat, YOLOTask, YOLOVersion, -}; +use usls::{coco, models::YOLO, Annotator, DataLoader, Options, Vision, YOLOTask, YOLOVersion}; #[derive(Parser, Clone)] #[command(author, version, about, long_about = None)] @@ -16,9 +15,8 @@ pub struct Args { #[arg(long, value_enum, default_value_t = YOLOVersion::V8)] pub version: YOLOVersion, - #[arg(long, value_enum, default_value_t = YOLOFormat::NCxcywhClssA)] - pub format: YOLOFormat, - + // #[arg(long, value_enum, default_value_t = YOLOFormat::NCxcywhClssA)] + // pub format: YOLOFormat, #[arg(long, default_value_t = 224)] pub width_min: isize, @@ -37,6 +35,9 @@ pub struct Args { #[arg(long, default_value_t = 800)] pub height_max: isize, + #[arg(long, default_value_t = 80)] + pub nc: usize, + #[arg(long)] pub trt: bool, @@ -59,7 +60,7 @@ pub struct Args { pub plot: bool, } -fn main() -> Result<(), Box> { +fn main() -> Result<()> { let args = Args::parse(); // build options @@ -68,13 +69,21 @@ fn main() -> Result<(), Box> { // version & task let options = match args.version { YOLOVersion::V5 => match args.task { - YOLOTask::Classify => options.with_model("../models/yolov5s-cls.onnx")?, - YOLOTask::Detect => options.with_model("../models/yolov5s.onnx")?, - YOLOTask::Segment => options.with_model("../models/yolov5s.onnx")?, - t => todo!("{t:?} is unsupported for {:?}", args.version), + YOLOTask::Classify => options.with_model("yolov5n-cls-dyn.onnx")?, + YOLOTask::Detect => options.with_model("yolov5n-dyn.onnx")?, + YOLOTask::Segment => options.with_model("yolov5n-seg-dyn.onnx")?, + t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.version), + }, + YOLOVersion::V6 => match args.task { + YOLOTask::Detect => options.with_model("yolov6n-dyn.onnx")?.with_nc(args.nc), + t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.version), + }, + YOLOVersion::V7 => match args.task { + YOLOTask::Detect => options.with_model("yolov7-tiny-dyn.onnx")?.with_nc(args.nc), + t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.version), }, YOLOVersion::V8 => match args.task { - YOLOTask::Classify => options.with_model("yolov8m-cls-dyn-cls.onnx")?, + YOLOTask::Classify => options.with_model("yolov8m-cls-dyn.onnx")?, YOLOTask::Detect => options.with_model("yolov8m-dyn.onnx")?, YOLOTask::Segment => options.with_model("yolov8m-seg-dyn.onnx")?, YOLOTask::Pose => options.with_model("yolov8m-pose-dyn.onnx")?, @@ -82,11 +91,11 @@ fn main() -> Result<(), Box> { }, YOLOVersion::V9 => match args.task { YOLOTask::Detect => options.with_model("yolov9-c-dyn-f16.onnx")?, - t => todo!("{t:?} is unsupported for {:?}", args.version), + t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.version), }, YOLOVersion::V10 => match args.task { YOLOTask::Detect => options.with_model("yolov10n-dyn.onnx")?, - t => todo!("{t:?} is unsupported for {:?}", args.version), + t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.version), }, } .with_yolo_version(args.version) diff --git a/src/core/options.rs b/src/core/options.rs index 0c8035f..cb6e866 100644 --- a/src/core/options.rs +++ b/src/core/options.rs @@ -4,7 +4,7 @@ use anyhow::Result; use crate::{ auto_load, - models::{YOLOFormat, YOLOTask, YOLOVersion}, + models::{YOLOPreds, YOLOTask, YOLOVersion}, Device, MinOptMax, }; @@ -51,8 +51,7 @@ pub struct Options { pub nm: Option, pub confs: Vec, pub kconfs: Vec, - pub iou: f32, - pub apply_nms: bool, + pub iou: Option, pub tokenizer: Option, pub vocab: Option, pub names: Option>, // names @@ -61,10 +60,9 @@ pub struct Options { pub min_width: Option, pub min_height: Option, pub unclip_ratio: f32, // DB - pub apply_probs_softmax: bool, pub yolo_task: Option, pub yolo_version: Option, - pub yolo_format: Option, + pub yolo_preds: Option, } impl Default for Options { @@ -106,8 +104,7 @@ impl Default for Options { nm: None, confs: vec![0.4f32], kconfs: vec![0.5f32], - iou: 0.45f32, - apply_nms: true, + iou: None, tokenizer: None, vocab: None, names: None, @@ -118,8 +115,7 @@ impl Default for Options { unclip_ratio: 1.5, yolo_task: None, yolo_version: None, - apply_probs_softmax: false, - yolo_format: None, + yolo_preds: None, } } } @@ -170,11 +166,6 @@ impl Options { self } - pub fn apply_probs_softmax(mut self, x: bool) -> Self { - self.apply_probs_softmax = x; - self - } - pub fn with_profile(mut self, profile: bool) -> Self { self.profile = profile; self @@ -220,13 +211,8 @@ impl Options { self } - pub fn with_yolo_format(mut self, x: YOLOFormat) -> Self { - self.yolo_format = Some(x); - self - } - - pub fn with_nms(mut self, apply_nms: bool) -> Self { - self.apply_nms = apply_nms; + pub fn with_yolo_preds(mut self, x: YOLOPreds) -> Self { + self.yolo_preds = Some(x); self } @@ -241,7 +227,7 @@ impl Options { } pub fn with_iou(mut self, x: f32) -> Self { - self.iou = x; + self.iou = Some(x); self } diff --git a/src/models/mod.rs b/src/models/mod.rs index 297ba50..df8c9d9 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -10,7 +10,7 @@ mod rtdetr; mod rtmo; mod svtr; mod yolo; -mod yolo_format; +mod yolo_; mod yolop; pub use blip::Blip; @@ -23,5 +23,8 @@ pub use rtdetr::RTDETR; pub use rtmo::RTMO; pub use svtr::SVTR; pub use yolo::YOLO; -pub use yolo_format::{BoxType, YOLOFormat, YOLOTask, YOLOVersion}; +pub use yolo_::*; +// { +// AnchorsPosition, BoxType, ClssType, KptsType, YOLOFormat, YOLOPreds, YOLOTask, YOLOVersion, +// }; pub use yolop::YOLOPv2; diff --git a/src/models/yolo.rs b/src/models/yolo.rs index 8a93518..40cd4d3 100644 --- a/src/models/yolo.rs +++ b/src/models/yolo.rs @@ -6,7 +6,7 @@ use regex::Regex; use crate::{ Bbox, BoxType, DynConf, Keypoint, Mask, Mbr, MinOptMax, Ops, Options, OrtEngine, Polygon, Prob, - Vision, YOLOFormat, YOLOTask, YOLOVersion, X, Y, + Vision, YOLOPreds, YOLOTask, YOLOVersion, X, Y, }; #[derive(Debug)] @@ -25,7 +25,7 @@ pub struct YOLO { apply_nms: bool, apply_probs_softmax: bool, task: YOLOTask, - yolo_format: YOLOFormat, + yolo_preds: YOLOPreds, version: Option, } @@ -55,47 +55,47 @@ impl Vision for YOLO { } })); - // YOLO Output Format - let (version, yolo_format, apply_nms, apply_probs_softmax) = match options.yolo_version { + // YOLO Outputs Format + let (version, yolo_preds) = match options.yolo_version { Some(ver) => match &task { - None => anyhow::bail!("No clear YOLO Task specified."), + None => anyhow::bail!("No clear YOLO Task specified for Version: {ver:?}."), Some(task) => match task { YOLOTask::Classify => match ver { - YOLOVersion::V5 => (Some(ver), YOLOFormat::NClss, None, Some(true)), - YOLOVersion::V8 => (Some(ver), YOLOFormat::NClss, None, Some(false)), - x => anyhow::bail!("YOLOTask::Classify is unsupported for {x:?}. Try using `.with_yolo_format()` for customization.") + YOLOVersion::V5 => (Some(ver), YOLOPreds::n_clss().apply_softmax(true)), + YOLOVersion::V8 => (Some(ver), YOLOPreds::n_clss()), + x => anyhow::bail!("YOLOTask::Classify is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.") } YOLOTask::Detect => match ver { - v @ YOLOVersion::V5 => (Some(v),YOLOFormat::NACxcywhConfClss, Some(true), None), - YOLOVersion::V8 => (Some(ver),YOLOFormat::NCxcywhClssA, Some(true), None), - YOLOVersion::V9 => (Some(ver),YOLOFormat::NCxcywhClssA, Some(true), None), - YOLOVersion::V10 => (Some(ver),YOLOFormat::NAXyxyConfCls, Some(false), None), + YOLOVersion::V5 | YOLOVersion::V6 | YOLOVersion::V7 => (Some(ver),YOLOPreds::n_a_cxcywh_confclss()), + YOLOVersion::V8 => (Some(ver),YOLOPreds::n_cxcywh_clss_a()), + YOLOVersion::V9 => (Some(ver),YOLOPreds::n_cxcywh_clss_a()), + YOLOVersion::V10 => (Some(ver),YOLOPreds::n_a_xyxy_confcls().apply_nms(false)), } YOLOTask::Pose => match ver { - YOLOVersion::V8 => (Some(ver),YOLOFormat::NCxcywhClssXycsA, Some(true), None), - x => anyhow::bail!("YOLOTask::Pose is unsupported for {x:?}. Try using `.with_yolo_format()` for customization.") + YOLOVersion::V8 => (Some(ver),YOLOPreds::n_cxcywh_clss_xycs_a()), + x => anyhow::bail!("YOLOTask::Pose is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.") } YOLOTask::Segment => match ver { - YOLOVersion::V5 => (Some(ver), YOLOFormat::NACxcywhConfClssCoefs, Some(true), None), - YOLOVersion::V8 => (Some(ver), YOLOFormat::NCxcywhClssCoefsA, Some(true), None), - x => anyhow::bail!("YOLOTask::Segment is unsupported for {x:?}. Try using `.with_yolo_format()` for customization.") + YOLOVersion::V5 => (Some(ver), YOLOPreds::n_a_cxcywh_confclss_coefs()), + YOLOVersion::V8 => (Some(ver), YOLOPreds::n_cxcywh_clss_coefs_a()), + x => anyhow::bail!("YOLOTask::Segment is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.") } YOLOTask::Obb => match ver { - YOLOVersion::V8 => (Some(ver), YOLOFormat::NCxcywhClssRA, Some(true), None), - x => anyhow::bail!("YOLOTask::Segment is unsupported for {x:?}. Try using `.with_yolo_format()` for customization.") + YOLOVersion::V8 => (Some(ver), YOLOPreds::n_cxcywh_clss_r_a()), + x => anyhow::bail!("YOLOTask::Segment is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.") } } } - None => match options.yolo_format { + None => match options.yolo_preds { None => anyhow::bail!("No clear YOLO version or YOLO Format specified."), - Some(fmt) => (None, fmt, None, None) + Some(fmt) => (None, fmt) } }; - let task = task.unwrap_or(yolo_format.task()); - let apply_nms = apply_nms.unwrap_or(options.apply_nms); - let apply_probs_softmax = apply_probs_softmax.unwrap_or(options.apply_probs_softmax); - // Try from custom class names, and then model metadata + let task = task.unwrap_or(yolo_preds.task()); + let (apply_nms, apply_probs_softmax) = (yolo_preds.apply_nms, yolo_preds.apply_softmax); + + // Class names let mut names = options.names.or(Self::fetch_names(&engine)); let nc = match options.nc { Some(nc) => { @@ -114,13 +114,15 @@ impl Vision for YOLO { None => match &names { Some(names) => names.len(), None => panic!( - "Can not parse model without `nc` and `class names`. Try to make it explicit." + "Can not parse model without `nc` and `class names`. Try to make it explicit with `options.with_nc(80)`" ), }, }; - let names_kpt = options.names2.or(None); - // Try from model metadata + // Keypoints names + let names_kpt = options.names2; + + // The number of keypoints let nk = engine .try_fetch("kpt_shape") .map(|kpt_string| { @@ -131,12 +133,10 @@ impl Vision for YOLO { .unwrap_or(0_usize); let confs = DynConf::new(&options.confs, nc); let kconfs = DynConf::new(&options.kconfs, nk); + let iou = options.iou.unwrap_or(0.45); // Summary - println!( - "Task: {:?}, Version: {:?}, Outputs Format: {:?}, Apply NMS: {:?}", - task, version, yolo_format, apply_nms - ); + println!("YOLO Task: {:?}, Version: {:?}", task, version); engine.dry_run()?; @@ -144,7 +144,7 @@ impl Vision for YOLO { engine, confs, kconfs, - iou: options.iou, + iou, nc, nk, height, @@ -155,7 +155,7 @@ impl Vision for YOLO { names_kpt, apply_nms, apply_probs_softmax, - yolo_format, + yolo_preds, version, }) } @@ -227,7 +227,7 @@ impl Vision for YOLO { slice_kpts, slice_coefs, slice_radians, - ) = self.yolo_format.parse_preds(preds, self.nc); + ) = self.yolo_preds.parse_preds(preds, self.nc); let mut y = Y::default(); let (y_bboxes, y_mbrs) = @@ -255,28 +255,29 @@ impl Vision for YOLO { } // Bboxes - let (cx, cy, x, y, w, h) = match self.yolo_format.box_type() { - BoxType::Cxcywh => { - let cx = bbox[0] / ratio; - let cy = bbox[1] / ratio; - let w = bbox[2] / ratio; - let h = bbox[3] / ratio; - let x = (cx - w / 2.).clamp(0.0, image_width); - let y = (cy - h / 2.).clamp(0.0, image_height); - (cx, cy, x, y, w, h) - } - BoxType::Xyxy => { - let x = bbox[0] / ratio; - let y = bbox[1] / ratio; - let x2 = bbox[2] / ratio; - let y2 = bbox[3] / ratio; - let (w, h) = (x2 - x, y2 - y); - let cx = x + w / 2.; - let cy = y + h / 2.; - (cx, cy, x, y, w, h) - } - _ => todo!(), - }; + let (cx, cy, x, y, w, h) = + match self.yolo_preds.bbox.as_ref()? { + BoxType::Cxcywh => { + let cx = bbox[0] / ratio; + let cy = bbox[1] / ratio; + let w = bbox[2] / ratio; + let h = bbox[3] / ratio; + let x = (cx - w / 2.).clamp(0.0, image_width); + let y = (cy - h / 2.).clamp(0.0, image_height); + (cx, cy, x, y, w, h) + } + BoxType::Xyxy => { + let x = bbox[0] / ratio; + let y = bbox[1] / ratio; + let x2 = bbox[2] / ratio; + let y2 = bbox[3] / ratio; + let (w, h) = (x2 - x, y2 - y); + let cx = x + w / 2.; + let cy = y + h / 2.; + (cx, cy, x, y, w, h) + } + _ => todo!(), + }; let (y_bbox, y_mbr) = match &slice_radians { @@ -347,7 +348,7 @@ impl Vision for YOLO { // Pose if let Some(pred_kpts) = slice_kpts { - let kpt_step = self.yolo_format.kpt_step().unwrap_or(3); + let kpt_step = self.yolo_preds.kpt_step().unwrap_or(3); if let Some(bboxes) = y.bboxes() { let y_kpts = bboxes .into_par_iter() @@ -493,9 +494,9 @@ impl YOLO { &self.task } - pub fn yolo_format(&self) -> &YOLOFormat { - &self.yolo_format - } + // pub fn yolo_preds(&self) -> &YOLOPreds { + // &self.yolo_preds + // } fn fetch_names(engine: &OrtEngine) -> Option> { // fetch class names from onnx metadata diff --git a/src/models/yolo_.rs b/src/models/yolo_.rs new file mode 100644 index 0000000..e19af6e --- /dev/null +++ b/src/models/yolo_.rs @@ -0,0 +1,446 @@ +use ndarray::{Array, ArrayBase, ArrayView, Axis, Dim, IxDyn, IxDynImpl, ViewRepr}; + +#[derive(Debug, Clone, clap::ValueEnum)] +pub enum YOLOTask { + Classify, + Detect, + Pose, + Segment, + Obb, +} + +#[derive(Debug, Copy, Clone, clap::ValueEnum)] +pub enum YOLOVersion { + V5, + V6, + V7, + V8, + V9, + V10, + // YOLOX, + // YOLOv3, + // YOLOv4, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum BoxType { + /// 1 + Cxcywh, + + /// 2 Cxcybr + Cxcyxy, + + /// 3 Tlbr + Xyxy, + + /// 4 Tlwh + Xywh, + + /// 5 Tlcxcy + XyCxcy, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum ClssType { + Clss, + ConfCls, + ClsConf, + ConfClss, + ClssConf, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum KptsType { + Xys, + Xycs, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum AnchorsPosition { + Before, + After, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct YOLOPreds { + pub clss: ClssType, + pub bbox: Option, + pub kpts: Option, + pub coefs: Option, + pub obb: Option, + pub anchors: Option, + pub is_bbox_normalized: bool, + pub apply_nms: bool, + pub apply_softmax: bool, +} + +impl Default for YOLOPreds { + fn default() -> Self { + Self { + clss: ClssType::Clss, + bbox: None, + kpts: None, + coefs: None, + obb: None, + anchors: None, + is_bbox_normalized: false, + apply_nms: true, + apply_softmax: false, + } + } +} + +impl YOLOPreds { + pub fn apply_nms(mut self, x: bool) -> Self { + self.apply_nms = x; + self + } + + pub fn apply_softmax(mut self, x: bool) -> Self { + self.apply_softmax = x; + self + } + + pub fn n_clss() -> Self { + // Classification: NClss + Self { + clss: ClssType::Clss, + ..Default::default() + } + } + + pub fn n_a_cxcywh_confclss() -> Self { + // YOLOv5 | YOLOv6 | YOLOv7 | YOLOX : NACxcywhConfClss + Self { + bbox: Some(BoxType::Cxcywh), + clss: ClssType::ConfClss, + anchors: Some(AnchorsPosition::Before), + ..Default::default() + } + } + + pub fn n_a_cxcywh_confclss_coefs() -> Self { + // YOLOv5 Segment : NACxcywhConfClssCoefs + Self { + bbox: Some(BoxType::Cxcywh), + clss: ClssType::ConfClss, + coefs: Some(true), + anchors: Some(AnchorsPosition::Before), + ..Default::default() + } + } + + pub fn n_cxcywh_clss_a() -> Self { + // YOLOv8 | YOLOv9 : NCxcywhClssA + Self { + bbox: Some(BoxType::Cxcywh), + clss: ClssType::Clss, + anchors: Some(AnchorsPosition::After), + ..Default::default() + } + } + + pub fn n_a_xyxy_confcls() -> Self { + // YOLOv10 : NAXyxyConfCls + Self { + bbox: Some(BoxType::Xyxy), + clss: ClssType::ConfCls, + anchors: Some(AnchorsPosition::Before), + ..Default::default() + } + } + + pub fn n_cxcywh_clss_a_n() -> Self { + // RTDETR + Self { + bbox: Some(BoxType::Cxcywh), + clss: ClssType::Clss, + anchors: Some(AnchorsPosition::After), + is_bbox_normalized: true, + ..Default::default() + } + } + + pub fn n_cxcywh_clss_xycs_a() -> Self { + // YOLOv8 Pose : NCxcywhClssXycsA + Self { + bbox: Some(BoxType::Cxcywh), + clss: ClssType::Clss, + kpts: Some(KptsType::Xycs), + anchors: Some(AnchorsPosition::After), + ..Default::default() + } + } + + pub fn n_cxcywh_clss_coefs_a() -> Self { + // YOLOv8 Segment : NCxcywhClssCoefsA + Self { + bbox: Some(BoxType::Cxcywh), + clss: ClssType::Clss, + coefs: Some(true), + anchors: Some(AnchorsPosition::After), + ..Default::default() + } + } + + pub fn n_cxcywh_clss_r_a() -> Self { + // YOLOv8 Obb : NCxcywhClssRA + Self { + bbox: Some(BoxType::Cxcywh), + clss: ClssType::Clss, + obb: Some(true), + anchors: Some(AnchorsPosition::After), + ..Default::default() + } + } + + pub fn task(&self) -> YOLOTask { + match self.obb { + Some(_) => YOLOTask::Obb, + None => match self.coefs { + Some(_) => YOLOTask::Segment, + None => match self.kpts { + Some(_) => YOLOTask::Pose, + None => match self.bbox { + Some(_) => YOLOTask::Detect, + None => YOLOTask::Classify, + }, + }, + }, + } + } + + pub fn is_anchors_first(&self) -> bool { + matches!(self.anchors, Some(AnchorsPosition::Before)) + } + + pub fn is_cls_type(&self) -> bool { + matches!(self.clss, ClssType::ClsConf | ClssType::ConfCls) + } + + pub fn is_clss_type(&self) -> bool { + matches!( + self.clss, + ClssType::ClssConf | ClssType::ConfClss | ClssType::Clss + ) + } + + pub fn is_conf_at_end(&self) -> bool { + matches!(self.clss, ClssType::ClssConf | ClssType::ClsConf) + } + + pub fn is_conf_independent(&self) -> bool { + !matches!(self.clss, ClssType::Clss) + } + + pub fn kpt_step(&self) -> Option { + match &self.kpts { + Some(x) => match x { + KptsType::Xycs => Some(3), + KptsType::Xys => Some(2), + }, + None => None, + } + } + + #[allow(clippy::type_complexity)] + pub fn parse_preds<'a>( + &'a self, + preds: ArrayBase, Dim>, + nc: usize, + ) -> ( + ArrayView, + Option>, + Array, + Option>, + Option>, + Option>, + ) { + let preds = if self.is_anchors_first() { + preds + } else { + preds.reversed_axes() + }; + + // get each tasks slices + let (slice_bboxes, xs) = preds.split_at(Axis(1), 4); + let (slice_id, slice_clss, xs) = match self.clss { + ClssType::ConfClss => { + let slice_id = None; + let (confs, xs) = xs.split_at(Axis(1), 1); + let (clss, xs) = xs.split_at(Axis(1), nc); + let confs = confs.broadcast((confs.shape()[0], nc)).unwrap(); // 267ns + + let t = std::time::Instant::now(); + // TODO: par + let clss = &confs * &clss; + println!("2 > {:?}", std::time::Instant::now() - t); // 868.281µs + + let slice_clss = clss; + (slice_id, slice_clss, xs) + } + ClssType::ClssConf => { + let slice_id = None; + let (clss, xs) = xs.split_at(Axis(1), nc); + let (confs, xs) = xs.split_at(Axis(1), 1); + let confs = confs.broadcast((confs.shape()[0], nc)).unwrap(); + // TODO: par + let clss = &confs * &clss; + let slice_clss = clss; + (slice_id, slice_clss, xs) + } + ClssType::ConfCls => { + let (clss, xs) = xs.split_at(Axis(1), 1); + let (ids, xs) = xs.split_at(Axis(1), 1); + let slice_id = Some(ids); + let slice_clss = clss.to_owned(); + (slice_id, slice_clss, xs) + } + ClssType::ClsConf => { + let (ids, xs) = xs.split_at(Axis(1), 1); + let (clss, xs) = xs.split_at(Axis(1), 1); + let slice_id = Some(ids); + let slice_clss = clss.to_owned(); + (slice_id, slice_clss, xs) + } + ClssType::Clss => { + let slice_id = None; + let (clss, xs) = xs.split_at(Axis(1), nc); + let slice_clss = clss.to_owned(); + (slice_id, slice_clss, xs) + } + }; + let (slice_kpts, slice_coefs, slice_radians) = match self.task() { + YOLOTask::Pose => (Some(xs), None, None), + YOLOTask::Segment => (None, Some(xs), None), + YOLOTask::Obb => (None, None, Some(xs)), + _ => (None, None, None), + }; + + ( + slice_bboxes, + slice_id, + slice_clss, + slice_kpts, + slice_coefs, + slice_radians, + ) + } + + // #[allow(clippy::type_complexity)] + // pub fn parse_preds<'a>( + // &'a self, + // preds: ArrayBase, Dim>, + // nc: usize, + // ) -> ( + // ArrayView, + // Option>, + // Array, + // Option>, + // Option>, + // Option>, + // ) { + // let preds = if self.is_anchors_first() { + // preds + // } else { + // preds.reversed_axes() + // }; + + // // get each tasks slices + // let (slice_bboxes, xs) = preds.split_at(Axis(1), 4); + // let (slice_id, slice_clss, slice_kpts, slice_coefs, slice_radians) = if self.is_cls_type() { + // // box-[cls | conf -[kpts | coefs]] + // if self.is_conf_at_end() { + // // box-cls-conf-[kpts | coefs] + + // let (ids, xs) = xs.split_at(Axis(1), 1); + // let (clss, xs) = xs.split_at(Axis(1), 1); + // let slice_id = Some(ids); + // let slice_clss = clss.to_owned(); + + // let (slice_kpts, slice_coefs, slice_radians) = match self.task() { + // YOLOTask::Pose => (Some(xs), None, None), + // YOLOTask::Segment => (None, Some(xs), None), + // YOLOTask::Obb => (None, None, Some(xs)), + // _ => (None, None, None), + // }; + + // (slice_id, slice_clss, slice_kpts, slice_coefs, slice_radians) + // } else { + // // box-conf-cls-[kpts | coefs] + + // let (clss, xs) = xs.split_at(Axis(1), 1); + // let (ids, xs) = xs.split_at(Axis(1), 1); + // let slice_id = Some(ids); + // let slice_clss = clss.to_owned(); + + // let (slice_kpts, slice_coefs, slice_radians) = match self.task() { + // YOLOTask::Pose => (Some(xs), None, None), + // YOLOTask::Segment => (None, Some(xs), None), + // YOLOTask::Obb => (None, None, Some(xs)), + // _ => (None, None, None), + // }; + // (slice_id, slice_clss, slice_kpts, slice_coefs, slice_radians) + // } + // } else { + // // box-[clss | conf -[kpts | coefs]] + // if self.is_conf_independent() { + // if self.is_conf_at_end() { + // // box-clss-conf-[kpts | coefs] + + // let slice_id = None; + // let (clss, xs) = xs.split_at(Axis(1), nc); + // let (confs, xs) = xs.split_at(Axis(1), 1); + // let confs = confs.broadcast((confs.shape()[0], nc)).unwrap(); + // let clss = &confs * &clss; + // let slice_clss = clss; + + // let (slice_kpts, slice_coefs, slice_radians) = match self.task() { + // YOLOTask::Pose => (Some(xs), None, None), + // YOLOTask::Segment => (None, Some(xs), None), + // YOLOTask::Obb => (None, None, Some(xs)), + // _ => (None, None, None), + // }; + // (slice_id, slice_clss, slice_kpts, slice_coefs, slice_radians) + // } else { + // // box-conf-clss-[kpts | coefs] + // let slice_id = None; + // let (confs, xs) = xs.split_at(Axis(1), 1); + // let (clss, xs) = xs.split_at(Axis(1), nc); + // let confs = confs.broadcast((confs.shape()[0], nc)).unwrap(); + // let clss = &confs * &clss; + // let slice_clss = clss; + + // let (slice_kpts, slice_coefs, slice_radians) = match self.task() { + // YOLOTask::Pose => (Some(xs), None, None), + // YOLOTask::Segment => (None, Some(xs), None), + // YOLOTask::Obb => (None, None, Some(xs)), + // _ => (None, None, None), + // }; + // (slice_id, slice_clss, slice_kpts, slice_coefs, slice_radians) + // } + // } else { + // // box-[clss -[kpts | coefs]] + // let slice_id = None; + // let (clss, xs) = xs.split_at(Axis(1), nc); + // let slice_clss = clss.to_owned(); + + // let (slice_kpts, slice_coefs, slice_radians) = match self.task() { + // YOLOTask::Pose => (Some(xs), None, None), + // YOLOTask::Segment => (None, Some(xs), None), + // YOLOTask::Obb => (None, None, Some(xs)), + // _ => (None, None, None), + // }; + // (slice_id, slice_clss, slice_kpts, slice_coefs, slice_radians) + // } + // }; + // ( + // slice_bboxes, + // slice_id, + // slice_clss, + // slice_kpts, + // slice_coefs, + // slice_radians, + // ) + // } +} diff --git a/src/models/yolo_format.rs b/src/models/yolo_format.rs index 3d0a019..e2aab2d 100644 --- a/src/models/yolo_format.rs +++ b/src/models/yolo_format.rs @@ -22,9 +22,15 @@ pub enum YOLOTask { #[derive(Debug, Copy, Clone, ValueEnum)] pub enum YOLOVersion { V5, + V6, + V7, V8, V9, V10, + // TODO: + // YOLOX, + // YOLOv3, + // YOLOv4, } /// Enumeration of various YOLO output formats. @@ -38,87 +44,309 @@ pub enum YOLOFormat { /// Classification NClss, - // Detections - // Batch - Anchors - Bbox - Clss + /// Detections: Batch - Anchors - Bbox - Clss NACxcywhClss, NACxcyxyClss, NAXyxyClss, NAXywhClss, - // Batch - Bbox - Clss - Anchors + /// Detections: Batch - Bbox - Clss - Anchors NCxcywhClssA, NCxcyxyClssA, NXyxyClssA, NXywhClssA, - // Batch - Anchors - Bbox - Conf - Clss + /// Detections: Batch - Anchors - Bbox - Conf - Clss NACxcywhConfClss, NACxcyxyConfClss, NAXyxyConfClss, NAXywhConfClss, - // Batch - Bbox - Conf - Clss - Anchors + /// Detections: Batch - Bbox - Conf - Clss - Anchors NCxcywhConfClssA, NCxcyxyConfClssA, NXyxyConfClssA, NXywhConfClssA, - // Batch - Anchors - Bbox - Conf - Cls + /// Detections: Batch - Anchors - Bbox - Conf - Cls NACxcywhConfCls, NACxcyxyConfCls, NAXyxyConfCls, NAXywhConfCls, - // Batch - Bbox - Conf - Cls - Anchors + /// Detections: Batch - Bbox - Conf - Cls - Anchors NCxcywhConfClsA, NCxcyxyConfClsA, NXyxyConfClsA, NXywhConfClsA, - // anchor first, one top class, Confidence Independent - // Batch - Anchors - Bbox - Cls - Conf + /// Detections: Batch - Anchors - Bbox - Cls - Conf NACxcywhClsConf, NACxcyxyClsConf, NAXyxyClsConf, NAXywhClsConf, - // anchor later, one top class, Confidence Independent - // Batch - Bbox - Cls - Conf - Anchors + /// Detections: Batch - Bbox - Cls - Conf - Anchors NCxcywhClsConfA, NCxcyxyClsConfA, NXyxyClsConfA, NXywhClsConfA, - // Batch - Anchors - Bbox - Clss - Conf + /// Detections: Batch - Anchors - Bbox - Clss - Conf NACxcywhClssConf, NACxcyxyClssConf, NAXyxyClssConf, NAXywhClssConf, - // Batch - Bbox - Clss - Conf - Anchors + /// Detections: Batch - Bbox - Clss - Conf - Anchors NCxcywhClssConfA, NCxcyxyClssConfA, NXyxyClssConfA, NXywhClssConfA, - // ===> TODO: Keypoints: Xycs/Xys must be at the end // xys => xy, xy, ..., No keypoint confidence // xycs => xyc, xyc, ..., Has keypoint confidence + /// Keypoints: Batch - Anchors - Bbox - Clss - Xys NACxcywhClssXys, + NACxcyxyClssXys, + NAXyxyClssXys, + NAXywhClssXys, + + /// Keypoints: Batch - Bbox - Clss - Anchors - Xys + NCxcywhClssXysA, + NCxcyxyClssXysA, + NXyxyClssXysA, + NXywhClssXysA, + + /// Keypoints: Batch - Anchors - Bbox - Conf - Clss - Xys + NACxcywhConfClssXys, + NACxcyxyConfClssXys, + NAXyxyConfClssXys, + NAXywhConfClssXys, + + /// Keypoints: Batch - Bbox - Conf - Clss - Anchors - Xys + NCxcywhConfClssXysA, + NCxcyxyConfClssXysA, + NXyxyConfClssXysA, + NXywhConfClssXysA, + + /// Keypoints: Batch - Anchors - Bbox - Conf - Cls - Xys + NACxcywhConfClsXys, + NACxcyxyConfClsXys, + NAXyxyConfClsXys, + NAXywhConfClsXys, + + /// Keypoints: Batch - Bbox - Conf - Cls - Anchors - Xys + NCxcywhConfClsXysA, + NCxcyxyConfClsXysA, + NXyxyConfClsXysA, + NXywhConfClsXysA, + + /// Keypoints: Batch - Anchors - Bbox - Cls - Conf - Xys + NACxcywhClsConfXys, + NACxcyxyClsConfXys, + NAXyxyClsConfXys, + NAXywhClsConfXys, + + /// Keypoints: Batch - Bbox - Cls - Conf - Anchors - Xys + NCxcywhClsConfXysA, + NCxcyxyClsConfXysA, + NXyxyClsConfXysA, + NXywhClsConfXysA, + + /// Keypoints: Batch - Anchors - Bbox - Clss - Conf - Xys + NACxcywhClssConfXys, + NACxcyxyClssConfXys, + NAXyxyClssConfXys, + NAXywhClssConfXys, + + /// Keypoints: Batch - Bbox - Clss - Conf - Anchors - Xys + NCxcywhClssConfXysA, + NCxcyxyClssConfXysA, + NXyxyClssConfXysA, + NXywhClssConfXysA, + + /// Keypoints: Batch - Anchors - Bbox - Clss - Xycs NACxcywhClssXycs, NACxcyxyClssXycs, NAXyxyClssXycs, - NCxcywhClssXycsA, + NAXywhClssXycs, - // ===> TODO: OBB - NCxcywhClssRA, // R => radians - NACxcywhClssR, // R => radians + /// Keypoints: Batch - Bbox - Clss - Anchors - Xycs + NCxcywhClssXycsA, + NCxcyxyClssXycsA, + NXyxyClssXycsA, + NXywhClssXycsA, + + /// Keypoints: Batch - Anchors - Bbox - Conf - Clss - Xycs + NACxcywhConfClssXycs, + NACxcyxyConfClssXycs, + NAXyxyConfClssXycs, + NAXywhConfClssXycs, + + /// Keypoints: Batch - Bbox - Conf - Clss - Anchors - Xycs + NCxcywhConfClssXycsA, + NCxcyxyConfClssXycsA, + NXyxyConfClssXycsA, + NXywhConfClssXycsA, + + /// Keypoints: Batch - Anchors - Bbox - Conf - Cls - Xycs + NACxcywhConfClsXycs, + NACxcyxyConfClsXycs, + NAXyxyConfClsXycs, + NAXywhConfClsXycs, + + /// Keypoints: Batch - Bbox - Conf - Cls - Anchors - Xycs + NCxcywhConfClsXycsA, + NCxcyxyConfClsXycsA, + NXyxyConfClsXycsA, + NXywhConfClsXycsA, + + /// Keypoints: Batch - Anchors - Bbox - Cls - Conf - Xycs + NACxcywhClsConfXycs, + NACxcyxyClsConfXycs, + NAXyxyClsConfXycs, + NAXywhClsConfXycs, + + // anchor later, one top class, Confidence Independent - Xycs + /// Keypoints: Batch - Bbox - Cls - Conf - Anchors - Xycs + NCxcywhClsConfXycsA, + NCxcyxyClsConfXycsA, + NXyxyClsConfXycsA, + NXywhClsConfXycsA, + + /// Keypoints: Batch - Anchors - Bbox - Clss - Conf - Xycs + NACxcywhClssConfXycs, + NACxcyxyClssConfXycs, + NAXyxyClssConfXycs, + NAXywhClssConfXycs, + + /// Keypoints: Batch - Bbox - Clss - Conf - Anchors - Xycs + NCxcywhClssConfXycsA, + NCxcyxyClssConfXycsA, + NXyxyClssConfXycsA, + NXywhClssConfXycsA, + + // R => radians + /// OBB: Batch - Anchors - Bbox - Clss - R + NACxcywhClssR, + NACxcyxyClssR, + NAXyxyClssR, + NAXywhClssR, + + /// OBB: Batch - Bbox - Clss - Anchors - R + NCxcywhClssRA, + NCxcyxyClssRA, + NXyxyClssRA, + NXywhClssRA, + + /// OBB: Batch - Anchors - Bbox - Conf - Clss - R + NACxcywhConfClssR, + NACxcyxyConfClssR, + NAXyxyConfClssR, + NAXywhConfClssR, + + /// OBB: Batch - Bbox - Conf - Clss - Anchors - R + NCxcywhConfClssRA, + NCxcyxyConfClssRA, + NXyxyConfClssRA, + NXywhConfClssRA, + + /// OBB: Batch - Anchors - Bbox - Conf - Cls - R + NACxcywhConfClsR, + NACxcyxyConfClsR, + NAXyxyConfClsR, + NAXywhConfClsR, + + /// OBB: Batch - Bbox - Conf - Cls - Anchors - R + NCxcywhConfClsRA, + NCxcyxyConfClsRA, + NXyxyConfClsRA, + NXywhConfClsRA, + + /// OBB: Batch - Anchors - Bbox - Cls - Conf - R + NACxcywhClsConfR, + NACxcyxyClsConfR, + NAXyxyClsConfR, + NAXywhClsConfR, + + /// OBB: Batch - Bbox - Cls - Conf - Anchors - R + NCxcywhClsConfRA, + NCxcyxyClsConfRA, + NXyxyClsConfRA, + NXywhClsConfRA, + + /// OBB: Batch - Anchors - Bbox - Clss - Conf - R + NACxcywhClssConfR, + NACxcyxyClssConfR, + NAXyxyClssConfR, + NAXywhClssConfR, + + /// OBB: Batch - Bbox - Clss - Conf - Anchors - R + NCxcywhClssConfRA, + NCxcyxyClssConfRA, + NXyxyClssConfRA, + NXywhClssConfRA, + + /// Instance Segment: Batch - Anchors - Bbox - Clss - Coefs + NACxcywhClssCoefs, + NACxcyxyClssCoefs, + NAXyxyClssCoefs, + NAXywhClssCoefs, - // ===> TODO: instance segment + /// Instance Segment: Batch - Bbox - Clss - Anchors - Coefs NCxcywhClssCoefsA, - NACxcywhClssCoefs, + NCxcyxyClssCoefsA, + NXyxyClssCoefsA, + NXywhClssCoefsA, + + /// Instance Segment: Batch - Anchors - Bbox - Conf - Clss - Coefs NACxcywhConfClssCoefs, + NACxcyxyConfClssCoefs, + NAXyxyConfClssCoefs, + NAXywhConfClssCoefs, + + /// Instance Segment: Batch - Bbox - Conf - Clss - Anchors - Coefs NCxcywhConfClssCoefsA, + NCxcyxyConfClssCoefsA, + NXyxyConfClssCoefsA, + NXywhConfClssCoefsA, + + /// Instance Segment: Batch - Anchors - Bbox - Conf - Cls - Coefs + NACxcywhConfClsCoefs, + NACxcyxyConfClsCoefs, + NAXyxyConfClsCoefs, + NAXywhConfClsCoefs, + + /// Instance Segment: Batch - Bbox - Conf - Cls - Anchors - Coefs + NCxcywhConfClsCoefsA, + NCxcyxyConfClsCoefsA, + NXyxyConfClsCoefsA, + NXywhConfClsCoefsA, + + /// Instance Segment: Batch - Anchors - Bbox - Cls - Conf - Coefs + NACxcywhClsConfCoefs, + NACxcyxyClsConfCoefs, + NAXyxyClsConfCoefs, + NAXywhClsConfCoefs, + + /// Instance Segment: Batch - Bbox - Cls - Conf - Anchors - Coefs + NCxcywhClsConfCoefsA, + NCxcyxyClsConfCoefsA, + NXyxyClsConfCoefsA, + NXywhClsConfCoefsA, + + /// Instance Segment: Batch - Anchors - Bbox - Clss - Conf - Coefs + NACxcywhClssConfCoefs, + NACxcyxyClssConfCoefs, + NAXyxyClssConfCoefs, + NAXywhClssConfCoefs, + + /// Instance Segment: Batch - Bbox - Clss - Conf - Anchors - Coefs + NCxcywhClssConfCoefsA, + NCxcyxyClssConfCoefsA, + NXyxyClssConfCoefsA, + NXywhClssConfCoefsA, } impl fmt::Display for YOLOFormat { @@ -129,6 +357,7 @@ impl fmt::Display for YOLOFormat { impl YOLOFormat { pub fn box_type(&self) -> BoxType { + // TODO: matches! let s = self.to_string(); if s.contains("Cxcywh") { BoxType::Cxcywh @@ -142,140 +371,35 @@ impl YOLOFormat { } pub fn is_anchors_first(&self) -> bool { + // TODO: matches! !self.to_string().ends_with('A') - // match self { - // YOLOFormat::NACxcywhClss - // | YOLOFormat::NACxcyxyClss - // | YOLOFormat::NAXyxyClss - // _ => false, - // } } pub fn is_conf_independent(&self) -> bool { + // TODO: matches! self.to_string().contains("Conf") - // matches!( - // self, - // YOLOFormat::NAXywhConfCls - // | YOLOFormat::NACxcywhClsConf - // | YOLOFormat::NXyxyClsConfA - // | YOLOFormat::NXywhClsConfA - // | YOLOFormat::NACxcywhConfClssCoefs - // | YOLOFormat::NCxcywhConfClssCoefsA - // ) } - pub fn is_conf_last(&self) -> bool { + pub fn is_conf_at_end(&self) -> bool { + // TODO: matches! let s = self.to_string(); let pos_conf = s.find("Conf").unwrap(); let pos_clss = s.find("Cls").unwrap(); pos_conf > pos_clss - // matches!( - // self, - // YOLOFormat::NACxcywhClsConf - // | YOLOFormat::NACxcyxyClsConf - // | YOLOFormat::NAXyxyClsConf - // | YOLOFormat::NAXywhClsConf - // | YOLOFormat::NCxcywhClsConfA - // | YOLOFormat::NCxcyxyClsConfA - // | YOLOFormat::NXyxyClsConfA - // | YOLOFormat::NXywhClsConfA - // ) - } - - pub fn is_cls(&self) -> bool { - matches!( - self, - YOLOFormat::NAXywhConfCls - | YOLOFormat::NACxcywhClsConf - | YOLOFormat::NACxcyxyClsConf - | YOLOFormat::NAXyxyClsConf - | YOLOFormat::NAXywhClsConf - | YOLOFormat::NACxcywhConfCls - | YOLOFormat::NACxcyxyConfCls - | YOLOFormat::NAXyxyConfCls - | YOLOFormat::NCxcywhConfClsA - | YOLOFormat::NCxcyxyConfClsA - | YOLOFormat::NXyxyConfClsA - | YOLOFormat::NXywhConfClsA - | YOLOFormat::NCxcywhClsConfA - | YOLOFormat::NCxcyxyClsConfA - | YOLOFormat::NXyxyClsConfA - | YOLOFormat::NXywhClsConfA - ) } - pub fn is_clss(&self) -> bool { - !self.is_cls() + pub fn is_cls_type(&self) -> bool { + // TODO: matches! + !self.is_clss_type() } - pub fn is_cxcywh(&self) -> bool { - let s = format!("{:?}", self); - s.contains("Cxcywh") - // matches!( - // self, - // YOLOFormat::NACxcywhClsConf - // | YOLOFormat::NACxcywhConfCls - // | YOLOFormat::NACxcywhClss - // | YOLOFormat::NACxcywhConfClss - // | YOLOFormat::NACxcywhClssXycs // kpt - // | YOLOFormat::NCxcywhClssA - // | YOLOFormat::NCxcywhConfClssA - // | YOLOFormat::NCxcywhConfClsA - // | YOLOFormat::NCxcywhClsConfA - // | YOLOFormat::NCxcywhClssXycsA // kpt - // | YOLOFormat::NCxcywhClssCoefsA - // | YOLOFormat::NACxcywhClssCoefs - // | YOLOFormat::NACxcywhConfClssCoefs - // | YOLOFormat::NCxcywhConfClssCoefsA - // | YOLOFormat::NACxcywhClssR - // | YOLOFormat::NCxcywhClssRA - // ) - } - - pub fn is_xywh(&self) -> bool { - matches!( - self, - YOLOFormat::NAXywhConfCls - | YOLOFormat::NAXywhClsConf - | YOLOFormat::NAXywhClss - | YOLOFormat::NAXywhConfClss - | YOLOFormat::NXywhClssA - | YOLOFormat::NXywhConfClssA - | YOLOFormat::NXywhConfClsA - | YOLOFormat::NXywhClsConfA - ) - } - - pub fn is_cxcyxy(&self) -> bool { - matches!( - self, - YOLOFormat::NACxcyxyClsConf - | YOLOFormat::NACxcyxyConfCls - | YOLOFormat::NACxcyxyConfClss - // | YOLOFormat::NACxcyxyClssConf // TODO - | YOLOFormat::NCxcyxyClssA - | YOLOFormat::NCxcyxyConfClssA - | YOLOFormat::NCxcyxyConfClsA - | YOLOFormat::NCxcyxyClsConfA - ) - } - - pub fn is_xyxy(&self) -> bool { - matches!( - self, - YOLOFormat::NAXyxyClsConf - | YOLOFormat::NAXyxyConfCls - | YOLOFormat::NAXyxyClss - | YOLOFormat::NAXyxyConfClss - | YOLOFormat::NXyxyClssA - | YOLOFormat::NXyxyConfClssA - | YOLOFormat::NXyxyConfClsA - | YOLOFormat::NXyxyClsConfA - ) + pub fn is_clss_type(&self) -> bool { + // TODO: matches! + self.to_string().contains("Clss") } pub fn task(&self) -> YOLOTask { - // TODO: + // TODO: matches! match self { YOLOFormat::NACxcywhClssXycs | YOLOFormat::NCxcywhClssXycsA => YOLOTask::Pose, YOLOFormat::NCxcywhClssCoefsA @@ -287,39 +411,40 @@ impl YOLOFormat { _ => YOLOTask::Detect, } } - pub fn is_clssification_task(&self) -> bool { - matches!(self, YOLOFormat::NClss) - } - - pub fn is_obb_task(&self) -> bool { - matches!(self, YOLOFormat::NACxcywhClssR | YOLOFormat::NCxcywhClssRA) - } - - pub fn is_kpt_task(&self) -> bool { - matches!( - self, - YOLOFormat::NACxcywhClssXycs | YOLOFormat::NCxcywhClssXycsA - ) - } - - pub fn is_seg_task(&self) -> bool { - matches!( - self, - YOLOFormat::NCxcywhClssCoefsA - | YOLOFormat::NACxcywhClssCoefs - | YOLOFormat::NACxcywhConfClssCoefs - | YOLOFormat::NCxcywhConfClssCoefsA - ) - } + // pub fn is_clssification_task(&self) -> bool { + // matches!(self, YOLOFormat::NClss) + // } + + // pub fn is_obb_task(&self) -> bool { + // matches!(self, YOLOFormat::NACxcywhClssR | YOLOFormat::NCxcywhClssRA) + // } + + // pub fn is_kpt_task(&self) -> bool { + // matches!( + // self, + // YOLOFormat::NACxcywhClssXycs | YOLOFormat::NCxcywhClssXycsA + // ) + // } + + // pub fn is_seg_task(&self) -> bool { + // matches!( + // self, + // YOLOFormat::NCxcywhClssCoefsA + // | YOLOFormat::NACxcywhClssCoefs + // | YOLOFormat::NACxcywhConfClssCoefs + // | YOLOFormat::NCxcywhConfClssCoefsA + // ) + // } pub fn kpt_step(&self) -> Option { - match self { - YOLOFormat::NACxcywhClssXys => Some(2), - YOLOFormat::NACxcywhClssXycs - | YOLOFormat::NACxcyxyClssXycs - | YOLOFormat::NAXyxyClssXycs - | YOLOFormat::NCxcywhClssXycsA => Some(3), - _ => None, + // TODO: matches! + let s = self.to_string(); + if s.contains("Xys") { + Some(2) + } else if s.contains("Xycs") { + Some(3) + } else { + None } } @@ -344,9 +469,9 @@ impl YOLOFormat { // get each tasks slices let (slice_bboxes, xs) = preds.split_at(Axis(1), 4); - let (slice_id, slice_clss, slice_kpts, slice_coefs, slice_radians) = if self.is_cls() { + let (slice_id, slice_clss, slice_kpts, slice_coefs, slice_radians) = if self.is_cls_type() { // box-[cls | conf -[kpts | coefs]] - if self.is_conf_last() { + if self.is_conf_at_end() { // box-cls-conf-[kpts | coefs] let (ids, xs) = xs.split_at(Axis(1), 1); @@ -354,15 +479,13 @@ impl YOLOFormat { let slice_id = Some(ids); let slice_clss = clss.to_owned(); - let (slice_kpts, slice_coefs, slice_radians) = if self.is_kpt_task() { - (Some(xs), None, None) - } else if self.is_seg_task() { - (None, Some(xs), None) - } else if self.is_obb_task() { - (None, None, Some(xs)) - } else { - (None, None, None) + let (slice_kpts, slice_coefs, slice_radians) = match self.task() { + YOLOTask::Pose => (Some(xs), None, None), + YOLOTask::Segment => (None, Some(xs), None), + YOLOTask::Obb => (None, None, Some(xs)), + _ => (None, None, None), }; + (slice_id, slice_clss, slice_kpts, slice_coefs, slice_radians) } else { // box-conf-cls-[kpts | coefs] @@ -372,21 +495,18 @@ impl YOLOFormat { let slice_id = Some(ids); let slice_clss = clss.to_owned(); - let (slice_kpts, slice_coefs, slice_radians) = if self.is_kpt_task() { - (Some(xs), None, None) - } else if self.is_seg_task() { - (None, Some(xs), None) - } else if self.is_obb_task() { - (None, None, Some(xs)) - } else { - (None, None, None) + let (slice_kpts, slice_coefs, slice_radians) = match self.task() { + YOLOTask::Pose => (Some(xs), None, None), + YOLOTask::Segment => (None, Some(xs), None), + YOLOTask::Obb => (None, None, Some(xs)), + _ => (None, None, None), }; (slice_id, slice_clss, slice_kpts, slice_coefs, slice_radians) } } else { // box-[clss | conf -[kpts | coefs]] if self.is_conf_independent() { - if self.is_conf_last() { + if self.is_conf_at_end() { // box-clss-conf-[kpts | coefs] let slice_id = None; @@ -396,14 +516,11 @@ impl YOLOFormat { let clss = &confs * &clss; let slice_clss = clss; - let (slice_kpts, slice_coefs, slice_radians) = if self.is_kpt_task() { - (Some(xs), None, None) - } else if self.is_seg_task() { - (None, Some(xs), None) - } else if self.is_obb_task() { - (None, None, Some(xs)) - } else { - (None, None, None) + let (slice_kpts, slice_coefs, slice_radians) = match self.task() { + YOLOTask::Pose => (Some(xs), None, None), + YOLOTask::Segment => (None, Some(xs), None), + YOLOTask::Obb => (None, None, Some(xs)), + _ => (None, None, None), }; (slice_id, slice_clss, slice_kpts, slice_coefs, slice_radians) } else { @@ -414,14 +531,12 @@ impl YOLOFormat { let confs = confs.broadcast((confs.shape()[0], nc)).unwrap(); let clss = &confs * &clss; let slice_clss = clss; - let (slice_kpts, slice_coefs, slice_radians) = if self.is_kpt_task() { - (Some(xs), None, None) - } else if self.is_seg_task() { - (None, Some(xs), None) - } else if self.is_obb_task() { - (None, None, Some(xs)) - } else { - (None, None, None) + + let (slice_kpts, slice_coefs, slice_radians) = match self.task() { + YOLOTask::Pose => (Some(xs), None, None), + YOLOTask::Segment => (None, Some(xs), None), + YOLOTask::Obb => (None, None, Some(xs)), + _ => (None, None, None), }; (slice_id, slice_clss, slice_kpts, slice_coefs, slice_radians) } @@ -431,14 +546,11 @@ impl YOLOFormat { let (clss, xs) = xs.split_at(Axis(1), nc); let slice_clss = clss.to_owned(); - let (slice_kpts, slice_coefs, slice_radians) = if self.is_kpt_task() { - (Some(xs), None, None) - } else if self.is_seg_task() { - (None, Some(xs), None) - } else if self.is_obb_task() { - (None, None, Some(xs)) - } else { - (None, None, None) + let (slice_kpts, slice_coefs, slice_radians) = match self.task() { + YOLOTask::Pose => (Some(xs), None, None), + YOLOTask::Segment => (None, Some(xs), None), + YOLOTask::Obb => (None, None, Some(xs)), + _ => (None, None, None), }; (slice_id, slice_clss, slice_kpts, slice_coefs, slice_radians) } diff --git a/src/models/yolop.rs b/src/models/yolop.rs index 9e28319..8a1de77 100644 --- a/src/models/yolop.rs +++ b/src/models/yolop.rs @@ -22,8 +22,8 @@ impl YOLOPv2 { engine.height().to_owned(), engine.width().to_owned(), ); - let nc = 80; - let confs = DynConf::new(&options.kconfs, nc); + let confs = DynConf::new(&options.kconfs, 80); + let iou = options.iou.unwrap_or(0.45f32); engine.dry_run()?; Ok(Self { @@ -32,7 +32,7 @@ impl YOLOPv2 { height, width, batch, - iou: options.iou, + iou, }) } diff --git a/src/ys/mbr.rs b/src/ys/mbr.rs index 12f5bf8..552fb6e 100644 --- a/src/ys/mbr.rs +++ b/src/ys/mbr.rs @@ -10,6 +10,7 @@ pub struct Mbr { confidence: f32, name: Option, } + impl Nms for Mbr { /// Returns the confidence score of the bounding box. fn confidence(&self) -> f32 { @@ -113,10 +114,6 @@ impl Mbr { self.name.as_ref() } - // pub fn confidence(&self) -> f32 { - // self.confidence - // } - pub fn label(&self, with_name: bool, with_conf: bool, decimal_places: usize) -> String { let mut label = String::new(); if with_name { @@ -208,15 +205,12 @@ impl Mbr { let p2 = Polygon::new(other.ls.clone(), vec![]); p1.union(&p2).unsigned_area() as f32 } - - // pub fn iou(&self, other: &Mbr) -> f32 { - // self.intersect(other) / self.union(other) - // } } #[cfg(test)] mod tests_mbr { use super::Mbr; + use crate::Nms; use geo::{coord, line_string}; #[test] diff --git a/src/ys/mod.rs b/src/ys/mod.rs index 10c4fd4..f07d38a 100644 --- a/src/ys/mod.rs +++ b/src/ys/mod.rs @@ -1,4 +1,3 @@ -#![allow(dead_code)] mod bbox; mod embedding; mod keypoint;