From 0c3fa5f40eebbacac7f61ea92810da315a0d98fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Guih=C3=A9neuf?= Date: Wed, 28 Jan 2026 15:53:27 +0100 Subject: [PATCH] Add w2p auth + new compliance and tags endpoints - w2p authent to support passwords generated by the Python Collector - New endpoints for : - GET /nodes/{node_id}/compliance/candidate_modulesets - GET /nodes/{node_id}/compliance/candidate_rulesets - GET /nodes/{node_id}/compliance/modulesets - DEL /nodes/{node_id}/compliance/modulesets/{mset_id} - POST /nodes/{node_id}/compliance/modulesets/{mset_id} - GET /nodes/{node_id]/compliance/rulesets - DEL /nodes/{node_id}/compliance/rulesets/{rset_id} - POST /nodes/{node_id}/compliance/rulesets/{rset_id} - GET /tags - GET /tags/{tag_id} - GET /tags/{tag_id}/nodes --- cdb/db_auth_filters.go | 305 ++++++++++ cdb/db_compliance.go | 546 ++++++++++++++++++ cdb/db_nodes.go | 89 +++ cdb/db_tags.go | 75 +++ cmd/conf.go | 1 + cmd/server.go | 2 +- server/api.yaml | 346 +++++++++++ server/codegen_server_gen.go | 351 ++++++++++- server/codegen_type_gen.go | 32 + .../delete_node_compliance_moduleset.go | 74 +++ .../delete_node_compliance_ruleset.go | 78 +++ ...et_node_compliance_candidate_modulesets.go | 41 ++ .../get_node_compliance_candidate_rulesets.go | 42 ++ server/handlers/get_node_compliance_logs.go | 16 + .../get_node_compliance_modulesets.go | 34 ++ .../handlers/get_node_compliance_rulesets.go | 34 ++ server/handlers/get_tag.go | 12 + server/handlers/get_tag_nodes.go | 16 + server/handlers/get_tags.go | 38 ++ server/handlers/log.go | 15 +- server/handlers/main.go | 11 + server/handlers/middleware.go | 73 +-- .../post_node_compliance_moduleset.go | 84 +++ .../handlers/post_node_compliance_ruleset.go | 90 +++ xauth/w2p.go | 216 +++++++ 25 files changed, 2561 insertions(+), 60 deletions(-) create mode 100644 cdb/db_auth_filters.go create mode 100644 cdb/db_tags.go create mode 100644 server/handlers/delete_node_compliance_moduleset.go create mode 100644 server/handlers/delete_node_compliance_ruleset.go create mode 100644 server/handlers/get_node_compliance_candidate_modulesets.go create mode 100644 server/handlers/get_node_compliance_candidate_rulesets.go create mode 100644 server/handlers/get_node_compliance_logs.go create mode 100644 server/handlers/get_node_compliance_modulesets.go create mode 100644 server/handlers/get_node_compliance_rulesets.go create mode 100644 server/handlers/get_tag.go create mode 100644 server/handlers/get_tag_nodes.go create mode 100644 server/handlers/get_tags.go create mode 100644 server/handlers/post_node_compliance_moduleset.go create mode 100644 server/handlers/post_node_compliance_ruleset.go create mode 100644 xauth/w2p.go diff --git a/cdb/db_auth_filters.go b/cdb/db_auth_filters.go new file mode 100644 index 0000000..e413922 --- /dev/null +++ b/cdb/db_auth_filters.go @@ -0,0 +1,305 @@ +package cdb + +import ( + "context" + "errors" + "strings" +) + +type QFilterInput struct { + BaseQuery string + BaseArgs []any + + SvcField string + NodeField string + GroupField string + AppField string + UserField string + + IsNode bool + IsSvc bool + IsManager bool + + NodeID string + NodeName string + App string + SvcID string + + UserGroups []string + + NodeSvcIDs []string + NodeApps []string + PublishedServices []string + PublishedNodes []string + PublishedApps []string + UserGroupIDs []int64 + + ResolveNodeSvcIDs func(ctx context.Context, nodeID, nodeName, app string) ([]string, error) + ResolveNodeApps func(ctx context.Context, nodeID string) ([]string, error) + ResolvePublishedServices func(ctx context.Context) ([]string, error) + ResolvePublishedNodes func(ctx context.Context) ([]string, error) + ResolvePublishedApps func(ctx context.Context) ([]string, error) + ResolveUserGroupIDs func(ctx context.Context) ([]int64, error) +} + +// NodeAccessFilter returns a SQL clause and args to restrict node access by team_responsible. +// If isManager is true, no filter is returned. +// If groups is empty, only nodes with team_responsible = 'Everybody' are allowed. +func NodeAccessFilter(groups []string, isManager bool) (string, []any) { + if isManager { + return "", nil + } + + clean := cleanGroups(groups) + if len(clean) == 0 { + return " AND nodes.team_responsible = 'Everybody'", nil + } + + clause := " AND (nodes.team_responsible = 'Everybody' OR nodes.team_responsible IN (?" + args := []any{clean[0]} + for i := 1; i < len(clean); i++ { + clause += ", ?" + args = append(args, clean[i]) + } + clause += "))" + + return clause, args +} + +// PublishedNodeIDsForGroups returns node ids accessible by the provided groups. +// It includes nodes with team_responsible = 'Everybody'. +func (oDb *DB) PublishedNodeIDsForGroups(ctx context.Context, groups []string) ([]string, error) { + clean := cleanGroups(groups) + + query := "SELECT node_id FROM nodes WHERE team_responsible = 'Everybody'" + args := []any{} + if len(clean) > 0 { + query += " OR team_responsible IN (" + Placeholders(len(clean)) + ")" + for _, g := range clean { + args = append(args, g) + } + } + + rows, err := oDb.DB.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + ids := []string{} + for rows.Next() { + var id string + if err := rows.Scan(&id); err != nil { + return nil, err + } + ids = append(ids, id) + } + if err := rows.Err(); err != nil { + return nil, err + } + return ids, nil +} + +// PublishedAppsForGroups returns app names accessible by the provided groups. +func (oDb *DB) PublishedAppsForGroups(ctx context.Context, groups []string) ([]string, error) { + clean := cleanGroups(groups) + + if len(clean) == 0 { + return []string{}, nil + } + + query := "SELECT DISTINCT apps.app FROM apps " + + "JOIN apps_responsibles ON apps.id = apps_responsibles.app_id " + + "JOIN auth_group ON apps_responsibles.group_id = auth_group.id " + + "WHERE auth_group.role IN (" + Placeholders(len(clean)) + ")" + args := []any{} + for _, g := range clean { + args = append(args, g) + } + + rows, err := oDb.DB.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + apps := []string{} + for rows.Next() { + var app string + if err := rows.Scan(&app); err != nil { + return nil, err + } + apps = append(apps, app) + } + if err := rows.Err(); err != nil { + return nil, err + } + return apps, nil +} + +// QFilter builds a SQL WHERE clause to append to a query +func QFilter(ctx context.Context, in QFilterInput) (string, []any, error) { + var ( + q string + args []any + ) + + switch { + case in.SvcField != "": + switch { + case in.IsSvc: + if in.SvcID == "" { + return "", nil, errors.New("qfilter: missing svc id") + } + q, args = eqClause(in.SvcField, in.SvcID) + case in.IsNode: + ids := in.NodeSvcIDs + if len(ids) == 0 && in.ResolveNodeSvcIDs != nil { + var err error + ids, err = in.ResolveNodeSvcIDs(ctx, in.NodeID, in.NodeName, in.App) + if err != nil { + return "", nil, err + } + } + q, args = inClause(in.SvcField, toAnySlice(ids)) + case !in.IsManager: + ids := in.PublishedServices + if len(ids) == 0 && in.ResolvePublishedServices != nil { + var err error + ids, err = in.ResolvePublishedServices(ctx) + if err != nil { + return "", nil, err + } + } + q, args = inClause(in.SvcField, toAnySlice(ids)) + } + + case in.NodeField != "": + switch { + case in.IsNode: + if in.NodeID == "" { + return "", nil, errors.New("qfilter: missing node id") + } + q, args = eqClause(in.NodeField, in.NodeID) + case !in.IsManager: + ids := in.PublishedNodes + if len(ids) == 0 && in.ResolvePublishedNodes != nil { + var err error + ids, err = in.ResolvePublishedNodes(ctx) + if err != nil { + return "", nil, err + } + } + q, args = inClause(in.NodeField, toAnySlice(ids)) + } + + case in.AppField != "": + switch { + case in.IsNode: + apps := in.NodeApps + if len(apps) == 0 && in.ResolveNodeApps != nil { + var err error + apps, err = in.ResolveNodeApps(ctx, in.NodeID) + if err != nil { + return "", nil, err + } + } + q, args = inClause(in.AppField, toAnySlice(apps)) + case !in.IsManager: + apps := in.PublishedApps + if len(apps) == 0 && in.ResolvePublishedApps != nil { + var err error + apps, err = in.ResolvePublishedApps(ctx) + if err != nil { + return "", nil, err + } + } + q, args = inClause(in.AppField, toAnySlice(apps)) + } + } + + if in.GroupField != "" { + if !in.IsNode && !in.IsSvc && !in.IsManager { + q, args = inClause(in.GroupField, toAnySlice(in.UserGroups)) + } + } + + if in.UserField != "" { + if !in.IsNode && !in.IsSvc && !in.IsManager { + ids := in.UserGroupIDs + if len(ids) == 0 && in.ResolveUserGroupIDs != nil { + var err error + ids, err = in.ResolveUserGroupIDs(ctx) + if err != nil { + return "", nil, err + } + } + q, args = inClause(in.UserField, toAnyInt64Slice(ids)) + } + } + + return joinBaseQuery(in.BaseQuery, q), append(in.BaseArgs, args...), nil +} + +func cleanGroups(groups []string) []string { + clean := make([]string, 0, len(groups)) + for _, g := range groups { + g = strings.TrimSpace(g) + if g == "" || g == "Manager" { + continue + } + clean = append(clean, g) + } + return clean +} + +func joinBaseQuery(base string, q string) string { + switch { + case base == "" && q == "": + return "1=1" + case base == "": + return q + case q == "": + return base + default: + return "(" + base + ") AND (" + q + ")" + } +} + +func eqClause(field string, value any) (string, []any) { + return field + " = ?", []any{value} +} + +func inClause(field string, values []any) (string, []any) { + if len(values) == 0 { + return "1=0", nil + } + clause := field + " IN (?" + for i := 1; i < len(values); i++ { + clause += ", ?" + } + clause += ")" + return clause, values +} + +func toAnySlice(values []string) []any { + if len(values) == 0 { + return nil + } + args := make([]any, 0, len(values)) + for _, v := range values { + args = append(args, v) + } + return args +} + +func toAnyInt64Slice(values []int64) []any { + if len(values) == 0 { + return nil + } + args := make([]any, 0, len(values)) + for _, v := range values { + args = append(args, v) + } + return args +} diff --git a/cdb/db_compliance.go b/cdb/db_compliance.go index 421315a..b5f177e 100644 --- a/cdb/db_compliance.go +++ b/cdb/db_compliance.go @@ -2,8 +2,24 @@ package cdb import ( "context" + "database/sql" + "fmt" ) +type Moduleset struct { + ID int `json:"id"` + Name string `json:"modset_name"` + Author string `json:"modset_author"` + Updated string `json:"modset_updated"` +} + +type Ruleset struct { + ID int `json:"id"` + Name string `json:"ruleset_name"` + Public bool `json:"ruleset_public"` + Type string `json:"ruleset_type"` +} + func (oDb *DB) PurgeCompModulesetsNodes(ctx context.Context) error { var query = `DELETE FROM comp_node_moduleset @@ -169,3 +185,533 @@ func (oDb *DB) PurgeCompStatusSvcUnattached(ctx context.Context) error { } return nil } + +// returns the moduleset name for a given moduleset ID. +func (oDb *DB) CompModulesetName(ctx context.Context, modulesetID string) (string, error) { + const query = "SELECT modset_name FROM comp_moduleset WHERE id = ?" + var modulesetName string + + err := oDb.DB.QueryRowContext(ctx, query, modulesetID).Scan(&modulesetName) + if err != nil { + return "", fmt.Errorf("compModulesetName: %w", err) + } + return modulesetName, nil +} + +// check if a moduleset is already attached to a node +func (oDb *DB) CompModulesetAttached(ctx context.Context, nodeID, modulesetID string) (bool, error) { + const query = "SELECT EXISTS(SELECT 1 FROM comp_node_moduleset WHERE node_id = ? AND modset_id = ?)" + var exists bool + + err := oDb.DB.QueryRowContext(ctx, query, nodeID, modulesetID).Scan(&exists) + if err != nil { + return false, fmt.Errorf("compModulesetAttached: %w", err) + } + return exists, nil +} + +// detach moduleset(s) from a node +func (oDb *DB) CompModulesetDetachNode(ctx context.Context, nodeID string, modulesetIDs []string) (int64, error) { + if len(modulesetIDs) == 0 { + return 0, nil + } + + query := "DELETE FROM comp_node_moduleset WHERE node_id = ? AND modset_id IN (" + args := []any{nodeID} + + for i, id := range modulesetIDs { + if i > 0 { + query += "," + } + query += "?" + args = append(args, id) + } + query += ")" + + result, err := oDb.DB.ExecContext(ctx, query, args...) + if err != nil { + return 0, fmt.Errorf("compModulesetDetachNode: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return 0, fmt.Errorf("compModulesetDetachNode rowsAffected: %w", err) + } + + if rows > 0 { + oDb.SetChange("comp_node_moduleset") + oDb.Session.NotifyChanges(ctx) + } + + return rows, nil +} + +// returns the ruleset name for a given ruleset ID. +func (oDb *DB) CompRulesetName(ctx context.Context, rulesetID string) (string, error) { + const query = "SELECT ruleset_name FROM comp_rulesets WHERE id = ?" + var rulesetName string + + err := oDb.DB.QueryRowContext(ctx, query, rulesetID).Scan(&rulesetName) + if err != nil { + return "", fmt.Errorf("compRulesetName: %w", err) + } + return rulesetName, nil +} + +// check if a ruleset is already attached to a node +func (oDb *DB) CompRulesetAttached(ctx context.Context, nodeID, rulesetID string) (bool, error) { + const query = "SELECT EXISTS(SELECT 1 FROM comp_rulesets_nodes WHERE node_id = ? AND ruleset_id = ?)" + var exists bool + + err := oDb.DB.QueryRowContext(ctx, query, nodeID, rulesetID).Scan(&exists) + if err != nil { + return false, fmt.Errorf("compRulesetAttached: %w", err) + } + return exists, nil +} + +// detach ruleset(s) from a node +func (oDb *DB) CompRulesetDetachNode(ctx context.Context, nodeID string, rulesetIDs []string) (int64, error) { + if len(rulesetIDs) == 0 { + return 0, nil + } + + query := "DELETE FROM comp_rulesets_nodes WHERE node_id = ? AND ruleset_id IN (" + args := []any{nodeID} + + for i, id := range rulesetIDs { + if i > 0 { + query += "," + } + query += "?" + args = append(args, id) + } + query += ")" + + result, err := oDb.DB.ExecContext(ctx, query, args...) + if err != nil { + return 0, fmt.Errorf("compRulesetDetachNode: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return 0, fmt.Errorf("compRulesetDetachNode rowsAffected: %w", err) + } + + if rows > 0 { + oDb.SetChange("comp_rulesets_nodes") + oDb.Session.NotifyChanges(ctx) + } + + return rows, nil +} + +// find attached modulesets for a node +func (oDb *DB) CompNodeModulesets(ctx context.Context, nodeID string) (modulesets []int, err error) { + const query = `SELECT modset_id FROM comp_node_moduleset WHERE node_id = ?` + var rows *sql.Rows + rows, err = oDb.DB.QueryContext(ctx, query, nodeID) + if err != nil { + return + } + defer func() { _ = rows.Close() }() + for rows.Next() { + var modsetID sql.NullInt64 + err = rows.Scan(&modsetID) + if err != nil { + return + } + if modsetID.Valid { + modulesets = append(modulesets, int(modsetID.Int64)) + } + } + err = rows.Err() + return +} + +// find attached rulesets for a node +func (oDb *DB) CompNodeRulesets(ctx context.Context, nodeID string) (rulesets []int, err error) { + const query = `SELECT ruleset_id FROM comp_rulesets_nodes WHERE node_id = ?` + var rows *sql.Rows + rows, err = oDb.DB.QueryContext(ctx, query, nodeID) + if err != nil { + return + } + defer func() { _ = rows.Close() }() + for rows.Next() { + var rulesetID sql.NullInt64 + err = rows.Scan(&rulesetID) + if err != nil { + return + } + if rulesetID.Valid { + rulesets = append(rulesets, int(rulesetID.Int64)) + } + } + err = rows.Err() + return +} + +// get candidate modulesets for a node (modulesets that can be attached but are not yet attached) +func (oDb *DB) CompNodeCandidateModulesets(ctx context.Context, nodeID string, attachedModulesets []int, groups []string, isManager bool) ([]Moduleset, error) { + var query = ` + SELECT DISTINCT comp_moduleset.id, comp_moduleset.modset_name, comp_moduleset.modset_author, comp_moduleset.modset_updated + FROM comp_moduleset + JOIN comp_moduleset_team_publication ON comp_moduleset.id = comp_moduleset_team_publication.modset_id + JOIN auth_group ON auth_group.id = comp_moduleset_team_publication.group_id + JOIN nodes ON nodes.node_id = ? + WHERE (nodes.team_responsible = auth_group.role OR auth_group.role = 'Everybody') + ` + + args := []any{nodeID} + + filter, filterArgs, err := QFilter(ctx, QFilterInput{ + NodeField: "nodes.node_id", + IsManager: isManager, + UserGroups: groups, + ResolvePublishedNodes: func(ctx context.Context) ([]string, error) { + return oDb.PublishedNodeIDsForGroups(ctx, groups) + }, + }) + + if err != nil { + return nil, err + } + if filter != "" { + query += " AND (" + filter + ")" + args = append(args, filterArgs...) + } + + if len(attachedModulesets) > 0 { + query += " AND comp_moduleset.id NOT IN (?" + args = append(args, attachedModulesets[0]) + for i := 1; i < len(attachedModulesets); i++ { + query += ", ?" + args = append(args, attachedModulesets[i]) + } + query += ")" + } + + query += " ORDER BY comp_moduleset.modset_name" + + rows, err := oDb.DB.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("compNodeCandidateModulesets: %w", err) + } + defer func() { _ = rows.Close() }() + + var candidates []Moduleset + for rows.Next() { + var candidate Moduleset + if err := rows.Scan(&candidate.ID, &candidate.Name, &candidate.Author, &candidate.Updated); err != nil { + return nil, fmt.Errorf("compNodeCandidateModulesets scan: %w", err) + } + candidates = append(candidates, candidate) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("compNodeCandidateModulesets rows: %w", err) + } + + return candidates, nil +} + +// get candidate rulesets for a node (rulesets that can be attached but are not yet attached) +func (oDb *DB) CompNodeCandidateRulesets(ctx context.Context, nodeID string, attachedRulesets []int, groups []string, isManager bool) ([]Ruleset, error) { + var query = ` + SELECT DISTINCT comp_rulesets.id, comp_rulesets.ruleset_name, comp_rulesets.ruleset_public, comp_rulesets.ruleset_type + FROM comp_rulesets + JOIN comp_ruleset_team_publication ON comp_rulesets.id = comp_ruleset_team_publication.ruleset_id + JOIN auth_group ON auth_group.id = comp_ruleset_team_publication.group_id + JOIN nodes ON nodes.node_id = ? + WHERE comp_rulesets.ruleset_type = 'explicit' + AND comp_rulesets.ruleset_public = 'T' + AND (nodes.team_responsible = auth_group.role OR auth_group.role = 'Everybody') + ` + + args := []any{nodeID} + filter, filterArgs, err := QFilter(ctx, QFilterInput{ + NodeField: "nodes.node_id", + IsManager: isManager, + UserGroups: groups, + ResolvePublishedNodes: func(ctx context.Context) ([]string, error) { + return oDb.PublishedNodeIDsForGroups(ctx, groups) + }, + }) + if err != nil { + return nil, err + } + if filter != "" { + query += " AND (" + filter + ")" + args = append(args, filterArgs...) + } + + if len(attachedRulesets) > 0 { + query += " AND comp_rulesets.id NOT IN (?" + args = append(args, attachedRulesets[0]) + for i := 1; i < len(attachedRulesets); i++ { + query += ", ?" + args = append(args, attachedRulesets[i]) + } + query += ")" + } + + query += " ORDER BY comp_rulesets.ruleset_name" + + rows, err := oDb.DB.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("compNodeCandidateRulesets: %w", err) + } + defer func() { _ = rows.Close() }() + + var candidates []Ruleset + for rows.Next() { + var candidate Ruleset + var publicStr string + if err := rows.Scan(&candidate.ID, &candidate.Name, &publicStr, &candidate.Type); err != nil { + return nil, fmt.Errorf("compNodeCandidateRulesets scan: %w", err) + } + candidate.Public = (publicStr == "T") + candidates = append(candidates, candidate) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("compNodeCandidateRulesets rows: %w", err) + } + + return candidates, nil +} + +// get attached modulesets for a node with details +func (oDb *DB) CompNodeAttachedModulesets(ctx context.Context, nodeID string, groups []string, isManager bool) ([]Moduleset, error) { + query := ` + SELECT comp_moduleset.id, comp_moduleset.modset_name, comp_moduleset.modset_author, comp_moduleset.modset_updated + FROM comp_moduleset + JOIN comp_node_moduleset ON comp_moduleset.id = comp_node_moduleset.modset_id + WHERE comp_node_moduleset.node_id = ? + ` + + args := []any{nodeID} + filter, filterArgs, err := QFilter(ctx, QFilterInput{ + NodeField: "comp_node_moduleset.node_id", + IsManager: isManager, + UserGroups: groups, + ResolvePublishedNodes: func(ctx context.Context) ([]string, error) { + return oDb.PublishedNodeIDsForGroups(ctx, groups) + }, + }) + if err != nil { + return nil, err + } + if filter != "" { + query += " AND (" + filter + ")" + args = append(args, filterArgs...) + } + + query += " ORDER BY comp_moduleset.modset_name" + + rows, err := oDb.DB.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("compNodeAttachedModulesets: %w", err) + } + defer func() { _ = rows.Close() }() + + var modulesets []Moduleset + for rows.Next() { + var moduleset Moduleset + if err := rows.Scan(&moduleset.ID, &moduleset.Name, &moduleset.Author, &moduleset.Updated); err != nil { + return nil, fmt.Errorf("compNodeAttachedModulesets scan: %w", err) + } + modulesets = append(modulesets, moduleset) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("compNodeAttachedModulesets rows: %w", err) + } + + return modulesets, nil +} + +// get attached rulesets for a node with details +func (oDb *DB) CompNodeAttachedRulesets(ctx context.Context, nodeID string, groups []string, isManager bool) ([]Ruleset, error) { + query := ` + SELECT comp_rulesets.id, comp_rulesets.ruleset_name, comp_rulesets.ruleset_public, comp_rulesets.ruleset_type + FROM comp_rulesets + JOIN comp_rulesets_nodes ON comp_rulesets.id = comp_rulesets_nodes.ruleset_id + WHERE comp_rulesets_nodes.node_id = ? + ` + + args := []any{nodeID} + filter, filterArgs, err := QFilter(ctx, QFilterInput{ + NodeField: "comp_rulesets_nodes.node_id", + IsManager: isManager, + UserGroups: groups, + ResolvePublishedNodes: func(ctx context.Context) ([]string, error) { + return oDb.PublishedNodeIDsForGroups(ctx, groups) + }, + }) + if err != nil { + return nil, err + } + if filter != "" { + query += " AND (" + filter + ")" + args = append(args, filterArgs...) + } + + query += " ORDER BY comp_rulesets.ruleset_name" + + rows, err := oDb.DB.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("compNodeAttachedRulesets: %w", err) + } + defer func() { _ = rows.Close() }() + + var rulesets []Ruleset + for rows.Next() { + var ruleset Ruleset + var publicStr string + if err := rows.Scan(&ruleset.ID, &ruleset.Name, &publicStr, &ruleset.Type); err != nil { + return nil, fmt.Errorf("compNodeAttachedRulesets scan: %w", err) + } + ruleset.Public = (publicStr == "T") + rulesets = append(rulesets, ruleset) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("compNodeAttachedRulesets rows: %w", err) + } + + return rulesets, nil +} + +// checks if a moduleset can be attached to a node. +func (oDb *DB) CompModulesetAttachable(ctx context.Context, nodeID, modulesetID string) (bool, error) { + hasEveryBody, err := oDb.modulesetHasEverybodyPublication(ctx, modulesetID) + if err != nil { + return false, fmt.Errorf("compModulesetAttachable: %w", err) + } + if hasEveryBody { + return true, nil + } + + const query = ` + SELECT EXISTS( + SELECT 1 FROM nodes + JOIN auth_group ON nodes.team_responsible = auth_group.role + JOIN comp_moduleset_team_publication ON auth_group.id = comp_moduleset_team_publication.group_id + JOIN comp_moduleset ON comp_moduleset_team_publication.modset_id = comp_moduleset.id + WHERE comp_moduleset.id = ? + AND nodes.node_id = ? + ) + ` + + var attachable bool + err = oDb.DB.QueryRowContext(ctx, query, modulesetID, nodeID).Scan(&attachable) + if err != nil { + return false, fmt.Errorf("compModulesetAttachable: %w", err) + } + return attachable, nil +} + +// checks if a moduleset has "Everybody" publication rights. +func (oDb *DB) modulesetHasEverybodyPublication(ctx context.Context, modulesetID string) (bool, error) { + const query = ` + SELECT EXISTS( + SELECT 1 FROM auth_group + JOIN comp_moduleset_team_publication ON auth_group.id = comp_moduleset_team_publication.group_id + WHERE auth_group.role = 'Everybody' + AND comp_moduleset_team_publication.modset_id = ? + ) + ` + var exists bool + err := oDb.DB.QueryRowContext(ctx, query, modulesetID).Scan(&exists) + if err != nil { + return false, fmt.Errorf("modulesetHasEverybodyPublication: %w", err) + } + return exists, nil +} + +// checks if a ruleset can be attached to a node. +func (oDb *DB) CompRulesetAttachable(ctx context.Context, nodeID, rulesetID string) (bool, error) { + hasEveryBody, err := oDb.RulesetHasEverybodyPublication(ctx, rulesetID) + if err != nil { + return false, fmt.Errorf("compRulesetAttachable: %w", err) + } + if hasEveryBody { + return true, nil + } + + const query = ` + SELECT EXISTS( + SELECT 1 FROM nodes + JOIN auth_group ON nodes.team_responsible = auth_group.role + JOIN comp_ruleset_team_publication ON auth_group.id = comp_ruleset_team_publication.group_id + JOIN comp_rulesets ON comp_ruleset_team_publication.ruleset_id = comp_rulesets.id + WHERE comp_rulesets.id = ? + AND comp_rulesets.ruleset_public = "T" + AND comp_rulesets.ruleset_type = 'explicit' + AND nodes.node_id = ? + ) + ` + + var attachable bool + err = oDb.DB.QueryRowContext(ctx, query, rulesetID, nodeID).Scan(&attachable) + if err != nil { + return false, fmt.Errorf("compRulesetAttachable: %w", err) + } + return attachable, nil +} + +// attach a moduleset to a node +func (oDb *DB) CompModulesetAttachNode(ctx context.Context, nodeID, modulesetID string) (int64, error) { + const query = "INSERT INTO comp_node_moduleset (node_id, modset_id) VALUES (?, ?)" + + result, err := oDb.DB.ExecContext(ctx, query, nodeID, modulesetID) + if err != nil { + return 0, fmt.Errorf("compModulesetAttachNode: %w", err) + } + + if rows, err := result.RowsAffected(); err == nil && rows > 0 { + oDb.SetChange("comp_node_moduleset") + oDb.Session.NotifyChanges(ctx) + } + + id, _ := result.LastInsertId() + return id, nil +} + +// checks if a ruleset has "Everybody" publication rights. +func (oDb *DB) RulesetHasEverybodyPublication(ctx context.Context, rulesetID string) (bool, error) { + const query = ` + SELECT EXISTS( + SELECT 1 FROM auth_group + JOIN comp_ruleset_team_publication ON auth_group.id = comp_ruleset_team_publication.group_id + WHERE auth_group.role = 'Everybody' + AND comp_ruleset_team_publication.ruleset_id = ? + ) + ` + var exists bool + err := oDb.DB.QueryRowContext(ctx, query, rulesetID).Scan(&exists) + if err != nil { + return false, fmt.Errorf("rulesetHasEverybodyPublication: %w", err) + } + return exists, nil +} + +// attach a ruleset to a node +func (oDb *DB) CompRulesetAttachNode(ctx context.Context, nodeID, rulesetID string) (int64, error) { + const query = "INSERT INTO comp_rulesets_nodes (node_id, ruleset_id) VALUES (?, ?)" + + result, err := oDb.DB.ExecContext(ctx, query, nodeID, rulesetID) + if err != nil { + return 0, fmt.Errorf("compRulesetAttachNode: %w", err) + } + + if rows, err := result.RowsAffected(); err == nil && rows > 0 { + oDb.SetChange("comp_rulesets_nodes") + oDb.Session.NotifyChanges(ctx) + } + + id, _ := result.LastInsertId() + return id, nil +} diff --git a/cdb/db_nodes.go b/cdb/db_nodes.go index 99a66ad..98552e5 100644 --- a/cdb/db_nodes.go +++ b/cdb/db_nodes.go @@ -6,7 +6,10 @@ import ( "errors" "fmt" "log/slog" + "strings" "time" + + "github.com/google/uuid" ) type ( @@ -472,3 +475,89 @@ func (oDb *DB) UpdateVirtualAsset(ctx context.Context, svcID, nodeID string) err } return nil } + +func (oDb *DB) NodeByNodeIDOrNodename(ctx context.Context, nodeIdOrName string) (*DBNode, error) { + defer logDuration("nodeByNodeIDOrNodename", time.Now()) + if nodeIdOrName == "" { + return nil, fmt.Errorf("nodeByNodeIDOrNodename: called with empty node ID or name") + } + + // Valid UUID : should be a node_id + if _, err := uuid.Parse(nodeIdOrName); err == nil { + n, err := oDb.NodeByNodeID(ctx, nodeIdOrName) + if err != nil { + return nil, err + } + return n, nil + } + + // Otherwise treat it as a nodename and resolve the node_id first. + const query = `SELECT node_id FROM nodes WHERE nodename = ?` + rows, err := oDb.DB.QueryContext(ctx, query, nodeIdOrName) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + var nodeIDs []string + for rows.Next() { + var nodeID sql.NullString + if err := rows.Scan(&nodeID); err != nil { + return nil, err + } + if nodeID.Valid { + nodeIDs = append(nodeIDs, nodeID.String) + } + } + if err := rows.Err(); err != nil { + return nil, err + } + + switch len(nodeIDs) { + case 0: + return nil, fmt.Errorf("node %s not found", nodeIdOrName) + case 1: + n, err := oDb.NodeByNodeID(ctx, nodeIDs[0]) + if err != nil { + return nil, err + } + return n, nil + default: + return nil, fmt.Errorf("nodeByNodeIDOrNodename: multiple node_ids found for nodename %s", nodeIdOrName) + } +} + +// check if a user is responsible for a given node +func (oDb *DB) NodeResponsible(ctx context.Context, nodeID string, groups []string, isManager bool) (bool, error) { + if nodeID == "" { + return false, fmt.Errorf("nodeResponsible: must have a not empty node_id parameter") + } + if isManager { + slog.Info("xxx - Manager") + return true, nil + } + + const query = `SELECT app FROM nodes WHERE node_id = ? LIMIT 1` + var app sql.NullString + if err := oDb.DB.QueryRowContext(ctx, query, nodeID).Scan(&app); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return false, fmt.Errorf("nodeResponsible: node %s does not exist", nodeID) + } + return false, fmt.Errorf("nodeResponsible: %w", err) + } + + allowedApps, err := oDb.PublishedAppsForGroups(ctx, groups) + if err != nil { + return false, fmt.Errorf("nodeResponsible: %w", err) + } + if len(allowedApps) == 0 { + return false, nil + } + + for _, a := range allowedApps { + if strings.EqualFold(a, app.String) { + return true, nil + } + } + return false, nil +} diff --git a/cdb/db_tags.go b/cdb/db_tags.go new file mode 100644 index 0000000..6f1079f --- /dev/null +++ b/cdb/db_tags.go @@ -0,0 +1,75 @@ +package cdb + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" +) + +type Tag struct { + ID int `json:"id"` + TagName string `json:"tag_name"` + TagCreated string `json:"tag_created"` + TagExclude string `json:"tag_exclude"` + TagData interface{} `json:"tag_data"` + TagID string `json:"tag_id"` +} + +// GetTags returns all tags with id > 0, or a specific tag if tagID is provided +func (oDb *DB) GetTags(ctx context.Context, tagID *int) ([]Tag, error) { + query := ` + SELECT id, tag_name, tag_created, tag_exclude, tag_data, tag_id + FROM tags + WHERE id > 0 + ` + var args []interface{} + + if tagID != nil { + query += " AND id = ?" + args = append(args, *tagID) + } + + query += " ORDER BY tag_name" + + rows, err := oDb.DB.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("getTags: %w", err) + } + defer func() { _ = rows.Close() }() + + var tags []Tag + for rows.Next() { + var tag Tag + var tagCreated, tagExclude, tagData, tagID sql.NullString + if err := rows.Scan(&tag.ID, &tag.TagName, &tagCreated, &tagExclude, &tagData, &tagID); err != nil { + return nil, fmt.Errorf("getTags scan: %w", err) + } + if tagCreated.Valid { + tag.TagCreated = tagCreated.String + } + if tagExclude.Valid { + tag.TagExclude = tagExclude.String + } + if tagData.Valid && tagData.String != "" { + // Try to parse as JSON + var parsed any + if err := json.Unmarshal([]byte(tagData.String), &parsed); err == nil { + tag.TagData = parsed + } else { + // If not valid JSON, keep as string + tag.TagData = tagData.String + } + } + if tagID.Valid { + tag.TagID = tagID.String + } + tags = append(tags, tag) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("getTags rows: %w", err) + } + + return tags, nil +} diff --git a/cmd/conf.go b/cmd/conf.go index b3bb0d4..51a16f4 100644 --- a/cmd/conf.go +++ b/cmd/conf.go @@ -66,6 +66,7 @@ func initConfig() error { viper.SetDefault("scheduler.metrics.addr", "127.0.0.1:2111") viper.SetDefault("scheduler.task.trim.retention", 365) viper.SetDefault("scheduler.task.trim.batch_size", 1000) + viper.SetDefault("w2p_hmac", "sha512:7755f108-1b83-45dc-8302-54be8f3616a1") // config file viper.SetConfigName("config") diff --git a/cmd/server.go b/cmd/server.go index b229edc..631bca5 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -41,7 +41,7 @@ func listenAndServeServer(addr string) error { strategy := union.New( xauth.NewPublicStrategy("/oc3/api/public/", "/oc3/api/docs", "/oc3/api/version", "/oc3/api/openapi"), - xauth.NewBasicNode(db), + xauth.NewBasicWeb2py(db, viper.GetString("w2p_hmac")), ) if viper.GetBool("server.metrics.enable") { slog.Info("add handler /oc3/api/public/metrics") diff --git a/server/api.yaml b/server/api.yaml index 8da6c14..bed5e2a 100644 --- a/server/api.yaml +++ b/server/api.yaml @@ -27,6 +27,315 @@ paths: 500: $ref: '#/components/responses/500' + /nodes/{node_id}/compliance/candidate_modulesets: + get: + description: | + Get candidate modulesets for a node + operationId: GetNodeComplianceCandidateModulesets + parameters: + - $ref: '#/components/parameters/inPathNodeId' + responses: + 200: + description: List of candidate modulesets + content: + application/json: + schema: + type: array + items: + type: object + 404: + description: Node not found + 500: + description: Internal server error + security: + - basicAuth: [ ] + - bearerAuth: [ ] + tags: + - collector + + /nodes/{node_id}/compliance/candidate_rulesets: + get: + description: | + Get candidate rulesets for a node + operationId: GetNodeComplianceCandidateRulesets + parameters: + - $ref: '#/components/parameters/inPathNodeId' + responses: + 200: + description: List of candidate rulesets + content: + application/json: + schema: + type: array + items: + type: object + 404: + description: Node not found + 500: + description: Internal server error + security: + - basicAuth: [ ] + - bearerAuth: [ ] + tags: + - collector + + /nodes/{node_id}/compliance/logs: + get: + description: | + Get compliance logs for a node + operationId: GetNodeComplianceLogs + parameters: + - $ref: '#/components/parameters/inPathNodeId' + responses: + 200: + description: Compliance logs for the node + content: + application/json: + schema: + type: array + items: + type: object + 404: + description: Node not found + 500: + description: Internal server error + security: + - basicAuth: [ ] + - bearerAuth: [ ] + tags: + - collector + + /nodes/{node_id}/compliance/modulesets: + get: + description: | + Get modulesets attached to a node + operationId: GetNodeComplianceModulesets + parameters: + - $ref: '#/components/parameters/inPathNodeId' + responses: + 200: + description: List of attached modulesets + content: + application/json: + schema: + type: array + items: + type: object + 404: + description: Node not found + 500: + description: Internal server error + security: + - basicAuth: [ ] + - bearerAuth: [ ] + tags: + - collector + + /nodes/{node_id}/compliance/modulesets/{mset_id}: + delete: + description: | + Detach a moduleset from a node + operationId: DeleteNodeComplianceModuleset + parameters: + - $ref: '#/components/parameters/inPathNodeId' + - $ref: '#/components/parameters/inPathMsetId' + responses: + 204: + description: Moduleset successfully detached + 404: + description: Node or moduleset not found + 500: + description: Internal server error + security: + - basicAuth: [ ] + - bearerAuth: [ ] + tags: + - collector + post: + description: | + Attach a moduleset to a node + operationId: PostNodeComplianceModuleset + parameters: + - $ref: '#/components/parameters/inPathNodeId' + - $ref: '#/components/parameters/inPathMsetId' + requestBody: + required: true + content: + application/json: + schema: + type: object + responses: + 202: + description: Moduleset attached + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 403: + $ref: '#/components/responses/403' + 500: + $ref: '#/components/responses/500' + security: + - basicAuth: [ ] + - bearerAuth: [ ] + tags: + - collector + + /nodes/{node_id}/compliance/rulesets: + get: + description: | + Get rulesets attached to a node + operationId: GetNodeComplianceRulesets + parameters: + - $ref: '#/components/parameters/inPathNodeId' + responses: + 200: + description: List of attached rulesets + content: + application/json: + schema: + type: array + items: + type: object + 404: + description: Node not found + 500: + description: Internal server error + security: + - basicAuth: [ ] + - bearerAuth: [ ] + tags: + - collector + + /nodes/{node_id}/compliance/rulesets/{rset_id}: + delete: + description: | + Detach a ruleset from a node + operationId: DeleteNodeComplianceRuleset + parameters: + - $ref: '#/components/parameters/inPathNodeId' + - $ref: '#/components/parameters/inPathRsetId' + responses: + 204: + description: Ruleset successfully detached + 404: + description: Node or ruleset not found + 500: + description: Internal server error + security: + - basicAuth: [ ] + - bearerAuth: [ ] + tags: + - collector + post: + description: | + Attach a ruleset to a node + operationId: PostNodeComplianceRuleset + parameters: + - $ref: '#/components/parameters/inPathNodeId' + - $ref: '#/components/parameters/inPathRsetId' + requestBody: + required: true + content: + application/json: + schema: + type: object + responses: + 202: + description: Ruleset attached + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 403: + $ref: '#/components/responses/403' + 500: + $ref: '#/components/responses/500' + security: + - basicAuth: [ ] + - bearerAuth: [ ] + tags: + - collector + + /tags: + get: + operationId: GetTags + description: List existing tags + tags: + - collector + responses: + 200: + description: OK + content: + application/json: + schema: + type: array + items: + type: object + 500: + $ref: '#/components/responses/500' + security: + - basicAuth: [ ] + - bearerAuth: [ ] + + /tags/{tag_id}: + get: + operationId: GetTag + description: Display tag property + parameters: + - in: path + name: tag_id + required: true + description: ID of the tag + schema: + type: integer + tags: + - collector + responses: + 200: + description: OK + content: + application/json: + schema: + type: object + 404: + $ref: '#/components/responses/404' + 500: + $ref: '#/components/responses/500' + security: + - basicAuth: [ ] + - bearerAuth: [ ] + + /tags/{tag_id}/nodes: + get: + operationId: GetTagNodes + description: Get nodes where the tag is attached + parameters: + - in: path + name: tag_id + required: true + description: ID of the tag + schema: + type: integer + tags: + - collector + responses: + 200: + description: List of nodes with this tag + content: + application/json: + schema: + type: array + items: + type: object + 404: + description: Tag not found + 500: + $ref: '#/components/responses/500' + security: + - basicAuth: [ ] + - bearerAuth: [ ] + /version: get: operationId: GetVersion @@ -138,6 +447,43 @@ components: type: string example: "0.0.1" + parameters: + inPathMsetId: + in: path + name: mset_id + required: true + description: ID of the moduleset + schema: + type: string + + inPathNodeId: + in: path + name: node_id + required: true + description: ID of the node + schema: + type: string + + inPathRsetId: + in: path + name: rset_id + required: true + description: ID of the ruleset + schema: + type: string + + inQuerySync: + in: query + name: sync + schema: + type: boolean + + ObjectPathHeader: + name: OC3-ObjectPath + in: header + schema: + type: string + securitySchemes: basicAuth: type: http diff --git a/server/codegen_server_gen.go b/server/codegen_server_gen.go index 564f890..3c375e9 100644 --- a/server/codegen_server_gen.go +++ b/server/codegen_server_gen.go @@ -8,12 +8,14 @@ import ( "compress/gzip" "encoding/base64" "fmt" + "net/http" "net/url" "path" "strings" "github.com/getkin/kin-openapi/openapi3" "github.com/labstack/echo/v4" + "github.com/oapi-codegen/runtime" ) // ServerInterface represents all server handlers. @@ -22,6 +24,42 @@ type ServerInterface interface { // (GET /docs/openapi) GetSwagger(ctx echo.Context) error + // (GET /nodes/{node_id}/compliance/candidate_modulesets) + GetNodeComplianceCandidateModulesets(ctx echo.Context, nodeId InPathNodeId) error + + // (GET /nodes/{node_id}/compliance/candidate_rulesets) + GetNodeComplianceCandidateRulesets(ctx echo.Context, nodeId InPathNodeId) error + + // (GET /nodes/{node_id}/compliance/logs) + GetNodeComplianceLogs(ctx echo.Context, nodeId InPathNodeId) error + + // (GET /nodes/{node_id}/compliance/modulesets) + GetNodeComplianceModulesets(ctx echo.Context, nodeId InPathNodeId) error + + // (DELETE /nodes/{node_id}/compliance/modulesets/{mset_id}) + DeleteNodeComplianceModuleset(ctx echo.Context, nodeId InPathNodeId, msetId InPathMsetId) error + + // (POST /nodes/{node_id}/compliance/modulesets/{mset_id}) + PostNodeComplianceModuleset(ctx echo.Context, nodeId InPathNodeId, msetId InPathMsetId) error + + // (GET /nodes/{node_id}/compliance/rulesets) + GetNodeComplianceRulesets(ctx echo.Context, nodeId InPathNodeId) error + + // (DELETE /nodes/{node_id}/compliance/rulesets/{rset_id}) + DeleteNodeComplianceRuleset(ctx echo.Context, nodeId InPathNodeId, rsetId InPathRsetId) error + + // (POST /nodes/{node_id}/compliance/rulesets/{rset_id}) + PostNodeComplianceRuleset(ctx echo.Context, nodeId InPathNodeId, rsetId InPathRsetId) error + + // (GET /tags) + GetTags(ctx echo.Context) error + + // (GET /tags/{tag_id}) + GetTag(ctx echo.Context, tagId int) error + + // (GET /tags/{tag_id}/nodes) + GetTagNodes(ctx echo.Context, tagId int) error + // (GET /version) GetVersion(ctx echo.Context) error } @@ -40,6 +78,271 @@ func (w *ServerInterfaceWrapper) GetSwagger(ctx echo.Context) error { return err } +// GetNodeComplianceCandidateModulesets converts echo context to params. +func (w *ServerInterfaceWrapper) GetNodeComplianceCandidateModulesets(ctx echo.Context) error { + var err error + // ------------- Path parameter "node_id" ------------- + var nodeId InPathNodeId + + err = runtime.BindStyledParameterWithOptions("simple", "node_id", ctx.Param("node_id"), &nodeId, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true}) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter node_id: %s", err)) + } + + ctx.Set(BasicAuthScopes, []string{}) + + ctx.Set(BearerAuthScopes, []string{}) + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.GetNodeComplianceCandidateModulesets(ctx, nodeId) + return err +} + +// GetNodeComplianceCandidateRulesets converts echo context to params. +func (w *ServerInterfaceWrapper) GetNodeComplianceCandidateRulesets(ctx echo.Context) error { + var err error + // ------------- Path parameter "node_id" ------------- + var nodeId InPathNodeId + + err = runtime.BindStyledParameterWithOptions("simple", "node_id", ctx.Param("node_id"), &nodeId, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true}) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter node_id: %s", err)) + } + + ctx.Set(BasicAuthScopes, []string{}) + + ctx.Set(BearerAuthScopes, []string{}) + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.GetNodeComplianceCandidateRulesets(ctx, nodeId) + return err +} + +// GetNodeComplianceLogs converts echo context to params. +func (w *ServerInterfaceWrapper) GetNodeComplianceLogs(ctx echo.Context) error { + var err error + // ------------- Path parameter "node_id" ------------- + var nodeId InPathNodeId + + err = runtime.BindStyledParameterWithOptions("simple", "node_id", ctx.Param("node_id"), &nodeId, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true}) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter node_id: %s", err)) + } + + ctx.Set(BasicAuthScopes, []string{}) + + ctx.Set(BearerAuthScopes, []string{}) + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.GetNodeComplianceLogs(ctx, nodeId) + return err +} + +// GetNodeComplianceModulesets converts echo context to params. +func (w *ServerInterfaceWrapper) GetNodeComplianceModulesets(ctx echo.Context) error { + var err error + // ------------- Path parameter "node_id" ------------- + var nodeId InPathNodeId + + err = runtime.BindStyledParameterWithOptions("simple", "node_id", ctx.Param("node_id"), &nodeId, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true}) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter node_id: %s", err)) + } + + ctx.Set(BasicAuthScopes, []string{}) + + ctx.Set(BearerAuthScopes, []string{}) + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.GetNodeComplianceModulesets(ctx, nodeId) + return err +} + +// DeleteNodeComplianceModuleset converts echo context to params. +func (w *ServerInterfaceWrapper) DeleteNodeComplianceModuleset(ctx echo.Context) error { + var err error + // ------------- Path parameter "node_id" ------------- + var nodeId InPathNodeId + + err = runtime.BindStyledParameterWithOptions("simple", "node_id", ctx.Param("node_id"), &nodeId, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true}) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter node_id: %s", err)) + } + + // ------------- Path parameter "mset_id" ------------- + var msetId InPathMsetId + + err = runtime.BindStyledParameterWithOptions("simple", "mset_id", ctx.Param("mset_id"), &msetId, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true}) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter mset_id: %s", err)) + } + + ctx.Set(BasicAuthScopes, []string{}) + + ctx.Set(BearerAuthScopes, []string{}) + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.DeleteNodeComplianceModuleset(ctx, nodeId, msetId) + return err +} + +// PostNodeComplianceModuleset converts echo context to params. +func (w *ServerInterfaceWrapper) PostNodeComplianceModuleset(ctx echo.Context) error { + var err error + // ------------- Path parameter "node_id" ------------- + var nodeId InPathNodeId + + err = runtime.BindStyledParameterWithOptions("simple", "node_id", ctx.Param("node_id"), &nodeId, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true}) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter node_id: %s", err)) + } + + // ------------- Path parameter "mset_id" ------------- + var msetId InPathMsetId + + err = runtime.BindStyledParameterWithOptions("simple", "mset_id", ctx.Param("mset_id"), &msetId, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true}) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter mset_id: %s", err)) + } + + ctx.Set(BasicAuthScopes, []string{}) + + ctx.Set(BearerAuthScopes, []string{}) + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.PostNodeComplianceModuleset(ctx, nodeId, msetId) + return err +} + +// GetNodeComplianceRulesets converts echo context to params. +func (w *ServerInterfaceWrapper) GetNodeComplianceRulesets(ctx echo.Context) error { + var err error + // ------------- Path parameter "node_id" ------------- + var nodeId InPathNodeId + + err = runtime.BindStyledParameterWithOptions("simple", "node_id", ctx.Param("node_id"), &nodeId, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true}) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter node_id: %s", err)) + } + + ctx.Set(BasicAuthScopes, []string{}) + + ctx.Set(BearerAuthScopes, []string{}) + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.GetNodeComplianceRulesets(ctx, nodeId) + return err +} + +// DeleteNodeComplianceRuleset converts echo context to params. +func (w *ServerInterfaceWrapper) DeleteNodeComplianceRuleset(ctx echo.Context) error { + var err error + // ------------- Path parameter "node_id" ------------- + var nodeId InPathNodeId + + err = runtime.BindStyledParameterWithOptions("simple", "node_id", ctx.Param("node_id"), &nodeId, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true}) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter node_id: %s", err)) + } + + // ------------- Path parameter "rset_id" ------------- + var rsetId InPathRsetId + + err = runtime.BindStyledParameterWithOptions("simple", "rset_id", ctx.Param("rset_id"), &rsetId, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true}) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter rset_id: %s", err)) + } + + ctx.Set(BasicAuthScopes, []string{}) + + ctx.Set(BearerAuthScopes, []string{}) + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.DeleteNodeComplianceRuleset(ctx, nodeId, rsetId) + return err +} + +// PostNodeComplianceRuleset converts echo context to params. +func (w *ServerInterfaceWrapper) PostNodeComplianceRuleset(ctx echo.Context) error { + var err error + // ------------- Path parameter "node_id" ------------- + var nodeId InPathNodeId + + err = runtime.BindStyledParameterWithOptions("simple", "node_id", ctx.Param("node_id"), &nodeId, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true}) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter node_id: %s", err)) + } + + // ------------- Path parameter "rset_id" ------------- + var rsetId InPathRsetId + + err = runtime.BindStyledParameterWithOptions("simple", "rset_id", ctx.Param("rset_id"), &rsetId, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true}) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter rset_id: %s", err)) + } + + ctx.Set(BasicAuthScopes, []string{}) + + ctx.Set(BearerAuthScopes, []string{}) + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.PostNodeComplianceRuleset(ctx, nodeId, rsetId) + return err +} + +// GetTags converts echo context to params. +func (w *ServerInterfaceWrapper) GetTags(ctx echo.Context) error { + var err error + + ctx.Set(BasicAuthScopes, []string{}) + + ctx.Set(BearerAuthScopes, []string{}) + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.GetTags(ctx) + return err +} + +// GetTag converts echo context to params. +func (w *ServerInterfaceWrapper) GetTag(ctx echo.Context) error { + var err error + // ------------- Path parameter "tag_id" ------------- + var tagId int + + err = runtime.BindStyledParameterWithOptions("simple", "tag_id", ctx.Param("tag_id"), &tagId, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true}) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter tag_id: %s", err)) + } + + ctx.Set(BasicAuthScopes, []string{}) + + ctx.Set(BearerAuthScopes, []string{}) + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.GetTag(ctx, tagId) + return err +} + +// GetTagNodes converts echo context to params. +func (w *ServerInterfaceWrapper) GetTagNodes(ctx echo.Context) error { + var err error + // ------------- Path parameter "tag_id" ------------- + var tagId int + + err = runtime.BindStyledParameterWithOptions("simple", "tag_id", ctx.Param("tag_id"), &tagId, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true}) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter tag_id: %s", err)) + } + + ctx.Set(BasicAuthScopes, []string{}) + + ctx.Set(BearerAuthScopes, []string{}) + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.GetTagNodes(ctx, tagId) + return err +} + // GetVersion converts echo context to params. func (w *ServerInterfaceWrapper) GetVersion(ctx echo.Context) error { var err error @@ -78,6 +381,18 @@ func RegisterHandlersWithBaseURL(router EchoRouter, si ServerInterface, baseURL } router.GET(baseURL+"/docs/openapi", wrapper.GetSwagger) + router.GET(baseURL+"/nodes/:node_id/compliance/candidate_modulesets", wrapper.GetNodeComplianceCandidateModulesets) + router.GET(baseURL+"/nodes/:node_id/compliance/candidate_rulesets", wrapper.GetNodeComplianceCandidateRulesets) + router.GET(baseURL+"/nodes/:node_id/compliance/logs", wrapper.GetNodeComplianceLogs) + router.GET(baseURL+"/nodes/:node_id/compliance/modulesets", wrapper.GetNodeComplianceModulesets) + router.DELETE(baseURL+"/nodes/:node_id/compliance/modulesets/:mset_id", wrapper.DeleteNodeComplianceModuleset) + router.POST(baseURL+"/nodes/:node_id/compliance/modulesets/:mset_id", wrapper.PostNodeComplianceModuleset) + router.GET(baseURL+"/nodes/:node_id/compliance/rulesets", wrapper.GetNodeComplianceRulesets) + router.DELETE(baseURL+"/nodes/:node_id/compliance/rulesets/:rset_id", wrapper.DeleteNodeComplianceRuleset) + router.POST(baseURL+"/nodes/:node_id/compliance/rulesets/:rset_id", wrapper.PostNodeComplianceRuleset) + router.GET(baseURL+"/tags", wrapper.GetTags) + router.GET(baseURL+"/tags/:tag_id", wrapper.GetTag) + router.GET(baseURL+"/tags/:tag_id/nodes", wrapper.GetTagNodes) router.GET(baseURL+"/version", wrapper.GetVersion) } @@ -85,19 +400,29 @@ func RegisterHandlersWithBaseURL(router EchoRouter, si ServerInterface, baseURL // Base64 encoded, gzipped, json marshaled Swagger object var swaggerSpec = []string{ - "H4sIAAAAAAAC/7xU34/bRBD+V0YDD61k7KQpIJmnCjh6gLiKpPDQ5GG9nthb2bvL7Pi4a5T/HY19yeV+", - "tQKkPmWz8818n7+ZnR3a0MfgyUvCcodMKQafaPzzcjbXHxu8kBc9mhg7Z4244Iv3KXi9S7al3ujpS6Yt", - "lvhFcVuzmKKpeMOh6qjH/X6fYU3JsotaBkt8680gbWD3gWrcZ/hytvgctGeBK1fX5JXz69nsc3CeeyH2", - "poMl8SUx/MgcGBV3k6y1D/nlDiOHSCxu6kdNYlw3nU6rvoJ26I3/isnUpuoI6Cp2xo/aIUWybussSABp", - "XYJg7cBM3hKELUhLax8nxnztMUO5joQlJmHnG/UmiZEhPaRdtQSvV6s3MAHAhprg2bvfz77/9sVivslg", - "SXaU8M1zaMgTG6EaquuJM7BrnIc0GbEN/IQ6eEyc80INsaoTJx095klqA0t235o09L3h63vFQevmAOcC", - "y9cXb3/9Ye1/u1iBbY1vCLYc+lNhEp6WmQFdWYqy9vpJceAYEiUFdcGazn2YuvKM8ibPYEjON5pqrLhL", - "gpv5W3tPTRA3Yr+DRASP2LrIXz5/tGX7DJn+GhxTjeW7w9gcG3nwbHNMDNV7sqJuXhInNw373dk7CdCV", - "6aN6jrN8ls8/yX9Ifcinw0V2YCfXS53/iaoyydlXg7THJ6c54+0tVysSVXBFhokP6OnfWeDeCJb4858r", - "zE5KjNH7NVSF89ug+TfDhCGST5cWbOg6shIYTHR4Yg/O81k+UwEK1WCJi/Eqw2ikHT+kqINNxRGww4bG", - "1aK+jq09r7HEn0iWf5umGZXd2b8v/uVSuu/ug/Vz8cu0X+dPrbAjfaGg2138KeziZId+HKsgFSamSToc", - "cag6Z3Gjd8XJkN1YdVc/kwzstRdwgGYP3fzjGPpfbn5sxR/Yn/T4P3lhTTSV69z45jb76XnoetToDgfu", - "sMQi2EWh87Tf7P8JAAD//00dve6+BwAA", + "H4sIAAAAAAAC/+xZb2/bthP+KgR/vxctoFlOk22A96pNljZbmwaOu71IjIKmzhILiVSPpzSu4e8+kLJl", + "2ZYdJx6cJtirxNKJ9+d5eM9JHHNpstxo0GR5Z8xzgSIDAvS/lL4QlHywQGeR+x2BlahyUkbzDj87YWbI", + "KAGWmahIwQLxgCt3KxeU8IBrkQHv8MwCfVYRDzjC10IhRLxDWEDArUwgE25pGuXO1BIqHfPJJJg6PzcR", + "bHauTQTNft2dh/rt3pk0bkoZH5DyxBnb3GgLvvpH7bb7I40m0OT+FXmeKilcKOEX6+IZ19b7P8KQd/j/", + "wjmkYXnXhhdoBilkpZfFjN6IiHXhawGW+CTgR+2DfXj9pEVBiUH1HaLS7eE+3J4aHKgoAl36PNqHz3ND", + "7NQU2uf5835APdMEqEXKLgFvANnviAY9vacPu7Vnz7t9jyYHJFUyLwISKl0l/2uWFJnQPyGISAxSYHCb", + "p0L72JnNQaqhkowMo0RZZqQsEEFLmG6Za52XHlvXmgfL/A+4JUGFXXXbS4C96/UuWGnApImAvbjqnh7/", + "+urwoB+wS5A+hF9eshg0oCCI2GBU+jSoYqWZLQsxNLgmOtYUnNIEMaCLjhSl0FQTmxikYLk0tsgygaOl", + "xZlbt8XYGbHLdx8/vT+51ucfe0wmQsfAhmiyemBk1ocZMLiVkNO1dinlBebGgnVGqZEiVd9LVF5AK24F", + "rLBKx+5RIUndAJvy71priA0pb/sbswCsoayHraOXjZBN6u3takabCshZzfrVg2bwBaRvMzeAVpVkX+Re", + "7Qbciix3NeftVrt1cKf/2aOr/hy5QBaoaHTp+F+6Ggir5OuCkmrLuWf81bmvhCh3AQ9AIODMuvx1ajAT", + "xDv8j797PKgt4e8ur1GKy9D43l+SiZsctL2RTJo0BUkGmcgVr5WHH7TarbYLwJm6mx1+6C8FXnJ8ImFk", + "pA0rgzGPwbcWV1cPrVMy/hbo8puIYx/ZgtK8umdTWq7uSvv5+GdNSppaWOU+dEbz/n+X7WGth262dUYu", + "MBKxdeTIi0GqJO+7a6GbDGw4ng4IE79CqoSWEEqhIxUJgs/VVGNrJV3M8y0Qqx6Yj0HWtxnhRxO/b1aA", + "cGPNceX0eLbEh7nLYGEWu2pOd24SLoxLk/6OCCuCzDZAXVFaIIpRE/TvlSXXhZrKUhPdZYmM3CBHbLik", + "k2tUbdrMYa5q093tK1Xb11f9STBe2LtXfVecGSuqfXcfYuD9aIE7kqL7/CiBz4EQqYnvoEBly5ztPfF/", + "75Z/qpAfN6Revaw9ZdC3FIWaFAgiIROI3DS3PfrPSAmqAjwTIZinEY6nnzYmZYgpUMMbwgm4/JmY518O", + "+mvJcOIXWsOHHekQbGk//ejTQJ8G1KrgmC2kBGuHRZqOWAQl8JvRNlgrzGNDH/Dc2IY9/ZpWMNy0nS+M", + "pR8IP/9x542JRrtO+YsfsyYr3Hi1iRuzPlDSob3NuN/+YV4j9tBXthorcTdZ6T47UXkWoyRWgoL3ERR8", + "sJx099qMuvcRk+5OUoJPRUjwQTLyiLg9poh0/5OQeR8pr6/RCN8f4VZZUjpm3rRBFHrl9Ufq4eXXwUer", + "XTgmEc96bGMRT5TNUzFy5WPTr9OjNWVc3YnrzgvJGzecFZbRbHNUWJ1G7CzAW3/BPdpmLxz9GHiWCrtx", + "fPIW7FsCCDNQmJoPU2tAPvfrPk2kdx+1pjVTlJQnZy69dTLcE3Gj9O6XFbUzpEYmIFCBmolcsZlpA/B/", + "Vbd2AmDTCe7M+7/TJKtyiFwMVKr8kVp/UlbWzTglbQtMeYeHRh6GIld80p/8EwAA//9aFG5UBiIAAA==", } // GetSwagger returns the content of the embedded swagger specification file diff --git a/server/codegen_type_gen.go b/server/codegen_type_gen.go index dc6f0a0..efc6072 100644 --- a/server/codegen_type_gen.go +++ b/server/codegen_type_gen.go @@ -3,6 +3,11 @@ // Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.5.1 DO NOT EDIT. package server +const ( + BasicAuthScopes = "basicAuth.Scopes" + BearerAuthScopes = "bearerAuth.Scopes" +) + // Problem defines model for Problem. type Problem struct { // Detail A human-readable explanation specific to this occurrence of the @@ -25,11 +30,38 @@ type Version struct { Version string `json:"version"` } +// InPathMsetId defines model for inPathMsetId. +type InPathMsetId = string + +// InPathNodeId defines model for inPathNodeId. +type InPathNodeId = string + +// InPathRsetId defines model for inPathRsetId. +type InPathRsetId = string + +// N400 defines model for 400. +type N400 = Problem + // N401 defines model for 401. type N401 = Problem // N403 defines model for 403. type N403 = Problem +// N404 defines model for 404. +type N404 = Problem + // N500 defines model for 500. type N500 = Problem + +// PostNodeComplianceModulesetJSONBody defines parameters for PostNodeComplianceModuleset. +type PostNodeComplianceModulesetJSONBody = map[string]interface{} + +// PostNodeComplianceRulesetJSONBody defines parameters for PostNodeComplianceRuleset. +type PostNodeComplianceRulesetJSONBody = map[string]interface{} + +// PostNodeComplianceModulesetJSONRequestBody defines body for PostNodeComplianceModuleset for application/json ContentType. +type PostNodeComplianceModulesetJSONRequestBody = PostNodeComplianceModulesetJSONBody + +// PostNodeComplianceRulesetJSONRequestBody defines body for PostNodeComplianceRuleset for application/json ContentType. +type PostNodeComplianceRulesetJSONRequestBody = PostNodeComplianceRulesetJSONBody diff --git a/server/handlers/delete_node_compliance_moduleset.go b/server/handlers/delete_node_compliance_moduleset.go new file mode 100644 index 0000000..885b4ef --- /dev/null +++ b/server/handlers/delete_node_compliance_moduleset.go @@ -0,0 +1,74 @@ +package serverhandlers + +import ( + "context" + "fmt" + "net/http" + + "github.com/labstack/echo/v4" +) + +// DeleteNodeComplianceModuleset handles DELETE /nodes/{node_id}/compliance/modulesets/{mset_id} +func (a *Api) DeleteNodeComplianceModuleset(c echo.Context, nodeId string, msetId string) error { + log := getLog(c) + odb := a.cdbSession() + ctx := c.Request().Context() + odb.CreateTx(ctx, nil) + ctx, cancel := context.WithTimeout(ctx, a.SyncTimeout) + defer cancel() + + var success bool + + defer func() { + if success { + odb.Commit() + } else { + odb.Rollback() + } + }() + + log.Info("DeleteNodeComplianceModuleset called", "node_id", nodeId, "mset_id", msetId) + + responsible, err := odb.NodeResponsible(ctx, nodeId, UserGroupsFromContext(c), IsManager(c)) + if err != nil { + log.Error("DeleteNodeComplianceModuleset: cannot check if user is responsible for the node", "node_id", nodeId, "error", err) + return JSONProblemf(c, http.StatusInternalServerError, "InternalError", "cannot check if user is responsible for node %s", nodeId) + } + if !responsible { + log.Info("DeleteNodeComplianceModuleset: user is not responsible for this node", "node_id", nodeId) + return JSONProblemf(c, http.StatusForbidden, "Forbidden", "user is not responsible for node %s", nodeId) + } + + // get moduleset name + _, err = odb.CompModulesetName(ctx, msetId) + if err != nil { + log.Error("DeleteNodeComplianceModuleset: cannot find moduleset", "mset_id", msetId, "error", err) + return JSONProblemf(c, http.StatusNotFound, "NotFound", "moduleset %s not found", msetId) + } + + // check if the moduleset is attached to the node + attached, err := odb.CompModulesetAttached(ctx, nodeId, msetId) + if err != nil { + log.Error("DeleteNodeComplianceModuleset: cannot check if moduleset is attached", "node_id", nodeId, "mset_id", msetId, "error", err) + return JSONProblemf(c, http.StatusInternalServerError, "InternalError", "cannot check if moduleset %s is attached to node %s", msetId, nodeId) + } + if !attached { + log.Info("DeleteNodeComplianceModuleset: moduleset is not attached to this node", "node_id", nodeId, "mset_id", msetId) + return JSONProblemf(c, http.StatusConflict, "Conflict", "moduleset %s is not attached to this node", msetId) + } + + // detach moduleset from node + _, err = odb.CompModulesetDetachNode(ctx, nodeId, []string{msetId}) + if err != nil { + log.Error("DeleteNodeComplianceModuleset: cannot detach moduleset from node", "node_id", nodeId, "mset_id", msetId, "error", err) + return JSONProblemf(c, http.StatusInternalServerError, "InternalError", "cannot detach moduleset %s from node %s", msetId, nodeId) + } + + success = true + + response := map[string]string{ + "info": fmt.Sprintf("moduleset %s detached from node %s", msetId, nodeId), + } + + return c.JSON(http.StatusAccepted, response) +} diff --git a/server/handlers/delete_node_compliance_ruleset.go b/server/handlers/delete_node_compliance_ruleset.go new file mode 100644 index 0000000..9576eb1 --- /dev/null +++ b/server/handlers/delete_node_compliance_ruleset.go @@ -0,0 +1,78 @@ +package serverhandlers + +import ( + "context" + "fmt" + "net/http" + + "github.com/labstack/echo/v4" +) + +// DeleteNodeComplianceRuleset handles DELETE /nodes/{node_id}/compliance/rulesets/{rset_id} +func (a *Api) DeleteNodeComplianceRuleset(c echo.Context, nodeId string, rsetId string) error { + log := getLog(c) + odb := a.cdbSession() + ctx := c.Request().Context() + odb.CreateTx(ctx, nil) + ctx, cancel := context.WithTimeout(ctx, a.SyncTimeout) + defer cancel() + + var success bool + + defer func() { + if success { + odb.Commit() + } else { + odb.Rollback() + } + }() + + log.Info("DeleteNodeComplianceRuleset called", "node_id", nodeId, "rset_id", rsetId) + + responsible, err := odb.NodeResponsible(ctx, nodeId, UserGroupsFromContext(c), IsManager(c)) + if err != nil { + log.Error("DeleteNodeComplianceRuleset: cannot check if user is responsible for the node", "node_id", nodeId, "error", err) + return JSONProblemf(c, http.StatusInternalServerError, "InternalError", "cannot check if user is responsible for node %s", nodeId) + } + if !responsible { + log.Info("DeleteNodeComplianceRuleset: user is not responsible for this node", "node_id", nodeId) + return JSONProblemf(c, http.StatusForbidden, "Forbidden", "user is not responsible for node %s", nodeId) + } + + // get ruleset name + rset, err := odb.CompRulesetName(ctx, rsetId) + if err != nil { + log.Error("PostNodeComplianceRuleset: cannot find ruleset", "rset_id", rsetId, "error", err) + return JSONProblemf(c, http.StatusNotFound, "NotFound", "ruleset %s not found", rsetId) + } else { + log.Info("Detaching ruleset from node", "ruleset", rset, "node_id", nodeId) + } + + // check if the ruleset is attached to the node + attached, err := odb.CompRulesetAttached(ctx, nodeId, rsetId) + if err != nil { + log.Error("DeleteNodeComplianceRuleset: cannot check if ruleset is attached", "node_id", nodeId, "rset_id", rsetId, "error", err) + return JSONProblemf(c, http.StatusInternalServerError, "InternalError", "cannot check if ruleset %s is attached to node %s", rsetId, nodeId) + } + if !attached { + log.Info("DeleteNodeComplianceRuleset: ruleset is not attached to this node", "node_id", nodeId, "rset_id", rsetId) + return JSONProblemf(c, http.StatusConflict, "Conflict", "ruleset %s is not attached to this node", rsetId) + } else { + log.Info("DeleteNodeComplianceRuleset: ruleset is attached to this node, proceeding to detach", "node_id", nodeId, "rset_id", rsetId) + } + + // detach ruleset from node + _, err = odb.CompRulesetDetachNode(c.Request().Context(), nodeId, []string{rsetId}) + if err != nil { + log.Error("DeleteNodeComplianceRuleset: cannot detach ruleset from node", "node_id", nodeId, "rset_id", rsetId, "error", err) + return JSONProblemf(c, http.StatusInternalServerError, "InternalError", "cannot detach ruleset %s from node %s", rsetId, nodeId) + } + + success = true + + response := map[string]string{ + "info": fmt.Sprintf("ruleset %s detached from node %s", rsetId, nodeId), + } + + return c.JSON(http.StatusAccepted, response) +} diff --git a/server/handlers/get_node_compliance_candidate_modulesets.go b/server/handlers/get_node_compliance_candidate_modulesets.go new file mode 100644 index 0000000..1ff1423 --- /dev/null +++ b/server/handlers/get_node_compliance_candidate_modulesets.go @@ -0,0 +1,41 @@ +package serverhandlers + +import ( + "net/http" + + "github.com/labstack/echo/v4" +) + +// GetNodeComplianceCandidateModulesets handles GET /nodes/{node_id}/compliance/candidate_modulesets +func (a *Api) GetNodeComplianceCandidateModulesets(c echo.Context, nodeId string) error { + log := getLog(c) + odb := a.cdbSession() + ctx := c.Request().Context() + + log.Info("GetNodeComplianceCandidateModulesets called", "node_id", nodeId) + + // get node ID + node, err := odb.NodeByNodeIDOrNodename(ctx, nodeId) + if err != nil { + log.Error("GetNodeComplianceCandidateModulesets: cannot find node", "node", nodeId, "error", err) + return JSONProblemf(c, http.StatusNotFound, "NotFound", "node %s not found", nodeId) + } + + // get modulesets already attached to the node + attachedModulesets, err := odb.CompNodeModulesets(ctx, node.NodeID) + if err != nil { + log.Error("GetNodeComplianceCandidateModulesets: cannot get attached modulesets", "node_id", node.NodeID, "error", err) + return JSONProblemf(c, http.StatusInternalServerError, "InternalError", "cannot get attached modulesets for node %s", node.NodeID) + } + + // get candidate modulesets + groups := UserGroupsFromContext(c) + isManager := IsManager(c) + candidates, err := odb.CompNodeCandidateModulesets(ctx, node.NodeID, attachedModulesets, groups, isManager) + if err != nil { + log.Error("GetNodeComplianceCandidateModulesets: cannot get candidate modulesets", "node_id", node.NodeID, "error", err) + return JSONProblemf(c, http.StatusInternalServerError, "InternalError", "cannot get candidate modulesets for node %s", node.NodeID) + } + + return c.JSON(http.StatusOK, candidates) +} diff --git a/server/handlers/get_node_compliance_candidate_rulesets.go b/server/handlers/get_node_compliance_candidate_rulesets.go new file mode 100644 index 0000000..b43432a --- /dev/null +++ b/server/handlers/get_node_compliance_candidate_rulesets.go @@ -0,0 +1,42 @@ +package serverhandlers + +import ( + "net/http" + + "github.com/labstack/echo/v4" +) + +// GetNodeComplianceCandidateRulesets handles GET /nodes/{node_id}/compliance/candidate_rulesets +func (a *Api) GetNodeComplianceCandidateRulesets(c echo.Context, nodeId string) error { + log := getLog(c) + odb := a.cdbSession() + ctx := c.Request().Context() + + log.Info("GetNodeComplianceCandidateRulesets called", "node_id", nodeId) + + // get node ID + node, err := a.cdbSession().NodeByNodeIDOrNodename(c.Request().Context(), nodeId) + if err != nil { + log.Error("GetNodeComplianceCandidateRulesets: cannot find node", "node", nodeId, "error", err) + return JSONProblemf(c, 404, "NotFound", "node %s not found", nodeId) + } + + // get rulesets already attached to the node + attachedRulesets, err := odb.CompNodeRulesets(ctx, node.NodeID) + if err != nil { + log.Error("GetNodeComplianceCandidateRulesets: cannot get attached rulesets", "node_id", node.NodeID, "error", err) + return JSONProblemf(c, http.StatusInternalServerError, "InternalError", "cannot get attached rulesets for node %s", node.NodeID) + } + + // get candidate rulesets + groups := UserGroupsFromContext(c) + isManager := IsManager(c) + candidates, err := odb.CompNodeCandidateRulesets(ctx, node.NodeID, attachedRulesets, groups, isManager) + if err != nil { + log.Error("GetNodeComplianceCandidateRulesets: cannot get candidate rulesets", "node_id", node.NodeID, "error", err) + return JSONProblemf(c, http.StatusInternalServerError, "InternalError", "cannot get candidate rulesets for node %s", node.NodeID) + } + + return c.JSON(http.StatusOK, candidates) + +} diff --git a/server/handlers/get_node_compliance_logs.go b/server/handlers/get_node_compliance_logs.go new file mode 100644 index 0000000..3a8fdac --- /dev/null +++ b/server/handlers/get_node_compliance_logs.go @@ -0,0 +1,16 @@ +package serverhandlers + +import ( + "github.com/labstack/echo/v4" +) + +// GetNodeComplianceLogs handles GET /nodes/{node_id}/compliance/logs +func (a *Api) GetNodeComplianceLogs(c echo.Context, nodeId string) error { + log := getLog(c) + + log.Info("GetNodeComplianceLogs called", "node_id", nodeId) + + // TODO + + return c.JSON(200, []any{}) +} diff --git a/server/handlers/get_node_compliance_modulesets.go b/server/handlers/get_node_compliance_modulesets.go new file mode 100644 index 0000000..fbdb8c4 --- /dev/null +++ b/server/handlers/get_node_compliance_modulesets.go @@ -0,0 +1,34 @@ +package serverhandlers + +import ( + "net/http" + + "github.com/labstack/echo/v4" +) + +// GetNodeComplianceModulesets handles GET /nodes/{node_id}/compliance/modulesets +func (a *Api) GetNodeComplianceModulesets(c echo.Context, nodeId string) error { + log := getLog(c) + odb := a.cdbSession() + ctx := c.Request().Context() + + log.Info("GetNodeComplianceModulesets called", "node_id", nodeId) + + // get node ID + node, err := odb.NodeByNodeIDOrNodename(ctx, nodeId) + if err != nil { + log.Error("GetNodeComplianceModulesets: cannot find node", "node", nodeId, "error", err) + return JSONProblemf(c, http.StatusNotFound, "NotFound", "node %s not found", nodeId) + } + + // get attached modulesets with details + groups := UserGroupsFromContext(c) + isManager := IsManager(c) + modulesets, err := odb.CompNodeAttachedModulesets(ctx, node.NodeID, groups, isManager) + if err != nil { + log.Error("GetNodeComplianceModulesets: cannot get attached modulesets", "node_id", node.NodeID, "error", err) + return JSONProblemf(c, http.StatusInternalServerError, "InternalError", "cannot get attached modulesets for node %s", node.NodeID) + } + + return c.JSON(http.StatusOK, modulesets) +} diff --git a/server/handlers/get_node_compliance_rulesets.go b/server/handlers/get_node_compliance_rulesets.go new file mode 100644 index 0000000..d7c1e26 --- /dev/null +++ b/server/handlers/get_node_compliance_rulesets.go @@ -0,0 +1,34 @@ +package serverhandlers + +import ( + "net/http" + + "github.com/labstack/echo/v4" +) + +// GetNodeComplianceRulesets handles GET /nodes/{node_id}/compliance/rulesets +func (a *Api) GetNodeComplianceRulesets(c echo.Context, nodeId string) error { + log := getLog(c) + odb := a.cdbSession() + ctx := c.Request().Context() + + log.Info("GetNodeComplianceRulesets called", "node_id", nodeId) + + // get node ID + node, err := odb.NodeByNodeIDOrNodename(ctx, nodeId) + if err != nil { + log.Error("GetNodeComplianceRulesets: cannot find node", "node", nodeId, "error", err) + return JSONProblemf(c, http.StatusNotFound, "NotFound", "node %s not found", nodeId) + } + + // get attached rulesets with details + groups := UserGroupsFromContext(c) + isManager := IsManager(c) + rulesets, err := odb.CompNodeAttachedRulesets(ctx, node.NodeID, groups, isManager) + if err != nil { + log.Error("GetNodeComplianceRulesets: cannot get attached rulesets", "node_id", node.NodeID, "error", err) + return JSONProblemf(c, http.StatusInternalServerError, "InternalError", "cannot get attached rulesets for node %s", node.NodeID) + } + + return c.JSON(http.StatusOK, rulesets) +} diff --git a/server/handlers/get_tag.go b/server/handlers/get_tag.go new file mode 100644 index 0000000..88db96a --- /dev/null +++ b/server/handlers/get_tag.go @@ -0,0 +1,12 @@ +package serverhandlers + +import ( + "github.com/labstack/echo/v4" +) + +// GetTag handles GET /tags/{tag_id} +func (a *Api) GetTag(c echo.Context, tagIdParam int) error { + log := getLog(c) + log.Info("GetTag called", "tag_id", tagIdParam) + return a.handleGetTags(c, &tagIdParam) +} diff --git a/server/handlers/get_tag_nodes.go b/server/handlers/get_tag_nodes.go new file mode 100644 index 0000000..7f6be62 --- /dev/null +++ b/server/handlers/get_tag_nodes.go @@ -0,0 +1,16 @@ +package serverhandlers + +import ( + "github.com/labstack/echo/v4" +) + +// GetTagNodes handles GET /tags/{tag_id}/nodes +func (a *Api) GetTagNodes(c echo.Context, tagIdParam int) error { + log := getLog(c) + + log.Info("GetTagNodes called", "tag_id", tagIdParam) + + // TODO + + return c.JSON(200, []interface{}{}) +} diff --git a/server/handlers/get_tags.go b/server/handlers/get_tags.go new file mode 100644 index 0000000..ceec19a --- /dev/null +++ b/server/handlers/get_tags.go @@ -0,0 +1,38 @@ +package serverhandlers + +import ( + "net/http" + + "github.com/labstack/echo/v4" +) + +// handleGetTags is the common logic for getting tags +func (a *Api) handleGetTags(c echo.Context, tagID *int) error { + log := getLog(c) + odb := a.cdbSession() + ctx := c.Request().Context() + + tags, err := odb.GetTags(ctx, tagID) + if err != nil { + log.Error("handleGetTags: cannot get tags", "tag_id", tagID, "error", err) + return JSONProblemf(c, http.StatusInternalServerError, "InternalError", "cannot get tags") + } + + if tagID != nil { + // Single tag requested + if len(tags) == 0 { + return JSONProblemf(c, http.StatusNotFound, "NotFound", "tag %d not found", *tagID) + } + return c.JSON(http.StatusOK, tags[0]) + } + + // All tags requested + return c.JSON(http.StatusOK, tags) +} + +// GetTags handles GET /tags +func (a *Api) GetTags(c echo.Context) error { + log := getLog(c) + log.Info("GetTags called") + return a.handleGetTags(c, nil) +} diff --git a/server/handlers/log.go b/server/handlers/log.go index 35442a4..5ba32b7 100644 --- a/server/handlers/log.go +++ b/server/handlers/log.go @@ -2,12 +2,17 @@ package serverhandlers import ( "log/slog" + + "github.com/labstack/echo/v4" ) -func log(args ...any) { - slog.Info("server", args...) -} +var ( + defaultLogger = slog.Default() +) -func logErr(args ...any) { - slog.Error("server", args...) +func getLog(c echo.Context) *slog.Logger { + if l, ok := c.Get("logger").(*slog.Logger); ok { + return l + } + return defaultLogger } diff --git a/server/handlers/main.go b/server/handlers/main.go index c85abf4..0f5e5ae 100644 --- a/server/handlers/main.go +++ b/server/handlers/main.go @@ -6,6 +6,7 @@ import ( "github.com/getkin/kin-openapi/openapi3" "github.com/go-redis/redis/v8" + "github.com/opensvc/oc3/cdb" "github.com/opensvc/oc3/server" ) @@ -17,6 +18,10 @@ type ( // SyncTimeout is the timeout for synchronous api calls SyncTimeout time.Duration + + Ev interface { + EventPublish(eventName string, data map[string]any) error + } } ) @@ -24,6 +29,12 @@ var ( SCHEMA openapi3.T ) +func (a *Api) cdbSession() *cdb.DB { + odb := cdb.New(a.DB) + odb.CreateSession(a.Ev) + return odb +} + func init() { if schema, err := server.GetSwagger(); err == nil { SCHEMA = *schema diff --git a/server/handlers/middleware.go b/server/handlers/middleware.go index 14fe638..e6370c7 100644 --- a/server/handlers/middleware.go +++ b/server/handlers/middleware.go @@ -11,9 +11,8 @@ import ( ) const ( - XClusterID = "XClusterID" - XNodeID = "XNodeID" - XNodename = "XNodename" + XUserID = "XUserID" + XUserEmail = "XUserEmail" ) // AuthMiddleware returns auth middleware that authenticate requests from strategies. @@ -21,65 +20,57 @@ func AuthMiddleware(strategies union.Union) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { _, user, err := strategies.AuthenticateRequest(c.Request()) + if err != nil { code := http.StatusUnauthorized return JSONProblem(c, code, http.StatusText(code), err.Error()) } - ext := user.GetExtensions() - if nodeID := ext.Get(xauth.XNodeID); nodeID != "" { - // request user is a node, sets node ID in echo context - c.Set(XNodeID, nodeID) - - if nodename := ext.Get(xauth.XNodename); nodename != "" { - c.Set(XNodename, nodename) - } - if clusterID := ext.Get(xauth.XClusterID); clusterID != "" { - // request user is a node with a cluster ID, sets cluster ID in echo context - c.Set(XClusterID, clusterID) - } + ext := user.GetExtensions() + if userEmail := ext.Get(xauth.XUserEmail); userEmail != "" { + c.Set(XUserEmail, userEmail) } + groups := user.GetGroups() + c.Set("groups", groups) c.Set("user", user) + return next(c) } } } -// nodeIDFromContext returns the nodeID from context or zero string -// if not found. -func nodeIDFromContext(c echo.Context) string { - user, ok := c.Get(XNodeID).(string) - if ok { - return user - } - return "" -} +// // userEmailFromContext returns the userEmail from context +// func userEmailFromContext(c echo.Context) string { +// user, ok := c.Get(XUserEmail).(string) +// if ok { +// return user +// } +// return "" +// } -// nodenameFromContext returns the nodename from context or zero string -// if not found. -func nodenameFromContext(c echo.Context) string { - nodename, ok := c.Get(XNodename).(string) +func UserInfoFromContext(c echo.Context) auth.Info { + user, ok := c.Get("user").(auth.Info) if ok { - return nodename + return user } - return "" + return nil } -// clusterIDFromContext returns the clusterID from context or zero string -// if not found. -func clusterIDFromContext(c echo.Context) string { - s, ok := c.Get(XClusterID).(string) +func UserGroupsFromContext(c echo.Context) []string { + groups, ok := c.Get("groups").([]string) if ok { - return s + return groups } - return "" + return nil } -func userInfoFromContext(c echo.Context) auth.Info { - user, ok := c.Get("user").(auth.Info) - if ok { - return user +func IsManager(c echo.Context) bool { + groups := UserGroupsFromContext(c) + for _, g := range groups { + if g == "Manager" { + return true + } } - return nil + return false } diff --git a/server/handlers/post_node_compliance_moduleset.go b/server/handlers/post_node_compliance_moduleset.go new file mode 100644 index 0000000..5ef7ea9 --- /dev/null +++ b/server/handlers/post_node_compliance_moduleset.go @@ -0,0 +1,84 @@ +package serverhandlers + +import ( + "context" + "fmt" + "net/http" + + "github.com/labstack/echo/v4" +) + +// PostNodeComplianceModuleset handles POST /nodes/{node_id}/compliance/modulesets/{mset_id} +func (a *Api) PostNodeComplianceModuleset(c echo.Context, nodeId string, msetId string) error { + log := getLog(c) + odb := a.cdbSession() + ctx := c.Request().Context() + odb.CreateTx(ctx, nil) + ctx, cancel := context.WithTimeout(ctx, a.SyncTimeout) + defer cancel() + + var success bool + + defer func() { + if success { + odb.Commit() + } else { + odb.Rollback() + } + }() + + log.Info("PostNodeComplianceModuleset called", "node_id", nodeId, "mset_id", msetId) + + responsible, err := odb.NodeResponsible(ctx, nodeId, UserGroupsFromContext(c), IsManager(c)) + if err != nil { + log.Error("PostNodeComplianceModuleset: cannot check if user is responsible for the node", "node_id", nodeId, "error", err) + return JSONProblemf(c, http.StatusInternalServerError, "InternalError", "cannot check if user is responsible for node %s", nodeId) + } + if !responsible { + log.Info("PostNodeComplianceModuleset: user is not responsible for this node", "node_id", nodeId) + return JSONProblemf(c, http.StatusForbidden, "Forbidden", "user is not responsible for node %s", nodeId) + } + + _, err = odb.CompModulesetName(ctx, msetId) + if err != nil { + log.Error("PostNodeComplianceModuleset: cannot find moduleset", "mset_id", msetId, "error", err) + return JSONProblemf(c, http.StatusNotFound, "NotFound", "moduleset %s not found", msetId) + } + + // check if the moduleset is already attached to the node + attached, err := odb.CompModulesetAttached(ctx, nodeId, msetId) + if err != nil { + log.Error("PostNodeComplianceModuleset: cannot check if moduleset is attached", "node_id", nodeId, "mset_id", msetId, "error", err) + return JSONProblemf(c, http.StatusInternalServerError, "InternalError", "cannot check if moduleset %s is attached to node %s", msetId, nodeId) + } + if attached { + log.Info("PostNodeComplianceModuleset: moduleset is already attached to this node", "node_id", nodeId, "mset_id", msetId) + return JSONProblemf(c, http.StatusConflict, "Conflict", "moduleset %s is already attached to this node", msetId) + } + + // check if the moduleset is attachable to the node + attachable, err := odb.CompModulesetAttachable(ctx, nodeId, msetId) + if err != nil { + log.Error("PostNodeComplianceModuleset: cannot check if moduleset is attachable", "node_id", nodeId, "mset_id", msetId, "error", err) + return JSONProblemf(c, http.StatusInternalServerError, "InternalError", "cannot check if moduleset %s is attachable to node %s", msetId, nodeId) + } + if !attachable { + log.Info("PostNodeComplianceModuleset: moduleset is not attachable to this node", "node_id", nodeId, "mset_id", msetId) + return JSONProblemf(c, http.StatusForbidden, "Forbidden", "moduleset %s is not attachable to this node", msetId) + } + + // attach moduleset to node + _, err = odb.CompModulesetAttachNode(ctx, nodeId, msetId) + if err != nil { + log.Error("PostNodeComplianceModuleset: cannot attach moduleset to node", "node_id", nodeId, "mset_id", msetId, "error", err) + return JSONProblemf(c, http.StatusInternalServerError, "InternalError", "cannot attach moduleset %s to node %s", msetId, nodeId) + } + + success = true + + response := map[string]string{ + "info": fmt.Sprintf("moduleset %s attached to node %s", msetId, nodeId), + } + + return c.JSON(http.StatusAccepted, response) +} diff --git a/server/handlers/post_node_compliance_ruleset.go b/server/handlers/post_node_compliance_ruleset.go new file mode 100644 index 0000000..19375c1 --- /dev/null +++ b/server/handlers/post_node_compliance_ruleset.go @@ -0,0 +1,90 @@ +package serverhandlers + +import ( + "context" + "fmt" + "net/http" + + "github.com/labstack/echo/v4" +) + +// PostNodeComplianceRuleset handles POST /nodes/{node_id}/compliance/rulesets/{rset_id} +func (a *Api) PostNodeComplianceRuleset(c echo.Context, nodeId string, rsetId string) error { + log := getLog(c) + odb := a.cdbSession() + ctx := c.Request().Context() + odb.CreateTx(ctx, nil) + ctx, cancel := context.WithTimeout(ctx, a.SyncTimeout) + defer cancel() + + var success bool + + defer func() { + if success { + odb.Commit() + } else { + odb.Rollback() + } + }() + + responsible, err := odb.NodeResponsible(ctx, nodeId, UserGroupsFromContext(c), IsManager(c)) + if err != nil { + log.Error("PostNodeComplianceRuleset: cannot check if user is responsible for the node", "node_id", nodeId, "error", err) + return JSONProblemf(c, http.StatusInternalServerError, "InternalError", "cannot check if user is responsible for node %s", nodeId) + } + if !responsible { + log.Info("PostNodeComplianceRuleset: user is not responsible for this node", "node_id", nodeId) + return JSONProblemf(c, http.StatusForbidden, "Forbidden", "user is not responsible for node %s", nodeId) + } + + log.Info("PostNodeComplianceRuleset called", "node_id", nodeId, "rset_id", rsetId) + + node, err := odb.NodeByNodeIDOrNodename(ctx, nodeId) + if err != nil { + log.Error("PostNodeComplianceRuleset: cannot find node", "node", nodeId, "error", err) + return JSONProblemf(c, http.StatusNotFound, "NotFound", "node %s not found", nodeId) + } + + rset, err := odb.CompRulesetName(ctx, rsetId) + if err != nil { + log.Error("PostNodeComplianceRuleset: cannot find ruleset", "rset_id", rsetId, "error", err) + return JSONProblemf(c, http.StatusNotFound, "NotFound", "ruleset %s not found", rsetId) + } + + // check if the ruleset is already attached + attached, err := odb.CompRulesetAttached(ctx, node.NodeID, rsetId) + if err != nil { + log.Error("PostNodeComplianceRuleset: cannot check if ruleset is already attached", "node_id", node.NodeID, "rset_id", rsetId, "error", err) + return JSONProblemf(c, http.StatusInternalServerError, "InternalError", "cannot check if ruleset %s is already attached to node %s", rsetId, node.NodeID) + } + if attached { + log.Info("PostNodeComplianceRuleset: ruleset is already attached to this node", "node_id", node.NodeID, "rset_id", rsetId) + return JSONProblemf(c, http.StatusConflict, "Conflict", "ruleset %s is already attached to this node", rsetId) + } + + // check if the ruleset is attachable to the node + attachable, err := odb.CompRulesetAttachable(ctx, node.NodeID, rsetId) + if err != nil { + log.Error("PostNodeComplianceRuleset: cannot check if ruleset is attachable", "node_id", node.NodeID, "rset_id", rsetId, "error", err) + return JSONProblemf(c, http.StatusInternalServerError, "InternalError", "cannot check if ruleset %s is attachable to node %s", rsetId, node.NodeID) + } + if !attachable { + log.Info("PostNodeComplianceRuleset: ruleset is not attachable to this node", "node_id", node.NodeID, "rset_id", rsetId) + return JSONProblemf(c, http.StatusForbidden, "Forbidden", "ruleset %s is not attachable to this node", rsetId) + } + + // attach ruleset to node + _, err = odb.CompRulesetAttachNode(ctx, node.NodeID, rsetId) + if err != nil { + log.Error("PostNodeComplianceRuleset: cannot attach ruleset to node", "node_id", node.NodeID, "rset_id", rsetId, "error", err) + return JSONProblemf(c, http.StatusInternalServerError, "InternalError", "cannot attach ruleset %s to node %s", rsetId, node.NodeID) + } + + success = true + + response := map[string]string{ + "info": fmt.Sprintf("ruleset %s(%s) attached", rset, rsetId), + } + + return c.JSON(http.StatusAccepted, response) +} diff --git a/xauth/w2p.go b/xauth/w2p.go new file mode 100644 index 0000000..a770b6a --- /dev/null +++ b/xauth/w2p.go @@ -0,0 +1,216 @@ +package xauth + +import ( + "context" + "crypto/hmac" + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "database/sql" + "encoding/hex" + "fmt" + "net/http" + "strings" + + "github.com/shaj13/go-guardian/v2/auth" + "github.com/shaj13/go-guardian/v2/auth/strategies/basic" +) + +type ( + authWeb2py struct { + id string + email string + password string + } +) + +var ( + digestAlgBySize = map[int]string{ + 128 / 4: "md5", + 160 / 4: "sha1", + 224 / 4: "sha224", + 256 / 4: "sha256", + 384 / 4: "sha384", + 512 / 4: "sha512", + } +) + +const ( + XUserID string = "id" + XUserEmail string = "email" + XUserPassword string = "password" +) + +const ( + queryAuthWeb2py = `SELECT auth_user.id, auth_user.email, auth_user.password + FROM auth_user + WHERE auth_user.email = ?` + + queryUserGroups = `SELECT auth_group.role + FROM auth_membership + JOIN auth_group ON auth_group.id = auth_membership.group_id + JOIN auth_user ON auth_user.id = auth_membership.user_id + WHERE auth_user.email = ?` +) + +func NewBasicWeb2py(db *sql.DB, hmacKey string) auth.Strategy { + authFunc := func(ctx context.Context, r *http.Request, userName, password string) (auth.Info, error) { + u, err := authenticateWeb2py(ctx, db, userName, password, hmacKey) + if err != nil { + return nil, fmt.Errorf("invalid credentials") + } + return auth.NewUserInfo(userName, u.id, u.Groups(ctx, db, userName), u.extensions()), nil + } + return basic.New(authFunc) +} + +func authenticateWeb2py(ctx context.Context, db *sql.DB, email, password, hmacKey string) (*authWeb2py, error) { + var user authWeb2py + + err := db. + QueryRowContext(ctx, queryAuthWeb2py, email). + Scan(&user.id, &user.email, &user.password) + if err != nil { + return nil, fmt.Errorf("invalid credentials") + } + + if !verifyWeb2pyPassword(password, user.password, hmacKey) { + return nil, fmt.Errorf("invalid credentials") + } + + return &user, nil +} + +func (n *authWeb2py) extensions() auth.Extensions { + ext := make(auth.Extensions) + ext.Set(XUserID, n.id) + ext.Set(XUserEmail, n.email) + ext.Set(XUserPassword, n.password) + return ext +} + +func (n *authWeb2py) Groups(ctx context.Context, db *sql.DB, username string) []string { + rows, err := db.QueryContext(ctx, queryUserGroups, username) + if err != nil { + return []string{} + } + defer rows.Close() + + groups := []string{} + for rows.Next() { + var group string + if err := rows.Scan(&group); err != nil { + continue + } + groups = append(groups, group) + } + return groups +} + +func guessAlg(s string) string { + n := len(s) + alg, ok := digestAlgBySize[n] + if ok { + return alg + } + return "" +} + +func toMD5(b []byte) []byte { + a := md5.Sum(b) + return a[:] +} + +func toHMACSHA512(secret, b []byte) []byte { + h := hmac.New(sha512.New, secret) + h.Write(b) + return h.Sum(nil) +} + +func toSHA512(b []byte) []byte { + a := sha512.Sum512(b) + return a[:] +} + +func toSHA384(b []byte) []byte { + a := sha512.Sum384(b) + return a[:] +} + +func toSHA256(b []byte) []byte { + a := sha256.Sum256(b) + return a[:] +} + +func toSHA224(b []byte) []byte { + a := sha256.Sum224(b) + return a[:] +} + +func toSHA1(b []byte) []byte { + a := sha1.Sum(b) + return a[:] +} + +func verifyWeb2pyPassword(password, storedHash, hmacKey string) bool { + if storedHash == "" { + return false + } + + var alg, salt, prefix string + parts := strings.SplitN(storedHash, "$", 3) + switch len(parts) { + case 3: + alg = parts[0] + salt = parts[1] + prefix = parts[0] + "$" + parts[1] + "$" + case 2: + alg = parts[0] + prefix = parts[0] + "$" + default: + alg = guessAlg(parts[0]) + } + + var digestBytes []byte + if hmacKey != "" { + hmacAlg := "sha512" + keyPart := hmacKey + if strings.Contains(hmacKey, ":") { + keyParts := strings.SplitN(hmacKey, ":", 2) + hmacAlg = keyParts[0] + keyPart = keyParts[1] + } + + secretBytes := []byte(keyPart + salt) + textBytes := []byte(password) + switch hmacAlg { + case "sha512": + digestBytes = toHMACSHA512(secretBytes, textBytes) + default: + return false + } + } else { + text := password + salt + textBytes := []byte(text) + switch alg { + case "sha512": + digestBytes = toSHA512(textBytes) + case "sha384": + digestBytes = toSHA384(textBytes) + case "sha256": + digestBytes = toSHA256(textBytes) + case "sha224": + digestBytes = toSHA224(textBytes) + case "sha1": + digestBytes = toSHA1(textBytes) + case "md5": + digestBytes = toMD5(textBytes) + default: + return false + } + } + + computedHash := prefix + hex.EncodeToString(digestBytes) + return hmac.Equal([]byte(computedHash), []byte(storedHash)) +}