diff --git a/Makefile b/Makefile index f68987d..56961e0 100644 --- a/Makefile +++ b/Makefile @@ -19,10 +19,11 @@ build: rm -rf build GOOS=linux GOARCH=${ARCH} go build -o build/extensions/firetail-extension-${ARCH} chmod +x build/extensions/firetail-extension-${ARCH} + cp firetail-wrapper.sh build/firetail-wrapper.sh .PHONY: package package: build - cd build && zip -r ../build/firetail-extension-${ARCH}-${VERSION}.zip extensions/ + cd build && zip -r ../build/firetail-extension-${ARCH}-${VERSION}.zip extensions/ firetail-wrapper.sh .PHONY: publish publish: @@ -34,4 +35,4 @@ public: .PHONY: add add: - aws lambda update-function-configuration --region ${AWS_REGION} --function-name ${FUNCTION_NAME} --layers ${LAYER_ARN} \ No newline at end of file + aws lambda update-function-configuration --region ${AWS_REGION} --function-name ${FUNCTION_NAME} --layers ${LAYER_ARN} diff --git a/examples/minimal-python/README.md b/examples/minimal-python/README.md index 66ae386..65e7b86 100644 --- a/examples/minimal-python/README.md +++ b/examples/minimal-python/README.md @@ -21,9 +21,8 @@ This example demonstrates how to setup a simple HTTP GET endpoint. Once you fetc ## Deploy ```bash -pip3 install -t src/vendor -r aws_requirements.txt npm install -serverless deploy +serverless deploy --param firetail-token=YOUR_API_TOKEN ``` The expected result should be similar to: diff --git a/examples/minimal-python/aws_requirements.txt b/examples/minimal-python/aws_requirements.txt deleted file mode 100644 index 94eb19f..0000000 --- a/examples/minimal-python/aws_requirements.txt +++ /dev/null @@ -1 +0,0 @@ -firetail-lambda \ No newline at end of file diff --git a/examples/minimal-python/handler.py b/examples/minimal-python/handler.py index 3e7684d..53b0267 100644 --- a/examples/minimal-python/handler.py +++ b/examples/minimal-python/handler.py @@ -3,21 +3,13 @@ import sys # Deps in src/vendor -sys.path.insert(0, 'src/vendor') +sys.path.insert(0, "src/vendor") -from firetail_lambda import firetail_handler, firetail_app # noqa: E402 -app = firetail_app() - -@firetail_handler(app) def endpoint(event, context): current_time = datetime.datetime.now().time() return { "statusCode": 200, - "body": json.dumps({ - "message": "Hello, the current time is %s" % current_time - }), - "headers": { - "Current-Time": "%s" % current_time - } + "body": json.dumps({"message": "Hello, the current time is %s" % current_time}), + "headers": {"Current-Time": "%s" % current_time}, } diff --git a/examples/minimal-python/serverless.yml b/examples/minimal-python/serverless.yml index 6d8ecb4..d4415c8 100644 --- a/examples/minimal-python/serverless.yml +++ b/examples/minimal-python/serverless.yml @@ -8,6 +8,7 @@ provider: environment: FIRETAIL_API_TOKEN: ${param:firetail-token} FIRETAIL_EXTENSION_DEBUG: TRUE + AWS_LAMBDA_EXEC_WRAPPER: /opt/firetail-wrapper.sh tracing: true iamRoleStatements: - Effect: "Allow" diff --git a/firetail-wrapper.sh b/firetail-wrapper.sh new file mode 100644 index 0000000..c46e6b3 --- /dev/null +++ b/firetail-wrapper.sh @@ -0,0 +1,4 @@ +#!/bin/bash +args=("$@") +export AWS_LAMBDA_RUNTIME_API="127.0.0.1:${FIRETAIL_LAMBDA_EXTENSION_PORT:-9009}" +exec "${args[@]}" diff --git a/firetail/record_receiver.go b/firetail/record_receiver.go new file mode 100644 index 0000000..c66eeb2 --- /dev/null +++ b/firetail/record_receiver.go @@ -0,0 +1,55 @@ +package firetail + +import "log" + +// recordReceiver receives records from the client into batches & passes them to the batch callback. If the batch callback +// returns an err, it does not remove the log entries from the batch. +func RecordReceiver(recordsChannel chan Record, maxBatchSize int, firetailApiUrl, firetailApiToken string) { + recordsBatch := []Record{} + + for { + newRecords, recordsRemaining := receiveRecords(recordsChannel, maxBatchSize-len(recordsBatch)) + recordsBatch = append(recordsBatch, newRecords...) + + // If the batch is empty, but there's records remaining, then we continue; else we return. + if len(recordsBatch) == 0 { + if recordsRemaining { + continue + } else { + return + } + } + + // Give the batch to the batch callback. If it errs, we continue + recordsSent, err := SendRecordsToSaaS(recordsBatch, firetailApiUrl, firetailApiToken) + if err != nil { + log.Println("Error sending records to Firetail:", err.Error()) + continue + } + log.Println("Successfully sent", recordsSent, "record(s) to Firetail.") + + // If the batch callback succeeded, we can clear the batch! + recordsBatch = []Record{} + } +} + +// ReceiveRecords returns a slice of firetail Records up to the size of `limit`, and a boolean indicating that the channel +// still has items to be read - it will only be `false` when the channel is closed & empty. It achieves this by continuously +// reading from the log server's recordsChannel until it's empty, or the size limit has been reached. +func receiveRecords(recordsChannel chan Record, limit int) ([]Record, bool) { + records := []Record{} + for { + select { + case record, open := <-recordsChannel: + if !open { + return records, false + } + records = append(records, record) + if len(records) == limit { + return records, true + } + default: + return records, true + } + } +} diff --git a/go.mod b/go.mod index 95ee593..c927461 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/go-chi/chi/v5 v5.2.1 github.com/hashicorp/errwrap v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index d2fbd9a..2d8206a 100644 --- a/go.sum +++ b/go.sum @@ -3,6 +3,8 @@ github.com/aws/aws-lambda-go v1.34.1/go.mod h1:jwFe2KmMsHmffA1X2R09hH6lFzJQxzI8q github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8= +github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= diff --git a/main.go b/main.go index 2fccb26..29cbd8a 100644 --- a/main.go +++ b/main.go @@ -2,7 +2,9 @@ package main import ( "firetail-lambda-extension/extensionsapi" + "firetail-lambda-extension/firetail" "firetail-lambda-extension/logsapi" + "firetail-lambda-extension/proxy" "fmt" "io/ioutil" "log" @@ -35,16 +37,36 @@ func main() { } log.Println("Registered extension, ID:", extensionClient.ExtensionID) - // Create a logsApiClient, start it & remember to shut it down when we're done - logsApiClient, err := logsapi.NewClient(logsapi.Options{ - ExtensionID: extensionClient.ExtensionID, - LogServerAddress: "sandbox:1234", - }) - if err != nil { - panic(err) + // In legacy mode, we use the logs API. Otherwise, we use the new proxy client. + if isLegacy, err := strconv.ParseBool(os.Getenv("FIRETAIL_EXTENSION_LEGACY")); err == nil && isLegacy { + // Create a logsApiClient, start it & remember to shut it down when we're done + logsApiClient, err := logsapi.NewClient(logsapi.Options{ + ExtensionID: extensionClient.ExtensionID, + LogServerAddress: "sandbox:1234", + }) + if err != nil { + panic(err) + } + go logsApiClient.Start(ctx) + defer logsApiClient.Shutdown(ctx) + } else { + firetailApiUrl, firetailApiUrlSet := os.LookupEnv("FIRETAIL_API_URL") + if !firetailApiUrlSet { + firetailApiUrl = logsapi.DefaultFiretailApiUrl + } + proxyServer, err := proxy.NewProxyServer() + if err != nil { + panic(err) + } + go proxyServer.ListenAndServe() + defer proxyServer.Shutdown(ctx) + go firetail.RecordReceiver( + proxyServer.RecordsChannel, + logsapi.DefaultMaxBatchSize, + firetailApiUrl, + os.Getenv("FIRETAIL_API_TOKEN"), + ) } - go logsApiClient.Start(ctx) - defer logsApiClient.Shutdown(ctx) // awaitShutdown will block until a shutdown event is received, or the context is cancelled reason, err := awaitShutdown(extensionClient, ctx) diff --git a/proxy/proxy.go b/proxy/proxy.go new file mode 100644 index 0000000..3d017c2 --- /dev/null +++ b/proxy/proxy.go @@ -0,0 +1,184 @@ +package proxy + +import ( + "context" + "encoding/json" + "firetail-lambda-extension/firetail" + "fmt" + "io" + "log" + "net/http" + "net/url" + "os" + "strconv" + "time" + + "github.com/go-chi/chi/v5" +) + +type ProxyServer struct { + runtimeEndpoint string + port int + server *http.Server + eventsChannel chan *http.Response + lambdaResponseChannel chan *http.Request + RecordsChannel chan firetail.Record +} + +func NewProxyServer() (*ProxyServer, error) { + portStr, portSet := os.LookupEnv("FIRETAIL_LAMBDA_EXTENSION_PORT") + var port int + var err error + if port, err = strconv.Atoi(portStr); err != nil || !portSet { + port = 9009 + } + + ps := &ProxyServer{ + runtimeEndpoint: os.Getenv("AWS_LAMBDA_RUNTIME_API"), + port: port, + eventsChannel: make(chan *http.Response), + lambdaResponseChannel: make(chan *http.Request), + RecordsChannel: make(chan firetail.Record), + } + + r := chi.NewRouter() + + handleError := func(w http.ResponseWriter, r *http.Request) { + http.Error(w, http.StatusText(404), 404) + } + r.NotFound(handleError) + r.MethodNotAllowed(handleError) + + initEndpoint, err := url.Parse( + fmt.Sprintf( + "http://%s/2018-06-01/runtime/init/error", + ps.runtimeEndpoint, + ), + ) + if err != nil { + return nil, err + } + initErrorHandler := getProxyHandler( + func(r *http.Request) (*url.URL, error) { + return initEndpoint, nil + }, + nil, + nil, + ) + r.Post("/2018-06-01/runtime/init/error", initErrorHandler) + + invokeErrorHandler := getProxyHandler( + func(r *http.Request) (*url.URL, error) { + return url.Parse( + fmt.Sprintf( + "http://%s/2018-06-01/runtime/invocation/%s/error", + ps.runtimeEndpoint, + chi.URLParam(r, "requestId"), + ), + ) + }, + nil, + nil, + ) + r.Post("/2018-06-01/runtime/invocation/{requestId}/error", invokeErrorHandler) + + nextEndpoint, err := url.Parse( + fmt.Sprintf( + "http://%s/2018-06-01/runtime/invocation/next", + ps.runtimeEndpoint, + ), + ) + if err != nil { + return nil, err + } + nextHandler := getProxyHandler( + func(r *http.Request) (*url.URL, error) { + return nextEndpoint, nil + }, + nil, + &ps.eventsChannel, + ) + r.Get("/2018-06-01/runtime/invocation/next", nextHandler) + + responseHandler := getProxyHandler( + func(r *http.Request) (*url.URL, error) { + return url.Parse( + fmt.Sprintf( + "http://%s/2018-06-01/runtime/invocation/%s/response", + ps.runtimeEndpoint, + chi.URLParam(r, "requestId"), + ), + ) + }, + &ps.lambdaResponseChannel, + nil, + ) + r.Post("/2018-06-01/runtime/invocation/{requestId}/response", responseHandler) + + ps.server = &http.Server{ + Addr: fmt.Sprintf(":%d", ps.port), + Handler: r, + } + + return ps, nil +} + +func (p *ProxyServer) recordAssembler() { + for { + // Events and lambda responses should come in pairs, event first and response second. + event, ok := <-p.eventsChannel + if !ok { + log.Println("Events channel closed, stopping record assembler.") + return + } + + // We can record the time between receiving the event and the response + // to calculate the execution time of the lambda function. + eventReceivedAt := time.Now() + + lambdaResponse, ok := <-p.lambdaResponseChannel + if !ok { + log.Println("Lambda response channel closed, stopping record assembler.") + return + } + + executionTime := time.Since(eventReceivedAt) + + eventBody, err := io.ReadAll(event.Body) + if err != nil { + log.Println("Error reading event body:", err) + continue + } + responseBody, err := io.ReadAll(lambdaResponse.Body) + if err != nil { + log.Println("Error reading response body:", err) + continue + } + + var recordResponse firetail.RecordResponse + if err := json.Unmarshal(responseBody, &recordResponse); err != nil { + log.Println("Error unmarshalling response body:", err) + continue + } + + p.RecordsChannel <- firetail.Record{ + Event: eventBody, + Response: recordResponse, + ExecutionTime: executionTime.Seconds(), + } + } +} + +func (p *ProxyServer) ListenAndServe() error { + go p.recordAssembler() + return p.server.ListenAndServe() +} + +func (p *ProxyServer) Shutdown(ctx context.Context) error { + if err := p.server.Shutdown(ctx); err != nil { + return err + } + close(p.eventsChannel) + close(p.lambdaResponseChannel) + return nil +} diff --git a/proxy/proxy_handler.go b/proxy/proxy_handler.go new file mode 100644 index 0000000..f3a6196 --- /dev/null +++ b/proxy/proxy_handler.go @@ -0,0 +1,64 @@ +package proxy + +import ( + "io" + "net/http" + "net/url" + "strings" +) + +func getProxyHandler(urlMappingFunc func(r *http.Request) (*url.URL, error), requestChannel *chan *http.Request, responseChannel *chan *http.Response) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Get the target URL from the mapping function + targetUrl, err := urlMappingFunc(r) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Set the request URL to the target URL + r.RequestURI = "" + r.Host = targetUrl.Host + r.URL = targetUrl + + // Make a copy of the request body + var requestBodyCopy strings.Builder + r.Body = io.NopCloser(io.TeeReader(r.Body, &requestBodyCopy)) + + // Do the request + resp, err := (&http.Client{}).Do(r) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Send the request to the requestChannel with the copied body if the channel was provided + if requestChannel != nil { + r.Body = io.NopCloser(strings.NewReader(requestBodyCopy.String())) + *requestChannel <- r + } + + // Make a copy of the response body + var responseBodyCopy strings.Builder + resp.Body = io.NopCloser(io.TeeReader(resp.Body, &responseBodyCopy)) + + // Write the response to the original response writer + defer resp.Body.Close() + for key, value := range resp.Header { + w.Header()[strings.ToLower(key)] = value + } + w.WriteHeader(resp.StatusCode) + body, err := io.ReadAll(resp.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Write(body) + + // Send the response to the responseChannel with the copied body if the channel was provided + if responseChannel != nil { + resp.Body = io.NopCloser(strings.NewReader(responseBodyCopy.String())) + *responseChannel <- resp + } + } +}