From 49a1808f7bb042d8287735ae673fd866ed53b582 Mon Sep 17 00:00:00 2001 From: Maliheh Date: Mon, 17 Dec 2018 12:26:12 -0800 Subject: [PATCH 1/5] Proxy shim for simplifying development environment workflow --- Makefile | 6 +- cmd/sso-devproxy/main.go | 48 + internal/devproxy/collector/collector.go | 102 ++ internal/devproxy/dev_config.go | 373 +++++++ internal/devproxy/devproxy.go | 392 ++++++++ internal/devproxy/devproxy_test.go | 933 ++++++++++++++++++ internal/devproxy/logging_handler.go | 131 +++ internal/devproxy/metrics.go | 54 + internal/devproxy/middleware.go | 69 ++ internal/devproxy/options.go | 143 +++ internal/devproxy/providers/http_client.go | 17 + internal/devproxy/providers/internal_util.go | 85 ++ .../devproxy/providers/internal_util_test.go | 132 +++ internal/devproxy/providers/provider_data.go | 31 + .../devproxy/providers/provider_default.go | 146 +++ .../providers/provider_default_test.go | 19 + internal/devproxy/providers/providers.go | 25 + internal/devproxy/providers/session_state.go | 61 ++ .../devproxy/providers/session_state_test.go | 72 ++ .../providers/singleflight_middleware.go | 145 +++ internal/devproxy/providers/sso.go | 385 ++++++++ internal/devproxy/providers/sso_test.go | 550 +++++++++++ internal/devproxy/templates.go | 126 +++ .../devproxy/testdata/upstream_configs.yml | 10 + 24 files changed, 4054 insertions(+), 1 deletion(-) create mode 100644 cmd/sso-devproxy/main.go create mode 100644 internal/devproxy/collector/collector.go create mode 100644 internal/devproxy/dev_config.go create mode 100644 internal/devproxy/devproxy.go create mode 100644 internal/devproxy/devproxy_test.go create mode 100644 internal/devproxy/logging_handler.go create mode 100644 internal/devproxy/metrics.go create mode 100644 internal/devproxy/middleware.go create mode 100644 internal/devproxy/options.go create mode 100644 internal/devproxy/providers/http_client.go create mode 100644 internal/devproxy/providers/internal_util.go create mode 100644 internal/devproxy/providers/internal_util_test.go create mode 100644 internal/devproxy/providers/provider_data.go create mode 100644 internal/devproxy/providers/provider_default.go create mode 100644 internal/devproxy/providers/provider_default_test.go create mode 100644 internal/devproxy/providers/providers.go create mode 100644 internal/devproxy/providers/session_state.go create mode 100644 internal/devproxy/providers/session_state_test.go create mode 100644 internal/devproxy/providers/singleflight_middleware.go create mode 100644 internal/devproxy/providers/sso.go create mode 100644 internal/devproxy/providers/sso_test.go create mode 100644 internal/devproxy/templates.go create mode 100644 internal/devproxy/testdata/upstream_configs.yml diff --git a/Makefile b/Makefile index 182f11b5..f93658f3 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ version := "v1.0.0" commit := $(shell git rev-parse --short HEAD) -build: dist/sso-auth dist/sso-proxy +build: dist/sso-auth dist/sso-proxy dist/sso-devproxy dist/sso-auth: mkdir -p dist @@ -12,6 +12,10 @@ dist/sso-proxy: mkdir -p dist go build -o dist/sso-proxy ./cmd/sso-proxy +dist/sso-devproxy: + mkdir -p dist + go build -o dist/sso-devproxy ./cmd/sso-devproxy + test: ./scripts/test diff --git a/cmd/sso-devproxy/main.go b/cmd/sso-devproxy/main.go new file mode 100644 index 00000000..a0b7ffb1 --- /dev/null +++ b/cmd/sso-devproxy/main.go @@ -0,0 +1,48 @@ +package main + +import ( + "fmt" + "net/http" + "os" + + log "github.com/buzzfeed/sso/internal/pkg/logging" + // "github.com/buzzfeed/sso/internal/pkg/options" + "github.com/buzzfeed/sso/internal/devproxy" + "github.com/kelseyhightower/envconfig" +) + +func init() { + log.SetServiceName("sso-dev-proxy") +} + +func main() { + logger := log.NewLogEntry() + + opts := devproxy.NewOptions() + + err := envconfig.Process("", opts) + if err != nil { + logger.Error(err, "error loading in env vars") + os.Exit(1) + } + + err = opts.Validate() + if err != nil { + logger.Error(err, "error validing options") + os.Exit(1) + } + + proxy, err := devproxy.NewDevProxy(opts) + if err != nil { + logger.Error(err, "error creating devproxy") + os.Exit(1) + } + + s := &http.Server{ + Addr: fmt.Sprintf(":%d", opts.Port), + ReadTimeout: opts.TCPReadTimeout, + WriteTimeout: opts.TCPWriteTimeout, + Handler: devproxy.NewLoggingHandler(os.Stdout, proxy.Handler(), opts.RequestLogging), //, devproxy.StatsdClient), + } + logger.Fatal(s.ListenAndServe()) +} diff --git a/internal/devproxy/collector/collector.go b/internal/devproxy/collector/collector.go new file mode 100644 index 00000000..02514934 --- /dev/null +++ b/internal/devproxy/collector/collector.go @@ -0,0 +1,102 @@ +package collector + +import ( + "runtime" + "time" + + "github.com/datadog/datadog-go/statsd" +) + +// Collector ticks periodically and emits runtime stats to datadog +type Collector struct { + // interval represents the interval inbetween ticks for stats collection + interval time.Duration + + // done, when closed, is used to signal the closure of the runtime polling goroutine + done chan struct{} + + // statsd client used to send metrics + client *statsd.Client +} + +// New creates a new collector that will periodically emit runtime statistics to datadog. +func New(client *statsd.Client, interval time.Duration) *Collector { + return &Collector{ + interval: interval, + client: client, + done: make(chan struct{}), + } +} + +// Run gathers statistics from package runtime and emits them to statsd via client +func (c *Collector) Run() { + tick := time.NewTicker(c.interval) + defer tick.Stop() + for { + select { + case <-c.done: + return + case <-tick.C: + c.emitStats() + } + } +} + +// Close signals the collector to close the polling goroutine, use for graceful shutdowns +func (c *Collector) Close() { + close(c.done) +} + +func (c *Collector) emitStats() { + c.emitCPUStats() + c.emitMemStats() +} + +func (c *Collector) emitCPUStats() { + c.gauge("cpu.goroutines", uint64(runtime.NumGoroutine())) + c.gauge("cpu.cgo_calls", uint64(runtime.NumCgoCall())) +} + +func (c *Collector) emitMemStats() { + m := &runtime.MemStats{} + runtime.ReadMemStats(m) + + // General + c.gauge("mem.alloc", m.Alloc) + c.gauge("mem.total", m.TotalAlloc) + c.gauge("mem.sys", m.Sys) + c.gauge("mem.lookups", m.Lookups) + c.gauge("mem.malloc", m.Mallocs) + c.gauge("mem.frees", m.Frees) + + // Heap + c.gauge("mem.heap.alloc", m.HeapAlloc) + c.gauge("mem.heap.sys", m.HeapSys) + c.gauge("mem.heap.idle", m.HeapIdle) + c.gauge("mem.heap.inuse", m.HeapInuse) + c.gauge("mem.heap.released", m.HeapReleased) + c.gauge("mem.heap.objects", m.HeapObjects) + + // Stack + c.gauge("mem.stack.inuse", m.StackInuse) + c.gauge("mem.stack.sys", m.StackSys) + c.gauge("mem.stack.mspan_inuse", m.MSpanInuse) + c.gauge("mem.stack.mspan_sys", m.MSpanSys) + c.gauge("mem.stack.mcache_inuse", m.MCacheInuse) + c.gauge("mem.stack.mcache_sys", m.MCacheSys) + + // Garbage Collection + c.gauge("mem.gc.sys", m.GCSys) + c.gauge("mem.gc.next", m.NextGC) + c.gauge("mem.gc.last", m.LastGC) + c.gauge("mem.gc.pause_total", m.PauseTotalNs) + c.gauge("mem.gc.pause", m.PauseNs[(m.NumGC+255)%256]) + c.gauge("mem.gc.count", uint64(m.NumGC)) + + // Other + c.gauge("mem.othersys", m.OtherSys) +} + +func (c *Collector) gauge(key string, val uint64) { + c.client.Gauge(key, float64(val), nil, 1.0) +} diff --git a/internal/devproxy/dev_config.go b/internal/devproxy/dev_config.go new file mode 100644 index 00000000..295ac3fe --- /dev/null +++ b/internal/devproxy/dev_config.go @@ -0,0 +1,373 @@ +package devproxy + +import ( + "fmt" + "net/url" + "regexp" + "strings" + "time" + + // "github.com/18F/hmacauth" + "github.com/imdario/mergo" + "gopkg.in/yaml.v2" +) + +const ( + simple = "simple" + rewrite = "rewrite" +) + +var ( + space = regexp.MustCompile(`\s+`) +) + +// ServiceConfig represents the configuration for a given service +type ServiceConfig struct { + Service string `yaml:"service"` + ClusterConfigs map[string]*UpstreamConfig `yaml:",inline"` +} + +// SimpleRoute contains a FromURL and ToURL used to construct simple routes in the reverse proxy. +type SimpleRoute struct { + FromURL *url.URL + ToURL *url.URL +} + +// RewriteRoute contains a FromRegex and ToTemplate used to construct rewrite routes in the reverse proxy. +type RewriteRoute struct { + FromRegex *regexp.Regexp + ToTemplate *url.URL +} + +// UpstreamConfig represents the configuration for a given cluster in a given service +type UpstreamConfig struct { + Service string + + RouteConfig RouteConfig `yaml:",inline"` + + ExtraRoutes []*RouteConfig `yaml:"extra_routes"` + + // Generated at Parse Time + Route interface{} // note: :/ + + // SkipAuthCompiledRegex []*regexp.Regexp + // AllowedGroups []string + Timeout time.Duration + FlushInterval time.Duration + HeaderOverrides map[string]string +} + +// RouteConfig maps to the yaml config fields, +// * "from" - the domain that will be used to access the service +// * "to" - the cname of the proxied service (this tells sso proxy where to proxy requests that come in on the from field) +type RouteConfig struct { + From string `yaml:"from"` + To string `yaml:"to"` + Type string `yaml:"type"` + Options *OptionsConfig `yaml:"options"` +} + +// OptionsConfig maps to the yaml config fields: +// * header_overrides - overrides any heads set either by sso proxy itself or upstream applications. +// This can be useful for modifying browser security headers. +// * skip_auth_regex - skips authentication for paths matching these regular expressions. +// * allowed_groups - optional list of authorized google groups that can access the service. +// * timeout - duration before timing out request. +// * flush_interval - interval at which the proxy should flush data to the browser +type OptionsConfig struct { + HeaderOverrides map[string]string `yaml:"header_overrides"` + // SkipAuthRegex []string `yaml:"skip_auth_regex"` + // AllowedGroups []string `yaml:"allowed_groups"` + Timeout time.Duration `yaml:"timeout"` + FlushInterval time.Duration `yaml:"flush_interval"` +} + +// ErrParsingConfig is an error specific to config parsing. +type ErrParsingConfig struct { + Message string + Err error +} + +// Error() implements the error interface, returning a string representation of the error. +func (e *ErrParsingConfig) Error() string { + if e.Err != nil { + return fmt.Sprintf("%s error=%s", e.Message, e.Err) + } + return e.Message +} + +func loadServiceConfigs(raw []byte, cluster, scheme string, configVars map[string]string) ([]*UpstreamConfig, error) { + // We fill in all templated values and resolve overrides + rawTemplated := resolveTemplates(raw, configVars) + + serviceConfigs, err := parseServiceConfigs(rawTemplated) + if err != nil { + return nil, err + } + + // we don't set this to the len(serviceConfig) since not all service configs + // are configured for all clusters, leaving nil tail pointers in the slice. + configs := make([]*UpstreamConfig, 0) + // resovle overrides + for _, service := range serviceConfigs { + proxy, err := resolveUpstreamConfig(service, cluster) + if err != nil { + return nil, err + } + + // if we don't resolve a upstream config, this cluster is not configured for this upstream + // so the proxy struct will be nil and we skip adding it to our running config + if proxy != nil { + configs = append(configs, proxy) + } + } + + extraRoutes := make([]*UpstreamConfig, 0) + for _, proxy := range configs { + if len(proxy.ExtraRoutes) == 0 { + continue + } + + for _, extra := range proxy.ExtraRoutes { + resolvedProxy, err := resolveExtraRoute(extra, proxy) + if err != nil { + return nil, err + } + extraRoutes = append(extraRoutes, resolvedProxy) + } + // for completeness, we set this to nil now that we've processed extra routes + proxy.ExtraRoutes = nil + } + + configs = append(configs, extraRoutes...) + + // We verify the config has necessary values + for _, proxy := range configs { + err := validateUpstreamConfig(proxy) + if err != nil { + return nil, err + } + } + + // We compose the URLs for all our finalized domains + for _, proxy := range configs { + switch proxy.RouteConfig.Type { + case simple, "": + route, err := simpleRoute(scheme, proxy.RouteConfig) + if err != nil { + return nil, err + } + proxy.Route = route + case rewrite: + route, err := rewriteRoute(scheme, proxy.RouteConfig) + if err != nil { + return nil, err + } + proxy.Route = route + default: + return nil, &ErrParsingConfig{ + Message: fmt.Sprintf("unknown routing config type %q", proxy.RouteConfig.Type), + Err: nil, + } + } + } + + // We validate OptionsConfig + for _, proxy := range configs { + err := parseOptionsConfig(proxy) + if err != nil { + return nil, err + } + } + + // for _, proxy := range configs { + // key := fmt.Sprintf("%s_signing_key", proxy.Service) + // signingKey, ok := configVars[key] + // if !ok { + // continue + // } + // auth, err := generateHmacAuth(signingKey) + // if err != nil { + // return nil, &ErrParsingConfig{ + // Message: fmt.Sprintf("unable to generate hmac auth for %s", proxy.Service), + // Err: err, + // } + // } + // proxy.HMACAuth = auth + + // } + + return configs, nil +} + +func rewriteRoute(scheme string, routeConfig RouteConfig) (*RewriteRoute, error) { + compiled, err := regexp.Compile(routeConfig.From) + if err != nil { + return nil, &ErrParsingConfig{ + Message: "unable to compile rewrite from regex", + Err: err, + } + } + + toURL := &url.URL{ + Scheme: scheme, + Opaque: routeConfig.To, // we use opaque since the template value may not be a parsable URL + } + + return &RewriteRoute{ + FromRegex: compiled, + ToTemplate: toURL, + }, nil +} + +func simpleRoute(scheme string, routeConfig RouteConfig) (*SimpleRoute, error) { + // url parse domain + fromURL, err := urlParse(scheme, routeConfig.From) + if err != nil { + return nil, &ErrParsingConfig{ + Message: "unable to url parse `from` parameter", + Err: err, + } + } + + // url parse to url + toURL, err := urlParse(scheme, routeConfig.To) + if err != nil { + return nil, &ErrParsingConfig{ + Message: "unable to url parse `to` parameter", + Err: err, + } + } + + return &SimpleRoute{ + FromURL: fromURL, + ToURL: toURL, + }, nil +} + +func urlParse(scheme, uri string) (*url.URL, error) { + // NOTE: This is done intentionally to add a scheme so it is valid to parse. + // + // From https://golang.org/pkg/net/url/#Parse + // > Trying to parse a hostname and path without a scheme is invalid + // > but may not necessarily return an error, due to parsing ambiguities. + if !strings.Contains(uri, "://") { + uri = fmt.Sprintf("%s://%s", scheme, uri) + } + return url.Parse(uri) +} + +func parseServiceConfigs(data []byte) ([]*ServiceConfig, error) { + serviceConfigs := make([]*ServiceConfig, 0) + err := yaml.Unmarshal(data, &serviceConfigs) + if err != nil { + return nil, &ErrParsingConfig{ + Message: "failed to parse yaml", + Err: err, + } + } + + return serviceConfigs, err +} + +func resolveExtraRoute(routeConfig *RouteConfig, src *UpstreamConfig) (*UpstreamConfig, error) { + dst := &UpstreamConfig{RouteConfig: *routeConfig} + + err := mergo.Merge(dst, *src) + if err != nil { + return nil, err + } + + dst.ExtraRoutes = nil + + return dst, nil +} + +func resolveUpstreamConfig(service *ServiceConfig, override string) (*UpstreamConfig, error) { + dst, dstOk := service.ClusterConfigs["default"] + src, srcOk := service.ClusterConfigs[override] + + if !(dstOk || srcOk) { + // no default or cluster is configured, which we allow + return nil, nil + } + + if dst == nil { + dst = &UpstreamConfig{} + } + + if src == nil { + src = &UpstreamConfig{} + } + + err := mergo.Merge(dst, *src, mergo.WithOverride) + if err != nil { + return nil, err + } + + dst.Service = cleanWhiteSpace(service.Service) + return dst, nil +} + +func validateUpstreamConfig(proxy *UpstreamConfig) error { + if proxy.Service == "" { + return &ErrParsingConfig{ + Message: "missing `service` parameter", + } + } + + if proxy.RouteConfig.From == "" { + return &ErrParsingConfig{ + Message: "missing `from` parameter", + } + } + + if proxy.RouteConfig.To == "" { + return &ErrParsingConfig{ + Message: "missing `to` parameter", + } + } + + return nil +} + +func resolveTemplates(raw []byte, templateVars map[string]string) []byte { + rawString := string(raw) + for k, v := range templateVars { + templated := fmt.Sprintf("{{%s}}", k) + rawString = strings.Replace(rawString, templated, v, -1) + } + return []byte(rawString) +} + +func parseOptionsConfig(proxy *UpstreamConfig) error { + if proxy.RouteConfig.Options == nil { + return nil + } + + // We compile all the regexes in SkipAuth Regex + // for _, uncompiled := range proxy.RouteConfig.Options.SkipAuthRegex { + // compiled, err := regexp.Compile(uncompiled) + // if err != nil { + // return &ErrParsingConfig{ + // Message: "unable to compile skip auth regex", + // Err: err, + // } + // } + // proxy.SkipAuthCompiledRegex = append(proxy.SkipAuthCompiledRegex, compiled) + // } + + // proxy.AllowedGroups = proxy.RouteConfig.Options.AllowedGroups + proxy.Timeout = proxy.RouteConfig.Options.Timeout + proxy.FlushInterval = proxy.RouteConfig.Options.FlushInterval + proxy.HeaderOverrides = proxy.RouteConfig.Options.HeaderOverrides + + proxy.RouteConfig.Options = nil + + return nil +} + +func cleanWhiteSpace(s string) string { + // This trims all white space from a service name and collapses all remaining space to `_` + return space.ReplaceAllString(strings.TrimSpace(s), "_") // +} diff --git a/internal/devproxy/devproxy.go b/internal/devproxy/devproxy.go new file mode 100644 index 00000000..9a86b7b1 --- /dev/null +++ b/internal/devproxy/devproxy.go @@ -0,0 +1,392 @@ +package devproxy + +import ( + "encoding/json" + "time" + + // "errors" + "fmt" + "html/template" + "io" + + // "net" + "net/http" + "net/http/httputil" + "net/url" + + // "reflect" + "regexp" + "strings" + + // "time" + + // "github.com/buzzfeed/sso/internal/pkg/aead" + log "github.com/buzzfeed/sso/internal/pkg/logging" + // "github.com/buzzfeed/sso/internal/dev/collector" + // "github.com/18F/hmacauth" + // "github.com/datadog/datadog-go/statsd" +) + +// SignatureHeader is the header name where the signed request header is stored. +const SignatureHeader = "Gap-Signature" + +// SignatureHeaders are the headers that are valid in the request. +var SignatureHeaders = []string{ + "X-Forwarded-User", + "X-Forwarded-Email", + "X-Forwarded-Groups", + "X-Forwarded-Access-Token", +} + +const statusInvalidHost = 421 + +// DevProxy stores all the information associated with proxying the request. +type DevProxy struct { + redirectURL *url.URL // the url to receive requests at + skipAuthPreflight bool + templates *template.Template + mux map[string]*route + regexRoutes []*route +} + +type route struct { + upstreamConfig *UpstreamConfig + handler http.Handler + tags []string + + // only used for ones that have regex + regex *regexp.Regexp +} + +// StateParameter holds the redirect id along with the session id. +type StateParameter struct { + SessionID string `json:"session_id"` + RedirectURI string `json:"redirect_uri"` +} + +// UpstreamProxy stores information necessary for proxying the request back to the upstream. +type UpstreamProxy struct { + name string + handler http.Handler +} + +// upstreamTransport is used to ensure that upstreams cannot override the +// security headers applied by dev_proxy +type upstreamTransport struct{} + +// RoundTrip round trips the request and deletes security headers before returning the response. +func (t *upstreamTransport) RoundTrip(req *http.Request) (*http.Response, error) { + resp, err := http.DefaultTransport.RoundTrip(req) + if err != nil { + logger := log.NewLogEntry() + logger.Error(err, "error in upstreamTransport RoundTrip") + return nil, err + } + for key := range securityHeaders { + resp.Header.Del(key) + } + return resp, err +} + +// ServeHTTP calls the upstream's ServeHTTP function. +func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + + start := time.Now() + u.handler.ServeHTTP(w, r) + duration := time.Now().Sub(start) + + fmt.Sprintf("service_name:%s, duation:%s", u.name, duration) +} + +func singleJoiningSlash(a, b string) string { + aslash := strings.HasSuffix(a, "/") + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b + } + return a + b +} + +// NewReverseProxy creates a reverse proxy to a specified url. +// It adds an X-Forwarded-Host header that is the request's host. +func NewReverseProxy(to *url.URL) *httputil.ReverseProxy { + targetQuery := to.RawQuery + director := func(req *http.Request) { + req.URL.Scheme = to.Scheme + req.URL.Host = to.Host + req.URL.Path = singleJoiningSlash(to.Path, req.URL.Path) + + if targetQuery == "" || req.URL.RawQuery == "" { + req.URL.RawQuery = targetQuery + req.URL.RawQuery + } else { + req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery + } + if _, ok := req.Header["User-Agent"]; !ok { + // explicitly disable User-Agent so it's not set to default value + req.Header.Set("User-Agent", "") + } + } + proxy := &httputil.ReverseProxy{Director: director} + proxy.Transport = &upstreamTransport{} + dir := proxy.Director + + proxy.Director = func(req *http.Request) { + req.Header.Add("X-Forwarded-Host", req.Host) + req.Header.Set("X-Forwarded-User", req.Header.Get("User")) + req.Header.Set("X-Forwarded-Email", req.Header.Get("Email")) + req.Header.Set("X-Forwarded-Groups", req.Header.Get("Groups")) + req.Header.Set("X-Forwarded-Access-Token", "") + dir(req) + req.Host = to.Host + } + return proxy +} + +// NewRewriteReverseProxy creates a reverse proxy that is capable of creating upstream +// urls on the fly based on a from regex and a templated to field. +// It adds an X-Forwarded-Host header to the the upstream's request. +func NewRewriteReverseProxy(route *RewriteRoute) *httputil.ReverseProxy { + proxy := &httputil.ReverseProxy{} + proxy.Transport = &upstreamTransport{} + proxy.Director = func(req *http.Request) { + // we do this to rewrite requests + rewritten := route.FromRegex.ReplaceAllString(req.Host, route.ToTemplate.Opaque) + + // we use to favor scheme's used in the regex, else we use the default passed in via the template + target, err := urlParse(route.ToTemplate.Scheme, rewritten) + if err != nil { + logger := log.NewLogEntry() + // we aren't in an error handling context so we have to fake it(thanks stdlib!) + logger.WithRequestHost(req.Host).WithRewriteRoute(route).Error( + err, "unable to parse and replace rewrite url") + req.URL = nil // this will raise an error in http.RoundTripper + return + } + director := httputil.NewSingleHostReverseProxy(target).Director + + req.Header.Add("X-Forwarded-Host", req.Host) + req.Header.Set("X-Forwarded-User", req.Header.Get("User")) + req.Header.Set("X-Forwarded-Email", req.Header.Get("Email")) + req.Header.Set("X-Forwarded-Groups", req.Header.Get("Groups")) + req.Header.Set("X-Forwarded-Access-Token", "") + director(req) + req.Host = target.Host + } + return proxy +} + +// NewReverseProxyHandler creates a new http.Handler given a httputil.ReverseProxy +func NewReverseProxyHandler(reverseProxy *httputil.ReverseProxy, opts *Options, config *UpstreamConfig) (http.Handler, []string) { + upstreamProxy := &UpstreamProxy{ + name: config.Service, + handler: reverseProxy, + } + if config.FlushInterval != 0 { + return NewStreamingHandler(upstreamProxy, opts, config), []string{"handler:streaming"} + } + return NewTimeoutHandler(upstreamProxy, opts, config), []string{"handler:timeout"} +} + +// NewTimeoutHandler creates a new handler with a configure timeout. +func NewTimeoutHandler(handler http.Handler, opts *Options, config *UpstreamConfig) http.Handler { + timeout := opts.DefaultUpstreamTimeout + if config.Timeout != 0 { + timeout = config.Timeout + } + timeoutMsg := fmt.Sprintf( + "%s failed to respond within the %s timeout period", config.Service, timeout) + return http.TimeoutHandler(handler, timeout, timeoutMsg) +} + +// NewStreamingHandler creates a new handler capable of proxying a stream +func NewStreamingHandler(handler http.Handler, opts *Options, config *UpstreamConfig) http.Handler { + upstreamProxy := handler.(*UpstreamProxy) + reverseProxy := upstreamProxy.handler.(*httputil.ReverseProxy) + reverseProxy.FlushInterval = config.FlushInterval + return upstreamProxy +} + +// NewDevProxy creates a new DevProxy struct. +func NewDevProxy(opts *Options, optFuncs ...func(*DevProxy) error) (*DevProxy, error) { + logger := log.NewLogEntry() + logger.Info("NewDevProxy...") + + p := &DevProxy{ + // these fields make up the routing mechanism + mux: make(map[string]*route), + regexRoutes: make([]*route, 0), + + redirectURL: &url.URL{Path: "/oauth2/callback"}, + // skipAuthPreflight: opts.SkipAuthPreflight, + templates: getTemplates(), + } + + for _, optFunc := range optFuncs { + err := optFunc(p) + if err != nil { + return nil, err + } + } + + for _, upstreamConfig := range opts.upstreamConfigs { + switch route := upstreamConfig.Route.(type) { + case *SimpleRoute: + reverseProxy := NewReverseProxy(route.ToURL) + handler, tags := NewReverseProxyHandler(reverseProxy, opts, upstreamConfig) + p.Handle(route.FromURL.Host, handler, tags, upstreamConfig) + case *RewriteRoute: + reverseProxy := NewRewriteReverseProxy(route) + handler, tags := NewReverseProxyHandler(reverseProxy, opts, upstreamConfig) + p.HandleRegex(route.FromRegex, handler, tags, upstreamConfig) + default: + return nil, fmt.Errorf("unkown route type") + } + } + + return p, nil +} + +// Handler returns a http handler for an DevProxy +func (p *DevProxy) Handler() http.Handler { + mux := http.NewServeMux() + mux.HandleFunc("/favicon.ico", p.Favicon) + mux.HandleFunc("/robots.txt", p.RobotsTxt) + // mux.HandleFunc("/oauth2/callback", p.DevCallback) + mux.HandleFunc("/", p.Proxy) + + // Global middleware, which will be applied to each request in reverse + // order as applied here (i.e., we want to validate the host _first_ when + // processing a request) + var handler http.Handler = mux + // if p.CookieSecure { + // handler = requireHTTPS(handler) + // } + handler = p.setResponseHeaderOverrides(handler) + handler = setSecurityHeaders(handler) + handler = p.validateHost(handler) + + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + // Skip host validation for /ping requests because they hit the LB directly. + if req.URL.Path == "/ping" { + p.PingPage(rw, req) + return + } + handler.ServeHTTP(rw, req) + }) +} + +// RobotsTxt sets the User-Agent header in the response to be "Disallow" +func (p *DevProxy) RobotsTxt(rw http.ResponseWriter, _ *http.Request) { + rw.WriteHeader(http.StatusOK) + fmt.Fprintf(rw, "User-agent: *\nDisallow: /") +} + +// Favicon will proxy the request as usual +func (p *DevProxy) Favicon(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(http.StatusOK) + p.Proxy(rw, req) +} + +// PingPage send back a 200 OK response. +func (p *DevProxy) PingPage(rw http.ResponseWriter, _ *http.Request) { + rw.WriteHeader(http.StatusOK) + fmt.Fprintf(rw, "OK") +} + +// ErrorPage renders an error page with a given status code, title, and message. +func (p *DevProxy) ErrorPage(rw http.ResponseWriter, req *http.Request, code int, title string, message string) { + if p.isXMLHTTPRequest(req) { + rw.Header().Set("Content-Type", "application/json") + rw.WriteHeader(code) + err := json.NewEncoder(rw).Encode(struct { + Error string `json:"error"` + }{ + Error: message, + }) + if err != nil { + io.WriteString(rw, err.Error()) + } + } else { + logger := log.NewLogEntry() + logger.WithHTTPStatus(code).WithPageTitle(title).WithPageMessage(message).Info( + "error page") + rw.WriteHeader(code) + t := struct { + Code int + Title string + Message string + }{ + Code: code, + Title: title, + Message: message, + } + p.templates.ExecuteTemplate(rw, "error.html", t) + } +} + +func (p *DevProxy) isXMLHTTPRequest(req *http.Request) bool { + return req.Header.Get("X-Requested-With") == "XMLHttpRequest" +} + +// Proxy forwards the request. +func (p *DevProxy) Proxy(rw http.ResponseWriter, req *http.Request) { + + logger := log.NewLogEntry() + // start := time.Now() + logger.Info("Proxy...") + + // We have validated the users request and now proxy their request to the provided upstream. + route, ok := p.router(req) + if !ok { + p.UnknownHost(rw, req) + return + } + + route.handler.ServeHTTP(rw, req) +} + +// UnknownHost returns an http error for unknown or invalid hosts +func (p *DevProxy) UnknownHost(rw http.ResponseWriter, req *http.Request) { + logger := log.NewLogEntry() + + // tags := []string{ + // fmt.Sprintf("action:%s", GetActionTag(req)), + // "error:unknown_host", + // } + // p.StatsdClient.Incr("application_error", tags, 1.0) + logger.WithRequestHost(req.Host).Error("unknown host") + http.Error(rw, "", statusInvalidHost) +} + +// Handle constructs a route from the given host string and matches it to the provided http.Handler and UpstreamConfig +func (p *DevProxy) Handle(host string, handler http.Handler, tags []string, upstreamConfig *UpstreamConfig) { + tags = append(tags, "route:simple") + p.mux[host] = &route{handler: handler, upstreamConfig: upstreamConfig, tags: tags} +} + +// HandleRegex constructs a route from the given regexp and matches it to the provided http.Handler and UpstreamConfig +func (p *DevProxy) HandleRegex(regex *regexp.Regexp, handler http.Handler, tags []string, upstreamConfig *UpstreamConfig) { + tags = append(tags, "route:rewrite") + p.regexRoutes = append(p.regexRoutes, &route{regex: regex, handler: handler, upstreamConfig: upstreamConfig, tags: tags}) +} + +// router attempts to find a route for a request. If a route is successfully matched, +// it returns the route information and a bool value of `true`. If a route can not be matched, +//a nil value for the route and false bool value is returned. +func (p *DevProxy) router(req *http.Request) (*route, bool) { + route, ok := p.mux[req.Host] + if ok { + return route, true + } + + for _, route := range p.regexRoutes { + if route.regex.MatchString(req.Host) { + return route, true + } + } + + return nil, false +} diff --git a/internal/devproxy/devproxy_test.go b/internal/devproxy/devproxy_test.go new file mode 100644 index 00000000..adab4134 --- /dev/null +++ b/internal/devproxy/devproxy_test.go @@ -0,0 +1,933 @@ +package devproxy + +import ( + // "crypto" + // "encoding/base64" + "encoding/json" + // "errors" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "net/http" + "net/http/httptest" + "net/url" + "reflect" + "regexp" + "strings" + "testing" + "time" + + "github.com/buzzfeed/sso/internal/pkg/testutil" +) + +func init() { + log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile) + +} + +func testValidatorFunc(valid bool) func(*DevProxy) error { + return func(p *DevProxy) error { + return nil + } +} + +func TestNewReverseProxy(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + hostname, _, _ := net.SplitHostPort(r.Host) + w.Write([]byte(hostname)) + })) + defer backend.Close() + + backendURL, _ := url.Parse(backend.URL) + backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host) + backendHost := net.JoinHostPort(backendHostname, backendPort) + proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/") + + proxyHandler := NewReverseProxy(proxyURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + res, _ := http.DefaultClient.Do(getReq) + bodyBytes, _ := ioutil.ReadAll(res.Body) + if g, e := string(bodyBytes), backendHostname; g != e { + t.Errorf("got body %q; expected %q", g, e) + } +} + +func TestNewRewriteReverseProxy(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(200) + rw.Write([]byte(req.Host)) + })) + defer upstream.Close() + + parsedUpstreamURL, err := url.Parse(upstream.URL) + if err != nil { + t.Fatalf("expected to parse upstream URL err:%q", err) + } + + route := &RewriteRoute{ + FromRegex: regexp.MustCompile("(.*)"), + ToTemplate: &url.URL{ + Scheme: parsedUpstreamURL.Scheme, + Opaque: parsedUpstreamURL.Host, + }, + } + + rewriteProxy := NewRewriteReverseProxy(route) + + frontend := httptest.NewServer(rewriteProxy) + defer frontend.Close() + + resp, err := http.Get(frontend.URL) + if err != nil { + t.Fatalf("expected to make successful request err:%q", err) + } + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("expected to read body err:%q", err) + } + + if string(body) != parsedUpstreamURL.Host { + t.Logf("got %v", string(body)) + t.Logf("want %v", parsedUpstreamURL.Host) + t.Fatalf("got unexpected response from upstream") + } +} + +func TestNewReverseProxyHostname(t *testing.T) { + type respStruct struct { + Host string `json:"host"` + XForwardedHost string `json:"x-forwarded-host"` + // XForwardedEmail string `json:"x-forwarded-email"` + // XForwardedUser string `json:"x-forwarded-user"` + // XForwardedGroups string `json:"x-forwarded-groups"` + // XForwardedAccessToken string `json:"x-forwarded-access-token"` + } + + to := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + body, err := json.Marshal( + &respStruct{ + Host: r.Host, + XForwardedHost: r.Header.Get("X-Forwarded-Host"), + // XForwardedEmail: r.Header.Get(`json:"x-forwarded-email"`), + // XForwardedUser: r.Header.Get(`json:"x-forwarded-user"`), + // XForwardedGroups: r.Header.Get(`json:"x-forwarded-groups"`), + // XForwardedAccessToken: r.Header.Get(`json:"x-forwarded-access-token"`), + }, + ) + if err != nil { + t.Fatalf("expected to marshal json: %s", err) + } + rw.Write(body) + })) + defer to.Close() + + toURL, err := url.Parse(to.URL) + if err != nil { + t.Fatalf("expected to parse to url: %s", err) + } + + reverseProxy := NewReverseProxy(toURL) + from := httptest.NewServer(reverseProxy) + defer from.Close() + + fromURL, err := url.Parse(from.URL) + if err != nil { + t.Fatalf("expected to parse from url: %s", err) + } + + want := &respStruct{ + Host: toURL.Host, + XForwardedHost: fromURL.Host, + } + + res, err := http.Get(from.URL) + if err != nil { + t.Fatalf("expected to be able to make req: %s", err) + } + + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("expected to read body: %s", err) + } + + got := &respStruct{} + err = json.Unmarshal(body, got) + if err != nil { + t.Fatalf("expected to decode json: %s", err) + } + + if !reflect.DeepEqual(want, got) { + t.Logf(" got host: %v", got.Host) + t.Logf("want host: %v", want.Host) + + t.Logf(" got X-Forwarded-Host: %v", got.XForwardedHost) + t.Logf("want X-Forwarded-Host: %v", want.XForwardedHost) + + t.Errorf("got unexpected response for Host or X-Forwarded-Host header") + } + +} + +func TestRoundTrip(t *testing.T) { + testCases := []struct { + name string + url string + expectedError bool + }{ + { + name: "no error", + url: "https://www.example.com/", + }, + { + name: "with error", + url: "/", + expectedError: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest("GET", tc.url, nil) + ut := upstreamTransport{} + resp, err := ut.RoundTrip(req) + if err == nil && tc.expectedError { + t.Errorf("expected error but error was nil") + } + if err != nil && !tc.expectedError { + t.Errorf("unexpected error %s", err.Error()) + } + if err != nil { + return + } + for key := range securityHeaders { + if resp.Header.Get(key) != "" { + t.Errorf("security header %s expected to be deleted but was %s", key, resp.Header.Get(key)) + } + } + }) + } +} + +func generateTestUpstreamConfigs(to string) []*UpstreamConfig { + if !strings.Contains(to, "://") { + to = fmt.Sprintf("%s://%s", "http", to) + } + parsed, err := url.Parse(to) + if err != nil { + panic(err) + } + templateVars := map[string]string{ + "root_domain": "dev", + "cluster": "sso", + } + upstreamConfigs, err := loadServiceConfigs([]byte(fmt.Sprintf(` +- service: foo + default: + from: foo.sso.dev + to: %s +`, parsed)), "DevProxy", "http", templateVars) + if err != nil { + panic(err) + } + return upstreamConfigs +} + +func TestRobotsTxt(t *testing.T) { + opts := NewOptions() + opts.upstreamConfigs = generateTestUpstreamConfigs("httpheader.net/") + opts.Validate() + + proxy, err := NewDevProxy(opts) + if err != nil { + t.Errorf("unexpected error %s", err) + } + rw := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "https://foo.sso.dev/robots.txt", nil) + proxy.Handler().ServeHTTP(rw, req) + testutil.Equal(t, 200, rw.Code) + testutil.Equal(t, "User-agent: *\nDisallow: /", rw.Body.String()) +} + +func TestFavicon(t *testing.T) { + opts := NewOptions() + opts.upstreamConfigs = generateTestUpstreamConfigs("httpheader.net/") + opts.Validate() + + proxy, _ := NewDevProxy(opts) + rw := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "https://foo.sso.dev/favicon.ico", nil) + proxy.Handler().ServeHTTP(rw, req) + testutil.Equal(t, http.StatusOK, rw.Code) +} + +type SignatureTest struct { + opts *Options + upstream *httptest.Server + upstreamHost string + header http.Header + rw *httptest.ResponseRecorder +} + +func generateSignatureTestUpstreamConfigs(key, to string) []*UpstreamConfig { + + if !strings.Contains(to, "://") { + to = fmt.Sprintf("%s://%s", "http", to) + } + parsed, err := url.Parse(to) + if err != nil { + panic(err) + } + templateVars := map[string]string{ + "root_domain": "dev", + "cluster": "sso", + "foo_signing_key": key, + } + upstreamConfigs, err := loadServiceConfigs([]byte(fmt.Sprintf(` +- service: foo + default: + from: foo.{{cluster}}.{{root_domain}} + to: %s +`, parsed)), "DevProxy", "http", templateVars) + if err != nil { + panic(err) + } + + return upstreamConfigs +} + +func (st *SignatureTest) Close() { + st.upstream.Close() +} + +func TestPing(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte("upstream")) + })) + defer upstream.Close() + + opts := NewOptions() + + opts.upstreamConfigs = generateTestUpstreamConfigs(upstream.URL) + opts.Validate() + + proxy, _ := NewDevProxy(opts) + + testCases := []struct { + name string + url string + host string + authenticated bool + expectedCode int + }{ + { + name: "ping never reaches upstream", + url: "http://foo.sso.dev/ping", + authenticated: true, + expectedCode: http.StatusOK, + }, + { + name: "ping skips host check with no host set", + url: "/ping", + expectedCode: http.StatusOK, + }, + { + name: "ping skips host check with unknown host set", + url: "/ping", + host: "example.com", + expectedCode: http.StatusOK, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rw := httptest.NewRecorder() + req, _ := http.NewRequest("GET", tc.url, nil) + + proxy.Handler().ServeHTTP(rw, req) + + if tc.expectedCode != rw.Code { + t.Errorf("expected code %d, got %d", tc.expectedCode, rw.Code) + } + if rw.Body.String() != "OK" { + t.Errorf("expected body = %q, got %q", "OK", rw.Body.String()) + } + }) + } +} + +func TestSecurityHeaders(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/add-header": + w.Header().Set("X-Test-Header", "true") + case "/override-security-header": + w.Header().Set("X-Frame-Options", "OVERRIDE") + } + w.WriteHeader(200) + w.Write([]byte(r.URL.RequestURI())) + })) + defer upstream.Close() + + opts := NewOptions() + opts.upstreamConfigs = generateTestUpstreamConfigs(upstream.URL) + opts.Validate() + + proxy, _ := NewDevProxy(opts, testValidatorFunc(true)) + + testCases := []struct { + name string + path string + expectedCode int + expectedHeaders map[string]string + }{ + { + name: "security headers are added to authenticated requests", + path: "/", + expectedCode: http.StatusOK, + expectedHeaders: securityHeaders, + }, + // { + // name: "security headers are added to unauthenticated requests", + // path: "/", + // expectedCode: http.StatusFound, + // expectedHeaders: securityHeaders, + // }, + { + name: "additional headers set by upstream are proxied", + path: "/add-header", + expectedCode: http.StatusOK, + expectedHeaders: map[string]string{ + "X-Test-Header": "true", + }, + }, + { + name: "security headers may NOT be overridden by upstream", + path: "/override-security-header", + expectedCode: http.StatusOK, + expectedHeaders: map[string]string{ + "X-Frame-Options": "SAMEORIGIN", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rw := httptest.NewRecorder() + req, _ := http.NewRequest("GET", fmt.Sprintf("http://foo.sso.dev%s", tc.path), nil) + + proxy.Handler().ServeHTTP(rw, req) + + if tc.expectedCode != rw.Code { + t.Errorf("expected code %d, got %d", tc.expectedCode, rw.Code) + out, _ := json.Marshal(tc) + fmt.Println(string(out)) + } + if tc.expectedCode == http.StatusOK { + if rw.Body.String() != tc.path { + t.Errorf("expected body = %q, got %q", tc.path, rw.Body.String()) + } + } + for key, val := range tc.expectedHeaders { + vals, found := rw.HeaderMap[http.CanonicalHeaderKey(key)] + if !found { + t.Errorf("expected header %s not found", key) + } else if len(vals) > 1 { + t.Errorf("got duplicate values for headers %s: %v", key, vals) + } else if vals[0] != val { + t.Errorf("expected header %s=%q, got %s=%q\n", key, val, key, vals[0]) + } + } + }) + } +} + +func makeUpstreamConfigWithHeaderOverrides(overrides map[string]string) []*UpstreamConfig { + templateVars := map[string]string{ + "root_domain": "dev", + "cluster": "sso", + } + upstreamConfigs, err := loadServiceConfigs([]byte(fmt.Sprintf(` +- service: foo + default: + from: foo.sso.dev + to: httpheader.net/ + +- service: bar + default: + from: bar.sso.dev + to: bar-internal.sso.dev +`)), "DevProxy", "http", templateVars) + if err != nil { + panic(err) + } + upstreamConfigs[0].HeaderOverrides = overrides // we override foo and not bar + return upstreamConfigs +} + +func TestHeaderOverrides(t *testing.T) { + testCases := []struct { + name string + overrides map[string]string + expectedCode int + expectedHeaders map[string]string + }{ + { + name: "security headers are added to requests", + overrides: nil, + expectedCode: http.StatusOK, + expectedHeaders: securityHeaders, + }, + { + name: "security headers are overridden by config", + overrides: map[string]string{ + "X-Frame-Options": "ALLOW-FROM nsa.gov", + }, + expectedCode: http.StatusOK, + expectedHeaders: map[string]string{ + "X-Frame-Options": "ALLOW-FROM nsa.gov", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + opts := NewOptions() + opts.upstreamConfigs = makeUpstreamConfigWithHeaderOverrides(tc.overrides) + opts.Validate() + + proxy, _ := NewDevProxy(opts, testValidatorFunc(true)) + + // Check Foo + rw := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "http://foo.sso.dev/", nil) + proxy.Handler().ServeHTTP(rw, req) + for key, val := range tc.expectedHeaders { + vals, found := rw.HeaderMap[http.CanonicalHeaderKey(key)] + if !found { + t.Errorf("expected header %s not found", key) + } else if len(vals) > 1 { + t.Errorf("got duplicate values for headers %s: %v", key, vals) + } else if vals[0] != val { + t.Errorf("expected header %s=%q, got %s=%q\n", key, val, key, vals[0]) + } + } + + // Check Bar + rwBar := httptest.NewRecorder() + reqBar, _ := http.NewRequest("GET", "http://bar.sso.dev/", nil) + proxy.Handler().ServeHTTP(rwBar, reqBar) + for key, val := range securityHeaders { + vals, found := rwBar.HeaderMap[http.CanonicalHeaderKey(key)] + if !found { + t.Errorf("expected header %s not found", key) + } else if len(vals) > 1 { + t.Errorf("got duplicate values for headers %s: %v", key, vals) + } else if vals[0] != val { + t.Errorf("expected header %s=%q, got %s=%q\n", key, val, key, vals[0]) + } + } + }) + } +} + +// func TestHTTPSRedirect(t *testing.T) { +// upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// w.Write([]byte(r.URL.String())) +// })) +// defer upstream.Close() + +// testCases := []struct { +// name string +// url string +// host string +// requestHeaders map[string]string +// expectedCode int +// expectedLocation string // must match entire Location header +// expectedLocationHost string // just match hostname of Location header +// expectSTS bool // should we get a Strict-Transport-Security header? +// }{ +// { +// name: "no https redirect with http ", +// url: "http://foo.sso.dev/", +// expectedCode: http.StatusOK, +// expectSTS: false, +// }, +// { +// name: "no https redirect with https ", +// url: "https://foo.sso.dev/", +// expectedCode: http.StatusOK, +// expectSTS: false, +// }, +// { +// name: "https redirect ", +// url: "http://foo.sso.dev/", +// expectedCode: http.StatusOK, +// expectedLocation: "https://foo.sso.dev/", +// expectSTS: true, +// }, +// { +// name: "https redirect", +// url: "http://foo.sso.dev/", +// expectedCode: http.StatusOK, +// expectedLocation: "https://foo.sso.dev/", +// expectSTS: true, +// }, +// { +// name: "no https redirect ", +// url: "https://foo.sso.dev/", +// expectedCode: http.StatusOK, +// expectSTS: true, +// }, +// { +// name: "no https redirect ", +// url: "https://foo.sso.dev/", +// expectedCode: http.StatusOK, +// expectSTS: true, +// }, +// { +// name: "request path and query are preserved in redirect", +// url: "http://foo.sso.dev/foo/bar.html?a=1&b=2&c=3", +// expectedCode: http.StatusOK, +// expectedLocation: "https://foo.sso.dev/foo/bar.html?a=1&b=2&c=3", +// expectSTS: true, +// }, +// { +// name: "no https redirect with http and X-Forwarded-Proto=https", +// url: "http://foo.sso.dev/", +// requestHeaders: map[string]string{"X-Forwarded-Proto": "https"}, +// expectedCode: http.StatusOK, +// expectSTS: true, +// }, +// { +// name: "correct host name with relative URL", +// url: "/", +// host: "foo.sso.dev", +// expectedLocation: "https://foo.sso.dev/", +// expectedCode: http.StatusOK, +// expectSTS: true, +// }, +// { +// name: "host validation is applied before https redirect", +// url: "http://bar.sso.dev/", +// expectedCode: statusInvalidHost, +// expectSTS: false, +// }, +// } + +// for _, tc := range testCases { +// t.Run(tc.name, func(t *testing.T) { +// opts := NewOptions() +// opts.upstreamConfigs = generateTestUpstreamConfigs(upstream.URL) +// opts.Validate() + +// proxy, _ := NewDevProxy(opts, testValidatorFunc(true)) + +// rw := httptest.NewRecorder() +// req, _ := http.NewRequest("GET", tc.url, nil) + +// for key, val := range tc.requestHeaders { +// req.Header.Set(key, val) +// } + +// if tc.host != "" { +// req.Host = tc.host +// } + +// proxy.Handler().ServeHTTP(rw, req) + +// if tc.expectedCode != rw.Code { +// t.Errorf("expected code %d, got %d", tc.expectedCode, rw.Code) +// } + +// location := rw.Header().Get("Location") +// locationURL, err := url.Parse(location) +// if err != nil { +// t.Errorf("error parsing location %q: %s", location, err) +// } +// if tc.expectedLocation != "" && location != tc.expectedLocation { +// t.Errorf("expected Location=%q, got Location=%q", tc.expectedLocation, location) +// } +// if tc.expectedLocationHost != "" && locationURL.Hostname() != tc.expectedLocationHost { +// t.Errorf("expected location host = %q, got %q", tc.expectedLocationHost, locationURL.Hostname()) +// } + +// stsKey := http.CanonicalHeaderKey("Strict-Transport-Security") +// if tc.expectSTS { +// val := rw.Header().Get(stsKey) +// expectedVal := "max-age=31536000" +// if val != expectedVal { +// t.Errorf("expected %s=%q, got %q", stsKey, expectedVal, val) +// } +// } else { +// _, found := rw.HeaderMap[stsKey] +// if found { +// t.Errorf("%s header should not be present, got %q", stsKey, rw.Header().Get(stsKey)) +// } +// } +// }) +// } +// } + +func TestTimeoutHandler(t *testing.T) { + testCases := []struct { + name string + config *UpstreamConfig + defaultTimeout time.Duration + globalTimeout time.Duration + ExpectedStatusCode int + ExpectedBody string + ExpectedErr error + }{ + { + name: "does not timeout", + config: &UpstreamConfig{ + Timeout: time.Duration(100) * time.Millisecond, + }, + defaultTimeout: time.Duration(100) * time.Millisecond, + globalTimeout: time.Duration(100) * time.Millisecond, + ExpectedStatusCode: 200, + ExpectedBody: "OK", + }, + { + name: "times out using upstream config timeout", + config: &UpstreamConfig{ + Service: "service-test", + Timeout: time.Duration(10) * time.Millisecond, + }, + defaultTimeout: time.Duration(100) * time.Millisecond, + globalTimeout: time.Duration(100) * time.Millisecond, + ExpectedStatusCode: 503, + ExpectedBody: fmt.Sprintf("service-test failed to respond within the 10ms timeout period"), + }, + { + name: "times out using default upstream config timeout", + config: &UpstreamConfig{ + Service: "service-test", + }, + defaultTimeout: time.Duration(10) * time.Millisecond, + globalTimeout: time.Duration(100) * time.Millisecond, + ExpectedStatusCode: 503, + ExpectedBody: fmt.Sprintf("service-test failed to respond within the 10ms timeout period"), + }, + { + name: "times out using global write timeout", + config: &UpstreamConfig{ + Service: "service-test", + }, + defaultTimeout: time.Duration(100) * time.Millisecond, + globalTimeout: time.Duration(10) * time.Millisecond, + ExpectedErr: &url.Error{ + Err: io.EOF, + }, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + opts := NewOptions() + opts.DefaultUpstreamTimeout = tc.defaultTimeout + + baseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + timer := time.NewTimer(time.Duration(50) * time.Millisecond) + <-timer.C + w.Write([]byte("OK")) + }) + timeoutHandler := NewTimeoutHandler(baseHandler, opts, tc.config) + + srv := httptest.NewUnstartedServer(timeoutHandler) + srv.Config.WriteTimeout = tc.globalTimeout + srv.Start() + defer srv.Close() + + res, err := http.Get(srv.URL) + if err != nil { + if tc.ExpectedErr == nil { + t.Fatalf("got unexpected err=%v", err) + } + urlErr, ok := err.(*url.Error) + if !ok { + t.Fatalf("got unexpected err=%v", err) + } + if urlErr.Err != io.EOF { + t.Fatalf("got unexpected err=%v", err) + } + // We got the error we expected, exit + return + } + + if res.StatusCode != tc.ExpectedStatusCode { + t.Errorf(" got=%v", res.StatusCode) + t.Errorf("want=%v", tc.ExpectedStatusCode) + t.Fatalf("got unexpcted status code") + } + + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("got unexpected err=%q", err) + } + + if string(body) != tc.ExpectedBody { + t.Errorf(" got=%q", body) + t.Errorf("want=%q", tc.ExpectedBody) + t.Fatalf("got unexpcted body") + } + }) + } +} + +func generateTestRewriteUpstreamConfigs(fromRegex, toTemplate string) []*UpstreamConfig { + templateVars := map[string]string{ + "root_domain": "dev", + "cluster": "sso", + } + upstreamConfigs, err := loadServiceConfigs([]byte(fmt.Sprintf(` +- service: foo + default: + from: %s + to: %s + type: rewrite +`, fromRegex, toTemplate)), "DevProxy", "http", templateVars) + if err != nil { + panic(err) + } + return upstreamConfigs +} + +func TestRewriteRoutingHandling(t *testing.T) { + type response struct { + Host string `json:"host"` + XForwardedHost string `json:"x-forwarded-host"` + XForwardedEmail string `json:"x-forwarded-email"` + XForwardedUser string `json:"x-forwarded-user"` + XForwardedGroups string `json:"x-forwarded-groups"` + XForwardedAccessToken string `json:"x-forwarded-access-token"` + } + + upstream := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + body, err := json.Marshal( + &response{ + Host: r.Host, + XForwardedHost: r.Header.Get("X-Forwarded-Host"), + XForwardedEmail: r.Header.Get(`json:"x-forwarded-email"`), + XForwardedUser: r.Header.Get(`json:"x-forwarded-user"`), + XForwardedGroups: r.Header.Get(`json:"x-forwarded-groups"`), + XForwardedAccessToken: r.Header.Get(`json:"x-forwarded-access-token"`), + }, + ) + if err != nil { + t.Fatalf("expected to marshal json: %s", err) + } + rw.Write(body) + })) + defer upstream.Close() + + parsedUpstreamURL, err := url.Parse(upstream.URL) + if err != nil { + t.Fatalf("expected to parse upstream URL err:%q", err) + } + + upstreamHost, upstreamPort, err := net.SplitHostPort(parsedUpstreamURL.Host) + if err != nil { + t.Fatalf("expected to split host/hort err:%q", err) + } + + testCases := []struct { + Name string + TestHost string + TestUser string + TestEmail string + TestGroups string + TestAccessToken string + FromRegex string + ToTemplate string + ExpectedCode int + ExpectedResponse *response + }{ + { + Name: "everything should work in the normal case", + TestHost: "foo.sso.dev", + FromRegex: "(.*)", + ToTemplate: parsedUpstreamURL.Host, + ExpectedCode: http.StatusOK, + ExpectedResponse: &response{ + Host: parsedUpstreamURL.Host, + XForwardedHost: "foo.sso.dev", + }, + }, + { + Name: "it should not match a non-matching regex", + TestHost: "foo.sso.dev", + FromRegex: "bar", + ToTemplate: parsedUpstreamURL.Host, + ExpectedCode: statusInvalidHost, + }, + { + Name: "it should match and replace using regex/template to find port in embeded domain", + TestHost: fmt.Sprintf("somedomain--%s", upstreamPort), + FromRegex: "somedomain--(.*)", // capture port + ToTemplate: fmt.Sprintf("%s:$1", upstreamHost), // add port to dest + ExpectedCode: http.StatusOK, + ExpectedResponse: &response{ + Host: parsedUpstreamURL.Host, + XForwardedHost: fmt.Sprintf("somedomain--%s", upstreamPort), + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + opts := NewOptions() + opts.upstreamConfigs = generateTestRewriteUpstreamConfigs(tc.FromRegex, tc.ToTemplate) + opts.Validate() + proxy, err := NewDevProxy(opts, testValidatorFunc(true)) + if err != nil { + t.Fatalf("unexpected err provisioning dev proxy err:%q", err) + } + + req, err := http.NewRequest("GET", fmt.Sprintf("https://%s/", tc.TestHost), strings.NewReader("")) + if err != nil { + t.Fatalf("unexpected err creating request err:%s", err) + } + + rw := httptest.NewRecorder() + + proxy.Handler().ServeHTTP(rw, req) + + if tc.ExpectedCode != rw.Code { + t.Errorf("expected code %d, got %d", tc.ExpectedCode, rw.Code) + } + + if tc.ExpectedResponse == nil { + // we've passed our test, we didn't expect a body, exit early + return + } + + body, err := ioutil.ReadAll(rw.Body) + if err != nil { + t.Fatalf("expected to read body: %s", err) + } + + got := &response{} + err = json.Unmarshal(body, got) + if err != nil { + t.Fatalf("expected to decode json: %s", err) + } + + if !reflect.DeepEqual(tc.ExpectedResponse, got) { + t.Logf(" got host: %v", got.Host) + t.Logf("want host: %v", tc.ExpectedResponse.Host) + + t.Logf(" got X-Forwarded-Host: %v", got.XForwardedHost) + t.Logf("want X-Forwarded-Host: %v", tc.ExpectedResponse.XForwardedHost) + + t.Errorf("got unexpected response for Host or X-Forwarded-Host header") + } + }) + } +} diff --git a/internal/devproxy/logging_handler.go b/internal/devproxy/logging_handler.go new file mode 100644 index 00000000..e81248de --- /dev/null +++ b/internal/devproxy/logging_handler.go @@ -0,0 +1,131 @@ +// largely adapted from https://github.com/gorilla/handlers/blob/master/handlers.go +// to add logging of request duration as last value (and drop referrer) + +package devproxy + +import ( + "io" + "net/http" + + // "net/url" + "strings" + // "time" + // log "github.com/buzzfeed/sso/internal/pkg/logging" + // "github.com/datadog/datadog-go/statsd" +) + +// Used to stash the authenticated user in the response for access when logging requests. +// const loggingUserHeader = "SSO-Authenticated-User" + +// responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP status +// code and body size +type responseLogger struct { + w http.ResponseWriter + status int + size int + // authInfo string +} + +func (l *responseLogger) Header() http.Header { + return l.w.Header() +} + +// func (l *responseLogger) extractUser() { +// authInfo := l.w.Header().Get(loggingUserHeader) +// if authInfo != "" { +// l.authInfo = authInfo +// l.w.Header().Del(loggingUserHeader) +// } +// } + +func (l *responseLogger) Write(b []byte) (int, error) { + if l.status == 0 { + // The status will be StatusOK if WriteHeader has not been called yet + l.status = http.StatusOK + } + // l.extractUser() + size, err := l.w.Write(b) + l.size += size + return size, err +} + +func (l *responseLogger) WriteHeader(s int) { + // l.extractUser() + l.w.WriteHeader(s) + l.status = s +} + +func (l *responseLogger) Status() int { + return l.status +} + +func (l *responseLogger) Size() int { + return l.size +} + +func (l *responseLogger) Flush() { + f := l.w.(http.Flusher) + f.Flush() +} + +// loggingHandler is the http.Handler implementation for LoggingHandlerTo and its friends +type loggingHandler struct { + writer io.Writer + handler http.Handler + // StatsdClient *statsd.Client + enabled bool +} + +// NewLoggingHandler returns a new loggingHandler that wraps a handler, statsd client, and writer. +func NewLoggingHandler(out io.Writer, h http.Handler, v bool /*, StatsdClient *statsd.Client*/) http.Handler { + return loggingHandler{writer: out, + handler: h, + enabled: v, + // StatsdClient: StatsdClient, + } +} + +func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // now := time.Now() + // url := *req.URL + logger := &responseLogger{w: w} + h.handler.ServeHTTP(logger, req) + if !h.enabled { + return + } + // logRequest(logger.authInfo, req, url, now, logger.Status(), h.StatsdClient) +} + +// logRequest logs information about a request +// func logRequest(username string, req *http.Request, url url.URL, ts time.Time, status int, StatsdClient *statsd.Client) { +// duration := time.Now().Sub(ts) + +// // Convert duration to floating point milliseconds +// // https://github.com/golang/go/issues/5491#issuecomment-66079585 +// durationMS := duration.Seconds() * 1e3 + +// uri := req.Host + url.RequestURI() + +// logger := log.NewLogEntry() +// logger.WithHTTPStatus(status).WithRequestMethod(req.Method).WithRequestURI( +// uri).WithUserAgent(req.Header.Get("User-Agent")).WithRemoteAddress( +// getRemoteAddr(req)).WithRequestDurationMs(durationMS).WithUser( +// username).WithAction(GetActionTag(req)).Info() +// logRequestMetrics(req, duration, status, StatsdClient) +// } + +// getRemoteAddr returns the client IP address from a request. If present, the +// X-Forwarded-For header is assumed to be set by a load balancer, and its +// rightmost entry (the client IP that connected to the LB) is returned. +func getRemoteAddr(req *http.Request) string { + addr := req.RemoteAddr + forwardedHeader := req.Header.Get("X-Forwarded-For") + if forwardedHeader != "" { + forwardedList := strings.Split(forwardedHeader, ",") + forwardedAddr := strings.TrimSpace(forwardedList[len(forwardedList)-1]) + if forwardedAddr != "" { + addr = forwardedAddr + } + } + return addr +} diff --git a/internal/devproxy/metrics.go b/internal/devproxy/metrics.go new file mode 100644 index 00000000..bc7b5795 --- /dev/null +++ b/internal/devproxy/metrics.go @@ -0,0 +1,54 @@ +package devproxy + +import ( + "fmt" + // "net" + "net/http" + // "strconv" + "time" + // "github.com/datadog/datadog-go/statsd" +) + +// GetActionTag returns the action triggered by an http.Request . +func GetActionTag(req *http.Request) string { + // only log metrics for these paths and actions + pathToAction := map[string]string{ + "/favicon.ico": "favicon", + "/oauth2/sign_out": "sign_out", + "/oauth2/callback": "callback", + "/oauth2/auth": "auth", + "/ping": "ping", + "/robots.txt": "robots", + } + // get the action from the url path + path := req.URL.Path + if action, ok := pathToAction[path]; ok { + return action + } + return "proxy" +} + +// logMetrics logs all metrics surrounding a given request to the metricsWriter +func logRequestMetrics(req *http.Request, requestDuration time.Duration, status int) { //, StatsdClient *statsd.Client) { + // Normalize proxyHost for a) invalid requests or b) LB health checks to + // avoid polluting the proxy_host tag's value space + proxyHost := req.Host + if status == statusInvalidHost { + proxyHost = "_unknown" + } + if req.URL.Path == "/ping" { + proxyHost = "_healthcheck" + } + + // tags := []string{ + fmt.Sprintf("method:%s", req.Method) + fmt.Sprintf("status_code:%d", status) + fmt.Sprintf("status_category:%dxx", status/100) + fmt.Sprintf("action:%s", GetActionTag(req)) + fmt.Sprintf("proxy_host:%s", proxyHost) + // } + + // TODO: eventually make rates configurable + // StatsdClient.Timing("request", requestDuration, tags, 1.0) + +} diff --git a/internal/devproxy/middleware.go b/internal/devproxy/middleware.go new file mode 100644 index 00000000..99192595 --- /dev/null +++ b/internal/devproxy/middleware.go @@ -0,0 +1,69 @@ +package devproxy + +import ( + "net/http" + "net/url" +) + +// With inspiration from https://github.com/unrolled/secure +// +// TODO: Add Content-Security-Report header? +var securityHeaders = map[string]string{ + "X-Content-Type-Options": "nosniff", + "X-Frame-Options": "SAMEORIGIN", + "X-XSS-Protection": "1; mode=block", +} + +// setHeaders ensures that every response includes some basic security headers. +// +// Note: the Strict-Transport-Security header is set by the requireHTTPS +// middleware below, to avoid issues with development environments that must +// allow plain HTTP. +func setSecurityHeaders(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + for key, val := range securityHeaders { + rw.Header().Set(key, val) + } + h.ServeHTTP(rw, req) + }) +} + +func (p *DevProxy) setResponseHeaderOverrides(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + route, ok := p.router(req) + if ok && route.upstreamConfig.HeaderOverrides != nil { + for key, val := range route.upstreamConfig.HeaderOverrides { + rw.Header().Set(key, val) + } + } + h.ServeHTTP(rw, req) + }) +} + +// validateHost ensures that each request's host is valid +func (p *DevProxy) validateHost(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if _, ok := p.router(req); !ok { + p.UnknownHost(rw, req) + return + } + h.ServeHTTP(rw, req) + }) +} + +func requireHTTPS(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("Strict-Transport-Security", "max-age=31536000") + if req.URL.Scheme != "https" && req.Header.Get("X-Forwarded-Proto") != "https" { + dest := &url.URL{ + Scheme: "https", + Host: req.Host, + Path: req.URL.Path, + RawQuery: req.URL.RawQuery, + } + http.Redirect(rw, req, dest.String(), http.StatusMovedPermanently) + return + } + h.ServeHTTP(rw, req) + }) +} diff --git a/internal/devproxy/options.go b/internal/devproxy/options.go new file mode 100644 index 00000000..0a02ef1a --- /dev/null +++ b/internal/devproxy/options.go @@ -0,0 +1,143 @@ +package devproxy + +import ( + // "encoding/base64" + // "errors" + "fmt" + "io/ioutil" + + // "net/http" + "net/url" + "os" + "strings" + "time" + // "github.com/datadog/datadog-go/statsd" +) + +// Options are configuration options that can be set by Environment Variables +// Port - int - port to listen on for HTTP clients +// ProviderURLString - the URL for the provider in this environment: "https://sso-auth.example.com" +// UpstreamConfigsFile - the path to upstream configs file +// Cluster - the cluster in which this is running, used for upstream configs +// Scheme - the default scheme, used for upstream configs +// DefaultUpstreamTimeout - the default time period to wait for a response from an upstream +// TCPWriteTimeout - http server tcp write timeout +// TCPReadTimeout - http server tcp read timeout +// SessionLifetimeTTL - time to live for a session lifetime +// SessionValidTTL - time to live for a valid session +// GracePeriodTTL - time to reuse session data when provider unavailable +// RequestLoging - boolean whether or not to log requests +// StatsdHost - host addr for statsd client to listen on +// StatsdPort - port for statsdclient to listen on +type Options struct { + Port int `envconfig:"PORT" default:"4180"` + + UpstreamConfigsFile string `envconfig:"UPSTREAM_CONFIGS"` + Cluster string `envconfig:"CLUSTER"` + Scheme string `envconfig:"SCHEME" default:"https"` + + Host string `envconfig:"HOST"` + + DefaultUpstreamTimeout time.Duration `envconfig:"DEFAULT_UPSTREAM_TIMEOUT" default:"10s"` + + TCPWriteTimeout time.Duration `envconfig:"TCP_WRITE_TIMEOUT" default:"30s"` + TCPReadTimeout time.Duration `envconfig:"TCP_READ_TIMEOUT" default:"30s"` + + RequestLogging bool `envconfig:"REQUEST_LOGGING" default:"true"` + + // StatsdHost string `envconfig:"STATSD_HOST"` + // StatsdPort int `envconfig:"STATSD_PORT"` + + // StatsdClient *statsd.Client + + // This is an override for supplying template vars at test time + testTemplateVars map[string]string + + // internal values that are set after config validation + upstreamConfigs []*UpstreamConfig +} + +// NewOptions returns a new options struct +func NewOptions() *Options { + return &Options{ + RequestLogging: true, + DefaultUpstreamTimeout: time.Duration(1) * time.Second, + } +} + +func parseURL(toParse string, urltype string, msgs []string) (*url.URL, []string) { + parsed, err := url.Parse(toParse) + if err != nil { + return nil, append(msgs, fmt.Sprintf( + "error parsing %s-url=%q %s", urltype, toParse, err)) + } + return parsed, msgs +} + +// Validate validates options +func (o *Options) Validate() error { + msgs := make([]string, 0) + if o.Cluster == "" { + msgs = append(msgs, "missing setting: cluster") + } + if o.UpstreamConfigsFile == "" { + msgs = append(msgs, "missing setting: upstream-configs") + } + + // if o.StatsdHost == "" { + // msgs = append(msgs, "missing setting: statsd-host") + // } + + // if o.StatsdPort == 0 { + // msgs = append(msgs, "missing setting: statsd-port") + // } + // if o.StatsdHost != "" && o.StatsdPort != 0 { + // StatsdClient, err := newStatsdClient(o) + // if err != nil { + // msgs = append(msgs, fmt.Sprintf("error creating statsd client error=%q", err)) + // } + // o.StatsdClient = StatsdClient + // } + + if o.UpstreamConfigsFile != "" { + rawBytes, err := ioutil.ReadFile(o.UpstreamConfigsFile) + if err != nil { + msgs = append(msgs, fmt.Sprintf("error reading upstream configs file: %s", err)) + } + + templateVars := parseEnvironment(os.Environ()) + if o.testTemplateVars != nil { + templateVars = o.testTemplateVars + } + + o.upstreamConfigs, err = loadServiceConfigs(rawBytes, o.Cluster, o.Scheme, templateVars) + if err != nil { + msgs = append(msgs, fmt.Sprintf("error parsing upstream configs file %s", err)) + } + } + + if len(msgs) != 0 { + return fmt.Errorf("Invalid configuration:\n %s", + strings.Join(msgs, "\n ")) + } + return nil +} + +func parseEnvironment(environ []string) map[string]string { + envPrefix := "DEV_CONFIG_" + env := make(map[string]string) + if len(environ) == 0 { + return env + } + for _, e := range environ { + // we only include env keys that have the SSO_CONFIG_ prefix + if !strings.HasPrefix(e, envPrefix) { + continue + } + + split := strings.SplitN(e, "=", 2) + key := strings.ToLower(strings.TrimPrefix(split[0], envPrefix)) + env[key] = split[1] + } + return env +} diff --git a/internal/devproxy/providers/http_client.go b/internal/devproxy/providers/http_client.go new file mode 100644 index 00000000..cee3d9bd --- /dev/null +++ b/internal/devproxy/providers/http_client.go @@ -0,0 +1,17 @@ +package providers + +import ( + "net" + "net/http" + "time" +) + +var httpClient = &http.Client{ + Timeout: time.Second * 5, + Transport: &http.Transport{ + Dial: (&net.Dialer{ + Timeout: 2 * time.Second, + }).Dial, + TLSHandshakeTimeout: 2 * time.Second, + }, +} diff --git a/internal/devproxy/providers/internal_util.go b/internal/devproxy/providers/internal_util.go new file mode 100644 index 00000000..c3714427 --- /dev/null +++ b/internal/devproxy/providers/internal_util.go @@ -0,0 +1,85 @@ +package providers + +import ( + "io/ioutil" + "net/http" + "net/url" + + log "github.com/buzzfeed/sso/internal/pkg/logging" +) + +// stripToken is a helper function to obfuscate "access_token" +// query parameters +func stripToken(endpoint string) string { + return stripParam("access_token", endpoint) +} + +// stripParam generalizes the obfuscation of a particular +// query parameter - typically 'access_token' or 'client_secret' +// The parameter's second half is replaced by '...' and returned +// as part of the encoded query parameters. +// If the target parameter isn't found, the endpoint is returned +// unmodified. +func stripParam(param, endpoint string) string { + logger := log.NewLogEntry() + + u, err := url.Parse(endpoint) + if err != nil { + logger.WithURLParam(param).Error(err, "error attempting to strip parameter") + return endpoint + } + + if u.RawQuery != "" { + values, err := url.ParseQuery(u.RawQuery) + if err != nil { + logger.WithURLParam(param).Error("error attempting to strip parameter") + return u.String() + } + + if val := values.Get(param); val != "" { + values.Set(param, val[:(len(val)/2)]+"...") + u.RawQuery = values.Encode() + return u.String() + } + } + + return endpoint +} + +// validateToken returns true if token is valid +func validateToken(p Provider, accessToken string, header http.Header) bool { + logger := log.NewLogEntry() + + if accessToken == "" || p.Data().ValidateURL == nil { + return false + } + endpoint := p.Data().ValidateURL.String() + if len(header) == 0 { + params := url.Values{"access_token": {accessToken}} + endpoint = endpoint + "?" + params.Encode() + } + + req, err := http.NewRequest("GET", endpoint, nil) + if err != nil { + logger.Error(err, "token validation request failed") + return false + } + req.Header = header + + resp, err := httpClient.Do(req) + if err != nil { + logger.Error(err, "token validation request failed") + return false + } + + body, _ := ioutil.ReadAll(resp.Body) + resp.Body.Close() + logger.Printf("%d GET %s %s", resp.StatusCode, stripToken(endpoint), body) + + if resp.StatusCode == 200 { + return true + } + logger.WithHTTPStatus(resp.StatusCode).WithResponseBody(body).Info( + "token validation request failed") + return false +} diff --git a/internal/devproxy/providers/internal_util_test.go b/internal/devproxy/providers/internal_util_test.go new file mode 100644 index 00000000..ca5a070b --- /dev/null +++ b/internal/devproxy/providers/internal_util_test.go @@ -0,0 +1,132 @@ +package providers + +import ( + "errors" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/buzzfeed/sso/internal/pkg/testutil" +) + +type ValidateSessionStateTestProvider struct { + *ProviderData +} + +func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *SessionState) (string, error) { + return "", errors.New("not implemented") +} + +// Note that we're testing the internal validateToken() used to implement +// several Provider's ValidateSessionState() implementations +func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *SessionState, g []string) bool { + return false +} + +type ValidateSessionStateTest struct { + backend *httptest.Server + responseCode int + provider *ValidateSessionStateTestProvider + header http.Header +} + +func NewValidateSessionStateTest() *ValidateSessionStateTest { + var vtTest ValidateSessionStateTest + + vtTest.backend = httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/oauth/tokeninfo" { + w.WriteHeader(500) + w.Write([]byte("unknown URL")) + } + tokenParam := r.FormValue("access_token") + if tokenParam == "" { + missing := false + receivedHeaders := r.Header + for k := range vtTest.header { + received := receivedHeaders.Get(k) + expected := vtTest.header.Get(k) + if received == "" || received != expected { + missing = true + } + } + if missing { + w.WriteHeader(500) + w.Write([]byte("no token param and missing or incorrect headers")) + } + } + w.WriteHeader(vtTest.responseCode) + w.Write([]byte("only code matters; contents disregarded")) + + })) + backendURL, _ := url.Parse(vtTest.backend.URL) + vtTest.provider = &ValidateSessionStateTestProvider{ + ProviderData: &ProviderData{ + ValidateURL: &url.URL{ + Scheme: "http", + Host: backendURL.Host, + Path: "/oauth/tokeninfo", + }, + }, + } + vtTest.responseCode = 200 + return &vtTest +} + +func (vtTest *ValidateSessionStateTest) Close() { + vtTest.backend.Close() +} + +func TestValidateSessionStateValidToken(t *testing.T) { + vtTest := NewValidateSessionStateTest() + defer vtTest.Close() + testutil.Equal(t, true, validateToken(vtTest.provider, "foobar", nil)) +} + +func TestValidateSessionStateValidTokenWithHeaders(t *testing.T) { + vtTest := NewValidateSessionStateTest() + defer vtTest.Close() + vtTest.header = make(http.Header) + vtTest.header.Set("Authorization", "Bearer foobar") + testutil.Equal(t, true, + validateToken(vtTest.provider, "foobar", vtTest.header)) +} + +func TestValidateSessionStateEmptyToken(t *testing.T) { + vtTest := NewValidateSessionStateTest() + defer vtTest.Close() + testutil.Equal(t, false, validateToken(vtTest.provider, "", nil)) +} + +func TestValidateSessionStateEmptyValidateURL(t *testing.T) { + vtTest := NewValidateSessionStateTest() + defer vtTest.Close() + vtTest.provider.Data().ValidateURL = nil + testutil.Equal(t, false, validateToken(vtTest.provider, "foobar", nil)) +} + +func TestValidateSessionStateRequestNetworkFailure(t *testing.T) { + vtTest := NewValidateSessionStateTest() + // Close immediately to simulate a network failure + vtTest.Close() + testutil.Equal(t, false, validateToken(vtTest.provider, "foobar", nil)) +} + +func TestValidateSessionStateExpiredToken(t *testing.T) { + vtTest := NewValidateSessionStateTest() + defer vtTest.Close() + vtTest.responseCode = 401 + testutil.Equal(t, false, validateToken(vtTest.provider, "foobar", nil)) +} + +func TestStripTokenNotPresent(t *testing.T) { + test := "http://local.test/api/test?a=1&b=2" + testutil.Equal(t, test, stripToken(test)) +} + +func TestStripToken(t *testing.T) { + test := "http://local.test/api/test?access_token=deadbeef&b=1&c=2" + expected := "http://local.test/api/test?access_token=dead...&b=1&c=2" + testutil.Equal(t, expected, stripToken(test)) +} diff --git a/internal/devproxy/providers/provider_data.go b/internal/devproxy/providers/provider_data.go new file mode 100644 index 00000000..afaddbf0 --- /dev/null +++ b/internal/devproxy/providers/provider_data.go @@ -0,0 +1,31 @@ +package providers + +import ( + "net/url" + "time" +) + +// ProviderData holds the fields associated with providers +// necessary to implement the Provider interface. +type ProviderData struct { + ProviderName string + ProviderURL *url.URL + ClientID string + ClientSecret string + SignInURL *url.URL + SignOutURL *url.URL + RedeemURL *url.URL + RefreshURL *url.URL + ProfileURL *url.URL + ProtectedResource *url.URL + ValidateURL *url.URL + Scope string + ApprovalPrompt string + + SessionValidTTL time.Duration + SessionLifetimeTTL time.Duration + GracePeriodTTL time.Duration +} + +// Data returns the ProviderData struct +func (p *ProviderData) Data() *ProviderData { return p } diff --git a/internal/devproxy/providers/provider_default.go b/internal/devproxy/providers/provider_default.go new file mode 100644 index 00000000..02496b68 --- /dev/null +++ b/internal/devproxy/providers/provider_default.go @@ -0,0 +1,146 @@ +package providers + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "time" +) + +// Redeem takes a redirectURL and code, creates some params and redeems the request +func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err error) { + if code == "" { + err = errors.New("missing code") + return + } + + params := url.Values{} + params.Add("redirect_uri", redirectURL) + params.Add("client_id", p.ClientID) + params.Add("client_secret", p.ClientSecret) + params.Add("code", code) + params.Add("grant_type", "authorization_code") + if p.ProtectedResource != nil && p.ProtectedResource.String() != "" { + params.Add("resource", p.ProtectedResource.String()) + } + + var req *http.Request + req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) + if err != nil { + return + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + var resp *http.Response + resp, err = httpClient.Do(req) + if err != nil { + return nil, err + } + var body []byte + body, err = ioutil.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return + } + + if resp.StatusCode != 200 { + err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) + return + } + + // blindly try json and x-www-form-urlencoded + var jsonResponse struct { + AccessToken string `json:"access_token"` + } + err = json.Unmarshal(body, &jsonResponse) + if err == nil { + s = &SessionState{ + AccessToken: jsonResponse.AccessToken, + } + return + } + + var v url.Values + v, err = url.ParseQuery(string(body)) + if err != nil { + return + } + if a := v.Get("access_token"); a != "" { + s = &SessionState{AccessToken: a} + } else { + err = fmt.Errorf("no access token found %s", body) + } + return +} + +// GetSignInURL with typical oauth parameters +func (p *ProviderData) GetSignInURL(redirectURL *url.URL, state string) *url.URL { + var a url.URL + a = *p.SignInURL + now := time.Now() + rawRedirect := redirectURL.String() + params, _ := url.ParseQuery(a.RawQuery) + params.Set("redirect_uri", rawRedirect) + params.Add("scope", p.Scope) + params.Set("client_id", p.ClientID) + params.Set("response_type", "code") + params.Add("state", state) + params.Set("ts", fmt.Sprint(now.Unix())) + params.Set("sig", p.signRedirectURL(rawRedirect, now)) + a.RawQuery = params.Encode() + return &a +} + +// GetSignOutURL creates and returns the sign out URL, given a redirectURL +func (p *ProviderData) GetSignOutURL(redirectURL *url.URL) *url.URL { + var a url.URL + a = *p.SignOutURL + now := time.Now() + rawRedirect := redirectURL.String() + params, _ := url.ParseQuery(a.RawQuery) + params.Add("redirect_uri", rawRedirect) + params.Set("ts", fmt.Sprint(now.Unix())) + params.Set("sig", p.signRedirectURL(rawRedirect, now)) + a.RawQuery = params.Encode() + return &a +} + +// signRedirectURL signs the redirect url string, given a timestamp, and returns it +func (p *ProviderData) signRedirectURL(rawRedirect string, timestamp time.Time) string { + h := hmac.New(sha256.New, []byte(p.ClientSecret)) + h.Write([]byte(rawRedirect)) + h.Write([]byte(fmt.Sprint(timestamp.Unix()))) + return base64.URLEncoding.EncodeToString(h.Sum(nil)) +} + +// GetEmailAddress returns an email address or error +func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) { + return "", errors.New("not implemented") +} + +// ValidateGroup validates that the provided email exists in the configured provider email group(s). +func (p *ProviderData) ValidateGroup(_ string, _ []string) ([]string, bool, error) { + return []string{}, true, nil +} + +// UserGroups returns a list of users +func (p *ProviderData) UserGroups(string, []string) ([]string, error) { + return []string{}, nil +} + +// ValidateSessionState calls to validate the token given the session and groups +func (p *ProviderData) ValidateSessionState(s *SessionState, groups []string) bool { + return validateToken(p, s.AccessToken, nil) +} + +// RefreshSession returns a boolean or error +func (p *ProviderData) RefreshSession(s *SessionState, group []string) (bool, error) { + return false, nil +} diff --git a/internal/devproxy/providers/provider_default_test.go b/internal/devproxy/providers/provider_default_test.go new file mode 100644 index 00000000..2aa48bfe --- /dev/null +++ b/internal/devproxy/providers/provider_default_test.go @@ -0,0 +1,19 @@ +package providers + +import ( + "testing" + "time" + + "github.com/buzzfeed/sso/internal/pkg/testutil" +) + +func TestRefresh(t *testing.T) { + p := &ProviderData{} + refreshed, err := p.RefreshSession(&SessionState{ + RefreshDeadline: time.Now().Add(time.Duration(-11) * time.Minute), + }, + []string{}, + ) + testutil.Equal(t, false, refreshed) + testutil.Equal(t, nil, err) +} diff --git a/internal/devproxy/providers/providers.go b/internal/devproxy/providers/providers.go new file mode 100644 index 00000000..4d33a796 --- /dev/null +++ b/internal/devproxy/providers/providers.go @@ -0,0 +1,25 @@ +package providers + +import ( + "net/url" + + "github.com/datadog/datadog-go/statsd" +) + +// Provider is an interface exposing functions necessary to authenticate with a given provider. +type Provider interface { + Data() *ProviderData + GetEmailAddress(*SessionState) (string, error) + Redeem(string, string) (*SessionState, error) + ValidateGroup(string, []string) ([]string, bool, error) + UserGroups(string, []string) ([]string, error) + ValidateSessionState(*SessionState, []string) bool + GetSignInURL(redirectURL *url.URL, finalRedirect string) *url.URL + GetSignOutURL(redirectURL *url.URL) *url.URL + RefreshSession(*SessionState, []string) (bool, error) +} + +// New returns a new sso Provider +func New(provider string, p *ProviderData, sc *statsd.Client) Provider { + return NewSSOProvider(p, sc) +} diff --git a/internal/devproxy/providers/session_state.go b/internal/devproxy/providers/session_state.go new file mode 100644 index 00000000..c8f9d8c6 --- /dev/null +++ b/internal/devproxy/providers/session_state.go @@ -0,0 +1,61 @@ +package providers + +import ( + "time" + + "github.com/buzzfeed/sso/internal/pkg/aead" +) + +// SessionState is our object that keeps track of a user's session state +type SessionState struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + + RefreshDeadline time.Time `json:"refresh_deadline"` + LifetimeDeadline time.Time `json:"lifetime_deadline"` + ValidDeadline time.Time `json:"valid_deadline"` + GracePeriodStart time.Time `json:"grace_period_start"` + + Email string `json:"email"` + User string `json:"user"` + Groups []string `json:"groups"` +} + +// LifetimePeriodExpired returns true if the lifetime has expired +func (s *SessionState) LifetimePeriodExpired() bool { + return isExpired(s.LifetimeDeadline) +} + +// RefreshPeriodExpired returns true if the refresh period has expired +func (s *SessionState) RefreshPeriodExpired() bool { + return isExpired(s.RefreshDeadline) +} + +// ValidationPeriodExpired returns true if the validation period has expired +func (s *SessionState) ValidationPeriodExpired() bool { + return isExpired(s.ValidDeadline) +} + +func isExpired(t time.Time) bool { + if t.Before(time.Now()) { + return true + } + return false +} + +// MarshalSession marshals the session state as JSON, encrypts the JSON using the +// given cipher, and base64-encodes the result +func MarshalSession(s *SessionState, c aead.Cipher) (string, error) { + return c.Marshal(s) +} + +// UnmarshalSession takes the marshaled string, base64-decodes into a byte slice, decrypts the +// byte slice using the pased cipher, and unmarshals the resulting JSON into a session state struct +func UnmarshalSession(value string, c aead.Cipher) (*SessionState, error) { + s := &SessionState{} + err := c.Unmarshal(value, s) + if err != nil { + return nil, err + } + return s, nil +} diff --git a/internal/devproxy/providers/session_state_test.go b/internal/devproxy/providers/session_state_test.go new file mode 100644 index 00000000..24cd138a --- /dev/null +++ b/internal/devproxy/providers/session_state_test.go @@ -0,0 +1,72 @@ +package providers + +import ( + "reflect" + "testing" + "time" + + "github.com/buzzfeed/sso/internal/pkg/aead" +) + +func TestSessionStateSerialization(t *testing.T) { + secret := "0123456789abcdefghijklmnopqrstuv" + + c, err := aead.NewMiscreantCipher([]byte(secret)) + if err != nil { + t.Fatalf("expected to be able to create cipher: %v", err) + } + + want := &SessionState{ + AccessToken: "token1234", + RefreshToken: "refresh4321", + + LifetimeDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), + RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), + ValidDeadline: time.Now().Add(1 * time.Minute).Truncate(time.Second).UTC(), + + Email: "user@domain.com", + User: "user", + } + + ciphertext, err := MarshalSession(want, c) + if err != nil { + t.Fatalf("expected to be encode session: %v", err) + } + + got, err := UnmarshalSession(ciphertext, c) + if err != nil { + t.Fatalf("expected to be decode session: %v", err) + } + + if !reflect.DeepEqual(want, got) { + t.Logf("want: %#v", want) + t.Logf(" got: %#v", got) + t.Errorf("encoding and decoding session resulted in unexpected output") + } +} + +func TestSessionStateExpirations(t *testing.T) { + session := &SessionState{ + AccessToken: "token1234", + RefreshToken: "refresh4321", + + LifetimeDeadline: time.Now().Add(-1 * time.Hour), + RefreshDeadline: time.Now().Add(-1 * time.Hour), + ValidDeadline: time.Now().Add(-1 * time.Minute), + + Email: "user@domain.com", + User: "user", + } + + if !session.LifetimePeriodExpired() { + t.Errorf("expcted lifetime period to be expired") + } + + if !session.RefreshPeriodExpired() { + t.Errorf("expcted lifetime period to be expired") + } + + if !session.ValidationPeriodExpired() { + t.Errorf("expcted lifetime period to be expired") + } +} diff --git a/internal/devproxy/providers/singleflight_middleware.go b/internal/devproxy/providers/singleflight_middleware.go new file mode 100644 index 00000000..6f583bab --- /dev/null +++ b/internal/devproxy/providers/singleflight_middleware.go @@ -0,0 +1,145 @@ +package providers + +import ( + "errors" + "fmt" + "net/url" + "sort" + "strings" + + "github.com/buzzfeed/sso/internal/proxy/singleflight" + + "github.com/datadog/datadog-go/statsd" +) + +var ( + // This is a compile-time check to make sure our types correctly implement the interface: + // https://medium.com/@matryer/golang-tip-compile-time-checks-to-ensure-your-type-satisfies-an-interface-c167afed3aae + _ Provider = &SingleFlightProvider{} +) + +// Error message for ErrUnexpectedReturnType +var ( + ErrUnexpectedReturnType = errors.New("received unexpected return type from single flight func call") +) + +// SingleFlightProvider middleware provider that multiple requests for the same object +// to be processed as a single request. This is often called request collpasing or coalesce. +// This middleware leverages the golang singlelflight provider, with modifications for metrics. +// +// It's common among HTTP reverse proxy cache servers such as nginx, Squid or Varnish - they all call it something else but works similarly. +// +// * https://www.varnish-cache.org/docs/3.0/tutorial/handling_misbehaving_servers.html +// * http://nginx.org/en/docs/http/ngx_http_proxy_module.html#proxy_cache_lock +// * http://wiki.squid-cache.org/Features/CollapsedForwarding +type SingleFlightProvider struct { + StatsdClient *statsd.Client + + provider Provider + + single *singleflight.Group +} + +// NewSingleFlightProvider instatiates a SingleFlightProvider given a provider and statsdClient +func NewSingleFlightProvider(provider Provider, StatsdClient *statsd.Client) *SingleFlightProvider { + return &SingleFlightProvider{ + provider: provider, + single: &singleflight.Group{}, + StatsdClient: StatsdClient, + } +} + +func (p *SingleFlightProvider) do(endpoint, key string, fn func() (interface{}, error)) (interface{}, error) { + compositeKey := fmt.Sprintf("%s/%s", endpoint, key) + resp, shared, err := p.single.Do(compositeKey, fn) + if shared > 0 { + tags := []string{fmt.Sprintf("endpoint:%s", endpoint)} + p.StatsdClient.Incr("provider.singleflight", tags, float64(shared)) + } + return resp, err +} + +// Data calls the provider's Data function +func (p *SingleFlightProvider) Data() *ProviderData { + return p.provider.Data() +} + +// GetEmailAddress calls the provider function getEmailAddress +func (p *SingleFlightProvider) GetEmailAddress(s *SessionState) (string, error) { + return p.provider.GetEmailAddress(s) +} + +// Redeem takes the redirectURL and a code and calls the provider function Redeem +func (p *SingleFlightProvider) Redeem(redirectURL, code string) (*SessionState, error) { + return p.provider.Redeem(redirectURL, code) +} + +// ValidateGroup takes an email, allowedGroups, and userGroups and passes it to the provider's ValidateGroup function and returns the response +func (p *SingleFlightProvider) ValidateGroup(email string, allowedGroups []string) ([]string, bool, error) { + return p.provider.ValidateGroup(email, allowedGroups) +} + +// UserGroups takes an email and passes it to the provider's UserGroups function and returns the response +func (p *SingleFlightProvider) UserGroups(email string, groups []string) ([]string, error) { + // sort the groups so that other requests may be able to use the cached request + sort.Strings(groups) + response, err := p.do("UserGroups", fmt.Sprintf("%s:%s", email, strings.Join(groups, ",")), func() (interface{}, error) { + return p.provider.UserGroups(email, groups) + }) + if err != nil { + return nil, err + } + + groups, ok := response.([]string) + if !ok { + return nil, ErrUnexpectedReturnType + } + + return groups, nil +} + +// ValidateSessionState calls the provider's ValidateSessionState function and returns the response +func (p *SingleFlightProvider) ValidateSessionState(s *SessionState, allowedGroups []string) bool { + response, err := p.do("ValidateSessionState", s.AccessToken, func() (interface{}, error) { + valid := p.provider.ValidateSessionState(s, allowedGroups) + return valid, nil + }) + if err != nil { + return false + } + + valid, ok := response.(bool) + if !ok { + return false + } + + return valid +} + +// GetSignInURL calls the GetSignInURL for the provider, which will return the sign in url +func (p *SingleFlightProvider) GetSignInURL(redirectURI *url.URL, finalRedirect string) *url.URL { + return p.provider.GetSignInURL(redirectURI, finalRedirect) +} + +// GetSignOutURL calls the GetSignOutURL for the provider, which will return the sign out url +func (p *SingleFlightProvider) GetSignOutURL(redirectURI *url.URL) *url.URL { + return p.provider.GetSignOutURL(redirectURI) +} + +// RefreshSession takes in a SessionState and allowedGroups and +// returns false if the session is not refreshed and true if it is. +func (p *SingleFlightProvider) RefreshSession(s *SessionState, allowedGroups []string) (bool, error) { + response, err := p.do("RefreshSession", s.RefreshToken, func() (interface{}, error) { + return p.provider.RefreshSession(s, allowedGroups) + }) + if err != nil { + return false, err + } + + r, ok := response.(bool) + if !ok { + return false, ErrUnexpectedReturnType + } + + return r, nil +} diff --git a/internal/devproxy/providers/sso.go b/internal/devproxy/providers/sso.go new file mode 100644 index 00000000..ffdb7e1c --- /dev/null +++ b/internal/devproxy/providers/sso.go @@ -0,0 +1,385 @@ +package providers + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/url" + "os" + "strings" + "time" + + log "github.com/buzzfeed/sso/internal/pkg/logging" + "github.com/datadog/datadog-go/statsd" +) + +var ( + // This is a compile-time check to make sure our types correctly implement the interface: + // https://medium.com/@matryer/golang-tip-compile-time-checks-to-ensure-your-type-satisfies-an-interface-c167afed3aae + _ Provider = &SSOProvider{} +) + +// Errors +var ( + ErrMissingRefreshToken = errors.New("missing refresh token") + ErrAuthProviderUnavailable = errors.New("auth provider unavailable") +) + +var userAgentString string + +// SSOProvider holds the data associated with the SSOProviders +// necessary to implement a SSOProvider interface. +type SSOProvider struct { + *ProviderData + + StatsdClient *statsd.Client +} + +func init() { + version := os.Getenv("RIG_IMAGE_VERSION") + if version == "" { + version = "HEAD" + } else { + version = strings.Trim(version, `"`) + } + userAgentString = fmt.Sprintf("sso_proxy/%s", version) +} + +// NewSSOProvider instantiates a new SSOProvider with provider data and +// a statsd client. +func NewSSOProvider(p *ProviderData, sc *statsd.Client) *SSOProvider { + p.ProviderName = "SSO" + base := p.ProviderURL + p.SignInURL = base.ResolveReference(&url.URL{Path: "/sign_in"}) + p.SignOutURL = base.ResolveReference(&url.URL{Path: "/sign_out"}) + p.RedeemURL = base.ResolveReference(&url.URL{Path: "/redeem"}) + p.RefreshURL = base.ResolveReference(&url.URL{Path: "/refresh"}) + p.ValidateURL = base.ResolveReference(&url.URL{Path: "/validate"}) + p.ProfileURL = base.ResolveReference(&url.URL{Path: "/profile"}) + return &SSOProvider{ + ProviderData: p, + StatsdClient: sc, + } +} + +func newRequest(method, url string, body io.Reader) (*http.Request, error) { + req, err := http.NewRequest(method, url, body) + if err != nil { + return nil, err + } + req.Header.Set("User-Agent", userAgentString) + req.Header.Set("Accept", "application/json") + return req, nil +} + +func isProviderUnavailable(statusCode int) bool { + return statusCode == http.StatusTooManyRequests || statusCode == http.StatusServiceUnavailable +} + +func extendDeadline(ttl time.Duration) time.Time { + return time.Now().Add(ttl).Truncate(time.Second) +} + +func (p *SSOProvider) withinGracePeriod(s *SessionState) bool { + if s.GracePeriodStart.IsZero() { + s.GracePeriodStart = time.Now() + } + return s.GracePeriodStart.Add(p.GracePeriodTTL).After(time.Now()) +} + +// Redeem takes a redirectURL and code and redeems the SessionState +func (p *SSOProvider) Redeem(redirectURL, code string) (*SessionState, error) { + if code == "" { + return nil, errors.New("missing code") + } + + // TODO: remove "grant_type" and "redirect_uri", unused by authenticator + params := url.Values{} + params.Add("redirect_uri", redirectURL) + params.Add("client_id", p.ClientID) + params.Add("client_secret", p.ClientSecret) + params.Add("code", code) + params.Add("grant_type", "authorization_code") + + req, err := newRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + resp, err := httpClient.Do(req) + if err != nil { + return nil, err + } + + body, err := ioutil.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return nil, err + } + + if resp.StatusCode != 200 { + if isProviderUnavailable(resp.StatusCode) { + return nil, ErrAuthProviderUnavailable + } + return nil, fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) + } + + var jsonResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int64 `json:"expires_in"` + Email string `json:"email"` + } + err = json.Unmarshal(body, &jsonResponse) + if err != nil { + return nil, err + } + + user := strings.Split(jsonResponse.Email, "@")[0] + return &SessionState{ + AccessToken: jsonResponse.AccessToken, + RefreshToken: jsonResponse.RefreshToken, + + RefreshDeadline: extendDeadline(time.Duration(jsonResponse.ExpiresIn) * time.Second), + LifetimeDeadline: extendDeadline(p.SessionLifetimeTTL), + ValidDeadline: extendDeadline(p.SessionValidTTL), + + Email: jsonResponse.Email, + User: user, + }, nil +} + +// ValidateGroup does a GET request to the profile url and returns true if the user belongs to +// an authorized group. +func (p *SSOProvider) ValidateGroup(email string, allowedGroups []string) ([]string, bool, error) { + logger := log.NewLogEntry() + + logger.WithUser(email).WithAllowedGroups(allowedGroups).Info("validating groups") + inGroups := []string{} + if len(allowedGroups) == 0 { + return inGroups, true, nil + } + + userGroups, err := p.UserGroups(email, allowedGroups) + if err != nil { + return nil, false, err + } + + allowed := false + for _, userGroup := range userGroups { + for _, allowedGroup := range allowedGroups { + if userGroup == allowedGroup { + inGroups = append(inGroups, userGroup) + allowed = true + } + } + } + + return inGroups, allowed, nil +} + +// UserGroups takes an email and returns the UserGroups for that email +func (p *SSOProvider) UserGroups(email string, groups []string) ([]string, error) { + params := url.Values{} + params.Add("email", email) + params.Add("client_id", p.ClientID) + params.Add("groups", strings.Join(groups, ",")) + + req, err := newRequest("GET", fmt.Sprintf("%s?%s", p.ProfileURL.String(), params.Encode()), nil) + if err != nil { + return nil, err + } + req.Header.Set("X-Client-Secret", p.ClientSecret) + resp, err := httpClient.Do(req) + if err != nil { + return nil, err + } + + var body []byte + body, err = ioutil.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return nil, err + } + + if resp.StatusCode != 200 { + if isProviderUnavailable(resp.StatusCode) { + return nil, ErrAuthProviderUnavailable + } + return nil, fmt.Errorf("got %d from %q %s", resp.StatusCode, p.ProfileURL.String(), body) + } + + var jsonResponse struct { + Email string `json:"email"` + Groups []string `json:"groups"` + } + if err := json.Unmarshal(body, &jsonResponse); err != nil { + return nil, err + } + + return jsonResponse.Groups, nil +} + +// RefreshSession takes a SessionState and allowedGroups and refreshes the session access token, +// returns `true` on success, and `false` on error +func (p *SSOProvider) RefreshSession(s *SessionState, allowedGroups []string) (bool, error) { + logger := log.NewLogEntry() + + if s.RefreshToken == "" { + return false, ErrMissingRefreshToken + } + + newToken, duration, err := p.redeemRefreshToken(s.RefreshToken) + if err != nil { + // When we detect that the auth provider is not explicitly denying + // authentication, and is merely unavailable, we refresh and continue + // as normal during the "grace period" + if err == ErrAuthProviderUnavailable && p.withinGracePeriod(s) { + tags := []string{"action:refresh_session", "error:redeem_token_failed"} + p.StatsdClient.Incr("provider_error_fallback", tags, 1.0) + s.RefreshDeadline = extendDeadline(p.SessionValidTTL) + return true, nil + } + return false, err + } + + inGroups, validGroup, err := p.ValidateGroup(s.Email, allowedGroups) + if err != nil { + // When we detect that the auth provider is not explicitly denying + // authentication, and is merely unavailable, we refresh and continue + // as normal during the "grace period" + if err == ErrAuthProviderUnavailable && p.withinGracePeriod(s) { + tags := []string{"action:refresh_session", "error:user_groups_failed"} + p.StatsdClient.Incr("provider_error_fallback", tags, 1.0) + s.RefreshDeadline = extendDeadline(p.SessionValidTTL) + return true, nil + } + return false, err + } + if !validGroup { + return false, errors.New("Group membership revoked") + } + s.Groups = inGroups + + s.AccessToken = newToken + s.RefreshDeadline = extendDeadline(duration) + s.GracePeriodStart = time.Time{} + logger.WithUser(s.Email).WithRefreshDeadline(s.RefreshDeadline).Info("refreshed session access token") + return true, nil +} + +func (p *SSOProvider) redeemRefreshToken(refreshToken string) (token string, expires time.Duration, err error) { + // https://developers.google.com/identity/protocols/OAuth2WebServer#refresh + params := url.Values{} + params.Add("client_id", p.ClientID) + params.Add("client_secret", p.ClientSecret) + params.Add("refresh_token", refreshToken) + var req *http.Request + req, err = newRequest("POST", p.RefreshURL.String(), bytes.NewBufferString(params.Encode())) + if err != nil { + return + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + resp, err := httpClient.Do(req) + if err != nil { + return + } + var body []byte + body, err = ioutil.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return + } + + if resp.StatusCode != http.StatusCreated { + if isProviderUnavailable(resp.StatusCode) { + err = ErrAuthProviderUnavailable + } else { + err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RefreshURL.String(), body) + } + return + } + + var data struct { + AccessToken string `json:"access_token"` + ExpiresIn int64 `json:"expires_in"` + } + err = json.Unmarshal(body, &data) + if err != nil { + return + } + token = data.AccessToken + expires = time.Duration(data.ExpiresIn) * time.Second + return +} + +// ValidateSessionState takes a sessionState and allowedGroups and validates the session state +func (p *SSOProvider) ValidateSessionState(s *SessionState, allowedGroups []string) bool { + logger := log.NewLogEntry() + + // we validate the user's access token is valid + params := url.Values{} + params.Add("client_id", p.ClientID) + req, err := newRequest("GET", fmt.Sprintf("%s?%s", p.ValidateURL.String(), params.Encode()), nil) + if err != nil { + logger.WithUser(s.Email).Error(err, "error validating session state") + return false + } + req.Header.Set("X-Client-Secret", p.ClientSecret) + req.Header.Set("X-Access-Token", s.AccessToken) + + resp, err := httpClient.Do(req) + if err != nil { + logger.WithUser(s.Email).Error("error making request to validate access token") + return false + } + + if resp.StatusCode != 200 { + // When we detect that the auth provider is not explicitly denying + // authentication, and is merely unavailable, we validate and continue + // as normal during the "grace period" + if isProviderUnavailable(resp.StatusCode) && p.withinGracePeriod(s) { + tags := []string{"action:validate_session", "error:validation_failed"} + p.StatsdClient.Incr("provider_error_fallback", tags, 1.0) + s.ValidDeadline = extendDeadline(p.SessionValidTTL) + return true + } + logger.WithUser(s.Email).WithHTTPStatus(resp.StatusCode).Info( + "could not validate user access token") + return false + } + + // check the user is in the proper group(s) + inGroups, validGroup, err := p.ValidateGroup(s.Email, allowedGroups) + if err != nil { + // When we detect that the auth provider is not explicitly denying + // authentication, and is merely unavailable, we validate and continue + // as normal during the "grace period" + if err == ErrAuthProviderUnavailable && p.withinGracePeriod(s) { + tags := []string{"action:validate_session", "error:user_groups_failed"} + p.StatsdClient.Incr("provider_error_fallback", tags, 1.0) + s.ValidDeadline = extendDeadline(p.SessionValidTTL) + return true + } + logger.WithUser(s.Email).Error(err, "error fetching group memberships") + return false + } + + if !validGroup { + logger.WithUser(s.Email).WithAllowedGroups(allowedGroups).Info( + "user is no longer in valid groups") + return false + } + s.Groups = inGroups + + s.ValidDeadline = extendDeadline(p.SessionValidTTL) + s.GracePeriodStart = time.Time{} + + logger.WithUser(s.Email).WithSessionValid(s.ValidDeadline).Info("validated session") + + return true +} diff --git a/internal/devproxy/providers/sso_test.go b/internal/devproxy/providers/sso_test.go new file mode 100644 index 00000000..88a01f81 --- /dev/null +++ b/internal/devproxy/providers/sso_test.go @@ -0,0 +1,550 @@ +package providers + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "os" + "strings" + "testing" + "time" + + "github.com/buzzfeed/sso/internal/pkg/testutil" +) + +func newTestServer(status int, body []byte) (*url.URL, *httptest.Server) { + s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(status) + rw.Write(body) + })) + u, _ := url.Parse(s.URL) + return u, s +} + +func newCodeTestServer(code int) (*url.URL, *httptest.Server) { + s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(code) + })) + u, _ := url.Parse(s.URL) + return u, s +} + +func newSSOProvider() *SSOProvider { + return NewSSOProvider( + &ProviderData{ + ProviderURL: &url.URL{ + Scheme: "https", + Host: "auth.example.com", + }, + }, nil) +} + +func TestNewRequest(t *testing.T) { + testCases := []struct { + name string + url string + expectedError bool + }{ + { + name: "error on new request", + url: ":", + expectedError: true, + }, + { + name: "optional headers set", + url: "/", + expectedError: false, + }, + } + os.Setenv("RIG_IMAGE_VERSION", "testVersion") + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + p := newSSOProvider() + if p == nil { + t.Fatalf("expected provider to not be nil but was") + } + req, err := newRequest("GET", tc.url, nil) + if tc.expectedError && err == nil { + t.Errorf("expected error but error was nil") + } + if !tc.expectedError && err != nil { + t.Errorf("unexpected error %s", err.Error()) + } + if err != nil { + return + } + if req.Header.Get("User-Agent") == "testVersion" { + t.Errorf("expected User-Agent header to be set but it was not") + } + + }) + } + +} + +func TestSSOProviderDefaults(t *testing.T) { + p := newSSOProvider() + testutil.NotEqual(t, nil, p) + + data := p.Data() + testutil.Equal(t, "SSO", data.ProviderName) + + base := fmt.Sprintf("%s://%s", data.ProviderURL.Scheme, data.ProviderURL.Host) + testutil.Equal(t, fmt.Sprintf("%s/sign_in", base), data.SignInURL.String()) + testutil.Equal(t, fmt.Sprintf("%s/sign_out", base), data.SignOutURL.String()) + testutil.Equal(t, fmt.Sprintf("%s/redeem", base), data.RedeemURL.String()) + testutil.Equal(t, fmt.Sprintf("%s/refresh", base), data.RefreshURL.String()) + testutil.Equal(t, fmt.Sprintf("%s/validate", base), data.ValidateURL.String()) + testutil.Equal(t, fmt.Sprintf("%s/profile", base), data.ProfileURL.String()) +} + +type redeemResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int64 `json:"expires_in"` + Email string `json:"email"` +} + +type refreshResponse struct { + Code int + AccessToken string `json:"access_token"` + ExpiresIn int64 `json:"expires_in"` +} + +type profileResponse struct { + Email string `json:"email"` + Groups []string `json:"groups"` +} + +func TestSSOProviderGroups(t *testing.T) { + testCases := []struct { + Name string + Email string + Groups []string + ProxyGroupIds []string + ExpectedValid bool + ExpectedInGroups []string + ExpectError error + ProfileStatus int + }{ + { + Name: "valid when no group id set", + Email: "michael.bland@gsa.gov", + Groups: []string{}, + ProxyGroupIds: []string{}, + ExpectedValid: true, + ExpectedInGroups: []string{}, + ExpectError: nil, + }, + { + Name: "valid when the group id exists", + Email: "michael.bland@gsa.gov", + Groups: []string{"user-in-this-group", "random-group"}, + ProxyGroupIds: []string{"user-in-this-group", "user-not-in-this-group"}, + ExpectedValid: true, + ExpectedInGroups: []string{"user-in-this-group"}, + ExpectError: nil, + }, + { + Name: "valid when the multiple group id exists", + Email: "michael.bland@gsa.gov", + Groups: []string{"user-in-this-group", "user-also-in-this-group"}, + ProxyGroupIds: []string{"user-in-this-group", "user-also-in-this-group"}, + ExpectedValid: true, + ExpectedInGroups: []string{"user-in-this-group", "user-also-in-this-group"}, + ExpectError: nil, + }, + { + Name: "invalid when the group id isn't in user groups", + Email: "michael.bland@gsa.gov", + Groups: []string{}, + ProxyGroupIds: []string{"test1"}, + ExpectedValid: false, + ExpectedInGroups: []string{}, + ExpectError: nil, + }, + { + Name: "invalid if can't access groups", + Email: "michael.bland@gsa.gov", + Groups: []string{}, + ProxyGroupIds: []string{"session-group"}, + ProfileStatus: http.StatusTooManyRequests, + ExpectError: ErrAuthProviderUnavailable, + }, + } + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + p := newSSOProvider() + body, err := json.Marshal(profileResponse{ + Email: tc.Email, + Groups: tc.Groups, + }) + testutil.Equal(t, nil, err) + var server *httptest.Server + profileStatus := http.StatusOK + if tc.ProfileStatus != 0 { + profileStatus = tc.ProfileStatus + } + p.ProfileURL, server = newTestServer(profileStatus, body) + defer server.Close() + inGroups, valid, err := p.ValidateGroup(tc.Email, tc.ProxyGroupIds) + testutil.Equal(t, tc.ExpectError, err) + if err == nil { + testutil.Equal(t, tc.ExpectedValid, valid) + testutil.Equal(t, tc.ExpectedInGroups, inGroups) + } + }) + } +} + +func TestSSOProviderGetEmailAddress(t *testing.T) { + testCases := []struct { + Name string + Code string + ExpectedError string + RedeemResponse *redeemResponse + ProfileResponse *profileResponse + }{ + { + Name: "redeem fails without code", + ExpectedError: "missing code", + }, + { + Name: "redeem fails if redemption server not responding", + Code: "code1234", + ExpectedError: "got 400", + }, + { + Name: "redeem successful", + Code: "code1234", + RedeemResponse: &redeemResponse{ + AccessToken: "a1234", + ExpiresIn: 10, + RefreshToken: "refresh12345", + Email: "michael.bland@gsa.gov", + }, + ProfileResponse: &profileResponse{ + Email: "michael.bland@gsa.gov", + Groups: []string{"core@gsa.gov"}, + }, + }, + } + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + p := newSSOProvider() + + var redeemServer *httptest.Server + // set up redemption resource + if tc.RedeemResponse != nil { + body, err := json.Marshal(tc.RedeemResponse) + testutil.Equal(t, nil, err) + p.RedeemURL, redeemServer = newTestServer(http.StatusOK, body) + } else { + p.RedeemURL, redeemServer = newCodeTestServer(400) + } + defer redeemServer.Close() + + var profileServer *httptest.Server + if tc.ProfileResponse != nil { + body, err := json.Marshal(tc.ProfileResponse) + testutil.Equal(t, nil, err) + p.ProfileURL, profileServer = newTestServer(http.StatusOK, body) + } else { + p.RedeemURL, profileServer = newCodeTestServer(400) + } + defer profileServer.Close() + + session, err := p.Redeem("http://redirect/", tc.Code) + if tc.RedeemResponse != nil { + testutil.Equal(t, nil, err) + testutil.NotEqual(t, session, nil) + testutil.Equal(t, tc.RedeemResponse.Email, session.Email) + testutil.Equal(t, tc.RedeemResponse.AccessToken, session.AccessToken) + testutil.Equal(t, tc.RedeemResponse.RefreshToken, session.RefreshToken) + } + if tc.ExpectedError != "" && !strings.Contains(err.Error(), tc.ExpectedError) { + t.Errorf("got unexpected result.\nwant=%v\ngot=%v\n", tc.ExpectedError, err.Error()) + } + }) + } +} + +func TestSSOProviderValidateSessionState(t *testing.T) { + testCases := []struct { + Name string + SessionState *SessionState + ProviderResponse int + Groups []string + ProxyGroupIds []string + ExpectedValid bool + }{ + { + Name: "valid when no group id set", + SessionState: &SessionState{ + AccessToken: "abc", + Email: "michael.bland@gsa.gov", + }, + ProviderResponse: http.StatusOK, + Groups: []string{}, + ProxyGroupIds: []string{}, + ExpectedValid: true, + }, + { + Name: "invalid when response is is not 200", + SessionState: &SessionState{ + AccessToken: "abc", + Email: "michael.bland@gsa.gov", + }, + ProviderResponse: http.StatusForbidden, + Groups: []string{}, + ProxyGroupIds: []string{}, + ExpectedValid: false, + }, + { + Name: "valid when the group id exists", + SessionState: &SessionState{ + AccessToken: "abc", + Email: "michael.bland@gsa.gov", + }, + ProviderResponse: http.StatusOK, + Groups: []string{"test1", "test2"}, + ProxyGroupIds: []string{"test1"}, + ExpectedValid: true, + }, + { + Name: "invalid when the group id isn't in user groups", + SessionState: &SessionState{ + AccessToken: "abc", + Email: "michael.bland@gsa.gov", + }, + ProviderResponse: http.StatusOK, + Groups: []string{}, + ProxyGroupIds: []string{"test1"}, + ExpectedValid: false, + }, + { + Name: "valid when provider unavailable, but grace period active", + SessionState: &SessionState{ + GracePeriodStart: time.Now().Add(time.Duration(-1) * time.Hour), + }, + ProviderResponse: http.StatusTooManyRequests, + ExpectedValid: true, + }, + { + Name: "invalid when provider unavailable and grace period inactive", + SessionState: &SessionState{ + GracePeriodStart: time.Now().Add(time.Duration(-4) * time.Hour), + }, + ProviderResponse: http.StatusServiceUnavailable, + ExpectedValid: false, + }, + } + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + p := newSSOProvider() + p.GracePeriodTTL = time.Duration(3) * time.Hour + + // setup group endpoint + body, err := json.Marshal(profileResponse{ + Email: tc.SessionState.Email, + Groups: tc.Groups, + }) + testutil.Equal(t, nil, err) + var profileServer *httptest.Server + p.ProfileURL, profileServer = newTestServer(http.StatusOK, body) + defer profileServer.Close() + + validateServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + accessToken := r.Header.Get("X-Access-Token") + if accessToken != tc.SessionState.AccessToken { + t.Logf("want: %v", tc.SessionState.AccessToken) + t.Logf(" got: %v", accessToken) + t.Fatalf("unexpected access token value") + } + rw.WriteHeader(tc.ProviderResponse) + })) + p.ValidateURL, _ = url.Parse(validateServer.URL) + defer validateServer.Close() + + valid := p.ValidateSessionState(tc.SessionState, tc.ProxyGroupIds) + if valid != tc.ExpectedValid { + t.Errorf("got unexpected result. want=%v got=%v", tc.ExpectedValid, valid) + } + }) + } +} + +func TestSSOProviderRefreshSession(t *testing.T) { + testCases := []struct { + Name string + SessionState *SessionState + UserGroups []string + ProxyGroups []string + RefreshResponse *refreshResponse + ExpectedRefresh bool + ExpectedError string + }{ + { + Name: "no refresh if no refresh token", + SessionState: &SessionState{ + Email: "user@domain.com", + AccessToken: "token1234", + RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), + }, + RefreshResponse: &refreshResponse{ + Code: http.StatusBadRequest, + }, + ExpectedRefresh: false, + }, + { + Name: "no refresh if not yet expired", + SessionState: &SessionState{ + Email: "user@domain.com", + AccessToken: "token1234", + RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), + RefreshToken: "refresh1234", + }, + RefreshResponse: &refreshResponse{ + Code: http.StatusBadRequest, + }, + ExpectedRefresh: false, + }, + { + Name: "no refresh if redeem not responding", + SessionState: &SessionState{ + Email: "user@domain.com", + AccessToken: "token1234", + RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), + RefreshToken: "refresh1234", + }, + RefreshResponse: &refreshResponse{ + Code: http.StatusBadRequest, + }, + ExpectedRefresh: false, + ExpectedError: "got 400", + }, + { + Name: "no refresh if profile not responding", + SessionState: &SessionState{ + Email: "user@domain.com", + AccessToken: "token1234", + RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), + RefreshToken: "refresh1234", + }, + RefreshResponse: &refreshResponse{ + Code: http.StatusCreated, + ExpiresIn: 10, + AccessToken: "newToken1234", + }, + ProxyGroups: []string{"test1"}, + ExpectedRefresh: false, + ExpectedError: "got 500", + }, + { + Name: "no refresh if user no longer in group", + SessionState: &SessionState{ + Email: "user@domain.com", + AccessToken: "token1234", + RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), + RefreshToken: "refresh1234", + }, + UserGroups: []string{"useless"}, + ProxyGroups: []string{"test1"}, + RefreshResponse: &refreshResponse{ + Code: http.StatusCreated, + ExpiresIn: 10, + AccessToken: "newToken1234", + }, + ExpectedRefresh: false, + ExpectedError: "Group membership revoked", + }, + { + Name: "successful refresh if can redeem and user in group", + SessionState: &SessionState{ + Email: "user@domain.com", + AccessToken: "token1234", + RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), + RefreshToken: "refresh1234", + }, + UserGroups: []string{"test1"}, + ProxyGroups: []string{"test1"}, + RefreshResponse: &refreshResponse{ + Code: http.StatusCreated, + ExpiresIn: 10, + AccessToken: "newToken1234", + }, + ExpectedRefresh: true, + }, + { + Name: "successful refresh if provider unavailable but within grace period", + SessionState: &SessionState{ + GracePeriodStart: time.Now().Add(time.Duration(-1) * time.Hour), + RefreshToken: "refresh1234", + }, + RefreshResponse: &refreshResponse{ + Code: http.StatusTooManyRequests, + }, + ExpectedRefresh: true, + }, + { + Name: "failed refresh if provider unavailable and outside grace period", + SessionState: &SessionState{ + GracePeriodStart: time.Now().Add(time.Duration(-4) * time.Hour), + RefreshToken: "refresh1234", + }, + RefreshResponse: &refreshResponse{ + Code: http.StatusTooManyRequests, + }, + ExpectedRefresh: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + p := newSSOProvider() + p.GracePeriodTTL = time.Duration(3) * time.Hour + + groups := []string{} + if tc.ProxyGroups != nil { + groups = tc.ProxyGroups + } + + // set up redeem resource + var refreshServer *httptest.Server + body, err := json.Marshal(tc.RefreshResponse) + testutil.Equal(t, nil, err) + p.RefreshURL, refreshServer = newTestServer(tc.RefreshResponse.Code, body) + defer refreshServer.Close() + + // set up groups resource + var groupsServer *httptest.Server + if tc.UserGroups != nil { + body, err := json.Marshal(profileResponse{ + Email: tc.SessionState.Email, + Groups: tc.UserGroups, + }) + testutil.Equal(t, nil, err) + p.ProfileURL, groupsServer = newTestServer(http.StatusOK, body) + } else { + p.ProfileURL, groupsServer = newCodeTestServer(500) + } + defer groupsServer.Close() + + // run the endpoint + actualRefresh, err := p.RefreshSession(tc.SessionState, groups) + if tc.ExpectedRefresh != actualRefresh { + t.Fatalf("got unexpected refresh behavior. want=%v got=%v", tc.ExpectedRefresh, actualRefresh) + } + + if tc.ExpectedError != "" && err == nil { + t.Fatalf("expected error: %v got: %v", tc.ExpectedError, err) + } + + if tc.ExpectedError != "" && !strings.Contains(err.Error(), tc.ExpectedError) { + t.Fatalf("got unexpected result.\nwant=%v\ngot=%v\n", tc.ExpectedError, err.Error()) + } + }) + } +} diff --git a/internal/devproxy/templates.go b/internal/devproxy/templates.go new file mode 100644 index 00000000..3b696b53 --- /dev/null +++ b/internal/devproxy/templates.go @@ -0,0 +1,126 @@ +package devproxy + +import ( + "html/template" +) + +func getTemplates() *template.Template { + t := template.New("foo") + t = template.Must(t.Parse(`{{define "error.html"}} + + + + Error + + + + + +
+
+
+

{{.Title}}

+
+

+ {{.Message}}
+ HTTP {{.Code}} +

+ {{if ne .Code 403 }} +
+ +
+ {{end}} +
+
Secured by SSO
+
+ +{{end}}`)) + return t +} diff --git a/internal/devproxy/testdata/upstream_configs.yml b/internal/devproxy/testdata/upstream_configs.yml new file mode 100644 index 00000000..41eade61 --- /dev/null +++ b/internal/devproxy/testdata/upstream_configs.yml @@ -0,0 +1,10 @@ +- service: httpbin + default: + from: httpbin.sso.localtest.me + to: http://httpheader.net + +# - service: hello-world +# default: +# from: hello-world.sso.localtest.me +# to: http://httpheader.net + From 7964ce353c0989c823edb3468d4f9c01f13aedb9 Mon Sep 17 00:00:00 2001 From: Maliheh Date: Fri, 4 Jan 2019 11:17:49 -0800 Subject: [PATCH 2/5] removed comments, statsd references, and providers --- cmd/sso-devproxy/main.go | 5 +- internal/devproxy/collector/collector.go | 102 ---- internal/devproxy/dev_config.go | 42 +- internal/devproxy/devproxy.go | 22 +- internal/devproxy/devproxy_test.go | 3 - internal/devproxy/logging_handler.go | 63 +- internal/devproxy/metrics.go | 30 - internal/devproxy/options.go | 32 - internal/devproxy/providers/http_client.go | 17 - internal/devproxy/providers/internal_util.go | 85 --- .../devproxy/providers/internal_util_test.go | 132 ----- internal/devproxy/providers/provider_data.go | 31 - .../devproxy/providers/provider_default.go | 146 ----- .../providers/provider_default_test.go | 19 - internal/devproxy/providers/providers.go | 25 - internal/devproxy/providers/session_state.go | 61 -- .../devproxy/providers/session_state_test.go | 72 --- .../providers/singleflight_middleware.go | 145 ----- internal/devproxy/providers/sso.go | 385 ------------ internal/devproxy/providers/sso_test.go | 550 ------------------ 20 files changed, 29 insertions(+), 1938 deletions(-) delete mode 100644 internal/devproxy/collector/collector.go delete mode 100644 internal/devproxy/providers/http_client.go delete mode 100644 internal/devproxy/providers/internal_util.go delete mode 100644 internal/devproxy/providers/internal_util_test.go delete mode 100644 internal/devproxy/providers/provider_data.go delete mode 100644 internal/devproxy/providers/provider_default.go delete mode 100644 internal/devproxy/providers/provider_default_test.go delete mode 100644 internal/devproxy/providers/providers.go delete mode 100644 internal/devproxy/providers/session_state.go delete mode 100644 internal/devproxy/providers/session_state_test.go delete mode 100644 internal/devproxy/providers/singleflight_middleware.go delete mode 100644 internal/devproxy/providers/sso.go delete mode 100644 internal/devproxy/providers/sso_test.go diff --git a/cmd/sso-devproxy/main.go b/cmd/sso-devproxy/main.go index a0b7ffb1..b739a3db 100644 --- a/cmd/sso-devproxy/main.go +++ b/cmd/sso-devproxy/main.go @@ -5,9 +5,8 @@ import ( "net/http" "os" - log "github.com/buzzfeed/sso/internal/pkg/logging" - // "github.com/buzzfeed/sso/internal/pkg/options" "github.com/buzzfeed/sso/internal/devproxy" + log "github.com/buzzfeed/sso/internal/pkg/logging" "github.com/kelseyhightower/envconfig" ) @@ -42,7 +41,7 @@ func main() { Addr: fmt.Sprintf(":%d", opts.Port), ReadTimeout: opts.TCPReadTimeout, WriteTimeout: opts.TCPWriteTimeout, - Handler: devproxy.NewLoggingHandler(os.Stdout, proxy.Handler(), opts.RequestLogging), //, devproxy.StatsdClient), + Handler: devproxy.NewLoggingHandler(os.Stdout, proxy.Handler(), opts.RequestLogging), } logger.Fatal(s.ListenAndServe()) } diff --git a/internal/devproxy/collector/collector.go b/internal/devproxy/collector/collector.go deleted file mode 100644 index 02514934..00000000 --- a/internal/devproxy/collector/collector.go +++ /dev/null @@ -1,102 +0,0 @@ -package collector - -import ( - "runtime" - "time" - - "github.com/datadog/datadog-go/statsd" -) - -// Collector ticks periodically and emits runtime stats to datadog -type Collector struct { - // interval represents the interval inbetween ticks for stats collection - interval time.Duration - - // done, when closed, is used to signal the closure of the runtime polling goroutine - done chan struct{} - - // statsd client used to send metrics - client *statsd.Client -} - -// New creates a new collector that will periodically emit runtime statistics to datadog. -func New(client *statsd.Client, interval time.Duration) *Collector { - return &Collector{ - interval: interval, - client: client, - done: make(chan struct{}), - } -} - -// Run gathers statistics from package runtime and emits them to statsd via client -func (c *Collector) Run() { - tick := time.NewTicker(c.interval) - defer tick.Stop() - for { - select { - case <-c.done: - return - case <-tick.C: - c.emitStats() - } - } -} - -// Close signals the collector to close the polling goroutine, use for graceful shutdowns -func (c *Collector) Close() { - close(c.done) -} - -func (c *Collector) emitStats() { - c.emitCPUStats() - c.emitMemStats() -} - -func (c *Collector) emitCPUStats() { - c.gauge("cpu.goroutines", uint64(runtime.NumGoroutine())) - c.gauge("cpu.cgo_calls", uint64(runtime.NumCgoCall())) -} - -func (c *Collector) emitMemStats() { - m := &runtime.MemStats{} - runtime.ReadMemStats(m) - - // General - c.gauge("mem.alloc", m.Alloc) - c.gauge("mem.total", m.TotalAlloc) - c.gauge("mem.sys", m.Sys) - c.gauge("mem.lookups", m.Lookups) - c.gauge("mem.malloc", m.Mallocs) - c.gauge("mem.frees", m.Frees) - - // Heap - c.gauge("mem.heap.alloc", m.HeapAlloc) - c.gauge("mem.heap.sys", m.HeapSys) - c.gauge("mem.heap.idle", m.HeapIdle) - c.gauge("mem.heap.inuse", m.HeapInuse) - c.gauge("mem.heap.released", m.HeapReleased) - c.gauge("mem.heap.objects", m.HeapObjects) - - // Stack - c.gauge("mem.stack.inuse", m.StackInuse) - c.gauge("mem.stack.sys", m.StackSys) - c.gauge("mem.stack.mspan_inuse", m.MSpanInuse) - c.gauge("mem.stack.mspan_sys", m.MSpanSys) - c.gauge("mem.stack.mcache_inuse", m.MCacheInuse) - c.gauge("mem.stack.mcache_sys", m.MCacheSys) - - // Garbage Collection - c.gauge("mem.gc.sys", m.GCSys) - c.gauge("mem.gc.next", m.NextGC) - c.gauge("mem.gc.last", m.LastGC) - c.gauge("mem.gc.pause_total", m.PauseTotalNs) - c.gauge("mem.gc.pause", m.PauseNs[(m.NumGC+255)%256]) - c.gauge("mem.gc.count", uint64(m.NumGC)) - - // Other - c.gauge("mem.othersys", m.OtherSys) -} - -func (c *Collector) gauge(key string, val uint64) { - c.client.Gauge(key, float64(val), nil, 1.0) -} diff --git a/internal/devproxy/dev_config.go b/internal/devproxy/dev_config.go index 295ac3fe..0afa0866 100644 --- a/internal/devproxy/dev_config.go +++ b/internal/devproxy/dev_config.go @@ -7,7 +7,6 @@ import ( "strings" "time" - // "github.com/18F/hmacauth" "github.com/imdario/mergo" "gopkg.in/yaml.v2" ) @@ -48,10 +47,7 @@ type UpstreamConfig struct { ExtraRoutes []*RouteConfig `yaml:"extra_routes"` // Generated at Parse Time - Route interface{} // note: :/ - - // SkipAuthCompiledRegex []*regexp.Regexp - // AllowedGroups []string + Route interface{} // note: :/ Timeout time.Duration FlushInterval time.Duration HeaderOverrides map[string]string @@ -76,10 +72,8 @@ type RouteConfig struct { // * flush_interval - interval at which the proxy should flush data to the browser type OptionsConfig struct { HeaderOverrides map[string]string `yaml:"header_overrides"` - // SkipAuthRegex []string `yaml:"skip_auth_regex"` - // AllowedGroups []string `yaml:"allowed_groups"` - Timeout time.Duration `yaml:"timeout"` - FlushInterval time.Duration `yaml:"flush_interval"` + Timeout time.Duration `yaml:"timeout"` + FlushInterval time.Duration `yaml:"flush_interval"` } // ErrParsingConfig is an error specific to config parsing. @@ -180,23 +174,6 @@ func loadServiceConfigs(raw []byte, cluster, scheme string, configVars map[strin } } - // for _, proxy := range configs { - // key := fmt.Sprintf("%s_signing_key", proxy.Service) - // signingKey, ok := configVars[key] - // if !ok { - // continue - // } - // auth, err := generateHmacAuth(signingKey) - // if err != nil { - // return nil, &ErrParsingConfig{ - // Message: fmt.Sprintf("unable to generate hmac auth for %s", proxy.Service), - // Err: err, - // } - // } - // proxy.HMACAuth = auth - - // } - return configs, nil } @@ -345,19 +322,6 @@ func parseOptionsConfig(proxy *UpstreamConfig) error { return nil } - // We compile all the regexes in SkipAuth Regex - // for _, uncompiled := range proxy.RouteConfig.Options.SkipAuthRegex { - // compiled, err := regexp.Compile(uncompiled) - // if err != nil { - // return &ErrParsingConfig{ - // Message: "unable to compile skip auth regex", - // Err: err, - // } - // } - // proxy.SkipAuthCompiledRegex = append(proxy.SkipAuthCompiledRegex, compiled) - // } - - // proxy.AllowedGroups = proxy.RouteConfig.Options.AllowedGroups proxy.Timeout = proxy.RouteConfig.Options.Timeout proxy.FlushInterval = proxy.RouteConfig.Options.FlushInterval proxy.HeaderOverrides = proxy.RouteConfig.Options.HeaderOverrides diff --git a/internal/devproxy/devproxy.go b/internal/devproxy/devproxy.go index 9a86b7b1..74d8d1eb 100644 --- a/internal/devproxy/devproxy.go +++ b/internal/devproxy/devproxy.go @@ -2,29 +2,17 @@ package devproxy import ( "encoding/json" - "time" - - // "errors" "fmt" "html/template" "io" - - // "net" "net/http" "net/http/httputil" "net/url" - - // "reflect" "regexp" "strings" + "time" - // "time" - - // "github.com/buzzfeed/sso/internal/pkg/aead" log "github.com/buzzfeed/sso/internal/pkg/logging" - // "github.com/buzzfeed/sso/internal/dev/collector" - // "github.com/18F/hmacauth" - // "github.com/datadog/datadog-go/statsd" ) // SignatureHeader is the header name where the signed request header is stored. @@ -261,9 +249,6 @@ func (p *DevProxy) Handler() http.Handler { // order as applied here (i.e., we want to validate the host _first_ when // processing a request) var handler http.Handler = mux - // if p.CookieSecure { - // handler = requireHTTPS(handler) - // } handler = p.setResponseHeaderOverrides(handler) handler = setSecurityHeaders(handler) handler = p.validateHost(handler) @@ -352,11 +337,6 @@ func (p *DevProxy) Proxy(rw http.ResponseWriter, req *http.Request) { func (p *DevProxy) UnknownHost(rw http.ResponseWriter, req *http.Request) { logger := log.NewLogEntry() - // tags := []string{ - // fmt.Sprintf("action:%s", GetActionTag(req)), - // "error:unknown_host", - // } - // p.StatsdClient.Incr("application_error", tags, 1.0) logger.WithRequestHost(req.Host).Error("unknown host") http.Error(rw, "", statusInvalidHost) } diff --git a/internal/devproxy/devproxy_test.go b/internal/devproxy/devproxy_test.go index adab4134..c734e655 100644 --- a/internal/devproxy/devproxy_test.go +++ b/internal/devproxy/devproxy_test.go @@ -1,10 +1,7 @@ package devproxy import ( - // "crypto" - // "encoding/base64" "encoding/json" - // "errors" "fmt" "io" "io/ioutil" diff --git a/internal/devproxy/logging_handler.go b/internal/devproxy/logging_handler.go index e81248de..44c26004 100644 --- a/internal/devproxy/logging_handler.go +++ b/internal/devproxy/logging_handler.go @@ -6,16 +6,12 @@ package devproxy import ( "io" "net/http" - - // "net/url" + "net/url" "strings" - // "time" - // log "github.com/buzzfeed/sso/internal/pkg/logging" - // "github.com/datadog/datadog-go/statsd" -) + "time" -// Used to stash the authenticated user in the response for access when logging requests. -// const loggingUserHeader = "SSO-Authenticated-User" + log "github.com/buzzfeed/sso/internal/pkg/logging" +) // responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP status // code and body size @@ -30,14 +26,6 @@ func (l *responseLogger) Header() http.Header { return l.w.Header() } -// func (l *responseLogger) extractUser() { -// authInfo := l.w.Header().Get(loggingUserHeader) -// if authInfo != "" { -// l.authInfo = authInfo -// l.w.Header().Del(loggingUserHeader) -// } -// } - func (l *responseLogger) Write(b []byte) (int, error) { if l.status == 0 { // The status will be StatusOK if WriteHeader has not been called yet @@ -72,47 +60,42 @@ func (l *responseLogger) Flush() { type loggingHandler struct { writer io.Writer handler http.Handler - // StatsdClient *statsd.Client enabled bool } -// NewLoggingHandler returns a new loggingHandler that wraps a handler, statsd client, and writer. -func NewLoggingHandler(out io.Writer, h http.Handler, v bool /*, StatsdClient *statsd.Client*/) http.Handler { +// NewLoggingHandler returns a new loggingHandler that wraps a handler, and writer. +func NewLoggingHandler(out io.Writer, h http.Handler, v bool) http.Handler { return loggingHandler{writer: out, handler: h, enabled: v, - // StatsdClient: StatsdClient, } } func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { - // now := time.Now() - // url := *req.URL + t := time.Now() + url := *req.URL logger := &responseLogger{w: w} h.handler.ServeHTTP(logger, req) if !h.enabled { return } - // logRequest(logger.authInfo, req, url, now, logger.Status(), h.StatsdClient) + requestDuration := time.Now().Sub(t) + logRequest(req, url, requestDuration, logger.Status()) } -// logRequest logs information about a request -// func logRequest(username string, req *http.Request, url url.URL, ts time.Time, status int, StatsdClient *statsd.Client) { -// duration := time.Now().Sub(ts) - -// // Convert duration to floating point milliseconds -// // https://github.com/golang/go/issues/5491#issuecomment-66079585 -// durationMS := duration.Seconds() * 1e3 - -// uri := req.Host + url.RequestURI() - -// logger := log.NewLogEntry() -// logger.WithHTTPStatus(status).WithRequestMethod(req.Method).WithRequestURI( -// uri).WithUserAgent(req.Header.Get("User-Agent")).WithRemoteAddress( -// getRemoteAddr(req)).WithRequestDurationMs(durationMS).WithUser( -// username).WithAction(GetActionTag(req)).Info() -// logRequestMetrics(req, duration, status, StatsdClient) -// } +// logRequests creates a log message from the request status, method, url, proxy host and duration of the request +func logRequest(req *http.Request, url url.URL, requestDuration time.Duration, status int) { + // Convert duration to floating point milliseconds + // https://github.com/golang/go/issues/5491#issuecomment-66079585 + durationMS := requestDuration.Seconds() * 1e3 + + logger := log.NewLogEntry() + logger.WithHTTPStatus(status).WithRequestMethod(req.Method).WithRequestURI( + url.RequestURI()).WithUserAgent( + req.Header.Get("User-Agent")).WithRemoteAddress( + getRemoteAddr(req)).WithRequestDurationMs( + durationMS).WithAction(GetActionTag(req)).Info() +} // getRemoteAddr returns the client IP address from a request. If present, the // X-Forwarded-For header is assumed to be set by a load balancer, and its diff --git a/internal/devproxy/metrics.go b/internal/devproxy/metrics.go index bc7b5795..68a56048 100644 --- a/internal/devproxy/metrics.go +++ b/internal/devproxy/metrics.go @@ -1,12 +1,7 @@ package devproxy import ( - "fmt" - // "net" "net/http" - // "strconv" - "time" - // "github.com/datadog/datadog-go/statsd" ) // GetActionTag returns the action triggered by an http.Request . @@ -27,28 +22,3 @@ func GetActionTag(req *http.Request) string { } return "proxy" } - -// logMetrics logs all metrics surrounding a given request to the metricsWriter -func logRequestMetrics(req *http.Request, requestDuration time.Duration, status int) { //, StatsdClient *statsd.Client) { - // Normalize proxyHost for a) invalid requests or b) LB health checks to - // avoid polluting the proxy_host tag's value space - proxyHost := req.Host - if status == statusInvalidHost { - proxyHost = "_unknown" - } - if req.URL.Path == "/ping" { - proxyHost = "_healthcheck" - } - - // tags := []string{ - fmt.Sprintf("method:%s", req.Method) - fmt.Sprintf("status_code:%d", status) - fmt.Sprintf("status_category:%dxx", status/100) - fmt.Sprintf("action:%s", GetActionTag(req)) - fmt.Sprintf("proxy_host:%s", proxyHost) - // } - - // TODO: eventually make rates configurable - // StatsdClient.Timing("request", requestDuration, tags, 1.0) - -} diff --git a/internal/devproxy/options.go b/internal/devproxy/options.go index 0a02ef1a..5465f655 100644 --- a/internal/devproxy/options.go +++ b/internal/devproxy/options.go @@ -1,34 +1,22 @@ package devproxy import ( - // "encoding/base64" - // "errors" "fmt" "io/ioutil" - - // "net/http" "net/url" "os" "strings" "time" - // "github.com/datadog/datadog-go/statsd" ) // Options are configuration options that can be set by Environment Variables // Port - int - port to listen on for HTTP clients -// ProviderURLString - the URL for the provider in this environment: "https://sso-auth.example.com" // UpstreamConfigsFile - the path to upstream configs file // Cluster - the cluster in which this is running, used for upstream configs // Scheme - the default scheme, used for upstream configs // DefaultUpstreamTimeout - the default time period to wait for a response from an upstream // TCPWriteTimeout - http server tcp write timeout // TCPReadTimeout - http server tcp read timeout -// SessionLifetimeTTL - time to live for a session lifetime -// SessionValidTTL - time to live for a valid session -// GracePeriodTTL - time to reuse session data when provider unavailable -// RequestLoging - boolean whether or not to log requests -// StatsdHost - host addr for statsd client to listen on -// StatsdPort - port for statsdclient to listen on type Options struct { Port int `envconfig:"PORT" default:"4180"` @@ -45,11 +33,6 @@ type Options struct { RequestLogging bool `envconfig:"REQUEST_LOGGING" default:"true"` - // StatsdHost string `envconfig:"STATSD_HOST"` - // StatsdPort int `envconfig:"STATSD_PORT"` - - // StatsdClient *statsd.Client - // This is an override for supplying template vars at test time testTemplateVars map[string]string @@ -84,21 +67,6 @@ func (o *Options) Validate() error { msgs = append(msgs, "missing setting: upstream-configs") } - // if o.StatsdHost == "" { - // msgs = append(msgs, "missing setting: statsd-host") - // } - - // if o.StatsdPort == 0 { - // msgs = append(msgs, "missing setting: statsd-port") - // } - // if o.StatsdHost != "" && o.StatsdPort != 0 { - // StatsdClient, err := newStatsdClient(o) - // if err != nil { - // msgs = append(msgs, fmt.Sprintf("error creating statsd client error=%q", err)) - // } - // o.StatsdClient = StatsdClient - // } - if o.UpstreamConfigsFile != "" { rawBytes, err := ioutil.ReadFile(o.UpstreamConfigsFile) if err != nil { diff --git a/internal/devproxy/providers/http_client.go b/internal/devproxy/providers/http_client.go deleted file mode 100644 index cee3d9bd..00000000 --- a/internal/devproxy/providers/http_client.go +++ /dev/null @@ -1,17 +0,0 @@ -package providers - -import ( - "net" - "net/http" - "time" -) - -var httpClient = &http.Client{ - Timeout: time.Second * 5, - Transport: &http.Transport{ - Dial: (&net.Dialer{ - Timeout: 2 * time.Second, - }).Dial, - TLSHandshakeTimeout: 2 * time.Second, - }, -} diff --git a/internal/devproxy/providers/internal_util.go b/internal/devproxy/providers/internal_util.go deleted file mode 100644 index c3714427..00000000 --- a/internal/devproxy/providers/internal_util.go +++ /dev/null @@ -1,85 +0,0 @@ -package providers - -import ( - "io/ioutil" - "net/http" - "net/url" - - log "github.com/buzzfeed/sso/internal/pkg/logging" -) - -// stripToken is a helper function to obfuscate "access_token" -// query parameters -func stripToken(endpoint string) string { - return stripParam("access_token", endpoint) -} - -// stripParam generalizes the obfuscation of a particular -// query parameter - typically 'access_token' or 'client_secret' -// The parameter's second half is replaced by '...' and returned -// as part of the encoded query parameters. -// If the target parameter isn't found, the endpoint is returned -// unmodified. -func stripParam(param, endpoint string) string { - logger := log.NewLogEntry() - - u, err := url.Parse(endpoint) - if err != nil { - logger.WithURLParam(param).Error(err, "error attempting to strip parameter") - return endpoint - } - - if u.RawQuery != "" { - values, err := url.ParseQuery(u.RawQuery) - if err != nil { - logger.WithURLParam(param).Error("error attempting to strip parameter") - return u.String() - } - - if val := values.Get(param); val != "" { - values.Set(param, val[:(len(val)/2)]+"...") - u.RawQuery = values.Encode() - return u.String() - } - } - - return endpoint -} - -// validateToken returns true if token is valid -func validateToken(p Provider, accessToken string, header http.Header) bool { - logger := log.NewLogEntry() - - if accessToken == "" || p.Data().ValidateURL == nil { - return false - } - endpoint := p.Data().ValidateURL.String() - if len(header) == 0 { - params := url.Values{"access_token": {accessToken}} - endpoint = endpoint + "?" + params.Encode() - } - - req, err := http.NewRequest("GET", endpoint, nil) - if err != nil { - logger.Error(err, "token validation request failed") - return false - } - req.Header = header - - resp, err := httpClient.Do(req) - if err != nil { - logger.Error(err, "token validation request failed") - return false - } - - body, _ := ioutil.ReadAll(resp.Body) - resp.Body.Close() - logger.Printf("%d GET %s %s", resp.StatusCode, stripToken(endpoint), body) - - if resp.StatusCode == 200 { - return true - } - logger.WithHTTPStatus(resp.StatusCode).WithResponseBody(body).Info( - "token validation request failed") - return false -} diff --git a/internal/devproxy/providers/internal_util_test.go b/internal/devproxy/providers/internal_util_test.go deleted file mode 100644 index ca5a070b..00000000 --- a/internal/devproxy/providers/internal_util_test.go +++ /dev/null @@ -1,132 +0,0 @@ -package providers - -import ( - "errors" - "net/http" - "net/http/httptest" - "net/url" - "testing" - - "github.com/buzzfeed/sso/internal/pkg/testutil" -) - -type ValidateSessionStateTestProvider struct { - *ProviderData -} - -func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *SessionState) (string, error) { - return "", errors.New("not implemented") -} - -// Note that we're testing the internal validateToken() used to implement -// several Provider's ValidateSessionState() implementations -func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *SessionState, g []string) bool { - return false -} - -type ValidateSessionStateTest struct { - backend *httptest.Server - responseCode int - provider *ValidateSessionStateTestProvider - header http.Header -} - -func NewValidateSessionStateTest() *ValidateSessionStateTest { - var vtTest ValidateSessionStateTest - - vtTest.backend = httptest.NewServer( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/oauth/tokeninfo" { - w.WriteHeader(500) - w.Write([]byte("unknown URL")) - } - tokenParam := r.FormValue("access_token") - if tokenParam == "" { - missing := false - receivedHeaders := r.Header - for k := range vtTest.header { - received := receivedHeaders.Get(k) - expected := vtTest.header.Get(k) - if received == "" || received != expected { - missing = true - } - } - if missing { - w.WriteHeader(500) - w.Write([]byte("no token param and missing or incorrect headers")) - } - } - w.WriteHeader(vtTest.responseCode) - w.Write([]byte("only code matters; contents disregarded")) - - })) - backendURL, _ := url.Parse(vtTest.backend.URL) - vtTest.provider = &ValidateSessionStateTestProvider{ - ProviderData: &ProviderData{ - ValidateURL: &url.URL{ - Scheme: "http", - Host: backendURL.Host, - Path: "/oauth/tokeninfo", - }, - }, - } - vtTest.responseCode = 200 - return &vtTest -} - -func (vtTest *ValidateSessionStateTest) Close() { - vtTest.backend.Close() -} - -func TestValidateSessionStateValidToken(t *testing.T) { - vtTest := NewValidateSessionStateTest() - defer vtTest.Close() - testutil.Equal(t, true, validateToken(vtTest.provider, "foobar", nil)) -} - -func TestValidateSessionStateValidTokenWithHeaders(t *testing.T) { - vtTest := NewValidateSessionStateTest() - defer vtTest.Close() - vtTest.header = make(http.Header) - vtTest.header.Set("Authorization", "Bearer foobar") - testutil.Equal(t, true, - validateToken(vtTest.provider, "foobar", vtTest.header)) -} - -func TestValidateSessionStateEmptyToken(t *testing.T) { - vtTest := NewValidateSessionStateTest() - defer vtTest.Close() - testutil.Equal(t, false, validateToken(vtTest.provider, "", nil)) -} - -func TestValidateSessionStateEmptyValidateURL(t *testing.T) { - vtTest := NewValidateSessionStateTest() - defer vtTest.Close() - vtTest.provider.Data().ValidateURL = nil - testutil.Equal(t, false, validateToken(vtTest.provider, "foobar", nil)) -} - -func TestValidateSessionStateRequestNetworkFailure(t *testing.T) { - vtTest := NewValidateSessionStateTest() - // Close immediately to simulate a network failure - vtTest.Close() - testutil.Equal(t, false, validateToken(vtTest.provider, "foobar", nil)) -} - -func TestValidateSessionStateExpiredToken(t *testing.T) { - vtTest := NewValidateSessionStateTest() - defer vtTest.Close() - vtTest.responseCode = 401 - testutil.Equal(t, false, validateToken(vtTest.provider, "foobar", nil)) -} - -func TestStripTokenNotPresent(t *testing.T) { - test := "http://local.test/api/test?a=1&b=2" - testutil.Equal(t, test, stripToken(test)) -} - -func TestStripToken(t *testing.T) { - test := "http://local.test/api/test?access_token=deadbeef&b=1&c=2" - expected := "http://local.test/api/test?access_token=dead...&b=1&c=2" - testutil.Equal(t, expected, stripToken(test)) -} diff --git a/internal/devproxy/providers/provider_data.go b/internal/devproxy/providers/provider_data.go deleted file mode 100644 index afaddbf0..00000000 --- a/internal/devproxy/providers/provider_data.go +++ /dev/null @@ -1,31 +0,0 @@ -package providers - -import ( - "net/url" - "time" -) - -// ProviderData holds the fields associated with providers -// necessary to implement the Provider interface. -type ProviderData struct { - ProviderName string - ProviderURL *url.URL - ClientID string - ClientSecret string - SignInURL *url.URL - SignOutURL *url.URL - RedeemURL *url.URL - RefreshURL *url.URL - ProfileURL *url.URL - ProtectedResource *url.URL - ValidateURL *url.URL - Scope string - ApprovalPrompt string - - SessionValidTTL time.Duration - SessionLifetimeTTL time.Duration - GracePeriodTTL time.Duration -} - -// Data returns the ProviderData struct -func (p *ProviderData) Data() *ProviderData { return p } diff --git a/internal/devproxy/providers/provider_default.go b/internal/devproxy/providers/provider_default.go deleted file mode 100644 index 02496b68..00000000 --- a/internal/devproxy/providers/provider_default.go +++ /dev/null @@ -1,146 +0,0 @@ -package providers - -import ( - "bytes" - "crypto/hmac" - "crypto/sha256" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "io/ioutil" - "net/http" - "net/url" - "time" -) - -// Redeem takes a redirectURL and code, creates some params and redeems the request -func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err error) { - if code == "" { - err = errors.New("missing code") - return - } - - params := url.Values{} - params.Add("redirect_uri", redirectURL) - params.Add("client_id", p.ClientID) - params.Add("client_secret", p.ClientSecret) - params.Add("code", code) - params.Add("grant_type", "authorization_code") - if p.ProtectedResource != nil && p.ProtectedResource.String() != "" { - params.Add("resource", p.ProtectedResource.String()) - } - - var req *http.Request - req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) - if err != nil { - return - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - var resp *http.Response - resp, err = httpClient.Do(req) - if err != nil { - return nil, err - } - var body []byte - body, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return - } - - if resp.StatusCode != 200 { - err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) - return - } - - // blindly try json and x-www-form-urlencoded - var jsonResponse struct { - AccessToken string `json:"access_token"` - } - err = json.Unmarshal(body, &jsonResponse) - if err == nil { - s = &SessionState{ - AccessToken: jsonResponse.AccessToken, - } - return - } - - var v url.Values - v, err = url.ParseQuery(string(body)) - if err != nil { - return - } - if a := v.Get("access_token"); a != "" { - s = &SessionState{AccessToken: a} - } else { - err = fmt.Errorf("no access token found %s", body) - } - return -} - -// GetSignInURL with typical oauth parameters -func (p *ProviderData) GetSignInURL(redirectURL *url.URL, state string) *url.URL { - var a url.URL - a = *p.SignInURL - now := time.Now() - rawRedirect := redirectURL.String() - params, _ := url.ParseQuery(a.RawQuery) - params.Set("redirect_uri", rawRedirect) - params.Add("scope", p.Scope) - params.Set("client_id", p.ClientID) - params.Set("response_type", "code") - params.Add("state", state) - params.Set("ts", fmt.Sprint(now.Unix())) - params.Set("sig", p.signRedirectURL(rawRedirect, now)) - a.RawQuery = params.Encode() - return &a -} - -// GetSignOutURL creates and returns the sign out URL, given a redirectURL -func (p *ProviderData) GetSignOutURL(redirectURL *url.URL) *url.URL { - var a url.URL - a = *p.SignOutURL - now := time.Now() - rawRedirect := redirectURL.String() - params, _ := url.ParseQuery(a.RawQuery) - params.Add("redirect_uri", rawRedirect) - params.Set("ts", fmt.Sprint(now.Unix())) - params.Set("sig", p.signRedirectURL(rawRedirect, now)) - a.RawQuery = params.Encode() - return &a -} - -// signRedirectURL signs the redirect url string, given a timestamp, and returns it -func (p *ProviderData) signRedirectURL(rawRedirect string, timestamp time.Time) string { - h := hmac.New(sha256.New, []byte(p.ClientSecret)) - h.Write([]byte(rawRedirect)) - h.Write([]byte(fmt.Sprint(timestamp.Unix()))) - return base64.URLEncoding.EncodeToString(h.Sum(nil)) -} - -// GetEmailAddress returns an email address or error -func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) { - return "", errors.New("not implemented") -} - -// ValidateGroup validates that the provided email exists in the configured provider email group(s). -func (p *ProviderData) ValidateGroup(_ string, _ []string) ([]string, bool, error) { - return []string{}, true, nil -} - -// UserGroups returns a list of users -func (p *ProviderData) UserGroups(string, []string) ([]string, error) { - return []string{}, nil -} - -// ValidateSessionState calls to validate the token given the session and groups -func (p *ProviderData) ValidateSessionState(s *SessionState, groups []string) bool { - return validateToken(p, s.AccessToken, nil) -} - -// RefreshSession returns a boolean or error -func (p *ProviderData) RefreshSession(s *SessionState, group []string) (bool, error) { - return false, nil -} diff --git a/internal/devproxy/providers/provider_default_test.go b/internal/devproxy/providers/provider_default_test.go deleted file mode 100644 index 2aa48bfe..00000000 --- a/internal/devproxy/providers/provider_default_test.go +++ /dev/null @@ -1,19 +0,0 @@ -package providers - -import ( - "testing" - "time" - - "github.com/buzzfeed/sso/internal/pkg/testutil" -) - -func TestRefresh(t *testing.T) { - p := &ProviderData{} - refreshed, err := p.RefreshSession(&SessionState{ - RefreshDeadline: time.Now().Add(time.Duration(-11) * time.Minute), - }, - []string{}, - ) - testutil.Equal(t, false, refreshed) - testutil.Equal(t, nil, err) -} diff --git a/internal/devproxy/providers/providers.go b/internal/devproxy/providers/providers.go deleted file mode 100644 index 4d33a796..00000000 --- a/internal/devproxy/providers/providers.go +++ /dev/null @@ -1,25 +0,0 @@ -package providers - -import ( - "net/url" - - "github.com/datadog/datadog-go/statsd" -) - -// Provider is an interface exposing functions necessary to authenticate with a given provider. -type Provider interface { - Data() *ProviderData - GetEmailAddress(*SessionState) (string, error) - Redeem(string, string) (*SessionState, error) - ValidateGroup(string, []string) ([]string, bool, error) - UserGroups(string, []string) ([]string, error) - ValidateSessionState(*SessionState, []string) bool - GetSignInURL(redirectURL *url.URL, finalRedirect string) *url.URL - GetSignOutURL(redirectURL *url.URL) *url.URL - RefreshSession(*SessionState, []string) (bool, error) -} - -// New returns a new sso Provider -func New(provider string, p *ProviderData, sc *statsd.Client) Provider { - return NewSSOProvider(p, sc) -} diff --git a/internal/devproxy/providers/session_state.go b/internal/devproxy/providers/session_state.go deleted file mode 100644 index c8f9d8c6..00000000 --- a/internal/devproxy/providers/session_state.go +++ /dev/null @@ -1,61 +0,0 @@ -package providers - -import ( - "time" - - "github.com/buzzfeed/sso/internal/pkg/aead" -) - -// SessionState is our object that keeps track of a user's session state -type SessionState struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - - RefreshDeadline time.Time `json:"refresh_deadline"` - LifetimeDeadline time.Time `json:"lifetime_deadline"` - ValidDeadline time.Time `json:"valid_deadline"` - GracePeriodStart time.Time `json:"grace_period_start"` - - Email string `json:"email"` - User string `json:"user"` - Groups []string `json:"groups"` -} - -// LifetimePeriodExpired returns true if the lifetime has expired -func (s *SessionState) LifetimePeriodExpired() bool { - return isExpired(s.LifetimeDeadline) -} - -// RefreshPeriodExpired returns true if the refresh period has expired -func (s *SessionState) RefreshPeriodExpired() bool { - return isExpired(s.RefreshDeadline) -} - -// ValidationPeriodExpired returns true if the validation period has expired -func (s *SessionState) ValidationPeriodExpired() bool { - return isExpired(s.ValidDeadline) -} - -func isExpired(t time.Time) bool { - if t.Before(time.Now()) { - return true - } - return false -} - -// MarshalSession marshals the session state as JSON, encrypts the JSON using the -// given cipher, and base64-encodes the result -func MarshalSession(s *SessionState, c aead.Cipher) (string, error) { - return c.Marshal(s) -} - -// UnmarshalSession takes the marshaled string, base64-decodes into a byte slice, decrypts the -// byte slice using the pased cipher, and unmarshals the resulting JSON into a session state struct -func UnmarshalSession(value string, c aead.Cipher) (*SessionState, error) { - s := &SessionState{} - err := c.Unmarshal(value, s) - if err != nil { - return nil, err - } - return s, nil -} diff --git a/internal/devproxy/providers/session_state_test.go b/internal/devproxy/providers/session_state_test.go deleted file mode 100644 index 24cd138a..00000000 --- a/internal/devproxy/providers/session_state_test.go +++ /dev/null @@ -1,72 +0,0 @@ -package providers - -import ( - "reflect" - "testing" - "time" - - "github.com/buzzfeed/sso/internal/pkg/aead" -) - -func TestSessionStateSerialization(t *testing.T) { - secret := "0123456789abcdefghijklmnopqrstuv" - - c, err := aead.NewMiscreantCipher([]byte(secret)) - if err != nil { - t.Fatalf("expected to be able to create cipher: %v", err) - } - - want := &SessionState{ - AccessToken: "token1234", - RefreshToken: "refresh4321", - - LifetimeDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), - RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), - ValidDeadline: time.Now().Add(1 * time.Minute).Truncate(time.Second).UTC(), - - Email: "user@domain.com", - User: "user", - } - - ciphertext, err := MarshalSession(want, c) - if err != nil { - t.Fatalf("expected to be encode session: %v", err) - } - - got, err := UnmarshalSession(ciphertext, c) - if err != nil { - t.Fatalf("expected to be decode session: %v", err) - } - - if !reflect.DeepEqual(want, got) { - t.Logf("want: %#v", want) - t.Logf(" got: %#v", got) - t.Errorf("encoding and decoding session resulted in unexpected output") - } -} - -func TestSessionStateExpirations(t *testing.T) { - session := &SessionState{ - AccessToken: "token1234", - RefreshToken: "refresh4321", - - LifetimeDeadline: time.Now().Add(-1 * time.Hour), - RefreshDeadline: time.Now().Add(-1 * time.Hour), - ValidDeadline: time.Now().Add(-1 * time.Minute), - - Email: "user@domain.com", - User: "user", - } - - if !session.LifetimePeriodExpired() { - t.Errorf("expcted lifetime period to be expired") - } - - if !session.RefreshPeriodExpired() { - t.Errorf("expcted lifetime period to be expired") - } - - if !session.ValidationPeriodExpired() { - t.Errorf("expcted lifetime period to be expired") - } -} diff --git a/internal/devproxy/providers/singleflight_middleware.go b/internal/devproxy/providers/singleflight_middleware.go deleted file mode 100644 index 6f583bab..00000000 --- a/internal/devproxy/providers/singleflight_middleware.go +++ /dev/null @@ -1,145 +0,0 @@ -package providers - -import ( - "errors" - "fmt" - "net/url" - "sort" - "strings" - - "github.com/buzzfeed/sso/internal/proxy/singleflight" - - "github.com/datadog/datadog-go/statsd" -) - -var ( - // This is a compile-time check to make sure our types correctly implement the interface: - // https://medium.com/@matryer/golang-tip-compile-time-checks-to-ensure-your-type-satisfies-an-interface-c167afed3aae - _ Provider = &SingleFlightProvider{} -) - -// Error message for ErrUnexpectedReturnType -var ( - ErrUnexpectedReturnType = errors.New("received unexpected return type from single flight func call") -) - -// SingleFlightProvider middleware provider that multiple requests for the same object -// to be processed as a single request. This is often called request collpasing or coalesce. -// This middleware leverages the golang singlelflight provider, with modifications for metrics. -// -// It's common among HTTP reverse proxy cache servers such as nginx, Squid or Varnish - they all call it something else but works similarly. -// -// * https://www.varnish-cache.org/docs/3.0/tutorial/handling_misbehaving_servers.html -// * http://nginx.org/en/docs/http/ngx_http_proxy_module.html#proxy_cache_lock -// * http://wiki.squid-cache.org/Features/CollapsedForwarding -type SingleFlightProvider struct { - StatsdClient *statsd.Client - - provider Provider - - single *singleflight.Group -} - -// NewSingleFlightProvider instatiates a SingleFlightProvider given a provider and statsdClient -func NewSingleFlightProvider(provider Provider, StatsdClient *statsd.Client) *SingleFlightProvider { - return &SingleFlightProvider{ - provider: provider, - single: &singleflight.Group{}, - StatsdClient: StatsdClient, - } -} - -func (p *SingleFlightProvider) do(endpoint, key string, fn func() (interface{}, error)) (interface{}, error) { - compositeKey := fmt.Sprintf("%s/%s", endpoint, key) - resp, shared, err := p.single.Do(compositeKey, fn) - if shared > 0 { - tags := []string{fmt.Sprintf("endpoint:%s", endpoint)} - p.StatsdClient.Incr("provider.singleflight", tags, float64(shared)) - } - return resp, err -} - -// Data calls the provider's Data function -func (p *SingleFlightProvider) Data() *ProviderData { - return p.provider.Data() -} - -// GetEmailAddress calls the provider function getEmailAddress -func (p *SingleFlightProvider) GetEmailAddress(s *SessionState) (string, error) { - return p.provider.GetEmailAddress(s) -} - -// Redeem takes the redirectURL and a code and calls the provider function Redeem -func (p *SingleFlightProvider) Redeem(redirectURL, code string) (*SessionState, error) { - return p.provider.Redeem(redirectURL, code) -} - -// ValidateGroup takes an email, allowedGroups, and userGroups and passes it to the provider's ValidateGroup function and returns the response -func (p *SingleFlightProvider) ValidateGroup(email string, allowedGroups []string) ([]string, bool, error) { - return p.provider.ValidateGroup(email, allowedGroups) -} - -// UserGroups takes an email and passes it to the provider's UserGroups function and returns the response -func (p *SingleFlightProvider) UserGroups(email string, groups []string) ([]string, error) { - // sort the groups so that other requests may be able to use the cached request - sort.Strings(groups) - response, err := p.do("UserGroups", fmt.Sprintf("%s:%s", email, strings.Join(groups, ",")), func() (interface{}, error) { - return p.provider.UserGroups(email, groups) - }) - if err != nil { - return nil, err - } - - groups, ok := response.([]string) - if !ok { - return nil, ErrUnexpectedReturnType - } - - return groups, nil -} - -// ValidateSessionState calls the provider's ValidateSessionState function and returns the response -func (p *SingleFlightProvider) ValidateSessionState(s *SessionState, allowedGroups []string) bool { - response, err := p.do("ValidateSessionState", s.AccessToken, func() (interface{}, error) { - valid := p.provider.ValidateSessionState(s, allowedGroups) - return valid, nil - }) - if err != nil { - return false - } - - valid, ok := response.(bool) - if !ok { - return false - } - - return valid -} - -// GetSignInURL calls the GetSignInURL for the provider, which will return the sign in url -func (p *SingleFlightProvider) GetSignInURL(redirectURI *url.URL, finalRedirect string) *url.URL { - return p.provider.GetSignInURL(redirectURI, finalRedirect) -} - -// GetSignOutURL calls the GetSignOutURL for the provider, which will return the sign out url -func (p *SingleFlightProvider) GetSignOutURL(redirectURI *url.URL) *url.URL { - return p.provider.GetSignOutURL(redirectURI) -} - -// RefreshSession takes in a SessionState and allowedGroups and -// returns false if the session is not refreshed and true if it is. -func (p *SingleFlightProvider) RefreshSession(s *SessionState, allowedGroups []string) (bool, error) { - response, err := p.do("RefreshSession", s.RefreshToken, func() (interface{}, error) { - return p.provider.RefreshSession(s, allowedGroups) - }) - if err != nil { - return false, err - } - - r, ok := response.(bool) - if !ok { - return false, ErrUnexpectedReturnType - } - - return r, nil -} diff --git a/internal/devproxy/providers/sso.go b/internal/devproxy/providers/sso.go deleted file mode 100644 index ffdb7e1c..00000000 --- a/internal/devproxy/providers/sso.go +++ /dev/null @@ -1,385 +0,0 @@ -package providers - -import ( - "bytes" - "encoding/json" - "errors" - "fmt" - "io" - "io/ioutil" - "net/http" - "net/url" - "os" - "strings" - "time" - - log "github.com/buzzfeed/sso/internal/pkg/logging" - "github.com/datadog/datadog-go/statsd" -) - -var ( - // This is a compile-time check to make sure our types correctly implement the interface: - // https://medium.com/@matryer/golang-tip-compile-time-checks-to-ensure-your-type-satisfies-an-interface-c167afed3aae - _ Provider = &SSOProvider{} -) - -// Errors -var ( - ErrMissingRefreshToken = errors.New("missing refresh token") - ErrAuthProviderUnavailable = errors.New("auth provider unavailable") -) - -var userAgentString string - -// SSOProvider holds the data associated with the SSOProviders -// necessary to implement a SSOProvider interface. -type SSOProvider struct { - *ProviderData - - StatsdClient *statsd.Client -} - -func init() { - version := os.Getenv("RIG_IMAGE_VERSION") - if version == "" { - version = "HEAD" - } else { - version = strings.Trim(version, `"`) - } - userAgentString = fmt.Sprintf("sso_proxy/%s", version) -} - -// NewSSOProvider instantiates a new SSOProvider with provider data and -// a statsd client. -func NewSSOProvider(p *ProviderData, sc *statsd.Client) *SSOProvider { - p.ProviderName = "SSO" - base := p.ProviderURL - p.SignInURL = base.ResolveReference(&url.URL{Path: "/sign_in"}) - p.SignOutURL = base.ResolveReference(&url.URL{Path: "/sign_out"}) - p.RedeemURL = base.ResolveReference(&url.URL{Path: "/redeem"}) - p.RefreshURL = base.ResolveReference(&url.URL{Path: "/refresh"}) - p.ValidateURL = base.ResolveReference(&url.URL{Path: "/validate"}) - p.ProfileURL = base.ResolveReference(&url.URL{Path: "/profile"}) - return &SSOProvider{ - ProviderData: p, - StatsdClient: sc, - } -} - -func newRequest(method, url string, body io.Reader) (*http.Request, error) { - req, err := http.NewRequest(method, url, body) - if err != nil { - return nil, err - } - req.Header.Set("User-Agent", userAgentString) - req.Header.Set("Accept", "application/json") - return req, nil -} - -func isProviderUnavailable(statusCode int) bool { - return statusCode == http.StatusTooManyRequests || statusCode == http.StatusServiceUnavailable -} - -func extendDeadline(ttl time.Duration) time.Time { - return time.Now().Add(ttl).Truncate(time.Second) -} - -func (p *SSOProvider) withinGracePeriod(s *SessionState) bool { - if s.GracePeriodStart.IsZero() { - s.GracePeriodStart = time.Now() - } - return s.GracePeriodStart.Add(p.GracePeriodTTL).After(time.Now()) -} - -// Redeem takes a redirectURL and code and redeems the SessionState -func (p *SSOProvider) Redeem(redirectURL, code string) (*SessionState, error) { - if code == "" { - return nil, errors.New("missing code") - } - - // TODO: remove "grant_type" and "redirect_uri", unused by authenticator - params := url.Values{} - params.Add("redirect_uri", redirectURL) - params.Add("client_id", p.ClientID) - params.Add("client_secret", p.ClientSecret) - params.Add("code", code) - params.Add("grant_type", "authorization_code") - - req, err := newRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - resp, err := httpClient.Do(req) - if err != nil { - return nil, err - } - - body, err := ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return nil, err - } - - if resp.StatusCode != 200 { - if isProviderUnavailable(resp.StatusCode) { - return nil, ErrAuthProviderUnavailable - } - return nil, fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) - } - - var jsonResponse struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - ExpiresIn int64 `json:"expires_in"` - Email string `json:"email"` - } - err = json.Unmarshal(body, &jsonResponse) - if err != nil { - return nil, err - } - - user := strings.Split(jsonResponse.Email, "@")[0] - return &SessionState{ - AccessToken: jsonResponse.AccessToken, - RefreshToken: jsonResponse.RefreshToken, - - RefreshDeadline: extendDeadline(time.Duration(jsonResponse.ExpiresIn) * time.Second), - LifetimeDeadline: extendDeadline(p.SessionLifetimeTTL), - ValidDeadline: extendDeadline(p.SessionValidTTL), - - Email: jsonResponse.Email, - User: user, - }, nil -} - -// ValidateGroup does a GET request to the profile url and returns true if the user belongs to -// an authorized group. -func (p *SSOProvider) ValidateGroup(email string, allowedGroups []string) ([]string, bool, error) { - logger := log.NewLogEntry() - - logger.WithUser(email).WithAllowedGroups(allowedGroups).Info("validating groups") - inGroups := []string{} - if len(allowedGroups) == 0 { - return inGroups, true, nil - } - - userGroups, err := p.UserGroups(email, allowedGroups) - if err != nil { - return nil, false, err - } - - allowed := false - for _, userGroup := range userGroups { - for _, allowedGroup := range allowedGroups { - if userGroup == allowedGroup { - inGroups = append(inGroups, userGroup) - allowed = true - } - } - } - - return inGroups, allowed, nil -} - -// UserGroups takes an email and returns the UserGroups for that email -func (p *SSOProvider) UserGroups(email string, groups []string) ([]string, error) { - params := url.Values{} - params.Add("email", email) - params.Add("client_id", p.ClientID) - params.Add("groups", strings.Join(groups, ",")) - - req, err := newRequest("GET", fmt.Sprintf("%s?%s", p.ProfileURL.String(), params.Encode()), nil) - if err != nil { - return nil, err - } - req.Header.Set("X-Client-Secret", p.ClientSecret) - resp, err := httpClient.Do(req) - if err != nil { - return nil, err - } - - var body []byte - body, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return nil, err - } - - if resp.StatusCode != 200 { - if isProviderUnavailable(resp.StatusCode) { - return nil, ErrAuthProviderUnavailable - } - return nil, fmt.Errorf("got %d from %q %s", resp.StatusCode, p.ProfileURL.String(), body) - } - - var jsonResponse struct { - Email string `json:"email"` - Groups []string `json:"groups"` - } - if err := json.Unmarshal(body, &jsonResponse); err != nil { - return nil, err - } - - return jsonResponse.Groups, nil -} - -// RefreshSession takes a SessionState and allowedGroups and refreshes the session access token, -// returns `true` on success, and `false` on error -func (p *SSOProvider) RefreshSession(s *SessionState, allowedGroups []string) (bool, error) { - logger := log.NewLogEntry() - - if s.RefreshToken == "" { - return false, ErrMissingRefreshToken - } - - newToken, duration, err := p.redeemRefreshToken(s.RefreshToken) - if err != nil { - // When we detect that the auth provider is not explicitly denying - // authentication, and is merely unavailable, we refresh and continue - // as normal during the "grace period" - if err == ErrAuthProviderUnavailable && p.withinGracePeriod(s) { - tags := []string{"action:refresh_session", "error:redeem_token_failed"} - p.StatsdClient.Incr("provider_error_fallback", tags, 1.0) - s.RefreshDeadline = extendDeadline(p.SessionValidTTL) - return true, nil - } - return false, err - } - - inGroups, validGroup, err := p.ValidateGroup(s.Email, allowedGroups) - if err != nil { - // When we detect that the auth provider is not explicitly denying - // authentication, and is merely unavailable, we refresh and continue - // as normal during the "grace period" - if err == ErrAuthProviderUnavailable && p.withinGracePeriod(s) { - tags := []string{"action:refresh_session", "error:user_groups_failed"} - p.StatsdClient.Incr("provider_error_fallback", tags, 1.0) - s.RefreshDeadline = extendDeadline(p.SessionValidTTL) - return true, nil - } - return false, err - } - if !validGroup { - return false, errors.New("Group membership revoked") - } - s.Groups = inGroups - - s.AccessToken = newToken - s.RefreshDeadline = extendDeadline(duration) - s.GracePeriodStart = time.Time{} - logger.WithUser(s.Email).WithRefreshDeadline(s.RefreshDeadline).Info("refreshed session access token") - return true, nil -} - -func (p *SSOProvider) redeemRefreshToken(refreshToken string) (token string, expires time.Duration, err error) { - // https://developers.google.com/identity/protocols/OAuth2WebServer#refresh - params := url.Values{} - params.Add("client_id", p.ClientID) - params.Add("client_secret", p.ClientSecret) - params.Add("refresh_token", refreshToken) - var req *http.Request - req, err = newRequest("POST", p.RefreshURL.String(), bytes.NewBufferString(params.Encode())) - if err != nil { - return - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - resp, err := httpClient.Do(req) - if err != nil { - return - } - var body []byte - body, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return - } - - if resp.StatusCode != http.StatusCreated { - if isProviderUnavailable(resp.StatusCode) { - err = ErrAuthProviderUnavailable - } else { - err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RefreshURL.String(), body) - } - return - } - - var data struct { - AccessToken string `json:"access_token"` - ExpiresIn int64 `json:"expires_in"` - } - err = json.Unmarshal(body, &data) - if err != nil { - return - } - token = data.AccessToken - expires = time.Duration(data.ExpiresIn) * time.Second - return -} - -// ValidateSessionState takes a sessionState and allowedGroups and validates the session state -func (p *SSOProvider) ValidateSessionState(s *SessionState, allowedGroups []string) bool { - logger := log.NewLogEntry() - - // we validate the user's access token is valid - params := url.Values{} - params.Add("client_id", p.ClientID) - req, err := newRequest("GET", fmt.Sprintf("%s?%s", p.ValidateURL.String(), params.Encode()), nil) - if err != nil { - logger.WithUser(s.Email).Error(err, "error validating session state") - return false - } - req.Header.Set("X-Client-Secret", p.ClientSecret) - req.Header.Set("X-Access-Token", s.AccessToken) - - resp, err := httpClient.Do(req) - if err != nil { - logger.WithUser(s.Email).Error("error making request to validate access token") - return false - } - - if resp.StatusCode != 200 { - // When we detect that the auth provider is not explicitly denying - // authentication, and is merely unavailable, we validate and continue - // as normal during the "grace period" - if isProviderUnavailable(resp.StatusCode) && p.withinGracePeriod(s) { - tags := []string{"action:validate_session", "error:validation_failed"} - p.StatsdClient.Incr("provider_error_fallback", tags, 1.0) - s.ValidDeadline = extendDeadline(p.SessionValidTTL) - return true - } - logger.WithUser(s.Email).WithHTTPStatus(resp.StatusCode).Info( - "could not validate user access token") - return false - } - - // check the user is in the proper group(s) - inGroups, validGroup, err := p.ValidateGroup(s.Email, allowedGroups) - if err != nil { - // When we detect that the auth provider is not explicitly denying - // authentication, and is merely unavailable, we validate and continue - // as normal during the "grace period" - if err == ErrAuthProviderUnavailable && p.withinGracePeriod(s) { - tags := []string{"action:validate_session", "error:user_groups_failed"} - p.StatsdClient.Incr("provider_error_fallback", tags, 1.0) - s.ValidDeadline = extendDeadline(p.SessionValidTTL) - return true - } - logger.WithUser(s.Email).Error(err, "error fetching group memberships") - return false - } - - if !validGroup { - logger.WithUser(s.Email).WithAllowedGroups(allowedGroups).Info( - "user is no longer in valid groups") - return false - } - s.Groups = inGroups - - s.ValidDeadline = extendDeadline(p.SessionValidTTL) - s.GracePeriodStart = time.Time{} - - logger.WithUser(s.Email).WithSessionValid(s.ValidDeadline).Info("validated session") - - return true -} diff --git a/internal/devproxy/providers/sso_test.go b/internal/devproxy/providers/sso_test.go deleted file mode 100644 index 88a01f81..00000000 --- a/internal/devproxy/providers/sso_test.go +++ /dev/null @@ -1,550 +0,0 @@ -package providers - -import ( - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "net/url" - "os" - "strings" - "testing" - "time" - - "github.com/buzzfeed/sso/internal/pkg/testutil" -) - -func newTestServer(status int, body []byte) (*url.URL, *httptest.Server) { - s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - rw.WriteHeader(status) - rw.Write(body) - })) - u, _ := url.Parse(s.URL) - return u, s -} - -func newCodeTestServer(code int) (*url.URL, *httptest.Server) { - s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - rw.WriteHeader(code) - })) - u, _ := url.Parse(s.URL) - return u, s -} - -func newSSOProvider() *SSOProvider { - return NewSSOProvider( - &ProviderData{ - ProviderURL: &url.URL{ - Scheme: "https", - Host: "auth.example.com", - }, - }, nil) -} - -func TestNewRequest(t *testing.T) { - testCases := []struct { - name string - url string - expectedError bool - }{ - { - name: "error on new request", - url: ":", - expectedError: true, - }, - { - name: "optional headers set", - url: "/", - expectedError: false, - }, - } - os.Setenv("RIG_IMAGE_VERSION", "testVersion") - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - p := newSSOProvider() - if p == nil { - t.Fatalf("expected provider to not be nil but was") - } - req, err := newRequest("GET", tc.url, nil) - if tc.expectedError && err == nil { - t.Errorf("expected error but error was nil") - } - if !tc.expectedError && err != nil { - t.Errorf("unexpected error %s", err.Error()) - } - if err != nil { - return - } - if req.Header.Get("User-Agent") == "testVersion" { - t.Errorf("expected User-Agent header to be set but it was not") - } - - }) - } - -} - -func TestSSOProviderDefaults(t *testing.T) { - p := newSSOProvider() - testutil.NotEqual(t, nil, p) - - data := p.Data() - testutil.Equal(t, "SSO", data.ProviderName) - - base := fmt.Sprintf("%s://%s", data.ProviderURL.Scheme, data.ProviderURL.Host) - testutil.Equal(t, fmt.Sprintf("%s/sign_in", base), data.SignInURL.String()) - testutil.Equal(t, fmt.Sprintf("%s/sign_out", base), data.SignOutURL.String()) - testutil.Equal(t, fmt.Sprintf("%s/redeem", base), data.RedeemURL.String()) - testutil.Equal(t, fmt.Sprintf("%s/refresh", base), data.RefreshURL.String()) - testutil.Equal(t, fmt.Sprintf("%s/validate", base), data.ValidateURL.String()) - testutil.Equal(t, fmt.Sprintf("%s/profile", base), data.ProfileURL.String()) -} - -type redeemResponse struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - ExpiresIn int64 `json:"expires_in"` - Email string `json:"email"` -} - -type refreshResponse struct { - Code int - AccessToken string `json:"access_token"` - ExpiresIn int64 `json:"expires_in"` -} - -type profileResponse struct { - Email string `json:"email"` - Groups []string `json:"groups"` -} - -func TestSSOProviderGroups(t *testing.T) { - testCases := []struct { - Name string - Email string - Groups []string - ProxyGroupIds []string - ExpectedValid bool - ExpectedInGroups []string - ExpectError error - ProfileStatus int - }{ - { - Name: "valid when no group id set", - Email: "michael.bland@gsa.gov", - Groups: []string{}, - ProxyGroupIds: []string{}, - ExpectedValid: true, - ExpectedInGroups: []string{}, - ExpectError: nil, - }, - { - Name: "valid when the group id exists", - Email: "michael.bland@gsa.gov", - Groups: []string{"user-in-this-group", "random-group"}, - ProxyGroupIds: []string{"user-in-this-group", "user-not-in-this-group"}, - ExpectedValid: true, - ExpectedInGroups: []string{"user-in-this-group"}, - ExpectError: nil, - }, - { - Name: "valid when the multiple group id exists", - Email: "michael.bland@gsa.gov", - Groups: []string{"user-in-this-group", "user-also-in-this-group"}, - ProxyGroupIds: []string{"user-in-this-group", "user-also-in-this-group"}, - ExpectedValid: true, - ExpectedInGroups: []string{"user-in-this-group", "user-also-in-this-group"}, - ExpectError: nil, - }, - { - Name: "invalid when the group id isn't in user groups", - Email: "michael.bland@gsa.gov", - Groups: []string{}, - ProxyGroupIds: []string{"test1"}, - ExpectedValid: false, - ExpectedInGroups: []string{}, - ExpectError: nil, - }, - { - Name: "invalid if can't access groups", - Email: "michael.bland@gsa.gov", - Groups: []string{}, - ProxyGroupIds: []string{"session-group"}, - ProfileStatus: http.StatusTooManyRequests, - ExpectError: ErrAuthProviderUnavailable, - }, - } - for _, tc := range testCases { - t.Run(tc.Name, func(t *testing.T) { - p := newSSOProvider() - body, err := json.Marshal(profileResponse{ - Email: tc.Email, - Groups: tc.Groups, - }) - testutil.Equal(t, nil, err) - var server *httptest.Server - profileStatus := http.StatusOK - if tc.ProfileStatus != 0 { - profileStatus = tc.ProfileStatus - } - p.ProfileURL, server = newTestServer(profileStatus, body) - defer server.Close() - inGroups, valid, err := p.ValidateGroup(tc.Email, tc.ProxyGroupIds) - testutil.Equal(t, tc.ExpectError, err) - if err == nil { - testutil.Equal(t, tc.ExpectedValid, valid) - testutil.Equal(t, tc.ExpectedInGroups, inGroups) - } - }) - } -} - -func TestSSOProviderGetEmailAddress(t *testing.T) { - testCases := []struct { - Name string - Code string - ExpectedError string - RedeemResponse *redeemResponse - ProfileResponse *profileResponse - }{ - { - Name: "redeem fails without code", - ExpectedError: "missing code", - }, - { - Name: "redeem fails if redemption server not responding", - Code: "code1234", - ExpectedError: "got 400", - }, - { - Name: "redeem successful", - Code: "code1234", - RedeemResponse: &redeemResponse{ - AccessToken: "a1234", - ExpiresIn: 10, - RefreshToken: "refresh12345", - Email: "michael.bland@gsa.gov", - }, - ProfileResponse: &profileResponse{ - Email: "michael.bland@gsa.gov", - Groups: []string{"core@gsa.gov"}, - }, - }, - } - for _, tc := range testCases { - t.Run(tc.Name, func(t *testing.T) { - p := newSSOProvider() - - var redeemServer *httptest.Server - // set up redemption resource - if tc.RedeemResponse != nil { - body, err := json.Marshal(tc.RedeemResponse) - testutil.Equal(t, nil, err) - p.RedeemURL, redeemServer = newTestServer(http.StatusOK, body) - } else { - p.RedeemURL, redeemServer = newCodeTestServer(400) - } - defer redeemServer.Close() - - var profileServer *httptest.Server - if tc.ProfileResponse != nil { - body, err := json.Marshal(tc.ProfileResponse) - testutil.Equal(t, nil, err) - p.ProfileURL, profileServer = newTestServer(http.StatusOK, body) - } else { - p.RedeemURL, profileServer = newCodeTestServer(400) - } - defer profileServer.Close() - - session, err := p.Redeem("http://redirect/", tc.Code) - if tc.RedeemResponse != nil { - testutil.Equal(t, nil, err) - testutil.NotEqual(t, session, nil) - testutil.Equal(t, tc.RedeemResponse.Email, session.Email) - testutil.Equal(t, tc.RedeemResponse.AccessToken, session.AccessToken) - testutil.Equal(t, tc.RedeemResponse.RefreshToken, session.RefreshToken) - } - if tc.ExpectedError != "" && !strings.Contains(err.Error(), tc.ExpectedError) { - t.Errorf("got unexpected result.\nwant=%v\ngot=%v\n", tc.ExpectedError, err.Error()) - } - }) - } -} - -func TestSSOProviderValidateSessionState(t *testing.T) { - testCases := []struct { - Name string - SessionState *SessionState - ProviderResponse int - Groups []string - ProxyGroupIds []string - ExpectedValid bool - }{ - { - Name: "valid when no group id set", - SessionState: &SessionState{ - AccessToken: "abc", - Email: "michael.bland@gsa.gov", - }, - ProviderResponse: http.StatusOK, - Groups: []string{}, - ProxyGroupIds: []string{}, - ExpectedValid: true, - }, - { - Name: "invalid when response is is not 200", - SessionState: &SessionState{ - AccessToken: "abc", - Email: "michael.bland@gsa.gov", - }, - ProviderResponse: http.StatusForbidden, - Groups: []string{}, - ProxyGroupIds: []string{}, - ExpectedValid: false, - }, - { - Name: "valid when the group id exists", - SessionState: &SessionState{ - AccessToken: "abc", - Email: "michael.bland@gsa.gov", - }, - ProviderResponse: http.StatusOK, - Groups: []string{"test1", "test2"}, - ProxyGroupIds: []string{"test1"}, - ExpectedValid: true, - }, - { - Name: "invalid when the group id isn't in user groups", - SessionState: &SessionState{ - AccessToken: "abc", - Email: "michael.bland@gsa.gov", - }, - ProviderResponse: http.StatusOK, - Groups: []string{}, - ProxyGroupIds: []string{"test1"}, - ExpectedValid: false, - }, - { - Name: "valid when provider unavailable, but grace period active", - SessionState: &SessionState{ - GracePeriodStart: time.Now().Add(time.Duration(-1) * time.Hour), - }, - ProviderResponse: http.StatusTooManyRequests, - ExpectedValid: true, - }, - { - Name: "invalid when provider unavailable and grace period inactive", - SessionState: &SessionState{ - GracePeriodStart: time.Now().Add(time.Duration(-4) * time.Hour), - }, - ProviderResponse: http.StatusServiceUnavailable, - ExpectedValid: false, - }, - } - for _, tc := range testCases { - t.Run(tc.Name, func(t *testing.T) { - p := newSSOProvider() - p.GracePeriodTTL = time.Duration(3) * time.Hour - - // setup group endpoint - body, err := json.Marshal(profileResponse{ - Email: tc.SessionState.Email, - Groups: tc.Groups, - }) - testutil.Equal(t, nil, err) - var profileServer *httptest.Server - p.ProfileURL, profileServer = newTestServer(http.StatusOK, body) - defer profileServer.Close() - - validateServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - accessToken := r.Header.Get("X-Access-Token") - if accessToken != tc.SessionState.AccessToken { - t.Logf("want: %v", tc.SessionState.AccessToken) - t.Logf(" got: %v", accessToken) - t.Fatalf("unexpected access token value") - } - rw.WriteHeader(tc.ProviderResponse) - })) - p.ValidateURL, _ = url.Parse(validateServer.URL) - defer validateServer.Close() - - valid := p.ValidateSessionState(tc.SessionState, tc.ProxyGroupIds) - if valid != tc.ExpectedValid { - t.Errorf("got unexpected result. want=%v got=%v", tc.ExpectedValid, valid) - } - }) - } -} - -func TestSSOProviderRefreshSession(t *testing.T) { - testCases := []struct { - Name string - SessionState *SessionState - UserGroups []string - ProxyGroups []string - RefreshResponse *refreshResponse - ExpectedRefresh bool - ExpectedError string - }{ - { - Name: "no refresh if no refresh token", - SessionState: &SessionState{ - Email: "user@domain.com", - AccessToken: "token1234", - RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), - }, - RefreshResponse: &refreshResponse{ - Code: http.StatusBadRequest, - }, - ExpectedRefresh: false, - }, - { - Name: "no refresh if not yet expired", - SessionState: &SessionState{ - Email: "user@domain.com", - AccessToken: "token1234", - RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), - RefreshToken: "refresh1234", - }, - RefreshResponse: &refreshResponse{ - Code: http.StatusBadRequest, - }, - ExpectedRefresh: false, - }, - { - Name: "no refresh if redeem not responding", - SessionState: &SessionState{ - Email: "user@domain.com", - AccessToken: "token1234", - RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), - RefreshToken: "refresh1234", - }, - RefreshResponse: &refreshResponse{ - Code: http.StatusBadRequest, - }, - ExpectedRefresh: false, - ExpectedError: "got 400", - }, - { - Name: "no refresh if profile not responding", - SessionState: &SessionState{ - Email: "user@domain.com", - AccessToken: "token1234", - RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), - RefreshToken: "refresh1234", - }, - RefreshResponse: &refreshResponse{ - Code: http.StatusCreated, - ExpiresIn: 10, - AccessToken: "newToken1234", - }, - ProxyGroups: []string{"test1"}, - ExpectedRefresh: false, - ExpectedError: "got 500", - }, - { - Name: "no refresh if user no longer in group", - SessionState: &SessionState{ - Email: "user@domain.com", - AccessToken: "token1234", - RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), - RefreshToken: "refresh1234", - }, - UserGroups: []string{"useless"}, - ProxyGroups: []string{"test1"}, - RefreshResponse: &refreshResponse{ - Code: http.StatusCreated, - ExpiresIn: 10, - AccessToken: "newToken1234", - }, - ExpectedRefresh: false, - ExpectedError: "Group membership revoked", - }, - { - Name: "successful refresh if can redeem and user in group", - SessionState: &SessionState{ - Email: "user@domain.com", - AccessToken: "token1234", - RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), - RefreshToken: "refresh1234", - }, - UserGroups: []string{"test1"}, - ProxyGroups: []string{"test1"}, - RefreshResponse: &refreshResponse{ - Code: http.StatusCreated, - ExpiresIn: 10, - AccessToken: "newToken1234", - }, - ExpectedRefresh: true, - }, - { - Name: "successful refresh if provider unavailable but within grace period", - SessionState: &SessionState{ - GracePeriodStart: time.Now().Add(time.Duration(-1) * time.Hour), - RefreshToken: "refresh1234", - }, - RefreshResponse: &refreshResponse{ - Code: http.StatusTooManyRequests, - }, - ExpectedRefresh: true, - }, - { - Name: "failed refresh if provider unavailable and outside grace period", - SessionState: &SessionState{ - GracePeriodStart: time.Now().Add(time.Duration(-4) * time.Hour), - RefreshToken: "refresh1234", - }, - RefreshResponse: &refreshResponse{ - Code: http.StatusTooManyRequests, - }, - ExpectedRefresh: false, - }, - } - - for _, tc := range testCases { - t.Run(tc.Name, func(t *testing.T) { - p := newSSOProvider() - p.GracePeriodTTL = time.Duration(3) * time.Hour - - groups := []string{} - if tc.ProxyGroups != nil { - groups = tc.ProxyGroups - } - - // set up redeem resource - var refreshServer *httptest.Server - body, err := json.Marshal(tc.RefreshResponse) - testutil.Equal(t, nil, err) - p.RefreshURL, refreshServer = newTestServer(tc.RefreshResponse.Code, body) - defer refreshServer.Close() - - // set up groups resource - var groupsServer *httptest.Server - if tc.UserGroups != nil { - body, err := json.Marshal(profileResponse{ - Email: tc.SessionState.Email, - Groups: tc.UserGroups, - }) - testutil.Equal(t, nil, err) - p.ProfileURL, groupsServer = newTestServer(http.StatusOK, body) - } else { - p.ProfileURL, groupsServer = newCodeTestServer(500) - } - defer groupsServer.Close() - - // run the endpoint - actualRefresh, err := p.RefreshSession(tc.SessionState, groups) - if tc.ExpectedRefresh != actualRefresh { - t.Fatalf("got unexpected refresh behavior. want=%v got=%v", tc.ExpectedRefresh, actualRefresh) - } - - if tc.ExpectedError != "" && err == nil { - t.Fatalf("expected error: %v got: %v", tc.ExpectedError, err) - } - - if tc.ExpectedError != "" && !strings.Contains(err.Error(), tc.ExpectedError) { - t.Fatalf("got unexpected result.\nwant=%v\ngot=%v\n", tc.ExpectedError, err.Error()) - } - }) - } -} From 6546833a53458278351924e395bd4248f2f9bd03 Mon Sep 17 00:00:00 2001 From: Maliheh Date: Tue, 15 Jan 2019 16:10:22 -0800 Subject: [PATCH 3/5] added request signing --- internal/devproxy/dev_config.go | 40 +++- internal/devproxy/devproxy.go | 117 +++++++++-- internal/devproxy/devproxy_test.go | 43 +++- internal/devproxy/logging_handler.go | 4 +- internal/devproxy/middleware.go | 2 - internal/devproxy/options.go | 2 + internal/devproxy/request_signer.go | 191 ++++++++++++++++++ internal/devproxy/request_signer_test.go | 179 ++++++++++++++++ .../devproxy/testdata/upstream_configs.yml | 4 - 9 files changed, 541 insertions(+), 41 deletions(-) create mode 100644 internal/devproxy/request_signer.go create mode 100644 internal/devproxy/request_signer_test.go diff --git a/internal/devproxy/dev_config.go b/internal/devproxy/dev_config.go index 0afa0866..1e09638b 100644 --- a/internal/devproxy/dev_config.go +++ b/internal/devproxy/dev_config.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "github.com/18F/hmacauth" "github.com/imdario/mergo" "gopkg.in/yaml.v2" ) @@ -47,10 +48,13 @@ type UpstreamConfig struct { ExtraRoutes []*RouteConfig `yaml:"extra_routes"` // Generated at Parse Time - Route interface{} // note: :/ - Timeout time.Duration - FlushInterval time.Duration - HeaderOverrides map[string]string + Route interface{} // note: :/ + HMACAuth hmacauth.HmacAuth + Timeout time.Duration + FlushInterval time.Duration + HeaderOverrides map[string]string + TLSSkipVerify bool + SkipRequestSigning bool } // RouteConfig maps to the yaml config fields, @@ -71,9 +75,10 @@ type RouteConfig struct { // * timeout - duration before timing out request. // * flush_interval - interval at which the proxy should flush data to the browser type OptionsConfig struct { - HeaderOverrides map[string]string `yaml:"header_overrides"` - Timeout time.Duration `yaml:"timeout"` - FlushInterval time.Duration `yaml:"flush_interval"` + HeaderOverrides map[string]string `yaml:"header_overrides"` + Timeout time.Duration `yaml:"timeout"` + FlushInterval time.Duration `yaml:"flush_interval"` + SkipRequestSigning bool `yaml:"skip_request_signing"` } // ErrParsingConfig is an error specific to config parsing. @@ -102,7 +107,7 @@ func loadServiceConfigs(raw []byte, cluster, scheme string, configVars map[strin // we don't set this to the len(serviceConfig) since not all service configs // are configured for all clusters, leaving nil tail pointers in the slice. configs := make([]*UpstreamConfig, 0) - // resovle overrides + // resolve overrides for _, service := range serviceConfigs { proxy, err := resolveUpstreamConfig(service, cluster) if err != nil { @@ -174,6 +179,23 @@ func loadServiceConfigs(raw []byte, cluster, scheme string, configVars map[strin } } + for _, proxy := range configs { + key := fmt.Sprintf("%s_signing_key", proxy.Service) + signingKey, ok := configVars[key] + if !ok { + continue + } + auth, err := generateHmacAuth(signingKey) + if err != nil { + return nil, &ErrParsingConfig{ + Message: fmt.Sprintf("unable to generate hmac auth for %s", proxy.Service), + Err: err, + } + } + proxy.HMACAuth = auth + + } + return configs, nil } @@ -325,7 +347,7 @@ func parseOptionsConfig(proxy *UpstreamConfig) error { proxy.Timeout = proxy.RouteConfig.Options.Timeout proxy.FlushInterval = proxy.RouteConfig.Options.FlushInterval proxy.HeaderOverrides = proxy.RouteConfig.Options.HeaderOverrides - + proxy.SkipRequestSigning = proxy.RouteConfig.Options.SkipRequestSigning proxy.RouteConfig.Options = nil return nil diff --git a/internal/devproxy/devproxy.go b/internal/devproxy/devproxy.go index 74d8d1eb..03b8876c 100644 --- a/internal/devproxy/devproxy.go +++ b/internal/devproxy/devproxy.go @@ -1,10 +1,12 @@ package devproxy import ( + "crypto/tls" "encoding/json" "fmt" "html/template" "io" + "net" "net/http" "net/http/httputil" "net/url" @@ -12,11 +14,12 @@ import ( "strings" "time" + "github.com/18F/hmacauth" log "github.com/buzzfeed/sso/internal/pkg/logging" ) -// SignatureHeader is the header name where the signed request header is stored. -const SignatureHeader = "Gap-Signature" +// HMACSignatureHeader is the header name where the signed request header is stored. +const HMACSignatureHeader = "Gap-Signature" // SignatureHeaders are the headers that are valid in the request. var SignatureHeaders = []string{ @@ -35,6 +38,8 @@ type DevProxy struct { templates *template.Template mux map[string]*route regexRoutes []*route + requestSigner *RequestSigner + publicCertsJSON []byte } type route struct { @@ -54,13 +59,16 @@ type StateParameter struct { // UpstreamProxy stores information necessary for proxying the request back to the upstream. type UpstreamProxy struct { - name string - handler http.Handler + name string + handler http.Handler + requestSigner *RequestSigner } // upstreamTransport is used to ensure that upstreams cannot override the // security headers applied by dev_proxy -type upstreamTransport struct{} +type upstreamTransport struct { + transport *http.Transport +} // RoundTrip round trips the request and deletes security headers before returning the response. func (t *upstreamTransport) RoundTrip(req *http.Request) (*http.Response, error) { @@ -76,8 +84,29 @@ func (t *upstreamTransport) RoundTrip(req *http.Request) (*http.Response, error) return resp, err } +func newUpstreamTransport(insecureSkipVerify bool) *upstreamTransport { + return &upstreamTransport{ + transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + TLSClientConfig: &tls.Config{InsecureSkipVerify: insecureSkipVerify}, + ExpectContinueTimeout: 1 * time.Second, + }, + } +} + // ServeHTTP calls the upstream's ServeHTTP function. func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if u.requestSigner != nil { + u.requestSigner.Sign(r) + } start := time.Now() u.handler.ServeHTTP(w, r) @@ -100,7 +129,7 @@ func singleJoiningSlash(a, b string) string { // NewReverseProxy creates a reverse proxy to a specified url. // It adds an X-Forwarded-Host header that is the request's host. -func NewReverseProxy(to *url.URL) *httputil.ReverseProxy { +func NewReverseProxy(to *url.URL, config *UpstreamConfig) *httputil.ReverseProxy { targetQuery := to.RawQuery director := func(req *http.Request) { req.URL.Scheme = to.Scheme @@ -118,7 +147,7 @@ func NewReverseProxy(to *url.URL) *httputil.ReverseProxy { } } proxy := &httputil.ReverseProxy{Director: director} - proxy.Transport = &upstreamTransport{} + proxy.Transport = newUpstreamTransport(config.TLSSkipVerify) dir := proxy.Director proxy.Director = func(req *http.Request) { @@ -136,9 +165,9 @@ func NewReverseProxy(to *url.URL) *httputil.ReverseProxy { // NewRewriteReverseProxy creates a reverse proxy that is capable of creating upstream // urls on the fly based on a from regex and a templated to field. // It adds an X-Forwarded-Host header to the the upstream's request. -func NewRewriteReverseProxy(route *RewriteRoute) *httputil.ReverseProxy { +func NewRewriteReverseProxy(route *RewriteRoute, config *UpstreamConfig) *httputil.ReverseProxy { proxy := &httputil.ReverseProxy{} - proxy.Transport = &upstreamTransport{} + proxy.Transport = newUpstreamTransport(config.TLSSkipVerify) proxy.Director = func(req *http.Request) { // we do this to rewrite requests rewritten := route.FromRegex.ReplaceAllString(req.Host, route.ToTemplate.Opaque) @@ -167,11 +196,16 @@ func NewRewriteReverseProxy(route *RewriteRoute) *httputil.ReverseProxy { } // NewReverseProxyHandler creates a new http.Handler given a httputil.ReverseProxy -func NewReverseProxyHandler(reverseProxy *httputil.ReverseProxy, opts *Options, config *UpstreamConfig) (http.Handler, []string) { +func NewReverseProxyHandler(reverseProxy *httputil.ReverseProxy, opts *Options, config *UpstreamConfig, signer *RequestSigner) (http.Handler, []string) { upstreamProxy := &UpstreamProxy{ - name: config.Service, - handler: reverseProxy, + name: config.Service, + handler: reverseProxy, + requestSigner: signer, + } + if config.SkipRequestSigning { + upstreamProxy.requestSigner = nil } + if config.FlushInterval != 0 { return NewStreamingHandler(upstreamProxy, opts, config), []string{"handler:streaming"} } @@ -197,19 +231,56 @@ func NewStreamingHandler(handler http.Handler, opts *Options, config *UpstreamCo return upstreamProxy } +func generateHmacAuth(signatureKey string) (hmacauth.HmacAuth, error) { + components := strings.Split(signatureKey, ":") + if len(components) != 2 { + return nil, fmt.Errorf("invalid signature hash:key spec") + } + + algorithm, secret := components[0], components[1] + hash, err := hmacauth.DigestNameToCryptoHash(algorithm) + if err != nil { + return nil, fmt.Errorf("unsupported signature hash algorithm: %s", algorithm) + } + auth := hmacauth.NewHmacAuth(hash, []byte(secret), HMACSignatureHeader, SignatureHeaders) + return auth, nil +} + // NewDevProxy creates a new DevProxy struct. func NewDevProxy(opts *Options, optFuncs ...func(*DevProxy) error) (*DevProxy, error) { logger := log.NewLogEntry() logger.Info("NewDevProxy...") + // Configure the RequestSigner (used to sign requests with `Sso-Signature` header). + // Also build the `certs` static JSON-string which will be served from a public endpoint. + // The key published at this endpoint allows upstreams to decrypt the `Sso-Signature` + // header, and validate the integrity and authenticity of a request. + certs := make(map[string]string) + var requestSigner *RequestSigner + if len(opts.RequestSigningKey) > 0 { + requestSigner, err := NewRequestSigner(opts.RequestSigningKey) + if err != nil { + return nil, fmt.Errorf("could not build RequestSigner: %s", err) + } + id, key := requestSigner.PublicKey() + certs[id] = key + } else { + logger.Warn("Running DevProxy without signing key. Requests will not be signed.") + } + certsAsStr, err := json.MarshalIndent(certs, "", " ") + if err != nil { + return nil, fmt.Errorf("could not marshal public certs as JSON: %s", err) + } + p := &DevProxy{ // these fields make up the routing mechanism mux: make(map[string]*route), regexRoutes: make([]*route, 0), - redirectURL: &url.URL{Path: "/oauth2/callback"}, - // skipAuthPreflight: opts.SkipAuthPreflight, - templates: getTemplates(), + redirectURL: &url.URL{Path: "/oauth2/callback"}, + templates: getTemplates(), + requestSigner: requestSigner, + publicCertsJSON: certsAsStr, } for _, optFunc := range optFuncs { @@ -222,12 +293,12 @@ func NewDevProxy(opts *Options, optFuncs ...func(*DevProxy) error) (*DevProxy, e for _, upstreamConfig := range opts.upstreamConfigs { switch route := upstreamConfig.Route.(type) { case *SimpleRoute: - reverseProxy := NewReverseProxy(route.ToURL) - handler, tags := NewReverseProxyHandler(reverseProxy, opts, upstreamConfig) + reverseProxy := NewReverseProxy(route.ToURL, upstreamConfig) + handler, tags := NewReverseProxyHandler(reverseProxy, opts, upstreamConfig, requestSigner) p.Handle(route.FromURL.Host, handler, tags, upstreamConfig) case *RewriteRoute: - reverseProxy := NewRewriteReverseProxy(route) - handler, tags := NewReverseProxyHandler(reverseProxy, opts, upstreamConfig) + reverseProxy := NewRewriteReverseProxy(route, upstreamConfig) + handler, tags := NewReverseProxyHandler(reverseProxy, opts, upstreamConfig, requestSigner) p.HandleRegex(route.FromRegex, handler, tags, upstreamConfig) default: return nil, fmt.Errorf("unkown route type") @@ -242,7 +313,7 @@ func (p *DevProxy) Handler() http.Handler { mux := http.NewServeMux() mux.HandleFunc("/favicon.ico", p.Favicon) mux.HandleFunc("/robots.txt", p.RobotsTxt) - // mux.HandleFunc("/oauth2/callback", p.DevCallback) + mux.HandleFunc("/oauth2/v1/certs", p.Certs) mux.HandleFunc("/", p.Proxy) // Global middleware, which will be applied to each request in reverse @@ -370,3 +441,9 @@ func (p *DevProxy) router(req *http.Request) (*route, bool) { return nil, false } + +// Certs publishes the public key necessary for upstream services to validate the digital signature +// used to sign each request. +func (p *DevProxy) Certs(rw http.ResponseWriter, _ *http.Request) { + rw.Write(p.publicCertsJSON) +} diff --git a/internal/devproxy/devproxy_test.go b/internal/devproxy/devproxy_test.go index c734e655..268d8019 100644 --- a/internal/devproxy/devproxy_test.go +++ b/internal/devproxy/devproxy_test.go @@ -1,6 +1,8 @@ package devproxy import ( + "crypto/sha256" + "encoding/hex" "encoding/json" "fmt" "io" @@ -43,7 +45,7 @@ func TestNewReverseProxy(t *testing.T) { backendHost := net.JoinHostPort(backendHostname, backendPort) proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/") - proxyHandler := NewReverseProxy(proxyURL) + proxyHandler := NewReverseProxy(proxyURL, &UpstreamConfig{TLSSkipVerify: false}) frontend := httptest.NewServer(proxyHandler) defer frontend.Close() @@ -75,7 +77,7 @@ func TestNewRewriteReverseProxy(t *testing.T) { }, } - rewriteProxy := NewRewriteReverseProxy(route) + rewriteProxy := NewRewriteReverseProxy(route, &UpstreamConfig{TLSSkipVerify: false}) frontend := httptest.NewServer(rewriteProxy) defer frontend.Close() @@ -130,7 +132,7 @@ func TestNewReverseProxyHostname(t *testing.T) { t.Fatalf("expected to parse to url: %s", err) } - reverseProxy := NewReverseProxy(toURL) + reverseProxy := NewReverseProxy(toURL, &UpstreamConfig{TLSSkipVerify: false}) from := httptest.NewServer(reverseProxy) defer from.Close() @@ -251,6 +253,41 @@ func TestRobotsTxt(t *testing.T) { testutil.Equal(t, "User-agent: *\nDisallow: /", rw.Body.String()) } +func TestCerts(t *testing.T) { + opts := NewOptions() + opts.upstreamConfigs = generateTestUpstreamConfigs("foo-internal.sso.dev") + + requestSigningKey, err := ioutil.ReadFile("testdata/private_key.pem") + testutil.Assert(t, err == nil, "could not read private key from testdata: %s", err) + opts.RequestSigningKey = string(requestSigningKey) + opts.Validate() + + expectedPublicKey, err := ioutil.ReadFile("testdata/public_key.pub") + testutil.Assert(t, err == nil, "could not read public key from testdata: %s", err) + + var keyHash []byte + hasher := sha256.New() + _, _ = hasher.Write(expectedPublicKey) + keyHash = hasher.Sum(keyHash) + + proxy, err := NewDevProxy(opts) + if err != nil { + t.Errorf("unexpected error %s", err) + return + } + rw := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "https://foo.sso.dev/oauth2/v1/certs", nil) + proxy.Handler().ServeHTTP(rw, req) + testutil.Equal(t, 200, rw.Code) + + var certs map[string]string + if err := json.Unmarshal([]byte(rw.Body.String()), &certs); err != nil { + t.Errorf("failed to unmarshal certs from json response: %s", err) + return + } + testutil.Equal(t, string(expectedPublicKey), certs[hex.EncodeToString(keyHash)]) +} + func TestFavicon(t *testing.T) { opts := NewOptions() opts.upstreamConfigs = generateTestUpstreamConfigs("httpheader.net/") diff --git a/internal/devproxy/logging_handler.go b/internal/devproxy/logging_handler.go index 44c26004..f11825d7 100644 --- a/internal/devproxy/logging_handler.go +++ b/internal/devproxy/logging_handler.go @@ -19,7 +19,6 @@ type responseLogger struct { w http.ResponseWriter status int size int - // authInfo string } func (l *responseLogger) Header() http.Header { @@ -31,14 +30,13 @@ func (l *responseLogger) Write(b []byte) (int, error) { // The status will be StatusOK if WriteHeader has not been called yet l.status = http.StatusOK } - // l.extractUser() + size, err := l.w.Write(b) l.size += size return size, err } func (l *responseLogger) WriteHeader(s int) { - // l.extractUser() l.w.WriteHeader(s) l.status = s } diff --git a/internal/devproxy/middleware.go b/internal/devproxy/middleware.go index 99192595..622e7605 100644 --- a/internal/devproxy/middleware.go +++ b/internal/devproxy/middleware.go @@ -6,8 +6,6 @@ import ( ) // With inspiration from https://github.com/unrolled/secure -// -// TODO: Add Content-Security-Report header? var securityHeaders = map[string]string{ "X-Content-Type-Options": "nosniff", "X-Frame-Options": "SAMEORIGIN", diff --git a/internal/devproxy/options.go b/internal/devproxy/options.go index 5465f655..eaf279fa 100644 --- a/internal/devproxy/options.go +++ b/internal/devproxy/options.go @@ -33,6 +33,8 @@ type Options struct { RequestLogging bool `envconfig:"REQUEST_LOGGING" default:"true"` + RequestSigningKey string `envconfig:"REQUEST_SIGNATURE_KEY"` + // This is an override for supplying template vars at test time testTemplateVars map[string]string diff --git a/internal/devproxy/request_signer.go b/internal/devproxy/request_signer.go new file mode 100644 index 00000000..c3965525 --- /dev/null +++ b/internal/devproxy/request_signer.go @@ -0,0 +1,191 @@ +package devproxy + +import ( + "bytes" + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/hex" + "encoding/pem" + "fmt" + "hash" + "io/ioutil" + "net/http" + "strings" +) + +// Only headers enumerated in this list are used to compute the signature of a request. +var signedHeaders = []string{ + "Content-Length", + "Content-Md5", + "Content-Type", + "Date", + "Authorization", + "X-Forwarded-User", + "X-Forwarded-Email", + "X-Forwarded-Groups", + "Cookie", +} + +// Name of the header used to transmit the signature computed for the request. +var signatureHeader = "Sso-Signature" +var signingKeyHeader = "kid" + +// RequestSigner exposes an interface for digitally signing requests using an RSA private key. +// See comments for the Sign() method below, for more on how this signature is constructed. +type RequestSigner struct { + hasher hash.Hash + signingKey crypto.Signer + publicKeyStr string + publicKeyID string +} + +// NewRequestSigner constructs a RequestSigner object from a PEM+PKCS8 encoded RSA public key. +func NewRequestSigner(signingKeyPemStr string) (*RequestSigner, error) { + var privateKey crypto.Signer + var publicKeyPEM []byte + + // Strip PEM encoding from private key. + block, _ := pem.Decode([]byte(signingKeyPemStr)) + if block == nil { + return nil, fmt.Errorf("could not read PEM block from signing key") + } + + // Extract private key as a crypto.Signer object. + key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("could not read key from signing key bytes: %s", err) + } + privateKey = key.(crypto.Signer) + + // Derive public key. + rsaPublicKey, ok := privateKey.Public().(*rsa.PublicKey) + if !ok { + return nil, fmt.Errorf("only RSA public keys are currently supported") + } + publicKeyPEM = pem.EncodeToMemory( + &pem.Block{ + Type: "RSA PUBLIC KEY", + Bytes: x509.MarshalPKCS1PublicKey(rsaPublicKey), + }) + + var keyHash []byte + hasher := sha256.New() + _, _ = hasher.Write(publicKeyPEM) + keyHash = hasher.Sum(keyHash) + + return &RequestSigner{ + hasher: sha256.New(), + signingKey: privateKey, + publicKeyStr: string(publicKeyPEM), + publicKeyID: hex.EncodeToString(keyHash), + }, nil +} + +// mapRequestToHashInput returns a string representation of a Request, formatted as a +// newline-separated sequence of entries from the request. Any two Requests sharing the same +// representation are considered "equivalent" for purposes of verifying the integrity of a request. +// +// Representations are formatted as follows: +// +// ... +// +// +// +// where: +// is the ','-joined concatenation of all header values of `signedHeaders[k]`; all +// other headers in the request are ignored, +// is the string "(?)(#FRAGMENT)", where "?" and "#" are +// ommitted if the associated components are absent from the request URL, +// is the body of the Request (may be `nil`; e.g. for GET requests). +// +// Receiving endpoints authenticating the integrity of a request should reconstruct this document +// exactly, when verifying the contents of a received request. +func mapRequestToHashInput(req *http.Request) (string, error) { + entries := []string{} + + // Add signed headers. + for _, hdr := range signedHeaders { + if hdrValues := req.Header[hdr]; len(hdrValues) > 0 { + entries = append(entries, strings.Join(hdrValues, ",")) + } + } + + // Add canonical URL representation. Ignore URL {scheme, host, port, etc}. + entries = append(entries, func() string { + url := req.URL.Path + if len(req.URL.RawQuery) > 0 { + url += ("?" + req.URL.RawQuery) + } + if len(req.URL.Fragment) > 0 { + url += ("#" + req.URL.Fragment) + } + return url + }()) + + // Add request body, if present (may be absent for GET requests, etc). + if req.Body != nil { + body, _ := ioutil.ReadAll(req.Body) + req.Body = ioutil.NopCloser(bytes.NewBuffer(body)) + entries = append(entries, string(body)) + } + + // Return the join of all entries, with each separated by a newline. + return strings.Join(entries, "\n"), nil +} + +// Sign appends a header to the request, with a public-key encrypted signature derive from +// a subset of the request headers, together with the request URL and body. +// +// Signature is computed as: +// repr := Representation(request) <- Computed by mapRequestToHashInput() +// hash := SHA256(repr) +// sig := SIGN(hash, SigningKey) +// final := WEB_SAFE_BASE64(sig) +// The header `Sso-Signature` is given the value of `final`. +// +// Receiving endpoints authenticating the integrity of a request should: +// 1. Strip the WEB_SAFE_BASE64 encoding from the value of `signatureHeader`, +// 2. Decrypt the resulting value using the public key published by sso_proxy, thus obtaining the +// hash of the request representation, +// 3. Compute the request representation from the received request, using the same format as the +// mapRequestToHashInput() function above, +// 4. Apply SHA256 hash to the recomputed representation, and verify that it matches the decrypted +// hash value received through the `Sso-Signature` of the request. +// +// Any requests failing this check should be considered tampered with, and rejected. +func (signer RequestSigner) Sign(req *http.Request) error { + // Generate the request representation that will serve as hash input. + repr, err := mapRequestToHashInput(req) + if err != nil { + return fmt.Errorf("could not generate representation for request: %s", err) + } + + // Generate hash of the document buffer. + var documentHash []byte + signer.hasher.Reset() + _, _ = signer.hasher.Write([]byte(repr)) + documentHash = signer.hasher.Sum(documentHash) + + // Sign the documentHash with the signing key. + signatureBytes, err := signer.signingKey.Sign(rand.Reader, documentHash, crypto.SHA256) + if err != nil { + return fmt.Errorf("failed signing document hash with signing key: %s", err) + } + signature := base64.URLEncoding.EncodeToString(signatureBytes) + + // Set the signature and signing-key request headers. Return nil to indicate no error. + req.Header.Set(signatureHeader, signature) + req.Header.Set(signingKeyHeader, signer.publicKeyID) + return nil +} + +// PublicKey returns a pair (KeyID, Key), where: +// - KeyID is a unique identifier (currently the SHA256 hash of Key), +// - Key is the (PEM+PKCS1)-encoding of a public key, usable for validating signed requests. +func (signer RequestSigner) PublicKey() (string, string) { + return signer.publicKeyID, signer.publicKeyStr +} diff --git a/internal/devproxy/request_signer_test.go b/internal/devproxy/request_signer_test.go new file mode 100644 index 00000000..99dd1688 --- /dev/null +++ b/internal/devproxy/request_signer_test.go @@ -0,0 +1,179 @@ +package devproxy + +import ( + "crypto" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/hex" + "encoding/pem" + "fmt" + "io/ioutil" + "net/http" + "strings" + "testing" + + "github.com/buzzfeed/sso/internal/pkg/testutil" +) + +// Convenience variables and utilities. +var urlExample = "https://foo.sso.example.com/path" + +func addHeaders(req *http.Request, examples []string, extras map[string][]string) { + var signedHeaderExamples = map[string][]string{ + "Content-Length": {"1234"}, + "Content-Md5": {"F00D"}, + "Content-Type": {"application/json"}, + "Date": {"2018-11-08"}, + "Authorization": {"Bearer ab12cd34"}, + "X-Forwarded-User": {"octoboi"}, + "X-Forwarded-Email": {"octoboi@example.com"}, + "X-Forwarded-Groups": {"molluscs", "security_applications"}, + } + + for _, signedHdr := range examples { + for _, value := range signedHeaderExamples[signedHdr] { + req.Header.Add(signedHdr, value) + } + } + for extraHdr, values := range extras { + for _, value := range values { + req.Header.Add(extraHdr, value) + } + } +} + +func TestRepr_UrlRepresentation(t *testing.T) { + testURL := func(url string, expect string) { + req, err := http.NewRequest("GET", url, nil) + if err != nil { + t.Errorf("could not build request: %s", err) + } + + repr, err := mapRequestToHashInput(req) + if err != nil { + t.Errorf("could not map request to hash input: %s", err) + } + testutil.Equal(t, expect, repr) + } + + testURL("http://foo.sso.example.com/path/to/resource", "/path/to/resource") + testURL("http://foo.sso.example.com/path?", "/path") + testURL("http://foo.sso.example.com/path/to?query#fragment", "/path/to?query#fragment") + testURL("https://foo.sso.example.com:4321/path#fragment", "/path#fragment") + testURL("http://foo.sso.example.com/path?query¶m=value#", "/path?query¶m=value") +} + +func TestRepr_HeaderRepresentation(t *testing.T) { + testHeaders := func(include []string, extra map[string][]string, expect string) { + req, err := http.NewRequest("GET", urlExample, nil) + if err != nil { + t.Errorf("could not build request: %s", err) + } + addHeaders(req, include, extra) + repr, err := mapRequestToHashInput(req) + if err != nil { + t.Errorf("could not map request to hash input: %s", err) + } + testutil.Equal(t, expect, repr) + } + + // Partial set of signed headers. + testHeaders([]string{"Authorization", "X-Forwarded-Groups"}, nil, + "Bearer ab12cd34\n"+ + "molluscs,security_applications\n"+ + "/path") + + // Full set of signed headers. + testHeaders(signedHeaders, nil, + "1234\n"+ + "F00D\n"+ + "application/json\n"+ + "2018-11-08\n"+ + "Bearer ab12cd34\n"+ + "octoboi\n"+ + "octoboi@example.com\n"+ + "molluscs,security_applications\n"+ + "/path") + + // Partial set of signed headers, plus another header (should not appear in representation). + testHeaders([]string{"Authorization", "X-Forwarded-Email"}, + map[string][]string{"X-Octopus-Stuff": {"54321"}}, + "Bearer ab12cd34\n"+ + "octoboi@example.com\n"+ + "/path") + + // Only unsigned headers. + testHeaders(nil, map[string][]string{"X-Octopus-Stuff": {"83721"}}, "/path") +} + +func TestRepr_PostWithBody(t *testing.T) { + req, err := http.NewRequest("POST", urlExample, strings.NewReader("something\nor other")) + if err != nil { + t.Errorf("could not build request: %s", err) + } + addHeaders(req, []string{"X-Forwarded-Email", "X-Forwarded-Groups"}, + map[string][]string{"X-Octopus-Stuff": {"54321"}}) + + repr, err := mapRequestToHashInput(req) + if err != nil { + t.Errorf("could not map request to hash input: %s", err) + } + testutil.Equal(t, + "octoboi@example.com\n"+ + "molluscs,security_applications\n"+ + "/path\n"+ + "something\n"+ + "or other", + repr) +} + +func TestSignatureRoundTripDecoding(t *testing.T) { + // Keys used for signing/validating request. + privateKey, err := ioutil.ReadFile("testdata/private_key.pem") + testutil.Assert(t, err == nil, "error reading private key from testdata") + + publicKey, err := ioutil.ReadFile("testdata/public_key.pub") + testutil.Assert(t, err == nil, "error reading public key from testdata") + + // Build the RequestSigner object used to generate the request signature header. + requestSigner, err := NewRequestSigner(string(privateKey)) + testutil.Assert(t, err == nil, "could not initialize request signer: %s", err) + + // And build the rsa.PublicKey object that will help verify the signature. + verifierKey, err := func() (*rsa.PublicKey, error) { + if block, _ := pem.Decode(publicKey); block == nil { + return nil, fmt.Errorf("could not read PEM block from public key") + } else if key, err := x509.ParsePKCS1PublicKey(block.Bytes); err != nil { + return nil, fmt.Errorf("could not read key from public key bytes: %s", err) + } else { + return key, nil + } + }() + testutil.Assert(t, err == nil, "could not construct public key: %s", err) + + // Build the Request to be signed. + req, err := http.NewRequest("POST", urlExample, strings.NewReader("something\nor other")) + testutil.Assert(t, err == nil, "could not construct request: %s", err) + addHeaders(req, []string{"X-Forwarded-Email", "X-Forwarded-Groups"}, + map[string][]string{"X-Octopus-Stuff": {"54321"}}) + + // Sign the request, and extract its signature from the header. + err = requestSigner.Sign(req) + testutil.Assert(t, err == nil, "could not sign request: %s", err) + sig, _ := base64.URLEncoding.DecodeString(req.Header.Get("Sso-Signature")) + + // Hardcoded expected hash, computed from the request. + expectedHash, _ := hex.DecodeString( + "04158c00fbecccd8b5dca58634a0a7f28bf5ad908f19cb1b404bdd37bb4485a9") + err = rsa.VerifyPKCS1v15(verifierKey, crypto.SHA256, expectedHash, sig) + testutil.Assert(t, err == nil, "could not verify request signature: %s", err) + + // Verify that the signing-key header is the hash of the public-key. + var pubKeyHash []byte + hasher := sha256.New() + _, _ = hasher.Write(publicKey) + pubKeyHash = hasher.Sum(pubKeyHash) + testutil.Equal(t, hex.EncodeToString(pubKeyHash), req.Header.Get("kid")) +} diff --git a/internal/devproxy/testdata/upstream_configs.yml b/internal/devproxy/testdata/upstream_configs.yml index 41eade61..21bfcb9b 100644 --- a/internal/devproxy/testdata/upstream_configs.yml +++ b/internal/devproxy/testdata/upstream_configs.yml @@ -3,8 +3,4 @@ from: httpbin.sso.localtest.me to: http://httpheader.net -# - service: hello-world -# default: -# from: hello-world.sso.localtest.me -# to: http://httpheader.net From c3b77624d29e479ac55729eca387f5134c67829f Mon Sep 17 00:00:00 2001 From: Maliheh Date: Wed, 20 Feb 2019 11:09:48 -0800 Subject: [PATCH 4/5] fixed the bug in request_signer --- internal/devproxy/devproxy.go | 45 ++++++++++++------- internal/devproxy/devproxy_test.go | 10 ++--- internal/devproxy/metrics.go | 9 ++-- internal/devproxy/request_signer.go | 18 ++++++-- internal/devproxy/request_signer_test.go | 2 +- internal/devproxy/testdata/.env | 10 +++++ internal/devproxy/testdata/private_key.pem | 28 ++++++++++++ internal/devproxy/testdata/public_key.pem | 8 ++++ .../devproxy/testdata/upstream_configs.yml | 8 ++-- 9 files changed, 103 insertions(+), 35 deletions(-) create mode 100644 internal/devproxy/testdata/.env create mode 100644 internal/devproxy/testdata/private_key.pem create mode 100644 internal/devproxy/testdata/public_key.pem diff --git a/internal/devproxy/devproxy.go b/internal/devproxy/devproxy.go index 03b8876c..da9f1b9f 100644 --- a/internal/devproxy/devproxy.go +++ b/internal/devproxy/devproxy.go @@ -33,7 +33,7 @@ const statusInvalidHost = 421 // DevProxy stores all the information associated with proxying the request. type DevProxy struct { - redirectURL *url.URL // the url to receive requests at + // redirectURL *url.URL // the url to receive requests at skipAuthPreflight bool templates *template.Template mux map[string]*route @@ -104,7 +104,9 @@ func newUpstreamTransport(insecureSkipVerify bool) *upstreamTransport { // ServeHTTP calls the upstream's ServeHTTP function. func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if u.requestSigner != nil { + u.requestSigner.Sign(r) } @@ -112,7 +114,7 @@ func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { u.handler.ServeHTTP(w, r) duration := time.Now().Sub(start) - fmt.Sprintf("service_name:%s, duation:%s", u.name, duration) + fmt.Printf("service_name:%s, duration:%s", u.name, duration) } func singleJoiningSlash(a, b string) string { @@ -152,10 +154,6 @@ func NewReverseProxy(to *url.URL, config *UpstreamConfig) *httputil.ReverseProxy proxy.Director = func(req *http.Request) { req.Header.Add("X-Forwarded-Host", req.Host) - req.Header.Set("X-Forwarded-User", req.Header.Get("User")) - req.Header.Set("X-Forwarded-Email", req.Header.Get("Email")) - req.Header.Set("X-Forwarded-Groups", req.Header.Get("Groups")) - req.Header.Set("X-Forwarded-Access-Token", "") dir(req) req.Host = to.Host } @@ -185,10 +183,6 @@ func NewRewriteReverseProxy(route *RewriteRoute, config *UpstreamConfig) *httput director := httputil.NewSingleHostReverseProxy(target).Director req.Header.Add("X-Forwarded-Host", req.Host) - req.Header.Set("X-Forwarded-User", req.Header.Get("User")) - req.Header.Set("X-Forwarded-Email", req.Header.Get("Email")) - req.Header.Set("X-Forwarded-Groups", req.Header.Get("Groups")) - req.Header.Set("X-Forwarded-Access-Token", "") director(req) req.Host = target.Host } @@ -202,6 +196,7 @@ func NewReverseProxyHandler(reverseProxy *httputil.ReverseProxy, opts *Options, handler: reverseProxy, requestSigner: signer, } + if config.SkipRequestSigning { upstreamProxy.requestSigner = nil } @@ -255,18 +250,23 @@ func NewDevProxy(opts *Options, optFuncs ...func(*DevProxy) error) (*DevProxy, e // Also build the `certs` static JSON-string which will be served from a public endpoint. // The key published at this endpoint allows upstreams to decrypt the `Sso-Signature` // header, and validate the integrity and authenticity of a request. + certs := make(map[string]string) var requestSigner *RequestSigner + var err error if len(opts.RequestSigningKey) > 0 { - requestSigner, err := NewRequestSigner(opts.RequestSigningKey) + requestSigner, err = NewRequestSigner(opts.RequestSigningKey) + if err != nil { return nil, fmt.Errorf("could not build RequestSigner: %s", err) } id, key := requestSigner.PublicKey() certs[id] = key + } else { logger.Warn("Running DevProxy without signing key. Requests will not be signed.") } + certsAsStr, err := json.MarshalIndent(certs, "", " ") if err != nil { return nil, fmt.Errorf("could not marshal public certs as JSON: %s", err) @@ -277,7 +277,7 @@ func NewDevProxy(opts *Options, optFuncs ...func(*DevProxy) error) (*DevProxy, e mux: make(map[string]*route), regexRoutes: make([]*route, 0), - redirectURL: &url.URL{Path: "/oauth2/callback"}, + // redirectURL: &url.URL{Path: "/oauth2/callback"}, templates: getTemplates(), requestSigner: requestSigner, publicCertsJSON: certsAsStr, @@ -301,7 +301,7 @@ func NewDevProxy(opts *Options, optFuncs ...func(*DevProxy) error) (*DevProxy, e handler, tags := NewReverseProxyHandler(reverseProxy, opts, upstreamConfig, requestSigner) p.HandleRegex(route.FromRegex, handler, tags, upstreamConfig) default: - return nil, fmt.Errorf("unkown route type") + return nil, fmt.Errorf("unknown route type") } } @@ -342,6 +342,11 @@ func (p *DevProxy) RobotsTxt(rw http.ResponseWriter, _ *http.Request) { // Favicon will proxy the request as usual func (p *DevProxy) Favicon(rw http.ResponseWriter, req *http.Request) { + err := p.setProxyHeaders(rw, req) + if err != nil { + rw.WriteHeader(http.StatusNotFound) + return + } rw.WriteHeader(http.StatusOK) p.Proxy(rw, req) } @@ -391,10 +396,10 @@ func (p *DevProxy) isXMLHTTPRequest(req *http.Request) bool { func (p *DevProxy) Proxy(rw http.ResponseWriter, req *http.Request) { logger := log.NewLogEntry() - // start := time.Now() - logger.Info("Proxy...") + p.setProxyHeaders(rw, req) - // We have validated the users request and now proxy their request to the provided upstream. + logger.Info("Proxy...") + // We now proxy their request to the provided upstream. route, ok := p.router(req) if !ok { p.UnknownHost(rw, req) @@ -424,6 +429,14 @@ func (p *DevProxy) HandleRegex(regex *regexp.Regexp, handler http.Handler, tags p.regexRoutes = append(p.regexRoutes, &route{regex: regex, handler: handler, upstreamConfig: upstreamConfig, tags: tags}) } +func (p *DevProxy) setProxyHeaders(rw http.ResponseWriter, req *http.Request) (err error) { + req.Header.Set("X-Forwarded-User", req.Header.Get("User")) + req.Header.Set("X-Forwarded-Email", req.Header.Get("Email")) + req.Header.Set("X-Forwarded-Groups", req.Header.Get("groups")) + // req.Header.set("X-Forwarded-Access-Token", "") + return nil +} + // router attempts to find a route for a request. If a route is successfully matched, // it returns the route information and a bool value of `true`. If a route can not be matched, //a nil value for the route and false bool value is returned. diff --git a/internal/devproxy/devproxy_test.go b/internal/devproxy/devproxy_test.go index 268d8019..842a235f 100644 --- a/internal/devproxy/devproxy_test.go +++ b/internal/devproxy/devproxy_test.go @@ -262,7 +262,7 @@ func TestCerts(t *testing.T) { opts.RequestSigningKey = string(requestSigningKey) opts.Validate() - expectedPublicKey, err := ioutil.ReadFile("testdata/public_key.pub") + expectedPublicKey, err := ioutil.ReadFile("testdata/public_key.pem") testutil.Assert(t, err == nil, "could not read public key from testdata: %s", err) var keyHash []byte @@ -798,7 +798,7 @@ func TestTimeoutHandler(t *testing.T) { if res.StatusCode != tc.ExpectedStatusCode { t.Errorf(" got=%v", res.StatusCode) t.Errorf("want=%v", tc.ExpectedStatusCode) - t.Fatalf("got unexpcted status code") + t.Fatalf("got unexpected status code") } body, err := ioutil.ReadAll(res.Body) @@ -809,7 +809,7 @@ func TestTimeoutHandler(t *testing.T) { if string(body) != tc.ExpectedBody { t.Errorf(" got=%q", body) t.Errorf("want=%q", tc.ExpectedBody) - t.Fatalf("got unexpcted body") + t.Fatalf("got unexpected body") } }) } @@ -868,7 +868,7 @@ func TestRewriteRoutingHandling(t *testing.T) { upstreamHost, upstreamPort, err := net.SplitHostPort(parsedUpstreamURL.Host) if err != nil { - t.Fatalf("expected to split host/hort err:%q", err) + t.Fatalf("expected to split host/port err:%q", err) } testCases := []struct { @@ -902,7 +902,7 @@ func TestRewriteRoutingHandling(t *testing.T) { ExpectedCode: statusInvalidHost, }, { - Name: "it should match and replace using regex/template to find port in embeded domain", + Name: "it should match and replace using regex/template to find port in embedded domain", TestHost: fmt.Sprintf("somedomain--%s", upstreamPort), FromRegex: "somedomain--(.*)", // capture port ToTemplate: fmt.Sprintf("%s:$1", upstreamHost), // add port to dest diff --git a/internal/devproxy/metrics.go b/internal/devproxy/metrics.go index 68a56048..228c9b1c 100644 --- a/internal/devproxy/metrics.go +++ b/internal/devproxy/metrics.go @@ -8,12 +8,9 @@ import ( func GetActionTag(req *http.Request) string { // only log metrics for these paths and actions pathToAction := map[string]string{ - "/favicon.ico": "favicon", - "/oauth2/sign_out": "sign_out", - "/oauth2/callback": "callback", - "/oauth2/auth": "auth", - "/ping": "ping", - "/robots.txt": "robots", + "/favicon.ico": "favicon", + "/ping": "ping", + "/robots.txt": "robots", } // get the action from the url path path := req.URL.Path diff --git a/internal/devproxy/request_signer.go b/internal/devproxy/request_signer.go index c3965525..323830cc 100644 --- a/internal/devproxy/request_signer.go +++ b/internal/devproxy/request_signer.go @@ -27,7 +27,6 @@ var signedHeaders = []string{ "X-Forwarded-User", "X-Forwarded-Email", "X-Forwarded-Groups", - "Cookie", } // Name of the header used to transmit the signature computed for the request. @@ -96,8 +95,8 @@ func NewRequestSigner(signingKeyPemStr string) (*RequestSigner, error) { // // // where: -// is the ','-joined concatenation of all header values of `signedHeaders[k]`; all -// other headers in the request are ignored, +// is the ','-joined concatenation of all header values of `signedHeaders[k]`; empty +// values such as '' and all other headers in the request are ignored, // is the string "(?)(#FRAGMENT)", where "?" and "#" are // ommitted if the associated components are absent from the request URL, // is the body of the Request (may be `nil`; e.g. for GET requests). @@ -109,7 +108,8 @@ func mapRequestToHashInput(req *http.Request) (string, error) { // Add signed headers. for _, hdr := range signedHeaders { - if hdrValues := req.Header[hdr]; len(hdrValues) > 0 { + hdrValues := removeEmpty(req.Header[hdr]) + if len(hdrValues) > 0 { entries = append(entries, strings.Join(hdrValues, ",")) } } @@ -189,3 +189,13 @@ func (signer RequestSigner) Sign(req *http.Request) error { func (signer RequestSigner) PublicKey() (string, string) { return signer.publicKeyID, signer.publicKeyStr } + +func removeEmpty(s []string) []string { + r := []string{} + for _, str := range s { + if len(str) > 0 { + r = append(r, str) + } + } + return r +} diff --git a/internal/devproxy/request_signer_test.go b/internal/devproxy/request_signer_test.go index 99dd1688..ced5b190 100644 --- a/internal/devproxy/request_signer_test.go +++ b/internal/devproxy/request_signer_test.go @@ -134,7 +134,7 @@ func TestSignatureRoundTripDecoding(t *testing.T) { privateKey, err := ioutil.ReadFile("testdata/private_key.pem") testutil.Assert(t, err == nil, "error reading private key from testdata") - publicKey, err := ioutil.ReadFile("testdata/public_key.pub") + publicKey, err := ioutil.ReadFile("testdata/public_key.pem") testutil.Assert(t, err == nil, "error reading public key from testdata") // Build the RequestSigner object used to generate the request signature header. diff --git a/internal/devproxy/testdata/.env b/internal/devproxy/testdata/.env new file mode 100644 index 00000000..91c839bb --- /dev/null +++ b/internal/devproxy/testdata/.env @@ -0,0 +1,10 @@ +export PORT=4888 +export SCHEME=http +export HOST=http://localhost/ +export UPSTREAM_CONFIGS=/path/to/upstream_configs.yml +export CLUSTER=sso-dev +export DEFAULT_UPSTREAM_TIMEOUT=10s +export TCP_WRITE_TIMEOUT=30s +export TCP_READ_TIMEOUT=30s +export REQUEST_LOGGING=true +export REQUEST_SIGNATURE_KEY=$(cat /path/to/devproxy/testdata/private_key.pem) diff --git a/internal/devproxy/testdata/private_key.pem b/internal/devproxy/testdata/private_key.pem new file mode 100644 index 00000000..03a16bd3 --- /dev/null +++ b/internal/devproxy/testdata/private_key.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCy38IQCH8QyeNF +s1zA0XuIyqnTcSfYZg0nPfB+K//pFy7tIOAwmR6th8NykrxFhEQDHKNCmLXt4j8V +FDHQZtGjUBHRmAXZW8NOQ0EI1vc/Dpt09sU40JQlXZZeL+9/7iAxEfSE3TQr1k7P +Xwxpjm9rsLSn7FoLnvXco0mc6+d2jjxf4cMgJIaQLKOd783KUQzLVEvBQJ05JnpI +2xMjS0q33ltMTMGF3QZQN9i4bZKgnItomKxTJbfxftO11FTNLB7og94sWmlThAY5 +/UMjZaWYJ1g89+WUJ+KpVYyJsHPBBkaQG+NYazcLDyIowpzJ1WVkInysshpTqwT+ +UPV4at+jAgMBAAECggEAX8lxK5LRMJVcLlwRZHQJekRE0yS6WKi1jHkfywEW5qRy +jatYQs4MXpLgN/+Z8IQWw6/XQXdznTLV4xzQXDBjPNhI4ntNTotUOBnNvsUW296f +ou/uxzDy1FuchU2YLGLBPGXIEko+gOcfhu74P6J1yi5zX6UyxxxVvtR2PCEb7yDw +m2881chwMblZ5Z8uyF++ajkK3/rqLk64w29+K4ZTDbTcCp5NtBYx2qSEU7yp12rc +qscUGqxG00Abx+osI3cUn0kOq7356LeR1rfA15yZwOb+s28QYp2WPlVB2hOiYXQv ++ttEOpt0x1QJhBAsFgwY173sD5w2MryRQb1RCwBvqQKBgQDeTdbRzxzAl83h/mAq +5I+pNEz57veAFVO+iby7TbZ/0w6q+QeT+bHF+TjGHiSlbtg3nd9NPrex2UjiN7ej ++DrxhsSLsP1ZfwDNv6f1Ii1HluJclUFSUNU/LntBjqqCJ959lniNp1y5+ZQ/j2Rf ++ZraVsHRB0itilFeAl5+n7CfxwKBgQDN/K+E1TCbp1inU60Lc9zeb8fqTEP6Mp36 +qQ0Dp+KMLPJ0xQSXFq9ILr4hTJlBqfmTkfmQUcQuwercZ3LNQPbsuIg96bPW73R1 +toXjokd6jUn5sJXCOE0RDumcJrL1VRf9RN1AmM4CgCc/adUMjws3pBc5R4An7UyU +ouRQhN+5RQKBgFOVTrzqM3RSX22mWAAomb9T09FxQQueeTM91IFUMdcTwwMTyP6h +Nm8qSmdrM/ojmBYpPKlteGHdQaMUse5rybXAJywiqs84ilPRyNPJOt8c4xVOZRYP +IG62Ck/W1VNErEnqBn+0OpAOP+g6ANJ5JfkL/6mZJIFjbT58g4z2e9FHAoGBAM3f +uBkd7lgTuLJ8Gh6xLVYQCJHuqZ49ytFE9qHpwK5zGdyFMSJE5OlS9mpXoXEUjkHk +iraoUlidLbwdlIr6XBCaGmku07SFXTNtOoIZpjEhV4c762HTXYsoCWos733uD2zt +z+iJEJVFOnTRtMK5kO+KjD+Oa9L8BCcmauTi+Ku1AoGAZBUzi95THA60hPXI0hm/ +o0J5mfLkFPfhpUmDAMaEpv3bM4byA+IGXSZVc1IZO6cGoaeUHD2Yl1m9a5tv5rF+ +FS9Ht+IgATvGojah+xxQy+kf6tRB9Hn4scyq+64AesXlDbWDEagomQ0hyV/JKSS6 +LQatvnCmBd9omRT2uwYUo+o= +-----END PRIVATE KEY----- diff --git a/internal/devproxy/testdata/public_key.pem b/internal/devproxy/testdata/public_key.pem new file mode 100644 index 00000000..cccac43b --- /dev/null +++ b/internal/devproxy/testdata/public_key.pem @@ -0,0 +1,8 @@ +-----BEGIN RSA PUBLIC KEY----- +MIIBCgKCAQEAst/CEAh/EMnjRbNcwNF7iMqp03En2GYNJz3wfiv/6Rcu7SDgMJke +rYfDcpK8RYREAxyjQpi17eI/FRQx0GbRo1AR0ZgF2VvDTkNBCNb3Pw6bdPbFONCU +JV2WXi/vf+4gMRH0hN00K9ZOz18MaY5va7C0p+xaC5713KNJnOvndo48X+HDICSG +kCyjne/NylEMy1RLwUCdOSZ6SNsTI0tKt95bTEzBhd0GUDfYuG2SoJyLaJisUyW3 +8X7TtdRUzSwe6IPeLFppU4QGOf1DI2WlmCdYPPfllCfiqVWMibBzwQZGkBvjWGs3 +Cw8iKMKcydVlZCJ8rLIaU6sE/lD1eGrfowIDAQAB +-----END RSA PUBLIC KEY----- diff --git a/internal/devproxy/testdata/upstream_configs.yml b/internal/devproxy/testdata/upstream_configs.yml index 21bfcb9b..a2ad2e32 100644 --- a/internal/devproxy/testdata/upstream_configs.yml +++ b/internal/devproxy/testdata/upstream_configs.yml @@ -1,6 +1,8 @@ -- service: httpbin +- service: dev-shim default: - from: httpbin.sso.localtest.me - to: http://httpheader.net + from: http://localhost:4888 + to: http://localhost:4810 + options: + skip_request_signing: false From 21f37ae326d09743a9c05385f422fb98f3d843b9 Mon Sep 17 00:00:00 2001 From: Maliheh Date: Tue, 21 May 2019 16:11:59 -0700 Subject: [PATCH 5/5] adding authentication --- .gitignore | 3 +- internal/devproxy/dev_config.go | 35 +++++++++++++++--- internal/devproxy/devproxy.go | 36 ++++++++++++------- internal/devproxy/options.go | 32 ++++++----------- internal/devproxy/request_signer.go | 1 + internal/devproxy/testdata/.env | 5 +-- .../devproxy/testdata/upstream_configs.yml | 7 ++-- 7 files changed, 77 insertions(+), 42 deletions(-) diff --git a/.gitignore b/.gitignore index 98b2c19f..846cab3a 100644 --- a/.gitignore +++ b/.gitignore @@ -51,8 +51,9 @@ __pycache__/ # C extensions *.so -# Test binary, build with `go test -c` +# Test binaries *.test +sso-devproxy # Output of the go coverage tool, specifically when used with LiteIDE *.out diff --git a/internal/devproxy/dev_config.go b/internal/devproxy/dev_config.go index 1e09638b..7400c322 100644 --- a/internal/devproxy/dev_config.go +++ b/internal/devproxy/dev_config.go @@ -55,16 +55,27 @@ type UpstreamConfig struct { HeaderOverrides map[string]string TLSSkipVerify bool SkipRequestSigning bool + User string + Groups string + Email string } // RouteConfig maps to the yaml config fields, // * "from" - the domain that will be used to access the service // * "to" - the cname of the proxied service (this tells sso proxy where to proxy requests that come in on the from field) type RouteConfig struct { - From string `yaml:"from"` - To string `yaml:"to"` - Type string `yaml:"type"` - Options *OptionsConfig `yaml:"options"` + From string `yaml:"from"` + To string `yaml:"to"` + Type string `yaml:"type"` + Options *OptionsConfig `yaml:"options"` + UserInfo *UserInfo `yaml:"user_info"` +} + +//UserInfo is going to be injected into the header +type UserInfo struct { + User string `yaml:"user"` + Groups string `yaml:"groups"` + Email string `yaml:"email"` } // OptionsConfig maps to the yaml config fields: @@ -177,6 +188,10 @@ func loadServiceConfigs(raw []byte, cluster, scheme string, configVars map[strin if err != nil { return nil, err } + err = parseUserInfoConfig(proxy) + if err != nil { + return nil, err + } } for _, proxy := range configs { @@ -353,6 +368,18 @@ func parseOptionsConfig(proxy *UpstreamConfig) error { return nil } +func parseUserInfoConfig(proxy *UpstreamConfig) error { + if proxy.RouteConfig.UserInfo == nil { + return nil + } + + proxy.User = proxy.RouteConfig.UserInfo.User + proxy.Groups = proxy.RouteConfig.UserInfo.Groups + proxy.Email = proxy.RouteConfig.UserInfo.Email + + return nil +} + func cleanWhiteSpace(s string) string { // This trims all white space from a service name and collapses all remaining space to `_` return space.ReplaceAllString(strings.TrimSpace(s), "_") // diff --git a/internal/devproxy/devproxy.go b/internal/devproxy/devproxy.go index da9f1b9f..43ac0d92 100644 --- a/internal/devproxy/devproxy.go +++ b/internal/devproxy/devproxy.go @@ -23,23 +23,30 @@ const HMACSignatureHeader = "Gap-Signature" // SignatureHeaders are the headers that are valid in the request. var SignatureHeaders = []string{ + "Content-Length", + "Content-Md5", + "Content-Type", + "Date", + "Authorization", "X-Forwarded-User", "X-Forwarded-Email", "X-Forwarded-Groups", - "X-Forwarded-Access-Token", + "Cookie", } const statusInvalidHost = 421 // DevProxy stores all the information associated with proxying the request. type DevProxy struct { - // redirectURL *url.URL // the url to receive requests at skipAuthPreflight bool templates *template.Template mux map[string]*route regexRoutes []*route requestSigner *RequestSigner publicCertsJSON []byte + user string + groups string + email string } type route struct { @@ -61,6 +68,7 @@ type StateParameter struct { type UpstreamProxy struct { name string handler http.Handler + auth hmacauth.HmacAuth requestSigner *RequestSigner } @@ -104,9 +112,10 @@ func newUpstreamTransport(insecureSkipVerify bool) *upstreamTransport { // ServeHTTP calls the upstream's ServeHTTP function. func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { - + if u.auth != nil { + u.auth.SignRequest(r) + } if u.requestSigner != nil { - u.requestSigner.Sign(r) } @@ -157,6 +166,7 @@ func NewReverseProxy(to *url.URL, config *UpstreamConfig) *httputil.ReverseProxy dir(req) req.Host = to.Host } + return proxy } @@ -194,6 +204,7 @@ func NewReverseProxyHandler(reverseProxy *httputil.ReverseProxy, opts *Options, upstreamProxy := &UpstreamProxy{ name: config.Service, handler: reverseProxy, + auth: config.HMACAuth, requestSigner: signer, } @@ -274,10 +285,8 @@ func NewDevProxy(opts *Options, optFuncs ...func(*DevProxy) error) (*DevProxy, e p := &DevProxy{ // these fields make up the routing mechanism - mux: make(map[string]*route), - regexRoutes: make([]*route, 0), - - // redirectURL: &url.URL{Path: "/oauth2/callback"}, + mux: make(map[string]*route), + regexRoutes: make([]*route, 0), templates: getTemplates(), requestSigner: requestSigner, publicCertsJSON: certsAsStr, @@ -289,8 +298,10 @@ func NewDevProxy(opts *Options, optFuncs ...func(*DevProxy) error) (*DevProxy, e return nil, err } } - for _, upstreamConfig := range opts.upstreamConfigs { + p.user = upstreamConfig.User + p.email = upstreamConfig.Email + p.groups = upstreamConfig.Groups switch route := upstreamConfig.Route.(type) { case *SimpleRoute: reverseProxy := NewReverseProxy(route.ToURL, upstreamConfig) @@ -419,6 +430,7 @@ func (p *DevProxy) UnknownHost(rw http.ResponseWriter, req *http.Request) { // Handle constructs a route from the given host string and matches it to the provided http.Handler and UpstreamConfig func (p *DevProxy) Handle(host string, handler http.Handler, tags []string, upstreamConfig *UpstreamConfig) { + tags = append(tags, "route:simple") p.mux[host] = &route{handler: handler, upstreamConfig: upstreamConfig, tags: tags} } @@ -430,9 +442,9 @@ func (p *DevProxy) HandleRegex(regex *regexp.Regexp, handler http.Handler, tags } func (p *DevProxy) setProxyHeaders(rw http.ResponseWriter, req *http.Request) (err error) { - req.Header.Set("X-Forwarded-User", req.Header.Get("User")) - req.Header.Set("X-Forwarded-Email", req.Header.Get("Email")) - req.Header.Set("X-Forwarded-Groups", req.Header.Get("groups")) + req.Header.Set("X-Forwarded-User", p.user) + req.Header.Set("X-Forwarded-Email", p.email) + req.Header.Set("X-Forwarded-Groups", p.groups) // req.Header.set("X-Forwarded-Access-Token", "") return nil } diff --git a/internal/devproxy/options.go b/internal/devproxy/options.go index eaf279fa..00c5c367 100644 --- a/internal/devproxy/options.go +++ b/internal/devproxy/options.go @@ -18,26 +18,17 @@ import ( // TCPWriteTimeout - http server tcp write timeout // TCPReadTimeout - http server tcp read timeout type Options struct { - Port int `envconfig:"PORT" default:"4180"` - - UpstreamConfigsFile string `envconfig:"UPSTREAM_CONFIGS"` - Cluster string `envconfig:"CLUSTER"` - Scheme string `envconfig:"SCHEME" default:"https"` - - Host string `envconfig:"HOST"` - + Port int `envconfig:"PORT" default:"4180"` + UpstreamConfigsFile string `envconfig:"UPSTREAM_CONFIGS"` + Scheme string `envconfig:"SCHEME" default:"https"` + Host string `envconfig:"HOST"` DefaultUpstreamTimeout time.Duration `envconfig:"DEFAULT_UPSTREAM_TIMEOUT" default:"10s"` - - TCPWriteTimeout time.Duration `envconfig:"TCP_WRITE_TIMEOUT" default:"30s"` - TCPReadTimeout time.Duration `envconfig:"TCP_READ_TIMEOUT" default:"30s"` - - RequestLogging bool `envconfig:"REQUEST_LOGGING" default:"true"` - - RequestSigningKey string `envconfig:"REQUEST_SIGNATURE_KEY"` - + TCPWriteTimeout time.Duration `envconfig:"TCP_WRITE_TIMEOUT" default:"30s"` + TCPReadTimeout time.Duration `envconfig:"TCP_READ_TIMEOUT" default:"30s"` + RequestLogging bool `envconfig:"REQUEST_LOGGING" default:"true"` + RequestSigningKey string `envconfig:"REQUEST_SIGNATURE_KEY"` // This is an override for supplying template vars at test time testTemplateVars map[string]string - // internal values that are set after config validation upstreamConfigs []*UpstreamConfig } @@ -62,11 +53,10 @@ func parseURL(toParse string, urltype string, msgs []string) (*url.URL, []string // Validate validates options func (o *Options) Validate() error { msgs := make([]string, 0) - if o.Cluster == "" { - msgs = append(msgs, "missing setting: cluster") - } + if o.UpstreamConfigsFile == "" { msgs = append(msgs, "missing setting: upstream-configs") + o.UpstreamConfigsFile = "internal/devproxy/testdata/upstream_configs.yml" } if o.UpstreamConfigsFile != "" { @@ -80,7 +70,7 @@ func (o *Options) Validate() error { templateVars = o.testTemplateVars } - o.upstreamConfigs, err = loadServiceConfigs(rawBytes, o.Cluster, o.Scheme, templateVars) + o.upstreamConfigs, err = loadServiceConfigs(rawBytes, "default", o.Scheme, templateVars) if err != nil { msgs = append(msgs, fmt.Sprintf("error parsing upstream configs file %s", err)) } diff --git a/internal/devproxy/request_signer.go b/internal/devproxy/request_signer.go index 323830cc..7184a35e 100644 --- a/internal/devproxy/request_signer.go +++ b/internal/devproxy/request_signer.go @@ -27,6 +27,7 @@ var signedHeaders = []string{ "X-Forwarded-User", "X-Forwarded-Email", "X-Forwarded-Groups", + "Cookie", } // Name of the header used to transmit the signature computed for the request. diff --git a/internal/devproxy/testdata/.env b/internal/devproxy/testdata/.env index 91c839bb..9e9b4a15 100644 --- a/internal/devproxy/testdata/.env +++ b/internal/devproxy/testdata/.env @@ -1,10 +1,11 @@ export PORT=4888 +export UPSTREAM_CONFIGS=/path/to/upstream_configs.yml export SCHEME=http export HOST=http://localhost/ -export UPSTREAM_CONFIGS=/path/to/upstream_configs.yml export CLUSTER=sso-dev export DEFAULT_UPSTREAM_TIMEOUT=10s export TCP_WRITE_TIMEOUT=30s export TCP_READ_TIMEOUT=30s export REQUEST_LOGGING=true -export REQUEST_SIGNATURE_KEY=$(cat /path/to/devproxy/testdata/private_key.pem) +export REQUEST_SIGNATURE_KEY=$(cat /path/to/private_key.pem) +export DEV_CONFIG_DEVSHIM_SIGNING_KEY="sha256:shared-secret-value" diff --git a/internal/devproxy/testdata/upstream_configs.yml b/internal/devproxy/testdata/upstream_configs.yml index a2ad2e32..8c8b4767 100644 --- a/internal/devproxy/testdata/upstream_configs.yml +++ b/internal/devproxy/testdata/upstream_configs.yml @@ -1,8 +1,11 @@ -- service: dev-shim +- service: devshim default: from: http://localhost:4888 to: http://localhost:4810 options: skip_request_signing: false - + user_info: + user: testUser + groups: team + email: testtest@remitly.com