Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions pkg/api/dialer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package api

import (
"context"
"errors"
"net"
"os"
"strings"
"testing"
"time"
)

// DialContextFunc is a function that dials a context, network, and address.
type DialContextFunc func(ctx context.Context, network, address string) (net.Conn, error)

// RetryDialUntilSuccess will retry every `retryTimeout` until it succeeds.
func RetryDialUntilSuccess(retryTimeout time.Duration) DialContextFunc {
return func(ctx context.Context, network, address string) (net.Conn, error) {
for {
dialer := &net.Dialer{
Timeout: retryTimeout,
KeepAlive: 30 * time.Second, // Similar to the default HTTP dialer.
}
c, err := dialer.DialContext(ctx, network, address)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
continue
}
if errors.Is(err, os.ErrDeadlineExceeded) {
continue
}
Comment on lines +26 to +31
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

understanding check: these are the only errors we're retrying for and others will just get returned below?

// Testing hook.
if testing.Testing() && strings.Contains(err.Error(), "connection refused") {
continue
}
}
return c, err
}
}
}

// DialNoTimeout will block with no timeout or until the context is canceled.
func DialNoTimeout() DialContextFunc {
dialer := &net.Dialer{}
return dialer.DialContext
}

// DefaultHTTPDialer has the same options as the default HTTP dialer.
func DefaultHTTPDialer() DialContextFunc {
dialer := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
return dialer.DialContext
}

// RacingDialer is a custom dialer that attempts to connect to a given address.
//
// It uses two different dialers.
// The dialer connects first is returned, and the other is canceled.
//
// The first has a short timeout (200 ms) and continues to retry until it succeeds.
// The second dialer has no timeout and will block until it either succeeds or fails.
//
// We are doing this because we see connection timeouts perhaps caused by some competing network routes.
// Our workaround is to use a short timeout dialer that will retry until it succeeds.
Comment on lines +59 to +66
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We dont need a retry limit bc the second one will eventually fail or succeed? Do we want to set some maximumum limit or log the timing so if this does end up being quite long we can see what the hold up is?

func RacingDialer(dialers ...DialContextFunc) DialContextFunc {
if len(dialers) == 0 {
return DialNoTimeout()
}

return func(ctx context.Context, network, address string) (net.Conn, error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

type dialResult struct {
conn net.Conn
err error
}
resultCh := make(chan dialResult, len(dialers))
for _, dialer := range dialers {
go func(d DialContextFunc) {
c, err := d(ctx, network, address)
resultCh <- dialResult{conn: c, err: err}
}(dialer)
}

var connError error
for range len(dialers) {
res := <-resultCh
if res.err == nil {
cancel()
return res.conn, nil
} else {
connError = res.err
}
}
return nil, connError
}
}
98 changes: 98 additions & 0 deletions pkg/api/dialer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package api

import (
"context"
"fmt"
"net"
"sync"
"testing"
"time"
)

func TestRetryDialUntilSuccess(t *testing.T) {
// "reserve" a port.
ln, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("failed to listen on random port: %v", err)
}
port := ln.Addr().(*net.TCPAddr).Port
_ = ln.Close()

ctx, cancel := context.WithCancel(t.Context())
defer cancel()

var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
// Wait a a bit so that the dialer will retry a few times.
time.Sleep(50 * time.Millisecond)

// Start a server to listen on the reserved port.
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
cancel()
return
}
t.Log("listener", listener.Addr())
defer listener.Close()
conn, err := listener.Accept()
if err != nil {
cancel()
return
}
defer conn.Close()
}()

dialer := RetryDialUntilSuccess(10 * time.Millisecond)
conn, err := dialer(ctx, "tcp", fmt.Sprintf("localhost:%d", port))
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
conn.Close()
wg.Wait()
}

func TestRacingDialer(t *testing.T) {
// "reserve" a port.
ln, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("failed to listen on random port: %v", err)
}
port := ln.Addr().(*net.TCPAddr).Port
_ = ln.Close()

ctx, cancel := context.WithCancel(t.Context())
defer cancel()

var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
// Wait a a bit so that the dialer will retry a few times.
time.Sleep(50 * time.Millisecond)

// Start a server to listen on the reserved port.
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
cancel()
return
}
t.Log("listener", listener.Addr())
defer listener.Close()
conn, err := listener.Accept()
if err != nil {
cancel()
return
}
defer conn.Close()
}()

dialer := RacingDialer(DialNoTimeout(), RetryDialUntilSuccess(10*time.Millisecond))
conn, err := dialer(ctx, "tcp", fmt.Sprintf("localhost:%d", port))
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
conn.Close()
wg.Wait()
}
20 changes: 18 additions & 2 deletions pkg/api/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net/http"
"os"
"strings"
"time"

"buf.build/gen/go/depot/api/connectrpc/go/depot/core/v1/corev1connect"
"connectrpc.com/connect"
Expand Down Expand Up @@ -84,8 +85,23 @@ func getHTTPClient(baseURL string) *http.Client {
},
}
}
// Use default client for HTTPS connections
return http.DefaultClient

t, ok := http.DefaultTransport.(*http.Transport)
if !ok {
return http.DefaultClient
}

transport := t.Clone()
transport.DialContext = RacingDialer(
RetryDialUntilSuccess(500*time.Millisecond),
DefaultHTTPDialer(),
)

racingClient := &http.Client{
Transport: transport,
}

return racingClient
}

func getBaseURL() string {
Expand Down
Loading