diff --git a/pkg/api/dialer.go b/pkg/api/dialer.go new file mode 100644 index 00000000..f146597a --- /dev/null +++ b/pkg/api/dialer.go @@ -0,0 +1,100 @@ +package api + +import ( + "context" + "errors" + "net" + "os" + "strings" + "testing" + "time" +) + +// DialContextFunc is a function that dials a context, network, and address. +type DialContextFunc func(ctx context.Context, network, address string) (net.Conn, error) + +// RetryDialUntilSuccess will retry every `retryTimeout` until it succeeds. +func RetryDialUntilSuccess(retryTimeout time.Duration) DialContextFunc { + return func(ctx context.Context, network, address string) (net.Conn, error) { + for { + dialer := &net.Dialer{ + Timeout: retryTimeout, + KeepAlive: 30 * time.Second, // Similar to the default HTTP dialer. + } + c, err := dialer.DialContext(ctx, network, address) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + continue + } + if errors.Is(err, os.ErrDeadlineExceeded) { + continue + } + // Testing hook. + if testing.Testing() && strings.Contains(err.Error(), "connection refused") { + continue + } + } + return c, err + } + } +} + +// DialNoTimeout will block with no timeout or until the context is canceled. +func DialNoTimeout() DialContextFunc { + dialer := &net.Dialer{} + return dialer.DialContext +} + +// DefaultHTTPDialer has the same options as the default HTTP dialer. +func DefaultHTTPDialer() DialContextFunc { + dialer := &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + } + return dialer.DialContext +} + +// RacingDialer is a custom dialer that attempts to connect to a given address. +// +// It uses two different dialers. +// The dialer connects first is returned, and the other is canceled. +// +// The first has a short timeout (200 ms) and continues to retry until it succeeds. +// The second dialer has no timeout and will block until it either succeeds or fails. +// +// We are doing this because we see connection timeouts perhaps caused by some competing network routes. +// Our workaround is to use a short timeout dialer that will retry until it succeeds. +func RacingDialer(dialers ...DialContextFunc) DialContextFunc { + if len(dialers) == 0 { + return DialNoTimeout() + } + + return func(ctx context.Context, network, address string) (net.Conn, error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + type dialResult struct { + conn net.Conn + err error + } + resultCh := make(chan dialResult, len(dialers)) + for _, dialer := range dialers { + go func(d DialContextFunc) { + c, err := d(ctx, network, address) + resultCh <- dialResult{conn: c, err: err} + }(dialer) + } + + var connError error + for range len(dialers) { + res := <-resultCh + if res.err == nil { + cancel() + return res.conn, nil + } else { + connError = res.err + } + } + return nil, connError + } +} diff --git a/pkg/api/dialer_test.go b/pkg/api/dialer_test.go new file mode 100644 index 00000000..ae64383c --- /dev/null +++ b/pkg/api/dialer_test.go @@ -0,0 +1,98 @@ +package api + +import ( + "context" + "fmt" + "net" + "sync" + "testing" + "time" +) + +func TestRetryDialUntilSuccess(t *testing.T) { + // "reserve" a port. + ln, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("failed to listen on random port: %v", err) + } + port := ln.Addr().(*net.TCPAddr).Port + _ = ln.Close() + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + // Wait a a bit so that the dialer will retry a few times. + time.Sleep(50 * time.Millisecond) + + // Start a server to listen on the reserved port. + listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) + if err != nil { + cancel() + return + } + t.Log("listener", listener.Addr()) + defer listener.Close() + conn, err := listener.Accept() + if err != nil { + cancel() + return + } + defer conn.Close() + }() + + dialer := RetryDialUntilSuccess(10 * time.Millisecond) + conn, err := dialer(ctx, "tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + conn.Close() + wg.Wait() +} + +func TestRacingDialer(t *testing.T) { + // "reserve" a port. + ln, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("failed to listen on random port: %v", err) + } + port := ln.Addr().(*net.TCPAddr).Port + _ = ln.Close() + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + // Wait a a bit so that the dialer will retry a few times. + time.Sleep(50 * time.Millisecond) + + // Start a server to listen on the reserved port. + listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) + if err != nil { + cancel() + return + } + t.Log("listener", listener.Addr()) + defer listener.Close() + conn, err := listener.Accept() + if err != nil { + cancel() + return + } + defer conn.Close() + }() + + dialer := RacingDialer(DialNoTimeout(), RetryDialUntilSuccess(10*time.Millisecond)) + conn, err := dialer(ctx, "tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + conn.Close() + wg.Wait() +} diff --git a/pkg/api/rpc.go b/pkg/api/rpc.go index 07d911c9..543e7eeb 100644 --- a/pkg/api/rpc.go +++ b/pkg/api/rpc.go @@ -6,6 +6,7 @@ import ( "net/http" "os" "strings" + "time" "buf.build/gen/go/depot/api/connectrpc/go/depot/core/v1/corev1connect" "connectrpc.com/connect" @@ -84,8 +85,23 @@ func getHTTPClient(baseURL string) *http.Client { }, } } - // Use default client for HTTPS connections - return http.DefaultClient + + t, ok := http.DefaultTransport.(*http.Transport) + if !ok { + return http.DefaultClient + } + + transport := t.Clone() + transport.DialContext = RacingDialer( + RetryDialUntilSuccess(500*time.Millisecond), + DefaultHTTPDialer(), + ) + + racingClient := &http.Client{ + Transport: transport, + } + + return racingClient } func getBaseURL() string {