diff --git a/internal/api/security_handlers.go b/internal/api/security_handlers.go index 8d8de60..83fee8a 100644 --- a/internal/api/security_handlers.go +++ b/internal/api/security_handlers.go @@ -264,7 +264,84 @@ func (s *Server) unblockIP(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "IP unblocked successfully"}) } -// getEventsByIP returns all security events for a specific IP +func (s *Server) listWhitelist(c *gin.Context) { + if s.securityManager == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "Security module not enabled"}) + return + } + + entries, err := s.securityManager.GetWhitelist() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"whitelist": entries}) +} + +func (s *Server) addWhitelistEntry(c *gin.Context) { + if s.securityManager == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "Security module not enabled"}) + return + } + + var req struct { + Value string `json:"value" binding:"required"` + Type string `json:"type" binding:"required"` + Reason string `json:"reason"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if req.Type != "ip" && req.Type != "cidr" && req.Type != "path" { + c.JSON(http.StatusBadRequest, gin.H{"error": "Type must be 'ip', 'cidr', or 'path'"}) + return + } + + id, err := s.securityManager.AddWhitelistEntry(req.Value, req.Type, req.Reason) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusCreated, gin.H{"id": id}) +} + +func (s *Server) removeWhitelistEntry(c *gin.Context) { + if s.securityManager == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "Security module not enabled"}) + return + } + + idStr := c.Param("id") + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid ID"}) + return + } + + if err := s.securityManager.RemoveWhitelistEntry(id); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "Entry removed"}) +} + +func (s *Server) listWhitelistInternal(c *gin.Context) { + token := c.GetHeader("X-Internal-Token") + expectedToken := s.config.Security.InternalAPIToken + + if token == "" || token != expectedToken { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid internal token"}) + return + } + + s.listWhitelist(c) +} + func (s *Server) getEventsByIP(c *gin.Context) { if s.securityManager == nil { c.JSON(http.StatusServiceUnavailable, gin.H{"error": "Security module not enabled"}) diff --git a/internal/api/server.go b/internal/api/server.go index 8fdc324..fbaec7b 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -107,6 +107,11 @@ func New(cfg *config.Config, configPath string) *Server { if err := securityManager.InitNginxConfigs(nginxConfigPath); err != nil { log.Printf("Warning: Failed to initialize security nginx configs: %v", err) } + // Add Docker gateway IP to whitelist + gatewayIP := infraManager.GetDockerHostIP() + if err := securityManager.AddDockerGatewayToWhitelist(gatewayIP); err != nil { + log.Printf("Warning: Failed to add Docker gateway to whitelist: %v", err) + } } } @@ -284,6 +289,9 @@ func (s *Server) setupRoutes() { protected.POST("/security/protected-routes", s.addProtectedRoute) protected.PUT("/security/protected-routes/:id", s.updateProtectedRoute) protected.DELETE("/security/protected-routes/:id", s.deleteProtectedRoute) + protected.GET("/security/whitelist", s.listWhitelist) + protected.POST("/security/whitelist", s.addWhitelistEntry) + protected.DELETE("/security/whitelist/:id", s.removeWhitelistEntry) protected.GET("/security/realtime-capture", s.getRealtimeCaptureStatus) protected.PUT("/security/realtime-capture", s.setRealtimeCaptureStatus) protected.GET("/security/health", s.getSecurityHealth) @@ -295,6 +303,7 @@ func (s *Server) setupRoutes() { // Traffic endpoints protected.GET("/traffic/logs", s.getTrafficLogs) protected.GET("/traffic/stats", s.getTrafficStats) + protected.GET("/traffic/unknown-domains", s.getUnknownDomainStats) protected.POST("/traffic/cleanup", s.cleanupTrafficLogs) protected.GET("/deployments/:name/traffic", s.getDeploymentTrafficStats) } @@ -303,8 +312,9 @@ func (s *Server) setupRoutes() { api.POST("/security/events/ingest", s.ingestSecurityEvent) api.POST("/traffic/ingest", s.ingestTrafficLog) - // Internal nginx endpoint - token-authenticated for blocked IPs + // Internal nginx endpoints - token-authenticated api.GET("/_internal/blocked-ips", s.listBlockedIPsInternal) + api.GET("/_internal/whitelist", s.listWhitelistInternal) } } @@ -1185,12 +1195,13 @@ func (s *Server) getSettings(c *gin.Context) { "subdomain_style": s.config.Domain.SubdomainStyle, }, "nginx": gin.H{ - "enabled": s.config.Nginx.Enabled, - "image": s.config.Nginx.Image, - "container_name": s.config.Nginx.ContainerName, - "config_path": s.config.Nginx.ConfigPath, - "reload_command": s.config.Nginx.ReloadCommand, - "external": s.config.Nginx.External, + "enabled": s.config.Nginx.Enabled, + "image": s.config.Nginx.Image, + "container_name": s.config.Nginx.ContainerName, + "config_path": s.config.Nginx.ConfigPath, + "reload_command": s.config.Nginx.ReloadCommand, + "external": s.config.Nginx.External, + "reject_unknown_domains": s.config.Nginx.RejectUnknownDomains, }, "certbot": gin.H{ "enabled": s.config.Certbot.Enabled, @@ -1241,12 +1252,13 @@ func (s *Server) updateSettings(c *gin.Context) { SubdomainStyle string `json:"subdomain_style"` } `json:"domain,omitempty"` Nginx *struct { - Enabled bool `json:"enabled"` - Image string `json:"image"` - ContainerName string `json:"container_name"` - ConfigPath string `json:"config_path"` - ReloadCommand string `json:"reload_command"` - External bool `json:"external"` + Enabled bool `json:"enabled"` + Image string `json:"image"` + ContainerName string `json:"container_name"` + ConfigPath string `json:"config_path"` + ReloadCommand string `json:"reload_command"` + External bool `json:"external"` + RejectUnknownDomains *bool `json:"reject_unknown_domains"` } `json:"nginx,omitempty"` Certbot *struct { Enabled bool `json:"enabled"` @@ -1318,6 +1330,9 @@ func (s *Server) updateSettings(c *gin.Context) { if req.Nginx.ReloadCommand != "" { s.config.Nginx.ReloadCommand = req.Nginx.ReloadCommand } + if req.Nginx.RejectUnknownDomains != nil { + s.config.Nginx.RejectUnknownDomains = *req.Nginx.RejectUnknownDomains + } } if req.Certbot != nil { @@ -1426,12 +1441,13 @@ func (s *Server) updateSettings(c *gin.Context) { "subdomain_style": s.config.Domain.SubdomainStyle, }, "nginx": gin.H{ - "enabled": s.config.Nginx.Enabled, - "image": s.config.Nginx.Image, - "container_name": s.config.Nginx.ContainerName, - "config_path": s.config.Nginx.ConfigPath, - "reload_command": s.config.Nginx.ReloadCommand, - "external": s.config.Nginx.External, + "enabled": s.config.Nginx.Enabled, + "image": s.config.Nginx.Image, + "container_name": s.config.Nginx.ContainerName, + "config_path": s.config.Nginx.ConfigPath, + "reload_command": s.config.Nginx.ReloadCommand, + "external": s.config.Nginx.External, + "reject_unknown_domains": s.config.Nginx.RejectUnknownDomains, }, "certbot": gin.H{ "enabled": s.config.Certbot.Enabled, @@ -2793,6 +2809,16 @@ func (s *Server) getSystemStats(c *gin.Context) { imageStats, _ := s.networksManager.GetImageStats() volumeStats, _ := s.networksManager.GetVolumeStats() + var networkCount, portCount int + if networks, err := s.networksManager.ListNetworks(); err == nil { + networkCount = len(networks) + } + if containers, err := s.networksManager.ListContainers(); err == nil { + for _, container := range containers { + portCount += len(container.Ports) + } + } + systemStats, _ := system.GetSystemStats() c.JSON(http.StatusOK, gin.H{ @@ -2800,6 +2826,8 @@ func (s *Server) getSystemStats(c *gin.Context) { "containers": containerStats, "images": imageStats, "volumes": volumeStats, + "networks": gin.H{"total": networkCount}, + "ports": gin.H{"total": portCount}, "system": systemStats, }) } diff --git a/internal/api/traffic_handlers.go b/internal/api/traffic_handlers.go index fde2b9b..0e7fdc2 100644 --- a/internal/api/traffic_handlers.go +++ b/internal/api/traffic_handlers.go @@ -117,6 +117,39 @@ func (s *Server) getTrafficStats(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"stats": stats}) } +func (s *Server) getUnknownDomainStats(c *gin.Context) { + if s.trafficManager == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "Traffic logging not enabled"}) + return + } + + since := 24 * time.Hour + if sinceStr := c.Query("since"); sinceStr != "" { + if d, err := time.ParseDuration(sinceStr); err == nil { + since = d + } + } + + deployments, err := s.manager.ListDeployments() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + var knownDeployments []string + for _, d := range deployments { + knownDeployments = append(knownDeployments, d.Name) + } + + stats, err := s.trafficManager.GetUnknownDomainStats(knownDeployments, since) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"stats": stats}) +} + // cleanupTrafficLogs removes old traffic logs func (s *Server) cleanupTrafficLogs(c *gin.Context) { if s.trafficManager == nil { diff --git a/internal/infra/manager.go b/internal/infra/manager.go index f957c54..472e47b 100644 --- a/internal/infra/manager.go +++ b/internal/infra/manager.go @@ -357,7 +357,9 @@ func (m *Manager) SetNginxRealtimeCaptureWithStatus(enabled bool) (map[string]in if enabled { // Write nginx.conf with Lua support - nginxConf, err := templates.GetNginxConfig(true) + nginxConf, err := templates.GetNginxConfigWithData(true, templates.NginxConfigData{ + RejectUnknownDomains: m.config.Nginx.RejectUnknownDomains, + }) if err != nil { errors = append(errors, fmt.Sprintf("failed to get nginx lua config template: %v", err)) } else { @@ -414,6 +416,13 @@ func (m *Manager) SetNginxRealtimeCaptureWithStatus(enabled bool) (map[string]in } result["conf_files_written"] = true } + + // Ensure ssl directory exists + sslDir := filepath.Join(nginxDir, "ssl") + if err := os.MkdirAll(sslDir, 0755); err != nil { + errors = append(errors, fmt.Sprintf("failed to create ssl directory: %v", err)) + } + } else { // Delete nginx.conf - container will use default from image if _, err := os.Stat(confPath); err == nil { @@ -1049,6 +1058,7 @@ func (m *Manager) checkNginxInternalAPIReachable() bool { var securityVolumeMounts = []string{ "./nginx.conf:/usr/local/openresty/nginx/conf/nginx.conf:ro", "./lua:/etc/nginx/lua:ro", + "./ssl:/etc/nginx/ssl:ro", } func (m *Manager) getNginxComposePath() string { @@ -1215,8 +1225,15 @@ func (m *Manager) RefreshSecurityScripts() (*RefreshSecurityScriptsResult, error result.Errors = append(result.Errors, fmt.Sprintf("failed to create conf.d directory: %v", err)) } + sslDir := filepath.Join(nginxDir, "ssl") + if err := os.MkdirAll(sslDir, 0755); err != nil { + result.Errors = append(result.Errors, fmt.Sprintf("failed to create ssl directory: %v", err)) + } + // Write nginx.conf with Lua support - nginxConf, err := templates.GetNginxConfig(true) + nginxConf, err := templates.GetNginxConfigWithData(true, templates.NginxConfigData{ + RejectUnknownDomains: m.config.Nginx.RejectUnknownDomains, + }) if err != nil { result.Errors = append(result.Errors, fmt.Sprintf("failed to get nginx lua config template: %v", err)) } else { diff --git a/internal/security/db.go b/internal/security/db.go index b7a0c6c..314cb8b 100644 --- a/internal/security/db.go +++ b/internal/security/db.go @@ -38,6 +38,12 @@ func NewDB(deploymentsPath string) (*DB, error) { return nil, err } + // Seed default whitelist entries + if err := db.SeedDefaultWhitelist(); err != nil { + conn.Close() + return nil, err + } + return db, nil } @@ -81,6 +87,18 @@ func (db *DB) migrate() error { CREATE INDEX IF NOT EXISTS idx_blocked_ips_ip ON blocked_ips(ip); CREATE INDEX IF NOT EXISTS idx_blocked_ips_expires ON blocked_ips(expires_at); + CREATE TABLE IF NOT EXISTS whitelist ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + value TEXT NOT NULL UNIQUE, + type TEXT NOT NULL CHECK (type IN ('ip', 'cidr', 'path')), + reason TEXT, + is_internal BOOLEAN DEFAULT FALSE, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP + ); + + CREATE INDEX IF NOT EXISTS idx_whitelist_value ON whitelist(value); + CREATE INDEX IF NOT EXISTS idx_whitelist_type ON whitelist(type); + CREATE TABLE IF NOT EXISTS protected_routes ( id INTEGER PRIMARY KEY AUTOINCREMENT, path_pattern TEXT NOT NULL, @@ -339,6 +357,100 @@ func (db *DB) IsIPBlocked(ip string) (bool, error) { return count > 0, nil } +// GetWhitelist retrieves all whitelist entries +func (db *DB) GetWhitelist() ([]WhitelistEntry, error) { + db.mu.RLock() + defer db.mu.RUnlock() + + rows, err := db.conn.Query(` + SELECT id, value, type, reason, is_internal, created_at + FROM whitelist + ORDER BY is_internal DESC, created_at DESC`) + if err != nil { + return nil, err + } + defer rows.Close() + + var entries []WhitelistEntry + for rows.Next() { + var e WhitelistEntry + var reason sql.NullString + if err := rows.Scan(&e.ID, &e.Value, &e.Type, &reason, &e.IsInternal, &e.CreatedAt); err != nil { + return nil, err + } + e.Reason = reason.String + entries = append(entries, e) + } + + return entries, nil +} + +// AddWhitelistEntry adds a new entry to the whitelist +func (db *DB) AddWhitelistEntry(value, entryType, reason string, isInternal bool) (int64, error) { + db.mu.Lock() + defer db.mu.Unlock() + + result, err := db.conn.Exec(` + INSERT INTO whitelist (value, type, reason, is_internal) + VALUES (?, ?, ?, ?) + ON CONFLICT(value) DO UPDATE SET reason = ?, is_internal = ?`, + value, entryType, reason, isInternal, reason, isInternal, + ) + if err != nil { + return 0, err + } + return result.LastInsertId() +} + +// RemoveWhitelistEntry removes an entry from the whitelist +func (db *DB) RemoveWhitelistEntry(id int64) error { + db.mu.Lock() + defer db.mu.Unlock() + + _, err := db.conn.Exec("DELETE FROM whitelist WHERE id = ? AND is_internal = FALSE", id) + return err +} + +// IsWhitelisted checks if an IP or path is in the whitelist +func (db *DB) IsWhitelisted(value string) (bool, error) { + db.mu.RLock() + defer db.mu.RUnlock() + + var count int + err := db.conn.QueryRow("SELECT COUNT(*) FROM whitelist WHERE value = ?", value).Scan(&count) + if err != nil { + return false, err + } + return count > 0, nil +} + +// SeedDefaultWhitelist adds default internal whitelist entries if not present +func (db *DB) SeedDefaultWhitelist() error { + defaults := []struct { + Value string + Type string + Reason string + }{ + {"127.0.0.1", "ip", "Localhost"}, + {"10.0.0.0/8", "cidr", "Private network"}, + {"172.16.0.0/12", "cidr", "Docker/Private network"}, + {"192.168.0.0/16", "cidr", "Private network"}, + {"/_internal", "path", "Internal API"}, + {"/api/_internal", "path", "Internal API"}, + {"/api/health", "path", "Health check"}, + {"/api/security/events/ingest", "path", "Security ingest"}, + {"/api/traffic/ingest", "path", "Traffic ingest"}, + } + + for _, d := range defaults { + _, err := db.AddWhitelistEntry(d.Value, d.Type, d.Reason, true) + if err != nil { + return err + } + } + return nil +} + // GetProtectedRoutes retrieves all protected routes func (db *DB) GetProtectedRoutes() ([]ProtectedRoute, error) { db.mu.RLock() diff --git a/internal/security/detector.go b/internal/security/detector.go index a474e12..e9301b7 100644 --- a/internal/security/detector.go +++ b/internal/security/detector.go @@ -11,12 +11,12 @@ type Detector struct { mu sync.RWMutex // Thresholds - windowDuration time.Duration - rateThreshold int // high request rate - notFoundThreshold int // 404 responses - authFailureThreshold int // 401/403 responses - uniquePathsThreshold int // scanning many different paths - repeatedHitsThreshold int // hammering same path + windowDuration time.Duration + rateThreshold int // high request rate + notFoundThreshold int // 404 responses + authFailureThreshold int // 401/403 responses + uniquePathsThreshold int // scanning many different paths + repeatedHitsThreshold int // hammering same path } type requestWindow struct { @@ -31,11 +31,11 @@ func NewDetector() *Detector { return &Detector{ ipRequestCount: make(map[string]*requestWindow), windowDuration: 2 * time.Minute, - rateThreshold: 60, // 60 requests in 2 min - notFoundThreshold: 10, // 10 404s in 2 min - authFailureThreshold: 5, // 5 auth failures in 2 min - uniquePathsThreshold: 20, // 20 different paths in 2 min - repeatedHitsThreshold: 30, // 30 hits to same path in 2 min + rateThreshold: 60, // 60 requests in 2 min + notFoundThreshold: 10, // 10 404s in 2 min + authFailureThreshold: 5, // 5 auth failures in 2 min + uniquePathsThreshold: 20, // 20 different paths in 2 min + repeatedHitsThreshold: 30, // 30 hits to same path in 2 min } } diff --git a/internal/security/manager.go b/internal/security/manager.go index cae4cab..68b22f2 100644 --- a/internal/security/manager.go +++ b/internal/security/manager.go @@ -175,6 +175,30 @@ func (m *Manager) DeleteProtectedRoute(id int64) error { return m.db.DeleteProtectedRoute(id) } +func (m *Manager) GetWhitelist() ([]WhitelistEntry, error) { + return m.db.GetWhitelist() +} + +func (m *Manager) AddWhitelistEntry(value, entryType, reason string) (int64, error) { + return m.db.AddWhitelistEntry(value, entryType, reason, false) +} + +func (m *Manager) RemoveWhitelistEntry(id int64) error { + return m.db.RemoveWhitelistEntry(id) +} + +func (m *Manager) IsWhitelisted(value string) (bool, error) { + return m.db.IsWhitelisted(value) +} + +func (m *Manager) AddDockerGatewayToWhitelist(gatewayIP string) error { + if gatewayIP == "" { + return nil + } + _, err := m.db.AddWhitelistEntry(gatewayIP, "ip", "Docker gateway", true) + return err +} + // Cleanup removes old events and expired blocks func (m *Manager) Cleanup(retentionDays int) (int64, int64, error) { eventsDeleted, err := m.db.CleanupOldEvents(time.Duration(retentionDays) * 24 * time.Hour) diff --git a/internal/security/models.go b/internal/security/models.go index 31dc6b4..19641e5 100644 --- a/internal/security/models.go +++ b/internal/security/models.go @@ -46,6 +46,15 @@ type ProtectedRoute struct { CreatedAt time.Time `json:"created_at"` } +type WhitelistEntry struct { + ID int64 `json:"id"` + Value string `json:"value"` + Type string `json:"type"` // "ip", "cidr", or "path" + Reason string `json:"reason,omitempty"` + IsInternal bool `json:"is_internal"` + CreatedAt time.Time `json:"created_at"` +} + type SecurityStats struct { TotalEvents int `json:"total_events"` Last24Hours int `json:"last_24_hours"` diff --git a/internal/traffic/db.go b/internal/traffic/db.go index ff85a95..873e30a 100644 --- a/internal/traffic/db.go +++ b/internal/traffic/db.go @@ -378,3 +378,127 @@ func (db *DB) Cleanup(olderThan time.Duration) (int64, error) { } return result.RowsAffected() } + +func (db *DB) GetUnknownDomainStats(knownDeployments []string, since time.Duration) (*UnknownDomainStats, error) { + db.mu.RLock() + defer db.mu.RUnlock() + + stats := &UnknownDomainStats{ + TopDomains: []UnknownDomainEntry{}, + TopIPs: []UnknownDomainIPEntry{}, + RecentLogs: []TrafficLog{}, + } + + cutoff := time.Now().Add(-since) + + placeholders := "" + args := []interface{}{cutoff} + for i, d := range knownDeployments { + if i > 0 { + placeholders += "," + } + placeholders += "?" + args = append(args, d) + } + + notInClause := "" + if len(knownDeployments) > 0 { + notInClause = " AND deployment_name NOT IN (" + placeholders + ")" + } + + // Total count + var total int64 + err := db.conn.QueryRow(` + SELECT COUNT(*) FROM traffic_logs + WHERE created_at >= ?`+notInClause, args...).Scan(&total) + if err != nil { + return nil, err + } + stats.TotalRequests = total + + // Top domains + rows, err := db.conn.Query(` + SELECT deployment_name, COUNT(*) as cnt, MAX(created_at) as last_seen + FROM traffic_logs + WHERE created_at >= ?`+notInClause+` + GROUP BY deployment_name + ORDER BY cnt DESC + LIMIT 20`, args...) + if err == nil { + defer rows.Close() + for rows.Next() { + var entry UnknownDomainEntry + if err := rows.Scan(&entry.Domain, &entry.RequestCount, &entry.LastSeen); err == nil { + stats.TopDomains = append(stats.TopDomains, entry) + } + } + } + + // Top IPs with domains they accessed + rows, err = db.conn.Query(` + SELECT source_ip, COUNT(*) as cnt, + GROUP_CONCAT(DISTINCT deployment_name) as domains, + MAX(created_at) as last_seen + FROM traffic_logs + WHERE created_at >= ?`+notInClause+` + GROUP BY source_ip + ORDER BY cnt DESC + LIMIT 20`, args...) + if err == nil { + defer rows.Close() + for rows.Next() { + var entry UnknownDomainIPEntry + var domainsStr string + if err := rows.Scan(&entry.IP, &entry.RequestCount, &domainsStr, &entry.LastSeen); err == nil { + if domainsStr != "" { + entry.Domains = append(entry.Domains, splitString(domainsStr, ",")...) + } + stats.TopIPs = append(stats.TopIPs, entry) + } + } + } + + // Recent logs + rows, err = db.conn.Query(` + SELECT id, deployment_name, request_path, request_method, status_code, + source_ip, response_time_ms, bytes_sent, request_length, upstream_time_ms, created_at + FROM traffic_logs + WHERE created_at >= ?`+notInClause+` + ORDER BY created_at DESC + LIMIT 50`, args...) + if err == nil { + defer rows.Close() + for rows.Next() { + var log TrafficLog + var upstreamTime sql.NullInt64 + if err := rows.Scan(&log.ID, &log.DeploymentName, &log.RequestPath, + &log.RequestMethod, &log.StatusCode, &log.SourceIP, + &log.ResponseTimeMs, &log.BytesSent, &log.RequestLength, + &upstreamTime, &log.CreatedAt); err == nil { + if upstreamTime.Valid { + t := int(upstreamTime.Int64) + log.UpstreamTimeMs = &t + } + stats.RecentLogs = append(stats.RecentLogs, log) + } + } + } + + return stats, nil +} + +func splitString(s, sep string) []string { + if s == "" { + return nil + } + var result []string + start := 0 + for i := 0; i < len(s); i++ { + if s[i] == sep[0] { + result = append(result, s[start:i]) + start = i + 1 + } + } + result = append(result, s[start:]) + return result +} diff --git a/internal/traffic/manager.go b/internal/traffic/manager.go index a1044e4..5b02215 100644 --- a/internal/traffic/manager.go +++ b/internal/traffic/manager.go @@ -75,6 +75,13 @@ func (m *Manager) GetStats(deploymentName string, since time.Duration) (*Traffic return m.db.GetStats(deploymentName, since) } +func (m *Manager) GetUnknownDomainStats(knownDeployments []string, since time.Duration) (*UnknownDomainStats, error) { + if since <= 0 { + since = 24 * time.Hour + } + return m.db.GetUnknownDomainStats(knownDeployments, since) +} + func (m *Manager) Cleanup(days int) (int64, error) { if days <= 0 { days = m.retentionDays diff --git a/internal/traffic/models.go b/internal/traffic/models.go index 6eddb71..83933f3 100644 --- a/internal/traffic/models.go +++ b/internal/traffic/models.go @@ -85,3 +85,23 @@ type IngestTrafficLog struct { UpstreamTimeMs *int `json:"upstream_time_ms,omitempty"` Timestamp int64 `json:"timestamp"` } + +type UnknownDomainStats struct { + TotalRequests int64 `json:"total_requests"` + TopDomains []UnknownDomainEntry `json:"top_domains"` + TopIPs []UnknownDomainIPEntry `json:"top_ips"` + RecentLogs []TrafficLog `json:"recent_logs"` +} + +type UnknownDomainEntry struct { + Domain string `json:"domain"` + RequestCount int64 `json:"request_count"` + LastSeen time.Time `json:"last_seen"` +} + +type UnknownDomainIPEntry struct { + IP string `json:"ip"` + RequestCount int64 `json:"request_count"` + Domains []string `json:"domains"` + LastSeen time.Time `json:"last_seen"` +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 6901a5d..7f67bef 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -51,6 +51,7 @@ type NginxConfig struct { ReloadCommand string `yaml:"reload_command" json:"reload_command"` External bool `yaml:"external" json:"external"` ContainerWebrootPath string `yaml:"container_webroot_path" json:"container_webroot_path"` + RejectUnknownDomains bool `yaml:"reject_unknown_domains" json:"reject_unknown_domains"` } type CertbotConfig struct { diff --git a/templates/infra/nginx/lua/security.lua b/templates/infra/nginx/lua/security.lua index 79303f9..55c40b2 100644 --- a/templates/infra/nginx/lua/security.lua +++ b/templates/infra/nginx/lua/security.lua @@ -11,10 +11,10 @@ local AGENT_IP = "{{.AgentIP}}" local AGENT_PORT = {{.AgentPort}} local INTERNAL_TOKEN = "{{.InternalAPIToken}}" --- Blocked IPs cache settings -local BLOCKED_IPS_CACHE_TTL = 30 -- seconds -local BLOCKED_IPS_CACHE_KEY = "blocked_ips_list" +-- Cache settings +local CACHE_TTL = 30 -- seconds local BLOCKED_IPS_LAST_FETCH = "blocked_ips_last_fetch" +local WHITELIST_LAST_FETCH = "whitelist_last_fetch" -- Suspicious paths patterns local suspicious_patterns = { @@ -63,14 +63,34 @@ local scanner_patterns = { "zgrab", } --- Check if an IP is blocked (with caching) +local function get_real_client_ip() + local cf_ip = ngx.var.http_cf_connecting_ip + if cf_ip and cf_ip ~= "" then + return cf_ip + end + + local xff = ngx.var.http_x_forwarded_for + if xff and xff ~= "" then + local first_ip = xff:match("^([^,]+)") + if first_ip then + return first_ip:match("^%s*(.-)%s*$") + end + end + + return ngx.var.remote_addr +end + +function _M.get_client_ip() + return get_real_client_ip() +end + function _M.is_blocked(ip) if not ip then return false end + if _M.is_whitelisted(ip, nil) then return false end local dict = ngx.shared.blocked_ips if not dict then return false end - -- Check if this specific IP is marked as blocked local is_blocked = dict:get("ip:" .. ip) if is_blocked ~= nil then return is_blocked @@ -80,7 +100,7 @@ function _M.is_blocked(ip) local last_fetch = dict:get(BLOCKED_IPS_LAST_FETCH) or 0 local now = ngx.time() - if now - last_fetch > BLOCKED_IPS_CACHE_TTL then + if now - last_fetch > CACHE_TTL then -- Refresh in background to not block the request ngx.timer.at(0, function() _M.refresh_blocked_ips() @@ -145,20 +165,215 @@ function _M.refresh_blocked_ips() local blocked_ips = data.blocked_ips or {} for _, entry in ipairs(blocked_ips) do if entry.ip then - dict:set("ip:" .. entry.ip, true, BLOCKED_IPS_CACHE_TTL * 2) + dict:set("ip:" .. entry.ip, true, CACHE_TTL * 2) end end ngx.log(ngx.INFO, "Refreshed blocked IPs cache: ", #blocked_ips, " IPs") end --- Initialize blocked IPs cache on worker start function _M.init_blocked_ips() ngx.timer.at(0, function() _M.refresh_blocked_ips() end) end +local function is_ipv6(ip) + return ip:find(":") ~= nil +end + +local function ipv4_to_int(ip_str) + local parts = {ip_str:match("^(%d+)%.(%d+)%.(%d+)%.(%d+)$")} + if #parts ~= 4 then return nil end + return tonumber(parts[1]) * 16777216 + tonumber(parts[2]) * 65536 + + tonumber(parts[3]) * 256 + tonumber(parts[4]) +end + +local function expand_ipv6(ip) + if ip:find("::") then + local left, right = ip:match("^(.-)::(.*)$") + left = left or "" + right = right or "" + local left_parts = {} + local right_parts = {} + for part in left:gmatch("[^:]+") do + left_parts[#left_parts + 1] = part + end + for part in right:gmatch("[^:]+") do + right_parts[#right_parts + 1] = part + end + local missing = 8 - #left_parts - #right_parts + local parts = {} + for _, p in ipairs(left_parts) do parts[#parts + 1] = p end + for _ = 1, missing do parts[#parts + 1] = "0" end + for _, p in ipairs(right_parts) do parts[#parts + 1] = p end + return parts + else + local parts = {} + for part in ip:gmatch("[^:]+") do + parts[#parts + 1] = part + end + return parts + end +end + +local function ipv6_parts_to_ints(parts) + if #parts ~= 8 then return nil end + local ints = {} + for i, p in ipairs(parts) do + ints[i] = tonumber(p, 16) or 0 + end + return ints +end + +local function ipv6_match_cidr(ip_ints, cidr_ints, bits) + local full_groups = math.floor(bits / 16) + local remaining_bits = bits % 16 + + for i = 1, full_groups do + if ip_ints[i] ~= cidr_ints[i] then return false end + end + + if remaining_bits > 0 and full_groups < 8 then + local mask = bit.lshift(0xFFFF, 16 - remaining_bits) + mask = bit.band(mask, 0xFFFF) + if bit.band(ip_ints[full_groups + 1], mask) ~= bit.band(cidr_ints[full_groups + 1], mask) then + return false + end + end + + return true +end + +local function is_ip_in_cidr(ip, cidr) + local cidr_ip, cidr_bits = cidr:match("^(.+)/(%d+)$") + if not cidr_ip then return ip == cidr end + + local bits = tonumber(cidr_bits) + local ip_is_v6 = is_ipv6(ip) + local cidr_is_v6 = is_ipv6(cidr_ip) + + if ip_is_v6 ~= cidr_is_v6 then return false end + + if ip_is_v6 then + local ip_parts = expand_ipv6(ip) + local cidr_parts = expand_ipv6(cidr_ip) + local ip_ints = ipv6_parts_to_ints(ip_parts) + local cidr_ints = ipv6_parts_to_ints(cidr_parts) + if not ip_ints or not cidr_ints then return false end + return ipv6_match_cidr(ip_ints, cidr_ints, bits) + else + local ip_int = ipv4_to_int(ip) + local cidr_int = ipv4_to_int(cidr_ip) + if not ip_int or not cidr_int then return false end + local mask = bits == 0 and 0 or (0xFFFFFFFF - (2^(32 - bits) - 1)) + return bit.band(ip_int, mask) == bit.band(cidr_int, mask) + end +end + +function _M.is_whitelisted(ip, path) + local dict = ngx.shared.whitelist + if not dict then return false end + + if ip then + if dict:get("ip:" .. ip) then return true end + local cidrs = dict:get("cidrs") + if cidrs then + for cidr in cidrs:gmatch("[^,]+") do + if is_ip_in_cidr(ip, cidr) then return true end + end + end + end + + if path then + local paths = dict:get("paths") + if paths then + for wpath in paths:gmatch("[^,]+") do + if path:sub(1, #wpath) == wpath then return true end + end + end + end + + local last_fetch = dict:get(WHITELIST_LAST_FETCH) or 0 + if ngx.time() - last_fetch > CACHE_TTL then + ngx.timer.at(0, function() _M.refresh_whitelist() end) + end + + return false +end + +function _M.refresh_whitelist() + local dict = ngx.shared.whitelist + if not dict then return end + + local httpc = http.new() + httpc:set_timeout(3000) + + local conn_ok, conn_err = httpc:connect({ + host = AGENT_IP, + port = AGENT_PORT, + scheme = "http", + }) + + if not conn_ok then + ngx.log(ngx.ERR, "Failed to connect to agent for whitelist: ", conn_err) + return + end + + local res, req_err = httpc:request({ + method = "GET", + path = "/api/_internal/whitelist", + headers = { + ["Host"] = AGENT_IP .. ":" .. AGENT_PORT, + ["X-Internal-Token"] = INTERNAL_TOKEN, + } + }) + + if not res then + ngx.log(ngx.ERR, "Failed to fetch whitelist: ", req_err) + httpc:close() + return + end + + local body = res:read_body() + httpc:close() + + if res.status ~= 200 then + ngx.log(ngx.ERR, "Whitelist API returned status: ", res.status) + return + end + + local data, decode_err = cjson.decode(body) + if not data then + ngx.log(ngx.ERR, "Failed to decode whitelist response: ", decode_err) + return + end + + dict:flush_all() + dict:set(WHITELIST_LAST_FETCH, ngx.time()) + + local ips, cidrs, paths = {}, {}, {} + for _, entry in ipairs(data.whitelist or {}) do + if entry.type == "ip" then + dict:set("ip:" .. entry.value, true, CACHE_TTL * 2) + table.insert(ips, entry.value) + elseif entry.type == "cidr" then + table.insert(cidrs, entry.value) + elseif entry.type == "path" then + table.insert(paths, entry.value) + end + end + + if #cidrs > 0 then dict:set("cidrs", table.concat(cidrs, ","), CACHE_TTL * 2) end + if #paths > 0 then dict:set("paths", table.concat(paths, ","), CACHE_TTL * 2) end + + ngx.log(ngx.INFO, "Refreshed whitelist: ", #ips, " IPs, ", #cidrs, " CIDRs, ", #paths, " paths") +end + +function _M.init_whitelist() + ngx.timer.at(0, function() _M.refresh_whitelist() end) +end + function _M.is_suspicious_path(uri) if not uri then return false end local uri_lower = string.lower(uri) @@ -184,12 +399,13 @@ end function _M.capture_event() local status = ngx.status local uri = ngx.var.uri - local ip = ngx.var.remote_addr + local ip = get_real_client_ip() local method = ngx.var.request_method local user_agent = ngx.var.http_user_agent or "" local host = ngx.var.host or "" - -- Only capture security-relevant events + if _M.is_whitelisted(ip, uri) then return end + local should_capture = false -- Scanner detection diff --git a/templates/infra/nginx/lua/traffic.lua b/templates/infra/nginx/lua/traffic.lua index a2e62bf..615434b 100644 --- a/templates/infra/nginx/lua/traffic.lua +++ b/templates/infra/nginx/lua/traffic.lua @@ -3,6 +3,7 @@ local cjson = require "cjson.safe" local http = require "resty.http" +local security = require "security" local _M = {} @@ -11,11 +12,14 @@ local AGENT_IP = "{{.AgentIP}}" local AGENT_PORT = {{.AgentPort}} function _M.log_request() + local uri = ngx.var.uri or "" + local ip = security.get_client_ip() + + if security.is_whitelisted(ip, uri) then return end + local status = ngx.status local host = ngx.var.host or "" - local uri = ngx.var.uri or "" local method = ngx.var.request_method or "" - local ip = ngx.var.remote_addr or "" local request_time = ngx.var.request_time or "0" local bytes_sent = ngx.var.bytes_sent or "0" local request_length = ngx.var.request_length or "0" diff --git a/templates/infra/nginx/nginx.lua.conf b/templates/infra/nginx/nginx.lua.conf index 830cc66..9e23892 100644 --- a/templates/infra/nginx/nginx.lua.conf +++ b/templates/infra/nginx/nginx.lua.conf @@ -38,10 +38,11 @@ http { # Lua package path lua_package_path "/etc/nginx/lua/?.lua;;"; - # Shared dictionary for security events + # Shared dictionaries for security lua_shared_dict security_events 10m; lua_shared_dict ip_rate_limit 10m; lua_shared_dict blocked_ips 5m; + lua_shared_dict whitelist 5m; # Load Lua modules init_by_lua_block { @@ -49,14 +50,15 @@ http { traffic = require "traffic" } - # Initialize blocked IPs cache on worker start + # Initialize caches on worker start init_worker_by_lua_block { security.init_blocked_ips() + security.init_whitelist() } - # Check blocked IPs on every request + # Check blocked IPs on every request (uses real client IP from headers) access_by_lua_block { - if security.is_blocked(ngx.var.remote_addr) then + if security.is_blocked(security.get_client_ip()) then ngx.exit(ngx.HTTP_FORBIDDEN) end } @@ -103,4 +105,16 @@ http { } } } +{{if .RejectUnknownDomains}} + # Reject requests to unknown domains + server { + listen 80 default_server; + listen 443 ssl http2 default_server; + server_name _; + + ssl_reject_handshake on; + + return 444; + } +{{end}} } diff --git a/templates/templates.go b/templates/templates.go index f838bbf..d88b694 100644 --- a/templates/templates.go +++ b/templates/templates.go @@ -79,6 +79,10 @@ func GetCategories() []Category { return Categories } +type NginxConfigData struct { + RejectUnknownDomains bool +} + func GetNginxConfig(luaEnabled bool) ([]byte, error) { if luaEnabled { return FS.ReadFile("infra/nginx/nginx.lua.conf") @@ -86,6 +90,32 @@ func GetNginxConfig(luaEnabled bool) ([]byte, error) { return FS.ReadFile("infra/nginx/nginx.conf") } +func GetNginxConfigWithData(luaEnabled bool, data NginxConfigData) ([]byte, error) { + var content []byte + var err error + + if luaEnabled { + content, err = FS.ReadFile("infra/nginx/nginx.lua.conf") + } else { + content, err = FS.ReadFile("infra/nginx/nginx.conf") + } + if err != nil { + return nil, err + } + + tmpl, err := template.New("nginx.conf").Parse(string(content)) + if err != nil { + return nil, err + } + + var buf bytes.Buffer + if err := tmpl.Execute(&buf, data); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + func GetNginxSecurityLua() ([]byte, error) { return FS.ReadFile("infra/nginx/lua/security.lua") }