Skip to content

Conversation

@madhav-db
Copy link
Contributor

@madhav-db madhav-db commented Oct 30, 2025

Adds token federation for databricks sql go driver

Adds automatic token exchange (federation) and caching capabilities:

- CachedTokenProvider: Automatic token refresh with 5min buffer
- FederationProvider: Auto-detects and exchanges external JWT tokens
- Supports both user federation and SP-wide (M2M) federation
- Graceful fallback if token exchange unavailable
- Connector functions: WithFederatedTokenProvider, WithFederatedTokenProviderAndClientID
- Azure domain list updates for staging/dev environments

Token exchange follows RFC 8693 standard.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR implements token federation for the Go driver, enabling automatic token exchange for external identity provider tokens. The implementation includes a FederationProvider that wraps base token providers and intelligently determines when token exchange is needed by comparing JWT issuers with the Databricks host. It also adds a CachedTokenProvider to optimize token refresh operations.

Key changes:

  • Added federation provider with automatic token exchange detection and fallback
  • Implemented comprehensive test coverage for federation scenarios including real-world identity providers
  • Added caching layer for token providers to reduce unnecessary token refreshes

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
connector.go Added public API functions for configuring federated token providers with optional client ID support
auth/tokenprovider/exchange.go Implements core federation logic including JWT validation, host comparison, and token exchange protocol
auth/tokenprovider/federation_test.go Comprehensive test suite covering host comparison, token exchange, caching, and real-world identity providers
auth/tokenprovider/cached.go Generic token caching provider with thread-safe refresh logic
auth/oauth/oauth.go Reorganized Azure domain lists, moving staging/dev domains from tenant map to domain list

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

func (p *FederationProvider) Name() string {
baseName := p.baseProvider.Name()
if p.clientID != "" {
return fmt.Sprintf("federation[%s,sp:%s]", baseName, p.clientID[:8]) // Truncate client ID for readability
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential panic if clientID length is less than 8 characters. Add a length check before slicing or use a safe truncation approach.

Suggested change
return fmt.Sprintf("federation[%s,sp:%s]", baseName, p.clientID[:8]) // Truncate client ID for readability
clientIDDisplay := p.clientID
if len(p.clientID) >= 8 {
clientIDDisplay = p.clientID[:8]
}
return fmt.Sprintf("federation[%s,sp:%s]", baseName, clientIDDisplay) // Truncate client ID for readability

Copilot uses AI. Check for mistakes.
Comment on lines +248 to +250
fedProvider := NewFederationProviderWithClientID(baseProvider, "test.databricks.com", "client-12345678-more")
// Should truncate client ID to first 8 chars
assert.Equal(t, "federation[static,sp:client-1]", fedProvider.Name())
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test assumes client ID will always be at least 8 characters. Add test case for short client IDs (< 8 characters) to verify behavior matches the truncation logic in Name().

Copilot uses AI. Check for mistakes.
func NewCachedTokenProvider(provider TokenProvider) *CachedTokenProvider {
return &CachedTokenProvider{
provider: provider,
RefreshThreshold: 5 * time.Minute,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this quite high?

return &FederationProvider{
baseProvider: baseProvider,
databricksHost: databricksHost,
httpClient: &http.Client{Timeout: 30 * time.Second},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we not have any http clients in sql go that can be used here rather than creating a new one?

Comment on lines +32 to +39
p.mutex.RLock()
cached := p.cache
p.mutex.RUnlock()

if cached != nil && !p.shouldRefresh(cached) {
log.Debug().Msgf("cached token provider: using cached token for provider %s", p.provider.Name())
return cached, nil
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After releasing the read lock, another goroutine could modify cached.ExpiresAt, making the shouldRefresh check stale. Additionally, returning a pointer to the cached token is unsafe.

// Need to refresh
p.mutex.Lock()
defer p.mutex.Unlock()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we check if the context has been cancelled?

}

log.Debug().Msgf("cached token provider: fetching new token from provider %s", p.provider.Name())
token, err := p.provider.GetToken(ctx)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LLM suggests below:
Issue: Calling GetToken while holding the mutex lock is dangerous. If the underlying provider makes HTTP calls or does other slow operations, it blocks all other goroutines trying to access the cache. This can lead to deadlocks if the provider tries to acquire the same lock.

Suggested change:

// Mark as refreshing to prevent thundering herd
if p.refreshing {
        p.mutex.Unlock()
        time.Sleep(50 * time.Millisecond)
        p.mutex.Lock()
        if p.cache != nil && !p.shouldRefresh(p.cache) {
                return p.cache, nil
        }
}
p.refreshing = true
p.mutex.Unlock()

log.Debug().Msgf("cached token provider: fetching new token from provider %s", p.provider.Name())
token, err := p.provider.GetToken(ctx)

p.mutex.Lock()
p.refreshing = false
if err != nil {
        return nil, fmt.Errorf("cached token provider: failed to get token: %w", err)
}

p.cache = token
return token, nil

func (p *FederationProvider) tryTokenExchange(ctx context.Context, subjectToken string) (*Token, error) {
// Build exchange URL - add scheme if not present
exchangeURL := p.databricksHost
if !strings.HasPrefix(exchangeURL, "http://") && !strings.HasPrefix(exchangeURL, "https://") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we allow http:// prefix, might be insecure for token exchange.

req.Header.Set("Accept", "*/*")

// Make request
resp, err := p.httpClient.Do(req)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should try to reuse an existing http client so that failures are handled by retries here

if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add validations on token resp here

u2, err2 := url.Parse(parsedURL2)

if err1 != nil || err2 != nil {
return false
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add some warn or debug logs here

require.NoError(t, err2)
assert.Equal(t, "databricks-token", token2.AccessToken)
assert.Equal(t, 1, callCount, "External provider should still be called only once (cached)")
assert.Equal(t, 1, exchangeCount, "Token should still be exchanged only once (cached)")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also add assert that token1 == token2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants