diff --git a/cmd/dbc/main.go b/cmd/dbc/main.go index 45117db..c03905c 100644 --- a/cmd/dbc/main.go +++ b/cmd/dbc/main.go @@ -154,6 +154,7 @@ type cmds struct { Auth *AuthCmd `arg:"subcommand" help:"Manage driver registry credentials"` Completion *completions.Cmd `arg:"subcommand,hidden"` Quiet bool `arg:"-q,--quiet" help:"Suppress all output"` + Proxy string `arg:"--proxy" help:"Proxy server URL for HTTP requests"` } func (cmds) Version() string { @@ -189,6 +190,13 @@ func main() { } } + if args.Proxy != "" { + if err := dbc.SetProxy(args.Proxy); err != nil { + fmt.Fprintf(os.Stderr, "Error setting proxy: %v\n", err) + os.Exit(1) + } + } + if p.Subcommand() == nil { p.WriteHelp(os.Stdout) os.Exit(1) diff --git a/drivers.go b/drivers.go index 28bcf08..e5e7047 100644 --- a/drivers.go +++ b/drivers.go @@ -85,6 +85,34 @@ func (u *uaRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { return u.RoundTripper.RoundTrip(req) } +// SetProxy configures the HTTP client to use the specified proxy server. +// If proxy is empty, it uses the default transport (which may still respect HTTP_PROXY env var). +func SetProxy(proxy string) error { + var transport http.RoundTripper = http.DefaultTransport + if proxy != "" { + proxyURL, err := url.Parse(proxy) + if err != nil { + return fmt.Errorf("invalid proxy URL: %w", err) + } + transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)} + } + + // Preserve the user agent from the current transport + var userAgent string + if ua, ok := DefaultClient.Transport.(*uaRoundTripper); ok { + userAgent = ua.userAgent + } else { + // Fallback, should not happen + userAgent = "dbc-cli" + } + + DefaultClient.Transport = &uaRoundTripper{ + RoundTripper: transport, + userAgent: userAgent, + } + return nil +} + func init() { info, ok := debug.ReadBuildInfo() if ok { diff --git a/drivers_test.go b/drivers_test.go new file mode 100644 index 0000000..26ece3b --- /dev/null +++ b/drivers_test.go @@ -0,0 +1,63 @@ +// Copyright 2026 Columnar Technologies Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dbc + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSetProxy(t *testing.T) { + // Save original transport + originalTransport := DefaultClient.Transport + + t.Cleanup(func() { + DefaultClient.Transport = originalTransport + }) + + t.Run("valid proxy URL", func(t *testing.T) { + err := SetProxy("http://proxy.example.com:8080") + require.NoError(t, err) + + // Check that transport is set + ua, ok := DefaultClient.Transport.(*uaRoundTripper) + require.True(t, ok, "transport should be uaRoundTripper") + + transport, ok := ua.RoundTripper.(*http.Transport) + require.True(t, ok, "inner transport should be http.Transport") + + require.NotNil(t, transport.Proxy, "proxy should be set") + }) + + t.Run("empty proxy", func(t *testing.T) { + err := SetProxy("") + require.NoError(t, err) + + ua, ok := DefaultClient.Transport.(*uaRoundTripper) + require.True(t, ok) + + // Should use default transport + assert.Equal(t, http.DefaultTransport, ua.RoundTripper) + }) + + t.Run("invalid proxy URL", func(t *testing.T) { + err := SetProxy("://invalid") + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid proxy URL") + }) +}