Skip to content

Commit

Permalink
Add support for restricting detection classes (#45)
Browse files Browse the repository at this point in the history
* Add support for restricting detection classes in `Options`
  • Loading branch information
jamjamjon authored Oct 5, 2024
1 parent 0102c15 commit 1d59638
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "usls"
version = "0.0.17"
version = "0.0.18"
edition = "2021"
description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models."
repository = "https://github.com/jamjamjon/usls"
Expand Down
2 changes: 2 additions & 0 deletions examples/yolo/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ fn main() -> Result<()> {
// .with_names(&COCO_CLASS_NAMES_80)
// .with_names2(&COCO_KEYPOINTS_17)
.with_find_contours(!args.no_contours) // find contours or not
.exclude_classes(&[0])
// .retain_classes(&[0, 5])
.with_profile(args.profile);

// build model
Expand Down
16 changes: 16 additions & 0 deletions src/core/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ pub struct Options {
pub sam_kind: Option<SamKind>,
pub use_low_res_mask: Option<bool>,
pub sapiens_task: Option<SapiensTask>,
pub classes_excluded: Vec<isize>,
pub classes_retained: Vec<isize>,
}

impl Default for Options {
Expand Down Expand Up @@ -88,6 +90,8 @@ impl Default for Options {
use_low_res_mask: None,
sapiens_task: None,
task: Task::Untitled,
classes_excluded: vec![],
classes_retained: vec![],
}
}
}
Expand Down Expand Up @@ -276,4 +280,16 @@ impl Options {
self.iiixs.push(Iiix::from((i, ii, x)));
self
}

pub fn exclude_classes(mut self, xs: &[isize]) -> Self {
self.classes_retained.clear();
self.classes_excluded.extend_from_slice(xs);
self
}

pub fn retain_classes(mut self, xs: &[isize]) -> Self {
self.classes_excluded.clear();
self.classes_retained.extend_from_slice(xs);
self
}
}
22 changes: 21 additions & 1 deletion src/models/yolo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ pub struct YOLO {
layout: YOLOPreds,
find_contours: bool,
version: Option<YOLOVersion>,
classes_excluded: Vec<isize>,
classes_retained: Vec<isize>,
}

impl Vision for YOLO {
Expand Down Expand Up @@ -157,6 +159,10 @@ impl Vision for YOLO {
let kconfs = DynConf::new(&options.kconfs, nk);
let iou = options.iou.unwrap_or(0.45);

// Classes excluded and retained
let classes_excluded = options.classes_excluded;
let classes_retained = options.classes_retained;

// Summary
tracing::info!("YOLO Task: {:?}, Version: {:?}", task, version);

Expand All @@ -179,6 +185,8 @@ impl Vision for YOLO {
layout,
version,
find_contours: options.find_contours,
classes_excluded,
classes_retained,
})
}

Expand Down Expand Up @@ -276,7 +284,19 @@ impl Vision for YOLO {
}
};

// filtering
// filtering by class id
if !self.classes_excluded.is_empty()
&& self.classes_excluded.contains(&(class_id as isize))
{
return None;
}
if !self.classes_retained.is_empty()
&& !self.classes_retained.contains(&(class_id as isize))
{
return None;
}

// filtering by conf
if confidence < self.confs[class_id] {
return None;
}
Expand Down

0 comments on commit 1d59638

Please sign in to comment.