diff --git a/proxy.go b/proxy.go index 72ec90bb..e431afb9 100644 --- a/proxy.go +++ b/proxy.go @@ -238,7 +238,12 @@ func executeWithRetry( // comment s.host.dec() line to avoid double increment; issue #322 // s.host.dec() s.host.SetIsActive(false) - nextHost := s.cluster.getHost() + var nextHost *topology.Node + if s.replicaNum > 0 || s.nodeNum > 0 { + nextHost = s.cluster.getSpecificHost(s.replicaNum, s.nodeNum) + } else { + nextHost = s.cluster.getHost() + } // The query could be retried if it has no stickiness to a certain server if numRetry < maxRetry && nextHost.IsActive() && s.sessionId == "" { // the query execution has been failed @@ -917,6 +922,11 @@ func (rp *reverseProxy) getScope(req *http.Request) (*scope, int, error) { return nil, http.StatusForbidden, fmt.Errorf("cluster user %q is not allowed to access", cu.name) } - s := newScope(req, u, c, cu, sessionId, sessionTimeout) + replicaNum, nodeNum, err := getSpecificHostNum(req, c) + if err != nil { + return nil, http.StatusBadRequest, err + } + + s := newScope(req, u, c, cu, sessionId, sessionTimeout, replicaNum, nodeNum) return s, 0, nil } diff --git a/scope.go b/scope.go index ef970565..cca556b3 100644 --- a/scope.go +++ b/scope.go @@ -45,6 +45,8 @@ type scope struct { sessionId string sessionTimeout int + replicaNum int + nodeNum int remoteAddr string localAddr string @@ -57,10 +59,14 @@ type scope struct { requestPacketSize int } -func newScope(req *http.Request, u *user, c *cluster, cu *clusterUser, sessionId string, sessionTimeout int) *scope { - h := c.getHost() +func newScope(req *http.Request, u *user, c *cluster, cu *clusterUser, sessionId string, sessionTimeout int, replicaNum, nodeNum int) *scope { + var h *topology.Node if sessionId != "" { h = c.getHostSticky(sessionId) + } else if replicaNum > 0 || nodeNum > 0 { + h = c.getSpecificHost(replicaNum, nodeNum) + } else { + h = c.getHost() } var localAddr string if addr, ok := req.Context().Value(http.LocalAddrContextKey).(net.Addr); ok { @@ -75,6 +81,8 @@ func newScope(req *http.Request, u *user, c *cluster, cu *clusterUser, sessionId clusterUser: cu, sessionId: sessionId, sessionTimeout: sessionTimeout, + replicaNum: replicaNum, + nodeNum: nodeNum, remoteAddr: req.RemoteAddr, localAddr: localAddr, @@ -185,11 +193,13 @@ func (s *scope) waitUntilAllowStart(sleep time.Duration, deadline time.Time, lab var h *topology.Node // Choose new host, since the previous one may become obsolete // after sleeping. - if s.sessionId == "" { - h = s.cluster.getHost() - } else { + if s.sessionId != "" { // if request has session_id, set same host h = s.cluster.getHostSticky(s.sessionId) + } else if s.replicaNum > 0 || s.nodeNum > 0 { + h = s.cluster.getSpecificHost(s.replicaNum, s.nodeNum) + } else { + h = s.cluster.getHost() } s.host = h @@ -720,6 +730,8 @@ func newReplicas(replicasCfg []config.Replica, nodes []string, scheme string, c return nil, err } r.hosts = hosts + c.maxNodeNum = len(r.hosts) + c.maxReplicaNum = 1 return []*replica{r}, nil } @@ -735,7 +747,9 @@ func newReplicas(replicasCfg []config.Replica, nodes []string, scheme string, c } r.hosts = hosts replicas[i] = r + c.maxNodeNum = max(c.maxNodeNum, len(r.hosts)) } + c.maxReplicaNum = len(replicas) return replicas, nil } @@ -775,6 +789,9 @@ type cluster struct { replicas []*replica nextReplicaIdx uint32 + maxReplicaNum int + maxNodeNum int + users map[string]*clusterUser killQueryUserName string @@ -937,6 +954,59 @@ func (r *replica) getHostSticky(sessionId string) *topology.Node { return h } +// getSpecificReplica returns specific replica by replicaNum from the cluster. +// +// Always returns non-nil. +func (c *cluster) getSpecificReplica(replicaNum, nodeNum int) *replica { + if replicaNum > 0 { + return c.replicas[replicaNum-1] + } + if nodeNum == 0 { + return c.getReplica() + } + + idx := atomic.AddUint32(&c.nextReplicaIdx, 1) + n := uint32(len(c.replicas)) + if n == 1 { + return c.replicas[0] + } + + var r *replica + reqs := ^uint32(0) + + // Scan all the replicas for the least loaded and nodeNum-satisfied replica. + for i := uint32(0); i < n; i++ { + tmpIdx := (idx + i) % n + tmpR := c.replicas[tmpIdx] + if nodeNum > len(tmpR.hosts) { + continue + } + if tmpR.isActive() || r == nil { + tmpReqs := tmpR.load() + if tmpReqs < reqs || !r.isActive() { + r = tmpR + reqs = tmpReqs + } + } + } + + // The returned replica may be inactive. This is OK, + // since this means all the nodeNum-satisfied replicas are inactive, + // so let's try proxying the request to any replica. + return r +} + +// getSpecificHost returns specific host by nodeNum from replica. +// +// Always returns non-nil. +func (r *replica) getSpecificHost(nodeNum int) *topology.Node { + if nodeNum > 0 { + return r.hosts[nodeNum-1] + } + + return r.getHost() +} + // getHost returns least loaded + round-robin host from replica. // // Always returns non-nil. @@ -991,6 +1061,16 @@ func (c *cluster) getHostSticky(sessionId string) *topology.Node { return r.getHostSticky(sessionId) } +// getSpecificHost returns specific host by num from cluster. +// Both replicaNum/nodeNum start from 1 and satisfy [0, maxReplicaNum/maxNodeNum], 0 means no specific host num. +// If both are 0, getSpecificHost equals to getHost. +// +// Always returns non-nil. +func (c *cluster) getSpecificHost(replicaNum, nodeNum int) *topology.Node { + r := c.getSpecificReplica(replicaNum, nodeNum) + return r.getSpecificHost(nodeNum) +} + // getHost returns least loaded + round-robin host from cluster. // // Always returns non-nil. diff --git a/scope_test.go b/scope_test.go index 438966c6..007fe5d5 100644 --- a/scope_test.go +++ b/scope_test.go @@ -410,6 +410,78 @@ func TestGetHostSticky(t *testing.T) { } } +func TestGetSpecificHost(t *testing.T) { + c := testGetCluster() + + t.Run("SpecifyReplicaNum", func(t *testing.T) { + h := c.getSpecificHost(1, 0) + if h.Host() != "127.0.0.11" && h.Host() != "127.0.0.22" { + t.Fatalf("Expected host from replica1, got: %s", h.Host()) + } + + h = c.getSpecificHost(2, 0) + if h.Host() != "127.0.0.33" && h.Host() != "127.0.0.44" { + t.Fatalf("Expected host from replica2, got: %s", h.Host()) + } + + h = c.getSpecificHost(3, 0) + if h.Host() != "127.0.0.55" && h.Host() != "127.0.0.66" { + t.Fatalf("Expected host from replica3, got: %s", h.Host()) + } + }) + + t.Run("SpecifyNodeNum", func(t *testing.T) { + h := c.getSpecificHost(0, 1) + if h.Host() != "127.0.0.11" && h.Host() != "127.0.0.33" && h.Host() != "127.0.0.55" { + t.Fatalf("Expected first node from any replica, got: %s", h.Host()) + } + + h = c.getSpecificHost(0, 2) + if h.Host() != "127.0.0.22" && h.Host() != "127.0.0.44" && h.Host() != "127.0.0.66" { + t.Fatalf("Expected second node from any replica, got: %s", h.Host()) + } + }) + + t.Run("SpecifyReplicaNumAndNodeNum", func(t *testing.T) { + h := c.getSpecificHost(1, 1) + if h.Host() != "127.0.0.11" { + t.Fatalf("Expected 127.0.0.11, got: %s", h.Host()) + } + + h = c.getSpecificHost(1, 2) + if h.Host() != "127.0.0.22" { + t.Fatalf("Expected 127.0.0.22, got: %s", h.Host()) + } + + h = c.getSpecificHost(2, 1) + if h.Host() != "127.0.0.33" { + t.Fatalf("Expected 127.0.0.33, got: %s", h.Host()) + } + }) + + t.Run("SpecifyBothNumsZero", func(t *testing.T) { + h := c.getSpecificHost(0, 0) + if h == nil { + t.Fatalf("getSpecificHost(0, 0) returned nil") + } + found := false + for _, r := range c.replicas { + for _, node := range r.hosts { + if h.Host() == node.Host() { + found = true + break + } + } + if found { + break + } + } + if !found { + t.Fatalf("getSpecificHost(0, 0) returned unknown host: %s", h.Host()) + } + }) +} + func TestIncQueued(t *testing.T) { u := testGetUser() cu := testGetClusterUser() @@ -485,6 +557,12 @@ func testGetCluster() *cluster { topology.NewNode(&url.URL{Host: "127.0.0.66"}, nil, "", r3.name, topology.WithDefaultActiveState(true)), } r3.name = "replica3" + + c.maxReplicaNum = len(c.replicas) + for _, r := range c.replicas { + c.maxNodeNum = max(c.maxNodeNum, len(r.hosts)) + } + return c } diff --git a/utils.go b/utils.go index c3d97d14..067be860 100644 --- a/utils.go +++ b/utils.go @@ -63,6 +63,53 @@ func getSessionTimeout(req *http.Request) int { return 60 } +// getSpecificHostNum retrieves specific host num, including replica and node num +// num starts from 1, 0 means no specific host num +// shard_num is alias for node_num, and override node_num if both are specified +func getSpecificHostNum(req *http.Request, c *cluster) (int, int, error) { + params := req.URL.Query() + var replicaNum, nodeNum int + var err error + // replica num + replicaNumStr := params.Get("replica_num") + if replicaNumStr != "" { + replicaNum, err = strconv.Atoi(replicaNumStr) + if err != nil { + return -1, -1, fmt.Errorf("invalid replica num %q", replicaNumStr) + } + if replicaNum < 0 || replicaNum > c.maxReplicaNum { + return -1, -1, fmt.Errorf("invalid replica num %q", replicaNumStr) + } + } + // node num (shard_num is alias for node_num) + nodeNumStr := params.Get("node_num") + if nodeNumStr != "" { + nodeNum, err = strconv.Atoi(nodeNumStr) + if err != nil { + return -1, -1, fmt.Errorf("invalid node num %q", nodeNumStr) + } + if nodeNum < 0 || nodeNum > c.maxNodeNum { + return -1, -1, fmt.Errorf("invalid node num %q", nodeNumStr) + } + } + shardNumStr := params.Get("shard_num") + if shardNumStr != "" { + nodeNum, err = strconv.Atoi(shardNumStr) + if err != nil { + return -1, -1, fmt.Errorf("invalid shard num %q", shardNumStr) + } + if nodeNum < 0 || nodeNum > c.maxNodeNum { + return -1, -1, fmt.Errorf("invalid shard num %q", shardNumStr) + } + } + // validate if both replicaNum and nodeNum are specified + if replicaNum > 0 && nodeNum > 0 && nodeNum > len(c.replicas[replicaNum-1].hosts) { + return -1, -1, fmt.Errorf("invalid host num (%q, %q)", replicaNumStr, nodeNumStr) + } + + return replicaNum, nodeNum, nil +} + // getQuerySnippet returns query snippet. // // getQuerySnippet must be called only for error reporting. diff --git a/utils_test.go b/utils_test.go index 9fbbc31d..bf749e69 100644 --- a/utils_test.go +++ b/utils_test.go @@ -4,6 +4,7 @@ import ( "bytes" "compress/gzip" "fmt" + "github.com/contentsquare/chproxy/internal/topology" "github.com/stretchr/testify/assert" "net/http" "net/url" @@ -374,3 +375,150 @@ func TestCalcMapHash(t *testing.T) { }) } } + +func TestGetSpecificHostNum(t *testing.T) { + // Create a test cluster with 2 replicas, each having 3 nodes + testCluster := &cluster{ + name: "test_cluster", + replicas: []*replica{ + { + name: "replica1", + hosts: []*topology.Node{{}, {}, {}}, + }, + { + name: "replica2", + hosts: []*topology.Node{{}, {}, {}}, + }, + }, + maxReplicaNum: 2, + maxNodeNum: 3, + } + // Set the cluster reference for each replica + for _, r := range testCluster.replicas { + r.cluster = testCluster + } + + testCases := []struct { + name string + params map[string]string + expectedRN int + expectedNN int + expectedError bool + }{ + { + "no parameters", + map[string]string{}, + 0, + 0, + false, + }, + { + "only replica_num", + map[string]string{"replica_num": "1"}, + 1, + 0, + false, + }, + { + "only node_num", + map[string]string{"node_num": "2"}, + 0, + 2, + false, + }, + { + "only shard_num", + map[string]string{"shard_num": "3"}, + 0, + 3, + false, + }, + { + "replica_num and node_num", + map[string]string{"replica_num": "1", "node_num": "2"}, + 1, + 2, + false, + }, + { + "invalid replica_num", + map[string]string{"replica_num": "invalid"}, + 0, + 0, + true, + }, + { + "invalid node_num", + map[string]string{"node_num": "-1"}, + 0, + 0, + true, + }, + { + "replica_num out of range", + map[string]string{"replica_num": "3"}, + 0, + 0, + true, + }, + { + "node_num out of range", + map[string]string{"node_num": "4"}, + 0, + 0, + true, + }, + { + "node_num out of range for specific replica", + map[string]string{"replica_num": "1", "node_num": "4"}, + 0, + 0, + true, + }, + { + "replica_num is zero", + map[string]string{"replica_num": "0"}, + 0, + 0, + false, + }, + { + "node_num is zero", + map[string]string{"node_num": "0"}, + 0, + 0, + false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req, err := http.NewRequest("GET", "", nil) + checkErr(t, err) + + // Set up the URL parameters + params := make(url.Values) + for k, v := range tc.params { + params.Set(k, v) + } + req.URL.RawQuery = params.Encode() + + replicaNum, nodeNum, err := getSpecificHostNum(req, testCluster) + if tc.expectedError { + if err == nil { + t.Fatalf("expected error but got none") + } + } else { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if replicaNum != tc.expectedRN { + t.Fatalf("unexpected replicaNum: got %d, expecting %d", replicaNum, tc.expectedRN) + } + if nodeNum != tc.expectedNN { + t.Fatalf("unexpected nodeNum: got %d, expecting %d", nodeNum, tc.expectedNN) + } + } + }) + } +}