package schema

import (
	"database/sql/driver"
	"fmt"
	"net"
	"reflect"
	"strconv"
	"strings"
	"sync"
	"time"

	"github.com/uptrace/bun/dialect"
	"github.com/uptrace/bun/dialect/sqltype"
	"github.com/uptrace/bun/extra/bunjson"
	"github.com/uptrace/bun/internal"
	"github.com/vmihailenco/msgpack/v5"
)

type (
	AppenderFunc   func(fmter Formatter, b []byte, v reflect.Value) []byte
	CustomAppender func(typ reflect.Type) AppenderFunc
)

var appenders = []AppenderFunc{
	reflect.Bool:          AppendBoolValue,
	reflect.Int:           AppendIntValue,
	reflect.Int8:          AppendIntValue,
	reflect.Int16:         AppendIntValue,
	reflect.Int32:         AppendIntValue,
	reflect.Int64:         AppendIntValue,
	reflect.Uint:          AppendUintValue,
	reflect.Uint8:         AppendUintValue,
	reflect.Uint16:        AppendUintValue,
	reflect.Uint32:        appendUint32Value,
	reflect.Uint64:        appendUint64Value,
	reflect.Uintptr:       nil,
	reflect.Float32:       AppendFloat32Value,
	reflect.Float64:       AppendFloat64Value,
	reflect.Complex64:     nil,
	reflect.Complex128:    nil,
	reflect.Array:         AppendJSONValue,
	reflect.Chan:          nil,
	reflect.Func:          nil,
	reflect.Interface:     nil,
	reflect.Map:           AppendJSONValue,
	reflect.Ptr:           nil,
	reflect.Slice:         AppendJSONValue,
	reflect.String:        AppendStringValue,
	reflect.Struct:        AppendJSONValue,
	reflect.UnsafePointer: nil,
}

var appenderMap sync.Map

func FieldAppender(dialect Dialect, field *Field) AppenderFunc {
	if field.Tag.HasOption("msgpack") {
		return appendMsgpack
	}

	fieldType := field.StructField.Type

	switch strings.ToUpper(field.UserSQLType) {
	case sqltype.JSON, sqltype.JSONB:
		if fieldType.Implements(driverValuerType) {
			return appendDriverValue
		}

		if fieldType.Kind() != reflect.Ptr {
			if reflect.PtrTo(fieldType).Implements(driverValuerType) {
				return addrAppender(appendDriverValue)
			}
		}

		return AppendJSONValue
	}

	return Appender(dialect, fieldType)
}

func Appender(dialect Dialect, typ reflect.Type) AppenderFunc {
	if v, ok := appenderMap.Load(typ); ok {
		return v.(AppenderFunc)
	}

	fn := appender(dialect, typ)

	if v, ok := appenderMap.LoadOrStore(typ, fn); ok {
		return v.(AppenderFunc)
	}
	return fn
}

func appender(dialect Dialect, typ reflect.Type) AppenderFunc {
	switch typ {
	case bytesType:
		return appendBytesValue
	case timeType:
		return appendTimeValue
	case timePtrType:
		return PtrAppender(appendTimeValue)
	case ipType:
		return appendIPValue
	case ipNetType:
		return appendIPNetValue
	case jsonRawMessageType:
		return appendJSONRawMessageValue
	}

	kind := typ.Kind()

	if typ.Implements(queryAppenderType) {
		if kind == reflect.Ptr {
			return nilAwareAppender(appendQueryAppenderValue)
		}
		return appendQueryAppenderValue
	}
	if typ.Implements(driverValuerType) {
		if kind == reflect.Ptr {
			return nilAwareAppender(appendDriverValue)
		}
		return appendDriverValue
	}

	if kind != reflect.Ptr {
		ptr := reflect.PtrTo(typ)
		if ptr.Implements(queryAppenderType) {
			return addrAppender(appendQueryAppenderValue)
		}
		if ptr.Implements(driverValuerType) {
			return addrAppender(appendDriverValue)
		}
	}

	switch kind {
	case reflect.Interface:
		return ifaceAppenderFunc
	case reflect.Ptr:
		if typ.Implements(jsonMarshalerType) {
			return nilAwareAppender(AppendJSONValue)
		}
		if fn := Appender(dialect, typ.Elem()); fn != nil {
			return PtrAppender(fn)
		}
	case reflect.Slice:
		if typ.Elem().Kind() == reflect.Uint8 {
			return appendBytesValue
		}
	case reflect.Array:
		if typ.Elem().Kind() == reflect.Uint8 {
			return appendArrayBytesValue
		}
	}

	return appenders[typ.Kind()]
}

func ifaceAppenderFunc(fmter Formatter, b []byte, v reflect.Value) []byte {
	if v.IsNil() {
		return dialect.AppendNull(b)
	}
	elem := v.Elem()
	appender := Appender(fmter.Dialect(), elem.Type())
	return appender(fmter, b, elem)
}

func nilAwareAppender(fn AppenderFunc) AppenderFunc {
	return func(fmter Formatter, b []byte, v reflect.Value) []byte {
		if v.IsNil() {
			return dialect.AppendNull(b)
		}
		return fn(fmter, b, v)
	}
}

