From 093eebffe8c7e3f727982d5b088d1e6606dbd594 Mon Sep 17 00:00:00 2001 From: Jason Lee Date: Mon, 9 Feb 2026 11:25:07 +0800 Subject: [PATCH 1/5] add AGENTS.md --- .gitignore | 1 + AGENTS.md | 164 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 165 insertions(+) create mode 100644 AGENTS.md diff --git a/.gitignore b/.gitignore index f2f2d90..220b6c4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /bin/ /certs/ +/.agents/ .env diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..162f7a7 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,164 @@ +# AGENTS.md + +Guidance for coding agents working in `github.com/EternisAI/silo-proxy`. + +## Scope +- Prefer minimal, safe changes that preserve behavior. +- Follow existing repo patterns over generic best practices. +- Keep API, service/domain, and infrastructure concerns separated. +- Avoid unrelated refactors in focused tasks. + +## Stack Snapshot +- Go `1.25.0`. +- HTTP: Gin. +- RPC: gRPC/protobuf (`proto/proxy.proto`). +- DB: PostgreSQL via `pgx/v5`, migrations via goose, queries via sqlc. +- Auth: JWT + API key middleware. +- Logging: `log/slog`. +- Tests: Go `testing` + `testify` + testcontainers system tests. + +## Repository Layout +- `cmd/silo-proxy-server`: server entrypoint, config, logger. +- `cmd/silo-proxy-agent`: agent entrypoint, config, logger. +- `internal/api/http`: router, handlers, middleware, per-agent HTTP management. +- `internal/grpc/server`: stream server and connection manager. +- `internal/grpc/client`: agent stream client and local forwarder. +- `internal/auth`, `internal/users`: business/domain services. +- `internal/db`: connection init, migrations, sqlc output. +- `proto`: protobuf schema and generated stubs. +- `systemtest`: integration tests against ephemeral Postgres. + +## Build / Run / Test Commands + +Prefer `make` targets when available. + +### Primary Make Targets +- `make build` - build both binaries. +- `make build-server` - build server binary. +- `make build-agent` - build agent binary. +- `make run` - run server locally. +- `make run-agent` - run agent locally. +- `make test` - run all tests (`go test -v ./...`). +- `make clean` - clean test cache and `bin/*`. +- `make generate` - run `sqlc generate`. +- `make protoc-gen` - regenerate protobuf/gRPC code. +- `make generate-certs` - generate local TLS certs. + +### Running Focused Tests (Important) +- Single package: + - `go test -v ./internal/grpc/server` +- Single test function: + - `go test -v ./internal/grpc/server -run '^TestConnectionManager_Register_WithServerManager$'` +- Single subtest: + - `go test -v ./systemtest -run '^TestSystemIntegration$/^Login$'` +- Re-run without cache: + - `go test -v -count=1 ./internal/api/http` +- Benchmarks: + - `go test -bench . ./internal/grpc/server` +- Compile-only check: + - `go test -run '^$' ./...` + +### Lint / Format / Validation +- No dedicated `make lint` target is currently defined. +- Format changed files with `gofmt -w `. +- Run `go test ./...` as baseline verification. +- Optional static pass: `go vet ./...`. +- Prefer formatting only touched files to avoid noisy diffs. + +## Test Selection Guidance +- Start with the narrowest package/test that covers your change. +- Use exact `-run` regex anchors (`^...$`) to avoid accidental matches. +- Use `-count=1` if caching may hide flakiness. +- `systemtest` requires Docker/testcontainers available locally. +- Run full `make test` before finalizing broad changes. + +## Code Style Guidelines + +### Imports +- Keep imports `gofmt`-organized. +- Standard library first, then external/internal packages. +- Use aliases only when clarity improves (e.g. `grpcserver`). +- Avoid dot imports and unused imports. + +### Formatting and Structure +- Enforce `gofmt`. +- Prefer early returns over nested conditionals. +- Keep functions focused; avoid mixing concerns. +- Add comments only for non-obvious invariants/behavior. + +### Types and API Shapes +- Use explicit structs for service/DTO outputs (`RegisterResult`, `UserInfo`). +- Keep struct fields cohesive to a single responsibility. +- Prefer typed constants for repeated timing/protocol values. +- Use typed config structs with `mapstructure` tags. + +### Naming Conventions +- Exported identifiers: `PascalCase`. +- Unexported identifiers: `camelCase`. +- Constructors follow `NewXxx(...) *Xxx`. +- Handler methods should be verb-driven (`Login`, `Register`, `DeleteUser`). +- Sentinel errors use `ErrXxx` naming. + +### Error Handling +- Return errors instead of panicking (except startup-fatal bootstrap failures). +- Wrap lower-level errors with context using `%w`. +- Branch on known causes using `errors.Is` / `errors.As`. +- Keep domain-level error contracts stable for handler mapping. +- Do not expose sensitive internals in HTTP error payloads. + +### Logging +- Use structured `slog` logs (`Info/Warn/Error/Debug`). +- Include stable keys like `agent_id`, `port`, `message_id`, `error`. +- Log lifecycle and recovery events (start/stop/retry/failure). +- Avoid noisy logs in hot loops unless debug-level is warranted. + +### Concurrency and Context +- Guard shared maps/state with `sync.RWMutex` where applicable. +- Use buffered channels for producer/consumer coordination. +- Use `context.WithTimeout` for shutdown/network operations. +- Preserve existing cancellation and cleanup semantics. + +### HTTP Layer Rules +- Keep handlers thin: bind/validate input, call service, map response. +- Return JSON errors in consistent shape: `{"error":"..."}`. +- Put cross-cutting concerns in middleware (auth, API key, logging). +- Treat read-only/status endpoints as side-effect free. + +### Service and Domain Rules +- Keep business logic in services, not in handlers/middleware. +- Keep DB-specific behavior at DB/service boundaries. +- Keep contracts explicit for auth/user and other domain services. + +### DB, Migrations, SQLC +- SQL files belong in `internal/db/queries`. +- sqlc-generated files live in `internal/db/sqlc` and must not be hand-edited. +- After SQL query/schema changes, run `make generate` and migrations. +- Keep migrations idempotent and ordered. + +### Proto / gRPC Changes +- Edit `proto/proxy.proto`, then run `make protoc-gen`. +- Update both server and agent paths for schema/message changes. +- Preserve compatibility expectations when possible. + +## Configuration Conventions +- Config is loaded from `application.yml` with env overrides. +- Nested key env mapping uses underscore replacement. +- Server and agent keep separate config roots under `cmd/...`. +- Keep examples/defaults aligned with checked-in app configs. + +## CI Notes +- CI runs `make test` and `make build` on pull requests. +- Docker publish jobs execute only on `main` and version tags. +- If packaging/runtime behavior changes, validate make/docker targets. + +## Cursor / Copilot Rules Check +- No `.cursorrules` file found. +- No `.cursor/rules/` directory found. +- No `.github/copilot-instructions.md` file found. +- If added later, treat those files as higher-priority local instructions. + +## Practical Agent Workflow +1. Make the smallest viable change. +2. Run targeted tests first, then broaden scope. +3. Keep diffs tight and architecture boundaries intact. +4. Add/adjust logs and errors for operability. From cd3c300b52974669350212a000db96027f621c10 Mon Sep 17 00:00:00 2001 From: Jason Lee Date: Mon, 9 Feb 2026 12:31:25 +0800 Subject: [PATCH 2/5] add provisioning system database schema (Phase 1) - Create provisioning_keys table with SHA-256 key hashing and status lifecycle - Create agents table where id IS the agent identifier (no separate agent_id) - Create agent_connection_logs table for audit trail - Add SQLC queries for provisioning keys, agents, and connection logs - Generate type-safe database access layer via SQLC - Add provisioning system documentation --- docs/provisioning/Overview.md | 239 ++++++++++++++++++ .../0002_create_provisioning_keys.sql | 31 +++ internal/db/migrations/0003_create_agents.sql | 30 +++ .../0004_create_agent_connection_logs.sql | 22 ++ internal/db/queries/agent_connection_logs.sql | 16 ++ internal/db/queries/agents.sql | 18 ++ internal/db/queries/provisioning_keys.sql | 29 +++ internal/db/sqlc/agent_connection_logs.sql.go | 99 ++++++++ internal/db/sqlc/agents.sql.go | 148 +++++++++++ internal/db/sqlc/models.go | 125 +++++++++ internal/db/sqlc/provisioning_keys.sql.go | 152 +++++++++++ internal/db/sqlc/querier.go | 15 ++ 12 files changed, 924 insertions(+) create mode 100644 docs/provisioning/Overview.md create mode 100644 internal/db/migrations/0002_create_provisioning_keys.sql create mode 100644 internal/db/migrations/0003_create_agents.sql create mode 100644 internal/db/migrations/0004_create_agent_connection_logs.sql create mode 100644 internal/db/queries/agent_connection_logs.sql create mode 100644 internal/db/queries/agents.sql create mode 100644 internal/db/queries/provisioning_keys.sql create mode 100644 internal/db/sqlc/agent_connection_logs.sql.go create mode 100644 internal/db/sqlc/agents.sql.go create mode 100644 internal/db/sqlc/provisioning_keys.sql.go diff --git a/docs/provisioning/Overview.md b/docs/provisioning/Overview.md new file mode 100644 index 0000000..37fcfcd --- /dev/null +++ b/docs/provisioning/Overview.md @@ -0,0 +1,239 @@ +# Key-Based Agent Provisioning System + +## Overview + +This document describes the key-based agent provisioning system for Silo Proxy. The system enables secure, user-scoped agent registration through temporary, single-use provisioning keys. + +## Problem Statement + +**Current Issues:** +- Agents use static `agent_id` with no authentication/authorization +- Any agent can claim any ID (security risk) +- No user ownership tracking (prevents multi-tenant dashboards) + +## Solution + +The provisioning system works as follows: + +1. Dashboard users generate temporary, single-use provisioning keys +2. Users copy keys to client devices +3. Agents connect with key, receive permanent agent_id (UUID) + optional TLS cert +4. Future connections validated against database +5. All agents associated with user accounts + +## Architecture + +### Database Schema + +#### provisioning_keys Table +- Stores SHA-256 hashes of provisioning keys (never plaintext) +- Single-use by default (`max_uses=1`) +- Time-based expiration (24-48 hours recommended) +- Status lifecycle: active → exhausted/expired/revoked + +#### agents Table +- `id` (UUID) IS the agent identifier (no separate agent_id column) +- Tracks user ownership and provisioning key used +- Stores connection metadata and status + +#### agent_connection_logs Table +- Audit trail for all agent connections +- Tracks connection duration, IP address, disconnect reason + +### Provisioning Flow + +#### Case 1: New Agent (First Connection) + +``` +1. Agent → Server: ProxyMessage { + Type: PING + Metadata: { + "provisioning_key": "pk_abc123..." + } + } + +2. Server: + - Hash key (SHA-256) + - Lookup in provisioning_keys table + - Validate: status='active', used_count < max_uses, expires_at > NOW() + - Insert into agents table (id auto-generated) + - Increment provisioning_keys.used_count + - Register in ConnectionManager using agents.id + +3. Server → Agent: ProxyMessage { + Type: PONG + Metadata: { + "provisioning_status": "success", + "agent_id": "550e8400-..." // agents.id + } + } +``` + +#### Case 2: Established Agent + +``` +1. Agent → Server: ProxyMessage { + Type: PING + Metadata: { + "agent_id": "550e8400-..." + } + } + +2. Server: + - Lookup agents.id in database + - Validate: status = 'active' + - Update last_seen_at + - Register in ConnectionManager + +3. Server → Agent: PONG (normal flow) +``` + +#### Case 3: Legacy Agent (Auto-Migration) + +``` +1. Agent sends agent_id (old static ID) +2. Server checks if exists in agents table +3. If NOT found: Auto-register with default user +4. Connection proceeds normally +``` + +## Implementation Status + +### Phase 1: Database Schema ✅ COMPLETED + +**Deliverables:** +- ✅ 3 migration files created and tested +- ✅ SQLC queries defined for all tables +- ✅ Type-safe Go code generated via SQLC + +**Files Created:** +- `internal/db/migrations/0002_create_provisioning_keys.sql` +- `internal/db/migrations/0003_create_agents.sql` +- `internal/db/migrations/0004_create_agent_connection_logs.sql` +- `internal/db/queries/provisioning_keys.sql` +- `internal/db/queries/agents.sql` +- `internal/db/queries/agent_connection_logs.sql` +- `internal/db/sqlc/*.go` (generated) + +### Phase 2: Core Provisioning Logic 🚧 IN PROGRESS + +**Deliverables:** +- Service layer (provisioning, agents) +- gRPC stream handler integration +- ConnectionManager database persistence +- E2E test: agent provisions via gRPC stream + +**Files to Create/Modify:** +- `internal/provisioning/service.go` (NEW) +- `internal/provisioning/models.go` (NEW) +- `internal/agents/service.go` (NEW) +- `internal/agents/models.go` (NEW) +- `internal/grpc/server/stream_handler.go` (MODIFY) +- `internal/grpc/server/connection_manager.go` (MODIFY) +- `cmd/silo-proxy-server/main.go` (MODIFY) + +### Phase 3: API & Client Integration ⏸️ PLANNED + +**Deliverables:** +- HTTP API endpoints (POST/GET/DELETE keys, GET/DELETE agents) +- Agent config + client changes +- E2E test: full provisioning flow via dashboard +- Documentation updates + +## Security + +### Key Generation +- 32 bytes generated using `crypto/rand` +- Format: `pk_` prefix + base64url encoding +- Keys never logged in plaintext + +### Storage +- Only SHA-256 hash stored in database +- Original key shown once to user during creation +- Keys are single-use by default + +### Validation +- Key hash lookup in database +- Expiration time check +- Usage count enforcement +- User ownership validation for all operations + +### Rate Limiting +- 5 requests/second on provisioning endpoint +- Prevents brute-force attacks + +### Audit Logging +- All provisioning attempts logged +- Connection history tracked +- Legacy agent connections flagged + +## Configuration + +### Server Configuration + +**Database Connection:** +```yaml +database: + url: "postgres://user:password@localhost:5432/silo-proxy" + schema: "public" # Optional, defaults to "public" +``` + +### Agent Configuration + +**Provisioning (First Connection):** +```yaml +grpc: + server_address: "server.example.com:9090" + provisioning_key: "pk_abc123..." # Provided by dashboard + tls: + enabled: true + ca_file: "ca.pem" +``` + +**Established Agent (Subsequent Connections):** +```yaml +grpc: + server_address: "server.example.com:9090" + agent_id: "550e8400-e29b-41d4-a716-446655440000" # Auto-saved after provisioning + tls: + enabled: true + ca_file: "ca.pem" +``` + +## Testing + +### Phase 1 Verification +```bash +# 1. Run migrations (embedded, automatic on server start) +make build + +# 2. Verify SQLC generation +make generate + +# 3. Verify tables exist (requires running database) +psql -d silo-proxy -c "\dt" + +# Expected output: +# provisioning_keys +# agents +# agent_connection_logs +# users +``` + +## Future Enhancements + +- Multi-use keys with configurable limits +- Key rotation policies +- Certificate-based authentication +- Agent grouping/tagging +- Webhook notifications on agent events +- Advanced audit logging with retention policies + +## Design Principles + +1. **agents.id is the agent identifier** - No separate agent_id column, simpler schema +2. **Single-use keys by default** - max_uses=1, configurable per key +3. **Connection audit logs** - Implemented from day 1 for compliance +4. **Legacy auto-migration** - Existing agents auto-registered with default user +5. **SHA-256 key hashing** - Keys never stored in plaintext +6. **JWT-based API auth** - Reuse existing middleware diff --git a/internal/db/migrations/0002_create_provisioning_keys.sql b/internal/db/migrations/0002_create_provisioning_keys.sql new file mode 100644 index 0000000..4149b79 --- /dev/null +++ b/internal/db/migrations/0002_create_provisioning_keys.sql @@ -0,0 +1,31 @@ +-- +goose Up +-- +goose StatementBegin +CREATE TYPE provisioning_key_status AS ENUM ('active', 'exhausted', 'expired', 'revoked'); + +CREATE TABLE IF NOT EXISTS provisioning_keys ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + key_hash VARCHAR(255) NOT NULL UNIQUE, + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + status provisioning_key_status NOT NULL DEFAULT 'active', + max_uses INT NOT NULL DEFAULT 1, + used_count INT NOT NULL DEFAULT 0, + expires_at TIMESTAMP NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP NOT NULL DEFAULT NOW(), + revoked_at TIMESTAMP, + notes TEXT +); + +CREATE INDEX IF NOT EXISTS idx_provisioning_keys_user_id ON provisioning_keys(user_id); +CREATE INDEX IF NOT EXISTS idx_provisioning_keys_status ON provisioning_keys(status); +CREATE INDEX IF NOT EXISTS idx_provisioning_keys_key_hash ON provisioning_keys(key_hash); +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +DROP INDEX IF EXISTS idx_provisioning_keys_key_hash; +DROP INDEX IF EXISTS idx_provisioning_keys_status; +DROP INDEX IF EXISTS idx_provisioning_keys_user_id; +DROP TABLE IF EXISTS provisioning_keys; +DROP TYPE IF EXISTS provisioning_key_status; +-- +goose StatementEnd diff --git a/internal/db/migrations/0003_create_agents.sql b/internal/db/migrations/0003_create_agents.sql new file mode 100644 index 0000000..49a48a0 --- /dev/null +++ b/internal/db/migrations/0003_create_agents.sql @@ -0,0 +1,30 @@ +-- +goose Up +-- +goose StatementBegin +CREATE TYPE agent_status AS ENUM ('active', 'inactive', 'suspended'); + +CREATE TABLE IF NOT EXISTS agents ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + provisioned_with_key_id UUID REFERENCES provisioning_keys(id) ON DELETE SET NULL, + status agent_status NOT NULL DEFAULT 'active', + cert_fingerprint VARCHAR(255), + registered_at TIMESTAMP NOT NULL DEFAULT NOW(), + last_seen_at TIMESTAMP NOT NULL DEFAULT NOW(), + last_ip_address INET, + metadata JSONB, + notes TEXT +); + +CREATE INDEX IF NOT EXISTS idx_agents_user_id ON agents(user_id); +CREATE INDEX IF NOT EXISTS idx_agents_status ON agents(status); +CREATE INDEX IF NOT EXISTS idx_agents_last_seen_at ON agents(last_seen_at); +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +DROP INDEX IF EXISTS idx_agents_last_seen_at; +DROP INDEX IF EXISTS idx_agents_status; +DROP INDEX IF EXISTS idx_agents_user_id; +DROP TABLE IF EXISTS agents; +DROP TYPE IF EXISTS agent_status; +-- +goose StatementEnd diff --git a/internal/db/migrations/0004_create_agent_connection_logs.sql b/internal/db/migrations/0004_create_agent_connection_logs.sql new file mode 100644 index 0000000..8ed397a --- /dev/null +++ b/internal/db/migrations/0004_create_agent_connection_logs.sql @@ -0,0 +1,22 @@ +-- +goose Up +-- +goose StatementBegin +CREATE TABLE IF NOT EXISTS agent_connection_logs ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + agent_id UUID NOT NULL REFERENCES agents(id) ON DELETE CASCADE, + connected_at TIMESTAMP NOT NULL DEFAULT NOW(), + disconnected_at TIMESTAMP, + duration_seconds INT, + ip_address INET, + disconnect_reason VARCHAR(255) +); + +CREATE INDEX IF NOT EXISTS idx_connection_logs_agent_id ON agent_connection_logs(agent_id); +CREATE INDEX IF NOT EXISTS idx_connection_logs_connected_at ON agent_connection_logs(connected_at); +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +DROP INDEX IF EXISTS idx_connection_logs_connected_at; +DROP INDEX IF EXISTS idx_connection_logs_agent_id; +DROP TABLE IF EXISTS agent_connection_logs; +-- +goose StatementEnd diff --git a/internal/db/queries/agent_connection_logs.sql b/internal/db/queries/agent_connection_logs.sql new file mode 100644 index 0000000..0a3ff1d --- /dev/null +++ b/internal/db/queries/agent_connection_logs.sql @@ -0,0 +1,16 @@ +-- name: CreateConnectionLog :one +INSERT INTO agent_connection_logs (agent_id, connected_at, ip_address) +VALUES ($1, $2, $3) RETURNING *; + +-- name: UpdateConnectionLog :exec +UPDATE agent_connection_logs +SET disconnected_at = $2, + duration_seconds = EXTRACT(EPOCH FROM ($2 - connected_at))::INT, + disconnect_reason = $3 +WHERE id = $1; + +-- name: GetAgentConnectionHistory :many +SELECT * FROM agent_connection_logs +WHERE agent_id = $1 +ORDER BY connected_at DESC +LIMIT $2 OFFSET $3; diff --git a/internal/db/queries/agents.sql b/internal/db/queries/agents.sql new file mode 100644 index 0000000..fc894d1 --- /dev/null +++ b/internal/db/queries/agents.sql @@ -0,0 +1,18 @@ +-- name: CreateAgent :one +INSERT INTO agents (user_id, provisioned_with_key_id, metadata, notes) +VALUES ($1, $2, $3, $4) RETURNING *; + +-- name: GetAgentByID :one +SELECT * FROM agents WHERE id = $1 LIMIT 1; + +-- name: ListAgentsByUser :many +SELECT * FROM agents WHERE user_id = $1 ORDER BY registered_at DESC; + +-- name: UpdateAgentLastSeen :exec +UPDATE agents SET last_seen_at = $2, last_ip_address = $3 WHERE id = $1; + +-- name: UpdateAgentStatus :exec +UPDATE agents SET status = $2 WHERE id = $1; + +-- name: UpdateAgentCertFingerprint :exec +UPDATE agents SET cert_fingerprint = $2 WHERE id = $1; diff --git a/internal/db/queries/provisioning_keys.sql b/internal/db/queries/provisioning_keys.sql new file mode 100644 index 0000000..00a8118 --- /dev/null +++ b/internal/db/queries/provisioning_keys.sql @@ -0,0 +1,29 @@ +-- name: CreateProvisioningKey :one +INSERT INTO provisioning_keys (key_hash, user_id, max_uses, expires_at, notes) +VALUES ($1, $2, $3, $4, $5) RETURNING *; + +-- name: GetProvisioningKeyByHash :one +SELECT * FROM provisioning_keys +WHERE key_hash = $1 AND status = 'active' LIMIT 1; + +-- name: ListProvisioningKeysByUser :many +SELECT * FROM provisioning_keys WHERE user_id = $1 ORDER BY created_at DESC; + +-- name: IncrementKeyUsage :exec +UPDATE provisioning_keys +SET used_count = used_count + 1, + status = CASE WHEN used_count + 1 >= max_uses + THEN 'exhausted'::provisioning_key_status + ELSE status END, + updated_at = NOW() +WHERE id = $1; + +-- name: RevokeProvisioningKey :exec +UPDATE provisioning_keys +SET status = 'revoked'::provisioning_key_status, revoked_at = NOW(), updated_at = NOW() +WHERE id = $1 AND user_id = $2; + +-- name: ExpireOldKeys :exec +UPDATE provisioning_keys +SET status = 'expired'::provisioning_key_status, updated_at = NOW() +WHERE status = 'active' AND expires_at < NOW(); diff --git a/internal/db/sqlc/agent_connection_logs.sql.go b/internal/db/sqlc/agent_connection_logs.sql.go new file mode 100644 index 0000000..1bc9549 --- /dev/null +++ b/internal/db/sqlc/agent_connection_logs.sql.go @@ -0,0 +1,99 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 +// source: agent_connection_logs.sql + +package sqlc + +import ( + "context" + "net/netip" + + "github.com/jackc/pgx/v5/pgtype" +) + +const createConnectionLog = `-- name: CreateConnectionLog :one +INSERT INTO agent_connection_logs (agent_id, connected_at, ip_address) +VALUES ($1, $2, $3) RETURNING id, agent_id, connected_at, disconnected_at, duration_seconds, ip_address, disconnect_reason +` + +type CreateConnectionLogParams struct { + AgentID pgtype.UUID `json:"agent_id"` + ConnectedAt pgtype.Timestamp `json:"connected_at"` + IpAddress *netip.Addr `json:"ip_address"` +} + +func (q *Queries) CreateConnectionLog(ctx context.Context, arg CreateConnectionLogParams) (AgentConnectionLog, error) { + row := q.db.QueryRow(ctx, createConnectionLog, arg.AgentID, arg.ConnectedAt, arg.IpAddress) + var i AgentConnectionLog + err := row.Scan( + &i.ID, + &i.AgentID, + &i.ConnectedAt, + &i.DisconnectedAt, + &i.DurationSeconds, + &i.IpAddress, + &i.DisconnectReason, + ) + return i, err +} + +const getAgentConnectionHistory = `-- name: GetAgentConnectionHistory :many +SELECT id, agent_id, connected_at, disconnected_at, duration_seconds, ip_address, disconnect_reason FROM agent_connection_logs +WHERE agent_id = $1 +ORDER BY connected_at DESC +LIMIT $2 OFFSET $3 +` + +type GetAgentConnectionHistoryParams struct { + AgentID pgtype.UUID `json:"agent_id"` + Limit int32 `json:"limit"` + Offset int32 `json:"offset"` +} + +func (q *Queries) GetAgentConnectionHistory(ctx context.Context, arg GetAgentConnectionHistoryParams) ([]AgentConnectionLog, error) { + rows, err := q.db.Query(ctx, getAgentConnectionHistory, arg.AgentID, arg.Limit, arg.Offset) + if err != nil { + return nil, err + } + defer rows.Close() + items := []AgentConnectionLog{} + for rows.Next() { + var i AgentConnectionLog + if err := rows.Scan( + &i.ID, + &i.AgentID, + &i.ConnectedAt, + &i.DisconnectedAt, + &i.DurationSeconds, + &i.IpAddress, + &i.DisconnectReason, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const updateConnectionLog = `-- name: UpdateConnectionLog :exec +UPDATE agent_connection_logs +SET disconnected_at = $2, + duration_seconds = EXTRACT(EPOCH FROM ($2 - connected_at))::INT, + disconnect_reason = $3 +WHERE id = $1 +` + +type UpdateConnectionLogParams struct { + ID pgtype.UUID `json:"id"` + DisconnectedAt pgtype.Timestamp `json:"disconnected_at"` + DisconnectReason pgtype.Text `json:"disconnect_reason"` +} + +func (q *Queries) UpdateConnectionLog(ctx context.Context, arg UpdateConnectionLogParams) error { + _, err := q.db.Exec(ctx, updateConnectionLog, arg.ID, arg.DisconnectedAt, arg.DisconnectReason) + return err +} diff --git a/internal/db/sqlc/agents.sql.go b/internal/db/sqlc/agents.sql.go new file mode 100644 index 0000000..0998286 --- /dev/null +++ b/internal/db/sqlc/agents.sql.go @@ -0,0 +1,148 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 +// source: agents.sql + +package sqlc + +import ( + "context" + "net/netip" + + "github.com/jackc/pgx/v5/pgtype" +) + +const createAgent = `-- name: CreateAgent :one +INSERT INTO agents (user_id, provisioned_with_key_id, metadata, notes) +VALUES ($1, $2, $3, $4) RETURNING id, user_id, provisioned_with_key_id, status, cert_fingerprint, registered_at, last_seen_at, last_ip_address, metadata, notes +` + +type CreateAgentParams struct { + UserID pgtype.UUID `json:"user_id"` + ProvisionedWithKeyID pgtype.UUID `json:"provisioned_with_key_id"` + Metadata []byte `json:"metadata"` + Notes pgtype.Text `json:"notes"` +} + +func (q *Queries) CreateAgent(ctx context.Context, arg CreateAgentParams) (Agent, error) { + row := q.db.QueryRow(ctx, createAgent, + arg.UserID, + arg.ProvisionedWithKeyID, + arg.Metadata, + arg.Notes, + ) + var i Agent + err := row.Scan( + &i.ID, + &i.UserID, + &i.ProvisionedWithKeyID, + &i.Status, + &i.CertFingerprint, + &i.RegisteredAt, + &i.LastSeenAt, + &i.LastIpAddress, + &i.Metadata, + &i.Notes, + ) + return i, err +} + +const getAgentByID = `-- name: GetAgentByID :one +SELECT id, user_id, provisioned_with_key_id, status, cert_fingerprint, registered_at, last_seen_at, last_ip_address, metadata, notes FROM agents WHERE id = $1 LIMIT 1 +` + +func (q *Queries) GetAgentByID(ctx context.Context, id pgtype.UUID) (Agent, error) { + row := q.db.QueryRow(ctx, getAgentByID, id) + var i Agent + err := row.Scan( + &i.ID, + &i.UserID, + &i.ProvisionedWithKeyID, + &i.Status, + &i.CertFingerprint, + &i.RegisteredAt, + &i.LastSeenAt, + &i.LastIpAddress, + &i.Metadata, + &i.Notes, + ) + return i, err +} + +const listAgentsByUser = `-- name: ListAgentsByUser :many +SELECT id, user_id, provisioned_with_key_id, status, cert_fingerprint, registered_at, last_seen_at, last_ip_address, metadata, notes FROM agents WHERE user_id = $1 ORDER BY registered_at DESC +` + +func (q *Queries) ListAgentsByUser(ctx context.Context, userID pgtype.UUID) ([]Agent, error) { + rows, err := q.db.Query(ctx, listAgentsByUser, userID) + if err != nil { + return nil, err + } + defer rows.Close() + items := []Agent{} + for rows.Next() { + var i Agent + if err := rows.Scan( + &i.ID, + &i.UserID, + &i.ProvisionedWithKeyID, + &i.Status, + &i.CertFingerprint, + &i.RegisteredAt, + &i.LastSeenAt, + &i.LastIpAddress, + &i.Metadata, + &i.Notes, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const updateAgentCertFingerprint = `-- name: UpdateAgentCertFingerprint :exec +UPDATE agents SET cert_fingerprint = $2 WHERE id = $1 +` + +type UpdateAgentCertFingerprintParams struct { + ID pgtype.UUID `json:"id"` + CertFingerprint pgtype.Text `json:"cert_fingerprint"` +} + +func (q *Queries) UpdateAgentCertFingerprint(ctx context.Context, arg UpdateAgentCertFingerprintParams) error { + _, err := q.db.Exec(ctx, updateAgentCertFingerprint, arg.ID, arg.CertFingerprint) + return err +} + +const updateAgentLastSeen = `-- name: UpdateAgentLastSeen :exec +UPDATE agents SET last_seen_at = $2, last_ip_address = $3 WHERE id = $1 +` + +type UpdateAgentLastSeenParams struct { + ID pgtype.UUID `json:"id"` + LastSeenAt pgtype.Timestamp `json:"last_seen_at"` + LastIpAddress *netip.Addr `json:"last_ip_address"` +} + +func (q *Queries) UpdateAgentLastSeen(ctx context.Context, arg UpdateAgentLastSeenParams) error { + _, err := q.db.Exec(ctx, updateAgentLastSeen, arg.ID, arg.LastSeenAt, arg.LastIpAddress) + return err +} + +const updateAgentStatus = `-- name: UpdateAgentStatus :exec +UPDATE agents SET status = $2 WHERE id = $1 +` + +type UpdateAgentStatusParams struct { + ID pgtype.UUID `json:"id"` + Status AgentStatus `json:"status"` +} + +func (q *Queries) UpdateAgentStatus(ctx context.Context, arg UpdateAgentStatusParams) error { + _, err := q.db.Exec(ctx, updateAgentStatus, arg.ID, arg.Status) + return err +} diff --git a/internal/db/sqlc/models.go b/internal/db/sqlc/models.go index 90880db..2879e4a 100644 --- a/internal/db/sqlc/models.go +++ b/internal/db/sqlc/models.go @@ -7,10 +7,98 @@ package sqlc import ( "database/sql/driver" "fmt" + "net/netip" "github.com/jackc/pgx/v5/pgtype" ) +type AgentStatus string + +const ( + AgentStatusActive AgentStatus = "active" + AgentStatusInactive AgentStatus = "inactive" + AgentStatusSuspended AgentStatus = "suspended" +) + +func (e *AgentStatus) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = AgentStatus(s) + case string: + *e = AgentStatus(s) + default: + return fmt.Errorf("unsupported scan type for AgentStatus: %T", src) + } + return nil +} + +type NullAgentStatus struct { + AgentStatus AgentStatus `json:"agent_status"` + Valid bool `json:"valid"` // Valid is true if AgentStatus is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullAgentStatus) Scan(value interface{}) error { + if value == nil { + ns.AgentStatus, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.AgentStatus.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullAgentStatus) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.AgentStatus), nil +} + +type ProvisioningKeyStatus string + +const ( + ProvisioningKeyStatusActive ProvisioningKeyStatus = "active" + ProvisioningKeyStatusExhausted ProvisioningKeyStatus = "exhausted" + ProvisioningKeyStatusExpired ProvisioningKeyStatus = "expired" + ProvisioningKeyStatusRevoked ProvisioningKeyStatus = "revoked" +) + +func (e *ProvisioningKeyStatus) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = ProvisioningKeyStatus(s) + case string: + *e = ProvisioningKeyStatus(s) + default: + return fmt.Errorf("unsupported scan type for ProvisioningKeyStatus: %T", src) + } + return nil +} + +type NullProvisioningKeyStatus struct { + ProvisioningKeyStatus ProvisioningKeyStatus `json:"provisioning_key_status"` + Valid bool `json:"valid"` // Valid is true if ProvisioningKeyStatus is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullProvisioningKeyStatus) Scan(value interface{}) error { + if value == nil { + ns.ProvisioningKeyStatus, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.ProvisioningKeyStatus.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullProvisioningKeyStatus) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.ProvisioningKeyStatus), nil +} + type UserRole string const ( @@ -53,6 +141,43 @@ func (ns NullUserRole) Value() (driver.Value, error) { return string(ns.UserRole), nil } +type Agent struct { + ID pgtype.UUID `json:"id"` + UserID pgtype.UUID `json:"user_id"` + ProvisionedWithKeyID pgtype.UUID `json:"provisioned_with_key_id"` + Status AgentStatus `json:"status"` + CertFingerprint pgtype.Text `json:"cert_fingerprint"` + RegisteredAt pgtype.Timestamp `json:"registered_at"` + LastSeenAt pgtype.Timestamp `json:"last_seen_at"` + LastIpAddress *netip.Addr `json:"last_ip_address"` + Metadata []byte `json:"metadata"` + Notes pgtype.Text `json:"notes"` +} + +type AgentConnectionLog struct { + ID pgtype.UUID `json:"id"` + AgentID pgtype.UUID `json:"agent_id"` + ConnectedAt pgtype.Timestamp `json:"connected_at"` + DisconnectedAt pgtype.Timestamp `json:"disconnected_at"` + DurationSeconds pgtype.Int4 `json:"duration_seconds"` + IpAddress *netip.Addr `json:"ip_address"` + DisconnectReason pgtype.Text `json:"disconnect_reason"` +} + +type ProvisioningKey struct { + ID pgtype.UUID `json:"id"` + KeyHash string `json:"key_hash"` + UserID pgtype.UUID `json:"user_id"` + Status ProvisioningKeyStatus `json:"status"` + MaxUses int32 `json:"max_uses"` + UsedCount int32 `json:"used_count"` + ExpiresAt pgtype.Timestamp `json:"expires_at"` + CreatedAt pgtype.Timestamp `json:"created_at"` + UpdatedAt pgtype.Timestamp `json:"updated_at"` + RevokedAt pgtype.Timestamp `json:"revoked_at"` + Notes pgtype.Text `json:"notes"` +} + type User struct { ID pgtype.UUID `json:"id"` Username string `json:"username"` diff --git a/internal/db/sqlc/provisioning_keys.sql.go b/internal/db/sqlc/provisioning_keys.sql.go new file mode 100644 index 0000000..63380a7 --- /dev/null +++ b/internal/db/sqlc/provisioning_keys.sql.go @@ -0,0 +1,152 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 +// source: provisioning_keys.sql + +package sqlc + +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" +) + +const createProvisioningKey = `-- name: CreateProvisioningKey :one +INSERT INTO provisioning_keys (key_hash, user_id, max_uses, expires_at, notes) +VALUES ($1, $2, $3, $4, $5) RETURNING id, key_hash, user_id, status, max_uses, used_count, expires_at, created_at, updated_at, revoked_at, notes +` + +type CreateProvisioningKeyParams struct { + KeyHash string `json:"key_hash"` + UserID pgtype.UUID `json:"user_id"` + MaxUses int32 `json:"max_uses"` + ExpiresAt pgtype.Timestamp `json:"expires_at"` + Notes pgtype.Text `json:"notes"` +} + +func (q *Queries) CreateProvisioningKey(ctx context.Context, arg CreateProvisioningKeyParams) (ProvisioningKey, error) { + row := q.db.QueryRow(ctx, createProvisioningKey, + arg.KeyHash, + arg.UserID, + arg.MaxUses, + arg.ExpiresAt, + arg.Notes, + ) + var i ProvisioningKey + err := row.Scan( + &i.ID, + &i.KeyHash, + &i.UserID, + &i.Status, + &i.MaxUses, + &i.UsedCount, + &i.ExpiresAt, + &i.CreatedAt, + &i.UpdatedAt, + &i.RevokedAt, + &i.Notes, + ) + return i, err +} + +const expireOldKeys = `-- name: ExpireOldKeys :exec +UPDATE provisioning_keys +SET status = 'expired'::provisioning_key_status, updated_at = NOW() +WHERE status = 'active' AND expires_at < NOW() +` + +func (q *Queries) ExpireOldKeys(ctx context.Context) error { + _, err := q.db.Exec(ctx, expireOldKeys) + return err +} + +const getProvisioningKeyByHash = `-- name: GetProvisioningKeyByHash :one +SELECT id, key_hash, user_id, status, max_uses, used_count, expires_at, created_at, updated_at, revoked_at, notes FROM provisioning_keys +WHERE key_hash = $1 AND status = 'active' LIMIT 1 +` + +func (q *Queries) GetProvisioningKeyByHash(ctx context.Context, keyHash string) (ProvisioningKey, error) { + row := q.db.QueryRow(ctx, getProvisioningKeyByHash, keyHash) + var i ProvisioningKey + err := row.Scan( + &i.ID, + &i.KeyHash, + &i.UserID, + &i.Status, + &i.MaxUses, + &i.UsedCount, + &i.ExpiresAt, + &i.CreatedAt, + &i.UpdatedAt, + &i.RevokedAt, + &i.Notes, + ) + return i, err +} + +const incrementKeyUsage = `-- name: IncrementKeyUsage :exec +UPDATE provisioning_keys +SET used_count = used_count + 1, + status = CASE WHEN used_count + 1 >= max_uses + THEN 'exhausted'::provisioning_key_status + ELSE status END, + updated_at = NOW() +WHERE id = $1 +` + +func (q *Queries) IncrementKeyUsage(ctx context.Context, id pgtype.UUID) error { + _, err := q.db.Exec(ctx, incrementKeyUsage, id) + return err +} + +const listProvisioningKeysByUser = `-- name: ListProvisioningKeysByUser :many +SELECT id, key_hash, user_id, status, max_uses, used_count, expires_at, created_at, updated_at, revoked_at, notes FROM provisioning_keys WHERE user_id = $1 ORDER BY created_at DESC +` + +func (q *Queries) ListProvisioningKeysByUser(ctx context.Context, userID pgtype.UUID) ([]ProvisioningKey, error) { + rows, err := q.db.Query(ctx, listProvisioningKeysByUser, userID) + if err != nil { + return nil, err + } + defer rows.Close() + items := []ProvisioningKey{} + for rows.Next() { + var i ProvisioningKey + if err := rows.Scan( + &i.ID, + &i.KeyHash, + &i.UserID, + &i.Status, + &i.MaxUses, + &i.UsedCount, + &i.ExpiresAt, + &i.CreatedAt, + &i.UpdatedAt, + &i.RevokedAt, + &i.Notes, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const revokeProvisioningKey = `-- name: RevokeProvisioningKey :exec +UPDATE provisioning_keys +SET status = 'revoked'::provisioning_key_status, revoked_at = NOW(), updated_at = NOW() +WHERE id = $1 AND user_id = $2 +` + +type RevokeProvisioningKeyParams struct { + ID pgtype.UUID `json:"id"` + UserID pgtype.UUID `json:"user_id"` +} + +func (q *Queries) RevokeProvisioningKey(ctx context.Context, arg RevokeProvisioningKeyParams) error { + _, err := q.db.Exec(ctx, revokeProvisioningKey, arg.ID, arg.UserID) + return err +} diff --git a/internal/db/sqlc/querier.go b/internal/db/sqlc/querier.go index 7595c40..b1089cd 100644 --- a/internal/db/sqlc/querier.go +++ b/internal/db/sqlc/querier.go @@ -12,11 +12,26 @@ import ( type Querier interface { CountUsers(ctx context.Context) (int64, error) + CreateAgent(ctx context.Context, arg CreateAgentParams) (Agent, error) + CreateConnectionLog(ctx context.Context, arg CreateConnectionLogParams) (AgentConnectionLog, error) + CreateProvisioningKey(ctx context.Context, arg CreateProvisioningKeyParams) (ProvisioningKey, error) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) DeleteUser(ctx context.Context, id pgtype.UUID) error + ExpireOldKeys(ctx context.Context) error + GetAgentByID(ctx context.Context, id pgtype.UUID) (Agent, error) + GetAgentConnectionHistory(ctx context.Context, arg GetAgentConnectionHistoryParams) ([]AgentConnectionLog, error) + GetProvisioningKeyByHash(ctx context.Context, keyHash string) (ProvisioningKey, error) GetUser(ctx context.Context, id pgtype.UUID) (User, error) GetUserByUsername(ctx context.Context, username string) (User, error) + IncrementKeyUsage(ctx context.Context, id pgtype.UUID) error + ListAgentsByUser(ctx context.Context, userID pgtype.UUID) ([]Agent, error) + ListProvisioningKeysByUser(ctx context.Context, userID pgtype.UUID) ([]ProvisioningKey, error) ListUsersPaginated(ctx context.Context, arg ListUsersPaginatedParams) ([]User, error) + RevokeProvisioningKey(ctx context.Context, arg RevokeProvisioningKeyParams) error + UpdateAgentCertFingerprint(ctx context.Context, arg UpdateAgentCertFingerprintParams) error + UpdateAgentLastSeen(ctx context.Context, arg UpdateAgentLastSeenParams) error + UpdateAgentStatus(ctx context.Context, arg UpdateAgentStatusParams) error + UpdateConnectionLog(ctx context.Context, arg UpdateConnectionLogParams) error } var _ Querier = (*Queries)(nil) From 92bc88438803c92d51290be70294bc676a707f03 Mon Sep 17 00:00:00 2001 From: Jason Lee Date: Mon, 9 Feb 2026 12:53:54 +0800 Subject: [PATCH 3/5] implement provisioning logic and gRPC integration (Phase 2) - Add provisioning service with key generation (crypto/rand) and SHA-256 hashing - Add agent service for lifecycle management and connection logging - Integrate provisioning handshake into gRPC stream handler - Support legacy agent auto-migration with default user - Add database persistence for agent last_seen updates - Update connection manager to accept agent service for DB persistence - Wire up services in server initialization - Update tests to reflect new ConnectionManager signature - Update documentation with Phase 2 completion status --- cmd/silo-proxy-server/main.go | 36 +- docs/provisioning/Overview.md | 30 +- internal/agents/models.go | 28 ++ internal/agents/service.go | 312 ++++++++++++++++++ internal/grpc/server/connection_manager.go | 29 +- .../grpc/server/connection_manager_test.go | 38 +-- internal/grpc/server/server.go | 17 +- internal/grpc/server/stream_handler.go | 158 ++++++++- internal/provisioning/models.go | 24 ++ internal/provisioning/service.go | 265 +++++++++++++++ 10 files changed, 879 insertions(+), 58 deletions(-) create mode 100644 internal/agents/models.go create mode 100644 internal/agents/service.go create mode 100644 internal/provisioning/models.go create mode 100644 internal/provisioning/service.go diff --git a/cmd/silo-proxy-server/main.go b/cmd/silo-proxy-server/main.go index 45b2caa..af10798 100644 --- a/cmd/silo-proxy-server/main.go +++ b/cmd/silo-proxy-server/main.go @@ -12,11 +12,13 @@ import ( "time" internalhttp "github.com/EternisAI/silo-proxy/internal/api/http" + "github.com/EternisAI/silo-proxy/internal/agents" "github.com/EternisAI/silo-proxy/internal/auth" "github.com/EternisAI/silo-proxy/internal/cert" "github.com/EternisAI/silo-proxy/internal/db" "github.com/EternisAI/silo-proxy/internal/db/sqlc" grpcserver "github.com/EternisAI/silo-proxy/internal/grpc/server" + "github.com/EternisAI/silo-proxy/internal/provisioning" "github.com/EternisAI/silo-proxy/internal/users" "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" @@ -46,17 +48,9 @@ func main() { authService := auth.NewService(queries, config.JWT) userService := users.NewService(queries) - tlsConfig := &grpcserver.TLSConfig{ - Enabled: config.Grpc.TLS.Enabled, - CertFile: config.Grpc.TLS.CertFile, - KeyFile: config.Grpc.TLS.KeyFile, - CAFile: config.Grpc.TLS.CAFile, - ClientAuth: config.Grpc.TLS.ClientAuth, - } - + // Initialize provisioning and agent services var certService *cert.Service if config.Grpc.TLS.Enabled { - var err error certService, err = cert.New( config.Grpc.TLS.CAFile, @@ -73,7 +67,31 @@ func main() { } } + provisioningService := provisioning.NewService(queries, certService) + agentService := agents.NewService(queries) + + // Get or create default user for legacy agent migration + defaultUser, err := queries.GetUserByUsername(context.Background(), "admin") + if err != nil { + slog.Warn("Default user 'admin' not found, legacy agent migration may fail") + } + defaultUserID := "" + if defaultUser.ID.Valid { + defaultUserID = fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", + defaultUser.ID.Bytes[0:4], defaultUser.ID.Bytes[4:6], defaultUser.ID.Bytes[6:8], + defaultUser.ID.Bytes[8:10], defaultUser.ID.Bytes[10:16]) + } + + tlsConfig := &grpcserver.TLSConfig{ + Enabled: config.Grpc.TLS.Enabled, + CertFile: config.Grpc.TLS.CertFile, + KeyFile: config.Grpc.TLS.KeyFile, + CAFile: config.Grpc.TLS.CAFile, + ClientAuth: config.Grpc.TLS.ClientAuth, + } + grpcSrv := grpcserver.NewServer(config.Grpc.Port, tlsConfig) + grpcSrv.SetServices(provisioningService, agentService, defaultUserID) portManager, err := internalhttp.NewPortManager( config.Http.AgentPortRange.Start, diff --git a/docs/provisioning/Overview.md b/docs/provisioning/Overview.md index 37fcfcd..b1664c9 100644 --- a/docs/provisioning/Overview.md +++ b/docs/provisioning/Overview.md @@ -115,22 +115,24 @@ The provisioning system works as follows: - `internal/db/queries/agent_connection_logs.sql` - `internal/db/sqlc/*.go` (generated) -### Phase 2: Core Provisioning Logic 🚧 IN PROGRESS +### Phase 2: Core Provisioning Logic ✅ COMPLETED **Deliverables:** -- Service layer (provisioning, agents) -- gRPC stream handler integration -- ConnectionManager database persistence -- E2E test: agent provisions via gRPC stream - -**Files to Create/Modify:** -- `internal/provisioning/service.go` (NEW) -- `internal/provisioning/models.go` (NEW) -- `internal/agents/service.go` (NEW) -- `internal/agents/models.go` (NEW) -- `internal/grpc/server/stream_handler.go` (MODIFY) -- `internal/grpc/server/connection_manager.go` (MODIFY) -- `cmd/silo-proxy-server/main.go` (MODIFY) +- ✅ Service layer (provisioning, agents) +- ✅ gRPC stream handler integration +- ✅ ConnectionManager database persistence +- ⏸️ E2E test: agent provisions via gRPC stream (will be done in Phase 3) + +**Files Created/Modified:** +- `internal/provisioning/service.go` (NEW) - Key generation, validation, agent provisioning +- `internal/provisioning/models.go` (NEW) - Domain models +- `internal/agents/service.go` (NEW) - Agent management, connection logging +- `internal/agents/models.go` (NEW) - Domain models +- `internal/grpc/server/stream_handler.go` (MODIFIED) - Provisioning handshake +- `internal/grpc/server/connection_manager.go` (MODIFIED) - DB persistence +- `internal/grpc/server/server.go` (MODIFIED) - Service initialization +- `internal/grpc/server/connection_manager_test.go` (MODIFIED) - Test updates +- `cmd/silo-proxy-server/main.go` (MODIFIED) - Service wiring ### Phase 3: API & Client Integration ⏸️ PLANNED diff --git a/internal/agents/models.go b/internal/agents/models.go new file mode 100644 index 0000000..112a090 --- /dev/null +++ b/internal/agents/models.go @@ -0,0 +1,28 @@ +package agents + +import ( + "time" +) + +type Agent struct { + ID string + UserID string + ProvisionedWithKeyID string + Status string + CertFingerprint string + RegisteredAt time.Time + LastSeenAt time.Time + LastIPAddress string + Metadata map[string]interface{} + Notes string +} + +type ConnectionLog struct { + ID string + AgentID string + ConnectedAt time.Time + DisconnectedAt *time.Time + DurationSeconds int + IPAddress string + DisconnectReason string +} diff --git a/internal/agents/service.go b/internal/agents/service.go new file mode 100644 index 0000000..e8aab69 --- /dev/null +++ b/internal/agents/service.go @@ -0,0 +1,312 @@ +package agents + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "net/netip" + "time" + + "github.com/EternisAI/silo-proxy/internal/db/sqlc" + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" +) + +var ( + ErrAgentNotFound = errors.New("agent not found") + ErrInvalidAgentID = errors.New("invalid agent ID") +) + +type Service struct { + queries *sqlc.Queries +} + +func NewService(queries *sqlc.Queries) *Service { + return &Service{ + queries: queries, + } +} + +// CreateLegacyAgent creates an agent for legacy migration (no provisioning key) +func (s *Service) CreateLegacyAgent(ctx context.Context, legacyID string, userID string) (*Agent, error) { + parsedUserID, err := uuid.Parse(userID) + if err != nil { + return nil, fmt.Errorf("invalid user ID: %w", err) + } + + // Parse legacy ID as UUID - if it's not a valid UUID, generate a new one + var agentUUID uuid.UUID + parsedAgentID, err := uuid.Parse(legacyID) + if err != nil { + // Not a valid UUID, generate a new one + agentUUID = uuid.New() + slog.Info("Legacy agent ID is not a UUID, generating new ID", + "legacy_id", legacyID, + "new_agent_id", agentUUID.String()) + } else { + agentUUID = parsedAgentID + } + + metadata := map[string]interface{}{ + "legacy": true, + "original_id": legacyID, + "migrated_at": time.Now().Format(time.RFC3339), + } + metadataJSON, _ := json.Marshal(metadata) + + dbAgent, err := s.queries.CreateAgent(ctx, sqlc.CreateAgentParams{ + UserID: pgtype.UUID{Bytes: parsedUserID, Valid: true}, + ProvisionedWithKeyID: pgtype.UUID{Valid: false}, // NULL for legacy agents + Metadata: metadataJSON, + Notes: pgtype.Text{String: "Auto-migrated legacy agent", Valid: true}, + }) + if err != nil { + return nil, fmt.Errorf("failed to create legacy agent: %w", err) + } + + result := &Agent{ + ID: uuidToString(dbAgent.ID.Bytes), + UserID: uuidToString(dbAgent.UserID.Bytes), + Status: string(dbAgent.Status), + RegisteredAt: dbAgent.RegisteredAt.Time, + LastSeenAt: dbAgent.LastSeenAt.Time, + } + + slog.Info("Legacy agent auto-migrated", + "legacy_id", legacyID, + "agent_id", result.ID, + "user_id", userID) + + return result, nil +} + +// GetAgentByID retrieves an agent by ID +func (s *Service) GetAgentByID(ctx context.Context, agentID string) (*Agent, error) { + parsedID, err := uuid.Parse(agentID) + if err != nil { + return nil, ErrInvalidAgentID + } + + dbAgent, err := s.queries.GetAgentByID(ctx, pgtype.UUID{Bytes: parsedID, Valid: true}) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrAgentNotFound + } + return nil, fmt.Errorf("failed to get agent: %w", err) + } + + var metadata map[string]interface{} + if len(dbAgent.Metadata) > 0 { + _ = json.Unmarshal(dbAgent.Metadata, &metadata) + } + + result := &Agent{ + ID: uuidToString(dbAgent.ID.Bytes), + UserID: uuidToString(dbAgent.UserID.Bytes), + Status: string(dbAgent.Status), + CertFingerprint: dbAgent.CertFingerprint.String, + RegisteredAt: dbAgent.RegisteredAt.Time, + LastSeenAt: dbAgent.LastSeenAt.Time, + Metadata: metadata, + Notes: dbAgent.Notes.String, + } + + if dbAgent.ProvisionedWithKeyID.Valid { + result.ProvisionedWithKeyID = uuidToString(dbAgent.ProvisionedWithKeyID.Bytes) + } + + if dbAgent.LastIpAddress != nil { + result.LastIPAddress = dbAgent.LastIpAddress.String() + } + + return result, nil +} + +// ListAgentsByUser retrieves all agents for a user +func (s *Service) ListAgentsByUser(ctx context.Context, userID string) ([]Agent, error) { + parsedUserID, err := uuid.Parse(userID) + if err != nil { + return nil, fmt.Errorf("invalid user ID: %w", err) + } + + dbAgents, err := s.queries.ListAgentsByUser(ctx, pgtype.UUID{Bytes: parsedUserID, Valid: true}) + if err != nil { + return nil, fmt.Errorf("failed to list agents: %w", err) + } + + result := make([]Agent, len(dbAgents)) + for i, a := range dbAgents { + var metadata map[string]interface{} + if len(a.Metadata) > 0 { + _ = json.Unmarshal(a.Metadata, &metadata) + } + + result[i] = Agent{ + ID: uuidToString(a.ID.Bytes), + UserID: uuidToString(a.UserID.Bytes), + Status: string(a.Status), + CertFingerprint: a.CertFingerprint.String, + RegisteredAt: a.RegisteredAt.Time, + LastSeenAt: a.LastSeenAt.Time, + Metadata: metadata, + Notes: a.Notes.String, + } + + if a.ProvisionedWithKeyID.Valid { + result[i].ProvisionedWithKeyID = uuidToString(a.ProvisionedWithKeyID.Bytes) + } + + if a.LastIpAddress != nil { + result[i].LastIPAddress = a.LastIpAddress.String() + } + } + + return result, nil +} + +// UpdateLastSeen updates the agent's last seen timestamp and IP address +func (s *Service) UpdateLastSeen(ctx context.Context, agentID string, timestamp time.Time, ipAddress string) error { + parsedID, err := uuid.Parse(agentID) + if err != nil { + return ErrInvalidAgentID + } + + var ipAddr *netip.Addr + if ipAddress != "" { + parsed, err := netip.ParseAddr(ipAddress) + if err == nil { + ipAddr = &parsed + } + } + + if err := s.queries.UpdateAgentLastSeen(ctx, sqlc.UpdateAgentLastSeenParams{ + ID: pgtype.UUID{Bytes: parsedID, Valid: true}, + LastSeenAt: pgtype.Timestamp{Time: timestamp, Valid: true}, + LastIpAddress: ipAddr, + }); err != nil { + return fmt.Errorf("failed to update last seen: %w", err) + } + + return nil +} + +// UpdateStatus updates the agent's status +func (s *Service) UpdateStatus(ctx context.Context, agentID string, status string) error { + parsedID, err := uuid.Parse(agentID) + if err != nil { + return ErrInvalidAgentID + } + + var agentStatus sqlc.AgentStatus + switch status { + case "active": + agentStatus = sqlc.AgentStatusActive + case "inactive": + agentStatus = sqlc.AgentStatusInactive + case "suspended": + agentStatus = sqlc.AgentStatusSuspended + default: + return fmt.Errorf("invalid status: %s", status) + } + + if err := s.queries.UpdateAgentStatus(ctx, sqlc.UpdateAgentStatusParams{ + ID: pgtype.UUID{Bytes: parsedID, Valid: true}, + Status: agentStatus, + }); err != nil { + return fmt.Errorf("failed to update status: %w", err) + } + + slog.Info("Agent status updated", "agent_id", agentID, "status", status) + return nil +} + +// CreateConnectionLog creates a new connection log entry +func (s *Service) CreateConnectionLog(ctx context.Context, agentID string, connectedAt time.Time, ipAddress string) (string, error) { + parsedID, err := uuid.Parse(agentID) + if err != nil { + return "", ErrInvalidAgentID + } + + var ipAddr *netip.Addr + if ipAddress != "" { + parsed, err := netip.ParseAddr(ipAddress) + if err == nil { + ipAddr = &parsed + } + } + + dbLog, err := s.queries.CreateConnectionLog(ctx, sqlc.CreateConnectionLogParams{ + AgentID: pgtype.UUID{Bytes: parsedID, Valid: true}, + ConnectedAt: pgtype.Timestamp{Time: connectedAt, Valid: true}, + IpAddress: ipAddr, + }) + if err != nil { + return "", fmt.Errorf("failed to create connection log: %w", err) + } + + return uuidToString(dbLog.ID.Bytes), nil +} + +// UpdateConnectionLog updates a connection log with disconnect information +func (s *Service) UpdateConnectionLog(ctx context.Context, logID string, disconnectedAt time.Time, reason string) error { + parsedID, err := uuid.Parse(logID) + if err != nil { + return fmt.Errorf("invalid log ID: %w", err) + } + + if err := s.queries.UpdateConnectionLog(ctx, sqlc.UpdateConnectionLogParams{ + ID: pgtype.UUID{Bytes: parsedID, Valid: true}, + DisconnectedAt: pgtype.Timestamp{Time: disconnectedAt, Valid: true}, + DisconnectReason: pgtype.Text{String: reason, Valid: reason != ""}, + }); err != nil { + return fmt.Errorf("failed to update connection log: %w", err) + } + + return nil +} + +// GetAgentConnectionHistory retrieves connection history for an agent +func (s *Service) GetAgentConnectionHistory(ctx context.Context, agentID string, limit, offset int) ([]ConnectionLog, error) { + parsedID, err := uuid.Parse(agentID) + if err != nil { + return nil, ErrInvalidAgentID + } + + dbLogs, err := s.queries.GetAgentConnectionHistory(ctx, sqlc.GetAgentConnectionHistoryParams{ + AgentID: pgtype.UUID{Bytes: parsedID, Valid: true}, + Limit: int32(limit), + Offset: int32(offset), + }) + if err != nil { + return nil, fmt.Errorf("failed to get connection history: %w", err) + } + + result := make([]ConnectionLog, len(dbLogs)) + for i, l := range dbLogs { + result[i] = ConnectionLog{ + ID: uuidToString(l.ID.Bytes), + AgentID: uuidToString(l.AgentID.Bytes), + ConnectedAt: l.ConnectedAt.Time, + DurationSeconds: int(l.DurationSeconds.Int32), + DisconnectReason: l.DisconnectReason.String, + } + + if l.DisconnectedAt.Valid { + result[i].DisconnectedAt = &l.DisconnectedAt.Time + } + + if l.IpAddress != nil { + result[i].IPAddress = l.IpAddress.String() + } + } + + return result, nil +} + +func uuidToString(id [16]byte) string { + return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", + id[0:4], id[4:6], id[6:8], id[8:10], id[10:16]) +} diff --git a/internal/grpc/server/connection_manager.go b/internal/grpc/server/connection_manager.go index 1276b07..aa8e752 100644 --- a/internal/grpc/server/connection_manager.go +++ b/internal/grpc/server/connection_manager.go @@ -7,6 +7,7 @@ import ( "sync" "time" + "github.com/EternisAI/silo-proxy/internal/agents" "github.com/EternisAI/silo-proxy/proto" ) @@ -41,16 +42,18 @@ type ConnectionManager struct { mu sync.RWMutex stopCh chan struct{} agentServerManager AgentServerManager // Optional: manages per-agent HTTP servers + agentService *agents.Service // Optional: for database persistence } // NewConnectionManager creates a new ConnectionManager. // The agentServerManager parameter is optional (can be nil) and enables // per-agent HTTP server management when provided. -func NewConnectionManager(agentServerManager AgentServerManager) *ConnectionManager { +func NewConnectionManager(agentServerManager AgentServerManager, agentService *agents.Service) *ConnectionManager { cm := &ConnectionManager{ agents: make(map[string]*AgentConnection), stopCh: make(chan struct{}), agentServerManager: agentServerManager, + agentService: agentService, } go cm.cleanupStaleConnections() return cm @@ -182,11 +185,27 @@ func (cm *ConnectionManager) SendToAgent(agentID string, msg *proto.ProxyMessage func (cm *ConnectionManager) UpdateLastSeen(agentID string) { cm.mu.Lock() - defer cm.mu.Unlock() - - if conn, ok := cm.agents[agentID]; ok { + conn, ok := cm.agents[agentID] + if ok { conn.LastSeen = time.Now() - slog.Debug("Agent last seen updated", "agent_id", agentID) + } + cm.mu.Unlock() + + if !ok { + return + } + + slog.Debug("Agent last seen updated", "agent_id", agentID) + + // Async DB update (non-blocking) + if cm.agentService != nil { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := cm.agentService.UpdateLastSeen(ctx, agentID, conn.LastSeen, ""); err != nil { + slog.Debug("Failed to update last seen in database", "agent_id", agentID, "error", err) + } + }() } } diff --git a/internal/grpc/server/connection_manager_test.go b/internal/grpc/server/connection_manager_test.go index 2d44560..6f52cba 100644 --- a/internal/grpc/server/connection_manager_test.go +++ b/internal/grpc/server/connection_manager_test.go @@ -78,7 +78,7 @@ func (m *MockStream) RecvMsg(msg interface{}) error { } func TestNewConnectionManager(t *testing.T) { - cm := NewConnectionManager(nil) + cm := NewConnectionManager(nil, nil) assert.NotNil(t, cm) assert.NotNil(t, cm.agents) assert.NotNil(t, cm.stopCh) @@ -89,7 +89,7 @@ func TestNewConnectionManager(t *testing.T) { func TestNewConnectionManager_WithAgentServerManager(t *testing.T) { mockASM := new(MockAgentServerManager) - cm := NewConnectionManager(mockASM) + cm := NewConnectionManager(mockASM, nil) assert.NotNil(t, cm) assert.NotNil(t, cm.agentServerManager) @@ -99,7 +99,7 @@ func TestNewConnectionManager_WithAgentServerManager(t *testing.T) { } func TestConnectionManager_Register_WithoutServerManager(t *testing.T) { - cm := NewConnectionManager(nil) + cm := NewConnectionManager(nil, nil) defer cm.Stop() mockStream := NewMockStream() @@ -118,7 +118,7 @@ func TestConnectionManager_Register_WithServerManager(t *testing.T) { mockASM.On("StopAgentServer", "agent-1").Return(nil) mockASM.On("Shutdown").Return(nil) - cm := NewConnectionManager(mockASM) + cm := NewConnectionManager(mockASM, nil) mockStream := NewMockStream() conn, err := cm.Register("agent-1", mockStream) @@ -138,7 +138,7 @@ func TestConnectionManager_Register_ServerManagerError(t *testing.T) { mockASM.On("StartAgentServer", "agent-1").Return(0, assert.AnError) mockASM.On("Shutdown").Return(nil) - cm := NewConnectionManager(mockASM) + cm := NewConnectionManager(mockASM, nil) mockStream := NewMockStream() conn, err := cm.Register("agent-1", mockStream) @@ -160,7 +160,7 @@ func TestConnectionManager_Register_ReplaceExisting(t *testing.T) { mockASM.On("StopAgentServer", "agent-1").Return(nil).Once() // for cm.Stop() mockASM.On("Shutdown").Return(nil) - cm := NewConnectionManager(mockASM) + cm := NewConnectionManager(mockASM, nil) // Register first connection mockStream1 := NewMockStream() @@ -184,7 +184,7 @@ func TestConnectionManager_Register_ReplaceExisting(t *testing.T) { } func TestConnectionManager_Deregister_WithoutServerManager(t *testing.T) { - cm := NewConnectionManager(nil) + cm := NewConnectionManager(nil, nil) defer cm.Stop() mockStream := NewMockStream() @@ -204,7 +204,7 @@ func TestConnectionManager_Deregister_WithServerManager(t *testing.T) { mockASM.On("StartAgentServer", "agent-1").Return(8100, nil) mockASM.On("StopAgentServer", "agent-1").Return(nil) - cm := NewConnectionManager(mockASM) + cm := NewConnectionManager(mockASM, nil) defer func() { mockASM.On("Shutdown").Return(nil) cm.Stop() @@ -225,7 +225,7 @@ func TestConnectionManager_Deregister_WithServerManager(t *testing.T) { } func TestConnectionManager_Deregister_NonExistent(t *testing.T) { - cm := NewConnectionManager(nil) + cm := NewConnectionManager(nil, nil) defer cm.Stop() // Should not panic @@ -240,7 +240,7 @@ func TestConnectionManager_Stop_WithServerManager(t *testing.T) { mockASM.On("StopAgentServer", "agent-2").Return(nil) mockASM.On("Shutdown").Return(nil) - cm := NewConnectionManager(mockASM) + cm := NewConnectionManager(mockASM, nil) mockStream1 := NewMockStream() _, err := cm.Register("agent-1", mockStream1) @@ -260,7 +260,7 @@ func TestConnectionManager_Stop_WithServerManager(t *testing.T) { } func TestConnectionManager_GetConnection(t *testing.T) { - cm := NewConnectionManager(nil) + cm := NewConnectionManager(nil, nil) defer cm.Stop() mockStream := NewMockStream() @@ -276,7 +276,7 @@ func TestConnectionManager_GetConnection(t *testing.T) { } func TestConnectionManager_ListConnections(t *testing.T) { - cm := NewConnectionManager(nil) + cm := NewConnectionManager(nil, nil) defer cm.Stop() // Initially empty @@ -299,7 +299,7 @@ func TestConnectionManager_ListConnections(t *testing.T) { } func TestConnectionManager_UpdateLastSeen(t *testing.T) { - cm := NewConnectionManager(nil) + cm := NewConnectionManager(nil, nil) defer cm.Stop() mockStream := NewMockStream() @@ -326,7 +326,7 @@ func TestConnectionManager_RemoveStaleConnections_WithServerManager(t *testing.T mockASM.On("StartAgentServer", "agent-1").Return(8100, nil) mockASM.On("StopAgentServer", "agent-1").Return(nil) - cm := NewConnectionManager(mockASM) + cm := NewConnectionManager(mockASM, nil) defer func() { mockASM.On("Shutdown").Return(nil) cm.Stop() @@ -363,7 +363,7 @@ func TestConnectionManager_ConcurrentRegistration(t *testing.T) { } mockASM.On("Shutdown").Return(nil) - cm := NewConnectionManager(mockASM) + cm := NewConnectionManager(mockASM, nil) // Register 10 agents concurrently done := make(chan bool, 10) @@ -397,7 +397,7 @@ func TestConnectionManager_PortFieldPersistence(t *testing.T) { mockASM.On("StopAgentServer", "agent-1").Return(nil) mockASM.On("Shutdown").Return(nil) - cm := NewConnectionManager(mockASM) + cm := NewConnectionManager(mockASM, nil) mockStream := NewMockStream() _, err := cm.Register("agent-1", mockStream) @@ -420,7 +420,7 @@ func TestConnectionManager_PortFieldPersistence(t *testing.T) { } func TestAgentConnection_PortFieldZeroWithoutManager(t *testing.T) { - cm := NewConnectionManager(nil) + cm := NewConnectionManager(nil, nil) defer cm.Stop() mockStream := NewMockStream() @@ -432,7 +432,7 @@ func TestAgentConnection_PortFieldZeroWithoutManager(t *testing.T) { } func BenchmarkConnectionManager_Register(b *testing.B) { - cm := NewConnectionManager(nil) + cm := NewConnectionManager(nil, nil) defer cm.Stop() b.ResetTimer() @@ -450,7 +450,7 @@ func BenchmarkConnectionManager_RegisterWithServerManager(b *testing.B) { mockASM.On("StopAgentServer", mock.Anything).Return(nil) mockASM.On("Shutdown").Return(nil) - cm := NewConnectionManager(mockASM) + cm := NewConnectionManager(mockASM, nil) defer cm.Stop() b.ResetTimer() diff --git a/internal/grpc/server/server.go b/internal/grpc/server/server.go index 1cbee04..35fc12d 100644 --- a/internal/grpc/server/server.go +++ b/internal/grpc/server/server.go @@ -8,6 +8,8 @@ import ( "sync" "time" + "github.com/EternisAI/silo-proxy/internal/agents" + "github.com/EternisAI/silo-proxy/internal/provisioning" "github.com/EternisAI/silo-proxy/proto" "google.golang.org/grpc" @@ -39,7 +41,7 @@ type TLSConfig struct { } func NewServer(port int, tlsConfig *TLSConfig) *Server { - connManager := NewConnectionManager(nil) + connManager := NewConnectionManager(nil, nil) s := &Server{ connManager: connManager, @@ -48,9 +50,7 @@ func NewServer(port int, tlsConfig *TLSConfig) *Server { pendingRequests: make(map[string]chan *proto.ProxyMessage), } - streamHandler := NewStreamHandler(connManager, s) - s.streamHandler = streamHandler - + // streamHandler will be initialized later via SetServices return s } @@ -178,3 +178,12 @@ func (s *Server) GetConnectionManager() *ConnectionManager { func (s *Server) SetAgentServerManager(asm AgentServerManager) { s.connManager.SetAgentServerManager(asm) } + +// SetServices initializes the stream handler with provisioning and agent services +func (s *Server) SetServices(provisioningService *provisioning.Service, agentService *agents.Service, defaultUserID string) { + // Update connection manager with agent service for DB persistence + s.connManager.agentService = agentService + + // Initialize stream handler with all services + s.streamHandler = NewStreamHandler(s.connManager, s, provisioningService, agentService, defaultUserID) +} diff --git a/internal/grpc/server/stream_handler.go b/internal/grpc/server/stream_handler.go index d34af4c..21f0934 100644 --- a/internal/grpc/server/stream_handler.go +++ b/internal/grpc/server/stream_handler.go @@ -1,39 +1,126 @@ package server import ( + "context" + "errors" "fmt" "io" "log/slog" + "time" + "github.com/EternisAI/silo-proxy/internal/agents" + "github.com/EternisAI/silo-proxy/internal/provisioning" "github.com/EternisAI/silo-proxy/proto" "github.com/google/uuid" ) type StreamHandler struct { - connManager *ConnectionManager - server *Server + connManager *ConnectionManager + server *Server + provisioningService *provisioning.Service + agentService *agents.Service + defaultUserID string // Default user for legacy agent migration } -func NewStreamHandler(connManager *ConnectionManager, server *Server) *StreamHandler { +func NewStreamHandler( + connManager *ConnectionManager, + server *Server, + provisioningService *provisioning.Service, + agentService *agents.Service, + defaultUserID string, +) *StreamHandler { return &StreamHandler{ - connManager: connManager, - server: server, + connManager: connManager, + server: server, + provisioningService: provisioningService, + agentService: agentService, + defaultUserID: defaultUserID, } } func (sh *StreamHandler) HandleStream(stream proto.ProxyService_StreamServer) error { + ctx := stream.Context() + firstMsg, err := stream.Recv() if err != nil { return fmt.Errorf("failed to receive first message: %w", err) } + // Extract remote IP from context (if available) + remoteIP := extractRemoteIP(ctx) + + // Extract provisioning key and agent_id from metadata + provisioningKey := firstMsg.Metadata["provisioning_key"] agentID := firstMsg.Metadata["agent_id"] - if agentID == "" { - return fmt.Errorf("agent_id not found in first message metadata") + + // Connection log ID for tracking + var connectionLogID string + + if provisioningKey != "" { + // NEW: Provisioning flow + slog.Info("Agent provisioning request received", "remote_ip", remoteIP) + + result, err := sh.provisioningService.ProvisionAgent(ctx, provisioningKey, remoteIP) + if err != nil { + sh.sendProvisioningError(stream, err) + return fmt.Errorf("provisioning failed: %w", err) + } + + sh.sendProvisioningSuccess(stream, result) + agentID = result.AgentID + + slog.Info("Agent provisioned successfully", "agent_id", agentID, "remote_ip", remoteIP) + + } else if agentID != "" { + // Established agent: validate against DB + agent, err := sh.agentService.GetAgentByID(ctx, agentID) + if err != nil { + if errors.Is(err, agents.ErrAgentNotFound) { + // Legacy agent auto-migration + slog.Warn("Legacy agent detected, auto-migrating", "agent_id", agentID) + + agent, err = sh.agentService.CreateLegacyAgent(ctx, agentID, sh.defaultUserID) + if err != nil { + return fmt.Errorf("failed to create legacy agent: %w", err) + } + + slog.Info("Legacy agent auto-migrated", + "agent_id", agent.ID, + "legacy_id", agentID, + "user_id", sh.defaultUserID) + + // Use the new agent ID + agentID = agent.ID + } else { + return fmt.Errorf("failed to get agent: %w", err) + } + } else { + // Validate agent status + if agent.Status != "active" { + slog.Warn("Agent connection rejected, status not active", + "agent_id", agentID, + "status", agent.Status) + return fmt.Errorf("agent suspended or inactive") + } + } + + slog.Info("Agent authenticated", "agent_id", agentID, "remote_ip", remoteIP) + } else { + return fmt.Errorf("either provisioning_key or agent_id required in first message metadata") + } + + // Create connection log + logID, err := sh.agentService.CreateConnectionLog(ctx, agentID, time.Now(), remoteIP) + if err != nil { + slog.Error("Failed to create connection log", "agent_id", agentID, "error", err) + // Don't fail the connection, just log the error + } else { + connectionLogID = logID } slog.Info("Agent connection established", "agent_id", agentID) + // Register with ConnectionManager conn, err := sh.connManager.Register(agentID, stream) if err != nil { return fmt.Errorf("failed to register agent: %w", err) @@ -41,11 +128,26 @@ func (sh *StreamHandler) HandleStream(stream proto.ProxyService_StreamServer) er defer func() { sh.connManager.Deregister(agentID) + + // Update connection log with disconnect information + if connectionLogID != "" { + if err := sh.agentService.UpdateConnectionLog(ctx, connectionLogID, time.Now(), "normal disconnect"); err != nil { + slog.Error("Failed to update connection log", "log_id", connectionLogID, "error", err) + } + } + slog.Info("Agent disconnected", "agent_id", agentID) }() sh.connManager.UpdateLastSeen(agentID) + // Update agent last seen in database (async) + go func() { + if err := sh.agentService.UpdateLastSeen(context.Background(), agentID, time.Now(), remoteIP); err != nil { + slog.Error("Failed to update agent last seen", "agent_id", agentID, "error", err) + } + }() + if err := sh.processMessage(agentID, firstMsg); err != nil { slog.Error("Failed to process first message", "agent_id", agentID, "error", err) } @@ -143,3 +245,45 @@ func (sh *StreamHandler) processMessage(agentID string, msg *proto.ProxyMessage) return nil } + +func (sh *StreamHandler) sendProvisioningSuccess(stream proto.ProxyService_StreamServer, result *provisioning.AgentProvisionResult) error { + msg := &proto.ProxyMessage{ + Id: uuid.New().String(), + Type: proto.MessageType_PONG, + Metadata: map[string]string{ + "provisioning_status": "success", + "agent_id": result.AgentID, + }, + } + + if result.CertFingerprint != "" { + msg.Metadata["cert_fingerprint"] = result.CertFingerprint + } + + if err := stream.Send(msg); err != nil { + return fmt.Errorf("failed to send provisioning success: %w", err) + } + + return nil +} + +func (sh *StreamHandler) sendProvisioningError(stream proto.ProxyService_StreamServer, err error) { + msg := &proto.ProxyMessage{ + Id: uuid.New().String(), + Type: proto.MessageType_PONG, + Metadata: map[string]string{ + "provisioning_status": "failed", + "error": err.Error(), + }, + } + + if sendErr := stream.Send(msg); sendErr != nil { + slog.Error("Failed to send provisioning error message", "error", sendErr) + } +} + +func extractRemoteIP(ctx context.Context) string { + // This is a placeholder - actual implementation depends on gRPC metadata + // For now, return empty string + return "" +} diff --git a/internal/provisioning/models.go b/internal/provisioning/models.go new file mode 100644 index 0000000..8f9cfe2 --- /dev/null +++ b/internal/provisioning/models.go @@ -0,0 +1,24 @@ +package provisioning + +import ( + "time" +) + +type ProvisioningKey struct { + ID string + KeyHash string + UserID string + Status string + MaxUses int + UsedCount int + ExpiresAt time.Time + CreatedAt time.Time + UpdatedAt time.Time + RevokedAt *time.Time + Notes string +} + +type AgentProvisionResult struct { + AgentID string + CertFingerprint string // Optional: TLS certificate fingerprint +} diff --git a/internal/provisioning/service.go b/internal/provisioning/service.go new file mode 100644 index 0000000..6c9b02e --- /dev/null +++ b/internal/provisioning/service.go @@ -0,0 +1,265 @@ +package provisioning + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "errors" + "fmt" + "log/slog" + "time" + + "github.com/EternisAI/silo-proxy/internal/cert" + "github.com/EternisAI/silo-proxy/internal/db/sqlc" + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" +) + +const ( + keyPrefix = "pk_" + keyLength = 32 // 32 bytes = 256 bits +) + +var ( + ErrKeyNotFound = errors.New("provisioning key not found") + ErrKeyExpired = errors.New("provisioning key expired") + ErrKeyExhausted = errors.New("provisioning key exhausted") + ErrKeyInvalid = errors.New("provisioning key invalid") +) + +type Service struct { + queries *sqlc.Queries + certService *cert.Service +} + +func NewService(queries *sqlc.Queries, certService *cert.Service) *Service { + return &Service{ + queries: queries, + certService: certService, + } +} + +// GenerateKey creates a new provisioning key with crypto/rand +func GenerateKey() (string, error) { + bytes := make([]byte, keyLength) + if _, err := rand.Read(bytes); err != nil { + return "", fmt.Errorf("failed to generate random bytes: %w", err) + } + + // Use base64 URL-safe encoding + encoded := base64.RawURLEncoding.EncodeToString(bytes) + return keyPrefix + encoded, nil +} + +// HashKey computes SHA-256 hash of the key +func HashKey(key string) string { + hash := sha256.Sum256([]byte(key)) + return fmt.Sprintf("%x", hash) +} + +// CreateKey generates and stores a new provisioning key +func (s *Service) CreateKey(ctx context.Context, userID string, maxUses int, expiresInHours int, notes string) (*ProvisioningKey, string, error) { + // Generate key + key, err := GenerateKey() + if err != nil { + return nil, "", fmt.Errorf("failed to generate key: %w", err) + } + + // Hash the key for storage + keyHash := HashKey(key) + + // Parse user ID + parsedUserID, err := uuid.Parse(userID) + if err != nil { + return nil, "", fmt.Errorf("invalid user ID: %w", err) + } + + // Calculate expiration time + expiresAt := time.Now().Add(time.Duration(expiresInHours) * time.Hour) + + // Store in database + dbKey, err := s.queries.CreateProvisioningKey(ctx, sqlc.CreateProvisioningKeyParams{ + KeyHash: keyHash, + UserID: pgtype.UUID{Bytes: parsedUserID, Valid: true}, + MaxUses: int32(maxUses), + ExpiresAt: pgtype.Timestamp{Time: expiresAt, Valid: true}, + Notes: pgtype.Text{String: notes, Valid: notes != ""}, + }) + if err != nil { + return nil, "", fmt.Errorf("failed to store key: %w", err) + } + + result := &ProvisioningKey{ + ID: uuidToString(dbKey.ID.Bytes), + KeyHash: dbKey.KeyHash, + UserID: uuidToString(dbKey.UserID.Bytes), + Status: string(dbKey.Status), + MaxUses: int(dbKey.MaxUses), + UsedCount: int(dbKey.UsedCount), + ExpiresAt: dbKey.ExpiresAt.Time, + CreatedAt: dbKey.CreatedAt.Time, + UpdatedAt: dbKey.UpdatedAt.Time, + Notes: dbKey.Notes.String, + } + + // Return both the model and the plaintext key (only shown once) + return result, key, nil +} + +// ProvisionAgent validates a provisioning key and creates a new agent +func (s *Service) ProvisionAgent(ctx context.Context, key string, remoteIP string) (*AgentProvisionResult, error) { + // Hash the provided key + keyHash := HashKey(key) + + // Lookup key in database + dbKey, err := s.queries.GetProvisioningKeyByHash(ctx, keyHash) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + slog.Warn("Provisioning attempt with invalid key", "remote_ip", remoteIP) + return nil, ErrKeyNotFound + } + return nil, fmt.Errorf("failed to lookup key: %w", err) + } + + // Validate key status + if dbKey.Status != sqlc.ProvisioningKeyStatusActive { + slog.Warn("Provisioning attempt with non-active key", + "key_id", uuidToString(dbKey.ID.Bytes), + "status", dbKey.Status, + "remote_ip", remoteIP) + + switch dbKey.Status { + case sqlc.ProvisioningKeyStatusExpired: + return nil, ErrKeyExpired + case sqlc.ProvisioningKeyStatusExhausted: + return nil, ErrKeyExhausted + default: + return nil, ErrKeyInvalid + } + } + + // Validate expiration + if time.Now().After(dbKey.ExpiresAt.Time) { + slog.Warn("Provisioning attempt with expired key", + "key_id", uuidToString(dbKey.ID.Bytes), + "expires_at", dbKey.ExpiresAt.Time, + "remote_ip", remoteIP) + return nil, ErrKeyExpired + } + + // Validate usage count + if dbKey.UsedCount >= dbKey.MaxUses { + slog.Warn("Provisioning attempt with exhausted key", + "key_id", uuidToString(dbKey.ID.Bytes), + "used_count", dbKey.UsedCount, + "max_uses", dbKey.MaxUses, + "remote_ip", remoteIP) + return nil, ErrKeyExhausted + } + + // Create agent in database (ID is auto-generated) + dbAgent, err := s.queries.CreateAgent(ctx, sqlc.CreateAgentParams{ + UserID: dbKey.UserID, + ProvisionedWithKeyID: pgtype.UUID{Bytes: dbKey.ID.Bytes, Valid: true}, + Metadata: []byte("{}"), // Empty JSON object + Notes: pgtype.Text{String: "", Valid: false}, + }) + if err != nil { + return nil, fmt.Errorf("failed to create agent: %w", err) + } + + // Increment key usage count + if err := s.queries.IncrementKeyUsage(ctx, dbKey.ID); err != nil { + slog.Error("Failed to increment key usage count", + "key_id", uuidToString(dbKey.ID.Bytes), + "error", err) + // Don't fail the provisioning, but log the error + } + + result := &AgentProvisionResult{ + AgentID: uuidToString(dbAgent.ID.Bytes), + } + + // TODO: Generate TLS certificate if cert service is available + // This will be implemented in a future phase when GenerateAgentCertificate method is added + + slog.Info("Agent provisioned successfully", + "agent_id", result.AgentID, + "user_id", uuidToString(dbKey.UserID.Bytes), + "key_id", uuidToString(dbKey.ID.Bytes), + "remote_ip", remoteIP) + + return result, nil +} + +// ListUserKeys returns all provisioning keys for a user +func (s *Service) ListUserKeys(ctx context.Context, userID string) ([]ProvisioningKey, error) { + parsedUserID, err := uuid.Parse(userID) + if err != nil { + return nil, fmt.Errorf("invalid user ID: %w", err) + } + + dbKeys, err := s.queries.ListProvisioningKeysByUser(ctx, pgtype.UUID{Bytes: parsedUserID, Valid: true}) + if err != nil { + return nil, fmt.Errorf("failed to list keys: %w", err) + } + + result := make([]ProvisioningKey, len(dbKeys)) + for i, k := range dbKeys { + result[i] = ProvisioningKey{ + ID: uuidToString(k.ID.Bytes), + KeyHash: k.KeyHash, + UserID: uuidToString(k.UserID.Bytes), + Status: string(k.Status), + MaxUses: int(k.MaxUses), + UsedCount: int(k.UsedCount), + ExpiresAt: k.ExpiresAt.Time, + CreatedAt: k.CreatedAt.Time, + UpdatedAt: k.UpdatedAt.Time, + Notes: k.Notes.String, + } + if k.RevokedAt.Valid { + result[i].RevokedAt = &k.RevokedAt.Time + } + } + + return result, nil +} + +// RevokeKey revokes a provisioning key +func (s *Service) RevokeKey(ctx context.Context, keyID string, userID string) error { + parsedKeyID, err := uuid.Parse(keyID) + if err != nil { + return fmt.Errorf("invalid key ID: %w", err) + } + + parsedUserID, err := uuid.Parse(userID) + if err != nil { + return fmt.Errorf("invalid user ID: %w", err) + } + + if err := s.queries.RevokeProvisioningKey(ctx, sqlc.RevokeProvisioningKeyParams{ + ID: pgtype.UUID{Bytes: parsedKeyID, Valid: true}, + UserID: pgtype.UUID{Bytes: parsedUserID, Valid: true}, + }); err != nil { + return fmt.Errorf("failed to revoke key: %w", err) + } + + slog.Info("Provisioning key revoked", "key_id", keyID, "user_id", userID) + return nil +} + +// ExpireOldKeys marks expired keys as expired (cleanup task) +func (s *Service) ExpireOldKeys(ctx context.Context) error { + if err := s.queries.ExpireOldKeys(ctx); err != nil { + return fmt.Errorf("failed to expire old keys: %w", err) + } + return nil +} + +func uuidToString(id [16]byte) string { + return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", + id[0:4], id[4:6], id[6:8], id[8:10], id[10:16]) +} From d621752e1d84d1106cd449231e7d0d2d51b10f92 Mon Sep 17 00:00:00 2001 From: Jason Lee Date: Mon, 9 Feb 2026 13:00:12 +0800 Subject: [PATCH 4/5] add HTTP API and agent client provisioning flow (Phase 3) - Add provisioning key management endpoints (POST/GET/DELETE) - Add agent management endpoints (GET/DELETE) - Implement JWT-based authentication for all provisioning endpoints - Add agent client provisioning handshake logic - Implement automatic config file persistence after provisioning - Agent removes provisioning_key and saves agent_id to config - Add comprehensive manual testing guide in documentation - Update router to expose new endpoints with authentication - Wire services into HTTP layer API Endpoints: - POST /provisioning-keys - Create provisioning key - GET /provisioning-keys - List user's keys - DELETE /provisioning-keys/:id - Revoke key - GET /agents - List user's agents with connection status - GET /agents/:id - Get agent details - DELETE /agents/:id - Deregister agent (soft delete) Agent Client: - Accepts provisioning_key in config - Sends provisioning_key on first connection - Receives agent_id from server - Persists agent_id to YAML config - Removes provisioning_key from config - Subsequent connections use agent_id --- cmd/silo-proxy-agent/config.go | 7 +- cmd/silo-proxy-agent/main.go | 42 ++++- cmd/silo-proxy-server/main.go | 10 +- docs/provisioning/Overview.md | 210 +++++++++++++++++++++- internal/api/http/dto/provisioning.go | 44 +++++ internal/api/http/handler/agents.go | 175 ++++++++++++++++++ internal/api/http/handler/provisioning.go | 117 ++++++++++++ internal/api/http/router.go | 39 +++- internal/grpc/client/client.go | 149 +++++++++++++-- 9 files changed, 764 insertions(+), 29 deletions(-) create mode 100644 internal/api/http/dto/provisioning.go create mode 100644 internal/api/http/handler/agents.go create mode 100644 internal/api/http/handler/provisioning.go diff --git a/cmd/silo-proxy-agent/config.go b/cmd/silo-proxy-agent/config.go index d3713b5..8b578b0 100644 --- a/cmd/silo-proxy-agent/config.go +++ b/cmd/silo-proxy-agent/config.go @@ -18,9 +18,10 @@ type Config struct { } type GrpcConfig struct { - ServerAddress string `mapstructure:"server_address"` - AgentID string `mapstructure:"agent_id"` - TLS TLSConfig `mapstructure:"tls"` + ServerAddress string `mapstructure:"server_address"` + AgentID string `mapstructure:"agent_id"` + ProvisioningKey string `mapstructure:"provisioning_key"` + TLS TLSConfig `mapstructure:"tls"` } type TLSConfig struct { diff --git a/cmd/silo-proxy-agent/main.go b/cmd/silo-proxy-agent/main.go index d9c2fcd..ced7fc9 100644 --- a/cmd/silo-proxy-agent/main.go +++ b/cmd/silo-proxy-agent/main.go @@ -7,6 +7,7 @@ import ( "net/http" "os" "os/signal" + "path/filepath" "sync" "syscall" "time" @@ -33,12 +34,51 @@ func main() { ServerNameOverride: config.Grpc.TLS.ServerNameOverride, } - grpcClient := grpcclient.NewClient(config.Grpc.ServerAddress, config.Grpc.AgentID, config.Local.ServiceURL, tlsConfig) + // Determine config path for persistence + configPath := "" + if config.Grpc.ProvisioningKey != "" { + // Try to find config file in common locations + possiblePaths := []string{ + "./application.yaml", + "./application.yml", + "./cmd/silo-proxy-agent/application.yaml", + "./cmd/silo-proxy-agent/application.yml", + } + for _, path := range possiblePaths { + if _, err := os.Stat(path); err == nil { + absPath, _ := filepath.Abs(path) + configPath = absPath + slog.Info("Config file found for persistence", "path", configPath) + break + } + } + if configPath == "" { + slog.Warn("Config file not found, agent_id will not be persisted") + } + } + + grpcClient := grpcclient.NewClient( + config.Grpc.ServerAddress, + config.Grpc.AgentID, + config.Grpc.ProvisioningKey, + config.Local.ServiceURL, + configPath, + tlsConfig, + ) if err := grpcClient.Start(); err != nil { slog.Error("Failed to start gRPC client", "error", err) os.Exit(1) } + if config.Grpc.ProvisioningKey != "" { + slog.Info("Agent started in provisioning mode") + } else if config.Grpc.AgentID != "" { + slog.Info("Agent started with agent_id", "agent_id", config.Grpc.AgentID) + } else { + slog.Error("Either agent_id or provisioning_key is required") + os.Exit(1) + } + gin.SetMode(gin.ReleaseMode) engine := gin.New() engine.Use(cors.New(cors.Config{ diff --git a/cmd/silo-proxy-server/main.go b/cmd/silo-proxy-server/main.go index af10798..acebf4d 100644 --- a/cmd/silo-proxy-server/main.go +++ b/cmd/silo-proxy-server/main.go @@ -111,10 +111,12 @@ func main() { "pool_size", config.Http.AgentPortRange.End-config.Http.AgentPortRange.Start+1) services := &internalhttp.Services{ - GrpcServer: grpcSrv, - CertService: certService, - AuthService: authService, - UserService: userService, + GrpcServer: grpcSrv, + CertService: certService, + AuthService: authService, + UserService: userService, + ProvisioningService: provisioningService, + AgentService: agentService, } gin.SetMode(gin.ReleaseMode) diff --git a/docs/provisioning/Overview.md b/docs/provisioning/Overview.md index b1664c9..00bad68 100644 --- a/docs/provisioning/Overview.md +++ b/docs/provisioning/Overview.md @@ -134,13 +134,24 @@ The provisioning system works as follows: - `internal/grpc/server/connection_manager_test.go` (MODIFIED) - Test updates - `cmd/silo-proxy-server/main.go` (MODIFIED) - Service wiring -### Phase 3: API & Client Integration ⏸️ PLANNED +### Phase 3: API & Client Integration ✅ COMPLETED **Deliverables:** -- HTTP API endpoints (POST/GET/DELETE keys, GET/DELETE agents) -- Agent config + client changes -- E2E test: full provisioning flow via dashboard -- Documentation updates +- ✅ HTTP API endpoints (POST/GET/DELETE keys, GET/DELETE agents) +- ✅ Agent config + client changes +- ✅ Config persistence after provisioning +- ⏸️ E2E test: full provisioning flow via dashboard (manual testing guide provided) + +**Files Created/Modified:** +- `internal/api/http/dto/provisioning.go` (NEW) - API DTOs +- `internal/api/http/handler/provisioning.go` (NEW) - Key management endpoints +- `internal/api/http/handler/agents.go` (NEW) - Agent management endpoints +- `internal/api/http/router.go` (MODIFIED) - Route registration +- `cmd/silo-proxy-server/main.go` (MODIFIED) - Service wiring +- `cmd/silo-proxy-agent/config.go` (MODIFIED) - Add provisioning_key field +- `cmd/silo-proxy-agent/main.go` (MODIFIED) - Config persistence logic +- `internal/grpc/client/client.go` (MODIFIED) - Provisioning handshake +- `docs/provisioning/Overview.md` (UPDATED) - Phase 3 completion ## Security @@ -222,6 +233,195 @@ psql -d silo-proxy -c "\dt" # users ``` +### Phase 3 Manual Testing Guide + +#### Prerequisites +- Running PostgreSQL database +- Server built and configured +- Agent built + +#### Step 1: Start the Server +```bash +# Make sure database URL is configured +export DB_URL="postgres://user:password@localhost:5432/silo-proxy?sslmode=disable" + +# Start server +./bin/silo-proxy-server +``` + +#### Step 2: Create a User Account +```bash +# Register a new user +curl -X POST http://localhost:8080/auth/register \ + -H "Content-Type: application/json" \ + -d '{ + "username": "testuser", + "password": "testpass123", + "role": "User" + }' + +# Login and get JWT token +TOKEN=$(curl -X POST http://localhost:8080/auth/login \ + -H "Content-Type: application/json" \ + -d '{ + "username": "testuser", + "password": "testpass123" + }' | jq -r .token) + +echo "JWT Token: $TOKEN" +``` + +#### Step 3: Generate a Provisioning Key +```bash +# Create a single-use provisioning key (expires in 24 hours) +RESPONSE=$(curl -X POST http://localhost:8080/provisioning-keys \ + -H "Authorization: Bearer $TOKEN" \ + -H "Content-Type: application/json" \ + -d '{ + "max_uses": 1, + "expires_in_hours": 24, + "notes": "Test key for agent-1" + }') + +echo "Provisioning Response: $RESPONSE" + +# Extract the provisioning key +KEY=$(echo $RESPONSE | jq -r .key) +echo "Provisioning Key: $KEY" +``` + +#### Step 4: Configure and Start Agent +```bash +# Create agent config with provisioning key +cat > cmd/silo-proxy-agent/application.yaml < Date: Mon, 9 Feb 2026 16:06:36 +0800 Subject: [PATCH 5/5] fix provisioning race condition and remove legacy agent migration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Make IncrementKeyUsage atomic with WHERE used_count < max_uses guard so concurrent requests cannot exceed max_uses. Reorder ProvisionAgent to claim a key use before creating the agent. Remove legacy agent auto-migration since the service has not been deployed to production — unknown agent_id now returns an error instead of silently creating agents under the admin user. --- cmd/silo-proxy-server/main.go | 14 +----- internal/agents/service.go | 53 ----------------------- internal/db/queries/provisioning_keys.sql | 5 ++- internal/db/sqlc/provisioning_keys.sql.go | 25 ++++++++--- internal/db/sqlc/querier.go | 2 +- internal/grpc/server/server.go | 4 +- internal/grpc/server/stream_handler.go | 37 +++------------- internal/provisioning/service.go | 52 ++++++---------------- 8 files changed, 48 insertions(+), 144 deletions(-) diff --git a/cmd/silo-proxy-server/main.go b/cmd/silo-proxy-server/main.go index acebf4d..584b3e9 100644 --- a/cmd/silo-proxy-server/main.go +++ b/cmd/silo-proxy-server/main.go @@ -70,18 +70,6 @@ func main() { provisioningService := provisioning.NewService(queries, certService) agentService := agents.NewService(queries) - // Get or create default user for legacy agent migration - defaultUser, err := queries.GetUserByUsername(context.Background(), "admin") - if err != nil { - slog.Warn("Default user 'admin' not found, legacy agent migration may fail") - } - defaultUserID := "" - if defaultUser.ID.Valid { - defaultUserID = fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", - defaultUser.ID.Bytes[0:4], defaultUser.ID.Bytes[4:6], defaultUser.ID.Bytes[6:8], - defaultUser.ID.Bytes[8:10], defaultUser.ID.Bytes[10:16]) - } - tlsConfig := &grpcserver.TLSConfig{ Enabled: config.Grpc.TLS.Enabled, CertFile: config.Grpc.TLS.CertFile, @@ -91,7 +79,7 @@ func main() { } grpcSrv := grpcserver.NewServer(config.Grpc.Port, tlsConfig) - grpcSrv.SetServices(provisioningService, agentService, defaultUserID) + grpcSrv.SetServices(provisioningService, agentService) portManager, err := internalhttp.NewPortManager( config.Http.AgentPortRange.Start, diff --git a/internal/agents/service.go b/internal/agents/service.go index e8aab69..102d8b8 100644 --- a/internal/agents/service.go +++ b/internal/agents/service.go @@ -30,59 +30,6 @@ func NewService(queries *sqlc.Queries) *Service { } } -// CreateLegacyAgent creates an agent for legacy migration (no provisioning key) -func (s *Service) CreateLegacyAgent(ctx context.Context, legacyID string, userID string) (*Agent, error) { - parsedUserID, err := uuid.Parse(userID) - if err != nil { - return nil, fmt.Errorf("invalid user ID: %w", err) - } - - // Parse legacy ID as UUID - if it's not a valid UUID, generate a new one - var agentUUID uuid.UUID - parsedAgentID, err := uuid.Parse(legacyID) - if err != nil { - // Not a valid UUID, generate a new one - agentUUID = uuid.New() - slog.Info("Legacy agent ID is not a UUID, generating new ID", - "legacy_id", legacyID, - "new_agent_id", agentUUID.String()) - } else { - agentUUID = parsedAgentID - } - - metadata := map[string]interface{}{ - "legacy": true, - "original_id": legacyID, - "migrated_at": time.Now().Format(time.RFC3339), - } - metadataJSON, _ := json.Marshal(metadata) - - dbAgent, err := s.queries.CreateAgent(ctx, sqlc.CreateAgentParams{ - UserID: pgtype.UUID{Bytes: parsedUserID, Valid: true}, - ProvisionedWithKeyID: pgtype.UUID{Valid: false}, // NULL for legacy agents - Metadata: metadataJSON, - Notes: pgtype.Text{String: "Auto-migrated legacy agent", Valid: true}, - }) - if err != nil { - return nil, fmt.Errorf("failed to create legacy agent: %w", err) - } - - result := &Agent{ - ID: uuidToString(dbAgent.ID.Bytes), - UserID: uuidToString(dbAgent.UserID.Bytes), - Status: string(dbAgent.Status), - RegisteredAt: dbAgent.RegisteredAt.Time, - LastSeenAt: dbAgent.LastSeenAt.Time, - } - - slog.Info("Legacy agent auto-migrated", - "legacy_id", legacyID, - "agent_id", result.ID, - "user_id", userID) - - return result, nil -} - // GetAgentByID retrieves an agent by ID func (s *Service) GetAgentByID(ctx context.Context, agentID string) (*Agent, error) { parsedID, err := uuid.Parse(agentID) diff --git a/internal/db/queries/provisioning_keys.sql b/internal/db/queries/provisioning_keys.sql index 00a8118..60772ee 100644 --- a/internal/db/queries/provisioning_keys.sql +++ b/internal/db/queries/provisioning_keys.sql @@ -9,14 +9,15 @@ WHERE key_hash = $1 AND status = 'active' LIMIT 1; -- name: ListProvisioningKeysByUser :many SELECT * FROM provisioning_keys WHERE user_id = $1 ORDER BY created_at DESC; --- name: IncrementKeyUsage :exec +-- name: IncrementKeyUsage :one UPDATE provisioning_keys SET used_count = used_count + 1, status = CASE WHEN used_count + 1 >= max_uses THEN 'exhausted'::provisioning_key_status ELSE status END, updated_at = NOW() -WHERE id = $1; +WHERE id = $1 AND used_count < max_uses AND status = 'active' +RETURNING *; -- name: RevokeProvisioningKey :exec UPDATE provisioning_keys diff --git a/internal/db/sqlc/provisioning_keys.sql.go b/internal/db/sqlc/provisioning_keys.sql.go index 63380a7..4222f9c 100644 --- a/internal/db/sqlc/provisioning_keys.sql.go +++ b/internal/db/sqlc/provisioning_keys.sql.go @@ -84,19 +84,34 @@ func (q *Queries) GetProvisioningKeyByHash(ctx context.Context, keyHash string) return i, err } -const incrementKeyUsage = `-- name: IncrementKeyUsage :exec +const incrementKeyUsage = `-- name: IncrementKeyUsage :one UPDATE provisioning_keys SET used_count = used_count + 1, status = CASE WHEN used_count + 1 >= max_uses THEN 'exhausted'::provisioning_key_status ELSE status END, updated_at = NOW() -WHERE id = $1 +WHERE id = $1 AND used_count < max_uses AND status = 'active' +RETURNING id, key_hash, user_id, status, max_uses, used_count, expires_at, created_at, updated_at, revoked_at, notes ` -func (q *Queries) IncrementKeyUsage(ctx context.Context, id pgtype.UUID) error { - _, err := q.db.Exec(ctx, incrementKeyUsage, id) - return err +func (q *Queries) IncrementKeyUsage(ctx context.Context, id pgtype.UUID) (ProvisioningKey, error) { + row := q.db.QueryRow(ctx, incrementKeyUsage, id) + var i ProvisioningKey + err := row.Scan( + &i.ID, + &i.KeyHash, + &i.UserID, + &i.Status, + &i.MaxUses, + &i.UsedCount, + &i.ExpiresAt, + &i.CreatedAt, + &i.UpdatedAt, + &i.RevokedAt, + &i.Notes, + ) + return i, err } const listProvisioningKeysByUser = `-- name: ListProvisioningKeysByUser :many diff --git a/internal/db/sqlc/querier.go b/internal/db/sqlc/querier.go index b1089cd..2ee812a 100644 --- a/internal/db/sqlc/querier.go +++ b/internal/db/sqlc/querier.go @@ -23,7 +23,7 @@ type Querier interface { GetProvisioningKeyByHash(ctx context.Context, keyHash string) (ProvisioningKey, error) GetUser(ctx context.Context, id pgtype.UUID) (User, error) GetUserByUsername(ctx context.Context, username string) (User, error) - IncrementKeyUsage(ctx context.Context, id pgtype.UUID) error + IncrementKeyUsage(ctx context.Context, id pgtype.UUID) (ProvisioningKey, error) ListAgentsByUser(ctx context.Context, userID pgtype.UUID) ([]Agent, error) ListProvisioningKeysByUser(ctx context.Context, userID pgtype.UUID) ([]ProvisioningKey, error) ListUsersPaginated(ctx context.Context, arg ListUsersPaginatedParams) ([]User, error) diff --git a/internal/grpc/server/server.go b/internal/grpc/server/server.go index 35fc12d..78be688 100644 --- a/internal/grpc/server/server.go +++ b/internal/grpc/server/server.go @@ -180,10 +180,10 @@ func (s *Server) SetAgentServerManager(asm AgentServerManager) { } // SetServices initializes the stream handler with provisioning and agent services -func (s *Server) SetServices(provisioningService *provisioning.Service, agentService *agents.Service, defaultUserID string) { +func (s *Server) SetServices(provisioningService *provisioning.Service, agentService *agents.Service) { // Update connection manager with agent service for DB persistence s.connManager.agentService = agentService // Initialize stream handler with all services - s.streamHandler = NewStreamHandler(s.connManager, s, provisioningService, agentService, defaultUserID) + s.streamHandler = NewStreamHandler(s.connManager, s, provisioningService, agentService) } diff --git a/internal/grpc/server/stream_handler.go b/internal/grpc/server/stream_handler.go index 21f0934..975a119 100644 --- a/internal/grpc/server/stream_handler.go +++ b/internal/grpc/server/stream_handler.go @@ -2,7 +2,6 @@ package server import ( "context" - "errors" "fmt" "io" "log/slog" @@ -19,7 +18,6 @@ type StreamHandler struct { server *Server provisioningService *provisioning.Service agentService *agents.Service - defaultUserID string // Default user for legacy agent migration } func NewStreamHandler( @@ -27,14 +25,12 @@ func NewStreamHandler( server *Server, provisioningService *provisioning.Service, agentService *agents.Service, - defaultUserID string, ) *StreamHandler { return &StreamHandler{ connManager: connManager, server: server, provisioningService: provisioningService, agentService: agentService, - defaultUserID: defaultUserID, } } @@ -75,33 +71,14 @@ func (sh *StreamHandler) HandleStream(stream proto.ProxyService_StreamServer) er // Established agent: validate against DB agent, err := sh.agentService.GetAgentByID(ctx, agentID) if err != nil { - if errors.Is(err, agents.ErrAgentNotFound) { - // Legacy agent auto-migration - slog.Warn("Legacy agent detected, auto-migrating", "agent_id", agentID) - - agent, err = sh.agentService.CreateLegacyAgent(ctx, agentID, sh.defaultUserID) - if err != nil { - return fmt.Errorf("failed to create legacy agent: %w", err) - } - - slog.Info("Legacy agent auto-migrated", - "agent_id", agent.ID, - "legacy_id", agentID, - "user_id", sh.defaultUserID) + return fmt.Errorf("failed to get agent: %w", err) + } - // Use the new agent ID - agentID = agent.ID - } else { - return fmt.Errorf("failed to get agent: %w", err) - } - } else { - // Validate agent status - if agent.Status != "active" { - slog.Warn("Agent connection rejected, status not active", - "agent_id", agentID, - "status", agent.Status) - return fmt.Errorf("agent suspended or inactive") - } + if agent.Status != "active" { + slog.Warn("Agent connection rejected, status not active", + "agent_id", agentID, + "status", agent.Status) + return fmt.Errorf("agent suspended or inactive") } slog.Info("Agent authenticated", "agent_id", agentID, "remote_ip", remoteIP) diff --git a/internal/provisioning/service.go b/internal/provisioning/service.go index 6c9b02e..28e64a7 100644 --- a/internal/provisioning/service.go +++ b/internal/provisioning/service.go @@ -113,7 +113,7 @@ func (s *Service) ProvisionAgent(ctx context.Context, key string, remoteIP strin // Hash the provided key keyHash := HashKey(key) - // Lookup key in database + // Lookup key in database (only returns active keys) dbKey, err := s.queries.GetProvisioningKeyByHash(ctx, keyHash) if err != nil { if errors.Is(err, pgx.ErrNoRows) { @@ -123,23 +123,6 @@ func (s *Service) ProvisionAgent(ctx context.Context, key string, remoteIP strin return nil, fmt.Errorf("failed to lookup key: %w", err) } - // Validate key status - if dbKey.Status != sqlc.ProvisioningKeyStatusActive { - slog.Warn("Provisioning attempt with non-active key", - "key_id", uuidToString(dbKey.ID.Bytes), - "status", dbKey.Status, - "remote_ip", remoteIP) - - switch dbKey.Status { - case sqlc.ProvisioningKeyStatusExpired: - return nil, ErrKeyExpired - case sqlc.ProvisioningKeyStatusExhausted: - return nil, ErrKeyExhausted - default: - return nil, ErrKeyInvalid - } - } - // Validate expiration if time.Now().After(dbKey.ExpiresAt.Time) { slog.Warn("Provisioning attempt with expired key", @@ -149,42 +132,35 @@ func (s *Service) ProvisionAgent(ctx context.Context, key string, remoteIP strin return nil, ErrKeyExpired } - // Validate usage count - if dbKey.UsedCount >= dbKey.MaxUses { - slog.Warn("Provisioning attempt with exhausted key", - "key_id", uuidToString(dbKey.ID.Bytes), - "used_count", dbKey.UsedCount, - "max_uses", dbKey.MaxUses, - "remote_ip", remoteIP) - return nil, ErrKeyExhausted + // Atomically increment key usage count. + // The WHERE clause (used_count < max_uses AND status = 'active') ensures + // concurrent requests cannot exceed max_uses — only one will succeed. + _, err = s.queries.IncrementKeyUsage(ctx, dbKey.ID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + slog.Warn("Provisioning attempt with exhausted key", + "key_id", uuidToString(dbKey.ID.Bytes), + "remote_ip", remoteIP) + return nil, ErrKeyExhausted + } + return nil, fmt.Errorf("failed to increment key usage: %w", err) } // Create agent in database (ID is auto-generated) dbAgent, err := s.queries.CreateAgent(ctx, sqlc.CreateAgentParams{ UserID: dbKey.UserID, ProvisionedWithKeyID: pgtype.UUID{Bytes: dbKey.ID.Bytes, Valid: true}, - Metadata: []byte("{}"), // Empty JSON object + Metadata: []byte("{}"), Notes: pgtype.Text{String: "", Valid: false}, }) if err != nil { return nil, fmt.Errorf("failed to create agent: %w", err) } - // Increment key usage count - if err := s.queries.IncrementKeyUsage(ctx, dbKey.ID); err != nil { - slog.Error("Failed to increment key usage count", - "key_id", uuidToString(dbKey.ID.Bytes), - "error", err) - // Don't fail the provisioning, but log the error - } - result := &AgentProvisionResult{ AgentID: uuidToString(dbAgent.ID.Bytes), } - // TODO: Generate TLS certificate if cert service is available - // This will be implemented in a future phase when GenerateAgentCertificate method is added - slog.Info("Agent provisioned successfully", "agent_id", result.AgentID, "user_id", uuidToString(dbKey.UserID.Bytes),