Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 57 additions & 6 deletions internal/api/security_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,12 @@ func (s *Server) updateDeploymentSecurity(c *gin.Context) {
}

vhostUpdated := false
var hookStatus interface{}

if s.proxyOrchestrator != nil && s.proxyOrchestrator.NginxManager().VirtualHostExists(name) {
if err := s.proxyOrchestrator.NginxManager().UpdateVirtualHost(deployment); err != nil {
nginxMgr := s.proxyOrchestrator.NginxManager()

if err := nginxMgr.UpdateVirtualHost(deployment); err != nil {
c.JSON(http.StatusOK, gin.H{
"security": securityConfig,
"vhost_updated": false,
Expand All @@ -405,30 +409,49 @@ func (s *Server) updateDeploymentSecurity(c *gin.Context) {
return
}

if err := s.proxyOrchestrator.NginxManager().TestConfig(); err != nil {
if err := nginxMgr.ValidateSecurityHooks(name, securityConfig.Enabled); err != nil {
c.JSON(http.StatusOK, gin.H{
"security": securityConfig,
"vhost_updated": true,
"validation_error": err.Error(),
"warning": "Vhost updated but security hook validation failed",
})
return
}

hookStatus, _ = nginxMgr.GetSecurityHookStatus(name)

if err := nginxMgr.TestConfig(); err != nil {
c.JSON(http.StatusOK, gin.H{
"security": securityConfig,
"vhost_updated": false,
"vhost_updated": true,
"hook_status": hookStatus,
"warning": "Security config saved but nginx config test failed: " + err.Error(),
})
return
}

if err := s.proxyOrchestrator.NginxManager().Reload(); err != nil {
if err := nginxMgr.Reload(); err != nil {
c.JSON(http.StatusOK, gin.H{
"security": securityConfig,
"vhost_updated": true,
"hook_status": hookStatus,
"warning": "Nginx reload failed (may need manual reload): " + err.Error(),
})
return
}
vhostUpdated = true
}

c.JSON(http.StatusOK, gin.H{
response := gin.H{
"security": securityConfig,
"vhost_updated": vhostUpdated,
})
}
if hookStatus != nil {
response["hook_status"] = hookStatus
}

c.JSON(http.StatusOK, response)
}

// getDeploymentSecurityEvents returns security events for a deployment
Expand Down Expand Up @@ -675,3 +698,31 @@ func (s *Server) updateSecuritySettings(c *gin.Context) {

c.JSON(http.StatusOK, result)
}

// refreshSecurityScripts regenerates Lua scripts with correct agent IP and reloads nginx
func (s *Server) refreshSecurityScripts(c *gin.Context) {
if !s.config.Security.Enabled {
c.JSON(http.StatusBadRequest, gin.H{
"error": "Security module not enabled",
})
return
}

if !s.infraManager.IsNginxRunning() {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": "Nginx container is not running",
})
return
}

result, err := s.infraManager.RefreshSecurityScripts()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": err.Error(),
"result": result,
})
return
}

c.JSON(http.StatusOK, result)
}
45 changes: 44 additions & 1 deletion internal/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/flatrun/agent/internal/proxy"
"github.com/flatrun/agent/internal/security"
"github.com/flatrun/agent/internal/system"
"github.com/flatrun/agent/internal/traffic"
"github.com/flatrun/agent/pkg/config"
"github.com/flatrun/agent/pkg/models"
"github.com/flatrun/agent/pkg/plugins"
Expand Down Expand Up @@ -53,6 +54,7 @@ type Server struct {
infraManager *infra.Manager
credentialsManager *credentials.Manager
securityManager *security.Manager
trafficManager *traffic.Manager
}

func New(cfg *config.Config, configPath string) *Server {
Expand Down Expand Up @@ -99,6 +101,12 @@ func New(cfg *config.Config, configPath string) *Server {
}
}

var trafficManager *traffic.Manager
trafficManager, err := traffic.NewManager(cfg.DeploymentsPath, 7)
if err != nil {
log.Printf("Warning: Failed to initialize traffic manager: %v", err)
}

