Skip to content

Commit

Permalink
Using sync_channel
Browse files Browse the repository at this point in the history
  • Loading branch information
jamjamjon committed Sep 15, 2024
1 parent 97af971 commit 1c4f428
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 69 deletions.
18 changes: 8 additions & 10 deletions examples/dl/main.rs → examples/dataloader/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,16 @@ fn main() -> anyhow::Result<()> {

// build dataloader
let dl = DataLoader::new(
// "rtsp://admin:[email protected]:554/h265/ch1/",
// "rtsp://admin:[email protected]:554/h264/ch1/",
// "../hall.mp4",
"./assets/bus.jpg",
// "images/car.jpg",
// "../set-negs",
// "/home/qweasd/Desktop/coco/val2017/images/test",
// "/home/qweasd/Desktop/SourceVideos/3.mp4",
// "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4",
"./assets/bus.jpg", // local image
// "images/bus.jpg", // remote image
// "../images", // image folder
// "../demo.mp4", // local video
// "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4", // remote video
// "rtsp://admin:[email protected]:554/h265/ch1/", // rtsp h264 stream
)?
.with_batch(1)
.with_progress_bar(true)
.with_bound(100)
.build()?;

// // build annotator
Expand All @@ -39,7 +37,7 @@ fn main() -> anyhow::Result<()> {

// run
for (xs, _) in dl {
// std::thread::sleep(std::time::Duration::from_millis(10000));
// std::thread::sleep(std::time::Duration::from_millis(1000));
let ys = model.forward(&xs, false)?;
annotator.annotate(&xs, &ys);
}
Expand Down
123 changes: 68 additions & 55 deletions src/core/dataloader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,21 @@ impl Iterator for DataLoaderIterator {
fn next(&mut self) -> Option<Self::Item> {
match &self.progress_bar {
None => self.receiver.recv().ok(),
Some(progress_bar) => {
progress_bar.inc(1);
match self.receiver.recv().ok() {
Some(item) => Some(item),
None => {
progress_bar.set_prefix(" Iterated");
progress_bar.finish();
None
}
Some(progress_bar) => match self.receiver.recv().ok() {
Some(item) => {
progress_bar.inc(1);
Some(item)
}
}
None => {
progress_bar.set_prefix(" Iterated");
progress_bar.set_style(
indicatif::ProgressStyle::with_template(crate::PROGRESS_BAR_STYLE_GREEN)
.unwrap(),
);
progress_bar.finish();
None
}
},
}
}
}
Expand All @@ -52,7 +56,7 @@ impl IntoIterator for DataLoader {
self.nf / self.batch_size as u64,
" Iterating",
Some(&format!("{:?}", self.media_type)),
"{prefix:.green.bold} {human_pos}/{human_len} |{bar}| {elapsed_precise} | {msg} ",
crate::PROGRESS_BAR_STYLE_CYAN,
)
.ok()
} else {
Expand All @@ -66,31 +70,34 @@ impl IntoIterator for DataLoader {
}
}

/// Load images, video, stream
/// A structure designed to load and manage image, video, or stream data.
/// It handles local file paths, remote URLs, and live streams, supporting both batch processing
/// and optional progress bar display. The structure also supports video decoding through
/// `video_rs` for video and stream data.
pub struct DataLoader {
pub paths: Option<VecDeque<PathBuf>>,
pub media_type: MediaType,
pub batch_size: usize,
sender: Option<mpsc::Sender<TempReturnType>>,
/// Queue of paths for images.
paths: Option<VecDeque<PathBuf>>,

/// Media type of the source (image, video, stream, etc.).
media_type: MediaType,

/// Batch size for iteration, determining how many files are processed at once.
batch_size: usize,

/// Buffer size for the channel, used to manage the buffer between producer and consumer.
bound: usize,

/// Receiver for processed data.
receiver: mpsc::Receiver<TempReturnType>,
pub decoder: Option<video_rs::decode::Decoder>,
nf: u64, // MAX means live stream
with_pb: bool,
}

impl Default for DataLoader {
fn default() -> Self {
Self {
paths: None,
media_type: MediaType::Unknown,
batch_size: 1,
sender: None,
receiver: mpsc::channel().1,
decoder: None,
nf: 0,
with_pb: true,
}
}
/// Video decoder for handling video or stream data.
decoder: Option<video_rs::decode::Decoder>,

/// Number of images or frames; `u64::MAX` is used for live streams (indicating no limit).
nf: u64,

/// Flag indicating whether to display a progress bar.
with_pb: bool,
}

impl DataLoader {
Expand Down Expand Up @@ -137,10 +144,7 @@ impl DataLoader {
anyhow::bail!("Could not locate the source path: {:?}", source_path);
}

// mpsc
let (sender, receiver) = mpsc::channel::<TempReturnType>();

// decoder
// video decoder
let decoder = match &media_type {
MediaType::Video(Location::Local) => Some(Decoder::new(source_path)?),
MediaType::Video(Location::Remote) | MediaType::Stream => {
Expand All @@ -150,9 +154,13 @@ impl DataLoader {
_ => None,
};

// get frames
// video & stream frames
if let Some(decoder) = &decoder {
nf = decoder.frames().unwrap_or(u64::MAX);
nf = match decoder.frames() {
Err(_) => u64::MAX,
Ok(0) => u64::MAX,
Ok(x) => x,
}
}

// summary
Expand All @@ -161,20 +169,35 @@ impl DataLoader {
Ok(DataLoader {
paths,
media_type,
sender: Some(sender),
receiver,
bound: 50,
receiver: mpsc::sync_channel(1).1,
batch_size: 1,
decoder,
nf,
with_pb: true,
})
}

pub fn with_bound(mut self, x: usize) -> Self {
self.bound = x;
self
}

pub fn with_batch(mut self, x: usize) -> Self {
self.batch_size = x;
self
}

pub fn with_progress_bar(mut self, x: bool) -> Self {
self.with_pb = x;
self
}

pub fn build(mut self) -> Result<Self> {
let sender = self.sender.take().expect("Sender should be available");
let (sender, receiver) = mpsc::sync_channel::<TempReturnType>(self.bound);
self.receiver = receiver;
let batch_size = self.batch_size;
let data = self.paths.take().unwrap_or_default();
// let media_type = self.media_type.take().unwrap_or(MediaType::Unknown);
let media_type = self.media_type.clone();
let decoder = self.decoder.take();

Expand All @@ -187,7 +210,7 @@ impl DataLoader {
}

fn producer_thread(
sender: mpsc::Sender<TempReturnType>,
sender: mpsc::SyncSender<TempReturnType>,
mut data: VecDeque<PathBuf>,
batch_size: usize,
media_type: MediaType,
Expand All @@ -208,7 +231,7 @@ impl DataLoader {
}
Ok(img) => {
yis.push(img);
yps.push(path.clone());
yps.push(path);
}
}
if yis.len() == batch_size
Expand Down Expand Up @@ -263,16 +286,6 @@ impl DataLoader {
}
}

pub fn with_batch(mut self, x: usize) -> Self {
self.batch_size = x;
self
}

pub fn with_progress_bar(mut self, x: bool) -> Self {
self.with_pb = x;
self
}

pub fn load_from_folder<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<std::path::PathBuf>> {
let mut paths: Vec<PathBuf> = std::fs::read_dir(path)?
.filter_map(|entry| entry.ok())
Expand Down
8 changes: 5 additions & 3 deletions src/core/ort_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,8 @@ impl OrtEngine {
let pb = build_progress_bar(
self.num_dry_run as u64,
" DryRun",
None,
"{prefix:.green.bold} {human_pos}/{human_len} |{bar}| {elapsed_precise} | {msg}",
Some(&format!("{:?}", self.device)),
crate::PROGRESS_BAR_STYLE_CYAN,
)?;

// dummy inputs
Expand All @@ -311,7 +311,9 @@ impl OrtEngine {
self.ts.clear();

// update
pb.set_message(format!("{:?}", self.device));
pb.set_style(indicatif::ProgressStyle::with_template(
crate::PROGRESS_BAR_STYLE_GREEN,
)?);
pb.finish();
}
Ok(())
Expand Down
5 changes: 4 additions & 1 deletion src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ pub use names::*;
pub(crate) const CHECK_MARK: &str = "✅";
pub(crate) const CROSS_MARK: &str = "❌";
pub(crate) const SAFE_CROSS_MARK: &str = "❎";

pub(crate) const NETWORK_PREFIXES: &[&str] = &[
"http://", "https://", "ftp://", "ftps://", "sftp://", "rtsp://", "mms://", "mmsh://",
"rtmp://", "rtmps://", "file://",
Expand All @@ -25,6 +24,10 @@ pub(crate) const AUDIO_EXTENSIONS: &[&str] = &["mp3", "wav", "flac", "aac", "ogg
pub(crate) const STREAM_PROTOCOLS: &[&str] = &[
"rtsp://", "rtsps://", "rtspu://", "rtmp://", "rtmps://", "hls://", "http://", "https://",
];
pub(crate) const PROGRESS_BAR_STYLE_CYAN: &str =
"{prefix:.cyan.bold} {msg} {human_pos}/{human_len} |{bar}| {elapsed_precise}";
pub(crate) const PROGRESS_BAR_STYLE_GREEN: &str =
"{prefix:.green.bold} {msg} {human_pos}/{human_len} |{bar}| {elapsed_precise}";

pub fn human_bytes(size: f64) -> String {
let units = ["B", "KB", "MB", "GB", "TB", "PB", "EB"];
Expand Down

0 comments on commit 1c4f428

Please sign in to comment.