From d56fb07079c5e511ebe00749a1466394b516153d Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Sat, 7 Feb 2026 20:30:29 +0100 Subject: [PATCH] Add support for protocol 3.2 This adds support for the 3.2 protocol version, introduced with PostgreSQL 18. It follows postgres in the sense that the default is still 3.0, but this allows for allocations to allow the 3.2 version of the protocol with longer secret key data. This is to both improve security and to provide room for additional metadata for middleware. Co-authored-by: Martin Tournoij --- CHANGELOG.md | 7 +++ conn.go | 34 +++++++----- conn_go18.go | 4 +- connector.go | 61 ++++++++++++++++++++-- connector_test.go | 112 +++++++++++++++++++++++++++++++++++++++- internal/pqtest/fake.go | 31 +++++++++-- internal/proto/proto.go | 2 +- 7 files changed, 227 insertions(+), 24 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2fce02d2..237a93cf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,10 @@ +unreleased +---------- +- Support protocol 3.2, and the `min_protocol_version` and + `max_protocol_version` DSN parameters ([#1258]). + +[#1258]: https://github.com/lib/pq/pull/1258 + v1.11.2 (2025-02-10) -------------------- This fixes two regressions: diff --git a/conn.go b/conn.go index 9e69b473..b1cf83cd 100644 --- a/conn.go +++ b/conn.go @@ -174,11 +174,12 @@ type conn struct { // (ErrBadConn) or getForNext(). err syncErr - processID, secretKey int // Cancellation key data for use with CancelRequest messages. - inCopy bool // If true this connection is in the middle of a COPY - noticeHandler func(*Error) // If not nil, notices will be synchronously sent here - notificationHandler func(*Notification) // If not nil, notifications will be synchronously sent here - gss GSS // GSSAPI context + secretKey []byte // Cancellation key for CancelRequest messages. + pid int // Cancellation PID. + inCopy bool // If true this connection is in the middle of a COPY + noticeHandler func(*Error) // If not nil, notices will be synchronously sent here + notificationHandler func(*Notification) // If not nil, notifications will be synchronously sent here + gss GSS // GSSAPI context } type syncErr struct { @@ -1186,7 +1187,10 @@ func (cn *conn) ssl(cfg Config) error { func (cn *conn) startup(cfg Config) error { w := cn.writeBuf(0) - w.int32(proto.ProtocolVersion30) + // Send maximum protocol version in startup; if the server doesn't support + // this version it responds with NegotiateProtocolVersion and the maximum + // version it supports (and will use). + w.int32(cfg.MaxProtocolVersion.proto()) if cfg.User != "" { w.string("user") @@ -1226,7 +1230,12 @@ func (cn *conn) startup(cfg Config) error { } switch t { case proto.BackendKeyData: - cn.processBackendKeyData(r) + cn.pid = r.int32() + if len(*r) > 256 { + return fmt.Errorf("pq: cancellation key longer than 256 bytes: %d bytes", len(*r)) + } + cn.secretKey = make([]byte, len(*r)) + copy(cn.secretKey, *r) case proto.ParameterStatus: cn.processParameterStatus(r) case proto.AuthenticationRequest: @@ -1234,6 +1243,12 @@ func (cn *conn) startup(cfg Config) error { if err != nil { return err } + case proto.NegotiateProtocolVersion: + newestMinor := r.int32() + serverVersion := proto.ProtocolVersion30&0xFFFF0000 | newestMinor + if serverVersion < cfg.MinProtocolVersion.proto() { + return fmt.Errorf("pq: protocol version mismatch: min_protocol_version=%s; server supports up to 3.%d", cfg.MinProtocolVersion, newestMinor) + } case proto.ReadyForQuery: cn.processReadyForQuery(r) return nil @@ -1566,11 +1581,6 @@ func (cn *conn) readReadyForQuery() error { } } -func (cn *conn) processBackendKeyData(r *readBuf) { - cn.processID = r.int32() - cn.secretKey = r.int32() -} - func (cn *conn) readParseResponse() error { t, r, err := cn.recv1() if err != nil { diff --git a/conn_go18.go b/conn_go18.go index 23a10aee..d776175e 100644 --- a/conn_go18.go +++ b/conn_go18.go @@ -151,8 +151,8 @@ func (cn *conn) cancel(ctx context.Context) error { w := cn2.writeBuf(0) w.int32(proto.CancelRequestCode) - w.int32(cn.processID) - w.int32(cn.secretKey) + w.int32(cn.pid) + w.bytes(cn.secretKey) if err := cn2.sendStartupPacket(w); err != nil { return err } diff --git a/connector.go b/connector.go index 1827fdbd..9b1c193a 100644 --- a/connector.go +++ b/connector.go @@ -19,6 +19,7 @@ import ( "unicode" "github.com/lib/pq/internal/pqutil" + "github.com/lib/pq/internal/proto" ) type ( @@ -33,6 +34,10 @@ type ( // LoadBalanceHosts is a load_balance_hosts setting. LoadBalanceHosts string + + // ProtocolVersion is a min_protocol_version or max_protocol_version + // setting. + ProtocolVersion string ) // Values for [SSLMode] that pq supports. @@ -110,6 +115,23 @@ const ( var loadBalanceHosts = []LoadBalanceHosts{LoadBalanceHostsDisable, LoadBalanceHostsRandom} +// Values for [ProtocolVersion] that pq supports. +const ( + // ProtocolVersion30 is the default protocol version, supported in + // PostgreSQL 3.0 and newer. + ProtocolVersion30 = ProtocolVersion("3.0") + + // ProtocolVersion32 uses a longer secret key length for query cancellation, + // supported in PostgreSQL 18 and newer. + ProtocolVersion32 = ProtocolVersion("3.2") + + // ProtocolVersionLatest is the latest protocol version that pq supports + // (which may not be supported by the server). + ProtocolVersionLatest = ProtocolVersion("latest") +) + +var protocolVersions = []ProtocolVersion{ProtocolVersion30, ProtocolVersion32, ProtocolVersionLatest} + // Connector represents a fixed configuration for the pq driver with a given // dsn. Connector satisfies the [database/sql/driver.Connector] interface and // can be used to create any number of DB Conn's via [sql.OpenDB]. @@ -148,6 +170,15 @@ func (c *Connector) Dialer(dialer Dialer) { c.dialer = dialer } // Driver returns the underlying driver of this Connector. func (c *Connector) Driver() driver.Driver { return &Driver{} } +func (p ProtocolVersion) proto() int { + switch p { + default: + return proto.ProtocolVersion30 + case ProtocolVersion32, ProtocolVersionLatest: + return proto.ProtocolVersion32 + } +} + // Config holds options pq supports when connecting to PostgreSQL. // // The postgres struct tag is used for the value from the DSN (e.g. @@ -303,6 +334,14 @@ type Config struct { // to the same server. LoadBalanceHosts LoadBalanceHosts `postgres:"load_balance_hosts" env:"PGLOADBALANCEHOSTS"` + // Minimum acceptable PostgreSQL protocol version. If the server does not + // support at least this version, the connection will fail. Defaults to + // "3.0". + MinProtocolVersion ProtocolVersion `postgres:"min_protocol_version" env:"PGMINPROTOCOLVERSION"` + + // Maximum PostgreSQL protocol version to request from the server. Defaults to "3.0". + MaxProtocolVersion ProtocolVersion `postgres:"max_protocol_version" env:"PGMAXPROTOCOLVERSION"` + // Runtime parameters: any unrecognized parameter in the DSN will be added // to this and sent to PostgreSQL during startup. Runtime map[string]string `postgres:"-" env:"-"` @@ -413,7 +452,13 @@ func (cfg Config) hosts() []Config { } func newConfig(dsn string, env []string) (Config, error) { - cfg := Config{Host: "localhost", Port: 5432, SSLSNI: true} + cfg := Config{ + Host: "localhost", + Port: 5432, + SSLSNI: true, + MinProtocolVersion: "3.0", + MaxProtocolVersion: "3.0", + } if err := cfg.fromEnv(env); err != nil { return Config{}, err } @@ -487,6 +532,11 @@ func newConfig(dsn string, env []string) (Config, error) { cfg.SSLMode = SSLModeDisable } + if cfg.MinProtocolVersion > cfg.MaxProtocolVersion { + return Config{}, fmt.Errorf("pq: min_protocol_version %q cannot be greater than max_protocol_version %q", + cfg.MinProtocolVersion, cfg.MaxProtocolVersion) + } + return cfg, nil } @@ -514,7 +564,7 @@ func (cfg *Config) fromEnv(env []string) error { case "PGREQUIREAUTH", "PGCHANNELBINDING", "PGSERVICE", "PGSERVICEFILE", "PGREALM", "PGSSLCERTMODE", "PGSSLCOMPRESSION", "PGREQUIRESSL", "PGSSLCRL", "PGREQUIREPEER", "PGSYSCONFDIR", "PGLOCALEDIR", "PGSSLCRLDIR", "PGSSLMINPROTOCOLVERSION", "PGSSLMAXPROTOCOLVERSION", - "PGGSSENCMODE", "PGGSSDELEGATION", "PGMINPROTOCOLVERSION", "PGMAXPROTOCOLVERSION", "PGGSSLIB": + "PGGSSENCMODE", "PGGSSDELEGATION", "PGGSSLIB": return fmt.Errorf("pq: environment variable $%s is not supported", k) case "PGKRBSRVNAME": if newGss == nil { @@ -654,6 +704,8 @@ func (cfg *Config) setFromTag(o map[string]string, tag string) error { sslnegotiation = (tag == "postgres" && k == "sslnegotiation") || (tag == "env" && k == "PGSSLNEGOTIATION") targetsessionattrs = (tag == "postgres" && k == "target_session_attrs") || (tag == "env" && k == "PGTARGETSESSIONATTRS") loadbalancehosts = (tag == "postgres" && k == "load_balance_hosts") || (tag == "env" && k == "PGLOADBALANCEHOSTS") + minprotocolversion = (tag == "postgres" && k == "min_protocol_version") || (tag == "env" && k == "PGMINPROTOCOLVERSION") + maxprotocolversion = (tag == "postgres" && k == "max_protocol_version") || (tag == "env" && k == "PGMAXPROTOCOLVERSION") ) if k == "" || k == "-" { continue @@ -706,6 +758,9 @@ func (cfg *Config) setFromTag(o map[string]string, tag string) error { if loadbalancehosts && !slices.Contains(loadBalanceHosts, LoadBalanceHosts(v)) { return fmt.Errorf(f+`%q is not supported; supported values are %s`, k, v, pqutil.Join(loadBalanceHosts)) } + if (minprotocolversion || maxprotocolversion) && !slices.Contains(protocolVersions, ProtocolVersion(v)) { + return fmt.Errorf(f+`%q is not supported; supported values are %s`, k, v, pqutil.Join(protocolVersions)) + } if host { vv := strings.Split(v, ",") v = vv[0] @@ -833,7 +888,7 @@ func (cfg Config) string() string { switch k { case "datestyle", "client_encoding": continue - case "host", "port", "user", "sslsni": + case "host", "port", "user", "sslsni", "min_protocol_version", "max_protocol_version": if !cfg.isset(k) { continue } diff --git a/connector_test.go b/connector_test.go index 370f9b21..6e5afba5 100644 --- a/connector_test.go +++ b/connector_test.go @@ -100,7 +100,7 @@ func TestNewConnector(t *testing.T) { t.Fatal(err) } want := fmt.Sprintf( - `map[client_encoding:UTF8 connect_timeout:20 datestyle:ISO, MDY dbname:pqgo host:localhost port:%d search_path:foo sslmode:disable sslsni:yes user:pqgo]`, + `map[client_encoding:UTF8 connect_timeout:20 datestyle:ISO, MDY dbname:pqgo host:localhost max_protocol_version:3.0 min_protocol_version:3.0 port:%d search_path:foo sslmode:disable sslsni:yes user:pqgo]`, cfg.Port) if have := fmt.Sprintf("%v", c.cfg.tomap()); have != want { t.Errorf("\nhave: %s\nwant: %s", have, want) @@ -439,6 +439,19 @@ func TestNewConfig(t *testing.T) { {"host=a,b,c hostaddr=1.1.1.1,2.2.2.2", nil, "", "could not match 3 host names to 2 hostaddr values"}, {"host=a hostaddr=1.1.1.1,2.2.2.2", nil, "", "could not match 1 host names to 2 hostaddr values"}, {"", []string{"PGHOST=a,,b", "PGHOSTADDR=1.1.1.1,,2.2.2.2", "PGPORT=3,,4"}, "host=a,localhost,b hostaddr=1.1.1.1,,2.2.2.2 port=3,5432,4", ""}, + + // Protocol version + {"min_protocol_version=3.0", nil, "min_protocol_version=3.0", ""}, + {"max_protocol_version=3.2", nil, "max_protocol_version=3.2", ""}, + {"min_protocol_version=3.2 max_protocol_version=3.2", nil, "max_protocol_version=3.2 min_protocol_version=3.2", ""}, + {"min_protocol_version=latest max_protocol_version=latest", nil, "max_protocol_version=latest min_protocol_version=latest", ""}, + {"min_protocol_version=3.0 max_protocol_version=latest", nil, "max_protocol_version=latest min_protocol_version=3.0", ""}, + {"", []string{"PGMINPROTOCOLVERSION=3.0", "PGMAXPROTOCOLVERSION=3.2"}, "max_protocol_version=3.2 min_protocol_version=3.0", ""}, + {"min_protocol_version=bogus", nil, "", `pq: wrong value for "min_protocol_version": "bogus" is not supported`}, + {"max_protocol_version=bogus", nil, "", `pq: wrong value for "max_protocol_version": "bogus" is not supported`}, + {"", []string{"PGMINPROTOCOLVERSION=bogus"}, "", `pq: wrong value for $PGMINPROTOCOLVERSION: "bogus" is not supported`}, + {"", []string{"PGMAXPROTOCOLVERSION=bogus"}, "", `pq: wrong value for $PGMAXPROTOCOLVERSION: "bogus" is not supported`}, + {"min_protocol_version=3.2 max_protocol_version=3.0", nil, "", `min_protocol_version "3.2" cannot be greater than max_protocol_version "3.0"`}, } t.Parallel() @@ -508,7 +521,7 @@ func TestConnectMulti(t *testing.T) { connectedTo [3]bool accept = func(n int) func(pqtest.Fake, net.Conn) { return func(f pqtest.Fake, cn net.Conn) { - clientParams, ok := f.ReadStartup(cn) + _, clientParams, ok := f.ReadStartup(cn) if !ok { return } @@ -747,3 +760,98 @@ func TestConnectionTargetSessionAttrs(t *testing.T) { }) } } + +func TestProtocolVersion(t *testing.T) { + var ( + key30 = []byte{1, 2, 3, 4} + key32 = make([]byte, 32) + ) + for i := 0; i < 32; i++ { + key32[i] = byte(i) + } + accept := func(version float32) (*[]byte, func(f pqtest.Fake, cn net.Conn)) { + var kd []byte + return &kd, func(f pqtest.Fake, cn net.Conn) { + v, _, ok := f.ReadStartup(cn) + if !ok { + return + } + use := v + if v > version { + use = version + f.WriteNegotiateProtocolVersion(cn, int(version*10-30), nil) + } + + f.WriteMsg(cn, proto.AuthenticationRequest, "\x00\x00\x00\x00") + if use >= 3.2 { + kd = key32 + } else { + kd = key30 + } + f.WriteBackendKeyData(cn, 666, kd) + f.WriteMsg(cn, proto.ReadyForQuery, "I") + for { + code, _, ok := f.ReadMsg(cn) + if !ok { + return + } + switch code { + case proto.Query: + f.WriteMsg(cn, proto.EmptyQueryResponse, "") + f.WriteMsg(cn, proto.ReadyForQuery, "I") + case proto.Terminate: + cn.Close() + return + } + } + } + } + + tests := []struct { + serverVersion float32 + min, max string + wantKey []byte + wantErr string + }{ + {3.2, "", "", key30, ""}, + {3.2, "3.0", "3.0", key30, ""}, + {3.2, "3.2", "3.2", key32, ""}, + {3.2, "3.0", "latest", key32, ""}, + {3.2, "latest", "latest", key32, ""}, + + {3.0, "3.0", "3.2", key30, ""}, + {3.0, "3.2", "3.2", nil, `pq: protocol version mismatch: min_protocol_version=3.2; server supports up to 3.0`}, + + {3.2, "3.9", "3.0", nil, `"3.9" is not supported`}, + {3.2, "3.0", "3.9", nil, `"3.9" is not supported`}, + {3.2, "3.2", "3.0", nil, `min_protocol_version "3.2" cannot be greater than max_protocol_version "3.0"`}, + {3.2, "latest", "3.0", nil, `min_protocol_version "latest" cannot be greater than max_protocol_version "3.0"`}, + } + + for _, tt := range tests { + tt := tt + t.Run("", func(t *testing.T) { + t.Parallel() + have, a := accept(tt.serverVersion) + f := pqtest.NewFake(t, a) + defer f.Close() + + var extra []string + if tt.min != "" { + extra = append(extra, "min_protocol_version="+tt.min) + } + if tt.max != "" { + extra = append(extra, "max_protocol_version="+tt.max) + } + + db := pqtest.MustDB(t, f.DSN()+" "+strings.Join(extra, " ")) + err := db.Ping() + if !pqtest.ErrorContains(err, tt.wantErr) { + t.Fatalf("wrong error\nhave: %v\nwant: %v", err, tt.wantErr) + } + if tt.wantErr == "" && !reflect.DeepEqual(*have, tt.wantKey) { + t.Fatalf("wrong keydata\nhave: %v\nwant: %v", *have, tt.wantKey) + } + }) + } +} diff --git a/internal/pqtest/fake.go b/internal/pqtest/fake.go index dc003ed0..2507f08b 100644 --- a/internal/pqtest/fake.go +++ b/internal/pqtest/fake.go @@ -91,7 +91,7 @@ func (f Fake) accept(fun func(Fake, net.Conn)) { // Startup reads the startup message from the server with [f.ReadStartup] and // sends [proto.AuthenticationRequest] and [proto.ReadyForQuery]. func (f Fake) Startup(cn net.Conn, params map[string]string) { - if _, ok := f.ReadStartup(cn); !ok { + if _, _, ok := f.ReadStartup(cn); !ok { return } // Technically we don't *need* to send the AuthRequest, but the psql CLI @@ -104,7 +104,7 @@ func (f Fake) Startup(cn net.Conn, params map[string]string) { } // ReadStartup reads the startup message. -func (f Fake) ReadStartup(cn net.Conn) (map[string]string, bool) { +func (f Fake) ReadStartup(cn net.Conn) (float32, map[string]string, bool) { _, msg, ok := f.read(cn, true) var ( params = make(map[string]string) @@ -113,7 +113,7 @@ func (f Fake) ReadStartup(cn net.Conn) (map[string]string, bool) { for i := 0; i < len(m); i += 2 { params[m[i]] = m[i+1] } - return params, ok + return float32(msg[1]) + float32(msg[3])/10, params, ok } // WriteStartup writes startup parameters. @@ -137,7 +137,9 @@ func (f Fake) read(cn net.Conn, startup bool) (proto.RequestCode, []byte, bool) typ := make([]byte, sz) _, err := cn.Read(typ) if err != nil { - if errors.Is(err, io.EOF) { + // No need to error if connection got closed, which is most likely + // intentional. + if errors.Is(err, io.EOF) || strings.Contains(err.Error(), "connection reset by peer") { return 0, nil, false } f.t.Errorf("reading: %s", err) @@ -246,3 +248,24 @@ func (f Fake) SimpleQuery(cn net.Conn, tag string, values ...any) { f.WriteMsg(cn, proto.CommandComplete, tag+"\x00") } + +// WriteBackendKeyData sends a BackendKeyData message with the given process ID +// and secret key (variable length). +func (f Fake) WriteBackendKeyData(cn net.Conn, pid int, secretKey []byte) { + b := make([]byte, 4+len(secretKey)) + binary.BigEndian.PutUint32(b[0:4], uint32(pid)) + copy(b[4:], secretKey) + f.WriteMsg(cn, proto.BackendKeyData, string(b)) +} + +// WriteNegotiateProtocolVersion sends a NegotiateProtocolVersion message. +func (f Fake) WriteNegotiateProtocolVersion(cn net.Conn, newestMinor int, options []string) { + b := make([]byte, 8) + binary.BigEndian.PutUint32(b[0:4], uint32(newestMinor)) + binary.BigEndian.PutUint32(b[4:8], uint32(len(options))) + for _, o := range options { + b = append(b, o...) + b = append(b, 0) + } + f.WriteMsg(cn, proto.NegotiateProtocolVersion, string(b)) +} diff --git a/internal/proto/proto.go b/internal/proto/proto.go index 318d180a..e8b4bc59 100644 --- a/internal/proto/proto.go +++ b/internal/proto/proto.go @@ -10,7 +10,7 @@ import ( // Constants from pqcomm.h const ( ProtocolVersion30 = (3 << 16) | 0 //lint:ignore SA4016 x - ProtocolVersion32 = (3 << 16) | 2 // PostgreSQL ≥18; not yet supported. + ProtocolVersion32 = (3 << 16) | 2 // PostgreSQL ≥18. CancelRequestCode = (1234 << 16) | 5678 NegotiateSSLCode = (1234 << 16) | 5679 NegotiateGSSCode = (1234 << 16) | 5680