Skip to content

Commit

Permalink
feat(android): enhance initialization scripts (#1076)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasfernog authored Nov 13, 2023
1 parent dbea0f3 commit 4d6f08e
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 81 deletions.
5 changes: 5 additions & 0 deletions .changes/enhance-init-scripts-android.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"wry": patch
---

Enhance initalization script implementation on Android supporting any kind of URL.
19 changes: 14 additions & 5 deletions src/android/binding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ macro_rules! android_binding {
$package,
RustWebViewClient,
handleRequest,
[JObject],
[JObject, jboolean],
jobject
);
android_fn!(
Expand Down Expand Up @@ -97,7 +97,11 @@ macro_rules! android_binding {
}};
}

fn handle_request(env: &mut JNIEnv, request: JObject) -> JniResult<jobject> {
fn handle_request(
env: &mut JNIEnv,
request: JObject,
is_document_start_script_enabled: jboolean,
) -> JniResult<jobject> {
if let Some(handler) = REQUEST_HANDLER.get() {
let mut request_builder = Request::builder();

Expand Down Expand Up @@ -148,7 +152,7 @@ fn handle_request(env: &mut JNIEnv, request: JObject) -> JniResult<jobject> {
}
};

let response = (handler.handler)(final_request);
let response = (handler.handler)(final_request, is_document_start_script_enabled != 0);
if let Some(response) = response {
let status = response.status();
let status_code = status.as_u16() as i32;
Expand Down Expand Up @@ -226,8 +230,13 @@ fn handle_request(env: &mut JNIEnv, request: JObject) -> JniResult<jobject> {
}

#[allow(non_snake_case)]
pub unsafe fn handleRequest(mut env: JNIEnv, _: JClass, request: JObject) -> jobject {
match handle_request(&mut env, request) {
pub unsafe fn handleRequest(
mut env: JNIEnv,
_: JClass,
request: JObject,
is_document_start_script_enabled: jboolean,
) -> jobject {
match handle_request(&mut env, request, is_document_start_script_enabled) {
Ok(response) => response,
Err(e) => {
log::warn!("Failed to handle request: {}", e);
Expand Down
18 changes: 17 additions & 1 deletion src/android/kotlin/RustWebView.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,34 @@

package {{package}}

import android.annotation.SuppressLint
import android.webkit.*
import android.content.Context
import androidx.webkit.WebViewCompat
import androidx.webkit.WebViewFeature
import kotlin.collections.Map

class RustWebView(context: Context): WebView(context) {
@SuppressLint("RestrictedApi")
class RustWebView(context: Context, val initScripts: Array<String>): WebView(context) {
val isDocumentStartScriptEnabled: Boolean

init {
settings.javaScriptEnabled = true
settings.domStorageEnabled = true
settings.setGeolocationEnabled(true)
settings.databaseEnabled = true
settings.mediaPlaybackRequiresUserGesture = false
settings.javaScriptCanOpenWindowsAutomatically = true

if (WebViewFeature.isFeatureSupported(WebViewFeature.DOCUMENT_START_SCRIPT)) {
isDocumentStartScriptEnabled = true
for (script in initScripts) {
WebViewCompat.addDocumentStartJavaScript(this, script, setOf("*"));
}
} else {
isDocumentStartScriptEnabled = false
}

{{class-init}}
}

Expand Down
18 changes: 14 additions & 4 deletions src/android/kotlin/RustWebViewClient.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import android.graphics.Bitmap
import androidx.webkit.WebViewAssetLoader

class RustWebViewClient(context: Context): WebViewClient() {
private val interceptedState = mutableMapOf<String, Boolean>()

private val assetLoader = WebViewAssetLoader.Builder()
.setDomain(assetLoaderDomain())
.addPathHandler("/", WebViewAssetLoader.AssetsPathHandler(context))
Expand All @@ -22,7 +24,9 @@ class RustWebViewClient(context: Context): WebViewClient() {
return if (withAssetLoader()) {
assetLoader.shouldInterceptRequest(request.url)
} else {
handleRequest(request)
val response = handleRequest(request, (view as RustWebView).isDocumentStartScriptEnabled)
interceptedState[request.url.toString()] = response != null
return response
}
}

Expand All @@ -33,11 +37,17 @@ class RustWebViewClient(context: Context): WebViewClient() {
return shouldOverride(request.url.toString())
}

override fun onPageStarted(view: WebView, url: String, favicon: Bitmap?): Unit {
override fun onPageStarted(view: WebView, url: String, favicon: Bitmap?) {
if (interceptedState[url] == false) {
val webView = view as RustWebView
for (script in webView.initScripts) {
view.evaluateJavascript(script, null)
}
}
return onPageLoading(url)
}

override fun onPageFinished(view: WebView, url: String): Unit {
override fun onPageFinished(view: WebView, url: String) {
return onPageLoaded(url)
}

Expand All @@ -50,7 +60,7 @@ class RustWebViewClient(context: Context): WebViewClient() {

private external fun assetLoaderDomain(): String
private external fun withAssetLoader(): Boolean
private external fun handleRequest(request: WebResourceRequest): WebResourceResponse?
private external fun handleRequest(request: WebResourceRequest, isDocumentStartScriptEnabled: Boolean): WebResourceResponse?
private external fun shouldOverride(url: String): Boolean
private external fun onPageLoading(url: String)
private external fun onPageLoaded(url: String)
Expand Down
21 changes: 19 additions & 2 deletions src/android/main_pipe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,24 @@ impl<'a> MainPipe<'a> {
on_webview_created,
autoplay,
user_agent,
initialization_scripts,
..
} = attrs;

let string_class = self.env.find_class("java/lang/String")?;
let initialization_scripts_array = self.env.new_object_array(
initialization_scripts.len() as i32,
string_class,
self.env.new_string("")?,
)?;
for (i, script) in initialization_scripts.into_iter().enumerate() {
self.env.set_object_array_element(
&initialization_scripts_array,
i as i32,
self.env.new_string(script)?,
)?;
}

// Create webview
let rust_webview_class = find_class(
&mut self.env,
Expand All @@ -62,8 +78,8 @@ impl<'a> MainPipe<'a> {
)?;
let webview = self.env.new_object(
&rust_webview_class,
"(Landroid/content/Context;)V",
&[activity.into()],
"(Landroid/content/Context;[Ljava/lang/String;)V",
&[activity.into(), (&initialization_scripts_array).into()],
)?;

// set media autoplay
Expand Down Expand Up @@ -344,4 +360,5 @@ pub(crate) struct CreateWebViewAttributes {
pub autoplay: bool,
pub on_webview_created: Option<Box<dyn Fn(super::Context) -> JniResult<()> + Send>>,
pub user_agent: Option<String>,
pub initialization_scripts: Vec<String>,
}
145 changes: 78 additions & 67 deletions src/android/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ macro_rules! define_static_handlers {

define_static_handlers! {
IPC = UnsafeIpc { handler: Box<dyn Fn(String)> };
REQUEST_HANDLER = UnsafeRequestHandler { handler: Box<dyn Fn(Request<Vec<u8>>) -> Option<HttpResponse<Cow<'static, [u8]>>>> };
REQUEST_HANDLER = UnsafeRequestHandler { handler: Box<dyn Fn(Request<Vec<u8>>, bool) -> Option<HttpResponse<Cow<'static, [u8]>>>> };
TITLE_CHANGE_HANDLER = UnsafeTitleHandler { handler: Box<dyn Fn(String)> };
URL_LOADING_OVERRIDE = UnsafeUrlLoadingOverride { handler: Box<dyn Fn(String) -> bool> };
ON_LOAD_HANDLER = UnsafeOnPageLoadHandler { handler: Box<dyn Fn(PageLoadEvent, String)> };
Expand Down Expand Up @@ -179,6 +179,7 @@ impl InnerWebView {
on_webview_created,
autoplay,
user_agent,
initialization_scripts: initialization_scripts.clone(),
}));

WITH_ASSET_LOADER.get_or_init(move || with_asset_loader);
Expand All @@ -187,77 +188,87 @@ impl InnerWebView {
}

REQUEST_HANDLER.get_or_init(move || {
UnsafeRequestHandler::new(Box::new(move |mut request| {
if let Some(custom_protocol) = custom_protocols.iter().find(|(name, _)| {
request
.uri()
.to_string()
.starts_with(&format!("{custom_protocol_scheme}://{}.", name))
}) {
*request.uri_mut() = request
.uri()
.to_string()
.replace(
&format!("{custom_protocol_scheme}://{}.", custom_protocol.0),
&format!("{}://", custom_protocol.0),
)
.parse()
.unwrap();

let (tx, rx) = channel();
let initialization_scripts = initialization_scripts.clone();
let responder: Box<dyn FnOnce(HttpResponse<Cow<'static, [u8]>>)> =
Box::new(move |mut response| {
let should_inject_scripts = response
.headers()
.get(CONTENT_TYPE)
// Content-Type must begin with the media type, but is case-insensitive.
// It may also be followed by any number of semicolon-delimited key value pairs.
// We don't care about these here.
// source: https://httpwg.org/specs/rfc9110.html#rfc.section.8.3.1
.and_then(|content_type| content_type.to_str().ok())
.map(|content_type_str| content_type_str.to_lowercase().starts_with("text/html"))
.unwrap_or_default();

if should_inject_scripts && !initialization_scripts.is_empty() {
let mut document =
kuchiki::parse_html().one(String::from_utf8_lossy(response.body()).into_owned());
let csp = response.headers_mut().get_mut(CONTENT_SECURITY_POLICY);
let mut hashes = Vec::new();
with_html_head(&mut document, |head| {
// iterate in reverse order since we are prepending each script to the head tag
for script in initialization_scripts.iter().rev() {
let script_el =
NodeRef::new_element(QualName::new(None, ns!(html), "script".into()), None);
script_el.append(NodeRef::new_text(script));
head.prepend(script_el);
if csp.is_some() {
hashes.push(hash_script(script));
UnsafeRequestHandler::new(Box::new(
move |mut request, is_document_start_script_enabled| {
if let Some(custom_protocol) = custom_protocols.iter().find(|(name, _)| {
request
.uri()
.to_string()
.starts_with(&format!("{custom_protocol_scheme}://{}.", name))
}) {
*request.uri_mut() = request
.uri()
.to_string()
.replace(
&format!("{custom_protocol_scheme}://{}.", custom_protocol.0),
&format!("{}://", custom_protocol.0),
)
.parse()
.unwrap();

let (tx, rx) = channel();
let initialization_scripts = initialization_scripts.clone();
let responder: Box<dyn FnOnce(HttpResponse<Cow<'static, [u8]>>)> =
Box::new(move |mut response| {
if !is_document_start_script_enabled {
log::info!("`addDocumentStartJavaScript` is not supported; injecting initialization scripts via custom protocol handler");
let should_inject_scripts = response
.headers()
.get(CONTENT_TYPE)
// Content-Type must begin with the media type, but is case-insensitive.
// It may also be followed by any number of semicolon-delimited key value pairs.
// We don't care about these here.
// source: https://httpwg.org/specs/rfc9110.html#rfc.section.8.3.1
.and_then(|content_type| content_type.to_str().ok())
.map(|content_type_str| {
content_type_str.to_lowercase().starts_with("text/html")
})
.unwrap_or_default();

if should_inject_scripts && !initialization_scripts.is_empty() {
let mut document = kuchiki::parse_html()
.one(String::from_utf8_lossy(response.body()).into_owned());
let csp = response.headers_mut().get_mut(CONTENT_SECURITY_POLICY);
let mut hashes = Vec::new();
with_html_head(&mut document, |head| {
// iterate in reverse order since we are prepending each script to the head tag
for script in initialization_scripts.iter().rev() {
let script_el = NodeRef::new_element(
QualName::new(None, ns!(html), "script".into()),
None,
);
script_el.append(NodeRef::new_text(script));
head.prepend(script_el);
if csp.is_some() {
hashes.push(hash_script(script));
}
}
});

if let Some(csp) = csp {
let csp_string = csp.to_str().unwrap().to_string();
let csp_string = if csp_string.contains("script-src") {
csp_string
.replace("script-src", &format!("script-src {}", hashes.join(" ")))
} else {
format!("{} script-src {}", csp_string, hashes.join(" "))
};
*csp = HeaderValue::from_str(&csp_string).unwrap();
}

*response.body_mut() = document.to_string().into_bytes().into();
}
});

if let Some(csp) = csp {
let csp_string = csp.to_str().unwrap().to_string();
let csp_string = if csp_string.contains("script-src") {
csp_string.replace("script-src", &format!("script-src {}", hashes.join(" ")))
} else {
format!("{} script-src {}", csp_string, hashes.join(" "))
};
*csp = HeaderValue::from_str(&csp_string).unwrap();
}

*response.body_mut() = document.to_string().into_bytes().into();
}

tx.send(response).unwrap();
});
tx.send(response).unwrap();
});

(custom_protocol.1)(request, RequestAsyncResponder { responder });
return Some(rx.recv().unwrap());
}
None
}))
(custom_protocol.1)(request, RequestAsyncResponder { responder });
return Some(rx.recv().unwrap());
}
None
},
))
});

if let Some(i) = ipc_handler {
Expand Down
8 changes: 6 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -675,8 +675,12 @@ impl<'a> WebViewBuilder<'a> {
///
/// ## Platform-specific
///
/// - **Android:** The Android WebView does not provide an API for initialization scripts,
/// so we prepend them to each HTML head. They are only implemented on custom protocol URLs.
/// - **Android:** When [addDocumentStartJavaScript] is not supported,
/// we prepend them to each HTML head (implementation only supported on custom protocol URLs).
/// For remote URLs, we use [onPageStarted] which is not guaranteed to run before other scripts.
///
/// [addDocumentStartJavaScript]: https://developer.android.com/reference/androidx/webkit/WebViewCompat#addDocumentStartJavaScript(android.webkit.WebView,java.lang.String,java.util.Set%3Cjava.lang.String%3E)
/// [onPageStarted]: https://developer.android.com/reference/android/webkit/WebViewClient#onPageStarted(android.webkit.WebView,%20java.lang.String,%20android.graphics.Bitmap)
pub fn with_initialization_script(mut self, js: &str) -> Self {
if !js.is_empty() {
self.attrs.initialization_scripts.push(js.to_string());
Expand Down

0 comments on commit 4d6f08e

Please sign in to comment.