Skip to content

Commit

Permalink
Get rid of clone
Browse files Browse the repository at this point in the history
  • Loading branch information
jamjamjon committed Sep 10, 2024
1 parent 62f190c commit b3ec0be
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 158 deletions.
46 changes: 46 additions & 0 deletions examples/dl/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// #![allow(unused)]

use usls::{models::YOLO, Annotator, DataLoader, Options, Vision, YOLOTask, YOLOVersion};

fn main() -> anyhow::Result<()> {
let options = Options::default()
.with_cuda(0)
.with_model("yolo/v8-m-dyn.onnx")?
.with_yolo_version(YOLOVersion::V8)
.with_yolo_task(YOLOTask::Detect)
.with_i00((1, 1, 4).into())
.with_i02((0, 640, 640).into())
.with_i03((0, 640, 640).into())
.with_confs(&[0.2]);
let mut model = YOLO::new(options)?;

// build annotator
let annotator = Annotator::default()
.with_bboxes_thickness(4)
.with_saveout("YOLO-Video-Stream");

// build dataloader
// let image = DataLoader::try_read("images/car.jpg")?;
let dl = DataLoader::new(
// "https://github.com/jamjamjon/assets/releases/download/images/bus.jpg",
// "rtsp://admin:[email protected]:554/h265/ch1/",
// "rtsp://admin:[email protected]:554/h264/ch1/",
// "/home/qweasd/Desktop/coco/val2017/images/test",
// "../hall.mp4",
// "./assets/bus.jpg",
// "image/cat.jpg",
// "../set-negs",
"/home/qweasd/Desktop/SourceVideos/3.mp4", // 400-800 us
// "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4", // ~2ms
)?
.with_batch(3)
.build()?;

for (xs, _paths) in dl {
println!("xs: {:?} | {:?}", xs.len(), _paths);
let ys = model.forward(&xs, false)?;
annotator.annotate(&xs, &ys);
}

Ok(())
}
50 changes: 0 additions & 50 deletions examples/videos/main.rs

This file was deleted.

93 changes: 58 additions & 35 deletions src/core/dataloader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,46 @@ use crate::{Hub, Location, MediaType, CHECK_MARK};

type TempReturnType = (Vec<DynamicImage>, Vec<PathBuf>);

/// Dataloader for loading image, video and stream,
// impl Iterator for DataLoader {
// type Item = TempReturnType;

// fn next(&mut self) -> Option<Self::Item> {
// self.receiver.recv().ok()
// }
// }

// Use IntoIterator trait
impl IntoIterator for DataLoader {
type Item = TempReturnType;
type IntoIter = DataLoaderIterator;

fn into_iter(self) -> Self::IntoIter {
DataLoaderIterator {
receiver: self.receiver,
}
}
}

pub struct DataLoaderIterator {
receiver: mpsc::Receiver<TempReturnType>,
}

impl Iterator for DataLoaderIterator {
type Item = TempReturnType;

fn next(&mut self) -> Option<Self::Item> {
self.receiver.recv().ok()
}
}

