mirror of
1
Fork 0

Merge commit '4d1c7c871bb192bc05ce1cb6fb80773126373111'

This commit is contained in:
f0x 2022-10-03 18:47:29 +00:00
parent 4100fc683c
commit 7594fc7456
22 changed files with 753 additions and 365 deletions

View File

@ -101,6 +101,11 @@ func (c *AccountCache) Put(account *gtsmodel.Account) {
c.cache.Set(account.ID, copyAccount(account)) c.cache.Set(account.ID, copyAccount(account))
} }
// Invalidate removes (invalidates) one account from the cache by its ID.
func (c *AccountCache) Invalidate(id string) {
c.cache.Invalidate(id)
}
// copyAccount performs a surface-level copy of account, only keeping attached IDs intact, not the objects. // copyAccount performs a surface-level copy of account, only keeping attached IDs intact, not the objects.
// due to all the data being copied being 99% primitive types or strings (which are immutable and passed by ptr) // due to all the data being copied being 99% primitive types or strings (which are immutable and passed by ptr)
// this should be a relatively cheap process // this should be a relatively cheap process

View File

@ -48,6 +48,11 @@ type Account interface {
// UpdateAccount updates one account by ID. // UpdateAccount updates one account by ID.
UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error)
// DeleteAccount deletes one account from the database by its ID.
// DO NOT USE THIS WHEN SUSPENDING ACCOUNTS! In that case you should mark the
// account as suspended instead, rather than deleting from the db entirely.
DeleteAccount(ctx context.Context, id string) Error
// GetAccountCustomCSSByUsername returns the custom css of an account on this instance with the given username. // GetAccountCustomCSSByUsername returns the custom css of an account on this instance with the given username.
GetAccountCustomCSSByUsername(ctx context.Context, username string) (string, Error) GetAccountCustomCSSByUsername(ctx context.Context, username string) (string, Error)

View File

