diff --git a/go.mod b/go.mod index 4ec4a35..09a6890 100644 --- a/go.mod +++ b/go.mod @@ -22,9 +22,11 @@ require ( github.com/testcontainers/testcontainers-go/modules/valkey v0.40.0 github.com/valkey-io/valkey-go v1.0.68 github.com/veqryn/slog-context v0.8.0 + github.com/zitadel/oidc/v3 v3.45.0 go.opentelemetry.io/otel v1.38.0 go.opentelemetry.io/otel/metric v1.38.0 go.opentelemetry.io/otel/trace v1.38.0 + golang.org/x/oauth2 v0.31.0 google.golang.org/grpc v1.76.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -62,6 +64,7 @@ require ( github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-openapi/jsonpointer v0.21.0 // indirect github.com/go-openapi/swag v0.23.0 // indirect + github.com/gorilla/securecookie v1.1.2 // indirect github.com/grafana/regexp v0.0.0-20240518133315-a468a5bfb3bc // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect github.com/hashicorp/go-version v1.7.0 // indirect @@ -88,6 +91,7 @@ require ( github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee // indirect github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect github.com/morikuni/aec v1.0.0 // indirect + github.com/muhlemmer/gu v0.3.1 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/oapi-codegen/oapi-codegen/v2 v2.5.0 // indirect github.com/oasdiff/yaml v0.0.0-20250309154309-f31be36b4037 // indirect @@ -128,6 +132,8 @@ require ( github.com/veqryn/slog-context/otel v0.8.0 // indirect github.com/vmware-labs/yaml-jsonpath v0.3.2 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect + github.com/zitadel/logging v0.6.2 // indirect + github.com/zitadel/schema v1.3.1 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/collector/featuregate v1.45.0 // indirect go.opentelemetry.io/collector/pdata v1.45.0 // indirect diff --git a/go.sum b/go.sum index 0d15674..9d79d7e 100644 --- a/go.sum +++ b/go.sum @@ -13,7 +13,10 @@ github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7D github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bmatcuk/doublestar v1.1.1 h1:YroD6BJCZBYx06yYFEWvUuKVWQn3vLLQAVmDmvTSaiQ= github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w= +github.com/bmatcuk/doublestar/v4 v4.9.1 h1:X8jg9rRZmJd4yRy7ZeNDRnM+T3ZfHv15JiBJ/avrEXE= +github.com/bmatcuk/doublestar/v4 v4.9.1/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= @@ -74,6 +77,8 @@ github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/getkin/kin-openapi v0.132.0 h1:3ISeLMsQzcb5v26yeJrBcdTCEQTag36ZjaGk7MIRUwk= github.com/getkin/kin-openapi v0.132.0/go.mod h1:3OlG51PCYNsPByuiMB0t4fjnNlIDnaEDsjiKUV8nL58= +github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE= +github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= @@ -111,9 +116,13 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= +github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= +github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/grafana/regexp v0.0.0-20240518133315-a468a5bfb3bc h1:GN2Lv3MGO7AS6PrRoT6yV5+wkrOpcszoIsO4+4ds248= github.com/grafana/regexp v0.0.0-20240518133315-a468a5bfb3bc/go.mod h1:+JKpmjMGhpgPL+rXZ5nsZieVzvarn86asRlBg4uNGnk= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU= @@ -132,6 +141,8 @@ github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk= github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jeremija/gosubmit v0.2.8 h1:mmSITBz9JxVtu8eqbN+zmmwX7Ij2RidQxhcwRVI4wqA= +github.com/jeremija/gosubmit v0.2.8/go.mod h1:Ui+HS073lCFREXBbdfrJzMB57OI/bdxTiLtrDHHhFPI= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= @@ -188,6 +199,10 @@ github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9 github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= +github.com/muhlemmer/gu v0.3.1 h1:7EAqmFrW7n3hETvuAdmFmn4hS8W+z3LgKtrnow+YzNM= +github.com/muhlemmer/gu v0.3.1/go.mod h1:YHtHR+gxM+bKEIIs7Hmi9sPT3ZDUvTN/i88wQpZkrdM= +github.com/muhlemmer/httpforwarded v0.1.0 h1:x4DLrzXdliq8mprgUMR0olDvHGkou5BJsK/vWUetyzY= +github.com/muhlemmer/httpforwarded v0.1.0/go.mod h1:yo9czKedo2pdZhoXe+yDkGVbU0TJ0q9oQ90BVoDEtw0= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= @@ -255,6 +270,8 @@ github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94 github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA= +github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc= github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik= @@ -328,6 +345,12 @@ github.com/vmware-labs/yaml-jsonpath v0.3.2/go.mod h1:U6whw1z03QyqgWdgXxvVnQ90zN github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= +github.com/zitadel/logging v0.6.2 h1:MW2kDDR0ieQynPZ0KIZPrh9ote2WkxfBif5QoARDQcU= +github.com/zitadel/logging v0.6.2/go.mod h1:z6VWLWUkJpnNVDSLzrPSQSQyttysKZ6bCRongw0ROK4= +github.com/zitadel/oidc/v3 v3.45.0 h1:SaVJ2kdcJi/zdEWWlAns+81VxmfdYX4E+2mWFVIH7Ec= +github.com/zitadel/oidc/v3 v3.45.0/go.mod h1:UeK0iVOoqfMuDVgSfv56BqTz8YQC2M+tGRIXZ7Ii3VY= +github.com/zitadel/schema v1.3.1 h1:QT3kwiRIRXXLVAs6gCK/u044WmUVh6IlbLXUsn6yRQU= +github.com/zitadel/schema v1.3.1/go.mod h1:071u7D2LQacy1HAN+YnMd/mx1qVE2isb0Mjeqg46xnU= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/collector/featuregate v1.45.0 h1:D06hpf1F2KzKC+qXLmVv5e8IZpgCyZVeVVC8iOQxVmw= @@ -409,6 +432,8 @@ golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM= golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY= +golang.org/x/oauth2 v0.31.0 h1:8Fq0yVZLh4j4YA47vHKFTa9Ew5XIrCP8LC6UeNZnLxo= +golang.org/x/oauth2 v0.31.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= diff --git a/internal/business/business.go b/internal/business/business.go index 56e77eb..0ce5f4d 100644 --- a/internal/business/business.go +++ b/internal/business/business.go @@ -187,6 +187,10 @@ func initSessionManager(ctx context.Context, cfg *config.Config) (_ *session.Man return nil, nil, errors.New("CSRF secret must be at least 32 bytes") } + clientSecret, err := commoncfg.LoadValueFromSourceRef(cfg.SessionManager.ClientAuth.ClientSecret) + if err != nil { + return nil, nil, fmt.Errorf("loading client secret: %w", err) + } return session.NewManager( oidcProviderRepo, sessionRepo, @@ -197,9 +201,9 @@ func initSessionManager(ctx context.Context, cfg *config.Config) (_ *session.Man cfg.SessionManager.AdditionalAuthContextKeys, cfg.SessionManager.RedirectURI, clientID, + string(clientSecret), httpClient, string(csrfSecret), - cfg.SessionManager.JWSSigAlgs, ), valkeyClient.Close, nil } diff --git a/pkg/session/helper_test.go b/pkg/session/helper_test.go index ff41a22..aec1c8c 100644 --- a/pkg/session/helper_test.go +++ b/pkg/session/helper_test.go @@ -1,10 +1,17 @@ package session_test import ( + "crypto/rand" + "crypto/rsa" "encoding/json" "net/http" "net/http/httptest" "testing" + "time" + + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" + "github.com/stretchr/testify/assert" "github.com/openkcm/session-manager/internal/oidc" "github.com/openkcm/session-manager/pkg/session" @@ -12,6 +19,9 @@ import ( func StartOIDCServer(t *testing.T, fail bool) *httptest.Server { t.Helper() + priv, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + kid := "test-kid" var server *httptest.Server server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if fail { @@ -30,18 +40,46 @@ func StartOIDCServer(t *testing.T, fail bool) *httptest.Server { JwksURI: server.URL + "/.well-known/jwks.json", }) case "/.well-known/jwks.json": - _, _ = w.Write([]byte(`{"keys":[{"kty": "RSA", "e": "AQAB", "use": "sig", "kid": "7cdrxOwDtBcW6ZmoW1CHjx2f74xqS6GAwJXOUd_oECw", "alg": "RS256", "n": "nMds_LftGh9YWfCuKfTU9rHezOPOUzooalZXIXMBnj4Xd7EQieVH4acwIlGQDsy9FasnSUzok4eeuJR1nmz7I5d0qIDjw_SItsFe83KetfFBLPsoCrR3kzcuof8KG3_N7pTGWMyl9cb8QTMzRYgzSrfgMJgi1TCHQq5uE-CWdjaCTklJgvnUb9QjYoyf3CkGz6hjlfu1TPw2CQfVXy0fW5jT9S6d10zYfYXnfeYxZFiKBgv2YNUPtwnejs0mZcE7lLyURf1tgkgZheHNde6Nz8UC0HEbGKBT6I-WXaFUJsmI5GDsQXTNfp6YmdYk_s-rM4bz-Hg51XI0JWk4J2bUyQ"}]}`)) + jwk := jose.JSONWebKey{Key: &priv.PublicKey, KeyID: kid, Algorithm: string(jose.RS256), Use: "sig"} + jwkSet := jose.JSONWebKeySet{Keys: []jose.JSONWebKey{jwk}} + jwkSetBytes, err := json.Marshal(jwkSet) + if err != nil { + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + assert.NoError(t, err) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write(jwkSetBytes) + case "/oauth2/token": + now := time.Now() + claims := map[string]any{ + "sub": "jwt-test", + "iss": server.URL, + "aud": []string{"client-id"}, + "iat": now.Unix(), + "exp": now.Add(time.Hour).Unix(), + "nbf": now.Unix(), + "jti": "test-jti", + } + signer, err := jose.NewSigner( + jose.SigningKey{Algorithm: jose.RS256, Key: priv}, + (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", kid), + ) + assert.NoError(t, err) + rawJWT, err := jwt.Signed(signer).Claims(claims).Serialize() + assert.NoError(t, err) + w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - tokenResponse := session.TokenResponse{ + _ = json.NewEncoder(w).Encode(session.TokenResponse{ AccessToken: "access-token", RefreshToken: "refresh-token", - IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjdjZHJ4T3dEdEJjVzZabW9XMUNIangyZjc0eHFTNkdBd0pYT1VkX29FQ3ciLCJ0eXAiOiJKV1QifQ.eyJzdWIiOiJqd3QtdGVzdCIsImp0aSI6IjIzNDE0MzUiLCJuYmYiOjE3NjA1NzcwMjgsImV4cCI6MTc2MDU4NzgyOCwiaWF0IjoxNzYwNTc3MDI4LCJpc3MiOiJkYXJ3aW5MYWJzIiwiYXVkIjoiaHR0cDovL3d3dy5kYXJ3aW4tbGFicy5jb20ifQ.TbUiSRxNE-x2NYc_9CkLt59caV_CeOxaaHjbtBekWeKSnYXlZIOqf6qikdVhKwN3IdssUi5af6E2tVEvM4fAZuCGKy7qkHXqvitxm2XLfZPvQzscrN7L476rjUaEr2HcjqoOmhPwcgTfeJJRp9o_JIqvtb-NXhIZPbPBkinTWFIArLfcJ1WZx4fYbXY7nixunJfQqYYtZSP_OukzRbAK5qwPj55USPFhh3IBWrUsS4x_YOiF8PITldLhCYIFNmhI5vkT6KwaWVYAVZPnwLARSW0nZAKnv_qAuhwHbhP8Et746Qw-WF-5K2Ij3YlgsNG-6_c0ID2MwBhoqpg-1sFcug", + IDToken: rawJWT, TokenType: "Bearer", ExpiresIn: 3600, - } - _ = json.NewEncoder(w).Encode(tokenResponse) + }) + default: + http.NotFound(w, r) } })) diff --git a/pkg/session/manager.go b/pkg/session/manager.go index 593114e..ebeee4d 100644 --- a/pkg/session/manager.go +++ b/pkg/session/manager.go @@ -1,48 +1,50 @@ package session import ( - "bytes" "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" "encoding/json" - "errors" "fmt" "net/http" "net/url" "strings" + "sync" "time" - "github.com/go-jose/go-jose/v4" - "github.com/go-jose/go-jose/v4/jwt" + "github.com/zitadel/oidc/v3/pkg/client/rp" + "github.com/zitadel/oidc/v3/pkg/oidc" + "golang.org/x/oauth2" otlpaudit "github.com/openkcm/common-sdk/pkg/otlp/audit" slogctx "github.com/veqryn/slog-context" - "github.com/openkcm/session-manager/internal/oidc" - "github.com/openkcm/session-manager/internal/pkce" + oidcprovider "github.com/openkcm/session-manager/internal/oidc" "github.com/openkcm/session-manager/internal/serviceerr" "github.com/openkcm/session-manager/pkg/csrf" ) type Manager struct { - oidc oidc.ProviderRepository + oidc oidcprovider.ProviderRepository sessions Repository - pkce pkce.Source audit *otlpaudit.AuditLogger sessionDuration time.Duration redirectURI string clientID string + clientSecret string secureClient *http.Client getParametersAuth []string getParametersToken []string authContextKeys []string - - csrfSecret []byte - jwsSigAlgs []jose.SignatureAlgorithm + csrfSecret []byte + relyingParty map[string]rp.RelyingParty } func NewManager( - oidc oidc.ProviderRepository, + oidc oidcprovider.ProviderRepository, sessions Repository, auditLogger *otlpaudit.AuditLogger, sessionDuration time.Duration, @@ -51,15 +53,10 @@ func NewManager( authContextKeys []string, redirectURI string, clientID string, + clientSecret string, httpClient *http.Client, csrfHMACSecret string, - jwsSigAlgs []string, ) *Manager { - algs := make([]jose.SignatureAlgorithm, 0, len(jwsSigAlgs)) - for _, alg := range jwsSigAlgs { - algs = append(algs, jose.SignatureAlgorithm(alg)) - } - return &Manager{ oidc: oidc, sessions: sessions, @@ -70,12 +67,41 @@ func NewManager( authContextKeys: authContextKeys, redirectURI: redirectURI, clientID: clientID, + clientSecret: clientSecret, secureClient: httpClient, csrfSecret: []byte(csrfHMACSecret), - jwsSigAlgs: algs, + relyingParty: make(map[string]rp.RelyingParty), } } +var ( + codeVerifierStore = make(map[string]string) + codeVerifierMu sync.Mutex +) + +// getRelyingParty creates or returns cached Zitadel OIDC client +func (m *Manager) getRelyingParty(ctx context.Context, provider oidcprovider.Provider) (rp.RelyingParty, error) { + if rpInst, exists := m.relyingParty[provider.IssuerURL]; exists { + return rpInst, nil + } + + scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, "groups"} + relyingParty, err := rp.NewRelyingPartyOIDC( + ctx, + provider.IssuerURL, + m.clientID, + m.clientSecret, + m.redirectURI, + scopes, + ) + if err != nil { + return nil, fmt.Errorf("creating relying party: %w", err) + } + + m.relyingParty[provider.IssuerURL] = relyingParty + return relyingParty, nil +} + // MakeAuthURI returns an OIDC authentication URI. func (m *Manager) MakeAuthURI(ctx context.Context, tenantID, fingerprint, requestURI string) (string, error) { provider, err := m.oidc.GetForTenant(ctx, tenantID) @@ -83,82 +109,49 @@ func (m *Manager) MakeAuthURI(ctx context.Context, tenantID, fingerprint, reques return "", fmt.Errorf("getting oidc provider: %w", err) } - openidConf, err := m.getOpenIDConfig(ctx, provider) + relyingParty, err := m.getRelyingParty(ctx, provider) if err != nil { return "", fmt.Errorf("getting an openid config: %w", err) } - stateID := m.pkce.State() - pkce := m.pkce.PKCE() - + stateID := generateStateID() state := State{ - ID: stateID, - TenantID: tenantID, - Fingerprint: fingerprint, - PKCEVerifier: pkce.Verifier, - RequestURI: requestURI, - Expiry: time.Now().Add(m.sessionDuration), + ID: stateID, + TenantID: tenantID, + Fingerprint: fingerprint, + RequestURI: requestURI, + Expiry: time.Now().Add(m.sessionDuration), } if err := m.sessions.StoreState(ctx, state); err != nil { return "", fmt.Errorf("storing session: %w", err) } - u, err := m.authURI(openidConf, state, pkce, provider.Properties) - if err != nil { - return "", fmt.Errorf("generating auth uri: %w", err) - } + // Generate code verifier and challenge for PKCE + codeVerifier := generateCodeVerifier() + codeChallenge := generateS256Challenge(codeVerifier) + storeCodeVerifier(stateID, codeVerifier) - return u, nil -} + authURL := rp.AuthURL(stateID, relyingParty, rp.WithCodeChallenge(codeChallenge)) -func (m *Manager) authURI(openidConf oidc.Configuration, state State, pkce pkce.PKCE, properties map[string]string) (string, error) { - u, err := url.Parse(openidConf.AuthorizationEndpoint) + // Add custom parameters from provider properties + u, err := url.Parse(authURL) if err != nil { - return "", fmt.Errorf("parsing authorisation endpoint url: %w", err) + return "", fmt.Errorf("parsing auth url: %w", err) } - q := u.Query() - q.Set("scope", "openid profile email groups") - q.Set("response_type", "code") - q.Set("client_id", m.clientID) - q.Set("state", state.ID) - q.Set("code_challenge", pkce.Challenge) - q.Set("code_challenge_method", pkce.Method) - q.Set("redirect_uri", m.redirectURI) for _, parameter := range m.getParametersAuth { - value, ok := properties[parameter] + value, ok := provider.Properties[parameter] if !ok { return "", fmt.Errorf("missing auth parameter: %s", parameter) } q.Set(parameter, value) } - u.RawQuery = q.Encode() return u.String(), nil } -func (m *Manager) getProviderKeySet(ctx context.Context, oidcConf oidc.Configuration) (*jose.JSONWebKeySet, error) { - var keySet jose.JSONWebKeySet - uri := oidcConf.JwksURI - req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri, nil) - if err != nil { - return nil, fmt.Errorf("creating a new HTTP request: %w", err) - } - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return nil, fmt.Errorf("executing an http request: %w", err) - } - - if err := json.NewDecoder(resp.Body).Decode(&keySet); err != nil { - return nil, fmt.Errorf("decoding keyset response: %w", err) - } - - return &keySet, nil -} - func (m *Manager) FinaliseOIDCLogin(ctx context.Context, stateID, code, fingerprint string) (OIDCSessionData, error) { state, err := m.sessions.LoadState(ctx, stateID) if err != nil { @@ -180,39 +173,45 @@ func (m *Manager) FinaliseOIDCLogin(ctx context.Context, stateID, code, fingerpr return OIDCSessionData{}, fmt.Errorf("getting oidc provider: %w", err) } - openidConf, err := m.getOpenIDConfig(ctx, provider) - if err != nil { - return OIDCSessionData{}, fmt.Errorf("getting openid configuration: %w", err) - } - - tokens, err := m.exchangeCode(ctx, openidConf, code, state.PKCEVerifier, provider.Properties) + relyingParty, err := m.getRelyingParty(ctx, provider) if err != nil { - return OIDCSessionData{}, fmt.Errorf("exchanging code for tokens: %w", err) + return OIDCSessionData{}, fmt.Errorf("getting relying party: %w", err) } - slogctx.Info(ctx, "Exchanged the auth code for tokens") + codeVerifier := getCodeVerifier(stateID) - sessionID := m.pkce.SessionID() - csrfToken := csrf.NewToken(sessionID, m.csrfSecret) + codeOpts := make([]rp.CodeExchangeOpt, 0, 1+len(m.getParametersToken)) + codeOpts = append(codeOpts, rp.WithCodeVerifier(codeVerifier)) - token, err := jwt.ParseSigned(tokens.IDToken, m.jwsSigAlgs) - if err != nil { - return OIDCSessionData{}, fmt.Errorf("parsing id token: %w", err) + // Add custom token parameters from provider properties + for _, parameter := range m.getParametersToken { + value, ok := provider.Properties[parameter] + if !ok { + return OIDCSessionData{}, fmt.Errorf("missing token parameter: %s", parameter) + } + codeOpts = append(codeOpts, func(key, val string) rp.CodeExchangeOpt { + return func() []oauth2.AuthCodeOption { + return []oauth2.AuthCodeOption{oauth2.SetAuthURLParam(key, val)} + } + }(parameter, value)) } - jws, err := jose.ParseSigned(tokens.IDToken, m.jwsSigAlgs) + tokens, err := rp.CodeExchange[*oidc.IDTokenClaims](ctx, code, relyingParty, codeOpts...) if err != nil { - return OIDCSessionData{}, fmt.Errorf("parsing JWS: %w", err) + return OIDCSessionData{}, fmt.Errorf("exchanging code for tokens: %w", err) } - keyset, err := m.getProviderKeySet(ctx, openidConf) + claims, err := rp.VerifyTokens[*oidc.IDTokenClaims](ctx, tokens.AccessToken, tokens.IDToken, relyingParty.IDTokenVerifier()) if err != nil { - return OIDCSessionData{}, fmt.Errorf("getting jwks for a provider: %w", err) + return OIDCSessionData{}, fmt.Errorf("verifying tokens: %w", err) } - var claims jwt.Claims - if err := token.Claims(keyset, &claims); err != nil { - return OIDCSessionData{}, fmt.Errorf("getting JWT claims: %w", err) + sessionID := generateSessionID() + csrfToken := csrf.NewToken(sessionID, m.csrfSecret) + + userInfo, err := rp.Userinfo[*oidc.UserInfo](ctx, tokens.AccessToken, tokens.TokenType, claims.GetSubject(), relyingParty) + if err != nil { + slogctx.Warn(ctx, "Failed to get user info", "error", err) } // prepare the auth context used by ExtAuthZ @@ -234,16 +233,16 @@ func (m *Manager) FinaliseOIDCLogin(ctx context.Context, stateID, code, fingerpr Fingerprint: fingerprint, CSRFToken: csrfToken, Issuer: provider.IssuerURL, - RawClaims: string(jws.UnsafePayloadWithoutVerification()), + RawClaims: tokens.IDToken, Claims: Claims{ - Subject: claims.Subject, - Email: "", // TODO: extract email from claims - Groups: []string{}, // TODO: extract groups from claims + Subject: claims.GetSubject(), + Email: getEmailFromUserInfo(userInfo), + Groups: getGroupsFromUserInfo(userInfo), }, AccessToken: tokens.AccessToken, RefreshToken: tokens.RefreshToken, - Expiry: time.Now().Add(m.sessionDuration), AuthContext: authContext, + Expiry: time.Now().Add(m.sessionDuration), } if err := m.sessions.StoreSession(ctx, session); err != nil { @@ -261,45 +260,6 @@ func (m *Manager) FinaliseOIDCLogin(ctx context.Context, stateID, code, fingerpr }, nil } -func (m *Manager) exchangeCode(ctx context.Context, openidConf oidc.Configuration, code, codeVerifier string, properties map[string]string) (tokenResponse, error) { - data := url.Values{} - data.Set("grant_type", "authorization_code") - data.Set("code", code) - data.Set("code_verifier", codeVerifier) - data.Set("redirect_uri", m.redirectURI) - data.Set("client_id", m.clientID) - for _, parameter := range m.getParametersToken { - value, ok := properties[parameter] - if !ok { - return tokenResponse{}, fmt.Errorf("missing token parameter: %s", parameter) - } - data.Set(parameter, value) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, openidConf.TokenEndpoint, strings.NewReader(data.Encode())) - if err != nil { - return tokenResponse{}, fmt.Errorf("creating request: %w", err) - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - resp, err := m.secureClient.Do(req) - if err != nil { - return tokenResponse{}, fmt.Errorf("executing request: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return tokenResponse{}, fmt.Errorf("token exchange failed with status: %d", resp.StatusCode) - } - - var tokens tokenResponse - if err := json.NewDecoder(resp.Body).Decode(&tokens); err != nil { - return tokenResponse{}, fmt.Errorf("decoding response: %w", err) - } - - return tokens, nil -} - func (m *Manager) ValidateCSRFToken(token, sessionID string) bool { return csrf.Validate(token, sessionID, m.csrfSecret) } @@ -330,79 +290,99 @@ func (m *Manager) RefreshExpiringSessions(ctx context.Context) error { return nil } -// RefreshSession refreshes the access token using the given refresh token for the tenant. -func (m *Manager) RefreshSession(ctx context.Context, s *Session, provider oidc.Provider) error { - data := url.Values{} - data.Set("grant_type", "refresh_token") - data.Set("refresh_token", s.RefreshToken) - data.Set("client_id", m.clientID) - - tokenEndpoint, err := url.JoinPath(provider.IssuerURL, "/token") - if err != nil { - return fmt.Errorf("making issuer token path: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenEndpoint, bytes.NewBufferString(data.Encode())) +// RefreshSession using Zitadel library +func (m *Manager) RefreshSession(ctx context.Context, s *Session, provider oidcprovider.Provider) error { + relyingParty, err := m.getRelyingParty(ctx, provider) if err != nil { - return err + return fmt.Errorf("getting relying party: %w", err) } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - resp, err := m.secureClient.Do(req) + newTokens, err := rp.RefreshTokens[*oidc.IDTokenClaims](ctx, relyingParty, s.RefreshToken, "", "") if err != nil { - return err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return errors.New("token endpoint returned non-200 status") - } - - var respData struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - ExpiresIn int `json:"expires_in"` + return fmt.Errorf("refreshing tokens: %w", err) } - s.AccessToken = respData.AccessToken - s.RefreshToken = respData.RefreshToken - s.AccessTokenExpiry = time.Now().Add(time.Duration(respData.ExpiresIn)) + s.AccessToken = newTokens.AccessToken + s.RefreshToken = newTokens.RefreshToken + s.AccessTokenExpiry = newTokens.Expiry return nil } -func (m *Manager) getOpenIDConfig(ctx context.Context, provider oidc.Provider) (oidc.Configuration, error) { - const wellKnownOpenIDConfigPath = "/.well-known/openid-configuration" +func shouldRefresh(s Session) bool { + return time.Until(s.AccessTokenExpiry) < 5*time.Minute +} - u, err := url.JoinPath(provider.IssuerURL, wellKnownOpenIDConfigPath) - if err != nil { - return oidc.Configuration{}, fmt.Errorf("building path to the well-known openid-config endpoint: %w", err) +// Helper functions to implement +func generateStateID() string { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return fmt.Sprintf("state-%d", time.Now().UnixNano()) } + return "state-" + hex.EncodeToString(b) +} - req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) - if err != nil { - return oidc.Configuration{}, fmt.Errorf("creating an HTTP request: %w", err) +func generateSessionID() string { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return fmt.Sprintf("session-%d", time.Now().UnixNano()) } + return "session-" + hex.EncodeToString(b) +} - resp, err := m.secureClient.Do(req) - if err != nil { - return oidc.Configuration{}, fmt.Errorf("doing an HTTP request: %w", err) +func generateCodeVerifier() string { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "dummy_verifier" } + return hex.EncodeToString(b) +} - var conf oidc.Configuration - if err := json.NewDecoder(resp.Body).Decode(&conf); err != nil { - return oidc.Configuration{}, fmt.Errorf("decoding a well-known openid config: %w", err) - } +func storeCodeVerifier(stateID, codeVerifier string) { + codeVerifierMu.Lock() + defer codeVerifierMu.Unlock() + codeVerifierStore[stateID] = codeVerifier +} + +func getCodeVerifier(stateID string) string { + codeVerifierMu.Lock() + defer codeVerifierMu.Unlock() + return codeVerifierStore[stateID] +} - // Validate the configuration - if conf.Issuer != provider.IssuerURL { - return oidc.Configuration{}, serviceerr.ErrInvalidOIDCProvider +func getEmailFromUserInfo(userInfo *oidc.UserInfo) string { + if userInfo != nil && userInfo.Email != "" { + return userInfo.Email } + return "" +} - return conf, nil +func generateS256Challenge(verifier string) string { + hash := sha256.Sum256([]byte(verifier)) + return strings.TrimRight(base64.URLEncoding.EncodeToString(hash[:]), "=") } -func shouldRefresh(s Session) bool { - // refresh if token expires in less than 5 minutes - return time.Until(s.AccessTokenExpiry) < 5*time.Minute +func getGroupsFromUserInfo(userInfo *oidc.UserInfo) []string { + if userInfo == nil { + return nil + } + // Marshal userInfo to JSON, then unmarshal to map to access custom claims + data, err := json.Marshal(userInfo) + if err != nil { + return nil + } + var raw map[string]interface{} + if err := json.Unmarshal(data, &raw); err != nil { + return nil + } + if groups, ok := raw["groups"].([]interface{}); ok { + result := make([]string, 0, len(groups)) + for _, g := range groups { + if s, ok := g.(string); ok { + result = append(result, s) + } + } + return result + } + return nil } diff --git a/pkg/session/manager_test.go b/pkg/session/manager_test.go index dfcf639..04e8707 100644 --- a/pkg/session/manager_test.go +++ b/pkg/session/manager_test.go @@ -58,6 +58,7 @@ func TestManager_Auth(t *testing.T) { sessions *sessionmock.Repository redirectURI string clientID string + clientSecret string tenantID string fingerprint string requestURI string @@ -74,36 +75,39 @@ func TestManager_Auth(t *testing.T) { sessions: sessionmock.NewInMemRepository(nil, nil, nil, nil, nil), redirectURI: redirectURI, clientID: "my-client-id", + clientSecret: "my-client-secret", tenantID: tenantID, fingerprint: "fingerprint", requestURI: requestURI, getParametersAuth: []string{"paramAuth1"}, - wantURL: oidcServer.URL + "/oauth2/authorize?client_id=my-client-id&code_challenge=someChallenge&code_challenge_method=S256&redirect_uri=" + redirectURI + "&response_type=code&scope=openid+profile+email+groups&state=someState¶mAuth1=paramAuth1", + wantURL: oidcServer.URL + "/oauth2/authorize?client_id=my-client-id&code_challenge=someChallenge&code_challenge_method=S256¶mAuth1=paramAuth1&redirect_uri=" + redirectURI + "&response_type=code&scope=openid+profile+email+groups&state=someState", errAssert: assert.NoError, }, { - name: "Get OIDC error", - oidc: newOIDCRepo(nil, errors.New("faield to get oidc provider"), nil, nil, nil), - sessions: sessionmock.NewInMemRepository(nil, nil, nil, nil, nil), - redirectURI: redirectURI, - clientID: "my-client-id", - tenantID: tenantID, - fingerprint: "fingerprint", - requestURI: requestURI, - wantURL: "", - errAssert: assert.Error, + name: "Get OIDC error", + oidc: newOIDCRepo(nil, errors.New("faield to get oidc provider"), nil, nil, nil), + sessions: sessionmock.NewInMemRepository(nil, nil, nil, nil, nil), + redirectURI: redirectURI, + clientID: "my-client-id", + clientSecret: "my-client-secret", + tenantID: tenantID, + fingerprint: "fingerprint", + requestURI: requestURI, + wantURL: "", + errAssert: assert.Error, }, { - name: "Save state error", - oidc: newOIDCRepo(nil, nil, nil, nil, nil), - sessions: sessionmock.NewInMemRepository(nil, errors.New("failed to save state"), nil, nil, nil), - redirectURI: redirectURI, - clientID: "my-client-id", - tenantID: tenantID, - fingerprint: "fingerprint", - requestURI: requestURI, - wantURL: "", - errAssert: assert.Error, + name: "Save state error", + oidc: newOIDCRepo(nil, nil, nil, nil, nil), + sessions: sessionmock.NewInMemRepository(nil, errors.New("failed to save state"), nil, nil, nil), + redirectURI: redirectURI, + clientID: "my-client-id", + clientSecret: "my-client-secret", + tenantID: tenantID, + fingerprint: "fingerprint", + requestURI: requestURI, + wantURL: "", + errAssert: assert.Error, }, } for _, tt := range tests { @@ -122,7 +126,7 @@ func TestManager_Auth(t *testing.T) { auditLogger, err := otlpaudit.NewLogger(&commoncfg.Audit{Endpoint: auditServer.URL}) require.NoError(t, err) - m := session.NewManager(tt.oidc, tt.sessions, auditLogger, time.Hour, tt.getParametersAuth, tt.getParametersToken, tt.authContextKeys, tt.redirectURI, tt.clientID, http.DefaultClient, testCSRFSecret, []string{"RS256"}) + m := session.NewManager(tt.oidc, tt.sessions, auditLogger, time.Hour, tt.getParametersAuth, tt.getParametersToken, tt.authContextKeys, tt.redirectURI, tt.clientID, tt.clientSecret, http.DefaultClient, testCSRFSecret) got, err := m.MakeAuthURI(t.Context(), tt.tenantID, tt.fingerprint, tt.requestURI) if !tt.errAssert(t, err, fmt.Sprintf("Manager.Auth() error = %v", err)) || err != nil { @@ -344,7 +348,7 @@ func TestManager_FinaliseOIDCLogin(t *testing.T) { IssuerURL: oidcServer.URL, Blocked: false, JWKSURIs: []string{jwksURI}, - Audiences: []string{requestURI}, + Audiences: []string{"client-id"}, Properties: map[string]string{ "getParamToken1": "getParamToken1", "authContextKey1": "authContextValue1", @@ -353,7 +357,7 @@ func TestManager_FinaliseOIDCLogin(t *testing.T) { tt.oidc.Add(tenantID, localOIDCProvider) - m := session.NewManager(tt.oidc, tt.sessions, auditLogger, time.Hour, tt.getParametersAuth, tt.getParametersToken, tt.authContextKeys, redirectURI, "client-id", http.DefaultClient, testCSRFSecret, []string{"RS256"}) + m := session.NewManager(tt.oidc, tt.sessions, auditLogger, time.Hour, tt.getParametersAuth, tt.getParametersToken, tt.authContextKeys, redirectURI, "client-id", "client-secret", http.DefaultClient, testCSRFSecret) result, err := m.FinaliseOIDCLogin(context.Background(), tt.stateID, tt.code, tt.fingerprint)