Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Tensor from_raw_parts #198

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions crates/kornia-image/src/image.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,29 @@ impl<T, const C: usize> Image<T, C> {
Image::try_from(tensor)
}

/// Create a new image from raw parts.
///
/// # Arguments
///
/// * `size` - The size of the image in pixels.
/// * `data` - A pointer to the pixel data.
/// * `len` - The length of the pixel data.
///
/// # Returns
///
/// A new image created from the given size and pixel data.
pub unsafe fn from_raw_parts(
size: ImageSize,
data: *const T,
len: usize,
) -> Result<Self, ImageError>
where
T: Clone,
{
let tensor = Tensor::from_raw_parts([size.height, size.width, C], data, len, CpuAllocator)?;
Image::try_from(tensor)
}

/// Cast the pixel data of the image to a different type.
///
/// # Returns
Expand Down
4 changes: 2 additions & 2 deletions crates/kornia-io/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ log = { workspace = true }
thiserror = { workspace = true }

# optional dependencies
gst = { version = "0.23.0", package = "gstreamer", optional = true }
gst-app = { version = "0.23.0", package = "gstreamer-app", optional = true }
gst = { version = "0.23.4", package = "gstreamer", optional = true }
gst-app = { version = "0.23.4", package = "gstreamer-app", optional = true }
memmap2 = "0.9.4"
turbojpeg = { version = "1.0.0", optional = true }

Expand Down
123 changes: 107 additions & 16 deletions crates/kornia-io/src/stream/capture.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,68 @@

use crate::stream::error::StreamCaptureError;
use gst::prelude::*;
use kornia_image::{Image, ImageSize};
use kornia_image::Image;

/// Utility struct to hold the frame buffer data for the last captured frame.
#[derive(Debug, Clone)]
struct FrameBuffer {
data: *const u8,
len: usize,
cols: usize,
rows: usize,
}

unsafe impl Send for FrameBuffer {}
unsafe impl Sync for FrameBuffer {}

struct BufferPool {
buffers: Vec<Option<FrameBuffer>>,
last_active_buffer_index: usize,
}

impl BufferPool {
const BUFFER_POOL_SIZE: usize = 4;

pub fn new() -> Self {
Self {
buffers: vec![None; Self::BUFFER_POOL_SIZE],
last_active_buffer_index: 0,
}
}

pub fn enqueue(&mut self, buffer: FrameBuffer) {
self.buffers[self.last_active_buffer_index] = Some(buffer);
self.last_active_buffer_index =
(self.last_active_buffer_index + 1) % Self::BUFFER_POOL_SIZE;
}

pub fn dequeue(&mut self) -> Option<&FrameBuffer> {
let buffer = &self.buffers[self.last_active_buffer_index];
if buffer.is_none() {
return None;
}

let buffer = buffer.as_ref();
buffer
}

pub fn print_debug(&self) {

Check warning on line 50 in crates/kornia-io/src/stream/capture.rs

View workflow job for this annotation

GitHub Actions / Check

method `print_debug` is never used

Check warning on line 50 in crates/kornia-io/src/stream/capture.rs

View workflow job for this annotation

GitHub Actions / Test Suite - x86_64-unknown-linux-gnu

method `print_debug` is never used

Check warning on line 50 in crates/kornia-io/src/stream/capture.rs

View workflow job for this annotation

GitHub Actions / Test Suite - i686-unknown-linux-gnu

method `print_debug` is never used

Check warning on line 50 in crates/kornia-io/src/stream/capture.rs

View workflow job for this annotation

GitHub Actions / Test Suite - aarch64-unknown-linux-gnu

method `print_debug` is never used
println!(">>> [StreamCapture] active buffers");
for i in 0..self.buffers.len() {
if let Some(buffer) = &self.buffers[i] {
println!("index: {}, {:?}", i, buffer.data);
}
}
println!(">>>");
}
}

/// Represents a stream capture pipeline using GStreamer.
pub struct StreamCapture {
pipeline: gst::Pipeline,
last_frame: Arc<Mutex<Option<Image<u8, 3>>>>,
running: bool,
handle: Option<std::thread::JoinHandle<()>>,
buffer_pool: Arc<Mutex<BufferPool>>,
}