@ -21,7 +21,6 @@ package bundb
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"strings" "strings"
"time" "time"
@ -56,7 +55,7 @@ func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Ac
return a.cache.GetByID(id) return a.cache.GetByID(id)
}, },
func(account *gtsmodel.Account) error { func(account *gtsmodel.Account) error {
return a.newAccountQ(account).Where("account.id = ?", id).Scan(ctx) return a.newAccountQ(account).Where("? = ?", bun.Ident("account.id"), id).Scan(ctx)
}, },
) )
} }
@ -68,7 +67,7 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.
return a.cache.GetByURI(uri) return a.cache.GetByURI(uri)
}, },
func(account *gtsmodel.Account) error { func(account *gtsmodel.Account) error {
return a.newAccountQ(account).Where("account.uri = ?", uri).Scan(ctx) return a.newAccountQ(account).Where("? = ?", bun.Ident("account.uri"), uri).Scan(ctx)
}, },
) )
} }
@ -80,7 +79,7 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.
return a.cache.GetByURL(url) return a.cache.GetByURL(url)
}, },
func(account *gtsmodel.Account) error { func(account *gtsmodel.Account) error {
return a.newAccountQ(account).Where("account.url = ?", url).Scan(ctx) return a.newAccountQ(account).Where("? = ?", bun.Ident("account.url"), url).Scan(ctx)
}, },
) )
} }
@ -95,11 +94,11 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str
q := a.newAccountQ(account) q := a.newAccountQ(account)
if domain != "" { if domain != "" {
q = q.Where("account.username = ?", username) q = q.Where("? = ?", bun.Ident("account.username"), username)
q = q.Where("account.domain = ?", domain) q = q.Where("? = ?", bun.Ident("account.domain"), domain)
} else { } else {
q = q.Where("account.username = ?", strings.ToLower(username)) q = q.Where("? = ?", bun.Ident("account.username"), strings.ToLower(username))
q = q.Where("account.domain IS NULL") q = q.Where("? IS NULL", bun.Ident("account.domain"))
} }
return q.Scan(ctx) return q.Scan(ctx)
@ -114,7 +113,7 @@ func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmo
return a.cache.GetByPubkeyID(id) return a.cache.GetByPubkeyID(id)
}, },
func(account *gtsmodel.Account) error { func(account *gtsmodel.Account) error {
return a.newAccountQ(account).Where("account.public_key_uri = ?", id).Scan(ctx) return a.newAccountQ(account).Where("? = ?", bun.Ident("account.public_key_uri"), id).Scan(ctx)
}, },
) )
} }
@ -170,8 +169,8 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
// create links between this account and any emojis it uses // create links between this account and any emojis it uses
// first clear out any old emoji links // first clear out any old emoji links
if _, err := tx.NewDelete(). if _, err := tx.NewDelete().
Model(&[]*gtsmodel.AccountToEmoji{}). TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")).
Where("account_id = ?", account.ID). Where("? = ?", bun.Ident("account_to_emoji.account_id"), account.ID).
Exec(ctx); err != nil { Exec(ctx); err != nil {
return err return err
} }
@ -197,6 +196,32 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
return account, nil return account, nil
} }
func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error {
// clear out any emoji links
if _, err := tx.
NewDelete().
TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")).
Where("? = ?", bun.Ident("account_to_emoji.account_id"), id).
Exec(ctx); err != nil {
return err
}
// delete the account
_, err := tx.
NewUpdate().
Model(&gtsmodel.Account{ID: id}).
WherePK().
Exec(ctx)
return err
}); err != nil {
return a.conn.ProcessError(err)
}
a.cache.Invalidate(id)
return nil
}
func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) { func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
account := new(gtsmodel.Account) account := new(gtsmodel.Account)
@ -204,11 +229,11 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts
if domain != "" { if domain != "" {
q = q. q = q.
Where("account.username = ?", domain). Where("? = ?", bun.Ident("account.username"), domain).
Where("account.domain = ?", domain) Where("? = ?", bun.Ident("account.domain"), domain)
} else { } else {
q = q. q = q.
Where("account.username = ?", config.GetHost()). Where("? = ?", bun.Ident("account.username"), config.GetHost()).
WhereGroup(" AND ", whereEmptyOrNull("domain")) WhereGroup(" AND ", whereEmptyOrNull("domain"))
} }
@ -224,10 +249,10 @@ func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string)
q := a.conn. q := a.conn.
NewSelect(). NewSelect().
Model(status). Model(status).
Order("id DESC"). Column("status.created_at").
Limit(1). Where("? = ?", bun.Ident("status.account_id"), accountID).
Where("account_id = ?", accountID). Order("status.id DESC").
Column("created_at") Limit(1)
if err := q.Scan(ctx); err != nil { if err := q.Scan(ctx); err != nil {
return time.Time{}, a.conn.ProcessError(err) return time.Time{}, a.conn.ProcessError(err)
@ -240,12 +265,12 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen
return errors.New("one media attachment cannot be both header and avatar") return errors.New("one media attachment cannot be both header and avatar")
} }
var headerOrAVI string var column bun.Ident
switch { switch {
case *mediaAttachment.Avatar: case *mediaAttachment.Avatar:
headerOrAVI = "avatar" column = bun.Ident("account.avatar_media_attachment_id")
case *mediaAttachment.Header: case *mediaAttachment.Header:
headerOrAVI = "header" column = bun.Ident("account.header_media_attachment_id")
default: default:
return errors.New("given media attachment was neither a header nor an avatar") return errors.New("given media attachment was neither a header nor an avatar")
} }
@ -257,11 +282,12 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen
Exec(ctx); err != nil { Exec(ctx); err != nil {
return a.conn.ProcessError(err) return a.conn.ProcessError(err)
} }
if _, err := a.conn. if _, err := a.conn.
NewUpdate(). NewUpdate().
Model(&gtsmodel.Account{}). TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID). Set("? = ?", column, mediaAttachment.ID).
Where("id = ?", accountID). Where("? = ?", bun.Ident("account.id"), accountID).
Exec(ctx); err != nil { Exec(ctx); err != nil {
return a.conn.ProcessError(err) return a.conn.ProcessError(err)
} }
@ -284,7 +310,7 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g
if err := a.conn. if err := a.conn.
NewSelect(). NewSelect().
Model(faves). Model(faves).
Where("account_id = ?", accountID). Where("? = ?", bun.Ident("status_fave.account_id"), accountID).
Scan(ctx); err != nil { Scan(ctx); err != nil {
return nil, a.conn.ProcessError(err) return nil, a.conn.ProcessError(err)
} }
@ -295,8 +321,8 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g
func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, db.Error) { func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, db.Error) {
return a.conn. return a.conn.
NewSelect(). NewSelect().
Model(&gtsmodel.Status{}). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Where("account_id = ?", accountID). Where("? = ?", bun.Ident("status.account_id"), accountID).
Count(ctx) Count(ctx)
} }
@ -305,12 +331,12 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
q := a.conn. q := a.conn.
NewSelect(). NewSelect().
Table("statuses"). Model(&gtsmodel.Status{}).
Column("id"). Column("status.id").
Order("id DESC") Order("status.id DESC")
if accountID != "" { if accountID != "" {
q = q.Where("account_id = ?", accountID) q = q.Where("? = ?", bun.Ident("status.account_id"), accountID)
} }
if limit != 0 { if limit != 0 {
@ -321,27 +347,27 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
// include self-replies (threads) // include self-replies (threads)
whereGroup := func(*bun.SelectQuery) *bun.SelectQuery { whereGroup := func(*bun.SelectQuery) *bun.SelectQuery {
return q. return q.
WhereOr("in_reply_to_account_id = ?", accountID). WhereOr("? = ?", bun.Ident("status.in_reply_to_account_id"), accountID).
WhereGroup(" OR ", whereEmptyOrNull("in_reply_to_uri")) WhereGroup(" OR ", whereEmptyOrNull("status.in_reply_to_uri"))
} }
q = q.WhereGroup(" AND ", whereGroup) q = q.WhereGroup(" AND ", whereGroup)
} }
if excludeReblogs { if excludeReblogs {
q = q.WhereGroup(" AND ", whereEmptyOrNull("boost_of_id")) q = q.WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id"))
} }
if maxID != "" { if maxID != "" {
q = q.Where("id < ?", maxID) q = q.Where("? < ?", bun.Ident("status.id"), maxID)
} }
if minID != "" { if minID != "" {
q = q.Where("id > ?", minID) q = q.Where("? > ?", bun.Ident("status.id"), minID)
} }
if pinnedOnly { if pinnedOnly {
q = q.Where("pinned = ?", true) q = q.Where("? = ?", bun.Ident("status.pinned"), true)
} }
if mediaOnly { if mediaOnly {
@ -352,15 +378,15 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
switch a.conn.Dialect().Name() { switch a.conn.Dialect().Name() {
case dialect.PG: case dialect.PG:
return q. return q.
Where("? IS NOT NULL", bun.Ident("attachments")). Where("? IS NOT NULL", bun.Ident("status.attachments")).
Where("? != '{}'", bun.Ident("attachments")) Where("? != '{}'", bun.Ident("status.attachments"))
case dialect.SQLite: case dialect.SQLite:
return q. return q.
Where("? IS NOT NULL", bun.Ident("attachments")). Where("? IS NOT NULL", bun.Ident("status.attachments")).
Where("? != ''", bun.Ident("attachments")). Where("? != ''", bun.Ident("status.attachments")).
Where("? != 'null'", bun.Ident("attachments")). Where("? != 'null'", bun.Ident("status.attachments")).
Where("? != '{}'", bun.Ident("attachments")). Where("? != '{}'", bun.Ident("status.attachments")).
Where("? != '[]'", bun.Ident("attachments")) Where("? != '[]'", bun.Ident("status.attachments"))
default: default:
log.Panic("db dialect was neither pg nor sqlite") log.Panic("db dialect was neither pg nor sqlite")
return q return q
@ -369,7 +395,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
} }
if publicOnly { if publicOnly {
q = q.Where("visibility = ?", gtsmodel.VisibilityPublic) q = q.Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic)
} }
if err := q.Scan(ctx, &statusIDs); err != nil { if err := q.Scan(ctx, &statusIDs); err != nil {
@ -384,19 +410,19 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string,
q := a.conn. q := a.conn.
NewSelect(). NewSelect().
Table("statuses"). Model(&gtsmodel.Status{}).
Column("id"). Column("status.id").
Where("account_id = ?", accountID). Where("? = ?", bun.Ident("status.account_id"), accountID).
WhereGroup(" AND ", whereEmptyOrNull("in_reply_to_uri")). WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_uri")).
WhereGroup(" AND ", whereEmptyOrNull("boost_of_id")). WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")).
Where("visibility = ?", gtsmodel.VisibilityPublic). Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic).
Where("federated = ?", true) Where("? = ?", bun.Ident("status.federated"), true)
if maxID != "" { if maxID != "" {
q = q.Where("id < ?", maxID) q = q.Where("? < ?", bun.Ident("status.id"), maxID)
} }
q = q.Limit(limit).Order("id DESC") q = q.Limit(limit).Order("status.id DESC")
if err := q.Scan(ctx, &statusIDs); err != nil { if err := q.Scan(ctx, &statusIDs); err != nil {
return nil, a.conn.ProcessError(err) return nil, a.conn.ProcessError(err)
@ -411,16 +437,16 @@ func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxI
fq := a.conn. fq := a.conn.
NewSelect(). NewSelect().
Model(&blocks). Model(&blocks).
Where("block.account_id = ?", accountID). Where("? = ?", bun.Ident("block.account_id"), accountID).
Relation("TargetAccount"). Relation("TargetAccount").
Order("block.id DESC") Order("block.id DESC")
if maxID != "" { if maxID != "" {
fq = fq.Where("block.id < ?", maxID) fq = fq.Where("? < ?", bun.Ident("block.id"), maxID)
} }
if sinceID != "" { if sinceID != "" {
fq = fq.Where("block.id > ?", sinceID) fq = fq.Where("? > ?", bun.Ident("block.id"), sinceID)
} }
if limit > 0 { if limit > 0 {

View File

@ -22,7 +22,6 @@ import (
"context" "context"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"database/sql"
"fmt" "fmt"
"net" "net"
"net/mail" "net/mail"
@ -37,21 +36,26 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/uris" "github.com/superseriousbusiness/gotosocial/internal/uris"
"github.com/uptrace/bun"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
// generate RSA keys of this length
const rsaKeyBits = 2048
type adminDB struct { type adminDB struct {
conn *DBConn conn *DBConn
userCache *cache.UserCache userCache *cache.UserCache
accountCache *cache.AccountCache
} }
func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) { func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) {
q := a.conn. q := a.conn.
NewSelect(). NewSelect().
Model(&gtsmodel.Account{}). TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
Where("username = ?", username). Column("account.id").
Where("domain = ?", nil) Where("? = ?", bun.Ident("account.username"), username).
Where("? IS NULL", bun.Ident("account.domain"))
return a.conn.NotExists(ctx, q) return a.conn.NotExists(ctx, q)
} }
@ -64,29 +68,31 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.
domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @ domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
// check if the email domain is blocked // check if the email domain is blocked
if err := a.conn. emailDomainBlockedQ := a.conn.
NewSelect(). NewSelect().
Model(&gtsmodel.EmailDomainBlock{}). TableExpr("? AS ?", bun.Ident("email_domain_blocks"), bun.Ident("email_domain_block")).
Where("domain = ?", domain). Column("email_domain_block.id").
Scan(ctx); err == nil { Where("? = ?", bun.Ident("email_domain_block.domain"), domain)
// fail because we found something emailDomainBlocked, err := a.conn.Exists(ctx, emailDomainBlockedQ)
if err != nil {
return false, err
}
if emailDomainBlocked {
return false, fmt.Errorf("email domain %s is blocked", domain) return false, fmt.Errorf("email domain %s is blocked", domain)
} else if err != sql.ErrNoRows {
return false, a.conn.ProcessError(err)
} }
// check if this email is associated with a user already // check if this email is associated with a user already
q := a.conn. q := a.conn.
NewSelect(). NewSelect().
Model(&gtsmodel.User{}). TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
Where("email = ?", email). Column("user.id").
WhereOr("unconfirmed_email = ?", email) Where("? = ?", bun.Ident("user.email"), email).
WhereOr("? = ?", bun.Ident("user.unconfirmed_email"), email)
return a.conn.NotExists(ctx, q) return a.conn.NotExists(ctx, q)
} }
func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) { func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) {
key, err := rsa.GenerateKey(rand.Reader, 2048) key, err := rsa.GenerateKey(rand.Reader, rsaKeyBits)
if err != nil { if err != nil {
log.Errorf("error creating new rsa key: %s", err) log.Errorf("error creating new rsa key: %s", err)
return nil, err return nil, err
@ -94,13 +100,20 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
// if something went wrong while creating a user, we might already have an account, so check here first... // if something went wrong while creating a user, we might already have an account, so check here first...
acct := &gtsmodel.Account{} acct := &gtsmodel.Account{}
q := a.conn.NewSelect(). if err := a.conn.
NewSelect().
Model(acct). Model(acct).
Where("username = ?", username). Where("? = ?", bun.Ident("account.username"), username).
WhereGroup(" AND ", whereEmptyOrNull("domain")) WhereGroup(" AND ", whereEmptyOrNull("account.domain")).
Scan(ctx); err != nil {
err = a.conn.ProcessError(err)
if err != db.ErrNoEntries {
log.Errorf("error checking for existing account: %s", err)
return nil, err
}
if err := q.Scan(ctx); err != nil { // if we have db.ErrNoEntries, we just don't have an
// we just don't have an account yet so create one before we proceed // account yet so create one before we proceed
accountURIs := uris.GenerateURIsForAccount(username) accountURIs := uris.GenerateURIsForAccount(username)
accountID, err := id.NewRandomULID() accountID, err := id.NewRandomULID()
if err != nil { if err != nil {
@ -126,14 +139,19 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
FeaturedCollectionURI: accountURIs.CollectionURI, FeaturedCollectionURI: accountURIs.CollectionURI,
} }
// insert the new account!
if _, err = a.conn. if _, err = a.conn.
NewInsert(). NewInsert().
Model(acct). Model(acct).
Exec(ctx); err != nil { Exec(ctx); err != nil {
return nil, a.conn.ProcessError(err) return nil, a.conn.ProcessError(err)
} }
a.accountCache.Put(acct)
} }
// we either created or already had an account by now,
// so proceed with creating a user for that account
pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil { if err != nil {
return nil, fmt.Errorf("error hashing password: %s", err) return nil, fmt.Errorf("error hashing password: %s", err)
@ -171,6 +189,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
u.Moderator = &moderator u.Moderator = &moderator
} }
// insert the user!
if _, err = a.conn. if _, err = a.conn.
NewInsert(). NewInsert().
Model(u). Model(u).
@ -187,9 +206,10 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
q := a.conn. q := a.conn.
NewSelect(). NewSelect().
Model(&gtsmodel.Account{}). TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
Where("username = ?", username). Column("account.id").
WhereGroup(" AND ", whereEmptyOrNull("domain")) Where("? = ?", bun.Ident("account.username"), username).
WhereGroup(" AND ", whereEmptyOrNull("account.domain"))
exists, err := a.conn.Exists(ctx, q) exists, err := a.conn.Exists(ctx, q)
if err != nil { if err != nil {
@ -200,7 +220,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
return nil return nil
} }
key, err := rsa.GenerateKey(rand.Reader, 2048) key, err := rsa.GenerateKey(rand.Reader, rsaKeyBits)
if err != nil { if err != nil {
log.Errorf("error creating new rsa key: %s", err) log.Errorf("error creating new rsa key: %s", err)
return err return err
@ -237,6 +257,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
return a.conn.ProcessError(err) return a.conn.ProcessError(err)
} }
a.accountCache.Put(acct)
log.Infof("instance account %s CREATED with id %s", username, acct.ID) log.Infof("instance account %s CREATED with id %s", username, acct.ID)
return nil return nil
} }
@ -248,8 +269,9 @@ func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error {
// check if instance entry already exists // check if instance entry already exists
q := a.conn. q := a.conn.
NewSelect(). NewSelect().
Model(&gtsmodel.Instance{}). Column("instance.id").
Where("domain = ?", host) TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")).
Where("? = ?", bun.Ident("instance.domain"), host)
exists, err := a.conn.Exists(ctx, q) exists, err := a.conn.Exists(ctx, q)
if err != nil { if err != nil {

View File

@ -23,6 +23,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
gtsmodel "github.com/superseriousbusiness/gotosocial/internal/db/bundb/migrations/20211113114307_init"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
@ -30,6 +31,44 @@ type AdminTestSuite struct {
BunDBStandardTestSuite BunDBStandardTestSuite
} }
func (suite *AdminTestSuite) TestIsUsernameAvailableNo() {
available, err := suite.db.IsUsernameAvailable(context.Background(), "the_mighty_zork")
suite.NoError(err)
suite.False(available)
}
func (suite *AdminTestSuite) TestIsUsernameAvailableYes() {
available, err := suite.db.IsUsernameAvailable(context.Background(), "someone_completely_different")
suite.NoError(err)
suite.True(available)
}
func (suite *AdminTestSuite) TestIsEmailAvailableNo() {
available, err := suite.db.IsEmailAvailable(context.Background(), "zork@example.org")
suite.NoError(err)
suite.False(available)
}
func (suite *AdminTestSuite) TestIsEmailAvailableYes() {
available, err := suite.db.IsEmailAvailable(context.Background(), "someone@somewhere.com")
suite.NoError(err)
suite.True(available)
}
func (suite *AdminTestSuite) TestIsEmailAvailableDomainBlocked() {
if err := suite.db.Put(context.Background(), &gtsmodel.EmailDomainBlock{
ID: "01GEEV2R2YC5GRSN96761YJE47",
Domain: "somewhere.com",
CreatedByAccountID: suite.testAccounts["admin_account"].ID,
}); err != nil {
suite.FailNow(err.Error())
}
available, err := suite.db.IsEmailAvailable(context.Background(), "someone@somewhere.com")
suite.EqualError(err, "email domain somewhere.com is blocked")
suite.False(available)
}
func (suite *AdminTestSuite) TestCreateInstanceAccount() { func (suite *AdminTestSuite) TestCreateInstanceAccount() {
// we need to take an empty db for this... // we need to take an empty db for this...
testrig.StandardDBTeardown(suite.db) testrig.StandardDBTeardown(suite.db)

View File

@ -110,7 +110,7 @@ func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string,
updateWhere(q, where) updateWhere(q, where)
q = q.Set("? = ?", bun.Safe(key), value) q = q.Set("? = ?", bun.Ident(key), value)
_, err := q.Exec(ctx) _, err := q.Exec(ctx)
return b.conn.ProcessError(err) return b.conn.ProcessError(err)

View File

@ -159,17 +159,11 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {
return nil, fmt.Errorf("db migration error: %s", err) return nil, fmt.Errorf("db migration error: %s", err)
} }
// Create DB structs that require ptrs to each other // Prepare caches required by more than one struct
accounts := &accountDB{conn: conn, cache: cache.NewAccountCache()} userCache := cache.NewUserCache()
status := &statusDB{conn: conn, cache: cache.NewStatusCache()} accountCache := cache.NewAccountCache()
emoji := &emojiDB{conn: conn, cache: cache.NewEmojiCache()}
timeline := &timelineDB{conn: conn}
// Setup DB cross-referencing
accounts.status = status
status.accounts = accounts
timeline.status = status
// Prepare other caches
// Prepare mentions cache // Prepare mentions cache
// TODO: move into internal/cache // TODO: move into internal/cache
mentionCache := grufcache.New[string, *gtsmodel.Mention]() mentionCache := grufcache.New[string, *gtsmodel.Mention]()
@ -182,22 +176,30 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {
notifCache.SetTTL(time.Minute*5, false) notifCache.SetTTL(time.Minute*5, false)
notifCache.Start(time.Second * 10) notifCache.Start(time.Second * 10)
// Prepare other caches // Create DB structs that require ptrs to each other
blockCache := cache.NewDomainBlockCache() accounts := &accountDB{conn: conn, cache: accountCache}
userCache := cache.NewUserCache() status := &statusDB{conn: conn, cache: cache.NewStatusCache()}
emoji := &emojiDB{conn: conn, cache: cache.NewEmojiCache()}
timeline := &timelineDB{conn: conn}
// Setup DB cross-referencing
accounts.status = status
status.accounts = accounts
timeline.status = status
ps := &DBService{ ps := &DBService{
Account: accounts, Account: accounts,
Admin: &adminDB{ Admin: &adminDB{
conn: conn, conn: conn,
userCache: userCache, userCache: userCache,
accountCache: accountCache,
}, },
Basic: &basicDB{ Basic: &basicDB{
conn: conn, conn: conn,
}, },
Domain: &domainDB{ Domain: &domainDB{
conn: conn, conn: conn,
cache: blockCache, cache: cache.NewDomainBlockCache(),
}, },
Emoji: emoji, Emoji: emoji,
Instance: &instanceDB{ Instance: &instanceDB{

View File

@ -40,6 +40,7 @@ type BunDBStandardTestSuite struct {
testStatuses map[string]*gtsmodel.Status testStatuses map[string]*gtsmodel.Status
testTags map[string]*gtsmodel.Tag testTags map[string]*gtsmodel.Tag
testMentions map[string]*gtsmodel.Mention testMentions map[string]*gtsmodel.Mention
testFollows map[string]*gtsmodel.Follow
} }
func (suite *BunDBStandardTestSuite) SetupSuite() { func (suite *BunDBStandardTestSuite) SetupSuite() {
@ -52,6 +53,7 @@ func (suite *BunDBStandardTestSuite) SetupSuite() {
suite.testStatuses = testrig.NewTestStatuses() suite.testStatuses = testrig.NewTestStatuses()
suite.testTags = testrig.NewTestTags() suite.testTags = testrig.NewTestTags()
suite.testMentions = testrig.NewTestMentions() suite.testMentions = testrig.NewTestMentions()
suite.testFollows = testrig.NewTestFollows()
} }
func (suite *BunDBStandardTestSuite) SetupTest() { func (suite *BunDBStandardTestSuite) SetupTest() {

View File

@ -28,6 +28,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
"golang.org/x/net/idna" "golang.org/x/net/idna"
) )
@ -95,7 +96,7 @@ func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel
q := d.conn. q := d.conn.
NewSelect(). NewSelect().
Model(block). Model(block).
Where("domain = ?", domain). Where("? = ?", bun.Ident("domain_block.domain"), domain).
Limit(1) Limit(1)
// Query database for domain block // Query database for domain block
@ -126,7 +127,7 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Erro
// Attempt to delete domain block // Attempt to delete domain block
if _, err := d.conn.NewDelete(). if _, err := d.conn.NewDelete().
Model((*gtsmodel.DomainBlock)(nil)). Model((*gtsmodel.DomainBlock)(nil)).
Where("domain = ?", domain). Where("? = ?", bun.Ident("domain_block.domain"), domain).
Exec(ctx); err != nil { Exec(ctx); err != nil {
return d.conn.ProcessError(err) return d.conn.ProcessError(err)
} }

View File

@ -54,12 +54,12 @@ func (e *emojiDB) GetCustomEmojis(ctx context.Context) ([]*gtsmodel.Emoji, db.Er
q := e.conn. q := e.conn.
NewSelect(). NewSelect().
Table("emojis"). TableExpr("? AS ?", bun.Ident("emojis"), bun.Ident("emoji")).
Column("id"). Column("emoji.id").
Where("visible_in_picker = true"). Where("? = ?", bun.Ident("emoji.visible_in_picker"), true).
Where("disabled = false"). Where("? = ?", bun.Ident("emoji.disabled"), false).
Where("domain IS NULL"). Where("? IS NULL", bun.Ident("emoji.domain")).
Order("shortcode ASC") Order("emoji.shortcode ASC")
if err := q.Scan(ctx, &emojiIDs); err != nil { if err := q.Scan(ctx, &emojiIDs); err != nil {
return nil, e.conn.ProcessError(err) return nil, e.conn.ProcessError(err)
@ -75,7 +75,7 @@ func (e *emojiDB) GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji,
return e.cache.GetByID(id) return e.cache.GetByID(id)
}, },
func(emoji *gtsmodel.Emoji) error { func(emoji *gtsmodel.Emoji) error {
return e.newEmojiQ(emoji).Where("emoji.id = ?", id).Scan(ctx) return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.id"), id).Scan(ctx)
}, },
) )
} }
@ -87,7 +87,7 @@ func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoj
return e.cache.GetByURI(uri) return e.cache.GetByURI(uri)
}, },
func(emoji *gtsmodel.Emoji) error { func(emoji *gtsmodel.Emoji) error {
return e.newEmojiQ(emoji).Where("emoji.uri = ?", uri).Scan(ctx) return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.uri"), uri).Scan(ctx)
}, },
) )
} }
@ -102,11 +102,11 @@ func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode strin
q := e.newEmojiQ(emoji) q := e.newEmojiQ(emoji)
if domain != "" { if domain != "" {
q = q.Where("emoji.shortcode = ?", shortcode) q = q.Where("? = ?", bun.Ident("emoji.shortcode"), shortcode)
q = q.Where("emoji.domain = ?", domain) q = q.Where("? = ?", bun.Ident("emoji.domain"), domain)
} else { } else {
q = q.Where("emoji.shortcode = ?", strings.ToLower(shortcode)) q = q.Where("? = ?", bun.Ident("emoji.shortcode"), strings.ToLower(shortcode))
q = q.Where("emoji.domain IS NULL") q = q.Where("? IS NULL", bun.Ident("emoji.domain"))
} }
return q.Scan(ctx) return q.Scan(ctx)

View File

@ -24,7 +24,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -35,15 +34,16 @@ type instanceDB struct {
func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) { func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) {
q := i.conn. q := i.conn.
NewSelect(). NewSelect().
Model(&[]*gtsmodel.Account{}). TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
Where("username != ?", domain). Column("account.id").
Where("? IS NULL", bun.Ident("suspended_at")) Where("? != ?", bun.Ident("account.username"), domain).
Where("? IS NULL", bun.Ident("account.suspended_at"))
if domain == config.GetHost() { if domain == config.GetHost() || domain == config.GetAccountDomain() {
// if the domain is *this* domain, just count where the domain field is null // if the domain is *this* domain, just count where the domain field is null
q = q.WhereGroup(" AND ", whereEmptyOrNull("domain")) q = q.WhereGroup(" AND ", whereEmptyOrNull("account.domain"))
} else { } else {
q = q.Where("domain = ?", domain) q = q.Where("? = ?", bun.Ident("account.domain"), domain)
} }
count, err := q.Count(ctx) count, err := q.Count(ctx)
@ -56,15 +56,16 @@ func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int
func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) { func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) {
q := i.conn. q := i.conn.
NewSelect(). NewSelect().
Model(&[]*gtsmodel.Status{}) TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status"))
if domain == config.GetHost() { if domain == config.GetHost() || domain == config.GetAccountDomain() {
// if the domain is *this* domain, just count where local is true // if the domain is *this* domain, just count where local is true
q = q.Where("local = ?", true) q = q.Where("? = ?", bun.Ident("status.local"), true)
} else { } else {
// join on the domain of the account // join on the domain of the account
q = q.Join("JOIN accounts AS account ON account.id = status.account_id"). q = q.
Where("account.domain = ?", domain) Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("account.id"), bun.Ident("status.account_id")).
Where("? = ?", bun.Ident("account.domain"), domain)
} }
count, err := q.Count(ctx) count, err := q.Count(ctx)
@ -77,14 +78,14 @@ func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (
func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) { func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) {
q := i.conn. q := i.conn.
NewSelect(). NewSelect().
Model(&[]*gtsmodel.Instance{}) TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance"))
if domain == config.GetHost() { if domain == config.GetHost() {
// if the domain is *this* domain, just count other instances it knows about // if the domain is *this* domain, just count other instances it knows about
// exclude domains that are blocked // exclude domains that are blocked
q = q. q = q.
Where("domain != ?", domain). Where("? != ?", bun.Ident("instance.domain"), domain).
Where("? IS NULL", bun.Ident("suspended_at")) Where("? IS NULL", bun.Ident("instance.suspended_at"))
} else { } else {
// TODO: implement federated domain counting properly for remote domains // TODO: implement federated domain counting properly for remote domains
return 0, nil return 0, nil
@ -103,10 +104,10 @@ func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool
q := i.conn. q := i.conn.
NewSelect(). NewSelect().
Model(&instances). Model(&instances).
Where("domain != ?", config.GetHost()) Where("? != ?", bun.Ident("instance.domain"), config.GetHost())
if !includeSuspended { if !includeSuspended {
q = q.Where("? IS NULL", bun.Ident("suspended_at")) q = q.Where("? IS NULL", bun.Ident("instance.suspended_at"))
} }
if err := q.Scan(ctx); err != nil { if err := q.Scan(ctx); err != nil {
@ -117,17 +118,15 @@ func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool
} }
func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) { func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
log.Debug("GetAccountsForInstance")
accounts := []*gtsmodel.Account{} accounts := []*gtsmodel.Account{}
q := i.conn.NewSelect(). q := i.conn.NewSelect().
Model(&accounts). Model(&accounts).
Where("domain = ?", domain). Where("? = ?", bun.Ident("account.domain"), domain).
Order("id DESC") Order("account.id DESC")
if maxID != "" { if maxID != "" {
q = q.Where("id < ?", maxID) q = q.Where("? < ?", bun.Ident("account.id"), maxID)
} }
if limit > 0 { if limit > 0 {

View File

@ -0,0 +1,83 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb_test
import (
"context"
"testing"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/config"
)
type InstanceTestSuite struct {
BunDBStandardTestSuite
}
func (suite *InstanceTestSuite) TestCountInstanceUsers() {
count, err := suite.db.CountInstanceUsers(context.Background(), config.GetHost())
suite.NoError(err)
suite.Equal(4, count)
}
func (suite *InstanceTestSuite) TestCountInstanceUsersRemote() {
count, err := suite.db.CountInstanceUsers(context.Background(), "fossbros-anonymous.io")
suite.NoError(err)
suite.Equal(1, count)
}
func (suite *InstanceTestSuite) TestCountInstanceStatuses() {
count, err := suite.db.CountInstanceStatuses(context.Background(), config.GetHost())
suite.NoError(err)
suite.Equal(16, count)
}
func (suite *InstanceTestSuite) TestCountInstanceStatusesRemote() {
count, err := suite.db.CountInstanceStatuses(context.Background(), "fossbros-anonymous.io")
suite.NoError(err)
suite.Equal(1, count)
}
func (suite *InstanceTestSuite) TestCountInstanceDomains() {
count, err := suite.db.CountInstanceDomains(context.Background(), config.GetHost())
suite.NoError(err)
suite.Equal(2, count)
}
func (suite *InstanceTestSuite) TestGetInstancePeers() {
peers, err := suite.db.GetInstancePeers(context.Background(), false)
suite.NoError(err)
suite.Len(peers, 2)
}
func (suite *InstanceTestSuite) TestGetInstancePeersIncludeSuspended() {
peers, err := suite.db.GetInstancePeers(context.Background(), true)
suite.NoError(err)
suite.Len(peers, 2)
}
func (suite *InstanceTestSuite) TestGetInstanceAccounts() {
accounts, err := suite.db.GetInstanceAccounts(context.Background(), "fossbros-anonymous.io", "", 10)
suite.NoError(err)
suite.Len(accounts, 1)
}
func TestInstanceTestSuite(t *testing.T) {
suite.Run(t, new(InstanceTestSuite))
}

View File

@ -42,7 +42,7 @@ func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.M
attachment := &gtsmodel.MediaAttachment{} attachment := &gtsmodel.MediaAttachment{}
q := m.newMediaQ(attachment). q := m.newMediaQ(attachment).
Where("media_attachment.id = ?", id) Where("? = ?", bun.Ident("media_attachment.id"), id)
if err := q.Scan(ctx); err != nil { if err := q.Scan(ctx); err != nil {
return nil, m.conn.ProcessError(err) return nil, m.conn.ProcessError(err)
@ -56,10 +56,10 @@ func (m *mediaDB) GetRemoteOlderThan(ctx context.Context, olderThan time.Time, l
q := m.conn. q := m.conn.
NewSelect(). NewSelect().
Model(&attachments). Model(&attachments).
Where("media_attachment.cached = true"). Where("? = ?", bun.Ident("media_attachment.cached"), true).
Where("media_attachment.avatar = false"). Where("? = ?", bun.Ident("media_attachment.avatar"), false).
Where("media_attachment.header = false"). Where("? = ?", bun.Ident("media_attachment.header"), false).
Where("media_attachment.created_at < ?", olderThan). Where("? < ?", bun.Ident("media_attachment.created_at"), olderThan).
WhereGroup(" AND ", whereNotEmptyAndNotNull("media_attachment.remote_url")). WhereGroup(" AND ", whereNotEmptyAndNotNull("media_attachment.remote_url")).
Order("media_attachment.created_at DESC") Order("media_attachment.created_at DESC")
@ -79,13 +79,13 @@ func (m *mediaDB) GetAvatarsAndHeaders(ctx context.Context, maxID string, limit
q := m.newMediaQ(&attachments). q := m.newMediaQ(&attachments).
WhereGroup(" AND ", func(innerQ *bun.SelectQuery) *bun.SelectQuery { WhereGroup(" AND ", func(innerQ *bun.SelectQuery) *bun.SelectQuery {
return innerQ. return innerQ.
WhereOr("media_attachment.avatar = true"). WhereOr("? = ?", bun.Ident("media_attachment.avatar"), true).
WhereOr("media_attachment.header = true") WhereOr("? = ?", bun.Ident("media_attachment.header"), true)
}). }).
Order("media_attachment.id DESC") Order("media_attachment.id DESC")
if maxID != "" { if maxID != "" {
q = q.Where("media_attachment.id < ?", maxID) q = q.Where("? < ?", bun.Ident("media_attachment.id"), maxID)
} }
if limit != 0 { if limit != 0 {
@ -103,15 +103,15 @@ func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan tim
attachments := []*gtsmodel.MediaAttachment{} attachments := []*gtsmodel.MediaAttachment{}
q := m.newMediaQ(&attachments). q := m.newMediaQ(&attachments).
Where("media_attachment.cached = true"). Where("? = ?", bun.Ident("media_attachment.cached"), true).
Where("media_attachment.avatar = false"). Where("? = ?", bun.Ident("media_attachment.avatar"), false).
Where("media_attachment.header = false"). Where("? = ?", bun.Ident("media_attachment.header"), false).
Where("media_attachment.created_at < ?", olderThan). Where("? < ?", bun.Ident("media_attachment.created_at"), olderThan).
Where("media_attachment.remote_url IS NULL"). Where("? IS NULL", bun.Ident("media_attachment.remote_url")).
Where("media_attachment.status_id IS NULL") Where("? IS NULL", bun.Ident("media_attachment.status_id"))
if maxID != "" { if maxID != "" {
q = q.Where("media_attachment.id < ?", maxID) q = q.Where("? < ?", bun.Ident("media_attachment.id"), maxID)
} }
if limit != 0 { if limit != 0 {

View File

@ -46,7 +46,7 @@ func (m *mentionDB) getMentionDB(ctx context.Context, id string) (*gtsmodel.Ment
mention := gtsmodel.Mention{} mention := gtsmodel.Mention{}
q := m.newMentionQ(&mention). q := m.newMentionQ(&mention).
Where("mention.id = ?", id) Where("? = ?", bun.Ident("mention.id"), id)
if err := q.Scan(ctx); err != nil { if err := q.Scan(ctx); err != nil {
return nil, m.conn.ProcessError(err) return nil, m.conn.ProcessError(err)

View File

@ -25,6 +25,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/uptrace/bun"
) )
type notificationDB struct { type notificationDB struct {
@ -67,24 +68,24 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,
q := n.conn. q := n.conn.
NewSelect(). NewSelect().
Table("notifications"). TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
Column("id") Column("notification.id")
if maxID != "" { if maxID != "" {
q = q.Where("id < ?", maxID) q = q.Where("? < ?", bun.Ident("notification.id"), maxID)
} }
if sinceID != "" { if sinceID != "" {
q = q.Where("id > ?", sinceID) q = q.Where("? > ?", bun.Ident("notification.id"), sinceID)
} }
for _, excludeType := range excludeTypes { for _, excludeType := range excludeTypes {
q = q.Where("notification_type != ?", excludeType) q = q.Where("? != ?", bun.Ident("notification.notification_type"), excludeType)
} }
q = q. q = q.
Where("target_account_id = ?", accountID). Where("? = ?", bun.Ident("notification.target_account_id"), accountID).
Order("id DESC") Order("notification.id DESC")
if limit != 0 { if limit != 0 {
q = q.Limit(limit) q = q.Limit(limit)
@ -116,13 +117,12 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,
func (n *notificationDB) ClearNotifications(ctx context.Context, accountID string) db.Error { func (n *notificationDB) ClearNotifications(ctx context.Context, accountID string) db.Error {
if _, err := n.conn. if _, err := n.conn.
NewDelete(). NewDelete().
Table("notifications"). TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
Where("target_account_id = ?", accountID). Where("? = ?", bun.Ident("notification.target_account_id"), accountID).
Exec(ctx); err != nil { Exec(ctx); err != nil {
return n.conn.ProcessError(err) return n.conn.ProcessError(err)
} }
n.cache.Clear() n.cache.Clear()
return nil return nil
} }

View File

@ -51,26 +51,25 @@ func (r *relationshipDB) newFollowQ(follow interface{}) *bun.SelectQuery {
func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) { func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) {
q := r.conn. q := r.conn.
NewSelect(). NewSelect().
Model(&gtsmodel.Block{}). TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
ExcludeColumn("id", "created_at", "updated_at", "uri"). Column("block.id")
Limit(1)
if eitherDirection { if eitherDirection {
q = q. q = q.
WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery { WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery {
return inner. return inner.
Where("account_id = ?", account1). Where("? = ?", bun.Ident("block.account_id"), account1).
Where("target_account_id = ?", account2) Where("? = ?", bun.Ident("block.target_account_id"), account2)
}). }).
WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery { WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery {
return inner. return inner.
Where("account_id = ?", account2). Where("? = ?", bun.Ident("block.account_id"), account2).
Where("target_account_id = ?", account1) Where("? = ?", bun.Ident("block.target_account_id"), account1)
}) })
} else { } else {
q = q. q = q.
Where("account_id = ?", account1). Where("? = ?", bun.Ident("block.account_id"), account1).
Where("target_account_id = ?", account2) Where("? = ?", bun.Ident("block.target_account_id"), account2)
} }
return r.conn.Exists(ctx, q) return r.conn.Exists(ctx, q)
@ -80,8 +79,8 @@ func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2
block := &gtsmodel.Block{} block := &gtsmodel.Block{}
q := r.newBlockQ(block). q := r.newBlockQ(block).
Where("block.account_id = ?", account1). Where("? = ?", bun.Ident("block.account_id"), account1).
Where("block.target_account_id = ?", account2) Where("? = ?", bun.Ident("block.target_account_id"), account2)
if err := q.Scan(ctx); err != nil { if err := q.Scan(ctx); err != nil {
return nil, r.conn.ProcessError(err) return nil, r.conn.ProcessError(err)
@ -99,13 +98,13 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount
if err := r.conn. if err := r.conn.
NewSelect(). NewSelect().
Model(follow). Model(follow).
Where("account_id = ?", requestingAccount). Column("follow.show_reblogs", "follow.notify").
Where("target_account_id = ?", targetAccount). Where("? = ?", bun.Ident("follow.account_id"), requestingAccount).
Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount).
Limit(1). Limit(1).
Scan(ctx); err != nil { Scan(ctx); err != nil {
if err != sql.ErrNoRows { if err := r.conn.ProcessError(err); err != db.ErrNoEntries {
// a proper error return nil, fmt.Errorf("GetRelationship: error fetching follow: %s", err)
return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err)
} }
// no follow exists so these are all false // no follow exists so these are all false
rel.Following = false rel.Following = false
@ -119,55 +118,56 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount
} }
// check if the target account follows the requesting account // check if the target account follows the requesting account
count, err := r.conn. followedByQ := r.conn.
NewSelect(). NewSelect().
Model(&gtsmodel.Follow{}). TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
Where("account_id = ?", targetAccount). Column("follow.id").
Where("target_account_id = ?", requestingAccount). Where("? = ?", bun.Ident("follow.account_id"), targetAccount).
Limit(1). Where("? = ?", bun.Ident("follow.target_account_id"), requestingAccount)
Count(ctx) followedBy, err := r.conn.Exists(ctx, followedByQ)
if err != nil { if err != nil {
return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err) return nil, fmt.Errorf("GetRelationship: error checking followedBy: %s", err)
} }
rel.FollowedBy = count > 0 rel.FollowedBy = followedBy
// check if the requesting account blocks the target account
count, err = r.conn.NewSelect().
Model(&gtsmodel.Block{}).
Where("account_id = ?", requestingAccount).
Where("target_account_id = ?", targetAccount).
Limit(1).
Count(ctx)
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err)
}
rel.Blocking = count > 0
// check if the target account blocks the requesting account
count, err = r.conn.
NewSelect().
Model(&gtsmodel.Block{}).
Where("account_id = ?", targetAccount).
Where("target_account_id = ?", requestingAccount).
Limit(1).
Count(ctx)
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
}
rel.BlockedBy = count > 0
// check if there's a pending following request from requesting account to target account // check if there's a pending following request from requesting account to target account
count, err = r.conn. requestedQ := r.conn.
NewSelect(). NewSelect().
Model(&gtsmodel.FollowRequest{}). TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
Where("account_id = ?", requestingAccount). Column("follow_request.id").
Where("target_account_id = ?", targetAccount). Where("? = ?", bun.Ident("follow_request.account_id"), requestingAccount).
Limit(1). Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount)
Count(ctx) requested, err := r.conn.Exists(ctx, requestedQ)
if err != nil { if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err) return nil, fmt.Errorf("GetRelationship: error checking requested: %s", err)
} }
rel.Requested = count > 0 rel.Requested = requested
// check if the requesting account is blocking the target account
blockingQ := r.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
Column("block.id").
Where("? = ?", bun.Ident("block.account_id"), requestingAccount).
Where("? = ?", bun.Ident("block.target_account_id"), targetAccount)
blocking, err := r.conn.Exists(ctx, blockingQ)
if err != nil {
return nil, fmt.Errorf("GetRelationship: error checking blocking: %s", err)
}
rel.Blocking = blocking
// check if the requesting account is blocked by the target account
blockedByQ := r.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
Column("block.id").
Where("? = ?", bun.Ident("block.account_id"), targetAccount).
Where("? = ?", bun.Ident("block.target_account_id"), requestingAccount)
blockedBy, err := r.conn.Exists(ctx, blockedByQ)
if err != nil {
return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %s", err)
}
rel.BlockedBy = blockedBy
return rel, nil return rel, nil
} }
@ -179,10 +179,10 @@ func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmode
q := r.conn. q := r.conn.
NewSelect(). NewSelect().
Model(&gtsmodel.Follow{}). TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
Where("account_id = ?", sourceAccount.ID). Column("follow.id").
Where("target_account_id = ?", targetAccount.ID). Where("? = ?", bun.Ident("follow.account_id"), sourceAccount.ID).
Limit(1) Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount.ID)
return r.conn.Exists(ctx, q) return r.conn.Exists(ctx, q)
} }
@ -194,9 +194,10 @@ func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *g
q := r.conn. q := r.conn.
NewSelect(). NewSelect().
Model(&gtsmodel.FollowRequest{}). TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
Where("account_id = ?", sourceAccount.ID). Column("follow_request.id").
Where("target_account_id = ?", targetAccount.ID) Where("? = ?", bun.Ident("follow_request.account_id"), sourceAccount.ID).
Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount.ID)
return r.conn.Exists(ctx, q) return r.conn.Exists(ctx, q)
} }
@ -222,82 +223,98 @@ func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmod
} }
func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) { func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
// make sure the original follow request exists var follow *gtsmodel.Follow
fr := &gtsmodel.FollowRequest{}
if err := r.conn. if err := r.conn.RunInTx(ctx, func(tx bun.Tx) error {
// get original follow request
followRequest := &gtsmodel.FollowRequest{}
if err := tx.
NewSelect(). NewSelect().
Model(fr). Model(followRequest).
Where("account_id = ?", originAccountID). Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID).
Where("target_account_id = ?", targetAccountID). Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID).
Scan(ctx); err != nil { Scan(ctx); err != nil {
return nil, r.conn.ProcessError(err) return err
} }
// create a new follow to 'replace' the request with // create a new follow to 'replace' the request with
follow := &gtsmodel.Follow{ follow = &gtsmodel.Follow{
ID: fr.ID, ID: followRequest.ID,
AccountID: originAccountID, AccountID: originAccountID,
TargetAccountID: targetAccountID, TargetAccountID: targetAccountID,
URI: fr.URI, URI: followRequest.URI,
} }
// if the follow already exists, just update the URI -- we don't need to do anything else // if the follow already exists, just update the URI -- we don't need to do anything else
if _, err := r.conn. if _, err := tx.
NewInsert(). NewInsert().
Model(follow). Model(follow).
On("CONFLICT (account_id,target_account_id) DO UPDATE set uri = ?", follow.URI). On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI).
Exec(ctx); err != nil { Exec(ctx); err != nil {
return nil, r.conn.ProcessError(err) return err
} }
// now remove the follow request // now remove the follow request
if _, err := r.conn. if _, err := tx.
NewDelete(). NewDelete().
Model(&gtsmodel.FollowRequest{}). Model(followRequest).
Where("account_id = ?", originAccountID). WherePK().
Where("target_account_id = ?", targetAccountID).
Exec(ctx); err != nil { Exec(ctx); err != nil {
return err
}
return nil
}); err != nil {
return nil, r.conn.ProcessError(err) return nil, r.conn.ProcessError(err)
} }
// return the new follow
return follow, nil return follow, nil
} }
func (r *relationshipDB) RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, db.Error) { func (r *relationshipDB) RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, db.Error) {
// first get the follow request out of the database followRequest := &gtsmodel.FollowRequest{}
fr := &gtsmodel.FollowRequest{}
if err := r.conn. if err := r.conn.RunInTx(ctx, func(tx bun.Tx) error {
// get original follow request
if err := tx.
NewSelect(). NewSelect().
Model(fr). Model(followRequest).
Where("account_id = ?", originAccountID). Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID).
Where("target_account_id = ?", targetAccountID). Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID).
Scan(ctx); err != nil { Scan(ctx); err != nil {
return nil, r.conn.ProcessError(err) return err
} }
// now delete it from the database by ID // now delete it from the database by ID
if _, err := r.conn. if _, err := tx.
NewDelete(). NewDelete().
Model(&gtsmodel.FollowRequest{ID: fr.ID}). Model(followRequest).
WherePK(). WherePK().
Exec(ctx); err != nil { Exec(ctx); err != nil {
return err
}
return nil
}); err != nil {
return nil, r.conn.ProcessError(err) return nil, r.conn.ProcessError(err)
} }
// return the deleted follow request // return the deleted follow request
return fr, nil return followRequest, nil
} }
func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, db.Error) { func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, db.Error) {
followRequests := []*gtsmodel.FollowRequest{} followRequests := []*gtsmodel.FollowRequest{}
q := r.newFollowQ(&followRequests). q := r.newFollowQ(&followRequests).
Where("target_account_id = ?", accountID). Where("? = ?", bun.Ident("follow_request.target_account_id"), accountID).
Order("follow_request.updated_at DESC") Order("follow_request.updated_at DESC")
if err := q.Scan(ctx); err != nil { if err := q.Scan(ctx); err != nil {
return nil, r.conn.ProcessError(err) return nil, r.conn.ProcessError(err)
} }
return followRequests, nil return followRequests, nil
} }
@ -305,21 +322,31 @@ func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string
follows := []*gtsmodel.Follow{} follows := []*gtsmodel.Follow{}
q := r.newFollowQ(&follows). q := r.newFollowQ(&follows).
Where("account_id = ?", accountID). Where("? = ?", bun.Ident("follow.account_id"), accountID).
Order("follow.updated_at DESC") Order("follow.updated_at DESC")
if err := q.Scan(ctx); err != nil { if err := q.Scan(ctx); err != nil {
return nil, r.conn.ProcessError(err) return nil, r.conn.ProcessError(err)
} }
return follows, nil return follows, nil
} }
func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, db.Error) { func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, db.Error) {
return r.conn. q := r.conn.
NewSelect(). NewSelect().
Model(&[]*gtsmodel.Follow{}). TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow"))
Where("account_id = ?", accountID).
Count(ctx) if localOnly {
q = q.
Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.target_account_id"), bun.Ident("account.id")).
Where("? = ?", bun.Ident("follow.account_id"), accountID).
Where("? IS NULL", bun.Ident("account.domain"))
} else {
q = q.Where("? = ?", bun.Ident("follow.account_id"), accountID)
}
return q.Count(ctx)
} }
func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) { func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) {
@ -331,12 +358,12 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str
Order("follow.updated_at DESC") Order("follow.updated_at DESC")
if localOnly { if localOnly {
q = q.ColumnExpr("follow.*"). q = q.
Join("JOIN accounts AS a ON follow.account_id = CAST(a.id as TEXT)"). Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.account_id"), bun.Ident("account.id")).
Where("follow.target_account_id = ?", accountID). Where("? = ?", bun.Ident("follow.target_account_id"), accountID).
WhereGroup(" AND ", whereEmptyOrNull("a.domain")) Where("? IS NULL", bun.Ident("account.domain"))
} else { } else {
q = q.Where("target_account_id = ?", accountID) q = q.Where("? = ?", bun.Ident("follow.target_account_id"), accountID)
} }
err := q.Scan(ctx) err := q.Scan(ctx)
@ -347,9 +374,18 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str
} }
func (r *relationshipDB) CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, db.Error) { func (r *relationshipDB) CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, db.Error) {
return r.conn. q := r.conn.
NewSelect(). NewSelect().
Model(&[]*gtsmodel.Follow{}). TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow"))
Where("target_account_id = ?", accountID).
Count(ctx) if localOnly {
q = q.
Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.account_id"), bun.Ident("account.id")).
Where("? = ?", bun.Ident("follow.target_account_id"), accountID).
Where("? IS NULL", bun.Ident("account.domain"))
} else {
q = q.Where("? = ?", bun.Ident("follow.target_account_id"), accountID)
}
return q.Count(ctx)
} }

