Skip to content

Commit

Permalink
Safer wrapping for FFI clients.
Browse files Browse the repository at this point in the history
Use `Arc` instead of `Box` to wrap the `ClientAdapter`, so that the lifetime of the adapter will be extended even in cases where the client is closed during commands.
This also saves on using `as usize` conversions to the pointer, which cause provenance to be lost and might lead to undefined behavior.

Signed-off-by: Shachar Langbeheim <[email protected]>
  • Loading branch information
nihohit committed Jan 21, 2025
1 parent abec885 commit 5a751f9
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 31 deletions.
41 changes: 27 additions & 14 deletions csharp/lib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use redis::{FromRedisValue, RedisResult};
use std::{
ffi::{c_void, CStr, CString},
os::raw::c_char,
sync::Arc,
};
use tokio::runtime::Builder;
use tokio::runtime::Runtime;
Expand Down Expand Up @@ -68,7 +69,8 @@ fn create_client_internal(
})
}

/// Creates a new client to the given address. The success callback needs to copy the given string synchronously, since it will be dropped by Rust once the callback returns. All callbacks should be offloaded to separate threads in order not to exhaust the client's thread pool.
/// Creates a new client to the given address. The success callback needs to copy the given string synchronously, since it will be dropped by Rust once the callback returns.
/// All callbacks should be offloaded to separate threads in order not to exhaust the client's thread pool.
#[no_mangle]
pub extern "C" fn create_client(
host: *const c_char,
Expand All @@ -79,38 +81,48 @@ pub extern "C" fn create_client(
) -> *const c_void {
match create_client_internal(host, port, use_tls, success_callback, failure_callback) {
Err(_) => std::ptr::null(), // TODO - log errors
Ok(client) => Box::into_raw(Box::new(client)) as *const c_void,
Ok(client) => Arc::into_raw(Arc::new(client)) as *const c_void,
}
}

/// # Safety
///
/// This function should only be called once per pointer created by [create_client]. After calling this function
/// the `client_ptr` is not in a valid state.
#[no_mangle]
pub extern "C" fn close_client(client_ptr: *const c_void) {
let client_ptr = unsafe { Box::from_raw(client_ptr as *mut Client) };
let _runtime_handle = client_ptr.runtime.enter();
drop(client_ptr);
let count = Arc::strong_count(&unsafe { Arc::from_raw(client_ptr as *mut Client) });
assert!(count == 1, "Client is still in use.");
}

/// Expects that key and value will be kept valid until the callback is called.
///
/// # Safety
///
/// This function should only be called should with a pointer created by [create_client], before [close_client] was called with the pointer.
#[no_mangle]
pub extern "C" fn command(
pub unsafe extern "C" fn command(
client_ptr: *const c_void,
callback_index: usize,
request_type: RequestType,
args: *const *mut c_char,
arg_count: u32,
) {
let client = unsafe { Box::leak(Box::from_raw(client_ptr as *mut Client)) };
let client = unsafe {
// we increment the strong count to ensure that the client is not dropped just because we turned it into an Arc.
Arc::increment_strong_count(client_ptr);
Arc::from_raw(client_ptr as *mut Client)
};
let core_client_clone = client.clone();

// The safety of these needs to be ensured by the calling code. Cannot dispose of the pointer before all operations have completed.
let ptr_address = client_ptr as usize;
let args_address = args as usize;

let mut client_clone = client.client.clone();
client.runtime.spawn(async move {
let Some(mut cmd) = request_type.get_command() else {
unsafe {
let client = Box::leak(Box::from_raw(ptr_address as *mut Client));
(client.failure_callback)(callback_index); // TODO - report errors
(core_client_clone.failure_callback)(callback_index); // TODO - report errors
return;
}
};
Expand All @@ -128,11 +140,12 @@ pub extern "C" fn command(
.await
.and_then(Option::<CString>::from_owned_redis_value);
unsafe {
let client = Box::leak(Box::from_raw(ptr_address as *mut Client));
match result {
Ok(None) => (client.success_callback)(callback_index, std::ptr::null()),
Ok(Some(c_str)) => (client.success_callback)(callback_index, c_str.as_ptr()),
Err(_) => (client.failure_callback)(callback_index), // TODO - report errors
Ok(None) => (core_client_clone.success_callback)(callback_index, std::ptr::null()),
Ok(Some(c_str)) => {
(core_client_clone.success_callback)(callback_index, c_str.as_ptr())
}
Err(_) => (core_client_clone.failure_callback)(callback_index), // TODO - report errors
};
}
});
Expand Down
39 changes: 22 additions & 17 deletions go/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use redis::cluster_routing::{
use redis::cluster_routing::{ResponsePolicy, Routable};
use redis::{Cmd, RedisResult, Value};
use std::slice::from_raw_parts;
use std::sync::Arc;
use std::{
ffi::{c_void, CString},
mem,
Expand Down Expand Up @@ -180,8 +181,8 @@ fn create_client_internal(
///
/// * `connection_request_bytes` must point to `connection_request_len` consecutive properly initialized bytes. It must be a well-formed Protobuf `ConnectionRequest` object. The array must be allocated by the caller and subsequently freed by the caller after this function returns.
/// * `connection_request_len` must not be greater than the length of the connection request bytes array. It must also not be greater than the max value of a signed pointer-sized integer.
/// * The `conn_ptr` pointer in the returned `ConnectionResponse` must live while the client is open/active and must be explicitly freed by calling [`close_client`].
/// * The `connection_error_message` pointer in the returned `ConnectionResponse` must live until the returned `ConnectionResponse` pointer is passed to [`free_connection_response`].
/// * The `conn_ptr` pointer in the returned `ConnectionResponse` must live while the client is open/active and must be explicitly freed by calling [close_client].
/// * The `connection_error_message` pointer in the returned `ConnectionResponse` must live until the returned `ConnectionResponse` pointer is passed to [free_connection_response].
/// * Both the `success_callback` and `failure_callback` function pointers need to live while the client is open/active. The caller is responsible for freeing both callbacks.
// TODO: Consider making this async
#[no_mangle]
Expand All @@ -201,7 +202,7 @@ pub unsafe extern "C" fn create_client(
),
},
Ok(client) => ConnectionResponse {
conn_ptr: Box::into_raw(Box::new(client)) as *const c_void,
conn_ptr: Arc::into_raw(Arc::new(client)) as *const c_void,
connection_error_message: std::ptr::null(),
},
};
Expand All @@ -220,13 +221,15 @@ pub unsafe extern "C" fn create_client(
///
/// * `close_client` can only be called once per client. Calling it twice is undefined behavior, since the address will be freed twice.
/// * `close_client` must be called after `free_connection_response` has been called to avoid creating a dangling pointer in the `ConnectionResponse`.
/// * `client_adapter_ptr` must be obtained from the `ConnectionResponse` returned from [`create_client`].
/// * `client_adapter_ptr` must be obtained from the `ConnectionResponse` returned from [create_client].
/// * `client_adapter_ptr` must be valid until `close_client` is called.
// TODO: Ensure safety when command has not completed yet
#[no_mangle]
pub unsafe extern "C" fn close_client(client_adapter_ptr: *const c_void) {
assert!(!client_adapter_ptr.is_null());
drop(unsafe { Box::from_raw(client_adapter_ptr as *mut ClientAdapter) });
let client_adapter = unsafe { Arc::from_raw(client_adapter_ptr as *mut ClientAdapter) };
let count = Arc::strong_count(&client_adapter);
assert!(count == 1, "Client is still in use.");
}

/// Deallocates a `ConnectionResponse`.
Expand Down Expand Up @@ -512,7 +515,7 @@ fn valkey_value_to_command_response(value: Value) -> RedisResult<CommandResponse
///
/// # Safety
///
/// * TODO: finish safety section.
/// This function should only be called should with a pointer created by [create_client], before [close_client] was called with the pointer.
#[no_mangle]
pub unsafe extern "C" fn command(
client_adapter_ptr: *const c_void,
Expand All @@ -524,11 +527,13 @@ pub unsafe extern "C" fn command(
route_bytes: *const u8,
route_bytes_len: usize,
) {
let client_adapter =
unsafe { Box::leak(Box::from_raw(client_adapter_ptr as *mut ClientAdapter)) };
// The safety of this needs to be ensured by the calling code. Cannot dispose of the pointer before
// all operations have completed.
let ptr_address = client_adapter_ptr as usize;
let client_adapter = unsafe {
// we increment the strong count to ensure that the client is not dropped just because we turned it into an Arc.
Arc::increment_strong_count(client_adapter_ptr);
Arc::from_raw(client_adapter_ptr as *mut ClientAdapter)
};

let client_adapter_clone = client_adapter.clone();

let arg_vec =
unsafe { convert_double_pointer_to_vec(args as *const *const c_void, arg_count, args_len) };
Expand All @@ -552,7 +557,6 @@ pub unsafe extern "C" fn command(
let result = client_clone
.send_command(&cmd, get_route(route, Some(&cmd)))
.await;
let client_adapter = unsafe { Box::leak(Box::from_raw(ptr_address as *mut ClientAdapter)) };
let value = match result {
Ok(value) => value,
Err(err) => {
Expand All @@ -562,7 +566,7 @@ pub unsafe extern "C" fn command(
let c_err_str = CString::into_raw(
CString::new(message).expect("Couldn't convert error message to CString"),
);
unsafe { (client_adapter.failure_callback)(channel, c_err_str, error_type) };
unsafe { (client_adapter_clone.failure_callback)(channel, c_err_str, error_type) };
return;
}
};
Expand All @@ -571,17 +575,18 @@ pub unsafe extern "C" fn command(

unsafe {
match result {
Ok(message) => {
(client_adapter.success_callback)(channel, Box::into_raw(Box::new(message)))
}
Ok(message) => (client_adapter_clone.success_callback)(
channel,
Box::into_raw(Box::new(message)),
),
Err(err) => {
let message = errors::error_message(&err);
let error_type = errors::error_type(&err);

let c_err_str = CString::into_raw(
CString::new(message).expect("Couldn't convert error message to CString"),
);
(client_adapter.failure_callback)(channel, c_err_str, error_type);
(client_adapter_clone.failure_callback)(channel, c_err_str, error_type);
}
};
}
Expand Down

0 comments on commit 5a751f9

Please sign in to comment.