diff --git a/client.go b/client.go index aa0832c..d2a5c74 100644 --- a/client.go +++ b/client.go @@ -5,24 +5,20 @@ import ( "net/http" ) -const ( - defaultMaxRetryCount = 5 -) - // NewDefaultClient returns a default http client with retry functionality wrapped around the Roundtripper (client.Transport). // // You should not replace the client.Transport field, otherwise you will lose the retry functionality. // // If you need to set / change the original client.Transport field you have two options: // -// 1. create your own http client and use NewCustomClient() function to enrich the client with retry functionality. -// client := &http.Client{} -// client.Transport = &http.Transport{ ... } -// retryClient := httpretry.NewCustomClient(client) -// 2. use one of the helper functions (e.g. httpretry.ModifyOriginalTransport(retryClient)) to retrieve and change the Transport. -// retryClient := httpretry.NewDefaultClient() -// err := httpretry.ModifyOriginalTransport(retryClient, func(t *http.Transport){t.TLSHandshakeTimeout = 5 * time.Second}) -// if err != nil { ... } // will be nil if embedded Roundtripper was not of type http.Transport +// 1. create your own http client and use NewCustomClient() function to enrich the client with retry functionality. +// client := &http.Client{} +// client.Transport = &http.Transport{ ... } +// retryClient := httpretry.NewCustomClient(client) +// 2. use one of the helper functions (e.g. httpretry.ModifyOriginalTransport(retryClient)) to retrieve and change the Transport. +// retryClient := httpretry.NewDefaultClient() +// err := httpretry.ModifyOriginalTransport(retryClient, func(t *http.Transport){t.TLSHandshakeTimeout = 5 * time.Second}) +// if err != nil { ... } // will be nil if embedded Roundtripper was not of type http.Transport func NewDefaultClient(opts ...Option) *http.Client { return NewCustomClient(&http.Client{}, opts...) } @@ -33,32 +29,14 @@ func NewDefaultClient(opts ...Option) *http.Client { // // If you need to change the original client.Transport field you may use the helper functions: // -// err := httpretry.ModifyTransport(retryClient, func(t *http.Transport){t.TLSHandshakeTimeout = 5 * time.Second}) -// if err != nil { ... } // will be nil if embedded Roundtripper was not of type http.Transport +// err := httpretry.ModifyTransport(retryClient, func(t *http.Transport){t.TLSHandshakeTimeout = 5 * time.Second}) +// if err != nil { ... } // will be nil if embedded Roundtripper was not of type http.Transport func NewCustomClient(client *http.Client, opts ...Option) *http.Client { if client == nil { panic("client must not be nil") } - nextRoundtripper := client.Transport - if nextRoundtripper == nil { - nextRoundtripper = http.DefaultTransport - } - - // set defaults - retryRoundtripper := &RetryRoundtripper{ - Next: nextRoundtripper, - MaxRetryCount: defaultMaxRetryCount, - ShouldRetry: defaultRetryPolicy, - CalculateBackoff: defaultBackoffPolicy, - } - - // overwrite defaults with user provided configuration - for _, o := range opts { - o(retryRoundtripper) - } - - client.Transport = retryRoundtripper + client.Transport = NewRoundtripper(client.Transport, opts...) return client } diff --git a/roundtripper.go b/roundtripper.go index 2b769b6..aa772ae 100644 --- a/roundtripper.go +++ b/roundtripper.go @@ -7,6 +7,10 @@ import ( "time" ) +const ( + defaultMaxRetryCount = 5 +) + // RetryRoundtripper is the roundtripper that will wrap around the actual http.Transport roundtripper // to enrich the http client with retry functionality. type RetryRoundtripper struct { @@ -16,6 +20,27 @@ type RetryRoundtripper struct { CalculateBackoff BackoffPolicy } +// NewRoundtripper creates a new RetryRoundtripper with the provided options. +// If no next [net/http.RoundTripper] is provided, the default [net/http.DefaultTransport] will be used. +func NewRoundtripper(next http.RoundTripper, opts ...Option) *RetryRoundtripper { + if next == nil { + next = http.DefaultTransport + } + + roundTripper := &RetryRoundtripper{ + Next: next, + MaxRetryCount: defaultMaxRetryCount, + ShouldRetry: defaultRetryPolicy, + CalculateBackoff: defaultBackoffPolicy, + } + + for _, o := range opts { + o(roundTripper) + } + + return roundTripper +} + // RoundTrip implements the actual roundtripper interface (http.RoundTripper). func (r *RetryRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) { var ( diff --git a/roundtripper_test.go b/roundtripper_test.go index 19d0372..683ab9b 100644 --- a/roundtripper_test.go +++ b/roundtripper_test.go @@ -3,15 +3,42 @@ package httpretry import ( "bytes" "context" - "github.com/stretchr/testify/assert" "io" - "io/ioutil" "net/http" "strings" "testing" "time" + + "github.com/stretchr/testify/assert" ) +func TestNewNewRoundtripper(t *testing.T) { + check := assert.New(t) + + t.Run("should create default rountriper", func(t *testing.T) { + customTransport := &http.Transport{} + roundTripper := NewRoundtripper(customTransport) + + check.Equal(customTransport, roundTripper.Next) + check.Equal(5, roundTripper.MaxRetryCount) + check.NotNil(roundTripper.CalculateBackoff) + check.NotNil(roundTripper.ShouldRetry) + }) + + t.Run("should use default http transport if nil next provided", func(t *testing.T) { + roundTripper := NewRoundtripper(nil) + check.Equal(http.DefaultTransport, roundTripper.Next) + }) + + t.Run("should apply options", func(t *testing.T) { + maxRetryCount := 2 + + roundTripper := NewRoundtripper(nil, WithMaxRetryCount(maxRetryCount)) + + check.Equal(maxRetryCount, roundTripper.MaxRetryCount) + }) +} + func TestRetryRoundtripperSimple(t *testing.T) { check := assert.New(t) @@ -195,7 +222,7 @@ func TestRetryRoundtripperWithBody(t *testing.T) { func readerContains(t *testing.T, r io.Reader, substring string) bool { t.Helper() - d, err := ioutil.ReadAll(r) + d, err := io.ReadAll(r) if err != nil { t.Fatal("could not read body: ", err.Error()) } @@ -268,7 +295,7 @@ func FakeResponse(req *http.Request, code int, body []byte) *http.Response { var contentLength int64 = -1 if len(body) != 0 { - bodyReadCloser = ioutil.NopCloser(bytes.NewReader(body)) + bodyReadCloser = io.NopCloser(bytes.NewReader(body)) contentLength = int64(len(body)) }