View File

@ -20,7 +20,6 @@ package bundb_test
import ( import (
"context" "context"
"errors"
"testing" "testing"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
@ -48,12 +47,14 @@ func (suite *RelationshipTestSuite) TestIsBlocked() {
suite.False(blocked) suite.False(blocked)
// have account1 block account2 // have account1 block account2
suite.db.Put(ctx, &gtsmodel.Block{ if err := suite.db.Put(ctx, &gtsmodel.Block{
ID: "01G202BCSXXJZ70BHB5KCAHH8C", ID: "01G202BCSXXJZ70BHB5KCAHH8C",
URI: "http://localhost:8080/some_block_uri_1", URI: "http://localhost:8080/some_block_uri_1",
AccountID: account1, AccountID: account1,
TargetAccountID: account2, TargetAccountID: account2,
}) }); err != nil {
suite.FailNow(err.Error())
}
// account 1 now blocks account 2 // account 1 now blocks account 2
blocked, err = suite.db.IsBlocked(ctx, account1, account2, false) blocked, err = suite.db.IsBlocked(ctx, account1, account2, false)
@ -75,62 +76,242 @@ func (suite *RelationshipTestSuite) TestIsBlocked() {
} }
func (suite *RelationshipTestSuite) TestGetBlock() { func (suite *RelationshipTestSuite) TestGetBlock() {
suite.Suite.T().Skip("TODO: implement") ctx := context.Background()
account1 := suite.testAccounts["local_account_1"].ID
account2 := suite.testAccounts["local_account_2"].ID
if err := suite.db.Put(ctx, &gtsmodel.Block{
ID: "01G202BCSXXJZ70BHB5KCAHH8C",
URI: "http://localhost:8080/some_block_uri_1",
AccountID: account1,
TargetAccountID: account2,
}); err != nil {
suite.FailNow(err.Error())
}
block, err := suite.db.GetBlock(ctx, account1, account2)
suite.NoError(err)
suite.NotNil(block)
suite.Equal("01G202BCSXXJZ70BHB5KCAHH8C", block.ID)
} }
func (suite *RelationshipTestSuite) TestGetRelationship() { func (suite *RelationshipTestSuite) TestGetRelationship() {
suite.Suite.T().Skip("TODO: implement") requestingAccount := suite.testAccounts["local_account_1"]
targetAccount := suite.testAccounts["admin_account"]
relationship, err := suite.db.GetRelationship(context.Background(), requestingAccount.ID, targetAccount.ID)
suite.NoError(err)
suite.NotNil(relationship)
suite.True(relationship.Following)
suite.True(relationship.ShowingReblogs)
suite.False(relationship.Notifying)
suite.True(relationship.FollowedBy)
suite.False(relationship.Blocking)
suite.False(relationship.BlockedBy)
suite.False(relationship.Muting)
suite.False(relationship.MutingNotifications)
suite.False(relationship.Requested)
suite.False(relationship.DomainBlocking)
suite.False(relationship.Endorsed)
suite.Empty(relationship.Note)
} }
func (suite *RelationshipTestSuite) TestIsFollowing() { func (suite *RelationshipTestSuite) TestIsFollowingYes() {
suite.Suite.T().Skip("TODO: implement") requestingAccount := suite.testAccounts["local_account_1"]
targetAccount := suite.testAccounts["admin_account"]
isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount, targetAccount)
suite.NoError(err)
suite.True(isFollowing)
}
func (suite *RelationshipTestSuite) TestIsFollowingNo() {
requestingAccount := suite.testAccounts["admin_account"]
targetAccount := suite.testAccounts["local_account_2"]
isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount, targetAccount)
suite.NoError(err)
suite.False(isFollowing)
} }
func (suite *RelationshipTestSuite) TestIsMutualFollowing() { func (suite *RelationshipTestSuite) TestIsMutualFollowing() {
suite.Suite.T().Skip("TODO: implement") requestingAccount := suite.testAccounts["local_account_1"]
targetAccount := suite.testAccounts["admin_account"]
isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount, targetAccount)
suite.NoError(err)
suite.True(isMutualFollowing)
} }
func (suite *RelationshipTestSuite) AcceptFollowRequest() { func (suite *RelationshipTestSuite) TestIsMutualFollowingNo() {
for _, account := range suite.testAccounts { requestingAccount := suite.testAccounts["local_account_1"]
_, err := suite.db.AcceptFollowRequest(context.Background(), account.ID, "NON-EXISTENT-ID") targetAccount := suite.testAccounts["local_account_2"]
if err != nil && !errors.Is(err, db.ErrNoEntries) { isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount, targetAccount)
suite.Suite.Fail("error accepting follow request: %v", err) suite.NoError(err)
} suite.True(isMutualFollowing)
}
} }
func (suite *RelationshipTestSuite) GetAccountFollowRequests() { func (suite *RelationshipTestSuite) TestAcceptFollowRequestOK() {
suite.Suite.T().Skip("TODO: implement") ctx := context.Background()
account := suite.testAccounts["admin_account"]
targetAccount := suite.testAccounts["local_account_2"]
followRequest := &gtsmodel.FollowRequest{
ID: "01GEF753FWHCHRDWR0QEHBXM8W",
URI: "http://localhost:8080/weeeeeeeeeeeeeeeee",
AccountID: account.ID,
TargetAccountID: targetAccount.ID,
} }
func (suite *RelationshipTestSuite) GetAccountFollows() { if err := suite.db.Put(ctx, followRequest); err != nil {
suite.Suite.T().Skip("TODO: implement") suite.FailNow(err.Error())
} }
func (suite *RelationshipTestSuite) CountAccountFollows() { follow, err := suite.db.AcceptFollowRequest(ctx, account.ID, targetAccount.ID)
suite.Suite.T().Skip("TODO: implement") suite.NoError(err)
suite.NotNil(follow)
suite.Equal(followRequest.URI, follow.URI)
} }
func (suite *RelationshipTestSuite) GetAccountFollowedBy() { func (suite *RelationshipTestSuite) TestAcceptFollowRequestNotExisting() {
// TODO: more comprehensive tests here ctx := context.Background()
account := suite.testAccounts["admin_account"]
targetAccount := suite.testAccounts["local_account_2"]
for _, account := range suite.testAccounts { follow, err := suite.db.AcceptFollowRequest(ctx, account.ID, targetAccount.ID)
var err error suite.ErrorIs(err, db.ErrNoEntries)
suite.Nil(follow)
_, err = suite.db.GetAccountFollowedBy(context.Background(), account.ID, false)
if err != nil {
suite.Suite.Fail("error checking accounts followed by: %v", err)
} }
_, err = suite.db.GetAccountFollowedBy(context.Background(), account.ID, true) func (suite *RelationshipTestSuite) TestAcceptFollowRequestFollowAlreadyExists() {
if err != nil { ctx := context.Background()
suite.Suite.Fail("error checking localOnly accounts followed by: %v", err) account := suite.testAccounts["local_account_1"]
} targetAccount := suite.testAccounts["admin_account"]
}
// follow already exists in the db from local_account_1 -> admin_account
existingFollow := &gtsmodel.Follow{}
if err := suite.db.GetByID(ctx, suite.testFollows["local_account_1_admin_account"].ID, existingFollow); err != nil {
suite.FailNow(err.Error())
} }
func (suite *RelationshipTestSuite) CountAccountFollowedBy() { followRequest := &gtsmodel.FollowRequest{
suite.Suite.T().Skip("TODO: implement") ID: "01GEF753FWHCHRDWR0QEHBXM8W",
URI: "http://localhost:8080/weeeeeeeeeeeeeeeee",
AccountID: account.ID,
TargetAccountID: targetAccount.ID,
}
if err := suite.db.Put(ctx, followRequest); err != nil {
suite.FailNow(err.Error())
}
follow, err := suite.db.AcceptFollowRequest(ctx, account.ID, targetAccount.ID)
suite.NoError(err)
suite.NotNil(follow)
// uri should be equal to value of new/overlapping follow request
suite.NotEqual(followRequest.URI, existingFollow.URI)
suite.Equal(followRequest.URI, follow.URI)
}
func (suite *RelationshipTestSuite) TestRejectFollowRequestOK() {
ctx := context.Background()
account := suite.testAccounts["admin_account"]
targetAccount := suite.testAccounts["local_account_2"]
followRequest := &gtsmodel.FollowRequest{
ID: "01GEF753FWHCHRDWR0QEHBXM8W",
URI: "http://localhost:8080/weeeeeeeeeeeeeeeee",
AccountID: account.ID,
TargetAccountID: targetAccount.ID,
}
if err := suite.db.Put(ctx, followRequest); err != nil {
suite.FailNow(err.Error())
}
rejectedFollowRequest, err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID)
suite.NoError(err)
suite.NotNil(rejectedFollowRequest)
}
func (suite *RelationshipTestSuite) TestRejectFollowRequestNotExisting() {
ctx := context.Background()
account := suite.testAccounts["admin_account"]
targetAccount := suite.testAccounts["local_account_2"]
rejectedFollowRequest, err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID)
suite.ErrorIs(err, db.ErrNoEntries)
suite.Nil(rejectedFollowRequest)
}
func (suite *RelationshipTestSuite) TestGetAccountFollowRequests() {
ctx := context.Background()
account := suite.testAccounts["admin_account"]
targetAccount := suite.testAccounts["local_account_2"]
followRequest := &gtsmodel.FollowRequest{
ID: "01GEF753FWHCHRDWR0QEHBXM8W",
URI: "http://localhost:8080/weeeeeeeeeeeeeeeee",
AccountID: account.ID,
TargetAccountID: targetAccount.ID,
}
if err := suite.db.Put(ctx, followRequest); err != nil {
suite.FailNow(err.Error())
}
followRequests, err := suite.db.GetAccountFollowRequests(ctx, targetAccount.ID)
suite.NoError(err)
suite.Len(followRequests, 1)
}
func (suite *RelationshipTestSuite) TestGetAccountFollows() {
account := suite.testAccounts["local_account_1"]
follows, err := suite.db.GetAccountFollows(context.Background(), account.ID)
suite.NoError(err)
suite.Len(follows, 2)
}
func (suite *RelationshipTestSuite) TestCountAccountFollowsLocalOnly() {
account := suite.testAccounts["local_account_1"]
followsCount, err := suite.db.CountAccountFollows(context.Background(), account.ID, true)
suite.NoError(err)
suite.Equal(2, followsCount)
}
func (suite *RelationshipTestSuite) TestCountAccountFollows() {
account := suite.testAccounts["local_account_1"]
followsCount, err := suite.db.CountAccountFollows(context.Background(), account.ID, false)
suite.NoError(err)
suite.Equal(2, followsCount)
}
func (suite *RelationshipTestSuite) TestGetAccountFollowedBy() {
account := suite.testAccounts["local_account_1"]
follows, err := suite.db.GetAccountFollowedBy(context.Background(), account.ID, false)
suite.NoError(err)
suite.Len(follows, 2)
}
func (suite *RelationshipTestSuite) TestGetAccountFollowedByLocalOnly() {
account := suite.testAccounts["local_account_1"]
follows, err := suite.db.GetAccountFollowedBy(context.Background(), account.ID, true)
suite.NoError(err)
suite.Len(follows, 2)
}
func (suite *RelationshipTestSuite) TestCountAccountFollowedBy() {
account := suite.testAccounts["local_account_1"]
followsCount, err := suite.db.CountAccountFollowedBy(context.Background(), account.ID, false)
suite.NoError(err)
suite.Equal(2, followsCount)
}
func (suite *RelationshipTestSuite) TestCountAccountFollowedByLocalOnly() {
account := suite.testAccounts["local_account_1"]
followsCount, err := suite.db.CountAccountFollowedBy(context.Background(), account.ID, true)
suite.NoError(err)
suite.Equal(2, followsCount)
} }
func TestRelationshipTestSuite(t *testing.T) { func TestRelationshipTestSuite(t *testing.T) {

View File

@ -67,7 +67,7 @@ func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, db
return u.cache.GetByID(id) return u.cache.GetByID(id)
}, },
func(user *gtsmodel.User) error { func(user *gtsmodel.User) error {
return u.newUserQ(user).Where("user.id = ?", id).Scan(ctx) return u.newUserQ(user).Where("? = ?", bun.Ident("user.id"), id).Scan(ctx)
}, },
) )
} }
@ -79,7 +79,7 @@ func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gts
return u.cache.GetByAccountID(accountID) return u.cache.GetByAccountID(accountID)
}, },
func(user *gtsmodel.User) error { func(user *gtsmodel.User) error {
return u.newUserQ(user).Where("user.account_id = ?", accountID).Scan(ctx) return u.newUserQ(user).Where("? = ?", bun.Ident("user.account_id"), accountID).Scan(ctx)
}, },
) )
} }
@ -91,7 +91,7 @@ func (u *userDB) GetUserByEmailAddress(ctx context.Context, emailAddress string)
return u.cache.GetByEmail(emailAddress) return u.cache.GetByEmail(emailAddress)
}, },
func(user *gtsmodel.User) error { func(user *gtsmodel.User) error {
return u.newUserQ(user).Where("user.email = ?", emailAddress).Scan(ctx) return u.newUserQ(user).Where("? = ?", bun.Ident("user.email"), emailAddress).Scan(ctx)
}, },
) )
} }
@ -103,7 +103,7 @@ func (u *userDB) GetUserByConfirmationToken(ctx context.Context, confirmationTok
return u.cache.GetByConfirmationToken(confirmationToken) return u.cache.GetByConfirmationToken(confirmationToken)
}, },
func(user *gtsmodel.User) error { func(user *gtsmodel.User) error {
return u.newUserQ(user).Where("user.confirmation_token = ?", confirmationToken).Scan(ctx) return u.newUserQ(user).Where("? = ?", bun.Ident("user.confirmation_token"), confirmationToken).Scan(ctx)
}, },
) )
} }

