mirror of
1
Fork 0

[chore]: Bump github.com/jackc/pgx/v5 from 5.5.5 to 5.6.0 (#2929)

This commit is contained in:
dependabot[bot] 2024-05-27 09:35:41 +00:00 committed by GitHub
parent 3d3e99ae52
commit 0a18c0d802
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
36 changed files with 969 additions and 561 deletions

2
go.mod
View File

@ -39,7 +39,7 @@ require (
github.com/gorilla/feeds v1.1.2 github.com/gorilla/feeds v1.1.2
github.com/gorilla/websocket v1.5.1 github.com/gorilla/websocket v1.5.1
github.com/h2non/filetype v1.1.3 github.com/h2non/filetype v1.1.3
github.com/jackc/pgx/v5 v5.5.5 github.com/jackc/pgx/v5 v5.6.0
github.com/microcosm-cc/bluemonday v1.0.26 github.com/microcosm-cc/bluemonday v1.0.26
github.com/miekg/dns v1.1.59 github.com/miekg/dns v1.1.59
github.com/minio/minio-go/v7 v7.0.70 github.com/minio/minio-go/v7 v7.0.70

4
go.sum
View File

@ -374,8 +374,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY=
github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=

View File

@ -1,3 +1,22 @@
# 5.6.0 (May 25, 2024)
* Add StrictNamedArgs (Tomas Zahradnicek)
* Add support for macaddr8 type (Carlos Pérez-Aradros Herce)
* Add SeverityUnlocalized field to PgError / Notice
* Performance optimization of RowToStructByPos/Name (Zach Olstein)
* Allow customizing context canceled behavior for pgconn
* Add ScanLocation to pgtype.Timestamp[tz]Codec
* Add custom data to pgconn.PgConn
* Fix ResultReader.Read() to handle nil values
* Do not encode interval microseconds when they are 0 (Carlos Pérez-Aradros Herce)
* pgconn.SafeToRetry checks for wrapped errors (tjasko)
* Failed connection attempts include all errors
* Optimize LargeObject.Read (Mitar)
* Add tracing for connection acquire and release from pool (ngavinsir)
* Fix encode driver.Valuer not called when nil
* Add support for custom JSON marshal and unmarshal (Mitar)
* Use Go default keepalive for TCP connections (Hans-Joachim Kliemeck)
# 5.5.5 (March 9, 2024) # 5.5.5 (March 9, 2024)
Use spaces instead of parentheses for SQL sanitization. Use spaces instead of parentheses for SQL sanitization.

View File

@ -29,6 +29,7 @@ Create and setup a test database:
export PGDATABASE=pgx_test export PGDATABASE=pgx_test
createdb createdb
psql -c 'create extension hstore;' psql -c 'create extension hstore;'
psql -c 'create extension ltree;'
psql -c 'create domain uint64 as numeric(20,0);' psql -c 'create domain uint64 as numeric(20,0);'
``` ```

View File

@ -92,7 +92,7 @@ See the presentation at Golang Estonia, [PGX Top to Bottom](https://www.youtube.
## Supported Go and PostgreSQL Versions ## Supported Go and PostgreSQL Versions
pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.20 and higher and PostgreSQL 12 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/). pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.21 and higher and PostgreSQL 12 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).
## Version Policy ## Version Policy

View File

@ -12,7 +12,7 @@ import (
type QueuedQuery struct { type QueuedQuery struct {
SQL string SQL string
Arguments []any Arguments []any
fn batchItemFunc Fn batchItemFunc
sd *pgconn.StatementDescription sd *pgconn.StatementDescription
} }
@ -20,7 +20,7 @@ type batchItemFunc func(br BatchResults) error
// Query sets fn to be called when the response to qq is received. // Query sets fn to be called when the response to qq is received.
func (qq *QueuedQuery) Query(fn func(rows Rows) error) { func (qq *QueuedQuery) Query(fn func(rows Rows) error) {
qq.fn = func(br BatchResults) error { qq.Fn = func(br BatchResults) error {
rows, _ := br.Query() rows, _ := br.Query()
defer rows.Close() defer rows.Close()
@ -36,7 +36,7 @@ func (qq *QueuedQuery) Query(fn func(rows Rows) error) {
// Query sets fn to be called when the response to qq is received. // Query sets fn to be called when the response to qq is received.
func (qq *QueuedQuery) QueryRow(fn func(row Row) error) { func (qq *QueuedQuery) QueryRow(fn func(row Row) error) {
qq.fn = func(br BatchResults) error { qq.Fn = func(br BatchResults) error {
row := br.QueryRow() row := br.QueryRow()
return fn(row) return fn(row)
} }
@ -44,7 +44,7 @@ func (qq *QueuedQuery) QueryRow(fn func(row Row) error) {
// Exec sets fn to be called when the response to qq is received. // Exec sets fn to be called when the response to qq is received.
func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) { func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) {
qq.fn = func(br BatchResults) error { qq.Fn = func(br BatchResults) error {
ct, err := br.Exec() ct, err := br.Exec()
if err != nil { if err != nil {
return err return err
@ -228,8 +228,8 @@ func (br *batchResults) Close() error {
// Read and run fn for all remaining items // Read and run fn for all remaining items
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) { for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
if br.b.QueuedQueries[br.qqIdx].fn != nil { if br.b.QueuedQueries[br.qqIdx].Fn != nil {
err := br.b.QueuedQueries[br.qqIdx].fn(br) err := br.b.QueuedQueries[br.qqIdx].Fn(br)
if err != nil { if err != nil {
br.err = err br.err = err
} }
@ -397,8 +397,8 @@ func (br *pipelineBatchResults) Close() error {
// Read and run fn for all remaining items // Read and run fn for all remaining items
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) { for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
if br.b.QueuedQueries[br.qqIdx].fn != nil { if br.b.QueuedQueries[br.qqIdx].Fn != nil {
err := br.b.QueuedQueries[br.qqIdx].fn(br) err := br.b.QueuedQueries[br.qqIdx].Fn(br)
if err != nil { if err != nil {
br.err = err br.err = err
} }

View File

@ -10,7 +10,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/jackc/pgx/v5/internal/anynil"
"github.com/jackc/pgx/v5/internal/sanitize" "github.com/jackc/pgx/v5/internal/sanitize"
"github.com/jackc/pgx/v5/internal/stmtcache" "github.com/jackc/pgx/v5/internal/stmtcache"
"github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgconn"
@ -624,7 +623,7 @@ const (
// to execute. It does not use named prepared statements. But it does use the unnamed prepared statement to get the // to execute. It does not use named prepared statements. But it does use the unnamed prepared statement to get the
// statement description on the first round trip and then uses it to execute the query on the second round trip. This // statement description on the first round trip and then uses it to execute the query on the second round trip. This
// may cause problems with connection poolers that switch the underlying connection between round trips. It is safe // may cause problems with connection poolers that switch the underlying connection between round trips. It is safe
// even when the the database schema is modified concurrently. // even when the database schema is modified concurrently.
QueryExecModeDescribeExec QueryExecModeDescribeExec
// Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended protocol // Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended protocol
@ -755,7 +754,6 @@ optionLoop:
} }
c.eqb.reset() c.eqb.reset()
anynil.NormalizeSlice(args)
rows := c.getRows(ctx, sql, args) rows := c.getRows(ctx, sql, args)
var err error var err error

View File

@ -11,9 +11,10 @@ The primary way of establishing a connection is with [pgx.Connect]:
conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL")) conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL"))
The database connection string can be in URL or DSN format. Both PostgreSQL settings and pgx settings can be specified The database connection string can be in URL or key/value format. Both PostgreSQL settings and pgx settings can be
here. In addition, a config struct can be created by [ParseConfig] and modified before establishing the connection with specified here. In addition, a config struct can be created by [ParseConfig] and modified before establishing the
[ConnectConfig] to configure settings such as tracing that cannot be configured with a connection string. connection with [ConnectConfig] to configure settings such as tracing that cannot be configured with a connection
string.
Connection Pool Connection Pool
@ -23,8 +24,8 @@ github.com/jackc/pgx/v5/pgxpool for a concurrency safe connection pool.
Query Interface Query Interface
pgx implements Query in the familiar database/sql style. However, pgx provides generic functions such as CollectRows and pgx implements Query in the familiar database/sql style. However, pgx provides generic functions such as CollectRows and
ForEachRow that are a simpler and safer way of processing rows than manually calling rows.Next(), rows.Scan, and ForEachRow that are a simpler and safer way of processing rows than manually calling defer rows.Close(), rows.Next(),
rows.Err(). rows.Scan, and rows.Err().
CollectRows can be used collect all returned rows into a slice. CollectRows can be used collect all returned rows into a slice.

View File

@ -1,10 +1,8 @@
package pgx package pgx
import ( import (
"database/sql/driver"
"fmt" "fmt"
"github.com/jackc/pgx/v5/internal/anynil"
"github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgtype"
) )
@ -23,10 +21,15 @@ type ExtendedQueryBuilder struct {
func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescription, args []any) error { func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescription, args []any) error {
eqb.reset() eqb.reset()
anynil.NormalizeSlice(args)
if sd == nil { if sd == nil {
return eqb.appendParamsForQueryExecModeExec(m, args) for i := range args {
err := eqb.appendParam(m, 0, pgtype.TextFormatCode, args[i])
if err != nil {
err = fmt.Errorf("failed to encode args[%d]: %w", i, err)
return err
}
}
return nil
} }
if len(sd.ParamOIDs) != len(args) { if len(sd.ParamOIDs) != len(args) {
@ -113,10 +116,6 @@ func (eqb *ExtendedQueryBuilder) reset() {
} }
func (eqb *ExtendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg any) ([]byte, error) { func (eqb *ExtendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg any) ([]byte, error) {
if anynil.Is(arg) {
return nil, nil
}
if eqb.paramValueBytes == nil { if eqb.paramValueBytes == nil {
eqb.paramValueBytes = make([]byte, 0, 128) eqb.paramValueBytes = make([]byte, 0, 128)
} }
@ -145,74 +144,3 @@ func (eqb *ExtendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid ui
return m.FormatCodeForOID(oid) return m.FormatCodeForOID(oid)
} }
// appendParamsForQueryExecModeExec appends the args to eqb.
//
// Parameters must be encoded in the text format because of differences in type conversion between timestamps and
// dates. In QueryExecModeExec we don't know what the actual PostgreSQL type is. To determine the type we use the
// Go type to OID type mapping registered by RegisterDefaultPgType. However, the Go time.Time represents both
// PostgreSQL timestamp[tz] and date. To use the binary format we would need to also specify what the PostgreSQL
// type OID is. But that would mean telling PostgreSQL that we have sent a timestamp[tz] when what is needed is a date.
// This means that the value is converted from text to timestamp[tz] to date. This means it does a time zone conversion
// before converting it to date. This means that dates can be shifted by one day. In text format without that double
// type conversion it takes the date directly and ignores time zone (i.e. it works).
//
// Given that the whole point of QueryExecModeExec is to operate without having to know the PostgreSQL types there is
// no way to safely use binary or to specify the parameter OIDs.
func (eqb *ExtendedQueryBuilder) appendParamsForQueryExecModeExec(m *pgtype.Map, args []any) error {
for _, arg := range args {
if arg == nil {
err := eqb.appendParam(m, 0, TextFormatCode, arg)
if err != nil {
return err
}
} else {
dt, ok := m.TypeForValue(arg)
if !ok {
var tv pgtype.TextValuer
if tv, ok = arg.(pgtype.TextValuer); ok {
t, err := tv.TextValue()
if err != nil {
return err
}
dt, ok = m.TypeForOID(pgtype.TextOID)
if ok {
arg = t
}
}
}
if !ok {
var dv driver.Valuer
if dv, ok = arg.(driver.Valuer); ok {
v, err := dv.Value()
if err != nil {
return err
}
dt, ok = m.TypeForValue(v)
if ok {
arg = v
}
}
}
if !ok {
var str fmt.Stringer
if str, ok = arg.(fmt.Stringer); ok {
dt, ok = m.TypeForOID(pgtype.TextOID)
if ok {
arg = str.String()
}
}
}
if !ok {
return &unknownArgumentTypeQueryExecModeExecError{arg: arg}
}
err := eqb.appendParam(m, dt.OID, TextFormatCode, arg)
if err != nil {
return err
}
}
}
return nil
}

View File

@ -1,36 +0,0 @@
package anynil
import "reflect"
// Is returns true if value is any type of nil. e.g. nil or []byte(nil).
func Is(value any) bool {
if value == nil {
return true
}
refVal := reflect.ValueOf(value)
switch refVal.Kind() {
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice:
return refVal.IsNil()
default:
return false
}
}
// Normalize converts typed nils (e.g. []byte(nil)) into untyped nil. Other values are returned unmodified.
func Normalize(v any) any {
if Is(v) {
return nil
}
return v
}
// NormalizeSlice converts all typed nils (e.g. []byte(nil)) in s into untyped nils. Other values are unmodified. s is
// mutated in place.
func NormalizeSlice(s []any) {
for i := range s {
if Is(s[i]) {
s[i] = nil
}
}
}

View File

@ -4,6 +4,8 @@ import (
"context" "context"
"errors" "errors"
"io" "io"
"github.com/jackc/pgx/v5/pgtype"
) )
// The PostgreSQL wire protocol has a limit of 1 GB - 1 per message. See definition of // The PostgreSQL wire protocol has a limit of 1 GB - 1 per message. See definition of
@ -115,9 +117,10 @@ func (o *LargeObject) Read(p []byte) (int, error) {
expected = maxLargeObjectMessageLength expected = maxLargeObjectMessageLength
} }
var res []byte res := pgtype.PreallocBytes(p[nTotal:])
err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, expected).Scan(&res) err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, expected).Scan(&res)
copy(p[nTotal:], res) // We compute expected so that it always fits into p, so it should never happen
// that PreallocBytes's ScanBytes had to allocate a new slice.
nTotal += len(res) nTotal += len(res)
if err != nil { if err != nil {
return nTotal, err return nTotal, err

View File

@ -2,6 +2,7 @@ package pgx
import ( import (
"context" "context"
"fmt"
"strconv" "strconv"
"strings" "strings"
"unicode/utf8" "unicode/utf8"
@ -21,6 +22,34 @@ type NamedArgs map[string]any
// RewriteQuery implements the QueryRewriter interface. // RewriteQuery implements the QueryRewriter interface.
func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) { func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
return rewriteQuery(na, sql, false)
}
// StrictNamedArgs can be used in the same way as NamedArgs, but provided arguments are also checked to include all
// named arguments that the sql query uses, and no extra arguments.
type StrictNamedArgs map[string]any
// RewriteQuery implements the QueryRewriter interface.
func (sna StrictNamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
return rewriteQuery(sna, sql, true)
}
type namedArg string
type sqlLexer struct {
src string
start int
pos int
nested int // multiline comment nesting level.
stateFn stateFn
parts []any
nameToOrdinal map[namedArg]int
}
type stateFn func(*sqlLexer) stateFn
func rewriteQuery(na map[string]any, sql string, isStrict bool) (newSQL string, newArgs []any, err error) {
l := &sqlLexer{ l := &sqlLexer{
src: sql, src: sql,
stateFn: rawState, stateFn: rawState,
@ -44,27 +73,24 @@ func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, ar
newArgs = make([]any, len(l.nameToOrdinal)) newArgs = make([]any, len(l.nameToOrdinal))
for name, ordinal := range l.nameToOrdinal { for name, ordinal := range l.nameToOrdinal {
newArgs[ordinal-1] = na[string(name)] var found bool
newArgs[ordinal-1], found = na[string(name)]
if isStrict && !found {
return "", nil, fmt.Errorf("argument %s found in sql query but not present in StrictNamedArgs", name)
}
}
if isStrict {
for name := range na {
if _, found := l.nameToOrdinal[namedArg(name)]; !found {
return "", nil, fmt.Errorf("argument %s of StrictNamedArgs not found in sql query", name)
}
}
} }
return sb.String(), newArgs, nil return sb.String(), newArgs, nil
} }
type namedArg string
type sqlLexer struct {
src string
start int
pos int
nested int // multiline comment nesting level.
stateFn stateFn
parts []any
nameToOrdinal map[namedArg]int
}
type stateFn func(*sqlLexer) stateFn
func rawState(l *sqlLexer) stateFn { func rawState(l *sqlLexer) stateFn {
for { for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:]) r, width := utf8.DecodeRuneInString(l.src[l.pos:])

View File

@ -19,6 +19,7 @@ import (
"github.com/jackc/pgpassfile" "github.com/jackc/pgpassfile"
"github.com/jackc/pgservicefile" "github.com/jackc/pgservicefile"
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
"github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgproto3"
) )
@ -39,7 +40,12 @@ type Config struct {
DialFunc DialFunc // e.g. net.Dialer.DialContext DialFunc DialFunc // e.g. net.Dialer.DialContext
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
BuildFrontend BuildFrontendFunc BuildFrontend BuildFrontendFunc
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
// BuildContextWatcherHandler is called to create a ContextWatcherHandler for a connection. The handler is called
// when a context passed to a PgConn method is canceled.
BuildContextWatcherHandler func(*PgConn) ctxwatch.Handler
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
KerberosSrvName string KerberosSrvName string
KerberosSpn string KerberosSpn string
@ -70,7 +76,7 @@ type Config struct {
// ParseConfigOptions contains options that control how a config is built such as GetSSLPassword. // ParseConfigOptions contains options that control how a config is built such as GetSSLPassword.
type ParseConfigOptions struct { type ParseConfigOptions struct {
// GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the the libpq function // GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the libpq function
// PQsetSSLKeyPassHook_OpenSSL. // PQsetSSLKeyPassHook_OpenSSL.
GetSSLPassword GetSSLPasswordFunc GetSSLPassword GetSSLPasswordFunc
} }
@ -112,6 +118,14 @@ type FallbackConfig struct {
TLSConfig *tls.Config // nil disables TLS TLSConfig *tls.Config // nil disables TLS
} }
// connectOneConfig is the configuration for a single attempt to connect to a single host.
type connectOneConfig struct {
network string
address string
originalHostname string // original hostname before resolving
tlsConfig *tls.Config // nil disables TLS
}
// isAbsolutePath checks if the provided value is an absolute path either // isAbsolutePath checks if the provided value is an absolute path either
// beginning with a forward slash (as on Linux-based systems) or with a capital // beginning with a forward slash (as on Linux-based systems) or with a capital
// letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows). // letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows).
@ -146,11 +160,11 @@ func NetworkAddress(host string, port uint16) (network, address string) {
// ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It // ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It
// uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely // uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely
// matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style). // matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format. See
// See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be // https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be empty
// empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file. // to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file.
// //
// # Example DSN // # Example Keyword/Value
// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca // user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca
// //
// # Example URL // # Example URL
@ -169,7 +183,7 @@ func NetworkAddress(host string, port uint16) (network, address string) {
// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb // postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb
// //
// ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed // ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed
// via database URL or DSN: // via database URL or keyword/value:
// //
// PGHOST // PGHOST
// PGPORT // PGPORT
@ -233,16 +247,16 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
connStringSettings := make(map[string]string) connStringSettings := make(map[string]string)
if connString != "" { if connString != "" {
var err error var err error
// connString may be a database URL or a DSN // connString may be a database URL or in PostgreSQL keyword/value format
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
connStringSettings, err = parseURLSettings(connString) connStringSettings, err = parseURLSettings(connString)
if err != nil { if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as URL", err: err} return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as URL", err: err}
} }
} else { } else {
connStringSettings, err = parseDSNSettings(connString) connStringSettings, err = parseKeywordValueSettings(connString)
if err != nil { if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as DSN", err: err} return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as keyword/value", err: err}
} }
} }
} }
@ -266,6 +280,9 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend { BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend {
return pgproto3.NewFrontend(r, w) return pgproto3.NewFrontend(r, w)
}, },
BuildContextWatcherHandler: func(pgConn *PgConn) ctxwatch.Handler {
return &DeadlineContextWatcherHandler{Conn: pgConn.conn}
},
OnPgError: func(_ *PgConn, pgErr *PgError) bool { OnPgError: func(_ *PgConn, pgErr *PgError) bool {
// we want to automatically close any fatal errors // we want to automatically close any fatal errors
if strings.EqualFold(pgErr.Severity, "FATAL") { if strings.EqualFold(pgErr.Severity, "FATAL") {
@ -517,7 +534,7 @@ func isIPOnly(host string) bool {
var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1} var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
func parseDSNSettings(s string) (map[string]string, error) { func parseKeywordValueSettings(s string) (map[string]string, error) {
settings := make(map[string]string) settings := make(map[string]string)
nameMap := map[string]string{ nameMap := map[string]string{
@ -528,7 +545,7 @@ func parseDSNSettings(s string) (map[string]string, error) {
var key, val string var key, val string
eqIdx := strings.IndexRune(s, '=') eqIdx := strings.IndexRune(s, '=')
if eqIdx < 0 { if eqIdx < 0 {
return nil, errors.New("invalid dsn") return nil, errors.New("invalid keyword/value")
} }
key = strings.Trim(s[:eqIdx], " \t\n\r\v\f") key = strings.Trim(s[:eqIdx], " \t\n\r\v\f")
@ -580,7 +597,7 @@ func parseDSNSettings(s string) (map[string]string, error) {
} }
if key == "" { if key == "" {
return nil, errors.New("invalid dsn") return nil, errors.New("invalid keyword/value")
} }
settings[key] = val settings[key] = val
@ -800,7 +817,8 @@ func parsePort(s string) (uint16, error) {
} }
func makeDefaultDialer() *net.Dialer { func makeDefaultDialer() *net.Dialer {
return &net.Dialer{KeepAlive: 5 * time.Minute} // rely on GOLANG KeepAlive settings
return &net.Dialer{}
} }
func makeDefaultResolver() *net.Resolver { func makeDefaultResolver() *net.Resolver {

View File

@ -8,9 +8,8 @@ import (
// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a // ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a
// time. // time.
type ContextWatcher struct { type ContextWatcher struct {
onCancel func() handler Handler
onUnwatchAfterCancel func() unwatchChan chan struct{}
unwatchChan chan struct{}
lock sync.Mutex lock sync.Mutex
watchInProgress bool watchInProgress bool
@ -20,11 +19,10 @@ type ContextWatcher struct {
// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled. // NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled.
// OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and // OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and
// onCancel called. // onCancel called.
func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWatcher { func NewContextWatcher(handler Handler) *ContextWatcher {
cw := &ContextWatcher{ cw := &ContextWatcher{
onCancel: onCancel, handler: handler,
onUnwatchAfterCancel: onUnwatchAfterCancel, unwatchChan: make(chan struct{}),
unwatchChan: make(chan struct{}),
} }
return cw return cw
@ -46,7 +44,7 @@ func (cw *ContextWatcher) Watch(ctx context.Context) {
go func() { go func() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
cw.onCancel() cw.handler.HandleCancel(ctx)
cw.onCancelWasCalled = true cw.onCancelWasCalled = true
<-cw.unwatchChan <-cw.unwatchChan
case <-cw.unwatchChan: case <-cw.unwatchChan:
@ -66,8 +64,17 @@ func (cw *ContextWatcher) Unwatch() {
if cw.watchInProgress { if cw.watchInProgress {
cw.unwatchChan <- struct{}{} cw.unwatchChan <- struct{}{}
if cw.onCancelWasCalled { if cw.onCancelWasCalled {
cw.onUnwatchAfterCancel() cw.handler.HandleUnwatchAfterCancel()
} }
cw.watchInProgress = false cw.watchInProgress = false
} }
} }
type Handler interface {
// HandleCancel is called when the context that a ContextWatcher is currently watching is canceled. canceledCtx is the
// context that was canceled.
HandleCancel(canceledCtx context.Context)
// HandleUnwatchAfterCancel is called when a ContextWatcher that called HandleCancel on this Handler is unwatched.
HandleUnwatchAfterCancel()
}

View File

@ -5,8 +5,8 @@ nearly the same level is the C library libpq.
Establishing a Connection Establishing a Connection
Use Connect to establish a connection. It accepts a connection string in URL or DSN and will read the environment for Use Connect to establish a connection. It accepts a connection string in URL or keyword/value format and will read the
libpq style environment variables. environment for libpq style environment variables.
Executing a Query Executing a Query
@ -20,13 +20,17 @@ result. The ReadAll method reads all query results into memory.
Pipeline Mode Pipeline Mode
Pipeline mode allows sending queries without having read the results of previously sent queries. It allows Pipeline mode allows sending queries without having read the results of previously sent queries. It allows control of
control of exactly how many and when network round trips occur. exactly how many and when network round trips occur.
Context Support Context Support
All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the All potentially blocking operations take a context.Context. The default behavior when a context is canceled is for the
method immediately returns. In most circumstances, this will close the underlying connection. method to immediately return. In most circumstances, this will also close the underlying connection. This behavior can
be customized by using BuildContextWatcherHandler on the Config to create a ctxwatch.Handler with different behavior.
This can be especially useful when queries that are frequently canceled and the overhead of creating new connections is
a problem. DeadlineContextWatcherHandler and CancelRequestContextWatcherHandler can be used to introduce a delay before
interrupting the query in such a way as to close the connection.
The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the
client to abort. client to abort.

View File

@ -12,13 +12,14 @@ import (
// SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server. // SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server.
func SafeToRetry(err error) bool { func SafeToRetry(err error) bool {
if e, ok := err.(interface{ SafeToRetry() bool }); ok { var retryableErr interface{ SafeToRetry() bool }
return e.SafeToRetry() if errors.As(err, &retryableErr) {
return retryableErr.SafeToRetry()
} }
return false return false
} }
// Timeout checks if err was was caused by a timeout. To be specific, it is true if err was caused within pgconn by a // Timeout checks if err was caused by a timeout. To be specific, it is true if err was caused within pgconn by a
// context.DeadlineExceeded or an implementer of net.Error where Timeout() is true. // context.DeadlineExceeded or an implementer of net.Error where Timeout() is true.
func Timeout(err error) bool { func Timeout(err error) bool {
var timeoutErr *errTimeout var timeoutErr *errTimeout
@ -29,23 +30,24 @@ func Timeout(err error) bool {
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for // http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
// detailed field description. // detailed field description.
type PgError struct { type PgError struct {
Severity string Severity string
Code string SeverityUnlocalized string
Message string Code string
Detail string Message string
Hint string Detail string
Position int32 Hint string
InternalPosition int32 Position int32
InternalQuery string InternalPosition int32
Where string InternalQuery string
SchemaName string Where string
TableName string SchemaName string
ColumnName string TableName string
DataTypeName string ColumnName string
ConstraintName string DataTypeName string
File string ConstraintName string
Line int32 File string
Routine string Line int32
Routine string
} }
func (pe *PgError) Error() string { func (pe *PgError) Error() string {
@ -60,23 +62,37 @@ func (pe *PgError) SQLState() string {
// ConnectError is the error returned when a connection attempt fails. // ConnectError is the error returned when a connection attempt fails.
type ConnectError struct { type ConnectError struct {
Config *Config // The configuration that was used in the connection attempt. Config *Config // The configuration that was used in the connection attempt.
msg string
err error err error
} }
func (e *ConnectError) Error() string { func (e *ConnectError) Error() string {
sb := &strings.Builder{} prefix := fmt.Sprintf("failed to connect to `user=%s database=%s`:", e.Config.User, e.Config.Database)
fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.Config.Host, e.Config.User, e.Config.Database, e.msg) details := e.err.Error()
if e.err != nil { if strings.Contains(details, "\n") {
fmt.Fprintf(sb, " (%s)", e.err.Error()) return prefix + "\n\t" + strings.ReplaceAll(details, "\n", "\n\t")
} else {
return prefix + " " + details
} }
return sb.String()
} }
func (e *ConnectError) Unwrap() error { func (e *ConnectError) Unwrap() error {
return e.err return e.err
} }
type perDialConnectError struct {
address string
originalHostname string
err error
}
func (e *perDialConnectError) Error() string {
return fmt.Sprintf("%s (%s): %s", e.address, e.originalHostname, e.err.Error())
}
func (e *perDialConnectError) Unwrap() error {
return e.err
}
type connLockError struct { type connLockError struct {
status string status string
} }
@ -195,10 +211,10 @@ func redactPW(connString string) string {
return redactURL(u) return redactURL(u)
} }
} }
quotedDSN := regexp.MustCompile(`password='[^']*'`) quotedKV := regexp.MustCompile(`password='[^']*'`)
connString = quotedDSN.ReplaceAllLiteralString(connString, "password=xxxxx") connString = quotedKV.ReplaceAllLiteralString(connString, "password=xxxxx")
plainDSN := regexp.MustCompile(`password=[^ ]*`) plainKV := regexp.MustCompile(`password=[^ ]*`)
connString = plainDSN.ReplaceAllLiteralString(connString, "password=xxxxx") connString = plainKV.ReplaceAllLiteralString(connString, "password=xxxxx")
brokenURL := regexp.MustCompile(`:[^:@]+?@`) brokenURL := regexp.MustCompile(`:[^:@]+?@`)
connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@") connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@")
return connString return connString

View File

@ -18,8 +18,8 @@ import (
"github.com/jackc/pgx/v5/internal/iobufpool" "github.com/jackc/pgx/v5/internal/iobufpool"
"github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/internal/pgio"
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
"github.com/jackc/pgx/v5/pgconn/internal/bgreader" "github.com/jackc/pgx/v5/pgconn/internal/bgreader"
"github.com/jackc/pgx/v5/pgconn/internal/ctxwatch"
"github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgproto3"
) )
@ -82,6 +82,8 @@ type PgConn struct {
slowWriteTimer *time.Timer slowWriteTimer *time.Timer
bgReaderStarted chan struct{} bgReaderStarted chan struct{}
customData map[string]any
config *Config config *Config
status byte // One of connStatus* constants status byte // One of connStatus* constants
@ -103,8 +105,9 @@ type PgConn struct {
cleanupDone chan struct{} cleanupDone chan struct{}
} }
// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or keyword/value
// to provide configuration. See documentation for [ParseConfig] for details. ctx can be used to cancel a connect attempt. // format) to provide configuration. See documentation for [ParseConfig] for details. ctx can be used to cancel a
// connect attempt.
func Connect(ctx context.Context, connString string) (*PgConn, error) { func Connect(ctx context.Context, connString string) (*PgConn, error) {
config, err := ParseConfig(connString) config, err := ParseConfig(connString)
if err != nil { if err != nil {
@ -114,9 +117,9 @@ func Connect(ctx context.Context, connString string) (*PgConn, error) {
return ConnectConfig(ctx, config) return ConnectConfig(ctx, config)
} }
// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or keyword/value
// and ParseConfigOptions to provide additional configuration. See documentation for [ParseConfig] for details. ctx can be // format) and ParseConfigOptions to provide additional configuration. See documentation for [ParseConfig] for details.
// used to cancel a connect attempt. // ctx can be used to cancel a connect attempt.
func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptions ParseConfigOptions) (*PgConn, error) { func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptions ParseConfigOptions) (*PgConn, error) {
config, err := ParseConfigWithOptions(connString, parseConfigOptions) config, err := ParseConfigWithOptions(connString, parseConfigOptions)
if err != nil { if err != nil {
@ -131,15 +134,46 @@ func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptio
// //
// If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An // If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An
// authentication error will terminate the chain of attempts (like libpq: // authentication error will terminate the chain of attempts (like libpq:
// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. Otherwise, // https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error.
// if all attempts fail the last error is returned. func ConnectConfig(ctx context.Context, config *Config) (*PgConn, error) {
func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err error) {
// Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from
// zero values. // zero values.
if !config.createdByParseConfig { if !config.createdByParseConfig {
panic("config must be created by ParseConfig") panic("config must be created by ParseConfig")
} }
var allErrors []error
connectConfigs, errs := buildConnectOneConfigs(ctx, config)
if len(errs) > 0 {
allErrors = append(allErrors, errs...)
}
if len(connectConfigs) == 0 {
return nil, &ConnectError{Config: config, err: fmt.Errorf("hostname resolving error: %w", errors.Join(allErrors...))}
}
pgConn, errs := connectPreferred(ctx, config, connectConfigs)
if len(errs) > 0 {
allErrors = append(allErrors, errs...)
return nil, &ConnectError{Config: config, err: errors.Join(allErrors...)}
}
if config.AfterConnect != nil {
err := config.AfterConnect(ctx, pgConn)
if err != nil {
pgConn.conn.Close()
return nil, &ConnectError{Config: config, err: fmt.Errorf("AfterConnect error: %w", err)}
}
}
return pgConn, nil
}
// buildConnectOneConfigs resolves hostnames and builds a list of connectOneConfigs to try connecting to. It returns a
// slice of successfully resolved connectOneConfigs and a slice of errors. It is possible for both slices to contain
// values if some hosts were successfully resolved and others were not.
func buildConnectOneConfigs(ctx context.Context, config *Config) ([]*connectOneConfig, []error) {
// Simplify usage by treating primary config and fallbacks the same. // Simplify usage by treating primary config and fallbacks the same.
fallbackConfigs := []*FallbackConfig{ fallbackConfigs := []*FallbackConfig{
{ {
@ -149,95 +183,28 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
}, },
} }
fallbackConfigs = append(fallbackConfigs, config.Fallbacks...) fallbackConfigs = append(fallbackConfigs, config.Fallbacks...)
ctx := octx
fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs)
if err != nil {
return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: err}
}
if len(fallbackConfigs) == 0 { var configs []*connectOneConfig
return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")}
}
foundBestServer := false var allErrors []error
var fallbackConfig *FallbackConfig
for i, fc := range fallbackConfigs {
// ConnectTimeout restricts the whole connection process.
if config.ConnectTimeout != 0 {
// create new context first time or when previous host was different
if i == 0 || (fallbackConfigs[i].Host != fallbackConfigs[i-1].Host) {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(octx, config.ConnectTimeout)
defer cancel()
}
} else {
ctx = octx
}
pgConn, err = connect(ctx, config, fc, false)
if err == nil {
foundBestServer = true
break
} else if pgerr, ok := err.(*PgError); ok {
err = &ConnectError{Config: config, msg: "server error", err: pgerr}
const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password
const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings
const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist
const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege
if pgerr.Code == ERRCODE_INVALID_PASSWORD ||
pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION && fc.TLSConfig != nil ||
pgerr.Code == ERRCODE_INVALID_CATALOG_NAME ||
pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE {
break
}
} else if cerr, ok := err.(*ConnectError); ok {
if _, ok := cerr.err.(*NotPreferredError); ok {
fallbackConfig = fc
}
}
}
if !foundBestServer && fallbackConfig != nil { for _, fb := range fallbackConfigs {
pgConn, err = connect(ctx, config, fallbackConfig, true)
if pgerr, ok := err.(*PgError); ok {
err = &ConnectError{Config: config, msg: "server error", err: pgerr}
}
}
if err != nil {
return nil, err // no need to wrap in connectError because it will already be wrapped in all cases except PgError
}
if config.AfterConnect != nil {
err := config.AfterConnect(ctx, pgConn)
if err != nil {
pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "AfterConnect error", err: err}
}
}
return pgConn, nil
}
func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*FallbackConfig) ([]*FallbackConfig, error) {
var configs []*FallbackConfig
var lookupErrors []error
for _, fb := range fallbacks {
// skip resolve for unix sockets // skip resolve for unix sockets
if isAbsolutePath(fb.Host) { if isAbsolutePath(fb.Host) {
configs = append(configs, &FallbackConfig{ network, address := NetworkAddress(fb.Host, fb.Port)
Host: fb.Host, configs = append(configs, &connectOneConfig{
Port: fb.Port, network: network,
TLSConfig: fb.TLSConfig, address: address,
originalHostname: fb.Host,
tlsConfig: fb.TLSConfig,
}) })
continue continue
} }
ips, err := lookupFn(ctx, fb.Host) ips, err := config.LookupFunc(ctx, fb.Host)
if err != nil { if err != nil {
lookupErrors = append(lookupErrors, err) allErrors = append(allErrors, err)
continue continue
} }
@ -246,63 +213,126 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba
if err == nil { if err == nil {
port, err := strconv.ParseUint(splitPort, 10, 16) port, err := strconv.ParseUint(splitPort, 10, 16)
if err != nil { if err != nil {
return nil, fmt.Errorf("error parsing port (%s) from lookup: %w", splitPort, err) return nil, []error{fmt.Errorf("error parsing port (%s) from lookup: %w", splitPort, err)}
} }
configs = append(configs, &FallbackConfig{ network, address := NetworkAddress(splitIP, uint16(port))
Host: splitIP, configs = append(configs, &connectOneConfig{
Port: uint16(port), network: network,
TLSConfig: fb.TLSConfig, address: address,
originalHostname: fb.Host,
tlsConfig: fb.TLSConfig,
}) })
} else { } else {
configs = append(configs, &FallbackConfig{ network, address := NetworkAddress(ip, fb.Port)
Host: ip, configs = append(configs, &connectOneConfig{
Port: fb.Port, network: network,
TLSConfig: fb.TLSConfig, address: address,
originalHostname: fb.Host,
tlsConfig: fb.TLSConfig,
}) })
} }
} }
} }
// See https://github.com/jackc/pgx/issues/1464. When Go 1.20 can be used in pgx consider using errors.Join so all return configs, allErrors
// errors are reported.
if len(configs) == 0 && len(lookupErrors) > 0 {
return nil, lookupErrors[0]
}
return configs, nil
} }
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig, // connectPreferred attempts to connect to the preferred host from connectOneConfigs. The connections are attempted in
// order. If a connection is successful it is returned. If no connection is successful then all errors are returned. If
// a connection attempt returns a [NotPreferredError], then that host will be used if no other hosts are successful.
func connectPreferred(ctx context.Context, config *Config, connectOneConfigs []*connectOneConfig) (*PgConn, []error) {
octx := ctx
var allErrors []error
var fallbackConnectOneConfig *connectOneConfig
for i, c := range connectOneConfigs {
// ConnectTimeout restricts the whole connection process.
if config.ConnectTimeout != 0 {
// create new context first time or when previous host was different
if i == 0 || (connectOneConfigs[i].address != connectOneConfigs[i-1].address) {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(octx, config.ConnectTimeout)
defer cancel()
}
} else {
ctx = octx
}
pgConn, err := connectOne(ctx, config, c, false)
if pgConn != nil {
return pgConn, nil
}
allErrors = append(allErrors, err)
var pgErr *PgError
if errors.As(err, &pgErr) {
const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password
const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings
const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist
const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege
if pgErr.Code == ERRCODE_INVALID_PASSWORD ||
pgErr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION && c.tlsConfig != nil ||
pgErr.Code == ERRCODE_INVALID_CATALOG_NAME ||
pgErr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE {
return nil, allErrors
}
}
var npErr *NotPreferredError
if errors.As(err, &npErr) {
fallbackConnectOneConfig = c
}
}
if fallbackConnectOneConfig != nil {
pgConn, err := connectOne(ctx, config, fallbackConnectOneConfig, true)
if err == nil {
return pgConn, nil
}
allErrors = append(allErrors, err)
}
return nil, allErrors
}
// connectOne makes one connection attempt to a single host.
func connectOne(ctx context.Context, config *Config, connectConfig *connectOneConfig,
ignoreNotPreferredErr bool, ignoreNotPreferredErr bool,
) (*PgConn, error) { ) (*PgConn, error) {
pgConn := new(PgConn) pgConn := new(PgConn)
pgConn.config = config pgConn.config = config
pgConn.cleanupDone = make(chan struct{}) pgConn.cleanupDone = make(chan struct{})
pgConn.customData = make(map[string]any)
var err error var err error
network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
netConn, err := config.DialFunc(ctx, network, address) newPerDialConnectError := func(msg string, err error) *perDialConnectError {
if err != nil { err = normalizeTimeoutError(ctx, err)
return nil, &ConnectError{Config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)} e := &perDialConnectError{address: connectConfig.address, originalHostname: connectConfig.originalHostname, err: fmt.Errorf("%s: %w", msg, err)}
return e
} }
pgConn.conn = netConn pgConn.conn, err = config.DialFunc(ctx, connectConfig.network, connectConfig.address)
pgConn.contextWatcher = newContextWatcher(netConn) if err != nil {
pgConn.contextWatcher.Watch(ctx) return nil, newPerDialConnectError("dial error", err)
}
if fallbackConfig.TLSConfig != nil { if connectConfig.tlsConfig != nil {
nbTLSConn, err := startTLS(netConn, fallbackConfig.TLSConfig) pgConn.contextWatcher = ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: pgConn.conn})
pgConn.contextWatcher.Watch(ctx)
tlsConn, err := startTLS(pgConn.conn, connectConfig.tlsConfig)
pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.
if err != nil { if err != nil {
netConn.Close() pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "tls error", err: normalizeTimeoutError(ctx, err)} return nil, newPerDialConnectError("tls error", err)
} }
pgConn.conn = nbTLSConn pgConn.conn = tlsConn
pgConn.contextWatcher = newContextWatcher(nbTLSConn)
pgConn.contextWatcher.Watch(ctx)
} }
pgConn.contextWatcher = ctxwatch.NewContextWatcher(config.BuildContextWatcherHandler(pgConn))
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch() defer pgConn.contextWatcher.Unwatch()
pgConn.parameterStatuses = make(map[string]string) pgConn.parameterStatuses = make(map[string]string)
@ -336,7 +366,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
pgConn.frontend.Send(&startupMsg) pgConn.frontend.Send(&startupMsg)
if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil { if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)} return nil, newPerDialConnectError("failed to write startup message", err)
} }
for { for {
@ -344,9 +374,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
if err != nil { if err != nil {
pgConn.conn.Close() pgConn.conn.Close()
if err, ok := err.(*PgError); ok { if err, ok := err.(*PgError); ok {
return nil, err return nil, newPerDialConnectError("server error", err)
} }
return nil, &ConnectError{Config: config, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)} return nil, newPerDialConnectError("failed to receive message", err)
} }
switch msg := msg.(type) { switch msg := msg.(type) {
@ -359,26 +389,26 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
err = pgConn.txPasswordMessage(pgConn.config.Password) err = pgConn.txPasswordMessage(pgConn.config.Password)
if err != nil { if err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err} return nil, newPerDialConnectError("failed to write password message", err)
} }
case *pgproto3.AuthenticationMD5Password: case *pgproto3.AuthenticationMD5Password:
digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:])) digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:]))
err = pgConn.txPasswordMessage(digestedPassword) err = pgConn.txPasswordMessage(digestedPassword)
if err != nil { if err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err} return nil, newPerDialConnectError("failed to write password message", err)
} }
case *pgproto3.AuthenticationSASL: case *pgproto3.AuthenticationSASL:
err = pgConn.scramAuth(msg.AuthMechanisms) err = pgConn.scramAuth(msg.AuthMechanisms)
if err != nil { if err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "failed SASL auth", err: err} return nil, newPerDialConnectError("failed SASL auth", err)
} }
case *pgproto3.AuthenticationGSS: case *pgproto3.AuthenticationGSS:
err = pgConn.gssAuth() err = pgConn.gssAuth()
if err != nil { if err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "failed GSS auth", err: err} return nil, newPerDialConnectError("failed GSS auth", err)
} }
case *pgproto3.ReadyForQuery: case *pgproto3.ReadyForQuery:
pgConn.status = connStatusIdle pgConn.status = connStatusIdle
@ -396,7 +426,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
return pgConn, nil return pgConn, nil
} }
pgConn.conn.Close() pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "ValidateConnect failed", err: err} return nil, newPerDialConnectError("ValidateConnect failed", err)
} }
} }
return pgConn, nil return pgConn, nil
@ -404,21 +434,14 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
// handled by ReceiveMessage // handled by ReceiveMessage
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
pgConn.conn.Close() pgConn.conn.Close()
return nil, ErrorResponseToPgError(msg) return nil, newPerDialConnectError("server error", ErrorResponseToPgError(msg))
default: default:
pgConn.conn.Close() pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "received unexpected message", err: err} return nil, newPerDialConnectError("received unexpected message", err)
} }
} }
} }
func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher {
return ctxwatch.NewContextWatcher(
func() { conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) },
func() { conn.SetDeadline(time.Time{}) },
)
}
func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) {
err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103}) err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103})
if err != nil { if err != nil {
@ -928,23 +951,24 @@ func (pgConn *PgConn) Deallocate(ctx context.Context, name string) error {
// ErrorResponseToPgError converts a wire protocol error message to a *PgError. // ErrorResponseToPgError converts a wire protocol error message to a *PgError.
func ErrorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { func ErrorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError {
return &PgError{ return &PgError{
Severity: msg.Severity, Severity: msg.Severity,
Code: string(msg.Code), SeverityUnlocalized: msg.SeverityUnlocalized,
Message: string(msg.Message), Code: string(msg.Code),
Detail: string(msg.Detail), Message: string(msg.Message),
Hint: msg.Hint, Detail: string(msg.Detail),
Position: msg.Position, Hint: msg.Hint,
InternalPosition: msg.InternalPosition, Position: msg.Position,
InternalQuery: string(msg.InternalQuery), InternalPosition: msg.InternalPosition,
Where: string(msg.Where), InternalQuery: string(msg.InternalQuery),
SchemaName: string(msg.SchemaName), Where: string(msg.Where),
TableName: string(msg.TableName), SchemaName: string(msg.SchemaName),
ColumnName: string(msg.ColumnName), TableName: string(msg.TableName),
DataTypeName: string(msg.DataTypeName), ColumnName: string(msg.ColumnName),
ConstraintName: msg.ConstraintName, DataTypeName: string(msg.DataTypeName),
File: string(msg.File), ConstraintName: msg.ConstraintName,
Line: msg.Line, File: string(msg.File),
Routine: string(msg.Routine), Line: msg.Line,
Routine: string(msg.Routine),
} }
} }
@ -987,10 +1011,7 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
defer cancelConn.Close() defer cancelConn.Close()
if ctx != context.Background() { if ctx != context.Background() {
contextWatcher := ctxwatch.NewContextWatcher( contextWatcher := ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: cancelConn})
func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) },
func() { cancelConn.SetDeadline(time.Time{}) },
)
contextWatcher.Watch(ctx) contextWatcher.Watch(ctx)
defer contextWatcher.Unwatch() defer contextWatcher.Unwatch()
} }
@ -1523,8 +1544,10 @@ func (rr *ResultReader) Read() *Result {
values := rr.Values() values := rr.Values()
row := make([][]byte, len(values)) row := make([][]byte, len(values))
for i := range row { for i := range row {
row[i] = make([]byte, len(values[i])) if values[i] != nil {
copy(row[i], values[i]) row[i] = make([]byte, len(values[i]))
copy(row[i], values[i])
}
} }
br.Rows = append(br.Rows, row) br.Rows = append(br.Rows, row)
} }
@ -1879,6 +1902,11 @@ func (pgConn *PgConn) SyncConn(ctx context.Context) error {
return errors.New("SyncConn: conn never synchronized") return errors.New("SyncConn: conn never synchronized")
} }
// CustomData returns a map that can be used to associate custom data with the connection.
func (pgConn *PgConn) CustomData() map[string]any {
return pgConn.customData
}
// HijackedConn is the result of hijacking a connection. // HijackedConn is the result of hijacking a connection.
// //
// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning
@ -1891,6 +1919,7 @@ type HijackedConn struct {
TxStatus byte TxStatus byte
Frontend *pgproto3.Frontend Frontend *pgproto3.Frontend
Config *Config Config *Config
CustomData map[string]any
} }
// Hijack extracts the internal connection data. pgConn must be in an idle state. SyncConn should be called immediately // Hijack extracts the internal connection data. pgConn must be in an idle state. SyncConn should be called immediately
@ -1913,6 +1942,7 @@ func (pgConn *PgConn) Hijack() (*HijackedConn, error) {
TxStatus: pgConn.txStatus, TxStatus: pgConn.txStatus,
Frontend: pgConn.frontend, Frontend: pgConn.frontend,
Config: pgConn.config, Config: pgConn.config,
CustomData: pgConn.customData,
}, nil }, nil
} }
@ -1932,13 +1962,14 @@ func Construct(hc *HijackedConn) (*PgConn, error) {
txStatus: hc.TxStatus, txStatus: hc.TxStatus,
frontend: hc.Frontend, frontend: hc.Frontend,
config: hc.Config, config: hc.Config,
customData: hc.CustomData,
status: connStatusIdle, status: connStatusIdle,
cleanupDone: make(chan struct{}), cleanupDone: make(chan struct{}),
} }
pgConn.contextWatcher = newContextWatcher(pgConn.conn) pgConn.contextWatcher = ctxwatch.NewContextWatcher(hc.Config.BuildContextWatcherHandler(pgConn))
pgConn.bgReader = bgreader.New(pgConn.conn) pgConn.bgReader = bgreader.New(pgConn.conn)
pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64),
func() { func() {
@ -2245,3 +2276,71 @@ func (p *Pipeline) Close() error {
return p.err return p.err
} }
// DeadlineContextWatcherHandler handles canceled contexts by setting a deadline on a net.Conn.
type DeadlineContextWatcherHandler struct {
Conn net.Conn
// DeadlineDelay is the delay to set on the deadline set on net.Conn when the context is canceled.
DeadlineDelay time.Duration
}
func (h *DeadlineContextWatcherHandler) HandleCancel(ctx context.Context) {
h.Conn.SetDeadline(time.Now().Add(h.DeadlineDelay))
}
func (h *DeadlineContextWatcherHandler) HandleUnwatchAfterCancel() {
h.Conn.SetDeadline(time.Time{})
}
// CancelRequestContextWatcherHandler handles canceled contexts by sending a cancel request to the server. It also sets
// a deadline on a net.Conn as a fallback.
type CancelRequestContextWatcherHandler struct {
Conn *PgConn
// CancelRequestDelay is the delay before sending the cancel request to the server.
CancelRequestDelay time.Duration
// DeadlineDelay is the delay to set on the deadline set on net.Conn when the context is canceled.
DeadlineDelay time.Duration
cancelFinishedChan chan struct{}
handleUnwatchAfterCancelCalled func()
}
func (h *CancelRequestContextWatcherHandler) HandleCancel(context.Context) {
h.cancelFinishedChan = make(chan struct{})
var handleUnwatchedAfterCancelCalledCtx context.Context
handleUnwatchedAfterCancelCalledCtx, h.handleUnwatchAfterCancelCalled = context.WithCancel(context.Background())
deadline := time.Now().Add(h.DeadlineDelay)
h.Conn.conn.SetDeadline(deadline)
go func() {
defer close(h.cancelFinishedChan)
select {
case <-handleUnwatchedAfterCancelCalledCtx.Done():
return
case <-time.After(h.CancelRequestDelay):
}
cancelRequestCtx, cancel := context.WithDeadline(handleUnwatchedAfterCancelCalledCtx, deadline)
defer cancel()
h.Conn.CancelRequest(cancelRequestCtx)
// CancelRequest is inherently racy. Even though the cancel request has been received by the server at this point,
// it hasn't necessarily been delivered to the other connection. If we immediately return and the connection is
// immediately used then it is possible the CancelRequest will actually cancel our next query. The
// TestCancelRequestContextWatcherHandler Stress test can produce this error without the sleep below. The sleep time
// is arbitrary, but should be sufficient to prevent this error case.
time.Sleep(100 * time.Millisecond)
}()
}
func (h *CancelRequestContextWatcherHandler) HandleUnwatchAfterCancel() {
h.handleUnwatchAfterCancelCalled()
<-h.cancelFinishedChan
h.Conn.conn.SetDeadline(time.Time{})
}

View File

@ -99,7 +99,7 @@ func getValueFromJSON(v map[string]string) ([]byte, error) {
return nil, errors.New("unknown protocol representation") return nil, errors.New("unknown protocol representation")
} }
// beginMessage begines a new message of type t. It appends the message type and a placeholder for the message length to // beginMessage begins a new message of type t. It appends the message type and a placeholder for the message length to
// dst. It returns the new buffer and the position of the message length placeholder. // dst. It returns the new buffer and the position of the message length placeholder.
func beginMessage(dst []byte, t byte) ([]byte, int) { func beginMessage(dst []byte, t byte) ([]byte, int) {
dst = append(dst, t) dst = append(dst, t)

View File

@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"github.com/jackc/pgx/v5/internal/anynil"
"github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/internal/pgio"
) )
@ -230,7 +229,7 @@ func (c *ArrayCodec) PlanScan(m *Map, oid uint32, format int16, target any) Scan
// target / arrayScanner might be a pointer to a nil. If it is create one so we can call ScanIndexType to plan the // target / arrayScanner might be a pointer to a nil. If it is create one so we can call ScanIndexType to plan the
// scan of the elements. // scan of the elements.
if anynil.Is(target) { if isNil, _ := isNilDriverValuer(target); isNil {
arrayScanner = reflect.New(reflect.TypeOf(target).Elem()).Interface().(ArraySetter) arrayScanner = reflect.New(reflect.TypeOf(target).Elem()).Interface().(ArraySetter)
} }

