Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Safer wrapping for FFI clients. #2906

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 27 additions & 14 deletions csharp/lib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use redis::{FromRedisValue, RedisResult};
use std::{
ffi::{c_void, CStr, CString},
os::raw::c_char,
sync::Arc,
};
use tokio::runtime::Builder;
use tokio::runtime::Runtime;
Expand Down Expand Up @@ -68,7 +69,8 @@ fn create_client_internal(
})
}

/// Creates a new client to the given address. The success callback needs to copy the given string synchronously, since it will be dropped by Rust once the callback returns. All callbacks should be offloaded to separate threads in order not to exhaust the client's thread pool.
/// Creates a new client to the given address. The success callback needs to copy the given string synchronously, since it will be dropped by Rust once the callback returns.
/// All callbacks should be offloaded to separate threads in order not to exhaust the client's thread pool.
#[no_mangle]
pub extern "C" fn create_client(
host: *const c_char,
Expand All @@ -79,38 +81,48 @@ pub extern "C" fn create_client(
) -> *const c_void {
match create_client_internal(host, port, use_tls, success_callback, failure_callback) {
Err(_) => std::ptr::null(), // TODO - log errors
Ok(client) => Box::into_raw(Box::new(client)) as *const c_void,
Ok(client) => Arc::into_raw(Arc::new(client)) as *const c_void,
}
}

/// # Safety
///
/// This function should only be called once per pointer created by [create_client]. After calling this function
/// the `client_ptr` is not in a valid state.
#[no_mangle]
pub extern "C" fn close_client(client_ptr: *const c_void) {
let client_ptr = unsafe { Box::from_raw(client_ptr as *mut Client) };
let _runtime_handle = client_ptr.runtime.enter();
drop(client_ptr);
let count = Arc::strong_count(&unsafe { Arc::from_raw(client_ptr as *mut Client) });
assert!(count == 1, "Client is still in use.");
Yury-Fridlyand marked this conversation as resolved.
Show resolved Hide resolved
}

/// Expects that key and value will be kept valid until the callback is called.
///
/// # Safety
///
/// This function should only be called should with a pointer created by [create_client], before [close_client] was called with the pointer.
#[no_mangle]
pub extern "C" fn command(
pub unsafe extern "C" fn command(
client_ptr: *const c_void,
callback_index: usize,
request_type: RequestType,
args: *const *mut c_char,
arg_count: u32,
) {
let client = unsafe { Box::leak(Box::from_raw(client_ptr as *mut Client)) };
let client = unsafe {
// we increment the strong count to ensure that the client is not dropped just because we turned it into an Arc.
Arc::increment_strong_count(client_ptr);
Yury-Fridlyand marked this conversation as resolved.
Show resolved Hide resolved
Arc::from_raw(client_ptr as *mut Client)
};
let core_client_clone = client.clone();

// The safety of these needs to be ensured by the calling code. Cannot dispose of the pointer before all operations have completed.
let ptr_address = client_ptr as usize;
let args_address = args as usize;

let mut client_clone = client.client.clone();
client.runtime.spawn(async move {
let Some(mut cmd) = request_type.get_command() else {
unsafe {
let client = Box::leak(Box::from_raw(ptr_address as *mut Client));
(client.failure_callback)(callback_index); // TODO - report errors
(core_client_clone.failure_callback)(callback_index); // TODO - report errors
return;
}
};
Expand All @@ -128,11 +140,12 @@ pub extern "C" fn command(
.await
.and_then(Option::<CString>::from_owned_redis_value);
unsafe {
let client = Box::leak(Box::from_raw(ptr_address as *mut Client));
match result {
Ok(None) => (client.success_callback)(callback_index, std::ptr::null()),
Ok(Some(c_str)) => (client.success_callback)(callback_index, c_str.as_ptr()),
Err(_) => (client.failure_callback)(callback_index), // TODO - report errors
Ok(None) => (core_client_clone.success_callback)(callback_index, std::ptr::null()),
Ok(Some(c_str)) => {
(core_client_clone.success_callback)(callback_index, c_str.as_ptr())
}
Err(_) => (core_client_clone.failure_callback)(callback_index), // TODO - report errors
};
}
});
Expand Down
39 changes: 22 additions & 17 deletions go/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use redis::cluster_routing::{
use redis::cluster_routing::{ResponsePolicy, Routable};
use redis::{Cmd, RedisResult, Value};
use std::slice::from_raw_parts;
use std::sync::Arc;
use std::{
ffi::{c_void, CString},
mem,
Expand Down Expand Up @@ -180,8 +181,8 @@ fn create_client_internal(
///
/// * `connection_request_bytes` must point to `connection_request_len` consecutive properly initialized bytes. It must be a well-formed Protobuf `ConnectionRequest` object. The array must be allocated by the caller and subsequently freed by the caller after this function returns.
/// * `connection_request_len` must not be greater than the length of the connection request bytes array. It must also not be greater than the max value of a signed pointer-sized integer.
/// * The `conn_ptr` pointer in the returned `ConnectionResponse` must live while the client is open/active and must be explicitly freed by calling [`close_client`].
/// * The `connection_error_message` pointer in the returned `ConnectionResponse` must live until the returned `ConnectionResponse` pointer is passed to [`free_connection_response`].
/// * The `conn_ptr` pointer in the returned `ConnectionResponse` must live while the client is open/active and must be explicitly freed by calling [close_client].
/// * The `connection_error_message` pointer in the returned `ConnectionResponse` must live until the returned `ConnectionResponse` pointer is passed to [free_connection_response].
/// * Both the `success_callback` and `failure_callback` function pointers need to live while the client is open/active. The caller is responsible for freeing both callbacks.
// TODO: Consider making this async
#[no_mangle]
Expand All @@ -201,7 +202,7 @@ pub unsafe extern "C" fn create_client(
),
},
Ok(client) => ConnectionResponse {
conn_ptr: Box::into_raw(Box::new(client)) as *const c_void,
conn_ptr: Arc::into_raw(Arc::new(client)) as *const c_void,
connection_error_message: std::ptr::null(),
},
};
Expand All @@ -220,13 +221,15 @@ pub unsafe extern "C" fn create_client(
///
/// * `close_client` can only be called once per client. Calling it twice is undefined behavior, since the address will be freed twice.
/// * `close_client` must be called after `free_connection_response` has been called to avoid creating a dangling pointer in the `ConnectionResponse`.
/// * `client_adapter_ptr` must be obtained from the `ConnectionResponse` returned from [`create_client`].
/// * `client_adapter_ptr` must be obtained from the `ConnectionResponse` returned from [create_client].
/// * `client_adapter_ptr` must be valid until `close_client` is called.
// TODO: Ensure safety when command has not completed yet
#[no_mangle]
pub unsafe extern "C" fn close_client(client_adapter_ptr: *const c_void) {
assert!(!client_adapter_ptr.is_null());
drop(unsafe { Box::from_raw(client_adapter_ptr as *mut ClientAdapter) });
let client_adapter = unsafe { Arc::from_raw(client_adapter_ptr as *mut ClientAdapter) };
let count = Arc::strong_count(&client_adapter);
assert!(count == 1, "Client is still in use.");
}

/// Deallocates a `ConnectionResponse`.
Expand Down Expand Up @@ -512,7 +515,7 @@ fn valkey_value_to_command_response(value: Value) -> RedisResult<CommandResponse
///
/// # Safety
///
/// * TODO: finish safety section.
/// This function should only be called should with a pointer created by [create_client], before [close_client] was called with the pointer.
#[no_mangle]
pub unsafe extern "C" fn command(
client_adapter_ptr: *const c_void,
Expand All @@ -524,11 +527,13 @@ pub unsafe extern "C" fn command(
route_bytes: *const u8,
route_bytes_len: usize,
) {
let client_adapter =
unsafe { Box::leak(Box::from_raw(client_adapter_ptr as *mut ClientAdapter)) };
// The safety of this needs to be ensured by the calling code. Cannot dispose of the pointer before
// all operations have completed.
let ptr_address = client_adapter_ptr as usize;
let client_adapter = unsafe {
// we increment the strong count to ensure that the client is not dropped just because we turned it into an Arc.
Arc::increment_strong_count(client_adapter_ptr);
Arc::from_raw(client_adapter_ptr as *mut ClientAdapter)
};

let client_adapter_clone = client_adapter.clone();

let arg_vec =
unsafe { convert_double_pointer_to_vec(args as *const *const c_void, arg_count, args_len) };
Expand All @@ -552,7 +557,6 @@ pub unsafe extern "C" fn command(
let result = client_clone
.send_command(&cmd, get_route(route, Some(&cmd)))
.await;
let client_adapter = unsafe { Box::leak(Box::from_raw(ptr_address as *mut ClientAdapter)) };
let value = match result {
Ok(value) => value,
Err(err) => {
Expand All @@ -562,7 +566,7 @@ pub unsafe extern "C" fn command(
let c_err_str = CString::into_raw(
CString::new(message).expect("Couldn't convert error message to CString"),
);
unsafe { (client_adapter.failure_callback)(channel, c_err_str, error_type) };
unsafe { (client_adapter_clone.failure_callback)(channel, c_err_str, error_type) };
return;
}
};
Expand All @@ -571,17 +575,18 @@ pub unsafe extern "C" fn command(

unsafe {
match result {
Ok(message) => {
(client_adapter.success_callback)(channel, Box::into_raw(Box::new(message)))
}
Ok(message) => (client_adapter_clone.success_callback)(
channel,
Box::into_raw(Box::new(message)),
),
Err(err) => {
let message = errors::error_message(&err);
let error_type = errors::error_type(&err);

let c_err_str = CString::into_raw(
CString::new(message).expect("Couldn't convert error message to CString"),
);
(client_adapter.failure_callback)(channel, c_err_str, error_type);
(client_adapter_clone.failure_callback)(channel, c_err_str, error_type);
}
};
}
Expand Down
Loading