From 5a751f9126887317e8957d1a3221d50ac70f6108 Mon Sep 17 00:00:00 2001 From: Shachar Langbeheim Date: Tue, 31 Dec 2024 19:51:43 +0200 Subject: [PATCH] Safer wrapping for FFI clients. Use `Arc` instead of `Box` to wrap the `ClientAdapter`, so that the lifetime of the adapter will be extended even in cases where the client is closed during commands. This also saves on using `as usize` conversions to the pointer, which cause provenance to be lost and might lead to undefined behavior. Signed-off-by: Shachar Langbeheim --- csharp/lib/src/lib.rs | 41 +++++++++++++++++++++++++++-------------- go/src/lib.rs | 39 ++++++++++++++++++++++----------------- 2 files changed, 49 insertions(+), 31 deletions(-) diff --git a/csharp/lib/src/lib.rs b/csharp/lib/src/lib.rs index c497410e31..66b9bf6056 100644 --- a/csharp/lib/src/lib.rs +++ b/csharp/lib/src/lib.rs @@ -7,6 +7,7 @@ use redis::{FromRedisValue, RedisResult}; use std::{ ffi::{c_void, CStr, CString}, os::raw::c_char, + sync::Arc, }; use tokio::runtime::Builder; use tokio::runtime::Runtime; @@ -68,7 +69,8 @@ fn create_client_internal( }) } -/// Creates a new client to the given address. The success callback needs to copy the given string synchronously, since it will be dropped by Rust once the callback returns. All callbacks should be offloaded to separate threads in order not to exhaust the client's thread pool. +/// Creates a new client to the given address. The success callback needs to copy the given string synchronously, since it will be dropped by Rust once the callback returns. +/// All callbacks should be offloaded to separate threads in order not to exhaust the client's thread pool. #[no_mangle] pub extern "C" fn create_client( host: *const c_char, @@ -79,38 +81,48 @@ pub extern "C" fn create_client( ) -> *const c_void { match create_client_internal(host, port, use_tls, success_callback, failure_callback) { Err(_) => std::ptr::null(), // TODO - log errors - Ok(client) => Box::into_raw(Box::new(client)) as *const c_void, + Ok(client) => Arc::into_raw(Arc::new(client)) as *const c_void, } } +/// # Safety +/// +/// This function should only be called once per pointer created by [create_client]. After calling this function +/// the `client_ptr` is not in a valid state. #[no_mangle] pub extern "C" fn close_client(client_ptr: *const c_void) { - let client_ptr = unsafe { Box::from_raw(client_ptr as *mut Client) }; - let _runtime_handle = client_ptr.runtime.enter(); - drop(client_ptr); + let count = Arc::strong_count(&unsafe { Arc::from_raw(client_ptr as *mut Client) }); + assert!(count == 1, "Client is still in use."); } /// Expects that key and value will be kept valid until the callback is called. +/// +/// # Safety +/// +/// This function should only be called should with a pointer created by [create_client], before [close_client] was called with the pointer. #[no_mangle] -pub extern "C" fn command( +pub unsafe extern "C" fn command( client_ptr: *const c_void, callback_index: usize, request_type: RequestType, args: *const *mut c_char, arg_count: u32, ) { - let client = unsafe { Box::leak(Box::from_raw(client_ptr as *mut Client)) }; + let client = unsafe { + // we increment the strong count to ensure that the client is not dropped just because we turned it into an Arc. + Arc::increment_strong_count(client_ptr); + Arc::from_raw(client_ptr as *mut Client) + }; + let core_client_clone = client.clone(); // The safety of these needs to be ensured by the calling code. Cannot dispose of the pointer before all operations have completed. - let ptr_address = client_ptr as usize; let args_address = args as usize; let mut client_clone = client.client.clone(); client.runtime.spawn(async move { let Some(mut cmd) = request_type.get_command() else { unsafe { - let client = Box::leak(Box::from_raw(ptr_address as *mut Client)); - (client.failure_callback)(callback_index); // TODO - report errors + (core_client_clone.failure_callback)(callback_index); // TODO - report errors return; } }; @@ -128,11 +140,12 @@ pub extern "C" fn command( .await .and_then(Option::::from_owned_redis_value); unsafe { - let client = Box::leak(Box::from_raw(ptr_address as *mut Client)); match result { - Ok(None) => (client.success_callback)(callback_index, std::ptr::null()), - Ok(Some(c_str)) => (client.success_callback)(callback_index, c_str.as_ptr()), - Err(_) => (client.failure_callback)(callback_index), // TODO - report errors + Ok(None) => (core_client_clone.success_callback)(callback_index, std::ptr::null()), + Ok(Some(c_str)) => { + (core_client_clone.success_callback)(callback_index, c_str.as_ptr()) + } + Err(_) => (core_client_clone.failure_callback)(callback_index), // TODO - report errors }; } }); diff --git a/go/src/lib.rs b/go/src/lib.rs index 376da58dfa..6415e9d1ba 100644 --- a/go/src/lib.rs +++ b/go/src/lib.rs @@ -16,6 +16,7 @@ use redis::cluster_routing::{ use redis::cluster_routing::{ResponsePolicy, Routable}; use redis::{Cmd, RedisResult, Value}; use std::slice::from_raw_parts; +use std::sync::Arc; use std::{ ffi::{c_void, CString}, mem, @@ -180,8 +181,8 @@ fn create_client_internal( /// /// * `connection_request_bytes` must point to `connection_request_len` consecutive properly initialized bytes. It must be a well-formed Protobuf `ConnectionRequest` object. The array must be allocated by the caller and subsequently freed by the caller after this function returns. /// * `connection_request_len` must not be greater than the length of the connection request bytes array. It must also not be greater than the max value of a signed pointer-sized integer. -/// * The `conn_ptr` pointer in the returned `ConnectionResponse` must live while the client is open/active and must be explicitly freed by calling [`close_client`]. -/// * The `connection_error_message` pointer in the returned `ConnectionResponse` must live until the returned `ConnectionResponse` pointer is passed to [`free_connection_response`]. +/// * The `conn_ptr` pointer in the returned `ConnectionResponse` must live while the client is open/active and must be explicitly freed by calling [close_client]. +/// * The `connection_error_message` pointer in the returned `ConnectionResponse` must live until the returned `ConnectionResponse` pointer is passed to [free_connection_response]. /// * Both the `success_callback` and `failure_callback` function pointers need to live while the client is open/active. The caller is responsible for freeing both callbacks. // TODO: Consider making this async #[no_mangle] @@ -201,7 +202,7 @@ pub unsafe extern "C" fn create_client( ), }, Ok(client) => ConnectionResponse { - conn_ptr: Box::into_raw(Box::new(client)) as *const c_void, + conn_ptr: Arc::into_raw(Arc::new(client)) as *const c_void, connection_error_message: std::ptr::null(), }, }; @@ -220,13 +221,15 @@ pub unsafe extern "C" fn create_client( /// /// * `close_client` can only be called once per client. Calling it twice is undefined behavior, since the address will be freed twice. /// * `close_client` must be called after `free_connection_response` has been called to avoid creating a dangling pointer in the `ConnectionResponse`. -/// * `client_adapter_ptr` must be obtained from the `ConnectionResponse` returned from [`create_client`]. +/// * `client_adapter_ptr` must be obtained from the `ConnectionResponse` returned from [create_client]. /// * `client_adapter_ptr` must be valid until `close_client` is called. // TODO: Ensure safety when command has not completed yet #[no_mangle] pub unsafe extern "C" fn close_client(client_adapter_ptr: *const c_void) { assert!(!client_adapter_ptr.is_null()); - drop(unsafe { Box::from_raw(client_adapter_ptr as *mut ClientAdapter) }); + let client_adapter = unsafe { Arc::from_raw(client_adapter_ptr as *mut ClientAdapter) }; + let count = Arc::strong_count(&client_adapter); + assert!(count == 1, "Client is still in use."); } /// Deallocates a `ConnectionResponse`. @@ -512,7 +515,7 @@ fn valkey_value_to_command_response(value: Value) -> RedisResult value, Err(err) => { @@ -562,7 +566,7 @@ pub unsafe extern "C" fn command( let c_err_str = CString::into_raw( CString::new(message).expect("Couldn't convert error message to CString"), ); - unsafe { (client_adapter.failure_callback)(channel, c_err_str, error_type) }; + unsafe { (client_adapter_clone.failure_callback)(channel, c_err_str, error_type) }; return; } }; @@ -571,9 +575,10 @@ pub unsafe extern "C" fn command( unsafe { match result { - Ok(message) => { - (client_adapter.success_callback)(channel, Box::into_raw(Box::new(message))) - } + Ok(message) => (client_adapter_clone.success_callback)( + channel, + Box::into_raw(Box::new(message)), + ), Err(err) => { let message = errors::error_message(&err); let error_type = errors::error_type(&err); @@ -581,7 +586,7 @@ pub unsafe extern "C" fn command( let c_err_str = CString::into_raw( CString::new(message).expect("Couldn't convert error message to CString"), ); - (client_adapter.failure_callback)(channel, c_err_str, error_type); + (client_adapter_clone.failure_callback)(channel, c_err_str, error_type); } }; }