Merge commit '4d1c7c871bb192bc05ce1cb6fb80773126373111'
This commit is contained in:
parent
4100fc683c
commit
7594fc7456
|
@ -101,6 +101,11 @@ func (c *AccountCache) Put(account *gtsmodel.Account) {
|
||||||
c.cache.Set(account.ID, copyAccount(account))
|
c.cache.Set(account.ID, copyAccount(account))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Invalidate removes (invalidates) one account from the cache by its ID.
|
||||||
|
func (c *AccountCache) Invalidate(id string) {
|
||||||
|
c.cache.Invalidate(id)
|
||||||
|
}
|
||||||
|
|
||||||
// copyAccount performs a surface-level copy of account, only keeping attached IDs intact, not the objects.
|
// copyAccount performs a surface-level copy of account, only keeping attached IDs intact, not the objects.
|
||||||
// due to all the data being copied being 99% primitive types or strings (which are immutable and passed by ptr)
|
// due to all the data being copied being 99% primitive types or strings (which are immutable and passed by ptr)
|
||||||
// this should be a relatively cheap process
|
// this should be a relatively cheap process
|
||||||
|
|
|
@ -48,6 +48,11 @@ type Account interface {
|
||||||
// UpdateAccount updates one account by ID.
|
// UpdateAccount updates one account by ID.
|
||||||
UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error)
|
UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error)
|
||||||
|
|
||||||
|
// DeleteAccount deletes one account from the database by its ID.
|
||||||
|
// DO NOT USE THIS WHEN SUSPENDING ACCOUNTS! In that case you should mark the
|
||||||
|
// account as suspended instead, rather than deleting from the db entirely.
|
||||||
|
DeleteAccount(ctx context.Context, id string) Error
|
||||||
|
|
||||||
// GetAccountCustomCSSByUsername returns the custom css of an account on this instance with the given username.
|
// GetAccountCustomCSSByUsername returns the custom css of an account on this instance with the given username.
|
||||||
GetAccountCustomCSSByUsername(ctx context.Context, username string) (string, Error)
|
GetAccountCustomCSSByUsername(ctx context.Context, username string) (string, Error)
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,6 @@ package bundb
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -56,7 +55,7 @@ func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Ac
|
||||||
return a.cache.GetByID(id)
|
return a.cache.GetByID(id)
|
||||||
},
|
},
|
||||||
func(account *gtsmodel.Account) error {
|
func(account *gtsmodel.Account) error {
|
||||||
return a.newAccountQ(account).Where("account.id = ?", id).Scan(ctx)
|
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.id"), id).Scan(ctx)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -68,7 +67,7 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.
|
||||||
return a.cache.GetByURI(uri)
|
return a.cache.GetByURI(uri)
|
||||||
},
|
},
|
||||||
func(account *gtsmodel.Account) error {
|
func(account *gtsmodel.Account) error {
|
||||||
return a.newAccountQ(account).Where("account.uri = ?", uri).Scan(ctx)
|
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.uri"), uri).Scan(ctx)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -80,7 +79,7 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.
|
||||||
return a.cache.GetByURL(url)
|
return a.cache.GetByURL(url)
|
||||||
},
|
},
|
||||||
func(account *gtsmodel.Account) error {
|
func(account *gtsmodel.Account) error {
|
||||||
return a.newAccountQ(account).Where("account.url = ?", url).Scan(ctx)
|
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.url"), url).Scan(ctx)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -95,11 +94,11 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str
|
||||||
q := a.newAccountQ(account)
|
q := a.newAccountQ(account)
|
||||||
|
|
||||||
if domain != "" {
|
if domain != "" {
|
||||||
q = q.Where("account.username = ?", username)
|
q = q.Where("? = ?", bun.Ident("account.username"), username)
|
||||||
q = q.Where("account.domain = ?", domain)
|
q = q.Where("? = ?", bun.Ident("account.domain"), domain)
|
||||||
} else {
|
} else {
|
||||||
q = q.Where("account.username = ?", strings.ToLower(username))
|
q = q.Where("? = ?", bun.Ident("account.username"), strings.ToLower(username))
|
||||||
q = q.Where("account.domain IS NULL")
|
q = q.Where("? IS NULL", bun.Ident("account.domain"))
|
||||||
}
|
}
|
||||||
|
|
||||||
return q.Scan(ctx)
|
return q.Scan(ctx)
|
||||||
|
@ -114,7 +113,7 @@ func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmo
|
||||||
return a.cache.GetByPubkeyID(id)
|
return a.cache.GetByPubkeyID(id)
|
||||||
},
|
},
|
||||||
func(account *gtsmodel.Account) error {
|
func(account *gtsmodel.Account) error {
|
||||||
return a.newAccountQ(account).Where("account.public_key_uri = ?", id).Scan(ctx)
|
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.public_key_uri"), id).Scan(ctx)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -170,8 +169,8 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
|
||||||
// create links between this account and any emojis it uses
|
// create links between this account and any emojis it uses
|
||||||
// first clear out any old emoji links
|
// first clear out any old emoji links
|
||||||
if _, err := tx.NewDelete().
|
if _, err := tx.NewDelete().
|
||||||
Model(&[]*gtsmodel.AccountToEmoji{}).
|
TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")).
|
||||||
Where("account_id = ?", account.ID).
|
Where("? = ?", bun.Ident("account_to_emoji.account_id"), account.ID).
|
||||||
Exec(ctx); err != nil {
|
Exec(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -197,6 +196,32 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
|
||||||
return account, nil
|
return account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
|
||||||
|
if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error {
|
||||||
|
// clear out any emoji links
|
||||||
|
if _, err := tx.
|
||||||
|
NewDelete().
|
||||||
|
TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")).
|
||||||
|
Where("? = ?", bun.Ident("account_to_emoji.account_id"), id).
|
||||||
|
Exec(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// delete the account
|
||||||
|
_, err := tx.
|
||||||
|
NewUpdate().
|
||||||
|
Model(>smodel.Account{ID: id}).
|
||||||
|
WherePK().
|
||||||
|
Exec(ctx)
|
||||||
|
return err
|
||||||
|
}); err != nil {
|
||||||
|
return a.conn.ProcessError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
a.cache.Invalidate(id)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
|
func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
|
||||||
account := new(gtsmodel.Account)
|
account := new(gtsmodel.Account)
|
||||||
|
|
||||||
|
@ -204,11 +229,11 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts
|
||||||
|
|
||||||
if domain != "" {
|
if domain != "" {
|
||||||
q = q.
|
q = q.
|
||||||
Where("account.username = ?", domain).
|
Where("? = ?", bun.Ident("account.username"), domain).
|
||||||
Where("account.domain = ?", domain)
|
Where("? = ?", bun.Ident("account.domain"), domain)
|
||||||
} else {
|
} else {
|
||||||
q = q.
|
q = q.
|
||||||
Where("account.username = ?", config.GetHost()).
|
Where("? = ?", bun.Ident("account.username"), config.GetHost()).
|
||||||
WhereGroup(" AND ", whereEmptyOrNull("domain"))
|
WhereGroup(" AND ", whereEmptyOrNull("domain"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -224,10 +249,10 @@ func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string)
|
||||||
q := a.conn.
|
q := a.conn.
|
||||||
NewSelect().
|
NewSelect().
|
||||||
Model(status).
|
Model(status).
|
||||||
Order("id DESC").
|
Column("status.created_at").
|
||||||
Limit(1).
|
Where("? = ?", bun.Ident("status.account_id"), accountID).
|
||||||
Where("account_id = ?", accountID).
|
Order("status.id DESC").
|
||||||
Column("created_at")
|
Limit(1)
|
||||||
|
|
||||||
if err := q.Scan(ctx); err != nil {
|
if err := q.Scan(ctx); err != nil {
|
||||||
return time.Time{}, a.conn.ProcessError(err)
|
return time.Time{}, a.conn.ProcessError(err)
|
||||||
|
@ -240,12 +265,12 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen
|
||||||
return errors.New("one media attachment cannot be both header and avatar")
|
return errors.New("one media attachment cannot be both header and avatar")
|
||||||
}
|
}
|
||||||
|
|
||||||
var headerOrAVI string
|
var column bun.Ident
|
||||||
switch {
|
switch {
|
||||||
case *mediaAttachment.Avatar:
|
case *mediaAttachment.Avatar:
|
||||||
headerOrAVI = "avatar"
|
column = bun.Ident("account.avatar_media_attachment_id")
|
||||||
case *mediaAttachment.Header:
|
case *mediaAttachment.Header:
|
||||||
headerOrAVI = "header"
|
column = bun.Ident("account.header_media_attachment_id")
|
||||||
default:
|
default:
|
||||||
return errors.New("given media attachment was neither a header nor an avatar")
|
return errors.New("given media attachment was neither a header nor an avatar")
|
||||||
}
|
}
|
||||||
|
@ -257,11 +282,12 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen
|
||||||
Exec(ctx); err != nil {
|
Exec(ctx); err != nil {
|
||||||
return a.conn.ProcessError(err)
|
return a.conn.ProcessError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := a.conn.
|
if _, err := a.conn.
|
||||||
NewUpdate().
|
NewUpdate().
|
||||||
Model(>smodel.Account{}).
|
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
|
||||||
Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).
|
Set("? = ?", column, mediaAttachment.ID).
|
||||||
Where("id = ?", accountID).
|
Where("? = ?", bun.Ident("account.id"), accountID).
|
||||||
Exec(ctx); err != nil {
|
Exec(ctx); err != nil {
|
||||||
return a.conn.ProcessError(err)
|
return a.conn.ProcessError(err)
|
||||||
}
|
}
|
||||||
|
@ -284,7 +310,7 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g
|
||||||
if err := a.conn.
|
if err := a.conn.
|
||||||
NewSelect().
|
NewSelect().
|
||||||
Model(faves).
|
Model(faves).
|
||||||
Where("account_id = ?", accountID).
|
Where("? = ?", bun.Ident("status_fave.account_id"), accountID).
|
||||||
Scan(ctx); err != nil {
|
Scan(ctx); err != nil {
|
||||||
return nil, a.conn.ProcessError(err)
|
return nil, a.conn.ProcessError(err)
|
||||||
}
|
}
|
||||||
|
@ -295,8 +321,8 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g
|
||||||
func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, db.Error) {
|
func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, db.Error) {
|
||||||
return a.conn.
|
return a.conn.
|
||||||
NewSelect().
|
NewSelect().
|
||||||
Model(>smodel.Status{}).
|
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
|
||||||
Where("account_id = ?", accountID).
|
Where("? = ?", bun.Ident("status.account_id"), accountID).
|
||||||
Count(ctx)
|
Count(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -305,12 +331,12 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
|
||||||
|
|
||||||
q := a.conn.
|
q := a.conn.
|
||||||
NewSelect().
|
NewSelect().
|
||||||
Table("statuses").
|
Model(>smodel.Status{}).
|
||||||
Column("id").
|
Column("status.id").
|
||||||
Order("id DESC")
|
Order("status.id DESC")
|
||||||
|
|
||||||
if accountID != "" {
|
if accountID != "" {
|
||||||
q = q.Where("account_id = ?", accountID)
|
q = q.Where("? = ?", bun.Ident("status.account_id"), accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if limit != 0 {
|
if limit != 0 {
|
||||||
|
@ -321,27 +347,27 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
|
||||||
// include self-replies (threads)
|
// include self-replies (threads)
|
||||||
whereGroup := func(*bun.SelectQuery) *bun.SelectQuery {
|
whereGroup := func(*bun.SelectQuery) *bun.SelectQuery {
|
||||||
return q.
|
return q.
|
||||||
WhereOr("in_reply_to_account_id = ?", accountID).
|
WhereOr("? = ?", bun.Ident("status.in_reply_to_account_id"), accountID).
|
||||||
WhereGroup(" OR ", whereEmptyOrNull("in_reply_to_uri"))
|
WhereGroup(" OR ", whereEmptyOrNull("status.in_reply_to_uri"))
|
||||||
}
|
}
|
||||||
|
|
||||||
q = q.WhereGroup(" AND ", whereGroup)
|
q = q.WhereGroup(" AND ", whereGroup)
|
||||||
}
|
}
|
||||||
|
|
||||||
if excludeReblogs {
|
if excludeReblogs {
|
||||||
q = q.WhereGroup(" AND ", whereEmptyOrNull("boost_of_id"))
|
q = q.WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id"))
|
||||||
}
|
}
|
||||||
|
|
||||||
if maxID != "" {
|
if maxID != "" {
|
||||||
q = q.Where("id < ?", maxID)
|
q = q.Where("? < ?", bun.Ident("status.id"), maxID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if minID != "" {
|
if minID != "" {
|
||||||
q = q.Where("id > ?", minID)
|
q = q.Where("? > ?", bun.Ident("status.id"), minID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if pinnedOnly {
|
if pinnedOnly {
|
||||||
q = q.Where("pinned = ?", true)
|
q = q.Where("? = ?", bun.Ident("status.pinned"), true)
|
||||||
}
|
}
|
||||||
|
|
||||||
if mediaOnly {
|
if mediaOnly {
|
||||||
|
@ -352,15 +378,15 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
|
||||||
switch a.conn.Dialect().Name() {
|
switch a.conn.Dialect().Name() {
|
||||||
case dialect.PG:
|
case dialect.PG:
|
||||||
return q.
|
return q.
|
||||||
Where("? IS NOT NULL", bun.Ident("attachments")).
|
Where("? IS NOT NULL", bun.Ident("status.attachments")).
|
||||||
Where("? != '{}'", bun.Ident("attachments"))
|
Where("? != '{}'", bun.Ident("status.attachments"))
|
||||||
case dialect.SQLite:
|
case dialect.SQLite:
|
||||||
return q.
|
return q.
|
||||||
Where("? IS NOT NULL", bun.Ident("attachments")).
|
Where("? IS NOT NULL", bun.Ident("status.attachments")).
|
||||||
Where("? != ''", bun.Ident("attachments")).
|
Where("? != ''", bun.Ident("status.attachments")).
|
||||||
Where("? != 'null'", bun.Ident("attachments")).
|
Where("? != 'null'", bun.Ident("status.attachments")).
|
||||||
Where("? != '{}'", bun.Ident("attachments")).
|
Where("? != '{}'", bun.Ident("status.attachments")).
|
||||||
Where("? != '[]'", bun.Ident("attachments"))
|
Where("? != '[]'", bun.Ident("status.attachments"))
|
||||||
default:
|
default:
|
||||||
log.Panic("db dialect was neither pg nor sqlite")
|
log.Panic("db dialect was neither pg nor sqlite")
|
||||||
return q
|
return q
|
||||||
|
@ -369,7 +395,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
|
||||||
}
|
}
|
||||||
|
|
||||||
if publicOnly {
|
if publicOnly {
|
||||||
q = q.Where("visibility = ?", gtsmodel.VisibilityPublic)
|
q = q.Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := q.Scan(ctx, &statusIDs); err != nil {
|
if err := q.Scan(ctx, &statusIDs); err != nil {
|
||||||
|
@ -384,19 +410,19 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string,
|
||||||
|
|
||||||
q := a.conn.
|
q := a.conn.
|
||||||
NewSelect().
|
NewSelect().
|
||||||
Table("statuses").
|
Model(>smodel.Status{}).
|
||||||
Column("id").
|
Column("status.id").
|
||||||
Where("account_id = ?", accountID).
|
Where("? = ?", bun.Ident("status.account_id"), accountID).
|
||||||
WhereGroup(" AND ", whereEmptyOrNull("in_reply_to_uri")).
|
WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_uri")).
|
||||||
WhereGroup(" AND ", whereEmptyOrNull("boost_of_id")).
|
WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")).
|
||||||
Where("visibility = ?", gtsmodel.VisibilityPublic).
|
Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic).
|
||||||
Where("federated = ?", true)
|
Where("? = ?", bun.Ident("status.federated"), true)
|
||||||
|
|
||||||
if maxID != "" {
|
if maxID != "" {
|
||||||
q = q.Where("id < ?", maxID)
|
q = q.Where("? < ?", bun.Ident("status.id"), maxID)
|
||||||
}
|
}
|
||||||
|
|
||||||
q = q.Limit(limit).Order("id DESC")
|
q = q.Limit(limit).Order("status.id DESC")
|
||||||
|
|
||||||
if err := q.Scan(ctx, &statusIDs); err != nil {
|
if err := q.Scan(ctx, &statusIDs); err != nil {
|
||||||
return nil, a.conn.ProcessError(err)
|
return nil, a.conn.ProcessError(err)
|
||||||
|
@ -411,16 +437,16 @@ func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxI
|
||||||
fq := a.conn.
|
fq := a.conn.
|
||||||
NewSelect().
|
NewSelect().
|
||||||
Model(&blocks).
|
Model(&blocks).
|
||||||
Where("block.account_id = ?", accountID).
|
Where("? = ?", bun.Ident("block.account_id"), accountID).
|
||||||
Relation("TargetAccount").
|
Relation("TargetAccount").
|
||||||
Order("block.id DESC")
|
Order("block.id DESC")
|
||||||
|
|
||||||
if maxID != "" {
|
if maxID != "" {
|
||||||
fq = fq.Where("block.id < ?", maxID)
|
fq = fq.Where("? < ?", bun.Ident("block.id"), maxID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if sinceID != "" {
|
if sinceID != "" {
|
||||||
fq = fq.Where("block.id > ?", sinceID)
|
fq = fq.Where("? > ?", bun.Ident("block.id"), sinceID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if limit > 0 {
|
if limit > 0 {
|
||||||
|
|
|
@ -22,7 +22,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/mail"
|
"net/mail"
|
||||||
|
@ -37,21 +36,26 @@ import (
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/id"
|
"github.com/superseriousbusiness/gotosocial/internal/id"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/uris"
|
"github.com/superseriousbusiness/gotosocial/internal/uris"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// generate RSA keys of this length
|
||||||
|
const rsaKeyBits = 2048
|
||||||
|
|
||||||
type adminDB struct {
|
type adminDB struct {
|
||||||
conn *DBConn
|
conn *DBConn
|
||||||
userCache *cache.UserCache
|
userCache *cache.UserCache
|
||||||
|
accountCache *cache.AccountCache
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) {
|
func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) {
|
||||||
q := a.conn.
|
q := a.conn.
|
||||||
NewSelect().
|
NewSelect().
|
||||||
Model(>smodel.Account{}).
|
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
|
||||||
Where("username = ?", username).
|
Column("account.id").
|
||||||
Where("domain = ?", nil)
|
Where("? = ?", bun.Ident("account.username"), username).
|
||||||
|
Where("? IS NULL", bun.Ident("account.domain"))
|
||||||
return a.conn.NotExists(ctx, q)
|
return a.conn.NotExists(ctx, q)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,29 +68,31 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.
|
||||||
domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
|
domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
|
||||||
|
|
||||||
// check if the email domain is blocked
|
// check if the email domain is blocked
|
||||||
if err := a.conn.
|
emailDomainBlockedQ := a.conn.
|
||||||
NewSelect().
|
NewSelect().
|
||||||
Model(>smodel.EmailDomainBlock{}).
|
TableExpr("? AS ?", bun.Ident("email_domain_blocks"), bun.Ident("email_domain_block")).
|
||||||
Where("domain = ?", domain).
|
Column("email_domain_block.id").
|
||||||
Scan(ctx); err == nil {
|
Where("? = ?", bun.Ident("email_domain_block.domain"), domain)
|
||||||
// fail because we found something
|
emailDomainBlocked, err := a.conn.Exists(ctx, emailDomainBlockedQ)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if emailDomainBlocked {
|
||||||
return false, fmt.Errorf("email domain %s is blocked", domain)
|
return false, fmt.Errorf("email domain %s is blocked", domain)
|
||||||
} else if err != sql.ErrNoRows {
|
|
||||||
return false, a.conn.ProcessError(err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if this email is associated with a user already
|
// check if this email is associated with a user already
|
||||||
q := a.conn.
|
q := a.conn.
|
||||||
NewSelect().
|
NewSelect().
|
||||||
Model(>smodel.User{}).
|
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
|
||||||
Where("email = ?", email).
|
Column("user.id").
|
||||||
WhereOr("unconfirmed_email = ?", email)
|
Where("? = ?", bun.Ident("user.email"), email).
|
||||||
|
WhereOr("? = ?", bun.Ident("user.unconfirmed_email"), email)
|
||||||
return a.conn.NotExists(ctx, q)
|
return a.conn.NotExists(ctx, q)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) {
|
func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) {
|
||||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
key, err := rsa.GenerateKey(rand.Reader, rsaKeyBits)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("error creating new rsa key: %s", err)
|
log.Errorf("error creating new rsa key: %s", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -94,13 +100,20 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
|
||||||
|
|
||||||
// if something went wrong while creating a user, we might already have an account, so check here first...
|
// if something went wrong while creating a user, we might already have an account, so check here first...
|
||||||
acct := >smodel.Account{}
|
acct := >smodel.Account{}
|
||||||
q := a.conn.NewSelect().
|
if err := a.conn.
|
||||||
|
NewSelect().
|
||||||
Model(acct).
|
Model(acct).
|
||||||
Where("username = ?", username).
|
Where("? = ?", bun.Ident("account.username"), username).
|
||||||
WhereGroup(" AND ", whereEmptyOrNull("domain"))
|
WhereGroup(" AND ", whereEmptyOrNull("account.domain")).
|
||||||
|
Scan(ctx); err != nil {
|
||||||
|
err = a.conn.ProcessError(err)
|
||||||
|
if err != db.ErrNoEntries {
|
||||||
|
log.Errorf("error checking for existing account: %s", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
if err := q.Scan(ctx); err != nil {
|
// if we have db.ErrNoEntries, we just don't have an
|
||||||
// we just don't have an account yet so create one before we proceed
|
// account yet so create one before we proceed
|
||||||
accountURIs := uris.GenerateURIsForAccount(username)
|
accountURIs := uris.GenerateURIsForAccount(username)
|
||||||
accountID, err := id.NewRandomULID()
|
accountID, err := id.NewRandomULID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -126,14 +139,19 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
|
||||||
FeaturedCollectionURI: accountURIs.CollectionURI,
|
FeaturedCollectionURI: accountURIs.CollectionURI,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// insert the new account!
|
||||||
if _, err = a.conn.
|
if _, err = a.conn.
|
||||||
NewInsert().
|
NewInsert().
|
||||||
Model(acct).
|
Model(acct).
|
||||||
Exec(ctx); err != nil {
|
Exec(ctx); err != nil {
|
||||||
return nil, a.conn.ProcessError(err)
|
return nil, a.conn.ProcessError(err)
|
||||||
}
|
}
|
||||||
|
a.accountCache.Put(acct)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// we either created or already had an account by now,
|
||||||
|
// so proceed with creating a user for that account
|
||||||
|
|
||||||
pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error hashing password: %s", err)
|
return nil, fmt.Errorf("error hashing password: %s", err)
|
||||||
|
@ -171,6 +189,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
|
||||||
u.Moderator = &moderator
|
u.Moderator = &moderator
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// insert the user!
|
||||||
if _, err = a.conn.
|
if _, err = a.conn.
|
||||||
NewInsert().
|
NewInsert().
|
||||||
Model(u).
|
Model(u).
|
||||||
|
@ -187,9 +206,10 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
|
||||||
|
|
||||||
q := a.conn.
|
q := a.conn.
|
||||||
NewSelect().
|
NewSelect().
|
||||||
Model(>smodel.Account{}).
|
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
|
||||||
Where("username = ?", username).
|
Column("account.id").
|
||||||
WhereGroup(" AND ", whereEmptyOrNull("domain"))
|
Where("? = ?", bun.Ident("account.username"), username).
|
||||||
|
WhereGroup(" AND ", whereEmptyOrNull("account.domain"))
|
||||||
|
|
||||||
exists, err := a.conn.Exists(ctx, q)
|
exists, err := a.conn.Exists(ctx, q)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -200,7 +220,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
key, err := rsa.GenerateKey(rand.Reader, rsaKeyBits)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("error creating new rsa key: %s", err)
|
log.Errorf("error creating new rsa key: %s", err)
|
||||||
return err
|
return err
|
||||||
|
@ -237,6 +257,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
|
||||||
return a.conn.ProcessError(err)
|
return a.conn.ProcessError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
a.accountCache.Put(acct)
|
||||||
log.Infof("instance account %s CREATED with id %s", username, acct.ID)
|
log.Infof("instance account %s CREATED with id %s", username, acct.ID)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -248,8 +269,9 @@ func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error {
|
||||||
// check if instance entry already exists
|
// check if instance entry already exists
|
||||||
q := a.conn.
|
q := a.conn.
|
||||||
NewSelect().
|
NewSelect().
|
||||||
Model(>smodel.Instance{}).
|
Column("instance.id").
|
||||||
Where("domain = ?", host)
|
TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")).
|
||||||
|
Where("? = ?", bun.Ident("instance.domain"), host)
|
||||||
|
|
||||||
exists, err := a.conn.Exists(ctx, q)
|
exists, err := a.conn.Exists(ctx, q)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -23,6 +23,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
|
gtsmodel "github.com/superseriousbusiness/gotosocial/internal/db/bundb/migrations/20211113114307_init"
|
||||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -30,6 +31,44 @@ type AdminTestSuite struct {
|
||||||
BunDBStandardTestSuite
|
BunDBStandardTestSuite
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (suite *AdminTestSuite) TestIsUsernameAvailableNo() {
|
||||||
|
available, err := suite.db.IsUsernameAvailable(context.Background(), "the_mighty_zork")
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.False(available)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *AdminTestSuite) TestIsUsernameAvailableYes() {
|
||||||
|
available, err := suite.db.IsUsernameAvailable(context.Background(), "someone_completely_different")
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.True(available)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *AdminTestSuite) TestIsEmailAvailableNo() {
|
||||||
|
available, err := suite.db.IsEmailAvailable(context.Background(), "zork@example.org")
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.False(available)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *AdminTestSuite) TestIsEmailAvailableYes() {
|
||||||
|
available, err := suite.db.IsEmailAvailable(context.Background(), "someone@somewhere.com")
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.True(available)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *AdminTestSuite) TestIsEmailAvailableDomainBlocked() {
|
||||||
|
if err := suite.db.Put(context.Background(), >smodel.EmailDomainBlock{
|
||||||
|
ID: "01GEEV2R2YC5GRSN96761YJE47",
|
||||||
|
Domain: "somewhere.com",
|
||||||
|
CreatedByAccountID: suite.testAccounts["admin_account"].ID,
|
||||||
|
}); err != nil {
|
||||||
|
suite.FailNow(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
available, err := suite.db.IsEmailAvailable(context.Background(), "someone@somewhere.com")
|
||||||
|
suite.EqualError(err, "email domain somewhere.com is blocked")
|
||||||
|
suite.False(available)
|
||||||
|
}
|
||||||
|
|
||||||
func (suite *AdminTestSuite) TestCreateInstanceAccount() {
|
func (suite *AdminTestSuite) TestCreateInstanceAccount() {
|
||||||
// we need to take an empty db for this...
|
// we need to take an empty db for this...
|
||||||
testrig.StandardDBTeardown(suite.db)
|
testrig.StandardDBTeardown(suite.db)
|
||||||
|
|
|
@ -110,7 +110,7 @@ func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string,
|
||||||
|
|
||||||
updateWhere(q, where)
|
updateWhere(q, where)
|
||||||
|
|
||||||
q = q.Set("? = ?", bun.Safe(key), value)
|
q = q.Set("? = ?", bun.Ident(key), value)
|
||||||
|
|
||||||
_, err := q.Exec(ctx)
|
_, err := q.Exec(ctx)
|
||||||
return b.conn.ProcessError(err)
|
return b.conn.ProcessError(err)
|
||||||
|
|
|
@ -159,17 +159,11 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {
|
||||||
return nil, fmt.Errorf("db migration error: %s", err)
|
return nil, fmt.Errorf("db migration error: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create DB structs that require ptrs to each other
|
// Prepare caches required by more than one struct
|
||||||
accounts := &accountDB{conn: conn, cache: cache.NewAccountCache()}
|
userCache := cache.NewUserCache()
|
||||||
status := &statusDB{conn: conn, cache: cache.NewStatusCache()}
|
accountCache := cache.NewAccountCache()
|
||||||
emoji := &emojiDB{conn: conn, cache: cache.NewEmojiCache()}
|
|
||||||
timeline := &timelineDB{conn: conn}
|
|
||||||
|
|
||||||
// Setup DB cross-referencing
|
|
||||||
accounts.status = status
|
|
||||||
status.accounts = accounts
|
|
||||||
timeline.status = status
|
|
||||||
|
|
||||||
|
// Prepare other caches
|
||||||
// Prepare mentions cache
|
// Prepare mentions cache
|
||||||
// TODO: move into internal/cache
|
// TODO: move into internal/cache
|
||||||
mentionCache := grufcache.New[string, *gtsmodel.Mention]()
|
mentionCache := grufcache.New[string, *gtsmodel.Mention]()
|
||||||
|
@ -182,22 +176,30 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {
|
||||||
notifCache.SetTTL(time.Minute*5, false)
|
notifCache.SetTTL(time.Minute*5, false)
|
||||||
notifCache.Start(time.Second * 10)
|
notifCache.Start(time.Second * 10)
|
||||||
|
|
||||||
// Prepare other caches
|
// Create DB structs that require ptrs to each other
|
||||||
blockCache := cache.NewDomainBlockCache()
|
accounts := &accountDB{conn: conn, cache: accountCache}
|
||||||
userCache := cache.NewUserCache()
|
status := &statusDB{conn: conn, cache: cache.NewStatusCache()}
|
||||||
|
emoji := &emojiDB{conn: conn, cache: cache.NewEmojiCache()}
|
||||||
|
timeline := &timelineDB{conn: conn}
|
||||||
|
|
||||||
|
// Setup DB cross-referencing
|
||||||
|
accounts.status = status
|
||||||
|
status.accounts = accounts
|
||||||
|
timeline.status = status
|
||||||
|
|
||||||
ps := &DBService{
|
ps := &DBService{
|
||||||
Account: accounts,
|
Account: accounts,
|
||||||
Admin: &adminDB{
|
Admin: &adminDB{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
userCache: userCache,
|
userCache: userCache,
|
||||||
|
accountCache: accountCache,
|
||||||
},
|
},
|
||||||
Basic: &basicDB{
|
Basic: &basicDB{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
},
|
},
|
||||||
Domain: &domainDB{
|
Domain: &domainDB{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
cache: blockCache,
|
cache: cache.NewDomainBlockCache(),
|
||||||
},
|
},
|
||||||
Emoji: emoji,
|
Emoji: emoji,
|
||||||
Instance: &instanceDB{
|
Instance: &instanceDB{
|
||||||
|
|
|
@ -40,6 +40,7 @@ type BunDBStandardTestSuite struct {
|
||||||
testStatuses map[string]*gtsmodel.Status
|
testStatuses map[string]*gtsmodel.Status
|
||||||
testTags map[string]*gtsmodel.Tag
|
testTags map[string]*gtsmodel.Tag
|
||||||
testMentions map[string]*gtsmodel.Mention
|
testMentions map[string]*gtsmodel.Mention
|
||||||
|
testFollows map[string]*gtsmodel.Follow
|
||||||
}
|
}
|
||||||
|
|
||||||
func (suite *BunDBStandardTestSuite) SetupSuite() {
|
func (suite *BunDBStandardTestSuite) SetupSuite() {
|
||||||
|
@ -52,6 +53,7 @@ func (suite *BunDBStandardTestSuite) SetupSuite() {
|
||||||
suite.testStatuses = testrig.NewTestStatuses()
|
suite.testStatuses = testrig.NewTestStatuses()
|
||||||
suite.testTags = testrig.NewTestTags()
|
suite.testTags = testrig.NewTestTags()
|
||||||
suite.testMentions = testrig.NewTestMentions()
|
suite.testMentions = testrig.NewTestMentions()
|
||||||
|
suite.testFollows = testrig.NewTestFollows()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (suite *BunDBStandardTestSuite) SetupTest() {
|
func (suite *BunDBStandardTestSuite) SetupTest() {
|
||||||
|
|
|
@ -28,6 +28,7 @@ import (
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
"golang.org/x/net/idna"
|
"golang.org/x/net/idna"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -95,7 +96,7 @@ func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel
|
||||||
q := d.conn.
|
q := d.conn.
|
||||||
NewSelect().
|
NewSelect().
|
||||||
Model(block).
|
Model(block).
|
||||||
Where("domain = ?", domain).
|
Where("? = ?", bun.Ident("domain_block.domain"), domain).
|
||||||
Limit(1)
|
Limit(1)
|
||||||
|
|
||||||
// Query database for domain block
|
// Query database for domain block
|
||||||
|
@ -126,7 +127,7 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Erro
|
||||||
// Attempt to delete domain block
|
// Attempt to delete domain block
|
||||||
if _, err := d.conn.NewDelete().
|
if _, err := d.conn.NewDelete().
|
||||||
Model((*gtsmodel.DomainBlock)(nil)).
|
Model((*gtsmodel.DomainBlock)(nil)).
|
||||||
Where("domain = ?", domain).
|
Where("? = ?", bun.Ident("domain_block.domain"), domain).
|
||||||
Exec(ctx); err != nil {
|
Exec(ctx); err != nil {
|
||||||
return d.conn.ProcessError(err)
|
return d.conn.ProcessError(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,12 +54,12 @@ func (e *emojiDB) GetCustomEmojis(ctx context.Context) ([]*gtsmodel.Emoji, db.Er
|
||||||
|
|
||||||
q := e.conn.
|
q := e.conn.
|
||||||
NewSelect().
|
NewSelect().
|
||||||
Table("emojis").
|
TableExpr("? AS ?", bun.Ident("emojis"), bun.Ident("emoji")).
|
||||||
Column("id").
|
Column("emoji.id").
|
||||||
Where("visible_in_picker = true").
|
Where("? = ?", bun.Ident("emoji.visible_in_picker"), true).
|
||||||
Where("disabled = false").
|
Where("? = ?", bun.Ident("emoji.disabled"), false).
|
||||||
Where("domain IS NULL").
|
Where("? IS NULL", bun.Ident("emoji.domain")).
|
||||||
Order("shortcode ASC")
|
Order("emoji.shortcode ASC")
|
||||||
|
|
||||||
if err := q.Scan(ctx, &emojiIDs); err != nil {
|
if err := q.Scan(ctx, &emojiIDs); err != nil {
|
||||||
return nil, e.conn.ProcessError(err)
|
return nil, e.conn.ProcessError(err)
|
||||||
|
@ -75,7 +75,7 @@ func (e *emojiDB) GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji,
|
||||||
return e.cache.GetByID(id)
|
return e.cache.GetByID(id)
|
||||||
},
|
},
|
||||||
func(emoji *gtsmodel.Emoji) error {
|
func(emoji *gtsmodel.Emoji) error {
|
||||||
return e.newEmojiQ(emoji).Where("emoji.id = ?", id).Scan(ctx)
|
return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.id"), id).Scan(ctx)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -87,7 +87,7 @@ func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoj
|
||||||
return e.cache.GetByURI(uri)
|
return e.cache.GetByURI(uri)
|
||||||
},
|
},
|
||||||
func(emoji *gtsmodel.Emoji) error {
|
func(emoji *gtsmodel.Emoji) error {
|
||||||
return e.newEmojiQ(emoji).Where("emoji.uri = ?", uri).Scan(ctx)
|
return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.uri"), uri).Scan(ctx)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -102,11 +102,11 @@ func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode strin
|
||||||
q := e.newEmojiQ(emoji)
|
q := e.newEmojiQ(emoji)
|
||||||
|
|
||||||
if domain != "" {
|
if domain != "" {
|
||||||
q = q.Where("emoji.shortcode = ?", shortcode)
|
q = q.Where("? = ?", bun.Ident("emoji.shortcode"), shortcode)
|
||||||
q = q.Where("emoji.domain = ?", domain)
|
q = q.Where("? = ?", bun.Ident("emoji.domain"), domain)
|
||||||
} else {
|
} else {
|
||||||
q = q.Where("emoji.shortcode = ?", strings.ToLower(shortcode))
|
q = q.Where("? = ?", bun.Ident("emoji.shortcode"), strings.ToLower(shortcode))
|
||||||
q = q.Where("emoji.domain IS NULL")
|
q = q.Where("? IS NULL", bun.Ident("emoji.domain"))
|
||||||
}
|
}
|
||||||
|
|
||||||
return q.Scan(ctx)
|
return q.Scan(ctx)
|
||||||
|
|
|
@ -24,7 +24,6 @@ import (
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -35,15 +34,16 @@ type instanceDB struct {
|
||||||
func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) {
|
func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) {
|
||||||
q := i.conn.
|
q := i.conn.
|
||||||
NewSelect().
|
NewSelect().
|
||||||
Model(&[]*gtsmodel.Account{}).
|
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
|
||||||
Where("username != ?", domain).
|
Column("account.id").
|
||||||
Where("? IS NULL", bun.Ident("suspended_at"))
|
Where("? != ?", bun.Ident("account.username"), domain).
|
||||||
|
Where("? IS NULL", bun.Ident("account.suspended_at"))
|
||||||
|
|
||||||
if domain == config.GetHost() {
|
if domain == config.GetHost() || domain == config.GetAccountDomain() {
|
||||||
// if the domain is *this* domain, just count where the domain field is null
|
// if the domain is *this* domain, just count where the domain field is null
|
||||||
q = q.WhereGroup(" AND ", whereEmptyOrNull("domain"))
|
q = q.WhereGroup(" AND ", whereEmptyOrNull("account.domain"))
|
||||||
} else {
|
} else {
|
||||||
q = q.Where("domain = ?", domain)
|
q = q.Where("? = ?", bun.Ident("account.domain"), domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
count, err := q.Count(ctx)
|
count, err := q.Count(ctx)
|
||||||
|
@ -56,15 +56,16 @@ func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int
|
||||||
func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) {
|
func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) {
|
||||||
q := i.conn.
|
q := i.conn.
|
||||||
NewSelect().
|
NewSelect().
|
||||||
Model(&[]*gtsmodel.Status{})
|
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status"))
|
||||||
|
|
||||||
if domain == config.GetHost() {
|
if domain == config.GetHost() || domain == config.GetAccountDomain() {
|
||||||
// if the domain is *this* domain, just count where local is true
|
// if the domain is *this* domain, just count where local is true
|
||||||
q = q.Where("local = ?", true)
|
q = q.Where("? = ?", bun.Ident("status.local"), true)
|
||||||
} else {
|
} else {
|
||||||
// join on the domain of the account
|
// join on the domain of the account
|
||||||
q = q.Join("JOIN accounts AS account ON account.id = status.account_id").
|
q = q.
|
||||||
Where("account.domain = ?", domain)
|
Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("account.id"), bun.Ident("status.account_id")).
|
||||||
|
Where("? = ?", bun.Ident("account.domain"), domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
count, err := q.Count(ctx)
|
count, err := q.Count(ctx)
|
||||||
|
@ -77,14 +78,14 @@ func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (
|
||||||
func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) {
|
func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) {
|
||||||
q := i.conn.
|
q := i.conn.
|
||||||
NewSelect().
|
NewSelect().
|
||||||
Model(&[]*gtsmodel.Instance{})
|
TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance"))
|
||||||
|
|
||||||
if domain == config.GetHost() {
|
if domain == config.GetHost() {
|
||||||
// if the domain is *this* domain, just count other instances it knows about
|
// if the domain is *this* domain, just count other instances it knows about
|
||||||
// exclude domains that are blocked
|
// exclude domains that are blocked
|
||||||
q = q.
|
q = q.
|
||||||
Where("domain != ?", domain).
|
Where("? != ?", bun.Ident("instance.domain"), domain).
|
||||||
Where("? IS NULL", bun.Ident("suspended_at"))
|
Where("? IS NULL", bun.Ident("instance.suspended_at"))
|
||||||
} else {
|
} else {
|
||||||
// TODO: implement federated domain counting properly for remote domains
|
// TODO: implement federated domain counting properly for remote domains
|
||||||
return 0, nil
|
return 0, nil
|
||||||
|
@ -103,10 +104,10 @@ func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool
|
||||||
q := i.conn.
|
q := i.conn.
|
||||||
NewSelect().
|
NewSelect().
|
||||||
Model(&instances).
|
Model(&instances).
|
||||||
Where("domain != ?", config.GetHost())
|
Where("? != ?", bun.Ident("instance.domain"), config.GetHost())
|
||||||
|
|
||||||
if !includeSuspended {
|
if !includeSuspended {
|
||||||
q = q.Where("? IS NULL", bun.Ident("suspended_at"))
|
q = q.Where("? IS NULL", bun.Ident("instance.suspended_at"))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := q.Scan(ctx); err != nil {
|
if err := q.Scan(ctx); err != nil {
|
||||||
|
@ -117,17 +118,15 @@ func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
|
func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
|
||||||
log.Debug("GetAccountsForInstance")
|
|
||||||
|
|
||||||
accounts := []*gtsmodel.Account{}
|
accounts := []*gtsmodel.Account{}
|
||||||
|
|
||||||
q := i.conn.NewSelect().
|
q := i.conn.NewSelect().
|
||||||
Model(&accounts).
|
Model(&accounts).
|
||||||
Where("domain = ?", domain).
|
Where("? = ?", bun.Ident("account.domain"), domain).
|
||||||
Order("id DESC")
|
Order("account.id DESC")
|
||||||
|
|
||||||
if maxID != "" {
|
if maxID != "" {
|
||||||
q = q.Where("id < ?", maxID)
|
q = q.Where("? < ?", bun.Ident("account.id"), maxID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if limit > 0 {
|
if limit > 0 {
|
||||||
|
|
|
@ -0,0 +1,83 @@
|
||||||
|
/*
|
||||||
|
GoToSocial
|
||||||
|
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU Affero General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Affero General Public License
|
||||||
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package bundb_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
type InstanceTestSuite struct {
|
||||||
|
BunDBStandardTestSuite
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *InstanceTestSuite) TestCountInstanceUsers() {
|
||||||
|
count, err := suite.db.CountInstanceUsers(context.Background(), config.GetHost())
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.Equal(4, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *InstanceTestSuite) TestCountInstanceUsersRemote() {
|
||||||
|
count, err := suite.db.CountInstanceUsers(context.Background(), "fossbros-anonymous.io")
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.Equal(1, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *InstanceTestSuite) TestCountInstanceStatuses() {
|
||||||
|
count, err := suite.db.CountInstanceStatuses(context.Background(), config.GetHost())
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.Equal(16, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *InstanceTestSuite) TestCountInstanceStatusesRemote() {
|
||||||
|
count, err := suite.db.CountInstanceStatuses(context.Background(), "fossbros-anonymous.io")
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.Equal(1, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *InstanceTestSuite) TestCountInstanceDomains() {
|
||||||
|
count, err := suite.db.CountInstanceDomains(context.Background(), config.GetHost())
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.Equal(2, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *InstanceTestSuite) TestGetInstancePeers() {
|
||||||
|
peers, err := suite.db.GetInstancePeers(context.Background(), false)
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.Len(peers, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *InstanceTestSuite) TestGetInstancePeersIncludeSuspended() {
|
||||||
|
peers, err := suite.db.GetInstancePeers(context.Background(), true)
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.Len(peers, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *InstanceTestSuite) TestGetInstanceAccounts() {
|
||||||
|
accounts, err := suite.db.GetInstanceAccounts(context.Background(), "fossbros-anonymous.io", "", 10)
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.Len(accounts, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInstanceTestSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(InstanceTestSuite))
|
||||||
|
}
|
|
@ -42,7 +42,7 @@ func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.M
|
||||||
attachment := >smodel.MediaAttachment{}
|
attachment := >smodel.MediaAttachment{}
|
||||||
|
|
||||||
q := m.newMediaQ(attachment).
|
q := m.newMediaQ(attachment).
|
||||||
Where("media_attachment.id = ?", id)
|
Where("? = ?", bun.Ident("media_attachment.id"), id)
|
||||||
|
|
||||||
if err := q.Scan(ctx); err != nil {
|
if err := q.Scan(ctx); err != nil {
|
||||||
return nil, m.conn.ProcessError(err)
|
return nil, m.conn.ProcessError(err)
|
||||||
|
@ -56,10 +56,10 @@ func (m *mediaDB) GetRemoteOlderThan(ctx context.Context, olderThan time.Time, l
|
||||||
q := m.conn.
|
q := m.conn.
|
||||||
NewSelect().
|
NewSelect().
|
||||||
Model(&attachments).
|
Model(&attachments).
|
||||||
Where("media_attachment.cached = true").
|
Where("? = ?", bun.Ident("media_attachment.cached"), true).
|
||||||
Where("media_attachment.avatar = false").
|
Where("? = ?", bun.Ident("media_attachment.avatar"), false).
|
||||||
Where("media_attachment.header = false").
|
Where("? = ?", bun.Ident("media_attachment.header"), false).
|
||||||
Where("media_attachment.created_at < ?", olderThan).
|
Where("? < ?", bun.Ident("media_attachment.created_at"), olderThan).
|
||||||
WhereGroup(" AND ", whereNotEmptyAndNotNull("media_attachment.remote_url")).
|
WhereGroup(" AND ", whereNotEmptyAndNotNull("media_attachment.remote_url")).
|
||||||
Order("media_attachment.created_at DESC")
|
Order("media_attachment.created_at DESC")
|
||||||
|
|
||||||
|
@ -79,13 +79,13 @@ func (m *mediaDB) GetAvatarsAndHeaders(ctx context.Context, maxID string, limit
|
||||||
q := m.newMediaQ(&attachments).
|
q := m.newMediaQ(&attachments).
|
||||||
WhereGroup(" AND ", func(innerQ *bun.SelectQuery) *bun.SelectQuery {
|
WhereGroup(" AND ", func(innerQ *bun.SelectQuery) *bun.SelectQuery {
|
||||||
return innerQ.
|
return innerQ.
|
||||||
WhereOr("media_attachment.avatar = true").
|
WhereOr("? = ?", bun.Ident("media_attachment.avatar"), true).
|
||||||
WhereOr("media_attachment.header = true")
|
WhereOr("? = ?", bun.Ident("media_attachment.header"), true)
|
||||||
}).
|
}).
|
||||||
Order("media_attachment.id DESC")
|
Order("media_attachment.id DESC")
|
||||||
|
|
||||||
if maxID != "" {
|
if maxID != "" {
|
||||||
q = q.Where("media_attachment.id < ?", maxID)
|
q = q.Where("? < ?", bun.Ident("media_attachment.id"), maxID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if limit != 0 {
|
if limit != 0 {
|
||||||
|
@ -103,15 +103,15 @@ func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan tim
|
||||||
attachments := []*gtsmodel.MediaAttachment{}
|
attachments := []*gtsmodel.MediaAttachment{}
|
||||||
|
|
||||||
q := m.newMediaQ(&attachments).
|
q := m.newMediaQ(&attachments).
|
||||||
Where("media_attachment.cached = true").
|
Where("? = ?", bun.Ident("media_attachment.cached"), true).
|
||||||
Where("media_attachment.avatar = false").
|
Where("? = ?", bun.Ident("media_attachment.avatar"), false).
|
||||||
Where("media_attachment.header = false").
|
Where("? = ?", bun.Ident("media_attachment.header"), false).
|
||||||
Where("media_attachment.created_at < ?", olderThan).
|
Where("? < ?", bun.Ident("media_attachment.created_at"), olderThan).
|
||||||
Where("media_attachment.remote_url IS NULL").
|
Where("? IS NULL", bun.Ident("media_attachment.remote_url")).
|
||||||
Where("media_attachment.status_id IS NULL")
|
Where("? IS NULL", bun.Ident("media_attachment.status_id"))
|
||||||
|
|
||||||
if maxID != "" {
|
if maxID != "" {
|
||||||
q = q.Where("media_attachment.id < ?", maxID)
|
q = q.Where("? < ?", bun.Ident("media_attachment.id"), maxID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if limit != 0 {
|
if limit != 0 {
|
||||||
|
|
|
@ -46,7 +46,7 @@ func (m *mentionDB) getMentionDB(ctx context.Context, id string) (*gtsmodel.Ment
|
||||||
mention := gtsmodel.Mention{}
|
mention := gtsmodel.Mention{}
|
||||||
|
|
||||||
q := m.newMentionQ(&mention).
|
q := m.newMentionQ(&mention).
|
||||||
Where("mention.id = ?", id)
|
Where("? = ?", bun.Ident("mention.id"), id)
|
||||||
|
|
||||||
if err := q.Scan(ctx); err != nil {
|
if err := q.Scan(ctx); err != nil {
|
||||||
return nil, m.conn.ProcessError(err)
|
return nil, m.conn.ProcessError(err)
|
||||||
|
|
|
@ -25,6 +25,7 @@ import (
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
)
|
)
|
||||||
|
|
||||||
type notificationDB struct {
|
type notificationDB struct {
|
||||||
|
@ -67,24 +68,24 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,
|
||||||
|
|
||||||
q := n.conn.
|
q := n.conn.
|
||||||
NewSelect().
|
NewSelect().
|
||||||
Table("notifications").
|
TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
|
||||||
Column("id")
|
Column("notification.id")
|
||||||
|
|
||||||
if maxID != "" {
|
if maxID != "" {
|
||||||
q = q.Where("id < ?", maxID)
|
q = q.Where("? < ?", bun.Ident("notification.id"), maxID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if sinceID != "" {
|
if sinceID != "" {
|
||||||
q = q.Where("id > ?", sinceID)
|
q = q.Where("? > ?", bun.Ident("notification.id"), sinceID)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, excludeType := range excludeTypes {
|
for _, excludeType := range excludeTypes {
|
||||||
q = q.Where("notification_type != ?", excludeType)
|
q = q.Where("? != ?", bun.Ident("notification.notification_type"), excludeType)
|
||||||
}
|
}
|
||||||
|
|
||||||
q = q.
|
q = q.
|
||||||
Where("target_account_id = ?", accountID).
|
Where("? = ?", bun.Ident("notification.target_account_id"), accountID).
|
||||||
Order("id DESC")
|
Order("notification.id DESC")
|
||||||
|
|
||||||
if limit != 0 {
|
if limit != 0 {
|
||||||
q = q.Limit(limit)
|
q = q.Limit(limit)
|
||||||
|
@ -116,13 +117,12 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,
|
||||||
func (n *notificationDB) ClearNotifications(ctx context.Context, accountID string) db.Error {
|
func (n *notificationDB) ClearNotifications(ctx context.Context, accountID string) db.Error {
|
||||||
if _, err := n.conn.
|
if _, err := n.conn.
|
||||||
NewDelete().
|
NewDelete().
|
||||||
Table("notifications").
|
TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
|
||||||
Where("target_account_id = ?", accountID).
|
Where("? = ?", bun.Ident("notification.target_account_id"), accountID).
|
||||||
Exec(ctx); err != nil {
|
Exec(ctx); err != nil {
|
||||||
return n.conn.ProcessError(err)
|
return n.conn.ProcessError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
n.cache.Clear()
|
n.cache.Clear()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -51,26 +51,25 @@ func (r *relationshipDB) newFollowQ(follow interface{}) *bun.SelectQuery {
|
||||||
func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) {
|
func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) {
|
||||||
q := r.conn.
|
q := r.conn.
|
||||||
NewSelect().
|
NewSelect().
|
||||||
Model(>smodel.Block{}).
|
TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
|
||||||
ExcludeColumn("id", "created_at", "updated_at", "uri").
|
Column("block.id")
|
||||||
Limit(1)
|
|
||||||
|
|
||||||
if eitherDirection {
|
if eitherDirection {
|
||||||
q = q.
|
q = q.
|
||||||
WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery {
|
WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery {
|
||||||
return inner.
|
return inner.
|
||||||
Where("account_id = ?", account1).
|
Where("? = ?", bun.Ident("block.account_id"), account1).
|
||||||
Where("target_account_id = ?", account2)
|
Where("? = ?", bun.Ident("block.target_account_id"), account2)
|
||||||
}).
|
}).
|
||||||
WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery {
|
WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery {
|
||||||
return inner.
|
return inner.
|
||||||
Where("account_id = ?", account2).
|
Where("? = ?", bun.Ident("block.account_id"), account2).
|
||||||
Where("target_account_id = ?", account1)
|
Where("? = ?", bun.Ident("block.target_account_id"), account1)
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
q = q.
|
q = q.
|
||||||
Where("account_id = ?", account1).
|
Where("? = ?", bun.Ident("block.account_id"), account1).
|
||||||
Where("target_account_id = ?", account2)
|
Where("? = ?", bun.Ident("block.target_account_id"), account2)
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.conn.Exists(ctx, q)
|
return r.conn.Exists(ctx, q)
|
||||||
|
@ -80,8 +79,8 @@ func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2
|
||||||
block := >smodel.Block{}
|
block := >smodel.Block{}
|
||||||
|
|
||||||
q := r.newBlockQ(block).
|
q := r.newBlockQ(block).
|
||||||
Where("block.account_id = ?", account1).
|
Where("? = ?", bun.Ident("block.account_id"), account1).
|
||||||
Where("block.target_account_id = ?", account2)
|
Where("? = ?", bun.Ident("block.target_account_id"), account2)
|
||||||
|
|
||||||
if err := q.Scan(ctx); err != nil {
|
if err := q.Scan(ctx); err != nil {
|
||||||
return nil, r.conn.ProcessError(err)
|
return nil, r.conn.ProcessError(err)
|
||||||
|
@ -99,13 +98,13 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount
|
||||||
if err := r.conn.
|
if err := r.conn.
|
||||||
NewSelect().
|
NewSelect().
|
||||||
Model(follow).
|
Model(follow).
|
||||||
Where("account_id = ?", requestingAccount).
|
Column("follow.show_reblogs", "follow.notify").
|
||||||
Where("target_account_id = ?", targetAccount).
|
Where("? = ?", bun.Ident("follow.account_id"), requestingAccount).
|
||||||
|
Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount).
|
||||||
Limit(1).
|
Limit(1).
|
||||||
Scan(ctx); err != nil {
|
Scan(ctx); err != nil {
|
||||||
if err != sql.ErrNoRows {
|
if err := r.conn.ProcessError(err); err != db.ErrNoEntries {
|
||||||
// a proper error
|
return nil, fmt.Errorf("GetRelationship: error fetching follow: %s", err)
|
||||||
return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err)
|
|
||||||
}
|
}
|
||||||
// no follow exists so these are all false
|
// no follow exists so these are all false
|
||||||
rel.Following = false
|
rel.Following = false
|
||||||
|
@ -119,55 +118,56 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if the target account follows the requesting account
|
// check if the target account follows the requesting account
|
||||||
count, err := r.conn.
|
followedByQ := r.conn.
|
||||||
NewSelect().
|
NewSelect().
|
||||||
Model(>smodel.Follow{}).
|
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
|
||||||
Where("account_id = ?", targetAccount).
|
Column("follow.id").
|
||||||
Where("target_account_id = ?", requestingAccount).
|
Where("? = ?", bun.Ident("follow.account_id"), targetAccount).
|
||||||
Limit(1).
|
Where("? = ?", bun.Ident("follow.target_account_id"), requestingAccount)
|
||||||
Count(ctx)
|
followedBy, err := r.conn.Exists(ctx, followedByQ)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err)
|
return nil, fmt.Errorf("GetRelationship: error checking followedBy: %s", err)
|
||||||
}
|
}
|
||||||
rel.FollowedBy = count > 0
|
rel.FollowedBy = followedBy
|
||||||
|
|
||||||
// check if the requesting account blocks the target account
|
|
||||||
count, err = r.conn.NewSelect().
|
|
||||||
Model(>smodel.Block{}).
|
|
||||||
Where("account_id = ?", requestingAccount).
|
|
||||||
Where("target_account_id = ?", targetAccount).
|
|
||||||
Limit(1).
|
|
||||||
Count(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err)
|
|
||||||
}
|
|
||||||
rel.Blocking = count > 0
|
|
||||||
|
|
||||||
// check if the target account blocks the requesting account
|
|
||||||
count, err = r.conn.
|
|
||||||
NewSelect().
|
|
||||||
Model(>smodel.Block{}).
|
|
||||||
Where("account_id = ?", targetAccount).
|
|
||||||
Where("target_account_id = ?", requestingAccount).
|
|
||||||
Limit(1).
|
|
||||||
Count(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
|
|
||||||
}
|
|
||||||
rel.BlockedBy = count > 0
|
|
||||||
|
|
||||||
// check if there's a pending following request from requesting account to target account
|
// check if there's a pending following request from requesting account to target account
|
||||||
count, err = r.conn.
|
requestedQ := r.conn.
|
||||||
NewSelect().
|
NewSelect().
|
||||||
Model(>smodel.FollowRequest{}).
|
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
|
||||||
Where("account_id = ?", requestingAccount).
|
Column("follow_request.id").
|
||||||
Where("target_account_id = ?", targetAccount).
|
Where("? = ?", bun.Ident("follow_request.account_id"), requestingAccount).
|
||||||
Limit(1).
|
Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount)
|
||||||
Count(ctx)
|
requested, err := r.conn.Exists(ctx, requestedQ)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
|
return nil, fmt.Errorf("GetRelationship: error checking requested: %s", err)
|
||||||
}
|
}
|
||||||
rel.Requested = count > 0
|
rel.Requested = requested
|
||||||
|
|
||||||
|
// check if the requesting account is blocking the target account
|
||||||
|
blockingQ := r.conn.
|
||||||
|
NewSelect().
|
||||||
|
TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
|
||||||
|
Column("block.id").
|
||||||
|
Where("? = ?", bun.Ident("block.account_id"), requestingAccount).
|
||||||
|
Where("? = ?", bun.Ident("block.target_account_id"), targetAccount)
|
||||||
|
blocking, err := r.conn.Exists(ctx, blockingQ)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("GetRelationship: error checking blocking: %s", err)
|
||||||
|
}
|
||||||
|
rel.Blocking = blocking
|
||||||
|
|
||||||
|
// check if the requesting account is blocked by the target account
|
||||||
|
blockedByQ := r.conn.
|
||||||
|
NewSelect().
|
||||||
|
TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
|
||||||
|
Column("block.id").
|
||||||
|
Where("? = ?", bun.Ident("block.account_id"), targetAccount).
|
||||||
|
Where("? = ?", bun.Ident("block.target_account_id"), requestingAccount)
|
||||||
|
blockedBy, err := r.conn.Exists(ctx, blockedByQ)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %s", err)
|
||||||
|
}
|
||||||
|
rel.BlockedBy = blockedBy
|
||||||
|
|
||||||
return rel, nil
|
return rel, nil
|
||||||
}
|
}
|
||||||
|
@ -179,10 +179,10 @@ func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmode
|
||||||
|
|
||||||
q := r.conn.
|
q := r.conn.
|
||||||
NewSelect().
|
NewSelect().
|
||||||
Model(>smodel.Follow{}).
|
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
|
||||||
Where("account_id = ?", sourceAccount.ID).
|
Column("follow.id").
|
||||||
Where("target_account_id = ?", targetAccount.ID).
|
Where("? = ?", bun.Ident("follow.account_id"), sourceAccount.ID).
|
||||||
Limit(1)
|
Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount.ID)
|
||||||
|
|
||||||
return r.conn.Exists(ctx, q)
|
return r.conn.Exists(ctx, q)
|
||||||
}
|
}
|
||||||
|
@ -194,9 +194,10 @@ func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *g
|
||||||
|
|
||||||
q := r.conn.
|
q := r.conn.
|
||||||
NewSelect().
|
NewSelect().
|
||||||
Model(>smodel.FollowRequest{}).
|
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
|
||||||
Where("account_id = ?", sourceAccount.ID).
|
Column("follow_request.id").
|
||||||
Where("target_account_id = ?", targetAccount.ID)
|
Where("? = ?", bun.Ident("follow_request.account_id"), sourceAccount.ID).
|
||||||
|
Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount.ID)
|
||||||
|
|
||||||
return r.conn.Exists(ctx, q)
|
return r.conn.Exists(ctx, q)
|
||||||
}
|
}
|
||||||
|
@ -222,82 +223,98 @@ func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmod
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
|
func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
|
||||||
// make sure the original follow request exists
|
var follow *gtsmodel.Follow
|
||||||
fr := >smodel.FollowRequest{}
|
|
||||||
if err := r.conn.
|
if err := r.conn.RunInTx(ctx, func(tx bun.Tx) error {
|
||||||
|
// get original follow request
|
||||||
|
followRequest := >smodel.FollowRequest{}
|
||||||
|
if err := tx.
|
||||||
NewSelect().
|
NewSelect().
|
||||||
Model(fr).
|
Model(followRequest).
|
||||||
Where("account_id = ?", originAccountID).
|
Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID).
|
||||||
Where("target_account_id = ?", targetAccountID).
|
Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID).
|
||||||
Scan(ctx); err != nil {
|
Scan(ctx); err != nil {
|
||||||
return nil, r.conn.ProcessError(err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// create a new follow to 'replace' the request with
|
// create a new follow to 'replace' the request with
|
||||||
follow := >smodel.Follow{
|
follow = >smodel.Follow{
|
||||||
ID: fr.ID,
|
ID: followRequest.ID,
|
||||||
AccountID: originAccountID,
|
AccountID: originAccountID,
|
||||||
TargetAccountID: targetAccountID,
|
TargetAccountID: targetAccountID,
|
||||||
URI: fr.URI,
|
URI: followRequest.URI,
|
||||||
}
|
}
|
||||||
|
|
||||||
// if the follow already exists, just update the URI -- we don't need to do anything else
|
// if the follow already exists, just update the URI -- we don't need to do anything else
|
||||||
if _, err := r.conn.
|
if _, err := tx.
|
||||||
NewInsert().
|
NewInsert().
|
||||||
Model(follow).
|
Model(follow).
|
||||||
On("CONFLICT (account_id,target_account_id) DO UPDATE set uri = ?", follow.URI).
|
On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI).
|
||||||
Exec(ctx); err != nil {
|
Exec(ctx); err != nil {
|
||||||
return nil, r.conn.ProcessError(err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// now remove the follow request
|
// now remove the follow request
|
||||||
if _, err := r.conn.
|
if _, err := tx.
|
||||||
NewDelete().
|
NewDelete().
|
||||||
Model(>smodel.FollowRequest{}).
|
Model(followRequest).
|
||||||
Where("account_id = ?", originAccountID).
|
WherePK().
|
||||||
Where("target_account_id = ?", targetAccountID).
|
|
||||||
Exec(ctx); err != nil {
|
Exec(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}); err != nil {
|
||||||
return nil, r.conn.ProcessError(err)
|
return nil, r.conn.ProcessError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// return the new follow
|
||||||
return follow, nil
|
return follow, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *relationshipDB) RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, db.Error) {
|
func (r *relationshipDB) RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, db.Error) {
|
||||||
// first get the follow request out of the database
|
followRequest := >smodel.FollowRequest{}
|
||||||
fr := >smodel.FollowRequest{}
|
|
||||||
if err := r.conn.
|
if err := r.conn.RunInTx(ctx, func(tx bun.Tx) error {
|
||||||
|
// get original follow request
|
||||||
|
if err := tx.
|
||||||
NewSelect().
|
NewSelect().
|
||||||
Model(fr).
|
Model(followRequest).
|
||||||
Where("account_id = ?", originAccountID).
|
Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID).
|
||||||
Where("target_account_id = ?", targetAccountID).
|
Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID).
|
||||||
Scan(ctx); err != nil {
|
Scan(ctx); err != nil {
|
||||||
return nil, r.conn.ProcessError(err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// now delete it from the database by ID
|
// now delete it from the database by ID
|
||||||
if _, err := r.conn.
|
if _, err := tx.
|
||||||
NewDelete().
|
NewDelete().
|
||||||
Model(>smodel.FollowRequest{ID: fr.ID}).
|
Model(followRequest).
|
||||||
WherePK().
|
WherePK().
|
||||||
Exec(ctx); err != nil {
|
Exec(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}); err != nil {
|
||||||
return nil, r.conn.ProcessError(err)
|
return nil, r.conn.ProcessError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// return the deleted follow request
|
// return the deleted follow request
|
||||||
return fr, nil
|
return followRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, db.Error) {
|
func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, db.Error) {
|
||||||
followRequests := []*gtsmodel.FollowRequest{}
|
followRequests := []*gtsmodel.FollowRequest{}
|
||||||
|
|
||||||
q := r.newFollowQ(&followRequests).
|
q := r.newFollowQ(&followRequests).
|
||||||
Where("target_account_id = ?", accountID).
|
Where("? = ?", bun.Ident("follow_request.target_account_id"), accountID).
|
||||||
Order("follow_request.updated_at DESC")
|
Order("follow_request.updated_at DESC")
|
||||||
|
|
||||||
if err := q.Scan(ctx); err != nil {
|
if err := q.Scan(ctx); err != nil {
|
||||||
return nil, r.conn.ProcessError(err)
|
return nil, r.conn.ProcessError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return followRequests, nil
|
return followRequests, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -305,21 +322,31 @@ func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string
|
||||||
follows := []*gtsmodel.Follow{}
|
follows := []*gtsmodel.Follow{}
|
||||||
|
|
||||||
q := r.newFollowQ(&follows).
|
q := r.newFollowQ(&follows).
|
||||||
Where("account_id = ?", accountID).
|
Where("? = ?", bun.Ident("follow.account_id"), accountID).
|
||||||
Order("follow.updated_at DESC")
|
Order("follow.updated_at DESC")
|
||||||
|
|
||||||
if err := q.Scan(ctx); err != nil {
|
if err := q.Scan(ctx); err != nil {
|
||||||
return nil, r.conn.ProcessError(err)
|
return nil, r.conn.ProcessError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return follows, nil
|
return follows, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, db.Error) {
|
func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, db.Error) {
|
||||||
return r.conn.
|
q := r.conn.
|
||||||
NewSelect().
|
NewSelect().
|
||||||
Model(&[]*gtsmodel.Follow{}).
|
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow"))
|
||||||
Where("account_id = ?", accountID).
|
|
||||||
Count(ctx)
|
if localOnly {
|
||||||
|
q = q.
|
||||||
|
Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.target_account_id"), bun.Ident("account.id")).
|
||||||
|
Where("? = ?", bun.Ident("follow.account_id"), accountID).
|
||||||
|
Where("? IS NULL", bun.Ident("account.domain"))
|
||||||
|
} else {
|
||||||
|
q = q.Where("? = ?", bun.Ident("follow.account_id"), accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return q.Count(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) {
|
func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) {
|
||||||
|
@ -331,12 +358,12 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str
|
||||||
Order("follow.updated_at DESC")
|
Order("follow.updated_at DESC")
|
||||||
|
|
||||||
if localOnly {
|
if localOnly {
|
||||||
q = q.ColumnExpr("follow.*").
|
q = q.
|
||||||
Join("JOIN accounts AS a ON follow.account_id = CAST(a.id as TEXT)").
|
Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.account_id"), bun.Ident("account.id")).
|
||||||
Where("follow.target_account_id = ?", accountID).
|
Where("? = ?", bun.Ident("follow.target_account_id"), accountID).
|
||||||
WhereGroup(" AND ", whereEmptyOrNull("a.domain"))
|
Where("? IS NULL", bun.Ident("account.domain"))
|
||||||
} else {
|
} else {
|
||||||
q = q.Where("target_account_id = ?", accountID)
|
q = q.Where("? = ?", bun.Ident("follow.target_account_id"), accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := q.Scan(ctx)
|
err := q.Scan(ctx)
|
||||||
|
@ -347,9 +374,18 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *relationshipDB) CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, db.Error) {
|
func (r *relationshipDB) CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, db.Error) {
|
||||||
return r.conn.
|
q := r.conn.
|
||||||
NewSelect().
|
NewSelect().
|
||||||
Model(&[]*gtsmodel.Follow{}).
|
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow"))
|
||||||
Where("target_account_id = ?", accountID).
|
|
||||||
Count(ctx)
|
if localOnly {
|
||||||
|
q = q.
|
||||||
|
Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.account_id"), bun.Ident("account.id")).
|
||||||
|
Where("? = ?", bun.Ident("follow.target_account_id"), accountID).
|
||||||
|
Where("? IS NULL", bun.Ident("account.domain"))
|
||||||
|
} else {
|
||||||
|
q = q.Where("? = ?", bun.Ident("follow.target_account_id"), accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return q.Count(ctx)
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,7 +20,6 @@ package bundb_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
|
@ -48,12 +47,14 @@ func (suite *RelationshipTestSuite) TestIsBlocked() {
|
||||||
suite.False(blocked)
|
suite.False(blocked)
|
||||||
|
|
||||||
// have account1 block account2
|
// have account1 block account2
|
||||||
suite.db.Put(ctx, >smodel.Block{
|
if err := suite.db.Put(ctx, >smodel.Block{
|
||||||
ID: "01G202BCSXXJZ70BHB5KCAHH8C",
|
ID: "01G202BCSXXJZ70BHB5KCAHH8C",
|
||||||
URI: "http://localhost:8080/some_block_uri_1",
|
URI: "http://localhost:8080/some_block_uri_1",
|
||||||
AccountID: account1,
|
AccountID: account1,
|
||||||
TargetAccountID: account2,
|
TargetAccountID: account2,
|
||||||
})
|
}); err != nil {
|
||||||
|
suite.FailNow(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
// account 1 now blocks account 2
|
// account 1 now blocks account 2
|
||||||
blocked, err = suite.db.IsBlocked(ctx, account1, account2, false)
|
blocked, err = suite.db.IsBlocked(ctx, account1, account2, false)
|
||||||
|
@ -75,62 +76,242 @@ func (suite *RelationshipTestSuite) TestIsBlocked() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (suite *RelationshipTestSuite) TestGetBlock() {
|
func (suite *RelationshipTestSuite) TestGetBlock() {
|
||||||
suite.Suite.T().Skip("TODO: implement")
|
ctx := context.Background()
|
||||||
|
|
||||||
|
account1 := suite.testAccounts["local_account_1"].ID
|
||||||
|
account2 := suite.testAccounts["local_account_2"].ID
|
||||||
|
|
||||||
|
if err := suite.db.Put(ctx, >smodel.Block{
|
||||||
|
ID: "01G202BCSXXJZ70BHB5KCAHH8C",
|
||||||
|
URI: "http://localhost:8080/some_block_uri_1",
|
||||||
|
AccountID: account1,
|
||||||
|
TargetAccountID: account2,
|
||||||
|
}); err != nil {
|
||||||
|
suite.FailNow(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
block, err := suite.db.GetBlock(ctx, account1, account2)
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.NotNil(block)
|
||||||
|
suite.Equal("01G202BCSXXJZ70BHB5KCAHH8C", block.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (suite *RelationshipTestSuite) TestGetRelationship() {
|
func (suite *RelationshipTestSuite) TestGetRelationship() {
|
||||||
suite.Suite.T().Skip("TODO: implement")
|
requestingAccount := suite.testAccounts["local_account_1"]
|
||||||
|
targetAccount := suite.testAccounts["admin_account"]
|
||||||
|
|
||||||
|
relationship, err := suite.db.GetRelationship(context.Background(), requestingAccount.ID, targetAccount.ID)
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.NotNil(relationship)
|
||||||
|
|
||||||
|
suite.True(relationship.Following)
|
||||||
|
suite.True(relationship.ShowingReblogs)
|
||||||
|
suite.False(relationship.Notifying)
|
||||||
|
suite.True(relationship.FollowedBy)
|
||||||
|
suite.False(relationship.Blocking)
|
||||||
|
suite.False(relationship.BlockedBy)
|
||||||
|
suite.False(relationship.Muting)
|
||||||
|
suite.False(relationship.MutingNotifications)
|
||||||
|
suite.False(relationship.Requested)
|
||||||
|
suite.False(relationship.DomainBlocking)
|
||||||
|
suite.False(relationship.Endorsed)
|
||||||
|
suite.Empty(relationship.Note)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (suite *RelationshipTestSuite) TestIsFollowing() {
|
func (suite *RelationshipTestSuite) TestIsFollowingYes() {
|
||||||
suite.Suite.T().Skip("TODO: implement")
|
requestingAccount := suite.testAccounts["local_account_1"]
|
||||||
|
targetAccount := suite.testAccounts["admin_account"]
|
||||||
|
isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount, targetAccount)
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.True(isFollowing)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *RelationshipTestSuite) TestIsFollowingNo() {
|
||||||
|
requestingAccount := suite.testAccounts["admin_account"]
|
||||||
|
targetAccount := suite.testAccounts["local_account_2"]
|
||||||
|
isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount, targetAccount)
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.False(isFollowing)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (suite *RelationshipTestSuite) TestIsMutualFollowing() {
|
func (suite *RelationshipTestSuite) TestIsMutualFollowing() {
|
||||||
suite.Suite.T().Skip("TODO: implement")
|
requestingAccount := suite.testAccounts["local_account_1"]
|
||||||
|
targetAccount := suite.testAccounts["admin_account"]
|
||||||
|
isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount, targetAccount)
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.True(isMutualFollowing)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (suite *RelationshipTestSuite) AcceptFollowRequest() {
|
func (suite *RelationshipTestSuite) TestIsMutualFollowingNo() {
|
||||||
for _, account := range suite.testAccounts {
|
requestingAccount := suite.testAccounts["local_account_1"]
|
||||||
_, err := suite.db.AcceptFollowRequest(context.Background(), account.ID, "NON-EXISTENT-ID")
|
targetAccount := suite.testAccounts["local_account_2"]
|
||||||
if err != nil && !errors.Is(err, db.ErrNoEntries) {
|
isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount, targetAccount)
|
||||||
suite.Suite.Fail("error accepting follow request: %v", err)
|
suite.NoError(err)
|
||||||
}
|
suite.True(isMutualFollowing)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (suite *RelationshipTestSuite) GetAccountFollowRequests() {
|
func (suite *RelationshipTestSuite) TestAcceptFollowRequestOK() {
|
||||||
suite.Suite.T().Skip("TODO: implement")
|
ctx := context.Background()
|
||||||
|
account := suite.testAccounts["admin_account"]
|
||||||
|
targetAccount := suite.testAccounts["local_account_2"]
|
||||||
|
|
||||||
|
followRequest := >smodel.FollowRequest{
|
||||||
|
ID: "01GEF753FWHCHRDWR0QEHBXM8W",
|
||||||
|
URI: "http://localhost:8080/weeeeeeeeeeeeeeeee",
|
||||||
|
AccountID: account.ID,
|
||||||
|
TargetAccountID: targetAccount.ID,
|
||||||
}
|
}
|
||||||
|
|
||||||
func (suite *RelationshipTestSuite) GetAccountFollows() {
|
if err := suite.db.Put(ctx, followRequest); err != nil {
|
||||||
suite.Suite.T().Skip("TODO: implement")
|
suite.FailNow(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (suite *RelationshipTestSuite) CountAccountFollows() {
|
follow, err := suite.db.AcceptFollowRequest(ctx, account.ID, targetAccount.ID)
|
||||||
suite.Suite.T().Skip("TODO: implement")
|
suite.NoError(err)
|
||||||
|
suite.NotNil(follow)
|
||||||
|
suite.Equal(followRequest.URI, follow.URI)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (suite *RelationshipTestSuite) GetAccountFollowedBy() {
|
func (suite *RelationshipTestSuite) TestAcceptFollowRequestNotExisting() {
|
||||||
// TODO: more comprehensive tests here
|
ctx := context.Background()
|
||||||
|
account := suite.testAccounts["admin_account"]
|
||||||
|
targetAccount := suite.testAccounts["local_account_2"]
|
||||||
|
|
||||||
for _, account := range suite.testAccounts {
|
follow, err := suite.db.AcceptFollowRequest(ctx, account.ID, targetAccount.ID)
|
||||||
var err error
|
suite.ErrorIs(err, db.ErrNoEntries)
|
||||||
|
suite.Nil(follow)
|
||||||
_, err = suite.db.GetAccountFollowedBy(context.Background(), account.ID, false)
|
|
||||||
if err != nil {
|
|
||||||
suite.Suite.Fail("error checking accounts followed by: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = suite.db.GetAccountFollowedBy(context.Background(), account.ID, true)
|
func (suite *RelationshipTestSuite) TestAcceptFollowRequestFollowAlreadyExists() {
|
||||||
if err != nil {
|
ctx := context.Background()
|
||||||
suite.Suite.Fail("error checking localOnly accounts followed by: %v", err)
|
account := suite.testAccounts["local_account_1"]
|
||||||
}
|
targetAccount := suite.testAccounts["admin_account"]
|
||||||
}
|
|
||||||
|
// follow already exists in the db from local_account_1 -> admin_account
|
||||||
|
existingFollow := >smodel.Follow{}
|
||||||
|
if err := suite.db.GetByID(ctx, suite.testFollows["local_account_1_admin_account"].ID, existingFollow); err != nil {
|
||||||
|
suite.FailNow(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (suite *RelationshipTestSuite) CountAccountFollowedBy() {
|
followRequest := >smodel.FollowRequest{
|
||||||
suite.Suite.T().Skip("TODO: implement")
|
ID: "01GEF753FWHCHRDWR0QEHBXM8W",
|
||||||
|
URI: "http://localhost:8080/weeeeeeeeeeeeeeeee",
|
||||||
|
AccountID: account.ID,
|
||||||
|
TargetAccountID: targetAccount.ID,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := suite.db.Put(ctx, followRequest); err != nil {
|
||||||
|
suite.FailNow(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
follow, err := suite.db.AcceptFollowRequest(ctx, account.ID, targetAccount.ID)
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.NotNil(follow)
|
||||||
|
|
||||||
|
// uri should be equal to value of new/overlapping follow request
|
||||||
|
suite.NotEqual(followRequest.URI, existingFollow.URI)
|
||||||
|
suite.Equal(followRequest.URI, follow.URI)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *RelationshipTestSuite) TestRejectFollowRequestOK() {
|
||||||
|
ctx := context.Background()
|
||||||
|
account := suite.testAccounts["admin_account"]
|
||||||
|
targetAccount := suite.testAccounts["local_account_2"]
|
||||||
|
|
||||||
|
followRequest := >smodel.FollowRequest{
|
||||||
|
ID: "01GEF753FWHCHRDWR0QEHBXM8W",
|
||||||
|
URI: "http://localhost:8080/weeeeeeeeeeeeeeeee",
|
||||||
|
AccountID: account.ID,
|
||||||
|
TargetAccountID: targetAccount.ID,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := suite.db.Put(ctx, followRequest); err != nil {
|
||||||
|
suite.FailNow(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
rejectedFollowRequest, err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID)
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.NotNil(rejectedFollowRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *RelationshipTestSuite) TestRejectFollowRequestNotExisting() {
|
||||||
|
ctx := context.Background()
|
||||||
|
account := suite.testAccounts["admin_account"]
|
||||||
|
targetAccount := suite.testAccounts["local_account_2"]
|
||||||
|
|
||||||
|
rejectedFollowRequest, err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID)
|
||||||
|
suite.ErrorIs(err, db.ErrNoEntries)
|
||||||
|
suite.Nil(rejectedFollowRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *RelationshipTestSuite) TestGetAccountFollowRequests() {
|
||||||
|
ctx := context.Background()
|
||||||
|
account := suite.testAccounts["admin_account"]
|
||||||
|
targetAccount := suite.testAccounts["local_account_2"]
|
||||||
|
|
||||||
|
followRequest := >smodel.FollowRequest{
|
||||||
|
ID: "01GEF753FWHCHRDWR0QEHBXM8W",
|
||||||
|
URI: "http://localhost:8080/weeeeeeeeeeeeeeeee",
|
||||||
|
AccountID: account.ID,
|
||||||
|
TargetAccountID: targetAccount.ID,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := suite.db.Put(ctx, followRequest); err != nil {
|
||||||
|
suite.FailNow(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
followRequests, err := suite.db.GetAccountFollowRequests(ctx, targetAccount.ID)
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.Len(followRequests, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *RelationshipTestSuite) TestGetAccountFollows() {
|
||||||
|
account := suite.testAccounts["local_account_1"]
|
||||||
|
follows, err := suite.db.GetAccountFollows(context.Background(), account.ID)
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.Len(follows, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *RelationshipTestSuite) TestCountAccountFollowsLocalOnly() {
|
||||||
|
account := suite.testAccounts["local_account_1"]
|
||||||
|
followsCount, err := suite.db.CountAccountFollows(context.Background(), account.ID, true)
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.Equal(2, followsCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *RelationshipTestSuite) TestCountAccountFollows() {
|
||||||
|
account := suite.testAccounts["local_account_1"]
|
||||||
|
followsCount, err := suite.db.CountAccountFollows(context.Background(), account.ID, false)
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.Equal(2, followsCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *RelationshipTestSuite) TestGetAccountFollowedBy() {
|
||||||
|
account := suite.testAccounts["local_account_1"]
|
||||||
|
follows, err := suite.db.GetAccountFollowedBy(context.Background(), account.ID, false)
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.Len(follows, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *RelationshipTestSuite) TestGetAccountFollowedByLocalOnly() {
|
||||||
|
account := suite.testAccounts["local_account_1"]
|
||||||
|
follows, err := suite.db.GetAccountFollowedBy(context.Background(), account.ID, true)
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.Len(follows, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *RelationshipTestSuite) TestCountAccountFollowedBy() {
|
||||||
|
account := suite.testAccounts["local_account_1"]
|
||||||
|
followsCount, err := suite.db.CountAccountFollowedBy(context.Background(), account.ID, false)
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.Equal(2, followsCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *RelationshipTestSuite) TestCountAccountFollowedByLocalOnly() {
|
||||||
|
account := suite.testAccounts["local_account_1"]
|
||||||
|
followsCount, err := suite.db.CountAccountFollowedBy(context.Background(), account.ID, true)
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.Equal(2, followsCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRelationshipTestSuite(t *testing.T) {
|
func TestRelationshipTestSuite(t *testing.T) {
|
||||||
|
|
|
@ -67,7 +67,7 @@ func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, db
|
||||||
return u.cache.GetByID(id)
|
return u.cache.GetByID(id)
|
||||||
},
|
},
|
||||||
func(user *gtsmodel.User) error {
|
func(user *gtsmodel.User) error {
|
||||||
return u.newUserQ(user).Where("user.id = ?", id).Scan(ctx)
|
return u.newUserQ(user).Where("? = ?", bun.Ident("user.id"), id).Scan(ctx)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -79,7 +79,7 @@ func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gts
|
||||||
return u.cache.GetByAccountID(accountID)
|
return u.cache.GetByAccountID(accountID)
|
||||||
},
|
},
|
||||||
func(user *gtsmodel.User) error {
|
func(user *gtsmodel.User) error {
|
||||||
return u.newUserQ(user).Where("user.account_id = ?", accountID).Scan(ctx)
|
return u.newUserQ(user).Where("? = ?", bun.Ident("user.account_id"), accountID).Scan(ctx)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -91,7 +91,7 @@ func (u *userDB) GetUserByEmailAddress(ctx context.Context, emailAddress string)
|
||||||
return u.cache.GetByEmail(emailAddress)
|
return u.cache.GetByEmail(emailAddress)
|
||||||
},
|
},
|
||||||
func(user *gtsmodel.User) error {
|
func(user *gtsmodel.User) error {
|
||||||
return u.newUserQ(user).Where("user.email = ?", emailAddress).Scan(ctx)
|
return u.newUserQ(user).Where("? = ?", bun.Ident("user.email"), emailAddress).Scan(ctx)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -103,7 +103,7 @@ func (u *userDB) GetUserByConfirmationToken(ctx context.Context, confirmationTok
|
||||||
return u.cache.GetByConfirmationToken(confirmationToken)
|
return u.cache.GetByConfirmationToken(confirmationToken)
|
||||||
},
|
},
|
||||||
func(user *gtsmodel.User) error {
|
func(user *gtsmodel.User) error {
|
||||||
return u.newUserQ(user).Where("user.confirmation_token = ?", confirmationToken).Scan(ctx)
|
return u.newUserQ(user).Where("? = ?", bun.Ident("user.confirmation_token"), confirmationToken).Scan(ctx)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -85,14 +85,8 @@ func parseWhere(w db.Where) (query string, args []interface{}) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if w.CaseInsensitive {
|
|
||||||
query = "LOWER(?) != LOWER(?)"
|
|
||||||
args = []interface{}{bun.Safe(w.Key), w.Value}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
query = "? != ?"
|
query = "? != ?"
|
||||||
args = []interface{}{bun.Safe(w.Key), w.Value}
|
args = []interface{}{bun.Ident(w.Key), w.Value}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -102,13 +96,7 @@ func parseWhere(w db.Where) (query string, args []interface{}) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if w.CaseInsensitive {
|
|
||||||
query = "LOWER(?) = LOWER(?)"
|
|
||||||
args = []interface{}{bun.Safe(w.Key), w.Value}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
query = "? = ?"
|
query = "? = ?"
|
||||||
args = []interface{}{bun.Safe(w.Key), w.Value}
|
args = []interface{}{bun.Ident(w.Key), w.Value}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,9 +24,6 @@ type Where struct {
|
||||||
Key string
|
Key string
|
||||||
// The value to match.
|
// The value to match.
|
||||||
Value interface{}
|
Value interface{}
|
||||||
// Whether the value (if a string) should be case sensitive or not.
|
|
||||||
// Defaults to false.
|
|
||||||
CaseInsensitive bool
|
|
||||||
// If set, reverse the where.
|
// If set, reverse the where.
|
||||||
// `WHERE k = v` becomes `WHERE k != v`.
|
// `WHERE k = v` becomes `WHERE k != v`.
|
||||||
// `WHERE k IS NULL` becomes `WHERE k IS NOT NULL`
|
// `WHERE k IS NULL` becomes `WHERE k IS NOT NULL`
|
||||||
|
|
|
@ -133,8 +133,10 @@ func (p *processor) initiateDomainBlockSideEffects(ctx context.Context, account
|
||||||
}
|
}
|
||||||
|
|
||||||
// if we have an instance account for this instance, delete it
|
// if we have an instance account for this instance, delete it
|
||||||
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "username", Value: block.Domain, CaseInsensitive: true}}, >smodel.Account{}); err != nil {
|
if instanceAccount, err := p.db.GetAccountByUsernameDomain(ctx, block.Domain, block.Domain); err == nil {
|
||||||
l.Errorf("domainBlockProcessSideEffects: db error removing instance account: %s", err)
|
if err := p.db.DeleteAccount(ctx, instanceAccount.ID); err != nil {
|
||||||
|
l.Errorf("domainBlockProcessSideEffects: db error deleting instance account: %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// delete accounts through the normal account deletion system (which should also delete media + posts + remove posts from timelines)
|
// delete accounts through the normal account deletion system (which should also delete media + posts + remove posts from timelines)
|
||||||
|
|
|
@ -55,7 +55,7 @@ func (p *processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Acc
|
||||||
// remove the domain block reference from the instance, if we have an entry for it
|
// remove the domain block reference from the instance, if we have an entry for it
|
||||||
i := >smodel.Instance{}
|
i := >smodel.Instance{}
|
||||||
if err := p.db.GetWhere(ctx, []db.Where{
|
if err := p.db.GetWhere(ctx, []db.Where{
|
||||||
{Key: "domain", Value: domainBlock.Domain, CaseInsensitive: true},
|
{Key: "domain", Value: domainBlock.Domain},
|
||||||
{Key: "domain_block_id", Value: id},
|
{Key: "domain_block_id", Value: id},
|
||||||
}, i); err == nil {
|
}, i); err == nil {
|
||||||
updatingColumns := []string{"suspended_at", "domain_block_id", "updated_at"}
|
updatingColumns := []string{"suspended_at", "domain_block_id", "updated_at"}
|
||||||
|
|
Loading…
Reference in New Issue