From 62f190cb5bdcb91673d3bc5c137df27d7d0cc98f Mon Sep 17 00:00:00 2001 From: jamjamjon Date: Tue, 10 Sep 2024 23:16:34 +0800 Subject: [PATCH] too slow --- examples/clip/main.rs | 4 +- examples/svtr/main.rs | 4 +- examples/videos/main.rs | 51 +++---- examples/yolo/main.rs | 4 +- src/core/dataloader.rs | 312 ++++++++++++++-------------------------- src/core/hub.rs | 87 ++++++----- 6 files changed, 179 insertions(+), 283 deletions(-) diff --git a/examples/clip/main.rs b/examples/clip/main.rs index f4030bf..b9c1cc9 100644 --- a/examples/clip/main.rs +++ b/examples/clip/main.rs @@ -30,9 +30,7 @@ fn main() -> Result<(), Box> { let feats_text = model.encode_texts(&texts)?; // [n, ndim] // load image - let dl = DataLoader::default() - .with_batch(model.batch_visual()) - .load("./examples/clip/images")?; + let dl = DataLoader::new("./examples/clip/images")?.build()?; // loop for (images, paths) in dl { diff --git a/examples/svtr/main.rs b/examples/svtr/main.rs index 92c07c8..35fa7aa 100644 --- a/examples/svtr/main.rs +++ b/examples/svtr/main.rs @@ -11,9 +11,7 @@ fn main() -> Result<(), Box> { let mut model = SVTR::new(options)?; // load images - let dl = DataLoader::default() - .with_batch(1) - .load("./examples/svtr/images")?; + let dl = DataLoader::new("./examples/svtr/images")?.build()?; // run for (xs, paths) in dl { diff --git a/examples/videos/main.rs b/examples/videos/main.rs index c89a0b1..5ad3403 100644 --- a/examples/videos/main.rs +++ b/examples/videos/main.rs @@ -1,56 +1,47 @@ #![allow(unused)] -use usls::{ - models::YOLO, Annotator, DataLoader, Options, Vision, YOLOTask, YOLOVersion, COCO_SKELETONS_16, -}; +use usls::{models::YOLO, Annotator, DataLoader, Options, Vision, YOLOTask, YOLOVersion}; fn main() -> anyhow::Result<()> { // let options = Options::default() // .with_cuda(0) - // .with_model("yolo/v8-m-pose-dyn.onnx")? + // .with_model("yolo/v8-n-dyn.onnx")? // .with_yolo_version(YOLOVersion::V8) - // .with_yolo_task(YOLOTask::Pose) + // .with_yolo_task(YOLOTask::Detect) // .with_i00((1, 1, 4).into()) // .with_i02((0, 640, 640).into()) // .with_i03((0, 640, 640).into()) - // .with_confs(&[0.2, 0.15]); + // .with_confs(&[0.2]); // let mut model = YOLO::new(options)?; // // build annotator // let annotator = Annotator::default() - // .with_skeletons(&COCO_SKELETONS_16) // .with_bboxes_thickness(4) // .with_saveout("YOLO-Video-Stream"); // build dataloader - // let dl = DataLoader::new( - // // "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4", - // // "rtsp://stream.pchome.com.tw/pc/pcmedia/1080p.mp4", - // // "rtsp://185.107.232.253:554/live/stream", - // // "/home/qweasd/Desktop/SourceVideos/3.mp4", - // "./assets/bus.jpg", - // // "/home/qweasd/Desktop/coco/val2017/images/test", - // // "https://github.com/jamjamjon/assets/releases/download/images/bus.jpg", - // )? - // .with_batch(1); - - // // run - // for (xs, _paths) in dl { - // let ys = model.forward(&xs, false)?; - // annotator.annotate(&xs, &ys); - // } - + // let image = DataLoader::try_read("images/car.jpg")?; let mut dl = DataLoader::new( // "https://github.com/jamjamjon/assets/releases/download/images/bus.jpg", - "/home/qweasd/Desktop/SourceVideos/3.mp4", - // "/home/qweasd/Desktop/coco/val2017/images/test", + // "rtsp://admin:zfsoft888@192.168.2.217:554/h265/ch1/", + // "rtsp://admin:KCNULU@192.168.2.193:554/h264/ch1/", + // "/home/qweasd/Desktop/coco/val2017/images/val2017", + // "../hall.mp4", + // "./assets/bus.jpg", + // "image/cat.jpg", + // "../set-negs", + "/home/qweasd/Desktop/SourceVideos/3.mp4", // 400-800 us,, 40ms + // "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4", // ~2ms , 120us )? - .with_batch(4); - dl.commit(); + .with_batch(1) + .build()?; - // println!("Current buffer size: {}", dl.buffer_size()); + // let mut t0 = std::time::Instant::now(); for (xs, _paths) in dl { - println!("xs: {:?} | {:?}", xs.len(), _paths); + // let t1 = std::time::Instant::now(); + // println!("OOOO: {:?}", t1 - t0); + // t0 = t1; + // println!("xs: {:?} | {:?}", xs.len(), _paths); // let ys = model.forward(&xs, false)?; // annotator.annotate(&xs, &ys); } diff --git a/examples/yolo/main.rs b/examples/yolo/main.rs index edb6b1e..04ede12 100644 --- a/examples/yolo/main.rs +++ b/examples/yolo/main.rs @@ -184,9 +184,9 @@ fn main() -> Result<()> { let mut model = YOLO::new(options)?; // build dataloader - let dl = DataLoader::default() + let dl = DataLoader::new(&args.source)? .with_batch(model.batch() as _) - .load(args.source)?; + .build()?; // build annotator let annotator = Annotator::default() diff --git a/src/core/dataloader.rs b/src/core/dataloader.rs index d754ebf..bdf0eb6 100644 --- a/src/core/dataloader.rs +++ b/src/core/dataloader.rs @@ -2,30 +2,23 @@ use anyhow::{anyhow, Result}; use image::DynamicImage; use std::collections::VecDeque; use std::path::{Path, PathBuf}; -use std::sync::{ - atomic::{AtomicUsize, Ordering}, - mpsc, Arc, -}; +use std::sync::mpsc; use video_rs::{Decoder, Url}; -use walkdir::{DirEntry, WalkDir}; // TODO: remove -use crate::{Dir, Hub, Location, MediaType, CHECK_MARK}; +use crate::{Hub, Location, MediaType, CHECK_MARK}; type TempReturnType = (Vec, Vec); -/// Dataloader for loading images, videos and streams, -// #[derive(Debug, Clone)] +/// Dataloader for loading image, video and stream, pub struct DataLoader { - // source - // - image(local & remote) + // source could be: + // - image(local & remote(hub)) // - images(dir) // - video(local & remote) // - stream(local & remote) pub paths: VecDeque, pub media_type: MediaType, - pub recursive: bool, // TODO: remove pub batch_size: usize, - pub buffer_count: Arc, sender: Option>, receiver: mpsc::Receiver, pub decoder: Option, @@ -36,9 +29,7 @@ impl Default for DataLoader { Self { paths: VecDeque::new(), media_type: MediaType::Unknown, - recursive: false, batch_size: 1, - buffer_count: Arc::new(AtomicUsize::new(0)), sender: None, receiver: mpsc::channel().1, decoder: None, @@ -46,253 +37,139 @@ impl Default for DataLoader { } } +impl Iterator for DataLoader { + type Item = TempReturnType; + + fn next(&mut self) -> Option { + let t0 = std::time::Instant::now(); + match self.receiver.recv() { + Ok(batch) => { + let t1 = std::time::Instant::now(); + println!("==> {:?}", t1 - t0); + Some(batch) + } + Err(_) => None, + } + } +} + impl DataLoader { pub fn new(source: &str) -> Result { - let mut paths = VecDeque::new(); - - // local or remote + // paths & media_type let source_path = Path::new(source); - let media_type = if source_path.exists() { - // local - if source_path.is_file() { - paths.push_back(source_path.to_path_buf()); - MediaType::from_path(source_path) - } else if source_path.is_dir() { - // dir => only can be images - for entry in (source_path.read_dir().map_err(|e| e.to_string()).unwrap()).flatten() - { - if entry.path().is_file() { - paths.push_back(entry.path()); - } + let (paths, media_type) = match source_path.exists() { + false => { + // remote + ( + VecDeque::from([source_path.to_path_buf()]), + MediaType::from_url(source), + ) + } + true => { + // local + if source_path.is_file() { + ( + VecDeque::from([source_path.to_path_buf()]), + MediaType::from_path(source_path), + ) + } else if source_path.is_dir() { + let mut entries: Vec = std::fs::read_dir(source_path)? + .filter_map(|entry| entry.ok()) + .filter_map(|entry| { + let path = entry.path(); + if path.is_file() { + Some(path) + } else { + None + } + }) + .collect(); + entries.sort_by(|a, b| a.file_name().cmp(&b.file_name())); + (VecDeque::from(entries), MediaType::Image(Location::Local)) + } else { + (VecDeque::new(), MediaType::Unknown) } - MediaType::Image(Location::Local) - } else { - MediaType::Unknown } - } else { - // remote - paths.push_back(PathBuf::from(source)); - MediaType::from_url(source) }; + // mpsc let (sender, receiver) = mpsc::channel::(); - let buffer_count = Arc::new(AtomicUsize::new(0)); // decoder let decoder = match &media_type { - MediaType::Video(Location::Local) => { - let location: video_rs::location::Location = paths[0].clone().into(); - Some(Decoder::new(location).unwrap()) - } + MediaType::Video(Location::Local) => Some(Decoder::new(source_path)?), MediaType::Video(Location::Remote) | MediaType::Stream => { - let location: video_rs::location::Location = - paths[0].to_str().unwrap().parse::().unwrap().into(); + let location: video_rs::location::Location = source.parse::()?.into(); - Some(Decoder::new(location).unwrap()) + Some(Decoder::new(location)?) } _ => None, }; - println!("{CHECK_MARK} Media Type: {:?} x{}", media_type, paths.len()); + // summary + println!("{CHECK_MARK} Found {:?} x{}", media_type, paths.len()); + Ok(DataLoader { paths, media_type, - buffer_count, sender: Some(sender), receiver, - recursive: false, batch_size: 1, decoder, }) } - // Initialize the producer thread - pub fn commit(&mut self) { + // Build to initialize the producer thread + pub fn build(mut self) -> Result { let sender = self.sender.take().expect("Sender should be available"); - let buffer_count = Arc::clone(&self.buffer_count); let batch_size = self.batch_size; let data = self.paths.clone(); let media_type = self.media_type.clone(); let decoder = self.decoder.take(); std::thread::spawn(move || { - DataLoader::producer_thread( - sender, - buffer_count, - data, - batch_size, - media_type, - decoder, - ); + DataLoader::producer_thread(sender, data, batch_size, media_type, decoder); }); - } - pub fn buffer_size(&self) -> usize { - self.buffer_count.load(Ordering::SeqCst) - } - - pub fn load>(mut self, source: P) -> Result { - self.paths = match source.as_ref() { - s if s.is_file() => VecDeque::from([s.to_path_buf()]), - s if s.is_dir() => WalkDir::new(s) - .into_iter() - .filter_entry(|e| !Self::_is_hidden(e)) - .filter_map(|entry| match entry { - Err(_) => None, - Ok(entry) => { - if entry.file_type().is_dir() { - return None; - } - if !self.recursive && entry.depth() > 1 { - return None; - } - Some(entry.path().to_path_buf()) - } - }) - .collect::>(), - // s if s.starts_with("rtsp://") || s.starts_with("rtmp://") || s.starts_with("http://")|| s.starts_with("https://") => todo!(), - s if !s.exists() => { - // try download - let p = Hub::new()?.fetch(s.to_str().unwrap())?.commit()?; - let p = PathBuf::from(&p); - VecDeque::from([p.to_path_buf()]) - } - _ => todo!(), - }; - println!("{CHECK_MARK} Found file x{}", self.paths.len()); Ok(self) } - pub fn try_read>(path: P) -> Result { - let mut path = path.as_ref().to_path_buf(); - - // try to download - if !path.exists() { - let p = Hub::new()?.fetch(path.to_str().unwrap())?.commit()?; - path = PathBuf::from(&p); - } - let img = image::ImageReader::open(&path) - .map_err(|err| { - anyhow!( - "Failed to open image at {:?}. Error: {:?}", - path.display(), - err - ) - })? - .with_guessed_format() - .map_err(|err| { - anyhow!( - "Failed to make a format guess based on the content: {:?}. Error: {:?}", - path.display(), - err - ) - })? - .decode() - .map_err(|err| { - anyhow!( - "Failed to decode image at {:?}. Error: {:?}", - path.display(), - err - ) - })? - .into_rgb8(); - Ok(DynamicImage::from(img)) - } - - pub fn with_batch(mut self, x: usize) -> Self { - self.batch_size = x; - self - } - - pub fn with_recursive(mut self, x: bool) -> Self { - self.recursive = x; - self - } - - pub fn paths(&self) -> &VecDeque { - &self.paths - } - - fn _is_hidden(entry: &DirEntry) -> bool { - entry - .file_name() - .to_str() - .map(|s| s.starts_with('.')) - .unwrap_or(false) - } - fn producer_thread( sender: mpsc::Sender, - buffer_count: Arc, mut data: VecDeque, batch_size: usize, media_type: MediaType, mut decoder: Option, ) { - let mut yis: Vec = Vec::new(); - let mut yps: Vec = Vec::new(); + let mut yis: Vec = Vec::with_capacity(batch_size); + let mut yps: Vec = Vec::with_capacity(batch_size); match media_type { - MediaType::Image(Location::Local) => { + MediaType::Image(_) => { while let Some(path) = data.pop_front() { match Self::try_read(&path) { - Err(err) => { - println!("Error reading image from path {:?}: {:?}", path, err); + Err(_) => { continue; } Ok(img) => { yis.push(img); yps.push(path.clone()); - buffer_count.fetch_add(1, Ordering::SeqCst); } } - if yis.len() == batch_size && sender .send((std::mem::take(&mut yis), std::mem::take(&mut yps))) .is_err() { - println!("Receiver dropped, stopping production"); break; } } } - MediaType::Image(Location::Remote) => { - while let Some(path) = data.pop_front() { - let file_name = path.file_name().unwrap(); - let p_tmp = Dir::Cache.path(Some("tmp")).unwrap().join(file_name); - Hub::download(path.to_str().unwrap(), &p_tmp, None, None, None).unwrap(); - match Self::try_read(&p_tmp) { - Err(err) => { - println!( - "Error reading downloaded image from path {:?}: {:?}", - p_tmp, err - ); - continue; - } - Ok(x) => { - yis.push(x); - yps.push(path.clone()); - buffer_count.fetch_add(1, Ordering::SeqCst); - } - } - - if yis.len() == batch_size - && sender - .send((std::mem::take(&mut yis), std::mem::take(&mut yps))) - .is_err() - { - println!("Receiver dropped, stopping production"); - break; - } - } - } - MediaType::Video(_) => { + MediaType::Video(_) | MediaType::Stream => { if let Some(decoder) = decoder.as_mut() { let (w, h) = decoder.size(); let frames = decoder.decode_iter(); - // while let Some(frame) = frames.next() { for frame in frames { match frame { Ok((ts, frame)) => { @@ -308,14 +185,12 @@ impl DataLoader { let img = image::DynamicImage::from(rgb8); yis.push(img); yps.push(ts.to_string().into()); - buffer_count.fetch_add(1, Ordering::SeqCst); if yis.len() == batch_size && sender .send((std::mem::take(&mut yis), std::mem::take(&mut yps))) .is_err() { - println!("Receiver dropped, stopping production"); break; } } @@ -332,22 +207,45 @@ impl DataLoader { println!("Receiver dropped, stopping production"); } } -} -impl Iterator for DataLoader { - type Item = (Vec, Vec); + pub fn with_batch(mut self, x: usize) -> Self { + self.batch_size = x; + self + } - fn next(&mut self) -> Option { - match self.receiver.recv() { - Ok(batch) => { - let t0 = std::time::Instant::now(); - self.buffer_count - .fetch_sub(self.batch_size, Ordering::SeqCst); - let t1 = std::time::Instant::now(); - println!("==> {:?}", t1 - t0); - Some(batch) - } - Err(_) => None, + pub fn try_read>(path: P) -> Result { + let mut path = path.as_ref().to_path_buf(); + + // try to fetch from hub or local cache + if !path.exists() { + let p = Hub::new()?.fetch(path.to_str().unwrap())?.commit()?; + path = PathBuf::from(&p); } + let img = image::ImageReader::open(&path) + .map_err(|err| { + anyhow!( + "Failed to open image at {:?}. Error: {:?}", + path.display(), + err + ) + })? + .with_guessed_format() + .map_err(|err| { + anyhow!( + "Failed to make a format guess based on the content: {:?}. Error: {:?}", + path.display(), + err + ) + })? + .decode() + .map_err(|err| { + anyhow!( + "Failed to decode image at {:?}. Error: {:?}", + path.display(), + err + ) + })? + .into_rgb8(); + Ok(DynamicImage::from(img)) } } diff --git a/src/core/hub.rs b/src/core/hub.rs index 23d1b28..f449fa0 100644 --- a/src/core/hub.rs +++ b/src/core/hub.rs @@ -96,7 +96,7 @@ impl Default for Hub { file_size: None, releases: None, cache: PathBuf::new(), - timeout: 2000, + timeout: 3000, max_attempts: 3, ttl: std::time::Duration::from_secs(10 * 60), } @@ -147,62 +147,71 @@ impl Hub { } pub fn fetch(mut self, s: &str) -> Result { + // try to fetch from hub or local cache + let p = PathBuf::from(s); match p.exists() { true => self.path = p, false => { - match s.split_once('/') { - Some((tag, file_name)) => { - // Extract tag and file from input string - self.tag = Some(tag.to_string()); - self.file_name = Some(file_name.to_string()); - - // Check if releases are already loaded in memory - if self.releases.is_none() { - self.releases = Some(self.connect_remote()?); - } - - if let Some(releases) = &self.releases { - // Validate the tag - let tags: Vec<&str> = - releases.iter().map(|x| x.tag_name.as_str()).collect(); - if !tags.contains(&tag) { - anyhow::bail!( - "Tag '{}' not found in releases. Available tags: {:?}", - tag, - tags - ); + // check local cache 1st + let p_cache = self.cache.with_file_name(s); + if p_cache.exists() { + self.path = p_cache; + } else { + // check remote list then + match s.split_once('/') { + Some((tag, file_name)) => { + // Extract tag and file from input string + self.tag = Some(tag.to_string()); + self.file_name = Some(file_name.to_string()); + + // Check if releases are already loaded in memory + if self.releases.is_none() { + self.releases = Some(self.connect_remote()?); } - // Validate the file - if let Some(release) = releases.iter().find(|r| r.tag_name == tag) { - let files: Vec<&str> = - release.assets.iter().map(|x| x.name.as_str()).collect(); - if !files.contains(&file_name) { + if let Some(releases) = &self.releases { + // Validate the tag + let tags: Vec<&str> = + releases.iter().map(|x| x.tag_name.as_str()).collect(); + if !tags.contains(&tag) { anyhow::bail!( + "Tag '{}' not found in releases. Available tags: {:?}", + tag, + tags + ); + } + + // Validate the file + if let Some(release) = releases.iter().find(|r| r.tag_name == tag) { + let files: Vec<&str> = + release.assets.iter().map(|x| x.name.as_str()).collect(); + if !files.contains(&file_name) { + anyhow::bail!( "File '{}' not found in tag '{}'. Available files: {:?}", file_name, tag, files ); - } else { - for f_ in release.assets.iter() { - if f_.name.as_str() == file_name { - self.url = Some(f_.browser_download_url.clone()); - self.file_size = Some(f_.size); - - break; + } else { + for f_ in release.assets.iter() { + if f_.name.as_str() == file_name { + self.url = Some(f_.browser_download_url.clone()); + self.file_size = Some(f_.size); + + break; + } } } } + self.path = self.to.path(Some(tag))?.join(file_name); } - self.path = self.to.path(Some(tag))?.join(file_name); } - } - _ => anyhow::bail!( + _ => anyhow::bail!( "Download failed due to invalid format. Expected: /, got: {}", s ), + } } } } @@ -336,7 +345,7 @@ impl Hub { .with_context(|| format!("Failed to convert PathBuf: {:?} to String", self.path)) } - /// Download a file from a given URL to a specified path with a progress bar + /// Download a file from a github release to a specified path with a progress bar pub fn download + std::fmt::Debug>( src: &str, dst: P, @@ -344,6 +353,8 @@ impl Hub { timeout: Option, max_attempts: Option, ) -> Result<()> { + // TODO: other url, not just github release page + let max_attempts = max_attempts.unwrap_or(2); let timeout_duration = std::time::Duration::from_secs(timeout.unwrap_or(2000)); let agent = ureq::AgentBuilder::new().try_proxy_from_env(true).build();