View File

@ -139,6 +139,16 @@ Compatibility with database/sql
pgtype also includes support for custom types implementing the database/sql.Scanner and database/sql/driver.Valuer pgtype also includes support for custom types implementing the database/sql.Scanner and database/sql/driver.Valuer
interfaces. interfaces.
Encoding Typed Nils
pgtype encodes untyped and typed nils (e.g. nil and []byte(nil)) to the SQL NULL value without going through the Codec
system. This means that Codecs and other encoding logic do not have to handle nil or *T(nil).
However, database/sql compatibility requires Value to be called on T(nil) when T implements driver.Valuer. Therefore,
driver.Valuer values are only considered NULL when *T(nil) where driver.Valuer is implemented on T not on *T. See
https://github.com/golang/go/issues/8415 and
https://github.com/golang/go/commit/0ce1d79a6a771f7449ec493b993ed2a720917870.
Child Records Child Records
pgtype's support for arrays and composite records can be used to load records and their children in a single query. See pgtype's support for arrays and composite records can be used to load records and their children in a single query. See

View File

@ -132,22 +132,31 @@ func (encodePlanIntervalCodecText) Encode(value any, buf []byte) (newBuf []byte,
if interval.Days != 0 { if interval.Days != 0 {
buf = append(buf, strconv.FormatInt(int64(interval.Days), 10)...) buf = append(buf, strconv.FormatInt(int64(interval.Days), 10)...)
buf = append(buf, " day "...) buf = append(buf, " day"...)
} }
absMicroseconds := interval.Microseconds if interval.Microseconds != 0 {
if absMicroseconds < 0 { buf = append(buf, " "...)
absMicroseconds = -absMicroseconds
buf = append(buf, '-') absMicroseconds := interval.Microseconds
if absMicroseconds < 0 {
absMicroseconds = -absMicroseconds
buf = append(buf, '-')
}
hours := absMicroseconds / microsecondsPerHour
minutes := (absMicroseconds % microsecondsPerHour) / microsecondsPerMinute
seconds := (absMicroseconds % microsecondsPerMinute) / microsecondsPerSecond
timeStr := fmt.Sprintf("%02d:%02d:%02d", hours, minutes, seconds)
buf = append(buf, timeStr...)
microseconds := absMicroseconds % microsecondsPerSecond
if microseconds != 0 {
buf = append(buf, fmt.Sprintf(".%06d", microseconds)...)
}
} }
hours := absMicroseconds / microsecondsPerHour
minutes := (absMicroseconds % microsecondsPerHour) / microsecondsPerMinute
seconds := (absMicroseconds % microsecondsPerMinute) / microsecondsPerSecond
microseconds := absMicroseconds % microsecondsPerSecond
timeStr := fmt.Sprintf("%02d:%02d:%02d.%06d", hours, minutes, seconds, microseconds)
buf = append(buf, timeStr...)
return buf, nil return buf, nil
} }

