mirror of
1
Fork 0

[chore]: Bump github.com/jackc/pgx/v5 from 5.5.1 to 5.5.2 (#2532)

Bumps [github.com/jackc/pgx/v5](https://github.com/jackc/pgx) from 5.5.1 to 5.5.2.
- [Changelog](https://github.com/jackc/pgx/blob/master/CHANGELOG.md)
- [Commits](https://github.com/jackc/pgx/compare/v5.5.1...v5.5.2)

---
updated-dependencies:
- dependency-name: github.com/jackc/pgx/v5
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
This commit is contained in:
dependabot[bot] 2024-01-15 14:02:02 +01:00 committed by GitHub
parent b70ec68499
commit 637a57f2de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 142 additions and 50 deletions

2
go.mod
View File

@ -34,7 +34,7 @@ require (
github.com/gorilla/feeds v1.1.2 github.com/gorilla/feeds v1.1.2
github.com/gorilla/websocket v1.5.1 github.com/gorilla/websocket v1.5.1
github.com/h2non/filetype v1.1.3 github.com/h2non/filetype v1.1.3
github.com/jackc/pgx/v5 v5.5.1 github.com/jackc/pgx/v5 v5.5.2
github.com/microcosm-cc/bluemonday v1.0.26 github.com/microcosm-cc/bluemonday v1.0.26
github.com/miekg/dns v1.1.57 github.com/miekg/dns v1.1.57
github.com/minio/minio-go/v7 v7.0.66 github.com/minio/minio-go/v7 v7.0.66

4
go.sum
View File

@ -346,8 +346,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.5.1 h1:5I9etrGkLrN+2XPCsi6XLlV5DITbSL/xBZdmAxFcXPI= github.com/jackc/pgx/v5 v5.5.2 h1:iLlpgp4Cp/gC9Xuscl7lFL1PhhW+ZLtXZcrfCt4C3tA=
github.com/jackc/pgx/v5 v5.5.1/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= github.com/jackc/pgx/v5 v5.5.2/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=

View File

@ -1,3 +1,14 @@
# 5.5.2 (January 13, 2024)
* Allow NamedArgs to start with underscore
* pgproto3: Maximum message body length support (jeremy.spriet)
* Upgrade golang.org/x/crypto to v0.17.0
* Add snake_case support to RowToStructByName (Tikhon Fedulov)
* Fix: update description cache after exec prepare (James Hartig)
* Fix: pipeline checks if it is closed (James Hartig and Ryan Fowler)
* Fix: normalize timeout / context errors during TLS startup (Samuel Stauffer)
* Add OnPgError for easier centralized error handling (James Hartig)
# 5.5.1 (December 9, 2023) # 5.5.1 (December 9, 2023)
* Add CopyFromFunc helper function. (robford) * Add CopyFromFunc helper function. (robford)

View File

@ -513,6 +513,7 @@ optionLoop:
if err != nil { if err != nil {
return pgconn.CommandTag{}, err return pgconn.CommandTag{}, err
} }
c.descriptionCache.Put(sd)
} }
return c.execParams(ctx, sd, arguments) return c.execParams(ctx, sd, arguments)

View File

@ -187,7 +187,7 @@ implemented on top of pgconn. The Conn.PgConn() method can be used to access thi
PgBouncer PgBouncer
By default pgx automatically uses prepared statements. Prepared statements are incompaptible with PgBouncer. This can be By default pgx automatically uses prepared statements. Prepared statements are incompatible with PgBouncer. This can be
disabled by setting a different QueryExecMode in ConnConfig.DefaultQueryExecMode. disabled by setting a different QueryExecMode in ConnConfig.DefaultQueryExecMode.
*/ */
package pgx package pgx

View File

@ -67,6 +67,10 @@ type LargeObject struct {
} }
// Write writes p to the large object and returns the number of bytes written and an error if not all of p was written. // Write writes p to the large object and returns the number of bytes written and an error if not all of p was written.
//
// Write is implemented with a single call to lowrite. The PostgreSQL wire protocol has a limit of 1 GB - 1 per message.
// See definition of PQ_LARGE_MESSAGE_LIMIT in the PostgreSQL source code. To allow for the other data in the message,
// len(p) should be no larger than 1 GB - 1 KB.
func (o *LargeObject) Write(p []byte) (int, error) { func (o *LargeObject) Write(p []byte) (int, error) {
var n int var n int
err := o.tx.QueryRow(o.ctx, "select lowrite($1, $2)", o.fd, p).Scan(&n) err := o.tx.QueryRow(o.ctx, "select lowrite($1, $2)", o.fd, p).Scan(&n)
@ -82,6 +86,10 @@ func (o *LargeObject) Write(p []byte) (int, error) {
} }
// Read reads up to len(p) bytes into p returning the number of bytes read. // Read reads up to len(p) bytes into p returning the number of bytes read.
//
// Read is implemented with a single call to loread. PostgreSQL internally allocates a single buffer for the response.
// The largest buffer PostgreSQL will allocate is 1 GB - 1. See definition of MaxAllocSize in the PostgreSQL source
// code. To allow for the other data in the message, len(p) should be no larger than 1 GB - 1 KB.
func (o *LargeObject) Read(p []byte) (int, error) { func (o *LargeObject) Read(p []byte) (int, error) {
var res []byte var res []byte
err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, len(p)).Scan(&res) err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, len(p)).Scan(&res)

View File

@ -14,6 +14,9 @@ import (
// //
// conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2}) // conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2})
// conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2) // conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2)
//
// Named placeholders are case sensitive and must start with a letter or underscore. Subsequent characters can be
// letters, numbers, or underscores.
type NamedArgs map[string]any type NamedArgs map[string]any
// RewriteQuery implements the QueryRewriter interface. // RewriteQuery implements the QueryRewriter interface.
@ -80,7 +83,7 @@ func rawState(l *sqlLexer) stateFn {
return doubleQuoteState return doubleQuoteState
case '@': case '@':
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:]) nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
if isLetter(nextRune) { if isLetter(nextRune) || nextRune == '_' {
if l.pos-l.start > 0 { if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos-width]) l.parts = append(l.parts, l.src[l.start:l.pos-width])
} }

View File

@ -60,6 +60,11 @@ type Config struct {
// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received. // OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
OnNotification NotificationHandler OnNotification NotificationHandler
// OnPgError is a callback function called when a Postgres error is received by the server. The default handler will close
// the connection on any FATAL errors. If you override this handler you should call the previously set handler or ensure
// that you close on FATAL errors by returning false.
OnPgError PgErrorHandler
createdByParseConfig bool // Used to enforce created by ParseConfig rule. createdByParseConfig bool // Used to enforce created by ParseConfig rule.
} }
@ -232,12 +237,12 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
connStringSettings, err = parseURLSettings(connString) connStringSettings, err = parseURLSettings(connString)
if err != nil { if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err} return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as URL", err: err}
} }
} else { } else {
connStringSettings, err = parseDSNSettings(connString) connStringSettings, err = parseDSNSettings(connString)
if err != nil { if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err} return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as DSN", err: err}
} }
} }
} }
@ -246,7 +251,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
if service, present := settings["service"]; present { if service, present := settings["service"]; present {
serviceSettings, err := parseServiceSettings(settings["servicefile"], service) serviceSettings, err := parseServiceSettings(settings["servicefile"], service)
if err != nil { if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err} return nil, &ParseConfigError{ConnString: connString, msg: "failed to read service", err: err}
} }
settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings) settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings)
@ -261,12 +266,19 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend { BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend {
return pgproto3.NewFrontend(r, w) return pgproto3.NewFrontend(r, w)
}, },
OnPgError: func(_ *PgConn, pgErr *PgError) bool {
// we want to automatically close any fatal errors
if strings.EqualFold(pgErr.Severity, "FATAL") {
return false
}
return true
},
} }
if connectTimeoutSetting, present := settings["connect_timeout"]; present { if connectTimeoutSetting, present := settings["connect_timeout"]; present {
connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting) connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting)
if err != nil { if err != nil {
return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err} return nil, &ParseConfigError{ConnString: connString, msg: "invalid connect_timeout", err: err}
} }
config.ConnectTimeout = connectTimeout config.ConnectTimeout = connectTimeout
config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout) config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout)
@ -328,7 +340,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
port, err := parsePort(portStr) port, err := parsePort(portStr)
if err != nil { if err != nil {
return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err} return nil, &ParseConfigError{ConnString: connString, msg: "invalid port", err: err}
} }
var tlsConfigs []*tls.Config var tlsConfigs []*tls.Config
@ -340,7 +352,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
var err error var err error
tlsConfigs, err = configTLS(settings, host, options) tlsConfigs, err = configTLS(settings, host, options)
if err != nil { if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err} return nil, &ParseConfigError{ConnString: connString, msg: "failed to configure TLS", err: err}
} }
} }
@ -384,7 +396,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
case "any": case "any":
// do nothing // do nothing
default: default:
return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)} return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
} }
return config, nil return config, nil