s := &Server{
config: cfg,
configPath: configPath,
Expand All @@ -115,6 +123,7 @@ func New(cfg *config.Config, configPath string) *Server {
infraManager: infraManager,
credentialsManager: credentialsManager,
securityManager: securityManager,
trafficManager: trafficManager,
}

s.setupRoutes()
Expand Down Expand Up @@ -224,6 +233,7 @@ func (s *Server) setupRoutes() {
protected.POST("/databases/tables/schema", s.describeTable)
protected.POST("/databases/query", s.executeDatabaseQuery)
protected.POST("/databases/users", s.listDatabaseUsers)
protected.POST("/databases/users/by-database", s.listUsersByDatabase)
protected.POST("/databases/create", s.createDatabaseInServer)
protected.POST("/databases/delete", s.deleteDatabaseInServer)
protected.POST("/databases/users/create", s.createDatabaseUser)
Expand Down Expand Up @@ -268,13 +278,21 @@ func (s *Server) setupRoutes() {
protected.GET("/security/realtime-capture", s.getRealtimeCaptureStatus)
protected.PUT("/security/realtime-capture", s.setRealtimeCaptureStatus)
protected.GET("/security/health", s.getSecurityHealth)
protected.POST("/security/refresh", s.refreshSecurityScripts)
protected.GET("/deployments/:name/security", s.getDeploymentSecurity)
protected.PUT("/deployments/:name/security", s.updateDeploymentSecurity)
protected.GET("/deployments/:name/security/events", s.getDeploymentSecurityEvents)

// Traffic endpoints
protected.GET("/traffic/logs", s.getTrafficLogs)
protected.GET("/traffic/stats", s.getTrafficStats)
protected.POST("/traffic/cleanup", s.cleanupTrafficLogs)
protected.GET("/deployments/:name/traffic", s.getDeploymentTrafficStats)
}

// Security event ingest endpoint (no auth - called by nginx Lua)
// Ingest endpoints (no auth - called by nginx Lua)
api.POST("/security/events/ingest", s.ingestSecurityEvent)
api.POST("/traffic/ingest", s.ingestTrafficLog)
}
}

Expand Down Expand Up @@ -3464,6 +3482,31 @@ func (s *Server) listDatabaseUsers(c *gin.Context) {
})
}

func (s *Server) listUsersByDatabase(c *gin.Context) {
var req struct {
database.ConnectionConfig
Database string `json:"database" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": err.Error(),
})
return
}

users, err := s.databaseManager.ListDatabaseUsers(&req.ConnectionConfig, req.Database)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": err.Error(),
})
return
}

c.JSON(http.StatusOK, gin.H{
"users": users,
})
}

func (s *Server) createDatabaseInServer(c *gin.Context) {
var req struct {
database.ConnectionConfig
Expand Down
80 changes: 80 additions & 0 deletions internal/api/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1044,3 +1044,83 @@ func TestProxySyncResultFailed(t *testing.T) {
t.Error("expected Created to be false for failed")
}
}

func TestListUsersByDatabaseRequestStructure(t *testing.T) {
type listUsersByDatabaseRequest struct {
Type string `json:"type"`
Host string `json:"host"`
Port int `json:"port"`
Username string `json:"username"`
Password string `json:"password"`
Database string `json:"database"`
}

tests := []struct {
name string
req listUsersByDatabaseRequest
wantMissing string
}{
{
name: "valid request",
req: listUsersByDatabaseRequest{
Type: "mysql",
Host: "localhost",
Port: 3306,
Username: "root",
Password: "secret",
Database: "testdb",
},
wantMissing: "",
},
{
name: "missing database",
req: listUsersByDatabaseRequest{
Type: "mysql",
Host: "localhost",
Port: 3306,
Username: "root",
Password: "secret",
},
wantMissing: "database",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.wantMissing == "database" && tt.req.Database != "" {
t.Error("test setup error: database should be empty for missing database test")
}
if tt.wantMissing == "" && tt.req.Database == "" {
t.Error("test setup error: database should not be empty for valid request test")
}
})
}
}

func TestListUsersByDatabaseResponseStructure(t *testing.T) {
type userInfo struct {
Name string `json:"name"`
Host string `json:"host,omitempty"`
}

type response struct {
Users []userInfo `json:"users"`
}

resp := response{
Users: []userInfo{
{Name: "app_user", Host: "%"},
{Name: "readonly", Host: "localhost"},
},
}

if len(resp.Users) != 2 {
t.Errorf("expected 2 users, got %d", len(resp.Users))
}
if resp.Users[0].Name != "app_user" {
t.Errorf("expected first user 'app_user', got '%s'", resp.Users[0].Name)
}
if resp.Users[1].Host != "localhost" {
t.Errorf("expected second user host 'localhost', got '%s'", resp.Users[1].Host)
}
}
Loading
Loading