diff --git a/CHANGELOG.md b/CHANGELOG.md index 2fce02d2..237a93cf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,10 @@ +unreleased +---------- +- Support protocol 3.2, and the `min_protocol_version` and + `max_protocol_version` DSN parameters ([#1258]). + +[#1258]: https://github.com/lib/pq/pull/1258 + v1.11.2 (2025-02-10) -------------------- This fixes two regressions: diff --git a/conn.go b/conn.go index 9e69b473..b1cf83cd 100644 --- a/conn.go +++ b/conn.go @@ -174,11 +174,12 @@ type conn struct { // (ErrBadConn) or getForNext(). err syncErr - processID, secretKey int // Cancellation key data for use with CancelRequest messages. - inCopy bool // If true this connection is in the middle of a COPY - noticeHandler func(*Error) // If not nil, notices will be synchronously sent here - notificationHandler func(*Notification) // If not nil, notifications will be synchronously sent here - gss GSS // GSSAPI context + secretKey []byte // Cancellation key for CancelRequest messages. + pid int // Cancellation PID. + inCopy bool // If true this connection is in the middle of a COPY + noticeHandler func(*Error) // If not nil, notices will be synchronously sent here + notificationHandler func(*Notification) // If not nil, notifications will be synchronously sent here + gss GSS // GSSAPI context } type syncErr struct { @@ -1186,7 +1187,10 @@ func (cn *conn) ssl(cfg Config) error { func (cn *conn) startup(cfg Config) error { w := cn.writeBuf(0) - w.int32(proto.ProtocolVersion30) + // Send maximum protocol version in startup; if the server doesn't support + // this version it responds with NegotiateProtocolVersion and the maximum + // version it supports (and will use). + w.int32(cfg.MaxProtocolVersion.proto()) if cfg.User != "" { w.string("user") @@ -1226,7 +1230,12 @@ func (cn *conn) startup(cfg Config) error { } switch t { case proto.BackendKeyData: - cn.processBackendKeyData(r) + cn.pid = r.int32() + if len(*r) > 256 { + return fmt.Errorf("pq: cancellation key longer than 256 bytes: %d bytes", len(*r)) + } + cn.secretKey = make([]byte, len(*r)) + copy(cn.secretKey, *r) case proto.ParameterStatus: cn.processParameterStatus(r) case proto.AuthenticationRequest: @@ -1234,6 +1243,12 @@ func (cn *conn) startup(cfg Config) error { if err != nil { return err } + case proto.NegotiateProtocolVersion: + newestMinor := r.int32() + serverVersion := proto.ProtocolVersion30&0xFFFF0000 | newestMinor + if serverVersion < cfg.MinProtocolVersion.proto() { + return fmt.Errorf("pq: protocol version mismatch: min_protocol_version=%s; server supports up to 3.%d", cfg.MinProtocolVersion, newestMinor) + } case proto.ReadyForQuery: cn.processReadyForQuery(r) return nil @@ -1566,11 +1581,6 @@ func (cn *conn) readReadyForQuery() error { } } -func (cn *conn) processBackendKeyData(r *readBuf) { - cn.processID = r.int32() - cn.secretKey = r.int32() -} - func (cn *conn) readParseResponse() error { t, r, err := cn.recv1() if err != nil { diff --git a/conn_go18.go b/conn_go18.go index 23a10aee..d776175e 100644 --- a/conn_go18.go +++ b/conn_go18.go @@ -151,8 +151,8 @@ func (cn *conn) cancel(ctx context.Context) error { w := cn2.writeBuf(0) w.int32(proto.CancelRequestCode) - w.int32(cn.processID) - w.int32(cn.secretKey) + w.int32(cn.pid) + w.bytes(cn.secretKey) if err := cn2.sendStartupPacket(w); err != nil { return err } diff --git a/connector.go b/connector.go index 1827fdbd..9b1c193a 100644 --- a/connector.go +++ b/connector.go @@ -19,6 +19,7 @@ import ( "unicode" "github.com/lib/pq/internal/pqutil" + "github.com/lib/pq/internal/proto" ) type ( @@ -33,6 +34,10 @@ type ( // LoadBalanceHosts is a load_balance_hosts setting. LoadBalanceHosts string + + // ProtocolVersion is a min_protocol_version or max_protocol_version + // setting. + ProtocolVersion string ) // Values for [SSLMode] that pq supports. @@ -110,6 +115,23 @@ const ( var loadBalanceHosts = []LoadBalanceHosts{LoadBalanceHostsDisable, LoadBalanceHostsRandom} +// Values for [ProtocolVersion] that pq supports. +const ( + // ProtocolVersion30 is the default protocol version, supported in + // PostgreSQL 3.0 and newer. + ProtocolVersion30 = ProtocolVersion("3.0") + + // ProtocolVersion32 uses a longer secret key length for query cancellation, + // supported in PostgreSQL 18 and newer. + ProtocolVersion32 = ProtocolVersion("3.2") + + // ProtocolVersionLatest is the latest protocol version that pq supports + // (which may not be supported by the server). + ProtocolVersionLatest = ProtocolVersion("latest") +) + +var protocolVersions = []ProtocolVersion{ProtocolVersion30, ProtocolVersion32, ProtocolVersionLatest} + // Connector represents a fixed configuration for the pq driver with a given // dsn. Connector satisfies the [database/sql/driver.Connector] interface and // can be used to create any number of DB Conn's via [sql.OpenDB]. @@ -148,6 +170,15 @@ func (c *Connector) Dialer(dialer Dialer) { c.dialer = dialer } // Driver returns the underlying driver of this Connector. func (c *Connector) Driver() driver.Driver { return &Driver{} } +func (p ProtocolVersion) proto() int { + switch p { + default: + return proto.ProtocolVersion30 + case ProtocolVersion32, ProtocolVersionLatest: + return proto.ProtocolVersion32 + } +} + // Config holds options pq supports when connecting to PostgreSQL. // // The postgres struct tag is used for the value from the DSN (e.g. @@ -303,6 +334,14 @@ type Config struct { // to the same server. LoadBalanceHosts LoadBalanceHosts `postgres:"load_balance_hosts" env:"PGLOADBALANCEHOSTS"` + // Minimum acceptable PostgreSQL protocol version. If the server does not + // support at least this version, the connection will fail. Defaults to + // "3.0". + MinProtocolVersion ProtocolVersion `postgres:"min_protocol_version" env:"PGMINPROTOCOLVERSION"` + + // Maximum PostgreSQL protocol version to request from the server. Defaults to "3.0". + MaxProtocolVersion ProtocolVersion `postgres:"max_protocol_version" env:"PGMAXPROTOCOLVERSION"` + // Runtime parameters: any unrecognized parameter in the DSN will be added // to this and sent to PostgreSQL during startup. Runtime map[string]string `postgres:"-" env:"-"` @@ -413,7 +452,13 @@ func (cfg Config) hosts() []Config { } func newConfig(dsn string, env []string) (Config, error) { - cfg := Config{Host: "localhost", Port: 5432, SSLSNI: true} + cfg := Config{ + Host: "localhost", + Port: 5432, + SSLSNI: true, + MinProtocolVersion: "3.0", + MaxProtocolVersion: "3.0", + } if err := cfg.fromEnv(env); err != nil { return Config{}, err } @@ -487,6 +532,11 @@ func newConfig(dsn string, env []string) (Config, error) { cfg.SSLMode = SSLModeDisable } + if cfg.MinProtocolVersion > cfg.MaxProtocolVersion { + return Config{}, fmt.Errorf("pq: min_protocol_version %q cannot be greater than max_protocol_version %q", + cfg.MinProtocolVersion, cfg.MaxProtocolVersion) + } + return cfg, nil } @@ -514,7 +564,7 @@ func (cfg *Config) fromEnv(env []string) error { case "PGREQUIREAUTH", "PGCHANNELBINDING", "PGSERVICE", "PGSERVICEFILE", "PGREALM", "PGSSLCERTMODE", "PGSSLCOMPRESSION", "PGREQUIRESSL", "PGSSLCRL", "PGREQUIREPEER", "PGSYSCONFDIR", "PGLOCALEDIR", "PGSSLCRLDIR", "PGSSLMINPROTOCOLVERSION", "PGSSLMAXPROTOCOLVERSION", - "PGGSSENCMODE", "PGGSSDELEGATION", "PGMINPROTOCOLVERSION", "PGMAXPROTOCOLVERSION", "PGGSSLIB": + "PGGSSENCMODE", "PGGSSDELEGATION", "PGGSSLIB": return fmt.Errorf("pq: environment variable $%s is not supported", k) case "PGKRBSRVNAME": if newGss == nil { @@ -654,6 +704,8 @@ func (cfg *Config) setFromTag(o map[string]string, tag string) error { sslnegotiation = (tag == "postgres" && k == "sslnegotiation") || (tag == "env" && k == "PGSSLNEGOTIATION") targetsessionattrs = (tag == "postgres" && k == "target_session_attrs") || (tag == "env" && k == "PGTARGETSESSIONATTRS") loadbalancehosts = (tag == "postgres" && k == "load_balance_hosts") || (tag == "env" && k == "PGLOADBALANCEHOSTS") + minprotocolversion = (tag == "postgres" && k == "min_protocol_version") || (tag == "env" && k == "PGMINPROTOCOLVERSION") + maxprotocolversion = (tag == "postgres" && k == "max_protocol_version") || (tag == "env" && k == "PGMAXPROTOCOLVERSION") ) if k == "" || k == "-" { continue @@ -706,6 +758,9 @@ func (cfg *Config) setFromTag(o map[string]string, tag string) error { if loadbalancehosts && !slices.Contains(loadBalanceHosts, LoadBalanceHosts(v)) { return fmt.Errorf(f+`%q is not supported; supported values are %s`, k, v, pqutil.Join(loadBalanceHosts)) } + if (minprotocolversion || maxprotocolversion) && !slices.Contains(protocolVersions, ProtocolVersion(v)) { + return fmt.Errorf(f+`%q is not supported; supported values are %s`, k, v, pqutil.Join(protocolVersions)) + } if host { vv := strings.Split(v, ",") v = vv[0] @@ -833,7 +888,7 @@ func (cfg Config) string() string { switch k { case "datestyle", "client_encoding": continue - case "host", "port", "user", "sslsni": + case "host", "port", "user", "sslsni", "min_protocol_version", "max_protocol_version": if !cfg.isset(k) { continue } diff --git a/connector_test.go b/connector_test.go index 370f9b21..6e5afba5 100644 --- a/connector_test.go +++ b/connector_test.go @@ -100,7 +100,7 @@ func TestNewConnector(t *testing.T) { t.Fatal(err) } want := fmt.Sprintf( - `map[client_encoding:UTF8 connect_timeout:20 datestyle:ISO, MDY dbname:pqgo host:localhost port:%d search_path:foo sslmode:disable sslsni:yes user:pqgo]`, + `map[client_encoding:UTF8 connect_timeout:20 datestyle:ISO, MDY dbname:pqgo host:localhost max_protocol_version:3.0 min_protocol_version:3.0 port:%d search_path:foo sslmode:disable sslsni:yes user:pqgo]`, cfg.Port) if have := fmt.Sprintf("%v", c.cfg.tomap()); have != want { t.Errorf("\nhave: %s\nwant: %s", have, want) @@ -439,6 +439,19 @@ func TestNewConfig(t *testing.T) { {"host=a,b,c hostaddr=1.1.1.1,2.2.2.2", nil, "", "could not match 3 host names to 2 hostaddr values"}, {"host=a hostaddr=1.1.1.1,2.2.2.2", nil, "", "could not match 1 host names to 2 hostaddr values"}, {"", []string{"PGHOST=a,,b", "PGHOSTADDR=1.1.1.1,,2.2.2.2", "PGPORT=3,,4"}, "host=a,localhost,b hostaddr=1.1.1.1,,2.2.2.2 port=3,5432,4", ""}, + + // Protocol version + {"min_protocol_version=3.0", nil, "min_protocol_version=3.0", ""}, + {"max_protocol_version=3.2", nil, "max_protocol_version=3.2", ""}, + {"min_protocol_version=3.2 max_protocol_version=3.2", nil, "max_protocol_version=3.2 min_protocol_version=3.2", ""}, + {"min_protocol_version=latest max_protocol_version=latest", nil, "max_protocol_version=latest min_protocol_version=latest", ""}, + {"min_protocol_version=3.0 max_protocol_version=latest", nil, "max_protocol_version=latest min_protocol_version=3.0", ""}, + {"", []string{"PGMINPROTOCOLVERSION=3.0", "PGMAXPROTOCOLVERSION=3.2"}, "max_protocol_version=3.2 min_protocol_version=3.0", ""}, + {"min_protocol_version=bogus", nil, "", `pq: wrong value for "min_protocol_version": "bogus" is not supported`}, + {"max_protocol_version=bogus", nil, "", `pq: wrong value for "max_protocol_version": "bogus" is not supported`}, + {"", []string{"PGMINPROTOCOLVERSION=bogus"}, "", `pq: wrong value for $PGMINPROTOCOLVERSION: "bogus" is not supported`}, + {"", []string{"PGMAXPROTOCOLVERSION=bogus"}, "", `pq: wrong value for $PGMAXPROTOCOLVERSION: "bogus" is not supported`}, + {"min_protocol_version=3.2 max_protocol_version=3.0", nil, "", `min_protocol_version "3.2" cannot be greater than max_protocol_version "3.0"`}, } t.Parallel() @@ -508,7 +521,7 @@ func TestConnectMulti(t *testing.T) { connectedTo [3]bool accept = func(n int) func(pqtest.Fake, net.Conn) { return func(f pqtest.Fake, cn net.Conn) { - clientParams, ok := f.ReadStartup(cn) + _, clientParams, ok := f.ReadStartup(cn) if !ok { return } @@ -747,3 +760,98 @@ func TestConnectionTargetSessionAttrs(t *testing.T) { }) } } + +func TestProtocolVersion(t *testing.T) { + var ( + key30 = []byte{1, 2, 3, 4} + key32 = make([]byte, 32) + ) + for i := 0; i < 32; i++ { + key32[i] = byte(i) + } + accept := func(version float32) (*[]byte, func(f pqtest.Fake, cn net.Conn)) { + var kd []byte + return &kd, func(f pqtest.Fake, cn net.Conn) { + v, _, ok := f.ReadStartup(cn) + if !ok { + return + } + use := v + if v > version { + use = version + f.WriteNegotiateProtocolVersion(cn, int(version*10-30), nil) + } + + f.WriteMsg(cn, proto.AuthenticationRequest, "\x00\x00\x00\x00") + if use >= 3.2 { + kd = key32 + } else { + kd = key30 + } + f.WriteBackendKeyData(cn, 666, kd) + f.WriteMsg(cn, proto.ReadyForQuery, "I") + for { + code, _, ok := f.ReadMsg(cn) + if !ok { + return + } + switch code { + case proto.Query: + f.WriteMsg(cn, proto.EmptyQueryResponse, "") + f.WriteMsg(cn, proto.ReadyForQuery, "I") + case proto.Terminate: + cn.Close() + return + } + } + } + } + + tests := []struct { + serverVersion float32 + min, max string + wantKey []byte + wantErr string + }{ + {3.2, "", "", key30, ""}, + {3.2, "3.0", "3.0", key30, ""}, + {3.2, "3.2", "3.2", key32, ""}, + {3.2, "3.0", "latest", key32, ""}, + {3.2, "latest", "latest", key32, ""}, + + {3.0, "3.0", "3.2", key30, ""}, + {3.0, "3.2", "3.2", nil, `pq: protocol version mismatch: min_protocol_version=3.2; server supports up to 3.0`}, + + {3.2, "3.9", "3.0", nil, `"3.9" is not supported`}, + {3.2, "3.0", "3.9", nil, `"3.9" is not supported`}, + {3.2, "3.2", "3.0", nil, `min_protocol_version "3.2" cannot be greater than max_protocol_version "3.0"`}, + {3.2, "latest", "3.0", nil, `min_protocol_version "latest" cannot be greater than max_protocol_version "3.0"`}, + } + + for _, tt := range tests { + tt := tt + t.Run("", func(t *testing.T) { + t.Parallel() + have, a := accept(tt.serverVersion) + f := pqtest.NewFake(t, a) + defer f.Close() + + var extra []string + if tt.min != "" { + extra = append(extra, "min_protocol_version="+tt.min) + } + if tt.max != "" { + extra = append(extra, "max_protocol_version="+tt.max) + } + + db := pqtest.MustDB(t, f.DSN()+" "+strings.Join(extra, " ")) + err := db.Ping() + if !pqtest.ErrorContains(err, tt.wantErr) { + t.Fatalf("wrong error\nhave: %v\nwant: %v", err, tt.wantErr) + } + if tt.wantErr == "" && !reflect.DeepEqual(*have, tt.wantKey) { + t.Fatalf("wrong keydata\nhave: %v\nwant: %v", *have, tt.wantKey) + } + }) + } +} diff --git a/internal/pqtest/fake.go b/internal/pqtest/fake.go index dc003ed0..2507f08b 100644 --- a/internal/pqtest/fake.go +++ b/internal/pqtest/fake.go @@ -91,7 +91,7 @@ func (f Fake) accept(fun func(Fake, net.Conn)) { // Startup reads the startup message from the server with [f.ReadStartup] and // sends [proto.AuthenticationRequest] and [proto.ReadyForQuery]. func (f Fake) Startup(cn net.Conn, params map[string]string) { - if _, ok := f.ReadStartup(cn); !ok { + if _, _, ok := f.ReadStartup(cn); !ok { return } // Technically we don't *need* to send the AuthRequest, but the psql CLI @@ -104,7 +104,7 @@ func (f Fake) Startup(cn net.Conn, params map[string]string) { } // ReadStartup reads the startup message. -func (f Fake) ReadStartup(cn net.Conn) (map[string]string, bool) { +func (f Fake) ReadStartup(cn net.Conn) (float32, map[string]string, bool) { _, msg, ok := f.read(cn, true) var ( params = make(map[string]string) @@ -113,7 +113,7 @@ func (f Fake) ReadStartup(cn net.Conn) (map[string]string, bool) { for i := 0; i < len(m); i += 2 { params[m[i]] = m[i+1] } - return params, ok + return float32(msg[1]) + float32(msg[3])/10, params, ok } // WriteStartup writes startup parameters. @@ -137,7 +137,9 @@ func (f Fake) read(cn net.Conn, startup bool) (proto.RequestCode, []byte, bool) typ := make([]byte, sz) _, err := cn.Read(typ) if err != nil { - if errors.Is(err, io.EOF) { + // No need to error if connection got closed, which is most likely + // intentional. + if errors.Is(err, io.EOF) || strings.Contains(err.Error(), "connection reset by peer") { return 0, nil, false } f.t.Errorf("reading: %s", err) @@ -246,3 +248,24 @@ func (f Fake) SimpleQuery(cn net.Conn, tag string, values ...any) { f.WriteMsg(cn, proto.CommandComplete, tag+"\x00") } + +// WriteBackendKeyData sends a BackendKeyData message with the given process ID +// and secret key (variable length). +func (f Fake) WriteBackendKeyData(cn net.Conn, pid int, secretKey []byte) { + b := make([]byte, 4+len(secretKey)) + binary.BigEndian.PutUint32(b[0:4], uint32(pid)) + copy(b[4:], secretKey) + f.WriteMsg(cn, proto.BackendKeyData, string(b)) +} + +// WriteNegotiateProtocolVersion sends a NegotiateProtocolVersion message. +func (f Fake) WriteNegotiateProtocolVersion(cn net.Conn, newestMinor int, options []string) { + b := make([]byte, 8) + binary.BigEndian.PutUint32(b[0:4], uint32(newestMinor)) + binary.BigEndian.PutUint32(b[4:8], uint32(len(options))) + for _, o := range options { + b = append(b, o...) + b = append(b, 0) + } + f.WriteMsg(cn, proto.NegotiateProtocolVersion, string(b)) +} diff --git a/internal/proto/proto.go b/internal/proto/proto.go index 318d180a..e8b4bc59 100644 --- a/internal/proto/proto.go +++ b/internal/proto/proto.go @@ -10,7 +10,7 @@ import ( // Constants from pqcomm.h const ( ProtocolVersion30 = (3 << 16) | 0 //lint:ignore SA4016 x - ProtocolVersion32 = (3 << 16) | 2 // PostgreSQL ≥18; not yet supported. + ProtocolVersion32 = (3 << 16) | 2 // PostgreSQL ≥18. CancelRequestCode = (1234 << 16) | 5678 NegotiateSSLCode = (1234 << 16) | 5679 NegotiateGSSCode = (1234 << 16) | 5680