View File

@ -85,14 +85,8 @@ func parseWhere(w db.Where) (query string, args []interface{}) {
return return
} }
if w.CaseInsensitive {
query = "LOWER(?) != LOWER(?)"
args = []interface{}{bun.Safe(w.Key), w.Value}
return
}
query = "? != ?" query = "? != ?"
args = []interface{}{bun.Safe(w.Key), w.Value} args = []interface{}{bun.Ident(w.Key), w.Value}
return return
} }
@ -102,13 +96,7 @@ func parseWhere(w db.Where) (query string, args []interface{}) {
return return
} }
if w.CaseInsensitive {
query = "LOWER(?) = LOWER(?)"
args = []interface{}{bun.Safe(w.Key), w.Value}
return
}
query = "? = ?" query = "? = ?"
args = []interface{}{bun.Safe(w.Key), w.Value} args = []interface{}{bun.Ident(w.Key), w.Value}
return return
} }

View File

@ -24,9 +24,6 @@ type Where struct {
Key string Key string
// The value to match. // The value to match.
Value interface{} Value interface{}
// Whether the value (if a string) should be case sensitive or not.
// Defaults to false.
CaseInsensitive bool
// If set, reverse the where. // If set, reverse the where.
// `WHERE k = v` becomes `WHERE k != v`. // `WHERE k = v` becomes `WHERE k != v`.
// `WHERE k IS NULL` becomes `WHERE k IS NOT NULL` // `WHERE k IS NULL` becomes `WHERE k IS NOT NULL`

View File

@ -133,8 +133,10 @@ func (p *processor) initiateDomainBlockSideEffects(ctx context.Context, account
} }
// if we have an instance account for this instance, delete it // if we have an instance account for this instance, delete it
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "username", Value: block.Domain, CaseInsensitive: true}}, &gtsmodel.Account{}); err != nil { if instanceAccount, err := p.db.GetAccountByUsernameDomain(ctx, block.Domain, block.Domain); err == nil {
l.Errorf("domainBlockProcessSideEffects: db error removing instance account: %s", err) if err := p.db.DeleteAccount(ctx, instanceAccount.ID); err != nil {
l.Errorf("domainBlockProcessSideEffects: db error deleting instance account: %s", err)
}
} }
// delete accounts through the normal account deletion system (which should also delete media + posts + remove posts from timelines) // delete accounts through the normal account deletion system (which should also delete media + posts + remove posts from timelines)

View File

@ -55,7 +55,7 @@ func (p *processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Acc
// remove the domain block reference from the instance, if we have an entry for it // remove the domain block reference from the instance, if we have an entry for it
i := &gtsmodel.Instance{} i := &gtsmodel.Instance{}
if err := p.db.GetWhere(ctx, []db.Where{ if err := p.db.GetWhere(ctx, []db.Where{
{Key: "domain", Value: domainBlock.Domain, CaseInsensitive: true}, {Key: "domain", Value: domainBlock.Domain},
{Key: "domain_block_id", Value: id}, {Key: "domain_block_id", Value: id},
}, i); err == nil { }, i); err == nil {
updatingColumns := []string{"suspended_at", "domain_block_id", "updated_at"} updatingColumns := []string{"suspended_at", "domain_block_id", "updated_at"}