Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cmd/dbc/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ type cmds struct {
Auth *AuthCmd `arg:"subcommand" help:"Manage driver registry credentials"`
Completion *completions.Cmd `arg:"subcommand,hidden"`
Quiet bool `arg:"-q,--quiet" help:"Suppress all output"`
Proxy string `arg:"--proxy" help:"Proxy server URL for HTTP requests"`
}

func (cmds) Version() string {
Expand Down Expand Up @@ -189,6 +190,13 @@ func main() {
}
}

if args.Proxy != "" {
if err := dbc.SetProxy(args.Proxy); err != nil {
fmt.Fprintf(os.Stderr, "Error setting proxy: %v\n", err)
os.Exit(1)
}
}

if p.Subcommand() == nil {
p.WriteHelp(os.Stdout)
os.Exit(1)
Expand Down
28 changes: 28 additions & 0 deletions drivers.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,34 @@ func (u *uaRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return u.RoundTripper.RoundTrip(req)
}

// SetProxy configures the HTTP client to use the specified proxy server.
// If proxy is empty, it uses the default transport (which may still respect HTTP_PROXY env var).
func SetProxy(proxy string) error {
var transport http.RoundTripper = http.DefaultTransport
if proxy != "" {
proxyURL, err := url.Parse(proxy)
if err != nil {
return fmt.Errorf("invalid proxy URL: %w", err)
}
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}

// Preserve the user agent from the current transport
var userAgent string
if ua, ok := DefaultClient.Transport.(*uaRoundTripper); ok {
userAgent = ua.userAgent
} else {
// Fallback, should not happen
userAgent = "dbc-cli"
}

DefaultClient.Transport = &uaRoundTripper{
RoundTripper: transport,
userAgent: userAgent,
}
return nil
}

func init() {
info, ok := debug.ReadBuildInfo()
if ok {
Expand Down
63 changes: 63 additions & 0 deletions drivers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright 2026 Columnar Technologies Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package dbc

import (
"net/http"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestSetProxy(t *testing.T) {
// Save original transport
originalTransport := DefaultClient.Transport

t.Cleanup(func() {
DefaultClient.Transport = originalTransport
})

t.Run("valid proxy URL", func(t *testing.T) {
err := SetProxy("http://proxy.example.com:8080")
require.NoError(t, err)

// Check that transport is set
ua, ok := DefaultClient.Transport.(*uaRoundTripper)
require.True(t, ok, "transport should be uaRoundTripper")

transport, ok := ua.RoundTripper.(*http.Transport)
require.True(t, ok, "inner transport should be http.Transport")

require.NotNil(t, transport.Proxy, "proxy should be set")
})

t.Run("empty proxy", func(t *testing.T) {
err := SetProxy("")
require.NoError(t, err)

ua, ok := DefaultClient.Transport.(*uaRoundTripper)
require.True(t, ok)

// Should use default transport
assert.Equal(t, http.DefaultTransport, ua.RoundTripper)
})

t.Run("invalid proxy URL", func(t *testing.T) {
err := SetProxy("://invalid")
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid proxy URL")
})
}
Loading