Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 104 additions & 17 deletions forge-cli/cmd/init.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cmd

import (
"context"
"fmt"
"os"
"os/exec"
Expand All @@ -16,6 +17,7 @@ import (
"github.com/initializ/forge/forge-cli/internal/tui/steps"
"github.com/initializ/forge/forge-cli/skills"
"github.com/initializ/forge/forge-cli/templates"
"github.com/initializ/forge/forge-core/llm/oauth"
"github.com/initializ/forge/forge-core/tools/builtins"
"github.com/initializ/forge/forge-core/util"
"github.com/initializ/forge/forge-skills/contract"
Expand All @@ -30,6 +32,7 @@ type initOptions struct {
Language string
ModelProvider string
APIKey string // validated provider key
Fallbacks []tui.FallbackProvider
Channels []string
SkillsFile string
Tools []toolEntry
Expand All @@ -56,6 +59,7 @@ type templateData struct {
Entrypoint string
ModelProvider string
ModelName string
Fallbacks []fallbackTmplData
Channels []string
Tools []toolEntry
BuiltinTools []string
Expand All @@ -64,6 +68,12 @@ type templateData struct {
EnvVars []envVarEntry
}

// fallbackTmplData holds template data for a fallback provider.
type fallbackTmplData struct {
Provider string
ModelName string
}

// skillTmplData holds template data for a registry skill.
type skillTmplData struct {
Name string
Expand Down Expand Up @@ -103,6 +113,7 @@ func init() {
initCmd.Flags().StringSlice("tools", nil, "builtin tools to enable (e.g., web_search,http_request)")
initCmd.Flags().StringSlice("skills", nil, "registry skills to include (e.g., github,weather)")
initCmd.Flags().String("api-key", "", "LLM provider API key")
initCmd.Flags().StringSlice("fallbacks", nil, "fallback LLM providers (e.g., openai,gemini)")
initCmd.Flags().Bool("force", false, "overwrite existing directory")
}

Expand All @@ -128,6 +139,10 @@ func runInit(cmd *cobra.Command, args []string) error {
opts.BuiltinTools, _ = cmd.Flags().GetStringSlice("tools")
opts.Skills, _ = cmd.Flags().GetStringSlice("skills")
opts.APIKey, _ = cmd.Flags().GetString("api-key")
fallbackProviders, _ := cmd.Flags().GetStringSlice("fallbacks")
for _, p := range fallbackProviders {
opts.Fallbacks = append(opts.Fallbacks, tui.FallbackProvider{Provider: p})
}

nonInteractive, _ := cmd.Flags().GetBool("non-interactive")
opts.NonInteractive = nonInteractive
Expand Down Expand Up @@ -216,6 +231,11 @@ func collectInteractive(opts *initOptions) error {
return validateProviderKey(provider, key)
}

// Build OAuth flow callback
oauthFlowFn := func(provider string) (string, error) {
return runOAuthFlow(provider)
}

// Build web search key validation callback
validateWebSearchKeyFn := func(provider, key string) error {
return validateWebSearchKey(provider, key)
Expand All @@ -224,7 +244,8 @@ func collectInteractive(opts *initOptions) error {
// Build step list
wizardSteps := []tui.Step{
steps.NewNameStep(styles, opts.Name),
steps.NewProviderStep(styles, validateKeyFn),
steps.NewProviderStep(styles, validateKeyFn, oauthFlowFn),
steps.NewFallbackStep(styles, validateKeyFn),
steps.NewChannelStep(styles),
steps.NewToolsStep(styles, toolInfos, validateWebSearchKeyFn),
steps.NewSkillsStep(styles, skillInfos),
Expand Down Expand Up @@ -264,7 +285,12 @@ func collectInteractive(opts *initOptions) error {

opts.ModelProvider = ctx.Provider
opts.APIKey = ctx.APIKey
opts.Fallbacks = ctx.Fallbacks
opts.CustomModel = ctx.CustomModel
// Use wizard-selected model name if available
if ctx.ModelName != "" {
opts.CustomModel = ctx.ModelName
}

if ctx.Channel != "" && ctx.Channel != "none" {
opts.Channels = []string{ctx.Channel}
Expand Down Expand Up @@ -711,22 +737,20 @@ func buildTemplateData(opts *initOptions) templateData {
}
}

// Set default model name based on provider
switch opts.ModelProvider {
case "openai":
data.ModelName = "gpt-4o-mini"
case "anthropic":
data.ModelName = "claude-sonnet-4-20250514"
case "gemini":
data.ModelName = "gemini-2.5-flash"
case "ollama":
data.ModelName = "llama3"
default:
if opts.CustomModel != "" {
data.ModelName = opts.CustomModel
} else {
data.ModelName = "default"
}
// Set model name: use wizard-selected model, or fall back to provider default
if opts.CustomModel != "" {
data.ModelName = opts.CustomModel
} else {
data.ModelName = defaultModelNameForProvider(opts.ModelProvider)
}

// Build fallback entries for templates
for _, fb := range opts.Fallbacks {
modelName := defaultModelNameForProvider(fb.Provider)
data.Fallbacks = append(data.Fallbacks, fallbackTmplData{
Provider: fb.Provider,
ModelName: modelName,
})
}

// Build skill entries for templates
Expand Down Expand Up @@ -759,6 +783,22 @@ func buildTemplateData(opts *initOptions) templateData {
return data
}

// defaultModelNameForProvider returns the default model name for wizard templates.
func defaultModelNameForProvider(provider string) string {
switch provider {
case "openai":
return "gpt-5.2-2025-12-11"
case "anthropic":
return "claude-sonnet-4-20250514"
case "gemini":
return "gemini-2.5-flash"
case "ollama":
return "llama3"
default:
return "default"
}
}

// buildEnvVars builds the list of environment variables for the .env file.
func buildEnvVars(opts *initOptions) []envVarEntry {
var vars []envVarEntry
Expand Down Expand Up @@ -831,6 +871,34 @@ func buildEnvVars(opts *initOptions) []envVarEntry {
}
}

// Fallback provider env vars
fallbackKeyMap := map[string]string{
"openai": "OPENAI_API_KEY",
"anthropic": "ANTHROPIC_API_KEY",
"gemini": "GEMINI_API_KEY",
}
for _, fb := range opts.Fallbacks {
envKey, ok := fallbackKeyMap[fb.Provider]
if !ok || fb.APIKey == "" {
continue
}
// Skip if already written (e.g., primary provider)
alreadyWritten := false
for _, v := range vars {
if v.Key == envKey {
alreadyWritten = true
break
}
}
if !alreadyWritten {
vars = append(vars, envVarEntry{
Key: envKey,
Value: fb.APIKey,
Comment: fmt.Sprintf("%s API key (fallback)", titleCase(fb.Provider)),
})
}
}

// Skill env vars (skip keys already added above)
written := make(map[string]bool)
for _, v := range vars {
Expand Down Expand Up @@ -882,6 +950,25 @@ func containsStr(slice []string, val string) bool {
return false
}

// runOAuthFlow executes the OAuth browser flow for a provider and returns the access token.
func runOAuthFlow(provider string) (string, error) {
var config oauth.ProviderConfig
switch provider {
case "openai":
config = oauth.OpenAIConfig()
default:
return "", fmt.Errorf("OAuth not supported for provider %q", provider)
}

flow := oauth.NewFlow(config)
token, err := flow.Execute(context.Background(), provider)
if err != nil {
return "", err
}

return token.AccessToken, nil
}

// titleCase capitalizes the first letter of a string.
func titleCase(s string) string {
if s == "" {
Expand Down
7 changes: 6 additions & 1 deletion forge-cli/cmd/init_egress.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,15 @@ func deriveEgressDomains(opts *initOptions, skills []contract.SkillDescriptor) [
}
}

// 1. Provider domain
// 1. Provider domains (primary + fallbacks)
if d, ok := providerDomains[opts.ModelProvider]; ok {
add(d)
}
for _, fb := range opts.Fallbacks {
if d, ok := providerDomains[fb.Provider]; ok {
add(d)
}
}

// 2. Channel domains
for _, d := range security.ResolveCapabilities(opts.Channels) {
Expand Down
2 changes: 1 addition & 1 deletion forge-cli/cmd/init_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ func TestBuildTemplateData_DefaultModels(t *testing.T) {
provider string
expectedModel string
}{
{"openai", "gpt-4o-mini"},
{"openai", "gpt-5.2-2025-12-11"},
{"anthropic", "claude-sonnet-4-20250514"},
{"gemini", "gemini-2.5-flash"},
{"ollama", "llama3"},
Expand Down
12 changes: 12 additions & 0 deletions forge-cli/internal/tui/components/multi_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,18 @@ func (m MultiSelect) Update(msg tea.Msg) (MultiSelect, tea.Cmd) {
case " ":
m.Items[m.cursor].Checked = !m.Items[m.cursor].Checked
case "enter":
// If nothing is checked, auto-check the cursor item so the user
// doesn't have to remember Space+Enter for a single selection.
anyChecked := false
for _, item := range m.Items {
if item.Checked {
anyChecked = true
break
}
}
if !anyChecked && len(m.Items) > 0 {
m.Items[m.cursor].Checked = true
}
m.done = true
}
}
Expand Down
Loading
Loading