140 lines
3.2 KiB
Go
140 lines
3.2 KiB
Go
|
package pg
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"fmt"
|
||
|
"time"
|
||
|
|
||
|
"github.com/go-pg/pg/v10/orm"
|
||
|
)
|
||
|
|
||
|
type (
|
||
|
BeforeScanHook = orm.BeforeScanHook
|
||
|
AfterScanHook = orm.AfterScanHook
|
||
|
AfterSelectHook = orm.AfterSelectHook
|
||
|
BeforeInsertHook = orm.BeforeInsertHook
|
||
|
AfterInsertHook = orm.AfterInsertHook
|
||
|
BeforeUpdateHook = orm.BeforeUpdateHook
|
||
|
AfterUpdateHook = orm.AfterUpdateHook
|
||
|
BeforeDeleteHook = orm.BeforeDeleteHook
|
||
|
AfterDeleteHook = orm.AfterDeleteHook
|
||
|
)
|
||
|
|
||
|
//------------------------------------------------------------------------------
|
||
|
|
||
|
type dummyFormatter struct{}
|
||
|
|
||
|
func (dummyFormatter) FormatQuery(b []byte, query string, params ...interface{}) []byte {
|
||
|
return append(b, query...)
|
||
|
}
|
||
|
|
||
|
// QueryEvent ...
|
||
|
type QueryEvent struct {
|
||
|
StartTime time.Time
|
||
|
DB orm.DB
|
||
|
Model interface{}
|
||
|
Query interface{}
|
||
|
Params []interface{}
|
||
|
fmtedQuery []byte
|
||
|
Result Result
|
||
|
Err error
|
||
|
|
||
|
Stash map[interface{}]interface{}
|
||
|
}
|
||
|
|
||
|
// QueryHook ...
|
||
|
type QueryHook interface {
|
||
|
BeforeQuery(context.Context, *QueryEvent) (context.Context, error)
|
||
|
AfterQuery(context.Context, *QueryEvent) error
|
||
|
}
|
||
|
|
||
|
// UnformattedQuery returns the unformatted query of a query event.
|
||
|
// The query is only valid until the query Result is returned to the user.
|
||
|
func (e *QueryEvent) UnformattedQuery() ([]byte, error) {
|
||
|
return queryString(e.Query)
|
||
|
}
|
||
|
|
||
|
func queryString(query interface{}) ([]byte, error) {
|
||
|
switch query := query.(type) {
|
||
|
case orm.TemplateAppender:
|
||
|
return query.AppendTemplate(nil)
|
||
|
case string:
|
||
|
return dummyFormatter{}.FormatQuery(nil, query), nil
|
||
|
default:
|
||
|
return nil, fmt.Errorf("pg: can't append %T", query)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// FormattedQuery returns the formatted query of a query event.
|
||
|
// The query is only valid until the query Result is returned to the user.
|
||
|
func (e *QueryEvent) FormattedQuery() ([]byte, error) {
|
||
|
return e.fmtedQuery, nil
|
||
|
}
|
||
|
|
||
|
// AddQueryHook adds a hook into query processing.
|
||
|
func (db *baseDB) AddQueryHook(hook QueryHook) {
|
||
|
db.queryHooks = append(db.queryHooks, hook)
|
||
|
}
|
||
|
|
||
|
func (db *baseDB) beforeQuery(
|
||
|
ctx context.Context,
|
||
|
ormDB orm.DB,
|
||
|
model, query interface{},
|
||
|
params []interface{},
|
||
|
fmtedQuery []byte,
|
||
|
) (context.Context, *QueryEvent, error) {
|
||
|
if len(db.queryHooks) == 0 {
|
||
|
return ctx, nil, nil
|
||
|
}
|
||
|
|
||
|
event := &QueryEvent{
|
||
|
StartTime: time.Now(),
|
||
|
DB: ormDB,
|
||
|
Model: model,
|
||
|
Query: query,
|
||
|
Params: params,
|
||
|
fmtedQuery: fmtedQuery,
|
||
|
}
|
||
|
|
||
|
for i, hook := range db.queryHooks {
|
||
|
var err error
|
||
|
ctx, err = hook.BeforeQuery(ctx, event)
|
||
|
if err != nil {
|
||
|
if err := db.afterQueryFromIndex(ctx, event, i); err != nil {
|
||
|
return ctx, nil, err
|
||
|
}
|
||
|
return ctx, nil, err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return ctx, event, nil
|
||
|
}
|
||
|
|
||
|
func (db *baseDB) afterQuery(
|
||
|
ctx context.Context,
|
||
|
event *QueryEvent,
|
||
|
res Result,
|
||
|
err error,
|
||
|
) error {
|
||
|
if event == nil {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
event.Err = err
|
||
|
event.Result = res
|
||
|
return db.afterQueryFromIndex(ctx, event, len(db.queryHooks)-1)
|
||
|
}
|
||
|
|
||
|
func (db *baseDB) afterQueryFromIndex(ctx context.Context, event *QueryEvent, hookIndex int) error {
|
||
|
for ; hookIndex >= 0; hookIndex-- {
|
||
|
if err := db.queryHooks[hookIndex].AfterQuery(ctx, event); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func copyQueryHooks(s []QueryHook) []QueryHook {
|
||
|
return s[:len(s):len(s)]
|
||
|
}
|