pub struct DataLoader {
// source could be:
// - image(local & remote(hub))
// - image(local & hub)
// - images(dir)
// - video(local & remote)
// - stream(local & remote)
pub paths: VecDeque<PathBuf>,
pub media_type: MediaType,
pub paths: Option<VecDeque<PathBuf>>,
pub media_type: Option<MediaType>,
pub batch_size: usize,
sender: Option<mpsc::Sender<TempReturnType>>,
receiver: mpsc::Receiver<TempReturnType>,
Expand All @@ -27,8 +58,8 @@ pub struct DataLoader {
impl Default for DataLoader {
fn default() -> Self {
Self {
paths: VecDeque::new(),
media_type: MediaType::Unknown,
paths: None,
media_type: Some(MediaType::Unknown),
batch_size: 1,
sender: None,
receiver: mpsc::channel().1,
Expand All @@ -37,22 +68,6 @@ impl Default for DataLoader {
}
}

impl Iterator for DataLoader {
type Item = TempReturnType;

fn next(&mut self) -> Option<Self::Item> {
let t0 = std::time::Instant::now();
match self.receiver.recv() {
Ok(batch) => {
let t1 = std::time::Instant::now();
println!("==> {:?}", t1 - t0);
Some(batch)
}
Err(_) => None,
}
}
}

impl DataLoader {
pub fn new(source: &str) -> Result<Self> {
// paths & media_type
Expand All @@ -61,16 +76,16 @@ impl DataLoader {
false => {
// remote
(
VecDeque::from([source_path.to_path_buf()]),
MediaType::from_url(source),
Some(VecDeque::from([source_path.to_path_buf()])),
Some(MediaType::from_url(source)),
)
}
true => {
// local
if source_path.is_file() {
(
VecDeque::from([source_path.to_path_buf()]),
MediaType::from_path(source_path),
Some(VecDeque::from([source_path.to_path_buf()])),
Some(MediaType::from_path(source_path)),
)
} else if source_path.is_dir() {
let mut entries: Vec<PathBuf> = std::fs::read_dir(source_path)?
Expand All @@ -84,10 +99,15 @@ impl DataLoader {
}
})
.collect();

// TODO: natural order
entries.sort_by(|a, b| a.file_name().cmp(&b.file_name()));
(VecDeque::from(entries), MediaType::Image(Location::Local))
(
Some(VecDeque::from(entries)),
Some(MediaType::Image(Location::Local)),
)
} else {
(VecDeque::new(), MediaType::Unknown)
(None, Some(MediaType::Unknown))
}
}
};
Expand All @@ -97,17 +117,20 @@ impl DataLoader {

// decoder
let decoder = match &media_type {
MediaType::Video(Location::Local) => Some(Decoder::new(source_path)?),
MediaType::Video(Location::Remote) | MediaType::Stream => {
Some(MediaType::Video(Location::Local)) => Some(Decoder::new(source_path)?),
Some(MediaType::Video(Location::Remote)) | Some(MediaType::Stream) => {
let location: video_rs::location::Location = source.parse::<Url>()?.into();

Some(Decoder::new(location)?)
}
_ => None,
};

// summary
println!("{CHECK_MARK} Found {:?} x{}", media_type, paths.len());
println!(
"{CHECK_MARK} Found {:?} x{}",
media_type.as_ref().unwrap_or(&MediaType::Unknown),
paths.as_ref().map_or(0, |p| p.len())
);

Ok(DataLoader {
paths,
Expand All @@ -119,14 +142,14 @@ impl DataLoader {
})
}

// Build to initialize the producer thread
pub fn build(mut self) -> Result<Self> {
let sender = self.sender.take().expect("Sender should be available");
let batch_size = self.batch_size;
let data = self.paths.clone();
let media_type = self.media_type.clone();
let data = self.paths.take().unwrap_or_default();
let media_type = self.media_type.take().unwrap_or(MediaType::Unknown);
let decoder = self.decoder.take();

// Spawn the producer thread
std::thread::spawn(move || {
DataLoader::producer_thread(sender, data, batch_size, media_type, decoder);
});
Expand Down
73 changes: 0 additions & 73 deletions src/core/media.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
use crate::{AUDIO_EXTENSIONS, IMAGE_EXTENSIONS, STREAM_PROTOCOLS, VIDEO_EXTENSIONS};

// ----------------------------------
// try this ???
// ----------------------------------
#[derive(Debug, Clone)]
pub enum MediaType {
Image(Location),
Expand Down Expand Up @@ -65,73 +62,3 @@ impl MediaType {
}
}
}
// ----------------------------

// #[derive(Debug, Clone)]
// pub enum MediaType {
// Image(ImageType),
// Video(VideoType),
// Audio(AudioType),
// Unknown,
// }

// #[derive(Debug, Clone)]
// pub enum ImageType {
// Local,
// Remote,
// }

// #[derive(Debug, Clone)]
// pub enum VideoType {
// Local,
// Remote,
// Stream,
// }

// #[derive(Debug, Clone)]
// pub enum AudioType {
// Local,
// Remote,
// }

// impl MediaType {
// pub fn from_path<P: AsRef<std::path::Path>>(path: P) -> Self {
// let extension = path
// .as_ref()
// .extension()
// .and_then(|ext| ext.to_str())
// .unwrap_or("")
// .to_lowercase();

// if IMAGE_EXTENSIONS.contains(&extension.as_str()) {
// MediaType::Image(ImageType::Local)
// } else if VIDEO_EXTENSIONS.contains(&extension.as_str()) {
// MediaType::Video(VideoType::Local)
// } else if AUDIO_EXTENSIONS.contains(&extension.as_str()) {
// MediaType::Audio(AudioType::Local)
// } else {
// MediaType::Unknown
// }
// }

// pub fn from_url(url: &str) -> Self {
// if IMAGE_EXTENSIONS
// .iter()
// .any(|&ext| url.ends_with(&format!(".{}", ext)))
// {
// MediaType::Image(ImageType::Remote)
// } else if VIDEO_EXTENSIONS
// .iter()
// .any(|&ext| url.ends_with(&format!(".{}", ext)))
// {
// MediaType::Video(VideoType::Remote)
// } else if STREAM_PROTOCOLS
// .iter()
// .any(|&protocol| url.contains(protocol))
// {
// MediaType::Video(VideoType::Stream)
// } else {
// MediaType::Unknown
// }
// }
// }

0 comments on commit b3ec0be

Please sign in to comment.