Merge commit '4d1c7c871bb192bc05ce1cb6fb80773126373111'
This commit is contained in:
parent
4100fc683c
commit
7594fc7456
|
@ -101,6 +101,11 @@ func (c *AccountCache) Put(account *gtsmodel.Account) {
|
|||
c.cache.Set(account.ID, copyAccount(account))
|
||||
}
|
||||
|
||||
// Invalidate removes (invalidates) one account from the cache by its ID.
|
||||
func (c *AccountCache) Invalidate(id string) {
|
||||
c.cache.Invalidate(id)
|
||||
}
|
||||
|
||||
// copyAccount performs a surface-level copy of account, only keeping attached IDs intact, not the objects.
|
||||
// due to all the data being copied being 99% primitive types or strings (which are immutable and passed by ptr)
|
||||
// this should be a relatively cheap process
|
||||
|
|
|
@ -48,6 +48,11 @@ type Account interface {
|
|||
// UpdateAccount updates one account by ID.
|
||||
UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error)
|
||||
|
||||
// DeleteAccount deletes one account from the database by its ID.
|
||||
// DO NOT USE THIS WHEN SUSPENDING ACCOUNTS! In that case you should mark the
|
||||
// account as suspended instead, rather than deleting from the db entirely.
|
||||
DeleteAccount(ctx context.Context, id string) Error
|
||||
|
||||
// GetAccountCustomCSSByUsername returns the custom css of an account on this instance with the given username.
|
||||
GetAccountCustomCSSByUsername(ctx context.Context, username string) (string, Error)
|
||||
|
||||
|
|
|
@ -21,7 +21,6 @@ package bundb
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -56,7 +55,7 @@ func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Ac
|
|||
return a.cache.GetByID(id)
|
||||
},
|
||||
func(account *gtsmodel.Account) error {
|
||||
return a.newAccountQ(account).Where("account.id = ?", id).Scan(ctx)
|
||||
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.id"), id).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
@ -68,7 +67,7 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.
|
|||
return a.cache.GetByURI(uri)
|
||||
},
|
||||
func(account *gtsmodel.Account) error {
|
||||
return a.newAccountQ(account).Where("account.uri = ?", uri).Scan(ctx)
|
||||
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.uri"), uri).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
@ -80,7 +79,7 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.
|
|||
return a.cache.GetByURL(url)
|
||||
},
|
||||
func(account *gtsmodel.Account) error {
|
||||
return a.newAccountQ(account).Where("account.url = ?", url).Scan(ctx)
|
||||
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.url"), url).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
@ -95,11 +94,11 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str
|
|||
q := a.newAccountQ(account)
|
||||
|
||||
if domain != "" {
|
||||
q = q.Where("account.username = ?", username)
|
||||
q = q.Where("account.domain = ?", domain)
|
||||
q = q.Where("? = ?", bun.Ident("account.username"), username)
|
||||
q = q.Where("? = ?", bun.Ident("account.domain"), domain)
|
||||
} else {
|
||||
q = q.Where("account.username = ?", strings.ToLower(username))
|
||||
q = q.Where("account.domain IS NULL")
|
||||
q = q.Where("? = ?", bun.Ident("account.username"), strings.ToLower(username))
|
||||
q = q.Where("? IS NULL", bun.Ident("account.domain"))
|
||||
}
|
||||
|
||||
return q.Scan(ctx)
|
||||
|
@ -114,7 +113,7 @@ func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmo
|
|||
return a.cache.GetByPubkeyID(id)
|
||||
},
|
||||
func(account *gtsmodel.Account) error {
|
||||
return a.newAccountQ(account).Where("account.public_key_uri = ?", id).Scan(ctx)
|
||||
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.public_key_uri"), id).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
@ -170,8 +169,8 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
|
|||
// create links between this account and any emojis it uses
|
||||
// first clear out any old emoji links
|
||||
if _, err := tx.NewDelete().
|
||||
Model(&[]*gtsmodel.AccountToEmoji{}).
|
||||
Where("account_id = ?", account.ID).
|
||||
TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")).
|
||||
Where("? = ?", bun.Ident("account_to_emoji.account_id"), account.ID).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -197,6 +196,32 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
|
|||
return account, nil
|
||||
}
|
||||
|
||||
func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
|
||||
if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error {
|
||||
// clear out any emoji links
|
||||
if _, err := tx.
|
||||
NewDelete().
|
||||
TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")).
|
||||
Where("? = ?", bun.Ident("account_to_emoji.account_id"), id).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// delete the account
|
||||
_, err := tx.
|
||||
NewUpdate().
|
||||
Model(>smodel.Account{ID: id}).
|
||||
WherePK().
|
||||
Exec(ctx)
|
||||
return err
|
||||
}); err != nil {
|
||||
return a.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
a.cache.Invalidate(id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
|
||||
account := new(gtsmodel.Account)
|
||||
|
||||
|
@ -204,11 +229,11 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts
|
|||
|
||||
if domain != "" {
|
||||
q = q.
|
||||
Where("account.username = ?", domain).
|
||||
Where("account.domain = ?", domain)
|
||||
Where("? = ?", bun.Ident("account.username"), domain).
|
||||
Where("? = ?", bun.Ident("account.domain"), domain)
|
||||
} else {
|
||||
q = q.
|
||||
Where("account.username = ?", config.GetHost()).
|
||||
Where("? = ?", bun.Ident("account.username"), config.GetHost()).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("domain"))
|
||||
}
|
||||
|
||||
|
@ -224,10 +249,10 @@ func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string)
|
|||
q := a.conn.
|
||||
NewSelect().
|
||||
Model(status).
|
||||
Order("id DESC").
|
||||
Limit(1).
|
||||
Where("account_id = ?", accountID).
|
||||
Column("created_at")
|
||||
Column("status.created_at").
|
||||
Where("? = ?", bun.Ident("status.account_id"), accountID).
|
||||
Order("status.id DESC").
|
||||
Limit(1)
|
||||
|
||||
if err := q.Scan(ctx); err != nil {
|
||||
return time.Time{}, a.conn.ProcessError(err)
|
||||
|
@ -240,12 +265,12 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen
|
|||
return errors.New("one media attachment cannot be both header and avatar")
|
||||
}
|
||||
|
||||
var headerOrAVI string
|
||||
var column bun.Ident
|
||||
switch {
|
||||
case *mediaAttachment.Avatar:
|
||||
headerOrAVI = "avatar"
|
||||
column = bun.Ident("account.avatar_media_attachment_id")
|
||||
case *mediaAttachment.Header:
|
||||
headerOrAVI = "header"
|
||||
column = bun.Ident("account.header_media_attachment_id")
|
||||
default:
|
||||
return errors.New("given media attachment was neither a header nor an avatar")
|
||||
}
|
||||
|
@ -257,11 +282,12 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen
|
|||
Exec(ctx); err != nil {
|
||||
return a.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
if _, err := a.conn.
|
||||
NewUpdate().
|
||||
Model(>smodel.Account{}).
|
||||
Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).
|
||||
Where("id = ?", accountID).
|
||||
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
|
||||
Set("? = ?", column, mediaAttachment.ID).
|
||||
Where("? = ?", bun.Ident("account.id"), accountID).
|
||||
Exec(ctx); err != nil {
|
||||
return a.conn.ProcessError(err)
|
||||
}
|
||||
|
@ -284,7 +310,7 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g
|
|||
if err := a.conn.
|
||||
NewSelect().
|
||||
Model(faves).
|
||||
Where("account_id = ?", accountID).
|
||||
Where("? = ?", bun.Ident("status_fave.account_id"), accountID).
|
||||
Scan(ctx); err != nil {
|
||||
return nil, a.conn.ProcessError(err)
|
||||
}
|
||||
|
@ -295,8 +321,8 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g
|
|||
func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, db.Error) {
|
||||
return a.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.Status{}).
|
||||
Where("account_id = ?", accountID).
|
||||
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
|
||||
Where("? = ?", bun.Ident("status.account_id"), accountID).
|
||||
Count(ctx)
|
||||
}
|
||||
|
||||
|
@ -305,12 +331,12 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
|
|||
|
||||
q := a.conn.
|
||||
NewSelect().
|
||||
Table("statuses").
|
||||
Column("id").
|
||||
Order("id DESC")
|
||||
Model(>smodel.Status{}).
|
||||
Column("status.id").
|
||||
Order("status.id DESC")
|
||||
|
||||
if accountID != "" {
|
||||
q = q.Where("account_id = ?", accountID)
|
||||
q = q.Where("? = ?", bun.Ident("status.account_id"), accountID)
|
||||
}
|
||||
|
||||
if limit != 0 {
|
||||
|
@ -321,27 +347,27 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
|
|||
// include self-replies (threads)
|
||||
whereGroup := func(*bun.SelectQuery) *bun.SelectQuery {
|
||||
return q.
|
||||
WhereOr("in_reply_to_account_id = ?", accountID).
|
||||
WhereGroup(" OR ", whereEmptyOrNull("in_reply_to_uri"))
|
||||
WhereOr("? = ?", bun.Ident("status.in_reply_to_account_id"), accountID).
|
||||
WhereGroup(" OR ", whereEmptyOrNull("status.in_reply_to_uri"))
|
||||
}
|
||||
|
||||
q = q.WhereGroup(" AND ", whereGroup)
|
||||
}
|
||||
|
||||
if excludeReblogs {
|
||||
q = q.WhereGroup(" AND ", whereEmptyOrNull("boost_of_id"))
|
||||
q = q.WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id"))
|
||||
}
|
||||
|
||||
if maxID != "" {
|
||||
q = q.Where("id < ?", maxID)
|
||||
q = q.Where("? < ?", bun.Ident("status.id"), maxID)
|
||||
}
|
||||
|
||||
if minID != "" {
|
||||
q = q.Where("id > ?", minID)
|
||||
q = q.Where("? > ?", bun.Ident("status.id"), minID)
|
||||
}
|
||||
|
||||
if pinnedOnly {
|
||||
q = q.Where("pinned = ?", true)
|
||||
q = q.Where("? = ?", bun.Ident("status.pinned"), true)
|
||||
}
|
||||
|
||||
if mediaOnly {
|
||||
|
@ -352,15 +378,15 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
|
|||
switch a.conn.Dialect().Name() {
|
||||
case dialect.PG:
|
||||
return q.
|
||||
Where("? IS NOT NULL", bun.Ident("attachments")).
|
||||
Where("? != '{}'", bun.Ident("attachments"))
|
||||
Where("? IS NOT NULL", bun.Ident("status.attachments")).
|
||||
Where("? != '{}'", bun.Ident("status.attachments"))
|
||||
case dialect.SQLite:
|
||||
return q.
|
||||
Where("? IS NOT NULL", bun.Ident("attachments")).
|
||||
Where("? != ''", bun.Ident("attachments")).
|
||||
Where("? != 'null'", bun.Ident("attachments")).
|
||||
Where("? != '{}'", bun.Ident("attachments")).
|
||||
Where("? != '[]'", bun.Ident("attachments"))
|
||||
Where("? IS NOT NULL", bun.Ident("status.attachments")).
|
||||
Where("? != ''", bun.Ident("status.attachments")).
|
||||
Where("? != 'null'", bun.Ident("status.attachments")).
|
||||
Where("? != '{}'", bun.Ident("status.attachments")).
|
||||
Where("? != '[]'", bun.Ident("status.attachments"))
|
||||
default:
|
||||
log.Panic("db dialect was neither pg nor sqlite")
|
||||
return q
|
||||
|
@ -369,7 +395,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
|
|||
}
|
||||
|
||||
if publicOnly {
|
||||
q = q.Where("visibility = ?", gtsmodel.VisibilityPublic)
|
||||
q = q.Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic)
|
||||
}
|
||||
|
||||
if err := q.Scan(ctx, &statusIDs); err != nil {
|
||||
|
@ -384,19 +410,19 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string,
|
|||
|
||||
q := a.conn.
|
||||
NewSelect().
|
||||
Table("statuses").
|
||||
Column("id").
|
||||
Where("account_id = ?", accountID).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("in_reply_to_uri")).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("boost_of_id")).
|
||||
Where("visibility = ?", gtsmodel.VisibilityPublic).
|
||||
Where("federated = ?", true)
|
||||
Model(>smodel.Status{}).
|
||||
Column("status.id").
|
||||
Where("? = ?", bun.Ident("status.account_id"), accountID).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_uri")).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")).
|
||||
Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic).
|
||||
Where("? = ?", bun.Ident("status.federated"), true)
|
||||
|
||||
if maxID != "" {
|
||||
q = q.Where("id < ?", maxID)
|
||||
q = q.Where("? < ?", bun.Ident("status.id"), maxID)
|
||||
}
|
||||
|
||||
q = q.Limit(limit).Order("id DESC")
|
||||
q = q.Limit(limit).Order("status.id DESC")
|
||||
|
||||
if err := q.Scan(ctx, &statusIDs); err != nil {
|
||||
return nil, a.conn.ProcessError(err)
|
||||
|
@ -411,16 +437,16 @@ func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxI
|
|||
fq := a.conn.
|
||||
NewSelect().
|
||||
Model(&blocks).
|
||||
Where("block.account_id = ?", accountID).
|
||||
Where("? = ?", bun.Ident("block.account_id"), accountID).
|
||||
Relation("TargetAccount").
|
||||
Order("block.id DESC")
|
||||
|
||||
if maxID != "" {
|
||||
fq = fq.Where("block.id < ?", maxID)
|
||||
fq = fq.Where("? < ?", bun.Ident("block.id"), maxID)
|
||||
}
|
||||
|
||||
if sinceID != "" {
|
||||
fq = fq.Where("block.id > ?", sinceID)
|
||||
fq = fq.Where("? > ?", bun.Ident("block.id"), sinceID)
|
||||
}
|
||||
|
||||
if limit > 0 {
|
||||
|
|
|
@ -22,7 +22,6 @@ import (
|
|||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/mail"
|
||||
|
@ -37,21 +36,26 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/id"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/uris"
|
||||
"github.com/uptrace/bun"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// generate RSA keys of this length
|
||||
const rsaKeyBits = 2048
|
||||
|
||||
type adminDB struct {
|
||||
conn *DBConn
|
||||
userCache *cache.UserCache
|
||||
accountCache *cache.AccountCache
|
||||
}
|
||||
|
||||
func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) {
|
||||
q := a.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.Account{}).
|
||||
Where("username = ?", username).
|
||||
Where("domain = ?", nil)
|
||||
|
||||
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
|
||||
Column("account.id").
|
||||
Where("? = ?", bun.Ident("account.username"), username).
|
||||
Where("? IS NULL", bun.Ident("account.domain"))
|
||||
return a.conn.NotExists(ctx, q)
|
||||
}
|
||||
|
||||
|
@ -64,29 +68,31 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.
|
|||
domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
|
||||
|
||||
// check if the email domain is blocked
|
||||
if err := a.conn.
|
||||
emailDomainBlockedQ := a.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.EmailDomainBlock{}).
|
||||
Where("domain = ?", domain).
|
||||
Scan(ctx); err == nil {
|
||||
// fail because we found something
|
||||
TableExpr("? AS ?", bun.Ident("email_domain_blocks"), bun.Ident("email_domain_block")).
|
||||
Column("email_domain_block.id").
|
||||
Where("? = ?", bun.Ident("email_domain_block.domain"), domain)
|
||||
emailDomainBlocked, err := a.conn.Exists(ctx, emailDomainBlockedQ)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if emailDomainBlocked {
|
||||
return false, fmt.Errorf("email domain %s is blocked", domain)
|
||||
} else if err != sql.ErrNoRows {
|
||||
return false, a.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// check if this email is associated with a user already
|
||||
q := a.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.User{}).
|
||||
Where("email = ?", email).
|
||||
WhereOr("unconfirmed_email = ?", email)
|
||||
|
||||
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
|
||||
Column("user.id").
|
||||
Where("? = ?", bun.Ident("user.email"), email).
|
||||
WhereOr("? = ?", bun.Ident("user.unconfirmed_email"), email)
|
||||
return a.conn.NotExists(ctx, q)
|
||||
}
|
||||
|
||||
func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
key, err := rsa.GenerateKey(rand.Reader, rsaKeyBits)
|
||||
if err != nil {
|
||||
log.Errorf("error creating new rsa key: %s", err)
|
||||
return nil, err
|
||||
|
@ -94,13 +100,20 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
|
|||
|
||||
// if something went wrong while creating a user, we might already have an account, so check here first...
|
||||
acct := >smodel.Account{}
|
||||
q := a.conn.NewSelect().
|
||||
if err := a.conn.
|
||||
NewSelect().
|
||||
Model(acct).
|
||||
Where("username = ?", username).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("domain"))
|
||||
Where("? = ?", bun.Ident("account.username"), username).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("account.domain")).
|
||||
Scan(ctx); err != nil {
|
||||
err = a.conn.ProcessError(err)
|
||||
if err != db.ErrNoEntries {
|
||||
log.Errorf("error checking for existing account: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := q.Scan(ctx); err != nil {
|
||||
// we just don't have an account yet so create one before we proceed
|
||||
// if we have db.ErrNoEntries, we just don't have an
|
||||
// account yet so create one before we proceed
|
||||
accountURIs := uris.GenerateURIsForAccount(username)
|
||||
accountID, err := id.NewRandomULID()
|
||||
if err != nil {
|
||||
|
@ -126,14 +139,19 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
|
|||
FeaturedCollectionURI: accountURIs.CollectionURI,
|
||||
}
|
||||
|
||||
// insert the new account!
|
||||
if _, err = a.conn.
|
||||
NewInsert().
|
||||
Model(acct).
|
||||
Exec(ctx); err != nil {
|
||||
return nil, a.conn.ProcessError(err)
|
||||
}
|
||||
a.accountCache.Put(acct)
|
||||
}
|
||||
|
||||
// we either created or already had an account by now,
|
||||
// so proceed with creating a user for that account
|
||||
|
||||
pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error hashing password: %s", err)
|
||||
|
@ -171,6 +189,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
|
|||
u.Moderator = &moderator
|
||||
}
|
||||
|
||||
// insert the user!
|
||||
if _, err = a.conn.
|
||||
NewInsert().
|
||||
Model(u).
|
||||
|
@ -187,9 +206,10 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
|
|||
|
||||
q := a.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.Account{}).
|
||||
Where("username = ?", username).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("domain"))
|
||||
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
|
||||
Column("account.id").
|
||||
Where("? = ?", bun.Ident("account.username"), username).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("account.domain"))
|
||||
|
||||
exists, err := a.conn.Exists(ctx, q)
|
||||
if err != nil {
|
||||
|
@ -200,7 +220,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
|
|||
return nil
|
||||
}
|
||||
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
key, err := rsa.GenerateKey(rand.Reader, rsaKeyBits)
|
||||
if err != nil {
|
||||
log.Errorf("error creating new rsa key: %s", err)
|
||||
return err
|
||||
|
@ -237,6 +257,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
|
|||
return a.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
a.accountCache.Put(acct)
|
||||
log.Infof("instance account %s CREATED with id %s", username, acct.ID)
|
||||
return nil
|
||||
}
|
||||
|
@ -248,8 +269,9 @@ func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error {
|
|||
// check if instance entry already exists
|
||||
q := a.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.Instance{}).
|
||||
Where("domain = ?", host)
|
||||
Column("instance.id").
|
||||
TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")).
|
||||
Where("? = ?", bun.Ident("instance.domain"), host)
|
||||
|
||||
exists, err := a.conn.Exists(ctx, q)
|
||||
if err != nil {
|
||||
|
|
|
@ -23,6 +23,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
gtsmodel "github.com/superseriousbusiness/gotosocial/internal/db/bundb/migrations/20211113114307_init"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
@ -30,6 +31,44 @@ type AdminTestSuite struct {
|
|||
BunDBStandardTestSuite
|
||||
}
|
||||
|
||||
func (suite *AdminTestSuite) TestIsUsernameAvailableNo() {
|
||||
available, err := suite.db.IsUsernameAvailable(context.Background(), "the_mighty_zork")
|
||||
suite.NoError(err)
|
||||
suite.False(available)
|
||||
}
|
||||
|
||||
func (suite *AdminTestSuite) TestIsUsernameAvailableYes() {
|
||||
available, err := suite.db.IsUsernameAvailable(context.Background(), "someone_completely_different")
|
||||
suite.NoError(err)
|
||||
suite.True(available)
|
||||
}
|
||||
|
||||
func (suite *AdminTestSuite) TestIsEmailAvailableNo() {
|
||||
available, err := suite.db.IsEmailAvailable(context.Background(), "zork@example.org")
|
||||
suite.NoError(err)
|
||||
suite.False(available)
|
||||
}
|
||||
|
||||
func (suite *AdminTestSuite) TestIsEmailAvailableYes() {
|
||||
available, err := suite.db.IsEmailAvailable(context.Background(), "someone@somewhere.com")
|
||||
suite.NoError(err)
|
||||
suite.True(available)
|
||||
}
|
||||
|
||||
func (suite *AdminTestSuite) TestIsEmailAvailableDomainBlocked() {
|
||||
if err := suite.db.Put(context.Background(), >smodel.EmailDomainBlock{
|
||||
ID: "01GEEV2R2YC5GRSN96761YJE47",
|
||||
Domain: "somewhere.com",
|
||||
CreatedByAccountID: suite.testAccounts["admin_account"].ID,
|
||||
}); err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
available, err := suite.db.IsEmailAvailable(context.Background(), "someone@somewhere.com")
|
||||
suite.EqualError(err, "email domain somewhere.com is blocked")
|
||||
suite.False(available)
|
||||
}
|
||||
|
||||
func (suite *AdminTestSuite) TestCreateInstanceAccount() {
|
||||
// we need to take an empty db for this...
|
||||
testrig.StandardDBTeardown(suite.db)
|
||||
|
|
|
@ -110,7 +110,7 @@ func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string,
|
|||
|
||||
updateWhere(q, where)
|
||||
|
||||
q = q.Set("? = ?", bun.Safe(key), value)
|
||||
q = q.Set("? = ?", bun.Ident(key), value)
|
||||
|
||||
_, err := q.Exec(ctx)
|
||||
return b.conn.ProcessError(err)
|
||||
|
|
|
@ -159,17 +159,11 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {
|
|||
return nil, fmt.Errorf("db migration error: %s", err)
|
||||
}
|
||||
|
||||
// Create DB structs that require ptrs to each other
|
||||
accounts := &accountDB{conn: conn, cache: cache.NewAccountCache()}
|
||||
status := &statusDB{conn: conn, cache: cache.NewStatusCache()}
|
||||
emoji := &emojiDB{conn: conn, cache: cache.NewEmojiCache()}
|
||||
timeline := &timelineDB{conn: conn}
|
||||
|
||||
// Setup DB cross-referencing
|
||||
accounts.status = status
|
||||
status.accounts = accounts
|
||||
timeline.status = status
|
||||
// Prepare caches required by more than one struct
|
||||
userCache := cache.NewUserCache()
|
||||
accountCache := cache.NewAccountCache()
|
||||
|
||||
// Prepare other caches
|
||||
// Prepare mentions cache
|
||||
// TODO: move into internal/cache
|
||||
mentionCache := grufcache.New[string, *gtsmodel.Mention]()
|
||||
|
@ -182,22 +176,30 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {
|
|||
notifCache.SetTTL(time.Minute*5, false)
|
||||
notifCache.Start(time.Second * 10)
|
||||
|
||||
// Prepare other caches
|
||||
blockCache := cache.NewDomainBlockCache()
|
||||
userCache := cache.NewUserCache()
|
||||
// Create DB structs that require ptrs to each other
|
||||
accounts := &accountDB{conn: conn, cache: accountCache}
|
||||
status := &statusDB{conn: conn, cache: cache.NewStatusCache()}
|
||||
emoji := &emojiDB{conn: conn, cache: cache.NewEmojiCache()}
|
||||
timeline := &timelineDB{conn: conn}
|
||||
|
||||
// Setup DB cross-referencing
|
||||
accounts.status = status
|
||||
status.accounts = accounts
|
||||
timeline.status = status
|
||||
|
||||
ps := &DBService{
|
||||
Account: accounts,
|
||||
Admin: &adminDB{
|
||||
conn: conn,
|
||||
userCache: userCache,
|
||||
accountCache: accountCache,
|
||||
},
|
||||
Basic: &basicDB{
|
||||
conn: conn,
|
||||
},
|
||||
Domain: &domainDB{
|
||||
conn: conn,
|
||||
cache: blockCache,
|
||||
cache: cache.NewDomainBlockCache(),
|
||||
},
|
||||
Emoji: emoji,
|
||||
Instance: &instanceDB{
|
||||
|
|
|
@ -40,6 +40,7 @@ type BunDBStandardTestSuite struct {
|
|||
testStatuses map[string]*gtsmodel.Status
|
||||
testTags map[string]*gtsmodel.Tag
|
||||
testMentions map[string]*gtsmodel.Mention
|
||||
testFollows map[string]*gtsmodel.Follow
|
||||
}
|
||||
|
||||
func (suite *BunDBStandardTestSuite) SetupSuite() {
|
||||
|
@ -52,6 +53,7 @@ func (suite *BunDBStandardTestSuite) SetupSuite() {
|
|||
suite.testStatuses = testrig.NewTestStatuses()
|
||||
suite.testTags = testrig.NewTestTags()
|
||||
suite.testMentions = testrig.NewTestMentions()
|
||||
suite.testFollows = testrig.NewTestFollows()
|
||||
}
|
||||
|
||||
func (suite *BunDBStandardTestSuite) SetupTest() {
|
||||
|
|
|
@ -28,6 +28,7 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/uptrace/bun"
|
||||
"golang.org/x/net/idna"
|
||||
)
|
||||
|
||||
|
@ -95,7 +96,7 @@ func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel
|
|||
q := d.conn.
|
||||
NewSelect().
|
||||
Model(block).
|
||||
Where("domain = ?", domain).
|
||||
Where("? = ?", bun.Ident("domain_block.domain"), domain).
|
||||
Limit(1)
|
||||
|
||||
// Query database for domain block
|
||||
|
@ -126,7 +127,7 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Erro
|
|||
// Attempt to delete domain block
|
||||
if _, err := d.conn.NewDelete().
|
||||
Model((*gtsmodel.DomainBlock)(nil)).
|
||||
Where("domain = ?", domain).
|
||||
Where("? = ?", bun.Ident("domain_block.domain"), domain).
|
||||
Exec(ctx); err != nil {
|
||||
return d.conn.ProcessError(err)
|
||||
}
|
||||
|
|
|
@ -54,12 +54,12 @@ func (e *emojiDB) GetCustomEmojis(ctx context.Context) ([]*gtsmodel.Emoji, db.Er
|
|||
|
||||
q := e.conn.
|
||||
NewSelect().
|
||||
Table("emojis").
|
||||
Column("id").
|
||||
Where("visible_in_picker = true").
|
||||
Where("disabled = false").
|
||||
Where("domain IS NULL").
|
||||
Order("shortcode ASC")
|
||||
TableExpr("? AS ?", bun.Ident("emojis"), bun.Ident("emoji")).
|
||||
Column("emoji.id").
|
||||
Where("? = ?", bun.Ident("emoji.visible_in_picker"), true).
|
||||
Where("? = ?", bun.Ident("emoji.disabled"), false).
|
||||
Where("? IS NULL", bun.Ident("emoji.domain")).
|
||||
Order("emoji.shortcode ASC")
|
||||
|
||||
if err := q.Scan(ctx, &emojiIDs); err != nil {
|
||||
return nil, e.conn.ProcessError(err)
|
||||
|
@ -75,7 +75,7 @@ func (e *emojiDB) GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji,
|
|||
return e.cache.GetByID(id)
|
||||
},
|
||||
func(emoji *gtsmodel.Emoji) error {
|
||||
return e.newEmojiQ(emoji).Where("emoji.id = ?", id).Scan(ctx)
|
||||
return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.id"), id).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
@ -87,7 +87,7 @@ func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoj
|
|||
return e.cache.GetByURI(uri)
|
||||
},
|
||||
func(emoji *gtsmodel.Emoji) error {
|
||||
return e.newEmojiQ(emoji).Where("emoji.uri = ?", uri).Scan(ctx)
|
||||
return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.uri"), uri).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
@ -102,11 +102,11 @@ func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode strin
|
|||
q := e.newEmojiQ(emoji)
|
||||
|
||||
if domain != "" {
|
||||
q = q.Where("emoji.shortcode = ?", shortcode)
|
||||
q = q.Where("emoji.domain = ?", domain)
|
||||
q = q.Where("? = ?", bun.Ident("emoji.shortcode"), shortcode)
|
||||
q = q.Where("? = ?", bun.Ident("emoji.domain"), domain)
|
||||
} else {
|
||||
q = q.Where("emoji.shortcode = ?", strings.ToLower(shortcode))
|
||||
q = q.Where("emoji.domain IS NULL")
|
||||
q = q.Where("? = ?", bun.Ident("emoji.shortcode"), strings.ToLower(shortcode))
|
||||
q = q.Where("? IS NULL", bun.Ident("emoji.domain"))
|
||||
}
|
||||
|
||||
return q.Scan(ctx)
|
||||
|
|
|
@ -24,7 +24,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
|
@ -35,15 +34,16 @@ type instanceDB struct {
|
|||
func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) {
|
||||
q := i.conn.
|
||||
NewSelect().
|
||||
Model(&[]*gtsmodel.Account{}).
|
||||
Where("username != ?", domain).
|
||||
Where("? IS NULL", bun.Ident("suspended_at"))
|
||||
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
|
||||
Column("account.id").
|
||||
Where("? != ?", bun.Ident("account.username"), domain).
|
||||
Where("? IS NULL", bun.Ident("account.suspended_at"))
|
||||
|
||||
if domain == config.GetHost() {
|
||||
if domain == config.GetHost() || domain == config.GetAccountDomain() {
|
||||
// if the domain is *this* domain, just count where the domain field is null
|
||||
q = q.WhereGroup(" AND ", whereEmptyOrNull("domain"))
|
||||
q = q.WhereGroup(" AND ", whereEmptyOrNull("account.domain"))
|
||||
} else {
|
||||
q = q.Where("domain = ?", domain)
|
||||
q = q.Where("? = ?", bun.Ident("account.domain"), domain)
|
||||
}
|
||||
|
||||
count, err := q.Count(ctx)
|
||||
|
@ -56,15 +56,16 @@ func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int
|
|||
func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) {
|
||||
q := i.conn.
|
||||
NewSelect().
|
||||
Model(&[]*gtsmodel.Status{})
|
||||
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status"))
|
||||
|
||||
if domain == config.GetHost() {
|
||||
if domain == config.GetHost() || domain == config.GetAccountDomain() {
|
||||
// if the domain is *this* domain, just count where local is true
|
||||
q = q.Where("local = ?", true)
|
||||
q = q.Where("? = ?", bun.Ident("status.local"), true)
|
||||
} else {
|
||||
// join on the domain of the account
|
||||
q = q.Join("JOIN accounts AS account ON account.id = status.account_id").
|
||||
Where("account.domain = ?", domain)
|
||||
q = q.
|
||||
Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("account.id"), bun.Ident("status.account_id")).
|
||||
Where("? = ?", bun.Ident("account.domain"), domain)
|
||||
}
|
||||
|
||||
count, err := q.Count(ctx)
|
||||
|
@ -77,14 +78,14 @@ func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (
|
|||
func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) {
|
||||
q := i.conn.
|
||||
NewSelect().
|
||||
Model(&[]*gtsmodel.Instance{})
|
||||
TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance"))
|
||||
|
||||
if domain == config.GetHost() {
|
||||
// if the domain is *this* domain, just count other instances it knows about
|
||||
// exclude domains that are blocked
|
||||
q = q.
|
||||
Where("domain != ?", domain).
|
||||
Where("? IS NULL", bun.Ident("suspended_at"))
|
||||
Where("? != ?", bun.Ident("instance.domain"), domain).
|
||||
Where("? IS NULL", bun.Ident("instance.suspended_at"))
|
||||
} else {
|
||||
// TODO: implement federated domain counting properly for remote domains
|
||||
return 0, nil
|
||||
|
@ -103,10 +104,10 @@ func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool
|
|||
q := i.conn.
|
||||
NewSelect().
|
||||
Model(&instances).
|
||||
Where("domain != ?", config.GetHost())
|
||||
Where("? != ?", bun.Ident("instance.domain"), config.GetHost())
|
||||
|
||||
if !includeSuspended {
|
||||
q = q.Where("? IS NULL", bun.Ident("suspended_at"))
|
||||
q = q.Where("? IS NULL", bun.Ident("instance.suspended_at"))
|
||||
}
|
||||
|
||||
if err := q.Scan(ctx); err != nil {
|
||||
|
@ -117,17 +118,15 @@ func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool
|
|||
}
|
||||
|
||||
func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
|
||||
log.Debug("GetAccountsForInstance")
|
||||
|
||||
accounts := []*gtsmodel.Account{}
|
||||
|
||||
q := i.conn.NewSelect().
|
||||
Model(&accounts).
|
||||
Where("domain = ?", domain).
|
||||
Order("id DESC")
|
||||
Where("? = ?", bun.Ident("account.domain"), domain).
|
||||
Order("account.id DESC")
|
||||
|
||||
if maxID != "" {
|
||||
q = q.Where("id < ?", maxID)
|
||||
q = q.Where("? < ?", bun.Ident("account.id"), maxID)
|
||||
}
|
||||
|
||||
if limit > 0 {
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package bundb_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
)
|
||||
|
||||
type InstanceTestSuite struct {
|
||||
BunDBStandardTestSuite
|
||||
}
|
||||
|
||||
func (suite *InstanceTestSuite) TestCountInstanceUsers() {
|
||||
count, err := suite.db.CountInstanceUsers(context.Background(), config.GetHost())
|
||||
suite.NoError(err)
|
||||
suite.Equal(4, count)
|
||||
}
|
||||
|
||||
func (suite *InstanceTestSuite) TestCountInstanceUsersRemote() {
|
||||
count, err := suite.db.CountInstanceUsers(context.Background(), "fossbros-anonymous.io")
|
||||
suite.NoError(err)
|
||||
suite.Equal(1, count)
|
||||
}
|
||||
|
||||
func (suite *InstanceTestSuite) TestCountInstanceStatuses() {
|
||||
count, err := suite.db.CountInstanceStatuses(context.Background(), config.GetHost())
|
||||
suite.NoError(err)
|
||||
suite.Equal(16, count)
|
||||
}
|
||||
|
||||
func (suite *InstanceTestSuite) TestCountInstanceStatusesRemote() {
|
||||
count, err := suite.db.CountInstanceStatuses(context.Background(), "fossbros-anonymous.io")
|
||||
suite.NoError(err)
|
||||
suite.Equal(1, count)
|
||||
}
|
||||
|
||||
func (suite *InstanceTestSuite) TestCountInstanceDomains() {
|
||||
count, err := suite.db.CountInstanceDomains(context.Background(), config.GetHost())
|
||||
suite.NoError(err)
|
||||
suite.Equal(2, count)
|
||||
}
|
||||
|
||||
func (suite *InstanceTestSuite) TestGetInstancePeers() {
|
||||
peers, err := suite.db.GetInstancePeers(context.Background(), false)
|
||||
suite.NoError(err)
|
||||
suite.Len(peers, 2)
|
||||
}
|
||||
|
||||
func (suite *InstanceTestSuite) TestGetInstancePeersIncludeSuspended() {
|
||||
peers, err := suite.db.GetInstancePeers(context.Background(), true)
|
||||
suite.NoError(err)
|
||||
suite.Len(peers, 2)
|
||||
}
|
||||
|
||||
func (suite *InstanceTestSuite) TestGetInstanceAccounts() {
|
||||
accounts, err := suite.db.GetInstanceAccounts(context.Background(), "fossbros-anonymous.io", "", 10)
|
||||
suite.NoError(err)
|
||||
suite.Len(accounts, 1)
|
||||
}
|
||||
|
||||
func TestInstanceTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(InstanceTestSuite))
|
||||
}
|
|
@ -42,7 +42,7 @@ func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.M
|
|||
attachment := >smodel.MediaAttachment{}
|
||||
|
||||
q := m.newMediaQ(attachment).
|
||||
Where("media_attachment.id = ?", id)
|
||||
Where("? = ?", bun.Ident("media_attachment.id"), id)
|
||||
|
||||
if err := q.Scan(ctx); err != nil {
|
||||
return nil, m.conn.ProcessError(err)
|
||||
|
@ -56,10 +56,10 @@ func (m *mediaDB) GetRemoteOlderThan(ctx context.Context, olderThan time.Time, l
|
|||
q := m.conn.
|
||||
NewSelect().
|
||||
Model(&attachments).
|
||||
Where("media_attachment.cached = true").
|
||||
Where("media_attachment.avatar = false").
|
||||
Where("media_attachment.header = false").
|
||||
Where("media_attachment.created_at < ?", olderThan).
|
||||
Where("? = ?", bun.Ident("media_attachment.cached"), true).
|
||||
Where("? = ?", bun.Ident("media_attachment.avatar"), false).
|
||||
Where("? = ?", bun.Ident("media_attachment.header"), false).
|
||||
Where("? < ?", bun.Ident("media_attachment.created_at"), olderThan).
|
||||
WhereGroup(" AND ", whereNotEmptyAndNotNull("media_attachment.remote_url")).
|
||||
Order("media_attachment.created_at DESC")
|
||||
|
||||
|
@ -79,13 +79,13 @@ func (m *mediaDB) GetAvatarsAndHeaders(ctx context.Context, maxID string, limit
|
|||
q := m.newMediaQ(&attachments).
|
||||
WhereGroup(" AND ", func(innerQ *bun.SelectQuery) *bun.SelectQuery {
|
||||
return innerQ.
|
||||
WhereOr("media_attachment.avatar = true").
|
||||
WhereOr("media_attachment.header = true")
|
||||
WhereOr("? = ?", bun.Ident("media_attachment.avatar"), true).
|
||||
WhereOr("? = ?", bun.Ident("media_attachment.header"), true)
|
||||
}).
|
||||
Order("media_attachment.id DESC")
|
||||
|
||||
if maxID != "" {
|
||||
q = q.Where("media_attachment.id < ?", maxID)
|
||||
q = q.Where("? < ?", bun.Ident("media_attachment.id"), maxID)
|
||||
}
|
||||
|
||||
if limit != 0 {
|
||||
|
@ -103,15 +103,15 @@ func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan tim
|
|||
attachments := []*gtsmodel.MediaAttachment{}
|
||||
|
||||
q := m.newMediaQ(&attachments).
|
||||
Where("media_attachment.cached = true").
|
||||
Where("media_attachment.avatar = false").
|
||||
Where("media_attachment.header = false").
|
||||
Where("media_attachment.created_at < ?", olderThan).
|
||||
Where("media_attachment.remote_url IS NULL").
|
||||
Where("media_attachment.status_id IS NULL")
|
||||
Where("? = ?", bun.Ident("media_attachment.cached"), true).
|
||||
Where("? = ?", bun.Ident("media_attachment.avatar"), false).
|
||||
Where("? = ?", bun.Ident("media_attachment.header"), false).
|
||||
Where("? < ?", bun.Ident("media_attachment.created_at"), olderThan).
|
||||
Where("? IS NULL", bun.Ident("media_attachment.remote_url")).
|
||||
Where("? IS NULL", bun.Ident("media_attachment.status_id"))
|
||||
|
||||
if maxID != "" {
|
||||
q = q.Where("media_attachment.id < ?", maxID)
|
||||
q = q.Where("? < ?", bun.Ident("media_attachment.id"), maxID)
|
||||
}
|
||||
|
||||
if limit != 0 {
|
||||
|
|
|
@ -46,7 +46,7 @@ func (m *mentionDB) getMentionDB(ctx context.Context, id string) (*gtsmodel.Ment
|
|||
mention := gtsmodel.Mention{}
|
||||
|
||||
q := m.newMentionQ(&mention).
|
||||
Where("mention.id = ?", id)
|
||||
Where("? = ?", bun.Ident("mention.id"), id)
|
||||
|
||||
if err := q.Scan(ctx); err != nil {
|
||||
return nil, m.conn.ProcessError(err)
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
type notificationDB struct {
|
||||
|
@ -67,24 +68,24 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,
|
|||
|
||||
q := n.conn.
|
||||
NewSelect().
|
||||
Table("notifications").
|
||||
Column("id")
|
||||
TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
|
||||
Column("notification.id")
|
||||
|
||||
if maxID != "" {
|
||||
q = q.Where("id < ?", maxID)
|
||||
q = q.Where("? < ?", bun.Ident("notification.id"), maxID)
|
||||
}
|
||||
|
||||
if sinceID != "" {
|
||||
q = q.Where("id > ?", sinceID)
|
||||
q = q.Where("? > ?", bun.Ident("notification.id"), sinceID)
|
||||
}
|
||||
|
||||
for _, excludeType := range excludeTypes {
|
||||
q = q.Where("notification_type != ?", excludeType)
|
||||
q = q.Where("? != ?", bun.Ident("notification.notification_type"), excludeType)
|
||||
}
|
||||
|
||||
q = q.
|
||||
Where("target_account_id = ?", accountID).
|
||||
Order("id DESC")
|
||||
Where("? = ?", bun.Ident("notification.target_account_id"), accountID).
|
||||
Order("notification.id DESC")
|
||||
|
||||
if limit != 0 {
|
||||
q = q.Limit(limit)
|
||||
|
@ -116,13 +117,12 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,
|
|||
func (n *notificationDB) ClearNotifications(ctx context.Context, accountID string) db.Error {
|
||||
if _, err := n.conn.
|
||||
NewDelete().
|
||||
Table("notifications").
|
||||
Where("target_account_id = ?", accountID).
|
||||
TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
|
||||
Where("? = ?", bun.Ident("notification.target_account_id"), accountID).
|
||||
Exec(ctx); err != nil {
|
||||
return n.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
n.cache.Clear()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -51,26 +51,25 @@ func (r *relationshipDB) newFollowQ(follow interface{}) *bun.SelectQuery {
|
|||
func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) {
|
||||
q := r.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.Block{}).
|
||||
ExcludeColumn("id", "created_at", "updated_at", "uri").
|
||||
Limit(1)
|
||||
TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
|
||||
Column("block.id")
|
||||
|
||||
if eitherDirection {
|
||||
q = q.
|
||||
WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery {
|
||||
return inner.
|
||||
Where("account_id = ?", account1).
|
||||
Where("target_account_id = ?", account2)
|
||||
Where("? = ?", bun.Ident("block.account_id"), account1).
|
||||
Where("? = ?", bun.Ident("block.target_account_id"), account2)
|
||||
}).
|
||||
WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery {
|
||||
return inner.
|
||||
Where("account_id = ?", account2).
|
||||
Where("target_account_id = ?", account1)
|
||||
Where("? = ?", bun.Ident("block.account_id"), account2).
|
||||
Where("? = ?", bun.Ident("block.target_account_id"), account1)
|
||||
})
|
||||
} else {
|
||||
q = q.
|
||||
Where("account_id = ?", account1).
|
||||
Where("target_account_id = ?", account2)
|
||||
Where("? = ?", bun.Ident("block.account_id"), account1).
|
||||
Where("? = ?", bun.Ident("block.target_account_id"), account2)
|
||||
}
|
||||
|
||||
return r.conn.Exists(ctx, q)
|
||||
|
@ -80,8 +79,8 @@ func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2
|
|||
block := >smodel.Block{}
|
||||
|
||||
q := r.newBlockQ(block).
|
||||
Where("block.account_id = ?", account1).
|
||||
Where("block.target_account_id = ?", account2)
|
||||
Where("? = ?", bun.Ident("block.account_id"), account1).
|
||||
Where("? = ?", bun.Ident("block.target_account_id"), account2)
|
||||
|
||||
if err := q.Scan(ctx); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
|
@ -99,13 +98,13 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount
|
|||
if err := r.conn.
|
||||
NewSelect().
|
||||
Model(follow).
|
||||
Where("account_id = ?", requestingAccount).
|
||||
Where("target_account_id = ?", targetAccount).
|
||||
Column("follow.show_reblogs", "follow.notify").
|
||||
Where("? = ?", bun.Ident("follow.account_id"), requestingAccount).
|
||||
Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount).
|
||||
Limit(1).
|
||||
Scan(ctx); err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
// a proper error
|
||||
return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err)
|
||||
if err := r.conn.ProcessError(err); err != db.ErrNoEntries {
|
||||
return nil, fmt.Errorf("GetRelationship: error fetching follow: %s", err)
|
||||
}
|
||||
// no follow exists so these are all false
|
||||
rel.Following = false
|
||||
|
@ -119,55 +118,56 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount
|
|||
}
|
||||
|
||||
// check if the target account follows the requesting account
|
||||
count, err := r.conn.
|
||||
followedByQ := r.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.Follow{}).
|
||||
Where("account_id = ?", targetAccount).
|
||||
Where("target_account_id = ?", requestingAccount).
|
||||
Limit(1).
|
||||
Count(ctx)
|
||||
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
|
||||
Column("follow.id").
|
||||
Where("? = ?", bun.Ident("follow.account_id"), targetAccount).
|
||||
Where("? = ?", bun.Ident("follow.target_account_id"), requestingAccount)
|
||||
followedBy, err := r.conn.Exists(ctx, followedByQ)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err)
|
||||
return nil, fmt.Errorf("GetRelationship: error checking followedBy: %s", err)
|
||||
}
|
||||
rel.FollowedBy = count > 0
|
||||
|
||||
// check if the requesting account blocks the target account
|
||||
count, err = r.conn.NewSelect().
|
||||
Model(>smodel.Block{}).
|
||||
Where("account_id = ?", requestingAccount).
|
||||
Where("target_account_id = ?", targetAccount).
|
||||
Limit(1).
|
||||
Count(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err)
|
||||
}
|
||||
rel.Blocking = count > 0
|
||||
|
||||
// check if the target account blocks the requesting account
|
||||
count, err = r.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.Block{}).
|
||||
Where("account_id = ?", targetAccount).
|
||||
Where("target_account_id = ?", requestingAccount).
|
||||
Limit(1).
|
||||
Count(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
|
||||
}
|
||||
rel.BlockedBy = count > 0
|
||||
rel.FollowedBy = followedBy
|
||||
|
||||
// check if there's a pending following request from requesting account to target account
|
||||
count, err = r.conn.
|
||||
requestedQ := r.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.FollowRequest{}).
|
||||
Where("account_id = ?", requestingAccount).
|
||||
Where("target_account_id = ?", targetAccount).
|
||||
Limit(1).
|
||||
Count(ctx)
|
||||
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
|
||||
Column("follow_request.id").
|
||||
Where("? = ?", bun.Ident("follow_request.account_id"), requestingAccount).
|
||||
Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount)
|
||||
requested, err := r.conn.Exists(ctx, requestedQ)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
|
||||
return nil, fmt.Errorf("GetRelationship: error checking requested: %s", err)
|
||||
}
|
||||
rel.Requested = count > 0
|
||||
rel.Requested = requested
|
||||
|
||||
// check if the requesting account is blocking the target account
|
||||
blockingQ := r.conn.
|
||||
NewSelect().
|
||||
TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
|
||||
Column("block.id").
|
||||
Where("? = ?", bun.Ident("block.account_id"), requestingAccount).
|
||||
Where("? = ?", bun.Ident("block.target_account_id"), targetAccount)
|
||||
blocking, err := r.conn.Exists(ctx, blockingQ)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetRelationship: error checking blocking: %s", err)
|
||||
}
|
||||
rel.Blocking = blocking
|
||||
|
||||
// check if the requesting account is blocked by the target account
|
||||
blockedByQ := r.conn.
|
||||
NewSelect().
|
||||
TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
|
||||
Column("block.id").
|
||||
Where("? = ?", bun.Ident("block.account_id"), targetAccount).
|
||||
Where("? = ?", bun.Ident("block.target_account_id"), requestingAccount)
|
||||
blockedBy, err := r.conn.Exists(ctx, blockedByQ)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %s", err)
|
||||
}
|
||||
rel.BlockedBy = blockedBy
|
||||
|
||||
return rel, nil
|
||||
}
|
||||
|
@ -179,10 +179,10 @@ func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmode
|
|||
|
||||
q := r.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.Follow{}).
|
||||
Where("account_id = ?", sourceAccount.ID).
|
||||
Where("target_account_id = ?", targetAccount.ID).
|
||||
Limit(1)
|
||||
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
|
||||
Column("follow.id").
|
||||
Where("? = ?", bun.Ident("follow.account_id"), sourceAccount.ID).
|
||||
Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount.ID)
|
||||
|
||||
return r.conn.Exists(ctx, q)
|
||||
}
|
||||
|
@ -194,9 +194,10 @@ func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *g
|
|||
|
||||
q := r.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.FollowRequest{}).
|
||||
Where("account_id = ?", sourceAccount.ID).
|
||||
Where("target_account_id = ?", targetAccount.ID)
|
||||
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
|
||||
Column("follow_request.id").
|
||||
Where("? = ?", bun.Ident("follow_request.account_id"), sourceAccount.ID).
|
||||
Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount.ID)
|
||||
|
||||
return r.conn.Exists(ctx, q)
|
||||
}
|
||||
|
@ -222,82 +223,98 @@ func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmod
|
|||
}
|
||||
|
||||
func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
|
||||
// make sure the original follow request exists
|
||||
fr := >smodel.FollowRequest{}
|
||||
if err := r.conn.
|
||||
var follow *gtsmodel.Follow
|
||||
|
||||
if err := r.conn.RunInTx(ctx, func(tx bun.Tx) error {
|
||||
// get original follow request
|
||||
followRequest := >smodel.FollowRequest{}
|
||||
if err := tx.
|
||||
NewSelect().
|
||||
Model(fr).
|
||||
Where("account_id = ?", originAccountID).
|
||||
Where("target_account_id = ?", targetAccountID).
|
||||
Model(followRequest).
|
||||
Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID).
|
||||
Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID).
|
||||
Scan(ctx); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
// create a new follow to 'replace' the request with
|
||||
follow := >smodel.Follow{
|
||||
ID: fr.ID,
|
||||
follow = >smodel.Follow{
|
||||
ID: followRequest.ID,
|
||||
AccountID: originAccountID,
|
||||
TargetAccountID: targetAccountID,
|
||||
URI: fr.URI,
|
||||
URI: followRequest.URI,
|
||||
}
|
||||
|
||||
// if the follow already exists, just update the URI -- we don't need to do anything else
|
||||
if _, err := r.conn.
|
||||
if _, err := tx.
|
||||
NewInsert().
|
||||
Model(follow).
|
||||
On("CONFLICT (account_id,target_account_id) DO UPDATE set uri = ?", follow.URI).
|
||||
On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI).
|
||||
Exec(ctx); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
// now remove the follow request
|
||||
if _, err := r.conn.
|
||||
if _, err := tx.
|
||||
NewDelete().
|
||||
Model(>smodel.FollowRequest{}).
|
||||
Where("account_id = ?", originAccountID).
|
||||
Where("target_account_id = ?", targetAccountID).
|
||||
Model(followRequest).
|
||||
WherePK().
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// return the new follow
|
||||
return follow, nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, db.Error) {
|
||||
// first get the follow request out of the database
|
||||
fr := >smodel.FollowRequest{}
|
||||
if err := r.conn.
|
||||
followRequest := >smodel.FollowRequest{}
|
||||
|
||||
if err := r.conn.RunInTx(ctx, func(tx bun.Tx) error {
|
||||
// get original follow request
|
||||
if err := tx.
|
||||
NewSelect().
|
||||
Model(fr).
|
||||
Where("account_id = ?", originAccountID).
|
||||
Where("target_account_id = ?", targetAccountID).
|
||||
Model(followRequest).
|
||||
Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID).
|
||||
Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID).
|
||||
Scan(ctx); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
// now delete it from the database by ID
|
||||
if _, err := r.conn.
|
||||
if _, err := tx.
|
||||
NewDelete().
|
||||
Model(>smodel.FollowRequest{ID: fr.ID}).
|
||||
Model(followRequest).
|
||||
WherePK().
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// return the deleted follow request
|
||||
return fr, nil
|
||||
return followRequest, nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, db.Error) {
|
||||
followRequests := []*gtsmodel.FollowRequest{}
|
||||
|
||||
q := r.newFollowQ(&followRequests).
|
||||
Where("target_account_id = ?", accountID).
|
||||
Where("? = ?", bun.Ident("follow_request.target_account_id"), accountID).
|
||||
Order("follow_request.updated_at DESC")
|
||||
|
||||
if err := q.Scan(ctx); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
return followRequests, nil
|
||||
}
|
||||
|
||||
|
@ -305,21 +322,31 @@ func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string
|
|||
follows := []*gtsmodel.Follow{}
|
||||
|
||||
q := r.newFollowQ(&follows).
|
||||
Where("account_id = ?", accountID).
|
||||
Where("? = ?", bun.Ident("follow.account_id"), accountID).
|
||||
Order("follow.updated_at DESC")
|
||||
|
||||
if err := q.Scan(ctx); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
return follows, nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, db.Error) {
|
||||
return r.conn.
|
||||
q := r.conn.
|
||||
NewSelect().
|
||||
Model(&[]*gtsmodel.Follow{}).
|
||||
Where("account_id = ?", accountID).
|
||||
Count(ctx)
|
||||
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow"))
|
||||
|
||||
if localOnly {
|
||||
q = q.
|
||||
Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.target_account_id"), bun.Ident("account.id")).
|
||||
Where("? = ?", bun.Ident("follow.account_id"), accountID).
|
||||
Where("? IS NULL", bun.Ident("account.domain"))
|
||||
} else {
|
||||
q = q.Where("? = ?", bun.Ident("follow.account_id"), accountID)
|
||||
}
|
||||
|
||||
return q.Count(ctx)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) {
|
||||
|
@ -331,12 +358,12 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str
|
|||
Order("follow.updated_at DESC")
|
||||
|
||||
if localOnly {
|
||||
q = q.ColumnExpr("follow.*").
|
||||
Join("JOIN accounts AS a ON follow.account_id = CAST(a.id as TEXT)").
|
||||
Where("follow.target_account_id = ?", accountID).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("a.domain"))
|
||||
q = q.
|
||||
Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.account_id"), bun.Ident("account.id")).
|
||||
Where("? = ?", bun.Ident("follow.target_account_id"), accountID).
|
||||
Where("? IS NULL", bun.Ident("account.domain"))
|
||||
} else {
|
||||
q = q.Where("target_account_id = ?", accountID)
|
||||
q = q.Where("? = ?", bun.Ident("follow.target_account_id"), accountID)
|
||||
}
|
||||
|
||||
err := q.Scan(ctx)
|
||||
|
@ -347,9 +374,18 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str
|
|||
}
|
||||
|
||||
func (r *relationshipDB) CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, db.Error) {
|
||||
return r.conn.
|
||||
q := r.conn.
|
||||
NewSelect().
|
||||
Model(&[]*gtsmodel.Follow{}).
|
||||
Where("target_account_id = ?", accountID).
|
||||
Count(ctx)
|
||||
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow"))
|
||||
|
||||
if localOnly {
|
||||
q = q.
|
||||
Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.account_id"), bun.Ident("account.id")).
|
||||
Where("? = ?", bun.Ident("follow.target_account_id"), accountID).
|
||||
Where("? IS NULL", bun.Ident("account.domain"))
|
||||
} else {
|
||||
q = q.Where("? = ?", bun.Ident("follow.target_account_id"), accountID)
|
||||
}
|
||||
|
||||
return q.Count(ctx)
|
||||
}
|
||||
|
|
|
@ -20,7 +20,6 @@ package bundb_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
@ -48,12 +47,14 @@ func (suite *RelationshipTestSuite) TestIsBlocked() {
|
|||
suite.False(blocked)
|
||||
|
||||
// have account1 block account2
|
||||
suite.db.Put(ctx, >smodel.Block{
|
||||
if err := suite.db.Put(ctx, >smodel.Block{
|
||||
ID: "01G202BCSXXJZ70BHB5KCAHH8C",
|
||||
URI: "http://localhost:8080/some_block_uri_1",
|
||||
AccountID: account1,
|
||||
TargetAccountID: account2,
|
||||
})
|
||||
}); err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
// account 1 now blocks account 2
|
||||
blocked, err = suite.db.IsBlocked(ctx, account1, account2, false)
|
||||
|
@ -75,62 +76,242 @@ func (suite *RelationshipTestSuite) TestIsBlocked() {
|
|||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestGetBlock() {
|
||||
suite.Suite.T().Skip("TODO: implement")
|
||||
ctx := context.Background()
|
||||
|
||||
account1 := suite.testAccounts["local_account_1"].ID
|
||||
account2 := suite.testAccounts["local_account_2"].ID
|
||||
|
||||
if err := suite.db.Put(ctx, >smodel.Block{
|
||||
ID: "01G202BCSXXJZ70BHB5KCAHH8C",
|
||||
URI: "http://localhost:8080/some_block_uri_1",
|
||||
AccountID: account1,
|
||||
TargetAccountID: account2,
|
||||
}); err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
block, err := suite.db.GetBlock(ctx, account1, account2)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(block)
|
||||
suite.Equal("01G202BCSXXJZ70BHB5KCAHH8C", block.ID)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestGetRelationship() {
|
||||
suite.Suite.T().Skip("TODO: implement")
|
||||
requestingAccount := suite.testAccounts["local_account_1"]
|
||||
targetAccount := suite.testAccounts["admin_account"]
|
||||
|
||||
relationship, err := suite.db.GetRelationship(context.Background(), requestingAccount.ID, targetAccount.ID)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(relationship)
|
||||
|
||||
suite.True(relationship.Following)
|
||||
suite.True(relationship.ShowingReblogs)
|
||||
suite.False(relationship.Notifying)
|
||||
suite.True(relationship.FollowedBy)
|
||||
suite.False(relationship.Blocking)
|
||||
suite.False(relationship.BlockedBy)
|
||||
suite.False(relationship.Muting)
|
||||
suite.False(relationship.MutingNotifications)
|
||||
suite.False(relationship.Requested)
|
||||
suite.False(relationship.DomainBlocking)
|
||||
suite.False(relationship.Endorsed)
|
||||
suite.Empty(relationship.Note)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestIsFollowing() {
|
||||
suite.Suite.T().Skip("TODO: implement")
|
||||
func (suite *RelationshipTestSuite) TestIsFollowingYes() {
|
||||
requestingAccount := suite.testAccounts["local_account_1"]
|
||||
targetAccount := suite.testAccounts["admin_account"]
|
||||
isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount, targetAccount)
|
||||
suite.NoError(err)
|
||||
suite.True(isFollowing)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestIsFollowingNo() {
|
||||
requestingAccount := suite.testAccounts["admin_account"]
|
||||
targetAccount := suite.testAccounts["local_account_2"]
|
||||
isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount, targetAccount)
|
||||
suite.NoError(err)
|
||||
suite.False(isFollowing)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestIsMutualFollowing() {
|
||||
suite.Suite.T().Skip("TODO: implement")
|
||||
requestingAccount := suite.testAccounts["local_account_1"]
|
||||
targetAccount := suite.testAccounts["admin_account"]
|
||||
isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount, targetAccount)
|
||||
suite.NoError(err)
|
||||
suite.True(isMutualFollowing)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) AcceptFollowRequest() {
|
||||
for _, account := range suite.testAccounts {
|
||||
_, err := suite.db.AcceptFollowRequest(context.Background(), account.ID, "NON-EXISTENT-ID")
|
||||
if err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
suite.Suite.Fail("error accepting follow request: %v", err)
|
||||
}
|
||||
}
|
||||
func (suite *RelationshipTestSuite) TestIsMutualFollowingNo() {
|
||||
requestingAccount := suite.testAccounts["local_account_1"]
|
||||
targetAccount := suite.testAccounts["local_account_2"]
|
||||
isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount, targetAccount)
|
||||
suite.NoError(err)
|
||||
suite.True(isMutualFollowing)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) GetAccountFollowRequests() {
|
||||
suite.Suite.T().Skip("TODO: implement")
|
||||
}
|
||||
func (suite *RelationshipTestSuite) TestAcceptFollowRequestOK() {
|
||||
ctx := context.Background()
|
||||
account := suite.testAccounts["admin_account"]
|
||||
targetAccount := suite.testAccounts["local_account_2"]
|
||||
|
||||
func (suite *RelationshipTestSuite) GetAccountFollows() {
|
||||
suite.Suite.T().Skip("TODO: implement")
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) CountAccountFollows() {
|
||||
suite.Suite.T().Skip("TODO: implement")
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) GetAccountFollowedBy() {
|
||||
// TODO: more comprehensive tests here
|
||||
|
||||
for _, account := range suite.testAccounts {
|
||||
var err error
|
||||
|
||||
_, err = suite.db.GetAccountFollowedBy(context.Background(), account.ID, false)
|
||||
if err != nil {
|
||||
suite.Suite.Fail("error checking accounts followed by: %v", err)
|
||||
followRequest := >smodel.FollowRequest{
|
||||
ID: "01GEF753FWHCHRDWR0QEHBXM8W",
|
||||
URI: "http://localhost:8080/weeeeeeeeeeeeeeeee",
|
||||
AccountID: account.ID,
|
||||
TargetAccountID: targetAccount.ID,
|
||||
}
|
||||
|
||||
_, err = suite.db.GetAccountFollowedBy(context.Background(), account.ID, true)
|
||||
if err != nil {
|
||||
suite.Suite.Fail("error checking localOnly accounts followed by: %v", err)
|
||||
}
|
||||
if err := suite.db.Put(ctx, followRequest); err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
follow, err := suite.db.AcceptFollowRequest(ctx, account.ID, targetAccount.ID)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(follow)
|
||||
suite.Equal(followRequest.URI, follow.URI)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) CountAccountFollowedBy() {
|
||||
suite.Suite.T().Skip("TODO: implement")
|
||||
func (suite *RelationshipTestSuite) TestAcceptFollowRequestNotExisting() {
|
||||
ctx := context.Background()
|
||||
account := suite.testAccounts["admin_account"]
|
||||
targetAccount := suite.testAccounts["local_account_2"]
|
||||
|
||||
follow, err := suite.db.AcceptFollowRequest(ctx, account.ID, targetAccount.ID)
|
||||
suite.ErrorIs(err, db.ErrNoEntries)
|
||||
suite.Nil(follow)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestAcceptFollowRequestFollowAlreadyExists() {
|
||||
ctx := context.Background()
|
||||
account := suite.testAccounts["local_account_1"]
|
||||
targetAccount := suite.testAccounts["admin_account"]
|
||||
|
||||
// follow already exists in the db from local_account_1 -> admin_account
|
||||
existingFollow := >smodel.Follow{}
|
||||
if err := suite.db.GetByID(ctx, suite.testFollows["local_account_1_admin_account"].ID, existingFollow); err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
followRequest := >smodel.FollowRequest{
|
||||
ID: "01GEF753FWHCHRDWR0QEHBXM8W",
|
||||
URI: "http://localhost:8080/weeeeeeeeeeeeeeeee",
|
||||
AccountID: account.ID,
|
||||
TargetAccountID: targetAccount.ID,
|
||||
}
|
||||
|
||||
if err := suite.db.Put(ctx, followRequest); err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
follow, err := suite.db.AcceptFollowRequest(ctx, account.ID, targetAccount.ID)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(follow)
|
||||
|
||||
// uri should be equal to value of new/overlapping follow request
|
||||
suite.NotEqual(followRequest.URI, existingFollow.URI)
|
||||
suite.Equal(followRequest.URI, follow.URI)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestRejectFollowRequestOK() {
|
||||
ctx := context.Background()
|
||||
account := suite.testAccounts["admin_account"]
|
||||
targetAccount := suite.testAccounts["local_account_2"]
|
||||
|
||||
followRequest := >smodel.FollowRequest{
|
||||
ID: "01GEF753FWHCHRDWR0QEHBXM8W",
|
||||
URI: "http://localhost:8080/weeeeeeeeeeeeeeeee",
|
||||
AccountID: account.ID,
|
||||
TargetAccountID: targetAccount.ID,
|
||||
}
|
||||
|
||||
if err := suite.db.Put(ctx, followRequest); err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
rejectedFollowRequest, err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(rejectedFollowRequest)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestRejectFollowRequestNotExisting() {
|
||||
ctx := context.Background()
|
||||
account := suite.testAccounts["admin_account"]
|
||||
targetAccount := suite.testAccounts["local_account_2"]
|
||||
|
||||
rejectedFollowRequest, err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID)
|
||||
suite.ErrorIs(err, db.ErrNoEntries)
|
||||
suite.Nil(rejectedFollowRequest)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestGetAccountFollowRequests() {
|
||||
ctx := context.Background()
|
||||
account := suite.testAccounts["admin_account"]
|
||||
targetAccount := suite.testAccounts["local_account_2"]
|
||||
|
||||
followRequest := >smodel.FollowRequest{
|
||||
ID: "01GEF753FWHCHRDWR0QEHBXM8W",
|
||||
URI: "http://localhost:8080/weeeeeeeeeeeeeeeee",
|
||||
AccountID: account.ID,
|
||||
TargetAccountID: targetAccount.ID,
|
||||
}
|
||||
|
||||
if err := suite.db.Put(ctx, followRequest); err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
followRequests, err := suite.db.GetAccountFollowRequests(ctx, targetAccount.ID)
|
||||
suite.NoError(err)
|
||||
suite.Len(followRequests, 1)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestGetAccountFollows() {
|
||||
account := suite.testAccounts["local_account_1"]
|
||||
follows, err := suite.db.GetAccountFollows(context.Background(), account.ID)
|
||||
suite.NoError(err)
|
||||
suite.Len(follows, 2)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestCountAccountFollowsLocalOnly() {
|
||||
account := suite.testAccounts["local_account_1"]
|
||||
followsCount, err := suite.db.CountAccountFollows(context.Background(), account.ID, true)
|
||||
suite.NoError(err)
|
||||
suite.Equal(2, followsCount)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestCountAccountFollows() {
|
||||
account := suite.testAccounts["local_account_1"]
|
||||
followsCount, err := suite.db.CountAccountFollows(context.Background(), account.ID, false)
|
||||
suite.NoError(err)
|
||||
suite.Equal(2, followsCount)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestGetAccountFollowedBy() {
|
||||
account := suite.testAccounts["local_account_1"]
|
||||
follows, err := suite.db.GetAccountFollowedBy(context.Background(), account.ID, false)
|
||||
suite.NoError(err)
|
||||
suite.Len(follows, 2)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestGetAccountFollowedByLocalOnly() {
|
||||
account := suite.testAccounts["local_account_1"]
|
||||
follows, err := suite.db.GetAccountFollowedBy(context.Background(), account.ID, true)
|
||||
suite.NoError(err)
|
||||
suite.Len(follows, 2)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestCountAccountFollowedBy() {
|
||||
account := suite.testAccounts["local_account_1"]
|
||||
followsCount, err := suite.db.CountAccountFollowedBy(context.Background(), account.ID, false)
|
||||
suite.NoError(err)
|
||||
suite.Equal(2, followsCount)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestCountAccountFollowedByLocalOnly() {
|
||||
account := suite.testAccounts["local_account_1"]
|
||||
followsCount, err := suite.db.CountAccountFollowedBy(context.Background(), account.ID, true)
|
||||
suite.NoError(err)
|
||||
suite.Equal(2, followsCount)
|
||||
}
|
||||
|
||||
func TestRelationshipTestSuite(t *testing.T) {
|
||||
|
|
|
@ -67,7 +67,7 @@ func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, db
|
|||
return u.cache.GetByID(id)
|
||||
},
|
||||
func(user *gtsmodel.User) error {
|
||||
return u.newUserQ(user).Where("user.id = ?", id).Scan(ctx)
|
||||
return u.newUserQ(user).Where("? = ?", bun.Ident("user.id"), id).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
@ -79,7 +79,7 @@ func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gts
|
|||
return u.cache.GetByAccountID(accountID)
|
||||
},
|
||||
func(user *gtsmodel.User) error {
|
||||
return u.newUserQ(user).Where("user.account_id = ?", accountID).Scan(ctx)
|
||||
return u.newUserQ(user).Where("? = ?", bun.Ident("user.account_id"), accountID).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
@ -91,7 +91,7 @@ func (u *userDB) GetUserByEmailAddress(ctx context.Context, emailAddress string)
|
|||
return u.cache.GetByEmail(emailAddress)
|
||||
},
|
||||
func(user *gtsmodel.User) error {
|
||||
return u.newUserQ(user).Where("user.email = ?", emailAddress).Scan(ctx)
|
||||
return u.newUserQ(user).Where("? = ?", bun.Ident("user.email"), emailAddress).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
@ -103,7 +103,7 @@ func (u *userDB) GetUserByConfirmationToken(ctx context.Context, confirmationTok
|
|||
return u.cache.GetByConfirmationToken(confirmationToken)
|
||||
},
|
||||
func(user *gtsmodel.User) error {
|
||||
return u.newUserQ(user).Where("user.confirmation_token = ?", confirmationToken).Scan(ctx)
|
||||
return u.newUserQ(user).Where("? = ?", bun.Ident("user.confirmation_token"), confirmationToken).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
|
|
@ -85,14 +85,8 @@ func parseWhere(w db.Where) (query string, args []interface{}) {
|
|||
return
|
||||
}
|
||||
|
||||
if w.CaseInsensitive {
|
||||
query = "LOWER(?) != LOWER(?)"
|
||||
args = []interface{}{bun.Safe(w.Key), w.Value}
|
||||
return
|
||||
}
|
||||
|
||||
query = "? != ?"
|
||||
args = []interface{}{bun.Safe(w.Key), w.Value}
|
||||
args = []interface{}{bun.Ident(w.Key), w.Value}
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -102,13 +96,7 @@ func parseWhere(w db.Where) (query string, args []interface{}) {
|
|||
return
|
||||
}
|
||||
|
||||
if w.CaseInsensitive {
|
||||
query = "LOWER(?) = LOWER(?)"
|
||||
args = []interface{}{bun.Safe(w.Key), w.Value}
|
||||
return
|
||||
}
|
||||
|
||||
query = "? = ?"
|
||||
args = []interface{}{bun.Safe(w.Key), w.Value}
|
||||
args = []interface{}{bun.Ident(w.Key), w.Value}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -24,9 +24,6 @@ type Where struct {
|
|||
Key string
|
||||
// The value to match.
|
||||
Value interface{}
|
||||
// Whether the value (if a string) should be case sensitive or not.
|
||||
// Defaults to false.
|
||||
CaseInsensitive bool
|
||||
// If set, reverse the where.
|
||||
// `WHERE k = v` becomes `WHERE k != v`.
|
||||
// `WHERE k IS NULL` becomes `WHERE k IS NOT NULL`
|
||||
|
|
|
@ -133,8 +133,10 @@ func (p *processor) initiateDomainBlockSideEffects(ctx context.Context, account
|
|||
}
|
||||
|
||||
// if we have an instance account for this instance, delete it
|
||||
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "username", Value: block.Domain, CaseInsensitive: true}}, >smodel.Account{}); err != nil {
|
||||
l.Errorf("domainBlockProcessSideEffects: db error removing instance account: %s", err)
|
||||
if instanceAccount, err := p.db.GetAccountByUsernameDomain(ctx, block.Domain, block.Domain); err == nil {
|
||||
if err := p.db.DeleteAccount(ctx, instanceAccount.ID); err != nil {
|
||||
l.Errorf("domainBlockProcessSideEffects: db error deleting instance account: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// delete accounts through the normal account deletion system (which should also delete media + posts + remove posts from timelines)
|
||||
|
|
|
@ -55,7 +55,7 @@ func (p *processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Acc
|
|||
// remove the domain block reference from the instance, if we have an entry for it
|
||||
i := >smodel.Instance{}
|
||||
if err := p.db.GetWhere(ctx, []db.Where{
|
||||
{Key: "domain", Value: domainBlock.Domain, CaseInsensitive: true},
|
||||
{Key: "domain", Value: domainBlock.Domain},
|
||||
{Key: "domain_block_id", Value: id},
|
||||
}, i); err == nil {
|
||||
updatingColumns := []string{"suspended_at", "domain_block_id", "updated_at"}
|
||||
|
|
Loading…
Reference in New Issue