func PtrAppender(fn AppenderFunc) AppenderFunc {
	return func(fmter Formatter, b []byte, v reflect.Value) []byte {
		if v.IsNil() {
			return dialect.AppendNull(b)
		}
		return fn(fmter, b, v.Elem())
	}
}

func AppendBoolValue(fmter Formatter, b []byte, v reflect.Value) []byte {
	return fmter.Dialect().AppendBool(b, v.Bool())
}

func AppendIntValue(fmter Formatter, b []byte, v reflect.Value) []byte {
	return strconv.AppendInt(b, v.Int(), 10)
}

func AppendUintValue(fmter Formatter, b []byte, v reflect.Value) []byte {
	return strconv.AppendUint(b, v.Uint(), 10)
}

func appendUint32Value(fmter Formatter, b []byte, v reflect.Value) []byte {
	return fmter.Dialect().AppendUint32(b, uint32(v.Uint()))
}

func appendUint64Value(fmter Formatter, b []byte, v reflect.Value) []byte {
	return fmter.Dialect().AppendUint64(b, v.Uint())
}

func AppendFloat32Value(fmter Formatter, b []byte, v reflect.Value) []byte {
	return dialect.AppendFloat32(b, float32(v.Float()))
}

func AppendFloat64Value(fmter Formatter, b []byte, v reflect.Value) []byte {
	return dialect.AppendFloat64(b, float64(v.Float()))
}

func appendBytesValue(fmter Formatter, b []byte, v reflect.Value) []byte {
	return fmter.Dialect().AppendBytes(b, v.Bytes())
}

func appendArrayBytesValue(fmter Formatter, b []byte, v reflect.Value) []byte {
	if v.CanAddr() {
		return fmter.Dialect().AppendBytes(b, v.Slice(0, v.Len()).Bytes())
	}

	tmp := make([]byte, v.Len())
	reflect.Copy(reflect.ValueOf(tmp), v)
	b = fmter.Dialect().AppendBytes(b, tmp)
	return b
}

func AppendStringValue(fmter Formatter, b []byte, v reflect.Value) []byte {
	return fmter.Dialect().AppendString(b, v.String())
}

func AppendJSONValue(fmter Formatter, b []byte, v reflect.Value) []byte {
	bb, err := bunjson.Marshal(v.Interface())
	if err != nil {
		return dialect.AppendError(b, err)
	}

	if len(bb) > 0 && bb[len(bb)-1] == '\n' {
		bb = bb[:len(bb)-1]
	}

	return fmter.Dialect().AppendJSON(b, bb)
}

func appendTimeValue(fmter Formatter, b []byte, v reflect.Value) []byte {
	tm := v.Interface().(time.Time)
	return fmter.Dialect().AppendTime(b, tm)
}

func appendIPValue(fmter Formatter, b []byte, v reflect.Value) []byte {
	ip := v.Interface().(net.IP)
	return fmter.Dialect().AppendString(b, ip.String())
}

func appendIPNetValue(fmter Formatter, b []byte, v reflect.Value) []byte {
	ipnet := v.Interface().(net.IPNet)
	return fmter.Dialect().AppendString(b, ipnet.String())
}

func appendJSONRawMessageValue(fmter Formatter, b []byte, v reflect.Value) []byte {
	bytes := v.Bytes()
	if bytes == nil {
		return dialect.AppendNull(b)
	}
	return fmter.Dialect().AppendString(b, internal.String(bytes))
}

func appendQueryAppenderValue(fmter Formatter, b []byte, v reflect.Value) []byte {
	return AppendQueryAppender(fmter, b, v.Interface().(QueryAppender))
}

func appendDriverValue(fmter Formatter, b []byte, v reflect.Value) []byte {
	value, err := v.Interface().(driver.Valuer).Value()
	if err != nil {
		return dialect.AppendError(b, err)
	}
	if _, ok := value.(driver.Valuer); ok {
		return dialect.AppendError(b, fmt.Errorf("driver.Valuer returns unsupported type %T", value))
	}
	return Append(fmter, b, value)
}

func addrAppender(fn AppenderFunc) AppenderFunc {
	return func(fmter Formatter, b []byte, v reflect.Value) []byte {
		if !v.CanAddr() {
			err := fmt.Errorf("bun: Append(nonaddressable %T)", v.Interface())
			return dialect.AppendError(b, err)
		}
		return fn(fmter, b, v.Addr())
	}
}

func appendMsgpack(fmter Formatter, b []byte, v reflect.Value) []byte {
	hexEnc := internal.NewHexEncoder(b)

	enc := msgpack.GetEncoder()
	defer msgpack.PutEncoder(enc)

	enc.Reset(hexEnc)
	if err := enc.EncodeValue(v); err != nil {
		return dialect.AppendError(b, err)
	}

	if err := hexEnc.Close(); err != nil {
		return dialect.AppendError(b, err)
	}

	return hexEnc.Bytes()
}

func AppendQueryAppender(fmter Formatter, b []byte, app QueryAppender) []byte {
	bb, err := app.AppendQuery(fmter, b)
	if err != nil {
		return dialect.AppendError(b, err)
	}
	return bb
}