diff --git a/forge-cli/cmd/init.go b/forge-cli/cmd/init.go index 946259d..573e9d5 100644 --- a/forge-cli/cmd/init.go +++ b/forge-cli/cmd/init.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "fmt" "os" "os/exec" @@ -16,6 +17,7 @@ import ( "github.com/initializ/forge/forge-cli/internal/tui/steps" "github.com/initializ/forge/forge-cli/skills" "github.com/initializ/forge/forge-cli/templates" + "github.com/initializ/forge/forge-core/llm/oauth" "github.com/initializ/forge/forge-core/tools/builtins" "github.com/initializ/forge/forge-core/util" "github.com/initializ/forge/forge-skills/contract" @@ -30,6 +32,7 @@ type initOptions struct { Language string ModelProvider string APIKey string // validated provider key + Fallbacks []tui.FallbackProvider Channels []string SkillsFile string Tools []toolEntry @@ -56,6 +59,7 @@ type templateData struct { Entrypoint string ModelProvider string ModelName string + Fallbacks []fallbackTmplData Channels []string Tools []toolEntry BuiltinTools []string @@ -64,6 +68,12 @@ type templateData struct { EnvVars []envVarEntry } +// fallbackTmplData holds template data for a fallback provider. +type fallbackTmplData struct { + Provider string + ModelName string +} + // skillTmplData holds template data for a registry skill. type skillTmplData struct { Name string @@ -103,6 +113,7 @@ func init() { initCmd.Flags().StringSlice("tools", nil, "builtin tools to enable (e.g., web_search,http_request)") initCmd.Flags().StringSlice("skills", nil, "registry skills to include (e.g., github,weather)") initCmd.Flags().String("api-key", "", "LLM provider API key") + initCmd.Flags().StringSlice("fallbacks", nil, "fallback LLM providers (e.g., openai,gemini)") initCmd.Flags().Bool("force", false, "overwrite existing directory") } @@ -128,6 +139,10 @@ func runInit(cmd *cobra.Command, args []string) error { opts.BuiltinTools, _ = cmd.Flags().GetStringSlice("tools") opts.Skills, _ = cmd.Flags().GetStringSlice("skills") opts.APIKey, _ = cmd.Flags().GetString("api-key") + fallbackProviders, _ := cmd.Flags().GetStringSlice("fallbacks") + for _, p := range fallbackProviders { + opts.Fallbacks = append(opts.Fallbacks, tui.FallbackProvider{Provider: p}) + } nonInteractive, _ := cmd.Flags().GetBool("non-interactive") opts.NonInteractive = nonInteractive @@ -216,6 +231,11 @@ func collectInteractive(opts *initOptions) error { return validateProviderKey(provider, key) } + // Build OAuth flow callback + oauthFlowFn := func(provider string) (string, error) { + return runOAuthFlow(provider) + } + // Build web search key validation callback validateWebSearchKeyFn := func(provider, key string) error { return validateWebSearchKey(provider, key) @@ -224,7 +244,8 @@ func collectInteractive(opts *initOptions) error { // Build step list wizardSteps := []tui.Step{ steps.NewNameStep(styles, opts.Name), - steps.NewProviderStep(styles, validateKeyFn), + steps.NewProviderStep(styles, validateKeyFn, oauthFlowFn), + steps.NewFallbackStep(styles, validateKeyFn), steps.NewChannelStep(styles), steps.NewToolsStep(styles, toolInfos, validateWebSearchKeyFn), steps.NewSkillsStep(styles, skillInfos), @@ -264,7 +285,12 @@ func collectInteractive(opts *initOptions) error { opts.ModelProvider = ctx.Provider opts.APIKey = ctx.APIKey + opts.Fallbacks = ctx.Fallbacks opts.CustomModel = ctx.CustomModel + // Use wizard-selected model name if available + if ctx.ModelName != "" { + opts.CustomModel = ctx.ModelName + } if ctx.Channel != "" && ctx.Channel != "none" { opts.Channels = []string{ctx.Channel} @@ -711,22 +737,20 @@ func buildTemplateData(opts *initOptions) templateData { } } - // Set default model name based on provider - switch opts.ModelProvider { - case "openai": - data.ModelName = "gpt-4o-mini" - case "anthropic": - data.ModelName = "claude-sonnet-4-20250514" - case "gemini": - data.ModelName = "gemini-2.5-flash" - case "ollama": - data.ModelName = "llama3" - default: - if opts.CustomModel != "" { - data.ModelName = opts.CustomModel - } else { - data.ModelName = "default" - } + // Set model name: use wizard-selected model, or fall back to provider default + if opts.CustomModel != "" { + data.ModelName = opts.CustomModel + } else { + data.ModelName = defaultModelNameForProvider(opts.ModelProvider) + } + + // Build fallback entries for templates + for _, fb := range opts.Fallbacks { + modelName := defaultModelNameForProvider(fb.Provider) + data.Fallbacks = append(data.Fallbacks, fallbackTmplData{ + Provider: fb.Provider, + ModelName: modelName, + }) } // Build skill entries for templates @@ -759,6 +783,22 @@ func buildTemplateData(opts *initOptions) templateData { return data } +// defaultModelNameForProvider returns the default model name for wizard templates. +func defaultModelNameForProvider(provider string) string { + switch provider { + case "openai": + return "gpt-5.2-2025-12-11" + case "anthropic": + return "claude-sonnet-4-20250514" + case "gemini": + return "gemini-2.5-flash" + case "ollama": + return "llama3" + default: + return "default" + } +} + // buildEnvVars builds the list of environment variables for the .env file. func buildEnvVars(opts *initOptions) []envVarEntry { var vars []envVarEntry @@ -831,6 +871,34 @@ func buildEnvVars(opts *initOptions) []envVarEntry { } } + // Fallback provider env vars + fallbackKeyMap := map[string]string{ + "openai": "OPENAI_API_KEY", + "anthropic": "ANTHROPIC_API_KEY", + "gemini": "GEMINI_API_KEY", + } + for _, fb := range opts.Fallbacks { + envKey, ok := fallbackKeyMap[fb.Provider] + if !ok || fb.APIKey == "" { + continue + } + // Skip if already written (e.g., primary provider) + alreadyWritten := false + for _, v := range vars { + if v.Key == envKey { + alreadyWritten = true + break + } + } + if !alreadyWritten { + vars = append(vars, envVarEntry{ + Key: envKey, + Value: fb.APIKey, + Comment: fmt.Sprintf("%s API key (fallback)", titleCase(fb.Provider)), + }) + } + } + // Skill env vars (skip keys already added above) written := make(map[string]bool) for _, v := range vars { @@ -882,6 +950,25 @@ func containsStr(slice []string, val string) bool { return false } +// runOAuthFlow executes the OAuth browser flow for a provider and returns the access token. +func runOAuthFlow(provider string) (string, error) { + var config oauth.ProviderConfig + switch provider { + case "openai": + config = oauth.OpenAIConfig() + default: + return "", fmt.Errorf("OAuth not supported for provider %q", provider) + } + + flow := oauth.NewFlow(config) + token, err := flow.Execute(context.Background(), provider) + if err != nil { + return "", err + } + + return token.AccessToken, nil +} + // titleCase capitalizes the first letter of a string. func titleCase(s string) string { if s == "" { diff --git a/forge-cli/cmd/init_egress.go b/forge-cli/cmd/init_egress.go index 48ef33e..77fc237 100644 --- a/forge-cli/cmd/init_egress.go +++ b/forge-cli/cmd/init_egress.go @@ -28,10 +28,15 @@ func deriveEgressDomains(opts *initOptions, skills []contract.SkillDescriptor) [ } } - // 1. Provider domain + // 1. Provider domains (primary + fallbacks) if d, ok := providerDomains[opts.ModelProvider]; ok { add(d) } + for _, fb := range opts.Fallbacks { + if d, ok := providerDomains[fb.Provider]; ok { + add(d) + } + } // 2. Channel domains for _, d := range security.ResolveCapabilities(opts.Channels) { diff --git a/forge-cli/cmd/init_test.go b/forge-cli/cmd/init_test.go index 26db3e1..c54dc5d 100644 --- a/forge-cli/cmd/init_test.go +++ b/forge-cli/cmd/init_test.go @@ -598,7 +598,7 @@ func TestBuildTemplateData_DefaultModels(t *testing.T) { provider string expectedModel string }{ - {"openai", "gpt-4o-mini"}, + {"openai", "gpt-5.2-2025-12-11"}, {"anthropic", "claude-sonnet-4-20250514"}, {"gemini", "gemini-2.5-flash"}, {"ollama", "llama3"}, diff --git a/forge-cli/internal/tui/components/multi_select.go b/forge-cli/internal/tui/components/multi_select.go index b989b76..54340b2 100644 --- a/forge-cli/internal/tui/components/multi_select.go +++ b/forge-cli/internal/tui/components/multi_select.go @@ -78,6 +78,18 @@ func (m MultiSelect) Update(msg tea.Msg) (MultiSelect, tea.Cmd) { case " ": m.Items[m.cursor].Checked = !m.Items[m.cursor].Checked case "enter": + // If nothing is checked, auto-check the cursor item so the user + // doesn't have to remember Space+Enter for a single selection. + anyChecked := false + for _, item := range m.Items { + if item.Checked { + anyChecked = true + break + } + } + if !anyChecked && len(m.Items) > 0 { + m.Items[m.cursor].Checked = true + } m.done = true } } diff --git a/forge-cli/internal/tui/steps/fallback_step.go b/forge-cli/internal/tui/steps/fallback_step.go new file mode 100644 index 0000000..ee365f6 --- /dev/null +++ b/forge-cli/internal/tui/steps/fallback_step.go @@ -0,0 +1,350 @@ +package steps + +import ( + "fmt" + "strings" + + tea "github.com/charmbracelet/bubbletea" + + "github.com/initializ/forge/forge-cli/internal/tui" + "github.com/initializ/forge/forge-cli/internal/tui/components" +) + +type fallbackPhase int + +const ( + fallbackAskPhase fallbackPhase = iota + fallbackSelectPhase + fallbackKeyPhase + fallbackDonePhase +) + +// FallbackStep handles fallback provider selection and API key collection. +type FallbackStep struct { + styles *tui.StyleSet + phase fallbackPhase + askSelector components.SingleSelect + multiSelector components.MultiSelect + keyInput components.SecretInput + complete bool + primaryProv string + selected []string // providers selected by user + collected []tui.FallbackProvider + keyIndex int // which selected provider we're collecting a key for + validateFn ValidateKeyFunc + validating bool +} + +// NewFallbackStep creates a new fallback provider wizard step. +func NewFallbackStep(styles *tui.StyleSet, validateFn ValidateKeyFunc) *FallbackStep { + return &FallbackStep{ + styles: styles, + validateFn: validateFn, + } +} + +// Prepare is called by the wizard to provide the primary provider context. +func (s *FallbackStep) Prepare(ctx *tui.WizardContext) { + s.primaryProv = ctx.Provider + s.complete = false + s.phase = fallbackAskPhase + s.selected = nil + s.collected = nil + s.keyIndex = 0 +} + +func (s *FallbackStep) Title() string { return "Fallback Providers" } +func (s *FallbackStep) Icon() string { return "๐" } + +func (s *FallbackStep) Init() tea.Cmd { + // Build the "Add fallback providers?" selector + items := []components.SingleSelectItem{ + {Label: "No", Value: "no", Description: "Use only the primary provider", Icon: "โญ๏ธ"}, + {Label: "Yes", Value: "yes", Description: "Configure fallback providers for reliability", Icon: "โ "}, + } + s.askSelector = components.NewSingleSelect( + items, + s.styles.Theme.Accent, + s.styles.Theme.Primary, + s.styles.Theme.Secondary, + s.styles.Theme.Dim, + s.styles.Theme.Border, + s.styles.Theme.ActiveBorder, + s.styles.Theme.ActiveBg, + s.styles.KbdKey, + s.styles.KbdDesc, + ) + return s.askSelector.Init() +} + +func (s *FallbackStep) Update(msg tea.Msg) (tui.Step, tea.Cmd) { + if s.complete { + return s, nil + } + + switch s.phase { + case fallbackAskPhase: + return s.updateAskPhase(msg) + case fallbackSelectPhase: + return s.updateSelectPhase(msg) + case fallbackKeyPhase: + return s.updateKeyPhase(msg) + } + + return s, nil +} + +func (s *FallbackStep) updateAskPhase(msg tea.Msg) (tui.Step, tea.Cmd) { + // Handle backspace for back navigation + if msg, ok := msg.(tea.KeyMsg); ok && msg.String() == "backspace" { + return s, func() tea.Msg { return tui.StepBackMsg{} } + } + + updated, cmd := s.askSelector.Update(msg) + s.askSelector = updated + + if s.askSelector.Done() { + _, val := s.askSelector.Selected() + if val == "no" { + s.complete = true + return s, func() tea.Msg { return tui.StepCompleteMsg{} } + } + // Yes โ show multi-select of providers + s.phase = fallbackSelectPhase + s.buildMultiSelect() + return s, s.multiSelector.Init() + } + + return s, cmd +} + +func (s *FallbackStep) buildMultiSelect() { + allProviders := []struct { + label, value, desc, icon string + }{ + {"OpenAI", "openai", "GPT-4o, GPT-4o-mini", "๐ท"}, + {"Anthropic", "anthropic", "Claude Sonnet, Haiku, Opus", "๐ "}, + {"Google Gemini", "gemini", "Gemini 2.5 Flash, Pro", "๐ต"}, + {"Ollama (local)", "ollama", "Run models locally, no API key needed", "๐ฆ"}, + } + + var items []components.MultiSelectItem + for _, p := range allProviders { + if p.value == s.primaryProv { + continue // exclude primary + } + items = append(items, components.MultiSelectItem{ + Label: p.label, + Value: p.value, + Description: p.desc, + Icon: p.icon, + }) + } + + s.multiSelector = components.NewMultiSelect( + items, + s.styles.Theme.Accent, + s.styles.Theme.AccentDim, + s.styles.Theme.Primary, + s.styles.Theme.Secondary, + s.styles.Theme.Dim, + s.styles.ActiveBorder, + s.styles.InactiveBorder, + s.styles.KbdKey, + s.styles.KbdDesc, + ) +} + +func (s *FallbackStep) updateSelectPhase(msg tea.Msg) (tui.Step, tea.Cmd) { + // Handle backspace to go back to ask phase + if msg, ok := msg.(tea.KeyMsg); ok && msg.String() == "backspace" { + s.phase = fallbackAskPhase + s.askSelector.Reset() + return s, s.askSelector.Init() + } + + updated, cmd := s.multiSelector.Update(msg) + s.multiSelector = updated + + if s.multiSelector.Done() { + s.selected = s.multiSelector.SelectedValues() + if len(s.selected) == 0 { + // No providers selected โ skip + s.complete = true + return s, func() tea.Msg { return tui.StepCompleteMsg{} } + } + // Start collecting API keys + s.keyIndex = 0 + return s.advanceToNextKey() + } + + return s, cmd +} + +func (s *FallbackStep) advanceToNextKey() (tui.Step, tea.Cmd) { + // Skip providers that don't need keys + for s.keyIndex < len(s.selected) { + if s.selected[s.keyIndex] == "ollama" { + s.collected = append(s.collected, tui.FallbackProvider{ + Provider: "ollama", + }) + s.keyIndex++ + continue + } + break + } + + if s.keyIndex >= len(s.selected) { + // All keys collected + s.complete = true + return s, func() tea.Msg { return tui.StepCompleteMsg{} } + } + + provider := s.selected[s.keyIndex] + s.phase = fallbackKeyPhase + label := fmt.Sprintf("%s API Key (fallback)", providerDisplayName(provider)) + s.keyInput = components.NewSecretInput( + label, true, + s.styles.Theme.Accent, + s.styles.Theme.Success, + s.styles.Theme.Error, + s.styles.Theme.Border, + s.styles.AccentTxt, + s.styles.InactiveBorder, + s.styles.SuccessTxt, + s.styles.ErrorTxt, + s.styles.DimTxt, + s.styles.KbdKey, + s.styles.KbdDesc, + ) + return s, s.keyInput.Init() +} + +func (s *FallbackStep) updateKeyPhase(msg tea.Msg) (tui.Step, tea.Cmd) { + // Handle validation result + if msg, ok := msg.(tui.ValidationResultMsg); ok { + s.validating = false + provider := s.selected[s.keyIndex] + if msg.Err != nil { + // Retry key input + label := fmt.Sprintf("%s API Key (retry โ %s)", providerDisplayName(provider), msg.Err) + s.keyInput = components.NewSecretInput( + label, true, + s.styles.Theme.Accent, + s.styles.Theme.Success, + s.styles.Theme.Error, + s.styles.Theme.Border, + s.styles.AccentTxt, + s.styles.InactiveBorder, + s.styles.SuccessTxt, + s.styles.ErrorTxt, + s.styles.DimTxt, + s.styles.KbdKey, + s.styles.KbdDesc, + ) + s.keyInput.SetState(components.SecretInputFailed, msg.Err.Error()) + return s, s.keyInput.Init() + } + // Success + s.collected = append(s.collected, tui.FallbackProvider{ + Provider: provider, + APIKey: s.keyInput.Value(), + }) + s.keyIndex++ + return s.advanceToNextKey() + } + + // Handle backspace at empty to go back + if msg, ok := msg.(tea.KeyMsg); ok && msg.String() == "backspace" { + if s.keyInput.Value() == "" { + s.phase = fallbackSelectPhase + s.multiSelector.Reset() + s.collected = nil + s.keyIndex = 0 + return s, s.multiSelector.Init() + } + } + + updated, cmd := s.keyInput.Update(msg) + s.keyInput = updated + + if s.keyInput.Done() { + key := s.keyInput.Value() + provider := s.selected[s.keyIndex] + if key == "" { + // Skipped โ add without key + s.collected = append(s.collected, tui.FallbackProvider{ + Provider: provider, + }) + s.keyIndex++ + return s.advanceToNextKey() + } + // Validate the key + if s.validateFn != nil { + s.validating = true + validateFn := s.validateFn + return s, func() tea.Msg { + err := validateFn(provider, key) + return tui.ValidationResultMsg{Err: err} + } + } + // No validation function โ accept + s.collected = append(s.collected, tui.FallbackProvider{ + Provider: provider, + APIKey: key, + }) + s.keyIndex++ + return s.advanceToNextKey() + } + + return s, cmd +} + +func (s *FallbackStep) View(width int) string { + switch s.phase { + case fallbackAskPhase: + return s.askSelector.View(width) + case fallbackSelectPhase: + return s.multiSelector.View(width) + case fallbackKeyPhase: + if s.validating { + return " " + s.styles.AccentTxt.Render("โฃพ Validating...") + "\n" + } + return s.keyInput.View(width) + } + return "" +} + +func (s *FallbackStep) Complete() bool { + return s.complete +} + +func (s *FallbackStep) Summary() string { + if len(s.collected) == 0 { + return "none" + } + var names []string + for _, fb := range s.collected { + names = append(names, providerDisplayName(fb.Provider)) + } + return strings.Join(names, ", ") +} + +func (s *FallbackStep) Apply(ctx *tui.WizardContext) { + ctx.Fallbacks = s.collected + + // Store fallback provider keys in env vars + for _, fb := range s.collected { + if fb.APIKey == "" { + continue + } + switch fb.Provider { + case "openai": + ctx.EnvVars["OPENAI_API_KEY"] = fb.APIKey + case "anthropic": + ctx.EnvVars["ANTHROPIC_API_KEY"] = fb.APIKey + case "gemini": + ctx.EnvVars["GEMINI_API_KEY"] = fb.APIKey + } + } +} diff --git a/forge-cli/internal/tui/steps/provider_step.go b/forge-cli/internal/tui/steps/provider_step.go index 09acfde..cc8687a 100644 --- a/forge-cli/internal/tui/steps/provider_step.go +++ b/forge-cli/internal/tui/steps/provider_step.go @@ -13,39 +13,73 @@ type providerPhase int const ( providerSelectPhase providerPhase = iota + providerAuthMethodPhase providerKeyPhase providerValidatingPhase + providerOAuthPhase + providerModelPhase providerCustomURLPhase providerCustomModelPhase providerCustomAuthPhase providerDonePhase ) +// OAuthFlowFunc is a function that runs the OAuth flow and returns the access token. +type OAuthFlowFunc func(provider string) (accessToken string, err error) + // ValidateKeyFunc validates an API key for a provider. type ValidateKeyFunc func(provider, key string) error +// modelOption maps a user-friendly display name to the actual model ID. +type modelOption struct { + DisplayName string + ModelID string +} + +// openAIOAuthModels are available when using browser-based OAuth login. +var openAIOAuthModels = []modelOption{ + {DisplayName: "GPT 5.3 Codex", ModelID: "gpt-5.3-codex"}, + {DisplayName: "GPT 5.2", ModelID: "gpt-5.2-2025-12-11"}, + {DisplayName: "GPT 5.2 Codex", ModelID: "gpt-5.2-codex"}, +} + +// openAIAPIKeyModels are available when using an API key. +var openAIAPIKeyModels = []modelOption{ + {DisplayName: "GPT 5.2", ModelID: "gpt-5.2-2025-12-11"}, + {DisplayName: "GPT 5 Mini", ModelID: "gpt-5-mini-2025-08-07"}, + {DisplayName: "GPT 5 Nano", ModelID: "gpt-5-nano-2025-08-07"}, + {DisplayName: "GPT 4.1 Mini", ModelID: "gpt-4.1-mini-2025-04-14"}, +} + // ProviderStep handles model provider selection and API key entry. type ProviderStep struct { - styles *tui.StyleSet - phase providerPhase - selector components.SingleSelect - keyInput components.SecretInput - textInput components.TextInput - complete bool - provider string - apiKey string - customURL string - customModel string - customAuth string - validateFn ValidateKeyFunc - validating bool - valErr error + styles *tui.StyleSet + phase providerPhase + selector components.SingleSelect + authMethodSelector components.SingleSelect + modelSelector components.SingleSelect + keyInput components.SecretInput + textInput components.TextInput + complete bool + provider string + apiKey string + authMethod string // "apikey" or "oauth" + modelID string // selected model ID + customURL string + customModel string + customAuth string + validateFn ValidateKeyFunc + oauthFn OAuthFlowFunc + validating bool + valErr error + oauthRunning bool } // NewProviderStep creates a new provider selection step. -func NewProviderStep(styles *tui.StyleSet, validateFn ValidateKeyFunc) *ProviderStep { +// oauthFn is optional โ pass nil to disable OAuth login. +func NewProviderStep(styles *tui.StyleSet, validateFn ValidateKeyFunc, oauthFn ...OAuthFlowFunc) *ProviderStep { items := []components.SingleSelectItem{ - {Label: "OpenAI", Value: "openai", Description: "GPT-4o, GPT-4o-mini", Icon: "๐ท"}, + {Label: "OpenAI", Value: "openai", Description: "GPT 5.3 Codex, GPT 5.2, GPT 5 Mini", Icon: "๐ท"}, {Label: "Anthropic", Value: "anthropic", Description: "Claude Sonnet, Haiku, Opus", Icon: "๐ "}, {Label: "Google Gemini", Value: "gemini", Description: "Gemini 2.5 Flash, Pro", Icon: "๐ต"}, {Label: "Ollama (local)", Value: "ollama", Description: "Run models locally, no API key needed", Icon: "๐ฆ"}, @@ -65,10 +99,16 @@ func NewProviderStep(styles *tui.StyleSet, validateFn ValidateKeyFunc) *Provider styles.KbdDesc, ) + var oFn OAuthFlowFunc + if len(oauthFn) > 0 { + oFn = oauthFn[0] + } + return &ProviderStep{ styles: styles, selector: selector, validateFn: validateFn, + oauthFn: oFn, } } @@ -87,10 +127,16 @@ func (s *ProviderStep) Update(msg tea.Msg) (tui.Step, tea.Cmd) { switch s.phase { case providerSelectPhase: return s.updateSelectPhase(msg) + case providerAuthMethodPhase: + return s.updateAuthMethodPhase(msg) case providerKeyPhase: return s.updateKeyPhase(msg) case providerValidatingPhase: return s.updateValidatingPhase(msg) + case providerOAuthPhase: + return s.updateOAuthPhase(msg) + case providerModelPhase: + return s.updateModelPhase(msg) case providerCustomURLPhase: return s.updateCustomURLPhase(msg) case providerCustomModelPhase: @@ -131,6 +177,31 @@ func (s *ProviderStep) updateSelectPhase(msg tea.Msg) (tui.Step, tea.Cmd) { s.styles.KbdDesc, ) return s, s.textInput.Init() + case "openai": + // If OAuth is available, show auth method choice + if s.oauthFn != nil { + s.phase = providerAuthMethodPhase + items := []components.SingleSelectItem{ + {Label: "Enter API Key", Value: "apikey", Description: "Paste your OpenAI API key", Icon: "๐"}, + {Label: "Login with OpenAI", Value: "oauth", Description: "Browser-based login (OAuth)", Icon: "๐"}, + } + s.authMethodSelector = components.NewSingleSelect( + items, + s.styles.Theme.Accent, + s.styles.Theme.Primary, + s.styles.Theme.Secondary, + s.styles.Theme.Dim, + s.styles.Theme.Border, + s.styles.Theme.ActiveBorder, + s.styles.Theme.ActiveBg, + s.styles.KbdKey, + s.styles.KbdDesc, + ) + return s, s.authMethodSelector.Init() + } + // No OAuth โ fall through to API key + s.authMethod = "apikey" + fallthrough default: // openai, anthropic, gemini โ ask for key s.phase = providerKeyPhase @@ -156,6 +227,85 @@ func (s *ProviderStep) updateSelectPhase(msg tea.Msg) (tui.Step, tea.Cmd) { return s, cmd } +func (s *ProviderStep) updateAuthMethodPhase(msg tea.Msg) (tui.Step, tea.Cmd) { + // Handle backspace to go back to provider selector + if msg, ok := msg.(tea.KeyMsg); ok && msg.String() == "backspace" { + s.phase = providerSelectPhase + s.provider = "" + s.selector.Reset() + return s, s.selector.Init() + } + + updated, cmd := s.authMethodSelector.Update(msg) + s.authMethodSelector = updated + + if s.authMethodSelector.Done() { + _, val := s.authMethodSelector.Selected() + s.authMethod = val + if val == "oauth" { + // Run OAuth flow + s.phase = providerOAuthPhase + s.oauthRunning = true + oauthFn := s.oauthFn + return s, func() tea.Msg { + _, err := oauthFn("openai") + return tui.ValidationResultMsg{Err: err} + } + } + // API key method + s.phase = providerKeyPhase + label := fmt.Sprintf("%s API Key", providerDisplayName(s.provider)) + s.keyInput = components.NewSecretInput( + label, true, + s.styles.Theme.Accent, + s.styles.Theme.Success, + s.styles.Theme.Error, + s.styles.Theme.Border, + s.styles.AccentTxt, + s.styles.InactiveBorder, + s.styles.SuccessTxt, + s.styles.ErrorTxt, + s.styles.DimTxt, + s.styles.KbdKey, + s.styles.KbdDesc, + ) + return s, s.keyInput.Init() + } + + return s, cmd +} + +func (s *ProviderStep) updateOAuthPhase(msg tea.Msg) (tui.Step, tea.Cmd) { + if msg, ok := msg.(tui.ValidationResultMsg); ok { + s.oauthRunning = false + if msg.Err != nil { + // OAuth failed โ fall back to API key entry + s.phase = providerKeyPhase + label := fmt.Sprintf("%s API Key (OAuth failed โ %s)", providerDisplayName(s.provider), msg.Err) + s.keyInput = components.NewSecretInput( + label, true, + s.styles.Theme.Accent, + s.styles.Theme.Success, + s.styles.Theme.Error, + s.styles.Theme.Border, + s.styles.AccentTxt, + s.styles.InactiveBorder, + s.styles.SuccessTxt, + s.styles.ErrorTxt, + s.styles.DimTxt, + s.styles.KbdKey, + s.styles.KbdDesc, + ) + return s, s.keyInput.Init() + } + // OAuth succeeded โ show model selection + s.apiKey = "__oauth__" + return s, s.showModelSelector() + } + + return s, nil +} + func (s *ProviderStep) updateKeyPhase(msg tea.Msg) (tui.Step, tea.Cmd) { // Handle backspace at empty input โ go back to provider selector (internal back) if msg, ok := msg.(tea.KeyMsg); ok && msg.String() == "backspace" { @@ -216,6 +366,10 @@ func (s *ProviderStep) updateValidatingPhase(msg tea.Msg) (tui.Step, tea.Cmd) { s.complete = true return s, func() tea.Msg { return tui.StepCompleteMsg{} } } + // Validation passed โ show model selection for OpenAI + if s.provider == "openai" { + return s, s.showModelSelector() + } s.complete = true return s, func() tea.Msg { return tui.StepCompleteMsg{} } } @@ -223,6 +377,53 @@ func (s *ProviderStep) updateValidatingPhase(msg tea.Msg) (tui.Step, tea.Cmd) { return s, nil } +// showModelSelector sets up the model selection phase for OpenAI. +func (s *ProviderStep) showModelSelector() tea.Cmd { + var models []modelOption + if s.authMethod == "oauth" { + models = openAIOAuthModels + } else { + models = openAIAPIKeyModels + } + + items := make([]components.SingleSelectItem, len(models)) + for i, m := range models { + items[i] = components.SingleSelectItem{ + Label: m.DisplayName, + Value: m.ModelID, + } + } + + s.modelSelector = components.NewSingleSelect( + items, + s.styles.Theme.Accent, + s.styles.Theme.Primary, + s.styles.Theme.Secondary, + s.styles.Theme.Dim, + s.styles.Theme.Border, + s.styles.Theme.ActiveBorder, + s.styles.Theme.ActiveBg, + s.styles.KbdKey, + s.styles.KbdDesc, + ) + s.phase = providerModelPhase + return s.modelSelector.Init() +} + +func (s *ProviderStep) updateModelPhase(msg tea.Msg) (tui.Step, tea.Cmd) { + updated, cmd := s.modelSelector.Update(msg) + s.modelSelector = updated + + if s.modelSelector.Done() { + _, val := s.modelSelector.Selected() + s.modelID = val + s.complete = true + return s, func() tea.Msg { return tui.StepCompleteMsg{} } + } + + return s, cmd +} + func (s *ProviderStep) updateCustomURLPhase(msg tea.Msg) (tui.Step, tea.Cmd) { updated, cmd := s.textInput.Update(msg) s.textInput = updated @@ -306,6 +507,8 @@ func (s *ProviderStep) View(width int) string { switch s.phase { case providerSelectPhase: return s.selector.View(width) + case providerAuthMethodPhase: + return s.authMethodSelector.View(width) case providerKeyPhase: return s.keyInput.View(width) case providerValidatingPhase: @@ -313,6 +516,13 @@ func (s *ProviderStep) View(width int) string { return " " + s.styles.AccentTxt.Render("โฃพ Validating...") + "\n" } return s.keyInput.View(width) + case providerOAuthPhase: + if s.oauthRunning { + return " " + s.styles.AccentTxt.Render("โฃพ Waiting for browser authorization...") + "\n" + } + return "" + case providerModelPhase: + return s.modelSelector.View(width) case providerCustomURLPhase, providerCustomModelPhase: return s.textInput.View(width) case providerCustomAuthPhase: @@ -327,13 +537,16 @@ func (s *ProviderStep) Complete() bool { func (s *ProviderStep) Summary() string { name := providerDisplayName(s.provider) + if s.modelID != "" { + return name + " ยท " + modelDisplayName(s.modelID) + } switch s.provider { case "openai": - return name + " ยท gpt-4o-mini" + return name + " ยท GPT 5.2" case "anthropic": - return name + " ยท claude-sonnet-4-20250514" + return name + " ยท Claude Sonnet 4" case "gemini": - return name + " ยท gemini-2.5-flash" + return name + " ยท Gemini 2.5 Flash" case "ollama": return name + " ยท llama3" case "custom": @@ -348,6 +561,8 @@ func (s *ProviderStep) Summary() string { func (s *ProviderStep) Apply(ctx *tui.WizardContext) { ctx.Provider = s.provider ctx.APIKey = s.apiKey + ctx.AuthMethod = s.authMethod + ctx.ModelName = s.modelID ctx.CustomBaseURL = s.customURL ctx.CustomModel = s.customModel ctx.CustomAPIKey = s.customAuth @@ -366,6 +581,22 @@ func (s *ProviderStep) Apply(ctx *tui.WizardContext) { } } +// modelDisplayName returns the user-friendly name for a model ID. +func modelDisplayName(modelID string) string { + // Check all model lists + for _, m := range openAIOAuthModels { + if m.ModelID == modelID { + return m.DisplayName + } + } + for _, m := range openAIAPIKeyModels { + if m.ModelID == modelID { + return m.DisplayName + } + } + return modelID +} + func providerDisplayName(provider string) string { switch provider { case "openai": diff --git a/forge-cli/internal/tui/steps/review_step.go b/forge-cli/internal/tui/steps/review_step.go index ce8b79c..080a882 100644 --- a/forge-cli/internal/tui/steps/review_step.go +++ b/forge-cli/internal/tui/steps/review_step.go @@ -40,6 +40,14 @@ func (s *ReviewStep) Prepare(ctx *tui.WizardContext) { rows = append(rows, components.SummaryRow{Key: "Name", Value: ctx.Name}) rows = append(rows, components.SummaryRow{Key: "Provider", Value: providerDisplayName(ctx.Provider)}) + if len(ctx.Fallbacks) > 0 { + var fbNames []string + for _, fb := range ctx.Fallbacks { + fbNames = append(fbNames, providerDisplayName(fb.Provider)) + } + rows = append(rows, components.SummaryRow{Key: "Fallbacks", Value: strings.Join(fbNames, ", ")}) + } + if ctx.Channel != "" && ctx.Channel != "none" { rows = append(rows, components.SummaryRow{Key: "Channel", Value: ctx.Channel}) } diff --git a/forge-cli/internal/tui/wizard.go b/forge-cli/internal/tui/wizard.go index 1ff8bca..e455e67 100644 --- a/forge-cli/internal/tui/wizard.go +++ b/forge-cli/internal/tui/wizard.go @@ -6,11 +6,20 @@ import ( tea "github.com/charmbracelet/bubbletea" ) +// FallbackProvider holds a fallback provider selection from the wizard. +type FallbackProvider struct { + Provider string + APIKey string +} + // WizardContext accumulates all data across wizard steps. type WizardContext struct { Name string Provider string APIKey string + AuthMethod string // "apikey" or "oauth" โ how the user authenticated + ModelName string // selected model ID (e.g. "gpt-5.3-codex") + Fallbacks []FallbackProvider Channel string ChannelTokens map[string]string BuiltinTools []string diff --git a/forge-cli/runtime/runner.go b/forge-cli/runtime/runner.go index 2177af9..4b8e06a 100644 --- a/forge-cli/runtime/runner.go +++ b/forge-cli/runtime/runner.go @@ -14,6 +14,8 @@ import ( clitools "github.com/initializ/forge/forge-cli/tools" "github.com/initializ/forge/forge-core/a2a" "github.com/initializ/forge/forge-core/agentspec" + "github.com/initializ/forge/forge-core/llm" + "github.com/initializ/forge/forge-core/llm/oauth" "github.com/initializ/forge/forge-core/llm/providers" coreruntime "github.com/initializ/forge/forge-core/runtime" "github.com/initializ/forge/forge-core/tools" @@ -42,6 +44,7 @@ type Runner struct { cfg RunnerConfig logger coreruntime.Logger cliExecTool *clitools.CLIExecuteTool + modelConfig *coreruntime.ModelConfig // resolved model config (for banner) } // NewRunner creates a Runner from the given config. @@ -152,7 +155,8 @@ func (r *Runner) Run(ctx context.Context) error { // Try LLM executor, fall back to stub mc := coreruntime.ResolveModelConfig(r.cfg.Config, envVars, r.cfg.ProviderOverride) if mc != nil { - llmClient, llmErr := providers.NewClient(mc.Provider, mc.Client) + r.modelConfig = mc + llmClient, llmErr := r.buildLLMClient(mc) if llmErr != nil { r.logger.Warn("failed to create LLM client, using stub", map[string]any{"error": llmErr.Error()}) executor = NewStubExecutor(r.cfg.Config.Framework) @@ -167,10 +171,12 @@ func (r *Runner) Run(ctx context.Context) error { Hooks: hooks, SystemPrompt: fmt.Sprintf("You are %s, an AI agent.", r.cfg.Config.AgentID), }) + r.logger.Info("using LLM executor", map[string]any{ - "provider": mc.Provider, - "model": mc.Client.Model, - "tools": len(toolNames), + "provider": mc.Provider, + "model": mc.Client.Model, + "tools": len(toolNames), + "fallbacks": len(mc.Fallbacks), }) } } else { @@ -506,6 +512,70 @@ func (r *Runner) registerLoggingHooks(hooks *coreruntime.HookRegistry) { }) } +// buildLLMClient creates the LLM client from the resolved model config. +// If fallback providers are configured, wraps them in a FallbackChain. +func (r *Runner) buildLLMClient(mc *coreruntime.ModelConfig) (llm.Client, error) { + primaryClient, err := r.createProviderClient(mc.Provider, mc.Client) + if err != nil { + return nil, err + } + + // No fallbacks โ return primary client directly + if len(mc.Fallbacks) == 0 { + return primaryClient, nil + } + + // Build fallback chain + candidates := []llm.FallbackCandidate{ + {Provider: mc.Provider, Model: mc.Client.Model, Client: primaryClient}, + } + for _, fb := range mc.Fallbacks { + fbClient, fbErr := r.createProviderClient(fb.Provider, fb.Client) + if fbErr != nil { + r.logger.Warn("skipping fallback provider", map[string]any{ + "provider": fb.Provider, "error": fbErr.Error(), + }) + continue + } + candidates = append(candidates, llm.FallbackCandidate{ + Provider: fb.Provider, + Model: fb.Client.Model, + Client: fbClient, + }) + } + + return llm.NewFallbackChain(candidates), nil +} + +// createProviderClient creates an LLM client for a provider, using OAuth +// credentials if available for supported providers. +func (r *Runner) createProviderClient(provider string, cfg llm.ClientConfig) (llm.Client, error) { + // Check for stored OAuth credentials โ but only if no API key is already + // configured. A real API key means the user chose API-key auth and we + // should use the standard OpenAI Chat Completions endpoint, not the + // Codex Responses endpoint that OAuth tokens require. + if provider == "openai" && cfg.APIKey == "" { + token, err := oauth.LoadCredentials(provider) + if err == nil && token != nil && token.RefreshToken != "" { + oauthCfg := oauth.OpenAIConfig() + // Use token's base URL, or fall back to the OAuth config default + baseURL := token.BaseURL + if baseURL == "" { + baseURL = oauthCfg.BaseURL + } + r.logger.Info("using OAuth credentials for provider", map[string]any{ + "provider": provider, + "base_url": baseURL, + }) + cfg.APIKey = token.AccessToken + cfg.BaseURL = baseURL + return providers.NewOAuthClient(cfg, provider, oauthCfg), nil + } + } + + return providers.NewClient(provider, cfg) +} + func (r *Runner) printBanner() { fmt.Fprintf(os.Stderr, "\n") fmt.Fprintf(os.Stderr, " Forge Dev Server\n") @@ -518,6 +588,17 @@ func (r *Runner) printBanner() { } else { fmt.Fprintf(os.Stderr, " Entrypoint: %s\n", r.cfg.Config.Entrypoint) } + // Model info + if r.modelConfig != nil { + fmt.Fprintf(os.Stderr, " Model: %s/%s\n", r.modelConfig.Provider, r.modelConfig.Client.Model) + if len(r.modelConfig.Fallbacks) > 0 { + var fbNames []string + for _, fb := range r.modelConfig.Fallbacks { + fbNames = append(fbNames, fb.Provider+"/"+fb.Client.Model) + } + fmt.Fprintf(os.Stderr, " Fallbacks: %s\n", strings.Join(fbNames, ", ")) + } + } // Tools if len(r.cfg.Config.Tools) > 0 { names := make([]string, 0, len(r.cfg.Config.Tools)) diff --git a/forge-cli/templates/init/forge.yaml.tmpl b/forge-cli/templates/init/forge.yaml.tmpl index d672e13..26be970 100644 --- a/forge-cli/templates/init/forge.yaml.tmpl +++ b/forge-cli/templates/init/forge.yaml.tmpl @@ -7,6 +7,13 @@ model: provider: {{.ModelProvider}} name: {{.ModelName}} version: "latest" +{{- if .Fallbacks}} + fallbacks: +{{- range .Fallbacks}} + - provider: {{.Provider}} + name: {{.ModelName}} +{{- end}} +{{- end}} {{- if .Tools}} tools: diff --git a/forge-core/llm/cooldown.go b/forge-core/llm/cooldown.go new file mode 100644 index 0000000..e09f55d --- /dev/null +++ b/forge-core/llm/cooldown.go @@ -0,0 +1,112 @@ +package llm + +import ( + "sync" + "time" +) + +// cooldownEntry tracks failure state for a single provider. +type cooldownEntry struct { + count int + reason FailoverReason + lastFail time.Time +} + +// CooldownTracker manages per-provider cooldown state with exponential backoff. +type CooldownTracker struct { + mu sync.RWMutex + entries map[string]*cooldownEntry + nowFunc func() time.Time +} + +// NewCooldownTracker creates a new cooldown tracker. +func NewCooldownTracker() *CooldownTracker { + return &CooldownTracker{ + entries: make(map[string]*cooldownEntry), + nowFunc: time.Now, + } +} + +// MarkFailure records a failure for the given provider. +func (ct *CooldownTracker) MarkFailure(provider string, reason FailoverReason) { + ct.mu.Lock() + defer ct.mu.Unlock() + + e, ok := ct.entries[provider] + if !ok { + e = &cooldownEntry{} + ct.entries[provider] = e + } + e.count++ + e.reason = reason + e.lastFail = ct.nowFunc() +} + +// MarkSuccess resets all cooldown state for a provider. +func (ct *CooldownTracker) MarkSuccess(provider string) { + ct.mu.Lock() + defer ct.mu.Unlock() + delete(ct.entries, provider) +} + +// IsAvailable returns true if the provider is not currently in cooldown. +func (ct *CooldownTracker) IsAvailable(provider string) bool { + ct.mu.RLock() + defer ct.mu.RUnlock() + + e, ok := ct.entries[provider] + if !ok { + return true + } + + dur := cooldownDuration(e.reason, e.count) + return ct.nowFunc().After(e.lastFail.Add(dur)) +} + +// cooldownDuration returns the cooldown period based on reason and failure count. +// +// Standard errors (rate_limit, overloaded, timeout, unknown): +// +// count 1: 1 min, count 2: 5 min, count 3: 25 min, count 4+: 1 hour (cap) +// +// Billing errors: +// +// count 1: 5 hours, count 2: 10 hours, count 3: 20 hours, count 4+: 24 hours (cap) +// +// Auth errors: +// +// Always 24 hours (credentials won't fix themselves mid-session) +func cooldownDuration(reason FailoverReason, count int) time.Duration { + if count <= 0 { + return 0 + } + + switch reason { + case FailoverAuth: + return 24 * time.Hour + + case FailoverBilling: + // 5h * 2^(count-1), capped at 24h + base := 5 * time.Hour + d := base + for i := 1; i < count; i++ { + d *= 2 + } + if d > 24*time.Hour { + d = 24 * time.Hour + } + return d + + default: + // Standard: 1min * 5^(count-1), capped at 1h + base := time.Minute + d := base + for i := 1; i < count; i++ { + d *= 5 + } + if d > time.Hour { + d = time.Hour + } + return d + } +} diff --git a/forge-core/llm/cooldown_test.go b/forge-core/llm/cooldown_test.go new file mode 100644 index 0000000..7289c08 --- /dev/null +++ b/forge-core/llm/cooldown_test.go @@ -0,0 +1,240 @@ +package llm + +import ( + "sync" + "testing" + "time" +) + +func TestCooldownTracker_NewProviderAvailable(t *testing.T) { + ct := NewCooldownTracker() + if !ct.IsAvailable("openai") { + t.Error("new provider should be available") + } +} + +func TestCooldownTracker_MarkFailurePutsCooldown(t *testing.T) { + ct := NewCooldownTracker() + now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + ct.nowFunc = func() time.Time { return now } + + ct.MarkFailure("openai", FailoverRateLimit) + + // Immediately after failure, provider should be unavailable + if ct.IsAvailable("openai") { + t.Error("provider should be in cooldown immediately after failure") + } + + // After 30 seconds, still unavailable (1 min cooldown) + ct.nowFunc = func() time.Time { return now.Add(30 * time.Second) } + if ct.IsAvailable("openai") { + t.Error("provider should still be in cooldown after 30s") + } + + // After 61 seconds, available again + ct.nowFunc = func() time.Time { return now.Add(61 * time.Second) } + if !ct.IsAvailable("openai") { + t.Error("provider should be available after cooldown expires") + } +} + +func TestCooldownTracker_ExponentialBackoff(t *testing.T) { + ct := NewCooldownTracker() + now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + ct.nowFunc = func() time.Time { return now } + + // First failure: 1 min cooldown + ct.MarkFailure("openai", FailoverRateLimit) + ct.nowFunc = func() time.Time { return now.Add(61 * time.Second) } + if !ct.IsAvailable("openai") { + t.Error("should be available after 1 min cooldown") + } + + // Second failure: 5 min cooldown + now = now.Add(2 * time.Minute) + ct.nowFunc = func() time.Time { return now } + ct.MarkFailure("openai", FailoverRateLimit) + + ct.nowFunc = func() time.Time { return now.Add(4 * time.Minute) } + if ct.IsAvailable("openai") { + t.Error("should still be in cooldown (5 min required, 4 min elapsed)") + } + + ct.nowFunc = func() time.Time { return now.Add(6 * time.Minute) } + if !ct.IsAvailable("openai") { + t.Error("should be available after 5 min cooldown") + } + + // Third failure: 25 min cooldown + now = now.Add(10 * time.Minute) + ct.nowFunc = func() time.Time { return now } + ct.MarkFailure("openai", FailoverRateLimit) + + ct.nowFunc = func() time.Time { return now.Add(24 * time.Minute) } + if ct.IsAvailable("openai") { + t.Error("should still be in cooldown (25 min required)") + } + + ct.nowFunc = func() time.Time { return now.Add(26 * time.Minute) } + if !ct.IsAvailable("openai") { + t.Error("should be available after 25 min cooldown") + } + + // Fourth failure: capped at 1 hour + now = now.Add(30 * time.Minute) + ct.nowFunc = func() time.Time { return now } + ct.MarkFailure("openai", FailoverRateLimit) + + ct.nowFunc = func() time.Time { return now.Add(59 * time.Minute) } + if ct.IsAvailable("openai") { + t.Error("should still be in cooldown (1 hour cap)") + } + + ct.nowFunc = func() time.Time { return now.Add(61 * time.Minute) } + if !ct.IsAvailable("openai") { + t.Error("should be available after 1 hour cap") + } +} + +func TestCooldownTracker_AuthAlways24Hours(t *testing.T) { + ct := NewCooldownTracker() + now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + ct.nowFunc = func() time.Time { return now } + + ct.MarkFailure("openai", FailoverAuth) + + // 23 hours later: still unavailable + ct.nowFunc = func() time.Time { return now.Add(23 * time.Hour) } + if ct.IsAvailable("openai") { + t.Error("auth failure should have 24h cooldown") + } + + // 25 hours later: available + ct.nowFunc = func() time.Time { return now.Add(25 * time.Hour) } + if !ct.IsAvailable("openai") { + t.Error("should be available after 24h auth cooldown") + } +} + +func TestCooldownTracker_BillingBackoff(t *testing.T) { + ct := NewCooldownTracker() + now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + ct.nowFunc = func() time.Time { return now } + + // First billing failure: 5 hours + ct.MarkFailure("openai", FailoverBilling) + + ct.nowFunc = func() time.Time { return now.Add(4 * time.Hour) } + if ct.IsAvailable("openai") { + t.Error("should be in cooldown (5h required)") + } + + ct.nowFunc = func() time.Time { return now.Add(6 * time.Hour) } + if !ct.IsAvailable("openai") { + t.Error("should be available after 5h billing cooldown") + } +} + +func TestCooldownTracker_MarkSuccessResets(t *testing.T) { + ct := NewCooldownTracker() + now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + ct.nowFunc = func() time.Time { return now } + + // Build up failures + ct.MarkFailure("openai", FailoverRateLimit) + ct.MarkFailure("openai", FailoverRateLimit) + + // Mark success resets everything + ct.MarkSuccess("openai") + + if !ct.IsAvailable("openai") { + t.Error("should be available after MarkSuccess") + } + + // Next failure should be back to count=1 (1 min) + ct.MarkFailure("openai", FailoverRateLimit) + ct.nowFunc = func() time.Time { return now.Add(61 * time.Second) } + if !ct.IsAvailable("openai") { + t.Error("after reset, first failure should have 1 min cooldown") + } +} + +func TestCooldownTracker_IndependentProviders(t *testing.T) { + ct := NewCooldownTracker() + now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + ct.nowFunc = func() time.Time { return now } + + ct.MarkFailure("openai", FailoverRateLimit) + + if !ct.IsAvailable("anthropic") { + t.Error("different provider should not be affected") + } + if ct.IsAvailable("openai") { + t.Error("failed provider should be in cooldown") + } +} + +func TestCooldownTracker_ConcurrentAccess(t *testing.T) { + ct := NewCooldownTracker() + now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + ct.nowFunc = func() time.Time { return now } + + var wg sync.WaitGroup + for range 100 { + wg.Add(3) + go func() { + defer wg.Done() + ct.MarkFailure("openai", FailoverRateLimit) + }() + go func() { + defer wg.Done() + ct.IsAvailable("openai") + }() + go func() { + defer wg.Done() + ct.MarkSuccess("openai") + }() + } + wg.Wait() +} + +func TestCooldownDuration(t *testing.T) { + tests := []struct { + reason FailoverReason + count int + want time.Duration + }{ + // Standard + {FailoverRateLimit, 1, time.Minute}, + {FailoverRateLimit, 2, 5 * time.Minute}, + {FailoverRateLimit, 3, 25 * time.Minute}, + {FailoverRateLimit, 4, time.Hour}, + {FailoverRateLimit, 10, time.Hour}, // capped + {FailoverOverloaded, 1, time.Minute}, + {FailoverTimeout, 1, time.Minute}, + {FailoverUnknown, 1, time.Minute}, + + // Billing + {FailoverBilling, 1, 5 * time.Hour}, + {FailoverBilling, 2, 10 * time.Hour}, + {FailoverBilling, 3, 20 * time.Hour}, + {FailoverBilling, 4, 24 * time.Hour}, + {FailoverBilling, 10, 24 * time.Hour}, // capped + + // Auth + {FailoverAuth, 1, 24 * time.Hour}, + {FailoverAuth, 5, 24 * time.Hour}, + + // Zero count + {FailoverRateLimit, 0, 0}, + } + + for _, tt := range tests { + t.Run("", func(t *testing.T) { + got := cooldownDuration(tt.reason, tt.count) + if got != tt.want { + t.Errorf("cooldownDuration(%s, %d) = %v, want %v", tt.reason, tt.count, got, tt.want) + } + }) + } +} diff --git a/forge-core/llm/errors.go b/forge-core/llm/errors.go new file mode 100644 index 0000000..ebccea6 --- /dev/null +++ b/forge-core/llm/errors.go @@ -0,0 +1,129 @@ +package llm + +import ( + "fmt" + "regexp" + "strconv" + "strings" +) + +// FailoverReason describes why a provider failed. +type FailoverReason string + +const ( + FailoverAuth FailoverReason = "auth" // 401/403 + FailoverRateLimit FailoverReason = "rate_limit" // 429 + FailoverBilling FailoverReason = "billing" // 402 + FailoverTimeout FailoverReason = "timeout" // 408/504/deadline + FailoverOverloaded FailoverReason = "overloaded" // 500/502/503/529 + FailoverFormat FailoverReason = "format" // 400 + FailoverUnknown FailoverReason = "unknown" // unknown error (treated as retriable) +) + +// FailoverError wraps an LLM provider error with classification metadata. +type FailoverError struct { + Reason FailoverReason + Provider string + Model string + Status int + Wrapped error +} + +func (e *FailoverError) Error() string { + if e.Status > 0 { + return fmt.Sprintf("%s/%s failover (%s, status %d): %v", + e.Provider, e.Model, e.Reason, e.Status, e.Wrapped) + } + return fmt.Sprintf("%s/%s failover (%s): %v", + e.Provider, e.Model, e.Reason, e.Wrapped) +} + +func (e *FailoverError) Unwrap() error { + return e.Wrapped +} + +// IsRetriable returns true if this error should trigger a fallback attempt. +// Auth and format errors are never retriable. +func (e *FailoverError) IsRetriable() bool { + return e.Reason != FailoverFormat && e.Reason != FailoverAuth && e.Reason != FailoverBilling +} + +// FallbackExhaustedError is returned when all candidates have been tried and failed. +type FallbackExhaustedError struct { + Errors []*FailoverError +} + +func (e *FallbackExhaustedError) Error() string { + if len(e.Errors) == 0 { + return "all fallback candidates exhausted" + } + parts := make([]string, len(e.Errors)) + for i, fe := range e.Errors { + parts[i] = fe.Error() + } + return fmt.Sprintf("all fallback candidates exhausted: [%s]", strings.Join(parts, "; ")) +} + +// statusRegex matches provider error patterns like "openai error (status 429): ..." +var statusRegex = regexp.MustCompile(`\(status (\d+)\)`) + +// ClassifyError wraps a raw provider error into a FailoverError with the +// appropriate reason. It extracts HTTP status codes from known provider error +// formats and falls back to message pattern matching. +func ClassifyError(err error, provider, model string) *FailoverError { + fe := &FailoverError{ + Provider: provider, + Model: model, + Wrapped: err, + } + + msg := err.Error() + + // Try to extract HTTP status code from provider error format + if matches := statusRegex.FindStringSubmatch(msg); len(matches) == 2 { + if status, parseErr := strconv.Atoi(matches[1]); parseErr == nil { + fe.Status = status + fe.Reason = reasonFromStatus(status) + return fe + } + } + + // Fallback: message pattern matching + lower := strings.ToLower(msg) + switch { + case strings.Contains(lower, "unauthorized") || strings.Contains(lower, "authentication") || + strings.Contains(lower, "invalid api key") || strings.Contains(lower, "permission denied"): + fe.Reason = FailoverAuth + case strings.Contains(lower, "rate limit") || strings.Contains(lower, "too many requests"): + fe.Reason = FailoverRateLimit + case strings.Contains(lower, "timeout") || strings.Contains(lower, "deadline exceeded") || + strings.Contains(lower, "context deadline"): + fe.Reason = FailoverTimeout + case strings.Contains(lower, "overloaded") || strings.Contains(lower, "service unavailable") || + strings.Contains(lower, "bad gateway"): + fe.Reason = FailoverOverloaded + default: + fe.Reason = FailoverUnknown + } + + return fe +} + +func reasonFromStatus(status int) FailoverReason { + switch status { + case 400: + return FailoverFormat + case 401, 403: + return FailoverAuth + case 402: + return FailoverBilling + case 429: + return FailoverRateLimit + case 408, 504: + return FailoverTimeout + case 500, 502, 503, 529: + return FailoverOverloaded + default: + return FailoverUnknown + } +} diff --git a/forge-core/llm/errors_test.go b/forge-core/llm/errors_test.go new file mode 100644 index 0000000..aa29b69 --- /dev/null +++ b/forge-core/llm/errors_test.go @@ -0,0 +1,248 @@ +package llm + +import ( + "fmt" + "testing" +) + +func TestClassifyError_StatusCodes(t *testing.T) { + tests := []struct { + name string + err error + wantReason FailoverReason + wantStatus int + }{ + { + name: "openai 429 rate limit", + err: fmt.Errorf("openai error (status 429): rate limit exceeded"), + wantReason: FailoverRateLimit, + wantStatus: 429, + }, + { + name: "anthropic 503 overloaded", + err: fmt.Errorf("anthropic error (status 503): service unavailable"), + wantReason: FailoverOverloaded, + wantStatus: 503, + }, + { + name: "openai stream 429", + err: fmt.Errorf("openai stream error (status 429): too many requests"), + wantReason: FailoverRateLimit, + wantStatus: 429, + }, + { + name: "anthropic stream 503", + err: fmt.Errorf("anthropic stream error (status 503): overloaded"), + wantReason: FailoverOverloaded, + wantStatus: 503, + }, + { + name: "400 bad request", + err: fmt.Errorf("openai error (status 400): invalid request"), + wantReason: FailoverFormat, + wantStatus: 400, + }, + { + name: "401 unauthorized", + err: fmt.Errorf("openai error (status 401): invalid api key"), + wantReason: FailoverAuth, + wantStatus: 401, + }, + { + name: "403 forbidden", + err: fmt.Errorf("anthropic error (status 403): forbidden"), + wantReason: FailoverAuth, + wantStatus: 403, + }, + { + name: "402 billing", + err: fmt.Errorf("openai error (status 402): payment required"), + wantReason: FailoverBilling, + wantStatus: 402, + }, + { + name: "500 internal server error", + err: fmt.Errorf("openai error (status 500): internal error"), + wantReason: FailoverOverloaded, + wantStatus: 500, + }, + { + name: "502 bad gateway", + err: fmt.Errorf("openai error (status 502): bad gateway"), + wantReason: FailoverOverloaded, + wantStatus: 502, + }, + { + name: "529 overloaded", + err: fmt.Errorf("anthropic error (status 529): overloaded"), + wantReason: FailoverOverloaded, + wantStatus: 529, + }, + { + name: "408 timeout", + err: fmt.Errorf("openai error (status 408): request timeout"), + wantReason: FailoverTimeout, + wantStatus: 408, + }, + { + name: "504 gateway timeout", + err: fmt.Errorf("openai error (status 504): gateway timeout"), + wantReason: FailoverTimeout, + wantStatus: 504, + }, + { + name: "unknown status", + err: fmt.Errorf("openai error (status 418): I'm a teapot"), + wantReason: FailoverUnknown, + wantStatus: 418, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fe := ClassifyError(tt.err, "test-provider", "test-model") + if fe.Reason != tt.wantReason { + t.Errorf("reason = %q, want %q", fe.Reason, tt.wantReason) + } + if fe.Status != tt.wantStatus { + t.Errorf("status = %d, want %d", fe.Status, tt.wantStatus) + } + }) + } +} + +func TestClassifyError_MessagePatterns(t *testing.T) { + tests := []struct { + name string + err error + wantReason FailoverReason + }{ + { + name: "timeout message", + err: fmt.Errorf("openai request: context deadline exceeded"), + wantReason: FailoverTimeout, + }, + { + name: "rate limit message", + err: fmt.Errorf("rate limit exceeded, try again later"), + wantReason: FailoverRateLimit, + }, + { + name: "unauthorized message", + err: fmt.Errorf("unauthorized: invalid api key"), + wantReason: FailoverAuth, + }, + { + name: "service unavailable message", + err: fmt.Errorf("service unavailable"), + wantReason: FailoverOverloaded, + }, + { + name: "unknown error", + err: fmt.Errorf("something completely unexpected"), + wantReason: FailoverUnknown, + }, + { + name: "deadline exceeded", + err: fmt.Errorf("deadline exceeded while waiting"), + wantReason: FailoverTimeout, + }, + { + name: "too many requests message", + err: fmt.Errorf("too many requests"), + wantReason: FailoverRateLimit, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fe := ClassifyError(tt.err, "test-provider", "test-model") + if fe.Reason != tt.wantReason { + t.Errorf("reason = %q, want %q", fe.Reason, tt.wantReason) + } + if fe.Status != 0 { + t.Errorf("status = %d, want 0 (no status in message)", fe.Status) + } + }) + } +} + +func TestFailoverError_IsRetriable(t *testing.T) { + tests := []struct { + reason FailoverReason + retriable bool + }{ + {FailoverRateLimit, true}, + {FailoverOverloaded, true}, + {FailoverTimeout, true}, + {FailoverUnknown, true}, + {FailoverAuth, false}, + {FailoverFormat, false}, + {FailoverBilling, false}, + } + + for _, tt := range tests { + t.Run(string(tt.reason), func(t *testing.T) { + fe := &FailoverError{Reason: tt.reason} + if fe.IsRetriable() != tt.retriable { + t.Errorf("IsRetriable() = %v, want %v", fe.IsRetriable(), tt.retriable) + } + }) + } +} + +func TestFailoverError_Error(t *testing.T) { + fe := &FailoverError{ + Reason: FailoverRateLimit, + Provider: "openai", + Model: "gpt-4o", + Status: 429, + Wrapped: fmt.Errorf("rate limit exceeded"), + } + got := fe.Error() + want := "openai/gpt-4o failover (rate_limit, status 429): rate limit exceeded" + if got != want { + t.Errorf("Error() = %q, want %q", got, want) + } + + // Without status + fe2 := &FailoverError{ + Reason: FailoverTimeout, + Provider: "anthropic", + Model: "claude", + Wrapped: fmt.Errorf("deadline exceeded"), + } + got2 := fe2.Error() + want2 := "anthropic/claude failover (timeout): deadline exceeded" + if got2 != want2 { + t.Errorf("Error() = %q, want %q", got2, want2) + } +} + +func TestFailoverError_Unwrap(t *testing.T) { + inner := fmt.Errorf("original error") + fe := &FailoverError{Wrapped: inner} + if fe.Unwrap() != inner { + t.Error("Unwrap() did not return the wrapped error") + } +} + +func TestFallbackExhaustedError(t *testing.T) { + // Empty errors + e := &FallbackExhaustedError{} + if e.Error() != "all fallback candidates exhausted" { + t.Errorf("unexpected error: %s", e.Error()) + } + + // With errors + e2 := &FallbackExhaustedError{ + Errors: []*FailoverError{ + {Reason: FailoverRateLimit, Provider: "openai", Model: "gpt-4o", Status: 429, Wrapped: fmt.Errorf("rate limited")}, + {Reason: FailoverOverloaded, Provider: "anthropic", Model: "claude", Status: 503, Wrapped: fmt.Errorf("overloaded")}, + }, + } + got := e2.Error() + if got == "" { + t.Error("expected non-empty error message") + } +} diff --git a/forge-core/llm/fallback.go b/forge-core/llm/fallback.go new file mode 100644 index 0000000..5e31fc0 --- /dev/null +++ b/forge-core/llm/fallback.go @@ -0,0 +1,130 @@ +package llm + +import ( + "context" + "fmt" +) + +// FallbackCandidate pairs a provider/model label with its LLM client. +type FallbackCandidate struct { + Provider string + Model string + Client Client +} + +// FallbackChain implements the Client interface by trying multiple LLM +// providers in order. When the primary provider fails with a retriable error +// (429, 503, timeouts), the chain moves to the next candidate. Non-retriable +// errors (400 bad request, 401 auth) abort immediately. +// +// When there is only one candidate, FallbackChain delegates directly without +// error classification to preserve exact current behavior. +type FallbackChain struct { + candidates []FallbackCandidate + cooldown *CooldownTracker +} + +// NewFallbackChain creates a new fallback chain from the given candidates. +// At least one candidate is required. +func NewFallbackChain(candidates []FallbackCandidate) *FallbackChain { + return &FallbackChain{ + candidates: candidates, + cooldown: NewCooldownTracker(), + } +} + +// Chat tries each candidate in order until one succeeds or all are exhausted. +func (fc *FallbackChain) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) { + // Single-candidate optimization: delegate directly, no classification. + if len(fc.candidates) == 1 { + return fc.candidates[0].Client.Chat(ctx, req) + } + + var errors []*FailoverError + + for _, c := range fc.candidates { + // Check context cancellation + if err := ctx.Err(); err != nil { + return nil, err + } + + // Skip providers in cooldown + if !fc.cooldown.IsAvailable(c.Provider) { + continue + } + + resp, err := c.Client.Chat(ctx, req) + if err == nil { + fc.cooldown.MarkSuccess(c.Provider) + return resp, nil + } + + fe := ClassifyError(err, c.Provider, c.Model) + errors = append(errors, fe) + + // Non-retriable errors abort immediately + if !fe.IsRetriable() { + return nil, fe + } + + // Retriable โ mark failure and try next + fc.cooldown.MarkFailure(c.Provider, fe.Reason) + } + + if len(errors) == 0 { + return nil, fmt.Errorf("all fallback candidates in cooldown") + } + return nil, &FallbackExhaustedError{Errors: errors} +} + +// ChatStream tries each candidate in order for streaming requests. +func (fc *FallbackChain) ChatStream(ctx context.Context, req *ChatRequest) (<-chan StreamDelta, error) { + // Single-candidate optimization: delegate directly, no classification. + if len(fc.candidates) == 1 { + return fc.candidates[0].Client.ChatStream(ctx, req) + } + + var errors []*FailoverError + + for _, c := range fc.candidates { + // Check context cancellation + if err := ctx.Err(); err != nil { + return nil, err + } + + // Skip providers in cooldown + if !fc.cooldown.IsAvailable(c.Provider) { + continue + } + + ch, err := c.Client.ChatStream(ctx, req) + if err == nil { + fc.cooldown.MarkSuccess(c.Provider) + return ch, nil + } + + fe := ClassifyError(err, c.Provider, c.Model) + errors = append(errors, fe) + + // Non-retriable errors abort immediately + if !fe.IsRetriable() { + return nil, fe + } + + // Retriable โ mark failure and try next + fc.cooldown.MarkFailure(c.Provider, fe.Reason) + } + + if len(errors) == 0 { + return nil, fmt.Errorf("all fallback candidates in cooldown") + } + return nil, &FallbackExhaustedError{Errors: errors} +} + +// ModelID returns the primary candidate's model identifier. +func (fc *FallbackChain) ModelID() string { + if len(fc.candidates) > 0 { + return fc.candidates[0].Client.ModelID() + } + return "" +} diff --git a/forge-core/llm/fallback_test.go b/forge-core/llm/fallback_test.go new file mode 100644 index 0000000..cf15644 --- /dev/null +++ b/forge-core/llm/fallback_test.go @@ -0,0 +1,402 @@ +package llm + +import ( + "context" + "fmt" + "testing" +) + +// mockClient is a test double for llm.Client. +type mockClient struct { + chatFunc func(ctx context.Context, req *ChatRequest) (*ChatResponse, error) + chatStreamFunc func(ctx context.Context, req *ChatRequest) (<-chan StreamDelta, error) + modelID string +} + +func (m *mockClient) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) { + return m.chatFunc(ctx, req) +} + +func (m *mockClient) ChatStream(ctx context.Context, req *ChatRequest) (<-chan StreamDelta, error) { + if m.chatStreamFunc != nil { + return m.chatStreamFunc(ctx, req) + } + return nil, fmt.Errorf("not implemented") +} + +func (m *mockClient) ModelID() string { + return m.modelID +} + +func okClient(model string) *mockClient { + return &mockClient{ + modelID: model, + chatFunc: func(_ context.Context, _ *ChatRequest) (*ChatResponse, error) { + return &ChatResponse{Message: ChatMessage{Content: "ok from " + model}}, nil + }, + chatStreamFunc: func(_ context.Context, _ *ChatRequest) (<-chan StreamDelta, error) { + ch := make(chan StreamDelta, 1) + ch <- StreamDelta{Content: "ok from " + model, Done: true} + close(ch) + return ch, nil + }, + } +} + +func errorClient(model string, err error) *mockClient { + return &mockClient{ + modelID: model, + chatFunc: func(_ context.Context, _ *ChatRequest) (*ChatResponse, error) { + return nil, err + }, + chatStreamFunc: func(_ context.Context, _ *ChatRequest) (<-chan StreamDelta, error) { + return nil, err + }, + } +} + +func TestFallbackChain_SingleCandidate_Success(t *testing.T) { + c := okClient("gpt-4o") + fc := NewFallbackChain([]FallbackCandidate{ + {Provider: "openai", Model: "gpt-4o", Client: c}, + }) + + resp, err := fc.Chat(context.Background(), &ChatRequest{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Message.Content != "ok from gpt-4o" { + t.Errorf("unexpected response: %s", resp.Message.Content) + } +} + +func TestFallbackChain_SingleCandidate_PassthroughError(t *testing.T) { + rawErr := fmt.Errorf("openai error (status 429): rate limited") + c := errorClient("gpt-4o", rawErr) + fc := NewFallbackChain([]FallbackCandidate{ + {Provider: "openai", Model: "gpt-4o", Client: c}, + }) + + _, err := fc.Chat(context.Background(), &ChatRequest{}) + if err == nil { + t.Fatal("expected error") + } + // Single-candidate passes through raw error, not classified + if err != rawErr { + t.Errorf("expected raw error passthrough, got: %v", err) + } +} + +func TestFallbackChain_PrimarySuccess(t *testing.T) { + fc := NewFallbackChain([]FallbackCandidate{ + {Provider: "openai", Model: "gpt-4o", Client: okClient("gpt-4o")}, + {Provider: "anthropic", Model: "claude", Client: okClient("claude")}, + }) + + resp, err := fc.Chat(context.Background(), &ChatRequest{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Message.Content != "ok from gpt-4o" { + t.Errorf("expected primary response, got: %s", resp.Message.Content) + } +} + +func TestFallbackChain_FallbackOn429(t *testing.T) { + fc := NewFallbackChain([]FallbackCandidate{ + {Provider: "openai", Model: "gpt-4o", Client: errorClient("gpt-4o", + fmt.Errorf("openai error (status 429): rate limited"))}, + {Provider: "anthropic", Model: "claude", Client: okClient("claude")}, + }) + + resp, err := fc.Chat(context.Background(), &ChatRequest{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Message.Content != "ok from claude" { + t.Errorf("expected fallback response, got: %s", resp.Message.Content) + } +} + +func TestFallbackChain_FallbackOn503(t *testing.T) { + fc := NewFallbackChain([]FallbackCandidate{ + {Provider: "openai", Model: "gpt-4o", Client: errorClient("gpt-4o", + fmt.Errorf("openai error (status 503): service unavailable"))}, + {Provider: "anthropic", Model: "claude", Client: okClient("claude")}, + }) + + resp, err := fc.Chat(context.Background(), &ChatRequest{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Message.Content != "ok from claude" { + t.Errorf("expected fallback response, got: %s", resp.Message.Content) + } +} + +func TestFallbackChain_AbortOnAuthError(t *testing.T) { + fc := NewFallbackChain([]FallbackCandidate{ + {Provider: "openai", Model: "gpt-4o", Client: errorClient("gpt-4o", + fmt.Errorf("openai error (status 401): invalid api key"))}, + {Provider: "anthropic", Model: "claude", Client: okClient("claude")}, + }) + + _, err := fc.Chat(context.Background(), &ChatRequest{}) + if err == nil { + t.Fatal("expected error for auth failure") + } + fe, ok := err.(*FailoverError) + if !ok { + t.Fatalf("expected FailoverError, got %T", err) + } + if fe.Reason != FailoverAuth { + t.Errorf("reason = %q, want %q", fe.Reason, FailoverAuth) + } +} + +func TestFallbackChain_AbortOnFormatError(t *testing.T) { + fc := NewFallbackChain([]FallbackCandidate{ + {Provider: "openai", Model: "gpt-4o", Client: errorClient("gpt-4o", + fmt.Errorf("openai error (status 400): bad request"))}, + {Provider: "anthropic", Model: "claude", Client: okClient("claude")}, + }) + + _, err := fc.Chat(context.Background(), &ChatRequest{}) + if err == nil { + t.Fatal("expected error for format failure") + } + fe, ok := err.(*FailoverError) + if !ok { + t.Fatalf("expected FailoverError, got %T", err) + } + if fe.Reason != FailoverFormat { + t.Errorf("reason = %q, want %q", fe.Reason, FailoverFormat) + } +} + +func TestFallbackChain_AllFail(t *testing.T) { + fc := NewFallbackChain([]FallbackCandidate{ + {Provider: "openai", Model: "gpt-4o", Client: errorClient("gpt-4o", + fmt.Errorf("openai error (status 429): rate limited"))}, + {Provider: "anthropic", Model: "claude", Client: errorClient("claude", + fmt.Errorf("anthropic error (status 503): overloaded"))}, + }) + + _, err := fc.Chat(context.Background(), &ChatRequest{}) + if err == nil { + t.Fatal("expected error when all candidates fail") + } + exhausted, ok := err.(*FallbackExhaustedError) + if !ok { + t.Fatalf("expected FallbackExhaustedError, got %T: %v", err, err) + } + if len(exhausted.Errors) != 2 { + t.Errorf("expected 2 errors, got %d", len(exhausted.Errors)) + } +} + +func TestFallbackChain_ContextCancelled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + fc := NewFallbackChain([]FallbackCandidate{ + {Provider: "openai", Model: "gpt-4o", Client: okClient("gpt-4o")}, + {Provider: "anthropic", Model: "claude", Client: okClient("claude")}, + }) + + _, err := fc.Chat(ctx, &ChatRequest{}) + if err == nil { + t.Fatal("expected context error") + } + if err != context.Canceled { + t.Errorf("expected context.Canceled, got: %v", err) + } +} + +func TestFallbackChain_CooldownSkip(t *testing.T) { + callCount := 0 + slowClient := &mockClient{ + modelID: "gpt-4o", + chatFunc: func(_ context.Context, _ *ChatRequest) (*ChatResponse, error) { + callCount++ + return nil, fmt.Errorf("openai error (status 429): rate limited") + }, + } + + fc := NewFallbackChain([]FallbackCandidate{ + {Provider: "openai", Model: "gpt-4o", Client: slowClient}, + {Provider: "anthropic", Model: "claude", Client: okClient("claude")}, + }) + + // First call: primary fails, fallback succeeds + resp, err := fc.Chat(context.Background(), &ChatRequest{}) + if err != nil { + t.Fatalf("first call: unexpected error: %v", err) + } + if resp.Message.Content != "ok from claude" { + t.Errorf("first call: expected fallback, got: %s", resp.Message.Content) + } + if callCount != 1 { + t.Errorf("first call: expected 1 primary call, got %d", callCount) + } + + // Second call: primary should be in cooldown, skipped entirely + resp, err = fc.Chat(context.Background(), &ChatRequest{}) + if err != nil { + t.Fatalf("second call: unexpected error: %v", err) + } + if resp.Message.Content != "ok from claude" { + t.Errorf("second call: expected fallback, got: %s", resp.Message.Content) + } + if callCount != 1 { + t.Errorf("second call: primary should have been skipped, call count = %d", callCount) + } +} + +func TestFallbackChain_ModelID(t *testing.T) { + fc := NewFallbackChain([]FallbackCandidate{ + {Provider: "openai", Model: "gpt-4o", Client: okClient("gpt-4o")}, + {Provider: "anthropic", Model: "claude", Client: okClient("claude")}, + }) + + if fc.ModelID() != "gpt-4o" { + t.Errorf("ModelID() = %q, want %q", fc.ModelID(), "gpt-4o") + } +} + +func TestFallbackChain_ModelID_Empty(t *testing.T) { + fc := NewFallbackChain([]FallbackCandidate{}) + if fc.ModelID() != "" { + t.Errorf("ModelID() = %q, want empty", fc.ModelID()) + } +} + +func TestFallbackChain_ChatStream_FallbackOn429(t *testing.T) { + fc := NewFallbackChain([]FallbackCandidate{ + {Provider: "openai", Model: "gpt-4o", Client: errorClient("gpt-4o", + fmt.Errorf("openai stream error (status 429): rate limited"))}, + {Provider: "anthropic", Model: "claude", Client: okClient("claude")}, + }) + + ch, err := fc.ChatStream(context.Background(), &ChatRequest{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + delta := <-ch + if delta.Content != "ok from claude" { + t.Errorf("expected fallback stream, got: %s", delta.Content) + } +} + +func TestFallbackChain_ChatStream_SingleCandidate(t *testing.T) { + rawErr := fmt.Errorf("openai stream error (status 429): rate limited") + fc := NewFallbackChain([]FallbackCandidate{ + {Provider: "openai", Model: "gpt-4o", Client: errorClient("gpt-4o", rawErr)}, + }) + + _, err := fc.ChatStream(context.Background(), &ChatRequest{}) + if err == nil { + t.Fatal("expected error") + } + // Single candidate: raw error passthrough + if err != rawErr { + t.Errorf("expected raw error passthrough, got: %v", err) + } +} + +func TestFallbackChain_ChatStream_AbortOnAuth(t *testing.T) { + fc := NewFallbackChain([]FallbackCandidate{ + {Provider: "openai", Model: "gpt-4o", Client: errorClient("gpt-4o", + fmt.Errorf("openai stream error (status 401): unauthorized"))}, + {Provider: "anthropic", Model: "claude", Client: okClient("claude")}, + }) + + _, err := fc.ChatStream(context.Background(), &ChatRequest{}) + if err == nil { + t.Fatal("expected error for auth failure") + } + fe, ok := err.(*FailoverError) + if !ok { + t.Fatalf("expected FailoverError, got %T", err) + } + if fe.Reason != FailoverAuth { + t.Errorf("reason = %q, want %q", fe.Reason, FailoverAuth) + } +} + +func TestFallbackChain_ChatStream_AllFail(t *testing.T) { + fc := NewFallbackChain([]FallbackCandidate{ + {Provider: "openai", Model: "gpt-4o", Client: errorClient("gpt-4o", + fmt.Errorf("openai stream error (status 429): rate limited"))}, + {Provider: "anthropic", Model: "claude", Client: errorClient("claude", + fmt.Errorf("anthropic stream error (status 503): overloaded"))}, + }) + + _, err := fc.ChatStream(context.Background(), &ChatRequest{}) + if err == nil { + t.Fatal("expected error when all candidates fail") + } + _, ok := err.(*FallbackExhaustedError) + if !ok { + t.Fatalf("expected FallbackExhaustedError, got %T", err) + } +} + +func TestFallbackChain_AllInCooldown(t *testing.T) { + fc := NewFallbackChain([]FallbackCandidate{ + {Provider: "openai", Model: "gpt-4o", Client: errorClient("gpt-4o", + fmt.Errorf("openai error (status 429): rate limited"))}, + {Provider: "anthropic", Model: "claude", Client: errorClient("claude", + fmt.Errorf("anthropic error (status 503): overloaded"))}, + }) + + // First call exhausts all candidates and puts them in cooldown + _, _ = fc.Chat(context.Background(), &ChatRequest{}) + + // Second call: all in cooldown, no candidates tried + _, err := fc.Chat(context.Background(), &ChatRequest{}) + if err == nil { + t.Fatal("expected error when all in cooldown") + } +} + +func TestFallbackChain_SuccessResetsCooldown(t *testing.T) { + callCount := 0 + flaky := &mockClient{ + modelID: "gpt-4o", + chatFunc: func(_ context.Context, _ *ChatRequest) (*ChatResponse, error) { + callCount++ + if callCount == 1 { + return nil, fmt.Errorf("openai error (status 503): temporary") + } + return &ChatResponse{Message: ChatMessage{Content: "recovered"}}, nil + }, + } + + fc := NewFallbackChain([]FallbackCandidate{ + {Provider: "openai", Model: "gpt-4o", Client: flaky}, + {Provider: "anthropic", Model: "claude", Client: okClient("claude")}, + }) + + // First: primary fails, fallback succeeds + resp, err := fc.Chat(context.Background(), &ChatRequest{}) + if err != nil { + t.Fatalf("first call: %v", err) + } + if resp.Message.Content != "ok from claude" { + t.Errorf("first: expected fallback, got %s", resp.Message.Content) + } + + // Manually reset cooldown for primary (simulates time passing) + fc.cooldown.MarkSuccess("openai") + + // Second: primary succeeds now + resp, err = fc.Chat(context.Background(), &ChatRequest{}) + if err != nil { + t.Fatalf("second call: %v", err) + } + if resp.Message.Content != "recovered" { + t.Errorf("second: expected recovered, got %s", resp.Message.Content) + } +} diff --git a/forge-core/llm/oauth/flow.go b/forge-core/llm/oauth/flow.go new file mode 100644 index 0000000..edf16c3 --- /dev/null +++ b/forge-core/llm/oauth/flow.go @@ -0,0 +1,161 @@ +package oauth + +import ( + "context" + "fmt" + "net/url" + "os/exec" + "runtime" + "time" +) + +// ProviderConfig holds the OAuth configuration for a provider. +type ProviderConfig struct { + AuthURL string + TokenURL string + ClientID string + Scopes string + RedirectURI string + BaseURL string // API base URL to use with the obtained token + ExtraParams map[string]string // additional query params for the auth URL +} + +// OpenAIConfig returns the OAuth configuration for OpenAI. +// Uses the same public client ID and endpoints as the official Codex CLI. +// ChatGPT OAuth tokens are scoped to the ChatGPT backend API, not the +// standard OpenAI API (api.openai.com). The base URL is set accordingly. +func OpenAIConfig() ProviderConfig { + return ProviderConfig{ + AuthURL: "https://auth.openai.com/oauth/authorize", + TokenURL: "https://auth.openai.com/oauth/token", + ClientID: "app_EMoamEEZ73f0CkXaXp7hrann", + Scopes: "openid profile email offline_access", + RedirectURI: "http://localhost:1455/auth/callback", + BaseURL: "https://chatgpt.com/backend-api/codex", + ExtraParams: map[string]string{ + "id_token_add_organizations": "true", + "codex_cli_simplified_flow": "true", + }, + } +} + +// Flow orchestrates the OAuth authorization code flow with PKCE. +type Flow struct { + Config ProviderConfig + Timeout time.Duration // default: 2 minutes +} + +// NewFlow creates a new OAuth flow with the given provider config. +func NewFlow(config ProviderConfig) *Flow { + return &Flow{ + Config: config, + Timeout: 2 * time.Minute, + } +} + +// Execute runs the full OAuth flow: +// 1. Generate PKCE params and state +// 2. Start local callback server +// 3. Open browser to authorization URL +// 4. Wait for authorization code +// 5. Exchange code for tokens +// 6. Store credentials +func (f *Flow) Execute(ctx context.Context, provider string) (*Token, error) { + // Generate PKCE + pkce, err := GeneratePKCE() + if err != nil { + return nil, fmt.Errorf("generating PKCE: %w", err) + } + + state, err := GenerateState() + if err != nil { + return nil, fmt.Errorf("generating state: %w", err) + } + + // Start callback server on port 1455 (matching redirect_uri) + callbackServer := NewCallbackServer(1455) + if err := callbackServer.Start(); err != nil { + return nil, fmt.Errorf("starting callback server: %w", err) + } + defer callbackServer.Stop() + + // Build authorization URL + authURL := f.buildAuthURL(pkce, state) + + // Open browser + if err := openBrowser(authURL); err != nil { + return nil, fmt.Errorf("opening browser: %w\n\nPlease open this URL manually:\n%s", err, authURL) + } + + // Wait for code + timeout := f.Timeout + if timeout == 0 { + timeout = 2 * time.Minute + } + waitCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + result, err := callbackServer.WaitForCode(waitCtx) + if err != nil { + return nil, err + } + + // Verify state + if result.State != state { + return nil, fmt.Errorf("state mismatch: possible CSRF attack") + } + + // Exchange code for tokens + token, err := ExchangeCode( + f.Config.TokenURL, + f.Config.ClientID, + result.Code, + f.Config.RedirectURI, + pkce.Verifier, + ) + if err != nil { + return nil, fmt.Errorf("exchanging code: %w", err) + } + + // Persist the API base URL from config so the correct endpoint is used at runtime + token.BaseURL = f.Config.BaseURL + + // Store credentials + if err := SaveCredentials(provider, token); err != nil { + return nil, fmt.Errorf("saving credentials: %w", err) + } + + return token, nil +} + +func (f *Flow) buildAuthURL(pkce *PKCEParams, state string) string { + params := url.Values{ + "response_type": {"code"}, + "client_id": {f.Config.ClientID}, + "redirect_uri": {f.Config.RedirectURI}, + "scope": {f.Config.Scopes}, + "state": {state}, + "code_challenge": {pkce.Challenge}, + "code_challenge_method": {pkce.Method}, + } + for k, v := range f.Config.ExtraParams { + params.Set(k, v) + } + return f.Config.AuthURL + "?" + params.Encode() +} + +// openBrowser opens the given URL in the default browser. +func openBrowser(url string) error { + var cmd *exec.Cmd + switch runtime.GOOS { + case "darwin": + cmd = exec.Command("open", url) + case "linux": + cmd = exec.Command("xdg-open", url) + case "windows": + cmd = exec.Command("cmd", "/c", "start", url) + default: + return fmt.Errorf("unsupported platform: %s", runtime.GOOS) + } + return cmd.Start() +} diff --git a/forge-core/llm/oauth/pkce.go b/forge-core/llm/oauth/pkce.go new file mode 100644 index 0000000..bbfbf8b --- /dev/null +++ b/forge-core/llm/oauth/pkce.go @@ -0,0 +1,42 @@ +package oauth + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" +) + +// PKCEParams holds the PKCE code verifier and challenge for OAuth flows. +type PKCEParams struct { + Verifier string + Challenge string + Method string // always "S256" +} + +// GeneratePKCE creates a new PKCE code verifier (32 random bytes, base64url-encoded) +// and its corresponding S256 challenge. +func GeneratePKCE() (*PKCEParams, error) { + verifierBytes := make([]byte, 32) + if _, err := rand.Read(verifierBytes); err != nil { + return nil, err + } + verifier := base64.RawURLEncoding.EncodeToString(verifierBytes) + + hash := sha256.Sum256([]byte(verifier)) + challenge := base64.RawURLEncoding.EncodeToString(hash[:]) + + return &PKCEParams{ + Verifier: verifier, + Challenge: challenge, + Method: "S256", + }, nil +} + +// GenerateState creates a random state parameter for OAuth flows. +func GenerateState() (string, error) { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} diff --git a/forge-core/llm/oauth/pkce_test.go b/forge-core/llm/oauth/pkce_test.go new file mode 100644 index 0000000..cd9af04 --- /dev/null +++ b/forge-core/llm/oauth/pkce_test.go @@ -0,0 +1,53 @@ +package oauth + +import ( + "crypto/sha256" + "encoding/base64" + "testing" +) + +func TestGeneratePKCE(t *testing.T) { + pkce, err := GeneratePKCE() + if err != nil { + t.Fatalf("GeneratePKCE() error: %v", err) + } + + // Verifier should be base64url-encoded 32 bytes = 43 chars + if len(pkce.Verifier) != 43 { + t.Errorf("expected verifier length 43, got %d", len(pkce.Verifier)) + } + + // Method must be S256 + if pkce.Method != "S256" { + t.Errorf("expected method S256, got %s", pkce.Method) + } + + // Verify the challenge matches the verifier + hash := sha256.Sum256([]byte(pkce.Verifier)) + expectedChallenge := base64.RawURLEncoding.EncodeToString(hash[:]) + if pkce.Challenge != expectedChallenge { + t.Errorf("challenge mismatch: got %s, want %s", pkce.Challenge, expectedChallenge) + } +} + +func TestGeneratePKCE_Uniqueness(t *testing.T) { + p1, _ := GeneratePKCE() + p2, _ := GeneratePKCE() + if p1.Verifier == p2.Verifier { + t.Error("two PKCE params should not have the same verifier") + } +} + +func TestGenerateState(t *testing.T) { + state, err := GenerateState() + if err != nil { + t.Fatalf("GenerateState() error: %v", err) + } + if state == "" { + t.Error("state should not be empty") + } + // 16 bytes base64url = 22 chars + if len(state) != 22 { + t.Errorf("expected state length 22, got %d", len(state)) + } +} diff --git a/forge-core/llm/oauth/server.go b/forge-core/llm/oauth/server.go new file mode 100644 index 0000000..3ef95b7 --- /dev/null +++ b/forge-core/llm/oauth/server.go @@ -0,0 +1,102 @@ +package oauth + +import ( + "context" + "fmt" + "net" + "net/http" + "sync" +) + +// CallbackResult holds the result from the OAuth callback. +type CallbackResult struct { + Code string + State string + Error string +} + +// CallbackServer is a local HTTP server that receives the OAuth authorization code. +type CallbackServer struct { + port int + resultCh chan CallbackResult + server *http.Server + mu sync.Mutex +} + +// NewCallbackServer creates a callback server on the given port. +func NewCallbackServer(port int) *CallbackServer { + return &CallbackServer{ + port: port, + resultCh: make(chan CallbackResult, 1), + } +} + +// Start starts the callback server and returns immediately. +func (s *CallbackServer) Start() error { + mux := http.NewServeMux() + mux.HandleFunc("/auth/callback", s.handleCallback) + + s.mu.Lock() + s.server = &http.Server{ + Handler: mux, + } + s.mu.Unlock() + + ln, err := net.Listen("tcp", fmt.Sprintf(":%d", s.port)) + if err != nil { + return fmt.Errorf("starting callback server on port %d: %w", s.port, err) + } + + go func() { + if err := s.server.Serve(ln); err != nil && err != http.ErrServerClosed { + s.resultCh <- CallbackResult{Error: err.Error()} + } + }() + + return nil +} + +// WaitForCode blocks until an authorization code is received or the context expires. +func (s *CallbackServer) WaitForCode(ctx context.Context) (CallbackResult, error) { + select { + case result := <-s.resultCh: + if result.Error != "" { + return result, fmt.Errorf("oauth callback error: %s", result.Error) + } + return result, nil + case <-ctx.Done(): + return CallbackResult{}, fmt.Errorf("timed out waiting for authorization") + } +} + +// Stop shuts down the callback server. +func (s *CallbackServer) Stop() { + s.mu.Lock() + defer s.mu.Unlock() + if s.server != nil { + _ = s.server.Close() + } +} + +func (s *CallbackServer) handleCallback(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + + if errMsg := query.Get("error"); errMsg != "" { + desc := query.Get("error_description") + s.resultCh <- CallbackResult{Error: fmt.Sprintf("%s: %s", errMsg, desc)} + _, _ = fmt.Fprintf(w, "
%s
You can close this tab.
", desc) + return + } + + code := query.Get("code") + state := query.Get("state") + + if code == "" { + s.resultCh <- CallbackResult{Error: "no code in callback"} + _, _ = fmt.Fprint(w, "No authorization code received.
") + return + } + + s.resultCh <- CallbackResult{Code: code, State: state} + _, _ = fmt.Fprint(w, "You can close this tab and return to the terminal.
") +} diff --git a/forge-core/llm/oauth/store.go b/forge-core/llm/oauth/store.go new file mode 100644 index 0000000..7ba3e02 --- /dev/null +++ b/forge-core/llm/oauth/store.go @@ -0,0 +1,80 @@ +package oauth + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" +) + +// DefaultCredentialsDir returns the default directory for OAuth credentials. +func DefaultCredentialsDir() (string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("getting home directory: %w", err) + } + return filepath.Join(home, ".forge", "credentials"), nil +} + +// SaveCredentials stores OAuth token data to disk with restricted permissions. +func SaveCredentials(provider string, token *Token) error { + dir, err := DefaultCredentialsDir() + if err != nil { + return err + } + + if err := os.MkdirAll(dir, 0o700); err != nil { + return fmt.Errorf("creating credentials directory: %w", err) + } + + data, err := json.MarshalIndent(token, "", " ") + if err != nil { + return fmt.Errorf("marshaling token: %w", err) + } + + path := filepath.Join(dir, provider+".json") + if err := os.WriteFile(path, data, 0o600); err != nil { + return fmt.Errorf("writing credentials: %w", err) + } + + return nil +} + +// LoadCredentials loads OAuth token data from disk. +func LoadCredentials(provider string) (*Token, error) { + dir, err := DefaultCredentialsDir() + if err != nil { + return nil, err + } + + path := filepath.Join(dir, provider+".json") + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil // no credentials stored + } + return nil, fmt.Errorf("reading credentials: %w", err) + } + + var token Token + if err := json.Unmarshal(data, &token); err != nil { + return nil, fmt.Errorf("parsing credentials: %w", err) + } + + return &token, nil +} + +// DeleteCredentials removes stored OAuth credentials for a provider. +func DeleteCredentials(provider string) error { + dir, err := DefaultCredentialsDir() + if err != nil { + return err + } + + path := filepath.Join(dir, provider+".json") + if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("removing credentials: %w", err) + } + + return nil +} diff --git a/forge-core/llm/oauth/store_test.go b/forge-core/llm/oauth/store_test.go new file mode 100644 index 0000000..84b8269 --- /dev/null +++ b/forge-core/llm/oauth/store_test.go @@ -0,0 +1,85 @@ +package oauth + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestSaveAndLoadCredentials(t *testing.T) { + // Use a temp directory + tmpDir := t.TempDir() + origHome := os.Getenv("HOME") + t.Setenv("HOME", tmpDir) + defer func() { _ = os.Setenv("HOME", origHome) }() + + token := &Token{ + AccessToken: "test-access-token", + RefreshToken: "test-refresh-token", + TokenType: "Bearer", + ExpiresAt: time.Now().Add(1 * time.Hour), + } + + // Save + err := SaveCredentials("testprovider", token) + if err != nil { + t.Fatalf("SaveCredentials() error: %v", err) + } + + // Verify file exists with correct permissions + credPath := filepath.Join(tmpDir, ".forge", "credentials", "testprovider.json") + info, err := os.Stat(credPath) + if err != nil { + t.Fatalf("credential file not found: %v", err) + } + if info.Mode().Perm() != 0o600 { + t.Errorf("expected permissions 0600, got %o", info.Mode().Perm()) + } + + // Load + loaded, err := LoadCredentials("testprovider") + if err != nil { + t.Fatalf("LoadCredentials() error: %v", err) + } + if loaded == nil { + t.Fatal("expected non-nil token") + } + if loaded.AccessToken != "test-access-token" { + t.Errorf("expected access token 'test-access-token', got %q", loaded.AccessToken) + } + if loaded.RefreshToken != "test-refresh-token" { + t.Errorf("expected refresh token 'test-refresh-token', got %q", loaded.RefreshToken) + } +} + +func TestLoadCredentials_NotFound(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + + token, err := LoadCredentials("nonexistent") + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + if token != nil { + t.Error("expected nil token for nonexistent provider") + } +} + +func TestDeleteCredentials(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + + token := &Token{AccessToken: "delete-me"} + _ = SaveCredentials("deletable", token) + + err := DeleteCredentials("deletable") + if err != nil { + t.Fatalf("DeleteCredentials() error: %v", err) + } + + loaded, _ := LoadCredentials("deletable") + if loaded != nil { + t.Error("expected nil after deletion") + } +} diff --git a/forge-core/llm/oauth/token.go b/forge-core/llm/oauth/token.go new file mode 100644 index 0000000..7430e96 --- /dev/null +++ b/forge-core/llm/oauth/token.go @@ -0,0 +1,111 @@ +package oauth + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +// Token holds the OAuth token data. +type Token struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in,omitempty"` + ExpiresAt time.Time `json:"expires_at"` + Scope string `json:"scope,omitempty"` + BaseURL string `json:"base_url,omitempty"` // API base URL for this token +} + +// IsExpired returns true if the token has expired or will expire within the +// given buffer duration (default 5 minutes). +func (t *Token) IsExpired() bool { + return t.IsExpiredWithBuffer(5 * time.Minute) +} + +// IsExpiredWithBuffer returns true if the token expires within the buffer. +func (t *Token) IsExpiredWithBuffer(buffer time.Duration) bool { + if t.ExpiresAt.IsZero() { + return true + } + return time.Now().Add(buffer).After(t.ExpiresAt) +} + +// tokenResponse is the raw response from the OAuth token endpoint. +type tokenResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope,omitempty"` + Error string `json:"error,omitempty"` + ErrorDesc string `json:"error_description,omitempty"` +} + +// ExchangeCode exchanges an authorization code for tokens. +func ExchangeCode(tokenURL, clientID, code, redirectURI, codeVerifier string) (*Token, error) { + data := url.Values{ + "grant_type": {"authorization_code"}, + "client_id": {clientID}, + "code": {code}, + "redirect_uri": {redirectURI}, + "code_verifier": {codeVerifier}, + } + + return doTokenRequest(tokenURL, data) +} + +// RefreshToken exchanges a refresh token for new access and refresh tokens. +func RefreshToken(tokenURL, clientID, refreshToken string) (*Token, error) { + data := url.Values{ + "grant_type": {"refresh_token"}, + "client_id": {clientID}, + "refresh_token": {refreshToken}, + } + + return doTokenRequest(tokenURL, data) +} + +func doTokenRequest(tokenURL string, data url.Values) (*Token, error) { + resp, err := http.Post(tokenURL, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) //nolint:gosec + if err != nil { + return nil, fmt.Errorf("token request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading token response: %w", err) + } + + var tr tokenResponse + if err := json.Unmarshal(body, &tr); err != nil { + return nil, fmt.Errorf("parsing token response: %w", err) + } + + if tr.Error != "" { + return nil, fmt.Errorf("oauth error: %s โ %s", tr.Error, tr.ErrorDesc) + } + + if tr.AccessToken == "" { + return nil, fmt.Errorf("no access token in response") + } + + token := &Token{ + AccessToken: tr.AccessToken, + RefreshToken: tr.RefreshToken, + TokenType: tr.TokenType, + ExpiresIn: tr.ExpiresIn, + Scope: tr.Scope, + } + + if tr.ExpiresIn > 0 { + token.ExpiresAt = time.Now().Add(time.Duration(tr.ExpiresIn) * time.Second) + } + + return token, nil +} diff --git a/forge-core/llm/oauth/token_test.go b/forge-core/llm/oauth/token_test.go new file mode 100644 index 0000000..186f633 --- /dev/null +++ b/forge-core/llm/oauth/token_test.go @@ -0,0 +1,60 @@ +package oauth + +import ( + "testing" + "time" +) + +func TestToken_IsExpired(t *testing.T) { + tests := []struct { + name string + token Token + expired bool + }{ + { + name: "zero expiry is expired", + token: Token{}, + expired: true, + }, + { + name: "future expiry is not expired", + token: Token{ExpiresAt: time.Now().Add(10 * time.Minute)}, + expired: false, + }, + { + name: "past expiry is expired", + token: Token{ExpiresAt: time.Now().Add(-1 * time.Minute)}, + expired: true, + }, + { + name: "within 5min buffer is expired", + token: Token{ExpiresAt: time.Now().Add(3 * time.Minute)}, + expired: true, + }, + { + name: "just outside 5min buffer is not expired", + token: Token{ExpiresAt: time.Now().Add(6 * time.Minute)}, + expired: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.token.IsExpired() + if got != tt.expired { + t.Errorf("IsExpired() = %v, want %v", got, tt.expired) + } + }) + } +} + +func TestToken_IsExpiredWithBuffer(t *testing.T) { + token := Token{ExpiresAt: time.Now().Add(30 * time.Second)} + + if !token.IsExpiredWithBuffer(1 * time.Minute) { + t.Error("should be expired with 1 minute buffer") + } + if token.IsExpiredWithBuffer(10 * time.Second) { + t.Error("should not be expired with 10 second buffer") + } +} diff --git a/forge-core/llm/providers/oauth_client.go b/forge-core/llm/providers/oauth_client.go new file mode 100644 index 0000000..6105bda --- /dev/null +++ b/forge-core/llm/providers/oauth_client.go @@ -0,0 +1,103 @@ +package providers + +import ( + "context" + "fmt" + "sync" + + "github.com/initializ/forge/forge-core/llm" + "github.com/initializ/forge/forge-core/llm/oauth" +) + +// OAuthClient wraps a ResponsesClient with automatic OAuth token refresh. +// It implements llm.Client and transparently refreshes expired tokens +// before each API call. ChatGPT OAuth tokens are scoped to the Responses API, +// not the Chat Completions API, so this client uses the Responses API format. +type OAuthClient struct { + inner *ResponsesClient + provider string + config oauth.ProviderConfig + mu sync.Mutex +} + +// NewOAuthClient creates a new OAuth-aware client that uses the Responses API. +// The token is loaded from stored credentials and refreshed automatically. +func NewOAuthClient(cfg llm.ClientConfig, provider string, oauthConfig oauth.ProviderConfig) *OAuthClient { + inner := NewResponsesClient(cfg) + // ChatGPT Codex backend requires store=false + inner.disableStore = true + return &OAuthClient{ + inner: inner, + provider: provider, + config: oauthConfig, + } +} + +// Chat sends a Responses API request, refreshing the token if needed. +func (c *OAuthClient) Chat(ctx context.Context, req *llm.ChatRequest) (*llm.ChatResponse, error) { + if err := c.ensureValidToken(); err != nil { + return nil, err + } + return c.inner.Chat(ctx, req) +} + +// ChatStream sends a streaming Responses API request, refreshing the token if needed. +func (c *OAuthClient) ChatStream(ctx context.Context, req *llm.ChatRequest) (<-chan llm.StreamDelta, error) { + if err := c.ensureValidToken(); err != nil { + return nil, err + } + return c.inner.ChatStream(ctx, req) +} + +// ModelID returns the model identifier. +func (c *OAuthClient) ModelID() string { + return c.inner.ModelID() +} + +// ensureValidToken checks if the stored token is expired and refreshes it. +func (c *OAuthClient) ensureValidToken() error { + c.mu.Lock() + defer c.mu.Unlock() + + token, err := oauth.LoadCredentials(c.provider) + if err != nil { + return fmt.Errorf("loading OAuth credentials: %w", err) + } + if token == nil { + return fmt.Errorf("no OAuth credentials found for %s", c.provider) + } + + if !token.IsExpired() { + return nil + } + + if token.RefreshToken == "" { + return fmt.Errorf("OAuth token expired and no refresh token available for %s", c.provider) + } + + // Refresh the token + newToken, err := oauth.RefreshToken(c.config.TokenURL, c.config.ClientID, token.RefreshToken) + if err != nil { + return fmt.Errorf("refreshing OAuth token: %w", err) + } + + // Preserve refresh token if not returned in refresh response + if newToken.RefreshToken == "" { + newToken.RefreshToken = token.RefreshToken + } + + // Preserve the base URL from the original token + if newToken.BaseURL == "" { + newToken.BaseURL = token.BaseURL + } + + // Persist the new token + if err := oauth.SaveCredentials(c.provider, newToken); err != nil { + return fmt.Errorf("saving refreshed token: %w", err) + } + + // Update the inner client's API key + c.inner.apiKey = newToken.AccessToken + + return nil +} diff --git a/forge-core/llm/providers/oauth_client_test.go b/forge-core/llm/providers/oauth_client_test.go new file mode 100644 index 0000000..55e0b25 --- /dev/null +++ b/forge-core/llm/providers/oauth_client_test.go @@ -0,0 +1,66 @@ +package providers + +import ( + "os" + "testing" + "time" + + "github.com/initializ/forge/forge-core/llm" + "github.com/initializ/forge/forge-core/llm/oauth" +) + +func TestOAuthClient_ModelID(t *testing.T) { + cfg := llm.ClientConfig{ + APIKey: "test-token", + Model: "gpt-4o", + } + client := NewOAuthClient(cfg, "openai", oauth.OpenAIConfig()) + if client.ModelID() != "gpt-4o" { + t.Errorf("expected model gpt-4o, got %s", client.ModelID()) + } +} + +func TestOAuthClient_EnsureValidToken_NoCredentials(t *testing.T) { + // Use a temp directory with no credentials + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + + cfg := llm.ClientConfig{ + APIKey: "test-token", + Model: "gpt-4o", + } + client := NewOAuthClient(cfg, "testprovider", oauth.OpenAIConfig()) + + err := client.ensureValidToken() + if err == nil { + t.Error("expected error when no credentials exist") + } +} + +func TestOAuthClient_EnsureValidToken_ValidToken(t *testing.T) { + tmpDir := t.TempDir() + origHome := os.Getenv("HOME") + t.Setenv("HOME", tmpDir) + defer func() { _ = os.Setenv("HOME", origHome) }() + + // Store a valid token + token := &oauth.Token{ + AccessToken: "valid-access-token", + RefreshToken: "valid-refresh-token", + ExpiresAt: time.Now().Add(1 * time.Hour), + } + if err := oauth.SaveCredentials("testprovider2", token); err != nil { + t.Fatalf("failed to save credentials: %v", err) + } + + cfg := llm.ClientConfig{ + APIKey: "old-token", + Model: "gpt-4o", + } + client := NewOAuthClient(cfg, "testprovider2", oauth.OpenAIConfig()) + + err := client.ensureValidToken() + if err != nil { + t.Errorf("expected no error for valid token, got: %v", err) + } +} diff --git a/forge-core/llm/providers/responses.go b/forge-core/llm/providers/responses.go new file mode 100644 index 0000000..1045ee4 --- /dev/null +++ b/forge-core/llm/providers/responses.go @@ -0,0 +1,423 @@ +package providers + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/initializ/forge/forge-core/llm" +) + +// ResponsesClient implements llm.Client using the OpenAI Responses API. +// This is used with ChatGPT OAuth tokens which are scoped to the Responses API +// endpoint (chatgpt.com/backend-api) rather than the Chat Completions API. +type ResponsesClient struct { + apiKey string + baseURL string + model string + client *http.Client + disableStore bool // set store=false in requests (required for ChatGPT Codex backend) +} + +// NewResponsesClient creates a new Responses API client. +func NewResponsesClient(cfg llm.ClientConfig) *ResponsesClient { + baseURL := cfg.BaseURL + if baseURL == "" { + baseURL = "https://api.openai.com/v1" + } + timeout := time.Duration(cfg.TimeoutSecs) * time.Second + if timeout == 0 { + timeout = 120 * time.Second + } + return &ResponsesClient{ + apiKey: cfg.APIKey, + baseURL: strings.TrimRight(baseURL, "/"), + model: cfg.Model, + client: &http.Client{Timeout: timeout}, + } +} + +func (c *ResponsesClient) ModelID() string { return c.model } + +// Chat sends a Responses API request. The ChatGPT Codex backend requires +// streaming, so this method always uses stream=true internally and collects +// the full response from the streamed deltas. +func (c *ResponsesClient) Chat(ctx context.Context, req *llm.ChatRequest) (*llm.ChatResponse, error) { + ch, err := c.ChatStream(ctx, req) + if err != nil { + return nil, err + } + + // Collect streamed deltas into a single response + result := &llm.ChatResponse{ + Message: llm.ChatMessage{Role: llm.RoleAssistant}, + } + // Track tool calls being assembled (keyed by ID) + toolCallMap := make(map[string]*llm.ToolCall) + var toolCallOrder []string + + for delta := range ch { + if delta.Content != "" { + result.Message.Content += delta.Content + } + for _, tc := range delta.ToolCalls { + existing, ok := toolCallMap[tc.ID] + if !ok { + newTC := llm.ToolCall{ + ID: tc.ID, + Type: tc.Type, + Function: llm.FunctionCall{ + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + }, + } + toolCallMap[tc.ID] = &newTC + toolCallOrder = append(toolCallOrder, tc.ID) + } else { + // Append streamed argument deltas + existing.Function.Arguments += tc.Function.Arguments + } + } + if delta.FinishReason != "" { + result.FinishReason = delta.FinishReason + } + if delta.Usage != nil { + result.Usage = *delta.Usage + } + } + + // Build ordered tool calls slice + for _, id := range toolCallOrder { + result.Message.ToolCalls = append(result.Message.ToolCalls, *toolCallMap[id]) + } + + return result, nil +} + +// ChatStream sends a streaming Responses API request. +func (c *ResponsesClient) ChatStream(ctx context.Context, req *llm.ChatRequest) (<-chan llm.StreamDelta, error) { + body := c.buildRequest(req, true) + data, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("marshalling request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/responses", bytes.NewReader(data)) + if err != nil { + return nil, err + } + c.setHeaders(httpReq) + + resp, err := c.client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("responses api stream request: %w", err) + } + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + _ = resp.Body.Close() + return nil, fmt.Errorf("responses api stream error (status %d): %s", resp.StatusCode, string(respBody)) + } + + ch := make(chan llm.StreamDelta, 32) + go func() { + defer func() { _ = resp.Body.Close() }() + defer close(ch) + c.readStream(resp.Body, ch) + }() + + return ch, nil +} + +func (c *ResponsesClient) setHeaders(req *http.Request) { + req.Header.Set("Content-Type", "application/json") + if c.apiKey != "" { + req.Header.Set("Authorization", "Bearer "+c.apiKey) + } +} + +// --- Request types --- + +type responsesRequest struct { + Model string `json:"model"` + Instructions string `json:"instructions,omitempty"` + Input []responsesInput `json:"input"` + Tools []responsesTool `json:"tools,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + MaxTokens int `json:"max_output_tokens,omitempty"` + Stream bool `json:"stream,omitempty"` + Store *bool `json:"store,omitempty"` +} + +// responsesInput is a union type for Responses API input items. +// It can be a message (role+content) or a function_call_output. +type responsesInput struct { + // For messages + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + + // For function_call items from assistant + Type string `json:"type,omitempty"` + CallID string `json:"call_id,omitempty"` + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Arguments string `json:"arguments,omitempty"` + + // For function_call_output + Output string `json:"output,omitempty"` +} + +// responsesTool is the Responses API tool format (flat, not nested under "function"). +type responsesTool struct { + Type string `json:"type"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters json.RawMessage `json:"parameters,omitempty"` +} + +func (c *ResponsesClient) buildRequest(req *llm.ChatRequest, stream bool) responsesRequest { + model := req.Model + if model == "" { + model = c.model + } + + var instructions string + var inputs []responsesInput + + for _, msg := range req.Messages { + switch msg.Role { + case llm.RoleSystem: + // System messages become the instructions field + if instructions != "" { + instructions += "\n" + } + instructions += msg.Content + + case llm.RoleUser: + inputs = append(inputs, responsesInput{ + Role: "user", + Content: msg.Content, + }) + + case llm.RoleAssistant: + if msg.Content != "" { + inputs = append(inputs, responsesInput{ + Role: "assistant", + Content: msg.Content, + }) + } + // If assistant had tool calls, add them as function_call items + for _, tc := range msg.ToolCalls { + inputs = append(inputs, responsesInput{ + Type: "function_call", + CallID: tc.ID, + ID: tc.ID, + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + }) + } + + case llm.RoleTool: + // Tool result messages become function_call_output + inputs = append(inputs, responsesInput{ + Type: "function_call_output", + CallID: msg.ToolCallID, + Output: msg.Content, + }) + } + } + + // Convert tools from Chat Completions format to Responses API format + var tools []responsesTool + for _, t := range req.Tools { + tools = append(tools, responsesTool{ + Type: "function", + Name: t.Function.Name, + Description: t.Function.Description, + Parameters: t.Function.Parameters, + }) + } + + r := responsesRequest{ + Model: model, + Instructions: instructions, + Input: inputs, + Tools: tools, + Temperature: req.Temperature, + MaxTokens: req.MaxTokens, + Stream: stream, + } + + if c.disableStore { + f := false + r.Store = &f + } + + return r +} + +// --- Response types --- + +type responsesResponse struct { + ID string `json:"id"` + Status string `json:"status"` + Output []responsesOutput `json:"output"` + Usage *responsesUsage `json:"usage,omitempty"` +} + +type responsesOutput struct { + Type string `json:"type"` // "message" or "function_call" + Role string `json:"role,omitempty"` + Content []responsesContentPart `json:"content,omitempty"` + + // For function_call outputs + ID string `json:"id,omitempty"` + CallID string `json:"call_id,omitempty"` + Name string `json:"name,omitempty"` + Arguments string `json:"arguments,omitempty"` +} + +type responsesContentPart struct { + Type string `json:"type"` // "output_text" + Text string `json:"text"` +} + +type responsesUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// --- Streaming --- + +type streamOutputItemAdded struct { + OutputIndex int `json:"output_index"` + Item responsesOutput `json:"item"` +} + +type streamTextDelta struct { + OutputIndex int `json:"output_index"` + ContentIndex int `json:"content_index"` + Delta string `json:"delta"` +} + +type streamFnArgsDelta struct { + OutputIndex int `json:"output_index"` + Delta string `json:"delta"` +} + +type streamCompleted struct { + Response responsesResponse `json:"response"` +} + +func (c *ResponsesClient) readStream(r io.Reader, ch chan<- llm.StreamDelta) { + // Track function calls being built so we can emit them with correct IDs + type pendingFC struct { + id string + name string + } + pendingFCs := make(map[int]*pendingFC) + + scanner := bufio.NewScanner(r) + var currentEvent string + + for scanner.Scan() { + line := scanner.Text() + + if after, ok := strings.CutPrefix(line, "event: "); ok { + currentEvent = after + continue + } + + data, ok := strings.CutPrefix(line, "data: ") + if !ok { + continue + } + + switch currentEvent { + case "response.output_text.delta": + var ev streamTextDelta + if err := json.Unmarshal([]byte(data), &ev); err != nil { + continue + } + ch <- llm.StreamDelta{Content: ev.Delta} + + case "response.output_item.added": + var ev streamOutputItemAdded + if err := json.Unmarshal([]byte(data), &ev); err != nil { + continue + } + if ev.Item.Type == "function_call" { + pendingFCs[ev.OutputIndex] = &pendingFC{ + id: ev.Item.CallID, + name: ev.Item.Name, + } + // Emit initial tool call with name + ch <- llm.StreamDelta{ + ToolCalls: []llm.ToolCall{{ + ID: ev.Item.CallID, + Type: "function", + Function: llm.FunctionCall{ + Name: ev.Item.Name, + Arguments: "", + }, + }}, + } + } + + case "response.function_call_arguments.delta": + var ev streamFnArgsDelta + if err := json.Unmarshal([]byte(data), &ev); err != nil { + continue + } + fc := pendingFCs[ev.OutputIndex] + if fc == nil { + continue + } + ch <- llm.StreamDelta{ + ToolCalls: []llm.ToolCall{{ + ID: fc.id, + Type: "function", + Function: llm.FunctionCall{ + Name: fc.name, + Arguments: ev.Delta, + }, + }}, + } + + case "response.completed": + var ev streamCompleted + if err := json.Unmarshal([]byte(data), &ev); err != nil { + continue + } + delta := llm.StreamDelta{Done: true} + if ev.Response.Usage != nil { + delta.Usage = &llm.UsageInfo{ + PromptTokens: ev.Response.Usage.InputTokens, + CompletionTokens: ev.Response.Usage.OutputTokens, + TotalTokens: ev.Response.Usage.TotalTokens, + } + } + // Determine finish reason from output + for _, out := range ev.Response.Output { + if out.Type == "function_call" { + delta.FinishReason = "tool_calls" + break + } + } + if delta.FinishReason == "" { + delta.FinishReason = "stop" + } + ch <- delta + return + } + + currentEvent = "" + } +} diff --git a/forge-core/runtime/config.go b/forge-core/runtime/config.go index ef84431..089b07d 100644 --- a/forge-core/runtime/config.go +++ b/forge-core/runtime/config.go @@ -1,12 +1,21 @@ package runtime import ( + "strings" + "github.com/initializ/forge/forge-core/llm" "github.com/initializ/forge/forge-core/types" ) // ModelConfig holds the resolved model provider and configuration. type ModelConfig struct { + Provider string + Client llm.ClientConfig + Fallbacks []FallbackModelConfig +} + +// FallbackModelConfig holds a resolved fallback provider's configuration. +type FallbackModelConfig struct { Provider string Client llm.ClientConfig } @@ -77,21 +86,128 @@ func ResolveModelConfig(cfg *types.ForgeConfig, envVars map[string]string, provi // Set default models per provider if not specified if mc.Client.Model == "" { - switch mc.Provider { - case "openai": - mc.Client.Model = "gpt-4o" - case "anthropic": - mc.Client.Model = "claude-sonnet-4-20250514" - case "gemini": - mc.Client.Model = "gemini-2.5-flash" - case "ollama": - mc.Client.Model = "llama3" - } + mc.Client.Model = defaultModelForProvider(mc.Provider) } + // Resolve fallback providers + mc.Fallbacks = resolveFallbacks(cfg, envVars, mc.Provider) + return mc } +// defaultModelForProvider returns the default model name for a given provider. +func defaultModelForProvider(provider string) string { + switch provider { + case "openai": + return "gpt-5.2-2025-12-11" + case "anthropic": + return "claude-sonnet-4-20250514" + case "gemini": + return "gemini-2.5-flash" + case "ollama": + return "llama3" + default: + return "" + } +} + +// resolveFallbacks resolves fallback provider configurations from multiple sources: +// 1. forge.yaml model.fallbacks +// 2. FORGE_MODEL_FALLBACKS env var (format: "openai:gpt-4o,gemini:gemini-2.5-flash") +// 3. Auto-detection from available API keys +func resolveFallbacks(cfg *types.ForgeConfig, envVars map[string]string, primaryProvider string) []FallbackModelConfig { + seen := map[string]bool{primaryProvider: true} + var fallbacks []FallbackModelConfig + + addFallback := func(provider, model string) { + if seen[provider] { + return + } + apiKey := resolveFallbackAPIKey(provider, envVars) + if apiKey == "" && provider != "ollama" { + return // skip providers without API keys + } + seen[provider] = true + if model == "" { + model = defaultModelForProvider(provider) + } + fc := FallbackModelConfig{ + Provider: provider, + Client: llm.ClientConfig{ + APIKey: apiKey, + Model: model, + }, + } + if provider == "ollama" && apiKey == "" { + fc.Client.APIKey = "ollama" + } + // Apply base URL overrides + fc.Client.BaseURL = resolveFallbackBaseURL(provider, envVars) + fallbacks = append(fallbacks, fc) + } + + // Source 1: forge.yaml model.fallbacks + for _, fb := range cfg.Model.Fallbacks { + addFallback(fb.Provider, fb.Name) + } + + // Source 2: FORGE_MODEL_FALLBACKS env var + if envFallbacks := envVars["FORGE_MODEL_FALLBACKS"]; envFallbacks != "" { + for _, entry := range strings.Split(envFallbacks, ",") { + entry = strings.TrimSpace(entry) + if entry == "" { + continue + } + provider, model, _ := strings.Cut(entry, ":") + addFallback(provider, model) + } + } + + // Source 3: Auto-detect from available API keys + providerKeys := map[string]string{ + "openai": "OPENAI_API_KEY", + "anthropic": "ANTHROPIC_API_KEY", + "gemini": "GEMINI_API_KEY", + } + for provider, keyName := range providerKeys { + if envVars[keyName] != "" { + addFallback(provider, "") + } + } + + return fallbacks +} + +// resolveFallbackAPIKey resolves the API key for a fallback provider. +func resolveFallbackAPIKey(provider string, envVars map[string]string) string { + switch provider { + case "openai": + return envVars["OPENAI_API_KEY"] + case "anthropic": + return envVars["ANTHROPIC_API_KEY"] + case "gemini": + return envVars["GEMINI_API_KEY"] + case "ollama": + return "ollama" + default: + return envVars["LLM_API_KEY"] + } +} + +// resolveFallbackBaseURL resolves the base URL for a fallback provider. +func resolveFallbackBaseURL(provider string, envVars map[string]string) string { + switch provider { + case "openai": + return envVars["OPENAI_BASE_URL"] + case "anthropic": + return envVars["ANTHROPIC_BASE_URL"] + case "ollama": + return envVars["OLLAMA_BASE_URL"] + default: + return "" + } +} + func resolveAPIKey(mc *ModelConfig, envVars map[string]string) { switch mc.Provider { case "openai": diff --git a/forge-core/runtime/config_test.go b/forge-core/runtime/config_test.go new file mode 100644 index 0000000..ca9aede --- /dev/null +++ b/forge-core/runtime/config_test.go @@ -0,0 +1,196 @@ +package runtime + +import ( + "testing" + + "github.com/initializ/forge/forge-core/types" +) + +func TestResolveModelConfig_FallbacksFromYAML(t *testing.T) { + cfg := &types.ForgeConfig{ + Model: types.ModelRef{ + Provider: "anthropic", + Name: "claude-sonnet-4-20250514", + Fallbacks: []types.ModelFallback{ + {Provider: "openai", Name: "gpt-4o"}, + {Provider: "gemini"}, + }, + }, + } + envVars := map[string]string{ + "ANTHROPIC_API_KEY": "sk-ant-test", + "OPENAI_API_KEY": "sk-openai-test", + "GEMINI_API_KEY": "gemini-test", + } + + mc := ResolveModelConfig(cfg, envVars, "") + if mc == nil { + t.Fatal("expected non-nil ModelConfig") + } + if mc.Provider != "anthropic" { + t.Fatalf("expected primary provider anthropic, got %s", mc.Provider) + } + if len(mc.Fallbacks) != 2 { + t.Fatalf("expected 2 fallbacks, got %d", len(mc.Fallbacks)) + } + if mc.Fallbacks[0].Provider != "openai" { + t.Errorf("expected first fallback openai, got %s", mc.Fallbacks[0].Provider) + } + if mc.Fallbacks[0].Client.Model != "gpt-4o" { + t.Errorf("expected first fallback model gpt-4o, got %s", mc.Fallbacks[0].Client.Model) + } + if mc.Fallbacks[1].Provider != "gemini" { + t.Errorf("expected second fallback gemini, got %s", mc.Fallbacks[1].Provider) + } + if mc.Fallbacks[1].Client.Model != "gemini-2.5-flash" { + t.Errorf("expected default gemini model, got %s", mc.Fallbacks[1].Client.Model) + } +} + +func TestResolveModelConfig_FallbacksFromEnvVar(t *testing.T) { + cfg := &types.ForgeConfig{ + Model: types.ModelRef{ + Provider: "anthropic", + Name: "claude-sonnet-4-20250514", + }, + } + envVars := map[string]string{ + "ANTHROPIC_API_KEY": "sk-ant-test", + "OPENAI_API_KEY": "sk-openai-test", + "GEMINI_API_KEY": "gemini-test", + "FORGE_MODEL_FALLBACKS": "openai:gpt-4o-mini,gemini:gemini-2.5-pro", + } + + mc := ResolveModelConfig(cfg, envVars, "") + if mc == nil { + t.Fatal("expected non-nil ModelConfig") + } + if len(mc.Fallbacks) < 2 { + t.Fatalf("expected at least 2 fallbacks, got %d", len(mc.Fallbacks)) + } + if mc.Fallbacks[0].Provider != "openai" { + t.Errorf("expected first fallback openai, got %s", mc.Fallbacks[0].Provider) + } + if mc.Fallbacks[0].Client.Model != "gpt-4o-mini" { + t.Errorf("expected model gpt-4o-mini, got %s", mc.Fallbacks[0].Client.Model) + } + if mc.Fallbacks[1].Provider != "gemini" { + t.Errorf("expected second fallback gemini, got %s", mc.Fallbacks[1].Provider) + } + if mc.Fallbacks[1].Client.Model != "gemini-2.5-pro" { + t.Errorf("expected model gemini-2.5-pro, got %s", mc.Fallbacks[1].Client.Model) + } +} + +func TestResolveModelConfig_AutoDetectFallbacks(t *testing.T) { + cfg := &types.ForgeConfig{ + Model: types.ModelRef{ + Provider: "anthropic", + Name: "claude-sonnet-4-20250514", + }, + } + envVars := map[string]string{ + "ANTHROPIC_API_KEY": "sk-ant-test", + "OPENAI_API_KEY": "sk-openai-test", + } + + mc := ResolveModelConfig(cfg, envVars, "") + if mc == nil { + t.Fatal("expected non-nil ModelConfig") + } + if len(mc.Fallbacks) != 1 { + t.Fatalf("expected 1 auto-detected fallback, got %d", len(mc.Fallbacks)) + } + if mc.Fallbacks[0].Provider != "openai" { + t.Errorf("expected auto-detected fallback openai, got %s", mc.Fallbacks[0].Provider) + } +} + +func TestResolveModelConfig_PrimaryNotDuplicated(t *testing.T) { + cfg := &types.ForgeConfig{ + Model: types.ModelRef{ + Provider: "openai", + Name: "gpt-4o", + Fallbacks: []types.ModelFallback{ + {Provider: "openai", Name: "gpt-4o-mini"}, + }, + }, + } + envVars := map[string]string{ + "OPENAI_API_KEY": "sk-openai-test", + } + + mc := ResolveModelConfig(cfg, envVars, "") + if mc == nil { + t.Fatal("expected non-nil ModelConfig") + } + // Primary provider should not appear in fallbacks + for _, fb := range mc.Fallbacks { + if fb.Provider == "openai" { + t.Errorf("primary provider openai should not appear in fallbacks") + } + } +} + +func TestResolveModelConfig_MissingAPIKeySkipsFallback(t *testing.T) { + cfg := &types.ForgeConfig{ + Model: types.ModelRef{ + Provider: "anthropic", + Name: "claude-sonnet-4-20250514", + Fallbacks: []types.ModelFallback{ + {Provider: "openai", Name: "gpt-4o"}, + }, + }, + } + envVars := map[string]string{ + "ANTHROPIC_API_KEY": "sk-ant-test", + // No OPENAI_API_KEY + } + + mc := ResolveModelConfig(cfg, envVars, "") + if mc == nil { + t.Fatal("expected non-nil ModelConfig") + } + if len(mc.Fallbacks) != 0 { + t.Fatalf("expected 0 fallbacks (missing API key), got %d", len(mc.Fallbacks)) + } +} + +func TestResolveModelConfig_NoFallbacksWhenSingleProvider(t *testing.T) { + cfg := &types.ForgeConfig{ + Model: types.ModelRef{ + Provider: "anthropic", + Name: "claude-sonnet-4-20250514", + }, + } + envVars := map[string]string{ + "ANTHROPIC_API_KEY": "sk-ant-test", + } + + mc := ResolveModelConfig(cfg, envVars, "") + if mc == nil { + t.Fatal("expected non-nil ModelConfig") + } + if len(mc.Fallbacks) != 0 { + t.Fatalf("expected 0 fallbacks, got %d", len(mc.Fallbacks)) + } +} + +func TestDefaultModelForProvider(t *testing.T) { + tests := []struct { + provider string + expected string + }{ + {"openai", "gpt-5.2-2025-12-11"}, + {"anthropic", "claude-sonnet-4-20250514"}, + {"gemini", "gemini-2.5-flash"}, + {"ollama", "llama3"}, + {"unknown", ""}, + } + for _, tt := range tests { + got := defaultModelForProvider(tt.provider) + if got != tt.expected { + t.Errorf("defaultModelForProvider(%q) = %q, want %q", tt.provider, got, tt.expected) + } + } +} diff --git a/forge-core/types/config.go b/forge-core/types/config.go index 0770661..856e578 100644 --- a/forge-core/types/config.go +++ b/forge-core/types/config.go @@ -36,9 +36,16 @@ type SkillsRef struct { // ModelRef identifies the model an agent uses. type ModelRef struct { + Provider string `yaml:"provider"` + Name string `yaml:"name"` + Version string `yaml:"version,omitempty"` + Fallbacks []ModelFallback `yaml:"fallbacks,omitempty"` +} + +// ModelFallback identifies an alternative LLM provider for fallback. +type ModelFallback struct { Provider string `yaml:"provider"` - Name string `yaml:"name"` - Version string `yaml:"version,omitempty"` + Name string `yaml:"name,omitempty"` } // ToolRef is a lightweight reference to a tool in forge.yaml.