mirror of
1
Fork 0
gotosocial/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go

2347 lines
73 KiB
Go
Raw Normal View History

package pgconn
import (
"context"
"crypto/md5"
"crypto/tls"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"io"
"math"
"net"
"strconv"
"strings"
"sync"
"time"
"github.com/jackc/pgx/v5/internal/iobufpool"
"github.com/jackc/pgx/v5/internal/pgio"
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
"github.com/jackc/pgx/v5/pgconn/internal/bgreader"
"github.com/jackc/pgx/v5/pgproto3"
)
const (
connStatusUninitialized = iota
connStatusConnecting
connStatusClosed
connStatusIdle
connStatusBusy
)
// Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from
// LISTEN/NOTIFY notification.
type Notice PgError
// Notification is a message received from the PostgreSQL LISTEN/NOTIFY system
type Notification struct {
PID uint32 // backend pid that sent the notification
Channel string // channel from which notification was received
Payload string
}
// DialFunc is a function that can be used to connect to a PostgreSQL server.
type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
// LookupFunc is a function that can be used to lookup IPs addrs from host. Optionally an ip:port combination can be
// returned in order to override the connection string's port.
type LookupFunc func(ctx context.Context, host string) (addrs []string, err error)
// BuildFrontendFunc is a function that can be used to create Frontend implementation for connection.
type BuildFrontendFunc func(r io.Reader, w io.Writer) *pgproto3.Frontend
// PgErrorHandler is a function that handles errors returned from Postgres. This function must return true to keep
// the connection open. Returning false will cause the connection to be closed immediately. You should return
// false on any FATAL-severity errors. This will not receive network errors. The *PgConn is provided so the handler is
// aware of the origin of the error, but it must not invoke any query method.
type PgErrorHandler func(*PgConn, *PgError) bool
// NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at
// any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin
// of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY
// notification.
type NoticeHandler func(*PgConn, *Notice)
// NotificationHandler is a function that can handle notifications received from the PostgreSQL server. Notifications
// can be received at any time, usually during handling of a query response. The *PgConn is provided so the handler is
// aware of the origin of the notice, but it must not invoke any query method. Be aware that this is distinct from a
// notice event.
type NotificationHandler func(*PgConn, *Notification)
// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.
type PgConn struct {
conn net.Conn
pid uint32 // backend pid
secretKey uint32 // key to use to send a cancel query message to the server
parameterStatuses map[string]string // parameters that have been reported by the server
txStatus byte
frontend *pgproto3.Frontend
bgReader *bgreader.BGReader
slowWriteTimer *time.Timer
bgReaderStarted chan struct{}
customData map[string]any
config *Config
status byte // One of connStatus* constants
bufferingReceive bool
bufferingReceiveMux sync.Mutex
bufferingReceiveMsg pgproto3.BackendMessage
bufferingReceiveErr error
peekedMsg pgproto3.BackendMessage
// Reusable / preallocated resources
resultReader ResultReader
multiResultReader MultiResultReader
pipeline Pipeline
contextWatcher *ctxwatch.ContextWatcher
fieldDescriptions [16]FieldDescription
cleanupDone chan struct{}
}
// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or keyword/value
// format) to provide configuration. See documentation for [ParseConfig] for details. ctx can be used to cancel a
// connect attempt.
func Connect(ctx context.Context, connString string) (*PgConn, error) {
config, err := ParseConfig(connString)
if err != nil {
return nil, err
}
return ConnectConfig(ctx, config)
}
// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or keyword/value
// format) and ParseConfigOptions to provide additional configuration. See documentation for [ParseConfig] for details.
// ctx can be used to cancel a connect attempt.
func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptions ParseConfigOptions) (*PgConn, error) {
config, err := ParseConfigWithOptions(connString, parseConfigOptions)
if err != nil {
return nil, err
}
return ConnectConfig(ctx, config)
}
// Connect establishes a connection to a PostgreSQL server using config. config must have been constructed with
// [ParseConfig]. ctx can be used to cancel a connect attempt.
//
// If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An
// authentication error will terminate the chain of attempts (like libpq:
// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error.
func ConnectConfig(ctx context.Context, config *Config) (*PgConn, error) {
// Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from
// zero values.
if !config.createdByParseConfig {
panic("config must be created by ParseConfig")
}
var allErrors []error
connectConfigs, errs := buildConnectOneConfigs(ctx, config)
if len(errs) > 0 {
allErrors = append(allErrors, errs...)
}
if len(connectConfigs) == 0 {
return nil, &ConnectError{Config: config, err: fmt.Errorf("hostname resolving error: %w", errors.Join(allErrors...))}
}
pgConn, errs := connectPreferred(ctx, config, connectConfigs)
if len(errs) > 0 {
allErrors = append(allErrors, errs...)
return nil, &ConnectError{Config: config, err: errors.Join(allErrors...)}
}
if config.AfterConnect != nil {
err := config.AfterConnect(ctx, pgConn)
if err != nil {
pgConn.conn.Close()
return nil, &ConnectError{Config: config, err: fmt.Errorf("AfterConnect error: %w", err)}
}
}
return pgConn, nil
}
// buildConnectOneConfigs resolves hostnames and builds a list of connectOneConfigs to try connecting to. It returns a
// slice of successfully resolved connectOneConfigs and a slice of errors. It is possible for both slices to contain
// values if some hosts were successfully resolved and others were not.
func buildConnectOneConfigs(ctx context.Context, config *Config) ([]*connectOneConfig, []error) {
// Simplify usage by treating primary config and fallbacks the same.
fallbackConfigs := []*FallbackConfig{
{
Host: config.Host,
Port: config.Port,
TLSConfig: config.TLSConfig,
},
}
fallbackConfigs = append(fallbackConfigs, config.Fallbacks...)
var configs []*connectOneConfig
var allErrors []error
for _, fb := range fallbackConfigs {
// skip resolve for unix sockets
if isAbsolutePath(fb.Host) {
network, address := NetworkAddress(fb.Host, fb.Port)
configs = append(configs, &connectOneConfig{
network: network,
address: address,
originalHostname: fb.Host,
tlsConfig: fb.TLSConfig,
})
continue
}
ips, err := config.LookupFunc(ctx, fb.Host)
if err != nil {
allErrors = append(allErrors, err)
continue
}
for _, ip := range ips {
splitIP, splitPort, err := net.SplitHostPort(ip)
if err == nil {
port, err := strconv.ParseUint(splitPort, 10, 16)
if err != nil {
return nil, []error{fmt.Errorf("error parsing port (%s) from lookup: %w", splitPort, err)}
}
network, address := NetworkAddress(splitIP, uint16(port))
configs = append(configs, &connectOneConfig{
network: network,
address: address,
originalHostname: fb.Host,
tlsConfig: fb.TLSConfig,
})
} else {
network, address := NetworkAddress(ip, fb.Port)
configs = append(configs, &connectOneConfig{
network: network,
address: address,
originalHostname: fb.Host,
tlsConfig: fb.TLSConfig,
})
}
}
}
return configs, allErrors
}
// connectPreferred attempts to connect to the preferred host from connectOneConfigs. The connections are attempted in
// order. If a connection is successful it is returned. If no connection is successful then all errors are returned. If
// a connection attempt returns a [NotPreferredError], then that host will be used if no other hosts are successful.
func connectPreferred(ctx context.Context, config *Config, connectOneConfigs []*connectOneConfig) (*PgConn, []error) {
octx := ctx
var allErrors []error
var fallbackConnectOneConfig *connectOneConfig
for i, c := range connectOneConfigs {
// ConnectTimeout restricts the whole connection process.
if config.ConnectTimeout != 0 {
// create new context first time or when previous host was different
if i == 0 || (connectOneConfigs[i].address != connectOneConfigs[i-1].address) {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(octx, config.ConnectTimeout)
defer cancel()
}
} else {
ctx = octx
}
pgConn, err := connectOne(ctx, config, c, false)
if pgConn != nil {
return pgConn, nil
}
allErrors = append(allErrors, err)
var pgErr *PgError
if errors.As(err, &pgErr) {
const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password
const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings
const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist
const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege
if pgErr.Code == ERRCODE_INVALID_PASSWORD ||
pgErr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION && c.tlsConfig != nil ||
pgErr.Code == ERRCODE_INVALID_CATALOG_NAME ||
pgErr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE {
return nil, allErrors
}
}
var npErr *NotPreferredError
if errors.As(err, &npErr) {
fallbackConnectOneConfig = c
}
}
if fallbackConnectOneConfig != nil {
pgConn, err := connectOne(ctx, config, fallbackConnectOneConfig, true)
if err == nil {
return pgConn, nil
}
allErrors = append(allErrors, err)
}
return nil, allErrors
}
// connectOne makes one connection attempt to a single host.
func connectOne(ctx context.Context, config *Config, connectConfig *connectOneConfig,
ignoreNotPreferredErr bool,
) (*PgConn, error) {
pgConn := new(PgConn)
pgConn.config = config
pgConn.cleanupDone = make(chan struct{})
pgConn.customData = make(map[string]any)
var err error
newPerDialConnectError := func(msg string, err error) *perDialConnectError {
err = normalizeTimeoutError(ctx, err)
e := &perDialConnectError{address: connectConfig.address, originalHostname: connectConfig.originalHostname, err: fmt.Errorf("%s: %w", msg, err)}
return e
}
pgConn.conn, err = config.DialFunc(ctx, connectConfig.network, connectConfig.address)
if err != nil {
return nil, newPerDialConnectError("dial error", err)
}
if connectConfig.tlsConfig != nil {
pgConn.contextWatcher = ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: pgConn.conn})
pgConn.contextWatcher.Watch(ctx)
tlsConn, err := startTLS(pgConn.conn, connectConfig.tlsConfig)
pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.
if err != nil {
pgConn.conn.Close()
return nil, newPerDialConnectError("tls error", err)
}
pgConn.conn = tlsConn
}
pgConn.contextWatcher = ctxwatch.NewContextWatcher(config.BuildContextWatcherHandler(pgConn))
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
pgConn.parameterStatuses = make(map[string]string)
pgConn.status = connStatusConnecting
pgConn.bgReader = bgreader.New(pgConn.conn)
pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64),
func() {
pgConn.bgReader.Start()
pgConn.bgReaderStarted <- struct{}{}
},
)
pgConn.slowWriteTimer.Stop()
pgConn.bgReaderStarted = make(chan struct{})
pgConn.frontend = config.BuildFrontend(pgConn.bgReader, pgConn.conn)
startupMsg := pgproto3.StartupMessage{
ProtocolVersion: pgproto3.ProtocolVersionNumber,
Parameters: make(map[string]string),
}
// Copy default run-time params
for k, v := range config.RuntimeParams {
startupMsg.Parameters[k] = v
}
startupMsg.Parameters["user"] = config.User
if config.Database != "" {
startupMsg.Parameters["database"] = config.Database
}
pgConn.frontend.Send(&startupMsg)
if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil {
pgConn.conn.Close()
return nil, newPerDialConnectError("failed to write startup message", err)
}
for {
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.conn.Close()
if err, ok := err.(*PgError); ok {
return nil, newPerDialConnectError("server error", err)
}
return nil, newPerDialConnectError("failed to receive message", err)
}
switch msg := msg.(type) {
case *pgproto3.BackendKeyData:
pgConn.pid = msg.ProcessID
pgConn.secretKey = msg.SecretKey
case *pgproto3.AuthenticationOk:
case *pgproto3.AuthenticationCleartextPassword:
err = pgConn.txPasswordMessage(pgConn.config.Password)
if err != nil {
pgConn.conn.Close()
return nil, newPerDialConnectError("failed to write password message", err)
}
case *pgproto3.AuthenticationMD5Password:
digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:]))
err = pgConn.txPasswordMessage(digestedPassword)
if err != nil {
pgConn.conn.Close()
return nil, newPerDialConnectError("failed to write password message", err)
}
case *pgproto3.AuthenticationSASL:
err = pgConn.scramAuth(msg.AuthMechanisms)
if err != nil {
pgConn.conn.Close()
return nil, newPerDialConnectError("failed SASL auth", err)
}
case *pgproto3.AuthenticationGSS:
err = pgConn.gssAuth()
if err != nil {
pgConn.conn.Close()
return nil, newPerDialConnectError("failed GSS auth", err)
}
case *pgproto3.ReadyForQuery:
pgConn.status = connStatusIdle
if config.ValidateConnect != nil {
// ValidateConnect may execute commands that cause the context to be watched again. Unwatch first to avoid
// the watch already in progress panic. This is that last thing done by this method so there is no need to
// restart the watch after ValidateConnect returns.
//
// See https://github.com/jackc/pgconn/issues/40.
pgConn.contextWatcher.Unwatch()
err := config.ValidateConnect(ctx, pgConn)
if err != nil {
if _, ok := err.(*NotPreferredError); ignoreNotPreferredErr && ok {
return pgConn, nil
}
pgConn.conn.Close()
return nil, newPerDialConnectError("ValidateConnect failed", err)
}
}
return pgConn, nil
case *pgproto3.ParameterStatus, *pgproto3.NoticeResponse:
// handled by ReceiveMessage
case *pgproto3.ErrorResponse:
pgConn.conn.Close()
return nil, newPerDialConnectError("server error", ErrorResponseToPgError(msg))
default:
pgConn.conn.Close()
return nil, newPerDialConnectError("received unexpected message", err)
}
}
}
func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) {
err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103})
if err != nil {
return nil, err
}
response := make([]byte, 1)
if _, err = io.ReadFull(conn, response); err != nil {
return nil, err
}
if response[0] != 'S' {
return nil, errors.New("server refused TLS connection")
}
return tls.Client(conn, tlsConfig), nil
}
func (pgConn *PgConn) txPasswordMessage(password string) (err error) {
pgConn.frontend.Send(&pgproto3.PasswordMessage{Password: password})
return pgConn.flushWithPotentialWriteReadDeadlock()
}
func hexMD5(s string) string {
hash := md5.New()
io.WriteString(hash, s)
return hex.EncodeToString(hash.Sum(nil))
}
func (pgConn *PgConn) signalMessage() chan struct{} {
if pgConn.bufferingReceive {
panic("BUG: signalMessage when already in progress")
}
pgConn.bufferingReceive = true
pgConn.bufferingReceiveMux.Lock()
ch := make(chan struct{})
go func() {
pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive()
pgConn.bufferingReceiveMux.Unlock()
close(ch)
}()
return ch
}
// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the
// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages
// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger
// the OnNotification callback.
//
// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly.
// See https://www.postgresql.org/docs/current/protocol.html.
func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessage, error) {
if err := pgConn.lock(); err != nil {
return nil, err
}
defer pgConn.unlock()
if ctx != context.Background() {
select {
case <-ctx.Done():
return nil, newContextAlreadyDoneError(ctx)
default:
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
}
msg, err := pgConn.receiveMessage()
if err != nil {
err = &pgconnError{
msg: "receive message failed",
err: normalizeTimeoutError(ctx, err),
safeToRetry: true,
}
}
return msg, err
}
// peekMessage peeks at the next message without setting up context cancellation.
func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {
if pgConn.peekedMsg != nil {
return pgConn.peekedMsg, nil
}
var msg pgproto3.BackendMessage
var err error
if pgConn.bufferingReceive {
pgConn.bufferingReceiveMux.Lock()
msg = pgConn.bufferingReceiveMsg
err = pgConn.bufferingReceiveErr
pgConn.bufferingReceiveMux.Unlock()
pgConn.bufferingReceive = false
// If a timeout error happened in the background try the read again.
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
msg, err = pgConn.frontend.Receive()
}
} else {
msg, err = pgConn.frontend.Receive()
}
if err != nil {
// Close on anything other than timeout error - everything else is fatal
var netErr net.Error
isNetErr := errors.As(err, &netErr)
if !(isNetErr && netErr.Timeout()) {
pgConn.asyncClose()
}
return nil, err
}
pgConn.peekedMsg = msg
return msg, nil
}
// receiveMessage receives a message without setting up context cancellation
func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
msg, err := pgConn.peekMessage()
if err != nil {
return nil, err
}
pgConn.peekedMsg = nil
switch msg := msg.(type) {
case *pgproto3.ReadyForQuery:
pgConn.txStatus = msg.TxStatus
case *pgproto3.ParameterStatus:
pgConn.parameterStatuses[msg.Name] = msg.Value
case *pgproto3.ErrorResponse:
err := ErrorResponseToPgError(msg)
if pgConn.config.OnPgError != nil && !pgConn.config.OnPgError(pgConn, err) {
pgConn.status = connStatusClosed
pgConn.conn.Close() // Ignore error as the connection is already broken and there is already an error to return.
close(pgConn.cleanupDone)
return nil, err
}
case *pgproto3.NoticeResponse:
if pgConn.config.OnNotice != nil {
pgConn.config.OnNotice(pgConn, noticeResponseToNotice(msg))
}
case *pgproto3.NotificationResponse:
if pgConn.config.OnNotification != nil {
pgConn.config.OnNotification(pgConn, &Notification{PID: msg.PID, Channel: msg.Channel, Payload: msg.Payload})
}
}
return msg, nil
}
// Conn returns the underlying net.Conn. This rarely necessary. If the connection will be directly used for reading or
// writing then SyncConn should usually be called before Conn.
func (pgConn *PgConn) Conn() net.Conn {
return pgConn.conn
}
// PID returns the backend PID.
func (pgConn *PgConn) PID() uint32 {
return pgConn.pid
}
// TxStatus returns the current TxStatus as reported by the server in the ReadyForQuery message.
//
// Possible return values:
//
// 'I' - idle / not in transaction
// 'T' - in a transaction
// 'E' - in a failed transaction
//
// See https://www.postgresql.org/docs/current/protocol-message-formats.html.
func (pgConn *PgConn) TxStatus() byte {
return pgConn.txStatus
}
// SecretKey returns the backend secret key used to send a cancel query message to the server.
func (pgConn *PgConn) SecretKey() uint32 {
return pgConn.secretKey
}
// Frontend returns the underlying *pgproto3.Frontend. This rarely necessary.
func (pgConn *PgConn) Frontend() *pgproto3.Frontend {
return pgConn.frontend
}
// Close closes a connection. It is safe to call Close on an already closed connection. Close attempts a clean close by
// sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The
// underlying net.Conn.Close() will always be called regardless of any other errors.
func (pgConn *PgConn) Close(ctx context.Context) error {
if pgConn.status == connStatusClosed {
return nil
}
pgConn.status = connStatusClosed
defer close(pgConn.cleanupDone)
defer pgConn.conn.Close()
if ctx != context.Background() {
// Close may be called while a cancellable query is in progress. This will most often be triggered by panic when
// a defer closes the connection (possibly indirectly via a transaction or a connection pool). Unwatch to end any
// previous watch. It is safe to Unwatch regardless of whether a watch is already is progress.
//
// See https://github.com/jackc/pgconn/issues/29
pgConn.contextWatcher.Unwatch()
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
}
// Ignore any errors sending Terminate message and waiting for server to close connection.
// This mimics the behavior of libpq PQfinish. It calls closePGconn which calls sendTerminateConn which purposefully
// ignores errors.
//
// See https://github.com/jackc/pgx/issues/637
pgConn.frontend.Send(&pgproto3.Terminate{})
pgConn.flushWithPotentialWriteReadDeadlock()
return pgConn.conn.Close()
}
// asyncClose marks the connection as closed and asynchronously sends a cancel query message and closes the underlying
// connection.
func (pgConn *PgConn) asyncClose() {
if pgConn.status == connStatusClosed {
return
}
pgConn.status = connStatusClosed
go func() {
defer close(pgConn.cleanupDone)
defer pgConn.conn.Close()
deadline := time.Now().Add(time.Second * 15)
ctx, cancel := context.WithDeadline(context.Background(), deadline)
defer cancel()
pgConn.CancelRequest(ctx)
pgConn.conn.SetDeadline(deadline)
pgConn.frontend.Send(&pgproto3.Terminate{})
pgConn.flushWithPotentialWriteReadDeadlock()
}()
}
// CleanupDone returns a channel that will be closed after all underlying resources have been cleaned up. A closed
// connection is no longer usable, but underlying resources, in particular the net.Conn, may not have finished closing
// yet. This is because certain errors such as a context cancellation require that the interrupted function call return
// immediately, but the error may also cause the connection to be closed. In these cases the underlying resources are
// closed asynchronously.
//
// This is only likely to be useful to connection pools. It gives them a way avoid establishing a new connection while
// an old connection is still being cleaned up and thereby exceeding the maximum pool size.
func (pgConn *PgConn) CleanupDone() chan (struct{}) {
return pgConn.cleanupDone
}
// IsClosed reports if the connection has been closed.
//
// CleanupDone() can be used to determine if all cleanup has been completed.
func (pgConn *PgConn) IsClosed() bool {
return pgConn.status < connStatusIdle
}
// IsBusy reports if the connection is busy.
func (pgConn *PgConn) IsBusy() bool {
return pgConn.status == connStatusBusy
}
// lock locks the connection.
func (pgConn *PgConn) lock() error {
switch pgConn.status {
case connStatusBusy:
return &connLockError{status: "conn busy"} // This only should be possible in case of an application bug.
case connStatusClosed:
return &connLockError{status: "conn closed"}
case connStatusUninitialized:
return &connLockError{status: "conn uninitialized"}
}
pgConn.status = connStatusBusy
return nil
}
func (pgConn *PgConn) unlock() {
switch pgConn.status {
case connStatusBusy:
pgConn.status = connStatusIdle
case connStatusClosed:
default:
panic("BUG: cannot unlock unlocked connection") // This should only be possible if there is a bug in this package.
}
}
// ParameterStatus returns the value of a parameter reported by the server (e.g.
// server_version). Returns an empty string for unknown parameters.
func (pgConn *PgConn) ParameterStatus(key string) string {
return pgConn.parameterStatuses[key]
}
// CommandTag is the status text returned by PostgreSQL for a query.
type CommandTag struct {
s string
}
// NewCommandTag makes a CommandTag from s.
func NewCommandTag(s string) CommandTag {
return CommandTag{s: s}
}
// RowsAffected returns the number of rows affected. If the CommandTag was not
// for a row affecting command (e.g. "CREATE TABLE") then it returns 0.
func (ct CommandTag) RowsAffected() int64 {
// Find last non-digit
idx := -1
for i := len(ct.s) - 1; i >= 0; i-- {
if ct.s[i] >= '0' && ct.s[i] <= '9' {
idx = i
} else {
break
}
}
if idx == -1 {
return 0
}
var n int64
for _, b := range ct.s[idx:] {
n = n*10 + int64(b-'0')
}
return n
}
func (ct CommandTag) String() string {
return ct.s
}
// Insert is true if the command tag starts with "INSERT".
func (ct CommandTag) Insert() bool {
return strings.HasPrefix(ct.s, "INSERT")
}
// Update is true if the command tag starts with "UPDATE".
func (ct CommandTag) Update() bool {
return strings.HasPrefix(ct.s, "UPDATE")
}
// Delete is true if the command tag starts with "DELETE".
func (ct CommandTag) Delete() bool {
return strings.HasPrefix(ct.s, "DELETE")
}
// Select is true if the command tag starts with "SELECT".
func (ct CommandTag) Select() bool {
return strings.HasPrefix(ct.s, "SELECT")
}
type FieldDescription struct {
Name string
TableOID uint32
TableAttributeNumber uint16
DataTypeOID uint32
DataTypeSize int16
TypeModifier int32
Format int16
}
func (pgConn *PgConn) convertRowDescription(dst []FieldDescription, rd *pgproto3.RowDescription) []FieldDescription {
if cap(dst) >= len(rd.Fields) {
dst = dst[:len(rd.Fields):len(rd.Fields)]
} else {
dst = make([]FieldDescription, len(rd.Fields))
}
for i := range rd.Fields {
dst[i].Name = string(rd.Fields[i].Name)
dst[i].TableOID = rd.Fields[i].TableOID
dst[i].TableAttributeNumber = rd.Fields[i].TableAttributeNumber
dst[i].DataTypeOID = rd.Fields[i].DataTypeOID
dst[i].DataTypeSize = rd.Fields[i].DataTypeSize
dst[i].TypeModifier = rd.Fields[i].TypeModifier
dst[i].Format = rd.Fields[i].Format
}
return dst
}
type StatementDescription struct {
Name string
SQL string
ParamOIDs []uint32
Fields []FieldDescription
}
// Prepare creates a prepared statement. If the name is empty, the anonymous prepared statement will be used. This
// allows Prepare to also to describe statements without creating a server-side prepared statement.
//
// Prepare does not send a PREPARE statement to the server. It uses the PostgreSQL Parse and Describe protocol messages
// directly.
func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) {
if err := pgConn.lock(); err != nil {
return nil, err
}
defer pgConn.unlock()
if ctx != context.Background() {
select {
case <-ctx.Done():
return nil, newContextAlreadyDoneError(ctx)
default:
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
}
pgConn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs})
pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name})
pgConn.frontend.SendSync(&pgproto3.Sync{})
err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
return nil, err
}
psd := &StatementDescription{Name: name, SQL: sql}
var parseErr error
readloop:
for {
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.asyncClose()
return nil, normalizeTimeoutError(ctx, err)
}
switch msg := msg.(type) {
case *pgproto3.ParameterDescription:
psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs))
copy(psd.ParamOIDs, msg.ParameterOIDs)
case *pgproto3.RowDescription:
psd.Fields = pgConn.convertRowDescription(nil, msg)
case *pgproto3.ErrorResponse:
parseErr = ErrorResponseToPgError(msg)
case *pgproto3.ReadyForQuery:
break readloop
}
}
if parseErr != nil {
return nil, parseErr
}
return psd, nil
}
// Deallocate deallocates a prepared statement.
//
// Deallocate does not send a DEALLOCATE statement to the server. It uses the PostgreSQL Close protocol message
// directly. This has slightly different behavior than executing DEALLOCATE statement.
// - Deallocate can succeed in an aborted transaction.
// - Deallocating a non-existent prepared statement is not an error.
func (pgConn *PgConn) Deallocate(ctx context.Context, name string) error {
if err := pgConn.lock(); err != nil {
return err
}
defer pgConn.unlock()
if ctx != context.Background() {
select {
case <-ctx.Done():
return newContextAlreadyDoneError(ctx)
default:
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
}
pgConn.frontend.SendClose(&pgproto3.Close{ObjectType: 'S', Name: name})
pgConn.frontend.SendSync(&pgproto3.Sync{})
err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
return err
}
for {
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.asyncClose()
return normalizeTimeoutError(ctx, err)
}
switch msg := msg.(type) {
case *pgproto3.ErrorResponse:
return ErrorResponseToPgError(msg)
case *pgproto3.ReadyForQuery:
return nil
}
}
}
// ErrorResponseToPgError converts a wire protocol error message to a *PgError.
func ErrorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError {
return &PgError{
Severity: msg.Severity,
SeverityUnlocalized: msg.SeverityUnlocalized,
Code: string(msg.Code),
Message: string(msg.Message),
Detail: string(msg.Detail),
Hint: msg.Hint,
Position: msg.Position,
InternalPosition: msg.InternalPosition,
InternalQuery: string(msg.InternalQuery),
Where: string(msg.Where),
SchemaName: string(msg.SchemaName),
TableName: string(msg.TableName),
ColumnName: string(msg.ColumnName),
DataTypeName: string(msg.DataTypeName),
ConstraintName: msg.ConstraintName,
File: string(msg.File),
Line: msg.Line,
Routine: string(msg.Routine),
}
}
func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice {
pgerr := ErrorResponseToPgError((*pgproto3.ErrorResponse)(msg))
return (*Notice)(pgerr)
}
// CancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel
// request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there
// is no way to be sure a query was canceled. See https://www.postgresql.org/docs/11/protocol-flow.html#id-1.10.5.7.9
func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
// Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing
// the connection config. This is important in high availability configurations where fallback connections may be
// specified or DNS may be used to load balance.
serverAddr := pgConn.conn.RemoteAddr()
var serverNetwork string
var serverAddress string
if serverAddr.Network() == "unix" {
// for unix sockets, RemoteAddr() calls getpeername() which returns the name the
// server passed to bind(). For Postgres, this is always a relative path "./.s.PGSQL.5432"
// so connecting to it will fail. Fall back to the config's value
serverNetwork, serverAddress = NetworkAddress(pgConn.config.Host, pgConn.config.Port)
} else {
serverNetwork, serverAddress = serverAddr.Network(), serverAddr.String()
}
cancelConn, err := pgConn.config.DialFunc(ctx, serverNetwork, serverAddress)
if err != nil {
// In case of unix sockets, RemoteAddr() returns only the file part of the path. If the
// first connect failed, try the config.
if serverAddr.Network() != "unix" {
return err
}
serverNetwork, serverAddr := NetworkAddress(pgConn.config.Host, pgConn.config.Port)
cancelConn, err = pgConn.config.DialFunc(ctx, serverNetwork, serverAddr)
if err != nil {
return err
}
}
defer cancelConn.Close()
if ctx != context.Background() {
contextWatcher := ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: cancelConn})
contextWatcher.Watch(ctx)
defer contextWatcher.Unwatch()
}
buf := make([]byte, 16)
binary.BigEndian.PutUint32(buf[0:4], 16)
binary.BigEndian.PutUint32(buf[4:8], 80877102)
binary.BigEndian.PutUint32(buf[8:12], pgConn.pid)
binary.BigEndian.PutUint32(buf[12:16], pgConn.secretKey)
if _, err := cancelConn.Write(buf); err != nil {
return fmt.Errorf("write to connection for cancellation: %w", err)
}
// Wait for the cancel request to be acknowledged by the server.
// It copies the behavior of the libpq: https://github.com/postgres/postgres/blob/REL_16_0/src/interfaces/libpq/fe-connect.c#L4946-L4960
_, _ = cancelConn.Read(buf)
return nil
}
// WaitForNotification waits for a LISTEN/NOTIFY message to be received. It returns an error if a notification was not
// received.
func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
if err := pgConn.lock(); err != nil {
return err
}
defer pgConn.unlock()
if ctx != context.Background() {
select {
case <-ctx.Done():
return newContextAlreadyDoneError(ctx)
default:
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
}
for {
msg, err := pgConn.receiveMessage()
if err != nil {
return normalizeTimeoutError(ctx, err)
}
switch msg.(type) {
case *pgproto3.NotificationResponse:
return nil
}
}
}
// Exec executes SQL via the PostgreSQL simple query protocol. SQL may contain multiple queries. Execution is
// implicitly wrapped in a transaction unless a transaction is already in progress or SQL contains transaction control
// statements.
//
// Prefer ExecParams unless executing arbitrary SQL that may contain multiple queries.
func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
if err := pgConn.lock(); err != nil {
return &MultiResultReader{
closed: true,
err: err,
}
}
pgConn.multiResultReader = MultiResultReader{
pgConn: pgConn,
ctx: ctx,
}
multiResult := &pgConn.multiResultReader
if ctx != context.Background() {
select {
case <-ctx.Done():
multiResult.closed = true
multiResult.err = newContextAlreadyDoneError(ctx)
pgConn.unlock()
return multiResult
default:
}
pgConn.contextWatcher.Watch(ctx)
}
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
pgConn.contextWatcher.Unwatch()
multiResult.closed = true
multiResult.err = err
pgConn.unlock()
return multiResult
}
return multiResult
}
// ExecParams executes a command via the PostgreSQL extended query protocol.
//
// sql is a SQL command string. It may only contain one query. Parameter substitution is positional using $1, $2, $3,
// etc.
//
// paramValues are the parameter values. It must be encoded in the format given by paramFormats.
//
// paramOIDs is a slice of data type OIDs for paramValues. If paramOIDs is nil, the server will infer the data type for
// all parameters. Any paramOID element that is 0 that will cause the server to infer the data type for that parameter.
// ExecParams will panic if len(paramOIDs) is not 0, 1, or len(paramValues).
//
// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or
// binary format. If paramFormats is nil all params are text format. ExecParams will panic if
// len(paramFormats) is not 0, 1, or len(paramValues).
//
// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or
// binary format. If resultFormats is nil all results will be in text format.
//
// ResultReader must be closed before PgConn can be used again.
func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *ResultReader {
result := pgConn.execExtendedPrefix(ctx, paramValues)
if result.closed {
return result
}
pgConn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs})
pgConn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
pgConn.execExtendedSuffix(result)
return result
}
// ExecPrepared enqueues the execution of a prepared statement via the PostgreSQL extended query protocol.
//
// paramValues are the parameter values. It must be encoded in the format given by paramFormats.
//
// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or
// binary format. If paramFormats is nil all params are text format. ExecPrepared will panic if
// len(paramFormats) is not 0, 1, or len(paramValues).
//
// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or
// binary format. If resultFormats is nil all results will be in text format.
//
// ResultReader must be closed before PgConn can be used again.
func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *ResultReader {
result := pgConn.execExtendedPrefix(ctx, paramValues)
if result.closed {
return result
}
pgConn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
pgConn.execExtendedSuffix(result)
return result
}
func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader {
pgConn.resultReader = ResultReader{
pgConn: pgConn,
ctx: ctx,
}
result := &pgConn.resultReader
if err := pgConn.lock(); err != nil {
result.concludeCommand(CommandTag{}, err)
result.closed = true
return result
}
if len(paramValues) > math.MaxUint16 {
result.concludeCommand(CommandTag{}, fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16))
result.closed = true
pgConn.unlock()
return result
}
if ctx != context.Background() {
select {
case <-ctx.Done():
result.concludeCommand(CommandTag{}, newContextAlreadyDoneError(ctx))
result.closed = true
pgConn.unlock()
return result
default:
}
pgConn.contextWatcher.Watch(ctx)
}
return result
}
func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) {
pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'})
pgConn.frontend.SendExecute(&pgproto3.Execute{})
pgConn.frontend.SendSync(&pgproto3.Sync{})
err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
result.concludeCommand(CommandTag{}, err)
pgConn.contextWatcher.Unwatch()
result.closed = true
pgConn.unlock()
return
}
result.readUntilRowDescription()
}
// CopyTo executes the copy command sql and copies the results to w.
func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) {
if err := pgConn.lock(); err != nil {
return CommandTag{}, err
}
if ctx != context.Background() {
select {
case <-ctx.Done():
pgConn.unlock()
return CommandTag{}, newContextAlreadyDoneError(ctx)
default:
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
}
// Send copy to command
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
pgConn.unlock()
return CommandTag{}, err
}
// Read results
var commandTag CommandTag
var pgErr error
for {
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.asyncClose()
return CommandTag{}, normalizeTimeoutError(ctx, err)
}
switch msg := msg.(type) {
case *pgproto3.CopyDone:
case *pgproto3.CopyData:
_, err := w.Write(msg.Data)
if err != nil {
pgConn.asyncClose()
return CommandTag{}, err
}
case *pgproto3.ReadyForQuery:
pgConn.unlock()
return commandTag, pgErr
case *pgproto3.CommandComplete:
commandTag = pgConn.makeCommandTag(msg.CommandTag)
case *pgproto3.ErrorResponse:
pgErr = ErrorResponseToPgError(msg)
}
}
}
// CopyFrom executes the copy command sql and copies all of r to the PostgreSQL server.
//
// Note: context cancellation will only interrupt operations on the underlying PostgreSQL network connection. Reads on r
// could still block.
func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) {
if err := pgConn.lock(); err != nil {
return CommandTag{}, err
}
defer pgConn.unlock()
if ctx != context.Background() {
select {
case <-ctx.Done():
return CommandTag{}, newContextAlreadyDoneError(ctx)
default:
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
}
// Send copy from query
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
return CommandTag{}, err
}
// Send copy data
abortCopyChan := make(chan struct{})
copyErrChan := make(chan error, 1)
signalMessageChan := pgConn.signalMessage()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
buf := iobufpool.Get(65536)
defer iobufpool.Put(buf)
(*buf)[0] = 'd'
for {
n, readErr := r.Read((*buf)[5:cap(*buf)])
if n > 0 {
*buf = (*buf)[0 : n+5]
pgio.SetInt32((*buf)[1:], int32(n+4))
writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(*buf)
if writeErr != nil {
// Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine. Not
// setting pgConn.status or closing pgConn.cleanupDone for the same reason.
pgConn.conn.Close()
copyErrChan <- writeErr
return
}
}
if readErr != nil {
copyErrChan <- readErr
return
}
select {
case <-abortCopyChan:
return
default:
}
}
}()
var pgErr error
var copyErr error
for copyErr == nil && pgErr == nil {
select {
case copyErr = <-copyErrChan:
case <-signalMessageChan:
// If pgConn.receiveMessage encounters an error it will call pgConn.asyncClose. But that is a race condition with
// the goroutine. So instead check pgConn.bufferingReceiveErr which will have been set by the signalMessage. If an
// error is found then forcibly close the connection without sending the Terminate message.
if err := pgConn.bufferingReceiveErr; err != nil {
pgConn.status = connStatusClosed
pgConn.conn.Close()
close(pgConn.cleanupDone)
return CommandTag{}, normalizeTimeoutError(ctx, err)
}
msg, _ := pgConn.receiveMessage()
switch msg := msg.(type) {
case *pgproto3.ErrorResponse:
pgErr = ErrorResponseToPgError(msg)
default:
signalMessageChan = pgConn.signalMessage()
}
}
}
close(abortCopyChan)
// Make sure io goroutine finishes before writing.
wg.Wait()
if copyErr == io.EOF || pgErr != nil {
pgConn.frontend.Send(&pgproto3.CopyDone{})
} else {
pgConn.frontend.Send(&pgproto3.CopyFail{Message: copyErr.Error()})
}
err = pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
return CommandTag{}, err
}
// Read results
var commandTag CommandTag
for {
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.asyncClose()
return CommandTag{}, normalizeTimeoutError(ctx, err)
}
switch msg := msg.(type) {
case *pgproto3.ReadyForQuery:
return commandTag, pgErr
case *pgproto3.CommandComplete:
commandTag = pgConn.makeCommandTag(msg.CommandTag)
case *pgproto3.ErrorResponse:
pgErr = ErrorResponseToPgError(msg)
}
}
}
// MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch.
type MultiResultReader struct {
pgConn *PgConn
ctx context.Context
pipeline *Pipeline
rr *ResultReader
closed bool
err error
}
// ReadAll reads all available results. Calling ReadAll is mutually exclusive with all other MultiResultReader methods.
func (mrr *MultiResultReader) ReadAll() ([]*Result, error) {
var results []*Result
for mrr.NextResult() {
results = append(results, mrr.ResultReader().Read())
}
err := mrr.Close()
return results, err
}
func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) {
msg, err := mrr.pgConn.receiveMessage()
if err != nil {
mrr.pgConn.contextWatcher.Unwatch()
mrr.err = normalizeTimeoutError(mrr.ctx, err)
mrr.closed = true
mrr.pgConn.asyncClose()
return nil, mrr.err
}
switch msg := msg.(type) {
case *pgproto3.ReadyForQuery:
mrr.closed = true
if mrr.pipeline != nil {
mrr.pipeline.expectedReadyForQueryCount--
} else {
mrr.pgConn.contextWatcher.Unwatch()
mrr.pgConn.unlock()
}
case *pgproto3.ErrorResponse:
mrr.err = ErrorResponseToPgError(msg)
}
return msg, nil
}
// NextResult returns advances the MultiResultReader to the next result and returns true if a result is available.
func (mrr *MultiResultReader) NextResult() bool {
for !mrr.closed && mrr.err == nil {
msg, err := mrr.receiveMessage()
if err != nil {
return false
}
switch msg := msg.(type) {
case *pgproto3.RowDescription:
mrr.pgConn.resultReader = ResultReader{
pgConn: mrr.pgConn,
multiResultReader: mrr,
ctx: mrr.ctx,
fieldDescriptions: mrr.pgConn.convertRowDescription(mrr.pgConn.fieldDescriptions[:], msg),
}
mrr.rr = &mrr.pgConn.resultReader
return true
case *pgproto3.CommandComplete:
mrr.pgConn.resultReader = ResultReader{
commandTag: mrr.pgConn.makeCommandTag(msg.CommandTag),
commandConcluded: true,
closed: true,
}
mrr.rr = &mrr.pgConn.resultReader
return true
case *pgproto3.EmptyQueryResponse:
return false
}
}
return false
}
// ResultReader returns the current ResultReader.
func (mrr *MultiResultReader) ResultReader() *ResultReader {
return mrr.rr
}
// Close closes the MultiResultReader and returns the first error that occurred during the MultiResultReader's use.
func (mrr *MultiResultReader) Close() error {
for !mrr.closed {
_, err := mrr.receiveMessage()
if err != nil {
return mrr.err
}
}
return mrr.err
}
// ResultReader is a reader for the result of a single query.
type ResultReader struct {
pgConn *PgConn
multiResultReader *MultiResultReader
pipeline *Pipeline
ctx context.Context
fieldDescriptions []FieldDescription
rowValues [][]byte
commandTag CommandTag
commandConcluded bool
closed bool
err error
}
// Result is the saved query response that is returned by calling Read on a ResultReader.
type Result struct {
FieldDescriptions []FieldDescription
Rows [][][]byte
CommandTag CommandTag
Err error
}
// Read saves the query response to a Result.
func (rr *ResultReader) Read() *Result {
br := &Result{}
for rr.NextRow() {
if br.FieldDescriptions == nil {
br.FieldDescriptions = make([]FieldDescription, len(rr.FieldDescriptions()))
copy(br.FieldDescriptions, rr.FieldDescriptions())
}
values := rr.Values()
row := make([][]byte, len(values))
for i := range row {
if values[i] != nil {
row[i] = make([]byte, len(values[i]))
copy(row[i], values[i])
}
}
br.Rows = append(br.Rows, row)
}
br.CommandTag, br.Err = rr.Close()
return br
}
// NextRow advances the ResultReader to the next row and returns true if a row is available.
func (rr *ResultReader) NextRow() bool {
for !rr.commandConcluded {
msg, err := rr.receiveMessage()
if err != nil {
return false
}
switch msg := msg.(type) {
case *pgproto3.DataRow:
rr.rowValues = msg.Values
return true
}
}
return false
}
// FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until
// the ResultReader is closed. It may return nil (for example, if the query did not return a result set or an error was
// encountered.)
func (rr *ResultReader) FieldDescriptions() []FieldDescription {
return rr.fieldDescriptions
}
// Values returns the current row data. NextRow must have been previously been called. The returned [][]byte is only
// valid until the next NextRow call or the ResultReader is closed.
func (rr *ResultReader) Values() [][]byte {
return rr.rowValues
}
// Close consumes any remaining result data and returns the command tag or
// error.
func (rr *ResultReader) Close() (CommandTag, error) {
if rr.closed {
return rr.commandTag, rr.err
}
rr.closed = true
for !rr.commandConcluded {
_, err := rr.receiveMessage()
if err != nil {
return CommandTag{}, rr.err
}
}
if rr.multiResultReader == nil && rr.pipeline == nil {
for {
msg, err := rr.receiveMessage()
if err != nil {
return CommandTag{}, rr.err
}
switch msg := msg.(type) {
// Detect a deferred constraint violation where the ErrorResponse is sent after CommandComplete.
case *pgproto3.ErrorResponse:
rr.err = ErrorResponseToPgError(msg)
case *pgproto3.ReadyForQuery:
rr.pgConn.contextWatcher.Unwatch()
rr.pgConn.unlock()
return rr.commandTag, rr.err
}
}
}
return rr.commandTag, rr.err
}
// readUntilRowDescription ensures the ResultReader's fieldDescriptions are loaded. It does not return an error as any
// error will be stored in the ResultReader.
func (rr *ResultReader) readUntilRowDescription() {
for !rr.commandConcluded {
// Peek before receive to avoid consuming a DataRow if the result set does not include a RowDescription method.
// This should never happen under normal pgconn usage, but it is possible if SendBytes and ReceiveResults are
// manually used to construct a query that does not issue a describe statement.
msg, _ := rr.pgConn.peekMessage()
if _, ok := msg.(*pgproto3.DataRow); ok {
return
}
// Consume the message
msg, _ = rr.receiveMessage()
if _, ok := msg.(*pgproto3.RowDescription); ok {
return
}
}
}
func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) {
if rr.multiResultReader == nil {
msg, err = rr.pgConn.receiveMessage()
} else {
msg, err = rr.multiResultReader.receiveMessage()
}
if err != nil {
err = normalizeTimeoutError(rr.ctx, err)
rr.concludeCommand(CommandTag{}, err)
rr.pgConn.contextWatcher.Unwatch()
rr.closed = true
if rr.multiResultReader == nil {
rr.pgConn.asyncClose()
}
return nil, rr.err
}
switch msg := msg.(type) {
case *pgproto3.RowDescription:
rr.fieldDescriptions = rr.pgConn.convertRowDescription(rr.pgConn.fieldDescriptions[:], msg)
case *pgproto3.CommandComplete:
rr.concludeCommand(rr.pgConn.makeCommandTag(msg.CommandTag), nil)
case *pgproto3.EmptyQueryResponse:
rr.concludeCommand(CommandTag{}, nil)
case *pgproto3.ErrorResponse:
rr.concludeCommand(CommandTag{}, ErrorResponseToPgError(msg))
}
return msg, nil
}
func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) {
// Keep the first error that is recorded. Store the error before checking if the command is already concluded to
// allow for receiving an error after CommandComplete but before ReadyForQuery.
if err != nil && rr.err == nil {
rr.err = err
}
if rr.commandConcluded {
return
}
rr.commandTag = commandTag
rr.rowValues = nil
rr.commandConcluded = true
}
// Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip.
type Batch struct {
buf []byte
err error
}
// ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions.
func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) {
if batch.err != nil {
return
}
batch.buf, batch.err = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf)
if batch.err != nil {
return
}
batch.ExecPrepared("", paramValues, paramFormats, resultFormats)
}
// ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions.
func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) {
if batch.err != nil {
return
}
batch.buf, batch.err = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf)
if batch.err != nil {
return
}
batch.buf, batch.err = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf)
if batch.err != nil {
return
}
batch.buf, batch.err = (&pgproto3.Execute{}).Encode(batch.buf)
if batch.err != nil {
return
}
}
// ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a
// transaction is already in progress or SQL contains transaction control statements. This is a simpler way of executing
// multiple queries in a single round trip than using pipeline mode.
func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader {
if batch.err != nil {
return &MultiResultReader{
closed: true,
err: batch.err,
}
}
if err := pgConn.lock(); err != nil {
return &MultiResultReader{
closed: true,
err: err,
}
}
pgConn.multiResultReader = MultiResultReader{
pgConn: pgConn,
ctx: ctx,
}
multiResult := &pgConn.multiResultReader
if ctx != context.Background() {
select {
case <-ctx.Done():
multiResult.closed = true
multiResult.err = newContextAlreadyDoneError(ctx)
pgConn.unlock()
return multiResult
default:
}
pgConn.contextWatcher.Watch(ctx)
}
batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf)
if batch.err != nil {
multiResult.closed = true
multiResult.err = batch.err
pgConn.unlock()
return multiResult
}
pgConn.enterPotentialWriteReadDeadlock()
defer pgConn.exitPotentialWriteReadDeadlock()
_, err := pgConn.conn.Write(batch.buf)
if err != nil {
multiResult.closed = true
multiResult.err = err
pgConn.unlock()
return multiResult
}
return multiResult
}
// EscapeString escapes a string such that it can safely be interpolated into a SQL command string. It does not include
// the surrounding single quotes.
//
// The current implementation requires that standard_conforming_strings=on and client_encoding="UTF8". If these
// conditions are not met an error will be returned. It is possible these restrictions will be lifted in the future.
func (pgConn *PgConn) EscapeString(s string) (string, error) {
if pgConn.ParameterStatus("standard_conforming_strings") != "on" {
return "", errors.New("EscapeString must be run with standard_conforming_strings=on")
}
if pgConn.ParameterStatus("client_encoding") != "UTF8" {
return "", errors.New("EscapeString must be run with client_encoding=UTF8")
}
return strings.Replace(s, "'", "''", -1), nil
}
// CheckConn checks the underlying connection without writing any bytes. This is currently implemented by doing a read
// with a very short deadline. This can be useful because a TCP connection can be broken such that a write will appear
// to succeed even though it will never actually reach the server. Reading immediately before a write will detect this
// condition. If this is done immediately before sending a query it reduces the chances a query will be sent that fails
// without the client knowing whether the server received it or not.
//
// Deprecated: CheckConn is deprecated in favor of Ping. CheckConn cannot detect all types of broken connections where
// the write would still appear to succeed. Prefer Ping unless on a high latency connection.
func (pgConn *PgConn) CheckConn() error {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
defer cancel()
_, err := pgConn.ReceiveMessage(ctx)
if err != nil {
if !Timeout(err) {
return err
}
}
return nil
}
// Ping pings the server. This can be useful because a TCP connection can be broken such that a write will appear to
// succeed even though it will never actually reach the server. Pinging immediately before sending a query reduces the
// chances a query will be sent that fails without the client knowing whether the server received it or not.
func (pgConn *PgConn) Ping(ctx context.Context) error {
return pgConn.Exec(ctx, "-- ping").Close()
}
// makeCommandTag makes a CommandTag. It does not retain a reference to buf or buf's underlying memory.
func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag {
return CommandTag{s: string(buf)}
}
// enterPotentialWriteReadDeadlock must be called before a write that could deadlock if the server is simultaneously
// blocked writing to us.
func (pgConn *PgConn) enterPotentialWriteReadDeadlock() {
// The time to wait is somewhat arbitrary. A Write should only take as long as the syscall and memcpy to the OS
// outbound network buffer unless the buffer is full (which potentially is a block). It needs to be long enough for
// the normal case, but short enough not to kill performance if a block occurs.
//
// In addition, on Windows the default timer resolution is 15.6ms. So setting the timer to less than that is
// ineffective.
if pgConn.slowWriteTimer.Reset(15 * time.Millisecond) {
panic("BUG: slow write timer already active")
}
}
// exitPotentialWriteReadDeadlock must be called after a call to enterPotentialWriteReadDeadlock.
func (pgConn *PgConn) exitPotentialWriteReadDeadlock() {
if !pgConn.slowWriteTimer.Stop() {
// The timer starts its function in a separate goroutine. It is necessary to ensure the background reader has
// started before calling Stop. Otherwise, the background reader may not be stopped. That on its own is not a
// serious problem. But what is a serious problem is that the background reader may start at an inopportune time in
// a subsequent query. For example, if a subsequent query was canceled then a deadline may be set on the net.Conn to
// interrupt an in-progress read. After the read is interrupted, but before the deadline is cleared, the background
// reader could start and read a deadline error. Then the next query would receive the an unexpected deadline error.
<-pgConn.bgReaderStarted
pgConn.bgReader.Stop()
}
}
func (pgConn *PgConn) flushWithPotentialWriteReadDeadlock() error {
pgConn.enterPotentialWriteReadDeadlock()
defer pgConn.exitPotentialWriteReadDeadlock()
err := pgConn.frontend.Flush()
return err
}
// SyncConn prepares the underlying net.Conn for direct use. PgConn may internally buffer reads or use goroutines for
// background IO. This means that any direct use of the underlying net.Conn may be corrupted if a read is already
// buffered or a read is in progress. SyncConn drains read buffers and stops background IO. In some cases this may
// require sending a ping to the server. ctx can be used to cancel this operation. This should be called before any
// operation that will use the underlying net.Conn directly. e.g. Before Conn() or Hijack().
//
// This should not be confused with the PostgreSQL protocol Sync message.
func (pgConn *PgConn) SyncConn(ctx context.Context) error {
for i := 0; i < 10; i++ {
if pgConn.bgReader.Status() == bgreader.StatusStopped && pgConn.frontend.ReadBufferLen() == 0 {
return nil
}
err := pgConn.Ping(ctx)
if err != nil {
return fmt.Errorf("SyncConn: Ping failed while syncing conn: %w", err)
}
}
// This should never happen. Only way I can imagine this occurring is if the server is constantly sending data such as
// LISTEN/NOTIFY or log notifications such that we never can get an empty buffer.
return errors.New("SyncConn: conn never synchronized")
}
// CustomData returns a map that can be used to associate custom data with the connection.
func (pgConn *PgConn) CustomData() map[string]any {
return pgConn.customData
}
// HijackedConn is the result of hijacking a connection.
//
// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning
// compatibility.
type HijackedConn struct {
Conn net.Conn
PID uint32 // backend pid
SecretKey uint32 // key to use to send a cancel query message to the server
ParameterStatuses map[string]string // parameters that have been reported by the server
TxStatus byte
Frontend *pgproto3.Frontend
Config *Config
CustomData map[string]any
}
// Hijack extracts the internal connection data. pgConn must be in an idle state. SyncConn should be called immediately
// before Hijack. pgConn is unusable after hijacking. Hijacking is typically only useful when using pgconn to establish
// a connection, but taking complete control of the raw connection after that (e.g. a load balancer or proxy).
//
// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning
// compatibility.
func (pgConn *PgConn) Hijack() (*HijackedConn, error) {
if err := pgConn.lock(); err != nil {
return nil, err
}
pgConn.status = connStatusClosed
return &HijackedConn{
Conn: pgConn.conn,
PID: pgConn.pid,
SecretKey: pgConn.secretKey,
ParameterStatuses: pgConn.parameterStatuses,
TxStatus: pgConn.txStatus,
Frontend: pgConn.frontend,
Config: pgConn.config,
CustomData: pgConn.customData,
}, nil
}
// Construct created a PgConn from an already established connection to a PostgreSQL server. This is the inverse of
// PgConn.Hijack. The connection must be in an idle state.
//
// hc.Frontend is replaced by a new pgproto3.Frontend built by hc.Config.BuildFrontend.
//
// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning
// compatibility.
func Construct(hc *HijackedConn) (*PgConn, error) {
pgConn := &PgConn{
conn: hc.Conn,
pid: hc.PID,
secretKey: hc.SecretKey,
parameterStatuses: hc.ParameterStatuses,
txStatus: hc.TxStatus,
frontend: hc.Frontend,
config: hc.Config,
customData: hc.CustomData,
status: connStatusIdle,
cleanupDone: make(chan struct{}),
}
pgConn.contextWatcher = ctxwatch.NewContextWatcher(hc.Config.BuildContextWatcherHandler(pgConn))
pgConn.bgReader = bgreader.New(pgConn.conn)
pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64),
func() {
pgConn.bgReader.Start()
pgConn.bgReaderStarted <- struct{}{}
},
)
pgConn.slowWriteTimer.Stop()
pgConn.bgReaderStarted = make(chan struct{})
pgConn.frontend = hc.Config.BuildFrontend(pgConn.bgReader, pgConn.conn)
return pgConn, nil
}
// Pipeline represents a connection in pipeline mode.
//
// SendPrepare, SendQueryParams, and SendQueryPrepared queue requests to the server. These requests are not written until
// pipeline is flushed by Flush or Sync. Sync must be called after the last request is queued. Requests between
// synchronization points are implicitly transactional unless explicit transaction control statements have been issued.
//
// The context the pipeline was started with is in effect for the entire life of the Pipeline.
//
// For a deeper understanding of pipeline mode see the PostgreSQL documentation for the extended query protocol
// (https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY) and the libpq pipeline mode
// (https://www.postgresql.org/docs/current/libpq-pipeline-mode.html).
type Pipeline struct {
conn *PgConn
ctx context.Context
expectedReadyForQueryCount int
pendingSync bool
err error
closed bool
}
// PipelineSync is returned by GetResults when a ReadyForQuery message is received.
type PipelineSync struct{}
// CloseComplete is returned by GetResults when a CloseComplete message is received.
type CloseComplete struct{}
// StartPipeline switches the connection to pipeline mode and returns a *Pipeline. In pipeline mode requests can be sent
// to the server without waiting for a response. Close must be called on the returned *Pipeline to return the connection
// to normal mode. While in pipeline mode, no methods that communicate with the server may be called except
// CancelRequest and Close. ctx is in effect for entire life of the *Pipeline.
//
// Prefer ExecBatch when only sending one group of queries at once.
func (pgConn *PgConn) StartPipeline(ctx context.Context) *Pipeline {
if err := pgConn.lock(); err != nil {
return &Pipeline{
closed: true,
err: err,
}
}
pgConn.pipeline = Pipeline{
conn: pgConn,
ctx: ctx,
}
pipeline := &pgConn.pipeline
if ctx != context.Background() {
select {
case <-ctx.Done():
pipeline.closed = true
pipeline.err = newContextAlreadyDoneError(ctx)
pgConn.unlock()
return pipeline
default:
}
pgConn.contextWatcher.Watch(ctx)
}
return pipeline
}
// SendPrepare is the pipeline version of *PgConn.Prepare.
func (p *Pipeline) SendPrepare(name, sql string, paramOIDs []uint32) {
if p.closed {
return
}
p.pendingSync = true
p.conn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs})
p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name})
}
// SendDeallocate deallocates a prepared statement.
func (p *Pipeline) SendDeallocate(name string) {
if p.closed {
return
}
p.pendingSync = true
p.conn.frontend.SendClose(&pgproto3.Close{ObjectType: 'S', Name: name})
}
// SendQueryParams is the pipeline version of *PgConn.QueryParams.
func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) {
if p.closed {
return
}
p.pendingSync = true
p.conn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs})
p.conn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'})
p.conn.frontend.SendExecute(&pgproto3.Execute{})
}
// SendQueryPrepared is the pipeline version of *PgConn.QueryPrepared.
func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) {
if p.closed {
return
}
p.pendingSync = true
p.conn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'})
p.conn.frontend.SendExecute(&pgproto3.Execute{})
}
// Flush flushes the queued requests without establishing a synchronization point.
func (p *Pipeline) Flush() error {
if p.closed {
if p.err != nil {
return p.err
}
return errors.New("pipeline closed")
}
err := p.conn.flushWithPotentialWriteReadDeadlock()
if err != nil {
err = normalizeTimeoutError(p.ctx, err)
p.conn.asyncClose()
p.conn.contextWatcher.Unwatch()
p.conn.unlock()
p.closed = true
p.err = err
return err
}
return nil
}
// Sync establishes a synchronization point and flushes the queued requests.
func (p *Pipeline) Sync() error {
if p.closed {
if p.err != nil {
return p.err
}
return errors.New("pipeline closed")
}
p.conn.frontend.SendSync(&pgproto3.Sync{})
err := p.Flush()
if err != nil {
return err
}
p.pendingSync = false
p.expectedReadyForQueryCount++
return nil
}
// GetResults gets the next results. If results are present, results may be a *ResultReader, *StatementDescription, or
// *PipelineSync. If an ErrorResponse is received from the server, results will be nil and err will be a *PgError. If no
// results are available, results and err will both be nil.
func (p *Pipeline) GetResults() (results any, err error) {
if p.closed {
if p.err != nil {
return nil, p.err
}
return nil, errors.New("pipeline closed")
}
if p.expectedReadyForQueryCount == 0 {
return nil, nil
}
return p.getResults()
}
func (p *Pipeline) getResults() (results any, err error) {
for {
msg, err := p.conn.receiveMessage()
if err != nil {
p.closed = true
p.err = err
p.conn.asyncClose()
return nil, normalizeTimeoutError(p.ctx, err)
}
switch msg := msg.(type) {
case *pgproto3.RowDescription:
p.conn.resultReader = ResultReader{
pgConn: p.conn,
pipeline: p,
ctx: p.ctx,
fieldDescriptions: p.conn.convertRowDescription(p.conn.fieldDescriptions[:], msg),
}
return &p.conn.resultReader, nil
case *pgproto3.CommandComplete:
p.conn.resultReader = ResultReader{
commandTag: p.conn.makeCommandTag(msg.CommandTag),
commandConcluded: true,
closed: true,
}
return &p.conn.resultReader, nil
case *pgproto3.ParseComplete:
peekedMsg, err := p.conn.peekMessage()
if err != nil {
p.conn.asyncClose()
return nil, normalizeTimeoutError(p.ctx, err)
}
if _, ok := peekedMsg.(*pgproto3.ParameterDescription); ok {
return p.getResultsPrepare()
}
case *pgproto3.CloseComplete:
return &CloseComplete{}, nil
case *pgproto3.ReadyForQuery:
p.expectedReadyForQueryCount--
return &PipelineSync{}, nil
case *pgproto3.ErrorResponse:
pgErr := ErrorResponseToPgError(msg)
return nil, pgErr
}
}
}
func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) {
psd := &StatementDescription{}
for {
msg, err := p.conn.receiveMessage()
if err != nil {
p.conn.asyncClose()
return nil, normalizeTimeoutError(p.ctx, err)
}
switch msg := msg.(type) {
case *pgproto3.ParameterDescription:
psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs))
copy(psd.ParamOIDs, msg.ParameterOIDs)
case *pgproto3.RowDescription:
psd.Fields = p.conn.convertRowDescription(nil, msg)
return psd, nil
// NoData is returned instead of RowDescription when there is no expected result. e.g. An INSERT without a RETURNING
// clause.
case *pgproto3.NoData:
return psd, nil
// These should never happen here. But don't take chances that could lead to a deadlock.
case *pgproto3.ErrorResponse:
pgErr := ErrorResponseToPgError(msg)
return nil, pgErr
case *pgproto3.CommandComplete:
p.conn.asyncClose()
return nil, errors.New("BUG: received CommandComplete while handling Describe")
case *pgproto3.ReadyForQuery:
p.conn.asyncClose()
return nil, errors.New("BUG: received ReadyForQuery while handling Describe")
}
}
}
// Close closes the pipeline and returns the connection to normal mode.
func (p *Pipeline) Close() error {
if p.closed {
return p.err
}
p.closed = true
if p.pendingSync {
p.conn.asyncClose()
p.err = errors.New("pipeline has unsynced requests")
p.conn.contextWatcher.Unwatch()
p.conn.unlock()
return p.err
}
for p.expectedReadyForQueryCount > 0 {
_, err := p.getResults()
if err != nil {
p.err = err
var pgErr *PgError
if !errors.As(err, &pgErr) {
p.conn.asyncClose()
break
}
}
}
p.conn.contextWatcher.Unwatch()
p.conn.unlock()
return p.err
}
// DeadlineContextWatcherHandler handles canceled contexts by setting a deadline on a net.Conn.
type DeadlineContextWatcherHandler struct {
Conn net.Conn
// DeadlineDelay is the delay to set on the deadline set on net.Conn when the context is canceled.
DeadlineDelay time.Duration
}
func (h *DeadlineContextWatcherHandler) HandleCancel(ctx context.Context) {
h.Conn.SetDeadline(time.Now().Add(h.DeadlineDelay))
}
func (h *DeadlineContextWatcherHandler) HandleUnwatchAfterCancel() {
h.Conn.SetDeadline(time.Time{})
}
// CancelRequestContextWatcherHandler handles canceled contexts by sending a cancel request to the server. It also sets
// a deadline on a net.Conn as a fallback.
type CancelRequestContextWatcherHandler struct {
Conn *PgConn
// CancelRequestDelay is the delay before sending the cancel request to the server.
CancelRequestDelay time.Duration
// DeadlineDelay is the delay to set on the deadline set on net.Conn when the context is canceled.
DeadlineDelay time.Duration
cancelFinishedChan chan struct{}
handleUnwatchAfterCancelCalled func()
}
func (h *CancelRequestContextWatcherHandler) HandleCancel(context.Context) {
h.cancelFinishedChan = make(chan struct{})
var handleUnwatchedAfterCancelCalledCtx context.Context
handleUnwatchedAfterCancelCalledCtx, h.handleUnwatchAfterCancelCalled = context.WithCancel(context.Background())
deadline := time.Now().Add(h.DeadlineDelay)
h.Conn.conn.SetDeadline(deadline)
go func() {
defer close(h.cancelFinishedChan)
select {
case <-handleUnwatchedAfterCancelCalledCtx.Done():
return
case <-time.After(h.CancelRequestDelay):
}
cancelRequestCtx, cancel := context.WithDeadline(handleUnwatchedAfterCancelCalledCtx, deadline)
defer cancel()
h.Conn.CancelRequest(cancelRequestCtx)
// CancelRequest is inherently racy. Even though the cancel request has been received by the server at this point,
// it hasn't necessarily been delivered to the other connection. If we immediately return and the connection is
// immediately used then it is possible the CancelRequest will actually cancel our next query. The
// TestCancelRequestContextWatcherHandler Stress test can produce this error without the sleep below. The sleep time
// is arbitrary, but should be sufficient to prevent this error case.
time.Sleep(100 * time.Millisecond)
}()
}
func (h *CancelRequestContextWatcherHandler) HandleUnwatchAfterCancel() {
h.handleUnwatchAfterCancelCalled()
<-h.cancelFinishedChan
h.Conn.conn.SetDeadline(time.Time{})
}