mirror of
1
Fork 0
gotosocial/vendor/github.com/uptrace/bun/dialect/pgdialect/array.go

667 lines
13 KiB
Go
Raw Normal View History

package pgdialect
import (
"database/sql"
"database/sql/driver"
"encoding/hex"
"fmt"
"math"
"reflect"
"strconv"
"time"
"unicode/utf8"
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
)
type ArrayValue struct {
v reflect.Value
append schema.AppenderFunc
scan schema.ScannerFunc
}
// Array accepts a slice and returns a wrapper for working with PostgreSQL
// array data type.
//
// For struct fields you can use array tag:
//
// Emails []string `bun:",array"`
func Array(vi interface{}) *ArrayValue {
v := reflect.ValueOf(vi)
if !v.IsValid() {
panic(fmt.Errorf("bun: Array(nil)"))
}
return &ArrayValue{
v: v,
2021-10-24 13:14:37 +02:00
append: pgDialect.arrayAppender(v.Type()),
scan: arrayScanner(v.Type()),
}
}
var (
_ schema.QueryAppender = (*ArrayValue)(nil)
_ sql.Scanner = (*ArrayValue)(nil)
)
func (a *ArrayValue) AppendQuery(fmter schema.Formatter, b []byte) ([]byte, error) {
if a.append == nil {
panic(fmt.Errorf("bun: Array(unsupported %s)", a.v.Type()))
}
return a.append(fmter, b, a.v), nil
}
func (a *ArrayValue) Scan(src interface{}) error {
if a.scan == nil {
return fmt.Errorf("bun: Array(unsupported %s)", a.v.Type())
}
if a.v.Kind() != reflect.Ptr {
return fmt.Errorf("bun: Array(non-pointer %s)", a.v.Type())
}
return a.scan(a.v, src)
}
func (a *ArrayValue) Value() interface{} {
if a.v.IsValid() {
return a.v.Interface()
}
return nil
}
//------------------------------------------------------------------------------
func (d *Dialect) arrayAppender(typ reflect.Type) schema.AppenderFunc {
kind := typ.Kind()
switch kind {
case reflect.Ptr:
if fn := d.arrayAppender(typ.Elem()); fn != nil {
return schema.PtrAppender(fn)
}
case reflect.Slice, reflect.Array:
// continue below
default:
return nil
}
elemType := typ.Elem()
if kind == reflect.Slice {
switch elemType {
case stringType:
return appendStringSliceValue
case intType:
return appendIntSliceValue
case int64Type:
return appendInt64SliceValue
case float64Type:
return appendFloat64SliceValue
case timeType:
return appendTimeSliceValue
}
}
appendElem := d.arrayElemAppender(elemType)
if appendElem == nil {
panic(fmt.Errorf("pgdialect: %s is not supported", typ))
}
return func(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
kind := v.Kind()
switch kind {
case reflect.Ptr, reflect.Slice:
if v.IsNil() {
return dialect.AppendNull(b)
}
}
if kind == reflect.Ptr {
v = v.Elem()
}
b = append(b, "'{"...)
ln := v.Len()
for i := 0; i < ln; i++ {
elem := v.Index(i)
if i > 0 {
b = append(b, ',')
}
b = appendElem(fmter, b, elem)
}
b = append(b, "}'"...)
return b
}
}
func (d *Dialect) arrayElemAppender(typ reflect.Type) schema.AppenderFunc {
if typ.Implements(driverValuerType) {
return arrayAppendDriverValue
}
switch typ.Kind() {
case reflect.String:
return arrayAppendStringValue
case reflect.Slice:
if typ.Elem().Kind() == reflect.Uint8 {
return arrayAppendBytesValue
}
}
return schema.Appender(d, typ)
}
func arrayAppend(fmter schema.Formatter, b []byte, v interface{}) []byte {
switch v := v.(type) {
case int64:
return strconv.AppendInt(b, v, 10)
case float64:
return arrayAppendFloat64(b, v)
case bool:
return dialect.AppendBool(b, v)
case []byte:
return arrayAppendBytes(b, v)
case string:
return arrayAppendString(b, v)
case time.Time:
b = append(b, '"')
b = appendTime(b, v)
b = append(b, '"')
return b
default:
err := fmt.Errorf("pgdialect: can't append %T", v)
return dialect.AppendError(b, err)
}
}
func arrayAppendStringValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
return arrayAppendString(b, v.String())
}
func arrayAppendBytesValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
return arrayAppendBytes(b, v.Bytes())
}
func arrayAppendDriverValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
iface, err := v.Interface().(driver.Valuer).Value()
if err != nil {
return dialect.AppendError(b, err)
}
return arrayAppend(fmter, b, iface)
}
func appendStringSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
ss := v.Convert(sliceStringType).Interface().([]string)
return appendStringSlice(b, ss)
}
func appendStringSlice(b []byte, ss []string) []byte {
if ss == nil {
return dialect.AppendNull(b)
}
b = append(b, '\'')
b = append(b, '{')
for _, s := range ss {
b = arrayAppendString(b, s)
b = append(b, ',')
}
if len(ss) > 0 {
b[len(b)-1] = '}' // Replace trailing comma.
} else {
b = append(b, '}')
}
b = append(b, '\'')
return b
}
func appendIntSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
ints := v.Convert(sliceIntType).Interface().([]int)
return appendIntSlice(b, ints)
}
func appendIntSlice(b []byte, ints []int) []byte {
if ints == nil {
return dialect.AppendNull(b)
}
b = append(b, '\'')
b = append(b, '{')
for _, n := range ints {
b = strconv.AppendInt(b, int64(n), 10)
b = append(b, ',')
}
if len(ints) > 0 {
b[len(b)-1] = '}' // Replace trailing comma.
} else {
b = append(b, '}')
}
b = append(b, '\'')
return b
}
func appendInt64SliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
ints := v.Convert(sliceInt64Type).Interface().([]int64)
return appendInt64Slice(b, ints)
}
func appendInt64Slice(b []byte, ints []int64) []byte {
if ints == nil {
return dialect.AppendNull(b)
}
b = append(b, '\'')
b = append(b, '{')
for _, n := range ints {
b = strconv.AppendInt(b, n, 10)
b = append(b, ',')
}
if len(ints) > 0 {
b[len(b)-1] = '}' // Replace trailing comma.
} else {
b = append(b, '}')
}
b = append(b, '\'')
return b
}
func appendFloat64SliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
floats := v.Convert(sliceFloat64Type).Interface().([]float64)
return appendFloat64Slice(b, floats)
}
func appendFloat64Slice(b []byte, floats []float64) []byte {
if floats == nil {
return dialect.AppendNull(b)
}
b = append(b, '\'')
b = append(b, '{')
for _, n := range floats {
b = arrayAppendFloat64(b, n)
b = append(b, ',')
}
if len(floats) > 0 {
b[len(b)-1] = '}' // Replace trailing comma.
} else {
b = append(b, '}')
}
b = append(b, '\'')
return b
}
func arrayAppendFloat64(b []byte, num float64) []byte {
switch {
case math.IsNaN(num):
return append(b, "NaN"...)
case math.IsInf(num, 1):
return append(b, "Infinity"...)
case math.IsInf(num, -1):
return append(b, "-Infinity"...)
default:
return strconv.AppendFloat(b, num, 'f', -1, 64)
}
}
func appendTimeSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
ts := v.Convert(sliceTimeType).Interface().([]time.Time)
return appendTimeSlice(fmter, b, ts)
}
func appendTimeSlice(fmter schema.Formatter, b []byte, ts []time.Time) []byte {
if ts == nil {
return dialect.AppendNull(b)
}
b = append(b, '\'')
b = append(b, '{')
for _, t := range ts {
b = append(b, '"')
b = appendTime(b, t)
b = append(b, '"')
b = append(b, ',')
}
if len(ts) > 0 {
b[len(b)-1] = '}' // Replace trailing comma.
} else {
b = append(b, '}')
}
b = append(b, '\'')
return b
}
//------------------------------------------------------------------------------
func arrayScanner(typ reflect.Type) schema.ScannerFunc {
kind := typ.Kind()
switch kind {
case reflect.Ptr:
if fn := arrayScanner(typ.Elem()); fn != nil {
return schema.PtrScanner(fn)
}
case reflect.Slice, reflect.Array:
// ok:
default:
return nil
}
elemType := typ.Elem()
if kind == reflect.Slice {
switch elemType {
case stringType:
return scanStringSliceValue
case intType:
return scanIntSliceValue
case int64Type:
return scanInt64SliceValue
case float64Type:
return scanFloat64SliceValue
}
}
scanElem := schema.Scanner(elemType)
return func(dest reflect.Value, src interface{}) error {
dest = reflect.Indirect(dest)
if !dest.CanSet() {
return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type())
}
kind := dest.Kind()
if src == nil {
if kind != reflect.Slice || !dest.IsNil() {
dest.Set(reflect.Zero(dest.Type()))
}
return nil
}
if kind == reflect.Slice {
if dest.IsNil() {
dest.Set(reflect.MakeSlice(dest.Type(), 0, 0))
} else if dest.Len() > 0 {
dest.Set(dest.Slice(0, 0))
}
}
if src == nil {
return nil
}
b, err := toBytes(src)
if err != nil {
return err
}
p := newArrayParser(b)
nextValue := internal.MakeSliceNextElemFunc(dest)
for p.Next() {
elem := p.Elem()
elemValue := nextValue()
if err := scanElem(elemValue, elem); err != nil {
return fmt.Errorf("scanElem failed: %w", err)
}
}
return p.Err()
}
}
func scanStringSliceValue(dest reflect.Value, src interface{}) error {
dest = reflect.Indirect(dest)
if !dest.CanSet() {
return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type())
}
slice, err := decodeStringSlice(src)
if err != nil {
return err
}
dest.Set(reflect.ValueOf(slice))
return nil
}
func decodeStringSlice(src interface{}) ([]string, error) {
if src == nil {
return nil, nil
}
b, err := toBytes(src)
if err != nil {
return nil, err
}
slice := make([]string, 0)
p := newArrayParser(b)
for p.Next() {
elem := p.Elem()
slice = append(slice, string(elem))
}
if err := p.Err(); err != nil {
return nil, err
}
return slice, nil
}
func scanIntSliceValue(dest reflect.Value, src interface{}) error {
dest = reflect.Indirect(dest)
if !dest.CanSet() {
return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type())
}
slice, err := decodeIntSlice(src)
if err != nil {
return err
}
dest.Set(reflect.ValueOf(slice))
return nil
}
func decodeIntSlice(src interface{}) ([]int, error) {
if src == nil {
return nil, nil
}
b, err := toBytes(src)
if err != nil {
return nil, err
}
slice := make([]int, 0)
p := newArrayParser(b)
for p.Next() {
elem := p.Elem()
if elem == nil {
slice = append(slice, 0)
continue
}
n, err := strconv.Atoi(bytesToString(elem))
if err != nil {
return nil, err
}
slice = append(slice, n)
}
if err := p.Err(); err != nil {
return nil, err
}
return slice, nil
}
func scanInt64SliceValue(dest reflect.Value, src interface{}) error {
dest = reflect.Indirect(dest)
if !dest.CanSet() {
return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type())
}
slice, err := decodeInt64Slice(src)
if err != nil {
return err
}
dest.Set(reflect.ValueOf(slice))
return nil
}
func decodeInt64Slice(src interface{}) ([]int64, error) {
if src == nil {
return nil, nil
}
b, err := toBytes(src)
if err != nil {
return nil, err
}
slice := make([]int64, 0)
p := newArrayParser(b)
for p.Next() {
elem := p.Elem()
if elem == nil {
slice = append(slice, 0)
continue
}
n, err := strconv.ParseInt(bytesToString(elem), 10, 64)
if err != nil {
return nil, err
}
slice = append(slice, n)
}
if err := p.Err(); err != nil {
return nil, err
}
return slice, nil
}
func scanFloat64SliceValue(dest reflect.Value, src interface{}) error {
dest = reflect.Indirect(dest)
if !dest.CanSet() {
return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type())
}
slice, err := scanFloat64Slice(src)
if err != nil {
return err
}
dest.Set(reflect.ValueOf(slice))
return nil
}
func scanFloat64Slice(src interface{}) ([]float64, error) {
if src == nil {
return nil, nil
}
b, err := toBytes(src)
if err != nil {
return nil, err
}
slice := make([]float64, 0)
p := newArrayParser(b)
for p.Next() {
elem := p.Elem()
if elem == nil {
slice = append(slice, 0)
continue
}
n, err := strconv.ParseFloat(bytesToString(elem), 64)
if err != nil {
return nil, err
}
slice = append(slice, n)
}
if err := p.Err(); err != nil {
return nil, err
}
return slice, nil
}
func toBytes(src interface{}) ([]byte, error) {
switch src := src.(type) {
case string:
return stringToBytes(src), nil
case []byte:
return src, nil
default:
return nil, fmt.Errorf("pgdialect: got %T, wanted []byte or string", src)
}
}
//------------------------------------------------------------------------------
func arrayAppendBytes(b []byte, bs []byte) []byte {
if bs == nil {
return dialect.AppendNull(b)
}
b = append(b, `"\\x`...)
s := len(b)
b = append(b, make([]byte, hex.EncodedLen(len(bs)))...)
hex.Encode(b[s:], bs)
b = append(b, '"')
return b
}
func arrayAppendString(b []byte, s string) []byte {
b = append(b, '"')
for _, r := range s {
switch r {
case 0:
// ignore
case '\'':
b = append(b, "''"...)
case '"':
b = append(b, '\\', '"')
case '\\':
b = append(b, '\\', '\\')
default:
if r < utf8.RuneSelf {
b = append(b, byte(r))
break
}
l := len(b)
if cap(b)-l < utf8.UTFMax {
b = append(b, make([]byte, utf8.UTFMax)...)
}
n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r)
b = b[:l+n]
}
}
b = append(b, '"')
return b
}