From a013907d65a864bf9a72a7cce2266cb57d536051 Mon Sep 17 00:00:00 2001 From: Ian Rumac Date: Wed, 21 Jan 2026 16:31:12 +0100 Subject: [PATCH] Add custom callbacks message and interface --- .../main/java/com/superwall/sdk/Superwall.kt | 8 + .../sdk/dependencies/DependencyContainer.kt | 4 + .../paywall/presentation/CustomCallback.kt | 72 +++++ .../presentation/CustomCallbackRegistry.kt | 51 +++ .../PaywallPresentationHandler.kt | 34 ++ .../presentation/PublicPresentation.kt | 17 + .../view/webview/messaging/PaywallMessage.kt | 33 ++ .../messaging/PaywallMessageHandler.kt | 114 +++++++ .../view/webview/messaging/PaywallWebEvent.kt | 9 + .../CustomCallbackRegistryTest.kt | 241 +++++++++++++++ .../paywall/view/PaywallMessageHandlerTest.kt | 3 + .../sdk/paywall/view/PaywallViewTest.kt | 3 + .../view/webview/PaywallMessageHandlerTest.kt | 3 + .../PaywallMessageHandlerEdgeCasesTest.kt | 292 ++++++++++++++++++ 14 files changed, 884 insertions(+) create mode 100644 superwall/src/main/java/com/superwall/sdk/paywall/presentation/CustomCallback.kt create mode 100644 superwall/src/main/java/com/superwall/sdk/paywall/presentation/CustomCallbackRegistry.kt create mode 100644 superwall/src/test/java/com/superwall/sdk/paywall/presentation/CustomCallbackRegistryTest.kt diff --git a/superwall/src/main/java/com/superwall/sdk/Superwall.kt b/superwall/src/main/java/com/superwall/sdk/Superwall.kt index ff7a45d9d..6f252c943 100644 --- a/superwall/src/main/java/com/superwall/sdk/Superwall.kt +++ b/superwall/src/main/java/com/superwall/sdk/Superwall.kt @@ -1417,6 +1417,14 @@ class Superwall( message = "Permission requested: ${paywallEvent.permissionType.rawValue}", ) } + + is PaywallWebEvent.RequestCallback -> { + Logger.debug( + LogLevel.debug, + LogScope.paywallView, + message = "Custom callback requested: ${paywallEvent.name}", + ) + } } } } diff --git a/superwall/src/main/java/com/superwall/sdk/dependencies/DependencyContainer.kt b/superwall/src/main/java/com/superwall/sdk/dependencies/DependencyContainer.kt index 297ec2d79..60e16213b 100644 --- a/superwall/src/main/java/com/superwall/sdk/dependencies/DependencyContainer.kt +++ b/superwall/src/main/java/com/superwall/sdk/dependencies/DependencyContainer.kt @@ -66,6 +66,7 @@ import com.superwall.sdk.network.device.DeviceInfo import com.superwall.sdk.network.session.CustomHttpUrlConnection import com.superwall.sdk.paywall.manager.PaywallManager import com.superwall.sdk.paywall.manager.PaywallViewCache +import com.superwall.sdk.paywall.presentation.CustomCallbackRegistry import com.superwall.sdk.paywall.presentation.PaywallInfo import com.superwall.sdk.paywall.presentation.dismiss import com.superwall.sdk.paywall.presentation.get_presentation_result.internallyGetPresentationResult @@ -190,6 +191,7 @@ class DependencyContainer( val googleBillingWrapper: GoogleBillingWrapper internal val reviewManager: ReviewManager internal val userPermissions: UserPermissions + internal val customCallbackRegistry: CustomCallbackRegistry var entitlements: Entitlements internal lateinit var customerInfoManager: CustomerInfoManager @@ -597,6 +599,7 @@ class DependencyContainer( } userPermissions = UserPermissionsImpl(context) + customCallbackRegistry = CustomCallbackRegistry() deepLinkRouter = DeepLinkRouter( @@ -709,6 +712,7 @@ class DependencyContainer( }, userPermissions = userPermissions, getActivity = { activityProvider?.getCurrentActivity() }, + customCallbackRegistry = customCallbackRegistry, ) val state = diff --git a/superwall/src/main/java/com/superwall/sdk/paywall/presentation/CustomCallback.kt b/superwall/src/main/java/com/superwall/sdk/paywall/presentation/CustomCallback.kt new file mode 100644 index 000000000..8361f7879 --- /dev/null +++ b/superwall/src/main/java/com/superwall/sdk/paywall/presentation/CustomCallback.kt @@ -0,0 +1,72 @@ +package com.superwall.sdk.paywall.presentation + +/** + * Defines how the paywall waits for a custom callback response. + */ +enum class CustomCallbackBehavior( + val rawValue: String, +) { + /** + * The paywall waits for the callback to complete before continuing + * the tap action chain. + */ + BLOCKING("blocking"), + + /** + * The paywall continues immediately; the response still triggers + * onSuccess/onFailure handlers in the paywall. + */ + NON_BLOCKING("non-blocking"), + ; + + companion object { + fun fromRaw(rawValue: String): CustomCallbackBehavior? = entries.find { it.rawValue == rawValue } + } +} + +/** + * Represents a custom callback request from the paywall. + * + * @property name The name of the callback being requested. + * @property variables Optional key-value pairs passed from the paywall. + * Values are type-preserved (string/number/boolean). + */ +data class CustomCallback( + val name: String, + val variables: Map?, +) + +/** + * The result status of a custom callback. + */ +enum class CustomCallbackResultStatus( + val rawValue: String, +) { + SUCCESS("success"), + FAILURE("failure"), +} + +/** + * The result to return from a custom callback handler. + * + * @property status Whether the callback succeeded or failed. + * Determines which branch (onSuccess/onFailure) executes in the paywall. + * @property data Optional key-value pairs to return to the paywall. + * Values are type-preserved and accessible as callbacks..data.. + */ +data class CustomCallbackResult( + val status: CustomCallbackResultStatus, + val data: Map? = null, +) { + companion object { + /** + * Creates a success result with optional data. + */ + fun success(data: Map? = null): CustomCallbackResult = CustomCallbackResult(CustomCallbackResultStatus.SUCCESS, data) + + /** + * Creates a failure result with optional data. + */ + fun failure(data: Map? = null): CustomCallbackResult = CustomCallbackResult(CustomCallbackResultStatus.FAILURE, data) + } +} diff --git a/superwall/src/main/java/com/superwall/sdk/paywall/presentation/CustomCallbackRegistry.kt b/superwall/src/main/java/com/superwall/sdk/paywall/presentation/CustomCallbackRegistry.kt new file mode 100644 index 000000000..dde261c16 --- /dev/null +++ b/superwall/src/main/java/com/superwall/sdk/paywall/presentation/CustomCallbackRegistry.kt @@ -0,0 +1,51 @@ +package com.superwall.sdk.paywall.presentation + +import java.util.concurrent.ConcurrentHashMap + +/** + * Registry for custom callback handlers associated with paywall presentations. + * + * Handlers are stored by paywall identifier and should be cleaned up when + * the paywall is dismissed to prevent memory leaks. + */ +class CustomCallbackRegistry { + private val handlers = ConcurrentHashMap CustomCallbackResult>() + + /** + * Registers a custom callback handler for a paywall. + * + * @param paywallIdentifier The unique identifier of the paywall + * @param handler The callback handler to register + */ + fun register( + paywallIdentifier: String, + handler: suspend (CustomCallback) -> CustomCallbackResult, + ) { + handlers[paywallIdentifier] = handler + } + + /** + * Unregisters the custom callback handler for a paywall. + * Should be called when the paywall is dismissed. + * + * @param paywallIdentifier The unique identifier of the paywall + */ + fun unregister(paywallIdentifier: String) { + handlers.remove(paywallIdentifier) + } + + /** + * Gets the custom callback handler for a paywall, if registered. + * + * @param paywallIdentifier The unique identifier of the paywall + * @return The handler, or null if not registered + */ + fun getHandler(paywallIdentifier: String): (suspend (CustomCallback) -> CustomCallbackResult)? = handlers[paywallIdentifier] + + /** + * Clears all registered handlers. + */ + fun clear() { + handlers.clear() + } +} diff --git a/superwall/src/main/java/com/superwall/sdk/paywall/presentation/PaywallPresentationHandler.kt b/superwall/src/main/java/com/superwall/sdk/paywall/presentation/PaywallPresentationHandler.kt index 45371ec63..453851bea 100644 --- a/superwall/src/main/java/com/superwall/sdk/paywall/presentation/PaywallPresentationHandler.kt +++ b/superwall/src/main/java/com/superwall/sdk/paywall/presentation/PaywallPresentationHandler.kt @@ -16,6 +16,9 @@ class PaywallPresentationHandler { // A block called when a paywall is skipped, but no error has occurred internal var onSkipHandler: ((PaywallSkippedReason) -> Unit)? = null + // A block called when the paywall requests a custom callback + internal var onCustomCallbackHandler: (suspend (CustomCallback) -> CustomCallbackResult)? = null + // Sets the handler that will be called when the paywall did present fun onPresent(handler: (PaywallInfo) -> Unit) { onPresentHandler = handler @@ -35,4 +38,35 @@ class PaywallPresentationHandler { fun onSkip(handler: (PaywallSkippedReason) -> Unit) { onSkipHandler = handler } + + /** + * Sets the handler that will be called when the paywall requests a custom callback. + * + * Custom callbacks allow paywalls to request arbitrary actions from the app and + * receive results that determine which branch (onSuccess/onFailure) executes. + * + * @param handler A function that receives a [CustomCallback] containing the callback + * name and optional variables, and returns a [CustomCallbackResult] + * indicating success/failure with optional data. + * + * Example: + * ``` + * handler.onCustomCallback { callback -> + * when (callback.name) { + * "validate_email" -> { + * val email = callback.variables?.get("email") as? String + * if (isValidEmail(email)) { + * CustomCallbackResult.success(mapOf("validated" to true)) + * } else { + * CustomCallbackResult.failure(mapOf("error" to "Invalid email")) + * } + * } + * else -> CustomCallbackResult.failure() + * } + * } + * ``` + */ + fun onCustomCallback(handler: suspend (CustomCallback) -> CustomCallbackResult) { + onCustomCallbackHandler = handler + } } diff --git a/superwall/src/main/java/com/superwall/sdk/paywall/presentation/PublicPresentation.kt b/superwall/src/main/java/com/superwall/sdk/paywall/presentation/PublicPresentation.kt index 52594e38a..29fbb9fb9 100644 --- a/superwall/src/main/java/com/superwall/sdk/paywall/presentation/PublicPresentation.kt +++ b/superwall/src/main/java/com/superwall/sdk/paywall/presentation/PublicPresentation.kt @@ -118,6 +118,8 @@ internal data class RegisterContext( val collectionScope: CoroutineScope, val serialTaskManager: SerialTaskManager, val trackAndPresentContext: TrackAndPresentContext, + val registerCustomCallback: ((String, suspend (CustomCallback) -> CustomCallbackResult) -> Unit)? = null, + val unregisterCustomCallback: ((String) -> Unit)? = null, ) internal data class RegisterRequest( @@ -142,11 +144,20 @@ internal fun registerPaywall( withErrorTracking { when (state) { is PaywallState.Presented -> { + // Register custom callback handler if provided + request.handler?.onCustomCallbackHandler?.let { callbackHandler -> + context.registerCustomCallback?.invoke( + state.paywallInfo.identifier, + callbackHandler, + ) + } request.handler?.onPresentHandler?.invoke(state.paywallInfo) } is PaywallState.Dismissed -> { val (paywallInfo, paywallResult) = state + // Unregister custom callback handler + context.unregisterCustomCallback?.invoke(paywallInfo.identifier) request.handler?.onDismissHandler?.invoke(paywallInfo, paywallResult) when (paywallResult) { is Purchased, is Restored -> { @@ -313,6 +324,12 @@ private fun Superwall.internallyRegister( isPaywallPresented = { isPaywallPresented }, present = { request, publisher -> internallyPresent(request, publisher) }, ), + registerCustomCallback = { paywallId, callbackHandler -> + dependencyContainer.customCallbackRegistry.register(paywallId, callbackHandler) + }, + unregisterCustomCallback = { paywallId -> + dependencyContainer.customCallbackRegistry.unregister(paywallId) + }, ) val registerRequest = diff --git a/superwall/src/main/java/com/superwall/sdk/paywall/view/webview/messaging/PaywallMessage.kt b/superwall/src/main/java/com/superwall/sdk/paywall/view/webview/messaging/PaywallMessage.kt index dc9a8d381..a1875119c 100644 --- a/superwall/src/main/java/com/superwall/sdk/paywall/view/webview/messaging/PaywallMessage.kt +++ b/superwall/src/main/java/com/superwall/sdk/paywall/view/webview/messaging/PaywallMessage.kt @@ -4,6 +4,7 @@ import com.superwall.sdk.logger.LogLevel import com.superwall.sdk.logger.LogScope import com.superwall.sdk.logger.Logger import com.superwall.sdk.models.paywall.LocalNotificationType +import com.superwall.sdk.paywall.presentation.CustomCallbackBehavior import com.superwall.sdk.permissions.PermissionType import com.superwall.sdk.storage.core_data.convertFromJsonElement import kotlinx.serialization.json.Json @@ -114,6 +115,13 @@ sealed class PaywallMessage { val permissionType: PermissionType, val requestId: String, ) : PaywallMessage() + + data class RequestCallback( + val requestId: String, + val name: String, + val behavior: CustomCallbackBehavior, + val variables: Map?, + ) : PaywallMessage() } fun parseWrappedPaywallMessages(jsonString: String): Result = @@ -220,6 +228,31 @@ private fun parsePaywallMessage(json: JsonObject): PaywallMessage { ) } + "request_callback" -> { + val requestId = + json["request_id"]?.jsonPrimitive?.contentOrNull + ?: throw IllegalArgumentException("request_callback missing request_id") + val name = + json["name"]?.jsonPrimitive?.contentOrNull + ?: throw IllegalArgumentException("request_callback missing name") + val behaviorRaw = + json["behavior"]?.jsonPrimitive?.contentOrNull + ?: throw IllegalArgumentException("request_callback missing behavior") + val behavior = + CustomCallbackBehavior.fromRaw(behaviorRaw) + ?: throw IllegalArgumentException("Unknown behavior: $behaviorRaw") + val variables = + json["variables"]?.jsonObject?.let { variablesJson -> + variablesJson.convertFromJsonElement() as? Map + } + PaywallMessage.RequestCallback( + requestId = requestId, + name = name, + behavior = behavior, + variables = variables, + ) + } + else -> { throw IllegalArgumentException("Unknown event name: $eventName") } diff --git a/superwall/src/main/java/com/superwall/sdk/paywall/view/webview/messaging/PaywallMessageHandler.kt b/superwall/src/main/java/com/superwall/sdk/paywall/view/webview/messaging/PaywallMessageHandler.kt index a919c1e7b..71580d7ac 100644 --- a/superwall/src/main/java/com/superwall/sdk/paywall/view/webview/messaging/PaywallMessageHandler.kt +++ b/superwall/src/main/java/com/superwall/sdk/paywall/view/webview/messaging/PaywallMessageHandler.kt @@ -14,6 +14,10 @@ import com.superwall.sdk.logger.Logger import com.superwall.sdk.misc.MainScope import com.superwall.sdk.models.paywall.LocalNotification import com.superwall.sdk.models.paywall.Paywall +import com.superwall.sdk.paywall.presentation.CustomCallback +import com.superwall.sdk.paywall.presentation.CustomCallbackRegistry +import com.superwall.sdk.paywall.presentation.CustomCallbackResult +import com.superwall.sdk.paywall.presentation.CustomCallbackResultStatus import com.superwall.sdk.paywall.view.PaywallView import com.superwall.sdk.paywall.view.PaywallViewState import com.superwall.sdk.paywall.view.delegate.PaywallLoadingState @@ -68,6 +72,7 @@ class PaywallMessageHandler( private val encodeToB64: (String) -> String, private val userPermissions: UserPermissions, private val getActivity: () -> Activity?, + private val customCallbackRegistry: CustomCallbackRegistry, ) : SendPaywallMessages { private companion object { val selectionString = @@ -233,6 +238,8 @@ class PaywallMessageHandler( is PaywallMessage.RequestPermission -> handleRequestPermission(message) + is PaywallMessage.RequestCallback -> handleRequestCallback(message) + else -> { Logger.debug( LogLevel.error, @@ -618,6 +625,113 @@ class PaywallMessageHandler( passMessageToWebView(base64String = encodeToB64(jsonString)) } + private fun handleRequestCallback(request: PaywallMessage.RequestCallback) { + val paywallIdentifier = messageHandler?.state?.paywall?.identifier ?: "" + + // Emit event to listeners + messageHandler?.eventDidOccur( + PaywallWebEvent.RequestCallback( + name = request.name, + behavior = request.behavior, + requestId = request.requestId, + variables = request.variables, + ), + ) + + // Get the callback handler from the registry + val callbackHandler = customCallbackRegistry.getHandler(paywallIdentifier) + + if (callbackHandler == null) { + Logger.debug( + LogLevel.warn, + LogScope.superwallCore, + "No custom callback handler registered for callback: ${request.name}", + ) + // Send failure response if no handler is registered + ioScope.launch { + sendCallbackResult( + requestId = request.requestId, + name = request.name, + status = CustomCallbackResultStatus.FAILURE, + data = null, + ) + } + return + } + + ioScope.launch { + val result = + try { + val callback = + CustomCallback( + name = request.name, + variables = request.variables, + ) + callbackHandler(callback) + } catch (e: Exception) { + Logger.debug( + LogLevel.error, + LogScope.superwallCore, + "Error executing custom callback: ${e.message}", + error = e, + ) + CustomCallbackResult.failure() + } + + sendCallbackResult( + requestId = request.requestId, + name = request.name, + status = result.status, + data = result.data, + ) + } + } + + /** + * Send a callback_result message back to the webview + */ + private suspend fun sendCallbackResult( + requestId: String, + name: String, + status: CustomCallbackResultStatus, + data: Map?, + ) { + val eventMap = + mutableMapOf( + "event_name" to "callback_result", + "request_id" to requestId, + "name" to name, + "status" to status.rawValue, + ) + + if (data != null) { + eventMap["data"] = data + } + + val eventList = listOf(eventMap) + + val jsonString = + try { + json.encodeToString(eventList.convertToJsonElement()) + } catch (e: Throwable) { + Logger.debug( + LogLevel.error, + LogScope.superwallCore, + "Error encoding callback result: ${e.message}", + error = e, + ) + return + } + + Logger.debug( + LogLevel.debug, + LogScope.superwallCore, + "Sending callback_result: $jsonString", + ) + + passMessageToWebView(base64String = encodeToB64(jsonString)) + } + private fun detectHiddenPaywallEvent( eventName: String, userInfo: Map? = null, diff --git a/superwall/src/main/java/com/superwall/sdk/paywall/view/webview/messaging/PaywallWebEvent.kt b/superwall/src/main/java/com/superwall/sdk/paywall/view/webview/messaging/PaywallWebEvent.kt index 71b0ca9f1..6c4f32c4d 100644 --- a/superwall/src/main/java/com/superwall/sdk/paywall/view/webview/messaging/PaywallWebEvent.kt +++ b/superwall/src/main/java/com/superwall/sdk/paywall/view/webview/messaging/PaywallWebEvent.kt @@ -2,6 +2,7 @@ package com.superwall.sdk.paywall.view.webview.messaging import android.net.Uri import com.superwall.sdk.models.paywall.LocalNotification +import com.superwall.sdk.paywall.presentation.CustomCallbackBehavior import com.superwall.sdk.permissions.PermissionType import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable @@ -70,4 +71,12 @@ sealed class PaywallWebEvent { val permissionType: PermissionType, val requestId: String, ) : PaywallWebEvent() + + @SerialName("request_callback") + data class RequestCallback( + val name: String, + val behavior: CustomCallbackBehavior, + val requestId: String, + val variables: Map?, + ) : PaywallWebEvent() } diff --git a/superwall/src/test/java/com/superwall/sdk/paywall/presentation/CustomCallbackRegistryTest.kt b/superwall/src/test/java/com/superwall/sdk/paywall/presentation/CustomCallbackRegistryTest.kt new file mode 100644 index 000000000..ada9a129e --- /dev/null +++ b/superwall/src/test/java/com/superwall/sdk/paywall/presentation/CustomCallbackRegistryTest.kt @@ -0,0 +1,241 @@ +package com.superwall.sdk.paywall.presentation + +import com.superwall.sdk.Given +import com.superwall.sdk.Then +import com.superwall.sdk.When +import kotlinx.coroutines.test.runTest +import org.junit.Assert.assertEquals +import org.junit.Assert.assertNotNull +import org.junit.Assert.assertNull +import org.junit.Before +import org.junit.Test + +class CustomCallbackRegistryTest { + private lateinit var registry: CustomCallbackRegistry + + @Before + fun setUp() { + registry = CustomCallbackRegistry() + } + + @Test + fun register_stores_handler_correctly() = + runTest { + Given("a CustomCallbackRegistry") { + val handlerInvocations = mutableListOf() + val handler: suspend (CustomCallback) -> CustomCallbackResult = { callback -> + handlerInvocations.add(callback) + CustomCallbackResult.success() + } + + When("a handler is registered for a paywall identifier") { + registry.register("paywall_123", handler) + + Then("the handler can be retrieved") { + val retrieved = registry.getHandler("paywall_123") + assertNotNull(retrieved) + } + } + } + } + + @Test + fun getHandler_returns_null_for_unregistered_paywall() = + runTest { + Given("a CustomCallbackRegistry with no handlers") { + When("getting a handler for an unregistered paywall") { + val result = registry.getHandler("unknown_paywall") + + Then("it returns null") { + assertNull(result) + } + } + } + } + + @Test + fun unregister_removes_handler() = + runTest { + Given("a registry with a registered handler") { + val handler: suspend (CustomCallback) -> CustomCallbackResult = { + CustomCallbackResult.success() + } + registry.register("paywall_456", handler) + + When("the handler is unregistered") { + registry.unregister("paywall_456") + + Then("getHandler returns null for that paywall") { + assertNull(registry.getHandler("paywall_456")) + } + } + } + } + + @Test + fun unregister_does_not_affect_other_handlers() = + runTest { + Given("a registry with multiple handlers") { + val handler1: suspend (CustomCallback) -> CustomCallbackResult = { + CustomCallbackResult.success(mapOf("source" to "handler1")) + } + val handler2: suspend (CustomCallback) -> CustomCallbackResult = { + CustomCallbackResult.success(mapOf("source" to "handler2")) + } + registry.register("paywall_1", handler1) + registry.register("paywall_2", handler2) + + When("one handler is unregistered") { + registry.unregister("paywall_1") + + Then("the other handler is still available") { + assertNull(registry.getHandler("paywall_1")) + assertNotNull(registry.getHandler("paywall_2")) + } + } + } + } + + @Test + fun clear_removes_all_handlers() = + runTest { + Given("a registry with multiple handlers") { + val handler: suspend (CustomCallback) -> CustomCallbackResult = { + CustomCallbackResult.success() + } + registry.register("paywall_a", handler) + registry.register("paywall_b", handler) + registry.register("paywall_c", handler) + + When("clear is called") { + registry.clear() + + Then("all handlers are removed") { + assertNull(registry.getHandler("paywall_a")) + assertNull(registry.getHandler("paywall_b")) + assertNull(registry.getHandler("paywall_c")) + } + } + } + } + + @Test + fun register_overwrites_existing_handler() = + runTest { + Given("a registry with an existing handler") { + var firstHandlerCalled = false + var secondHandlerCalled = false + + val firstHandler: suspend (CustomCallback) -> CustomCallbackResult = { + firstHandlerCalled = true + CustomCallbackResult.success(mapOf("handler" to "first")) + } + val secondHandler: suspend (CustomCallback) -> CustomCallbackResult = { + secondHandlerCalled = true + CustomCallbackResult.success(mapOf("handler" to "second")) + } + + registry.register("paywall_x", firstHandler) + + When("a new handler is registered for the same paywall") { + registry.register("paywall_x", secondHandler) + + Then("the new handler replaces the old one") { + val retrieved = registry.getHandler("paywall_x") + assertNotNull(retrieved) + + val callback = CustomCallback(name = "test", variables = null) + val result = retrieved!!.invoke(callback) + + assertEquals(false, firstHandlerCalled) + assertEquals(true, secondHandlerCalled) + assertEquals("second", result.data?.get("handler")) + } + } + } + } + + @Test + fun registered_handler_receives_correct_callback_data() = + runTest { + Given("a registry with a handler that captures callback data") { + var capturedCallback: CustomCallback? = null + + val handler: suspend (CustomCallback) -> CustomCallbackResult = { callback -> + capturedCallback = callback + CustomCallbackResult.success() + } + registry.register("paywall_capture", handler) + + When("the handler is invoked with callback data") { + val callback = + CustomCallback( + name = "validate_email", + variables = mapOf("email" to "test@example.com", "count" to 42), + ) + + val retrieved = registry.getHandler("paywall_capture") + retrieved!!.invoke(callback) + + Then("the handler receives the correct callback data") { + assertNotNull(capturedCallback) + assertEquals("validate_email", capturedCallback!!.name) + assertEquals("test@example.com", capturedCallback!!.variables?.get("email")) + assertEquals(42, capturedCallback!!.variables?.get("count")) + } + } + } + } + + @Test + fun handler_can_return_success_with_data() = + runTest { + Given("a handler that returns success with data") { + val handler: suspend (CustomCallback) -> CustomCallbackResult = { + CustomCallbackResult.success( + mapOf( + "validated" to true, + "score" to 100, + ), + ) + } + registry.register("paywall_success", handler) + + When("the handler is invoked") { + val callback = CustomCallback(name = "test", variables = null) + val retrieved = registry.getHandler("paywall_success") + val result = retrieved!!.invoke(callback) + + Then("it returns success status with the data") { + assertEquals(CustomCallbackResultStatus.SUCCESS, result.status) + assertEquals(true, result.data?.get("validated")) + assertEquals(100, result.data?.get("score")) + } + } + } + } + + @Test + fun handler_can_return_failure_with_data() = + runTest { + Given("a handler that returns failure with error data") { + val handler: suspend (CustomCallback) -> CustomCallbackResult = { + CustomCallbackResult.failure( + mapOf("error" to "Invalid input"), + ) + } + registry.register("paywall_failure", handler) + + When("the handler is invoked") { + val callback = CustomCallback(name = "test", variables = null) + val retrieved = registry.getHandler("paywall_failure") + val result = retrieved!!.invoke(callback) + + Then("it returns failure status with error data") { + assertEquals(CustomCallbackResultStatus.FAILURE, result.status) + assertEquals("Invalid input", result.data?.get("error")) + } + } + } + } +} diff --git a/superwall/src/test/java/com/superwall/sdk/paywall/view/PaywallMessageHandlerTest.kt b/superwall/src/test/java/com/superwall/sdk/paywall/view/PaywallMessageHandlerTest.kt index 147396cdf..2d865005e 100644 --- a/superwall/src/test/java/com/superwall/sdk/paywall/view/PaywallMessageHandlerTest.kt +++ b/superwall/src/test/java/com/superwall/sdk/paywall/view/PaywallMessageHandlerTest.kt @@ -311,6 +311,9 @@ class PaywallMessageHandlerTest { encodeToB64 = { it }, userPermissions = fakeUserPermissions, getActivity = { null }, + customCallbackRegistry = + com.superwall.sdk.paywall.presentation + .CustomCallbackRegistry(), ) val state = diff --git a/superwall/src/test/java/com/superwall/sdk/paywall/view/PaywallViewTest.kt b/superwall/src/test/java/com/superwall/sdk/paywall/view/PaywallViewTest.kt index dbed79195..cc0992fcb 100644 --- a/superwall/src/test/java/com/superwall/sdk/paywall/view/PaywallViewTest.kt +++ b/superwall/src/test/java/com/superwall/sdk/paywall/view/PaywallViewTest.kt @@ -448,6 +448,9 @@ class PaywallViewTest { encodeToB64 = { it }, userPermissions = fakeUserPermissions, getActivity = { null }, + customCallbackRegistry = + com.superwall.sdk.paywall.presentation + .CustomCallbackRegistry(), ) val state = diff --git a/superwall/src/test/java/com/superwall/sdk/paywall/view/webview/PaywallMessageHandlerTest.kt b/superwall/src/test/java/com/superwall/sdk/paywall/view/webview/PaywallMessageHandlerTest.kt index 615ac5a5f..4bc93e973 100644 --- a/superwall/src/test/java/com/superwall/sdk/paywall/view/webview/PaywallMessageHandlerTest.kt +++ b/superwall/src/test/java/com/superwall/sdk/paywall/view/webview/PaywallMessageHandlerTest.kt @@ -120,6 +120,9 @@ class PaywallMessageHandlerTest { encodeToB64 = encodeToB64, userPermissions = FakeUserPermissions(), getActivity = { null }, + customCallbackRegistry = + com.superwall.sdk.paywall.presentation + .CustomCallbackRegistry(), ) @Test diff --git a/superwall/src/test/java/com/superwall/sdk/paywall/view/webview/messaging/PaywallMessageHandlerEdgeCasesTest.kt b/superwall/src/test/java/com/superwall/sdk/paywall/view/webview/messaging/PaywallMessageHandlerEdgeCasesTest.kt index 40eede2ba..0709817f6 100644 --- a/superwall/src/test/java/com/superwall/sdk/paywall/view/webview/messaging/PaywallMessageHandlerEdgeCasesTest.kt +++ b/superwall/src/test/java/com/superwall/sdk/paywall/view/webview/messaging/PaywallMessageHandlerEdgeCasesTest.kt @@ -13,6 +13,10 @@ import com.superwall.sdk.models.config.ComputedPropertyRequest import com.superwall.sdk.models.events.EventData import com.superwall.sdk.models.paywall.Paywall import com.superwall.sdk.models.product.ProductVariable +import com.superwall.sdk.paywall.presentation.CustomCallback +import com.superwall.sdk.paywall.presentation.CustomCallbackBehavior +import com.superwall.sdk.paywall.presentation.CustomCallbackRegistry +import com.superwall.sdk.paywall.presentation.CustomCallbackResult import com.superwall.sdk.paywall.view.PaywallViewState import com.superwall.sdk.paywall.view.webview.templating.models.JsonVariables import com.superwall.sdk.paywall.view.webview.templating.models.Variables @@ -27,6 +31,7 @@ import kotlinx.coroutines.test.runTest import kotlinx.coroutines.test.setMain import org.junit.After import org.junit.Assert.assertEquals +import org.junit.Assert.assertTrue import org.junit.Before import org.junit.Test @@ -121,6 +126,7 @@ class PaywallMessageHandlerEdgeCasesTest { encodeToB64 = { it }, userPermissions = FakeUserPermissions(), getActivity = { null }, + customCallbackRegistry = CustomCallbackRegistry(), ) // Note: messageHandler is not set @@ -169,6 +175,7 @@ class PaywallMessageHandlerEdgeCasesTest { encodeToB64 = { it }, userPermissions = FakeUserPermissions(), getActivity = { null }, + customCallbackRegistry = CustomCallbackRegistry(), ) handler.messageHandler = delegate @@ -211,6 +218,7 @@ class PaywallMessageHandlerEdgeCasesTest { encodeToB64 = { it }, userPermissions = FakeUserPermissions(), getActivity = { null }, + customCallbackRegistry = CustomCallbackRegistry(), ) handler.messageHandler = delegate @@ -253,6 +261,7 @@ class PaywallMessageHandlerEdgeCasesTest { encodeToB64 = { it }, userPermissions = FakeUserPermissions(), getActivity = { null }, + customCallbackRegistry = CustomCallbackRegistry(), ) handler.messageHandler = delegate @@ -289,6 +298,7 @@ class PaywallMessageHandlerEdgeCasesTest { encodeToB64 = { it }, userPermissions = FakeUserPermissions(), getActivity = { null }, + customCallbackRegistry = CustomCallbackRegistry(), ) handler.messageHandler = delegate @@ -331,6 +341,7 @@ class PaywallMessageHandlerEdgeCasesTest { encodeToB64 = { it }, userPermissions = FakeUserPermissions(), getActivity = { null }, + customCallbackRegistry = CustomCallbackRegistry(), ) handler.messageHandler = delegate @@ -376,6 +387,7 @@ class PaywallMessageHandlerEdgeCasesTest { encodeToB64 = { it }, userPermissions = fakePermissions, getActivity = { null }, + customCallbackRegistry = CustomCallbackRegistry(), ) handler.messageHandler = delegate @@ -425,6 +437,7 @@ class PaywallMessageHandlerEdgeCasesTest { }, userPermissions = FakeUserPermissions(), getActivity = { null }, // No activity available + customCallbackRegistry = CustomCallbackRegistry(), ) handler.messageHandler = delegate @@ -445,4 +458,283 @@ class PaywallMessageHandlerEdgeCasesTest { } } } + + @Test + fun requestCallback_emits_event_with_correct_data() = + runTest { + Given("a handler with a delegate") { + val paywall = Paywall.stub() + val state = PaywallViewState(paywall = paywall, locale = "en-US") + val delegate = RecordingDelegate(state) + val registry = CustomCallbackRegistry() + + val handler = + PaywallMessageHandler( + factory = FakeVariablesFactory(), + options = + object : OptionsFactory { + override fun makeSuperwallOptions(): SuperwallOptions = SuperwallOptions() + }, + track = { _ -> }, + setAttributes = { _ -> }, + getView = { null }, + mainScope = MainScope(Dispatchers.Unconfined), + ioScope = IOScope(Dispatchers.Unconfined), + encodeToB64 = { it }, + userPermissions = FakeUserPermissions(), + getActivity = { null }, + customCallbackRegistry = registry, + ) + handler.messageHandler = delegate + + When("a RequestCallback message is handled") { + handler.handle( + PaywallMessage.RequestCallback( + requestId = "req-123", + name = "validate_email", + behavior = CustomCallbackBehavior.BLOCKING, + variables = mapOf("email" to "test@example.com"), + ), + ) + advanceUntilIdle() + + Then("it emits RequestCallback event with correct data") { + assertEquals(1, delegate.events.size) + val event = delegate.events[0] as PaywallWebEvent.RequestCallback + assertEquals("validate_email", event.name) + assertEquals("req-123", event.requestId) + assertEquals(CustomCallbackBehavior.BLOCKING, event.behavior) + assertEquals("test@example.com", event.variables?.get("email")) + } + } + } + } + + @Test + fun requestCallback_without_registered_handler_sends_failure_result() = + runTest { + Given("a handler with no registered callback handler") { + val paywall = Paywall.stub() + val state = PaywallViewState(paywall = paywall, locale = "en-US") + val delegate = RecordingDelegate(state) + val registry = CustomCallbackRegistry() + val encodedMessages = mutableListOf() + + val handler = + PaywallMessageHandler( + factory = FakeVariablesFactory(), + options = + object : OptionsFactory { + override fun makeSuperwallOptions(): SuperwallOptions = SuperwallOptions() + }, + track = { _ -> }, + setAttributes = { _ -> }, + getView = { null }, + mainScope = MainScope(Dispatchers.Unconfined), + ioScope = IOScope(Dispatchers.Unconfined), + encodeToB64 = { msg -> + encodedMessages.add(msg) + msg + }, + userPermissions = FakeUserPermissions(), + getActivity = { null }, + customCallbackRegistry = registry, + ) + handler.messageHandler = delegate + + When("a RequestCallback message is handled") { + handler.handle( + PaywallMessage.RequestCallback( + requestId = "req-456", + name = "some_callback", + behavior = CustomCallbackBehavior.BLOCKING, + variables = null, + ), + ) + advanceUntilIdle() + + Then("it sends failure result back to webview") { + assertTrue(encodedMessages.isNotEmpty()) + val lastMessage = encodedMessages.last() + assertTrue(lastMessage.contains("callback_result")) + assertTrue(lastMessage.contains("failure")) + assertTrue(lastMessage.contains("req-456")) + } + } + } + } + + @Test + fun requestCallback_invokes_registered_handler_with_correct_data() = + runTest { + Given("a handler with a registered callback handler") { + val paywall = Paywall.stub() + val state = PaywallViewState(paywall = paywall, locale = "en-US") + val delegate = RecordingDelegate(state) + val registry = CustomCallbackRegistry() + var capturedCallback: CustomCallback? = null + + // Register a handler that captures the callback + registry.register(paywall.identifier) { callback -> + capturedCallback = callback + CustomCallbackResult.success() + } + + val handler = + PaywallMessageHandler( + factory = FakeVariablesFactory(), + options = + object : OptionsFactory { + override fun makeSuperwallOptions(): SuperwallOptions = SuperwallOptions() + }, + track = { _ -> }, + setAttributes = { _ -> }, + getView = { null }, + mainScope = MainScope(Dispatchers.Unconfined), + ioScope = IOScope(Dispatchers.Unconfined), + encodeToB64 = { it }, + userPermissions = FakeUserPermissions(), + getActivity = { null }, + customCallbackRegistry = registry, + ) + handler.messageHandler = delegate + + When("a RequestCallback message is handled") { + handler.handle( + PaywallMessage.RequestCallback( + requestId = "req-789", + name = "validate_user", + behavior = CustomCallbackBehavior.NON_BLOCKING, + variables = mapOf("userId" to "user123", "count" to 42), + ), + ) + advanceUntilIdle() + + Then("the registered handler receives the correct callback data") { + assertEquals("validate_user", capturedCallback?.name) + assertEquals("user123", capturedCallback?.variables?.get("userId")) + assertEquals(42, capturedCallback?.variables?.get("count")) + } + } + } + } + + @Test + fun requestCallback_sends_success_result_back_to_webview() = + runTest { + Given("a handler with a registered callback handler that returns success") { + val paywall = Paywall.stub() + val state = PaywallViewState(paywall = paywall, locale = "en-US") + val delegate = RecordingDelegate(state) + val registry = CustomCallbackRegistry() + val encodedMessages = mutableListOf() + + // Register a handler that returns success with data + registry.register(paywall.identifier) { + CustomCallbackResult.success(mapOf("validated" to true, "score" to 100)) + } + + val handler = + PaywallMessageHandler( + factory = FakeVariablesFactory(), + options = + object : OptionsFactory { + override fun makeSuperwallOptions(): SuperwallOptions = SuperwallOptions() + }, + track = { _ -> }, + setAttributes = { _ -> }, + getView = { null }, + mainScope = MainScope(Dispatchers.Unconfined), + ioScope = IOScope(Dispatchers.Unconfined), + encodeToB64 = { msg -> + encodedMessages.add(msg) + msg + }, + userPermissions = FakeUserPermissions(), + getActivity = { null }, + customCallbackRegistry = registry, + ) + handler.messageHandler = delegate + + When("a RequestCallback message is handled") { + handler.handle( + PaywallMessage.RequestCallback( + requestId = "req-success", + name = "check_validation", + behavior = CustomCallbackBehavior.BLOCKING, + variables = null, + ), + ) + advanceUntilIdle() + + Then("it sends success result with data back to webview") { + assertTrue(encodedMessages.isNotEmpty()) + val lastMessage = encodedMessages.last() + assertTrue(lastMessage.contains("callback_result")) + assertTrue(lastMessage.contains("success")) + assertTrue(lastMessage.contains("req-success")) + assertTrue(lastMessage.contains("validated")) + } + } + } + } + + @Test + fun requestCallback_sends_failure_result_on_handler_exception() = + runTest { + Given("a handler with a registered callback handler that throws an exception") { + val paywall = Paywall.stub() + val state = PaywallViewState(paywall = paywall, locale = "en-US") + val delegate = RecordingDelegate(state) + val registry = CustomCallbackRegistry() + val encodedMessages = mutableListOf() + + // Register a handler that throws an exception + registry.register(paywall.identifier) { + throw RuntimeException("Handler error") + } + + val handler = + PaywallMessageHandler( + factory = FakeVariablesFactory(), + options = + object : OptionsFactory { + override fun makeSuperwallOptions(): SuperwallOptions = SuperwallOptions() + }, + track = { _ -> }, + setAttributes = { _ -> }, + getView = { null }, + mainScope = MainScope(Dispatchers.Unconfined), + ioScope = IOScope(Dispatchers.Unconfined), + encodeToB64 = { msg -> + encodedMessages.add(msg) + msg + }, + userPermissions = FakeUserPermissions(), + getActivity = { null }, + customCallbackRegistry = registry, + ) + handler.messageHandler = delegate + + When("a RequestCallback message is handled") { + handler.handle( + PaywallMessage.RequestCallback( + requestId = "req-error", + name = "failing_callback", + behavior = CustomCallbackBehavior.BLOCKING, + variables = null, + ), + ) + advanceUntilIdle() + + Then("it sends failure result back to webview") { + assertTrue(encodedMessages.isNotEmpty()) + val lastMessage = encodedMessages.last() + assertTrue(lastMessage.contains("callback_result")) + assertTrue(lastMessage.contains("failure")) + assertTrue(lastMessage.contains("req-error")) + } + } + } + } }