Skip to content

Commit

Permalink
fix(macos): check WKURLSchemeTask valid before use (#1282)
Browse files Browse the repository at this point in the history
  • Loading branch information
pewsheen authored Jun 4, 2024
1 parent 8fc4ecf commit f7936b8
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 41 deletions.
5 changes: 5 additions & 0 deletions .changes/fix-macos-mitigate-async-command-panic.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"wry": patch
---

On macOS, mitigate an issue that could cause a panic when running an async command.
2 changes: 2 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,6 @@ pub enum Error {
#[cfg(target_os = "android")]
#[error(transparent)]
CrossBeamRecvError(#[from] crossbeam_channel::RecvError),
#[error("Custom protocol task is invalid.")]
CustomProtocolTaskInvalid,
}
123 changes: 82 additions & 41 deletions src/wkwebview/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ mod util;
use cocoa::appkit::{NSView, NSViewHeightSizable, NSViewMinYMargin, NSViewWidthSizable};
use cocoa::{
base::{id, nil, NO, YES},
foundation::{NSDictionary, NSFastEnumeration, NSInteger},
foundation::{NSDictionary, NSFastEnumeration, NSInteger, NSUInteger},
};
use dpi::{LogicalPosition, LogicalSize};
use once_cell::sync::Lazy;
Expand All @@ -40,6 +40,7 @@ use core_graphics::{
use objc::{
declare::ClassDecl,
runtime::{Class, Object, Sel, BOOL},
Message,
};
use objc_id::Id;

Expand Down Expand Up @@ -82,6 +83,7 @@ const NS_JSON_WRITING_FRAGMENTS_ALLOWED: u64 = 4;

static COUNTER: Counter = Counter::new();
static WEBVIEW_IDS: Lazy<Mutex<HashSet<u32>>> = Lazy::new(Default::default);
static TASK_IDS: Lazy<Mutex<HashSet<NSUInteger>>> = Lazy::new(Default::default);

#[derive(Debug, Default, Copy, Clone)]
pub struct PrintMargin {
Expand Down Expand Up @@ -193,7 +195,7 @@ impl InnerWebView {
}

// Task handler for custom protocol
extern "C" fn start_task(this: &Object, _: Sel, _webview: id, task: id) {
extern "C" fn start_task(this: &Object, _: Sel, _webview: id, task: *mut Object) {
unsafe {
#[cfg(feature = "tracing")]
let span = tracing::info_span!(parent: None, "wry::custom_protocol::handle", uri = tracing::field::Empty)
Expand Down Expand Up @@ -274,58 +276,94 @@ impl InnerWebView {
// send response
match http_request.body(sent_form_body) {
Ok(final_request) => {
// Place here to prevent task is dropped when responder is called
let task_id: NSUInteger = msg_send![task, hash];
let responder: Box<dyn FnOnce(HttpResponse<Cow<'static, [u8]>>)> = Box::new(
move |sent_response| {
let content = sent_response.body();
// default: application/octet-stream, but should be provided by the client
let wanted_mime = sent_response.headers().get(CONTENT_TYPE);
// default to 200
let wanted_status_code = sent_response.status().as_u16() as i32;
// default to HTTP/1.1
let wanted_version = format!("{:#?}", sent_response.version());

let dictionary: id = msg_send![class!(NSMutableDictionary), alloc];
let headers: id = msg_send![dictionary, initWithCapacity:1];
if let Some(mime) = wanted_mime {
let () = msg_send![headers, setObject:NSString::new(mime.to_str().unwrap()) forKey: NSString::new(CONTENT_TYPE.as_str())];
}
let () = msg_send![headers, setObject:NSString::new(&content.len().to_string()) forKey: NSString::new(CONTENT_LENGTH.as_str())];

// add headers
for (name, value) in sent_response.headers().iter() {
let header_key = name.as_str();
if let Ok(value) = value.to_str() {
let () = msg_send![headers, setObject:NSString::new(value) forKey: NSString::new(header_key)];
// Best-effort. OS may release task at any moment.
fn check_task_is_valid(webview_id: u32, task_id: u64) -> crate::Result<()> {
if !WEBVIEW_IDS.lock().unwrap().contains(&webview_id)
|| !TASK_IDS.lock().unwrap().contains(&task_id)
{
return Err(crate::Error::CustomProtocolTaskInvalid);
}
Ok(())
}

let urlresponse: id = msg_send![class!(NSHTTPURLResponse), alloc];
let response: id = msg_send![urlresponse, initWithURL:url statusCode: wanted_status_code HTTPVersion:NSString::new(&wanted_version) headerFields:headers];
if !WEBVIEW_IDS.lock().unwrap().contains(&webview_id) {
return;
}
let () = msg_send![task, didReceiveResponse: response];

// Send data
let bytes = content.as_ptr() as *mut c_void;
let data: id = msg_send![class!(NSData), alloc];
let data: id = msg_send![data, initWithBytesNoCopy:bytes length:content.len() freeWhenDone: if content.len() == 0 { NO } else { YES }];
unsafe fn response(
task: id,
task_id: NSUInteger,
webview_id: u32,
url: id, /* NSURL */
sent_response: HttpResponse<Cow<'_, [u8]>>,
) -> crate::Result<()> {
let content = sent_response.body();
// default: application/octet-stream, but should be provided by the client
let wanted_mime = sent_response.headers().get(CONTENT_TYPE);
// default to 200
let wanted_status_code = sent_response.status().as_u16() as i32;
// default to HTTP/1.1
let wanted_version = format!("{:#?}", sent_response.version());

let dictionary: id = msg_send![class!(NSMutableDictionary), alloc];
let headers: id = msg_send![dictionary, initWithCapacity:1];
if let Some(mime) = wanted_mime {
let () = msg_send![headers, setObject:NSString::new(mime.to_str().unwrap()) forKey: NSString::new(CONTENT_TYPE.as_str())];
}
let () = msg_send![headers, setObject:NSString::new(&content.len().to_string()) forKey: NSString::new(CONTENT_LENGTH.as_str())];

// add headers
for (name, value) in sent_response.headers().iter() {
let header_key = name.as_str();
if let Ok(value) = value.to_str() {
let () = msg_send![headers, setObject:NSString::new(value) forKey: NSString::new(header_key)];
}
}

if !WEBVIEW_IDS.lock().unwrap().contains(&webview_id) {
return;
let urlresponse: id = msg_send![class!(NSHTTPURLResponse), alloc];
// url is part of the task, we need to check task is still valid
check_task_is_valid(webview_id, task_id)?;
let response: id = msg_send![urlresponse, initWithURL:url statusCode: wanted_status_code HTTPVersion:NSString::new(&wanted_version) headerFields:headers];

check_task_is_valid(webview_id, task_id)?;
(*task)
.send_message::<(id,), ()>(sel!(didReceiveResponse:), (response,))
.map_err(|_| crate::Error::CustomProtocolTaskInvalid)?;

// Send data
let bytes = content.as_ptr() as *mut c_void;
let data: id = msg_send![class!(NSData), alloc];
let data: id = msg_send![data, initWithBytesNoCopy:bytes length:content.len() freeWhenDone: if content.len() == 0 { NO } else { YES }];

check_task_is_valid(webview_id, task_id)?;
(*task)
.send_message::<(id,), ()>(sel!(didReceiveData:), (data,))
.map_err(|_| crate::Error::CustomProtocolTaskInvalid)?;

// Finish
check_task_is_valid(webview_id, task_id)?;
(*task)
.send_message::<(), ()>(sel!(didFinish), ())
.map_err(|_| crate::Error::CustomProtocolTaskInvalid)?;

Ok(())
}
let () = msg_send![task, didReceiveData: data];

// Finish
if !WEBVIEW_IDS.lock().unwrap().contains(&webview_id) {
return;
if check_task_is_valid(webview_id, task_id).is_ok() {
let _ = response(task, task_id, webview_id, url, sent_response);
}
let () = msg_send![task, didFinish];
TASK_IDS.lock().unwrap().remove(&task_id);
},
);

#[cfg(feature = "tracing")]
let _span = tracing::info_span!("wry::custom_protocol::call_handler").entered();

{
let mut task_ids = TASK_IDS.lock().unwrap();
task_ids.insert(task_id);
}

function(final_request, RequestAsyncResponder { responder });
}
Err(_) => respond_with_404(),
Expand All @@ -338,7 +376,10 @@ impl InnerWebView {
}
}
}
extern "C" fn stop_task(_: &Object, _: Sel, _webview: id, _task: id) {}
extern "C" fn stop_task(_: &Object, _: Sel, _webview: id, task: id) {
let task_id: NSUInteger = unsafe { msg_send![task, hash] };
TASK_IDS.lock().unwrap().remove(&task_id);
}

let mut wv_ids = WEBVIEW_IDS.lock().unwrap();
let webview_id = COUNTER.next();
Expand Down

0 comments on commit f7936b8

Please sign in to comment.