mirror of
1
Fork 0
forgejo/vendor/gopkg.in/testfixtures.v2/postgresql.go

212 lines
4.5 KiB
Go

package testfixtures
import (
"database/sql"
"fmt"
)
// PostgreSQL is the PG helper for this package
type PostgreSQL struct {
baseHelper
// UseAlterConstraint If true, the contraint disabling will do
// using ALTER CONTRAINT sintax, only allowed in PG >= 9.4.
// If false, the constraint disabling will use DISABLE TRIGGER ALL,
// which requires SUPERUSER privileges.
UseAlterConstraint bool
tables []string
sequences []string
nonDeferrableConstraints []pgConstraint
}
type pgConstraint struct {
tableName string
constraintName string
}
func (h *PostgreSQL) init(db *sql.DB) error {
var err error
h.tables, err = h.getTables(db)
if err != nil {
return err
}
h.sequences, err = h.getSequences(db)
if err != nil {
return err
}
h.nonDeferrableConstraints, err = h.getNonDeferrableConstraints(db)
if err != nil {
return err
}
return nil
}
func (*PostgreSQL) paramType() int {
return paramTypeDollar
}
func (*PostgreSQL) databaseName(db *sql.DB) (dbName string) {
db.QueryRow("SELECT current_database()").Scan(&dbName)
return
}
func (h *PostgreSQL) getTables(db *sql.DB) ([]string, error) {
var tables []string
sql := `
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public'
AND table_type = 'BASE TABLE';
`
rows, err := db.Query(sql)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var table string
rows.Scan(&table)
tables = append(tables, table)
}
return tables, nil
}
func (h *PostgreSQL) getSequences(db *sql.DB) ([]string, error) {
var sequences []string
sql := "SELECT relname FROM pg_class WHERE relkind = 'S'"
rows, err := db.Query(sql)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var sequence string
if err = rows.Scan(&sequence); err != nil {
return nil, err
}
sequences = append(sequences, sequence)
}
return sequences, nil
}
func (*PostgreSQL) getNonDeferrableConstraints(db *sql.DB) ([]pgConstraint, error) {
var constraints []pgConstraint
sql := `
SELECT table_name, constraint_name
FROM information_schema.table_constraints
WHERE constraint_type = 'FOREIGN KEY'
AND is_deferrable = 'NO'`
rows, err := db.Query(sql)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var constraint pgConstraint
err = rows.Scan(&constraint.tableName, &constraint.constraintName)
if err != nil {
return nil, err
}
constraints = append(constraints, constraint)
}
return constraints, nil
}
func (h *PostgreSQL) disableTriggers(db *sql.DB, loadFn loadFunction) error {
defer func() {
// re-enable triggers after load
var sql string
for _, table := range h.tables {
sql += fmt.Sprintf("ALTER TABLE %s ENABLE TRIGGER ALL;", h.quoteKeyword(table))
}
db.Exec(sql)
}()
tx, err := db.Begin()
if err != nil {
return err
}
var sql string
for _, table := range h.tables {
sql += fmt.Sprintf("ALTER TABLE %s DISABLE TRIGGER ALL;", h.quoteKeyword(table))
}
if _, err = tx.Exec(sql); err != nil {
return err
}
if err = loadFn(tx); err != nil {
tx.Rollback()
return err
}
return tx.Commit()
}
func (h *PostgreSQL) makeConstraintsDeferrable(db *sql.DB, loadFn loadFunction) error {
defer func() {
// ensure constraint being not deferrable again after load
var sql string
for _, constraint := range h.nonDeferrableConstraints {
sql += fmt.Sprintf("ALTER TABLE %s ALTER CONSTRAINT %s NOT DEFERRABLE;", h.quoteKeyword(constraint.tableName), h.quoteKeyword(constraint.constraintName))
}
db.Exec(sql)
}()
var sql string
for _, constraint := range h.nonDeferrableConstraints {
sql += fmt.Sprintf("ALTER TABLE %s ALTER CONSTRAINT %s DEFERRABLE;", h.quoteKeyword(constraint.tableName), h.quoteKeyword(constraint.constraintName))
}
if _, err := db.Exec(sql); err != nil {
return err
}
tx, err := db.Begin()
if err != nil {
return err
}
if _, err = tx.Exec("SET CONSTRAINTS ALL DEFERRED"); err != nil {
return nil
}
if err = loadFn(tx); err != nil {
tx.Rollback()
return err
}
return tx.Commit()
}
func (h *PostgreSQL) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) error {
// ensure sequences being reset after load
defer h.resetSequences(db)
if h.UseAlterConstraint {
return h.makeConstraintsDeferrable(db, loadFn)
} else {
return h.disableTriggers(db, loadFn)
}
}
func (h *PostgreSQL) resetSequences(db *sql.DB) error {
for _, sequence := range h.sequences {
_, err := db.Exec(fmt.Sprintf("SELECT SETVAL('%s', %d)", sequence, resetSequencesTo))
if err != nil {
return err
}
}
return nil
}