From 19fa016ddbd6924d0b39539277ca31ee6395b0b7 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Fri, 26 Dec 2025 11:45:23 +0000 Subject: [PATCH 1/2] Refactor EventHandler and Transport traits for better performance and ergonomics - Updated `EventHandler::on_event` to take `event: String` (owned) instead of `&str`. This allows handlers to consume the event string without cloning if necessary. - Updated `Transport` trait methods (`get`, `post`) to accept headers as `Option<&[(&str, &str)]>` instead of `Option<(Vec<&str>, Vec<&str>)>`. This avoids unnecessary vector allocations and enforces key-value pairing. - Updated `ureq` backend implementation to support the new `Transport` interface. - Updated `DefaultRuntime` to use the new trait definitions and handle headers more efficiently. - Added error logging to the runtime loop for failed invocation fetches. - Addressed Clippy warnings in `src/api/handler.rs` (doc comment), `src/api/response.rs` (range matching), and `src/backends/ureq.rs` (error handling). - Updated `examples/echo-server.rs` to reflect the `EventHandler` trait change. --- examples/echo-server.rs | 2 +- src/api/handler.rs | 5 +++-- src/api/response.rs | 2 +- src/api/transport.rs | 4 ++-- src/backends/ureq.rs | 21 +++++++-------------- src/runtime/mod.rs | 24 +++++++++++++++++++----- 6 files changed, 33 insertions(+), 25 deletions(-) diff --git a/examples/echo-server.rs b/examples/echo-server.rs index 719f78d..2728fff 100644 --- a/examples/echo-server.rs +++ b/examples/echo-server.rs @@ -35,7 +35,7 @@ impl EventHandler for EchoEventHandler { fn on_event( &mut self, - event: &str, + event: String, context: &Ctx, ) -> Result { // Get the aws request id diff --git a/src/api/handler.rs b/src/api/handler.rs index e5f25bf..b9426fb 100644 --- a/src/api/handler.rs +++ b/src/api/handler.rs @@ -22,12 +22,13 @@ pub trait EventHandler: Sized { /// Processes each incoming lambda event and returns a [`Result`] with the lambda's output. /// # Arguments /// - /// * `event` - The JSON event as a string slice, should be deserialized by the implementation. + /// * `event` - The JSON event as a string, should be deserialized by the implementation. /// * `context` - A shared reference to the current event context. + /// /// `Ctx` Defines the context object type, typically a [`crate::data::context::EventContext`]. fn on_event( &mut self, - event: &str, + event: String, context: &Ctx, ) -> Result; } diff --git a/src/api/response.rs b/src/api/response.rs index 911e16b..6ef8f0a 100644 --- a/src/api/response.rs +++ b/src/api/response.rs @@ -37,6 +37,6 @@ pub trait LambdaAPIResponse { } fn is_err(&self) -> bool { - matches!(self.get_status_code(), 400..=499 | 500..=599) + matches!(self.get_status_code(), 400..=599) } } diff --git a/src/api/transport.rs b/src/api/transport.rs index 8845dfd..e0199e4 100644 --- a/src/api/transport.rs +++ b/src/api/transport.rs @@ -17,13 +17,13 @@ pub trait Transport: Default { &self, url: &str, body: Option<&str>, - headers: Option<(Vec<&str>, Vec<&str>)>, + headers: Option<&[(&str, &str)]>, ) -> Result; /// Sends an HTTP POST request to the specified `url` with the optional `body` and `headers`. fn post( &self, url: &str, body: Option<&str>, - headers: Option<(Vec<&str>, Vec<&str>)>, + headers: Option<&[(&str, &str)]>, ) -> Result; } diff --git a/src/backends/ureq.rs b/src/backends/ureq.rs index 7afca1b..9de9eea 100644 --- a/src/backends/ureq.rs +++ b/src/backends/ureq.rs @@ -31,13 +31,8 @@ impl LambdaAPIResponse for ureq::Response { #[inline] fn get_deadline(&self) -> Option { - match self.header(AWS_DEADLINE_MS) { - Some(ms) => match ms.parse::() { - Ok(val) => Some(val), - Err(_) => None, - }, - None => None, - } + self.header(AWS_DEADLINE_MS) + .and_then(|ms| ms.parse::().ok()) } #[inline] @@ -84,14 +79,12 @@ impl UreqTransport { method: &str, url: &str, body: Option<&str>, - headers: Option<(Vec<&str>, Vec<&str>)>, + headers: Option<&[(&str, &str)]>, ) -> Result { let mut req = self.agent.request(method, url); if let Some(headers) = headers { - let (keys, values) = headers; - let len = std::cmp::min(keys.len(), values.len()); - for i in 0..len { - req = req.set(keys[i], values[i]); + for (key, value) in headers { + req = req.set(key, value); } } if let Some(body) = body { @@ -110,7 +103,7 @@ impl Transport for UreqTransport { &self, url: &str, body: Option<&str>, - headers: Option<(Vec<&str>, Vec<&str>)>, + headers: Option<&[(&str, &str)]>, ) -> Result { self.request("GET", url, body, headers) } @@ -119,7 +112,7 @@ impl Transport for UreqTransport { &self, url: &str, body: Option<&str>, - headers: Option<(Vec<&str>, Vec<&str>)>, + headers: Option<&[(&str, &str)]>, ) -> Result { self.request("POST", url, body, headers) } diff --git a/src/runtime/mod.rs b/src/runtime/mod.rs index 37cc587..71f728e 100644 --- a/src/runtime/mod.rs +++ b/src/runtime/mod.rs @@ -117,7 +117,10 @@ where // Failing to get the next event will either panic (on server error) or continue with an error (on client-error codes). let next_invo = match self.next_invocation() { // TODO - perhaps log the error - Err(_e) => continue, + Err(e) => { + eprintln!("Failed to get next invocation: {}", e); + continue; + } Ok(resp) => resp, }; @@ -126,12 +129,11 @@ where let event = next_invo.get_body().unwrap(); // Execute the event handler - // TODO - pass the event an an owned String let lambda_output = self .handler .as_mut() .unwrap() - .on_event(&event, &self.context); + .on_event(event, &self.context); let request_id = self.context.get_aws_request_id().unwrap(); // TODO - figure out what we'd like to do with the result returned from success/client-err api responses (e.g: log, run a user defined callback...) @@ -211,7 +213,13 @@ where "http://{}/{}/runtime/init/error", self.api_base, self.version ); - let headers = error_type.map(|et| (vec![AWS_FUNC_ERR_TYPE], vec![et])); + let headers_vec; + let headers = if let Some(et) = error_type { + headers_vec = [(AWS_FUNC_ERR_TYPE, et)]; + Some(&headers_vec[..]) + } else { + None + }; let resp = self.transport.post(&url, error_req, headers)?; handle_response!(resp); @@ -228,7 +236,13 @@ where "http://{}/{}/runtime/invocation/{}/error", self.api_base, self.version, request_id ); - let headers = error_type.map(|et| (vec![AWS_FUNC_ERR_TYPE], vec![et])); + let headers_vec; + let headers = if let Some(et) = error_type { + headers_vec = [(AWS_FUNC_ERR_TYPE, et)]; + Some(&headers_vec[..]) + } else { + None + }; let resp = self.transport.post(&url, error_req, headers)?; handle_response!(resp); From 7bb047052aed93a78622e55b42fd0e4a72020e66 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Fri, 26 Dec 2025 12:02:06 +0000 Subject: [PATCH 2/2] Add unit tests and refactor core traits for better testability and ergonomics - Added unit tests for `EventContext` in `src/data/tests.rs` covering environment variable parsing. - Added unit tests for `DefaultRuntime` in `src/runtime/tests.rs` using a `MockTransport` to verify runtime behavior without network calls. - Refactored `Transport` trait to accept headers as `Option<&[(&str, &str)]>` for better ergonomics and performance. - Refactored `EventHandler` trait to take `event: String` (owned) to allow flexible ownership transfer. - Updated `ureq` backend and `examples/echo-server.rs` to match API changes. - Added `serial_test` dependency to ensure thread safety for tests modifying environment variables. - Improved error logging in the runtime loop. --- Cargo.toml | 1 + src/backends/ureq.rs | 2 +- src/data/mod.rs | 2 + src/data/tests.rs | 47 ++++++++++ src/runtime/mod.rs | 21 ++--- src/runtime/tests.rs | 218 +++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 276 insertions(+), 15 deletions(-) create mode 100644 src/data/tests.rs create mode 100644 src/runtime/tests.rs diff --git a/Cargo.toml b/Cargo.toml index 74f0c23..0ac9992 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ serde_json = { version = "1.0" } [dev-dependencies] serde = { version = "1", features = ["derive"] } +serial_test = "2.0" [features] default = ["ureq"] diff --git a/src/backends/ureq.rs b/src/backends/ureq.rs index 9de9eea..00afdb2 100644 --- a/src/backends/ureq.rs +++ b/src/backends/ureq.rs @@ -83,7 +83,7 @@ impl UreqTransport { ) -> Result { let mut req = self.agent.request(method, url); if let Some(headers) = headers { - for (key, value) in headers { + for (key, value) in headers.iter() { req = req.set(key, value); } } diff --git a/src/data/mod.rs b/src/data/mod.rs index c3959f1..2e0e8a9 100644 --- a/src/data/mod.rs +++ b/src/data/mod.rs @@ -4,3 +4,5 @@ /// Defines the interface of the context object and provides an implementation for it. pub mod context; +#[cfg(test)] +mod tests; diff --git a/src/data/tests.rs b/src/data/tests.rs new file mode 100644 index 0000000..3a668b9 --- /dev/null +++ b/src/data/tests.rs @@ -0,0 +1,47 @@ +#[cfg(test)] +mod tests { + use crate::api::{InitializationType, LambdaContext, LambdaContextSetter, LambdaEnvVars}; + use crate::data::context::EventContext; + use serial_test::serial; + use std::env; + + #[test] + #[serial] + fn test_context_default_from_env() { + // Set some env vars + env::set_var("AWS_LAMBDA_FUNCTION_NAME", "my-func"); + env::set_var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE", "128"); + env::set_var("AWS_LAMBDA_INITIALIZATION_TYPE", "on-demand"); + env::set_var("_HANDLER", "index.handler"); + + let ctx = EventContext::default(); + + assert_eq!(ctx.get_lambda_function_name(), Some("my-func")); + assert_eq!(ctx.get_lambda_function_memory_size(), Some(128)); + assert!(matches!( + ctx.get_lambda_initialization_type(), + InitializationType::OnDemand + )); + assert_eq!(ctx.get_handler_location(), Some("index.handler")); + + // Clean up + env::remove_var("AWS_LAMBDA_FUNCTION_NAME"); + env::remove_var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE"); + env::remove_var("AWS_LAMBDA_INITIALIZATION_TYPE"); + env::remove_var("_HANDLER"); + } + + #[test] + fn test_context_setters() { + let mut ctx = EventContext::default(); + + ctx.set_aws_request_id(Some("req-123")); + assert_eq!(ctx.get_aws_request_id(), Some("req-123")); + + ctx.set_invoked_function_arn(Some("arn:aws:lambda:us-east-1:123456789012:function:my-func")); + assert_eq!( + ctx.get_invoked_function_arn(), + Some("arn:aws:lambda:us-east-1:123456789012:function:my-func") + ); + } +} diff --git a/src/runtime/mod.rs b/src/runtime/mod.rs index 71f728e..5b62fc9 100644 --- a/src/runtime/mod.rs +++ b/src/runtime/mod.rs @@ -11,6 +11,9 @@ use crate::api::{ use crate::data::context::EventContext; use crate::error::{Error, CONTAINER_ERR}; +#[cfg(test)] +mod tests; + // Already handles any panic inducing errors macro_rules! handle_response { ($resp:expr) => { @@ -213,13 +216,8 @@ where "http://{}/{}/runtime/init/error", self.api_base, self.version ); - let headers_vec; - let headers = if let Some(et) = error_type { - headers_vec = [(AWS_FUNC_ERR_TYPE, et)]; - Some(&headers_vec[..]) - } else { - None - }; + let headers_storage = error_type.map(|et| [(AWS_FUNC_ERR_TYPE, et)]); + let headers = headers_storage.as_ref().map(|h| &h[..]); let resp = self.transport.post(&url, error_req, headers)?; handle_response!(resp); @@ -236,13 +234,8 @@ where "http://{}/{}/runtime/invocation/{}/error", self.api_base, self.version, request_id ); - let headers_vec; - let headers = if let Some(et) = error_type { - headers_vec = [(AWS_FUNC_ERR_TYPE, et)]; - Some(&headers_vec[..]) - } else { - None - }; + let headers_storage = error_type.map(|et| [(AWS_FUNC_ERR_TYPE, et)]); + let headers = headers_storage.as_ref().map(|h| &h[..]); let resp = self.transport.post(&url, error_req, headers)?; handle_response!(resp); diff --git a/src/runtime/tests.rs b/src/runtime/tests.rs new file mode 100644 index 0000000..44bce59 --- /dev/null +++ b/src/runtime/tests.rs @@ -0,0 +1,218 @@ +#[cfg(test)] +mod tests { + use crate::api::{EventHandler, LambdaContext, LambdaAPIResponse, LambdaRuntime, Transport}; + use crate::error::Error; + use crate::runtime::DefaultRuntime; + use std::cell::RefCell; + use std::collections::VecDeque; + use serde::Serialize; + use serial_test::serial; + + // --- Mock Transport --- + + #[derive(Clone)] + struct MockResponse { + body: String, + status: u16, + headers: Vec<(String, String)>, + } + + impl LambdaAPIResponse for MockResponse { + fn get_body(self) -> Result { + Ok(self.body) + } + fn get_status_code(&self) -> u16 { + self.status + } + fn get_aws_request_id(&self) -> Option<&str> { + self.headers.iter().find(|(k, _)| k == "Lambda-Runtime-Aws-Request-Id").map(|(_, v)| v.as_str()) + } + fn get_deadline(&self) -> Option { + None + } + fn get_invoked_function_arn(&self) -> Option<&str> { + None + } + fn get_x_ray_tracing_id(&self) -> Option<&str> { + None + } + fn get_client_context(&self) -> Option<&str> { + None + } + fn get_cognito_identity(&self) -> Option<&str> { + None + } + } + + #[derive(Default)] + struct MockTransport; + + thread_local! { + static MOCK_QUEUE: RefCell> = RefCell::new(VecDeque::new()); + static REQUEST_LOG: RefCell)>> = RefCell::new(Vec::new()); // (method, url, body) + } + + impl MockTransport { + fn push_response(body: &str, status: u16, request_id: Option<&str>) { + let mut headers = Vec::new(); + if let Some(rid) = request_id { + headers.push(("Lambda-Runtime-Aws-Request-Id".to_string(), rid.to_string())); + } + MOCK_QUEUE.with(|q| { + q.borrow_mut().push_back(MockResponse { + body: body.to_string(), + status, + headers, + }) + }); + } + + fn pop_request() -> Option<(String, String, Option)> { + REQUEST_LOG.with(|l| l.borrow_mut().pop()) + } + + fn clear() { + MOCK_QUEUE.with(|q| q.borrow_mut().clear()); + REQUEST_LOG.with(|l| l.borrow_mut().clear()); + } + } + + impl Transport for MockTransport { + type Response = MockResponse; + + fn get( + &self, + url: &str, + body: Option<&str>, + _headers: Option<&[(&str, &str)]>, + ) -> Result { + REQUEST_LOG.with(|l| l.borrow_mut().push(("GET".to_string(), url.to_string(), body.map(|s| s.to_string())))); + MOCK_QUEUE.with(|q| { + q.borrow_mut().pop_front().ok_or(Error::new("No mock response".to_string())) + }) + } + + fn post( + &self, + url: &str, + body: Option<&str>, + headers: Option<&[(&str, &str)]>, + ) -> Result { + let mut logged_headers = Vec::new(); + if let Some(h) = headers { + for (k, v) in h.iter() { + logged_headers.push((k.to_string(), v.to_string())); + } + } + // For now we just log the body/url, but if we wanted to check headers we could add another field to REQUEST_LOG + // Or append to body for verification in tests. + // Let's modify REQUEST_LOG to include headers? + // Or simpler: just log if we find the error header. + let mut body_str = body.map(|s| s.to_string()).unwrap_or_default(); + if let Some(h) = headers { + for (k, v) in h.iter() { + if *k == "Lambda-Runtime-Function-Error-Type" { + body_str.push_str(&format!("|Header:{}:{}", k, v)); + } + } + } + + REQUEST_LOG.with(|l| l.borrow_mut().push(("POST".to_string(), url.to_string(), Some(body_str)))); + MOCK_QUEUE.with(|q| { + q.borrow_mut().pop_front().ok_or(Error::new("No mock response".to_string())) + }) + } + } + + // --- Mock Handler --- + + struct TestEventHandler; + + #[derive(Serialize)] + struct TestOutput { + msg: String, + } + + impl EventHandler for TestEventHandler { + type EventOutput = TestOutput; + type EventError = String; + type InitError = String; + + fn initialize() -> Result { + Ok(TestEventHandler) + } + + fn on_event( + &mut self, + event: String, + _context: &Ctx, + ) -> Result { + if event == "\"fail\"" { + Err("Failed".to_string()) + } else { + Ok(TestOutput { msg: format!("Echo: {}", event) }) + } + } + } + + // --- Tests --- + + #[test] + #[serial] + fn test_runtime_next_invocation() { + MockTransport::clear(); + std::env::set_var("AWS_LAMBDA_RUNTIME_API", "localhost:8080"); + + MockTransport::push_response("\"hello\"", 200, Some("req-1")); + + let mut runtime = DefaultRuntime::::new("2018-06-01"); + + // This fails if MockTransport doesn't work or logic is wrong + let resp = runtime.next_invocation().expect("Failed to get next invocation"); + + assert_eq!(resp.get_aws_request_id(), Some("req-1")); + assert_eq!(resp.get_body().unwrap(), "\"hello\""); + } + + #[test] + #[serial] + fn test_runtime_invocation_response() { + MockTransport::clear(); + std::env::set_var("AWS_LAMBDA_RUNTIME_API", "localhost:8080"); + + MockTransport::push_response("", 202, None); + + let runtime = DefaultRuntime::::new("2018-06-01"); + + let output = TestOutput { msg: "success".to_string() }; + let resp = runtime.invocation_response("req-1", &output).expect("Failed to send response"); + + assert_eq!(resp.get_status_code(), 202); + + let req = MockTransport::pop_request().unwrap(); + assert_eq!(req.0, "POST"); + assert!(req.1.contains("/invocation/req-1/response")); + assert_eq!(req.2.unwrap(), "{\"msg\":\"success\"}"); + } + + #[test] + #[serial] + fn test_runtime_invocation_error() { + MockTransport::clear(); + std::env::set_var("AWS_LAMBDA_RUNTIME_API", "localhost:8080"); + + MockTransport::push_response("", 202, None); + + let runtime = DefaultRuntime::::new("2018-06-01"); + + let resp = runtime.invocation_error("req-1", Some("ErrorType"), Some("ErrorMsg")).expect("Failed to report error"); + + assert_eq!(resp.get_status_code(), 202); + + let req = MockTransport::pop_request().unwrap(); + assert_eq!(req.0, "POST"); + assert!(req.1.contains("/invocation/req-1/error")); + // Check for the error header which we appended to the body string in MockTransport::post + assert!(req.2.unwrap().contains("|Header:Lambda-Runtime-Function-Error-Type:ErrorType")); + } +}