View File

@ -57,22 +57,23 @@ func (pe *PgError) SQLState() string {
return pe.Code return pe.Code
} }
type connectError struct { // ConnectError is the error returned when a connection attempt fails.
config *Config type ConnectError struct {
Config *Config // The configuration that was used in the connection attempt.
msg string msg string
err error err error
} }
func (e *connectError) Error() string { func (e *ConnectError) Error() string {
sb := &strings.Builder{} sb := &strings.Builder{}
fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg) fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.Config.Host, e.Config.User, e.Config.Database, e.msg)
if e.err != nil { if e.err != nil {
fmt.Fprintf(sb, " (%s)", e.err.Error()) fmt.Fprintf(sb, " (%s)", e.err.Error())
} }
return sb.String() return sb.String()
} }
func (e *connectError) Unwrap() error { func (e *ConnectError) Unwrap() error {
return e.err return e.err
} }
@ -88,33 +89,38 @@ func (e *connLockError) Error() string {
return e.status return e.status
} }
type parseConfigError struct { // ParseConfigError is the error returned when a connection string cannot be parsed.
connString string type ParseConfigError struct {
ConnString string // The connection string that could not be parsed.
msg string msg string
err error err error
} }
func (e *parseConfigError) Error() string { func (e *ParseConfigError) Error() string {
connString := redactPW(e.connString) // Now that ParseConfigError is public and ConnString is available to the developer, perhaps it would be better only
// return a static string. That would ensure that the error message cannot leak a password. The ConnString field would
// allow access to the original string if desired and Unwrap would allow access to the underlying error.
connString := redactPW(e.ConnString)
if e.err == nil { if e.err == nil {
return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg) return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg)
} }
return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error()) return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error())
} }
func (e *parseConfigError) Unwrap() error { func (e *ParseConfigError) Unwrap() error {
return e.err return e.err
} }
func normalizeTimeoutError(ctx context.Context, err error) error { func normalizeTimeoutError(ctx context.Context, err error) error {
if err, ok := err.(net.Error); ok && err.Timeout() { var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
if ctx.Err() == context.Canceled { if ctx.Err() == context.Canceled {
// Since the timeout was caused by a context cancellation, the actual error is context.Canceled not the timeout error. // Since the timeout was caused by a context cancellation, the actual error is context.Canceled not the timeout error.
return context.Canceled return context.Canceled
} else if ctx.Err() == context.DeadlineExceeded { } else if ctx.Err() == context.DeadlineExceeded {
return &errTimeout{err: ctx.Err()} return &errTimeout{err: ctx.Err()}
} else { } else {
return &errTimeout{err: err} return &errTimeout{err: netErr}
} }
} }
return err return err

View File

@ -52,6 +52,12 @@ type LookupFunc func(ctx context.Context, host string) (addrs []string, err erro
// BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. // BuildFrontendFunc is a function that can be used to create Frontend implementation for connection.
type BuildFrontendFunc func(r io.Reader, w io.Writer) *pgproto3.Frontend type BuildFrontendFunc func(r io.Reader, w io.Writer) *pgproto3.Frontend
// PgErrorHandler is a function that handles errors returned from Postgres. This function must return true to keep
// the connection open. Returning false will cause the connection to be closed immediately. You should return
// false on any FATAL-severity errors. This will not receive network errors. The *PgConn is provided so the handler is
// aware of the origin of the error, but it must not invoke any query method.
type PgErrorHandler func(*PgConn, *PgError) bool
// NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at // NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at
// any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin // any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin
// of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY // of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY
@ -146,11 +152,11 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
ctx := octx ctx := octx
fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs) fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs)
if err != nil { if err != nil {
return nil, &connectError{config: config, msg: "hostname resolving error", err: err} return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: err}
} }
if len(fallbackConfigs) == 0 { if len(fallbackConfigs) == 0 {
return nil, &connectError{config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")} return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")}
} }
foundBestServer := false foundBestServer := false
@ -172,7 +178,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
foundBestServer = true foundBestServer = true
break break
} else if pgerr, ok := err.(*PgError); ok { } else if pgerr, ok := err.(*PgError); ok {
err = &connectError{config: config, msg: "server error", err: pgerr} err = &ConnectError{Config: config, msg: "server error", err: pgerr}
const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password
const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings
const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist
@ -183,7 +189,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE { pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE {
break break
} }
} else if cerr, ok := err.(*connectError); ok { } else if cerr, ok := err.(*ConnectError); ok {
if _, ok := cerr.err.(*NotPreferredError); ok { if _, ok := cerr.err.(*NotPreferredError); ok {
fallbackConfig = fc fallbackConfig = fc
} }
@ -193,7 +199,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
if !foundBestServer && fallbackConfig != nil { if !foundBestServer && fallbackConfig != nil {
pgConn, err = connect(ctx, config, fallbackConfig, true) pgConn, err = connect(ctx, config, fallbackConfig, true)
if pgerr, ok := err.(*PgError); ok { if pgerr, ok := err.(*PgError); ok {
err = &connectError{config: config, msg: "server error", err: pgerr} err = &ConnectError{Config: config, msg: "server error", err: pgerr}
} }
} }
@ -205,7 +211,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
err := config.AfterConnect(ctx, pgConn) err := config.AfterConnect(ctx, pgConn)
if err != nil { if err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, &connectError{config: config, msg: "AfterConnect error", err: err} return nil, &ConnectError{Config: config, msg: "AfterConnect error", err: err}
} }
} }
@ -277,7 +283,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
netConn, err := config.DialFunc(ctx, network, address) netConn, err := config.DialFunc(ctx, network, address)
if err != nil { if err != nil {
return nil, &connectError{config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)} return nil, &ConnectError{Config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)}
} }
pgConn.conn = netConn pgConn.conn = netConn
@ -289,7 +295,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.
if err != nil { if err != nil {
netConn.Close() netConn.Close()
return nil, &connectError{config: config, msg: "tls error", err: err} return nil, &ConnectError{Config: config, msg: "tls error", err: normalizeTimeoutError(ctx, err)}
} }
pgConn.conn = nbTLSConn pgConn.conn = nbTLSConn
@ -330,7 +336,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
pgConn.frontend.Send(&startupMsg) pgConn.frontend.Send(&startupMsg)
if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil { if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)} return nil, &ConnectError{Config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)}
} }
for { for {
@ -340,7 +346,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
if err, ok := err.(*PgError); ok { if err, ok := err.(*PgError); ok {
return nil, err return nil, err
} }
return nil, &connectError{config: config, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)} return nil, &ConnectError{Config: config, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)}
} }
switch msg := msg.(type) { switch msg := msg.(type) {
@ -353,26 +359,26 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
err = pgConn.txPasswordMessage(pgConn.config.Password) err = pgConn.txPasswordMessage(pgConn.config.Password)
if err != nil { if err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed to write password message", err: err} return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err}
} }
case *pgproto3.AuthenticationMD5Password: case *pgproto3.AuthenticationMD5Password:
digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:])) digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:]))
err = pgConn.txPasswordMessage(digestedPassword) err = pgConn.txPasswordMessage(digestedPassword)
if err != nil { if err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed to write password message", err: err} return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err}
} }
case *pgproto3.AuthenticationSASL: case *pgproto3.AuthenticationSASL:
err = pgConn.scramAuth(msg.AuthMechanisms) err = pgConn.scramAuth(msg.AuthMechanisms)
if err != nil { if err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed SASL auth", err: err} return nil, &ConnectError{Config: config, msg: "failed SASL auth", err: err}
} }
case *pgproto3.AuthenticationGSS: case *pgproto3.AuthenticationGSS:
err = pgConn.gssAuth() err = pgConn.gssAuth()
if err != nil { if err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed GSS auth", err: err} return nil, &ConnectError{Config: config, msg: "failed GSS auth", err: err}
} }
case *pgproto3.ReadyForQuery: case *pgproto3.ReadyForQuery:
pgConn.status = connStatusIdle pgConn.status = connStatusIdle
@ -390,7 +396,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
return pgConn, nil return pgConn, nil
} }
pgConn.conn.Close() pgConn.conn.Close()
return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err} return nil, &ConnectError{Config: config, msg: "ValidateConnect failed", err: err}
} }
} }
return pgConn, nil return pgConn, nil
@ -401,7 +407,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
return nil, ErrorResponseToPgError(msg) return nil, ErrorResponseToPgError(msg)
default: default:
pgConn.conn.Close() pgConn.conn.Close()
return nil, &connectError{config: config, msg: "received unexpected message", err: err} return nil, &ConnectError{Config: config, msg: "received unexpected message", err: err}
} }
} }
} }
@ -547,11 +553,12 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
case *pgproto3.ParameterStatus: case *pgproto3.ParameterStatus:
pgConn.parameterStatuses[msg.Name] = msg.Value pgConn.parameterStatuses[msg.Name] = msg.Value
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
if msg.Severity == "FATAL" { err := ErrorResponseToPgError(msg)
if pgConn.config.OnPgError != nil && !pgConn.config.OnPgError(pgConn, err) {
pgConn.status = connStatusClosed pgConn.status = connStatusClosed
pgConn.conn.Close() // Ignore error as the connection is already broken and there is already an error to return. pgConn.conn.Close() // Ignore error as the connection is already broken and there is already an error to return.
close(pgConn.cleanupDone) close(pgConn.cleanupDone)
return nil, ErrorResponseToPgError(msg) return nil, err
} }
case *pgproto3.NoticeResponse: case *pgproto3.NoticeResponse:
if pgConn.config.OnNotice != nil { if pgConn.config.OnNotice != nil {
@ -2046,6 +2053,13 @@ func (p *Pipeline) Flush() error {
// Sync establishes a synchronization point and flushes the queued requests. // Sync establishes a synchronization point and flushes the queued requests.
func (p *Pipeline) Sync() error { func (p *Pipeline) Sync() error {
if p.closed {
if p.err != nil {
return p.err
}
return errors.New("pipeline closed")
}
p.conn.frontend.SendSync(&pgproto3.Sync{}) p.conn.frontend.SendSync(&pgproto3.Sync{})
err := p.Flush() err := p.Flush()
if err != nil { if err != nil {
@ -2062,10 +2076,21 @@ func (p *Pipeline) Sync() error {
// *PipelineSync. If an ErrorResponse is received from the server, results will be nil and err will be a *PgError. If no // *PipelineSync. If an ErrorResponse is received from the server, results will be nil and err will be a *PgError. If no
// results are available, results and err will both be nil. // results are available, results and err will both be nil.
func (p *Pipeline) GetResults() (results any, err error) { func (p *Pipeline) GetResults() (results any, err error) {
if p.closed {
if p.err != nil {
return nil, p.err
}
return nil, errors.New("pipeline closed")
}
if p.expectedReadyForQueryCount == 0 { if p.expectedReadyForQueryCount == 0 {
return nil, nil return nil, nil
} }
return p.getResults()
}
func (p *Pipeline) getResults() (results any, err error) {
for { for {
msg, err := p.conn.receiveMessage() msg, err := p.conn.receiveMessage()
if err != nil { if err != nil {
@ -2092,7 +2117,8 @@ func (p *Pipeline) GetResults() (results any, err error) {
case *pgproto3.ParseComplete: case *pgproto3.ParseComplete:
peekedMsg, err := p.conn.peekMessage() peekedMsg, err := p.conn.peekMessage()
if err != nil { if err != nil {
return nil, err p.conn.asyncClose()
return nil, normalizeTimeoutError(p.ctx, err)
} }
if _, ok := peekedMsg.(*pgproto3.ParameterDescription); ok { if _, ok := peekedMsg.(*pgproto3.ParameterDescription); ok {
return p.getResultsPrepare() return p.getResultsPrepare()
@ -2152,6 +2178,7 @@ func (p *Pipeline) Close() error {
if p.closed { if p.closed {
return p.err return p.err
} }
p.closed = true p.closed = true
if p.pendingSync { if p.pendingSync {
@ -2164,7 +2191,7 @@ func (p *Pipeline) Close() error {
} }
for p.expectedReadyForQueryCount > 0 { for p.expectedReadyForQueryCount > 0 {
_, err := p.GetResults() _, err := p.getResults()
if err != nil { if err != nil {
p.err = err p.err = err
var pgErr *PgError var pgErr *PgError

View File

@ -38,6 +38,7 @@ type Backend struct {
terminate Terminate terminate Terminate
bodyLen int bodyLen int
maxBodyLen int // maxBodyLen is the maximum length of a message body in octets. If a message body exceeds this length, Receive will return an error.
msgType byte msgType byte
partialMsg bool partialMsg bool
authType uint32 authType uint32
@ -158,6 +159,9 @@ func (b *Backend) Receive() (FrontendMessage, error) {
b.msgType = header[0] b.msgType = header[0]
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
if b.maxBodyLen > 0 && b.bodyLen > b.maxBodyLen {
return nil, &ExceededMaxBodyLenErr{b.maxBodyLen, b.bodyLen}
}
b.partialMsg = true b.partialMsg = true
} }
@ -260,3 +264,12 @@ func (b *Backend) SetAuthType(authType uint32) error {
return nil return nil
} }
// SetMaxBodyLen sets the maximum length of a message body in octets. If a message body exceeds this length, Receive will return
// an error. This is useful for protecting against malicious clients that send large messages with the intent of
// causing memory exhaustion.
// The default value is 0.
// If maxBodyLen is 0, then no maximum is enforced.
func (b *Backend) SetMaxBodyLen(maxBodyLen int) {
b.maxBodyLen = maxBodyLen
}

View File

@ -70,6 +70,15 @@ func (e *writeError) Unwrap() error {
return e.err return e.err
} }
type ExceededMaxBodyLenErr struct {
MaxExpectedBodyLen int
ActualBodyLen int
}
func (e *ExceededMaxBodyLenErr) Error() string {
return fmt.Sprintf("invalid body length: expected at most %d, but got %d", e.MaxExpectedBodyLen, e.ActualBodyLen)
}
// getValueFromJSON gets the value from a protocol message representation in JSON. // getValueFromJSON gets the value from a protocol message representation in JSON.
func getValueFromJSON(v map[string]string) ([]byte, error) { func getValueFromJSON(v map[string]string) ([]byte, error) {
if v == nil { if v == nil {

View File

@ -667,7 +667,12 @@ const structTagKey = "db"
func fieldPosByName(fldDescs []pgconn.FieldDescription, field string) (i int) { func fieldPosByName(fldDescs []pgconn.FieldDescription, field string) (i int) {
i = -1 i = -1
for i, desc := range fldDescs { for i, desc := range fldDescs {
if strings.EqualFold(desc.Name, field) {
// Snake case support.
field = strings.ReplaceAll(field, "_", "")
descName := strings.ReplaceAll(desc.Name, "_", "")
if strings.EqualFold(descName, field) {
return i return i
} }
} }

View File

@ -21,10 +21,7 @@
// return err // return err
// } // }
// //
// db, err := stdlib.OpenDBFromPool(pool) // db := stdlib.OpenDBFromPool(pool)
// if err != nil {
// return err
// }
// //
// Or a pgx.ConnConfig can be used to set configuration not accessible via connection string. In this case the // Or a pgx.ConnConfig can be used to set configuration not accessible via connection string. In this case the
// pgx.ConnConfig must first be registered with the driver. This registration returns a connection string which is used // pgx.ConnConfig must first be registered with the driver. This registration returns a connection string which is used

2
vendor/modules.txt vendored
View File

@ -322,7 +322,7 @@ github.com/jackc/pgpassfile
# github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a # github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a
## explicit; go 1.14 ## explicit; go 1.14
github.com/jackc/pgservicefile github.com/jackc/pgservicefile
# github.com/jackc/pgx/v5 v5.5.1 # github.com/jackc/pgx/v5 v5.5.2
## explicit; go 1.19 ## explicit; go 1.19
github.com/jackc/pgx/v5 github.com/jackc/pgx/v5
github.com/jackc/pgx/v5/internal/anynil github.com/jackc/pgx/v5/internal/anynil