From 09a834e4d223ab99d035d621bdf0320f9528b363 Mon Sep 17 00:00:00 2001 From: Alexander Nicke Date: Fri, 29 Aug 2025 12:13:37 +0200 Subject: [PATCH 01/17] Implement hash-based routing (#505) This commit provides the basic implementation for hash-based routing. It does not consider the balance factor yet. Co-authored-by: Clemens Hoffmann Co-authored-by: Tamara Boehm Co-authored-by: Soha Alboghdady --- docs/03-how-to-add-new-route-option.md | 5 + .../round_tripper/proxy_round_tripper.go | 14 + .../round_tripper/proxy_round_tripper_test.go | 162 +++++++++++ .../gorouter/route/hash_based.go | 141 ++++++++++ .../gorouter/route/hash_based_test.go | 155 +++++++++++ .../gorouter/route/maglev.go | 212 +++++++++++++++ .../gorouter/route/maglev_test.go | 257 ++++++++++++++++++ .../gorouter/route/pool.go | 79 +++++- .../gorouter/route/pool_test.go | 40 +++ 9 files changed, 1061 insertions(+), 4 deletions(-) create mode 100644 src/code.cloudfoundry.org/gorouter/route/hash_based.go create mode 100644 src/code.cloudfoundry.org/gorouter/route/hash_based_test.go create mode 100644 src/code.cloudfoundry.org/gorouter/route/maglev.go create mode 100644 src/code.cloudfoundry.org/gorouter/route/maglev_test.go diff --git a/docs/03-how-to-add-new-route-option.md b/docs/03-how-to-add-new-route-option.md index 39ce25447..308213fcc 100644 --- a/docs/03-how-to-add-new-route-option.md +++ b/docs/03-how-to-add-new-route-option.md @@ -22,6 +22,11 @@ applications: - route: example2.com options: loadbalancing: least-connection + - route: example3.com + options: + loadbalancing: hash + hash_header: tenant-id + hash_balance: 1.25 ``` **NOTE**: In the implementation, the `options` property of a route represents per-route features. diff --git a/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper.go b/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper.go index 88cfd20a5..84252262f 100644 --- a/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper.go +++ b/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper.go @@ -127,6 +127,20 @@ func (rt *roundTripper) RoundTrip(originalRequest *http.Request) (*http.Response stickyEndpointID, mustBeSticky := handlers.GetStickySession(request, rt.config.StickySessionCookieNames, rt.config.StickySessionsForAuthNegotiate) numberOfEndpoints := reqInfo.RoutePool.NumEndpoints() iter := reqInfo.RoutePool.Endpoints(rt.logger, stickyEndpointID, mustBeSticky, rt.config.LoadBalanceAZPreference, rt.config.Zone) + if reqInfo.RoutePool.LoadBalancingAlgorithm == config.LOAD_BALANCE_HB { + if reqInfo.RoutePool.HashRoutingProperties == nil { + rt.logger.Error("hash-routing-properties-nil", slog.String("host", reqInfo.RoutePool.Host())) + + } else { + headerName := reqInfo.RoutePool.HashRoutingProperties.Header + headerValue := request.Header.Get(headerName) + if headerValue != "" { + iter.(*route.HashBased).HeaderValue = headerValue + } else { + iter = reqInfo.RoutePool.FallBackToDefaultLoadBalancing(rt.config.LoadBalance, rt.logger, stickyEndpointID, mustBeSticky, rt.config.LoadBalanceAZPreference, rt.config.Zone) + } + } + } // The selectEndpointErr needs to be tracked separately. If we get an error // while selecting an endpoint we might just have run out of routes. In diff --git a/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper_test.go b/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper_test.go index 9d270867c..6abe4d218 100644 --- a/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper_test.go +++ b/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper_test.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "math/rand" "net" "net/http" "net/http/httptest" @@ -1700,6 +1701,167 @@ var _ = Describe("ProxyRoundTripper", func() { }) }) + Context("when load-balancing strategy is set to hash-based routing", func() { + JustBeforeEach(func() { + for i := 1; i <= 3; i++ { + endpoint = route.NewEndpoint(&route.EndpointOpts{ + AppId: fmt.Sprintf("appID%d", i), + Host: fmt.Sprintf("%d.%d.%d.%d", i, i, i, i), + Port: 9090, + PrivateInstanceId: fmt.Sprintf("instanceID%d", i), + PrivateInstanceIndex: fmt.Sprintf("%d", i), + AvailabilityZone: AZ, + LoadBalancingAlgorithm: config.LOAD_BALANCE_HB, + HashHeaderName: "X-Hash", + }) + + _ = routePool.Put(endpoint) + Expect(routePool.HashLookupTable).ToNot(BeNil()) + + } + }) + + It("routes requests with same hash header value to the same endpoint", func() { + req.Header.Set("X-Hash", "value") + reqInfo, err := handlers.ContextRequestInfo(req) + Expect(err).ToNot(HaveOccurred()) + reqInfo.RoutePool = routePool + + var selectedEndpoints []*route.Endpoint + + // Make multiple requests with the same hash value + for i := 0; i < 5; i++ { + _, err = proxyRoundTripper.RoundTrip(req) + Expect(err).NotTo(HaveOccurred()) + selectedEndpoints = append(selectedEndpoints, reqInfo.RouteEndpoint) + } + + // All requests should go to the same endpoint + firstEndpoint := selectedEndpoints[0] + for _, ep := range selectedEndpoints[1:] { + Expect(ep.PrivateInstanceId).To(Equal(firstEndpoint.PrivateInstanceId)) + } + }) + + It("routes requests with different hash header values to potentially different endpoints", func() { + reqInfo, err := handlers.ContextRequestInfo(req) + Expect(err).ToNot(HaveOccurred()) + reqInfo.RoutePool = routePool + + endpointDistribution := make(map[string]int) + + // Make requests with different hash values + for i := 0; i < 10; i++ { + req.Header.Set("X-Hash", fmt.Sprintf("value-%d", i)) + _, err = proxyRoundTripper.RoundTrip(req) + Expect(err).NotTo(HaveOccurred()) + endpointDistribution[reqInfo.RouteEndpoint.PrivateInstanceId]++ + } + + // Should distribute across multiple endpoints (not all to one) + Expect(len(endpointDistribution)).To(BeNumerically(">", 1)) + }) + + It("falls back to default load balancing algorithm when hash header is missing", func() { + reqInfo, err := handlers.ContextRequestInfo(req) + Expect(err).ToNot(HaveOccurred()) + + reqInfo.RoutePool = routePool + + _, err = proxyRoundTripper.RoundTrip(req) + Expect(err).NotTo(HaveOccurred()) + + infoLogs := logger.Lines(zap.InfoLevel) + count := 0 + for i := 0; i < len(infoLogs); i++ { + if strings.Contains(infoLogs[i], "hash-based-routing-header-not-found") { + count++ + } + } + Expect(count).To(Equal(1)) + // Verify it still selects an endpoint + Expect(reqInfo.RouteEndpoint).ToNot(BeNil()) + }) + + Context("when sticky session cookies (JSESSIONID and VCAP_ID) are on the request", func() { + var ( + sessionCookie *http.Cookie + cookies []*http.Cookie + ) + + JustBeforeEach(func() { + sessionCookie = &http.Cookie{ + Name: StickyCookieKey, //JSESSIONID + } + transport.RoundTripStub = func(req *http.Request) (*http.Response, error) { + resp := &http.Response{StatusCode: http.StatusTeapot, Header: make(map[string][]string)} + //Attach the same JSESSIONID on to the response if it exists on the request + + if len(req.Cookies()) > 0 { + for _, cookie := range req.Cookies() { + if cookie.Name == StickyCookieKey { + resp.Header.Add(round_tripper.CookieHeader, cookie.String()) + return resp, nil + } + } + } + + sessionCookie.Value, _ = uuid.GenerateUUID() + resp.Header.Add(round_tripper.CookieHeader, sessionCookie.String()) + return resp, nil + } + resp, err := proxyRoundTripper.RoundTrip(req) + Expect(err).ToNot(HaveOccurred()) + + cookies = resp.Cookies() + Expect(cookies).To(HaveLen(2)) + + }) + + Context("when there is a JSESSIONID and __VCAP_ID__ set on the request", func() { + It("will always route to the instance specified with the __VCAP_ID__ cookie", func() { + + // Generate 20 random values for the hash header, so chance that all go to instanceID1 + // by accident is 0.33^20 + for i := 0; i < 20; i++ { + randomStr := make([]byte, 8) + for j := range randomStr { + randomStr[j] = byte('a' + rand.Intn(26)) + } + + req.Header.Set("X-Hash", string(randomStr)) + reqInfo, err := handlers.ContextRequestInfo(req) + req.AddCookie(&http.Cookie{Name: round_tripper.VcapCookieId, Value: "instanceID1"}) + req.AddCookie(&http.Cookie{Name: StickyCookieKey, Value: "abc"}) + + Expect(err).ToNot(HaveOccurred()) + reqInfo.RoutePool = routePool + + resp, err := proxyRoundTripper.RoundTrip(req) + Expect(err).ToNot(HaveOccurred()) + + new_cookies := resp.Cookies() + Expect(new_cookies).To(HaveLen(2)) + + for _, cookie := range new_cookies { + Expect(cookie.Name).To(SatisfyAny( + Equal(StickyCookieKey), + Equal(round_tripper.VcapCookieId), + )) + if cookie.Name == StickyCookieKey { + Expect(cookie.Value).To(Equal("abc")) + } else { + Expect(cookie.Value).To(Equal("instanceID1")) + } + } + + } + + }) + }) + }) + }) + Context("when endpoint timeout is not 0", func() { var reqCh chan *http.Request BeforeEach(func() { diff --git a/src/code.cloudfoundry.org/gorouter/route/hash_based.go b/src/code.cloudfoundry.org/gorouter/route/hash_based.go new file mode 100644 index 000000000..39d551eb8 --- /dev/null +++ b/src/code.cloudfoundry.org/gorouter/route/hash_based.go @@ -0,0 +1,141 @@ +package route + +import ( + "context" + "errors" + "log/slog" + "sync" + + log "code.cloudfoundry.org/gorouter/logger" +) + +// HashBased load balancing algorithm distributes requests based on a hash of a specific header value. +// The sticky session cookie has precedence over hash-based routing and the request should be routed to the instance stored in the cookie. +// If requests do not contain the hash-related header set configured for the hash-based route option, use the default load-balancing algorithm. +type HashBased struct { + lock *sync.Mutex + + logger *slog.Logger + pool *EndpointPool + lastEndpoint *Endpoint + + stickyEndpointID string + mustBeSticky bool + + HeaderValue string +} + +// NewHashBased initializes an endpoint iterator that selects endpoints based on a hash of a header value. +// The global properties locallyOptimistic and localAvailabilityZone will be ignored when using Hash-Based Routing. +func NewHashBased(logger *slog.Logger, p *EndpointPool, initial string, mustBeSticky bool, locallyOptimistic bool, localAvailabilityZone string) EndpointIterator { + return &HashBased{ + logger: logger, + pool: p, + lock: &sync.Mutex{}, + stickyEndpointID: initial, + mustBeSticky: mustBeSticky, + } +} + +// Next selects the next endpoint based on the hash of the header value. +// If a sticky session endpoint is available and not overloaded, it will be returned. +// If the request must be sticky and the sticky endpoint is unavailable or overloaded, nil will be returned. +// If no sticky session is present, the endpoint will be selected based on the hash of the header value. +// It returns the same endpoint for the same header value consistently. +// If the hash lookup fails or the endpoint is not found, nil will be returned. +func (h *HashBased) Next(attempt int) *Endpoint { + h.lock.Lock() + defer h.lock.Unlock() + + e := h.findEndpointIfStickySession() + if e == nil && h.mustBeSticky { + return nil + } + + if e != nil { + h.lastEndpoint = e + return e + } + + if h.pool.HashLookupTable == nil { + h.logger.Error("hash-based-routing-failed", slog.String("host", h.pool.host), log.ErrAttr(errors.New("Lookup table is empty"))) + return nil + } + + id, err := h.pool.HashLookupTable.Get(h.HeaderValue) + + if err != nil { + h.logger.Error( + "hash-based-routing-failed", + slog.String("host", h.pool.host), + log.ErrAttr(err), + ) + return nil + } + + h.logger.Debug( + "hash-based-routing", + slog.String("hash header value", h.HeaderValue), + slog.String("endpoint-id", id), + ) + + endpointElem := h.pool.findById(id) + if endpointElem == nil { + h.logger.Error("hash-based-routing-failed", slog.String("host", h.pool.host), log.ErrAttr(errors.New("Endpoint not found in pool")), slog.String("endpoint-id", id)) + return nil + } + + return endpointElem.endpoint +} + +// findEndpointIfStickySession checks if there is a sticky session endpoint and returns it if available. +// If the sticky session endpoint is overloaded, returns nil. +func (h *HashBased) findEndpointIfStickySession() *Endpoint { + var e *endpointElem + if h.stickyEndpointID != "" { + e = h.pool.findById(h.stickyEndpointID) + if e != nil && e.isOverloaded() { + if h.mustBeSticky { + if h.logger.Enabled(context.Background(), slog.LevelDebug) { + h.logger.Debug("endpoint-overloaded-but-request-must-be-sticky", e.endpoint.ToLogData()...) + } + return nil + } + e = nil + } + + if e == nil && h.mustBeSticky { + h.logger.Debug("endpoint-missing-but-request-must-be-sticky", slog.String("requested-endpoint", h.stickyEndpointID)) + return nil + } + + if !h.mustBeSticky { + h.logger.Debug("endpoint-missing-choosing-alternate", slog.String("requested-endpoint", h.stickyEndpointID)) + h.stickyEndpointID = "" + } + } + + if e != nil { + e.RLock() + defer e.RUnlock() + return e.endpoint + } + return nil +} + +// EndpointFailed notifies the endpoint pool that the last selected endpoint has failed. +func (h *HashBased) EndpointFailed(err error) { + if h.lastEndpoint != nil { + h.pool.EndpointFailed(h.lastEndpoint, err) + } +} + +// PreRequest increments the in-flight request count for the selected endpoint from current Gorouter. +func (h *HashBased) PreRequest(e *Endpoint) { + e.Stats.NumberConnections.Increment() +} + +// PostRequest decrements the in-flight request count for the selected endpoint from current Gorouter. +func (h *HashBased) PostRequest(e *Endpoint) { + e.Stats.NumberConnections.Decrement() +} diff --git a/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go b/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go new file mode 100644 index 000000000..1caaed19c --- /dev/null +++ b/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go @@ -0,0 +1,155 @@ +package route_test + +import ( + "code.cloudfoundry.org/gorouter/config" + _ "errors" + "time" + + "code.cloudfoundry.org/gorouter/route" + "code.cloudfoundry.org/gorouter/test_util" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("HashBased", func() { + var ( + pool *route.EndpointPool + logger *test_util.TestLogger + ) + + BeforeEach(func() { + logger = test_util.NewTestLogger("test") + pool = route.NewPool(&route.PoolOpts{ + Logger: logger.Logger, + RetryAfterFailure: 2 * time.Minute, + Host: "", + ContextPath: "", + MaxConnsPerBackend: 0, + LoadBalancingAlgorithm: config.LOAD_BALANCE_HB, + }) + }) + + Describe("Next", func() { + + Context("when pool is empty", func() { + It("does not select an endpoint", func() { + iter := route.NewHashBased(logger.Logger, pool, "", false, false, "") + Expect(iter.Next(0)).To(BeNil()) + }) + }) + + Context("when pool has endpoints", func() { + var ( + endpoints []*route.Endpoint + ) + BeforeEach(func() { + e1 := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", PrivateInstanceId: "ID1"}) + e2 := route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", PrivateInstanceId: "ID2"}) + endpoints = []*route.Endpoint{e1, e2} + for _, e := range endpoints { + pool.Put(e) + } + + }) + It("It returns the same endpoint for the same header value", func() { + iter := route.NewHashBased(logger.Logger, pool, "", false, false, "") + iter.(*route.HashBased).HeaderValue = "tenant-1" + first := iter.Next(0) + second := iter.Next(0) + Expect(first).NotTo(BeNil()) + Expect(second).NotTo(BeNil()) + Expect(first).To(Equal(second)) + }) + + It("It selects another instance for other hash header value", func() { + iter := route.NewHashBased(logger.Logger, pool, "", false, false, "") + iter.(*route.HashBased).HeaderValue = "example.com" + Expect(iter.Next(0)).NotTo(BeNil()) + Expect(iter.Next(0)).To(Equal(endpoints[1])) + Expect(iter.Next(0)).To(Equal(endpoints[1])) + Expect(iter.Next(0)).To(Equal(endpoints[1])) + }) + }) + + Context("when using sticky sessions", func() { + var ( + endpoints []*route.Endpoint + iter route.EndpointIterator + ) + + BeforeEach(func() { + e1 := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID1"}) + e2 := route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID2"}) + e3 := route.NewEndpoint(&route.EndpointOpts{Host: "3.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", PrivateInstanceId: "ID3"}) + endpoints = []*route.Endpoint{e1, e2, e3} + for _, e := range endpoints { + pool.Put(e) + } + }) + + Context("when mustBeSticky is true", func() { + BeforeEach(func() { + iter = route.NewHashBased(logger.Logger, pool, "ID1", true, false, "") + }) + + It("returns the sticky endpoint when it exists", func() { + endpoint := iter.Next(0) + Expect(endpoint).NotTo(BeNil()) + Expect(endpoint.PrivateInstanceId).To(Equal("ID1")) + }) + + It("returns nil when sticky endpoint doesn't exist", func() { + iter = route.NewHashBased(logger.Logger, pool, "nonexistent-id", true, false, "") + Expect(iter.Next(0)).To(BeNil()) + }) + }) + + Context("when mustBeSticky is false", func() { + BeforeEach(func() { + iter = route.NewHashBased(logger.Logger, pool, "ID1", false, false, "") + }) + + It("returns the sticky endpoint when it exists", func() { + endpoint := iter.Next(0) + Expect(endpoint).NotTo(BeNil()) + Expect(endpoint.PrivateInstanceId).To(Equal("ID1")) + }) + + It("falls back to hash-based routing when sticky endpoint doesn't exist", func() { + iter = route.NewHashBased(logger.Logger, pool, "nonexistent-id", false, false, "") + hashIter := iter.(*route.HashBased) + hashIter.HeaderValue = "some-value" + endpoint := iter.Next(0) + Expect(endpoint).NotTo(BeNil()) + }) + }) + }) + }) + + Context("when testing PreRequest and PostRequest", func() { + var ( + endpoint *route.Endpoint + iter route.EndpointIterator + ) + + BeforeEach(func() { + endpoint = route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID1"}) + pool.Put(endpoint) + iter = route.NewHashBased(logger.Logger, pool, "", false, false, "") + }) + + It("increments connection count on PreRequest", func() { + initialCount := endpoint.Stats.NumberConnections.Count() + iter.PreRequest(endpoint) + Expect(endpoint.Stats.NumberConnections.Count()).To(Equal(initialCount + 1)) + }) + + It("decrements connection count on PostRequest", func() { + iter.PreRequest(endpoint) + initialCount := endpoint.Stats.NumberConnections.Count() + iter.PostRequest(endpoint) + Expect(endpoint.Stats.NumberConnections.Count()).To(Equal(initialCount - 1)) + }) + }) + +}) diff --git a/src/code.cloudfoundry.org/gorouter/route/maglev.go b/src/code.cloudfoundry.org/gorouter/route/maglev.go new file mode 100644 index 000000000..9b70aaa09 --- /dev/null +++ b/src/code.cloudfoundry.org/gorouter/route/maglev.go @@ -0,0 +1,212 @@ +package route + +import ( + "errors" + "fmt" + "hash/fnv" + "log/slog" + "sort" + "strconv" + "strings" + "sync" +) + +const ( + // lookupTableSize is prime number for the size of the maglev lookup table, which should be approximately 100x + // the number of expected endpoints + lookupTableSize uint64 = 1801 +) + +// Maglev implementation of consistent hashing algorithm described in "Maglev: A Fast and Reliable Software Network +// Load Balancer" (https://storage.googleapis.com/gweb-research2023-media/pubtools/2904.pdf) +type Maglev struct { + logger *slog.Logger + permutationTable [][]uint64 + lookupTable []int + endpointList []string + lock *sync.RWMutex +} + +// NewMaglev initializes an empty maglev lookupTable table +func NewMaglev(logger *slog.Logger) *Maglev { + return &Maglev{ + lock: &sync.RWMutex{}, + lookupTable: make([]int, lookupTableSize), + endpointList: make([]string, 0, 2), + permutationTable: make([][]uint64, 0, 2), + logger: logger, + } +} + +// Add a new endpoint to lookupTable if it's not already contained. +func (m *Maglev) Add(endpoint string) { + m.lock.Lock() + defer m.lock.Unlock() + + if lookupTableSize == uint64(len(m.endpointList)) { + m.logger.Warn("maglev-add-lookuptable-capacity-exceeded", slog.String("endpoint-id", endpoint)) + return + } + + index := sort.SearchStrings(m.endpointList, endpoint) + if index < len(m.endpointList) && m.endpointList[index] == endpoint { + m.logger.Debug("maglev-add-lookuptable-endpoint-exists", slog.String("endpoint-id", endpoint), slog.Int("current-endpoints", len(m.endpointList))) + return + } + + m.endpointList = append(m.endpointList, "") + copy(m.endpointList[index+1:], m.endpointList[index:]) + m.endpointList[index] = endpoint + + m.generatePermutation(endpoint) + m.fillLookupTable() +} + +// Remove an endpoint from lookupTable if it's contained. +func (m *Maglev) Remove(endpoint string) { + m.lock.Lock() + defer m.lock.Unlock() + + index := sort.SearchStrings(m.endpointList, endpoint) + if index >= len(m.endpointList) || m.endpointList[index] != endpoint { + m.logger.Debug("maglev-remove-endpoint-not-found", slog.String("endpoint-id", endpoint)) + return + } + + m.endpointList = append(m.endpointList[:index], m.endpointList[index+1:]...) + m.permutationTable = append(m.permutationTable[:index], m.permutationTable[index+1:]...) + + m.fillLookupTable() +} + +// Get endpoint by specified request header value +// Todo: Overload scenario: Get should return an index rather than an instance, +// so that we can iterate to the next endpoint in case it is overloaded (e.g. via another +// helper function that resolves the endpoint via the index) +func (m *Maglev) Get(headerValue string) (string, error) { + m.lock.RLock() + defer m.lock.RUnlock() + + if len(m.endpointList) == 0 { + return "", errors.New("maglev-get-endpoint-no-endpoints") + } + key := m.hashKey(headerValue) + return m.endpointList[m.lookupTable[key%lookupTableSize]], nil +} + +func (m *Maglev) hashKey(headerValue string) uint64 { + return m.calculateFNVHash64(headerValue) +} + +// generatePermutation creates a permutationTable of the lookup table for each endpoint +func (m *Maglev) generatePermutation(endpoint string) { + pos := sort.SearchStrings(m.endpointList, endpoint) + if pos == len(m.endpointList) { + m.logger.Debug("maglev-permutation-no-endpoints") + return + } + + endpointHash := m.calculateFNVHash64(endpoint) + offset := endpointHash % lookupTableSize + skip := (endpointHash % (lookupTableSize - 1)) + 1 + + permutationForEndpoint := make([]uint64, lookupTableSize) + for j := uint64(0); j < lookupTableSize; j++ { + permutationForEndpoint[j] = (offset + j*skip) % lookupTableSize + } + + // insert permutationForEndpoint at position pos, shifting the rest to the right + m.permutationTable = append(m.permutationTable, nil) + copy(m.permutationTable[pos+1:], m.permutationTable[pos:]) + m.permutationTable[pos] = permutationForEndpoint + +} + +func (m *Maglev) fillLookupTable() { + if len(m.endpointList) == 0 { + return + } + + numberOfEndpoints := len(m.endpointList) + next := make([]int, numberOfEndpoints) + entry := make([]int, lookupTableSize) + for j := range entry { + entry[j] = -1 + } + + for n := uint64(0); n <= lookupTableSize; { + for i := 0; i < numberOfEndpoints; i++ { + candidate := m.findNextAvailableSlot(i, next, entry) + entry[candidate] = int(i) + next[i] = next[i] + 1 + n++ + + if n == lookupTableSize { + m.lookupTable = entry + return + } + } + } +} + +func (m *Maglev) findNextAvailableSlot(i int, next []int, entry []int) uint64 { + candidate := m.permutationTable[i][next[i]] + for entry[candidate] >= 0 { + next[i]++ + if next[i] >= len(m.permutationTable[i]) { + // This should not happen in a properly functioning Maglev algorithm, + // but we add this safety check to prevent panic + m.logger.Error("maglev-permutation-table-exhausted", + slog.Int("endpoint-index", i), + slog.Int("next-value", next[i]), + slog.Int("table-size", len(m.permutationTable[i]))) + // Reset to beginning of permutation table as fallback + next[i] = 0 + } + candidate = m.permutationTable[i][next[i]] + } + return candidate +} + +// Getters for unit tests +func (m *Maglev) GetEndpointList() []string { + m.lock.RLock() + defer m.lock.RUnlock() + return append([]string(nil), m.endpointList...) +} + +func (m *Maglev) GetLookupTable() []int { + m.lock.RLock() + defer m.lock.RUnlock() + return append([]int(nil), m.lookupTable...) +} + +func (m *Maglev) GetPermutationTable() [][]uint64 { + m.lock.RLock() + defer m.lock.RUnlock() + copied := make([][]uint64, len(m.permutationTable)) + for i, v := range m.permutationTable { + copied[i] = append([]uint64(nil), v...) + } + return copied +} + +func (m *Maglev) GetLookupTableSize() uint64 { + return lookupTableSize +} + +// TODO: Remove in final version +func (m *Maglev) PrintLookupTable() string { + strArr := make([]string, len(m.lookupTable)) + for i, value := range m.lookupTable { + strArr[i] = strconv.Itoa(value) + } + return fmt.Sprintf("[%s]", strings.Join(strArr, ", ")) +} + +// calculateFNVHash64 computes a hash using the non-cryptographic FNV hash algorithm. +func (m *Maglev) calculateFNVHash64(key string) uint64 { + h := fnv.New64a() + _, _ = h.Write([]byte(key)) + return h.Sum64() +} diff --git a/src/code.cloudfoundry.org/gorouter/route/maglev_test.go b/src/code.cloudfoundry.org/gorouter/route/maglev_test.go new file mode 100644 index 000000000..ae8af9d07 --- /dev/null +++ b/src/code.cloudfoundry.org/gorouter/route/maglev_test.go @@ -0,0 +1,257 @@ +package route_test + +import ( + "fmt" + "strconv" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "code.cloudfoundry.org/gorouter/route" + "code.cloudfoundry.org/gorouter/test_util" +) + +var _ = Describe("Maglev", func() { + var ( + logger *test_util.TestLogger + maglev *route.Maglev + ) + + BeforeEach(func() { + logger = test_util.NewTestLogger("test") + + maglev = route.NewMaglev(logger.Logger) + }) + + Describe("NewMaglev", func() { + It("should create a new Maglev instance", func() { + Expect(maglev).NotTo(BeNil()) + }) + }) + + Describe("Add", func() { + Context("when adding a new backend", func() { + It("should add the backend successfully", func() { + maglev.Add("backend1") + + Expect(maglev.GetEndpointList()).To(HaveLen(1)) + Expect(maglev.GetLookupTable()).To(HaveLen(int(maglev.GetLookupTableSize()))) + Expect(maglev.GetPermutationTable()).To(HaveLen(1)) + Expect(maglev.GetPermutationTable()[0]).To(HaveLen(int(maglev.GetLookupTableSize()))) + + result, err := maglev.Get("test-key") + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(Equal("backend1")) + }) + }) + + Context("when adding a backend twice", func() { + It("should skip adding subsequent adds", func() { + maglev.Add("backend1") + maglev.Add("backend1") + + Expect(maglev.GetEndpointList()).To(HaveLen(1)) + Expect(maglev.GetLookupTable()).To(HaveLen(int(maglev.GetLookupTableSize()))) + Expect(maglev.GetPermutationTable()).To(HaveLen(1)) + Expect(maglev.GetPermutationTable()[0]).To(HaveLen(int(maglev.GetLookupTableSize()))) + + result, err := maglev.Get("test-key") + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(Equal("backend1")) + }) + }) + + Context("when adding multiple backends", func() { + It("should make all backends reachable", func() { + maglev.Add("backend1") + maglev.Add("backend2") + maglev.Add("backend3") + + Expect(maglev.GetEndpointList()).To(HaveLen(3)) + Expect(maglev.GetLookupTable()).To(HaveLen(int(maglev.GetLookupTableSize()))) + Expect(maglev.GetPermutationTable()).To(HaveLen(len(maglev.GetEndpointList()))) + for i := range len(maglev.GetEndpointList()) { + Expect(maglev.GetPermutationTable()[i]).To(HaveLen(int(maglev.GetLookupTableSize()))) + } + + backends := make(map[string]bool) + for i := 0; i < 1000; i++ { + result, err := maglev.Get(string(rune(i))) + Expect(err).NotTo(HaveOccurred()) + backends[result] = true + } + + Expect(backends["backend1"]).To(BeTrue()) + Expect(backends["backend2"]).To(BeTrue()) + Expect(backends["backend3"]).To(BeTrue()) + }) + }) + }) + + Describe("Remove", func() { + Context("when removing an existing backend", func() { + It("should remove the backend successfully", func() { + maglev.Add("backend1") + maglev.Add("backend2") + + maglev.Remove("backend1") + + Expect(maglev.GetEndpointList()).To(HaveLen(1)) + Expect(maglev.GetLookupTable()).To(HaveLen(int(maglev.GetLookupTableSize()))) + Expect(maglev.GetPermutationTable()).To(HaveLen(1)) + Expect(maglev.GetPermutationTable()[0]).To(HaveLen(int(maglev.GetLookupTableSize()))) + + }) + }) + + Context("when removing a non-existent backend", func() { + It("should handle gracefully without error", func() { + maglev.Add("backend1") + + Expect(func() { maglev.Remove("non-existent") }).NotTo(Panic()) + + Expect(maglev.GetEndpointList()).To(HaveLen(1)) + Expect(maglev.GetLookupTable()).To(HaveLen(int(maglev.GetLookupTableSize()))) + Expect(maglev.GetPermutationTable()).To(HaveLen(1)) + Expect(maglev.GetPermutationTable()[0]).To(HaveLen(int(maglev.GetLookupTableSize()))) + }) + }) + }) + + Describe("Get", func() { + Context("when no backends were added", func() { + It("should return an error", func() { + _, err := maglev.Get("test-key") + Expect(err).To(HaveOccurred()) + }) + }) + + Context("when backends are added", func() { + BeforeEach(func() { + maglev.Add("backend1") + maglev.Add("backend2") + }) + + It("should return consistent results for the same key", func() { + var counter = make(map[string]int) + var result1 string + var err error + for _ = range 100 { + result1, err = maglev.Get("consistent-key") + Expect(err).NotTo(HaveOccurred()) + counter[result1]++ + } + + Expect(counter[result1]).To(Equal(100)) + }) + + It("should distribute keys across backends", func() { + maglev.Add("backend1") + maglev.Add("backend2") + maglev.Add("backend3") + + distribution := make(map[string]int) + for i := range 1000 { + result, err := maglev.Get(string(rune(i))) + Expect(err).NotTo(HaveOccurred()) + distribution[result]++ + } + + Expect(distribution["backend1"]).To(BeNumerically(">", 0)) + Expect(distribution["backend2"]).To(BeNumerically(">", 0)) + Expect(distribution["backend3"]).To(BeNumerically(">", 0)) + }) + }) + + Context("when backends are removed", func() { + BeforeEach(func() { + maglev.Add("backend1") + maglev.Add("backend2") + maglev.Remove("backend1") + }) + + It("should not return the removed backend", func() { + for _ = range 100 { + endpoint, err := maglev.Get("consistent-key") + Expect(err).NotTo(HaveOccurred()) + Expect(endpoint).To(Equal("backend2")) + } + }) + }) + }) + + Describe("Consistency", func() { + // We test that at most half the keys are reassigned to new backends, when one backend is added. + // This ensures a minimal level of consistency. + It("should minimize disruption when adding backends", func() { + for i := range 10 { + maglev.Add(fmt.Sprintf("backend%d", i+1)) + } + keys := make([]string, 1000) + for i := range keys { + keys[i] = fmt.Sprintf("key%d", i+1) + } + + initialMappings := make(map[string]string) + + for _, key := range keys { + backend, err := maglev.Get(key) + Expect(err).NotTo(HaveOccurred()) + initialMappings[key] = backend + } + + maglev.Add("newbackend") + + changedMappings := 0 + for _, key := range keys { + backend, err := maglev.Get(key) + Expect(err).NotTo(HaveOccurred()) + if initialMappings[key] != backend { + changedMappings++ + } + } + + Expect(changedMappings).To(BeNumerically("<=", len(keys)/2)) + }) + }) + + Describe("Concurrency", func() { + It("should handle concurrent reads safely", func() { + maglev.Add("backend1") + + done := make(chan bool, 10) + for i := 0; i < 10; i++ { + go func() { + defer GinkgoRecover() + for j := 0; j < 100; j++ { + _, err := maglev.Get("test-key") + Expect(err).NotTo(HaveOccurred()) + } + done <- true + }() + } + + for i := 0; i < 10; i++ { + Eventually(done).Should(Receive()) + } + }) + It("should handle concurrent endpoint registrations safely", func() { + done := make(chan bool, 10) + for i := 0; i < 10; i++ { + go func() { + defer GinkgoRecover() + for j := 0; j < 100; j++ { + Expect(func() { maglev.Add("endpoint" + strconv.Itoa(j)) }).NotTo(Panic()) + } + done <- true + }() + } + + for i := 0; i < 10; i++ { + Eventually(done).Should(Receive()) + } + Expect(len(maglev.GetEndpointList())).To(Equal(100)) + }) + + }) +}) diff --git a/src/code.cloudfoundry.org/gorouter/route/pool.go b/src/code.cloudfoundry.org/gorouter/route/pool.go index f089fc15b..b9f491798 100644 --- a/src/code.cloudfoundry.org/gorouter/route/pool.go +++ b/src/code.cloudfoundry.org/gorouter/route/pool.go @@ -74,6 +74,21 @@ type ProxyRoundTripper interface { CancelRequest(*http.Request) } +type HashRoutingProperties struct { + Header string + BalanceFactor float64 +} + +func (hrp *HashRoutingProperties) Equal(hrp2 *HashRoutingProperties) bool { + if hrp == nil && hrp2 == nil { + return true + } + if hrp == nil || hrp2 == nil { + return false + } + return hrp.Header == hrp2.Header && hrp.BalanceFactor == hrp2.BalanceFactor +} + type Endpoint struct { ApplicationId string AvailabilityZone string @@ -186,6 +201,8 @@ type EndpointPool struct { logger *slog.Logger updatedAt time.Time LoadBalancingAlgorithm string + HashRoutingProperties *HashRoutingProperties + HashLookupTable *Maglev } type EndpointOpts struct { @@ -248,10 +265,12 @@ type PoolOpts struct { MaxConnsPerBackend int64 Logger *slog.Logger LoadBalancingAlgorithm string + HashHeader string + HashBalanceFactor float64 } func NewPool(opts *PoolOpts) *EndpointPool { - return &EndpointPool{ + pool := &EndpointPool{ endpoints: make([]*endpointElem, 0, 1), index: make(map[string]*endpointElem), retryAfterFailure: opts.RetryAfterFailure, @@ -264,6 +283,14 @@ func NewPool(opts *PoolOpts) *EndpointPool { updatedAt: time.Now(), LoadBalancingAlgorithm: opts.LoadBalancingAlgorithm, } + if pool.LoadBalancingAlgorithm == config.LOAD_BALANCE_HB { + pool.HashLookupTable = NewMaglev(opts.Logger) + pool.HashRoutingProperties = &HashRoutingProperties{ + Header: opts.HashHeader, + BalanceFactor: opts.HashBalanceFactor, + } + } + return pool } func PoolsMatch(p1, p2 *EndpointPool) bool { @@ -320,7 +347,6 @@ func (p *EndpointPool) Put(endpoint *Endpoint) PoolPutResult { // new one. e.Lock() defer e.Unlock() - oldEndpoint := e.endpoint e.endpoint = endpoint @@ -336,6 +362,9 @@ func (p *EndpointPool) Put(endpoint *Endpoint) PoolPutResult { p.RouteSvcUrl = e.endpoint.RouteServiceUrl p.setPoolLoadBalancingAlgorithm(e.endpoint) e.updated = time.Now() + if p.LoadBalancingAlgorithm == config.LOAD_BALANCE_HB { + p.HashLookupTable.Add(e.endpoint.PrivateInstanceId) + } p.Update() return EndpointUpdated @@ -348,7 +377,6 @@ func (p *EndpointPool) Put(endpoint *Endpoint) PoolPutResult { updated: time.Now(), maxConnsPerBackend: p.maxConnsPerBackend, } - p.endpoints = append(p.endpoints, e) p.index[endpoint.CanonicalAddr()] = e @@ -356,6 +384,9 @@ func (p *EndpointPool) Put(endpoint *Endpoint) PoolPutResult { p.RouteSvcUrl = e.endpoint.RouteServiceUrl p.setPoolLoadBalancingAlgorithm(e.endpoint) + if p.LoadBalancingAlgorithm == config.LOAD_BALANCE_HB { + p.HashLookupTable.Add(e.endpoint.PrivateInstanceId) + } p.Update() return EndpointAdded @@ -433,6 +464,11 @@ func (p *EndpointPool) removeEndpoint(e *endpointElem) { delete(p.index, e.endpoint.CanonicalAddr()) delete(p.index, e.endpoint.PrivateInstanceId) p.Update() + + if p.LoadBalancingAlgorithm == config.LOAD_BALANCE_HB { + p.HashLookupTable.Remove(e.endpoint.PrivateInstanceId) + } + } func (p *EndpointPool) Endpoints(logger *slog.Logger, initial string, mustBeSticky bool, azPreference string, az string) EndpointIterator { @@ -443,6 +479,9 @@ func (p *EndpointPool) Endpoints(logger *slog.Logger, initial string, mustBeStic case config.LOAD_BALANCE_RR: logger.Debug("endpoint-iterator-with-round-robin-lb-algo") return NewRoundRobin(logger, p, initial, mustBeSticky, azPreference == config.AZ_PREF_LOCAL, az) + case config.LOAD_BALANCE_HB: + logger.Debug("endpoint-iterator-with-hash-based-lb-algo") + return NewHashBased(logger, p, initial, mustBeSticky, azPreference == config.AZ_PREF_LOCAL, az) default: logger.Error("invalid-pool-load-balancing-algorithm", slog.String("poolLBAlgorithm", p.LoadBalancingAlgorithm), @@ -452,6 +491,23 @@ func (p *EndpointPool) Endpoints(logger *slog.Logger, initial string, mustBeStic } } +func (p *EndpointPool) FallBackToDefaultLoadBalancing(defaultLBAlgo string, logger *slog.Logger, initial string, mustBeSticky bool, azPreference string, az string) EndpointIterator { + logger.Info("hash-based-routing-header-not-found", + slog.String("poolLBAlgorithm", p.LoadBalancingAlgorithm), + slog.String("Host", p.host), + slog.String("Path", p.contextPath)) + + switch defaultLBAlgo { + case config.LOAD_BALANCE_LC: + logger.Debug("endpoint-iterator-with-least-connection-lb-algo") + return NewLeastConnection(logger, p, initial, mustBeSticky, azPreference == config.AZ_PREF_LOCAL, az) + case config.LOAD_BALANCE_RR: + logger.Debug("endpoint-iterator-with-round-robin-lb-algo") + return NewRoundRobin(logger, p, initial, mustBeSticky, azPreference == config.AZ_PREF_LOCAL, az) + } + return NewRoundRobin(logger, p, initial, mustBeSticky, azPreference == config.AZ_PREF_LOCAL, az) +} + func (p *EndpointPool) NumEndpoints() int { p.Lock() defer p.Unlock() @@ -561,12 +617,13 @@ func (p *EndpointPool) MarshalJSON() ([]byte, error) { // setPoolLoadBalancingAlgorithm overwrites the load balancing algorithm of a pool by that of a specified endpoint, if that is valid. func (p *EndpointPool) setPoolLoadBalancingAlgorithm(endpoint *Endpoint) { - if len(endpoint.LoadBalancingAlgorithm) > 0 && endpoint.LoadBalancingAlgorithm != p.LoadBalancingAlgorithm { + if endpoint.LoadBalancingAlgorithm != "" && endpoint.LoadBalancingAlgorithm != p.LoadBalancingAlgorithm { if config.IsLoadBalancingAlgorithmValid(endpoint.LoadBalancingAlgorithm) { p.LoadBalancingAlgorithm = endpoint.LoadBalancingAlgorithm p.logger.Debug("setting-pool-load-balancing-algorithm-to-that-of-an-endpoint", slog.String("endpointLBAlgorithm", endpoint.LoadBalancingAlgorithm), slog.String("poolLBAlgorithm", p.LoadBalancingAlgorithm)) + p.prepareHashBasedRouting(endpoint) } else { p.logger.Error("invalid-endpoint-load-balancing-algorithm-provided-keeping-pool-lb-algo", slog.String("endpointLBAlgorithm", endpoint.LoadBalancingAlgorithm), @@ -575,6 +632,20 @@ func (p *EndpointPool) setPoolLoadBalancingAlgorithm(endpoint *Endpoint) { } } +func (p *EndpointPool) prepareHashBasedRouting(endpoint *Endpoint) { + if p.LoadBalancingAlgorithm != config.LOAD_BALANCE_HB { + return + } + if p.HashLookupTable == nil { + p.HashLookupTable = NewMaglev(p.logger) + } + p.HashRoutingProperties = &HashRoutingProperties{ + Header: endpoint.HashHeaderName, + BalanceFactor: endpoint.HashBalanceFactor, + } + +} + func (e *endpointElem) failed() { t := time.Now() e.failedAt = &t diff --git a/src/code.cloudfoundry.org/gorouter/route/pool_test.go b/src/code.cloudfoundry.org/gorouter/route/pool_test.go index 31da6c8d7..7709a1d8b 100644 --- a/src/code.cloudfoundry.org/gorouter/route/pool_test.go +++ b/src/code.cloudfoundry.org/gorouter/route/pool_test.go @@ -428,6 +428,46 @@ var _ = Describe("EndpointPool", func() { Expect(pool.LoadBalancingAlgorithm).To(Equal(config.LOAD_BALANCE_RR)) }) }) + + Context("When switching to hash-based routing", func() { + It("will create the maglev table and add the endpoint", func() { + pool := route.NewPool(&route.PoolOpts{ + Logger: logger.Logger, + LoadBalancingAlgorithm: config.LOAD_BALANCE_RR, + }) + + endpointOpts := route.EndpointOpts{ + Host: "host-1", + Port: 1234, + RouteServiceUrl: "url", + LoadBalancingAlgorithm: config.LOAD_BALANCE_RR, + } + + initalEndpoint := route.NewEndpoint(&endpointOpts) + + pool.Put(initalEndpoint) + Expect(pool.LoadBalancingAlgorithm).To(Equal(config.LOAD_BALANCE_RR)) + + endpointOptsHash := route.EndpointOpts{ + Host: "host-1", + Port: 1234, + RouteServiceUrl: "url", + LoadBalancingAlgorithm: config.LOAD_BALANCE_HB, + HashBalanceFactor: 1.25, + HashHeaderName: "X-Tenant", + } + + hashEndpoint := route.NewEndpoint(&endpointOptsHash) + + pool.Put(hashEndpoint) + Expect(pool.LoadBalancingAlgorithm).To(Equal(config.LOAD_BALANCE_HB)) + Expect(pool.HashLookupTable).ToNot(BeNil()) + Expect(pool.HashLookupTable.GetEndpointList()).To(HaveLen(1)) + Expect(pool.HashLookupTable.GetEndpointList()[0]).To(Equal(hashEndpoint.PrivateInstanceId)) + }) + + }) + }) Context("RouteServiceUrl", func() { From 4c6982ff38c021352cb6f60d4c818873cb63595e Mon Sep 17 00:00:00 2001 From: Clemens Hoffmann Date: Wed, 29 Oct 2025 14:32:37 +0100 Subject: [PATCH 02/17] Add LICENSE information for maglev.go --- src/code.cloudfoundry.org/gorouter/route/maglev.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/code.cloudfoundry.org/gorouter/route/maglev.go b/src/code.cloudfoundry.org/gorouter/route/maglev.go index 9b70aaa09..6085c7d0f 100644 --- a/src/code.cloudfoundry.org/gorouter/route/maglev.go +++ b/src/code.cloudfoundry.org/gorouter/route/maglev.go @@ -1,5 +1,15 @@ package route +/****************************************************************************** + * Original github.com/kkdai/maglev/maglev.go + * + * Copyright (c) 2019 Evan Lin (github.com/kkdai) + * + * This program and the accompanying materials are made available under + * the terms of the Apache License, Version 2.0 which is available at + * http://www.apache.org/licenses/LICENSE-2.0. + ******************************************************************************/ + import ( "errors" "fmt" From 5981bcb5958b1914f4959745a8dbc79d11cccfb8 Mon Sep 17 00:00:00 2001 From: Tamara Boehm Date: Wed, 22 Oct 2025 08:54:13 +0200 Subject: [PATCH 03/17] Implement overflow traffic --- .../gorouter/route/hash_based.go | 128 +++++++-- .../gorouter/route/hash_based_test.go | 261 +++++++++++++++++- .../gorouter/route/maglev.go | 58 +++- .../gorouter/route/maglev_test.go | 128 +++++++-- .../gorouter/route/pool.go | 17 +- 5 files changed, 525 insertions(+), 67 deletions(-) diff --git a/src/code.cloudfoundry.org/gorouter/route/hash_based.go b/src/code.cloudfoundry.org/gorouter/route/hash_based.go index 39d551eb8..ff98f072a 100644 --- a/src/code.cloudfoundry.org/gorouter/route/hash_based.go +++ b/src/code.cloudfoundry.org/gorouter/route/hash_based.go @@ -15,9 +15,10 @@ import ( type HashBased struct { lock *sync.Mutex - logger *slog.Logger - pool *EndpointPool - lastEndpoint *Endpoint + logger *slog.Logger + pool *EndpointPool + lastEndpoint *Endpoint + lastLookupTableIndex uint64 stickyEndpointID string mustBeSticky bool @@ -47,14 +48,14 @@ func (h *HashBased) Next(attempt int) *Endpoint { h.lock.Lock() defer h.lock.Unlock() - e := h.findEndpointIfStickySession() - if e == nil && h.mustBeSticky { + endpoint := h.findEndpointIfStickySession() + if endpoint == nil && h.mustBeSticky { return nil } - if e != nil { - h.lastEndpoint = e - return e + if endpoint != nil { + h.lastEndpoint = endpoint + return endpoint } if h.pool.HashLookupTable == nil { @@ -62,30 +63,92 @@ func (h *HashBased) Next(attempt int) *Endpoint { return nil } - id, err := h.pool.HashLookupTable.Get(h.HeaderValue) + if attempt == 0 || h.lastLookupTableIndex == 0 { + initialLookupTableIndex, _, err := h.pool.HashLookupTable.GetInstanceForHashHeader(h.HeaderValue) - if err != nil { - h.logger.Error( - "hash-based-routing-failed", - slog.String("host", h.pool.host), - log.ErrAttr(err), - ) - return nil + if err != nil { + h.logger.Error( + "hash-based-routing-failed", + slog.String("host", h.pool.host), + log.ErrAttr(err), + ) + return nil + } + + endpoint = h.findEndpoint(initialLookupTableIndex, attempt) + } else { + // On retries, start looking from the next index in the lookup table + nextIndex := (h.lastLookupTableIndex + 1) % h.pool.HashLookupTable.GetLookupTableSize() + endpoint = h.findEndpoint(nextIndex, attempt) } - h.logger.Debug( - "hash-based-routing", - slog.String("hash header value", h.HeaderValue), - slog.String("endpoint-id", id), - ) + if endpoint != nil { + h.lastEndpoint = endpoint + } + return endpoint +} - endpointElem := h.pool.findById(id) - if endpointElem == nil { - h.logger.Error("hash-based-routing-failed", slog.String("host", h.pool.host), log.ErrAttr(errors.New("Endpoint not found in pool")), slog.String("endpoint-id", id)) +func (h *HashBased) findEndpoint(index uint64, attempt int) *Endpoint { + maxIterations := len(h.pool.endpoints) + if maxIterations == 0 { return nil } - return endpointElem.endpoint + // Ensure we don't exceed the lookup table size + lookupTableSize := h.pool.HashLookupTable.GetLookupTableSize() + + // Normalize index + currentIndex := index % lookupTableSize + // Keep track of endpoints already visited, to avoid visiting them twice + visitedEndpoints := make(map[string]bool) + + numberOfEndpoints := len(h.pool.HashLookupTable.GetEndpointList()) + + lastEndpointPrivateId := "" + if attempt > 0 && h.lastEndpoint != nil { + lastEndpointPrivateId = h.lastEndpoint.PrivateInstanceId + } + + // abort when we have visited all available endpoints unsuccessfully + for len(visitedEndpoints) < numberOfEndpoints { + id := h.pool.HashLookupTable.GetEndpointId(currentIndex) + + if visitedEndpoints[id] || id == lastEndpointPrivateId { + currentIndex = (currentIndex + 1) % lookupTableSize + continue + } + visitedEndpoints[id] = true + + endpointElem := h.pool.findById(id) + if endpointElem == nil { + h.logger.Error("hash-based-routing-failed", slog.String("host", h.pool.host), log.ErrAttr(errors.New("Endpoint not found in pool")), slog.String("endpoint-id", id)) + currentIndex = (currentIndex + 1) % lookupTableSize + continue + } + + lastEndpointPrivateId = id + + e := endpointElem.endpoint + if h.pool.HashRoutingProperties.BalanceFactor <= 0 || !h.isOverloaded(e) { + h.lastLookupTableIndex = currentIndex + return e + } + + currentIndex = (currentIndex + 1) % lookupTableSize + } + // All endpoints checked and overloaded or not found + h.logger.Error("hash-based-routing-failed", slog.String("host", h.pool.host), log.ErrAttr(errors.New("All endpoints are overloaded"))) + return nil +} + +func (h *HashBased) isOverloaded(e *Endpoint) bool { + avgLoad := h.CalculateAverageLoad() + balanceFactor := h.pool.HashRoutingProperties.BalanceFactor + if float64(e.Stats.NumberConnections.Count())/avgLoad > balanceFactor { + h.logger.Info("hash-based-routing-endpoint-overloaded", slog.String("host", h.pool.host), slog.String("endpoint-id", e.PrivateInstanceId), slog.Int64("endpoint-connections", e.Stats.NumberConnections.Count()), slog.Float64("average-load", avgLoad)) + return true + } + return false } // findEndpointIfStickySession checks if there is a sticky session endpoint and returns it if available. @@ -139,3 +202,18 @@ func (h *HashBased) PreRequest(e *Endpoint) { func (h *HashBased) PostRequest(e *Endpoint) { e.Stats.NumberConnections.Decrement() } + +func (h *HashBased) CalculateAverageLoad() float64 { + if len(h.pool.endpoints) == 0 { + return 0 + } + + var currentInFlightRequestCount int64 + for _, endpointElem := range h.pool.endpoints { + endpointElem.RLock() + currentInFlightRequestCount += endpointElem.endpoint.Stats.NumberConnections.Count() + endpointElem.RUnlock() + } + + return float64(currentInFlightRequestCount) / float64(len(h.pool.endpoints)) +} diff --git a/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go b/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go index 1caaed19c..26df2c5b4 100644 --- a/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go +++ b/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go @@ -1,10 +1,12 @@ package route_test import ( - "code.cloudfoundry.org/gorouter/config" _ "errors" + "hash/fnv" "time" + "code.cloudfoundry.org/gorouter/config" + "code.cloudfoundry.org/gorouter/route" "code.cloudfoundry.org/gorouter/test_util" . "github.com/onsi/ginkgo/v2" @@ -24,8 +26,9 @@ var _ = Describe("HashBased", func() { RetryAfterFailure: 2 * time.Minute, Host: "", ContextPath: "", - MaxConnsPerBackend: 0, + MaxConnsPerBackend: 500, LoadBalancingAlgorithm: config.LOAD_BALANCE_HB, + HashHeader: "tenant-id", }) }) @@ -60,14 +63,110 @@ var _ = Describe("HashBased", func() { Expect(second).NotTo(BeNil()) Expect(first).To(Equal(second)) }) + }) + + Context("when endpoint overloaded", func() { + var ( + endpoints []*route.Endpoint + e1 *route.Endpoint + e2 *route.Endpoint + e3 *route.Endpoint + ) + It("It returns the next endpoint for the same header value when balancer factor set", func() { + e1 = route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 1.2, PrivateInstanceId: "ID1"}) + e2 = route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 1.2, PrivateInstanceId: "ID2"}) + e3 = route.NewEndpoint(&route.EndpointOpts{Host: "3.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 1.2, PrivateInstanceId: "ID3"}) + endpoints = []*route.Endpoint{e1, e2, e3} + for _, e := range endpoints { + pool.Put(e) + } + iter := route.NewHashBased(logger.Logger, pool, "", false, false, "") + iter.(*route.HashBased).HeaderValue = "tenant-1" + first := iter.Next(0) + Expect(iter.Next(0)).To(Equal(first)) + for i := 0; i < 6; i++ { + iter.PreRequest(first) + } + second := iter.Next(0) + Expect(second).NotTo(Equal(first)) + }) + It("It returns the same overloaded endpoint for the same header value when balancer factor not set", func() { + e1 = route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 0, PrivateInstanceId: "ID1"}) + e2 = route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 0, PrivateInstanceId: "ID2"}) + e3 = route.NewEndpoint(&route.EndpointOpts{Host: "3.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 0, PrivateInstanceId: "ID3"}) + endpoints = []*route.Endpoint{e1, e2, e3} + for _, e := range endpoints { + pool.Put(e) + } + iter := route.NewHashBased(logger.Logger, pool, "", false, false, "") + iter.(*route.HashBased).HeaderValue = "tenant-1" + first := iter.Next(0) + Expect(iter.Next(0)).To(Equal(first)) + for i := 0; i < 6; i++ { + iter.PreRequest(first) + } + second := iter.Next(0) + Expect(second).To(Equal(first)) + }) + + }) + + Context("with retries", func() { + var ( + endpoints []*route.Endpoint + e1 *route.Endpoint + e2 *route.Endpoint + e3 *route.Endpoint + e4 *route.Endpoint + MaglevLookupTable = []int{2, 2, 1, 0, 1, 0, 0, 0, 2, 0, 1, 3, 1, 0, 1, 0, 3, 0, 3, 0, 0, 0, 1, 0, 1, 2, 2, 0, 3, 2, 3, 0, 1, 0, 1, 0, 3, 3, 2, 0, 3, 1, 2, 0, 3, 0, 1, 0, 2, 3, 2, 3, 2, 0, 1, 2, 1, 0, 3, 2, 2, 1, 1, 2, 1, 3, 1, 2, 2, 0, 3, 2, 3, 1, 1, 3, 1, 3, 1, 0, 2, 1, 3, 1, 2, 2, 1, 3, 2, 2, 2, 3, 3, 1, 3, 0, 3, 2, 3, 3, 0} + ) + It("It returns next endpoint from maglev lookup table", func() { + e1 = route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", PrivateInstanceId: "ID1"}) + e2 = route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", PrivateInstanceId: "ID2"}) + e3 = route.NewEndpoint(&route.EndpointOpts{Host: "3.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", PrivateInstanceId: "ID3"}) + e4 = route.NewEndpoint(&route.EndpointOpts{Host: "4.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", PrivateInstanceId: "ID4"}) + + endpoints = []*route.Endpoint{e1, e2, e3, e4} + endpointIDList := make([]string, 0, 4) + for _, e := range endpoints { + pool.Put(e) + endpointIDList = append(endpointIDList, e.PrivateInstanceId) + } + maglevMock := NewMockHashLookupTable(MaglevLookupTable, endpointIDList) + pool.HashLookupTable = maglevMock + iter := route.NewHashBased(logger.Logger, pool, "", false, false, "") + iter.(*route.HashBased).HeaderValue = "tenant-1" + // The returned endpoint has always ID3 according to the Maglev lookup table + first := iter.Next(0) + Expect(first).To(Equal(e4)) + second := iter.Next(1) + Expect(second).To(Equal(e1)) + third := iter.Next(2) + Expect(third).To(Equal(e4)) + }) + It("It returns the next not overloaded endpoint for the second attempt", func() { + e1 = route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 1.2, PrivateInstanceId: "ID1"}) + e2 = route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 1.2, PrivateInstanceId: "ID2"}) + e3 = route.NewEndpoint(&route.EndpointOpts{Host: "3.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 1.2, PrivateInstanceId: "ID3"}) + e4 = route.NewEndpoint(&route.EndpointOpts{Host: "4.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", PrivateInstanceId: "ID3"}) - It("It selects another instance for other hash header value", func() { + endpoints = []*route.Endpoint{e1, e2, e3, e4} + for _, e := range endpoints { + pool.Put(e) + } iter := route.NewHashBased(logger.Logger, pool, "", false, false, "") - iter.(*route.HashBased).HeaderValue = "example.com" - Expect(iter.Next(0)).NotTo(BeNil()) - Expect(iter.Next(0)).To(Equal(endpoints[1])) - Expect(iter.Next(0)).To(Equal(endpoints[1])) - Expect(iter.Next(0)).To(Equal(endpoints[1])) + iter.(*route.HashBased).HeaderValue = "tenant-1" + firstAttemptResult := iter.Next(0) + Expect(iter.Next(0)).To(Equal(firstAttemptResult)) + for i := 0; i < 6; i++ { + // Simulate requests to overload the endpoints + iter.PreRequest(e1) + iter.PreRequest(e2) + } + secondAttemptResult := iter.Next(1) + Expect(secondAttemptResult).NotTo(Equal(firstAttemptResult)) + Expect(secondAttemptResult).NotTo(Equal(e1)) + Expect(secondAttemptResult).NotTo(Equal(e2)) }) }) @@ -102,6 +201,13 @@ var _ = Describe("HashBased", func() { iter = route.NewHashBased(logger.Logger, pool, "nonexistent-id", true, false, "") Expect(iter.Next(0)).To(BeNil()) }) + It("returns nil when sticky endpoint is overloaded and mustBeSticky is true", func() { + iter = route.NewHashBased(logger.Logger, pool, "ID1", true, false, "") + for i := 0; i < 1000; i++ { + iter.PreRequest(endpoints[0]) + } + Expect(iter.Next(0)).To(BeNil()) + }) }) Context("when mustBeSticky is false", func() { @@ -151,5 +257,144 @@ var _ = Describe("HashBased", func() { Expect(endpoint.Stats.NumberConnections.Count()).To(Equal(initialCount - 1)) }) }) + Describe("CalculateAverageLoad", func() { + var iter *route.HashBased + var endpoints []*route.Endpoint + + BeforeEach(func() { + iter = route.NewHashBased(logger.Logger, pool, "", false, false, "").(*route.HashBased) + }) + + Context("when there are no endpoints", func() { + It("returns 0", func() { + Expect(iter.CalculateAverageLoad()).To(Equal(float64(0))) + }) + }) + + Context("when all endpoints have zero connections", func() { + BeforeEach(func() { + pool.Put(route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID1"})) + pool.Put(route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID2"})) + }) + It("returns 0", func() { + Expect(iter.CalculateAverageLoad()).To(Equal(float64(0))) + }) + }) + + Context("when endpoints have varying connection counts", func() { + var e1, e2, e3 *route.Endpoint + BeforeEach(func() { + e1 = route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID1"}) + e2 = route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID2"}) + e3 = route.NewEndpoint(&route.EndpointOpts{Host: "3.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID3"}) + endpoints = []*route.Endpoint{e1, e2, e3} + for _, e := range endpoints { + pool.Put(e) + } + for i := 0; i < 2; i++ { + iter.PreRequest(e1) + } + for i := 0; i < 4; i++ { + iter.PreRequest(e2) + } + for i := 0; i < 6; i++ { + iter.PreRequest(e3) + } + }) + It("returns the correct average", func() { + // in general 12 in flight requests + Expect(iter.CalculateAverageLoad()).To(Equal(float64(4))) + }) + }) + + Context("when one endpoint has many connections", func() { + var e1, e2 *route.Endpoint + BeforeEach(func() { + e1 = route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID1"}) + e2 = route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID2"}) + endpoints = []*route.Endpoint{e1, e2} + for _, e := range endpoints { + pool.Put(e) + } + for i := 0; i < 10; i++ { + iter.PreRequest(e1) + } + }) + It("returns the correct average", func() { + Expect(iter.CalculateAverageLoad()).To(Equal(float64(5))) + }) + }) + }) }) + +// MockHashLookupTable provides a simple mock implementation of MaglevLookup interface for testing. +type MockHashLookupTable struct { + lookupTable []int + endpointList []string +} + +// NewMockHashLookupTable creates a new mock lookup table with predefined mappings +func NewMockHashLookupTable(lookupTable []int, endpointList []string) *MockHashLookupTable { + + return &MockHashLookupTable{ + lookupTable: lookupTable, + endpointList: endpointList, + } +} + +func (m *MockHashLookupTable) GetInstanceForHashHeader(hashHeaderValue string) (uint64, string, error) { + if len(m.endpointList) == 0 { + return 0, "", nil + } + h := fnv.New64a() + _, _ = h.Write([]byte(hashHeaderValue)) + key := h.Sum64() + index := key % m.GetLookupTableSize() + return index, m.endpointList[m.lookupTable[index]], nil + +} + +func (m *MockHashLookupTable) GetLookupTableSize() uint64 { + return uint64(len(m.lookupTable)) +} + +func (m *MockHashLookupTable) GetEndpointId(lookupTableIndex uint64) string { + return m.endpointList[m.lookupTable[lookupTableIndex]] +} + +func (m *MockHashLookupTable) Add(endpoint string) { + // Check if endpoint already exists + for _, existing := range m.endpointList { + if existing == endpoint { + return + } + } + m.endpointList = append(m.endpointList, endpoint) +} + +func (m *MockHashLookupTable) Remove(endpoint string) { + for i, existing := range m.endpointList { + if existing == endpoint { + m.endpointList = append(m.endpointList[:i], m.endpointList[i+1:]...) + return + } + } +} + +func (m *MockHashLookupTable) GetEndpointList() []string { + return append([]string(nil), m.endpointList...) // return a copy +} + +// GetLookupTable returns a copy of the current lookup table (for testing) +func (m *MockHashLookupTable) GetLookupTable() []int { + return m.lookupTable // return a copy +} + +// GetPermutationTable returns a copy of the current permutation table (for testing) +func (m *MockHashLookupTable) GetPermutationTable() [][]uint64 { + return nil // not implemented in mock +} + +// Compile-time check to ensure MockHashLookupTable implements MaglevLookup interface +var _ route.MaglevLookup = (*MockHashLookupTable)(nil) diff --git a/src/code.cloudfoundry.org/gorouter/route/maglev.go b/src/code.cloudfoundry.org/gorouter/route/maglev.go index 6085c7d0f..23d46210f 100644 --- a/src/code.cloudfoundry.org/gorouter/route/maglev.go +++ b/src/code.cloudfoundry.org/gorouter/route/maglev.go @@ -27,6 +27,35 @@ const ( lookupTableSize uint64 = 1801 ) +// MaglevLookup defines the interface for consistent hashing lookup table implementations. +// This interface allows for different implementations of the Maglev algorithm and +// enables easy testing with mock implementations. +type MaglevLookup interface { + // Add a new endpoint to the lookup table + Add(endpoint string) + + // Remove an endpoint from the lookup table + Remove(endpoint string) + + // GetInstanceForHashHeader endpoint by specified request header value + GetInstanceForHashHeader(hashHeaderValue string) (uint64, string, error) + + // GetEndpointId returns the endpoint ID by specified lookup table index + GetEndpointId(lookupTableIndex uint64) string + + // GetLookupTableSize returns the size of the lookup table + GetLookupTableSize() uint64 + + // GetEndpointList returns a copy of the current endpoint list (for testing) + GetEndpointList() []string + + // GetLookupTable returns a copy of the current lookup table (for testing) + GetLookupTable() []int + + // GetPermutationTable returns a copy of the current permutation table (for testing) + GetPermutationTable() [][]uint64 +} + // Maglev implementation of consistent hashing algorithm described in "Maglev: A Fast and Reliable Software Network // Load Balancer" (https://storage.googleapis.com/gweb-research2023-media/pubtools/2904.pdf) type Maglev struct { @@ -89,23 +118,29 @@ func (m *Maglev) Remove(endpoint string) { m.fillLookupTable() } -// Get endpoint by specified request header value -// Todo: Overload scenario: Get should return an index rather than an instance, -// so that we can iterate to the next endpoint in case it is overloaded (e.g. via another -// helper function that resolves the endpoint via the index) -func (m *Maglev) Get(headerValue string) (string, error) { +func (m *Maglev) hashKey(headerValue string) uint64 { + return m.calculateFNVHash64(headerValue) +} + +// GetInstanceForHashHeader lookup table index and private instance ID for the specified request header value +func (m *Maglev) GetInstanceForHashHeader(hashHeaderValue string) (uint64, string, error) { m.lock.RLock() defer m.lock.RUnlock() if len(m.endpointList) == 0 { - return "", errors.New("maglev-get-endpoint-no-endpoints") + return 0, "", errors.New("no endpoint available") } - key := m.hashKey(headerValue) - return m.endpointList[m.lookupTable[key%lookupTableSize]], nil + key := m.hashKey(hashHeaderValue) + index := key % lookupTableSize + return index, m.endpointList[m.lookupTable[key%lookupTableSize]], nil } -func (m *Maglev) hashKey(headerValue string) uint64 { - return m.calculateFNVHash64(headerValue) +// GetEndpointId by specified lookup table index +func (m *Maglev) GetEndpointId(lookupTableIndex uint64) string { + m.lock.RLock() + defer m.lock.RUnlock() + + return m.endpointList[m.lookupTable[lookupTableIndex]] } // generatePermutation creates a permutationTable of the lookup table for each endpoint @@ -220,3 +255,6 @@ func (m *Maglev) calculateFNVHash64(key string) uint64 { _, _ = h.Write([]byte(key)) return h.Sum64() } + +// Compile-time check to ensure Maglev implements MaglevLookup interface +var _ MaglevLookup = (*Maglev)(nil) diff --git a/src/code.cloudfoundry.org/gorouter/route/maglev_test.go b/src/code.cloudfoundry.org/gorouter/route/maglev_test.go index ae8af9d07..b72d13b3d 100644 --- a/src/code.cloudfoundry.org/gorouter/route/maglev_test.go +++ b/src/code.cloudfoundry.org/gorouter/route/maglev_test.go @@ -39,9 +39,9 @@ var _ = Describe("Maglev", func() { Expect(maglev.GetPermutationTable()).To(HaveLen(1)) Expect(maglev.GetPermutationTable()[0]).To(HaveLen(int(maglev.GetLookupTableSize()))) - result, err := maglev.Get("test-key") + _, backend, err := maglev.GetInstanceForHashHeader("test-key") Expect(err).NotTo(HaveOccurred()) - Expect(result).To(Equal("backend1")) + Expect(backend).To(Equal("backend1")) }) }) @@ -55,9 +55,9 @@ var _ = Describe("Maglev", func() { Expect(maglev.GetPermutationTable()).To(HaveLen(1)) Expect(maglev.GetPermutationTable()[0]).To(HaveLen(int(maglev.GetLookupTableSize()))) - result, err := maglev.Get("test-key") + _, backend, err := maglev.GetInstanceForHashHeader("test-key") Expect(err).NotTo(HaveOccurred()) - Expect(result).To(Equal("backend1")) + Expect(backend).To(Equal("backend1")) }) }) @@ -76,9 +76,9 @@ var _ = Describe("Maglev", func() { backends := make(map[string]bool) for i := 0; i < 1000; i++ { - result, err := maglev.Get(string(rune(i))) + _, backend, err := maglev.GetInstanceForHashHeader(string(rune(i))) Expect(err).NotTo(HaveOccurred()) - backends[result] = true + backends[backend] = true } Expect(backends["backend1"]).To(BeTrue()) @@ -121,7 +121,7 @@ var _ = Describe("Maglev", func() { Describe("Get", func() { Context("when no backends were added", func() { It("should return an error", func() { - _, err := maglev.Get("test-key") + _, _, err := maglev.GetInstanceForHashHeader("test-key") Expect(err).To(HaveOccurred()) }) }) @@ -134,15 +134,15 @@ var _ = Describe("Maglev", func() { It("should return consistent results for the same key", func() { var counter = make(map[string]int) - var result1 string + var result string var err error - for _ = range 100 { - result1, err = maglev.Get("consistent-key") + for range 100 { + _, result, err = maglev.GetInstanceForHashHeader("consistent-key") Expect(err).NotTo(HaveOccurred()) - counter[result1]++ + counter[result]++ } - Expect(counter[result1]).To(Equal(100)) + Expect(counter[result]).To(Equal(100)) }) It("should distribute keys across backends", func() { @@ -152,9 +152,9 @@ var _ = Describe("Maglev", func() { distribution := make(map[string]int) for i := range 1000 { - result, err := maglev.Get(string(rune(i))) + _, backend, err := maglev.GetInstanceForHashHeader(string(rune(i))) Expect(err).NotTo(HaveOccurred()) - distribution[result]++ + distribution[backend]++ } Expect(distribution["backend1"]).To(BeNumerically(">", 0)) @@ -171,10 +171,98 @@ var _ = Describe("Maglev", func() { }) It("should not return the removed backend", func() { - for _ = range 100 { - endpoint, err := maglev.Get("consistent-key") + for range 100 { + _, backend, err := maglev.GetInstanceForHashHeader("consistent-key") Expect(err).NotTo(HaveOccurred()) - Expect(endpoint).To(Equal("backend2")) + Expect(backend).To(Equal("backend2")) + } + }) + }) + }) + + Describe("GetInstanceForHashHeader", func() { + Context("when no backends were added", func() { + It("should return an error", func() { + _, _, err := maglev.GetInstanceForHashHeader("test-key") + Expect(err).To(HaveOccurred()) + }) + }) + + Context("when backends are added", func() { + BeforeEach(func() { + maglev.Add("backend1") + maglev.Add("backend2") + }) + + It("should return consistent results for the same key", func() { + var counter = make(map[uint64]int) + var lookupTableIndex uint64 + var err error + for range 100 { + lookupTableIndex, _, err = maglev.GetInstanceForHashHeader("consistent-key") + Expect(err).NotTo(HaveOccurred()) + counter[lookupTableIndex]++ + } + + Expect(counter[lookupTableIndex]).To(Equal(100)) + }) + }) + }) + + Describe("GetEndpointId", func() { + Context("when backends are added", func() { + BeforeEach(func() { + maglev.Add("app_instance_1") + maglev.Add("app_instance_2") + }) + + It("should return consistent results for the same key", func() { + var counter = make(map[string]int) + var endpointID string + for range 100 { + lookupTableIndex, _, err := maglev.GetInstanceForHashHeader("consistent-key") + Expect(err).NotTo(HaveOccurred()) + endpointID = maglev.GetEndpointId(lookupTableIndex) + Expect(err).NotTo(HaveOccurred()) + counter[endpointID]++ + } + + Expect(counter[endpointID]).To(Equal(100)) + }) + + It("should distribute keys across backends", func() { + maglev.Add("app_instance_1") + maglev.Add("app_instance_2") + maglev.Add("app_instance_3") + + distribution := make(map[string]int) + for i := range 1000 { + lookupTableIndex, _, err := maglev.GetInstanceForHashHeader(string(rune(i))) + Expect(err).NotTo(HaveOccurred()) + endpointID := maglev.GetEndpointId(lookupTableIndex) + Expect(err).NotTo(HaveOccurred()) + distribution[endpointID]++ + } + + Expect(distribution["app_instance_1"]).To(BeNumerically(">", 0)) + Expect(distribution["app_instance_2"]).To(BeNumerically(">", 0)) + Expect(distribution["app_instance_3"]).To(BeNumerically(">", 0)) + }) + }) + + Context("when backends are removed", func() { + BeforeEach(func() { + maglev.Add("app_instance_1") + maglev.Add("app_instance_2") + maglev.Remove("app_instance_1") + }) + + It("should not return the removed backend", func() { + for i := range 1000 { + lookupTableIndex, _, err := maglev.GetInstanceForHashHeader(string(rune(i))) + Expect(err).NotTo(HaveOccurred()) + endpointID := maglev.GetEndpointId(lookupTableIndex) + Expect(endpointID).To(Equal("app_instance_2")) } }) }) @@ -195,7 +283,7 @@ var _ = Describe("Maglev", func() { initialMappings := make(map[string]string) for _, key := range keys { - backend, err := maglev.Get(key) + _, backend, err := maglev.GetInstanceForHashHeader(key) Expect(err).NotTo(HaveOccurred()) initialMappings[key] = backend } @@ -204,7 +292,7 @@ var _ = Describe("Maglev", func() { changedMappings := 0 for _, key := range keys { - backend, err := maglev.Get(key) + _, backend, err := maglev.GetInstanceForHashHeader(key) Expect(err).NotTo(HaveOccurred()) if initialMappings[key] != backend { changedMappings++ @@ -224,7 +312,7 @@ var _ = Describe("Maglev", func() { go func() { defer GinkgoRecover() for j := 0; j < 100; j++ { - _, err := maglev.Get("test-key") + _, _, err := maglev.GetInstanceForHashHeader("test-key") Expect(err).NotTo(HaveOccurred()) } done <- true diff --git a/src/code.cloudfoundry.org/gorouter/route/pool.go b/src/code.cloudfoundry.org/gorouter/route/pool.go index b9f491798..8217aefca 100644 --- a/src/code.cloudfoundry.org/gorouter/route/pool.go +++ b/src/code.cloudfoundry.org/gorouter/route/pool.go @@ -202,7 +202,7 @@ type EndpointPool struct { updatedAt time.Time LoadBalancingAlgorithm string HashRoutingProperties *HashRoutingProperties - HashLookupTable *Maglev + HashLookupTable MaglevLookup } type EndpointOpts struct { @@ -617,19 +617,24 @@ func (p *EndpointPool) MarshalJSON() ([]byte, error) { // setPoolLoadBalancingAlgorithm overwrites the load balancing algorithm of a pool by that of a specified endpoint, if that is valid. func (p *EndpointPool) setPoolLoadBalancingAlgorithm(endpoint *Endpoint) { - if endpoint.LoadBalancingAlgorithm != "" && endpoint.LoadBalancingAlgorithm != p.LoadBalancingAlgorithm { + if endpoint.LoadBalancingAlgorithm == "" { + return + } + + if endpoint.LoadBalancingAlgorithm != p.LoadBalancingAlgorithm { if config.IsLoadBalancingAlgorithmValid(endpoint.LoadBalancingAlgorithm) { p.LoadBalancingAlgorithm = endpoint.LoadBalancingAlgorithm p.logger.Debug("setting-pool-load-balancing-algorithm-to-that-of-an-endpoint", slog.String("endpointLBAlgorithm", endpoint.LoadBalancingAlgorithm), slog.String("poolLBAlgorithm", p.LoadBalancingAlgorithm)) - p.prepareHashBasedRouting(endpoint) + } else { p.logger.Error("invalid-endpoint-load-balancing-algorithm-provided-keeping-pool-lb-algo", slog.String("endpointLBAlgorithm", endpoint.LoadBalancingAlgorithm), slog.String("poolLBAlgorithm", p.LoadBalancingAlgorithm)) } } + p.prepareHashBasedRouting(endpoint) } func (p *EndpointPool) prepareHashBasedRouting(endpoint *Endpoint) { @@ -639,11 +644,15 @@ func (p *EndpointPool) prepareHashBasedRouting(endpoint *Endpoint) { if p.HashLookupTable == nil { p.HashLookupTable = NewMaglev(p.logger) } - p.HashRoutingProperties = &HashRoutingProperties{ + + newProps := &HashRoutingProperties{ Header: endpoint.HashHeaderName, BalanceFactor: endpoint.HashBalanceFactor, } + if p.HashRoutingProperties == nil || !p.HashRoutingProperties.Equal(newProps) { + p.HashRoutingProperties = newProps + } } func (e *endpointElem) failed() { From 25d9395966f7162e0ce71b8a7433ff11e0d23b1b Mon Sep 17 00:00:00 2001 From: Clemens Hoffmann Date: Wed, 12 Nov 2025 16:14:35 +0100 Subject: [PATCH 04/17] * Minor improvements and refactoring --- .../gorouter/route/hash_based.go | 55 ++++++++++--- .../gorouter/route/hash_based_test.go | 81 ++++++++++++++++++- 2 files changed, 121 insertions(+), 15 deletions(-) diff --git a/src/code.cloudfoundry.org/gorouter/route/hash_based.go b/src/code.cloudfoundry.org/gorouter/route/hash_based.go index ff98f072a..db7aaac1a 100644 --- a/src/code.cloudfoundry.org/gorouter/route/hash_based.go +++ b/src/code.cloudfoundry.org/gorouter/route/hash_based.go @@ -58,6 +58,17 @@ func (h *HashBased) Next(attempt int) *Endpoint { return endpoint } + if len(h.pool.endpoints) == 0 { + h.logger.Warn("hash-based-routing-pool-empty", slog.String("host", h.pool.host)) + return nil + } + + endpoint = h.getSingleEndpoint() + if endpoint != nil { + h.lastEndpoint = endpoint + return endpoint + } + if h.pool.HashLookupTable == nil { h.logger.Error("hash-based-routing-failed", slog.String("host", h.pool.host), log.ErrAttr(errors.New("Lookup table is empty"))) return nil @@ -89,11 +100,6 @@ func (h *HashBased) Next(attempt int) *Endpoint { } func (h *HashBased) findEndpoint(index uint64, attempt int) *Endpoint { - maxIterations := len(h.pool.endpoints) - if maxIterations == 0 { - return nil - } - // Ensure we don't exceed the lookup table size lookupTableSize := h.pool.HashLookupTable.GetLookupTableSize() @@ -128,10 +134,9 @@ func (h *HashBased) findEndpoint(index uint64, attempt int) *Endpoint { lastEndpointPrivateId = id - e := endpointElem.endpoint - if h.pool.HashRoutingProperties.BalanceFactor <= 0 || !h.isOverloaded(e) { + if h.pool.HashRoutingProperties.BalanceFactor <= 0 || !h.isImbalancedOrOverloaded(endpointElem) { h.lastLookupTableIndex = currentIndex - return e + return endpointElem.endpoint } currentIndex = (currentIndex + 1) % lookupTableSize @@ -141,11 +146,24 @@ func (h *HashBased) findEndpoint(index uint64, attempt int) *Endpoint { return nil } -func (h *HashBased) isOverloaded(e *Endpoint) bool { - avgLoad := h.CalculateAverageLoad() +func (h *HashBased) isImbalancedOrOverloaded(e *endpointElem) bool { + endpoint := e.endpoint + return h.IsImbalancedOrOverloaded(endpoint, e.isOverloaded()) +} + +func (h *HashBased) IsImbalancedOrOverloaded(endpoint *Endpoint, isEndpointOverloaded bool) bool { + avgNumberOfInFlightRequests := h.CalculateAverageLoad() + currentInFlightRequestCount := endpoint.Stats.NumberConnections.Count() balanceFactor := h.pool.HashRoutingProperties.BalanceFactor - if float64(e.Stats.NumberConnections.Count())/avgLoad > balanceFactor { - h.logger.Info("hash-based-routing-endpoint-overloaded", slog.String("host", h.pool.host), slog.String("endpoint-id", e.PrivateInstanceId), slog.Int64("endpoint-connections", e.Stats.NumberConnections.Count()), slog.Float64("average-load", avgLoad)) + + if isEndpointOverloaded { + h.logger.Debug("hash-based-routing-endpoint-overloaded", slog.String("host", h.pool.host), slog.String("endpoint-id", endpoint.PrivateInstanceId), slog.Int64("endpoint-connections", currentInFlightRequestCount)) + return true + } + + // Check if avgNumberOfInFlightRequests is 0 to avoid division by 0 + if avgNumberOfInFlightRequests == 0 || float64(currentInFlightRequestCount)/avgNumberOfInFlightRequests > balanceFactor { + h.logger.Debug("hash-based-routing-endpoint-imbalanced", slog.String("host", h.pool.host), slog.String("endpoint-id", endpoint.PrivateInstanceId), slog.Int64("endpoint-connections", endpoint.Stats.NumberConnections.Count()), slog.Float64("average-load", avgNumberOfInFlightRequests)) return true } return false @@ -203,6 +221,7 @@ func (h *HashBased) PostRequest(e *Endpoint) { e.Stats.NumberConnections.Decrement() } +// CalculateAverageLoad computes the average number of in-flight requests across all endpoints in the pool. func (h *HashBased) CalculateAverageLoad() float64 { if len(h.pool.endpoints) == 0 { return 0 @@ -217,3 +236,15 @@ func (h *HashBased) CalculateAverageLoad() float64 { return float64(currentInFlightRequestCount) / float64(len(h.pool.endpoints)) } + +func (h *HashBased) getSingleEndpoint() *Endpoint { + if len(h.pool.endpoints) == 1 { + e := h.pool.endpoints[0] + if e.isOverloaded() { + return nil + } + + return e.endpoint + } + return nil +} diff --git a/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go b/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go index 26df2c5b4..c7fcd79a3 100644 --- a/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go +++ b/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go @@ -6,11 +6,11 @@ import ( "time" "code.cloudfoundry.org/gorouter/config" - "code.cloudfoundry.org/gorouter/route" "code.cloudfoundry.org/gorouter/test_util" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "github.com/onsi/gomega/gbytes" ) var _ = Describe("HashBased", func() { @@ -257,7 +257,83 @@ var _ = Describe("HashBased", func() { Expect(endpoint.Stats.NumberConnections.Count()).To(Equal(initialCount - 1)) }) }) - Describe("CalculateAverageLoad", func() { + Describe("IsImbalancedOrOverloaded", func() { + var iter *route.HashBased + var endpoints []*route.Endpoint + + BeforeEach(func() { + iter = route.NewHashBased(logger.Logger, pool, "", false, false, "").(*route.HashBased) + }) + + Context("when endpoints have a lot of in-flight requests", func() { + var e1, e2, e3 *route.Endpoint + BeforeEach(func() { + e1 = route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 1.2, PrivateInstanceId: "ID1"}) + e2 = route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 1.2, PrivateInstanceId: "ID2"}) + e3 = route.NewEndpoint(&route.EndpointOpts{Host: "3.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 1.2, PrivateInstanceId: "ID3"}) + endpoints = []*route.Endpoint{e1, e2, e3} + for _, e := range endpoints { + pool.Put(e) + } + + }) + It("mark the endpoint as overloaded", func() { + for i := 0; i < 500; i++ { + iter.PreRequest(e1) + } + // in general 500 in flight requests counted by e1 + Expect(iter.IsImbalancedOrOverloaded(e1, true)).To(BeTrue()) + }) + It("do not mark as imbalanced if every endpoint has 499 in-flight requests", func() { + for i := 0; i < 498; i++ { + iter.PreRequest(e1) + } + for i := 0; i < 498; i++ { + iter.PreRequest(e2) + } + for i := 0; i < 498; i++ { + iter.PreRequest(e3) + } + // in general 500 in flight requests counted by e1 + Expect(iter.IsImbalancedOrOverloaded(e1, false)).To(BeFalse()) + }) + + It("mark endpoint as overloaded if every endpoint has 500 in-flight requests", func() { + for i := 0; i < 499; i++ { + iter.PreRequest(e1) + } + for i := 0; i < 499; i++ { + iter.PreRequest(e2) + } + for i := 0; i < 499; i++ { + iter.PreRequest(e3) + } + // in general 500 in flight requests counted by e1 + Expect(iter.IsImbalancedOrOverloaded(e1, true)).To(BeTrue()) + Eventually(logger).Should(gbytes.Say("hash-based-routing-endpoint-overloaded")) + Expect(iter.IsImbalancedOrOverloaded(e2, true)).To(BeTrue()) + Expect(iter.IsImbalancedOrOverloaded(e3, true)).To(BeTrue()) + + }) + It("mark as imbalanced if it has more in-flight requests", func() { + for i := 0; i < 300; i++ { + iter.PreRequest(e1) + } + for i := 0; i < 200; i++ { + iter.PreRequest(e2) + } + for i := 0; i < 200; i++ { + iter.PreRequest(e3) + } + Expect(iter.IsImbalancedOrOverloaded(e1, false)).To(BeTrue()) + Eventually(logger).Should(gbytes.Say("hash-based-routing-endpoint-imbalanced")) + Expect(iter.IsImbalancedOrOverloaded(e2, false)).To(BeFalse()) + Expect(iter.IsImbalancedOrOverloaded(e3, false)).To(BeFalse()) + }) + }) + }) + + Describe("CalculateAverageNumberOfConnections", func() { var iter *route.HashBased var endpoints []*route.Endpoint @@ -336,7 +412,6 @@ type MockHashLookupTable struct { // NewMockHashLookupTable creates a new mock lookup table with predefined mappings func NewMockHashLookupTable(lookupTable []int, endpointList []string) *MockHashLookupTable { - return &MockHashLookupTable{ lookupTable: lookupTable, endpointList: endpointList, From bda2dbf750f62872d60200712b9e91a536284724 Mon Sep 17 00:00:00 2001 From: Tamara Boehm Date: Fri, 28 Nov 2025 13:19:15 +0100 Subject: [PATCH 05/17] Refactor pool.Endpoints to apply review feedback --- .../gorouter/handlers/helpers.go | 4 +- .../gorouter/handlers/max_request_size.go | 2 +- .../round_tripper/proxy_round_tripper.go | 16 +------- .../round_tripper/proxy_round_tripper_test.go | 4 +- .../gorouter/route/hash_based.go | 3 +- .../gorouter/route/hash_based_test.go | 41 ++++++++----------- .../gorouter/route/maglev.go | 25 +++++++---- .../gorouter/route/pool.go | 10 ++++- .../gorouter/route/pool_test.go | 12 +++--- 9 files changed, 56 insertions(+), 61 deletions(-) diff --git a/src/code.cloudfoundry.org/gorouter/handlers/helpers.go b/src/code.cloudfoundry.org/gorouter/handlers/helpers.go index f1f048819..1eeebf648 100644 --- a/src/code.cloudfoundry.org/gorouter/handlers/helpers.go +++ b/src/code.cloudfoundry.org/gorouter/handlers/helpers.go @@ -63,13 +63,13 @@ func upgradeHeader(request *http.Request) string { return "" } -func EndpointIteratorForRequest(logger *slog.Logger, request *http.Request, stickySessionCookieNames config.StringSet, authNegotiateSticky bool, azPreference string, az string) (route.EndpointIterator, error) { +func EndpointIteratorForRequest(logger *slog.Logger, request *http.Request, stickySessionCookieNames config.StringSet, authNegotiateSticky bool, azPreference string, az string, globalLB string) (route.EndpointIterator, error) { reqInfo, err := ContextRequestInfo(request) if err != nil { return nil, fmt.Errorf("could not find reqInfo in context") } stickyEndpointID, mustBeSticky := GetStickySession(request, stickySessionCookieNames, authNegotiateSticky) - return reqInfo.RoutePool.Endpoints(logger, stickyEndpointID, mustBeSticky, azPreference, az), nil + return reqInfo.RoutePool.Endpoints(logger, stickyEndpointID, mustBeSticky, azPreference, az, globalLB, request), nil } func GetStickySession(request *http.Request, stickySessionCookieNames config.StringSet, authNegotiateSticky bool) (string, bool) { diff --git a/src/code.cloudfoundry.org/gorouter/handlers/max_request_size.go b/src/code.cloudfoundry.org/gorouter/handlers/max_request_size.go index d164b5e80..a4e44d63b 100644 --- a/src/code.cloudfoundry.org/gorouter/handlers/max_request_size.go +++ b/src/code.cloudfoundry.org/gorouter/handlers/max_request_size.go @@ -69,7 +69,7 @@ func (m *MaxRequestSize) ServeHTTP(rw http.ResponseWriter, r *http.Request, next if err != nil { logger.Error("request-info-err", log.ErrAttr(err)) } else { - endpointIterator, err := EndpointIteratorForRequest(logger, r, m.cfg.StickySessionCookieNames, m.cfg.StickySessionsForAuthNegotiate, m.cfg.LoadBalanceAZPreference, m.cfg.Zone) + endpointIterator, err := EndpointIteratorForRequest(logger, r, m.cfg.StickySessionCookieNames, m.cfg.StickySessionsForAuthNegotiate, m.cfg.LoadBalanceAZPreference, m.cfg.Zone, m.cfg.LoadBalance) if err != nil { logger.Error("failed-to-find-endpoint-for-req-during-431-short-circuit", log.ErrAttr(err)) } else { diff --git a/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper.go b/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper.go index 84252262f..b17273ad1 100644 --- a/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper.go +++ b/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper.go @@ -126,21 +126,7 @@ func (rt *roundTripper) RoundTrip(originalRequest *http.Request) (*http.Response stickyEndpointID, mustBeSticky := handlers.GetStickySession(request, rt.config.StickySessionCookieNames, rt.config.StickySessionsForAuthNegotiate) numberOfEndpoints := reqInfo.RoutePool.NumEndpoints() - iter := reqInfo.RoutePool.Endpoints(rt.logger, stickyEndpointID, mustBeSticky, rt.config.LoadBalanceAZPreference, rt.config.Zone) - if reqInfo.RoutePool.LoadBalancingAlgorithm == config.LOAD_BALANCE_HB { - if reqInfo.RoutePool.HashRoutingProperties == nil { - rt.logger.Error("hash-routing-properties-nil", slog.String("host", reqInfo.RoutePool.Host())) - - } else { - headerName := reqInfo.RoutePool.HashRoutingProperties.Header - headerValue := request.Header.Get(headerName) - if headerValue != "" { - iter.(*route.HashBased).HeaderValue = headerValue - } else { - iter = reqInfo.RoutePool.FallBackToDefaultLoadBalancing(rt.config.LoadBalance, rt.logger, stickyEndpointID, mustBeSticky, rt.config.LoadBalanceAZPreference, rt.config.Zone) - } - } - } + iter := reqInfo.RoutePool.Endpoints(rt.logger, stickyEndpointID, mustBeSticky, rt.config.LoadBalanceAZPreference, rt.config.Zone, rt.config.LoadBalance, request) // The selectEndpointErr needs to be tracked separately. If we get an error // while selecting an endpoint we might just have run out of routes. In diff --git a/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper_test.go b/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper_test.go index 6abe4d218..50176e0d2 100644 --- a/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper_test.go +++ b/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper_test.go @@ -275,7 +275,7 @@ var _ = Describe("ProxyRoundTripper", func() { res, err := proxyRoundTripper.RoundTrip(req) Expect(err).NotTo(HaveOccurred()) - iter := routePool.Endpoints(logger.Logger, "", false, AZPreference, AZ) + iter := routePool.Endpoints(logger.Logger, "", false, AZPreference, AZ, cfg.LoadBalance, req) ep1 := iter.Next(0) ep2 := iter.Next(1) Expect(ep1.PrivateInstanceId).To(Equal(ep2.PrivateInstanceId)) @@ -609,7 +609,7 @@ var _ = Describe("ProxyRoundTripper", func() { _, err := proxyRoundTripper.RoundTrip(req) Expect(err).To(MatchError(ContainSubstring("tls: handshake failure"))) - iter := routePool.Endpoints(logger.Logger, "", false, AZPreference, AZ) + iter := routePool.Endpoints(logger.Logger, "", false, AZPreference, AZ, cfg.LoadBalance, req) ep1 := iter.Next(0) ep2 := iter.Next(1) Expect(ep1).To(Equal(ep2)) diff --git a/src/code.cloudfoundry.org/gorouter/route/hash_based.go b/src/code.cloudfoundry.org/gorouter/route/hash_based.go index db7aaac1a..4c6df58de 100644 --- a/src/code.cloudfoundry.org/gorouter/route/hash_based.go +++ b/src/code.cloudfoundry.org/gorouter/route/hash_based.go @@ -28,13 +28,14 @@ type HashBased struct { // NewHashBased initializes an endpoint iterator that selects endpoints based on a hash of a header value. // The global properties locallyOptimistic and localAvailabilityZone will be ignored when using Hash-Based Routing. -func NewHashBased(logger *slog.Logger, p *EndpointPool, initial string, mustBeSticky bool, locallyOptimistic bool, localAvailabilityZone string) EndpointIterator { +func NewHashBased(logger *slog.Logger, p *EndpointPool, initial string, mustBeSticky bool, headerValue string) EndpointIterator { return &HashBased{ logger: logger, pool: p, lock: &sync.Mutex{}, stickyEndpointID: initial, mustBeSticky: mustBeSticky, + HeaderValue: headerValue, } } diff --git a/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go b/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go index c7fcd79a3..09a286153 100644 --- a/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go +++ b/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go @@ -36,7 +36,7 @@ var _ = Describe("HashBased", func() { Context("when pool is empty", func() { It("does not select an endpoint", func() { - iter := route.NewHashBased(logger.Logger, pool, "", false, false, "") + iter := route.NewHashBased(logger.Logger, pool, "", false, "tenant-1") Expect(iter.Next(0)).To(BeNil()) }) }) @@ -55,8 +55,7 @@ var _ = Describe("HashBased", func() { }) It("It returns the same endpoint for the same header value", func() { - iter := route.NewHashBased(logger.Logger, pool, "", false, false, "") - iter.(*route.HashBased).HeaderValue = "tenant-1" + iter := route.NewHashBased(logger.Logger, pool, "", false, "tenant-1") first := iter.Next(0) second := iter.Next(0) Expect(first).NotTo(BeNil()) @@ -80,8 +79,11 @@ var _ = Describe("HashBased", func() { for _, e := range endpoints { pool.Put(e) } - iter := route.NewHashBased(logger.Logger, pool, "", false, false, "") - iter.(*route.HashBased).HeaderValue = "tenant-1" + iter := route.NewHashBased(logger.Logger, pool, "", false, "tenant-1") + // Simulate in-flight requests + for _, e := range endpoints { + iter.PreRequest(e) + } first := iter.Next(0) Expect(iter.Next(0)).To(Equal(first)) for i := 0; i < 6; i++ { @@ -98,7 +100,7 @@ var _ = Describe("HashBased", func() { for _, e := range endpoints { pool.Put(e) } - iter := route.NewHashBased(logger.Logger, pool, "", false, false, "") + iter := route.NewHashBased(logger.Logger, pool, "", false, "tenant-1") iter.(*route.HashBased).HeaderValue = "tenant-1" first := iter.Next(0) Expect(iter.Next(0)).To(Equal(first)) @@ -134,8 +136,7 @@ var _ = Describe("HashBased", func() { } maglevMock := NewMockHashLookupTable(MaglevLookupTable, endpointIDList) pool.HashLookupTable = maglevMock - iter := route.NewHashBased(logger.Logger, pool, "", false, false, "") - iter.(*route.HashBased).HeaderValue = "tenant-1" + iter := route.NewHashBased(logger.Logger, pool, "", false, "tenant-1") // The returned endpoint has always ID3 according to the Maglev lookup table first := iter.Next(0) Expect(first).To(Equal(e4)) @@ -154,8 +155,7 @@ var _ = Describe("HashBased", func() { for _, e := range endpoints { pool.Put(e) } - iter := route.NewHashBased(logger.Logger, pool, "", false, false, "") - iter.(*route.HashBased).HeaderValue = "tenant-1" + iter := route.NewHashBased(logger.Logger, pool, "", false, "tenant-1") firstAttemptResult := iter.Next(0) Expect(iter.Next(0)).To(Equal(firstAttemptResult)) for i := 0; i < 6; i++ { @@ -187,22 +187,19 @@ var _ = Describe("HashBased", func() { }) Context("when mustBeSticky is true", func() { - BeforeEach(func() { - iter = route.NewHashBased(logger.Logger, pool, "ID1", true, false, "") - }) - It("returns the sticky endpoint when it exists", func() { + iter = route.NewHashBased(logger.Logger, pool, "ID1", true, "abc") endpoint := iter.Next(0) Expect(endpoint).NotTo(BeNil()) Expect(endpoint.PrivateInstanceId).To(Equal("ID1")) }) It("returns nil when sticky endpoint doesn't exist", func() { - iter = route.NewHashBased(logger.Logger, pool, "nonexistent-id", true, false, "") + iter = route.NewHashBased(logger.Logger, pool, "nonexistent-id", true, "abc") Expect(iter.Next(0)).To(BeNil()) }) It("returns nil when sticky endpoint is overloaded and mustBeSticky is true", func() { - iter = route.NewHashBased(logger.Logger, pool, "ID1", true, false, "") + iter = route.NewHashBased(logger.Logger, pool, "ID1", true, "abc") for i := 0; i < 1000; i++ { iter.PreRequest(endpoints[0]) } @@ -212,7 +209,7 @@ var _ = Describe("HashBased", func() { Context("when mustBeSticky is false", func() { BeforeEach(func() { - iter = route.NewHashBased(logger.Logger, pool, "ID1", false, false, "") + iter = route.NewHashBased(logger.Logger, pool, "ID1", false, "some-value") }) It("returns the sticky endpoint when it exists", func() { @@ -222,9 +219,7 @@ var _ = Describe("HashBased", func() { }) It("falls back to hash-based routing when sticky endpoint doesn't exist", func() { - iter = route.NewHashBased(logger.Logger, pool, "nonexistent-id", false, false, "") - hashIter := iter.(*route.HashBased) - hashIter.HeaderValue = "some-value" + iter = route.NewHashBased(logger.Logger, pool, "nonexistent-id", false, "some-value") endpoint := iter.Next(0) Expect(endpoint).NotTo(BeNil()) }) @@ -241,7 +236,7 @@ var _ = Describe("HashBased", func() { BeforeEach(func() { endpoint = route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID1"}) pool.Put(endpoint) - iter = route.NewHashBased(logger.Logger, pool, "", false, false, "") + iter = route.NewHashBased(logger.Logger, pool, "", false, "abc") }) It("increments connection count on PreRequest", func() { @@ -262,7 +257,7 @@ var _ = Describe("HashBased", func() { var endpoints []*route.Endpoint BeforeEach(func() { - iter = route.NewHashBased(logger.Logger, pool, "", false, false, "").(*route.HashBased) + iter = route.NewHashBased(logger.Logger, pool, "", false, "abc").(*route.HashBased) }) Context("when endpoints have a lot of in-flight requests", func() { @@ -338,7 +333,7 @@ var _ = Describe("HashBased", func() { var endpoints []*route.Endpoint BeforeEach(func() { - iter = route.NewHashBased(logger.Logger, pool, "", false, false, "").(*route.HashBased) + iter = route.NewHashBased(logger.Logger, pool, "", false, "abc").(*route.HashBased) }) Context("when there are no endpoints", func() { diff --git a/src/code.cloudfoundry.org/gorouter/route/maglev.go b/src/code.cloudfoundry.org/gorouter/route/maglev.go index 23d46210f..72731ec05 100644 --- a/src/code.cloudfoundry.org/gorouter/route/maglev.go +++ b/src/code.cloudfoundry.org/gorouter/route/maglev.go @@ -1,14 +1,21 @@ package route -/****************************************************************************** - * Original github.com/kkdai/maglev/maglev.go - * - * Copyright (c) 2019 Evan Lin (github.com/kkdai) - * - * This program and the accompanying materials are made available under - * the terms of the Apache License, Version 2.0 which is available at - * http://www.apache.org/licenses/LICENSE-2.0. - ******************************************************************************/ +// Original https://github.com/kkdai/maglev +// +// Copyright (c) 2019 Evan Lin (github.com/kkdai) +// +// This program and the accompanying materials are made available under +// the terms of the Apache License, Version 2.0 which is available at +// http://www.apache.org/licenses/LICENSE-2.0. +// +// CHANGES: +// - Modified for integration with CF GoRouter +// - Added MaglevLookup interface for testability and abstraction +// - Enhanced with structured logging using slog +// - Added thread-safe operations +// - Extended with getter methods for unit testing +// - Added error handling and safety checks +// - Customized for hash-based routing requirements import ( "errors" diff --git a/src/code.cloudfoundry.org/gorouter/route/pool.go b/src/code.cloudfoundry.org/gorouter/route/pool.go index 8217aefca..4df23ae4c 100644 --- a/src/code.cloudfoundry.org/gorouter/route/pool.go +++ b/src/code.cloudfoundry.org/gorouter/route/pool.go @@ -347,6 +347,7 @@ func (p *EndpointPool) Put(endpoint *Endpoint) PoolPutResult { // new one. e.Lock() defer e.Unlock() + oldEndpoint := e.endpoint e.endpoint = endpoint @@ -471,7 +472,7 @@ func (p *EndpointPool) removeEndpoint(e *endpointElem) { } -func (p *EndpointPool) Endpoints(logger *slog.Logger, initial string, mustBeSticky bool, azPreference string, az string) EndpointIterator { +func (p *EndpointPool) Endpoints(logger *slog.Logger, initial string, mustBeSticky bool, azPreference string, az string, globalLB string, request *http.Request) EndpointIterator { switch p.LoadBalancingAlgorithm { case config.LOAD_BALANCE_LC: logger.Debug("endpoint-iterator-with-least-connection-lb-algo") @@ -480,8 +481,13 @@ func (p *EndpointPool) Endpoints(logger *slog.Logger, initial string, mustBeStic logger.Debug("endpoint-iterator-with-round-robin-lb-algo") return NewRoundRobin(logger, p, initial, mustBeSticky, azPreference == config.AZ_PREF_LOCAL, az) case config.LOAD_BALANCE_HB: + if p.HashRoutingProperties == nil || request.Header.Get(p.HashRoutingProperties.Header) == "" { + logger.Error("hash-routing-properties-missing", slog.String("host", p.Host())) + return p.FallBackToDefaultLoadBalancing(globalLB, logger, initial, mustBeSticky, azPreference, az) + } + headerValue := request.Header.Get(p.HashRoutingProperties.Header) logger.Debug("endpoint-iterator-with-hash-based-lb-algo") - return NewHashBased(logger, p, initial, mustBeSticky, azPreference == config.AZ_PREF_LOCAL, az) + return NewHashBased(logger, p, initial, mustBeSticky, headerValue) default: logger.Error("invalid-pool-load-balancing-algorithm", slog.String("poolLBAlgorithm", p.LoadBalancingAlgorithm), diff --git a/src/code.cloudfoundry.org/gorouter/route/pool_test.go b/src/code.cloudfoundry.org/gorouter/route/pool_test.go index 7709a1d8b..050c2694a 100644 --- a/src/code.cloudfoundry.org/gorouter/route/pool_test.go +++ b/src/code.cloudfoundry.org/gorouter/route/pool_test.go @@ -246,7 +246,7 @@ var _ = Describe("EndpointPool", func() { endpoint := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, ModificationTag: modTag2}) Expect(pool.Put(endpoint)).To(Equal(route.EndpointUpdated)) - Expect(pool.Endpoints(logger.Logger, "", false, azPreference, az).Next(0).ModificationTag).To(Equal(modTag2)) + Expect(pool.Endpoints(logger.Logger, "", false, azPreference, az, config.LOAD_BALANCE_RR, nil).Next(0).ModificationTag).To(Equal(modTag2)) }) Context("when modification_tag is older", func() { @@ -261,7 +261,7 @@ var _ = Describe("EndpointPool", func() { endpoint := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, ModificationTag: olderModTag}) Expect(pool.Put(endpoint)).To(Equal(route.EndpointUnmodified)) - Expect(pool.Endpoints(logger.Logger, "", false, azPreference, az).Next(0).ModificationTag).To(Equal(modTag2)) + Expect(pool.Endpoints(logger.Logger, "", false, azPreference, az, config.LOAD_BALANCE_RR, nil).Next(0).ModificationTag).To(Equal(modTag2)) }) }) }) @@ -312,7 +312,7 @@ var _ = Describe("EndpointPool", func() { Logger: logger.Logger, LoadBalancingAlgorithm: "wrong-lb-algo", }) - iterator := poolWithLBAlgo2.Endpoints(logger.Logger, "", false, "none", "zone") + iterator := poolWithLBAlgo2.Endpoints(logger.Logger, "", false, "none", "zone", config.LOAD_BALANCE_RR, nil) Expect(iterator).To(BeAssignableToTypeOf(&route.RoundRobin{})) Eventually(logger).Should(gbytes.Say(`invalid-pool-load-balancing-algorithm`)) }) @@ -322,7 +322,7 @@ var _ = Describe("EndpointPool", func() { Logger: logger.Logger, LoadBalancingAlgorithm: config.LOAD_BALANCE_LC, }) - iterator := poolWithLBAlgoLC.Endpoints(logger.Logger, "", false, "none", "az") + iterator := poolWithLBAlgoLC.Endpoints(logger.Logger, "", false, "none", "az", config.LOAD_BALANCE_LC, nil) Expect(iterator).To(BeAssignableToTypeOf(&route.LeastConnection{})) Eventually(logger).Should(gbytes.Say(`endpoint-iterator-with-least-connection-lb-algo`)) }) @@ -332,7 +332,7 @@ var _ = Describe("EndpointPool", func() { Logger: logger.Logger, LoadBalancingAlgorithm: config.LOAD_BALANCE_RR, }) - iterator := poolWithLBAlgoLC.Endpoints(logger.Logger, "", false, "none", "az") + iterator := poolWithLBAlgoLC.Endpoints(logger.Logger, "", false, "none", "az", config.LOAD_BALANCE_RR, nil) Expect(iterator).To(BeAssignableToTypeOf(&route.RoundRobin{})) Eventually(logger).Should(gbytes.Say(`endpoint-iterator-with-round-robin-lb-algo`)) }) @@ -540,7 +540,7 @@ var _ = Describe("EndpointPool", func() { azPreference := "none" connectionResetError := &net.OpError{Op: "read", Err: errors.New("read: connection reset by peer")} pool.EndpointFailed(failedEndpoint, connectionResetError) - i := pool.Endpoints(logger.Logger, "", false, azPreference, az) + i := pool.Endpoints(logger.Logger, "", false, azPreference, az, config.LOAD_BALANCE_RR, nil) epOne := i.Next(0) epTwo := i.Next(1) Expect(epOne).To(Equal(epTwo)) From fbbfaef4db0fdc1ee66c59647fb381022faad6be Mon Sep 17 00:00:00 2001 From: Tamara Boehm Date: Wed, 3 Dec 2025 10:31:42 +0100 Subject: [PATCH 06/17] Fix bug related to all endpoints are overloaded --- src/code.cloudfoundry.org/gorouter/route/hash_based.go | 2 +- src/code.cloudfoundry.org/gorouter/route/hash_based_test.go | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/src/code.cloudfoundry.org/gorouter/route/hash_based.go b/src/code.cloudfoundry.org/gorouter/route/hash_based.go index 4c6df58de..82bca1eef 100644 --- a/src/code.cloudfoundry.org/gorouter/route/hash_based.go +++ b/src/code.cloudfoundry.org/gorouter/route/hash_based.go @@ -163,7 +163,7 @@ func (h *HashBased) IsImbalancedOrOverloaded(endpoint *Endpoint, isEndpointOverl } // Check if avgNumberOfInFlightRequests is 0 to avoid division by 0 - if avgNumberOfInFlightRequests == 0 || float64(currentInFlightRequestCount)/avgNumberOfInFlightRequests > balanceFactor { + if avgNumberOfInFlightRequests != 0 && float64(currentInFlightRequestCount)/avgNumberOfInFlightRequests > balanceFactor { h.logger.Debug("hash-based-routing-endpoint-imbalanced", slog.String("host", h.pool.host), slog.String("endpoint-id", endpoint.PrivateInstanceId), slog.Int64("endpoint-connections", endpoint.Stats.NumberConnections.Count()), slog.Float64("average-load", avgNumberOfInFlightRequests)) return true } diff --git a/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go b/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go index 09a286153..b3e93b4ce 100644 --- a/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go +++ b/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go @@ -80,10 +80,6 @@ var _ = Describe("HashBased", func() { pool.Put(e) } iter := route.NewHashBased(logger.Logger, pool, "", false, "tenant-1") - // Simulate in-flight requests - for _, e := range endpoints { - iter.PreRequest(e) - } first := iter.Next(0) Expect(iter.Next(0)).To(Equal(first)) for i := 0; i < 6; i++ { From 350038a233b8338d55ba1ae2cc88f385ba18d709 Mon Sep 17 00:00:00 2001 From: Tamara Boehm Date: Wed, 3 Dec 2025 11:07:23 +0100 Subject: [PATCH 07/17] Fix compiler errors --- .../gorouter/registry/registry_test.go | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/code.cloudfoundry.org/gorouter/registry/registry_test.go b/src/code.cloudfoundry.org/gorouter/registry/registry_test.go index 9ac5c1ce9..591c6f4dd 100644 --- a/src/code.cloudfoundry.org/gorouter/registry/registry_test.go +++ b/src/code.cloudfoundry.org/gorouter/registry/registry_test.go @@ -407,7 +407,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumEndpoints()).To(Equal(1)) p := r.Lookup("foo.com") - Expect(p.Endpoints(logger.Logger, "", false, azPreference, az).Next(0).ModificationTag).To(Equal(modTag)) + Expect(p.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0).ModificationTag).To(Equal(modTag)) }) }) @@ -429,7 +429,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumEndpoints()).To(Equal(1)) p := r.Lookup("foo.com") - Expect(p.Endpoints(logger.Logger, "", false, azPreference, az).Next(0).ModificationTag).To(Equal(modTag)) + Expect(p.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0).ModificationTag).To(Equal(modTag)) }) Context("updating an existing route with an older modification tag", func() { @@ -449,7 +449,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumEndpoints()).To(Equal(1)) p := r.Lookup("foo.com") - ep := p.Endpoints(logger.Logger, "", false, azPreference, az).Next(0) + ep := p.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) Expect(ep.ModificationTag).To(Equal(modTag)) Expect(ep).To(Equal(endpoint2)) }) @@ -468,7 +468,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumEndpoints()).To(Equal(1)) p := r.Lookup("foo.com") - Expect(p.Endpoints(logger.Logger, "", false, azPreference, az).Next(0).ModificationTag).To(Equal(modTag)) + Expect(p.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0).ModificationTag).To(Equal(modTag)) }) }) }) @@ -813,7 +813,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumUris()).To(Equal(1)) p1 := r.Lookup("foo/bar") - iter := p1.Endpoints(logger.Logger, "", false, azPreference, az) + iter := p1.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil) Expect(iter.Next(0).CanonicalAddr()).To(Equal("192.168.1.1:1234")) p2 := r.Lookup("foo") @@ -917,7 +917,7 @@ var _ = Describe("RouteRegistry", func() { p2 := r.Lookup("FOO") Expect(p1).To(Equal(p2)) - iter := p1.Endpoints(logger.Logger, "", false, azPreference, az) + iter := p1.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil) Expect(iter.Next(0).CanonicalAddr()).To(Equal("192.168.1.1:1234")) }) @@ -936,7 +936,7 @@ var _ = Describe("RouteRegistry", func() { p := r.Lookup("bar") Expect(p).ToNot(BeNil()) - e := p.Endpoints(logger.Logger, "", false, azPreference, az).Next(0) + e := p.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) Expect(e).ToNot(BeNil()) Expect(e.CanonicalAddr()).To(MatchRegexp("192.168.1.1:123[4|5]")) @@ -951,13 +951,13 @@ var _ = Describe("RouteRegistry", func() { p := r.Lookup("foo.wild.card") Expect(p).ToNot(BeNil()) - e := p.Endpoints(logger.Logger, "", false, azPreference, az).Next(0) + e := p.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) Expect(e).ToNot(BeNil()) Expect(e.CanonicalAddr()).To(Equal("192.168.1.2:1234")) p = r.Lookup("foo.space.wild.card") Expect(p).ToNot(BeNil()) - e = p.Endpoints(logger.Logger, "", false, azPreference, az).Next(0) + e = p.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) Expect(e).ToNot(BeNil()) Expect(e.CanonicalAddr()).To(Equal("192.168.1.2:1234")) }) @@ -971,7 +971,7 @@ var _ = Describe("RouteRegistry", func() { p := r.Lookup("not.wild.card") Expect(p).ToNot(BeNil()) - e := p.Endpoints(logger.Logger, "", false, azPreference, az).Next(0) + e := p.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) Expect(e).ToNot(BeNil()) Expect(e.CanonicalAddr()).To(Equal("192.168.1.1:1234")) }) @@ -1003,7 +1003,7 @@ var _ = Describe("RouteRegistry", func() { p := r.Lookup("dora.app.com/env?foo=bar") Expect(p).ToNot(BeNil()) - iter := p.Endpoints(logger.Logger, "", false, azPreference, az) + iter := p.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil) Expect(iter.Next(0).CanonicalAddr()).To(Equal("192.168.1.1:1234")) }) @@ -1012,7 +1012,7 @@ var _ = Describe("RouteRegistry", func() { p := r.Lookup("dora.app.com/env/abc?foo=bar&baz=bing") Expect(p).ToNot(BeNil()) - iter := p.Endpoints(logger.Logger, "", false, azPreference, az) + iter := p.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil) Expect(iter.Next(0).CanonicalAddr()).To(Equal("192.168.1.1:1234")) }) }) @@ -1032,7 +1032,7 @@ var _ = Describe("RouteRegistry", func() { p1 := r.Lookup("foo/extra/paths") Expect(p1).ToNot(BeNil()) - iter := p1.Endpoints(logger.Logger, "", false, azPreference, az) + iter := p1.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil) Expect(iter.Next(0).CanonicalAddr()).To(Equal("192.168.1.1:1234")) }) @@ -1044,7 +1044,7 @@ var _ = Describe("RouteRegistry", func() { p1 := r.Lookup("foo?fields=foo,bar") Expect(p1).ToNot(BeNil()) - iter := p1.Endpoints(logger.Logger, "", false, azPreference, az) + iter := p1.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil) Expect(iter.Next(0).CanonicalAddr()).To(Equal("192.168.1.1:1234")) }) @@ -1131,7 +1131,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumEndpoints()).To(Equal(2)) p := r.LookupWithAppInstance("bar.com/foo", appId, appIndex) - e := p.Endpoints(logger.Logger, "", false, azPreference, az).Next(0) + e := p.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) Expect(e).ToNot(BeNil()) Expect(e.CanonicalAddr()).To(MatchRegexp("192.168.1.1:1234")) @@ -1152,7 +1152,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumEndpoints()).To(Equal(2)) p := r.LookupWithAppInstance("bar.com/foo", appId, appIndex) - e := p.Endpoints(logger.Logger, "", false, azPreference, az).Next(0) + e := p.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) Expect(e).ToNot(BeNil()) Expect(e.CanonicalAddr()).To(MatchRegexp("192.168.1.1:1234")) @@ -1260,7 +1260,7 @@ var _ = Describe("RouteRegistry", func() { p := r.LookupWithProcessInstance("bar.com/foo", processId, processIndex) Expect(p.NumEndpoints()).To(Equal(2)) - es := p.Endpoints(logger.Logger, "", false, azPreference, az) + es := p.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil) e1 := es.Next(0) Expect(e1).ToNot(BeNil()) e2 := es.Next(0) @@ -1299,7 +1299,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumEndpoints()).To(Equal(5)) p := r.LookupWithProcessInstance("bar.com/foo", processId, processIndex) - e := p.Endpoints(logger.Logger, "", false, azPreference, az).Next(0) + e := p.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) Expect(e).ToNot(BeNil()) Expect(e.CanonicalAddr()).To(MatchRegexp("192.168.1.4:1237")) @@ -1506,7 +1506,7 @@ var _ = Describe("RouteRegistry", func() { p := r.Lookup("foo") Expect(p).ToNot(BeNil()) - Expect(p.Endpoints(logger.Logger, "", false, azPreference, az).Next(0)).To(Equal(endpoint)) + Expect(p.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0)).To(Equal(endpoint)) p = r.Lookup("bar") Expect(p).To(BeNil()) From 95edd7b402ff2d2fbb461503d6ce7472159dae81 Mon Sep 17 00:00:00 2001 From: Clemens Hoffmann Date: Mon, 8 Dec 2025 10:40:34 +0100 Subject: [PATCH 08/17] Update RouteSchema struct in route-emitter --- src/code.cloudfoundry.org/route-registrar/config/config.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/code.cloudfoundry.org/route-registrar/config/config.go b/src/code.cloudfoundry.org/route-registrar/config/config.go index 63eb01101..b1546195f 100644 --- a/src/code.cloudfoundry.org/route-registrar/config/config.go +++ b/src/code.cloudfoundry.org/route-registrar/config/config.go @@ -75,6 +75,8 @@ type RouteSchema struct { type Options struct { LoadBalancingAlgorithm LoadBalancingAlgorithm `json:"loadbalancing,omitempty" yaml:"loadbalancing,omitempty"` + HashHeader string `json:"hash_header,omitempty" yaml:"hash_header,omitempty"` + HashBalance float64 `json:"hash_balance,omitempty" yaml:"hash_balance,omitempty"` } type LoadBalancingAlgorithm string From f7910d72cc5b42916cf3de36cfc4c08d38fa8237 Mon Sep 17 00:00:00 2001 From: Clemens Hoffmann Date: Mon, 8 Dec 2025 15:28:07 +0100 Subject: [PATCH 09/17] Add json string property to hash_balance --- src/code.cloudfoundry.org/gorouter/mbus/subscriber.go | 2 +- src/code.cloudfoundry.org/gorouter/route/pool.go | 2 +- src/code.cloudfoundry.org/gorouter/route/pool_test.go | 2 +- src/code.cloudfoundry.org/route-registrar/config/config.go | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/code.cloudfoundry.org/gorouter/mbus/subscriber.go b/src/code.cloudfoundry.org/gorouter/mbus/subscriber.go index af17a19b4..758a6c1a6 100644 --- a/src/code.cloudfoundry.org/gorouter/mbus/subscriber.go +++ b/src/code.cloudfoundry.org/gorouter/mbus/subscriber.go @@ -43,7 +43,7 @@ type RegistryMessage struct { type RegistryMessageOpts struct { LoadBalancingAlgorithm string `json:"loadbalancing"` HashHeaderName string `json:"hash_header"` - HashBalance float64 `json:"hash_balance"` + HashBalance float64 `json:"hash_balance,string"` } func (rm *RegistryMessage) makeEndpoint(http2Enabled bool) (*route.Endpoint, error) { diff --git a/src/code.cloudfoundry.org/gorouter/route/pool.go b/src/code.cloudfoundry.org/gorouter/route/pool.go index 4df23ae4c..7225a5825 100644 --- a/src/code.cloudfoundry.org/gorouter/route/pool.go +++ b/src/code.cloudfoundry.org/gorouter/route/pool.go @@ -688,7 +688,7 @@ func (e *Endpoint) MarshalJSON() ([]byte, error) { ServerCertDomainSAN string `json:"server_cert_domain_san,omitempty"` LoadBalancingAlgorithm string `json:"load_balancing_algorithm,omitempty"` HashHeader string `json:"hash_header,omitempty"` - HashBalance float64 `json:"hash_balance,omitempty"` + HashBalance float64 `json:"hash_balance,omitempty,string"` } jsonObj.Address = e.addr diff --git a/src/code.cloudfoundry.org/gorouter/route/pool_test.go b/src/code.cloudfoundry.org/gorouter/route/pool_test.go index 050c2694a..cdd7fbee4 100644 --- a/src/code.cloudfoundry.org/gorouter/route/pool_test.go +++ b/src/code.cloudfoundry.org/gorouter/route/pool_test.go @@ -967,7 +967,7 @@ var _ = Describe("EndpointPool", func() { json, err := pool.MarshalJSON() Expect(err).ToNot(HaveOccurred()) - Expect(string(json)).To(Equal(`[{"address":"1.2.3.4:5678","availability_zone":"az-meow","protocol":"http1","tls":false,"ttl":-1,"route_service_url":"https://my-rs.com","tags":null},{"address":"5.6.7.8:5678","availability_zone":"","protocol":"http2","tls":true,"ttl":-1,"tags":null,"private_instance_id":"pvt_test_instance_id","server_cert_domain_san":"pvt_test_san","load_balancing_algorithm":"hash","hash_header":"X-Header","hash_balance":1.25}]`)) + Expect(string(json)).To(Equal(`[{"address":"1.2.3.4:5678","availability_zone":"az-meow","protocol":"http1","tls":false,"ttl":-1,"route_service_url":"https://my-rs.com","tags":null},{"address":"5.6.7.8:5678","availability_zone":"","protocol":"http2","tls":true,"ttl":-1,"tags":null,"private_instance_id":"pvt_test_instance_id","server_cert_domain_san":"pvt_test_san","load_balancing_algorithm":"hash","hash_header":"X-Header","hash_balance":"1.25"}]`)) }) Context("when endpoints do not have empty tags", func() { diff --git a/src/code.cloudfoundry.org/route-registrar/config/config.go b/src/code.cloudfoundry.org/route-registrar/config/config.go index b1546195f..1c130de25 100644 --- a/src/code.cloudfoundry.org/route-registrar/config/config.go +++ b/src/code.cloudfoundry.org/route-registrar/config/config.go @@ -76,7 +76,7 @@ type RouteSchema struct { type Options struct { LoadBalancingAlgorithm LoadBalancingAlgorithm `json:"loadbalancing,omitempty" yaml:"loadbalancing,omitempty"` HashHeader string `json:"hash_header,omitempty" yaml:"hash_header,omitempty"` - HashBalance float64 `json:"hash_balance,omitempty" yaml:"hash_balance,omitempty"` + HashBalance float64 `json:"hash_balance,omitempty,string" yaml:"hash_balance,omitempty"` } type LoadBalancingAlgorithm string From b8b26079601a8ab0356b06f165ea7a79e5e1c0ef Mon Sep 17 00:00:00 2001 From: Tamara Boehm Date: Tue, 13 Jan 2026 09:59:37 +0100 Subject: [PATCH 10/17] Refactor route/pool Endpoints function --- .../gorouter/route/pool.go | 62 +++++++++++-------- 1 file changed, 35 insertions(+), 27 deletions(-) diff --git a/src/code.cloudfoundry.org/gorouter/route/pool.go b/src/code.cloudfoundry.org/gorouter/route/pool.go index 7225a5825..fa3817e9d 100644 --- a/src/code.cloudfoundry.org/gorouter/route/pool.go +++ b/src/code.cloudfoundry.org/gorouter/route/pool.go @@ -473,45 +473,53 @@ func (p *EndpointPool) removeEndpoint(e *endpointElem) { } func (p *EndpointPool) Endpoints(logger *slog.Logger, initial string, mustBeSticky bool, azPreference string, az string, globalLB string, request *http.Request) EndpointIterator { - switch p.LoadBalancingAlgorithm { - case config.LOAD_BALANCE_LC: - logger.Debug("endpoint-iterator-with-least-connection-lb-algo") - return NewLeastConnection(logger, p, initial, mustBeSticky, azPreference == config.AZ_PREF_LOCAL, az) - case config.LOAD_BALANCE_RR: - logger.Debug("endpoint-iterator-with-round-robin-lb-algo") - return NewRoundRobin(logger, p, initial, mustBeSticky, azPreference == config.AZ_PREF_LOCAL, az) - case config.LOAD_BALANCE_HB: - if p.HashRoutingProperties == nil || request.Header.Get(p.HashRoutingProperties.Header) == "" { - logger.Error("hash-routing-properties-missing", slog.String("host", p.Host())) - return p.FallBackToDefaultLoadBalancing(globalLB, logger, initial, mustBeSticky, azPreference, az) + locallyOptimistic := azPreference == config.AZ_PREF_LOCAL + + // For hash-based routing, validate inputs and get header value + if p.LoadBalancingAlgorithm == config.LOAD_BALANCE_HB { + valid, headerValue := p.hashBasedInputsValid(request, p.HashRoutingProperties, logger) + if !valid { + logger.Info("hash-based-routing-header-not-found", + slog.String("Host", p.host), + slog.String("Path", p.contextPath)) + return p.createIterator(globalLB, logger, initial, mustBeSticky, locallyOptimistic, az) } - headerValue := request.Header.Get(p.HashRoutingProperties.Header) logger.Debug("endpoint-iterator-with-hash-based-lb-algo") return NewHashBased(logger, p, initial, mustBeSticky, headerValue) - default: - logger.Error("invalid-pool-load-balancing-algorithm", - slog.String("poolLBAlgorithm", p.LoadBalancingAlgorithm), - slog.String("Host", p.host), - slog.String("Path", p.contextPath)) - return NewRoundRobin(logger, p, initial, mustBeSticky, azPreference == config.AZ_PREF_LOCAL, az) } + + return p.createIterator(p.LoadBalancingAlgorithm, logger, initial, mustBeSticky, locallyOptimistic, az) } -func (p *EndpointPool) FallBackToDefaultLoadBalancing(defaultLBAlgo string, logger *slog.Logger, initial string, mustBeSticky bool, azPreference string, az string) EndpointIterator { - logger.Info("hash-based-routing-header-not-found", - slog.String("poolLBAlgorithm", p.LoadBalancingAlgorithm), - slog.String("Host", p.host), - slog.String("Path", p.contextPath)) +func (p *EndpointPool) hashBasedInputsValid(request *http.Request, hashProps *HashRoutingProperties, logger *slog.Logger) (bool, string) { + if hashProps == nil { + logger.Error("hash-routing-properties-missing", slog.String("host", p.Host())) + return false, "" + } + hashHeader := request.Header.Get(hashProps.Header) + if hashHeader == "" { + logger.Error("hash-based-routing-header-not-found", slog.String("host", p.Host())) + return false, "" + } + return true, hashHeader +} - switch defaultLBAlgo { +func (p *EndpointPool) createIterator(lbAlgo string, logger *slog.Logger, initial string, mustBeSticky bool, locallyOptimistic bool, az string) EndpointIterator { + switch lbAlgo { case config.LOAD_BALANCE_LC: logger.Debug("endpoint-iterator-with-least-connection-lb-algo") - return NewLeastConnection(logger, p, initial, mustBeSticky, azPreference == config.AZ_PREF_LOCAL, az) + return NewLeastConnection(logger, p, initial, mustBeSticky, locallyOptimistic, az) case config.LOAD_BALANCE_RR: logger.Debug("endpoint-iterator-with-round-robin-lb-algo") - return NewRoundRobin(logger, p, initial, mustBeSticky, azPreference == config.AZ_PREF_LOCAL, az) + return NewRoundRobin(logger, p, initial, mustBeSticky, locallyOptimistic, az) + default: + logger.Error("invalid-pool-load-balancing-algorithm", + slog.String("poolLBAlgorithm", lbAlgo), + slog.String("Host", p.host), + slog.String("Path", p.contextPath)) + logger.Debug("endpoint-iterator-with-round-robin-lb-algo") + return NewRoundRobin(logger, p, initial, mustBeSticky, locallyOptimistic, az) } - return NewRoundRobin(logger, p, initial, mustBeSticky, azPreference == config.AZ_PREF_LOCAL, az) } func (p *EndpointPool) NumEndpoints() int { From d23f2851a85c0968cb5fa58d63fcb523ef7e1c7f Mon Sep 17 00:00:00 2001 From: Clemens Hoffmann Date: Tue, 13 Jan 2026 11:02:51 +0100 Subject: [PATCH 11/17] Refactor locallyOptimistic and pass request header to Endpoints --- .../gorouter/handlers/helpers.go | 4 +- .../gorouter/handlers/max_request_size.go | 3 +- .../round_tripper/proxy_round_tripper.go | 3 +- .../round_tripper/proxy_round_tripper_test.go | 4 +- .../gorouter/registry/registry_test.go | 43 ++++++++++--------- .../gorouter/route/pool.go | 27 +++++------- .../gorouter/route/pool_test.go | 22 +++++----- 7 files changed, 53 insertions(+), 53 deletions(-) diff --git a/src/code.cloudfoundry.org/gorouter/handlers/helpers.go b/src/code.cloudfoundry.org/gorouter/handlers/helpers.go index 1eeebf648..86c3d9d89 100644 --- a/src/code.cloudfoundry.org/gorouter/handlers/helpers.go +++ b/src/code.cloudfoundry.org/gorouter/handlers/helpers.go @@ -63,13 +63,13 @@ func upgradeHeader(request *http.Request) string { return "" } -func EndpointIteratorForRequest(logger *slog.Logger, request *http.Request, stickySessionCookieNames config.StringSet, authNegotiateSticky bool, azPreference string, az string, globalLB string) (route.EndpointIterator, error) { +func EndpointIteratorForRequest(logger *slog.Logger, request *http.Request, stickySessionCookieNames config.StringSet, authNegotiateSticky bool, locallyOptimistic bool, az string, globalLB string) (route.EndpointIterator, error) { reqInfo, err := ContextRequestInfo(request) if err != nil { return nil, fmt.Errorf("could not find reqInfo in context") } stickyEndpointID, mustBeSticky := GetStickySession(request, stickySessionCookieNames, authNegotiateSticky) - return reqInfo.RoutePool.Endpoints(logger, stickyEndpointID, mustBeSticky, azPreference, az, globalLB, request), nil + return reqInfo.RoutePool.Endpoints(logger, stickyEndpointID, mustBeSticky, locallyOptimistic, az, globalLB, &request.Header), nil } func GetStickySession(request *http.Request, stickySessionCookieNames config.StringSet, authNegotiateSticky bool) (string, bool) { diff --git a/src/code.cloudfoundry.org/gorouter/handlers/max_request_size.go b/src/code.cloudfoundry.org/gorouter/handlers/max_request_size.go index a4e44d63b..e88654c74 100644 --- a/src/code.cloudfoundry.org/gorouter/handlers/max_request_size.go +++ b/src/code.cloudfoundry.org/gorouter/handlers/max_request_size.go @@ -69,7 +69,8 @@ func (m *MaxRequestSize) ServeHTTP(rw http.ResponseWriter, r *http.Request, next if err != nil { logger.Error("request-info-err", log.ErrAttr(err)) } else { - endpointIterator, err := EndpointIteratorForRequest(logger, r, m.cfg.StickySessionCookieNames, m.cfg.StickySessionsForAuthNegotiate, m.cfg.LoadBalanceAZPreference, m.cfg.Zone, m.cfg.LoadBalance) + locallyOptimistic := m.cfg.LoadBalanceAZPreference == config.AZ_PREF_LOCAL + endpointIterator, err := EndpointIteratorForRequest(logger, r, m.cfg.StickySessionCookieNames, m.cfg.StickySessionsForAuthNegotiate, locallyOptimistic, m.cfg.Zone, m.cfg.LoadBalance) if err != nil { logger.Error("failed-to-find-endpoint-for-req-during-431-short-circuit", log.ErrAttr(err)) } else { diff --git a/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper.go b/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper.go index b17273ad1..ead7df427 100644 --- a/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper.go +++ b/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper.go @@ -126,7 +126,8 @@ func (rt *roundTripper) RoundTrip(originalRequest *http.Request) (*http.Response stickyEndpointID, mustBeSticky := handlers.GetStickySession(request, rt.config.StickySessionCookieNames, rt.config.StickySessionsForAuthNegotiate) numberOfEndpoints := reqInfo.RoutePool.NumEndpoints() - iter := reqInfo.RoutePool.Endpoints(rt.logger, stickyEndpointID, mustBeSticky, rt.config.LoadBalanceAZPreference, rt.config.Zone, rt.config.LoadBalance, request) + locallyOptimistic := rt.config.LoadBalanceAZPreference == config.AZ_PREF_LOCAL + iter := reqInfo.RoutePool.Endpoints(rt.logger, stickyEndpointID, mustBeSticky, locallyOptimistic, rt.config.Zone, rt.config.LoadBalance, &request.Header) // The selectEndpointErr needs to be tracked separately. If we get an error // while selecting an endpoint we might just have run out of routes. In diff --git a/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper_test.go b/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper_test.go index 50176e0d2..3c05f535a 100644 --- a/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper_test.go +++ b/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper_test.go @@ -275,7 +275,7 @@ var _ = Describe("ProxyRoundTripper", func() { res, err := proxyRoundTripper.RoundTrip(req) Expect(err).NotTo(HaveOccurred()) - iter := routePool.Endpoints(logger.Logger, "", false, AZPreference, AZ, cfg.LoadBalance, req) + iter := routePool.Endpoints(logger.Logger, "", false, false, AZ, cfg.LoadBalance, &req.Header) ep1 := iter.Next(0) ep2 := iter.Next(1) Expect(ep1.PrivateInstanceId).To(Equal(ep2.PrivateInstanceId)) @@ -609,7 +609,7 @@ var _ = Describe("ProxyRoundTripper", func() { _, err := proxyRoundTripper.RoundTrip(req) Expect(err).To(MatchError(ContainSubstring("tls: handshake failure"))) - iter := routePool.Endpoints(logger.Logger, "", false, AZPreference, AZ, cfg.LoadBalance, req) + iter := routePool.Endpoints(logger.Logger, "", false, false, AZ, cfg.LoadBalance, &req.Header) ep1 := iter.Next(0) ep2 := iter.Next(1) Expect(ep1).To(Equal(ep2)) diff --git a/src/code.cloudfoundry.org/gorouter/registry/registry_test.go b/src/code.cloudfoundry.org/gorouter/registry/registry_test.go index 591c6f4dd..267e4b1ff 100644 --- a/src/code.cloudfoundry.org/gorouter/registry/registry_test.go +++ b/src/code.cloudfoundry.org/gorouter/registry/registry_test.go @@ -25,10 +25,11 @@ var _ = Describe("RouteRegistry", func() { var configObj *config.Config var logger *test_util.TestLogger - var azPreference, az string + var locallyOptimistic bool + var az string BeforeEach(func() { - azPreference = "none" + locallyOptimistic = false az = "meow-zone" logger = test_util.NewTestLogger("test") @@ -407,7 +408,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumEndpoints()).To(Equal(1)) p := r.Lookup("foo.com") - Expect(p.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0).ModificationTag).To(Equal(modTag)) + Expect(p.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0).ModificationTag).To(Equal(modTag)) }) }) @@ -429,7 +430,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumEndpoints()).To(Equal(1)) p := r.Lookup("foo.com") - Expect(p.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0).ModificationTag).To(Equal(modTag)) + Expect(p.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0).ModificationTag).To(Equal(modTag)) }) Context("updating an existing route with an older modification tag", func() { @@ -449,7 +450,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumEndpoints()).To(Equal(1)) p := r.Lookup("foo.com") - ep := p.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) + ep := p.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) Expect(ep.ModificationTag).To(Equal(modTag)) Expect(ep).To(Equal(endpoint2)) }) @@ -468,7 +469,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumEndpoints()).To(Equal(1)) p := r.Lookup("foo.com") - Expect(p.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0).ModificationTag).To(Equal(modTag)) + Expect(p.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0).ModificationTag).To(Equal(modTag)) }) }) }) @@ -813,7 +814,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumUris()).To(Equal(1)) p1 := r.Lookup("foo/bar") - iter := p1.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil) + iter := p1.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil) Expect(iter.Next(0).CanonicalAddr()).To(Equal("192.168.1.1:1234")) p2 := r.Lookup("foo") @@ -917,7 +918,7 @@ var _ = Describe("RouteRegistry", func() { p2 := r.Lookup("FOO") Expect(p1).To(Equal(p2)) - iter := p1.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil) + iter := p1.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil) Expect(iter.Next(0).CanonicalAddr()).To(Equal("192.168.1.1:1234")) }) @@ -936,7 +937,7 @@ var _ = Describe("RouteRegistry", func() { p := r.Lookup("bar") Expect(p).ToNot(BeNil()) - e := p.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) + e := p.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) Expect(e).ToNot(BeNil()) Expect(e.CanonicalAddr()).To(MatchRegexp("192.168.1.1:123[4|5]")) @@ -951,13 +952,13 @@ var _ = Describe("RouteRegistry", func() { p := r.Lookup("foo.wild.card") Expect(p).ToNot(BeNil()) - e := p.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) + e := p.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) Expect(e).ToNot(BeNil()) Expect(e.CanonicalAddr()).To(Equal("192.168.1.2:1234")) p = r.Lookup("foo.space.wild.card") Expect(p).ToNot(BeNil()) - e = p.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) + e = p.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) Expect(e).ToNot(BeNil()) Expect(e.CanonicalAddr()).To(Equal("192.168.1.2:1234")) }) @@ -971,7 +972,7 @@ var _ = Describe("RouteRegistry", func() { p := r.Lookup("not.wild.card") Expect(p).ToNot(BeNil()) - e := p.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) + e := p.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) Expect(e).ToNot(BeNil()) Expect(e.CanonicalAddr()).To(Equal("192.168.1.1:1234")) }) @@ -1003,7 +1004,7 @@ var _ = Describe("RouteRegistry", func() { p := r.Lookup("dora.app.com/env?foo=bar") Expect(p).ToNot(BeNil()) - iter := p.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil) + iter := p.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil) Expect(iter.Next(0).CanonicalAddr()).To(Equal("192.168.1.1:1234")) }) @@ -1012,7 +1013,7 @@ var _ = Describe("RouteRegistry", func() { p := r.Lookup("dora.app.com/env/abc?foo=bar&baz=bing") Expect(p).ToNot(BeNil()) - iter := p.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil) + iter := p.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil) Expect(iter.Next(0).CanonicalAddr()).To(Equal("192.168.1.1:1234")) }) }) @@ -1032,7 +1033,7 @@ var _ = Describe("RouteRegistry", func() { p1 := r.Lookup("foo/extra/paths") Expect(p1).ToNot(BeNil()) - iter := p1.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil) + iter := p1.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil) Expect(iter.Next(0).CanonicalAddr()).To(Equal("192.168.1.1:1234")) }) @@ -1044,7 +1045,7 @@ var _ = Describe("RouteRegistry", func() { p1 := r.Lookup("foo?fields=foo,bar") Expect(p1).ToNot(BeNil()) - iter := p1.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil) + iter := p1.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil) Expect(iter.Next(0).CanonicalAddr()).To(Equal("192.168.1.1:1234")) }) @@ -1131,7 +1132,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumEndpoints()).To(Equal(2)) p := r.LookupWithAppInstance("bar.com/foo", appId, appIndex) - e := p.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) + e := p.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) Expect(e).ToNot(BeNil()) Expect(e.CanonicalAddr()).To(MatchRegexp("192.168.1.1:1234")) @@ -1152,7 +1153,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumEndpoints()).To(Equal(2)) p := r.LookupWithAppInstance("bar.com/foo", appId, appIndex) - e := p.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) + e := p.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) Expect(e).ToNot(BeNil()) Expect(e.CanonicalAddr()).To(MatchRegexp("192.168.1.1:1234")) @@ -1260,7 +1261,7 @@ var _ = Describe("RouteRegistry", func() { p := r.LookupWithProcessInstance("bar.com/foo", processId, processIndex) Expect(p.NumEndpoints()).To(Equal(2)) - es := p.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil) + es := p.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil) e1 := es.Next(0) Expect(e1).ToNot(BeNil()) e2 := es.Next(0) @@ -1299,7 +1300,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumEndpoints()).To(Equal(5)) p := r.LookupWithProcessInstance("bar.com/foo", processId, processIndex) - e := p.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) + e := p.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) Expect(e).ToNot(BeNil()) Expect(e.CanonicalAddr()).To(MatchRegexp("192.168.1.4:1237")) @@ -1506,7 +1507,7 @@ var _ = Describe("RouteRegistry", func() { p := r.Lookup("foo") Expect(p).ToNot(BeNil()) - Expect(p.Endpoints(logger.Logger, "", false, azPreference, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0)).To(Equal(endpoint)) + Expect(p.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0)).To(Equal(endpoint)) p = r.Lookup("bar") Expect(p).To(BeNil()) diff --git a/src/code.cloudfoundry.org/gorouter/route/pool.go b/src/code.cloudfoundry.org/gorouter/route/pool.go index fa3817e9d..3402b6c62 100644 --- a/src/code.cloudfoundry.org/gorouter/route/pool.go +++ b/src/code.cloudfoundry.org/gorouter/route/pool.go @@ -472,36 +472,31 @@ func (p *EndpointPool) removeEndpoint(e *endpointElem) { } -func (p *EndpointPool) Endpoints(logger *slog.Logger, initial string, mustBeSticky bool, azPreference string, az string, globalLB string, request *http.Request) EndpointIterator { - locallyOptimistic := azPreference == config.AZ_PREF_LOCAL - +func (p *EndpointPool) Endpoints(logger *slog.Logger, initial string, mustBeSticky bool, locallyOptimistic bool, az string, globalLB string, header *http.Header) EndpointIterator { // For hash-based routing, validate inputs and get header value if p.LoadBalancingAlgorithm == config.LOAD_BALANCE_HB { - valid, headerValue := p.hashBasedInputsValid(request, p.HashRoutingProperties, logger) - if !valid { - logger.Info("hash-based-routing-header-not-found", - slog.String("Host", p.host), - slog.String("Path", p.contextPath)) + headerValue := p.hashBasedInputsValid(header, p.HashRoutingProperties, logger) + if headerValue == "" { return p.createIterator(globalLB, logger, initial, mustBeSticky, locallyOptimistic, az) } - logger.Debug("endpoint-iterator-with-hash-based-lb-algo") return NewHashBased(logger, p, initial, mustBeSticky, headerValue) } return p.createIterator(p.LoadBalancingAlgorithm, logger, initial, mustBeSticky, locallyOptimistic, az) } -func (p *EndpointPool) hashBasedInputsValid(request *http.Request, hashProps *HashRoutingProperties, logger *slog.Logger) (bool, string) { - if hashProps == nil { +func (p *EndpointPool) hashBasedInputsValid(header *http.Header, hashProps *HashRoutingProperties, logger *slog.Logger) string { + if hashProps == nil || hashProps.Header == "" { logger.Error("hash-routing-properties-missing", slog.String("host", p.Host())) - return false, "" + return "" } - hashHeader := request.Header.Get(hashProps.Header) + hashHeader := header.Get(hashProps.Header) if hashHeader == "" { - logger.Error("hash-based-routing-header-not-found", slog.String("host", p.Host())) - return false, "" + logger.Warn("hash-header-value-not-found", + slog.String("Host", p.host), + slog.String("Path", p.contextPath)) } - return true, hashHeader + return hashHeader } func (p *EndpointPool) createIterator(lbAlgo string, logger *slog.Logger, initial string, mustBeSticky bool, locallyOptimistic bool, az string) EndpointIterator { diff --git a/src/code.cloudfoundry.org/gorouter/route/pool_test.go b/src/code.cloudfoundry.org/gorouter/route/pool_test.go index cdd7fbee4..9c0aa63ee 100644 --- a/src/code.cloudfoundry.org/gorouter/route/pool_test.go +++ b/src/code.cloudfoundry.org/gorouter/route/pool_test.go @@ -198,8 +198,8 @@ var _ = Describe("EndpointPool", func() { Context("Put", func() { var ( - az = "meow-zone" - azPreference = "none" + az = "meow-zone" + locallyOptimistic = false ) It("adds endpoints", func() { @@ -246,7 +246,7 @@ var _ = Describe("EndpointPool", func() { endpoint := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, ModificationTag: modTag2}) Expect(pool.Put(endpoint)).To(Equal(route.EndpointUpdated)) - Expect(pool.Endpoints(logger.Logger, "", false, azPreference, az, config.LOAD_BALANCE_RR, nil).Next(0).ModificationTag).To(Equal(modTag2)) + Expect(pool.Endpoints(logger.Logger, "", false, locallyOptimistic, az, config.LOAD_BALANCE_RR, nil).Next(0).ModificationTag).To(Equal(modTag2)) }) Context("when modification_tag is older", func() { @@ -261,7 +261,7 @@ var _ = Describe("EndpointPool", func() { endpoint := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, ModificationTag: olderModTag}) Expect(pool.Put(endpoint)).To(Equal(route.EndpointUnmodified)) - Expect(pool.Endpoints(logger.Logger, "", false, azPreference, az, config.LOAD_BALANCE_RR, nil).Next(0).ModificationTag).To(Equal(modTag2)) + Expect(pool.Endpoints(logger.Logger, "", false, locallyOptimistic, az, config.LOAD_BALANCE_RR, nil).Next(0).ModificationTag).To(Equal(modTag2)) }) }) }) @@ -297,7 +297,9 @@ var _ = Describe("EndpointPool", func() { }) }) Context("Customizable Per Route Load Balancing", func() { - + var ( + locallyOptimistic = false + ) Context("Load Balancing Algorithm of a pool", func() { It("has a value specified in the pool options", func() { poolWithLBAlgo := route.NewPool(&route.PoolOpts{ @@ -312,7 +314,7 @@ var _ = Describe("EndpointPool", func() { Logger: logger.Logger, LoadBalancingAlgorithm: "wrong-lb-algo", }) - iterator := poolWithLBAlgo2.Endpoints(logger.Logger, "", false, "none", "zone", config.LOAD_BALANCE_RR, nil) + iterator := poolWithLBAlgo2.Endpoints(logger.Logger, "", false, locallyOptimistic, "zone", config.LOAD_BALANCE_RR, nil) Expect(iterator).To(BeAssignableToTypeOf(&route.RoundRobin{})) Eventually(logger).Should(gbytes.Say(`invalid-pool-load-balancing-algorithm`)) }) @@ -322,7 +324,7 @@ var _ = Describe("EndpointPool", func() { Logger: logger.Logger, LoadBalancingAlgorithm: config.LOAD_BALANCE_LC, }) - iterator := poolWithLBAlgoLC.Endpoints(logger.Logger, "", false, "none", "az", config.LOAD_BALANCE_LC, nil) + iterator := poolWithLBAlgoLC.Endpoints(logger.Logger, "", false, locallyOptimistic, "az", config.LOAD_BALANCE_LC, nil) Expect(iterator).To(BeAssignableToTypeOf(&route.LeastConnection{})) Eventually(logger).Should(gbytes.Say(`endpoint-iterator-with-least-connection-lb-algo`)) }) @@ -332,7 +334,7 @@ var _ = Describe("EndpointPool", func() { Logger: logger.Logger, LoadBalancingAlgorithm: config.LOAD_BALANCE_RR, }) - iterator := poolWithLBAlgoLC.Endpoints(logger.Logger, "", false, "none", "az", config.LOAD_BALANCE_RR, nil) + iterator := poolWithLBAlgoLC.Endpoints(logger.Logger, "", false, locallyOptimistic, "az", config.LOAD_BALANCE_RR, nil) Expect(iterator).To(BeAssignableToTypeOf(&route.RoundRobin{})) Eventually(logger).Should(gbytes.Say(`endpoint-iterator-with-round-robin-lb-algo`)) }) @@ -537,10 +539,10 @@ var _ = Describe("EndpointPool", func() { Context("when a read connection is reset", func() { It("marks the endpoint as failed", func() { az := "meow-zone" - azPreference := "none" + locallyOptimistic := false connectionResetError := &net.OpError{Op: "read", Err: errors.New("read: connection reset by peer")} pool.EndpointFailed(failedEndpoint, connectionResetError) - i := pool.Endpoints(logger.Logger, "", false, azPreference, az, config.LOAD_BALANCE_RR, nil) + i := pool.Endpoints(logger.Logger, "", false, locallyOptimistic, az, config.LOAD_BALANCE_RR, nil) epOne := i.Next(0) epTwo := i.Next(1) Expect(epOne).To(Equal(epTwo)) From 52f0f7cfb948061a850aa67b5c4bd2e88c42e816 Mon Sep 17 00:00:00 2001 From: Tamara Boehm Date: Thu, 15 Jan 2026 10:05:57 +0100 Subject: [PATCH 12/17] WIP: Introduce routing properties --- .../gorouter/handlers/helpers.go | 8 +++- .../round_tripper/proxy_round_tripper.go | 9 +++- .../round_tripper/proxy_round_tripper_test.go | 19 ++++++-- .../gorouter/registry/registry_test.go | 48 +++++++++++-------- .../gorouter/route/pool.go | 25 ++++++---- .../gorouter/route/pool_test.go | 33 ++++++++++--- 6 files changed, 101 insertions(+), 41 deletions(-) diff --git a/src/code.cloudfoundry.org/gorouter/handlers/helpers.go b/src/code.cloudfoundry.org/gorouter/handlers/helpers.go index 86c3d9d89..d2716613f 100644 --- a/src/code.cloudfoundry.org/gorouter/handlers/helpers.go +++ b/src/code.cloudfoundry.org/gorouter/handlers/helpers.go @@ -69,7 +69,13 @@ func EndpointIteratorForRequest(logger *slog.Logger, request *http.Request, stic return nil, fmt.Errorf("could not find reqInfo in context") } stickyEndpointID, mustBeSticky := GetStickySession(request, stickySessionCookieNames, authNegotiateSticky) - return reqInfo.RoutePool.Endpoints(logger, stickyEndpointID, mustBeSticky, locallyOptimistic, az, globalLB, &request.Header), nil + routingProperties := route.RoutingProperties{ + RequestHeaders: &request.Header, + LocallyOptimistic: locallyOptimistic, + GlobalLB: globalLB, + AZ: az, + } + return reqInfo.RoutePool.Endpoints(logger, stickyEndpointID, mustBeSticky, routingProperties), nil } func GetStickySession(request *http.Request, stickySessionCookieNames config.StringSet, authNegotiateSticky bool) (string, bool) { diff --git a/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper.go b/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper.go index ead7df427..48e67221c 100644 --- a/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper.go +++ b/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper.go @@ -127,7 +127,14 @@ func (rt *roundTripper) RoundTrip(originalRequest *http.Request) (*http.Response stickyEndpointID, mustBeSticky := handlers.GetStickySession(request, rt.config.StickySessionCookieNames, rt.config.StickySessionsForAuthNegotiate) numberOfEndpoints := reqInfo.RoutePool.NumEndpoints() locallyOptimistic := rt.config.LoadBalanceAZPreference == config.AZ_PREF_LOCAL - iter := reqInfo.RoutePool.Endpoints(rt.logger, stickyEndpointID, mustBeSticky, locallyOptimistic, rt.config.Zone, rt.config.LoadBalance, &request.Header) + routingProperties := route.RoutingProperties{ + RequestHeaders: &request.Header, + LocallyOptimistic: locallyOptimistic, + GlobalLB: rt.config.LoadBalance, + AZ: rt.config.Zone, + } + + iter := reqInfo.RoutePool.Endpoints(rt.logger, stickyEndpointID, mustBeSticky, routingProperties) // The selectEndpointErr needs to be tracked separately. If we get an error // while selecting an endpoint we might just have run out of routes. In diff --git a/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper_test.go b/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper_test.go index 3c05f535a..dab18e3e4 100644 --- a/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper_test.go +++ b/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper_test.go @@ -274,8 +274,14 @@ var _ = Describe("ProxyRoundTripper", func() { It("logs the error and removes offending backend", func() { res, err := proxyRoundTripper.RoundTrip(req) Expect(err).NotTo(HaveOccurred()) + routingProps := route.RoutingProperties{ + LocallyOptimistic: false, + GlobalLB: cfg.LoadBalance, + AZ: AZ, + RequestHeaders: &req.Header, + } - iter := routePool.Endpoints(logger.Logger, "", false, false, AZ, cfg.LoadBalance, &req.Header) + iter := routePool.Endpoints(logger.Logger, "", false, routingProps) ep1 := iter.Next(0) ep2 := iter.Next(1) Expect(ep1.PrivateInstanceId).To(Equal(ep2.PrivateInstanceId)) @@ -603,13 +609,20 @@ var _ = Describe("ProxyRoundTripper", func() { PrivateInstanceIndex: "2", }) + routingProps := route.RoutingProperties{ + LocallyOptimistic: false, + GlobalLB: cfg.LoadBalance, + AZ: AZ, + RequestHeaders: &req.Header, + } + added := routePool.Put(endpoint) Expect(added).To(Equal(route.EndpointAdded)) _, err := proxyRoundTripper.RoundTrip(req) Expect(err).To(MatchError(ContainSubstring("tls: handshake failure"))) - iter := routePool.Endpoints(logger.Logger, "", false, false, AZ, cfg.LoadBalance, &req.Header) + iter := routePool.Endpoints(logger.Logger, "", false, routingProps) ep1 := iter.Next(0) ep2 := iter.Next(1) Expect(ep1).To(Equal(ep2)) @@ -1774,7 +1787,7 @@ var _ = Describe("ProxyRoundTripper", func() { infoLogs := logger.Lines(zap.InfoLevel) count := 0 for i := 0; i < len(infoLogs); i++ { - if strings.Contains(infoLogs[i], "hash-based-routing-header-not-found") { + if strings.Contains(infoLogs[i], "hash-based-routing-header-value-not-found") { count++ } } diff --git a/src/code.cloudfoundry.org/gorouter/registry/registry_test.go b/src/code.cloudfoundry.org/gorouter/registry/registry_test.go index 267e4b1ff..eda8710a3 100644 --- a/src/code.cloudfoundry.org/gorouter/registry/registry_test.go +++ b/src/code.cloudfoundry.org/gorouter/registry/registry_test.go @@ -392,14 +392,20 @@ var _ = Describe("RouteRegistry", func() { Context("Modification Tags", func() { var ( - endpoint *route.Endpoint - modTag models.ModificationTag + endpoint *route.Endpoint + modTag models.ModificationTag + routingProps route.RoutingProperties ) BeforeEach(func() { modTag = models.ModificationTag{Guid: "abc"} endpoint = route.NewEndpoint(&route.EndpointOpts{ModificationTag: modTag}) r.Register("foo.com", endpoint) + routingProps = route.RoutingProperties{ + LocallyOptimistic: locallyOptimistic, + GlobalLB: config.LOAD_BALANCE_RR, + AZ: r.DefaultLoadBalancingAlgorithm, + } }) Context("registering a new route", func() { @@ -408,7 +414,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumEndpoints()).To(Equal(1)) p := r.Lookup("foo.com") - Expect(p.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0).ModificationTag).To(Equal(modTag)) + Expect(p.Endpoints(logger.Logger, "", false, routingProps).Next(0).ModificationTag).To(Equal(modTag)) }) }) @@ -430,7 +436,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumEndpoints()).To(Equal(1)) p := r.Lookup("foo.com") - Expect(p.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0).ModificationTag).To(Equal(modTag)) + Expect(p.Endpoints(logger.Logger, "", false, routingProps).Next(0).ModificationTag).To(Equal(modTag)) }) Context("updating an existing route with an older modification tag", func() { @@ -450,7 +456,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumEndpoints()).To(Equal(1)) p := r.Lookup("foo.com") - ep := p.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) + ep := p.Endpoints(logger.Logger, "", false, routingProps).Next(0) Expect(ep.ModificationTag).To(Equal(modTag)) Expect(ep).To(Equal(endpoint2)) }) @@ -469,7 +475,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumEndpoints()).To(Equal(1)) p := r.Lookup("foo.com") - Expect(p.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0).ModificationTag).To(Equal(modTag)) + Expect(p.Endpoints(logger.Logger, "", false, routingProps).Next(0).ModificationTag).To(Equal(modTag)) }) }) }) @@ -814,7 +820,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumUris()).To(Equal(1)) p1 := r.Lookup("foo/bar") - iter := p1.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil) + iter := p1.Endpoints(logger.Logger, "", false, route.RoutingProperties{LocallyOptimistic: locallyOptimistic, AZ: az}) Expect(iter.Next(0).CanonicalAddr()).To(Equal("192.168.1.1:1234")) p2 := r.Lookup("foo") @@ -918,7 +924,7 @@ var _ = Describe("RouteRegistry", func() { p2 := r.Lookup("FOO") Expect(p1).To(Equal(p2)) - iter := p1.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil) + iter := p1.Endpoints(logger.Logger, "", false, route.RoutingProperties{LocallyOptimistic: locallyOptimistic, AZ: az}) Expect(iter.Next(0).CanonicalAddr()).To(Equal("192.168.1.1:1234")) }) @@ -937,7 +943,7 @@ var _ = Describe("RouteRegistry", func() { p := r.Lookup("bar") Expect(p).ToNot(BeNil()) - e := p.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) + e := p.Endpoints(logger.Logger, "", false, route.RoutingProperties{LocallyOptimistic: locallyOptimistic, AZ: az}).Next(0) Expect(e).ToNot(BeNil()) Expect(e.CanonicalAddr()).To(MatchRegexp("192.168.1.1:123[4|5]")) @@ -952,13 +958,13 @@ var _ = Describe("RouteRegistry", func() { p := r.Lookup("foo.wild.card") Expect(p).ToNot(BeNil()) - e := p.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) + e := p.Endpoints(logger.Logger, "", false, route.RoutingProperties{LocallyOptimistic: locallyOptimistic, AZ: az}).Next(0) Expect(e).ToNot(BeNil()) Expect(e.CanonicalAddr()).To(Equal("192.168.1.2:1234")) p = r.Lookup("foo.space.wild.card") Expect(p).ToNot(BeNil()) - e = p.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) + e = p.Endpoints(logger.Logger, "", false, route.RoutingProperties{LocallyOptimistic: locallyOptimistic, AZ: az}).Next(0) Expect(e).ToNot(BeNil()) Expect(e.CanonicalAddr()).To(Equal("192.168.1.2:1234")) }) @@ -972,7 +978,7 @@ var _ = Describe("RouteRegistry", func() { p := r.Lookup("not.wild.card") Expect(p).ToNot(BeNil()) - e := p.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) + e := p.Endpoints(logger.Logger, "", false, route.RoutingProperties{LocallyOptimistic: locallyOptimistic, AZ: az}).Next(0) Expect(e).ToNot(BeNil()) Expect(e.CanonicalAddr()).To(Equal("192.168.1.1:1234")) }) @@ -1004,7 +1010,7 @@ var _ = Describe("RouteRegistry", func() { p := r.Lookup("dora.app.com/env?foo=bar") Expect(p).ToNot(BeNil()) - iter := p.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil) + iter := p.Endpoints(logger.Logger, "", false, route.RoutingProperties{LocallyOptimistic: locallyOptimistic, AZ: az}) Expect(iter.Next(0).CanonicalAddr()).To(Equal("192.168.1.1:1234")) }) @@ -1013,7 +1019,7 @@ var _ = Describe("RouteRegistry", func() { p := r.Lookup("dora.app.com/env/abc?foo=bar&baz=bing") Expect(p).ToNot(BeNil()) - iter := p.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil) + iter := p.Endpoints(logger.Logger, "", false, route.RoutingProperties{LocallyOptimistic: locallyOptimistic, AZ: az}) Expect(iter.Next(0).CanonicalAddr()).To(Equal("192.168.1.1:1234")) }) }) @@ -1033,7 +1039,7 @@ var _ = Describe("RouteRegistry", func() { p1 := r.Lookup("foo/extra/paths") Expect(p1).ToNot(BeNil()) - iter := p1.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil) + iter := p1.Endpoints(logger.Logger, "", false, route.RoutingProperties{LocallyOptimistic: locallyOptimistic, AZ: az}) Expect(iter.Next(0).CanonicalAddr()).To(Equal("192.168.1.1:1234")) }) @@ -1045,7 +1051,7 @@ var _ = Describe("RouteRegistry", func() { p1 := r.Lookup("foo?fields=foo,bar") Expect(p1).ToNot(BeNil()) - iter := p1.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil) + iter := p1.Endpoints(logger.Logger, "", false, route.RoutingProperties{LocallyOptimistic: locallyOptimistic, AZ: az}) Expect(iter.Next(0).CanonicalAddr()).To(Equal("192.168.1.1:1234")) }) @@ -1132,7 +1138,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumEndpoints()).To(Equal(2)) p := r.LookupWithAppInstance("bar.com/foo", appId, appIndex) - e := p.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) + e := p.Endpoints(logger.Logger, "", false, route.RoutingProperties{LocallyOptimistic: locallyOptimistic, AZ: az}).Next(0) Expect(e).ToNot(BeNil()) Expect(e.CanonicalAddr()).To(MatchRegexp("192.168.1.1:1234")) @@ -1153,7 +1159,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumEndpoints()).To(Equal(2)) p := r.LookupWithAppInstance("bar.com/foo", appId, appIndex) - e := p.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) + e := p.Endpoints(logger.Logger, "", false, route.RoutingProperties{LocallyOptimistic: locallyOptimistic, AZ: az}).Next(0) Expect(e).ToNot(BeNil()) Expect(e.CanonicalAddr()).To(MatchRegexp("192.168.1.1:1234")) @@ -1261,7 +1267,7 @@ var _ = Describe("RouteRegistry", func() { p := r.LookupWithProcessInstance("bar.com/foo", processId, processIndex) Expect(p.NumEndpoints()).To(Equal(2)) - es := p.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil) + es := p.Endpoints(logger.Logger, "", false, route.RoutingProperties{LocallyOptimistic: locallyOptimistic, AZ: az}) e1 := es.Next(0) Expect(e1).ToNot(BeNil()) e2 := es.Next(0) @@ -1300,7 +1306,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumEndpoints()).To(Equal(5)) p := r.LookupWithProcessInstance("bar.com/foo", processId, processIndex) - e := p.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0) + e := p.Endpoints(logger.Logger, "", false, route.RoutingProperties{LocallyOptimistic: locallyOptimistic, AZ: az}).Next(0) Expect(e).ToNot(BeNil()) Expect(e.CanonicalAddr()).To(MatchRegexp("192.168.1.4:1237")) @@ -1507,7 +1513,7 @@ var _ = Describe("RouteRegistry", func() { p := r.Lookup("foo") Expect(p).ToNot(BeNil()) - Expect(p.Endpoints(logger.Logger, "", false, locallyOptimistic, az, r.DefaultLoadBalancingAlgorithm, nil).Next(0)).To(Equal(endpoint)) + Expect(p.Endpoints(logger.Logger, "", false, route.RoutingProperties{LocallyOptimistic: locallyOptimistic, AZ: az}).Next(0)).To(Equal(endpoint)) p = r.Lookup("bar") Expect(p).To(BeNil()) diff --git a/src/code.cloudfoundry.org/gorouter/route/pool.go b/src/code.cloudfoundry.org/gorouter/route/pool.go index 3402b6c62..f4da470b7 100644 --- a/src/code.cloudfoundry.org/gorouter/route/pool.go +++ b/src/code.cloudfoundry.org/gorouter/route/pool.go @@ -74,6 +74,13 @@ type ProxyRoundTripper interface { CancelRequest(*http.Request) } +type RoutingProperties struct { + RequestHeaders *http.Header + LocallyOptimistic bool + GlobalLB string + AZ string +} + type HashRoutingProperties struct { Header string BalanceFactor float64 @@ -472,17 +479,17 @@ func (p *EndpointPool) removeEndpoint(e *endpointElem) { } -func (p *EndpointPool) Endpoints(logger *slog.Logger, initial string, mustBeSticky bool, locallyOptimistic bool, az string, globalLB string, header *http.Header) EndpointIterator { +func (p *EndpointPool) Endpoints(logger *slog.Logger, initial string, mustBeSticky bool, routingProps RoutingProperties) EndpointIterator { // For hash-based routing, validate inputs and get header value if p.LoadBalancingAlgorithm == config.LOAD_BALANCE_HB { - headerValue := p.hashBasedInputsValid(header, p.HashRoutingProperties, logger) + headerValue := p.hashBasedInputsValid(routingProps.RequestHeaders, p.HashRoutingProperties, logger) if headerValue == "" { - return p.createIterator(globalLB, logger, initial, mustBeSticky, locallyOptimistic, az) + return p.createIterator(routingProps.GlobalLB, logger, initial, mustBeSticky, routingProps) } return NewHashBased(logger, p, initial, mustBeSticky, headerValue) } - return p.createIterator(p.LoadBalancingAlgorithm, logger, initial, mustBeSticky, locallyOptimistic, az) + return p.createIterator(p.LoadBalancingAlgorithm, logger, initial, mustBeSticky, routingProps) } func (p *EndpointPool) hashBasedInputsValid(header *http.Header, hashProps *HashRoutingProperties, logger *slog.Logger) string { @@ -492,28 +499,28 @@ func (p *EndpointPool) hashBasedInputsValid(header *http.Header, hashProps *Hash } hashHeader := header.Get(hashProps.Header) if hashHeader == "" { - logger.Warn("hash-header-value-not-found", + logger.Info("hash-based-routing-header-value-not-found", slog.String("Host", p.host), slog.String("Path", p.contextPath)) } return hashHeader } -func (p *EndpointPool) createIterator(lbAlgo string, logger *slog.Logger, initial string, mustBeSticky bool, locallyOptimistic bool, az string) EndpointIterator { +func (p *EndpointPool) createIterator(lbAlgo string, logger *slog.Logger, initial string, mustBeSticky bool, routingProps RoutingProperties) EndpointIterator { switch lbAlgo { case config.LOAD_BALANCE_LC: logger.Debug("endpoint-iterator-with-least-connection-lb-algo") - return NewLeastConnection(logger, p, initial, mustBeSticky, locallyOptimistic, az) + return NewLeastConnection(logger, p, initial, mustBeSticky, routingProps.LocallyOptimistic, routingProps.AZ) case config.LOAD_BALANCE_RR: logger.Debug("endpoint-iterator-with-round-robin-lb-algo") - return NewRoundRobin(logger, p, initial, mustBeSticky, locallyOptimistic, az) + return NewRoundRobin(logger, p, initial, mustBeSticky, routingProps.LocallyOptimistic, routingProps.AZ) default: logger.Error("invalid-pool-load-balancing-algorithm", slog.String("poolLBAlgorithm", lbAlgo), slog.String("Host", p.host), slog.String("Path", p.contextPath)) logger.Debug("endpoint-iterator-with-round-robin-lb-algo") - return NewRoundRobin(logger, p, initial, mustBeSticky, locallyOptimistic, az) + return NewRoundRobin(logger, p, initial, mustBeSticky, routingProps.LocallyOptimistic, routingProps.AZ) } } diff --git a/src/code.cloudfoundry.org/gorouter/route/pool_test.go b/src/code.cloudfoundry.org/gorouter/route/pool_test.go index 9c0aa63ee..bba5a20eb 100644 --- a/src/code.cloudfoundry.org/gorouter/route/pool_test.go +++ b/src/code.cloudfoundry.org/gorouter/route/pool_test.go @@ -233,6 +233,7 @@ var _ = Describe("EndpointPool", func() { Context("with modification tags", func() { var modTag models.ModificationTag var modTag2 models.ModificationTag + var routingProps route.RoutingProperties BeforeEach(func() { modTag = models.ModificationTag{} @@ -240,13 +241,19 @@ var _ = Describe("EndpointPool", func() { endpoint1 := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, ModificationTag: modTag}) Expect(pool.Put(endpoint1)).To(Equal(route.EndpointAdded)) + + routingProps = route.RoutingProperties{ + LocallyOptimistic: locallyOptimistic, + GlobalLB: config.LOAD_BALANCE_RR, + AZ: az, + } }) It("updates an endpoint with modification tag", func() { endpoint := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, ModificationTag: modTag2}) Expect(pool.Put(endpoint)).To(Equal(route.EndpointUpdated)) - Expect(pool.Endpoints(logger.Logger, "", false, locallyOptimistic, az, config.LOAD_BALANCE_RR, nil).Next(0).ModificationTag).To(Equal(modTag2)) + Expect(pool.Endpoints(logger.Logger, "", false, routingProps).Next(0).ModificationTag).To(Equal(modTag2)) }) Context("when modification_tag is older", func() { @@ -261,7 +268,7 @@ var _ = Describe("EndpointPool", func() { endpoint := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, ModificationTag: olderModTag}) Expect(pool.Put(endpoint)).To(Equal(route.EndpointUnmodified)) - Expect(pool.Endpoints(logger.Logger, "", false, locallyOptimistic, az, config.LOAD_BALANCE_RR, nil).Next(0).ModificationTag).To(Equal(modTag2)) + Expect(pool.Endpoints(logger.Logger, "", false, routingProps).Next(0).ModificationTag).To(Equal(modTag2)) }) }) }) @@ -299,7 +306,16 @@ var _ = Describe("EndpointPool", func() { Context("Customizable Per Route Load Balancing", func() { var ( locallyOptimistic = false + routingProps route.RoutingProperties ) + + BeforeEach(func() { + routingProps = route.RoutingProperties{ + LocallyOptimistic: locallyOptimistic, + GlobalLB: config.LOAD_BALANCE_RR, + AZ: "az", + } + }) Context("Load Balancing Algorithm of a pool", func() { It("has a value specified in the pool options", func() { poolWithLBAlgo := route.NewPool(&route.PoolOpts{ @@ -314,7 +330,7 @@ var _ = Describe("EndpointPool", func() { Logger: logger.Logger, LoadBalancingAlgorithm: "wrong-lb-algo", }) - iterator := poolWithLBAlgo2.Endpoints(logger.Logger, "", false, locallyOptimistic, "zone", config.LOAD_BALANCE_RR, nil) + iterator := poolWithLBAlgo2.Endpoints(logger.Logger, "", false, routingProps) Expect(iterator).To(BeAssignableToTypeOf(&route.RoundRobin{})) Eventually(logger).Should(gbytes.Say(`invalid-pool-load-balancing-algorithm`)) }) @@ -324,7 +340,7 @@ var _ = Describe("EndpointPool", func() { Logger: logger.Logger, LoadBalancingAlgorithm: config.LOAD_BALANCE_LC, }) - iterator := poolWithLBAlgoLC.Endpoints(logger.Logger, "", false, locallyOptimistic, "az", config.LOAD_BALANCE_LC, nil) + iterator := poolWithLBAlgoLC.Endpoints(logger.Logger, "", false, routingProps) Expect(iterator).To(BeAssignableToTypeOf(&route.LeastConnection{})) Eventually(logger).Should(gbytes.Say(`endpoint-iterator-with-least-connection-lb-algo`)) }) @@ -334,7 +350,7 @@ var _ = Describe("EndpointPool", func() { Logger: logger.Logger, LoadBalancingAlgorithm: config.LOAD_BALANCE_RR, }) - iterator := poolWithLBAlgoLC.Endpoints(logger.Logger, "", false, locallyOptimistic, "az", config.LOAD_BALANCE_RR, nil) + iterator := poolWithLBAlgoLC.Endpoints(logger.Logger, "", false, routingProps) Expect(iterator).To(BeAssignableToTypeOf(&route.RoundRobin{})) Eventually(logger).Should(gbytes.Say(`endpoint-iterator-with-round-robin-lb-algo`)) }) @@ -540,9 +556,14 @@ var _ = Describe("EndpointPool", func() { It("marks the endpoint as failed", func() { az := "meow-zone" locallyOptimistic := false + routingProps := route.RoutingProperties{ + LocallyOptimistic: locallyOptimistic, + GlobalLB: config.LOAD_BALANCE_RR, + AZ: az, + } connectionResetError := &net.OpError{Op: "read", Err: errors.New("read: connection reset by peer")} pool.EndpointFailed(failedEndpoint, connectionResetError) - i := pool.Endpoints(logger.Logger, "", false, locallyOptimistic, az, config.LOAD_BALANCE_RR, nil) + i := pool.Endpoints(logger.Logger, "", false, routingProps) epOne := i.Next(0) epTwo := i.Next(1) Expect(epOne).To(Equal(epTwo)) From 560f9358e371a229a8520f5b638790772cadbc1b Mon Sep 17 00:00:00 2001 From: Tamara Boehm Date: Fri, 16 Jan 2026 11:36:12 +0100 Subject: [PATCH 13/17] Apply review feedback --- .../gorouter/route/pool.go | 47 +++++++++---------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/src/code.cloudfoundry.org/gorouter/route/pool.go b/src/code.cloudfoundry.org/gorouter/route/pool.go index f4da470b7..15efcffb5 100644 --- a/src/code.cloudfoundry.org/gorouter/route/pool.go +++ b/src/code.cloudfoundry.org/gorouter/route/pool.go @@ -480,33 +480,16 @@ func (p *EndpointPool) removeEndpoint(e *endpointElem) { } func (p *EndpointPool) Endpoints(logger *slog.Logger, initial string, mustBeSticky bool, routingProps RoutingProperties) EndpointIterator { - // For hash-based routing, validate inputs and get header value - if p.LoadBalancingAlgorithm == config.LOAD_BALANCE_HB { - headerValue := p.hashBasedInputsValid(routingProps.RequestHeaders, p.HashRoutingProperties, logger) - if headerValue == "" { - return p.createIterator(routingProps.GlobalLB, logger, initial, mustBeSticky, routingProps) + lbAlgo := p.LoadBalancingAlgorithm + // Handle hash-based routing as special case + if lbAlgo == config.LOAD_BALANCE_HB { + headerValue := p.GetValidHashHeaderValue(routingProps.RequestHeaders, logger) + if headerValue != "" { + return NewHashBased(logger, p, initial, mustBeSticky, headerValue) } - return NewHashBased(logger, p, initial, mustBeSticky, headerValue) - } - - return p.createIterator(p.LoadBalancingAlgorithm, logger, initial, mustBeSticky, routingProps) -} - -func (p *EndpointPool) hashBasedInputsValid(header *http.Header, hashProps *HashRoutingProperties, logger *slog.Logger) string { - if hashProps == nil || hashProps.Header == "" { - logger.Error("hash-routing-properties-missing", slog.String("host", p.Host())) - return "" + lbAlgo = routingProps.GlobalLB } - hashHeader := header.Get(hashProps.Header) - if hashHeader == "" { - logger.Info("hash-based-routing-header-value-not-found", - slog.String("Host", p.host), - slog.String("Path", p.contextPath)) - } - return hashHeader -} -func (p *EndpointPool) createIterator(lbAlgo string, logger *slog.Logger, initial string, mustBeSticky bool, routingProps RoutingProperties) EndpointIterator { switch lbAlgo { case config.LOAD_BALANCE_LC: logger.Debug("endpoint-iterator-with-least-connection-lb-algo") @@ -524,6 +507,22 @@ func (p *EndpointPool) createIterator(lbAlgo string, logger *slog.Logger, initia } } +func (p *EndpointPool) GetValidHashHeaderValue(header *http.Header, logger *slog.Logger) string { + if p.HashRoutingProperties == nil || p.HashRoutingProperties.Header == "" { + logger.Error("hash-routing-properties-missing", slog.String("host", p.Host())) + return "" + } + + hashHeader := header.Get(p.HashRoutingProperties.Header) + if hashHeader == "" { + logger.Info("hash-based-routing-header-value-not-found", + slog.String("Host", p.host), + slog.String("Path", p.contextPath)) + return "" + } + return hashHeader +} + func (p *EndpointPool) NumEndpoints() int { p.Lock() defer p.Unlock() From b6811affc564503d9f8a8c1aee60a8f3871fb31e Mon Sep 17 00:00:00 2001 From: Tamara Boehm Date: Mon, 19 Jan 2026 10:48:09 +0100 Subject: [PATCH 14/17] Refactor hash-based.go --- .../gorouter/route/hash_based.go | 43 +++++++++---------- .../gorouter/route/hash_based_test.go | 33 ++------------ 2 files changed, 25 insertions(+), 51 deletions(-) diff --git a/src/code.cloudfoundry.org/gorouter/route/hash_based.go b/src/code.cloudfoundry.org/gorouter/route/hash_based.go index 82bca1eef..81e4143db 100644 --- a/src/code.cloudfoundry.org/gorouter/route/hash_based.go +++ b/src/code.cloudfoundry.org/gorouter/route/hash_based.go @@ -71,7 +71,7 @@ func (h *HashBased) Next(attempt int) *Endpoint { } if h.pool.HashLookupTable == nil { - h.logger.Error("hash-based-routing-failed", slog.String("host", h.pool.host), log.ErrAttr(errors.New("Lookup table is empty"))) + h.logger.Error("hash-based-routing-failed", slog.String("host", h.pool.host), log.ErrAttr(errors.New("lookup table is empty"))) return nil } @@ -128,14 +128,17 @@ func (h *HashBased) findEndpoint(index uint64, attempt int) *Endpoint { endpointElem := h.pool.findById(id) if endpointElem == nil { - h.logger.Error("hash-based-routing-failed", slog.String("host", h.pool.host), log.ErrAttr(errors.New("Endpoint not found in pool")), slog.String("endpoint-id", id)) + h.logger.Error("hash-based-routing-failed", slog.String("host", h.pool.host), log.ErrAttr(errors.New("endpoint not found in pool")), slog.String("endpoint-id", id)) currentIndex = (currentIndex + 1) % lookupTableSize continue } lastEndpointPrivateId = id - if h.pool.HashRoutingProperties.BalanceFactor <= 0 || !h.isImbalancedOrOverloaded(endpointElem) { + if endpointElem.isOverloaded() { + // If the selected endpoint has reached the limit of max request per backend, log the info about it and try the next one in the lookup table + h.logger.Info("hash-based-routing-endpoint-overloaded", slog.String("host", h.pool.host), slog.String("endpoint-id", endpointElem.endpoint.PrivateInstanceId)) + } else if h.pool.HashRoutingProperties.BalanceFactor <= 0 || !h.IsImbalanced(endpointElem.endpoint) { h.lastLookupTableIndex = currentIndex return endpointElem.endpoint } @@ -143,28 +146,22 @@ func (h *HashBased) findEndpoint(index uint64, attempt int) *Endpoint { currentIndex = (currentIndex + 1) % lookupTableSize } // All endpoints checked and overloaded or not found - h.logger.Error("hash-based-routing-failed", slog.String("host", h.pool.host), log.ErrAttr(errors.New("All endpoints are overloaded"))) + h.logger.Error("hash-based-routing-failed", slog.String("host", h.pool.host), log.ErrAttr(errors.New("all endpoints are overloaded"))) return nil } -func (h *HashBased) isImbalancedOrOverloaded(e *endpointElem) bool { - endpoint := e.endpoint - return h.IsImbalancedOrOverloaded(endpoint, e.isOverloaded()) -} - -func (h *HashBased) IsImbalancedOrOverloaded(endpoint *Endpoint, isEndpointOverloaded bool) bool { +func (h *HashBased) IsImbalanced(endpoint *Endpoint) bool { avgNumberOfInFlightRequests := h.CalculateAverageLoad() + // Check if avgNumberOfInFlightRequests is 0 to avoid division by 0 in the next if-condition + if avgNumberOfInFlightRequests == 0 { + return false + } + currentInFlightRequestCount := endpoint.Stats.NumberConnections.Count() balanceFactor := h.pool.HashRoutingProperties.BalanceFactor - if isEndpointOverloaded { - h.logger.Debug("hash-based-routing-endpoint-overloaded", slog.String("host", h.pool.host), slog.String("endpoint-id", endpoint.PrivateInstanceId), slog.Int64("endpoint-connections", currentInFlightRequestCount)) - return true - } - - // Check if avgNumberOfInFlightRequests is 0 to avoid division by 0 - if avgNumberOfInFlightRequests != 0 && float64(currentInFlightRequestCount)/avgNumberOfInFlightRequests > balanceFactor { - h.logger.Debug("hash-based-routing-endpoint-imbalanced", slog.String("host", h.pool.host), slog.String("endpoint-id", endpoint.PrivateInstanceId), slog.Int64("endpoint-connections", endpoint.Stats.NumberConnections.Count()), slog.Float64("average-load", avgNumberOfInFlightRequests)) + if float64(currentInFlightRequestCount)/avgNumberOfInFlightRequests > balanceFactor { + h.logger.Debug("hash-based-routing-endpoint-imbalanced", slog.String("host", h.pool.host), slog.String("endpoint-id", endpoint.PrivateInstanceId), slog.Int64("endpoint-connections", currentInFlightRequestCount), slog.Float64("average-load", avgNumberOfInFlightRequests)) return true } return false @@ -187,12 +184,16 @@ func (h *HashBased) findEndpointIfStickySession() *Endpoint { } if e == nil && h.mustBeSticky { - h.logger.Debug("endpoint-missing-but-request-must-be-sticky", slog.String("requested-endpoint", h.stickyEndpointID)) + if h.logger.Enabled(context.Background(), slog.LevelDebug) { + h.logger.Debug("endpoint-missing-but-request-must-be-sticky", slog.String("requested-endpoint", h.stickyEndpointID)) + } return nil } if !h.mustBeSticky { - h.logger.Debug("endpoint-missing-choosing-alternate", slog.String("requested-endpoint", h.stickyEndpointID)) + if h.logger.Enabled(context.Background(), slog.LevelDebug) { + h.logger.Debug("endpoint-missing-choosing-alternate", slog.String("requested-endpoint", h.stickyEndpointID)) + } h.stickyEndpointID = "" } } @@ -230,9 +231,7 @@ func (h *HashBased) CalculateAverageLoad() float64 { var currentInFlightRequestCount int64 for _, endpointElem := range h.pool.endpoints { - endpointElem.RLock() currentInFlightRequestCount += endpointElem.endpoint.Stats.NumberConnections.Count() - endpointElem.RUnlock() } return float64(currentInFlightRequestCount) / float64(len(h.pool.endpoints)) diff --git a/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go b/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go index b3e93b4ce..75f75e4d9 100644 --- a/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go +++ b/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go @@ -268,13 +268,6 @@ var _ = Describe("HashBased", func() { } }) - It("mark the endpoint as overloaded", func() { - for i := 0; i < 500; i++ { - iter.PreRequest(e1) - } - // in general 500 in flight requests counted by e1 - Expect(iter.IsImbalancedOrOverloaded(e1, true)).To(BeTrue()) - }) It("do not mark as imbalanced if every endpoint has 499 in-flight requests", func() { for i := 0; i < 498; i++ { iter.PreRequest(e1) @@ -286,25 +279,7 @@ var _ = Describe("HashBased", func() { iter.PreRequest(e3) } // in general 500 in flight requests counted by e1 - Expect(iter.IsImbalancedOrOverloaded(e1, false)).To(BeFalse()) - }) - - It("mark endpoint as overloaded if every endpoint has 500 in-flight requests", func() { - for i := 0; i < 499; i++ { - iter.PreRequest(e1) - } - for i := 0; i < 499; i++ { - iter.PreRequest(e2) - } - for i := 0; i < 499; i++ { - iter.PreRequest(e3) - } - // in general 500 in flight requests counted by e1 - Expect(iter.IsImbalancedOrOverloaded(e1, true)).To(BeTrue()) - Eventually(logger).Should(gbytes.Say("hash-based-routing-endpoint-overloaded")) - Expect(iter.IsImbalancedOrOverloaded(e2, true)).To(BeTrue()) - Expect(iter.IsImbalancedOrOverloaded(e3, true)).To(BeTrue()) - + Expect(iter.IsImbalanced(e1)).To(BeFalse()) }) It("mark as imbalanced if it has more in-flight requests", func() { for i := 0; i < 300; i++ { @@ -316,10 +291,10 @@ var _ = Describe("HashBased", func() { for i := 0; i < 200; i++ { iter.PreRequest(e3) } - Expect(iter.IsImbalancedOrOverloaded(e1, false)).To(BeTrue()) + Expect(iter.IsImbalanced(e1)).To(BeTrue()) Eventually(logger).Should(gbytes.Say("hash-based-routing-endpoint-imbalanced")) - Expect(iter.IsImbalancedOrOverloaded(e2, false)).To(BeFalse()) - Expect(iter.IsImbalancedOrOverloaded(e3, false)).To(BeFalse()) + Expect(iter.IsImbalanced(e2)).To(BeFalse()) + Expect(iter.IsImbalanced(e3)).To(BeFalse()) }) }) }) From 6b7daaab57031f9d3bfb7b3d7c089de31e2854cc Mon Sep 17 00:00:00 2001 From: Tamara Boehm Date: Mon, 19 Jan 2026 11:40:53 +0100 Subject: [PATCH 15/17] refactor looking for endpoint for sticky session --- .../gorouter/route/hash_based.go | 57 ++++--------------- .../gorouter/route/leastconnection.go | 46 ++++----------- .../gorouter/route/pool.go | 43 ++++++++++++++ .../gorouter/route/roundrobin.go | 46 ++++----------- 4 files changed, 79 insertions(+), 113 deletions(-) diff --git a/src/code.cloudfoundry.org/gorouter/route/hash_based.go b/src/code.cloudfoundry.org/gorouter/route/hash_based.go index 81e4143db..0a82e4234 100644 --- a/src/code.cloudfoundry.org/gorouter/route/hash_based.go +++ b/src/code.cloudfoundry.org/gorouter/route/hash_based.go @@ -1,7 +1,6 @@ package route import ( - "context" "errors" "log/slog" "sync" @@ -49,16 +48,18 @@ func (h *HashBased) Next(attempt int) *Endpoint { h.lock.Lock() defer h.lock.Unlock() - endpoint := h.findEndpointIfStickySession() - if endpoint == nil && h.mustBeSticky { - return nil - } + endpoint := h.pool.FindStickyEndpoint(h.logger, &h.stickyEndpointID, h.mustBeSticky) if endpoint != nil { h.lastEndpoint = endpoint return endpoint } + if h.mustBeSticky { + return nil + } + + // Check for empty pool if len(h.pool.endpoints) == 0 { h.logger.Warn("hash-based-routing-pool-empty", slog.String("host", h.pool.host)) return nil @@ -138,7 +139,7 @@ func (h *HashBased) findEndpoint(index uint64, attempt int) *Endpoint { if endpointElem.isOverloaded() { // If the selected endpoint has reached the limit of max request per backend, log the info about it and try the next one in the lookup table h.logger.Info("hash-based-routing-endpoint-overloaded", slog.String("host", h.pool.host), slog.String("endpoint-id", endpointElem.endpoint.PrivateInstanceId)) - } else if h.pool.HashRoutingProperties.BalanceFactor <= 0 || !h.IsImbalanced(endpointElem.endpoint) { + } else if !h.IsImbalanced(endpointElem.endpoint) { h.lastLookupTableIndex = currentIndex return endpointElem.endpoint } @@ -151,6 +152,11 @@ func (h *HashBased) findEndpoint(index uint64, attempt int) *Endpoint { } func (h *HashBased) IsImbalanced(endpoint *Endpoint) bool { + // endpoint cannot be imbalanced if balance factor is not set + if h.pool.HashRoutingProperties.BalanceFactor <= 0 { + return false + } + avgNumberOfInFlightRequests := h.CalculateAverageLoad() // Check if avgNumberOfInFlightRequests is 0 to avoid division by 0 in the next if-condition if avgNumberOfInFlightRequests == 0 { @@ -167,45 +173,6 @@ func (h *HashBased) IsImbalanced(endpoint *Endpoint) bool { return false } -// findEndpointIfStickySession checks if there is a sticky session endpoint and returns it if available. -// If the sticky session endpoint is overloaded, returns nil. -func (h *HashBased) findEndpointIfStickySession() *Endpoint { - var e *endpointElem - if h.stickyEndpointID != "" { - e = h.pool.findById(h.stickyEndpointID) - if e != nil && e.isOverloaded() { - if h.mustBeSticky { - if h.logger.Enabled(context.Background(), slog.LevelDebug) { - h.logger.Debug("endpoint-overloaded-but-request-must-be-sticky", e.endpoint.ToLogData()...) - } - return nil - } - e = nil - } - - if e == nil && h.mustBeSticky { - if h.logger.Enabled(context.Background(), slog.LevelDebug) { - h.logger.Debug("endpoint-missing-but-request-must-be-sticky", slog.String("requested-endpoint", h.stickyEndpointID)) - } - return nil - } - - if !h.mustBeSticky { - if h.logger.Enabled(context.Background(), slog.LevelDebug) { - h.logger.Debug("endpoint-missing-choosing-alternate", slog.String("requested-endpoint", h.stickyEndpointID)) - } - h.stickyEndpointID = "" - } - } - - if e != nil { - e.RLock() - defer e.RUnlock() - return e.endpoint - } - return nil -} - // EndpointFailed notifies the endpoint pool that the last selected endpoint has failed. func (h *HashBased) EndpointFailed(err error) { if h.lastEndpoint != nil { diff --git a/src/code.cloudfoundry.org/gorouter/route/leastconnection.go b/src/code.cloudfoundry.org/gorouter/route/leastconnection.go index d538b65a4..ee288f2ac 100644 --- a/src/code.cloudfoundry.org/gorouter/route/leastconnection.go +++ b/src/code.cloudfoundry.org/gorouter/route/leastconnection.go @@ -1,7 +1,6 @@ package route import ( - "context" "log/slog" "math/rand" "time" @@ -31,43 +30,22 @@ func NewLeastConnection(logger *slog.Logger, p *EndpointPool, initial string, mu } func (r *LeastConnection) Next(attempt int) *Endpoint { - var e *endpointElem - if r.initialEndpoint != "" { - e = r.pool.findById(r.initialEndpoint) - if e != nil && e.isOverloaded() { - if r.mustBeSticky { - if r.logger.Enabled(context.Background(), slog.LevelDebug) { - r.logger.Debug("endpoint-overloaded-but-request-must-be-sticky", e.endpoint.ToLogData()...) - } - return nil - } - e = nil - } - - if e == nil && r.mustBeSticky { - r.logger.Debug("endpoint-missing-but-request-must-be-sticky", slog.String("requested-endpoint", r.initialEndpoint)) - return nil - } - - if !r.mustBeSticky { - r.logger.Debug("endpoint-missing-choosing-alternate", slog.String("requested-endpoint", r.initialEndpoint)) - r.initialEndpoint = "" - } + e := r.pool.FindStickyEndpoint(r.logger, &r.initialEndpoint, r.mustBeSticky) + if e != nil { + r.lastEndpoint = e + return e } - if e != nil { - e.RLock() - defer e.RUnlock() - r.lastEndpoint = e.endpoint - return e.endpoint + if r.mustBeSticky { + return nil } - e = r.next(attempt) - if e != nil { - e.RLock() - defer e.RUnlock() - r.lastEndpoint = e.endpoint - return e.endpoint + endpointElem := r.next(attempt) + if endpointElem != nil { + endpointElem.RLock() + defer endpointElem.RUnlock() + r.lastEndpoint = endpointElem.endpoint + return endpointElem.endpoint } r.lastEndpoint = nil diff --git a/src/code.cloudfoundry.org/gorouter/route/pool.go b/src/code.cloudfoundry.org/gorouter/route/pool.go index 15efcffb5..38044b253 100644 --- a/src/code.cloudfoundry.org/gorouter/route/pool.go +++ b/src/code.cloudfoundry.org/gorouter/route/pool.go @@ -1,6 +1,7 @@ package route import ( + "context" "encoding/json" "fmt" "log/slog" @@ -535,6 +536,48 @@ func (p *EndpointPool) findById(id string) *endpointElem { return p.index[id] } +// FindStickyEndpoint attempts to find and return a sticky session endpoint. +// If the endpoint is found and not overloaded, it returns the endpoint. +// If mustBeSticky is true and the endpoint is missing or overloaded, it returns nil. +// If mustBeSticky is false and the endpoint is missing or overloaded, it clears the stickyEndpointID and returns nil. +// The stickyEndpointID pointer is modified in place when the endpoint is not sticky. +func (p *EndpointPool) FindStickyEndpoint(logger *slog.Logger, stickyEndpointID *string, mustBeSticky bool) *Endpoint { + var e *endpointElem + if *stickyEndpointID != "" { + e = p.findById(*stickyEndpointID) + if e != nil && e.isOverloaded() { + if mustBeSticky { + if logger.Enabled(context.Background(), slog.LevelDebug) { + logger.Debug("endpoint-overloaded-but-request-must-be-sticky", e.endpoint.ToLogData()...) + } + return nil + } + e = nil + } + + if e == nil && mustBeSticky { + if logger.Enabled(context.Background(), slog.LevelDebug) { + logger.Debug("endpoint-missing-but-request-must-be-sticky", slog.String("requested-endpoint", *stickyEndpointID)) + } + return nil + } + + if !mustBeSticky { + if logger.Enabled(context.Background(), slog.LevelDebug) { + logger.Debug("endpoint-missing-choosing-alternate", slog.String("requested-endpoint", *stickyEndpointID)) + } + *stickyEndpointID = "" + } + } + + if e != nil { + e.RLock() + defer e.RUnlock() + return e.endpoint + } + return nil +} + func (p *EndpointPool) IsEmpty() bool { p.Lock() l := len(p.endpoints) diff --git a/src/code.cloudfoundry.org/gorouter/route/roundrobin.go b/src/code.cloudfoundry.org/gorouter/route/roundrobin.go index 9af2735a3..f9820fa0f 100644 --- a/src/code.cloudfoundry.org/gorouter/route/roundrobin.go +++ b/src/code.cloudfoundry.org/gorouter/route/roundrobin.go @@ -1,7 +1,6 @@ package route import ( - "context" "log/slog" "sync" "time" @@ -38,43 +37,22 @@ func (r *RoundRobin) Next(attempt int) *Endpoint { r.lock.Lock() defer r.lock.Unlock() - var e *endpointElem - if r.initialEndpoint != "" { - e = r.pool.findById(r.initialEndpoint) - if e != nil && e.isOverloaded() { - if r.mustBeSticky { - if r.logger.Enabled(context.Background(), slog.LevelDebug) { - r.logger.Debug("endpoint-overloaded-but-request-must-be-sticky", e.endpoint.ToLogData()...) - } - return nil - } - e = nil - } - - if e == nil && r.mustBeSticky { - r.logger.Debug("endpoint-missing-but-request-must-be-sticky", slog.String("requested-endpoint", r.initialEndpoint)) - return nil - } - - if !r.mustBeSticky { - r.logger.Debug("endpoint-missing-choosing-alternate", slog.String("requested-endpoint", r.initialEndpoint)) - r.initialEndpoint = "" - } + e := r.pool.FindStickyEndpoint(r.logger, &r.initialEndpoint, r.mustBeSticky) + if e != nil { + r.lastEndpoint = e + return e } - if e != nil { - e.RLock() - defer e.RUnlock() - r.lastEndpoint = e.endpoint - return e.endpoint + if r.mustBeSticky { + return nil } - e = r.next(attempt) - if e != nil { - e.RLock() - defer e.RUnlock() - r.lastEndpoint = e.endpoint - return e.endpoint + endpointElem := r.next(attempt) + if endpointElem != nil { + endpointElem.RLock() + defer endpointElem.RUnlock() + r.lastEndpoint = endpointElem.endpoint + return endpointElem.endpoint } r.lastEndpoint = nil From c518e14f93d8d9bdc43338ecd8ce26fe9a534ad3 Mon Sep 17 00:00:00 2001 From: Tamara Boehm Date: Mon, 19 Jan 2026 12:07:22 +0100 Subject: [PATCH 16/17] Refactor hash_based.go --- .../gorouter/route/hash_based.go | 49 +++++++++++-------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/src/code.cloudfoundry.org/gorouter/route/hash_based.go b/src/code.cloudfoundry.org/gorouter/route/hash_based.go index 0a82e4234..752a8b012 100644 --- a/src/code.cloudfoundry.org/gorouter/route/hash_based.go +++ b/src/code.cloudfoundry.org/gorouter/route/hash_based.go @@ -49,7 +49,6 @@ func (h *HashBased) Next(attempt int) *Endpoint { defer h.lock.Unlock() endpoint := h.pool.FindStickyEndpoint(h.logger, &h.stickyEndpointID, h.mustBeSticky) - if endpoint != nil { h.lastEndpoint = endpoint return endpoint @@ -71,34 +70,42 @@ func (h *HashBased) Next(attempt int) *Endpoint { return endpoint } + // Perform hash-based selection + endpoint = h.selectHashBasedEndpoint(attempt) + if endpoint != nil { + h.lastEndpoint = endpoint + } + return endpoint +} + +// selectHashBasedEndpoint performs hash-based endpoint selection using the lookup table. +func (h *HashBased) selectHashBasedEndpoint(attempt int) *Endpoint { if h.pool.HashLookupTable == nil { h.logger.Error("hash-based-routing-failed", slog.String("host", h.pool.host), log.ErrAttr(errors.New("lookup table is empty"))) return nil } - if attempt == 0 || h.lastLookupTableIndex == 0 { - initialLookupTableIndex, _, err := h.pool.HashLookupTable.GetInstanceForHashHeader(h.HeaderValue) - - if err != nil { - h.logger.Error( - "hash-based-routing-failed", - slog.String("host", h.pool.host), - log.ErrAttr(err), - ) - return nil - } - - endpoint = h.findEndpoint(initialLookupTableIndex, attempt) - } else { - // On retries, start looking from the next index in the lookup table - nextIndex := (h.lastLookupTableIndex + 1) % h.pool.HashLookupTable.GetLookupTableSize() - endpoint = h.findEndpoint(nextIndex, attempt) + startIndex, err := h.getStartingIndex(attempt) + if err != nil { + h.logger.Error("hash-based-routing-failed", slog.String("host", h.pool.host), log.ErrAttr(err)) + return nil } - if endpoint != nil { - h.lastEndpoint = endpoint + return h.findEndpoint(startIndex, attempt) +} + +// getStartingIndex determines the starting index in the lookup table based on the attempt number. +// For the initial attempt, it uses the hash of the header value. +// For retries, it uses the next index after the last lookup. +func (h *HashBased) getStartingIndex(attempt int) (uint64, error) { + if attempt == 0 || h.lastLookupTableIndex == 0 { + index, _, err := h.pool.HashLookupTable.GetInstanceForHashHeader(h.HeaderValue) + return index, err } - return endpoint + + // On retries, start from the next index in the lookup table + nextIndex := (h.lastLookupTableIndex + 1) % h.pool.HashLookupTable.GetLookupTableSize() + return nextIndex, nil } func (h *HashBased) findEndpoint(index uint64, attempt int) *Endpoint { From b414a383c6bfe4eccf8104c89ef053d7422ce843 Mon Sep 17 00:00:00 2001 From: Tamara Boehm Date: Thu, 22 Jan 2026 09:55:00 +0100 Subject: [PATCH 17/17] Add function for debug logs --- .../gorouter/route/pool.go | 55 ++++++++++--------- 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/src/code.cloudfoundry.org/gorouter/route/pool.go b/src/code.cloudfoundry.org/gorouter/route/pool.go index 38044b253..a044e4905 100644 --- a/src/code.cloudfoundry.org/gorouter/route/pool.go +++ b/src/code.cloudfoundry.org/gorouter/route/pool.go @@ -493,17 +493,17 @@ func (p *EndpointPool) Endpoints(logger *slog.Logger, initial string, mustBeStic switch lbAlgo { case config.LOAD_BALANCE_LC: - logger.Debug("endpoint-iterator-with-least-connection-lb-algo") + logDebugIfEnabled(logger, "endpoint-iterator-with-least-connection-lb-algo") return NewLeastConnection(logger, p, initial, mustBeSticky, routingProps.LocallyOptimistic, routingProps.AZ) case config.LOAD_BALANCE_RR: - logger.Debug("endpoint-iterator-with-round-robin-lb-algo") + logDebugIfEnabled(logger, "endpoint-iterator-with-round-robin-lb-algo") return NewRoundRobin(logger, p, initial, mustBeSticky, routingProps.LocallyOptimistic, routingProps.AZ) default: logger.Error("invalid-pool-load-balancing-algorithm", slog.String("poolLBAlgorithm", lbAlgo), slog.String("Host", p.host), slog.String("Path", p.contextPath)) - logger.Debug("endpoint-iterator-with-round-robin-lb-algo") + logDebugIfEnabled(logger, "endpoint-iterator-with-round-robin-lb-algo") return NewRoundRobin(logger, p, initial, mustBeSticky, routingProps.LocallyOptimistic, routingProps.AZ) } } @@ -542,32 +542,30 @@ func (p *EndpointPool) findById(id string) *endpointElem { // If mustBeSticky is false and the endpoint is missing or overloaded, it clears the stickyEndpointID and returns nil. // The stickyEndpointID pointer is modified in place when the endpoint is not sticky. func (p *EndpointPool) FindStickyEndpoint(logger *slog.Logger, stickyEndpointID *string, mustBeSticky bool) *Endpoint { - var e *endpointElem - if *stickyEndpointID != "" { - e = p.findById(*stickyEndpointID) - if e != nil && e.isOverloaded() { - if mustBeSticky { - if logger.Enabled(context.Background(), slog.LevelDebug) { - logger.Debug("endpoint-overloaded-but-request-must-be-sticky", e.endpoint.ToLogData()...) - } - return nil - } - e = nil - } + if *stickyEndpointID == "" { + return nil + } - if e == nil && mustBeSticky { - if logger.Enabled(context.Background(), slog.LevelDebug) { - logger.Debug("endpoint-missing-but-request-must-be-sticky", slog.String("requested-endpoint", *stickyEndpointID)) - } + var e *endpointElem + e = p.findById(*stickyEndpointID) + if e != nil && e.isOverloaded() { + if mustBeSticky { + logDebugIfEnabled(logger, "endpoint-overloaded-but-request-must-be-sticky", e.endpoint.ToLogData()...) return nil } + e = nil + } - if !mustBeSticky { - if logger.Enabled(context.Background(), slog.LevelDebug) { - logger.Debug("endpoint-missing-choosing-alternate", slog.String("requested-endpoint", *stickyEndpointID)) - } - *stickyEndpointID = "" + if e == nil && mustBeSticky { + logDebugIfEnabled(logger, "endpoint-missing-but-request-must-be-sticky", slog.String("requested-endpoint", *stickyEndpointID)) + return nil + } + + if !mustBeSticky { + if e == nil { + logDebugIfEnabled(logger, "endpoint-missing-choosing-alternate", slog.String("requested-endpoint", *stickyEndpointID)) } + *stickyEndpointID = "" } if e != nil { @@ -578,6 +576,13 @@ func (p *EndpointPool) FindStickyEndpoint(logger *slog.Logger, stickyEndpointID return nil } +// logDebugIfEnabled logs a debug message only if debug level is enabled +func logDebugIfEnabled(logger *slog.Logger, msg string, args ...any) { + if logger.Enabled(context.Background(), slog.LevelDebug) { + logger.Debug(msg, args...) + } +} + func (p *EndpointPool) IsEmpty() bool { p.Lock() l := len(p.endpoints) @@ -682,7 +687,7 @@ func (p *EndpointPool) setPoolLoadBalancingAlgorithm(endpoint *Endpoint) { if endpoint.LoadBalancingAlgorithm != p.LoadBalancingAlgorithm { if config.IsLoadBalancingAlgorithmValid(endpoint.LoadBalancingAlgorithm) { p.LoadBalancingAlgorithm = endpoint.LoadBalancingAlgorithm - p.logger.Debug("setting-pool-load-balancing-algorithm-to-that-of-an-endpoint", + logDebugIfEnabled(p.logger, "setting-pool-load-balancing-algorithm-to-that-of-an-endpoint", slog.String("endpointLBAlgorithm", endpoint.LoadBalancingAlgorithm), slog.String("poolLBAlgorithm", p.LoadBalancingAlgorithm))