From e24ae05ee60660cc167c828e90109a2710d110ab Mon Sep 17 00:00:00 2001 From: VanderChen Date: Sat, 29 Nov 2025 17:13:17 +0800 Subject: [PATCH 1/6] Implement router for invoke reuqest Signed-off-by: VanderChen --- Makefile | 19 +- cmd/router/main.go | 82 +++++ docs/proposal/router-design.md | 295 +++++++++++++++ go.mod | 2 +- ...e-apiserver.yaml => agentcube-router.yaml} | 35 +- pkg/agentd/agentd_test.go | 6 +- pkg/router/apiserver.go | 153 ++++++++ pkg/router/apiserver_test.go | 343 ++++++++++++++++++ pkg/router/config.go | 53 +++ pkg/router/handlers.go | 205 +++++++++++ pkg/router/redis_manager.go | 135 +++++++ pkg/router/redis_manager_test.go | 250 +++++++++++++ pkg/router/session_manager.go | 116 ++++++ pkg/router/session_manager_test.go | 305 ++++++++++++++++ pkg/router/utils.go | 60 +++ 15 files changed, 2027 insertions(+), 32 deletions(-) create mode 100644 cmd/router/main.go create mode 100644 docs/proposal/router-design.md rename k8s/{agentcube-apiserver.yaml => agentcube-router.yaml} (83%) create mode 100644 pkg/router/apiserver.go create mode 100644 pkg/router/apiserver_test.go create mode 100644 pkg/router/config.go create mode 100644 pkg/router/handlers.go create mode 100644 pkg/router/redis_manager.go create mode 100644 pkg/router/redis_manager_test.go create mode 100644 pkg/router/session_manager.go create mode 100644 pkg/router/session_manager_test.go create mode 100644 pkg/router/utils.go diff --git a/Makefile b/Makefile index 33bb16d4..b8649c91 100644 --- a/Makefile +++ b/Makefile @@ -87,23 +87,20 @@ run: @echo "Running agentcube-apiserver..." go run ./cmd/workload-manager/main.go \ --port=8080 \ - --ssh-username=sandbox \ - --ssh-port=22 + --debug # Run server (with kubeconfig) run-local: @echo "Running agentcube-apiserver with local kubeconfig..." go run ./cmd/workload-manager/main.go \ --port=8080 \ - --kubeconfig=${HOME}/.kube/config \ - --ssh-username=sandbox \ - --ssh-port=22 + --debug # Clean build artifacts clean: @echo "Cleaning..." rm -rf bin/ - rm -f agentcube-apiserver agentd + rm -f agentcube-router agentd # Install dependencies deps: @@ -138,8 +135,8 @@ lint: golangci-lint ## Run golangci-lint # Install to system install: build - @echo "Installing agentcube-apiserver..." - sudo cp bin/agentcube-apiserver /usr/local/bin/ + @echo "Installing agentcube-router..." + sudo cp bin/agentcube-router /usr/local/bin/ # Docker image variables APISERVER_IMAGE ?= agentcube-apiserver:latest @@ -177,15 +174,15 @@ docker-push: docker-build k8s-deploy: @echo "Deploying to Kubernetes..." - kubectl apply -f k8s/agentcube-apiserver.yaml + kubectl apply -f k8s/agentcube-router.yaml k8s-delete: @echo "Deleting from Kubernetes..." - kubectl delete -f k8s/agentcube-apiserver.yaml + kubectl delete -f k8s/agentcube-router.yaml k8s-logs: @echo "Showing logs..." - kubectl logs -n agentcube -l app=agentcube-apiserver -f + kubectl logs -n agentcube -l app=agentcube-router -f # Load image to kind cluster kind-load: diff --git a/cmd/router/main.go b/cmd/router/main.go new file mode 100644 index 00000000..7e0768de --- /dev/null +++ b/cmd/router/main.go @@ -0,0 +1,82 @@ +package main + +import ( + "context" + "flag" + "log" + "os" + "os/signal" + "syscall" + "time" + + "github.com/volcano-sh/agentcube/pkg/router" +) + +func main() { + var ( + port = flag.String("port", "8080", "Router API server port") + enableTLS = flag.Bool("enable-tls", false, "Enable TLS (HTTPS)") + tlsCert = flag.String("tls-cert", "", "Path to TLS certificate file") + tlsKey = flag.String("tls-key", "", "Path to TLS key file") + debug = flag.Bool("debug", true, "Enable debug mode") + maxConcurrentRequests = flag.Int("max-concurrent-requests", 1000, "Maximum number of concurrent requests") + requestTimeout = flag.Int("request-timeout", 30, "Request timeout in seconds") + maxIdleConns = flag.Int("max-idle-conns", 100, "Maximum number of idle connections") + maxConnsPerHost = flag.Int("max-conns-per-host", 10, "Maximum number of connections per host") + ) + + // Parse command line flags + flag.Parse() + + // Create Router API server configuration + config := &router.Config{ + Port: *port, + SandboxEndpoints: []string{ + "http://sandbox-1:8080", + "http://sandbox-2:8080", + "http://sandbox-3:8080", + }, // Default sandbox endpoints, can be configured via env vars + Debug: *debug, + EnableTLS: *enableTLS, + TLSCert: *tlsCert, + TLSKey: *tlsKey, + MaxConcurrentRequests: *maxConcurrentRequests, + RequestTimeout: *requestTimeout, + MaxIdleConns: *maxIdleConns, + MaxConnsPerHost: *maxConnsPerHost, + } + + // Create Router API server + server, err := router.NewServer(config) + if err != nil { + log.Fatalf("Failed to create Router API server: %v", err) + } + + // Setup signal handling + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) + + // Start Router API server in goroutine + errCh := make(chan error, 1) + go func() { + log.Printf("Starting agentcube Router server on port %s", *port) + if err := server.Start(ctx); err != nil { + errCh <- err + } + }() + + // Wait for signal or error + select { + case <-sigCh: + log.Println("Received shutdown signal, shutting down gracefully...") + cancel() + time.Sleep(2 * time.Second) // Give server time to shutdown gracefully + case err := <-errCh: + log.Fatalf("Server error: %v", err) + } + + log.Println("Router server stopped") +} diff --git a/docs/proposal/router-design.md b/docs/proposal/router-design.md new file mode 100644 index 00000000..9df17add --- /dev/null +++ b/docs/proposal/router-design.md @@ -0,0 +1,295 @@ +# Router Submodule Design Document + +## 1. Overview + +Router apiserver is responsible for receiving user HTTP requests and forwarding them to the corresponding Sandbox. Router focuses on high-performance request routing, while session and sandbox management is handled by SessionManager. + +## 2. Architecture Design + +### 2.1 Overall Architecture Flow + +```mermaid +graph TB + Client[Client] --> Router[Router API Server] + Router --> SessionMgr[SessionManager Interface] + Router --> Sandbox1[Sandbox 1] + Router --> Sandbox2[Sandbox 2] + Router --> SandboxN[Sandbox N] + + subgraph "Router Core Components" + Router + SessionMgr + end + + subgraph "Sandbox Cluster" + Sandbox1 + Sandbox2 + SandboxN + end +``` + +### 2.2 Request Routing Flow + +```mermaid +sequenceDiagram + participant C as Client + participant R as Router + participant SM as SessionManager (Mock) + participant SB as Sandbox + + C->>R: HTTP Request with x-agentcube-session-id + + alt session-id exists + R->>SM: getSandboxInfoBySessionId(session-id) + SM->>R: return (endpoint, session-id, nil) + else session-id is empty + R->>SM: getSandboxInfoBySessionId("") + SM->>R: return (endpoint, new-session-id, nil) + end + + alt get sandbox success + R->>SB: forward request to sandbox endpoint + SB->>R: return response + R->>C: forward response + session-id header + else get sandbox failed + R->>C: return error response + end +``` + +## 3. Detailed Design + +### 3.1 SessionManager Interface Definition + +Router obtains Sandbox information through the SessionManager interface, which can be implemented with Mock: + +```go +type SessionManager interface { + // Get sandbox information based on session-id + // When sessionId is empty, create a new session + GetSandboxInfoBySessionId(sessionId string) (endpoint string, newSessionId string, err error) +} + +// Mock implementation example +type MockSessionManager struct { + sandboxEndpoints []string + currentIndex int +} + +func (m *MockSessionManager) GetSandboxInfoBySessionId(sessionId string) (string, string, error) { + if sessionId == "" { + sessionId = generateNewSessionId() + } + + // Simple round-robin sandbox selection + endpoint := m.sandboxEndpoints[m.currentIndex%len(m.sandboxEndpoints)] + m.currentIndex++ + + return endpoint, sessionId, nil +} +``` + +### 3.2 Supported Request Types + +Uses Gin framework to provide HTTP Server services, handling two types of requests: + +1. **Agent Invoke Requests** + ``` + :/v1/namespaces/{agentNamespace}/agent-runtimes/{agentName}/invocations/ + ``` + +2. **Code Interpreter Invoke Requests** + ``` + :/v1/namespaces/{namespace}/code-interpreters/{name}/invocations/ + ``` + +### 3.3 Request Processing Flow + +```mermaid +flowchart TD + Start([Receive HTTP Request]) --> ValidateReq{Validate Request Format} + ValidateReq -->|Invalid| ReturnBadRequest[Return 400 Bad Request] + ValidateReq -->|Valid| ExtractSessionId[Extract x-agentcube-session-id] + + ExtractSessionId --> GetSandbox[Call SessionMgr.GetSandboxInfoBySessionId] + GetSandbox --> CheckResult{Check Result} + + CheckResult -->|Success| ForwardRequest[Forward Request to Sandbox] + CheckResult -->|Interface Error| ReturnInternalError[Return 500 Internal Server Error] + + ForwardRequest --> CheckSandboxResponse{Check Sandbox Response} + CheckSandboxResponse -->|Success| ReturnSuccess[Return Success Response + Session ID] + CheckSandboxResponse -->|Timeout| ReturnTimeout[Return 504 Gateway Timeout] + CheckSandboxResponse -->|Connection Failed| ReturnBadGateway[Return 502 Bad Gateway] + CheckSandboxResponse -->|Other Error| ReturnSandboxError[Return Sandbox Error] + + ReturnBadRequest --> End([End]) + ReturnInternalError --> End + ReturnSuccess --> End + ReturnTimeout --> End + ReturnBadGateway --> End + ReturnSandboxError --> End +``` + +### 3.4 Core Requirements + +1. **High-Performance Routing**: Fast routing to corresponding sandbox based on session-id +2. **Session Integration**: Seamless collaboration with SessionManager, supporting dynamic sandbox creation +3. **Long Connection Support**: Support for long-running requests such as code execution and file operations +4. **Simple Design**: Focus on core routing functionality, avoid over-engineering +5. **Graceful Shutdown**: Reference E2B's graceful shutdown process to ensure no request loss + +### 3.5 Design Goals + +- **High Performance**: Millisecond-level routing latency, support for high concurrency +- **High Availability**: Stateless design, support for horizontal scaling +- **Observability**: Complete monitoring, logging, and tracing system + +## 4. HTTP Response Handling + +### 4.1 Success Responses + +| Status Code | Scenario | Response Headers | Response Body | +|-------------|----------|------------------|---------------| +| 200 OK | Request processed successfully | `x-agentcube-session-id: ` | Original response from Sandbox | +| 201 Created | Resource created successfully | `x-agentcube-session-id: ` | Created resource information | +| 202 Accepted | Async request accepted | `x-agentcube-session-id: ` | Task status information | + +### 4.2 Client Error Responses + +| Status Code | Scenario | Response Body Example | +|-------------|----------|----------------------| +| 400 Bad Request | Invalid request format, invalid parameters | `{"error": "invalid request format", "code": "INVALID_REQUEST"}` | +| 401 Unauthorized | Authentication failed | `{"error": "authentication required", "code": "AUTH_REQUIRED"}` | +| 403 Forbidden | Insufficient permissions | `{"error": "insufficient permissions", "code": "PERMISSION_DENIED"}` | +| 404 Not Found | Session or resource not found | `{"error": "session not found", "code": "SESSION_NOT_FOUND"}` | +| 409 Conflict | Resource conflict | `{"error": "resource conflict", "code": "RESOURCE_CONFLICT"}` | +| 429 Too Many Requests | Rate limit exceeded | `{"error": "rate limit exceeded", "code": "RATE_LIMIT_EXCEEDED"}` | + +### 4.3 Server Error Responses + +| Status Code | Scenario | Response Body Example | +|-------------|----------|----------------------| +| 500 Internal Server Error | Router internal error | `{"error": "internal server error", "code": "INTERNAL_ERROR"}` | +| 502 Bad Gateway | Sandbox connection failed | `{"error": "sandbox unreachable", "code": "SANDBOX_UNREACHABLE"}` | +| 503 Service Unavailable | Sandbox unavailable or overloaded | `{"error": "sandbox unavailable", "code": "SANDBOX_UNAVAILABLE"}` | +| 504 Gateway Timeout | Sandbox response timeout | `{"error": "sandbox timeout", "code": "SANDBOX_TIMEOUT"}` | + +### 4.4 Error Handling Flow + +```mermaid +flowchart TD + Error[Error Occurred] --> CheckErrorType{Check Error Type} + + CheckErrorType -->|Request Validation Error| ClientError[Client Error 4xx] + CheckErrorType -->|SessionMgr Error| CheckSessionError{Session Error Type} + CheckErrorType -->|Sandbox Error| CheckSandboxError{Sandbox Error Type} + CheckErrorType -->|Router Internal Error| ServerError[Server Error 5xx] + + CheckSessionError -->|Session Not Found| Return404[404 Not Found] + CheckSessionError -->|Session Creation Failed| Return500[500 Internal Error] + CheckSessionError -->|Permission Error| Return403[403 Forbidden] + + CheckSandboxError -->|Connection Failed| Return502[502 Bad Gateway] + CheckSandboxError -->|Response Timeout| Return504[504 Gateway Timeout] + CheckSandboxError -->|Service Unavailable| Return503[503 Service Unavailable] + CheckSandboxError -->|Sandbox Internal Error| ForwardError[Forward Sandbox Error] + + ClientError --> LogError[Log Error] + ServerError --> LogError + Return404 --> LogError + Return500 --> LogError + Return403 --> LogError + Return502 --> LogError + Return504 --> LogError + Return503 --> LogError + ForwardError --> LogError + + LogError --> UpdateMetrics[Update Metrics] + UpdateMetrics --> ReturnResponse[Return Error Response] +``` + +## 5. Performance and Monitoring + +### 5.1 Performance Metrics + +- **Routing Latency**: Target < 5ms (P99) +- **Throughput**: Target > 10,000 RPS +- **Concurrent Connections**: Support > 50,000 concurrent connections +- **Memory Usage**: < 1GB (steady state) + +### 5.2 Monitoring Metrics + +```mermaid +graph LR + subgraph "Request Metrics" + A[Total Requests] + B[Request Latency] + C[Error Rate] + D[Concurrency] + end + + subgraph "Routing Metrics" + E[Routing Success Rate] + F[Session Creation Rate] + G[Sandbox Hit Rate] + end + + subgraph "System Metrics" + H[CPU Usage] + I[Memory Usage] + J[Network I/O] + K[Connection Pool Status] + end +``` + +### 5.3 Logging + +- **Access Logs**: Record all HTTP requests +- **Error Logs**: Record all errors and exceptions +- **Performance Logs**: Record key performance metrics +- **Audit Logs**: Record important operations and state changes + +## 6. Deployment and Operations + +### 6.1 Deployment Architecture + +```mermaid +graph TB + LB[Load Balancer] --> R1[Router Instance 1] + LB --> R2[Router Instance 2] + LB --> R3[Router Instance N] + + R1 --> SM[SessionManager Interface] + R2 --> SM + R3 --> SM + + R1 --> SB1[Sandbox Cluster] + R2 --> SB1 + R3 --> SB1 + + subgraph "Router Layer" + R1 + R2 + R3 + end + + subgraph "Interface Layer" + SM + end + + subgraph "Sandbox Layer" + SB1 + end +``` + +### 6.2 Configuration Management + +- **Environment Configuration**: Support for multi-environment configuration (dev/staging/prod) +- **Dynamic Configuration**: Support for runtime configuration updates +- **Configuration Validation**: Validate configuration integrity at startup + +### 6.3 Health Checks + +- **Liveness Check**: `/health/live` +- **Readiness Check**: `/health/ready` +- **Dependency Check**: Verify connectivity to SessionManager and Sandbox diff --git a/go.mod b/go.mod index 72b43efb..8acc13c3 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( k8s.io/api v0.34.1 k8s.io/apimachinery v0.34.1 k8s.io/client-go v0.34.1 + k8s.io/utils v0.0.0-20251002143259-bc988d571ff4 sigs.k8s.io/agent-sandbox v0.1.0 sigs.k8s.io/controller-runtime v0.22.2 ) @@ -97,7 +98,6 @@ require ( k8s.io/apiextensions-apiserver v0.34.1 // indirect k8s.io/klog/v2 v2.130.1 // indirect k8s.io/kube-openapi v0.0.0-20250910181357-589584f1c912 // indirect - k8s.io/utils v0.0.0-20251002143259-bc988d571ff4 // indirect sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 // indirect sigs.k8s.io/randfill v1.0.0 // indirect sigs.k8s.io/structured-merge-diff/v6 v6.3.0 // indirect diff --git a/k8s/agentcube-apiserver.yaml b/k8s/agentcube-router.yaml similarity index 83% rename from k8s/agentcube-apiserver.yaml rename to k8s/agentcube-router.yaml index 0ec7ac16..3477b7d7 100644 --- a/k8s/agentcube-apiserver.yaml +++ b/k8s/agentcube-router.yaml @@ -1,5 +1,5 @@ -# All-in-one deployment file for agentcube-apiserver -# Apply with: kubectl apply -f k8s/agentcube-apiserver.yaml +# All-in-one deployment file for agentcube-router +# Apply with: kubectl apply -f k8s/agentcube-router.yaml --- apiVersion: v1 @@ -11,14 +11,14 @@ metadata: apiVersion: v1 kind: ServiceAccount metadata: - name: agentcube-apiserver + name: agentcube-router namespace: agentcube --- apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRole metadata: - name: agentcube-apiserver + name: agentcube-router rules: - apiGroups: ["agents.x-k8s.io"] resources: ["sandboxes"] @@ -58,38 +58,38 @@ rules: apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRoleBinding metadata: - name: agentcube-apiserver + name: agentcube-router roleRef: apiGroup: rbac.authorization.k8s.io kind: ClusterRole - name: agentcube-apiserver + name: agentcube-router subjects: - kind: ServiceAccount - name: agentcube-apiserver + name: agentcube-router namespace: agentcube --- apiVersion: apps/v1 kind: Deployment metadata: - name: agentcube-apiserver + name: agentcube-router namespace: agentcube labels: - app: agentcube-apiserver + app: agentcube-router spec: replicas: 1 selector: matchLabels: - app: agentcube-apiserver + app: agentcube-router template: metadata: labels: - app: agentcube-apiserver + app: agentcube-router spec: - serviceAccountName: agentcube-apiserver + serviceAccountName: agentcube-router containers: - - name: agentcube-apiserver - image: agentcube-apiserver:latest + - name: agentcube-router + image: agentcube-router:latest imagePullPolicy: IfNotPresent ports: - name: http @@ -103,6 +103,7 @@ spec: args: - --port=8080 - --runtime-class-name= + - --debug resources: requests: cpu: 100m @@ -127,10 +128,10 @@ spec: apiVersion: v1 kind: Service metadata: - name: agentcube-apiserver + name: agentcube-router namespace: agentcube labels: - app: agentcube-apiserver + app: agentcube-router spec: type: ClusterIP ports: @@ -139,4 +140,4 @@ spec: protocol: TCP name: http selector: - app: agentcube-apiserver + app: agentcube-router diff --git a/pkg/agentd/agentd_test.go b/pkg/agentd/agentd_test.go index da9b4d98..48e65e60 100644 --- a/pkg/agentd/agentd_test.go +++ b/pkg/agentd/agentd_test.go @@ -38,7 +38,7 @@ func TestReconciler_Reconcile_WithLastActivity(t *testing.T) { Name: "test-sandbox", Namespace: "default", Annotations: map[string]string{ - "last-activity-time": now.Add(-5 * time.Minute).Format(time.RFC3339), + "agentcube.volcano.sh/last-activity": now.Add(-5 * time.Minute).Format(time.RFC3339), }, }, Status: sandboxv1alpha1.SandboxStatus{ @@ -61,7 +61,7 @@ func TestReconciler_Reconcile_WithLastActivity(t *testing.T) { Name: "test-sandbox", Namespace: "default", Annotations: map[string]string{ - "last-activity-time": now.Add(-20 * time.Minute).Format(time.RFC3339), + "agentcube.volcano.sh/last-activity": now.Add(-20 * time.Minute).Format(time.RFC3339), }, }, Status: sandboxv1alpha1.SandboxStatus{ @@ -83,7 +83,7 @@ func TestReconciler_Reconcile_WithLastActivity(t *testing.T) { Name: "test-sandbox", Namespace: "default", Annotations: map[string]string{ - "last-activity-time": now.Add(-20 * time.Minute).Format(time.RFC3339), + "agentcube.volcano.sh/last-activity": now.Add(-20 * time.Minute).Format(time.RFC3339), }, }, Status: sandboxv1alpha1.SandboxStatus{ diff --git a/pkg/router/apiserver.go b/pkg/router/apiserver.go new file mode 100644 index 00000000..4df9c271 --- /dev/null +++ b/pkg/router/apiserver.go @@ -0,0 +1,153 @@ +package router + +import ( + "context" + "fmt" + "log" + "net/http" + "time" + + "github.com/gin-gonic/gin" +) + +// Server is the main structure for Router apiserver +type Server struct { + config *Config + engine *gin.Engine + httpServer *http.Server + sessionManager SessionManager + redisManager RedisManager + semaphore chan struct{} // For limiting concurrent requests +} + +// NewServer creates a new Router API server instance +func NewServer(config *Config) (*Server, error) { + if config == nil { + return nil, fmt.Errorf("config cannot be nil") + } + + // Set default values for concurrency settings + if config.MaxConcurrentRequests <= 0 { + config.MaxConcurrentRequests = 1000 // Default limit + } + if config.RequestTimeout <= 0 { + config.RequestTimeout = 30 // Default 30 seconds + } + if config.MaxIdleConns <= 0 { + config.MaxIdleConns = 100 // Default 100 idle connections + } + if config.MaxConnsPerHost <= 0 { + config.MaxConnsPerHost = 10 // Default 10 connections per host + } + if config.SessionExpireDuration <= 0 { + config.SessionExpireDuration = 3600 // Default 1 hour + } + + // Create session manager (using mock implementation) + sessionManager := NewMockSessionManager(config.SandboxEndpoints) + + // Create Redis manager (using mock implementation) + redisManager := NewMockRedisManager(config.EnableRedis) + + // Set Gin mode based on environment + if config.Debug { + gin.SetMode(gin.DebugMode) + } else { + gin.SetMode(gin.ReleaseMode) + } + + server := &Server{ + config: config, + sessionManager: sessionManager, + redisManager: redisManager, + semaphore: make(chan struct{}, config.MaxConcurrentRequests), + } + + // Setup routes + server.setupRoutes() + + return server, nil +} + +// concurrencyLimitMiddleware limits the number of concurrent requests +func (s *Server) concurrencyLimitMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + // Try to acquire a slot in the semaphore + select { + case s.semaphore <- struct{}{}: + // Successfully acquired a slot, continue processing + defer func() { + // Release the slot when done + <-s.semaphore + }() + c.Next() + default: + // No slots available, return 503 Service Unavailable + c.JSON(http.StatusServiceUnavailable, gin.H{ + "error": "server overloaded, please try again later", + "code": "SERVER_OVERLOADED", + }) + c.Abort() + } + } +} + +// setupRoutes configures HTTP routes using Gin +func (s *Server) setupRoutes() { + s.engine = gin.New() + + // Add middleware + s.engine.Use(gin.Logger()) + s.engine.Use(gin.Recovery()) + + // Health check endpoints (no authentication required, no concurrency limit) + s.engine.GET("/health", s.handleHealth) + s.engine.GET("/health/live", s.handleHealthLive) + s.engine.GET("/health/ready", s.handleHealthReady) + + // API v1 routes with concurrency limiting + v1 := s.engine.Group("/v1") + v1.Use(s.concurrencyLimitMiddleware()) // Apply concurrency limit to API routes + + // Agent invoke requests + v1.Any("/namespaces/:agentNamespace/agent-runtimes/:agentName/invocations/*path", s.handleAgentInvoke) + + // Code interpreter invoke requests - use different base path to avoid conflicts + v1.Any("/code-namespaces/:namespace/code-interpreters/:name/invocations/*path", s.handleCodeInterpreterInvoke) +} + +// Start starts the Router API server +func (s *Server) Start(ctx context.Context) error { + addr := ":" + s.config.Port + + s.httpServer = &http.Server{ + Addr: addr, + Handler: s.engine, + ReadTimeout: 30 * time.Second, // Longer timeout for potential long-running requests + WriteTimeout: 30 * time.Second, + IdleTimeout: 60 * time.Second, + } + + // Listen for shutdown signal in goroutine + go func() { + <-ctx.Done() + log.Println("Shutting down Router server...") + shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := s.httpServer.Shutdown(shutdownCtx); err != nil { + log.Printf("Server shutdown error: %v", err) + } + }() + + log.Printf("Router server listening on %s", addr) + + // Start HTTP or HTTPS server + if s.config.EnableTLS { + if s.config.TLSCert == "" || s.config.TLSKey == "" { + return fmt.Errorf("TLS enabled but cert/key not provided") + } + return s.httpServer.ListenAndServeTLS(s.config.TLSCert, s.config.TLSKey) + } + + return s.httpServer.ListenAndServe() +} diff --git a/pkg/router/apiserver_test.go b/pkg/router/apiserver_test.go new file mode 100644 index 00000000..0cbf8dfe --- /dev/null +++ b/pkg/router/apiserver_test.go @@ -0,0 +1,343 @@ +package router + +import ( + "context" + "testing" + "time" +) + +func TestNewServer(t *testing.T) { + tests := []struct { + name string + config *Config + wantErr bool + }{ + { + name: "nil config", + config: nil, + wantErr: true, + }, + { + name: "valid config with defaults", + config: &Config{ + Port: "8080", + }, + wantErr: false, + }, + { + name: "valid config with custom values", + config: &Config{ + Port: "9090", + MaxConcurrentRequests: 500, + RequestTimeout: 60, + MaxIdleConns: 200, + MaxConnsPerHost: 20, + SessionExpireDuration: 7200, + EnableRedis: true, + Debug: true, + }, + wantErr: false, + }, + { + name: "config with TLS enabled", + config: &Config{ + Port: "8443", + EnableTLS: true, + TLSCert: "/path/to/cert.pem", + TLSKey: "/path/to/key.pem", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server, err := NewServer(tt.config) + + if (err != nil) != tt.wantErr { + t.Errorf("NewServer() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr { + // Verify server was created + if server == nil { + t.Error("Expected non-nil server") + return + } + + // Verify config was set + if server.config != tt.config { + t.Error("Server config was not set correctly") + } + + // Verify session manager was created + if server.sessionManager == nil { + t.Error("Session manager was not created") + } + + // Verify redis manager was created + if server.redisManager == nil { + t.Error("Redis manager was not created") + } + + // Verify semaphore was created with correct capacity + expectedCapacity := tt.config.MaxConcurrentRequests + if expectedCapacity <= 0 { + expectedCapacity = 1000 // Default value + } + if cap(server.semaphore) != expectedCapacity { + t.Errorf("Expected semaphore capacity %d, got %d", expectedCapacity, cap(server.semaphore)) + } + + // Verify default values were set + if server.config.MaxConcurrentRequests <= 0 { + t.Error("MaxConcurrentRequests should have been set to default") + } + if server.config.RequestTimeout <= 0 { + t.Error("RequestTimeout should have been set to default") + } + if server.config.MaxIdleConns <= 0 { + t.Error("MaxIdleConns should have been set to default") + } + if server.config.MaxConnsPerHost <= 0 { + t.Error("MaxConnsPerHost should have been set to default") + } + if server.config.SessionExpireDuration <= 0 { + t.Error("SessionExpireDuration should have been set to default") + } + } + }) + } +} + +func TestServer_DefaultValues(t *testing.T) { + config := &Config{ + Port: "8080", + // Leave other values as zero to test defaults + } + + server, err := NewServer(config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + // Test default values + if server.config.MaxConcurrentRequests != 1000 { + t.Errorf("Expected default MaxConcurrentRequests 1000, got %d", server.config.MaxConcurrentRequests) + } + + if server.config.RequestTimeout != 30 { + t.Errorf("Expected default RequestTimeout 30, got %d", server.config.RequestTimeout) + } + + if server.config.MaxIdleConns != 100 { + t.Errorf("Expected default MaxIdleConns 100, got %d", server.config.MaxIdleConns) + } + + if server.config.MaxConnsPerHost != 10 { + t.Errorf("Expected default MaxConnsPerHost 10, got %d", server.config.MaxConnsPerHost) + } + + if server.config.SessionExpireDuration != 3600 { + t.Errorf("Expected default SessionExpireDuration 3600, got %d", server.config.SessionExpireDuration) + } +} + +func TestServer_ConcurrencyLimitMiddleware(t *testing.T) { + config := &Config{ + Port: "8080", + MaxConcurrentRequests: 2, // Small limit for testing + } + + server, err := NewServer(config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + middleware := server.concurrencyLimitMiddleware() + + // Test that middleware function was created + if middleware == nil { + t.Error("Expected non-nil middleware function") + } + + // Note: Testing the actual middleware behavior would require setting up + // a full HTTP test environment, which is beyond the scope of unit tests. + // Integration tests would be more appropriate for testing middleware behavior. +} + +func TestServer_SetupRoutes(t *testing.T) { + config := &Config{ + Port: "8080", + } + + server, err := NewServer(config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + // Verify that engine was created during setupRoutes + if server.engine == nil { + t.Error("Expected non-nil Gin engine after setupRoutes") + } + + // Note: Testing specific routes would require HTTP testing, + // which is more appropriate for integration tests. +} + +func TestServer_StartContext(t *testing.T) { + config := &Config{ + Port: "0", // Use port 0 to let the OS assign a free port + } + + server, err := NewServer(config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + // Test context cancellation + ctx, cancel := context.WithCancel(context.Background()) + + // Start server in goroutine + errChan := make(chan error, 1) + go func() { + err := server.Start(ctx) + errChan <- err + }() + + // Give server a moment to start + time.Sleep(100 * time.Millisecond) + + // Cancel context to trigger shutdown + cancel() + + // Wait for server to shutdown + select { + case err := <-errChan: + // Server should shutdown gracefully, error might be http.ErrServerClosed + if err != nil && err.Error() != "http: Server closed" { + t.Errorf("Unexpected error during shutdown: %v", err) + } + case <-time.After(5 * time.Second): + t.Error("Server did not shutdown within timeout") + } +} + +func TestServer_TLSConfiguration(t *testing.T) { + tests := []struct { + name string + config *Config + wantErr bool + errString string + }{ + { + name: "TLS enabled with cert and key", + config: &Config{ + Port: "8443", + EnableTLS: true, + TLSCert: "/path/to/cert.pem", + TLSKey: "/path/to/key.pem", + }, + wantErr: false, + }, + { + name: "TLS enabled without cert", + config: &Config{ + Port: "8443", + EnableTLS: true, + TLSKey: "/path/to/key.pem", + }, + wantErr: true, + errString: "TLS enabled but cert/key not provided", + }, + { + name: "TLS enabled without key", + config: &Config{ + Port: "8443", + EnableTLS: true, + TLSCert: "/path/to/cert.pem", + }, + wantErr: true, + errString: "TLS enabled but cert/key not provided", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server, err := NewServer(tt.config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + // Test Start method with a context that will be cancelled immediately + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately to avoid actually starting the server + + err = server.Start(ctx) + + if tt.wantErr { + if err == nil { + t.Error("Expected error but got none") + } else if tt.errString != "" && err.Error() != tt.errString { + t.Errorf("Expected error '%s', got '%s'", tt.errString, err.Error()) + } + } else { + // For TLS tests, we expect the server to fail to start due to invalid cert/key paths + // but the configuration validation should pass + if err != nil && err.Error() == tt.errString { + t.Errorf("Unexpected configuration error: %v", err) + } + } + }) + } +} + +func TestServer_RedisIntegration(t *testing.T) { + tests := []struct { + name string + enableRedis bool + }{ + { + name: "Redis enabled", + enableRedis: true, + }, + { + name: "Redis disabled", + enableRedis: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := &Config{ + Port: "8080", + EnableRedis: tt.enableRedis, + } + + server, err := NewServer(config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + // Verify Redis manager was created + if server.redisManager == nil { + t.Error("Redis manager was not created") + } + + // Test Redis manager functionality based on enabled state + err = server.redisManager.UpdateSessionActivity("test-session") + if tt.enableRedis { + if err != nil { + t.Errorf("Expected no error when Redis is enabled, got: %v", err) + } + } else { + // When disabled, UpdateSessionActivity should not return an error + // (it silently skips) + if err != nil { + t.Errorf("Expected no error when Redis is disabled, got: %v", err) + } + } + }) + } +} diff --git a/pkg/router/config.go b/pkg/router/config.go new file mode 100644 index 00000000..ab686daa --- /dev/null +++ b/pkg/router/config.go @@ -0,0 +1,53 @@ +package router + +// LastActivityAnnotationKey is the annotation key for tracking last activity +const LastActivityAnnotationKey = "agentcube.volcano.sh/last-activity" + +// Config contains configuration parameters for Router apiserver +type Config struct { + // Port is the port the API server listens on + Port string + + // SandboxEndpoints is the list of available sandbox endpoints + SandboxEndpoints []string + + // Debug enables debug mode + Debug bool + + // EnableTLS enables HTTPS + EnableTLS bool + + // TLSCert is the path to the TLS certificate file + TLSCert string + + // TLSKey is the path to the TLS private key file + TLSKey string + + // MaxConcurrentRequests limits the number of concurrent requests (0 = unlimited) + MaxConcurrentRequests int + + // RequestTimeout sets the timeout for individual requests + RequestTimeout int // seconds + + // MaxIdleConns sets the maximum number of idle connections in the connection pool + MaxIdleConns int + + // MaxConnsPerHost sets the maximum number of connections per host + MaxConnsPerHost int + + // Redis configuration + // EnableRedis enables Redis session activity tracking + EnableRedis bool + + // RedisAddr is the Redis server address (e.g., "localhost:6379") + RedisAddr string + + // RedisPassword is the Redis password (optional) + RedisPassword string + + // RedisDB is the Redis database number + RedisDB int + + // SessionExpireDuration is the duration after which inactive sessions expire + SessionExpireDuration int // seconds, default 3600 (1 hour) +} diff --git a/pkg/router/handlers.go b/pkg/router/handlers.go new file mode 100644 index 00000000..bb3aec1d --- /dev/null +++ b/pkg/router/handlers.go @@ -0,0 +1,205 @@ +package router + +import ( + "context" + "log" + "net/http" + "net/http/httputil" + "net/url" + "strings" + "time" + + "github.com/gin-gonic/gin" +) + +// handleHealth handles health check requests +func (s *Server) handleHealth(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "status": "healthy", + }) +} + +// handleHealthLive handles liveness probe +func (s *Server) handleHealthLive(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "status": "alive", + }) +} + +// handleHealthReady handles readiness probe +func (s *Server) handleHealthReady(c *gin.Context) { + // Check if SessionManager is available + if s.sessionManager == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{ + "status": "not ready", + "error": "session manager not available", + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "status": "ready", + }) +} + +// handleAgentInvoke handles agent invocation requests +func (s *Server) handleAgentInvoke(c *gin.Context) { + agentNamespace := c.Param("agentNamespace") + agentName := c.Param("agentName") + path := c.Param("path") + + log.Printf("Agent invoke request: namespace=%s, agent=%s, path=%s", agentNamespace, agentName, path) + + // Extract session ID from header + sessionID := c.GetHeader("x-agentcube-session-id") + + // Get sandbox info from session manager + endpoint, newSessionID, err := s.sessionManager.GetSandboxInfoBySessionId(sessionID, agentNamespace, agentName, KindAgent) + if err != nil { + log.Printf("Failed to get sandbox info: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "internal server error", + "code": "INTERNAL_ERROR", + }) + return + } + + // Update session activity in Redis when receiving request + if newSessionID != "" { + if err := s.redisManager.UpdateSessionActivity(newSessionID); err != nil { + log.Printf("Failed to update session activity for request: %v", err) + } + } + + // Forward request to sandbox with session ID + s.forwardToSandbox(c, endpoint, path, newSessionID) +} + +// handleCodeInterpreterInvoke handles code interpreter invocation requests +func (s *Server) handleCodeInterpreterInvoke(c *gin.Context) { + namespace := c.Param("namespace") + name := c.Param("name") + path := c.Param("path") + + log.Printf("Code interpreter invoke request: namespace=%s, name=%s, path=%s", namespace, name, path) + + // Extract session ID from header + sessionID := c.GetHeader("x-agentcube-session-id") + + // Get sandbox info from session manager + endpoint, newSessionID, err := s.sessionManager.GetSandboxInfoBySessionId(sessionID, namespace, name, KindCodeInterpreter) + if err != nil { + log.Printf("Failed to get sandbox info: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "internal server error", + "code": "INTERNAL_ERROR", + }) + return + } + + // Update session activity in Redis when receiving request + if newSessionID != "" { + if err := s.redisManager.UpdateSessionActivity(newSessionID); err != nil { + log.Printf("Failed to update session activity for request: %v", err) + } + } + + // Forward request to sandbox with session ID + s.forwardToSandbox(c, endpoint, path, newSessionID) +} + +// forwardToSandbox forwards the request to the specified sandbox endpoint +func (s *Server) forwardToSandbox(c *gin.Context, endpoint, path, sessionID string) { + // Parse the target URL + targetURL, err := url.Parse(endpoint) + if err != nil { + log.Printf("Invalid sandbox endpoint: %s, error: %v", endpoint, err) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "internal server error", + "code": "INTERNAL_ERROR", + }) + return + } + + // Create reverse proxy with optimized transport + proxy := httputil.NewSingleHostReverseProxy(targetURL) + + // Configure HTTP transport for better concurrency + proxy.Transport = &http.Transport{ + MaxIdleConns: s.config.MaxIdleConns, + MaxIdleConnsPerHost: s.config.MaxConnsPerHost, + IdleConnTimeout: 90 * time.Second, + DisableCompression: false, + ForceAttemptHTTP2: true, + } + + // Customize the director to modify the request + originalDirector := proxy.Director + proxy.Director = func(req *http.Request) { + originalDirector(req) + + // Set the target path + if path != "" && !strings.HasPrefix(path, "/") { + path = "/" + path + } + req.URL.Path = path + req.URL.RawPath = "" + + // Set the host header + req.Host = targetURL.Host + + // Add forwarding headers + req.Header.Set("X-Forwarded-Host", c.Request.Host) + req.Header.Set("X-Forwarded-Proto", "http") + if c.Request.TLS != nil { + req.Header.Set("X-Forwarded-Proto", "https") + } + + log.Printf("Forwarding request to: %s%s", targetURL.String(), path) + } + + // Customize error handler + proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { + log.Printf("Proxy error: %v", err) + + // Determine error type and return appropriate response + if strings.Contains(err.Error(), "connection refused") { + c.JSON(http.StatusBadGateway, gin.H{ + "error": "sandbox unreachable", + "code": "SANDBOX_UNREACHABLE", + }) + } else if strings.Contains(err.Error(), "timeout") { + c.JSON(http.StatusGatewayTimeout, gin.H{ + "error": "sandbox timeout", + "code": "SANDBOX_TIMEOUT", + }) + } else { + c.JSON(http.StatusBadGateway, gin.H{ + "error": "sandbox unreachable", + "code": "SANDBOX_UNREACHABLE", + }) + } + } + + // Modify response + proxy.ModifyResponse = func(resp *http.Response) error { + // Always set session ID in response header + if sessionID != "" { + resp.Header.Set("x-agentcube-session-id", sessionID) + + // Update session activity in Redis when returning response + if err := s.redisManager.UpdateSessionActivity(sessionID); err != nil { + log.Printf("Failed to update session activity for response: %v", err) + } + } + return nil + } + + // Set timeout for the proxy request using configured timeout + ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(s.config.RequestTimeout)*time.Second) + defer cancel() + c.Request = c.Request.WithContext(ctx) + + // Use the proxy to serve the request + proxy.ServeHTTP(c.Writer, c.Request) +} diff --git a/pkg/router/redis_manager.go b/pkg/router/redis_manager.go new file mode 100644 index 00000000..cc90d9c3 --- /dev/null +++ b/pkg/router/redis_manager.go @@ -0,0 +1,135 @@ +package router + +import ( + "fmt" + "log" + "sync" + "time" +) + +// RedisManager interface for managing session activity in Redis +type RedisManager interface { + // UpdateSessionActivity updates the lastActive time for a session ID + UpdateSessionActivity(sessionID string) error + + // GetSessionLastActive gets the last active time for a session ID + GetSessionLastActive(sessionID string) (time.Time, error) + + // CleanupExpiredSessions removes sessions that haven't been active for a specified duration + CleanupExpiredSessions(expireDuration time.Duration) error +} + +// MockRedisManager is a mock implementation for testing +type MockRedisManager struct { + mu sync.RWMutex + sessions map[string]time.Time // sessionID -> lastActive time + enabled bool +} + +// NewMockRedisManager creates a new mock Redis manager +func NewMockRedisManager(enabled bool) *MockRedisManager { + return &MockRedisManager{ + sessions: make(map[string]time.Time), + enabled: enabled, + } +} + +// UpdateSessionActivity implements RedisManager interface +func (r *MockRedisManager) UpdateSessionActivity(sessionID string) error { + if !r.enabled { + return nil // Silently skip if disabled + } + + if sessionID == "" { + return fmt.Errorf("session ID cannot be empty") + } + + r.mu.Lock() + defer r.mu.Unlock() + + now := time.Now() + r.sessions[sessionID] = now + + log.Printf("Updated session activity for session %s at %s", sessionID, now.Format(time.RFC3339)) + return nil +} + +// GetSessionLastActive implements RedisManager interface +func (r *MockRedisManager) GetSessionLastActive(sessionID string) (time.Time, error) { + if !r.enabled { + return time.Time{}, fmt.Errorf("Redis manager is disabled") + } + + if sessionID == "" { + return time.Time{}, fmt.Errorf("session ID cannot be empty") + } + + r.mu.RLock() + defer r.mu.RUnlock() + + lastActive, exists := r.sessions[sessionID] + if !exists { + return time.Time{}, fmt.Errorf("session %s not found", sessionID) + } + + return lastActive, nil +} + +// CleanupExpiredSessions implements RedisManager interface +func (r *MockRedisManager) CleanupExpiredSessions(expireDuration time.Duration) error { + if !r.enabled { + return nil // Silently skip if disabled + } + + r.mu.Lock() + defer r.mu.Unlock() + + now := time.Now() + expiredSessions := make([]string, 0) + + for sessionID, lastActive := range r.sessions { + if now.Sub(lastActive) > expireDuration { + expiredSessions = append(expiredSessions, sessionID) + } + } + + // Remove expired sessions + for _, sessionID := range expiredSessions { + delete(r.sessions, sessionID) + log.Printf("Cleaned up expired session: %s", sessionID) + } + + if len(expiredSessions) > 0 { + log.Printf("Cleaned up %d expired sessions", len(expiredSessions)) + } + + return nil +} + +// GetActiveSessionCount returns the number of active sessions +func (r *MockRedisManager) GetActiveSessionCount() int { + if !r.enabled { + return 0 + } + + r.mu.RLock() + defer r.mu.RUnlock() + return len(r.sessions) +} + +// GetAllActiveSessions returns all active session IDs (for debugging) +func (r *MockRedisManager) GetAllActiveSessions() []string { + if !r.enabled { + return nil + } + + r.mu.RLock() + defer r.mu.RUnlock() + + sessions := make([]string, 0, len(r.sessions)) + for sessionID := range r.sessions { + sessions = append(sessions, sessionID) + } + + return sessions +} diff --git a/pkg/router/redis_manager_test.go b/pkg/router/redis_manager_test.go new file mode 100644 index 00000000..502a462b --- /dev/null +++ b/pkg/router/redis_manager_test.go @@ -0,0 +1,250 @@ +package router + +import ( + "testing" + "time" +) + +func TestMockRedisManager_UpdateSessionActivity(t *testing.T) { + tests := []struct { + name string + enabled bool + sessionID string + wantErr bool + }{ + { + name: "valid session ID with enabled Redis", + enabled: true, + sessionID: "test-session-123", + wantErr: false, + }, + { + name: "empty session ID with enabled Redis", + enabled: true, + sessionID: "", + wantErr: true, + }, + { + name: "valid session ID with disabled Redis", + enabled: false, + sessionID: "test-session-123", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := NewMockRedisManager(tt.enabled) + err := r.UpdateSessionActivity(tt.sessionID) + + if (err != nil) != tt.wantErr { + t.Errorf("UpdateSessionActivity() error = %v, wantErr %v", err, tt.wantErr) + return + } + + // If enabled and no error, check if session was stored by checking count + if tt.enabled && !tt.wantErr { + if count := r.GetActiveSessionCount(); count != 1 { + t.Errorf("Expected 1 session to be stored, got %d", count) + } + } + }) + } +} + +func TestMockRedisManager_GetSessionLastActive(t *testing.T) { + r := NewMockRedisManager(true) + sessionID := "test-session-123" + + // Test getting non-existent session + _, err := r.GetSessionLastActive(sessionID) + if err == nil { + t.Error("Expected error for non-existent session") + } + + // Add session and test retrieval + beforeTime := time.Now() + err = r.UpdateSessionActivity(sessionID) + if err != nil { + t.Fatalf("Failed to update session activity: %v", err) + } + afterTime := time.Now() + + lastActive, err := r.GetSessionLastActive(sessionID) + if err != nil { + t.Fatalf("Failed to get session last active: %v", err) + } + + if lastActive.Before(beforeTime) || lastActive.After(afterTime) { + t.Errorf("Last active time %v is not within expected range [%v, %v]", lastActive, beforeTime, afterTime) + } + + // Test with disabled Redis + rDisabled := NewMockRedisManager(false) + _, err = rDisabled.GetSessionLastActive(sessionID) + if err == nil { + t.Error("Expected error when Redis is disabled") + } +} + +func TestMockRedisManager_CleanupExpiredSessions(t *testing.T) { + r := NewMockRedisManager(true) + + // Add some sessions + oldSessionID := "old-session" + newSessionID := "new-session" + + // Add sessions - the old one will be cleaned up based on time + err := r.UpdateSessionActivity(oldSessionID) + if err != nil { + t.Fatalf("Failed to update old session activity: %v", err) + } + + // Manually set old session timestamp by accessing the struct field + r.sessions[oldSessionID] = time.Now().Add(-2 * time.Hour) + + // Add new session + err = r.UpdateSessionActivity(newSessionID) + if err != nil { + t.Fatalf("Failed to update new session activity: %v", err) + } + + // Get initial count + initialCount := r.GetActiveSessionCount() + if initialCount != 2 { + t.Errorf("Expected 2 initial sessions, got %d", initialCount) + } + + // Cleanup sessions older than 1 hour + err = r.CleanupExpiredSessions(1 * time.Hour) + if err != nil { + t.Fatalf("Failed to cleanup expired sessions: %v", err) + } + + // Check that count decreased + finalCount := r.GetActiveSessionCount() + if finalCount != 1 { + t.Errorf("Expected 1 session after cleanup, got %d", finalCount) + } + + // Test with disabled Redis + rDisabled := NewMockRedisManager(false) + err = rDisabled.CleanupExpiredSessions(1 * time.Hour) + if err != nil { + t.Errorf("Cleanup should not fail when Redis is disabled: %v", err) + } +} + +func TestMockRedisManager_GetActiveSessionCount(t *testing.T) { + r := NewMockRedisManager(true) + + // Initially should be 0 + if count := r.GetActiveSessionCount(); count != 0 { + t.Errorf("Expected 0 active sessions, got %d", count) + } + + // Add some sessions + sessions := []string{"session1", "session2", "session3"} + for _, sessionID := range sessions { + err := r.UpdateSessionActivity(sessionID) + if err != nil { + t.Fatalf("Failed to update session activity: %v", err) + } + } + + // Should have 3 sessions + if count := r.GetActiveSessionCount(); count != 3 { + t.Errorf("Expected 3 active sessions, got %d", count) + } + + // Test with disabled Redis + rDisabled := NewMockRedisManager(false) + if count := rDisabled.GetActiveSessionCount(); count != 0 { + t.Errorf("Expected 0 active sessions when disabled, got %d", count) + } +} + +func TestMockRedisManager_GetAllActiveSessions(t *testing.T) { + r := NewMockRedisManager(true) + + // Initially should be empty + if sessions := r.GetAllActiveSessions(); len(sessions) != 0 { + t.Errorf("Expected 0 active sessions, got %d", len(sessions)) + } + + // Add some sessions + expectedSessions := []string{"session1", "session2", "session3"} + for _, sessionID := range expectedSessions { + err := r.UpdateSessionActivity(sessionID) + if err != nil { + t.Fatalf("Failed to update session activity: %v", err) + } + } + + // Get all sessions + activeSessions := r.GetAllActiveSessions() + if len(activeSessions) != len(expectedSessions) { + t.Errorf("Expected %d active sessions, got %d", len(expectedSessions), len(activeSessions)) + } + + // Check that all expected sessions are present + sessionMap := make(map[string]bool) + for _, session := range activeSessions { + sessionMap[session] = true + } + + for _, expected := range expectedSessions { + if !sessionMap[expected] { + t.Errorf("Expected session %s not found in active sessions", expected) + } + } + + // Test with disabled Redis + rDisabled := NewMockRedisManager(false) + if sessions := rDisabled.GetAllActiveSessions(); sessions != nil { + t.Error("Expected nil when Redis is disabled") + } +} + +func TestMockRedisManager_ConcurrentAccess(t *testing.T) { + r := NewMockRedisManager(true) + sessionID := "concurrent-session" + + // Test concurrent updates + done := make(chan bool, 10) + for i := 0; i < 10; i++ { + go func() { + err := r.UpdateSessionActivity(sessionID) + if err != nil { + t.Errorf("Concurrent update failed: %v", err) + } + done <- true + }() + } + + // Wait for all goroutines to complete + for i := 0; i < 10; i++ { + <-done + } + + // Check that session exists by checking count + if count := r.GetActiveSessionCount(); count != 1 { + t.Errorf("Expected 1 session after concurrent updates, got %d", count) + } + + // Test concurrent reads + for i := 0; i < 10; i++ { + go func() { + _, err := r.GetSessionLastActive(sessionID) + if err != nil { + t.Errorf("Concurrent read failed: %v", err) + } + done <- true + }() + } + + // Wait for all read goroutines to complete + for i := 0; i < 10; i++ { + <-done + } +} diff --git a/pkg/router/session_manager.go b/pkg/router/session_manager.go new file mode 100644 index 00000000..08d1e196 --- /dev/null +++ b/pkg/router/session_manager.go @@ -0,0 +1,116 @@ +package router + +import ( + "fmt" + "sync" + + "github.com/google/uuid" +) + +// Kind constants for sandbox types +const ( + KindAgent = "agent" + KindCodeInterpreter = "codeinterpreter" +) + +// SessionManager interface for managing sandbox sessions +type SessionManager interface { + // GetSandboxInfoBySessionId returns sandbox endpoint and session ID + // When sessionId is empty, creates a new session + // kind can be "agent" or "codeinterpreter" + GetSandboxInfoBySessionId(sessionId, namespace, name, kind string) (endpoint string, newSessionId string, err error) +} + +// MockSessionManager is a simple implementation for testing +type MockSessionManager struct { + mu sync.RWMutex + sandboxEndpoints []string + currentIndex int + sessions map[string]string // sessionId -> endpoint +} + +// NewMockSessionManager creates a new mock session manager +func NewMockSessionManager(sandboxEndpoints []string) *MockSessionManager { + if len(sandboxEndpoints) == 0 { + // Default sandbox endpoints for testing + sandboxEndpoints = []string{ + "http://sandbox-1:8080", + "http://sandbox-2:8080", + "http://sandbox-3:8080", + } + } + + return &MockSessionManager{ + sandboxEndpoints: sandboxEndpoints, + sessions: make(map[string]string), + } +} + +// GetSandboxInfoBySessionId implements SessionManager interface +func (m *MockSessionManager) GetSandboxInfoBySessionId(sessionId, namespace, name, kind string) (string, string, error) { + m.mu.Lock() + defer m.mu.Unlock() + + // Validate kind parameter + if kind != "" && kind != KindAgent && kind != KindCodeInterpreter { + return "", "", fmt.Errorf("invalid kind: %s, must be '%s' or '%s'", kind, KindAgent, KindCodeInterpreter) + } + + // If sessionId is empty, create a new session + if sessionId == "" { + sessionId = m.generateNewSessionId() + } + + // Check if session already exists + if endpoint, exists := m.sessions[sessionId]; exists { + return endpoint, sessionId, nil + } + + // Create new session with round-robin endpoint selection + if len(m.sandboxEndpoints) == 0 { + return "", "", fmt.Errorf("no sandbox endpoints available") + } + + // Select endpoint based on kind if specified + var endpoint string + if kind == KindAgent { + // For agent kind, prefer agent-specific endpoints or use round-robin + endpoint = m.sandboxEndpoints[m.currentIndex%len(m.sandboxEndpoints)] + } else if kind == KindCodeInterpreter { + // For codeinterpreter kind, prefer code interpreter endpoints or use round-robin + endpoint = m.sandboxEndpoints[m.currentIndex%len(m.sandboxEndpoints)] + } else { + // Default behavior for backward compatibility + endpoint = m.sandboxEndpoints[m.currentIndex%len(m.sandboxEndpoints)] + } + + m.currentIndex++ + + // Store the session with additional metadata + sessionKey := sessionId + if namespace != "" && name != "" { + sessionKey = fmt.Sprintf("%s/%s/%s", namespace, name, sessionId) + } + m.sessions[sessionKey] = endpoint + + return endpoint, sessionId, nil +} + +// generateNewSessionId generates a new UUID-based session ID +func (m *MockSessionManager) generateNewSessionId() string { + return uuid.New().String() +} + +// RemoveSession removes a session (for cleanup) +func (m *MockSessionManager) RemoveSession(sessionId string) { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.sessions, sessionId) +} + +// GetSessionCount returns the number of active sessions +func (m *MockSessionManager) GetSessionCount() int { + m.mu.RLock() + defer m.mu.RUnlock() + return len(m.sessions) +} diff --git a/pkg/router/session_manager_test.go b/pkg/router/session_manager_test.go new file mode 100644 index 00000000..9d328df6 --- /dev/null +++ b/pkg/router/session_manager_test.go @@ -0,0 +1,305 @@ +package router + +import ( + "fmt" + "testing" +) + +func TestMockSessionManager_GetSandboxInfoBySessionId(t *testing.T) { + endpoints := []string{ + "http://sandbox-1:8080", + "http://sandbox-2:8080", + "http://sandbox-3:8080", + } + sm := NewMockSessionManager(endpoints) + + tests := []struct { + name string + sessionID string + namespace string + agentName string + kind string + wantErr bool + }{ + { + name: "new session with agent kind", + sessionID: "", + namespace: "default", + agentName: "test-agent", + kind: KindAgent, + wantErr: false, + }, + { + name: "new session with code interpreter kind", + sessionID: "", + namespace: "default", + agentName: "test-interpreter", + kind: KindCodeInterpreter, + wantErr: false, + }, + { + name: "existing session", + sessionID: "existing-session-123", + namespace: "default", + agentName: "test-agent", + kind: KindAgent, + wantErr: false, + }, + { + name: "invalid kind", + sessionID: "", + namespace: "default", + agentName: "test-agent", + kind: "invalid-kind", + wantErr: true, + }, + { + name: "empty kind (should work for backward compatibility)", + sessionID: "", + namespace: "default", + agentName: "test-agent", + kind: "", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + endpoint, sessionID, err := sm.GetSandboxInfoBySessionId(tt.sessionID, tt.namespace, tt.agentName, tt.kind) + + if (err != nil) != tt.wantErr { + t.Errorf("GetSandboxInfoBySessionId() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr { + // Check that endpoint is not empty + if endpoint == "" { + t.Error("Expected non-empty endpoint") + } + + // Check that sessionID is not empty + if sessionID == "" { + t.Error("Expected non-empty session ID") + } + + // Check that endpoint is one of the configured endpoints + found := false + for _, ep := range endpoints { + if endpoint == ep { + found = true + break + } + } + if !found { + t.Errorf("Endpoint %s not found in configured endpoints", endpoint) + } + + // If we provided a session ID, it should be returned unchanged + if tt.sessionID != "" && sessionID != tt.sessionID { + t.Errorf("Expected session ID %s, got %s", tt.sessionID, sessionID) + } + } + }) + } +} + +func TestMockSessionManager_SessionPersistence(t *testing.T) { + endpoints := []string{"http://sandbox-1:8080"} + sm := NewMockSessionManager(endpoints) + + // Create a new session + endpoint1, sessionID, err := sm.GetSandboxInfoBySessionId("", "default", "test-agent", KindAgent) + if err != nil { + t.Fatalf("Failed to create new session: %v", err) + } + + // Use the same session ID again + endpoint2, sessionID2, err := sm.GetSandboxInfoBySessionId(sessionID, "default", "test-agent", KindAgent) + if err != nil { + t.Fatalf("Failed to get existing session: %v", err) + } + + // Should return the same endpoint and session ID + if endpoint1 != endpoint2 { + t.Errorf("Expected same endpoint for existing session, got %s vs %s", endpoint1, endpoint2) + } + + if sessionID != sessionID2 { + t.Errorf("Expected same session ID, got %s vs %s", sessionID, sessionID2) + } +} + +func TestMockSessionManager_RoundRobinDistribution(t *testing.T) { + endpoints := []string{ + "http://sandbox-1:8080", + "http://sandbox-2:8080", + "http://sandbox-3:8080", + } + sm := NewMockSessionManager(endpoints) + + // Create multiple sessions and track endpoint distribution + endpointCount := make(map[string]int) + numSessions := 9 // Multiple of 3 to test round-robin + + for i := 0; i < numSessions; i++ { + endpoint, _, err := sm.GetSandboxInfoBySessionId("", "default", "test-agent", KindAgent) + if err != nil { + t.Fatalf("Failed to create session %d: %v", i, err) + } + endpointCount[endpoint]++ + } + + // Each endpoint should be used equally + expectedCount := numSessions / len(endpoints) + for _, endpoint := range endpoints { + if count := endpointCount[endpoint]; count != expectedCount { + t.Errorf("Expected endpoint %s to be used %d times, got %d", endpoint, expectedCount, count) + } + } +} + +func TestMockSessionManager_EmptyEndpoints(t *testing.T) { + // Test with empty endpoints (should use defaults) + sm := NewMockSessionManager([]string{}) + + endpoint, sessionID, err := sm.GetSandboxInfoBySessionId("", "default", "test-agent", KindAgent) + if err != nil { + t.Fatalf("Failed to create session with default endpoints: %v", err) + } + + if endpoint == "" { + t.Error("Expected non-empty endpoint with default configuration") + } + + if sessionID == "" { + t.Error("Expected non-empty session ID") + } +} + +func TestMockSessionManager_SessionKeyGeneration(t *testing.T) { + endpoints := []string{"http://sandbox-1:8080"} + sm := NewMockSessionManager(endpoints) + + // Test session with namespace and name + endpoint1, sessionID1, err := sm.GetSandboxInfoBySessionId("", "namespace1", "agent1", KindAgent) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + // Test session with different namespace but same name + _, sessionID2, err := sm.GetSandboxInfoBySessionId("", "namespace2", "agent1", KindAgent) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + // Should create different sessions + if sessionID1 == sessionID2 { + t.Error("Expected different session IDs for different namespaces") + } + + // Test reusing first session + endpoint3, sessionID3, err := sm.GetSandboxInfoBySessionId(sessionID1, "namespace1", "agent1", KindAgent) + if err != nil { + t.Fatalf("Failed to reuse session: %v", err) + } + + if endpoint1 != endpoint3 || sessionID1 != sessionID3 { + t.Error("Expected same endpoint and session ID when reusing session") + } +} + +func TestMockSessionManager_GetSessionCount(t *testing.T) { + endpoints := []string{"http://sandbox-1:8080"} + sm := NewMockSessionManager(endpoints) + + // Initially should have 0 sessions + if count := sm.GetSessionCount(); count != 0 { + t.Errorf("Expected 0 initial sessions, got %d", count) + } + + // Create some sessions + numSessions := 3 + for i := 0; i < numSessions; i++ { + _, _, err := sm.GetSandboxInfoBySessionId("", "default", "test-agent", KindAgent) + if err != nil { + t.Fatalf("Failed to create session %d: %v", i, err) + } + } + + // Should have the expected number of sessions + if count := sm.GetSessionCount(); count != numSessions { + t.Errorf("Expected %d sessions, got %d", numSessions, count) + } +} + +func TestMockSessionManager_RemoveSession(t *testing.T) { + endpoints := []string{"http://sandbox-1:8080"} + sm := NewMockSessionManager(endpoints) + + // Create a session + _, sessionID, err := sm.GetSandboxInfoBySessionId("", "default", "test-agent", KindAgent) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + // Verify session exists + if count := sm.GetSessionCount(); count != 1 { + t.Errorf("Expected 1 session, got %d", count) + } + + // Remove the session using the composite key (namespace/name/sessionId) + compositeKey := fmt.Sprintf("default/test-agent/%s", sessionID) + sm.RemoveSession(compositeKey) + + // Verify session was removed + if count := sm.GetSessionCount(); count != 0 { + t.Errorf("Expected 0 sessions after removal, got %d", count) + } + + // Test removing session without namespace/name (should not affect anything) + sm.RemoveSession(sessionID) + if count := sm.GetSessionCount(); count != 0 { + t.Errorf("Expected 0 sessions after second removal attempt, got %d", count) + } +} + +func TestMockSessionManager_ConcurrentAccess(t *testing.T) { + endpoints := []string{"http://sandbox-1:8080", "http://sandbox-2:8080"} + sm := NewMockSessionManager(endpoints) + + // Test concurrent session creation + done := make(chan bool, 10) + sessionIDs := make(chan string, 10) + + for i := 0; i < 10; i++ { + go func(index int) { + _, sessionID, err := sm.GetSandboxInfoBySessionId("", "default", "test-agent", KindAgent) + if err != nil { + t.Errorf("Concurrent session creation failed: %v", err) + } else { + sessionIDs <- sessionID + } + done <- true + }(i) + } + + // Wait for all goroutines to complete + for i := 0; i < 10; i++ { + <-done + } + close(sessionIDs) + + // Verify all sessions were created with unique IDs + uniqueIDs := make(map[string]bool) + for sessionID := range sessionIDs { + if uniqueIDs[sessionID] { + t.Errorf("Duplicate session ID found: %s", sessionID) + } + uniqueIDs[sessionID] = true + } + + if len(uniqueIDs) != 10 { + t.Errorf("Expected 10 unique session IDs, got %d", len(uniqueIDs)) + } +} diff --git a/pkg/router/utils.go b/pkg/router/utils.go new file mode 100644 index 00000000..c75d2951 --- /dev/null +++ b/pkg/router/utils.go @@ -0,0 +1,60 @@ +package router + +import ( + "strconv" + "time" + + "github.com/gin-gonic/gin" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + sandboxv1alpha1 "sigs.k8s.io/agent-sandbox/api/v1alpha1" +) + +// ErrorResponse represents an API error response +type ErrorResponse struct { + Error string `json:"error"` + Message string `json:"message"` + Details map[string]interface{} `json:"details,omitempty"` + Timestamp time.Time `json:"timestamp"` + RequestID string `json:"requestId,omitempty"` +} + +// respondJSON sends a JSON response +func respondJSON(c *gin.Context, statusCode int, data interface{}) { + c.JSON(statusCode, data) +} + +// respondError sends an error response +func respondError(c *gin.Context, statusCode int, errorCode, message string) { + response := ErrorResponse{ + Error: errorCode, + Message: message, + Timestamp: time.Now(), + } + respondJSON(c, statusCode, response) +} + +// getIntQueryParam gets an integer value from query parameters, returns default value if not present +func getIntQueryParam(c *gin.Context, key string, defaultValue int) int { + valueStr := c.Query(key) + if valueStr == "" { + return defaultValue + } + + value, err := strconv.Atoi(valueStr) + if err != nil { + return defaultValue + } + + return value +} + +// getSandboxStatus extracts status from Sandbox CRD conditions +func getSandboxStatus(sandbox *sandboxv1alpha1.Sandbox) string { + // Check conditions for Ready status + for _, condition := range sandbox.Status.Conditions { + if condition.Type == string(sandboxv1alpha1.SandboxConditionReady) && condition.Status == metav1.ConditionTrue { + return "running" + } + } + return "paused" +} From b4cdaac85223ed79da4f912c008de023711d605a Mon Sep 17 00:00:00 2001 From: LeslieKuo <676365950@qq.com> Date: Mon, 1 Dec 2025 21:43:10 +0800 Subject: [PATCH 2/6] session manager implementation Signed-off-by: LeslieKuo <676365950@qq.com> --- pkg/sessionmgr/manager.go | 123 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100644 pkg/sessionmgr/manager.go diff --git a/pkg/sessionmgr/manager.go b/pkg/sessionmgr/manager.go new file mode 100644 index 00000000..4aaf7eba --- /dev/null +++ b/pkg/sessionmgr/manager.go @@ -0,0 +1,123 @@ +package sessionmgr + +import ( + "context" + "errors" + "fmt" + + "github.com/volcano-sh/agentcube/pkg/common/types" + "github.com/volcano-sh/agentcube/pkg/redis" +) + +// Manager defines the session management behavior on top of Redis and the sandbox manager. +type Manager interface { + // GetSandboxBySession returns the sandbox associated with the given sessionID. + GetSandboxBySession(ctx context.Context, sessionID string) (*types.SandboxRedis, error) + // CreateSandbox creates a new sandbox via the sandbox manager and returns a SandboxRedis view. + CreateSandbox(ctx context.Context, req *types.CreateSandboxRequest) (*types.SandboxRedis, error) +} + +// RedisClient is the subset of the redis.Client interface used by the session manager. +// A redis.Client returned by redis.NewClient satisfies this interface. +type RedisClient interface { + GetSandboxBySessionID(ctx context.Context, sessionID string) (*types.SandboxRedis, error) +} + +// SandboxManagerClient defines the sandbox manager operations used by the session manager. +type SandboxManagerClient interface { + CreateSandbox(ctx context.Context, req *types.CreateSandboxRequest) (*types.CreateSandboxResponse, error) +} + +// manager is the default implementation of the Manager interface. +type manager struct { + redis RedisClient + sandbox SandboxManagerClient +} + +// New returns a default Manager implementation. +// Redis and sandbox manager clients are injected from the outside to make testing +// and implementation swapping easier. +func New(redisClient RedisClient, sandboxClient SandboxManagerClient) Manager { + return &manager{ + redis: redisClient, + sandbox: sandboxClient, + } +} + +// GetSandboxBySession looks up the sandbox by sessionID using Redis. +func (m *manager) GetSandboxBySession(ctx context.Context, sessionID string) (*types.SandboxRedis, error) { + if sessionID == "" { + return nil, ErrInvalidArgument + } + + // For now we do not validate the SessionID format; any non-empty string is treated as valid. + + sb, err := m.redis.GetSandboxBySessionID(ctx, sessionID) + if err != nil { + // redis.ErrNotFound is mapped to the unified ErrSessionNotFound in session manager. + if errors.Is(err, redis.ErrNotFound) { + return nil, ErrSessionNotFound + } + // Other errors are wrapped and propagated for upper layers to log and map to 5xx. + return nil, fmt.Errorf("sessionmgr: get sandbox by sessionID %q from redis failed: %w", sessionID, err) + } + if sb == nil { + return nil, fmt.Errorf("sessionmgr: get sandbox by sessionID %q returned nil sandbox", sessionID) + } + + return sb, nil +} + +// CreateSandbox creates a new sandbox via the sandbox manager using the shared CreateSandboxRequest type. +func (m *manager) CreateSandbox(ctx context.Context, req *types.CreateSandboxRequest) (*types.SandboxRedis, error) { + if req == nil { + return nil, ErrInvalidArgument + } + // Basic argument validation: creating a sandbox requires at least Kind and Namespace. + if req.Kind == "" || req.Namespace == "" { + return nil, ErrInvalidArgument + } + + cResp, err := m.sandbox.CreateSandbox(ctx, req) + if err != nil { + // Upstream network/timeouts etc. are treated as ErrUpstreamUnavailable. + // Here we roughly distinguish by whether the error is our ErrCreateSandboxFailed. + if errors.Is(err, ErrCreateSandboxFailed) { + return nil, err + } + return nil, fmt.Errorf("%w: %v", ErrUpstreamUnavailable, err) + } + + if cResp == nil || cResp.SessionID == "" || cResp.SandboxID == "" { + return nil, fmt.Errorf("%w: invalid response from sandbox manager", ErrCreateSandboxFailed) + } + + // Construct a SandboxRedis view from the response so that callers + // see a consistent sandbox object. + sb := &types.SandboxRedis{ + SandboxID: cResp.SandboxID, + SandboxName: cResp.SandboxName, + EntryPoints: cResp.Accesses, + SessionID: cResp.SessionID, + // CreatedAt / ExpiresAt / Status can be filled later when they are available. + } + + return sb, nil +} + +var ( + // ErrInvalidArgument indicates that the request arguments are invalid + // (for example, missing kind/namespace when creating a sandbox). + ErrInvalidArgument = errors.New("sessionmgr: invalid argument") + + // ErrSessionNotFound indicates that the session does not exist in redis, + // and is typically mapped to HTTP 404/410. + ErrSessionNotFound = errors.New("sessionmgr: session not found") + + // ErrUpstreamUnavailable indicates that the sandbox manager is unavailable + // (e.g. due to network errors), and is typically mapped to HTTP 503. + ErrUpstreamUnavailable = errors.New("sessionmgr: sandbox manager unavailable") + + // ErrCreateSandboxFailed indicates that the sandbox manager returned a business-level error. + ErrCreateSandboxFailed = errors.New("sessionmgr: create sandbox failed") +) From 1aef50c02875ce0237622123db55b4d6ee0f66e0 Mon Sep 17 00:00:00 2001 From: LeslieKuo <676365950@qq.com> Date: Mon, 1 Dec 2025 21:51:04 +0800 Subject: [PATCH 3/6] add UT Signed-off-by: LeslieKuo <676365950@qq.com> --- pkg/sessionmgr/manager_test.go | 335 +++++++++++++++++++++++++++++++++ 1 file changed, 335 insertions(+) create mode 100644 pkg/sessionmgr/manager_test.go diff --git a/pkg/sessionmgr/manager_test.go b/pkg/sessionmgr/manager_test.go new file mode 100644 index 00000000..49caff7e --- /dev/null +++ b/pkg/sessionmgr/manager_test.go @@ -0,0 +1,335 @@ +package sessionmgr + +import ( + "context" + "errors" + "strings" + "testing" + + "github.com/volcano-sh/agentcube/pkg/common/types" + "github.com/volcano-sh/agentcube/pkg/redis" +) + +// ---- fakes ---- + +type fakeRedisClient struct { + sandbox *types.SandboxRedis + err error + called bool + lastSessionID string + lastContextNil bool +} + +func (f *fakeRedisClient) GetSandboxBySessionID(ctx context.Context, sessionID string) (*types.SandboxRedis, error) { + f.called = true + f.lastSessionID = sessionID + f.lastContextNil = ctx == nil + return f.sandbox, f.err +} + +type fakeSandboxManagerClient struct { + resp *types.CreateSandboxResponse + err error + called bool + lastReq *types.CreateSandboxRequest + lastCtxNil bool + calls int +} + +func (f *fakeSandboxManagerClient) CreateSandbox(ctx context.Context, req *types.CreateSandboxRequest) (*types.CreateSandboxResponse, error) { + f.called = true + f.calls++ + f.lastReq = req + f.lastCtxNil = ctx == nil + return f.resp, f.err +} + +// ---- tests: GetSandboxBySession ---- + +func TestGetSandboxBySession_Success(t *testing.T) { + ctx := context.Background() + + sb := &types.SandboxRedis{ + SandboxID: "sandbox-1", + SandboxName: "sandbox-1", + EntryPoints: []types.SandboxAccess{ + {Endpoint: "10.0.0.1:9000"}, + }, + SessionID: "sess-1", + Status: "running", + } + + r := &fakeRedisClient{ + sandbox: sb, + } + m := New(r, &fakeSandboxManagerClient{}) + + got, err := m.GetSandboxBySession(ctx, "sess-1") + if err != nil { + t.Fatalf("GetSandboxBySession unexpected error: %v", err) + } + if !r.called { + t.Fatalf("expected RedisClient to be called") + } + if r.lastSessionID != "sess-1" { + t.Fatalf("expected RedisClient to be called with sessionID 'sess-1', got %q", r.lastSessionID) + } + if got == nil { + t.Fatalf("expected non-nil sandbox") + } + if got.SandboxID != "sandbox-1" { + t.Fatalf("unexpected SandboxID: got %q, want %q", got.SandboxID, "sandbox-1") + } +} + +func TestGetSandboxBySession_EmptySessionID(t *testing.T) { + ctx := context.Background() + m := New(&fakeRedisClient{}, &fakeSandboxManagerClient{}) + + _, err := m.GetSandboxBySession(ctx, "") + if err == nil { + t.Fatalf("expected error for empty sessionID") + } + if !errors.Is(err, ErrInvalidArgument) { + t.Fatalf("expected ErrInvalidArgument, got %v", err) + } +} + +func TestGetSandboxBySession_NotFound(t *testing.T) { + ctx := context.Background() + r := &fakeRedisClient{ + sandbox: nil, + err: redis.ErrNotFound, + } + m := New(r, &fakeSandboxManagerClient{}) + + _, err := m.GetSandboxBySession(ctx, "sess-1") + if err == nil { + t.Fatalf("expected error for not found session") + } + if !errors.Is(err, ErrSessionNotFound) { + t.Fatalf("expected ErrSessionNotFound, got %v", err) + } +} + +func TestGetSandboxBySession_OtherErrorWrapped(t *testing.T) { + ctx := context.Background() + inner := errors.New("redis boom") + r := &fakeRedisClient{ + sandbox: nil, + err: inner, + } + m := New(r, &fakeSandboxManagerClient{}) + + _, err := m.GetSandboxBySession(ctx, "sess-1") + if err == nil { + t.Fatalf("expected error") + } + // Should wrap the inner error. + if !errors.Is(err, inner) { + t.Fatalf("expected error to wrap inner error, got %v", err) + } + if !strings.Contains(err.Error(), "sessionmgr: get sandbox by sessionID") { + t.Fatalf("unexpected error message: %v", err) + } +} + +func TestGetSandboxBySession_NilSandbox(t *testing.T) { + ctx := context.Background() + r := &fakeRedisClient{ + sandbox: nil, + err: nil, + } + m := New(r, &fakeSandboxManagerClient{}) + + _, err := m.GetSandboxBySession(ctx, "sess-1") + if err == nil { + t.Fatalf("expected error when redis returns nil sandbox") + } + if !strings.Contains(err.Error(), "returned nil sandbox") { + t.Fatalf("unexpected error message: %v", err) + } +} + +// ---- tests: CreateSandbox ---- + +func TestCreateSandbox_Success(t *testing.T) { + ctx := context.Background() + + req := &types.CreateSandboxRequest{ + Kind: "agent", + Name: "sandbox-name", + Namespace: "default", + } + + resp := &types.CreateSandboxResponse{ + SessionID: "sess-1", + SandboxID: "sandbox-1", + SandboxName: "sandbox-name", + Accesses: []types.SandboxAccess{ + {Endpoint: "10.0.0.1:9000"}, + }, + } + + s := &fakeSandboxManagerClient{ + resp: resp, + } + m := New(&fakeRedisClient{}, s) + + got, err := m.CreateSandbox(ctx, req) + if err != nil { + t.Fatalf("CreateSandbox unexpected error: %v", err) + } + if !s.called { + t.Fatalf("expected SandboxManagerClient to be called") + } + if got == nil { + t.Fatalf("expected non-nil sandbox") + } + if got.SandboxID != "sandbox-1" { + t.Fatalf("unexpected SandboxID: got %q, want %q", got.SandboxID, "sandbox-1") + } + if got.SessionID != "sess-1" { + t.Fatalf("unexpected SessionID: got %q, want %q", got.SessionID, "sess-1") + } + if len(got.EntryPoints) != 1 || got.EntryPoints[0].Endpoint != "10.0.0.1:9000" { + t.Fatalf("unexpected EntryPoints: %+v", got.EntryPoints) + } +} + +func TestCreateSandbox_NilRequest(t *testing.T) { + ctx := context.Background() + s := &fakeSandboxManagerClient{} + m := New(&fakeRedisClient{}, s) + + _, err := m.CreateSandbox(ctx, nil) + if err == nil { + t.Fatalf("expected error for nil request") + } + if !errors.Is(err, ErrInvalidArgument) { + t.Fatalf("expected ErrInvalidArgument, got %v", err) + } + if s.called { + t.Fatalf("expected SandboxManagerClient not to be called for nil request") + } +} + +func TestCreateSandbox_InvalidRequest_MissingKindOrNamespace(t *testing.T) { + ctx := context.Background() + s := &fakeSandboxManagerClient{} + m := New(&fakeRedisClient{}, s) + + // Missing Kind. + _, err := m.CreateSandbox(ctx, &types.CreateSandboxRequest{ + Kind: "", + Name: "name", + Namespace: "ns", + }) + if err == nil { + t.Fatalf("expected error for missing Kind") + } + if !errors.Is(err, ErrInvalidArgument) { + t.Fatalf("expected ErrInvalidArgument, got %v", err) + } + if s.called { + t.Fatalf("expected SandboxManagerClient not to be called when Kind is empty") + } + + // Reset fake. + s.called = false + s.calls = 0 + + // Missing Namespace. + _, err = m.CreateSandbox(ctx, &types.CreateSandboxRequest{ + Kind: "agent", + Name: "name", + Namespace: "", + }) + if err == nil { + t.Fatalf("expected error for missing Namespace") + } + if !errors.Is(err, ErrInvalidArgument) { + t.Fatalf("expected ErrInvalidArgument, got %v", err) + } + if s.called { + t.Fatalf("expected SandboxManagerClient not to be called when Namespace is empty") + } +} + +func TestCreateSandbox_CreateSandboxFailed(t *testing.T) { + ctx := context.Background() + s := &fakeSandboxManagerClient{ + err: ErrCreateSandboxFailed, + } + m := New(&fakeRedisClient{}, s) + + req := &types.CreateSandboxRequest{ + Kind: "agent", + Name: "sandbox-name", + Namespace: "default", + } + + _, err := m.CreateSandbox(ctx, req) + if err == nil { + t.Fatalf("expected error") + } + // ErrCreateSandboxFailed should be propagated as is. + if !errors.Is(err, ErrCreateSandboxFailed) { + t.Fatalf("expected ErrCreateSandboxFailed, got %v", err) + } +} + +func TestCreateSandbox_UpstreamUnavailableWrapped(t *testing.T) { + ctx := context.Background() + inner := errors.New("upstream timeout") + s := &fakeSandboxManagerClient{ + err: inner, + } + m := New(&fakeRedisClient{}, s) + + req := &types.CreateSandboxRequest{ + Kind: "agent", + Name: "sandbox-name", + Namespace: "default", + } + + _, err := m.CreateSandbox(ctx, req) + if err == nil { + t.Fatalf("expected error") + } + // Should be wrapped as ErrUpstreamUnavailable. + if !errors.Is(err, ErrUpstreamUnavailable) { + t.Fatalf("expected error to be ErrUpstreamUnavailable, got %v", err) + } + // Inner error is included in the message, but not wrapped with %w, + // so errors.Is on inner should be false. + if errors.Is(err, inner) { + t.Fatalf("did not expect error to wrap inner error via errors.Is") + } + if !strings.Contains(err.Error(), "upstream timeout") { + t.Fatalf("expected error message to contain inner error, got %v", err) + } +} + +func TestCreateSandbox_InvalidResponse(t *testing.T) { + ctx := context.Background() + // Response missing SessionID and SandboxID. + s := &fakeSandboxManagerClient{ + resp: &types.CreateSandboxResponse{}, + } + m := New(&fakeRedisClient{}, s) + + req := &types.CreateSandboxRequest{ + Kind: "agent", + Name: "sandbox-name", + Namespace: "default", + } + + _, err := m.CreateSandbox(ctx, req) + if err == nil { + t.Fatalf("expected error for invalid CreateSandboxResponse") + } + if !errors.Is(err, ErrCreateSandboxFailed) { + t.Fatalf("expected ErrCreateSandboxFailed, got %v", err) + } +} From b7a3fe3c8e85bf5cfc58c1447924313efd75de01 Mon Sep 17 00:00:00 2001 From: VanderChen Date: Wed, 3 Dec 2025 09:25:46 +0800 Subject: [PATCH 4/6] Integrate session manager in router Signed-off-by: VanderChen --- cmd/router/main.go | 13 +- pkg/router/apiserver.go | 52 ++++- pkg/router/apiserver_test.go | 128 +++++++---- pkg/router/handlers.go | 81 +++++-- pkg/router/redis_manager.go | 135 ------------ pkg/router/redis_manager_test.go | 250 --------------------- pkg/router/session_manager.go | 201 ++++++++++------- pkg/router/session_manager_test.go | 339 +++++++---------------------- pkg/sessionmgr/manager.go | 123 ----------- pkg/sessionmgr/manager_test.go | 335 ---------------------------- 10 files changed, 385 insertions(+), 1272 deletions(-) delete mode 100644 pkg/router/redis_manager.go delete mode 100644 pkg/router/redis_manager_test.go delete mode 100644 pkg/sessionmgr/manager.go delete mode 100644 pkg/sessionmgr/manager_test.go diff --git a/cmd/router/main.go b/cmd/router/main.go index 7e0768de..f33cb84c 100644 --- a/cmd/router/main.go +++ b/cmd/router/main.go @@ -52,13 +52,10 @@ func main() { log.Fatalf("Failed to create Router API server: %v", err) } - // Setup signal handling - ctx, cancel := context.WithCancel(context.Background()) + // Setup signal handling with context cancellation + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer cancel() - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) - // Start Router API server in goroutine errCh := make(chan error, 1) go func() { @@ -70,10 +67,10 @@ func main() { // Wait for signal or error select { - case <-sigCh: + case <-ctx.Done(): log.Println("Received shutdown signal, shutting down gracefully...") - cancel() - time.Sleep(2 * time.Second) // Give server time to shutdown gracefully + // Wait for server to finish shutting down + time.Sleep(2 * time.Second) case err := <-errCh: log.Fatalf("Server error: %v", err) } diff --git a/pkg/router/apiserver.go b/pkg/router/apiserver.go index 4df9c271..608614a5 100644 --- a/pkg/router/apiserver.go +++ b/pkg/router/apiserver.go @@ -5,9 +5,12 @@ import ( "fmt" "log" "net/http" + "os" "time" "github.com/gin-gonic/gin" + redisv9 "github.com/redis/go-redis/v9" + "github.com/volcano-sh/agentcube/pkg/redis" ) // Server is the main structure for Router apiserver @@ -16,8 +19,26 @@ type Server struct { engine *gin.Engine httpServer *http.Server sessionManager SessionManager - redisManager RedisManager - semaphore chan struct{} // For limiting concurrent requests + redisClient redis.Client + semaphore chan struct{} // For limiting concurrent requests + httpTransport *http.Transport // Reusable HTTP transport for connection pooling +} + +// makeRedisOptions creates redis options from environment variables +func makeRedisOptions() (*redisv9.Options, error) { + redisAddr := os.Getenv("REDIS_ADDR") + if redisAddr == "" { + return nil, fmt.Errorf("missing env var REDIS_ADDR") + } + redisPassword := os.Getenv("REDIS_PASSWORD") + if redisPassword == "" { + return nil, fmt.Errorf("missing env var REDIS_PASSWORD") + } + redisOptions := &redisv9.Options{ + Addr: redisAddr, + Password: redisPassword, + } + return redisOptions, nil } // NewServer creates a new Router API server instance @@ -43,11 +64,18 @@ func NewServer(config *Config) (*Server, error) { config.SessionExpireDuration = 3600 // Default 1 hour } - // Create session manager (using mock implementation) - sessionManager := NewMockSessionManager(config.SandboxEndpoints) + // Initialize Redis client + redisOptions, err := makeRedisOptions() + if err != nil { + return nil, fmt.Errorf("make redis options failed: %w", err) + } + redisClient := redis.NewClient(redisOptions) - // Create Redis manager (using mock implementation) - redisManager := NewMockRedisManager(config.EnableRedis) + // Create session manager with redis client + sessionManager, err := NewSessionManager(redisClient) + if err != nil { + return nil, fmt.Errorf("failed to create session manager: %w", err) + } // Set Gin mode based on environment if config.Debug { @@ -56,11 +84,21 @@ func NewServer(config *Config) (*Server, error) { gin.SetMode(gin.ReleaseMode) } + // Create a reusable HTTP transport for connection pooling + httpTransport := &http.Transport{ + MaxIdleConns: config.MaxIdleConns, + MaxIdleConnsPerHost: config.MaxConnsPerHost, + IdleConnTimeout: 90 * time.Second, + DisableCompression: false, + ForceAttemptHTTP2: true, + } + server := &Server{ config: config, sessionManager: sessionManager, - redisManager: redisManager, + redisClient: redisClient, semaphore: make(chan struct{}, config.MaxConcurrentRequests), + httpTransport: httpTransport, } // Setup routes diff --git a/pkg/router/apiserver_test.go b/pkg/router/apiserver_test.go index 0cbf8dfe..9bc8a087 100644 --- a/pkg/router/apiserver_test.go +++ b/pkg/router/apiserver_test.go @@ -2,11 +2,22 @@ package router import ( "context" + "os" "testing" "time" ) func TestNewServer(t *testing.T) { + // Set required environment variables for tests + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + tests := []struct { name string config *Config @@ -76,9 +87,9 @@ func TestNewServer(t *testing.T) { t.Error("Session manager was not created") } - // Verify redis manager was created - if server.redisManager == nil { - t.Error("Redis manager was not created") + // Verify redis client was created + if server.redisClient == nil { + t.Error("Redis client was not created") } // Verify semaphore was created with correct capacity @@ -112,6 +123,16 @@ func TestNewServer(t *testing.T) { } func TestServer_DefaultValues(t *testing.T) { + // Set required environment variables for tests + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + config := &Config{ Port: "8080", // Leave other values as zero to test defaults @@ -145,6 +166,16 @@ func TestServer_DefaultValues(t *testing.T) { } func TestServer_ConcurrencyLimitMiddleware(t *testing.T) { + // Set required environment variables for tests + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + config := &Config{ Port: "8080", MaxConcurrentRequests: 2, // Small limit for testing @@ -168,6 +199,16 @@ func TestServer_ConcurrencyLimitMiddleware(t *testing.T) { } func TestServer_SetupRoutes(t *testing.T) { + // Set required environment variables for tests + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + config := &Config{ Port: "8080", } @@ -187,6 +228,16 @@ func TestServer_SetupRoutes(t *testing.T) { } func TestServer_StartContext(t *testing.T) { + // Set required environment variables for tests + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + config := &Config{ Port: "0", // Use port 0 to let the OS assign a free port } @@ -225,6 +276,16 @@ func TestServer_StartContext(t *testing.T) { } func TestServer_TLSConfiguration(t *testing.T) { + // Set required environment variables for tests + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + tests := []struct { name string config *Config @@ -294,50 +355,27 @@ func TestServer_TLSConfiguration(t *testing.T) { } func TestServer_RedisIntegration(t *testing.T) { - tests := []struct { - name string - enableRedis bool - }{ - { - name: "Redis enabled", - enableRedis: true, - }, - { - name: "Redis disabled", - enableRedis: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - config := &Config{ - Port: "8080", - EnableRedis: tt.enableRedis, - } + // Set required environment variables for tests + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() - server, err := NewServer(config) - if err != nil { - t.Fatalf("Failed to create server: %v", err) - } + config := &Config{ + Port: "8080", + } - // Verify Redis manager was created - if server.redisManager == nil { - t.Error("Redis manager was not created") - } + server, err := NewServer(config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } - // Test Redis manager functionality based on enabled state - err = server.redisManager.UpdateSessionActivity("test-session") - if tt.enableRedis { - if err != nil { - t.Errorf("Expected no error when Redis is enabled, got: %v", err) - } - } else { - // When disabled, UpdateSessionActivity should not return an error - // (it silently skips) - if err != nil { - t.Errorf("Expected no error when Redis is disabled, got: %v", err) - } - } - }) + // Verify Redis client was created + if server.redisClient == nil { + t.Error("Redis client was not created") } } diff --git a/pkg/router/handlers.go b/pkg/router/handlers.go index bb3aec1d..12a36bf9 100644 --- a/pkg/router/handlers.go +++ b/pkg/router/handlers.go @@ -54,7 +54,7 @@ func (s *Server) handleAgentInvoke(c *gin.Context) { sessionID := c.GetHeader("x-agentcube-session-id") // Get sandbox info from session manager - endpoint, newSessionID, err := s.sessionManager.GetSandboxInfoBySessionId(sessionID, agentNamespace, agentName, KindAgent) + sandbox, err := s.sessionManager.GetSandboxBySession(sessionID, agentNamespace, agentName, "AgentRuntime") if err != nil { log.Printf("Failed to get sandbox info: %v", err) c.JSON(http.StatusInternalServerError, gin.H{ @@ -64,15 +64,37 @@ func (s *Server) handleAgentInvoke(c *gin.Context) { return } + // Extract endpoint from sandbox - find matching entry point by path + var endpoint string + for _, ep := range sandbox.EntryPoints { + if ep.Path == path || ep.Path == "" { + endpoint = ep.Endpoint + break + } + } + + // If no matching endpoint found, use the first one as fallback + if endpoint == "" { + if len(sandbox.EntryPoints) == 0 { + log.Printf("No entry points found for sandbox: %s", sandbox.SandboxID) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "internal server error", + "code": "INTERNAL_ERROR", + }) + return + } + endpoint = sandbox.EntryPoints[0].Endpoint + } + // Update session activity in Redis when receiving request - if newSessionID != "" { - if err := s.redisManager.UpdateSessionActivity(newSessionID); err != nil { - log.Printf("Failed to update session activity for request: %v", err) + if sandbox.SessionID != "" && sandbox.SandboxID != "" { + if err := s.redisClient.UpdateSandboxLastActivity(c.Request.Context(), sandbox.SandboxID, time.Now()); err != nil { + log.Printf("Failed to update sandbox last activity for request: %v", err) } } // Forward request to sandbox with session ID - s.forwardToSandbox(c, endpoint, path, newSessionID) + s.forwardToSandbox(c, endpoint, path, sandbox.SessionID) } // handleCodeInterpreterInvoke handles code interpreter invocation requests @@ -87,7 +109,7 @@ func (s *Server) handleCodeInterpreterInvoke(c *gin.Context) { sessionID := c.GetHeader("x-agentcube-session-id") // Get sandbox info from session manager - endpoint, newSessionID, err := s.sessionManager.GetSandboxInfoBySessionId(sessionID, namespace, name, KindCodeInterpreter) + sandbox, err := s.sessionManager.GetSandboxBySession(sessionID, namespace, name, "CodeInterpreter") if err != nil { log.Printf("Failed to get sandbox info: %v", err) c.JSON(http.StatusInternalServerError, gin.H{ @@ -97,15 +119,37 @@ func (s *Server) handleCodeInterpreterInvoke(c *gin.Context) { return } + // Extract endpoint from sandbox - find matching entry point by path + var endpoint string + for _, ep := range sandbox.EntryPoints { + if ep.Path == path || ep.Path == "" { + endpoint = ep.Endpoint + break + } + } + + // If no matching endpoint found, use the first one as fallback + if endpoint == "" { + if len(sandbox.EntryPoints) == 0 { + log.Printf("No entry points found for sandbox: %s", sandbox.SandboxID) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "internal server error", + "code": "INTERNAL_ERROR", + }) + return + } + endpoint = sandbox.EntryPoints[0].Endpoint + } + // Update session activity in Redis when receiving request - if newSessionID != "" { - if err := s.redisManager.UpdateSessionActivity(newSessionID); err != nil { - log.Printf("Failed to update session activity for request: %v", err) + if sandbox.SessionID != "" && sandbox.SandboxID != "" { + if err := s.redisClient.UpdateSandboxLastActivity(c.Request.Context(), sandbox.SandboxID, time.Now()); err != nil { + log.Printf("Failed to update sandbox last activity for request: %v", err) } } // Forward request to sandbox with session ID - s.forwardToSandbox(c, endpoint, path, newSessionID) + s.forwardToSandbox(c, endpoint, path, sandbox.SessionID) } // forwardToSandbox forwards the request to the specified sandbox endpoint @@ -121,17 +165,11 @@ func (s *Server) forwardToSandbox(c *gin.Context, endpoint, path, sessionID stri return } - // Create reverse proxy with optimized transport + // Create reverse proxy with reusable transport proxy := httputil.NewSingleHostReverseProxy(targetURL) - // Configure HTTP transport for better concurrency - proxy.Transport = &http.Transport{ - MaxIdleConns: s.config.MaxIdleConns, - MaxIdleConnsPerHost: s.config.MaxConnsPerHost, - IdleConnTimeout: 90 * time.Second, - DisableCompression: false, - ForceAttemptHTTP2: true, - } + // Use the shared HTTP transport for connection pooling + proxy.Transport = s.httpTransport // Customize the director to modify the request originalDirector := proxy.Director @@ -186,11 +224,6 @@ func (s *Server) forwardToSandbox(c *gin.Context, endpoint, path, sessionID stri // Always set session ID in response header if sessionID != "" { resp.Header.Set("x-agentcube-session-id", sessionID) - - // Update session activity in Redis when returning response - if err := s.redisManager.UpdateSessionActivity(sessionID); err != nil { - log.Printf("Failed to update session activity for response: %v", err) - } } return nil } diff --git a/pkg/router/redis_manager.go b/pkg/router/redis_manager.go deleted file mode 100644 index cc90d9c3..00000000 --- a/pkg/router/redis_manager.go +++ /dev/null @@ -1,135 +0,0 @@ -package router - -import ( - "fmt" - "log" - "sync" - "time" -) - -// RedisManager interface for managing session activity in Redis -type RedisManager interface { - // UpdateSessionActivity updates the lastActive time for a session ID - UpdateSessionActivity(sessionID string) error - - // GetSessionLastActive gets the last active time for a session ID - GetSessionLastActive(sessionID string) (time.Time, error) - - // CleanupExpiredSessions removes sessions that haven't been active for a specified duration - CleanupExpiredSessions(expireDuration time.Duration) error -} - -// MockRedisManager is a mock implementation for testing -type MockRedisManager struct { - mu sync.RWMutex - sessions map[string]time.Time // sessionID -> lastActive time - enabled bool -} - -// NewMockRedisManager creates a new mock Redis manager -func NewMockRedisManager(enabled bool) *MockRedisManager { - return &MockRedisManager{ - sessions: make(map[string]time.Time), - enabled: enabled, - } -} - -// UpdateSessionActivity implements RedisManager interface -func (r *MockRedisManager) UpdateSessionActivity(sessionID string) error { - if !r.enabled { - return nil // Silently skip if disabled - } - - if sessionID == "" { - return fmt.Errorf("session ID cannot be empty") - } - - r.mu.Lock() - defer r.mu.Unlock() - - now := time.Now() - r.sessions[sessionID] = now - - log.Printf("Updated session activity for session %s at %s", sessionID, now.Format(time.RFC3339)) - return nil -} - -// GetSessionLastActive implements RedisManager interface -func (r *MockRedisManager) GetSessionLastActive(sessionID string) (time.Time, error) { - if !r.enabled { - return time.Time{}, fmt.Errorf("Redis manager is disabled") - } - - if sessionID == "" { - return time.Time{}, fmt.Errorf("session ID cannot be empty") - } - - r.mu.RLock() - defer r.mu.RUnlock() - - lastActive, exists := r.sessions[sessionID] - if !exists { - return time.Time{}, fmt.Errorf("session %s not found", sessionID) - } - - return lastActive, nil -} - -// CleanupExpiredSessions implements RedisManager interface -func (r *MockRedisManager) CleanupExpiredSessions(expireDuration time.Duration) error { - if !r.enabled { - return nil // Silently skip if disabled - } - - r.mu.Lock() - defer r.mu.Unlock() - - now := time.Now() - expiredSessions := make([]string, 0) - - for sessionID, lastActive := range r.sessions { - if now.Sub(lastActive) > expireDuration { - expiredSessions = append(expiredSessions, sessionID) - } - } - - // Remove expired sessions - for _, sessionID := range expiredSessions { - delete(r.sessions, sessionID) - log.Printf("Cleaned up expired session: %s", sessionID) - } - - if len(expiredSessions) > 0 { - log.Printf("Cleaned up %d expired sessions", len(expiredSessions)) - } - - return nil -} - -// GetActiveSessionCount returns the number of active sessions -func (r *MockRedisManager) GetActiveSessionCount() int { - if !r.enabled { - return 0 - } - - r.mu.RLock() - defer r.mu.RUnlock() - return len(r.sessions) -} - -// GetAllActiveSessions returns all active session IDs (for debugging) -func (r *MockRedisManager) GetAllActiveSessions() []string { - if !r.enabled { - return nil - } - - r.mu.RLock() - defer r.mu.RUnlock() - - sessions := make([]string, 0, len(r.sessions)) - for sessionID := range r.sessions { - sessions = append(sessions, sessionID) - } - - return sessions -} diff --git a/pkg/router/redis_manager_test.go b/pkg/router/redis_manager_test.go deleted file mode 100644 index 502a462b..00000000 --- a/pkg/router/redis_manager_test.go +++ /dev/null @@ -1,250 +0,0 @@ -package router - -import ( - "testing" - "time" -) - -func TestMockRedisManager_UpdateSessionActivity(t *testing.T) { - tests := []struct { - name string - enabled bool - sessionID string - wantErr bool - }{ - { - name: "valid session ID with enabled Redis", - enabled: true, - sessionID: "test-session-123", - wantErr: false, - }, - { - name: "empty session ID with enabled Redis", - enabled: true, - sessionID: "", - wantErr: true, - }, - { - name: "valid session ID with disabled Redis", - enabled: false, - sessionID: "test-session-123", - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r := NewMockRedisManager(tt.enabled) - err := r.UpdateSessionActivity(tt.sessionID) - - if (err != nil) != tt.wantErr { - t.Errorf("UpdateSessionActivity() error = %v, wantErr %v", err, tt.wantErr) - return - } - - // If enabled and no error, check if session was stored by checking count - if tt.enabled && !tt.wantErr { - if count := r.GetActiveSessionCount(); count != 1 { - t.Errorf("Expected 1 session to be stored, got %d", count) - } - } - }) - } -} - -func TestMockRedisManager_GetSessionLastActive(t *testing.T) { - r := NewMockRedisManager(true) - sessionID := "test-session-123" - - // Test getting non-existent session - _, err := r.GetSessionLastActive(sessionID) - if err == nil { - t.Error("Expected error for non-existent session") - } - - // Add session and test retrieval - beforeTime := time.Now() - err = r.UpdateSessionActivity(sessionID) - if err != nil { - t.Fatalf("Failed to update session activity: %v", err) - } - afterTime := time.Now() - - lastActive, err := r.GetSessionLastActive(sessionID) - if err != nil { - t.Fatalf("Failed to get session last active: %v", err) - } - - if lastActive.Before(beforeTime) || lastActive.After(afterTime) { - t.Errorf("Last active time %v is not within expected range [%v, %v]", lastActive, beforeTime, afterTime) - } - - // Test with disabled Redis - rDisabled := NewMockRedisManager(false) - _, err = rDisabled.GetSessionLastActive(sessionID) - if err == nil { - t.Error("Expected error when Redis is disabled") - } -} - -func TestMockRedisManager_CleanupExpiredSessions(t *testing.T) { - r := NewMockRedisManager(true) - - // Add some sessions - oldSessionID := "old-session" - newSessionID := "new-session" - - // Add sessions - the old one will be cleaned up based on time - err := r.UpdateSessionActivity(oldSessionID) - if err != nil { - t.Fatalf("Failed to update old session activity: %v", err) - } - - // Manually set old session timestamp by accessing the struct field - r.sessions[oldSessionID] = time.Now().Add(-2 * time.Hour) - - // Add new session - err = r.UpdateSessionActivity(newSessionID) - if err != nil { - t.Fatalf("Failed to update new session activity: %v", err) - } - - // Get initial count - initialCount := r.GetActiveSessionCount() - if initialCount != 2 { - t.Errorf("Expected 2 initial sessions, got %d", initialCount) - } - - // Cleanup sessions older than 1 hour - err = r.CleanupExpiredSessions(1 * time.Hour) - if err != nil { - t.Fatalf("Failed to cleanup expired sessions: %v", err) - } - - // Check that count decreased - finalCount := r.GetActiveSessionCount() - if finalCount != 1 { - t.Errorf("Expected 1 session after cleanup, got %d", finalCount) - } - - // Test with disabled Redis - rDisabled := NewMockRedisManager(false) - err = rDisabled.CleanupExpiredSessions(1 * time.Hour) - if err != nil { - t.Errorf("Cleanup should not fail when Redis is disabled: %v", err) - } -} - -func TestMockRedisManager_GetActiveSessionCount(t *testing.T) { - r := NewMockRedisManager(true) - - // Initially should be 0 - if count := r.GetActiveSessionCount(); count != 0 { - t.Errorf("Expected 0 active sessions, got %d", count) - } - - // Add some sessions - sessions := []string{"session1", "session2", "session3"} - for _, sessionID := range sessions { - err := r.UpdateSessionActivity(sessionID) - if err != nil { - t.Fatalf("Failed to update session activity: %v", err) - } - } - - // Should have 3 sessions - if count := r.GetActiveSessionCount(); count != 3 { - t.Errorf("Expected 3 active sessions, got %d", count) - } - - // Test with disabled Redis - rDisabled := NewMockRedisManager(false) - if count := rDisabled.GetActiveSessionCount(); count != 0 { - t.Errorf("Expected 0 active sessions when disabled, got %d", count) - } -} - -func TestMockRedisManager_GetAllActiveSessions(t *testing.T) { - r := NewMockRedisManager(true) - - // Initially should be empty - if sessions := r.GetAllActiveSessions(); len(sessions) != 0 { - t.Errorf("Expected 0 active sessions, got %d", len(sessions)) - } - - // Add some sessions - expectedSessions := []string{"session1", "session2", "session3"} - for _, sessionID := range expectedSessions { - err := r.UpdateSessionActivity(sessionID) - if err != nil { - t.Fatalf("Failed to update session activity: %v", err) - } - } - - // Get all sessions - activeSessions := r.GetAllActiveSessions() - if len(activeSessions) != len(expectedSessions) { - t.Errorf("Expected %d active sessions, got %d", len(expectedSessions), len(activeSessions)) - } - - // Check that all expected sessions are present - sessionMap := make(map[string]bool) - for _, session := range activeSessions { - sessionMap[session] = true - } - - for _, expected := range expectedSessions { - if !sessionMap[expected] { - t.Errorf("Expected session %s not found in active sessions", expected) - } - } - - // Test with disabled Redis - rDisabled := NewMockRedisManager(false) - if sessions := rDisabled.GetAllActiveSessions(); sessions != nil { - t.Error("Expected nil when Redis is disabled") - } -} - -func TestMockRedisManager_ConcurrentAccess(t *testing.T) { - r := NewMockRedisManager(true) - sessionID := "concurrent-session" - - // Test concurrent updates - done := make(chan bool, 10) - for i := 0; i < 10; i++ { - go func() { - err := r.UpdateSessionActivity(sessionID) - if err != nil { - t.Errorf("Concurrent update failed: %v", err) - } - done <- true - }() - } - - // Wait for all goroutines to complete - for i := 0; i < 10; i++ { - <-done - } - - // Check that session exists by checking count - if count := r.GetActiveSessionCount(); count != 1 { - t.Errorf("Expected 1 session after concurrent updates, got %d", count) - } - - // Test concurrent reads - for i := 0; i < 10; i++ { - go func() { - _, err := r.GetSessionLastActive(sessionID) - if err != nil { - t.Errorf("Concurrent read failed: %v", err) - } - done <- true - }() - } - - // Wait for all read goroutines to complete - for i := 0; i < 10; i++ { - <-done - } -} diff --git a/pkg/router/session_manager.go b/pkg/router/session_manager.go index 08d1e196..a273ec29 100644 --- a/pkg/router/session_manager.go +++ b/pkg/router/session_manager.go @@ -1,116 +1,155 @@ package router import ( + "bytes" + "context" + "encoding/json" + "errors" "fmt" - "sync" + "io" + "net/http" + "os" + "time" - "github.com/google/uuid" + "github.com/volcano-sh/agentcube/pkg/common/types" + "github.com/volcano-sh/agentcube/pkg/redis" ) -// Kind constants for sandbox types -const ( - KindAgent = "agent" - KindCodeInterpreter = "codeinterpreter" -) - -// SessionManager interface for managing sandbox sessions +// SessionManager defines the session management behavior on top of Redis and the workload manager. type SessionManager interface { - // GetSandboxInfoBySessionId returns sandbox endpoint and session ID - // When sessionId is empty, creates a new session - // kind can be "agent" or "codeinterpreter" - GetSandboxInfoBySessionId(sessionId, namespace, name, kind string) (endpoint string, newSessionId string, err error) + // GetSandboxBySession returns the sandbox associated with the given sessionID. + // When sessionID is empty, it creates a new sandbox by calling the external API. + // When sessionID is not empty, it queries Redis for the sandbox. + GetSandboxBySession(sessionID string, namespace string, name string, kind string) (*types.SandboxRedis, error) } -// MockSessionManager is a simple implementation for testing -type MockSessionManager struct { - mu sync.RWMutex - sandboxEndpoints []string - currentIndex int - sessions map[string]string // sessionId -> endpoint +// manager is the default implementation of the SessionManager interface. +type manager struct { + redisClient redis.Client + workloadMgrURL string + httpClient *http.Client } -// NewMockSessionManager creates a new mock session manager -func NewMockSessionManager(sandboxEndpoints []string) *MockSessionManager { - if len(sandboxEndpoints) == 0 { - // Default sandbox endpoints for testing - sandboxEndpoints = []string{ - "http://sandbox-1:8080", - "http://sandbox-2:8080", - "http://sandbox-3:8080", - } +// NewSessionManager returns a SessionManager implementation. +// redisClient is used to query sandbox information from Redis. +// workloadMgrURL is read from the environment variable WORKLOAD_MGR_URL. +func NewSessionManager(redisClient redis.Client) (SessionManager, error) { + workloadMgrURL := os.Getenv("WORKLOAD_MGR_URL") + if workloadMgrURL == "" { + return nil, fmt.Errorf("WORKLOAD_MGR_URL environment variable is not set") } - return &MockSessionManager{ - sandboxEndpoints: sandboxEndpoints, - sessions: make(map[string]string), + return &manager{ + redisClient: redisClient, + workloadMgrURL: workloadMgrURL, + httpClient: &http.Client{ + Timeout: 30 * time.Second, // Set a reasonable timeout to prevent hanging + }, + }, nil +} + +// GetSandboxBySession returns the sandbox associated with the given sessionID. +// When sessionID is empty, it creates a new sandbox by calling the external API. +// When sessionID is not empty, it queries Redis for the sandbox. +func (m *manager) GetSandboxBySession(sessionID string, namespace string, name string, kind string) (*types.SandboxRedis, error) { + ctx := context.Background() + + // When sessionID is empty, create a new sandbox + if sessionID == "" { + return m.createSandbox(ctx, namespace, name, kind) } + + // When sessionID is not empty, query Redis + sandbox, err := m.redisClient.GetSandboxBySessionID(ctx, sessionID) + if err != nil { + if errors.Is(err, redis.ErrNotFound) { + return nil, ErrSessionNotFound + } + return nil, fmt.Errorf("failed to get sandbox from redis: %w", err) + } + + return sandbox, nil } -// GetSandboxInfoBySessionId implements SessionManager interface -func (m *MockSessionManager) GetSandboxInfoBySessionId(sessionId, namespace, name, kind string) (string, string, error) { - m.mu.Lock() - defer m.mu.Unlock() +// createSandbox creates a new sandbox by calling the external workload manager API. +func (m *manager) createSandbox(ctx context.Context, namespace string, name string, kind string) (*types.SandboxRedis, error) { + // Determine the API endpoint based on kind + var endpoint string + switch kind { + case types.AgentRuntimeKind: + endpoint = m.workloadMgrURL + "/v1/agent-runtime" + case types.CodeInterpreterKind: + endpoint = m.workloadMgrURL + "/v1/code-interpreter" + default: + return nil, fmt.Errorf("unsupported kind: %s", kind) + } + + // Prepare the request body + reqBody := &types.CreateSandboxRequest{ + Kind: kind, + Name: name, + Namespace: namespace, + } - // Validate kind parameter - if kind != "" && kind != KindAgent && kind != KindCodeInterpreter { - return "", "", fmt.Errorf("invalid kind: %s, must be '%s' or '%s'", kind, KindAgent, KindCodeInterpreter) + bodyBytes, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request body: %w", err) } - // If sessionId is empty, create a new session - if sessionId == "" { - sessionId = m.generateNewSessionId() + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(bodyBytes)) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %w", err) } + req.Header.Set("Content-Type", "application/json") - // Check if session already exists - if endpoint, exists := m.sessions[sessionId]; exists { - return endpoint, sessionId, nil + // Send the request + resp, err := m.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrUpstreamUnavailable, err) } + defer resp.Body.Close() - // Create new session with round-robin endpoint selection - if len(m.sandboxEndpoints) == 0 { - return "", "", fmt.Errorf("no sandbox endpoints available") + // Read response body + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) } - // Select endpoint based on kind if specified - var endpoint string - if kind == KindAgent { - // For agent kind, prefer agent-specific endpoints or use round-robin - endpoint = m.sandboxEndpoints[m.currentIndex%len(m.sandboxEndpoints)] - } else if kind == KindCodeInterpreter { - // For codeinterpreter kind, prefer code interpreter endpoints or use round-robin - endpoint = m.sandboxEndpoints[m.currentIndex%len(m.sandboxEndpoints)] - } else { - // Default behavior for backward compatibility - endpoint = m.sandboxEndpoints[m.currentIndex%len(m.sandboxEndpoints)] + // Check response status + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("%w: status code %d, body: %s", ErrCreateSandboxFailed, resp.StatusCode, string(respBody)) } - m.currentIndex++ + // Parse response + var createResp types.CreateSandboxResponse + if err := json.Unmarshal(respBody, &createResp); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } - // Store the session with additional metadata - sessionKey := sessionId - if namespace != "" && name != "" { - sessionKey = fmt.Sprintf("%s/%s/%s", namespace, name, sessionId) + // Validate response + if createResp.SessionID == "" || createResp.SandboxID == "" { + return nil, fmt.Errorf("%w: invalid response from workload manager", ErrCreateSandboxFailed) } - m.sessions[sessionKey] = endpoint - return endpoint, sessionId, nil -} + // Construct SandboxRedis from response + sandbox := &types.SandboxRedis{ + SandboxID: createResp.SandboxID, + SandboxName: createResp.SandboxName, + SessionID: createResp.SessionID, + EntryPoints: createResp.EntryPoints, + } -// generateNewSessionId generates a new UUID-based session ID -func (m *MockSessionManager) generateNewSessionId() string { - return uuid.New().String() + return sandbox, nil } -// RemoveSession removes a session (for cleanup) -func (m *MockSessionManager) RemoveSession(sessionId string) { - m.mu.Lock() - defer m.mu.Unlock() - delete(m.sessions, sessionId) -} +var ( + // ErrSessionNotFound indicates that the session does not exist in redis. + ErrSessionNotFound = errors.New("sessionmgr: session not found") -// GetSessionCount returns the number of active sessions -func (m *MockSessionManager) GetSessionCount() int { - m.mu.RLock() - defer m.mu.RUnlock() - return len(m.sessions) -} + // ErrUpstreamUnavailable indicates that the workload manager is unavailable. + ErrUpstreamUnavailable = errors.New("sessionmgr: workload manager unavailable") + + // ErrCreateSandboxFailed indicates that the workload manager returned an error. + ErrCreateSandboxFailed = errors.New("sessionmgr: create sandbox failed") +) diff --git a/pkg/router/session_manager_test.go b/pkg/router/session_manager_test.go index 9d328df6..c9f41c79 100644 --- a/pkg/router/session_manager_test.go +++ b/pkg/router/session_manager_test.go @@ -1,305 +1,116 @@ package router import ( - "fmt" + "context" + "errors" "testing" -) - -func TestMockSessionManager_GetSandboxInfoBySessionId(t *testing.T) { - endpoints := []string{ - "http://sandbox-1:8080", - "http://sandbox-2:8080", - "http://sandbox-3:8080", - } - sm := NewMockSessionManager(endpoints) - - tests := []struct { - name string - sessionID string - namespace string - agentName string - kind string - wantErr bool - }{ - { - name: "new session with agent kind", - sessionID: "", - namespace: "default", - agentName: "test-agent", - kind: KindAgent, - wantErr: false, - }, - { - name: "new session with code interpreter kind", - sessionID: "", - namespace: "default", - agentName: "test-interpreter", - kind: KindCodeInterpreter, - wantErr: false, - }, - { - name: "existing session", - sessionID: "existing-session-123", - namespace: "default", - agentName: "test-agent", - kind: KindAgent, - wantErr: false, - }, - { - name: "invalid kind", - sessionID: "", - namespace: "default", - agentName: "test-agent", - kind: "invalid-kind", - wantErr: true, - }, - { - name: "empty kind (should work for backward compatibility)", - sessionID: "", - namespace: "default", - agentName: "test-agent", - kind: "", - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - endpoint, sessionID, err := sm.GetSandboxInfoBySessionId(tt.sessionID, tt.namespace, tt.agentName, tt.kind) - - if (err != nil) != tt.wantErr { - t.Errorf("GetSandboxInfoBySessionId() error = %v, wantErr %v", err, tt.wantErr) - return - } + "time" - if !tt.wantErr { - // Check that endpoint is not empty - if endpoint == "" { - t.Error("Expected non-empty endpoint") - } - - // Check that sessionID is not empty - if sessionID == "" { - t.Error("Expected non-empty session ID") - } + "github.com/volcano-sh/agentcube/pkg/common/types" + "github.com/volcano-sh/agentcube/pkg/redis" +) - // Check that endpoint is one of the configured endpoints - found := false - for _, ep := range endpoints { - if endpoint == ep { - found = true - break - } - } - if !found { - t.Errorf("Endpoint %s not found in configured endpoints", endpoint) - } +// ---- fakes ---- - // If we provided a session ID, it should be returned unchanged - if tt.sessionID != "" && sessionID != tt.sessionID { - t.Errorf("Expected session ID %s, got %s", tt.sessionID, sessionID) - } - } - }) - } +type fakeRedisClient struct { + sandbox *types.SandboxRedis + err error + called bool + lastSessionID string + lastContextNil bool } -func TestMockSessionManager_SessionPersistence(t *testing.T) { - endpoints := []string{"http://sandbox-1:8080"} - sm := NewMockSessionManager(endpoints) - - // Create a new session - endpoint1, sessionID, err := sm.GetSandboxInfoBySessionId("", "default", "test-agent", KindAgent) - if err != nil { - t.Fatalf("Failed to create new session: %v", err) - } - - // Use the same session ID again - endpoint2, sessionID2, err := sm.GetSandboxInfoBySessionId(sessionID, "default", "test-agent", KindAgent) - if err != nil { - t.Fatalf("Failed to get existing session: %v", err) - } - - // Should return the same endpoint and session ID - if endpoint1 != endpoint2 { - t.Errorf("Expected same endpoint for existing session, got %s vs %s", endpoint1, endpoint2) - } - - if sessionID != sessionID2 { - t.Errorf("Expected same session ID, got %s vs %s", sessionID, sessionID2) - } +func (f *fakeRedisClient) GetSandboxBySessionID(ctx context.Context, sessionID string) (*types.SandboxRedis, error) { + f.called = true + f.lastSessionID = sessionID + f.lastContextNil = ctx == nil + return f.sandbox, f.err } -func TestMockSessionManager_RoundRobinDistribution(t *testing.T) { - endpoints := []string{ - "http://sandbox-1:8080", - "http://sandbox-2:8080", - "http://sandbox-3:8080", - } - sm := NewMockSessionManager(endpoints) - - // Create multiple sessions and track endpoint distribution - endpointCount := make(map[string]int) - numSessions := 9 // Multiple of 3 to test round-robin - - for i := 0; i < numSessions; i++ { - endpoint, _, err := sm.GetSandboxInfoBySessionId("", "default", "test-agent", KindAgent) - if err != nil { - t.Fatalf("Failed to create session %d: %v", i, err) - } - endpointCount[endpoint]++ - } - - // Each endpoint should be used equally - expectedCount := numSessions / len(endpoints) - for _, endpoint := range endpoints { - if count := endpointCount[endpoint]; count != expectedCount { - t.Errorf("Expected endpoint %s to be used %d times, got %d", endpoint, expectedCount, count) - } - } +func (f *fakeRedisClient) SetSessionLockIfAbsent(ctx context.Context, sessionID string, ttl time.Duration) (bool, error) { + return false, nil } -func TestMockSessionManager_EmptyEndpoints(t *testing.T) { - // Test with empty endpoints (should use defaults) - sm := NewMockSessionManager([]string{}) - - endpoint, sessionID, err := sm.GetSandboxInfoBySessionId("", "default", "test-agent", KindAgent) - if err != nil { - t.Fatalf("Failed to create session with default endpoints: %v", err) - } - - if endpoint == "" { - t.Error("Expected non-empty endpoint with default configuration") - } - - if sessionID == "" { - t.Error("Expected non-empty session ID") - } +func (f *fakeRedisClient) BindSessionWithSandbox(ctx context.Context, sessionID string, sandboxRedis *types.SandboxRedis, ttl time.Duration) error { + return nil } -func TestMockSessionManager_SessionKeyGeneration(t *testing.T) { - endpoints := []string{"http://sandbox-1:8080"} - sm := NewMockSessionManager(endpoints) +func (f *fakeRedisClient) DeleteSessionBySandboxIDTx(ctx context.Context, sandboxID string) error { + return nil +} - // Test session with namespace and name - endpoint1, sessionID1, err := sm.GetSandboxInfoBySessionId("", "namespace1", "agent1", KindAgent) - if err != nil { - t.Fatalf("Failed to create session: %v", err) - } +func (f *fakeRedisClient) StoreSandbox(ctx context.Context, sandboxRedis *types.SandboxRedis, ttl time.Duration) error { + return nil +} - // Test session with different namespace but same name - _, sessionID2, err := sm.GetSandboxInfoBySessionId("", "namespace2", "agent1", KindAgent) - if err != nil { - t.Fatalf("Failed to create session: %v", err) - } +func (f *fakeRedisClient) Ping(ctx context.Context) error { + return nil +} - // Should create different sessions - if sessionID1 == sessionID2 { - t.Error("Expected different session IDs for different namespaces") - } +func (f *fakeRedisClient) ListExpiredSandboxes(ctx context.Context, before time.Time, limit int64) ([]*types.SandboxRedis, error) { + return nil, nil +} - // Test reusing first session - endpoint3, sessionID3, err := sm.GetSandboxInfoBySessionId(sessionID1, "namespace1", "agent1", KindAgent) - if err != nil { - t.Fatalf("Failed to reuse session: %v", err) - } +func (f *fakeRedisClient) ListInactiveSandboxes(ctx context.Context, before time.Time, limit int64) ([]*types.SandboxRedis, error) { + return nil, nil +} - if endpoint1 != endpoint3 || sessionID1 != sessionID3 { - t.Error("Expected same endpoint and session ID when reusing session") - } +func (f *fakeRedisClient) UpdateSandboxLastActivity(ctx context.Context, sandboxID string, at time.Time) error { + return nil } -func TestMockSessionManager_GetSessionCount(t *testing.T) { - endpoints := []string{"http://sandbox-1:8080"} - sm := NewMockSessionManager(endpoints) +// ---- tests: GetSandboxBySession ---- - // Initially should have 0 sessions - if count := sm.GetSessionCount(); count != 0 { - t.Errorf("Expected 0 initial sessions, got %d", count) +func TestGetSandboxBySession_Success(t *testing.T) { + sb := &types.SandboxRedis{ + SandboxID: "sandbox-1", + SandboxName: "sandbox-1", + EntryPoints: []types.SandboxEntryPoints{ + {Endpoint: "10.0.0.1:9000"}, + }, + SessionID: "sess-1", + Status: "running", } - // Create some sessions - numSessions := 3 - for i := 0; i < numSessions; i++ { - _, _, err := sm.GetSandboxInfoBySessionId("", "default", "test-agent", KindAgent) - if err != nil { - t.Fatalf("Failed to create session %d: %v", i, err) - } + r := &fakeRedisClient{ + sandbox: sb, } - - // Should have the expected number of sessions - if count := sm.GetSessionCount(); count != numSessions { - t.Errorf("Expected %d sessions, got %d", numSessions, count) + m := &manager{ + redisClient: r, } -} -func TestMockSessionManager_RemoveSession(t *testing.T) { - endpoints := []string{"http://sandbox-1:8080"} - sm := NewMockSessionManager(endpoints) - - // Create a session - _, sessionID, err := sm.GetSandboxInfoBySessionId("", "default", "test-agent", KindAgent) + got, err := m.GetSandboxBySession("sess-1", "default", "test", "AgentRuntime") if err != nil { - t.Fatalf("Failed to create session: %v", err) + t.Fatalf("GetSandboxBySession unexpected error: %v", err) } - - // Verify session exists - if count := sm.GetSessionCount(); count != 1 { - t.Errorf("Expected 1 session, got %d", count) + if !r.called { + t.Fatalf("expected RedisClient to be called") } - - // Remove the session using the composite key (namespace/name/sessionId) - compositeKey := fmt.Sprintf("default/test-agent/%s", sessionID) - sm.RemoveSession(compositeKey) - - // Verify session was removed - if count := sm.GetSessionCount(); count != 0 { - t.Errorf("Expected 0 sessions after removal, got %d", count) + if r.lastSessionID != "sess-1" { + t.Fatalf("expected RedisClient to be called with sessionID 'sess-1', got %q", r.lastSessionID) } - - // Test removing session without namespace/name (should not affect anything) - sm.RemoveSession(sessionID) - if count := sm.GetSessionCount(); count != 0 { - t.Errorf("Expected 0 sessions after second removal attempt, got %d", count) + if got == nil { + t.Fatalf("expected non-nil sandbox") + } + if got.SandboxID != "sandbox-1" { + t.Fatalf("unexpected SandboxID: got %q, want %q", got.SandboxID, "sandbox-1") } } -func TestMockSessionManager_ConcurrentAccess(t *testing.T) { - endpoints := []string{"http://sandbox-1:8080", "http://sandbox-2:8080"} - sm := NewMockSessionManager(endpoints) - - // Test concurrent session creation - done := make(chan bool, 10) - sessionIDs := make(chan string, 10) - - for i := 0; i < 10; i++ { - go func(index int) { - _, sessionID, err := sm.GetSandboxInfoBySessionId("", "default", "test-agent", KindAgent) - if err != nil { - t.Errorf("Concurrent session creation failed: %v", err) - } else { - sessionIDs <- sessionID - } - done <- true - }(i) +func TestGetSandboxBySession_NotFound(t *testing.T) { + r := &fakeRedisClient{ + sandbox: nil, + err: redis.ErrNotFound, } - - // Wait for all goroutines to complete - for i := 0; i < 10; i++ { - <-done + m := &manager{ + redisClient: r, } - close(sessionIDs) - // Verify all sessions were created with unique IDs - uniqueIDs := make(map[string]bool) - for sessionID := range sessionIDs { - if uniqueIDs[sessionID] { - t.Errorf("Duplicate session ID found: %s", sessionID) - } - uniqueIDs[sessionID] = true + _, err := m.GetSandboxBySession("sess-1", "default", "test", "AgentRuntime") + if err == nil { + t.Fatalf("expected error for not found session") } - - if len(uniqueIDs) != 10 { - t.Errorf("Expected 10 unique session IDs, got %d", len(uniqueIDs)) + if !errors.Is(err, ErrSessionNotFound) { + t.Fatalf("expected ErrSessionNotFound, got %v", err) } } diff --git a/pkg/sessionmgr/manager.go b/pkg/sessionmgr/manager.go deleted file mode 100644 index 4aaf7eba..00000000 --- a/pkg/sessionmgr/manager.go +++ /dev/null @@ -1,123 +0,0 @@ -package sessionmgr - -import ( - "context" - "errors" - "fmt" - - "github.com/volcano-sh/agentcube/pkg/common/types" - "github.com/volcano-sh/agentcube/pkg/redis" -) - -// Manager defines the session management behavior on top of Redis and the sandbox manager. -type Manager interface { - // GetSandboxBySession returns the sandbox associated with the given sessionID. - GetSandboxBySession(ctx context.Context, sessionID string) (*types.SandboxRedis, error) - // CreateSandbox creates a new sandbox via the sandbox manager and returns a SandboxRedis view. - CreateSandbox(ctx context.Context, req *types.CreateSandboxRequest) (*types.SandboxRedis, error) -} - -// RedisClient is the subset of the redis.Client interface used by the session manager. -// A redis.Client returned by redis.NewClient satisfies this interface. -type RedisClient interface { - GetSandboxBySessionID(ctx context.Context, sessionID string) (*types.SandboxRedis, error) -} - -// SandboxManagerClient defines the sandbox manager operations used by the session manager. -type SandboxManagerClient interface { - CreateSandbox(ctx context.Context, req *types.CreateSandboxRequest) (*types.CreateSandboxResponse, error) -} - -// manager is the default implementation of the Manager interface. -type manager struct { - redis RedisClient - sandbox SandboxManagerClient -} - -// New returns a default Manager implementation. -// Redis and sandbox manager clients are injected from the outside to make testing -// and implementation swapping easier. -func New(redisClient RedisClient, sandboxClient SandboxManagerClient) Manager { - return &manager{ - redis: redisClient, - sandbox: sandboxClient, - } -} - -// GetSandboxBySession looks up the sandbox by sessionID using Redis. -func (m *manager) GetSandboxBySession(ctx context.Context, sessionID string) (*types.SandboxRedis, error) { - if sessionID == "" { - return nil, ErrInvalidArgument - } - - // For now we do not validate the SessionID format; any non-empty string is treated as valid. - - sb, err := m.redis.GetSandboxBySessionID(ctx, sessionID) - if err != nil { - // redis.ErrNotFound is mapped to the unified ErrSessionNotFound in session manager. - if errors.Is(err, redis.ErrNotFound) { - return nil, ErrSessionNotFound - } - // Other errors are wrapped and propagated for upper layers to log and map to 5xx. - return nil, fmt.Errorf("sessionmgr: get sandbox by sessionID %q from redis failed: %w", sessionID, err) - } - if sb == nil { - return nil, fmt.Errorf("sessionmgr: get sandbox by sessionID %q returned nil sandbox", sessionID) - } - - return sb, nil -} - -// CreateSandbox creates a new sandbox via the sandbox manager using the shared CreateSandboxRequest type. -func (m *manager) CreateSandbox(ctx context.Context, req *types.CreateSandboxRequest) (*types.SandboxRedis, error) { - if req == nil { - return nil, ErrInvalidArgument - } - // Basic argument validation: creating a sandbox requires at least Kind and Namespace. - if req.Kind == "" || req.Namespace == "" { - return nil, ErrInvalidArgument - } - - cResp, err := m.sandbox.CreateSandbox(ctx, req) - if err != nil { - // Upstream network/timeouts etc. are treated as ErrUpstreamUnavailable. - // Here we roughly distinguish by whether the error is our ErrCreateSandboxFailed. - if errors.Is(err, ErrCreateSandboxFailed) { - return nil, err - } - return nil, fmt.Errorf("%w: %v", ErrUpstreamUnavailable, err) - } - - if cResp == nil || cResp.SessionID == "" || cResp.SandboxID == "" { - return nil, fmt.Errorf("%w: invalid response from sandbox manager", ErrCreateSandboxFailed) - } - - // Construct a SandboxRedis view from the response so that callers - // see a consistent sandbox object. - sb := &types.SandboxRedis{ - SandboxID: cResp.SandboxID, - SandboxName: cResp.SandboxName, - EntryPoints: cResp.Accesses, - SessionID: cResp.SessionID, - // CreatedAt / ExpiresAt / Status can be filled later when they are available. - } - - return sb, nil -} - -var ( - // ErrInvalidArgument indicates that the request arguments are invalid - // (for example, missing kind/namespace when creating a sandbox). - ErrInvalidArgument = errors.New("sessionmgr: invalid argument") - - // ErrSessionNotFound indicates that the session does not exist in redis, - // and is typically mapped to HTTP 404/410. - ErrSessionNotFound = errors.New("sessionmgr: session not found") - - // ErrUpstreamUnavailable indicates that the sandbox manager is unavailable - // (e.g. due to network errors), and is typically mapped to HTTP 503. - ErrUpstreamUnavailable = errors.New("sessionmgr: sandbox manager unavailable") - - // ErrCreateSandboxFailed indicates that the sandbox manager returned a business-level error. - ErrCreateSandboxFailed = errors.New("sessionmgr: create sandbox failed") -) diff --git a/pkg/sessionmgr/manager_test.go b/pkg/sessionmgr/manager_test.go deleted file mode 100644 index 49caff7e..00000000 --- a/pkg/sessionmgr/manager_test.go +++ /dev/null @@ -1,335 +0,0 @@ -package sessionmgr - -import ( - "context" - "errors" - "strings" - "testing" - - "github.com/volcano-sh/agentcube/pkg/common/types" - "github.com/volcano-sh/agentcube/pkg/redis" -) - -// ---- fakes ---- - -type fakeRedisClient struct { - sandbox *types.SandboxRedis - err error - called bool - lastSessionID string - lastContextNil bool -} - -func (f *fakeRedisClient) GetSandboxBySessionID(ctx context.Context, sessionID string) (*types.SandboxRedis, error) { - f.called = true - f.lastSessionID = sessionID - f.lastContextNil = ctx == nil - return f.sandbox, f.err -} - -type fakeSandboxManagerClient struct { - resp *types.CreateSandboxResponse - err error - called bool - lastReq *types.CreateSandboxRequest - lastCtxNil bool - calls int -} - -func (f *fakeSandboxManagerClient) CreateSandbox(ctx context.Context, req *types.CreateSandboxRequest) (*types.CreateSandboxResponse, error) { - f.called = true - f.calls++ - f.lastReq = req - f.lastCtxNil = ctx == nil - return f.resp, f.err -} - -// ---- tests: GetSandboxBySession ---- - -func TestGetSandboxBySession_Success(t *testing.T) { - ctx := context.Background() - - sb := &types.SandboxRedis{ - SandboxID: "sandbox-1", - SandboxName: "sandbox-1", - EntryPoints: []types.SandboxAccess{ - {Endpoint: "10.0.0.1:9000"}, - }, - SessionID: "sess-1", - Status: "running", - } - - r := &fakeRedisClient{ - sandbox: sb, - } - m := New(r, &fakeSandboxManagerClient{}) - - got, err := m.GetSandboxBySession(ctx, "sess-1") - if err != nil { - t.Fatalf("GetSandboxBySession unexpected error: %v", err) - } - if !r.called { - t.Fatalf("expected RedisClient to be called") - } - if r.lastSessionID != "sess-1" { - t.Fatalf("expected RedisClient to be called with sessionID 'sess-1', got %q", r.lastSessionID) - } - if got == nil { - t.Fatalf("expected non-nil sandbox") - } - if got.SandboxID != "sandbox-1" { - t.Fatalf("unexpected SandboxID: got %q, want %q", got.SandboxID, "sandbox-1") - } -} - -func TestGetSandboxBySession_EmptySessionID(t *testing.T) { - ctx := context.Background() - m := New(&fakeRedisClient{}, &fakeSandboxManagerClient{}) - - _, err := m.GetSandboxBySession(ctx, "") - if err == nil { - t.Fatalf("expected error for empty sessionID") - } - if !errors.Is(err, ErrInvalidArgument) { - t.Fatalf("expected ErrInvalidArgument, got %v", err) - } -} - -func TestGetSandboxBySession_NotFound(t *testing.T) { - ctx := context.Background() - r := &fakeRedisClient{ - sandbox: nil, - err: redis.ErrNotFound, - } - m := New(r, &fakeSandboxManagerClient{}) - - _, err := m.GetSandboxBySession(ctx, "sess-1") - if err == nil { - t.Fatalf("expected error for not found session") - } - if !errors.Is(err, ErrSessionNotFound) { - t.Fatalf("expected ErrSessionNotFound, got %v", err) - } -} - -func TestGetSandboxBySession_OtherErrorWrapped(t *testing.T) { - ctx := context.Background() - inner := errors.New("redis boom") - r := &fakeRedisClient{ - sandbox: nil, - err: inner, - } - m := New(r, &fakeSandboxManagerClient{}) - - _, err := m.GetSandboxBySession(ctx, "sess-1") - if err == nil { - t.Fatalf("expected error") - } - // Should wrap the inner error. - if !errors.Is(err, inner) { - t.Fatalf("expected error to wrap inner error, got %v", err) - } - if !strings.Contains(err.Error(), "sessionmgr: get sandbox by sessionID") { - t.Fatalf("unexpected error message: %v", err) - } -} - -func TestGetSandboxBySession_NilSandbox(t *testing.T) { - ctx := context.Background() - r := &fakeRedisClient{ - sandbox: nil, - err: nil, - } - m := New(r, &fakeSandboxManagerClient{}) - - _, err := m.GetSandboxBySession(ctx, "sess-1") - if err == nil { - t.Fatalf("expected error when redis returns nil sandbox") - } - if !strings.Contains(err.Error(), "returned nil sandbox") { - t.Fatalf("unexpected error message: %v", err) - } -} - -// ---- tests: CreateSandbox ---- - -func TestCreateSandbox_Success(t *testing.T) { - ctx := context.Background() - - req := &types.CreateSandboxRequest{ - Kind: "agent", - Name: "sandbox-name", - Namespace: "default", - } - - resp := &types.CreateSandboxResponse{ - SessionID: "sess-1", - SandboxID: "sandbox-1", - SandboxName: "sandbox-name", - Accesses: []types.SandboxAccess{ - {Endpoint: "10.0.0.1:9000"}, - }, - } - - s := &fakeSandboxManagerClient{ - resp: resp, - } - m := New(&fakeRedisClient{}, s) - - got, err := m.CreateSandbox(ctx, req) - if err != nil { - t.Fatalf("CreateSandbox unexpected error: %v", err) - } - if !s.called { - t.Fatalf("expected SandboxManagerClient to be called") - } - if got == nil { - t.Fatalf("expected non-nil sandbox") - } - if got.SandboxID != "sandbox-1" { - t.Fatalf("unexpected SandboxID: got %q, want %q", got.SandboxID, "sandbox-1") - } - if got.SessionID != "sess-1" { - t.Fatalf("unexpected SessionID: got %q, want %q", got.SessionID, "sess-1") - } - if len(got.EntryPoints) != 1 || got.EntryPoints[0].Endpoint != "10.0.0.1:9000" { - t.Fatalf("unexpected EntryPoints: %+v", got.EntryPoints) - } -} - -func TestCreateSandbox_NilRequest(t *testing.T) { - ctx := context.Background() - s := &fakeSandboxManagerClient{} - m := New(&fakeRedisClient{}, s) - - _, err := m.CreateSandbox(ctx, nil) - if err == nil { - t.Fatalf("expected error for nil request") - } - if !errors.Is(err, ErrInvalidArgument) { - t.Fatalf("expected ErrInvalidArgument, got %v", err) - } - if s.called { - t.Fatalf("expected SandboxManagerClient not to be called for nil request") - } -} - -func TestCreateSandbox_InvalidRequest_MissingKindOrNamespace(t *testing.T) { - ctx := context.Background() - s := &fakeSandboxManagerClient{} - m := New(&fakeRedisClient{}, s) - - // Missing Kind. - _, err := m.CreateSandbox(ctx, &types.CreateSandboxRequest{ - Kind: "", - Name: "name", - Namespace: "ns", - }) - if err == nil { - t.Fatalf("expected error for missing Kind") - } - if !errors.Is(err, ErrInvalidArgument) { - t.Fatalf("expected ErrInvalidArgument, got %v", err) - } - if s.called { - t.Fatalf("expected SandboxManagerClient not to be called when Kind is empty") - } - - // Reset fake. - s.called = false - s.calls = 0 - - // Missing Namespace. - _, err = m.CreateSandbox(ctx, &types.CreateSandboxRequest{ - Kind: "agent", - Name: "name", - Namespace: "", - }) - if err == nil { - t.Fatalf("expected error for missing Namespace") - } - if !errors.Is(err, ErrInvalidArgument) { - t.Fatalf("expected ErrInvalidArgument, got %v", err) - } - if s.called { - t.Fatalf("expected SandboxManagerClient not to be called when Namespace is empty") - } -} - -func TestCreateSandbox_CreateSandboxFailed(t *testing.T) { - ctx := context.Background() - s := &fakeSandboxManagerClient{ - err: ErrCreateSandboxFailed, - } - m := New(&fakeRedisClient{}, s) - - req := &types.CreateSandboxRequest{ - Kind: "agent", - Name: "sandbox-name", - Namespace: "default", - } - - _, err := m.CreateSandbox(ctx, req) - if err == nil { - t.Fatalf("expected error") - } - // ErrCreateSandboxFailed should be propagated as is. - if !errors.Is(err, ErrCreateSandboxFailed) { - t.Fatalf("expected ErrCreateSandboxFailed, got %v", err) - } -} - -func TestCreateSandbox_UpstreamUnavailableWrapped(t *testing.T) { - ctx := context.Background() - inner := errors.New("upstream timeout") - s := &fakeSandboxManagerClient{ - err: inner, - } - m := New(&fakeRedisClient{}, s) - - req := &types.CreateSandboxRequest{ - Kind: "agent", - Name: "sandbox-name", - Namespace: "default", - } - - _, err := m.CreateSandbox(ctx, req) - if err == nil { - t.Fatalf("expected error") - } - // Should be wrapped as ErrUpstreamUnavailable. - if !errors.Is(err, ErrUpstreamUnavailable) { - t.Fatalf("expected error to be ErrUpstreamUnavailable, got %v", err) - } - // Inner error is included in the message, but not wrapped with %w, - // so errors.Is on inner should be false. - if errors.Is(err, inner) { - t.Fatalf("did not expect error to wrap inner error via errors.Is") - } - if !strings.Contains(err.Error(), "upstream timeout") { - t.Fatalf("expected error message to contain inner error, got %v", err) - } -} - -func TestCreateSandbox_InvalidResponse(t *testing.T) { - ctx := context.Background() - // Response missing SessionID and SandboxID. - s := &fakeSandboxManagerClient{ - resp: &types.CreateSandboxResponse{}, - } - m := New(&fakeRedisClient{}, s) - - req := &types.CreateSandboxRequest{ - Kind: "agent", - Name: "sandbox-name", - Namespace: "default", - } - - _, err := m.CreateSandbox(ctx, req) - if err == nil { - t.Fatalf("expected error for invalid CreateSandboxResponse") - } - if !errors.Is(err, ErrCreateSandboxFailed) { - t.Fatalf("expected ErrCreateSandboxFailed, got %v", err) - } -} From d1f24f86eeabb7d597617d1f51b6b1a3ea5a2eae Mon Sep 17 00:00:00 2001 From: VanderChen Date: Mon, 8 Dec 2025 21:33:51 +0800 Subject: [PATCH 5/6] Fix router path and set no timeout for invoke request Signed-off-by: VanderChen --- Dockerfile.router | 41 ++ Makefile | 11 +- cmd/router/main.go | 14 +- k8s/agentcube-router.yaml | 10 +- pkg/agentd/agentd_test.go | 6 +- pkg/router/config.go | 3 - pkg/router/handlers.go | 111 ++-- pkg/router/handlers_test.go | 485 ++++++++++++++++++ pkg/router/{apiserver.go => server.go} | 12 +- .../{apiserver_test.go => server_test.go} | 8 - pkg/router/session_manager.go | 12 +- pkg/router/session_manager_test.go | 288 ++++++++++- pkg/router/utils.go | 60 --- pkg/workloadmanager/garbage_collection.go | 3 +- pkg/workloadmanager/handlers.go | 5 +- pkg/workloadmanager/server.go | 1 + 16 files changed, 895 insertions(+), 175 deletions(-) create mode 100644 Dockerfile.router create mode 100644 pkg/router/handlers_test.go rename pkg/router/{apiserver.go => server.go} (91%) rename pkg/router/{apiserver_test.go => server_test.go} (96%) delete mode 100644 pkg/router/utils.go diff --git a/Dockerfile.router b/Dockerfile.router new file mode 100644 index 00000000..475c7441 --- /dev/null +++ b/Dockerfile.router @@ -0,0 +1,41 @@ +# Multi-stage build for agentcube-router +FROM golang:1.24.9-alpine AS builder + +# Build arguments for multi-architecture support +ARG TARGETOS=linux +ARG TARGETARCH + +WORKDIR /workspace + +# Copy go mod files +COPY go.mod go.sum ./ +RUN go mod download + +# Copy source code +COPY cmd/ cmd/ +COPY pkg/ pkg/ +COPY client-go/ client-go/ + +# Build with dynamic architecture support +# Supports amd64, arm64, arm/v7, etc. +RUN CGO_ENABLED=0 GOOS=${TARGETOS} GOARCH=${TARGETARCH} \ + go build -o agentcube-router ./cmd/router + +# Runtime image +FROM alpine:3.19 + +RUN apk --no-cache add ca-certificates + +WORKDIR /app + +# Copy binary from builder +COPY --from=builder /workspace/agentcube-router . + +# Run as non-root user +RUN adduser -D -u 1000 router +USER router + +EXPOSE 8080 + +ENTRYPOINT ["/app/agentcube-router"] +CMD ["--port=8080", "--debug"] diff --git a/Makefile b/Makefile index b8649c91..63e750f6 100644 --- a/Makefile +++ b/Makefile @@ -72,6 +72,10 @@ build: generate ## Build agentcube-apiserver binary @echo "Building agentcube-apiserver..." go build -o bin/agentcube-apiserver ./cmd/workload-manager +build-router: ## Build agentcube-router binary + @echo "Building agentcube-router..." + go build -o bin/agentcube-router ./cmd/router + build-agentd: generate ## Build agentd binary @echo "Building agentd..." go build -o bin/agentd ./cmd/agentd @@ -80,7 +84,7 @@ build-test-tunnel: ## Build test-tunnel tool @echo "Building test-tunnel..." go build -o bin/test-tunnel ./cmd/test-tunnel -build-all: build build-agentd build-test-tunnel ## Build all binaries +build-all: build build-router build-agentd build-test-tunnel ## Build all binaries # Run server (development mode) run: @@ -140,6 +144,7 @@ install: build # Docker image variables APISERVER_IMAGE ?= agentcube-apiserver:latest +ROUTER_IMAGE ?= agentcube-router:latest IMAGE_REGISTRY ?= "" # Docker and Kubernetes targets @@ -147,6 +152,10 @@ docker-build: @echo "Building Docker image..." docker build -t $(APISERVER_IMAGE) . +docker-build-router: ## Build router Docker image + @echo "Building router Docker image..." + docker build -f Dockerfile.router -t $(ROUTER_IMAGE) . + # Multi-architecture build (supports amd64, arm64) docker-buildx: @echo "Building multi-architecture Docker image..." diff --git a/cmd/router/main.go b/cmd/router/main.go index f33cb84c..b55e7939 100644 --- a/cmd/router/main.go +++ b/cmd/router/main.go @@ -7,7 +7,6 @@ import ( "os" "os/signal" "syscall" - "time" "github.com/volcano-sh/agentcube/pkg/router" ) @@ -30,12 +29,7 @@ func main() { // Create Router API server configuration config := &router.Config{ - Port: *port, - SandboxEndpoints: []string{ - "http://sandbox-1:8080", - "http://sandbox-2:8080", - "http://sandbox-3:8080", - }, // Default sandbox endpoints, can be configured via env vars + Port: *port, Debug: *debug, EnableTLS: *enableTLS, TLSCert: *tlsCert, @@ -69,8 +63,10 @@ func main() { select { case <-ctx.Done(): log.Println("Received shutdown signal, shutting down gracefully...") - // Wait for server to finish shutting down - time.Sleep(2 * time.Second) + // Cancel the context to trigger server shutdown + cancel() + // Wait for server goroutine to exit after graceful shutdown is complete + <-errCh case err := <-errCh: log.Fatalf("Server error: %v", err) } diff --git a/k8s/agentcube-router.yaml b/k8s/agentcube-router.yaml index 3477b7d7..01ada16a 100644 --- a/k8s/agentcube-router.yaml +++ b/k8s/agentcube-router.yaml @@ -43,10 +43,13 @@ rules: verbs: ["update"] - apiGroups: ["runtime.agentcube.volcano.sh"] resources: ["agentruntimes"] - verbs: ["get", "list", "watch"] + verbs: ["get", "list", "watch", "create", "update", "patch", "delete"] - apiGroups: ["runtime.agentcube.volcano.sh"] resources: ["agentruntimes/status"] - verbs: ["update", "patch"] + verbs: ["get", "update", "patch"] + - apiGroups: ["runtime.agentcube.volcano.sh"] + resources: ["agentruntimes/finalizers"] + verbs: ["update"] - apiGroups: [""] resources: ["pods"] verbs: ["get", "list", "watch"] @@ -100,9 +103,10 @@ spec: value: "127.0.0.1:6379" - name: REDIS_PASSWORD value: "" + - name: WORKLOAD_MGR_URL + value: "http://agentcube-workload-manager:8080" args: - --port=8080 - - --runtime-class-name= - --debug resources: requests: diff --git a/pkg/agentd/agentd_test.go b/pkg/agentd/agentd_test.go index 48e65e60..da9b4d98 100644 --- a/pkg/agentd/agentd_test.go +++ b/pkg/agentd/agentd_test.go @@ -38,7 +38,7 @@ func TestReconciler_Reconcile_WithLastActivity(t *testing.T) { Name: "test-sandbox", Namespace: "default", Annotations: map[string]string{ - "agentcube.volcano.sh/last-activity": now.Add(-5 * time.Minute).Format(time.RFC3339), + "last-activity-time": now.Add(-5 * time.Minute).Format(time.RFC3339), }, }, Status: sandboxv1alpha1.SandboxStatus{ @@ -61,7 +61,7 @@ func TestReconciler_Reconcile_WithLastActivity(t *testing.T) { Name: "test-sandbox", Namespace: "default", Annotations: map[string]string{ - "agentcube.volcano.sh/last-activity": now.Add(-20 * time.Minute).Format(time.RFC3339), + "last-activity-time": now.Add(-20 * time.Minute).Format(time.RFC3339), }, }, Status: sandboxv1alpha1.SandboxStatus{ @@ -83,7 +83,7 @@ func TestReconciler_Reconcile_WithLastActivity(t *testing.T) { Name: "test-sandbox", Namespace: "default", Annotations: map[string]string{ - "agentcube.volcano.sh/last-activity": now.Add(-20 * time.Minute).Format(time.RFC3339), + "last-activity-time": now.Add(-20 * time.Minute).Format(time.RFC3339), }, }, Status: sandboxv1alpha1.SandboxStatus{ diff --git a/pkg/router/config.go b/pkg/router/config.go index ab686daa..a0ff58c7 100644 --- a/pkg/router/config.go +++ b/pkg/router/config.go @@ -47,7 +47,4 @@ type Config struct { // RedisDB is the Redis database number RedisDB int - - // SessionExpireDuration is the duration after which inactive sessions expire - SessionExpireDuration int // seconds, default 3600 (1 hour) } diff --git a/pkg/router/handlers.go b/pkg/router/handlers.go index 12a36bf9..c540361d 100644 --- a/pkg/router/handlers.go +++ b/pkg/router/handlers.go @@ -1,7 +1,7 @@ package router import ( - "context" + "fmt" "log" "net/http" "net/http/httputil" @@ -42,24 +42,20 @@ func (s *Server) handleHealthReady(c *gin.Context) { }) } -// handleAgentInvoke handles agent invocation requests -func (s *Server) handleAgentInvoke(c *gin.Context) { - agentNamespace := c.Param("agentNamespace") - agentName := c.Param("agentName") - path := c.Param("path") - - log.Printf("Agent invoke request: namespace=%s, agent=%s, path=%s", agentNamespace, agentName, path) +// handleInvoke is a private helper function that handles invocation requests for both agents and code interpreters +func (s *Server) handleInvoke(c *gin.Context, namespace, name, path, kind string) { + log.Printf("%s invoke request: namespace=%s, name=%s, path=%s", kind, namespace, name, path) // Extract session ID from header sessionID := c.GetHeader("x-agentcube-session-id") // Get sandbox info from session manager - sandbox, err := s.sessionManager.GetSandboxBySession(sessionID, agentNamespace, agentName, "AgentRuntime") + sandbox, err := s.sessionManager.GetSandboxBySession(c.Request.Context(), sessionID, namespace, name, kind) if err != nil { - log.Printf("Failed to get sandbox info: %v", err) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "internal server error", - "code": "INTERNAL_ERROR", + log.Printf("Failed to get sandbox info: %v, session id %s", err, sessionID) + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Invalid session id %s", sessionID), + "code": "BadRequest", }) return } @@ -67,8 +63,13 @@ func (s *Server) handleAgentInvoke(c *gin.Context) { // Extract endpoint from sandbox - find matching entry point by path var endpoint string for _, ep := range sandbox.EntryPoints { - if ep.Path == path || ep.Path == "" { - endpoint = ep.Endpoint + if strings.HasPrefix(path, ep.Path) { + // Only add protocol if not already present + if ep.Protocol != "" && !strings.Contains(ep.Endpoint, "://") { + endpoint = strings.ToLower(ep.Protocol) + "://" + ep.Endpoint + } else { + endpoint = ep.Endpoint + } break } } @@ -83,18 +84,37 @@ func (s *Server) handleAgentInvoke(c *gin.Context) { }) return } - endpoint = sandbox.EntryPoints[0].Endpoint + // Only add protocol if not already present + if sandbox.EntryPoints[0].Protocol != "" && !strings.Contains(sandbox.EntryPoints[0].Endpoint, "://") { + endpoint = strings.ToLower(sandbox.EntryPoints[0].Protocol) + "://" + sandbox.EntryPoints[0].Endpoint + } else { + endpoint = sandbox.EntryPoints[0].Endpoint + } } + log.Printf("The selected entrypoint for session-id %s to sandbox is %s", sandbox.SessionID, endpoint) + // Update session activity in Redis when receiving request if sandbox.SessionID != "" && sandbox.SandboxID != "" { - if err := s.redisClient.UpdateSandboxLastActivity(c.Request.Context(), sandbox.SandboxID, time.Now()); err != nil { + if err := s.redisClient.UpdateSessionLastActivity(c.Request.Context(), sandbox.SessionID, time.Now()); err != nil { log.Printf("Failed to update sandbox last activity for request: %v", err) } } // Forward request to sandbox with session ID s.forwardToSandbox(c, endpoint, path, sandbox.SessionID) + + if err := s.redisClient.UpdateSessionLastActivity(c.Request.Context(), sandbox.SessionID, time.Now()); err != nil { + log.Printf("Failed to update sandbox last activity for request: %v", err) + } +} + +// handleAgentInvoke handles agent invocation requests +func (s *Server) handleAgentInvoke(c *gin.Context) { + namespace := c.Param("namespace") + name := c.Param("name") + path := c.Param("path") + s.handleInvoke(c, namespace, name, path, "AgentRuntime") } // handleCodeInterpreterInvoke handles code interpreter invocation requests @@ -102,54 +122,7 @@ func (s *Server) handleCodeInterpreterInvoke(c *gin.Context) { namespace := c.Param("namespace") name := c.Param("name") path := c.Param("path") - - log.Printf("Code interpreter invoke request: namespace=%s, name=%s, path=%s", namespace, name, path) - - // Extract session ID from header - sessionID := c.GetHeader("x-agentcube-session-id") - - // Get sandbox info from session manager - sandbox, err := s.sessionManager.GetSandboxBySession(sessionID, namespace, name, "CodeInterpreter") - if err != nil { - log.Printf("Failed to get sandbox info: %v", err) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "internal server error", - "code": "INTERNAL_ERROR", - }) - return - } - - // Extract endpoint from sandbox - find matching entry point by path - var endpoint string - for _, ep := range sandbox.EntryPoints { - if ep.Path == path || ep.Path == "" { - endpoint = ep.Endpoint - break - } - } - - // If no matching endpoint found, use the first one as fallback - if endpoint == "" { - if len(sandbox.EntryPoints) == 0 { - log.Printf("No entry points found for sandbox: %s", sandbox.SandboxID) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "internal server error", - "code": "INTERNAL_ERROR", - }) - return - } - endpoint = sandbox.EntryPoints[0].Endpoint - } - - // Update session activity in Redis when receiving request - if sandbox.SessionID != "" && sandbox.SandboxID != "" { - if err := s.redisClient.UpdateSandboxLastActivity(c.Request.Context(), sandbox.SandboxID, time.Now()); err != nil { - log.Printf("Failed to update sandbox last activity for request: %v", err) - } - } - - // Forward request to sandbox with session ID - s.forwardToSandbox(c, endpoint, path, sandbox.SessionID) + s.handleInvoke(c, namespace, name, path, "CodeInterpreter") } // forwardToSandbox forwards the request to the specified sandbox endpoint @@ -228,10 +201,10 @@ func (s *Server) forwardToSandbox(c *gin.Context, endpoint, path, sessionID stri return nil } - // Set timeout for the proxy request using configured timeout - ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(s.config.RequestTimeout)*time.Second) - defer cancel() - c.Request = c.Request.WithContext(ctx) + // No timeout for invoke requests to allow long-running operations + // ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(s.config.RequestTimeout)*time.Second) + // defer cancel() + // c.Request = c.Request.WithContext(ctx) // Use the proxy to serve the request proxy.ServeHTTP(c.Writer, c.Request) diff --git a/pkg/router/handlers_test.go b/pkg/router/handlers_test.go new file mode 100644 index 00000000..60b3cc4b --- /dev/null +++ b/pkg/router/handlers_test.go @@ -0,0 +1,485 @@ +package router + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "os" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/volcano-sh/agentcube/pkg/common/types" +) + +func init() { + // Set Gin to test mode + gin.SetMode(gin.TestMode) +} + +// Mock SessionManager for testing +type mockSessionManager struct { + sandbox *types.SandboxRedis + err error +} + +func (m *mockSessionManager) GetSandboxBySession(ctx context.Context, sessionID string, namespace string, name string, kind string) (*types.SandboxRedis, error) { + return m.sandbox, m.err +} + +func TestHandleHealth(t *testing.T) { + // Set required environment variables + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + + config := &Config{ + Port: "8080", + } + + server, err := NewServer(config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/health", nil) + server.engine.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code) + } + + expectedBody := `{"status":"healthy"}` + if w.Body.String() != expectedBody { + t.Errorf("Expected body %s, got %s", expectedBody, w.Body.String()) + } +} + +func TestHandleHealthLive(t *testing.T) { + // Set required environment variables + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + + config := &Config{ + Port: "8080", + } + + server, err := NewServer(config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/health/live", nil) + server.engine.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code) + } + + expectedBody := `{"status":"alive"}` + if w.Body.String() != expectedBody { + t.Errorf("Expected body %s, got %s", expectedBody, w.Body.String()) + } +} + +func TestHandleHealthReady(t *testing.T) { + // Set required environment variables + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + + tests := []struct { + name string + sessionManager SessionManager + expectedStatusCode int + expectedBody string + }{ + { + name: "ready with session manager", + sessionManager: &mockSessionManager{}, + expectedStatusCode: http.StatusOK, + expectedBody: `{"status":"ready"}`, + }, + { + name: "not ready without session manager", + sessionManager: nil, + expectedStatusCode: http.StatusServiceUnavailable, + expectedBody: `{"error":"session manager not available","status":"not ready"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := &Config{ + Port: "8080", + } + + server, err := NewServer(config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + // Override session manager for testing + server.sessionManager = tt.sessionManager + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/health/ready", nil) + server.engine.ServeHTTP(w, req) + + if w.Code != tt.expectedStatusCode { + t.Errorf("Expected status code %d, got %d", tt.expectedStatusCode, w.Code) + } + + if w.Body.String() != tt.expectedBody { + t.Errorf("Expected body %s, got %s", tt.expectedBody, w.Body.String()) + } + }) + } +} + +func TestHandleInvoke_SessionManagerError(t *testing.T) { + // Set required environment variables + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + + config := &Config{ + Port: "8080", + } + + server, err := NewServer(config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + // Mock session manager that returns error + server.sessionManager = &mockSessionManager{ + err: errors.New("session manager error"), + } + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/v1/namespaces/default/agent-runtimes/test-agent/invocations/test", nil) + server.engine.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected status code %d, got %d", http.StatusBadRequest, w.Code) + } +} + +func TestHandleInvoke_NoEntryPoints(t *testing.T) { + // Set required environment variables + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + + config := &Config{ + Port: "8080", + } + + server, err := NewServer(config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + // Mock session manager that returns sandbox with no entry points + server.sessionManager = &mockSessionManager{ + sandbox: &types.SandboxRedis{ + SandboxID: "test-sandbox", + SessionID: "test-session", + EntryPoints: []types.SandboxEntryPoints{}, + }, + } + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/v1/namespaces/default/agent-runtimes/test-agent/invocations/test", nil) + server.engine.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("Expected status code %d, got %d", http.StatusInternalServerError, w.Code) + } +} + +func TestHandleAgentInvoke(t *testing.T) { + // Set required environment variables + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + + // Create a test HTTP server to act as the sandbox + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"result":"success"}`)) + })) + defer testServer.Close() + + config := &Config{ + Port: "8080", + RequestTimeout: 30, + } + + server, err := NewServer(config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + // Mock session manager that returns sandbox with test server endpoint + server.sessionManager = &mockSessionManager{ + sandbox: &types.SandboxRedis{ + SandboxID: "test-sandbox", + SessionID: "test-session", + SandboxName: "test-sandbox", + EntryPoints: []types.SandboxEntryPoints{ + { + Endpoint: testServer.URL, + Path: "/test", + }, + }, + }, + } + + // Use real HTTP client instead of httptest.ResponseRecorder to avoid CloseNotifier panic + req, _ := http.NewRequest("POST", "/v1/namespaces/default/agent-runtimes/test-agent/invocations/test", nil) + req.Header.Set("x-agentcube-session-id", "test-session") + + // Start a real test server + testRouterServer := httptest.NewServer(server.engine) + defer testRouterServer.Close() + + // Make real HTTP request + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Post(testRouterServer.URL+"/v1/namespaces/default/agent-runtimes/test-agent/invocations/test", "application/json", nil) + if err != nil { + t.Fatalf("Failed to make request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode) + } + + // Check if session ID is set in response header + sessionID := resp.Header.Get("x-agentcube-session-id") + if sessionID != "test-session" { + t.Errorf("Expected session ID 'test-session', got '%s'", sessionID) + } +} + +func TestHandleCodeInterpreterInvoke(t *testing.T) { + // Set required environment variables + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + + // Create a test HTTP server to act as the sandbox + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"result":"success"}`)) + })) + defer testServer.Close() + + config := &Config{ + Port: "8080", + RequestTimeout: 30, + } + + server, err := NewServer(config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + // Mock session manager that returns sandbox with test server endpoint + server.sessionManager = &mockSessionManager{ + sandbox: &types.SandboxRedis{ + SandboxID: "test-sandbox", + SessionID: "test-session", + SandboxName: "test-sandbox", + EntryPoints: []types.SandboxEntryPoints{ + { + Endpoint: testServer.URL, + Path: "/execute", + }, + }, + }, + } + + // Use real HTTP client instead of httptest.ResponseRecorder to avoid CloseNotifier panic + testRouterServer := httptest.NewServer(server.engine) + defer testRouterServer.Close() + + // Make real HTTP request + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Post(testRouterServer.URL+"/v1/namespaces/default/code-interpreters/test-ci/invocations/execute", "application/json", nil) + if err != nil { + t.Fatalf("Failed to make request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode) + } + + // Check if session ID is set in response header + sessionID := resp.Header.Get("x-agentcube-session-id") + if sessionID != "test-session" { + t.Errorf("Expected session ID 'test-session', got '%s'", sessionID) + } +} + +func TestForwardToSandbox_InvalidEndpoint(t *testing.T) { + // Set required environment variables + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + + config := &Config{ + Port: "8080", + RequestTimeout: 30, + } + + server, err := NewServer(config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + // Mock session manager that returns sandbox with invalid endpoint + server.sessionManager = &mockSessionManager{ + sandbox: &types.SandboxRedis{ + SandboxID: "test-sandbox", + SessionID: "test-session", + SandboxName: "test-sandbox", + EntryPoints: []types.SandboxEntryPoints{ + { + Endpoint: "://invalid-url", + Path: "/test", + }, + }, + }, + } + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/v1/namespaces/default/agent-runtimes/test-agent/invocations/test", nil) + server.engine.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("Expected status code %d, got %d", http.StatusInternalServerError, w.Code) + } +} + +func TestConcurrencyLimitMiddleware_Overload(t *testing.T) { + // Set required environment variables + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + + config := &Config{ + Port: "8080", + MaxConcurrentRequests: 1, // Set to 1 to easily trigger overload + RequestTimeout: 30, + } + + server, err := NewServer(config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + // Create a slow test server + slowServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(100 * time.Millisecond) + w.WriteHeader(http.StatusOK) + })) + defer slowServer.Close() + + // Mock session manager with slow response + server.sessionManager = &mockSessionManager{ + sandbox: &types.SandboxRedis{ + SandboxID: "test-sandbox", + SessionID: "test-session", + SandboxName: "test-sandbox", + EntryPoints: []types.SandboxEntryPoints{ + { + Endpoint: slowServer.URL, + Path: "/test", + }, + }, + }, + } + + // Start a real test server + testRouterServer := httptest.NewServer(server.engine) + defer testRouterServer.Close() + + // Start first request (will occupy the semaphore) + done := make(chan bool) + go func() { + client := &http.Client{Timeout: 5 * time.Second} + _, _ = client.Post(testRouterServer.URL+"/v1/namespaces/default/agent-runtimes/test-agent/invocations/test", "application/json", nil) + done <- true + }() + + // Give first request time to acquire semaphore + time.Sleep(50 * time.Millisecond) + + // Try second request (should be rejected due to overload) + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Post(testRouterServer.URL+"/v1/namespaces/default/agent-runtimes/test-agent/invocations/test", "application/json", nil) + if err != nil { + t.Fatalf("Failed to make request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusServiceUnavailable { + t.Errorf("Expected status code %d, got %d", http.StatusServiceUnavailable, resp.StatusCode) + } + + // Wait for first request to complete + <-done +} diff --git a/pkg/router/apiserver.go b/pkg/router/server.go similarity index 91% rename from pkg/router/apiserver.go rename to pkg/router/server.go index 608614a5..85d6593d 100644 --- a/pkg/router/apiserver.go +++ b/pkg/router/server.go @@ -10,6 +10,7 @@ import ( "github.com/gin-gonic/gin" redisv9 "github.com/redis/go-redis/v9" + "github.com/volcano-sh/agentcube/pkg/redis" ) @@ -60,9 +61,6 @@ func NewServer(config *Config) (*Server, error) { if config.MaxConnsPerHost <= 0 { config.MaxConnsPerHost = 10 // Default 10 connections per host } - if config.SessionExpireDuration <= 0 { - config.SessionExpireDuration = 3600 // Default 1 hour - } // Initialize Redis client redisOptions, err := makeRedisOptions() @@ -88,7 +86,7 @@ func NewServer(config *Config) (*Server, error) { httpTransport := &http.Transport{ MaxIdleConns: config.MaxIdleConns, MaxIdleConnsPerHost: config.MaxConnsPerHost, - IdleConnTimeout: 90 * time.Second, + IdleConnTimeout: 0, DisableCompression: false, ForceAttemptHTTP2: true, } @@ -148,10 +146,10 @@ func (s *Server) setupRoutes() { v1.Use(s.concurrencyLimitMiddleware()) // Apply concurrency limit to API routes // Agent invoke requests - v1.Any("/namespaces/:agentNamespace/agent-runtimes/:agentName/invocations/*path", s.handleAgentInvoke) + v1.POST("/namespaces/:namespace/agent-runtimes/:name/invocations/*path", s.handleAgentInvoke) - // Code interpreter invoke requests - use different base path to avoid conflicts - v1.Any("/code-namespaces/:namespace/code-interpreters/:name/invocations/*path", s.handleCodeInterpreterInvoke) + // Code interpreter invoke requests + v1.POST("/namespaces/:namespace/code-interpreters/:name/invocations/*path", s.handleCodeInterpreterInvoke) } // Start starts the Router API server diff --git a/pkg/router/apiserver_test.go b/pkg/router/server_test.go similarity index 96% rename from pkg/router/apiserver_test.go rename to pkg/router/server_test.go index 9bc8a087..5932dd4a 100644 --- a/pkg/router/apiserver_test.go +++ b/pkg/router/server_test.go @@ -43,7 +43,6 @@ func TestNewServer(t *testing.T) { RequestTimeout: 60, MaxIdleConns: 200, MaxConnsPerHost: 20, - SessionExpireDuration: 7200, EnableRedis: true, Debug: true, }, @@ -114,9 +113,6 @@ func TestNewServer(t *testing.T) { if server.config.MaxConnsPerHost <= 0 { t.Error("MaxConnsPerHost should have been set to default") } - if server.config.SessionExpireDuration <= 0 { - t.Error("SessionExpireDuration should have been set to default") - } } }) } @@ -159,10 +155,6 @@ func TestServer_DefaultValues(t *testing.T) { if server.config.MaxConnsPerHost != 10 { t.Errorf("Expected default MaxConnsPerHost 10, got %d", server.config.MaxConnsPerHost) } - - if server.config.SessionExpireDuration != 3600 { - t.Errorf("Expected default SessionExpireDuration 3600, got %d", server.config.SessionExpireDuration) - } } func TestServer_ConcurrencyLimitMiddleware(t *testing.T) { diff --git a/pkg/router/session_manager.go b/pkg/router/session_manager.go index a273ec29..5a7e56bb 100644 --- a/pkg/router/session_manager.go +++ b/pkg/router/session_manager.go @@ -20,7 +20,7 @@ type SessionManager interface { // GetSandboxBySession returns the sandbox associated with the given sessionID. // When sessionID is empty, it creates a new sandbox by calling the external API. // When sessionID is not empty, it queries Redis for the sandbox. - GetSandboxBySession(sessionID string, namespace string, name string, kind string) (*types.SandboxRedis, error) + GetSandboxBySession(ctx context.Context, sessionID string, namespace string, name string, kind string) (*types.SandboxRedis, error) } // manager is the default implementation of the SessionManager interface. @@ -43,7 +43,7 @@ func NewSessionManager(redisClient redis.Client) (SessionManager, error) { redisClient: redisClient, workloadMgrURL: workloadMgrURL, httpClient: &http.Client{ - Timeout: 30 * time.Second, // Set a reasonable timeout to prevent hanging + Timeout: time.Minute, // No timeout for createSandbox requests }, }, nil } @@ -51,9 +51,7 @@ func NewSessionManager(redisClient redis.Client) (SessionManager, error) { // GetSandboxBySession returns the sandbox associated with the given sessionID. // When sessionID is empty, it creates a new sandbox by calling the external API. // When sessionID is not empty, it queries Redis for the sandbox. -func (m *manager) GetSandboxBySession(sessionID string, namespace string, name string, kind string) (*types.SandboxRedis, error) { - ctx := context.Background() - +func (m *manager) GetSandboxBySession(ctx context.Context, sessionID string, namespace string, name string, kind string) (*types.SandboxRedis, error) { // When sessionID is empty, create a new sandbox if sessionID == "" { return m.createSandbox(ctx, namespace, name, kind) @@ -128,8 +126,8 @@ func (m *manager) createSandbox(ctx context.Context, namespace string, name stri } // Validate response - if createResp.SessionID == "" || createResp.SandboxID == "" { - return nil, fmt.Errorf("%w: invalid response from workload manager", ErrCreateSandboxFailed) + if createResp.SessionID == "" { + return nil, fmt.Errorf("%w: response with empty session id from workload manager", ErrCreateSandboxFailed) } // Construct SandboxRedis from response diff --git a/pkg/router/session_manager_test.go b/pkg/router/session_manager_test.go index c9f41c79..cb21c0f8 100644 --- a/pkg/router/session_manager_test.go +++ b/pkg/router/session_manager_test.go @@ -2,7 +2,11 @@ package router import ( "context" + "encoding/json" "errors" + "io" + "net/http" + "net/http/httptest" "testing" "time" @@ -39,6 +43,18 @@ func (f *fakeRedisClient) DeleteSessionBySandboxIDTx(ctx context.Context, sandbo return nil } +func (f *fakeRedisClient) DeleteSandboxBySessionIDTx(ctx context.Context, sessionID string) error { + return nil +} + +func (f *fakeRedisClient) UpdateSandbox(ctx context.Context, sandboxRedis *types.SandboxRedis, ttl time.Duration) error { + return nil +} + +func (f *fakeRedisClient) UpdateSessionLastActivity(ctx context.Context, sessionID string, at time.Time) error { + return nil +} + func (f *fakeRedisClient) StoreSandbox(ctx context.Context, sandboxRedis *types.SandboxRedis, ttl time.Duration) error { return nil } @@ -79,7 +95,7 @@ func TestGetSandboxBySession_Success(t *testing.T) { redisClient: r, } - got, err := m.GetSandboxBySession("sess-1", "default", "test", "AgentRuntime") + got, err := m.GetSandboxBySession(context.Background(), "sess-1", "default", "test", "AgentRuntime") if err != nil { t.Fatalf("GetSandboxBySession unexpected error: %v", err) } @@ -106,7 +122,7 @@ func TestGetSandboxBySession_NotFound(t *testing.T) { redisClient: r, } - _, err := m.GetSandboxBySession("sess-1", "default", "test", "AgentRuntime") + _, err := m.GetSandboxBySession(context.Background(), "sess-1", "default", "test", "AgentRuntime") if err == nil { t.Fatalf("expected error for not found session") } @@ -114,3 +130,271 @@ func TestGetSandboxBySession_NotFound(t *testing.T) { t.Fatalf("expected ErrSessionNotFound, got %v", err) } } + +// ---- tests: GetSandboxBySession with empty sessionID (sandbox creation path) ---- + +func TestGetSandboxBySession_CreateSandbox_AgentRuntime_Success(t *testing.T) { + // Mock workload manager server + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request method and path + if r.Method != http.MethodPost { + t.Errorf("expected POST request, got %s", r.Method) + } + if r.URL.Path != "/v1/agent-runtime" { + t.Errorf("expected path /v1/agent-runtime, got %s", r.URL.Path) + } + + // Verify request body + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("failed to read request body: %v", err) + } + var req types.CreateSandboxRequest + if err := json.Unmarshal(body, &req); err != nil { + t.Fatalf("failed to unmarshal request: %v", err) + } + if req.Kind != types.AgentRuntimeKind { + t.Errorf("expected kind %s, got %s", types.AgentRuntimeKind, req.Kind) + } + if req.Name != "test-runtime" { + t.Errorf("expected name test-runtime, got %s", req.Name) + } + if req.Namespace != "default" { + t.Errorf("expected namespace default, got %s", req.Namespace) + } + + // Send successful response + resp := types.CreateSandboxResponse{ + SessionID: "new-session-123", + SandboxID: "sandbox-456", + SandboxName: "sandbox-test", + EntryPoints: []types.SandboxEntryPoints{ + {Endpoint: "10.0.0.1:9000", Protocol: "http", Path: "/"}, + }, + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(resp) + })) + defer mockServer.Close() + + r := &fakeRedisClient{} + m := &manager{ + redisClient: r, + workloadMgrURL: mockServer.URL, + httpClient: &http.Client{}, + } + + sandbox, err := m.GetSandboxBySession(context.Background(), "", "default", "test-runtime", types.AgentRuntimeKind) + if err != nil { + t.Fatalf("GetSandboxBySession unexpected error: %v", err) + } + if sandbox == nil { + t.Fatalf("expected non-nil sandbox") + } + if sandbox.SessionID != "new-session-123" { + t.Errorf("expected SessionID new-session-123, got %s", sandbox.SessionID) + } + if sandbox.SandboxID != "sandbox-456" { + t.Errorf("expected SandboxID sandbox-456, got %s", sandbox.SandboxID) + } + if sandbox.SandboxName != "sandbox-test" { + t.Errorf("expected SandboxName sandbox-test, got %s", sandbox.SandboxName) + } + if len(sandbox.EntryPoints) != 1 { + t.Fatalf("expected 1 entry point, got %d", len(sandbox.EntryPoints)) + } + if sandbox.EntryPoints[0].Endpoint != "10.0.0.1:9000" { + t.Errorf("expected endpoint 10.0.0.1:9000, got %s", sandbox.EntryPoints[0].Endpoint) + } +} + +func TestGetSandboxBySession_CreateSandbox_CodeInterpreter_Success(t *testing.T) { + // Mock workload manager server + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request method and path + if r.Method != http.MethodPost { + t.Errorf("expected POST request, got %s", r.Method) + } + if r.URL.Path != "/v1/code-interpreter" { + t.Errorf("expected path /v1/code-interpreter, got %s", r.URL.Path) + } + + // Verify request body + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("failed to read request body: %v", err) + } + var req types.CreateSandboxRequest + if err := json.Unmarshal(body, &req); err != nil { + t.Fatalf("failed to unmarshal request: %v", err) + } + if req.Kind != types.CodeInterpreterKind { + t.Errorf("expected kind %s, got %s", types.CodeInterpreterKind, req.Kind) + } + + // Send successful response + resp := types.CreateSandboxResponse{ + SessionID: "ci-session-789", + SandboxID: "ci-sandbox-101", + SandboxName: "ci-sandbox-test", + EntryPoints: []types.SandboxEntryPoints{ + {Endpoint: "10.0.0.2:8080", Protocol: "http", Path: "/"}, + }, + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(resp) + })) + defer mockServer.Close() + + r := &fakeRedisClient{} + m := &manager{ + redisClient: r, + workloadMgrURL: mockServer.URL, + httpClient: &http.Client{}, + } + + sandbox, err := m.GetSandboxBySession(context.Background(), "", "default", "test-ci", types.CodeInterpreterKind) + if err != nil { + t.Fatalf("GetSandboxBySession unexpected error: %v", err) + } + if sandbox == nil { + t.Fatalf("expected non-nil sandbox") + } + if sandbox.SessionID != "ci-session-789" { + t.Errorf("expected SessionID ci-session-789, got %s", sandbox.SessionID) + } +} + +func TestGetSandboxBySession_CreateSandbox_UnsupportedKind(t *testing.T) { + r := &fakeRedisClient{} + m := &manager{ + redisClient: r, + workloadMgrURL: "http://localhost:8080", + httpClient: &http.Client{}, + } + + _, err := m.GetSandboxBySession(context.Background(), "", "default", "test", "UnsupportedKind") + if err == nil { + t.Fatalf("expected error for unsupported kind") + } + if err.Error() != "unsupported kind: UnsupportedKind" { + t.Errorf("unexpected error message: %v", err) + } +} + +func TestGetSandboxBySession_CreateSandbox_WorkloadManagerUnavailable(t *testing.T) { + // Mock workload manager server that closes connection immediately + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Close connection without sending response + hj, ok := w.(http.Hijacker) + if !ok { + t.Fatal("webserver doesn't support hijacking") + } + conn, _, err := hj.Hijack() + if err != nil { + t.Fatal(err) + } + conn.Close() + })) + serverURL := mockServer.URL + mockServer.Close() // Close the server to make it unavailable + + r := &fakeRedisClient{} + m := &manager{ + redisClient: r, + workloadMgrURL: serverURL, + httpClient: &http.Client{}, + } + + _, err := m.GetSandboxBySession(context.Background(), "", "default", "test", types.AgentRuntimeKind) + if err == nil { + t.Fatalf("expected error for unavailable workload manager") + } + if !errors.Is(err, ErrUpstreamUnavailable) { + t.Errorf("expected ErrUpstreamUnavailable, got %v", err) + } +} + +func TestGetSandboxBySession_CreateSandbox_NonOKStatus(t *testing.T) { + // Mock workload manager server that returns error + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("internal server error")) + })) + defer mockServer.Close() + + r := &fakeRedisClient{} + m := &manager{ + redisClient: r, + workloadMgrURL: mockServer.URL, + httpClient: &http.Client{}, + } + + _, err := m.GetSandboxBySession(context.Background(), "", "default", "test", types.AgentRuntimeKind) + if err == nil { + t.Fatalf("expected error for non-OK status") + } + if !errors.Is(err, ErrCreateSandboxFailed) { + t.Errorf("expected ErrCreateSandboxFailed, got %v", err) + } +} + +func TestGetSandboxBySession_CreateSandbox_InvalidJSON(t *testing.T) { + // Mock workload manager server that returns invalid JSON + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte("invalid json")) + })) + defer mockServer.Close() + + r := &fakeRedisClient{} + m := &manager{ + redisClient: r, + workloadMgrURL: mockServer.URL, + httpClient: &http.Client{}, + } + + _, err := m.GetSandboxBySession(context.Background(), "", "default", "test", types.AgentRuntimeKind) + if err == nil { + t.Fatalf("expected error for invalid JSON") + } + if err.Error() == "" { + t.Errorf("expected error message for invalid JSON") + } +} + +func TestGetSandboxBySession_CreateSandbox_EmptySessionID(t *testing.T) { + // Mock workload manager server that returns empty sessionID + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := types.CreateSandboxResponse{ + SessionID: "", // Empty sessionID + SandboxID: "sandbox-456", + SandboxName: "sandbox-test", + EntryPoints: []types.SandboxEntryPoints{ + {Endpoint: "10.0.0.1:9000"}, + }, + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(resp) + })) + defer mockServer.Close() + + r := &fakeRedisClient{} + m := &manager{ + redisClient: r, + workloadMgrURL: mockServer.URL, + httpClient: &http.Client{}, + } + + _, err := m.GetSandboxBySession(context.Background(), "", "default", "test", types.AgentRuntimeKind) + if err == nil { + t.Fatalf("expected error for empty sessionID in response") + } + if !errors.Is(err, ErrCreateSandboxFailed) { + t.Errorf("expected ErrCreateSandboxFailed, got %v", err) + } +} diff --git a/pkg/router/utils.go b/pkg/router/utils.go deleted file mode 100644 index c75d2951..00000000 --- a/pkg/router/utils.go +++ /dev/null @@ -1,60 +0,0 @@ -package router - -import ( - "strconv" - "time" - - "github.com/gin-gonic/gin" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - sandboxv1alpha1 "sigs.k8s.io/agent-sandbox/api/v1alpha1" -) - -// ErrorResponse represents an API error response -type ErrorResponse struct { - Error string `json:"error"` - Message string `json:"message"` - Details map[string]interface{} `json:"details,omitempty"` - Timestamp time.Time `json:"timestamp"` - RequestID string `json:"requestId,omitempty"` -} - -// respondJSON sends a JSON response -func respondJSON(c *gin.Context, statusCode int, data interface{}) { - c.JSON(statusCode, data) -} - -// respondError sends an error response -func respondError(c *gin.Context, statusCode int, errorCode, message string) { - response := ErrorResponse{ - Error: errorCode, - Message: message, - Timestamp: time.Now(), - } - respondJSON(c, statusCode, response) -} - -// getIntQueryParam gets an integer value from query parameters, returns default value if not present -func getIntQueryParam(c *gin.Context, key string, defaultValue int) int { - valueStr := c.Query(key) - if valueStr == "" { - return defaultValue - } - - value, err := strconv.Atoi(valueStr) - if err != nil { - return defaultValue - } - - return value -} - -// getSandboxStatus extracts status from Sandbox CRD conditions -func getSandboxStatus(sandbox *sandboxv1alpha1.Sandbox) string { - // Check conditions for Ready status - for _, condition := range sandbox.Status.Conditions { - if condition.Type == string(sandboxv1alpha1.SandboxConditionReady) && condition.Status == metav1.ConditionTrue { - return "running" - } - } - return "paused" -} diff --git a/pkg/workloadmanager/garbage_collection.go b/pkg/workloadmanager/garbage_collection.go index 5a3383f1..77f4ed9c 100644 --- a/pkg/workloadmanager/garbage_collection.go +++ b/pkg/workloadmanager/garbage_collection.go @@ -6,10 +6,11 @@ import ( "log" "time" - "github.com/volcano-sh/agentcube/pkg/redis" "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" utilerrors "k8s.io/apimachinery/pkg/util/errors" + + "github.com/volcano-sh/agentcube/pkg/redis" ) const ( diff --git a/pkg/workloadmanager/handlers.go b/pkg/workloadmanager/handlers.go index 4e624b3d..3c54454b 100644 --- a/pkg/workloadmanager/handlers.go +++ b/pkg/workloadmanager/handlers.go @@ -11,10 +11,11 @@ import ( "github.com/gin-gonic/gin" redisv9 "github.com/redis/go-redis/v9" - "github.com/volcano-sh/agentcube/pkg/common/types" - "github.com/volcano-sh/agentcube/pkg/redis" sandboxv1alpha1 "sigs.k8s.io/agent-sandbox/api/v1alpha1" extensionsv1alpha1 "sigs.k8s.io/agent-sandbox/extensions/api/v1alpha1" + + "github.com/volcano-sh/agentcube/pkg/common/types" + "github.com/volcano-sh/agentcube/pkg/redis" ) // handleHealth handles health check requests diff --git a/pkg/workloadmanager/server.go b/pkg/workloadmanager/server.go index 67e78f04..ba01f812 100644 --- a/pkg/workloadmanager/server.go +++ b/pkg/workloadmanager/server.go @@ -10,6 +10,7 @@ import ( "github.com/gin-gonic/gin" redisv9 "github.com/redis/go-redis/v9" + "github.com/volcano-sh/agentcube/pkg/redis" ) From 5eab3347ddb2fb118585e8462c10952b4e5e0fb5 Mon Sep 17 00:00:00 2001 From: LeslieKuo <676365950@qq.com> Date: Thu, 11 Dec 2025 10:00:47 +0800 Subject: [PATCH 6/6] add readinessProbe --- pkg/common/types/sandbox.go | 39 ++++++++++-- pkg/workloadmanager/handlers.go | 4 +- pkg/workloadmanager/workload_builder.go | 81 ++++++++++++++++++++++--- 3 files changed, 110 insertions(+), 14 deletions(-) diff --git a/pkg/common/types/sandbox.go b/pkg/common/types/sandbox.go index d69886e8..ec9ff360 100644 --- a/pkg/common/types/sandbox.go +++ b/pkg/common/types/sandbox.go @@ -26,11 +26,40 @@ type SandboxEntryPoints struct { } type CreateSandboxRequest struct { - Kind string `json:"kind"` - Name string `json:"name"` - Namespace string `json:"namespace"` - Auth Auth `json:"auth"` - Metadata map[string]string `json:"metadata"` + Kind string `json:"kind"` + Name string `json:"name"` + Namespace string `json:"namespace"` + Auth Auth `json:"auth"` + Metadata map[string]string `json:"metadata"` + ReadinessProbe *ReadinessProbe `json:"readinessProbe,omitempty"` +} + +type ReadinessProbe struct { + InitialDelaySeconds int32 `json:"initialDelaySeconds,omitempty"` + PeriodSeconds int32 `json:"periodSeconds,omitempty"` + SuccessThreshold int32 `json:"successThreshold,omitempty"` + FailureThreshold int32 `json:"failureThreshold,omitempty"` + TimeoutSeconds int32 `json:"timeoutSeconds,omitempty"` + TCPSocket *TCPSocketAction `json:"tcpSocket,omitempty"` + HTTPGet *HTTPGetAction `json:"httpGet,omitempty"` + Exec *ExecAction `json:"exec,omitempty"` +} + +type TCPSocketAction struct { + Port int `json:"port"` + Host string `json:"host,omitempty"` +} + +type HTTPGetAction struct { + Path string `json:"path,omitempty"` + Port int `json:"port"` + Host string `json:"host,omitempty"` + Scheme string `json:"scheme,omitempty"` + HTTPHeaders map[string]string `json:"httpHeaders,omitempty"` +} + +type ExecAction struct { + Command []string `json:"command,omitempty"` } type Auth struct { diff --git a/pkg/workloadmanager/handlers.go b/pkg/workloadmanager/handlers.go index 3c54454b..44c6f671 100644 --- a/pkg/workloadmanager/handlers.go +++ b/pkg/workloadmanager/handlers.go @@ -56,9 +56,9 @@ func (s *Server) handleCreateSandbox(c *gin.Context) { var err error switch createAgentRequest.Kind { case types.AgentRuntimeKind: - sandbox, externalInfo, err = buildSandboxByAgentRuntime(createAgentRequest.Namespace, createAgentRequest.Name, s.informers) + sandbox, externalInfo, err = buildSandboxByAgentRuntime(createAgentRequest.Namespace, createAgentRequest.Name, s.informers, createAgentRequest.ReadinessProbe) case types.CodeInterpreterKind: - sandbox, sandboxClaim, externalInfo, err = buildSandboxByCodeInterpreter(createAgentRequest.Namespace, createAgentRequest.Name, s.informers) + sandbox, sandboxClaim, externalInfo, err = buildSandboxByCodeInterpreter(createAgentRequest.Namespace, createAgentRequest.Name, s.informers, createAgentRequest.ReadinessProbe) default: log.Printf("invalid request kind: %v", createAgentRequest.Kind) respondError(c, http.StatusBadRequest, "INVALID_REQUEST", fmt.Sprintf("invalid request kind: %v", createAgentRequest.Kind)) diff --git a/pkg/workloadmanager/workload_builder.go b/pkg/workloadmanager/workload_builder.go index e0d2c593..64b983b9 100644 --- a/pkg/workloadmanager/workload_builder.go +++ b/pkg/workloadmanager/workload_builder.go @@ -7,10 +7,12 @@ import ( "github.com/google/uuid" runtimev1alpha1 "github.com/volcano-sh/agentcube/pkg/apis/runtime/v1alpha1" + "github.com/volcano-sh/agentcube/pkg/common/types" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/intstr" "k8s.io/utils/ptr" sandboxv1alpha1 "sigs.k8s.io/agent-sandbox/api/v1alpha1" extensionsv1alpha1 "sigs.k8s.io/agent-sandbox/extensions/api/v1alpha1" @@ -26,6 +28,7 @@ type buildSandboxParams struct { podSpec corev1.PodSpec podLabels map[string]string podAnnotations map[string]string + readinessProbe *types.ReadinessProbe } type buildSandboxClaimParams struct { @@ -79,6 +82,68 @@ func buildSandboxObject(params *buildSandboxParams) *sandboxv1alpha1.Sandbox { } sandbox.Spec.PodTemplate.ObjectMeta.Labels[SessionIdLabelKey] = params.sessionID sandbox.Spec.PodTemplate.ObjectMeta.Labels["sandbox-name"] = params.sandboxName + + // Handle Readiness Probe + if len(sandbox.Spec.PodTemplate.Spec.Containers) > 0 { + container := &sandbox.Spec.PodTemplate.Spec.Containers[0] + + if params.readinessProbe != nil { + probe := &corev1.Probe{ + InitialDelaySeconds: params.readinessProbe.InitialDelaySeconds, + PeriodSeconds: params.readinessProbe.PeriodSeconds, + SuccessThreshold: params.readinessProbe.SuccessThreshold, + FailureThreshold: params.readinessProbe.FailureThreshold, + TimeoutSeconds: params.readinessProbe.TimeoutSeconds, + } + + if params.readinessProbe.TCPSocket != nil { + probe.ProbeHandler = corev1.ProbeHandler{ + TCPSocket: &corev1.TCPSocketAction{ + Port: intstr.FromInt(params.readinessProbe.TCPSocket.Port), + Host: params.readinessProbe.TCPSocket.Host, + }, + } + } else if params.readinessProbe.HTTPGet != nil { + headers := []corev1.HTTPHeader{} + for k, v := range params.readinessProbe.HTTPGet.HTTPHeaders { + headers = append(headers, corev1.HTTPHeader{Name: k, Value: v}) + } + probe.ProbeHandler = corev1.ProbeHandler{ + HTTPGet: &corev1.HTTPGetAction{ + Path: params.readinessProbe.HTTPGet.Path, + Port: intstr.FromInt(params.readinessProbe.HTTPGet.Port), + Host: params.readinessProbe.HTTPGet.Host, + Scheme: corev1.URIScheme(params.readinessProbe.HTTPGet.Scheme), + HTTPHeaders: headers, + }, + } + } else if params.readinessProbe.Exec != nil { + probe.ProbeHandler = corev1.ProbeHandler{ + Exec: &corev1.ExecAction{ + Command: params.readinessProbe.Exec.Command, + }, + } + } + container.ReadinessProbe = probe + } else if container.ReadinessProbe == nil { + // Add default TCP probe if ports are available + if len(container.Ports) > 0 { + container.ReadinessProbe = &corev1.Probe{ + ProbeHandler: corev1.ProbeHandler{ + TCPSocket: &corev1.TCPSocketAction{ + Port: intstr.FromInt(int(container.Ports[0].ContainerPort)), + }, + }, + InitialDelaySeconds: 5, + PeriodSeconds: 10, + SuccessThreshold: 1, + FailureThreshold: 3, + TimeoutSeconds: 1, + } + } + } + } + return sandbox } @@ -103,7 +168,7 @@ func buildSandboxClaimObject(params *buildSandboxClaimParams) *extensionsv1alpha return sandboxClaim } -func buildSandboxByAgentRuntime(namespace string, name string, ifm *Informers) (*sandboxv1alpha1.Sandbox, *sandboxExternalInfo, error) { +func buildSandboxByAgentRuntime(namespace string, name string, ifm *Informers, readinessProbe *types.ReadinessProbe) (*sandboxv1alpha1.Sandbox, *sandboxExternalInfo, error) { agentRuntimeKey := namespace + "/" + name runtimeObj, exists, err := ifm.AgentRuntimeInformer.GetStore().GetByKey(agentRuntimeKey) if err != nil { @@ -131,11 +196,12 @@ func buildSandboxByAgentRuntime(namespace string, name string, ifm *Informers) ( sessionID := uuid.New().String() sandboxName := "agent-runtime-" + uuid.New().String() buildParams := &buildSandboxParams{ - namespace: namespace, - workloadName: name, - sandboxName: sandboxName, - sessionID: sessionID, - podSpec: agentRuntimeObj.Spec.Template.Spec, + namespace: namespace, + workloadName: name, + sandboxName: sandboxName, + sessionID: sessionID, + podSpec: agentRuntimeObj.Spec.Template.Spec, + readinessProbe: readinessProbe, } if agentRuntimeObj.Spec.MaxSessionDuration != nil { buildParams.ttl = agentRuntimeObj.Spec.MaxSessionDuration.Duration @@ -151,7 +217,7 @@ func buildSandboxByAgentRuntime(namespace string, name string, ifm *Informers) ( return sandbox, externalInfo, nil } -func buildSandboxByCodeInterpreter(namespace string, codeInterpreterName string, ifm *Informers) (*sandboxv1alpha1.Sandbox, *extensionsv1alpha1.SandboxClaim, *sandboxExternalInfo, error) { +func buildSandboxByCodeInterpreter(namespace string, codeInterpreterName string, ifm *Informers, readinessProbe *types.ReadinessProbe) (*sandboxv1alpha1.Sandbox, *extensionsv1alpha1.SandboxClaim, *sandboxExternalInfo, error) { codeInterpreterKey := namespace + "/" + codeInterpreterName runtimeObj, exists, err := ifm.CodeInterpreterInformer.GetStore().GetByKey(codeInterpreterKey) if err != nil { @@ -220,6 +286,7 @@ func buildSandboxByCodeInterpreter(namespace string, codeInterpreterName string, podSpec: podSpec, podLabels: codeInterpreterObj.Spec.Template.Labels, podAnnotations: codeInterpreterObj.Spec.Template.Annotations, + readinessProbe: readinessProbe, } if codeInterpreterObj.Spec.MaxSessionDuration != nil { buildParams.ttl = codeInterpreterObj.Spec.MaxSessionDuration.Duration