Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 11 additions & 33 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,20 @@ import (
"net/http"
)

const (
defaultMaxRetryCount = 5
)

// NewDefaultClient returns a default http client with retry functionality wrapped around the Roundtripper (client.Transport).
//
// You should not replace the client.Transport field, otherwise you will lose the retry functionality.
//
// If you need to set / change the original client.Transport field you have two options:
//
// 1. create your own http client and use NewCustomClient() function to enrich the client with retry functionality.
// client := &http.Client{}
// client.Transport = &http.Transport{ ... }
// retryClient := httpretry.NewCustomClient(client)
// 2. use one of the helper functions (e.g. httpretry.ModifyOriginalTransport(retryClient)) to retrieve and change the Transport.
// retryClient := httpretry.NewDefaultClient()
// err := httpretry.ModifyOriginalTransport(retryClient, func(t *http.Transport){t.TLSHandshakeTimeout = 5 * time.Second})
// if err != nil { ... } // will be nil if embedded Roundtripper was not of type http.Transport
// 1. create your own http client and use NewCustomClient() function to enrich the client with retry functionality.
// client := &http.Client{}
// client.Transport = &http.Transport{ ... }
// retryClient := httpretry.NewCustomClient(client)
// 2. use one of the helper functions (e.g. httpretry.ModifyOriginalTransport(retryClient)) to retrieve and change the Transport.
// retryClient := httpretry.NewDefaultClient()
// err := httpretry.ModifyOriginalTransport(retryClient, func(t *http.Transport){t.TLSHandshakeTimeout = 5 * time.Second})
// if err != nil { ... } // will be nil if embedded Roundtripper was not of type http.Transport
func NewDefaultClient(opts ...Option) *http.Client {
return NewCustomClient(&http.Client{}, opts...)
}
Expand All @@ -33,32 +29,14 @@ func NewDefaultClient(opts ...Option) *http.Client {
//
// If you need to change the original client.Transport field you may use the helper functions:
//
// err := httpretry.ModifyTransport(retryClient, func(t *http.Transport){t.TLSHandshakeTimeout = 5 * time.Second})
// if err != nil { ... } // will be nil if embedded Roundtripper was not of type http.Transport
// err := httpretry.ModifyTransport(retryClient, func(t *http.Transport){t.TLSHandshakeTimeout = 5 * time.Second})
// if err != nil { ... } // will be nil if embedded Roundtripper was not of type http.Transport
func NewCustomClient(client *http.Client, opts ...Option) *http.Client {
if client == nil {
panic("client must not be nil")
}

nextRoundtripper := client.Transport
if nextRoundtripper == nil {
nextRoundtripper = http.DefaultTransport
}

// set defaults
retryRoundtripper := &RetryRoundtripper{
Next: nextRoundtripper,
MaxRetryCount: defaultMaxRetryCount,
ShouldRetry: defaultRetryPolicy,
CalculateBackoff: defaultBackoffPolicy,
}

// overwrite defaults with user provided configuration
for _, o := range opts {
o(retryRoundtripper)
}

client.Transport = retryRoundtripper
client.Transport = NewRoundtripper(client.Transport, opts...)

return client
}
Expand Down
25 changes: 25 additions & 0 deletions roundtripper.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ import (
"time"
)

const (
defaultMaxRetryCount = 5
)

// RetryRoundtripper is the roundtripper that will wrap around the actual http.Transport roundtripper
// to enrich the http client with retry functionality.
type RetryRoundtripper struct {
Expand All @@ -16,6 +20,27 @@ type RetryRoundtripper struct {
CalculateBackoff BackoffPolicy
}

// NewRoundtripper creates a new RetryRoundtripper with the provided options.
// If no next [net/http.RoundTripper] is provided, the default [net/http.DefaultTransport] will be used.
func NewRoundtripper(next http.RoundTripper, opts ...Option) *RetryRoundtripper {
if next == nil {
next = http.DefaultTransport
}

roundTripper := &RetryRoundtripper{
Next: next,
MaxRetryCount: defaultMaxRetryCount,
ShouldRetry: defaultRetryPolicy,
CalculateBackoff: defaultBackoffPolicy,
}

for _, o := range opts {
o(roundTripper)
}

return roundTripper
}

// RoundTrip implements the actual roundtripper interface (http.RoundTripper).
func (r *RetryRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) {
var (
Expand Down
35 changes: 31 additions & 4 deletions roundtripper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,42 @@ package httpretry
import (
"bytes"
"context"
"github.com/stretchr/testify/assert"
"io"
"io/ioutil"
"net/http"
"strings"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestNewNewRoundtripper(t *testing.T) {
check := assert.New(t)

t.Run("should create default rountriper", func(t *testing.T) {
customTransport := &http.Transport{}
roundTripper := NewRoundtripper(customTransport)

check.Equal(customTransport, roundTripper.Next)
check.Equal(5, roundTripper.MaxRetryCount)
check.NotNil(roundTripper.CalculateBackoff)
check.NotNil(roundTripper.ShouldRetry)
})

t.Run("should use default http transport if nil next provided", func(t *testing.T) {
roundTripper := NewRoundtripper(nil)
check.Equal(http.DefaultTransport, roundTripper.Next)
})

t.Run("should apply options", func(t *testing.T) {
maxRetryCount := 2

roundTripper := NewRoundtripper(nil, WithMaxRetryCount(maxRetryCount))

check.Equal(maxRetryCount, roundTripper.MaxRetryCount)
})
}

func TestRetryRoundtripperSimple(t *testing.T) {
check := assert.New(t)

Expand Down Expand Up @@ -195,7 +222,7 @@ func TestRetryRoundtripperWithBody(t *testing.T) {

func readerContains(t *testing.T, r io.Reader, substring string) bool {
t.Helper()
d, err := ioutil.ReadAll(r)
d, err := io.ReadAll(r)
if err != nil {
t.Fatal("could not read body: ", err.Error())
}
Expand Down Expand Up @@ -268,7 +295,7 @@ func FakeResponse(req *http.Request, code int, body []byte) *http.Response {
var contentLength int64 = -1

if len(body) != 0 {
bodyReadCloser = ioutil.NopCloser(bytes.NewReader(body))
bodyReadCloser = io.NopCloser(bytes.NewReader(body))
contentLength = int64(len(body))
}

Expand Down