diff --git a/README.md b/README.md index c02bbe1..d2f53d6 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,28 @@ To manually process requests: } ``` +For both cases you can choose between different strategies for how the recorded requests (including +the body) and the intercepted requests are matched together. By default, only the url is used. If +you want to use a different strategy, for example if you have parallel requests to the same url with +different bodies (e.g. GraphQL queries), you can pass a custom `RequestMatcher` to the constructor +of `RequestInspectorWebViewClient`: + +```kotlin + val webView = WebView(this) + webView.webViewClient = RequestInspectorWebViewClient( + webView, + matcher = RequestUuidInHeaderMatcher() + ) +``` + +Currently available matchers are `RequestUuidInHeaderMatcher` and `RequestUuidInUrlMatcher`, which +both create an UUID and add it to the request before it's recorded and sent. They only differ by how +the attach the UUID to the request, as an additional header or as an additional query param. But +both clean up the request before it's been sent. + +If you want to implement your own matching strategy, you can implement the `RequestMatcher` +interface and pass an instance of it to the constructor of `RequestInspectorWebViewClient`. + Known limitations === diff --git a/app/build.gradle.kts b/app/build.gradle.kts index e824808..7d176c7 100644 --- a/app/build.gradle.kts +++ b/app/build.gradle.kts @@ -15,7 +15,12 @@ android { defaultConfig { minSdk = 21 - targetSdk = 31 + testOptions { + targetSdk = 31 + } + lint { + targetSdk = 31 + } version = currentVersion } @@ -28,8 +33,8 @@ android { } compileOptions { - sourceCompatibility = JavaVersion.VERSION_11 - targetCompatibility = JavaVersion.VERSION_11 + sourceCompatibility = JavaVersion.VERSION_17 + targetCompatibility = JavaVersion.VERSION_17 } publishing { @@ -52,3 +57,7 @@ publishing { } } } + +dependencies { + implementation("androidx.core:core-ktx:1.17.0") +} diff --git a/app/src/main/java/com/acsbendi/requestinspectorwebview/RequestInspectorJavaScriptInterface.kt b/app/src/main/java/com/acsbendi/requestinspectorwebview/RequestInspectorJavaScriptInterface.kt index bf09d84..d3c3f30 100644 --- a/app/src/main/java/com/acsbendi/requestinspectorwebview/RequestInspectorJavaScriptInterface.kt +++ b/app/src/main/java/com/acsbendi/requestinspectorwebview/RequestInspectorJavaScriptInterface.kt @@ -1,38 +1,31 @@ package com.acsbendi.requestinspectorwebview +import android.net.Uri import android.util.Log import android.webkit.JavascriptInterface +import android.webkit.WebResourceRequest import android.webkit.WebView +import androidx.core.net.toUri +import com.acsbendi.requestinspectorwebview.matcher.RequestMatcher import org.intellij.lang.annotations.Language import org.json.JSONArray +import org.json.JSONException import org.json.JSONObject import java.net.URLEncoder -internal class RequestInspectorJavaScriptInterface(webView: WebView) { +class RequestInspectorJavaScriptInterface(webView: WebView, val matcher: RequestMatcher) { init { webView.addJavascriptInterface(this, INTERFACE_NAME) } - private val recordedRequests = ArrayList() - - fun findRecordedRequestForUrl(url: String): RecordedRequest? { - return synchronized(recordedRequests) { - // use findLast instead of find to find the last added query matching a URL - - // they are included at the end of the list when written. - recordedRequests.findLast { recordedRequest -> - // Added search by exact URL to find the actual request body - url == recordedRequest.url - } ?: recordedRequests.findLast { recordedRequest -> - // Previously, there was only a search by contains, and because of this, sometimes the wrong request body was found - url.contains(recordedRequest.url) - } - } + fun createWebViewRequest(request: WebResourceRequest): WebViewRequest { + return matcher.createWebViewRequest(request) } data class RecordedRequest( val type: WebViewRequestType, - val url: String, + val url: Uri, val method: String, val body: String, val formParameters: Map, @@ -80,7 +73,7 @@ internal class RequestInspectorJavaScriptInterface(webView: WebView) { addRecordedRequest( RecordedRequest( WebViewRequestType.FORM, - url, + url.toUri(), method, body, formParameterMap, @@ -98,7 +91,7 @@ internal class RequestInspectorJavaScriptInterface(webView: WebView) { addRecordedRequest( RecordedRequest( WebViewRequestType.XML_HTTP, - url, + url.toUri(), method, body, mapOf(), @@ -116,7 +109,7 @@ internal class RequestInspectorJavaScriptInterface(webView: WebView) { addRecordedRequest( RecordedRequest( WebViewRequestType.FETCH, - url, + url.toUri(), method, body, mapOf(), @@ -127,14 +120,28 @@ internal class RequestInspectorJavaScriptInterface(webView: WebView) { ) } + @JavascriptInterface + fun getAdditionalHeaders(url: String): String { + return matcher.getAdditionalHeaders(url).toString() + } + + @JavascriptInterface + fun getAdditionalQueryParam(): String { + return matcher.getAdditionalQueryParams() + } + private fun addRecordedRequest(recordedRequest: RecordedRequest) { - synchronized(recordedRequests) { - recordedRequests.add(recordedRequest) - } + matcher.addRecordedRequest(recordedRequest) } private fun getHeadersAsMap(headersString: String): MutableMap { - val headersObject = JSONObject(headersString) + val headersObject = try { + JSONObject(headersString) + } catch (_: JSONException) { + // When during the creation of a JSONObject from the string a JSONException is thrown, we simply return an + // empty JSONObject. This happens e.g. when JS send "undefined" or an empty string as headers. + JSONObject() + } val map = HashMap() for (key in headersObject.keys()) { val lowercaseHeader = key.lowercase() @@ -158,7 +165,6 @@ internal class RequestInspectorJavaScriptInterface(webView: WebView) { return map } - private fun getUrlEncodedFormBody(formParameterJsonArray: JSONArray): String { val resultStringBuilder = StringBuilder() repeat(formParameterJsonArray.length()) { i -> @@ -250,6 +256,22 @@ function getFullUrl(url) { } } +function appendAdditionalQueryParams(url) { + try { + var extraQueryParam = $INTERFACE_NAME.getAdditionalQueryParam(); + if (extraQueryParam) { + if (url.indexOf('?') === -1) { + url += '?' + extraQueryParam; + } else { + url += '&' + extraQueryParam; + } + } + } catch (e) { + console.warn('Failed to inject query param from Kotlin:', e); + } + return url; +} + function recordFormSubmission(form) { var jsonArr = []; for (i = 0; i < form.elements.length; i++) { @@ -270,7 +292,7 @@ function recordFormSubmission(form) { const path = form.attributes['action'] === undefined ? "/" : form.attributes['action'].nodeValue; const method = form.attributes['method'] === undefined ? "GET" : form.attributes['method'].nodeValue; - const url = getFullUrl(path); + const url = appendAdditionalQueryParams(getFullUrl(path)); const encType = form.attributes['enctype'] === undefined ? "application/x-www-form-urlencoded" : form.attributes['enctype'].nodeValue; const err = new Error(); $INTERFACE_NAME.recordFormSubmission( @@ -302,9 +324,9 @@ let xmlhttpRequestUrl = null; XMLHttpRequest.prototype._open = XMLHttpRequest.prototype.open; XMLHttpRequest.prototype.open = function (method, url, async, user, password) { lastXmlhttpRequestPrototypeMethod = method; - xmlhttpRequestUrl = url; + xmlhttpRequestUrl = appendAdditionalQueryParams(url); const asyncWithDefault = async === undefined ? true : async; - this._open(method, url, asyncWithDefault, user, password); + this._open(method, xmlhttpRequestUrl, asyncWithDefault, user, password); }; XMLHttpRequest.prototype._setRequestHeader = XMLHttpRequest.prototype.setRequestHeader; XMLHttpRequest.prototype.setRequestHeader = function (header, value) { @@ -314,7 +336,18 @@ XMLHttpRequest.prototype.setRequestHeader = function (header, value) { XMLHttpRequest.prototype._send = XMLHttpRequest.prototype.send; XMLHttpRequest.prototype.send = function (body) { const err = new Error(); - const url = getFullUrl(xmlhttpRequestUrl); + let url = getFullUrl(xmlhttpRequestUrl); + // Inject headers from Kotlin if any + try { + var extraHeaders = JSON.parse($INTERFACE_NAME.getAdditionalHeaders(url)); + for (var h in extraHeaders) { + if (extraHeaders.hasOwnProperty(h)) { + this.setRequestHeader(h, extraHeaders[h]); + } + } + } catch (e) { + console.warn('Failed to inject headers from Kotlin (XHR):', e); + } $INTERFACE_NAME.recordXhr( url, lastXmlhttpRequestPrototypeMethod, @@ -331,22 +364,39 @@ XMLHttpRequest.prototype.send = function (body) { window._fetch = window.fetch; window.fetch = function () { const firstArgument = arguments[0]; - let url; - let method; - let body; - let headers; + let url, method, body, headers; if (typeof firstArgument === 'string') { - url = firstArgument; + url = appendAdditionalQueryParams(firstArgument); + if (!arguments[1]) arguments[1] = {}; method = arguments[1] && 'method' in arguments[1] ? arguments[1]['method'] : "GET"; body = arguments[1] && 'body' in arguments[1] ? arguments[1]['body'] : ""; - headers = JSON.stringify(arguments[1] && 'headers' in arguments[1] ? arguments[1]['headers'] : {}); + headers = arguments[1] && 'headers' in arguments[1] ? arguments[1]['headers'] : {}; + // Inject headers from Kotlin if any + try { + var extraHeaders = JSON.parse($INTERFACE_NAME.getAdditionalHeaders(url)); + arguments[1].headers = Object.assign({}, extraHeaders, headers || {}); + } catch (e) { + console.warn('Failed to inject headers from Kotlin (fetch):', e); + } + arguments[0] = url; } else { // Request object - url = firstArgument.url; + url = appendAdditionalQueryParams(firstArgument.url); method = firstArgument.method; body = firstArgument.body; - headers = JSON.stringify(Object.fromEntries(firstArgument.headers.entries())); + headers = Object.fromEntries(firstArgument.headers.entries()); + // Inject headers from Kotlin if any + try { + var extraHeaders = JSON.parse($INTERFACE_NAME.getAdditionalHeaders(url)); + for (var h in extraHeaders) { + firstArgument.headers.set ? firstArgument.headers.set(h, extraHeaders[h]) : firstArgument.headers[h] = extraHeaders[h]; + } + } catch (e) { + console.warn('Failed to inject headers from Kotlin (fetch):', e); + } + firstArgument.url = url; } + const fullUrl = getFullUrl(url); const err = new Error(); $INTERFACE_NAME.recordFetch(fullUrl, method, body, headers, err.stack); diff --git a/app/src/main/java/com/acsbendi/requestinspectorwebview/RequestInspectorWebViewClient.kt b/app/src/main/java/com/acsbendi/requestinspectorwebview/RequestInspectorWebViewClient.kt index 3b895fe..aab9f68 100644 --- a/app/src/main/java/com/acsbendi/requestinspectorwebview/RequestInspectorWebViewClient.kt +++ b/app/src/main/java/com/acsbendi/requestinspectorwebview/RequestInspectorWebViewClient.kt @@ -7,14 +7,17 @@ import android.webkit.WebResourceRequest import android.webkit.WebResourceResponse import android.webkit.WebView import android.webkit.WebViewClient +import com.acsbendi.requestinspectorwebview.matcher.RequestMatcher +import com.acsbendi.requestinspectorwebview.matcher.RequestUrlMatcher @SuppressLint("SetJavaScriptEnabled") open class RequestInspectorWebViewClient @JvmOverloads constructor( webView: WebView, + val matcher: RequestMatcher = RequestUrlMatcher(), private val options: RequestInspectorOptions = RequestInspectorOptions() ) : WebViewClient() { - private val interceptionJavascriptInterface = RequestInspectorJavaScriptInterface(webView) + private val interceptionJavascriptInterface = RequestInspectorJavaScriptInterface(webView, matcher) init { val webSettings = webView.settings @@ -26,10 +29,7 @@ open class RequestInspectorWebViewClient @JvmOverloads constructor( view: WebView, request: WebResourceRequest ): WebResourceResponse? { - val recordedRequest = interceptionJavascriptInterface.findRecordedRequestForUrl( - request.url.toString() - ) - val webViewRequest = WebViewRequest.create(request, recordedRequest) + val webViewRequest = interceptionJavascriptInterface.createWebViewRequest(request) return shouldInterceptRequest(view, webViewRequest) } @@ -48,6 +48,7 @@ open class RequestInspectorWebViewClient @JvmOverloads constructor( override fun onPageStarted(view: WebView, url: String, favicon: Bitmap?) { Log.i(LOG_TAG, "Page started loading, enabling request inspection. URL: $url") + matcher.onPageStarted(url) RequestInspectorJavaScriptInterface.enabledRequestInspection( view, options.extraJavaScriptToInject diff --git a/app/src/main/java/com/acsbendi/requestinspectorwebview/matcher/RequestMatcher.kt b/app/src/main/java/com/acsbendi/requestinspectorwebview/matcher/RequestMatcher.kt new file mode 100644 index 0000000..73e7e6f --- /dev/null +++ b/app/src/main/java/com/acsbendi/requestinspectorwebview/matcher/RequestMatcher.kt @@ -0,0 +1,15 @@ +package com.acsbendi.requestinspectorwebview.matcher + +import com.acsbendi.requestinspectorwebview.RequestInspectorJavaScriptInterface.RecordedRequest +import android.webkit.WebResourceRequest +import com.acsbendi.requestinspectorwebview.WebViewRequest +import org.json.JSONObject + +interface RequestMatcher { + fun addRecordedRequest(recordedRequest: RecordedRequest) + fun createWebViewRequest(request: WebResourceRequest): WebViewRequest + fun getAdditionalHeaders(url: String): JSONObject = JSONObject() + fun getAdditionalQueryParams(): String = "" + fun onPageStarted(url: String) {} +} + diff --git a/app/src/main/java/com/acsbendi/requestinspectorwebview/matcher/RequestUrlMatcher.kt b/app/src/main/java/com/acsbendi/requestinspectorwebview/matcher/RequestUrlMatcher.kt new file mode 100644 index 0000000..56151ef --- /dev/null +++ b/app/src/main/java/com/acsbendi/requestinspectorwebview/matcher/RequestUrlMatcher.kt @@ -0,0 +1,35 @@ +package com.acsbendi.requestinspectorwebview.matcher + +import com.acsbendi.requestinspectorwebview.RequestInspectorJavaScriptInterface.RecordedRequest +import android.webkit.WebResourceRequest +import com.acsbendi.requestinspectorwebview.WebViewRequest + +class RequestUrlMatcher : RequestMatcher { + private val recordedRequests = ArrayList() + + override fun addRecordedRequest(recordedRequest: RecordedRequest) { + synchronized(recordedRequests) { + recordedRequests.add(recordedRequest) + } + } + + override fun createWebViewRequest(request: WebResourceRequest): WebViewRequest { + val recordedRequest = findRecordedRequest(request) + return WebViewRequest.create(request, recordedRequest) + } + + private fun findRecordedRequest(request: WebResourceRequest): RecordedRequest? { + return synchronized(recordedRequests) { + val url = request.url.toString() + // use findLast instead of find to find the last added query matching a URL - + // they are included at the end of the list when written. + recordedRequests.findLast { recordedRequest -> + // Added search by exact URL to find the actual request body + url == recordedRequest.url.toString() + } ?: recordedRequests.findLast { recordedRequest -> + // Previously, there was only a search by contains, and because of this, sometimes the wrong request body was found + url.contains(recordedRequest.url.toString()) + } + } + } +} diff --git a/app/src/main/java/com/acsbendi/requestinspectorwebview/matcher/RequestUuidInHeaderMatcher.kt b/app/src/main/java/com/acsbendi/requestinspectorwebview/matcher/RequestUuidInHeaderMatcher.kt new file mode 100644 index 0000000..275ca72 --- /dev/null +++ b/app/src/main/java/com/acsbendi/requestinspectorwebview/matcher/RequestUuidInHeaderMatcher.kt @@ -0,0 +1,65 @@ +package com.acsbendi.requestinspectorwebview.matcher + +import android.util.Log +import android.webkit.WebResourceRequest +import androidx.core.net.toUri +import com.acsbendi.requestinspectorwebview.RequestInspectorJavaScriptInterface.RecordedRequest +import org.json.JSONObject +import java.util.UUID + +/** + * This matcher only works for NON CORS requests. It adds a unique UUID header to each request + * originating from the WebView, and matches recorded requests based on that header. + * + * It doesn't work for CORS requests, because it changes the headers of the request, which influences the preflight + * request checking for allowed headers. Even when cleaning up the headers after the request is matched with it's body, + * the CORS request will fail because the browser engine only knows about the adapted header and doesn't execute the + * CORS request, because the preflight check doesn't return the custom header as allowed. + */ +class RequestUuidInHeaderMatcher() : RequestUuidMatcher() { + + private var origin: String = "" + + override fun getUuidFromRequest(recordedRequest: RecordedRequest): String? = + recordedRequest.headers[REQUEST_INSPECTOR_ID] + + override fun getUuidFromRequest(webResourceRequest: WebResourceRequest): String? = + webResourceRequest.requestHeaders[REQUEST_INSPECTOR_ID] + + override fun removeUuidFromRequests(request: WebResourceRequest, recordedRequest: RecordedRequest?): Pair { + // Clean up headers by removing REQUEST_ID_HEADER from both requests + val cleanedRequest = object : WebResourceRequest by request { + override fun getRequestHeaders(): Map = + request.requestHeaders.filter { (key, _) -> key != REQUEST_INSPECTOR_ID } + } + val cleanedRecordedRequest = recordedRequest?.copy( + headers = recordedRequest.headers.filter { (key, _) -> key != REQUEST_INSPECTOR_ID } + ) + return cleanedRequest to cleanedRecordedRequest + } + + override fun getAdditionalHeaders(url: String): JSONObject { + val headersJson = JSONObject() + if (getOrigin(url) == origin) { + val uuid = UUID.randomUUID().toString() + headersJson.put(REQUEST_INSPECTOR_ID, uuid) + } else { + Log.i(LOG_TAG, "Recorded CORS to $url, not adding $REQUEST_INSPECTOR_ID") + } + return headersJson + } + + override fun onPageStarted(url: String) { + origin = getOrigin(url) + } + + private fun getOrigin(url: String): String { + val uri = url.toUri() + val port = if (uri.port != -1) ":${uri.port}" else "" + return "${uri.scheme}://${uri.host}$port" + } + + companion object { + private const val LOG_TAG = "RequestUuidMatcher" + } +} diff --git a/app/src/main/java/com/acsbendi/requestinspectorwebview/matcher/RequestUuidInQueryParamMatcher.kt b/app/src/main/java/com/acsbendi/requestinspectorwebview/matcher/RequestUuidInQueryParamMatcher.kt new file mode 100644 index 0000000..bdc0807 --- /dev/null +++ b/app/src/main/java/com/acsbendi/requestinspectorwebview/matcher/RequestUuidInQueryParamMatcher.kt @@ -0,0 +1,39 @@ +package com.acsbendi.requestinspectorwebview.matcher + +import android.webkit.WebResourceRequest +import com.acsbendi.requestinspectorwebview.RequestInspectorJavaScriptInterface +import java.util.UUID + +class RequestUuidInQueryParamMatcher : RequestUuidMatcher() { + + override fun getUuidFromRequest(recordedRequest: RequestInspectorJavaScriptInterface.RecordedRequest): String? = + recordedRequest.url.getQueryParameter(REQUEST_INSPECTOR_ID) + + override fun getUuidFromRequest(webResourceRequest: WebResourceRequest): String? = + webResourceRequest.url.getQueryParameter(REQUEST_INSPECTOR_ID) + + override fun removeUuidFromRequests( + request: WebResourceRequest, + recordedRequest: RequestInspectorJavaScriptInterface.RecordedRequest? + ): Pair { + val originalUrl = request.url + val cleanedUrlBuilder = originalUrl.buildUpon().clearQuery() + for (key in originalUrl.queryParameterNames.filter { it != REQUEST_INSPECTOR_ID }) { + originalUrl.getQueryParameters(key).forEach { paramValue -> + cleanedUrlBuilder.appendQueryParameter(key, paramValue) + } + } + val cleanedUrl = cleanedUrlBuilder.build() + + val cleanedWebResourceRequest = object : WebResourceRequest by request { + override fun getUrl() = cleanedUrl + } + val cleanedRecordedRequest = recordedRequest?.copy(url = cleanedUrl) + return cleanedWebResourceRequest to cleanedRecordedRequest + } + + override fun getAdditionalQueryParams(): String { + val uuid = UUID.randomUUID().toString() + return "$REQUEST_INSPECTOR_ID=$uuid" + } +} diff --git a/app/src/main/java/com/acsbendi/requestinspectorwebview/matcher/RequestUuidMatcher.kt b/app/src/main/java/com/acsbendi/requestinspectorwebview/matcher/RequestUuidMatcher.kt new file mode 100644 index 0000000..8e7de8a --- /dev/null +++ b/app/src/main/java/com/acsbendi/requestinspectorwebview/matcher/RequestUuidMatcher.kt @@ -0,0 +1,46 @@ +package com.acsbendi.requestinspectorwebview.matcher + +import android.webkit.WebResourceRequest +import com.acsbendi.requestinspectorwebview.RequestInspectorJavaScriptInterface.RecordedRequest +import com.acsbendi.requestinspectorwebview.WebViewRequest + +abstract class RequestUuidMatcher : RequestMatcher { + + private val recordedRequests = mutableMapOf() + + abstract fun getUuidFromRequest(recordedRequest: RecordedRequest): String? + abstract fun getUuidFromRequest(webResourceRequest: WebResourceRequest): String? + abstract fun removeUuidFromRequests( + request: WebResourceRequest, + recordedRequest: RecordedRequest? + ): Pair + + final override fun addRecordedRequest(recordedRequest: RecordedRequest) { + val id = getUuidFromRequest(recordedRequest) ?: return + + synchronized(recordedRequests) { + recordedRequests[id] = recordedRequest + } + } + + override fun createWebViewRequest(request: WebResourceRequest): WebViewRequest { + val recordedRequest = findRecordedRequest(request) + val (cleanedRequest, cleanedRecordedRequest) = removeUuidFromRequests(request, recordedRequest) + return WebViewRequest.create(cleanedRequest, cleanedRecordedRequest) + } + + + private fun findRecordedRequest(request: WebResourceRequest): RecordedRequest? { + val id = getUuidFromRequest(request) ?: return null + val recordedRequest = synchronized(recordedRequests) { + recordedRequests.remove(id) + } + return recordedRequest + } + + override fun onPageStarted(url: String) {} + + companion object { + const val REQUEST_INSPECTOR_ID = "x-request-inspector-id" + } +} diff --git a/build.gradle.kts b/build.gradle.kts index d60edc0..ab0f969 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -6,8 +6,8 @@ buildscript { mavenCentral() } dependencies { - classpath("com.android.tools.build:gradle:8.4.0") - classpath(kotlin("gradle-plugin", version = "1.6.21")) + classpath("com.android.tools.build:gradle:8.4.2") + classpath(kotlin("gradle-plugin", version = "2.0.21")) // NOTE: Do not place your application dependencies here; they belong // in the individual module build.gradle files @@ -15,5 +15,5 @@ buildscript { } tasks.register("clean", Delete::class) { - delete(rootProject.buildDir) + delete(rootProject.layout.buildDirectory.get().asFile) }