diff --git a/Dockerfile.router b/Dockerfile.router new file mode 100644 index 00000000..475c7441 --- /dev/null +++ b/Dockerfile.router @@ -0,0 +1,41 @@ +# Multi-stage build for agentcube-router +FROM golang:1.24.9-alpine AS builder + +# Build arguments for multi-architecture support +ARG TARGETOS=linux +ARG TARGETARCH + +WORKDIR /workspace + +# Copy go mod files +COPY go.mod go.sum ./ +RUN go mod download + +# Copy source code +COPY cmd/ cmd/ +COPY pkg/ pkg/ +COPY client-go/ client-go/ + +# Build with dynamic architecture support +# Supports amd64, arm64, arm/v7, etc. +RUN CGO_ENABLED=0 GOOS=${TARGETOS} GOARCH=${TARGETARCH} \ + go build -o agentcube-router ./cmd/router + +# Runtime image +FROM alpine:3.19 + +RUN apk --no-cache add ca-certificates + +WORKDIR /app + +# Copy binary from builder +COPY --from=builder /workspace/agentcube-router . + +# Run as non-root user +RUN adduser -D -u 1000 router +USER router + +EXPOSE 8080 + +ENTRYPOINT ["/app/agentcube-router"] +CMD ["--port=8080", "--debug"] diff --git a/Makefile b/Makefile index 33bb16d4..63e750f6 100644 --- a/Makefile +++ b/Makefile @@ -72,6 +72,10 @@ build: generate ## Build agentcube-apiserver binary @echo "Building agentcube-apiserver..." go build -o bin/agentcube-apiserver ./cmd/workload-manager +build-router: ## Build agentcube-router binary + @echo "Building agentcube-router..." + go build -o bin/agentcube-router ./cmd/router + build-agentd: generate ## Build agentd binary @echo "Building agentd..." go build -o bin/agentd ./cmd/agentd @@ -80,30 +84,27 @@ build-test-tunnel: ## Build test-tunnel tool @echo "Building test-tunnel..." go build -o bin/test-tunnel ./cmd/test-tunnel -build-all: build build-agentd build-test-tunnel ## Build all binaries +build-all: build build-router build-agentd build-test-tunnel ## Build all binaries # Run server (development mode) run: @echo "Running agentcube-apiserver..." go run ./cmd/workload-manager/main.go \ --port=8080 \ - --ssh-username=sandbox \ - --ssh-port=22 + --debug # Run server (with kubeconfig) run-local: @echo "Running agentcube-apiserver with local kubeconfig..." go run ./cmd/workload-manager/main.go \ --port=8080 \ - --kubeconfig=${HOME}/.kube/config \ - --ssh-username=sandbox \ - --ssh-port=22 + --debug # Clean build artifacts clean: @echo "Cleaning..." rm -rf bin/ - rm -f agentcube-apiserver agentd + rm -f agentcube-router agentd # Install dependencies deps: @@ -138,11 +139,12 @@ lint: golangci-lint ## Run golangci-lint # Install to system install: build - @echo "Installing agentcube-apiserver..." - sudo cp bin/agentcube-apiserver /usr/local/bin/ + @echo "Installing agentcube-router..." + sudo cp bin/agentcube-router /usr/local/bin/ # Docker image variables APISERVER_IMAGE ?= agentcube-apiserver:latest +ROUTER_IMAGE ?= agentcube-router:latest IMAGE_REGISTRY ?= "" # Docker and Kubernetes targets @@ -150,6 +152,10 @@ docker-build: @echo "Building Docker image..." docker build -t $(APISERVER_IMAGE) . +docker-build-router: ## Build router Docker image + @echo "Building router Docker image..." + docker build -f Dockerfile.router -t $(ROUTER_IMAGE) . + # Multi-architecture build (supports amd64, arm64) docker-buildx: @echo "Building multi-architecture Docker image..." @@ -177,15 +183,15 @@ docker-push: docker-build k8s-deploy: @echo "Deploying to Kubernetes..." - kubectl apply -f k8s/agentcube-apiserver.yaml + kubectl apply -f k8s/agentcube-router.yaml k8s-delete: @echo "Deleting from Kubernetes..." - kubectl delete -f k8s/agentcube-apiserver.yaml + kubectl delete -f k8s/agentcube-router.yaml k8s-logs: @echo "Showing logs..." - kubectl logs -n agentcube -l app=agentcube-apiserver -f + kubectl logs -n agentcube -l app=agentcube-router -f # Load image to kind cluster kind-load: diff --git a/cmd/router/main.go b/cmd/router/main.go new file mode 100644 index 00000000..b55e7939 --- /dev/null +++ b/cmd/router/main.go @@ -0,0 +1,75 @@ +package main + +import ( + "context" + "flag" + "log" + "os" + "os/signal" + "syscall" + + "github.com/volcano-sh/agentcube/pkg/router" +) + +func main() { + var ( + port = flag.String("port", "8080", "Router API server port") + enableTLS = flag.Bool("enable-tls", false, "Enable TLS (HTTPS)") + tlsCert = flag.String("tls-cert", "", "Path to TLS certificate file") + tlsKey = flag.String("tls-key", "", "Path to TLS key file") + debug = flag.Bool("debug", true, "Enable debug mode") + maxConcurrentRequests = flag.Int("max-concurrent-requests", 1000, "Maximum number of concurrent requests") + requestTimeout = flag.Int("request-timeout", 30, "Request timeout in seconds") + maxIdleConns = flag.Int("max-idle-conns", 100, "Maximum number of idle connections") + maxConnsPerHost = flag.Int("max-conns-per-host", 10, "Maximum number of connections per host") + ) + + // Parse command line flags + flag.Parse() + + // Create Router API server configuration + config := &router.Config{ + Port: *port, + Debug: *debug, + EnableTLS: *enableTLS, + TLSCert: *tlsCert, + TLSKey: *tlsKey, + MaxConcurrentRequests: *maxConcurrentRequests, + RequestTimeout: *requestTimeout, + MaxIdleConns: *maxIdleConns, + MaxConnsPerHost: *maxConnsPerHost, + } + + // Create Router API server + server, err := router.NewServer(config) + if err != nil { + log.Fatalf("Failed to create Router API server: %v", err) + } + + // Setup signal handling with context cancellation + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer cancel() + + // Start Router API server in goroutine + errCh := make(chan error, 1) + go func() { + log.Printf("Starting agentcube Router server on port %s", *port) + if err := server.Start(ctx); err != nil { + errCh <- err + } + }() + + // Wait for signal or error + select { + case <-ctx.Done(): + log.Println("Received shutdown signal, shutting down gracefully...") + // Cancel the context to trigger server shutdown + cancel() + // Wait for server goroutine to exit after graceful shutdown is complete + <-errCh + case err := <-errCh: + log.Fatalf("Server error: %v", err) + } + + log.Println("Router server stopped") +} diff --git a/docs/proposal/router-design.md b/docs/proposal/router-design.md new file mode 100644 index 00000000..9df17add --- /dev/null +++ b/docs/proposal/router-design.md @@ -0,0 +1,295 @@ +# Router Submodule Design Document + +## 1. Overview + +Router apiserver is responsible for receiving user HTTP requests and forwarding them to the corresponding Sandbox. Router focuses on high-performance request routing, while session and sandbox management is handled by SessionManager. + +## 2. Architecture Design + +### 2.1 Overall Architecture Flow + +```mermaid +graph TB + Client[Client] --> Router[Router API Server] + Router --> SessionMgr[SessionManager Interface] + Router --> Sandbox1[Sandbox 1] + Router --> Sandbox2[Sandbox 2] + Router --> SandboxN[Sandbox N] + + subgraph "Router Core Components" + Router + SessionMgr + end + + subgraph "Sandbox Cluster" + Sandbox1 + Sandbox2 + SandboxN + end +``` + +### 2.2 Request Routing Flow + +```mermaid +sequenceDiagram + participant C as Client + participant R as Router + participant SM as SessionManager (Mock) + participant SB as Sandbox + + C->>R: HTTP Request with x-agentcube-session-id + + alt session-id exists + R->>SM: getSandboxInfoBySessionId(session-id) + SM->>R: return (endpoint, session-id, nil) + else session-id is empty + R->>SM: getSandboxInfoBySessionId("") + SM->>R: return (endpoint, new-session-id, nil) + end + + alt get sandbox success + R->>SB: forward request to sandbox endpoint + SB->>R: return response + R->>C: forward response + session-id header + else get sandbox failed + R->>C: return error response + end +``` + +## 3. Detailed Design + +### 3.1 SessionManager Interface Definition + +Router obtains Sandbox information through the SessionManager interface, which can be implemented with Mock: + +```go +type SessionManager interface { + // Get sandbox information based on session-id + // When sessionId is empty, create a new session + GetSandboxInfoBySessionId(sessionId string) (endpoint string, newSessionId string, err error) +} + +// Mock implementation example +type MockSessionManager struct { + sandboxEndpoints []string + currentIndex int +} + +func (m *MockSessionManager) GetSandboxInfoBySessionId(sessionId string) (string, string, error) { + if sessionId == "" { + sessionId = generateNewSessionId() + } + + // Simple round-robin sandbox selection + endpoint := m.sandboxEndpoints[m.currentIndex%len(m.sandboxEndpoints)] + m.currentIndex++ + + return endpoint, sessionId, nil +} +``` + +### 3.2 Supported Request Types + +Uses Gin framework to provide HTTP Server services, handling two types of requests: + +1. **Agent Invoke Requests** + ``` + :/v1/namespaces/{agentNamespace}/agent-runtimes/{agentName}/invocations/ + ``` + +2. **Code Interpreter Invoke Requests** + ``` + :/v1/namespaces/{namespace}/code-interpreters/{name}/invocations/ + ``` + +### 3.3 Request Processing Flow + +```mermaid +flowchart TD + Start([Receive HTTP Request]) --> ValidateReq{Validate Request Format} + ValidateReq -->|Invalid| ReturnBadRequest[Return 400 Bad Request] + ValidateReq -->|Valid| ExtractSessionId[Extract x-agentcube-session-id] + + ExtractSessionId --> GetSandbox[Call SessionMgr.GetSandboxInfoBySessionId] + GetSandbox --> CheckResult{Check Result} + + CheckResult -->|Success| ForwardRequest[Forward Request to Sandbox] + CheckResult -->|Interface Error| ReturnInternalError[Return 500 Internal Server Error] + + ForwardRequest --> CheckSandboxResponse{Check Sandbox Response} + CheckSandboxResponse -->|Success| ReturnSuccess[Return Success Response + Session ID] + CheckSandboxResponse -->|Timeout| ReturnTimeout[Return 504 Gateway Timeout] + CheckSandboxResponse -->|Connection Failed| ReturnBadGateway[Return 502 Bad Gateway] + CheckSandboxResponse -->|Other Error| ReturnSandboxError[Return Sandbox Error] + + ReturnBadRequest --> End([End]) + ReturnInternalError --> End + ReturnSuccess --> End + ReturnTimeout --> End + ReturnBadGateway --> End + ReturnSandboxError --> End +``` + +### 3.4 Core Requirements + +1. **High-Performance Routing**: Fast routing to corresponding sandbox based on session-id +2. **Session Integration**: Seamless collaboration with SessionManager, supporting dynamic sandbox creation +3. **Long Connection Support**: Support for long-running requests such as code execution and file operations +4. **Simple Design**: Focus on core routing functionality, avoid over-engineering +5. **Graceful Shutdown**: Reference E2B's graceful shutdown process to ensure no request loss + +### 3.5 Design Goals + +- **High Performance**: Millisecond-level routing latency, support for high concurrency +- **High Availability**: Stateless design, support for horizontal scaling +- **Observability**: Complete monitoring, logging, and tracing system + +## 4. HTTP Response Handling + +### 4.1 Success Responses + +| Status Code | Scenario | Response Headers | Response Body | +|-------------|----------|------------------|---------------| +| 200 OK | Request processed successfully | `x-agentcube-session-id: ` | Original response from Sandbox | +| 201 Created | Resource created successfully | `x-agentcube-session-id: ` | Created resource information | +| 202 Accepted | Async request accepted | `x-agentcube-session-id: ` | Task status information | + +### 4.2 Client Error Responses + +| Status Code | Scenario | Response Body Example | +|-------------|----------|----------------------| +| 400 Bad Request | Invalid request format, invalid parameters | `{"error": "invalid request format", "code": "INVALID_REQUEST"}` | +| 401 Unauthorized | Authentication failed | `{"error": "authentication required", "code": "AUTH_REQUIRED"}` | +| 403 Forbidden | Insufficient permissions | `{"error": "insufficient permissions", "code": "PERMISSION_DENIED"}` | +| 404 Not Found | Session or resource not found | `{"error": "session not found", "code": "SESSION_NOT_FOUND"}` | +| 409 Conflict | Resource conflict | `{"error": "resource conflict", "code": "RESOURCE_CONFLICT"}` | +| 429 Too Many Requests | Rate limit exceeded | `{"error": "rate limit exceeded", "code": "RATE_LIMIT_EXCEEDED"}` | + +### 4.3 Server Error Responses + +| Status Code | Scenario | Response Body Example | +|-------------|----------|----------------------| +| 500 Internal Server Error | Router internal error | `{"error": "internal server error", "code": "INTERNAL_ERROR"}` | +| 502 Bad Gateway | Sandbox connection failed | `{"error": "sandbox unreachable", "code": "SANDBOX_UNREACHABLE"}` | +| 503 Service Unavailable | Sandbox unavailable or overloaded | `{"error": "sandbox unavailable", "code": "SANDBOX_UNAVAILABLE"}` | +| 504 Gateway Timeout | Sandbox response timeout | `{"error": "sandbox timeout", "code": "SANDBOX_TIMEOUT"}` | + +### 4.4 Error Handling Flow + +```mermaid +flowchart TD + Error[Error Occurred] --> CheckErrorType{Check Error Type} + + CheckErrorType -->|Request Validation Error| ClientError[Client Error 4xx] + CheckErrorType -->|SessionMgr Error| CheckSessionError{Session Error Type} + CheckErrorType -->|Sandbox Error| CheckSandboxError{Sandbox Error Type} + CheckErrorType -->|Router Internal Error| ServerError[Server Error 5xx] + + CheckSessionError -->|Session Not Found| Return404[404 Not Found] + CheckSessionError -->|Session Creation Failed| Return500[500 Internal Error] + CheckSessionError -->|Permission Error| Return403[403 Forbidden] + + CheckSandboxError -->|Connection Failed| Return502[502 Bad Gateway] + CheckSandboxError -->|Response Timeout| Return504[504 Gateway Timeout] + CheckSandboxError -->|Service Unavailable| Return503[503 Service Unavailable] + CheckSandboxError -->|Sandbox Internal Error| ForwardError[Forward Sandbox Error] + + ClientError --> LogError[Log Error] + ServerError --> LogError + Return404 --> LogError + Return500 --> LogError + Return403 --> LogError + Return502 --> LogError + Return504 --> LogError + Return503 --> LogError + ForwardError --> LogError + + LogError --> UpdateMetrics[Update Metrics] + UpdateMetrics --> ReturnResponse[Return Error Response] +``` + +## 5. Performance and Monitoring + +### 5.1 Performance Metrics + +- **Routing Latency**: Target < 5ms (P99) +- **Throughput**: Target > 10,000 RPS +- **Concurrent Connections**: Support > 50,000 concurrent connections +- **Memory Usage**: < 1GB (steady state) + +### 5.2 Monitoring Metrics + +```mermaid +graph LR + subgraph "Request Metrics" + A[Total Requests] + B[Request Latency] + C[Error Rate] + D[Concurrency] + end + + subgraph "Routing Metrics" + E[Routing Success Rate] + F[Session Creation Rate] + G[Sandbox Hit Rate] + end + + subgraph "System Metrics" + H[CPU Usage] + I[Memory Usage] + J[Network I/O] + K[Connection Pool Status] + end +``` + +### 5.3 Logging + +- **Access Logs**: Record all HTTP requests +- **Error Logs**: Record all errors and exceptions +- **Performance Logs**: Record key performance metrics +- **Audit Logs**: Record important operations and state changes + +## 6. Deployment and Operations + +### 6.1 Deployment Architecture + +```mermaid +graph TB + LB[Load Balancer] --> R1[Router Instance 1] + LB --> R2[Router Instance 2] + LB --> R3[Router Instance N] + + R1 --> SM[SessionManager Interface] + R2 --> SM + R3 --> SM + + R1 --> SB1[Sandbox Cluster] + R2 --> SB1 + R3 --> SB1 + + subgraph "Router Layer" + R1 + R2 + R3 + end + + subgraph "Interface Layer" + SM + end + + subgraph "Sandbox Layer" + SB1 + end +``` + +### 6.2 Configuration Management + +- **Environment Configuration**: Support for multi-environment configuration (dev/staging/prod) +- **Dynamic Configuration**: Support for runtime configuration updates +- **Configuration Validation**: Validate configuration integrity at startup + +### 6.3 Health Checks + +- **Liveness Check**: `/health/live` +- **Readiness Check**: `/health/ready` +- **Dependency Check**: Verify connectivity to SessionManager and Sandbox diff --git a/go.mod b/go.mod index 72b43efb..8acc13c3 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( k8s.io/api v0.34.1 k8s.io/apimachinery v0.34.1 k8s.io/client-go v0.34.1 + k8s.io/utils v0.0.0-20251002143259-bc988d571ff4 sigs.k8s.io/agent-sandbox v0.1.0 sigs.k8s.io/controller-runtime v0.22.2 ) @@ -97,7 +98,6 @@ require ( k8s.io/apiextensions-apiserver v0.34.1 // indirect k8s.io/klog/v2 v2.130.1 // indirect k8s.io/kube-openapi v0.0.0-20250910181357-589584f1c912 // indirect - k8s.io/utils v0.0.0-20251002143259-bc988d571ff4 // indirect sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 // indirect sigs.k8s.io/randfill v1.0.0 // indirect sigs.k8s.io/structured-merge-diff/v6 v6.3.0 // indirect diff --git a/k8s/agentcube-apiserver.yaml b/k8s/agentcube-router.yaml similarity index 77% rename from k8s/agentcube-apiserver.yaml rename to k8s/agentcube-router.yaml index 0ec7ac16..01ada16a 100644 --- a/k8s/agentcube-apiserver.yaml +++ b/k8s/agentcube-router.yaml @@ -1,5 +1,5 @@ -# All-in-one deployment file for agentcube-apiserver -# Apply with: kubectl apply -f k8s/agentcube-apiserver.yaml +# All-in-one deployment file for agentcube-router +# Apply with: kubectl apply -f k8s/agentcube-router.yaml --- apiVersion: v1 @@ -11,14 +11,14 @@ metadata: apiVersion: v1 kind: ServiceAccount metadata: - name: agentcube-apiserver + name: agentcube-router namespace: agentcube --- apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRole metadata: - name: agentcube-apiserver + name: agentcube-router rules: - apiGroups: ["agents.x-k8s.io"] resources: ["sandboxes"] @@ -43,10 +43,13 @@ rules: verbs: ["update"] - apiGroups: ["runtime.agentcube.volcano.sh"] resources: ["agentruntimes"] - verbs: ["get", "list", "watch"] + verbs: ["get", "list", "watch", "create", "update", "patch", "delete"] - apiGroups: ["runtime.agentcube.volcano.sh"] resources: ["agentruntimes/status"] - verbs: ["update", "patch"] + verbs: ["get", "update", "patch"] + - apiGroups: ["runtime.agentcube.volcano.sh"] + resources: ["agentruntimes/finalizers"] + verbs: ["update"] - apiGroups: [""] resources: ["pods"] verbs: ["get", "list", "watch"] @@ -58,38 +61,38 @@ rules: apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRoleBinding metadata: - name: agentcube-apiserver + name: agentcube-router roleRef: apiGroup: rbac.authorization.k8s.io kind: ClusterRole - name: agentcube-apiserver + name: agentcube-router subjects: - kind: ServiceAccount - name: agentcube-apiserver + name: agentcube-router namespace: agentcube --- apiVersion: apps/v1 kind: Deployment metadata: - name: agentcube-apiserver + name: agentcube-router namespace: agentcube labels: - app: agentcube-apiserver + app: agentcube-router spec: replicas: 1 selector: matchLabels: - app: agentcube-apiserver + app: agentcube-router template: metadata: labels: - app: agentcube-apiserver + app: agentcube-router spec: - serviceAccountName: agentcube-apiserver + serviceAccountName: agentcube-router containers: - - name: agentcube-apiserver - image: agentcube-apiserver:latest + - name: agentcube-router + image: agentcube-router:latest imagePullPolicy: IfNotPresent ports: - name: http @@ -100,9 +103,11 @@ spec: value: "127.0.0.1:6379" - name: REDIS_PASSWORD value: "" + - name: WORKLOAD_MGR_URL + value: "http://agentcube-workload-manager:8080" args: - --port=8080 - - --runtime-class-name= + - --debug resources: requests: cpu: 100m @@ -127,10 +132,10 @@ spec: apiVersion: v1 kind: Service metadata: - name: agentcube-apiserver + name: agentcube-router namespace: agentcube labels: - app: agentcube-apiserver + app: agentcube-router spec: type: ClusterIP ports: @@ -139,4 +144,4 @@ spec: protocol: TCP name: http selector: - app: agentcube-apiserver + app: agentcube-router diff --git a/pkg/common/types/sandbox.go b/pkg/common/types/sandbox.go index d69886e8..ec9ff360 100644 --- a/pkg/common/types/sandbox.go +++ b/pkg/common/types/sandbox.go @@ -26,11 +26,40 @@ type SandboxEntryPoints struct { } type CreateSandboxRequest struct { - Kind string `json:"kind"` - Name string `json:"name"` - Namespace string `json:"namespace"` - Auth Auth `json:"auth"` - Metadata map[string]string `json:"metadata"` + Kind string `json:"kind"` + Name string `json:"name"` + Namespace string `json:"namespace"` + Auth Auth `json:"auth"` + Metadata map[string]string `json:"metadata"` + ReadinessProbe *ReadinessProbe `json:"readinessProbe,omitempty"` +} + +type ReadinessProbe struct { + InitialDelaySeconds int32 `json:"initialDelaySeconds,omitempty"` + PeriodSeconds int32 `json:"periodSeconds,omitempty"` + SuccessThreshold int32 `json:"successThreshold,omitempty"` + FailureThreshold int32 `json:"failureThreshold,omitempty"` + TimeoutSeconds int32 `json:"timeoutSeconds,omitempty"` + TCPSocket *TCPSocketAction `json:"tcpSocket,omitempty"` + HTTPGet *HTTPGetAction `json:"httpGet,omitempty"` + Exec *ExecAction `json:"exec,omitempty"` +} + +type TCPSocketAction struct { + Port int `json:"port"` + Host string `json:"host,omitempty"` +} + +type HTTPGetAction struct { + Path string `json:"path,omitempty"` + Port int `json:"port"` + Host string `json:"host,omitempty"` + Scheme string `json:"scheme,omitempty"` + HTTPHeaders map[string]string `json:"httpHeaders,omitempty"` +} + +type ExecAction struct { + Command []string `json:"command,omitempty"` } type Auth struct { diff --git a/pkg/router/config.go b/pkg/router/config.go new file mode 100644 index 00000000..a0ff58c7 --- /dev/null +++ b/pkg/router/config.go @@ -0,0 +1,50 @@ +package router + +// LastActivityAnnotationKey is the annotation key for tracking last activity +const LastActivityAnnotationKey = "agentcube.volcano.sh/last-activity" + +// Config contains configuration parameters for Router apiserver +type Config struct { + // Port is the port the API server listens on + Port string + + // SandboxEndpoints is the list of available sandbox endpoints + SandboxEndpoints []string + + // Debug enables debug mode + Debug bool + + // EnableTLS enables HTTPS + EnableTLS bool + + // TLSCert is the path to the TLS certificate file + TLSCert string + + // TLSKey is the path to the TLS private key file + TLSKey string + + // MaxConcurrentRequests limits the number of concurrent requests (0 = unlimited) + MaxConcurrentRequests int + + // RequestTimeout sets the timeout for individual requests + RequestTimeout int // seconds + + // MaxIdleConns sets the maximum number of idle connections in the connection pool + MaxIdleConns int + + // MaxConnsPerHost sets the maximum number of connections per host + MaxConnsPerHost int + + // Redis configuration + // EnableRedis enables Redis session activity tracking + EnableRedis bool + + // RedisAddr is the Redis server address (e.g., "localhost:6379") + RedisAddr string + + // RedisPassword is the Redis password (optional) + RedisPassword string + + // RedisDB is the Redis database number + RedisDB int +} diff --git a/pkg/router/handlers.go b/pkg/router/handlers.go new file mode 100644 index 00000000..c540361d --- /dev/null +++ b/pkg/router/handlers.go @@ -0,0 +1,211 @@ +package router + +import ( + "fmt" + "log" + "net/http" + "net/http/httputil" + "net/url" + "strings" + "time" + + "github.com/gin-gonic/gin" +) + +// handleHealth handles health check requests +func (s *Server) handleHealth(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "status": "healthy", + }) +} + +// handleHealthLive handles liveness probe +func (s *Server) handleHealthLive(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "status": "alive", + }) +} + +// handleHealthReady handles readiness probe +func (s *Server) handleHealthReady(c *gin.Context) { + // Check if SessionManager is available + if s.sessionManager == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{ + "status": "not ready", + "error": "session manager not available", + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "status": "ready", + }) +} + +// handleInvoke is a private helper function that handles invocation requests for both agents and code interpreters +func (s *Server) handleInvoke(c *gin.Context, namespace, name, path, kind string) { + log.Printf("%s invoke request: namespace=%s, name=%s, path=%s", kind, namespace, name, path) + + // Extract session ID from header + sessionID := c.GetHeader("x-agentcube-session-id") + + // Get sandbox info from session manager + sandbox, err := s.sessionManager.GetSandboxBySession(c.Request.Context(), sessionID, namespace, name, kind) + if err != nil { + log.Printf("Failed to get sandbox info: %v, session id %s", err, sessionID) + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Invalid session id %s", sessionID), + "code": "BadRequest", + }) + return + } + + // Extract endpoint from sandbox - find matching entry point by path + var endpoint string + for _, ep := range sandbox.EntryPoints { + if strings.HasPrefix(path, ep.Path) { + // Only add protocol if not already present + if ep.Protocol != "" && !strings.Contains(ep.Endpoint, "://") { + endpoint = strings.ToLower(ep.Protocol) + "://" + ep.Endpoint + } else { + endpoint = ep.Endpoint + } + break + } + } + + // If no matching endpoint found, use the first one as fallback + if endpoint == "" { + if len(sandbox.EntryPoints) == 0 { + log.Printf("No entry points found for sandbox: %s", sandbox.SandboxID) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "internal server error", + "code": "INTERNAL_ERROR", + }) + return + } + // Only add protocol if not already present + if sandbox.EntryPoints[0].Protocol != "" && !strings.Contains(sandbox.EntryPoints[0].Endpoint, "://") { + endpoint = strings.ToLower(sandbox.EntryPoints[0].Protocol) + "://" + sandbox.EntryPoints[0].Endpoint + } else { + endpoint = sandbox.EntryPoints[0].Endpoint + } + } + + log.Printf("The selected entrypoint for session-id %s to sandbox is %s", sandbox.SessionID, endpoint) + + // Update session activity in Redis when receiving request + if sandbox.SessionID != "" && sandbox.SandboxID != "" { + if err := s.redisClient.UpdateSessionLastActivity(c.Request.Context(), sandbox.SessionID, time.Now()); err != nil { + log.Printf("Failed to update sandbox last activity for request: %v", err) + } + } + + // Forward request to sandbox with session ID + s.forwardToSandbox(c, endpoint, path, sandbox.SessionID) + + if err := s.redisClient.UpdateSessionLastActivity(c.Request.Context(), sandbox.SessionID, time.Now()); err != nil { + log.Printf("Failed to update sandbox last activity for request: %v", err) + } +} + +// handleAgentInvoke handles agent invocation requests +func (s *Server) handleAgentInvoke(c *gin.Context) { + namespace := c.Param("namespace") + name := c.Param("name") + path := c.Param("path") + s.handleInvoke(c, namespace, name, path, "AgentRuntime") +} + +// handleCodeInterpreterInvoke handles code interpreter invocation requests +func (s *Server) handleCodeInterpreterInvoke(c *gin.Context) { + namespace := c.Param("namespace") + name := c.Param("name") + path := c.Param("path") + s.handleInvoke(c, namespace, name, path, "CodeInterpreter") +} + +// forwardToSandbox forwards the request to the specified sandbox endpoint +func (s *Server) forwardToSandbox(c *gin.Context, endpoint, path, sessionID string) { + // Parse the target URL + targetURL, err := url.Parse(endpoint) + if err != nil { + log.Printf("Invalid sandbox endpoint: %s, error: %v", endpoint, err) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "internal server error", + "code": "INTERNAL_ERROR", + }) + return + } + + // Create reverse proxy with reusable transport + proxy := httputil.NewSingleHostReverseProxy(targetURL) + + // Use the shared HTTP transport for connection pooling + proxy.Transport = s.httpTransport + + // Customize the director to modify the request + originalDirector := proxy.Director + proxy.Director = func(req *http.Request) { + originalDirector(req) + + // Set the target path + if path != "" && !strings.HasPrefix(path, "/") { + path = "/" + path + } + req.URL.Path = path + req.URL.RawPath = "" + + // Set the host header + req.Host = targetURL.Host + + // Add forwarding headers + req.Header.Set("X-Forwarded-Host", c.Request.Host) + req.Header.Set("X-Forwarded-Proto", "http") + if c.Request.TLS != nil { + req.Header.Set("X-Forwarded-Proto", "https") + } + + log.Printf("Forwarding request to: %s%s", targetURL.String(), path) + } + + // Customize error handler + proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { + log.Printf("Proxy error: %v", err) + + // Determine error type and return appropriate response + if strings.Contains(err.Error(), "connection refused") { + c.JSON(http.StatusBadGateway, gin.H{ + "error": "sandbox unreachable", + "code": "SANDBOX_UNREACHABLE", + }) + } else if strings.Contains(err.Error(), "timeout") { + c.JSON(http.StatusGatewayTimeout, gin.H{ + "error": "sandbox timeout", + "code": "SANDBOX_TIMEOUT", + }) + } else { + c.JSON(http.StatusBadGateway, gin.H{ + "error": "sandbox unreachable", + "code": "SANDBOX_UNREACHABLE", + }) + } + } + + // Modify response + proxy.ModifyResponse = func(resp *http.Response) error { + // Always set session ID in response header + if sessionID != "" { + resp.Header.Set("x-agentcube-session-id", sessionID) + } + return nil + } + + // No timeout for invoke requests to allow long-running operations + // ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(s.config.RequestTimeout)*time.Second) + // defer cancel() + // c.Request = c.Request.WithContext(ctx) + + // Use the proxy to serve the request + proxy.ServeHTTP(c.Writer, c.Request) +} diff --git a/pkg/router/handlers_test.go b/pkg/router/handlers_test.go new file mode 100644 index 00000000..60b3cc4b --- /dev/null +++ b/pkg/router/handlers_test.go @@ -0,0 +1,485 @@ +package router + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "os" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/volcano-sh/agentcube/pkg/common/types" +) + +func init() { + // Set Gin to test mode + gin.SetMode(gin.TestMode) +} + +// Mock SessionManager for testing +type mockSessionManager struct { + sandbox *types.SandboxRedis + err error +} + +func (m *mockSessionManager) GetSandboxBySession(ctx context.Context, sessionID string, namespace string, name string, kind string) (*types.SandboxRedis, error) { + return m.sandbox, m.err +} + +func TestHandleHealth(t *testing.T) { + // Set required environment variables + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + + config := &Config{ + Port: "8080", + } + + server, err := NewServer(config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/health", nil) + server.engine.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code) + } + + expectedBody := `{"status":"healthy"}` + if w.Body.String() != expectedBody { + t.Errorf("Expected body %s, got %s", expectedBody, w.Body.String()) + } +} + +func TestHandleHealthLive(t *testing.T) { + // Set required environment variables + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + + config := &Config{ + Port: "8080", + } + + server, err := NewServer(config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/health/live", nil) + server.engine.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code) + } + + expectedBody := `{"status":"alive"}` + if w.Body.String() != expectedBody { + t.Errorf("Expected body %s, got %s", expectedBody, w.Body.String()) + } +} + +func TestHandleHealthReady(t *testing.T) { + // Set required environment variables + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + + tests := []struct { + name string + sessionManager SessionManager + expectedStatusCode int + expectedBody string + }{ + { + name: "ready with session manager", + sessionManager: &mockSessionManager{}, + expectedStatusCode: http.StatusOK, + expectedBody: `{"status":"ready"}`, + }, + { + name: "not ready without session manager", + sessionManager: nil, + expectedStatusCode: http.StatusServiceUnavailable, + expectedBody: `{"error":"session manager not available","status":"not ready"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := &Config{ + Port: "8080", + } + + server, err := NewServer(config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + // Override session manager for testing + server.sessionManager = tt.sessionManager + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/health/ready", nil) + server.engine.ServeHTTP(w, req) + + if w.Code != tt.expectedStatusCode { + t.Errorf("Expected status code %d, got %d", tt.expectedStatusCode, w.Code) + } + + if w.Body.String() != tt.expectedBody { + t.Errorf("Expected body %s, got %s", tt.expectedBody, w.Body.String()) + } + }) + } +} + +func TestHandleInvoke_SessionManagerError(t *testing.T) { + // Set required environment variables + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + + config := &Config{ + Port: "8080", + } + + server, err := NewServer(config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + // Mock session manager that returns error + server.sessionManager = &mockSessionManager{ + err: errors.New("session manager error"), + } + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/v1/namespaces/default/agent-runtimes/test-agent/invocations/test", nil) + server.engine.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected status code %d, got %d", http.StatusBadRequest, w.Code) + } +} + +func TestHandleInvoke_NoEntryPoints(t *testing.T) { + // Set required environment variables + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + + config := &Config{ + Port: "8080", + } + + server, err := NewServer(config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + // Mock session manager that returns sandbox with no entry points + server.sessionManager = &mockSessionManager{ + sandbox: &types.SandboxRedis{ + SandboxID: "test-sandbox", + SessionID: "test-session", + EntryPoints: []types.SandboxEntryPoints{}, + }, + } + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/v1/namespaces/default/agent-runtimes/test-agent/invocations/test", nil) + server.engine.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("Expected status code %d, got %d", http.StatusInternalServerError, w.Code) + } +} + +func TestHandleAgentInvoke(t *testing.T) { + // Set required environment variables + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + + // Create a test HTTP server to act as the sandbox + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"result":"success"}`)) + })) + defer testServer.Close() + + config := &Config{ + Port: "8080", + RequestTimeout: 30, + } + + server, err := NewServer(config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + // Mock session manager that returns sandbox with test server endpoint + server.sessionManager = &mockSessionManager{ + sandbox: &types.SandboxRedis{ + SandboxID: "test-sandbox", + SessionID: "test-session", + SandboxName: "test-sandbox", + EntryPoints: []types.SandboxEntryPoints{ + { + Endpoint: testServer.URL, + Path: "/test", + }, + }, + }, + } + + // Use real HTTP client instead of httptest.ResponseRecorder to avoid CloseNotifier panic + req, _ := http.NewRequest("POST", "/v1/namespaces/default/agent-runtimes/test-agent/invocations/test", nil) + req.Header.Set("x-agentcube-session-id", "test-session") + + // Start a real test server + testRouterServer := httptest.NewServer(server.engine) + defer testRouterServer.Close() + + // Make real HTTP request + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Post(testRouterServer.URL+"/v1/namespaces/default/agent-runtimes/test-agent/invocations/test", "application/json", nil) + if err != nil { + t.Fatalf("Failed to make request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode) + } + + // Check if session ID is set in response header + sessionID := resp.Header.Get("x-agentcube-session-id") + if sessionID != "test-session" { + t.Errorf("Expected session ID 'test-session', got '%s'", sessionID) + } +} + +func TestHandleCodeInterpreterInvoke(t *testing.T) { + // Set required environment variables + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + + // Create a test HTTP server to act as the sandbox + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"result":"success"}`)) + })) + defer testServer.Close() + + config := &Config{ + Port: "8080", + RequestTimeout: 30, + } + + server, err := NewServer(config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + // Mock session manager that returns sandbox with test server endpoint + server.sessionManager = &mockSessionManager{ + sandbox: &types.SandboxRedis{ + SandboxID: "test-sandbox", + SessionID: "test-session", + SandboxName: "test-sandbox", + EntryPoints: []types.SandboxEntryPoints{ + { + Endpoint: testServer.URL, + Path: "/execute", + }, + }, + }, + } + + // Use real HTTP client instead of httptest.ResponseRecorder to avoid CloseNotifier panic + testRouterServer := httptest.NewServer(server.engine) + defer testRouterServer.Close() + + // Make real HTTP request + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Post(testRouterServer.URL+"/v1/namespaces/default/code-interpreters/test-ci/invocations/execute", "application/json", nil) + if err != nil { + t.Fatalf("Failed to make request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode) + } + + // Check if session ID is set in response header + sessionID := resp.Header.Get("x-agentcube-session-id") + if sessionID != "test-session" { + t.Errorf("Expected session ID 'test-session', got '%s'", sessionID) + } +} + +func TestForwardToSandbox_InvalidEndpoint(t *testing.T) { + // Set required environment variables + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + + config := &Config{ + Port: "8080", + RequestTimeout: 30, + } + + server, err := NewServer(config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + // Mock session manager that returns sandbox with invalid endpoint + server.sessionManager = &mockSessionManager{ + sandbox: &types.SandboxRedis{ + SandboxID: "test-sandbox", + SessionID: "test-session", + SandboxName: "test-sandbox", + EntryPoints: []types.SandboxEntryPoints{ + { + Endpoint: "://invalid-url", + Path: "/test", + }, + }, + }, + } + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/v1/namespaces/default/agent-runtimes/test-agent/invocations/test", nil) + server.engine.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("Expected status code %d, got %d", http.StatusInternalServerError, w.Code) + } +} + +func TestConcurrencyLimitMiddleware_Overload(t *testing.T) { + // Set required environment variables + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + + config := &Config{ + Port: "8080", + MaxConcurrentRequests: 1, // Set to 1 to easily trigger overload + RequestTimeout: 30, + } + + server, err := NewServer(config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + // Create a slow test server + slowServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(100 * time.Millisecond) + w.WriteHeader(http.StatusOK) + })) + defer slowServer.Close() + + // Mock session manager with slow response + server.sessionManager = &mockSessionManager{ + sandbox: &types.SandboxRedis{ + SandboxID: "test-sandbox", + SessionID: "test-session", + SandboxName: "test-sandbox", + EntryPoints: []types.SandboxEntryPoints{ + { + Endpoint: slowServer.URL, + Path: "/test", + }, + }, + }, + } + + // Start a real test server + testRouterServer := httptest.NewServer(server.engine) + defer testRouterServer.Close() + + // Start first request (will occupy the semaphore) + done := make(chan bool) + go func() { + client := &http.Client{Timeout: 5 * time.Second} + _, _ = client.Post(testRouterServer.URL+"/v1/namespaces/default/agent-runtimes/test-agent/invocations/test", "application/json", nil) + done <- true + }() + + // Give first request time to acquire semaphore + time.Sleep(50 * time.Millisecond) + + // Try second request (should be rejected due to overload) + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Post(testRouterServer.URL+"/v1/namespaces/default/agent-runtimes/test-agent/invocations/test", "application/json", nil) + if err != nil { + t.Fatalf("Failed to make request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusServiceUnavailable { + t.Errorf("Expected status code %d, got %d", http.StatusServiceUnavailable, resp.StatusCode) + } + + // Wait for first request to complete + <-done +} diff --git a/pkg/router/server.go b/pkg/router/server.go new file mode 100644 index 00000000..85d6593d --- /dev/null +++ b/pkg/router/server.go @@ -0,0 +1,189 @@ +package router + +import ( + "context" + "fmt" + "log" + "net/http" + "os" + "time" + + "github.com/gin-gonic/gin" + redisv9 "github.com/redis/go-redis/v9" + + "github.com/volcano-sh/agentcube/pkg/redis" +) + +// Server is the main structure for Router apiserver +type Server struct { + config *Config + engine *gin.Engine + httpServer *http.Server + sessionManager SessionManager + redisClient redis.Client + semaphore chan struct{} // For limiting concurrent requests + httpTransport *http.Transport // Reusable HTTP transport for connection pooling +} + +// makeRedisOptions creates redis options from environment variables +func makeRedisOptions() (*redisv9.Options, error) { + redisAddr := os.Getenv("REDIS_ADDR") + if redisAddr == "" { + return nil, fmt.Errorf("missing env var REDIS_ADDR") + } + redisPassword := os.Getenv("REDIS_PASSWORD") + if redisPassword == "" { + return nil, fmt.Errorf("missing env var REDIS_PASSWORD") + } + redisOptions := &redisv9.Options{ + Addr: redisAddr, + Password: redisPassword, + } + return redisOptions, nil +} + +// NewServer creates a new Router API server instance +func NewServer(config *Config) (*Server, error) { + if config == nil { + return nil, fmt.Errorf("config cannot be nil") + } + + // Set default values for concurrency settings + if config.MaxConcurrentRequests <= 0 { + config.MaxConcurrentRequests = 1000 // Default limit + } + if config.RequestTimeout <= 0 { + config.RequestTimeout = 30 // Default 30 seconds + } + if config.MaxIdleConns <= 0 { + config.MaxIdleConns = 100 // Default 100 idle connections + } + if config.MaxConnsPerHost <= 0 { + config.MaxConnsPerHost = 10 // Default 10 connections per host + } + + // Initialize Redis client + redisOptions, err := makeRedisOptions() + if err != nil { + return nil, fmt.Errorf("make redis options failed: %w", err) + } + redisClient := redis.NewClient(redisOptions) + + // Create session manager with redis client + sessionManager, err := NewSessionManager(redisClient) + if err != nil { + return nil, fmt.Errorf("failed to create session manager: %w", err) + } + + // Set Gin mode based on environment + if config.Debug { + gin.SetMode(gin.DebugMode) + } else { + gin.SetMode(gin.ReleaseMode) + } + + // Create a reusable HTTP transport for connection pooling + httpTransport := &http.Transport{ + MaxIdleConns: config.MaxIdleConns, + MaxIdleConnsPerHost: config.MaxConnsPerHost, + IdleConnTimeout: 0, + DisableCompression: false, + ForceAttemptHTTP2: true, + } + + server := &Server{ + config: config, + sessionManager: sessionManager, + redisClient: redisClient, + semaphore: make(chan struct{}, config.MaxConcurrentRequests), + httpTransport: httpTransport, + } + + // Setup routes + server.setupRoutes() + + return server, nil +} + +// concurrencyLimitMiddleware limits the number of concurrent requests +func (s *Server) concurrencyLimitMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + // Try to acquire a slot in the semaphore + select { + case s.semaphore <- struct{}{}: + // Successfully acquired a slot, continue processing + defer func() { + // Release the slot when done + <-s.semaphore + }() + c.Next() + default: + // No slots available, return 503 Service Unavailable + c.JSON(http.StatusServiceUnavailable, gin.H{ + "error": "server overloaded, please try again later", + "code": "SERVER_OVERLOADED", + }) + c.Abort() + } + } +} + +// setupRoutes configures HTTP routes using Gin +func (s *Server) setupRoutes() { + s.engine = gin.New() + + // Add middleware + s.engine.Use(gin.Logger()) + s.engine.Use(gin.Recovery()) + + // Health check endpoints (no authentication required, no concurrency limit) + s.engine.GET("/health", s.handleHealth) + s.engine.GET("/health/live", s.handleHealthLive) + s.engine.GET("/health/ready", s.handleHealthReady) + + // API v1 routes with concurrency limiting + v1 := s.engine.Group("/v1") + v1.Use(s.concurrencyLimitMiddleware()) // Apply concurrency limit to API routes + + // Agent invoke requests + v1.POST("/namespaces/:namespace/agent-runtimes/:name/invocations/*path", s.handleAgentInvoke) + + // Code interpreter invoke requests + v1.POST("/namespaces/:namespace/code-interpreters/:name/invocations/*path", s.handleCodeInterpreterInvoke) +} + +// Start starts the Router API server +func (s *Server) Start(ctx context.Context) error { + addr := ":" + s.config.Port + + s.httpServer = &http.Server{ + Addr: addr, + Handler: s.engine, + ReadTimeout: 30 * time.Second, // Longer timeout for potential long-running requests + WriteTimeout: 30 * time.Second, + IdleTimeout: 60 * time.Second, + } + + // Listen for shutdown signal in goroutine + go func() { + <-ctx.Done() + log.Println("Shutting down Router server...") + shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := s.httpServer.Shutdown(shutdownCtx); err != nil { + log.Printf("Server shutdown error: %v", err) + } + }() + + log.Printf("Router server listening on %s", addr) + + // Start HTTP or HTTPS server + if s.config.EnableTLS { + if s.config.TLSCert == "" || s.config.TLSKey == "" { + return fmt.Errorf("TLS enabled but cert/key not provided") + } + return s.httpServer.ListenAndServeTLS(s.config.TLSCert, s.config.TLSKey) + } + + return s.httpServer.ListenAndServe() +} diff --git a/pkg/router/server_test.go b/pkg/router/server_test.go new file mode 100644 index 00000000..5932dd4a --- /dev/null +++ b/pkg/router/server_test.go @@ -0,0 +1,373 @@ +package router + +import ( + "context" + "os" + "testing" + "time" +) + +func TestNewServer(t *testing.T) { + // Set required environment variables for tests + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + + tests := []struct { + name string + config *Config + wantErr bool + }{ + { + name: "nil config", + config: nil, + wantErr: true, + }, + { + name: "valid config with defaults", + config: &Config{ + Port: "8080", + }, + wantErr: false, + }, + { + name: "valid config with custom values", + config: &Config{ + Port: "9090", + MaxConcurrentRequests: 500, + RequestTimeout: 60, + MaxIdleConns: 200, + MaxConnsPerHost: 20, + EnableRedis: true, + Debug: true, + }, + wantErr: false, + }, + { + name: "config with TLS enabled", + config: &Config{ + Port: "8443", + EnableTLS: true, + TLSCert: "/path/to/cert.pem", + TLSKey: "/path/to/key.pem", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server, err := NewServer(tt.config) + + if (err != nil) != tt.wantErr { + t.Errorf("NewServer() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr { + // Verify server was created + if server == nil { + t.Error("Expected non-nil server") + return + } + + // Verify config was set + if server.config != tt.config { + t.Error("Server config was not set correctly") + } + + // Verify session manager was created + if server.sessionManager == nil { + t.Error("Session manager was not created") + } + + // Verify redis client was created + if server.redisClient == nil { + t.Error("Redis client was not created") + } + + // Verify semaphore was created with correct capacity + expectedCapacity := tt.config.MaxConcurrentRequests + if expectedCapacity <= 0 { + expectedCapacity = 1000 // Default value + } + if cap(server.semaphore) != expectedCapacity { + t.Errorf("Expected semaphore capacity %d, got %d", expectedCapacity, cap(server.semaphore)) + } + + // Verify default values were set + if server.config.MaxConcurrentRequests <= 0 { + t.Error("MaxConcurrentRequests should have been set to default") + } + if server.config.RequestTimeout <= 0 { + t.Error("RequestTimeout should have been set to default") + } + if server.config.MaxIdleConns <= 0 { + t.Error("MaxIdleConns should have been set to default") + } + if server.config.MaxConnsPerHost <= 0 { + t.Error("MaxConnsPerHost should have been set to default") + } + } + }) + } +} + +func TestServer_DefaultValues(t *testing.T) { + // Set required environment variables for tests + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + + config := &Config{ + Port: "8080", + // Leave other values as zero to test defaults + } + + server, err := NewServer(config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + // Test default values + if server.config.MaxConcurrentRequests != 1000 { + t.Errorf("Expected default MaxConcurrentRequests 1000, got %d", server.config.MaxConcurrentRequests) + } + + if server.config.RequestTimeout != 30 { + t.Errorf("Expected default RequestTimeout 30, got %d", server.config.RequestTimeout) + } + + if server.config.MaxIdleConns != 100 { + t.Errorf("Expected default MaxIdleConns 100, got %d", server.config.MaxIdleConns) + } + + if server.config.MaxConnsPerHost != 10 { + t.Errorf("Expected default MaxConnsPerHost 10, got %d", server.config.MaxConnsPerHost) + } +} + +func TestServer_ConcurrencyLimitMiddleware(t *testing.T) { + // Set required environment variables for tests + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + + config := &Config{ + Port: "8080", + MaxConcurrentRequests: 2, // Small limit for testing + } + + server, err := NewServer(config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + middleware := server.concurrencyLimitMiddleware() + + // Test that middleware function was created + if middleware == nil { + t.Error("Expected non-nil middleware function") + } + + // Note: Testing the actual middleware behavior would require setting up + // a full HTTP test environment, which is beyond the scope of unit tests. + // Integration tests would be more appropriate for testing middleware behavior. +} + +func TestServer_SetupRoutes(t *testing.T) { + // Set required environment variables for tests + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + + config := &Config{ + Port: "8080", + } + + server, err := NewServer(config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + // Verify that engine was created during setupRoutes + if server.engine == nil { + t.Error("Expected non-nil Gin engine after setupRoutes") + } + + // Note: Testing specific routes would require HTTP testing, + // which is more appropriate for integration tests. +} + +func TestServer_StartContext(t *testing.T) { + // Set required environment variables for tests + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + + config := &Config{ + Port: "0", // Use port 0 to let the OS assign a free port + } + + server, err := NewServer(config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + // Test context cancellation + ctx, cancel := context.WithCancel(context.Background()) + + // Start server in goroutine + errChan := make(chan error, 1) + go func() { + err := server.Start(ctx) + errChan <- err + }() + + // Give server a moment to start + time.Sleep(100 * time.Millisecond) + + // Cancel context to trigger shutdown + cancel() + + // Wait for server to shutdown + select { + case err := <-errChan: + // Server should shutdown gracefully, error might be http.ErrServerClosed + if err != nil && err.Error() != "http: Server closed" { + t.Errorf("Unexpected error during shutdown: %v", err) + } + case <-time.After(5 * time.Second): + t.Error("Server did not shutdown within timeout") + } +} + +func TestServer_TLSConfiguration(t *testing.T) { + // Set required environment variables for tests + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + + tests := []struct { + name string + config *Config + wantErr bool + errString string + }{ + { + name: "TLS enabled with cert and key", + config: &Config{ + Port: "8443", + EnableTLS: true, + TLSCert: "/path/to/cert.pem", + TLSKey: "/path/to/key.pem", + }, + wantErr: false, + }, + { + name: "TLS enabled without cert", + config: &Config{ + Port: "8443", + EnableTLS: true, + TLSKey: "/path/to/key.pem", + }, + wantErr: true, + errString: "TLS enabled but cert/key not provided", + }, + { + name: "TLS enabled without key", + config: &Config{ + Port: "8443", + EnableTLS: true, + TLSCert: "/path/to/cert.pem", + }, + wantErr: true, + errString: "TLS enabled but cert/key not provided", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server, err := NewServer(tt.config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + // Test Start method with a context that will be cancelled immediately + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately to avoid actually starting the server + + err = server.Start(ctx) + + if tt.wantErr { + if err == nil { + t.Error("Expected error but got none") + } else if tt.errString != "" && err.Error() != tt.errString { + t.Errorf("Expected error '%s', got '%s'", tt.errString, err.Error()) + } + } else { + // For TLS tests, we expect the server to fail to start due to invalid cert/key paths + // but the configuration validation should pass + if err != nil && err.Error() == tt.errString { + t.Errorf("Unexpected configuration error: %v", err) + } + } + }) + } +} + +func TestServer_RedisIntegration(t *testing.T) { + // Set required environment variables for tests + os.Setenv("REDIS_ADDR", "localhost:6379") + os.Setenv("REDIS_PASSWORD", "test-password") + os.Setenv("WORKLOAD_MGR_URL", "http://localhost:8080") + defer func() { + os.Unsetenv("REDIS_ADDR") + os.Unsetenv("REDIS_PASSWORD") + os.Unsetenv("WORKLOAD_MGR_URL") + }() + + config := &Config{ + Port: "8080", + } + + server, err := NewServer(config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + // Verify Redis client was created + if server.redisClient == nil { + t.Error("Redis client was not created") + } +} diff --git a/pkg/router/session_manager.go b/pkg/router/session_manager.go new file mode 100644 index 00000000..5a7e56bb --- /dev/null +++ b/pkg/router/session_manager.go @@ -0,0 +1,153 @@ +package router + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "time" + + "github.com/volcano-sh/agentcube/pkg/common/types" + "github.com/volcano-sh/agentcube/pkg/redis" +) + +// SessionManager defines the session management behavior on top of Redis and the workload manager. +type SessionManager interface { + // GetSandboxBySession returns the sandbox associated with the given sessionID. + // When sessionID is empty, it creates a new sandbox by calling the external API. + // When sessionID is not empty, it queries Redis for the sandbox. + GetSandboxBySession(ctx context.Context, sessionID string, namespace string, name string, kind string) (*types.SandboxRedis, error) +} + +// manager is the default implementation of the SessionManager interface. +type manager struct { + redisClient redis.Client + workloadMgrURL string + httpClient *http.Client +} + +// NewSessionManager returns a SessionManager implementation. +// redisClient is used to query sandbox information from Redis. +// workloadMgrURL is read from the environment variable WORKLOAD_MGR_URL. +func NewSessionManager(redisClient redis.Client) (SessionManager, error) { + workloadMgrURL := os.Getenv("WORKLOAD_MGR_URL") + if workloadMgrURL == "" { + return nil, fmt.Errorf("WORKLOAD_MGR_URL environment variable is not set") + } + + return &manager{ + redisClient: redisClient, + workloadMgrURL: workloadMgrURL, + httpClient: &http.Client{ + Timeout: time.Minute, // No timeout for createSandbox requests + }, + }, nil +} + +// GetSandboxBySession returns the sandbox associated with the given sessionID. +// When sessionID is empty, it creates a new sandbox by calling the external API. +// When sessionID is not empty, it queries Redis for the sandbox. +func (m *manager) GetSandboxBySession(ctx context.Context, sessionID string, namespace string, name string, kind string) (*types.SandboxRedis, error) { + // When sessionID is empty, create a new sandbox + if sessionID == "" { + return m.createSandbox(ctx, namespace, name, kind) + } + + // When sessionID is not empty, query Redis + sandbox, err := m.redisClient.GetSandboxBySessionID(ctx, sessionID) + if err != nil { + if errors.Is(err, redis.ErrNotFound) { + return nil, ErrSessionNotFound + } + return nil, fmt.Errorf("failed to get sandbox from redis: %w", err) + } + + return sandbox, nil +} + +// createSandbox creates a new sandbox by calling the external workload manager API. +func (m *manager) createSandbox(ctx context.Context, namespace string, name string, kind string) (*types.SandboxRedis, error) { + // Determine the API endpoint based on kind + var endpoint string + switch kind { + case types.AgentRuntimeKind: + endpoint = m.workloadMgrURL + "/v1/agent-runtime" + case types.CodeInterpreterKind: + endpoint = m.workloadMgrURL + "/v1/code-interpreter" + default: + return nil, fmt.Errorf("unsupported kind: %s", kind) + } + + // Prepare the request body + reqBody := &types.CreateSandboxRequest{ + Kind: kind, + Name: name, + Namespace: namespace, + } + + bodyBytes, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request body: %w", err) + } + + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(bodyBytes)) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + // Send the request + resp, err := m.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrUpstreamUnavailable, err) + } + defer resp.Body.Close() + + // Read response body + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + // Check response status + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("%w: status code %d, body: %s", ErrCreateSandboxFailed, resp.StatusCode, string(respBody)) + } + + // Parse response + var createResp types.CreateSandboxResponse + if err := json.Unmarshal(respBody, &createResp); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + // Validate response + if createResp.SessionID == "" { + return nil, fmt.Errorf("%w: response with empty session id from workload manager", ErrCreateSandboxFailed) + } + + // Construct SandboxRedis from response + sandbox := &types.SandboxRedis{ + SandboxID: createResp.SandboxID, + SandboxName: createResp.SandboxName, + SessionID: createResp.SessionID, + EntryPoints: createResp.EntryPoints, + } + + return sandbox, nil +} + +var ( + // ErrSessionNotFound indicates that the session does not exist in redis. + ErrSessionNotFound = errors.New("sessionmgr: session not found") + + // ErrUpstreamUnavailable indicates that the workload manager is unavailable. + ErrUpstreamUnavailable = errors.New("sessionmgr: workload manager unavailable") + + // ErrCreateSandboxFailed indicates that the workload manager returned an error. + ErrCreateSandboxFailed = errors.New("sessionmgr: create sandbox failed") +) diff --git a/pkg/router/session_manager_test.go b/pkg/router/session_manager_test.go new file mode 100644 index 00000000..cb21c0f8 --- /dev/null +++ b/pkg/router/session_manager_test.go @@ -0,0 +1,400 @@ +package router + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/volcano-sh/agentcube/pkg/common/types" + "github.com/volcano-sh/agentcube/pkg/redis" +) + +// ---- fakes ---- + +type fakeRedisClient struct { + sandbox *types.SandboxRedis + err error + called bool + lastSessionID string + lastContextNil bool +} + +func (f *fakeRedisClient) GetSandboxBySessionID(ctx context.Context, sessionID string) (*types.SandboxRedis, error) { + f.called = true + f.lastSessionID = sessionID + f.lastContextNil = ctx == nil + return f.sandbox, f.err +} + +func (f *fakeRedisClient) SetSessionLockIfAbsent(ctx context.Context, sessionID string, ttl time.Duration) (bool, error) { + return false, nil +} + +func (f *fakeRedisClient) BindSessionWithSandbox(ctx context.Context, sessionID string, sandboxRedis *types.SandboxRedis, ttl time.Duration) error { + return nil +} + +func (f *fakeRedisClient) DeleteSessionBySandboxIDTx(ctx context.Context, sandboxID string) error { + return nil +} + +func (f *fakeRedisClient) DeleteSandboxBySessionIDTx(ctx context.Context, sessionID string) error { + return nil +} + +func (f *fakeRedisClient) UpdateSandbox(ctx context.Context, sandboxRedis *types.SandboxRedis, ttl time.Duration) error { + return nil +} + +func (f *fakeRedisClient) UpdateSessionLastActivity(ctx context.Context, sessionID string, at time.Time) error { + return nil +} + +func (f *fakeRedisClient) StoreSandbox(ctx context.Context, sandboxRedis *types.SandboxRedis, ttl time.Duration) error { + return nil +} + +func (f *fakeRedisClient) Ping(ctx context.Context) error { + return nil +} + +func (f *fakeRedisClient) ListExpiredSandboxes(ctx context.Context, before time.Time, limit int64) ([]*types.SandboxRedis, error) { + return nil, nil +} + +func (f *fakeRedisClient) ListInactiveSandboxes(ctx context.Context, before time.Time, limit int64) ([]*types.SandboxRedis, error) { + return nil, nil +} + +func (f *fakeRedisClient) UpdateSandboxLastActivity(ctx context.Context, sandboxID string, at time.Time) error { + return nil +} + +// ---- tests: GetSandboxBySession ---- + +func TestGetSandboxBySession_Success(t *testing.T) { + sb := &types.SandboxRedis{ + SandboxID: "sandbox-1", + SandboxName: "sandbox-1", + EntryPoints: []types.SandboxEntryPoints{ + {Endpoint: "10.0.0.1:9000"}, + }, + SessionID: "sess-1", + Status: "running", + } + + r := &fakeRedisClient{ + sandbox: sb, + } + m := &manager{ + redisClient: r, + } + + got, err := m.GetSandboxBySession(context.Background(), "sess-1", "default", "test", "AgentRuntime") + if err != nil { + t.Fatalf("GetSandboxBySession unexpected error: %v", err) + } + if !r.called { + t.Fatalf("expected RedisClient to be called") + } + if r.lastSessionID != "sess-1" { + t.Fatalf("expected RedisClient to be called with sessionID 'sess-1', got %q", r.lastSessionID) + } + if got == nil { + t.Fatalf("expected non-nil sandbox") + } + if got.SandboxID != "sandbox-1" { + t.Fatalf("unexpected SandboxID: got %q, want %q", got.SandboxID, "sandbox-1") + } +} + +func TestGetSandboxBySession_NotFound(t *testing.T) { + r := &fakeRedisClient{ + sandbox: nil, + err: redis.ErrNotFound, + } + m := &manager{ + redisClient: r, + } + + _, err := m.GetSandboxBySession(context.Background(), "sess-1", "default", "test", "AgentRuntime") + if err == nil { + t.Fatalf("expected error for not found session") + } + if !errors.Is(err, ErrSessionNotFound) { + t.Fatalf("expected ErrSessionNotFound, got %v", err) + } +} + +// ---- tests: GetSandboxBySession with empty sessionID (sandbox creation path) ---- + +func TestGetSandboxBySession_CreateSandbox_AgentRuntime_Success(t *testing.T) { + // Mock workload manager server + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request method and path + if r.Method != http.MethodPost { + t.Errorf("expected POST request, got %s", r.Method) + } + if r.URL.Path != "/v1/agent-runtime" { + t.Errorf("expected path /v1/agent-runtime, got %s", r.URL.Path) + } + + // Verify request body + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("failed to read request body: %v", err) + } + var req types.CreateSandboxRequest + if err := json.Unmarshal(body, &req); err != nil { + t.Fatalf("failed to unmarshal request: %v", err) + } + if req.Kind != types.AgentRuntimeKind { + t.Errorf("expected kind %s, got %s", types.AgentRuntimeKind, req.Kind) + } + if req.Name != "test-runtime" { + t.Errorf("expected name test-runtime, got %s", req.Name) + } + if req.Namespace != "default" { + t.Errorf("expected namespace default, got %s", req.Namespace) + } + + // Send successful response + resp := types.CreateSandboxResponse{ + SessionID: "new-session-123", + SandboxID: "sandbox-456", + SandboxName: "sandbox-test", + EntryPoints: []types.SandboxEntryPoints{ + {Endpoint: "10.0.0.1:9000", Protocol: "http", Path: "/"}, + }, + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(resp) + })) + defer mockServer.Close() + + r := &fakeRedisClient{} + m := &manager{ + redisClient: r, + workloadMgrURL: mockServer.URL, + httpClient: &http.Client{}, + } + + sandbox, err := m.GetSandboxBySession(context.Background(), "", "default", "test-runtime", types.AgentRuntimeKind) + if err != nil { + t.Fatalf("GetSandboxBySession unexpected error: %v", err) + } + if sandbox == nil { + t.Fatalf("expected non-nil sandbox") + } + if sandbox.SessionID != "new-session-123" { + t.Errorf("expected SessionID new-session-123, got %s", sandbox.SessionID) + } + if sandbox.SandboxID != "sandbox-456" { + t.Errorf("expected SandboxID sandbox-456, got %s", sandbox.SandboxID) + } + if sandbox.SandboxName != "sandbox-test" { + t.Errorf("expected SandboxName sandbox-test, got %s", sandbox.SandboxName) + } + if len(sandbox.EntryPoints) != 1 { + t.Fatalf("expected 1 entry point, got %d", len(sandbox.EntryPoints)) + } + if sandbox.EntryPoints[0].Endpoint != "10.0.0.1:9000" { + t.Errorf("expected endpoint 10.0.0.1:9000, got %s", sandbox.EntryPoints[0].Endpoint) + } +} + +func TestGetSandboxBySession_CreateSandbox_CodeInterpreter_Success(t *testing.T) { + // Mock workload manager server + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request method and path + if r.Method != http.MethodPost { + t.Errorf("expected POST request, got %s", r.Method) + } + if r.URL.Path != "/v1/code-interpreter" { + t.Errorf("expected path /v1/code-interpreter, got %s", r.URL.Path) + } + + // Verify request body + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("failed to read request body: %v", err) + } + var req types.CreateSandboxRequest + if err := json.Unmarshal(body, &req); err != nil { + t.Fatalf("failed to unmarshal request: %v", err) + } + if req.Kind != types.CodeInterpreterKind { + t.Errorf("expected kind %s, got %s", types.CodeInterpreterKind, req.Kind) + } + + // Send successful response + resp := types.CreateSandboxResponse{ + SessionID: "ci-session-789", + SandboxID: "ci-sandbox-101", + SandboxName: "ci-sandbox-test", + EntryPoints: []types.SandboxEntryPoints{ + {Endpoint: "10.0.0.2:8080", Protocol: "http", Path: "/"}, + }, + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(resp) + })) + defer mockServer.Close() + + r := &fakeRedisClient{} + m := &manager{ + redisClient: r, + workloadMgrURL: mockServer.URL, + httpClient: &http.Client{}, + } + + sandbox, err := m.GetSandboxBySession(context.Background(), "", "default", "test-ci", types.CodeInterpreterKind) + if err != nil { + t.Fatalf("GetSandboxBySession unexpected error: %v", err) + } + if sandbox == nil { + t.Fatalf("expected non-nil sandbox") + } + if sandbox.SessionID != "ci-session-789" { + t.Errorf("expected SessionID ci-session-789, got %s", sandbox.SessionID) + } +} + +func TestGetSandboxBySession_CreateSandbox_UnsupportedKind(t *testing.T) { + r := &fakeRedisClient{} + m := &manager{ + redisClient: r, + workloadMgrURL: "http://localhost:8080", + httpClient: &http.Client{}, + } + + _, err := m.GetSandboxBySession(context.Background(), "", "default", "test", "UnsupportedKind") + if err == nil { + t.Fatalf("expected error for unsupported kind") + } + if err.Error() != "unsupported kind: UnsupportedKind" { + t.Errorf("unexpected error message: %v", err) + } +} + +func TestGetSandboxBySession_CreateSandbox_WorkloadManagerUnavailable(t *testing.T) { + // Mock workload manager server that closes connection immediately + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Close connection without sending response + hj, ok := w.(http.Hijacker) + if !ok { + t.Fatal("webserver doesn't support hijacking") + } + conn, _, err := hj.Hijack() + if err != nil { + t.Fatal(err) + } + conn.Close() + })) + serverURL := mockServer.URL + mockServer.Close() // Close the server to make it unavailable + + r := &fakeRedisClient{} + m := &manager{ + redisClient: r, + workloadMgrURL: serverURL, + httpClient: &http.Client{}, + } + + _, err := m.GetSandboxBySession(context.Background(), "", "default", "test", types.AgentRuntimeKind) + if err == nil { + t.Fatalf("expected error for unavailable workload manager") + } + if !errors.Is(err, ErrUpstreamUnavailable) { + t.Errorf("expected ErrUpstreamUnavailable, got %v", err) + } +} + +func TestGetSandboxBySession_CreateSandbox_NonOKStatus(t *testing.T) { + // Mock workload manager server that returns error + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("internal server error")) + })) + defer mockServer.Close() + + r := &fakeRedisClient{} + m := &manager{ + redisClient: r, + workloadMgrURL: mockServer.URL, + httpClient: &http.Client{}, + } + + _, err := m.GetSandboxBySession(context.Background(), "", "default", "test", types.AgentRuntimeKind) + if err == nil { + t.Fatalf("expected error for non-OK status") + } + if !errors.Is(err, ErrCreateSandboxFailed) { + t.Errorf("expected ErrCreateSandboxFailed, got %v", err) + } +} + +func TestGetSandboxBySession_CreateSandbox_InvalidJSON(t *testing.T) { + // Mock workload manager server that returns invalid JSON + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte("invalid json")) + })) + defer mockServer.Close() + + r := &fakeRedisClient{} + m := &manager{ + redisClient: r, + workloadMgrURL: mockServer.URL, + httpClient: &http.Client{}, + } + + _, err := m.GetSandboxBySession(context.Background(), "", "default", "test", types.AgentRuntimeKind) + if err == nil { + t.Fatalf("expected error for invalid JSON") + } + if err.Error() == "" { + t.Errorf("expected error message for invalid JSON") + } +} + +func TestGetSandboxBySession_CreateSandbox_EmptySessionID(t *testing.T) { + // Mock workload manager server that returns empty sessionID + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := types.CreateSandboxResponse{ + SessionID: "", // Empty sessionID + SandboxID: "sandbox-456", + SandboxName: "sandbox-test", + EntryPoints: []types.SandboxEntryPoints{ + {Endpoint: "10.0.0.1:9000"}, + }, + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(resp) + })) + defer mockServer.Close() + + r := &fakeRedisClient{} + m := &manager{ + redisClient: r, + workloadMgrURL: mockServer.URL, + httpClient: &http.Client{}, + } + + _, err := m.GetSandboxBySession(context.Background(), "", "default", "test", types.AgentRuntimeKind) + if err == nil { + t.Fatalf("expected error for empty sessionID in response") + } + if !errors.Is(err, ErrCreateSandboxFailed) { + t.Errorf("expected ErrCreateSandboxFailed, got %v", err) + } +} diff --git a/pkg/workloadmanager/garbage_collection.go b/pkg/workloadmanager/garbage_collection.go index 5a3383f1..77f4ed9c 100644 --- a/pkg/workloadmanager/garbage_collection.go +++ b/pkg/workloadmanager/garbage_collection.go @@ -6,10 +6,11 @@ import ( "log" "time" - "github.com/volcano-sh/agentcube/pkg/redis" "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" utilerrors "k8s.io/apimachinery/pkg/util/errors" + + "github.com/volcano-sh/agentcube/pkg/redis" ) const ( diff --git a/pkg/workloadmanager/handlers.go b/pkg/workloadmanager/handlers.go index 4e624b3d..44c6f671 100644 --- a/pkg/workloadmanager/handlers.go +++ b/pkg/workloadmanager/handlers.go @@ -11,10 +11,11 @@ import ( "github.com/gin-gonic/gin" redisv9 "github.com/redis/go-redis/v9" - "github.com/volcano-sh/agentcube/pkg/common/types" - "github.com/volcano-sh/agentcube/pkg/redis" sandboxv1alpha1 "sigs.k8s.io/agent-sandbox/api/v1alpha1" extensionsv1alpha1 "sigs.k8s.io/agent-sandbox/extensions/api/v1alpha1" + + "github.com/volcano-sh/agentcube/pkg/common/types" + "github.com/volcano-sh/agentcube/pkg/redis" ) // handleHealth handles health check requests @@ -55,9 +56,9 @@ func (s *Server) handleCreateSandbox(c *gin.Context) { var err error switch createAgentRequest.Kind { case types.AgentRuntimeKind: - sandbox, externalInfo, err = buildSandboxByAgentRuntime(createAgentRequest.Namespace, createAgentRequest.Name, s.informers) + sandbox, externalInfo, err = buildSandboxByAgentRuntime(createAgentRequest.Namespace, createAgentRequest.Name, s.informers, createAgentRequest.ReadinessProbe) case types.CodeInterpreterKind: - sandbox, sandboxClaim, externalInfo, err = buildSandboxByCodeInterpreter(createAgentRequest.Namespace, createAgentRequest.Name, s.informers) + sandbox, sandboxClaim, externalInfo, err = buildSandboxByCodeInterpreter(createAgentRequest.Namespace, createAgentRequest.Name, s.informers, createAgentRequest.ReadinessProbe) default: log.Printf("invalid request kind: %v", createAgentRequest.Kind) respondError(c, http.StatusBadRequest, "INVALID_REQUEST", fmt.Sprintf("invalid request kind: %v", createAgentRequest.Kind)) diff --git a/pkg/workloadmanager/server.go b/pkg/workloadmanager/server.go index 67e78f04..ba01f812 100644 --- a/pkg/workloadmanager/server.go +++ b/pkg/workloadmanager/server.go @@ -10,6 +10,7 @@ import ( "github.com/gin-gonic/gin" redisv9 "github.com/redis/go-redis/v9" + "github.com/volcano-sh/agentcube/pkg/redis" ) diff --git a/pkg/workloadmanager/workload_builder.go b/pkg/workloadmanager/workload_builder.go index e0d2c593..64b983b9 100644 --- a/pkg/workloadmanager/workload_builder.go +++ b/pkg/workloadmanager/workload_builder.go @@ -7,10 +7,12 @@ import ( "github.com/google/uuid" runtimev1alpha1 "github.com/volcano-sh/agentcube/pkg/apis/runtime/v1alpha1" + "github.com/volcano-sh/agentcube/pkg/common/types" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/intstr" "k8s.io/utils/ptr" sandboxv1alpha1 "sigs.k8s.io/agent-sandbox/api/v1alpha1" extensionsv1alpha1 "sigs.k8s.io/agent-sandbox/extensions/api/v1alpha1" @@ -26,6 +28,7 @@ type buildSandboxParams struct { podSpec corev1.PodSpec podLabels map[string]string podAnnotations map[string]string + readinessProbe *types.ReadinessProbe } type buildSandboxClaimParams struct { @@ -79,6 +82,68 @@ func buildSandboxObject(params *buildSandboxParams) *sandboxv1alpha1.Sandbox { } sandbox.Spec.PodTemplate.ObjectMeta.Labels[SessionIdLabelKey] = params.sessionID sandbox.Spec.PodTemplate.ObjectMeta.Labels["sandbox-name"] = params.sandboxName + + // Handle Readiness Probe + if len(sandbox.Spec.PodTemplate.Spec.Containers) > 0 { + container := &sandbox.Spec.PodTemplate.Spec.Containers[0] + + if params.readinessProbe != nil { + probe := &corev1.Probe{ + InitialDelaySeconds: params.readinessProbe.InitialDelaySeconds, + PeriodSeconds: params.readinessProbe.PeriodSeconds, + SuccessThreshold: params.readinessProbe.SuccessThreshold, + FailureThreshold: params.readinessProbe.FailureThreshold, + TimeoutSeconds: params.readinessProbe.TimeoutSeconds, + } + + if params.readinessProbe.TCPSocket != nil { + probe.ProbeHandler = corev1.ProbeHandler{ + TCPSocket: &corev1.TCPSocketAction{ + Port: intstr.FromInt(params.readinessProbe.TCPSocket.Port), + Host: params.readinessProbe.TCPSocket.Host, + }, + } + } else if params.readinessProbe.HTTPGet != nil { + headers := []corev1.HTTPHeader{} + for k, v := range params.readinessProbe.HTTPGet.HTTPHeaders { + headers = append(headers, corev1.HTTPHeader{Name: k, Value: v}) + } + probe.ProbeHandler = corev1.ProbeHandler{ + HTTPGet: &corev1.HTTPGetAction{ + Path: params.readinessProbe.HTTPGet.Path, + Port: intstr.FromInt(params.readinessProbe.HTTPGet.Port), + Host: params.readinessProbe.HTTPGet.Host, + Scheme: corev1.URIScheme(params.readinessProbe.HTTPGet.Scheme), + HTTPHeaders: headers, + }, + } + } else if params.readinessProbe.Exec != nil { + probe.ProbeHandler = corev1.ProbeHandler{ + Exec: &corev1.ExecAction{ + Command: params.readinessProbe.Exec.Command, + }, + } + } + container.ReadinessProbe = probe + } else if container.ReadinessProbe == nil { + // Add default TCP probe if ports are available + if len(container.Ports) > 0 { + container.ReadinessProbe = &corev1.Probe{ + ProbeHandler: corev1.ProbeHandler{ + TCPSocket: &corev1.TCPSocketAction{ + Port: intstr.FromInt(int(container.Ports[0].ContainerPort)), + }, + }, + InitialDelaySeconds: 5, + PeriodSeconds: 10, + SuccessThreshold: 1, + FailureThreshold: 3, + TimeoutSeconds: 1, + } + } + } + } + return sandbox } @@ -103,7 +168,7 @@ func buildSandboxClaimObject(params *buildSandboxClaimParams) *extensionsv1alpha return sandboxClaim } -func buildSandboxByAgentRuntime(namespace string, name string, ifm *Informers) (*sandboxv1alpha1.Sandbox, *sandboxExternalInfo, error) { +func buildSandboxByAgentRuntime(namespace string, name string, ifm *Informers, readinessProbe *types.ReadinessProbe) (*sandboxv1alpha1.Sandbox, *sandboxExternalInfo, error) { agentRuntimeKey := namespace + "/" + name runtimeObj, exists, err := ifm.AgentRuntimeInformer.GetStore().GetByKey(agentRuntimeKey) if err != nil { @@ -131,11 +196,12 @@ func buildSandboxByAgentRuntime(namespace string, name string, ifm *Informers) ( sessionID := uuid.New().String() sandboxName := "agent-runtime-" + uuid.New().String() buildParams := &buildSandboxParams{ - namespace: namespace, - workloadName: name, - sandboxName: sandboxName, - sessionID: sessionID, - podSpec: agentRuntimeObj.Spec.Template.Spec, + namespace: namespace, + workloadName: name, + sandboxName: sandboxName, + sessionID: sessionID, + podSpec: agentRuntimeObj.Spec.Template.Spec, + readinessProbe: readinessProbe, } if agentRuntimeObj.Spec.MaxSessionDuration != nil { buildParams.ttl = agentRuntimeObj.Spec.MaxSessionDuration.Duration @@ -151,7 +217,7 @@ func buildSandboxByAgentRuntime(namespace string, name string, ifm *Informers) ( return sandbox, externalInfo, nil } -func buildSandboxByCodeInterpreter(namespace string, codeInterpreterName string, ifm *Informers) (*sandboxv1alpha1.Sandbox, *extensionsv1alpha1.SandboxClaim, *sandboxExternalInfo, error) { +func buildSandboxByCodeInterpreter(namespace string, codeInterpreterName string, ifm *Informers, readinessProbe *types.ReadinessProbe) (*sandboxv1alpha1.Sandbox, *extensionsv1alpha1.SandboxClaim, *sandboxExternalInfo, error) { codeInterpreterKey := namespace + "/" + codeInterpreterName runtimeObj, exists, err := ifm.CodeInterpreterInformer.GetStore().GetByKey(codeInterpreterKey) if err != nil { @@ -220,6 +286,7 @@ func buildSandboxByCodeInterpreter(namespace string, codeInterpreterName string, podSpec: podSpec, podLabels: codeInterpreterObj.Spec.Template.Labels, podAnnotations: codeInterpreterObj.Spec.Template.Annotations, + readinessProbe: readinessProbe, } if codeInterpreterObj.Spec.MaxSessionDuration != nil { buildParams.ttl = codeInterpreterObj.Spec.MaxSessionDuration.Duration