Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
unreleased
----------
- Support protocol 3.2, and the `min_protocol_version` and
`max_protocol_version` DSN parameters ([#1258]).

[#1258]: https://github.com/lib/pq/pull/1258

v1.11.2 (2025-02-10)
--------------------
This fixes two regressions:
Expand Down
34 changes: 22 additions & 12 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,12 @@ type conn struct {
// (ErrBadConn) or getForNext().
err syncErr

processID, secretKey int // Cancellation key data for use with CancelRequest messages.
inCopy bool // If true this connection is in the middle of a COPY
noticeHandler func(*Error) // If not nil, notices will be synchronously sent here
notificationHandler func(*Notification) // If not nil, notifications will be synchronously sent here
gss GSS // GSSAPI context
secretKey []byte // Cancellation key for CancelRequest messages.
pid int // Cancellation PID.
inCopy bool // If true this connection is in the middle of a COPY
noticeHandler func(*Error) // If not nil, notices will be synchronously sent here
notificationHandler func(*Notification) // If not nil, notifications will be synchronously sent here
gss GSS // GSSAPI context
}

type syncErr struct {
Expand Down Expand Up @@ -1186,7 +1187,10 @@ func (cn *conn) ssl(cfg Config) error {

func (cn *conn) startup(cfg Config) error {
w := cn.writeBuf(0)
w.int32(proto.ProtocolVersion30)
// Send maximum protocol version in startup; if the server doesn't support
// this version it responds with NegotiateProtocolVersion and the maximum
// version it supports (and will use).
w.int32(cfg.MaxProtocolVersion.proto())

if cfg.User != "" {
w.string("user")
Expand Down Expand Up @@ -1226,14 +1230,25 @@ func (cn *conn) startup(cfg Config) error {
}
switch t {
case proto.BackendKeyData:
cn.processBackendKeyData(r)
cn.pid = r.int32()
if len(*r) > 256 {
return fmt.Errorf("pq: cancellation key longer than 256 bytes: %d bytes", len(*r))
}
cn.secretKey = make([]byte, len(*r))
copy(cn.secretKey, *r)
case proto.ParameterStatus:
cn.processParameterStatus(r)
case proto.AuthenticationRequest:
err := cn.auth(r, cfg)
if err != nil {
return err
}
case proto.NegotiateProtocolVersion:
newestMinor := r.int32()
serverVersion := proto.ProtocolVersion30&0xFFFF0000 | newestMinor
if serverVersion < cfg.MinProtocolVersion.proto() {
return fmt.Errorf("pq: protocol version mismatch: min_protocol_version=%s; server supports up to 3.%d", cfg.MinProtocolVersion, newestMinor)
}
case proto.ReadyForQuery:
cn.processReadyForQuery(r)
return nil
Expand Down Expand Up @@ -1566,11 +1581,6 @@ func (cn *conn) readReadyForQuery() error {
}
}

func (cn *conn) processBackendKeyData(r *readBuf) {
cn.processID = r.int32()
cn.secretKey = r.int32()
}

func (cn *conn) readParseResponse() error {
t, r, err := cn.recv1()
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions conn_go18.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ func (cn *conn) cancel(ctx context.Context) error {

w := cn2.writeBuf(0)
w.int32(proto.CancelRequestCode)
w.int32(cn.processID)
w.int32(cn.secretKey)
w.int32(cn.pid)
w.bytes(cn.secretKey)
if err := cn2.sendStartupPacket(w); err != nil {
return err
}
Expand Down
61 changes: 58 additions & 3 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"unicode"

"github.com/lib/pq/internal/pqutil"
"github.com/lib/pq/internal/proto"
)

type (
Expand All @@ -33,6 +34,10 @@ type (

// LoadBalanceHosts is a load_balance_hosts setting.
LoadBalanceHosts string

// ProtocolVersion is a min_protocol_version or max_protocol_version
// setting.
ProtocolVersion string
)

// Values for [SSLMode] that pq supports.
Expand Down Expand Up @@ -110,6 +115,23 @@ const (

var loadBalanceHosts = []LoadBalanceHosts{LoadBalanceHostsDisable, LoadBalanceHostsRandom}

// Values for [ProtocolVersion] that pq supports.
const (
// ProtocolVersion30 is the default protocol version, supported in
// PostgreSQL 3.0 and newer.
ProtocolVersion30 = ProtocolVersion("3.0")

// ProtocolVersion32 uses a longer secret key length for query cancellation,
// supported in PostgreSQL 18 and newer.
ProtocolVersion32 = ProtocolVersion("3.2")

// ProtocolVersionLatest is the latest protocol version that pq supports
// (which may not be supported by the server).
ProtocolVersionLatest = ProtocolVersion("latest")
)

var protocolVersions = []ProtocolVersion{ProtocolVersion30, ProtocolVersion32, ProtocolVersionLatest}

// Connector represents a fixed configuration for the pq driver with a given
// dsn. Connector satisfies the [database/sql/driver.Connector] interface and
// can be used to create any number of DB Conn's via [sql.OpenDB].
Expand Down Expand Up @@ -148,6 +170,15 @@ func (c *Connector) Dialer(dialer Dialer) { c.dialer = dialer }
// Driver returns the underlying driver of this Connector.
func (c *Connector) Driver() driver.Driver { return &Driver{} }

func (p ProtocolVersion) proto() int {
switch p {
default:
return proto.ProtocolVersion30
case ProtocolVersion32, ProtocolVersionLatest:
return proto.ProtocolVersion32
}
}

// Config holds options pq supports when connecting to PostgreSQL.
//
// The postgres struct tag is used for the value from the DSN (e.g.
Expand Down Expand Up @@ -303,6 +334,14 @@ type Config struct {
// to the same server.
LoadBalanceHosts LoadBalanceHosts `postgres:"load_balance_hosts" env:"PGLOADBALANCEHOSTS"`

// Minimum acceptable PostgreSQL protocol version. If the server does not
// support at least this version, the connection will fail. Defaults to
// "3.0".
MinProtocolVersion ProtocolVersion `postgres:"min_protocol_version" env:"PGMINPROTOCOLVERSION"`

// Maximum PostgreSQL protocol version to request from the server. Defaults to "3.0".
MaxProtocolVersion ProtocolVersion `postgres:"max_protocol_version" env:"PGMAXPROTOCOLVERSION"`

// Runtime parameters: any unrecognized parameter in the DSN will be added
// to this and sent to PostgreSQL during startup.
Runtime map[string]string `postgres:"-" env:"-"`
Expand Down Expand Up @@ -413,7 +452,13 @@ func (cfg Config) hosts() []Config {
}

func newConfig(dsn string, env []string) (Config, error) {
cfg := Config{Host: "localhost", Port: 5432, SSLSNI: true}
cfg := Config{
Host: "localhost",
Port: 5432,
SSLSNI: true,
MinProtocolVersion: "3.0",
MaxProtocolVersion: "3.0",
}
if err := cfg.fromEnv(env); err != nil {
return Config{}, err
}
Expand Down Expand Up @@ -487,6 +532,11 @@ func newConfig(dsn string, env []string) (Config, error) {
cfg.SSLMode = SSLModeDisable
}

if cfg.MinProtocolVersion > cfg.MaxProtocolVersion {
return Config{}, fmt.Errorf("pq: min_protocol_version %q cannot be greater than max_protocol_version %q",
cfg.MinProtocolVersion, cfg.MaxProtocolVersion)
}

return cfg, nil
}

Expand Down Expand Up @@ -514,7 +564,7 @@ func (cfg *Config) fromEnv(env []string) error {
case "PGREQUIREAUTH", "PGCHANNELBINDING", "PGSERVICE", "PGSERVICEFILE", "PGREALM",
"PGSSLCERTMODE", "PGSSLCOMPRESSION", "PGREQUIRESSL", "PGSSLCRL", "PGREQUIREPEER",
"PGSYSCONFDIR", "PGLOCALEDIR", "PGSSLCRLDIR", "PGSSLMINPROTOCOLVERSION", "PGSSLMAXPROTOCOLVERSION",
"PGGSSENCMODE", "PGGSSDELEGATION", "PGMINPROTOCOLVERSION", "PGMAXPROTOCOLVERSION", "PGGSSLIB":
"PGGSSENCMODE", "PGGSSDELEGATION", "PGGSSLIB":
return fmt.Errorf("pq: environment variable $%s is not supported", k)
case "PGKRBSRVNAME":
if newGss == nil {
Expand Down Expand Up @@ -654,6 +704,8 @@ func (cfg *Config) setFromTag(o map[string]string, tag string) error {
sslnegotiation = (tag == "postgres" && k == "sslnegotiation") || (tag == "env" && k == "PGSSLNEGOTIATION")
targetsessionattrs = (tag == "postgres" && k == "target_session_attrs") || (tag == "env" && k == "PGTARGETSESSIONATTRS")
loadbalancehosts = (tag == "postgres" && k == "load_balance_hosts") || (tag == "env" && k == "PGLOADBALANCEHOSTS")
minprotocolversion = (tag == "postgres" && k == "min_protocol_version") || (tag == "env" && k == "PGMINPROTOCOLVERSION")
maxprotocolversion = (tag == "postgres" && k == "max_protocol_version") || (tag == "env" && k == "PGMAXPROTOCOLVERSION")
)
if k == "" || k == "-" {
continue
Expand Down Expand Up @@ -706,6 +758,9 @@ func (cfg *Config) setFromTag(o map[string]string, tag string) error {
if loadbalancehosts && !slices.Contains(loadBalanceHosts, LoadBalanceHosts(v)) {
return fmt.Errorf(f+`%q is not supported; supported values are %s`, k, v, pqutil.Join(loadBalanceHosts))
}
if (minprotocolversion || maxprotocolversion) && !slices.Contains(protocolVersions, ProtocolVersion(v)) {
return fmt.Errorf(f+`%q is not supported; supported values are %s`, k, v, pqutil.Join(protocolVersions))
}
if host {
vv := strings.Split(v, ",")
v = vv[0]
Expand Down Expand Up @@ -833,7 +888,7 @@ func (cfg Config) string() string {
switch k {
case "datestyle", "client_encoding":
continue
case "host", "port", "user", "sslsni":
case "host", "port", "user", "sslsni", "min_protocol_version", "max_protocol_version":
if !cfg.isset(k) {
continue
}
Expand Down
112 changes: 110 additions & 2 deletions connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func TestNewConnector(t *testing.T) {
t.Fatal(err)
}
want := fmt.Sprintf(
`map[client_encoding:UTF8 connect_timeout:20 datestyle:ISO, MDY dbname:pqgo host:localhost port:%d search_path:foo sslmode:disable sslsni:yes user:pqgo]`,
`map[client_encoding:UTF8 connect_timeout:20 datestyle:ISO, MDY dbname:pqgo host:localhost max_protocol_version:3.0 min_protocol_version:3.0 port:%d search_path:foo sslmode:disable sslsni:yes user:pqgo]`,
cfg.Port)
if have := fmt.Sprintf("%v", c.cfg.tomap()); have != want {
t.Errorf("\nhave: %s\nwant: %s", have, want)
Expand Down Expand Up @@ -439,6 +439,19 @@ func TestNewConfig(t *testing.T) {
{"host=a,b,c hostaddr=1.1.1.1,2.2.2.2", nil, "", "could not match 3 host names to 2 hostaddr values"},
{"host=a hostaddr=1.1.1.1,2.2.2.2", nil, "", "could not match 1 host names to 2 hostaddr values"},
{"", []string{"PGHOST=a,,b", "PGHOSTADDR=1.1.1.1,,2.2.2.2", "PGPORT=3,,4"}, "host=a,localhost,b hostaddr=1.1.1.1,,2.2.2.2 port=3,5432,4", ""},

// Protocol version
{"min_protocol_version=3.0", nil, "min_protocol_version=3.0", ""},
{"max_protocol_version=3.2", nil, "max_protocol_version=3.2", ""},
{"min_protocol_version=3.2 max_protocol_version=3.2", nil, "max_protocol_version=3.2 min_protocol_version=3.2", ""},
{"min_protocol_version=latest max_protocol_version=latest", nil, "max_protocol_version=latest min_protocol_version=latest", ""},
{"min_protocol_version=3.0 max_protocol_version=latest", nil, "max_protocol_version=latest min_protocol_version=3.0", ""},
{"", []string{"PGMINPROTOCOLVERSION=3.0", "PGMAXPROTOCOLVERSION=3.2"}, "max_protocol_version=3.2 min_protocol_version=3.0", ""},
{"min_protocol_version=bogus", nil, "", `pq: wrong value for "min_protocol_version": "bogus" is not supported`},
{"max_protocol_version=bogus", nil, "", `pq: wrong value for "max_protocol_version": "bogus" is not supported`},
{"", []string{"PGMINPROTOCOLVERSION=bogus"}, "", `pq: wrong value for $PGMINPROTOCOLVERSION: "bogus" is not supported`},
{"", []string{"PGMAXPROTOCOLVERSION=bogus"}, "", `pq: wrong value for $PGMAXPROTOCOLVERSION: "bogus" is not supported`},
{"min_protocol_version=3.2 max_protocol_version=3.0", nil, "", `min_protocol_version "3.2" cannot be greater than max_protocol_version "3.0"`},
}

t.Parallel()
Expand Down Expand Up @@ -508,7 +521,7 @@ func TestConnectMulti(t *testing.T) {
connectedTo [3]bool
accept = func(n int) func(pqtest.Fake, net.Conn) {
return func(f pqtest.Fake, cn net.Conn) {
clientParams, ok := f.ReadStartup(cn)
_, clientParams, ok := f.ReadStartup(cn)
if !ok {
return
}
Expand Down Expand Up @@ -747,3 +760,98 @@ func TestConnectionTargetSessionAttrs(t *testing.T) {
})
}
}

func TestProtocolVersion(t *testing.T) {
var (
key30 = []byte{1, 2, 3, 4}
key32 = make([]byte, 32)
)
for i := 0; i < 32; i++ {
key32[i] = byte(i)
}
accept := func(version float32) (*[]byte, func(f pqtest.Fake, cn net.Conn)) {
var kd []byte
return &kd, func(f pqtest.Fake, cn net.Conn) {
v, _, ok := f.ReadStartup(cn)
if !ok {
return
}
use := v
if v > version {
use = version
f.WriteNegotiateProtocolVersion(cn, int(version*10-30), nil)
}

f.WriteMsg(cn, proto.AuthenticationRequest, "\x00\x00\x00\x00")
if use >= 3.2 {
kd = key32
} else {
kd = key30
}
f.WriteBackendKeyData(cn, 666, kd)
f.WriteMsg(cn, proto.ReadyForQuery, "I")
for {
code, _, ok := f.ReadMsg(cn)
if !ok {
return
}
switch code {
case proto.Query:
f.WriteMsg(cn, proto.EmptyQueryResponse, "")
f.WriteMsg(cn, proto.ReadyForQuery, "I")
case proto.Terminate:
cn.Close()
return
}
}
}
}

tests := []struct {
serverVersion float32
min, max string
wantKey []byte
wantErr string
}{
{3.2, "", "", key30, ""},
{3.2, "3.0", "3.0", key30, ""},
{3.2, "3.2", "3.2", key32, ""},
{3.2, "3.0", "latest", key32, ""},
{3.2, "latest", "latest", key32, ""},

{3.0, "3.0", "3.2", key30, ""},
{3.0, "3.2", "3.2", nil, `pq: protocol version mismatch: min_protocol_version=3.2; server supports up to 3.0`},

{3.2, "3.9", "3.0", nil, `"3.9" is not supported`},
{3.2, "3.0", "3.9", nil, `"3.9" is not supported`},
{3.2, "3.2", "3.0", nil, `min_protocol_version "3.2" cannot be greater than max_protocol_version "3.0"`},
{3.2, "latest", "3.0", nil, `min_protocol_version "latest" cannot be greater than max_protocol_version "3.0"`},
}

for _, tt := range tests {
tt := tt
t.Run("", func(t *testing.T) {
t.Parallel()
have, a := accept(tt.serverVersion)
f := pqtest.NewFake(t, a)
defer f.Close()

var extra []string
if tt.min != "" {
extra = append(extra, "min_protocol_version="+tt.min)
}
if tt.max != "" {
extra = append(extra, "max_protocol_version="+tt.max)
}

db := pqtest.MustDB(t, f.DSN()+" "+strings.Join(extra, " "))
err := db.Ping()
if !pqtest.ErrorContains(err, tt.wantErr) {
t.Fatalf("wrong error\nhave: %v\nwant: %v", err, tt.wantErr)
}
if tt.wantErr == "" && !reflect.DeepEqual(*have, tt.wantKey) {
t.Fatalf("wrong keydata\nhave: %v\nwant: %v", *have, tt.wantKey)
}
})
}
}
Loading
Loading