Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(macos): check WKURLSchemeTask is valid before using #1282

Merged
merged 1 commit into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changes/fix-macos-mitigate-async-command-panic.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"wry": patch
---

On macOS, mitigate an issue that could cause a panic when running an async command.
2 changes: 2 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,6 @@ pub enum Error {
#[cfg(target_os = "android")]
#[error(transparent)]
CrossBeamRecvError(#[from] crossbeam_channel::RecvError),
#[error("Custom protocol task is invalid.")]
CustomProtocolTaskInvalid,
}
123 changes: 82 additions & 41 deletions src/wkwebview/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ mod util;
use cocoa::appkit::{NSView, NSViewHeightSizable, NSViewMinYMargin, NSViewWidthSizable};
use cocoa::{
base::{id, nil, NO, YES},
foundation::{NSDictionary, NSFastEnumeration, NSInteger},
foundation::{NSDictionary, NSFastEnumeration, NSInteger, NSUInteger},
};
use dpi::{LogicalPosition, LogicalSize};
use once_cell::sync::Lazy;
Expand All @@ -40,6 +40,7 @@ use core_graphics::{
use objc::{
declare::ClassDecl,
runtime::{Class, Object, Sel, BOOL},
Message,
};
use objc_id::Id;

Expand Down Expand Up @@ -82,6 +83,7 @@ const NS_JSON_WRITING_FRAGMENTS_ALLOWED: u64 = 4;

static COUNTER: Counter = Counter::new();
static WEBVIEW_IDS: Lazy<Mutex<HashSet<u32>>> = Lazy::new(Default::default);
static TASK_IDS: Lazy<Mutex<HashSet<NSUInteger>>> = Lazy::new(Default::default);

#[derive(Debug, Default, Copy, Clone)]
pub struct PrintMargin {
Expand Down Expand Up @@ -193,7 +195,7 @@ impl InnerWebView {
}

// Task handler for custom protocol
extern "C" fn start_task(this: &Object, _: Sel, _webview: id, task: id) {
extern "C" fn start_task(this: &Object, _: Sel, _webview: id, task: *mut Object) {
unsafe {
#[cfg(feature = "tracing")]
let span = tracing::info_span!(parent: None, "wry::custom_protocol::handle", uri = tracing::field::Empty)
Expand Down Expand Up @@ -274,58 +276,94 @@ impl InnerWebView {
// send response
match http_request.body(sent_form_body) {
Ok(final_request) => {
// Place here to prevent task is dropped when responder is called
let task_id: NSUInteger = msg_send![task, hash];
let responder: Box<dyn FnOnce(HttpResponse<Cow<'static, [u8]>>)> = Box::new(
move |sent_response| {
let content = sent_response.body();
// default: application/octet-stream, but should be provided by the client
let wanted_mime = sent_response.headers().get(CONTENT_TYPE);
// default to 200
let wanted_status_code = sent_response.status().as_u16() as i32;
// default to HTTP/1.1
let wanted_version = format!("{:#?}", sent_response.version());

let dictionary: id = msg_send![class!(NSMutableDictionary), alloc];
let headers: id = msg_send![dictionary, initWithCapacity:1];
if let Some(mime) = wanted_mime {
let () = msg_send![headers, setObject:NSString::new(mime.to_str().unwrap()) forKey: NSString::new(CONTENT_TYPE.as_str())];
}
let () = msg_send![headers, setObject:NSString::new(&content.len().to_string()) forKey: NSString::new(CONTENT_LENGTH.as_str())];

// add headers
for (name, value) in sent_response.headers().iter() {
let header_key = name.as_str();
if let Ok(value) = value.to_str() {
let () = msg_send![headers, setObject:NSString::new(value) forKey: NSString::new(header_key)];
// Best-effort. OS may release task at any moment.
fn check_task_is_valid(webview_id: u32, task_id: u64) -> crate::Result<()> {
if !WEBVIEW_IDS.lock().unwrap().contains(&webview_id)
|| !TASK_IDS.lock().unwrap().contains(&task_id)
{
return Err(crate::Error::CustomProtocolTaskInvalid);
}
Ok(())
}

let urlresponse: id = msg_send![class!(NSHTTPURLResponse), alloc];
let response: id = msg_send![urlresponse, initWithURL:url statusCode: wanted_status_code HTTPVersion:NSString::new(&wanted_version) headerFields:headers];
if !WEBVIEW_IDS.lock().unwrap().contains(&webview_id) {
return;
}
let () = msg_send![task, didReceiveResponse: response];

// Send data
let bytes = content.as_ptr() as *mut c_void;
let data: id = msg_send![class!(NSData), alloc];
let data: id = msg_send![data, initWithBytesNoCopy:bytes length:content.len() freeWhenDone: if content.len() == 0 { NO } else { YES }];
unsafe fn response(
task: id,
task_id: NSUInteger,
webview_id: u32,
url: id, /* NSURL */
sent_response: HttpResponse<Cow<'_, [u8]>>,
) -> crate::Result<()> {
let content = sent_response.body();
// default: application/octet-stream, but should be provided by the client
let wanted_mime = sent_response.headers().get(CONTENT_TYPE);
// default to 200
let wanted_status_code = sent_response.status().as_u16() as i32;
// default to HTTP/1.1
let wanted_version = format!("{:#?}", sent_response.version());

let dictionary: id = msg_send![class!(NSMutableDictionary), alloc];
let headers: id = msg_send![dictionary, initWithCapacity:1];
if let Some(mime) = wanted_mime {
let () = msg_send![headers, setObject:NSString::new(mime.to_str().unwrap()) forKey: NSString::new(CONTENT_TYPE.as_str())];
}
let () = msg_send![headers, setObject:NSString::new(&content.len().to_string()) forKey: NSString::new(CONTENT_LENGTH.as_str())];

// add headers
for (name, value) in sent_response.headers().iter() {
let header_key = name.as_str();
if let Ok(value) = value.to_str() {
let () = msg_send![headers, setObject:NSString::new(value) forKey: NSString::new(header_key)];
}
}

if !WEBVIEW_IDS.lock().unwrap().contains(&webview_id) {
return;
let urlresponse: id = msg_send![class!(NSHTTPURLResponse), alloc];
// url is part of the task, we need to check task is still valid
check_task_is_valid(webview_id, task_id)?;
let response: id = msg_send![urlresponse, initWithURL:url statusCode: wanted_status_code HTTPVersion:NSString::new(&wanted_version) headerFields:headers];

check_task_is_valid(webview_id, task_id)?;
(*task)
.send_message::<(id,), ()>(sel!(didReceiveResponse:), (response,))
.map_err(|_| crate::Error::CustomProtocolTaskInvalid)?;

// Send data
let bytes = content.as_ptr() as *mut c_void;
let data: id = msg_send![class!(NSData), alloc];
let data: id = msg_send![data, initWithBytesNoCopy:bytes length:content.len() freeWhenDone: if content.len() == 0 { NO } else { YES }];

check_task_is_valid(webview_id, task_id)?;
(*task)
.send_message::<(id,), ()>(sel!(didReceiveData:), (data,))
.map_err(|_| crate::Error::CustomProtocolTaskInvalid)?;

// Finish
check_task_is_valid(webview_id, task_id)?;
(*task)
.send_message::<(), ()>(sel!(didFinish), ())
.map_err(|_| crate::Error::CustomProtocolTaskInvalid)?;

Ok(())
}
let () = msg_send![task, didReceiveData: data];

// Finish
if !WEBVIEW_IDS.lock().unwrap().contains(&webview_id) {
return;
if check_task_is_valid(webview_id, task_id).is_ok() {
let _ = response(task, task_id, webview_id, url, sent_response);
}
let () = msg_send![task, didFinish];
TASK_IDS.lock().unwrap().remove(&task_id);
},
);

#[cfg(feature = "tracing")]
let _span = tracing::info_span!("wry::custom_protocol::call_handler").entered();

{
let mut task_ids = TASK_IDS.lock().unwrap();
task_ids.insert(task_id);
}

function(final_request, RequestAsyncResponder { responder });
}
Err(_) => respond_with_404(),
Expand All @@ -338,7 +376,10 @@ impl InnerWebView {
}
}
}
extern "C" fn stop_task(_: &Object, _: Sel, _webview: id, _task: id) {}
extern "C" fn stop_task(_: &Object, _: Sel, _webview: id, task: id) {
let task_id: NSUInteger = unsafe { msg_send![task, hash] };
TASK_IDS.lock().unwrap().remove(&task_id);
}

let mut wv_ids = WEBVIEW_IDS.lock().unwrap();
let webview_id = COUNTER.next();
Expand Down
Loading