View File

@ -8,17 +8,20 @@ import (
"reflect" "reflect"
) )
type JSONCodec struct{} type JSONCodec struct {
Marshal func(v any) ([]byte, error)
Unmarshal func(data []byte, v any) error
}
func (JSONCodec) FormatSupported(format int16) bool { func (*JSONCodec) FormatSupported(format int16) bool {
return format == TextFormatCode || format == BinaryFormatCode return format == TextFormatCode || format == BinaryFormatCode
} }
func (JSONCodec) PreferredFormat() int16 { func (*JSONCodec) PreferredFormat() int16 {
return TextFormatCode return TextFormatCode
} }
func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { func (c *JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
switch value.(type) { switch value.(type) {
case string: case string:
return encodePlanJSONCodecEitherFormatString{} return encodePlanJSONCodecEitherFormatString{}
@ -44,7 +47,9 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod
// //
// https://github.com/jackc/pgx/issues/1681 // https://github.com/jackc/pgx/issues/1681
case json.Marshaler: case json.Marshaler:
return encodePlanJSONCodecEitherFormatMarshal{} return &encodePlanJSONCodecEitherFormatMarshal{
marshal: c.Marshal,
}
} }
// Because anything can be marshalled the normal wrapping in Map.PlanScan doesn't get a chance to run. So try the // Because anything can be marshalled the normal wrapping in Map.PlanScan doesn't get a chance to run. So try the
@ -61,7 +66,9 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod
} }
} }
return encodePlanJSONCodecEitherFormatMarshal{} return &encodePlanJSONCodecEitherFormatMarshal{
marshal: c.Marshal,
}
} }
type encodePlanJSONCodecEitherFormatString struct{} type encodePlanJSONCodecEitherFormatString struct{}
@ -96,10 +103,12 @@ func (encodePlanJSONCodecEitherFormatJSONRawMessage) Encode(value any, buf []byt
return buf, nil return buf, nil
} }
type encodePlanJSONCodecEitherFormatMarshal struct{} type encodePlanJSONCodecEitherFormatMarshal struct {
marshal func(v any) ([]byte, error)
}
func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) { func (e *encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) {
jsonBytes, err := json.Marshal(value) jsonBytes, err := e.marshal(value)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -108,7 +117,7 @@ func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (new
return buf, nil return buf, nil
} }
func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch target.(type) { switch target.(type) {
case *string: case *string:
return scanPlanAnyToString{} return scanPlanAnyToString{}
@ -141,7 +150,9 @@ func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan
return &scanPlanSQLScanner{formatCode: format} return &scanPlanSQLScanner{formatCode: format}
} }
return scanPlanJSONToJSONUnmarshal{} return &scanPlanJSONToJSONUnmarshal{
unmarshal: c.Unmarshal,
}
} }
type scanPlanAnyToString struct{} type scanPlanAnyToString struct{}
@ -173,9 +184,11 @@ func (scanPlanJSONToBytesScanner) Scan(src []byte, dst any) error {
return scanner.ScanBytes(src) return scanner.ScanBytes(src)
} }
type scanPlanJSONToJSONUnmarshal struct{} type scanPlanJSONToJSONUnmarshal struct {
unmarshal func(data []byte, v any) error
}
func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error { func (s *scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
if src == nil { if src == nil {
dstValue := reflect.ValueOf(dst) dstValue := reflect.ValueOf(dst)
if dstValue.Kind() == reflect.Ptr { if dstValue.Kind() == reflect.Ptr {
@ -193,10 +206,10 @@ func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
elem := reflect.ValueOf(dst).Elem() elem := reflect.ValueOf(dst).Elem()
elem.Set(reflect.Zero(elem.Type())) elem.Set(reflect.Zero(elem.Type()))
return json.Unmarshal(src, dst) return s.unmarshal(src, dst)
} }
func (c JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { func (c *JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
if src == nil { if src == nil {
return nil, nil return nil, nil
} }
@ -206,12 +219,12 @@ func (c JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src
return dstBuf, nil return dstBuf, nil
} }
func (c JSONCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { func (c *JSONCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
if src == nil { if src == nil {
return nil, nil return nil, nil
} }
var dst any var dst any
err := json.Unmarshal(src, &dst) err := c.Unmarshal(src, &dst)
return dst, err return dst, err
} }

View File

@ -2,29 +2,31 @@ package pgtype
import ( import (
"database/sql/driver" "database/sql/driver"
"encoding/json"
"fmt" "fmt"
) )
type JSONBCodec struct{} type JSONBCodec struct {
Marshal func(v any) ([]byte, error)
Unmarshal func(data []byte, v any) error
}
func (JSONBCodec) FormatSupported(format int16) bool { func (*JSONBCodec) FormatSupported(format int16) bool {
return format == TextFormatCode || format == BinaryFormatCode return format == TextFormatCode || format == BinaryFormatCode
} }
func (JSONBCodec) PreferredFormat() int16 { func (*JSONBCodec) PreferredFormat() int16 {
return TextFormatCode return TextFormatCode
} }
func (JSONBCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { func (c *JSONBCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
switch format { switch format {
case BinaryFormatCode: case BinaryFormatCode:
plan := JSONCodec{}.PlanEncode(m, oid, TextFormatCode, value) plan := (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanEncode(m, oid, TextFormatCode, value)
if plan != nil { if plan != nil {
return &encodePlanJSONBCodecBinaryWrapper{textPlan: plan} return &encodePlanJSONBCodecBinaryWrapper{textPlan: plan}
} }
case TextFormatCode: case TextFormatCode:
return JSONCodec{}.PlanEncode(m, oid, format, value) return (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanEncode(m, oid, format, value)
} }
return nil return nil
@ -39,15 +41,15 @@ func (plan *encodePlanJSONBCodecBinaryWrapper) Encode(value any, buf []byte) (ne
return plan.textPlan.Encode(value, buf) return plan.textPlan.Encode(value, buf)
} }
func (JSONBCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { func (c *JSONBCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format { switch format {
case BinaryFormatCode: case BinaryFormatCode:
plan := JSONCodec{}.PlanScan(m, oid, TextFormatCode, target) plan := (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanScan(m, oid, TextFormatCode, target)
if plan != nil { if plan != nil {
return &scanPlanJSONBCodecBinaryUnwrapper{textPlan: plan} return &scanPlanJSONBCodecBinaryUnwrapper{textPlan: plan}
} }
case TextFormatCode: case TextFormatCode:
return JSONCodec{}.PlanScan(m, oid, format, target) return (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanScan(m, oid, format, target)
} }
return nil return nil
@ -73,7 +75,7 @@ func (plan *scanPlanJSONBCodecBinaryUnwrapper) Scan(src []byte, dst any) error {
return plan.textPlan.Scan(src[1:], dst) return plan.textPlan.Scan(src[1:], dst)
} }
func (c JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { func (c *JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
if src == nil { if src == nil {
return nil, nil return nil, nil
} }
@ -100,7 +102,7 @@ func (c JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src
} }
} }
func (c JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { func (c *JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
if src == nil { if src == nil {
return nil, nil return nil, nil
} }
@ -122,6 +124,6 @@ func (c JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (a
} }
var dst any var dst any
err := json.Unmarshal(src, &dst) err := c.Unmarshal(src, &dst)
return dst, err return dst, err
} }

View File

@ -41,6 +41,7 @@ const (
CircleOID = 718 CircleOID = 718
CircleArrayOID = 719 CircleArrayOID = 719
UnknownOID = 705 UnknownOID = 705
Macaddr8OID = 774
MacaddrOID = 829 MacaddrOID = 829
InetOID = 869 InetOID = 869
BoolArrayOID = 1000 BoolArrayOID = 1000
@ -1330,7 +1331,7 @@ func (plan *derefPointerEncodePlan) Encode(value any, buf []byte) (newBuf []byte
} }
// TryWrapDerefPointerEncodePlan tries to dereference a pointer. e.g. If value was of type *string then a wrapper plan // TryWrapDerefPointerEncodePlan tries to dereference a pointer. e.g. If value was of type *string then a wrapper plan
// would be returned that derefences the value. // would be returned that dereferences the value.
func TryWrapDerefPointerEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) { func TryWrapDerefPointerEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) {
if _, ok := value.(driver.Valuer); ok { if _, ok := value.(driver.Valuer); ok {
return nil, nil, false return nil, nil, false
@ -1911,8 +1912,17 @@ func newEncodeError(value any, m *Map, oid uint32, formatCode int16, err error)
// (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data
// written. // written.
func (m *Map) Encode(oid uint32, formatCode int16, value any, buf []byte) (newBuf []byte, err error) { func (m *Map) Encode(oid uint32, formatCode int16, value any, buf []byte) (newBuf []byte, err error) {
if value == nil { if isNil, callNilDriverValuer := isNilDriverValuer(value); isNil {
return nil, nil if callNilDriverValuer {
newBuf, err = (&encodePlanDriverValuer{m: m, oid: oid, formatCode: formatCode}).Encode(value, buf)
if err != nil {
return nil, newEncodeError(value, m, oid, formatCode, err)
}
return newBuf, nil
} else {
return nil, nil
}
} }
plan := m.PlanEncode(oid, formatCode, value) plan := m.PlanEncode(oid, formatCode, value)
@ -1967,3 +1977,55 @@ func (w *sqlScannerWrapper) Scan(src any) error {
return w.m.Scan(t.OID, TextFormatCode, bufSrc, w.v) return w.m.Scan(t.OID, TextFormatCode, bufSrc, w.v)
} }
// canBeNil returns true if value can be nil.
func canBeNil(value any) bool {
refVal := reflect.ValueOf(value)
kind := refVal.Kind()
switch kind {
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice:
return true
default:
return false
}
}
// valuerReflectType is a reflect.Type for driver.Valuer. It has confusing syntax because reflect.TypeOf returns nil
// when it's argument is a nil interface value. So we use a pointer to the interface and call Elem to get the actual
// type. Yuck.
//
// This can be simplified in Go 1.22 with reflect.TypeFor.
//
// var valuerReflectType = reflect.TypeFor[driver.Valuer]()
var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
// isNilDriverValuer returns true if value is any type of nil unless it implements driver.Valuer. *T is not considered to implement
// driver.Valuer if it is only implemented by T.
func isNilDriverValuer(value any) (isNil bool, callNilDriverValuer bool) {
if value == nil {
return true, false
}
refVal := reflect.ValueOf(value)
kind := refVal.Kind()
switch kind {
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice:
if !refVal.IsNil() {
return false, false
}
if _, ok := value.(driver.Valuer); ok {
if kind == reflect.Ptr {
// The type assertion will succeed if driver.Valuer is implemented on T or *T. Check if it is implemented on *T
// by checking if it is not implemented on *T.
return true, !refVal.Type().Elem().Implements(valuerReflectType)
} else {
return true, true
}
}
return true, false
default:
return false, false
}
}

View File

@ -65,11 +65,12 @@ func initDefaultMap() {
defaultMap.RegisterType(&Type{Name: "int4", OID: Int4OID, Codec: Int4Codec{}}) defaultMap.RegisterType(&Type{Name: "int4", OID: Int4OID, Codec: Int4Codec{}})
defaultMap.RegisterType(&Type{Name: "int8", OID: Int8OID, Codec: Int8Codec{}}) defaultMap.RegisterType(&Type{Name: "int8", OID: Int8OID, Codec: Int8Codec{}})
defaultMap.RegisterType(&Type{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}}) defaultMap.RegisterType(&Type{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}})
defaultMap.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: JSONCodec{}}) defaultMap.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: &JSONCodec{Marshal: json.Marshal, Unmarshal: json.Unmarshal}})
defaultMap.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}}) defaultMap.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: &JSONBCodec{Marshal: json.Marshal, Unmarshal: json.Unmarshal}})
defaultMap.RegisterType(&Type{Name: "jsonpath", OID: JSONPathOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) defaultMap.RegisterType(&Type{Name: "jsonpath", OID: JSONPathOID, Codec: &TextFormatOnlyCodec{TextCodec{}}})
defaultMap.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}}) defaultMap.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}})
defaultMap.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}}) defaultMap.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}})
defaultMap.RegisterType(&Type{Name: "macaddr8", OID: Macaddr8OID, Codec: MacaddrCodec{}})
defaultMap.RegisterType(&Type{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}}) defaultMap.RegisterType(&Type{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}})
defaultMap.RegisterType(&Type{Name: "name", OID: NameOID, Codec: TextCodec{}}) defaultMap.RegisterType(&Type{Name: "name", OID: NameOID, Codec: TextCodec{}})
defaultMap.RegisterType(&Type{Name: "numeric", OID: NumericOID, Codec: NumericCodec{}}) defaultMap.RegisterType(&Type{Name: "numeric", OID: NumericOID, Codec: NumericCodec{}})
@ -81,8 +82,8 @@ func initDefaultMap() {
defaultMap.RegisterType(&Type{Name: "text", OID: TextOID, Codec: TextCodec{}}) defaultMap.RegisterType(&Type{Name: "text", OID: TextOID, Codec: TextCodec{}})
defaultMap.RegisterType(&Type{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) defaultMap.RegisterType(&Type{Name: "tid", OID: TIDOID, Codec: TIDCodec{}})
defaultMap.RegisterType(&Type{Name: "time", OID: TimeOID, Codec: TimeCodec{}}) defaultMap.RegisterType(&Type{Name: "time", OID: TimeOID, Codec: TimeCodec{}})
defaultMap.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}}) defaultMap.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: &TimestampCodec{}})
defaultMap.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: TimestamptzCodec{}}) defaultMap.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: &TimestamptzCodec{}})
defaultMap.RegisterType(&Type{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}}) defaultMap.RegisterType(&Type{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}})
defaultMap.RegisterType(&Type{Name: "uuid", OID: UUIDOID, Codec: UUIDCodec{}}) defaultMap.RegisterType(&Type{Name: "uuid", OID: UUIDOID, Codec: UUIDCodec{}})
defaultMap.RegisterType(&Type{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}}) defaultMap.RegisterType(&Type{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}})

