-
Notifications
You must be signed in to change notification settings - Fork 55
Token federation for Go driver (2/3) #291
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Adds automatic token exchange (federation) and caching capabilities: - CachedTokenProvider: Automatic token refresh with 5min buffer - FederationProvider: Auto-detects and exchanges external JWT tokens - Supports both user federation and SP-wide (M2M) federation - Graceful fallback if token exchange unavailable - Connector functions: WithFederatedTokenProvider, WithFederatedTokenProviderAndClientID - Azure domain list updates for staging/dev environments Token exchange follows RFC 8693 standard. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR implements token federation for the Go driver, enabling automatic token exchange for external identity provider tokens. The implementation includes a FederationProvider that wraps base token providers and intelligently determines when token exchange is needed by comparing JWT issuers with the Databricks host. It also adds a CachedTokenProvider to optimize token refresh operations.
Key changes:
- Added federation provider with automatic token exchange detection and fallback
- Implemented comprehensive test coverage for federation scenarios including real-world identity providers
- Added caching layer for token providers to reduce unnecessary token refreshes
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| connector.go | Added public API functions for configuring federated token providers with optional client ID support |
| auth/tokenprovider/exchange.go | Implements core federation logic including JWT validation, host comparison, and token exchange protocol |
| auth/tokenprovider/federation_test.go | Comprehensive test suite covering host comparison, token exchange, caching, and real-world identity providers |
| auth/tokenprovider/cached.go | Generic token caching provider with thread-safe refresh logic |
| auth/oauth/oauth.go | Reorganized Azure domain lists, moving staging/dev domains from tenant map to domain list |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| func (p *FederationProvider) Name() string { | ||
| baseName := p.baseProvider.Name() | ||
| if p.clientID != "" { | ||
| return fmt.Sprintf("federation[%s,sp:%s]", baseName, p.clientID[:8]) // Truncate client ID for readability |
Copilot
AI
Dec 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential panic if clientID length is less than 8 characters. Add a length check before slicing or use a safe truncation approach.
| return fmt.Sprintf("federation[%s,sp:%s]", baseName, p.clientID[:8]) // Truncate client ID for readability | |
| clientIDDisplay := p.clientID | |
| if len(p.clientID) >= 8 { | |
| clientIDDisplay = p.clientID[:8] | |
| } | |
| return fmt.Sprintf("federation[%s,sp:%s]", baseName, clientIDDisplay) // Truncate client ID for readability |
| fedProvider := NewFederationProviderWithClientID(baseProvider, "test.databricks.com", "client-12345678-more") | ||
| // Should truncate client ID to first 8 chars | ||
| assert.Equal(t, "federation[static,sp:client-1]", fedProvider.Name()) |
Copilot
AI
Dec 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test assumes client ID will always be at least 8 characters. Add test case for short client IDs (< 8 characters) to verify behavior matches the truncation logic in Name().
| func NewCachedTokenProvider(provider TokenProvider) *CachedTokenProvider { | ||
| return &CachedTokenProvider{ | ||
| provider: provider, | ||
| RefreshThreshold: 5 * time.Minute, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't this quite high?
| return &FederationProvider{ | ||
| baseProvider: baseProvider, | ||
| databricksHost: databricksHost, | ||
| httpClient: &http.Client{Timeout: 30 * time.Second}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we not have any http clients in sql go that can be used here rather than creating a new one?
| p.mutex.RLock() | ||
| cached := p.cache | ||
| p.mutex.RUnlock() | ||
|
|
||
| if cached != nil && !p.shouldRefresh(cached) { | ||
| log.Debug().Msgf("cached token provider: using cached token for provider %s", p.provider.Name()) | ||
| return cached, nil | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After releasing the read lock, another goroutine could modify cached.ExpiresAt, making the shouldRefresh check stale. Additionally, returning a pointer to the cached token is unsafe.
| // Need to refresh | ||
| p.mutex.Lock() | ||
| defer p.mutex.Unlock() | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we check if the context has been cancelled?
| } | ||
|
|
||
| log.Debug().Msgf("cached token provider: fetching new token from provider %s", p.provider.Name()) | ||
| token, err := p.provider.GetToken(ctx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LLM suggests below:
Issue: Calling GetToken while holding the mutex lock is dangerous. If the underlying provider makes HTTP calls or does other slow operations, it blocks all other goroutines trying to access the cache. This can lead to deadlocks if the provider tries to acquire the same lock.
Suggested change:
// Mark as refreshing to prevent thundering herd
if p.refreshing {
p.mutex.Unlock()
time.Sleep(50 * time.Millisecond)
p.mutex.Lock()
if p.cache != nil && !p.shouldRefresh(p.cache) {
return p.cache, nil
}
}
p.refreshing = true
p.mutex.Unlock()
log.Debug().Msgf("cached token provider: fetching new token from provider %s", p.provider.Name())
token, err := p.provider.GetToken(ctx)
p.mutex.Lock()
p.refreshing = false
if err != nil {
return nil, fmt.Errorf("cached token provider: failed to get token: %w", err)
}
p.cache = token
return token, nil| func (p *FederationProvider) tryTokenExchange(ctx context.Context, subjectToken string) (*Token, error) { | ||
| // Build exchange URL - add scheme if not present | ||
| exchangeURL := p.databricksHost | ||
| if !strings.HasPrefix(exchangeURL, "http://") && !strings.HasPrefix(exchangeURL, "https://") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we allow http:// prefix, might be insecure for token exchange.
| req.Header.Set("Accept", "*/*") | ||
|
|
||
| // Make request | ||
| resp, err := p.httpClient.Do(req) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should try to reuse an existing http client so that failures are handled by retries here
| if err := json.Unmarshal(body, &tokenResp); err != nil { | ||
| return nil, fmt.Errorf("failed to parse response: %w", err) | ||
| } | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we add validations on token resp here
| u2, err2 := url.Parse(parsedURL2) | ||
|
|
||
| if err1 != nil || err2 != nil { | ||
| return false |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add some warn or debug logs here
| require.NoError(t, err2) | ||
| assert.Equal(t, "databricks-token", token2.AccessToken) | ||
| assert.Equal(t, 1, callCount, "External provider should still be called only once (cached)") | ||
| assert.Equal(t, 1, exchangeCount, "Token should still be exchanged only once (cached)") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we also add assert that token1 == token2
Adds token federation for databricks sql go driver