impl StreamCapture {
Expand All @@ -35,16 +89,20 @@
.dynamic_cast::<gst_app::AppSink>()
.map_err(StreamCaptureError::DowncastPipelineError)?;

let last_frame = Arc::new(Mutex::new(None));
// create a buffer pool
let buffer_pool = Arc::new(Mutex::new(BufferPool::new()));

appsink.set_callbacks(
gst_app::AppSinkCallbacks::builder()
.new_sample({
let last_frame = last_frame.clone();
let buffer_pool = buffer_pool.clone();
move |sink| match Self::extract_image_frame(sink) {
Ok(frame) => {
// SAFETY: we have a lock on the last_frame
*last_frame.lock().unwrap() = Some(frame);
Ok(new_frame) => {
// SAFETY: we have a lock on the buffer_pool
if let Ok(mut buffer_pool) = buffer_pool.lock() {
buffer_pool.enqueue(new_frame);
}

Ok(gst::FlowSuccess::Ok)
}
Err(_) => Err(gst::FlowError::Error),
Expand All @@ -55,9 +113,9 @@

Ok(Self {
pipeline,
last_frame,
running: false,
handle: None,
buffer_pool,
})
}

Expand Down Expand Up @@ -102,13 +160,32 @@
/// # Returns
///
/// An Option containing the last captured Image or None if no image has been captured yet.
pub fn grab(&self) -> Result<Option<Image<u8, 3>>, StreamCaptureError> {
pub fn grab(&mut self) -> Result<Option<Image<u8, 3>>, StreamCaptureError> {
if !self.running {
return Err(StreamCaptureError::PipelineNotRunning);
}

// SAFETY: we have a lock on the last_frame
Ok(self.last_frame.lock().unwrap().take())
// SAFETY: we have a lock on the buffer_pool
let mut buffer_pool = self.buffer_pool.lock().unwrap();
let Some(frame_buffer) = buffer_pool.dequeue() else {
return Ok(None);
};

let frame_buffer = std::mem::ManuallyDrop::new(frame_buffer);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gbin having a bit of trouble by trying to prevent gstreamer to deallocate the buffer. Do you have any insights here ?


// SAFETY: this operation is safe because we know the frame_buffer is valid
let image = unsafe {
Image::from_raw_parts(
[frame_buffer.cols, frame_buffer.rows].into(),
frame_buffer.data,
frame_buffer.len,
)
.map_err(|_| StreamCaptureError::CreateImageFrameError)?
};

//buffer_pool.print_debug();

Ok(Some(image))
}

/// Closes the stream capture pipeline.
Expand All @@ -135,9 +212,9 @@
///
/// # Returns
///
/// A Result containing the extracted Image or a StreamCaptureError.
fn extract_image_frame(appsink: &gst_app::AppSink) -> Result<Image<u8, 3>, StreamCaptureError> {
let sample = appsink.pull_sample()?;
/// A Result containing the extracted FrameBuffer or a StreamCaptureError.
fn extract_image_frame(appsink: &gst_app::AppSink) -> Result<FrameBuffer, StreamCaptureError> {
let sample = appsink.pull_sample().map_err(|_| gst::FlowError::Eos)?;

let caps = sample
.caps()
Expand All @@ -160,8 +237,22 @@
.ok_or_else(|| StreamCaptureError::GetBufferError)?
.map_readable()?;

Image::<u8, 3>::new(ImageSize { width, height }, buffer.as_slice().to_vec())
.map_err(|_| StreamCaptureError::CreateImageFrameError)
//println!(
// "[StreamCapture] {:?} successfully mapped buffer",
// buffer.as_ptr()
//);

// SAFETY: we need to forget the buffer because we are not going to drop it
let buffer = std::mem::ManuallyDrop::new(buffer);

let frame_buffer = FrameBuffer {
data: buffer.as_ptr(),
len: buffer.len(),
cols: width,
rows: height,
};

Ok(frame_buffer)
}
}

Expand Down
8 changes: 7 additions & 1 deletion crates/kornia-tensor/src/allocator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,14 @@
///
/// The pointer must be non-null and the layout must be correct.
#[allow(clippy::not_unsafe_ptr_arg_deref)]
fn dealloc(&self, ptr: *mut u8, layout: Layout) {

Check warning on line 77 in crates/kornia-tensor/src/allocator.rs

View workflow job for this annotation

GitHub Actions / Check

unused variable: `layout`

Check warning on line 77 in crates/kornia-tensor/src/allocator.rs

View workflow job for this annotation

GitHub Actions / Test Suite - x86_64-unknown-linux-gnu

unused variable: `layout`

Check warning on line 77 in crates/kornia-tensor/src/allocator.rs

View workflow job for this annotation

GitHub Actions / Test Suite - i686-unknown-linux-gnu

unused variable: `layout`

Check warning on line 77 in crates/kornia-tensor/src/allocator.rs

View workflow job for this annotation

GitHub Actions / Test Suite - aarch64-unknown-linux-gnu

unused variable: `layout`
unsafe { alloc::dealloc(ptr, layout) }
//println!("[allocator] {:?} attempting to deallocate pointer", ptr);

// Check for null pointer
if !ptr.is_null() {
//unsafe { alloc::dealloc(ptr, layout) }
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gbin by commenting this line seems to do the job but i don't think it's correct. The behavior I'm seeing is that I never enter in this branch statement but gstreamer tries to allocate an invalid pointer. A bit weird -- might be missing something fundamental ?

//println!("[allocator] {:?} successfully deallocated pointer", ptr);
}
}
}

Expand Down
37 changes: 12 additions & 25 deletions crates/kornia-tensor/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,18 @@ impl<T, A: TensorAllocator> TensorStorage<T, A> {
}
}

/// Creates a new tensor buffer from a raw pointer.
pub unsafe fn from_raw_parts(data: *const T, len: usize, alloc: A) -> Self {
let ptr = NonNull::new_unchecked(data as _);
let layout = Layout::from_size_align_unchecked(len, std::mem::size_of::<T>());
Self {
ptr,
len,
layout,
alloc,
}
}

/// Converts the `TensorStorage` into a `Vec<T>`.
///
/// Returns `Err(self)` if the buffer does not have the same layout as the destination Vec.
Expand All @@ -108,31 +120,6 @@ impl<T, A: TensorAllocator> TensorStorage<T, A> {
}
}

// TODO: pass the allocator to constructor
impl<T, A: TensorAllocator> From<Vec<T>> for TensorStorage<T, A>
where
A: Default,
{
/// Creates a new tensor buffer from a vector.
fn from(value: Vec<T>) -> Self {
// Safety
// Vec::as_ptr guaranteed to not be null
let ptr = unsafe { NonNull::new_unchecked(value.as_ptr() as *mut T) };
let len = value.len() * std::mem::size_of::<T>();
// Safety
// Vec guaranteed to have a valid layout matching that of `Layout::array`
// This is based on `RawVec::current_memory`
let layout = unsafe { Layout::array::<T>(value.capacity()).unwrap_unchecked() };
std::mem::forget(value);

Self {
ptr,
len,
layout,
alloc: A::default(),
}
}
}
// Safety:
// TensorStorage is thread safe if the allocator is thread safe.
unsafe impl<T, A: TensorAllocator> Send for TensorStorage<T, A> {}
Expand Down
38 changes: 38 additions & 0 deletions crates/kornia-tensor/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,35 @@ where
if numel != data.len() {
return Err(TensorError::InvalidShape(numel));
}

let storage = TensorStorage::from_vec(data.to_vec(), alloc);

let strides = get_strides_from_shape(shape);
Ok(Self {
storage,
shape,
strides,
})
}

/// Creates a new `Tensor` with the given shape and raw parts.
///
/// # Arguments
///
/// * `shape` - An array containing the shape of the tensor.
/// * `data` - A pointer to the data of the tensor.
/// * `len` - The length of the data.
/// * `alloc` - The allocator to use.
pub unsafe fn from_raw_parts(
shape: [usize; N],
data: *const T,
len: usize,
alloc: A,
) -> Result<Self, TensorError>
where
T: Clone,
{
let storage = TensorStorage::from_raw_parts(data, len, alloc);
let strides = get_strides_from_shape(shape);
Ok(Self {
storage,
Expand Down Expand Up @@ -1374,6 +1402,16 @@ mod tests {
Ok(())
}

#[test]
fn from_raw_parts() -> Result<(), TensorError> {
let data: Vec<u8> = vec![1, 2, 3, 4];
edgarriba marked this conversation as resolved.
Show resolved Hide resolved
let t = unsafe { Tensor::from_raw_parts([2, 2], data.as_ptr(), data.len(), CpuAllocator)? };
std::mem::forget(data);
assert_eq!(t.shape, [2, 2]);
assert_eq!(t.as_slice(), &[1, 2, 3, 4]);
Ok(())
}
edgarriba marked this conversation as resolved.
Show resolved Hide resolved

#[test]
fn display_2d() -> Result<(), TensorError> {
let data: [u8; 4] = [1, 2, 3, 4];
Expand Down
4 changes: 4 additions & 0 deletions examples/webcam/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
continue;
};

//println!("[main] {:?} successfully grabbed frame", img.as_ptr());

// lets resize the image to 256x256
imgproc::resize::resize_fast(
&img,
Expand Down Expand Up @@ -114,6 +116,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
"binary",
&rerun::Image::from_elements(bin.as_slice(), bin.size().into(), rerun::ColorModel::L),
)?;

//println!("[main] {:?} successfully logged frame", img.as_ptr());
}

// NOTE: this is important to close the webcam properly, otherwise the app will hang
Expand Down
Loading