View File

@ -45,7 +45,12 @@ func (t *Time) Scan(src any) error {
switch src := src.(type) { switch src := src.(type) {
case string: case string:
return scanPlanTextAnyToTimeScanner{}.Scan([]byte(src), t) err := scanPlanTextAnyToTimeScanner{}.Scan([]byte(src), t)
if err != nil {
t.Microseconds = 0
t.Valid = false
}
return err
} }
return fmt.Errorf("cannot scan %T", src) return fmt.Errorf("cannot scan %T", src)
@ -136,6 +141,8 @@ func (TimeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan
switch target.(type) { switch target.(type) {
case TimeScanner: case TimeScanner:
return scanPlanBinaryTimeToTimeScanner{} return scanPlanBinaryTimeToTimeScanner{}
case TextScanner:
return scanPlanBinaryTimeToTextScanner{}
} }
case TextFormatCode: case TextFormatCode:
switch target.(type) { switch target.(type) {
@ -165,6 +172,34 @@ func (scanPlanBinaryTimeToTimeScanner) Scan(src []byte, dst any) error {
return scanner.ScanTime(Time{Microseconds: usec, Valid: true}) return scanner.ScanTime(Time{Microseconds: usec, Valid: true})
} }
type scanPlanBinaryTimeToTextScanner struct{}
func (scanPlanBinaryTimeToTextScanner) Scan(src []byte, dst any) error {
ts, ok := (dst).(TextScanner)
if !ok {
return ErrScanTargetTypeChanged
}
if src == nil {
return ts.ScanText(Text{})
}
if len(src) != 8 {
return fmt.Errorf("invalid length for time: %v", len(src))
}
usec := int64(binary.BigEndian.Uint64(src))
tim := Time{Microseconds: usec, Valid: true}
buf, err := TimeCodec{}.PlanEncode(nil, 0, TextFormatCode, tim).Encode(tim, nil)
if err != nil {
return err
}
return ts.ScanText(Text{String: string(buf), Valid: true})
}
type scanPlanTextAnyToTimeScanner struct{} type scanPlanTextAnyToTimeScanner struct{}
func (scanPlanTextAnyToTimeScanner) Scan(src []byte, dst any) error { func (scanPlanTextAnyToTimeScanner) Scan(src []byte, dst any) error {
@ -176,7 +211,7 @@ func (scanPlanTextAnyToTimeScanner) Scan(src []byte, dst any) error {
s := string(src) s := string(src)
if len(s) < 8 { if len(s) < 8 || s[2] != ':' || s[5] != ':' {
return fmt.Errorf("cannot decode %v into Time", s) return fmt.Errorf("cannot decode %v into Time", s)
} }
@ -199,6 +234,10 @@ func (scanPlanTextAnyToTimeScanner) Scan(src []byte, dst any) error {
usec += seconds * microsecondsPerSecond usec += seconds * microsecondsPerSecond
if len(s) > 9 { if len(s) > 9 {
if s[8] != '.' || len(s) > 15 {
return fmt.Errorf("cannot decode %v into Time", s)
}
fraction := s[9:] fraction := s[9:]
n, err := strconv.ParseInt(fraction, 10, 64) n, err := strconv.ParseInt(fraction, 10, 64)
if err != nil { if err != nil {

View File

@ -46,7 +46,7 @@ func (ts *Timestamp) Scan(src any) error {
switch src := src.(type) { switch src := src.(type) {
case string: case string:
return scanPlanTextTimestampToTimestampScanner{}.Scan([]byte(src), ts) return (&scanPlanTextTimestampToTimestampScanner{}).Scan([]byte(src), ts)
case time.Time: case time.Time:
*ts = Timestamp{Time: src, Valid: true} *ts = Timestamp{Time: src, Valid: true}
return nil return nil
@ -116,17 +116,21 @@ func (ts *Timestamp) UnmarshalJSON(b []byte) error {
return nil return nil
} }
type TimestampCodec struct{} type TimestampCodec struct {
// ScanLocation is the location that the time is assumed to be in for scanning. This is different from
// TimestamptzCodec.ScanLocation in that this setting does change the instant in time that the timestamp represents.
ScanLocation *time.Location
}
func (TimestampCodec) FormatSupported(format int16) bool { func (*TimestampCodec) FormatSupported(format int16) bool {
return format == TextFormatCode || format == BinaryFormatCode return format == TextFormatCode || format == BinaryFormatCode
} }
func (TimestampCodec) PreferredFormat() int16 { func (*TimestampCodec) PreferredFormat() int16 {
return BinaryFormatCode return BinaryFormatCode
} }
func (TimestampCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { func (*TimestampCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
if _, ok := value.(TimestampValuer); !ok { if _, ok := value.(TimestampValuer); !ok {
return nil return nil
} }
@ -220,27 +224,27 @@ func discardTimeZone(t time.Time) time.Time {
return t return t
} }
func (TimestampCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { func (c *TimestampCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format { switch format {
case BinaryFormatCode: case BinaryFormatCode:
switch target.(type) { switch target.(type) {
case TimestampScanner: case TimestampScanner:
return scanPlanBinaryTimestampToTimestampScanner{} return &scanPlanBinaryTimestampToTimestampScanner{location: c.ScanLocation}
} }
case TextFormatCode: case TextFormatCode:
switch target.(type) { switch target.(type) {
case TimestampScanner: case TimestampScanner:
return scanPlanTextTimestampToTimestampScanner{} return &scanPlanTextTimestampToTimestampScanner{location: c.ScanLocation}
} }
} }
return nil return nil
} }
type scanPlanBinaryTimestampToTimestampScanner struct{} type scanPlanBinaryTimestampToTimestampScanner struct{ location *time.Location }
func (scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst any) error { func (plan *scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst any) error {
scanner := (dst).(TimestampScanner) scanner := (dst).(TimestampScanner)
if src == nil { if src == nil {
@ -264,15 +268,18 @@ func (scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst any) error
microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000, microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000,
(microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000), (microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000),
).UTC() ).UTC()
if plan.location != nil {
tim = time.Date(tim.Year(), tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), plan.location)
}
ts = Timestamp{Time: tim, Valid: true} ts = Timestamp{Time: tim, Valid: true}
} }
return scanner.ScanTimestamp(ts) return scanner.ScanTimestamp(ts)
} }
type scanPlanTextTimestampToTimestampScanner struct{} type scanPlanTextTimestampToTimestampScanner struct{ location *time.Location }
func (scanPlanTextTimestampToTimestampScanner) Scan(src []byte, dst any) error { func (plan *scanPlanTextTimestampToTimestampScanner) Scan(src []byte, dst any) error {
scanner := (dst).(TimestampScanner) scanner := (dst).(TimestampScanner)
if src == nil { if src == nil {
@ -302,13 +309,17 @@ func (scanPlanTextTimestampToTimestampScanner) Scan(src []byte, dst any) error {
tim = time.Date(year, tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), tim.Location()) tim = time.Date(year, tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), tim.Location())
} }
if plan.location != nil {
tim = time.Date(tim.Year(), tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), plan.location)
}
ts = Timestamp{Time: tim, Valid: true} ts = Timestamp{Time: tim, Valid: true}
} }
return scanner.ScanTimestamp(ts) return scanner.ScanTimestamp(ts)
} }
func (c TimestampCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { func (c *TimestampCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
if src == nil { if src == nil {
return nil, nil return nil, nil
} }
@ -326,7 +337,7 @@ func (c TimestampCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16,
return ts.Time, nil return ts.Time, nil
} }
func (c TimestampCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { func (c *TimestampCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
if src == nil { if src == nil {
return nil, nil return nil, nil
} }

View File

@ -54,7 +54,7 @@ func (tstz *Timestamptz) Scan(src any) error {
switch src := src.(type) { switch src := src.(type) {
case string: case string:
return scanPlanTextTimestamptzToTimestamptzScanner{}.Scan([]byte(src), tstz) return (&scanPlanTextTimestamptzToTimestamptzScanner{}).Scan([]byte(src), tstz)
case time.Time: case time.Time:
*tstz = Timestamptz{Time: src, Valid: true} *tstz = Timestamptz{Time: src, Valid: true}
return nil return nil
@ -124,17 +124,21 @@ func (tstz *Timestamptz) UnmarshalJSON(b []byte) error {
return nil return nil
} }
type TimestamptzCodec struct{} type TimestamptzCodec struct {
// ScanLocation is the location to return scanned timestamptz values in. This does not change the instant in time that
// the timestamptz represents.
ScanLocation *time.Location
}
func (TimestamptzCodec) FormatSupported(format int16) bool { func (*TimestamptzCodec) FormatSupported(format int16) bool {
return format == TextFormatCode || format == BinaryFormatCode return format == TextFormatCode || format == BinaryFormatCode
} }
func (TimestamptzCodec) PreferredFormat() int16 { func (*TimestamptzCodec) PreferredFormat() int16 {
return BinaryFormatCode return BinaryFormatCode
} }
func (TimestamptzCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { func (*TimestamptzCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
if _, ok := value.(TimestamptzValuer); !ok { if _, ok := value.(TimestamptzValuer); !ok {
return nil return nil
} }
@ -220,27 +224,27 @@ func (encodePlanTimestamptzCodecText) Encode(value any, buf []byte) (newBuf []by
return buf, nil return buf, nil
} }
func (TimestamptzCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { func (c *TimestamptzCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format { switch format {
case BinaryFormatCode: case BinaryFormatCode:
switch target.(type) { switch target.(type) {
case TimestamptzScanner: case TimestamptzScanner:
return scanPlanBinaryTimestamptzToTimestamptzScanner{} return &scanPlanBinaryTimestamptzToTimestamptzScanner{location: c.ScanLocation}
} }
case TextFormatCode: case TextFormatCode:
switch target.(type) { switch target.(type) {
case TimestamptzScanner: case TimestamptzScanner:
return scanPlanTextTimestamptzToTimestamptzScanner{} return &scanPlanTextTimestamptzToTimestamptzScanner{location: c.ScanLocation}
} }
} }
return nil return nil
} }
type scanPlanBinaryTimestamptzToTimestamptzScanner struct{} type scanPlanBinaryTimestamptzToTimestamptzScanner struct{ location *time.Location }
func (scanPlanBinaryTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) error { func (plan *scanPlanBinaryTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) error {
scanner := (dst).(TimestamptzScanner) scanner := (dst).(TimestamptzScanner)
if src == nil { if src == nil {
@ -264,15 +268,18 @@ func (scanPlanBinaryTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) e
microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000, microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000,
(microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000), (microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000),
) )
if plan.location != nil {
tim = tim.In(plan.location)
}
tstz = Timestamptz{Time: tim, Valid: true} tstz = Timestamptz{Time: tim, Valid: true}
} }
return scanner.ScanTimestamptz(tstz) return scanner.ScanTimestamptz(tstz)
} }
type scanPlanTextTimestamptzToTimestamptzScanner struct{} type scanPlanTextTimestamptzToTimestamptzScanner struct{ location *time.Location }
func (scanPlanTextTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) error { func (plan *scanPlanTextTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) error {
scanner := (dst).(TimestamptzScanner) scanner := (dst).(TimestamptzScanner)
if src == nil { if src == nil {
@ -312,13 +319,17 @@ func (scanPlanTextTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) err
tim = time.Date(year, tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), tim.Location()) tim = time.Date(year, tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), tim.Location())
} }
if plan.location != nil {
tim = tim.In(plan.location)
}
tstz = Timestamptz{Time: tim, Valid: true} tstz = Timestamptz{Time: tim, Valid: true}
} }
return scanner.ScanTimestamptz(tstz) return scanner.ScanTimestamptz(tstz)
} }
func (c TimestamptzCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { func (c *TimestamptzCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
if src == nil { if src == nil {
return nil, nil return nil, nil
} }
@ -336,7 +347,7 @@ func (c TimestamptzCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int1
return tstz.Time, nil return tstz.Time, nil
} }
func (c TimestamptzCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { func (c *TimestamptzCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
if src == nil { if src == nil {
return nil, nil return nil, nil
} }

View File

@ -26,6 +26,10 @@ func (c *Conn) Release() {
res := c.res res := c.res
c.res = nil c.res = nil
if c.p.releaseTracer != nil {
c.p.releaseTracer.TraceRelease(c.p, TraceReleaseData{Conn: conn})
}
if conn.IsClosed() || conn.PgConn().IsBusy() || conn.PgConn().TxStatus() != 'I' { if conn.IsClosed() || conn.PgConn().IsBusy() || conn.PgConn().TxStatus() != 'I' {
res.Destroy() res.Destroy()
// Signal to the health check to run since we just destroyed a connections // Signal to the health check to run since we just destroyed a connections

View File

@ -8,7 +8,7 @@ The primary way of creating a pool is with [pgxpool.New]:
pool, err := pgxpool.New(context.Background(), os.Getenv("DATABASE_URL")) pool, err := pgxpool.New(context.Background(), os.Getenv("DATABASE_URL"))
The database connection string can be in URL or DSN format. PostgreSQL settings, pgx settings, and pool settings can be The database connection string can be in URL or keyword/value format. PostgreSQL settings, pgx settings, and pool settings can be
specified here. In addition, a config struct can be created by [ParseConfig]. specified here. In addition, a config struct can be created by [ParseConfig].
config, err := pgxpool.ParseConfig(os.Getenv("DATABASE_URL")) config, err := pgxpool.ParseConfig(os.Getenv("DATABASE_URL"))

View File

@ -95,6 +95,9 @@ type Pool struct {
healthCheckChan chan struct{} healthCheckChan chan struct{}
acquireTracer AcquireTracer
releaseTracer ReleaseTracer
closeOnce sync.Once closeOnce sync.Once
closeChan chan struct{} closeChan chan struct{}
} }
@ -195,6 +198,14 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
closeChan: make(chan struct{}), closeChan: make(chan struct{}),
} }
if t, ok := config.ConnConfig.Tracer.(AcquireTracer); ok {
p.acquireTracer = t
}
if t, ok := config.ConnConfig.Tracer.(ReleaseTracer); ok {
p.releaseTracer = t
}
var err error var err error
p.p, err = puddle.NewPool( p.p, err = puddle.NewPool(
&puddle.Config[*connResource]{ &puddle.Config[*connResource]{
@ -279,7 +290,7 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
// //
// See Config for definitions of these arguments. // See Config for definitions of these arguments.
// //
// # Example DSN // # Example Keyword/Value
// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca pool_max_conns=10 // user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca pool_max_conns=10
// //
// # Example URL // # Example URL
@ -498,7 +509,18 @@ func (p *Pool) createIdleResources(parentCtx context.Context, targetResources in
} }
// Acquire returns a connection (*Conn) from the Pool // Acquire returns a connection (*Conn) from the Pool
func (p *Pool) Acquire(ctx context.Context) (*Conn, error) { func (p *Pool) Acquire(ctx context.Context) (c *Conn, err error) {
if p.acquireTracer != nil {
ctx = p.acquireTracer.TraceAcquireStart(ctx, p, TraceAcquireStartData{})
defer func() {
var conn *pgx.Conn
if c != nil {
conn = c.Conn()
}
p.acquireTracer.TraceAcquireEnd(ctx, p, TraceAcquireEndData{Conn: conn, Err: err})
}()
}
for { for {
res, err := p.p.Acquire(ctx) res, err := p.p.Acquire(ctx)
if err != nil { if err != nil {

33
vendor/github.com/jackc/pgx/v5/pgxpool/tracer.go generated vendored Normal file
View File

@ -0,0 +1,33 @@
package pgxpool
import (
"context"
"github.com/jackc/pgx/v5"
)
// AcquireTracer traces Acquire.
type AcquireTracer interface {
// TraceAcquireStart is called at the beginning of Acquire.
// The returned context is used for the rest of the call and will be passed to the TraceAcquireEnd.
TraceAcquireStart(ctx context.Context, pool *Pool, data TraceAcquireStartData) context.Context
// TraceAcquireEnd is called when a connection has been acquired.
TraceAcquireEnd(ctx context.Context, pool *Pool, data TraceAcquireEndData)
}
type TraceAcquireStartData struct{}
type TraceAcquireEndData struct {
Conn *pgx.Conn
Err error
}
// ReleaseTracer traces Release.
type ReleaseTracer interface {
// TraceRelease is called at the beginning of Release.
TraceRelease(pool *Pool, data TraceReleaseData)
}
type TraceReleaseData struct {
Conn *pgx.Conn
}

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
"sync"
"time" "time"
"github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgconn"
@ -418,6 +419,8 @@ type CollectableRow interface {
type RowToFunc[T any] func(row CollectableRow) (T, error) type RowToFunc[T any] func(row CollectableRow) (T, error)
// AppendRows iterates through rows, calling fn for each row, and appending the results into a slice of T. // AppendRows iterates through rows, calling fn for each row, and appending the results into a slice of T.
//
// This function closes the rows automatically on return.
func AppendRows[T any, S ~[]T](slice S, rows Rows, fn RowToFunc[T]) (S, error) { func AppendRows[T any, S ~[]T](slice S, rows Rows, fn RowToFunc[T]) (S, error) {
defer rows.Close() defer rows.Close()
@ -437,12 +440,16 @@ func AppendRows[T any, S ~[]T](slice S, rows Rows, fn RowToFunc[T]) (S, error) {
} }
// CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T. // CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T.
//
// This function closes the rows automatically on return.
func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) { func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) {
return AppendRows([]T{}, rows, fn) return AppendRows([]T{}, rows, fn)
} }
// CollectOneRow calls fn for the first row in rows and returns the result. If no rows are found returns an error where errors.Is(ErrNoRows) is true. // CollectOneRow calls fn for the first row in rows and returns the result. If no rows are found returns an error where errors.Is(ErrNoRows) is true.
// CollectOneRow is to CollectRows as QueryRow is to Query. // CollectOneRow is to CollectRows as QueryRow is to Query.
//
// This function closes the rows automatically on return.
func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) { func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) {
defer rows.Close() defer rows.Close()
@ -468,6 +475,8 @@ func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) {
// CollectExactlyOneRow calls fn for the first row in rows and returns the result. // CollectExactlyOneRow calls fn for the first row in rows and returns the result.
// - If no rows are found returns an error where errors.Is(ErrNoRows) is true. // - If no rows are found returns an error where errors.Is(ErrNoRows) is true.
// - If more than 1 row is found returns an error where errors.Is(ErrTooManyRows) is true. // - If more than 1 row is found returns an error where errors.Is(ErrTooManyRows) is true.
//
// This function closes the rows automatically on return.
func CollectExactlyOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) { func CollectExactlyOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) {
defer rows.Close() defer rows.Close()
@ -541,7 +550,7 @@ func (rs *mapRowScanner) ScanRow(rows Rows) error {
// ignored. // ignored.
func RowToStructByPos[T any](row CollectableRow) (T, error) { func RowToStructByPos[T any](row CollectableRow) (T, error) {
var value T var value T
err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value}) err := (&positionalStructRowScanner{ptrToStruct: &value}).ScanRow(row)
return value, err return value, err
} }
@ -550,7 +559,7 @@ func RowToStructByPos[T any](row CollectableRow) (T, error) {
// the field will be ignored. // the field will be ignored.
func RowToAddrOfStructByPos[T any](row CollectableRow) (*T, error) { func RowToAddrOfStructByPos[T any](row CollectableRow) (*T, error) {
var value T var value T
err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value}) err := (&positionalStructRowScanner{ptrToStruct: &value}).ScanRow(row)
return &value, err return &value, err
} }
@ -558,46 +567,60 @@ type positionalStructRowScanner struct {
ptrToStruct any ptrToStruct any
} }
func (rs *positionalStructRowScanner) ScanRow(rows Rows) error { func (rs *positionalStructRowScanner) ScanRow(rows CollectableRow) error {
dst := rs.ptrToStruct typ := reflect.TypeOf(rs.ptrToStruct).Elem()
dstValue := reflect.ValueOf(dst) fields := lookupStructFields(typ)
if dstValue.Kind() != reflect.Ptr { if len(rows.RawValues()) > len(fields) {
return fmt.Errorf("dst not a pointer") return fmt.Errorf(
"got %d values, but dst struct has only %d fields",
len(rows.RawValues()),
len(fields),
)
} }
scanTargets := setupStructScanTargets(rs.ptrToStruct, fields)
dstElemValue := dstValue.Elem()
scanTargets := rs.appendScanTargets(dstElemValue, nil)
if len(rows.RawValues()) > len(scanTargets) {
return fmt.Errorf("got %d values, but dst struct has only %d fields", len(rows.RawValues()), len(scanTargets))
}
return rows.Scan(scanTargets...) return rows.Scan(scanTargets...)
} }
func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any) []any { // Map from reflect.Type -> []structRowField
dstElemType := dstElemValue.Type() var positionalStructFieldMap sync.Map
if scanTargets == nil { func lookupStructFields(t reflect.Type) []structRowField {
scanTargets = make([]any, 0, dstElemType.NumField()) if cached, ok := positionalStructFieldMap.Load(t); ok {
return cached.([]structRowField)
} }
for i := 0; i < dstElemType.NumField(); i++ { fieldStack := make([]int, 0, 1)
sf := dstElemType.Field(i) fields := computeStructFields(t, make([]structRowField, 0, t.NumField()), &fieldStack)
fieldsIface, _ := positionalStructFieldMap.LoadOrStore(t, fields)
return fieldsIface.([]structRowField)
}
func computeStructFields(
t reflect.Type,
fields []structRowField,
fieldStack *[]int,
) []structRowField {
tail := len(*fieldStack)
*fieldStack = append(*fieldStack, 0)
for i := 0; i < t.NumField(); i++ {
sf := t.Field(i)
(*fieldStack)[tail] = i
// Handle anonymous struct embedding, but do not try to handle embedded pointers. // Handle anonymous struct embedding, but do not try to handle embedded pointers.
if sf.Anonymous && sf.Type.Kind() == reflect.Struct { if sf.Anonymous && sf.Type.Kind() == reflect.Struct {
scanTargets = rs.appendScanTargets(dstElemValue.Field(i), scanTargets) fields = computeStructFields(sf.Type, fields, fieldStack)
} else if sf.PkgPath == "" { } else if sf.PkgPath == "" {
dbTag, _ := sf.Tag.Lookup(structTagKey) dbTag, _ := sf.Tag.Lookup(structTagKey)
if dbTag == "-" { if dbTag == "-" {
// Field is ignored, skip it. // Field is ignored, skip it.
continue continue
} }
scanTargets = append(scanTargets, dstElemValue.Field(i).Addr().Interface()) fields = append(fields, structRowField{
path: append([]int(nil), *fieldStack...),
})
} }
} }
*fieldStack = (*fieldStack)[:tail]
return scanTargets return fields
} }
// RowToStructByName returns a T scanned from row. T must be a struct. T must have the same number of named public // RowToStructByName returns a T scanned from row. T must be a struct. T must have the same number of named public
@ -605,7 +628,7 @@ func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Val
// column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored. // column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored.
func RowToStructByName[T any](row CollectableRow) (T, error) { func RowToStructByName[T any](row CollectableRow) (T, error) {
var value T var value T
err := row.Scan(&namedStructRowScanner{ptrToStruct: &value}) err := (&namedStructRowScanner{ptrToStruct: &value}).ScanRow(row)
return value, err return value, err
} }
@ -615,7 +638,7 @@ func RowToStructByName[T any](row CollectableRow) (T, error) {
// then the field will be ignored. // then the field will be ignored.
func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) { func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) {
var value T var value T
err := row.Scan(&namedStructRowScanner{ptrToStruct: &value}) err := (&namedStructRowScanner{ptrToStruct: &value}).ScanRow(row)
return &value, err return &value, err
} }
@ -624,7 +647,7 @@ func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) {
// column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored. // column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored.
func RowToStructByNameLax[T any](row CollectableRow) (T, error) { func RowToStructByNameLax[T any](row CollectableRow) (T, error) {
var value T var value T
err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true}) err := (&namedStructRowScanner{ptrToStruct: &value, lax: true}).ScanRow(row)
return value, err return value, err
} }
@ -634,7 +657,7 @@ func RowToStructByNameLax[T any](row CollectableRow) (T, error) {
// then the field will be ignored. // then the field will be ignored.
func RowToAddrOfStructByNameLax[T any](row CollectableRow) (*T, error) { func RowToAddrOfStructByNameLax[T any](row CollectableRow) (*T, error) {
var value T var value T
err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true}) err := (&namedStructRowScanner{ptrToStruct: &value, lax: true}).ScanRow(row)
return &value, err return &value, err
} }
@ -643,26 +666,152 @@ type namedStructRowScanner struct {
lax bool lax bool
} }
func (rs *namedStructRowScanner) ScanRow(rows Rows) error { func (rs *namedStructRowScanner) ScanRow(rows CollectableRow) error {
dst := rs.ptrToStruct typ := reflect.TypeOf(rs.ptrToStruct).Elem()
dstValue := reflect.ValueOf(dst) fldDescs := rows.FieldDescriptions()
if dstValue.Kind() != reflect.Ptr { namedStructFields, err := lookupNamedStructFields(typ, fldDescs)
return fmt.Errorf("dst not a pointer")
}
dstElemValue := dstValue.Elem()
scanTargets, err := rs.appendScanTargets(dstElemValue, nil, rows.FieldDescriptions())
if err != nil { if err != nil {
return err return err
} }
if !rs.lax && namedStructFields.missingField != "" {
return fmt.Errorf("cannot find field %s in returned row", namedStructFields.missingField)
}
fields := namedStructFields.fields
scanTargets := setupStructScanTargets(rs.ptrToStruct, fields)
return rows.Scan(scanTargets...)
}
for i, t := range scanTargets { // Map from namedStructFieldMap -> *namedStructFields
if t == nil { var namedStructFieldMap sync.Map
return fmt.Errorf("struct doesn't have corresponding row field %s", rows.FieldDescriptions()[i].Name)
type namedStructFieldsKey struct {
t reflect.Type
colNames string
}
type namedStructFields struct {
fields []structRowField
// missingField is the first field from the struct without a corresponding row field.
// This is used to construct the correct error message for non-lax queries.
missingField string
}
func lookupNamedStructFields(
t reflect.Type,
fldDescs []pgconn.FieldDescription,
) (*namedStructFields, error) {
key := namedStructFieldsKey{
t: t,
colNames: joinFieldNames(fldDescs),
}
if cached, ok := namedStructFieldMap.Load(key); ok {
return cached.(*namedStructFields), nil
}
// We could probably do two-levels of caching, where we compute the key -> fields mapping
// for a type only once, cache it by type, then use that to compute the column -> fields
// mapping for a given set of columns.
fieldStack := make([]int, 0, 1)
fields, missingField := computeNamedStructFields(
fldDescs,
t,
make([]structRowField, len(fldDescs)),
&fieldStack,
)
for i, f := range fields {
if f.path == nil {
return nil, fmt.Errorf(
"struct doesn't have corresponding row field %s",
fldDescs[i].Name,
)
} }
} }
return rows.Scan(scanTargets...) fieldsIface, _ := namedStructFieldMap.LoadOrStore(
key,
&namedStructFields{fields: fields, missingField: missingField},
)
return fieldsIface.(*namedStructFields), nil
}
func joinFieldNames(fldDescs []pgconn.FieldDescription) string {
switch len(fldDescs) {
case 0:
return ""
case 1:
return fldDescs[0].Name
}
totalSize := len(fldDescs) - 1 // Space for separator bytes.
for _, d := range fldDescs {
totalSize += len(d.Name)
}
var b strings.Builder
b.Grow(totalSize)
b.WriteString(fldDescs[0].Name)
for _, d := range fldDescs[1:] {
b.WriteByte(0) // Join with NUL byte as it's (presumably) not a valid column character.
b.WriteString(d.Name)
}
return b.String()
}
func computeNamedStructFields(
fldDescs []pgconn.FieldDescription,
t reflect.Type,
fields []structRowField,
fieldStack *[]int,
) ([]structRowField, string) {
var missingField string
tail := len(*fieldStack)
*fieldStack = append(*fieldStack, 0)
for i := 0; i < t.NumField(); i++ {
sf := t.Field(i)
(*fieldStack)[tail] = i
if sf.PkgPath != "" && !sf.Anonymous {
// Field is unexported, skip it.
continue
}
// Handle anonymous struct embedding, but do not try to handle embedded pointers.
if sf.Anonymous && sf.Type.Kind() == reflect.Struct {
var missingSubField string
fields, missingSubField = computeNamedStructFields(
fldDescs,
sf.Type,
fields,
fieldStack,
)
if missingField == "" {
missingField = missingSubField
}
} else {
dbTag, dbTagPresent := sf.Tag.Lookup(structTagKey)
if dbTagPresent {
dbTag, _, _ = strings.Cut(dbTag, ",")
}
if dbTag == "-" {
// Field is ignored, skip it.
continue
}
colName := dbTag
if !dbTagPresent {
colName = sf.Name
}
fpos := fieldPosByName(fldDescs, colName)
if fpos == -1 {
if missingField == "" {
missingField = colName
}
continue
}
fields[fpos] = structRowField{
path: append([]int(nil), *fieldStack...),
}
}
}
*fieldStack = (*fieldStack)[:tail]
return fields, missingField
} }
const structTagKey = "db" const structTagKey = "db"
@ -682,52 +831,21 @@ func fieldPosByName(fldDescs []pgconn.FieldDescription, field string) (i int) {
return return
} }
func (rs *namedStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any, fldDescs []pgconn.FieldDescription) ([]any, error) { // structRowField describes a field of a struct.
var err error //
dstElemType := dstElemValue.Type() // TODO: It would be a bit more efficient to track the path using the pointer
// offset within the (outermost) struct and use unsafe.Pointer arithmetic to
if scanTargets == nil { // construct references when scanning rows. However, it's not clear it's worth
scanTargets = make([]any, len(fldDescs)) // using unsafe for this.
} type structRowField struct {
path []int
for i := 0; i < dstElemType.NumField(); i++ { }
sf := dstElemType.Field(i)
if sf.PkgPath != "" && !sf.Anonymous { func setupStructScanTargets(receiver any, fields []structRowField) []any {
// Field is unexported, skip it. scanTargets := make([]any, len(fields))
continue v := reflect.ValueOf(receiver).Elem()
} for i, f := range fields {
// Handle anonymous struct embedding, but do not try to handle embedded pointers. scanTargets[i] = v.FieldByIndex(f.path).Addr().Interface()
if sf.Anonymous && sf.Type.Kind() == reflect.Struct { }
scanTargets, err = rs.appendScanTargets(dstElemValue.Field(i), scanTargets, fldDescs) return scanTargets
if err != nil {
return nil, err
}
} else {
dbTag, dbTagPresent := sf.Tag.Lookup(structTagKey)
if dbTagPresent {
dbTag, _, _ = strings.Cut(dbTag, ",")
}
if dbTag == "-" {
// Field is ignored, skip it.
continue
}
colName := dbTag
if !dbTagPresent {
colName = sf.Name
}
fpos := fieldPosByName(fldDescs, colName)
if fpos == -1 {
if rs.lax {
continue
}
return nil, fmt.Errorf("cannot find field %s in returned row", colName)
}
if fpos >= len(scanTargets) && !rs.lax {
return nil, fmt.Errorf("cannot find field %s in returned row", colName)
}
scanTargets[fpos] = dstElemValue.Field(i).Addr().Interface()
}
}
return scanTargets, err
} }

View File

@ -7,7 +7,7 @@
// return err // return err
// } // }
// //
// Or from a DSN string. // Or from a keyword/value string.
// //
// db, err := sql.Open("pgx", "user=postgres password=secret host=localhost port=5432 database=pgx_test sslmode=disable") // db, err := sql.Open("pgx", "user=postgres password=secret host=localhost port=5432 database=pgx_test sslmode=disable")
// if err != nil { // if err != nil {

View File

@ -3,7 +3,6 @@ package pgx
import ( import (
"errors" "errors"
"github.com/jackc/pgx/v5/internal/anynil"
"github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/internal/pgio"
"github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgtype"
) )
@ -15,10 +14,6 @@ const (
) )
func convertSimpleArgument(m *pgtype.Map, arg any) (any, error) { func convertSimpleArgument(m *pgtype.Map, arg any) (any, error) {
if anynil.Is(arg) {
return nil, nil
}
buf, err := m.Encode(0, TextFormatCode, arg, []byte{}) buf, err := m.Encode(0, TextFormatCode, arg, []byte{})
if err != nil { if err != nil {
return nil, err return nil, err
@ -30,10 +25,6 @@ func convertSimpleArgument(m *pgtype.Map, arg any) (any, error) {
} }
func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, error) { func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, error) {
if anynil.Is(arg) {
return pgio.AppendInt32(buf, -1), nil
}
sp := len(buf) sp := len(buf)
buf = pgio.AppendInt32(buf, -1) buf = pgio.AppendInt32(buf, -1)
argBuf, err := m.Encode(oid, BinaryFormatCode, arg, buf) argBuf, err := m.Encode(oid, BinaryFormatCode, arg, buf)

7
vendor/modules.txt vendored
View File

@ -421,17 +421,16 @@ github.com/jackc/pgpassfile
# github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a # github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a
## explicit; go 1.14 ## explicit; go 1.14
github.com/jackc/pgservicefile github.com/jackc/pgservicefile
# github.com/jackc/pgx/v5 v5.5.5 # github.com/jackc/pgx/v5 v5.6.0
## explicit; go 1.19 ## explicit; go 1.20
github.com/jackc/pgx/v5 github.com/jackc/pgx/v5
github.com/jackc/pgx/v5/internal/anynil
github.com/jackc/pgx/v5/internal/iobufpool github.com/jackc/pgx/v5/internal/iobufpool
github.com/jackc/pgx/v5/internal/pgio github.com/jackc/pgx/v5/internal/pgio
github.com/jackc/pgx/v5/internal/sanitize github.com/jackc/pgx/v5/internal/sanitize
github.com/jackc/pgx/v5/internal/stmtcache github.com/jackc/pgx/v5/internal/stmtcache
github.com/jackc/pgx/v5/pgconn github.com/jackc/pgx/v5/pgconn
github.com/jackc/pgx/v5/pgconn/ctxwatch
github.com/jackc/pgx/v5/pgconn/internal/bgreader github.com/jackc/pgx/v5/pgconn/internal/bgreader
github.com/jackc/pgx/v5/pgconn/internal/ctxwatch
github.com/jackc/pgx/v5/pgproto3 github.com/jackc/pgx/v5/pgproto3
github.com/jackc/pgx/v5/pgtype github.com/jackc/pgx/v5/pgtype
github.com/jackc/pgx/v5/pgxpool github.com/jackc/pgx/v5/pgxpool