diff --git a/.gitignore b/.gitignore index f2f2d90..220b6c4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /bin/ /certs/ +/.agents/ .env diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..162f7a7 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,164 @@ +# AGENTS.md + +Guidance for coding agents working in `github.com/EternisAI/silo-proxy`. + +## Scope +- Prefer minimal, safe changes that preserve behavior. +- Follow existing repo patterns over generic best practices. +- Keep API, service/domain, and infrastructure concerns separated. +- Avoid unrelated refactors in focused tasks. + +## Stack Snapshot +- Go `1.25.0`. +- HTTP: Gin. +- RPC: gRPC/protobuf (`proto/proxy.proto`). +- DB: PostgreSQL via `pgx/v5`, migrations via goose, queries via sqlc. +- Auth: JWT + API key middleware. +- Logging: `log/slog`. +- Tests: Go `testing` + `testify` + testcontainers system tests. + +## Repository Layout +- `cmd/silo-proxy-server`: server entrypoint, config, logger. +- `cmd/silo-proxy-agent`: agent entrypoint, config, logger. +- `internal/api/http`: router, handlers, middleware, per-agent HTTP management. +- `internal/grpc/server`: stream server and connection manager. +- `internal/grpc/client`: agent stream client and local forwarder. +- `internal/auth`, `internal/users`: business/domain services. +- `internal/db`: connection init, migrations, sqlc output. +- `proto`: protobuf schema and generated stubs. +- `systemtest`: integration tests against ephemeral Postgres. + +## Build / Run / Test Commands + +Prefer `make` targets when available. + +### Primary Make Targets +- `make build` - build both binaries. +- `make build-server` - build server binary. +- `make build-agent` - build agent binary. +- `make run` - run server locally. +- `make run-agent` - run agent locally. +- `make test` - run all tests (`go test -v ./...`). +- `make clean` - clean test cache and `bin/*`. +- `make generate` - run `sqlc generate`. +- `make protoc-gen` - regenerate protobuf/gRPC code. +- `make generate-certs` - generate local TLS certs. + +### Running Focused Tests (Important) +- Single package: + - `go test -v ./internal/grpc/server` +- Single test function: + - `go test -v ./internal/grpc/server -run '^TestConnectionManager_Register_WithServerManager$'` +- Single subtest: + - `go test -v ./systemtest -run '^TestSystemIntegration$/^Login$'` +- Re-run without cache: + - `go test -v -count=1 ./internal/api/http` +- Benchmarks: + - `go test -bench . ./internal/grpc/server` +- Compile-only check: + - `go test -run '^$' ./...` + +### Lint / Format / Validation +- No dedicated `make lint` target is currently defined. +- Format changed files with `gofmt -w `. +- Run `go test ./...` as baseline verification. +- Optional static pass: `go vet ./...`. +- Prefer formatting only touched files to avoid noisy diffs. + +## Test Selection Guidance +- Start with the narrowest package/test that covers your change. +- Use exact `-run` regex anchors (`^...$`) to avoid accidental matches. +- Use `-count=1` if caching may hide flakiness. +- `systemtest` requires Docker/testcontainers available locally. +- Run full `make test` before finalizing broad changes. + +## Code Style Guidelines + +### Imports +- Keep imports `gofmt`-organized. +- Standard library first, then external/internal packages. +- Use aliases only when clarity improves (e.g. `grpcserver`). +- Avoid dot imports and unused imports. + +### Formatting and Structure +- Enforce `gofmt`. +- Prefer early returns over nested conditionals. +- Keep functions focused; avoid mixing concerns. +- Add comments only for non-obvious invariants/behavior. + +### Types and API Shapes +- Use explicit structs for service/DTO outputs (`RegisterResult`, `UserInfo`). +- Keep struct fields cohesive to a single responsibility. +- Prefer typed constants for repeated timing/protocol values. +- Use typed config structs with `mapstructure` tags. + +### Naming Conventions +- Exported identifiers: `PascalCase`. +- Unexported identifiers: `camelCase`. +- Constructors follow `NewXxx(...) *Xxx`. +- Handler methods should be verb-driven (`Login`, `Register`, `DeleteUser`). +- Sentinel errors use `ErrXxx` naming. + +### Error Handling +- Return errors instead of panicking (except startup-fatal bootstrap failures). +- Wrap lower-level errors with context using `%w`. +- Branch on known causes using `errors.Is` / `errors.As`. +- Keep domain-level error contracts stable for handler mapping. +- Do not expose sensitive internals in HTTP error payloads. + +### Logging +- Use structured `slog` logs (`Info/Warn/Error/Debug`). +- Include stable keys like `agent_id`, `port`, `message_id`, `error`. +- Log lifecycle and recovery events (start/stop/retry/failure). +- Avoid noisy logs in hot loops unless debug-level is warranted. + +### Concurrency and Context +- Guard shared maps/state with `sync.RWMutex` where applicable. +- Use buffered channels for producer/consumer coordination. +- Use `context.WithTimeout` for shutdown/network operations. +- Preserve existing cancellation and cleanup semantics. + +### HTTP Layer Rules +- Keep handlers thin: bind/validate input, call service, map response. +- Return JSON errors in consistent shape: `{"error":"..."}`. +- Put cross-cutting concerns in middleware (auth, API key, logging). +- Treat read-only/status endpoints as side-effect free. + +### Service and Domain Rules +- Keep business logic in services, not in handlers/middleware. +- Keep DB-specific behavior at DB/service boundaries. +- Keep contracts explicit for auth/user and other domain services. + +### DB, Migrations, SQLC +- SQL files belong in `internal/db/queries`. +- sqlc-generated files live in `internal/db/sqlc` and must not be hand-edited. +- After SQL query/schema changes, run `make generate` and migrations. +- Keep migrations idempotent and ordered. + +### Proto / gRPC Changes +- Edit `proto/proxy.proto`, then run `make protoc-gen`. +- Update both server and agent paths for schema/message changes. +- Preserve compatibility expectations when possible. + +## Configuration Conventions +- Config is loaded from `application.yml` with env overrides. +- Nested key env mapping uses underscore replacement. +- Server and agent keep separate config roots under `cmd/...`. +- Keep examples/defaults aligned with checked-in app configs. + +## CI Notes +- CI runs `make test` and `make build` on pull requests. +- Docker publish jobs execute only on `main` and version tags. +- If packaging/runtime behavior changes, validate make/docker targets. + +## Cursor / Copilot Rules Check +- No `.cursorrules` file found. +- No `.cursor/rules/` directory found. +- No `.github/copilot-instructions.md` file found. +- If added later, treat those files as higher-priority local instructions. + +## Practical Agent Workflow +1. Make the smallest viable change. +2. Run targeted tests first, then broaden scope. +3. Keep diffs tight and architecture boundaries intact. +4. Add/adjust logs and errors for operability. diff --git a/cmd/silo-proxy-agent/config.go b/cmd/silo-proxy-agent/config.go index d3713b5..8b578b0 100644 --- a/cmd/silo-proxy-agent/config.go +++ b/cmd/silo-proxy-agent/config.go @@ -18,9 +18,10 @@ type Config struct { } type GrpcConfig struct { - ServerAddress string `mapstructure:"server_address"` - AgentID string `mapstructure:"agent_id"` - TLS TLSConfig `mapstructure:"tls"` + ServerAddress string `mapstructure:"server_address"` + AgentID string `mapstructure:"agent_id"` + ProvisioningKey string `mapstructure:"provisioning_key"` + TLS TLSConfig `mapstructure:"tls"` } type TLSConfig struct { diff --git a/cmd/silo-proxy-agent/main.go b/cmd/silo-proxy-agent/main.go index d9c2fcd..ced7fc9 100644 --- a/cmd/silo-proxy-agent/main.go +++ b/cmd/silo-proxy-agent/main.go @@ -7,6 +7,7 @@ import ( "net/http" "os" "os/signal" + "path/filepath" "sync" "syscall" "time" @@ -33,12 +34,51 @@ func main() { ServerNameOverride: config.Grpc.TLS.ServerNameOverride, } - grpcClient := grpcclient.NewClient(config.Grpc.ServerAddress, config.Grpc.AgentID, config.Local.ServiceURL, tlsConfig) + // Determine config path for persistence + configPath := "" + if config.Grpc.ProvisioningKey != "" { + // Try to find config file in common locations + possiblePaths := []string{ + "./application.yaml", + "./application.yml", + "./cmd/silo-proxy-agent/application.yaml", + "./cmd/silo-proxy-agent/application.yml", + } + for _, path := range possiblePaths { + if _, err := os.Stat(path); err == nil { + absPath, _ := filepath.Abs(path) + configPath = absPath + slog.Info("Config file found for persistence", "path", configPath) + break + } + } + if configPath == "" { + slog.Warn("Config file not found, agent_id will not be persisted") + } + } + + grpcClient := grpcclient.NewClient( + config.Grpc.ServerAddress, + config.Grpc.AgentID, + config.Grpc.ProvisioningKey, + config.Local.ServiceURL, + configPath, + tlsConfig, + ) if err := grpcClient.Start(); err != nil { slog.Error("Failed to start gRPC client", "error", err) os.Exit(1) } + if config.Grpc.ProvisioningKey != "" { + slog.Info("Agent started in provisioning mode") + } else if config.Grpc.AgentID != "" { + slog.Info("Agent started with agent_id", "agent_id", config.Grpc.AgentID) + } else { + slog.Error("Either agent_id or provisioning_key is required") + os.Exit(1) + } + gin.SetMode(gin.ReleaseMode) engine := gin.New() engine.Use(cors.New(cors.Config{ diff --git a/cmd/silo-proxy-server/main.go b/cmd/silo-proxy-server/main.go index 45b2caa..584b3e9 100644 --- a/cmd/silo-proxy-server/main.go +++ b/cmd/silo-proxy-server/main.go @@ -12,11 +12,13 @@ import ( "time" internalhttp "github.com/EternisAI/silo-proxy/internal/api/http" + "github.com/EternisAI/silo-proxy/internal/agents" "github.com/EternisAI/silo-proxy/internal/auth" "github.com/EternisAI/silo-proxy/internal/cert" "github.com/EternisAI/silo-proxy/internal/db" "github.com/EternisAI/silo-proxy/internal/db/sqlc" grpcserver "github.com/EternisAI/silo-proxy/internal/grpc/server" + "github.com/EternisAI/silo-proxy/internal/provisioning" "github.com/EternisAI/silo-proxy/internal/users" "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" @@ -46,17 +48,9 @@ func main() { authService := auth.NewService(queries, config.JWT) userService := users.NewService(queries) - tlsConfig := &grpcserver.TLSConfig{ - Enabled: config.Grpc.TLS.Enabled, - CertFile: config.Grpc.TLS.CertFile, - KeyFile: config.Grpc.TLS.KeyFile, - CAFile: config.Grpc.TLS.CAFile, - ClientAuth: config.Grpc.TLS.ClientAuth, - } - + // Initialize provisioning and agent services var certService *cert.Service if config.Grpc.TLS.Enabled { - var err error certService, err = cert.New( config.Grpc.TLS.CAFile, @@ -73,7 +67,19 @@ func main() { } } + provisioningService := provisioning.NewService(queries, certService) + agentService := agents.NewService(queries) + + tlsConfig := &grpcserver.TLSConfig{ + Enabled: config.Grpc.TLS.Enabled, + CertFile: config.Grpc.TLS.CertFile, + KeyFile: config.Grpc.TLS.KeyFile, + CAFile: config.Grpc.TLS.CAFile, + ClientAuth: config.Grpc.TLS.ClientAuth, + } + grpcSrv := grpcserver.NewServer(config.Grpc.Port, tlsConfig) + grpcSrv.SetServices(provisioningService, agentService) portManager, err := internalhttp.NewPortManager( config.Http.AgentPortRange.Start, @@ -93,10 +99,12 @@ func main() { "pool_size", config.Http.AgentPortRange.End-config.Http.AgentPortRange.Start+1) services := &internalhttp.Services{ - GrpcServer: grpcSrv, - CertService: certService, - AuthService: authService, - UserService: userService, + GrpcServer: grpcSrv, + CertService: certService, + AuthService: authService, + UserService: userService, + ProvisioningService: provisioningService, + AgentService: agentService, } gin.SetMode(gin.ReleaseMode) diff --git a/docs/provisioning/Overview.md b/docs/provisioning/Overview.md new file mode 100644 index 0000000..00bad68 --- /dev/null +++ b/docs/provisioning/Overview.md @@ -0,0 +1,441 @@ +# Key-Based Agent Provisioning System + +## Overview + +This document describes the key-based agent provisioning system for Silo Proxy. The system enables secure, user-scoped agent registration through temporary, single-use provisioning keys. + +## Problem Statement + +**Current Issues:** +- Agents use static `agent_id` with no authentication/authorization +- Any agent can claim any ID (security risk) +- No user ownership tracking (prevents multi-tenant dashboards) + +## Solution + +The provisioning system works as follows: + +1. Dashboard users generate temporary, single-use provisioning keys +2. Users copy keys to client devices +3. Agents connect with key, receive permanent agent_id (UUID) + optional TLS cert +4. Future connections validated against database +5. All agents associated with user accounts + +## Architecture + +### Database Schema + +#### provisioning_keys Table +- Stores SHA-256 hashes of provisioning keys (never plaintext) +- Single-use by default (`max_uses=1`) +- Time-based expiration (24-48 hours recommended) +- Status lifecycle: active → exhausted/expired/revoked + +#### agents Table +- `id` (UUID) IS the agent identifier (no separate agent_id column) +- Tracks user ownership and provisioning key used +- Stores connection metadata and status + +#### agent_connection_logs Table +- Audit trail for all agent connections +- Tracks connection duration, IP address, disconnect reason + +### Provisioning Flow + +#### Case 1: New Agent (First Connection) + +``` +1. Agent → Server: ProxyMessage { + Type: PING + Metadata: { + "provisioning_key": "pk_abc123..." + } + } + +2. Server: + - Hash key (SHA-256) + - Lookup in provisioning_keys table + - Validate: status='active', used_count < max_uses, expires_at > NOW() + - Insert into agents table (id auto-generated) + - Increment provisioning_keys.used_count + - Register in ConnectionManager using agents.id + +3. Server → Agent: ProxyMessage { + Type: PONG + Metadata: { + "provisioning_status": "success", + "agent_id": "550e8400-..." // agents.id + } + } +``` + +#### Case 2: Established Agent + +``` +1. Agent → Server: ProxyMessage { + Type: PING + Metadata: { + "agent_id": "550e8400-..." + } + } + +2. Server: + - Lookup agents.id in database + - Validate: status = 'active' + - Update last_seen_at + - Register in ConnectionManager + +3. Server → Agent: PONG (normal flow) +``` + +#### Case 3: Legacy Agent (Auto-Migration) + +``` +1. Agent sends agent_id (old static ID) +2. Server checks if exists in agents table +3. If NOT found: Auto-register with default user +4. Connection proceeds normally +``` + +## Implementation Status + +### Phase 1: Database Schema ✅ COMPLETED + +**Deliverables:** +- ✅ 3 migration files created and tested +- ✅ SQLC queries defined for all tables +- ✅ Type-safe Go code generated via SQLC + +**Files Created:** +- `internal/db/migrations/0002_create_provisioning_keys.sql` +- `internal/db/migrations/0003_create_agents.sql` +- `internal/db/migrations/0004_create_agent_connection_logs.sql` +- `internal/db/queries/provisioning_keys.sql` +- `internal/db/queries/agents.sql` +- `internal/db/queries/agent_connection_logs.sql` +- `internal/db/sqlc/*.go` (generated) + +### Phase 2: Core Provisioning Logic ✅ COMPLETED + +**Deliverables:** +- ✅ Service layer (provisioning, agents) +- ✅ gRPC stream handler integration +- ✅ ConnectionManager database persistence +- ⏸️ E2E test: agent provisions via gRPC stream (will be done in Phase 3) + +**Files Created/Modified:** +- `internal/provisioning/service.go` (NEW) - Key generation, validation, agent provisioning +- `internal/provisioning/models.go` (NEW) - Domain models +- `internal/agents/service.go` (NEW) - Agent management, connection logging +- `internal/agents/models.go` (NEW) - Domain models +- `internal/grpc/server/stream_handler.go` (MODIFIED) - Provisioning handshake +- `internal/grpc/server/connection_manager.go` (MODIFIED) - DB persistence +- `internal/grpc/server/server.go` (MODIFIED) - Service initialization +- `internal/grpc/server/connection_manager_test.go` (MODIFIED) - Test updates +- `cmd/silo-proxy-server/main.go` (MODIFIED) - Service wiring + +### Phase 3: API & Client Integration ✅ COMPLETED + +**Deliverables:** +- ✅ HTTP API endpoints (POST/GET/DELETE keys, GET/DELETE agents) +- ✅ Agent config + client changes +- ✅ Config persistence after provisioning +- ⏸️ E2E test: full provisioning flow via dashboard (manual testing guide provided) + +**Files Created/Modified:** +- `internal/api/http/dto/provisioning.go` (NEW) - API DTOs +- `internal/api/http/handler/provisioning.go` (NEW) - Key management endpoints +- `internal/api/http/handler/agents.go` (NEW) - Agent management endpoints +- `internal/api/http/router.go` (MODIFIED) - Route registration +- `cmd/silo-proxy-server/main.go` (MODIFIED) - Service wiring +- `cmd/silo-proxy-agent/config.go` (MODIFIED) - Add provisioning_key field +- `cmd/silo-proxy-agent/main.go` (MODIFIED) - Config persistence logic +- `internal/grpc/client/client.go` (MODIFIED) - Provisioning handshake +- `docs/provisioning/Overview.md` (UPDATED) - Phase 3 completion + +## Security + +### Key Generation +- 32 bytes generated using `crypto/rand` +- Format: `pk_` prefix + base64url encoding +- Keys never logged in plaintext + +### Storage +- Only SHA-256 hash stored in database +- Original key shown once to user during creation +- Keys are single-use by default + +### Validation +- Key hash lookup in database +- Expiration time check +- Usage count enforcement +- User ownership validation for all operations + +### Rate Limiting +- 5 requests/second on provisioning endpoint +- Prevents brute-force attacks + +### Audit Logging +- All provisioning attempts logged +- Connection history tracked +- Legacy agent connections flagged + +## Configuration + +### Server Configuration + +**Database Connection:** +```yaml +database: + url: "postgres://user:password@localhost:5432/silo-proxy" + schema: "public" # Optional, defaults to "public" +``` + +### Agent Configuration + +**Provisioning (First Connection):** +```yaml +grpc: + server_address: "server.example.com:9090" + provisioning_key: "pk_abc123..." # Provided by dashboard + tls: + enabled: true + ca_file: "ca.pem" +``` + +**Established Agent (Subsequent Connections):** +```yaml +grpc: + server_address: "server.example.com:9090" + agent_id: "550e8400-e29b-41d4-a716-446655440000" # Auto-saved after provisioning + tls: + enabled: true + ca_file: "ca.pem" +``` + +## Testing + +### Phase 1 Verification +```bash +# 1. Run migrations (embedded, automatic on server start) +make build + +# 2. Verify SQLC generation +make generate + +# 3. Verify tables exist (requires running database) +psql -d silo-proxy -c "\dt" + +# Expected output: +# provisioning_keys +# agents +# agent_connection_logs +# users +``` + +### Phase 3 Manual Testing Guide + +#### Prerequisites +- Running PostgreSQL database +- Server built and configured +- Agent built + +#### Step 1: Start the Server +```bash +# Make sure database URL is configured +export DB_URL="postgres://user:password@localhost:5432/silo-proxy?sslmode=disable" + +# Start server +./bin/silo-proxy-server +``` + +#### Step 2: Create a User Account +```bash +# Register a new user +curl -X POST http://localhost:8080/auth/register \ + -H "Content-Type: application/json" \ + -d '{ + "username": "testuser", + "password": "testpass123", + "role": "User" + }' + +# Login and get JWT token +TOKEN=$(curl -X POST http://localhost:8080/auth/login \ + -H "Content-Type: application/json" \ + -d '{ + "username": "testuser", + "password": "testpass123" + }' | jq -r .token) + +echo "JWT Token: $TOKEN" +``` + +#### Step 3: Generate a Provisioning Key +```bash +# Create a single-use provisioning key (expires in 24 hours) +RESPONSE=$(curl -X POST http://localhost:8080/provisioning-keys \ + -H "Authorization: Bearer $TOKEN" \ + -H "Content-Type: application/json" \ + -d '{ + "max_uses": 1, + "expires_in_hours": 24, + "notes": "Test key for agent-1" + }') + +echo "Provisioning Response: $RESPONSE" + +# Extract the provisioning key +KEY=$(echo $RESPONSE | jq -r .key) +echo "Provisioning Key: $KEY" +``` + +#### Step 4: Configure and Start Agent +```bash +# Create agent config with provisioning key +cat > cmd/silo-proxy-agent/application.yaml < 0 { + _ = json.Unmarshal(dbAgent.Metadata, &metadata) + } + + result := &Agent{ + ID: uuidToString(dbAgent.ID.Bytes), + UserID: uuidToString(dbAgent.UserID.Bytes), + Status: string(dbAgent.Status), + CertFingerprint: dbAgent.CertFingerprint.String, + RegisteredAt: dbAgent.RegisteredAt.Time, + LastSeenAt: dbAgent.LastSeenAt.Time, + Metadata: metadata, + Notes: dbAgent.Notes.String, + } + + if dbAgent.ProvisionedWithKeyID.Valid { + result.ProvisionedWithKeyID = uuidToString(dbAgent.ProvisionedWithKeyID.Bytes) + } + + if dbAgent.LastIpAddress != nil { + result.LastIPAddress = dbAgent.LastIpAddress.String() + } + + return result, nil +} + +// ListAgentsByUser retrieves all agents for a user +func (s *Service) ListAgentsByUser(ctx context.Context, userID string) ([]Agent, error) { + parsedUserID, err := uuid.Parse(userID) + if err != nil { + return nil, fmt.Errorf("invalid user ID: %w", err) + } + + dbAgents, err := s.queries.ListAgentsByUser(ctx, pgtype.UUID{Bytes: parsedUserID, Valid: true}) + if err != nil { + return nil, fmt.Errorf("failed to list agents: %w", err) + } + + result := make([]Agent, len(dbAgents)) + for i, a := range dbAgents { + var metadata map[string]interface{} + if len(a.Metadata) > 0 { + _ = json.Unmarshal(a.Metadata, &metadata) + } + + result[i] = Agent{ + ID: uuidToString(a.ID.Bytes), + UserID: uuidToString(a.UserID.Bytes), + Status: string(a.Status), + CertFingerprint: a.CertFingerprint.String, + RegisteredAt: a.RegisteredAt.Time, + LastSeenAt: a.LastSeenAt.Time, + Metadata: metadata, + Notes: a.Notes.String, + } + + if a.ProvisionedWithKeyID.Valid { + result[i].ProvisionedWithKeyID = uuidToString(a.ProvisionedWithKeyID.Bytes) + } + + if a.LastIpAddress != nil { + result[i].LastIPAddress = a.LastIpAddress.String() + } + } + + return result, nil +} + +// UpdateLastSeen updates the agent's last seen timestamp and IP address +func (s *Service) UpdateLastSeen(ctx context.Context, agentID string, timestamp time.Time, ipAddress string) error { + parsedID, err := uuid.Parse(agentID) + if err != nil { + return ErrInvalidAgentID + } + + var ipAddr *netip.Addr + if ipAddress != "" { + parsed, err := netip.ParseAddr(ipAddress) + if err == nil { + ipAddr = &parsed + } + } + + if err := s.queries.UpdateAgentLastSeen(ctx, sqlc.UpdateAgentLastSeenParams{ + ID: pgtype.UUID{Bytes: parsedID, Valid: true}, + LastSeenAt: pgtype.Timestamp{Time: timestamp, Valid: true}, + LastIpAddress: ipAddr, + }); err != nil { + return fmt.Errorf("failed to update last seen: %w", err) + } + + return nil +} + +// UpdateStatus updates the agent's status +func (s *Service) UpdateStatus(ctx context.Context, agentID string, status string) error { + parsedID, err := uuid.Parse(agentID) + if err != nil { + return ErrInvalidAgentID + } + + var agentStatus sqlc.AgentStatus + switch status { + case "active": + agentStatus = sqlc.AgentStatusActive + case "inactive": + agentStatus = sqlc.AgentStatusInactive + case "suspended": + agentStatus = sqlc.AgentStatusSuspended + default: + return fmt.Errorf("invalid status: %s", status) + } + + if err := s.queries.UpdateAgentStatus(ctx, sqlc.UpdateAgentStatusParams{ + ID: pgtype.UUID{Bytes: parsedID, Valid: true}, + Status: agentStatus, + }); err != nil { + return fmt.Errorf("failed to update status: %w", err) + } + + slog.Info("Agent status updated", "agent_id", agentID, "status", status) + return nil +} + +// CreateConnectionLog creates a new connection log entry +func (s *Service) CreateConnectionLog(ctx context.Context, agentID string, connectedAt time.Time, ipAddress string) (string, error) { + parsedID, err := uuid.Parse(agentID) + if err != nil { + return "", ErrInvalidAgentID + } + + var ipAddr *netip.Addr + if ipAddress != "" { + parsed, err := netip.ParseAddr(ipAddress) + if err == nil { + ipAddr = &parsed + } + } + + dbLog, err := s.queries.CreateConnectionLog(ctx, sqlc.CreateConnectionLogParams{ + AgentID: pgtype.UUID{Bytes: parsedID, Valid: true}, + ConnectedAt: pgtype.Timestamp{Time: connectedAt, Valid: true}, + IpAddress: ipAddr, + }) + if err != nil { + return "", fmt.Errorf("failed to create connection log: %w", err) + } + + return uuidToString(dbLog.ID.Bytes), nil +} + +// UpdateConnectionLog updates a connection log with disconnect information +func (s *Service) UpdateConnectionLog(ctx context.Context, logID string, disconnectedAt time.Time, reason string) error { + parsedID, err := uuid.Parse(logID) + if err != nil { + return fmt.Errorf("invalid log ID: %w", err) + } + + if err := s.queries.UpdateConnectionLog(ctx, sqlc.UpdateConnectionLogParams{ + ID: pgtype.UUID{Bytes: parsedID, Valid: true}, + DisconnectedAt: pgtype.Timestamp{Time: disconnectedAt, Valid: true}, + DisconnectReason: pgtype.Text{String: reason, Valid: reason != ""}, + }); err != nil { + return fmt.Errorf("failed to update connection log: %w", err) + } + + return nil +} + +// GetAgentConnectionHistory retrieves connection history for an agent +func (s *Service) GetAgentConnectionHistory(ctx context.Context, agentID string, limit, offset int) ([]ConnectionLog, error) { + parsedID, err := uuid.Parse(agentID) + if err != nil { + return nil, ErrInvalidAgentID + } + + dbLogs, err := s.queries.GetAgentConnectionHistory(ctx, sqlc.GetAgentConnectionHistoryParams{ + AgentID: pgtype.UUID{Bytes: parsedID, Valid: true}, + Limit: int32(limit), + Offset: int32(offset), + }) + if err != nil { + return nil, fmt.Errorf("failed to get connection history: %w", err) + } + + result := make([]ConnectionLog, len(dbLogs)) + for i, l := range dbLogs { + result[i] = ConnectionLog{ + ID: uuidToString(l.ID.Bytes), + AgentID: uuidToString(l.AgentID.Bytes), + ConnectedAt: l.ConnectedAt.Time, + DurationSeconds: int(l.DurationSeconds.Int32), + DisconnectReason: l.DisconnectReason.String, + } + + if l.DisconnectedAt.Valid { + result[i].DisconnectedAt = &l.DisconnectedAt.Time + } + + if l.IpAddress != nil { + result[i].IPAddress = l.IpAddress.String() + } + } + + return result, nil +} + +func uuidToString(id [16]byte) string { + return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", + id[0:4], id[4:6], id[6:8], id[8:10], id[10:16]) +} diff --git a/internal/api/http/dto/provisioning.go b/internal/api/http/dto/provisioning.go new file mode 100644 index 0000000..24a2bf3 --- /dev/null +++ b/internal/api/http/dto/provisioning.go @@ -0,0 +1,44 @@ +package dto + +import "time" + +type CreateProvisioningKeyRequest struct { + MaxUses int `json:"max_uses" binding:"required,min=1"` + ExpiresInHours int `json:"expires_in_hours" binding:"required,min=1"` + Notes string `json:"notes"` +} + +type ProvisioningKeyResponse struct { + ID string `json:"id"` + Key string `json:"key,omitempty"` // Only returned on creation + Status string `json:"status"` + MaxUses int `json:"max_uses"` + UsedCount int `json:"used_count"` + ExpiresAt time.Time `json:"expires_at"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + RevokedAt *time.Time `json:"revoked_at,omitempty"` + Notes string `json:"notes,omitempty"` +} + +type ListProvisioningKeysResponse struct { + Keys []ProvisioningKeyResponse `json:"keys"` +} + +type AgentResponse struct { + ID string `json:"id"` + Status string `json:"status"` + RegisteredAt time.Time `json:"registered_at"` + LastSeenAt time.Time `json:"last_seen_at"` + LastIPAddress string `json:"last_ip_address,omitempty"` + Connected bool `json:"connected"` + Port int `json:"port,omitempty"` + CertFingerprint string `json:"cert_fingerprint,omitempty"` + ProvisionedWithKeyID string `json:"provisioned_with_key_id,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` + Notes string `json:"notes,omitempty"` +} + +type ListAgentsResponse struct { + Agents []AgentResponse `json:"agents"` +} diff --git a/internal/api/http/handler/agents.go b/internal/api/http/handler/agents.go new file mode 100644 index 0000000..4af9fb0 --- /dev/null +++ b/internal/api/http/handler/agents.go @@ -0,0 +1,175 @@ +package handler + +import ( + "log/slog" + "net/http" + + "github.com/EternisAI/silo-proxy/internal/agents" + "github.com/EternisAI/silo-proxy/internal/api/http/dto" + grpcserver "github.com/EternisAI/silo-proxy/internal/grpc/server" + "github.com/gin-gonic/gin" +) + +type AgentsHandler struct { + agentService *agents.Service + grpcServer *grpcserver.Server +} + +func NewAgentsHandler(agentService *agents.Service, grpcServer *grpcserver.Server) *AgentsHandler { + return &AgentsHandler{ + agentService: agentService, + grpcServer: grpcServer, + } +} + +// ListAgents returns all agents for the authenticated user +// GET /agents +func (h *AgentsHandler) ListAgents(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + c.JSON(http.StatusUnauthorized, gin.H{"error": "user_id not found in context"}) + return + } + + agentList, err := h.agentService.ListAgentsByUser(c.Request.Context(), userID) + if err != nil { + slog.Error("Failed to list agents", "error", err, "user_id", userID) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to list agents"}) + return + } + + // Get connection manager to check which agents are currently connected + connManager := h.grpcServer.GetConnectionManager() + + responses := make([]dto.AgentResponse, len(agentList)) + for i, a := range agentList { + response := dto.AgentResponse{ + ID: a.ID, + Status: a.Status, + RegisteredAt: a.RegisteredAt, + LastSeenAt: a.LastSeenAt, + LastIPAddress: a.LastIPAddress, + CertFingerprint: a.CertFingerprint, + ProvisionedWithKeyID: a.ProvisionedWithKeyID, + Metadata: a.Metadata, + Notes: a.Notes, + Connected: false, + } + + // Check if agent is currently connected + if conn, ok := connManager.GetConnection(a.ID); ok { + response.Connected = true + response.Port = conn.Port + } + + responses[i] = response + } + + c.JSON(http.StatusOK, dto.ListAgentsResponse{Agents: responses}) +} + +// GetAgent returns details for a specific agent +// GET /agents/:id +func (h *AgentsHandler) GetAgent(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + c.JSON(http.StatusUnauthorized, gin.H{"error": "user_id not found in context"}) + return + } + + agentID := c.Param("id") + if agentID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "agent_id is required"}) + return + } + + agent, err := h.agentService.GetAgentByID(c.Request.Context(), agentID) + if err != nil { + if err == agents.ErrAgentNotFound { + c.JSON(http.StatusNotFound, gin.H{"error": "agent not found"}) + return + } + slog.Error("Failed to get agent", "error", err, "agent_id", agentID) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get agent"}) + return + } + + // Verify ownership + if agent.UserID != userID { + c.JSON(http.StatusForbidden, gin.H{"error": "access denied"}) + return + } + + // Get connection manager to check if agent is currently connected + connManager := h.grpcServer.GetConnectionManager() + response := dto.AgentResponse{ + ID: agent.ID, + Status: agent.Status, + RegisteredAt: agent.RegisteredAt, + LastSeenAt: agent.LastSeenAt, + LastIPAddress: agent.LastIPAddress, + CertFingerprint: agent.CertFingerprint, + ProvisionedWithKeyID: agent.ProvisionedWithKeyID, + Metadata: agent.Metadata, + Notes: agent.Notes, + Connected: false, + } + + if conn, ok := connManager.GetConnection(agent.ID); ok { + response.Connected = true + response.Port = conn.Port + } + + c.JSON(http.StatusOK, response) +} + +// DeregisterAgent soft-deletes an agent (sets status to inactive) and disconnects if online +// DELETE /agents/:id +func (h *AgentsHandler) DeregisterAgent(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + c.JSON(http.StatusUnauthorized, gin.H{"error": "user_id not found in context"}) + return + } + + agentID := c.Param("id") + if agentID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "agent_id is required"}) + return + } + + // Get agent to verify ownership + agent, err := h.agentService.GetAgentByID(c.Request.Context(), agentID) + if err != nil { + if err == agents.ErrAgentNotFound { + c.JSON(http.StatusNotFound, gin.H{"error": "agent not found"}) + return + } + slog.Error("Failed to get agent", "error", err, "agent_id", agentID) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get agent"}) + return + } + + // Verify ownership + if agent.UserID != userID { + c.JSON(http.StatusForbidden, gin.H{"error": "access denied"}) + return + } + + // Update status to inactive + if err := h.agentService.UpdateStatus(c.Request.Context(), agentID, "inactive"); err != nil { + slog.Error("Failed to update agent status", "error", err, "agent_id", agentID) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to deregister agent"}) + return + } + + // Disconnect if currently online + connManager := h.grpcServer.GetConnectionManager() + if _, ok := connManager.GetConnection(agentID); ok { + connManager.Deregister(agentID) + slog.Info("Agent forcefully disconnected", "agent_id", agentID, "user_id", userID) + } + + slog.Info("Agent deregistered", "agent_id", agentID, "user_id", userID) + c.JSON(http.StatusOK, gin.H{"message": "agent deregistered"}) +} diff --git a/internal/api/http/handler/provisioning.go b/internal/api/http/handler/provisioning.go new file mode 100644 index 0000000..ac34d03 --- /dev/null +++ b/internal/api/http/handler/provisioning.go @@ -0,0 +1,117 @@ +package handler + +import ( + "log/slog" + "net/http" + + "github.com/EternisAI/silo-proxy/internal/api/http/dto" + "github.com/EternisAI/silo-proxy/internal/provisioning" + "github.com/gin-gonic/gin" +) + +type ProvisioningHandler struct { + provisioningService *provisioning.Service +} + +func NewProvisioningHandler(provisioningService *provisioning.Service) *ProvisioningHandler { + return &ProvisioningHandler{ + provisioningService: provisioningService, + } +} + +// CreateKey generates a new provisioning key +// POST /provisioning-keys +func (h *ProvisioningHandler) CreateKey(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + c.JSON(http.StatusUnauthorized, gin.H{"error": "user_id not found in context"}) + return + } + + var req dto.CreateProvisioningKeyRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + key, plaintextKey, err := h.provisioningService.CreateKey(c.Request.Context(), userID, req.MaxUses, req.ExpiresInHours, req.Notes) + if err != nil { + slog.Error("Failed to create provisioning key", "error", err, "user_id", userID) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create provisioning key"}) + return + } + + response := dto.ProvisioningKeyResponse{ + ID: key.ID, + Key: plaintextKey, // Only shown once! + Status: key.Status, + MaxUses: key.MaxUses, + UsedCount: key.UsedCount, + ExpiresAt: key.ExpiresAt, + CreatedAt: key.CreatedAt, + UpdatedAt: key.UpdatedAt, + Notes: key.Notes, + } + + slog.Info("Provisioning key created", "key_id", key.ID, "user_id", userID, "max_uses", key.MaxUses) + c.JSON(http.StatusCreated, response) +} + +// ListKeys returns all provisioning keys for the authenticated user +// GET /provisioning-keys +func (h *ProvisioningHandler) ListKeys(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + c.JSON(http.StatusUnauthorized, gin.H{"error": "user_id not found in context"}) + return + } + + keys, err := h.provisioningService.ListUserKeys(c.Request.Context(), userID) + if err != nil { + slog.Error("Failed to list provisioning keys", "error", err, "user_id", userID) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to list provisioning keys"}) + return + } + + responses := make([]dto.ProvisioningKeyResponse, len(keys)) + for i, k := range keys { + responses[i] = dto.ProvisioningKeyResponse{ + ID: k.ID, + Status: k.Status, + MaxUses: k.MaxUses, + UsedCount: k.UsedCount, + ExpiresAt: k.ExpiresAt, + CreatedAt: k.CreatedAt, + UpdatedAt: k.UpdatedAt, + RevokedAt: k.RevokedAt, + Notes: k.Notes, + } + } + + c.JSON(http.StatusOK, dto.ListProvisioningKeysResponse{Keys: responses}) +} + +// RevokeKey revokes a provisioning key +// DELETE /provisioning-keys/:id +func (h *ProvisioningHandler) RevokeKey(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + c.JSON(http.StatusUnauthorized, gin.H{"error": "user_id not found in context"}) + return + } + + keyID := c.Param("id") + if keyID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "key_id is required"}) + return + } + + if err := h.provisioningService.RevokeKey(c.Request.Context(), keyID, userID); err != nil { + slog.Error("Failed to revoke provisioning key", "error", err, "key_id", keyID, "user_id", userID) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to revoke provisioning key"}) + return + } + + slog.Info("Provisioning key revoked", "key_id", keyID, "user_id", userID) + c.JSON(http.StatusOK, gin.H{"message": "provisioning key revoked"}) +} diff --git a/internal/api/http/router.go b/internal/api/http/router.go index 9814538..46fd350 100644 --- a/internal/api/http/router.go +++ b/internal/api/http/router.go @@ -1,20 +1,24 @@ package http import ( + "github.com/EternisAI/silo-proxy/internal/agents" "github.com/EternisAI/silo-proxy/internal/api/http/handler" "github.com/EternisAI/silo-proxy/internal/api/http/middleware" "github.com/EternisAI/silo-proxy/internal/auth" "github.com/EternisAI/silo-proxy/internal/cert" grpcserver "github.com/EternisAI/silo-proxy/internal/grpc/server" + "github.com/EternisAI/silo-proxy/internal/provisioning" "github.com/EternisAI/silo-proxy/internal/users" "github.com/gin-gonic/gin" ) type Services struct { - GrpcServer *grpcserver.Server - CertService *cert.Service - AuthService *auth.Service - UserService *users.Service + GrpcServer *grpcserver.Server + CertService *cert.Service + AuthService *auth.Service + UserService *users.Service + ProvisioningService *provisioning.Service + AgentService *agents.Service } func SetupRoute(engine *gin.Engine, srvs *Services, adminAPIKey string, jwtSecret string) { @@ -38,13 +42,38 @@ func SetupRoute(engine *gin.Engine, srvs *Services, adminAPIKey string, jwtSecre usersGroup.GET("", middleware.RequireRole("Admin"), userHandler.ListUsers) } + // Provisioning key management (authenticated) + if srvs.ProvisioningService != nil { + provisioningHandler := handler.NewProvisioningHandler(srvs.ProvisioningService) + provisioningRoutes := engine.Group("/provisioning-keys") + provisioningRoutes.Use(middleware.JWTAuth(jwtSecret)) + { + provisioningRoutes.POST("", provisioningHandler.CreateKey) + provisioningRoutes.GET("", provisioningHandler.ListKeys) + provisioningRoutes.DELETE("/:id", provisioningHandler.RevokeKey) + } + } + + // Agent management agents := engine.Group("/agents") { - if srvs.GrpcServer != nil { + // Authenticated agent management (user-scoped) + if srvs.AgentService != nil && srvs.GrpcServer != nil { + agentsHandler := handler.NewAgentsHandler(srvs.AgentService, srvs.GrpcServer) + agentRoutes := agents.Group("") + agentRoutes.Use(middleware.JWTAuth(jwtSecret)) + { + agentRoutes.GET("", agentsHandler.ListAgents) + agentRoutes.GET("/:id", agentsHandler.GetAgent) + agentRoutes.DELETE("/:id", agentsHandler.DeregisterAgent) + } + } else if srvs.GrpcServer != nil { + // Legacy admin handler (backward compatibility) adminHandler := handler.NewAdminHandler(srvs.GrpcServer) agents.GET("", adminHandler.ListAgents) } + // Certificate management (API key auth) certHandler := handler.NewCertHandler(srvs.CertService) certRoutes := agents.Group("") certRoutes.Use(middleware.APIKeyAuth(adminAPIKey)) diff --git a/internal/db/migrations/0002_create_provisioning_keys.sql b/internal/db/migrations/0002_create_provisioning_keys.sql new file mode 100644 index 0000000..4149b79 --- /dev/null +++ b/internal/db/migrations/0002_create_provisioning_keys.sql @@ -0,0 +1,31 @@ +-- +goose Up +-- +goose StatementBegin +CREATE TYPE provisioning_key_status AS ENUM ('active', 'exhausted', 'expired', 'revoked'); + +CREATE TABLE IF NOT EXISTS provisioning_keys ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + key_hash VARCHAR(255) NOT NULL UNIQUE, + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + status provisioning_key_status NOT NULL DEFAULT 'active', + max_uses INT NOT NULL DEFAULT 1, + used_count INT NOT NULL DEFAULT 0, + expires_at TIMESTAMP NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP NOT NULL DEFAULT NOW(), + revoked_at TIMESTAMP, + notes TEXT +); + +CREATE INDEX IF NOT EXISTS idx_provisioning_keys_user_id ON provisioning_keys(user_id); +CREATE INDEX IF NOT EXISTS idx_provisioning_keys_status ON provisioning_keys(status); +CREATE INDEX IF NOT EXISTS idx_provisioning_keys_key_hash ON provisioning_keys(key_hash); +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +DROP INDEX IF EXISTS idx_provisioning_keys_key_hash; +DROP INDEX IF EXISTS idx_provisioning_keys_status; +DROP INDEX IF EXISTS idx_provisioning_keys_user_id; +DROP TABLE IF EXISTS provisioning_keys; +DROP TYPE IF EXISTS provisioning_key_status; +-- +goose StatementEnd diff --git a/internal/db/migrations/0003_create_agents.sql b/internal/db/migrations/0003_create_agents.sql new file mode 100644 index 0000000..49a48a0 --- /dev/null +++ b/internal/db/migrations/0003_create_agents.sql @@ -0,0 +1,30 @@ +-- +goose Up +-- +goose StatementBegin +CREATE TYPE agent_status AS ENUM ('active', 'inactive', 'suspended'); + +CREATE TABLE IF NOT EXISTS agents ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + provisioned_with_key_id UUID REFERENCES provisioning_keys(id) ON DELETE SET NULL, + status agent_status NOT NULL DEFAULT 'active', + cert_fingerprint VARCHAR(255), + registered_at TIMESTAMP NOT NULL DEFAULT NOW(), + last_seen_at TIMESTAMP NOT NULL DEFAULT NOW(), + last_ip_address INET, + metadata JSONB, + notes TEXT +); + +CREATE INDEX IF NOT EXISTS idx_agents_user_id ON agents(user_id); +CREATE INDEX IF NOT EXISTS idx_agents_status ON agents(status); +CREATE INDEX IF NOT EXISTS idx_agents_last_seen_at ON agents(last_seen_at); +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +DROP INDEX IF EXISTS idx_agents_last_seen_at; +DROP INDEX IF EXISTS idx_agents_status; +DROP INDEX IF EXISTS idx_agents_user_id; +DROP TABLE IF EXISTS agents; +DROP TYPE IF EXISTS agent_status; +-- +goose StatementEnd diff --git a/internal/db/migrations/0004_create_agent_connection_logs.sql b/internal/db/migrations/0004_create_agent_connection_logs.sql new file mode 100644 index 0000000..8ed397a --- /dev/null +++ b/internal/db/migrations/0004_create_agent_connection_logs.sql @@ -0,0 +1,22 @@ +-- +goose Up +-- +goose StatementBegin +CREATE TABLE IF NOT EXISTS agent_connection_logs ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + agent_id UUID NOT NULL REFERENCES agents(id) ON DELETE CASCADE, + connected_at TIMESTAMP NOT NULL DEFAULT NOW(), + disconnected_at TIMESTAMP, + duration_seconds INT, + ip_address INET, + disconnect_reason VARCHAR(255) +); + +CREATE INDEX IF NOT EXISTS idx_connection_logs_agent_id ON agent_connection_logs(agent_id); +CREATE INDEX IF NOT EXISTS idx_connection_logs_connected_at ON agent_connection_logs(connected_at); +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +DROP INDEX IF EXISTS idx_connection_logs_connected_at; +DROP INDEX IF EXISTS idx_connection_logs_agent_id; +DROP TABLE IF EXISTS agent_connection_logs; +-- +goose StatementEnd diff --git a/internal/db/queries/agent_connection_logs.sql b/internal/db/queries/agent_connection_logs.sql new file mode 100644 index 0000000..0a3ff1d --- /dev/null +++ b/internal/db/queries/agent_connection_logs.sql @@ -0,0 +1,16 @@ +-- name: CreateConnectionLog :one +INSERT INTO agent_connection_logs (agent_id, connected_at, ip_address) +VALUES ($1, $2, $3) RETURNING *; + +-- name: UpdateConnectionLog :exec +UPDATE agent_connection_logs +SET disconnected_at = $2, + duration_seconds = EXTRACT(EPOCH FROM ($2 - connected_at))::INT, + disconnect_reason = $3 +WHERE id = $1; + +-- name: GetAgentConnectionHistory :many +SELECT * FROM agent_connection_logs +WHERE agent_id = $1 +ORDER BY connected_at DESC +LIMIT $2 OFFSET $3; diff --git a/internal/db/queries/agents.sql b/internal/db/queries/agents.sql new file mode 100644 index 0000000..fc894d1 --- /dev/null +++ b/internal/db/queries/agents.sql @@ -0,0 +1,18 @@ +-- name: CreateAgent :one +INSERT INTO agents (user_id, provisioned_with_key_id, metadata, notes) +VALUES ($1, $2, $3, $4) RETURNING *; + +-- name: GetAgentByID :one +SELECT * FROM agents WHERE id = $1 LIMIT 1; + +-- name: ListAgentsByUser :many +SELECT * FROM agents WHERE user_id = $1 ORDER BY registered_at DESC; + +-- name: UpdateAgentLastSeen :exec +UPDATE agents SET last_seen_at = $2, last_ip_address = $3 WHERE id = $1; + +-- name: UpdateAgentStatus :exec +UPDATE agents SET status = $2 WHERE id = $1; + +-- name: UpdateAgentCertFingerprint :exec +UPDATE agents SET cert_fingerprint = $2 WHERE id = $1; diff --git a/internal/db/queries/provisioning_keys.sql b/internal/db/queries/provisioning_keys.sql new file mode 100644 index 0000000..60772ee --- /dev/null +++ b/internal/db/queries/provisioning_keys.sql @@ -0,0 +1,30 @@ +-- name: CreateProvisioningKey :one +INSERT INTO provisioning_keys (key_hash, user_id, max_uses, expires_at, notes) +VALUES ($1, $2, $3, $4, $5) RETURNING *; + +-- name: GetProvisioningKeyByHash :one +SELECT * FROM provisioning_keys +WHERE key_hash = $1 AND status = 'active' LIMIT 1; + +-- name: ListProvisioningKeysByUser :many +SELECT * FROM provisioning_keys WHERE user_id = $1 ORDER BY created_at DESC; + +-- name: IncrementKeyUsage :one +UPDATE provisioning_keys +SET used_count = used_count + 1, + status = CASE WHEN used_count + 1 >= max_uses + THEN 'exhausted'::provisioning_key_status + ELSE status END, + updated_at = NOW() +WHERE id = $1 AND used_count < max_uses AND status = 'active' +RETURNING *; + +-- name: RevokeProvisioningKey :exec +UPDATE provisioning_keys +SET status = 'revoked'::provisioning_key_status, revoked_at = NOW(), updated_at = NOW() +WHERE id = $1 AND user_id = $2; + +-- name: ExpireOldKeys :exec +UPDATE provisioning_keys +SET status = 'expired'::provisioning_key_status, updated_at = NOW() +WHERE status = 'active' AND expires_at < NOW(); diff --git a/internal/db/sqlc/agent_connection_logs.sql.go b/internal/db/sqlc/agent_connection_logs.sql.go new file mode 100644 index 0000000..1bc9549 --- /dev/null +++ b/internal/db/sqlc/agent_connection_logs.sql.go @@ -0,0 +1,99 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 +// source: agent_connection_logs.sql + +package sqlc + +import ( + "context" + "net/netip" + + "github.com/jackc/pgx/v5/pgtype" +) + +const createConnectionLog = `-- name: CreateConnectionLog :one +INSERT INTO agent_connection_logs (agent_id, connected_at, ip_address) +VALUES ($1, $2, $3) RETURNING id, agent_id, connected_at, disconnected_at, duration_seconds, ip_address, disconnect_reason +` + +type CreateConnectionLogParams struct { + AgentID pgtype.UUID `json:"agent_id"` + ConnectedAt pgtype.Timestamp `json:"connected_at"` + IpAddress *netip.Addr `json:"ip_address"` +} + +func (q *Queries) CreateConnectionLog(ctx context.Context, arg CreateConnectionLogParams) (AgentConnectionLog, error) { + row := q.db.QueryRow(ctx, createConnectionLog, arg.AgentID, arg.ConnectedAt, arg.IpAddress) + var i AgentConnectionLog + err := row.Scan( + &i.ID, + &i.AgentID, + &i.ConnectedAt, + &i.DisconnectedAt, + &i.DurationSeconds, + &i.IpAddress, + &i.DisconnectReason, + ) + return i, err +} + +const getAgentConnectionHistory = `-- name: GetAgentConnectionHistory :many +SELECT id, agent_id, connected_at, disconnected_at, duration_seconds, ip_address, disconnect_reason FROM agent_connection_logs +WHERE agent_id = $1 +ORDER BY connected_at DESC +LIMIT $2 OFFSET $3 +` + +type GetAgentConnectionHistoryParams struct { + AgentID pgtype.UUID `json:"agent_id"` + Limit int32 `json:"limit"` + Offset int32 `json:"offset"` +} + +func (q *Queries) GetAgentConnectionHistory(ctx context.Context, arg GetAgentConnectionHistoryParams) ([]AgentConnectionLog, error) { + rows, err := q.db.Query(ctx, getAgentConnectionHistory, arg.AgentID, arg.Limit, arg.Offset) + if err != nil { + return nil, err + } + defer rows.Close() + items := []AgentConnectionLog{} + for rows.Next() { + var i AgentConnectionLog + if err := rows.Scan( + &i.ID, + &i.AgentID, + &i.ConnectedAt, + &i.DisconnectedAt, + &i.DurationSeconds, + &i.IpAddress, + &i.DisconnectReason, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const updateConnectionLog = `-- name: UpdateConnectionLog :exec +UPDATE agent_connection_logs +SET disconnected_at = $2, + duration_seconds = EXTRACT(EPOCH FROM ($2 - connected_at))::INT, + disconnect_reason = $3 +WHERE id = $1 +` + +type UpdateConnectionLogParams struct { + ID pgtype.UUID `json:"id"` + DisconnectedAt pgtype.Timestamp `json:"disconnected_at"` + DisconnectReason pgtype.Text `json:"disconnect_reason"` +} + +func (q *Queries) UpdateConnectionLog(ctx context.Context, arg UpdateConnectionLogParams) error { + _, err := q.db.Exec(ctx, updateConnectionLog, arg.ID, arg.DisconnectedAt, arg.DisconnectReason) + return err +} diff --git a/internal/db/sqlc/agents.sql.go b/internal/db/sqlc/agents.sql.go new file mode 100644 index 0000000..0998286 --- /dev/null +++ b/internal/db/sqlc/agents.sql.go @@ -0,0 +1,148 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 +// source: agents.sql + +package sqlc + +import ( + "context" + "net/netip" + + "github.com/jackc/pgx/v5/pgtype" +) + +const createAgent = `-- name: CreateAgent :one +INSERT INTO agents (user_id, provisioned_with_key_id, metadata, notes) +VALUES ($1, $2, $3, $4) RETURNING id, user_id, provisioned_with_key_id, status, cert_fingerprint, registered_at, last_seen_at, last_ip_address, metadata, notes +` + +type CreateAgentParams struct { + UserID pgtype.UUID `json:"user_id"` + ProvisionedWithKeyID pgtype.UUID `json:"provisioned_with_key_id"` + Metadata []byte `json:"metadata"` + Notes pgtype.Text `json:"notes"` +} + +func (q *Queries) CreateAgent(ctx context.Context, arg CreateAgentParams) (Agent, error) { + row := q.db.QueryRow(ctx, createAgent, + arg.UserID, + arg.ProvisionedWithKeyID, + arg.Metadata, + arg.Notes, + ) + var i Agent + err := row.Scan( + &i.ID, + &i.UserID, + &i.ProvisionedWithKeyID, + &i.Status, + &i.CertFingerprint, + &i.RegisteredAt, + &i.LastSeenAt, + &i.LastIpAddress, + &i.Metadata, + &i.Notes, + ) + return i, err +} + +const getAgentByID = `-- name: GetAgentByID :one +SELECT id, user_id, provisioned_with_key_id, status, cert_fingerprint, registered_at, last_seen_at, last_ip_address, metadata, notes FROM agents WHERE id = $1 LIMIT 1 +` + +func (q *Queries) GetAgentByID(ctx context.Context, id pgtype.UUID) (Agent, error) { + row := q.db.QueryRow(ctx, getAgentByID, id) + var i Agent + err := row.Scan( + &i.ID, + &i.UserID, + &i.ProvisionedWithKeyID, + &i.Status, + &i.CertFingerprint, + &i.RegisteredAt, + &i.LastSeenAt, + &i.LastIpAddress, + &i.Metadata, + &i.Notes, + ) + return i, err +} + +const listAgentsByUser = `-- name: ListAgentsByUser :many +SELECT id, user_id, provisioned_with_key_id, status, cert_fingerprint, registered_at, last_seen_at, last_ip_address, metadata, notes FROM agents WHERE user_id = $1 ORDER BY registered_at DESC +` + +func (q *Queries) ListAgentsByUser(ctx context.Context, userID pgtype.UUID) ([]Agent, error) { + rows, err := q.db.Query(ctx, listAgentsByUser, userID) + if err != nil { + return nil, err + } + defer rows.Close() + items := []Agent{} + for rows.Next() { + var i Agent + if err := rows.Scan( + &i.ID, + &i.UserID, + &i.ProvisionedWithKeyID, + &i.Status, + &i.CertFingerprint, + &i.RegisteredAt, + &i.LastSeenAt, + &i.LastIpAddress, + &i.Metadata, + &i.Notes, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const updateAgentCertFingerprint = `-- name: UpdateAgentCertFingerprint :exec +UPDATE agents SET cert_fingerprint = $2 WHERE id = $1 +` + +type UpdateAgentCertFingerprintParams struct { + ID pgtype.UUID `json:"id"` + CertFingerprint pgtype.Text `json:"cert_fingerprint"` +} + +func (q *Queries) UpdateAgentCertFingerprint(ctx context.Context, arg UpdateAgentCertFingerprintParams) error { + _, err := q.db.Exec(ctx, updateAgentCertFingerprint, arg.ID, arg.CertFingerprint) + return err +} + +const updateAgentLastSeen = `-- name: UpdateAgentLastSeen :exec +UPDATE agents SET last_seen_at = $2, last_ip_address = $3 WHERE id = $1 +` + +type UpdateAgentLastSeenParams struct { + ID pgtype.UUID `json:"id"` + LastSeenAt pgtype.Timestamp `json:"last_seen_at"` + LastIpAddress *netip.Addr `json:"last_ip_address"` +} + +func (q *Queries) UpdateAgentLastSeen(ctx context.Context, arg UpdateAgentLastSeenParams) error { + _, err := q.db.Exec(ctx, updateAgentLastSeen, arg.ID, arg.LastSeenAt, arg.LastIpAddress) + return err +} + +const updateAgentStatus = `-- name: UpdateAgentStatus :exec +UPDATE agents SET status = $2 WHERE id = $1 +` + +type UpdateAgentStatusParams struct { + ID pgtype.UUID `json:"id"` + Status AgentStatus `json:"status"` +} + +func (q *Queries) UpdateAgentStatus(ctx context.Context, arg UpdateAgentStatusParams) error { + _, err := q.db.Exec(ctx, updateAgentStatus, arg.ID, arg.Status) + return err +} diff --git a/internal/db/sqlc/models.go b/internal/db/sqlc/models.go index 90880db..2879e4a 100644 --- a/internal/db/sqlc/models.go +++ b/internal/db/sqlc/models.go @@ -7,10 +7,98 @@ package sqlc import ( "database/sql/driver" "fmt" + "net/netip" "github.com/jackc/pgx/v5/pgtype" ) +type AgentStatus string + +const ( + AgentStatusActive AgentStatus = "active" + AgentStatusInactive AgentStatus = "inactive" + AgentStatusSuspended AgentStatus = "suspended" +) + +func (e *AgentStatus) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = AgentStatus(s) + case string: + *e = AgentStatus(s) + default: + return fmt.Errorf("unsupported scan type for AgentStatus: %T", src) + } + return nil +} + +type NullAgentStatus struct { + AgentStatus AgentStatus `json:"agent_status"` + Valid bool `json:"valid"` // Valid is true if AgentStatus is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullAgentStatus) Scan(value interface{}) error { + if value == nil { + ns.AgentStatus, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.AgentStatus.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullAgentStatus) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.AgentStatus), nil +} + +type ProvisioningKeyStatus string + +const ( + ProvisioningKeyStatusActive ProvisioningKeyStatus = "active" + ProvisioningKeyStatusExhausted ProvisioningKeyStatus = "exhausted" + ProvisioningKeyStatusExpired ProvisioningKeyStatus = "expired" + ProvisioningKeyStatusRevoked ProvisioningKeyStatus = "revoked" +) + +func (e *ProvisioningKeyStatus) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = ProvisioningKeyStatus(s) + case string: + *e = ProvisioningKeyStatus(s) + default: + return fmt.Errorf("unsupported scan type for ProvisioningKeyStatus: %T", src) + } + return nil +} + +type NullProvisioningKeyStatus struct { + ProvisioningKeyStatus ProvisioningKeyStatus `json:"provisioning_key_status"` + Valid bool `json:"valid"` // Valid is true if ProvisioningKeyStatus is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullProvisioningKeyStatus) Scan(value interface{}) error { + if value == nil { + ns.ProvisioningKeyStatus, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.ProvisioningKeyStatus.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullProvisioningKeyStatus) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.ProvisioningKeyStatus), nil +} + type UserRole string const ( @@ -53,6 +141,43 @@ func (ns NullUserRole) Value() (driver.Value, error) { return string(ns.UserRole), nil } +type Agent struct { + ID pgtype.UUID `json:"id"` + UserID pgtype.UUID `json:"user_id"` + ProvisionedWithKeyID pgtype.UUID `json:"provisioned_with_key_id"` + Status AgentStatus `json:"status"` + CertFingerprint pgtype.Text `json:"cert_fingerprint"` + RegisteredAt pgtype.Timestamp `json:"registered_at"` + LastSeenAt pgtype.Timestamp `json:"last_seen_at"` + LastIpAddress *netip.Addr `json:"last_ip_address"` + Metadata []byte `json:"metadata"` + Notes pgtype.Text `json:"notes"` +} + +type AgentConnectionLog struct { + ID pgtype.UUID `json:"id"` + AgentID pgtype.UUID `json:"agent_id"` + ConnectedAt pgtype.Timestamp `json:"connected_at"` + DisconnectedAt pgtype.Timestamp `json:"disconnected_at"` + DurationSeconds pgtype.Int4 `json:"duration_seconds"` + IpAddress *netip.Addr `json:"ip_address"` + DisconnectReason pgtype.Text `json:"disconnect_reason"` +} + +type ProvisioningKey struct { + ID pgtype.UUID `json:"id"` + KeyHash string `json:"key_hash"` + UserID pgtype.UUID `json:"user_id"` + Status ProvisioningKeyStatus `json:"status"` + MaxUses int32 `json:"max_uses"` + UsedCount int32 `json:"used_count"` + ExpiresAt pgtype.Timestamp `json:"expires_at"` + CreatedAt pgtype.Timestamp `json:"created_at"` + UpdatedAt pgtype.Timestamp `json:"updated_at"` + RevokedAt pgtype.Timestamp `json:"revoked_at"` + Notes pgtype.Text `json:"notes"` +} + type User struct { ID pgtype.UUID `json:"id"` Username string `json:"username"` diff --git a/internal/db/sqlc/provisioning_keys.sql.go b/internal/db/sqlc/provisioning_keys.sql.go new file mode 100644 index 0000000..4222f9c --- /dev/null +++ b/internal/db/sqlc/provisioning_keys.sql.go @@ -0,0 +1,167 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 +// source: provisioning_keys.sql + +package sqlc + +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" +) + +const createProvisioningKey = `-- name: CreateProvisioningKey :one +INSERT INTO provisioning_keys (key_hash, user_id, max_uses, expires_at, notes) +VALUES ($1, $2, $3, $4, $5) RETURNING id, key_hash, user_id, status, max_uses, used_count, expires_at, created_at, updated_at, revoked_at, notes +` + +type CreateProvisioningKeyParams struct { + KeyHash string `json:"key_hash"` + UserID pgtype.UUID `json:"user_id"` + MaxUses int32 `json:"max_uses"` + ExpiresAt pgtype.Timestamp `json:"expires_at"` + Notes pgtype.Text `json:"notes"` +} + +func (q *Queries) CreateProvisioningKey(ctx context.Context, arg CreateProvisioningKeyParams) (ProvisioningKey, error) { + row := q.db.QueryRow(ctx, createProvisioningKey, + arg.KeyHash, + arg.UserID, + arg.MaxUses, + arg.ExpiresAt, + arg.Notes, + ) + var i ProvisioningKey + err := row.Scan( + &i.ID, + &i.KeyHash, + &i.UserID, + &i.Status, + &i.MaxUses, + &i.UsedCount, + &i.ExpiresAt, + &i.CreatedAt, + &i.UpdatedAt, + &i.RevokedAt, + &i.Notes, + ) + return i, err +} + +const expireOldKeys = `-- name: ExpireOldKeys :exec +UPDATE provisioning_keys +SET status = 'expired'::provisioning_key_status, updated_at = NOW() +WHERE status = 'active' AND expires_at < NOW() +` + +func (q *Queries) ExpireOldKeys(ctx context.Context) error { + _, err := q.db.Exec(ctx, expireOldKeys) + return err +} + +const getProvisioningKeyByHash = `-- name: GetProvisioningKeyByHash :one +SELECT id, key_hash, user_id, status, max_uses, used_count, expires_at, created_at, updated_at, revoked_at, notes FROM provisioning_keys +WHERE key_hash = $1 AND status = 'active' LIMIT 1 +` + +func (q *Queries) GetProvisioningKeyByHash(ctx context.Context, keyHash string) (ProvisioningKey, error) { + row := q.db.QueryRow(ctx, getProvisioningKeyByHash, keyHash) + var i ProvisioningKey + err := row.Scan( + &i.ID, + &i.KeyHash, + &i.UserID, + &i.Status, + &i.MaxUses, + &i.UsedCount, + &i.ExpiresAt, + &i.CreatedAt, + &i.UpdatedAt, + &i.RevokedAt, + &i.Notes, + ) + return i, err +} + +const incrementKeyUsage = `-- name: IncrementKeyUsage :one +UPDATE provisioning_keys +SET used_count = used_count + 1, + status = CASE WHEN used_count + 1 >= max_uses + THEN 'exhausted'::provisioning_key_status + ELSE status END, + updated_at = NOW() +WHERE id = $1 AND used_count < max_uses AND status = 'active' +RETURNING id, key_hash, user_id, status, max_uses, used_count, expires_at, created_at, updated_at, revoked_at, notes +` + +func (q *Queries) IncrementKeyUsage(ctx context.Context, id pgtype.UUID) (ProvisioningKey, error) { + row := q.db.QueryRow(ctx, incrementKeyUsage, id) + var i ProvisioningKey + err := row.Scan( + &i.ID, + &i.KeyHash, + &i.UserID, + &i.Status, + &i.MaxUses, + &i.UsedCount, + &i.ExpiresAt, + &i.CreatedAt, + &i.UpdatedAt, + &i.RevokedAt, + &i.Notes, + ) + return i, err +} + +const listProvisioningKeysByUser = `-- name: ListProvisioningKeysByUser :many +SELECT id, key_hash, user_id, status, max_uses, used_count, expires_at, created_at, updated_at, revoked_at, notes FROM provisioning_keys WHERE user_id = $1 ORDER BY created_at DESC +` + +func (q *Queries) ListProvisioningKeysByUser(ctx context.Context, userID pgtype.UUID) ([]ProvisioningKey, error) { + rows, err := q.db.Query(ctx, listProvisioningKeysByUser, userID) + if err != nil { + return nil, err + } + defer rows.Close() + items := []ProvisioningKey{} + for rows.Next() { + var i ProvisioningKey + if err := rows.Scan( + &i.ID, + &i.KeyHash, + &i.UserID, + &i.Status, + &i.MaxUses, + &i.UsedCount, + &i.ExpiresAt, + &i.CreatedAt, + &i.UpdatedAt, + &i.RevokedAt, + &i.Notes, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const revokeProvisioningKey = `-- name: RevokeProvisioningKey :exec +UPDATE provisioning_keys +SET status = 'revoked'::provisioning_key_status, revoked_at = NOW(), updated_at = NOW() +WHERE id = $1 AND user_id = $2 +` + +type RevokeProvisioningKeyParams struct { + ID pgtype.UUID `json:"id"` + UserID pgtype.UUID `json:"user_id"` +} + +func (q *Queries) RevokeProvisioningKey(ctx context.Context, arg RevokeProvisioningKeyParams) error { + _, err := q.db.Exec(ctx, revokeProvisioningKey, arg.ID, arg.UserID) + return err +} diff --git a/internal/db/sqlc/querier.go b/internal/db/sqlc/querier.go index 7595c40..2ee812a 100644 --- a/internal/db/sqlc/querier.go +++ b/internal/db/sqlc/querier.go @@ -12,11 +12,26 @@ import ( type Querier interface { CountUsers(ctx context.Context) (int64, error) + CreateAgent(ctx context.Context, arg CreateAgentParams) (Agent, error) + CreateConnectionLog(ctx context.Context, arg CreateConnectionLogParams) (AgentConnectionLog, error) + CreateProvisioningKey(ctx context.Context, arg CreateProvisioningKeyParams) (ProvisioningKey, error) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) DeleteUser(ctx context.Context, id pgtype.UUID) error + ExpireOldKeys(ctx context.Context) error + GetAgentByID(ctx context.Context, id pgtype.UUID) (Agent, error) + GetAgentConnectionHistory(ctx context.Context, arg GetAgentConnectionHistoryParams) ([]AgentConnectionLog, error) + GetProvisioningKeyByHash(ctx context.Context, keyHash string) (ProvisioningKey, error) GetUser(ctx context.Context, id pgtype.UUID) (User, error) GetUserByUsername(ctx context.Context, username string) (User, error) + IncrementKeyUsage(ctx context.Context, id pgtype.UUID) (ProvisioningKey, error) + ListAgentsByUser(ctx context.Context, userID pgtype.UUID) ([]Agent, error) + ListProvisioningKeysByUser(ctx context.Context, userID pgtype.UUID) ([]ProvisioningKey, error) ListUsersPaginated(ctx context.Context, arg ListUsersPaginatedParams) ([]User, error) + RevokeProvisioningKey(ctx context.Context, arg RevokeProvisioningKeyParams) error + UpdateAgentCertFingerprint(ctx context.Context, arg UpdateAgentCertFingerprintParams) error + UpdateAgentLastSeen(ctx context.Context, arg UpdateAgentLastSeenParams) error + UpdateAgentStatus(ctx context.Context, arg UpdateAgentStatusParams) error + UpdateConnectionLog(ctx context.Context, arg UpdateConnectionLogParams) error } var _ Querier = (*Queries)(nil) diff --git a/internal/grpc/client/client.go b/internal/grpc/client/client.go index 1f04e82..5e78e23 100644 --- a/internal/grpc/client/client.go +++ b/internal/grpc/client/client.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "log/slog" + "os" "sync" "time" @@ -12,6 +13,7 @@ import ( "github.com/google/uuid" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + "gopkg.in/yaml.v3" grpctls "github.com/EternisAI/silo-proxy/internal/grpc/tls" ) @@ -25,11 +27,13 @@ const ( ) type Client struct { - serverAddr string - agentID string - tlsConfig *TLSConfig - conn *grpc.ClientConn - stream proto.ProxyService_StreamClient + serverAddr string + agentID string + provisioningKey string + configPath string // Path to config file for persistence + tlsConfig *TLSConfig + conn *grpc.ClientConn + stream proto.ProxyService_StreamClient sendCh chan *proto.ProxyMessage stopCh chan struct{} @@ -53,11 +57,13 @@ type TLSConfig struct { ServerNameOverride string } -func NewClient(serverAddr, agentID, localURL string, tlsConfig *TLSConfig) *Client { +func NewClient(serverAddr, agentID, provisioningKey, localURL, configPath string, tlsConfig *TLSConfig) *Client { ctx, cancel := context.WithCancel(context.Background()) return &Client{ serverAddr: serverAddr, agentID: agentID, + provisioningKey: provisioningKey, + configPath: configPath, tlsConfig: tlsConfig, sendCh: make(chan *proto.ProxyMessage, sendChannelBuffer), stopCh: make(chan struct{}), @@ -172,12 +178,25 @@ func (c *Client) connect() error { return fmt.Errorf("failed to create stream: %w", err) } + // Build first message with either provisioning_key or agent_id firstMsg := &proto.ProxyMessage{ - Id: uuid.New().String(), - Type: proto.MessageType_PING, - Metadata: map[string]string{ - "agent_id": c.agentID, - }, + Id: uuid.New().String(), + Type: proto.MessageType_PING, + Metadata: make(map[string]string), + } + + if c.provisioningKey != "" { + // Provisioning flow: send provisioning_key + firstMsg.Metadata["provisioning_key"] = c.provisioningKey + slog.Info("Attempting to provision agent with key") + } else if c.agentID != "" { + // Established agent: send agent_id + firstMsg.Metadata["agent_id"] = c.agentID + slog.Info("Connecting with agent_id", "agent_id", c.agentID) + } else { + stream.CloseSend() + conn.Close() + return fmt.Errorf("either agent_id or provisioning_key is required") } if err := stream.Send(firstMsg); err != nil { @@ -186,6 +205,24 @@ func (c *Client) connect() error { return fmt.Errorf("failed to send first message: %w", err) } + // Wait for response if provisioning + if c.provisioningKey != "" { + resp, err := stream.Recv() + if err != nil { + stream.CloseSend() + conn.Close() + return fmt.Errorf("failed to receive provisioning response: %w", err) + } + + if err := c.handleProvisioningResponse(resp); err != nil { + stream.CloseSend() + conn.Close() + return fmt.Errorf("provisioning failed: %w", err) + } + + slog.Info("Agent provisioned successfully", "agent_id", c.agentID) + } + c.mu.Lock() c.conn = conn c.stream = stream @@ -360,3 +397,93 @@ func (c *Client) handleRequest(msg *proto.ProxyMessage) { slog.Error("Failed to send response", "error", err, "message_id", msg.Id) } } + +func (c *Client) handleProvisioningResponse(msg *proto.ProxyMessage) error { + status := msg.Metadata["provisioning_status"] + if status != "success" { + errorMsg := msg.Metadata["error"] + if errorMsg == "" { + errorMsg = "unknown provisioning error" + } + return fmt.Errorf("provisioning failed: %s", errorMsg) + } + + agentID := msg.Metadata["agent_id"] + if agentID == "" { + return fmt.Errorf("provisioning response missing agent_id") + } + + // Update agent ID + c.mu.Lock() + c.agentID = agentID + c.mu.Unlock() + + // Persist to config file + if c.configPath != "" { + if err := c.saveAgentIDToConfig(agentID); err != nil { + slog.Error("Failed to persist agent_id to config", "error", err) + // Don't fail provisioning, agent can reconnect with same key + } else { + slog.Info("Agent ID persisted to config", "config_path", c.configPath) + } + } + + return nil +} + +func (c *Client) saveAgentIDToConfig(agentID string) error { + if c.configPath == "" { + return fmt.Errorf("config path not set") + } + + // Read current config file + data, err := os.ReadFile(c.configPath) + if err != nil { + return fmt.Errorf("failed to read config file: %w", err) + } + + // Parse YAML + var config map[string]interface{} + if err := yaml.Unmarshal(data, &config); err != nil { + return fmt.Errorf("failed to parse config: %w", err) + } + + // Update grpc section + grpcConfig, ok := config["grpc"].(map[string]interface{}) + if !ok { + grpcConfig = make(map[string]interface{}) + config["grpc"] = grpcConfig + } + + // Set agent_id and remove provisioning_key + grpcConfig["agent_id"] = agentID + delete(grpcConfig, "provisioning_key") + + // Convert back to YAML + updatedData, err := yaml.Marshal(config) + if err != nil { + return fmt.Errorf("failed to marshal config: %w", err) + } + + // Add comment at the top + comment := "# Agent provisioned successfully on " + time.Now().Format(time.RFC3339) + "\n" + finalData := comment + string(updatedData) + + // Write back to file + if err := os.WriteFile(c.configPath, []byte(finalData), 0644); err != nil { + return fmt.Errorf("failed to write config file: %w", err) + } + + // Clear provisioning_key from memory + c.mu.Lock() + c.provisioningKey = "" + c.mu.Unlock() + + return nil +} + +func (c *Client) GetAgentID() string { + c.mu.RLock() + defer c.mu.RUnlock() + return c.agentID +} diff --git a/internal/grpc/server/connection_manager.go b/internal/grpc/server/connection_manager.go index 1276b07..aa8e752 100644 --- a/internal/grpc/server/connection_manager.go +++ b/internal/grpc/server/connection_manager.go @@ -7,6 +7,7 @@ import ( "sync" "time" + "github.com/EternisAI/silo-proxy/internal/agents" "github.com/EternisAI/silo-proxy/proto" ) @@ -41,16 +42,18 @@ type ConnectionManager struct { mu sync.RWMutex stopCh chan struct{} agentServerManager AgentServerManager // Optional: manages per-agent HTTP servers + agentService *agents.Service // Optional: for database persistence } // NewConnectionManager creates a new ConnectionManager. // The agentServerManager parameter is optional (can be nil) and enables // per-agent HTTP server management when provided. -func NewConnectionManager(agentServerManager AgentServerManager) *ConnectionManager { +func NewConnectionManager(agentServerManager AgentServerManager, agentService *agents.Service) *ConnectionManager { cm := &ConnectionManager{ agents: make(map[string]*AgentConnection), stopCh: make(chan struct{}), agentServerManager: agentServerManager, + agentService: agentService, } go cm.cleanupStaleConnections() return cm @@ -182,11 +185,27 @@ func (cm *ConnectionManager) SendToAgent(agentID string, msg *proto.ProxyMessage func (cm *ConnectionManager) UpdateLastSeen(agentID string) { cm.mu.Lock() - defer cm.mu.Unlock() - - if conn, ok := cm.agents[agentID]; ok { + conn, ok := cm.agents[agentID] + if ok { conn.LastSeen = time.Now() - slog.Debug("Agent last seen updated", "agent_id", agentID) + } + cm.mu.Unlock() + + if !ok { + return + } + + slog.Debug("Agent last seen updated", "agent_id", agentID) + + // Async DB update (non-blocking) + if cm.agentService != nil { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := cm.agentService.UpdateLastSeen(ctx, agentID, conn.LastSeen, ""); err != nil { + slog.Debug("Failed to update last seen in database", "agent_id", agentID, "error", err) + } + }() } } diff --git a/internal/grpc/server/connection_manager_test.go b/internal/grpc/server/connection_manager_test.go index 2d44560..6f52cba 100644 --- a/internal/grpc/server/connection_manager_test.go +++ b/internal/grpc/server/connection_manager_test.go @@ -78,7 +78,7 @@ func (m *MockStream) RecvMsg(msg interface{}) error { } func TestNewConnectionManager(t *testing.T) { - cm := NewConnectionManager(nil) + cm := NewConnectionManager(nil, nil) assert.NotNil(t, cm) assert.NotNil(t, cm.agents) assert.NotNil(t, cm.stopCh) @@ -89,7 +89,7 @@ func TestNewConnectionManager(t *testing.T) { func TestNewConnectionManager_WithAgentServerManager(t *testing.T) { mockASM := new(MockAgentServerManager) - cm := NewConnectionManager(mockASM) + cm := NewConnectionManager(mockASM, nil) assert.NotNil(t, cm) assert.NotNil(t, cm.agentServerManager) @@ -99,7 +99,7 @@ func TestNewConnectionManager_WithAgentServerManager(t *testing.T) { } func TestConnectionManager_Register_WithoutServerManager(t *testing.T) { - cm := NewConnectionManager(nil) + cm := NewConnectionManager(nil, nil) defer cm.Stop() mockStream := NewMockStream() @@ -118,7 +118,7 @@ func TestConnectionManager_Register_WithServerManager(t *testing.T) { mockASM.On("StopAgentServer", "agent-1").Return(nil) mockASM.On("Shutdown").Return(nil) - cm := NewConnectionManager(mockASM) + cm := NewConnectionManager(mockASM, nil) mockStream := NewMockStream() conn, err := cm.Register("agent-1", mockStream) @@ -138,7 +138,7 @@ func TestConnectionManager_Register_ServerManagerError(t *testing.T) { mockASM.On("StartAgentServer", "agent-1").Return(0, assert.AnError) mockASM.On("Shutdown").Return(nil) - cm := NewConnectionManager(mockASM) + cm := NewConnectionManager(mockASM, nil) mockStream := NewMockStream() conn, err := cm.Register("agent-1", mockStream) @@ -160,7 +160,7 @@ func TestConnectionManager_Register_ReplaceExisting(t *testing.T) { mockASM.On("StopAgentServer", "agent-1").Return(nil).Once() // for cm.Stop() mockASM.On("Shutdown").Return(nil) - cm := NewConnectionManager(mockASM) + cm := NewConnectionManager(mockASM, nil) // Register first connection mockStream1 := NewMockStream() @@ -184,7 +184,7 @@ func TestConnectionManager_Register_ReplaceExisting(t *testing.T) { } func TestConnectionManager_Deregister_WithoutServerManager(t *testing.T) { - cm := NewConnectionManager(nil) + cm := NewConnectionManager(nil, nil) defer cm.Stop() mockStream := NewMockStream() @@ -204,7 +204,7 @@ func TestConnectionManager_Deregister_WithServerManager(t *testing.T) { mockASM.On("StartAgentServer", "agent-1").Return(8100, nil) mockASM.On("StopAgentServer", "agent-1").Return(nil) - cm := NewConnectionManager(mockASM) + cm := NewConnectionManager(mockASM, nil) defer func() { mockASM.On("Shutdown").Return(nil) cm.Stop() @@ -225,7 +225,7 @@ func TestConnectionManager_Deregister_WithServerManager(t *testing.T) { } func TestConnectionManager_Deregister_NonExistent(t *testing.T) { - cm := NewConnectionManager(nil) + cm := NewConnectionManager(nil, nil) defer cm.Stop() // Should not panic @@ -240,7 +240,7 @@ func TestConnectionManager_Stop_WithServerManager(t *testing.T) { mockASM.On("StopAgentServer", "agent-2").Return(nil) mockASM.On("Shutdown").Return(nil) - cm := NewConnectionManager(mockASM) + cm := NewConnectionManager(mockASM, nil) mockStream1 := NewMockStream() _, err := cm.Register("agent-1", mockStream1) @@ -260,7 +260,7 @@ func TestConnectionManager_Stop_WithServerManager(t *testing.T) { } func TestConnectionManager_GetConnection(t *testing.T) { - cm := NewConnectionManager(nil) + cm := NewConnectionManager(nil, nil) defer cm.Stop() mockStream := NewMockStream() @@ -276,7 +276,7 @@ func TestConnectionManager_GetConnection(t *testing.T) { } func TestConnectionManager_ListConnections(t *testing.T) { - cm := NewConnectionManager(nil) + cm := NewConnectionManager(nil, nil) defer cm.Stop() // Initially empty @@ -299,7 +299,7 @@ func TestConnectionManager_ListConnections(t *testing.T) { } func TestConnectionManager_UpdateLastSeen(t *testing.T) { - cm := NewConnectionManager(nil) + cm := NewConnectionManager(nil, nil) defer cm.Stop() mockStream := NewMockStream() @@ -326,7 +326,7 @@ func TestConnectionManager_RemoveStaleConnections_WithServerManager(t *testing.T mockASM.On("StartAgentServer", "agent-1").Return(8100, nil) mockASM.On("StopAgentServer", "agent-1").Return(nil) - cm := NewConnectionManager(mockASM) + cm := NewConnectionManager(mockASM, nil) defer func() { mockASM.On("Shutdown").Return(nil) cm.Stop() @@ -363,7 +363,7 @@ func TestConnectionManager_ConcurrentRegistration(t *testing.T) { } mockASM.On("Shutdown").Return(nil) - cm := NewConnectionManager(mockASM) + cm := NewConnectionManager(mockASM, nil) // Register 10 agents concurrently done := make(chan bool, 10) @@ -397,7 +397,7 @@ func TestConnectionManager_PortFieldPersistence(t *testing.T) { mockASM.On("StopAgentServer", "agent-1").Return(nil) mockASM.On("Shutdown").Return(nil) - cm := NewConnectionManager(mockASM) + cm := NewConnectionManager(mockASM, nil) mockStream := NewMockStream() _, err := cm.Register("agent-1", mockStream) @@ -420,7 +420,7 @@ func TestConnectionManager_PortFieldPersistence(t *testing.T) { } func TestAgentConnection_PortFieldZeroWithoutManager(t *testing.T) { - cm := NewConnectionManager(nil) + cm := NewConnectionManager(nil, nil) defer cm.Stop() mockStream := NewMockStream() @@ -432,7 +432,7 @@ func TestAgentConnection_PortFieldZeroWithoutManager(t *testing.T) { } func BenchmarkConnectionManager_Register(b *testing.B) { - cm := NewConnectionManager(nil) + cm := NewConnectionManager(nil, nil) defer cm.Stop() b.ResetTimer() @@ -450,7 +450,7 @@ func BenchmarkConnectionManager_RegisterWithServerManager(b *testing.B) { mockASM.On("StopAgentServer", mock.Anything).Return(nil) mockASM.On("Shutdown").Return(nil) - cm := NewConnectionManager(mockASM) + cm := NewConnectionManager(mockASM, nil) defer cm.Stop() b.ResetTimer() diff --git a/internal/grpc/server/server.go b/internal/grpc/server/server.go index 1cbee04..78be688 100644 --- a/internal/grpc/server/server.go +++ b/internal/grpc/server/server.go @@ -8,6 +8,8 @@ import ( "sync" "time" + "github.com/EternisAI/silo-proxy/internal/agents" + "github.com/EternisAI/silo-proxy/internal/provisioning" "github.com/EternisAI/silo-proxy/proto" "google.golang.org/grpc" @@ -39,7 +41,7 @@ type TLSConfig struct { } func NewServer(port int, tlsConfig *TLSConfig) *Server { - connManager := NewConnectionManager(nil) + connManager := NewConnectionManager(nil, nil) s := &Server{ connManager: connManager, @@ -48,9 +50,7 @@ func NewServer(port int, tlsConfig *TLSConfig) *Server { pendingRequests: make(map[string]chan *proto.ProxyMessage), } - streamHandler := NewStreamHandler(connManager, s) - s.streamHandler = streamHandler - + // streamHandler will be initialized later via SetServices return s } @@ -178,3 +178,12 @@ func (s *Server) GetConnectionManager() *ConnectionManager { func (s *Server) SetAgentServerManager(asm AgentServerManager) { s.connManager.SetAgentServerManager(asm) } + +// SetServices initializes the stream handler with provisioning and agent services +func (s *Server) SetServices(provisioningService *provisioning.Service, agentService *agents.Service) { + // Update connection manager with agent service for DB persistence + s.connManager.agentService = agentService + + // Initialize stream handler with all services + s.streamHandler = NewStreamHandler(s.connManager, s, provisioningService, agentService) +} diff --git a/internal/grpc/server/stream_handler.go b/internal/grpc/server/stream_handler.go index d34af4c..975a119 100644 --- a/internal/grpc/server/stream_handler.go +++ b/internal/grpc/server/stream_handler.go @@ -1,39 +1,103 @@ package server import ( + "context" "fmt" "io" "log/slog" + "time" + "github.com/EternisAI/silo-proxy/internal/agents" + "github.com/EternisAI/silo-proxy/internal/provisioning" "github.com/EternisAI/silo-proxy/proto" "github.com/google/uuid" ) type StreamHandler struct { - connManager *ConnectionManager - server *Server + connManager *ConnectionManager + server *Server + provisioningService *provisioning.Service + agentService *agents.Service } -func NewStreamHandler(connManager *ConnectionManager, server *Server) *StreamHandler { +func NewStreamHandler( + connManager *ConnectionManager, + server *Server, + provisioningService *provisioning.Service, + agentService *agents.Service, +) *StreamHandler { return &StreamHandler{ - connManager: connManager, - server: server, + connManager: connManager, + server: server, + provisioningService: provisioningService, + agentService: agentService, } } func (sh *StreamHandler) HandleStream(stream proto.ProxyService_StreamServer) error { + ctx := stream.Context() + firstMsg, err := stream.Recv() if err != nil { return fmt.Errorf("failed to receive first message: %w", err) } + // Extract remote IP from context (if available) + remoteIP := extractRemoteIP(ctx) + + // Extract provisioning key and agent_id from metadata + provisioningKey := firstMsg.Metadata["provisioning_key"] agentID := firstMsg.Metadata["agent_id"] - if agentID == "" { - return fmt.Errorf("agent_id not found in first message metadata") + + // Connection log ID for tracking + var connectionLogID string + + if provisioningKey != "" { + // NEW: Provisioning flow + slog.Info("Agent provisioning request received", "remote_ip", remoteIP) + + result, err := sh.provisioningService.ProvisionAgent(ctx, provisioningKey, remoteIP) + if err != nil { + sh.sendProvisioningError(stream, err) + return fmt.Errorf("provisioning failed: %w", err) + } + + sh.sendProvisioningSuccess(stream, result) + agentID = result.AgentID + + slog.Info("Agent provisioned successfully", "agent_id", agentID, "remote_ip", remoteIP) + + } else if agentID != "" { + // Established agent: validate against DB + agent, err := sh.agentService.GetAgentByID(ctx, agentID) + if err != nil { + return fmt.Errorf("failed to get agent: %w", err) + } + + if agent.Status != "active" { + slog.Warn("Agent connection rejected, status not active", + "agent_id", agentID, + "status", agent.Status) + return fmt.Errorf("agent suspended or inactive") + } + + slog.Info("Agent authenticated", "agent_id", agentID, "remote_ip", remoteIP) + } else { + return fmt.Errorf("either provisioning_key or agent_id required in first message metadata") + } + + // Create connection log + logID, err := sh.agentService.CreateConnectionLog(ctx, agentID, time.Now(), remoteIP) + if err != nil { + slog.Error("Failed to create connection log", "agent_id", agentID, "error", err) + // Don't fail the connection, just log the error + } else { + connectionLogID = logID } slog.Info("Agent connection established", "agent_id", agentID) + // Register with ConnectionManager conn, err := sh.connManager.Register(agentID, stream) if err != nil { return fmt.Errorf("failed to register agent: %w", err) @@ -41,11 +105,26 @@ func (sh *StreamHandler) HandleStream(stream proto.ProxyService_StreamServer) er defer func() { sh.connManager.Deregister(agentID) + + // Update connection log with disconnect information + if connectionLogID != "" { + if err := sh.agentService.UpdateConnectionLog(ctx, connectionLogID, time.Now(), "normal disconnect"); err != nil { + slog.Error("Failed to update connection log", "log_id", connectionLogID, "error", err) + } + } + slog.Info("Agent disconnected", "agent_id", agentID) }() sh.connManager.UpdateLastSeen(agentID) + // Update agent last seen in database (async) + go func() { + if err := sh.agentService.UpdateLastSeen(context.Background(), agentID, time.Now(), remoteIP); err != nil { + slog.Error("Failed to update agent last seen", "agent_id", agentID, "error", err) + } + }() + if err := sh.processMessage(agentID, firstMsg); err != nil { slog.Error("Failed to process first message", "agent_id", agentID, "error", err) } @@ -143,3 +222,45 @@ func (sh *StreamHandler) processMessage(agentID string, msg *proto.ProxyMessage) return nil } + +func (sh *StreamHandler) sendProvisioningSuccess(stream proto.ProxyService_StreamServer, result *provisioning.AgentProvisionResult) error { + msg := &proto.ProxyMessage{ + Id: uuid.New().String(), + Type: proto.MessageType_PONG, + Metadata: map[string]string{ + "provisioning_status": "success", + "agent_id": result.AgentID, + }, + } + + if result.CertFingerprint != "" { + msg.Metadata["cert_fingerprint"] = result.CertFingerprint + } + + if err := stream.Send(msg); err != nil { + return fmt.Errorf("failed to send provisioning success: %w", err) + } + + return nil +} + +func (sh *StreamHandler) sendProvisioningError(stream proto.ProxyService_StreamServer, err error) { + msg := &proto.ProxyMessage{ + Id: uuid.New().String(), + Type: proto.MessageType_PONG, + Metadata: map[string]string{ + "provisioning_status": "failed", + "error": err.Error(), + }, + } + + if sendErr := stream.Send(msg); sendErr != nil { + slog.Error("Failed to send provisioning error message", "error", sendErr) + } +} + +func extractRemoteIP(ctx context.Context) string { + // This is a placeholder - actual implementation depends on gRPC metadata + // For now, return empty string + return "" +} diff --git a/internal/provisioning/models.go b/internal/provisioning/models.go new file mode 100644 index 0000000..8f9cfe2 --- /dev/null +++ b/internal/provisioning/models.go @@ -0,0 +1,24 @@ +package provisioning + +import ( + "time" +) + +type ProvisioningKey struct { + ID string + KeyHash string + UserID string + Status string + MaxUses int + UsedCount int + ExpiresAt time.Time + CreatedAt time.Time + UpdatedAt time.Time + RevokedAt *time.Time + Notes string +} + +type AgentProvisionResult struct { + AgentID string + CertFingerprint string // Optional: TLS certificate fingerprint +} diff --git a/internal/provisioning/service.go b/internal/provisioning/service.go new file mode 100644 index 0000000..28e64a7 --- /dev/null +++ b/internal/provisioning/service.go @@ -0,0 +1,241 @@ +package provisioning + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "errors" + "fmt" + "log/slog" + "time" + + "github.com/EternisAI/silo-proxy/internal/cert" + "github.com/EternisAI/silo-proxy/internal/db/sqlc" + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" +) + +const ( + keyPrefix = "pk_" + keyLength = 32 // 32 bytes = 256 bits +) + +var ( + ErrKeyNotFound = errors.New("provisioning key not found") + ErrKeyExpired = errors.New("provisioning key expired") + ErrKeyExhausted = errors.New("provisioning key exhausted") + ErrKeyInvalid = errors.New("provisioning key invalid") +) + +type Service struct { + queries *sqlc.Queries + certService *cert.Service +} + +func NewService(queries *sqlc.Queries, certService *cert.Service) *Service { + return &Service{ + queries: queries, + certService: certService, + } +} + +// GenerateKey creates a new provisioning key with crypto/rand +func GenerateKey() (string, error) { + bytes := make([]byte, keyLength) + if _, err := rand.Read(bytes); err != nil { + return "", fmt.Errorf("failed to generate random bytes: %w", err) + } + + // Use base64 URL-safe encoding + encoded := base64.RawURLEncoding.EncodeToString(bytes) + return keyPrefix + encoded, nil +} + +// HashKey computes SHA-256 hash of the key +func HashKey(key string) string { + hash := sha256.Sum256([]byte(key)) + return fmt.Sprintf("%x", hash) +} + +// CreateKey generates and stores a new provisioning key +func (s *Service) CreateKey(ctx context.Context, userID string, maxUses int, expiresInHours int, notes string) (*ProvisioningKey, string, error) { + // Generate key + key, err := GenerateKey() + if err != nil { + return nil, "", fmt.Errorf("failed to generate key: %w", err) + } + + // Hash the key for storage + keyHash := HashKey(key) + + // Parse user ID + parsedUserID, err := uuid.Parse(userID) + if err != nil { + return nil, "", fmt.Errorf("invalid user ID: %w", err) + } + + // Calculate expiration time + expiresAt := time.Now().Add(time.Duration(expiresInHours) * time.Hour) + + // Store in database + dbKey, err := s.queries.CreateProvisioningKey(ctx, sqlc.CreateProvisioningKeyParams{ + KeyHash: keyHash, + UserID: pgtype.UUID{Bytes: parsedUserID, Valid: true}, + MaxUses: int32(maxUses), + ExpiresAt: pgtype.Timestamp{Time: expiresAt, Valid: true}, + Notes: pgtype.Text{String: notes, Valid: notes != ""}, + }) + if err != nil { + return nil, "", fmt.Errorf("failed to store key: %w", err) + } + + result := &ProvisioningKey{ + ID: uuidToString(dbKey.ID.Bytes), + KeyHash: dbKey.KeyHash, + UserID: uuidToString(dbKey.UserID.Bytes), + Status: string(dbKey.Status), + MaxUses: int(dbKey.MaxUses), + UsedCount: int(dbKey.UsedCount), + ExpiresAt: dbKey.ExpiresAt.Time, + CreatedAt: dbKey.CreatedAt.Time, + UpdatedAt: dbKey.UpdatedAt.Time, + Notes: dbKey.Notes.String, + } + + // Return both the model and the plaintext key (only shown once) + return result, key, nil +} + +// ProvisionAgent validates a provisioning key and creates a new agent +func (s *Service) ProvisionAgent(ctx context.Context, key string, remoteIP string) (*AgentProvisionResult, error) { + // Hash the provided key + keyHash := HashKey(key) + + // Lookup key in database (only returns active keys) + dbKey, err := s.queries.GetProvisioningKeyByHash(ctx, keyHash) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + slog.Warn("Provisioning attempt with invalid key", "remote_ip", remoteIP) + return nil, ErrKeyNotFound + } + return nil, fmt.Errorf("failed to lookup key: %w", err) + } + + // Validate expiration + if time.Now().After(dbKey.ExpiresAt.Time) { + slog.Warn("Provisioning attempt with expired key", + "key_id", uuidToString(dbKey.ID.Bytes), + "expires_at", dbKey.ExpiresAt.Time, + "remote_ip", remoteIP) + return nil, ErrKeyExpired + } + + // Atomically increment key usage count. + // The WHERE clause (used_count < max_uses AND status = 'active') ensures + // concurrent requests cannot exceed max_uses — only one will succeed. + _, err = s.queries.IncrementKeyUsage(ctx, dbKey.ID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + slog.Warn("Provisioning attempt with exhausted key", + "key_id", uuidToString(dbKey.ID.Bytes), + "remote_ip", remoteIP) + return nil, ErrKeyExhausted + } + return nil, fmt.Errorf("failed to increment key usage: %w", err) + } + + // Create agent in database (ID is auto-generated) + dbAgent, err := s.queries.CreateAgent(ctx, sqlc.CreateAgentParams{ + UserID: dbKey.UserID, + ProvisionedWithKeyID: pgtype.UUID{Bytes: dbKey.ID.Bytes, Valid: true}, + Metadata: []byte("{}"), + Notes: pgtype.Text{String: "", Valid: false}, + }) + if err != nil { + return nil, fmt.Errorf("failed to create agent: %w", err) + } + + result := &AgentProvisionResult{ + AgentID: uuidToString(dbAgent.ID.Bytes), + } + + slog.Info("Agent provisioned successfully", + "agent_id", result.AgentID, + "user_id", uuidToString(dbKey.UserID.Bytes), + "key_id", uuidToString(dbKey.ID.Bytes), + "remote_ip", remoteIP) + + return result, nil +} + +// ListUserKeys returns all provisioning keys for a user +func (s *Service) ListUserKeys(ctx context.Context, userID string) ([]ProvisioningKey, error) { + parsedUserID, err := uuid.Parse(userID) + if err != nil { + return nil, fmt.Errorf("invalid user ID: %w", err) + } + + dbKeys, err := s.queries.ListProvisioningKeysByUser(ctx, pgtype.UUID{Bytes: parsedUserID, Valid: true}) + if err != nil { + return nil, fmt.Errorf("failed to list keys: %w", err) + } + + result := make([]ProvisioningKey, len(dbKeys)) + for i, k := range dbKeys { + result[i] = ProvisioningKey{ + ID: uuidToString(k.ID.Bytes), + KeyHash: k.KeyHash, + UserID: uuidToString(k.UserID.Bytes), + Status: string(k.Status), + MaxUses: int(k.MaxUses), + UsedCount: int(k.UsedCount), + ExpiresAt: k.ExpiresAt.Time, + CreatedAt: k.CreatedAt.Time, + UpdatedAt: k.UpdatedAt.Time, + Notes: k.Notes.String, + } + if k.RevokedAt.Valid { + result[i].RevokedAt = &k.RevokedAt.Time + } + } + + return result, nil +} + +// RevokeKey revokes a provisioning key +func (s *Service) RevokeKey(ctx context.Context, keyID string, userID string) error { + parsedKeyID, err := uuid.Parse(keyID) + if err != nil { + return fmt.Errorf("invalid key ID: %w", err) + } + + parsedUserID, err := uuid.Parse(userID) + if err != nil { + return fmt.Errorf("invalid user ID: %w", err) + } + + if err := s.queries.RevokeProvisioningKey(ctx, sqlc.RevokeProvisioningKeyParams{ + ID: pgtype.UUID{Bytes: parsedKeyID, Valid: true}, + UserID: pgtype.UUID{Bytes: parsedUserID, Valid: true}, + }); err != nil { + return fmt.Errorf("failed to revoke key: %w", err) + } + + slog.Info("Provisioning key revoked", "key_id", keyID, "user_id", userID) + return nil +} + +// ExpireOldKeys marks expired keys as expired (cleanup task) +func (s *Service) ExpireOldKeys(ctx context.Context) error { + if err := s.queries.ExpireOldKeys(ctx); err != nil { + return fmt.Errorf("failed to expire old keys: %w", err) + } + return nil +} + +func uuidToString(id [16]byte) string { + return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", + id[0:4], id[4:6], id[6:8], id[8:10], id[10:16]) +}