diff --git a/examples/dl/main.rs b/examples/dataloader/main.rs similarity index 70% rename from examples/dl/main.rs rename to examples/dataloader/main.rs index 8d4b666..a0b62db 100644 --- a/examples/dl/main.rs +++ b/examples/dataloader/main.rs @@ -18,18 +18,16 @@ fn main() -> anyhow::Result<()> { // build dataloader let dl = DataLoader::new( - // "rtsp://admin:zfsoft888@192.168.2.217:554/h265/ch1/", - // "rtsp://admin:KCNULU@192.168.2.193:554/h264/ch1/", - // "../hall.mp4", - "./assets/bus.jpg", - // "images/car.jpg", - // "../set-negs", - // "/home/qweasd/Desktop/coco/val2017/images/test", - // "/home/qweasd/Desktop/SourceVideos/3.mp4", - // "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4", + "./assets/bus.jpg", // local image + // "images/bus.jpg", // remote image + // "../images", // image folder + // "../demo.mp4", // local video + // "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4", // remote video + // "rtsp://admin:xyz@192.168.2.217:554/h265/ch1/", // rtsp h264 stream )? .with_batch(1) .with_progress_bar(true) + .with_bound(100) .build()?; // // build annotator @@ -39,7 +37,7 @@ fn main() -> anyhow::Result<()> { // run for (xs, _) in dl { - // std::thread::sleep(std::time::Duration::from_millis(10000)); + // std::thread::sleep(std::time::Duration::from_millis(1000)); let ys = model.forward(&xs, false)?; annotator.annotate(&xs, &ys); } diff --git a/src/core/dataloader.rs b/src/core/dataloader.rs index 0d4cdca..c4dfc51 100644 --- a/src/core/dataloader.rs +++ b/src/core/dataloader.rs @@ -27,17 +27,21 @@ impl Iterator for DataLoaderIterator { fn next(&mut self) -> Option { match &self.progress_bar { None => self.receiver.recv().ok(), - Some(progress_bar) => { - progress_bar.inc(1); - match self.receiver.recv().ok() { - Some(item) => Some(item), - None => { - progress_bar.set_prefix(" Iterated"); - progress_bar.finish(); - None - } + Some(progress_bar) => match self.receiver.recv().ok() { + Some(item) => { + progress_bar.inc(1); + Some(item) } - } + None => { + progress_bar.set_prefix(" Iterated"); + progress_bar.set_style( + indicatif::ProgressStyle::with_template(crate::PROGRESS_BAR_STYLE_GREEN) + .unwrap(), + ); + progress_bar.finish(); + None + } + }, } } } @@ -52,7 +56,7 @@ impl IntoIterator for DataLoader { self.nf / self.batch_size as u64, " Iterating", Some(&format!("{:?}", self.media_type)), - "{prefix:.green.bold} {human_pos}/{human_len} |{bar}| {elapsed_precise} | {msg} ", + crate::PROGRESS_BAR_STYLE_CYAN, ) .ok() } else { @@ -66,31 +70,34 @@ impl IntoIterator for DataLoader { } } -/// Load images, video, stream +/// A structure designed to load and manage image, video, or stream data. +/// It handles local file paths, remote URLs, and live streams, supporting both batch processing +/// and optional progress bar display. The structure also supports video decoding through +/// `video_rs` for video and stream data. pub struct DataLoader { - pub paths: Option>, - pub media_type: MediaType, - pub batch_size: usize, - sender: Option>, + /// Queue of paths for images. + paths: Option>, + + /// Media type of the source (image, video, stream, etc.). + media_type: MediaType, + + /// Batch size for iteration, determining how many files are processed at once. + batch_size: usize, + + /// Buffer size for the channel, used to manage the buffer between producer and consumer. + bound: usize, + + /// Receiver for processed data. receiver: mpsc::Receiver, - pub decoder: Option, - nf: u64, // MAX means live stream - with_pb: bool, -} -impl Default for DataLoader { - fn default() -> Self { - Self { - paths: None, - media_type: MediaType::Unknown, - batch_size: 1, - sender: None, - receiver: mpsc::channel().1, - decoder: None, - nf: 0, - with_pb: true, - } - } + /// Video decoder for handling video or stream data. + decoder: Option, + + /// Number of images or frames; `u64::MAX` is used for live streams (indicating no limit). + nf: u64, + + /// Flag indicating whether to display a progress bar. + with_pb: bool, } impl DataLoader { @@ -137,10 +144,7 @@ impl DataLoader { anyhow::bail!("Could not locate the source path: {:?}", source_path); } - // mpsc - let (sender, receiver) = mpsc::channel::(); - - // decoder + // video decoder let decoder = match &media_type { MediaType::Video(Location::Local) => Some(Decoder::new(source_path)?), MediaType::Video(Location::Remote) | MediaType::Stream => { @@ -150,9 +154,13 @@ impl DataLoader { _ => None, }; - // get frames + // video & stream frames if let Some(decoder) = &decoder { - nf = decoder.frames().unwrap_or(u64::MAX); + nf = match decoder.frames() { + Err(_) => u64::MAX, + Ok(0) => u64::MAX, + Ok(x) => x, + } } // summary @@ -161,8 +169,8 @@ impl DataLoader { Ok(DataLoader { paths, media_type, - sender: Some(sender), - receiver, + bound: 50, + receiver: mpsc::sync_channel(1).1, batch_size: 1, decoder, nf, @@ -170,11 +178,26 @@ impl DataLoader { }) } + pub fn with_bound(mut self, x: usize) -> Self { + self.bound = x; + self + } + + pub fn with_batch(mut self, x: usize) -> Self { + self.batch_size = x; + self + } + + pub fn with_progress_bar(mut self, x: bool) -> Self { + self.with_pb = x; + self + } + pub fn build(mut self) -> Result { - let sender = self.sender.take().expect("Sender should be available"); + let (sender, receiver) = mpsc::sync_channel::(self.bound); + self.receiver = receiver; let batch_size = self.batch_size; let data = self.paths.take().unwrap_or_default(); - // let media_type = self.media_type.take().unwrap_or(MediaType::Unknown); let media_type = self.media_type.clone(); let decoder = self.decoder.take(); @@ -187,7 +210,7 @@ impl DataLoader { } fn producer_thread( - sender: mpsc::Sender, + sender: mpsc::SyncSender, mut data: VecDeque, batch_size: usize, media_type: MediaType, @@ -208,7 +231,7 @@ impl DataLoader { } Ok(img) => { yis.push(img); - yps.push(path.clone()); + yps.push(path); } } if yis.len() == batch_size @@ -263,16 +286,6 @@ impl DataLoader { } } - pub fn with_batch(mut self, x: usize) -> Self { - self.batch_size = x; - self - } - - pub fn with_progress_bar(mut self, x: bool) -> Self { - self.with_pb = x; - self - } - pub fn load_from_folder>(path: P) -> Result> { let mut paths: Vec = std::fs::read_dir(path)? .filter_map(|entry| entry.ok()) diff --git a/src/core/ort_engine.rs b/src/core/ort_engine.rs index 21ebc01..dd2994c 100644 --- a/src/core/ort_engine.rs +++ b/src/core/ort_engine.rs @@ -287,8 +287,8 @@ impl OrtEngine { let pb = build_progress_bar( self.num_dry_run as u64, " DryRun", - None, - "{prefix:.green.bold} {human_pos}/{human_len} |{bar}| {elapsed_precise} | {msg}", + Some(&format!("{:?}", self.device)), + crate::PROGRESS_BAR_STYLE_CYAN, )?; // dummy inputs @@ -311,7 +311,9 @@ impl OrtEngine { self.ts.clear(); // update - pb.set_message(format!("{:?}", self.device)); + pb.set_style(indicatif::ProgressStyle::with_template( + crate::PROGRESS_BAR_STYLE_GREEN, + )?); pb.finish(); } Ok(()) diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 19479d4..dc4aef1 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -12,7 +12,6 @@ pub use names::*; pub(crate) const CHECK_MARK: &str = "✅"; pub(crate) const CROSS_MARK: &str = "❌"; pub(crate) const SAFE_CROSS_MARK: &str = "❎"; - pub(crate) const NETWORK_PREFIXES: &[&str] = &[ "http://", "https://", "ftp://", "ftps://", "sftp://", "rtsp://", "mms://", "mmsh://", "rtmp://", "rtmps://", "file://", @@ -25,6 +24,10 @@ pub(crate) const AUDIO_EXTENSIONS: &[&str] = &["mp3", "wav", "flac", "aac", "ogg pub(crate) const STREAM_PROTOCOLS: &[&str] = &[ "rtsp://", "rtsps://", "rtspu://", "rtmp://", "rtmps://", "hls://", "http://", "https://", ]; +pub(crate) const PROGRESS_BAR_STYLE_CYAN: &str = + "{prefix:.cyan.bold} {msg} {human_pos}/{human_len} |{bar}| {elapsed_precise}"; +pub(crate) const PROGRESS_BAR_STYLE_GREEN: &str = + "{prefix:.green.bold} {msg} {human_pos}/{human_len} |{bar}| {elapsed_precise}"; pub fn human_bytes(size: f64) -> String { let units = ["B", "KB", "MB", "GB", "TB", "PB", "EB"];