diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..e28305f --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,34 @@ +# Forge - CLAUDE.md + +## Project Structure + +Multi-module Go workspace with three modules: +- `forge-core/` — Core library (registry, tools, security, channels, LLM) +- `forge-cli/` — CLI commands, TUI wizard, runtime +- `forge-plugins/` — Channel plugins (telegram, slack), markdown converter + +## Pre-Commit Requirements + +**Always run before committing:** + +```sh +# Format all modules +gofmt -w forge-core/ forge-cli/ forge-plugins/ + +# Lint all modules +golangci-lint run ./forge-core/... +golangci-lint run ./forge-cli/... +golangci-lint run ./forge-plugins/... +``` + +Fix any lint errors and formatting issues before creating commits. + +## Testing + +Run tests for affected modules before committing: + +```sh +cd forge-core && go test ./... +cd forge-cli && go test ./... +cd forge-plugins && go test ./... +``` diff --git a/forge-cli/cmd/init.go b/forge-cli/cmd/init.go index d507883..65486d1 100644 --- a/forge-cli/cmd/init.go +++ b/forge-cli/cmd/init.go @@ -187,6 +187,8 @@ func collectInteractive(opts *initOptions) error { DisplayName: s.DisplayName, Description: s.Description, RequiredEnv: s.RequiredEnv, + OneOfEnv: s.OneOfEnv, + OptionalEnv: s.OptionalEnv, RequiredBins: s.RequiredBins, EgressDomains: s.EgressDomains, }) @@ -194,23 +196,25 @@ func collectInteractive(opts *initOptions) error { } // Build the egress derivation callback (avoids circular import) - deriveEgressFn := func(provider string, channels, tools, skills []string) []string { + deriveEgressFn := func(provider string, channels, tools, skills []string, envVars map[string]string) []string { tmpOpts := &initOptions{ ModelProvider: provider, Channels: channels, BuiltinTools: tools, - EnvVars: make(map[string]string), + EnvVars: envVars, } selectedInfos := lookupSelectedSkills(skills) return deriveEgressDomains(tmpOpts, selectedInfos) } - // Build validation callbacks + // Build validation callback validateKeyFn := func(provider, key string) error { return validateProviderKey(provider, key) } - validatePerpFn := func(key string) error { - return validatePerplexityKey(key) + + // Build web search key validation callback + validateWebSearchKeyFn := func(provider, key string) error { + return validateWebSearchKey(provider, key) } // Build step list @@ -218,7 +222,7 @@ func collectInteractive(opts *initOptions) error { steps.NewNameStep(styles, opts.Name), steps.NewProviderStep(styles, validateKeyFn), steps.NewChannelStep(styles), - steps.NewToolsStep(styles, toolInfos, validatePerpFn), + steps.NewToolsStep(styles, toolInfos, validateWebSearchKeyFn), steps.NewSkillsStep(styles, skillInfos), steps.NewEgressStep(styles, deriveEgressFn), steps.NewReviewStep(styles), // scaffold is handled by the caller after collectInteractive returns @@ -541,6 +545,19 @@ func scaffold(opts *initOptions) error { if err := os.WriteFile(skillPath, content, 0o644); err != nil { return fmt.Errorf("writing skill file %s: %w", skillName, err) } + + // Vendor script if the skill has one + if skillreg.HasSkillScript(skillName) { + scriptContent, sErr := skillreg.LoadSkillScript(skillName) + if sErr == nil { + scriptDir := filepath.Join(dir, "skills", "scripts") + _ = os.MkdirAll(scriptDir, 0o755) + scriptPath := filepath.Join(scriptDir, skillName+".sh") + if wErr := os.WriteFile(scriptPath, scriptContent, 0o755); wErr != nil { + fmt.Printf("Warning: could not write script for %q: %s\n", skillName, wErr) + } + } + } } fmt.Printf("\nCreated agent project in ./%s\n", opts.AgentID) @@ -754,13 +771,24 @@ func buildEnvVars(opts *initOptions) []envVarEntry { vars = append(vars, envVarEntry{Key: "MODEL_API_KEY", Value: apiKeyVal, Comment: "Model provider API key"}) } - // Perplexity key if web_search selected + // Web search provider key if web_search selected if containsStr(opts.BuiltinTools, "web_search") { - val := opts.EnvVars["PERPLEXITY_API_KEY"] - if val == "" { - val = "your-perplexity-key-here" + provider := opts.EnvVars["WEB_SEARCH_PROVIDER"] + if provider == "perplexity" { + val := opts.EnvVars["PERPLEXITY_API_KEY"] + if val == "" { + val = "your-perplexity-key-here" + } + vars = append(vars, envVarEntry{Key: "PERPLEXITY_API_KEY", Value: val, Comment: "Perplexity API key for web_search"}) + vars = append(vars, envVarEntry{Key: "WEB_SEARCH_PROVIDER", Value: "perplexity", Comment: "Web search provider"}) + } else { + // Default to Tavily + val := opts.EnvVars["TAVILY_API_KEY"] + if val == "" { + val = "your-tavily-key-here" + } + vars = append(vars, envVarEntry{Key: "TAVILY_API_KEY", Value: val, Comment: "Tavily API key for web_search"}) } - vars = append(vars, envVarEntry{Key: "PERPLEXITY_API_KEY", Value: val, Comment: "Perplexity API key for web_search"}) } // Channel env vars @@ -777,21 +805,30 @@ func buildEnvVars(opts *initOptions) []envVarEntry { } } - // Skill env vars + // Skill env vars (skip keys already added above) + written := make(map[string]bool) + for _, v := range vars { + written[v.Key] = true + } for _, skillName := range opts.Skills { info := skillreg.GetSkillByName(skillName) if info == nil { continue } for _, env := range info.RequiredEnv { - val := opts.EnvVars[env] - if val == "" { - val = "" + if written[env] { + continue } + written[env] = true + val := opts.EnvVars[env] vars = append(vars, envVarEntry{Key: env, Value: val, Comment: fmt.Sprintf("Required by %s skill", skillName)}) } if len(info.OneOfEnv) > 0 { for _, env := range info.OneOfEnv { + if written[env] { + continue + } + written[env] = true val := opts.EnvVars[env] vars = append(vars, envVarEntry{ Key: env, diff --git a/forge-cli/cmd/init_egress.go b/forge-cli/cmd/init_egress.go index 6497a6c..2255759 100644 --- a/forge-cli/cmd/init_egress.go +++ b/forge-cli/cmd/init_egress.go @@ -38,9 +38,21 @@ func deriveEgressDomains(opts *initOptions, skills []skillreg.SkillInfo) []strin add(d) } - // 3. Tool domains - for _, d := range security.InferToolDomains(opts.BuiltinTools) { - add(d) + // 3. Tool domains (web_search filtered by provider) + for _, toolName := range opts.BuiltinTools { + if toolName == "web_search" || toolName == "web-search" { + provider := opts.EnvVars["WEB_SEARCH_PROVIDER"] + switch provider { + case "perplexity": + add("api.perplexity.ai") + default: + add("api.tavily.com") + } + continue + } + for _, d := range security.DefaultToolDomains[toolName] { + add(d) + } } // 4. Skill domains diff --git a/forge-cli/cmd/init_test.go b/forge-cli/cmd/init_test.go index 85a7f9c..3e0fe9c 100644 --- a/forge-cli/cmd/init_test.go +++ b/forge-cli/cmd/init_test.go @@ -455,8 +455,8 @@ func TestScaffold_EgressInForgeYAML(t *testing.T) { if !strings.Contains(yamlStr, "api.openai.com") { t.Error("forge.yaml missing api.openai.com in egress domains") } - if !strings.Contains(yamlStr, "api.perplexity.ai") { - t.Error("forge.yaml missing api.perplexity.ai in egress domains") + if !strings.Contains(yamlStr, "api.tavily.com") { + t.Error("forge.yaml missing api.tavily.com in egress domains") } } @@ -504,13 +504,13 @@ func TestDeriveEgressDomains(t *testing.T) { domains := deriveEgressDomains(opts, skillInfos) expected := map[string]bool{ - "api.openai.com": true, - "slack.com": true, - "hooks.slack.com": true, - "api.slack.com": true, - "api.perplexity.ai": true, - "api.github.com": true, - "github.com": true, + "api.openai.com": true, + "slack.com": true, + "hooks.slack.com": true, + "api.slack.com": true, + "api.tavily.com": true, + "api.github.com": true, + "github.com": true, } for _, d := range domains { if !expected[d] { @@ -550,8 +550,8 @@ func TestBuildEnvVars(t *testing.T) { if !found["OPENAI_API_KEY"] { t.Error("missing OPENAI_API_KEY") } - if !found["PERPLEXITY_API_KEY"] { - t.Error("missing PERPLEXITY_API_KEY") + if !found["TAVILY_API_KEY"] { + t.Error("missing TAVILY_API_KEY") } if !found["GH_TOKEN"] { t.Error("missing GH_TOKEN") diff --git a/forge-cli/cmd/init_validate.go b/forge-cli/cmd/init_validate.go index f814ceb..3fbb2b6 100644 --- a/forge-cli/cmd/init_validate.go +++ b/forge-cli/cmd/init_validate.go @@ -17,6 +17,7 @@ var ( anthropicValidationURL = "https://api.anthropic.com/v1/messages" geminiValidationURL = "https://generativelanguage.googleapis.com/v1beta/models" ollamaValidationURL = "http://localhost:11434/api/tags" + tavilyValidationURL = "https://api.tavily.com/search" perplexityValidationURL = "https://api.perplexity.ai/chat/completions" ) @@ -140,6 +141,52 @@ func validateOllamaConnection(ctx context.Context) error { return nil } +// validateWebSearchKey validates a web search API key based on the provider. +func validateWebSearchKey(provider, apiKey string) error { + switch provider { + case "tavily": + return validateTavilyKey(apiKey) + case "perplexity": + return validatePerplexityKey(apiKey) + default: + return fmt.Errorf("unknown web search provider %q", provider) + } +} + +// validateTavilyKey validates a Tavily API key with a minimal search request. +func validateTavilyKey(apiKey string) error { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + body := map[string]any{ + "query": "test", + "max_results": 1, + } + bodyBytes, _ := json.Marshal(body) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, tavilyValidationURL, bytes.NewReader(bodyBytes)) + if err != nil { + return fmt.Errorf("creating request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("connecting to Tavily: %w", err) + } + defer func() { _ = resp.Body.Close() }() + _, _ = io.Copy(io.Discard, resp.Body) + + if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { + return fmt.Errorf("invalid Tavily API key (%d)", resp.StatusCode) + } + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("tavily API returned status %d", resp.StatusCode) + } + return nil +} + // validatePerplexityKey validates a Perplexity API key with a minimal request. func validatePerplexityKey(apiKey string) error { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) diff --git a/forge-cli/cmd/init_validate_test.go b/forge-cli/cmd/init_validate_test.go index da707d5..d87cdfb 100644 --- a/forge-cli/cmd/init_validate_test.go +++ b/forge-cli/cmd/init_validate_test.go @@ -129,6 +129,46 @@ func TestValidateProviderKey_Timeout(t *testing.T) { } } +func TestValidateTavilyKey_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "Bearer valid-key" { + w.WriteHeader(http.StatusUnauthorized) + return + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"query":"test","results":[]}`)) + })) + defer server.Close() + + orig := tavilyValidationURL + tavilyValidationURL = server.URL + defer func() { tavilyValidationURL = orig }() + + err := validateWebSearchKey("tavily", "valid-key") + if err != nil { + t.Fatalf("expected nil error, got: %v", err) + } +} + +func TestValidateTavilyKey_Unauthorized(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + defer server.Close() + + orig := tavilyValidationURL + tavilyValidationURL = server.URL + defer func() { tavilyValidationURL = orig }() + + err := validateWebSearchKey("tavily", "bad-key") + if err == nil { + t.Fatal("expected error for unauthorized key") + } + if !strings.Contains(err.Error(), "invalid") { + t.Errorf("expected error containing 'invalid', got: %v", err) + } +} + func TestValidatePerplexityKey_Success(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("Authorization") != "Bearer valid-key" { @@ -144,7 +184,7 @@ func TestValidatePerplexityKey_Success(t *testing.T) { perplexityValidationURL = server.URL defer func() { perplexityValidationURL = orig }() - err := validatePerplexityKey("valid-key") + err := validateWebSearchKey("perplexity", "valid-key") if err != nil { t.Fatalf("expected nil error, got: %v", err) } @@ -160,7 +200,7 @@ func TestValidatePerplexityKey_Unauthorized(t *testing.T) { perplexityValidationURL = server.URL defer func() { perplexityValidationURL = orig }() - err := validatePerplexityKey("bad-key") + err := validateWebSearchKey("perplexity", "bad-key") if err == nil { t.Fatal("expected error for unauthorized key") } diff --git a/forge-cli/cmd/skills.go b/forge-cli/cmd/skills.go index a1793c1..d651e84 100644 --- a/forge-cli/cmd/skills.go +++ b/forge-cli/cmd/skills.go @@ -1,13 +1,16 @@ package cmd import ( + "bufio" "fmt" "os" + "os/exec" "path/filepath" "strings" "github.com/initializ/forge/forge-cli/config" cliskills "github.com/initializ/forge/forge-cli/skills" + skillreg "github.com/initializ/forge/forge-core/registry" coreskills "github.com/initializ/forge/forge-core/skills" "github.com/spf13/cobra" ) @@ -23,8 +26,115 @@ var skillsValidateCmd = &cobra.Command{ RunE: runSkillsValidate, } +var skillsAddCmd = &cobra.Command{ + Use: "add ", + Short: "Add a registry skill to the current project", + Args: cobra.ExactArgs(1), + RunE: runSkillsAdd, +} + func init() { skillsCmd.AddCommand(skillsValidateCmd) + skillsCmd.AddCommand(skillsAddCmd) +} + +func runSkillsAdd(cmd *cobra.Command, args []string) error { + name := args[0] + + // Look up skill in registry + info := skillreg.GetSkillByName(name) + if info == nil { + return fmt.Errorf("skill %q not found in registry", name) + } + + wd, err := os.Getwd() + if err != nil { + return fmt.Errorf("getting working directory: %w", err) + } + + // Write skill markdown + skillDir := filepath.Join(wd, "skills") + if err := os.MkdirAll(skillDir, 0o755); err != nil { + return fmt.Errorf("creating skills directory: %w", err) + } + + content, err := skillreg.LoadSkillFile(name) + if err != nil { + return fmt.Errorf("loading skill file: %w", err) + } + + skillPath := filepath.Join(skillDir, name+".md") + if err := os.WriteFile(skillPath, content, 0o644); err != nil { + return fmt.Errorf("writing skill file: %w", err) + } + fmt.Printf(" Added skill file: skills/%s.md\n", name) + + // Write script if the skill has one + if skillreg.HasSkillScript(name) { + scriptContent, sErr := skillreg.LoadSkillScript(name) + if sErr == nil { + scriptDir := filepath.Join(skillDir, "scripts") + if mkErr := os.MkdirAll(scriptDir, 0o755); mkErr != nil { + fmt.Printf(" Warning: could not create scripts directory: %s\n", mkErr) + } else { + scriptPath := filepath.Join(scriptDir, name+".sh") + if wErr := os.WriteFile(scriptPath, scriptContent, 0o755); wErr != nil { + fmt.Printf(" Warning: could not write script: %s\n", wErr) + } else { + fmt.Printf(" Added script: skills/scripts/%s.sh\n", name) + } + } + } + } + + // Check binary requirements + if len(info.RequiredBins) > 0 { + fmt.Println("\n Binary requirements:") + for _, bin := range info.RequiredBins { + if _, lookErr := exec.LookPath(bin); lookErr != nil { + fmt.Printf(" %s — MISSING (not found in PATH)\n", bin) + } else { + fmt.Printf(" %s — ok\n", bin) + } + } + } + + // Check env var requirements + missingEnvs := []string{} + if len(info.RequiredEnv) > 0 { + fmt.Println("\n Environment requirements:") + for _, env := range info.RequiredEnv { + if os.Getenv(env) == "" { + fmt.Printf(" %s — NOT SET\n", env) + missingEnvs = append(missingEnvs, env) + } else { + fmt.Printf(" %s — ok\n", env) + } + } + } + + // Prompt for missing env vars + if len(missingEnvs) > 0 { + reader := bufio.NewReader(os.Stdin) + for _, env := range missingEnvs { + fmt.Printf("\n Enter value for %s (or press Enter to skip): ", env) + val, _ := reader.ReadString('\n') + val = strings.TrimSpace(val) + if val != "" { + // Append to .env file + envPath := filepath.Join(wd, ".env") + f, fErr := os.OpenFile(envPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) + if fErr == nil { + _, _ = fmt.Fprintf(f, "# Required by %s skill\n%s=%s\n", name, env, val) + _ = f.Close() + fmt.Printf(" Added %s to .env\n", env) + } + } + } + } + + fmt.Printf("\nSkill %q added successfully.\n", info.DisplayName) + return nil } func runSkillsValidate(cmd *cobra.Command, args []string) error { diff --git a/forge-cli/internal/tui/components/egress_display.go b/forge-cli/internal/tui/components/egress_display.go index 18b6646..5d03dcf 100644 --- a/forge-cli/internal/tui/components/egress_display.go +++ b/forge-cli/internal/tui/components/egress_display.go @@ -47,8 +47,9 @@ func NewEgressDisplay(domains []EgressDomain, primaryStyle, dimStyle, borderStyl } } -// Init returns no initial command. -func (e EgressDisplay) Init() tea.Cmd { +// Init resets done state so the component can be re-used after back-navigation. +func (e *EgressDisplay) Init() tea.Cmd { + e.done = false return nil } diff --git a/forge-cli/internal/tui/components/multi_select.go b/forge-cli/internal/tui/components/multi_select.go index 5c497e0..b989b76 100644 --- a/forge-cli/internal/tui/components/multi_select.go +++ b/forge-cli/internal/tui/components/multi_select.go @@ -53,8 +53,9 @@ func NewMultiSelect(items []MultiSelectItem, accentColor, accentDimColor, primar } } -// Init returns no initial command. -func (m MultiSelect) Init() tea.Cmd { +// Init resets done state so the component can be re-used after back-navigation. +func (m *MultiSelect) Init() tea.Cmd { + m.done = false return nil } diff --git a/forge-cli/internal/tui/components/single_select.go b/forge-cli/internal/tui/components/single_select.go index bd4aeb8..39b3d04 100644 --- a/forge-cli/internal/tui/components/single_select.go +++ b/forge-cli/internal/tui/components/single_select.go @@ -59,8 +59,9 @@ func NewSingleSelect(items []SingleSelectItem, accentColor, primaryColor, second } } -// Init returns no initial command. -func (s SingleSelect) Init() tea.Cmd { +// Init resets done state so the component can be re-used after back-navigation. +func (s *SingleSelect) Init() tea.Cmd { + s.done = false return nil } diff --git a/forge-cli/internal/tui/steps/egress_step.go b/forge-cli/internal/tui/steps/egress_step.go index 0d3e41a..d21101e 100644 --- a/forge-cli/internal/tui/steps/egress_step.go +++ b/forge-cli/internal/tui/steps/egress_step.go @@ -10,7 +10,7 @@ import ( ) // DeriveEgressFunc computes egress domains from wizard context. -type DeriveEgressFunc func(provider string, channels, tools, skills []string) []string +type DeriveEgressFunc func(provider string, channels, tools, skills []string, envVars map[string]string) []string // EgressStep handles egress domain review. type EgressStep struct { @@ -40,7 +40,7 @@ func (s *EgressStep) Prepare(ctx *tui.WizardContext) { s.domains = nil if s.deriveFn != nil { - s.domains = s.deriveFn(ctx.Provider, channels, ctx.BuiltinTools, ctx.Skills) + s.domains = s.deriveFn(ctx.Provider, channels, ctx.BuiltinTools, ctx.Skills, ctx.EnvVars) } s.empty = len(s.domains) == 0 @@ -149,6 +149,7 @@ func inferSource(domain string, ctx *tui.WizardContext) string { // Tool domains toolDomains := map[string]string{ + "api.tavily.com": "web_search tool", "api.perplexity.ai": "web_search tool", } if src, ok := toolDomains[domain]; ok { diff --git a/forge-cli/internal/tui/steps/provider_step.go b/forge-cli/internal/tui/steps/provider_step.go index 1d987f3..09acfde 100644 --- a/forge-cli/internal/tui/steps/provider_step.go +++ b/forge-cli/internal/tui/steps/provider_step.go @@ -351,6 +351,19 @@ func (s *ProviderStep) Apply(ctx *tui.WizardContext) { ctx.CustomBaseURL = s.customURL ctx.CustomModel = s.customModel ctx.CustomAPIKey = s.customAuth + + // Store the provider API key in EnvVars so later steps (e.g. skills) + // can detect it's already collected and skip re-prompting. + if s.apiKey != "" { + switch s.provider { + case "openai": + ctx.EnvVars["OPENAI_API_KEY"] = s.apiKey + case "anthropic": + ctx.EnvVars["ANTHROPIC_API_KEY"] = s.apiKey + case "gemini": + ctx.EnvVars["GEMINI_API_KEY"] = s.apiKey + } + } } func providerDisplayName(provider string) string { diff --git a/forge-cli/internal/tui/steps/review_step.go b/forge-cli/internal/tui/steps/review_step.go index 98b8a97..ce8b79c 100644 --- a/forge-cli/internal/tui/steps/review_step.go +++ b/forge-cli/internal/tui/steps/review_step.go @@ -45,7 +45,17 @@ func (s *ReviewStep) Prepare(ctx *tui.WizardContext) { } if len(ctx.BuiltinTools) > 0 { - rows = append(rows, components.SummaryRow{Key: "Tools", Value: strings.Join(ctx.BuiltinTools, ", ")}) + var toolNames []string + for _, name := range ctx.BuiltinTools { + if name == "web_search" { + if p := ctx.EnvVars["WEB_SEARCH_PROVIDER"]; p != "" { + toolNames = append(toolNames, fmt.Sprintf("web_search [%s]", p)) + continue + } + } + toolNames = append(toolNames, name) + } + rows = append(rows, components.SummaryRow{Key: "Tools", Value: strings.Join(toolNames, ", ")}) } if len(ctx.Skills) > 0 { @@ -83,8 +93,6 @@ func (s *ReviewStep) Update(msg tea.Msg) (tui.Step, tea.Cmd) { return s, func() tea.Msg { return tui.StepCompleteMsg{} } case "backspace": return s, func() tea.Msg { return tui.StepBackMsg{} } - case "esc": - return s, func() tea.Msg { return tui.StepBackMsg{} } } } return s, nil diff --git a/forge-cli/internal/tui/steps/skills_step.go b/forge-cli/internal/tui/steps/skills_step.go index c04888b..8d8fae0 100644 --- a/forge-cli/internal/tui/steps/skills_step.go +++ b/forge-cli/internal/tui/steps/skills_step.go @@ -2,6 +2,7 @@ package steps import ( "fmt" + "os" "strings" tea "github.com/charmbracelet/bubbletea" @@ -16,33 +17,60 @@ type SkillInfo struct { DisplayName string Description string RequiredEnv []string + OneOfEnv []string + OptionalEnv []string RequiredBins []string EgressDomains []string } +type skillsPhase int + +const ( + skillsSelectPhase skillsPhase = iota + skillsEnvPhase +) + +// envPrompt describes a single env var prompt to show. +type envPrompt struct { + envVar string + label string + allowSkip bool + skillName string + kind string // "required", "one_of", "optional" +} + // SkillsStep handles external skill selection. type SkillsStep struct { styles *tui.StyleSet + allSkills []SkillInfo multiSelect components.MultiSelect + phase skillsPhase complete bool selected []string - empty bool // true if no skills available + empty bool + + // Env prompting + envPrompts []envPrompt + currentPrompt int + keyInput components.SecretInput + envValues map[string]string + knownEnvVars map[string]string // env vars already collected by earlier steps } // NewSkillsStep creates a new skills selection step. func NewSkillsStep(styles *tui.StyleSet, skills []SkillInfo) *SkillsStep { if len(skills) == 0 { return &SkillsStep{ - styles: styles, - complete: false, - empty: true, + styles: styles, + complete: false, + empty: true, + envValues: make(map[string]string), } } var items []components.MultiSelectItem for _, sk := range skills { icon := skillIcon(sk.Name) - var reqLine string var reqs []string if len(sk.RequiredBins) > 0 { reqs = append(reqs, "bins: "+strings.Join(sk.RequiredBins, ", ")) @@ -50,6 +78,10 @@ func NewSkillsStep(styles *tui.StyleSet, skills []SkillInfo) *SkillsStep { if len(sk.RequiredEnv) > 0 { reqs = append(reqs, "env: "+strings.Join(sk.RequiredEnv, ", ")) } + if len(sk.OneOfEnv) > 0 { + reqs = append(reqs, "one of: "+strings.Join(sk.OneOfEnv, " / ")) + } + var reqLine string if len(reqs) > 0 { reqLine = strings.Join(reqs, " · ") } @@ -78,7 +110,17 @@ func NewSkillsStep(styles *tui.StyleSet, skills []SkillInfo) *SkillsStep { return &SkillsStep{ styles: styles, + allSkills: skills, multiSelect: ms, + envValues: make(map[string]string), + } +} + +// Prepare captures env vars already collected by earlier wizard steps. +func (s *SkillsStep) Prepare(ctx *tui.WizardContext) { + s.knownEnvVars = make(map[string]string) + for k, v := range ctx.EnvVars { + s.knownEnvVars[k] = v } } @@ -87,6 +129,10 @@ func (s *SkillsStep) Icon() string { return "📦" } func (s *SkillsStep) Init() tea.Cmd { s.complete = false + s.phase = skillsSelectPhase + s.currentPrompt = 0 + s.envPrompts = nil + s.envValues = make(map[string]string) if s.empty { s.complete = true return func() tea.Msg { return tui.StepCompleteMsg{} } @@ -99,23 +145,241 @@ func (s *SkillsStep) Update(msg tea.Msg) (tui.Step, tea.Cmd) { return s, nil } - updated, cmd := s.multiSelect.Update(msg) - s.multiSelect = updated + switch s.phase { + case skillsSelectPhase: + updated, cmd := s.multiSelect.Update(msg) + s.multiSelect = updated - if s.multiSelect.Done() { - s.selected = s.multiSelect.SelectedValues() - s.complete = true - return s, func() tea.Msg { return tui.StepCompleteMsg{} } + if s.multiSelect.Done() { + s.selected = s.multiSelect.SelectedValues() + + // Build env prompts for selected skills + s.buildEnvPrompts() + + if len(s.envPrompts) == 0 { + s.complete = true + return s, func() tea.Msg { return tui.StepCompleteMsg{} } + } + + // Start env prompting + s.phase = skillsEnvPhase + s.currentPrompt = 0 + s.initCurrentPrompt() + return s, s.keyInput.Init() + } + + return s, cmd + + case skillsEnvPhase: + updated, cmd := s.keyInput.Update(msg) + s.keyInput = updated + + if s.keyInput.Done() { + val := s.keyInput.Value() + prompt := s.envPrompts[s.currentPrompt] + if val != "" { + s.envValues[prompt.envVar] = val + } + + s.currentPrompt++ + + // Check if we're done with all prompts + if s.currentPrompt >= len(s.envPrompts) { + // Check one_of groups + if s.checkOneOfGroups() { + s.complete = true + return s, func() tea.Msg { return tui.StepCompleteMsg{} } + } + // One or more one_of groups unsatisfied — prompts were appended + } + + s.initCurrentPrompt() + return s, s.keyInput.Init() + } + + return s, cmd + } + + return s, nil +} + +// envAlreadyKnown returns true if the env var is already set in OS env or +// was collected by an earlier wizard step (provider key, web search key, etc.). +func (s *SkillsStep) envAlreadyKnown(env string) bool { + if os.Getenv(env) != "" { + return true + } + if v, ok := s.knownEnvVars[env]; ok && v != "" { + return true + } + return false +} + +// buildEnvPrompts creates the list of env prompts for selected skills. +func (s *SkillsStep) buildEnvPrompts() { + s.envPrompts = nil + seen := make(map[string]bool) + + for _, skillName := range s.selected { + sk := s.findSkill(skillName) + if sk == nil { + continue + } + + // Required env vars + for _, env := range sk.RequiredEnv { + if seen[env] || s.envAlreadyKnown(env) { + continue + } + seen[env] = true + s.envPrompts = append(s.envPrompts, envPrompt{ + envVar: env, + label: fmt.Sprintf("%s (required by %s)", env, sk.DisplayName), + allowSkip: false, + skillName: sk.Name, + kind: "required", + }) + } + + // One-of env vars + if len(sk.OneOfEnv) > 0 { + // Check if any one-of is already available + anySet := false + for _, env := range sk.OneOfEnv { + if s.envAlreadyKnown(env) { + anySet = true + break + } + } + if !anySet { + for _, env := range sk.OneOfEnv { + if seen[env] { + continue + } + seen[env] = true + s.envPrompts = append(s.envPrompts, envPrompt{ + envVar: env, + label: fmt.Sprintf("%s (one of %s — %s)", env, strings.Join(sk.OneOfEnv, " / "), sk.DisplayName), + allowSkip: true, // initially skippable, but group must have at least one + skillName: sk.Name, + kind: "one_of", + }) + } + } + } + + // Optional env vars + for _, env := range sk.OptionalEnv { + if seen[env] || s.envAlreadyKnown(env) { + continue + } + seen[env] = true + s.envPrompts = append(s.envPrompts, envPrompt{ + envVar: env, + label: fmt.Sprintf("%s (optional — %s)", env, sk.DisplayName), + allowSkip: true, + skillName: sk.Name, + kind: "optional", + }) + } + } +} + +// checkOneOfGroups verifies that all one_of groups have at least one value. +// If not, appends a mandatory re-prompt and returns false. +func (s *SkillsStep) checkOneOfGroups() bool { + // Collect one_of skills that need checking + type group struct { + skillName string + envVars []string } + seen := make(map[string]bool) + var groups []group - return s, cmd + for _, p := range s.envPrompts { + if p.kind != "one_of" || seen[p.skillName] { + continue + } + seen[p.skillName] = true + sk := s.findSkill(p.skillName) + if sk != nil { + groups = append(groups, group{skillName: p.skillName, envVars: sk.OneOfEnv}) + } + } + + allSatisfied := true + for _, g := range groups { + hasValue := false + for _, env := range g.envVars { + if v, ok := s.envValues[env]; ok && v != "" { + hasValue = true + break + } + } + if !hasValue { + // Re-prompt the first env var as required + sk := s.findSkill(g.skillName) + displayName := g.skillName + if sk != nil { + displayName = sk.DisplayName + } + label := fmt.Sprintf("%s (required — at least one needed for %s)", g.envVars[0], displayName) + s.envPrompts = append(s.envPrompts, envPrompt{ + envVar: g.envVars[0], + label: label, + allowSkip: false, + skillName: g.skillName, + kind: "required", + }) + allSatisfied = false + } + } + + return allSatisfied +} + +func (s *SkillsStep) initCurrentPrompt() { + if s.currentPrompt >= len(s.envPrompts) { + return + } + prompt := s.envPrompts[s.currentPrompt] + s.keyInput = components.NewSecretInput( + prompt.label, + prompt.allowSkip, + s.styles.Theme.Accent, + s.styles.Theme.Success, + s.styles.Theme.Error, + s.styles.Theme.Border, + s.styles.AccentTxt, + s.styles.InactiveBorder, + s.styles.SuccessTxt, + s.styles.ErrorTxt, + s.styles.DimTxt, + s.styles.KbdKey, + s.styles.KbdDesc, + ) +} + +func (s *SkillsStep) findSkill(name string) *SkillInfo { + for i := range s.allSkills { + if s.allSkills[i].Name == name { + return &s.allSkills[i] + } + } + return nil } func (s *SkillsStep) View(width int) string { if s.empty { return fmt.Sprintf(" %s\n", s.styles.DimTxt.Render("No skills available in registry.")) } - return s.multiSelect.View(width) + switch s.phase { + case skillsSelectPhase: + return s.multiSelect.View(width) + case skillsEnvPhase: + return s.keyInput.View(width) + } + return "" } func (s *SkillsStep) Complete() bool { @@ -131,13 +395,17 @@ func (s *SkillsStep) Summary() string { func (s *SkillsStep) Apply(ctx *tui.WizardContext) { ctx.Skills = s.selected + for k, v := range s.envValues { + ctx.EnvVars[k] = v + } } func skillIcon(name string) string { icons := map[string]string{ - "summarize": "🧾", - "github": "🐙", - "weather": "🌤️", + "summarize": "🧾", + "github": "🐙", + "weather": "🌤️", + "tavily-search": "🔍", } if icon, ok := icons[name]; ok { return icon diff --git a/forge-cli/internal/tui/steps/tools_step.go b/forge-cli/internal/tui/steps/tools_step.go index 58e6229..8fc547c 100644 --- a/forge-cli/internal/tui/steps/tools_step.go +++ b/forge-cli/internal/tui/steps/tools_step.go @@ -1,6 +1,7 @@ package steps import ( + "fmt" "os" "strings" @@ -16,31 +17,37 @@ type ToolInfo struct { Description string } +// ValidateWebSearchKeyFunc validates a web search API key for a given provider. +type ValidateWebSearchKeyFunc func(provider, key string) error + type toolsPhase int const ( toolsSelectPhase toolsPhase = iota - toolsPerplexityKeyPhase + toolsWebSearchProviderPhase + toolsWebSearchKeyPhase + toolsWebSearchValidatingPhase toolsDonePhase ) -// ValidatePerplexityFunc validates a Perplexity API key. -type ValidatePerplexityFunc func(key string) error - // ToolsStep handles builtin tool selection. type ToolsStep struct { - styles *tui.StyleSet - phase toolsPhase - multiSelect components.MultiSelect - keyInput components.SecretInput - complete bool - selected []string - perplexityKey string - validatePerp ValidatePerplexityFunc + styles *tui.StyleSet + phase toolsPhase + multiSelect components.MultiSelect + providerSelect components.SingleSelect + keyInput components.SecretInput + complete bool + selected []string + webSearchKey string + webSearchKeyName string // "TAVILY_API_KEY" or "PERPLEXITY_API_KEY" + webSearchProvider string // "tavily" or "perplexity" + validateFn ValidateWebSearchKeyFunc + validating bool } // NewToolsStep creates a new tools selection step. -func NewToolsStep(styles *tui.StyleSet, tools []ToolInfo, validatePerp ValidatePerplexityFunc) *ToolsStep { +func NewToolsStep(styles *tui.StyleSet, tools []ToolInfo, validateFn ValidateWebSearchKeyFunc) *ToolsStep { var items []components.MultiSelectItem for _, t := range tools { icon := toolIcon(t.Name) @@ -66,9 +73,9 @@ func NewToolsStep(styles *tui.StyleSet, tools []ToolInfo, validatePerp ValidateP ) return &ToolsStep{ - styles: styles, - multiSelect: ms, - validatePerp: validatePerp, + styles: styles, + multiSelect: ms, + validateFn: validateFn, } } @@ -92,25 +99,37 @@ func (s *ToolsStep) Update(msg tea.Msg) (tui.Step, tea.Cmd) { if s.multiSelect.Done() { s.selected = s.multiSelect.SelectedValues() - // Check if web_search selected and no perplexity key - if containsStr(s.selected, "web_search") && os.Getenv("PERPLEXITY_API_KEY") == "" { - s.phase = toolsPerplexityKeyPhase - s.keyInput = components.NewSecretInput( - "Perplexity API key for web_search", - true, + // Check if web_search selected and no key is already set + if containsStr(s.selected, "web_search") && + os.Getenv("TAVILY_API_KEY") == "" && + os.Getenv("PERPLEXITY_API_KEY") == "" { + // Show provider selection + s.phase = toolsWebSearchProviderPhase + s.providerSelect = components.NewSingleSelect( + []components.SingleSelectItem{ + {Label: "Tavily (Recommended)", Value: "tavily", Description: "LLM-optimized search with structured results", Icon: "🔍"}, + {Label: "Perplexity", Value: "perplexity", Description: "AI-powered search with citations", Icon: "🌐"}, + }, s.styles.Theme.Accent, - s.styles.Theme.Success, - s.styles.Theme.Error, + s.styles.Theme.Primary, + s.styles.Theme.Secondary, + s.styles.Theme.Dim, s.styles.Theme.Border, - s.styles.AccentTxt, - s.styles.InactiveBorder, - s.styles.SuccessTxt, - s.styles.ErrorTxt, - s.styles.DimTxt, + s.styles.Theme.Accent, + s.styles.Theme.AccentDim, s.styles.KbdKey, s.styles.KbdDesc, ) - return s, s.keyInput.Init() + return s, s.providerSelect.Init() + } + + // If a key is already set in env, detect the provider + if containsStr(s.selected, "web_search") { + if os.Getenv("TAVILY_API_KEY") != "" { + s.webSearchProvider = "tavily" + } else if os.Getenv("PERPLEXITY_API_KEY") != "" { + s.webSearchProvider = "perplexity" + } } s.complete = true @@ -119,27 +138,114 @@ func (s *ToolsStep) Update(msg tea.Msg) (tui.Step, tea.Cmd) { return s, cmd - case toolsPerplexityKeyPhase: + case toolsWebSearchProviderPhase: + updated, cmd := s.providerSelect.Update(msg) + s.providerSelect = updated + + if s.providerSelect.Done() { + _, s.webSearchProvider = s.providerSelect.Selected() + s.initKeyInput("") + return s, s.keyInput.Init() + } + + return s, cmd + + case toolsWebSearchKeyPhase: updated, cmd := s.keyInput.Update(msg) s.keyInput = updated if s.keyInput.Done() { - s.perplexityKey = s.keyInput.Value() + s.webSearchKey = s.keyInput.Value() + + // Run validation if we have a key and a validateFn + if s.webSearchKey != "" && s.validateFn != nil { + s.phase = toolsWebSearchValidatingPhase + s.validating = true + return s, s.runValidation() + } + s.complete = true return s, func() tea.Msg { return tui.StepCompleteMsg{} } } return s, cmd + + case toolsWebSearchValidatingPhase: + if msg, ok := msg.(tui.ValidationResultMsg); ok { + s.validating = false + if msg.Err != nil { + // Validation failed — go back to key input with error + s.initKeyInput(fmt.Sprintf("retry — %s", msg.Err)) + s.keyInput.SetState(components.SecretInputFailed, msg.Err.Error()) + return s, s.keyInput.Init() + } + // Success + s.complete = true + return s, func() tea.Msg { return tui.StepCompleteMsg{} } + } + + return s, nil } return s, nil } +// initKeyInput creates a fresh SecretInput for the web search API key. +func (s *ToolsStep) initKeyInput(suffix string) { + keyLabel := "Tavily API key for web_search" + s.webSearchKeyName = "TAVILY_API_KEY" + if s.webSearchProvider == "perplexity" { + keyLabel = "Perplexity API key for web_search" + s.webSearchKeyName = "PERPLEXITY_API_KEY" + } + if suffix != "" { + keyLabel = fmt.Sprintf("%s (%s)", keyLabel, suffix) + } + + s.phase = toolsWebSearchKeyPhase + s.keyInput = components.NewSecretInput( + keyLabel, + false, // required — cannot skip + s.styles.Theme.Accent, + s.styles.Theme.Success, + s.styles.Theme.Error, + s.styles.Theme.Border, + s.styles.AccentTxt, + s.styles.InactiveBorder, + s.styles.SuccessTxt, + s.styles.ErrorTxt, + s.styles.DimTxt, + s.styles.KbdKey, + s.styles.KbdDesc, + ) +} + +// runValidation runs the web search key validation asynchronously. +func (s *ToolsStep) runValidation() tea.Cmd { + provider := s.webSearchProvider + key := s.webSearchKey + validateFn := s.validateFn + return func() tea.Msg { + if validateFn == nil { + return tui.ValidationResultMsg{Err: nil} + } + err := validateFn(provider, key) + return tui.ValidationResultMsg{Err: err} + } +} + func (s *ToolsStep) View(width int) string { switch s.phase { case toolsSelectPhase: return s.multiSelect.View(width) - case toolsPerplexityKeyPhase: + case toolsWebSearchProviderPhase: + return s.providerSelect.View(width) + case toolsWebSearchKeyPhase: + return s.keyInput.View(width) + case toolsWebSearchValidatingPhase: + if s.validating { + return " " + s.styles.AccentTxt.Render("⣾ Validating...") + "\n" + } return s.keyInput.View(width) } return "" @@ -153,13 +259,24 @@ func (s *ToolsStep) Summary() string { if len(s.selected) == 0 { return "none" } - return strings.Join(s.selected, ", ") + var parts []string + for _, name := range s.selected { + if name == "web_search" && s.webSearchProvider != "" { + parts = append(parts, fmt.Sprintf("web_search [%s]", s.webSearchProvider)) + } else { + parts = append(parts, name) + } + } + return strings.Join(parts, ", ") } func (s *ToolsStep) Apply(ctx *tui.WizardContext) { ctx.BuiltinTools = s.selected - if s.perplexityKey != "" { - ctx.EnvVars["PERPLEXITY_API_KEY"] = s.perplexityKey + if s.webSearchKey != "" && s.webSearchKeyName != "" { + ctx.EnvVars[s.webSearchKeyName] = s.webSearchKey + } + if s.webSearchProvider != "" { + ctx.EnvVars["WEB_SEARCH_PROVIDER"] = s.webSearchProvider } } diff --git a/forge-cli/internal/tui/wizard.go b/forge-cli/internal/tui/wizard.go index 6224949..1ff8bca 100644 --- a/forge-cli/internal/tui/wizard.go +++ b/forge-cli/internal/tui/wizard.go @@ -93,12 +93,10 @@ func (w WizardModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return w, nil case tea.KeyMsg: - if msg.String() == "ctrl+c" { + if msg.String() == "ctrl+c" || msg.String() == "esc" { w.err = fmt.Errorf("wizard cancelled") return w, tea.Quit } - // Don't intercept esc globally — let steps handle it first. - // Only quit on esc if we're at the first step in its initial phase. case StepBackMsg: if w.current > 0 { diff --git a/forge-core/registry/index.json b/forge-core/registry/index.json index fbdb831..6b00870 100644 --- a/forge-core/registry/index.json +++ b/forge-core/registry/index.json @@ -26,5 +26,14 @@ "skill_file": "weather.md", "required_bins": ["curl"], "egress_domains": ["api.openweathermap.org", "api.weatherapi.com"] + }, + { + "name": "tavily-search", + "display_name": "Tavily Search", + "description": "Search the web using Tavily AI search API", + "skill_file": "tavily-search.md", + "required_env": ["TAVILY_API_KEY"], + "required_bins": ["curl", "jq"], + "egress_domains": ["api.tavily.com"] } ] diff --git a/forge-core/registry/registry.go b/forge-core/registry/registry.go index e21d5e6..98d3ebd 100644 --- a/forge-core/registry/registry.go +++ b/forge-core/registry/registry.go @@ -10,6 +10,9 @@ import ( //go:embed skills var skillFS embed.FS +//go:embed scripts +var scriptFS embed.FS + //go:embed index.json var indexJSON []byte @@ -53,3 +56,14 @@ func GetSkillByName(name string) *SkillInfo { } return nil } + +// LoadSkillScript reads an embedded script for a skill. +func LoadSkillScript(name string) ([]byte, error) { + return scriptFS.ReadFile("scripts/" + name + ".sh") +} + +// HasSkillScript checks if a skill has an embedded script. +func HasSkillScript(name string) bool { + _, err := scriptFS.ReadFile("scripts/" + name + ".sh") + return err == nil +} diff --git a/forge-core/registry/registry_test.go b/forge-core/registry/registry_test.go index 9034602..811ea41 100644 --- a/forge-core/registry/registry_test.go +++ b/forge-core/registry/registry_test.go @@ -29,7 +29,7 @@ func TestLoadIndex(t *testing.T) { } } - for _, expected := range []string{"summarize", "github", "weather"} { + for _, expected := range []string{"summarize", "github", "weather", "tavily-search"} { if !names[expected] { t.Errorf("expected skill %q not found in index", expected) } @@ -107,3 +107,66 @@ func TestWeatherSkillRequiredBins(t *testing.T) { t.Error("weather skill should require curl binary") } } + +func TestTavilySearchSkillRequirements(t *testing.T) { + s := GetSkillByName("tavily-search") + if s == nil { + t.Fatal("tavily-search skill not found") + } + if s.DisplayName != "Tavily Search" { + t.Errorf("expected display_name \"Tavily Search\", got %q", s.DisplayName) + } + if len(s.RequiredEnv) == 0 { + t.Error("tavily-search skill should have required_env") + } + foundKey := false + for _, env := range s.RequiredEnv { + if env == "TAVILY_API_KEY" { + foundKey = true + } + } + if !foundKey { + t.Error("tavily-search skill should require TAVILY_API_KEY") + } + if len(s.RequiredBins) < 2 { + t.Error("tavily-search skill should require curl and jq") + } + if len(s.EgressDomains) == 0 { + t.Error("tavily-search skill should have egress_domains") + } + foundDomain := false + for _, d := range s.EgressDomains { + if d == "api.tavily.com" { + foundDomain = true + } + } + if !foundDomain { + t.Error("tavily-search skill should have api.tavily.com egress domain") + } +} + +func TestLoadSkillScript(t *testing.T) { + // tavily-search should have a script + if !HasSkillScript("tavily-search") { + t.Fatal("HasSkillScript(\"tavily-search\") returned false") + } + + data, err := LoadSkillScript("tavily-search") + if err != nil { + t.Fatalf("LoadSkillScript(\"tavily-search\") error: %v", err) + } + if len(data) == 0 { + t.Error("LoadSkillScript(\"tavily-search\") returned empty content") + } + if !strings.Contains(string(data), "TAVILY_API_KEY") { + t.Error("tavily-search script should reference TAVILY_API_KEY") + } + + // Skills without scripts should return false + if HasSkillScript("github") { + t.Error("HasSkillScript(\"github\") should return false") + } + if HasSkillScript("nonexistent") { + t.Error("HasSkillScript(\"nonexistent\") should return false") + } +} diff --git a/forge-core/registry/scripts/tavily-search.sh b/forge-core/registry/scripts/tavily-search.sh new file mode 100755 index 0000000..8635a3f --- /dev/null +++ b/forge-core/registry/scripts/tavily-search.sh @@ -0,0 +1,51 @@ +#!/usr/bin/env bash +# tavily-search.sh — Search the web using the Tavily API. +# Usage: ./tavily-search.sh '{"query": "search terms", "max_results": 5}' +# +# Requires: curl, jq, TAVILY_API_KEY environment variable. +set -euo pipefail + +# --- Validate environment --- +if [ -z "${TAVILY_API_KEY:-}" ]; then + echo '{"error": "TAVILY_API_KEY environment variable is not set"}' >&2 + exit 1 +fi + +# --- Read input --- +INPUT="${1:-}" +if [ -z "$INPUT" ]; then + echo '{"error": "usage: tavily-search.sh {\"query\": \"...\"}"}' >&2 + exit 1 +fi + +# Validate JSON +if ! echo "$INPUT" | jq empty 2>/dev/null; then + echo '{"error": "invalid JSON input"}' >&2 + exit 1 +fi + +# --- Check required fields --- +QUERY=$(echo "$INPUT" | jq -r '.query // empty') +if [ -z "$QUERY" ]; then + echo '{"error": "query field is required"}' >&2 + exit 1 +fi + +# --- Call Tavily API --- +RESPONSE=$(curl -s -w "\n%{http_code}" \ + -X POST "https://api.tavily.com/search" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer ${TAVILY_API_KEY}" \ + -d "$INPUT") + +# Split response body and status code +HTTP_CODE=$(echo "$RESPONSE" | tail -1) +BODY=$(echo "$RESPONSE" | sed '$d') + +if [ "$HTTP_CODE" -ne 200 ]; then + echo "{\"error\": \"Tavily API returned status $HTTP_CODE\", \"details\": $BODY}" >&2 + exit 1 +fi + +# --- Pretty-print response --- +echo "$BODY" | jq . diff --git a/forge-core/registry/skills/tavily-search.md b/forge-core/registry/skills/tavily-search.md new file mode 100644 index 0000000..08a08c5 --- /dev/null +++ b/forge-core/registry/skills/tavily-search.md @@ -0,0 +1,81 @@ +--- +name: tavily-search +description: Search the web using Tavily AI search API +metadata: + forge: + requires: + bins: + - curl + - jq + env: + required: + - TAVILY_API_KEY + one_of: [] + optional: [] +--- + +# Tavily Web Search Skill + +Search the web using the Tavily AI search API, optimized for LLM applications. + +## Authentication + +Set the `TAVILY_API_KEY` environment variable with your Tavily API key. +Get your key at https://tavily.com + +No OAuth or MCP configuration required. + +## Quick Start + +```bash +./scripts/tavily-search.sh '{"query": "latest AI news"}' +``` + +## Tool: tavily_search + +Search the web using Tavily AI. + +**Input:** + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| query | string | yes | The search query | +| search_depth | string | no | `basic` (fast) or `advanced` (thorough). Default: `basic` | +| max_results | integer | no | Maximum results to return (1-20). Default: 5 | +| time_range | string | no | Filter by time: `day`, `week`, `month`, `year` | +| include_domains | array | no | Only include results from these domains | +| exclude_domains | array | no | Exclude results from these domains | + +**Output:** JSON object with `query`, `answer`, `results` (array of `{title, url, content, score}`), and `response_time`. + +### Search Depth + +| Depth | Speed | Detail | Use Case | +|-------|-------|--------|----------| +| basic | Fast (~1s) | Standard snippets | Quick lookups, fact checks | +| advanced | Slower (~3s) | Detailed content | Research, analysis | + +### Response Format + +```json +{ + "query": "your search query", + "answer": "AI-generated summary answer", + "response_time": 0.5, + "results": [ + { + "title": "Page Title", + "url": "https://example.com", + "content": "Relevant content snippet...", + "score": 0.95 + } + ] +} +``` + +### Tips + +- Use `search_depth: advanced` for research tasks that need detailed content +- Use `include_domains` to restrict searches to trusted sources +- Use `time_range: day` for breaking news or very recent information +- The `answer` field provides a concise AI-generated summary when available diff --git a/forge-core/security/tool_domains.go b/forge-core/security/tool_domains.go index db65bd6..afaf0d7 100644 --- a/forge-core/security/tool_domains.go +++ b/forge-core/security/tool_domains.go @@ -2,8 +2,8 @@ package security // DefaultToolDomains maps tool names to their known required domains. var DefaultToolDomains = map[string][]string{ - "web_search": {"api.perplexity.ai"}, - "web-search": {"api.perplexity.ai"}, + "web_search": {"api.tavily.com", "api.perplexity.ai"}, + "web-search": {"api.tavily.com", "api.perplexity.ai"}, "http_request": {}, // dynamic — depends on user config "slack_notify": {"slack.com", "hooks.slack.com"}, "github_api": {"api.github.com", "github.com"}, diff --git a/forge-core/tools/builtins/builtins_test.go b/forge-core/tools/builtins/builtins_test.go index ea34a27..9add8be 100644 --- a/forge-core/tools/builtins/builtins_test.go +++ b/forge-core/tools/builtins/builtins_test.go @@ -181,11 +181,21 @@ func TestMathCalculateTool_DivisionByZero(t *testing.T) { } func TestWebSearchTool_NoKey(t *testing.T) { - orig := os.Getenv("PERPLEXITY_API_KEY") + origTavily := os.Getenv("TAVILY_API_KEY") + origPerp := os.Getenv("PERPLEXITY_API_KEY") + origProvider := os.Getenv("WEB_SEARCH_PROVIDER") + _ = os.Unsetenv("TAVILY_API_KEY") _ = os.Unsetenv("PERPLEXITY_API_KEY") + _ = os.Unsetenv("WEB_SEARCH_PROVIDER") defer func() { - if orig != "" { - _ = os.Setenv("PERPLEXITY_API_KEY", orig) + if origTavily != "" { + _ = os.Setenv("TAVILY_API_KEY", origTavily) + } + if origPerp != "" { + _ = os.Setenv("PERPLEXITY_API_KEY", origPerp) + } + if origProvider != "" { + _ = os.Setenv("WEB_SEARCH_PROVIDER", origProvider) } }() @@ -195,8 +205,172 @@ func TestWebSearchTool_NoKey(t *testing.T) { if err != nil { t.Fatalf("Execute error: %v", err) } - if !strings.Contains(result, "PERPLEXITY_API_KEY") { - t.Errorf("expected missing key message, got: %q", result) + if !strings.Contains(result, "TAVILY_API_KEY") || !strings.Contains(result, "PERPLEXITY_API_KEY") { + t.Errorf("expected error mentioning both API keys, got: %q", result) + } +} + +func TestWebSearchTool_TavilyProvider(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify authorization header + if auth := r.Header.Get("Authorization"); auth != "Bearer test-tavily-key" { + w.WriteHeader(http.StatusUnauthorized) + return + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{ //nolint:errcheck + "query": "test query", + "response_time": 0.5, + "answer": "Tavily answer", + "results": []map[string]any{ + { + "title": "Result 1", + "url": "https://example.com", + "content": "Example content", + "score": 0.95, + }, + }, + }) + })) + defer ts.Close() + + // Create a Tavily provider with test server URL + p := &tavilyProvider{apiKey: "test-tavily-key", baseURL: ts.URL} + result, err := p.search(context.Background(), "test query", webSearchOpts{MaxResults: 5}) + if err != nil { + t.Fatalf("search error: %v", err) + } + if !strings.Contains(result, "Tavily answer") { + t.Errorf("expected Tavily answer in result, got: %q", result) + } + if !strings.Contains(result, "Result 1") { + t.Errorf("expected result title in result, got: %q", result) + } +} + +func TestWebSearchTool_PerplexityProvider(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if auth := r.Header.Get("Authorization"); auth != "Bearer test-perplexity-key" { + w.WriteHeader(http.StatusUnauthorized) + return + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{ //nolint:errcheck + "choices": []map[string]any{ + { + "message": map[string]string{ + "content": "Perplexity answer", + }, + }, + }, + "citations": []string{"https://source.com"}, + }) + })) + defer ts.Close() + + p := &perplexityProvider{apiKey: "test-perplexity-key", baseURL: ts.URL} + result, err := p.search(context.Background(), "test query", webSearchOpts{}) + if err != nil { + t.Fatalf("search error: %v", err) + } + if !strings.Contains(result, "Perplexity answer") { + t.Errorf("expected Perplexity answer in result, got: %q", result) + } + if !strings.Contains(result, "source.com") { + t.Errorf("expected citation in result, got: %q", result) + } +} + +func TestWebSearchTool_ProviderOverride(t *testing.T) { + origTavily := os.Getenv("TAVILY_API_KEY") + origPerp := os.Getenv("PERPLEXITY_API_KEY") + origProvider := os.Getenv("WEB_SEARCH_PROVIDER") + _ = os.Unsetenv("TAVILY_API_KEY") + _ = os.Unsetenv("PERPLEXITY_API_KEY") + _ = os.Setenv("WEB_SEARCH_PROVIDER", "tavily") + defer func() { + if origTavily != "" { + _ = os.Setenv("TAVILY_API_KEY", origTavily) + } else { + _ = os.Unsetenv("TAVILY_API_KEY") + } + if origPerp != "" { + _ = os.Setenv("PERPLEXITY_API_KEY", origPerp) + } else { + _ = os.Unsetenv("PERPLEXITY_API_KEY") + } + if origProvider != "" { + _ = os.Setenv("WEB_SEARCH_PROVIDER", origProvider) + } else { + _ = os.Unsetenv("WEB_SEARCH_PROVIDER") + } + }() + + tool := GetByName("web_search") + args, _ := json.Marshal(map[string]string{"query": "test"}) + result, err := tool.Execute(context.Background(), args) + if err != nil { + t.Fatalf("Execute error: %v", err) + } + // Should error because TAVILY_API_KEY is not set + if !strings.Contains(result, "TAVILY_API_KEY") { + t.Errorf("expected missing TAVILY_API_KEY message, got: %q", result) + } +} + +func TestWebSearchTool_ExplicitPerplexity(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{ //nolint:errcheck + "choices": []map[string]any{ + {"message": map[string]string{"content": "Perplexity explicit"}}, + }, + }) + })) + defer ts.Close() + + // Both keys set, but WEB_SEARCH_PROVIDER=perplexity -> should use Perplexity + origTavily := os.Getenv("TAVILY_API_KEY") + origPerp := os.Getenv("PERPLEXITY_API_KEY") + origProvider := os.Getenv("WEB_SEARCH_PROVIDER") + _ = os.Setenv("TAVILY_API_KEY", "some-tavily-key") + _ = os.Setenv("PERPLEXITY_API_KEY", "test-perp-key") + _ = os.Setenv("WEB_SEARCH_PROVIDER", "perplexity") + defer func() { + if origTavily != "" { + _ = os.Setenv("TAVILY_API_KEY", origTavily) + } else { + _ = os.Unsetenv("TAVILY_API_KEY") + } + if origPerp != "" { + _ = os.Setenv("PERPLEXITY_API_KEY", origPerp) + } else { + _ = os.Unsetenv("PERPLEXITY_API_KEY") + } + if origProvider != "" { + _ = os.Setenv("WEB_SEARCH_PROVIDER", origProvider) + } else { + _ = os.Unsetenv("WEB_SEARCH_PROVIDER") + } + }() + + // Use the provider directly with test server + p := &perplexityProvider{apiKey: "test-perp-key", baseURL: ts.URL} + result, err := p.search(context.Background(), "test", webSearchOpts{}) + if err != nil { + t.Fatalf("search error: %v", err) + } + if !strings.Contains(result, "Perplexity explicit") { + t.Errorf("expected Perplexity response, got: %q", result) + } + + // Also verify resolveWebSearchProvider picks Perplexity + provider, resolveErr := resolveWebSearchProvider() + if resolveErr != nil { + t.Fatalf("resolveWebSearchProvider error: %v", resolveErr) + } + if provider.name() != "perplexity" { + t.Errorf("expected perplexity provider, got %q", provider.name()) } } diff --git a/forge-core/tools/builtins/web_search.go b/forge-core/tools/builtins/web_search.go index fde0f68..9d79a74 100644 --- a/forge-core/tools/builtins/web_search.go +++ b/forge-core/tools/builtins/web_search.go @@ -1,12 +1,9 @@ package builtins import ( - "bytes" "context" "encoding/json" "fmt" - "io" - "net/http" "os" "github.com/initializ/forge/forge-core/tools" @@ -15,7 +12,7 @@ import ( type webSearchTool struct{} func (t *webSearchTool) Name() string { return "web_search" } -func (t *webSearchTool) Description() string { return "Search the web using Perplexity AI" } +func (t *webSearchTool) Description() string { return "Search the web using Tavily or Perplexity AI" } func (t *webSearchTool) Category() tools.Category { return tools.CategoryBuiltin } func (t *webSearchTool) InputSchema() json.RawMessage { @@ -23,23 +20,26 @@ func (t *webSearchTool) InputSchema() json.RawMessage { "type": "object", "properties": { "query": {"type": "string", "description": "Search query"}, - "max_results": {"type": "integer", "description": "Maximum number of results (default 5)"} + "max_results": {"type": "integer", "description": "Maximum number of results (default 5)"}, + "search_depth": {"type": "string", "description": "Search depth: basic or advanced (Tavily only)", "enum": ["basic", "advanced"]}, + "time_range": {"type": "string", "description": "Time range filter: day, week, month, year (Tavily only)"}, + "include_domains": {"type": "array", "items": {"type": "string"}, "description": "Only include results from these domains (Tavily only)"}, + "exclude_domains": {"type": "array", "items": {"type": "string"}, "description": "Exclude results from these domains (Tavily only)"} }, "required": ["query"] }`) } type webSearchInput struct { - Query string `json:"query"` - MaxResults int `json:"max_results,omitempty"` + Query string `json:"query"` + MaxResults int `json:"max_results,omitempty"` + SearchDepth string `json:"search_depth,omitempty"` + TimeRange string `json:"time_range,omitempty"` + IncludeDomains []string `json:"include_domains,omitempty"` + ExcludeDomains []string `json:"exclude_domains,omitempty"` } func (t *webSearchTool) Execute(ctx context.Context, args json.RawMessage) (string, error) { - apiKey := os.Getenv("PERPLEXITY_API_KEY") - if apiKey == "" { - return `{"error": "PERPLEXITY_API_KEY is not set. Add it to your .env file to enable web search."}`, nil - } - var input webSearchInput if err := json.Unmarshal(args, &input); err != nil { return "", fmt.Errorf("parsing web_search input: %w", err) @@ -48,65 +48,53 @@ func (t *webSearchTool) Execute(ctx context.Context, args json.RawMessage) (stri return `{"error": "query is required"}`, nil } - // Build Perplexity chat completion request - reqBody := map[string]any{ - "model": "sonar", - "messages": []map[string]string{ - {"role": "user", "content": input.Query}, - }, - } - bodyBytes, err := json.Marshal(reqBody) - if err != nil { - return "", fmt.Errorf("marshalling request: %w", err) - } - - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://api.perplexity.ai/chat/completions", bytes.NewReader(bodyBytes)) - if err != nil { - return "", fmt.Errorf("creating request: %w", err) - } - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("Authorization", "Bearer "+apiKey) - - resp, err := http.DefaultClient.Do(httpReq) + provider, err := resolveWebSearchProvider() if err != nil { - return "", fmt.Errorf("calling Perplexity API: %w", err) + return fmt.Sprintf(`{"error": %q}`, err.Error()), nil } - defer func() { _ = resp.Body.Close() }() - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return "", fmt.Errorf("reading response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return fmt.Sprintf(`{"error": "Perplexity API returned status %d: %s"}`, resp.StatusCode, string(respBody)), nil - } - - // Extract the answer from the response - var pResp struct { - Choices []struct { - Message struct { - Content string `json:"content"` - } `json:"message"` - } `json:"choices"` - Citations []string `json:"citations,omitempty"` - } - if err := json.Unmarshal(respBody, &pResp); err != nil { - return "", fmt.Errorf("parsing Perplexity response: %w", err) + opts := webSearchOpts{ + MaxResults: input.MaxResults, + SearchDepth: input.SearchDepth, + TimeRange: input.TimeRange, + IncludeDomains: input.IncludeDomains, + ExcludeDomains: input.ExcludeDomains, } - if len(pResp.Choices) == 0 { - return `{"error": "no results from Perplexity"}`, nil - } + return provider.search(ctx, input.Query, opts) +} - result := map[string]any{ - "query": input.Query, - "answer": pResp.Choices[0].Message.Content, +// resolveWebSearchProvider selects the web search provider based on environment. +// Priority: WEB_SEARCH_PROVIDER env > auto-detect (Tavily first, then Perplexity). +func resolveWebSearchProvider() (webSearchProvider, error) { + override := os.Getenv("WEB_SEARCH_PROVIDER") + + switch override { + case "tavily": + key := os.Getenv("TAVILY_API_KEY") + if key == "" { + return nil, fmt.Errorf("WEB_SEARCH_PROVIDER is set to tavily but TAVILY_API_KEY is not set") + } + return newTavilyProvider(key), nil + + case "perplexity": + key := os.Getenv("PERPLEXITY_API_KEY") + if key == "" { + return nil, fmt.Errorf("WEB_SEARCH_PROVIDER is set to perplexity but PERPLEXITY_API_KEY is not set") + } + return newPerplexityProvider(key), nil + + case "": + // Auto-detect: try Tavily first, then Perplexity + if key := os.Getenv("TAVILY_API_KEY"); key != "" { + return newTavilyProvider(key), nil + } + if key := os.Getenv("PERPLEXITY_API_KEY"); key != "" { + return newPerplexityProvider(key), nil + } + return nil, fmt.Errorf("no web search API key set. Set TAVILY_API_KEY or PERPLEXITY_API_KEY in your .env file to enable web search") + + default: + return nil, fmt.Errorf("unknown WEB_SEARCH_PROVIDER %q: must be tavily or perplexity", override) } - if len(pResp.Citations) > 0 { - result["citations"] = pResp.Citations - } - - out, _ := json.Marshal(result) - return string(out), nil } diff --git a/forge-core/tools/builtins/web_search_perplexity.go b/forge-core/tools/builtins/web_search_perplexity.go new file mode 100644 index 0000000..1193eda --- /dev/null +++ b/forge-core/tools/builtins/web_search_perplexity.go @@ -0,0 +1,90 @@ +package builtins + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" +) + +// perplexityProvider implements webSearchProvider using the Perplexity API. +type perplexityProvider struct { + apiKey string + baseURL string // defaults to "https://api.perplexity.ai" +} + +func newPerplexityProvider(apiKey string) *perplexityProvider { + return &perplexityProvider{apiKey: apiKey, baseURL: "https://api.perplexity.ai"} +} + +func (p *perplexityProvider) name() string { return "perplexity" } + +func (p *perplexityProvider) egressDomains() []string { + return []string{"api.perplexity.ai"} +} + +func (p *perplexityProvider) search(ctx context.Context, query string, opts webSearchOpts) (string, error) { + // Perplexity uses the chat completions API with the sonar model. + // Tavily-specific opts (search_depth, time_range, domains) are ignored gracefully. + reqBody := map[string]any{ + "model": "sonar", + "messages": []map[string]string{ + {"role": "user", "content": query}, + }, + } + bodyBytes, err := json.Marshal(reqBody) + if err != nil { + return "", fmt.Errorf("marshalling Perplexity request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/chat/completions", bytes.NewReader(bodyBytes)) + if err != nil { + return "", fmt.Errorf("creating Perplexity request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Authorization", "Bearer "+p.apiKey) + + resp, err := http.DefaultClient.Do(httpReq) + if err != nil { + return "", fmt.Errorf("calling Perplexity API: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("reading Perplexity response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return fmt.Sprintf(`{"error": "Perplexity API returned status %d: %s"}`, resp.StatusCode, string(respBody)), nil + } + + var pResp struct { + Choices []struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + } `json:"choices"` + Citations []string `json:"citations,omitempty"` + } + if err := json.Unmarshal(respBody, &pResp); err != nil { + return "", fmt.Errorf("parsing Perplexity response: %w", err) + } + + if len(pResp.Choices) == 0 { + return `{"error": "no results from Perplexity"}`, nil + } + + result := map[string]any{ + "query": query, + "answer": pResp.Choices[0].Message.Content, + } + if len(pResp.Citations) > 0 { + result["citations"] = pResp.Citations + } + + out, _ := json.Marshal(result) + return string(out), nil +} diff --git a/forge-core/tools/builtins/web_search_provider.go b/forge-core/tools/builtins/web_search_provider.go new file mode 100644 index 0000000..06ef4cd --- /dev/null +++ b/forge-core/tools/builtins/web_search_provider.go @@ -0,0 +1,19 @@ +package builtins + +import "context" + +// webSearchProvider abstracts a web search backend (Tavily, Perplexity, etc.). +type webSearchProvider interface { + name() string + search(ctx context.Context, query string, opts webSearchOpts) (string, error) + egressDomains() []string +} + +// webSearchOpts holds optional parameters for a web search request. +type webSearchOpts struct { + MaxResults int `json:"max_results"` + SearchDepth string `json:"search_depth"` + TimeRange string `json:"time_range"` + IncludeDomains []string `json:"include_domains"` + ExcludeDomains []string `json:"exclude_domains"` +} diff --git a/forge-core/tools/builtins/web_search_tavily.go b/forge-core/tools/builtins/web_search_tavily.go new file mode 100644 index 0000000..8674d99 --- /dev/null +++ b/forge-core/tools/builtins/web_search_tavily.go @@ -0,0 +1,113 @@ +package builtins + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" +) + +// tavilyProvider implements webSearchProvider using the Tavily API. +type tavilyProvider struct { + apiKey string + baseURL string // defaults to "https://api.tavily.com" +} + +func newTavilyProvider(apiKey string) *tavilyProvider { + return &tavilyProvider{apiKey: apiKey, baseURL: "https://api.tavily.com"} +} + +func (p *tavilyProvider) name() string { return "tavily" } + +func (p *tavilyProvider) egressDomains() []string { + return []string{"api.tavily.com"} +} + +func (p *tavilyProvider) search(ctx context.Context, query string, opts webSearchOpts) (string, error) { + reqBody := map[string]any{ + "query": query, + } + if opts.MaxResults > 0 { + reqBody["max_results"] = opts.MaxResults + } + if opts.SearchDepth != "" { + reqBody["search_depth"] = opts.SearchDepth + } + if opts.TimeRange != "" { + reqBody["time_range"] = opts.TimeRange + } + if len(opts.IncludeDomains) > 0 { + reqBody["include_domains"] = opts.IncludeDomains + } + if len(opts.ExcludeDomains) > 0 { + reqBody["exclude_domains"] = opts.ExcludeDomains + } + + bodyBytes, err := json.Marshal(reqBody) + if err != nil { + return "", fmt.Errorf("marshalling Tavily request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/search", bytes.NewReader(bodyBytes)) + if err != nil { + return "", fmt.Errorf("creating Tavily request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Authorization", "Bearer "+p.apiKey) + + resp, err := http.DefaultClient.Do(httpReq) + if err != nil { + return "", fmt.Errorf("calling Tavily API: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("reading Tavily response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return fmt.Sprintf(`{"error": "Tavily API returned status %d: %s"}`, resp.StatusCode, string(respBody)), nil + } + + // Parse the Tavily response + var tResp struct { + Query string `json:"query"` + ResponseTime float64 `json:"response_time"` + Answer string `json:"answer,omitempty"` + Results []struct { + Title string `json:"title"` + URL string `json:"url"` + Content string `json:"content"` + Score float64 `json:"score"` + } `json:"results"` + } + if err := json.Unmarshal(respBody, &tResp); err != nil { + return "", fmt.Errorf("parsing Tavily response: %w", err) + } + + result := map[string]any{ + "query": tResp.Query, + "response_time": tResp.ResponseTime, + } + if tResp.Answer != "" { + result["answer"] = tResp.Answer + } + if len(tResp.Results) > 0 { + var results []map[string]any + for _, r := range tResp.Results { + results = append(results, map[string]any{ + "title": r.Title, + "url": r.URL, + "content": r.Content, + "score": r.Score, + }) + } + result["results"] = results + } + + out, _ := json.Marshal(result) + return string(out), nil +} diff --git a/forge-plugins/channels/telegram/telegram.go b/forge-plugins/channels/telegram/telegram.go index d272e6b..173a466 100644 --- a/forge-plugins/channels/telegram/telegram.go +++ b/forge-plugins/channels/telegram/telegram.go @@ -136,7 +136,9 @@ func (p *Plugin) makeWebhookHandler(handler channels.EventHandler) http.HandlerF go func() { ctx := context.Background() + stopTyping := p.startTypingIndicator(ctx, event.WorkspaceID) resp, err := handler(ctx, event) + stopTyping() if err != nil { fmt.Printf("telegram: handler error: %v\n", err) return @@ -189,7 +191,9 @@ func (p *Plugin) startPolling(ctx context.Context, handler channels.EventHandler } go func() { + stopTyping := p.startTypingIndicator(ctx, event.WorkspaceID) resp, err := handler(ctx, event) + stopTyping() if err != nil { fmt.Printf("telegram: handler error: %v\n", err) return @@ -280,6 +284,67 @@ func (p *Plugin) SendResponse(event *channels.ChannelEvent, response *a2a.Messag return nil } +// sendChatAction sends a chat action (e.g. "typing") to indicate activity. +func (p *Plugin) sendChatAction(chatID, action string) error { + payload := map[string]string{ + "chat_id": chatID, + "action": action, + } + body, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("marshalling chat action: %w", err) + } + + url := fmt.Sprintf("%s/bot%s/sendChatAction", p.apiBase, p.botToken) + req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("creating chat action request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := p.client.Do(req) + if err != nil { + return err + } + defer func() { _ = resp.Body.Close() }() + _, _ = io.ReadAll(resp.Body) + return nil +} + +// startTypingIndicator sends "typing" chat action repeatedly until the +// returned stop function is called. Telegram's typing indicator expires +// after ~5 seconds, so we resend every 4 seconds. +func (p *Plugin) startTypingIndicator(ctx context.Context, chatID string) (stop func()) { + done := make(chan struct{}) + stop = func() { + select { + case <-done: + default: + close(done) + } + } + + // Send the first typing indicator immediately. + _ = p.sendChatAction(chatID, "typing") + + go func() { + ticker := time.NewTicker(4 * time.Second) + defer ticker.Stop() + for { + select { + case <-done: + return + case <-ctx.Done(): + return + case <-ticker.C: + _ = p.sendChatAction(chatID, "typing") + } + } + }() + + return stop +} + // sendMessage posts a JSON payload to the Telegram sendMessage API. func (p *Plugin) sendMessage(payload map[string]any) error { body, err := json.Marshal(payload) diff --git a/forge-plugins/channels/telegram/telegram_test.go b/forge-plugins/channels/telegram/telegram_test.go index 92b5d34..b295fb5 100644 --- a/forge-plugins/channels/telegram/telegram_test.go +++ b/forge-plugins/channels/telegram/telegram_test.go @@ -8,6 +8,7 @@ import ( "net/http/httptest" "strings" "testing" + "time" "github.com/initializ/forge/forge-core/a2a" "github.com/initializ/forge/forge-core/channels" @@ -272,6 +273,66 @@ func TestInit_MissingToken(t *testing.T) { } } +func TestSendChatAction(t *testing.T) { + var receivedAction string + var receivedChatID string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var payload map[string]string + json.Unmarshal(body, &payload) //nolint:errcheck + receivedChatID = payload["chat_id"] + receivedAction = payload["action"] + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"ok":true}`)) //nolint:errcheck + })) + defer srv.Close() + + p := New() + p.botToken = "test-token" + p.apiBase = srv.URL + + err := p.sendChatAction("12345", "typing") + if err != nil { + t.Fatalf("sendChatAction() error: %v", err) + } + if receivedChatID != "12345" { + t.Errorf("chat_id = %q, want 12345", receivedChatID) + } + if receivedAction != "typing" { + t.Errorf("action = %q, want typing", receivedAction) + } +} + +func TestStartTypingIndicator(t *testing.T) { + var actionCount int + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Count sendChatAction calls (path contains sendChatAction) + if strings.Contains(r.URL.Path, "sendChatAction") { + actionCount++ + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"ok":true}`)) //nolint:errcheck + })) + defer srv.Close() + + p := New() + p.botToken = "test-token" + p.apiBase = srv.URL + + ctx := context.Background() + stop := p.startTypingIndicator(ctx, "67890") + + // The first typing action is sent immediately + // Give it a moment to process + time.Sleep(100 * time.Millisecond) + + if actionCount < 1 { + t.Errorf("expected at least 1 typing action, got %d", actionCount) + } + + stop() +} + func TestInit_InvalidMode(t *testing.T) { p := New()