Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion forge-cli/channels/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,15 @@ func (r *Router) Handler() channels.EventHandler {
// forwardToA2A sends a tasks/send JSON-RPC request to the A2A server and
// extracts the agent's response message from the returned task.
func (r *Router) forwardToA2A(ctx context.Context, event *channels.ChannelEvent) (*a2a.Message, error) {
taskID := fmt.Sprintf("%s-%s-%d", event.Channel, event.WorkspaceID, time.Now().UnixMilli())
// Build a stable task ID so all messages in the same conversation share
// one session. Use thread ID when available (threaded replies), otherwise
// fall back to channel + workspace + user for DM-style conversations.
var taskID string
if event.ThreadID != "" {
taskID = fmt.Sprintf("%s-%s-%s", event.Channel, event.WorkspaceID, event.ThreadID)
} else {
taskID = fmt.Sprintf("%s-%s-%s", event.Channel, event.WorkspaceID, event.UserID)
}

params := a2a.SendTaskParams{
ID: taskID,
Expand Down
236 changes: 225 additions & 11 deletions forge-cli/runtime/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os"
"path/filepath"
"strings"
"time"

"github.com/initializ/forge/forge-cli/server"
cliskills "github.com/initializ/forge/forge-cli/skills"
Expand All @@ -17,6 +18,7 @@ import (
"github.com/initializ/forge/forge-core/llm"
"github.com/initializ/forge/forge-core/llm/oauth"
"github.com/initializ/forge/forge-core/llm/providers"
"github.com/initializ/forge/forge-core/memory"
coreruntime "github.com/initializ/forge/forge-core/runtime"
"github.com/initializ/forge/forge-core/tools"
"github.com/initializ/forge/forge-core/tools/builtins"
Expand Down Expand Up @@ -165,12 +167,71 @@ func (r *Runner) Run(ctx context.Context) error {
hooks := coreruntime.NewHookRegistry()
r.registerLoggingHooks(hooks)

executor = coreruntime.NewLLMExecutor(coreruntime.LLMExecutorConfig{
// Compute model-aware character budget.
charBudget := r.cfg.Config.Memory.CharBudget
if charBudget == 0 {
charBudget = coreruntime.ContextBudgetForModel(mc.Client.Model)
}

execCfg := coreruntime.LLMExecutorConfig{
Client: llmClient,
Tools: reg,
Hooks: hooks,
SystemPrompt: fmt.Sprintf("You are %s, an AI agent.", r.cfg.Config.AgentID),
})
Logger: r.logger,
ModelName: mc.Client.Model,
CharBudget: charBudget,
}

// Initialize memory persistence (enabled by default).
// Disable via FORGE_MEMORY_PERSISTENCE=false or memory.persistence: false in forge.yaml.
memPersistence := true
if r.cfg.Config.Memory.Persistence != nil {
memPersistence = *r.cfg.Config.Memory.Persistence
}
if os.Getenv("FORGE_MEMORY_PERSISTENCE") == "false" {
memPersistence = false
}
if memPersistence {
sessDir := r.cfg.Config.Memory.SessionsDir
if sessDir == "" {
sessDir = filepath.Join(r.cfg.WorkDir, ".forge", "sessions")
}
memStore, storeErr := coreruntime.NewMemoryStore(sessDir)
if storeErr != nil {
r.logger.Warn("failed to create memory store, persistence disabled", map[string]any{
"error": storeErr.Error(),
})
} else {
// Clean up old sessions on startup (7-day TTL).
deleted, _ := memStore.Cleanup(7 * 24 * time.Hour)
if deleted > 0 {
r.logger.Info("cleaned up old sessions", map[string]any{"deleted": deleted})
}

compactor := coreruntime.NewCompactor(coreruntime.CompactorConfig{
Client: llmClient,
Store: memStore,
Logger: r.logger,
CharBudget: charBudget,
TriggerRatio: r.cfg.Config.Memory.TriggerRatio,
})

execCfg.Store = memStore
execCfg.Compactor = compactor
r.logger.Info("memory persistence enabled", map[string]any{
"sessions_dir": sessDir,
})
}
}

// Initialize long-term memory if enabled.
memMgr := r.initLongTermMemory(ctx, mc, reg, execCfg.Compactor)
if memMgr != nil {
defer memMgr.Close() //nolint:errcheck
}

executor = coreruntime.NewLLMExecutor(execCfg)

r.logger.Info("using LLM executor", map[string]any{
"provider": mc.Provider,
Expand Down Expand Up @@ -248,11 +309,12 @@ func (r *Runner) registerHandlers(srv *server.Server, executor coreruntime.Agent

r.logger.Info("tasks/send", map[string]any{"task_id": params.ID})

// Create task in submitted state
task := &a2a.Task{
ID: params.ID,
Status: a2a.TaskStatus{State: a2a.TaskStateSubmitted},
// Load existing task to preserve conversation history, or create new.
task := store.Get(params.ID)
if task == nil {
task = &a2a.Task{ID: params.ID}
}
task.Status = a2a.TaskStatus{State: a2a.TaskStateSubmitted}
store.Put(task)

// Guardrail check inbound
Expand All @@ -268,9 +330,12 @@ func (r *Runner) registerHandlers(srv *server.Server, executor coreruntime.Agent
return a2a.NewResponse(id, task)
}

// Append inbound user message to task history.
task.History = append(task.History, params.Message)

// Update to working
store.UpdateStatus(params.ID, a2a.TaskStatus{State: a2a.TaskStateWorking})
task.Status = a2a.TaskStatus{State: a2a.TaskStateWorking}
store.Put(task)

// Execute via executor
respMsg, err := executor.Execute(ctx, task, &params.Message)
Expand Down Expand Up @@ -302,6 +367,11 @@ func (r *Runner) registerHandlers(srv *server.Server, executor coreruntime.Agent
}
}

// Append agent response to task history.
if respMsg != nil {
task.History = append(task.History, *respMsg)
}

// Build completed task
task.Status = a2a.TaskStatus{
State: a2a.TaskStateCompleted,
Expand Down Expand Up @@ -330,11 +400,12 @@ func (r *Runner) registerHandlers(srv *server.Server, executor coreruntime.Agent

r.logger.Info("tasks/sendSubscribe", map[string]any{"task_id": params.ID})

// Create task
task := &a2a.Task{
ID: params.ID,
Status: a2a.TaskStatus{State: a2a.TaskStateSubmitted},
// Load existing task to preserve conversation history, or create new.
task := store.Get(params.ID)
if task == nil {
task = &a2a.Task{ID: params.ID}
}
task.Status = a2a.TaskStatus{State: a2a.TaskStateSubmitted}
store.Put(task)
server.WriteSSEEvent(w, flusher, "status", task) //nolint:errcheck

Expand All @@ -352,6 +423,9 @@ func (r *Runner) registerHandlers(srv *server.Server, executor coreruntime.Agent
return
}

// Append inbound user message to task history.
task.History = append(task.History, params.Message)

// Update to working
task.Status = a2a.TaskStatus{State: a2a.TaskStateWorking}
store.Put(task)
Expand Down Expand Up @@ -387,6 +461,9 @@ func (r *Runner) registerHandlers(srv *server.Server, executor coreruntime.Agent
return
}

// Append agent response to task history.
task.History = append(task.History, *respMsg)

// Build completed result
task.Status = a2a.TaskStatus{
State: a2a.TaskStateCompleted,
Expand Down Expand Up @@ -721,6 +798,143 @@ func envFromOS() map[string]string {
return env
}

// initLongTermMemory sets up the long-term memory system if enabled.
// It resolves the embedder, creates a memory.Manager, registers memory tools,
// and starts background indexing. Returns the Manager (caller must Close) or nil.
func (r *Runner) initLongTermMemory(ctx context.Context, mc *coreruntime.ModelConfig, reg *tools.Registry, compactor *coreruntime.Compactor) *memory.Manager {
// Check if long-term memory is enabled.
enabled := false
if r.cfg.Config.Memory.LongTerm != nil {
enabled = *r.cfg.Config.Memory.LongTerm
}
if os.Getenv("FORGE_MEMORY_LONG_TERM") == "true" {
enabled = true
}
if !enabled {
return nil
}

memDir := r.cfg.Config.Memory.MemoryDir
if memDir == "" {
memDir = filepath.Join(r.cfg.WorkDir, ".forge", "memory")
}

// Resolve embedder.
embedder := r.resolveEmbedder(mc)

// Build search config from forge.yaml.
searchCfg := memory.DefaultSearchConfig()
if r.cfg.Config.Memory.VectorWeight > 0 {
searchCfg.VectorWeight = r.cfg.Config.Memory.VectorWeight
}
if r.cfg.Config.Memory.KeywordWeight > 0 {
searchCfg.KeywordWeight = r.cfg.Config.Memory.KeywordWeight
}
if r.cfg.Config.Memory.DecayHalfLifeDays > 0 {
searchCfg.DecayHalfLife = time.Duration(r.cfg.Config.Memory.DecayHalfLifeDays) * 24 * time.Hour
}

mgr, err := memory.NewManager(memory.ManagerConfig{
MemoryDir: memDir,
Embedder: embedder,
Logger: r.logger,
SearchConfig: searchCfg,
})
if err != nil {
r.logger.Warn("failed to create memory manager, long-term memory disabled", map[string]any{
"error": err.Error(),
})
return nil
}

// Register memory tools.
if regErr := reg.Register(builtins.NewMemorySearchTool(mgr)); regErr != nil {
r.logger.Warn("failed to register memory_search tool", map[string]any{"error": regErr.Error()})
}
if regErr := reg.Register(builtins.NewMemoryGetTool(mgr)); regErr != nil {
r.logger.Warn("failed to register memory_get tool", map[string]any{"error": regErr.Error()})
}

// Wire memory flusher into compactor (if compactor exists).
if compactor != nil {
compactor.SetMemoryFlusher(mgr)
}

// Index memory files at startup in background.
go func() {
if idxErr := mgr.IndexAll(ctx); idxErr != nil {
r.logger.Warn("background memory indexing failed", map[string]any{"error": idxErr.Error()})
}
}()

mode := "keyword-only"
if embedder != nil {
mode = "vector+keyword"
}
r.logger.Info("long-term memory enabled", map[string]any{
"memory_dir": memDir,
"mode": mode,
})

return mgr
}

// resolveEmbedder creates an embedder from config or auto-detection.
// Returns nil if no embedder can be created (keyword-only mode).
func (r *Runner) resolveEmbedder(mc *coreruntime.ModelConfig) llm.Embedder {
// Resolution order: config override → env → primary LLM provider.
embProvider := r.cfg.Config.Memory.EmbeddingProvider
if embProvider == "" {
embProvider = os.Getenv("FORGE_EMBEDDING_PROVIDER")
}
if embProvider == "" {
embProvider = mc.Provider
}

// Anthropic has no embedding API — skip.
if embProvider == "anthropic" {
r.logger.Info("primary provider is anthropic (no embedding API), trying fallbacks for embeddings", nil)
// Try fallback providers.
for _, fb := range mc.Fallbacks {
if fb.Provider != "anthropic" {
embProvider = fb.Provider
break
}
}
if embProvider == "anthropic" {
r.logger.Info("no embedding-capable provider found, using keyword-only search", nil)
return nil
}
}

cfg := providers.OpenAIEmbedderConfig{
APIKey: mc.Client.APIKey,
Model: r.cfg.Config.Memory.EmbeddingModel,
}

// Use the correct API key for the embedding provider if it differs from primary.
if embProvider != mc.Provider {
for _, fb := range mc.Fallbacks {
if fb.Provider == embProvider {
cfg.APIKey = fb.Client.APIKey
cfg.BaseURL = fb.Client.BaseURL
break
}
}
}

embedder, err := providers.NewEmbedder(embProvider, cfg)
if err != nil {
r.logger.Warn("failed to create embedder, using keyword-only search", map[string]any{
"provider": embProvider,
"error": err.Error(),
})
return nil
}

return embedder
}

func defaultStr(s, def string) string {
if s != "" {
return s
Expand Down
1 change: 1 addition & 0 deletions forge-core/channels/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ type ChannelEvent struct {
WorkspaceID string `json:"workspace_id"`
UserID string `json:"user_id"`
ThreadID string `json:"thread_id,omitempty"`
MessageID string `json:"message_id,omitempty"` // per-message ID for reply targeting
Message string `json:"message"`
Attachments []Attachment `json:"attachments,omitempty"`
Raw json.RawMessage `json:"raw,omitempty"`
Expand Down
24 changes: 24 additions & 0 deletions forge-core/llm/embedder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package llm

import "context"

// EmbeddingRequest is a provider-agnostic request to generate embeddings.
type EmbeddingRequest struct {
Texts []string // texts to embed
Model string // optional model override
}

// EmbeddingResponse is a provider-agnostic embedding response.
type EmbeddingResponse struct {
Embeddings [][]float32
Model string
Usage UsageInfo
}

// Embedder generates vector embeddings from text.
type Embedder interface {
// Embed produces embeddings for the given texts.
Embed(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error)
// Dimensions returns the dimensionality of the embedding vectors.
Dimensions() int
}
28 changes: 28 additions & 0 deletions forge-core/llm/providers/embedder_factory.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package providers

import (
"fmt"

"github.com/initializ/forge/forge-core/llm"
)

// NewEmbedder creates an Embedder for the specified provider.
// Supported providers: "openai", "gemini", "ollama".
// Returns an error for "anthropic" (no embedding API).
func NewEmbedder(provider string, cfg OpenAIEmbedderConfig) (llm.Embedder, error) {
switch provider {
case "openai":
return NewOpenAIEmbedder(cfg), nil
case "gemini":
if cfg.BaseURL == "" {
cfg.BaseURL = "https://generativelanguage.googleapis.com/v1beta/openai"
}
return NewOpenAIEmbedder(cfg), nil
case "ollama":
return NewOllamaEmbedder(cfg), nil
case "anthropic":
return nil, fmt.Errorf("anthropic does not provide an embedding API; configure an alternative embedding provider")
default:
return nil, fmt.Errorf("unknown embedding provider: %q", provider)
}
}
Loading
Loading