diff --git a/internal/api/security_handlers.go b/internal/api/security_handlers.go index b14bb89..7bfa6e2 100644 --- a/internal/api/security_handlers.go +++ b/internal/api/security_handlers.go @@ -395,8 +395,12 @@ func (s *Server) updateDeploymentSecurity(c *gin.Context) { } vhostUpdated := false + var hookStatus interface{} + if s.proxyOrchestrator != nil && s.proxyOrchestrator.NginxManager().VirtualHostExists(name) { - if err := s.proxyOrchestrator.NginxManager().UpdateVirtualHost(deployment); err != nil { + nginxMgr := s.proxyOrchestrator.NginxManager() + + if err := nginxMgr.UpdateVirtualHost(deployment); err != nil { c.JSON(http.StatusOK, gin.H{ "security": securityConfig, "vhost_updated": false, @@ -405,19 +409,33 @@ func (s *Server) updateDeploymentSecurity(c *gin.Context) { return } - if err := s.proxyOrchestrator.NginxManager().TestConfig(); err != nil { + if err := nginxMgr.ValidateSecurityHooks(name, securityConfig.Enabled); err != nil { + c.JSON(http.StatusOK, gin.H{ + "security": securityConfig, + "vhost_updated": true, + "validation_error": err.Error(), + "warning": "Vhost updated but security hook validation failed", + }) + return + } + + hookStatus, _ = nginxMgr.GetSecurityHookStatus(name) + + if err := nginxMgr.TestConfig(); err != nil { c.JSON(http.StatusOK, gin.H{ "security": securityConfig, - "vhost_updated": false, + "vhost_updated": true, + "hook_status": hookStatus, "warning": "Security config saved but nginx config test failed: " + err.Error(), }) return } - if err := s.proxyOrchestrator.NginxManager().Reload(); err != nil { + if err := nginxMgr.Reload(); err != nil { c.JSON(http.StatusOK, gin.H{ "security": securityConfig, "vhost_updated": true, + "hook_status": hookStatus, "warning": "Nginx reload failed (may need manual reload): " + err.Error(), }) return @@ -425,10 +443,15 @@ func (s *Server) updateDeploymentSecurity(c *gin.Context) { vhostUpdated = true } - c.JSON(http.StatusOK, gin.H{ + response := gin.H{ "security": securityConfig, "vhost_updated": vhostUpdated, - }) + } + if hookStatus != nil { + response["hook_status"] = hookStatus + } + + c.JSON(http.StatusOK, response) } // getDeploymentSecurityEvents returns security events for a deployment @@ -675,3 +698,31 @@ func (s *Server) updateSecuritySettings(c *gin.Context) { c.JSON(http.StatusOK, result) } + +// refreshSecurityScripts regenerates Lua scripts with correct agent IP and reloads nginx +func (s *Server) refreshSecurityScripts(c *gin.Context) { + if !s.config.Security.Enabled { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "Security module not enabled", + }) + return + } + + if !s.infraManager.IsNginxRunning() { + c.JSON(http.StatusServiceUnavailable, gin.H{ + "error": "Nginx container is not running", + }) + return + } + + result, err := s.infraManager.RefreshSecurityScripts() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": err.Error(), + "result": result, + }) + return + } + + c.JSON(http.StatusOK, result) +} diff --git a/internal/api/server.go b/internal/api/server.go index 9bd897a..0233f2e 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -26,6 +26,7 @@ import ( "github.com/flatrun/agent/internal/proxy" "github.com/flatrun/agent/internal/security" "github.com/flatrun/agent/internal/system" + "github.com/flatrun/agent/internal/traffic" "github.com/flatrun/agent/pkg/config" "github.com/flatrun/agent/pkg/models" "github.com/flatrun/agent/pkg/plugins" @@ -53,6 +54,7 @@ type Server struct { infraManager *infra.Manager credentialsManager *credentials.Manager securityManager *security.Manager + trafficManager *traffic.Manager } func New(cfg *config.Config, configPath string) *Server { @@ -99,6 +101,12 @@ func New(cfg *config.Config, configPath string) *Server { } } + var trafficManager *traffic.Manager + trafficManager, err := traffic.NewManager(cfg.DeploymentsPath, 7) + if err != nil { + log.Printf("Warning: Failed to initialize traffic manager: %v", err) + } + s := &Server{ config: cfg, configPath: configPath, @@ -115,6 +123,7 @@ func New(cfg *config.Config, configPath string) *Server { infraManager: infraManager, credentialsManager: credentialsManager, securityManager: securityManager, + trafficManager: trafficManager, } s.setupRoutes() @@ -224,6 +233,7 @@ func (s *Server) setupRoutes() { protected.POST("/databases/tables/schema", s.describeTable) protected.POST("/databases/query", s.executeDatabaseQuery) protected.POST("/databases/users", s.listDatabaseUsers) + protected.POST("/databases/users/by-database", s.listUsersByDatabase) protected.POST("/databases/create", s.createDatabaseInServer) protected.POST("/databases/delete", s.deleteDatabaseInServer) protected.POST("/databases/users/create", s.createDatabaseUser) @@ -268,13 +278,21 @@ func (s *Server) setupRoutes() { protected.GET("/security/realtime-capture", s.getRealtimeCaptureStatus) protected.PUT("/security/realtime-capture", s.setRealtimeCaptureStatus) protected.GET("/security/health", s.getSecurityHealth) + protected.POST("/security/refresh", s.refreshSecurityScripts) protected.GET("/deployments/:name/security", s.getDeploymentSecurity) protected.PUT("/deployments/:name/security", s.updateDeploymentSecurity) protected.GET("/deployments/:name/security/events", s.getDeploymentSecurityEvents) + + // Traffic endpoints + protected.GET("/traffic/logs", s.getTrafficLogs) + protected.GET("/traffic/stats", s.getTrafficStats) + protected.POST("/traffic/cleanup", s.cleanupTrafficLogs) + protected.GET("/deployments/:name/traffic", s.getDeploymentTrafficStats) } - // Security event ingest endpoint (no auth - called by nginx Lua) + // Ingest endpoints (no auth - called by nginx Lua) api.POST("/security/events/ingest", s.ingestSecurityEvent) + api.POST("/traffic/ingest", s.ingestTrafficLog) } } @@ -3464,6 +3482,31 @@ func (s *Server) listDatabaseUsers(c *gin.Context) { }) } +func (s *Server) listUsersByDatabase(c *gin.Context) { + var req struct { + database.ConnectionConfig + Database string `json:"database" binding:"required"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": err.Error(), + }) + return + } + + users, err := s.databaseManager.ListDatabaseUsers(&req.ConnectionConfig, req.Database) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "users": users, + }) +} + func (s *Server) createDatabaseInServer(c *gin.Context) { var req struct { database.ConnectionConfig diff --git a/internal/api/server_test.go b/internal/api/server_test.go index 687aa9d..992a5f1 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -1044,3 +1044,83 @@ func TestProxySyncResultFailed(t *testing.T) { t.Error("expected Created to be false for failed") } } + +func TestListUsersByDatabaseRequestStructure(t *testing.T) { + type listUsersByDatabaseRequest struct { + Type string `json:"type"` + Host string `json:"host"` + Port int `json:"port"` + Username string `json:"username"` + Password string `json:"password"` + Database string `json:"database"` + } + + tests := []struct { + name string + req listUsersByDatabaseRequest + wantMissing string + }{ + { + name: "valid request", + req: listUsersByDatabaseRequest{ + Type: "mysql", + Host: "localhost", + Port: 3306, + Username: "root", + Password: "secret", + Database: "testdb", + }, + wantMissing: "", + }, + { + name: "missing database", + req: listUsersByDatabaseRequest{ + Type: "mysql", + Host: "localhost", + Port: 3306, + Username: "root", + Password: "secret", + }, + wantMissing: "database", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.wantMissing == "database" && tt.req.Database != "" { + t.Error("test setup error: database should be empty for missing database test") + } + if tt.wantMissing == "" && tt.req.Database == "" { + t.Error("test setup error: database should not be empty for valid request test") + } + }) + } +} + +func TestListUsersByDatabaseResponseStructure(t *testing.T) { + type userInfo struct { + Name string `json:"name"` + Host string `json:"host,omitempty"` + } + + type response struct { + Users []userInfo `json:"users"` + } + + resp := response{ + Users: []userInfo{ + {Name: "app_user", Host: "%"}, + {Name: "readonly", Host: "localhost"}, + }, + } + + if len(resp.Users) != 2 { + t.Errorf("expected 2 users, got %d", len(resp.Users)) + } + if resp.Users[0].Name != "app_user" { + t.Errorf("expected first user 'app_user', got '%s'", resp.Users[0].Name) + } + if resp.Users[1].Host != "localhost" { + t.Errorf("expected second user host 'localhost', got '%s'", resp.Users[1].Host) + } +} diff --git a/internal/api/traffic_handlers.go b/internal/api/traffic_handlers.go new file mode 100644 index 0000000..fde2b9b --- /dev/null +++ b/internal/api/traffic_handlers.go @@ -0,0 +1,172 @@ +package api + +import ( + "net/http" + "strconv" + "time" + + "github.com/flatrun/agent/internal/traffic" + "github.com/gin-gonic/gin" +) + +// ingestTrafficLog handles real-time traffic log ingestion from nginx Lua +func (s *Server) ingestTrafficLog(c *gin.Context) { + if s.trafficManager == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "Traffic logging not enabled"}) + return + } + + var ingest traffic.IngestTrafficLog + if err := c.ShouldBindJSON(&ingest); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + log, err := s.trafficManager.IngestLog(&ingest) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusCreated, gin.H{"logged": true, "id": log.ID}) +} + +// getTrafficLogs returns a paginated list of traffic logs +func (s *Server) getTrafficLogs(c *gin.Context) { + if s.trafficManager == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "Traffic logging not enabled"}) + return + } + + filter := &traffic.TrafficFilter{ + DeploymentName: c.Query("deployment"), + RequestMethod: c.Query("method"), + StatusGroup: c.Query("status_group"), + SourceIP: c.Query("source_ip"), + RequestPath: c.Query("path"), + } + + if statusCode := c.Query("status_code"); statusCode != "" { + if code, err := strconv.Atoi(statusCode); err == nil { + filter.StatusCode = &code + } + } + + if limit := c.Query("limit"); limit != "" { + if l, err := strconv.Atoi(limit); err == nil { + filter.Limit = l + } + } else { + filter.Limit = 100 + } + + if offset := c.Query("offset"); offset != "" { + if o, err := strconv.Atoi(offset); err == nil { + filter.Offset = o + } + } + + if startTime := c.Query("start_time"); startTime != "" { + if t, err := time.Parse(time.RFC3339, startTime); err == nil { + filter.StartTime = t + } + } + + if endTime := c.Query("end_time"); endTime != "" { + if t, err := time.Parse(time.RFC3339, endTime); err == nil { + filter.EndTime = t + } + } + + logs, total, err := s.trafficManager.GetLogs(filter) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "logs": logs, + "total": total, + "limit": filter.Limit, + "offset": filter.Offset, + }) +} + +// getTrafficStats returns aggregated traffic statistics +func (s *Server) getTrafficStats(c *gin.Context) { + if s.trafficManager == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "Traffic logging not enabled"}) + return + } + + deploymentName := c.Query("deployment") + + since := 24 * time.Hour + if sinceStr := c.Query("since"); sinceStr != "" { + if d, err := time.ParseDuration(sinceStr); err == nil { + since = d + } + } + + stats, err := s.trafficManager.GetStats(deploymentName, since) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"stats": stats}) +} + +// cleanupTrafficLogs removes old traffic logs +func (s *Server) cleanupTrafficLogs(c *gin.Context) { + if s.trafficManager == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "Traffic logging not enabled"}) + return + } + + var req struct { + Days int `json:"days"` + } + if err := c.ShouldBindJSON(&req); err != nil { + req.Days = 7 + } + if req.Days <= 0 { + req.Days = 7 + } + + deleted, err := s.trafficManager.Cleanup(req.Days) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"deleted": deleted}) +} + +// getDeploymentTrafficStats returns traffic stats for a specific deployment +func (s *Server) getDeploymentTrafficStats(c *gin.Context) { + if s.trafficManager == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "Traffic logging not enabled"}) + return + } + + name := c.Param("name") + + since := 24 * time.Hour + if sinceStr := c.Query("since"); sinceStr != "" { + if d, err := time.ParseDuration(sinceStr); err == nil { + since = d + } + } + + stats, err := s.trafficManager.GetStats(name, since) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "deployment": name, + "stats": stats, + }) +} diff --git a/internal/database/manager.go b/internal/database/manager.go index 8f5d237..59ab1ff 100644 --- a/internal/database/manager.go +++ b/internal/database/manager.go @@ -347,6 +347,65 @@ func (m *Manager) ListUsers(cfg *ConnectionConfig) ([]UserInfo, error) { return users, nil } +func (m *Manager) ListDatabaseUsers(cfg *ConnectionConfig, database string) ([]UserInfo, error) { + driver := m.getDriver(cfg.Type) + if driver == "" { + return nil, fmt.Errorf("unsupported database type: %s", cfg.Type) + } + + dsn, err := m.buildDSN(cfg) + if err != nil { + return nil, err + } + + db, err := sql.Open(driver, dsn) + if err != nil { + return nil, err + } + defer db.Close() + + database = strings.ReplaceAll(database, "'", "") + database = strings.ReplaceAll(database, "\"", "") + database = strings.ReplaceAll(database, ";", "") + + var query string + switch cfg.Type { + case "mysql", "mariadb": + query = fmt.Sprintf(` + SELECT DISTINCT User, Host FROM mysql.db WHERE Db = '%s' + UNION + SELECT DISTINCT User, Host FROM mysql.tables_priv WHERE Db = '%s' + UNION + SELECT DISTINCT User, Host FROM mysql.columns_priv WHERE Db = '%s' + `, database, database, database) + case "postgresql": + query = fmt.Sprintf(` + SELECT DISTINCT grantee, '' as host + FROM information_schema.role_table_grants + WHERE table_catalog = '%s' + `, database) + default: + return nil, fmt.Errorf("unsupported database type: %s", cfg.Type) + } + + rows, err := db.Query(query) + if err != nil { + return nil, err + } + defer rows.Close() + + var users []UserInfo + for rows.Next() { + var u UserInfo + if err := rows.Scan(&u.Name, &u.Host); err != nil { + continue + } + users = append(users, u) + } + + return users, nil +} + func (m *Manager) CreateDatabase(cfg *ConnectionConfig, dbName string) error { driver := m.getDriver(cfg.Type) if driver == "" { diff --git a/internal/database/manager_test.go b/internal/database/manager_test.go index 201ba81..5499c17 100644 --- a/internal/database/manager_test.go +++ b/internal/database/manager_test.go @@ -293,3 +293,54 @@ func TestBuildDSN(t *testing.T) { }) } } + +func TestListDatabaseUsers_UnsupportedType(t *testing.T) { + m := NewManager() + + cfg := &ConnectionConfig{ + Type: "mongodb", + Host: "localhost", + Port: 27017, + Username: "admin", + Password: "secret", + } + + _, err := m.ListDatabaseUsers(cfg, "testdb") + if err == nil { + t.Error("expected error for unsupported database type") + } + if !strings.Contains(err.Error(), "unsupported") { + t.Errorf("error should mention 'unsupported', got: %v", err) + } +} + +func TestListDatabaseUsers_SanitizesDatabaseName(t *testing.T) { + m := NewManager() + + tests := []struct { + name string + database string + }{ + {"single quotes", "test'db"}, + {"double quotes", "test\"db"}, + {"semicolon", "test;db"}, + {"multiple special chars", "test';\"db"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &ConnectionConfig{ + Type: "mysql", + Host: "localhost", + Port: 3306, + Username: "root", + Password: "secret", + } + + _, err := m.ListDatabaseUsers(cfg, tt.database) + if err != nil && strings.Contains(err.Error(), "syntax") { + t.Errorf("SQL injection should be prevented, got syntax error: %v", err) + } + }) + } +} diff --git a/internal/infra/manager.go b/internal/infra/manager.go index b9eea51..825e4d3 100644 --- a/internal/infra/manager.go +++ b/internal/infra/manager.go @@ -368,11 +368,16 @@ func (m *Manager) SetNginxRealtimeCaptureWithStatus(enabled bool) (map[string]in } } - // Create lua directory and write security.lua + // Create lua directory and write security.lua with injected agent IP if err := os.MkdirAll(luaDir, 0755); err != nil { errors = append(errors, fmt.Sprintf("failed to create lua directory: %v", err)) } else { - securityLua, err := templates.GetNginxSecurityLua() + agentIP := m.GetDockerHostIP() + agentPort := m.GetAgentPort() + result["agent_ip"] = agentIP + result["agent_port"] = agentPort + + securityLua, err := templates.GetNginxSecurityLuaWithConfig(agentIP, agentPort) if err != nil { errors = append(errors, fmt.Sprintf("failed to get security.lua template: %v", err)) } else { @@ -383,6 +388,16 @@ func (m *Manager) SetNginxRealtimeCaptureWithStatus(enabled bool) (map[string]in result["lua_files_written"] = true } } + + trafficLua, err := templates.GetNginxTrafficLuaWithConfig(agentIP, agentPort) + if err != nil { + errors = append(errors, fmt.Sprintf("failed to get traffic.lua template: %v", err)) + } else { + luaPath := filepath.Join(luaDir, "traffic.lua") + if err := os.WriteFile(luaPath, trafficLua, 0644); err != nil { + errors = append(errors, fmt.Sprintf("failed to write traffic.lua: %v", err)) + } + } } // Ensure conf.d directory and security config files exist @@ -472,7 +487,58 @@ func (m *Manager) IsNginxRunning() bool { return strings.TrimSpace(string(output)) == "true" } +// GetDockerHostIP returns the IP address that containers can use to reach the host. +// It tries multiple methods and falls back to the default Docker bridge gateway. +func (m *Manager) GetDockerHostIP() string { + // Method 1: Try to get host.docker.internal from nginx container's /etc/hosts + if m.config.Nginx.ContainerName != "" && m.IsNginxRunning() { + cmd := exec.Command("docker", "exec", m.config.Nginx.ContainerName, "sh", "-c", + "getent hosts host.docker.internal 2>/dev/null | awk '{print $1}'") + if output, err := cmd.Output(); err == nil { + ip := strings.TrimSpace(string(output)) + if ip != "" && ip != "host.docker.internal" { + return ip + } + } + + // Also try grepping /etc/hosts + cmd = exec.Command("docker", "exec", m.config.Nginx.ContainerName, "sh", "-c", + "grep host.docker.internal /etc/hosts 2>/dev/null | awk '{print $1}'") + if output, err := cmd.Output(); err == nil { + ip := strings.TrimSpace(string(output)) + if ip != "" { + return ip + } + } + } + + // Method 2: Try to get the Docker bridge gateway IP + cmd := exec.Command("docker", "network", "inspect", "bridge", "-f", + "{{range .IPAM.Config}}{{.Gateway}}{{end}}") + if output, err := cmd.Output(); err == nil { + ip := strings.TrimSpace(string(output)) + if ip != "" { + return ip + } + } + + // Fallback: Default Docker bridge gateway + return "172.17.0.1" +} + +// GetAgentPort returns the port the agent API is listening on +func (m *Manager) GetAgentPort() int { + if m.config.API.Port > 0 { + return m.config.API.Port + } + return 8090 +} + func (m *Manager) reloadNginx() error { + if err := m.waitForContainerReady(5); err != nil { + return fmt.Errorf("container not ready: %w", err) + } + reloadCmd := m.config.Nginx.ReloadCommand if reloadCmd == "" { reloadCmd = "nginx -s reload" @@ -486,6 +552,31 @@ func (m *Manager) reloadNginx() error { return nil } +func (m *Manager) waitForContainerReady(maxRetries int) error { + containerName := m.config.Nginx.ContainerName + for i := 0; i < maxRetries; i++ { + cmd := exec.Command("docker", "inspect", "-f", "{{.State.Status}}", containerName) + output, err := cmd.Output() + if err != nil { + return fmt.Errorf("failed to get container status: %w", err) + } + + status := strings.TrimSpace(string(output)) + if status == "running" { + cmd = exec.Command("docker", "inspect", "-f", "{{.State.Restarting}}", containerName) + output, err = cmd.Output() + if err == nil && strings.TrimSpace(string(output)) == "false" { + return nil + } + } + + if i < maxRetries-1 { + time.Sleep(time.Second) + } + } + return fmt.Errorf("container %s not ready after %d attempts", containerName, maxRetries) +} + func (m *Manager) getNginxDir() string { configPath := m.config.Nginx.ConfigPath if configPath == "" { @@ -520,32 +611,94 @@ func (m *Manager) CheckSecurityHealth() *SecurityHealthCheck { result.Details["nginx_dir"] = nginxDir result.Details["nginx_container"] = m.config.Nginx.ContainerName - // Check 1: security.lua exists - luaPath := filepath.Join(nginxDir, "lua", "security.lua") - if _, err := os.Stat(luaPath); err == nil { + // Check 1: security.lua exists and has correct agent IP + securityLuaPath := filepath.Join(nginxDir, "lua", "security.lua") + if content, err := os.ReadFile(securityLuaPath); err == nil { result.Checks["security_lua_exists"] = true - result.Details["security_lua_path"] = luaPath + result.Details["security_lua_path"] = securityLuaPath + + // Check if agent IP is properly configured + if strings.Contains(string(content), "host.docker.internal") { + result.Checks["security_lua_ip_injected"] = false + result.Issues = append(result.Issues, "Agent connection not configured in security module") + result.Recommendations = append(result.Recommendations, "Click 'Regenerate Scripts' in Security settings to configure agent connection") + } else { + result.Checks["security_lua_ip_injected"] = true + } } else { result.Checks["security_lua_exists"] = false - result.Issues = append(result.Issues, "security.lua does not exist at "+luaPath) + result.Checks["security_lua_ip_injected"] = false + result.Issues = append(result.Issues, "security.lua does not exist at "+securityLuaPath) result.Recommendations = append(result.Recommendations, "Enable realtime capture in Security settings to deploy security.lua") } + // Check 1b: traffic.lua exists and has correct agent IP + trafficLuaPath := filepath.Join(nginxDir, "lua", "traffic.lua") + if content, err := os.ReadFile(trafficLuaPath); err == nil { + result.Checks["traffic_lua_exists"] = true + result.Details["traffic_lua_path"] = trafficLuaPath + + if strings.Contains(string(content), "host.docker.internal") { + result.Checks["traffic_lua_ip_injected"] = false + result.Issues = append(result.Issues, "Agent connection not configured in traffic module") + result.Recommendations = append(result.Recommendations, "Click 'Regenerate Scripts' in Security settings to configure agent connection") + } else { + result.Checks["traffic_lua_ip_injected"] = true + } + } else { + result.Checks["traffic_lua_exists"] = false + result.Checks["traffic_lua_ip_injected"] = false + result.Issues = append(result.Issues, "traffic.lua does not exist at "+trafficLuaPath) + result.Recommendations = append(result.Recommendations, "Enable realtime capture to deploy traffic.lua for request logging") + } + + // Check 1c: Agent IP detection works + agentIP := m.GetDockerHostIP() + result.Details["detected_agent_ip"] = agentIP + result.Details["agent_port"] = m.GetAgentPort() + if agentIP != "" { + result.Checks["agent_ip_detected"] = true + } else { + result.Checks["agent_ip_detected"] = false + result.Issues = append(result.Issues, "Unable to detect agent network address") + } + // Check 2: nginx.conf exists and has Lua initialization nginxConfPath := filepath.Join(nginxDir, "nginx.conf") result.Details["nginx_conf_path"] = nginxConfPath if content, err := os.ReadFile(nginxConfPath); err == nil { result.Checks["nginx_conf_exists"] = true - if strings.Contains(string(content), "init_by_lua_block") { + contentStr := string(content) + + if strings.Contains(contentStr, "init_by_lua_block") { result.Checks["nginx_conf_has_lua_init"] = true } else { result.Checks["nginx_conf_has_lua_init"] = false result.Issues = append(result.Issues, "nginx.conf does not have init_by_lua_block directive") result.Recommendations = append(result.Recommendations, "Enable realtime capture to generate Lua-enabled nginx.conf") } + + // Check for traffic module loading + if strings.Contains(contentStr, "traffic = require") || strings.Contains(contentStr, "traffic.log_request") { + result.Checks["nginx_conf_has_traffic_module"] = true + } else { + result.Checks["nginx_conf_has_traffic_module"] = false + result.Issues = append(result.Issues, "nginx.conf does not load traffic module for request logging") + result.Recommendations = append(result.Recommendations, "Use POST /api/security/refresh to regenerate nginx.conf with traffic logging") + } + + // Check for global traffic logging + if strings.Contains(contentStr, "log_by_lua_block") && strings.Contains(contentStr, "traffic.log_request") { + result.Checks["nginx_conf_has_global_traffic_logging"] = true + } else { + result.Checks["nginx_conf_has_global_traffic_logging"] = false + result.Issues = append(result.Issues, "nginx.conf does not have global traffic logging enabled") + } } else { result.Checks["nginx_conf_exists"] = false result.Checks["nginx_conf_has_lua_init"] = false + result.Checks["nginx_conf_has_traffic_module"] = false + result.Checks["nginx_conf_has_global_traffic_logging"] = false result.Issues = append(result.Issues, "nginx.conf does not exist at "+nginxConfPath) result.Recommendations = append(result.Recommendations, "Enable realtime capture in Security settings") } @@ -596,13 +749,13 @@ func (m *Manager) CheckSecurityHealth() *SecurityHealthCheck { "Add volume mount to nginx docker-compose: "+nginxConfPath+":/usr/local/openresty/nginx/conf/nginx.conf:ro") } - // Check if extra_hosts is configured (for Linux) + // Check if nginx can reach the agent hasExtraHosts := m.checkNginxExtraHosts() result.Checks["nginx_extra_hosts_configured"] = hasExtraHosts if !hasExtraHosts { - result.Issues = append(result.Issues, "Nginx container may not be able to reach host.docker.internal") + result.Issues = append(result.Issues, "Nginx container cannot reach the agent") result.Recommendations = append(result.Recommendations, - "Add extra_hosts to nginx docker-compose: - \"host.docker.internal:host-gateway\"") + "Configure network access in your nginx docker-compose file") } } } else { @@ -610,21 +763,43 @@ func (m *Manager) CheckSecurityHealth() *SecurityHealthCheck { result.Issues = append(result.Issues, "Nginx container name not configured") } - // Check 6: Vhosts have log_by_lua_block directive + // Check 6: Vhosts with security enabled have log_by_lua_block directive vhostsWithHook, vhostsWithoutHook := m.checkVhostsSecurityHook() + deploymentsWithSecurityEnabled := m.getDeploymentsWithSecurityEnabled() + result.Details["vhosts_with_security_hook"] = vhostsWithHook result.Details["vhosts_without_security_hook"] = vhostsWithoutHook - if len(vhostsWithoutHook) > 0 { + result.Details["deployments_with_security_enabled"] = deploymentsWithSecurityEnabled + + // Find vhosts that SHOULD have hooks but don't + var missingHooks []string + for _, dep := range deploymentsWithSecurityEnabled { + hasHook := false + for _, v := range vhostsWithHook { + if v == dep { + hasHook = true + break + } + } + if !hasHook { + missingHooks = append(missingHooks, dep) + } + } + + result.Details["vhosts_missing_required_hooks"] = missingHooks + + if len(missingHooks) > 0 { result.Checks["vhosts_have_security_hook"] = false result.Issues = append(result.Issues, - fmt.Sprintf("%d vhost(s) missing log_by_lua_block: %v", len(vhostsWithoutHook), vhostsWithoutHook)) + fmt.Sprintf("%d deployment(s) have security enabled but vhost missing hooks: %v", len(missingHooks), missingHooks)) result.Recommendations = append(result.Recommendations, - "Add log_by_lua_block { security.capture_event() } to vhost server blocks, or use the regenerate vhosts API") - } else if len(vhostsWithHook) > 0 { + "Use PUT /api/deployments/:name/security to regenerate vhost with security hooks") + } else if len(deploymentsWithSecurityEnabled) > 0 { result.Checks["vhosts_have_security_hook"] = true } else { - result.Checks["vhosts_have_security_hook"] = false - result.Issues = append(result.Issues, "No vhost configurations found") + // No deployments have security enabled - that's fine, hooks not required + result.Checks["vhosts_have_security_hook"] = true + result.Details["note"] = "No deployments have per-deployment security enabled (traffic logging still works globally)" } // Check 7: Lua directory is mounted in nginx container @@ -638,15 +813,40 @@ func (m *Manager) CheckSecurityHealth() *SecurityHealthCheck { } } + // Check 8: DNS/Connectivity - Can nginx reach the agent? + if m.config.Nginx.ContainerName != "" && result.Checks["nginx_container_running"] { + agentIP := m.GetDockerHostIP() + agentPort := m.GetAgentPort() + canReachAgent := m.checkNginxCanReachAgent(agentIP, agentPort) + result.Checks["nginx_can_reach_agent"] = canReachAgent + result.Details["connectivity_test_ip"] = agentIP + result.Details["connectivity_test_port"] = agentPort + + if !canReachAgent { + result.Issues = append(result.Issues, + fmt.Sprintf("Nginx container cannot reach agent at %s:%d - Lua scripts will fail to send events", agentIP, agentPort)) + result.Recommendations = append(result.Recommendations, + "1. Check if agent is running and listening on the correct port") + result.Recommendations = append(result.Recommendations, + "2. Ensure nginx container has network access to host (extra_hosts or host network mode)") + result.Recommendations = append(result.Recommendations, + "3. Use POST /api/security/refresh to regenerate scripts with correct IP") + } + } + // Determine overall status criticalChecks := []string{ "security_lua_exists", + "security_lua_ip_injected", + "traffic_lua_exists", + "traffic_lua_ip_injected", "nginx_conf_has_lua_init", "nginx_container_running", "nginx_lua_module_loaded", "nginx_conf_mounted", - "vhosts_have_security_hook", "lua_directory_mounted", + "nginx_can_reach_agent", + "vhosts_have_security_hook", } failedCritical := 0 @@ -717,6 +917,36 @@ func (m *Manager) checkNginxExtraHosts() bool { return err == nil } +// getDeploymentsWithSecurityEnabled reads deployment metadata to find which have security enabled +func (m *Manager) getDeploymentsWithSecurityEnabled() []string { + var enabled []string + + entries, err := os.ReadDir(m.config.DeploymentsPath) + if err != nil { + return enabled + } + + for _, entry := range entries { + if !entry.IsDir() || strings.HasPrefix(entry.Name(), ".") { + continue + } + + metadataPath := filepath.Join(m.config.DeploymentsPath, entry.Name(), "service.yml") + content, err := os.ReadFile(metadataPath) + if err != nil { + continue + } + + contentStr := string(content) + if strings.Contains(contentStr, "security:") && + (strings.Contains(contentStr, "enabled: true") || strings.Contains(contentStr, "enabled: \"true\"")) { + enabled = append(enabled, entry.Name()) + } + } + + return enabled +} + func (m *Manager) checkVhostsSecurityHook() (withHook []string, withoutHook []string) { nginxDir := m.getNginxDir() confDir := filepath.Join(nginxDir, "conf.d") @@ -754,7 +984,7 @@ func (m *Manager) checkVhostsSecurityHook() (withHook []string, withoutHook []st func (m *Manager) checkNginxLuaDirectoryMounted() bool { cmd := exec.Command("docker", "exec", m.config.Nginx.ContainerName, "sh", "-c", - "test -f /etc/nginx/lua/security.lua && echo yes") + "test -f /etc/nginx/lua/security.lua && test -f /etc/nginx/lua/traffic.lua && echo yes") output, err := cmd.Output() if err != nil { return false @@ -762,6 +992,24 @@ func (m *Manager) checkNginxLuaDirectoryMounted() bool { return strings.TrimSpace(string(output)) == "yes" } +// checkNginxCanReachAgent tests if nginx container can reach the agent API +func (m *Manager) checkNginxCanReachAgent(agentIP string, agentPort int) bool { + // Try to connect to agent health endpoint from nginx container + // Use wget or curl depending on what's available in the container + testCmd := fmt.Sprintf( + "wget -q -O /dev/null --timeout=2 http://%s:%d/api/health 2>/dev/null && echo yes || "+ + "curl -s --connect-timeout 2 http://%s:%d/api/health >/dev/null 2>&1 && echo yes || "+ + "echo no", + agentIP, agentPort, agentIP, agentPort) + + cmd := exec.Command("docker", "exec", m.config.Nginx.ContainerName, "sh", "-c", testCmd) + output, err := cmd.Output() + if err != nil { + return false + } + return strings.Contains(string(output), "yes") +} + // securityVolumeMounts are added when security is enabled and removed when disabled var securityVolumeMounts = []string{ "./nginx.conf:/usr/local/openresty/nginx/conf/nginx.conf:ro", @@ -878,3 +1126,138 @@ func (m *Manager) removeSecurityVolumeMountsInternal() (bool, error) { return modified, nil } + +// RefreshSecurityScriptsResult contains the result of refreshing security scripts +type RefreshSecurityScriptsResult struct { + Success bool `json:"success"` + AgentIP string `json:"agent_ip"` + AgentPort int `json:"agent_port"` + NginxConfWritten bool `json:"nginx_conf_written"` + LuaWritten bool `json:"lua_written"` + VolumesModified bool `json:"volumes_modified"` + ContainerRecreated bool `json:"container_recreated"` + NginxReloaded bool `json:"nginx_reloaded"` + VhostsUpdated []string `json:"vhosts_updated,omitempty"` + Errors []string `json:"errors,omitempty"` +} + +// RefreshSecurityScripts regenerates all security configs: nginx.conf, Lua scripts, and vhosts +func (m *Manager) RefreshSecurityScripts() (*RefreshSecurityScriptsResult, error) { + m.mu.Lock() + defer m.mu.Unlock() + + result := &RefreshSecurityScriptsResult{ + Success: true, + Errors: []string{}, + VhostsUpdated: []string{}, + } + + // Get agent IP and port + agentIP := m.GetDockerHostIP() + agentPort := m.GetAgentPort() + result.AgentIP = agentIP + result.AgentPort = agentPort + + nginxDir := m.getNginxDir() + if nginxDir == "" { + result.Errors = append(result.Errors, "nginx config path not configured") + result.Success = false + return result, fmt.Errorf("nginx config path not configured") + } + + luaDir := filepath.Join(nginxDir, "lua") + confPath := filepath.Join(nginxDir, "nginx.conf") + + // Create directories + if err := os.MkdirAll(luaDir, 0755); err != nil { + result.Errors = append(result.Errors, fmt.Sprintf("failed to create lua directory: %v", err)) + result.Success = false + return result, err + } + + confDir := filepath.Join(nginxDir, "conf.d") + if err := os.MkdirAll(confDir, 0755); err != nil { + result.Errors = append(result.Errors, fmt.Sprintf("failed to create conf.d directory: %v", err)) + } + + // Write nginx.conf with Lua support + nginxConf, err := templates.GetNginxConfig(true) + if err != nil { + result.Errors = append(result.Errors, fmt.Sprintf("failed to get nginx lua config template: %v", err)) + } else { + if err := os.WriteFile(confPath, nginxConf, 0644); err != nil { + result.Errors = append(result.Errors, fmt.Sprintf("failed to write nginx.conf: %v", err)) + } else { + result.NginxConfWritten = true + } + } + + // Generate and write security.lua with injected IP + securityLua, err := templates.GetNginxSecurityLuaWithConfig(agentIP, agentPort) + if err != nil { + result.Errors = append(result.Errors, fmt.Sprintf("failed to generate security.lua: %v", err)) + result.Success = false + return result, err + } + + securityLuaPath := filepath.Join(luaDir, "security.lua") + if err := os.WriteFile(securityLuaPath, securityLua, 0644); err != nil { + result.Errors = append(result.Errors, fmt.Sprintf("failed to write security.lua: %v", err)) + result.Success = false + return result, err + } + result.LuaWritten = true + + // Generate and write traffic.lua with injected IP + trafficLua, err := templates.GetNginxTrafficLuaWithConfig(agentIP, agentPort) + if err != nil { + result.Errors = append(result.Errors, fmt.Sprintf("failed to generate traffic.lua: %v", err)) + } else { + trafficLuaPath := filepath.Join(luaDir, "traffic.lua") + if err := os.WriteFile(trafficLuaPath, trafficLua, 0644); err != nil { + result.Errors = append(result.Errors, fmt.Sprintf("failed to write traffic.lua: %v", err)) + } + } + + // Ensure blocked_ips.conf exists + blockedIPsPath := filepath.Join(confDir, "blocked_ips.conf") + if _, err := os.Stat(blockedIPsPath); os.IsNotExist(err) { + content := "# Auto-generated - No blocked IPs\n" + if err := os.WriteFile(blockedIPsPath, []byte(content), 0644); err != nil { + result.Errors = append(result.Errors, fmt.Sprintf("failed to create blocked_ips.conf: %v", err)) + } + } + + // Ensure rate_limits.conf exists + rateLimitsPath := filepath.Join(confDir, "rate_limits.conf") + if _, err := os.Stat(rateLimitsPath); os.IsNotExist(err) { + content := "# Auto-generated - No rate limit zones\n" + if err := os.WriteFile(rateLimitsPath, []byte(content), 0644); err != nil { + result.Errors = append(result.Errors, fmt.Sprintf("failed to create rate_limits.conf: %v", err)) + } + } + + // Add volume mounts to docker-compose if needed + volumesModified, volumeErr := m.addSecurityVolumeMountsInternal() + if volumeErr != nil { + result.Errors = append(result.Errors, fmt.Sprintf("failed to modify volume mounts: %v", volumeErr)) + } + result.VolumesModified = volumesModified + + // Recreate or reload nginx container + if volumesModified { + if err := m.recreateNginxContainer(); err != nil { + result.Errors = append(result.Errors, fmt.Sprintf("failed to recreate nginx container: %v", err)) + } else { + result.ContainerRecreated = true + } + } else if m.IsNginxRunning() { + if err := m.reloadNginx(); err != nil { + result.Errors = append(result.Errors, fmt.Sprintf("failed to reload nginx: %v", err)) + } else { + result.NginxReloaded = true + } + } + + return result, nil +} diff --git a/internal/nginx/manager.go b/internal/nginx/manager.go index 6afd775..32a9102 100644 --- a/internal/nginx/manager.go +++ b/internal/nginx/manager.go @@ -9,6 +9,7 @@ import ( "strings" "sync" "text/template" + "time" "github.com/flatrun/agent/pkg/config" "github.com/flatrun/agent/pkg/models" @@ -224,6 +225,10 @@ func (m *Manager) Reload() error { return fmt.Errorf("nginx container name not configured") } + if err := m.waitForContainerReady(5); err != nil { + return fmt.Errorf("container not ready: %w", err) + } + reloadCmd := m.config.ReloadCommand if reloadCmd == "" { reloadCmd = "nginx -s reload" @@ -243,6 +248,10 @@ func (m *Manager) TestConfig() error { return fmt.Errorf("nginx container name not configured") } + if err := m.waitForContainerReady(5); err != nil { + return fmt.Errorf("container not ready: %w", err) + } + cmd := exec.Command("docker", "exec", m.config.ContainerName, "nginx", "-t") output, err := cmd.CombinedOutput() if err != nil { @@ -252,6 +261,31 @@ func (m *Manager) TestConfig() error { return nil } +func (m *Manager) waitForContainerReady(maxRetries int) error { + for i := 0; i < maxRetries; i++ { + cmd := exec.Command("docker", "inspect", "-f", "{{.State.Status}}", m.config.ContainerName) + output, err := cmd.Output() + if err != nil { + return fmt.Errorf("failed to get container status: %w", err) + } + + status := strings.TrimSpace(string(output)) + if status == "running" { + // Also check it's not in a restart loop + cmd = exec.Command("docker", "inspect", "-f", "{{.State.Restarting}}", m.config.ContainerName) + output, err = cmd.Output() + if err == nil && strings.TrimSpace(string(output)) == "false" { + return nil + } + } + + if i < maxRetries-1 { + time.Sleep(time.Second) + } + } + return fmt.Errorf("container not ready after %d attempts", maxRetries) +} + func (m *Manager) generateConfig(deployment *models.Deployment) (string, error) { net := deployment.Metadata.Networking ssl := deployment.Metadata.SSL @@ -606,3 +640,105 @@ func (m *Manager) updateRateLimitsInternal(deploymentName string, rateLimits []m func (m *Manager) RemoveDeploymentRateLimits(deploymentName string) error { return m.UpdateDeploymentRateLimits(deploymentName, nil) } + +// ValidateSecurityHooks checks that a vhost has the correct security hooks based on the expected state +func (m *Manager) ValidateSecurityHooks(deploymentName string, shouldHaveHooks bool) error { + content, err := m.GetVirtualHost(deploymentName) + if err != nil { + return fmt.Errorf("failed to read vhost: %w", err) + } + + hasHooks := strings.Contains(content, "security.capture_event()") + hasLogByLua := strings.Contains(content, "log_by_lua_block") + + if shouldHaveHooks { + if !hasHooks { + return fmt.Errorf("security enabled but vhost missing security.capture_event() call") + } + if !hasLogByLua { + return fmt.Errorf("security enabled but vhost missing log_by_lua_block") + } + + // Check hooks are inside location blocks + lines := strings.Split(content, "\n") + inLocation := false + foundHookInLocation := false + + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "location ") { + inLocation = true + } + if inLocation && strings.Contains(trimmed, "security.capture_event()") { + foundHookInLocation = true + } + if trimmed == "}" && inLocation { + inLocation = false + } + } + + if !foundHookInLocation { + return fmt.Errorf("security hook not properly placed inside location block") + } + } else { + if hasHooks { + return fmt.Errorf("security disabled but vhost still contains security.capture_event()") + } + } + + return nil +} + +// SecurityHookStatus returns details about security hooks in a vhost +type SecurityHookStatus struct { + HasHooks bool `json:"has_hooks"` + HookLocations []string `json:"hook_locations"` + ProperlyConfigured bool `json:"properly_configured"` +} + +// GetSecurityHookStatus returns detailed info about security hooks in a vhost +func (m *Manager) GetSecurityHookStatus(deploymentName string) (*SecurityHookStatus, error) { + content, err := m.GetVirtualHost(deploymentName) + if err != nil { + return nil, err + } + + status := &SecurityHookStatus{ + HasHooks: strings.Contains(content, "security.capture_event()"), + HookLocations: []string{}, + } + + lines := strings.Split(content, "\n") + currentLocation := "" + inLocation := false + depth := 0 + + for _, line := range lines { + trimmed := strings.TrimSpace(line) + + if strings.HasPrefix(trimmed, "location ") { + inLocation = true + depth = 1 + parts := strings.Fields(trimmed) + if len(parts) >= 2 { + currentLocation = parts[1] + } + } + + if inLocation { + depth += strings.Count(trimmed, "{") - strings.Count(trimmed, "}") + if strings.Contains(trimmed, "security.capture_event()") { + status.HookLocations = append(status.HookLocations, currentLocation) + } + if depth <= 0 { + inLocation = false + currentLocation = "" + } + } + } + + // Properly configured if hooks are present and all found in location blocks + status.ProperlyConfigured = status.HasHooks && len(status.HookLocations) > 0 + + return status, nil +} diff --git a/internal/nginx/manager_test.go b/internal/nginx/manager_test.go index 2226f92..16df67d 100644 --- a/internal/nginx/manager_test.go +++ b/internal/nginx/manager_test.go @@ -1444,6 +1444,195 @@ func TestSanitizeZoneName(t *testing.T) { } } +func TestValidateSecurityHooks(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "nginx-validate-hooks-*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + confDir := filepath.Join(tmpDir, "nginx", "conf.d") + if err := os.MkdirAll(confDir, 0755); err != nil { + t.Fatalf("failed to create conf.d: %v", err) + } + + cfg := &config.NginxConfig{} + m := NewManager(cfg, tmpDir, "") + + t.Run("returns error when vhost does not exist", func(t *testing.T) { + err := m.ValidateSecurityHooks("nonexistent", true) + if err == nil { + t.Error("expected error for nonexistent vhost") + } + }) + + t.Run("validates hooks are present when expected", func(t *testing.T) { + configWithHook := `server { + listen 80; + server_name test.example.com; + + location / { + proxy_pass http://test:8080; + log_by_lua_block { + security.capture_event() + } + } +}` + if err := os.WriteFile(filepath.Join(confDir, "with-hook.conf"), []byte(configWithHook), 0644); err != nil { + t.Fatalf("failed to write config: %v", err) + } + + err := m.ValidateSecurityHooks("with-hook", true) + if err != nil { + t.Errorf("expected no error, got: %v", err) + } + }) + + t.Run("returns error when hooks missing but expected", func(t *testing.T) { + configWithoutHook := `server { + listen 80; + server_name test.example.com; + + location / { + proxy_pass http://test:8080; + } +}` + if err := os.WriteFile(filepath.Join(confDir, "without-hook.conf"), []byte(configWithoutHook), 0644); err != nil { + t.Fatalf("failed to write config: %v", err) + } + + err := m.ValidateSecurityHooks("without-hook", true) + if err == nil { + t.Error("expected error when hooks are missing but expected") + } + }) + + t.Run("validates hooks are absent when not expected", func(t *testing.T) { + configWithoutHook := `server { + listen 80; + server_name test.example.com; + + location / { + proxy_pass http://test:8080; + } +}` + if err := os.WriteFile(filepath.Join(confDir, "clean.conf"), []byte(configWithoutHook), 0644); err != nil { + t.Fatalf("failed to write config: %v", err) + } + + err := m.ValidateSecurityHooks("clean", false) + if err != nil { + t.Errorf("expected no error, got: %v", err) + } + }) + + t.Run("returns error when hooks present but not expected", func(t *testing.T) { + configWithHook := `server { + listen 80; + server_name test.example.com; + + location / { + proxy_pass http://test:8080; + log_by_lua_block { + security.capture_event() + } + } +}` + if err := os.WriteFile(filepath.Join(confDir, "unwanted-hook.conf"), []byte(configWithHook), 0644); err != nil { + t.Fatalf("failed to write config: %v", err) + } + + err := m.ValidateSecurityHooks("unwanted-hook", false) + if err == nil { + t.Error("expected error when hooks present but not expected") + } + }) +} + +func TestGetSecurityHookStatus(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "nginx-hook-status-*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + confDir := filepath.Join(tmpDir, "nginx", "conf.d") + if err := os.MkdirAll(confDir, 0755); err != nil { + t.Fatalf("failed to create conf.d: %v", err) + } + + cfg := &config.NginxConfig{} + m := NewManager(cfg, tmpDir, "") + + t.Run("returns status for vhost with hooks", func(t *testing.T) { + configWithHook := `server { + listen 80; + server_name test.example.com; + + location / { + proxy_pass http://test:8080; + log_by_lua_block { + security.capture_event() + } + } + + location /api { + proxy_pass http://test:8080; + log_by_lua_block { + security.capture_event() + } + } +}` + if err := os.WriteFile(filepath.Join(confDir, "multi-hook.conf"), []byte(configWithHook), 0644); err != nil { + t.Fatalf("failed to write config: %v", err) + } + + status, err := m.GetSecurityHookStatus("multi-hook") + if err != nil { + t.Fatalf("GetSecurityHookStatus failed: %v", err) + } + + if !status.HasHooks { + t.Error("expected HasHooks to be true") + } + if len(status.HookLocations) != 2 { + t.Errorf("expected 2 hook locations, got %d", len(status.HookLocations)) + } + if !status.ProperlyConfigured { + t.Error("expected ProperlyConfigured to be true") + } + }) + + t.Run("returns status for vhost without hooks", func(t *testing.T) { + configWithoutHook := `server { + listen 80; + server_name test.example.com; + + location / { + proxy_pass http://test:8080; + } +}` + if err := os.WriteFile(filepath.Join(confDir, "no-hook.conf"), []byte(configWithoutHook), 0644); err != nil { + t.Fatalf("failed to write config: %v", err) + } + + status, err := m.GetSecurityHookStatus("no-hook") + if err != nil { + t.Fatalf("GetSecurityHookStatus failed: %v", err) + } + + if status.HasHooks { + t.Error("expected HasHooks to be false") + } + if len(status.HookLocations) != 0 { + t.Errorf("expected 0 hook locations, got %d", len(status.HookLocations)) + } + if status.ProperlyConfigured { + t.Error("expected ProperlyConfigured to be false") + } + }) +} + func TestUpdateDeploymentRateLimits(t *testing.T) { tmpDir, err := os.MkdirTemp("", "nginx-rate-limits-test-*") if err != nil { diff --git a/internal/security/nginx.go b/internal/security/nginx.go index 1b6e562..09866d9 100644 --- a/internal/security/nginx.go +++ b/internal/security/nginx.go @@ -243,7 +243,9 @@ const LuaSecurityScript = `-- FlatRun Security Event Capture local cjson = require "cjson" local http = require "resty.http" -local AGENT_URL = "{{.AgentURL}}" +-- Configuration (injected by agent during deployment) +local AGENT_IP = "{{.AgentIP}}" +local AGENT_PORT = {{.AgentPort}} -- Suspicious paths patterns local suspicious_patterns = { @@ -301,9 +303,6 @@ local function capture_event() -- Send event to agent API (non-blocking) ngx.timer.at(0, function() - local httpc = http.new() - httpc:set_timeout(5000) - local body = cjson.encode({ source_ip = ip, request_path = uri, @@ -314,10 +313,27 @@ local function capture_event() timestamp = ngx.time() }) - local res, err = httpc:request_uri(AGENT_URL .. "/api/security/events/ingest", { + local httpc = http.new() + httpc:set_timeout(2000) + + -- Connect directly using injected IP and port + local conn_ok, conn_err = httpc:connect({ + host = AGENT_IP, + port = AGENT_PORT, + scheme = "http", + }) + + if not conn_ok then + ngx.log(ngx.ERR, "Failed to connect to agent: ", conn_err) + return + end + + local res, err = httpc:request({ method = "POST", + path = "/api/security/events/ingest", body = body, headers = { + ["Host"] = AGENT_IP .. ":" .. AGENT_PORT, ["Content-Type"] = "application/json", } }) @@ -336,8 +352,8 @@ return { } ` -// GenerateLuaScript generates the security.lua file -func GenerateLuaScript(agentURL string) (string, error) { +// GenerateLuaScript generates the security.lua file with injected agent IP and port +func GenerateLuaScript(agentIP string, agentPort int) (string, error) { tmpl, err := template.New("lua").Parse(LuaSecurityScript) if err != nil { return "", err @@ -345,9 +361,11 @@ func GenerateLuaScript(agentURL string) (string, error) { var buf bytes.Buffer err = tmpl.Execute(&buf, struct { - AgentURL string + AgentIP string + AgentPort int }{ - AgentURL: agentURL, + AgentIP: agentIP, + AgentPort: agentPort, }) if err != nil { return "", err diff --git a/internal/traffic/db.go b/internal/traffic/db.go new file mode 100644 index 0000000..3f56a0d --- /dev/null +++ b/internal/traffic/db.go @@ -0,0 +1,380 @@ +package traffic + +import ( + "database/sql" + "os" + "path/filepath" + "sync" + "time" + + _ "github.com/mattn/go-sqlite3" +) + +type DB struct { + conn *sql.DB + path string + mu sync.RWMutex +} + +func NewDB(deploymentsPath string) (*DB, error) { + dbDir := filepath.Join(deploymentsPath, ".flatrun") + if err := os.MkdirAll(dbDir, 0755); err != nil { + return nil, err + } + + dbPath := filepath.Join(dbDir, "traffic.db") + conn, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_busy_timeout=5000") + if err != nil { + return nil, err + } + + conn.SetMaxOpenConns(1) + conn.SetMaxIdleConns(1) + conn.SetConnMaxLifetime(time.Hour) + + db := &DB{conn: conn, path: dbPath} + if err := db.migrate(); err != nil { + conn.Close() + return nil, err + } + + return db, nil +} + +func (db *DB) Close() error { + db.mu.Lock() + defer db.mu.Unlock() + return db.conn.Close() +} + +func (db *DB) migrate() error { + schema := ` + CREATE TABLE IF NOT EXISTS traffic_logs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + deployment_name TEXT NOT NULL, + request_path TEXT, + request_method TEXT, + status_code INTEGER, + source_ip TEXT, + response_time_ms INTEGER, + bytes_sent INTEGER, + request_length INTEGER, + upstream_time_ms INTEGER, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP + ); + + CREATE INDEX IF NOT EXISTS idx_traffic_deployment ON traffic_logs(deployment_name); + CREATE INDEX IF NOT EXISTS idx_traffic_created ON traffic_logs(created_at DESC); + CREATE INDEX IF NOT EXISTS idx_traffic_status ON traffic_logs(status_code); + CREATE INDEX IF NOT EXISTS idx_traffic_source_ip ON traffic_logs(source_ip); + ` + + _, err := db.conn.Exec(schema) + return err +} + +func (db *DB) InsertLog(log *TrafficLog) (int64, error) { + db.mu.Lock() + defer db.mu.Unlock() + + result, err := db.conn.Exec(` + INSERT INTO traffic_logs + (deployment_name, request_path, request_method, status_code, source_ip, response_time_ms, bytes_sent, request_length, upstream_time_ms, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + log.DeploymentName, log.RequestPath, log.RequestMethod, log.StatusCode, + log.SourceIP, log.ResponseTimeMs, log.BytesSent, log.RequestLength, + log.UpstreamTimeMs, log.CreatedAt, + ) + if err != nil { + return 0, err + } + return result.LastInsertId() +} + +func (db *DB) GetLogs(filter *TrafficFilter) ([]TrafficLog, int, error) { + db.mu.RLock() + defer db.mu.RUnlock() + + query := "SELECT id, deployment_name, request_path, request_method, status_code, source_ip, response_time_ms, bytes_sent, request_length, upstream_time_ms, created_at FROM traffic_logs WHERE 1=1" + countQuery := "SELECT COUNT(*) FROM traffic_logs WHERE 1=1" + args := []interface{}{} + + if filter.DeploymentName != "" { + query += " AND deployment_name = ?" + countQuery += " AND deployment_name = ?" + args = append(args, filter.DeploymentName) + } + if filter.RequestMethod != "" { + query += " AND request_method = ?" + countQuery += " AND request_method = ?" + args = append(args, filter.RequestMethod) + } + if filter.StatusCode != nil { + query += " AND status_code = ?" + countQuery += " AND status_code = ?" + args = append(args, *filter.StatusCode) + } + if filter.StatusGroup != "" { + switch filter.StatusGroup { + case "2xx": + query += " AND status_code >= 200 AND status_code < 300" + countQuery += " AND status_code >= 200 AND status_code < 300" + case "3xx": + query += " AND status_code >= 300 AND status_code < 400" + countQuery += " AND status_code >= 300 AND status_code < 400" + case "4xx": + query += " AND status_code >= 400 AND status_code < 500" + countQuery += " AND status_code >= 400 AND status_code < 500" + case "5xx": + query += " AND status_code >= 500" + countQuery += " AND status_code >= 500" + } + } + if filter.SourceIP != "" { + query += " AND source_ip = ?" + countQuery += " AND source_ip = ?" + args = append(args, filter.SourceIP) + } + if filter.RequestPath != "" { + query += " AND request_path LIKE ?" + countQuery += " AND request_path LIKE ?" + args = append(args, "%"+filter.RequestPath+"%") + } + if !filter.StartTime.IsZero() { + query += " AND created_at >= ?" + countQuery += " AND created_at >= ?" + args = append(args, filter.StartTime) + } + if !filter.EndTime.IsZero() { + query += " AND created_at <= ?" + countQuery += " AND created_at <= ?" + args = append(args, filter.EndTime) + } + + var total int + if err := db.conn.QueryRow(countQuery, args...).Scan(&total); err != nil { + return nil, 0, err + } + + query += " ORDER BY created_at DESC" + if filter.Limit > 0 { + query += " LIMIT ?" + args = append(args, filter.Limit) + } + if filter.Offset > 0 { + query += " OFFSET ?" + args = append(args, filter.Offset) + } + + rows, err := db.conn.Query(query, args...) + if err != nil { + return nil, 0, err + } + defer rows.Close() + + var logs []TrafficLog + for rows.Next() { + var log TrafficLog + var upstreamTimeMs sql.NullInt64 + if err := rows.Scan(&log.ID, &log.DeploymentName, &log.RequestPath, &log.RequestMethod, + &log.StatusCode, &log.SourceIP, &log.ResponseTimeMs, &log.BytesSent, + &log.RequestLength, &upstreamTimeMs, &log.CreatedAt); err != nil { + return nil, 0, err + } + if upstreamTimeMs.Valid { + val := int(upstreamTimeMs.Int64) + log.UpstreamTimeMs = &val + } + logs = append(logs, log) + } + + return logs, total, nil +} + +func (db *DB) GetStats(deploymentName string, since time.Duration) (*TrafficStats, error) { + db.mu.RLock() + defer db.mu.RUnlock() + + stats := &TrafficStats{ + ByStatusGroup: make(map[string]int64), + ByDeployment: make(map[string]int64), + ByMethod: make(map[string]int64), + } + + sinceTime := time.Now().Add(-since) + deploymentFilter := "" + args := []interface{}{sinceTime} + + if deploymentName != "" { + deploymentFilter = " AND deployment_name = ?" + args = append(args, deploymentName) + } + + // Total requests and bytes + var avgTime sql.NullFloat64 + err := db.conn.QueryRow(` + SELECT COUNT(*), COALESCE(SUM(bytes_sent), 0), AVG(response_time_ms) + FROM traffic_logs WHERE created_at >= ?`+deploymentFilter, args...). + Scan(&stats.TotalRequests, &stats.TotalBytes, &avgTime) + if err != nil { + return nil, err + } + if avgTime.Valid { + stats.AvgResponseTimeMs = avgTime.Float64 + } + + // By status group + rows, err := db.conn.Query(` + SELECT + CASE + WHEN status_code >= 200 AND status_code < 300 THEN '2xx' + WHEN status_code >= 300 AND status_code < 400 THEN '3xx' + WHEN status_code >= 400 AND status_code < 500 THEN '4xx' + WHEN status_code >= 500 THEN '5xx' + ELSE 'other' + END as status_group, + COUNT(*) as cnt + FROM traffic_logs + WHERE created_at >= ?`+deploymentFilter+` + GROUP BY status_group`, args...) + if err == nil { + defer rows.Close() + for rows.Next() { + var group string + var count int64 + if err := rows.Scan(&group, &count); err == nil { + stats.ByStatusGroup[group] = count + } + } + } + + // By deployment + rows, err = db.conn.Query(` + SELECT deployment_name, COUNT(*) as cnt + FROM traffic_logs + WHERE created_at >= ?`+deploymentFilter+` + GROUP BY deployment_name + ORDER BY cnt DESC + LIMIT 20`, args...) + if err == nil { + defer rows.Close() + for rows.Next() { + var name string + var count int64 + if err := rows.Scan(&name, &count); err == nil { + stats.ByDeployment[name] = count + } + } + } + + // By method + rows, err = db.conn.Query(` + SELECT request_method, COUNT(*) as cnt + FROM traffic_logs + WHERE created_at >= ?`+deploymentFilter+` + GROUP BY request_method`, args...) + if err == nil { + defer rows.Close() + for rows.Next() { + var method string + var count int64 + if err := rows.Scan(&method, &count); err == nil { + stats.ByMethod[method] = count + } + } + } + + // Top paths + rows, err = db.conn.Query(` + SELECT request_path, COUNT(*) as cnt, AVG(response_time_ms) as avg_time, + SUM(CASE WHEN status_code >= 400 THEN 1 ELSE 0 END) as errors + FROM traffic_logs + WHERE created_at >= ?`+deploymentFilter+` + GROUP BY request_path + ORDER BY cnt DESC + LIMIT 10`, args...) + if err == nil { + defer rows.Close() + for rows.Next() { + var p PathStats + if err := rows.Scan(&p.Path, &p.RequestCount, &p.AvgTimeMs, &p.ErrorCount); err == nil { + stats.TopPaths = append(stats.TopPaths, p) + } + } + } + + // Top IPs + rows, err = db.conn.Query(` + SELECT source_ip, COUNT(*) as cnt, SUM(bytes_sent) as bytes, MAX(created_at) as last_seen + FROM traffic_logs + WHERE created_at >= ?`+deploymentFilter+` + GROUP BY source_ip + ORDER BY cnt DESC + LIMIT 10`, args...) + if err == nil { + defer rows.Close() + for rows.Next() { + var ip IPTrafficStats + if err := rows.Scan(&ip.IP, &ip.RequestCount, &ip.BytesSent, &ip.LastSeen); err == nil { + stats.TopIPs = append(stats.TopIPs, ip) + } + } + } + + // Requests per hour (last 24 hours) + rows, err = db.conn.Query(` + SELECT strftime('%Y-%m-%d %H:00', created_at) as hour, COUNT(*) as cnt + FROM traffic_logs + WHERE created_at >= datetime('now', '-24 hours')`+deploymentFilter+` + GROUP BY hour + ORDER BY hour ASC`, args...) + if err == nil { + defer rows.Close() + for rows.Next() { + var h HourlyStats + if err := rows.Scan(&h.Hour, &h.RequestCount); err == nil { + stats.RequestsPerHour = append(stats.RequestsPerHour, h) + } + } + } + + // Deployment stats + rows, err = db.conn.Query(` + SELECT deployment_name, COUNT(*) as total, + AVG(response_time_ms) as avg_time, + SUM(CASE WHEN status_code >= 200 AND status_code < 300 THEN 1 ELSE 0 END) as s2xx, + SUM(CASE WHEN status_code >= 300 AND status_code < 400 THEN 1 ELSE 0 END) as s3xx, + SUM(CASE WHEN status_code >= 400 AND status_code < 500 THEN 1 ELSE 0 END) as s4xx, + SUM(CASE WHEN status_code >= 500 THEN 1 ELSE 0 END) as s5xx + FROM traffic_logs + WHERE created_at >= ?`+deploymentFilter+` + GROUP BY deployment_name + ORDER BY total DESC`, args...) + if err == nil { + defer rows.Close() + for rows.Next() { + var d DeploymentTrafficStats + if err := rows.Scan(&d.Name, &d.TotalRequests, &d.AvgResponseTime, + &d.Status2xx, &d.Status3xx, &d.Status4xx, &d.Status5xx); err == nil { + if d.TotalRequests > 0 { + d.ErrorRate = float64(d.Status4xx+d.Status5xx) / float64(d.TotalRequests) * 100 + } + stats.DeploymentStats = append(stats.DeploymentStats, d) + } + } + } + + return stats, nil +} + +func (db *DB) Cleanup(olderThan time.Duration) (int64, error) { + db.mu.Lock() + defer db.mu.Unlock() + + cutoff := time.Now().Add(-olderThan) + result, err := db.conn.Exec("DELETE FROM traffic_logs WHERE created_at < ?", cutoff) + if err != nil { + return 0, err + } + return result.RowsAffected() +} diff --git a/internal/traffic/manager.go b/internal/traffic/manager.go new file mode 100644 index 0000000..a1044e4 --- /dev/null +++ b/internal/traffic/manager.go @@ -0,0 +1,97 @@ +package traffic + +import ( + "log" + "time" +) + +type Manager struct { + db *DB + retentionDays int +} + +func NewManager(deploymentsPath string, retentionDays int) (*Manager, error) { + db, err := NewDB(deploymentsPath) + if err != nil { + return nil, err + } + + if retentionDays <= 0 { + retentionDays = 7 + } + + m := &Manager{ + db: db, + retentionDays: retentionDays, + } + + go m.cleanupLoop() + + return m, nil +} + +func (m *Manager) Close() error { + return m.db.Close() +} + +func (m *Manager) IngestLog(ingest *IngestTrafficLog) (*TrafficLog, error) { + log := &TrafficLog{ + DeploymentName: ingest.DeploymentName, + RequestPath: ingest.RequestPath, + RequestMethod: ingest.RequestMethod, + StatusCode: ingest.StatusCode, + SourceIP: ingest.SourceIP, + ResponseTimeMs: ingest.ResponseTimeMs, + BytesSent: ingest.BytesSent, + RequestLength: ingest.RequestLength, + UpstreamTimeMs: ingest.UpstreamTimeMs, + CreatedAt: time.Now(), + } + + if ingest.Timestamp > 0 { + log.CreatedAt = time.Unix(ingest.Timestamp, 0) + } + + id, err := m.db.InsertLog(log) + if err != nil { + return nil, err + } + log.ID = id + + return log, nil +} + +func (m *Manager) GetLogs(filter *TrafficFilter) ([]TrafficLog, int, error) { + if filter.Limit <= 0 { + filter.Limit = 100 + } + return m.db.GetLogs(filter) +} + +func (m *Manager) GetStats(deploymentName string, since time.Duration) (*TrafficStats, error) { + if since <= 0 { + since = 24 * time.Hour + } + return m.db.GetStats(deploymentName, since) +} + +func (m *Manager) Cleanup(days int) (int64, error) { + if days <= 0 { + days = m.retentionDays + } + return m.db.Cleanup(time.Duration(days) * 24 * time.Hour) +} + +func (m *Manager) cleanupLoop() { + ticker := time.NewTicker(6 * time.Hour) + defer ticker.Stop() + + for range ticker.C { + deleted, err := m.Cleanup(m.retentionDays) + if err != nil { + log.Printf("Traffic cleanup error: %v", err) + } else if deleted > 0 { + log.Printf("Traffic cleanup: deleted %d old logs", deleted) + } + } +} diff --git a/internal/traffic/models.go b/internal/traffic/models.go new file mode 100644 index 0000000..5408b18 --- /dev/null +++ b/internal/traffic/models.go @@ -0,0 +1,86 @@ +package traffic + +import "time" + +type TrafficLog struct { + ID int64 `json:"id"` + DeploymentName string `json:"deployment_name"` + RequestPath string `json:"request_path"` + RequestMethod string `json:"request_method"` + StatusCode int `json:"status_code"` + SourceIP string `json:"source_ip"` + ResponseTimeMs int `json:"response_time_ms"` + BytesSent int `json:"bytes_sent"` + RequestLength int `json:"request_length"` + UpstreamTimeMs *int `json:"upstream_time_ms,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +type TrafficFilter struct { + DeploymentName string + RequestMethod string + StatusCode *int + StatusGroup string // "2xx", "3xx", "4xx", "5xx" + SourceIP string + RequestPath string + StartTime time.Time + EndTime time.Time + Limit int + Offset int +} + +type TrafficStats struct { + TotalRequests int64 `json:"total_requests"` + TotalBytes int64 `json:"total_bytes"` + AvgResponseTimeMs float64 `json:"avg_response_time_ms"` + ByStatusGroup map[string]int64 `json:"by_status_group"` + ByDeployment map[string]int64 `json:"by_deployment"` + ByMethod map[string]int64 `json:"by_method"` + TopPaths []PathStats `json:"top_paths"` + TopIPs []IPTrafficStats `json:"top_ips"` + RequestsPerHour []HourlyStats `json:"requests_per_hour"` + DeploymentStats []DeploymentTrafficStats `json:"deployment_stats"` +} + +type PathStats struct { + Path string `json:"path"` + RequestCount int64 `json:"request_count"` + AvgTimeMs float64 `json:"avg_time_ms"` + ErrorCount int64 `json:"error_count"` +} + +type IPTrafficStats struct { + IP string `json:"ip"` + RequestCount int64 `json:"request_count"` + BytesSent int64 `json:"bytes_sent"` + LastSeen time.Time `json:"last_seen"` +} + +type HourlyStats struct { + Hour string `json:"hour"` + RequestCount int64 `json:"request_count"` +} + +type DeploymentTrafficStats struct { + Name string `json:"name"` + TotalRequests int64 `json:"total_requests"` + AvgResponseTime float64 `json:"avg_response_time_ms"` + ErrorRate float64 `json:"error_rate"` + Status2xx int64 `json:"status_2xx"` + Status3xx int64 `json:"status_3xx"` + Status4xx int64 `json:"status_4xx"` + Status5xx int64 `json:"status_5xx"` +} + +type IngestTrafficLog struct { + DeploymentName string `json:"deployment_name"` + RequestPath string `json:"request_path"` + RequestMethod string `json:"request_method"` + StatusCode int `json:"status_code"` + SourceIP string `json:"source_ip"` + ResponseTimeMs int `json:"response_time_ms"` + BytesSent int `json:"bytes_sent"` + RequestLength int `json:"request_length"` + UpstreamTimeMs *int `json:"upstream_time_ms,omitempty"` + Timestamp int64 `json:"timestamp"` +} diff --git a/templates/infra/nginx/lua/security.lua b/templates/infra/nginx/lua/security.lua index 68b1414..9ce40bd 100644 --- a/templates/infra/nginx/lua/security.lua +++ b/templates/infra/nginx/lua/security.lua @@ -6,8 +6,9 @@ local http = require "resty.http" local _M = {} --- Configuration (will be set by the agent) -local AGENT_URL = os.getenv("FLATRUN_AGENT_URL") or "http://host.docker.internal:8090" +-- Configuration (injected by agent during deployment) +local AGENT_IP = "{{.AgentIP}}" +local AGENT_PORT = {{.AgentPort}} -- Suspicious paths patterns local suspicious_patterns = { @@ -120,9 +121,6 @@ function _M.capture_event() local ok, err = ngx.timer.at(0, function(premature) if premature then return end - local httpc = http.new() - httpc:set_timeout(5000) - local body, encode_err = cjson.encode({ source_ip = ip, request_path = uri, @@ -138,10 +136,27 @@ function _M.capture_event() return end - local res, req_err = httpc:request_uri(AGENT_URL .. "/api/security/events/ingest", { + local httpc = http.new() + httpc:set_timeout(2000) + + -- Connect directly using injected IP and port + local conn_ok, conn_err = httpc:connect({ + host = AGENT_IP, + port = AGENT_PORT, + scheme = "http", + }) + + if not conn_ok then + ngx.log(ngx.ERR, "Failed to connect to agent: ", conn_err) + return + end + + local res, req_err = httpc:request({ method = "POST", + path = "/api/security/events/ingest", body = body, headers = { + ["Host"] = AGENT_IP .. ":" .. AGENT_PORT, ["Content-Type"] = "application/json", } }) diff --git a/templates/infra/nginx/lua/traffic.lua b/templates/infra/nginx/lua/traffic.lua new file mode 100644 index 0000000..a2e62bf --- /dev/null +++ b/templates/infra/nginx/lua/traffic.lua @@ -0,0 +1,85 @@ +-- FlatRun Traffic Logging +-- This script logs all requests to the agent API for traffic statistics + +local cjson = require "cjson.safe" +local http = require "resty.http" + +local _M = {} + +-- Configuration (injected by agent during deployment) +local AGENT_IP = "{{.AgentIP}}" +local AGENT_PORT = {{.AgentPort}} + +function _M.log_request() + local status = ngx.status + local host = ngx.var.host or "" + local uri = ngx.var.uri or "" + local method = ngx.var.request_method or "" + local ip = ngx.var.remote_addr or "" + local request_time = ngx.var.request_time or "0" + local bytes_sent = ngx.var.bytes_sent or "0" + local request_length = ngx.var.request_length or "0" + local upstream_response_time = ngx.var.upstream_response_time or "" + + -- Extract deployment name from host (remove port if present) + local deployment_name = host:match("^([^:]+)") or host + + -- Non-blocking: fire and forget via timer + local ok, err = ngx.timer.at(0, function(premature) + if premature then return end + + local body, encode_err = cjson.encode({ + deployment_name = deployment_name, + request_path = uri, + request_method = method, + status_code = status, + source_ip = ip, + response_time_ms = math.floor((tonumber(request_time) or 0) * 1000), + bytes_sent = tonumber(bytes_sent) or 0, + request_length = tonumber(request_length) or 0, + upstream_time_ms = upstream_response_time ~= "" and math.floor((tonumber(upstream_response_time) or 0) * 1000) or nil, + timestamp = ngx.time() + }) + + if not body then + ngx.log(ngx.ERR, "Failed to encode traffic log: ", encode_err) + return + end + + local httpc = http.new() + httpc:set_timeout(2000) + + local conn_ok, conn_err = httpc:connect({ + host = AGENT_IP, + port = AGENT_PORT, + scheme = "http", + }) + + if not conn_ok then + ngx.log(ngx.ERR, "Failed to connect to agent for traffic log: ", conn_err) + return + end + + local res, req_err = httpc:request({ + method = "POST", + path = "/api/traffic/ingest", + body = body, + headers = { + ["Host"] = AGENT_IP .. ":" .. AGENT_PORT, + ["Content-Type"] = "application/json", + } + }) + + if not res then + ngx.log(ngx.ERR, "Failed to send traffic log: ", req_err) + end + + httpc:close() + end) + + if not ok then + ngx.log(ngx.ERR, "Failed to create timer for traffic log: ", err) + end +end + +return _M diff --git a/templates/infra/nginx/nginx.lua.conf b/templates/infra/nginx/nginx.lua.conf index caf5eab..af443d6 100644 --- a/templates/infra/nginx/nginx.lua.conf +++ b/templates/infra/nginx/nginx.lua.conf @@ -42,9 +42,15 @@ http { lua_shared_dict security_events 10m; lua_shared_dict ip_rate_limit 10m; - # Load security module + # Load Lua modules init_by_lua_block { security = require "security" + traffic = require "traffic" + } + + # Global traffic logging - logs ALL requests + log_by_lua_block { + traffic.log_request() } # Docker DNS resolver diff --git a/templates/templates.go b/templates/templates.go index 6159c7b..440ee8b 100644 --- a/templates/templates.go +++ b/templates/templates.go @@ -1,9 +1,11 @@ package templates import ( + "bytes" "embed" "io/fs" "path/filepath" + "text/template" ) //go:embed */metadata.yml */docker-compose.yml @@ -87,3 +89,59 @@ func GetNginxConfig(luaEnabled bool) ([]byte, error) { func GetNginxSecurityLua() ([]byte, error) { return FS.ReadFile("infra/nginx/lua/security.lua") } + +// LuaTemplateData contains the data for Lua template processing +type LuaTemplateData struct { + AgentIP string + AgentPort int +} + +// GetNginxSecurityLuaWithConfig returns the security.lua template processed with agent config +func GetNginxSecurityLuaWithConfig(agentIP string, agentPort int) ([]byte, error) { + content, err := FS.ReadFile("infra/nginx/lua/security.lua") + if err != nil { + return nil, err + } + + tmpl, err := template.New("security.lua").Parse(string(content)) + if err != nil { + return nil, err + } + + var buf bytes.Buffer + data := LuaTemplateData{ + AgentIP: agentIP, + AgentPort: agentPort, + } + + if err := tmpl.Execute(&buf, data); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// GetNginxTrafficLuaWithConfig returns the traffic.lua template processed with agent config +func GetNginxTrafficLuaWithConfig(agentIP string, agentPort int) ([]byte, error) { + content, err := FS.ReadFile("infra/nginx/lua/traffic.lua") + if err != nil { + return nil, err + } + + tmpl, err := template.New("traffic.lua").Parse(string(content)) + if err != nil { + return nil, err + } + + var buf bytes.Buffer + data := LuaTemplateData{ + AgentIP: agentIP, + AgentPort: agentPort, + } + + if err := tmpl.Execute(&buf, data); err != nil { + return nil, err + } + + return buf.Bytes(), nil +}